-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: fix several env/policy load bugs
- Loading branch information
Showing
6 changed files
with
160 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Module containing several control related exceptions.""" | ||
|
||
|
||
class EnvLoadError(Exception): | ||
"""Custom exception that is raised when the saved environment could not be loaded. | ||
Attributes: | ||
log_message (str): The full log message. | ||
details (dict): Dictionary containing extra Exception information. | ||
""" | ||
|
||
def __init__(self, message="", log_message="", **details): | ||
"""Initializes the EePoseLookupError exception object. | ||
Args: | ||
message (str, optional): Exception message specifying whether the exception | ||
occurred. Defaults to ``""``. | ||
log_message (str, optional): Full log message. Defaults to ``""``. | ||
details (dict): Additional dictionary that can be used to supply the user | ||
with more details about why the exception occurred. | ||
""" | ||
super().__init__(message) | ||
|
||
self.log_message = log_message | ||
self.details = details | ||
|
||
|
||
class PolicyLoadError(Exception): | ||
"""Custom exception that is raised when the saved policy could not be loaded. | ||
Attributes: | ||
log_message (str): The full log message. | ||
details (dict): Dictionary containing extra Exception information. | ||
""" | ||
|
||
def __init__(self, message="", log_message="", **details): | ||
"""Initializes the EePoseLookupError exception object. | ||
Args: | ||
message (str, optional): Exception message specifying whether the exception | ||
occurred. Defaults to ``""``. | ||
log_message (str, optional): Full log message. Defaults to ``""``. | ||
details (dict): Additional dictionary that can be used to supply the user | ||
with more details about why the exception occurred. | ||
""" | ||
super().__init__(message) | ||
|
||
self.log_message = log_message | ||
self.details = details |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""A small script which shows how to manually load a saved environment and policy when | ||
the CLI fails. | ||
""" | ||
|
||
import gym | ||
import ros_gazebo_gym # noqa: F401 | ||
from bayesian_learning_control.control.utils.test_policy import ( | ||
load_policy_and_env, | ||
load_pytorch_policy, | ||
load_tf_policy, | ||
run_policy, | ||
) | ||
|
||
AGENT_TYPE = "torch" # The type of agent that was trained. Options: 'tf2' and 'torch'. | ||
AGENT_FOLDER = "/home/ricks/Development/work/bayesian-learning-control/data/2022-02-17_staa_lac_panda_reach/2022-02-17_09-35-31-staa_lac_panda_reach_s25" # noqa: E501 | ||
|
||
if __name__ == "__main__": | ||
# NOTE: STEP 1a: Try to load the policy and environment | ||
try: | ||
env, policy = load_policy_and_env(AGENT_FOLDER) | ||
except Exception: | ||
# NOTE: STEP: 1b: If step 1 fails recreate the environment and load the | ||
# Pytorch/TF2 agent separately. | ||
|
||
# Create the environment | ||
# NOTE: Here the 'FlattenObservation' wrapper is used to make sure the alg works | ||
# with dictionary based observation spaces. | ||
env = gym.make("PandaReach-v1") | ||
env = gym.wrappers.FlattenObservation(env) | ||
|
||
# Load the policy | ||
if AGENT_TYPE.lower() == "tf2": | ||
policy = load_tf_policy(AGENT_FOLDER, itr="last", env=env) # Load TF2 agent | ||
else: | ||
policy = load_pytorch_policy( | ||
AGENT_FOLDER, itr="last", env=env | ||
) # Load Pytorch agent | ||
|
||
# Step 2: Try to run the policy on the environment | ||
try: | ||
run_policy(env, policy) | ||
except Exception: | ||
raise Exception( | ||
"Something went wrong while trying to run the inference. Please check the " | ||
"'AGENT_FOLDER' and try again. If the problem persists please open a issue " | ||
"on https://github.com/rickstaa/bayesian-learning-control/issues." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.