-
Notifications
You must be signed in to change notification settings - Fork 6
/
metric.py
32 lines (23 loc) · 1.04 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
def compute_purity(y_pred, y_true):
"""
Calculate the purity, a measurement of quality for the clustering
results.
Each cluster is assigned to the class which is most frequent in the
cluster. Using these classes, the percent accuracy is then calculated.
Returns:
A number between 0 and 1. Poor clusterings have a purity close to 0
while a perfect clustering has a purity of 1.
"""
# get the set of unique cluster ids
clusters = set(y_pred)
# find out what class is most frequent in each cluster
cluster_classes = {}
correct = 0
for cluster in clusters:
# get the indices of rows in this cluster
indices = np.where(y_pred == cluster)[0]
cluster_labels = y_true[indices]
majority_label = np.argmax(np.bincount(cluster_labels))
correct += np.sum(cluster_labels == majority_label)
return float(correct) / len(y_pred)