Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding DML to python cuda package #22606

Merged
merged 62 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
1fe6bae
Adding new Python package testing pipeline for CUda Alt
jchen351 Oct 24, 2024
e973742
Remove Linux_Test_GPU_x86_64_stage from stage final
jchen351 Oct 24, 2024
2a5be8c
Adding dependencies
jchen351 Oct 24, 2024
7690e26
Distinguish between DML and the generic 'GPU' term. This is needed fo…
pranavsharma Oct 25, 2024
a0918e0
Update dml pkg with python
jchen351 Oct 25, 2024
10dcad9
Update dml pkg with python
jchen351 Oct 25, 2024
3b9630e
Update dml pkg with python
jchen351 Oct 25, 2024
381f55f
Address code review comment
pranavsharma Oct 25, 2024
b671f41
Merge remote-tracking branch 'origin/package_dml' into Cjian/pydml
jchen351 Oct 25, 2024
8196b65
remove --enable_wcos and --cmake_extra_defines "CMAKE_SYSTEM_VERSION=…
jchen351 Oct 25, 2024
6dd4694
Merge branch 'main' into Cjian/pytest
jchen351 Oct 25, 2024
5ba0c12
Split DML test out of cuda
jchen351 Oct 26, 2024
847a7ff
Merge branch 'Cjian/pytest' into Cjian/pydml
jchen351 Oct 26, 2024
36c5bde
Split DML test out of cuda
jchen351 Oct 26, 2024
7113a15
Split DML test out of cuda
jchen351 Oct 26, 2024
bdce800
Split DML test out of cuda
jchen351 Oct 26, 2024
3755562
Split DML test out of cuda
jchen351 Oct 26, 2024
7a2cd9c
Split DML test out of cuda
jchen351 Oct 26, 2024
5188f4c
Split DML test out of cuda
jchen351 Oct 26, 2024
e334bfe
Merge remote-tracking branch 'origin/Cjian/pytest' into Cjian/pytest
jchen351 Oct 26, 2024
ff73270
lintrunner
jchen351 Oct 26, 2024
4c29c54
Merge branch 'Cjian/pytest' into Cjian/pydml
jchen351 Oct 26, 2024
a9b5a90
update ruff
jchen351 Oct 28, 2024
775ff5b
Merge branch 'Cjian/pytest' into Cjian/pydml
jchen351 Oct 28, 2024
9557f40
update cuda
jchen351 Oct 28, 2024
42c8702
exclude test_convtranspose_autopad_same_cuda
jchen351 Oct 28, 2024
2007bb3
update os.environ["ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS"]
jchen351 Oct 29, 2024
a4dddc5
# exclude DML EP when CUDA test is running.
jchen351 Oct 29, 2024
b11b993
Distinguish between DML and the generic 'GPU' term. This is needed fo…
pranavsharma Oct 30, 2024
b3fa3b2
Merge remote-tracking branch 'origin/dml_device' into Cjian/pydml
jchen351 Oct 30, 2024
b66393c
Move py_packaging_test_step into a template
jchen351 Oct 30, 2024
9d61bf4
parameters
jchen351 Oct 30, 2024
e4c3269
$(Agent.TempDirectory)
jchen351 Oct 30, 2024
a1d0985
Merge branch 'main' into Cjian/pydml
jchen351 Nov 1, 2024
2470bd0
Lintrunner -a
jchen351 Nov 1, 2024
4993088
Update DNNL CI python to 310
jchen351 Nov 1, 2024
12b5ff6
Merge branch 'Cjian/dnnl' into Cjian/pydml
jchen351 Nov 1, 2024
f963b6e
Replace reference to python 3.8 with python 3.19
jchen351 Nov 1, 2024
e0b895b
Merge branch 'Cjian/dnnl' into Cjian/pydml
jchen351 Nov 1, 2024
ba7dd01
Replace reference to python 3.8 with python 3.10
jchen351 Nov 1, 2024
c16a1d4
Merge branch 'Cjian/dnnl' into Cjian/pydml
jchen351 Nov 1, 2024
b8b98ea
Enable Win CUDA python test
jchen351 Nov 4, 2024
bcd173b
Enable Win CUDA python test
jchen351 Nov 4, 2024
720fce9
Disable 3 failed test due to upgrading to python 3.10
jchen351 Nov 4, 2024
94e152c
Merge branch 'Cjian/dnnl' into Cjian/pydml
jchen351 Nov 4, 2024
3a66769
Using Iterable instead of list
jchen351 Nov 4, 2024
5bb6617
Undo Iterable instead of list
jchen351 Nov 4, 2024
a48cccd
linrunner -a
jchen351 Nov 4, 2024
0efb9fc
Merge branch 'Cjian/dnnl' into Cjian/pydml
jchen351 Nov 4, 2024
015b32b
exclude failed cuda 12 tests
jchen351 Nov 5, 2024
ef629ad
Adding verbose to run onnx_backend_test_series.py
jchen351 Nov 5, 2024
bc4e825
Adding verbose to run onnx_backend_test_series.py
jchen351 Nov 5, 2024
cb3b48b
Merge branch 'Cjian/enable_cuda_py_test' into Cjian/pydml
jchen351 Nov 5, 2024
f9f0d46
Merge branch 'main' into Cjian/pydml
jchen351 Nov 5, 2024
d462901
rolling back dml texts exclusions,
jchen351 Nov 5, 2024
099d7dd
remove CudaVersion: ${{ parameters.cuda_version }}
jchen351 Nov 5, 2024
057ae40
Merge branch 'Cjian/enable_cuda_py_test' into Cjian/pydml
jchen351 Nov 5, 2024
4cd9ecd
Merge branch 'main' into Cjian/pydml
jchen351 Nov 6, 2024
01a4359
Merge branch 'main' into Cjian/pydml
jchen351 Nov 13, 2024
82d3f6e
Merge remote-tracking branch 'origin/main' into Cjian/pydml
jchen351 Nov 16, 2024
bc8e9d7
Exclude failed cuda test that running during the DML testing
jchen351 Nov 18, 2024
ce4666d
Reduce parallel counts
jchen351 Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/ortdevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct OrtDevice {
static const DeviceType GPU = 1; // Nvidia or AMD
static const DeviceType FPGA = 2;
static const DeviceType NPU = 3; // Ascend
static const DeviceType DML = 4;

struct MemType {
// Pre-defined memory types.
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
*out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA) == 0 ||
strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 ||
strcmp(name1, onnxruntime::DML) == 0 ||
strcmp(name1, onnxruntime::HIP) == 0 ||
strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 ||
strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::DML) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,19 @@
D3D12_HEAP_FLAGS heapFlags,
D3D12_RESOURCE_FLAGS resourceFlags,
D3D12_RESOURCE_STATES initialState,
std::unique_ptr<DmlSubAllocator>&& subAllocator
)
std::unique_ptr<DmlSubAllocator>&& subAllocator)

Check warning on line 44 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp:44: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
: onnxruntime::IAllocator(
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
)
),
m_device(device),
m_heapProperties(heapProps),
m_heapFlags(heapFlags),
m_resourceFlags(resourceFlags),
m_initialState(initialState),
m_context(context),
m_subAllocator(std::move(subAllocator))
{
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))),
m_device(device),
m_heapProperties(heapProps),
m_heapFlags(heapFlags),
m_resourceFlags(resourceFlags),
m_initialState(initialState),
m_context(context),
m_subAllocator(std::move(subAllocator)) {
}

/*static*/ gsl::index BucketizedBufferAllocator::GetBucketIndexFromSize(uint64_t size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
class DmlExternalBufferAllocator : public onnxruntime::IAllocator
{
public:
DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator(
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
))
{
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
}
DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator(

Check warning on line 23 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h:23: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))) {
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
}

void* Alloc(size_t size) final
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,17 @@ namespace Dml
bool enableMetacommands,
bool enableGraphCapture,
bool enableSyncSpinning,
bool disableMemoryArena) :
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0))
{
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE)
{
// DML requires either DIRECT or COMPUTE command queues.
ORT_THROW_HR(E_INVALIDARG);
}
bool disableMemoryArena) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)) {
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE) {
// DML requires either DIRECT or COMPUTE command queues.
ORT_THROW_HR(E_INVALIDARG);
}

ComPtr<ID3D12Device> device;
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));
ComPtr<ID3D12Device> device;
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));

m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
}

std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ namespace Dml

bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final
{
return (srcDevice.Type() == OrtDevice::GPU) ||
(dstDevice.Type() == OrtDevice::GPU);
return (srcDevice.Type() == OrtDevice::DML) ||
(dstDevice.Type() == OrtDevice::DML);
}

private:
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,23 @@
}
}
}

// This function is called when the session is being initialized.
// For now, this function only checks for invalid combination of DML EP with other EPs.
// TODO: extend this function to check for other invalid combinations of EPs.

Check warning on line 1668 in onnxruntime/core/session/inference_session.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/session/inference_session.cc:1668: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const {
// DML EP is only allowed with CPU EP
bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr;
if (has_dml_ep) {
const auto& ep_list = execution_providers_.GetIds();
for (const auto& ep : ep_list) {
if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue;
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP.");
}
}
return Status::OK();
}

#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
// VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong.
Expand Down Expand Up @@ -1719,6 +1736,11 @@
execution_providers_.SetCpuProviderWasImplicitlyAdded(true);
}

// Check for the presence of an invalid combination of execution providers in the session
// For e.g. we don't support DML EP and other GPU EPs to be present in the same session
// This check is placed here because it serves as a common place for all language bindings.
ORT_RETURN_IF_ERROR_SESSIONID_(HasInvalidCombinationOfExecutionProviders());

// re-acquire mutex
std::lock_guard<std::mutex> l(session_mutex_);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ class InferenceSession {
const Environment& session_env);
void ConstructorCommon(const SessionOptions& session_options,
const Environment& session_env);

[[nodiscard]] common::Status HasInvalidCombinationOfExecutionProviders() const;
[[nodiscard]] common::Status SaveModelMetadata(const onnxruntime::Model& model);

#if !defined(ORT_MINIMAL_BUILD)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_mlvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {

const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* GetDmlToHostMemCpyFunction() {
static std::unordered_map<OrtDevice::DeviceType, MemCpyFunc> map{
{OrtDevice::GPU, DmlToCpuMemCpy}};
{OrtDevice::DML, DmlToCpuMemCpy}};

return &map;
}
Expand Down
57 changes: 34 additions & 23 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,22 @@
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy);
#elif USE_DML
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML
CreateGenericMLValue(
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (device.Type() == OrtDevice::DML) {
#if USE_DML
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML

Check warning on line 108 in onnxruntime/python/onnxruntime_pybind_ortvalue.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_ortvalue.cc:108: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 108 in onnxruntime/python/onnxruntime_pybind_ortvalue.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_ortvalue.cc:108: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
CreateGenericMLValue(
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (device.Type() == OrtDevice::NPU) {
#ifdef USE_CANN
Expand All @@ -116,9 +122,9 @@
CreateGenericMLValue(nullptr, GetCannAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
true, false, CpuToCannMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
"Please use the CANN package of OnnxRuntime to use this feature.");
throw std::runtime_error(
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
"Please use the CANN package of OnnxRuntime to use this feature.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
Expand Down Expand Up @@ -160,19 +166,24 @@
}

onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToRocmMemCpy);
#elif USE_DML
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToRocmMemCpy);
#else
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
#endif
} else if (device.Type() == OrtDevice::DML) {
#if USE_DML
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToDmlMemCpy);
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot update the OrtValue on this device");
Expand Down
8 changes: 3 additions & 5 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,9 @@ const char* GetDeviceName(const OrtDevice& device) {
case OrtDevice::CPU:
return CPU;
case OrtDevice::GPU:
#ifdef USE_DML
return DML;
#else
return CUDA;
#endif
case OrtDevice::DML:
return DML;
case OrtDevice::FPGA:
return "FPGA";
case OrtDevice::NPU:
Expand Down Expand Up @@ -1579,7 +1577,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
.def_static("cann", []() { return OrtDevice::NPU; })
.def_static("fpga", []() { return OrtDevice::FPGA; })
.def_static("npu", []() { return OrtDevice::NPU; })
.def_static("dml", []() { return OrtDevice::GPU; })
.def_static("dml", []() { return OrtDevice::DML; })
.def_static("webgpu", []() { return OrtDevice::GPU; })
.def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });

Expand Down
41 changes: 28 additions & 13 deletions onnxruntime/test/python/onnx_backend_test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def load_jsonc(basename: str):
return json.loads("\n".join(lines))


def create_backend_test(test_name=None):
def create_backend_test(devices: list[str], test_name=None):
snnn marked this conversation as resolved.
Show resolved Hide resolved
"""Creates an OrtBackendTest and adds its TestCase's to global scope so unittest will find them."""

overrides = load_jsonc("onnx_backend_test_series_overrides.jsonc")
Expand All @@ -126,30 +126,29 @@ def create_backend_test(test_name=None):
else:
filters = load_jsonc("onnx_backend_test_series_filters.jsonc")
current_failing_tests = apply_filters(filters, "current_failing_tests")

if platform.architecture()[0] == "32bit":
current_failing_tests += apply_filters(filters, "current_failing_tests_x86")

if backend.supports_device("DNNL"):
if backend.supports_device("DNNL") or "DNNL" in devices:
current_failing_tests += apply_filters(filters, "current_failing_tests_DNNL")

