diff --git a/ludwig/hyperopt/execution.py b/ludwig/hyperopt/execution.py index 04e57990156..90964c1ab2e 100644 --- a/ludwig/hyperopt/execution.py +++ b/ludwig/hyperopt/execution.py @@ -44,11 +44,11 @@ if _ray_200: from ray.air import Checkpoint from ray.tune.search import SEARCH_ALG_IMPORT - from ray.tune.syncer import get_node_to_storage_syncer, SyncConfig + + from ludwig.hyperopt.syncer import RemoteSyncer else: from ray.ml import Checkpoint from ray.tune.suggest import SEARCH_ALG_IMPORT - from ray.tune.syncer import get_cloud_sync_client logger = logging.getLogger(__name__) @@ -778,12 +778,14 @@ def run_experiment_trial(config, local_hyperopt_dict, checkpoint_dir=None): ) if has_remote_protocol(output_directory): - run_experiment_trial = tune.durable(run_experiment_trial) - self.sync_config = tune.SyncConfig(sync_to_driver=False, upload_dir=output_directory) if _ray_200: - self.sync_client = get_node_to_storage_syncer(SyncConfig(upload_dir=output_directory)) + self.sync_client = RemoteSyncer() + self.sync_config = tune.SyncConfig(upload_dir=output_directory, syncer=self.sync_client) else: - self.sync_client = get_cloud_sync_client(output_directory) + raise ValueError( + "Syncing to remote filesystems with hyperopt is not supported with ray<2.0, " + "please upgrade to ray>=2.0" + ) output_directory = None elif self.kubernetes_namespace: from ray.tune.integration.kubernetes import KubernetesSyncClient, NamespacedKubernetesSyncer diff --git a/ludwig/hyperopt/syncer.py b/ludwig/hyperopt/syncer.py new file mode 100644 index 00000000000..561b3048587 --- /dev/null +++ b/ludwig/hyperopt/syncer.py @@ -0,0 +1,34 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple + +from ray.tune.syncer import _BackgroundSyncer + +from ludwig.utils.data_utils import use_credentials +from ludwig.utils.fs_utils import delete, download, upload + + +class RemoteSyncer(_BackgroundSyncer): + def __init__(self, sync_period: float = 300.0, creds: Optional[Dict[str, Any]] = None): + super().__init__(sync_period=sync_period) + self.creds = creds + + def _sync_up_command(self, local_path: str, uri: str, exclude: Optional[List] = None) -> Tuple[Callable, Dict]: + with use_credentials(self.creds): + return upload, dict(lpath=local_path, rpath=uri) + + def _sync_down_command(self, uri: str, local_path: str) -> Tuple[Callable, Dict]: + with use_credentials(self.creds): + return download, dict(rpath=uri, lpath=local_path) + + def _delete_command(self, uri: str) -> Tuple[Callable, Dict]: + with use_credentials(self.creds): + return delete, dict(url=uri, recursive=True) + + def __reduce__(self): + """We need this custom serialization because we can't pickle thread.lock objects that are used by the + use_credentials context manager. + + https://docs.ray.io/en/latest/ray-core/objects/serialization.html#customized-serialization + """ + deserializer = RemoteSyncer + serialized_data = (self.sync_period, self.creds) + return deserializer, serialized_data diff --git a/ludwig/utils/fs_utils.py b/ludwig/utils/fs_utils.py index afcf6423afd..116d2429f8a 100644 --- a/ludwig/utils/fs_utils.py +++ b/ludwig/utils/fs_utils.py @@ -202,6 +202,16 @@ def delete(url, recursive=False): return fs.delete(path, recursive=recursive) +def upload(lpath, rpath): + fs, path = get_fs_and_path(rpath) + pyarrow.fs.copy_files(lpath, path, destination_filesystem=pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(fs))) + + +def download(rpath, lpath): + fs, path = get_fs_and_path(rpath) + pyarrow.fs.copy_files(path, lpath, source_filesystem=pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(fs))) + + def checksum(url): fs, path = get_fs_and_path(url) return fs.checksum(path) diff --git a/tests/integration_tests/test_hyperopt.py b/tests/integration_tests/test_hyperopt.py index 76d717d6ce9..ebadd93ed0e 100644 --- a/tests/integration_tests/test_hyperopt.py +++ b/tests/integration_tests/test_hyperopt.py @@ -15,7 +15,8 @@ import contextlib import json import os.path -from typing import Any, Dict, Optional, Tuple, Union +import uuid +from typing import Any, Dict, Optional, Tuple import pytest import torch @@ -39,23 +40,21 @@ from ludwig.globals import HYPEROPT_STATISTICS_FILE_NAME from ludwig.hyperopt.results import HyperoptResults from ludwig.hyperopt.run import hyperopt, update_hyperopt_params_with_defaults +from ludwig.utils import fs_utils from ludwig.utils.data_utils import load_json from ludwig.utils.defaults import merge_with_defaults -from tests.integration_tests.utils import category_feature, generate_data, text_feature +from tests.integration_tests.utils import category_feature, generate_data, private_param, remote_tmpdir, text_feature -try: - import ray +ray = pytest.importorskip("ray") - from ludwig.hyperopt.execution import get_build_hyperopt_executor +from ludwig.hyperopt.execution import get_build_hyperopt_executor # noqa - _ray113 = version.parse(ray.__version__) > version.parse("1.13") +_ray200 = version.parse(ray.__version__) >= version.parse("2.0") -except ImportError: - ray = None - _ray113 = None +pytestmark = pytest.mark.distributed -RANDOM_SEARCH_SIZE = 4 +RANDOM_SEARCH_SIZE = 2 HYPEROPT_CONFIG = { "parameters": { @@ -165,18 +164,6 @@ def _setup_ludwig_config_with_shared_params(dataset_fp: str) -> Tuple[Dict, Any] return config, rel_path, num_filters_search_space, embedding_size_search_space, reduce_input_search_space -def _get_trial_parameter_value(parameter_key: str, trial_row: str) -> Union[str, None]: - """Returns the parameter value from the Ray trial row, which has slightly different column names depending on - the version of Ray. Returns None if the parameter key is not found. - - TODO(#2176): There are different key name delimiters depending on Ray version. The delimiter in future versions of - Ray (> 1.13) will be '/' instead of '.' Simplify this as Ray is upgraded. - """ - if _ray113: - return trial_row[f"config/{parameter_key}"] - return trial_row[f"config.{parameter_key}"] - - @contextlib.contextmanager def ray_start(num_cpus: Optional[int] = None, num_gpus: Optional[int] = None): res = ray.init( @@ -198,7 +185,6 @@ def ray_cluster(): yield -@pytest.mark.distributed @pytest.mark.parametrize("search_alg", SEARCH_ALGS_FOR_TESTING) def test_hyperopt_search_alg( search_alg, csv_filename, tmpdir, ray_cluster, validate_output_feature=False, validation_metric=None @@ -249,7 +235,6 @@ def test_hyperopt_search_alg( assert isinstance(path, str) -@pytest.mark.distributed def test_hyperopt_executor_with_metric(csv_filename, tmpdir, ray_cluster): test_hyperopt_search_alg( "variant_generator", @@ -261,7 +246,6 @@ def test_hyperopt_executor_with_metric(csv_filename, tmpdir, ray_cluster): ) -@pytest.mark.distributed @pytest.mark.parametrize("scheduler", SCHEDULERS_FOR_TESTING) def test_hyperopt_scheduler( scheduler, csv_filename, tmpdir, ray_cluster, validate_output_feature=False, validation_metric=None @@ -316,7 +300,6 @@ def test_hyperopt_scheduler( assert isinstance(raytune_results, HyperoptResults) -@pytest.mark.distributed @pytest.mark.parametrize("search_space", ["random", "grid"]) def test_hyperopt_run_hyperopt(csv_filename, search_space, tmpdir, ray_cluster): input_features = [ @@ -370,14 +353,19 @@ def test_hyperopt_run_hyperopt(csv_filename, search_space, tmpdir, ray_cluster): "goal": "minimize", "output_feature": output_feature_name, "validation_metrics": "loss", - "executor": {TYPE: "ray", "num_samples": 1 if search_space == "grid" else RANDOM_SEARCH_SIZE}, + "executor": { + TYPE: "ray", + "num_samples": 1 if search_space == "grid" else RANDOM_SEARCH_SIZE, + "max_concurrent_trials": 1, + }, "search_alg": {TYPE: "variant_generator"}, } # add hyperopt parameter space to the config config[HYPEROPT] = hyperopt_configs - hyperopt_results = hyperopt(config, dataset=rel_path, output_directory=tmpdir, experiment_name="test_hyperopt") + experiment_name = f"test_hyperopt_{uuid.uuid4().hex}" + hyperopt_results = hyperopt(config, dataset=rel_path, output_directory=tmpdir, experiment_name=experiment_name) if search_space == "random": assert hyperopt_results.experiment_analysis.results_df.shape[0] == RANDOM_SEARCH_SIZE else: @@ -391,10 +379,21 @@ def test_hyperopt_run_hyperopt(csv_filename, search_space, tmpdir, ray_cluster): assert isinstance(hyperopt_results, HyperoptResults) # check for existence of the hyperopt statistics file - assert os.path.isfile(os.path.join(tmpdir, "test_hyperopt", HYPEROPT_STATISTICS_FILE_NAME)) + assert fs_utils.path_exists(os.path.join(tmpdir, experiment_name, HYPEROPT_STATISTICS_FILE_NAME)) + + +@pytest.mark.parametrize("fs_protocol,bucket", [private_param(("s3", "ludwig-tests"))], ids=["s3"]) +def test_hyperopt_sync_remote(fs_protocol, bucket, csv_filename, ray_cluster): + with remote_tmpdir(fs_protocol, bucket) as tmpdir: + with pytest.raises(ValueError) if not _ray200 else contextlib.nullcontext(): + test_hyperopt_run_hyperopt( + csv_filename, + "random", + tmpdir, + ray_cluster, + ) -@pytest.mark.distributed def test_hyperopt_with_feature_specific_parameters(csv_filename, tmpdir, ray_cluster): input_features = [ text_feature(name="utterance", reduce_output="sum"), @@ -446,7 +445,6 @@ def test_hyperopt_with_feature_specific_parameters(csv_filename, tmpdir, ray_clu assert input_feature["encoder"]["embedding_size"] in embedding_size_search_space -@pytest.mark.distributed def test_hyperopt_old_config(csv_filename, tmpdir, ray_cluster): old_config = { "ludwig_version": "0.4", @@ -500,7 +498,6 @@ def test_hyperopt_old_config(csv_filename, tmpdir, ray_cluster): hyperopt(old_config, dataset=rel_path, output_directory=tmpdir, experiment_name="test_hyperopt") -@pytest.mark.distributed def test_hyperopt_nested_parameters(csv_filename, tmpdir, ray_cluster): config = { INPUT_FEATURES: [ @@ -591,7 +588,6 @@ def test_hyperopt_nested_parameters(csv_filename, tmpdir, ray_cluster): assert trial_config[TRAINER]["learning_rate"] in {0.7, 0.42} -@pytest.mark.distributed def test_hyperopt_grid_search_more_than_one_sample(csv_filename, tmpdir, ray_cluster): input_features = [ text_feature(name="utterance", encoder={"reduce_output": "sum"}), diff --git a/tests/integration_tests/test_remote.py b/tests/integration_tests/test_remote.py index 020d755970f..debf1f18004 100644 --- a/tests/integration_tests/test_remote.py +++ b/tests/integration_tests/test_remote.py @@ -1,7 +1,4 @@ -import contextlib import os -import tempfile -import uuid import pytest import yaml @@ -11,22 +8,13 @@ from ludwig.constants import TRAINER from ludwig.globals import DESCRIPTION_FILE_NAME from ludwig.utils import fs_utils -from tests.integration_tests.utils import category_feature, generate_data, private_param, sequence_feature - - -@contextlib.contextmanager -def remote_tmpdir(fs_protocol, bucket): - if bucket is None: - with tempfile.TemporaryDirectory() as tmpdir: - yield f"{fs_protocol}://{tmpdir}" - return - - prefix = f"tmp_{uuid.uuid4().hex}" - tmpdir = f"{fs_protocol}://{bucket}/{prefix}" - try: - yield tmpdir - finally: - fs_utils.delete(tmpdir, recursive=True) +from tests.integration_tests.utils import ( + category_feature, + generate_data, + private_param, + remote_tmpdir, + sequence_feature, +) @pytest.mark.parametrize( diff --git a/tests/integration_tests/utils.py b/tests/integration_tests/utils.py index cc9110bf283..64a9b9b4d48 100644 --- a/tests/integration_tests/utils.py +++ b/tests/integration_tests/utils.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== +import contextlib import logging import multiprocessing import os @@ -40,6 +41,7 @@ from ludwig.experiment import experiment_cli from ludwig.features.feature_utils import compute_feature_hash from ludwig.trainers.trainer import Trainer +from ludwig.utils import fs_utils from ludwig.utils.data_utils import read_csv, replace_file_extension logger = logging.getLogger(__name__) @@ -872,3 +874,22 @@ def filter(stats): finally: # Remove results/intermediate data saved to disk shutil.rmtree(output_dir, ignore_errors=True) + + +@contextlib.contextmanager +def remote_tmpdir(fs_protocol, bucket): + if bucket is None: + with tempfile.TemporaryDirectory() as tmpdir: + yield f"{fs_protocol}://{tmpdir}" + return + + prefix = f"tmp_{uuid.uuid4().hex}" + tmpdir = f"{fs_protocol}://{bucket}/{prefix}" + try: + yield tmpdir + finally: + try: + fs_utils.delete(tmpdir, recursive=True) + except FileNotFoundError as e: + logging.info(f"failed to delete remote tempdir, does not exist: {str(e)}") + pass