Skip to content

Commit

Permalink
change uplift metric error handling
Browse files Browse the repository at this point in the history
Signed-off-by: amarv <[email protected]>
  • Loading branch information
amarvenu committed Jan 10, 2024
1 parent 19024c4 commit 41c4068
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 4 deletions.
3 changes: 0 additions & 3 deletions econml/validate/drtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,6 @@ def evaluate_uplift(
if not hasattr(self, 'dr_val_'):
raise Exception("Must fit nuisances before evaluating")

if not (metric in ['qini', 'toc']):
raise ValueError("Uplift metric must be one of ['qini', 'toc']")

if (not hasattr(self, 'cate_preds_train_')) or (not hasattr(self, 'cate_preds_val_')):
if (Xval is None) or (Xtrain is None):
raise Exception('CATE predictions not yet calculated - must provide both Xval, Xtrain')
Expand Down
2 changes: 1 addition & 1 deletion econml/validate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def calc_uplift(
toc[it] = np.mean(dr_val[inds]) - ate # tau(q) := E[Y(1) - Y(0) | tau(X) >= q[it]] - E[Y(1) - Y(0)]
toc_psi[it, :] = np.squeeze((dr_val - ate) * (inds / group_prob - 1) - toc[it])
else:
raise ValueError("Unsupported metric! Must be one of ['toc', 'qini']")
raise ValueError("Unsupported metric - must be one of ['toc', 'qini']")

toc_std[it] = np.sqrt(np.mean(toc_psi[it] ** 2) / n) # standard error of tau(q)

Expand Down

0 comments on commit 41c4068

Please sign in to comment.