Skip to content

Commit

Permalink
fix docstrings + clean up
Browse files Browse the repository at this point in the history
Signed-off-by: amarv <[email protected]>
  • Loading branch information
amarvenu committed Jan 10, 2024
1 parent 041bcaf commit 19024c4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
16 changes: 8 additions & 8 deletions econml/validate/drtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def evaluate_cal(
self.get_cate_preds(Xval, Xtrain)

cal_r_squared = np.zeros(self.n_treat)
plot_dict = dict()
plot_data_dict = dict()
for k in range(self.n_treat):
cuts = np.quantile(self.cate_preds_train_[:, k], np.linspace(0, 1, n_groups + 1))
probs = np.zeros(n_groups)
Expand All @@ -409,19 +409,19 @@ def evaluate_cal(
# Calculate R-square calibration score
cal_r_squared[k] = 1 - (cal_score_g / cal_score_o)

df_plot1 = pd.DataFrame({
df_plot = pd.DataFrame({
'ind': np.array(range(n_groups)),
'gate': gate,
'se_gate': se_gate,
'g_cate': g_cate,
'se_g_cate': se_g_cate
})

plot_dict[self.treatments[k + 1]] = df_plot1
plot_data_dict[self.treatments[k + 1]] = df_plot

self.cal_res = CalibrationEvaluationResults(
cal_r_squared=cal_r_squared,
plot_dict=plot_dict,
plot_data_dict=plot_data_dict,
treatments=self.treatments
)

Expand Down Expand Up @@ -528,7 +528,7 @@ def evaluate_uplift(
raise Exception('CATE predictions not yet calculated - must provide both Xval, Xtrain')
self.get_cate_preds(Xval, Xtrain)

curve_dict = dict()
curve_data_dict = dict()
if self.n_treat == 1:
coeff, err, curve_df = calc_uplift(
self.cate_preds_train_,
Expand All @@ -539,7 +539,7 @@ def evaluate_uplift(
)
coeffs = [coeff]
errs = [err]
curve_dict[self.treatments[1]] = curve_df
curve_data_dict[self.treatments[1]] = curve_df
else:
coeffs = []
errs = []
Expand All @@ -553,7 +553,7 @@ def evaluate_uplift(
)
coeffs.append(coeff)
errs.append(err)
curve_dict[self.treatments[k + 1]] = curve_df
curve_data_dict[self.treatments[k + 1]] = curve_df

pvals = [st.norm.sf(abs(q / e)) for q, e in zip(coeffs, errs)]

Expand All @@ -562,7 +562,7 @@ def evaluate_uplift(
errs=errs,
pvals=pvals,
treatments=self.treatments,
curve_dict=curve_dict
curve_data_dict=curve_data_dict
)

return self.uplift_res
Expand Down
24 changes: 16 additions & 8 deletions econml/validate/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@ class CalibrationEvaluationResults:
cal_r_squared: list or numpy array of floats
Sequence of calibration R^2 values
df_plot: pandas dataframe
Dataframe containing necessary data for plotting calibration test GATE results
plot_data_dict: dict
Dictionary mapping treatment levels to dataframes containing necessary
data for plotting calibration test GATE results
treatments: list or numpy array of floats
Sequence of treatment labels
"""
def __init__(
self,
cal_r_squared: np.array,
plot_dict: Dict[Any, pd.DataFrame],
plot_data_dict: Dict[Any, pd.DataFrame],
treatments: np.array
):
self.cal_r_squared = cal_r_squared
self.plot_dict = plot_dict
self.plot_data_dict = plot_data_dict
self.treatments = treatments

def summary(self) -> pd.DataFrame:
Expand Down Expand Up @@ -64,7 +65,7 @@ def plot_cal(self, tmt: Any):
if tmt not in self.treatments[1:]:
raise ValueError(f'Invalid treatment; must be one of {self.treatments[1:]}')

df = self.plot_dict[tmt].copy()
df = self.plot_data_dict[tmt].copy()
rsq = round(self.cal_r_squared[np.where(self.treatments == tmt)[0][0] - 1], 3)
df['95_err'] = 1.96 * df['se_gate']
fig = df.plot(
Expand Down Expand Up @@ -148,20 +149,24 @@ class UpliftEvaluationResults:
treatments: list or numpy array of floats
Sequence of treatment labels
curve_data_dict: dict
Dictionary mapping treatment levels to dataframes containing
necessary data for plotting uplift curves
"""
def __init__(
self,
params: List[float],
errs: List[float],
pvals: List[float],
treatments: np.array,
curve_dict: Dict[Any, pd.DataFrame]
curve_data_dict: Dict[Any, pd.DataFrame]
):
self.params = params
self.errs = errs
self.pvals = pvals
self.treatments = treatments
self.curves = curve_dict
self.curves = curve_data_dict

def summary(self):
"""
Expand Down Expand Up @@ -228,8 +233,11 @@ class EvaluationResults:
blp_res: BLPEvaluationResults object
Results object for BLP test
qini_res: QiniEvaluationResults object
qini_res: UpliftEvaluationResults object
Results object for QINI test
toc_res: UpliftEvaluationResults object
Results object for TOC test
"""
def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion econml/validate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ def calc_uplift(
np.mean(dr_val[inds]) - ate) # tau(q) = q * E[Y(1) - Y(0) | tau(X) >= q[it]] - E[Y(1) - Y(0)]
toc_psi[it, :] = np.squeeze(
(dr_val - ate) * (inds - group_prob) - toc[it]) # influence function for the tau(q)
else:
elif metric == 'toc':
toc[it] = np.mean(dr_val[inds]) - ate # tau(q) := E[Y(1) - Y(0) | tau(X) >= q[it]] - E[Y(1) - Y(0)]
toc_psi[it, :] = np.squeeze((dr_val - ate) * (inds / group_prob - 1) - toc[it])
else:
raise ValueError("Unsupported metric! Must be one of ['toc', 'qini']")

toc_std[it] = np.sqrt(np.mean(toc_psi[it] ** 2) / n) # standard error of tau(q)

Expand Down

0 comments on commit 19024c4

Please sign in to comment.