diff --git a/.gitignore b/.gitignore index 8108ae2a..1753d5f3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea +.vscode .cache/ build/ dist/ diff --git a/setup.py b/setup.py index c593c565..994e2c8f 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of @@ -77,6 +77,7 @@ def read_version(): "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", ], install_requires=required_packages, extras_require={ diff --git a/src/sagemaker_training/entry_point.py b/src/sagemaker_training/entry_point.py index 96d68285..f74349e0 100644 --- a/src/sagemaker_training/entry_point.py +++ b/src/sagemaker_training/entry_point.py @@ -1,4 +1,4 @@ -# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of diff --git a/src/sagemaker_training/environment.py b/src/sagemaker_training/environment.py index 12daab01..a91ea9fb 100644 --- a/src/sagemaker_training/environment.py +++ b/src/sagemaker_training/environment.py @@ -36,6 +36,10 @@ SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str BASE_PATH_ENV = "SAGEMAKER_BASE_DIR" # type: str +HYPERPARAMETERS_FILE = "hyperparameters.json" # type: str +RESOURCE_CONFIG_FILE = "resourceconfig.json" # type: str +INPUT_DATA_CONFIG_FILE = "inputdataconfig.json" # type: str + def _write_json(obj, path): # type: (object, str) -> None """Write a serializeable object as a JSON file.""" @@ -65,10 +69,11 @@ def _is_training_path_configured(): # type: () -> bool def _set_base_path_env(): # type: () -> None """Set the environment variable SAGEMAKER_BASE_DIR as - ~/sagemaker_local/{timestamp}/opt/ml + ~/sagemaker_local/jobs/{timestamp}/opt/ml """ + timestamp = str(time.time()) local_config_dir = os.path.join( - os.path.expanduser("~"), "sagemaker_local", "jobs", str(time.time()), "opt", "ml" + os.path.expanduser("~"), "sagemaker_local", "jobs", timestamp, "opt", "ml" ) logger.info("Setting environment variable SAGEMAKER_BASE_DIR as %s ." % local_config_dir) @@ -139,10 +144,6 @@ def _set_base_path_env(): # type: () -> None str: the path to the intermediate output directory, e.g. /opt/ml/output/intermediate. """ -HYPERPARAMETERS_FILE = "hyperparameters.json" # type: str -RESOURCE_CONFIG_FILE = "resourceconfig.json" # type: str -INPUT_DATA_CONFIG_FILE = "inputdataconfig.json" # type: str - hyperparameters_file_dir = os.path.join(input_config_dir, HYPERPARAMETERS_FILE) # type: str input_data_config_file_dir = os.path.join(input_config_dir, INPUT_DATA_CONFIG_FILE) # type: str resource_config_file_dir = os.path.join(input_config_dir, RESOURCE_CONFIG_FILE) # type: str @@ -196,7 +197,7 @@ def read_hyperparameters(): # type: () -> dict """Read the hyperparameters from /opt/ml/input/config/hyperparameters.json. For more information about hyperparameters.json: - https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-running-container-hyperparameters + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html#your-algorithms-training-algo-running-container-hyperparameters Returns: (dict[string, object]): A dictionary containing the hyperparameters. @@ -225,7 +226,7 @@ def read_resource_config(): # type: () -> dict """Read the resource configuration from /opt/ml/input/config/resourceconfig.json. For more information about resourceconfig.json: -https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-running-container-dist-training + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html#your-algorithms-training-algo-running-container-dist-training Returns: resource_config (dict[string, object]): the contents from /opt/ml/input/config/resourceconfig.json. @@ -264,7 +265,7 @@ def read_input_data_config(): # type: () -> dict }} For more information about inpudataconfig.json: -https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-running-container-dist-training + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html#your-algorithms-training-algo-running-container-inputdataconfig Returns: input_data_config (dict[string, object]): Contents from /opt/ml/input/config/inputdataconfig.json. @@ -305,6 +306,7 @@ def num_cpus(): # type: () -> int Returns: int: Number of CPUs available in the current container. """ + # TODO: https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python return multiprocessing.cpu_count() @@ -326,7 +328,7 @@ class Environment(mapping.MappingMixin): # pylint:disable=too-many-public-metho get the path of the channel 'training' from the inputdataconfig.json file >>>training_dir = environment.channel_input_dirs['training'] - get a the hyperparameter 'training_data_file' from hyperparameters.json file + get the hyperparameter 'training_data_file' from hyperparameters.json file >>>file_name = environment.hyperparameters['training_data_file'] get the folder where the model should be saved @@ -407,7 +409,7 @@ class Environment(mapping.MappingMixin): # pylint:disable=too-many-public-metho }} You can find more information about /opt/ml/input/config/inputdataconfig.json here: - https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-running-container-inputdataconfig + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html#your-algorithms-training-algo-running-container-inputdataconfig output_data_dir (str): The dir to write non-model training artifacts (e.g. evaluation results) which will be retained by SageMaker, @@ -476,7 +478,7 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters }} You can find more information about /opt/ml/input/config/inputdataconfig.json here: - https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-running-container-inputdataconfig + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html#your-algorithms-training-algo-running-container-inputdataconfig hyperparameters (dict[string, object]): An instance of `HyperParameters` containing the training job hyperparameters. diff --git a/src/sagemaker_training/errors.py b/src/sagemaker_training/errors.py index 127571ee..51da6a15 100644 --- a/src/sagemaker_training/errors.py +++ b/src/sagemaker_training/errors.py @@ -30,21 +30,42 @@ class _CalledProcessError(ClientError): cmd, return_code, output """ - def __init__(self, cmd, return_code=None, output=None): - self.return_code = return_code + def __init__(self, cmd, return_code=None, output=None, info=None): + self.return_code = str(return_code) self.cmd = cmd self.output = output + self.extra_info = info super(_CalledProcessError, self).__init__() def __str__(self): if six.PY3 and self.output: - error_msg = "\n%s" % self.output.decode("latin1") + # error_msg = "%s" % self.output.decode("latin1") + if isinstance(self.output, bytes): + error_msg = "%s" % self.output.decode("utf-8") + else: + error_msg = "%s" % self.output elif self.output: - error_msg = "\n%s" % self.output + error_msg = "%s" % self.output else: error_msg = "" - - message = '%s:\nCommand "%s"%s' % (type(self).__name__, self.cmd, error_msg) + if self.extra_info is None: + message = '%s:\nExitCode %s\nErrorMessage "%s"\nCommand "%s"' % ( + type(self).__name__, + self.return_code, + error_msg, + self.cmd, + ) + else: + message = ( + '%s:\nExitCode %s\nErrorMessage "%s"\nExtraInfo "%s"\nCommand "%s"' + % ( + type(self).__name__, + self.return_code, + error_msg, + self.extra_info, + self.cmd, + ) + ) return message.strip() diff --git a/src/sagemaker_training/modules.py b/src/sagemaker_training/modules.py index e484b8fd..9cf52050 100644 --- a/src/sagemaker_training/modules.py +++ b/src/sagemaker_training/modules.py @@ -80,7 +80,9 @@ def prepare(path, name): # type: (str, str) -> None % name ) - logger.info("Module %s does not provide a setup.py. \nGenerating setup.py" % name) + logger.info( + "Module %s does not provide a setup.py. \nGenerating setup.py" % name + ) files.write_file(setup_path, data) @@ -125,7 +127,11 @@ def install(path, capture_error=False): # type: (str, bool) -> None logger.info("Installing module with the following command:\n%s", cmd) process.check_error( - shlex.split(cmd), errors.InstallModuleError, cwd=path, capture_error=capture_error + shlex.split(cmd), + errors.InstallModuleError, + 1, + cwd=path, + capture_error=capture_error, ) @@ -142,7 +148,11 @@ def install_requirements(path, capture_error=False): # type: (str, bool) -> Non logger.info("Installing dependencies from requirements.txt:\n{}".format(cmd)) process.check_error( - shlex.split(cmd), errors.InstallRequirementsError, cwd=path, capture_error=capture_error + shlex.split(cmd), + errors.InstallRequirementsError, + 1, + cwd=path, + capture_error=capture_error, ) @@ -171,4 +181,6 @@ def import_module(uri, name=DEFAULT_MODULE_NAME): # type: (str, str) -> module return module except Exception as e: # pylint: disable=broad-except - six.reraise(errors.ImportModuleError, errors.ImportModuleError(e), sys.exc_info()[2]) + six.reraise( + errors.ImportModuleError, errors.ImportModuleError(e), sys.exc_info()[2] + ) diff --git a/src/sagemaker_training/mpi.py b/src/sagemaker_training/mpi.py index e5b844de..ddbe6c3d 100644 --- a/src/sagemaker_training/mpi.py +++ b/src/sagemaker_training/mpi.py @@ -1,4 +1,4 @@ -# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of @@ -31,10 +31,12 @@ class WorkerRunner(process.ProcessRunner): """Runner responsible for preparing MPI distributed training and waiting for MPI - master execution to finish. + master execution to finish. """ - def __init__(self, user_entry_point, args, env_vars, master_hostname): + def __init__( + self, user_entry_point, args, env_vars, processes_per_host, master_hostname + ): """Initialize a WorkerRunner, which is responsible for preparing distributed training with MPI and waiting for MPI master execution to finish. @@ -44,7 +46,9 @@ def __init__(self, user_entry_point, args, env_vars, master_hostname): env_vars (dict(str,str)): A dictionary of environment variables. master_hostname (str): The master hostname. """ - super(WorkerRunner, self).__init__(user_entry_point, args, env_vars) + super(WorkerRunner, self).__init__( + user_entry_point, args, env_vars, processes_per_host + ) self._master_hostname = str(master_hostname) def run( @@ -62,7 +66,9 @@ def run( self._wait_master_to_start() logger.info("MPI Master online, creating SSH daemon.") - logger.info("Writing environment variables to /etc/environment for the MPI process.") + logger.info( + "Writing environment variables to /etc/environment for the MPI process." + ) _write_env_vars_to_file() _start_sshd_daemon() @@ -99,7 +105,9 @@ def _wait_orted_process_to_finish(): # type: () -> None def _orted_process(): # pylint: disable=inconsistent-return-statements """Wait a maximum of 5 minutes for orted process to start.""" for _ in range(5 * 60): - procs = [p for p in psutil.process_iter(attrs=["name"]) if p.info["name"] == "orted"] + procs = [ + p for p in psutil.process_iter(attrs=["name"]) if p.info["name"] == "orted" + ] if procs: logger.info("Process[es]: %s", procs) return procs @@ -116,14 +124,14 @@ def __init__( user_entry_point, args, env_vars, + processes_per_host, master_hostname, hosts, - process_per_host, custom_mpi_options, network_interface_name, interval=1, timeout_in_seconds=60 * 60, - num_processes=None, + num_processes=1, ): """Initialize a MasterRunner, which is responsible for preparing distributed training with MPI and synchronizing work among the Workers. @@ -134,7 +142,7 @@ def __init__( env_vars (dict(str,str)): A dictionary of environment variables. master_hostname (str): The master hostname. hosts ([str]): A list of hosts. - process_per_host (int): Number of processes per host. + processes_per_host (int): Number of processes per host. custom_mpi_options (str): A string of custom MPI options to be parsed. network_interface_name (str): The network interface name. interval (int or float): The interval at which to check the connection in seconds. @@ -144,11 +152,12 @@ def __init__( num_processes (int): The total number of processes. """ - super(MasterRunner, self).__init__(user_entry_point, args, env_vars) + super(MasterRunner, self).__init__( + user_entry_point, args, env_vars, processes_per_host + ) self._master_hostname = master_hostname self._hosts = hosts - self._process_per_host = process_per_host self._num_processes = num_processes self._custom_mpi_options = custom_mpi_options self._network_interface_name = network_interface_name @@ -174,16 +183,20 @@ def _wait_for_workers(self): # type: () -> None def _create_command(self): num_hosts = len(self._hosts) - num_processes = self._num_processes or self._process_per_host * num_hosts + num_processes = self._num_processes or self._processes_per_host * num_hosts # By default, use one process per GPU, or one process per node (if training with CPU). - if self._process_per_host == 1: + if self._processes_per_host == 1: host_list = self._hosts else: - host_list = ["%s:%s" % (host, self._process_per_host) for host in self._hosts] + host_list = [ + "%s:%s" % (host, self._processes_per_host) for host in self._hosts + ] msg = "Env Hosts: %s Hosts: %s process_per_hosts: %s num_processes: %s" - logger.info(msg, self._hosts, host_list, self._process_per_host, num_processes) + logger.info( + msg, self._hosts, host_list, self._processes_per_host, num_processes + ) overridden_known_options, additional_options = _parse_custom_mpi_options( self._custom_mpi_options @@ -241,7 +254,11 @@ def _create_command(self): command.extend(additional_options) - for credential in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]: + for credential in [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + ]: if credential in os.environ: command.extend(["-x", credential]) @@ -291,9 +308,9 @@ def _start_sshd_daemon(): # type: () -> None def _can_connect(host, port=22): # type: (str, int) -> bool """Check if the connection to provided ``host`` and ``port`` is possible. - Args: - host (str): Hostname for the host to check connection. - port (int): Port name of the host to check connection on. + Args: + host (str): Hostname for the host to check connection. + port (int): Port name of the host to check connection on. """ try: logger.debug("Testing connection to host %s", host) diff --git a/src/sagemaker_training/process.py b/src/sagemaker_training/process.py index 8ba83326..84188f62 100644 --- a/src/sagemaker_training/process.py +++ b/src/sagemaker_training/process.py @@ -1,4 +1,4 @@ -# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of @@ -14,18 +14,152 @@ and execute the user entry point within a process. """ from __future__ import absolute_import - +import asyncio import os +import re import subprocess import sys - +from asyncio.subprocess import PIPE import six from sagemaker_training import _entry_point_type, environment, errors, logging_config +# Default limit of the stream is 2**16 KB, we can increase it to 128KB in subproc call +_DEFAULT_BUF_SIZE = 1024 * 64 +# [x for x in dir(__builtins__) if 'Error' in x] +_PYTHON_ERRORS_ = [ + "ArithmeticError", + "AssertionError", + "AttributeError", + "BlockingIOError", + "BrokenPipeError", + "BufferError", + "ChildProcessError", + "ConnectionAbortedError", + "ConnectionError", + "ConnectionRefusedError", + "ConnectionResetError", + "EOFError", + "EnvironmentError", + "FileExistsError", + "FileNotFoundError", + "FloatingPointError", + "IOError", + "ImportError", + "IndentationError", + "IndexError", + "InterruptedError", + "IsADirectoryError", + "KeyError", + "LookupError", + "MemoryError", + "ModuleNotFoundError", + "NameError", + "NotADirectoryError", + "NotImplementedError", + "OSError", + "OverflowError", + "PermissionError", + "ProcessLookupError", + "RecursionError", + "ReferenceError", + "RuntimeError", + "SyntaxError", + "SystemError", + "TabError", + "TimeoutError", + "TypeError", + "UnboundLocalError", + "UnicodeDecodeError", + "UnicodeEncodeError", + "UnicodeError", + "UnicodeTranslateError", + "ValueError", + "ZeroDivisionError", +] + + +async def watch(stream, num_processes_per_host): + """Process the stdout and stderr streams on the fly. + Decode the output lines + Remove new line characters (if any) + Prepend tags for easier search on CloudWatch + Look for errors in the stderr + Returns: + output: Filtered stderr + """ + output = [] + buf_size = _DEFAULT_BUF_SIZE + start = False + while True: + lines = await stream.read(buf_size) + if not lines or lines == "": + break + + lines = lines.decode("utf-8").strip().split("\n") + for line in lines: + err_line = line + if "" in line: + line = re.sub( + r"\[(\d),(\d)\]", + lambda x: f"[{x[1]},mpirank:{x[2]}, algo-{(int(x[2])//num_processes_per_host)+1}]", + line, + ) + elif "" in line: + line = re.sub( + r"\[(\d),(\d)\]", + lambda x: f"[{x[1]},mpirank:{x[2]}, algo-{(int(x[2])//num_processes_per_host)+1}]", + line, + ) + print(line) + # log only if necessary + err_line = re.sub(r"\[(\d),(\d)\]", "", err_line) + if start: + if line not in output: + output.append(err_line) + else: + if any(err in err_line for err in _PYTHON_ERRORS_): + start = True + output.append(err_line + "\n") + + return " ".join(output) + + +async def run_async(cmd, error_class, processes_per_host, env, cwd, stderr, **kwargs): + """Method responsible for launching asyncio subprocess shell + Watching proc stdout and stderr + Returns: + return_code: Launched Process's return code + output: Processed [stdout, stderr] + asyncio.subprocess.Process: The asyncio process for the given command. + Raises: + error_class: If there is an exception raised when creating the process. + """ + cmd = " ".join(cmd) + try: + proc = await asyncio.create_subprocess_shell( + cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs + ) + output = await asyncio.gather( + watch(proc.stdout, processes_per_host), + watch(proc.stderr, processes_per_host), + ) + return_code = proc.returncode + return return_code, output, proc + except Exception as e: + six.reraise(error_class, error_class(e), sys.exc_info()[2]) + -def create(cmd, error_class, cwd=None, capture_error=False, **kwargs): - """Spawn a process with subprocess.Popen for the given command. +def create( + cmd, + error_class, + processes_per_host, + cwd=None, + env=None, + capture_error=False, + **kwargs, +): + """Spawn a process with asyncio for the given command. Args: cmd (list): The command to be run. @@ -37,51 +171,83 @@ def create(cmd, error_class, cwd=None, capture_error=False, **kwargs): **kwargs: Extra arguments that are passed to the subprocess.Popen constructor. Returns: - subprocess.Popen: The process for the given command. + asyncio.subprocess.Process: The asyncio process for the given command. Raises: error_class: If there is an exception raised when creating the process. """ try: - stderr = subprocess.PIPE if capture_error else None - return subprocess.Popen( - cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs + stderr = PIPE if capture_error else None + loop = asyncio.get_event_loop() + rc, output, proc = loop.run_until_complete( + run_async( + cmd, + error_class, + processes_per_host, + env=env or os.environ, + cwd=cwd or environment.code_dir, + stderr=stderr, + **kwargs, + ) ) except Exception as e: # pylint: disable=broad-except six.reraise(error_class, error_class(e), sys.exc_info()[2]) + finally: + loop.close() + return rc, output, proc -def check_error(cmd, error_class, capture_error=False, **kwargs): +def check_error( + cmd, error_class, processes_per_host, cwd=None, capture_error=False, **kwargs +): """Run a commmand, raising an exception if there is an error. - Args: cmd ([str]): The command to be run. error_class (cls): The class to use when raising an exception. + processes_per_host (int): Number of processes per host capture_error (bool): Whether or not to include stderr in the exception message (default: False). In either case, stderr is streamed to the process's output. **kwargs: Extra arguments that are passed to the subprocess.Popen constructor. - Returns: subprocess.Popen: The process for the given command. - Raises: error_class: If there is an exception raised when creating the process. """ - process = create(cmd, error_class, capture_error=capture_error, **kwargs) if capture_error: - _, stderr = process.communicate() - # This will force the stderr to be printed after stdout - # If wait is false and cature error is true, we will never see the stderr. - print(stderr.decode(errors="replace")) - return_code = process.poll() + return_code, output, process = create( + cmd, + error_class, + processes_per_host, + env=os.environ, + cwd=cwd or environment.code_dir, + capture_error=True, + **kwargs, + ) + stderr = output[1] else: stderr = None + process = subprocess.Popen( + cmd, + env=os.environ, + cwd=cwd or environment.code_dir, + stderr=stderr, + **kwargs, + ) return_code = process.wait() if return_code: - raise error_class(return_code=return_code, cmd=" ".join(cmd), output=stderr) + extra_info = None + if return_code == 137: + extra_info = "Out of memory: Process killed by SIGKILL (signal 9)" + raise error_class( + cmd=" ".join(cmd) if isinstance(cmd, list) else cmd, + return_code=return_code, + output=stderr, + info=extra_info, + ) + return process @@ -93,15 +259,16 @@ def python_executable(): (str): The real path of the current Python executable. """ if not sys.executable: - raise RuntimeError("Failed to retrieve the real path for the Python executable binary") + raise RuntimeError( + "Failed to retrieve the real path for the Python executable binary" + ) return sys.executable class ProcessRunner(object): - """Responsible for executing the user entry point within a process. - """ + """Responsible for executing the user entry point within a process.""" - def __init__(self, user_entry_point, args, env_vars): + def __init__(self, user_entry_point, args, env_vars, processes_per_host): """Initialize a ProcessRunner, which is responsible for executing the user entry point within a process. @@ -113,9 +280,12 @@ def __init__(self, user_entry_point, args, env_vars): self._user_entry_point = user_entry_point self._args = args self._env_vars = env_vars + self._processes_per_host = processes_per_host def _create_command(self): - entrypoint_type = _entry_point_type.get(environment.code_dir, self._user_entry_point) + entrypoint_type = _entry_point_type.get( + environment.code_dir, self._user_entry_point + ) if entrypoint_type is _entry_point_type.PYTHON_PACKAGE: entry_module = self._user_entry_point.replace(".py", "") @@ -127,7 +297,11 @@ def _create_command(self): six.moves.shlex_quote(arg) # pylint: disable=too-many-function-args for arg in self._args ] - return ["/bin/sh", "-c", "./%s %s" % (self._user_entry_point, " ".join(args))] + return [ + "/bin/sh", + "-c", + "./%s %s" % (self._user_entry_point, " ".join(args)), + ] def _python_command(self): # pylint: disable=no-self-use return [python_executable()] @@ -160,17 +334,18 @@ def run(self, wait=True, capture_error=False): process = check_error( cmd, errors.ExecuteUserScriptError, + self._processes_per_host, capture_error=capture_error, cwd=environment.code_dir, ) else: - process = create( + _, _, process = create( cmd, errors.ExecuteUserScriptError, + self._processes_per_host, capture_error=capture_error, cwd=environment.code_dir, ) self._tear_down() - return process diff --git a/src/sagemaker_training/runner.py b/src/sagemaker_training/runner.py index bb15126b..5a78e1fa 100644 --- a/src/sagemaker_training/runner.py +++ b/src/sagemaker_training/runner.py @@ -1,4 +1,4 @@ -# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of @@ -49,7 +49,9 @@ def get(identifier, user_entry_point=None, args=None, env_vars=None, extra_opts= if isinstance(identifier, process.ProcessRunner): return identifier else: - return _get_by_runner_type(identifier, user_entry_point, args, env_vars, extra_opts) + return _get_by_runner_type( + identifier, user_entry_point, args, env_vars, extra_opts + ) def _get_by_runner_type( @@ -60,6 +62,13 @@ def _get_by_runner_type( args = args or env.to_cmd_args() env_vars = env_vars or env.to_env_vars() mpi_args = extra_opts or {} + num_processes = _mpi_param_value(mpi_args, env, params.MPI_NUM_PROCESSES) + ## Processes per host + ## Default to single process for CPU + default_processes_per_host = int(env.num_gpus) if int(env.num_gpus) > 0 else 1 + processes_per_host = _mpi_param_value( + mpi_args, env, params.MPI_PROCESSES_PER_HOST, default_processes_per_host + ) if identifier is RunnerType.SMDataParallel and env.is_master: custom_mpi_options = _mpi_param_value( @@ -69,41 +78,44 @@ def _get_by_runner_type( user_entry_point, args, env_vars, + processes_per_host, env.master_hostname, env.hosts, custom_mpi_options, env.network_interface_name, ) elif identifier is RunnerType.SMDataParallel: - return mpi.WorkerRunner(user_entry_point, args, env_vars, env.master_hostname) + return mpi.WorkerRunner( + user_entry_point, args, env_vars, processes_per_host, env.master_hostname + ) elif identifier is RunnerType.MPI and env.is_master: - - # Default to single process for CPU - default_processes_per_host = env.num_gpus if env.num_gpus > 0 else 1 - processes_per_host = _mpi_param_value( - mpi_args, env, params.MPI_PROCESSES_PER_HOST, default_processes_per_host + custom_mpi_options = _mpi_param_value( + mpi_args, env, params.MPI_CUSTOM_OPTIONS, "" ) - num_processes = _mpi_param_value(mpi_args, env, params.MPI_NUM_PROCESSES) - custom_mpi_options = _mpi_param_value(mpi_args, env, params.MPI_CUSTOM_OPTIONS, "") - return mpi.MasterRunner( user_entry_point, args, env_vars, + processes_per_host, env.master_hostname, env.hosts, - processes_per_host, custom_mpi_options, env.network_interface_name, num_processes=num_processes, ) elif identifier is RunnerType.MPI: - return mpi.WorkerRunner(user_entry_point, args, env_vars, env.master_hostname) + return mpi.WorkerRunner( + user_entry_point, args, env_vars, processes_per_host, env.master_hostname + ) elif identifier is RunnerType.Process: - return process.ProcessRunner(user_entry_point, args, env_vars) + return process.ProcessRunner( + user_entry_point, args, env_vars, processes_per_host + ) else: raise ValueError("Invalid identifier %s" % identifier) def _mpi_param_value(mpi_args, env, param_name, default=None): - return mpi_args.get(param_name) or env.additional_framework_parameters.get(param_name, default) + return mpi_args.get(param_name) or env.additional_framework_parameters.get( + param_name, default + ) diff --git a/src/sagemaker_training/smdataparallel.py b/src/sagemaker_training/smdataparallel.py index 62763583..e797e9f9 100644 --- a/src/sagemaker_training/smdataparallel.py +++ b/src/sagemaker_training/smdataparallel.py @@ -40,6 +40,7 @@ def __init__( user_entry_point, args, env_vars, + processes_per_host, master_hostname, hosts, custom_mpi_options, @@ -66,7 +67,9 @@ def __init__( 3600 seconds (ie. 1 hour). """ - super(SMDataParallelRunner, self).__init__(user_entry_point, args, env_vars) + super(SMDataParallelRunner, self).__init__( + user_entry_point, args, env_vars, processes_per_host + ) self._master_hostname = master_hostname self._hosts = hosts @@ -107,9 +110,7 @@ def _get_mpirun_command( smdataparallel_server_addr=None, smdataparallel_server_port=None, ): - """Fetch mpi command for SMDataParallel - - """ + """Fetch mpi command for SMDataParallel""" overridden_known_options, additional_options = _parse_custom_mpi_options( self._custom_mpi_options ) @@ -217,7 +218,9 @@ def _create_command(self): # homogeneous mode uses 16 processes per host; 8 server; 8 worker smdataparallel_server_addr = self._master_hostname smdataparallel_server_port = 7592 - host_list = ["{}:{}".format(host, num_processes_per_host) for host in self._hosts] + host_list = [ + "{}:{}".format(host, num_processes_per_host) for host in self._hosts + ] smdataparallel_flag = "SMDATAPARALLEL_USE_HOMOGENEOUS=1" command = self._get_mpirun_command( num_hosts, @@ -264,6 +267,7 @@ def run(self, wait=True, capture_error=False): process_spawned = process.check_error( cmd, errors.ExecuteUserScriptError, + self._processes_per_host, capture_error=capture_error, cwd=environment.code_dir, ) @@ -271,6 +275,7 @@ def run(self, wait=True, capture_error=False): process_spawned = process.create( cmd, errors.ExecuteUserScriptError, + self._processes_per_host, capture_error=capture_error, cwd=environment.code_dir, ) @@ -311,9 +316,9 @@ def _can_connect(host, port=22): # type: (str, int) -> bool """Check if the connection to provided ``host`` and ``port`` is possible. - Args: - host (str): Hostname for the host to check connection. - port (int): Port name of the host to check connection on. + Args: + host (str): Hostname for the host to check connection. + port (int): Port name of the host to check connection on. """ try: logger.debug("Testing connection to host %s at port %s", host, port) diff --git a/src/sagemaker_training/trainer.py b/src/sagemaker_training/trainer.py index 2311c9ae..1f5f5266 100644 --- a/src/sagemaker_training/trainer.py +++ b/src/sagemaker_training/trainer.py @@ -17,6 +17,7 @@ import importlib import os +import sys import traceback from sagemaker_training import ( @@ -55,7 +56,9 @@ def _exit_processes(exit_code): # type: (int) -> None Args: exit_code (int): exit code """ - os._exit(exit_code) # pylint: disable=protected-access + if exit_code != 0: + logger.error(f"Encountered exit_code {exit_code}") + sys.exit(exit_code) def train(): @@ -80,14 +83,15 @@ def train(): # the framework to configure logging at import time. logging_config.configure_logger(env.log_level) logger.info("Imported framework %s", framework_name) - entrypoint = getattr(framework, entry_point_name) entrypoint() else: logging_config.configure_logger(env.log_level) mpi_enabled = env.additional_framework_parameters.get(params.MPI_ENABLED) - runner_type = runner.RunnerType.MPI if mpi_enabled else runner.RunnerType.Process + runner_type = ( + runner.RunnerType.MPI if mpi_enabled else runner.RunnerType.Process + ) entry_point.run( env.module_dir, @@ -96,7 +100,6 @@ def train(): env.to_env_vars(), runner_type=runner_type, ) - logger.info("Reporting training SUCCESS") files.write_success_file() @@ -104,6 +107,7 @@ def train(): failure_message = str(e) files.write_failure_file(failure_message) + logger.error("Reporting training FAILURE") logger.error(failure_message) @@ -112,7 +116,7 @@ def train(): exit_code = DEFAULT_FAILURE_CODE except Exception as e: # pylint: disable=broad-except - failure_msg = "framework error: \n%s\n%s" % (traceback.format_exc(), str(e)) + failure_msg = "Framework Error: \n%s\n%s" % (traceback.format_exc(), str(e)) files.write_failure_file(failure_msg) logger.error("Reporting training FAILURE") @@ -124,5 +128,4 @@ def train(): finally: if intermediate_sync: intermediate_sync.join() - _exit_processes(exit_code) diff --git a/test/conftest.py b/test/conftest.py index bad1cf63..c9a93a1a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import asyncio import json import logging import os @@ -78,3 +79,10 @@ def fix_protobuf_installation_for_python_2(): site_packages = re.match(r"[\S\s]*Location: (.*)\s", protobuf_info).group(1) with open(os.path.join(site_packages, "google", "__init__.py"), "w"): pass + + +@pytest.fixture(autouse=True) +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() diff --git a/test/unit/test_entry_point.py b/test/unit/test_entry_point.py index 1a531cfd..a8e944d3 100644 --- a/test/unit/test_entry_point.py +++ b/test/unit/test_entry_point.py @@ -1,4 +1,4 @@ -# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of @@ -49,7 +49,9 @@ def test_install_module(check_error, prepare, entry_point_type_module): entry_point.install("python_module.py", path) cmd = [sys.executable, "-m", "pip", "install", "."] - check_error.assert_called_with(cmd, errors.InstallModuleError, capture_error=False, cwd=path) + check_error.assert_called_with( + cmd, errors.InstallModuleError, 1, capture_error=False, cwd=path + ) with patch("os.path.exists", return_value=True): entry_point.install("python_module.py", path) @@ -57,6 +59,7 @@ def test_install_module(check_error, prepare, entry_point_type_module): check_error.assert_called_with( cmd + ["-r", "requirements.txt"], errors.InstallModuleError, + 1, cwd=path, capture_error=False, ) @@ -64,7 +67,9 @@ def test_install_module(check_error, prepare, entry_point_type_module): @patch("sagemaker_training.modules.prepare") @patch("sagemaker_training.process.check_error", autospec=True) -def test_install_script(check_error, prepare, entry_point_type_module, has_requirements): +def test_install_script( + check_error, prepare, entry_point_type_module, has_requirements +): path = "c://sagemaker-pytorch-container" entry_point.install("train.py", path) @@ -88,7 +93,10 @@ def test_install_no_python_executable( ): with pytest.raises(RuntimeError) as e: entry_point.install("train.py", "git://aws/container-support") - assert str(e.value) == "Failed to retrieve the real path for the Python executable binary" + assert ( + str(e.value) + == "Failed to retrieve the real path for the Python executable binary" + ) @patch("os.chmod") @@ -132,13 +140,21 @@ def test_run_module_wait(gethostbyname, check_error, chmod, download_and_extract @patch("sagemaker_training.files.download_and_extract") @patch("sagemaker_training.modules.install") @patch.object( - environment.Environment, "hosts", return_value=["algo-1", "algo-2"], new_callable=PropertyMock + environment.Environment, + "hosts", + return_value=["algo-1", "algo-2"], + new_callable=PropertyMock, ) @patch("socket.gethostbyname") -def test_run_calls_hostname_resolution(gethostbyname, install, hosts, download_and_extract): +def test_run_calls_hostname_resolution( + gethostbyname, install, hosts, download_and_extract +): runner_mock = MagicMock(spec=process.ProcessRunner) entry_point.run( - uri="s3://url", user_entry_point="launcher.py", args=["42"], runner_type=runner_mock + uri="s3://url", + user_entry_point="launcher.py", + args=["42"], + runner_type=runner_mock, ) gethostbyname.assert_called_with("algo-2") @@ -148,19 +164,29 @@ def test_run_calls_hostname_resolution(gethostbyname, install, hosts, download_a @patch("sagemaker_training.files.download_and_extract") @patch("sagemaker_training.modules.install") @patch.object( - environment.Environment, "hosts", return_value=["algo-1", "algo-2"], new_callable=PropertyMock + environment.Environment, + "hosts", + return_value=["algo-1", "algo-2"], + new_callable=PropertyMock, ) @patch("socket.gethostbyname") -def test_run_waits_hostname_resolution(gethostbyname, hosts, install, download_and_extract): +def test_run_waits_hostname_resolution( + gethostbyname, hosts, install, download_and_extract +): gethostbyname.side_effect = [ValueError(), ValueError(), True, True] runner_mock = MagicMock(spec=process.ProcessRunner) entry_point.run( - uri="s3://url", user_entry_point="launcher.py", args=["42"], runner_type=runner_mock + uri="s3://url", + user_entry_point="launcher.py", + args=["42"], + runner_type=runner_mock, ) - gethostbyname.assert_has_calls([call("algo-1"), call("algo-1"), call("algo-1"), call("algo-2")]) + gethostbyname.assert_has_calls( + [call("algo-1"), call("algo-1"), call("algo-1"), call("algo-2")] + ) @patch("sagemaker_training.files.download_and_extract") @@ -186,11 +212,16 @@ def test_run_module_no_wait(gethostbyname, chmod, download_and_extract): @patch("sagemaker_training.files.download_and_extract") @patch("os.chmod") @patch("socket.gethostbyname") -def test_run_module_with_env_vars(gethostbyname, chmod, download_and_extract, get_runner, sys_path): +def test_run_module_with_env_vars( + gethostbyname, chmod, download_and_extract, get_runner, sys_path +): module_name = "default_user_module_name" args = ["--some-arg", "42"] entry_point.run( - uri="s3://url", user_entry_point=module_name, args=args, env_vars={"FOO": "BAR"} + uri="s3://url", + user_entry_point=module_name, + args=args, + env_vars={"FOO": "BAR"}, ) expected_env_vars = {"FOO": "BAR", "PYTHONPATH": ""} @@ -211,5 +242,9 @@ def test_run_module_with_extra_opts( args = ["--some-arg", "42"] extra_opts = {"foo": "bar"} - entry_point.run(uri="s3://url", user_entry_point=module_name, args=args, extra_opts=extra_opts) - get_runner.assert_called_with(runner.ProcessRunnerType, module_name, args, {}, extra_opts) + entry_point.run( + uri="s3://url", user_entry_point=module_name, args=args, extra_opts=extra_opts + ) + get_runner.assert_called_with( + runner.ProcessRunnerType, module_name, args, {}, extra_opts + ) diff --git a/test/unit/test_errors.py b/test/unit/test_errors.py index 365c3be7..052e8ec8 100644 --- a/test/unit/test_errors.py +++ b/test/unit/test_errors.py @@ -17,33 +17,49 @@ def test_install_module_error(): error = errors.InstallModuleError(["python", "-m", "42"], return_code=42) - - assert str(error) == "InstallModuleError:\nCommand \"['python', '-m', '42']\"" + assert ( + str(error) + == "InstallModuleError:\nExitCode 42\nErrorMessage \"\"\nCommand \"['python', '-m', '42']\"" + ) def test_execute_user_script_error(): error = errors.ExecuteUserScriptError(["python", "-m", "42"], return_code=42) - assert str(error) == "ExecuteUserScriptError:\nCommand \"['python', '-m', '42']\"" + assert ( + str(error) + == "ExecuteUserScriptError:\nExitCode 42\nErrorMessage \"\"\nCommand \"['python', '-m', '42']\"" + ) def test_install_module_error_with_output(): - error = errors.InstallModuleError(["python", "-m", "42"], return_code=42, output=b"42") + error = errors.InstallModuleError( + ["python", "-m", "42"], return_code=42, output="42" + ) assert ( str(error) - == """InstallModuleError: -Command "['python', '-m', '42']" -42""" + == "InstallModuleError:\nExitCode 42\nErrorMessage \"42\"\nCommand \"['python', '-m', '42']\"" ) def test_execute_user_script_error_with_output(): - error = errors.ExecuteUserScriptError(["python", "-m", "42"], return_code=42, output=b"42") + error = errors.ExecuteUserScriptError( + ["python", "-m", "42"], return_code=137, output=b"42" + ) + + assert ( + str(error) + == "ExecuteUserScriptError:\nExitCode 137\nErrorMessage \"42\"\nCommand \"['python', '-m', '42']\"" + ) + + +def test_execute_user_script_error_with_output_and_info(): + error = errors.ExecuteUserScriptError( + ["python", "-m", "42"], return_code=137, output="42", info="SIGKILL" + ) assert ( str(error) - == """ExecuteUserScriptError: -Command "['python', '-m', '42']" -42""" + == "ExecuteUserScriptError:\nExitCode 137\nErrorMessage \"42\"\nExtraInfo \"SIGKILL\"\nCommand \"['python', '-m', '42']\"" ) diff --git a/test/unit/test_modules.py b/test/unit/test_modules.py index e37a48c9..4ae58daf 100644 --- a/test/unit/test_modules.py +++ b/test/unit/test_modules.py @@ -30,8 +30,20 @@ @pytest.mark.parametrize( "url,bucket_name,key,dst,endpoint", [ - ("S3://my-bucket/path/to/my-file", "my-bucket", "path/to/my-file", "/tmp/my-file", None), - ("s3://my-bucket/my-file", "my-bucket", "my-file", "/tmp/my-file", "http://localhost:9000"), + ( + "S3://my-bucket/path/to/my-file", + "my-bucket", + "path/to/my-file", + "/tmp/my-file", + None, + ), + ( + "s3://my-bucket/my-file", + "my-bucket", + "my-file", + "/tmp/my-file", + "http://localhost:9000", + ), ], ) def test_s3_download(resource, url, bucket_name, key, dst, endpoint): @@ -63,7 +75,9 @@ def test_install(check_error): modules.install(path) cmd = [sys.executable, "-m", "pip", "install", "."] - check_error.assert_called_with(cmd, errors.InstallModuleError, cwd=path, capture_error=False) + check_error.assert_called_with( + cmd, errors.InstallModuleError, 1, cwd=path, capture_error=False + ) with patch("os.path.exists", return_value=True): modules.install(path) @@ -71,6 +85,7 @@ def test_install(check_error): check_error.assert_called_with( cmd + ["-r", "requirements.txt"], errors.InstallModuleError, + 1, capture_error=False, cwd=path, ) @@ -86,7 +101,7 @@ def test_install_requirements(check_error): modules.install_requirements(path) check_error.assert_called_with( - cmd, errors.InstallRequirementsError, cwd=path, capture_error=False + cmd, errors.InstallRequirementsError, 1, cwd=path, capture_error=False ) @@ -101,7 +116,10 @@ def test_install_fails(check_error): def test_install_no_python_executable(): with pytest.raises(RuntimeError) as e: modules.install("git://aws/container-support") - assert str(e.value) == "Failed to retrieve the real path for the Python executable binary" + assert ( + str(e.value) + == "Failed to retrieve the real path for the Python executable binary" + ) @contextlib.contextmanager @@ -173,7 +191,9 @@ def test_import_module(reload, import_module, install, download_and_extract): modules.import_module("s3://bucket/my-module") - download_and_extract.assert_called_with("s3://bucket/my-module", environment.code_dir) + download_and_extract.assert_called_with( + "s3://bucket/my-module", environment.code_dir + ) install.assert_called_with(environment.code_dir) reload.assert_called_with(import_module(modules.DEFAULT_MODULE_NAME)) @@ -191,6 +211,8 @@ def test_import_module_local_directory( modules.import_module(uri) s3_download.assert_not_called() - tarfile.assert_called_with(name="/opt/ml/input/data/code/sourcedir.tar.gz", mode="r:gz") + tarfile.assert_called_with( + name="/opt/ml/input/data/code/sourcedir.tar.gz", mode="r:gz" + ) prepare.assert_called_once() install.assert_called_once() diff --git a/test/unit/test_mpi.py b/test/unit/test_mpi.py index 7d9361e0..a1345ef3 100644 --- a/test/unit/test_mpi.py +++ b/test/unit/test_mpi.py @@ -1,4 +1,4 @@ -# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of @@ -12,13 +12,17 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import asyncio import inspect import os +from test.unit.test_process import AsyncMock from mock import ANY, MagicMock, patch - +import unittest import gethostname +import pytest from sagemaker_training import environment, mpi +import nest_asyncio def does_not_connect(): @@ -29,10 +33,17 @@ def connect(): pass +class AsyncMockCall(MagicMock): + async def __call__(self, *args, **kwargs): + super().__call__(*args, **kwargs) + + class MockSSHClient(MagicMock): def __init__(self, *args, **kw): super(MockSSHClient, self).__init__(*args, **kw) - self.connect = MagicMock(side_effect=[does_not_connect, connect, does_not_connect]) + self.connect = MagicMock( + side_effect=[does_not_connect, connect, does_not_connect] + ) @patch("sagemaker_training.mpi._write_env_vars_to_file") @@ -44,7 +55,14 @@ def __init__(self, *args, **kw): @patch("paramiko.AutoAddPolicy") @patch("subprocess.Popen") def test_mpi_worker_run( - popen, policy, process_iter, wait_procs, ssh_client, sleep, path_exists, write_env_vars + popen, + policy, + process_iter, + wait_procs, + ssh_client, + sleep, + path_exists, + write_env_vars, ): process = MagicMock(info={"name": "orted"}) @@ -54,6 +72,7 @@ def test_mpi_worker_run( user_entry_point="train.sh", args=["-v", "--lr", "35"], env_vars={"LD_CONFIG_PATH": "/etc/ld"}, + processes_per_host="1", master_hostname="algo-1", ) @@ -80,6 +99,7 @@ def test_mpi_worker_run_no_wait(popen, ssh_client, path_exists, write_env_vars): user_entry_point="train.sh", args=["-v", "--lr", "35"], env_vars={"LD_CONFIG_PATH": "/etc/ld"}, + processes_per_host=1, master_hostname="algo-1", ) diff --git a/test/unit/test_process.py b/test/unit/test_process.py index 6704f426..419de64e 100644 --- a/test/unit/test_process.py +++ b/test/unit/test_process.py @@ -1,4 +1,4 @@ -# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of @@ -11,17 +11,23 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from io import StringIO import os import subprocess import sys - -from mock import MagicMock, patch +import unittest +from mock import ANY, MagicMock, patch import pytest - +import nest_asyncio from sagemaker_training import environment, errors, process +class AsyncMock(MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + @pytest.fixture def entry_point_type_module(): with patch("os.listdir", lambda x: ("setup.py",)): @@ -49,7 +55,7 @@ def test_python_executable_exception(): @patch("subprocess.Popen", MagicMock(side_effect=ValueError("FAIL"))) def test_create_error(): with pytest.raises(errors.ExecuteUserScriptError): - process.create(["run"], errors.ExecuteUserScriptError) + process.create(["run"], errors.ExecuteUserScriptError, 1) @patch("subprocess.Popen") @@ -57,14 +63,16 @@ def test_check_error(popen): test_process = MagicMock(wait=MagicMock(return_value=0)) popen.return_value = test_process - assert test_process == process.check_error(["run"], errors.ExecuteUserScriptError) + assert test_process == process.check_error( + ["run"], errors.ExecuteUserScriptError, 1 + ) @patch("subprocess.Popen") @patch("sagemaker_training.logging_config.log_script_invocation") def test_run_bash(log, popen, entry_point_type_script): with pytest.raises(errors.ExecuteUserScriptError): - process.ProcessRunner("launcher.sh", ["--lr", "1 3"], {}).run() + process.ProcessRunner("launcher.sh", ["--lr", "1 3"], {}, 1).run() cmd = ["/bin/sh", "-c", "./launcher.sh --lr '1 3'"] popen.assert_called_with(cmd, cwd=environment.code_dir, env=os.environ, stderr=None) @@ -77,10 +85,14 @@ def test_run_python(log, popen, entry_point_type_script): popen().communicate.return_value = (None, b"this is stderr") with pytest.raises(errors.ExecuteUserScriptError): - process.ProcessRunner("launcher.py", ["--lr", "13"], {}).run(capture_error=True) + process.ProcessRunner("launcher.py", ["--lr", "13"], {}, 1).run( + capture_error=False + ) cmd = [sys.executable, "launcher.py", "--lr", "13"] - popen.assert_called_with(cmd, cwd=environment.code_dir, env=os.environ, stderr=subprocess.PIPE) + popen.assert_called_with( + cmd, cwd=environment.code_dir, env=os.environ, stderr=subprocess.PIPE + ) log.assert_called_with(cmd, {}) @@ -88,7 +100,7 @@ def test_run_python(log, popen, entry_point_type_script): @patch("sagemaker_training.logging_config.log_script_invocation") def test_run_module(log, popen, entry_point_type_module): with pytest.raises(errors.ExecuteUserScriptError): - process.ProcessRunner("module.py", ["--lr", "13"], {}).run() + process.ProcessRunner("module.py", ["--lr", "13"], {}, 1).run() cmd = [sys.executable, "-m", "module", "--lr", "13"] popen.assert_called_with(cmd, cwd=environment.code_dir, env=os.environ, stderr=None) @@ -98,7 +110,20 @@ def test_run_module(log, popen, entry_point_type_module): @patch("sagemaker_training.environment.Environment", lambda: {}) def test_run_error(): with pytest.raises(errors.ExecuteUserScriptError) as e: - process.ProcessRunner("wrong module", [], {}).run() + process.ProcessRunner("wrong module", [], {}, 1).run() message = str(e.value) assert "ExecuteUserScriptError:" in message + + +@pytest.mark.asyncio +async def test_watch(event_loop): + expected_outcome = "[1, mpirank:2, algo-2]:This is stdout" + + def write_to_stream(): + print("[1,2]:This is stdout") + + with patch("sys.stdout", new=StringIO()) as mock_stdout: + write_to_stream() + output = await process.watch(mock_stdout.getvalue(), 2) + assert output.return_value == expected_outcome diff --git a/test/unit/test_runner.py b/test/unit/test_runner.py index b3eda13c..4d4bd68a 100644 --- a/test/unit/test_runner.py +++ b/test/unit/test_runner.py @@ -1,4 +1,4 @@ -# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of @@ -20,6 +20,7 @@ USER_SCRIPT = "script" CMD_ARGS = ["--some-arg", 42] ENV_VARS = {"FOO": "BAR"} +PROC_PER_HOST = 1 NCCL_DEBUG_MPI_OPT = "-X NCCL_DEBUG=WARN" MPI_OPTS = { @@ -88,7 +89,7 @@ def test_runnner_with_default_cpu_processes_per_host(training_env): test_runner = runner.get(runner.MPIRunnerType) assert isinstance(test_runner, mpi.MasterRunner) - assert test_runner._process_per_host == 1 + assert test_runner._processes_per_host == 1 @patch("sagemaker_training.environment.Environment") @@ -99,21 +100,23 @@ def test_runnner_with_default_gpu_processes_per_host(training_env): test_runner = runner.get(runner.MPIRunnerType) assert isinstance(test_runner, mpi.MasterRunner) - assert test_runner._process_per_host == 2 + assert test_runner._processes_per_host == 2 @patch("sagemaker_training.environment.Environment") def test_get_runner_by_mpi_with_extra_args(training_env): training_env().num_gpus = 0 - test_runner = runner.get(runner.MPIRunnerType, USER_SCRIPT, CMD_ARGS, ENV_VARS, MPI_OPTS) + test_runner = runner.get( + runner.MPIRunnerType, USER_SCRIPT, CMD_ARGS, ENV_VARS, MPI_OPTS + ) assert isinstance(test_runner, mpi.MasterRunner) assert test_runner._user_entry_point == USER_SCRIPT assert test_runner._args == CMD_ARGS assert test_runner._env_vars == ENV_VARS - assert test_runner._process_per_host == 2 + assert test_runner._processes_per_host == 2 assert test_runner._num_processes == 4 assert test_runner._custom_mpi_options == NCCL_DEBUG_MPI_OPT diff --git a/test/unit/test_smdataparallel.py b/test/unit/test_smdataparallel.py index 081fe4b9..3685d90c 100644 --- a/test/unit/test_smdataparallel.py +++ b/test/unit/test_smdataparallel.py @@ -49,6 +49,7 @@ def test_smdataparallel_run_multi_node_python( env_vars={ "SM_TRAINING_ENV": '{"additional_framework_parameters":{"sagemaker_instance_type":"ml.p3.16xlarge"}}' }, + processes_per_host=num_processes_per_host, master_hostname=master_hostname, hosts=hosts, custom_mpi_options="--verbose", @@ -162,6 +163,7 @@ def test_smdataparallel_run_single_node_python( env_vars={ "SM_TRAINING_ENV": '{"additional_framework_parameters":{"sagemaker_instance_type":"ml.p4d.24xlarge"}}' }, + processes_per_host=num_processes_per_host, master_hostname=master_hostname, hosts=hosts, custom_mpi_options="--verbose", diff --git a/tox.ini b/tox.ini index b4b451bc..4742bf34 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = black-format,flake8,pylint,twine,py27,py36,py37 +envlist = black-format,flake8,pylint,twine,py27,py36,py37,py38 skip_missing_interpreters = False @@ -46,6 +46,7 @@ deps = pytest==4.4.1 pytest-cov pytest-xdist + pytest-asyncio mock awslogs sagemaker[local]==2.18.0 @@ -55,6 +56,7 @@ deps = gevent paramiko==2.4.2 psutil==5.6.7 + nest_asyncio [testenv:twine] basepython = python3