Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Reinstate trajectory view API tests. #18809

Merged
merged 3 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1399,13 +1399,12 @@ py_test(
srcs = ["evaluation/tests/test_rollout_worker.py"]
)

# mysteriously times out on travis.
#py_test(
# name = "evaluation/tests/test_trajectory_view_api",
# tags = ["team:ml", "evaluation"],
# size = "medium",
# srcs = ["evaluation/tests/test_trajectory_view_api.py"]
#)
py_test(
name = "evaluation/tests/test_trajectory_view_api",
tags = ["team:ml", "evaluation"],
size = "medium",
srcs = ["evaluation/tests/test_trajectory_view_api.py"]
)


# --------------------------------------------------------------------
Expand Down
87 changes: 44 additions & 43 deletions rllib/evaluation/tests/test_trajectory_view_api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import copy
import gym
from gym.spaces import Box, Discrete
import numpy as np
import unittest

import ray
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
import ray.rllib.agents.dqn as dqn
import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.examples.env.multi_agent import MultiAgentPendulum
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.examples.policy.episode_env_aware_policy import \
EpisodeEnvAwareAttentionPolicy, EpisodeEnvAwareLSTMPolicy
Expand All @@ -28,7 +26,7 @@ def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs):
assert train_batch.count == 201
assert sum(train_batch[SampleBatch.SEQ_LENS]) == 201
for k, v in train_batch.items():
if k == "state_in_0":
if k in ["state_in_0", SampleBatch.SEQ_LENS]:
assert len(v) == len(train_batch[SampleBatch.SEQ_LENS])
else:
assert len(v) == 201
Expand Down Expand Up @@ -65,7 +63,7 @@ def test_traj_view_normal_case(self):
view_req_model = policy.model.view_requirements
view_req_policy = policy.view_requirements
assert len(view_req_model) == 1, view_req_model
assert len(view_req_policy) == 8, view_req_policy
assert len(view_req_policy) == 10, view_req_policy
for key in [
SampleBatch.OBS,
SampleBatch.ACTIONS,
Expand Down Expand Up @@ -111,7 +109,8 @@ def test_traj_view_lstm_prev_actions_and_rewards(self):
view_req_policy = policy.view_requirements
# 7=obs, prev-a + r, 2x state-in, 2x state-out.
assert len(view_req_model) == 7, view_req_model
assert len(view_req_policy) == 19, view_req_policy
assert len(view_req_policy) == 20,\
(len(view_req_policy), view_req_policy)
for key in [
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
SampleBatch.DONES, SampleBatch.NEXT_OBS,
Expand Down Expand Up @@ -171,9 +170,9 @@ def test_traj_view_attention_net(self):
)
rw = trainer.workers.local_worker()
sample = rw.sample()
assert sample.count == config["rollout_fragment_length"]
assert sample.count == trainer.config["rollout_fragment_length"]
results = trainer.train()
assert results["train_batch_size"] == config["train_batch_size"]
assert results["timesteps_total"] == config["train_batch_size"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a random question that I have been curious about for a while: how much do we honor the train_batch_size param here?
for example, in complete_episode mode, or if there is sample replay, will we ever give a training batch that is of very different size?
thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We try to do our best to honor it, but it's not guaranteed to be exact always.
The reason is that we do parallel rollouts with a fixed (or full episode length) step limit per vectorized(!) environment. Depending on the number of vectorized sub-envs per worker and the number of workers, the final train batch may be slightly off. For PPO for example, we auto-correct the rollout_fragment_length (since a few releases ago) based on these factors to better match the train_batch_size, but of course if you have lots of odd numbers in these setting, you will not get the train batch exactly right.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation, that's my impression as well.
I actually have a feeling sometimes we may be off a lot.
I can probably do some testing when I get a chance.
thanks.

trainer.stop()

def test_traj_view_next_action(self):
Expand All @@ -188,9 +187,14 @@ def test_traj_view_next_action(self):
)
# Add the next action to the view reqs of the policy.
# This should be visible then in postprocessing and train batches.
# Switch off for action computations (can't be there as we don't know
# the next action already at action computation time).
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
"next_actions"] = ViewRequirement(
SampleBatch.ACTIONS, shift=1, space=action_space)
SampleBatch.ACTIONS,
shift=1,
space=action_space,
used_for_compute_actions=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should have validation for this field somewhere? seems easy to miss, and not straight-forward for regular users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean to check, whether it's even possible to have this in the action computation event, even though the shift is >0 from a "collected" field, like actions? Great idea!

