Skip to content

Commit

Permalink
modify scale_ to use sample scaler for better interept matching
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Mar 12, 2024
1 parent 2a79828 commit 69bc42f
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ def to_dask_data(X_train, X_test, y_train, y_test):
# if fit_intercept is false, scale the dataset without mean center
scaler = StandardScaler(with_mean=fit_intercept, with_std=True)
scaler.fit(X_train)
scaler.scale_ = np.sqrt(scaler.var_ * len(X_train) / (len(X_train) - 1))
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)

Expand Down Expand Up @@ -842,10 +843,6 @@ def to_dask_data(X_train, X_test, y_train, y_test):
np.abs(mgon_accuracy - mgoff_accuracy) < 1e-3
)

print(f"mgon_coef_origin: {mgon_coef_origin}")
print(f"mgoff.coef_: {mgoff.coef_.to_numpy()}")
print(f"mgon_intercept_origin: {mgon_intercept_origin}")
print(f"mgoff.intercept_: {mgoff.intercept_.to_numpy()}")
assert array_equal(mgon_coef_origin, mgoff.coef_.to_numpy(), tolerance)
assert array_equal(
mgon_intercept_origin, mgoff.intercept_.to_numpy(), tolerance
Expand Down

0 comments on commit 69bc42f

Please sign in to comment.