Machine learning 022 — Modeling using mean-shift clustering algorithms

(Python libraries and versions used in this article: Python 3.5, Numpy 1.14, Scikit-learn 0.19, matplotlib 2.2)

There are many kinds of unsupervised learning algorithms, and the K-means clustering algorithm has been explained before, which is used to carry out vector quantization compression of pictures. Now let’s learn the second unsupervised learning algorithm —- mean shift algorithm.


1. Introduction to mean shift algorithm

Mean shift algorithm is a non-parametric method based on density gradient rise, which is often used in image recognition of target tracking, data clustering, classification and other scenes.

Its core idea is: first, choose a central point, and then calculate the center of all points within a certain range to the center of the average of the distance vector, calculating the average mean an offset, and then move the center point to offset the mean position, through the repeated move, can make the center gradually approaching to the best position. This idea is similar to the gradient descent method, which can reach the local optimal solution or global optimal solution on the gradient by constantly moving towards the gradient descent direction.

Thought of the following is the mean drift, first randomly select a central point (green), and then calculate all within a certain range to the point at which the distance average, average, and then move the center distance to yellow point, in the same way, and then calculate the yellow point must be within the scope of all point to the yellow point average distance, After calculating the mean for many times and moving the center point, the center point can gradually approach the best center point, namely the red point in the figure.

1.1 Basic formula of mean shift algorithm

As can be seen from the above core idea, the process of mean shift is the process of repeatedly calculating the distance mean and moving the center point. Therefore, calculating the offset mean and moving distance are two key steps. The basic formula for calculating the offset mean is as follows.

Where, Sh is the high-dimensional sphere region with x as the center point and h as the radius; K: number of points contained in the range of Sh; Xi: points contained in the Sh range

The second step is to calculate the position of the center point after moving a certain distance, and the calculation formula is as follows:

Where, Mt is the offset mean value obtained under t state; Xt is the center in the state t

Obviously, the center position after the move is the position before the move plus the mean of the offset.

1.2 Introducing kernel function offset mean algorithm

That although the basic formula of mean shift algorithm is introduced, but the formula has certain problem, we know that all the sample points in high dimensional ball area’s contribution to solving is different, and the basic formula is as a contribution to deal with, namely the weight of all points, it is not logical, so how to improve? We can introduce a kernel function to find the contribution weight of each sample point. Of course, there are many kinds of kernel functions to solve the weight, and gaussian function is one of them. The following formula is the calculation formula of offset mean after introducing gaussian kernel function:

This is what the kernel looks like inside.

1.3 Calculation procedure of mean shift algorithm

Mean shift algorithm is widely used, such as in clustering, image segmentation, target tracking and other fields. Its operation steps often include the following steps:

1. Select a point randomly from the data points as the initial center point.

2. Find out all points within the bandwidth distance from the center point, record them as set M, and consider these points as cluster C.

3. Calculate the vectors from the center point to each element in set M and add these vectors to get the offset vector.

4. Move the center point along the offset direction, and the moving distance is the modulus of the offset vector.

5. Repeat steps 2,3, and 4 until the size of the offset vector meets the set threshold, remembering the center point at this time.

6. Repeat 1,2,3,4,5 until all points are classified.

7. Classification: According to the access frequency of each class and each point, the class with the highest access frequency is selected as the class of the current point set.

1.4 Advantages of mean shift algorithm

When clustering data points, the mean shift algorithm regards the distribution of data points as probability density function, and hopes to find the patterns of data points according to the distribution characteristics of function in the feature space. These patterns correspond to a group of locally most densely distributed points.

Although we explained the K-means algorithm before, in the practical application of the K-means algorithm, we need to know that we need to divide the data into several categories. If the number of categories is wrong, it is often difficult to get satisfactory classification results, and the categories to be divided are often difficult to determine in advance. This is the application difficulty of k-means algorithm.

However, the mean shift algorithm does not need to know the number of clusters in advance. This algorithm can automatically divide the most suitable population when we do not know how many clusters to look for, which is an obvious advantage of the mean shift algorithm.

Some of the above content is from the blog post, thanks.


2. Build mean shift model to cluster data

