Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

please add observation and action spaces #241

Closed
bionicles opened this issue Jul 2, 2020 · 9 comments
Closed

please add observation and action spaces #241

bionicles opened this issue Jul 2, 2020 · 9 comments

Comments

@bionicles
Copy link

bionicles commented Jul 2, 2020

this is a key component of the gym API

here's a custom space for strings (license: MIT, author: Bion Howard)

class String(gym.Space):
    def __init__(
                self,
                length=None,
                letters=LETTERS,
                min_length=1,
                max_length=280):
        self.min_length = min_length
        self.max_length = max_length
        self.letters = letters
        self.length = length

    def sample(self):
        length = self.length if self.length else random.randint(self.min_length, self.max_length)
        s = ''
        for i in range(length):
            letter = random.choice(self.letters)
            s += letter
        return s

    def contains(self, x):
        is_a_string = isinstance(x, str)
        correct_length = self.min_length <= len(x) <= self.max_length
        correct_letters = all([l in self.letters for l in x])
        return is_a_string and correct_length and correct_letters
    
    def __repr__(self):
        return f"String(min_length={self.min_length},length={self.length},max_length={self.max_length},letters={self.letters})"
@MarcCote
Copy link
Contributor

MarcCote commented Jul 2, 2020

Hi @bionicles, I'd be happy to integrate it into TextWorld. Can you make a PR to add it to https://github.com/microsoft/TextWorld/blob/master/textworld/gym/spaces/text_spaces.py ? Also, note the existing textworld.gym.spaces.Char.

@bionicles
Copy link
Author

bionicles commented Jul 2, 2020

What space would be a reasonable default? It might be better to use an existing one if it's built for this (less code)

(just curious because spaces help with random actions and normalization)

@MarcCote
Copy link
Contributor

MarcCote commented Aug 4, 2020

Sorry for the delay in getting back to you, I just got back from paternity leave.

What default are you referring to? If you are talking about TextworldGymEnv, I've set it to None to force the user to think of what makes sense in their case. The main reason being I wasn't sure how to pick good values for LETTERS (or WORDS/VOCAB) and
max_length. If you have some ideas, I'm all hears.

@bionicles
Copy link
Author

bionicles commented Aug 5, 2020

congrats on being a new dad! here's what i wound up doing so far: self.observation_space = String() and posted the updated string space below

however, that's gonna pass raw strings to the agent, so the agent needs a string sensor to handle string observations

another option which plays better with frameworks would be to convert the text into a numpy array of UTF-8 bytes (uint8) then cast to float32 and normalize... this could go in a wrapper, and the observation could then just be a float32 gym.spaces.Box

# within nature/sense.py
sense_str = lambda mystring: jnp.array(list(bytes(mystring, "utf-8")), dtype=jnp.float32) / 255

here's a utf-8 actuator ( also could go in a wrapper ) ... it stops writing at the first non-UTF8 byte

# within nature/actuate.py
non_unicode_bytes = jnp.array([0, 247, 248, 249, 250, 251, 252, 253, 254, 255])


def _actuate_string(space, values, xmin=-1.0, xmax=1.0):
    decimal = rescale(xmin, xmax, 0, 255, values.flatten()).astype("int")
    bad_bytes = jnp.where(jnp.isin(decimal, non_unicode_bytes))[0]
    if bad_bytes.size > 0:
        decimal = decimal[: bad_bytes[0]]
    result = bytearray(decimal).decode("utf-8", errors="ignore")
    if space.letters:
        result = "".join(c for c in result if c in space.letterset)
    return result

then the rescale function is this:

# tricks/rescale.py
"min max scale function"
import jax.numpy as jnp
from jax import jit


@jit
def maybe_replace(z):
    return (
        jnp.nan_to_num(z, neginf=-1000.0, posinf=1000.0)
        if jnp.issubsctype(z, jnp.inexact)
        else z
    )


