-
Notifications
You must be signed in to change notification settings - Fork 261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make accelerator-friendly versions of a subset of the gym wrappers #89
Comments
So I had a discussion with Costa, and I don't know what the best way to proceed forward here is. Right now, literally no current Gym wrappers can be used with Brax environments because they're all numpy based which has almost no interoperability. This would mean that, for full support (Phys2D, Phys3D, classic control, etc.), there would need to be duplicates of all the Gym wrappers written in Jax. This is rather undesirable for several reasons: I know you wrote a conversion code between Jax and torch tensors to get things to work, is there any way to shim conversion code to get Gym wrappers to work as is to avoid this whole situation? An additional related point is- where should the Jax->torch (and possible other) Gym adapters live, given things like the plan for Jax based classic control? One reasonable candidate location is https://github.com/Farama-Foundation/SuperSuit. SuperSuit is a collection of Gym and PettingZoo wrappers, once gym.wrappers is complete and stable those will be removed, and the pettingzoo wrappers will be merged into pettingzoo.wrappers, like Gym. However, SuperSuit is still going to retain various preprocessing wrappers that don't make sense to house in either Gym or PettingZoo, such as ones purely for manipulating Gym or PettingZoo vector environments or converting the PettingZoo environments to Gym environments (this is something you actually want to do in certain odd cases). It doesn't have to go there is there's a better option or one that you'd prefer or something, it would just be odd to use a Brax wrapper for classic control for instance. |
@jkterry1 please take a look at this shim we put in Brax a while back: https://github.com/google/brax/blob/main/brax/jumpy.py I think it does what you're asking - if you pass in numpy arrays you get out numpy arrays. If you pass in jax arrays you get back out jax arrays. So: import numpy as np
from jax import numpy as jnp
from brax import jumpy as jp
def myfunc(x: jp.ndarray) -> jp.ndarray:
# do all kinds of numpy ops here but use 'jp' instead of 'np' e.g.
x = jp.cos(x)
return x
for x in [np.array(1), jnp.array(1)]:
print(f'input type: {type(x)} output type: {type(myfunc(x))}') prints:
Using such a technique. any gym wrapper could be written once, and could handle any needed input/output types. Btw, such a shim could be extended too to handle torch tensors, if needed. re: where to put a jax->torch converter, it's a handful of lines of code, so feel free to put it wherever you find handy. |
So if I understand correctly then, the prudent move would be to port all of the Gym wrappers to use Jumpy? I have no problem doing that, especially since gym.wrappers needs was planned for a massive overhaul anyways, we'd just have to actually do it. Jumpy would also have to be a separate package on PyPI as it would become a core Gym dependency. We'd also have to sort out who would do this, the big wrapper fixes are one of the things I wanted a proper paid maintainer for because they're labor intensive and no one on my side in the community has been willing to do them. Regarding shim code, I have a few questions: On further contemplation, I also think that putting the shim in gym.wrappers would make more sense. |
At some point in the future we'd also likely want jax based image resizing code, which tends to be pretty performance sensitive. |
@jkterry1 certainly, feel free to pull in jumpy! As for whether it needs to be in a package, defer to you, but you can also just pull in the file, and mangle it, give it a less silly name - anything that is useful to gym. It may be missing a few ops that gym wrappers would use, but it is easy to add in more. brax to tf: we have a converter for models here: https://github.com/google/brax/blob/main/brax/io/export.py but haven't had any requests to actually interoperate with tensorflow tensors as input/output, interestingly enough. It could probably be done also through jax2tf Torch converter lives here: https://github.com/google/brax/blob/main/brax/io/torch.py ( code kindly donated by @lebrice ) |
I definitely would rather jumpy be a package than in Gym, because numpy/jax functions is more of a general purpose tool than a Gym specific thing. I can release a package for it if needed, but if you'd be willing to host and maintain the package I think that would make more sense since it's ultimately a Jax tool and you guys created it. I also honestly like the name. FWIW I briefly mentioned the idea of releasing stand alone jumpy to Daniel Suo in a meeting just a moment ago (thanks for that introduction by the way) and he also thought it would be valuable as a tool to the Jax community. As part of the merging of Brax into gym, we'd need to have two wrappers that would function something like this:
A ton of people use TF in Gym, and this has to be natively supported and easy to use. I think the only to-do's regarding wrappers on your end would be contributing those to wrappers to Gym as part of the Brax PR and release jumpy for the gym wrappers rewrite. |
Hey guys, I've got a proposition: what about converting the transformation functions in Gym to single-dispatch callables (see the singledispatch function from the functools package) and let brax/other external packages register custom handlers for their types? |
Oh and about a jax2tf wrapper, I've written one before, a few months back, but I wasn't able to get an in-place cuda tensor conversion using the dlpack interface of tensorflow. The tensors seemed to always end up on the CPU for some reason. I might be wrong though, maybe it was something about my setup, or they have patched that issue since. |
@lebrice Could you please elaborate on your proposal with single-dispatch callables a bit more? |
Gladly @jkterry1 :) For example, consider something like this: from functools import partial, singledispatch
from gym import Env
from typing import Tuple, TypeVar
from gym.wrappers.transform_observation import TransformObservation
import numpy as np
T = TypeVar("T")
# 1. Extract the core "transformation" from the wrapper and decorate it with @singledispatch:
@singledispatch
def resize_image(image: T, size: Tuple[int, int]) -> T:
# default handler.
raise NotImplementedError(f"No handler registered for images of type {type(image)}.")
@resize_image.register(np.ndarray)
def _resize_image_array(image: np.ndarray, size: Tuple[int, int]) -> np.ndarray:
return some_numpy_resizing_logic(image, size) # default handler.
# 2. Have the wrappers in gym use these single-dispatch callables:
class ResizeImageWrapper(TransformObservation):
def __init__(self, env: Env, size: Tuple[int, int]):
self.size = size
super().__init__(env=env, f=lambda img: resize_image(img, size=size)) Then, somewhere in the brax repo, for instance, you'd have something like this: import jax.numpy as jnp
from gym.wrappers.somewhere import resize_image
@resize_image.register(jnp.ndarray)
def resize_jax_image(image: jnp.ndarray, size: Tuple[int, int]) -> jnp.ndarray:
return some_jax_resizing_logic(image, size)
Let me know if this makes sense. |
For more complicated, "stateful" wrappers / functions, from gym import Wrapper
from functools import singledispatchmethod
from typing import Any
class DoSomethingComplicatedWrapper(Wrapper):
def __init__(self, env: Env):
super().__init__(env=env)
self.state = None
def reset(self):
self.state = None
obs = self.env.reset()
obs = self.some_method(obs)
return obs
def step(self, action):
obs, reward, done, info = self.env.step(action=action)
obs = self.some_method(obs)
return obs, reward, done, info
@singledispatchmethod
def some_method(self, obs: Any):
raise NotImplementedError(f"No registered handler for observations of type {type(obs)}.")
@some_method.register(np.ndarray)
def _some_method_numpy(self, obs: np.ndarray):
# Some kind of operation that depends on or affects the state of the wrapper
if self.state is None:
self.state = obs
else:
p = np.random.uniform(0, 1, size=obs.shape)
self.state = (1 - p) * self.state + p * obs
return self.state
...
# then in another repo (e.g. brax)
@DoSomethingComplicatedWrapper.some_method.register(jnp.ndarray)
def _do_something_stateful_jax(self, obs: jnp.ndarray):
# the jax equivalent, etc.
... |
Sorry for the late follow-up. What I'm essentially suggesting is two-fold:
Here are the main advantages of this solution IMO:
|
Candidates are:
gym.wrappers.RecordEpisodeStatistics
gym.wrappers.ClipAction
gym.wrappers.NormalizeObservation
gym.wrappers.TransformObservation
gym.wrappers.NormalizeReward
gym.wrappers.TransformReward
And also possibly a new wrapper:
gym.wrappers.RecordHTML
/cc @vwxyzjn @jkterry1
The text was updated successfully, but these errors were encountered: