-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cartpole example and end to end unittest #6
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Config for the cartpole implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Demo jupyter notebook showing how to use the library
|
||
def forward(self, context: torch.Tensor) -> torch.Tensor: | ||
logits = self.model(context) | ||
actions = (torch.sigmoid(logits) > 0.5).squeeze().int().cpu().numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to do some hacking here to get the prescriptor to fit the required action in .step()
@@ -33,7 +33,7 @@ def evaluate_population(self, population: list[Prescriptor], force=False, verbos | |||
:param force: Whether to force evaluation of all prescriptors. | |||
:param verbose: Whether to show a progress bar. | |||
""" | |||
iterator = population if verbose < 1 else tqdm(population, leave=False) | |||
iterator = population if verbose < 1 else tqdm(population, leave=False, desc="Evaluating Population") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some nice text for the loading bars
@@ -185,5 +185,5 @@ def run_evolution(self): | |||
Runs the evolutionary process for n_generations. | |||
""" | |||
self.create_initial_population() | |||
for _ in tqdm(range(self.generation, self.n_generations+1)): | |||
for _ in tqdm(range(self.generation, self.n_generations+1), desc="Running Evolution"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some more nice text for the loading bars
numpy==1.26.4 | ||
pandas==2.2.3 | ||
pyyaml==6.0.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gymnasium and pyyaml are now dependencies because of the cart pole example. In general they're not necessarily needed, maybe later we can put in a flag to download these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A unit test that runs evolution on cartpole end to end and makes sure it works and the output is in the expected format.
# Checks the results file is 101x4 | ||
self.assertEqual(len(rows), 101) | ||
for row in rows: | ||
self.assertEqual(len(row), 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check the results file is the right shape: a csv of shape 101x4
# Checks that the first candidate in the file has rank 1, inf distance, and 0 score | ||
self.assertEqual(rows[1][1], "1") | ||
self.assertEqual(rows[1][2], "inf") | ||
self.assertEqual(rows[1][3], "-0.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We make sure we found a good solution. The best solution should be rank 1 with infinite distance and have a score of 0 (the max possible)
Added cartpole example implementations and an end to end unittest running cartpole and ensuring it found a solution that passes.