Skip to content

Commit

Permalink
fix gpu test
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahyurick committed Jan 31, 2023
1 parent 4717bde commit 98c42d5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
26 changes: 13 additions & 13 deletions tests/integration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ def test_training_and_prediction(c, gpu_client):
)
check_trained_model(c)

c.sql(
f"""
CREATE OR REPLACE MODEL my_model WITH (
model_class = 'LogisticRegression',
wrap_predict = True,
wrap_fit = False,
target_column = 'target'
) AS (
SELECT x, y, x*y > 0 AS target
FROM {timeseries}
c.sql(
f"""
CREATE OR REPLACE MODEL my_model WITH (
model_class = 'LogisticRegression',
wrap_predict = True,
wrap_fit = False,
target_column = 'target'
) AS (
SELECT x, y, x*y > 0 AS target
FROM {timeseries}
)
"""
)
"""
)
check_trained_model(c, df_name=timeseries)
check_trained_model(c, df_name=timeseries)

c.sql(
f"""
Expand Down
7 changes: 1 addition & 6 deletions tests/unit/test_ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ def test_ml_class_mappings(gpu):
if not ("XGB" in key and xgboost is None) and not (
"LGBM" in key and lightgbm is None
):
if gpu and key == "LogisticRegression":
# dask-glm >= 0.2.1.dev needed to use multi-GPU logistic regression
with pytest.raises(ImportError):
import_class(classes_dict[key])
else:
import_class(classes_dict[key])
import_class(classes_dict[key])


def _check_axis_partitioning(chunks, n_features):
Expand Down

0 comments on commit 98c42d5

Please sign in to comment.