Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve discreteness handling, allow binary outcomes #816

Merged
merged 29 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
06f85fe
initial commit for binary outcome, warn when clf passed but disc_trea…
fverac Sep 22, 2023
6bc0660
add init args to drlearner, causalforestdml
fverac Sep 25, 2023
058c3e8
modify bootstrap test to use np array
fverac Sep 25, 2023
a92d140
bugfix causalforest firststagewrapper
fverac Sep 25, 2023
8929eab
fix test bug ortholearner
fverac Sep 25, 2023
1540a08
fix test bugs treatfeat OL doctest
fverac Sep 25, 2023
d39a091
add tests, allow str y, add warnings/errors
fverac Oct 11, 2023
bfb6e67
Merge branch 'main' into fverac/improve_discreteness_handling
fverac Oct 13, 2023
ee64b0e
bugfixes
fverac Oct 27, 2023
5b36a4a
Merge branch 'main' into fverac/improve_discreteness_handling
fverac Oct 27, 2023
5aaee9d
linting
fverac Oct 27, 2023
9064f8b
indent
fverac Oct 27, 2023
c98edbc
linting
fverac Oct 27, 2023
1ff9505
rlearner doctest
fverac Nov 9, 2023
3c4eac7
Merge branch 'main' into fverac/improve_discreteness_handling
fverac Dec 7, 2023
a67eb54
linting
fverac Dec 7, 2023
e104d73
more typos
fverac Dec 7, 2023
edc0b48
bugfixes, docstrings, enable for intenttotreatdrivs
fverac Dec 15, 2023
79a3b07
fix default
fverac Dec 15, 2023
17a0b36
bugfixes
fverac Dec 15, 2023
6ba3b1f
test_binary_outcome bugfix
fverac Jan 2, 2024
5d75de4
adjust tests
fverac Jan 2, 2024
9e7d701
address comments; binary_outcome->discrete_outcome, improve warnings
fverac Jan 5, 2024
0757d39
line endings
fverac Jan 5, 2024
b848e73
fix tests where clf was used without specifying disc treat
fverac Jan 9, 2024
12dae44
rename function, fix warning
fverac Jan 10, 2024
5014d4c
add test for discrete model constraints, fix warning whitespace
fverac Jan 11, 2024
74842eb
fix test
fverac Jan 11, 2024
961cf24
merge main
fverac Jan 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 46 additions & 20 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class in this module implements the general logic in a very versatile way
TreatmentExpansionMixin)
from .inference import BootstrapInference
from .utilities import (_deprecate_positional, check_input_arrays,
cross_product, filter_none_kwargs,
cross_product, filter_none_kwargs, single_strata_from_discrete_arrays,
inverse_onehot, jacify_featurizer, ndim, reshape, shape, transpose)
from .sklearn_extensions.model_selection import ModelSelector

Expand Down Expand Up @@ -327,6 +327,9 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):

Parameters
----------
discrete_outcome: bool
Whether the outcome should be treated as binary

discrete_treatment: bool
Whether the treatment values should be treated as categorical, rather than continuous, quantities

Expand Down Expand Up @@ -426,7 +429,7 @@ def _gen_ortho_learner_model_final(self):
np.random.seed(123)
X = np.random.normal(size=(100, 3))
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.1, size=(100,))
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None,
est = OrthoLearner(cv=2, discrete_outcome=False, discrete_treatment=False, treatment_featurizer=None,
discrete_instrument=False, categories='auto', random_state=None)
est.fit(y, X[:, 0], W=X[:, 1:])

