KMeans Explained

Published on: October 19, 2020

KMeans Explained

Table of Content

Clustering is a machine learning technique that involves grouping similar data points together into so called clusters. Clustering is an unsupervised learning method commonly used in data science and other fields.

KMeans is probably the most well-known of all the clustering algorithm. Its goal is to separate the data into K distinct non-overlapping subgroups (clusters) of equal variance, minimizing a criterion known as the inertia or within-cluster sum-of-squares.

KMeans theory

KMeans works as follows:

  1. First, pick the number of clusters (For more info, check the "Choosing K" section).
  2. Initialize the center points of the cluster (centroids) by shuffling the dataset and then selecting K data points for the centroids.
  3. Assign data points to the cluster with the nearest centroid.
  4. Recompute centroid position by taking the mean of all data points assigned to the cluster.
  5. Repeat steps 3 and 4 for a set number of iterations or until the centroids aren't moving much between iterations anymore.

Choosing K

Choosing the right K value by hand can get quite tricky, especially if you're working with 3+ dimensional data. If you select a too small or big number for K, the result can be quite underwhelming.


In this section, I'll show you two methods commonly used to choose the right K value:

  • The Elbow Method
  • Silhouette Analysis

Elbow Method

The Elbow Method shows us what a good number for K is based on the sum of squared distances (SSE) between data points and their assigned clusters' centroid. We pick k at the spot where the SSE starts to flatten out, which looks like an elbow. Below you can see an example created using Yellowbrick.


Silhouette Analysis

The Silhouette Analysis can be used to study the separation distance between the resulting clusters. It displays a measure of how similar an object is to its own cluster (cohesion) compared to other clusters (separation) and can thus be used to assess the number of clusters k.

The Silhouette Analysis is computed as follows:

  • Compute the average distance between all data points in one cluster

  • For all data points in cluster compute the average distance to all points in another cluster (where )

The is used, because we want to know the average distance to the closed cluster is not a member of.

With and we can now calculate the silhouette coefficient:

The coefficient can take values in the interval . Zero means the sample is very close to the neighboring clusters. One means the sample is far away from the neighboring cluster, and negative one means the sample is probably assigned to the wrong cluster.

Below you can see an example of silhouette analysis using Scikit Learn:

silhouette_analysis_3_clusters silhouette_analysis_4_clusters silhouette_analysis_5_clusters


KMeans is an easy-to-implement algorithm that is also quite fast with an average complexity of , where n is the number of samples, and T is the number of iteration.


As mentioned above, KMeans makes use of the sum-of-squares criterion, which works well if the clusters have a spherical-like shape. It doesn't work well on many other types of data like complicated shapes, though. In this section, we'll go over a few cases where KMeans performs poorly.

First, KMeans doesn't put data points that are far away from each other into the same cluster, even when they obviously should be because they underly some obvious structure like points on a line, for example.


In the image above, you can see that KMeans creates the clusters in between the two lines and therefore splits each line into one of two clusters rather than classifying each line as a cluster. On the right side, you can see the DBSCAN (Density-based spatial clustering of applications with noise) algorithm, which is able to separate the two lines without any issues.

Also, as mentioned at the start of the section KMeans performs poorly for complicated geometric shapes such as the moons and circles shown below.

noisy_moons_with_true_output noisy_circles_with_true_output

Other clustering algorithms like Spectral Clustering, Agglomerative Clustering, or DBSCAN don't have any problems with such data. For a more in-depth analysis of how different clustering algorithms perform on different interesting 2d datasets, I recommend checking out 'Comparing different clustering algorithms on toy datasets' from Scikit-Learn.


Credit / Other resources

More stories

  • Linear Discriminant Analysis (LDA)

  • Metrics

  • Activation Functions