Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing gym's Mujoco envs with brax envs #49

Open
vwxyzjn opened this issue Sep 1, 2021 · 70 comments
Open

Replacing gym's Mujoco envs with brax envs #49

vwxyzjn opened this issue Sep 1, 2021 · 70 comments

Comments

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Sep 1, 2021

Had a conversation with @jkterry1 on openai/gym#2366, and it appears brax would also be a great alternative for the mujoco envs replacement.

To help with this transition. I made an attempt to try out brax with pytorch. Here is a basic report: https://wandb.ai/costa-huang/brax/reports/Brax-as-Pybullet-replacement--Vmlldzo5ODI4MDk. The source code is here: https://github.com/vwxyzjn/cleanrl/blob/mybranch/cleanrl/brax/readme.md

One of the biggest issue with the brax adoption is the env normalization:

I think going forward, probably the best way to fix this is to refactor the brax training side's normalization to the environment side. This in the future will also help throughput with the JaxToTorchWrapper. Otherwise, the observation will go from GPU to CPU for gym or sb3's normalization wrapper, then GPU again for torch, which just doesn't make sense.

One small thing is that given the brax environment directly produces the vector env, there is also no way to inject a ClipActionsWrapper(env), which may or may not have a performance impact. That said, this can be implemented in the training side with ease.

@erwincoumans
Copy link

erwincoumans commented Sep 1, 2021

Yes, as I suggested previously, Brax seems a good option for OpenAI Gym, since it allows for GPU and TPU accelerators (training in minutes instead of hours), next to CPU. We can use this issue to track progress and add an itemized todo.

@jkterry1
Copy link

jkterry1 commented Sep 5, 2021

To recap the to do list:

  1. Add suitable rendering
  2. Further tune observation/action spaces to make them as close as possible
  3. Make sure we are not reproducing the list of bugs in MuJuCo environments from Antonin Raffin that I sent you

I feel like there may have been a 4th issue, but I don't sleep very much and can no longer recall it. @erwincoumans @benelot do you remember?

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Sep 5, 2021

One note on the suitable rendering is I feel implementing env.render(“rgb_array”) might be too expensive and counterproductive. Maybe implementing env.render(“html”) at the end of episode is more preferable.

@jkterry1
Copy link

jkterry1 commented Sep 5, 2021

They're planning to add a new rendering engine such that "rgb_array" will be suitable

@jkterry1
Copy link

jkterry1 commented Sep 5, 2021

I don't know if this is the 4th feature I can't remember, but another thing we'll need to eventually deal with that I briefly discussed is action/observation space documentation for the new Gym website we're working on, in the flavor of https://www.pettingzoo.ml/classic/chess

@joaogui1
Copy link

joaogui1 commented Sep 5, 2021

I would like to help with this, what can I do to help?

@jkterry1
Copy link

jkterry1 commented Sep 6, 2021

@joaogui1 Probably nothing, at least at the moment. Right I'm waiting on the Brax team to do some work and for the guy who created the pybullet replacement envs to get back from vacation, this will take 4-6 weeks. If you'd like to help with gym maintenance problems in general though, please email me and we can coordinate some things ([email protected])

@joaogui1
Copy link

joaogui1 commented Sep 7, 2021

Got it, will wait a little then, thanks!

@sgillen
Copy link
Contributor

sgillen commented Sep 8, 2021

I'm also happy to help on this, I've spent a lot of time with the mujoco/pybullet environments at this point. Can certainly help with points 2/3 that @jkterry1 posted in this thread.

@erikfrey
Copy link
Collaborator

We have started working on 1) the renderer. We're looking at porting a simple technique like https://github.com/rougier/tiny-renderer to JAX as a new module in brax.io

Tuning observation/action space could start in parallel if anyone is interested. I think the steps would involve:

  1. reset a Gym Mujoco env (say Ant) to default state and inspect the observation space and its description
  2. compare to Brax Ant env and make adjustments
  3. step both and compare dynamic observations (e.g. contact forces)

I think the envs are already ~80% comparable, and the last 20% is just sleuthing to read the mujoco docs, and confirm the format matches. I think we can get to the point where the meaning of each observation dimension is the same in both envs, even if the dynamics are still different.

