Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make accelerator-friendly versions of a subset of the gym wrappers #89

Open
erikfrey opened this issue Oct 28, 2021 · 13 comments
Open

Make accelerator-friendly versions of a subset of the gym wrappers #89

erikfrey opened this issue Oct 28, 2021 · 13 comments
Labels
enhancement New feature or request

Comments

@erikfrey
Copy link
Collaborator

erikfrey commented Oct 28, 2021

Candidates are:

gym.wrappers.RecordEpisodeStatistics
gym.wrappers.ClipAction
gym.wrappers.NormalizeObservation
gym.wrappers.TransformObservation
gym.wrappers.NormalizeReward
gym.wrappers.TransformReward

And also possibly a new wrapper:

gym.wrappers.RecordHTML

/cc @vwxyzjn @jkterry1

@jkterry1
Copy link

So I had a discussion with Costa, and I don't know what the best way to proceed forward here is.

Right now, literally no current Gym wrappers can be used with Brax environments because they're all numpy based which has almost no interoperability. This would mean that, for full support (Phys2D, Phys3D, classic control, etc.), there would need to be duplicates of all the Gym wrappers written in Jax.

This is rather undesirable for several reasons:
-Copying and maintaining copied code sucks
-gym.wrappers needs a big overhaul, doesn't have various key wrappers, certain wrappers are missing key features, and so on, so there's not a reference set to copy
-We'd have to make sure that the Jax and numpy gym wrappers are have identical function for reproducibility papers when people are describing wrappers in the literature
-There isn't a clear place for the Jax wrappers to even live if classic control is jax-ified

I know you wrote a conversion code between Jax and torch tensors to get things to work, is there any way to shim conversion code to get Gym wrappers to work as is to avoid this whole situation?

An additional related point is- where should the Jax->torch (and possible other) Gym adapters live, given things like the plan for Jax based classic control?

One reasonable candidate location is https://github.com/Farama-Foundation/SuperSuit. SuperSuit is a collection of Gym and PettingZoo wrappers, once gym.wrappers is complete and stable those will be removed, and the pettingzoo wrappers will be merged into pettingzoo.wrappers, like Gym. However, SuperSuit is still going to retain various preprocessing wrappers that don't make sense to house in either Gym or PettingZoo, such as ones purely for manipulating Gym or PettingZoo vector environments or converting the PettingZoo environments to Gym environments (this is something you actually want to do in certain odd cases). It doesn't have to go there is there's a better option or one that you'd prefer or something, it would just be odd to use a Brax wrapper for classic control for instance.

@erikfrey
Copy link
Collaborator Author

erikfrey commented Nov 15, 2021

@jkterry1 please take a look at this shim we put in Brax a while back:

https://github.com/google/brax/blob/main/brax/jumpy.py

I think it does what you're asking - if you pass in numpy arrays you get out numpy arrays. If you pass in jax arrays you get back out jax arrays. So:

import numpy as np
from jax import numpy as jnp
from brax import jumpy as jp

def myfunc(x: jp.ndarray) -> jp.ndarray:
  # do all kinds of numpy ops here but use 'jp' instead of 'np' e.g.
  x = jp.cos(x)
  return x

for x in [np.array(1),  jnp.array(1)]:
  print(f'input type: {type(x)} output type: {type(myfunc(x))}')

prints:

input type: <class 'numpy.ndarray'> output type: <class 'numpy.float64'>
input type: <class 'google3.third_party.tensorflow.compiler.xla.python.xla_extension.DeviceArray'> output type: <class 'google3.third_party.tensorflow.compiler.xla.python.xla_extension.DeviceArray'>

Using such a technique. any gym wrapper could be written once, and could handle any needed input/output types. Btw, such a shim could be extended too to handle torch tensors, if needed.

re: where to put a jax->torch converter, it's a handful of lines of code, so feel free to put it wherever you find handy.

@jkterry1
Copy link

So if I understand correctly then, the prudent move would be to port all of the Gym wrappers to use Jumpy?

