From a7b624ca6f34cb244e805b92aa4c6fcd18066684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Jan=C3=9Fen?= Date: Thu, 29 Feb 2024 17:28:28 +0100 Subject: [PATCH] add unit test --- tests/test_parse.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/test_parse.py b/tests/test_parse.py index eaa953d1..3dd2336d 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -69,3 +69,47 @@ def test_command_slurm(self): ), ) self.assertEqual(result_dict, parse_arguments(command_lst)) + + def test_command_slurm_user_command(self): + result_dict = { + "host": "127.0.0.1", + "zmqport": "22", + } + command_lst = [ + "srun", + "-n", + "2", + "-D", + os.path.abspath("."), + "--gpus-per-task=1", + "--oversubscribe", + "--account=test", + "--job-name=pympipool", + sys.executable, + "/", + "--host", + result_dict["host"], + "--zmqport", + result_dict["zmqport"], + ] + interface = SrunInterface( + cwd=os.path.abspath("."), + cores=2, + gpus_per_core=1, + oversubscribe=True, + command_line_argument_lst=["--account=test", "--job-name=pympipool"], + ) + self.assertEqual( + command_lst, + interface.generate_command( + command_lst=[ + sys.executable, + "/", + "--host", + result_dict["host"], + "--zmqport", + result_dict["zmqport"], + ] + ), + ) + self.assertEqual(result_dict, parse_arguments(command_lst))