Skip to content

Commit

Permalink
Implement details for d-expand
Browse files Browse the repository at this point in the history
Fix a function call
  • Loading branch information
wschin committed Oct 27, 2023
1 parent 6ad5a2e commit d92dca2
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,47 @@ DistributedExpand<T>::DistributedExpand(const OpKernelInfo& info) : DistributedK
template <typename T>
Status DistributedExpand<T>::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(context != nullptr);
return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported expand pattern.");
// Assumptions.
// - Shape is not sharded.
// Algorithm.
// - Compute logical output shape.
// - Compute local output shape.
// - Expand from local input to local output.

auto input_tensor = context->Input<Tensor>(0);
auto shape_tensor = context->Input<Tensor>(1);
const auto& input_sharding_spec = input_shard_specs_.at(0);
const auto& shape_sharding_spec = input_shard_specs_.at(1);
const auto& output_sharding_spec = output_shard_specs_.at(0);

ORT_ENFORCE(shape_sharding_spec.HasNoShard(),
"It's not worth to shard Shape tensor. "
"If sharding shape is needed, please submit a feature request.");
// Compute logical input shape.
const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec);

// Compute logical output shape.
// This `shape_tensor` stores the logical output shape.
const auto* p_shape = shape_tensor->Data<int64_t>();
TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()};
TensorShape original_output_shape(original_output_dims);
ORT_ENFORCE(
onnxruntime::cuda::ComputeOutputShape(
Node().Name(),
original_input_shape,
original_output_dims, original_output_shape).IsOK());

// Compute local output shape.
const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec);

auto output_tensor = context->Output(0, local_output_shape);

return FuncExpand(
this,
context,
input_tensor,
shape_tensor,
output_tensor);
}

ONNX_OPERATOR_TYPED_KERNEL_EX(
Expand Down

0 comments on commit d92dca2

Please sign in to comment.