Skip to content

Commit

Permalink
feature: add back FI_EFA_USE_DEVICE_RDMA=1 flag, revert 2936f22
Browse files Browse the repository at this point in the history
fix: fixed the black lint, upgraded black to version 21.3.0
fix: remove u prefix of strings, as python3 defaults to unicode strings

note: EFA is only available on p3dn or p4dn instances
note: EFA version 1.15.1 and OFI 1.1.5-aws have the issue fixed
note: black format reference on remove u prefix
https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#strings
  • Loading branch information
ydaiming committed Apr 4, 2022
1 parent 3196927 commit b5e90ca
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 15 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def read_version():
"pytest-cov",
"mock",
"sagemaker[local]<2",
"black==19.3b0 ; python_version >= '3.7'",
"black==22.3.0 ; python_version >= '3.7'",
]
},
entry_points={"console_scripts": ["train=sagemaker_training.cli.train:main"]},
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker_training/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def to_cmd_args(mapping): # type: (dict) -> list
def arg_name(obj):
string = _decode(obj)
if string:
return u"--%s" % string if len(string) > 1 else u"-%s" % string
return "--%s" % string if len(string) > 1 else "-%s" % string
else:
return u""
return ""

arg_names = [arg_name(argument) for argument in sorted_keys]

Expand All @@ -106,7 +106,7 @@ def _decode(obj): # type: (bytes or str or unicode or object) -> unicode # noqa
Object decoded in unicode.
"""
if obj is None:
return u""
return ""
if six.PY3 and isinstance(obj, six.binary_type):
# transforms a byte string (b'') in unicode
return obj.decode("latin1")
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker_training/smdataparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def _get_mpirun_command(
mpirun_command.extend(additional_options)

instance_type = self._get_instance_type()
# Use EFA's RDMA functionality for one-sided and two-sided transfer
if instance_type in ["ml.p3dn.24xlarge", "ml.p4d.24xlarge"]:
mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"])

if smdataparallel_server_addr and smdataparallel_server_port:
# in case of multi-node [distributed] training, smdataparallel_server_addr,
Expand Down
6 changes: 3 additions & 3 deletions test/unit/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

@pytest.mark.parametrize(
"target",
([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], [u"42", u"6", u"9"], {42: {"6": 9.0}}),
([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], ["42", "6", "9"], {42: {"6": 9.0}}),
)
def test_npy_to_numpy(target):
buffer = BytesIO()
Expand All @@ -40,7 +40,7 @@ def test_npy_to_numpy(target):

@pytest.mark.parametrize(
"target",
([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], [u"42", u"6", u"9"], {42: {"6": 9.0}}),
([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], ["42", "6", "9"], {42: {"6": 9.0}}),
)
def test_array_to_npy(target):
input_data = np.array(target)
Expand All @@ -60,7 +60,7 @@ def test_array_to_npy(target):
("[42, 6, 9]", np.array([42, 6, 9])),
("[42.0, 6.0, 9.0]", np.array([42.0, 6.0, 9.0])),
('["42", "6", "9"]', np.array(["42", "6", "9"])),
(u'["42", "6", "9"]', np.array([u"42", u"6", u"9"])),
('["42", "6", "9"]', np.array(["42", "6", "9"])),
],
)
def test_json_to_numpy(target, expected):
Expand Down
12 changes: 6 additions & 6 deletions test/unit/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ def test_mapping_throws_exception_trying_to_access_non_properties(property, erro
[
(
{"da-sh": "1", "un_der": "2", "un-sh": "3", "da_der": "2"},
[u"--da-sh", u"1", u"--da_der", u"2", u"--un-sh", u"3", u"--un_der", u"2"],
["--da-sh", "1", "--da_der", "2", "--un-sh", "3", "--un_der", "2"],
),
({}, []),
({"": ""}, [u"", u""]),
({"": ""}, ["", ""]),
(
{"unicode": u"¡ø", "bytes": b"2", "floats": 4.0, "int": 2},
[u"--bytes", u"2", u"--floats", u"4.0", u"--int", u"2", u"--unicode", u"¡ø"],
{"unicode": "¡ø", "bytes": b"2", "floats": 4.0, "int": 2},
["--bytes", "2", "--floats", "4.0", "--int", "2", "--unicode", "¡ø"],
),
({"U": u"1", "b": b"2", "T": "", "": "42"}, ["", "42", "-T", "", "-U", "1", "-b", "2"]),
({"U": "1", "b": b"2", "T": "", "": "42"}, ["", "42", "-T", "", "-U", "1", "-b", "2"]),
({"nested": ["1", ["2", "3", [["6"]]]]}, ["--nested", "['1', ['2', '3', [['6']]]]"]),
(
{"map": {"a": [1, 3, 4]}, "channel_dirs": {"train": "foo", "eval": "bar"}},
Expand All @@ -133,7 +133,7 @@ def test_to_cmd_args(target, expected):
{"SM_MODEL_DIR": "/opt/ml/model", "SM_OUTPUT_DIR": "/opt/ml/output"},
),
({}, {}),
({"": None}, {u"": u""}),
({"": None}, {"": ""}),
(
{"bytes": b"2", "floats": 4.0, "int": 2, "unicode": "¡ø"},
{"SM_BYTES": "2", "SM_FLOATS": "4.0", "SM_INT": "2", "SM_UNICODE": "¡ø"},
Expand Down
2 changes: 2 additions & 0 deletions test/unit/test_smdataparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def test_smdataparallel_run_single_node_python(
"RDMAV_FORK_SAFE=1",
"-x",
"LD_PRELOAD=%s" % inspect.getfile(gethostname),
"-x",
"FI_EFA_USE_DEVICE_RDMA=1",
"--verbose",
"smddprun",
"usr/bin/python3",
Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ commands = flake8
[testenv:black-format]
# Used during development (before committing) to format .py files.
basepython = python3.7
deps = black==19.3b0
deps = black==22.3.0
commands =
black -l 100 ./

[testenv:black-check]
# Used by automated build steps to check that all files are properly formatted.
basepython = python3.7
deps = black==19.3b0
deps = black==22.3.0
commands =
black -l 100 --check ./

Expand Down

0 comments on commit b5e90ca

Please sign in to comment.