From 19024c4716d4137c2918ce4b7cc89cf6f35fc183 Mon Sep 17 00:00:00 2001 From: amarv Date: Tue, 9 Jan 2024 11:04:04 -0800 Subject: [PATCH] fix docstrings + clean up Signed-off-by: amarv --- econml/validate/drtester.py | 16 ++++++++-------- econml/validate/results.py | 24 ++++++++++++++++-------- econml/validate/utils.py | 4 +++- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/econml/validate/drtester.py b/econml/validate/drtester.py index fea403b6c..fd3805100 100644 --- a/econml/validate/drtester.py +++ b/econml/validate/drtester.py @@ -382,7 +382,7 @@ def evaluate_cal( self.get_cate_preds(Xval, Xtrain) cal_r_squared = np.zeros(self.n_treat) - plot_dict = dict() + plot_data_dict = dict() for k in range(self.n_treat): cuts = np.quantile(self.cate_preds_train_[:, k], np.linspace(0, 1, n_groups + 1)) probs = np.zeros(n_groups) @@ -409,7 +409,7 @@ def evaluate_cal( # Calculate R-square calibration score cal_r_squared[k] = 1 - (cal_score_g / cal_score_o) - df_plot1 = pd.DataFrame({ + df_plot = pd.DataFrame({ 'ind': np.array(range(n_groups)), 'gate': gate, 'se_gate': se_gate, @@ -417,11 +417,11 @@ def evaluate_cal( 'se_g_cate': se_g_cate }) - plot_dict[self.treatments[k + 1]] = df_plot1 + plot_data_dict[self.treatments[k + 1]] = df_plot self.cal_res = CalibrationEvaluationResults( cal_r_squared=cal_r_squared, - plot_dict=plot_dict, + plot_data_dict=plot_data_dict, treatments=self.treatments ) @@ -528,7 +528,7 @@ def evaluate_uplift( raise Exception('CATE predictions not yet calculated - must provide both Xval, Xtrain') self.get_cate_preds(Xval, Xtrain) - curve_dict = dict() + curve_data_dict = dict() if self.n_treat == 1: coeff, err, curve_df = calc_uplift( self.cate_preds_train_, @@ -539,7 +539,7 @@ def evaluate_uplift( ) coeffs = [coeff] errs = [err] - curve_dict[self.treatments[1]] = curve_df + curve_data_dict[self.treatments[1]] = curve_df else: coeffs = [] errs = [] @@ -553,7 +553,7 @@ def evaluate_uplift( ) coeffs.append(coeff) errs.append(err) - curve_dict[self.treatments[k + 1]] = curve_df + curve_data_dict[self.treatments[k + 1]] = curve_df pvals = [st.norm.sf(abs(q / e)) for q, e in zip(coeffs, errs)] @@ -562,7 +562,7 @@ def evaluate_uplift( errs=errs, pvals=pvals, treatments=self.treatments, - curve_dict=curve_dict + curve_data_dict=curve_data_dict ) return self.uplift_res diff --git a/econml/validate/results.py b/econml/validate/results.py index 074860496..d236757a7 100644 --- a/econml/validate/results.py +++ b/econml/validate/results.py @@ -13,8 +13,9 @@ class CalibrationEvaluationResults: cal_r_squared: list or numpy array of floats Sequence of calibration R^2 values - df_plot: pandas dataframe - Dataframe containing necessary data for plotting calibration test GATE results + plot_data_dict: dict + Dictionary mapping treatment levels to dataframes containing necessary + data for plotting calibration test GATE results treatments: list or numpy array of floats Sequence of treatment labels @@ -22,11 +23,11 @@ class CalibrationEvaluationResults: def __init__( self, cal_r_squared: np.array, - plot_dict: Dict[Any, pd.DataFrame], + plot_data_dict: Dict[Any, pd.DataFrame], treatments: np.array ): self.cal_r_squared = cal_r_squared - self.plot_dict = plot_dict + self.plot_data_dict = plot_data_dict self.treatments = treatments def summary(self) -> pd.DataFrame: @@ -64,7 +65,7 @@ def plot_cal(self, tmt: Any): if tmt not in self.treatments[1:]: raise ValueError(f'Invalid treatment; must be one of {self.treatments[1:]}') - df = self.plot_dict[tmt].copy() + df = self.plot_data_dict[tmt].copy() rsq = round(self.cal_r_squared[np.where(self.treatments == tmt)[0][0] - 1], 3) df['95_err'] = 1.96 * df['se_gate'] fig = df.plot( @@ -148,6 +149,10 @@ class UpliftEvaluationResults: treatments: list or numpy array of floats Sequence of treatment labels + + curve_data_dict: dict + Dictionary mapping treatment levels to dataframes containing + necessary data for plotting uplift curves """ def __init__( self, @@ -155,13 +160,13 @@ def __init__( errs: List[float], pvals: List[float], treatments: np.array, - curve_dict: Dict[Any, pd.DataFrame] + curve_data_dict: Dict[Any, pd.DataFrame] ): self.params = params self.errs = errs self.pvals = pvals self.treatments = treatments - self.curves = curve_dict + self.curves = curve_data_dict def summary(self): """ @@ -228,8 +233,11 @@ class EvaluationResults: blp_res: BLPEvaluationResults object Results object for BLP test - qini_res: QiniEvaluationResults object + qini_res: UpliftEvaluationResults object Results object for QINI test + + toc_res: UpliftEvaluationResults object + Results object for TOC test """ def __init__( self, diff --git a/econml/validate/utils.py b/econml/validate/utils.py index a473a0b7d..8cd69cae8 100644 --- a/econml/validate/utils.py +++ b/econml/validate/utils.py @@ -90,9 +90,11 @@ def calc_uplift( np.mean(dr_val[inds]) - ate) # tau(q) = q * E[Y(1) - Y(0) | tau(X) >= q[it]] - E[Y(1) - Y(0)] toc_psi[it, :] = np.squeeze( (dr_val - ate) * (inds - group_prob) - toc[it]) # influence function for the tau(q) - else: + elif metric == 'toc': toc[it] = np.mean(dr_val[inds]) - ate # tau(q) := E[Y(1) - Y(0) | tau(X) >= q[it]] - E[Y(1) - Y(0)] toc_psi[it, :] = np.squeeze((dr_val - ate) * (inds / group_prob - 1) - toc[it]) + else: + raise ValueError("Unsupported metric! Must be one of ['toc', 'qini']") toc_std[it] = np.sqrt(np.mean(toc_psi[it] ** 2) / n) # standard error of tau(q)