Skip to content

Commit

Permalink
create both sample_weight and freq_weight for fit (#439)
Browse files Browse the repository at this point in the history
* Make StatsmodelsLinearRegresssion allow fractional weight for each individual sample, which is sample_weight. Rename the weights used as the count of observations for corresponding sample_var as freq_weight.
* For all the child classes inherit from _OrthoLearner, only expose freq_weight and sample_weight when the final stage is StatsmodelsLinearRegression or its parent class.
* Fix a few places to make sure it works consistently with un-summarized data when both weights exist.
  • Loading branch information
heimengqi authored Mar 31, 2021
1 parent 056fb30 commit 9630421
Show file tree
Hide file tree
Showing 18 changed files with 875 additions and 645 deletions.
65 changes: 44 additions & 21 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ def predict(self, X, y, W=None):


CachedValues = namedtuple('_CachedValues', ['nuisances',
'Y', 'T', 'X', 'W', 'Z', 'sample_weight', 'sample_var', 'groups'])
'Y', 'T', 'X', 'W', 'Z', 'sample_weight', 'freq_weight',
'sample_var', 'groups'])


class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
Expand Down Expand Up @@ -448,9 +449,9 @@ def _gen_ortho_learner_model_nuisance(self):
`fit` and `predict` methods that both have signatures::
model_nuisance.fit(Y, T, X=X, W=W, Z=Z,
sample_weight=sample_weight, sample_var=sample_var)
sample_weight=sample_weight)
model_nuisance.predict(Y, T, X=X, W=W, Z=Z,
sample_weight=sample_weight, sample_var=sample_var)
sample_weight=sample_weight)
In fact we allow for the model method signatures to skip any of the keyword arguments
as long as the class is always called with the omitted keyword argument set to ``None``.
Expand All @@ -473,7 +474,7 @@ def _gen_ortho_learner_model_final(self):
Must implement `fit` and `predict` methods that must have signatures::
model_final.fit(Y, T, X=X, W=W, Z=Z, nuisances=nuisances,
sample_weight=sample_weight, sample_var=sample_var)
sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var)
model_final.predict(X=X)
Predict, should just take the features X and return the constant marginal effect. In fact we allow
Expand Down Expand Up @@ -515,7 +516,7 @@ def _subinds_check_none(self, var, inds):
return var[inds] if var is not None else None

