Skip to content

Commit

Permalink
Support per-tensor device mesh at op level.
Browse files Browse the repository at this point in the history
Since Reshape may change device mesh from, e.g., [0, 1]
to [0, 1, 0, 1], we can't assume since device mesh per op.

Lint

resolve style warnings
  • Loading branch information
wschin committed Oct 25, 2023
1 parent ae85619 commit d358409
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 40 deletions.
49 changes: 45 additions & 4 deletions onnxruntime/contrib_ops/cuda/collective/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "mpi_include.h"
#include "sharding_spec.h"

#include <vector>

Check warning on line 8 in onnxruntime/contrib_ops/cuda/collective/sharding.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/sharding.cc#L8

Found C++ system header after other header. Should be: sharding.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/sharding.cc:8:  Found C++ system header after other header. Should be: sharding.h, c system, c++ system, other.  [build/include_order] [4]
#include <string>

Check warning on line 9 in onnxruntime/contrib_ops/cuda/collective/sharding.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/sharding.cc#L9

Found C++ system header after other header. Should be: sharding.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/sharding.cc:9:  Found C++ system header after other header. Should be: sharding.h, c system, c++ system, other.  [build/include_order] [4]
#include "core/providers/cpu/tensor/slice.h"
#include "core/providers/cuda/tensor/slice.h"
#include "core/providers/cuda/math/matmul.h"
Expand Down Expand Up @@ -237,16 +239,55 @@ void ReshardTensor(
}

DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info) {
std::vector<int64_t> device_mesh_elements = info.GetAttrsOrDefault<int64_t>("device_mesh_elements");
std::vector<int64_t> device_mesh_shape = info.GetAttrsOrDefault<int64_t>("device_mesh_shape");
std::vector<std::string> input_shard_specs = info.GetAttrsOrDefault<std::string>("input_shard_specs");
std::vector<std::string> output_shard_specs = info.GetAttrsOrDefault<std::string>("output_shard_specs");
// input_device_mesh_shapes[i] is the shape of device mesh for the i-th input.
// E.g., device_mesh_shapes = ["[2]", "[1]"] means the first input is
// stored on a 1-D mesh with 2 devices and the second input on another 1-D
// mesh with 1 device.
std::vector<std::string> attr_input_device_mesh_shapes;
ORT_ENFORCE(info.GetAttrs<std::string>("input_device_mesh_shapes", attr_input_device_mesh_shapes).IsOK());

// input_device_mesh_elements[i] is the flattened device mesh for the i-th input.
// Note that its actual shape is input_device_mesh_shapes[i].
// Example:
// Assume
// device_mesh_shapes = ["[2]", "[1]"]
// device_mesh_elements = ["[0,1]", "[0]"]
// Then the first input is stored on a 1-D mesh with 2 devices and the second
// input on another 1-D mesh with 1 device.
std::vector<std::string> attr_input_device_mesh_elements;
ORT_ENFORCE(info.GetAttrs<std::string>("input_device_mesh_elements", attr_input_device_mesh_elements).IsOK());

// input_shard_specs[i] is the sharding spec of the i-th input; e.g.,
// "RR" if the i-th input is not sharded.
std::vector<std::string> input_shard_specs;
ORT_ENFORCE(info.GetAttrs<std::string>("input_shard_specs", input_shard_specs).IsOK());

ORT_ENFORCE(attr_input_device_mesh_shapes.size() == attr_input_device_mesh_elements.size());
ORT_ENFORCE(attr_input_device_mesh_shapes.size() == input_shard_specs.size());

// Begin parsing sharding metadata for inputs.
for (size_t i = 0; i < input_shard_specs.size(); ++i) {
auto device_mesh_shape = ParseStringAsInt64Vector(attr_input_device_mesh_shapes[i]);
auto device_mesh_elements = ParseStringAsInt64Vector(attr_input_device_mesh_elements[i]);
auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements);
input_shard_specs_.push_back(spec);
}

std::vector<std::string> attr_output_device_mesh_shapes;
ORT_ENFORCE(info.GetAttrs<std::string>("output_device_mesh_shapes", attr_output_device_mesh_shapes).IsOK());

std::vector<std::string> attr_output_device_mesh_elements;
ORT_ENFORCE(info.GetAttrs<std::string>("output_device_mesh_elements", attr_output_device_mesh_elements).IsOK());

std::vector<std::string> output_shard_specs;
ORT_ENFORCE(info.GetAttrs<std::string>("output_shard_specs", output_shard_specs).IsOK());

