Skip to content

Commit

Permalink
Merge branch 'master' into ovep-develop-lnl-1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
preetha-intel committed Aug 26, 2024
2 parents 985a4e9 + 983c4d5 commit 3291c49
Show file tree
Hide file tree
Showing 14 changed files with 218 additions and 251 deletions.
3 changes: 2 additions & 1 deletion include/onnxruntime/core/graph/graph_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ class ValidNodes {
return (current_ != other.current_);
}

void operator++() {
NodeIterator<TIterator>& operator++() {
if (current_ < end_) {
while (++current_ != end_) {
if (*current_ != nullptr && (!apply_filter_ || (*filter_func_)((*current_)->Index()) == false))
break;
}
}
return *this;
}

NodeIterator<TIterator> operator++(int) {
Expand Down
311 changes: 124 additions & 187 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
// 2. Otherwise, we add Argmax layer normally
if (node.GetOutputEdgesCount() == 1) {
auto it = node.OutputEdgesBegin();
const auto* succ_node(graph_viewer.GetNode(it->GetNode().Index()));
const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index());
// If Argmax's successive node is a Cast from int64 to int32 output
// The 'cast to' type is checked in operator supported related, omit the check here
if (succ_node->OpType() == "Cast") {
// The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl())
// so we omit the check here
if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") {
// Skip the cast's input/argmax's output
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = succ_node->OutputDefs()[0]->Name();
*layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name();
model_builder.AddLayer(std::move(layer));
return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara
return false;
}

if (node.GetInputEdgesCount() > 1) {
LOGS(logger, VERBOSE) << "Multiple nodes producing Cast's input.";
return false;
}

const auto& prec_node = node.InputEdgesBegin()->GetNode();

/*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,28 +223,4 @@ namespace Dml
{
m_defaultRoundingMode = roundingMode;
}

CPUAllocator::CPUAllocator(OrtMemType memType)
: onnxruntime::IAllocator(
OrtMemoryInfo(
"DML CPU",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0),
0,
memType
)
)
{
}

void* CPUAllocator::Alloc(size_t size)
{
return onnxruntime::AllocatorDefaultAlloc(size);
}

void CPUAllocator::Free(void* p)
{
return onnxruntime::AllocatorDefaultFree(p);
}

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ namespace Dml
std::make_unique<DmlCommittedResourceAllocator>(m_d3d12Device.Get()));
m_context->SetAllocator(m_allocator);
// CPU Allocator used to create buffers for the MemcpyFromHost, Shape and Size operators.
m_cpuInputAllocator = std::make_shared<CPUAllocator>(OrtMemType::OrtMemTypeCPUInput);
OrtMemoryInfo memoryInfo(onnxruntime::CPU, OrtAllocatorType::OrtDeviceAllocator);
memoryInfo.mem_type = ::OrtMemType::OrtMemTypeCPUInput;
m_cpuInputAllocator = std::make_shared<onnxruntime::CPUAllocator>(memoryInfo);
}

return std::vector<onnxruntime::AllocatorPtr>{m_allocator, m_cpuInputAllocator,};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ namespace Dml
class ReadbackHeap;
class ExecutionContext;
class BucketizedBufferAllocator;
class CPUAllocator;
class ExecutionProvider;

class ExecutionProviderImpl : public WRL::Base<Dml::IExecutionProvider,
Expand Down Expand Up @@ -213,7 +212,7 @@ namespace Dml
std::unique_ptr<PooledUploadHeap> m_uploadHeap;
std::unique_ptr<ReadbackHeap> m_readbackHeap;
std::shared_ptr<BucketizedBufferAllocator> m_allocator;
std::shared_ptr<CPUAllocator> m_cpuInputAllocator;
std::shared_ptr<onnxruntime::IAllocator> m_cpuInputAllocator;
std::shared_ptr<onnxruntime::KernelRegistry> m_kernelRegistry;
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap> m_internalRegInfoMap;
mutable uint64_t m_partitionKernelPrefixVal = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,6 @@ def parse_args():
)
parser.add_argument(
"--op_types_to_quantize",
default="MatMul",
type=str,
nargs="+",
choices=["MatMul", "Gather"],
Expand All @@ -1089,7 +1088,7 @@ def parse_args():
input_model_path = args.input_model
output_model_path = args.output_model
quant_format = QuantFormat[args.quant_format]
op_types_to_quantize = tuple(args.op_types_to_quantize) if args.op_types_to_quantize else None
op_types_to_quantize = tuple(args.op_types_to_quantize) if args.op_types_to_quantize else ("MatMul",)
quant_axes = tuple(args.quant_axes) if args.quant_axes else None

if os.path.exists(output_model_path):
Expand Down
9 changes: 0 additions & 9 deletions onnxruntime/python/tools/transformers/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,6 @@ def run_onnxruntime(
)
return results

if provider == "migraphx":
optimizer_info = OptimizerInfo.NOOPT
warm_up_repeat = 5
if "MIGraphXExecutionProvider" not in onnxruntime.get_available_providers():
logger.error(
"Please install onnxruntime-rocm package, and use a machine with GPU for testing gpu performance."
)
return results

if optimizer_info == OptimizerInfo.NOOPT:
logger.warning(
f"OptimizerInfo is set to {optimizer_info}, graph optimizations specified in FusionOptions are not applied."
Expand Down
51 changes: 47 additions & 4 deletions onnxruntime/test/providers/coreml/coreml_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT License.

#include "core/common/logging/logging.h"
#include "core/graph/graph.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/coreml/coreml_execution_provider.h"
#include "core/providers/coreml/coreml_provider_factory.h"
#include "core/session/inference_session.h"
Expand Down Expand Up @@ -92,7 +94,7 @@ TEST(CoreMLExecutionProviderTest, FunctionTest) {
feeds.insert(std::make_pair("Y", ml_value_y));
feeds.insert(std::make_pair("Z", ml_value_z));

RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.FunctionTest",
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
MakeCoreMLExecutionProvider(),
feeds);
#else
Expand All @@ -118,9 +120,50 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) {
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));

RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.ArgMaxCastTest",
EPVerificationParams verification_params{};
verification_params.ep_node_assignment = ExpectedEPNodeAssignment::All;

RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
MakeCoreMLExecutionProvider(),
feeds);
feeds,
verification_params);
#else
TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All);
#endif
}

TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) {
const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/coreml_argmax_unsupported_cast_test.onnx");

#if defined(__APPLE__)
std::vector<int64_t> dims_mul_x = {3, 2, 2};
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
OrtValue ml_value_x;
AllocatorPtr allocator = std::make_shared<CPUAllocator>();
CreateMLValue<float>(allocator, dims_mul_x, values_mul_x, &ml_value_x);

NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));

const std::function<void(const Graph&)> graph_verifier = [](const Graph& graph) {
GraphViewer graph_viewer{graph};
const auto& node_indices_in_order = graph_viewer.GetNodesInTopologicalOrder();
ASSERT_EQ(node_indices_in_order.size(), size_t{2});
// second node should be an unsupported Cast
const auto* cast_node = graph.GetNode(node_indices_in_order[1]);
ASSERT_NE(cast_node, nullptr);
ASSERT_EQ(cast_node->OpType(), "Cast");
ASSERT_EQ(cast_node->GetExecutionProviderType(), kCpuExecutionProvider);
};

EPVerificationParams verification_params{};
verification_params.ep_node_assignment = ExpectedEPNodeAssignment::Some;
verification_params.graph_verifier = &graph_verifier;

RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
MakeCoreMLExecutionProvider(),
feeds,
verification_params);
#else
TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::Some);
#endif
Expand Down Expand Up @@ -184,7 +227,7 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) {
NameMLValMap feeds;
feeds.insert(std::make_pair("Input3", ml_value));

RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.TestOrtFormatModel",
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
MakeCoreMLExecutionProvider(),
feeds);
#else
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/test/testdata/coreml_argmax_cast_test.onnx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
:�

:�
F
Xargmax_output_int64argmax"ArgMax*
axis�*
Expand All @@ -15,4 +16,4 @@ F



B
B
19 changes: 11 additions & 8 deletions onnxruntime/test/testdata/coreml_argmax_cast_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import onnx
from onnx import TensorProto, helper

# CoreML EP currently handles a special case for supporting ArgMax op
# Please see in <repo_root>/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and
# <repo_root>/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc
# We have this separated test script to generate graph for the case: An ArgMax followed by a Cast to int32 type
# CoreML EP currently handles a special case for supporting ArgMax followed by a Cast to int32.
# Please see <repo_root>/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and
# <repo_root>/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc.
# This script generates graphs for these cases:
# - An ArgMax followed by a supported Cast to int32 type
# - An ArgMax followed by an unsupported Cast to a type other than int32


def GenerateModel(model_name): # noqa: N802
def GenerateModel(model_name, cast_to_dtype): # noqa: N802
nodes = [
helper.make_node("ArgMax", ["X"], ["argmax_output_int64"], "argmax", axis=1, keepdims=1),
helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=6), # cast to int32 type
helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=cast_to_dtype),
]

graph = helper.make_graph(
Expand All @@ -20,7 +22,7 @@ def GenerateModel(model_name): # noqa: N802
helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2, 2]),
],
[ # output
helper.make_tensor_value_info("Y", TensorProto.INT32, [3, 1, 2]),
helper.make_tensor_value_info("Y", cast_to_dtype, [3, 1, 2]),
],
)

Expand All @@ -29,4 +31,5 @@ def GenerateModel(model_name): # noqa: N802


if __name__ == "__main__":
GenerateModel("coreml_argmax_cast_test.onnx")
GenerateModel("coreml_argmax_cast_test.onnx", TensorProto.INT32)
GenerateModel("coreml_argmax_unsupported_cast_test.onnx", TensorProto.UINT32)
19 changes: 19 additions & 0 deletions onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

:�
F
Xargmax_output_int64argmax"ArgMax*
axis�*
keepdims�
/
argmax_output_int64Ycast"Cast*
to �CoreML_ArgMax_Cast_TestZ
X



b
Y
 


B
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ parameters:

steps:
- task: CmdLine@2
input:
inputs:
script: |
if [ -f $(Build.BinariesDirectory)/emulator.pid ]; then
echo "Emulator is running."
echo "##vso[task.setvariable variable=isEmulatorRunning]True"
else
echo "Emulator is not running."
fi
name: Determine if emulator is running
displayName: "Determine if emulator is running"

- task: CmdLine@2
inputs:
Expand Down

0 comments on commit 3291c49

Please sign in to comment.