Skip to content

Commit

Permalink
feat: add support for dictionary type observation spaces
Browse files Browse the repository at this point in the history
This commit makes sure that the BLC package also works with dictionary
type observation spaces.
  • Loading branch information
rickstaa committed Jan 26, 2022
1 parent 6155872 commit e3bf761
Show file tree
Hide file tree
Showing 18 changed files with 323 additions and 399 deletions.
11 changes: 0 additions & 11 deletions .github/dependabot.yml

This file was deleted.

3 changes: 1 addition & 2 deletions TODOS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# TODOS

* Add openai\_ros enivronments. + virtual env explanation.
* Use gym wrappers in the disturber.
* Try to add torch script?
86 changes: 46 additions & 40 deletions bayesian_learning_control/control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def lac( # noqa: C901
actor_critic (torch.nn.Module, optional): The constructor method for a
Torch Module with an ``act`` method, a ``pi`` module and several
``Q`` or ``L`` modules. The ``act`` method and ``pi`` module should
accept batches of observations as inputs, and the ``Q*`` and ``L``
accept batches of observations as inputs, and the ``Q*`` and ``L``
modules should accept a batch of observations and a batch of actions as
inputs. When called, these modules should return:
Expand Down Expand Up @@ -963,7 +963,13 @@ def lac( # noqa: C901
} # Retrieve hyperparameters (Ignore logger object)
logger.save_config(hyper_paramet_dict) # Write hyperparameters to logger

env, test_env = env_fn(), env_fn()
env = env_fn()
env = gym.wrappers.FlattenObservation(
env
) # NOTE: Done to make sure the alg works with dict observation spaces
if num_test_episodes != 0:
test_env = env_fn()
test_env = gym.wrappers.FlattenObservation(test_env)
obs_dim = env.observation_space.shape
act_dim = env.action_space.shape[0]
rew_dim = (
Expand All @@ -972,20 +978,21 @@ def lac( # noqa: C901

# Retrieve max episode length
if max_ep_len is None:
max_ep_len = env._max_episode_steps
max_ep_len = env.env._max_episode_steps
else:
if max_ep_len > env._max_episode_steps:
if max_ep_len > env.env._max_episode_steps:
logger.log(
(
f"You defined your 'max_ep_len' to be {max_ep_len} "
"while the environment 'max_epsiode_steps' is "
f"{env._max_episode_steps}. As a result the environment "
"while the environment 'max_epsisode_steps' is "
f"{env.env._max_episode_steps}. As a result the environment "
f"'max_episode_steps' has been increased to {max_ep_len}"
),
type="warning",
)
env._max_episode_steps = max_ep_len
test_env._max_episode_steps = max_ep_len
env.env._max_episode_steps = max_ep_len
if num_test_episodes != 0:
test_env.env._max_episode_steps = max_ep_len

# Get default actor critic if no 'actor_critic' was supplied
actor_critic = LyapunovActorCritic if actor_critic is None else actor_critic
Expand All @@ -997,7 +1004,8 @@ def lac( # noqa: C901
np.random.seed(seed)
random.seed(seed)
env.seed(seed)
test_env.seed(seed)
if num_test_episodes != 0:
test_env.seed(seed)

policy = LAC(
env,
Expand All @@ -1024,10 +1032,8 @@ def lac( # noqa: C901
logger.log("Model successfully restored.", type="info")
except Exception as e:
logger.log(
(
"Shutting down training since {}.".format(
e.args[0].lower().rstrip(".")
)
"Shutting down training since {}.".format(
e.args[0].lower().rstrip(".")
),
type="error",
)
Expand Down Expand Up @@ -1113,7 +1119,6 @@ def lac( # noqa: C901
start_time = time.time()
o, ep_ret, ep_len = env.reset(), 0, 0
for t in range(total_steps):

# Until start_steps have elapsed, randomly sample actions
# from a uniform distribution for better exploration. Afterwards,
# use the learned policy.
Expand Down Expand Up @@ -1156,10 +1161,7 @@ def lac( # noqa: C901
for _ in range(steps_per_update):
batch = replay_buffer.sample_batch(batch_size)
update_diagnostics = policy.update(data=batch)

# Log diagnostics
logger.store(**update_diagnostics)

logger.store(**update_diagnostics) # Log diagnostics
# SGD batch tb logging
if use_tensorboard and not tb_low_log_freq:
logger.log_to_tb(keys=diag_tb_log_list, global_step=t)
Expand All @@ -1173,14 +1175,15 @@ def lac( # noqa: C901
logger.save_state({"env": env}, itr=epoch)

# Test the performance of the deterministic version of the agent
eps_ret, eps_len = test_agent(
policy, test_env, num_test_episodes, max_ep_len=max_ep_len
)
logger.store(
TestEpRet=eps_ret,
TestEpLen=eps_len,
extend=True,
)
if num_test_episodes != 0:
eps_ret, eps_len = test_agent(
policy, test_env, num_test_episodes, max_ep_len=max_ep_len
)
logger.store(
TestEpRet=eps_ret,
TestEpLen=eps_len,
extend=True,
)

# Epoch based learning rate decay
if lr_decay_ref.lower() != "step":
Expand All @@ -1207,17 +1210,18 @@ def lac( # noqa: C901
with_min_and_max=True,
tb_write=use_tensorboard,
)
logger.log_tabular(
"TestEpRet",
with_min_and_max=True,
tb_write=use_tensorboard,
)
logger.log_tabular("EpLen", average_only=True, tb_write=use_tensorboard)
logger.log_tabular(
"TestEpLen",
average_only=True,
tb_write=use_tensorboard,
)
if num_test_episodes != 0:
logger.log_tabular(
"TestEpRet",
with_min_and_max=True,
tb_write=use_tensorboard,
)
logger.log_tabular("EpLen", average_only=True, tb_write=use_tensorboard)
logger.log_tabular(
"TestEpLen",
average_only=True,
tb_write=use_tensorboard,
)
logger.log_tabular(
"Lr_a",
policy._pi_optimizer.param_groups[0]["lr"],
Expand Down Expand Up @@ -1305,8 +1309,7 @@ def lac( # noqa: C901
parser.add_argument(
"--env",
type=str,
# default="Oscillator-v1",
default="CartPoleCost-v0", # DEBUG
default="Oscillator-v1",
help="the gym env (default: Oscillator-v1)",
)
parser.add_argument(
Expand Down Expand Up @@ -1403,7 +1406,10 @@ def lac( # noqa: C901
"--num_test_episodes",
type=int,
default=10,
help="the number of episodes for the performance analysis (default: 10)",
help=(
"the number of episodes for the performance analysis (default: 10). When "
"set to zero no test episodes will be performed"
),
)
parser.add_argument(
"--alpha",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from bayesian_learning_control.control.algos.pytorch.policies.lyapunov_actor_critic import (
LyapunovActorCritic,
)
from bayesian_learning_control.control.algos.pytorch.policies.lyapunov_actor_critic2 import (
LyapunovActorCritic2,
)
from bayesian_learning_control.control.algos.pytorch.policies.soft_actor_critic import (
SoftActorCritic,
)

This file was deleted.

Loading

0 comments on commit e3bf761

Please sign in to comment.