@jit
def rescale(xmin, xmax, ymin, ymax, inputs):
    "rescales inputs from [xmin, xmax] to [ymin, ymax]"
    xmin, xmax, ymin, ymax = [maybe_replace(z) for z in [xmin, xmax, ymin, ymax]]
    return jnp.nan_to_num(
        jnp.clip(
            ((((inputs - xmin) * (ymax - ymin)) / (xmax - xmin)) + ymin), ymin, ymax,
        )
    )

here's some tinkering with the string space

# nurture/spaces/string.py
import re

from jax import random
from gym import Space

from tricks import RNG  # just an iterator over jax.random.PRNGKey 

LETTERS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%&()*+,-./:;<=>?@[]^_`{|}~' "

regex = re.compile(
    r"[^abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%&()*+,-./:;<=>?@[\]^_`{\|}~' ]"
)


def sanitize_string(string, letters=None):
    "filter string so it only contains letters in 'letters' kwarg"
    if letters is None:
        return string
    else:
        return "".join(char for char in string if char in letters)


class String(Space):
    "a space of potential strings from min to max length with certain set of letters"

    def __init__(self, length=None, letters=LETTERS, min_len=0, max_len=4096):
        self.min_len = min_len
        self.max_len = max_len
        self.letters = letters
        self.len = length
        self.rng = None
        self.seed()

    def seed(self, initial=None):
        _seed = 420 if initial is None else initial
        assert _seed is not None
        self.rng = RNG(_seed)

    def sample(self):
        letters = self.letters if self.letters else LETTERS
        key = next(self.rng)
        length = self.len if self.len else self.max_len
        string = []
        for _ in range(length):
            key = next(self.rng)
            string.append(letters[random.choice(key, len(letters))])
        return "".join(string)

    def contains(self, x):
        if not isinstance(x, str):
            return False
        if not self.min_len <= len(x) <= self.max_len:
            return False
        if self.letters:
            if not all([l in self.letters for l in x]):
                return False
        return True

    def __repr__(self):
        letters = "DEFAULT" if self.letters == LETTERS else self.letters
        return f"String(min_len={self.min_len},len={self.len},max_len={self.max_len},letters={letters})"

    @property
    def letterset(self):
        if self.letters is None:
            return set()
        return set(self.letters)

here's a wrapper to make various difficulty levels:
easy difficulty provides more information and uses a Discrete action space
harder difficulties get into String action space

#nurture/textworld/env.py
"wrap the microsoft textworld MUD-style game"

from random import randint
import shutil
import os

import gym
import textworld
import textworld.gym

from nurture import String, sanitize_string


def _get_easy_options():
    options = textworld.GameOptions()
    options.nb_objects = randint(1, 8)
    options.quest_length = randint(1, 8)
    options.nb_rooms = randint(1, 8)
    return options


def _get_hard_options():
    options = textworld.GameOptions()
    options.nb_rooms = randint(8, 10)
    options.nb_objects = randint(8, 10)
    options.nb_parallel_quests = randint(1, 2)
    options.quest_length = randint(8, 10)
    options.quest_breadth = randint(1, 2)
    options.quest_depth = randint(2, 8)
    return options


GET_OPTIONS = dict(
    easy=_get_easy_options, hard=_get_hard_options, expert=_get_hard_options
)
INFOS = textworld.EnvInfos(
    objective=True,
    inventory=True,
    description=True,
    admissible_commands=True,
    feedback=True,
)
MAX_STEPS = dict(easy=50, hard=200, expert=420)

fp = os.path.dirname(__file__)


def _make_textworld(difficulty):
    try:
        shutil.rmtree(os.path.join(fp, "tw_games", difficulty))
    except Exception as _:
        pass
    options = GET_OPTIONS[difficulty]()
    options.path = os.path.join(fp, "tw_games", difficulty, "game.ulx")
    game_file, _ = textworld.make(options)
    return textworld.gym.register_game(
        game_file, INFOS, max_episode_steps=MAX_STEPS[difficulty]
    )


