Skip to content

Commit

Permalink
[Jobs] [CI] Deflake test_sdk.py by setting test env vars for head n…
Browse files Browse the repository at this point in the history
…ode, not just workers (ray-project#29806)

Why are these changes needed?
The environment variables for a test were supposed to be set for all workers in a cluster. However, the head node of the cluster was started by a pytest fixture before the call to monkeypatch.setenv that sets the environment variables.
Thus, the environment variables were only being set in the worker nodes, since they were started after monkeypatch.setenv.

This PR adds a fixture to set environment variables before starting the cluster head and uses that fixture in the tests.

It's likely that this will fix the flakiness, but this should be fixed regardless.

Related issue number
May address ray-project#29006

Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
architkulkarni authored and WeichenXu123 committed Dec 19, 2022
1 parent 5fdbd9d commit cc10a0d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 19 deletions.
49 changes: 30 additions & 19 deletions dashboard/modules/job/tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,20 @@ def get_register_agents_number(webui_url):


@pytest.mark.parametrize(
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
"ray_start_cluster_head_with_env_vars",
[
{
"include_dashboard": True,
"env_vars": {
"CANDIDATE_AGENT_NUMBER": "2",
RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR: "1",
},
}
],
indirect=True,
)
def test_job_head_choose_job_agent_E2E(
mock_candidate_number, ray_start_cluster_head, monkeypatch
):
monkeypatch.setenv(RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR, "1")
cluster = ray_start_cluster_head
def test_job_head_choose_job_agent_E2E(ray_start_cluster_head_with_env_vars):
cluster = ray_start_cluster_head_with_env_vars
assert wait_until_server_available(cluster.webui_url) is True
webui_url = cluster.webui_url
webui_url = format_web_url(webui_url)
Expand Down Expand Up @@ -287,21 +294,25 @@ def get_all_new_supervisor_actor_info(old_supervisor_actor):


@pytest.mark.parametrize(
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
"ray_start_cluster_head_with_env_vars",
[
{
"include_dashboard": True,
"env_vars": {RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR: "1"},
},
{
"include_dashboard": True,
"env_vars": {RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR: "0"},
},
],
indirect=True,
)
@pytest.mark.parametrize("allow_driver_on_worker_nodes", [True, False])
def test_jobs_run_on_head_by_default_E2E(
ray_start_cluster_head, monkeypatch, allow_driver_on_worker_nodes
):
"""This test makes sure that the job will be run on the head node by default,
unless the environment variable `RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES` is set
to `1`.
"""
if allow_driver_on_worker_nodes:
monkeypatch.setenv(RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR, "1")

def test_jobs_run_on_head_by_default_E2E(ray_start_cluster_head_with_env_vars):
allow_driver_on_worker_nodes = (
os.environ.get(RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR) == "1"
)
# Cluster setup
cluster = ray_start_cluster_head
cluster = ray_start_cluster_head_with_env_vars
cluster.add_node(dashboard_agent_listen_port=52366)
cluster.add_node(dashboard_agent_listen_port=52367)
assert wait_until_server_available(cluster.webui_url) is True
Expand Down
10 changes: 10 additions & 0 deletions python/ray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,16 @@ def ray_start_cluster_head_with_external_redis(request, external_redis):
yield res


@pytest.fixture
def ray_start_cluster_head_with_env_vars(request, maybe_external_redis, monkeypatch):
param = getattr(request, "param", {})
env_vars = param.pop("env_vars", {})
for k, v in env_vars.items():
monkeypatch.setenv(k, v)
with _ray_start_cluster(do_init=True, num_nodes=1, **param) as res:
yield res


@pytest.fixture
def ray_start_cluster_2_nodes(request, maybe_external_redis):
param = getattr(request, "param", {})
Expand Down

0 comments on commit cc10a0d

Please sign in to comment.