# Make sure, we have DONEs as well.
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
"dones"] = ViewRequirement()
Expand All @@ -209,7 +213,7 @@ def test_traj_view_next_action(self):
expected_a_ = a_

def test_traj_view_lstm_functionality(self):
action_space = Box(-float("inf"), float("inf"), shape=(3, ))
action_space = Box(float("-inf"), float("inf"), shape=(3, ))
obs_space = Box(float("-inf"), float("inf"), (4, ))
max_seq_len = 50
rollout_fragment_length = 200
Expand All @@ -230,37 +234,30 @@ def policy_fn(agent_id, episode, **kwargs):
"use_lstm": True,
"max_seq_len": max_seq_len,
},
},
}

rollout_worker_w_api = RolloutWorker(
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config=config,
rollout_fragment_length=rollout_fragment_length,
policy_spec=policies,
policy_mapping_fn=policy_fn,
num_envs=1,
)
rollout_worker_wo_api = RolloutWorker(
rw = RolloutWorker(
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config=config,
rollout_fragment_length=rollout_fragment_length,
policy_spec=policies,
policy_mapping_fn=policy_fn,
normalize_actions=False,
num_envs=1,
)

for iteration in range(20):
result = rollout_worker_w_api.sample()
result = rw.sample()
check(result.count, rollout_fragment_length)
pol_batch_w = result.policy_batches["pol0"]
assert pol_batch_w.count >= rollout_fragment_length
analyze_rnn_batch(pol_batch_w, max_seq_len)

result = rollout_worker_wo_api.sample()
pol_batch_wo = result.policy_batches["pol0"]
check(pol_batch_w, pol_batch_wo)
analyze_rnn_batch(
pol_batch_w,
max_seq_len,
view_requirements=rw.policy_map["pol0"].view_requirements)

def test_traj_view_attention_functionality(self):
action_space = Box(-float("inf"), float("inf"), shape=(3, ))
action_space = Box(float("-inf"), float("inf"), shape=(3, ))
obs_space = Box(float("-inf"), float("inf"), (4, ))
max_seq_len = 50
rollout_fragment_length = 201
Expand All @@ -280,14 +277,15 @@ def policy_fn(agent_id, episode, **kwargs):
"model": {
"max_seq_len": max_seq_len,
},
},
}