def _parse_inventory(inv):
    return inv.replace("You are carrying", "You carry").replace(":", "") + "."


class TextWorldWrapper(gym.Env):
    "a custom env to adjust the textworld env observation space and difficulty levels"

    def __init__(self, difficulty="easy"):
        self.observation_space = String()
        self.action_space = String()
        self.difficulty = difficulty
        if difficulty in ["easy", "hard"]:
            self.stringify = self._stringify_easy
        else:
            self.stringify = _stringify_hard
        self.commands = None
        self._env = None

    def reset(self):
        env_id = _make_textworld(self.difficulty)
        self._env = gym.make(env_id)
        _, i = self._env.reset()
        return self.stringify(i, stepping=False)

    def step(self, action):
        if isinstance(action, int):
            action = self.commands[action]
        _, r, d, i = self._env.step(action)
        return self.stringify(i), r, d, {}

    def render(self):
        self._env.render()

    def _stringify_easy(self, i, stepping=True):
        inventory = _parse_inventory(i["inventory"])
        self.commands = [
            c for c in i["admissible_commands"] if c not in ["inventory", "look"]
        ]
        self.action_space = gym.spaces.Discrete(len(self.commands))
        obs = (
            f'{i["objective"]} {i["description"]} {inventory} Commands: {self.commands}'
        )
        if stepping:
            obs += f" Feedback: {i['feedback']}"
        return sanitize_string(
            obs.replace("\n", " ")
            .replace("  ", " ")
            .replace("  ", " ")
            .replace("\\", "")
            .strip()
        )


def _stringify_hard(i, stepping=True):
    obs = i["feedback"] if stepping else i["objective"]
    obs = (
        obs.strip()
        .replace("\n", " ")
        .replace("  ", " ")
        .replace("  ", " ")
        .replace("\\", "")
    )
    if stepping and "You are carrying:" in obs:
        obs = _parse_inventory(obs)
    return sanitize_string(obs)

@bionicles
Copy link
Author

just for completeness, here is the rng class

# tricks/rng.py
import jax


@jax.tree_util.register_pytree_node_class
class RNG:
    "PRNGKey iterator"

    def __init__(self, seed):
        self.seed = seed
        self.key = jax.random.PRNGKey(seed)

    def __iter__(self):
        return self

    def __next__(self):
        self.key, output = jax.random.split(self.key)
        return output

    def tree_flatten(self):
        return ((self.seed, self.key), None)

    @classmethod
    def tree_unflatten(cls, _, rng_state):
        seed, key = rng_state
        new = cls(seed)
        new.key = key
        return new

    def __eq__(self, other):
        if not isinstance(other, RNG):
            return False
        same_key = jax.numpy.all(other.key == self.key)
        same_seed = other.seed == self.seed
        if same_seed and same_key:
            return True
        return False

@MarcCote
Copy link
Contributor

MarcCote commented Aug 5, 2020

Thanks for sharing your code. I like the style and it is very insightful.

I never thought of changing the env.action_space at every step (i.e. choice-based setting) but that might not play well with some existing algorithms, e.g. in the OpenAI's baselines repo: PolicyWithValue where .n will change throughout the episode.

I'd be happy to integrate your String space to textworld.gym.spaces. Or, maybe, it could be added into the Gym codebase directly?

@ai-nikolai
Copy link

ai-nikolai commented Nov 17, 2023

@MarcCote @bionicles any updates on the above. It seems like there are a few warnings these days from gym.

Maybe it is now related to:
#324

@MarcCote
Copy link
Contributor

This was not integrated in TextWorld yet. I'd be happy to review any PR though.
Note that gym is no longer under development, it has been replaced with gymnasium which seems to have a Text space.
https://gymnasium.farama.org/api/spaces/fundamental/#gymnasium.spaces.Text

@MarcCote
Copy link
Contributor

Dependency on gym has been dropped. No need for those spaces anymore.
See #341

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants