-
Notifications
You must be signed in to change notification settings - Fork 6.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Reinstate trajectory view API tests. #18809
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,14 @@ | ||
import copy | ||
import gym | ||
from gym.spaces import Box, Discrete | ||
import numpy as np | ||
import unittest | ||
|
||
import ray | ||
from ray import tune | ||
from ray.rllib.agents.callbacks import DefaultCallbacks | ||
import ray.rllib.agents.dqn as dqn | ||
import ray.rllib.agents.ppo as ppo | ||
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv | ||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole | ||
from ray.rllib.examples.env.multi_agent import MultiAgentPendulum | ||
from ray.rllib.evaluation.rollout_worker import RolloutWorker | ||
from ray.rllib.examples.policy.episode_env_aware_policy import \ | ||
EpisodeEnvAwareAttentionPolicy, EpisodeEnvAwareLSTMPolicy | ||
|
@@ -28,7 +26,7 @@ def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs): | |
assert train_batch.count == 201 | ||
assert sum(train_batch[SampleBatch.SEQ_LENS]) == 201 | ||
for k, v in train_batch.items(): | ||
if k == "state_in_0": | ||
if k in ["state_in_0", SampleBatch.SEQ_LENS]: | ||
assert len(v) == len(train_batch[SampleBatch.SEQ_LENS]) | ||
else: | ||
assert len(v) == 201 | ||
|
@@ -65,7 +63,7 @@ def test_traj_view_normal_case(self): | |
view_req_model = policy.model.view_requirements | ||
view_req_policy = policy.view_requirements | ||
assert len(view_req_model) == 1, view_req_model | ||
assert len(view_req_policy) == 8, view_req_policy | ||
assert len(view_req_policy) == 10, view_req_policy | ||
for key in [ | ||
SampleBatch.OBS, | ||
SampleBatch.ACTIONS, | ||
|
@@ -111,7 +109,8 @@ def test_traj_view_lstm_prev_actions_and_rewards(self): | |
view_req_policy = policy.view_requirements | ||
# 7=obs, prev-a + r, 2x state-in, 2x state-out. | ||
assert len(view_req_model) == 7, view_req_model | ||
assert len(view_req_policy) == 19, view_req_policy | ||
assert len(view_req_policy) == 20,\ | ||
(len(view_req_policy), view_req_policy) | ||
for key in [ | ||
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, | ||
SampleBatch.DONES, SampleBatch.NEXT_OBS, | ||
|
@@ -171,9 +170,9 @@ def test_traj_view_attention_net(self): | |
) | ||
rw = trainer.workers.local_worker() | ||
sample = rw.sample() | ||
assert sample.count == config["rollout_fragment_length"] | ||
assert sample.count == trainer.config["rollout_fragment_length"] | ||
results = trainer.train() | ||
assert results["train_batch_size"] == config["train_batch_size"] | ||
assert results["timesteps_total"] == config["train_batch_size"] | ||
trainer.stop() | ||
|
||
def test_traj_view_next_action(self): | ||
|
@@ -188,9 +187,14 @@ def test_traj_view_next_action(self): | |
) | ||
# Add the next action to the view reqs of the policy. | ||
# This should be visible then in postprocessing and train batches. | ||
# Switch off for action computations (can't be there as we don't know | ||
# the next action already at action computation time). | ||
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[ | ||
"next_actions"] = ViewRequirement( | ||
SampleBatch.ACTIONS, shift=1, space=action_space) | ||
SampleBatch.ACTIONS, | ||
shift=1, | ||
space=action_space, | ||
used_for_compute_actions=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we should have validation for this field somewhere? seems easy to miss, and not straight-forward for regular users. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean to check, whether it's even possible to have this in the action computation event, even though the shift is >0 from a "collected" field, like actions? Great idea! |
||
# Make sure, we have DONEs as well. | ||
rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[ | ||
"dones"] = ViewRequirement() | ||
|
@@ -209,7 +213,7 @@ def test_traj_view_next_action(self): | |
expected_a_ = a_ | ||
|
||
def test_traj_view_lstm_functionality(self): | ||
action_space = Box(-float("inf"), float("inf"), shape=(3, )) | ||
action_space = Box(float("-inf"), float("inf"), shape=(3, )) | ||
obs_space = Box(float("-inf"), float("inf"), (4, )) | ||
max_seq_len = 50 | ||
rollout_fragment_length = 200 | ||
|
@@ -230,37 +234,30 @@ def policy_fn(agent_id, episode, **kwargs): | |
"use_lstm": True, | ||
"max_seq_len": max_seq_len, | ||
}, | ||
}, | ||
} | ||
|
||
rollout_worker_w_api = RolloutWorker( | ||
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), | ||
policy_config=config, | ||
rollout_fragment_length=rollout_fragment_length, | ||
policy_spec=policies, | ||
policy_mapping_fn=policy_fn, | ||
num_envs=1, | ||
) | ||
rollout_worker_wo_api = RolloutWorker( | ||
rw = RolloutWorker( | ||
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), | ||
policy_config=config, | ||
rollout_fragment_length=rollout_fragment_length, | ||
policy_spec=policies, | ||
policy_mapping_fn=policy_fn, | ||
normalize_actions=False, | ||
num_envs=1, | ||
) | ||
|
||
for iteration in range(20): | ||
result = rollout_worker_w_api.sample() | ||
result = rw.sample() | ||
check(result.count, rollout_fragment_length) | ||
pol_batch_w = result.policy_batches["pol0"] | ||
assert pol_batch_w.count >= rollout_fragment_length | ||
analyze_rnn_batch(pol_batch_w, max_seq_len) | ||
|
||
result = rollout_worker_wo_api.sample() | ||
pol_batch_wo = result.policy_batches["pol0"] | ||
check(pol_batch_w, pol_batch_wo) | ||
analyze_rnn_batch( | ||
pol_batch_w, | ||
max_seq_len, | ||
view_requirements=rw.policy_map["pol0"].view_requirements) | ||
|
||
def test_traj_view_attention_functionality(self): | ||
action_space = Box(-float("inf"), float("inf"), shape=(3, )) | ||
action_space = Box(float("-inf"), float("inf"), shape=(3, )) | ||
obs_space = Box(float("-inf"), float("inf"), (4, )) | ||
max_seq_len = 50 | ||
rollout_fragment_length = 201 | ||
|
@@ -280,14 +277,15 @@ def policy_fn(agent_id, episode, **kwargs): | |
"model": { | ||
"max_seq_len": max_seq_len, | ||
}, | ||
}, | ||
} | ||
|
||
rollout_worker_w_api = RolloutWorker( | ||
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), | ||
policy_config=config, | ||
rollout_fragment_length=rollout_fragment_length, | ||
policy_spec=policies, | ||
policy_mapping_fn=policy_fn, | ||
normalize_actions=False, | ||
num_envs=1, | ||
) | ||
batch = rollout_worker_w_api.sample() | ||
|
@@ -296,37 +294,38 @@ def policy_fn(agent_id, episode, **kwargs): | |
def test_counting_by_agent_steps(self): | ||
"""Test whether a PPOTrainer can be built with all frameworks.""" | ||
config = copy.deepcopy(ppo.DEFAULT_CONFIG) | ||
action_space = Discrete(2) | ||
obs_space = Box(float("-inf"), float("inf"), (4, ), dtype=np.float32) | ||
|
||
num_agents = 3 | ||
|
||
config["num_workers"] = 2 | ||
config["num_sgd_iter"] = 2 | ||
config["framework"] = "torch" | ||
config["rollout_fragment_length"] = 21 | ||
config["train_batch_size"] = 147 | ||
config["multiagent"] = { | ||
"policies": { | ||
"p0": (None, obs_space, action_space, {}), | ||
"p1": (None, obs_space, action_space, {}), | ||
}, | ||
"policies": {"p0", "p1"}, | ||
"policy_mapping_fn": lambda aid, **kwargs: "p{}".format(aid), | ||
"count_steps_by": "agent_steps", | ||
} | ||
tune.register_env( | ||
"ma_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})) | ||
# Env setup. | ||
config["env"] = MultiAgentPendulum | ||
config["env_config"] = {"num_agents": num_agents} | ||
|
||
num_iterations = 2 | ||
trainer = ppo.PPOTrainer(config=config, env="ma_cartpole") | ||
trainer = ppo.PPOTrainer(config=config) | ||
results = None | ||
for i in range(num_iterations): | ||
results = trainer.train() | ||
self.assertGreater(results["agent_timesteps_total"], | ||
num_iterations * config["train_batch_size"]) | ||
self.assertLess(results["agent_timesteps_total"], | ||
(num_iterations + 1) * config["train_batch_size"]) | ||
self.assertEqual(results["agent_timesteps_total"], | ||
results["timesteps_total"] * num_agents) | ||
self.assertGreaterEqual(results["agent_timesteps_total"], | ||
num_iterations * config["train_batch_size"]) | ||
self.assertLessEqual(results["agent_timesteps_total"], | ||
(num_iterations + 1) * config["train_batch_size"]) | ||
trainer.stop() | ||
|
||
|
||
def analyze_rnn_batch(batch, max_seq_len): | ||
def analyze_rnn_batch(batch, max_seq_len, view_requirements): | ||
count = batch.count | ||
|
||
# Check prev_reward/action, next_obs consistency. | ||
|
@@ -399,7 +398,9 @@ def analyze_rnn_batch(batch, max_seq_len): | |
batch, | ||
max_seq_len=max_seq_len, | ||
shuffle=False, | ||
batch_divisibility_req=1) | ||
batch_divisibility_req=1, | ||
view_requirements=view_requirements, | ||
) | ||
|
||
# Check after seq-len 0-padding. | ||
cursor = 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a random question that I have been curious about for a while: how much do we honor the train_batch_size param here?
for example, in complete_episode mode, or if there is sample replay, will we ever give a training batch that is of very different size?
thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We try to do our best to honor it, but it's not guaranteed to be exact always.
The reason is that we do parallel rollouts with a fixed (or full episode length) step limit per vectorized(!) environment. Depending on the number of vectorized sub-envs per worker and the number of workers, the final train batch may be slightly off. For PPO for example, we auto-correct the rollout_fragment_length (since a few releases ago) based on these factors to better match the train_batch_size, but of course if you have lots of odd numbers in these setting, you will not get the train batch exactly right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the explanation, that's my impression as well.
I actually have a feeling sometimes we may be off a lot.
I can probably do some testing when I get a chance.
thanks.