Expand Down Expand Up @@ -484,7 +487,7 @@ def _gen_ortho_learner_model_final(self):
import scipy.special
T = np.random.binomial(1, scipy.special.expit(W[:, 0]))
y = T + W[:, 0] + np.random.normal(0, 0.01, size=(100,))
est = OrthoLearner(cv=2, discrete_treatment=True, discrete_instrument=False,
est = OrthoLearner(cv=2, discrete_outcome=False, discrete_treatment=True, discrete_instrument=False,
treatment_featurizer=None, categories='auto', random_state=None)
est.fit(y, T, W=W)

Expand Down Expand Up @@ -516,11 +519,20 @@ def _gen_ortho_learner_model_final(self):
"""

def __init__(self, *,
discrete_treatment, treatment_featurizer,
discrete_instrument, categories, cv, random_state,
mc_iters=None, mc_agg='mean', allow_missing=False, use_ray=False, ray_remote_func_options=None):
self.actors = []
fverac marked this conversation as resolved.
Show resolved Hide resolved
discrete_outcome,
discrete_treatment,
treatment_featurizer,
discrete_instrument,
categories,
cv,
random_state,
mc_iters=None,
mc_agg='mean',
allow_missing=False,
use_ray=False,
ray_remote_func_options=None):
self.cv = cv
self.discrete_outcome = discrete_outcome
self.discrete_treatment = discrete_treatment
self.treatment_featurizer = treatment_featurizer
self.discrete_instrument = discrete_instrument
Expand Down Expand Up @@ -616,20 +628,15 @@ def _subinds_check_none(self, var, inds):
def _strata(self, Y, T, X=None, W=None, Z=None,
sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, only_final=False, check_input=True):
arrs = []
if self.discrete_outcome:
arrs.append(Y)
if self.discrete_treatment:
arrs.append(T)
if self.discrete_instrument:
Z = LabelEncoder().fit_transform(np.ravel(Z))
arrs.append(Z)

if self.discrete_treatment:
enc = LabelEncoder()
T = enc.fit_transform(np.ravel(T))
if self.discrete_instrument:
return T + Z * len(enc.classes_)
else:
return T
elif self.discrete_instrument:
return Z
else:
return None
return single_strata_from_discrete_arrays(arrs)

def _prefit(self, Y, T, *args, only_final=False, **kwargs):

Expand Down Expand Up @@ -706,6 +713,20 @@ def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=N

if not only_final:

if self.discrete_outcome:
self.outcome_transformer = LabelEncoder()
self.outcome_transformer.fit(Y)
if Y.shape[1:] and Y.shape[1] > 1:
raise ValueError(
f"Only one outcome variable is supported when discrete_outcome=True. Got Y of shape {Y.shape}")
if len(self.outcome_transformer.classes_) > 2:
raise AttributeError(
f"({self.outcome_transformer.classes_} outcome classes detected. \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're including the classes themselves rather than their count here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

Currently, only 2 outcome classes are allowed when discrete_outcome=True. \
Classes provided include {self.outcome_transformer.classes_[:5]}")
else:
self.outcome_transformer = None

if self.discrete_treatment:
categories = self.categories
if categories != 'auto':
Expand Down Expand Up @@ -865,7 +886,7 @@ def refit_final(self, inference=None):
def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):

# use a binary array to get stratified split in case of discrete treatment
stratify = self.discrete_treatment or self.discrete_instrument
stratify = self.discrete_treatment or self.discrete_instrument or self.discrete_outcome
strata = self._strata(Y, T, X=X, W=W, Z=Z, sample_weight=sample_weight, groups=groups)
if strata is None:
strata = T # always safe to pass T as second arg to split even if we're not actually stratifying
Expand All @@ -878,6 +899,9 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, group
if self.discrete_instrument:
Z = self.z_transformer.transform(reshape(Z, (-1, 1)))

if self.discrete_outcome:
Y = self.outcome_transformer.transform(Y).reshape(-1, 1)

if self.cv == 1: # special case, no cross validation
folds = None
else:
Expand Down Expand Up @@ -1008,6 +1032,8 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
X, T = self._expand_treatments(X, T)
if self.z_transformer is not None:
Z = self.z_transformer.transform(reshape(Z, (-1, 1)))
if self.discrete_outcome:
Y = self.outcome_transformer.transform(Y).reshape(-1, 1)
n_iters = len(self._models_nuisance)
n_splits = len(self._models_nuisance[0])

Expand Down
24 changes: 19 additions & 5 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class _RLearner(_OrthoLearner):

Parameters
----------
discrete_outcome: bool
Whether the outcome should be treated as binary

discrete_treatment: bool
Whether the treatment values should be treated as categorical, rather than continuous, quantities

Expand Down Expand Up @@ -242,7 +245,7 @@ def _gen_rlearner_model_final(self):
np.random.seed(123)
X = np.random.normal(size=(1000, 3))
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.01, size=(1000,))
est = RLearner(cv=2, discrete_treatment=False,
est = RLearner(cv=2, discrete_outcome=False, discrete_treatment=False,
treatment_featurizer=None, categories='auto', random_state=None)
est.fit(y, X[:, 0], X=np.ones((X.shape[0], 1)), W=X[:, 1:])

Expand Down Expand Up @@ -290,10 +293,21 @@ def _gen_rlearner_model_final(self):
is multidimensional, then the average of the MSEs for each dimension of Y is returned.
"""

def __init__(self, *, discrete_treatment, treatment_featurizer, categories,
cv, random_state, mc_iters=None, mc_agg='mean', allow_missing=False,
use_ray=False, ray_remote_func_options=None):
super().__init__(discrete_treatment=discrete_treatment,
def __init__(self,
*,
discrete_outcome,
discrete_treatment,
treatment_featurizer,
categories,
cv,
random_state,
mc_iters=None,
mc_agg='mean',
allow_missing=False,
use_ray=False,
ray_remote_func_options=None):
super().__init__(discrete_outcome=discrete_outcome,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
discrete_instrument=False, # no instrument, so doesn't matter
categories=categories,
Expand Down
49 changes: 36 additions & 13 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,35 @@ class CausalForestDML(_BaseDML):

Parameters
----------
model_y: estimator or 'auto', default 'auto'
The estimator for fitting the response to the features. Must implement
`fit` and `predict` methods.
If 'auto' :class:`.WeightedLassoCV`/:class:`.WeightedMultiTaskLassoCV` will be chosen.

model_t: estimator or 'auto', default 'auto'
The estimator for fitting the treatment to the features.
If estimator, it must implement `fit` and `predict` methods;
If 'auto', :class:`~sklearn.linear_model.LogisticRegressionCV` will be applied for discrete treatment,
and :class:`.WeightedLassoCV`/:class:`.WeightedMultiTaskLassoCV`
will be applied for continuous treatment.
model_y: estimator, {'linear', 'forest'}, list of str/estimator, or 'auto'
Determines how to fit the treatment to the features.

- If an estimator, will use the model as is for fitting.
- If str, will use model associated with the keyword.

- 'linear' - LogisticRegressionCV if discrete_outcome=True else WeightedLassoCVWrapper
- 'forest' - RandomForestClassifier if discrete_outcome=True else RandomForestRegressor
- If list, will perform model selection on the supplied list, which can be a mix of str and estimators, \
and then use the best estimator for fitting.
- If 'auto', model will select over linear and forest models

User-supplied estimators should support 'fit' and 'predict' methods,
and additionally 'predict_proba' if discrete_outcome=True.

model_t: estimator, {'linear', 'forest'}, list of str/estimator, or 'auto', default 'auto'
Determines how to fit the treatment to the features. str in a sentence

- If an estimator, will use the model as is for fitting.
- If str, will use model associated with the keyword.

- 'linear' - LogisticRegressionCV if discrete_treatment=True else WeightedLassoCVWrapper
- 'forest' - RandomForestClassifier if discrete_treatment=True else RandomForestRegressor
- If list, will perform model selection on the supplied list, which can be a mix of str and estimators, \
and then use the best estimator for fitting.
- If 'auto', model will select over linear and forest models

User-supplied estimators should support 'fit' and 'predict' methods,
and additionally 'predict_proba' if discrete_treatment=True.

featurizer : :term:`transformer`, optional
Must support fit_transform and transform. Used to create composite features in the final CATE regression.
Expand All @@ -290,6 +308,9 @@ class CausalForestDML(_BaseDML):
The final CATE will be trained on the outcome of featurizer.fit_transform(T).
If featurizer=None, then CATE is trained on T.

discrete_outcome: bool, default ``False``
Whether the outcome should be treated as binary

discrete_treatment: bool, default ``False``
Whether the treatment values should be treated as categorical, rather than continuous, quantities

Expand Down Expand Up @@ -588,6 +609,7 @@ def __init__(self, *,
model_t='auto',
featurizer=None,
treatment_featurizer=None,
discrete_outcome=False,
discrete_treatment=False,
categories='auto',
cv=2,
Expand Down Expand Up @@ -644,7 +666,8 @@ def __init__(self, *,
self.subforest_size = subforest_size
self.n_jobs = n_jobs
self.verbose = verbose
super().__init__(discrete_treatment=discrete_treatment,
super().__init__(discrete_outcome=discrete_outcome,
discrete_treatment=discrete_treatment,
treatment_featurizer=treatment_featurizer,
categories=categories,
cv=cv,
Expand All @@ -668,7 +691,7 @@ def _gen_featurizer(self):
return clone(self.featurizer, safe=False)

def _gen_model_y(self):
return _make_first_stage_selector(self.model_y, False, self.random_state)
return _make_first_stage_selector(self.model_y, self.discrete_outcome, self.random_state)

def _gen_model_t(self):
return _make_first_stage_selector(self.model_t, self.discrete_treatment, self.random_state)
Expand Down
Loading