The data set used in this paper and the way of reading the data set are exactly the same as the clustering analysis of data by k-means algorithm in the last article [Furnace AI] machine learning 020, so it is omitted here.

The following is the code to build the MeanShift object. Before using MeanShift, we need to evaluate the bandwidth, which is the distance from the center point mentioned above. We need to put all the points contained within this distance into a set M, which is used to calculate the offset vector.

# Build MeanShift objects, but need to evaluate bandwidth
from sklearn.cluster import MeanShift, estimate_bandwidth
bandwidth=estimate_bandwidth(dataset_X,quantile=0.1,
                             n_samples=len(dataset_X))
meanshift=MeanShift(bandwidth=bandwidth,bin_seeding=True) # Build object
meanshift.fit(dataset_X) Train the dataset with MeanShift objects

centroids=meanshift.cluster_centers_ The center of mass coordinates correspond to Feature0, feature1
print(centroids) # You can see that there are four rows, or four centers of mass
labels=meanshift.labels_  Label for each data point in the dataset
# print(labels)

cluster_num=len(np.unique(labels)) Number of labels, that is, the number of ethnic groups automatically divided
print('cluster num: {}'.format(cluster_num))
Copy the code

— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — – — – a — — — — — — — — — — — — — — — –

[[8.22338235 1.34779412] [4.10104478-0.81164179] [1.18820896 2.10716418] [4.995 4.99967742]] cluster num: 4

— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — – — — — — — — — — — — — — — —

It can be seen that we have obtained four centroids here. The coordinates of these four centroids can be obtained by meanShift. cluster_centers_, and meanshift.labels_ is the label of the original sample data. That is, the label we find by ourselves through the mean shift algorithm, which is the advantage of unsupervised learning: although no label is assigned to sample data, the algorithm can find its corresponding label by itself.

In the same way, we can check the effectiveness of the MeanShift algorithm directly by using the following function.

def visual_meanshift_effect(meanshift,dataset):
    assert dataset.shape[1] = =2.'only support dataset with 2 features'
    X=dataset[:,0]
    Y=dataset[:,1]
    X_min,X_max=np.min(X)- 1,np.max(X)+1
    Y_min,Y_max=np.min(Y)- 1,np.max(Y)+1
    X_values,Y_values=np.meshgrid(np.arange(X_min,X_max,0.01),
                                  np.arange(Y_min,Y_max,0.01))
    # Predict the marking of grid points
    predict_labels=meanshift.predict(np.c_[X_values.ravel(),Y_values.ravel()])
    predict_labels=predict_labels.reshape(X_values.shape)
    plt.figure()
    plt.imshow(predict_labels,interpolation='nearest',
               extent=(X_values.min(),X_values.max(),
                       Y_values.min(),Y_values.max()),
               cmap=plt.cm.Paired,
               aspect='auto',
               origin='lower')
    
    Draw the data set into a chart
    plt.scatter(X,Y,marker='v',facecolors='none',edgecolors='k',s=30)
    
    # Draw the center point into the graph
    centroids=meanshift.cluster_centers_
    plt.scatter(centroids[:,0],centroids[:,1],marker='o',
                s=100,linewidths=2,color='k',zorder=5,facecolors='b')
    plt.title('MeanShift effect graph')
    plt.xlim(X_min,X_max)
    plt.ylim(Y_min,Y_max)
    plt.xlabel('feature_0')
    plt.ylabel('feature_1')
    plt.show()
    
visual_meanshift_effect(meanshift,dataset_X)
Copy the code

# # # # # # # # # # # # # # # # # # # # # # # # small * * * * * * * * * * and # # # # # # # # # # # # # # # # # # #

1. The construction and training methods of MeanShift are almost the same as k-means, but MeanShift can automatically calculate the number of populations of data sets without the need to specify beforehand, which makes MeanShift easier to use than K-means.

2. The MeanShift object after training contains the centroid coordinates of the data set and the corresponding label information of each sample of the data set, which can be easily obtained.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #


Note: This part of the code has been uploaded to (my Github), welcome to download.

References:

1, Classic Examples of Python machine learning, by Prateek Joshi, translated by Tao Junjie and Chen Xiaoli