I have no problem doing that, especially since gym.wrappers needs was planned for a massive overhaul anyways, we'd just have to actually do it. Jumpy would also have to be a separate package on PyPI as it would become a core Gym dependency. We'd also have to sort out who would do this, the big wrapper fixes are one of the things I wanted a proper paid maintainer for because they're labor intensive and no one on my side in the community has been willing to do them.

Regarding shim code, I have a few questions:
-Does a Brax to TF shim need to exist or does the XLA give us this for free?
-I apologize if I'm just missing it, but could you please link me to the torch shim code?

On further contemplation, I also think that putting the shim in gym.wrappers would make more sense.

@jkterry1
Copy link

At some point in the future we'd also likely want jax based image resizing code, which tends to be pretty performance sensitive.

@erikfrey
Copy link
Collaborator Author

@jkterry1 certainly, feel free to pull in jumpy! As for whether it needs to be in a package, defer to you, but you can also just pull in the file, and mangle it, give it a less silly name - anything that is useful to gym. It may be missing a few ops that gym wrappers would use, but it is easy to add in more.

brax to tf: we have a converter for models here: https://github.com/google/brax/blob/main/brax/io/export.py but haven't had any requests to actually interoperate with tensorflow tensors as input/output, interestingly enough. It could probably be done also through jax2tf

Torch converter lives here: https://github.com/google/brax/blob/main/brax/io/torch.py ( code kindly donated by @lebrice )

@jkterry1
Copy link

jkterry1 commented Nov 17, 2021

I definitely would rather jumpy be a package than in Gym, because numpy/jax functions is more of a general purpose tool than a Gym specific thing. I can release a package for it if needed, but if you'd be willing to host and maintain the package I think that would make more sense since it's ultimately a Jax tool and you guys created it. I also honestly like the name. FWIW I briefly mentioned the idea of releasing stand alone jumpy to Daniel Suo in a meeting just a moment ago (thanks for that introduction by the way) and he also thought it would be valuable as a tool to the Jax community.

As part of the merging of Brax into gym, we'd need to have two wrappers that would function something like this:

from gym.wrappers import jax2torch
jax_env = ....
jax2torch(jax_env)
...
<torch code>
from gym.wrappers import jax2tf
jax_env = ....
jax2tf(jax_env)
...
<tf code>

A ton of people use TF in Gym, and this has to be natively supported and easy to use. I think the only to-do's regarding wrappers on your end would be contributing those to wrappers to Gym as part of the Brax PR and release jumpy for the gym wrappers rewrite.

@lebrice
Copy link
Contributor

lebrice commented Nov 18, 2021

Hey guys, I've got a proposition: what about converting the transformation functions in Gym to single-dispatch callables (see the singledispatch function from the functools package) and let brax/other external packages register custom handlers for their types?

@lebrice
Copy link
Contributor

lebrice commented Nov 18, 2021

Oh and about a jax2tf wrapper, I've written one before, a few months back, but I wasn't able to get an in-place cuda tensor conversion using the dlpack interface of tensorflow. The tensors seemed to always end up on the CPU for some reason. I might be wrong though, maybe it was something about my setup, or they have patched that issue since.

@erikfrey
Copy link
Collaborator Author

@jkterry1 OK, we can make a jumpy pypi package after we've cleared out some of the other remaining TODO's.

@lebrice thanks for the heads up about dlpack, it's definitely critical to stay on device. We'll poke around and see what's up there.

@jkterry1
Copy link

@lebrice Could you please elaborate on your proposal with single-dispatch callables a bit more?

@lebrice
Copy link
Contributor

lebrice commented Nov 21, 2021

Gladly @jkterry1 :)
Here's what I mean: We could use singledispatch and singledispatchmethod from the built-in functools module to make the gym wrappers agnostic to exactly what kind of data they are handing.

For example, consider something like this:

