diff --git a/modin/experimental/cloud/rayscale.py b/modin/experimental/cloud/rayscale.py index ba5f4e9673a..c77676616b9 100644 --- a/modin/experimental/cloud/rayscale.py +++ b/modin/experimental/cloud/rayscale.py @@ -149,12 +149,12 @@ def _conda_requirements(self): reqs.extend(self._get_python_version()) - if not any(re.match(r"modin(\W|$)", p) for p in self.add_conda_packages): - # user didn't define modin release; - # use automatically detected modin release from local context - reqs.append(self._get_modin_version()) - if self.add_conda_packages: + if not any(re.match(r"modin(\W|$)", p) for p in self.add_conda_packages): + # user didn't define modin release; + # use automatically detected modin release from local context + reqs.append(self._get_modin_version()) + reqs.extend(self.add_conda_packages) # this is needed, for example, for dependencies that diff --git a/modin/experimental/cloud/test/test_cloud.py b/modin/experimental/cloud/test/test_cloud.py index f0f0d75fabf..a7e4c5b3c83 100644 --- a/modin/experimental/cloud/test/test_cloud.py +++ b/modin/experimental/cloud/test/test_cloud.py @@ -57,14 +57,14 @@ def make_create_or_update_cluster_mock(): @pytest.fixture def make_ray_cluster(make_bootstrap_config_mock): - def ray_cluster(): + def ray_cluster(conda_packages=None): with mock.patch( "modin.experimental.cloud.rayscale._bootstrap_config", make_bootstrap_config_mock, ): ray_cluster = RayCluster( Provider(name="aws"), - add_conda_packages=["scikit-learn>=0.23", "modin==0.8.0"], + add_conda_packages=conda_packages, ) return ray_cluster @@ -115,9 +115,9 @@ def test_create_or_update_cluster(make_ray_cluster, make_create_or_update_cluste def test_update_conda_requirements(setup_commands_source, make_ray_cluster): fake_version = namedtuple("FakeVersion", "major minor micro")(7, 12, 45) with mock.patch("sys.version_info", fake_version): - setup_commands_result = make_ray_cluster()._update_conda_requirements( - setup_commands_source - ) + setup_commands_result = make_ray_cluster( + ["scikit-learn>=0.23", "modin==0.8.0"] + )._update_conda_requirements(setup_commands_source) assert f"python>={fake_version.major}.{fake_version.minor}" in setup_commands_result assert (