Skip to content

Commit

Permalink
Remove reliance on sklearn internals
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Nov 20, 2020
1 parent 7c71e6a commit 30cee23
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 22 deletions.
1 change: 1 addition & 0 deletions econml/ortho_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
cross_product, inverse_onehot, _EncoderWrapper, check_input_arrays,
_RegressionWrapper, deprecated)
from sklearn.model_selection import check_cv
# TODO: consider working around relying on sklearn implementation details
from .sklearn_extensions.model_selection import _cross_val_predict


Expand Down
19 changes: 7 additions & 12 deletions econml/sklearn_extensions/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,16 @@
import threading
import sparse as sp
import itertools
from joblib import effective_n_jobs, Parallel, delayed
from sklearn.utils import check_array, check_X_y, issparse
from sklearn.ensemble.forest import ForestRegressor, _accumulate_prediction
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.base import RegressorMixin
from warnings import catch_warnings, simplefilter, warn
from sklearn.exceptions import DataConversionWarning, NotFittedError
from sklearn.tree._tree import DTYPE, DOUBLE
from sklearn.utils import check_random_state, check_array, compute_sample_weight
from sklearn.utils._joblib import Parallel, delayed
from sklearn.utils.fixes import _joblib_parallel_args
from sklearn.utils.validation import check_is_fitted
from sklearn.ensemble.base import _partition_estimators

MAX_INT = np.iinfo(np.int32).max

Expand Down Expand Up @@ -462,7 +459,7 @@ def fit(self, X, y, sample_weight=None, sample_var=None):
"""

# Validate or convert input data
X = check_array(X, accept_sparse="csc", dtype=DTYPE)
X = check_array(X, accept_sparse="csc", dtype=np.float32)
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
if sample_weight is not None:
sample_weight = check_array(sample_weight, ensure_2d=False)
Expand Down Expand Up @@ -490,8 +487,8 @@ def fit(self, X, y, sample_weight=None, sample_var=None):

y, expanded_class_weight = self._validate_y_class_weight(y)

if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)
if getattr(y, "dtype", None) != np.float64 or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=np.float64)

if expanded_class_weight is not None:
if sample_weight is not None:
Expand Down Expand Up @@ -555,8 +552,7 @@ def fit(self, X, y, sample_weight=None, sample_var=None):
int(np.ceil(self.subsample_fr_ *
(X.shape[0] // 2))),
replace=False)])
res = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
**_joblib_parallel_args(prefer='threads'))(
res = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, prefer='threads')(
delayed(_parallel_add_trees)(
t, self, X, y, sample_weight, s_inds[i], i, len(trees),
verbose=self.verbose)
Expand Down Expand Up @@ -590,10 +586,9 @@ def _mean_fn(self, X, fn, acc, slice=None):
n_estimators = slice[1] - slice[0]

# Assign chunk of trees to jobs
n_jobs, _, _ = _partition_estimators(n_estimators, self.n_jobs)
n_jobs = min(effective_n_jobs(self.n_jobs), n_estimators)
lock = threading.Lock()
Parallel(n_jobs=n_jobs, verbose=self.verbose,
**_joblib_parallel_args(require="sharedmem"))(
Parallel(n_jobs=n_jobs, verbose=self.verbose, require="sharedmem")(
delayed(_accumulate_prediction)(fn(e, n, d), X, [acc], lock)
for e, n, d in estimator_slice)
acc /= n_estimators
Expand Down
1 change: 1 addition & 0 deletions econml/sklearn_extensions/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sklearn.linear_model import LinearRegression, LassoCV, MultiTaskLassoCV, Lasso, MultiTaskLasso
from sklearn.metrics import r2_score
from sklearn.model_selection import KFold, StratifiedKFold
# TODO: consider working around relying on sklearn implementation details
from sklearn.model_selection._split import _CVIterableWrapper
from sklearn.multioutput import MultiOutputRegressor
from sklearn.utils import check_array, check_X_y
Expand Down
18 changes: 8 additions & 10 deletions econml/sklearn_extensions/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
"""Collection of scikit-learn extensions for model selection techniques."""

import numbers
import numpy as np
import scipy.sparse as sp
import warnings
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.utils.multiclass import type_of_target

from sklearn.base import is_classifier, clone
from sklearn.utils import (indexable, check_random_state, _safe_indexing,
_message_with_time)
from sklearn.utils.validation import _check_fit_params
import numpy as np
import scipy.sparse as sp
from joblib import Parallel, delayed
from sklearn.model_selection import check_cv
from sklearn.base import clone, is_classifier
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
# TODO: conisder working around relying on sklearn implementation details
from sklearn.model_selection._validation import (_check_is_permutation,
_fit_and_predict)
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection._validation import _fit_and_predict, _check_is_permutation
from sklearn.utils import indexable
from sklearn.utils.validation import _num_samples


Expand Down

0 comments on commit 30cee23

Please sign in to comment.