def _strata(self, Y, T, X=None, W=None, Z=None,
sample_weight=None, sample_var=None, groups=None,
sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, only_final=False, check_input=True):
if self.discrete_instrument:
Z = LabelEncoder().fit_transform(np.ravel(Z))
Expand Down Expand Up @@ -545,7 +546,7 @@ def _prefit(self, Y, T, *args, only_final=False, **kwargs):
@_deprecate_positional("X, W, and Z should be passed by keyword only. In a future release "
"we will disallow passing X, W, and Z by position.", ['X', 'W', 'Z'])
@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=None, groups=None,
def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None, only_final=False, check_input=True):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Expand All @@ -562,10 +563,15 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=No
Controls for each sample
Z: optional (n, d_z) matrix or None (Default=None)
Instruments for each sample
sample_weight: optional (n,) vector or None (Default=None)
Weights for each samples
sample_var: optional (n,) vector or None (Default=None)
Sample variance for each sample
sample_weight : (n,) array like, default None
Individual weights for each sample. If None, it assumes equal weight.
freq_weight: (n, ) array like of integers, default None
Weight for the observation. Observation i is treated as the mean
outcome of freq_weight[i] independent observations.
When ``sample_var`` is not None, this should be provided.
sample_var : {(n,), (n, d_y)} nd array like, default None
Variance of the outcome(s) of the original freq_weight[i] observations that were used to
compute the mean outcome represented by observation i.
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the cv argument passed to this class's initializer
Expand All @@ -589,10 +595,12 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=No
self : object
"""
self._random_state = check_random_state(self.random_state)
assert (freq_weight is None) == (
sample_var is None), "Sample variances and frequency weights must be provided together!"
if check_input:
Y, T, X, W, Z, sample_weight, sample_var, groups = check_input_arrays(
Y, T, X, W, Z, sample_weight, sample_var, groups)
self._check_input_dims(Y, T, X, W, Z, sample_weight, sample_var, groups)
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)
self._check_input_dims(Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)

if not only_final:

Expand All @@ -614,11 +622,21 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=No

all_nuisances = []
fitted_inds = None
self._models_nuisance = []
if sample_weight is None:
if freq_weight is not None:
sample_weight_nuisances = freq_weight
else:
sample_weight_nuisances = None
else:
if freq_weight is not None:
sample_weight_nuisances = freq_weight * sample_weight
else:
sample_weight_nuisances = sample_weight

self._models_nuisance = []
for idx in range(self.mc_iters or 1):
nuisances, fitted_models, new_inds, scores = self._fit_nuisances(
Y, T, X, W, Z, sample_weight=sample_weight, groups=groups)
Y, T, X, W, Z, sample_weight=sample_weight_nuisances, groups=groups)
all_nuisances.append(nuisances)
self._models_nuisance.append(fitted_models)
if scores is None:
Expand All @@ -644,12 +662,14 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=No
raise ValueError(
"Parameter `mc_agg` must be one of {'mean', 'median'}. Got {}".format(self.mc_agg))

Y, T, X, W, Z, sample_weight, sample_var = (self._subinds_check_none(arr, fitted_inds)
for arr in (Y, T, X, W, Z, sample_weight, sample_var))
Y, T, X, W, Z, sample_weight, freq_weight, sample_var = (self._subinds_check_none(arr, fitted_inds)
for arr in (Y, T, X, W, Z, sample_weight,
freq_weight, sample_var))
nuisances = tuple([self._subinds_check_none(nuis, fitted_inds) for nuis in nuisances])
self._cached_values = CachedValues(nuisances=nuisances,
Y=Y, T=T, X=X, W=W, Z=Z,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var,
groups=groups) if cache_values else None
else:
Expand All @@ -664,6 +684,7 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=No
X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var)

return self
Expand Down Expand Up @@ -698,7 +719,7 @@ def refit_final(self, inference=None):
cached = self._cached_values
kwargs = filter_none_kwargs(
Y=cached.Y, T=cached.T, X=cached.X, W=cached.W, Z=cached.Z,
sample_weight=cached.sample_weight, sample_var=cached.sample_var,
sample_weight=cached.sample_weight, freq_weight=cached.freq_weight, sample_var=cached.sample_var,
groups=cached.groups,
)
_OrthoLearner.fit(self, **kwargs,
Expand Down Expand Up @@ -748,17 +769,19 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, group
sample_weight=sample_weight, groups=groups)
return nuisances, fitted_models, fitted_inds, scores

def _fit_final(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, sample_var=None):
def _fit_final(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None,
freq_weight=None, sample_var=None):
self._ortho_learner_model_final.fit(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var))
self.score_ = None
if hasattr(self._ortho_learner_model_final, 'score'):
self.score_ = self._ortho_learner_model_final.score(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
sample_var=sample_var))
sample_weight=sample_weight)
)

def const_marginal_effect(self, X=None):
X, = check_input_arrays(X)
Expand Down
28 changes: 17 additions & 11 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,16 @@ class _ModelFinal:
def __init__(self, model_final):
self._model_final = model_final

def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, sample_var=None):
def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, freq_weight=None, sample_var=None):
Y_res, T_res = nuisances
self._model_final.fit(X, T, T_res, Y_res, sample_weight=sample_weight, sample_var=sample_var)
self._model_final.fit(X, T, T_res, Y_res, sample_weight=sample_weight,
freq_weight=freq_weight, sample_var=sample_var)
return self

def predict(self, X=None):
return self._model_final.predict(X)

def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, sample_var=None):
def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None):
Y_res, T_res = nuisances
if Y_res.ndim == 1:
Y_res = Y_res.reshape((-1, 1))
Expand Down Expand Up @@ -200,7 +201,7 @@ def fit(self, X, W, Y, sample_weight=None):
def predict(self, X, W):
return self._model.predict(np.hstack([X, W]))
class ModelFinal:
def fit(self, X, T, T_res, Y_res, sample_weight=None, sample_var=None):
def fit(self, X, T, T_res, Y_res, sample_weight=None, freq_weight=None, sample_var=None):
self.model = LinearRegression(fit_intercept=False).fit(X * T_res.reshape(-1, 1),
Y_res)
return self
Expand Down Expand Up @@ -313,7 +314,7 @@ def _gen_rlearner_model_final(self):
should just take the features and return the constant marginal effect. More, concretely::
model_final.fit(X, T_res, Y_res,
sample_weight=sample_weight, sample_var=sample_var)
sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var)
model_final.predict(X)
"""
pass
Expand All @@ -326,7 +327,7 @@ def _gen_ortho_learner_model_final(self):

@_deprecate_positional("X, and should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Expand All @@ -341,10 +342,15 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
Features for each sample
W: optional(n, d_w) matrix or None (Default=None)
Controls for each sample
sample_weight: optional(n,) vector or None (Default=None)
Weights for each samples
sample_var: optional(n,) vector or None (Default=None)
Sample variance for each sample
sample_weight : (n,) array like, default None
Individual weights for each sample. If None, it assumes equal weight.
freq_weight: (n, ) array like of integers, default None
Weight for the observation. Observation i is treated as the mean
outcome of freq_weight[i] independent observations.
When ``sample_var`` is not None, this should be provided.
sample_var : {(n,), (n, d_y)} nd array like, default None
Variance of the outcome(s) of the original freq_weight[i] observations that were used to
compute the mean outcome represented by observation i.
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the `cv` argument passed to this class's initializer
Expand All @@ -361,7 +367,7 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
"""
# Replacing fit from _OrthoLearner, to enforce Z=None and improve the docstring
return super().fit(Y, T, X=X, W=W,
sample_weight=sample_weight, sample_var=sample_var, groups=groups,
sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups,
cache_values=cache_values,
inference=inference)

Expand Down
28 changes: 7 additions & 21 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,11 @@ def _ate_and_stderr(self, drpreds, mask=None):
stderr = (np.nanstd(drpreds, axis=0) / np.sqrt(nonnan)).reshape(self._d_y + self._d_t)
return point, stderr

def fit(self, X, T, T_res, Y_res, sample_weight=None, sample_var=None):
def fit(self, X, T, T_res, Y_res, sample_weight=None, freq_weight=None, sample_var=None):
# Track training dimensions to see if Y or T is a vector instead of a 2-dimensional array
self._d_t = shape(T_res)[1:]
self._d_y = shape(Y_res)[1:]
fts = self._combine(X)
if sample_var is not None:
raise ValueError("This estimator does not support sample_var!")
if T_res.ndim == 1:
T_res = T_res.reshape((-1, 1))
if Y_res.ndim == 1:
Expand Down Expand Up @@ -608,7 +606,7 @@ def tunable_params(self):
'honest', 'inference', 'fit_intercept', 'subforest_size']

def tune(self, Y, T, *, X=None, W=None,
sample_weight=None, sample_var=None, groups=None,
sample_weight=None, groups=None,
params='auto'):
"""
Tunes the major hyperparameters of the final stage causal forest based on out-of-sample R-score
Expand All @@ -630,9 +628,6 @@ def tune(self, Y, T, *, X=None, W=None,
Controls for each sample
sample_weight: optional (n,) vector
Weights for each row
sample_var: optional (n, n_y) vector
Variance of sample, in case it corresponds to summary of many samples. Currently
not in use by this method (as inference method does not require sample variance info).
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the `cv` argument passed to this class's initializer
Expand Down Expand Up @@ -674,10 +669,6 @@ def tune(self, Y, T, *, X=None, W=None,
sample_weight_train, sample_weight_val = sample_weight[train], sample_weight[test]
else:
sample_weight_train, sample_weight_val = None, None
if sample_var is not None:
sample_var_train, _ = sample_var[train], sample_var[test]
else:
sample_var_train, _ = None, None

est = clone(self, safe=False)
est.n_estimators = 100
Expand All @@ -696,7 +687,7 @@ def tune(self, Y, T, *, X=None, W=None,
setattr(est, key, value)
if it == 0:
est.fit(ytrain, Ttrain, X=Xtrain, W=Wtrain, sample_weight=sample_weight_train,
sample_var=sample_var_train, groups=groups_train, cache_values=True)
groups=groups_train, cache_values=True)
else:
est.refit_final()
scores.append((scorer.score(est), tuple(zip(names, values))))
Expand All @@ -712,7 +703,7 @@ def tune(self, Y, T, *, X=None, W=None,

@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, groups=None,
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand All @@ -727,11 +718,8 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
Features for each sample
W: optional (n × d_w) matrix
Controls for each sample
sample_weight: optional (n,) vector
Weights for each row
sample_var: optional (n, n_y) vector
Variance of sample, in case it corresponds to summary of many samples. Currently
not in use by this method (as inference method does not require sample variance info).
sample_weight : (n,) array like or None
Individual weights for each sample. If None, it assumes equal weight.
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the `cv` argument passed to this class's initializer
Expand All @@ -747,13 +735,11 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
-------
self
"""
if sample_var is not None:
raise ValueError("This estimator does not support sample_var!")
if X is None:
raise ValueError("This estimator does not support X=None!")
Y, T, X, W = check_inputs(Y, T, X, W=W, multi_output_T=True, multi_output_Y=True)
return super().fit(Y, T, X=X, W=W,
sample_weight=sample_weight, sample_var=sample_var, groups=groups,
sample_weight=sample_weight, groups=groups,
cache_values=cache_values,
inference=inference)

Expand Down
Loading

0 comments on commit 9630421

Please sign in to comment.