Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed May 5, 2021
1 parent 968aea9 commit da0ff45
Show file tree
Hide file tree
Showing 2 changed files with 276 additions and 70 deletions.
281 changes: 211 additions & 70 deletions econml/solutions/causal_analysis/_causal_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import Lasso, LassoCV, LogisticRegression, LogisticRegressionCV
from sklearn.preprocessing import OneHotEncoder, PolynomialFeatures
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import OneHotEncoder, PolynomialFeatures, StandardScaler
from sklearn.utils.validation import column_or_1d
from ...cate_interpreter import SingleTreeCateInterpreter, SingleTreePolicyInterpreter
from ...dml import LinearDML
from ...dml import LinearDML, CausalForestDML
from ...sklearn_extensions.linear_model import WeightedLasso
from ...sklearn_extensions.model_selection import GridSearchCVList
from ...utilities import _RegressionWrapper, inverse_onehot
Expand Down Expand Up @@ -101,7 +102,7 @@ def _get_metadata_causal_insights_keys():

def _first_stage_reg(X, y, *, automl=True):
if automl:
model = GridSearchCVList([LassoCV(),
model = GridSearchCVList([make_pipeline(StandardScalar(), LassoCV()),
RandomForestRegressor(
n_estimators=100, random_state=123, min_samples_leaf=10),
lgb.LGBMRegressor()],
Expand All @@ -112,17 +113,18 @@ def _first_stage_reg(X, y, *, automl=True):
cv=2,
scoring='neg_mean_squared_error')
best_est = model.fit(X, y).best_estimator_
if isinstance(best_est, LassoCV):
return Lasso(alpha=best_est.alpha_)
return best_est
if isinstance(best_est, Pipeline):
return make_pipeline(StandardScaler(), Lasso(alpha=best_est.steps[1][1].alpha_))
else:
return best_est
else:
model = LassoCV(cv=5).fit(X, y)
return Lasso(alpha=model.alpha_)
model = make_pipeline(StandardScaler(), LassoCV(cv=5)).fit(X, y)
return make_pipeline(StandardScaler(), Lasso(alpha=model.steps[1][1].alpha_))


def _first_stage_clf(X, y, *, make_regressor=False, automl=True):
if automl:
model = GridSearchCVList([LogisticRegression(),
model = GridSearchCVList([make_pipeline(StandardScaler(), LogisticRegression()),
RandomForestClassifier(
n_estimators=100, random_state=123),
GradientBoostingClassifier(random_state=123)],
Expand All @@ -136,8 +138,8 @@ def _first_stage_clf(X, y, *, make_regressor=False, automl=True):
scoring='neg_log_loss')
est = model.fit(X, y).best_estimator_
else:
model = LogisticRegressionCV(cv=5, max_iter=1000).fit(X, y)
est = LogisticRegression(C=model.C_[0])
model = make_pipeline(StandardScaler(), LogisticRegressionCV(cv=5, max_iter=1000)).fit(X, y)
est = make_pipeline(StandardScaler(), LogisticRegression(C=model.steps[1][1].C_[0]))
if make_regressor:
return _RegressionWrapper(est)
else:
Expand Down Expand Up @@ -216,7 +218,9 @@ def get_feature_names(self, names=None):

class CausalAnalysis:
"""
Gets causal importance of features
Note: this class is experimental and the API may evolve over our next few releases.
Gets causal importance of features.
Parameters
----------
Expand Down Expand Up @@ -251,10 +255,10 @@ class CausalAnalysis:
among several models and the best is chosen.
TODO. Add other options, such as {'azure_automl', 'forests', 'boosting'} that will use particular sub-cases
of models or also integrate with azure autoML. (post-MVP)
heterogeneity_model: one of {'linear'}, optional (default='linear')
heterogeneity_model: one of {'linear', 'forest'}, optional (default='linear')
What type of model to use for treatment effect heterogeneity. 'linear' means that a heterogeneity model
of the form theta(X)=<a, X> will be used.
TODO. Add other options, such as {'forest'}, for the use of a causal forest, or {'automl'} for performing
of the form theta(X)=<a, X> will be used, while 'forest' means that a forest model will be trained instead.
TODO. Add other options, such as {'automl'} for performing
model selection for the causal effect, or {'sparse_linear'} for using a debiased lasso. (post-MVP)
automl: bool, default True
Whether to automatically perform model selection over a variety of models
Expand Down Expand Up @@ -303,8 +307,9 @@ def fit(self, X, y, warm_start=False):
"The only supported nuisance models are 'linear' and 'automl', "
f"but was given {self.nuisance_models}")

assert self.heterogeneity_model in ['linear'], ("The only supported heterogeneity model is 'linear', "
f"but was given {self.heterogeneity_models}")
assert self.heterogeneity_model in ['linear', 'forest'], (
"The only supported heterogeneity models are 'linear' and, 'forest but received "
f"{self.heterogeneity_model}")

assert np.ndim(X) == 2, f"X must be a 2-dimensional array, but here had shape {np.shape(X)}"

Expand Down Expand Up @@ -405,6 +410,12 @@ def process_feature(name, feat_ind):
# can't use X[:, feat_ind] when X is a DataFrame
T = _safe_indexing(X, feat_ind, axis=1)

W = W_transformer.fit_transform(X)
X_xf = X_transformer.fit_transform(X)
if W.shape[1] == 0:
# array checking routines don't accept 0-width arrays
W = None

# perform model selection
model_t = (_first_stage_clf(WX, T, automl=self.nuisance_models == 'automl')
if discrete_treatment else _first_stage_reg(WX, T, automl=self.nuisance_models == 'automl'))
Expand All @@ -416,18 +427,21 @@ def process_feature(name, feat_ind):

# TODO: support other types of heterogeneity via an initializer arg
# e.g. 'forest' -> ForestDML
est = LinearDML(model_y=self._model_y,
model_t=model_t,
featurizer=featurizer,
discrete_treatment=discrete_treatment,
fit_cate_intercept=False,
linear_first_stages=False,
random_state=123)
W = W_transformer.fit_transform(X)
X_xf = X_transformer.fit_transform(X)
if W.shape[1] == 0:
# array checking routines don't accept 0-width arrays
W = None
if self.heterogeneity_model == 'linear':
est = LinearDML(model_y=self._model_y,
model_t=model_t,
featurizer=featurizer,
discrete_treatment=discrete_treatment,
fit_cate_intercept=False,
linear_first_stages=False,
random_state=123)
else:
est = CausalForestDML(model_y=self._model_y,
model_t=model_t,
featurizer=featurizer,
discrete_treatment=discrete_treatment,
random_state=123)
est.tune(y, T, X=X_xf, W=W)
est.fit(y, T, X=X_xf, W=W, cache_values=True)

# effect doesn't depend on W, so only pass in first row
Expand Down Expand Up @@ -521,6 +535,19 @@ def _summary_props(alpha):
(_CausalInsightsConstants.ConfidenceIntervalLowerKey, lambda inf: inf.conf_int_mean(alpha=alpha)[0]),
(_CausalInsightsConstants.ConfidenceIntervalUpperKey, lambda inf: inf.conf_int_mean(alpha=alpha)[1])]

@staticmethod
def _make_accessor(attr):
if isinstance(attr, str):
s = attr

def attr(o):
val = getattr(o, s)
if callable(val):
return val()
else:
return val
return attr

def _summarize(self, *, summary, get_inference, props, n, expand_arr, drop_sample):

assert hasattr(self, "_results"), "This object has not been fit, so cannot get results"
Expand All @@ -540,16 +567,7 @@ def ensure_proper_dims(arr):
# each attr has dimension (m,y) or (m,y,t)
def coalesce(attr):
"""Join together the arrays for each feature"""
if isinstance(attr, str):
s = attr

def attr(o):
val = getattr(o, s)
if callable(val):
return val()
else:
return val

attr = self._make_accessor(attr)
# concatenate along treatment dimension
arr = np.concatenate([ensure_proper_dims(attr(get_inference(res)))
for res in self._results], axis=2)
Expand Down Expand Up @@ -798,10 +816,73 @@ def whatif(self, X, Xnew, feature_index, y):

return inf

def policy_tree(self, Xtest, feature_index, *, treatment_cost=0,
max_depth=3, min_samples_leaf=2, min_value_increase=1e-4, alpha=.1):
def _whatif_dict(self, X, Xnew, feature_index, y):
"""
Get counterfactual predictions when feature_index is changed to Xnew from its observational counterpart.
Note that this only applies to regression use cases; for classification what-if analysis is not supported.
Parameters
----------
X: array-like
Features
Xnew: array-like
New values of a single column of X
feature_index: int or string
The index of the feature being varied to Xnew, either as a numeric index or
the string name if the input is a dataframe
y: array-like
Observed labels or outcome of a predictive model for baseline y values
Returns
-------
dict : dict
The counterfactual predictions, as a dictionary
"""

inf = self.whatif(X, Xnew, feature_index, y)
props = self._point_props(0.05)
res = _get_default_specific_insights('whatif')
res.update([(key, self._make_accessor(attr)(inf).tolist()) for key, attr in props])
return res

def _tree(self, is_policy, Xtest, feature_index, *, treatment_cost=0,
max_depth=3, min_samples_leaf=2, min_impurity_decrease=1e-4, alpha=.1):

result = self._check_feature_index(Xtest, feature_index)
Xtest = result.X_transformer.transform(Xtest)

if result.feature_baseline is None:
treatment_names = ['low', 'high']
else:
treatment_names = [f"{result.feature_baseline}"] + \
[f"{lvl}" for lvl in result.feature_levels]

if len(treatment_names) > 2 and is_policy:
raise AssertionError("Can't create policy trees for multi-class features, "
f"but this feature has values {treatment_names}")

TreeType = SingleTreePolicyInterpreter if is_policy else SingleTreeCateInterpreter
intrp = TreeType(include_model_uncertainty=True,
uncertainty_level=alpha,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=min_impurity_decrease)

if is_policy:
intrp.interpret(result.estimator, Xtest,
sample_treatment_costs=treatment_cost)
else: # no treatment cost for CATE trees
intrp.interpret(result.estimator, Xtest)

return intrp, result.X_transformer.get_feature_names(self.feature_names_), treatment_names

# TODO: it seems like it would be better to just return the tree itself rather than plot it;
# however, the tree can't store the feature and treatment names we compute here...
def plot_policy_tree(self, Xtest, feature_index, *, treatment_cost=0,
max_depth=3, min_samples_leaf=2, min_value_increase=1e-4, alpha=.1):
"""
Get a recommended policy tree in graphviz format.
Plot a recommended policy tree using matplotlib.
Parameters
----------
Expand All @@ -821,29 +902,59 @@ def policy_tree(self, Xtest, feature_index, *, treatment_cost=0,
Confidence level of the confidence intervals displayed in the leaf nodes.
A (1-alpha)*100% confidence interval is displayed.
"""
intrp, feature_names, treatment_names = self._tree(True, Xtest, feature_index,
treatment_cost=treatment_cost,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=min_value_increase,
alpha=alpha)
return intrp.plot(feature_names=feature_names, treatment_names=treatment_names)

def _policy_tree_string(self, Xtest, feature_index, *, treatment_cost=0,
max_depth=3, min_samples_leaf=2, min_value_increase=1e-4, alpha=.1):
"""
Get a recommended policy tree in graphviz format as a string.
result = self._check_feature_index(Xtest, feature_index)
Xtest = result.X_transformer.transform(Xtest)
intrp = SingleTreePolicyInterpreter(include_model_uncertainty=True,
uncertainty_level=alpha,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=min_value_increase)
if result.feature_baseline is None:
treatment_names = ['low', 'high']
else:
treatment_names = [f"{result.feature_baseline}"] + \
[f"{lvl}" for lvl in result.feature_levels]
intrp.interpret(result.estimator, Xtest,
sample_treatment_costs=treatment_cost)
return intrp.export_graphviz(feature_names=result.X_transformer.get_feature_names(self.feature_names_),
Parameters
----------
X : array-like
Features
feature_index
Index of the feature to be considered as treament
treatment_cost : int, or array-like of same length as number of rows of X, optional (default=0)
Cost of treatment, or cost of treatment for each sample
max_depth : int, optional (default=3)
maximum depth of the tree
min_samples_leaf : int, optional (default=2)
minimum number of samples on each leaf
min_value_increase : float, optional (default=1e-4)
The minimum increase in the policy value that a split needs to create to construct it
alpha : float in [0, 1], optional (default=.1)
Confidence level of the confidence intervals displayed in the leaf nodes.
A (1-alpha)*100% confidence interval is displayed.
Returns
-------
tree : string
The policy tree represented as a graphviz string
"""

intrp, feature_names, treatment_names = self._tree(True, Xtest, feature_index,
treatment_cost=treatment_cost,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=min_value_increase,
alpha=alpha)
return intrp.export_graphviz(feature_names=feature_names,
treatment_names=treatment_names)
# TODO: it seems like it would be better to just return the tree itself rather than plot it;
# however, the tree can't store the feature and treatment names we compute here...

def heterogeneity_tree(self, Xtest, feature_index, *,
max_depth=3, min_samples_leaf=2, min_impurity_decrease=1e-4,
alpha=.1):
def plot_heterogeneity_tree(self, Xtest, feature_index, *,
max_depth=3, min_samples_leaf=2, min_impurity_decrease=1e-4,
alpha=.1):
"""
Get an effect hetergoeneity tree in graphviz format.
Plot an effect hetergoeneity tree using matplotlib.
Parameters
----------
Expand All @@ -863,15 +974,45 @@ def heterogeneity_tree(self, Xtest, feature_index, *,
A (1-alpha)*100% confidence interval is displayed.
"""

result = self._check_feature_index(Xtest, feature_index)
Xtest = result.X_transformer.transform(Xtest)
intrp = SingleTreeCateInterpreter(include_model_uncertainty=True,
uncertainty_level=alpha,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=min_impurity_decrease)
intrp.interpret(result.estimator, Xtest)
return intrp.export_graphviz()
intrp, feature_names, treatment_names = self._tree(False, Xtest, feature_index,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=min_impurity_decrease,
alpha=alpha)
return intrp.plot(feature_names=feature_names,
treatment_names=treatment_names)

def _heterogeneity_tree_string(self, Xtest, feature_index, *,
max_depth=3, min_samples_leaf=2, min_impurity_decrease=1e-4,
alpha=.1):
"""
Get an effect hetergoeneity tree in graphviz format as a string.
Parameters
----------
X : array-like
Features
feature_index
Index of the feature to be considered as treament
max_depth : int, optional (default=3)
maximum depth of the tree
min_samples_leaf : int, optional (default=2)
minimum number of samples on each leaf
min_impurity_decrease : float, optional (default=1e-4)
The minimum decrease in the impurity/uniformity of the causal effect that a split needs to
achieve to construct it
alpha : float in [0, 1], optional (default=.1)
Confidence level of the confidence intervals displayed in the leaf nodes.
A (1-alpha)*100% confidence interval is displayed.
"""

intrp, feature_names, treatment_names = self._tree(False, Xtest, feature_index,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
min_impurity_decrease=min_impurity_decrease,
alpha=alpha)
return intrp.export_graphviz(feature_names=feature_names,
treatment_names=treatment_names)

@property
def cate_models_(self):
Expand Down
Loading

0 comments on commit da0ff45

Please sign in to comment.