Skip to content

Commit

Permalink
Refactor distance function used in Hierarchical Clustering class.
Browse files Browse the repository at this point in the history
  • Loading branch information
hoanganhngo610 committed Sep 13, 2023
1 parent 43884f0 commit 2a80d3c
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions river/cluster/hcluster.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from collections.abc import Callable

import functools
import math
import numpy as np

from river import base, utils
from river.neighbors.base import DistanceFunc


# Node of a binary tree for Hierarchical Clustering
Expand Down Expand Up @@ -137,7 +138,7 @@ class HierarchicalClustering(base.Clusterer):
def __init__(
self,
window_size: int = 100,
distance_func: Callable[[BinaryTreeNode, BinaryTreeNode], float] = None,
distance_func: DistanceFunc = functools.partial(utils.math.minkowski_distance, p=2),
):
# Number of nodes
self.n = 0
Expand All @@ -150,10 +151,7 @@ def __init__(
# First node of the tree
self.root = None
# Distance function
if distance_func is not None:
self.distance = distance_func
else:
self.distance = euclidean_distance
self.distance = distance_func

def otd_clustering(self, tree, x):
# Online top down clustering (OTD), the first algorithm for online hierarchical clustering.
Expand Down Expand Up @@ -347,7 +345,7 @@ def inter_subtree_similarity(self, tree_a, tree_b):
for i, w_i in enumerate(leaves_a):
for j, w_j in enumerate(leaves_b):
nb += 1
r += self.distance(w_i, w_j)
r += self.distance(utils.numpy2dict(w_i.data), utils.numpy2dict(w_j.data))
return r / nb

def intra_subtree_similarity(self, tree):
Expand All @@ -361,7 +359,7 @@ def intra_subtree_similarity(self, tree):
for j, w_j in enumerate(leaves):
if i < j:
nb += 1
r += self.distance(w_i, w_j)
r += self.distance(utils.numpy2dict(w_i.data), utils.numpy2dict(w_j.data))
return r / nb

def __str__(self):
Expand Down

0 comments on commit 2a80d3c

Please sign in to comment.