diff --git a/tests/test_parse.py b/tests/test_parse.py index eaa953d1..3dd2336d 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -69,3 +69,47 @@ def test_command_slurm(self): ), ) self.assertEqual(result_dict, parse_arguments(command_lst)) + + def test_command_slurm_user_command(self): + result_dict = { + "host": "127.0.0.1", + "zmqport": "22", + } + command_lst = [ + "srun", + "-n", + "2", + "-D", + os.path.abspath("."), + "--gpus-per-task=1", + "--oversubscribe", + "--account=test", + "--job-name=pympipool", + sys.executable, + "/", + "--host", + result_dict["host"], + "--zmqport", + result_dict["zmqport"], + ] + interface = SrunInterface( + cwd=os.path.abspath("."), + cores=2, + gpus_per_core=1, + oversubscribe=True, + command_line_argument_lst=["--account=test", "--job-name=pympipool"], + ) + self.assertEqual( + command_lst, + interface.generate_command( + command_lst=[ + sys.executable, + "/", + "--host", + result_dict["host"], + "--zmqport", + result_dict["zmqport"], + ] + ), + ) + self.assertEqual(result_dict, parse_arguments(command_lst))