Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Fix for OPG KNN Classifier & Regressor #2844

Merged
merged 13 commits into from
Oct 6, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
- PR #2842: KNN index preprocessors were using incorrect n_samples
- PR #2848: Fix typo in Python docstring for UMAP
- PR #2855: Updates for RMM being header only
- PR #2844: Fix for OPG KNN Classifier & Regressor

# cuML 0.15.0 (Date TBD)

Expand Down
289 changes: 170 additions & 119 deletions cpp/src/knn/knn_opg_common.cu

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions cpp/src_prims/selection/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,12 @@ void class_probs(std::vector<float *> &out, const int64_t *knn_indices,
* knn_indices and labels
*/
device_buffer<int> y_normalized(allocator, stream, n_index_rows);
MLCommon::Label::make_monotonic(y_normalized.data(), y[i], n_index_rows,
stream, allocator);
device_buffer<int> y_tmp(allocator, stream, n_index_rows + n_unique_labels);
updateDevice(y_tmp.data(), y[i], n_index_rows, stream);
updateDevice(y_tmp.data() + n_index_rows, uniq_labels[i], n_unique_labels,
stream);
MLCommon::Label::make_monotonic(y_normalized.data(), y_tmp.data(),
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
y_tmp.size(), stream, allocator);
MLCommon::LinAlg::unaryOp<int>(
y_normalized.data(), y_normalized.data(), n_index_rows,
[] __device__(int input) { return input - 1; }, stream);
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/neighbors/kneighbors_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ def score(self, X, y, convert_dtype=True):
-------
score
"""
labels, _, _ = self.predict(X, convert_dtype=convert_dtype)
diff = (labels == y)
if self.data_handler.datatype == 'cupy':
preds, _, _ = self.predict(X, convert_dtype=convert_dtype)
diff = (preds == y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "matched" or something similar would be a better name.

mean = da.mean(diff)
return mean.compute()
else:
Expand Down
10 changes: 6 additions & 4 deletions python/cuml/dask/neighbors/kneighbors_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,12 @@ def score(self, X, y):
-------
score
"""
labels, _, _ = self.predict(X, convert_dtype=True)
diff = (labels == y)
if self.data_handler.datatype == 'cupy':
mean = da.mean(diff)
return mean.compute()
preds, _, _ = self.predict(X, convert_dtype=True)
y_mean = y.mean(axis=0)
residual_sss = ((y - preds) ** 2).sum(axis=0)
total_sss = ((y - y_mean) ** 2).sum(axis=0)
r2_score = da.mean(1 - (residual_sss / total_sss))
return r2_score.compute()
else:
raise ValueError("Only Dask arrays are supported")
2 changes: 1 addition & 1 deletion python/cuml/neighbors/kneighbors_regressor_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class KNeighborsRegressorMG(KNeighborsMG):
query, query_parts_to_ranks, query_nrows,
ncols, rank, convert_dtype)

output = self.gen_local_output(data, convert_dtype, dtype='int32')
output = self.gen_local_output(data, convert_dtype, dtype='float32')

