Skip to content

Commit

Permalink
Model selection WIP
Browse files Browse the repository at this point in the history
Signed-off-by: Keith Battocchi <[email protected]>
  • Loading branch information
kbattocchi committed Nov 8, 2023
1 parent 55c5858 commit 356f36e
Show file tree
Hide file tree
Showing 23 changed files with 774 additions and 2,252 deletions.
40 changes: 23 additions & 17 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class in this module implements the general logic in a very versatile way
from .utilities import (_deprecate_positional, check_input_arrays,
cross_product, filter_none_kwargs,
inverse_onehot, jacify_featurizer, ndim, reshape, shape, transpose)
from .sklearn_extensions.model_selection import ModelSelector

try:
import ray
Expand Down Expand Up @@ -100,7 +101,7 @@ def _fit_fold(model, train_idxs, test_idxs, calculate_scores, args, kwargs):
kwargs_train = {key: var[train_idxs] for key, var in kwargs.items()}
kwargs_test = {key: var[test_idxs] for key, var in kwargs.items()}

model.fit(*args_train, **kwargs_train)
model.train(False, *args_train, **kwargs_train)
nuisance_temp = model.predict(*args_test, **kwargs_test)

if not isinstance(nuisance_temp, tuple):
Expand All @@ -115,17 +116,18 @@ def _fit_fold(model, train_idxs, test_idxs, calculate_scores, args, kwargs):
return nuisance_temp, model, test_idxs, (score_temp if calculate_scores else None)


def _crossfit(model, folds, use_ray, ray_remote_fun_option, *args, **kwargs):
def _crossfit(model: ModelSelector, folds, use_ray, ray_remote_fun_option, *args, **kwargs):
"""
General crossfit based calculation of nuisance parameters.
Parameters
----------
model : object
An object that supports fit and predict. Fit must accept all the args
and the keyword arguments kwargs. Similarly predict must all accept
all the args as arguments and kwards as keyword arguments. The fit
function estimates a model of the nuisance function, based on the input
model : ModelSelector
An object that has train and predict methods.
The train method must take an 'is_selecting' argument first, and then
accept positional arguments `args` and keyword arguments `kwargs`; the predict method
just takes those `args` and `kwargs`. The train
method selects or estimates a model of the nuisance function, based on the input
data to fit. Predict evaluates the fitted nuisance function on the input
data to predict.
folds : list of tuple or None
Expand Down Expand Up @@ -177,7 +179,7 @@ def _crossfit(model, folds, use_ray, ray_remote_fun_option, *args, **kwargs):
class Wrapper:
def __init__(self, model):
self._model = model
def fit(self, X, y, W=None):
def fit(self, is_selecting, X, y, W=None):
self._model.fit(X, y)
return self
def predict(self, X, y, W=None):
Expand All @@ -202,13 +204,17 @@ def predict(self, X, y, W=None):
"""
model_list = []

kwargs = filter_none_kwargs(**kwargs)
model.train(True, *args, **kwargs)

calculate_scores = hasattr(model, 'score')
# remove None arguments
kwargs = filter_none_kwargs(**kwargs)

if folds is None: # skip crossfitting
model_list.append(clone(model, safe=False))
model_list[0].fit(*args, **kwargs)
model_list[0].train(True, *args, **kwargs)
model_list[0].train(False, *args, **kwargs) # fit the selected model
nuisances = model_list[0].predict(*args, **kwargs)
scores = model_list[0].score(*args, **kwargs) if calculate_scores else None

Expand Down Expand Up @@ -394,7 +400,7 @@ class ModelNuisance:
def __init__(self, model_t, model_y):
self._model_t = model_t
self._model_y = model_y
def fit(self, Y, T, W=None):
def train(self, is_selecting, Y, T, W=None):
self._model_t.fit(W, T)
self._model_y.fit(W, Y)
return self
Expand Down Expand Up @@ -448,7 +454,7 @@ class ModelNuisance:
def __init__(self, model_t, model_y):
self._model_t = model_t
self._model_y = model_y
def fit(self, Y, T, W=None):
def train(self, is_selecting, Y, T, W=None):
self._model_t.fit(W, np.matmul(T, np.arange(1, T.shape[1]+1)))
self._model_y.fit(W, Y)
return self
Expand Down Expand Up @@ -532,15 +538,15 @@ def _gen_allowed_missing_vars(self):

@abstractmethod
def _gen_ortho_learner_model_nuisance(self):
""" Must return a fresh instance of a nuisance model
"""Must return a fresh instance of a nuisance model selector
Returns
-------
model_nuisance: estimator
The estimator for fitting the nuisance function. Must implement
`fit` and `predict` methods that both have signatures::
model_nuisance: selector
The selector for fitting the nuisance function. The returned estimator must implement
`train` and `predict` methods that both have signatures::
model_nuisance.fit(Y, T, X=X, W=W, Z=Z,
model_nuisance.train(is_selecting, Y, T, X=X, W=W, Z=Z,
sample_weight=sample_weight)
model_nuisance.predict(Y, T, X=X, W=W, Z=Z,
sample_weight=sample_weight)
Expand Down
33 changes: 14 additions & 19 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,35 @@
import numpy as np
import copy
from warnings import warn

from ..sklearn_extensions.model_selection import ModelSelector
from ..utilities import (shape, reshape, ndim, hstack, filter_none_kwargs, _deprecate_positional)
from sklearn.linear_model import LinearRegression
from sklearn.base import clone
from .._ortho_learner import _OrthoLearner


class _ModelNuisance:
class _ModelNuisance(ModelSelector):
"""
Nuisance model fits the model_y and model_t at fit time and at predict time
calculates the residual Y and residual T based on the fitted models and returns
the residuals as two nuisance parameters.
"""

def __init__(self, model_y, model_t):
def __init__(self, model_y: ModelSelector, model_t: ModelSelector):
self._model_y = model_y
self._model_t = model_t

def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
def train(self, is_selecting, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
assert Z is None, "Cannot accept instrument!"
self._model_t.fit(X, W, T, **filter_none_kwargs(sample_weight=sample_weight, groups=groups))
self._model_y.fit(X, W, Y, **filter_none_kwargs(sample_weight=sample_weight, groups=groups))
self._model_t.train(is_selecting, X, W, T, **filter_none_kwargs(sample_weight=sample_weight, groups=groups))
self._model_y.train(is_selecting, X, W, Y, **filter_none_kwargs(sample_weight=sample_weight, groups=groups))
return self

def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
if hasattr(self._model_y, 'score'):
# note that groups are not passed to score because they are only used for fitting
Y_score = self._model_y.score(X, W, Y, **filter_none_kwargs(sample_weight=sample_weight))
else:
Y_score = None
if hasattr(self._model_t, 'score'):
# note that groups are not passed to score because they are only used for fitting
T_score = self._model_t.score(X, W, T, **filter_none_kwargs(sample_weight=sample_weight))
else:
T_score = None
# note that groups are not passed to score because they are only used for fitting
T_score = self._model_t.score(X, W, T, **filter_none_kwargs(sample_weight=sample_weight))
Y_score = self._model_y.score(X, W, Y, **filter_none_kwargs(sample_weight=sample_weight))
return Y_score, T_score

def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
Expand Down Expand Up @@ -302,7 +297,7 @@ def _gen_model_y(self):
"""
Returns
-------
model_y: estimator of E[Y | X, W]
model_y: selector for the estimator of E[Y | X, W]
The estimator for fitting the response to the features and controls. Must implement
`fit` and `predict` methods. Unlike sklearn estimators both methods must
take an extra second argument (the controls), i.e. ::
Expand All @@ -317,7 +312,7 @@ def _gen_model_t(self):
"""
Returns
-------
model_t: estimator of E[T | X, W]
model_t: selector for the estimator of E[T | X, W]
The estimator for fitting the treatment to the features and controls. Must implement
`fit` and `predict` methods. Unlike sklearn estimators both methods must
take an extra second argument (the controls), i.e. ::
Expand Down Expand Up @@ -432,11 +427,11 @@ def rlearner_model_final_(self):

@property
def models_y(self):
return [[mdl._model_y for mdl in mdls] for mdls in super().models_nuisance_]
return [[mdl._model_y.best_model for mdl in mdls] for mdls in super().models_nuisance_]

@property
def models_t(self):
return [[mdl._model_t for mdl in mdls] for mdls in super().models_nuisance_]
return [[mdl._model_t.best_model for mdl in mdls] for mdls in super().models_nuisance_]

@property
def nuisance_scores_y(self):
Expand Down
18 changes: 3 additions & 15 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sklearn.model_selection import train_test_split
from itertools import product
from .dml import _BaseDML
from .dml import _FirstStageWrapper
from .dml import _make_first_stage_selector
from ..sklearn_extensions.linear_model import WeightedLassoCVWrapper
from ..sklearn_extensions.model_selection import WeightedStratifiedKFold
from ..inference import NormalInferenceResults
Expand Down Expand Up @@ -668,22 +668,10 @@ def _gen_featurizer(self):
return clone(self.featurizer, safe=False)

def _gen_model_y(self):
if self.model_y == 'auto':
model_y = WeightedLassoCVWrapper(random_state=self.random_state)
else:
model_y = clone(self.model_y, safe=False)
return _FirstStageWrapper(model_y, True, self._gen_featurizer(), False, self.discrete_treatment)
return _make_first_stage_selector(self.model_y, False, self.random_state)

def _gen_model_t(self):
if self.model_t == 'auto':
if self.discrete_treatment:
model_t = LogisticRegressionCV(cv=WeightedStratifiedKFold(random_state=self.random_state),
random_state=self.random_state)
else:
model_t = WeightedLassoCVWrapper(random_state=self.random_state)
else:
model_t = clone(self.model_t, safe=False)
return _FirstStageWrapper(model_t, False, self._gen_featurizer(), False, self.discrete_treatment)
return _make_first_stage_selector(self.model_t, self.discrete_treatment, self.random_state)

def _gen_model_final(self):
return MultiOutputGRF(CausalForest(n_estimators=self.n_estimators,
Expand Down
Loading

0 comments on commit 356f36e

Please sign in to comment.