Skip to content

Commit

Permalink
[RLlib] New ConnectorV2 API #6: Changes in SingleAgentEpisode & Singl…
Browse files Browse the repository at this point in the history
…eAgentEnvRunner. (#42296)
  • Loading branch information
sven1977 authored Jan 12, 2024
1 parent 3a306ef commit 6da2636
Show file tree
Hide file tree
Showing 15 changed files with 1,114 additions and 630 deletions.
11 changes: 11 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def __init__(self, algo_class=None):
self.enable_connectors = True
self._env_to_module_connector = None
self._module_to_env_connector = None
self.episode_lookback_horizon = 1
# TODO (sven): Rename into `sample_timesteps` (or `sample_duration`
# and `sample_duration_unit` (replacing batch_mode), like we do it
# in the evaluation config).
Expand Down Expand Up @@ -1405,6 +1406,7 @@ def rollouts(
module_to_env_connector: Optional[
Callable[[EnvType, "RLModule"], "ConnectorV2"]
] = NotProvided,
episode_lookback_horizon: Optional[int] = NotProvided,
use_worker_filter_stats: Optional[bool] = NotProvided,
update_worker_filter_stats: Optional[bool] = NotProvided,
rollout_fragment_length: Optional[Union[int, str]] = NotProvided,
Expand Down Expand Up @@ -1455,6 +1457,13 @@ def rollouts(
module_to_env_connector: A callable taking an Env and an RLModule as input
args and returning a module-to-env ConnectorV2 (might be a pipeline)
object.
episode_lookback_horizon: The amount of data (in timesteps) to keep from the
preceeding episode chunk when a new chunk (for the same episode) is
generated to continue sampling at a later time. The larger this value,
the more an env-to-module connector will be able to look back in time
and compile RLModule input data from this information. For example, if
your custom env-to-module connector (and your custom RLModule) requires
the previous 10 rewards as inputs, you must set this to at least 10.
use_worker_filter_stats: Whether to use the workers in the WorkerSet to
update the central filters (held by the local worker). If False, stats
from the workers will not be used and discarded.
Expand Down Expand Up @@ -1550,6 +1559,8 @@ def rollouts(
self._env_to_module_connector = env_to_module_connector
if module_to_env_connector is not NotProvided:
self._module_to_env_connector = module_to_env_connector
if episode_lookback_horizon is not NotProvided:
self.episode_lookback_horizon = episode_lookback_horizon
if use_worker_filter_stats is not NotProvided:
self.use_worker_filter_stats = use_worker_filter_stats
if update_worker_filter_stats is not NotProvided:
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ def test_ppo_compilation_and_schedule_mixins(self):

num_iterations = 2

for fw in framework_iterator(config, frameworks=("torch", "tf2")):
for fw in framework_iterator(config, frameworks=("tf2", "torch")):
# TODO (Kourosh) Bring back "FrozenLake-v1"
for env in ["CartPole-v1", "Pendulum-v1", "ALE/Breakout-v5"]:
print("Env={}".format(env))
for lstm in [False, True]:
for lstm in [True, False]:
print("LSTM={}".format(lstm))
config.training(model=get_model_config(fw, lstm=lstm))

Expand Down Expand Up @@ -175,7 +175,7 @@ def test_ppo_exploration_setup(self):
obs, prev_action=np.array(2), prev_reward=np.array(1.0)
)
)
check(np.mean(actions), 1.5, atol=0.2)
check(np.mean(actions), 1.5, atol=0.49)
algo.stop()

def test_ppo_free_log_std_with_rl_modules(self):
Expand Down
5 changes: 1 addition & 4 deletions rllib/connectors/connector_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,7 @@ def __call__(
"""Method for transforming input data into output data.
Args:
rl_module: An optional RLModule object that the connector might need to know
about. Note that normally, only module-to-env connectors get this
information at construction time, but env-to-module and learner
connectors won't (b/c they get constructed before the RLModule).
rl_module: The RLModule object that the connector connects to or from.
data: The input data abiding to `self.input_type` to be transformed by
this connector. Transformations might either be done in-place or a new
structure may be returned that matches `self.output_type`.
Expand Down
14 changes: 14 additions & 0 deletions rllib/connectors/env_to_module/default_env_to_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import batch
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import EpisodeType
from ray.util.annotations import PublicAPI


_, tf, _ = try_import_tf()


@PublicAPI(stability="alpha")
class DefaultEnvToModule(ConnectorV2):
"""Default connector piece added by RLlib to the end of any env-to-module pipeline.
Expand Down Expand Up @@ -77,4 +82,13 @@ def __call__(
# Note that state ins should NOT have the extra time dimension.
data[STATE_IN] = batch(states)

# Convert data to proper tensor formats, depending on framework used by the
# RLModule.
# TODO (sven): Support GPU-based EnvRunners + RLModules for sampling. Right
# now we assume EnvRunners are always only on the CPU.
if rl_module.framework == "torch":
data = convert_to_torch_tensor(data)
elif rl_module.framework == "tf2":
data = tree.map_structure(lambda s: tf.convert_to_tensor(s), data)

return data
9 changes: 9 additions & 0 deletions rllib/connectors/learner/default_learner_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ def __call__(
T=T,
)

# TODO (sven): Convert data to proper tensor formats, depending on framework
# used by the RLModule. We cannot do this right now as the RLModule does NOT
# know its own device. Only the Learner knows the device. Also, on the
# EnvRunner side, we assume that it's always the CPU (even though one could
# imagine a GPU-based EnvRunner + RLModule for sampling).
# if rl_module.framework == "torch":
# data = convert_to_torch_tensor(data, device=??)
# elif rl_module.framework == "tf2":
# data =
return data


Expand Down
Loading

0 comments on commit 6da2636

Please sign in to comment.