From 41c4068b6cb7314e569c083c35b6186c40e8403b Mon Sep 17 00:00:00 2001 From: amarv Date: Tue, 9 Jan 2024 11:06:42 -0800 Subject: [PATCH] change uplift metric error handling Signed-off-by: amarv --- econml/validate/drtester.py | 3 --- econml/validate/utils.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/econml/validate/drtester.py b/econml/validate/drtester.py index fd3805100..bcf4c0141 100644 --- a/econml/validate/drtester.py +++ b/econml/validate/drtester.py @@ -520,9 +520,6 @@ def evaluate_uplift( if not hasattr(self, 'dr_val_'): raise Exception("Must fit nuisances before evaluating") - if not (metric in ['qini', 'toc']): - raise ValueError("Uplift metric must be one of ['qini', 'toc']") - if (not hasattr(self, 'cate_preds_train_')) or (not hasattr(self, 'cate_preds_val_')): if (Xval is None) or (Xtrain is None): raise Exception('CATE predictions not yet calculated - must provide both Xval, Xtrain') diff --git a/econml/validate/utils.py b/econml/validate/utils.py index 8cd69cae8..50dc3235d 100644 --- a/econml/validate/utils.py +++ b/econml/validate/utils.py @@ -94,7 +94,7 @@ def calc_uplift( toc[it] = np.mean(dr_val[inds]) - ate # tau(q) := E[Y(1) - Y(0) | tau(X) >= q[it]] - E[Y(1) - Y(0)] toc_psi[it, :] = np.squeeze((dr_val - ate) * (inds / group_prob - 1) - toc[it]) else: - raise ValueError("Unsupported metric! Must be one of ['toc', 'qini']") + raise ValueError("Unsupported metric - must be one of ['toc', 'qini']") toc_std[it] = np.sqrt(np.mean(toc_psi[it] ** 2) / n) # standard error of tau(q)