Skip to content

Commit

Permalink
feature(nyz): add PPOF ch4 reward demo suuport (#608)
Browse files Browse the repository at this point in the history
* feature(nyz): add ppof acrobot and metadrive demo

* tmp

* feature(nyz): add ppof minigrid demo and fix caller bugs
  • Loading branch information
PaParaZz1 authored Mar 10, 2023
1 parent 3b43ec0 commit 275141b
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 48 deletions.
52 changes: 50 additions & 2 deletions ding/bonus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import gym
from ding.envs import BaseEnv, DingEnvWrapper
from ding.envs.env_wrappers import MaxAndSkipWrapper, WarpFrameWrapper, ScaledFloatFrameWrapper, FrameStackWrapper, \
EvalEpisodeReturnEnv, TransposeWrapper, TimeLimitWrapper
EvalEpisodeReturnEnv, TransposeWrapper, TimeLimitWrapper, FlatObsWrapper, GymToGymnasiumWrapper
from ding.policy import PPOFPolicy


Expand All @@ -18,6 +18,9 @@ def get_instance_config(env: str) -> EasyDict:
cfg.learning_rate = 1e-3
cfg.action_space = 'continuous'
cfg.n_sample = 1024
elif env == 'acrobot':
cfg.learning_rate = 1e-4
cfg.n_sample = 400
elif env == 'rocket_landing':
cfg.n_sample = 2048
cfg.adv_norm = False
Expand Down Expand Up @@ -88,6 +91,25 @@ def get_instance_config(env: str) -> EasyDict:
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
elif env == 'minigrid_fourroom':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.learning_rate = 3e-4
cfg.epoch_per_collect = 10
cfg.entropy_weight = 0.001
elif env == 'metadrive':
cfg.learning_rate = 3e-4
cfg.action_space = 'continuous'
cfg.entropy_weight = 0.001
cfg.n_sample = 3000
cfg.epoch_per_collect = 10
cfg.learning_rate = 0.0001
cfg.model = dict(
encoder_hidden_size_list=[32, 64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
else:
raise KeyError("not supported env type: {}".format(env))
return cfg
Expand All @@ -100,6 +122,8 @@ def get_instance_env(env: str) -> BaseEnv:
return DingEnvWrapper(gym.make('LunarLander-v2', continuous=True))
elif env == 'bipedalwalker':
return DingEnvWrapper(gym.make('BipedalWalker-v3'), cfg={'act_scale': True})
elif env == 'acrobot':
return DingEnvWrapper(gym.make('Acrobot-v1'))
elif env == 'rocket_landing':
from dizoo.rocket.envs import RocketEnv
cfg = EasyDict({
Expand Down Expand Up @@ -176,8 +200,32 @@ def get_instance_env(env: str) -> BaseEnv:
})
ding_env_atari = DingEnvWrapper(gym.make(atari_env_list[env]), cfg=cfg)
ding_env_atari.enable_save_replay(env + '_log/')
obs = ding_env_atari.reset()
return ding_env_atari
elif env == 'minigrid_fourroom':
import gymnasium
return DingEnvWrapper(
gymnasium.make('MiniGrid-FourRooms-v0'),
cfg={
'env_wrapper': [
lambda env: GymToGymnasiumWrapper(env),
lambda env: FlatObsWrapper(env),
lambda env: TimeLimitWrapper(env, max_limit=300),
lambda env: EvalEpisodeReturnEnv(env),
]
}
)
elif env == 'metadrive':
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
cfg = dict(
map='XSOS',
horizon=4000,
out_of_road_penalty=40.0,
crash_vehicle_penalty=40.0,
out_of_route_done=True,
)
cfg = EasyDict(cfg)
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
else:
raise KeyError("not supported env type: {}".format(env))

Expand Down
15 changes: 12 additions & 3 deletions ding/bonus/ppof.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import partial
import os
import gym
import gymnasium
import torch
from ding.framework import task, OnlineRLContext
from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \
Expand All @@ -22,6 +23,7 @@ class PPOF:
'lunarlander_discrete',
'lunarlander_continuous',
'bipedalwalker',
'acrobot',
# ch2: action
'rocket_landing',
'drone_fly',
Expand All @@ -31,6 +33,9 @@ class PPOF:
'mario',
'di_sheep',
'procgen_bigfish',
# ch4: reward
'minigrid_fourroom',
'metadrive',
# atari
'atari_qbert',
'atari_kangaroo',
Expand Down Expand Up @@ -64,9 +69,9 @@ def __init__(
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))

action_space = self.env.action_space
if isinstance(action_space, gym.spaces.Discrete):
action_shape = action_space.n
elif isinstance(action_space, gym.spaces.Tuple):
if isinstance(action_space, (gym.spaces.Discrete, gymnasium.spaces.Discrete)):
action_shape = int(action_space.n)
elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)):
action_shape = get_hybrid_shape(action_space)
else:
action_shape = action_space.shape
Expand All @@ -84,6 +89,7 @@ def train(
n_iter_log_show: int = 500,
n_iter_save_ckpt: int = 1000,
context: Optional[str] = None,
reward_model: Optional[str] = None,
debug: bool = False
) -> None:
if debug:
Expand All @@ -92,6 +98,9 @@ def train(
# define env and policy
collector_env = self._setup_env_manager(collector_env_num, context, debug, 'collector')
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')
if reward_model is not None:
# self.reward_model = create_reward_model(reward_model, self.cfg.reward_model)
pass

with task.start(ctx=OnlineRLContext()):
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
Expand Down
41 changes: 28 additions & 13 deletions ding/envs/env/ding_env_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Union, Dict
from easydict import EasyDict
import gym
import gymnasium
import copy
import numpy as np
import treetensor.numpy as tnp
Expand All @@ -22,6 +23,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
Do not support subprocess env manager; Thus usually used in simple env.
- A config to create an env instance: Parameter `cfg` dict must contain `env_id`.
"""
self._env = None
self._raw_env = env
self._cfg = cfg
self._seed_api = seed_api # some env may disable `env.seed` api
Expand All @@ -36,7 +38,6 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
if 'env_id' not in self._cfg:
self._cfg.env_id = None
if env is not None:
self._init_flag = True
self._env = env
self._wrap_env(caller)
self._observation_space = self._env.observation_space
Expand All @@ -45,6 +46,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)
self._init_flag = True
else:
assert 'env_id' in self._cfg
self._init_flag = False
Expand Down Expand Up @@ -73,17 +75,30 @@ def reset(self) -> None:
name_prefix='rl-video-{}'.format(id(self))
)
self._replay_path = None
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
if self._seed_api:
self._env.seed(self._seed + np_seed)
self._action_space.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
if self._seed_api:
self._env.seed(self._seed)
self._action_space.seed(self._seed)
obs = self._env.reset()
obs = to_ndarray(obs, np.float32)
if isinstance(self._env, gym.Env):
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
if self._seed_api:
self._env.seed(self._seed + np_seed)
self._action_space.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
if self._seed_api:
self._env.seed(self._seed)
self._action_space.seed(self._seed)
obs = self._env.reset()
elif isinstance(self._env, gymnasium.Env):
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._action_space.seed(self._seed + np_seed)
obs = self._env.reset(seed=self._seed + np_seed)
elif hasattr(self, '_seed'):
self._action_space.seed(self._seed)
obs = self._env.reset(seed=self._seed)
else:
obs = self._env.reset()
else:
raise RuntimeError("not support env type: {}".format(type(self._env)))
obs = to_ndarray(obs)
return obs

# override
Expand All @@ -106,7 +121,7 @@ def step(self, action: Union[np.int64, np.ndarray]) -> BaseEnvTimestep:
if self._cfg.act_scale:
action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high)
obs, rew, done, info = self._env.step(action)
obs = to_ndarray(obs, np.float32)
obs = to_ndarray(obs)
rew = to_ndarray([rew], np.float32)
return BaseEnvTimestep(obs, rew, done, info)

Expand Down
86 changes: 86 additions & 0 deletions ding/envs/env_wrappers/env_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

from typing import Union, List, Tuple
from easydict import EasyDict
from functools import reduce
from collections import deque
import copy
import operator
import gym
import gymnasium
import numpy as np
from torch import float32

Expand All @@ -28,6 +31,7 @@
- FireResetWrapper: Take fire action at environment reset.
- GymHybridDictActionWrapper: Transform Gym-Hybrid's original ``gym.spaces.Tuple`` action space
to ``gym.spaces.Dict``.
- FlatObsWrapper: Flatten image and language observation into a vector.
'''


Expand Down Expand Up @@ -1088,6 +1092,88 @@ def step(self, action):
return obs, reward, done, info


class FlatObsWrapper(gym.Wrapper):
"""
Note: only suitable for these envs like minigrid.
"""

def __init__(self, env, maxStrLen=96):
super().__init__(env)

self.maxStrLen = maxStrLen
self.numCharCodes = 28

imgSpace = env.observation_space.spaces["image"]
imgSize = reduce(operator.mul, imgSpace.shape, 1)

self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(imgSize + self.numCharCodes * self.maxStrLen, ),
dtype="float32",
)

self.cachedStr: str = None

def observation(self, obs):
if isinstance(obs, tuple): # for compatibility of gymnasium
obs = obs[0]
image = obs["image"]
mission = obs["mission"]

# Cache the last-encoded mission string
if mission != self.cachedStr:
assert (len(mission) <= self.maxStrLen), f"mission string too long ({len(mission)} chars)"
mission = mission.lower()

strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype="float32")

for idx, ch in enumerate(mission):
if ch >= "a" and ch <= "z":
chNo = ord(ch) - ord("a")
elif ch == " ":
chNo = ord("z") - ord("a") + 1
elif ch == ",":
chNo = ord("z") - ord("a") + 2
else:
raise ValueError(f"Character {ch} is not available in mission string.")
assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
strArray[idx, chNo] = 1

self.cachedStr = mission
self.cachedArray = strArray

obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))

return obs

def reset(self, *args, **kwargs):
obs = self.env.reset(*args, **kwargs)
return self.observation(obs)

def step(self, *args, **kwargs):
o, r, d, i = self.env.step(*args, **kwargs)
o = self.observation(o)
return o, r, d, i


class GymToGymnasiumWrapper(gym.Wrapper):

def __init__(self, env):
assert isinstance(env, gymnasium.Env), type(env)
super().__init__(env)
self._seed = None

def seed(self, seed):
self._seed = seed

def reset(self):
if self.seed is not None:
return self.env.reset(seed=self._seed)
else:
return self.env.reset()


def update_shape(obs_shape, act_shape, rew_shape, wrapper_names):
"""
Overview:
Expand Down
3 changes: 2 additions & 1 deletion ding/framework/middleware/functional/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
if ctx.train_data is None:
return
train_output = policy.forward(ctx.train_data)
if ctx.train_iter % cfg.policy.learn.learner.hook.log_show_after_iter == 0:
#if ctx.train_iter % cfg.policy.learn.learner.hook.log_show_after_iter == 0:
if True:
if isinstance(ctx, OnlineRLContext):
logging.info(
'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(
Expand Down
2 changes: 1 addition & 1 deletion dizoo/gym_pybullet_drones/envs/gym_pybullet_drones_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,5 +266,5 @@ def plot_observation_curve(self) -> None:
if self._cfg["plot_observation"]:
self.observation_logger.plot()

def clone(self) -> 'GymPybulletDronesEnv':
def clone(self, caller: str) -> 'GymPybulletDronesEnv':
return GymPybulletDronesEnv(self.raw_cfg)
Loading

0 comments on commit 275141b

Please sign in to comment.