rollout_worker_w_api = RolloutWorker(
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config=config,
rollout_fragment_length=rollout_fragment_length,
policy_spec=policies,
policy_mapping_fn=policy_fn,
normalize_actions=False,
num_envs=1,
)
batch = rollout_worker_w_api.sample()
Expand All @@ -296,37 +294,38 @@ def policy_fn(agent_id, episode, **kwargs):
def test_counting_by_agent_steps(self):
"""Test whether a PPOTrainer can be built with all frameworks."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
action_space = Discrete(2)
obs_space = Box(float("-inf"), float("inf"), (4, ), dtype=np.float32)

num_agents = 3

config["num_workers"] = 2
config["num_sgd_iter"] = 2
config["framework"] = "torch"
config["rollout_fragment_length"] = 21
config["train_batch_size"] = 147
config["multiagent"] = {
"policies": {
"p0": (None, obs_space, action_space, {}),
"p1": (None, obs_space, action_space, {}),
},
"policies": {"p0", "p1"},
"policy_mapping_fn": lambda aid, **kwargs: "p{}".format(aid),
"count_steps_by": "agent_steps",
}
tune.register_env(
"ma_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2}))
# Env setup.
config["env"] = MultiAgentPendulum
config["env_config"] = {"num_agents": num_agents}

num_iterations = 2
trainer = ppo.PPOTrainer(config=config, env="ma_cartpole")
trainer = ppo.PPOTrainer(config=config)
results = None
for i in range(num_iterations):
results = trainer.train()
self.assertGreater(results["agent_timesteps_total"],
num_iterations * config["train_batch_size"])
self.assertLess(results["agent_timesteps_total"],
(num_iterations + 1) * config["train_batch_size"])
self.assertEqual(results["agent_timesteps_total"],
results["timesteps_total"] * num_agents)
self.assertGreaterEqual(results["agent_timesteps_total"],
num_iterations * config["train_batch_size"])
self.assertLessEqual(results["agent_timesteps_total"],
(num_iterations + 1) * config["train_batch_size"])
trainer.stop()


def analyze_rnn_batch(batch, max_seq_len):
def analyze_rnn_batch(batch, max_seq_len, view_requirements):
count = batch.count

# Check prev_reward/action, next_obs consistency.
Expand Down Expand Up @@ -399,7 +398,9 @@ def analyze_rnn_batch(batch, max_seq_len):
batch,
max_seq_len=max_seq_len,
shuffle=False,
batch_divisibility_req=1)
batch_divisibility_req=1,
view_requirements=view_requirements,
)

# Check after seq-len 0-padding.
cursor = 0
Expand Down
4 changes: 3 additions & 1 deletion rllib/examples/policy/episode_env_aware_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ class _fake_model:
# Repeat the incoming state every n time steps (usually max seq
# len).
batch_repeat_value=self.config["model"]["max_seq_len"],
space=self.state_space)
space=self.state_space),
"state_out_0": ViewRequirement(
space=self.state_space, used_for_compute_actions=False),
}

self.view_requirements = dict(super()._get_default_view_requirements(),
Expand Down
4 changes: 2 additions & 2 deletions rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,6 @@ def compute_actions(self,
**kwargs):

self._is_training = False
self._is_recurrent = \
state_batches is not None and state_batches != []

if not tf1.executing_eagerly():
tf1.enable_eager_execution()
Expand Down Expand Up @@ -475,6 +473,8 @@ def _compute_action_helper(self, input_dict, state_batches, episodes,
self.global_timestep
if isinstance(timestep, tf.Tensor):
timestep = int(timestep.numpy())
self._is_recurrent = state_batches is not None and \
state_batches != []
self._is_training = False
self._state_in = state_batches or []
# Calculate RNN sequence lengths.
Expand Down
3 changes: 2 additions & 1 deletion rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,8 @@ def _initialize_loss_from_dummy_batch(
self._dummy_batch.accessed_keys | \
self._dummy_batch.added_keys
for key in all_accessed_keys:
if key not in self.view_requirements:
if key not in self.view_requirements and \
key != SampleBatch.SEQ_LENS:
self.view_requirements[key] = ViewRequirement()
if self._loss:
# Tag those only needed for post-processing (with some
Expand Down
16 changes: 13 additions & 3 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,6 @@ def __getitem__(self, key: Union[str, slice]) -> TensorType:
if isinstance(key, slice):
return self._slice(key)

if not hasattr(self, key) and key in self:
self.accessed_keys.add(key)

# Backward compatibility for when "input-dicts" were used.
if key == "is_training":
if log_once("SampleBatch['is_training']"):
Expand All @@ -700,6 +697,9 @@ def __getitem__(self, key: Union[str, slice]) -> TensorType:
error=False)
return self.is_training

if not hasattr(self, key) and key in self:
self.accessed_keys.add(key)

value = dict.__getitem__(self, key)
if self.get_interceptor is not None:
if key not in self.intercepted_values:
Expand All @@ -721,6 +721,16 @@ def __setitem__(self, key, item) -> None:
dict.__setitem__(self, key, item)
return

# Backward compatibility for when "input-dicts" were used.
if key == "is_training":
if log_once("SampleBatch['is_training']"):
deprecation_warning(
old="SampleBatch['is_training']",
new="SampleBatch.is_training",
error=False)
self.is_training = item
return

if key not in self:
self.added_keys.add(key)

Expand Down
18 changes: 12 additions & 6 deletions rllib/utils/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,21 @@ def standardized(array):
return (array - array.mean()) / max(1e-4, array.std())


def minibatches(samples, sgd_minibatch_size, shuffle=True):
def minibatches(samples: SampleBatch,
sgd_minibatch_size: int,
shuffle: bool = True):
"""Return a generator yielding minibatches from a sample batch.

Args:
samples (SampleBatch): batch of samples to split up.
sgd_minibatch_size (int): size of minibatches to return.

Returns:
generator that returns mini-SampleBatches of size sgd_minibatch_size.
samples: SampleBatch to split up.
sgd_minibatch_size: Size of minibatches to return.
shuffle: Whether to shuffle the order of the generated minibatches.
Note that in case of a non-recurrent policy, the incoming batch
is globally shuffled first regardless of this setting, before
the minibatches are generated from it!

Yields:
SampleBatch: Each of size `sgd_minibatch_size`.
"""
if not sgd_minibatch_size:
yield samples
Expand Down