@sgillen
Copy link
Contributor

sgillen commented Sep 11, 2021

I can get that going next week. I will use Mujoco 1.5 due to this issue. It looks like the Brax environments are based off the v2 version of the Mujoco environments, so I'll start by comparing to those. Based on openai/gym#1304 I think the v3 versions are supposed to be identical if using default args, not 100% sure that's true though.

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Sep 12, 2021

This is so great to hear! I also have a quick update. Gym now has a normalization wrapper: openai/gym#2387. The usage is roughly

env = gym.make("HalfCheetahBulletEnv-v0")
env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.ClipAction(env)
env = gym.wrappers.NormalizeObservation(env)
env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
env = gym.wrappers.NormalizeReward(env)
env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))

However as I suggested earlier, this might be not as fast as implementing the normalization on brax's side. Another thing is directly applying these wrappers to brax environment won't work because some issues with jax's device array overriding numpy arrays in the wrappers.

A typical example is gym.wrappers.RecordEpisodeStatistics, and its episode_returns array will be casted to a jax array, which causes problems because jax array is not mutable.

@sgillen
Copy link
Contributor

sgillen commented Sep 16, 2021

Ok, I was a bit busier than I expected this week, but as promised I did start comparing the ant environments this evening.
Here is a notebook I was using that may be useful to anyone else who wants to compare and tweak the envs.

With regards to the observations:

  1. I believe all the state position and velocity information match up. For Mujoco it seems to be: z + quaternion for the torso (5), 8 joint angles, dxyz/drot (6) for torso, 8 more joint velocities, which matches exactly what brax has.
  2. The contact information is where big differences appear. Brax seems to be missing some internal bodies that are present in the Mujoco model, this accounts for the difference in observation size (The brax team was already aware of this).
  3. I'm not sure what the ordering for the contact forces is in brax. It doesn't match what mujoco does (see the notebook linked above) and it also doesn't seem to match up with the bodies in env.sys.body_idx.keys().

with regard to rewards:

  1. The rewards also exclude any contact force penalty because of the lack of those forces caused by a bug with gym+Mujoco 2.0 (see the issue I posted above), but I think it would be best to put them back.

