Skip to content

Commit

Permalink
[REVIEW] Fix for OPG KNN Classifier & Regressor (#2844)
Browse files Browse the repository at this point in the history
* Fix for OPG KNN Classifier & Regressor

* Multiple additional fixes

* Pytests update

* Code style update

* Changelog update

* Requested changes

* Trying something to make CI pass

* Requested changes

* Dealing with distances imprecisions

* Updating tests for them to pass on all platforms

Co-authored-by: John Zedlewski <[email protected]>
  • Loading branch information
viclafargue and JohnZed authored Oct 6, 2020
1 parent d7e1c46 commit 726096a
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 305 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
- PR #2848: Fix typo in Python docstring for UMAP
- PR #2856: Fix LabelEncoder for filtered input
- PR #2855: Updates for RMM being header only
- PR #2844: Fix for OPG KNN Classifier & Regressor
- PR #2880: Fix bugs in Auto-ARIMA when s==None
- PR #2877: TSNE exception for n_components > 2
- PR #2879: Update unit test for LabelEncoder on filtered input
Expand Down
356 changes: 226 additions & 130 deletions cpp/src/knn/knn_opg_common.cu

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions cpp/src_prims/selection/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,19 @@ void class_probs(std::vector<float *> &out, const int64_t *knn_indices,
* knn_indices and labels
*/
device_buffer<int> y_normalized(allocator, stream, n_index_rows);
MLCommon::Label::make_monotonic(y_normalized.data(), y[i], n_index_rows,
stream, allocator);

/*
* Appending the array of unique labels to the original labels array
* to prevent make_monotonic function from producing misleading results
* due to the absence of some of the unique labels in the labels array
*/
device_buffer<int> y_tmp(allocator, stream, n_index_rows + n_unique_labels);
raft::update_device(y_tmp.data(), y[i], n_index_rows, stream);
raft::update_device(y_tmp.data() + n_index_rows, uniq_labels[i],
n_unique_labels, stream);

MLCommon::Label::make_monotonic(y_normalized.data(), y_tmp.data(),
y_tmp.size(), stream, allocator);
raft::linalg::unaryOp<int>(
y_normalized.data(), y_normalized.data(), n_index_rows,
[] __device__(int input) { return input - 1; }, stream);
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/neighbors/kneighbors_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ def score(self, X, y, convert_dtype=True):
-------
score
"""
labels, _, _ = self.predict(X, convert_dtype=convert_dtype)
diff = (labels == y)
if self.data_handler.datatype == 'cupy':
preds, _, _ = self.predict(X, convert_dtype=convert_dtype)
diff = (preds == y)
mean = da.mean(diff)
return mean.compute()
else:
Expand Down
10 changes: 6 additions & 4 deletions python/cuml/dask/neighbors/kneighbors_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,12 @@ def score(self, X, y):
-------
score
"""
labels, _, _ = self.predict(X, convert_dtype=True)
diff = (labels == y)
if self.data_handler.datatype == 'cupy':
mean = da.mean(diff)
return mean.compute()
preds, _, _ = self.predict(X, convert_dtype=True)
y_mean = y.mean(axis=0)
residual_sss = ((y - preds) ** 2).sum(axis=0)
total_sss = ((y - y_mean) ** 2).sum(axis=0)
r2_score = da.mean(1 - (residual_sss / total_sss))
return r2_score.compute()
else:
raise ValueError("Only Dask arrays are supported")
2 changes: 1 addition & 1 deletion python/cuml/neighbors/kneighbors_regressor_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class KNeighborsRegressorMG(KNeighborsMG):
query, query_parts_to_ranks, query_nrows,
ncols, rank, convert_dtype)

output = self.gen_local_output(data, convert_dtype, dtype='int32')
output = self.gen_local_output(data, convert_dtype, dtype='float32')

query_cais = input['cais']['query']
local_query_rows = list(map(lambda x: x.shape[0], query_cais))
Expand Down
130 changes: 45 additions & 85 deletions python/cuml/test/dask/test_kneighbors_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def generate_dask_array(np_array, n_parts):
@pytest.fixture(
scope="module",
params=[
unit_param({'n_samples': 1000, 'n_features': 30,
unit_param({'n_samples': 3000, 'n_features': 30,
'n_classes': 5, 'n_targets': 2}),
quality_param({'n_samples': 5000, 'n_features': 100,
'n_classes': 12, 'n_targets': 4}),
stress_param({'n_samples': 12000, 'n_features': 40,
'n_classes': 5, 'n_targets': 2})
quality_param({'n_samples': 8000, 'n_features': 35,
'n_classes': 12, 'n_targets': 3}),
stress_param({'n_samples': 20000, 'n_features': 40,
'n_classes': 12, 'n_targets': 4})
])
def dataset(request):
X, y = make_multilabel_classification(
Expand All @@ -69,18 +69,14 @@ def dataset(request):
if len(new_x) >= request.param['n_samples']:
break
X = X[new_x]
noise = np.random.normal(0, 1.2, X.shape)
X += noise
y = np.array(new_y)

return train_test_split(X, y, test_size=0.33)
return train_test_split(X, y, test_size=0.1)


def accuracy_score(y_true, y_pred):
assert y_pred.shape[0] == y_true.shape[0]
assert y_pred.shape[1] == y_true.shape[1]
return np.mean(y_pred == y_true)


def match_test(output1, output2):
def exact_match(output1, output2):
l1, i1, d1 = output1
l2, i2, d2 = output2
l2 = l2.squeeze()
Expand All @@ -90,130 +86,94 @@ def match_test(output1, output2):
assert i1.shape == i2.shape
assert d1.shape == d2.shape

# Distances should strictly match
assert np.array_equal(d1, d2)
# Distances should match
d1 = np.round(d1, 4)
d2 = np.round(d2, 4)
assert np.mean(d1 == d2) > 0.98

# Indices might differ for equivalent distances
for i in range(d1.shape[0]):
idx_set1, idx_set2 = (set(), set())
dist = 0.
for j in range(d1.shape[1]):
if d1[i, j] > dist:
assert idx_set1 == idx_set2
idx_set1, idx_set2 = (set(), set())
dist = d1[i, j]
idx_set1.add(i1[i, j])
idx_set2.add(i2[i, j])
# the last set of indices is not guaranteed
# Indices should match
correct_queries = (i1 == i2).all(axis=1)
assert np.mean(correct_queries) > 0.95

# As indices might differ, labels can also differ
# assert np.mean((l1 == l2)) > 0.6
# Labels should match
correct_queries = (l1 == l2).all(axis=1)
assert np.mean(correct_queries) > 0.95


def check_probabilities(l_probas, d_probas):
assert len(l_probas) == len(d_probas)
for i in range(len(l_probas)):
assert l_probas[i].shape == d_probas[i].shape
assert np.array_equal(l_probas[i], d_probas[i])


@pytest.mark.parametrize("datatype", ['dask_array', 'dask_cudf'])
@pytest.mark.parametrize("n_neighbors", [1, 3, 6])
@pytest.mark.parametrize("n_parts", [None, 2, 3, 5])
@pytest.mark.parametrize("batch_size", [256, 512, 1024])
def test_predict(dataset, datatype, n_neighbors, n_parts, batch_size, client):
@pytest.mark.parametrize("n_neighbors", [1, 3, 8])
@pytest.mark.parametrize("n_parts", [2, 4, 12])
@pytest.mark.parametrize("batch_size", [128, 1024])
def test_predict_and_score(dataset, datatype, n_neighbors,
n_parts, batch_size, client):
X_train, X_test, y_train, y_test = dataset
np_y_test = y_test

l_model = lKNNClf(n_neighbors=n_neighbors)
l_model.fit(X_train, y_train)
l_distances, l_indices = l_model.kneighbors(X_test)
l_labels = l_model.predict(X_test)
local_out = (l_labels, l_indices, l_distances)

if not n_parts:
n_parts = len(client.has_what().keys())
handmade_local_score = np.mean(y_test == l_labels)
handmade_local_score = round(handmade_local_score, 3)

X_train = generate_dask_array(X_train, n_parts)
X_test = generate_dask_array(X_test, n_parts)
y_train = generate_dask_array(y_train, n_parts)
y_test = generate_dask_array(y_test, n_parts)

if datatype == 'dask_cudf':
X_train = to_dask_cudf(X_train, client)
X_test = to_dask_cudf(X_test, client)
y_train = to_dask_cudf(y_train, client)
y_test = to_dask_cudf(y_test, client)

d_model = dKNNClf(client=client, n_neighbors=n_neighbors,
batch_size=batch_size)
d_model.fit(X_train, y_train)
d_labels, d_indices, d_distances = \
d_model.predict(X_test, convert_dtype=True)
distributed_out = da.compute(d_labels, d_indices, d_distances)
if datatype == 'dask_array':
distributed_score = d_model.score(X_test, y_test)
distributed_score = round(distributed_score, 3)

if datatype == 'dask_cudf':
distributed_out = list(map(lambda o: o.as_matrix()
if isinstance(o, DataFrame)
else o.to_array()[..., np.newaxis],
distributed_out))

match_test(local_out, distributed_out)
assert accuracy_score(y_test, distributed_out[0]) > 0.12


@pytest.mark.skip(reason="Sometimes incorrect labels are returned")
@pytest.mark.parametrize("datatype", ['dask_array'])
@pytest.mark.parametrize("n_neighbors", [1, 2, 3])
@pytest.mark.parametrize("n_parts", [None, 2, 3, 5])
def test_score(dataset, datatype, n_neighbors, n_parts, client):
X_train, X_test, y_train, y_test = dataset

if not n_parts:
n_parts = len(client.has_what().keys())
exact_match(local_out, distributed_out)

X_train = generate_dask_array(X_train, n_parts)
X_test = generate_dask_array(X_test, n_parts)
y_train = generate_dask_array(y_train, n_parts)
y_test = generate_dask_array(y_test, n_parts)

if datatype == 'dask_cudf':
X_train = to_dask_cudf(X_train, client)
X_test = to_dask_cudf(X_test, client)
y_train = to_dask_cudf(y_train, client)
y_test = to_dask_cudf(y_test, client)

d_model = dKNNClf(client=client, n_neighbors=n_neighbors)
d_model.fit(X_train, y_train)
d_labels, d_indices, d_distances = \
d_model.predict(X_test, convert_dtype=True)
distributed_out = da.compute(d_labels, d_indices, d_distances)

if datatype == 'dask_cudf':
distributed_out = list(map(lambda o: o.as_matrix()
if isinstance(o, DataFrame)
else o.to_array()[..., np.newaxis],
distributed_out))
cuml_score = d_model.score(X_test, y_test)

if datatype == 'dask_cudf':
y_test = y_test.compute().as_matrix()
if datatype == 'dask_array':
assert distributed_score == handmade_local_score
else:
y_test = y_test.compute()
manual_score = np.mean(y_test == distributed_out[0])

assert cuml_score == manual_score
y_pred = distributed_out[0]
handmade_distributed_score = np.mean(np_y_test == y_pred)
handmade_distributed_score = round(handmade_distributed_score, 3)
assert handmade_distributed_score == handmade_local_score


@pytest.mark.parametrize("datatype", ['dask_array', 'dask_cudf'])
@pytest.mark.parametrize("n_neighbors", [1, 3, 6])
@pytest.mark.parametrize("n_parts", [None, 2, 3, 5])
def test_predict_proba(dataset, datatype, n_neighbors, n_parts, client):
@pytest.mark.parametrize("n_neighbors", [1, 3, 8])
@pytest.mark.parametrize("n_parts", [2, 4, 12])
@pytest.mark.parametrize("batch_size", [128, 1024])
def test_predict_proba(dataset, datatype, n_neighbors,
n_parts, batch_size, client):
X_train, X_test, y_train, y_test = dataset

l_model = lKNNClf(n_neighbors=n_neighbors)
l_model.fit(X_train, y_train)
l_probas = l_model.predict_proba(X_test)

if not n_parts:
n_parts = len(client.has_what().keys())

X_train = generate_dask_array(X_train, n_parts)
X_test = generate_dask_array(X_test, n_parts)
y_train = generate_dask_array(y_train, n_parts)
Expand Down
Loading

0 comments on commit 726096a

Please sign in to comment.