ORT_ENFORCE(attr_output_device_mesh_shapes.size() == attr_output_device_mesh_elements.size());
ORT_ENFORCE(attr_output_device_mesh_shapes.size() == output_shard_specs.size());

for (size_t i = 0; i < output_shard_specs.size(); ++i) {
auto device_mesh_shape = ParseStringAsInt64Vector(attr_output_device_mesh_shapes[i]);
auto device_mesh_elements = ParseStringAsInt64Vector(attr_output_device_mesh_elements[i]);
auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements);
output_shard_specs_.push_back(spec);
}
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,28 @@ void ValidateAxisIndex(const int64_t axis, const int64_t rank) {
ORT_ENFORCE(adjusted_axis >= 0 && adjusted_axis < rank, "axis,", axis, ", should be in [", -rank, ",", rank, ").");
}

std::vector<int64_t> ParseStringAsInt64Vector(const std::string& str) {
if (str.empty() || str.front() != '[' || str.back() != ']') {
throw std::invalid_argument("Invalid input string format");
}
// Parsed vector.
// If input is "[0, 1, 2]", result should be {0, 1, 2}.
std::vector<int64_t> result;
// Skip '[' and ']'
std::istringstream iss(str.substr(1, str.size() - 2));

// Extract integers separated by ',' or whitespaces.
int64_t num = -1;
while (/* Read one number at a time */ iss >> num) {
result.push_back(num);
// Skip the comma
if (iss.peek() == ',') {
iss.ignore();
}
}
return result;
}

DeviceMesh CreateDeviceMesh(
std::vector<int64_t> device_mesh_shape,
std::vector<int64_t> device_mesh_elements) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharding_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ class TensorPartitionSpec {
}
};

// Parse "[0, 1, 2, 3]" as std::vector<int64_t>{0, 1, 2, 3}.
std::vector<int64_t> ParseStringAsInt64Vector(const std::string& str);

DeviceMesh CreateDeviceMesh(
std::vector<int64_t> device_mesh_shape,
std::vector<int64_t> device_mesh_elements);
Expand Down
50 changes: 34 additions & 16 deletions onnxruntime/core/graph/contrib_ops/collective_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,26 @@ void RegisterCollectiveOps() {
ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedMatMul)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("device_mesh_elements",
"",
AttributeProto::INTS)
.Attr("device_mesh_shape",
"",
AttributeProto::INTS)
.Attr("input_device_mesh_elements",
"device_mesh_elements[i] defines the device mesh's value for the i-th input. "
"E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd "
" inputs are stored on the 0-th and the 1st devices, respectively.",
AttributeProto::STRINGS)
.Attr("input_device_mesh_shapes",
"device_mesh_shape[i] defines the device mesh's shape for the i-th input.",
AttributeProto::STRINGS)
.Attr("input_shard_specs",
"The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.",
"The sharding spec of inputs. "
"E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_elements",
"Similar to input_device_mesh_elments but for outputs.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_shapes",
"Similar to input_device_mesh_shapes but for outputs.",
AttributeProto::STRINGS)
.Attr("output_shard_specs",
"The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.",
"Similar to input_shard_specs but for outputs.",
AttributeProto::STRINGS)
.Input(0, "A", "N-dimensional matrix A", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(1, "B", "N-dimensional matrix B", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
Expand All @@ -109,17 +118,26 @@ void RegisterCollectiveOps() {
ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedSlice)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("device_mesh_elements",
"",
AttributeProto::INTS)
.Attr("device_mesh_shape",
"",
AttributeProto::INTS)
.Attr("input_device_mesh_elements",
"device_mesh_elements[i] defines the device mesh's value for the i-th input. "
"E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd "
" inputs are stored on the 0-th and the 1st devices, respectively.",
AttributeProto::STRINGS)
.Attr("input_device_mesh_shapes",
"device_mesh_shape[i] defines the device mesh's shape for the i-th input.",
AttributeProto::STRINGS)
.Attr("input_shard_specs",
"The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.",
"The sharding spec of inputs. "
"E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_elements",
"Similar to input_device_mesh_elments but for outputs.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_shapes",
"Similar to input_device_mesh_shapes but for outputs.",
AttributeProto::STRINGS)
.Attr("output_shard_specs",
"The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.",
"Similar to input_shard_specs but for outputs.",
AttributeProto::STRINGS)
.Input(
0,
Expand Down
Loading

0 comments on commit d358409

Please sign in to comment.