-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4979301
commit 12933b7
Showing
4 changed files
with
364 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
defmodule Scholar.Cluster.DBSCAN do | ||
@moduledoc """ | ||
Perform DBSCAN clustering from vector array or distance matrix. | ||
DBSCAN - Density-Based Spatial Clustering of Applications with Noise. | ||
Finds core samples of high density and expands clusters from them. | ||
Good for data which contains clusters of similar density. | ||
""" | ||
import Nx.Defn | ||
import Scholar.Shared | ||
|
||
@derive {Nx.Container, containers: [:core_sample_indices, :labels]} | ||
defstruct [:core_sample_indices, :labels] | ||
|
||
opts = [ | ||
eps: [ | ||
default: 0.5, | ||
doc: """ | ||
The maximum distance between two samples for them to be considered as in the same neighborhood. | ||
""", | ||
type: {:custom, Scholar.Options, :positive_number, []} | ||
], | ||
min_samples: [ | ||
default: 5, | ||
doc: """ | ||
The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. | ||
This includes the point itself. | ||
""", | ||
type: :integer | ||
], | ||
metric: [ | ||
type: {:custom, Scholar.Options, :metric, []}, | ||
default: {:minkowski, 2}, | ||
doc: ~S""" | ||
Name of the metric. Possible values: | ||
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or :infinity) | ||
we can set Manhattan (1), Euclidean (2), Chebyshev (:infinity), or any arbitrary $L_p$ metric. | ||
* `:cosine` - Cosine metric. | ||
""" | ||
], | ||
weights: [ | ||
type: {:custom, Scholar.Options, :weights, []}, | ||
doc: """ | ||
The weights for each observation in x. If equals to `nil`, | ||
all observations are assigned equal weight. | ||
""" | ||
] | ||
] | ||
|
||
@opts_schema NimbleOptions.new!(opts) | ||
|
||
@doc """ | ||
Perform DBSCAN clustering from vector array or distance matrix. | ||
## Options | ||
#{NimbleOptions.docs(@opts_schema)} | ||
## Return Values | ||
The function returns a struct with the following parameters: | ||
* `:core_sample_indices` - Indices of core samples represented as a mask. | ||
The mask is a boolean array of shape {num_samples} where `1` indicates | ||
that the corresponding sample is a core sample and `0` otherwise. | ||
* `:labels` - Cluster labels for each point in the dataset given to fit(). | ||
Noisy samples are given the label -1. | ||
## Examples | ||
iex> x = Nx.tensor([[1, 2], [2, 2], [2, 3], [8, 7], [8, 8], [25, 80]]) | ||
iex> Scholar.Cluster.DBSCAN.fit(x, eps: 3, min_samples: 2) | ||
%Scholar.Cluster.DBSCAN{ | ||
core_sample_indices: Nx.tensor( | ||
[1, 1, 1, 1, 1, 0], type: :u8 | ||
), | ||
labels: Nx.tensor( | ||
[0, 0, 0, 1, 1, -1] | ||
) | ||
} | ||
""" | ||
deftransform fit(x, opts \\ []) do | ||
fit_n(x, NimbleOptions.validate!(opts, @opts_schema)) | ||
end | ||
|
||
defnp fit_n(x, opts) do | ||
num_samples = Nx.axis_size(x, 0) | ||
weights = validate_weights(opts[:weights], num_samples, type: to_float_type(x)) | ||
weights = if Nx.rank(weights) == 0, do: weights, else: Nx.new_axis(weights, -1) | ||
y_dummy = Nx.broadcast(Nx.tensor(0), {num_samples}) | ||
|
||
neighbor_model = | ||
Scholar.Neighbors.RadiusNearestNeighbors.fit(x, y_dummy, | ||
num_classes: 1, | ||
radius: opts[:eps], | ||
metric: opts[:metric] | ||
) | ||
|
||
{_dist, indices} = | ||
Scholar.Neighbors.RadiusNearestNeighbors.radius_neighbors(neighbor_model, x) | ||
|
||
n_neigbors = Nx.sum(indices * weights, axes: [1]) | ||
core_samples = n_neigbors >= opts[:min_samples] | ||
labels = dbscan_inner(core_samples, indices) | ||
|
||
%__MODULE__{ | ||
core_sample_indices: core_samples, | ||
labels: labels | ||
} | ||
end | ||
|
||
defnp dbscan_inner(is_core?, indices) do | ||
{labels, _} = | ||
while {labels = Nx.broadcast(0, {Nx.axis_size(indices, 0)}), | ||
{indices, is_core?, label_num = 1, i = 0}}, | ||
i < Nx.axis_size(indices, 0) do | ||
stack = Nx.broadcast(0, {Nx.axis_size(indices, 0) ** 2}) | ||
stack_ptr = 0 | ||
|
||
if Nx.take(labels, i) != 0 or not Nx.take(is_core?, i) do | ||
{labels, {indices, is_core?, label_num, i + 1}} | ||
else | ||
{labels, _} = | ||
while {labels, {k = i, label_num, indices, is_core?, stack, stack_ptr}}, | ||
stack_ptr >= 0 do | ||
{labels, stack, stack_ptr} = | ||
if Nx.take(labels, k) == 0 do | ||
labels = | ||
Nx.indexed_put( | ||
labels, | ||
Nx.new_axis(Nx.new_axis(k, 0), 0), | ||
Nx.new_axis(label_num, 0) | ||
) | ||
|
||
{stack, stack_ptr} = | ||
if Nx.take(is_core?, k) do | ||
neighb = Nx.take(indices, k) | ||
mask = neighb * (labels == 0) | ||
|
||
{stack, stack_ptr, _} = | ||
while {stack, stack_ptr, {mask, j = 0}}, j < Nx.axis_size(mask, 0) do | ||
if Nx.take(mask, j) != 0 do | ||
stack = | ||
Nx.indexed_put( | ||
stack, | ||
Nx.new_axis(Nx.new_axis(stack_ptr, 0), 0), | ||
Nx.new_axis(j, 0) | ||
) | ||
|
||
{stack, stack_ptr + 1, {mask, j + 1}} | ||
else | ||
{stack, stack_ptr, {mask, j + 1}} | ||
end | ||
end | ||
|
||
{stack, stack_ptr} | ||
else | ||
{stack, stack_ptr} | ||
end | ||
|
||
{labels, stack, stack_ptr} | ||
else | ||
{labels, stack, stack_ptr} | ||
end | ||
|
||
k = if stack_ptr > 0, do: Nx.take(stack, stack_ptr - 1), else: -1 | ||
stack_ptr = stack_ptr - 1 | ||
{labels, {k, label_num, indices, is_core?, stack, stack_ptr}} | ||
end | ||
|
||
{labels, {indices, is_core?, label_num + 1, i + 1}} | ||
end | ||
end | ||
|
||
# we need to subtract 1 from labels because we started from label_num=1 which simplifies oprations | ||
labels - 1 | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
defmodule Scholar.Cluster.DBSCANTest do | ||
use Scholar.Case, async: true | ||
alias Scholar.Cluster.DBSCAN | ||
doctest DBSCAN | ||
|
||
defp x do | ||
Nx.tensor([ | ||
[3, 6, 7, 5], | ||
[9, 8, 5, 4], | ||
[4, 4, 4, 1], | ||
[9, 4, 5, 6], | ||
[6, 4, 5, 7], | ||
[4, 5, 3, 3], | ||
[4, 5, 7, 8], | ||
[9, 4, 4, 5], | ||
[8, 4, 3, 9], | ||
[2, 8, 4, 4] | ||
]) | ||
end | ||
|
||
describe "fit" do | ||
test "fit with default parameters" do | ||
model = DBSCAN.fit(x()) | ||
|
||
assert model.labels == Nx.tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]) | ||
assert model.core_sample_indices == Nx.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], type: :u8) | ||
end | ||
|
||
test "fit with custom parameters" do | ||
model = DBSCAN.fit(x(), eps: 3.3, min_samples: 2) | ||
|
||
assert model.labels == Nx.tensor([-1, -1, 0, 1, 1, 0, 1, 1, -1, -1]) | ||
assert model.core_sample_indices == Nx.tensor([0, 0, 1, 1, 1, 1, 1, 1, 0, 0], type: :u8) | ||
end | ||
|
||
test "fit with custom metric" do | ||
model = DBSCAN.fit(x(), metric: :cosine, eps: 0.025, min_samples: 2) | ||
|
||
assert model.labels == Nx.tensor([-1, 0, -1, 1, -1, 0, -1, 1, -1, -1]) | ||
assert model.core_sample_indices == Nx.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 0], type: :u8) | ||
end | ||
|
||
test "test with artifically created 3 blobs" do | ||
x = | ||
Nx.tensor([ | ||
[0.21944999, 1.43491283], | ||
[0.76205536, -1.50701033], | ||
[0.17951214, 1.40902498], | ||
[0.55395946, -0.72871634], | ||
[0.84700085, -0.74057968], | ||
[0.84172206, 1.80812576], | ||
[0.22960209, 1.16815515], | ||
[0.76556559, -0.51520299], | ||
[-0.67304942, -0.80809616], | ||
[-0.62113428, 1.97278075], | ||
[1.58928515, -0.63815665], | ||
[-1.61518766, -0.72879148], | ||
[-0.90179385, -0.6071971], | ||
[1.00826938, -0.97348199], | ||
[2.02971885, -0.28943621], | ||
[0.21452547, 1.80296627], | ||
[1.02191358, -0.16055734], | ||
[-0.92808574, -0.43114313], | ||
[1.31314013, -0.81780066], | ||
[-1.51911189, -0.79773621], | ||
[-1.55314151, -0.56212295], | ||
[-0.08122475, 0.90177102], | ||
[-1.11135974, -0.41252607], | ||
[-1.02371315, -0.98077793], | ||
[-1.04447701, -0.50489449], | ||
[0.53987738, 1.05173636], | ||
[-1.38751061, -0.32845505], | ||
[0.99291866, 0.97132263], | ||
[-1.73164148, -0.43845562], | ||
[-0.35493331, -1.20752094], | ||
[0.70510875, 1.71681314], | ||
[1.17456328, 0.8081482], | ||
[-1.73524259, -0.08179708], | ||
[1.33510777, -1.29449661], | ||
[1.40129845, -0.46490155], | ||
[0.60100499, -1.27490336], | ||
[0.46883888, -0.65522071], | ||
[-0.2177715, 0.71514667], | ||
[0.58563711, -0.51343774], | ||
[2.12903658, -0.42352534], | ||
[-1.00353943, 1.52912482], | ||
[1.56155776, -1.11591808], | ||
[0.45842855, -0.90442994], | ||
[1.46862236, -0.63754612], | ||
[-1.07072954, -0.29290502], | ||
[0.59149966, 2.07200737], | ||
[-0.53328187, -0.09812247], | ||
[0.57856061, 1.25380552], | ||
[-1.20068112, -0.13575126], | ||
[1.23158268, -1.12040882], | ||
[-0.00765717, 1.35904775], | ||
[-0.69083416, -1.16889531], | ||
[-0.02545807, 1.20217886], | ||
[0.94617226, 1.44243549], | ||
[0.29086207, 1.01345361], | ||
[1.03499227, -1.1460714], | ||
[-0.99511549, -1.2719868], | ||
[0.68988861, -0.98239901], | ||
[0.4992885, -0.45776737], | ||
[1.43970521, -0.96922289], | ||
[-1.40195225, -0.11406939], | ||
[-0.36197503, -0.43667089], | ||
[-1.39222656, -1.00220379], | ||
[-1.77921414, -0.41500311], | ||
[-0.58677785, 1.46385462], | ||
[-0.13700889, 1.1815076], | ||
[0.13674289, 1.45207184], | ||
[0.71637898, -1.02921795], | ||
[-1.32646199, -0.47210034], | ||
[-0.32408855, 0.81989587], | ||
[-0.25494992, 1.43790209], | ||
[-0.3369623, -0.24018957], | ||
[-0.15423336, 1.18259797], | ||
[1.28906178, -0.80523052], | ||
[0.102854, 1.44600664], | ||
[-0.57941603, 1.23281117], | ||
[-0.41633082, 1.57149166], | ||
[-1.5943798, -0.09327404], | ||
[-0.25148103, 0.62809738], | ||
[-1.03541767, -1.122577], | ||
[-1.50997304, -0.73567119], | ||
[1.12800861, -0.52936333], | ||
[1.48152185, -0.7166135], | ||
[0.49317939, 1.3471885], | ||
[0.89285887, -1.21682128], | ||
[0.53450078, -0.89630194], | ||
[-0.08070404, 1.15574202], | ||
[0.82780654, -0.08661587], | ||
[0.34993344, 1.41969689], | ||
[0.71872377, -0.21881327], | ||
[0.82424386, 1.23540439], | ||
[-1.34613015, -0.72734816], | ||
[-1.59870257, -0.94453715], | ||
[0.37923491, -0.59627639], | ||
[0.17013686, 1.24155203], | ||
[-1.68900873, -0.38591244], | ||
[0.64047447, -0.36451399], | ||
[-1.14810577, -0.6090688], | ||
[-0.77761455, -0.79957477], | ||
[-0.87614308, -0.70244297] | ||
]) | ||
|
||
expected_labels = | ||
Nx.tensor( | ||
[0, 1, 0, 1, 1, 0, 0, 1, 2, -1, 1, 2, 2, 1] ++ | ||
[-1, 0, 1, 2, 1, 2, 2, 0, 2, 2, 2, 0, 2, 0] ++ | ||
[2, 2, 0, -1, 2, 1, 1, 1, 1, 0, 1, -1, -1, 1] ++ | ||
[1, 1, 2, 0, -1, 0, 2, 1, 0, 2, 0, 0, 0, 1] ++ | ||
[2, 1, 1, 1, 2, -1, 2, 2, 0, 0, 0, 1, 2, 0] ++ | ||
[0, -1, 0, 1, 0, 0, 0, 2, 0, 2, 2, 1, 1, 0] ++ | ||
[1, 1, 0, 1, 0, 1, 0, 2, 2, 1, 0, 2, 1, 2, 2, 2] | ||
) | ||
|
||
expected_core_samples = | ||
Nx.tensor( | ||
[1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1] ++ | ||
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] ++ | ||
[1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1] ++ | ||
[1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1] ++ | ||
[1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1] ++ | ||
[1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ++ | ||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | ||
type: :u8 | ||
) | ||
|
||
num_clusters = Nx.tensor(3) | ||
model = DBSCAN.fit(x, eps: 0.4, min_samples: 4) | ||
assert model.labels == expected_labels | ||
assert model.core_sample_indices == expected_core_samples | ||
# Check if algorithm predicted the correct number of clusters | ||
assert Nx.add(Nx.reduce_max(model.labels), 1) == num_clusters | ||
end | ||
end | ||
end |