Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] No Preprocessors (part 2). #18468

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,7 @@ py_test(
py_test(
name = "test_preprocessors",
tags = ["team:ml", "models"],
size = "small",
size = "medium",
srcs = ["models/tests/test_preprocessors.py"]
)

Expand Down Expand Up @@ -2659,6 +2659,24 @@ py_test(
srcs = ["examples/pettingzoo_env.py"],
)

py_test(
name = "examples/preprocessing_disabled_tf",
main = "examples/preprocessing_disabled.py",
tags = ["team:ml", "examples", "examples_P"],
size = "medium",
srcs = ["examples/preprocessing_disabled.py"],
args = ["--stop-iters=2"]
)

py_test(
name = "examples/preprocessing_disabled_torch",
main = "examples/preprocessing_disabled.py",
tags = ["team:ml", "examples", "examples_P"],
size = "medium",
srcs = ["examples/preprocessing_disabled.py"],
args = ["--framework=torch", "--stop-iters=2"]
)

py_test(
name = "examples/remote_envs_with_inference_done_on_main_node_tf",
main = "examples/remote_envs_with_inference_done_on_main_node.py",
Expand Down
34 changes: 22 additions & 12 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
from ray.rllib.utils.debug import update_global_seed_if_necessary
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR
from ray.rllib.utils.framework import try_import_tf, TensorStructType
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.multi_agent import check_multi_agent
from ray.rllib.utils.spaces import space_utils
from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \
PartialTrainerConfigDict, PolicyID, ResultDict, TrainerConfigDict
PartialTrainerConfigDict, PolicyID, ResultDict, TensorStructType, \
TrainerConfigDict
from ray.tune.logger import Logger, UnifiedLogger
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.resources import Resources
Expand Down Expand Up @@ -113,11 +114,6 @@
"model": MODEL_DEFAULTS,
# Arguments to pass to the policy optimizer. These vary by optimizer.
"optimizer": {},
# Experimental flag, indicating that TFPolicy will handle more than one
# loss/optimizer. Set this to True, if you would like to return more than
# one loss term from your `loss_fn` and an equal number of optimizers
# from your `optimizer_fn`.
"_tf_policy_handles_more_than_one_loss": False,

# === Environment Settings ===
# Number of steps after which the episode is forced to terminate. Defaults
Expand Down Expand Up @@ -483,6 +479,20 @@
# Default value None allows overwriting with nested dicts
"logger_config": None,

# === API deprecations/simplifications/changes ===
# Experimental flag.
# If True, TFPolicy will handle more than one loss/optimizer.
# Set this to True, if you would like to return more than
# one loss term from your `loss_fn` and an equal number of optimizers
# from your `optimizer_fn`.
# In the future, the default for this will be True.
"_tf_policy_handles_more_than_one_loss": False,
# Experimental flag.
# If True, no (observation) preprocessor will be created and
# observations will arrive in model as they are returned by the env.
# In the future, the default for this will be True.
"_disable_preprocessor_api": False,

# === Deprecated keys ===
# Uses the sync samples optimizer instead of the multi-gpu one. This is
# usually slower, but you might want to try it if you run into issues with
Expand Down Expand Up @@ -1128,8 +1138,8 @@ def compute_actions(
tuple: The full output of policy.compute_actions() if
full_fetch=True or we have an RNN-based Policy.
"""
# Preprocess obs and states
stateDefined = state is not None
# Preprocess obs and states.
state_defined = state is not None
policy = self.get_policy(policy_id)
filtered_obs, filtered_state = [], []
for agent_id, ob in observations.items():
Expand Down Expand Up @@ -1174,7 +1184,7 @@ def compute_actions(
unbatched_states[agent_id] = [s[idx] for s in states]

# Return only actions or full tuple
if stateDefined or full_fetch:
if state_defined or full_fetch:
return actions, unbatched_states, infos
else:
return actions
Expand Down Expand Up @@ -1529,8 +1539,8 @@ def _validate_config(config: PartialTrainerConfigDict,
# Check model config.
# If no preprocessing, propagate into model's config as well
# (so model will know, whether inputs are preprocessed or not).
if config["preprocessor_pref"] is None:
model_config["_no_preprocessor"] = True
if config["_disable_preprocessor_api"] is True:
model_config["_disable_preprocessor_api"] = True

# Prev_a/r settings.
prev_a_r = model_config.get("lstm_use_prev_action_reward",
Expand Down
Loading