From f3a8d7321b1fe76e1f59554b132a9980bc355aaf Mon Sep 17 00:00:00 2001 From: Daiming Yang Date: Fri, 1 Apr 2022 18:23:09 -0700 Subject: [PATCH] feature: add back FI_EFA_USE_DEVICE_RDMA=1 flag, revert 2936f22 fix: fixed the black lint, upgraded black to version 21.3.0 fix: remove u prefix of strings, as python3 defaults to unicode strings note: EFA is only available on p3dn or p4dn instances note: EFA version 1.15.1 and OFI 1.1.5-aws have the issue fixed note: black format reference on remove u prefix https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#strings --- CHANGELOG.md | 2 ++ VERSION | 2 +- setup.py | 2 +- src/sagemaker_training/mapping.py | 6 +++--- src/sagemaker_training/smdataparallel.py | 3 +++ test/unit/test_encoder.py | 6 +++--- test/unit/test_mapping.py | 12 ++++++------ test/unit/test_smdataparallel.py | 4 ++++ tox.ini | 4 ++-- 9 files changed, 25 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 992c891b5..e593b51e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Changelog +## v4.0.1 (2022-01-29) + ## v4.0.0 (2021-10-08) ### Breaking Changes diff --git a/VERSION b/VERSION index 39c84160e..5d5e81ba5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -4.0.1.dev0 +4.0.2.dev0 diff --git a/setup.py b/setup.py index f79aca147..81fab38a2 100644 --- a/setup.py +++ b/setup.py @@ -85,7 +85,7 @@ def read_version(): "pytest-cov", "mock", "sagemaker[local]<2", - "black==19.3b0 ; python_version >= '3.7'", + "black==22.3.0 ; python_version >= '3.7'", ] }, entry_points={"console_scripts": ["train=sagemaker_training.cli.train:main"]}, diff --git a/src/sagemaker_training/mapping.py b/src/sagemaker_training/mapping.py index d0794f3ed..3284e3982 100644 --- a/src/sagemaker_training/mapping.py +++ b/src/sagemaker_training/mapping.py @@ -79,9 +79,9 @@ def to_cmd_args(mapping): # type: (dict) -> list def arg_name(obj): string = _decode(obj) if string: - return u"--%s" % string if len(string) > 1 else u"-%s" % string + return "--%s" % string if len(string) > 1 else "-%s" % string else: - return u"" + return "" arg_names = [arg_name(argument) for argument in sorted_keys] @@ -106,7 +106,7 @@ def _decode(obj): # type: (bytes or str or unicode or object) -> unicode # noqa Object decoded in unicode. """ if obj is None: - return u"" + return "" if six.PY3 and isinstance(obj, six.binary_type): # transforms a byte string (b'') in unicode return obj.decode("latin1") diff --git a/src/sagemaker_training/smdataparallel.py b/src/sagemaker_training/smdataparallel.py index 11126e057..694982542 100644 --- a/src/sagemaker_training/smdataparallel.py +++ b/src/sagemaker_training/smdataparallel.py @@ -169,6 +169,9 @@ def _get_mpirun_command( mpirun_command.extend(additional_options) instance_type = self._get_instance_type() + # Use EFA's RDMA functionality for one-sided and two-sided transfer + if instance_type in ["ml.p3dn.24xlarge", "ml.p4d.24xlarge"]: + mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"]) if smdataparallel_server_addr and smdataparallel_server_port: # in case of multi-node [distributed] training, smdataparallel_server_addr, diff --git a/test/unit/test_encoder.py b/test/unit/test_encoder.py index d55332ba4..9203c86ae 100644 --- a/test/unit/test_encoder.py +++ b/test/unit/test_encoder.py @@ -26,7 +26,7 @@ @pytest.mark.parametrize( "target", - ([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], [u"42", u"6", u"9"], {42: {"6": 9.0}}), + ([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], ["42", "6", "9"], {42: {"6": 9.0}}), ) def test_npy_to_numpy(target): buffer = BytesIO() @@ -40,7 +40,7 @@ def test_npy_to_numpy(target): @pytest.mark.parametrize( "target", - ([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], [u"42", u"6", u"9"], {42: {"6": 9.0}}), + ([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], ["42", "6", "9"], {42: {"6": 9.0}}), ) def test_array_to_npy(target): input_data = np.array(target) @@ -60,7 +60,7 @@ def test_array_to_npy(target): ("[42, 6, 9]", np.array([42, 6, 9])), ("[42.0, 6.0, 9.0]", np.array([42.0, 6.0, 9.0])), ('["42", "6", "9"]', np.array(["42", "6", "9"])), - (u'["42", "6", "9"]', np.array([u"42", u"6", u"9"])), + ('["42", "6", "9"]', np.array(["42", "6", "9"])), ], ) def test_json_to_numpy(target, expected): diff --git a/test/unit/test_mapping.py b/test/unit/test_mapping.py index 71051702d..f8d4f68d8 100644 --- a/test/unit/test_mapping.py +++ b/test/unit/test_mapping.py @@ -102,15 +102,15 @@ def test_mapping_throws_exception_trying_to_access_non_properties(property, erro [ ( {"da-sh": "1", "un_der": "2", "un-sh": "3", "da_der": "2"}, - [u"--da-sh", u"1", u"--da_der", u"2", u"--un-sh", u"3", u"--un_der", u"2"], + ["--da-sh", "1", "--da_der", "2", "--un-sh", "3", "--un_der", "2"], ), ({}, []), - ({"": ""}, [u"", u""]), + ({"": ""}, ["", ""]), ( - {"unicode": u"¡ø", "bytes": b"2", "floats": 4.0, "int": 2}, - [u"--bytes", u"2", u"--floats", u"4.0", u"--int", u"2", u"--unicode", u"¡ø"], + {"unicode": "¡ø", "bytes": b"2", "floats": 4.0, "int": 2}, + ["--bytes", "2", "--floats", "4.0", "--int", "2", "--unicode", "¡ø"], ), - ({"U": u"1", "b": b"2", "T": "", "": "42"}, ["", "42", "-T", "", "-U", "1", "-b", "2"]), + ({"U": "1", "b": b"2", "T": "", "": "42"}, ["", "42", "-T", "", "-U", "1", "-b", "2"]), ({"nested": ["1", ["2", "3", [["6"]]]]}, ["--nested", "['1', ['2', '3', [['6']]]]"]), ( {"map": {"a": [1, 3, 4]}, "channel_dirs": {"train": "foo", "eval": "bar"}}, @@ -133,7 +133,7 @@ def test_to_cmd_args(target, expected): {"SM_MODEL_DIR": "/opt/ml/model", "SM_OUTPUT_DIR": "/opt/ml/output"}, ), ({}, {}), - ({"": None}, {u"": u""}), + ({"": None}, {"": ""}), ( {"bytes": b"2", "floats": 4.0, "int": 2, "unicode": "¡ø"}, {"SM_BYTES": "2", "SM_FLOATS": "4.0", "SM_INT": "2", "SM_UNICODE": "¡ø"}, diff --git a/test/unit/test_smdataparallel.py b/test/unit/test_smdataparallel.py index 054b3b9cf..9d3a23f59 100644 --- a/test/unit/test_smdataparallel.py +++ b/test/unit/test_smdataparallel.py @@ -126,6 +126,8 @@ def test_smdataparallel_run_multi_node_python( "RDMAV_FORK_SAFE=1", "-x", "LD_PRELOAD=%s" % inspect.getfile(gethostname), + "-x", + "FI_EFA_USE_DEVICE_RDMA=1", "--verbose", "-x", "SMDATAPARALLEL_SERVER_ADDR=%s" % smdataparallel_server_addr, @@ -245,6 +247,8 @@ def test_smdataparallel_run_single_node_python( "RDMAV_FORK_SAFE=1", "-x", "LD_PRELOAD=%s" % inspect.getfile(gethostname), + "-x", + "FI_EFA_USE_DEVICE_RDMA=1", "--verbose", "smddprun", "usr/bin/python3", diff --git a/tox.ini b/tox.ini index e5a0e4774..3640ff188 100644 --- a/tox.ini +++ b/tox.ini @@ -80,14 +80,14 @@ commands = flake8 [testenv:black-format] # Used during development (before committing) to format .py files. basepython = python3.7 -deps = black==19.3b0 +deps = black==22.3.0 commands = black -l 100 ./ [testenv:black-check] # Used by automated build steps to check that all files are properly formatted. basepython = python3.7 -deps = black==19.3b0 +deps = black==22.3.0 commands = black -l 100 --check ./