-
Notifications
You must be signed in to change notification settings - Fork 266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated basic APG algorithm #476
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, just one nit.
@@ -131,7 +131,7 @@ def forward_log_det_jacobian(self, x): | |||
class NormalTanhDistribution(ParametricDistribution): | |||
"""Normal distribution followed by tanh.""" | |||
|
|||
def __init__(self, event_size, min_std=0.001): | |||
def __init__(self, event_size, min_std=0.001, var_scale=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add var_scale to Args section of docstring
Oh - it looks like you need to update brax/training/agents/apg/train_test.py - changes should hopefully be minimal - please review the failing test. Also, please do sign the CLA. Thank you! |
Hi @erikfrey, I have updated the tests and they pass on my local setup. I've also fixed the nit. I've signed the CLA, and the Checks tab is saying that my signing went through. Please let me know if there's anything missing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there!
@@ -22,6 +22,9 @@ | |||
from brax.training.agents.apg import networks as apg_networks | |||
from brax.training.agents.apg import train as apg | |||
import jax | |||
from jax import config | |||
config.update("jax_enable_x64", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these config changes are process-wide, so it's breaking other tests that expect default float width and precision. we run all the tests in a single process.
the envs in this test are simple enough that hopefully they don't need these config changes. can you try removing these config changes or otherwise tweak the tests so that jax_enable_64
and the precision change are not needed?
I removed the double precision toggle and the tests still run fine on my local setup. Let's see if this works :) |
Amazing, thank you! |
The goal of this proposed update is to provide a simple APG algorithm that can solve non-trivial tasks, as a first step for researchers and practitioners to explore Brax's differentiable simulation. It has been tested on MJX. Notes:
1: Algorithm Update
This fork contains an APG implementation that is about as simple as the current one, but reflects a common thread between several recent results that have used differentiable simulators to achieve locomotion: 1, 2, 3.
Brax's current APG algorithm is roughly equivalent to the following pseudocode:
In contrast, the cited results update the policy gradient much more frequently, using the observation that policy gradients that differentiate through the simulator have low variance. Hence, unrolling for an entire episode before updating has limited use. That additional samples past a certain point do not help is seen in that convergence does not increase with with massive parallelization [2]. The proposed APG algorithm essentially performs live stochastic gradient descent on the policy, unrolling it for a short window, doing a gradient update, then continuing where it left off:
Note that n_epochs can be much larger in this case. This modification allows the algorithm to relatively quickly learn quadruped locomotion, albeit with a particular training pipeline and reward design. (Notebook)
Additional notes:
2: Supporting Updates
Configurable initial policy variance: When hotstarting a policy, it benefits to explore around its induced state-space vicinity. This can be done by initializing the policy network weights small. Currently, the softplus disables this possibility, so this fork adds a scaling parameter.
Layer norm: I have found that using layer normalization in the policy neural network has greatly improved the training stability of APG methods and is seen in other implementations.