If the goal is to make as faithful representation of mujoco envs as possible (which IMO it shouldn't necessarily be) then we will at least need to address the following:

  1. The mj ant starts life suspended .75m in the air, the brax ant at .5
  2. mj adds a relatively large amount random noise to its initial state on reset.
  3. Inertial parameters for the two envs are different. Does brax have a way to infer an inertia from geometry? This is what mj does.
  4. No matter the ordering , the magnitude of the force and and moment are substantially different, but that may be because of the difference in mass.
  5. Torque limits appear different 300 in brax vs 150 in mj (units? That would be a lot of N*m)
  6. These are minor, but may want to find out what brax integrator settings are closest to an rk4 with dt = .01.
  7. May also want to tune friction parameters, which will probably need to be done empirically.

TLDR: For the ant the difference in observations is in the ordering and number of contact forces. To make them match exactly we would need to re order the existing forces, and insert some dummy, zeroed elements into the observation. That said the "missing" contact forces weren't useful in the old env, and the ordering of contacts shouldn't matter to an RL agent, so IMHO it would be enough to adjust the mass, inertia, and torque limit, add back in the contact force reward/penalty, and maybe add the wider distribution to initial state.

@sgillen
Copy link
Contributor

sgillen commented Sep 16, 2021

@vwxyzjn good to hear about the normalization wrapper, I agree that the normalization and clipping should all be done on the brax side. This makes things awkward with respect to saving and loading environments / agents, since it will make brax a special case for gym, sb3 etc. Related, I also think that if the brax envs aren't going to be extremely fast that it would better to just use pybullet.

@erikfrey
Copy link
Collaborator

@vwxyzjn we recently started using a similar Wrapper concept for wrapping envs in Brax, inspired by Gym. e.g. EpisodeWrapper collects episode statistics and sets done at the episode boundary, and so on:

https://github.com/google/brax/blob/main/brax/envs/wrappers.py#L43

I don't think it would be too hard to make the brax API mirror what gym is doing, and still keep it all on device.

@erikfrey
Copy link
Collaborator

@sgillen this is super helpful - thanks for putting together this thorough comparison. I hear you that our envs don't need to be exactly 1:1 to MuJoCo's - that said, we'd be happy to prioritize any fixes to the differences you brought up, according to whether they:

  • impact training curves significantly (e.g. we noticed that adding noise to initial states does sometimes impact training)
  • produce a more pleasing gate (e.g. perhaps re-adding the contact force reward penalty)
  • ... some other reason?

Of the differences you found, do you have a suggestion for which might be the most important to address?

@benelot
Copy link
Contributor

benelot commented Sep 16, 2021

I agree with @sgillen on the tasks, but would reorder to:

  1. add back in the contact force reward/penalty
  2. adjust the mass, inertia, and torque limit
  3. add the wider distribution to initial state

On 1: If we want to copy the previous env, we need it, whether it helps with training or not, otherwise we diverge.
On 2: Is there any reason these were set in the brax ant env the way they are? Torque limit looks like the result of f(mass, inertia, mujoco_engine_details), so we should be able to set similar ones to mujoco. If they can not be adjusted exactly, I would suggest to fall back to the metric of "similar learning curve". In pybullet I once looked at the metric of "similar observation distribution shape" which says something about in which observational manifold the ant moves.
On 3: This is certainly important for higher robustness of the learned policy. Especially in the humanoids, adding some noise during testing but not training easily messes them up.

On my side I started to play a bit with brax and built some initial version of the humanoid standup but ,being on vacation, I am not done yet. I plan to begin building a first version of all required mujoco envs next week in brax just to see how they perform. Then we can do the same for every env as @sgillen did for ant.

@jkterry1
Copy link

Just to confirm, does the list of inconsistencies include the list of bugs in MuJuCo that we want to make sure that we aren't reproducing that I sent?

@sgillen
Copy link
Contributor

sgillen commented Sep 16, 2021

@erikfrey I agree with @benelot list on what to prioritize. They will probably impact training, making the environment slightly harder if anything, but also closer to the original. The contact reward might lead to more pleasing gaits but it's hard to say.

@jkterry1 I am not sure, can you post that list of bugs here?

@benelot
Copy link
Contributor

benelot commented Sep 17, 2021

@jkterry1 possibly means those:
(according to Antonin Raffin)

@jkterry1
Copy link

@benelot that's the list, thanks a ton

@erikfrey
Copy link
Collaborator

Can confirm that our HalfCheetah is at least not broken in the ways discussed in those blogs. In fact this is something we had to address in our paper comparing our envs to Mujoco's. See section E1 in the appendix for a brief discussion about this problem.

That said, I am quite prepared for folks to find new and interesting bugs as these envs get more attention! We'll be happy to address them when they come up :-)

We are 90% done on hopper. If someone would like to take a pass at Walker2d or Swimmer, please let me know. Otherwise we'll get to them soon.

@erikfrey
Copy link
Collaborator

Quick update - we now have the Hopper env, and tomorrow we will land Walker2d. We'll also add them soon to the colab with good default hparams. Other things in flight:

  • We've added back the contact force penalty and shown it works well for Ant, confirming for others and then will push
  • @erwincoumans is making great progress on a simple, small software renderer that we will use for env.render
  • Wider distributions to initial state is also in progress

@erikfrey
Copy link
Collaborator

OK! We now support state to pixels for env.render:

https://github.com/google/brax/blob/main/brax/io/image.py

Please keep in mind this is CPU rendering, so better for eval rendering and other programmatic use cases, rather than training. We will move to GPU/TPU rendering in the future, which should be suitable for training.

In the coming days we'll update our colabs with an example of how to use it.

@slerman12
Copy link

slerman12 commented Oct 14, 2021

I'm trying to making Brax/MuJoCo more apples-to-apples in the setup for them. I'm not sure what major differences need to be accounted for. Is there a set of operations that need to be called on Brax to get settings as similar to MuJoCo as possible? (e.g. this normalization mentioned in this issue here)

@sgillen
Copy link
Contributor

sgillen commented Oct 16, 2021

Hi @slerman12, the process is still ongoing I think to make the brax environments similar to Mujoco. This thread has some info on the major differences at this point, you can see the notebook I posted above as a starting point for comparing the environments in an "apples-to-apples" way. The normalization is not a difference by itself, the Mujoco envs don't have normalization built in. Usually training frameworks like stable baselines will normalize observations from environments, but that presents some difficulty in brax.

@jkterry1
Copy link

jkterry1 commented Oct 20, 2021

Per the meeting, we still need the following things before merging into Gym:

Adding missing environments:
Swimmer (Benjamin Ellenberger)
Standup (Brax team)
Inverted pendulum (Daniel Freeman)
Inverted double pendulum (Daniel Freeman)

Remove 0s where applicable (Brax team)
Remove unnecessary inheritance regarding hopper (Brax team)

@benelot
Copy link
Contributor

benelot commented Oct 20, 2021

I have not found pusher, reacher, striker, thrower anywhere in the brax repo. I think they are required as well @jkterry1. Are they somewhere internal @cdfreeman-google?

@erikfrey
Copy link
Collaborator

Reacher is here: https://github.com/google/brax/blob/main/brax/envs/reacher.py

Ah, I wasn't aware of pusher, striker, thrower as they are not here: https://gym.openai.com/envs/#mujoco

BUT I do see them here: https://github.com/openai/gym/tree/master/gym/envs/mujoco

We'll look into those on the Brax side unless anyone jumps in and would like to claim them.

@erikfrey
Copy link
Collaborator

OK more updates:

@erikfrey
Copy link
Collaborator

@vwxyzjn also in case you were wondering, the increasingly higher SPS is because the first call to reset() and step() is where the compilation happens. JIT compilation is pretty slow. If you want the "stable" SPS, you can add something like this:

    # env warmup
    next_obs = envs.reset()
    next_obs, reward, done, info = envs.step(actions[0])
    next_obs = envs.reset()

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    ...

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Oct 27, 2021

@erikfrey thanks for the reply! The SPS thing makes sense to me.

I was able to get the notebook working in the CPU mode after incorporating your suggested fix

gym_env = gym.make("brax-ant-v0", batch_size=args.num_envs)

However, I still have trouble running under the GPU runtime with the same error presented in the screenshot above. Did you manage to get the notebook working under the GPU runtime?

UnfilteredStackTrace: RuntimeError: INTERNAL: Failed to launch CUDA kernel: fusion_162 with block dimensions: 128x1x1 and grid dimensions: 1x1x1: CUDA_ERROR_OUT_OF_MEMORY: out of memory

@erikfrey
Copy link
Collaborator

@vwxyzjn hmm, yes I am using a GPU runtime, and I am not able to reproduce that issue. Only two things I can think of:

  • maybe try Runtime -> Factory reset runtime?
  • it's possible we are getting different devices. if you run this in the colab what do you see:
import jax
print(jax.devices()[0].device_kind)

I get assigned a Tesla K80.

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Oct 27, 2021

@erikfrey the assigned a GPU seems to be the difference. I had the pro subscription and it was giving me a Tesla P100-PCIE-16GB, I switched back to a normal account, got assigned a Tesla K80 and it worked fine.

I also tested out the env.render('rgb_array') API and it works great with the PR #84

image

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Oct 27, 2021

As a sidenote, rendering images does significantly slow down the throughput. If rendering HTML is faster, I personally would prefer doing that instead...

This could also be achieved with a wrapper called RecordHTML that collects the rollouts from the first sub environment of the vector env, and by the end of the training, it outputs an HTML labeled by the episode like (basically doing HTML(html.render(env.sys, [s.qp for s in rollout])))

@erikfrey
Copy link
Collaborator

erikfrey commented Oct 27, 2021

@vwxyzjn sure, we can make such a wrapper. Quick update: I think some part of that colab's PPO algorithm is still causing device copies - when I comment out the optimizing code block, SPS goes from 8k to 250k. I have more time tomorrow to look into why that's happening, but if you'd like to take a look in the meantime, please do let me know if anything obvious sticks out.

@erikfrey
Copy link
Collaborator

@vwxyzjn fyi I pulled some the remaining work items out into separate issues as this issue is getting large:

I'll update with progress on #88 and #89 we can do after that. Feel free to hop over to those issues to discuss more, and also see https://github.com/google/brax/projects/1 for what we're tracking overall for this effort.

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Oct 29, 2021

Thank you @erikfrey! This is very exciting. I'll try to help as much as I can :)

@erwincoumans
Copy link

erwincoumans commented Oct 29, 2021

As a sidenote, rendering images does significantly slow down the throughput.

Yes, the CPU pytinyrenderer is not intended to use during training, only afterwards to see the rollout.

2 things to make it much faster:
half the resolution (width and height), and disable anti-aliasing

Image(image.render(env.sys, [s.qp for s in rollout], width=160, height=120, ssaa=1))

@erikfrey What is needed to make those efficiency changes when using a Gym environment wrapper?

Perhaps add some members to the Gym wrappers to tune width, height and ssaa (instead of hardcoded 256,256)?

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Oct 29, 2021

Yes, the CPU pytinyrenderer is not intended to use during training, only afterwards to see the rollout.

@erwincoumans That makes sense!

I think the desired solution also depends on if rendering the whole episode is faster. If the episode has 100 frames. Is it more expensive to (1) render the image at each frame, or (2) render the images for all the 100 frames as a batch?

gym.wrappers.RecordVideo takes approach (1), which takes an rgb_array from brax at every step. So if there are 1000 sub environments running, approach (1) could considerably slow things down since all the other 999 envs are basically blocked.

Maybe approach (2) is more preferable because you can have a wrapper that caches the states for the episode, when that episode ends, the wrapper spits out a video or an HTML. This way the other 999 sub environments are not blocked at all during the rollouts. #

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Oct 29, 2021

Here is a quick example.

Imagine we have 10 envs, 10 steps (unrolls), and record 1 frame takes 1 second, and each step takes 1 second.

Then approach (1) would take num_envs * (t_record + t_step) * num_steps = 10 * (1+1) * 10) = 120s

