Skip to content

Commit

Permalink
Regression errors failing with mixed data type combinations (rapidsai…
Browse files Browse the repository at this point in the history
…#4770)

Resolves rapidsai#4442

This PR fixes the issue with using mixed data types in regression errors like `mean_squared_error`, `mean_absolute_error` and `mean_squared_log_error`.

Authors:
  - Shaswat Anand (https://github.com/shaswat-indian)

Approvers:
  - William Hicks (https://github.com/wphicks)

URL: rapidsai#4770
  • Loading branch information
shaswat-indian authored Jun 29, 2022
1 parent 5255757 commit dbc8731
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
12 changes: 6 additions & 6 deletions python/cuml/metrics/regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,19 @@ def _prepare_input_reg(y_true, y_pred, sample_weight, multioutput):
Helper function to avoid code duplication for regression metrics.
Converts inputs to CumlArray and check multioutput parameter validity.
"""
allowed_d_types = [np.float32, np.float64, np.int32, np.int64]
y_true = y_true.squeeze() if len(y_true.shape) > 1 else y_true
y_true, n_rows, n_cols, ytype = \
input_to_cuml_array(y_true, check_dtype=[np.float32, np.float64,
np.int32, np.int64])
input_to_cuml_array(y_true, check_dtype=allowed_d_types)

y_pred = y_pred.squeeze() if len(y_pred.shape) > 1 else y_pred
y_pred, _, _, _ = \
input_to_cuml_array(y_pred, check_dtype=ytype, check_rows=n_rows,
check_cols=n_cols)
input_to_cuml_array(y_pred, check_dtype=allowed_d_types,
check_rows=n_rows, check_cols=n_cols)

if sample_weight is not None:
sample_weight, _, _, _ = \
input_to_cuml_array(sample_weight, check_dtype=ytype,
input_to_cuml_array(sample_weight, check_dtype=allowed_d_types,
check_rows=n_rows, check_cols=n_cols)

raw_multioutput = False
Expand All @@ -134,7 +134,7 @@ def _prepare_input_reg(y_true, y_pred, sample_weight, multioutput):
multioutput = None
elif multioutput is not None:
multioutput, _, _, _ = \
input_to_cuml_array(multioutput, check_dtype=ytype)
input_to_cuml_array(multioutput, check_dtype=allowed_d_types)
if n_cols == 1:
raise ValueError("Custom weights are useful only in "
"multi-output cases.")
Expand Down
17 changes: 10 additions & 7 deletions python/cuml/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,15 +481,18 @@ def test_regression_metrics():


@pytest.mark.parametrize('n_samples', [50, stress_param(500000)])
@pytest.mark.parametrize('dtype', [np.int32, np.int64, np.float32, np.float64])
@pytest.mark.parametrize('y_dtype',
[np.int32, np.int64, np.float32, np.float64])
@pytest.mark.parametrize('pred_dtype',
[np.int32, np.int64, np.float32, np.float64])
@pytest.mark.parametrize('function', ['mse', 'mae', 'msle'])
def test_regression_metrics_random(n_samples, dtype, function):
if dtype == np.float32 and n_samples == 500000:
# stress test for float32 fails because of floating point precision
pytest.xfail()
def test_regression_metrics_random_with_mixed_dtypes(n_samples, y_dtype,
pred_dtype, function):
y_true, _, _, _ = generate_random_labels(
lambda rng: rng.randint(0, 1000, n_samples).astype(y_dtype))

y_true, y_pred, _, _ = generate_random_labels(
lambda rng: rng.randint(0, 1000, n_samples).astype(dtype))
y_pred, _, _, _ = generate_random_labels(
lambda rng: rng.randint(0, 1000, n_samples).astype(pred_dtype))

cuml_reg, sklearn_reg = {
'mse': (mean_squared_error, sklearn_mse),
Expand Down

0 comments on commit dbc8731

Please sign in to comment.