Skip to content

Commit

Permalink
Fix asyncio unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
satishpasumarthi committed Sep 29, 2021
1 parent 2607202 commit 00aac29
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions test/unit/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def test_mpi_worker_run_no_wait(popen, ssh_client, path_exists, write_env_vars):
path_exists.assert_called_with("/usr/sbin/sshd")


@patch("process.asyncio.gather", new_callable=AsyncMock)
@patch("asyncio.gather", new_callable=AsyncMock)
@patch("os.path.exists")
@patch("paramiko.SSHClient", new_callable=MockSSHClient)
@patch("paramiko.AutoAddPolicy")
@patch("process.asyncio.create_subprocess_shell")
@patch("asyncio.create_subprocess_shell")
@patch("sagemaker_training.environment.Environment")
def test_mpi_master_run(
training_env, async_shell, policy, ssh_client, path_exists, async_gather, event_loop
Expand Down Expand Up @@ -198,12 +198,12 @@ def test_mpi_master_run(
path_exists.assert_called_with("/usr/sbin/sshd")


@patch("process.asyncio.gather", new_callable=AsyncMock)
@patch("asyncio.gather", new_callable=AsyncMock)
@patch("os.path.exists")
@patch("sagemaker_training.process.python_executable", return_value="usr/bin/python3")
@patch("paramiko.SSHClient", new_callable=MockSSHClient)
@patch("paramiko.AutoAddPolicy")
@patch("process.asyncio.create_subprocess_shell")
@patch("asyncio.create_subprocess_shell")
@patch("sagemaker_training.environment.Environment")
def test_mpi_master_run_python(
training_env,
Expand Down
8 changes: 4 additions & 4 deletions test/unit/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def test_create_error():
process.create(["run"], errors.ExecuteUserScriptError, 1)


@patch("process.asyncio.gather", new_callable=AsyncMock1)
@patch("process.asyncio.create_subprocess_shell")
@patch("asyncio.gather", new_callable=AsyncMock1)
@patch("asyncio.create_subprocess_shell")
@pytest.mark.asyncio
async def test_run_async(async_shell, async_gather):
processes_per_host = 2
Expand All @@ -154,8 +154,8 @@ async def test_run_async(async_shell, async_gather):
assert output == "test"


@patch("process.asyncio.gather", new_callable=AsyncMock1)
@patch("process.asyncio.create_subprocess_shell")
@patch("asyncio.gather", new_callable=AsyncMock1)
@patch("asyncio.create_subprocess_shell")
@patch("sagemaker_training.logging_config.log_script_invocation")
def test_run_python(log, async_shell, async_gather, entry_point_type_script, event_loop):

Expand Down
8 changes: 4 additions & 4 deletions test/unit/test_smdataparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)


@patch("process.asyncio.gather", new_callable=AsyncMock)
@patch("asyncio.gather", new_callable=AsyncMock)
@patch("os.path.exists")
@patch("sagemaker_training.process.python_executable", return_value="usr/bin/python3")
@patch("paramiko.SSHClient", new_callable=MockSSHClient)
@patch("paramiko.AutoAddPolicy")
@patch("process.asyncio.create_subprocess_shell")
@patch("asyncio.create_subprocess_shell")
@patch("sagemaker_training.environment.Environment")
def test_smdataparallel_run_multi_node_python(
training_env,
Expand Down Expand Up @@ -154,12 +154,12 @@ def test_smdataparallel_run_multi_node_python(
path_exists.assert_called_with("/usr/sbin/sshd")


@patch("process.asyncio.gather", new_callable=AsyncMock)
@patch("asyncio.gather", new_callable=AsyncMock)
@patch("os.path.exists")
@patch("sagemaker_training.process.python_executable", return_value="usr/bin/python3")
@patch("paramiko.SSHClient", new_callable=MockSSHClient)
@patch("paramiko.AutoAddPolicy")
@patch("process.asyncio.create_subprocess_shell")
@patch("asyncio.create_subprocess_shell")
@patch("sagemaker_training.environment.Environment")
def test_smdataparallel_run_single_node_python(
training_env,
Expand Down

0 comments on commit 00aac29

Please sign in to comment.