if backend.supports_device("NNAPI"):
if backend.supports_device("NNAPI") or "NNAPI" in devices:
current_failing_tests += apply_filters(filters, "current_failing_tests_NNAPI")

if backend.supports_device("OPENVINO_GPU"):
if backend.supports_device("OPENVINO_GPU") or "OPENVINO_GPU" in devices:
current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_GPU")

if backend.supports_device("OPENVINO_CPU"):
if backend.supports_device("OPENVINO_CPU") or "OPENVINO_CPU" in devices:
current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP32")
current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP16")

if backend.supports_device("OPENVINO_NPU"):
if backend.supports_device("OPENVINO_NPU") or "OPENVINO_NPU" in devices:
current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_NPU")

if backend.supports_device("OPENVINO"):
if backend.supports_device("OPENVINO") or "OPENVINO" in devices:
current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_opset18")

if backend.supports_device("MIGRAPHX"):
if backend.supports_device("MIGRAPHX") or "MIGRAPHX" in devices:
current_failing_tests += apply_filters(filters, "current_failing_tests_MIGRAPHX")

if backend.supports_device("WEBGPU"):
Expand All @@ -158,8 +157,17 @@ def create_backend_test(test_name=None):
# Skip these tests for a "pure" DML onnxruntime python wheel. We keep these tests enabled for instances where both DML and CUDA
# EPs are available (Windows GPU CI pipeline has this config) - these test will pass because CUDA has higher precedence than DML
# and the nodes are assigned to only the CUDA EP (which supports these tests)
if backend.supports_device("DML") and not backend.supports_device("GPU"):
if (backend.supports_device("DML") and not backend.supports_device("GPU")) or "DML" in devices:
current_failing_tests += apply_filters(filters, "current_failing_tests_pure_DML")
# exclude CUDA EP when DML test is running.
os.environ["ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS"] = "TensorrtExecutionProvider,CUDAExecutionProvider"
elif backend.supports_device("DML") and "DML" not in devices:
# exclude DML EP when CUDA test is running.
os.environ["ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS"] = "TensorrtExecutionProvider,DmlExecutionProvider"
else:
# exclude TRT EP temporarily and only test CUDA EP to retain previous behavior
os.environ["ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS"] = "TensorrtExecutionProvider"


filters = (
jchen351 marked this conversation as resolved.
Show resolved Hide resolved
current_failing_tests
Expand All @@ -172,8 +180,6 @@ def create_backend_test(test_name=None):
backend_test.exclude("(" + "|".join(filters) + ")")
print("excluded tests:", filters)

# exclude TRT EP temporarily and only test CUDA EP to retain previous behavior
os.environ["ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS"] = "TensorrtExecutionProvider"

# import all test cases at global scope to make
jchen351 marked this conversation as resolved.
Show resolved Hide resolved
# them visible to python.unittest.
Expand All @@ -199,6 +205,15 @@ def parse_args():
help="Only run tests that match this value. Matching is regex based, and '.*' is automatically appended",
)

parser.add_argument(
"--devices",
type=str,
choices=["CPU", "CUDA", "MIGRAPHX", "DNNL", "DML", "OPENVINO_GPU", "OPENVINO_CPU", "OPENVINO_NPU", "OPENVINO"],
nargs="+", # allows multiple values
default=["CPU"], # default to ["CPU"] if no input is given
help="Select one or more devices CPU, CUDA, MIGRAPHX, DNNL, DML, OPENVINO_GPU, OPENVINO_CPU, OPENVINO_NPU, OPENVINO",
)

# parse just our args. python unittest has its own args and arg parsing, and that runs inside unittest.main()
parsed, unknown = parser.parse_known_args()
sys.argv = sys.argv[:1] + unknown
Expand All @@ -209,5 +224,5 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()

create_backend_test(args.test_name)
create_backend_test(args.devices, args.test_name)
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,12 @@
"^test_reduce_prod_empty_set_cpu",
//Bug: DML EP does not execute operators with an empty input tensor
//TODO: Resolve as a graph implementation that returns a constant inf tensor with appropriate strides
"^test_reduce_min_empty_set_cpu"
"^test_reduce_min_empty_set_cpu",
snnn marked this conversation as resolved.
Show resolved Hide resolved
"^test_reduce_min_empty_set_cuda",
"^test_asin_example_cuda",
"^test_dynamicquantizelinear_cuda",
"^test_dynamicquantizelinear_expanded_cuda",
"^test_convtranspose_autopad_same_cuda"
],
// ORT first supported opset 7, so models with nodes that require versions prior to opset 7 are not supported
"tests_with_pre_opset7_dependencies": [
Expand Down
Loading
Loading