from functools import partial, singledispatch
from gym import Env
from typing import Tuple, TypeVar
from gym.wrappers.transform_observation import TransformObservation
import numpy as np

T = TypeVar("T")

# 1. Extract the core "transformation" from the wrapper and decorate it with @singledispatch:

@singledispatch
def resize_image(image: T, size: Tuple[int, int]) -> T:
    # default handler.
    raise NotImplementedError(f"No handler registered for images of type {type(image)}.")

@resize_image.register(np.ndarray)
def _resize_image_array(image: np.ndarray, size: Tuple[int, int]) -> np.ndarray:
    return some_numpy_resizing_logic(image, size)  # default handler.

# 2. Have the wrappers in gym use these single-dispatch callables:

class ResizeImageWrapper(TransformObservation):
    def __init__(self, env: Env, size: Tuple[int, int]):
        self.size = size
        super().__init__(env=env, f=lambda img: resize_image(img, size=size))

Then, somewhere in the brax repo, for instance, you'd have something like this:

import jax.numpy as jnp
from gym.wrappers.somewhere import resize_image

@resize_image.register(jnp.ndarray)
def resize_jax_image(image: jnp.ndarray, size: Tuple[int, int]) -> jnp.ndarray:
    return some_jax_resizing_logic(image, size)

singledispatch is great for any kind of transformation or stateless function. singledispatchmethod can be used to register new versions of a method in the same way as singledispatch, so it is more useful for customizing a function that affects the state.

Let me know if this makes sense.

@lebrice
Copy link
Contributor

lebrice commented Nov 21, 2021

For more complicated, "stateful" wrappers / functions, singledispatchmethod can be used:

from gym import Wrapper
from functools import singledispatchmethod
from typing import Any

class DoSomethingComplicatedWrapper(Wrapper):
    def __init__(self, env: Env):
        super().__init__(env=env)
        self.state = None

    def reset(self):
        self.state = None
        obs = self.env.reset()
        obs = self.some_method(obs)
        return obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action=action)
        obs = self.some_method(obs)
        return obs, reward, done, info

    @singledispatchmethod
    def some_method(self, obs: Any):
        raise NotImplementedError(f"No registered handler for observations of type {type(obs)}.")
    
    @some_method.register(np.ndarray)
    def _some_method_numpy(self, obs: np.ndarray):
        # Some kind of operation that depends on or affects the state of the wrapper
        if self.state is None:
            self.state = obs
        else:    
            p = np.random.uniform(0, 1, size=obs.shape)
            self.state = (1 - p) * self.state + p * obs
        return self.state
...

# then in another repo (e.g. brax)

@DoSomethingComplicatedWrapper.some_method.register(jnp.ndarray)
def _do_something_stateful_jax(self, obs: jnp.ndarray):
    # the jax equivalent, etc.
    ...

@lebrice
Copy link
Contributor

lebrice commented Feb 2, 2022

Sorry for the late follow-up.

What I'm essentially suggesting is two-fold:

  1. make the transformation-style wrappers in Gym more agnostic to the exact type of data they are manipulating;
  2. annotate the main transformation functions or methods using functools.singledispatch and functools.singledispatchmethod, respectively, so that other packages outside gym can register new handlers for their datatypes. (Also, detail: these functions/methods should raise a NotImplementedError when encountering a datatype that isn't supported, as above, so it's always clear to the users.)

Here are the main advantages of this solution IMO:

  • No need to duplicate all the "fluff" (wrapper classes, etc) in multiple packages / repositories.
    The same wrapper (e.g. NormalizeObservation) can be used everywhere. The only thing that would be different about using such a transformation wrapper on a gym.Env, an Env from Brax that has jax.numpy.ndarray observations, or yet another with torch.Tensors, is the transformation itself!
  • In some cases, it's a very simple change: Just annotate the functions of interest with @singledispatch. Specialized handlers can then be implemented in the repos that extend gym with more environments. This is somewhat akin to the plugin system used for the environments.

Any thoughts @erikfrey , @jkterry1 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants