Skip to content

Commit

Permalink
Support per-tensor device mesh at op level.
Browse files Browse the repository at this point in the history
Since Reshape may change device mesh from, e.g., [0, 1]
to [0, 1, 0, 1], we can't assume since device mesh per op.

Lint
  • Loading branch information
wschin committed Oct 19, 2023
1 parent a2c6283 commit db85449
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 40 deletions.
47 changes: 43 additions & 4 deletions onnxruntime/contrib_ops/cuda/collective/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,55 @@ void ReshardTensor(
}

DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info) {
std::vector<int64_t> device_mesh_elements = info.GetAttrsOrDefault<int64_t>("device_mesh_elements");
std::vector<int64_t> device_mesh_shape = info.GetAttrsOrDefault<int64_t>("device_mesh_shape");
std::vector<std::string> input_shard_specs = info.GetAttrsOrDefault<std::string>("input_shard_specs");
std::vector<std::string> output_shard_specs = info.GetAttrsOrDefault<std::string>("output_shard_specs");
// input_device_mesh_shapes[i] is the shape of device mesh for the i-th input.
// E.g., device_mesh_shapes = ["[2]", "[1]"] means the first input is
// stored on a 1-D mesh with 2 devices and the second input on another 1-D
// mesh with 1 device.
std::vector<std::string> attr_input_device_mesh_shapes;
ORT_ENFORCE(info.GetAttrs<std::string>("input_device_mesh_shapes", attr_input_device_mesh_shapes).IsOK());

// input_device_mesh_elements[i] is the flattened device mesh for the i-th input.
// Note that its actual shape is input_device_mesh_shapes[i].
// Example:
// Assume
// device_mesh_shapes = ["[2]", "[1]"]
// device_mesh_elements = ["[0,1]", "[0]"]
// Then the first input is stored on a 1-D mesh with 2 devices and the second
// input on another 1-D mesh with 1 device.
std::vector<std::string> attr_input_device_mesh_elements;
ORT_ENFORCE(info.GetAttrs<std::string>("input_device_mesh_elements", attr_input_device_mesh_elements).IsOK());

// input_shard_specs[i] is the sharding spec of the i-th input; e.g.,
// "RR" if the i-th input is not sharded.
std::vector<std::string> input_shard_specs;
ORT_ENFORCE(info.GetAttrs<std::string>("input_shard_specs", input_shard_specs).IsOK());

ORT_ENFORCE(attr_input_device_mesh_shapes.size() == attr_input_device_mesh_elements.size());
ORT_ENFORCE(attr_input_device_mesh_shapes.size() == input_shard_specs.size());

// Begin parsing sharding metadata for inputs.
for (size_t i = 0; i < input_shard_specs.size(); ++i) {
auto device_mesh_shape = ParseStringAsInt64Vector(attr_input_device_mesh_shapes[i]);
auto device_mesh_elements = ParseStringAsInt64Vector(attr_input_device_mesh_elements[i]);
auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements);
input_shard_specs_.push_back(spec);
}

std::vector<std::string> attr_output_device_mesh_shapes;
ORT_ENFORCE(info.GetAttrs<std::string>("output_device_mesh_shapes", attr_output_device_mesh_shapes).IsOK());

std::vector<std::string> attr_output_device_mesh_elements;
ORT_ENFORCE(info.GetAttrs<std::string>("output_device_mesh_elements", attr_output_device_mesh_elements).IsOK());

std::vector<std::string> output_shard_specs;

Check warning on line 280 in onnxruntime/contrib_ops/cuda/collective/sharding.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/sharding.cc#L280

Add #include <vector> for vector<> [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/sharding.cc:280:  Add #include <vector> for vector<>  [build/include_what_you_use] [4]
ORT_ENFORCE(info.GetAttrs<std::string>("output_shard_specs", output_shard_specs).IsOK());

Check warning on line 281 in onnxruntime/contrib_ops/cuda/collective/sharding.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/sharding.cc#L281

Add #include <string> for string [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/sharding.cc:281:  Add #include <string> for string  [build/include_what_you_use] [4]

ORT_ENFORCE(attr_output_device_mesh_shapes.size() == attr_output_device_mesh_elements.size());
ORT_ENFORCE(attr_output_device_mesh_shapes.size() == output_shard_specs.size());

for (size_t i = 0; i < output_shard_specs.size(); ++i) {
auto device_mesh_shape = ParseStringAsInt64Vector(attr_output_device_mesh_shapes[i]);
auto device_mesh_elements = ParseStringAsInt64Vector(attr_output_device_mesh_elements[i]);
auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements);
output_shard_specs_.push_back(spec);
}
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,28 @@ void ValidateAxisIndex(const int64_t axis, const int64_t rank) {
ORT_ENFORCE(adjusted_axis >= 0 && adjusted_axis < rank, "axis,", axis, ", should be in [", -rank, ",", rank, ").");
}

std::vector<int64_t> ParseStringAsInt64Vector(const std::string& str) {
if (str.empty() || str.front() != '[' || str.back() != ']') {
throw std::invalid_argument("Invalid input string format");
}
// Parsed vector.
// If input is "[0, 1, 2]", result should be {0, 1, 2}.
std::vector<int64_t> result;
// Skip '[' and ']'
std::istringstream iss(str.substr(1, str.size() - 2));

// Extract integers separated by ',' or whitespaces.
int64_t num = -1;
while (/* Read one number at a time */ iss >> num) {
result.push_back(num);
// Skip the comma
if (iss.peek() == ',') {
iss.ignore();
}
}
return result;
}

DeviceMesh CreateDeviceMesh(
std::vector<int64_t> device_mesh_shape,
std::vector<int64_t> device_mesh_elements) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharding_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ class TensorPartitionSpec {
}
};

// Parse "[0, 1, 2, 3]" as std::vector<int64_t>{0, 1, 2, 3}.
std::vector<int64_t> ParseStringAsInt64Vector(const std::string& str);

DeviceMesh CreateDeviceMesh(
std::vector<int64_t> device_mesh_shape,
std::vector<int64_t> device_mesh_elements);
Expand Down
50 changes: 34 additions & 16 deletions onnxruntime/core/graph/contrib_ops/collective_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,26 @@ void RegisterCollectiveOps() {
ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedMatMul)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("device_mesh_elements",
"",
AttributeProto::INTS)
.Attr("device_mesh_shape",
"",
AttributeProto::INTS)
.Attr("input_device_mesh_elements",
"device_mesh_elements[i] defines the device mesh's value for the i-th input. "
"E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd "
" inputs are stored on the 0-th and the 1st devices, respectively.",
AttributeProto::STRINGS)
.Attr("input_device_mesh_shapes",
"device_mesh_shape[i] defines the device mesh's shape for the i-th input.",
AttributeProto::STRINGS)
.Attr("input_shard_specs",
"The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.",
"The sharding spec of inputs. "
"E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_elements",
"Similar to input_device_mesh_elments but for outputs.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_shapes",
"Similar to input_device_mesh_shapes but for outputs.",
AttributeProto::STRINGS)
.Attr("output_shard_specs",
"The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.",
"Similar to input_shard_specs but for outputs.",
AttributeProto::STRINGS)
.Input(0, "A", "N-dimensional matrix A", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(1, "B", "N-dimensional matrix B", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
Expand All @@ -109,17 +118,26 @@ void RegisterCollectiveOps() {
ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedSlice)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("device_mesh_elements",
"",
AttributeProto::INTS)
.Attr("device_mesh_shape",
"",
AttributeProto::INTS)
.Attr("input_device_mesh_elements",
"device_mesh_elements[i] defines the device mesh's value for the i-th input. "
"E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd "
" inputs are stored on the 0-th and the 1st devices, respectively.",
AttributeProto::STRINGS)
.Attr("input_device_mesh_shapes",
"device_mesh_shape[i] defines the device mesh's shape for the i-th input.",
AttributeProto::STRINGS)
.Attr("input_shard_specs",
"The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.",
"The sharding spec of inputs. "
"E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_elements",
"Similar to input_device_mesh_elments but for outputs.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_shapes",
"Similar to input_device_mesh_shapes but for outputs.",
AttributeProto::STRINGS)
.Attr("output_shard_specs",
"The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.",
"Similar to input_shard_specs but for outputs.",
AttributeProto::STRINGS)
.Input(
0,
Expand Down
Loading

0 comments on commit db85449

Please sign in to comment.