-
Notifications
You must be signed in to change notification settings - Fork 262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scaling control actions for BRAX environments #473
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @nic-barbara, appreciate the PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! modulo getting some before/after training curves
Thanks for adding this!
@btaba just a heads up I haven't forgotten about this and will run ablation studies soon! All my GPUs are currently in use for some other research. |
@btaba I've just run ablation studies on the Results are attached. Action scaling doesn't seem to harm performance in either case. Would you like me to run them for For reference, here's the code I used to generate these results. It's basically exactly the same as the Brax training notebook but with some additional data logging/plotting. import functools
import matplotlib.pyplot as plt
import numpy as np
from brax import envs
from brax.io import model
from brax.training.agents.ppo import train as ppo
from datetime import datetime
from pathlib import Path
# suffix = ""
suffix = "_scaled"
env_names = ["humanoid"] # ["inverted_pendulum", "humanoid", "pusher", "humanoidstandup"]
seeds = list(range(3))
def _get_fname(env_name, seed):
dirpath = Path(__file__).resolve().parent
fname = dirpath / f"{env_name}_results{suffix}_v{seed}"
return fname
def train_model(env_name, seed, backend='positional'):
# Environments and training functions from brax tutorial
env = envs.get_environment(env_name=env_name, backend=backend)
train_fn = {
'inverted_pendulum': functools.partial(ppo.train, num_timesteps=2_000_000, num_evals=20, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=2048, batch_size=1024, seed=seed),
'humanoid': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=10, reward_scaling=0.1, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=10, num_minibatches=32, num_updates_per_batch=8, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=2048, batch_size=1024, seed=seed),
'humanoidstandup': functools.partial(ppo.train, num_timesteps=100_000_000, num_evals=20, reward_scaling=0.1, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=15, num_minibatches=32, num_updates_per_batch=8, discounting=0.97, learning_rate=6e-4, entropy_cost=1e-2, num_envs=2048, batch_size=1024, seed=seed),
'pusher': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=20, reward_scaling=5, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=30, num_minibatches=16, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4,entropy_cost=1e-2, num_envs=2048, batch_size=512, seed=seed),
}[env_name]
# Log rewards and print if required
results = {"rewards": [], "stdev": [], "steps": [], "times": []}
results["times"].append(datetime.now())
def progress(num_steps, metrics):
results["times"].append(datetime.now())
results["steps"].append(num_steps)
results["rewards"].append(metrics["eval/episode_reward"])
results["stdev"].append(metrics["eval/episode_reward_std"])
print("step: {} \t reward: {:.2f} \t stdev: {:.2f} \t time: {}".format(
num_steps,
metrics["eval/episode_reward"],
metrics["eval/episode_reward_std"],
results["times"][-1],))
# Train
_, params, _ = train_fn(environment=env, progress_fn=progress)
times = results["times"]
print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')
# Save params and metrics
data = (*params, results)
fname = _get_fname(env_name, seed)
model.save_params(fname, data)
def plot_rewards(env_name, seeds):
# Read in results
xs, ys = [], []
for seed in seeds:
fname = _get_fname(env_name, seed)
results = model.load_params(fname)[2]
xs.append(np.array(results["steps"]))
ys.append(np.array(results["rewards"]))
x = np.vstack(xs).T
y = np.vstack(ys).T
# Plot formatting
max_y = {'inverted_pendulum': 1100, 'humanoid': 13000, 'humanoidstandup': 75_000, 'pusher': 0}[env_name]
min_y = {'reacher': -100, 'pusher': -150}.get(env_name, 0)
# Make the plot and save it
_, ax = plt.subplots()
ax.plot(x, y)
ax.set_ylim(min_y, max_y)
ax.set_xlabel("Environment steps")
ax.set_ylabel("Reward")
ax.set_title(f"{env_name}{suffix} ({len(seeds)} random seeds)")
plt.tight_layout()
plt.savefig(f"{fname}.pdf")
plt.close()
for env_name in env_names:
for seed in seeds:
train_model(env_name, seed)
plot_rewards(env_name, seeds) |
Hey @nic-barbara thanks for checking on those two envs, scaled plots LGTM! |
This PR is an attempt to fix #472. All environments assume the control action is restricted to
[-1,1]
.For the few environments that have an action space that is not
[-1,1]
(humanoid
,humanoidstandup
,pusher
,inverted_pendulum
), the input action is linearly scaled to the limits of the action space inside the environment'sstep()
function.Please let me know if this is an appropriate solution, I'd be happy to iterate it. Note also that scaling the actions will mean that the "best" hyperparameters for these environments will likely change. Does this need to be addressed anywhere?