From 4f732f63265b0d1c9c18757c135e7a4376d283f8 Mon Sep 17 00:00:00 2001 From: antoine_galataud Date: Fri, 19 Apr 2024 09:20:35 +0200 Subject: [PATCH] cleanup --- doc/source/ope/index.rst | 2 +- hopes/ope/estimators.py | 4 ++-- tests/test_estimators.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/source/ope/index.rst b/doc/source/ope/index.rst index 6b2f1fc..1c46b66 100644 --- a/doc/source/ope/index.rst +++ b/doc/source/ope/index.rst @@ -7,9 +7,9 @@ Roadmap - [x] Implement Inverse Probability Weighting (IPW) estimator - [x] Implement Self-Normalized Inverse Probability Weighting (SNIPW) estimator - [x] Implement Direct Method (DM) estimator -- [ ] Implement Doubly Robust (DR) estimator - [X] Implement Trajectory-wise Importance Sampling (TIS) estimator - [ ] Implement Per-Decision Importance Sampling (PDIS) estimator +- [ ] Implement Doubly Robust (DR) estimator Implemented estimators ----------------------- diff --git a/hopes/ope/estimators.py b/hopes/ope/estimators.py index 84ae7e2..c6aaf06 100644 --- a/hopes/ope/estimators.py +++ b/hopes/ope/estimators.py @@ -499,11 +499,11 @@ def estimate_weighted_rewards(self) -> np.ndarray: # [gamma^0, gamma^1, ..., gamma^(T-1)] = [gamma^1, gamma^2, ..., gamma^T] / gamma discount_factors = np.cumprod(discount_factors, axis=1) / self.discount_factor - # compute the weighted rewards per trajectory, shape: (n,) + # compute the weighted rewards per trajectory, shape: (n, 1) weighted_rewards = np.sum( importance_weights * rewards * discount_factors, # (n, 1) * (n, T) * (n, T) axis=1, # sum weights over the trajectory length - ) + ).reshape(-1, 1) return weighted_rewards diff --git a/tests/test_estimators.py b/tests/test_estimators.py index f8f1c6c..2eac37b 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -164,7 +164,7 @@ def test_tis(self): wrew = tis.estimate_weighted_rewards() self.assertIsInstance(wrew, np.ndarray) - self.assertEqual(wrew.shape, (5,)) + self.assertEqual(wrew.shape, (5, 1)) policy_value = tis.estimate_policy_value() self.assertIsInstance(policy_value, float)