Whereas approach (2) would take num_envs * t_step * num_steps + t_record * num_steps = 10 * 1 * 10 + 1 * 10 = 110s.

@jkterry1
Copy link

Because I forgot to post it earlier, here's my list of remaining needed features in Brax per the last group call we had:

-Swimmer
-Remove unnecessary inheritance regarding hopper
-Joint dislocation
--Adjust joint stiffness
--Generalized constraint system
-Proof of concept with fast training in PyTorch
-Brax compatible wrappers

@cdfreeman-google
Copy link
Collaborator

Just as a little mini-update on joint dislocation: Currently, Brax handles joint constraints using springy forces, but you can also rephrase these kinds of equality constraints as velocity-level updates which directly update the velocities of a system to satisfy constraints (instead of relying on strong springs to enforce the constraints). There's a bit of a duality to these approaches, but it turns out it's much easier to reduce jitter in the joints with velocity-level methods. As a proof of concept, here's Ant running around with its joints being governed by this new method:
2021-11-03 15 34 30

@erikfrey
Copy link
Collaborator

erikfrey commented Nov 8, 2021

OK! I have uploaded an example of PPO in pure PyTorch, running at 150,000 steps/sec on a Colab GPU. It can train Ant in a few minutes. We are far from Torch experts - I suspect there is an easy doubling in performance for someone who knows PyTorch better than we do:

Brax Training with PyTorch on GPU

I think is a good enough demonstration of using Brax from PyTorch via Gym, but if anyone out there would like to speed it up further, please have a go. @vwxyzjn - your CleanRL was a helpful reference, thank you! We tried a few things in that Colab (including using torch.jit) that may help bridge the perf gap in CleanRL.

@jkterry1
Copy link

For those following this thread, the remaining items are:
-Swimmer (Benjamin is working on this)
-Generalized constraint system
-Brax compatible wrappers (this is waiting on me)
-PR to Gym
-Phys2D environments to replace Box2D (future work)

@benelot
Copy link
Contributor

benelot commented Dec 11, 2021

We finished the swimmer and it is going to be in the repo soon.

@benelot
Copy link
Contributor

benelot commented Dec 11, 2021

download_20211211_134531.gif

Looks pretty decent!

@jkterry1
Copy link

Awesome!

@jkterry1
Copy link

The remaining action items from when we last met are:

-Grid line missing in rendering problem (Erik Frey)
-See if shadow jigging can be fixed (Erik Frey)
-Generalized constraint system
-Jumpy release
-PR to Gym
-Phys2D environments to replace Box2D
-Tensorflow conversion snippet so we can create a gym wrapper for it

@erikfrey
Copy link
Collaborator

  • OK Swimmer is live in 4f1e15a
  • Grid lines fixed in cb59468
  • Texture jiggling I was able to repro in Ubuntu Cinnamon / Chrome, but not a problem on other platforms. Poking around in three.js issues the consensus seemed to be "it's a driver issue" but if I find a better answer, I'll be sure to address it.
  • Jumpy release underway!

@jkterry1
Copy link

That's awesome!

A minor update to the Box2D environments on my end:

-PRs are up for the needed refactoring of lunar lander and bipedal walker
-People are ostensibly currently working on the needed pygame rendering rewrite
-It turns out that the car racing environment has been used in a bunch of important work that I didn't know of by David Ha and others, so that would need to be ported instead of deleted like I'd planned

@jkterry1
Copy link

There's now a PR fixing the (large) outstanding bug in car racing though

@jkterry1
Copy link

jkterry1 commented Dec 21, 2021

Updated notes before I forget:

-Generalized constraint system
-Separate Jumpy repo
-Documentation files for the new website
-PR to Gym
-Phys2D environments to replace Box2D
-Tensorflow conversion snippet so we can create a gym wrapper for it
-openai/gym#2456 (comment)

@jkterry1
Copy link

Hey, I had a discussion with a few people. The current to-do AFAIK is:

-Generalized constraint system
-Separate Jumpy repo
-Documentation files for the new website
-PR to Gym
-Phys2D environments to replace Box2D
-Tensorflow conversion snippet so we can create a gym wrapper for it
-openai/gym#2456 (comment)
-Supporting the non-vector API
-Upstreaming environments
-Pusher environment (Benjamin is working on this)
-#164

@erikfrey
Copy link
Collaborator

OK! Here is an example snippet for converting tensorflow tensors to JAX ndarrays:

import numpy as np
import tensorflow as tf
import jax.dlpack

tf_arr = tf.random.uniform((10,))
print(f'tensorflow tensor on device {tf_arr.device}')

dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)
jax_arr = jax.dlpack.from_dlpack(dl_arr)
print(f'jax ndarray on device {jax_arr.device()}')

np.testing.assert_array_equal(tf_arr, jax_arr)

I tried this on a colab with a GPU and it works great. Here's the output:

tensorflow tensor on device /job:localhost/replica:0/task:0/device:GPU:0
jax ndarray on device gpu:0

@jkterry1
Copy link

Plumbing fixes:
-Final merge of generalized constraint system
-Pusher environment (Benjamin is working on this)
-PR to Jumpy repo
-Fix default camera view for reacher (briefly check the others)

-PR to Gym
--Documentation files for the new website
--Adapt MuJoCo md files (https://github.com/Farama-Foundation/gym-docs/tree/main/docs/pages/environments/mujoco)
--#164

Future:
-Phys2D environments to replace Box2D

@erikfrey
Copy link
Collaborator

OK, we have a fix for the default camera view in 748229c. Next time we do a version bump, it'll land in everyone's viewers.

@erikfrey
Copy link
Collaborator

PR to Jumpy repo is here: Farama-Foundation/Jumpy#1

@bycn
Copy link

bycn commented May 2, 2022

Apologies if this is the wrong place to comment, feel free to redirect.

Presumably porting of environments hasn't been started yet, but what's the status of 2d simulation with Brax ("Phys2d")? How would physics compare to box2d physics?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants