Skip to content

Commit

Permalink
Cherry-pick 1st Round (#17308)
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang authored Aug 28, 2023
1 parent cbaa008 commit 198fc90
Show file tree
Hide file tree
Showing 49 changed files with 1,121 additions and 220 deletions.
2 changes: 1 addition & 1 deletion csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ internal static OrtValue CreateFromTensorObject(TensorBase value, out TensorElem
/// <summary>
/// Creates an OrtValue that contains a string tensor of specified shape, and
/// containing empty strings. String tensors are always on CPU.
/// Use FillStringTensorElement to assign individual elements values.
/// Use StringTensorSetElementAt to assign individual elements values.
/// </summary>
/// <param name="allocator"></param>
/// <returns>disposable OrtValue</returns>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public struct OrtTrainingApi
public IntPtr LoadCheckpoint;
public IntPtr SaveCheckpoint;
public IntPtr CreateTrainingSession;
public IntPtr CreateTrainingSessionFromBuffer;
public IntPtr TrainingSessionGetTrainingModelOutputCount;
public IntPtr TrainingSessionGetEvalModelOutputCount;
public IntPtr TrainingSessionGetTrainingModelOutputName;
Expand Down
4 changes: 0 additions & 4 deletions include/onnxruntime/core/framework/ort_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,7 @@ struct OrtValue {
}

bool IsSparseTensor() const {
#if !defined(DISABLE_SPARSE_TENSORS)
return (type_ != nullptr && type_->IsSparseTensorType());
#else
ORT_THROW("Sparse tensor is not supported in this build.");
#endif
}

onnxruntime::MLDataType Type() const {
Expand Down
2 changes: 1 addition & 1 deletion js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Do not modify directly.*
| Floor | ai.onnx(6-12,13+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| Gelu | com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11+) | |
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
Expand Down
6 changes: 4 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/split.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
const createSplitAttributesFromInputs =
(inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => {
const splitSizes: number[] = [];
let numOutputs: number = attributes.numOutputs;
if (inputs[1].dims[0] > 0) {
inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v)));
numOutputs = splitSizes.length;
}
return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes});
return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes});
};

const calculateOutputIndexImpl = (numberOfTensors: number): string => `
Expand Down Expand Up @@ -114,7 +116,7 @@ const createSplitProgramInfoLoader =
const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes);
const metadata:
ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)};
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)};
};

export const split = (context: ComputeContext, attributes: SplitAttributes): void => {
Expand Down
44 changes: 26 additions & 18 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1715,31 +1715,39 @@ class PlannerImpl {
void PartitionIntoStreams(const logging::Logger& /*logger*/,
const ExecutionProviders& /*execution_providers*/,
const PathString& /*partition_config_file*/) {
stream_nodes_.push_back({});
node_stream_map_.resize(SafeInt<size_t>(graph_viewer_.MaxNodeIndex()) + 1);
for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) {
stream_nodes_[0].push_back(node_index);
node_stream_map_[node_index] = 0;
if (graph_viewer_.NumberOfNodes() > 0) {
stream_nodes_.push_back({});
node_stream_map_.resize(SafeInt<size_t>(graph_viewer_.MaxNodeIndex()) + 1);
for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) {
stream_nodes_[0].push_back(node_index);
node_stream_map_[node_index] = 0;
}
num_logic_streams_ = 1;
}
num_logic_streams_ = 1;
}

Status BuildExecutionPlan(const ExecutionProviders& execution_providers) {
// 1. create logic stream instance
auto& execution_plan = plan_.execution_plan;
ORT_ENFORCE(num_logic_streams_ == 1 && !stream_nodes_[0].empty());
execution_plan.reserve(1);
auto first_node_index = stream_nodes_[0][0];
auto* node = graph_viewer_.GetNode(first_node_index);
onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType();
const IExecutionProvider* ep = execution_providers.Get(exec_provider_name);
ORT_ENFORCE(ep);
auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault);
execution_plan.emplace_back(std::make_unique<SequentialExecutionPlan::LogicStream>(node_device_mem_location));
// 2. add steps to the execution plan
for (auto node_index : stream_nodes_[0]) {
execution_plan[0]->steps_.emplace_back(std::make_unique<LaunchKernelStep>(node_index));

if (graph_viewer_.NumberOfNodes() > 0) {
ORT_ENFORCE(num_logic_streams_ == 1 && !stream_nodes_[0].empty());
execution_plan.reserve(1);
auto first_node_index = stream_nodes_[0][0];
auto* node = graph_viewer_.GetNode(first_node_index);
onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType();
const IExecutionProvider* ep = execution_providers.Get(exec_provider_name);
ORT_ENFORCE(ep);
auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault);
execution_plan.emplace_back(std::make_unique<SequentialExecutionPlan::LogicStream>(node_device_mem_location));
// 2. add steps to the execution plan
for (auto node_index : stream_nodes_[0]) {
execution_plan[0]->steps_.emplace_back(std::make_unique<LaunchKernelStep>(node_index));
}
} else {
// graph with no nodes. e.g. subgraph of If might return the input as-is or a constant value from an initializer
}

return Status::OK();
}

Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/framework/session_state_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,11 @@ common::Status SaveInitializedTensors(
auto initialized_tensors_to_allocate = id_to_initialized_tensor;
for (int ort_value_index : initializer_allocation_order) {
const auto entry = initialized_tensors_to_allocate.find(ort_value_index);
ORT_ENFORCE(entry != initialized_tensors_to_allocate.end(),
"OrtValue index: ", ort_value_index, " from initializer_allocation_order not found among initialized tensors");
if (!(utils::HasExternalData(*entry->second) && exec_plan.GetLocation(ort_value_index).Type() == OrtDevice::CPU)) {
// can not trace string tensor
ORT_ENFORCE(entry != initialized_tensors_to_allocate.end() &&
entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING);
ORT_ENFORCE(entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING, "Can not trace string tensor");
ORT_RETURN_IF_ERROR(planner.Trace(entry->first, entry->second));
}
initialized_tensors_to_allocate.erase(entry);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<Float8E5M2FNUZ>

int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);

#ifdef ENABLE_TRAINING_CORE
#ifdef ENABLE_TRAINING
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context);
#endif

Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/core/graph/function_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,15 @@ std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const std::string& functi
std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*> map_copy(model_local_functions.begin(),
model_local_functions.end());
std::unordered_map<std::string, TensorShapeProto> empty_map;
ONNX_NAMESPACE::shape_inference::SymbolTableImpl symbolTable;

// https://github.com/microsoft/onnxruntime/issues/17061
// We are passing a nullptr for the symbol table, because symbol table must be global
// for all the shape inferencing to work correctly. Otherwise, unrelated shapes get
// the same symbolic shapes and are marked for memory re-use. This is a Temp fix.
constexpr ONNX_NAMESPACE::shape_inference::SymbolTableImpl* symbolTable = nullptr;
ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(*onnx_func_proto, func_domain_to_version,
schema_registry, ctx, options, map_copy,
&symbolTable, &empty_map);
symbolTable, &empty_map);
});

op_schema->Finalize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,27 @@ bool LayerNormalizationGatherActor::PreCheck(const Graph& /* graph */,
return true;
}

bool LayerNormalizationGatherActor::PostProcess(Graph& /*graph*/, Node& current_node,
const SliceInfo& info_without_node,
const logging::Logger& /*logger*/,
const std::unordered_map<int, int>& /*propagate_input_indices*/,
const std::unordered_map<int, std::vector<DimCompare>>&
/*all_input_cmp_rets*/,
const std::unordered_map<int, SliceInfo>& /*new_gather_infos*/) {
// Update LayerNormalization's axis attribute if it is scalar slice.
if (info_without_node.is_scalar_slice) {
auto axis = static_cast<int64_t>(current_node.GetAttributes().at("axis").i());
auto original_ln_input_rank = info_without_node.input_rank;
axis = axis < 0 ? axis + original_ln_input_rank : axis;
auto new_axis = axis - 1;

auto& attributes = current_node.GetMutableAttributes();
attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast<int64_t>(new_axis));
}

return true;
}

bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_indices,
Expand All @@ -479,6 +500,28 @@ bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node,
propagate_input_indices, all_input_cmp_rets, shape_update_func);
}

bool SoftmaxGatherActor::PostProcess(Graph& graph, Node& current_node, const SliceInfo& info_without_node,
const logging::Logger& logger,
const std::unordered_map<int, int>& propagate_input_indices,
const std::unordered_map<int, std::vector<DimCompare>>& all_input_cmp_rets,
const std::unordered_map<int, SliceInfo>& new_gather_infos) {
SimplePointwiseGatherActor<true>::PostProcess(graph, current_node, info_without_node, logger,
propagate_input_indices, all_input_cmp_rets, new_gather_infos);

// Update Softmax's axis attribute if it is scalar slice.
if (info_without_node.is_scalar_slice) {
auto axis = static_cast<int64_t>(current_node.GetAttributes().at("axis").i());
auto original_ln_input_rank = info_without_node.input_rank;
axis = axis < 0 ? axis + original_ln_input_rank : axis;
auto new_axis = axis - 1;

auto& attributes = current_node.GetMutableAttributes();
attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast<int64_t>(new_axis));
}

return true;
}

bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const logging::Logger& logger,
std::unordered_map<int, int>& propagate_input_indices,
Expand Down Expand Up @@ -566,6 +609,11 @@ bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node,
return true;
}

LOG_DEBUG_INFO(logger, "Skip handle the Reshape, new_shape_const_values[info.non_negative_axis]:" +
std::to_string(new_shape_const_values[info.non_negative_axis]) +
", info.output_dim_on_axis.has_dim_value(): " +
std::to_string(info.output_dim_on_axis.has_dim_value()) + ".");

return false;
}

Expand Down Expand Up @@ -604,11 +652,12 @@ bool ReshapeGatherActor::PostProcess(
return true;
}

// If it selected shape is a dim value, we can update the shape tensor directory.
// If the selected shape is a dim value, we can update the shape tensor directory.
if (info_without_node.output_dim_on_axis.has_dim_value()) {
new_shape_const_values[slice_axis] = info_without_node.output_dim_on_axis.dim_value();
auto new_shape_arg =
CreateInitializerFromVector(graph, {static_cast<int64_t>(new_shape_const_values.size())}, new_shape_const_values,
CreateInitializerFromVector(graph, {static_cast<int64_t>(new_shape_const_values.size())},
new_shape_const_values,
graph.GenerateNodeArgName(current_node.MutableInputDefs()[1]->Name()));
graph_utils::ReplaceNodeInput(current_node, 1, *new_shape_arg);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class LayerNormalizationGatherActor : public UpStreamGatherOperatorActorBase {
const logging::Logger& /* logger */,
const std::unordered_map<int, int>& /* propagate_input_indices */,
const std::unordered_map<int, std::vector<DimCompare>>& /* all_input_cmp_rets */,
const std::unordered_map<int, SliceInfo>& /* new_gather_infos */) override { return true; }
const std::unordered_map<int, SliceInfo>& /* new_gather_infos */) override;
};

class SoftmaxGatherActor : public SimplePointwiseGatherActor<true> {
Expand All @@ -202,6 +202,12 @@ class SoftmaxGatherActor : public SimplePointwiseGatherActor<true> {
std::unordered_map<int, int>& propagate_input_indices,
std::unordered_map<int, std::vector<DimCompare>>& all_input_cmp_rets,
std::function<void(Node& node)>& shape_update_func) override;

bool PostProcess(Graph& /* graph */, Node& /* current_node */, const SliceInfo& /* info_without_node */,
const logging::Logger& /* logger */,
const std::unordered_map<int, int>& /* propagate_input_indices */,
const std::unordered_map<int, std::vector<DimCompare>>& /* all_input_cmp_rets */,
const std::unordered_map<int, SliceInfo>& /* new_gather_infos */) override;
};

class ReshapeGatherActor : public UpStreamGatherOperatorActorBase {
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul);

Expand Down Expand Up @@ -464,7 +465,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul)>,

Expand Down
10 changes: 9 additions & 1 deletion onnxruntime/core/providers/js/operators/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ namespace js {
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
11, \
13, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
11, 12, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/js/operators/split.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ class Split : public JsKernel, public SplitBase {
if (num_outputs_ < 0) {
num_outputs_ = split_sizes.size();
}
} else if (split_sizes_.size() == 0) {
// Compute split_sizes from input shape and num_outputs
} else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) {
// Compute split_sizes from input shape and num_outputs.
// TODO: Shape might not be known at this point, better to handle this in javascript
auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast<int32_t>(axis_)).dim_value();
int64_t split_size_sum = 0;
if (num_outputs_ < 0) {
Expand All @@ -44,6 +45,7 @@ class Split : public JsKernel, public SplitBase {
ORT_ENFORCE(split_size_sum == total_split_size,
"Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")");
}
// else: let javascript handle all other cases, ie. split_sizes come as input[1]

JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1,
"numOutputs" : $2,
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/vitisai/imp/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri
model_proto = model.ToProto();
} else {
model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat,
ToPathString(filename),
initializer_size_threshold);
}
auto& metadata = model.MetaData();
Expand Down
Loading

0 comments on commit 198fc90

Please sign in to comment.