diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc index 7d106fd75e2d0..fd26a1f430de7 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -237,16 +237,55 @@ void ReshardTensor( } DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info) { - std::vector device_mesh_elements = info.GetAttrsOrDefault("device_mesh_elements"); - std::vector device_mesh_shape = info.GetAttrsOrDefault("device_mesh_shape"); - std::vector input_shard_specs = info.GetAttrsOrDefault("input_shard_specs"); - std::vector output_shard_specs = info.GetAttrsOrDefault("output_shard_specs"); + // input_device_mesh_shapes[i] is the shape of device mesh for the i-th input. + // E.g., device_mesh_shapes = ["[2]", "[1]"] means the first input is + // stored on a 1-D mesh with 2 devices and the second input on another 1-D + // mesh with 1 device. + std::vector attr_input_device_mesh_shapes; + ORT_ENFORCE(info.GetAttrs("input_device_mesh_shapes", attr_input_device_mesh_shapes).IsOK()); + // input_device_mesh_elements[i] is the flattened device mesh for the i-th input. + // Note that its actual shape is input_device_mesh_shapes[i]. + // Example: + // Assume + // device_mesh_shapes = ["[2]", "[1]"] + // device_mesh_elements = ["[0,1]", "[0]"] + // Then the first input is stored on a 1-D mesh with 2 devices and the second + // input on another 1-D mesh with 1 device. + std::vector attr_input_device_mesh_elements; + ORT_ENFORCE(info.GetAttrs("input_device_mesh_elements", attr_input_device_mesh_elements).IsOK()); + + // input_shard_specs[i] is the sharding spec of the i-th input; e.g., + // "RR" if the i-th input is not sharded. + std::vector input_shard_specs; + ORT_ENFORCE(info.GetAttrs("input_shard_specs", input_shard_specs).IsOK()); + + ORT_ENFORCE(attr_input_device_mesh_shapes.size() == attr_input_device_mesh_elements.size()); + ORT_ENFORCE(attr_input_device_mesh_shapes.size() == input_shard_specs.size()); + + // Begin parsing sharding metadata for inputs. for (size_t i = 0; i < input_shard_specs.size(); ++i) { + auto device_mesh_shape = ParseStringAsInt64Vector(attr_input_device_mesh_shapes[i]); + auto device_mesh_elements = ParseStringAsInt64Vector(attr_input_device_mesh_elements[i]); auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements); input_shard_specs_.push_back(spec); } + + std::vector attr_output_device_mesh_shapes; + ORT_ENFORCE(info.GetAttrs("output_device_mesh_shapes", attr_output_device_mesh_shapes).IsOK()); + + std::vector attr_output_device_mesh_elements; + ORT_ENFORCE(info.GetAttrs("output_device_mesh_elements", attr_output_device_mesh_elements).IsOK()); + + std::vector output_shard_specs; + ORT_ENFORCE(info.GetAttrs("output_shard_specs", output_shard_specs).IsOK()); + + ORT_ENFORCE(attr_output_device_mesh_shapes.size() == attr_output_device_mesh_elements.size()); + ORT_ENFORCE(attr_output_device_mesh_shapes.size() == output_shard_specs.size()); + for (size_t i = 0; i < output_shard_specs.size(); ++i) { + auto device_mesh_shape = ParseStringAsInt64Vector(attr_output_device_mesh_shapes[i]); + auto device_mesh_elements = ParseStringAsInt64Vector(attr_output_device_mesh_elements[i]); auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements); output_shard_specs_.push_back(spec); } diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc index f1d399077e37b..220938f3ceaef 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc @@ -27,6 +27,28 @@ void ValidateAxisIndex(const int64_t axis, const int64_t rank) { ORT_ENFORCE(adjusted_axis >= 0 && adjusted_axis < rank, "axis,", axis, ", should be in [", -rank, ",", rank, ")."); } +std::vector ParseStringAsInt64Vector(const std::string& str) { + if (str.empty() || str.front() != '[' || str.back() != ']') { + throw std::invalid_argument("Invalid input string format"); + } + // Parsed vector. + // If input is "[0, 1, 2]", result should be {0, 1, 2}. + std::vector result; + // Skip '[' and ']' + std::istringstream iss(str.substr(1, str.size() - 2)); + + // Extract integers separated by ',' or whitespaces. + int64_t num = -1; + while (/* Read one number at a time */ iss >> num) { + result.push_back(num); + // Skip the comma + if (iss.peek() == ',') { + iss.ignore(); + } + } + return result; +} + DeviceMesh CreateDeviceMesh( std::vector device_mesh_shape, std::vector device_mesh_elements) { diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 0f5ef6927a545..451d44b4bd434 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -314,6 +314,9 @@ class TensorPartitionSpec { } }; +// Parse "[0, 1, 2, 3]" as std::vector{0, 1, 2, 3}. +std::vector ParseStringAsInt64Vector(const std::string& str); + DeviceMesh CreateDeviceMesh( std::vector device_mesh_shape, std::vector device_mesh_elements); diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 7cdd71014c02e..97befe2a58301 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -83,17 +83,26 @@ void RegisterCollectiveOps() { ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedMatMul) .SetDomain(kMSDomain) .SinceVersion(1) - .Attr("device_mesh_elements", - "", - AttributeProto::INTS) - .Attr("device_mesh_shape", - "", - AttributeProto::INTS) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) .Attr("input_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", AttributeProto::STRINGS) .Attr("output_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "Similar to input_shard_specs but for outputs.", AttributeProto::STRINGS) .Input(0, "A", "N-dimensional matrix A", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .Input(1, "B", "N-dimensional matrix B", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) @@ -109,17 +118,26 @@ void RegisterCollectiveOps() { ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedSlice) .SetDomain(kMSDomain) .SinceVersion(1) - .Attr("device_mesh_elements", - "", - AttributeProto::INTS) - .Attr("device_mesh_shape", - "", - AttributeProto::INTS) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) .Attr("input_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", AttributeProto::STRINGS) .Attr("output_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "Similar to input_shard_specs but for outputs.", AttributeProto::STRINGS) .Input( 0, diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index 1baec80cb7c45..a9b55122c6806 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -20,15 +20,22 @@ def shard_tensor(X, rank, axis, num_shards): class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): + # It means 1-D tensor with single element: [2]. + device_mesh_shape = "[2]" + # It means 1-D tensor with two elements: [0, 1]. + device_mesh_elements = "[0,1]" + @onnxscript.script() def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RS[0]", "S[0]R"], output_shard_specs=["RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -55,15 +62,20 @@ def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul2d_rs_rs_rr(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_rs_rs_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RS[0]", "RS[0]"], output_shard_specs=["RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -93,15 +105,20 @@ def matmul_rs_rs_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul2d_rs_rs_rs(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul2d_rs_rs_rs(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RS[0]", "RS[0]"], output_shard_specs=["RS[0]"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -128,15 +145,20 @@ def matmul2d_rs_rs_rs(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul_srr_rr_srr(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_srr_rr_srr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["S[0]RR", "RR"], output_shard_specs=["S[0]RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -165,15 +187,20 @@ def matmul_srr_rr_srr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul_srr_rrrr_rsrr(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_srr_rrrr_rsrr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["S[0]RR", "RRRR"], output_shard_specs=["RS[0]RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -202,15 +229,20 @@ def matmul_srr_rrrr_rsrr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul_sr_rs_rr(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_sr_rs_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["S[0]R", "RS[0]"], output_shard_specs=["RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -239,15 +271,20 @@ def matmul_sr_rs_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul_rr_rs_rs(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_rr_rs_rs(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RR", "RS[0]"], output_shard_specs=["RS[0]"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -276,15 +313,20 @@ def matmul_rr_rs_rs(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul_rr_sr_rr(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_rr_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RR", "S[0]R"], output_shard_specs=["RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -313,6 +355,9 @@ def matmul_rr_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_slice_sr_axis1(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, tensor_axes: INT64) -> FLOAT: return MICROSOFT_OPSET.DistributedSlice( @@ -320,10 +365,12 @@ def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, te tensor_starts, tensor_ends, tensor_axes, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["S[0]R", "R", "R", "R", "R"], output_shard_specs=["S[0]R"], + input_device_mesh_shapes=[device_mesh_shape] * 5, + input_device_mesh_elements=[device_mesh_elements] * 5, + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -360,6 +407,9 @@ def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, te np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_slice_rs_axis1(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, tensor_axes: INT64) -> FLOAT: return MICROSOFT_OPSET.DistributedSlice( @@ -367,10 +417,12 @@ def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, te tensor_starts, tensor_ends, tensor_axes, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RS[0]", "R", "R", "R", "R"], output_shard_specs=["RS[0]"], + input_device_mesh_shapes=[device_mesh_shape] * 5, + input_device_mesh_elements=[device_mesh_elements] * 5, + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank()