diff --git a/python/cuml/linear_model/base.pyx b/python/cuml/linear_model/base.pyx index 71bd79ed79..4e43a2ee29 100644 --- a/python/cuml/linear_model/base.pyx +++ b/python/cuml/linear_model/base.pyx @@ -68,9 +68,14 @@ class LinearPredictMixin: Predicts `y` values for `X`. """ - coef_cp, n_feat, n_targets, _ = input_to_cupy_array(self.coef_) - if 1 < n_targets: + if self.coef_ is None: + raise ValueError( + "LinearModel.predict() cannot be called before fit(). " + "Please fit the model first." + ) + if len(self.coef_.shape) == 2 and self.coef_.shape[1] > 1: # Handle multi-target prediction in Python. + coef_cp = input_to_cupy_array(self.coef_).array X_cp = input_to_cupy_array( X, check_dtype=self.dtype, @@ -79,8 +84,8 @@ class LinearPredictMixin: ).array intercept_cp = input_to_cupy_array(self.intercept_).array preds_cp = X_cp @ coef_cp + intercept_cp - preds = input_to_cuml_array(preds_cp).array - return preds + # preds = input_to_cuml_array(preds_cp).array # TODO:remove + return preds_cp # Handle single-target prediction in C++ X_m, n_rows, n_cols, dtype = \ diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index f6e616d1d4..5da18cba70 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -322,12 +322,10 @@ class LinearRegression(Base, self.algo = 0 if 1 < y_cols: - del X_m - del y_m - if sample_weight is not None: - del sample_weight_m + if sample_weight is None: + sample_weight_m = None - return self._fit_multi_target(X, y, convert_dtype, sample_weight) + return self._fit_multi_target(X_m, y_m, convert_dtype, sample_weight_m) self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) cdef uintptr_t coef_ptr = self.coef_.ptr @@ -443,3 +441,7 @@ class LinearRegression(Base, def get_attributes_names(self): return ['coef_', 'intercept_'] + + @staticmethod + def _more_static_tags(): + return {"multioutput": True}