I have implemented a function to find the nearest data point to each centroid calculated after running the K-Means clustering algorithm. I wanted to know if there's a sklearn
function that allows me to find the M nearest points to each of the centroids.
M nearest points to centroid in K-Means clustering
Asked Answered
After running K-means, we can use sklearn.neighbors.NearestNeighbors to fit our dataset. We can then query the Nearest Neighbor model with our K-means centroids to retrieve the neighbors. Like this:
# Copyright 2024 Google LLC.
# SPDX-License-Identifier: Apache-2.0
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
# random dense embeddings for 100 points with 10 dimensions.
dataset = np.random.rand(100,10)
# fit K-means with 3 clusters on our dataset.
kme = KMeans(n_clusters=3)
kme.fit(dataset)
# we should have 3 vectors for 3 centroids.
print(kme.cluster_centers_.shape) # (3, 10)
# initialize NearestNeighbor with 5 neighbors and fit our dataset.
knn = NearestNeighbors(n_neighbors=5, metric='cosine')
knn.fit(dataset)
# Use the model to query the centroids' neighbors.
distances, indices = knn.kneighbors(kme.cluster_centers_)
for centroid, distance_from_centroid, index in zip(kme.cluster_centers_, distances, indices):
print(centroid, distance_from_centroid, index)
The last loop will output 3 lines. Each one resembles the vector of a centroid along with 5 distances and indices of its closest neighbors.
Yes you want to check out the following tutorial from sklearn : http://scikit-learn.org/stable/modules/neighbors.html
The classsklearn.neighbors.NearestNeighbors finds it for you: http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html#sklearn.neighbors.NearestNeighbors
© 2022 - 2024 — McMap. All rights reserved.
sklearn.neighbors.NearestNeighbors
and usekneighbors
function to find the nearest neighbours of the cluster centroids. – Clinkscales