query_cais = input['cais']['query']
local_query_rows = list(map(lambda x: x.shape[0], query_cais))
Expand Down
126 changes: 42 additions & 84 deletions python/cuml/test/dask/test_kneighbors_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def generate_dask_array(np_array, n_parts):
@pytest.fixture(
scope="module",
params=[
unit_param({'n_samples': 1000, 'n_features': 30,
unit_param({'n_samples': 3000, 'n_features': 30,
'n_classes': 5, 'n_targets': 2}),
quality_param({'n_samples': 5000, 'n_features': 100,
'n_classes': 12, 'n_targets': 4}),
stress_param({'n_samples': 12000, 'n_features': 40,
'n_classes': 5, 'n_targets': 2})
quality_param({'n_samples': 8000, 'n_features': 35,
'n_classes': 12, 'n_targets': 3}),
stress_param({'n_samples': 20000, 'n_features': 40,
'n_classes': 12, 'n_targets': 4})
])
def dataset(request):
X, y = make_multilabel_classification(
Expand All @@ -69,18 +69,14 @@ def dataset(request):
if len(new_x) >= request.param['n_samples']:
break
X = X[new_x]
noise = np.random.normal(0, 1.2, X.shape)
X += noise
y = np.array(new_y)

return train_test_split(X, y, test_size=0.33)
return train_test_split(X, y, test_size=0.1)


def accuracy_score(y_true, y_pred):
assert y_pred.shape[0] == y_true.shape[0]
assert y_pred.shape[1] == y_true.shape[1]
return np.mean(y_pred == y_true)


def match_test(output1, output2):
def exact_match(output1, output2):
l1, i1, d1 = output1
l2, i2, d2 = output2
l2 = l2.squeeze()
Expand All @@ -90,130 +86,92 @@ def match_test(output1, output2):
assert i1.shape == i2.shape
assert d1.shape == d2.shape

# Distances should strictly match
# Distances should match
d1 = np.round(d1, 4)
d2 = np.round(d2, 4)
assert np.array_equal(d1, d2)

# Indices might differ for equivalent distances
for i in range(d1.shape[0]):
idx_set1, idx_set2 = (set(), set())
dist = 0.
for j in range(d1.shape[1]):
if d1[i, j] > dist:
assert idx_set1 == idx_set2
idx_set1, idx_set2 = (set(), set())
dist = d1[i, j]
idx_set1.add(i1[i, j])
idx_set2.add(i2[i, j])
# the last set of indices is not guaranteed
# Indices should strictly match
assert np.array_equal(i1, i2)

# As indices might differ, labels can also differ
# assert np.mean((l1 == l2)) > 0.6
# Labels should strictly match
assert np.array_equal(l1, l2)


def check_probabilities(l_probas, d_probas):
assert len(l_probas) == len(d_probas)
for i in range(len(l_probas)):
assert l_probas[i].shape == d_probas[i].shape
assert np.array_equal(l_probas[i], d_probas[i])


@pytest.mark.parametrize("datatype", ['dask_array', 'dask_cudf'])
@pytest.mark.parametrize("n_neighbors", [1, 3, 6])
@pytest.mark.parametrize("n_parts", [None, 2, 3, 5])
@pytest.mark.parametrize("batch_size", [256, 512, 1024])
def test_predict(dataset, datatype, n_neighbors, n_parts, batch_size, client):
@pytest.mark.parametrize("n_neighbors", [1, 3, 8])
@pytest.mark.parametrize("n_parts", [2, 4, 12])
@pytest.mark.parametrize("batch_size", [128, 1024])
def test_predict_and_score(dataset, datatype, n_neighbors,
n_parts, batch_size, client):
X_train, X_test, y_train, y_test = dataset
np_y_test = y_test

l_model = lKNNClf(n_neighbors=n_neighbors)
l_model.fit(X_train, y_train)
l_distances, l_indices = l_model.kneighbors(X_test)
l_labels = l_model.predict(X_test)
local_out = (l_labels, l_indices, l_distances)

if not n_parts:
n_parts = len(client.has_what().keys())
handmade_local_score = np.mean(y_test == l_labels)
handmade_local_score = round(handmade_local_score, 3)

X_train = generate_dask_array(X_train, n_parts)
X_test = generate_dask_array(X_test, n_parts)
y_train = generate_dask_array(y_train, n_parts)
y_test = generate_dask_array(y_test, n_parts)

if datatype == 'dask_cudf':
X_train = to_dask_cudf(X_train, client)
X_test = to_dask_cudf(X_test, client)
y_train = to_dask_cudf(y_train, client)
y_test = to_dask_cudf(y_test, client)

d_model = dKNNClf(client=client, n_neighbors=n_neighbors,
batch_size=batch_size)
d_model.fit(X_train, y_train)
d_labels, d_indices, d_distances = \
d_model.predict(X_test, convert_dtype=True)
distributed_out = da.compute(d_labels, d_indices, d_distances)
if datatype == 'dask_array':
distributed_score = d_model.score(X_test, y_test)
distributed_score = round(distributed_score, 3)

if datatype == 'dask_cudf':
distributed_out = list(map(lambda o: o.as_matrix()
if isinstance(o, DataFrame)
else o.to_array()[..., np.newaxis],
distributed_out))

match_test(local_out, distributed_out)
assert accuracy_score(y_test, distributed_out[0]) > 0.12


@pytest.mark.skip(reason="Sometimes incorrect labels are returned")
@pytest.mark.parametrize("datatype", ['dask_array'])
@pytest.mark.parametrize("n_neighbors", [1, 2, 3])
@pytest.mark.parametrize("n_parts", [None, 2, 3, 5])
def test_score(dataset, datatype, n_neighbors, n_parts, client):
X_train, X_test, y_train, y_test = dataset

if not n_parts:
n_parts = len(client.has_what().keys())
exact_match(local_out, distributed_out)

X_train = generate_dask_array(X_train, n_parts)
X_test = generate_dask_array(X_test, n_parts)
y_train = generate_dask_array(y_train, n_parts)
y_test = generate_dask_array(y_test, n_parts)

if datatype == 'dask_cudf':
X_train = to_dask_cudf(X_train, client)
X_test = to_dask_cudf(X_test, client)
y_train = to_dask_cudf(y_train, client)
y_test = to_dask_cudf(y_test, client)

d_model = dKNNClf(client=client, n_neighbors=n_neighbors)
d_model.fit(X_train, y_train)
d_labels, d_indices, d_distances = \
d_model.predict(X_test, convert_dtype=True)
distributed_out = da.compute(d_labels, d_indices, d_distances)

if datatype == 'dask_cudf':
distributed_out = list(map(lambda o: o.as_matrix()
if isinstance(o, DataFrame)
else o.to_array()[..., np.newaxis],
distributed_out))
cuml_score = d_model.score(X_test, y_test)

if datatype == 'dask_cudf':
y_test = y_test.compute().as_matrix()
if datatype == 'dask_array':
assert distributed_score == handmade_local_score
else:
y_test = y_test.compute()
manual_score = np.mean(y_test == distributed_out[0])

assert cuml_score == manual_score
y_pred = distributed_out[0]
handmade_distributed_score = np.mean(np_y_test == y_pred)
handmade_distributed_score = round(handmade_distributed_score, 3)
assert handmade_distributed_score == handmade_local_score


@pytest.mark.parametrize("datatype", ['dask_array', 'dask_cudf'])
@pytest.mark.parametrize("n_neighbors", [1, 3, 6])
@pytest.mark.parametrize("n_parts", [None, 2, 3, 5])
def test_predict_proba(dataset, datatype, n_neighbors, n_parts, client):
@pytest.mark.parametrize("n_neighbors", [1, 3, 8])
@pytest.mark.parametrize("n_parts", [2, 4, 12])
@pytest.mark.parametrize("batch_size", [128, 1024])
def test_predict_proba(dataset, datatype, n_neighbors,
n_parts, batch_size, client):
X_train, X_test, y_train, y_test = dataset

l_model = lKNNClf(n_neighbors=n_neighbors)
l_model.fit(X_train, y_train)
l_probas = l_model.predict_proba(X_test)

if not n_parts:
n_parts = len(client.has_what().keys())

X_train = generate_dask_array(X_train, n_parts)
X_test = generate_dask_array(X_test, n_parts)
y_train = generate_dask_array(y_train, n_parts)
Expand Down
Loading