Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
antoine-galataud committed Apr 19, 2024
1 parent fb32665 commit 4f732f6
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion doc/source/ope/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ Roadmap
- [x] Implement Inverse Probability Weighting (IPW) estimator
- [x] Implement Self-Normalized Inverse Probability Weighting (SNIPW) estimator
- [x] Implement Direct Method (DM) estimator
- [ ] Implement Doubly Robust (DR) estimator
- [X] Implement Trajectory-wise Importance Sampling (TIS) estimator
- [ ] Implement Per-Decision Importance Sampling (PDIS) estimator
- [ ] Implement Doubly Robust (DR) estimator

Implemented estimators
-----------------------
Expand Down
4 changes: 2 additions & 2 deletions hopes/ope/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,11 @@ def estimate_weighted_rewards(self) -> np.ndarray:
# [gamma^0, gamma^1, ..., gamma^(T-1)] = [gamma^1, gamma^2, ..., gamma^T] / gamma
discount_factors = np.cumprod(discount_factors, axis=1) / self.discount_factor

# compute the weighted rewards per trajectory, shape: (n,)
# compute the weighted rewards per trajectory, shape: (n, 1)
weighted_rewards = np.sum(
importance_weights * rewards * discount_factors, # (n, 1) * (n, T) * (n, T)
axis=1, # sum weights over the trajectory length
)
).reshape(-1, 1)

return weighted_rewards

Expand Down
2 changes: 1 addition & 1 deletion tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_tis(self):

wrew = tis.estimate_weighted_rewards()
self.assertIsInstance(wrew, np.ndarray)
self.assertEqual(wrew.shape, (5,))
self.assertEqual(wrew.shape, (5, 1))

policy_value = tis.estimate_policy_value()
self.assertIsInstance(policy_value, float)
Expand Down

0 comments on commit 4f732f6

Please sign in to comment.