Skip to content

Commit

Permalink
[train v2+tune] Add an environment variable to disable running the `T…
Browse files Browse the repository at this point in the history
…rainController` as an actor (ray-project#49522)

For the Train v2 + Tune integration, the `TrainController` cannot run as
a separate actor, since callbacks would run in a separate process and
would not be able to call `ray.tune.report` to propagate intermediate
metrics/checkpoints to Tune. Therefore, Train needs to be able to run in
a mode where the `TrainController` just runs on the process that
`trainer.fit()` was called in. For Tune, this it the function Trainable
that acts as the Ray Train driver. This is an internal implementation
detail, which is why I introduce this as an environment variable that
Ray Tune will set automatically.

---------

Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Puyuan Yao <[email protected]>
  • Loading branch information
justinvyu authored and anyadontfly committed Feb 13, 2025
1 parent 1064061 commit dc9d358
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
4 changes: 4 additions & 0 deletions python/ray/train/v2/_internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
ENABLE_PRINT_PATCH_ENV_VAR = "RAY_TRAIN_ENABLE_PRINT_PATCH"
DEFAULT_ENABLE_PRINT_PATCH = "1"

# Whether or not to run the controller as an actor.
RUN_CONTROLLER_AS_ACTOR_ENV_VAR = "RAY_TRAIN_RUN_CONTROLLER_AS_ACTOR"
DEFAULT_RUN_CONTROLLER_AS_ACTOR = "1"

# V2 feature flag.
V2_ENABLED_ENV_VAR = "RAY_TRAIN_V2_ENABLED"

Expand Down
37 changes: 25 additions & 12 deletions python/ray/train/v2/api/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
)
from ray.train.v2._internal.constants import (
_UNSUPPORTED,
DEFAULT_RUN_CONTROLLER_AS_ACTOR,
METRICS_ENABLED_ENV_VAR,
RUN_CONTROLLER_AS_ACTOR_ENV_VAR,
get_env_vars_to_propagate,
)
from ray.train.v2._internal.execution.context import TrainRunContext
Expand Down Expand Up @@ -195,26 +197,15 @@ def fit(self) -> Result:

# TODO: Add support for user-defined callbacks

# By default, attach the controller to the node running the driver script.
controller_actor_cls = ray.remote(
num_cpus=0,
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(), soft=False
),
runtime_env={"env_vars": get_env_vars_to_propagate()},
)(TrainController)

controller = controller_actor_cls.remote(
result = self._initialize_and_run_controller(
train_fn=train_fn,
scaling_policy=create_scaling_policy(self.scaling_config),
failure_policy=DefaultFailurePolicy(self.run_config.failure_config),
train_run_context=self.train_run_context,
callbacks=callbacks,
resume_from_checkpoint=self.resume_from_checkpoint,
)
ray.get(controller.run.remote())

result = ray.get(controller.get_result.remote())
if result.error:
# NOTE: If the training run errored out, raise an error back to the
# user's driver script.
Expand All @@ -225,6 +216,28 @@ def fit(self) -> Result:

return result

def _initialize_and_run_controller(self, **controller_init_kwargs) -> Result:
run_controller_as_actor = env_bool(
RUN_CONTROLLER_AS_ACTOR_ENV_VAR, DEFAULT_RUN_CONTROLLER_AS_ACTOR
)
if run_controller_as_actor:
# Attach the controller to the node running the driver script.
controller_actor_cls = ray.remote(
num_cpus=0,
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(), soft=False
),
runtime_env={"env_vars": get_env_vars_to_propagate()},
)(TrainController)

controller = controller_actor_cls.remote(**controller_init_kwargs)
ray.get(controller.run.remote())
return ray.get(controller.get_result.remote())
else:
controller = TrainController(**controller_init_kwargs)
controller.run()
return controller.get_result()

@classmethod
def restore(cls, *args, **kwargs):
raise DeprecationWarning(TRAINER_RESTORE_DEPRECATION_WARNING)
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/v2/tests/test_data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
from ray.train.backend import Backend
from ray.train.constants import RAY_CHDIR_TO_TRIAL_DIR, _get_ray_train_session_dir
from ray.train.tests.util import create_dict_checkpoint
from ray.train.v2._internal.constants import is_v2_enabled
from ray.train.v2._internal.exceptions import TrainingFailedError
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
from ray.train.v2.api.result import Result

assert is_v2_enabled()


@pytest.fixture(scope="module", autouse=True)
def ray_start_4_cpus():
Expand Down

0 comments on commit dc9d358

Please sign in to comment.