diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 6efa8a5592337..aecc05c91d736 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -1,192 +1,197 @@ -name: Mac_CI - -on: - push: - branches: - - main - - rel-* - pull_request: - branches: - - main - - rel-* - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -env: - python_version: 3.11 - xcode_version: 15.2 - -jobs: - ARM64: - runs-on: macos-14 - - timeout-minutes: 60 - - steps: - - uses: actions/setup-python@v5 - with: - python-version: ${{ env.python_version }} - - - name: Verify ARM64 machine - shell: python - run: | - import platform - assert platform.machine() == "arm64", "This job expects to be run on an ARM64 machine." - - - name: Use Xcode ${{ env.xcode_version }} - shell: bash - run: | - XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" - sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" - - - uses: actions/checkout@v4 - - - name: Build and test - shell: bash - run: | - python ./tools/ci_build/build.py \ - --build_dir ./build \ - --update \ - --build --parallel \ - --test \ - --build_shared_lib \ - --build_objc \ - --use_coreml \ - --use_xnnpack \ - --use_binskim_compliant_compile_flags - - Vcpkg: - runs-on: macos-13 - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: ${{ env.python_version }} - - - name: "Run vcpkg(x64-osx)" - uses: lukka/run-vcpkg@v11 - with: - vcpkgDirectory: "${{ runner.temp }}/vcpkg" - vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 - runVcpkgInstall: true - vcpkgJsonGlob: "cmake/vcpkg.json" - vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" - env: - VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" - VCPKG_DEFAULT_TRIPLET: "x64-osx" - # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching - - - name: "Run compile_schema.py" - run: | - # Runner's host triplet should be x64-osx or arm64-osx - export FLATC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/flatbuffers" - export PATH="$FLATC_DIR:$PATH" - flatc --version - python onnxruntime/core/flatbuffers/schema/compile_schema.py --flatc "$(which flatc)" - - - name: "Detect protoc" - id: protoc-detect - run: | - export PROTOC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/protobuf" - export PATH="$PROTOC_DIR:$PATH" - protoc --version - echo "protoc_path=$(which protoc)" >> "$GITHUB_OUTPUT" - - - name: "Run build.py(x64-osx)" - run: | - python ./tools/ci_build/build.py \ - --build_dir "build/x64-osx" \ - --skip_submodule_sync \ - --skip_tests \ - --compile_no_warning_as_error \ - --parallel \ - --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ - --osx_arch x86_64 \ - --use_vcpkg \ - --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ - --cmake_extra_defines "VCPKG_TARGET_TRIPLET=x64-osx" \ - --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ - --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" - shell: bash - - - name: "Run vcpkg(arm64-osx)" - uses: lukka/run-vcpkg@v11 - with: - vcpkgDirectory: "${{ runner.temp }}/vcpkg" - vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 - runVcpkgInstall: true - vcpkgJsonGlob: "cmake/vcpkg.json" - vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" - env: - VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" - VCPKG_DEFAULT_TRIPLET: "arm64-osx" - # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching - - - name: "Run build.py(arm64-osx)" - run: | - python ./tools/ci_build/build.py \ - --build_dir "build/arm64-osx" \ - --skip_submodule_sync \ - --skip_tests \ - --compile_no_warning_as_error \ - --parallel \ - --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ - --osx_arch arm64 \ - --use_vcpkg \ - --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ - --cmake_extra_defines "VCPKG_TARGET_TRIPLET=arm64-osx" \ - --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ - --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" - shell: bash - - Objective-C-StaticAnalysis: - runs-on: macos-14 - - timeout-minutes: 30 - - steps: - - uses: actions/setup-python@v5 - with: - python-version: ${{ env.python_version }} - - - name: Use Xcode ${{ env.xcode_version }} - shell: bash - run: | - XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" - sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" - - - uses: actions/checkout@v4 - - - name: Generate compile_commands.json and ONNX protobuf files - shell: bash - run: | - python ./tools/ci_build/build.py \ - --build_dir ./build \ - --cmake_generator "Unix Makefiles" \ - --config Debug \ - --build_shared_lib \ - --use_coreml \ - --build_objc \ - --enable_training_apis \ - --cmake_extra_defines CMAKE_EXPORT_COMPILE_COMMANDS=ON \ - --use_binskim_compliant_compile_flags \ - --update \ - --build --parallel \ - --target onnx_proto - - - name: Analyze Objective-C/C++ source code - shell: bash - run: | - CLANG_TIDY_CHECKS="-*,clang-analyzer-*" - - "$(brew --prefix llvm@15)/bin/clang-tidy" \ - -p=./build/Debug \ - --checks="${CLANG_TIDY_CHECKS}" \ - --warnings-as-errors="${CLANG_TIDY_CHECKS}" \ - --header-filter="objectivec/include|objectivec|onnxruntime/core" \ - ./objectivec/*.mm \ - ./onnxruntime/core/platform/apple/logging/apple_log_sink.mm \ - ./onnxruntime/core/providers/coreml/model/*.mm +name: Mac_CI + +on: + push: + branches: + - main + - rel-* + pull_request: + branches: + - main + - rel-* + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +env: + python_version: 3.11 + +jobs: + ARM64: + runs-on: macos-14 + + env: + xcode_version: 16 + + timeout-minutes: 60 + + steps: + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: Verify ARM64 machine + shell: python + run: | + import platform + assert platform.machine() == "arm64", "This job expects to be run on an ARM64 machine." + + - name: Use Xcode ${{ env.xcode_version }} + shell: bash + run: | + XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" + sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" + + - uses: actions/checkout@v4 + + - name: Build and test + shell: bash + run: | + python ./tools/ci_build/build.py \ + --build_dir ./build \ + --update \ + --build --parallel \ + --test \ + --build_shared_lib \ + --build_objc \ + --use_coreml \ + --use_xnnpack \ + --use_binskim_compliant_compile_flags + + Vcpkg: + runs-on: macos-13 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: "Run vcpkg(x64-osx)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "${{ runner.temp }}/vcpkg" + vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "x64-osx" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run compile_schema.py" + run: | + # Runner's host triplet should be x64-osx or arm64-osx + export FLATC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/flatbuffers" + export PATH="$FLATC_DIR:$PATH" + flatc --version + python onnxruntime/core/flatbuffers/schema/compile_schema.py --flatc "$(which flatc)" + + - name: "Detect protoc" + id: protoc-detect + run: | + export PROTOC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/protobuf" + export PATH="$PROTOC_DIR:$PATH" + protoc --version + echo "protoc_path=$(which protoc)" >> "$GITHUB_OUTPUT" + + - name: "Run build.py(x64-osx)" + run: | + python ./tools/ci_build/build.py \ + --build_dir "build/x64-osx" \ + --skip_submodule_sync \ + --skip_tests \ + --compile_no_warning_as_error \ + --parallel \ + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ + --osx_arch x86_64 \ + --use_vcpkg \ + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=x64-osx" \ + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: bash + + - name: "Run vcpkg(arm64-osx)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "${{ runner.temp }}/vcpkg" + vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "arm64-osx" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run build.py(arm64-osx)" + run: | + python ./tools/ci_build/build.py \ + --build_dir "build/arm64-osx" \ + --skip_submodule_sync \ + --skip_tests \ + --compile_no_warning_as_error \ + --parallel \ + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ + --osx_arch arm64 \ + --use_vcpkg \ + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=arm64-osx" \ + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: bash + + Objective-C-StaticAnalysis: + runs-on: macos-14 + + env: + xcode_version: 15.2 + + timeout-minutes: 30 + + steps: + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: Use Xcode ${{ env.xcode_version }} + shell: bash + run: | + XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" + sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" + + - uses: actions/checkout@v4 + + - name: Generate compile_commands.json and ONNX protobuf files + shell: bash + run: | + python ./tools/ci_build/build.py \ + --build_dir ./build \ + --cmake_generator "Unix Makefiles" \ + --config Debug \ + --build_shared_lib \ + --use_coreml \ + --build_objc \ + --enable_training_apis \ + --cmake_extra_defines CMAKE_EXPORT_COMPILE_COMMANDS=ON \ + --use_binskim_compliant_compile_flags \ + --update \ + --build --parallel \ + --target onnx_proto + + - name: Analyze Objective-C/C++ source code + shell: bash + run: | + CLANG_TIDY_CHECKS="-*,clang-analyzer-*" + + "$(brew --prefix llvm@15)/bin/clang-tidy" \ + -p=./build/Debug \ + --checks="${CLANG_TIDY_CHECKS}" \ + --warnings-as-errors="${CLANG_TIDY_CHECKS}" \ + --header-filter="objectivec/include|objectivec|onnxruntime/core" \ + ./objectivec/*.mm \ + ./onnxruntime/core/platform/apple/logging/apple_log_sink.mm \ + ./onnxruntime/core/providers/coreml/model/*.mm diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config index 96bb053a13f29..294bd926a34cb 100644 --- a/.pipelines/nuget_config/x64/packages.config +++ b/.pipelines/nuget_config/x64/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config index 6bf842ac18037..3528545dfb06e 100644 --- a/.pipelines/nuget_config/x86/packages.config +++ b/.pipelines/nuget_config/x86/packages.config @@ -1,6 +1,6 @@  - + diff --git a/README.md b/README.md index 24c3e191c115b..cde039cec52a8 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,8 @@ |Web|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/ONNX%20Runtime%20Web%20CI%20Pipeline?label=Web)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=161)|| |Other|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-binary-size-checks-ci-pipeline?repoName=microsoft%2Fonnxruntime&label=Binary+Size+Check)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=187&repoName=microsoft%2Fonnxruntime)|| +This project is tested with [BrowserStack](https://www.browserstack.com/home). + ## Third-party Pipeline Status |System|Inference|Training| diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 654099958b21b..4fc4a369051d5 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -216,7 +216,7 @@ "component": { "type": "git", "git": { - "commitHash": "62bdde2a04fcd53c2409cb895ee18db445b7e755", + "commitHash": "9f98e2ebe7507fe0774d06a44bbf4b0e82cc9ce7", "repositoryUrl": "https://github.com/onnx/onnx-tensorrt.git" }, "comments": "onnx_tensorrt" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 02ec0d93e783f..cd6d0669c67f5 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -642,10 +642,12 @@ else() check_cxx_compiler_flag(-Wcast-function-type HAS_CAST_FUNCTION_TYPE) check_cxx_compiler_flag(-Wcatch-value HAS_CATCH_VALUE) check_cxx_compiler_flag(-Wclass-memaccess HAS_CLASS_MEMACCESS) + check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) check_cxx_compiler_flag(-Wdeprecated-anon-enum-enum-conversion HAS_DEPRECATED_ANON_ENUM_ENUM_CONVERSION) check_cxx_compiler_flag(-Wdeprecated-builtins HAS_DEPRECATED_BUILTINS) check_cxx_compiler_flag(-Wdeprecated-copy HAS_DEPRECATED_COPY) check_cxx_compiler_flag(-Wdeprecated-declarations HAS_DEPRECATED_DECLARATIONS) + check_cxx_compiler_flag(-Wdeprecated-this-capture HAS_DEPRECATED_THIS_CAPTURE) check_cxx_compiler_flag(-Wenum-constexpr-conversion HAS_ENUM_CONSTEXPR_CONVERSION) check_cxx_compiler_flag(-Wformat-truncation HAS_FORMAT_TRUNCATION) check_cxx_compiler_flag(-Wignored-attributes HAS_IGNORED_ATTRIBUTES) @@ -656,15 +658,15 @@ else() check_cxx_compiler_flag(-Wshorten-64-to-32 HAS_SHORTEN_64_TO_32) check_cxx_compiler_flag(-Wstrict-aliasing HAS_STRICT_ALIASING) check_nvcc_compiler_flag(-Wstrict-aliasing NVCC_HAS_STRICT_ALIASING) + check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) check_cxx_compiler_flag(-Wtautological-pointer-compare HAS_TAUTOLOGICAL_POINTER_COMPARE) check_cxx_compiler_flag(-Wundefined-var-template HAS_UNDEFINED_VAR_TEMPLATE) check_cxx_compiler_flag(-Wunused-but-set-parameter HAS_UNUSED_BUT_SET_PARAMETER) check_cxx_compiler_flag(-Wunused-but-set-variable HAS_UNUSED_BUT_SET_VARIABLE) check_cxx_compiler_flag(-Wunused-variable HAS_UNUSED_VARIABLE) check_cxx_compiler_flag(-Wuseless-cast HAS_USELESS_CAST) - check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) + if(onnxruntime_ENABLE_TRAINING_APIS) - check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) if(HAS_DANGLING_REFERENCE) list(APPEND ORT_WARNING_FLAGS -Wno-dangling-reference) endif() diff --git a/cmake/deps.txt b/cmake/deps.txt index 342184bda2f0e..3646c14587ff7 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -38,8 +38,8 @@ mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/v0.3.zip;5ec64e3071edc7347ebd8a81679cf06e2bb9b851 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.16.1.zip;2eb9198bb352757d5ff13977cbe0634898e0837c -#use the latest commit of 10.3-GA -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/62bdde2a04fcd53c2409cb895ee18db445b7e755.zip;980a455b07dfa67aa70b9e49d37dd9d4cdf690a0 +# Use the latest commit of 10.4-GA-ORT-DDS +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/9f98e2ebe7507fe0774d06a44bbf4b0e82cc9ce7.zip;1d92137f424513bce20033ab4fb31cc0be8d1185 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index 8b5f602643c0b..e03506de12728 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.15.1) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.15.2) # Restore nuget packages, which will pull down the DirectML redist package. add_custom_command( diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 43f18abbe9522..cb737ee53639f 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -91,6 +91,7 @@ if (NOT WIN32) google_nsync URL ${DEP_URL_google_nsync} URL_HASH SHA1=${DEP_SHA1_google_nsync} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/nsync/nsync_1.26.0.patch FIND_PACKAGE_ARGS NAMES nsync unofficial-nsync ) #nsync tests failed on Mac Build diff --git a/cmake/onnxruntime_config.h.in b/cmake/onnxruntime_config.h.in index e3ea767401ddc..bbddefe531cb8 100644 --- a/cmake/onnxruntime_config.h.in +++ b/cmake/onnxruntime_config.h.in @@ -9,6 +9,7 @@ #cmakedefine HAS_CLASS_MEMACCESS #cmakedefine HAS_DEPRECATED_COPY #cmakedefine HAS_DEPRECATED_DECLARATIONS +#cmakedefine HAS_DEPRECATED_THIS_CAPTURE #cmakedefine HAS_FORMAT_TRUNCATION #cmakedefine HAS_IGNORED_ATTRIBUTES #cmakedefine HAS_MAYBE_UNINITIALIZED diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index e35c83ba45952..0ba4694c329e3 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -88,6 +88,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp ) set(mlas_platform_preprocess_srcs @@ -382,6 +383,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -391,6 +393,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 87cd9e64e778d..0148861d42761 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -899,8 +899,6 @@ if (MSVC) set_property(SOURCE "${TEST_SRC_DIR}/optimizer/graph_transform_test.cc" "${TEST_SRC_DIR}/optimizer/qdq_transformer_test.cc" APPEND PROPERTY COMPILE_OPTIONS "/bigobj") - set_property(SOURCE "${TEST_SRC_DIR}/optimizer/qdq_transformer_test.cc" - APPEND PROPERTY COMPILE_OPTIONS "/bigobj") else() target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses") endif() @@ -1339,7 +1337,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2) + list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 onnx) endif() AddTest(DYN @@ -1643,7 +1641,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") list(APPEND onnxruntime_customopregistration_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp) + list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp) endif() AddTest(DYN TARGET onnxruntime_customopregistration_test @@ -1762,7 +1760,7 @@ if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" set(onnxruntime_logging_apis_test_LIBS onnxruntime_common onnxruntime_test_utils) if (${CMAKE_SYSTEM_NAME} MATCHES "AIX") - list(APPEND onnxruntime_logging_apis_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_lora onnxruntime_framework onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp) + list(APPEND onnxruntime_logging_apis_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_lora onnxruntime_framework onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite ${PROTOBUF_LIB} onnx onnx_proto nsync_cpp) endif() if(NOT WIN32) diff --git a/cmake/patches/nsync/nsync_1.26.0.patch b/cmake/patches/nsync/nsync_1.26.0.patch new file mode 100644 index 0000000000000..78ef2b3cb20d4 --- /dev/null +++ b/cmake/patches/nsync/nsync_1.26.0.patch @@ -0,0 +1,14 @@ +diff --git a/public/nsync_atomic.h b/public/nsync_atomic.h +index aebe4f7..466a262 100644 +--- a/public/nsync_atomic.h ++++ b/public/nsync_atomic.h +@@ -45,7 +45,8 @@ NSYNC_CPP_END_ + NSYNC_CPP_START_ + typedef std::atomic nsync_atomic_uint32_; + NSYNC_CPP_END_ +-#define NSYNC_ATOMIC_UINT32_INIT_ ATOMIC_VAR_INIT (0) ++// Replace deprecated ATOMIC_VAR_INIT with std::atomic brace initialization ++#define NSYNC_ATOMIC_UINT32_INIT_ { 0 } + #define NSYNC_ATOMIC_UINT32_LOAD_(p) (std::atomic_load (p)) + #define NSYNC_ATOMIC_UINT32_STORE_(p,v) (std::atomic_store ((p), (uint32_t) (v))) + #define NSYNC_ATOMIC_UINT32_PTR_(p) (p) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtFloat16.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtFloat16.shared.cs index 7c22e1b213b41..a059208ef6373 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtFloat16.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtFloat16.shared.cs @@ -60,9 +60,9 @@ internal static int LeadingZeroCount(uint num) /// /// Extracts single precision number bit representation as uint /// so its bits can be manipulated. - /// + /// /// This API is the reverse of UInt32BitsToSingle(). - /// + /// /// /// float value /// @@ -79,11 +79,11 @@ internal static uint SingleToUInt32Bits(float single) /// /// Needed because BitConverter impl is not available until /// later versions. This API is the reverse of SingleToUInt32Bits(). - /// + /// /// For the exact bit representation of float see IEEE 754 standard for single precision. - /// + /// /// - /// bit representation of float either obtained from + /// bit representation of float either obtained from /// SingleToUInt32Bits or assembled using bitwise operators /// internal static float UInt32BitsToSingle(uint singleBits) @@ -99,7 +99,7 @@ internal static float UInt32BitsToSingle(uint singleBits) /// /// Converts single precision bits representation which can be obtained using /// SingleToUInt32Bits() or manually constructed according to IEEE 754 standard. - /// + /// /// /// bits representation of a single precision number (float) /// @@ -177,8 +177,8 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand) /// do not have to be copied to be passed to native memory but simply pinned and read by native code. Thus, /// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library. /// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching. - /// - /// The implementation is derived from + /// + /// The implementation is derived from /// https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Half.cs,7895d5942d33f974 /// [StructLayout(LayoutKind.Sequential)] @@ -215,6 +215,7 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand) private const ushort OneBits = 0x3C00; + // Minimum positive normalized value. It is corresponding to numeric_limits::min() in C++. private const ushort EpsilonBits = 0x0400; private const ushort PositiveInfinityBits = 0x7C00; @@ -238,7 +239,7 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand) /// /// Float16 Epsilon value /// - public static Float16 Epsilon => new Float16(EpsilonBits); // 5.9604645E-08 + public static Float16 Epsilon => new Float16(EpsilonBits); // 0.00006103515625 /// /// Float16 Pi value @@ -248,17 +249,17 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand) /// /// Float16 Positive Infinity value /// - public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits); // 1.0 / 0.0; + public static Float16 PositiveInfinity => new Float16(PositiveInfinityBits); /// /// Float16 Negative Infinity value /// - public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits); // -1.0 / 0.0 + public static Float16 NegativeInfinity => new Float16(NegativeInfinityBits); /// /// Float16 NaN /// - public static Float16 NaN => new Float16(NegativeQNaNBits); // 0.0 / 0.0 + public static Float16 NaN => new Float16(NegativeQNaNBits); // Same as System.Half.NaN /// /// Float16 Zero value @@ -276,14 +277,14 @@ internal static float CreateSingle(bool sign, byte exponent, uint significand) public static Float16 NegativeZero => new Float16(NegativeZeroBits); // -0.0 /// - /// Float16 Min value + /// Float16 Lowest value /// - public static Float16 MinValue => new Float16(MinValueBits); // 64,511 + public static Float16 MinValue => new Float16(MinValueBits); // -65504.0 /// /// Float16 Max value /// - public static Float16 MaxValue => new Float16(MaxValueBits); // 31,743 + public static Float16 MaxValue => new Float16(MaxValueBits); // 65504.0 /// /// float16 representation bits @@ -348,7 +349,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits) /// /// Compares values of two Float16 - /// + /// /// /// left hand side /// right hand side @@ -376,7 +377,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits) /// /// Compares values of two Float16 - /// + /// /// /// left hand side /// right hand side @@ -388,7 +389,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits) /// /// Compares values of two Float16 - /// + /// /// /// left hand side /// right hand side @@ -429,7 +430,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits) /// /// Compares values of two Float16 for binary equality. /// If either of the values is NaN, this will return false. - /// + /// /// /// left hand side /// right hand side @@ -479,7 +480,7 @@ public static bool IsInfinity(Float16 value) /// /// Determines whether the specified value is NaN. /// - /// + /// /// Float16 instance /// true if the value is not a number public static bool IsNaN(Float16 value) @@ -500,7 +501,7 @@ public static bool IsNegative(Float16 value) /// /// Determines whether the specified value is negative infinity. /// - /// + /// /// Float16 instance /// true if the value is negative infinity public static bool IsNegativeInfinity(Float16 value) @@ -549,7 +550,7 @@ public static bool IsSubnormal(Float16 value) /// /// Compares this object to another object, returning an integer that indicates the relationship. /// - /// + /// /// Object to compare to /// A value less than zero if this is less than , /// zero if this is equal to , or a value greater than zero @@ -570,7 +571,7 @@ public int CompareTo(object obj) /// /// Object to compare to /// A value less than zero if this is less than , - /// zero if this is equal to , + /// zero if this is equal to , /// or a value greater than zero if this is greater than . public int CompareTo(Float16 other) { @@ -864,10 +865,13 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig) private const ushort PositiveQNaNBits = 0x7FC1; private const ushort NegativeQNaNBits = 0xFFC1; + // Lowest finite value. It is corresponding to numeric_limits::lowest() in C++. private const ushort MinValueBits = 0xFF7F; // 1b0_11111110_1111111 + private const ushort MaxValueBits = 0x7F7F; // 0b0_11111110_1111111 - private const ushort EpsilonBits = 0x0080; // the smallest positive normal value + // Minimum positive normalized value. It is corresponding to numeric_limits::min() in C++. + private const ushort EpsilonBits = 0x0080; private const ushort PiBits = 0x4049; // 0b0_10000000_1001001 @@ -899,7 +903,7 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig) /// /// BFloat16 NaN /// - public static BFloat16 NaN => new BFloat16(NegativeQNaNBits); + public static BFloat16 NaN => new BFloat16(NegativeQNaNBits); // .Net has no BFloat16. Follow Float16 style. /// /// BFloat16 Positive Zero @@ -919,13 +923,13 @@ private static ushort RoundPackToFloat16(bool sign, short exp, ushort sig) /// /// BFloat16 Min value /// - public static BFloat16 MinValue => new BFloat16(MinValueBits); // 65,407 + public static BFloat16 MinValue => new BFloat16(MinValueBits); // -3.38953139e38 /// /// BFloat16 Max value /// - public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 32,639 + public static BFloat16 MaxValue => new BFloat16(MaxValueBits); // 3.38953139e38 /// /// bfloat16 representation bits @@ -1051,7 +1055,7 @@ internal static ushort ExtractTrailingSignificandFromBits(ushort bits) /// /// Compares values of two BFloat16 for binary equality. /// If either of the values is NaN, this will return false. - /// + /// /// /// left hand side /// right hand side @@ -1102,7 +1106,7 @@ public static bool IsInfinity(BFloat16 value) /// /// Determines whether the specified value is NaN. /// - /// + /// /// BFloat16 instance /// true if the value is not a number public static bool IsNaN(BFloat16 value) @@ -1123,7 +1127,7 @@ public static bool IsNegative(BFloat16 value) /// /// Determines whether the specified value is negative infinity. /// - /// + /// /// BFloat16 instance /// true if the value is negative infinity public static bool IsNegativeInfinity(BFloat16 value) @@ -1170,7 +1174,7 @@ public static bool IsSubnormal(BFloat16 value) /// /// Compares this object to another object, returning an integer that indicates the relationship. /// - /// + /// /// Object to compare to /// A value less than zero if this is less than , /// zero if this is equal to , or a value greater than zero @@ -1191,7 +1195,7 @@ public int CompareTo(object obj) /// /// Object to compare to /// A value less than zero if this is less than , - /// zero if this is equal to , + /// zero if this is equal to , /// or a value greater than zero if this is greater than . public int CompareTo(BFloat16 other) { @@ -1368,4 +1372,4 @@ private static uint StripSign(BFloat16 value) #endregion } -} \ No newline at end of file +} diff --git a/docs/Coding_Conventions_and_Standards.md b/docs/Coding_Conventions_and_Standards.md index e8e1e7dc9ccd8..f18f1036efee8 100644 --- a/docs/Coding_Conventions_and_Standards.md +++ b/docs/Coding_Conventions_and_Standards.md @@ -155,7 +155,7 @@ Using `Show Code Coverage Coloring` will allow you to visually inspect which lin This project uses [lintrunner](https://github.com/suo/lintrunner) for linting. It provides a consistent linting experience locally and in CI. You can install the dependencies and initialize with ```sh -pip install lintrunner lintrunner-adapters +pip install -r requirements-lintrunner.txt lintrunner init ``` diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 407e08c96a891..a9176605d9175 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -175,8 +175,8 @@ Do not modify directly.* |||[1, 12]|**T** = tensor(float)| |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| |||[7, 13]|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| -|LayerNormalization|*in* X:**T**
*in* Scale:**T**
*in* B:**T**
*out* Y:**T**
*out* Mean:**U**
*out* InvStdDev:**U**

or

*in* X:**T**
*in* Scale:**V**
*in* B:**V**
*out* Y:**V**
*out* Mean:**U**
*out* InvStdDev:**U**|17+|**T** = tensor(double), tensor(float)
**U** = tensor(float)| -|||[1, 16]|**T** = tensor(double), tensor(float)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float)| +|LayerNormalization|*in* X:**T**
*in* Scale:**T**
*in* B:**T**
*out* Y:**T**
*out* Mean:**U**
*out* InvStdDev:**U**

or

*in* X:**T**
*in* Scale:**V**
*in* B:**V**
*out* Y:**V**
*out* Mean:**U**
*out* InvStdDev:**U**|17+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(float)| +|||[1, 16]|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)
**V** = tensor(double), tensor(float), tensor(float16)| |LeakyRelu|*in* X:**T**
*out* Y:**T**|16+|**T** = tensor(float)| |||[6, 15]|**T** = tensor(float)| |Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| @@ -369,7 +369,7 @@ Do not modify directly.* |||[6, 12]|**T** = tensor(double), tensor(float)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[9, 12]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)
**V** = tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)| |Size|*in* data:**T**
*out* size:**T1**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -482,7 +482,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -508,11 +508,11 @@ Do not modify directly.* |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| -|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| -|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| -|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| +|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| @@ -1055,9 +1055,9 @@ Do not modify directly.* |||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|If|*in* cond:**B**
*out* outputs:**V**|19+|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||16+|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||13+|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|If|*in* cond:**B**
*out* outputs:**V**|19+|**B** = tensor(bool)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||16+|**B** = tensor(bool)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**B** = tensor(bool)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||7+|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/common/eigen_common_wrapper.h b/include/onnxruntime/core/common/eigen_common_wrapper.h index 57599e04037dc..19efa7bcff107 100644 --- a/include/onnxruntime/core/common/eigen_common_wrapper.h +++ b/include/onnxruntime/core/common/eigen_common_wrapper.h @@ -49,6 +49,12 @@ #pragma GCC diagnostic ignored "-Wshorten-64-to-32" #endif +// eigen-src/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h:215:9: +// error: implicit capture of 'this' with a capture default of '=' is deprecated [-Werror,-Wdeprecated-this-capture] +#ifdef HAS_DEPRECATED_THIS_CAPTURE +#pragma GCC diagnostic ignored "-Wdeprecated-this-capture" +#endif + #elif defined(_MSC_VER) // build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76): // warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index abab118efd04f..57b332ce65b93 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -53,6 +53,7 @@ constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; constexpr const char* OpenVINO_RT = "OpenVINO_RT"; constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU"; constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; +constexpr const char* WEBNN_TENSOR = "WebNN_Tensor"; constexpr size_t kAllocAlignment = 256; diff --git a/include/onnxruntime/core/framework/float16.h b/include/onnxruntime/core/framework/float16.h index 1f2f175c6e691..dac0a01fbc3fe 100644 --- a/include/onnxruntime/core/framework/float16.h +++ b/include/onnxruntime/core/framework/float16.h @@ -295,3 +295,147 @@ inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) { } } // namespace onnxruntime + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::MLFloat16 min() noexcept { + return onnxruntime::MLFloat16::FromBits(0x0400U); // Minimum positive normalized value: 0.00006103515625 + } + + static constexpr onnxruntime::MLFloat16 max() noexcept { + return onnxruntime::MLFloat16::FromBits(0x7BFFU); // Largest representable value: 65504 + } + + static constexpr onnxruntime::MLFloat16 lowest() noexcept { + return onnxruntime::MLFloat16::FromBits(0xFBFFU); // Smallest representable value: -65504 + } + + static constexpr onnxruntime::MLFloat16 infinity() noexcept { + return onnxruntime::MLFloat16::FromBits(0x7C00U); // Bits: sign(0), exponent(111,11), fraction(00,0000,0000) + } + + static constexpr onnxruntime::MLFloat16 quiet_NaN() noexcept { + // The most significant fraction bit shall be 1, and no limitation on other fraction bits. + // Note that most frameworks use 0x7E00; while CUDA uses 0x7FFF; .Net System.Half.NaN uses 0xFE00; + return onnxruntime::MLFloat16::FromBits(0x7E00U); // Bits: sign(0), exponent(111,11), fraction(10,0000,0000) + } + + static constexpr onnxruntime::MLFloat16 signaling_NaN() noexcept { + return onnxruntime::MLFloat16::FromBits(0x7D00U); // Bits: sign(0), exponent(111,11), fraction(01,0000,0000) + } + + static constexpr onnxruntime::MLFloat16 denorm_min() noexcept { + return onnxruntime::MLFloat16::FromBits(0x0001U); // Minimum subnormal value: 0.000000059604645 + } + + static constexpr onnxruntime::MLFloat16 epsilon() noexcept { + return onnxruntime::MLFloat16::FromBits(0x1400U); // Difference between 1.0 and the next value: 2^-10 = 0.0009765625 + } + + static constexpr onnxruntime::MLFloat16 round_error() noexcept { + return onnxruntime::MLFloat16::FromBits(0x3800U); // 0.5 + } + + static constexpr bool is_specialized = true; + + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = denorm_present; + static constexpr bool has_denorm_loss = false; + + static constexpr bool is_bounded = true; + static constexpr bool is_iec559 = true; + static constexpr bool is_modulo = false; + + static constexpr int digits = 11; // Number of significant digits (mantissa) + static constexpr int digits10 = 3; // Decimal digits of precision + static constexpr int max_digits10 = 5; // Max decimal digits required for precision + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr std::float_round_style round_style = std::round_to_nearest; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::BFloat16 min() noexcept { + return onnxruntime::BFloat16::FromBits(0x0080U); // Minimum positive normalized value: 1.175494e-38 + } + + static constexpr onnxruntime::BFloat16 max() noexcept { + return onnxruntime::BFloat16::FromBits(0x7F7FU); // Largest representable value: 3.38953139e38 + } + + static constexpr onnxruntime::BFloat16 lowest() noexcept { + return onnxruntime::BFloat16::FromBits(0xFF7FU); // Smallest representable value: -3.38953139e38 + } + + static constexpr onnxruntime::BFloat16 infinity() noexcept { + return onnxruntime::BFloat16::FromBits(0x7F80U); // Bits: sign(0), exponent(111,1111,1), fraction(000,0000) + } + + static constexpr onnxruntime::BFloat16 quiet_NaN() noexcept { + // The most significant fraction bit shall be 1, and no limitation on other fraction bits. + // Note that Torch, Tensorflow, OpenVino, nGraph uses 0x7FC0; Paddle uses 0x7FC1; CUDA uses 0x7FFF. + return onnxruntime::BFloat16::FromBits(0x7FC1U); // Bits: sign(0), exponent(111,1111,1), fraction(100,0001) + } + + static constexpr onnxruntime::BFloat16 signaling_NaN() noexcept { + // The most significant fraction bit shall be 0, and there is at least one 1 in other fraction bits. + return onnxruntime::BFloat16::FromBits(0x7F81U); // Bits: sign(0), exponent(111,1111,1), fraction(000,0001) + } + + static constexpr onnxruntime::BFloat16 denorm_min() noexcept { + return onnxruntime::BFloat16::FromBits(0x0001U); // Minimum subnormal value: 9.1835e-41 + } + + static constexpr onnxruntime::BFloat16 epsilon() noexcept { + return onnxruntime::BFloat16::FromBits(0x3C00U); // Difference between 1.0 and the next value: 2^-7 = 0.0078125 + } + + static constexpr onnxruntime::BFloat16 round_error() noexcept { + return onnxruntime::BFloat16::FromBits(0x3F00U); // 0.5 + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = denorm_present; + static constexpr bool has_denorm_loss = false; + + static constexpr bool is_bounded = true; + static constexpr bool is_iec559 = false; + static constexpr bool is_modulo = false; + + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; + +} // namespace std diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h index 5e39849186756..5d92ee86af864 100644 --- a/include/onnxruntime/core/framework/float8.h +++ b/include/onnxruntime/core/framework/float8.h @@ -102,6 +102,10 @@ struct Float8E4M3FN { #endif } + inline ORT_HOST_DEVICE bool IsNaN() const { + return (val & 0b01111111) == 0b01111111; + } + inline ORT_HOST_DEVICE float ToFloat() const { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 return __half2float(__nv_cvt_fp8_to_halfraw(val, __NV_E4M3)); @@ -266,6 +270,10 @@ struct Float8E4M3FNUZ { } } + inline ORT_HOST_DEVICE bool IsNaN() const { + return val == 0b10000000; + } + inline ORT_HOST_DEVICE float ToFloat() const { // This type does not exist on CUDA. uint32_t res; @@ -416,6 +424,16 @@ struct Float8E5M2 { #endif } + inline ORT_HOST_DEVICE bool IsNaN() const { + // 7D, 7E, 7F are positive NaNs; FD, FE, FF are negative NaNs + return (val & 0b01111111) > 0b01111100; + } + + inline ORT_HOST_DEVICE bool IsInfinity() const { + // 7C and FC are infinity + return (val & 0b01111111) == 0b01111100; + } + inline ORT_HOST_DEVICE float ToFloat() const { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 return __half2float(__nv_cvt_fp8_to_halfraw(val, __NV_E5M2)); @@ -575,6 +593,10 @@ struct Float8E5M2FNUZ { } } + inline ORT_HOST_DEVICE bool IsNaN() const { + return val == 0b10000000; + } + inline ORT_HOST_DEVICE float ToFloat() const { // This type does not exist on CUDA. uint32_t res; @@ -648,4 +670,251 @@ inline void FloatToFloat8E5M2FNUZ(const float* flt, Float8E5M2FNUZ* blf, size_t } // namespace onnxruntime +namespace std { + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E4M3FN lowest() { + return onnxruntime::Float8E4M3FN(0xFE, onnxruntime::Float8E4M3FN::FromBits()); // -448 + } + + static constexpr onnxruntime::Float8E4M3FN max() { + return onnxruntime::Float8E4M3FN(0x7E, onnxruntime::Float8E4M3FN::FromBits()); // 448 + } + + static constexpr onnxruntime::Float8E4M3FN min() { + return onnxruntime::Float8E4M3FN(0x08, onnxruntime::Float8E4M3FN::FromBits()); // 2^-6 = 0.015625 + } + + static constexpr onnxruntime::Float8E4M3FN denorm_min() { + return onnxruntime::Float8E4M3FN(0x01, onnxruntime::Float8E4M3FN::FromBits()); // 2^-9 = 0.001953125 + } + + static constexpr onnxruntime::Float8E4M3FN epsilon() { + return onnxruntime::Float8E4M3FN(0x20, onnxruntime::Float8E4M3FN::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FN round_error() { + return onnxruntime::Float8E4M3FN(0x30, onnxruntime::Float8E4M3FN::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FN infinity() { + // no infinity, returns quiet NaN instead + return quiet_NaN(); + } + + static constexpr onnxruntime::Float8E4M3FN quiet_NaN() { + return onnxruntime::Float8E4M3FN(0x7F, onnxruntime::Float8E4M3FN::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -5; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E5M2 lowest() { + return onnxruntime::Float8E5M2(0xFB, onnxruntime::Float8E5M2::FromBits()); // -57344.0 + } + + static constexpr onnxruntime::Float8E5M2 max() { + return onnxruntime::Float8E5M2(0x7B, onnxruntime::Float8E5M2::FromBits()); // 57344.0 + } + + static constexpr onnxruntime::Float8E5M2 min() { + return onnxruntime::Float8E5M2(0x4, onnxruntime::Float8E5M2::FromBits()); // 2^-14 = 0.00006103515 + } + + static constexpr onnxruntime::Float8E5M2 denorm_min() { + return onnxruntime::Float8E5M2(0x01, onnxruntime::Float8E5M2::FromBits()); // 2^-16 = 0.00001525878 + } + + static constexpr onnxruntime::Float8E5M2 epsilon() { + return onnxruntime::Float8E5M2(0x34, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2 round_error() { + return onnxruntime::Float8E5M2(0x38, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2 infinity() { + return onnxruntime::Float8E5M2(0x7C, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2 quiet_NaN() { + return onnxruntime::Float8E5M2(0x7F, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E4M3FNUZ lowest() { + return onnxruntime::Float8E4M3FNUZ(0xFF, onnxruntime::Float8E4M3FNUZ::FromBits()); // -240.0 + } + + static constexpr onnxruntime::Float8E4M3FNUZ max() { + return onnxruntime::Float8E4M3FNUZ(0x7F, onnxruntime::Float8E4M3FNUZ::FromBits()); // 240.0 + } + + static constexpr onnxruntime::Float8E4M3FNUZ min() { + return onnxruntime::Float8E4M3FNUZ(0x08, onnxruntime::Float8E4M3FNUZ::FromBits()); // 2^-7 = 0.0078125 + } + + static constexpr onnxruntime::Float8E4M3FNUZ denorm_min() { + return onnxruntime::Float8E4M3FNUZ(0x01, onnxruntime::Float8E4M3FNUZ::FromBits()); // 2^-10 = 0.0009765625 + } + + static constexpr onnxruntime::Float8E4M3FNUZ epsilon() { + return onnxruntime::Float8E4M3FNUZ(0x28, onnxruntime::Float8E4M3FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FNUZ round_error() { + return onnxruntime::Float8E4M3FNUZ(0x38, onnxruntime::Float8E4M3FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FNUZ infinity() { + // no infinity, returns quiet NaN instead + return quiet_NaN(); + } + + static constexpr onnxruntime::Float8E4M3FNUZ quiet_NaN() { + return onnxruntime::Float8E4M3FNUZ(0x80, onnxruntime::Float8E4M3FNUZ::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -6; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E5M2FNUZ lowest() { + return onnxruntime::Float8E5M2FNUZ(0xFF, onnxruntime::Float8E5M2FNUZ::FromBits()); // -57344.0 + } + + static constexpr onnxruntime::Float8E5M2FNUZ max() { + return onnxruntime::Float8E5M2FNUZ(0x7F, onnxruntime::Float8E5M2FNUZ::FromBits()); // 57344.0 + } + + static constexpr onnxruntime::Float8E5M2FNUZ min() { + return onnxruntime::Float8E5M2FNUZ(0x04, onnxruntime::Float8E5M2FNUZ::FromBits()); // 2^-15 = 0.00003051757 + } + + static constexpr onnxruntime::Float8E5M2FNUZ denorm_min() { + return onnxruntime::Float8E5M2FNUZ(0x01, onnxruntime::Float8E5M2FNUZ::FromBits()); // 2^-17 = 0.00000762939 + } + + static constexpr onnxruntime::Float8E5M2FNUZ epsilon() { + return onnxruntime::Float8E5M2FNUZ(0x34, onnxruntime::Float8E5M2FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2FNUZ round_error() { + return onnxruntime::Float8E5M2FNUZ(0x38, onnxruntime::Float8E5M2FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2FNUZ infinity() { + // no infinity, returns quiet NaN instead + return quiet_NaN(); + } + + static constexpr onnxruntime::Float8E5M2FNUZ quiet_NaN() { + return onnxruntime::Float8E5M2FNUZ(0x80, onnxruntime::Float8E5M2FNUZ::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -14; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +} // namespace std + #endif // DISABLE_FLOAT8_TYPES diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 8237ac4220f24..39e0361b7ff4f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3651,10 +3651,10 @@ struct OrtApi { * - "73" * - "75" * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). - "enable_htp_fp16_precision": Only used for float32 model. + "enable_htp_fp16_precision": Used for float32 model for HTP backend. Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. - - "0": Default. With fp32 precision. - - "1": With fp16 precision. + - "0": With fp32 precision. + - "1": Default. With fp16 precision. "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context. - "0": Default. Disabled. - "1": Enabled. diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index e1ee2c14fd9d1..3f276a3670156 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -54,7 +54,7 @@ public class OnnxTensor extends OnnxTensorLike { * the state of this buffer without first getting the reference via {@link #getBufferRef()}. * * @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is - * a copy of a user buffer.) + * a copy of a user buffer or array.) */ public boolean ownsBuffer() { return this.ownsBuffer; @@ -62,8 +62,8 @@ public boolean ownsBuffer() { /** * Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not - * backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by - * ORT) this method returns an empty {@link Optional}. + * backed by a buffer (i.e., it is backed by memory allocated by ORT) this method returns an empty + * {@link Optional}. * *

Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be * used to repeatedly update a single tensor for multiple different inferences without allocating @@ -77,7 +77,116 @@ public boolean ownsBuffer() { * @return A reference to the buffer. */ public Optional getBufferRef() { - return Optional.ofNullable(buffer); + return Optional.ofNullable(duplicate(buffer)); + } + + /** + * Duplicates the buffer to ensure concurrent reads don't disrupt the buffer position. Concurrent + * writes will modify the underlying memory in a racy way, don't do that. + * + *

Can be replaced to a call to buf.duplicate() in Java 9+. + * + * @param buf The buffer to duplicate. + * @return A copy of the buffer which refers to the same underlying memory, but has an independent + * position, limit and mark. + */ + private static Buffer duplicate(Buffer buf) { + if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).duplicate().order(ByteOrder.nativeOrder()); + } else if (buf instanceof ShortBuffer) { + return ((ShortBuffer) buf).duplicate(); + } else if (buf instanceof IntBuffer) { + return ((IntBuffer) buf).duplicate(); + } else if (buf instanceof LongBuffer) { + return ((LongBuffer) buf).duplicate(); + } else if (buf instanceof FloatBuffer) { + return ((FloatBuffer) buf).duplicate(); + } else if (buf instanceof DoubleBuffer) { + return ((DoubleBuffer) buf).duplicate(); + } else { + throw new IllegalStateException("Unknown buffer type " + buf.getClass()); + } + } + + /** + * Checks that the buffer is the right type for the {@code info.type}, and if it's a {@link + * ByteBuffer} then convert it to the right type. If it's not convertible it throws {@link + * IllegalStateException}. + * + *

Note this method converts FP16 and BFLOAT16 ShortBuffers into FP32 FloatBuffers, to preserve + * compatibility with existing {@link #getValue} calls. + * + * @param buf The buffer to convert. + * @return The buffer with the expected type. + */ + private Buffer castBuffer(Buffer buf) { + switch (info.type) { + case FLOAT: + if (buf instanceof FloatBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asFloatBuffer(); + } + break; + case DOUBLE: + if (buf instanceof DoubleBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asDoubleBuffer(); + } + break; + case BOOL: + case INT8: + case UINT8: + if (buf instanceof ByteBuffer) { + return buf; + } + break; + case BFLOAT16: + if (buf instanceof ShortBuffer) { + ShortBuffer bf16Buf = (ShortBuffer) buf; + return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf); + } else if (buf instanceof ByteBuffer) { + ShortBuffer bf16Buf = ((ByteBuffer) buf).asShortBuffer(); + return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf); + } + break; + case FLOAT16: + if (buf instanceof ShortBuffer) { + ShortBuffer fp16Buf = (ShortBuffer) buf; + return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf); + } else if (buf instanceof ByteBuffer) { + ShortBuffer fp16Buf = ((ByteBuffer) buf).asShortBuffer(); + return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf); + } + break; + case INT16: + if (buf instanceof ShortBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asShortBuffer(); + } + break; + case INT32: + if (buf instanceof IntBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asIntBuffer(); + } + break; + case INT64: + if (buf instanceof LongBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asLongBuffer(); + } + break; + } + throw new IllegalStateException( + "Invalid buffer type for cast operation, found " + + buf.getClass() + + " expected something convertible to " + + info.type); } @Override @@ -133,15 +242,26 @@ public Object getValue() throws OrtException { Object carrier = info.makeCarrier(); if (info.getNumElements() > 0) { // If the tensor has values copy them out - getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); - } - if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { - // We read the strings out from native code in a flat array and then reshape - // to the desired output shape. - return OrtUtil.reshape((String[]) carrier, info.shape); - } else { - return carrier; + if (info.type == OnnxJavaType.STRING) { + // We read the strings out from native code in a flat array and then reshape + // to the desired output shape if necessary. + getStringArray(OnnxRuntime.ortApiHandle, nativeHandle, (String[]) carrier); + if (info.shape.length != 1) { + carrier = OrtUtil.reshape((String[]) carrier, info.shape); + } + } else { + // Wrap ORT owned memory in buffer, otherwise use our reference + Buffer buf; + if (buffer == null) { + buf = castBuffer(getBuffer()); + } else { + buf = castBuffer(duplicate(buffer)); + } + // Copy out buffer into arrays + OrtUtil.fillArrayFromBuffer(info, buf, 0, carrier); + } } + return carrier; } } @@ -175,8 +295,8 @@ public synchronized void close() { public ByteBuffer getByteBuffer() { checkClosed(); if (info.type != OnnxJavaType.STRING) { - ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle); - ByteBuffer output = ByteBuffer.allocate(buffer.capacity()); + ByteBuffer buffer = getBuffer(); + ByteBuffer output = ByteBuffer.allocate(buffer.capacity()).order(ByteOrder.nativeOrder()); output.put(buffer); output.rewind(); return output; @@ -201,12 +321,12 @@ public FloatBuffer getFloatBuffer() { output.rewind(); return output; } else if (info.type == OnnxJavaType.FLOAT16) { - // if it's fp16 we need to copy it out by hand. + // if it's fp16 we need to convert it. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); return Fp16Conversions.convertFp16BufferToFloatBuffer(buffer); } else if (info.type == OnnxJavaType.BFLOAT16) { - // if it's bf16 we need to copy it out by hand. + // if it's bf16 we need to convert it. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); return Fp16Conversions.convertBf16BufferToFloatBuffer(buffer); @@ -331,7 +451,7 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType) private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException; - private native void getArray(long apiHandle, long nativeHandle, Object carrier) + private native void getStringArray(long apiHandle, long nativeHandle, String[] carrier) throws OrtException; private native void close(long apiHandle, long nativeHandle); @@ -387,21 +507,32 @@ static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Objec info); } } else { + Buffer buf; if (info.shape.length == 0) { - data = OrtUtil.convertBoxedPrimitiveToArray(info.type, data); - if (data == null) { + buf = OrtUtil.convertBoxedPrimitiveToBuffer(info.type, data); + if (buf == null) { throw new OrtException( "Failed to convert a boxed primitive to an array, this is an error with the ORT Java API, please report this message & stack trace. JavaType = " + info.type + ", object = " + data); } + } else { + buf = OrtUtil.convertArrayToBuffer(info, data); } return new OnnxTensor( - createTensor( - OnnxRuntime.ortApiHandle, allocator.handle, data, info.shape, info.onnxType.value), + createTensorFromBuffer( + OnnxRuntime.ortApiHandle, + allocator.handle, + buf, + 0, + info.type.size * info.numElements, + info.shape, + info.onnxType.value), allocator.handle, - info); + info, + buf, + true); } } else { throw new IllegalStateException("Trying to create an OnnxTensor with a closed OrtAllocator."); @@ -627,7 +758,26 @@ static OnnxTensor createTensor( */ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long[] shape) throws OrtException { - return createTensor(env, env.defaultAllocator, data, shape); + return createTensor(env, env.defaultAllocator, data, shape, OnnxJavaType.INT16); + } + + /** + * Create an OnnxTensor backed by a direct ShortBuffer. The buffer should be in nativeOrder. + * + *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime + * of the tensor. Uses the default allocator. + * + * @param env The current OrtEnvironment. + * @param data The tensor data. + * @param shape The shape of tensor. + * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16}, + * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}. + * @return An OnnxTensor of the required shape. + * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. + */ + public static OnnxTensor createTensor( + OrtEnvironment env, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException { + return createTensor(env, env.defaultAllocator, data, shape, type); } /** @@ -640,15 +790,23 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. + * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16}, + * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( - OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape) + OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException { if (!allocator.isClosed()) { - OnnxJavaType type = OnnxJavaType.INT16; - return createTensor(type, allocator, data, shape); + if ((type == OnnxJavaType.BFLOAT16) + || (type == OnnxJavaType.FLOAT16) + || (type == OnnxJavaType.INT16)) { + return createTensor(type, allocator, data, shape); + } else { + throw new IllegalArgumentException( + "Only int16, float16 or bfloat16 tensors can be created from ShortBuffer."); + } } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -768,10 +926,6 @@ private static OnnxTensor createTensor( tuple.isCopy); } - private static native long createTensor( - long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType) - throws OrtException; - private static native long createTensorFromBuffer( long apiHandle, long allocatorHandle, diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 4f3dee3c00b91..2f44236e4ef67 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -26,10 +26,10 @@ public final class OrtUtil { private OrtUtil() {} /** - * Converts an long shape into a int shape. + * Converts a long shape into an int shape. * - *

Validates that the shape has more than 1 elements, less than 9 elements, each element is - * less than {@link Integer#MAX_VALUE} and that each entry is non-negative. + *

Validates that the shape has more than 1 element, less than 9 elements, each element is less + * than {@link Integer#MAX_VALUE} and that each entry is non-negative. * * @param shape The long shape. * @return The int shape. @@ -460,6 +460,308 @@ static Object convertBoxedPrimitiveToArray(OnnxJavaType javaType, Object data) { } } + /** + * Stores a boxed primitive in a single element buffer of the unboxed type. + * + *

If it's not a boxed primitive then it returns null. + * + * @param javaType The type of the boxed primitive. + * @param data The boxed primitive. + * @return The primitive in a direct buffer. + */ + static Buffer convertBoxedPrimitiveToBuffer(OnnxJavaType javaType, Object data) { + switch (javaType) { + case FLOAT: + { + FloatBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + buf.put(0, (Float) data); + return buf; + } + case DOUBLE: + { + DoubleBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer(); + buf.put(0, (Double) data); + return buf; + } + case BOOL: + { + ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()); + buf.put(0, ((boolean) data) ? (byte) 1 : (byte) 0); + return buf; + } + case UINT8: + case INT8: + { + ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()); + buf.put(0, (Byte) data); + return buf; + } + case FLOAT16: + case BFLOAT16: + case INT16: + { + ShortBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asShortBuffer(); + buf.put(0, (Short) data); + return buf; + } + case INT32: + { + IntBuffer buf = + ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asIntBuffer(); + buf.put(0, (Integer) data); + return buf; + } + case INT64: + { + LongBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asLongBuffer(); + buf.put(0, (Long) data); + return buf; + } + case STRING: + case UNKNOWN: + default: + return null; + } + } + + /** + * Copies a Java (possibly multidimensional) array into a direct {@link Buffer}. + * + *

Throws {@link IllegalArgumentException} if the array is not an array of Java primitives or + * if the array is ragged. + * + * @param info The tensor info object containing the types and shape of the array. + * @param array The array object. + * @return A direct buffer containing all the elements. + */ + static Buffer convertArrayToBuffer(TensorInfo info, Object array) { + ByteBuffer byteBuffer = + ByteBuffer.allocateDirect((int) info.numElements * info.type.size) + .order(ByteOrder.nativeOrder()); + + Buffer buffer; + switch (info.type) { + case FLOAT: + buffer = byteBuffer.asFloatBuffer(); + break; + case DOUBLE: + buffer = byteBuffer.asDoubleBuffer(); + break; + case BOOL: + case INT8: + case UINT8: + // no-op, it's already a bytebuffer + buffer = byteBuffer; + break; + case BFLOAT16: + case FLOAT16: + case INT16: + buffer = byteBuffer.asShortBuffer(); + break; + case INT32: + buffer = byteBuffer.asIntBuffer(); + break; + case INT64: + buffer = byteBuffer.asLongBuffer(); + break; + case STRING: + case UNKNOWN: + default: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + + fillBufferFromArray(info, array, 0, buffer); + + if (buffer.remaining() != 0) { + throw new IllegalArgumentException( + "Failed to copy all elements into the buffer, expected to copy " + + info.numElements + + " into a buffer of capacity " + + buffer.capacity() + + " but had " + + buffer.remaining() + + " values left over."); + } + buffer.rewind(); + + return buffer; + } + + /** + * Fills the provided buffer with the values from the array, recursing through the array + * structure. + * + * @param info The tensor info containing the type and shape of the array. + * @param array The array object to read from. + * @param curDim The current dimension we're processing. + * @param buffer The buffer to write to. + */ + private static void fillBufferFromArray( + TensorInfo info, Object array, int curDim, Buffer buffer) { + if (curDim == info.shape.length - 1) { + // Reached primitive values, copy into buffer + switch (info.type) { + case FLOAT: + float[] fArr = (float[]) array; + FloatBuffer fBuf = (FloatBuffer) buffer; + fBuf.put(fArr); + break; + case DOUBLE: + double[] dArr = (double[]) array; + DoubleBuffer dBuf = (DoubleBuffer) buffer; + dBuf.put(dArr); + break; + case INT8: + case UINT8: + byte[] bArr = (byte[]) array; + ByteBuffer bBuf = (ByteBuffer) buffer; + bBuf.put(bArr); + break; + case FLOAT16: + case BFLOAT16: + case INT16: + short[] sArr = (short[]) array; + ShortBuffer sBuf = (ShortBuffer) buffer; + sBuf.put(sArr); + break; + case INT32: + int[] iArr = (int[]) array; + IntBuffer iBuf = (IntBuffer) buffer; + iBuf.put(iArr); + break; + case INT64: + long[] lArr = (long[]) array; + LongBuffer lBuf = (LongBuffer) buffer; + lBuf.put(lArr); + break; + case BOOL: + boolean[] boolArr = (boolean[]) array; + ByteBuffer boolBuf = (ByteBuffer) buffer; + for (int i = 0; i < boolArr.length; i++) { + boolBuf.put(boolArr[i] ? (byte) 1 : (byte) 0); + } + break; + case STRING: + case UNKNOWN: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + } else { + // Recurse through array + long expectedSize = info.shape[curDim]; + long actualSize = Array.getLength(array); + if (expectedSize != actualSize) { + throw new IllegalArgumentException( + "Mismatch in array sizes, expected " + + expectedSize + + " at dim " + + curDim + + " from shape " + + Arrays.toString(info.shape) + + ", found " + + actualSize); + } else { + for (int i = 0; i < actualSize; i++) { + fillBufferFromArray(info, Array.get(array, i), curDim + 1, buffer); + } + } + } + } + + /** + * Fills the provided array with the values from the buffer, recursing through the array + * structure. + * + * @param info The tensor info containing the type and shape of the array. + * @param buffer The buffer to read from. + * @param curDim The current dimension we're processing. + * @param array The array object to write to. + */ + static void fillArrayFromBuffer(TensorInfo info, Buffer buffer, int curDim, Object array) { + if (curDim == info.shape.length - 1) { + // Reached primitive values, copy into buffer + switch (info.type) { + case FLOAT16: + case BFLOAT16: + case FLOAT: + float[] fArr = (float[]) array; + FloatBuffer fBuf = (FloatBuffer) buffer; + fBuf.get(fArr); + break; + case DOUBLE: + double[] dArr = (double[]) array; + DoubleBuffer dBuf = (DoubleBuffer) buffer; + dBuf.get(dArr); + break; + case INT8: + case UINT8: + byte[] bArr = (byte[]) array; + ByteBuffer bBuf = (ByteBuffer) buffer; + bBuf.get(bArr); + break; + case INT16: + short[] sArr = (short[]) array; + ShortBuffer sBuf = (ShortBuffer) buffer; + sBuf.get(sArr); + break; + case INT32: + int[] iArr = (int[]) array; + IntBuffer iBuf = (IntBuffer) buffer; + iBuf.get(iArr); + break; + case INT64: + long[] lArr = (long[]) array; + LongBuffer lBuf = (LongBuffer) buffer; + lBuf.get(lArr); + break; + case BOOL: + boolean[] boolArr = (boolean[]) array; + ByteBuffer boolBuf = (ByteBuffer) buffer; + for (int i = 0; i < boolArr.length; i++) { + // Test to see if the byte is non-zero, non-zero bytes are true, zero bytes are false. + boolArr[i] = boolBuf.get() != 0; + } + break; + case STRING: + case UNKNOWN: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + } else { + // Recurse through array + long expectedSize = info.shape[curDim]; + long actualSize = Array.getLength(array); + if (expectedSize != actualSize) { + throw new IllegalArgumentException( + "Mismatch in array sizes, expected " + + expectedSize + + " at dim " + + curDim + + " from shape " + + Arrays.toString(info.shape) + + ", found " + + actualSize); + } else { + for (int i = 0; i < actualSize; i++) { + fillArrayFromBuffer(info, buffer, curDim + 1, Array.get(array, i)); + } + } + } + } + /** * Returns expected JDK map capacity for a given size, this factors in the default JDK load factor * diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 1c21387b50455..f3e9f21ef408d 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -323,6 +323,9 @@ public long getNumElements() { * all elements as that's the expected format of the native code. It can be reshaped to the * correct shape using {@link OrtUtil#reshape(String[],long[])}. * + *

For fp16 and bf16 tensors the output carrier type is float, and so this method produces + * multidimensional float arrays. + * * @return A multidimensional array of the appropriate primitive type (or String). * @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is * greater than an int). @@ -335,6 +338,8 @@ public Object makeCarrier() throws OrtException { + Arrays.toString(shape)); } switch (type) { + case BFLOAT16: + case FLOAT16: case FLOAT: return OrtUtil.newFloatArray(shape); case DOUBLE: diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 7b26291581395..6a3c279073860 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -502,104 +502,6 @@ jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSeque return sequenceInfo; } -int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor) { - int32_t inputLength = (*jniEnv)->GetArrayLength(jniEnv, inputArray); - int64_t consumedSize = inputLength * onnxTypeSize(onnxType); - switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // maps to c type uint8_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { // maps to c type int8_t - jbyteArray typedArr = (jbyteArray)inputArray; - (*jniEnv)->GetByteArrayRegion(jniEnv, typedArr, 0, inputLength, (jbyte * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: // maps to c type uint16_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { // maps to c type int16_t - jshortArray typedArr = (jshortArray)inputArray; - (*jniEnv)->GetShortArrayRegion(jniEnv, typedArr, 0, inputLength, (jshort * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: // maps to c type uint32_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { // maps to c type int32_t - jintArray typedArr = (jintArray)inputArray; - (*jniEnv)->GetIntArrayRegion(jniEnv, typedArr, 0, inputLength, (jint * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: // maps to c type uint64_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { // maps to c type int64_t - jlongArray typedArr = (jlongArray)inputArray; - (*jniEnv)->GetLongArrayRegion(jniEnv, typedArr, 0, inputLength, (jlong * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "16-bit float not supported."); - return -1; - /* - float *floatArr = malloc(sizeof(float) * inputLength); - uint16_t *halfArr = (uint16_t *) outputTensor; - for (uint32_t i = 0; i < inputLength; i++) { - floatArr[i] = convertHalfToFloat(halfArr[i]); - } - jfloatArray typedArr = (jfloatArray) inputArray; - (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, floatArr); - free(floatArr); - return consumedSize; - */ - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { // maps to c type float - jfloatArray typedArr = (jfloatArray)inputArray; - (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, (jfloat * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { // maps to c type double - jdoubleArray typedArr = (jdoubleArray)inputArray; - (*jniEnv)->GetDoubleArrayRegion(jniEnv, typedArr, 0, inputLength, (jdouble * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { // maps to c++ type std::string - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "String is not supported."); - return -1; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - jbooleanArray typedArr = (jbooleanArray)inputArray; - (*jniEnv)->GetBooleanArrayRegion(jniEnv, typedArr, 0, inputLength, (jboolean *)outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: // complex with float64 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: - default: { - throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid outputTensor element type."); - return -1; - } - } -} - -int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor) { - if (dimensionsRemaining == 1) { - // write out 1d array of the respective primitive type - return copyJavaToPrimitiveArray(jniEnv, onnxType, inputArray, outputTensor); - } else { - // recurse through the dimensions - // Java arrays are objects until the final dimension - jobjectArray inputObjArr = (jobjectArray)inputArray; - int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, inputObjArr); - int64_t sizeConsumed = 0; - for (int32_t i = 0; i < dimLength; i++) { - jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, inputObjArr, i); - int64_t consumed = copyJavaToTensor(jniEnv, onnxType, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr, outputTensor + sizeConsumed); - sizeConsumed += consumed; - // Cleanup reference to childArr so it doesn't prevent GC. - (*jniEnv)->DeleteLocalRef(jniEnv, childArr); - // If we failed to copy an array then break and return. - if (consumed == -1) { - return -1; - } - } - return sizeConsumed; - } -} - int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray) { int32_t outputLength = (*jniEnv)->GetArrayLength(jniEnv, outputArray); if (outputLength == 0) return 0; @@ -697,65 +599,6 @@ int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxT } } -int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, - size_t dimensionsRemaining, jarray outputArray) { - if (dimensionsRemaining == 1) { - // write out 1d array of the respective primitive type - return copyPrimitiveArrayToJava(jniEnv, onnxType, inputTensor, outputArray); - } else { - // recurse through the dimensions - // Java arrays are objects until the final dimension - jobjectArray outputObjArr = (jobjectArray)outputArray; - int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, outputObjArr); - int64_t sizeConsumed = 0; - for (int32_t i = 0; i < dimLength; i++) { - jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, outputObjArr, i); - int64_t consumed = copyTensorToJava(jniEnv, onnxType, inputTensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr); - sizeConsumed += consumed; - // Cleanup reference to childArr so it doesn't prevent GC. - (*jniEnv)->DeleteLocalRef(jniEnv, childArr); - // If we failed to copy an array then break and return. - if (consumed == -1) { - return -1; - } - } - return sizeConsumed; - } -} - -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { - jobject tempString = NULL; - // Get the buffer size needed - size_t totalStringLength = 0; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, &totalStringLength)); - if (code != ORT_OK) { - return NULL; - } - - // Create the character and offset buffers, character is one larger to allow zero termination. - char * characterBuffer = malloc(sizeof(char)*(totalStringLength+1)); - if (characterBuffer == NULL) { - throwOrtException(jniEnv, 1, "OOM error"); - } else { - size_t * offsets = malloc(sizeof(size_t)); - if (offsets != NULL) { - // Get a view on the String data - code = checkOrtStatus(jniEnv, api, api->GetStringTensorContent(tensor, characterBuffer, totalStringLength, offsets, 1)); - - if (code == ORT_OK) { - size_t curSize = (offsets[0]) + 1; - characterBuffer[curSize-1] = '\0'; - tempString = (*jniEnv)->NewStringUTF(jniEnv, characterBuffer); - } - - free((void*)characterBuffer); - free((void*)offsets); - } - } - - return tempString; -} - OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray) { size_t bufferSize = 16; char * tempBuffer = malloc(bufferSize); diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 023bc0c739583..7f41e06371f2a 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -54,16 +54,8 @@ jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInf jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info); -int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor); - -int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor); - int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray); -int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray); - -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); - OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray); jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c index b694f57357bb5..d757bd6281499 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c +++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c @@ -8,72 +8,6 @@ #include "OrtJniUtil.h" #include "ai_onnxruntime_OnnxTensor.h" -/* - * Class: ai_onnxruntime_OnnxTensor - * Method: createTensor - * Signature: (JJLjava/lang/Object;[JI)J - */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor - (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj, - jlongArray shape, jint onnxTypeJava) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; - // Convert type to ONNX C enum - ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava); - - // Extract the shape information - jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL); - jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape); - - // Create the OrtValue - OrtValue* ortValue = NULL; - OrtErrorCode code = checkOrtStatus(jniEnv, api, - api->CreateTensorAsOrtValue( - allocator, (int64_t*)shapeArr, shapeLen, onnxType, &ortValue - ) - ); - (*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT); - - int failed = 0; - if (code == ORT_OK) { - // Get a reference to the OrtValue's data - uint8_t* tensorData = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&tensorData)); - if (code == ORT_OK) { - // Check if we're copying a scalar or not - if (shapeLen == 0) { - // Scalars are passed in as a single element array - int64_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, dataObj, tensorData); - failed = copied == -1 ? 1 : failed; - } else { - // Extract the tensor shape information - JavaTensorTypeShape typeShape; - code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue); - - if (code == ORT_OK) { - // Copy the java array into the tensor - int64_t copied = copyJavaToTensor(jniEnv, onnxType, typeShape.elementCount, - typeShape.dimensions, dataObj, tensorData); - failed = copied == -1 ? 1 : failed; - } else { - failed = 1; - } - } - } else { - failed = 1; - } - } - - if (failed) { - api->ReleaseValue(ortValue); - ortValue = NULL; - } - - // Return the pointer to the OrtValue - return (jlong) ortValue; -} - /* * Class: ai_onnxruntime_OnnxTensor * Method: createTensorFromBuffer @@ -227,7 +161,7 @@ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer size_t sizeBytes = typeShape.elementCount * typeSize; uint8_t* arr = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&arr)); if (code == ORT_OK) { return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes); @@ -401,11 +335,11 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool /* * Class: ai_onnxruntime_OnnxTensor - * Method: getArray - * Signature: (JJLjava/lang/Object;)V + * Method: getStringArray + * Signature: (JJ[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) { +JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getStringArray + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobjectArray carrier) { (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtValue* value = (OrtValue*) handle; @@ -415,12 +349,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier); } else { - uint8_t* arr = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr)); - if (code == ORT_OK) { - copyTensorToJava(jniEnv, typeShape.onnxTypeEnum, arr, typeShape.elementCount, - typeShape.dimensions, (jarray)carrier); - } + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Non-string types are not supported by this codepath, please raise a Github issue as it should not reach here."); } } } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 11141a3a65a3e..7cb6305923279 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -495,12 +495,12 @@ public void throwWrongInputName() throws OrtException { container.put("wrong_name", OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for incorrect name."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unknown input name")); + } finally { + OnnxValue.close(container.values()); } } } @@ -522,12 +522,57 @@ public void throwWrongInputType() throws OrtException { container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for incorrect type."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unexpected input data type")); + } finally { + OnnxValue.close(container.values()); + } + } + } + + @Test + public void throwWrongSizeInput() throws OrtException { + SqueezeNetTuple tuple = openSessionSqueezeNet(); + try (OrtSession session = tuple.session) { + + float[] inputData = tuple.inputData; + NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); + Map container = new HashMap<>(); + float[] wrongSizeData = Arrays.copyOf(inputData, 2 * 224 * 224); + Object tensor = OrtUtil.reshape(wrongSizeData, new long[] {1, 2, 224, 224}); + container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); + try { + session.run(container); + fail("Should throw exception for incorrect size."); + } catch (OrtException e) { + String msg = e.getMessage(); + assertTrue(msg.contains("Got invalid dimensions for input")); + } finally { + OnnxValue.close(container.values()); + } + } + } + + @Test + public void throwWrongRankInput() throws OrtException { + SqueezeNetTuple tuple = openSessionSqueezeNet(); + try (OrtSession session = tuple.session) { + + float[] inputData = tuple.inputData; + NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); + Map container = new HashMap<>(); + Object tensor = OrtUtil.reshape(inputData, new long[] {1, 1, 3, 224, 224}); + container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); + try { + session.run(container); + fail("Should throw exception for incorrect size."); + } catch (OrtException e) { + String msg = e.getMessage(); + assertTrue(msg.contains("Invalid rank for input")); + } finally { + OnnxValue.close(container.values()); } } } @@ -550,12 +595,12 @@ public void throwExtraInputs() throws OrtException { container.put("extra", OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for too many inputs."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unexpected number of inputs")); + } finally { + OnnxValue.close(container.values()); } } } @@ -565,12 +610,11 @@ public void testMultiThreads() throws OrtException, InterruptedException { int numThreads = 10; int loop = 10; SqueezeNetTuple tuple = openSessionSqueezeNet(); + Map container = new HashMap<>(); try (OrtSession session = tuple.session) { - float[] inputData = tuple.inputData; float[] expectedOutput = tuple.outputData; NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); - Map container = new HashMap<>(); long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape; Object tensor = OrtUtil.reshape(inputData, inputShape); container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); @@ -592,8 +636,9 @@ public void testMultiThreads() throws OrtException, InterruptedException { } executor.shutdown(); executor.awaitTermination(1, TimeUnit.MINUTES); - OnnxValue.close(container.values()); assertTrue(executor.isTerminated()); + } finally { + OnnxValue.close(container.values()); } } diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index ea210d96c1507..064f14f3b51ff 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -12,8 +12,11 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; +import java.nio.IntBuffer; import java.nio.ShortBuffer; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.SplittableRandom; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -93,30 +96,108 @@ public void testScalarCreation() throws OrtException { } @Test - public void testBufferCreation() throws OrtException { + public void testArrayCreation() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); - // Test creating a value from an array - // Arrays result in tensors allocated by ORT, so they do not have a backing java.nio.Buffer + // Test creating a value from a single dimensional array float[] arrValues = new float[] {0, 1, 2, 3, 4}; try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { - // array creation isn't backed by buffers - assertFalse(t.ownsBuffer()); - assertFalse(t.getBufferRef().isPresent()); - FloatBuffer buf = t.getFloatBuffer(); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + FloatBuffer buf = (FloatBuffer) t.getBufferRef().get(); float[] output = new float[arrValues.length]; buf.get(output); Assertions.assertArrayEquals(arrValues, output); - // Can't modify the tensor through this buffer. + // Can modify the tensor through this buffer. buf.put(0, 25); - Assertions.assertArrayEquals(arrValues, output); + Assertions.assertArrayEquals(new float[] {25, 1, 2, 3, 4}, (float[]) t.getValue()); } + // Test creating a value from a multidimensional float array + float[][][] arr3dValues = + new float[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}, {9, 10, 11}}, + {{12, 13, 14}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, arr3dValues)) { + Assertions.assertArrayEquals(new long[] {4, 2, 3}, t.getInfo().getShape()); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + float[][][] output = (float[][][]) t.getValue(); + Assertions.assertArrayEquals(arr3dValues, output); + + // Can modify the tensor through the buffer. + FloatBuffer buf = (FloatBuffer) t.getBufferRef().get(); + buf.put(0, 25); + buf.put(12, 32); + buf.put(13, 33); + buf.put(23, 35); + arr3dValues[0][0][0] = 25; + arr3dValues[2][0][0] = 32; + arr3dValues[2][0][1] = 33; + arr3dValues[3][1][2] = 35; + output = (float[][][]) t.getValue(); + Assertions.assertArrayEquals(arr3dValues, output); + } + + // Test creating a value from a multidimensional int array + int[][][] iArr3dValues = + new int[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}, {9, 10, 11}}, + {{12, 13, 14}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, iArr3dValues)) { + Assertions.assertArrayEquals(new long[] {4, 2, 3}, t.getInfo().getShape()); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + int[][][] output = (int[][][]) t.getValue(); + Assertions.assertArrayEquals(iArr3dValues, output); + + // Can modify the tensor through the buffer. + IntBuffer buf = (IntBuffer) t.getBufferRef().get(); + buf.put(0, 25); + iArr3dValues[0][0][0] = 25; + output = (int[][][]) t.getValue(); + Assertions.assertArrayEquals(iArr3dValues, output); + } + + // Test creating a value from a ragged array throws + int[][][] ragged = + new int[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}}, + {{12, 13}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, ragged)) { + Assertions.fail("Can't create tensors from ragged arrays"); + } catch (OrtException e) { + Assertions.assertTrue(e.getMessage().contains("ragged")); + } + + // Test creating a value from a non-array, non-primitive type throws. + List list = new ArrayList<>(5); + list.add(5); + try (OnnxTensor t = OnnxTensor.createTensor(env, list)) { + Assertions.fail("Can't create tensors from lists"); + } catch (OrtException e) { + Assertions.assertTrue(e.getMessage().contains("Cannot convert")); + } + } + + @Test + public void testBufferCreation() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + // Test creating a value from a non-direct byte buffer // Non-direct byte buffers are allocated on the Java heap and must be copied into off-heap - // direct byte buffers - // which can be directly passed to ORT + // direct byte buffers which can be directly passed to ORT + float[] arrValues = new float[] {0, 1, 2, 3, 4}; FloatBuffer nonDirectBuffer = FloatBuffer.allocate(5); nonDirectBuffer.put(arrValues); nonDirectBuffer.rewind(); @@ -335,10 +416,12 @@ public void testFp32ToFp16() throws OrtException { String modelPath = TestHelpers.getResourcePath("/java-fp32-to-fp16.onnx").toString(); SplittableRandom rng = new SplittableRandom(1); - float[][] input = new float[10][5]; + int dim1 = 10, dim2 = 5; + float[][] input = new float[dim1][dim2]; + float[][] expectedOutput = new float[dim1][dim2]; FloatBuffer floatBuf = - ByteBuffer.allocateDirect(4 * 10 * 5).order(ByteOrder.nativeOrder()).asFloatBuffer(); - ShortBuffer shortBuf = ShortBuffer.allocate(10 * 5); + ByteBuffer.allocateDirect(4 * dim1 * dim2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ShortBuffer shortBuf = ShortBuffer.allocate(dim1 * dim2); // Generate data for (int i = 0; i < input.length; i++) { @@ -347,6 +430,8 @@ public void testFp32ToFp16() throws OrtException { input[i][j] = Float.intBitsToFloat(bits); floatBuf.put(input[i][j]); shortBuf.put(Fp16Conversions.floatToFp16(input[i][j])); + expectedOutput[i][j] = + Fp16Conversions.fp16ToFloat(Fp16Conversions.floatToFp16(input[i][j])); } } floatBuf.rewind(); @@ -354,25 +439,31 @@ public void testFp32ToFp16() throws OrtException { try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); - OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {10, 5}); + OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {dim1, dim2}); OrtSession.Result result = session.run(Collections.singletonMap("input", tensor))) { OnnxTensor output = (OnnxTensor) result.get(0); // Check outbound Java side cast to fp32 works FloatBuffer castOutput = output.getFloatBuffer(); - float[] expectedFloatArr = new float[10 * 5]; + float[] expectedFloatArr = new float[dim1 * dim2]; Fp16Conversions.convertFp16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); - float[] actualFloatArr = new float[10 * 5]; + float[] actualFloatArr = new float[dim1 * dim2]; castOutput.get(actualFloatArr); Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr); // Check bits are correct ShortBuffer outputBuf = output.getShortBuffer(); - short[] expectedShortArr = new short[10 * 5]; + short[] expectedShortArr = new short[dim1 * dim2]; shortBuf.get(expectedShortArr); - short[] actualShortArr = new short[10 * 5]; + short[] actualShortArr = new short[dim1 * dim2]; outputBuf.get(actualShortArr); Assertions.assertArrayEquals(expectedShortArr, actualShortArr); + + // Check outbound fp16 -> float[] conversion + float[][] floats = (float[][]) output.getValue(); + for (int i = 0; i < dim1; i++) { + Assertions.assertArrayEquals(expectedOutput[i], floats[i]); + } } } @@ -382,10 +473,12 @@ public void testFp32ToBf16() throws OrtException { String modelPath = TestHelpers.getResourcePath("/java-fp32-to-bf16.onnx").toString(); SplittableRandom rng = new SplittableRandom(1); - float[][] input = new float[10][5]; + int dim1 = 10, dim2 = 5; + float[][] input = new float[dim1][dim2]; + float[][] expectedOutput = new float[dim1][dim2]; FloatBuffer floatBuf = - ByteBuffer.allocateDirect(4 * 10 * 5).order(ByteOrder.nativeOrder()).asFloatBuffer(); - ShortBuffer shortBuf = ShortBuffer.allocate(10 * 5); + ByteBuffer.allocateDirect(4 * dim1 * dim2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ShortBuffer shortBuf = ShortBuffer.allocate(dim1 * dim2); // Generate data for (int i = 0; i < input.length; i++) { @@ -394,6 +487,8 @@ public void testFp32ToBf16() throws OrtException { input[i][j] = Float.intBitsToFloat(bits); floatBuf.put(input[i][j]); shortBuf.put(Fp16Conversions.floatToBf16(input[i][j])); + expectedOutput[i][j] = + Fp16Conversions.bf16ToFloat(Fp16Conversions.floatToBf16(input[i][j])); } } floatBuf.rewind(); @@ -401,25 +496,31 @@ public void testFp32ToBf16() throws OrtException { try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); - OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {10, 5}); + OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {dim1, dim2}); OrtSession.Result result = session.run(Collections.singletonMap("input", tensor))) { OnnxTensor output = (OnnxTensor) result.get(0); // Check outbound Java side cast to fp32 works FloatBuffer castOutput = output.getFloatBuffer(); - float[] expectedFloatArr = new float[10 * 5]; + float[] expectedFloatArr = new float[dim1 * dim2]; Fp16Conversions.convertBf16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); - float[] actualFloatArr = new float[10 * 5]; + float[] actualFloatArr = new float[dim1 * dim2]; castOutput.get(actualFloatArr); Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr); // Check bits are correct ShortBuffer outputBuf = output.getShortBuffer(); - short[] expectedShortArr = new short[10 * 5]; + short[] expectedShortArr = new short[dim1 * dim2]; shortBuf.get(expectedShortArr); - short[] actualShortArr = new short[10 * 5]; + short[] actualShortArr = new short[dim1 * dim2]; outputBuf.get(actualShortArr); Assertions.assertArrayEquals(expectedShortArr, actualShortArr); + + // Check outbound bf16 -> float[] conversion + float[][] floats = (float[][]) output.getValue(); + for (int i = 0; i < dim1; i++) { + Assertions.assertArrayEquals(expectedOutput[i], floats[i]); + } } } diff --git a/js/common/lib/tensor-factory-impl.ts b/js/common/lib/tensor-factory-impl.ts index 52e028a9fcd31..cbc0270091818 100644 --- a/js/common/lib/tensor-factory-impl.ts +++ b/js/common/lib/tensor-factory-impl.ts @@ -11,6 +11,7 @@ import { TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, + TensorFromMLTensorOptions, TensorFromTextureOptions, TensorFromUrlOptions, } from './tensor-factory.js'; @@ -152,7 +153,7 @@ export const tensorFromImage = async ( } }; const createCanvasContext = (canvas: HTMLCanvasElement | OffscreenCanvas) => { - if (canvas instanceof HTMLCanvasElement) { + if (typeof HTMLCanvasElement !== 'undefined' && canvas instanceof HTMLCanvasElement) { return canvas.getContext('2d'); } else if (canvas instanceof OffscreenCanvas) { return canvas.getContext('2d') as OffscreenCanvasRenderingContext2D; @@ -310,6 +311,17 @@ export const tensorFromGpuBuffer = ( + mlTensor: TensorInterface.MLTensorType, + options: TensorFromMLTensorOptions, +): Tensor => { + const { dataType, dims, download, dispose } = options; + return new Tensor({ location: 'ml-tensor', type: dataType ?? 'float32', mlTensor, dims, download, dispose }); +}; + /** * implementation of Tensor.fromPinnedBuffer(). */ diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts index 7938b4a4eb927..f66684112623e 100644 --- a/js/common/lib/tensor-factory.ts +++ b/js/common/lib/tensor-factory.ts @@ -86,6 +86,20 @@ export interface GpuBufferConstructorParameters + extends CommonConstructorParameters, + GpuResourceConstructorParameters { + /** + * Specify the location of the data to be 'ml-tensor'. + */ + readonly location: 'ml-tensor'; + + /** + * Specify the WebNN MLTensor that holds the tensor data. + */ + readonly mlTensor: Tensor.MLTensorType; +} + // #endregion // the following region contains type definitions of each individual options. @@ -219,6 +233,15 @@ export interface TensorFromGpuBufferOptions dataType?: T; } +export interface TensorFromMLTensorOptions + extends Pick, + GpuResourceConstructorParameters { + /** + * Describes the data type of the tensor. + */ + dataType?: T; +} + // #endregion /** @@ -336,6 +359,29 @@ export interface TensorFactory { options: TensorFromGpuBufferOptions, ): TypedTensor; + /** + * create a tensor from a WebNN MLTensor + * + * @param tensor - the MLTensor object to create tensor from + * @param options - An optional object representing options for creating tensor from a WebNN MLTensor. + * + * The options include following properties: + * - `dataType`: the data type of the tensor. If omitted, assume 'float32'. + * - `dims`: the dimension of the tensor. Required. + * - `download`: an optional function to download the tensor data from the MLTensor to CPU. If omitted, the MLTensor + * data will not be able to download. Usually, this is provided by the WebNN backend for the inference outputs. + * Users don't need to provide this function. + * - `dispose`: an optional function to dispose the tensor data on the WebNN MLTensor. If omitted, the MLTensor will + * not be disposed. Usually, this is provided by the WebNN backend for the inference outputs. Users don't need to + * provide this function. + * + * @returns a tensor object + */ + fromMLTensor( + tensor: Tensor.MLTensorType, + options: TensorFromMLTensorOptions, + ): TypedTensor; + /** * create a tensor from a pre-allocated buffer. The buffer will be used as a pinned buffer. * diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 342f5e3a467eb..c0e1582c17de5 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -6,16 +6,19 @@ import { TensorToDataUrlOptions, TensorToImageDataOptions } from './tensor-conve import { tensorFromGpuBuffer, tensorFromImage, + tensorFromMLTensor, tensorFromPinnedBuffer, tensorFromTexture, } from './tensor-factory-impl.js'; import { CpuPinnedConstructorParameters, GpuBufferConstructorParameters, + MLTensorConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, + TensorFromMLTensorOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters, @@ -37,6 +40,7 @@ type TensorDataType = TensorInterface.DataType; type TensorDataLocation = TensorInterface.DataLocation; type TensorTextureType = TensorInterface.TextureType; type TensorGpuBufferType = TensorInterface.GpuBufferType; +type TensorMLTensorType = TensorInterface.MLTensorType; /** * the implementation of Tensor interface. @@ -86,6 +90,15 @@ export class Tensor implements TensorInterface { */ constructor(params: GpuBufferConstructorParameters); + /** + * Construct a new tensor object from the WebNN MLTensor with the given type and dims. + * + * Tensor's location will be set to 'ml-tensor'. + * + * @param params - Specify the parameters to construct the tensor. + */ + constructor(params: MLTensorConstructorParameters); + /** * implementation. */ @@ -98,7 +111,8 @@ export class Tensor implements TensorInterface { | readonly boolean[] | CpuPinnedConstructorParameters | TextureConstructorParameters - | GpuBufferConstructorParameters, + | GpuBufferConstructorParameters + | MLTensorConstructorParameters, arg1?: TensorDataType | Uint8ClampedArray | readonly number[] | readonly string[] | readonly boolean[], arg2?: readonly number[], ) { @@ -155,6 +169,25 @@ export class Tensor implements TensorInterface { this.disposer = arg0.dispose; break; } + case 'ml-tensor': { + if ( + type !== 'float32' && + type !== 'float16' && + type !== 'int32' && + type !== 'int64' && + type !== 'uint32' && + type !== 'uint64' && + type !== 'int8' && + type !== 'uint8' && + type !== 'bool' + ) { + throw new TypeError(`unsupported type "${type}" to create tensor from MLTensor`); + } + this.mlTensorData = arg0.mlTensor; + this.downloader = arg0.download; + this.disposer = arg0.dispose; + break; + } default: throw new Error(`Tensor constructor: unsupported location '${this.dataLocation}'`); } @@ -325,6 +358,13 @@ export class Tensor implements TensorInterface { return tensorFromGpuBuffer(gpuBuffer, options); } + static fromMLTensor( + mlTensor: TensorMLTensorType, + options: TensorFromMLTensorOptions, + ): TensorInterface { + return tensorFromMLTensor(mlTensor, options); + } + static fromPinnedBuffer( type: T, buffer: TensorInterface.DataTypeMap[T], @@ -373,6 +413,11 @@ export class Tensor implements TensorInterface { */ private gpuBufferData?: TensorGpuBufferType; + /** + * stores the underlying WebNN MLTensor when location is 'ml-tensor'. otherwise empty. + */ + private mlTensorData?: TensorMLTensorType; + /** * stores an optional downloader function to download data from GPU to CPU. */ @@ -420,6 +465,14 @@ export class Tensor implements TensorInterface { } return this.gpuBufferData; } + + get mlTensor(): TensorMLTensorType { + this.ensureValid(); + if (!this.mlTensorData) { + throw new Error('The data is not stored as a WebNN MLTensor.'); + } + return this.mlTensorData; + } // #endregion // #region methods @@ -431,7 +484,8 @@ export class Tensor implements TensorInterface { case 'cpu-pinned': return this.data; case 'texture': - case 'gpu-buffer': { + case 'gpu-buffer': + case 'ml-tensor': { if (!this.downloader) { throw new Error('The current tensor is not created with a specified data downloader.'); } @@ -472,6 +526,7 @@ export class Tensor implements TensorInterface { this.cpuData = undefined; this.gpuTextureData = undefined; this.gpuBufferData = undefined; + this.mlTensorData = undefined; this.downloader = undefined; this.isDownloading = undefined; diff --git a/js/common/lib/tensor-utils-impl.ts b/js/common/lib/tensor-utils-impl.ts index 9c633cd95fac3..97b1735e6eac5 100644 --- a/js/common/lib/tensor-utils-impl.ts +++ b/js/common/lib/tensor-utils-impl.ts @@ -4,6 +4,7 @@ import { CpuPinnedConstructorParameters, GpuBufferConstructorParameters, + MLTensorConstructorParameters, TextureConstructorParameters, } from './tensor-factory.js'; import { Tensor } from './tensor-impl.js'; @@ -56,6 +57,13 @@ export const tensorReshape = (tensor: Tensor, dims: readonly number[]): Tensor = type: tensor.type as GpuBufferConstructorParameters['type'], dims, }); + case 'ml-tensor': + return new Tensor({ + location: 'ml-tensor', + mlTensor: tensor.mlTensor, + type: tensor.type as MLTensorConstructorParameters['type'], + dims, + }); default: throw new Error(`tensorReshape: tensor location ${tensor.location} is not supported`); } diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 8a1197994393b..17e2f4d37c91f 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -42,6 +42,13 @@ interface TypedTensorBase { */ readonly gpuBuffer: Tensor.GpuBufferType; + /** + * Get the WebNN MLTensor that holds the tensor data. + * + * If the data is not in a WebNN MLTensor, throw error. + */ + readonly mlTensor: Tensor.MLTensorType; + /** * Get the buffer data of the tensor. * @@ -136,15 +143,36 @@ export declare namespace Tensor { */ export type GpuBufferType = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' }; + /** + * type alias for WebNN MLTensor + * + * The specification for WebNN's MLTensor is currently in flux. + */ + export type MLTensorType = unknown; + /** * supported data types for constructing a tensor from a WebGPU buffer */ export type GpuBufferDataTypes = 'float32' | 'float16' | 'int32' | 'int64' | 'uint32' | 'uint8' | 'bool'; + /** + * supported data types for constructing a tensor from a WebNN MLTensor + */ + export type MLTensorDataTypes = + | 'float32' + | 'float16' + | 'int8' + | 'uint8' + | 'int32' + | 'uint32' + | 'int64' + | 'uint64' + | 'bool'; + /** * represent where the tensor data is stored */ - export type DataLocation = 'none' | 'cpu' | 'cpu-pinned' | 'texture' | 'gpu-buffer'; + export type DataLocation = 'none' | 'cpu' | 'cpu-pinned' | 'texture' | 'gpu-buffer' | 'ml-tensor'; /** * represent the data type of a tensor diff --git a/js/react_native/e2e/.detoxrc.js b/js/react_native/e2e/.detoxrc.js index 0792c3d528585..e886a363d378b 100644 --- a/js/react_native/e2e/.detoxrc.js +++ b/js/react_native/e2e/.detoxrc.js @@ -6,7 +6,7 @@ module.exports = { config: 'test/jest.config.js', }, jest: { - setupTimeout: 120000, + setupTimeout: 240000, }, }, apps: { diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 6fd4f9af20432..6c50f3752737b 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -53,6 +53,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | | | Log | ai.onnx(7-12, 13+) | log | ✓ | ✓ | | | LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 | +| LSTM | ai.onnx(7-13, 14-21, 22+) | lstm | ✓ | ✓ | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | | | Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | | | MaxPool | ai.onnx(7, 8-9, 10, 11, 12+) | maxPool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'storage_order' != 1, one output | diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 39f8c2a6d0db3..bfb74355b0d70 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -785,15 +785,20 @@ export class WebGpuBackend { this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping); } + // the buffer may be user created, or managed by GPU data manager. + // The GPU data manager will not manage these buffers. we register them as external buffers. + // + // The map `sessionInputOutputMapping` is used to store the data ID and buffer for each input/output. Once a + // specific input/output is registered, the data ID will not change. const previousBuffer = sessionInputOutputMapping.get(index); - const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]); + const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer); sessionInputOutputMapping.set(index, [id, buffer]); return id; } unregisterBuffers(sessionId: number): void { const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); if (sessionInputOutputMapping) { - sessionInputOutputMapping.forEach((bufferInfo) => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1])); + sessionInputOutputMapping.forEach((bufferInfo) => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[0])); this.sessionExternalDataMapping.delete(sessionId); } } diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts new file mode 100644 index 0000000000000..685f3dc019461 --- /dev/null +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + +import { Env, Tensor } from 'onnxruntime-common'; + +import { DataType } from '../wasm-common'; +import { getInstance } from '../wasm-factory'; + +import { createView } from './tensor-view'; +import { TensorId, createTensorManager } from './webnn/tensor-manager'; +import { configureLogger, LOG_DEBUG } from './log'; + +/* + * TensorProto::data_type to WebNN OperandType mapping. + */ +const onnxDataTypeToWebnnDataType = new Map([ + [DataType.float, 'float32'], + [DataType.float16, 'float16'], + [DataType.int32, 'int32'], + [DataType.uint32, 'uint32'], + [DataType.int64, 'int64'], + [DataType.uint64, 'uint64'], + [DataType.int8, 'int8'], + [DataType.uint8, 'uint8'], + [DataType.bool, 'uint8'], +]); + +/** + * WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track + * of the current MLContext being used by the sessions. + */ +export class WebNNBackend { + /** + * Tensor managers for each session. + */ + private tensorManager = createTensorManager(this); + /** + * Maps from session id to MLContexts. + */ + private mlContextBySessionId = new Map(); + /** + * Maps from MLContext to session ids. + */ + private sessionIdsByMLContext = new Map>(); + /** + * Current session id. + */ + private activeSessionId?: number; + + constructor(env: Env) { + configureLogger(env.logLevel!, !!env.debug); + } + + public get currentSessionId(): number { + if (this.activeSessionId === undefined) { + throw new Error('No active session'); + } + return this.activeSessionId; + } + + public onRunStart(sessionId: number): void { + this.activeSessionId = sessionId; + } + + public get currentContext(): MLContext { + const mlContext = this.getMLContext(this.currentSessionId); + if (!mlContext) { + throw new Error(`No MLContext found for session ${this.currentSessionId}`); + } + return mlContext; + } + + public registerMLContext(sessionId: number, mlContext: MLContext): void { + this.mlContextBySessionId.set(sessionId, mlContext); + let sessionIds = this.sessionIdsByMLContext.get(mlContext); + if (!sessionIds) { + sessionIds = new Set(); + this.sessionIdsByMLContext.set(mlContext, sessionIds); + } + sessionIds.add(sessionId); + } + + public onReleaseSession(sessionId: number): void { + const mlContext = this.mlContextBySessionId.get(sessionId)!; + if (!mlContext) { + // Current session is not a WebNN session. + return; + } + this.mlContextBySessionId.delete(sessionId); + const sessionIds = this.sessionIdsByMLContext.get(mlContext)!; + sessionIds.delete(sessionId); + if (sessionIds.size === 0) { + this.sessionIdsByMLContext.delete(mlContext); + this.tensorManager.releaseTensorsForContext(mlContext); + } + } + + public getMLContext(sessionId: number): MLContext | undefined { + return this.mlContextBySessionId.get(sessionId); + } + + public reserveTensorId(): TensorId { + return this.tensorManager.reserveTensorId(); + } + + public releaseTensorId(tensorId: TensorId): void { + LOG_DEBUG('verbose', () => `[WebNN] releaseTensorId {tensorId: ${tensorId}}`); + this.tensorManager.releaseTensorId(tensorId); + } + + public async ensureTensor( + tensorId: TensorId, + onnxDataType: DataType, + dimensions: number[], + copyOld: boolean, + ): Promise { + const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType); + if (!webnnDataType) { + throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); + } + return this.tensorManager.ensureTensor(tensorId, webnnDataType, dimensions, copyOld); + } + + public uploadTensor(tensorId: TensorId, data: Uint8Array): void { + const wasm = getInstance(); + if (!wasm.shouldTransferToMLTensor) { + throw new Error('Trying to upload to a MLTensor while shouldTransferToMLTensor is false'); + } + LOG_DEBUG('verbose', () => `[WebNN] uploadTensor {tensorId: ${tensorId}, data: ${data.byteLength}}`); + this.tensorManager.upload(tensorId, data); + } + + public async downloadTensor(tensorId: TensorId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise { + return this.tensorManager.download(tensorId, dstBuffer); + } + + public createMLTensorDownloader(tensorId: TensorId, type: Tensor.MLTensorDataTypes): () => Promise { + return async () => { + const data = await this.tensorManager.download(tensorId); + return createView(data, type); + }; + } + + public registerMLTensor(tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId { + const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType); + if (!webnnDataType) { + throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); + } + + const id = this.tensorManager.registerTensor(this.currentContext, tensor, webnnDataType, dimensions); + LOG_DEBUG( + 'verbose', + () => + `[WebNN] registerMLTensor {tensor: ${tensor}, dataType: ${webnnDataType}, dimensions: ${ + dimensions + }} -> {tensorId: ${id}}`, + ); + return id; + } + + public flush(): void { + // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations. + } +} diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 2f0e5da2b3f27..7bce5ff9390e8 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -12,6 +12,7 @@ import { LOG_DEBUG } from './log'; import { TensorView } from './tensor-view'; import { ShapeUtil } from './util'; import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types'; +import { WebNNBackend } from './backend-webnn'; /* eslint-disable no-bitwise */ @@ -266,6 +267,22 @@ export const init = async ( () => backend.replay(), ]); } else { - jsepInit('webnn'); + const backend = new WebNNBackend(env); + jsepInit('webnn', [ + backend, + // jsepReserveTensorId + () => backend.reserveTensorId(), + // jsepReleaseTensorId, + (tensorId: number) => backend.releaseTensorId(tensorId), + // jsepEnsureTensor + async (tensorId: number, onnxDataType: number, shape: number[], copyOld) => + backend.ensureTensor(tensorId, onnxDataType, shape, copyOld), + // jsepUploadTensor + (tensorId: number, data: Uint8Array) => { + backend.uploadTensor(tensorId, data); + }, + // jsepDownloadTensor + async (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadTensor(tensorId, dstBuffer), + ]); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 8e18a28acc364..33e8c95c141ee 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -52,12 +52,12 @@ export interface GpuDataManager { * GPU data manager only manages a mapping between the buffer and the GPU data ID. It will not manage the lifecycle of * the external buffer. */ - registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number; + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number; /** * unregister an external buffer for IO Binding. */ - unregisterExternalBuffer(buffer: GPUBuffer): void; + unregisterExternalBuffer(id: GpuDataId): void; /** * destroy all gpu buffers. @@ -196,9 +196,6 @@ class GpuDataManagerImpl implements GpuDataManager { // The reusable uniform buffers private freeUniformBuffers: Map; - // The external buffers registered users for IO Binding. - private externalBuffers: Map; - // The pendingBuffers for capture graph. // a SessionID -> GPUBuffer[] mapping. private capturedPendingBuffers: Map; @@ -209,7 +206,6 @@ class GpuDataManagerImpl implements GpuDataManager { this.freeUniformBuffers = new Map(); this.buffersForUploadingPending = []; this.buffersPending = []; - this.externalBuffers = new Map(); this.capturedPendingBuffers = new Map(); for (const [key] of bucketFreelist) { @@ -284,14 +280,11 @@ class GpuDataManagerImpl implements GpuDataManager { ); } - registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number { + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number { let id: number | undefined; - if (previousBuffer) { - id = this.externalBuffers.get(previousBuffer); - if (id === undefined) { - throw new Error('previous buffer is not registered'); - } - if (buffer === previousBuffer) { + if (previous) { + id = previous[0]; + if (buffer === previous[1]) { LOG_DEBUG( 'verbose', () => @@ -304,13 +297,11 @@ class GpuDataManagerImpl implements GpuDataManager { throw new Error(`Registering a different external buffer under graph capture mode is not supported yet. Please use the previous external buffer!`); } - this.externalBuffers.delete(previousBuffer); } else { id = createNewGpuDataId(); } this.storageCache.set(id, { gpuData: { id, type: GpuDataType.default, buffer }, originalSize }); - this.externalBuffers.set(buffer, id); LOG_DEBUG( 'verbose', () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`, @@ -318,11 +309,9 @@ class GpuDataManagerImpl implements GpuDataManager { return id; } - unregisterExternalBuffer(buffer: GPUBuffer): void { - const id = this.externalBuffers.get(buffer); + unregisterExternalBuffer(id: GpuDataId): void { if (id !== undefined) { this.storageCache.delete(id); - this.externalBuffers.delete(buffer); LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`); } } diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts new file mode 100644 index 0000000000000..9475de019ed1d --- /dev/null +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -0,0 +1,303 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { WebNNBackend } from '../backend-webnn'; +import { LOG_DEBUG } from '../log'; + +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + +export type TensorId = number; + +/** + * Manages TensorId to MLTensor mapping. + */ +export interface TensorManager { + /** + * Reserve a new TensorId. + */ + reserveTensorId(): TensorId; + /** + * Release a TensorId. + */ + releaseTensorId(tensorId: TensorId): void; + /** + * Ensure a MLTensor is created for the TensorId. + */ + ensureTensor( + tensorId: TensorId, + dataType: MLOperandDataType, + shape: readonly number[], + copyOld: boolean, + ): Promise; + /** + * Upload data to a MLTensor. + */ + upload(tensorId: TensorId, data: Uint8Array): void; + /** + * Download data from a MLTensor. + */ + download(tensorId: TensorId): Promise; + download(tensorId: TensorId, dstTensor: ArrayBufferView | ArrayBuffer): Promise; + /** + * Release all tensors for a MLContext. + */ + releaseTensorsForContext(mlContext: MLContext): void; + /** + * Register an externally created MLTensor with a given MLContext and return a TensorId. + */ + registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId; +} + +let tensorGuid = 1; +const createNewTensorId = (): TensorId => tensorGuid++; + +export type MLTensorEntry = [MLTensor, MLOperandDataType, readonly number[]]; + +/** + * TensorTracker tracks the MLTensor and pending upload data. + * + * We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until + * we know the data type and shape. This is because future implementations of WebNN will only support creating + * MLTensors with dataTypes and shape. + */ +class TensorTracker { + private tensorEntry?: MLTensorEntry; + private activeUpload?: Uint8Array; + private tensorCache: MLTensorEntry[]; + + constructor( + private mlContext?: MLContext, + tensorEntry?: MLTensorEntry, + ) { + this.tensorEntry = tensorEntry; + this.tensorCache = tensorEntry ? [tensorEntry] : []; + } + + public get tensor(): MLTensor | undefined { + return this.tensorEntry?.[0]; + } + + public get context(): MLContext { + if (!this.mlContext) { + throw new Error('MLContext has not been set.'); + } + return this.mlContext; + } + + public set context(mlContext: MLContext) { + if (this.mlContext && this.mlContext !== mlContext) { + throw new Error('MLTensor in use in a different MLContext.'); + } + this.mlContext = mlContext; + } + + public destroy(): void { + for (const [mlTensor] of this.tensorCache) { + mlTensor.destroy(); + } + this.tensorCache = []; + this.tensorEntry = undefined; + } + + public trySelectTensor(context: MLContext, tryMLTensor: MLTensor): boolean { + for (const [mlTensor, dataType, shape] of this.tensorCache) { + if (tryMLTensor === mlTensor) { + if (this.context !== context) { + throw new Error('MLTensor cannot be registered with a different MLContext.'); + } + this.tensorEntry = [mlTensor, dataType, shape]; + return true; + } + } + return false; + } + + public async ensureTensor( + dataType: MLOperandDataType, + shape: readonly number[], + copyOld: boolean, + ): Promise { + if (this.tensorEntry) { + const [mlTensor, existingDataType, existingShape] = this.tensorEntry; + if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) { + return mlTensor; + } + } + + for (const [mlTensor, existingDataType, existingShape] of this.tensorCache) { + if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) { + if (copyOld && this.tensorEntry) { + // WebNN does not support copyTensorToTensor, so we need to read and write the tensors. + LOG_DEBUG( + 'verbose', + () => `[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${dataType}, shape: ${shape}}`, + ); + const data = await this.context.readTensor(this.tensorEntry[0]); + this.context.writeTensor(mlTensor, data); + } + this.tensorEntry = [mlTensor, existingDataType, existingShape]; + return mlTensor; + } + } + LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`); + // eslint-disable-next-line no-bitwise + const usage = MLTensorUsage.READ | MLTensorUsage.WRITE; + const tensor = await this.context.createTensor({ + dataType, + shape, + // Assign both shape and dimensions while transitioning to new API. + dimensions: shape, + usage, + }); + this.tensorEntry = [tensor, dataType, shape]; + this.tensorCache.push(this.tensorEntry); + + if (this.activeUpload) { + this.mlContext?.writeTensor(tensor, this.activeUpload); + this.activeUpload = undefined; + } + + return tensor; + } + + public upload(data: Uint8Array): void { + if (!this.tensorEntry) { + this.activeUpload = new Uint8Array(data); + return; + } + this.mlContext?.writeTensor(this.tensorEntry[0], data); + } + + public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { + if (this.activeUpload) { + if (dstBuffer) { + if (dstBuffer instanceof ArrayBuffer) { + new Uint8Array(dstBuffer).set(this.activeUpload); + } else { + new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload); + } + + return; + } else { + return this.activeUpload.buffer; + } + } + if (!this.tensorEntry) { + throw new Error('Tensor has not been created.'); + } + if (dstBuffer) { + return this.context.readTensor(this.tensorEntry[0], dstBuffer); + } + return this.context.readTensor(this.tensorEntry[0]); + } +} + +class TensorManagerImpl implements TensorManager { + private tensorsById = new Map(); + private tensorIdsByContext = new Map>(); + + constructor(private backend: WebNNBackend) {} + + public reserveTensorId(): TensorId { + const tensorId = createNewTensorId(); + this.tensorsById.set(tensorId, new TensorTracker()); + return tensorId; + } + + public releaseTensorId(tensorId: TensorId): void { + const tensorTracker = this.tensorsById.get(tensorId); + if (!tensorTracker) { + return; + } + tensorTracker.destroy(); + this.tensorsById.delete(tensorId); + for (const [mlContext, tensors] of this.tensorIdsByContext) { + if (tensors.has(tensorId)) { + tensors.delete(tensorId); + if (tensors.size === 0) { + this.tensorIdsByContext.delete(mlContext); + } + break; + } + } + } + + public async ensureTensor( + tensorId: TensorId, + dataType: MLOperandDataType, + shape: number[], + copyOld: boolean, + ): Promise { + LOG_DEBUG( + 'verbose', + () => + `[WebNN] TensorManager.ensureTensor {tensorId: ${tensorId}, dataType: ${ + dataType + }, shape: ${shape}, copyOld: ${copyOld}}`, + ); + const tensor = this.tensorsById.get(tensorId); + if (!tensor) { + throw new Error('Tensor not found.'); + } + tensor.context = this.backend.currentContext; + if (!this.tensorIdsByContext.has(this.backend.currentContext)) { + this.tensorIdsByContext.set(this.backend.currentContext, new Set()); + } + this.tensorIdsByContext.get(this.backend.currentContext)?.add(tensorId); + return tensor.ensureTensor(dataType, shape, copyOld); + } + + public upload(tensorId: TensorId, data: Uint8Array): void { + this.tensorsById.get(tensorId)!.upload(data); + } + + public async download(tensorId: TensorId): Promise; + public async download(tensorId: TensorId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise; + async download(tensorId: TensorId, dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { + LOG_DEBUG( + 'verbose', + () => `[WebNN] TensorManager.download {tensorId: ${tensorId}, dstBuffer: ${dstBuffer?.byteLength}}`, + ); + return this.tensorsById.get(tensorId)!.download(dstBuffer); + } + + public releaseTensorsForContext(mlContext: MLContext): void { + const tensors = this.tensorIdsByContext.get(mlContext); + if (!tensors) { + return; + } + for (const tensorId of tensors) { + this.tensorsById.get(tensorId)!.destroy(); + this.tensorsById.delete(tensorId); + } + this.tensorIdsByContext.delete(mlContext); + } + + public registerTensor( + mlContext: MLContext, + mlTensor: MLTensor, + dataType: MLOperandDataType, + shape: readonly number[], + ): TensorId { + for (const [tensorId, tensorTracker] of this.tensorsById) { + if (tensorTracker.trySelectTensor(mlContext, mlTensor)) { + return tensorId; + } + } + const tensorId = createNewTensorId(); + this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, shape])); + let tensors = this.tensorIdsByContext.get(mlContext); + if (!tensors) { + tensors = new Set(); + this.tensorIdsByContext.set(mlContext, tensors); + } + tensors.add(tensorId); + return tensorId; + } +} + +export const createTensorManager = (...args: ConstructorParameters): TensorManager => + new TensorManagerImpl(...args); diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index f8a1e1966fd4c..5cb0f4e74c3df 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +/* eslint-disable @typescript-eslint/naming-convention */ + interface NavigatorML { readonly ml: ML; } @@ -30,7 +32,9 @@ type MLInputOperandLayout = 'nchw'|'nhwc'; type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'; interface MLOperandDescriptor { dataType: MLOperandDataType; - dimensions?: number[]; + shape?: readonly number[]; + /** @deprecated Use shape instead of dimensions */ + dimensions?: readonly number[]; } interface MLOperand { dataType(): MLOperandDataType; @@ -379,23 +383,32 @@ interface MLGraphBuilder { where(condition: MLOperand, input: MLOperand, other: MLOperand): MLOperand; } -// Experimental MLBuffer interface +// Experimental MLTensor interface -type MLSize64Out = number; -interface MLBuffer { - readonly size: MLSize64Out; +interface MLTensor { destroy(): void; } -type MLSize64 = number; -interface MLBufferDescriptor { - size: MLSize64; + +type MLNamedTensor = Record; + +type MLTensorUsageFlags = number; + +declare const MLTensorUsage: { + readonly WEBGPU_INTEROP: MLTensorUsageFlags; + readonly READ: MLTensorUsageFlags; + readonly WRITE: MLTensorUsageFlags; +}; + +interface MLTensorDescriptor extends MLOperandDescriptor { + usage: MLTensorUsageFlags; } -type MLNamedBuffers = Record; + interface MLContext { - createBuffer(descriptor: MLBufferDescriptor): MLBuffer; - writeBuffer( - dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: MLSize64, - srcElementSize?: MLSize64): void; - readBuffer(srcBuffer: MLBuffer): Promise; - dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void; + createTensor(descriptor: MLTensorDescriptor): Promise; + writeTensor( + destinationTensor: MLTensor, sourceData: ArrayBufferView|ArrayBuffer, sourceElementOffset?: number, + sourceElementSize?: number): void; + readTensor(sourceTensor: MLTensor): Promise; + readTensor(sourceTensor: MLTensor, destinationData: ArrayBufferView|ArrayBuffer): Promise; + dispatch(graph: MLGraph, inputs: MLNamedTensor, outputs: MLNamedTensor): void; } diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index 8f3acdd582445..559f319a10f66 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -19,11 +19,18 @@ export type GpuBufferMetadata = { dispose?: () => void; }; +export type MLTensorMetadata = { + mlTensor: Tensor.MLTensorType; + download?: () => Promise; + dispose?: () => void; +}; + /** - * Tensors on location 'cpu-pinned' and 'gpu-buffer' are not serializable. + * Tensors on location 'cpu-pinned', 'gpu-buffer', and 'ml-tensor' are not serializable. */ export type UnserializableTensorMetadata = | [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer'] + | [dataType: Tensor.Type, dims: readonly number[], data: MLTensorMetadata, location: 'ml-tensor'] | [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; /** @@ -34,6 +41,7 @@ export type UnserializableTensorMetadata = * - cpu: Uint8Array * - cpu-pinned: Uint8Array * - gpu-buffer: GpuBufferMetadata + * - ml-tensor: MLTensorMetadata * - location: tensor data location */ export type TensorMetadata = SerializableTensorMetadata | UnserializableTensorMetadata; diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts index eff3e91389c98..c19043cc3637f 100644 --- a/js/web/lib/wasm/session-handler-inference.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -12,7 +12,7 @@ import { import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; import { copyFromExternalBuffer, createSession, endProfiling, releaseSession, run } from './proxy-wrapper'; -import { isGpuBufferSupportedType } from './wasm-common'; +import { isGpuBufferSupportedType, isMLTensorSupportedType } from './wasm-common'; import { isNode } from './wasm-utils-env'; import { loadFile } from './wasm-utils-load-file'; @@ -22,6 +22,8 @@ export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): Ten return [tensor.type, tensor.dims, tensor.data, 'cpu']; case 'gpu-buffer': return [tensor.type, tensor.dims, { gpuBuffer: tensor.gpuBuffer }, 'gpu-buffer']; + case 'ml-tensor': + return [tensor.type, tensor.dims, { mlTensor: tensor.mlTensor }, 'ml-tensor']; default: throw new Error(`invalid data location: ${tensor.location} for ${getName()}`); } @@ -39,6 +41,14 @@ export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { const { gpuBuffer, download, dispose } = tensor[2]; return Tensor.fromGpuBuffer(gpuBuffer, { dataType, dims: tensor[1], download, dispose }); } + case 'ml-tensor': { + const dataType = tensor[0]; + if (!isMLTensorSupportedType(dataType)) { + throw new Error(`not supported data type: ${dataType} for deserializing MLTensor tensor`); + } + const { mlTensor, download, dispose } = tensor[2]; + return Tensor.fromMLTensor(mlTensor, { dataType, dims: tensor[1], download, dispose }); + } default: throw new Error(`invalid data location: ${tensor[3]}`); } diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 78ff14540d8cb..ad2ff62587252 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -240,6 +240,20 @@ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuB type === 'uint4' || type === 'int4'; +/** + * Check whether the given tensor type is supported by WebNN MLTensor + */ +export const isMLTensorSupportedType = (type: Tensor.Type): type is Tensor.MLTensorDataTypes => + type === 'float32' || + type === 'float16' || + type === 'int32' || + type === 'int64' || + type === 'uint32' || + type === 'uint64' || + type === 'int8' || + type === 'uint8' || + type === 'bool'; + /** * Map string data location to integer value */ @@ -255,6 +269,8 @@ export const dataLocationStringToEnum = (location: Tensor.DataLocation): number return 3; case 'gpu-buffer': return 4; + case 'ml-tensor': + return 5; default: throw new Error(`unsupported data location: ${location}`); } @@ -264,4 +280,4 @@ export const dataLocationStringToEnum = (location: Tensor.DataLocation): number * Map integer data location to string value */ export const dataLocationEnumToString = (location: number): Tensor.DataLocation | undefined => - (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location]; + (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer', 'ml-tensor'] as const)[location]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index ed001cfa90f59..0668ac1931988 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -20,6 +20,7 @@ import { calculateTensorSizeInBytes, dataLocationStringToEnum, isGpuBufferSupportedType, + isMLTensorSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, @@ -162,7 +163,7 @@ export const initEp = async (env: Env, epName: string): Promise => { /** * valid data locations for input/output tensors. */ -type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer'; +type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer' | 'ml-tensor'; type IOBindingState = { /** @@ -173,7 +174,7 @@ type IOBindingState = { /** * the preferred location for each output tensor. * - * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'. + * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer', 'ml-tensor'. */ readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[]; @@ -287,6 +288,7 @@ export const createSession = async ( for (const provider of options?.executionProviders ?? []) { const providerName = typeof provider === 'string' ? provider : provider.name; if (providerName === 'webnn') { + wasm.shouldTransferToMLTensor = false; if (wasm.currentContext) { throw new Error('WebNN execution provider is already set.'); } @@ -318,7 +320,9 @@ export const createSession = async ( // clear current MLContext after session creation if (wasm.currentContext) { + wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext); wasm.currentContext = undefined; + wasm.shouldTransferToMLTensor = true; } const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); @@ -354,7 +358,7 @@ export const createSession = async ( typeof options?.preferredOutputLocation === 'string' ? options.preferredOutputLocation : (options?.preferredOutputLocation?.[nameString] ?? 'cpu'); - if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { + if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer' && location !== 'ml-tensor') { throw new Error(`Not supported preferred output location: ${location}.`); } if (enableGraphCapture && location !== 'gpu-buffer') { @@ -366,9 +370,9 @@ export const createSession = async ( } } - // use IO binding only when at least one output is preffered to be on GPU. + // use IO binding only when at least one output is preferred to be on GPU. let bindingState: IOBindingState | null = null; - if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer')) { + if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor')) { ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); if (ioBindingHandle === 0) { checkLastError("Can't create IO binding."); @@ -459,7 +463,7 @@ export const prepareInputOutputTensor = ( let rawData: number; let dataByteLength: number; - if (dataType === 'string' && location === 'gpu-buffer') { + if (dataType === 'string' && (location === 'gpu-buffer' || location === 'ml-tensor')) { throw new Error('String tensor is not supported on GPU.'); } @@ -478,6 +482,15 @@ export const prepareInputOutputTensor = ( throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); } rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else if (location === 'ml-tensor') { + const mlTensor = tensor[2].mlTensor as MLTensor; + dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; + + const registerMLTensor = wasm.jsepRegisterMLTensor; + if (!registerMLTensor) { + throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.'); + } + rawData = registerMLTensor(mlTensor, tensorDataTypeStringToEnum(dataType), dims); } else { const data = tensor[2]; @@ -563,6 +576,9 @@ export const run = async ( const outputNamesOffset = wasm.stackAlloc(outputCount * 4); try { + // WebNN backend needs the active session to check MLTensors with the current context. + wasm.jsepOnRunStart?.(sessionHandle); + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors @@ -654,7 +670,6 @@ export const run = async ( ]); } - wasm.jsepOnRunStart?.(sessionHandle); let errorCode: number; if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( @@ -726,7 +741,7 @@ export const run = async ( const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]]; if (type === 'string') { - if (preferredLocation === 'gpu-buffer') { + if (preferredLocation === 'gpu-buffer' || preferredLocation === 'ml-tensor') { throw new Error('String tensor is not supported on GPU.'); } const stringData: string[] = []; @@ -766,6 +781,37 @@ export const run = async ( }, 'gpu-buffer', ]); + } else if (preferredLocation === 'ml-tensor' && size > 0) { + const ensureTensor = wasm.jsepEnsureTensor; + if (!ensureTensor) { + throw new Error('preferredLocation "ml-tensor" is not supported without using WebNN.'); + } + const tensorSize = calculateTensorSizeInBytes(dataType, size); + if (tensorSize === undefined || !isMLTensorSupportedType(type)) { + throw new Error(`Unsupported data type: ${type}`); + } + + // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use + // ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new tensor + // has been created. + const mlTensor = await ensureTensor(dataOffset, dataType, dims, false); + + // do not release the tensor right now. it will be released when user calls tensor.dispose(). + keepOutputTensor = true; + + output.push([ + type, + dims, + { + mlTensor, + download: wasm.jsepCreateMLTensorDownloader!(dataOffset, type), + dispose: () => { + wasm.jsepReleaseTensorId!(dataOffset); + wasm._OrtReleaseTensor(tensor); + }, + }, + 'ml-tensor', + ]); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); const data = new typedArrayConstructor(size); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 828cd3cfd94fa..3e08fe97f559d 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -7,6 +7,7 @@ /// import type { Tensor } from 'onnxruntime-common'; +import { DataType } from './wasm-common'; /* eslint-disable @typescript-eslint/naming-convention */ @@ -27,6 +28,16 @@ export declare namespace JSEP { type CaptureBeginFunction = () => void; type CaptureEndFunction = () => void; type ReplayFunction = () => void; + type ReserveTensorIdFunction = () => number; + type ReleaseTensorIdFunction = (tensorId: number) => void; + type EnsureTensorFunction = ( + tensorId: number, + dataType: DataType, + shape: readonly number[], + copyOld: boolean, + ) => Promise; + type UploadTensorFunction = (tensorId: number, data: Uint8Array) => void; + type DownloadTensorFunction = (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; export interface Module extends WebGpuModule, WebNnModule { /** @@ -62,7 +73,17 @@ export declare namespace JSEP { replay: ReplayFunction, ], ): void; - jsepInit(name: 'webnn', initParams?: never): void; + jsepInit( + name: 'webnn', + initParams: [ + backend: BackendType, + reserveTensorId: ReserveTensorIdFunction, + releaseTensorId: ReleaseTensorIdFunction, + ensureTensor: EnsureTensorFunction, + uploadTensor: UploadTensorFunction, + downloadTensor: DownloadTensorFunction, + ], + ): void; } export interface WebGpuModule { @@ -134,6 +155,70 @@ export declare namespace JSEP { * Active MLContext used to create WebNN EP. */ currentContext: MLContext; + + /** + * Disables creating MLTensors. This is used to avoid creating MLTensors for graph initializers. + */ + shouldTransferToMLTensor: boolean; + + /** + * [exported from pre-jsep.js] Register MLContext for a session. + * @param sessionId - specify the session ID. + * @param context - specify the MLContext. + * @returns + */ + jsepRegisterMLContext: (sessionId: number, context: MLContext) => void; + /** + * [exported from pre-jsep.js] Reserve a MLTensor ID attached to the current session. + * @returns the MLTensor ID. + */ + jsepReserveTensorId: () => number; + /** + * [exported from pre-jsep.js] Release an MLTensor ID from use and destroys underlying MLTensor if no longer in use. + * @param tensorId - specify the MLTensor ID. + * @returns + */ + jsepReleaseTensorId: (tensorId: number) => void; + /** + * [exported from pre-jsep.js] Ensure that an MLTensor of a given type and shape exists for a MLTensor ID. + * @param tensorId - specify the MLTensor ID. + * @param onnxDataType - specify the data type. + * @param shape - specify the dimensions (WebNN shape) of the tensor. + * @param copyOld - specify whether to copy the old tensor if a new tensor was created. + * @returns the MLTensor associated with the tensor ID. + */ + jsepEnsureTensor: (tensorId: number, dataType: DataType, shape: number[], copyOld: boolean) => Promise; + /** + * [exported from pre-jsep.js] Upload data to an MLTensor. + * @param tensorId - specify the MLTensor ID. + * @param data - specify the data to upload. It can be a TensorProto::data_type or a WebNN MLOperandDataType. + * @returns + */ + jsepUploadTensor: (tensorId: number, data: Uint8Array) => void; + /** + * [exported from pre-jsep.js] Download data from an MLTensor. + * @param tensorId - specify the MLTensor ID. + * @returns the downloaded data. + */ + jsepDownloadTensor: (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; + /** + * [exported from pre-jsep.js] Creates a downloader function to download data from an MLTensor. + * @param tensorId - specify the MLTensor ID. + * @param type - specify the data type. + * @returns the downloader function. + */ + jsepCreateMLTensorDownloader: ( + tensorId: number, + type: Tensor.MLTensorDataTypes, + ) => () => Promise; + /** + * [exported from pre-jsep.js] Registers an external MLTensor to a session. + * @param tensor - specify the MLTensor. + * @param dataType - specify the data type. + * @param dimensions - specify the dimensions. + * @returns the MLTensor ID for the external MLTensor. + */ + jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; } } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index d237293dbb192..e94e11d0ace56 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -62,6 +62,8 @@ Options: none (default) gpu-tensor use pre-allocated GPU tensors for inputs and outputs gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer' + ml-tensor use pre-allocated MLTensor tensors for inputs and outputs + ml-location use pre-allocated MLTensor tensors for inputs and set preferredOutputLocation to 'ml-tensor' *** Logging Options *** @@ -133,7 +135,7 @@ export declare namespace TestRunnerCliArgs { type Backend = 'cpu' | 'webgl' | 'webgpu' | 'wasm' | 'onnxruntime' | 'webnn'; type Environment = 'chrome' | 'chromecanary' | 'edge' | 'firefox' | 'electron' | 'safari' | 'node' | 'bs'; type BundleMode = 'dev' | 'perf'; - type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location'; + type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location' | 'ml-tensor' | 'ml-location'; } export interface TestRunnerCliArgs { @@ -455,7 +457,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs // Option: -i=<...>, --io-binding=<...> const ioBindingArg = args['io-binding'] || args.i; const ioBindingMode = typeof ioBindingArg !== 'string' ? 'none' : ioBindingArg; - if (['none', 'gpu-tensor', 'gpu-location'].indexOf(ioBindingMode) === -1) { + if (['none', 'gpu-tensor', 'gpu-location', 'ml-tensor', 'ml-location'].indexOf(ioBindingMode) === -1) { throw new Error(`not supported io binding mode ${ioBindingMode}`); } diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index a9fcd7b876b2f..68ee58dab7094 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -380,7 +380,7 @@ async function main() { } let ioBinding: Test.IOBindingMode; - if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + if (!['webgpu', 'webnn'].includes(backend) && args.ioBindingMode !== 'none') { npmlog.warn( 'TestRunnerCli.Init.Model', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`, diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 5c1e2e27a6eff..ae708467be8a2 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1912,9 +1912,9 @@ // "test_lrn_default", // "test_lrn", // // "test_lstm_batchwise", - // // "test_lstm_defaults", - // // "test_lstm_with_initial_bias", - // // "test_lstm_with_peepholes", + "test_lstm_defaults", + "test_lstm_with_initial_bias", + "test_lstm_with_peepholes", "test_matmul_2d", "test_matmul_3d", "test_matmul_4d", diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index aa9555c191501..2176a776a0192 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + import { Float16Array as Float16ArrayPolyfill } from '@petamoriken/float16'; import { expect } from 'chai'; import * as ort from 'onnxruntime-common'; @@ -19,6 +24,7 @@ import { createView } from '../lib/wasm/jsep/tensor-view'; import { calculateTensorSizeInBytes, isGpuBufferSupportedType, + isMLTensorSupportedType, tensorDataTypeStringToEnum, } from '../lib/wasm/wasm-common'; @@ -170,13 +176,20 @@ async function initializeSession( }`, ); + let preferredOutputLocation: ort.Tensor.DataLocation | undefined; + if (ioBindingMode === 'gpu-location') { + preferredOutputLocation = 'gpu-buffer'; + } else if (ioBindingMode === 'ml-location') { + preferredOutputLocation = 'ml-tensor'; + } + const profilerConfig = profile ? { maxNumberEvents: 65536 } : undefined; const sessionConfig = { ...sessionOptions, executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile, - preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, + preferredOutputLocation, externalData, }; @@ -219,6 +232,7 @@ export class ModelTestContext { readonly perfData: ModelTestContext.ModelTestPerfData, readonly ioBinding: Test.IOBindingMode, private readonly profile: boolean, + public readonly mlContext?: MLContext, ) {} /** @@ -272,7 +286,24 @@ export class ModelTestContext { const initStart = now(); const executionProviderConfig = - modelTest.backend === 'webnn' ? testOptions?.webnnOptions || 'webnn' : modelTest.backend!; + modelTest.backend === 'webnn' ? testOptions?.webnnOptions || { name: 'webnn' } : modelTest.backend!; + let mlContext: MLContext | undefined; + if (['ml-tensor', 'ml-location'].includes(modelTest.ioBinding)) { + const webnnOptions = executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption; + const deviceType = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.deviceType; + const numThreads = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.numThreads; + const powerPreference = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.powerPreference; + + mlContext = await navigator.ml.createContext({ + deviceType, + numThreads, + powerPreference, + }); + (executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption).context = mlContext; + if (!deviceType) { + (executionProviderConfig as ort.InferenceSession.WebNNContextOptions).deviceType = deviceType; + } + } const session = await initializeSession( modelTest.modelUrl, executionProviderConfig, @@ -295,6 +326,7 @@ export class ModelTestContext { { init: initEnd - initStart, firstRun: -1, runs: [], count: 0 }, modelTest.ioBinding, profile, + mlContext, ); } finally { this.initializing = false; @@ -622,30 +654,82 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[] }); } +async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Type, dims: readonly number[]) { + if (!isMLTensorSupportedType(type)) { + throw new Error(`createMLTensorForOutput can not work with ${type} tensor`); + } + + const dataType = type === 'bool' ? 'uint8' : type; + + const mlTensor = await mlContext.createTensor({ + dataType, + shape: dims as number[], + // Assign both shape and dimensions while transitioning to new API. + dimensions: dims as number[], + usage: MLTensorUsage.READ, + }); + + return ort.Tensor.fromMLTensor(mlTensor, { + dataType: type, + dims, + dispose: () => mlTensor.destroy(), + download: async () => { + const arrayBuffer = await mlContext.readTensor(mlTensor); + return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.MLTensorDataTypes]; + }, + }); +} + +async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tensor): Promise { + if (!isMLTensorSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { + throw new Error(`createMLTensorForInput can not work with ${cpuTensor.type} tensor`); + } + const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type; + const mlTensor = await mlContext.createTensor({ + dataType, + shape: cpuTensor.dims as number[], + // Assign both shape and dimensions while transitioning to new API. + dimensions: cpuTensor.dims as number[], + usage: MLTensorUsage.WRITE, + }); + mlContext.writeTensor(mlTensor, cpuTensor.data); + return ort.Tensor.fromMLTensor(mlTensor, { + dataType: cpuTensor.type, + dims: cpuTensor.dims, + dispose: () => mlTensor.destroy(), + }); +} + export async function sessionRun(options: { session: ort.InferenceSession; feeds: Record; outputsMetaInfo: Record>; ioBinding: Test.IOBindingMode; + mlContext?: MLContext; }): Promise<[number, number, ort.InferenceSession.OnnxValueMapType]> { const session = options.session; const feeds = options.feeds; const fetches: Record = {}; - // currently we only support IO Binding for WebGPU + // currently we only support IO Binding for WebGPU and WebNN // - // For inputs, we create GPU tensors on both 'gpu-tensor' and 'gpu-location' binding testing mode. - // For outputs, we create GPU tensors on 'gpu-tensor' binding testing mode only. + // For inputs, we create tensors on 'gpu-tensor', 'gpu-location', 'ml-tensor', and 'ml-location' binding testing + // modes. + // For outputs, we create tensors on 'gpu-tensor' and 'ml-tensor' binding testing modes. // in 'gpu-device' binding mode, outputs are not pre-allocated. - const shouldUploadInput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'gpu-location'; - const shouldUploadOutput = options.ioBinding === 'gpu-tensor'; + const shouldUploadInput = ['gpu-tensor', 'gpu-location', 'ml-location', 'ml-tensor'].includes(options.ioBinding); + const shouldUploadOutput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'ml-tensor'; try { if (shouldUploadInput) { // replace the CPU tensors in feeds into GPU tensors for (const name in feeds) { if (Object.hasOwnProperty.call(feeds, name)) { if (feeds[name].size > 0) { - feeds[name] = createGpuTensorForInput(feeds[name]); + if (options.ioBinding === 'ml-location' || options.ioBinding === 'ml-tensor') { + feeds[name] = await createMLTensorForInput(options.mlContext!, feeds[name]); + } else { + feeds[name] = createGpuTensorForInput(feeds[name]); + } } } } @@ -658,7 +742,11 @@ export async function sessionRun(options: { if (dims.some((d) => d === 0)) { fetches[name] = new ort.Tensor(type, [], dims); } else { - fetches[name] = createGpuTensorForOutput(type, dims); + if (options.ioBinding === 'ml-tensor') { + fetches[name] = await createMLTensorForOutput(options.mlContext!, type, dims); + } else { + fetches[name] = createGpuTensorForOutput(type, dims); + } } } } @@ -714,6 +802,7 @@ export async function runModelTestSet( feeds, outputsMetaInfo, ioBinding: context.ioBinding, + mlContext: context.mlContext, }); if (context.perfData.count === 0) { context.perfData.firstRun = end - start; diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index be1e56485ec5a..29a11f969ffea 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -52,8 +52,12 @@ export declare namespace Test { * `preferredOutputLocation` will be set to `gpu-buffer`. * - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation` * will not be set. + * - ml-location: inputs will be pre-allocated as ML tensors; no output will be pre-allocated; + * `preferredOutputLocation` will be set to `ml-tensor`. + * - ml-tensor: inputs and outputs will all be pre-allocated as MLTensor tensors. `preferredOutputLocation` + * will not be set. */ - export type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location'; + export type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location' | 'ml-tensor' | 'ml-location'; export interface ModelTestCase { name: string; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc index 7b84971585f9f..c8fe9c77d8ff8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc @@ -48,13 +48,13 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat constexpr size_t element_size = sizeof(T); ProcessBroadcastSpanFuncs add_funcs{ [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); }}; // For element-wise add // Allocate space for output of Q(BS, D) + bias(D) @@ -132,13 +132,13 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is constexpr size_t element_size = sizeof(T); ProcessBroadcastSpanFuncs add_funcs{ [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); }}; // For element-wise add // Get Q's bias from combined bias @@ -219,6 +219,10 @@ template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, int batch_size, int num_heads, int sequence_length, int head_size, const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out); +template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out); + template Status MaybeTransposeToBNSH(AllocatorPtr allocator, int batch_size, int num_heads, int sequence_length, int head_size, @@ -242,5 +246,9 @@ template Status MaybeTransposeToBNSH(AllocatorPtr allocator, int batch_size, int num_heads, int sequence_length, int head_size, const Tensor* in, OrtValue& out); +template Status MaybeTransposeToBNSH(AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, OrtValue& out); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc index 570f4108c3f62..72adfa025da57 100644 --- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc @@ -86,6 +86,11 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { std::atomic_bool failed{false}; int n = batch_size * sequence_length; + + // Put epsilon into local variable here to avoid the need to capture 'this' in the TryBatchParallelFor() lambda. + // Using the copy capture default (=) to implicitly capture 'this' is deprecated. + const float epsilon_value = epsilon(); + concurrency::ThreadPool::TryBatchParallelFor( context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) { int word_col_index = input_ids_data[index]; @@ -136,7 +141,7 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { y[i] = a; sum += a * a; } - T e = sqrt(sum / hidden_size + static_cast(epsilon())); + T e = sqrt(sum / hidden_size + static_cast(epsilon_value)); for (int i = 0; i < hidden_size; i++) { y[i] = y[i] / e * gamma_data[i] + beta_data[i]; } diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index bfec9aef56727..ccaeb6654e286 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -75,7 +75,7 @@ class GQAAttentionBase { int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]); // Compute the attention score. - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -87,16 +87,17 @@ class GQAAttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp); + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), + ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, + seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, tp); + is_prompt, tp, allocator); return Status::OK(); } @@ -106,7 +107,7 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + void ComputeAttentionProbs(float* attention_probs, // output buffer with size BxNxSxT const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor @@ -120,7 +121,8 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt - ThreadPool* tp) const { // thread pool + ThreadPool* tp, // thread pool + AllocatorPtr allocator) const { // allocator for temporary buffer const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); @@ -131,7 +133,9 @@ class GQAAttentionBase { const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { - memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); + memset((void*)present_key, + 0, + batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } const size_t loop_len = batch_size * num_heads_; @@ -164,7 +168,7 @@ class GQAAttentionBase { const size_t past_chunk_length = past_seqlen * head_size; const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; - T* output = attention_probs + output_offset; + float* output = attention_probs + output_offset; const T* k; if (packed_qkv) { @@ -190,12 +194,28 @@ class GQAAttentionBase { q = Q + q_input_chunk_length * i; } - math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, - static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, - static_cast(present_buffer_sequence_length), nullptr); + if constexpr (std::is_same::value) { + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, + static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, + output, static_cast(present_buffer_sequence_length), nullptr); + } else { + size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float); + auto q_k_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator)); + + float* q_fp32 = static_cast(q_k_fp32); + MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length); + + float* k_fp32 = q_fp32 + head_size * sequence_length; + MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seqlen); + + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q_fp32, + static_cast(head_size), k_fp32, static_cast(head_size), 0.0f /*bata*/, + output, static_cast(present_buffer_sequence_length), nullptr); + } // compute Softmax - T* output_softmax = output; + float* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { size_t seq_causal_length = past_seqlen + seq + 1; if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) { @@ -237,7 +257,7 @@ class GQAAttentionBase { template void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH - const T* attention_probs, // Attention probs with size BxNxSxT + const float* attention_probs, // Attention probs with size BxNxSxT const T* V, // V value with size BxN_kvxSxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor const size_t batch_size, // batch size @@ -251,7 +271,8 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt - ThreadPool* tp) const { + ThreadPool* tp, + AllocatorPtr allocator) const { const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); @@ -261,7 +282,9 @@ class GQAAttentionBase { const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { - memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); + memset((void*)present_value, + 0, + batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } const size_t loop_len = batch_size * num_heads_; @@ -285,6 +308,13 @@ class GQAAttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; + size_t output_fp32_bytes = 0; + if constexpr (std::is_same::value) { + output_fp32_bytes = SafeInt(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float); + } + auto output_fp32 = allocator->Alloc(output_fp32_bytes); + BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator)); + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const size_t batch_index = i / num_heads_; @@ -305,15 +335,39 @@ class GQAAttentionBase { i / kv_num_heads_factor); } - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/ - attention_probs + attention_probs_offset, - static_cast(present_buffer_sequence_length), v, static_cast(head_size), - 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr); + if constexpr (std::is_same::value) { + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v, + static_cast(head_size), 0.0f /*beta*/, output_current, + static_cast(hidden_size), nullptr); + } else { + size_t bytes = head_size * total_seqlen * sizeof(float); + auto v_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator)); + + float* v_fp32_ptr = static_cast(v_fp32); + MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seqlen); + + float* output_fp32_current = static_cast(output_fp32) + + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v_fp32_ptr, + static_cast(head_size), 0.0f /*beta*/, output_fp32_current, + static_cast(hidden_size), nullptr); + } } }); + + if constexpr (std::is_same::value) { + MlasConvertFloatToHalfBuffer(static_cast(output_fp32), + output, + SafeInt(sequence_length) * batch_size * num_heads_ * head_size); + } } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 2a38e4a1ac636..a1ed35e54b008 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -22,16 +22,20 @@ namespace onnxruntime { namespace contrib { // These ops are internal-only, so register outside of onnx -ONNX_OPERATOR_TYPED_KERNEL_EX( - GroupQueryAttention, - kMSDomain, - 1, - float, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("M", DataTypeImpl::GetTensorType()), - GroupQueryAttention); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + GroupQueryAttention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) template GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 6732f8b96cce2..cbfd2f0949363 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -13,16 +13,20 @@ namespace onnxruntime { namespace contrib { // These ops are internal-only, so register outside of onnx -ONNX_OPERATOR_TYPED_KERNEL_EX( - RotaryEmbedding, - kMSDomain, - 1, - float, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("M", DataTypeImpl::GetTensorType()), - RotaryEmbedding); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { @@ -75,19 +79,27 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete const T* sin_data = sin_cache + cache_offset; int cache_idx = 0; - T sign = 0; + bool sign = false; int j = 0; for (int i = 0; i < rotary_emb_dim; i++) { if (interleaved) { cache_idx = (i / 2) % half_rotary_emb_dim; - sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); - j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + sign = i & 1; + j = sign ? i - 1 : i + 1; // i - sign } else { cache_idx = i % half_rotary_emb_dim; - sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); + sign = (i >= half_rotary_emb_dim); j = (i + half_rotary_emb_dim) % rotary_emb_dim; } - output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); + float input_data_j = static_cast(input_data[j]); + float sin_data_cache_idx = static_cast(sin_data[cache_idx]); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output_data[i] = static_cast(output_data_i); } for (int i = rotary_emb_dim; i < head_size; i++) { output_data[i] = input_data[i]; @@ -102,6 +114,10 @@ template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryPar const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output, bool interleaved); +template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input, + const int64_t* position_ids, const MLFloat16* cos_cache, const MLFloat16* sin_cache, + MLFloat16* output, bool interleaved); + template Status RotaryEmbedding::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index dcd1f5ec22b52..6ffe861d19931 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -22,8 +22,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); @@ -134,12 +136,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomai // LayerNormalization is now in the ONNX spec. As the contrib op (incorrectly) used kOnnxDomain we need to version it class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, float, LayerNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, double, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, MLFloat16, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu); @@ -288,8 +294,10 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -338,12 +346,16 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/layer_norm.cc b/onnxruntime/contrib_ops/cpu/layer_norm.cc index 94f32360bd2f4..c949fcddad093 100644 --- a/onnxruntime/contrib_ops/cpu/layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/layer_norm.cc @@ -25,6 +25,7 @@ namespace contrib { REGISTER_CONTRIB_KERNELS(float) REGISTER_CONTRIB_KERNELS(double) +REGISTER_CONTRIB_KERNELS(MLFloat16) } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index f8f07b6e2827d..67af00beaba06 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -142,6 +142,8 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_{}; size_t packed_b_size_{0}; + IAllocatorUniquePtr scales_fp32_{}; + IAllocatorUniquePtr bias_fp32_{}; bool has_zp_input_{false}; #if defined(ORT_NEURAL_SPEED) @@ -175,30 +177,9 @@ class MatMulNBits final : public OpKernel { const MatMulComputeHelper& helper) const { ORT_THROW("ComputeBPacked is not supported for T1 type."); } - - void PackScale(const Tensor& tensor) { - ORT_THROW("PackScale is not supported for T1 type."); - } }; -#ifdef MLAS_TARGET_AMD64_IX86 -template <> -void MatMulNBits::PackScale(const Tensor& tensor) { - auto sptr = tensor.Data(); - std::vector scales_v(static_cast(tensor.Shape().Size())); - MlasConvertHalfToFloatBuffer(sptr, &scales_v[0], scales_v.size()); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), &scales_v[0], - has_zp_input_, nullptr, nullptr); -} - -template <> -void MatMulNBits::PackScale(const Tensor& tensor) { - auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, - has_zp_input_, nullptr, nullptr); -} -#endif - +#if defined(ORT_NEURAL_SPEED) template Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -207,7 +188,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All if (has_g_idx_ || has_unquantized_zero_point_) { return Status::OK(); } -#if defined(ORT_NEURAL_SPEED) if (!all_constant_) { return Status::OK(); @@ -259,8 +239,21 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All is_packed = true; } + return Status::OK(); +} + #else // defined(ORT_NEURAL_SPEED) + +template +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); + is_packed = false; + if (has_g_idx_ || has_unquantized_zero_point_) { + return Status::OK(); + } + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return Status::OK(); } @@ -276,20 +269,77 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } else if (compute_type_ == CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - PackScale(tensor); + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, nullptr, nullptr); is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } -#endif +#endif // MLAS_TARGET_AMD64_IX86 + } + + return Status::OK(); +} + +template <> +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + ORT_UNUSED_PARAMETER(prepacked_weights); + + if (input_idx == InputIndex::scales || input_idx == InputIndex::bias) { + auto sptr = tensor.Data(); + auto tensor_size = static_cast(tensor.Shape().Size()); + auto ptr = IAllocator::MakeUniquePtr(alloc, tensor_size, true); + MlasConvertHalfToFloatBuffer(sptr, ptr.get(), tensor_size); + if (input_idx == InputIndex::scales) { + scales_fp32_ = std::move(ptr); + } else { + bias_fp32_ = std::move(ptr); + } + } + + is_packed = false; + if (has_g_idx_ || has_unquantized_zero_point_) { + return Status::OK(); + } + + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + return Status::OK(); + } + if (input_idx == InputIndex::B) { + packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); + if (packed_b_size_ == 0) { + return Status::OK(); + } + auto qptr = tensor.DataRaw(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), + nullptr, has_zp_input_, nullptr, nullptr); + is_packed = true; + } else if (compute_type_ == CompInt8) { +#ifdef MLAS_TARGET_AMD64_IX86 + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + scales_fp32_.get(), has_zp_input_, nullptr, nullptr); + is_packed = false; + } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + nullptr, has_zp_input_, zptr, nullptr); + is_packed = false; + } +#endif // MLAS_TARGET_AMD64_IX86 } -#endif // defined(ORT_NEURAL_SPEED) return Status::OK(); } +#endif // !defined(ORT_NEURAL_SPEED) + template Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { @@ -348,7 +398,8 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + // Use reserve since no caching is needed + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size, true); } InlinedVector data(batch_count); @@ -397,22 +448,36 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + // Use reserve since no caching is needed + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size, true); } - auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); - MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); + auto a_size = static_cast(a->Shape().Size()); + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, a_size, true); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), a_size); - auto tmp_scales_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(scales->Shape().Size())); - MlasConvertHalfToFloatBuffer(scales_data, tmp_scales_data_ptr.get(), static_cast(scales->Shape().Size())); + float* scales_ptr = nullptr; + if (!scales_fp32_) { + auto scales_temp = IAllocator::MakeUniquePtr(allocator, static_cast(scales->Shape().Size()), true); + MlasConvertHalfToFloatBuffer(scales_data, scales_temp.get(), static_cast(scales->Shape().Size())); + scales_ptr = scales_temp.get(); + } else { + scales_ptr = scales_fp32_.get(); + } - std::vector bias_data_v; - if (bias_data != nullptr) { - bias_data_v.resize(static_cast(bias->Shape().Size())); - MlasConvertHalfToFloatBuffer(bias_data, &bias_data_v[0], bias_data_v.size()); + float* bias_ptr = nullptr; + if (bias_data) { + if (!bias_fp32_) { + auto bias_temp = IAllocator::MakeUniquePtr(allocator, static_cast(bias->Shape().Size()), true); + MlasConvertHalfToFloatBuffer(bias_data, bias_temp.get(), static_cast(bias->Shape().Size())); + bias_ptr = bias_temp.get(); + } else { + bias_ptr = bias_fp32_.get(); + } } - std::vector C_v(static_cast(y->Shape().Size())); + size_t c_size = static_cast(y->Shape().Size()); + std::vector c_v(c_size); InlinedVector data(batch_count); for (size_t i = 0; i < batch_count; ++i) { @@ -424,15 +489,15 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, } #endif data[i].PackedQuantBData = static_cast(packed_b_.get()); - data[i].QuantBScale = tmp_scales_data_ptr.get(); + data[i].QuantBScale = scales_ptr; data[i].QuantBZeroPoint = zero_points_data; - data[i].Bias = bias_data != nullptr ? &bias_data_v[0] : nullptr; - data[i].C = &C_v[0] + helper.OutputOffsets()[i]; + data[i].Bias = bias ? bias_ptr : nullptr; + data[i].C = c_v.data() + helper.OutputOffsets()[i]; data[i].ldc = N; } MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), thread_pool); - MlasConvertFloatToHalfBuffer(&C_v[0], y_data, C_v.size()); + MlasConvertFloatToHalfBuffer(c_v.data(), y_data, c_size); return Status::OK(); } @@ -461,7 +526,8 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const size_t lda = helper.Lda(false); const size_t ldb = helper.Ldb(true); - auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + // TODO(fajin): move B dequant to prepack + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_, true); if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { // dequantize b, only 4b quantization is supported for now @@ -561,12 +627,6 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); auto* y_data = y->MutableData(); - const float* scales_data_; - std::vector scales_data_v; - scales_data_v.resize(static_cast(scales->Shape().Size())); - MlasConvertHalfToFloatBuffer(scales_data, &scales_data_v[0], scales_data_v.size()); - scales_data_ = &scales_data_v[0]; - const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); const size_t N = static_cast(helper.N()); @@ -574,14 +634,25 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const size_t lda = helper.Lda(false); const size_t ldb = helper.Ldb(true); - auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + float* scales_ptr = nullptr; + if (!scales_fp32_) { + auto scales_size = static_cast(scales->Shape().Size()); + auto temp_scales = IAllocator::MakeUniquePtr(allocator, scales_size, true); + MlasConvertHalfToFloatBuffer(scales_data, temp_scales.get(), scales_size); + scales_ptr = temp_scales.get(); + } else { + scales_ptr = scales_fp32_.get(); + } + + // TODO(fajin): move B dequant to prepack + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_, true); if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { // dequantize b, only 4b quantization is supported for now MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data_, // quantization scales + scales_ptr, // quantization scales static_cast(zero_points_data), // quantization zero points static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -595,7 +666,7 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data_, // quantization scales + scales_ptr, // quantization scales static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size @@ -607,7 +678,7 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data_, // quantization scales + scales_ptr, // quantization scales static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size @@ -623,9 +694,14 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, #endif std::vector data(batch_count); - auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); - MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); - auto tmp_c_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(y->Shape().Size())); + + auto a_size = static_cast(a->Shape().Size()); + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, a_size, true); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), a_size); + + auto c_size = static_cast(y->Shape().Size()); + auto tmp_c_ptr = IAllocator::MakeUniquePtr(allocator, c_size, true); + for (size_t i = 0; i < batch_count; i++) { data[i].BIsPacked = false; data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; @@ -640,24 +716,28 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, // if there is a bias input, copy bias values into C and set beta to 1.0f if (bias) { - auto tmp_bias_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(bias->Shape().Size())); - MlasConvertHalfToFloatBuffer(bias->Data(), - tmp_bias_data_ptr.get(), - static_cast(bias->Shape().Size())); + float* bias_ptr = nullptr; + const size_t bias_size = static_cast(bias->Shape().Size()); + if (!bias_fp32_) { + auto bias_temp = IAllocator::MakeUniquePtr(allocator, bias_size, true); + MlasConvertHalfToFloatBuffer(bias->Data(), bias_temp.get(), bias_size); + bias_ptr = bias_temp.get(); + } else { + bias_ptr = bias_fp32_.get(); + } for (size_t i = 0; i < batch_count; ++i) { float* C_row = data[i].C; const size_t ldc = data[i].ldc; for (size_t m = 0; m < M; ++m) { - std::copy(tmp_bias_data_ptr.get(), tmp_bias_data_ptr.get() + bias->Shape().Size(), C_row); + std::copy(bias_ptr, bias_ptr + bias_size, C_row); C_row += ldc; } data[i].beta = 1.0f; } } - MlasGemmBatch(CblasNoTrans, CblasTrans, - M, N, K, data.data(), batch_count, thread_pool); - MlasConvertFloatToHalfBuffer(tmp_c_ptr.get(), y_data, static_cast(y->Shape().Size())); + MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), batch_count, thread_pool); + MlasConvertFloatToHalfBuffer(tmp_c_ptr.get(), y_data, c_size); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index 4e103c2556a7a..faf78cae80ee1 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -5,6 +5,7 @@ #include "core/util/math_cpuonly.h" #include "core/providers/common.h" #include "core/platform/threadpool.h" +#include "core/util/force_inline.h" #include "skip_layer_norm.h" #include "skip_layer_norm_helper.h" @@ -33,6 +34,50 @@ namespace contrib { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) +REGISTER_KERNEL_TYPED(MLFloat16) + +// Utility to convert from MLFloat16 to float only when the input type is MLFloat16. +template +ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val); + +template <> +ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { + return val.ToFloat(); +} + +template <> +ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { + return static_cast(ConvertMLFloat16ToDoubleOrFloatIfNeeded(val)); +} + +template <> +ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded(float val) { + return val; +} + +template <> +ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded(double val) { + return val; +} + +// Function template that only converts the input value to MLFloat16 if T is MLFloat16. +template +ORT_FORCEINLINE constexpr typename std::enable_if_t || std::is_same_v, T> +ConvertDoubleOrFloatToMLFloat16IfNeeded(T val) { + return val; +} + +template +ORT_FORCEINLINE constexpr typename std::enable_if_t, T> +ConvertDoubleOrFloatToMLFloat16IfNeeded(float val) { + return MLFloat16(val); +} + +template +ORT_FORCEINLINE constexpr typename std::enable_if_t, T> +ConvertDoubleOrFloatToMLFloat16IfNeeded(double val) { + return MLFloat16(static_cast(val)); +} template SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) @@ -91,21 +136,32 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { T* p_output = output_data + offset; T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr; - T mean = 0; - T mean_square = 0; + using DoubleOrFloat = typename std::conditional< + std::is_same::value, // If T is double + double, // Use double + float // Otherwise, use float (covers float and MLFloat16) + >::type; + + DoubleOrFloat mean(0.0f); + DoubleOrFloat mean_square(0.0f); + + std::unique_ptr output_buffer = std::make_unique(hidden_size); + for (size_t h = 0; h < static_cast(hidden_size); h++) { + DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]); + DoubleOrFloat skip_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_skip[h]); - for (int64_t h = 0; h < hidden_size; h++) { - T value = p_input[h] + p_skip[h]; + DoubleOrFloat value = input_value + skip_value; if (nullptr != bias_data) { - value += bias_data[h]; + value += ConvertMLFloat16ToDoubleOrFloatIfNeeded(bias_data[h]); } + output_buffer[h] = value; + T converted_value = ConvertDoubleOrFloatToMLFloat16IfNeeded(value); if (nullptr != p_skip_input_bias_add_output_data) { - p_skip_input_bias_add_output_data[h] = value; + p_skip_input_bias_add_output_data[h] = converted_value; } - p_output[h] = value; mean += value; mean_square += value * value; } @@ -117,13 +173,15 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); } - for (int64_t h = 0; h < hidden_size; h++) { + for (size_t h = 0; h < static_cast(hidden_size); h++) { + DoubleOrFloat gamma_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(gamma_data[h]); if (simplified) { - p_output[h] = p_output[h] / mean_square * gamma_data[h]; + p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded(output_buffer[h] / mean_square * gamma_value); } else if (nullptr == beta_data) { - p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h]; + p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded((output_buffer[h] - mean) / mean_square * gamma_value); } else { - p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h]; + DoubleOrFloat beta_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(beta_data[h]); + p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded((output_buffer[h] - mean) / mean_square * gamma_value + beta_value); } } }, diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 5e66f2b99fded..b6dc8ad56f257 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -141,7 +141,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 || strcmp(name1, onnxruntime::DML) == 0 || strcmp(name1, onnxruntime::HIP) == 0 || - strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0) { + strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || + strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc index 7665a90448520..607969cd4cdc4 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc @@ -333,7 +333,7 @@ void DumpTensor( } else { std::cout << tensor_location << "\n"; -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) const auto data_type = tensor.DataType(); // Dumping GPU only when cuda is enabled. if (tensor_location.device.Type() == OrtDevice::GPU) { diff --git a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp b/onnxruntime/core/mlas/lib/fp16_neon_common.cpp new file mode 100644 index 0000000000000..29734c2277667 --- /dev/null +++ b/onnxruntime/core/mlas/lib/fp16_neon_common.cpp @@ -0,0 +1,164 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + fp16_neon_common.cpp + +Abstract: + + This module implements the common kernels for ARM NEON specific to float16. + +--*/ + +#include "mlasi.h" + +#include "arm_neon.h" + +// This file is enabled in cmake only if ARM64 is defined and not on Apple platforms +// The cmake condition is equivalent to MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64. +// Therefore omit the MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 macro in this file. + +MLAS_FORCEINLINE +size_t +StoreFp32Lane(float* dest, float32x4_t src, size_t count) +{ + if (count == 3) { + vst1q_lane_f32(dest + 0, src, 0); + vst1q_lane_f32(dest + 1, src, 1); + vst1q_lane_f32(dest + 2, src, 2); + return 3; + } else if (count == 2) { + vst1q_lane_f32(dest + 0, src, 0); + vst1q_lane_f32(dest + 1, src, 1); + return 2; + } else if (count == 1) { + vst1q_lane_f32(dest + 0, src, 0); + return 1; + } + + return 0; +} + +void +MlasCastF16ToF32KernelNeon(const unsigned short* src, float* dest, size_t count) +{ + // 4 float16 alignment + auto* src_aligned = reinterpret_cast((reinterpret_cast(src) + 7) & ~7); + auto pre_count = std::min(static_cast(src_aligned - src), count); + size_t i = 0; + + // Handle leading unaligned src + if (pre_count > 0) { + float16x4_t fp16v4; + std::memcpy(&fp16v4, src, pre_count * sizeof(unsigned short)); + float32x4_t fp32v4 = vcvt_f32_f16(fp16v4); + + i = StoreFp32Lane(dest, fp32v4, pre_count); + } + + // aligned src + for (; i + 7 < count; i += 8) + { + float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i)); + float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0); + vst1q_f32(dest + i, fp32v4_0); + + float16x4_t fp16v4_1 = vreinterpret_f16_u16(vld1_u16(src + i + 4)); + float32x4_t fp32v4_1 = vcvt_f32_f16(fp16v4_1); + vst1q_f32(dest + i + 4, fp32v4_1); + } + + if (i + 3 < count) + { + float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i)); + float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0); + vst1q_f32(dest + i, fp32v4_0); + i += 4; + } + + // Handle trailing unaligned src + auto post_count = count - i; + if (post_count > 0) + { + float16x4_t fp16v4; + std::memcpy(&fp16v4, src + i, post_count * sizeof(unsigned short)); + float32x4_t fp32v4 = vcvt_f32_f16(fp16v4); + + StoreFp32Lane(dest + i, fp32v4, post_count); + } +} + +MLAS_FORCEINLINE +size_t +StoreU16Lane(unsigned short* dest, uint16x4_t src, size_t count) +{ + if (count == 3) { + vst1_lane_u16(dest + 0, src, 0); + vst1_lane_u16(dest + 1, src, 1); + vst1_lane_u16(dest + 2, src, 2); + return 3; + } else if (count == 2) { + vst1_lane_u16(dest + 0, src, 0); + vst1_lane_u16(dest + 1, src, 1); + return 2; + } else if (count == 1) { + vst1_lane_u16(dest + 0, src, 0); + return 1; + } + + return 0; +} + +void +MlasCastF32ToF16KernelNeon(const float* src, unsigned short* dest, size_t count) +{ + // 4 float32 alignment + auto* src_aligned = reinterpret_cast((reinterpret_cast(src) + 15) & ~15); + auto pre_count = std::min(static_cast(src_aligned - src), count); + size_t i = 0; + + // Handle leading unaligned src + if (pre_count > 0) + { + float32x4_t fp32v4; + std::memcpy(&fp32v4, src, pre_count * sizeof(float)); + uint16x4_t u16v4 = vreinterpret_u16_f16(vcvt_f16_f32(fp32v4)); + + i = StoreU16Lane(dest, u16v4, pre_count); + } + + // aligned src + for (; i + 7 < count; i += 8) + { + float32x4_t fp32v4_0 = vld1q_f32(src + i); + float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0); + vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0)); + + float32x4_t fp32v4_1 = vld1q_f32(src + i + 4); + float16x4_t fp16v4_1 = vcvt_f16_f32(fp32v4_1); + vst1_u16(dest + i + 4, vreinterpret_u16_f16(fp16v4_1)); + } + + if (i + 3 < count) + { + float32x4_t fp32v4_0 = vld1q_f32(src + i); + float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0); + vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0)); + i += 4; + } + + // Handle trailing unaligned src + auto post_count = count - i; + if (post_count > 0) + { + float32x4_t fp32v4; + std::memcpy(&fp32v4, src + i, post_count * sizeof(float)); + uint16x4_t u16v4 = vreinterpret_u16_f16(vcvt_f16_f32(fp32v4)); + + StoreU16Lane(dest + i, u16v4, post_count); + } +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 96ba8c6c92b26..13ea8d96c20e4 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -893,6 +893,10 @@ extern "C" { MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelAvx2; #endif +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelNeon; + MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelNeon; +#endif } // @@ -2603,4 +2607,3 @@ MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueH static_assert(std::is_same_v || std::is_same_v); *Output = static_cast(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF)); } - diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 102d6052276bc..23d29fd02fa5a 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -20,7 +20,7 @@ Module Name: #include #include -#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) #if defined(__linux__) #include #elif defined(_AIX) @@ -576,6 +576,11 @@ Return Value: } #endif +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon; + this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelNeon; +#endif + #endif // MLAS_TARGET_ARM64 #if defined(MLAS_TARGET_POWER) this->GemmFloatKernel = MlasSgemmKernel; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp index ca64ebe3b1137..12ddc42506e98 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -12,6 +12,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to + input type T1 as float32 and MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. --*/ diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index ec5cdbc75220a..0d62ea37b7e26 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -12,6 +12,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to + input type T1 as float32 and MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. --*/ diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 5d689a9d933e8..470838d36ec1c 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -2749,7 +2749,9 @@ static bool CanModifyNode(const OptimizerCtx& ctx, const api::NodeRef& node) { ///

/// Try to remove empty DQ -> Q pair that results from moving a Transpose downstream or a Transpose being canceled out. -/// (DQ -> Q -> consumer node) => consumer node +/// Handles the following scenarios: +/// - (DQ -> Q -> consumer node) => consumer node +/// - (parent node -> DQ -> Q -> graph output) => parent node -> graph output /// /// Optimizer context /// QuantizeLinear node @@ -2764,12 +2766,27 @@ static bool TryRemoveEmptyDQQ(OptimizerCtx& ctx, api::NodeRef& q_node) { } auto& dq_node = *input_node; - std::unique_ptr single_consumer_node; - // remove empty DQ -> Q before a consumer node if the DQ and Q have matching types, scale and zp. - if (OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, single_consumer_node) && - OutputValueHasSingleConsumerNode(ctx.graph, q_node, 0, single_consumer_node) && - CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) { + // DQ should have a single consumer (the Q) + std::unique_ptr dq_consumer_node; + if (!OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, dq_consumer_node)) { + return false; + } + + // The DQ and Q should have matching types, scale and zp. + if (!CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) { + return false; + } + + std::string_view q_output = q_node.Outputs()[0]; + auto q_consumers = ctx.graph.GetValueConsumers(q_output); + const size_t num_q_consumers = q_consumers->nodes.size(); + const bool q_has_single_consumer = q_consumers->comprehensive && (num_q_consumers == 1); + + // (DQ -> Q -> consumer node) => consumer node + if (q_has_single_consumer) { + std::unique_ptr single_consumer_node = std::move(q_consumers->nodes[0]); + // connect Q consumer to DQ input for (size_t j_idx = 0, j_end = single_consumer_node->Inputs().size(); j_idx < j_end; ++j_idx) { if (single_consumer_node->Inputs()[j_idx] == q_node.Outputs()[0]) { @@ -2787,6 +2804,40 @@ static bool TryRemoveEmptyDQQ(OptimizerCtx& ctx, api::NodeRef& q_node) { return true; } + // (parent node -> DQ -> Q -> graph output) => (parent node -> graph output) + if (num_q_consumers == 0 && ctx.graph.IsGraphOutput(q_output)) { + // Get the DQ's parent node. + std::string_view dq_input = dq_node.Inputs()[0]; + auto dq_parent_node = ctx.graph.GetNodeProducingOutput(dq_input); + if (!dq_parent_node) { + return false; // Don't handle DQ that consumes a graph input. + } + + // Find index of output from DQ's parent node + auto dq_parent_outputs = dq_parent_node->Outputs(); + size_t dq_parent_output_index = 0; + for (dq_parent_output_index = 0; dq_parent_output_index < dq_parent_outputs.size(); ++dq_parent_output_index) { + if (dq_parent_outputs[dq_parent_output_index] == dq_input) break; + } + + // The DQ's parent should only have a single consumer (i.e., the DQ itself). + std::unique_ptr dq_parent_consumer; + if (!OutputValueHasSingleConsumerNode(ctx.graph, *dq_parent_node, dq_parent_output_index, dq_parent_consumer)) { + return false; + } + + // Move Q's output to come out of DQ's parent node so the graph output value name is maintained. + dq_node.SetInput(0, ""); // Disconnect DQ from its parent first. + ctx.graph.MoveOutput(q_node, 0, *dq_parent_node, dq_parent_output_index); + + // Disconnect Q and remove both DQ and Q from the graph. + q_node.SetInput(0, ""); + ctx.graph.RemoveNode(dq_node); + ctx.graph.RemoveNode(q_node); + + return true; + } + return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index c8670cd546253..5389eb5ab7e95 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -104,7 +104,13 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (add_alpha) { NodeAttrHelper helper(node); const auto alpha = helper.Get("alpha", 0.01f); - AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + } else { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha))); + } } AddOperationOutput(*op, *node.OutputDefs()[0]); diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 2cae85a0a1c8d..f185a80de3cbf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" @@ -12,6 +13,15 @@ using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { +// Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to +// filter suppported ones. +static std::set Float16Ops = { + "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", + "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", + "Clip", "DepthToSpace", "Resize", "Slice", "Conv", + "ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul", + "AveragePool", "MaxPool", "Reshape", "Split", "Transpose"}; + namespace { // TODO, move this to shared_library bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, @@ -83,8 +93,9 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputPar } /* static */ -bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& /*input_params*/, - const logging::Logger& logger) { +bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, + [[maybe_unused]] const OpBuilderInputParams& input_params, + const logging::Logger& logger) { if (idx >= node.InputDefs().size()) { LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range"; return false; @@ -94,20 +105,33 @@ bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderIn int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; - // currently only float is supported - if (!GetType(input, input_type, logger) || input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; + if (!GetType(input, input_type, logger)) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Get Input type failed"; return false; } - return true; + // float is supported + if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return true; + } + +// only support MLProgram for FP16 +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && + Float16Ops.count(node.OpType())) { + return true; + } +#endif + + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; + return false; } bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { // We only check the type of input 0 by default // specific op builder can override this - return IsInputFloat(node, 0, input_params, logger); + return IsInputDtypeSupport(node, 0, input_params, logger); } bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index 071008520fbdc..153ae841b238f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -32,9 +32,9 @@ class BaseOpBuilder : public IOpBuilder { : allow_empty_tensor_as_input_(allow_empty_tensor_as_input) { } - // currently we only support float - static bool IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& input_params, - const logging::Logger& logger); + // currently we support float/float16 + static bool IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params, + const logging::Logger& logger); private: virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/, diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index fb8e07633621f..8aa2dbae2531c 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -73,7 +73,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } else if (op_type == "Sub") { coreml_op_type = "sub"; } else if (op_type == "Div") { - // we only support fp32 currently. when we add support for integers we need to check the type and use + // we support fp32/fp16 currently. when we add support for integers we need to check the type and use // "floor_div" or "real_div" accordingly coreml_op_type = "real_div"; } else if (op_type == "Pow") { @@ -138,9 +138,22 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderIn const logging::Logger& logger) const { // Add/Sub/Mul/Div spec says inputs must be of the same type. // Pow spec says inputs can be different types. - // We only support float for all of these inputs. - if (!IsInputFloat(node, 0, input_params, logger) || - ((node.OpType() == "Pow") && !IsInputFloat(node, 1, input_params, logger))) { + // We support float/float16 for all of these inputs. + + if (node.OpType() == "Pow") { + const auto& input0 = *node.InputDefs()[0]; + const auto& input1 = *node.InputDefs()[1]; + int32_t input_type0 = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + int32_t input_type1 = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + if (!GetType(input0, input_type0, logger)) { + return false; + } + if (!GetType(input1, input_type1, logger) || input_type1 != input_type0) { + return false; + } + } + + if (!IsInputDtypeSupport(node, 0, input_params, logger)) { return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index e02186d3aee89..6f9bb35c27d80 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -96,6 +96,9 @@ Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan()); break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan()); + break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan()); break; @@ -114,6 +117,11 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::spanAssign(data.begin(), data.end()); } +void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data) { + const char* data_byte_ptr = reinterpret_cast(data.data()); + weight.mutable_float16value()->assign(data_byte_ptr, data_byte_ptr + data.size_bytes()); +} + namespace { template void CreateCoreMLWeightConvertingDataToFloats(CoreML::Specification::WeightParams& weight, gsl::span data) { @@ -123,6 +131,15 @@ void CreateCoreMLWeightConvertingDataToFloats(CoreML::Specification::WeightParam [](T v) { return narrow(v); }); *weight.mutable_floatvalue() = std::move(weight_floats); } + +template +void CreateCoreMLWeightConvertingDataToFloat16s(CoreML::Specification::WeightParams& weight, gsl::span data) { + std::vector weight_float16s{}; + weight_float16s.reserve(data.size()); + std::transform(data.begin(), data.end(), std::back_inserter(weight_float16s), + [](T v) { return MLFloat16(float(v)); }); + CreateCoreMLWeight(weight, weight_float16s); +} } // namespace void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data) { @@ -195,6 +212,13 @@ void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span< tensor_value.mutable_floats()->mutable_values()->Add(data.begin(), data.end()); } +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + const char* begin = reinterpret_cast(data.data()); + const char* end = begin + (data.size() * sizeof(MLFloat16)); + tensor_value.mutable_bytes()->mutable_values()->assign(begin, end); +} + template <> void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { tensor_value.mutable_ints()->mutable_values()->Add(data.begin(), data.end()); @@ -290,6 +314,14 @@ MILSpec::Value CreateScalarTensorValue(const T& data) { // explicit specializations for types we handle so the implementation can be in the .cc file template MILSpec::Value CreateTensorValue(gsl::span data, std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); template MILSpec::Value CreateScalarTensorValue(const float& data); template MILSpec::Value CreateScalarTensorValue(const int32_t& data); diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 475ce79b0a812..f38afc0ec181d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -41,6 +41,9 @@ Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, const ONN // Copy the float array to a coreml weight void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); +// Copy the MLFloat16 array to a coreml weight +void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); + // Copy the int32_t array to a coreml weight void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index 41f4041ef1181..bc9e2f10296ed 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -92,16 +92,30 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, Operation& clip_op = *op; AddOperationInput(clip_op, "x", input_name); + // we already checked it and dtype must be existed. + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); // if min and max were attributes we need to add initializers. otherwise we use the existing inputs const bool min_max_attribs = node.SinceVersion() < 11; - std::string_view min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) - : node.InputDefs()[1]->Name(); + std::string_view min_name; + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) + : node.InputDefs()[1]->Name(); + } else { + min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", MLFloat16(min)) + : node.InputDefs()[1]->Name(); + } AddOperationInput(clip_op, "alpha", min_name); if (has_max) { - std::string_view max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) - : node.InputDefs()[2]->Name(); + std::string_view max_name; + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) + : node.InputDefs()[2]->Name(); + } else { + max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", MLFloat16(max)) + : node.InputDefs()[2]->Name(); + } AddOperationInput(clip_op, "beta", max_name); } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index bec2461ffbc52..ddaa19c7fab18 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -67,7 +67,9 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // we checked shape was static in IsOpSupportedImpl so this should never fail std::vector input_shape; ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Failed to get input shape"); - const int32_t elem_type = static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + + const int32_t elem_type = static_cast(input_dtype); // reshape to [b * c // (blocksize ** 2), blocksize, blocksize, h, w] auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre"); diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 7338fc18fe779..e685c09ef43ca 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -70,16 +70,17 @@ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod } } -// This is an internal function, requires input tensor to be 2d float tensor -// TODO, add support of other data types -static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& tensor, - std::vector& transposed_data) { +// This is an internal function, requires input tensor to be 2d float/float16 tensor +template +static Status GetTensorDataTransposed(const ONNX_NAMESPACE::TensorProto& tensor, + std::vector& transposed_data) { Initializer unpacked_tensor(tensor); - auto src_data = unpacked_tensor.DataAsSpan(); + const auto src_data = unpacked_tensor.DataAsSpan(); const auto& tensor_shape = tensor.dims(); auto x_t = SafeInt(tensor_shape[0]); auto y_t = SafeInt(tensor_shape[1]); transposed_data.resize(x_t * y_t); + for (size_t x = 0; x < x_t; x++) { for (size_t y = 0; y < y_t; y++) { transposed_data[y * x_t + x] = src_data[x * y_t + y]; @@ -121,8 +122,9 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true const auto K = transB ? b1 : b0; const auto N = transB ? b0 : b1; - + // we already checked it and dtype must be existed. #if defined(COREML_ENABLE_MLPROGRAM) + auto input_dtype = a.TypeAsProto()->tensor_type().elem_type(); if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -136,13 +138,19 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (transB) { AddOperationInput(*gemm_op, "weight", b.Name()); } else { - // transpose from {K, N} to {N, K} - std::vector weight_nk; std::vector weight_nk_shape = {N, K}; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, weight_nk)); - - AddOperationInput(*gemm_op, "weight", - model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + // transpose from {K, N} to {N, K} + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector weight_nk; + ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk)); + AddOperationInput(*gemm_op, "weight", + model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + } else { // TensorProto_DataType_FLOAT16 + std::vector weight_nk; + ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk)); + AddOperationInput(*gemm_op, "weight", + model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + } } if (input_defs.size() == 3) { @@ -155,15 +163,28 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N AddOperationInput(*gemm_op, "bias", bias_arg.Name()); } else { Initializer unpacked_tensor(bias); - auto bias_data = unpacked_tensor.DataAsSpan(); std::string_view bias_data_name; - if (bias_data.size() == 1) { - // expand scalar to N - std::vector expanded_bias_data(N, bias_data[0]); - bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); - } else { - // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) - bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + auto bias_data = unpacked_tensor.DataAsSpan(); + if (bias_data.size() == 1) { + // expand scalar to N + std::vector expanded_bias_data(N, bias_data[0]); + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); + } else { + // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + } + } else { // TensorProto_DataType_FLOAT16 + auto bias_data = unpacked_tensor.DataAsSpan(); + if (bias_data.size() == 1) { + // expand scalar to N + std::vector expanded_bias_data(N, bias_data[0]); + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); + } else { + // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + } } AddOperationInput(*gemm_op, "bias", bias_data_name); @@ -202,7 +223,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), *b_initializer)); } else { std::vector b_transposed; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, b_transposed)); + ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, b_transposed)); CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc index 9caec290ea5a2..6dcf14c16f111 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc @@ -49,6 +49,9 @@ Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& const auto input_defs = node.InputDefs(); const auto output_defs = node.OutputDefs(); + // we already checked it and dtype must be existed. + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + NodeAttrHelper helper(node); std::string mode{GetMode(helper)}; // need a std::string for use in AddScalarConstant std::string padding_mode = helper.Get("padding_mode", "zeros"); @@ -65,7 +68,11 @@ Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& AddOperationInput(*op, "coordinates", input_defs[1]->Name()); AddOperationInput(*op, "sampling_mode", model_builder.AddScalarConstant(op->type(), "sampling_mode", mode)); AddOperationInput(*op, "padding_mode", model_builder.AddScalarConstant(op->type(), "padding_mode", padding_mode)); - AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", 0.0f)); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", 0.0f)); + } else { + AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", MLFloat16(0.0f))); + } AddOperationInput(*op, "coordinates_mode", model_builder.AddScalarConstant(op->type(), "coordinates_mode", coordinates_mode)); AddOperationInput(*op, "align_corners", model_builder.AddScalarConstant(op->type(), "align_corners", align_corners)); diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 51fc3f2c11c73..6b3fe75fa592d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -144,7 +144,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } } - // Only int32 and float are supported by CoreML slice_by_index. + // Int32, float and float16 are supported by CoreML slice_by_index. // We convert any int64 model input to int32 when running the CoreML model for the partition. // Any other integer data created at runtime is the output from CoreML operations, and should int32 not int64. // Based on that, we assume that the actual input when running will be int32, so we override the output data @@ -214,18 +214,29 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, + [[maybe_unused]] const OpBuilderInputParams& input_params, const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) { return false; } - if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { - LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported"; - return false; - } +#ifdef COREML_ENABLE_MLPROGRAM + // The [Doc](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.slice_by_index) + // says ML Program slice_by_index supports fp16 in CoreML 5 (iOS 15). + // It's incorrect and CoreML 6+ (iOS16, CoreML spec version >= 7) is required otherwise only float is supported. + // CoreML 5:https://github.com/apple/coremltools/blob/89d058ffdcb0b39a03031782d8a448b6889ac425/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py#L515 + // CoreML 6:https://github.com/apple/coremltools/blob/c3ea4cf56fef1176417246c1b85363417f3e713d/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_transformation.py#L495 + if (input_params.create_mlprogram && input_params.coreml_version >= 6 && + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + } else +#endif // nolint + if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported"; + return false; + } return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index 3403378d59114..335ca737081b2 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -3,6 +3,7 @@ #include "core/providers/common.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" @@ -14,6 +15,7 @@ namespace coreml { class UnaryOpBuilder : public BaseOpBuilder { Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, @@ -21,21 +23,54 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = model_builder.CreateNNLayer(node); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; - if (op_type == "Sqrt") { - layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT); - } else if (op_type == "Reciprocal") { - layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::INVERSE); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } + std::string_view coreml_op_type; + if (op_type == "Sqrt") { + coreml_op_type = "sqrt"; + } else if (op_type == "Reciprocal") { + coreml_op_type = "inverse"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "UnaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); + } + + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + if (op_type == "Reciprocal") { + float epsilon = 1e-4; // epsilon: const T (Optional, default=1e-4) + auto dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon)); + } else if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", MLFloat16(epsilon))); + } + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + model_builder.AddOperation(std::move(op)); + } else // NOLINT +#endif // defined (COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); - model_builder.AddLayer(std::move(layer)); + if (op_type == "Sqrt") { + layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT); + } else if (op_type == "Reciprocal") { + layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::INVERSE); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } + + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 9668bfcd09adf..50faebf06875d 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -639,6 +639,14 @@ std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::st return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } +template <> +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + template <> std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, @@ -811,6 +819,9 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: multi_array->set_datatype(ArrayFeatureType::FLOAT32); break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + multi_array->set_datatype(ArrayFeatureType::FLOAT16); + break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: multi_array->set_datatype(ArrayFeatureType::INT32); break; diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index bb791fb902908..b3dfec29872a2 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -107,11 +107,12 @@ class ModelBuilder { std::string_view AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, std::optional> shape = std::nullopt) { static_assert(std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, // add specialization in AddConstantImpl for new types if needed - "AddConstant currently supports float, int64_t, std::string and bool."); + "AddConstant currently supports float, MLFloat16, int64_t, std::string and bool."); return AddConstantImpl(op_type, value_type, value, shape); } diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 68460ff7c9b31..1401cbe95fd56 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -120,6 +120,10 @@ Status CreateInputFeatureProvider(const std::unordered_map +void StridedCopy(const T* src_buffer, T* dst_buffer, size_t block_size, + size_t num_blocks, size_t src_stride, size_t dst_stride) { + for (size_t idx = 0; idx < num_blocks; ++idx) { + std::copy_n(src_buffer, block_size, dst_buffer); + src_buffer += src_stride; + dst_buffer += dst_stride; + } +} + Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buffer, const MLMultiArray* array, const int64_t num_blocks, const int64_t block_size, const int64_t stride, @@ -196,25 +210,21 @@ Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buff case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); - const auto block_byte_size = block_size * sizeof(float); + StridedCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); + + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + const auto* src_buffer = static_cast(mlmultiarray_buffer); + auto* dst_buffer = static_cast(tensor_buffer); + StridedCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); - for (int64_t idx = 0; idx < num_blocks; ++idx) { - memcpy(dst_buffer, src_buffer, block_byte_size); - src_buffer += stride; - dst_buffer += block_size; - } break; } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); - const auto block_byte_size = block_size * sizeof(int32_t); - - for (int64_t idx = 0; idx < num_blocks; ++idx) { - memcpy(dst_buffer, src_buffer, block_byte_size); - src_buffer += stride; - dst_buffer += block_size; - } + StridedCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); break; } diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 7b1b136eb091e..424bee63511ad 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -903,6 +903,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, Me class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, STFT); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, double, LayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization); // Opset 18 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18, float, Resize); @@ -2465,6 +2466,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { LayerNormalization)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 18 BuildKernelCreateInfo output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_1_vec_map.min(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template min( + static_cast(per_iter_bh.ScalarInput0())); } else { - output_vec_map = input_1_vec_map.max(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template max( + static_cast(per_iter_bh.ScalarInput0())); } }, [](BroadcastHelper& per_iter_bh) { @@ -772,9 +774,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.min(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template min( + static_cast(per_iter_bh.ScalarInput1())); } else { - output_vec_map = input_0_vec_map.max(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template max( + static_cast(per_iter_bh.ScalarInput1())); } }, [](BroadcastHelper& per_iter_bh) { @@ -790,9 +794,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.min(input_1_vec_map); + output_vec_map = input_0_vec_map.template min(input_1_vec_map); } else { - output_vec_map = input_0_vec_map.max(input_1_vec_map); + output_vec_map = input_0_vec_map.template max(input_1_vec_map); } }}; diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm.cc b/onnxruntime/core/providers/cpu/nn/layer_norm.cc index 56e5042fc7408..56463d00840cd 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm.cc @@ -15,5 +15,6 @@ namespace onnxruntime { REGISTER_ONNX_KERNEL_TYPED(float) REGISTER_ONNX_KERNEL_TYPED(double) +REGISTER_ONNX_KERNEL_TYPED(MLFloat16) } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index e01f7f27c3596..23630dcb63efa 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -7,10 +7,62 @@ #include "core/framework/tensor.h" #include "core/platform/threadpool.h" #include "core/providers/common.h" +#include "core/util/force_inline.h" #include "core/util/math_cpuonly.h" namespace onnxruntime { +// Utility to convert from MLFloat16 to float only when the input type is MLFloat16. +template +ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val); + +template <> +ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { + return val.ToFloat(); +} + +template <> +ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { + return double(ConvertMLFloat16ToDoubleOrFloatIfNeeded(val)); +} + +template <> +ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded(float val) { + return val; +} + +template <> +ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded(double val) { + return val; +} + +ORT_FORCEINLINE constexpr float ConvertToFloatIfNeeded(float val) { + return val; +} + +ORT_FORCEINLINE constexpr float ConvertToFloatIfNeeded(double val) { + // ONNX spec doesn't support 'double' for 'Ret' so when 'T' == double, 'Ret' == float and we need to narrow + return gsl::narrow_cast(val); +} + +// Function template that only converts the input value to MLFloat16 if T is MLFloat16. +template +ORT_FORCEINLINE constexpr typename std::enable_if_t || std::is_same_v, float> +ConvertToMLFloat16IfNeeded(float val) { + return val; +} + +template +ORT_FORCEINLINE constexpr typename std::enable_if_t, MLFloat16> +ConvertToMLFloat16IfNeeded(float val) { + return MLFloat16(val); +} + +template +ORT_FORCEINLINE constexpr double ConvertToMLFloat16IfNeeded(double val) { + return val; +} + LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op) : OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op} { ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK()); @@ -24,14 +76,14 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo const Tensor* X = p_ctx->Input(0); const Tensor* scale = p_ctx->Input(1); const Tensor* bias = p_ctx->Input(2); - auto X_data = X->Data(); - auto scale_data = scale->Data(); - auto bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data(); + const T* X_data = X->Data(); + const T* scale_data = scale->Data(); + const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data(); const TensorShape& x_shape = X->Shape(); const int64_t axis = HandleNegativeAxis(orig_axis, x_shape.NumDimensions()); - auto norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis)); - auto norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); + int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis)); + int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); const auto scale_size = scale->Shape().Size(); const auto bias_size = (bias_data) ? bias->Shape().Size() : 0; @@ -80,12 +132,19 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo const T* p_input = X_data + task_idx * norm_size; T* p_output = Y_data + task_idx * norm_size; - T mean = 0; - T mean_square = 0; + using DoubleOrFloat = typename std::conditional< + std::is_same::value, // If T is double + double, // Use double + float // Otherwise, use float (covers float and MLFloat16) + >::type; + + DoubleOrFloat mean(0.0f); + DoubleOrFloat mean_square(0.0f); for (int64_t h = 0; h < norm_size; h++) { - mean += p_input[h]; - mean_square += p_input[h] * p_input[h]; + DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]); + mean += input_value; + mean_square += input_value * input_value; } mean = mean / norm_size; @@ -96,22 +155,25 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo } for (int64_t h = 0; h < norm_size; h++) { + DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]); + DoubleOrFloat scale_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(scale_data[h]); if (simplified) { - p_output[h] = p_input[h] / mean_square * scale_data[h]; + p_output[h] = ConvertToMLFloat16IfNeeded(input_value / mean_square * scale_value); } else if (nullptr == bias) { - p_output[h] = (p_input[h] - mean) / mean_square * scale_data[h]; + p_output[h] = ConvertToMLFloat16IfNeeded((input_value - mean) / mean_square * scale_value); } else { - p_output[h] = (p_input[h] - mean) / mean_square * scale_data[h] + bias_data[h]; + DoubleOrFloat bias_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(bias_data[h]); + p_output[h] = ConvertToMLFloat16IfNeeded((input_value - mean) / mean_square * scale_value + bias_value); } } if (mean_data != nullptr) { // ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow - mean_data[task_idx] = gsl::narrow_cast(mean); + mean_data[task_idx] = ConvertToMLFloat16IfNeeded(ConvertToFloatIfNeeded(mean)); } if (inv_std_dev_data != nullptr) { - inv_std_dev_data[task_idx] = gsl::narrow_cast(1 / mean_square); + inv_std_dev_data[task_idx] = ConvertToMLFloat16IfNeeded(ConvertToFloatIfNeeded(1 / mean_square)); } }, 0); @@ -141,7 +203,7 @@ struct SrcDispatcher { Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const { const auto elem_type = p_ctx->Input(0)->GetElementType(); - using SupportedTypeList = boost::mp11::mp_list; + using SupportedTypeList = boost::mp11::mp_list; utils::MLTypeCallDispatcherFromTypeList t_disp(elem_type); return t_disp.InvokeRet(p_ctx, axis_, epsilon_, simplified_, contrib_op_); diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index db36754319309..55935a9eae86d 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -10,13 +10,10 @@ #include #include #include +#include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/cuda_call.h" -#if CUDA_VERSION >= 11000 -#include -#endif - namespace onnxruntime { namespace cuda { @@ -347,6 +344,21 @@ __device__ __inline__ double _Pow(double a, double b) { return pow(a, b); } template <> __device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); } +#define ISNAN_HALF(v__) static_cast(*reinterpret_cast(&v__) & ~MLFloat16::kSignMask) \ + > MLFloat16::kPositiveInfinityBits + +#define ISNAN_BFLOAT16(v__) static_cast(*reinterpret_cast(&v__) & ~BFloat16::kSignMask) \ + > BFloat16::kPositiveInfinityBits + +// CUDART_NAN_BF16 and CUDART_NAN_FP16 constants were only added in CUDA 12.2, +// so define our own equivalent constants to support older versions. +// Note that there is no consistent canonical NaN for FP16 and BF16; +// CUDA uses 0x7FFF for both, but ONNX Runtime uses 0x7E00 and 0x7FC1 +// for FP16 and BF16 respectively +// (see Float16Impl::kPositiveQNaNBits and BFloat16Impl::kPositiveQNaNBits). +#define NAN_HALF __ushort_as_half((unsigned short)0x7FFFU) +#define NAN_BFLOAT16 BFloat16::FromBits((uint16_t)0x7FFFU) + template __device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; } @@ -360,6 +372,24 @@ __device__ __inline__ double _Min(double a, double b) { return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); } +template <> +__device__ __inline__ half _Min(half a, half b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_HALF(a) || ISNAN_HALF(b)) ? NAN_HALF : (a < b ? a : b); +#else + return __hmin_nan(a, b); +#endif +} + +template <> +__device__ __inline__ BFloat16 _Min(BFloat16 a, BFloat16 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a < b ? a : b); +#else + return BFloat16(__hmin_nan((__nv_bfloat16)a, (__nv_bfloat16)b)); +#endif +} + template __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } @@ -373,6 +403,29 @@ __device__ __inline__ double _Max(double a, double b) { return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); } +template <> +__device__ __inline__ half _Max(half a, half b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_HALF(a) || ISNAN_HALF(b)) ? NAN_HALF : (a > b ? a : b); +#else + return __hmax_nan(a, b); +#endif +} + +template <> +__device__ __inline__ BFloat16 _Max(BFloat16 a, BFloat16 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a > b ? a : b); +#else + return BFloat16(__hmax_nan((__nv_bfloat16)a, (__nv_bfloat16)b)); +#endif +} + +#undef ISNAN_HALF +#undef ISNAN_BFLOAT16 +#undef NAN_HALF +#undef NAN_BFLOAT16 + template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } diff --git a/onnxruntime/core/providers/cuda/math/clip.cc b/onnxruntime/core/providers/cuda/math/clip.cc index ea986798659e7..71096cb2d1705 100644 --- a/onnxruntime/core/providers/cuda/math/clip.cc +++ b/onnxruntime/core/providers/cuda/math/clip.cc @@ -59,33 +59,11 @@ Status Clip_6::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } -namespace clip_internal { -template -struct LowMax { - constexpr static T low() { - return std::numeric_limits::lowest(); - } - constexpr static T max() { - return std::numeric_limits::max(); - } -}; - -template <> -struct LowMax { - static MLFloat16 low() { - return MLFloat16::FromBits(math::floatToHalf(std::numeric_limits::lowest())); - } - static MLFloat16 max() { - return MLFloat16::FromBits(math::floatToHalf(std::numeric_limits::max())); - } -}; -} // namespace clip_internal - template struct Clip::ComputeImpl { void operator()(cudaStream_t stream, const Tensor* X, const Tensor* min, const Tensor* max, Tensor* Y) const { - auto min_default = clip_internal::LowMax::low(); - auto max_default = clip_internal::LowMax::max(); + constexpr T min_default = std::numeric_limits::lowest(); + constexpr T max_default = std::numeric_limits::max(); const T* min_data = nullptr; const T* max_data = nullptr; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index cb6fc165a932f..9c01df13741e1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -676,7 +676,7 @@ namespace Dml bool IsCpuOnDmlOperator(const onnxruntime::Node& node) { - auto cpuOnDmlOperators = std::array{ + auto cpuOnDmlOperators = std::array{ "SequenceAt", "SequenceConstruct", "SequenceEmpty", @@ -684,7 +684,8 @@ namespace Dml "SequenceErase", "SequenceInsert", "OptionalGetElement", - "OptionalHasElement" + "OptionalHasElement", + "If", }; for (auto& cpuOnDmlOperator : cpuOnDmlOperators) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 412207fd3cbd2..d4d7ee1311874 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -451,8 +451,19 @@ class DmlOperatorElementwiseClip11 : public DmlOperator // logic for some corner test case // Same applies to min and max value. opDesc.MinMaxDataType = this->m_inputTensorDescs[0].GetDmlDataType(); - CastToClampedScalarUnion(opDesc.MinMaxDataType, -DBL_MAX, /*out*/&opDesc.Min); - CastToClampedScalarUnion(opDesc.MinMaxDataType, DBL_MAX, /*out*/&opDesc.Max); + + if (opDesc.MinMaxDataType == DML_TENSOR_DATA_TYPE_FLOAT16 || opDesc.MinMaxDataType == DML_TENSOR_DATA_TYPE_FLOAT32 || opDesc.MinMaxDataType == DML_TENSOR_DATA_TYPE_FLOAT64) + { + CastToClampedScalarUnion(opDesc.MinMaxDataType, -DBL_MAX, /*out*/&opDesc.Min); + CastToClampedScalarUnion(opDesc.MinMaxDataType, DBL_MAX, /*out*/&opDesc.Max); + } + else + { + // It's not safe to use DBL_MAX for non-float datatypes because not all integer can be represented in the range. + // For example, static_cast(static_cast(INT64_MAX)) will yield a negative number. + CastToClampedScalarUnion(opDesc.MinMaxDataType, -INT64_MAX, /*out*/&opDesc.Min); + CastToClampedScalarUnion(opDesc.MinMaxDataType, UINT64_MAX, /*out*/&opDesc.Max); + } if (kernelInfo.IsInputValid(1)) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index db8922439ed8a..2375131cb34ea 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -196,7 +196,7 @@ ONNX_OPERATOR_KERNEL_EX( (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes()), If); ONNX_OPERATOR_KERNEL_EX( @@ -207,7 +207,7 @@ ONNX_OPERATOR_KERNEL_EX( (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes()), If); ONNX_OPERATOR_KERNEL_EX( @@ -218,7 +218,7 @@ ONNX_OPERATOR_KERNEL_EX( (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), If); } diff --git a/onnxruntime/core/providers/js/operators/unsqueeze.h b/onnxruntime/core/providers/js/operators/unsqueeze.h index f15a3008895aa..dd5f7e0525669 100644 --- a/onnxruntime/core/providers/js/operators/unsqueeze.h +++ b/onnxruntime/core/providers/js/operators/unsqueeze.h @@ -29,9 +29,8 @@ class Unsqueeze final : public JsKernel, public UnsqueezeBase { ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a scalar or a vector tensor."); - auto nDims = static_cast(axes_tensor->Shape()[0]); - const auto* data = axes_tensor->Data(); - axes.assign(data, data + nDims); + auto data_span = axes_tensor->template DataAsSpan(); + axes.assign(data_span.begin(), data_span.end()); } else { axes.assign(axes_.begin(), axes_.end()); } diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 1f9c61780f27a..71a02f076c8cc 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -303,33 +303,20 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_); } else { auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); - auto allocator_name = tensor.GetTensorMemoryInfo().GetAllocatorName(); - ov_tensor_data_t ov_tensor_key; - ort_tensor_key_t ort_tensor_key{tensor.GetTensorRawData(), allocator_name}; - if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { - ov_tensor_key = it->second; - } else { - // Does this make sense for both types of allocators? + ort_tensor_key_t ort_tensor_key{input_name}; + auto it = ort_ov_tensor_map.find(ort_tensor_key); + if ((it == ort_ov_tensor_map.end()) || + (it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) { + ov_tensor_data_t ov_tensor_data; auto input = graph_input_info.at(input_idx); - if (allocator_name == OpenVINO_RT_NPU) { - ov_tensor_key.copy_needed = false; - ov_tensor_key.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape(), - (void*)tensor.GetTensorRawData()); - } else { - ov_tensor_key.copy_needed = true; - ov_tensor_key.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape()); - } - ort_ov_tensor_map.emplace(ort_tensor_key, ov_tensor_key); + ov_tensor_data.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape(), + const_cast(tensor.GetTensorRawData())); - if (ov_tensor_key.copy_needed) { - const char* ort_tensor_data = tensor.GetTensorData(); - size_t tensor_data_size = ov_tensor_key.tensor_ptr->get_byte_size(); - auto ort_batch_memory_offset = ort_tensor_data + tensor_data_size * batch_slice_idx; - std::memcpy(ov_tensor_key.tensor_ptr->data(), ort_batch_memory_offset, tensor_data_size); - } + ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); + ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; try { - infer_request->SetTensor(input_name, ov_tensor_key.tensor_ptr); + infer_request->SetTensor(input_name, ov_tensor_data.tensor_ptr); } catch (const char* msg) { ORT_THROW(msg); } @@ -362,23 +349,16 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque infer_request, output_name, subgraph_context_.output_names); - auto allocator_name = tensor.GetTensorMemoryInfo().GetAllocatorName(); - - ov_tensor_data_t ov_tensor_data; - ort_tensor_key_t ort_tensor_key{tensor.GetTensorRawData(), allocator_name}; - if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { - ov_tensor_data = it->second; - } else { + ort_tensor_key_t ort_tensor_key{output_name}; + const auto& it = ort_ov_tensor_map.find(ort_tensor_key); + if ((it == ort_ov_tensor_map.end()) || + (it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) { + ov_tensor_data_t ov_tensor_data; auto output = graph_output_info.at(output_idx); - if (allocator_name == OpenVINO_RT_NPU) { - ov_tensor_data.copy_needed = false; - ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape(), - (void*)tensor.GetTensorRawData()); - } else { - ov_tensor_data.copy_needed = true; - ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape()); - } - ort_ov_tensor_map.emplace(ort_tensor_key, ov_tensor_data); + ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); + ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape(), + const_cast(tensor.GetTensorRawData())); + ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; try { infer_request->SetTensor(output_name, ov_tensor_data.tensor_ptr); @@ -556,25 +536,6 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe size_t batch_slice = 0; FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); } - } else { - size_t batch_size = 1; - Ort::UnownedValue output_tensor = - GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); - auto allocator_name = output_tensor.GetTensorMemoryInfo().GetAllocatorName(); - ov_tensor_data_t ov_tensor_data; - ort_tensor_key_t ort_tensor_key{output_tensor.GetTensorRawData(), allocator_name}; - if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { - ov_tensor_data = it->second; - } else { - ORT_THROW(log_tag + "Expected all outputs to have associated OV::Tensor's"); - } - - if (ov_tensor_data.copy_needed) { - auto ort_tensor_data = output_tensor.GetTensorMutableData(); - size_t tensor_data_size = ov_tensor_data.tensor_ptr->get_byte_size(); - auto ort_batch_memory_offset = ort_tensor_data /*+ tensor_data_size * batch_size*/; - std::memcpy(ort_batch_memory_offset, ov_tensor_data.tensor_ptr->data(), tensor_data_size); - } } } diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index cd69e88f994b9..12502a1d83c5d 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -23,7 +23,7 @@ namespace openvino_ep { struct ov_tensor_data_t { OVTensorPtr tensor_ptr; - bool copy_needed; + const void* ort_ptr; }; class InferRequestsQueue; @@ -67,7 +67,7 @@ class BasicBackend : public IBackend { OVRemoteContextPtr remote_context_; #endif - using ort_tensor_key_t = std::pair; + using ort_tensor_key_t = const std::string; std::map ort_ov_tensor_map; }; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 285781aaa3559..0358fae3c2115 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -325,159 +325,6 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } -static Status GetFloatBytes(float f32_val, Qnn_DataType_t qnn_data_type, std::vector& bytes) { - switch (qnn_data_type) { - case QNN_DATATYPE_FLOAT_32: { - bytes.resize(sizeof(float)); - std::memcpy(bytes.data(), &f32_val, bytes.size()); - break; - } - case QNN_DATATYPE_FLOAT_16: { - bytes.resize(sizeof(MLFloat16)); - const MLFloat16 f16_val(f32_val); - std::memcpy(bytes.data(), &f16_val, bytes.size()); - break; - } - default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Qnn Data Type: ", qnn_data_type, " is not supported"); - } - - return Status::OK(); -} - -static Status DecomposeHardSigmoid(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) { - ORT_UNUSED_PARAMETER(logger); - const auto& onnx_node_name = utils::GetNodeName(node_unit); - const auto& input = node_unit.Inputs()[0]; - const auto& output = node_unit.Outputs()[0]; - - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input.node_arg, input_shape), "Cannot get shape of input 0"); - - Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; - ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false /*is_quantized*/, input.node_arg.TypeAsProto(), qnn_data_type)); - - NodeAttrHelper node_helper(node_unit); - - // - // Create Mul node. - // - const OnnxAttrInfo onnx_alpha_attr{"alpha", 0.2f}; - const OnnxAttrInfo onnx_beta_attr{"beta", 0.5}; - std::string alpha_input_name = MakeString("ort_qnn_ep_", onnx_node_name, "_HardSigmoid_Mul_alpha"); - std::vector alpha_bytes; - ORT_RETURN_IF_ERROR(GetFloatBytes(GetOnnxAttr(node_helper, onnx_alpha_attr), qnn_data_type, alpha_bytes)); - - QnnTensorWrapper alpha_input(alpha_input_name, - QNN_TENSOR_TYPE_STATIC, - qnn_data_type, - QnnQuantParamsWrapper(), - {1}, // shape - std::move(alpha_bytes)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(alpha_input)), "Failed to add alpha input tensor."); - - std::string mul_output_name = MakeString("ort_qnn_ep_", onnx_node_name, "_HardSigmoid_Mul_output"); - std::string mul_node_name = MakeString("ort_qnn_ep_", onnx_node_name, "_HardSigmoid_Mul_node"); - QnnTensorWrapper mul_output(mul_output_name, - QNN_TENSOR_TYPE_NATIVE, - qnn_data_type, - QnnQuantParamsWrapper(), - std::vector(input_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(mul_output)), "Failed to add Mul output tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(mul_node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_ELEMENT_WISE_MULTIPLY, - {input_names[0], alpha_input_name}, // input names - {mul_output_name}, // output names - {}, - do_op_validation), - "Failed to add Mul node."); - - // - // Create Add node. - // - - std::string beta_input_name = MakeString("ort_qnn_ep_", onnx_node_name, "_HardSigmoid_Mul_beta"); - std::vector beta_bytes; - ORT_RETURN_IF_ERROR(GetFloatBytes(GetOnnxAttr(node_helper, onnx_beta_attr), qnn_data_type, beta_bytes)); - - QnnTensorWrapper beta_input(beta_input_name, - QNN_TENSOR_TYPE_STATIC, - qnn_data_type, - QnnQuantParamsWrapper(), - {1}, // shape - std::move(beta_bytes)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(beta_input)), "Failed to add beta input tensor."); - - std::string add_output_name = MakeString("ort_qnn_ep_", onnx_node_name, "_HardSigmoid_Add_output"); - std::string add_node_name = MakeString("ort_qnn_ep_", onnx_node_name, "_HardSigmoid_Add_node"); - QnnTensorWrapper add_output(add_output_name, - QNN_TENSOR_TYPE_NATIVE, - qnn_data_type, - QnnQuantParamsWrapper(), - std::vector(input_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(add_output)), "Failed to add Add output tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(add_node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_ELEMENT_WISE_ADD, - {mul_output_name, beta_input_name}, // input names - {add_output_name}, // output names - {}, - do_op_validation), - "Failed to add Add node."); - - // - // Create ReluMinMax node. - // - - std::vector param_tensor_names; - - // Parameter 'min_value' - { - Qnn_Scalar_t min_value = QNN_SCALAR_INIT; - min_value.dataType = QNN_DATATYPE_FLOAT_32; - min_value.floatValue = 0.0f; - - QnnParamWrapper qnn_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, min_value); - param_tensor_names.push_back(qnn_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_param)); - } - - // Parameter 'max_value' - { - Qnn_Scalar_t max_value = QNN_SCALAR_INIT; - max_value.dataType = QNN_DATATYPE_FLOAT_32; - max_value.floatValue = 1.0f; - - QnnParamWrapper qnn_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, max_value); - param_tensor_names.push_back(qnn_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_param)); - } - - const std::string& output_name = output.node_arg.Name(); - std::string relu_min_max_node_name = MakeString("ort_qnn_ep_", onnx_node_name, "_HardSigmoid_ReluMinMax_node"); - QnnTensorWrapper output_tensor(output_name, - qnn_model_wrapper.GetTensorType(output_name), - qnn_data_type, - QnnQuantParamsWrapper(), - std::vector(input_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(relu_min_max_node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_RELU_MIN_MAX, - {add_output_name}, // input names - {output_name}, // output names - std::move(param_tensor_names), - do_op_validation), - "Failed to add ReluMinMax node."); - - return Status::OK(); -} - Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -546,13 +393,8 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w } if (op_type == "HardSigmoid") { - // direct conversion to ElementWiseNeuron has issue to finalize the graph for FP16 data type - // still decompose it to Mul, Add, ReluMinMax int32_t onnx_data_type = 0; ORT_RETURN_IF_ERROR(utils::GetOnnxTensorElemDataType(node_unit.Inputs()[0].node_arg, onnx_data_type)); - if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - return DecomposeHardSigmoid(qnn_model_wrapper, node_unit, std::move(input_names), logger, do_op_validation); - } ORT_RETURN_IF_ERROR(ProcessNodeAttribute(qnn_model_wrapper, node_unit, "alpha", QNN_OP_ELEMENT_WISE_NEURON_PARAM_ALPHA, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index db5c2c5cb32ba..eaffe1e2ac224 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1179,17 +1179,16 @@ Status QnnBackendManager::ExtractProfilingEventExtended( #endif if (!tracelogging_provider_ep_enabled) { - // QNN issue, the version number not correct, ticket created - // if (event_data_extended.version == QNN_PROFILE_DATA_VERSION_1) { - outfile << event_data_extended.v1.timestamp << "," - << message << "," - << ExtractQnnScalarValue(event_data_extended.v1.value) << "," - << unit << "," - << "BACKEND" - << "," - << eventLevel << "," - << (event_data_extended.v1.identifier ? event_data_extended.v1.identifier : "NULL") << "\n"; - //} + if (event_data_extended.version == QNN_PROFILE_DATA_VERSION_1) { + outfile << event_data_extended.v1.timestamp << "," + << message << "," + << ExtractQnnScalarValue(event_data_extended.v1.value) << "," + << unit << "," + << "BACKEND" + << "," + << eventLevel << "," + << (event_data_extended.v1.identifier ? event_data_extended.v1.identifier : "NULL") << "\n"; + } } else { #ifdef _WIN32 LogQnnProfileEventAsTraceLogging( diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc index ce87ac4a3d21c..caf4725626338 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -170,9 +170,11 @@ static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node return false; } - // check Q/DQ have same scale type and different zero point type - return (dq_zp_tensor_proto->data_type() != q_zp_tensor_proto->data_type()) && - (dq_scale_tensor_proto->data_type() == q_scale_tensor_proto->data_type()); + // For scale, ensure that the Q/DQ have same scale type. + // + // For zero-point: we previously only fused (DQ -> Q) into a Convert op if the quantization types differed. + // However, a single Convert op is faster than (DQ -> Q), so we should always fuse regardless of the zero-point type. + return (dq_scale_tensor_proto->data_type() == q_scale_tensor_proto->data_type()); } } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 698ceaea7c3b7..24132b98e3757 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -546,8 +546,8 @@ static bool EpSharedContextsHasAllGraphs(const onnxruntime::GraphViewer& graph_v if (qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) { const std::string& graph_name = node.Name(); - auto shared_qnn_model = SharedContext::GetInstance().GetSharedQnnModel(graph_name); - if (nullptr == shared_qnn_model) { + bool has_shared_qnn_model = SharedContext::GetInstance().HasQnnModel(graph_name); + if (!has_shared_qnn_model) { LOGS(logger, VERBOSE) << "Graph: " << graph_name << " from EpContext node not found from shared EP contexts."; return false; } @@ -566,8 +566,8 @@ static bool EpSharedContextsHasAllGraphs(const std::vectorName(); - auto shared_qnn_model = SharedContext::GetInstance().GetSharedQnnModel(graph_name); - if (nullptr == shared_qnn_model) { + bool has_shared_qnn_model = SharedContext::GetInstance().HasQnnModel(graph_name); + if (!has_shared_qnn_model) { LOGS(logger, VERBOSE) << "Graph: " << graph_name << " from EpContext node not found from shared EP contexts."; return false; } @@ -776,10 +776,6 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod NodeComputeInfo compute_info; compute_info.create_state_func = [&](ComputeContext* context, FunctionState* state) { LOGS(logger, VERBOSE) << "compute_info.create_state_func context->node_name: " << context->node_name; - if (use_shared_model_) { - *state = qnn_models_shared_[context->node_name].get(); - return 0; - } *state = qnn_models_[context->node_name].get(); return 0; }; @@ -895,8 +891,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF(nullptr == qnn_model_shared, "Graph: " + key + " not found from shared EP contexts."); ORT_RETURN_IF_ERROR(qnn_model_shared->SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); ORT_RETURN_IF_ERROR(qnn_model_shared->SetupQnnInputOutput(logger)); - qnn_models_shared_.emplace(graph_meta_id, qnn_model_shared); - use_shared_model_ = true; + qnn_models_.emplace(graph_meta_id, std::move(qnn_model_shared)); ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); } return Status::OK(); @@ -940,12 +935,12 @@ Status QNNExecutionProvider::Compile(const std::vector& fused } if (share_ep_contexts_ && qnn_models.size() > 0) { - std::vector> shared_qnn_models; + std::vector> shared_qnn_models; for (auto& [key, value] : qnn_models) { shared_qnn_models.push_back(std::move(qnn_models[key])); } std::string duplicate_graph_names; - bool has_duplicate_graph = SharedContext::GetInstance().SetSharedQnnModel(shared_qnn_models, + bool has_duplicate_graph = SharedContext::GetInstance().SetSharedQnnModel(std::move(shared_qnn_models), duplicate_graph_names); ORT_RETURN_IF(has_duplicate_graph, "Duplicate graph names detect across sessions: " + duplicate_graph_names); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 9cd73edbff0e0..e0eaf31c94a36 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -35,26 +35,34 @@ class SharedContext { return !shared_qnn_models_.empty(); } - std::shared_ptr GetSharedQnnModel(const std::string& model_name) { + bool HasQnnModel(const std::string& model_name) { + auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + return it != shared_qnn_models_.end(); + } + + std::unique_ptr GetSharedQnnModel(const std::string& model_name) { const std::lock_guard lock(mtx_); auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), - [&model_name](const std::shared_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); if (it == shared_qnn_models_.end()) { return nullptr; } - return *it; + auto qnn_model = std::move(*it); + shared_qnn_models_.erase(it); + return qnn_model; } - bool SetSharedQnnModel(std::vector>& shared_qnn_models, + bool SetSharedQnnModel(std::vector>&& shared_qnn_models, std::string& duplicate_graph_names) { const std::lock_guard lock(mtx_); bool graph_exist = false; for (auto& shared_qnn_model : shared_qnn_models) { auto& model_name = shared_qnn_model->Name(); auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), - [&model_name](const std::shared_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); if (it == shared_qnn_models_.end()) { - shared_qnn_models_.push_back(shared_qnn_model); + shared_qnn_models_.push_back(std::move(shared_qnn_model)); } else { duplicate_graph_names.append(model_name + " "); graph_exist = true; @@ -70,7 +78,7 @@ class SharedContext { SharedContext(const SharedContext&) = delete; SharedContext& operator=(const SharedContext&) = delete; - std::vector> shared_qnn_models_; + std::vector> shared_qnn_models_; // Producer sessions can be in parallel // Consumer sessions have to be after producer sessions initialized OrtMutex mtx_; @@ -128,8 +136,6 @@ class QNNExecutionProvider : public IExecutionProvider { qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; std::unique_ptr qnn_backend_manager_; std::unordered_map> qnn_models_; - std::unordered_map> qnn_models_shared_; - bool use_shared_model_ = false; bool context_cache_enabled_ = false; std::string context_cache_path_cfg_ = ""; std::string context_node_name_prefix_ = ""; @@ -142,7 +148,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t device_id_ = 0; qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; uint32_t default_rpc_control_latency_ = 0; - bool enable_HTP_FP16_precision_ = false; + bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; #ifdef _WIN32 onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 11073ab3584eb..a1f5eba9a24c8 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -731,10 +731,9 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenR std::vector axes; size_t num_inputs = ctx->InputCount(); - if (num_inputs == 2) { + const Tensor* axes_tensor = num_inputs == 2 ? ctx->Input(1) : nullptr; // optional input. may be nullptr. + if (axes_tensor != nullptr) { // override the attribute value with the input value for reduction_axes - const Tensor* axes_tensor = ctx->Input(1); - ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); auto nDims = static_cast(axes_tensor->Shape()[0]); const auto* data = axes_tensor->Data(); diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 6f8e153cd1232..de16bff7a49e0 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -49,6 +49,9 @@ struct OrtVitisAIEpAPI { void (*create_ep_context_nodes)( const std::vector>& eps, vaip_core::DllSafe>* ret_value) = nullptr; + int (*vitisai_ep_on_run_start)( + const std::vector>& eps, const void* state, + vaip_core::DllSafe (*get_config_entry)(const void* state, const char* entry_name)) = nullptr; void Ensure() { if (handle_) return; @@ -73,6 +76,7 @@ struct OrtVitisAIEpAPI { std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version", (void**)&vaip_get_version); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "create_ep_context_nodes", (void**)&create_ep_context_nodes)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_on_run_start", (void**)&vitisai_ep_on_run_start)); } private: @@ -105,6 +109,15 @@ std::optional> create_ep_context_nodes( return std::nullopt; } +int vitisai_ep_on_run_start( + const std::vector>& eps, const void* state, + vaip_core::DllSafe (*get_config_entry)(const void* state, const char* entry_name)) { + if (s_library_vitisaiep.vitisai_ep_on_run_start) { + return s_library_vitisaiep.vitisai_ep_on_run_start(eps, state, get_config_entry); + } + return 100; +} + struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { op_kernel_ = diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index ec2b98e5b6eda..1a90f4c7fdebb 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -16,3 +16,7 @@ std::shared_ptr get_kernel_registry_vitisaiep(); const std::vector& get_domains_vitisaiep(); std::optional> create_ep_context_nodes( const std::vector>& eps); + +int vitisai_ep_on_run_start( + const std::vector>& eps, const void* state, + vaip_core::DllSafe (*get_config_entry)(const void* state, const char* entry_name)); diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 57c3e21b70104..09b115b4a57fc 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -97,4 +97,22 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector ep_context_node_ptrs; + auto get_config_entry = [](const void* state, const char* entry_name) -> vaip_core::DllSafe { + const onnxruntime::RunOptions& run_options = *static_cast(state); + auto ret = run_options.GetConfigOptions().GetConfigEntry(std::string(entry_name)); + if (ret) { + return vaip_core::DllSafe(new std::string(ret.value())); + } else { + return {}; + }; + }; + auto error_code = vitisai_ep_on_run_start(**execution_providers_, (const void*)&run_options, get_config_entry); + if (error_code) { + return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL, std::to_string(error_code)); + } + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 24692dd45ca49..05d2a976815b9 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -31,7 +31,7 @@ class VitisAIExecutionProvider : public IExecutionProvider { const IKernelLookup& /*kernel_lookup*/) const override; int GetDeviceId() const { return 0; } - + common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/webnn/allocator.cc b/onnxruntime/core/providers/webnn/allocator.cc new file mode 100644 index 0000000000000..9c5cd651e1f00 --- /dev/null +++ b/onnxruntime/core/providers/webnn/allocator.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webnn/allocator.h" + +#include "core/common/safeint.h" + +namespace onnxruntime { +namespace webnn { + +void* WebNNTensorAllocator::Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + if (!emscripten::val::module_property("shouldTransferToMLTensor").as()) { + // We don't need to transfer the tensor to an MLTensor, so we don't need to allocate an MLTensor id. + return nullptr; + } + void* p = EM_ASM_PTR({ return Module.jsepReserveTensorId(); }); + allocations_[p] = size; + stats_.num_allocs++; + stats_.bytes_in_use += SafeInt(size); + return p; +} + +void WebNNTensorAllocator::Free(void* p) { + if (p == nullptr) { + return; + } + EM_ASM({ Module.jsepReleaseTensorId($0); }, p); + size_t size = allocations_[p]; + stats_.bytes_in_use -= size; + allocations_.erase(p); +} + +void WebNNTensorAllocator::GetStats(AllocatorStats* stats) { + *stats = stats_; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/allocator.h b/onnxruntime/core/providers/webnn/allocator.h new file mode 100644 index 0000000000000..c06da909801cc --- /dev/null +++ b/onnxruntime/core/providers/webnn/allocator.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/common/inlined_containers.h" +#include "core/framework/allocator.h" +#include "core/framework/ortdevice.h" + +namespace onnxruntime { +namespace webnn { + +class WebNNTensorAllocator : public IAllocator { + public: + WebNNTensorAllocator() : IAllocator(OrtMemoryInfo(WEBNN_TENSOR, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), 0, OrtMemTypeDefault)) {} + + void* Alloc(size_t size) override; + + void Free(void* p) override; + + void GetStats(AllocatorStats* stats) override; + + private: + AllocatorStats stats_; + InlinedHashMap allocations_; +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index c4a633fcc92bb..b90c7d76a6507 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -12,6 +12,19 @@ namespace onnxruntime { namespace webnn { +WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type) { + if (device_type == "gpu") { + return WebnnDeviceType::GPU; + } + if (device_type == "cpu") { + return WebnnDeviceType::CPU; + } + if (device_type == "npu") { + return WebnnDeviceType::NPU; + } + ORT_THROW("Unknown WebNN deviceType."); +} + InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer) { InitializedTensorSet all_initializers; if (graph_viewer.IsSubgraph()) { @@ -243,5 +256,10 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) { } } +bool IsMLTensorSupported() { + static bool is_supported = !emscripten::val::global("MLTensor").isUndefined(); + return is_supported; +} + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index dd4a8acc662ef..529463f0808ad 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -31,6 +31,8 @@ enum class WebnnDeviceType { NPU, }; +WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type); + // Collects all the initializer tensors in the subGraph and its ancestor graphs. InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer); @@ -195,6 +197,7 @@ static const InlinedHashMap op_map = { {"LessOrEqual", "lesserOrEqual"}, {"Log", "log"}, {"LpPool", "l2Pool2d"}, + {"LSTM", "lstm"}, {"MatMul", "matmul"}, {"MatMulInteger", "matmulInteger"}, {"Max", "max"}, @@ -291,5 +294,7 @@ bool GetBidirectionalBroadcastShape(std::vector& shape_a, bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type); +bool IsMLTensorSupported(); + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc new file mode 100644 index 0000000000000..6213b039fb2f9 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime::webnn { + +class LstmOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; +}; + +void LstmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (node.InputDefs().size() > 4 && node.InputDefs()[4]->Exists()) { + model_builder.AddInitializerToSkip(node.InputDefs()[4]->Name()); // sequence_lens + model_builder.AddInputToSkip(node.InputDefs()[4]->Name()); + } +} + +Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + uint32_t hidden_size = helper.Get("hidden_size", 1); + + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input's shape"); + uint32_t steps = static_cast(input_shape[0]); + + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val weight = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val recurrent_weight = model_builder.GetOperand(input_defs[2]->Name()); + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + options.set("layout", emscripten::val("iofg")); + + if (input_defs.size() > 3 && input_defs[3]->Exists()) { + emscripten::val bias = model_builder.GetOperand(input_defs[3]->Name()); + emscripten::val split_options = emscripten::val::object(); + split_options.set("axis", 1); + split_options.set("label", node.Name() + "_split"); + // Split it to bias and recurrentBias. + emscripten::val splitted_biases = + model_builder.GetBuilder().call("split", bias, /*splits*/ 2, split_options); + options.set("bias", splitted_biases[0]); + options.set("recurrentBias", splitted_biases[1]); + } + if (input_defs.size() > 5 && input_defs[5]->Exists()) { + options.set("initialHiddenState", model_builder.GetOperand(input_defs[5]->Name())); + } + if (input_defs.size() > 6 && input_defs[6]->Exists()) { + options.set("initialCellState", model_builder.GetOperand(input_defs[6]->Name())); + } + if (input_defs.size() > 7 && input_defs[7]->Exists()) { + options.set("peepholeWeight", model_builder.GetOperand(input_defs[7]->Name())); + } + + std::string direction = helper.Get("direction", "forward"); + if (direction == "forward") { + options.set("direction", emscripten::val("forward")); + } else if (direction == "reverse") { + options.set("direction", emscripten::val("backward")); + } else if (direction == "bidirectional") { + options.set("direction", emscripten::val("both")); + } + + const auto& output_defs = node.OutputDefs(); + bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); + bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); + bool has_Y_c = output_defs.size() > 2 && output_defs[2]->Exists(); + options.set("returnSequence", has_Y); + + if (helper.HasAttr("activations")) { + const auto activations = helper.Get("activations", std::vector{"Sigmoid", "Tanh", "Tanh"}); + emscripten::val opt_activations = emscripten::val::array(); + for (size_t i = 0; i < 3; ++i) { + const std::string& activation = activations[i]; + if (activation == "Relu") { + opt_activations.call("push", emscripten::val("relu")); + } else if (activation == "Sigmoid") { + opt_activations.call("push", emscripten::val("sigmoid")); + } else if (activation == "Tanh") { + opt_activations.call("push", emscripten::val("tanh")); + } + } + + options.set("activations", opt_activations); + } + + emscripten::val outputs = model_builder.GetBuilder().call("lstm", input, weight, recurrent_weight, + steps, hidden_size, options); + + if (has_Y) { + model_builder.AddOperand(output_defs[0]->Name(), outputs[2]); + } + if (has_Y_h) { + model_builder.AddOperand(output_defs[1]->Name(), outputs[0]); + } + if (has_Y_c) { + model_builder.AddOperand(output_defs[2]->Name(), outputs[1]); + } + + return Status::OK(); +} + +bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 3) { + LOGS(logger, ERROR) << "LSTM: input size must be greater than or equal to 3"; + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, ERROR) << "Cannot get input's shape"; + return false; + } + int32_t steps = static_cast(input_shape[0]); + + if (input_defs.size() > 4 && input_defs[4]->Exists()) { + if (!Contains(initializers, input_defs[4]->Name())) { + LOGS(logger, ERROR) << "LSTM: sequence_lens must be constant"; + return false; + } + + const auto& sequence_lens_tensor = *initializers.at(input_defs[4]->Name()); + std::vector sequence_lens; + if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, logger)) { + LOGS(logger, ERROR) << "Cannot read sequence lens tensor"; + return false; + } + if (std::any_of(sequence_lens.begin(), sequence_lens.end(), + [steps](int32_t lens) -> bool { return steps != lens; })) { + LOGS(logger, ERROR) << "LSTM: every sequence length must be equal to input shape[0]"; + return false; + } + } + + NodeAttrHelper helper(node); + if (helper.HasAttr("activations")) { + const auto activations = helper.Get("activations", std::vector{"Sigmoid", "Tanh", "Tanh"}); + + if (activations.size() >= 6) { + if (activations[0] != activations[3] || activations[1] != activations[4] || activations[2] != activations[5]) { + LOGS(logger, ERROR) << "LSTM: forward and backward activations must be the same"; + return false; + } + } + + const InlinedHashSet supported_activations = {"Relu", "Tanh", "Sigmoid"}; + if (std::any_of(activations.begin(), activations.end(), + [&supported_activations](const std::string& activation) -> bool { + return !supported_activations.contains(activation); + })) { + LOGS(logger, ERROR) << "LSTM: activations must be one of Relu, Tanh, Sigmoid"; + return false; + } + } + + if (helper.Get("clip", std::numeric_limits::max()) != std::numeric_limits::max()) { + LOGS(logger, ERROR) << "LSTM: clip is not supported"; + return false; + } + + if (helper.Get("input_forget", 0) != 0) { + LOGS(logger, ERROR) << "LSTM: input_forget == 1 is not supported"; + return false; + } + + if (helper.Get("layout", 0) != 0) { + LOGS(logger, ERROR) << "LSTM: batchwise (layout == 1) is not supported"; + return false; + } + + return true; +} + +bool LstmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type = 0; // input data type + int32_t input1_type = 0; // weight data type + int32_t input2_type = 0; // recurrentWeight data type + int32_t input3_type = 0; // bias data type + // input4 sequence_lens is skipped. + int32_t input5_type = 0; // initialHiddenState data type + int32_t input6_type = 0; // initialCellState data type + int32_t input7_type = 0; // peepholeWeight data type + bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists(); + bool has_input6 = input_defs.size() > 6 && input_defs[6]->Exists(); + bool has_input7 = input_defs.size() > 7 && input_defs[7]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + !GetType(*input_defs[2], input2_type, logger) || + (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || + (has_input5 && !GetType(*input_defs[5], input5_type, logger)) || + (has_input6 && !GetType(*input_defs[6], input6_type, logger)) || + (has_input7 && !GetType(*input_defs[7], input7_type, logger))) { + return false; + } + + InlinedVector input_types = {input0_type, input1_type, input2_type}; + if (has_input3) { + input_types.push_back(input3_type); + } + if (has_input5) { + input_types.push_back(input5_type); + } + if (has_input6) { + input_types.push_back(input6_type); + } + if (has_input7) { + input_types.push_back(input7_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); +} + +bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& output_defs = node.OutputDefs(); + const auto& op_type = node.OpType(); + int32_t Y_type = 0; + int32_t Y_h_type = 0; + int32_t Y_c_type = 0; + bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); + bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); + bool has_Y_c = output_defs.size() > 2 && output_defs[2]->Exists(); + + if (has_Y && GetType(*output_defs[0], Y_type, logger)) { + return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); + } + if (has_Y_h && GetType(*output_defs[1], Y_h_type, logger)) { + return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "outputs", "Y_h", logger); + } + if (has_Y_c && GetType(*output_defs[2], Y_c_type, logger)) { + return IsDataTypeSupportedByOp(op_type, Y_c_type, wnn_limits, "outputs", "Y_c", logger); + } + + return false; +} + +void CreateLstmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace onnxruntime::webnn diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 8cd2e8d0ffad3..fcfdb146bff34 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -11,21 +11,30 @@ #include "core/common/safeint.h" #include "core/graph/onnx_protobuf.h" #include "core/providers/common.h" -#include "core/providers/webnn/builders/helper.h" #include "model.h" namespace onnxruntime { namespace webnn { -Model::Model(const emscripten::val& context, const emscripten::val& graph, const logging::Logger& logger) +Model::Model(const emscripten::val& context, const emscripten::val& graph, const logging::Logger& logger, bool use_dispatch) : wnn_context_(context), wnn_graph_(graph), - logger_(logger) {} + logger_(logger), + use_dispatch_(use_dispatch) {} Model::~Model() {} Status Model::Predict(const InlinedHashMap& inputs, const InlinedHashMap& outputs) { + if (use_dispatch_) { + return Dispatch(inputs, outputs); + } else { + return Compute(inputs, outputs); + } +} + +onnxruntime::common::Status Model::Compute(const InlinedHashMap& inputs, + const InlinedHashMap& outputs) { for (const auto& input : inputs) { const std::string& name = input.first; const struct OnnxTensorData tensor = input.second; @@ -142,6 +151,40 @@ Status Model::Predict(const InlinedHashMap& inputs, return Status::OK(); } +onnxruntime::common::Status Model::Dispatch(const InlinedHashMap& inputs, + const InlinedHashMap& outputs) { + auto jsepEnsureTensor = emscripten::val::module_property("jsepEnsureTensor"); + auto promises = emscripten::val::array(); + for (const auto& [_, tensor] : inputs) { + emscripten::val shape = emscripten::val::array(); + for (const auto& dim : tensor.tensor_info.shape) { + uint32_t dim_val = SafeInt(dim); + shape.call("push", dim_val); + } + auto ml_tensor = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); + promises.call("push", ml_tensor); + } + for (const auto& [_, tensor] : outputs) { + emscripten::val shape = emscripten::val::array(); + for (const auto& dim : tensor.tensor_info.shape) { + uint32_t dim_val = SafeInt(dim); + shape.call("push", dim_val); + } + auto ml_tensor = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); + promises.call("push", ml_tensor); + } + auto ml_tensors = emscripten::val::global("Promise").call("all", promises).await(); + for (const auto& [name, _] : inputs) { + wnn_inputs_.set(name, ml_tensors.call("shift")); + } + for (const auto& [name, _] : outputs) { + wnn_outputs_.set(name, ml_tensors.call("shift")); + } + wnn_context_.call("dispatch", wnn_graph_, wnn_inputs_, wnn_outputs_); + + return Status::OK(); +} + const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const { return input_output_info_.at(name); } @@ -156,6 +199,10 @@ void Model::SetOutputMap(InlinedHashMap&& output_map) { // Pre-allocate the input and output buffers for the WebNN graph. void Model::AllocateInputOutputBuffers() { + // We don't need to allocate JS ArrayBuffers if the WebNN API supports MLTensor. + if (use_dispatch_) { + return; + } for (const auto& input : inputs_) { const auto& input_info = input_output_info_.at(input); const auto input_shape = input_info.shape; diff --git a/onnxruntime/core/providers/webnn/builders/model.h b/onnxruntime/core/providers/webnn/builders/model.h index 5119dbbbc9858..c554dcb6f6877 100644 --- a/onnxruntime/core/providers/webnn/builders/model.h +++ b/onnxruntime/core/providers/webnn/builders/model.h @@ -56,6 +56,12 @@ class Model { size_t GetMappedOutputIdx(const std::string& name) const; private: + onnxruntime::common::Status Dispatch(const InlinedHashMap& inputs, + const InlinedHashMap& outputs); + + onnxruntime::common::Status Compute(const InlinedHashMap& inputs, + const InlinedHashMap& outputs); + emscripten::val wnn_context_ = emscripten::val::object(); emscripten::val wnn_graph_ = emscripten::val::object(); const logging::Logger& logger_; @@ -73,7 +79,9 @@ class Model { OrtMutex mutex_; - Model(const emscripten::val& context, const emscripten::val& path, const logging::Logger& logger); + bool use_dispatch_; + + Model(const emscripten::val& context, const emscripten::val& path, const logging::Logger& logger, bool use_dispatch); void SetInputOutputInfo(InlinedHashMap&& input_output_info) { input_output_info_ = std::move(input_output_info); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index f92fda8c74717..044baa738e8c4 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -340,7 +340,7 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { } // Explicitly release the WebNN builder to free memory. wnn_builder_ = emscripten::val::undefined(); - model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_)); + model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_, IsMLTensorSupported())); model->SetInputs(std::move(input_names_)); model->SetOutputs(std::move(output_names_)); model->SetInputOutputInfo(std::move(input_output_info_)); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 93a2b232a7d51..9df09af01ba67 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -121,6 +121,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateLogicalOpBuilder("Not", op_registrations); } + { // LSTM + CreateLstmOpBuilder("LSTM", op_registrations); + } + { // Max/Min CreateMaxMinOpBuilder("Max", op_registrations); CreateMaxMinOpBuilder("Min", op_registrations); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 61fe6d936e9d1..398dfc2d3f1c7 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -34,6 +34,7 @@ void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& o void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateLstmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc new file mode 100644 index 0000000000000..44e9bf9edf3d9 --- /dev/null +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webnn/data_transfer.h" + +#include +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace webnn { + +bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + // Copying data between MLTensors is not supported by WebNN. + return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || + (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); +} + +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + if (!emscripten::val::module_property("shouldTransferToMLTensor").as()) { + // We don't need to transfer the tensor to an MLTensor, so we don't need to copy the data. + return Status::OK(); + } + + size_t bytes = src.SizeInBytes(); + if (bytes > 0) { + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + const auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + EM_ASM({ Module.jsepUploadTensor($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast(src_data), bytes); + } else { + auto jsepDownloadTensor = emscripten::val::module_property("jsepDownloadTensor"); + auto subarray = emscripten::typed_memory_view(bytes, static_cast(dst_data)); + jsepDownloadTensor(reinterpret_cast(src_data), subarray).await(); + } + } + + return Status::OK(); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/data_transfer.h b/onnxruntime/core/providers/webnn/data_transfer.h new file mode 100644 index 0000000000000..03cfada46d1a0 --- /dev/null +++ b/onnxruntime/core/providers/webnn/data_transfer.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/data_transfer.h" + +namespace onnxruntime { +namespace webnn { + +class DataTransfer : public IDataTransfer { + public: + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index b729623c5d3d8..2258d1ac1cd8f 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -5,11 +5,14 @@ #include "webnn_execution_provider.h" #include "core/framework/compute_capability.h" +#include "core/framework/data_transfer_manager.h" #include "core/framework/memcpy.h" #include "core/framework/kernel_registry.h" #include "core/graph/graph_viewer.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/common/safeint.h" +#include "core/providers/webnn/allocator.h" +#include "core/providers/webnn/data_transfer.h" #include "builders/model.h" #include "builders/helper.h" @@ -18,20 +21,14 @@ namespace onnxruntime { WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) - : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { - // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. - if (webnn_device_flags.compare("cpu") == 0) { - wnn_device_type_ = webnn::WebnnDeviceType::CPU; - } else { - if (webnn_device_flags.compare("gpu") == 0) { - wnn_device_type_ = webnn::WebnnDeviceType::GPU; - } else if (webnn_device_flags.compare("npu") == 0) { - wnn_device_type_ = webnn::WebnnDeviceType::NPU; - } else { - ORT_THROW("Unknown WebNN deviceType."); - } - } - + : IExecutionProvider{ + onnxruntime::kWebNNExecutionProvider, + // If MLTensor is supported, we force all the tensors to be allocated as MLTensor. + OrtDevice( + webnn::IsMLTensorSupported() ? OrtDevice::GPU : OrtDevice::CPU, + OrtDevice::MemType::DEFAULT, + 0)}, + wnn_device_type_(webnn::DeviceTypeFromString(webnn_device_flags)) { wnn_context_ = emscripten::val::module_property("currentContext"); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); @@ -322,6 +319,32 @@ common::Status WebNNExecutionProvider::Compile(const std::vectorInput(0); + ORT_ENFORCE(X != nullptr, "Memcpy: input tensor is null"); + auto* Y = context->Output(0, X->Shape()); + ORT_ENFORCE(X != nullptr, "Memcpy: output tensor is null"); + emscripten::val shape = emscripten::val::array(); + for (auto dim : X->Shape().GetDims()) { + shape.call("push", SafeInt(dim).Ref()); + } + + jsepEnsureTensor(reinterpret_cast(Y->MutableDataRaw()), + Y->GetElementType(), + shape, false) + .await(); + + const auto* data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + + return data_transfer->CopyTensor(*X, *Y); + } +}; + ONNX_OPERATOR_KERNEL_EX( MemcpyFromHost, kOnnxDomain, @@ -330,7 +353,7 @@ ONNX_OPERATOR_KERNEL_EX( KernelDefBuilder() .InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), - Memcpy); + WebNNMemcpy); ONNX_OPERATOR_KERNEL_EX( MemcpyToHost, @@ -373,4 +396,22 @@ WebNNExecutionProvider::GetKernelRegistry() const { return kernel_registry; } +std::unique_ptr WebNNExecutionProvider::GetDataTransfer() const { + if (!webnn::IsMLTensorSupported()) { + return nullptr; + } + return std::make_unique(); +} + +std::vector WebNNExecutionProvider::CreatePreferredAllocators() { + if (!webnn::IsMLTensorSupported()) { + return {}; + } + AllocatorCreationInfo customAllocatorCreationInfo([&](OrtDevice::DeviceId) { + return std::make_unique(); + }, + 0, false); + return {CreateAllocator(customAllocatorCreationInfo)}; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index 8ea8cedf04300..26c5e476bcc4f 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -40,6 +40,8 @@ class WebNNExecutionProvider : public IExecutionProvider { #endif std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; + std::vector CreatePreferredAllocators() override; private: emscripten::val wnn_context_ = emscripten::val::undefined(); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b9e017df5baa3..83e7596d2f6b8 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2770,7 +2770,8 @@ common::Status InferenceSession::RunAsync(const RunOptions* run_options, if (!tp || concurrency::ThreadPool::DegreeOfParallelism(tp) < 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync"); } - std::function run_fn = [=]() { + std::function run_fn = [run_options, feed_names, feeds, fetch_names, fetches, num_fetches, + callback, user_data, this]() { Status status = Status::OK(); ORT_TRY { if (run_options) { diff --git a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py index b94c2cb76a635..1180945d5b5dc 100644 --- a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py +++ b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py @@ -38,7 +38,7 @@ def __init__(self, args): # Deserialize an TRT engine runtime = trt.Runtime(logger) engine = runtime.deserialize_cuda_engine(engine_buffer) - num_bindings = engine.num_bindings + num_bindings = engine.num_io_tensors input_tensors = [] output_tensors = [] diff --git a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py index b8b80942c2dcf..763d160fa56b5 100644 --- a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py @@ -17,8 +17,8 @@ TRT_DOCKER_FILES = { "8.6.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6", "8.6.cuda_12_3_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6", - "10.3.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10", - "10.3.cuda_12_5_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10", + "10.4.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10", + "10.4.cuda_12_5_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10", "BIN": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin", } diff --git a/onnxruntime/python/tools/transformers/models/sam2/README.md b/onnxruntime/python/tools/transformers/models/sam2/README.md index 83c0c51f09929..b0d35ac79f2fa 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/README.md +++ b/onnxruntime/python/tools/transformers/models/sam2/README.md @@ -26,6 +26,7 @@ Clone the SAM2 git repository and download the checkpoints: ```bash git clone https://github.com/facebookresearch/segment-anything-2.git cd segment-anything-2 +export sam2_dir=$PWD python3 -m pip install -e . cd checkpoints sh ./download_ckpts.sh @@ -42,7 +43,7 @@ curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.p ## Export ONNX To export ONNX models, run the convert_to_onnx.py script and specify the segment-anything-2 directory created by the above git clone command: ```bash -python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 +python3 convert_to_onnx.py --sam2_dir $sam2_dir ``` The exported ONNX models will be found in the sam2_onnx_models sub-directory. You can change the output directory using the `--output_dir` option. @@ -58,12 +59,12 @@ python3 convert_to_onnx.py -h To optimize the onnx models for CPU with float32 data type: ```bash -python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp32 +python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --dtype fp32 ``` To optimize the onnx models for GPU with float16 data type: ```bash -python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp16 --use_gpu +python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --dtype fp16 --use_gpu ``` Another option is to use optimizer.py like the following: @@ -80,13 +81,22 @@ The optimizer.py could be helpful when you have SAM2 onnx models that is exporte The exported ONNX models can run on a CPU. The demo will output sam2_demo.png. ```bash curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg -python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --demo +python3 convert_to_onnx.py --sam2_dir $sam2_dir --demo ``` It is able to run demo on optimized model as well. For example, ```bash -python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp16 --use_gpu --demo +python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --dtype fp16 --use_gpu --demo ``` +## Benchmark +To prepare an environment for benchmark, follow [Setup Environment](#setup-environment) and [Download Checkpoints](#download-checkpoints). + +Run the benchmark like the following: +```bash +sh benchmark_sam2.sh +``` +The result is in sam2.csv, which can be loaded into Excel. + ## Limitations - The exported image_decoder model does not support batch mode for now. diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh index 74048f90424cd..94e57ecb89fc1 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh @@ -11,7 +11,8 @@ dir="$( cd "$( dirname "$0" )" && pwd )" onnx_dir=$dir/sam2_onnx_models # Directory of the sam2 code by "git clone https://github.com/facebookresearch/segment-anything-2" -sam2_dir=~/segment-anything-2 +# It reads from the sam2_dir environment variable, or defaults to ~/segment-anything-2. +sam2_dir=${sam2_dir:-~/segment-anything-2} # model name to benchmark model=sam2_hiera_large @@ -65,8 +66,18 @@ run_gpu() python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp32_gpu.onnx --component image_decoder --use_gpu } +if ! [ -f truck.jpg ]; then + curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg +fi + if python3 -c "import torch; assert torch.cuda.is_available()" 2>/dev/null; then run_gpu 1000 else run_cpu 100 fi + +cat benchmark*.csv > combined_csv +awk '!x[$0]++' combined_csv > sam2.csv +rm combined_csv + +echo "Benchmarking SAM2 model $model done. Results are saved in sam2.csv" diff --git a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py index 8ad69dee0a763..40c408e851638 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py @@ -125,7 +125,7 @@ def parse_arguments(): return args -def optimize_sam2_model(onnx_model_path, optimized_model_path, use_gpu: bool, float16: bool): +def optimize_sam2_model(onnx_model_path, optimized_model_path, float16: bool, use_gpu: bool): print(f"Optimizing {onnx_model_path} to {optimized_model_path} with float16={float16} and use_gpu={use_gpu}...") # Import from source directory. diff --git a/onnxruntime/test/framework/data_types_test.cc b/onnxruntime/test/framework/data_types_test.cc index 871b255831029..0ec0ecc285931 100644 --- a/onnxruntime/test/framework/data_types_test.cc +++ b/onnxruntime/test/framework/data_types_test.cc @@ -8,6 +8,7 @@ #include "core/framework/data_types.h" #include "core/framework/data_types_internal.h" #include "core/framework/float16.h" +#include "core/framework/float8.h" #include "core/graph/onnx_protobuf.h" #include "gtest/gtest.h" @@ -494,6 +495,25 @@ TEST_F(DataTypeTest, MLFloat16Comparision) { } TEST_F(DataTypeTest, MLFloat16TestNAN) { + const MLFloat16 quiet_NaN = std::numeric_limits::quiet_NaN(); + EXPECT_TRUE(quiet_NaN.IsNaN()); + EXPECT_TRUE(quiet_NaN.IsNaNOrZero()); + EXPECT_NE(MLFloat16::NaN, quiet_NaN); // NaN are not equal to each other + EXPECT_TRUE(std::isnan(quiet_NaN.ToFloat())); + + const MLFloat16 signaling_NaN = std::numeric_limits::signaling_NaN(); + EXPECT_TRUE(signaling_NaN.IsNaN()); + EXPECT_TRUE(signaling_NaN.IsNaNOrZero()); + EXPECT_NE(MLFloat16::NaN, signaling_NaN); // NaN are not equal to each other + EXPECT_TRUE(std::isnan(signaling_NaN.ToFloat())); + + // NaN used in C# has negative sign + const MLFloat16 csharp_NaN = MLFloat16::FromBits(0xFE00U); + EXPECT_TRUE(csharp_NaN.IsNaN()); + EXPECT_TRUE(csharp_NaN.IsNaNOrZero()); + EXPECT_NE(BFloat16::NaN, csharp_NaN); + EXPECT_TRUE(std::isnan(csharp_NaN.ToFloat())); + const MLFloat16 fp16NANFromSingle(std::numeric_limits::quiet_NaN()); EXPECT_TRUE(fp16NANFromSingle.IsNaN()); EXPECT_TRUE(fp16NANFromSingle.IsNaNOrZero()); @@ -520,6 +540,11 @@ TEST_F(DataTypeTest, MLFloat16NaNComparision) { } TEST_F(DataTypeTest, MLFloat16Infinity) { + const MLFloat16 fp16_infinity(std::numeric_limits::infinity()); + EXPECT_TRUE(fp16_infinity.IsInfinity()); + EXPECT_FALSE(fp16_infinity.IsFinite()); + EXPECT_FALSE(fp16_infinity.IsNegative()); + EXPECT_FALSE(MLFloat16::MaxValue.Negate().IsInfinity()); EXPECT_FALSE(MLFloat16::MaxValue.IsInfinity()); EXPECT_TRUE(MLFloat16::MaxValue.IsFinite()); @@ -550,6 +575,8 @@ TEST_F(DataTypeTest, MLFloat16NormalSubnormal) { EXPECT_TRUE(smallest_subnormal.IsSubnormal()); EXPECT_FALSE(smallest_subnormal.IsNormal()); + EXPECT_EQ(smallest_subnormal, std::numeric_limits::denorm_min()); + // float smallest positive subnormal is ~1.40129846432481707092E-45, and // in float the same number above would be normal const float float_from_smallest_subnormal = static_cast(smallest_subnormal); @@ -639,6 +666,24 @@ TEST_F(DataTypeTest, BFloat16Comparision) { } TEST_F(DataTypeTest, BFloat16TestNAN) { + const BFloat16 quiet_NaN = std::numeric_limits::quiet_NaN(); + EXPECT_TRUE(quiet_NaN.IsNaN()); + EXPECT_TRUE(quiet_NaN.IsNaNOrZero()); + EXPECT_NE(BFloat16::NaN, quiet_NaN); + EXPECT_TRUE(std::isnan(quiet_NaN.ToFloat())); + + const BFloat16 signaling_NaN = std::numeric_limits::signaling_NaN(); + EXPECT_TRUE(signaling_NaN.IsNaN()); + EXPECT_TRUE(signaling_NaN.IsNaNOrZero()); + EXPECT_NE(BFloat16::NaN, signaling_NaN); + EXPECT_TRUE(std::isnan(signaling_NaN.ToFloat())); + + const BFloat16 csharp_NaN = BFloat16::FromBits(0xFFC1U); + EXPECT_TRUE(csharp_NaN.IsNaN()); + EXPECT_TRUE(csharp_NaN.IsNaNOrZero()); + EXPECT_NE(BFloat16::NaN, csharp_NaN); + EXPECT_TRUE(std::isnan(csharp_NaN.ToFloat())); + const BFloat16 fp16NANFromSingle = std::numeric_limits::quiet_NaN(); EXPECT_TRUE(fp16NANFromSingle.IsNaN()); EXPECT_TRUE(fp16NANFromSingle.IsNaNOrZero()); @@ -695,6 +740,8 @@ TEST_F(DataTypeTest, BFloat16NormalSubnormal) { EXPECT_TRUE(smallest_subnormal.IsSubnormal()); EXPECT_FALSE(smallest_subnormal.IsNormal()); + EXPECT_EQ(smallest_subnormal, std::numeric_limits::denorm_min()); + const float float_from_smallest_subnormal = (float)smallest_subnormal; EXPECT_FALSE(std::isnormal(float_from_smallest_subnormal)); @@ -708,6 +755,59 @@ TEST_F(DataTypeTest, BFloat16NormalSubnormal) { EXPECT_FALSE(std::isnormal(float_from_largest_subnormal)); } +#if !defined(DISABLE_FLOAT8_TYPES) +TEST_F(DataTypeTest, Float8TestNAN) { + const auto fp8_e4m3fn_nan = std::numeric_limits::quiet_NaN(); + EXPECT_TRUE(fp8_e4m3fn_nan.IsNaN()); + EXPECT_TRUE(std::isnan(fp8_e4m3fn_nan.ToFloat())); + + const auto fp8_e5m2_nan = std::numeric_limits::quiet_NaN(); + EXPECT_TRUE(fp8_e5m2_nan.IsNaN()); + EXPECT_TRUE(std::isnan(fp8_e5m2_nan.ToFloat())); + + const auto fp8_e4m3fnuz_nan = std::numeric_limits::quiet_NaN(); + EXPECT_TRUE(fp8_e4m3fnuz_nan.IsNaN()); + EXPECT_TRUE(std::isnan(fp8_e4m3fnuz_nan.ToFloat())); + + const auto fp8_e5m2fnuz_nan = std::numeric_limits::quiet_NaN(); + EXPECT_TRUE(fp8_e5m2fnuz_nan.IsNaN()); + EXPECT_TRUE(std::isnan(fp8_e5m2fnuz_nan.ToFloat())); +} + +TEST_F(DataTypeTest, Float8TestInf) { + const auto fp8_e5m2_inf = std::numeric_limits::infinity(); + EXPECT_TRUE(fp8_e5m2_inf.IsInfinity()); + EXPECT_TRUE(std::isinf(fp8_e5m2_inf.ToFloat())); + + EXPECT_FALSE(std::numeric_limits::has_infinity); + EXPECT_FALSE(std::numeric_limits::has_infinity); + EXPECT_FALSE(std::numeric_limits::has_infinity); +} + +TEST_F(DataTypeTest, Float8TestLimits) { + constexpr float abs_tolerance = 1e-6f; + EXPECT_NEAR(std::numeric_limits::min().ToFloat(), 0.015625f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::max().ToFloat(), 448.0f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::lowest().ToFloat(), -448.0f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::denorm_min().ToFloat(), 0.001953125f, abs_tolerance); + + EXPECT_NEAR(std::numeric_limits::min().ToFloat(), 0.00006103515f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::max().ToFloat(), 57344.0f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::lowest().ToFloat(), -57344.0f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::denorm_min().ToFloat(), 0.00001525878f, abs_tolerance); + + EXPECT_NEAR(std::numeric_limits::min().ToFloat(), 0.0078125f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::max().ToFloat(), 240.0f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::lowest().ToFloat(), -240.0f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::denorm_min().ToFloat(), 0.0009765625f, abs_tolerance); + + EXPECT_NEAR(std::numeric_limits::min().ToFloat(), 0.00003051757f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::max().ToFloat(), 57344.0f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::lowest().ToFloat(), -57344.0f, abs_tolerance); + EXPECT_NEAR(std::numeric_limits::denorm_min().ToFloat(), 0.00000762939f, abs_tolerance); +} +#endif + TEST_F(DataTypeTest, DataUtilsTest) { using namespace ONNX_NAMESPACE::Utils; // Test simple seq diff --git a/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp b/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp new file mode 100644 index 0000000000000..1dccbe44aafaf --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp @@ -0,0 +1,54 @@ +#include "bench_util.h" +#include "core/mlas/lib/mlasi.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +void BM_ConvertF16ToF32(benchmark::State& state) { + bool aligned = static_cast(state.range(0)); + const size_t count = 1 << 18; + auto src = RandomVectorUniform(count, 0, 60000); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + float* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + // Warm up + MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); + + for (auto _ : state) { + MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); + } +} + +void BM_ConvertF32ToF16(benchmark::State& state) { + bool aligned = static_cast(state.range(0)); + const size_t count = 1 << 18; + auto src = RandomVectorUniform(count, -30000.0f, 30000.0f); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + unsigned short* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + // Warm up + MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); + + for (auto _ : state) { + MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); + } +} + +BENCHMARK(BM_ConvertF16ToF32) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"aligned"}); + b->ArgsProduct({{0, 1}}); + }); + +BENCHMARK(BM_ConvertF32ToF16) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"aligned"}); + b->ArgsProduct({{0, 1}}); + }); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 73c78b8cc3d47..2a14ee0e6ff04 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -80,9 +80,11 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, params.A = A.data(); params.lda = K; if (PackedQuantBData != nullptr) - params.QuantBDataWorkspace = static_cast(PackedQuantBData.get()); + params.QuantBDataWorkspace = PackedQuantBData.get(); else params.QuantBDataWorkspace = static_cast(QuantBData.data()); + + params.PackedQuantBData = PackedQuantBData.get(); params.QuantBScale = QuantBScale.data(); params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); params.Bias = HasBias ? Bias.data() : nullptr; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp new file mode 100644 index 0000000000000..243752bbea24e --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp @@ -0,0 +1,82 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sqnbitgemm_neon_fp16.cpp + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM on ARM CPU with input A type T1 fp16. + +--*/ + +#include + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonFp16CastTest : public MlasTestBase { + private: + void TestFp16ToFp32(size_t count) { + std::vector src(count); + std::vector dest(count); + + for (size_t i = 0; i < count; i++) { + src[i] = static_cast(i); + } + + MlasCastF16ToF32KernelNeon(src.data(), dest.data(), count); + + for (size_t i = 0; i < count; i++) { + if ((src[i] & 0x1c00) == 0x1c00) continue; // skip inf and nan + ASSERT_EQ(dest[i], MLAS_FP16::FromBits(src[i]).ToFloat()); + } + } + + void TestFp32ToFp16(size_t count) { + std::vector src(count); + std::vector dest(count); + + for (size_t i = 0; i < count; i++) { + src[i] = static_cast(i) + 0.125f; + } + + MlasCastF32ToF16KernelNeon(src.data(), dest.data(), count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(dest[i], MLAS_FP16(src[i]).val); + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16Cast"; + } + + void ExecuteShort(void) override { + TestFp16ToFp32(1 << 16); + TestFp16ToFp32(1); + TestFp16ToFp32(4); + TestFp16ToFp32(7); + TestFp32ToFp16(1 << 16); + TestFp32ToFp16(3); + TestFp32ToFp16(4); + TestFp32ToFp16(6); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 924616f49ab25..e8c948ade1068 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -73,7 +73,7 @@ void usage() { "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" - "\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n" + "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" diff --git a/onnxruntime/test/optimizer/optimizer_test.cc b/onnxruntime/test/optimizer/optimizer_test.cc index 79704f2cc79e3..81c1a4ace1e33 100644 --- a/onnxruntime/test/optimizer/optimizer_test.cc +++ b/onnxruntime/test/optimizer/optimizer_test.cc @@ -24,8 +24,6 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { -static const std::string MODEL_FOLDER = "testdata/transform/"; - TEST(OptimizerTest, Basic) { Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 8d90e48db97c1..35ba1a3369597 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -4964,6 +4964,90 @@ TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerAxisDQUnsqueezeTranspose) { testing::ContainerEq(fetches[0].Get().DataAsSpan())); } +// Test that the TransposeOptimizer's qdq-fixup pass converts the sequence (Op -> DQ -> Q -> GRAPH_OUTPUT) to +// (Op -> GRAPH_OUTPUT). +TEST(TransposeOptimizerTests, RemoveEmptyDQQAtGraphOutput) { + auto model_uri = ORT_TSTR("testdata/transpose_optimizer_empty_dq_q_at_graph_output.onnx"); + + RandomValueGenerator random{123}; + std::vector input_dims{1, 3, 4, 4}; + std::vector input0_data = random.Gaussian(input_dims, 0.0f, 1.0f); + + auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators(); + OrtValue input0; + CreateMLValue(allocators[0], input_dims, input0_data, &input0); + + NameMLValMap feeds{{"input0", input0}}; + + std::vector output_names{"output0"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + so.graph_optimization_level = TransformerLevel::Default; // off + + // get results with no modifications to the model + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + namespace alias_oto = onnx_transpose_optimization; + auto api_graph = MakeApiGraph(graph, + TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + + alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph); + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + ASSERT_STATUS_OK(graph.Resolve()); + + // Use this hack to save model for viewing if needed + // ASSERT_STATUS_OK(Model::Save(const_cast(session.GetModel()), + // ToPathString("updated_model_empty_dqq_graph_output.onnx"))); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0) << "2 pre-existing Transposes at the I/O cancel. "; + + // Check that the graph ends in the sequence (Mul -> Q -> GRAPH_OUTPUT) + Node* mul_node = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Mul") { + mul_node = &node; + break; + } + } + + // Mul should be followed by a Q node. + ASSERT_TRUE(mul_node != nullptr); + const auto& last_q_node = *(mul_node->OutputNodesBegin()); + EXPECT_EQ(last_q_node.OpType(), "QuantizeLinear"); + + // The Q node should generate the graph's output. + const std::string& q_out_name = last_q_node.OutputDefs()[0]->Name(); + const std::string& graph_out_name = graph.GetOutputs()[0]->Name(); + EXPECT_EQ(q_out_name, graph_out_name); + + // Run optimized model. + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); +} + // Tests the in-place unsqueeze and transpose of a constant consumed by a per-axis DQ. TEST(TransposeOptimizerTests, InPlaceUnsqueezeTransposePerAxisDQ) { // Model contains a Mul with a constant/broadcastable/per-axis DQ input[1]. diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index c1c48d4945a4d..6e811f4596eab 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -98,7 +98,7 @@ namespace perftest { "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" - "\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n" + "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 182fa4729a88f..ff5895623fc9b 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -385,6 +385,8 @@ void InternalNumericalCheck(const Tensor& expected, EXPECT_TRUE(std::isnan(cur_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(cur_expected[i])) { // Test infinity for equality EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; + } else if (std::isinf(cur_actual[i])) { // Handle cur_actual is inf but cur_expected is FLT_MAX case + EXPECT_TRUE(cur_expected[i] == FLT_MAX) << "Expected infinity. i:" << i; } else { T tolerance = get_tolerance(tolerance_params, cur_expected[i]); EXPECT_NEAR(cur_expected[i], cur_actual[i], tolerance) << "i:" << i; diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index d2e883331acd4..724118d7419d2 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -125,7 +125,7 @@ TEST_F(ActivationOpTest, Relu) { {}, {}, /*is_tensorrt_supported=*/false, /*opset_version= */ 14); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) TestActivationOp( "Relu", input_values_fp16, @@ -139,7 +139,7 @@ TEST_F(ActivationOpTest, Relu) { #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST_F(ActivationOpTest, Sigmoid_fp16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -413,7 +413,7 @@ TEST_F(ActivationOpTest, LeakyRelu) { {{"alpha", alpha}}, {}); } -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) TEST_F(ActivationOpTest, LeakyRelu_fp16) { OpTester test("LeakyRelu", 11); float alpha = 0.01f; // oneDNN set alpha equal to 0.01 diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h index 409409f56c51c..8ca0f6d845a09 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h @@ -90,7 +90,6 @@ class ActivationOpTest : public ::testing::Test { DBL_MAX, -DBL_MAX, std::numeric_limits::infinity()}}; // max, -max, inf std::vector> input_values_int8{{-1, -5, 0, 1, 5, 100, -100, // normal input values for activation std::numeric_limits::min(), std::numeric_limits::max()}}; // min, max -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED std::vector> input_values_fp16{{MLFloat16(-1.0f), MLFloat16(-5.f), MLFloat16(), @@ -100,7 +99,6 @@ class ActivationOpTest : public ::testing::Test { MLFloat16(-100.f), MLFloat16(65504.f), MLFloat16(-65504.f)}}; -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED void SetUp() override { float low = -1.0f, high = 1.0f; diff --git a/onnxruntime/test/providers/cpu/math/clip_test.cc b/onnxruntime/test/providers/cpu/math/clip_test.cc index 9948a6cc8a681..c1452ab686279 100644 --- a/onnxruntime/test/providers/cpu/math/clip_test.cc +++ b/onnxruntime/test/providers/cpu/math/clip_test.cc @@ -120,19 +120,62 @@ TEST(MathOpTest, Clip_Default_uint64) { } TEST(MathOpTest, Clip_MLFloat16) { + auto run_test = [](bool min_max_are_initializer) { + OpTester test("Clip", 12); + + std::vector dims{3, 3}; + test.AddInput("X", dims, + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(-3.0f), + MLFloat16(-4.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(8.0f)}); + test.AddInput("min", {}, {MLFloat16(0.0f)}, min_max_are_initializer); + test.AddInput("max", {}, {MLFloat16(6.0f)}, min_max_are_initializer); + test.AddOutput("Y", dims, + {MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), + MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(6.0f)}); + + test.Run(); + }; + run_test(true); // coreml requires constant max/min + run_test(false); +} + +TEST(MathOpTest, Clip_MLFloat16_NoMin_NoMax) { OpTester test("Clip", 12); - std::vector dims{3, 3}; + std::vector dims{3}; + test.AddInput("X", dims, + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(3.0f)}); + test.AddOutput("Y", dims, + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(3.0f)}); + + test.Run(); +} + +TEST(MathOpTest, Clip_MLFloat16_NoMax) { + OpTester test("Clip", 12); + + std::vector dims{3}; test.AddInput("X", dims, - {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(-3.0f), - MLFloat16(-4.0f), MLFloat16(0.0f), MLFloat16(2.0f), - MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(8.0f)}); + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(3.0f)}); test.AddInput("min", {}, {MLFloat16(0.0f)}); - test.AddInput("max", {}, {MLFloat16(6.0f)}); test.AddOutput("Y", dims, - {MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), - MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(2.0f), - MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(6.0f)}); + {MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(3.0f)}); + + test.Run(); +} + +TEST(MathOpTest, Clip_MLFloat16_NoMin) { + OpTester test("Clip", 12); + + std::vector dims{3}; + test.AddInput("X", dims, + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(3.0f)}); + test.AddOptionalInputEdge(); // no min + test.AddInput("max", {}, {MLFloat16(0.0f)}); + test.AddOutput("Y", dims, + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(0.0f)}); test.Run(); } diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index eb914646942fe..b2e9034653746 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -22,40 +22,93 @@ std::vector MakeMLFloat16(const std::initializer_list& input) return output; } -#if defined(USE_CUDA) || defined(USE_ROCM) -void TestFloat16(const char* op_name, const std::vector& lhs_dim, - const std::initializer_list& lhs_values, const std::vector& rhs_dim, - const std::initializer_list& rhs_values, const std::vector& out_dim, - const std::initializer_list& out_values) { +void TestBinaryFloat16(const char* op_name, + const std::vector& lhs_dim, + const std::initializer_list& lhs_values, + const std::vector& rhs_dim, + const std::initializer_list& rhs_values, + const std::vector& out_dim, + const std::initializer_list& out_values, + bool enable_bf16 = true) { + { + std::vector> execution_providers; +#ifdef COREML_ENABLE_MLPROGRAM + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); +#elif USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + if (execution_providers.size() > 0) { + OpTester tester(op_name, 14); + tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); + tester.AddInput("B", rhs_dim, MakeMLFloat16(rhs_values)); + tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); + + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } { - OpTester tester(op_name, 14); - tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); - tester.AddInput("B", rhs_dim, MakeMLFloat16(rhs_values)); - tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); #endif - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + + if (enable_bf16 && execution_providers.size() > 0) { + OpTester tester(op_name, 14); + tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); + tester.AddInput("B", rhs_dim, MakeBFloat16(rhs_values)); + tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); + + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } +} + +void TestUnaryFloat16(const char* op_name, + const std::vector& lhs_dim, + const std::initializer_list& lhs_values, + const std::vector& out_dim, + const std::initializer_list& out_values, + int opset = 14, + bool run_bf16 = true) { + { + std::vector> execution_providers; +#ifdef COREML_ENABLE_MLPROGRAM + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); +#elif USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + if (execution_providers.size() > 0) { + OpTester tester(op_name, opset); + tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); + tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); + + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } { - OpTester tester(op_name, 14); - tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); - tester.AddInput("B", rhs_dim, MakeBFloat16(rhs_values)); - tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); #endif - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + + if (run_bf16 && execution_providers.size() > 0) { + OpTester tester(op_name, opset); + tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); + tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); + + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } } -#endif void TestBFloat16(const char* op_name, const std::vector& lhs_dim, const std::initializer_list& lhs_values, const std::vector& rhs_dim, @@ -163,9 +216,7 @@ TEST(MathOpTest, Add_float) { test.Run(); #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -202,9 +253,7 @@ TEST(MathOpTest, Add_Broadcast_Axis) { test.AddOutput("C", dims, out_values); test.Run(OpTester::ExpectResult::kExpectSuccess, ""); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values); @@ -228,9 +277,7 @@ TEST(MathOpTest, Add_Broadcast_MultidirectionalAB) { {kTensorrtExecutionProvider}); // TensorRT: got C with shape [3, 1] #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values); -#endif + TestBinaryFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values); #if defined(USE_DNNL) TestBFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values); @@ -254,9 +301,7 @@ TEST(MathOpTest, Add_Broadcast_MultidirectionalBA) { {kTensorrtExecutionProvider}); // TensorRT: got C with shape [3, 1] #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values); -#endif + TestBinaryFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values); #if defined(USE_DNNL) TestBFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values); @@ -527,9 +572,7 @@ TEST(MathOpTest, Sub) { test.AddOutput("C", dims, out_values); test.Run(); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -584,9 +627,7 @@ TEST(MathOpTest, Mul) { test.Run(); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -622,9 +663,7 @@ TEST(MathOpTest, Div) { test.AddOutput("C", dims, out_values); test.Run(); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -772,13 +811,12 @@ TEST(MathOpTest, Ceil_double) { TEST(MathOpTest, Reciprocal) { OpTester test("Reciprocal"); std::vector dims{2, 2}; - test.AddInput("X", dims, - {1.0f, 2.0f, - -1.0f, -2.0f}); - test.AddOutput("Y", dims, - {1.0f, 0.5f, - -1.0f, -0.5f}); + std::initializer_list inputs = {1.0f, 2.0f, -1.0f, -2.0f}; + std::initializer_list outputs = {1.0f, 0.5f, -1.0f, -0.5f}; + test.AddInput("X", dims, inputs); + test.AddOutput("Y", dims, outputs); test.Run(); + TestUnaryFloat16("Reciprocal", dims, inputs, dims, outputs, 12, false); } TEST(MathOpTest, Reciprocal_double) { @@ -795,14 +833,13 @@ TEST(MathOpTest, Reciprocal_double) { TEST(MathOpTest, Sqrt_Float) { OpTester test("Sqrt"); + std::initializer_list inputs = {1.0f, 4.0f, 0.0f, 9.0f}; + std::initializer_list outputs = {1.0f, 2.0f, 0.0f, 3.0f}; std::vector dims{2, 2}; - test.AddInput("X", dims, - {1.0f, 4.0f, - 0.0f, 9.0f}); - test.AddOutput("Y", dims, - {1.0f, 2.0f, - 0.0f, 3.0f}); + test.AddInput("X", dims, inputs); + test.AddOutput("Y", dims, outputs); test.Run(); + TestUnaryFloat16("Sqrt", dims, inputs, dims, outputs); } #if defined(USE_DNNL) || defined(USE_CUDA) @@ -1056,24 +1093,13 @@ TEST(MathOpTest, Pow_double_int64) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) TEST(MathOpTest, Pow_float16_float16) { - OpTester test("Pow", 12); std::vector dims{4}; - - test.AddInput("X", dims, MakeMLFloat16({2.0f, 2.0f, std::sqrt(2.0f), 1.0f})); - test.AddInput("Y", dims, MakeMLFloat16({0.0f, 8.0f, 2.0f, 9.0f})); - test.AddOutput("Z", dims, MakeMLFloat16({1.0f, 256.0f, 2.0f, 1.0f})); - - std::vector> execution_providers; -#ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); -#endif - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + TestBinaryFloat16("Pow", dims, {2.0f, 2.0f, std::sqrt(2.0f), 1.0f}, dims, {0.0f, 8.0f, 2.0f, 9.0f}, + dims, {1.0f, 256.0f, 2.0f, 1.0f}, false); } +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(MathOpTest, Pow_float_float16) { OpTester test("Pow", 12); std::vector dims{4}; @@ -1087,6 +1113,8 @@ TEST(MathOpTest, Pow_float_float16) { execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); +#elif COREML_ENABLE_MLPROGRAM + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -1787,54 +1815,90 @@ TEST(MathOpTest, Min_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFloat16_MatrixVector) { - OpTester test("Min", 12); - test.AddInput("data_0", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); - test.AddInput("data_1", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddOutput("min", {3, 3}, - MakeMLFloat16({0.0f, 0.0f, 0.0f, - -1.0f, -1.0f, -2.0f, - 0.5f, 0.0f, 1.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { +void TestFloat16MinMax( + const char* op_name, + const std::vector& lhs_dim, + const std::initializer_list& lhs_values, + const std::vector& rhs_dim, + const std::initializer_list& rhs_values, + const std::vector& out_dim, + const std::initializer_list& out_values) { + { std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); + if (nullptr != DefaultCpuExecutionProvider()) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (nullptr != DefaultCudaExecutionProvider()) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + OpTester test(op_name, 13); + test.AddInput("data_0", lhs_dim, MakeMLFloat16(lhs_values)); + test.AddInput("data_1", rhs_dim, MakeMLFloat16(rhs_values)); + test.AddOutput("output", out_dim, MakeMLFloat16(out_values)); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} -TEST(MathOpTest, Min_12_MLFloat16_VectorMatrix) { - OpTester test("Min", 12); - test.AddInput("data_0", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddInput("data_1", {3, 4}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, -1.0f, - -0.5f, 0.0f, -2.0f, -1.25f, - 0.5f, 0.0f, 2.0f, 1.5f})); - test.AddOutput("min", {3, 4}, - MakeMLFloat16({0.0f, 0.0f, 0.0f, -1.0f, - -1.0f, -1.0f, -2.0f, -1.25f, - 0.5f, 0.0f, 1.0f, 1.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + OpTester test(op_name, 13); + test.AddInput("data_0", lhs_dim, MakeBFloat16(lhs_values)); + test.AddInput("data_1", rhs_dim, MakeBFloat16(rhs_values)); + test.AddOutput("output", out_dim, MakeBFloat16(out_values)); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } +TEST(MathOpTest, Min_13_Float16_MatrixVector) { + TestFloat16MinMax("Min", + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f}, + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 3}, + {0.0f, 0.0f, 0.0f, + -1.0f, -1.0f, -2.0f, + 0.5f, 0.0f, 1.0f}); +} + +TEST(MathOpTest, Min_13_Float16_VectorMatrix) { + TestFloat16MinMax("Min", + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 4}, + {1.0f, 1.0f, 1.0f, -1.0f, + -0.5f, 0.0f, -2.0f, -1.25f, + 0.5f, 0.0f, 2.0f, 1.5f}, + {3, 4}, + {0.0f, 0.0f, 0.0f, -1.0f, + -1.0f, -1.0f, -2.0f, -1.25f, + 0.5f, 0.0f, 1.0f, 1.0f}); +} + +TEST(MathOpTest, Min_13_Float16_Nan) { + TestFloat16MinMax("Min", + {4, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f, 0.5f}, + {4, 1}, {0.5f, 1.0f, 0.25f, std::numeric_limits::quiet_NaN()}, + {4, 1}, + {-1.0f, std::numeric_limits::quiet_NaN(), 0.25f, std::numeric_limits::quiet_NaN()}); +} + +TEST(MathOpTest, Min_13_Float16_Nan_with_scalar) { + TestFloat16MinMax("Min", + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f}, + {1}, {0.25f}, + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 0.25f}); +} + +TEST(MathOpTest, Min_13_Float16_with_scalar_Nan) { + TestFloat16MinMax("Min", + {3, 1}, {-0.5f, 1.0f, 1.5f}, + {1}, {std::numeric_limits::quiet_NaN()}, + {3, 1}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); +} TEST(MathOpTest, Max_6) { OpTester test("Max", 6); std::vector dims{3, 3}; @@ -2185,54 +2249,57 @@ TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFloat16_MatrixVector) { - OpTester test("Max", 12); - test.AddInput("data_0", {4, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.0f, 0.5f, 0.75f, - 0.5f, 0.0f, 2.0f})); - test.AddInput("data_1", {4, 1}, - MakeMLFloat16({0.0f, -1.0f, 0.5f, 1.0f})); - test.AddOutput("max", {4, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -1.0f, - 0.5f, 0.5f, 0.75f, - 1.0f, 1.0f, 2.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} - -TEST(MathOpTest, Max_12_MLFloat16_VectorMatrix) { - OpTester test("Max", 12); - test.AddInput("data_0", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddInput("data_1", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); - test.AddOutput("max", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -1.0f, - 1.0f, 1.0f, 2.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } +TEST(MathOpTest, Max_13_Float16_MatrixVector) { + TestFloat16MinMax("Max", + {4, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.0f, 0.5f, 0.75f, + 0.5f, 0.0f, 2.0f}, + {4, 1}, {0.0f, -1.0f, 0.5f, 1.0f}, + {4, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -1.0f, + 0.5f, 0.5f, 0.75f, + 1.0f, 1.0f, 2.0f}); +} + +TEST(MathOpTest, Max_13_Float16_VectorMatrix) { + TestFloat16MinMax("Max", + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f}, + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -1.0f, + 1.0f, 1.0f, 2.0f}); +} + +TEST(MathOpTest, Max_13_Float16_Nan) { + TestFloat16MinMax("Max", + {4, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f, 0.5f}, + {4, 1}, {0.5f, 1.0f, 0.25f, std::numeric_limits::quiet_NaN()}, + {4, 1}, + {0.5f, std::numeric_limits::quiet_NaN(), 1.0f, std::numeric_limits::quiet_NaN()}); +} + +TEST(MathOpTest, Max_13_Float16_Nan_with_scalar) { + TestFloat16MinMax("Max", + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f}, + {1}, {0.25f}, + {3, 1}, {0.25f, std::numeric_limits::quiet_NaN(), 1.0f}); +} + +TEST(MathOpTest, Max_13_Float16_with_scalar_Nan) { + TestFloat16MinMax("Max", + {3, 1}, {-0.5f, 1.0f, 1.5f}, + {1}, {std::numeric_limits::quiet_NaN()}, + {3, 1}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); } TEST(MathOpTest, Not) { @@ -3758,5 +3825,6 @@ TEST(MathOpTest, BitwiseNot_uint8) { test.AddOutput("Y", dims, {254, 251, 250, 252}); test.Run(); } + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 625ff29d4ccf9..66408e6adfbc5 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -25,7 +25,7 @@ const constexpr auto run_with_tunable_op = &run_options; } // namespace -// Only CUDA and ROCM kernel has float 16 support +// Only CUDA, ROCM and CoreML kernels have float 16 support TEST(GemmOpTest, GemmNoTrans_f16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -34,36 +34,142 @@ TEST(GemmOpTest, GemmNoTrans_f16) { return; } #endif - OpTester test("Gemm", 13); - test.AddAttribute("transA", (int64_t)0); - test.AddAttribute("transB", (int64_t)0); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 1.0f); + std::vector A{1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector B = {0.5f, 2.1f, 1.2f, -0.3f, + -1.2f, 0.2f, 1.0f, -2.1f, + 1.3f, 4.1f, 1.3f, -8.1f}; + std::vector C = {0.5f, 2.1f, 1.2f, + -0.3f, -1.2f, 0.2f}; + + std::vector f_A(8); + std::vector f_B(12); + ConvertFloatToMLFloat16(A.data(), f_A.data(), 8); + ConvertFloatToMLFloat16(B.data(), f_B.data(), 12); + + { + // bias has same shape as output + std::vector f_Y(6); + std::vector Y{19.8f, 0.7f, -25.7f, + -19.6f, 0.2f, 27.1f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(6); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 6); + + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B); + test.AddInput("C", {2, 3}, f_C); + test.AddOutput("Y", {2, 3}, f_Y); + // we used float data with decimal instead of only integer, increase Tolerance to make test pass + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } + { + // bias has shape {1, output_features} + std::vector f_Y(6); + std::vector Y{19.8f, 0.7f, -25.7f, + -18.8f, 3.5f, 28.1f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(3); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 3); + // CoreML program require B/C are constant + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B, true); + test.AddInput("C", {3}, f_C, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } + { + // bias is a scalar + std::vector f_Y(6); + std::vector Y{19.8f, -0.9f, -26.4f, + -18.8f, 1.9f, 27.4f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(1); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 1); + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B, true); + test.AddInput("C", {1}, f_C, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } +} + +// Only CUDA, ROCM and CoreML kernels have float 16 support +TEST(GemmOpTest, GemmTransB_f16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; + return; + } +#endif std::vector A{1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; - std::vector B(12, 1.0f); - std::vector C(6, 1.0f); - std::vector Y{11.0f, 11.0f, 11.0f, - -9.0f, -9.0f, -9.0f}; + std::vector B = {0.5f, 2.1f, 1.2f, -0.3f, + -1.2f, 0.2f, 1.0f, -2.1f, + 1.3f, 4.1f, 1.3f, -8.1f}; + std::vector C = {0.5f, 2.1f, 1.2f, + -0.3f, -1.2f, 0.2f}; std::vector f_A(8); std::vector f_B(12); - std::vector f_C(6); - std::vector f_Y(6); ConvertFloatToMLFloat16(A.data(), f_A.data(), 8); ConvertFloatToMLFloat16(B.data(), f_B.data(), 12); - ConvertFloatToMLFloat16(C.data(), f_C.data(), 6); - ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); - - test.AddInput("A", {2, 4}, f_A); - test.AddInput("B", {4, 3}, f_B); - test.AddInput("C", {2, 3}, f_C); - test.AddOutput("Y", {2, 3}, f_Y); - test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported - .Config(run_with_tunable_op) - .RunWithConfig(); + { + // bias is a scalar and transB is True + std::vector f_Y(6); + std::vector Y{7.6f, -5.7f, -18.5f, -6.6f, 6.7f, 19.5f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(1); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 1); + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)1); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {3, 4}, f_B, true); + test.AddInput("C", {1}, f_C, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } } #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 90370560859aa..a7d2281ac19f8 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -246,7 +246,7 @@ TEST(MathOpTest, MatMulZeroKInt32Type) { RunMatMulZeroKTest(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(MathOpTest, MatMul_Float16) { #ifdef USE_CUDA int min_cuda_architecture = 530; diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 95b274966fbbb..ce1ac7591ec34 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" @@ -28,6 +28,15 @@ struct ConvOpAndTestAttributes { vector activation_parameters = {}; }; +/* +Please notice that, we have predefined macros in the head of the file +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) +When we have these two macro defines, this UT will turn into green light and work. + +If attributes.activation is set the NhwcFusedConv contrib op is used. +If you are adding support for a new EP to the test and the EP does not support NhwcFusedConv +please add the EP to the excluded_providers list. +*/ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, @@ -81,11 +90,13 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); - // QNN have issue with dynamic weight, auto pad with SAME_UPPER, SAME_LOWER + // QNN has issue with dynamic weight, auto pad with SAME_UPPER, SAME_LOWER if (!weight_is_initializer || attributes.auto_pad == "SAME_UPPER" || attributes.auto_pad == "SAME_LOWER") { excluded_providers.insert(kQnnExecutionProvider); } - + if (!weight_is_initializer || !attributes.activation.empty()) { + excluded_providers.insert(kCoreMLExecutionProvider); + } tester->Run(expect_result, err_str, excluded_providers); } @@ -1147,6 +1158,7 @@ TEST(ConvFp16Test, Pointwise_Relu) { MLFloat16(17.5f), MLFloat16(9.5f)}; TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } TEST(ConvFp16Test, Conv2D_HardSigmoid) { @@ -1176,6 +1188,7 @@ TEST(ConvFp16Test, Conv2D_HardSigmoid) { MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(0.0f)}; TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } TEST(ConvFp16Test, Conv2D_Bias_Z_Relu) { @@ -1205,6 +1218,7 @@ TEST(ConvFp16Test, Conv2D_Bias_Z_Relu) { vector Z_shape = {1, 2, 2, 2}; auto expected_vals = {MLFloat16(12.0f), MLFloat16(11.0f), MLFloat16(17.0f), MLFloat16(15.0f), MLFloat16(25.0f), MLFloat16(23.0f), MLFloat16(29.0f), MLFloat16(28.0f)}; TestConvFp16Op(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, true); } #endif // CONTRIB_OPS diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 2bf53ce5b5986..29525f89ef544 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -22,10 +22,11 @@ struct ConvTransposeOpAttributes { string auto_pad; }; +template void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, - const vector>& inputs, + const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const std::vector& expected_output, const vector& expected_output_shape, bool is_weight_and_bias_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, @@ -61,17 +62,18 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, const char* input_names[] = {"X", "W", "B"}; bool is_initializers[] = {false, is_weight_and_bias_initializer, is_weight_and_bias_initializer}; for (size_t i = 0; i < inputs.size(); i++) { - test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); + test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); } - test.AddOutput("Y", expected_output_shape, expected_output); + test.AddOutput("Y", expected_output_shape, expected_output); test.Run(expect_result, err_str, excluded_provider_types); // Disable TensorRT because weight as input is not supported } +template void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, - const vector>& inputs, + const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const std::vector& expected_output, const vector& expected_output_shape, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", @@ -87,6 +89,13 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, } // namespace +template +class ConvTransposeTest : public ::testing::Test { +}; + +using ConvTransposeTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ConvTransposeTest, ConvTransposeTestTypes); + TEST(ConvTransposeTest, ConvTranspose_1D) { ConvTransposeOpAttributes attrs{ vector{3}, // kernel_shape @@ -108,13 +117,24 @@ TEST(ConvTransposeTest, ConvTranspose_1D) { 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; vector Y_shape = {1, 2, 5}; - auto expected_vals = {18.1f, 40.2f, 66.3f, 48.f, 26.f, - 9.4f, 22.5f, 39.6f, 30.f, 17.f}; + vector expected_vals = {18.1f, 40.2f, 66.3f, 48.f, 26.f, + 9.4f, 22.5f, 39.6f, 30.f, 17.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } -TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { +template +static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + +TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape vector{1, 1}, // output_padding @@ -137,17 +157,27 @@ TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { 0.04118127f, -0.44696793f, 0.06373066f}; vector Y_shape = {1, 1, 6, 6}; - auto expected_vals = {0.07368518f, -0.08925839f, -0.06627201f, 0.06301362f, 0.03732984f, -0.01919658f, - -0.00628807f, -0.02817563f, -0.01472169f, 0.04392925f, -0.00689478f, -0.01549204f, - 0.07957941f, -0.11459791f, -0.09505399f, 0.07681622f, 0.03604182f, -0.01853423f, - -0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f, - 0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f, - -0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f}; - TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + vector expected_vals = {0.07368518f, -0.08925839f, -0.06627201f, 0.06301362f, 0.03732984f, -0.01919658f, + -0.00628807f, -0.02817563f, -0.01472169f, 0.04392925f, -0.00689478f, -0.01549204f, + 0.07957941f, -0.11459791f, -0.09505399f, 0.07681622f, 0.03604182f, -0.01853423f, + -0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f, + 0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f, + -0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f}; + if constexpr (std::is_same::value) { + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + } else { + vector X_fp16(X.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + vector W_fp16(W.size()); + ConvertFloatToMLFloat16(W.data(), W_fp16.data(), W.size()); + std::vector expected_vals_fp16(expected_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + TestConvTransposeOp(attrs, {X_fp16, W_fp16}, {X_shape, W_shape}, expected_vals_fp16, Y_shape); + } } // 2D input with C > 1 -TEST(ConvTransposeTest, ConvTranspose_2D_C2) { +TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_C2) { ConvTransposeOpAttributes attrs = { vector{2, 2}, // kernel_shape {}, // output_padding @@ -176,16 +206,17 @@ TEST(ConvTransposeTest, ConvTranspose_2D_C2) { 0.44524362f, 0.6056068f}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = { + vector expected_vals = { 0.50678771f, 1.10413539f, 0.74340409f, 0.14989006f, 0.34063845f, 1.19294512f, 1.85030293f, 0.63518577f, 0.58575004f, 1.25774109f, 1.23472511f, 0.77670550f, 0.25844323f, 0.88953220f, 0.77098041f, 0.27468451f}; - TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W)}, + {X_shape, W_shape}, GetTypedArray(expected_vals), Y_shape); } -TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { +TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape vector{0, 0}, // output_padding @@ -209,12 +240,13 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { vector B = {0.04676145f}; vector B_shape = {1}; vector Y_shape = {1, 1, 5, 5}; - auto expected_vals = {-0.03781903f, -0.09041066f, 0.14239404f, 0.09704495f, -0.03399426f, - 0.08749044f, 0.35613984f, 0.07240347f, -0.27841991f, -0.00337578f, - 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, - 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, - 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; - TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + vector expected_vals = {-0.03781903f, -0.09041066f, 0.14239404f, 0.09704495f, -0.03399426f, + 0.08749044f, 0.35613984f, 0.07240347f, -0.27841991f, -0.00337578f, + 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, + 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, + 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); } TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { @@ -247,22 +279,22 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { vector B = {0.17402864f}; vector B_shape = {1}; vector Y_shape = {1, 1, 8, 8}; - auto expected_vals = {0.1695925f, 0.14171794f, 0.31368554f, 0.16113512f, - 0.15653302f, 0.033998f, 0.38345876f, 0.12173492f, - 0.05362644f, 0.35481372f, 0.09013268f, -0.06378071f, - 0.24394518f, 0.00222442f, 0.50842237f, -0.07341707f, - 0.17984779f, 0.35392997f, 0.03631867f, 0.16350585f, - 0.30338728f, 0.2088346f, 0.47435546f, 0.0147884f, - 0.20821247f, 0.08664516f, 0.03569011f, 0.16659322f, - 0.47522858f, 0.19675478f, -0.10781619f, 0.02401161f, - 0.0965334f, 0.1788421f, 0.36887163f, 0.2512877f, - 0.00254938f, 0.04799958f, 0.11982619f, 0.31525785f, - 0.12701407f, 0.19566584f, 0.31214368f, -0.10558143f, - 0.18591091f, 0.46830338f, 0.05418756f, 0.20530567f, - 0.07357728f, 0.39731777f, 0.1872202f, 0.08253923f, - 0.11266428f, 0.17892915f, 0.32709083f, 0.1860041f, - 0.16902491f, 0.3129794f, -0.01718347f, 0.28917417f, - 0.07588299f, 0.32025051f, 0.39891475f, -0.04581133f}; + vector expected_vals = {0.1695925f, 0.14171794f, 0.31368554f, 0.16113512f, + 0.15653302f, 0.033998f, 0.38345876f, 0.12173492f, + 0.05362644f, 0.35481372f, 0.09013268f, -0.06378071f, + 0.24394518f, 0.00222442f, 0.50842237f, -0.07341707f, + 0.17984779f, 0.35392997f, 0.03631867f, 0.16350585f, + 0.30338728f, 0.2088346f, 0.47435546f, 0.0147884f, + 0.20821247f, 0.08664516f, 0.03569011f, 0.16659322f, + 0.47522858f, 0.19675478f, -0.10781619f, 0.02401161f, + 0.0965334f, 0.1788421f, 0.36887163f, 0.2512877f, + 0.00254938f, 0.04799958f, 0.11982619f, 0.31525785f, + 0.12701407f, 0.19566584f, 0.31214368f, -0.10558143f, + 0.18591091f, 0.46830338f, 0.05418756f, 0.20530567f, + 0.07357728f, 0.39731777f, 0.1872202f, 0.08253923f, + 0.11266428f, 0.17892915f, 0.32709083f, 0.1860041f, + 0.16902491f, 0.3129794f, -0.01718347f, 0.28917417f, + 0.07588299f, 0.32025051f, 0.39891475f, -0.04581133f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); } @@ -292,18 +324,18 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_1) { vector W_shape = {3, 3, 3, 3}; vector Y_shape = {1, 3, 4, 4}; - auto expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f}; + vector expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}); @@ -338,12 +370,12 @@ TEST(ConvTransposeTest, ConvTranspose_1D_OutputShape_1_group_2_for_transpose_pat vector W_shape = {6, 3, 3}; vector Y_shape = {1, 6, 4}; - auto expected_vals = {6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f}; + vector expected_vals = {6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -376,30 +408,30 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_1_group_2_for_transpose_pat vector W_shape = {6, 3, 3, 3}; vector Y_shape = {1, 6, 4, 4}; - auto expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, // duplicate below - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f}; + vector expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, // duplicate below + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -424,7 +456,7 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_2) { vector B = {1.0f}; vector B_shape = {1}; vector Y_shape = {1, 1, 1, 14}; - auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider}); @@ -449,8 +481,8 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShapeWithBatchSize) { vector B = {1.0f}; vector B_shape = {1}; vector Y_shape = {2, 1, 1, 14}; - auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, - 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, + 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider}); @@ -475,8 +507,8 @@ TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) { vector B = {1.0f}; vector B_shape = {1}; vector Y_shape = {2, 1, 1, 14}; - auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, - 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, + 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectFailure, // error message will end in "W: {1,1,1,5}" or "W: {1,1,5,1} depending on whether NCHW or NHWC, @@ -502,7 +534,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx) { vector W = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17.}; vector W_shape = {1, 2, 3, 3}; vector Y_shape = {1, 2, 5, 5}; - auto expected_vals = { + vector expected_vals = { 0.f, 0.f, 1.f, 4.f, 4.f, 0.f, 6.f, 20.f, 26.f, 20.f, 9.f, 36.f, 84.f, 84.f, 57.f, @@ -533,7 +565,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx2) { vector W = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.}; vector W_shape = {2, 3, 2, 2}; // this requires weight transpose vector Y_shape = {1, 3, 4, 4}; - auto expected_vals = { + vector expected_vals = { 108.f, 237.f, 263.f, 145.f, 270.f, 592.f, 652.f, 358.f, 354.f, 772.f, 832.f, 454.f, @@ -566,7 +598,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx_group) { vector W = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.0f}; vector W_shape = {16, 2, 1, 1}; vector Y_shape = {1, 8, 1, 1}; - auto expected_vals = {28.f, 34.f, 252.f, 274.f, 732.f, 770.f, 1468.f, 1522.f}; + vector expected_vals = {28.f, 34.f, 252.f, 274.f, 732.f, 770.f, 1468.f, 1522.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -586,10 +618,10 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_1) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = {11.0f, 12.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 21.0f, 22.0f, - 11.0f, 12.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 21.0f, 22.0f}; + vector expected_vals = {11.0f, 12.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 21.0f, 22.0f, + 11.0f, 12.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 21.0f, 22.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -609,11 +641,11 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_2) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 5, 5}; - auto expected_vals = {11.0f, 12.0f, 0.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 0.0f, 21.0f, 22.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - 11.0f, 12.0f, 0.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 0.0f, 21.0f, 22.0f}; + vector expected_vals = {11.0f, 12.0f, 0.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 0.0f, 21.0f, 22.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 11.0f, 12.0f, 0.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 0.0f, 21.0f, 22.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -633,11 +665,11 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_3) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 5, 5}; - auto expected_vals = {21.0f, 56.0f, 13.0f, 16.0f, 2.0f, - 63.0f, 35.0f, 67.0f, 10.0f, 14.0f, - 24.0f, 22.0f, 76.0f, 76.0f, 21.0f, - 9.0f, 5.0f, 88.0f, 45.0f, 63.0f, - 3.0f, 2.0f, 33.0f, 18.0f, 54.0f}; + vector expected_vals = {21.0f, 56.0f, 13.0f, 16.0f, 2.0f, + 63.0f, 35.0f, 67.0f, 10.0f, 14.0f, + 24.0f, 22.0f, 76.0f, 76.0f, 21.0f, + 9.0f, 5.0f, 88.0f, 45.0f, 63.0f, + 3.0f, 2.0f, 33.0f, 18.0f, 54.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -658,12 +690,12 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_4) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 6, 6}; - auto expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, 16.0f, 2.0f, - 63.0f, 35.0f, 49.0f, 18.0f, 10.0f, 14.0f, - 21.0f, 14.0f, 42.0f, 6.0f, 4.0f, 12.0f, - 3.0f, 8.0f, 1.0f, 27.0f, 72.0f, 9.0f, - 9.0f, 5.0f, 7.0f, 81.0f, 45.0f, 63.0f, - 3.0f, 2.0f, 6.0f, 27.0f, 18.0f, 54.0f}; + vector expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, 16.0f, 2.0f, + 63.0f, 35.0f, 49.0f, 18.0f, 10.0f, 14.0f, + 21.0f, 14.0f, 42.0f, 6.0f, 4.0f, 12.0f, + 3.0f, 8.0f, 1.0f, 27.0f, 72.0f, 9.0f, + 9.0f, 5.0f, 7.0f, 81.0f, 45.0f, 63.0f, + 3.0f, 2.0f, 6.0f, 27.0f, 18.0f, 54.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -684,9 +716,9 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_1) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 3, 3}; - auto expected_vals = {42.0f, 6.0f, 4.0f, - 1.0f, 27.0f, 72.0f, - 7.0f, 81.0f, 45.0f}; + vector expected_vals = {42.0f, 6.0f, 4.0f, + 1.0f, 27.0f, 72.0f, + 7.0f, 81.0f, 45.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -707,9 +739,9 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_2) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 3, 3}; - auto expected_vals = {35.0f, 49.0f, 18.0f, - 14.0f, 42.0f, 6.0f, - 8.0f, 1.0f, 27.0f}; + vector expected_vals = {35.0f, 49.0f, 18.0f, + 14.0f, 42.0f, 6.0f, + 8.0f, 1.0f, 27.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -730,10 +762,10 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_3) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = {42.0f, 6.0f, 4.0f, 12.0f, - 1.0f, 27.0f, 72.0f, 9.0f, - 7.0f, 81.0f, 45.0f, 63.0f, - 6.0f, 27.0f, 18.0f, 54.0f}; + vector expected_vals = {42.0f, 6.0f, 4.0f, 12.0f, + 1.0f, 27.0f, 72.0f, 9.0f, + 7.0f, 81.0f, 45.0f, 63.0f, + 6.0f, 27.0f, 18.0f, 54.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -754,10 +786,10 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_4) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, - 63.0f, 35.0f, 49.0f, 18.0f, - 21.0f, 14.0f, 42.0f, 6.0f, - 3.0f, 8.0f, 1.0f, 27.0f}; + vector expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, + 63.0f, 35.0f, 49.0f, 18.0f, + 21.0f, 14.0f, 42.0f, 6.0f, + 3.0f, 8.0f, 1.0f, 27.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -778,16 +810,16 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_Group_1) { vector W = {9.0f, 3.0f, 1.0f, 2.0f, 3.0f, 7.0f, 0.0f, 8.0f}; vector W_shape = {2, 1, 2, 2}; vector Y_shape = {1, 2, 5, 5}; - auto expected_vals = {27.0f, 72.0f, 18.0f, 24.0f, 3.0f, - 81.0f, 45.0f, 90.0f, 15.0f, 21.0f, - 30.0f, 26.0f, 43.0f, 22.0f, 11.0f, - 9.0f, 5.0f, 25.0f, 10.0f, 14.0f, - 3.0f, 2.0f, 9.0f, 4.0f, 6.0f, - 21.0f, 27.0f, 52.0f, 63.0f, 7.0f, - 15.0f, 6.0f, 44.0f, 14.0f, 21.0f, - 27.0f, 0.0f, 125.0f, 72.0f, 22.0f, - 0.0f, 0.0f, 40.0f, 16.0f, 24.0f, - 0.0f, 0.0f, 72.0f, 0.0f, 16.0f}; + vector expected_vals = {27.0f, 72.0f, 18.0f, 24.0f, 3.0f, + 81.0f, 45.0f, 90.0f, 15.0f, 21.0f, + 30.0f, 26.0f, 43.0f, 22.0f, 11.0f, + 9.0f, 5.0f, 25.0f, 10.0f, 14.0f, + 3.0f, 2.0f, 9.0f, 4.0f, 6.0f, + 21.0f, 27.0f, 52.0f, 63.0f, 7.0f, + 15.0f, 6.0f, 44.0f, 14.0f, 21.0f, + 27.0f, 0.0f, 125.0f, 72.0f, 22.0f, + 0.0f, 0.0f, 40.0f, 16.0f, 24.0f, + 0.0f, 0.0f, 72.0f, 0.0f, 16.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -808,7 +840,7 @@ TEST(ConvTransposeTest, ConvTranspose_DefaultStridesAndDilations) { vector W = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.}; vector W_shape = {2, 3, 2, 2}; // this requires weight transpose vector Y_shape = {1, 3, 4, 4}; - auto expected_vals = { + vector expected_vals = { 108.f, 237.f, 263.f, 145.f, 270.f, 592.f, 652.f, 358.f, 354.f, 772.f, 832.f, 454.f, @@ -841,7 +873,7 @@ TEST(ConvTransposeTest, ConvTranspose_2D_NonDefaultStridesAndDilations) { vector W = {1., 1., 1., 1.}; vector W_shape = {1, 1, 1, 4}; vector Y_shape = {1, 1, 1, 12}; - auto expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; + vector expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -862,7 +894,7 @@ TEST(ConvTransposeTest, ConvTranspose_2D_NonDefaultStridesAndDilations_T) { vector W = {1., 1., 1., 1.}; vector W_shape = {1, 1, 4, 1}; vector Y_shape = {1, 1, 12, 1}; - auto expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; + vector expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -885,7 +917,7 @@ TEST(ConvTransposeTest, DimWithZero) { 0.04118127f, -0.44696793f, 0.06373066f}; vector W_shape = {1, 1, 3, 3}; vector Y_shape = {0, 1, 6, 6}; - initializer_list expected_vals = {}; + vector expected_vals = {}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -952,75 +984,75 @@ TEST(ConvTransposeTest, ConvTranspose_3D) { vector B = {-0.11784090101718903f, -0.060990236699581146f}; vector Y_shape = {1, 2, 5, 6, 7}; - auto expected_vals = {-0.08241813f, -0.06676699f, -0.13411677f, -0.15724352f, -0.18772511f, -0.11080553f, -0.114930674f, - -0.0953398f, -0.111061305f, -0.0413035f, -0.10902196f, -0.071916685f, -0.102583766f, -0.13639182f, - -0.21214074f, -0.18799849f, -0.15122052f, 0.00434383f, -0.011207409f, -0.11604968f, -0.08378546f, - -0.1722928f, -0.044016793f, -0.1914465f, -0.16952308f, -0.39505655f, 0.080385f, -0.15767722f, - -0.060116887f, -0.16235165f, -0.075614765f, -0.14631891f, 0.05837299f, -0.31712085f, -0.13272354f, - -0.08320008f, -0.1967324f, -0.033198006f, -0.06718128f, -0.2568521f, 0.0314174f, -0.15864298f, - - -0.13070306f, -0.09003539f, -0.29147533f, -0.024966106f, 0.079442084f, -0.096389435f, -0.09941827f, - -0.3365072f, -0.4451772f, -0.13154466f, -0.08992967f, -0.16572365f, 0.06494926f, -0.21230686f, - -0.11307171f, -0.056943115f, -0.35291147f, -0.317253f, -0.070464894f, -0.6300395f, -0.031246513f, - 0.19395588f, 0.011135533f, 0.096916616f, -0.3942836f, -0.29872403f, 0.16881491f, -0.24881886f, - -0.038873613f, -0.032735735f, -0.21593677f, 0.088557824f, 0.13849314f, -0.30753696f, -0.07219358f, - -0.15177673f, -0.09156879f, -0.2286228f, 0.080623806f, -0.39201033f, 0.07819712f, -0.19924995f, - - -0.3376814f, -0.033524483f, 0.230105f, -0.0377952f, -0.12315659f, -0.28858358f, -0.13848148f, - -0.16134796f, 0.012239918f, 0.27276647f, 0.020731017f, -0.4651906f, -0.14341736f, -0.07956973f, - 0.1342433f, -0.16956037f, 0.310399f, 0.34338957f, -0.040192716f, 0.12504166f, -0.21490449f, - -0.15410437f, -0.1338158f, -0.39244395f, 0.29117042f, -0.26415867f, -0.4450379f, 0.0699404f, - 0.042872816f, -0.14961651f, -0.17582522f, -0.6919577f, -0.13723494f, -0.0681901f, -0.16183335f, - -0.0021959245f, -0.0418434f, -0.32134426f, 0.16967098f, -0.08680786f, -0.32077473f, 0.0066963434f, - - -0.114091426f, -0.041066267f, -0.080250874f, -0.72594404f, -0.30254412f, -0.03862554f, -0.27475363f, - 0.15282185f, -0.22887689f, -0.72043663f, -0.47111863f, -0.3755179f, -0.20074406f, 0.16101281f, - -0.20939936f, -0.21245953f, 0.11726546f, -0.8030824f, -0.5866715f, 0.20001571f, -0.26259118f, - 0.17054747f, 0.061063558f, -0.6348493f, 0.2620284f, -0.782919f, -0.31278569f, 0.2926497f, - -0.08745579f, 0.20646049f, -0.050303012f, -0.13460274f, 0.060659587f, -0.037006564f, -0.1292249f, - -0.11211421f, -0.038967483f, -0.21644044f, -0.24912538f, 0.08591288f, -0.40798867f, 0.006527111f, - - -0.049734667f, -0.3685795f, -0.11538547f, 0.27292788f, 0.025990233f, 0.119311824f, 0.0700129f, - -0.156443f, -0.13340846f, 0.10764159f, -0.014803357f, 0.046525866f, 0.015691683f, -0.1869241f, - 0.1004442f, -0.4885978f, -0.7585998f, -0.047841772f, -0.07570776f, 0.0471261f, 0.24483289f, - -0.16554686f, -0.1250152f, -0.15132052f, -0.08515984f, 0.14412321f, -0.1030291f, -0.2780918f, - 0.05803944f, -0.10257156f, -0.4341917f, -0.13150966f, -0.53996617f, -0.15628646f, 0.059058204f, - -0.11976162f, -0.022163756f, -0.13519828f, -0.20148787f, 0.16934697f, -0.14327072f, -0.2129095f, - - -0.107836396f, -0.0819309f, -0.06148723f, -0.0063935146f, -0.02425649f, -0.056219954f, -0.06095987f, - -0.14403576f, -0.025357183f, -0.15828207f, 0.012748428f, -0.16061643f, -0.03419252f, -0.05130991f, - -0.109983265f, -0.08312916f, -0.07035978f, -0.008285124f, -0.10610263f, -0.01489019f, -0.106886685f, - -0.007659614f, -0.2947925f, -0.09132287f, -0.040577132f, 0.089866154f, -0.24528673f, -0.055424154f, - 0.13783869f, 0.023674607f, -0.10545369f, -0.20873478f, -0.4685722f, 0.09418375f, -0.06684458f, - 0.0410614f, 0.04018917f, -0.15845582f, 0.06580096f, 0.070554025f, -0.19462511f, -0.03526502f, - - -0.02956047f, -0.16035908f, -0.0638171f, -0.261022f, -0.022948403f, 0.08353848f, -0.041173913f, - 0.04770004f, 0.091520615f, 0.006987013f, -0.39962748f, 0.23266485f, -0.32719564f, -0.12885109f, - -0.29559937f, -0.08031146f, 0.76168066f, 0.0009028502f, -0.4091536f, -0.14801738f, -0.17058557f, - -0.05754847f, 0.2955231f, -0.089874476f, 0.17254886f, -0.13203058f, -0.007648442f, 0.010943003f, - 0.04123217f, 0.26074114f, -0.24313056f, 0.1008903f, -0.26472318f, 0.01998391f, -0.03422378f, - -0.024659738f, 0.033793047f, -0.1998924f, -0.110185415f, 0.10620246f, -0.3435271f, 0.019390412f, - - 0.21691665f, -0.26076952f, -0.5040901f, 0.28383943f, -0.34750903f, -0.32484284f, -0.01734912f, - -0.08909689f, -0.0466362f, 0.21648785f, 0.06733417f, 0.009496197f, 0.18728223f, -0.35110205f, - -0.04908372f, -0.36729553f, -0.346236f, -0.13589534f, -0.16435221f, -0.16853788f, 0.12264759f, - -0.019215636f, -0.38316554f, 0.35669535f, -0.56980205f, -0.059346225f, 0.15008381f, -0.1751053f, - 0.059508912f, 0.116622455f, -0.32607535f, -0.22282779f, -0.29149055f, -0.3829086f, 0.15905643f, - -0.077926554f, 0.06549884f, -0.09004557f, -0.15897253f, 0.26810864f, -0.08931713f, -0.047756508f, - - -0.14657992f, 0.43070868f, -0.021787114f, -0.4532621f, 0.092385404f, -0.30126676f, -0.24893704f, - -0.10896815f, -0.14514503f, -0.21353528f, 0.018723361f, 0.037694372f, 0.11514955f, 0.13013864f, - -0.25713888f, -0.056000195f, -0.3505367f, 0.0836427f, -0.032017898f, -0.26742116f, -0.14740711f, - -0.13330215f, -0.18958306f, -0.08968873f, 0.014723815f, -0.20343366f, 0.3098968f, 0.114284225f, - -0.026738256f, -0.14110464f, -0.054464605f, -0.17529932f, -0.0030034669f, -0.050670102f, -0.04016705f, - -0.062238634f, -0.04886609f, -0.042247344f, -0.12185234f, 0.0357792f, -0.10265522f, -0.116296895f, - - -0.1035416f, -0.09126053f, 0.20045105f, 0.12366664f, 0.05460281f, 0.09944453f, -0.055443168f, - -0.09767935f, -0.040166672f, -0.01716708f, 0.020299219f, 0.02864775f, -0.07159522f, -0.04354491f, - -0.1390779f, -0.13270372f, 0.02992779f, -0.025869183f, 0.12530258f, 0.05101595f, -0.07891131f, - -0.1051311f, -0.093200594f, -0.10368025f, 0.047598884f, -0.12069465f, -0.098738566f, -0.042393237f, - -0.08531736f, -0.051284637f, -0.04354899f, -0.06810297f, -0.083224006f, -0.11702064f, -0.08514082f, - -0.06071842f, -0.07496775f, -0.03626109f, -0.07785503f, -0.07243007f, -0.041736744f, -0.052593358f}; + vector expected_vals = {-0.08241813f, -0.06676699f, -0.13411677f, -0.15724352f, -0.18772511f, -0.11080553f, -0.114930674f, + -0.0953398f, -0.111061305f, -0.0413035f, -0.10902196f, -0.071916685f, -0.102583766f, -0.13639182f, + -0.21214074f, -0.18799849f, -0.15122052f, 0.00434383f, -0.011207409f, -0.11604968f, -0.08378546f, + -0.1722928f, -0.044016793f, -0.1914465f, -0.16952308f, -0.39505655f, 0.080385f, -0.15767722f, + -0.060116887f, -0.16235165f, -0.075614765f, -0.14631891f, 0.05837299f, -0.31712085f, -0.13272354f, + -0.08320008f, -0.1967324f, -0.033198006f, -0.06718128f, -0.2568521f, 0.0314174f, -0.15864298f, + + -0.13070306f, -0.09003539f, -0.29147533f, -0.024966106f, 0.079442084f, -0.096389435f, -0.09941827f, + -0.3365072f, -0.4451772f, -0.13154466f, -0.08992967f, -0.16572365f, 0.06494926f, -0.21230686f, + -0.11307171f, -0.056943115f, -0.35291147f, -0.317253f, -0.070464894f, -0.6300395f, -0.031246513f, + 0.19395588f, 0.011135533f, 0.096916616f, -0.3942836f, -0.29872403f, 0.16881491f, -0.24881886f, + -0.038873613f, -0.032735735f, -0.21593677f, 0.088557824f, 0.13849314f, -0.30753696f, -0.07219358f, + -0.15177673f, -0.09156879f, -0.2286228f, 0.080623806f, -0.39201033f, 0.07819712f, -0.19924995f, + + -0.3376814f, -0.033524483f, 0.230105f, -0.0377952f, -0.12315659f, -0.28858358f, -0.13848148f, + -0.16134796f, 0.012239918f, 0.27276647f, 0.020731017f, -0.4651906f, -0.14341736f, -0.07956973f, + 0.1342433f, -0.16956037f, 0.310399f, 0.34338957f, -0.040192716f, 0.12504166f, -0.21490449f, + -0.15410437f, -0.1338158f, -0.39244395f, 0.29117042f, -0.26415867f, -0.4450379f, 0.0699404f, + 0.042872816f, -0.14961651f, -0.17582522f, -0.6919577f, -0.13723494f, -0.0681901f, -0.16183335f, + -0.0021959245f, -0.0418434f, -0.32134426f, 0.16967098f, -0.08680786f, -0.32077473f, 0.0066963434f, + + -0.114091426f, -0.041066267f, -0.080250874f, -0.72594404f, -0.30254412f, -0.03862554f, -0.27475363f, + 0.15282185f, -0.22887689f, -0.72043663f, -0.47111863f, -0.3755179f, -0.20074406f, 0.16101281f, + -0.20939936f, -0.21245953f, 0.11726546f, -0.8030824f, -0.5866715f, 0.20001571f, -0.26259118f, + 0.17054747f, 0.061063558f, -0.6348493f, 0.2620284f, -0.782919f, -0.31278569f, 0.2926497f, + -0.08745579f, 0.20646049f, -0.050303012f, -0.13460274f, 0.060659587f, -0.037006564f, -0.1292249f, + -0.11211421f, -0.038967483f, -0.21644044f, -0.24912538f, 0.08591288f, -0.40798867f, 0.006527111f, + + -0.049734667f, -0.3685795f, -0.11538547f, 0.27292788f, 0.025990233f, 0.119311824f, 0.0700129f, + -0.156443f, -0.13340846f, 0.10764159f, -0.014803357f, 0.046525866f, 0.015691683f, -0.1869241f, + 0.1004442f, -0.4885978f, -0.7585998f, -0.047841772f, -0.07570776f, 0.0471261f, 0.24483289f, + -0.16554686f, -0.1250152f, -0.15132052f, -0.08515984f, 0.14412321f, -0.1030291f, -0.2780918f, + 0.05803944f, -0.10257156f, -0.4341917f, -0.13150966f, -0.53996617f, -0.15628646f, 0.059058204f, + -0.11976162f, -0.022163756f, -0.13519828f, -0.20148787f, 0.16934697f, -0.14327072f, -0.2129095f, + + -0.107836396f, -0.0819309f, -0.06148723f, -0.0063935146f, -0.02425649f, -0.056219954f, -0.06095987f, + -0.14403576f, -0.025357183f, -0.15828207f, 0.012748428f, -0.16061643f, -0.03419252f, -0.05130991f, + -0.109983265f, -0.08312916f, -0.07035978f, -0.008285124f, -0.10610263f, -0.01489019f, -0.106886685f, + -0.007659614f, -0.2947925f, -0.09132287f, -0.040577132f, 0.089866154f, -0.24528673f, -0.055424154f, + 0.13783869f, 0.023674607f, -0.10545369f, -0.20873478f, -0.4685722f, 0.09418375f, -0.06684458f, + 0.0410614f, 0.04018917f, -0.15845582f, 0.06580096f, 0.070554025f, -0.19462511f, -0.03526502f, + + -0.02956047f, -0.16035908f, -0.0638171f, -0.261022f, -0.022948403f, 0.08353848f, -0.041173913f, + 0.04770004f, 0.091520615f, 0.006987013f, -0.39962748f, 0.23266485f, -0.32719564f, -0.12885109f, + -0.29559937f, -0.08031146f, 0.76168066f, 0.0009028502f, -0.4091536f, -0.14801738f, -0.17058557f, + -0.05754847f, 0.2955231f, -0.089874476f, 0.17254886f, -0.13203058f, -0.007648442f, 0.010943003f, + 0.04123217f, 0.26074114f, -0.24313056f, 0.1008903f, -0.26472318f, 0.01998391f, -0.03422378f, + -0.024659738f, 0.033793047f, -0.1998924f, -0.110185415f, 0.10620246f, -0.3435271f, 0.019390412f, + + 0.21691665f, -0.26076952f, -0.5040901f, 0.28383943f, -0.34750903f, -0.32484284f, -0.01734912f, + -0.08909689f, -0.0466362f, 0.21648785f, 0.06733417f, 0.009496197f, 0.18728223f, -0.35110205f, + -0.04908372f, -0.36729553f, -0.346236f, -0.13589534f, -0.16435221f, -0.16853788f, 0.12264759f, + -0.019215636f, -0.38316554f, 0.35669535f, -0.56980205f, -0.059346225f, 0.15008381f, -0.1751053f, + 0.059508912f, 0.116622455f, -0.32607535f, -0.22282779f, -0.29149055f, -0.3829086f, 0.15905643f, + -0.077926554f, 0.06549884f, -0.09004557f, -0.15897253f, 0.26810864f, -0.08931713f, -0.047756508f, + + -0.14657992f, 0.43070868f, -0.021787114f, -0.4532621f, 0.092385404f, -0.30126676f, -0.24893704f, + -0.10896815f, -0.14514503f, -0.21353528f, 0.018723361f, 0.037694372f, 0.11514955f, 0.13013864f, + -0.25713888f, -0.056000195f, -0.3505367f, 0.0836427f, -0.032017898f, -0.26742116f, -0.14740711f, + -0.13330215f, -0.18958306f, -0.08968873f, 0.014723815f, -0.20343366f, 0.3098968f, 0.114284225f, + -0.026738256f, -0.14110464f, -0.054464605f, -0.17529932f, -0.0030034669f, -0.050670102f, -0.04016705f, + -0.062238634f, -0.04886609f, -0.042247344f, -0.12185234f, 0.0357792f, -0.10265522f, -0.116296895f, + + -0.1035416f, -0.09126053f, 0.20045105f, 0.12366664f, 0.05460281f, 0.09944453f, -0.055443168f, + -0.09767935f, -0.040166672f, -0.01716708f, 0.020299219f, 0.02864775f, -0.07159522f, -0.04354491f, + -0.1390779f, -0.13270372f, 0.02992779f, -0.025869183f, 0.12530258f, 0.05101595f, -0.07891131f, + -0.1051311f, -0.093200594f, -0.10368025f, 0.047598884f, -0.12069465f, -0.098738566f, -0.042393237f, + -0.08531736f, -0.051284637f, -0.04354899f, -0.06810297f, -0.083224006f, -0.11702064f, -0.08514082f, + -0.06071842f, -0.07496775f, -0.03626109f, -0.07785503f, -0.07243007f, -0.041736744f, -0.052593358f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -1045,7 +1077,7 @@ TEST(ConvTransposeTest, ConvTranspose_1D_AsymmetricPads) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 2}; vector Y_shape = {1, 2, 4}; - auto expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; + vector expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); @@ -1068,7 +1100,7 @@ TEST(ConvTransposeTest, ConvTranspose_1D_AutoPad_SameUpper) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 2}; vector Y_shape = {1, 2, 4}; - auto expected_vals = {1.0f, 3.0f, 5.0f, 7.0f, 1.0f, 3.0f, 5.0f, 7.0f}; + vector expected_vals = {1.0f, 3.0f, 5.0f, 7.0f, 1.0f, 3.0f, 5.0f, 7.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -1092,7 +1124,7 @@ TEST(ConvTransposeTest, ConvTranspose_1D_AutoPad_SameLower) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 2}; vector Y_shape = {1, 2, 4}; - auto expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; + vector expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -1125,19 +1157,19 @@ TEST(ConvTransposeTest, ConvTranspose_AutoPad_with_non_default_strides) { 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 3, 3}; - auto expected_vals = {0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, - 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, - 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, - 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f, - - 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, - 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, - 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, - 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f}; + vector expected_vals = {0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, + 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, + 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, + 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f, + + 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, + 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, + 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, + 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f}; vector Y_shape = {1, 2, 6, 6}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, @@ -1170,7 +1202,7 @@ TEST(ConvTransposeTest, SharedPrepackedWeights) { W.push_back(1.0f); test.AddInput("W", {6, 3, 3, 3}, W, true); // Trigger pre-packing - auto expected_vals = { + vector expected_vals = { 12.0f, 18.0f, 18.0f, diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index b033ddbca23d6..7d736d41e804b 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) #include "core/providers/cpu/nn/pool.h" #include "gtest/gtest.h" @@ -567,4 +567,4 @@ TEST(PoolFp16Test, GlobalAveragePool) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 885fb11c6e999..a340f975ec91a 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -9,6 +9,13 @@ using namespace std; namespace onnxruntime { namespace test { +template +class PoolTest : public ::testing::Test { +}; + +using PoolTestTypes = ::testing::Types; +TYPED_TEST_SUITE(PoolTest, PoolTestTypes); + // Disable TensorRT on some of the tests because "pads" attribute is not supported TEST(PoolTest, MaxPool) { @@ -63,13 +70,15 @@ TEST(PoolTest, MaxPool) { // Only CUDA kernel has float 16 support // Disable for now, still investigating the issue with cudnn lib -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(COREML_ENABLE_MLPROGRAM) TEST(PoolTest, MaxPool_F16) { +#if defined(USE_CUDA) int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; return; } +#endif OpTester test("MaxPool"); test.AddAttribute("auto_pad", ""); @@ -672,7 +681,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_3d) { {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } -TEST(PoolTest, GlobalMaxPool) { +TYPED_TEST(PoolTest, GlobalMaxPool) { OpTester test("GlobalMaxPool"); std::vector x_vals = {0.19151945412158966, 0.6221087574958801, 0.43772774934768677, @@ -743,12 +752,23 @@ TEST(PoolTest, GlobalMaxPool) { std::vector expected_dims = {1, 3, 1, 1}; std::vector expected_vals = {0.9920814633369446, 0.9820047616958618, 0.9946538209915161}; - test.AddInput("X", x_dims, x_vals); - test.AddOutput("Y", expected_dims, expected_vals); + if constexpr (std::is_same::value) { + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + } else { + std::vector x_vals_fp16(x_vals.size()); + std::vector expected_vals_fp16(expected_vals.size()); + + ConvertFloatToMLFloat16(x_vals.data(), x_vals_fp16.data(), x_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + test.AddInput("X", x_dims, x_vals_fp16); + test.AddOutput("Y", expected_dims, expected_vals_fp16); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } -TEST(PoolTest, GlobalMaxPool3D) { +TYPED_TEST(PoolTest, GlobalMaxPool3D) { OpTester test("GlobalMaxPool"); std::vector x_vals = {0.19151945412158966, 0.6221087574958801, 0.43772774934768677, @@ -819,8 +839,19 @@ TEST(PoolTest, GlobalMaxPool3D) { std::vector expected_dims = {1, 3, 1, 1, 1}; std::vector expected_vals = {0.9920814633369446, 0.9820047616958618, 0.9946538209915161}; - test.AddInput("X", x_dims, x_vals); - test.AddOutput("Y", expected_dims, expected_vals); + if constexpr (std::is_same::value) { + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + } else { + std::vector x_vals_fp16(x_vals.size()); + std::vector expected_vals_fp16(expected_vals.size()); + + ConvertFloatToMLFloat16(x_vals.data(), x_vals_fp16.data(), x_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + test.AddInput("X", x_dims, x_vals_fp16); + test.AddOutput("Y", expected_dims, expected_vals_fp16); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 98f57f4573540..4a1888a5ca7d6 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -7,6 +7,13 @@ namespace onnxruntime { namespace test { +template +class ConcatOpTest : public ::testing::Test { +}; + +using ConcatOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ConcatOpTest, ConcatOpTestTypes); + // Some of the tests can't run on TensorrtExecutionProvider because of unsupported data types or limits // in its parser: axis >=0 && axis < nbDims. Those Tests will fallback to other EPs @@ -68,34 +75,45 @@ TEST(ConcatOpTest, Concat1D_2) { kQnnExecutionProvider}); // QNN: not support dynamic shape tensor } -TEST(ConcatOpTest, Concat2D_1) { +template +static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + +TYPED_TEST(ConcatOpTest, Concat2D_1) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{0}); std::vector dims{1, 4}; - test.AddInput("input1", dims, {11.0f, 12.0f, 13.0f, 14.0f}); - test.AddInput("input2", dims, {21.0f, 22.0f, 23.0f, 24.0f}); - test.AddInput("input3", dims, {31.0f, 32.0f, 33.0f, 34.0f}); - test.AddOutput("concat_result", {3, 4}, - {11.0f, 12.0f, 13.0f, 14.0f, - 21.0f, 22.0f, 23.0f, 24.0f, - 31.0f, 32.0f, 33.0f, 34.0f}); + test.AddInput("input1", dims, GetTypedArray({11.0f, 12.0f, 13.0f, 14.0f})); + test.AddInput("input2", dims, GetTypedArray({21.0f, 22.0f, 23.0f, 24.0f})); + test.AddInput("input3", dims, GetTypedArray({31.0f, 32.0f, 33.0f, 34.0f})); + test.AddOutput("concat_result", {3, 4}, + GetTypedArray({11.0f, 12.0f, 13.0f, 14.0f, + 21.0f, 22.0f, 23.0f, 24.0f, + 31.0f, 32.0f, 33.0f, 34.0f})); test.Run(); } -TEST(ConcatOpTest, Concat2D_2) { +TYPED_TEST(ConcatOpTest, Concat2D_2) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{1}); std::vector dims{4, 1}; - test.AddInput("input1", dims, {11.0f, 21.0f, 31.0f, 41.0f}); - test.AddInput("input2", {4, 2}, {12.0f, 13.0f, 22.0f, 23.0f, 32.0f, 33.0f, 42.0f, 43.0f}); - test.AddInput("input3", dims, {14.0f, 24.0f, 34.0f, 44.0f}); - test.AddOutput("concat_result", {4, 4}, - {11.0f, 12.0f, 13.0f, 14.0f, - 21.0f, 22.0f, 23.0f, 24.0f, - 31.0f, 32.0f, 33.0f, 34.0f, - 41.0f, 42.0f, 43.0f, 44.0f}); + test.AddInput("input1", dims, GetTypedArray({11.0f, 21.0f, 31.0f, 41.0f})); + test.AddInput("input2", {4, 2}, GetTypedArray({12.0f, 13.0f, 22.0f, 23.0f, 32.0f, 33.0f, 42.0f, 43.0f})); + test.AddInput("input3", dims, GetTypedArray({14.0f, 24.0f, 34.0f, 44.0f})); + test.AddOutput("concat_result", {4, 4}, + GetTypedArray({11.0f, 12.0f, 13.0f, 14.0f, + 21.0f, 22.0f, 23.0f, 24.0f, + 31.0f, 32.0f, 33.0f, 34.0f, + 41.0f, 42.0f, 43.0f, 44.0f})); test.Run(); } diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 540dc6dee68fb..05cfb5c13d689 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -40,963 +40,970 @@ void RunTests(T& test, std::vector>&& execut // DO NOT edit following tests. They are generated by: // onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { +template +class GridSampleTest : public ::testing::Test { +}; + +using GridSampleTestTypes = ::testing::Types; +TYPED_TEST_SUITE(GridSampleTest, GridSampleTestTypes); + +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.125840f, -1.152360f, -0.250579f, -0.433879f, 0.848710f, 0.692009f, -0.316013f, -2.115219f, 0.468096f, -0.157712f, 1.443660f, 0.266049f, 0.166455f, 0.874382f, -0.143474f, -0.111609f, 0.931827f, 1.259009f, 2.004981f, 0.053737f, 0.618057f, -0.412802f, -0.841065f, -2.316042f}; + std::initializer_list X_data{TypeParam(-1.125840f), TypeParam(-1.152360f), TypeParam(-0.250579f), TypeParam(-0.433879f), TypeParam(0.848710f), TypeParam(0.692009f), TypeParam(-0.316013f), TypeParam(-2.115219f), TypeParam(0.468096f), TypeParam(-0.157712f), TypeParam(1.443660f), TypeParam(0.266049f), TypeParam(0.166455f), TypeParam(0.874382f), TypeParam(-0.143474f), TypeParam(-0.111609f), TypeParam(0.931827f), TypeParam(1.259009f), TypeParam(2.004981f), TypeParam(0.053737f), TypeParam(0.618057f), TypeParam(-0.412802f), TypeParam(-0.841065f), TypeParam(-2.316042f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.063110f, -0.615220f, 0.203022f, -1.120434f, -0.867079f, -0.618636f, 0.757125f, 0.703586f, -0.532194f, -0.043299f, 0.767473f, 1.192960f, 0.476259f, 0.162111f, 0.804584f, -0.706563f, 0.223613f, -0.930367f, -0.831703f, -0.619900f, 0.542968f, 0.482592f, -0.710823f, 0.362529f}; + std::initializer_list Grid_data{TypeParam(0.063110f), TypeParam(-0.615220f), TypeParam(0.203022f), TypeParam(-1.120434f), TypeParam(-0.867079f), TypeParam(-0.618636f), TypeParam(0.757125f), TypeParam(0.703586f), TypeParam(-0.532194f), TypeParam(-0.043299f), TypeParam(0.767473f), TypeParam(1.192960f), TypeParam(0.476259f), TypeParam(0.162111f), TypeParam(0.804584f), TypeParam(-0.706563f), TypeParam(0.223613f), TypeParam(-0.930367f), TypeParam(-0.831703f), TypeParam(-0.619900f), TypeParam(0.542968f), TypeParam(0.482592f), TypeParam(-0.710823f), TypeParam(0.362529f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.152360f, -1.152360f, -1.125840f, 0.692009f, -0.250579f, 0.692009f, -2.115219f, -2.115219f, -0.316013f, 0.266049f, 0.468096f, 0.266049f, -0.111609f, 0.874382f, 0.874382f, 0.166455f, -0.111609f, -0.143474f, -0.412802f, 0.053737f, 0.053737f, 2.004981f, -0.412802f, 0.618057f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.152360f), TypeParam(-1.152360f), TypeParam(-1.125840f), TypeParam(0.692009f), TypeParam(-0.250579f), TypeParam(0.692009f), TypeParam(-2.115219f), TypeParam(-2.115219f), TypeParam(-0.316013f), TypeParam(0.266049f), TypeParam(0.468096f), TypeParam(0.266049f), TypeParam(-0.111609f), TypeParam(0.874382f), TypeParam(0.874382f), TypeParam(0.166455f), TypeParam(-0.111609f), TypeParam(-0.143474f), TypeParam(-0.412802f), TypeParam(0.053737f), TypeParam(0.053737f), TypeParam(2.004981f), TypeParam(-0.412802f), TypeParam(0.618057f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.569248f, 0.919971f, 1.110816f, 1.289874f, -1.478174f, 2.567233f, -0.473120f, 0.335551f, -0.003304f, -0.534441f, 1.168688f, 0.394503f, 1.941462f, 0.791498f, -0.020252f, -0.437170f, -1.535287f, -0.412679f, 0.966303f, 1.624783f, -0.365619f, -1.302440f, 0.099403f, 0.441822f}; + std::initializer_list X_data{TypeParam(-0.569248f), TypeParam(0.919971f), TypeParam(1.110816f), TypeParam(1.289874f), TypeParam(-1.478174f), TypeParam(2.567233f), TypeParam(-0.473120f), TypeParam(0.335551f), TypeParam(-0.003304f), TypeParam(-0.534441f), TypeParam(1.168688f), TypeParam(0.394503f), TypeParam(1.941462f), TypeParam(0.791498f), TypeParam(-0.020252f), TypeParam(-0.437170f), TypeParam(-1.535287f), TypeParam(-0.412679f), TypeParam(0.966303f), TypeParam(1.624783f), TypeParam(-0.365619f), TypeParam(-1.302440f), TypeParam(0.099403f), TypeParam(0.441822f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-1.143118f, -0.021569f, -0.903671f, -0.925628f, -0.066120f, 0.180174f, -0.491436f, 0.712053f, -0.730247f, 1.088844f, 0.822360f, -1.011940f, -0.298661f, 0.054147f, 0.175081f, 0.284609f, 0.470914f, 0.071880f, -0.585515f, 0.567827f, -1.151099f, -0.711248f, -0.300396f, -0.584536f}; + std::initializer_list Grid_data{TypeParam(-1.143118f), TypeParam(-0.021569f), TypeParam(-0.903671f), TypeParam(-0.925628f), TypeParam(-0.066120f), TypeParam(0.180174f), TypeParam(-0.491436f), TypeParam(0.712053f), TypeParam(-0.730247f), TypeParam(1.088844f), TypeParam(0.822360f), TypeParam(-1.011940f), TypeParam(-0.298661f), TypeParam(0.054147f), TypeParam(0.175081f), TypeParam(0.284609f), TypeParam(0.470914f), TypeParam(0.071880f), TypeParam(-0.585515f), TypeParam(0.567827f), TypeParam(-1.151099f), TypeParam(-0.711248f), TypeParam(-0.300396f), TypeParam(-0.584536f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.000000f, -0.569248f, 1.110816f, -1.478174f, 0.000000f, 0.000000f, 0.000000f, -0.473120f, -0.003304f, 1.168688f, 0.000000f, 0.000000f, -0.020252f, -0.437170f, -0.437170f, -1.535287f, 0.000000f, 1.941462f, -0.365619f, -1.302440f, -1.302440f, 0.099403f, 0.000000f, 0.966303f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.000000f), TypeParam(-0.569248f), TypeParam(1.110816f), TypeParam(-1.478174f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.473120f), TypeParam(-0.003304f), TypeParam(1.168688f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.020252f), TypeParam(-0.437170f), TypeParam(-0.437170f), TypeParam(-1.535287f), TypeParam(0.000000f), TypeParam(1.941462f), TypeParam(-0.365619f), TypeParam(-1.302440f), TypeParam(-1.302440f), TypeParam(0.099403f), TypeParam(0.000000f), TypeParam(0.966303f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.883376f, -0.418913f, -0.804826f, 0.565610f, 0.610365f, 0.466884f, 1.950657f, -1.063099f, -0.829367f, -1.407257f, 1.626847f, 0.172273f, -1.611502f, -0.479448f, -0.143351f, -0.317295f, 0.573655f, 0.997931f, 0.543609f, 0.078804f, 0.862860f, -0.019490f, 0.991047f, -0.777735f}; + std::initializer_list X_data{TypeParam(-0.883376f), TypeParam(-0.418913f), TypeParam(-0.804826f), TypeParam(0.565610f), TypeParam(0.610365f), TypeParam(0.466884f), TypeParam(1.950657f), TypeParam(-1.063099f), TypeParam(-0.829367f), TypeParam(-1.407257f), TypeParam(1.626847f), TypeParam(0.172273f), TypeParam(-1.611502f), TypeParam(-0.479448f), TypeParam(-0.143351f), TypeParam(-0.317295f), TypeParam(0.573655f), TypeParam(0.997931f), TypeParam(0.543609f), TypeParam(0.078804f), TypeParam(0.862860f), TypeParam(-0.019490f), TypeParam(0.991047f), TypeParam(-0.777735f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-1.080070f, -0.080985f, 1.055303f, -0.489470f, 1.083604f, 0.434584f, -1.082953f, 0.759237f, -0.138473f, -0.535688f, 0.959584f, -0.969714f, 0.128766f, -0.251242f, 0.856935f, 0.334973f, 0.576606f, 0.423791f, -0.288570f, -0.252367f, -0.988898f, 0.650213f, 0.952774f, 0.821070f}; + std::initializer_list Grid_data{TypeParam(-1.080070f), TypeParam(-0.080985f), TypeParam(1.055303f), TypeParam(-0.489470f), TypeParam(1.083604f), TypeParam(0.434584f), TypeParam(-1.082953f), TypeParam(0.759237f), TypeParam(-0.138473f), TypeParam(-0.535688f), TypeParam(0.959584f), TypeParam(-0.969714f), TypeParam(0.128766f), TypeParam(-0.251242f), TypeParam(0.856935f), TypeParam(0.334973f), TypeParam(0.576606f), TypeParam(0.423791f), TypeParam(-0.288570f), TypeParam(-0.252367f), TypeParam(-0.988898f), TypeParam(0.650213f), TypeParam(0.952774f), TypeParam(0.821070f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.804826f, 0.565610f, 0.565610f, 0.610365f, -0.883376f, -0.418913f, -0.829367f, -1.407257f, -1.407257f, 1.626847f, 1.950657f, -1.063099f, -0.317295f, -0.317295f, -0.317295f, -0.143351f, 0.573655f, 0.997931f, -0.019490f, -0.019490f, -0.019490f, 0.862860f, 0.991047f, -0.777735f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.804826f), TypeParam(0.565610f), TypeParam(0.565610f), TypeParam(0.610365f), TypeParam(-0.883376f), TypeParam(-0.418913f), TypeParam(-0.829367f), TypeParam(-1.407257f), TypeParam(-1.407257f), TypeParam(1.626847f), TypeParam(1.950657f), TypeParam(-1.063099f), TypeParam(-0.317295f), TypeParam(-0.317295f), TypeParam(-0.317295f), TypeParam(-0.143351f), TypeParam(0.573655f), TypeParam(0.997931f), TypeParam(-0.019490f), TypeParam(-0.019490f), TypeParam(-0.019490f), TypeParam(0.862860f), TypeParam(0.991047f), TypeParam(-0.777735f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.559630f, 0.533472f, 0.406887f, 0.394587f, 0.171511f, 0.876045f, -0.287087f, 1.021640f, 0.438649f, -0.010704f, 1.338354f, -0.279405f, -0.551834f, -2.889061f, -1.509981f, 1.024115f, 0.195393f, -0.737109f, 1.700101f, 0.346216f, 0.971125f, 1.450250f, -0.051909f, -0.628431f}; + std::initializer_list X_data{TypeParam(-0.559630f), TypeParam(0.533472f), TypeParam(0.406887f), TypeParam(0.394587f), TypeParam(0.171511f), TypeParam(0.876045f), TypeParam(-0.287087f), TypeParam(1.021640f), TypeParam(0.438649f), TypeParam(-0.010704f), TypeParam(1.338354f), TypeParam(-0.279405f), TypeParam(-0.551834f), TypeParam(-2.889061f), TypeParam(-1.509981f), TypeParam(1.024115f), TypeParam(0.195393f), TypeParam(-0.737109f), TypeParam(1.700101f), TypeParam(0.346216f), TypeParam(0.971125f), TypeParam(1.450250f), TypeParam(-0.051909f), TypeParam(-0.628431f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.149807f, 1.074831f, 0.734055f, -0.758657f, 0.538205f, -0.848275f, -0.508590f, 0.352947f, 0.396231f, 0.900274f, -0.386299f, 0.001921f, 0.617788f, -1.160511f, 0.867577f, -0.992307f, 0.016539f, -0.204020f, -0.632008f, 0.158605f, 0.992302f, -0.350783f, -0.712433f, -0.443807f}; + std::initializer_list Grid_data{TypeParam(0.149807f), TypeParam(1.074831f), TypeParam(0.734055f), TypeParam(-0.758657f), TypeParam(0.538205f), TypeParam(-0.848275f), TypeParam(-0.508590f), TypeParam(0.352947f), TypeParam(0.396231f), TypeParam(0.900274f), TypeParam(-0.386299f), TypeParam(0.001921f), TypeParam(0.617788f), TypeParam(-1.160511f), TypeParam(0.867577f), TypeParam(-0.992307f), TypeParam(0.016539f), TypeParam(-0.204020f), TypeParam(-0.632008f), TypeParam(0.158605f), TypeParam(0.992302f), TypeParam(-0.350783f), TypeParam(-0.712433f), TypeParam(-0.443807f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.876045f, 0.533472f, 0.533472f, 0.171511f, 0.876045f, 0.406887f, -0.279405f, 1.021640f, 1.021640f, 1.338354f, -0.279405f, 0.438649f, -2.889061f, -2.889061f, 1.024115f, -1.509981f, -2.889061f, -0.551834f, 0.346216f, 0.346216f, 1.450250f, 0.971125f, 0.346216f, 1.700101f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.876045f), TypeParam(0.533472f), TypeParam(0.533472f), TypeParam(0.171511f), TypeParam(0.876045f), TypeParam(0.406887f), TypeParam(-0.279405f), TypeParam(1.021640f), TypeParam(1.021640f), TypeParam(1.338354f), TypeParam(-0.279405f), TypeParam(0.438649f), TypeParam(-2.889061f), TypeParam(-2.889061f), TypeParam(1.024115f), TypeParam(-1.509981f), TypeParam(-2.889061f), TypeParam(-0.551834f), TypeParam(0.346216f), TypeParam(0.346216f), TypeParam(1.450250f), TypeParam(0.971125f), TypeParam(0.346216f), TypeParam(1.700101f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.039373f, -0.801472f, -0.495544f, -0.361514f, 0.585113f, -1.156007f, -0.143365f, -0.194741f, -0.906885f, -0.591838f, 0.150785f, -1.041149f, -0.720534f, -2.214754f, -0.683730f, 0.516358f, 0.792848f, 0.083228f, 0.422800f, -1.868747f, -1.105713f, 0.143731f, 0.583597f, 1.348155f}; + std::initializer_list X_data{TypeParam(-0.039373f), TypeParam(-0.801472f), TypeParam(-0.495544f), TypeParam(-0.361514f), TypeParam(0.585113f), TypeParam(-1.156007f), TypeParam(-0.143365f), TypeParam(-0.194741f), TypeParam(-0.906885f), TypeParam(-0.591838f), TypeParam(0.150785f), TypeParam(-1.041149f), TypeParam(-0.720534f), TypeParam(-2.214754f), TypeParam(-0.683730f), TypeParam(0.516358f), TypeParam(0.792848f), TypeParam(0.083228f), TypeParam(0.422800f), TypeParam(-1.868747f), TypeParam(-1.105713f), TypeParam(0.143731f), TypeParam(0.583597f), TypeParam(1.348155f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.829854f, -0.893309f, 0.491599f, -0.403504f, -0.578962f, 0.215574f, -0.623348f, 0.276486f, 0.235657f, -0.890987f, 0.199798f, 0.511115f, 0.474997f, -0.151054f, -0.983745f, -0.184985f, 0.416769f, -0.437853f, 0.455497f, 0.799155f, -0.626582f, 0.011834f, 0.496199f, 0.094053f}; + std::initializer_list Grid_data{TypeParam(0.829854f), TypeParam(-0.893309f), TypeParam(0.491599f), TypeParam(-0.403504f), TypeParam(-0.578962f), TypeParam(0.215574f), TypeParam(-0.623348f), TypeParam(0.276486f), TypeParam(0.235657f), TypeParam(-0.890987f), TypeParam(0.199798f), TypeParam(0.511115f), TypeParam(0.474997f), TypeParam(-0.151054f), TypeParam(-0.983745f), TypeParam(-0.184985f), TypeParam(0.416769f), TypeParam(-0.437853f), TypeParam(0.455497f), TypeParam(0.799155f), TypeParam(-0.626582f), TypeParam(0.011834f), TypeParam(0.496199f), TypeParam(0.094053f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.801472f, -0.361514f, -0.495544f, -0.495544f, -0.801472f, -1.156007f, -0.194741f, -0.591838f, -0.906885f, -0.906885f, -0.194741f, -1.041149f, 0.516358f, -0.683730f, 0.516358f, 0.083228f, -0.683730f, 0.516358f, 0.143731f, -1.105713f, 0.143731f, 1.348155f, -1.105713f, 0.143731f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.801472f), TypeParam(-0.361514f), TypeParam(-0.495544f), TypeParam(-0.495544f), TypeParam(-0.801472f), TypeParam(-1.156007f), TypeParam(-0.194741f), TypeParam(-0.591838f), TypeParam(-0.906885f), TypeParam(-0.906885f), TypeParam(-0.194741f), TypeParam(-1.041149f), TypeParam(0.516358f), TypeParam(-0.683730f), TypeParam(0.516358f), TypeParam(0.083228f), TypeParam(-0.683730f), TypeParam(0.516358f), TypeParam(0.143731f), TypeParam(-1.105713f), TypeParam(0.143731f), TypeParam(1.348155f), TypeParam(-1.105713f), TypeParam(0.143731f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.129230f, -0.054595f, 0.408347f, 1.126366f, 1.935057f, 1.007685f, 1.004642f, -0.433520f, -0.562711f, -0.832754f, -1.395545f, -0.399295f, -0.309940f, -0.056062f, 0.517413f, -1.596237f, 0.356960f, -2.297482f, -0.871083f, -1.674028f, 0.563055f, -1.435067f, 0.719400f, -1.370747f}; + std::initializer_list X_data{TypeParam(-0.129230f), TypeParam(-0.054595f), TypeParam(0.408347f), TypeParam(1.126366f), TypeParam(1.935057f), TypeParam(1.007685f), TypeParam(1.004642f), TypeParam(-0.433520f), TypeParam(-0.562711f), TypeParam(-0.832754f), TypeParam(-1.395545f), TypeParam(-0.399295f), TypeParam(-0.309940f), TypeParam(-0.056062f), TypeParam(0.517413f), TypeParam(-1.596237f), TypeParam(0.356960f), TypeParam(-2.297482f), TypeParam(-0.871083f), TypeParam(-1.674028f), TypeParam(0.563055f), TypeParam(-1.435067f), TypeParam(0.719400f), TypeParam(-1.370747f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.811910f, -1.183845f, -0.963667f, 0.947364f, 0.649243f, 1.125859f, 0.961345f, -1.071655f, -0.818917f, -0.193899f, -0.779319f, 0.833276f, -0.907209f, -0.585482f, -1.159310f, -0.681295f, 0.986973f, 0.982512f, 0.859005f, 0.926553f, 1.067024f, -0.307276f, 0.528003f, 1.069117f}; + std::initializer_list Grid_data{TypeParam(-0.811910f), TypeParam(-1.183845f), TypeParam(-0.963667f), TypeParam(0.947364f), TypeParam(0.649243f), TypeParam(1.125859f), TypeParam(0.961345f), TypeParam(-1.071655f), TypeParam(-0.818917f), TypeParam(-0.193899f), TypeParam(-0.779319f), TypeParam(0.833276f), TypeParam(-0.907209f), TypeParam(-0.585482f), TypeParam(-1.159310f), TypeParam(-0.681295f), TypeParam(0.986973f), TypeParam(0.982512f), TypeParam(0.859005f), TypeParam(0.926553f), TypeParam(1.067024f), TypeParam(-0.307276f), TypeParam(0.528003f), TypeParam(1.069117f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.129230f, 1.935057f, 1.007685f, -0.054595f, 0.408347f, 1.935057f, 1.004642f, -1.395545f, -0.399295f, -0.433520f, -0.562711f, -1.395545f, -0.309940f, -0.309940f, -2.297482f, -2.297482f, -1.596237f, -2.297482f, -0.871083f, -0.871083f, -1.370747f, -1.370747f, -1.435067f, -1.370747f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.129230f), TypeParam(1.935057f), TypeParam(1.007685f), TypeParam(-0.054595f), TypeParam(0.408347f), TypeParam(1.935057f), TypeParam(1.004642f), TypeParam(-1.395545f), TypeParam(-0.399295f), TypeParam(-0.433520f), TypeParam(-0.562711f), TypeParam(-1.395545f), TypeParam(-0.309940f), TypeParam(-0.309940f), TypeParam(-2.297482f), TypeParam(-2.297482f), TypeParam(-1.596237f), TypeParam(-2.297482f), TypeParam(-0.871083f), TypeParam(-0.871083f), TypeParam(-1.370747f), TypeParam(-1.370747f), TypeParam(-1.435067f), TypeParam(-1.370747f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.294201f, 0.797322f, 1.264215f, 0.935492f, 0.545464f, -1.537389f, 0.312439f, 0.740060f, -0.575326f, -1.432532f, -0.666175f, 1.017438f, -2.241368f, 0.437349f, -0.555362f, -0.057943f, 0.658583f, 0.992938f, -0.206548f, -0.244841f, -0.380599f, 1.131112f, -0.090205f, -0.897900f}; + std::initializer_list X_data{TypeParam(0.294201f), TypeParam(0.797322f), TypeParam(1.264215f), TypeParam(0.935492f), TypeParam(0.545464f), TypeParam(-1.537389f), TypeParam(0.312439f), TypeParam(0.740060f), TypeParam(-0.575326f), TypeParam(-1.432532f), TypeParam(-0.666175f), TypeParam(1.017438f), TypeParam(-2.241368f), TypeParam(0.437349f), TypeParam(-0.555362f), TypeParam(-0.057943f), TypeParam(0.658583f), TypeParam(0.992938f), TypeParam(-0.206548f), TypeParam(-0.244841f), TypeParam(-0.380599f), TypeParam(1.131112f), TypeParam(-0.090205f), TypeParam(-0.897900f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.595248f, -1.096726f, -0.214731f, -0.891773f, -0.512023f, 0.432352f, -0.852156f, 0.446072f, 1.018534f, 0.078706f, -0.799785f, -0.429942f, 0.262037f, -0.914782f, 0.596172f, -1.089444f, -1.153552f, -1.165993f, -0.243436f, 0.806920f, -1.135775f, 0.997425f, -0.480027f, 0.351461f}; + std::initializer_list Grid_data{TypeParam(0.595248f), TypeParam(-1.096726f), TypeParam(-0.214731f), TypeParam(-0.891773f), TypeParam(-0.512023f), TypeParam(0.432352f), TypeParam(-0.852156f), TypeParam(0.446072f), TypeParam(1.018534f), TypeParam(0.078706f), TypeParam(-0.799785f), TypeParam(-0.429942f), TypeParam(0.262037f), TypeParam(-0.914782f), TypeParam(0.596172f), TypeParam(-1.089444f), TypeParam(-1.153552f), TypeParam(-1.165993f), TypeParam(-0.243436f), TypeParam(0.806920f), TypeParam(-1.135775f), TypeParam(0.997425f), TypeParam(-0.480027f), TypeParam(0.351461f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.628229f, 0.561377f, 0.688215f, 0.861459f, 0.733996f, 0.850061f, 0.590307f, 0.329661f, -0.555725f, -0.595435f, -1.228216f, -0.224152f, -0.524667f, -0.094262f, -1.725798f, 0.562584f, 0.610959f, -0.014286f, -0.162194f, -0.215901f, -0.159037f, -0.282404f, -0.084779f, -0.097448f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.628229f), TypeParam(0.561377f), TypeParam(0.688215f), TypeParam(0.861459f), TypeParam(0.733996f), TypeParam(0.850061f), TypeParam(0.590307f), TypeParam(0.329661f), TypeParam(-0.555725f), TypeParam(-0.595435f), TypeParam(-1.228216f), TypeParam(-0.224152f), TypeParam(-0.524667f), TypeParam(-0.094262f), TypeParam(-1.725798f), TypeParam(0.562584f), TypeParam(0.610959f), TypeParam(-0.014286f), TypeParam(-0.162194f), TypeParam(-0.215901f), TypeParam(-0.159037f), TypeParam(-0.282404f), TypeParam(-0.084779f), TypeParam(-0.097448f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.199109f, -0.025686f, 1.802375f, -1.059653f, 3.402826f, -0.568670f, -0.475489f, 1.743163f, 1.060884f, -0.015953f, 1.275653f, 0.009457f, -0.369450f, 1.218198f, 0.255044f, 0.273993f, 1.404381f, 1.082878f, 0.788966f, -0.137615f, 0.122478f, -1.076701f, -0.650897f, -1.619658f}; + std::initializer_list X_data{TypeParam(-1.199109f), TypeParam(-0.025686f), TypeParam(1.802375f), TypeParam(-1.059653f), TypeParam(3.402826f), TypeParam(-0.568670f), TypeParam(-0.475489f), TypeParam(1.743163f), TypeParam(1.060884f), TypeParam(-0.015953f), TypeParam(1.275653f), TypeParam(0.009457f), TypeParam(-0.369450f), TypeParam(1.218198f), TypeParam(0.255044f), TypeParam(0.273993f), TypeParam(1.404381f), TypeParam(1.082878f), TypeParam(0.788966f), TypeParam(-0.137615f), TypeParam(0.122478f), TypeParam(-1.076701f), TypeParam(-0.650897f), TypeParam(-1.619658f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.038587f, -0.371014f, -0.260918f, 0.159481f, 0.594851f, -0.840708f, 1.007133f, -0.130476f, -1.005535f, -0.649269f, 1.061781f, 1.097433f, -1.111536f, 0.846358f, 0.601391f, 0.710302f, 1.015835f, -0.646740f, 0.378931f, 0.491080f, -0.354592f, 0.401584f, -0.345256f, 0.741914f}; + std::initializer_list Grid_data{TypeParam(0.038587f), TypeParam(-0.371014f), TypeParam(-0.260918f), TypeParam(0.159481f), TypeParam(0.594851f), TypeParam(-0.840708f), TypeParam(1.007133f), TypeParam(-0.130476f), TypeParam(-1.005535f), TypeParam(-0.649269f), TypeParam(1.061781f), TypeParam(1.097433f), TypeParam(-1.111536f), TypeParam(0.846358f), TypeParam(0.601391f), TypeParam(0.710302f), TypeParam(1.015835f), TypeParam(-0.646740f), TypeParam(0.378931f), TypeParam(0.491080f), TypeParam(-0.354592f), TypeParam(0.401584f), TypeParam(-0.345256f), TypeParam(0.741914f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.199899f, 1.437523f, -0.017180f, -0.422530f, -0.554188f, -0.088180f, 0.613663f, 0.843979f, 1.165913f, 0.161823f, -0.215288f, 0.001466f, 0.398506f, 0.909392f, 0.576145f, 0.897902f, 0.920312f, 1.201733f, -0.184698f, -1.360176f, -0.080218f, -1.352020f, -0.497572f, -0.710420f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.199899f), TypeParam(1.437523f), TypeParam(-0.017180f), TypeParam(-0.422530f), TypeParam(-0.554188f), TypeParam(-0.088180f), TypeParam(0.613663f), TypeParam(0.843979f), TypeParam(1.165913f), TypeParam(0.161823f), TypeParam(-0.215288f), TypeParam(0.001466f), TypeParam(0.398506f), TypeParam(0.909392f), TypeParam(0.576145f), TypeParam(0.897902f), TypeParam(0.920312f), TypeParam(1.201733f), TypeParam(-0.184698f), TypeParam(-1.360176f), TypeParam(-0.080218f), TypeParam(-1.352020f), TypeParam(-0.497572f), TypeParam(-0.710420f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.546073f, -0.630178f, -0.634650f, 0.974665f, 0.209843f, 0.029890f, 1.709235f, -0.725759f, -0.876951f, 0.522287f, 0.462005f, -1.329269f, -0.295974f, 1.371414f, 0.973846f, 0.765543f, -0.403897f, -0.326279f, 0.748218f, -0.195299f, 0.676756f, -0.080633f, 0.158123f, 0.099984f}; + std::initializer_list X_data{TypeParam(-0.546073f), TypeParam(-0.630178f), TypeParam(-0.634650f), TypeParam(0.974665f), TypeParam(0.209843f), TypeParam(0.029890f), TypeParam(1.709235f), TypeParam(-0.725759f), TypeParam(-0.876951f), TypeParam(0.522287f), TypeParam(0.462005f), TypeParam(-1.329269f), TypeParam(-0.295974f), TypeParam(1.371414f), TypeParam(0.973846f), TypeParam(0.765543f), TypeParam(-0.403897f), TypeParam(-0.326279f), TypeParam(0.748218f), TypeParam(-0.195299f), TypeParam(0.676756f), TypeParam(-0.080633f), TypeParam(0.158123f), TypeParam(0.099984f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{1.182462f, -0.759228f, 0.230068f, -0.103567f, -0.252788f, -0.268017f, 0.762529f, 0.057356f, -1.168338f, -0.708432f, -0.409080f, 0.603860f, -0.776560f, 1.131504f, -0.267275f, -0.215474f, 0.940270f, 0.603129f, 1.017745f, 0.694133f, -0.364025f, -0.796167f, -0.089284f, 0.993165f}; + std::initializer_list Grid_data{TypeParam(1.182462f), TypeParam(-0.759228f), TypeParam(0.230068f), TypeParam(-0.103567f), TypeParam(-0.252788f), TypeParam(-0.268017f), TypeParam(0.762529f), TypeParam(0.057356f), TypeParam(-1.168338f), TypeParam(-0.708432f), TypeParam(-0.409080f), TypeParam(0.603860f), TypeParam(-0.776560f), TypeParam(1.131504f), TypeParam(-0.267275f), TypeParam(-0.215474f), TypeParam(0.940270f), TypeParam(0.603129f), TypeParam(1.017745f), TypeParam(0.694133f), TypeParam(-0.364025f), TypeParam(-0.796167f), TypeParam(-0.089284f), TypeParam(0.993165f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.243777f, 0.256440f, -0.179228f, 0.741578f, -0.571899f, 0.031558f, -0.425264f, 0.007242f, -0.044977f, 0.271677f, 0.955187f, -0.224230f, -0.395226f, 0.771988f, 0.108104f, 0.007673f, 0.371491f, -0.360026f, 0.151628f, 0.399982f, 0.038327f, 0.044739f, 0.445689f, 0.133017f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.243777f), TypeParam(0.256440f), TypeParam(-0.179228f), TypeParam(0.741578f), TypeParam(-0.571899f), TypeParam(0.031558f), TypeParam(-0.425264f), TypeParam(0.007242f), TypeParam(-0.044977f), TypeParam(0.271677f), TypeParam(0.955187f), TypeParam(-0.224230f), TypeParam(-0.395226f), TypeParam(0.771988f), TypeParam(0.108104f), TypeParam(0.007673f), TypeParam(0.371491f), TypeParam(-0.360026f), TypeParam(0.151628f), TypeParam(0.399982f), TypeParam(0.038327f), TypeParam(0.044739f), TypeParam(0.445689f), TypeParam(0.133017f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.873307f, 0.004261f, -1.257887f, -1.084466f, 0.752979f, 0.323648f, -0.275010f, 1.305612f, -0.009480f, -0.831312f, -0.556290f, 2.070567f, 0.710039f, -0.146461f, -0.746745f, 0.725842f, 0.403461f, 0.234374f, 0.173281f, 1.724145f, -0.408946f, 0.782749f, -1.520847f, -0.314686f}; + std::initializer_list X_data{TypeParam(-0.873307f), TypeParam(0.004261f), TypeParam(-1.257887f), TypeParam(-1.084466f), TypeParam(0.752979f), TypeParam(0.323648f), TypeParam(-0.275010f), TypeParam(1.305612f), TypeParam(-0.009480f), TypeParam(-0.831312f), TypeParam(-0.556290f), TypeParam(2.070567f), TypeParam(0.710039f), TypeParam(-0.146461f), TypeParam(-0.746745f), TypeParam(0.725842f), TypeParam(0.403461f), TypeParam(0.234374f), TypeParam(0.173281f), TypeParam(1.724145f), TypeParam(-0.408946f), TypeParam(0.782749f), TypeParam(-1.520847f), TypeParam(-0.314686f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.605180f, 0.169896f, 1.021029f, 0.161312f, -0.555188f, 1.135200f, 0.284017f, -1.170817f, -0.341630f, -0.817401f, 1.052104f, -0.198175f, -1.093830f, -0.075436f, 0.753615f, 0.311761f, 0.379445f, 0.111448f, 0.447382f, -0.292382f, -0.477360f, -1.121650f, -0.904004f, 0.520083f}; + std::initializer_list Grid_data{TypeParam(0.605180f), TypeParam(0.169896f), TypeParam(1.021029f), TypeParam(0.161312f), TypeParam(-0.555188f), TypeParam(1.135200f), TypeParam(0.284017f), TypeParam(-1.170817f), TypeParam(-0.341630f), TypeParam(-0.817401f), TypeParam(1.052104f), TypeParam(-0.198175f), TypeParam(-1.093830f), TypeParam(-0.075436f), TypeParam(0.753615f), TypeParam(0.311761f), TypeParam(0.379445f), TypeParam(0.111448f), TypeParam(0.447382f), TypeParam(-0.292382f), TypeParam(-0.477360f), TypeParam(-1.121650f), TypeParam(-0.904004f), TypeParam(0.520083f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.725617f, -0.743749f, 0.752979f, -0.185279f, -0.734326f, -0.760828f, -0.091786f, -0.129152f, -0.556290f, 0.964224f, -0.024687f, -0.196084f, -0.581904f, 0.496011f, 0.499240f, 0.319537f, 0.690648f, 0.150559f, -0.343065f, 0.269544f, 0.455333f, 1.124628f, 0.208392f, -1.276367f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.725617f), TypeParam(-0.743749f), TypeParam(0.752979f), TypeParam(-0.185279f), TypeParam(-0.734326f), TypeParam(-0.760828f), TypeParam(-0.091786f), TypeParam(-0.129152f), TypeParam(-0.556290f), TypeParam(0.964224f), TypeParam(-0.024687f), TypeParam(-0.196084f), TypeParam(-0.581904f), TypeParam(0.496011f), TypeParam(0.499240f), TypeParam(0.319537f), TypeParam(0.690648f), TypeParam(0.150559f), TypeParam(-0.343065f), TypeParam(0.269544f), TypeParam(0.455333f), TypeParam(1.124628f), TypeParam(0.208392f), TypeParam(-1.276367f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.540757f, -0.947807f, 0.202144f, -0.350748f, 0.545005f, 1.541211f, 0.600239f, -0.338015f, -1.080823f, -1.391537f, -0.352570f, 1.560770f, -0.822488f, -2.140920f, 0.099553f, -0.697505f, 0.665352f, -2.256198f, -1.002236f, -1.395144f, 0.415783f, 0.268104f, -0.151752f, 0.794042f}; + std::initializer_list X_data{TypeParam(0.540757f), TypeParam(-0.947807f), TypeParam(0.202144f), TypeParam(-0.350748f), TypeParam(0.545005f), TypeParam(1.541211f), TypeParam(0.600239f), TypeParam(-0.338015f), TypeParam(-1.080823f), TypeParam(-1.391537f), TypeParam(-0.352570f), TypeParam(1.560770f), TypeParam(-0.822488f), TypeParam(-2.140920f), TypeParam(0.099553f), TypeParam(-0.697505f), TypeParam(0.665352f), TypeParam(-2.256198f), TypeParam(-1.002236f), TypeParam(-1.395144f), TypeParam(0.415783f), TypeParam(0.268104f), TypeParam(-0.151752f), TypeParam(0.794042f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{1.051960f, -0.798975f, -0.129852f, -0.064453f, 0.535452f, 0.820411f, -0.190205f, -0.994177f, 0.594591f, 0.358958f, 0.482039f, -0.740241f, 0.772315f, 1.136586f, 0.104126f, -1.120858f, 0.842388f, -0.889742f, 0.275846f, 0.174381f, -0.561644f, 0.417835f, -1.073319f, 0.273311f}; + std::initializer_list Grid_data{TypeParam(1.051960f), TypeParam(-0.798975f), TypeParam(-0.129852f), TypeParam(-0.064453f), TypeParam(0.535452f), TypeParam(0.820411f), TypeParam(-0.190205f), TypeParam(-0.994177f), TypeParam(0.594591f), TypeParam(0.358958f), TypeParam(0.482039f), TypeParam(-0.740241f), TypeParam(0.772315f), TypeParam(1.136586f), TypeParam(0.104126f), TypeParam(-1.120858f), TypeParam(0.842388f), TypeParam(-0.889742f), TypeParam(0.275846f), TypeParam(0.174381f), TypeParam(-0.561644f), TypeParam(0.417835f), TypeParam(-1.073319f), TypeParam(0.273311f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.793997f, -0.042818f, 1.034663f, -0.061725f, 0.327743f, -0.470152f, -0.528701f, -1.125254f, 0.678924f, 0.212033f, -0.430627f, -0.410903f, -1.743740f, -1.404122f, -1.882401f, -0.546577f, -0.033295f, 0.203686f, 0.631537f, -1.031405f, -1.182924f, 0.344248f, 0.246420f, 0.266212f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.793997f), TypeParam(-0.042818f), TypeParam(1.034663f), TypeParam(-0.061725f), TypeParam(0.327743f), TypeParam(-0.470152f), TypeParam(-0.528701f), TypeParam(-1.125254f), TypeParam(0.678924f), TypeParam(0.212033f), TypeParam(-0.430627f), TypeParam(-0.410903f), TypeParam(-1.743740f), TypeParam(-1.404122f), TypeParam(-1.882401f), TypeParam(-0.546577f), TypeParam(-0.033295f), TypeParam(0.203686f), TypeParam(0.631537f), TypeParam(-1.031405f), TypeParam(-1.182924f), TypeParam(0.344248f), TypeParam(0.246420f), TypeParam(0.266212f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.584178f, 1.050431f, 1.285579f, -1.616520f, -0.768962f, -1.220462f, 0.573128f, 0.699197f, -1.654887f, 0.493267f, -0.615042f, 1.311865f, 0.788249f, -1.232951f, 0.454381f, -1.436621f, 0.711631f, 0.554599f, -0.807529f, 1.680131f, 0.597634f, -0.238890f, -0.345997f, 1.770104f}; + std::initializer_list X_data{TypeParam(0.584178f), TypeParam(1.050431f), TypeParam(1.285579f), TypeParam(-1.616520f), TypeParam(-0.768962f), TypeParam(-1.220462f), TypeParam(0.573128f), TypeParam(0.699197f), TypeParam(-1.654887f), TypeParam(0.493267f), TypeParam(-0.615042f), TypeParam(1.311865f), TypeParam(0.788249f), TypeParam(-1.232951f), TypeParam(0.454381f), TypeParam(-1.436621f), TypeParam(0.711631f), TypeParam(0.554599f), TypeParam(-0.807529f), TypeParam(1.680131f), TypeParam(0.597634f), TypeParam(-0.238890f), TypeParam(-0.345997f), TypeParam(1.770104f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.564800f, 1.031186f, 0.795913f, -0.629473f, -0.131544f, -0.377622f, -0.964948f, 0.000496f, 0.902922f, 1.011019f, 0.111961f, 0.272548f, -0.519506f, 0.905811f, -0.499330f, -0.833583f, 0.184792f, 0.719262f, -1.081910f, 1.084761f, 0.431677f, -0.840735f, -0.258489f, 1.041096f}; + std::initializer_list Grid_data{TypeParam(0.564800f), TypeParam(1.031186f), TypeParam(0.795913f), TypeParam(-0.629473f), TypeParam(-0.131544f), TypeParam(-0.377622f), TypeParam(-0.964948f), TypeParam(0.000496f), TypeParam(0.902922f), TypeParam(1.011019f), TypeParam(0.111961f), TypeParam(0.272548f), TypeParam(-0.519506f), TypeParam(0.905811f), TypeParam(-0.499330f), TypeParam(-0.833583f), TypeParam(0.184792f), TypeParam(0.719262f), TypeParam(-1.081910f), TypeParam(1.084761f), TypeParam(0.431677f), TypeParam(-0.840735f), TypeParam(-0.258489f), TypeParam(1.041096f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.220462f, 0.901641f, 0.521980f, 1.284051f, -1.220462f, -0.717235f, 1.311865f, 0.687708f, -0.023386f, -1.654114f, 1.311865f, 0.029458f, 0.711631f, 0.786895f, 0.604097f, 0.711631f, -1.094857f, 0.673706f, -0.345997f, -0.805863f, 1.103092f, -0.345997f, 1.510167f, 0.165064f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.220462f), TypeParam(0.901641f), TypeParam(0.521980f), TypeParam(1.284051f), TypeParam(-1.220462f), TypeParam(-0.717235f), TypeParam(1.311865f), TypeParam(0.687708f), TypeParam(-0.023386f), TypeParam(-1.654114f), TypeParam(1.311865f), TypeParam(0.029458f), TypeParam(0.711631f), TypeParam(0.786895f), TypeParam(0.604097f), TypeParam(0.711631f), TypeParam(-1.094857f), TypeParam(0.673706f), TypeParam(-0.345997f), TypeParam(-0.805863f), TypeParam(1.103092f), TypeParam(-0.345997f), TypeParam(1.510167f), TypeParam(0.165064f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.497417f, 0.268522f, 1.476879f, 0.354795f, 1.624709f, 0.593423f, -1.725412f, -0.622016f, -0.466707f, -0.319962f, 0.701868f, 0.494252f, -0.630165f, 0.548236f, 1.042740f, 0.253800f, -2.667303f, 1.379165f, -0.519418f, 0.672783f, -0.005627f, -0.180192f, -0.018395f, 0.998084f}; + std::initializer_list X_data{TypeParam(0.497417f), TypeParam(0.268522f), TypeParam(1.476879f), TypeParam(0.354795f), TypeParam(1.624709f), TypeParam(0.593423f), TypeParam(-1.725412f), TypeParam(-0.622016f), TypeParam(-0.466707f), TypeParam(-0.319962f), TypeParam(0.701868f), TypeParam(0.494252f), TypeParam(-0.630165f), TypeParam(0.548236f), TypeParam(1.042740f), TypeParam(0.253800f), TypeParam(-2.667303f), TypeParam(1.379165f), TypeParam(-0.519418f), TypeParam(0.672783f), TypeParam(-0.005627f), TypeParam(-0.180192f), TypeParam(-0.018395f), TypeParam(0.998084f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.213755f, 0.141747f, -0.562622f, -0.414594f, 0.325025f, -0.834438f, 0.197995f, 0.519270f, -0.472884f, 0.996769f, -0.078973f, 0.544455f, 1.188368f, -0.366802f, 0.652090f, -0.343235f, -0.175288f, -0.203365f, -0.007455f, -0.453322f, 0.281264f, 0.045216f, 0.760668f, -0.242886f}; + std::initializer_list Grid_data{TypeParam(0.213755f), TypeParam(0.141747f), TypeParam(-0.562622f), TypeParam(-0.414594f), TypeParam(0.325025f), TypeParam(-0.834438f), TypeParam(0.197995f), TypeParam(0.519270f), TypeParam(-0.472884f), TypeParam(0.996769f), TypeParam(-0.078973f), TypeParam(0.544455f), TypeParam(1.188368f), TypeParam(-0.366802f), TypeParam(0.652090f), TypeParam(-0.343235f), TypeParam(-0.175288f), TypeParam(-0.203365f), TypeParam(-0.007455f), TypeParam(-0.453322f), TypeParam(0.281264f), TypeParam(0.045216f), TypeParam(0.760668f), TypeParam(-0.242886f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.007407f, 1.068583f, 0.492134f, 1.222040f, 1.576835f, 1.464183f, -0.238652f, -1.242164f, -1.156880f, 0.279082f, 0.744912f, 0.338287f, 0.215322f, 0.388598f, 0.866571f, 0.556826f, 0.608617f, 0.326312f, 0.044527f, -0.028766f, -0.136528f, -0.084880f, -0.121429f, -0.105516f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.007407f), TypeParam(1.068583f), TypeParam(0.492134f), TypeParam(1.222040f), TypeParam(1.576835f), TypeParam(1.464183f), TypeParam(-0.238652f), TypeParam(-1.242164f), TypeParam(-1.156880f), TypeParam(0.279082f), TypeParam(0.744912f), TypeParam(0.338287f), TypeParam(0.215322f), TypeParam(0.388598f), TypeParam(0.866571f), TypeParam(0.556826f), TypeParam(0.608617f), TypeParam(0.326312f), TypeParam(0.044527f), TypeParam(-0.028766f), TypeParam(-0.136528f), TypeParam(-0.084880f), TypeParam(-0.121429f), TypeParam(-0.105516f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.065470f, 0.402578f, -0.405242f, -0.583366f, -0.258523f, -0.605559f, -0.188242f, 0.959607f, 1.189619f, -0.179522f, -1.823240f, -0.051351f, -1.636092f, -2.510569f, -1.238273f, -0.929619f, -0.058536f, 0.772879f, 0.468944f, 0.259886f, 0.757624f, -2.041813f, -0.552378f, 0.626977f}; + std::initializer_list X_data{TypeParam(-1.065470f), TypeParam(0.402578f), TypeParam(-0.405242f), TypeParam(-0.583366f), TypeParam(-0.258523f), TypeParam(-0.605559f), TypeParam(-0.188242f), TypeParam(0.959607f), TypeParam(1.189619f), TypeParam(-0.179522f), TypeParam(-1.823240f), TypeParam(-0.051351f), TypeParam(-1.636092f), TypeParam(-2.510569f), TypeParam(-1.238273f), TypeParam(-0.929619f), TypeParam(-0.058536f), TypeParam(0.772879f), TypeParam(0.468944f), TypeParam(0.259886f), TypeParam(0.757624f), TypeParam(-2.041813f), TypeParam(-0.552378f), TypeParam(0.626977f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-1.199809f, 0.061445f, -0.035546f, 0.180524f, 0.919500f, 1.166411f, -0.711939f, -0.074825f, -0.480808f, -1.105975f, -0.873191f, 1.126273f, 0.699673f, 0.644581f, 0.666892f, -0.953375f, 0.126023f, 1.116858f, -0.669703f, 1.067513f, 0.315406f, 0.844252f, -0.514065f, 0.553221f}; + std::initializer_list Grid_data{TypeParam(-1.199809f), TypeParam(0.061445f), TypeParam(-0.035546f), TypeParam(0.180524f), TypeParam(0.919500f), TypeParam(1.166411f), TypeParam(-0.711939f), TypeParam(-0.074825f), TypeParam(-0.480808f), TypeParam(-1.105975f), TypeParam(-0.873191f), TypeParam(1.126273f), TypeParam(0.699673f), TypeParam(0.644581f), TypeParam(0.666892f), TypeParam(-0.953375f), TypeParam(0.126023f), TypeParam(1.116858f), TypeParam(-0.669703f), TypeParam(1.067513f), TypeParam(0.315406f), TypeParam(0.844252f), TypeParam(-0.514065f), TypeParam(0.553221f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.086429f, -0.590424f, -0.090572f, -0.393926f, -0.379182f, -0.031455f, 0.347836f, 0.182097f, 0.050161f, 1.154870f, -0.134312f, -0.509844f, 0.697346f, -1.440179f, 0.264668f, 0.021389f, 0.729883f, -0.236038f, 0.576661f, 0.348301f, 0.149351f, -0.327477f, 0.607344f, -0.405680f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.086429f), TypeParam(-0.590424f), TypeParam(-0.090572f), TypeParam(-0.393926f), TypeParam(-0.379182f), TypeParam(-0.031455f), TypeParam(0.347836f), TypeParam(0.182097f), TypeParam(0.050161f), TypeParam(1.154870f), TypeParam(-0.134312f), TypeParam(-0.509844f), TypeParam(0.697346f), TypeParam(-1.440179f), TypeParam(0.264668f), TypeParam(0.021389f), TypeParam(0.729883f), TypeParam(-0.236038f), TypeParam(0.576661f), TypeParam(0.348301f), TypeParam(0.149351f), TypeParam(-0.327477f), TypeParam(0.607344f), TypeParam(-0.405680f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.203585f, -1.032829f, 1.130481f, -0.570301f, -2.100938f, 0.389922f, 0.087343f, -0.857360f, 1.193520f, -0.019760f, 0.280285f, 1.811013f, 1.838673f, 0.164184f, 1.436009f, 0.167011f, -1.139939f, -0.029833f, -0.009878f, 0.079750f, 0.216590f, -0.265852f, -0.528116f, -0.451915f}; + std::initializer_list X_data{TypeParam(0.203585f), TypeParam(-1.032829f), TypeParam(1.130481f), TypeParam(-0.570301f), TypeParam(-2.100938f), TypeParam(0.389922f), TypeParam(0.087343f), TypeParam(-0.857360f), TypeParam(1.193520f), TypeParam(-0.019760f), TypeParam(0.280285f), TypeParam(1.811013f), TypeParam(1.838673f), TypeParam(0.164184f), TypeParam(1.436009f), TypeParam(0.167011f), TypeParam(-1.139939f), TypeParam(-0.029833f), TypeParam(-0.009878f), TypeParam(0.079750f), TypeParam(0.216590f), TypeParam(-0.265852f), TypeParam(-0.528116f), TypeParam(-0.451915f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.797796f, -1.010726f, 0.868577f, -1.132977f, 0.268082f, -0.786042f, -0.476635f, 0.212483f, -0.471816f, -0.189867f, -1.137389f, -1.131448f, 0.464836f, -0.507934f, -0.730068f, -0.473499f, -0.981082f, -0.959280f, 0.718047f, 0.609891f, 0.159844f, -0.655512f, 0.399241f, 0.053910f}; + std::initializer_list Grid_data{TypeParam(0.797796f), TypeParam(-1.010726f), TypeParam(0.868577f), TypeParam(-1.132977f), TypeParam(0.268082f), TypeParam(-0.786042f), TypeParam(-0.476635f), TypeParam(0.212483f), TypeParam(-0.471816f), TypeParam(-0.189867f), TypeParam(-1.137389f), TypeParam(-1.131448f), TypeParam(0.464836f), TypeParam(-0.507934f), TypeParam(-0.730068f), TypeParam(-0.473499f), TypeParam(-0.981082f), TypeParam(-0.959280f), TypeParam(0.718047f), TypeParam(0.609891f), TypeParam(0.159844f), TypeParam(-0.655512f), TypeParam(0.399241f), TypeParam(0.053910f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.934180f, -1.004565f, -0.467118f, 0.384839f, 0.792549f, 0.188357f, -0.785741f, -0.871727f, -0.372851f, 0.958270f, 0.751528f, 0.046397f, 0.598629f, 1.686400f, 1.817043f, 0.015806f, 0.866266f, 0.480930f, -0.013358f, 0.152904f, -0.001292f, -0.385043f, 0.030959f, -0.152332f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.934180f), TypeParam(-1.004565f), TypeParam(-0.467118f), TypeParam(0.384839f), TypeParam(0.792549f), TypeParam(0.188357f), TypeParam(-0.785741f), TypeParam(-0.871727f), TypeParam(-0.372851f), TypeParam(0.958270f), TypeParam(0.751528f), TypeParam(0.046397f), TypeParam(0.598629f), TypeParam(1.686400f), TypeParam(1.817043f), TypeParam(0.015806f), TypeParam(0.866266f), TypeParam(0.480930f), TypeParam(-0.013358f), TypeParam(0.152904f), TypeParam(-0.001292f), TypeParam(-0.385043f), TypeParam(0.030959f), TypeParam(-0.152332f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.427361f, 0.814325f, -1.412076f, -0.099774f, 0.074936f, 0.590322f, 0.398556f, -0.635891f, -1.081747f, -0.330179f, 0.271759f, -1.089819f, -0.746656f, -0.942538f, -1.251568f, -1.730282f, -0.722323f, 0.525964f, -0.436259f, -0.188952f, -0.499550f, 1.502071f, -0.014112f, 1.194050f}; + std::initializer_list X_data{TypeParam(-0.427361f), TypeParam(0.814325f), TypeParam(-1.412076f), TypeParam(-0.099774f), TypeParam(0.074936f), TypeParam(0.590322f), TypeParam(0.398556f), TypeParam(-0.635891f), TypeParam(-1.081747f), TypeParam(-0.330179f), TypeParam(0.271759f), TypeParam(-1.089819f), TypeParam(-0.746656f), TypeParam(-0.942538f), TypeParam(-1.251568f), TypeParam(-1.730282f), TypeParam(-0.722323f), TypeParam(0.525964f), TypeParam(-0.436259f), TypeParam(-0.188952f), TypeParam(-0.499550f), TypeParam(1.502071f), TypeParam(-0.014112f), TypeParam(1.194050f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.102021f, -0.935855f, -0.007380f, -0.996053f, -0.258157f, 0.695455f, -0.834420f, -0.808862f, -0.293012f, -0.328961f, 0.203145f, 0.199219f, 0.608516f, -0.826657f, -0.084685f, 0.671149f, 1.037966f, -0.087535f, -0.694344f, 0.344955f, 0.683373f, -0.749700f, -0.696352f, 0.530398f}; + std::initializer_list Grid_data{TypeParam(-0.102021f), TypeParam(-0.935855f), TypeParam(-0.007380f), TypeParam(-0.996053f), TypeParam(-0.258157f), TypeParam(0.695455f), TypeParam(-0.834420f), TypeParam(-0.808862f), TypeParam(-0.293012f), TypeParam(-0.328961f), TypeParam(0.203145f), TypeParam(0.199219f), TypeParam(0.608516f), TypeParam(-0.826657f), TypeParam(-0.084685f), TypeParam(0.671149f), TypeParam(1.037966f), TypeParam(-0.087535f), TypeParam(-0.694344f), TypeParam(0.344955f), TypeParam(0.683373f), TypeParam(-0.749700f), TypeParam(-0.696352f), TypeParam(0.530398f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.154701f, 0.273277f, 0.226316f, -0.467055f, -0.820643f, -0.311691f, 0.084699f, -0.052970f, 0.001158f, 0.679701f, -0.467804f, -0.607116f, -0.871407f, -0.210613f, -1.860685f, -1.059387f, -0.902250f, -0.918798f, -0.360562f, 0.476049f, 1.499304f, -0.418396f, -0.298854f, -0.235927f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.154701f), TypeParam(0.273277f), TypeParam(0.226316f), TypeParam(-0.467055f), TypeParam(-0.820643f), TypeParam(-0.311691f), TypeParam(0.084699f), TypeParam(-0.052970f), TypeParam(0.001158f), TypeParam(0.679701f), TypeParam(-0.467804f), TypeParam(-0.607116f), TypeParam(-0.871407f), TypeParam(-0.210613f), TypeParam(-1.860685f), TypeParam(-1.059387f), TypeParam(-0.902250f), TypeParam(-0.918798f), TypeParam(-0.360562f), TypeParam(0.476049f), TypeParam(1.499304f), TypeParam(-0.418396f), TypeParam(-0.298854f), TypeParam(-0.235927f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.084082f, -0.128738f, -0.681077f, -1.309896f, 0.660269f, -1.412063f, 1.834581f, 0.456195f, 0.162801f, -0.638266f, 0.897973f, -0.383653f, 0.297945f, 1.809414f, -0.091298f, 1.092744f, -0.102453f, -1.726535f, -0.484632f, 0.712097f, 1.820312f, -0.852073f, -0.341399f, -0.138106f}; + std::initializer_list X_data{TypeParam(-1.084082f), TypeParam(-0.128738f), TypeParam(-0.681077f), TypeParam(-1.309896f), TypeParam(0.660269f), TypeParam(-1.412063f), TypeParam(1.834581f), TypeParam(0.456195f), TypeParam(0.162801f), TypeParam(-0.638266f), TypeParam(0.897973f), TypeParam(-0.383653f), TypeParam(0.297945f), TypeParam(1.809414f), TypeParam(-0.091298f), TypeParam(1.092744f), TypeParam(-0.102453f), TypeParam(-1.726535f), TypeParam(-0.484632f), TypeParam(0.712097f), TypeParam(1.820312f), TypeParam(-0.852073f), TypeParam(-0.341399f), TypeParam(-0.138106f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.501236f, -0.770480f, -0.140656f, -1.129896f, 0.470370f, 0.885106f, 0.288068f, -0.118568f, 0.594968f, -0.761702f, 1.173892f, -1.193212f, -1.149534f, -0.283562f, 0.980213f, 0.120151f, 0.460855f, -0.879608f, 0.437623f, -0.134092f, 0.480988f, 0.847491f, 0.521616f, -0.102077f}; + std::initializer_list Grid_data{TypeParam(-0.501236f), TypeParam(-0.770480f), TypeParam(-0.140656f), TypeParam(-1.129896f), TypeParam(0.470370f), TypeParam(0.885106f), TypeParam(0.288068f), TypeParam(-0.118568f), TypeParam(0.594968f), TypeParam(-0.761702f), TypeParam(1.173892f), TypeParam(-1.193212f), TypeParam(-1.149534f), TypeParam(-0.283562f), TypeParam(0.980213f), TypeParam(0.120151f), TypeParam(0.460855f), TypeParam(-0.879608f), TypeParam(0.437623f), TypeParam(-0.134092f), TypeParam(0.480988f), TypeParam(0.847491f), TypeParam(0.521616f), TypeParam(-0.102077f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.953278f, -0.722872f, -1.065112f, -1.071529f, -0.344328f, -0.233562f, 1.436462f, 1.232983f, -0.181487f, -0.297043f, 0.464837f, 0.396673f, 0.053896f, 0.733510f, 1.541248f, 1.117701f, -1.352406f, 1.131762f, 1.324986f, -0.882173f, 0.469635f, -0.247133f, -0.196824f, -0.393592f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.953278f), TypeParam(-0.722872f), TypeParam(-1.065112f), TypeParam(-1.071529f), TypeParam(-0.344328f), TypeParam(-0.233562f), TypeParam(1.436462f), TypeParam(1.232983f), TypeParam(-0.181487f), TypeParam(-0.297043f), TypeParam(0.464837f), TypeParam(0.396673f), TypeParam(0.053896f), TypeParam(0.733510f), TypeParam(1.541248f), TypeParam(1.117701f), TypeParam(-1.352406f), TypeParam(1.131762f), TypeParam(1.324986f), TypeParam(-0.882173f), TypeParam(0.469635f), TypeParam(-0.247133f), TypeParam(-0.196824f), TypeParam(-0.393592f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.122981f, 0.620969f, -0.876394f, -1.774003f, -0.810376f, -1.475962f, 0.667025f, 0.668804f, -0.748346f, 1.422400f, 0.138469f, -0.165945f, 1.266886f, -0.496157f, 0.158060f, 0.488900f, 0.414476f, 0.419527f, 0.238000f, -0.034674f, 0.229435f, 0.234530f, 0.320846f, 0.703888f}; + std::initializer_list X_data{TypeParam(-1.122981f), TypeParam(0.620969f), TypeParam(-0.876394f), TypeParam(-1.774003f), TypeParam(-0.810376f), TypeParam(-1.475962f), TypeParam(0.667025f), TypeParam(0.668804f), TypeParam(-0.748346f), TypeParam(1.422400f), TypeParam(0.138469f), TypeParam(-0.165945f), TypeParam(1.266886f), TypeParam(-0.496157f), TypeParam(0.158060f), TypeParam(0.488900f), TypeParam(0.414476f), TypeParam(0.419527f), TypeParam(0.238000f), TypeParam(-0.034674f), TypeParam(0.229435f), TypeParam(0.234530f), TypeParam(0.320846f), TypeParam(0.703888f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.471637f, -0.923628f, -0.909401f, 0.684338f, 0.224360f, 1.092855f, -0.320755f, -0.579618f, -0.111056f, 0.006071f, 0.915173f, -1.195296f, -0.085441f, 0.530823f, -0.660820f, -0.609769f, 0.579921f, -1.149822f, 0.284347f, -0.929024f, 0.596474f, -1.026049f, 0.737766f, -1.135959f}; + std::initializer_list Grid_data{TypeParam(0.471637f), TypeParam(-0.923628f), TypeParam(-0.909401f), TypeParam(0.684338f), TypeParam(0.224360f), TypeParam(1.092855f), TypeParam(-0.320755f), TypeParam(-0.579618f), TypeParam(-0.111056f), TypeParam(0.006071f), TypeParam(0.915173f), TypeParam(-1.195296f), TypeParam(-0.085441f), TypeParam(0.530823f), TypeParam(-0.660820f), TypeParam(-0.609769f), TypeParam(0.579921f), TypeParam(-1.149822f), TypeParam(0.284347f), TypeParam(-0.929024f), TypeParam(0.596474f), TypeParam(-1.026049f), TypeParam(0.737766f), TypeParam(-1.135959f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.998063f, -0.689213f, -1.266024f, -0.870706f, -1.217616f, 1.292693f, 0.543307f, 0.219521f, -0.255151f, 0.543599f, 0.062982f, 0.527696f, 0.387590f, 1.352544f, -0.758053f, -0.262859f, -0.820496f, -0.934255f, 0.434353f, 0.262797f, -0.092283f, -0.021089f, -0.106052f, -0.119717f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.998063f), TypeParam(-0.689213f), TypeParam(-1.266024f), TypeParam(-0.870706f), TypeParam(-1.217616f), TypeParam(1.292693f), TypeParam(0.543307f), TypeParam(0.219521f), TypeParam(-0.255151f), TypeParam(0.543599f), TypeParam(0.062982f), TypeParam(0.527696f), TypeParam(0.387590f), TypeParam(1.352544f), TypeParam(-0.758053f), TypeParam(-0.262859f), TypeParam(-0.820496f), TypeParam(-0.934255f), TypeParam(0.434353f), TypeParam(0.262797f), TypeParam(-0.092283f), TypeParam(-0.021089f), TypeParam(-0.106052f), TypeParam(-0.119717f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.404710f, -0.654932f, 0.052124f, 0.340055f, -0.212416f, 1.562917f, -0.907159f, -1.566185f, 0.596746f, 1.002548f, -0.820504f, 0.509186f, 0.951389f, 0.773736f, -2.144711f, 0.044147f, 1.290612f, 0.664926f, 0.530731f, -0.423196f, -0.388699f, 0.333224f, 0.293744f, -0.157543f}; + std::initializer_list X_data{TypeParam(0.404710f), TypeParam(-0.654932f), TypeParam(0.052124f), TypeParam(0.340055f), TypeParam(-0.212416f), TypeParam(1.562917f), TypeParam(-0.907159f), TypeParam(-1.566185f), TypeParam(0.596746f), TypeParam(1.002548f), TypeParam(-0.820504f), TypeParam(0.509186f), TypeParam(0.951389f), TypeParam(0.773736f), TypeParam(-2.144711f), TypeParam(0.044147f), TypeParam(1.290612f), TypeParam(0.664926f), TypeParam(0.530731f), TypeParam(-0.423196f), TypeParam(-0.388699f), TypeParam(0.333224f), TypeParam(0.293744f), TypeParam(-0.157543f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.528957f, 0.982925f, -0.033286f, -0.806271f, 0.793837f, -0.411498f, 0.621343f, -0.295724f, 0.510113f, 1.079311f, 1.115827f, -1.092078f, -0.793776f, -0.496160f, -0.765241f, 1.151400f, -0.105983f, -0.796009f, -0.533987f, -0.662838f, 0.489587f, -1.046701f, -1.118884f, -1.182913f}; + std::initializer_list Grid_data{TypeParam(0.528957f), TypeParam(0.982925f), TypeParam(-0.033286f), TypeParam(-0.806271f), TypeParam(0.793837f), TypeParam(-0.411498f), TypeParam(0.621343f), TypeParam(-0.295724f), TypeParam(0.510113f), TypeParam(1.079311f), TypeParam(1.115827f), TypeParam(-1.092078f), TypeParam(-0.793776f), TypeParam(-0.496160f), TypeParam(-0.765241f), TypeParam(1.151400f), TypeParam(-0.105983f), TypeParam(-0.796009f), TypeParam(-0.533987f), TypeParam(-0.662838f), TypeParam(0.489587f), TypeParam(-1.046701f), TypeParam(-1.118884f), TypeParam(-1.182913f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.562917f, 0.404710f, 0.340055f, 0.340055f, 1.562917f, -0.654932f, 0.509186f, -0.907159f, 1.002548f, 1.002548f, 0.509186f, -1.566185f, -2.144711f, 1.290612f, 0.951389f, 0.951389f, 0.773736f, 0.951389f, -0.388699f, 0.293744f, 0.530731f, 0.530731f, -0.423196f, 0.530731f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.562917f), TypeParam(0.404710f), TypeParam(0.340055f), TypeParam(0.340055f), TypeParam(1.562917f), TypeParam(-0.654932f), TypeParam(0.509186f), TypeParam(-0.907159f), TypeParam(1.002548f), TypeParam(1.002548f), TypeParam(0.509186f), TypeParam(-1.566185f), TypeParam(-2.144711f), TypeParam(1.290612f), TypeParam(0.951389f), TypeParam(0.951389f), TypeParam(0.773736f), TypeParam(0.951389f), TypeParam(-0.388699f), TypeParam(0.293744f), TypeParam(0.530731f), TypeParam(0.530731f), TypeParam(-0.423196f), TypeParam(0.530731f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-1.495959f, 0.018231f, 0.345600f, 0.031206f, 0.400390f, 0.425763f, 0.839517f, 1.238945f, 0.523906f, -1.658372f, 0.548335f, -1.398321f, -1.976414f, 1.232491f, -0.545575f, -0.069414f, 0.732245f, -0.150333f, -0.707132f, 0.467497f, 0.278677f, 1.335679f, 1.155313f, -0.056298f, 0.430615f, -0.932645f, -1.505319f, 0.103317f, 1.521579f, 0.365497f, 1.428928f, 0.364333f, 1.683777f, 1.010632f, 0.621895f, 2.284701f, 1.574905f, -0.310514f, 1.495724f, 1.003370f, -1.437482f, 0.043097f, -1.645546f, -1.464643f, 0.350139f, -0.105905f, -0.740495f, 1.157691f, 1.443377f, 0.198399f, -1.105180f, -2.037115f, 2.128767f, -0.204457f, 0.468464f, 1.203629f, -0.362309f, -0.130520f, 1.532353f, 1.547599f, -0.831847f, -1.008509f, 0.023218f, 0.342626f, -0.882915f, 0.560640f, -1.142297f, 1.119107f, 0.385787f, -0.068515f, -0.529550f, -0.233903f}; + std::initializer_list X_data{TypeParam(-1.495959f), TypeParam(0.018231f), TypeParam(0.345600f), TypeParam(0.031206f), TypeParam(0.400390f), TypeParam(0.425763f), TypeParam(0.839517f), TypeParam(1.238945f), TypeParam(0.523906f), TypeParam(-1.658372f), TypeParam(0.548335f), TypeParam(-1.398321f), TypeParam(-1.976414f), TypeParam(1.232491f), TypeParam(-0.545575f), TypeParam(-0.069414f), TypeParam(0.732245f), TypeParam(-0.150333f), TypeParam(-0.707132f), TypeParam(0.467497f), TypeParam(0.278677f), TypeParam(1.335679f), TypeParam(1.155313f), TypeParam(-0.056298f), TypeParam(0.430615f), TypeParam(-0.932645f), TypeParam(-1.505319f), TypeParam(0.103317f), TypeParam(1.521579f), TypeParam(0.365497f), TypeParam(1.428928f), TypeParam(0.364333f), TypeParam(1.683777f), TypeParam(1.010632f), TypeParam(0.621895f), TypeParam(2.284701f), TypeParam(1.574905f), TypeParam(-0.310514f), TypeParam(1.495724f), TypeParam(1.003370f), TypeParam(-1.437482f), TypeParam(0.043097f), TypeParam(-1.645546f), TypeParam(-1.464643f), TypeParam(0.350139f), TypeParam(-0.105905f), TypeParam(-0.740495f), TypeParam(1.157691f), TypeParam(1.443377f), TypeParam(0.198399f), TypeParam(-1.105180f), TypeParam(-2.037115f), TypeParam(2.128767f), TypeParam(-0.204457f), TypeParam(0.468464f), TypeParam(1.203629f), TypeParam(-0.362309f), TypeParam(-0.130520f), TypeParam(1.532353f), TypeParam(1.547599f), TypeParam(-0.831847f), TypeParam(-1.008509f), TypeParam(0.023218f), TypeParam(0.342626f), TypeParam(-0.882915f), TypeParam(0.560640f), TypeParam(-1.142297f), TypeParam(1.119107f), TypeParam(0.385787f), TypeParam(-0.068515f), TypeParam(-0.529550f), TypeParam(-0.233903f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.812645f, 0.528235f, -0.550793f, -0.856977f, -1.073535f, 0.059526f, 1.163856f, -0.227931f, -0.050518f, -0.872033f, 0.368412f, 0.760780f, -1.183099f, -0.844947f, 0.888849f, 0.284117f, -0.074815f, 0.214510f, -0.182450f, -0.838758f, -1.121316f, 0.789250f, -0.142724f, -0.445665f, -0.309738f, -0.654508f, -0.355420f, -1.030097f, 0.898012f, 0.490011f, -0.605186f, -0.409576f, 0.538365f, -0.444367f, 0.316432f, 0.330410f, -0.755392f, 0.300602f, 0.073421f, 1.048061f, -0.434184f, -0.308482f, 1.033921f, -0.979923f, 0.086698f, 1.156203f, -0.538042f, 1.150419f, 1.064809f, 1.116408f, -0.114508f, 1.085560f, -0.522863f, -0.410766f, 0.453879f, 0.253497f, 0.661531f, 1.140383f, -0.751187f, 0.636872f, 0.401477f, 0.633082f, 0.569007f, -0.448884f, -0.948427f, 0.960462f, -0.684283f, 0.767193f, -1.143172f, -0.207603f, 0.012719f, 0.207628f, 0.096998f, 0.378128f, -0.133613f, 0.293885f, 1.187501f, -0.776462f, -0.065516f, -0.458068f, 1.052916f, 1.027248f, -0.032723f, -0.415959f, -0.741439f, 0.858648f, -0.082636f, 1.130172f, 0.684314f, 1.050365f, 0.949108f, -0.779811f, 0.351243f, -0.497591f, 0.602104f, -0.107892f, 0.103884f, -0.829931f, -1.072471f, 0.451888f, 0.278862f, 0.104235f, 0.815033f, -0.501089f, 0.425977f, -0.660914f, 0.248640f, -0.273958f}; + std::initializer_list Grid_data{TypeParam(0.812645f), TypeParam(0.528235f), TypeParam(-0.550793f), TypeParam(-0.856977f), TypeParam(-1.073535f), TypeParam(0.059526f), TypeParam(1.163856f), TypeParam(-0.227931f), TypeParam(-0.050518f), TypeParam(-0.872033f), TypeParam(0.368412f), TypeParam(0.760780f), TypeParam(-1.183099f), TypeParam(-0.844947f), TypeParam(0.888849f), TypeParam(0.284117f), TypeParam(-0.074815f), TypeParam(0.214510f), TypeParam(-0.182450f), TypeParam(-0.838758f), TypeParam(-1.121316f), TypeParam(0.789250f), TypeParam(-0.142724f), TypeParam(-0.445665f), TypeParam(-0.309738f), TypeParam(-0.654508f), TypeParam(-0.355420f), TypeParam(-1.030097f), TypeParam(0.898012f), TypeParam(0.490011f), TypeParam(-0.605186f), TypeParam(-0.409576f), TypeParam(0.538365f), TypeParam(-0.444367f), TypeParam(0.316432f), TypeParam(0.330410f), TypeParam(-0.755392f), TypeParam(0.300602f), TypeParam(0.073421f), TypeParam(1.048061f), TypeParam(-0.434184f), TypeParam(-0.308482f), TypeParam(1.033921f), TypeParam(-0.979923f), TypeParam(0.086698f), TypeParam(1.156203f), TypeParam(-0.538042f), TypeParam(1.150419f), TypeParam(1.064809f), TypeParam(1.116408f), TypeParam(-0.114508f), TypeParam(1.085560f), TypeParam(-0.522863f), TypeParam(-0.410766f), TypeParam(0.453879f), TypeParam(0.253497f), TypeParam(0.661531f), TypeParam(1.140383f), TypeParam(-0.751187f), TypeParam(0.636872f), TypeParam(0.401477f), TypeParam(0.633082f), TypeParam(0.569007f), TypeParam(-0.448884f), TypeParam(-0.948427f), TypeParam(0.960462f), TypeParam(-0.684283f), TypeParam(0.767193f), TypeParam(-1.143172f), TypeParam(-0.207603f), TypeParam(0.012719f), TypeParam(0.207628f), TypeParam(0.096998f), TypeParam(0.378128f), TypeParam(-0.133613f), TypeParam(0.293885f), TypeParam(1.187501f), TypeParam(-0.776462f), TypeParam(-0.065516f), TypeParam(-0.458068f), TypeParam(1.052916f), TypeParam(1.027248f), TypeParam(-0.032723f), TypeParam(-0.415959f), TypeParam(-0.741439f), TypeParam(0.858648f), TypeParam(-0.082636f), TypeParam(1.130172f), TypeParam(0.684314f), TypeParam(1.050365f), TypeParam(0.949108f), TypeParam(-0.779811f), TypeParam(0.351243f), TypeParam(-0.497591f), TypeParam(0.602104f), TypeParam(-0.107892f), TypeParam(0.103884f), TypeParam(-0.829931f), TypeParam(-1.072471f), TypeParam(0.451888f), TypeParam(0.278862f), TypeParam(0.104235f), TypeParam(0.815033f), TypeParam(-0.501089f), TypeParam(0.425977f), TypeParam(-0.660914f), TypeParam(0.248640f), TypeParam(-0.273958f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.425763f, 0.839517f, -1.658372f, -0.545575f, -1.976414f, -1.658372f, -1.495959f, -1.658372f, 0.839517f, 0.548335f, -0.545575f, 0.523906f, 0.523906f, -1.658372f, 1.238945f, 1.232491f, -1.398321f, 1.238945f, -0.056298f, 0.430615f, 0.103317f, 1.683777f, 1.428928f, 0.103317f, -0.707132f, 0.103317f, 0.430615f, 1.521579f, 1.683777f, -1.505319f, -1.505319f, 0.103317f, -0.932645f, 0.364333f, 0.365497f, -0.932645f, -2.037115f, 0.198399f, -0.204457f, 1.443377f, -1.437482f, 0.350139f, -0.105905f, 0.043097f, -1.105180f, -0.105905f, -0.740495f, -0.204457f, -1.464643f, -0.740495f, -0.310514f, -0.105905f, -1.464643f, 0.350139f, -0.068515f, 1.119107f, -0.233903f, -1.142297f, 1.532353f, 0.023218f, 0.342626f, 1.547599f, 0.385787f, 0.342626f, -0.882915f, -0.233903f, -1.008509f, -0.882915f, 1.203629f, 0.342626f, -1.008509f, 0.023218f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.425763f), TypeParam(0.839517f), TypeParam(-1.658372f), TypeParam(-0.545575f), TypeParam(-1.976414f), TypeParam(-1.658372f), TypeParam(-1.495959f), TypeParam(-1.658372f), TypeParam(0.839517f), TypeParam(0.548335f), TypeParam(-0.545575f), TypeParam(0.523906f), TypeParam(0.523906f), TypeParam(-1.658372f), TypeParam(1.238945f), TypeParam(1.232491f), TypeParam(-1.398321f), TypeParam(1.238945f), TypeParam(-0.056298f), TypeParam(0.430615f), TypeParam(0.103317f), TypeParam(1.683777f), TypeParam(1.428928f), TypeParam(0.103317f), TypeParam(-0.707132f), TypeParam(0.103317f), TypeParam(0.430615f), TypeParam(1.521579f), TypeParam(1.683777f), TypeParam(-1.505319f), TypeParam(-1.505319f), TypeParam(0.103317f), TypeParam(-0.932645f), TypeParam(0.364333f), TypeParam(0.365497f), TypeParam(-0.932645f), TypeParam(-2.037115f), TypeParam(0.198399f), TypeParam(-0.204457f), TypeParam(1.443377f), TypeParam(-1.437482f), TypeParam(0.350139f), TypeParam(-0.105905f), TypeParam(0.043097f), TypeParam(-1.105180f), TypeParam(-0.105905f), TypeParam(-0.740495f), TypeParam(-0.204457f), TypeParam(-1.464643f), TypeParam(-0.740495f), TypeParam(-0.310514f), TypeParam(-0.105905f), TypeParam(-1.464643f), TypeParam(0.350139f), TypeParam(-0.068515f), TypeParam(1.119107f), TypeParam(-0.233903f), TypeParam(-1.142297f), TypeParam(1.532353f), TypeParam(0.023218f), TypeParam(0.342626f), TypeParam(1.547599f), TypeParam(0.385787f), TypeParam(0.342626f), TypeParam(-0.882915f), TypeParam(-0.233903f), TypeParam(-1.008509f), TypeParam(-0.882915f), TypeParam(1.203629f), TypeParam(0.342626f), TypeParam(-1.008509f), TypeParam(0.023218f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.948141f, 1.836740f, -0.418393f, -0.125621f, 1.779137f, -0.028049f, 0.367697f, -0.388847f, -0.939514f, -0.129193f, -0.101240f, -3.087570f, -0.778617f, 1.026859f, 0.624162f, 0.291416f, 0.580998f, -0.185200f, 0.333020f, 0.415896f, 0.011702f, 0.014502f, -0.722870f, -0.201041f}; + std::initializer_list X_data{TypeParam(-1.948141f), TypeParam(1.836740f), TypeParam(-0.418393f), TypeParam(-0.125621f), TypeParam(1.779137f), TypeParam(-0.028049f), TypeParam(0.367697f), TypeParam(-0.388847f), TypeParam(-0.939514f), TypeParam(-0.129193f), TypeParam(-0.101240f), TypeParam(-3.087570f), TypeParam(-0.778617f), TypeParam(1.026859f), TypeParam(0.624162f), TypeParam(0.291416f), TypeParam(0.580998f), TypeParam(-0.185200f), TypeParam(0.333020f), TypeParam(0.415896f), TypeParam(0.011702f), TypeParam(0.014502f), TypeParam(-0.722870f), TypeParam(-0.201041f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.818167f, -0.394078f, 0.627076f, -1.124307f, -0.296864f, -0.244061f, -0.423780f, 0.504000f, -0.546789f, -0.139085f, -0.346504f, -1.126900f, -0.198169f, -1.016972f, 0.699725f, 0.641356f, 1.124151f, -0.402963f, 0.061023f, 0.235069f, 1.197862f, 1.099936f, -0.621047f, -1.021083f}; + std::initializer_list Grid_data{TypeParam(0.818167f), TypeParam(-0.394078f), TypeParam(0.627076f), TypeParam(-1.124307f), TypeParam(-0.296864f), TypeParam(-0.244061f), TypeParam(-0.423780f), TypeParam(0.504000f), TypeParam(-0.546789f), TypeParam(-0.139085f), TypeParam(-0.346504f), TypeParam(-1.126900f), TypeParam(-0.198169f), TypeParam(-1.016972f), TypeParam(0.699725f), TypeParam(0.641356f), TypeParam(1.124151f), TypeParam(-0.402963f), TypeParam(0.061023f), TypeParam(0.235069f), TypeParam(1.197862f), TypeParam(1.099936f), TypeParam(-0.621047f), TypeParam(-1.021083f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.836740f, 0.000000f, -0.418393f, 1.779137f, -0.418393f, 0.000000f, -0.388847f, 0.000000f, -0.939514f, -0.101240f, -0.939514f, 0.000000f, 0.000000f, -0.185200f, 0.000000f, 0.291416f, 0.000000f, 0.000000f, 0.000000f, -0.201041f, 0.000000f, 0.014502f, 0.000000f, 0.000000f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.836740f), TypeParam(0.000000f), TypeParam(-0.418393f), TypeParam(1.779137f), TypeParam(-0.418393f), TypeParam(0.000000f), TypeParam(-0.388847f), TypeParam(0.000000f), TypeParam(-0.939514f), TypeParam(-0.101240f), TypeParam(-0.939514f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.185200f), TypeParam(0.000000f), TypeParam(0.291416f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.201041f), TypeParam(0.000000f), TypeParam(0.014502f), TypeParam(0.000000f), TypeParam(0.000000f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{0.317302f, 0.629807f, -0.470444f, 0.215051f, 2.234212f, -1.940229f, 0.577203f, -0.166697f, -0.023467f, -0.451050f, -2.199999f, 1.469197f, -1.758133f, -0.570410f, -1.040355f, -0.627640f, 1.398573f, 0.275127f, -0.333592f, -0.677762f, -0.247167f, -0.290725f, -0.986956f, 0.173983f, -0.971920f, 0.225261f, -0.626680f, 1.660835f, 0.972993f, 0.223424f, 2.283593f, -1.145964f, -0.851223f, -2.052948f, -1.351783f, -0.028922f, 0.394421f, 0.057878f, -0.668671f, -0.088841f, 0.560186f, -0.105506f, 0.277478f, 1.047901f, -0.564728f, -0.287761f, 0.653621f, 0.259766f, 1.629452f, -2.337903f, -0.276703f, 0.258084f, -0.552200f, -0.464470f, -0.412042f, -1.047346f, 0.169468f, 1.334588f, 0.580615f, 1.217562f, -2.487876f, -1.218598f, -0.256617f, 1.397251f, 0.694875f, 0.732315f, 0.574448f, 0.673838f, -1.870634f, -0.855206f, 1.068415f, 0.096061f}; + std::initializer_list X_data{TypeParam(0.317302f), TypeParam(0.629807f), TypeParam(-0.470444f), TypeParam(0.215051f), TypeParam(2.234212f), TypeParam(-1.940229f), TypeParam(0.577203f), TypeParam(-0.166697f), TypeParam(-0.023467f), TypeParam(-0.451050f), TypeParam(-2.199999f), TypeParam(1.469197f), TypeParam(-1.758133f), TypeParam(-0.570410f), TypeParam(-1.040355f), TypeParam(-0.627640f), TypeParam(1.398573f), TypeParam(0.275127f), TypeParam(-0.333592f), TypeParam(-0.677762f), TypeParam(-0.247167f), TypeParam(-0.290725f), TypeParam(-0.986956f), TypeParam(0.173983f), TypeParam(-0.971920f), TypeParam(0.225261f), TypeParam(-0.626680f), TypeParam(1.660835f), TypeParam(0.972993f), TypeParam(0.223424f), TypeParam(2.283593f), TypeParam(-1.145964f), TypeParam(-0.851223f), TypeParam(-2.052948f), TypeParam(-1.351783f), TypeParam(-0.028922f), TypeParam(0.394421f), TypeParam(0.057878f), TypeParam(-0.668671f), TypeParam(-0.088841f), TypeParam(0.560186f), TypeParam(-0.105506f), TypeParam(0.277478f), TypeParam(1.047901f), TypeParam(-0.564728f), TypeParam(-0.287761f), TypeParam(0.653621f), TypeParam(0.259766f), TypeParam(1.629452f), TypeParam(-2.337903f), TypeParam(-0.276703f), TypeParam(0.258084f), TypeParam(-0.552200f), TypeParam(-0.464470f), TypeParam(-0.412042f), TypeParam(-1.047346f), TypeParam(0.169468f), TypeParam(1.334588f), TypeParam(0.580615f), TypeParam(1.217562f), TypeParam(-2.487876f), TypeParam(-1.218598f), TypeParam(-0.256617f), TypeParam(1.397251f), TypeParam(0.694875f), TypeParam(0.732315f), TypeParam(0.574448f), TypeParam(0.673838f), TypeParam(-1.870634f), TypeParam(-0.855206f), TypeParam(1.068415f), TypeParam(0.096061f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.650046f, -0.680891f, -0.200337f, -1.006178f, -0.676990f, 0.500592f, -1.118072f, -0.684288f, 0.899676f, -0.615418f, -0.499387f, -0.336929f, 0.512951f, -0.787164f, 0.120318f, 0.490083f, -0.087112f, 0.216982f, -0.915417f, 0.542519f, 0.448475f, -0.150519f, -0.992244f, 0.479971f, 0.783050f, -0.209890f, 0.565605f, 0.444791f, -0.479961f, -0.083304f, 1.194526f, 0.005665f, -0.955336f, -0.087514f, 0.596991f, -0.391708f, -0.628420f, 0.988534f, 0.634814f, -0.203871f, 0.061307f, -0.126915f, 0.278599f, 0.042647f, -0.726162f, 0.222329f, 0.031386f, 0.077584f, -0.457305f, 0.307467f, -0.970375f, 0.358708f, 0.650272f, -0.132064f, -0.932160f, -0.004362f, 0.001704f, -1.037046f, -0.848754f, 1.109926f, 0.897382f, 0.665044f, 0.831311f, 0.461956f, 0.675346f, 0.794786f, -0.280329f, -0.152546f, 0.855656f, -0.000432f, -0.780824f, -0.930479f, 0.671131f, 0.993983f, 0.931935f, 0.199703f, 0.828337f, -1.101760f, -0.864556f, -1.154677f, 0.966824f, -0.010858f, -0.552558f, 0.406048f, -0.449199f, -0.769613f, 0.462838f, 0.219719f, -0.859342f, -0.790394f, 0.562644f, 0.912452f, 0.097688f, -0.602742f, 0.579449f, 0.209287f, -1.050575f, -0.777654f, 0.262652f, 0.742529f, -0.385517f, 0.580240f, -0.743175f, 1.148320f, 0.855053f, 0.224769f, 0.533871f, 0.417788f}; + std::initializer_list Grid_data{TypeParam(0.650046f), TypeParam(-0.680891f), TypeParam(-0.200337f), TypeParam(-1.006178f), TypeParam(-0.676990f), TypeParam(0.500592f), TypeParam(-1.118072f), TypeParam(-0.684288f), TypeParam(0.899676f), TypeParam(-0.615418f), TypeParam(-0.499387f), TypeParam(-0.336929f), TypeParam(0.512951f), TypeParam(-0.787164f), TypeParam(0.120318f), TypeParam(0.490083f), TypeParam(-0.087112f), TypeParam(0.216982f), TypeParam(-0.915417f), TypeParam(0.542519f), TypeParam(0.448475f), TypeParam(-0.150519f), TypeParam(-0.992244f), TypeParam(0.479971f), TypeParam(0.783050f), TypeParam(-0.209890f), TypeParam(0.565605f), TypeParam(0.444791f), TypeParam(-0.479961f), TypeParam(-0.083304f), TypeParam(1.194526f), TypeParam(0.005665f), TypeParam(-0.955336f), TypeParam(-0.087514f), TypeParam(0.596991f), TypeParam(-0.391708f), TypeParam(-0.628420f), TypeParam(0.988534f), TypeParam(0.634814f), TypeParam(-0.203871f), TypeParam(0.061307f), TypeParam(-0.126915f), TypeParam(0.278599f), TypeParam(0.042647f), TypeParam(-0.726162f), TypeParam(0.222329f), TypeParam(0.031386f), TypeParam(0.077584f), TypeParam(-0.457305f), TypeParam(0.307467f), TypeParam(-0.970375f), TypeParam(0.358708f), TypeParam(0.650272f), TypeParam(-0.132064f), TypeParam(-0.932160f), TypeParam(-0.004362f), TypeParam(0.001704f), TypeParam(-1.037046f), TypeParam(-0.848754f), TypeParam(1.109926f), TypeParam(0.897382f), TypeParam(0.665044f), TypeParam(0.831311f), TypeParam(0.461956f), TypeParam(0.675346f), TypeParam(0.794786f), TypeParam(-0.280329f), TypeParam(-0.152546f), TypeParam(0.855656f), TypeParam(-0.000432f), TypeParam(-0.780824f), TypeParam(-0.930479f), TypeParam(0.671131f), TypeParam(0.993983f), TypeParam(0.931935f), TypeParam(0.199703f), TypeParam(0.828337f), TypeParam(-1.101760f), TypeParam(-0.864556f), TypeParam(-1.154677f), TypeParam(0.966824f), TypeParam(-0.010858f), TypeParam(-0.552558f), TypeParam(0.406048f), TypeParam(-0.449199f), TypeParam(-0.769613f), TypeParam(0.462838f), TypeParam(0.219719f), TypeParam(-0.859342f), TypeParam(-0.790394f), TypeParam(0.562644f), TypeParam(0.912452f), TypeParam(0.097688f), TypeParam(-0.602742f), TypeParam(0.579449f), TypeParam(0.209287f), TypeParam(-1.050575f), TypeParam(-0.777654f), TypeParam(0.262652f), TypeParam(0.742529f), TypeParam(-0.385517f), TypeParam(0.580240f), TypeParam(-0.743175f), TypeParam(1.148320f), TypeParam(0.855053f), TypeParam(0.224769f), TypeParam(0.533871f), TypeParam(0.417788f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.166697f, 0.000000f, 0.000000f, 0.317302f, -0.166697f, -0.451050f, 1.398573f, -1.758133f, -0.627640f, -0.166697f, 0.000000f, 2.234212f, 1.398573f, -0.023467f, 0.215051f, -0.451050f, -0.470444f, 1.469197f, 0.225261f, 0.000000f, 0.000000f, -0.333592f, 0.225261f, 1.660835f, -1.351783f, 2.283593f, -2.052948f, 0.225261f, 0.000000f, -0.986956f, -1.351783f, -0.626680f, -0.290725f, 1.660835f, -0.247167f, 0.223424f, -0.564728f, 0.000000f, -0.464470f, -0.464470f, -0.276703f, 0.394421f, -0.464470f, 0.000000f, 0.000000f, 1.629452f, 1.629452f, 0.057878f, 0.259766f, 0.653621f, 0.000000f, -2.337903f, 0.000000f, -0.464470f, -0.256617f, 0.000000f, 0.096061f, 0.096061f, -1.870634f, -0.412042f, 0.096061f, 0.000000f, 0.000000f, 0.574448f, 0.574448f, -1.047346f, 0.732315f, 0.694875f, 0.000000f, 0.673838f, 0.000000f, 0.096061f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.166697f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.317302f), TypeParam(-0.166697f), TypeParam(-0.451050f), TypeParam(1.398573f), TypeParam(-1.758133f), TypeParam(-0.627640f), TypeParam(-0.166697f), TypeParam(0.000000f), TypeParam(2.234212f), TypeParam(1.398573f), TypeParam(-0.023467f), TypeParam(0.215051f), TypeParam(-0.451050f), TypeParam(-0.470444f), TypeParam(1.469197f), TypeParam(0.225261f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.333592f), TypeParam(0.225261f), TypeParam(1.660835f), TypeParam(-1.351783f), TypeParam(2.283593f), TypeParam(-2.052948f), TypeParam(0.225261f), TypeParam(0.000000f), TypeParam(-0.986956f), TypeParam(-1.351783f), TypeParam(-0.626680f), TypeParam(-0.290725f), TypeParam(1.660835f), TypeParam(-0.247167f), TypeParam(0.223424f), TypeParam(-0.564728f), TypeParam(0.000000f), TypeParam(-0.464470f), TypeParam(-0.464470f), TypeParam(-0.276703f), TypeParam(0.394421f), TypeParam(-0.464470f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(1.629452f), TypeParam(1.629452f), TypeParam(0.057878f), TypeParam(0.259766f), TypeParam(0.653621f), TypeParam(0.000000f), TypeParam(-2.337903f), TypeParam(0.000000f), TypeParam(-0.464470f), TypeParam(-0.256617f), TypeParam(0.000000f), TypeParam(0.096061f), TypeParam(0.096061f), TypeParam(-1.870634f), TypeParam(-0.412042f), TypeParam(0.096061f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.574448f), TypeParam(0.574448f), TypeParam(-1.047346f), TypeParam(0.732315f), TypeParam(0.694875f), TypeParam(0.000000f), TypeParam(0.673838f), TypeParam(0.000000f), TypeParam(0.096061f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.660065f, 0.995767f, -0.226389f, 0.590604f, -2.628610f, 0.444899f, 0.023282f, 0.024018f, -0.584701f, 1.988638f, -0.023379f, 0.711650f, -1.062933f, -0.064113f, 1.178346f, -0.652373f, 1.259795f, 1.508661f, -0.079368f, 0.819443f, 0.836356f, -0.362184f, -1.153828f, -0.561180f}; + std::initializer_list X_data{TypeParam(0.660065f), TypeParam(0.995767f), TypeParam(-0.226389f), TypeParam(0.590604f), TypeParam(-2.628610f), TypeParam(0.444899f), TypeParam(0.023282f), TypeParam(0.024018f), TypeParam(-0.584701f), TypeParam(1.988638f), TypeParam(-0.023379f), TypeParam(0.711650f), TypeParam(-1.062933f), TypeParam(-0.064113f), TypeParam(1.178346f), TypeParam(-0.652373f), TypeParam(1.259795f), TypeParam(1.508661f), TypeParam(-0.079368f), TypeParam(0.819443f), TypeParam(0.836356f), TypeParam(-0.362184f), TypeParam(-1.153828f), TypeParam(-0.561180f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.447651f, -0.521958f, 0.673539f, 0.222645f, 1.010165f, 0.451903f, 0.966699f, -0.966970f, 0.964714f, -0.551345f, -0.321222f, 0.007182f, -0.225038f, 0.237367f, 1.069316f, -0.716982f, 0.370785f, -0.964445f, 0.188419f, 0.988574f, 0.809140f, 1.027635f, 0.649589f, -0.099282f}; + std::initializer_list Grid_data{TypeParam(-0.447651f), TypeParam(-0.521958f), TypeParam(0.673539f), TypeParam(0.222645f), TypeParam(1.010165f), TypeParam(0.451903f), TypeParam(0.966699f), TypeParam(-0.966970f), TypeParam(0.964714f), TypeParam(-0.551345f), TypeParam(-0.321222f), TypeParam(0.007182f), TypeParam(-0.225038f), TypeParam(0.237367f), TypeParam(1.069316f), TypeParam(-0.716982f), TypeParam(0.370785f), TypeParam(-0.964445f), TypeParam(0.188419f), TypeParam(0.988574f), TypeParam(0.809140f), TypeParam(1.027635f), TypeParam(0.649589f), TypeParam(-0.099282f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.660065f, 0.590604f, 0.590604f, 0.995767f, 0.995767f, -0.226389f, 0.023282f, 1.988638f, 1.988638f, 0.024018f, 0.024018f, -0.584701f, 1.178346f, -0.064113f, -0.064113f, 1.508661f, 1.508661f, -0.652373f, 0.836356f, 0.819443f, 0.819443f, -0.561180f, -0.561180f, -0.362184f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.660065f), TypeParam(0.590604f), TypeParam(0.590604f), TypeParam(0.995767f), TypeParam(0.995767f), TypeParam(-0.226389f), TypeParam(0.023282f), TypeParam(1.988638f), TypeParam(1.988638f), TypeParam(0.024018f), TypeParam(0.024018f), TypeParam(-0.584701f), TypeParam(1.178346f), TypeParam(-0.064113f), TypeParam(-0.064113f), TypeParam(1.508661f), TypeParam(1.508661f), TypeParam(-0.652373f), TypeParam(0.836356f), TypeParam(0.819443f), TypeParam(0.819443f), TypeParam(-0.561180f), TypeParam(-0.561180f), TypeParam(-0.362184f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.920922f, -0.560469f, -2.244605f, -0.061799f, 0.523656f, 0.110097f, -0.944521f, 0.818932f, 1.069286f, 0.611457f, -0.355875f, 1.664810f, 0.116694f, 2.318200f, 0.681699f, -0.792880f, -0.025672f, -0.592222f, 0.229768f, -0.521888f, 0.570937f, -0.029345f, -0.873323f, 1.721509f, 2.011626f, -0.310838f, 1.121670f, 0.778967f, -0.450894f, 1.030269f, 0.166967f, -0.244737f, 0.227200f, -0.416612f, -0.276513f, 0.714623f, 0.908783f, -1.393580f, -0.983675f, -0.366833f, 1.473970f, 0.624368f, -0.607720f, -0.523833f, -0.124702f, -0.766457f, -0.131027f, 2.227047f, 1.399269f, 0.053366f, -0.295771f, -0.283811f, 0.019280f, -0.104450f, -0.574185f, -2.130628f, 0.617878f, -1.728151f, -0.272528f, 1.299354f, -1.109310f, -1.881107f, -1.300843f, -0.765376f, -0.477722f, -1.230664f, -0.495792f, 1.061688f, 1.244247f, -0.550821f, -0.520524f, 1.541448f}; + std::initializer_list X_data{TypeParam(-0.920922f), TypeParam(-0.560469f), TypeParam(-2.244605f), TypeParam(-0.061799f), TypeParam(0.523656f), TypeParam(0.110097f), TypeParam(-0.944521f), TypeParam(0.818932f), TypeParam(1.069286f), TypeParam(0.611457f), TypeParam(-0.355875f), TypeParam(1.664810f), TypeParam(0.116694f), TypeParam(2.318200f), TypeParam(0.681699f), TypeParam(-0.792880f), TypeParam(-0.025672f), TypeParam(-0.592222f), TypeParam(0.229768f), TypeParam(-0.521888f), TypeParam(0.570937f), TypeParam(-0.029345f), TypeParam(-0.873323f), TypeParam(1.721509f), TypeParam(2.011626f), TypeParam(-0.310838f), TypeParam(1.121670f), TypeParam(0.778967f), TypeParam(-0.450894f), TypeParam(1.030269f), TypeParam(0.166967f), TypeParam(-0.244737f), TypeParam(0.227200f), TypeParam(-0.416612f), TypeParam(-0.276513f), TypeParam(0.714623f), TypeParam(0.908783f), TypeParam(-1.393580f), TypeParam(-0.983675f), TypeParam(-0.366833f), TypeParam(1.473970f), TypeParam(0.624368f), TypeParam(-0.607720f), TypeParam(-0.523833f), TypeParam(-0.124702f), TypeParam(-0.766457f), TypeParam(-0.131027f), TypeParam(2.227047f), TypeParam(1.399269f), TypeParam(0.053366f), TypeParam(-0.295771f), TypeParam(-0.283811f), TypeParam(0.019280f), TypeParam(-0.104450f), TypeParam(-0.574185f), TypeParam(-2.130628f), TypeParam(0.617878f), TypeParam(-1.728151f), TypeParam(-0.272528f), TypeParam(1.299354f), TypeParam(-1.109310f), TypeParam(-1.881107f), TypeParam(-1.300843f), TypeParam(-0.765376f), TypeParam(-0.477722f), TypeParam(-1.230664f), TypeParam(-0.495792f), TypeParam(1.061688f), TypeParam(1.244247f), TypeParam(-0.550821f), TypeParam(-0.520524f), TypeParam(1.541448f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-1.189605f, -0.312072f, 0.459409f, 1.033285f, -1.083635f, 0.572921f, -1.138649f, -1.147562f, -0.751493f, -0.158500f, 0.335153f, -0.912613f, 0.924528f, 1.085165f, 0.073832f, 0.976781f, -0.543258f, -0.474714f, -0.154854f, 0.131118f, -0.837104f, -0.960885f, 0.474040f, 0.345992f, 1.173923f, -0.489256f, 0.423768f, -0.484246f, 0.592379f, -0.066474f, 0.889570f, 0.666682f, 0.998817f, 0.616675f, 0.045084f, 1.034127f, -0.704858f, 1.131824f, 1.172625f, 1.146321f, -0.560545f, -0.635830f, 0.075922f, 0.373677f, 0.601953f, 0.488043f, 1.021787f, -0.300648f, -0.393688f, 0.402240f, 0.334401f, -0.699993f, 0.116070f, -0.911100f, -0.352043f, -0.470968f, 1.051900f, -1.080208f, -0.708510f, -1.174356f, 0.302647f, -0.923627f, 0.388249f, -0.833533f, -0.768697f, -0.613051f, 0.180083f, 1.102657f, 1.124055f, -0.090660f, -1.175396f, -0.396450f, -0.457333f, -0.255235f, 0.458506f, 0.603882f, 0.532050f, 0.342802f, -0.485794f, -0.012730f, 0.152721f, -0.612948f, -0.107348f, -0.149795f, -1.133775f, 0.813507f, -0.121323f, -1.037352f, 0.949408f, -0.645689f, 0.424853f, 1.190055f, 0.055551f, 0.345244f, 0.476794f, 0.906949f, -0.368187f, -0.675263f, -0.093908f, 0.938461f, 0.103178f, 0.833774f, -0.008922f, 0.368184f, 0.041727f, 0.032575f, -1.141943f, -1.049081f}; + std::initializer_list Grid_data{TypeParam(-1.189605f), TypeParam(-0.312072f), TypeParam(0.459409f), TypeParam(1.033285f), TypeParam(-1.083635f), TypeParam(0.572921f), TypeParam(-1.138649f), TypeParam(-1.147562f), TypeParam(-0.751493f), TypeParam(-0.158500f), TypeParam(0.335153f), TypeParam(-0.912613f), TypeParam(0.924528f), TypeParam(1.085165f), TypeParam(0.073832f), TypeParam(0.976781f), TypeParam(-0.543258f), TypeParam(-0.474714f), TypeParam(-0.154854f), TypeParam(0.131118f), TypeParam(-0.837104f), TypeParam(-0.960885f), TypeParam(0.474040f), TypeParam(0.345992f), TypeParam(1.173923f), TypeParam(-0.489256f), TypeParam(0.423768f), TypeParam(-0.484246f), TypeParam(0.592379f), TypeParam(-0.066474f), TypeParam(0.889570f), TypeParam(0.666682f), TypeParam(0.998817f), TypeParam(0.616675f), TypeParam(0.045084f), TypeParam(1.034127f), TypeParam(-0.704858f), TypeParam(1.131824f), TypeParam(1.172625f), TypeParam(1.146321f), TypeParam(-0.560545f), TypeParam(-0.635830f), TypeParam(0.075922f), TypeParam(0.373677f), TypeParam(0.601953f), TypeParam(0.488043f), TypeParam(1.021787f), TypeParam(-0.300648f), TypeParam(-0.393688f), TypeParam(0.402240f), TypeParam(0.334401f), TypeParam(-0.699993f), TypeParam(0.116070f), TypeParam(-0.911100f), TypeParam(-0.352043f), TypeParam(-0.470968f), TypeParam(1.051900f), TypeParam(-1.080208f), TypeParam(-0.708510f), TypeParam(-1.174356f), TypeParam(0.302647f), TypeParam(-0.923627f), TypeParam(0.388249f), TypeParam(-0.833533f), TypeParam(-0.768697f), TypeParam(-0.613051f), TypeParam(0.180083f), TypeParam(1.102657f), TypeParam(1.124055f), TypeParam(-0.090660f), TypeParam(-1.175396f), TypeParam(-0.396450f), TypeParam(-0.457333f), TypeParam(-0.255235f), TypeParam(0.458506f), TypeParam(0.603882f), TypeParam(0.532050f), TypeParam(0.342802f), TypeParam(-0.485794f), TypeParam(-0.012730f), TypeParam(0.152721f), TypeParam(-0.612948f), TypeParam(-0.107348f), TypeParam(-0.149795f), TypeParam(-1.133775f), TypeParam(0.813507f), TypeParam(-0.121323f), TypeParam(-1.037352f), TypeParam(0.949408f), TypeParam(-0.645689f), TypeParam(0.424853f), TypeParam(1.190055f), TypeParam(0.055551f), TypeParam(0.345244f), TypeParam(0.476794f), TypeParam(0.906949f), TypeParam(-0.368187f), TypeParam(-0.675263f), TypeParam(-0.093908f), TypeParam(0.938461f), TypeParam(0.103178f), TypeParam(0.833774f), TypeParam(-0.008922f), TypeParam(0.368184f), TypeParam(0.041727f), TypeParam(0.032575f), TypeParam(-1.141943f), TypeParam(-1.049081f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{1.069286f, 2.318200f, -0.920922f, -2.244605f, 1.664810f, 0.818932f, -2.244605f, 1.069286f, 0.611457f, -0.355875f, -0.592222f, -0.792880f, -0.025672f, -0.560469f, -0.792880f, 1.664810f, 1.069286f, -2.244605f, 1.121670f, -0.244737f, 0.229768f, 0.570937f, 1.030269f, -0.310838f, 0.570937f, 1.121670f, 0.778967f, -0.450894f, 0.714623f, -0.416612f, -0.276513f, -0.521888f, -0.416612f, 1.030269f, 1.121670f, 0.570937f, -0.295771f, 0.908783f, -0.523833f, 0.908783f, -0.104450f, -0.607720f, -0.124702f, 2.227047f, -0.124702f, -0.124702f, -0.131027f, 1.473970f, 2.227047f, -0.283811f, -0.607720f, -0.283811f, -0.124702f, -1.393580f, 1.244247f, -0.574185f, -1.881107f, -0.574185f, 1.541448f, -1.109310f, -1.300843f, -1.230664f, -1.300843f, -1.300843f, -0.477722f, -0.272528f, -1.230664f, -0.550821f, -1.109310f, -0.550821f, -1.300843f, -2.130628f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.069286f), TypeParam(2.318200f), TypeParam(-0.920922f), TypeParam(-2.244605f), TypeParam(1.664810f), TypeParam(0.818932f), TypeParam(-2.244605f), TypeParam(1.069286f), TypeParam(0.611457f), TypeParam(-0.355875f), TypeParam(-0.592222f), TypeParam(-0.792880f), TypeParam(-0.025672f), TypeParam(-0.560469f), TypeParam(-0.792880f), TypeParam(1.664810f), TypeParam(1.069286f), TypeParam(-2.244605f), TypeParam(1.121670f), TypeParam(-0.244737f), TypeParam(0.229768f), TypeParam(0.570937f), TypeParam(1.030269f), TypeParam(-0.310838f), TypeParam(0.570937f), TypeParam(1.121670f), TypeParam(0.778967f), TypeParam(-0.450894f), TypeParam(0.714623f), TypeParam(-0.416612f), TypeParam(-0.276513f), TypeParam(-0.521888f), TypeParam(-0.416612f), TypeParam(1.030269f), TypeParam(1.121670f), TypeParam(0.570937f), TypeParam(-0.295771f), TypeParam(0.908783f), TypeParam(-0.523833f), TypeParam(0.908783f), TypeParam(-0.104450f), TypeParam(-0.607720f), TypeParam(-0.124702f), TypeParam(2.227047f), TypeParam(-0.124702f), TypeParam(-0.124702f), TypeParam(-0.131027f), TypeParam(1.473970f), TypeParam(2.227047f), TypeParam(-0.283811f), TypeParam(-0.607720f), TypeParam(-0.283811f), TypeParam(-0.124702f), TypeParam(-1.393580f), TypeParam(1.244247f), TypeParam(-0.574185f), TypeParam(-1.881107f), TypeParam(-0.574185f), TypeParam(1.541448f), TypeParam(-1.109310f), TypeParam(-1.300843f), TypeParam(-1.230664f), TypeParam(-1.300843f), TypeParam(-1.300843f), TypeParam(-0.477722f), TypeParam(-0.272528f), TypeParam(-1.230664f), TypeParam(-0.550821f), TypeParam(-1.109310f), TypeParam(-0.550821f), TypeParam(-1.300843f), TypeParam(-2.130628f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.950589f, -1.656624f, 0.767704f, -0.650720f, -1.404308f, -0.531582f, -0.280854f, 0.344309f, -0.959146f, -0.115645f, 0.515696f, -0.114243f, 1.971614f, 0.274268f, 0.543080f, -1.758563f, 1.771011f, 0.934901f, 0.695798f, 1.905137f, 1.598307f, 1.108385f, 0.156008f, 1.290824f}; + std::initializer_list X_data{TypeParam(0.950589f), TypeParam(-1.656624f), TypeParam(0.767704f), TypeParam(-0.650720f), TypeParam(-1.404308f), TypeParam(-0.531582f), TypeParam(-0.280854f), TypeParam(0.344309f), TypeParam(-0.959146f), TypeParam(-0.115645f), TypeParam(0.515696f), TypeParam(-0.114243f), TypeParam(1.971614f), TypeParam(0.274268f), TypeParam(0.543080f), TypeParam(-1.758563f), TypeParam(1.771011f), TypeParam(0.934901f), TypeParam(0.695798f), TypeParam(1.905137f), TypeParam(1.598307f), TypeParam(1.108385f), TypeParam(0.156008f), TypeParam(1.290824f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.482490f, -0.910951f, -0.001676f, -0.442514f, 0.580438f, 1.039346f, -0.159076f, -0.603960f, -0.922037f, -0.705026f, 0.346468f, 0.275332f, 0.646235f, -0.178307f, 0.616600f, -1.069108f, 0.322583f, 1.164952f, -1.187638f, -0.622953f, 0.768203f, -0.187618f, -0.639652f, 0.732078f}; + std::initializer_list Grid_data{TypeParam(0.482490f), TypeParam(-0.910951f), TypeParam(-0.001676f), TypeParam(-0.442514f), TypeParam(0.580438f), TypeParam(1.039346f), TypeParam(-0.159076f), TypeParam(-0.603960f), TypeParam(-0.922037f), TypeParam(-0.705026f), TypeParam(0.346468f), TypeParam(0.275332f), TypeParam(0.646235f), TypeParam(-0.178307f), TypeParam(0.616600f), TypeParam(-1.069108f), TypeParam(0.322583f), TypeParam(1.164952f), TypeParam(-1.187638f), TypeParam(-0.622953f), TypeParam(0.768203f), TypeParam(-0.187618f), TypeParam(-0.639652f), TypeParam(0.732078f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.656624f, 0.950589f, -0.531582f, 0.950589f, 0.950589f, -0.650720f, 0.344309f, -0.280854f, -0.114243f, -0.280854f, -0.280854f, -0.115645f, -1.758563f, 0.274268f, 0.934901f, 1.971614f, -1.758563f, 1.771011f, 1.108385f, 1.905137f, 1.290824f, 0.695798f, 1.108385f, 0.156008f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.656624f), TypeParam(0.950589f), TypeParam(-0.531582f), TypeParam(0.950589f), TypeParam(0.950589f), TypeParam(-0.650720f), TypeParam(0.344309f), TypeParam(-0.280854f), TypeParam(-0.114243f), TypeParam(-0.280854f), TypeParam(-0.280854f), TypeParam(-0.115645f), TypeParam(-1.758563f), TypeParam(0.274268f), TypeParam(0.934901f), TypeParam(1.971614f), TypeParam(-1.758563f), TypeParam(1.771011f), TypeParam(1.108385f), TypeParam(1.905137f), TypeParam(1.290824f), TypeParam(0.695798f), TypeParam(1.108385f), TypeParam(0.156008f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{0.465448f, -0.337086f, -0.870849f, -0.389573f, -0.083941f, 1.306894f, 0.719508f, -0.203690f, -1.143864f, 1.163003f, 0.312170f, -2.008687f, 1.731257f, -0.270431f, 1.095352f, -1.673520f, 0.492743f, 0.521962f, -1.938783f, -0.186813f, -0.836257f, -1.835450f, 0.476500f, -0.123386f, 0.246604f, 1.374159f, -0.158435f, 1.268192f, -0.704226f, -0.195314f, -0.277259f, 0.582961f, -0.340940f, 0.192264f, 0.463124f, -2.719402f, -0.593470f, -1.165777f, 0.566071f, 1.622836f, -0.886798f, 1.874877f, -0.849095f, 0.550185f, 0.604298f, 0.073976f, -0.800372f, -0.097283f, -1.576251f, -0.633278f, -1.776745f, -0.827586f, 0.665697f, 0.884698f, 0.467112f, -0.645219f, -0.510110f, 0.032418f, -1.056009f, -0.206175f, -0.173385f, 0.947787f, 1.937234f, 0.615880f, -0.311580f, 0.770921f, -0.841602f, 1.796220f, 0.479491f, 1.609346f, 1.113868f, -0.453360f}; + std::initializer_list X_data{TypeParam(0.465448f), TypeParam(-0.337086f), TypeParam(-0.870849f), TypeParam(-0.389573f), TypeParam(-0.083941f), TypeParam(1.306894f), TypeParam(0.719508f), TypeParam(-0.203690f), TypeParam(-1.143864f), TypeParam(1.163003f), TypeParam(0.312170f), TypeParam(-2.008687f), TypeParam(1.731257f), TypeParam(-0.270431f), TypeParam(1.095352f), TypeParam(-1.673520f), TypeParam(0.492743f), TypeParam(0.521962f), TypeParam(-1.938783f), TypeParam(-0.186813f), TypeParam(-0.836257f), TypeParam(-1.835450f), TypeParam(0.476500f), TypeParam(-0.123386f), TypeParam(0.246604f), TypeParam(1.374159f), TypeParam(-0.158435f), TypeParam(1.268192f), TypeParam(-0.704226f), TypeParam(-0.195314f), TypeParam(-0.277259f), TypeParam(0.582961f), TypeParam(-0.340940f), TypeParam(0.192264f), TypeParam(0.463124f), TypeParam(-2.719402f), TypeParam(-0.593470f), TypeParam(-1.165777f), TypeParam(0.566071f), TypeParam(1.622836f), TypeParam(-0.886798f), TypeParam(1.874877f), TypeParam(-0.849095f), TypeParam(0.550185f), TypeParam(0.604298f), TypeParam(0.073976f), TypeParam(-0.800372f), TypeParam(-0.097283f), TypeParam(-1.576251f), TypeParam(-0.633278f), TypeParam(-1.776745f), TypeParam(-0.827586f), TypeParam(0.665697f), TypeParam(0.884698f), TypeParam(0.467112f), TypeParam(-0.645219f), TypeParam(-0.510110f), TypeParam(0.032418f), TypeParam(-1.056009f), TypeParam(-0.206175f), TypeParam(-0.173385f), TypeParam(0.947787f), TypeParam(1.937234f), TypeParam(0.615880f), TypeParam(-0.311580f), TypeParam(0.770921f), TypeParam(-0.841602f), TypeParam(1.796220f), TypeParam(0.479491f), TypeParam(1.609346f), TypeParam(1.113868f), TypeParam(-0.453360f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.151540f, -0.033291f, -0.597203f, 0.836404f, -0.686848f, -0.485355f, -0.936738f, -1.009057f, 1.065352f, -0.926635f, -0.165670f, -0.347352f, 0.439545f, 0.320963f, -0.919909f, 1.077689f, -1.195359f, 0.118687f, -0.100253f, -0.278089f, 0.817760f, 1.013180f, 0.156316f, -0.423839f, 0.892139f, 0.753924f, 0.215530f, -0.328214f, 0.050592f, 1.069553f, 0.130134f, -0.236478f, -1.015986f, -0.643059f, 0.866682f, -0.042256f, -0.079912f, 0.467233f, -0.789513f, -0.081063f, -0.337505f, 0.627865f, 0.976589f, 0.753489f, 0.894667f, -1.072442f, -0.426020f, 0.142099f, -1.019226f, 0.325527f, -0.786578f, 0.514215f, 0.971223f, -1.026539f, 1.005531f, 0.559922f, -0.791906f, 1.148613f, -1.039306f, -0.807864f, -0.596935f, -0.060766f, 0.215484f, -0.352165f, -1.137417f, -0.138518f, 0.910459f, 0.923925f, 0.600710f, 0.174227f, 0.298169f, -0.925092f, 0.485927f, -1.194283f, -0.495564f, -0.315357f, 0.881199f, -0.034981f, -0.546611f, 0.209651f, -0.995724f, -0.317709f, 0.332343f, -0.079474f, -0.126024f, 0.733410f, -0.911554f, -0.605911f, 1.161566f, 0.238787f, -0.194293f, 0.621583f, 0.721901f, -0.200521f, -0.499850f, -0.196149f, 0.435730f, -0.153196f, 0.698401f, -0.978582f, -0.588758f, 0.914808f, 0.157427f, 0.241646f, 0.394674f, -0.283552f, -0.479889f, 0.344261f}; + std::initializer_list Grid_data{TypeParam(-0.151540f), TypeParam(-0.033291f), TypeParam(-0.597203f), TypeParam(0.836404f), TypeParam(-0.686848f), TypeParam(-0.485355f), TypeParam(-0.936738f), TypeParam(-1.009057f), TypeParam(1.065352f), TypeParam(-0.926635f), TypeParam(-0.165670f), TypeParam(-0.347352f), TypeParam(0.439545f), TypeParam(0.320963f), TypeParam(-0.919909f), TypeParam(1.077689f), TypeParam(-1.195359f), TypeParam(0.118687f), TypeParam(-0.100253f), TypeParam(-0.278089f), TypeParam(0.817760f), TypeParam(1.013180f), TypeParam(0.156316f), TypeParam(-0.423839f), TypeParam(0.892139f), TypeParam(0.753924f), TypeParam(0.215530f), TypeParam(-0.328214f), TypeParam(0.050592f), TypeParam(1.069553f), TypeParam(0.130134f), TypeParam(-0.236478f), TypeParam(-1.015986f), TypeParam(-0.643059f), TypeParam(0.866682f), TypeParam(-0.042256f), TypeParam(-0.079912f), TypeParam(0.467233f), TypeParam(-0.789513f), TypeParam(-0.081063f), TypeParam(-0.337505f), TypeParam(0.627865f), TypeParam(0.976589f), TypeParam(0.753489f), TypeParam(0.894667f), TypeParam(-1.072442f), TypeParam(-0.426020f), TypeParam(0.142099f), TypeParam(-1.019226f), TypeParam(0.325527f), TypeParam(-0.786578f), TypeParam(0.514215f), TypeParam(0.971223f), TypeParam(-1.026539f), TypeParam(1.005531f), TypeParam(0.559922f), TypeParam(-0.791906f), TypeParam(1.148613f), TypeParam(-1.039306f), TypeParam(-0.807864f), TypeParam(-0.596935f), TypeParam(-0.060766f), TypeParam(0.215484f), TypeParam(-0.352165f), TypeParam(-1.137417f), TypeParam(-0.138518f), TypeParam(0.910459f), TypeParam(0.923925f), TypeParam(0.600710f), TypeParam(0.174227f), TypeParam(0.298169f), TypeParam(-0.925092f), TypeParam(0.485927f), TypeParam(-1.194283f), TypeParam(-0.495564f), TypeParam(-0.315357f), TypeParam(0.881199f), TypeParam(-0.034981f), TypeParam(-0.546611f), TypeParam(0.209651f), TypeParam(-0.995724f), TypeParam(-0.317709f), TypeParam(0.332343f), TypeParam(-0.079474f), TypeParam(-0.126024f), TypeParam(0.733410f), TypeParam(-0.911554f), TypeParam(-0.605911f), TypeParam(1.161566f), TypeParam(0.238787f), TypeParam(-0.194293f), TypeParam(0.621583f), TypeParam(0.721901f), TypeParam(-0.200521f), TypeParam(-0.499850f), TypeParam(-0.196149f), TypeParam(0.435730f), TypeParam(-0.153196f), TypeParam(0.698401f), TypeParam(-0.978582f), TypeParam(-0.588758f), TypeParam(0.914808f), TypeParam(0.157427f), TypeParam(0.241646f), TypeParam(0.394674f), TypeParam(-0.283552f), TypeParam(-0.479889f), TypeParam(0.344261f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.870849f, -0.337086f, 1.731257f, -0.870849f, -0.389573f, -0.203690f, 1.095352f, -0.389573f, -2.008687f, 1.095352f, -0.389573f, 0.312170f, -0.083941f, 1.731257f, 0.521962f, 0.719508f, -0.870849f, 1.306894f, -0.836257f, -0.186813f, -0.277259f, -0.836257f, -1.835450f, 1.374159f, -0.340940f, -1.835450f, -0.195314f, -0.340940f, -1.835450f, -0.704226f, 0.476500f, -0.277259f, -2.719402f, 0.246604f, -0.836257f, -0.123386f, 1.874877f, -1.165777f, 0.604298f, -0.849095f, 0.884698f, 1.622836f, -1.165777f, -0.800372f, 0.566071f, 0.604298f, -0.886798f, -0.800372f, 0.665697f, -0.849095f, -0.827586f, -1.576251f, -0.827586f, -1.576251f, -0.206175f, -0.645219f, 1.937234f, -0.173385f, -0.453360f, 0.032418f, -0.645219f, -0.311580f, -0.510110f, 1.937234f, -1.056009f, -0.311580f, 1.113868f, -0.173385f, 1.609346f, -0.841602f, 1.609346f, -0.841602f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.870849f), TypeParam(-0.337086f), TypeParam(1.731257f), TypeParam(-0.870849f), TypeParam(-0.389573f), TypeParam(-0.203690f), TypeParam(1.095352f), TypeParam(-0.389573f), TypeParam(-2.008687f), TypeParam(1.095352f), TypeParam(-0.389573f), TypeParam(0.312170f), TypeParam(-0.083941f), TypeParam(1.731257f), TypeParam(0.521962f), TypeParam(0.719508f), TypeParam(-0.870849f), TypeParam(1.306894f), TypeParam(-0.836257f), TypeParam(-0.186813f), TypeParam(-0.277259f), TypeParam(-0.836257f), TypeParam(-1.835450f), TypeParam(1.374159f), TypeParam(-0.340940f), TypeParam(-1.835450f), TypeParam(-0.195314f), TypeParam(-0.340940f), TypeParam(-1.835450f), TypeParam(-0.704226f), TypeParam(0.476500f), TypeParam(-0.277259f), TypeParam(-2.719402f), TypeParam(0.246604f), TypeParam(-0.836257f), TypeParam(-0.123386f), TypeParam(1.874877f), TypeParam(-1.165777f), TypeParam(0.604298f), TypeParam(-0.849095f), TypeParam(0.884698f), TypeParam(1.622836f), TypeParam(-1.165777f), TypeParam(-0.800372f), TypeParam(0.566071f), TypeParam(0.604298f), TypeParam(-0.886798f), TypeParam(-0.800372f), TypeParam(0.665697f), TypeParam(-0.849095f), TypeParam(-0.827586f), TypeParam(-1.576251f), TypeParam(-0.827586f), TypeParam(-1.576251f), TypeParam(-0.206175f), TypeParam(-0.645219f), TypeParam(1.937234f), TypeParam(-0.173385f), TypeParam(-0.453360f), TypeParam(0.032418f), TypeParam(-0.645219f), TypeParam(-0.311580f), TypeParam(-0.510110f), TypeParam(1.937234f), TypeParam(-1.056009f), TypeParam(-0.311580f), TypeParam(1.113868f), TypeParam(-0.173385f), TypeParam(1.609346f), TypeParam(-0.841602f), TypeParam(1.609346f), TypeParam(-0.841602f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.079043f, 0.407494f, 1.038992f, -0.437542f, 0.991216f, 0.409636f, 1.050403f, -0.687172f, -2.021689f, 0.789633f, 0.538178f, 0.414847f, 2.221617f, -0.254833f, -0.179968f, -0.952356f, -1.213159f, 0.499103f, -0.374865f, 0.441938f, -0.114847f, 0.716887f, 1.059090f, 0.438870f}; + std::initializer_list X_data{TypeParam(0.079043f), TypeParam(0.407494f), TypeParam(1.038992f), TypeParam(-0.437542f), TypeParam(0.991216f), TypeParam(0.409636f), TypeParam(1.050403f), TypeParam(-0.687172f), TypeParam(-2.021689f), TypeParam(0.789633f), TypeParam(0.538178f), TypeParam(0.414847f), TypeParam(2.221617f), TypeParam(-0.254833f), TypeParam(-0.179968f), TypeParam(-0.952356f), TypeParam(-1.213159f), TypeParam(0.499103f), TypeParam(-0.374865f), TypeParam(0.441938f), TypeParam(-0.114847f), TypeParam(0.716887f), TypeParam(1.059090f), TypeParam(0.438870f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.355147f, -0.222342f, -1.197658f, 0.844060f, 1.188586f, 0.605435f, 1.174232f, 0.327060f, -0.094032f, -0.955794f, -1.048806f, -0.826196f, -0.304468f, 0.698768f, -0.495101f, -0.046607f, -0.016936f, -0.784415f, -0.032484f, 1.158664f, 0.959105f, 0.913943f, -0.118352f, 0.021282f}; + std::initializer_list Grid_data{TypeParam(0.355147f), TypeParam(-0.222342f), TypeParam(-1.197658f), TypeParam(0.844060f), TypeParam(1.188586f), TypeParam(0.605435f), TypeParam(1.174232f), TypeParam(0.327060f), TypeParam(-0.094032f), TypeParam(-0.955794f), TypeParam(-1.048806f), TypeParam(-0.826196f), TypeParam(-0.304468f), TypeParam(0.698768f), TypeParam(-0.495101f), TypeParam(-0.046607f), TypeParam(-0.016936f), TypeParam(-0.784415f), TypeParam(-0.032484f), TypeParam(1.158664f), TypeParam(0.959105f), TypeParam(0.913943f), TypeParam(-0.118352f), TypeParam(0.021282f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.437542f, 0.991216f, 0.409636f, -0.437542f, 0.079043f, 0.079043f, 0.789633f, 0.538178f, 0.414847f, 0.789633f, 1.050403f, 1.050403f, -1.213159f, -0.179968f, 2.221617f, -1.213159f, 0.499103f, -0.179968f, 1.059090f, -0.114847f, -0.374865f, 1.059090f, 0.438870f, -0.114847f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.437542f), TypeParam(0.991216f), TypeParam(0.409636f), TypeParam(-0.437542f), TypeParam(0.079043f), TypeParam(0.079043f), TypeParam(0.789633f), TypeParam(0.538178f), TypeParam(0.414847f), TypeParam(0.789633f), TypeParam(1.050403f), TypeParam(1.050403f), TypeParam(-1.213159f), TypeParam(-0.179968f), TypeParam(2.221617f), TypeParam(-1.213159f), TypeParam(0.499103f), TypeParam(-0.179968f), TypeParam(1.059090f), TypeParam(-0.114847f), TypeParam(-0.374865f), TypeParam(1.059090f), TypeParam(0.438870f), TypeParam(-0.114847f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{0.189379f, 0.825309f, -0.701365f, 0.787800f, -1.102514f, 0.126954f, 1.824453f, -0.144635f, -1.712534f, 0.361739f, -0.462516f, -2.153102f, 0.536963f, 0.581639f, -1.325014f, -1.314673f, -0.524797f, -1.304159f, -1.093757f, -1.703444f, -0.672976f, 0.505303f, 1.497654f, -0.545441f, -1.334648f, 0.474489f, 0.484384f, 0.434399f, -0.733471f, 0.452991f, 0.324606f, -1.307459f, -0.640603f, -0.450100f, 0.772854f, 1.281813f, -0.481714f, 1.224667f, -0.437546f, 0.371986f, -0.320368f, -1.011020f, -1.199298f, 0.213302f, 1.795444f, 0.409271f, 1.328065f, -1.037527f, 0.224494f, 0.217863f, -0.925740f, 0.344755f, -1.445667f, -0.935542f, -0.427280f, -2.010803f, -1.174929f, 1.434105f, -1.168630f, 0.321896f, -0.561974f, -0.209305f, -1.063838f, 1.451708f, 0.266913f, -0.132535f, 0.798299f, 0.619547f, -0.324459f, 0.255630f, 0.488773f, -0.142060f}; + std::initializer_list X_data{TypeParam(0.189379f), TypeParam(0.825309f), TypeParam(-0.701365f), TypeParam(0.787800f), TypeParam(-1.102514f), TypeParam(0.126954f), TypeParam(1.824453f), TypeParam(-0.144635f), TypeParam(-1.712534f), TypeParam(0.361739f), TypeParam(-0.462516f), TypeParam(-2.153102f), TypeParam(0.536963f), TypeParam(0.581639f), TypeParam(-1.325014f), TypeParam(-1.314673f), TypeParam(-0.524797f), TypeParam(-1.304159f), TypeParam(-1.093757f), TypeParam(-1.703444f), TypeParam(-0.672976f), TypeParam(0.505303f), TypeParam(1.497654f), TypeParam(-0.545441f), TypeParam(-1.334648f), TypeParam(0.474489f), TypeParam(0.484384f), TypeParam(0.434399f), TypeParam(-0.733471f), TypeParam(0.452991f), TypeParam(0.324606f), TypeParam(-1.307459f), TypeParam(-0.640603f), TypeParam(-0.450100f), TypeParam(0.772854f), TypeParam(1.281813f), TypeParam(-0.481714f), TypeParam(1.224667f), TypeParam(-0.437546f), TypeParam(0.371986f), TypeParam(-0.320368f), TypeParam(-1.011020f), TypeParam(-1.199298f), TypeParam(0.213302f), TypeParam(1.795444f), TypeParam(0.409271f), TypeParam(1.328065f), TypeParam(-1.037527f), TypeParam(0.224494f), TypeParam(0.217863f), TypeParam(-0.925740f), TypeParam(0.344755f), TypeParam(-1.445667f), TypeParam(-0.935542f), TypeParam(-0.427280f), TypeParam(-2.010803f), TypeParam(-1.174929f), TypeParam(1.434105f), TypeParam(-1.168630f), TypeParam(0.321896f), TypeParam(-0.561974f), TypeParam(-0.209305f), TypeParam(-1.063838f), TypeParam(1.451708f), TypeParam(0.266913f), TypeParam(-0.132535f), TypeParam(0.798299f), TypeParam(0.619547f), TypeParam(-0.324459f), TypeParam(0.255630f), TypeParam(0.488773f), TypeParam(-0.142060f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.034431f, 1.048250f, 0.160255f, -0.446426f, 0.879791f, -0.683555f, 0.039704f, 0.269729f, 0.538601f, -1.107191f, 0.058867f, -0.310704f, 0.778040f, 0.403733f, 0.480956f, 0.721512f, -0.268657f, -0.076883f, 0.962704f, -0.967187f, -0.829464f, 0.087786f, -0.475353f, 0.068725f, 1.060032f, -0.139108f, -1.023162f, -0.545493f, 1.102040f, -0.263627f, -0.526173f, 0.540152f, 0.148556f, -1.058015f, 0.999344f, 0.675750f, 1.043022f, 0.525119f, -0.404585f, -0.391737f, 0.581547f, -0.232625f, 0.235264f, -1.162786f, -0.593187f, 0.445737f, -0.059159f, -0.576901f, -1.046721f, 0.762672f, -0.241271f, -1.179040f, 1.157741f, 0.583952f, -0.717767f, -0.875798f, 1.159575f, 0.005010f, -0.721707f, 0.690536f, -0.249959f, 0.082204f, -0.625120f, -1.016394f, -0.796947f, -0.131764f, -0.868737f, 1.182731f, 0.012988f, -0.459398f, 0.474264f, -1.063883f, -0.613791f, 0.450721f, -1.019595f, 0.598084f, 0.100866f, -1.000569f, -1.190919f, 0.379261f, 0.567202f, -0.239888f, -1.061107f, -0.691616f, 0.127540f, 0.043657f, 0.307172f, 0.212184f, -0.062900f, 0.633272f, 1.164016f, 0.999377f, 1.090411f, -0.405004f, -0.409578f, -0.132722f, 0.354671f, 0.485734f, -0.106963f, -0.775112f, -0.905400f, 1.155262f, -0.322627f, -0.162203f, -0.735432f, -0.594912f, 0.263568f, 0.505424f}; + std::initializer_list Grid_data{TypeParam(-0.034431f), TypeParam(1.048250f), TypeParam(0.160255f), TypeParam(-0.446426f), TypeParam(0.879791f), TypeParam(-0.683555f), TypeParam(0.039704f), TypeParam(0.269729f), TypeParam(0.538601f), TypeParam(-1.107191f), TypeParam(0.058867f), TypeParam(-0.310704f), TypeParam(0.778040f), TypeParam(0.403733f), TypeParam(0.480956f), TypeParam(0.721512f), TypeParam(-0.268657f), TypeParam(-0.076883f), TypeParam(0.962704f), TypeParam(-0.967187f), TypeParam(-0.829464f), TypeParam(0.087786f), TypeParam(-0.475353f), TypeParam(0.068725f), TypeParam(1.060032f), TypeParam(-0.139108f), TypeParam(-1.023162f), TypeParam(-0.545493f), TypeParam(1.102040f), TypeParam(-0.263627f), TypeParam(-0.526173f), TypeParam(0.540152f), TypeParam(0.148556f), TypeParam(-1.058015f), TypeParam(0.999344f), TypeParam(0.675750f), TypeParam(1.043022f), TypeParam(0.525119f), TypeParam(-0.404585f), TypeParam(-0.391737f), TypeParam(0.581547f), TypeParam(-0.232625f), TypeParam(0.235264f), TypeParam(-1.162786f), TypeParam(-0.593187f), TypeParam(0.445737f), TypeParam(-0.059159f), TypeParam(-0.576901f), TypeParam(-1.046721f), TypeParam(0.762672f), TypeParam(-0.241271f), TypeParam(-1.179040f), TypeParam(1.157741f), TypeParam(0.583952f), TypeParam(-0.717767f), TypeParam(-0.875798f), TypeParam(1.159575f), TypeParam(0.005010f), TypeParam(-0.721707f), TypeParam(0.690536f), TypeParam(-0.249959f), TypeParam(0.082204f), TypeParam(-0.625120f), TypeParam(-1.016394f), TypeParam(-0.796947f), TypeParam(-0.131764f), TypeParam(-0.868737f), TypeParam(1.182731f), TypeParam(0.012988f), TypeParam(-0.459398f), TypeParam(0.474264f), TypeParam(-1.063883f), TypeParam(-0.613791f), TypeParam(0.450721f), TypeParam(-1.019595f), TypeParam(0.598084f), TypeParam(0.100866f), TypeParam(-1.000569f), TypeParam(-1.190919f), TypeParam(0.379261f), TypeParam(0.567202f), TypeParam(-0.239888f), TypeParam(-1.061107f), TypeParam(-0.691616f), TypeParam(0.127540f), TypeParam(0.043657f), TypeParam(0.307172f), TypeParam(0.212184f), TypeParam(-0.062900f), TypeParam(0.633272f), TypeParam(1.164016f), TypeParam(0.999377f), TypeParam(1.090411f), TypeParam(-0.405004f), TypeParam(-0.409578f), TypeParam(-0.132722f), TypeParam(0.354671f), TypeParam(0.485734f), TypeParam(-0.106963f), TypeParam(-0.775112f), TypeParam(-0.905400f), TypeParam(1.155262f), TypeParam(-0.322627f), TypeParam(-0.162203f), TypeParam(-0.735432f), TypeParam(-0.594912f), TypeParam(0.263568f), TypeParam(0.505424f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.462516f, -1.102514f, -1.314673f, -1.712534f, 0.361739f, 0.361739f, 0.825309f, 0.361739f, 0.787800f, -0.462516f, -0.462516f, -0.524797f, -2.153102f, -0.462516f, 0.825309f, 0.787800f, -0.462516f, -0.524797f, -0.733471f, 1.497654f, -0.450100f, 0.484384f, 0.434399f, 0.434399f, -1.703444f, 0.434399f, 0.505303f, -0.733471f, -0.733471f, 0.772854f, 0.452991f, -0.733471f, -1.703444f, 0.505303f, -0.733471f, 0.772854f, 0.224494f, 0.217863f, -0.437546f, -1.199298f, 1.328065f, -0.437546f, -0.437546f, 0.371986f, -0.925740f, -0.481714f, 0.409271f, 0.344755f, -0.935542f, 1.795444f, 0.409271f, 0.224494f, -0.437546f, -0.925740f, 0.798299f, 0.619547f, -1.174929f, -0.561974f, 0.266913f, -1.174929f, -1.174929f, 1.434105f, -0.324459f, -0.427280f, 1.451708f, 0.255630f, -0.142060f, -1.063838f, 1.451708f, 0.798299f, -1.174929f, -0.324459f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.462516f), TypeParam(-1.102514f), TypeParam(-1.314673f), TypeParam(-1.712534f), TypeParam(0.361739f), TypeParam(0.361739f), TypeParam(0.825309f), TypeParam(0.361739f), TypeParam(0.787800f), TypeParam(-0.462516f), TypeParam(-0.462516f), TypeParam(-0.524797f), TypeParam(-2.153102f), TypeParam(-0.462516f), TypeParam(0.825309f), TypeParam(0.787800f), TypeParam(-0.462516f), TypeParam(-0.524797f), TypeParam(-0.733471f), TypeParam(1.497654f), TypeParam(-0.450100f), TypeParam(0.484384f), TypeParam(0.434399f), TypeParam(0.434399f), TypeParam(-1.703444f), TypeParam(0.434399f), TypeParam(0.505303f), TypeParam(-0.733471f), TypeParam(-0.733471f), TypeParam(0.772854f), TypeParam(0.452991f), TypeParam(-0.733471f), TypeParam(-1.703444f), TypeParam(0.505303f), TypeParam(-0.733471f), TypeParam(0.772854f), TypeParam(0.224494f), TypeParam(0.217863f), TypeParam(-0.437546f), TypeParam(-1.199298f), TypeParam(1.328065f), TypeParam(-0.437546f), TypeParam(-0.437546f), TypeParam(0.371986f), TypeParam(-0.925740f), TypeParam(-0.481714f), TypeParam(0.409271f), TypeParam(0.344755f), TypeParam(-0.935542f), TypeParam(1.795444f), TypeParam(0.409271f), TypeParam(0.224494f), TypeParam(-0.437546f), TypeParam(-0.925740f), TypeParam(0.798299f), TypeParam(0.619547f), TypeParam(-1.174929f), TypeParam(-0.561974f), TypeParam(0.266913f), TypeParam(-1.174929f), TypeParam(-1.174929f), TypeParam(1.434105f), TypeParam(-0.324459f), TypeParam(-0.427280f), TypeParam(1.451708f), TypeParam(0.255630f), TypeParam(-0.142060f), TypeParam(-1.063838f), TypeParam(1.451708f), TypeParam(0.798299f), TypeParam(-1.174929f), TypeParam(-0.324459f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.769854f, -0.805659f, 0.813652f, -0.010183f, 0.276463f, -0.771678f, -2.563015f, -1.243904f, 2.365071f, 0.730651f, -0.068795f, -1.495438f, 0.211578f, -1.042373f, 0.884036f, -0.746288f, 1.011368f, 0.194463f, -0.307214f, 0.556053f, 0.629364f, 0.083601f, 0.248627f, -0.822453f}; + std::initializer_list X_data{TypeParam(-0.769854f), TypeParam(-0.805659f), TypeParam(0.813652f), TypeParam(-0.010183f), TypeParam(0.276463f), TypeParam(-0.771678f), TypeParam(-2.563015f), TypeParam(-1.243904f), TypeParam(2.365071f), TypeParam(0.730651f), TypeParam(-0.068795f), TypeParam(-1.495438f), TypeParam(0.211578f), TypeParam(-1.042373f), TypeParam(0.884036f), TypeParam(-0.746288f), TypeParam(1.011368f), TypeParam(0.194463f), TypeParam(-0.307214f), TypeParam(0.556053f), TypeParam(0.629364f), TypeParam(0.083601f), TypeParam(0.248627f), TypeParam(-0.822453f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.569884f, 1.163780f, -0.977608f, -0.145509f, 0.651234f, 1.099753f, -0.853766f, 0.509955f, 0.495437f, 0.723445f, -0.827299f, 0.856340f, -0.522676f, -0.738659f, 0.238269f, 1.016568f, -0.794666f, 0.640690f, -0.137431f, 0.383085f, 0.936085f, 0.325824f, -0.996188f, -0.361291f}; + std::initializer_list Grid_data{TypeParam(0.569884f), TypeParam(1.163780f), TypeParam(-0.977608f), TypeParam(-0.145509f), TypeParam(0.651234f), TypeParam(1.099753f), TypeParam(-0.853766f), TypeParam(0.509955f), TypeParam(0.495437f), TypeParam(0.723445f), TypeParam(-0.827299f), TypeParam(0.856340f), TypeParam(-0.522676f), TypeParam(-0.738659f), TypeParam(0.238269f), TypeParam(1.016568f), TypeParam(-0.794666f), TypeParam(0.640690f), TypeParam(-0.137431f), TypeParam(0.383085f), TypeParam(0.936085f), TypeParam(0.325824f), TypeParam(-0.996188f), TypeParam(-0.361291f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.771678f, 0.813652f, -0.771678f, 0.276463f, -0.771678f, 0.276463f, -1.495438f, 2.365071f, -1.495438f, -0.068795f, -1.495438f, -0.068795f, 0.211578f, 0.194463f, 1.011368f, 1.011368f, -0.746288f, 0.211578f, -0.307214f, -0.822453f, 0.248627f, 0.248627f, 0.083601f, -0.307214f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.771678f), TypeParam(0.813652f), TypeParam(-0.771678f), TypeParam(0.276463f), TypeParam(-0.771678f), TypeParam(0.276463f), TypeParam(-1.495438f), TypeParam(2.365071f), TypeParam(-1.495438f), TypeParam(-0.068795f), TypeParam(-1.495438f), TypeParam(-0.068795f), TypeParam(0.211578f), TypeParam(0.194463f), TypeParam(1.011368f), TypeParam(1.011368f), TypeParam(-0.746288f), TypeParam(0.211578f), TypeParam(-0.307214f), TypeParam(-0.822453f), TypeParam(0.248627f), TypeParam(0.248627f), TypeParam(0.083601f), TypeParam(-0.307214f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.185898f, 0.403325f, 0.737314f, 0.545995f, -1.010481f, -1.204522f, -0.147342f, 0.232425f, -1.339485f, 0.013892f, -1.098319f, 0.478079f, 0.051159f, -0.906061f, -0.428560f, 0.583460f, 1.137472f, 1.487881f, 1.349931f, -0.118774f, 0.436410f, 1.334689f, -1.115846f, 0.159820f, 0.617671f, 0.546630f, 1.861115f, 0.500044f, 0.623446f, 0.541840f, -0.279259f, -0.573875f, 0.783115f, -1.125017f, -1.166457f, -0.827232f, 0.273074f, 0.702953f, 1.288608f, -1.037043f, 0.021860f, 0.575628f, -0.034170f, 1.400741f, 0.508057f, 0.994702f, -2.267981f, 1.677437f, 0.175134f, 0.712679f, -0.440408f, -1.248550f, 1.618839f, -0.214598f, 0.486398f, -0.478466f, 0.912471f, 0.499651f, -0.886606f, -0.929524f, 0.449260f, 0.017969f, -0.050906f, 1.799695f, -0.033007f, -1.884108f, -1.392415f, -0.852990f, -0.052969f, 0.819434f, 0.089723f, 0.598047f}; + std::initializer_list X_data{TypeParam(-0.185898f), TypeParam(0.403325f), TypeParam(0.737314f), TypeParam(0.545995f), TypeParam(-1.010481f), TypeParam(-1.204522f), TypeParam(-0.147342f), TypeParam(0.232425f), TypeParam(-1.339485f), TypeParam(0.013892f), TypeParam(-1.098319f), TypeParam(0.478079f), TypeParam(0.051159f), TypeParam(-0.906061f), TypeParam(-0.428560f), TypeParam(0.583460f), TypeParam(1.137472f), TypeParam(1.487881f), TypeParam(1.349931f), TypeParam(-0.118774f), TypeParam(0.436410f), TypeParam(1.334689f), TypeParam(-1.115846f), TypeParam(0.159820f), TypeParam(0.617671f), TypeParam(0.546630f), TypeParam(1.861115f), TypeParam(0.500044f), TypeParam(0.623446f), TypeParam(0.541840f), TypeParam(-0.279259f), TypeParam(-0.573875f), TypeParam(0.783115f), TypeParam(-1.125017f), TypeParam(-1.166457f), TypeParam(-0.827232f), TypeParam(0.273074f), TypeParam(0.702953f), TypeParam(1.288608f), TypeParam(-1.037043f), TypeParam(0.021860f), TypeParam(0.575628f), TypeParam(-0.034170f), TypeParam(1.400741f), TypeParam(0.508057f), TypeParam(0.994702f), TypeParam(-2.267981f), TypeParam(1.677437f), TypeParam(0.175134f), TypeParam(0.712679f), TypeParam(-0.440408f), TypeParam(-1.248550f), TypeParam(1.618839f), TypeParam(-0.214598f), TypeParam(0.486398f), TypeParam(-0.478466f), TypeParam(0.912471f), TypeParam(0.499651f), TypeParam(-0.886606f), TypeParam(-0.929524f), TypeParam(0.449260f), TypeParam(0.017969f), TypeParam(-0.050906f), TypeParam(1.799695f), TypeParam(-0.033007f), TypeParam(-1.884108f), TypeParam(-1.392415f), TypeParam(-0.852990f), TypeParam(-0.052969f), TypeParam(0.819434f), TypeParam(0.089723f), TypeParam(0.598047f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.118828f, 0.082315f, 0.328488f, -0.834821f, -0.138863f, -0.988801f, -0.976128f, 0.156412f, -1.171383f, 0.319534f, -1.105438f, -0.834991f, -0.248995f, -1.145138f, 0.969159f, 0.983228f, -0.626795f, 0.251376f, 0.613890f, 0.381328f, -0.160747f, -1.131853f, 0.872567f, -1.052516f, -0.222240f, 0.074438f, -0.395210f, -0.438906f, -1.037125f, 0.066119f, -0.136254f, 1.046163f, -0.395065f, 0.927498f, 0.056808f, -0.539139f, -0.285382f, -0.136177f, 0.012430f, -0.197703f, 0.356128f, 0.988219f, 0.188620f, 0.434655f, 0.741024f, 0.258662f, 0.553165f, 0.629461f, 1.123216f, -1.095185f, 0.410630f, -0.054374f, -0.215508f, -0.462650f, 0.721441f, 1.097745f, -0.979308f, 0.648336f, 0.827460f, 0.209729f, 0.014136f, 0.923431f, 0.035578f, -0.299309f, -0.088614f, 0.385002f, 0.300407f, -0.064744f, 0.378800f, 0.323185f, -0.972071f, 0.299012f, 0.734213f, 0.137618f, -0.109532f, 0.919238f, -1.048417f, -0.547724f, -0.542389f, 1.036863f, -1.160666f, 0.119013f, -1.162427f, -0.039461f, 0.447285f, -0.280625f, 1.164882f, 0.003820f, -0.611796f, 0.309439f, 0.624077f, -0.002384f, 1.026569f, -0.759499f, 0.512014f, 0.681403f, 0.596030f, -0.000440f, 0.342557f, -0.941414f, -0.941707f, -0.074588f, -0.150400f, 0.891031f, 0.871352f, 0.813657f, -0.549640f, -0.942044f}; + std::initializer_list Grid_data{TypeParam(-0.118828f), TypeParam(0.082315f), TypeParam(0.328488f), TypeParam(-0.834821f), TypeParam(-0.138863f), TypeParam(-0.988801f), TypeParam(-0.976128f), TypeParam(0.156412f), TypeParam(-1.171383f), TypeParam(0.319534f), TypeParam(-1.105438f), TypeParam(-0.834991f), TypeParam(-0.248995f), TypeParam(-1.145138f), TypeParam(0.969159f), TypeParam(0.983228f), TypeParam(-0.626795f), TypeParam(0.251376f), TypeParam(0.613890f), TypeParam(0.381328f), TypeParam(-0.160747f), TypeParam(-1.131853f), TypeParam(0.872567f), TypeParam(-1.052516f), TypeParam(-0.222240f), TypeParam(0.074438f), TypeParam(-0.395210f), TypeParam(-0.438906f), TypeParam(-1.037125f), TypeParam(0.066119f), TypeParam(-0.136254f), TypeParam(1.046163f), TypeParam(-0.395065f), TypeParam(0.927498f), TypeParam(0.056808f), TypeParam(-0.539139f), TypeParam(-0.285382f), TypeParam(-0.136177f), TypeParam(0.012430f), TypeParam(-0.197703f), TypeParam(0.356128f), TypeParam(0.988219f), TypeParam(0.188620f), TypeParam(0.434655f), TypeParam(0.741024f), TypeParam(0.258662f), TypeParam(0.553165f), TypeParam(0.629461f), TypeParam(1.123216f), TypeParam(-1.095185f), TypeParam(0.410630f), TypeParam(-0.054374f), TypeParam(-0.215508f), TypeParam(-0.462650f), TypeParam(0.721441f), TypeParam(1.097745f), TypeParam(-0.979308f), TypeParam(0.648336f), TypeParam(0.827460f), TypeParam(0.209729f), TypeParam(0.014136f), TypeParam(0.923431f), TypeParam(0.035578f), TypeParam(-0.299309f), TypeParam(-0.088614f), TypeParam(0.385002f), TypeParam(0.300407f), TypeParam(-0.064744f), TypeParam(0.378800f), TypeParam(0.323185f), TypeParam(-0.972071f), TypeParam(0.299012f), TypeParam(0.734213f), TypeParam(0.137618f), TypeParam(-0.109532f), TypeParam(0.919238f), TypeParam(-1.048417f), TypeParam(-0.547724f), TypeParam(-0.542389f), TypeParam(1.036863f), TypeParam(-1.160666f), TypeParam(0.119013f), TypeParam(-1.162427f), TypeParam(-0.039461f), TypeParam(0.447285f), TypeParam(-0.280625f), TypeParam(1.164882f), TypeParam(0.003820f), TypeParam(-0.611796f), TypeParam(0.309439f), TypeParam(0.624077f), TypeParam(-0.002384f), TypeParam(1.026569f), TypeParam(-0.759499f), TypeParam(0.512014f), TypeParam(0.681403f), TypeParam(0.596030f), TypeParam(-0.000440f), TypeParam(0.342557f), TypeParam(-0.941414f), TypeParam(-0.941707f), TypeParam(-0.074588f), TypeParam(-0.150400f), TypeParam(0.891031f), TypeParam(0.871352f), TypeParam(0.813657f), TypeParam(-0.549640f), TypeParam(-0.942044f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-1.339485f, 0.737314f, 0.737314f, 0.403325f, 0.051159f, 0.232425f, 0.478079f, -1.010481f, 0.737314f, -0.147342f, -1.010481f, 0.545995f, -1.339485f, 1.137472f, 1.487881f, 1.487881f, -0.906061f, 0.737314f, 1.861115f, 0.436410f, 0.436410f, -0.118774f, -0.279259f, 0.546630f, 0.541840f, -1.115846f, 0.436410f, 0.617671f, -1.115846f, 1.334689f, 1.861115f, -1.166457f, -0.827232f, -0.827232f, -0.573875f, 0.436410f, 0.575628f, 1.677437f, 1.677437f, -0.440408f, -1.248550f, 1.400741f, 0.994702f, 0.702953f, 0.021860f, 1.400741f, -1.248550f, 1.400741f, -1.248550f, 1.618839f, -1.248550f, -0.034170f, 1.618839f, 0.702953f, -0.929524f, -1.884108f, -1.884108f, -0.052969f, 0.819434f, 0.017969f, 1.799695f, -0.478466f, -0.886606f, 0.017969f, 0.819434f, 0.017969f, 0.819434f, 0.089723f, 0.819434f, 0.449260f, 0.089723f, -0.478466f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.339485f), TypeParam(0.737314f), TypeParam(0.737314f), TypeParam(0.403325f), TypeParam(0.051159f), TypeParam(0.232425f), TypeParam(0.478079f), TypeParam(-1.010481f), TypeParam(0.737314f), TypeParam(-0.147342f), TypeParam(-1.010481f), TypeParam(0.545995f), TypeParam(-1.339485f), TypeParam(1.137472f), TypeParam(1.487881f), TypeParam(1.487881f), TypeParam(-0.906061f), TypeParam(0.737314f), TypeParam(1.861115f), TypeParam(0.436410f), TypeParam(0.436410f), TypeParam(-0.118774f), TypeParam(-0.279259f), TypeParam(0.546630f), TypeParam(0.541840f), TypeParam(-1.115846f), TypeParam(0.436410f), TypeParam(0.617671f), TypeParam(-1.115846f), TypeParam(1.334689f), TypeParam(1.861115f), TypeParam(-1.166457f), TypeParam(-0.827232f), TypeParam(-0.827232f), TypeParam(-0.573875f), TypeParam(0.436410f), TypeParam(0.575628f), TypeParam(1.677437f), TypeParam(1.677437f), TypeParam(-0.440408f), TypeParam(-1.248550f), TypeParam(1.400741f), TypeParam(0.994702f), TypeParam(0.702953f), TypeParam(0.021860f), TypeParam(1.400741f), TypeParam(-1.248550f), TypeParam(1.400741f), TypeParam(-1.248550f), TypeParam(1.618839f), TypeParam(-1.248550f), TypeParam(-0.034170f), TypeParam(1.618839f), TypeParam(0.702953f), TypeParam(-0.929524f), TypeParam(-1.884108f), TypeParam(-1.884108f), TypeParam(-0.052969f), TypeParam(0.819434f), TypeParam(0.017969f), TypeParam(1.799695f), TypeParam(-0.478466f), TypeParam(-0.886606f), TypeParam(0.017969f), TypeParam(0.819434f), TypeParam(0.017969f), TypeParam(0.819434f), TypeParam(0.089723f), TypeParam(0.819434f), TypeParam(0.449260f), TypeParam(0.089723f), TypeParam(-0.478466f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.010274f, 1.493496f, -0.264303f, 0.035897f, -0.751962f, -0.370195f, -0.514836f, 0.399928f, -0.191651f, -0.239505f, -1.931184f, -1.074773f, -0.121908f, 0.050673f, -0.741501f, -0.229127f, -0.360925f, 0.264077f, 1.537180f, 1.603202f, -1.241810f, -0.388456f, -0.609742f, 0.095097f}; + std::initializer_list X_data{TypeParam(0.010274f), TypeParam(1.493496f), TypeParam(-0.264303f), TypeParam(0.035897f), TypeParam(-0.751962f), TypeParam(-0.370195f), TypeParam(-0.514836f), TypeParam(0.399928f), TypeParam(-0.191651f), TypeParam(-0.239505f), TypeParam(-1.931184f), TypeParam(-1.074773f), TypeParam(-0.121908f), TypeParam(0.050673f), TypeParam(-0.741501f), TypeParam(-0.229127f), TypeParam(-0.360925f), TypeParam(0.264077f), TypeParam(1.537180f), TypeParam(1.603202f), TypeParam(-1.241810f), TypeParam(-0.388456f), TypeParam(-0.609742f), TypeParam(0.095097f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.118589f, -0.020968f, -0.893597f, 1.170924f, -0.517539f, 0.698168f, -0.672718f, 0.008056f, 0.410793f, -1.101817f, 0.550440f, -0.918534f, 0.167456f, -0.237959f, 0.687868f, 1.166281f, 0.270439f, -0.034265f, -0.594534f, 0.447403f, -0.577587f, 0.495680f, -0.520113f, 0.813977f}; + std::initializer_list Grid_data{TypeParam(-0.118589f), TypeParam(-0.020968f), TypeParam(-0.893597f), TypeParam(1.170924f), TypeParam(-0.517539f), TypeParam(0.698168f), TypeParam(-0.672718f), TypeParam(0.008056f), TypeParam(0.410793f), TypeParam(-1.101817f), TypeParam(0.550440f), TypeParam(-0.918534f), TypeParam(0.167456f), TypeParam(-0.237959f), TypeParam(0.687868f), TypeParam(1.166281f), TypeParam(0.270439f), TypeParam(-0.034265f), TypeParam(-0.594534f), TypeParam(0.447403f), TypeParam(-0.577587f), TypeParam(0.495680f), TypeParam(-0.520113f), TypeParam(0.813977f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.115313f, -0.606595f, -0.518616f, -0.218999f, 0.948961f, 1.063015f, -0.210622f, -1.563324f, -1.265386f, -0.212304f, 0.117155f, 0.159843f, -0.342175f, 0.138844f, -0.402196f, -0.457139f, -0.432849f, -0.286783f, -0.191760f, -0.012426f, -0.621658f, -0.799488f, -0.763820f, -0.551571f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.115313f), TypeParam(-0.606595f), TypeParam(-0.518616f), TypeParam(-0.218999f), TypeParam(0.948961f), TypeParam(1.063015f), TypeParam(-0.210622f), TypeParam(-1.563324f), TypeParam(-1.265386f), TypeParam(-0.212304f), TypeParam(0.117155f), TypeParam(0.159843f), TypeParam(-0.342175f), TypeParam(0.138844f), TypeParam(-0.402196f), TypeParam(-0.457139f), TypeParam(-0.432849f), TypeParam(-0.286783f), TypeParam(-0.191760f), TypeParam(-0.012426f), TypeParam(-0.621658f), TypeParam(-0.799488f), TypeParam(-0.763820f), TypeParam(-0.551571f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-1.787070f, -0.894227f, -0.113069f, 0.713917f, 0.041566f, -1.847208f, 0.013441f, -1.439041f, 1.051864f, 1.576791f, 1.180527f, -1.457019f, 0.298446f, 1.142738f, -0.961347f, -0.471509f, -0.074154f, 0.047739f, -0.679950f, -2.306940f, -0.552171f, -0.357144f, -0.492247f, -0.455872f, 0.399680f, 0.057915f, -0.362704f, 1.083763f, -0.084941f, -1.691393f, -1.913178f, 0.696366f, 1.172833f, 0.901506f, -1.189840f, -1.197158f, 0.007338f, 0.161468f, -1.048452f, -0.480832f, 0.391235f, 1.056413f, -0.116648f, 0.632195f, 0.840261f, -2.187738f, 0.302910f, -0.956190f, -0.362645f, 0.771747f, 0.524840f, -0.954672f, -1.084612f, -0.525794f, -0.969691f, -1.056405f, -0.364709f, 0.336189f, -0.178281f, 1.015025f, -0.532580f, 0.036602f, -0.434395f, -1.208987f, -1.084039f, 0.642844f, -0.819208f, -0.982898f, -0.109210f, -1.231957f, 1.083089f, -0.870451f}; + std::initializer_list X_data{TypeParam(-1.787070f), TypeParam(-0.894227f), TypeParam(-0.113069f), TypeParam(0.713917f), TypeParam(0.041566f), TypeParam(-1.847208f), TypeParam(0.013441f), TypeParam(-1.439041f), TypeParam(1.051864f), TypeParam(1.576791f), TypeParam(1.180527f), TypeParam(-1.457019f), TypeParam(0.298446f), TypeParam(1.142738f), TypeParam(-0.961347f), TypeParam(-0.471509f), TypeParam(-0.074154f), TypeParam(0.047739f), TypeParam(-0.679950f), TypeParam(-2.306940f), TypeParam(-0.552171f), TypeParam(-0.357144f), TypeParam(-0.492247f), TypeParam(-0.455872f), TypeParam(0.399680f), TypeParam(0.057915f), TypeParam(-0.362704f), TypeParam(1.083763f), TypeParam(-0.084941f), TypeParam(-1.691393f), TypeParam(-1.913178f), TypeParam(0.696366f), TypeParam(1.172833f), TypeParam(0.901506f), TypeParam(-1.189840f), TypeParam(-1.197158f), TypeParam(0.007338f), TypeParam(0.161468f), TypeParam(-1.048452f), TypeParam(-0.480832f), TypeParam(0.391235f), TypeParam(1.056413f), TypeParam(-0.116648f), TypeParam(0.632195f), TypeParam(0.840261f), TypeParam(-2.187738f), TypeParam(0.302910f), TypeParam(-0.956190f), TypeParam(-0.362645f), TypeParam(0.771747f), TypeParam(0.524840f), TypeParam(-0.954672f), TypeParam(-1.084612f), TypeParam(-0.525794f), TypeParam(-0.969691f), TypeParam(-1.056405f), TypeParam(-0.364709f), TypeParam(0.336189f), TypeParam(-0.178281f), TypeParam(1.015025f), TypeParam(-0.532580f), TypeParam(0.036602f), TypeParam(-0.434395f), TypeParam(-1.208987f), TypeParam(-1.084039f), TypeParam(0.642844f), TypeParam(-0.819208f), TypeParam(-0.982898f), TypeParam(-0.109210f), TypeParam(-1.231957f), TypeParam(1.083089f), TypeParam(-0.870451f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.350638f, -0.554259f, 0.740901f, -1.134597f, -0.450763f, -0.706065f, -0.712365f, -0.727142f, -1.130749f, 0.205940f, -0.237380f, -1.010413f, -0.000494f, -0.199898f, 0.495032f, -0.939943f, -0.337590f, 0.247001f, 0.508664f, 0.090780f, 0.325198f, 1.199561f, -0.415694f, 0.817854f, 1.033666f, -1.061540f, 0.290273f, 0.679739f, -0.187185f, 0.662278f, 0.040817f, 0.913540f, 0.025838f, -0.768267f, 0.911326f, 0.356885f, 1.020923f, 0.297892f, 0.637209f, 0.748214f, 0.202064f, -0.278959f, 0.247841f, -0.836700f, 0.040996f, -0.385697f, 0.075869f, -0.950110f, 0.733227f, -1.107135f, 0.513890f, 0.790272f, -1.099795f, 1.084212f, -0.892061f, -0.235640f, 0.621837f, -0.380523f, 1.069422f, -0.529383f, -0.160661f, -0.784422f, -0.556715f, 1.171015f, 0.902476f, 0.088357f, 0.098667f, -1.018314f, 0.905937f, -0.179914f, -0.500513f, -0.954987f, 0.986618f, 0.569025f, 0.722795f, 0.124254f, -0.814285f, 0.491561f, 0.138395f, 0.402690f, -0.298810f, -0.566298f, 0.985118f, 0.402260f, -0.487031f, 0.107159f, -0.260850f, -0.102620f, 0.672911f, -0.955102f, 1.086040f, 0.807667f, 0.001031f, -0.490841f, 0.244670f, -0.794290f, 0.779461f, -0.634633f, 0.229290f, -1.180597f, 0.574650f, 0.812338f, 0.900697f, 0.097950f, 0.708525f, 0.409153f, 0.804739f, 0.677169f}; + std::initializer_list Grid_data{TypeParam(0.350638f), TypeParam(-0.554259f), TypeParam(0.740901f), TypeParam(-1.134597f), TypeParam(-0.450763f), TypeParam(-0.706065f), TypeParam(-0.712365f), TypeParam(-0.727142f), TypeParam(-1.130749f), TypeParam(0.205940f), TypeParam(-0.237380f), TypeParam(-1.010413f), TypeParam(-0.000494f), TypeParam(-0.199898f), TypeParam(0.495032f), TypeParam(-0.939943f), TypeParam(-0.337590f), TypeParam(0.247001f), TypeParam(0.508664f), TypeParam(0.090780f), TypeParam(0.325198f), TypeParam(1.199561f), TypeParam(-0.415694f), TypeParam(0.817854f), TypeParam(1.033666f), TypeParam(-1.061540f), TypeParam(0.290273f), TypeParam(0.679739f), TypeParam(-0.187185f), TypeParam(0.662278f), TypeParam(0.040817f), TypeParam(0.913540f), TypeParam(0.025838f), TypeParam(-0.768267f), TypeParam(0.911326f), TypeParam(0.356885f), TypeParam(1.020923f), TypeParam(0.297892f), TypeParam(0.637209f), TypeParam(0.748214f), TypeParam(0.202064f), TypeParam(-0.278959f), TypeParam(0.247841f), TypeParam(-0.836700f), TypeParam(0.040996f), TypeParam(-0.385697f), TypeParam(0.075869f), TypeParam(-0.950110f), TypeParam(0.733227f), TypeParam(-1.107135f), TypeParam(0.513890f), TypeParam(0.790272f), TypeParam(-1.099795f), TypeParam(1.084212f), TypeParam(-0.892061f), TypeParam(-0.235640f), TypeParam(0.621837f), TypeParam(-0.380523f), TypeParam(1.069422f), TypeParam(-0.529383f), TypeParam(-0.160661f), TypeParam(-0.784422f), TypeParam(-0.556715f), TypeParam(1.171015f), TypeParam(0.902476f), TypeParam(0.088357f), TypeParam(0.098667f), TypeParam(-1.018314f), TypeParam(0.905937f), TypeParam(-0.179914f), TypeParam(-0.500513f), TypeParam(-0.954987f), TypeParam(0.986618f), TypeParam(0.569025f), TypeParam(0.722795f), TypeParam(0.124254f), TypeParam(-0.814285f), TypeParam(0.491561f), TypeParam(0.138395f), TypeParam(0.402690f), TypeParam(-0.298810f), TypeParam(-0.566298f), TypeParam(0.985118f), TypeParam(0.402260f), TypeParam(-0.487031f), TypeParam(0.107159f), TypeParam(-0.260850f), TypeParam(-0.102620f), TypeParam(0.672911f), TypeParam(-0.955102f), TypeParam(1.086040f), TypeParam(0.807667f), TypeParam(0.001031f), TypeParam(-0.490841f), TypeParam(0.244670f), TypeParam(-0.794290f), TypeParam(0.779461f), TypeParam(-0.634633f), TypeParam(0.229290f), TypeParam(-1.180597f), TypeParam(0.574650f), TypeParam(0.812338f), TypeParam(0.900697f), TypeParam(0.097950f), TypeParam(0.708525f), TypeParam(0.409153f), TypeParam(0.804739f), TypeParam(0.677169f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.171946f, -0.411342f, -1.046998f, -0.002345f, 0.246533f, 0.396970f, 0.664278f, 0.199883f, -0.636287f, 0.162358f, -0.061161f, 0.528084f, 0.041846f, 0.750291f, -0.476442f, 0.142258f, -0.067844f, 0.869081f, 0.360025f, -0.406785f, -0.701985f, -0.718142f, 0.519179f, -0.022693f, 0.618451f, 0.708731f, 0.224429f, 0.784241f, -0.812606f, -0.521137f, 0.266524f, 0.190886f, 0.231077f, -0.465330f, 0.204730f, 0.348489f, 0.356190f, 0.256096f, -0.038212f, -0.943162f, 0.258902f, -0.360112f, -0.920536f, 0.126677f, -0.523600f, -0.361337f, -0.154168f, 0.179761f, -1.141155f, -0.423488f, -0.225410f, -0.204886f, -1.162816f, -0.678226f, -0.384409f, -0.146245f, -0.622531f, 0.312188f, -0.828836f, -0.541017f, -0.778291f, -0.602484f, -0.328754f, -0.163964f, -0.508068f, 0.193021f, 0.273133f, -0.217934f, -0.562420f, 0.287725f, -1.097279f, -0.306201f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.171946f), TypeParam(-0.411342f), TypeParam(-1.046998f), TypeParam(-0.002345f), TypeParam(0.246533f), TypeParam(0.396970f), TypeParam(0.664278f), TypeParam(0.199883f), TypeParam(-0.636287f), TypeParam(0.162358f), TypeParam(-0.061161f), TypeParam(0.528084f), TypeParam(0.041846f), TypeParam(0.750291f), TypeParam(-0.476442f), TypeParam(0.142258f), TypeParam(-0.067844f), TypeParam(0.869081f), TypeParam(0.360025f), TypeParam(-0.406785f), TypeParam(-0.701985f), TypeParam(-0.718142f), TypeParam(0.519179f), TypeParam(-0.022693f), TypeParam(0.618451f), TypeParam(0.708731f), TypeParam(0.224429f), TypeParam(0.784241f), TypeParam(-0.812606f), TypeParam(-0.521137f), TypeParam(0.266524f), TypeParam(0.190886f), TypeParam(0.231077f), TypeParam(-0.465330f), TypeParam(0.204730f), TypeParam(0.348489f), TypeParam(0.356190f), TypeParam(0.256096f), TypeParam(-0.038212f), TypeParam(-0.943162f), TypeParam(0.258902f), TypeParam(-0.360112f), TypeParam(-0.920536f), TypeParam(0.126677f), TypeParam(-0.523600f), TypeParam(-0.361337f), TypeParam(-0.154168f), TypeParam(0.179761f), TypeParam(-1.141155f), TypeParam(-0.423488f), TypeParam(-0.225410f), TypeParam(-0.204886f), TypeParam(-1.162816f), TypeParam(-0.678226f), TypeParam(-0.384409f), TypeParam(-0.146245f), TypeParam(-0.622531f), TypeParam(0.312188f), TypeParam(-0.828836f), TypeParam(-0.541017f), TypeParam(-0.778291f), TypeParam(-0.602484f), TypeParam(-0.328754f), TypeParam(-0.163964f), TypeParam(-0.508068f), TypeParam(0.193021f), TypeParam(0.273133f), TypeParam(-0.217934f), TypeParam(-0.562420f), TypeParam(0.287725f), TypeParam(-1.097279f), TypeParam(-0.306201f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.185965f, 0.133937f, -0.763030f, 0.733342f, 1.932445f, -0.582571f, -1.312078f, 0.738952f, 0.444459f, 0.742593f, -0.805960f, -0.202535f, 0.970323f, -0.801176f, 0.277655f, -1.938051f, -1.879800f, 0.287116f, 0.261958f, -0.358247f, -0.107750f, 0.748162f, -0.742330f, 0.344665f}; + std::initializer_list X_data{TypeParam(0.185965f), TypeParam(0.133937f), TypeParam(-0.763030f), TypeParam(0.733342f), TypeParam(1.932445f), TypeParam(-0.582571f), TypeParam(-1.312078f), TypeParam(0.738952f), TypeParam(0.444459f), TypeParam(0.742593f), TypeParam(-0.805960f), TypeParam(-0.202535f), TypeParam(0.970323f), TypeParam(-0.801176f), TypeParam(0.277655f), TypeParam(-1.938051f), TypeParam(-1.879800f), TypeParam(0.287116f), TypeParam(0.261958f), TypeParam(-0.358247f), TypeParam(-0.107750f), TypeParam(0.748162f), TypeParam(-0.742330f), TypeParam(0.344665f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.460252f, 0.734353f, -1.069308f, 1.005361f, 1.198595f, -0.327629f, 0.474026f, 1.196645f, 0.361782f, 0.469280f, 0.440632f, -0.490951f, 0.292918f, -0.639568f, 1.024697f, -0.514217f, 0.274326f, -0.347614f, 0.600117f, 0.019780f, 0.659824f, -0.324940f, -0.704174f, 0.460072f}; + std::initializer_list Grid_data{TypeParam(-0.460252f), TypeParam(0.734353f), TypeParam(-1.069308f), TypeParam(1.005361f), TypeParam(1.198595f), TypeParam(-0.327629f), TypeParam(0.474026f), TypeParam(1.196645f), TypeParam(0.361782f), TypeParam(0.469280f), TypeParam(0.440632f), TypeParam(-0.490951f), TypeParam(0.292918f), TypeParam(-0.639568f), TypeParam(1.024697f), TypeParam(-0.514217f), TypeParam(0.274326f), TypeParam(-0.347614f), TypeParam(0.600117f), TypeParam(0.019780f), TypeParam(0.659824f), TypeParam(-0.324940f), TypeParam(-0.704174f), TypeParam(0.460072f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.646426f, 0.409452f, 0.132247f, -0.106052f, -0.009495f, 0.270785f, -0.702581f, -0.170769f, 0.223282f, -0.044740f, 0.006388f, 0.645576f, -0.476802f, -0.504368f, -0.897503f, -1.684608f, -1.162742f, -0.963921f, -0.197266f, -0.050021f, 0.151796f, 0.662485f, 0.175502f, -0.434265f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.646426f), TypeParam(0.409452f), TypeParam(0.132247f), TypeParam(-0.106052f), TypeParam(-0.009495f), TypeParam(0.270785f), TypeParam(-0.702581f), TypeParam(-0.170769f), TypeParam(0.223282f), TypeParam(-0.044740f), TypeParam(0.006388f), TypeParam(0.645576f), TypeParam(-0.476802f), TypeParam(-0.504368f), TypeParam(-0.897503f), TypeParam(-1.684608f), TypeParam(-1.162742f), TypeParam(-0.963921f), TypeParam(-0.197266f), TypeParam(-0.050021f), TypeParam(0.151796f), TypeParam(0.662485f), TypeParam(0.175502f), TypeParam(-0.434265f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.299262f, -0.304887f, 0.906636f, -0.392850f, -0.050410f, 0.548199f, -1.235108f, -0.475848f, 0.635455f, 0.307462f, -1.241370f, -0.538672f, 0.863466f, 0.799983f, -0.090064f, -0.751721f, 0.956040f, -0.117709f, -2.183699f, -0.484444f, 1.105900f, 0.164466f, 0.720736f, 0.168044f, -0.656400f, 1.770106f, -0.544832f, 1.358424f, 0.981648f, -1.759268f, -0.526924f, 1.322339f, 0.148774f, 0.321413f, -1.257438f, -0.383775f, -2.117908f, -0.077921f, -0.197889f, 0.555813f, -1.517724f, 1.419652f, -0.891774f, 1.684663f, -1.524669f, -2.055758f, -0.299843f, -0.644860f, 0.428609f, -1.704372f, 1.257671f, -0.886508f, -0.029344f, -1.718824f, -0.294273f, 1.537690f, -1.366837f, -1.610098f, 0.650240f, -0.288219f, 0.837292f, 0.431683f, -0.405852f, 0.492271f, 0.416507f, 0.971658f, -0.183526f, 0.615709f, -0.081615f, 1.160796f, 1.431487f, 0.485687f}; + std::initializer_list X_data{TypeParam(-0.299262f), TypeParam(-0.304887f), TypeParam(0.906636f), TypeParam(-0.392850f), TypeParam(-0.050410f), TypeParam(0.548199f), TypeParam(-1.235108f), TypeParam(-0.475848f), TypeParam(0.635455f), TypeParam(0.307462f), TypeParam(-1.241370f), TypeParam(-0.538672f), TypeParam(0.863466f), TypeParam(0.799983f), TypeParam(-0.090064f), TypeParam(-0.751721f), TypeParam(0.956040f), TypeParam(-0.117709f), TypeParam(-2.183699f), TypeParam(-0.484444f), TypeParam(1.105900f), TypeParam(0.164466f), TypeParam(0.720736f), TypeParam(0.168044f), TypeParam(-0.656400f), TypeParam(1.770106f), TypeParam(-0.544832f), TypeParam(1.358424f), TypeParam(0.981648f), TypeParam(-1.759268f), TypeParam(-0.526924f), TypeParam(1.322339f), TypeParam(0.148774f), TypeParam(0.321413f), TypeParam(-1.257438f), TypeParam(-0.383775f), TypeParam(-2.117908f), TypeParam(-0.077921f), TypeParam(-0.197889f), TypeParam(0.555813f), TypeParam(-1.517724f), TypeParam(1.419652f), TypeParam(-0.891774f), TypeParam(1.684663f), TypeParam(-1.524669f), TypeParam(-2.055758f), TypeParam(-0.299843f), TypeParam(-0.644860f), TypeParam(0.428609f), TypeParam(-1.704372f), TypeParam(1.257671f), TypeParam(-0.886508f), TypeParam(-0.029344f), TypeParam(-1.718824f), TypeParam(-0.294273f), TypeParam(1.537690f), TypeParam(-1.366837f), TypeParam(-1.610098f), TypeParam(0.650240f), TypeParam(-0.288219f), TypeParam(0.837292f), TypeParam(0.431683f), TypeParam(-0.405852f), TypeParam(0.492271f), TypeParam(0.416507f), TypeParam(0.971658f), TypeParam(-0.183526f), TypeParam(0.615709f), TypeParam(-0.081615f), TypeParam(1.160796f), TypeParam(1.431487f), TypeParam(0.485687f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.884040f, -0.825214f, 0.496720f, -0.440955f, 1.195811f, 0.169268f, -1.042100f, 0.206524f, 0.145895f, -1.160650f, 0.240829f, 1.144915f, 0.345332f, -0.006382f, -0.248763f, 0.318888f, -0.534619f, 1.181719f, 1.037350f, 0.560600f, -0.446974f, -1.126746f, -0.690807f, 1.166754f, -1.101454f, -1.145775f, -0.086488f, 0.381780f, -1.194351f, -1.114106f, 0.006524f, -0.402521f, 0.836016f, 0.344533f, -1.041627f, -1.081571f, 0.824102f, -0.212785f, -0.524949f, 0.377977f, -0.235842f, 0.573897f, 0.304308f, -0.519568f, -0.961787f, 0.649611f, -0.720973f, -0.132725f, 0.164074f, -0.698360f, 0.653669f, -0.844065f, 0.294728f, 0.128341f, 0.440293f, -1.177701f, 0.069319f, 0.585007f, -0.768260f, 0.296941f, 0.004702f, 1.018020f, -0.254096f, 0.008198f, -0.521925f, -0.295744f, 0.343532f, -1.157334f, 0.910329f, 0.862921f, 0.508195f, 0.898317f, -0.373544f, 0.273330f, 0.061050f, -0.829794f, -0.461335f, -0.426012f, -0.296704f, -1.065526f, -0.843948f, -0.113955f, -0.182548f, -1.089296f, 0.256401f, 0.653393f, 0.999377f, 1.009925f, -0.838519f, -0.384579f, -0.569276f, 0.220093f, 0.321562f, 0.266984f, 0.701244f, 0.633093f, -0.644096f, 0.823778f, 0.809482f, 0.158802f, -1.044029f, -0.735991f, 0.334411f, 0.414891f, 1.118940f, 0.610743f, 0.434932f, -0.040928f}; + std::initializer_list Grid_data{TypeParam(0.884040f), TypeParam(-0.825214f), TypeParam(0.496720f), TypeParam(-0.440955f), TypeParam(1.195811f), TypeParam(0.169268f), TypeParam(-1.042100f), TypeParam(0.206524f), TypeParam(0.145895f), TypeParam(-1.160650f), TypeParam(0.240829f), TypeParam(1.144915f), TypeParam(0.345332f), TypeParam(-0.006382f), TypeParam(-0.248763f), TypeParam(0.318888f), TypeParam(-0.534619f), TypeParam(1.181719f), TypeParam(1.037350f), TypeParam(0.560600f), TypeParam(-0.446974f), TypeParam(-1.126746f), TypeParam(-0.690807f), TypeParam(1.166754f), TypeParam(-1.101454f), TypeParam(-1.145775f), TypeParam(-0.086488f), TypeParam(0.381780f), TypeParam(-1.194351f), TypeParam(-1.114106f), TypeParam(0.006524f), TypeParam(-0.402521f), TypeParam(0.836016f), TypeParam(0.344533f), TypeParam(-1.041627f), TypeParam(-1.081571f), TypeParam(0.824102f), TypeParam(-0.212785f), TypeParam(-0.524949f), TypeParam(0.377977f), TypeParam(-0.235842f), TypeParam(0.573897f), TypeParam(0.304308f), TypeParam(-0.519568f), TypeParam(-0.961787f), TypeParam(0.649611f), TypeParam(-0.720973f), TypeParam(-0.132725f), TypeParam(0.164074f), TypeParam(-0.698360f), TypeParam(0.653669f), TypeParam(-0.844065f), TypeParam(0.294728f), TypeParam(0.128341f), TypeParam(0.440293f), TypeParam(-1.177701f), TypeParam(0.069319f), TypeParam(0.585007f), TypeParam(-0.768260f), TypeParam(0.296941f), TypeParam(0.004702f), TypeParam(1.018020f), TypeParam(-0.254096f), TypeParam(0.008198f), TypeParam(-0.521925f), TypeParam(-0.295744f), TypeParam(0.343532f), TypeParam(-1.157334f), TypeParam(0.910329f), TypeParam(0.862921f), TypeParam(0.508195f), TypeParam(0.898317f), TypeParam(-0.373544f), TypeParam(0.273330f), TypeParam(0.061050f), TypeParam(-0.829794f), TypeParam(-0.461335f), TypeParam(-0.426012f), TypeParam(-0.296704f), TypeParam(-1.065526f), TypeParam(-0.843948f), TypeParam(-0.113955f), TypeParam(-0.182548f), TypeParam(-1.089296f), TypeParam(0.256401f), TypeParam(0.653393f), TypeParam(0.999377f), TypeParam(1.009925f), TypeParam(-0.838519f), TypeParam(-0.384579f), TypeParam(-0.569276f), TypeParam(0.220093f), TypeParam(0.321562f), TypeParam(0.266984f), TypeParam(0.701244f), TypeParam(0.633093f), TypeParam(-0.644096f), TypeParam(0.823778f), TypeParam(0.809482f), TypeParam(0.158802f), TypeParam(-1.044029f), TypeParam(-0.735991f), TypeParam(0.334411f), TypeParam(0.414891f), TypeParam(1.118940f), TypeParam(0.610743f), TypeParam(0.434932f), TypeParam(-0.040928f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.222880f, -0.137918f, 0.042779f, 0.027606f, 0.146833f, 0.119531f, 0.062001f, 0.077615f, -0.124874f, -0.020856f, 0.248748f, -0.050235f, -0.185885f, -0.124030f, -0.148987f, -0.345107f, 0.753440f, -0.055873f, 0.674388f, 0.063018f, -0.054480f, -0.034452f, 0.780917f, 0.193151f, -0.140647f, -0.047364f, -0.095816f, -0.046983f, 0.254384f, -0.123703f, 0.191358f, 0.674903f, -0.311971f, 1.032054f, 0.672506f, 0.009147f, 0.281933f, 0.135835f, -0.145082f, -0.392560f, -0.229593f, -0.632284f, -0.936929f, -0.916689f, -0.502247f, -0.108609f, -0.645451f, 0.242939f, -0.165902f, -1.220095f, -0.015084f, -0.300940f, -0.352557f, -0.886474f, 0.109150f, 0.398365f, 0.235757f, 0.358618f, 0.082189f, 0.268617f, 0.077955f, -0.157573f, 0.023048f, -0.346908f, 0.360128f, 0.389098f, 0.122882f, 0.675956f, 0.735857f, 0.354858f, 0.244544f, 0.631102f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.222880f), TypeParam(-0.137918f), TypeParam(0.042779f), TypeParam(0.027606f), TypeParam(0.146833f), TypeParam(0.119531f), TypeParam(0.062001f), TypeParam(0.077615f), TypeParam(-0.124874f), TypeParam(-0.020856f), TypeParam(0.248748f), TypeParam(-0.050235f), TypeParam(-0.185885f), TypeParam(-0.124030f), TypeParam(-0.148987f), TypeParam(-0.345107f), TypeParam(0.753440f), TypeParam(-0.055873f), TypeParam(0.674388f), TypeParam(0.063018f), TypeParam(-0.054480f), TypeParam(-0.034452f), TypeParam(0.780917f), TypeParam(0.193151f), TypeParam(-0.140647f), TypeParam(-0.047364f), TypeParam(-0.095816f), TypeParam(-0.046983f), TypeParam(0.254384f), TypeParam(-0.123703f), TypeParam(0.191358f), TypeParam(0.674903f), TypeParam(-0.311971f), TypeParam(1.032054f), TypeParam(0.672506f), TypeParam(0.009147f), TypeParam(0.281933f), TypeParam(0.135835f), TypeParam(-0.145082f), TypeParam(-0.392560f), TypeParam(-0.229593f), TypeParam(-0.632284f), TypeParam(-0.936929f), TypeParam(-0.916689f), TypeParam(-0.502247f), TypeParam(-0.108609f), TypeParam(-0.645451f), TypeParam(0.242939f), TypeParam(-0.165902f), TypeParam(-1.220095f), TypeParam(-0.015084f), TypeParam(-0.300940f), TypeParam(-0.352557f), TypeParam(-0.886474f), TypeParam(0.109150f), TypeParam(0.398365f), TypeParam(0.235757f), TypeParam(0.358618f), TypeParam(0.082189f), TypeParam(0.268617f), TypeParam(0.077955f), TypeParam(-0.157573f), TypeParam(0.023048f), TypeParam(-0.346908f), TypeParam(0.360128f), TypeParam(0.389098f), TypeParam(0.122882f), TypeParam(0.675956f), TypeParam(0.735857f), TypeParam(0.354858f), TypeParam(0.244544f), TypeParam(0.631102f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.916003f, 0.150784f, -0.179898f, 0.402727f, -0.549764f, 1.772484f, 1.014343f, 0.502823f, 0.976771f, -0.071957f, 0.519875f, 0.408665f, 1.435640f, -0.807775f, -0.181661f, -0.574026f, -0.335351f, -0.155602f, 0.348749f, 1.055618f, 0.737784f, -0.394725f, 0.597608f, 0.006105f}; + std::initializer_list X_data{TypeParam(-1.916003f), TypeParam(0.150784f), TypeParam(-0.179898f), TypeParam(0.402727f), TypeParam(-0.549764f), TypeParam(1.772484f), TypeParam(1.014343f), TypeParam(0.502823f), TypeParam(0.976771f), TypeParam(-0.071957f), TypeParam(0.519875f), TypeParam(0.408665f), TypeParam(1.435640f), TypeParam(-0.807775f), TypeParam(-0.181661f), TypeParam(-0.574026f), TypeParam(-0.335351f), TypeParam(-0.155602f), TypeParam(0.348749f), TypeParam(1.055618f), TypeParam(0.737784f), TypeParam(-0.394725f), TypeParam(0.597608f), TypeParam(0.006105f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.189838f, -1.050410f, -1.072351f, -0.930754f, -0.502573f, 0.186642f, -0.564332f, -0.042774f, -0.143740f, 1.097448f, -0.547044f, 1.127440f, -0.921224f, -1.001202f, 0.390232f, -0.698394f, 0.615509f, -0.663897f, 0.944958f, 1.161950f, 0.076823f, 0.256464f, 1.118784f, 0.711380f}; + std::initializer_list Grid_data{TypeParam(-0.189838f), TypeParam(-1.050410f), TypeParam(-1.072351f), TypeParam(-0.930754f), TypeParam(-0.502573f), TypeParam(0.186642f), TypeParam(-0.564332f), TypeParam(-0.042774f), TypeParam(-0.143740f), TypeParam(1.097448f), TypeParam(-0.547044f), TypeParam(1.127440f), TypeParam(-0.921224f), TypeParam(-1.001202f), TypeParam(0.390232f), TypeParam(-0.698394f), TypeParam(0.615509f), TypeParam(-0.663897f), TypeParam(0.944958f), TypeParam(1.161950f), TypeParam(0.076823f), TypeParam(0.256464f), TypeParam(1.118784f), TypeParam(0.711380f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.078787f, -1.795786f, -0.023270f, -0.113413f, 0.444460f, -0.023826f, 0.807136f, 1.011742f, 0.674182f, 0.754935f, 0.472262f, 0.494688f, 1.347277f, -0.223507f, -0.417529f, -0.160549f, -0.353331f, -0.276367f, 0.376591f, 0.571813f, 0.551111f, 0.022384f, 0.166782f, -0.109583f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.078787f), TypeParam(-1.795786f), TypeParam(-0.023270f), TypeParam(-0.113413f), TypeParam(0.444460f), TypeParam(-0.023826f), TypeParam(0.807136f), TypeParam(1.011742f), TypeParam(0.674182f), TypeParam(0.754935f), TypeParam(0.472262f), TypeParam(0.494688f), TypeParam(1.347277f), TypeParam(-0.223507f), TypeParam(-0.417529f), TypeParam(-0.160549f), TypeParam(-0.353331f), TypeParam(-0.276367f), TypeParam(0.376591f), TypeParam(0.571813f), TypeParam(0.551111f), TypeParam(0.022384f), TypeParam(0.166782f), TypeParam(-0.109583f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.332555f, 0.980958f, 0.002632f, -1.976749f, 0.979548f, 1.109773f, -0.534887f, 0.705692f, -0.143637f, -0.600830f, 0.315853f, -0.604687f, -0.300652f, -0.375240f, 0.377196f, -0.140920f, 1.159946f, 2.364598f, 0.320719f, 0.397938f, -0.680097f, -1.201632f, 0.270077f, -0.036712f, -0.972864f, 0.792393f, -1.159168f, -0.016679f, -0.665027f, 0.809646f, -1.684452f, 0.049476f, 0.065748f, 0.279619f, -1.079668f, 0.301309f, 1.010100f, -0.119015f, -0.104838f, 0.916627f, -0.522838f, 0.485269f, -1.221088f, 2.044754f, -0.669823f, 0.128370f, 0.080480f, 0.372679f, -0.046427f, -0.732652f, -0.395790f, 0.012594f, -0.170518f, -0.706783f, -0.862588f, -1.177275f, -1.165262f, 0.914826f, -0.661128f, -0.386656f, -0.599246f, 0.544643f, 0.930679f, -1.146137f, 0.212913f, -0.022433f, 1.692830f, 0.187511f, -0.631569f, -0.311540f, -0.885167f, -0.429959f}; + std::initializer_list X_data{TypeParam(-0.332555f), TypeParam(0.980958f), TypeParam(0.002632f), TypeParam(-1.976749f), TypeParam(0.979548f), TypeParam(1.109773f), TypeParam(-0.534887f), TypeParam(0.705692f), TypeParam(-0.143637f), TypeParam(-0.600830f), TypeParam(0.315853f), TypeParam(-0.604687f), TypeParam(-0.300652f), TypeParam(-0.375240f), TypeParam(0.377196f), TypeParam(-0.140920f), TypeParam(1.159946f), TypeParam(2.364598f), TypeParam(0.320719f), TypeParam(0.397938f), TypeParam(-0.680097f), TypeParam(-1.201632f), TypeParam(0.270077f), TypeParam(-0.036712f), TypeParam(-0.972864f), TypeParam(0.792393f), TypeParam(-1.159168f), TypeParam(-0.016679f), TypeParam(-0.665027f), TypeParam(0.809646f), TypeParam(-1.684452f), TypeParam(0.049476f), TypeParam(0.065748f), TypeParam(0.279619f), TypeParam(-1.079668f), TypeParam(0.301309f), TypeParam(1.010100f), TypeParam(-0.119015f), TypeParam(-0.104838f), TypeParam(0.916627f), TypeParam(-0.522838f), TypeParam(0.485269f), TypeParam(-1.221088f), TypeParam(2.044754f), TypeParam(-0.669823f), TypeParam(0.128370f), TypeParam(0.080480f), TypeParam(0.372679f), TypeParam(-0.046427f), TypeParam(-0.732652f), TypeParam(-0.395790f), TypeParam(0.012594f), TypeParam(-0.170518f), TypeParam(-0.706783f), TypeParam(-0.862588f), TypeParam(-1.177275f), TypeParam(-1.165262f), TypeParam(0.914826f), TypeParam(-0.661128f), TypeParam(-0.386656f), TypeParam(-0.599246f), TypeParam(0.544643f), TypeParam(0.930679f), TypeParam(-1.146137f), TypeParam(0.212913f), TypeParam(-0.022433f), TypeParam(1.692830f), TypeParam(0.187511f), TypeParam(-0.631569f), TypeParam(-0.311540f), TypeParam(-0.885167f), TypeParam(-0.429959f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.453992f, 0.394222f, 0.755023f, -0.025610f, 0.658840f, 0.982105f, -0.642922f, -0.265292f, -1.080379f, 0.275464f, 0.855228f, -0.233029f, 0.191483f, 0.383441f, -0.025595f, 0.932929f, 0.174866f, -1.179535f, -0.990943f, -1.188918f, 0.049460f, 0.648682f, -0.158317f, 1.078936f, -0.215883f, 0.245340f, 1.082089f, 0.607310f, -0.038283f, 1.155868f, -0.716957f, 0.446971f, 0.757844f, -0.743030f, -1.127212f, 0.383835f, -0.455267f, -0.605570f, 0.238686f, -0.870514f, 1.079285f, -0.107719f, -0.384303f, 1.003178f, 0.334130f, 0.228627f, -0.573757f, 1.143690f, -0.365482f, 0.998076f, -0.088210f, 0.601965f, 0.843747f, -0.893403f, -0.799804f, -1.186625f, 0.865515f, 1.031983f, -0.438564f, -0.587735f, 0.200868f, 0.646055f, 0.296203f, -0.250092f, -0.763290f, 1.026321f, -0.777136f, -1.159559f, -0.479127f, 0.239290f, 0.446029f, 0.464001f, -0.695158f, -0.460548f, -0.533616f, -0.581111f, -1.010728f, 0.245640f, -0.348981f, -1.155007f, -0.700701f, -0.720655f, -0.517635f, -0.741485f, -0.208103f, 0.430035f, -0.971177f, -0.102798f, -0.345348f, -0.613510f, -0.266458f, -0.508597f, 0.038577f, -0.866220f, 0.227567f, 1.101759f, 0.994334f, -0.538031f, 0.369874f, -1.134245f, 1.010332f, -1.195878f, -1.072351f, -1.077155f, -1.114385f, 0.162516f, -0.317319f, 0.287217f}; + std::initializer_list Grid_data{TypeParam(-0.453992f), TypeParam(0.394222f), TypeParam(0.755023f), TypeParam(-0.025610f), TypeParam(0.658840f), TypeParam(0.982105f), TypeParam(-0.642922f), TypeParam(-0.265292f), TypeParam(-1.080379f), TypeParam(0.275464f), TypeParam(0.855228f), TypeParam(-0.233029f), TypeParam(0.191483f), TypeParam(0.383441f), TypeParam(-0.025595f), TypeParam(0.932929f), TypeParam(0.174866f), TypeParam(-1.179535f), TypeParam(-0.990943f), TypeParam(-1.188918f), TypeParam(0.049460f), TypeParam(0.648682f), TypeParam(-0.158317f), TypeParam(1.078936f), TypeParam(-0.215883f), TypeParam(0.245340f), TypeParam(1.082089f), TypeParam(0.607310f), TypeParam(-0.038283f), TypeParam(1.155868f), TypeParam(-0.716957f), TypeParam(0.446971f), TypeParam(0.757844f), TypeParam(-0.743030f), TypeParam(-1.127212f), TypeParam(0.383835f), TypeParam(-0.455267f), TypeParam(-0.605570f), TypeParam(0.238686f), TypeParam(-0.870514f), TypeParam(1.079285f), TypeParam(-0.107719f), TypeParam(-0.384303f), TypeParam(1.003178f), TypeParam(0.334130f), TypeParam(0.228627f), TypeParam(-0.573757f), TypeParam(1.143690f), TypeParam(-0.365482f), TypeParam(0.998076f), TypeParam(-0.088210f), TypeParam(0.601965f), TypeParam(0.843747f), TypeParam(-0.893403f), TypeParam(-0.799804f), TypeParam(-1.186625f), TypeParam(0.865515f), TypeParam(1.031983f), TypeParam(-0.438564f), TypeParam(-0.587735f), TypeParam(0.200868f), TypeParam(0.646055f), TypeParam(0.296203f), TypeParam(-0.250092f), TypeParam(-0.763290f), TypeParam(1.026321f), TypeParam(-0.777136f), TypeParam(-1.159559f), TypeParam(-0.479127f), TypeParam(0.239290f), TypeParam(0.446029f), TypeParam(0.464001f), TypeParam(-0.695158f), TypeParam(-0.460548f), TypeParam(-0.533616f), TypeParam(-0.581111f), TypeParam(-1.010728f), TypeParam(0.245640f), TypeParam(-0.348981f), TypeParam(-1.155007f), TypeParam(-0.700701f), TypeParam(-0.720655f), TypeParam(-0.517635f), TypeParam(-0.741485f), TypeParam(-0.208103f), TypeParam(0.430035f), TypeParam(-0.971177f), TypeParam(-0.102798f), TypeParam(-0.345348f), TypeParam(-0.613510f), TypeParam(-0.266458f), TypeParam(-0.508597f), TypeParam(0.038577f), TypeParam(-0.866220f), TypeParam(0.227567f), TypeParam(1.101759f), TypeParam(0.994334f), TypeParam(-0.538031f), TypeParam(0.369874f), TypeParam(-1.134245f), TypeParam(1.010332f), TypeParam(-1.195878f), TypeParam(-1.072351f), TypeParam(-1.077155f), TypeParam(-1.114385f), TypeParam(0.162516f), TypeParam(-0.317319f), TypeParam(0.287217f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.517362f, 1.168304f, -0.283719f, -0.056944f, -0.345007f, -1.383013f, -0.517978f, -0.099340f, 0.531814f, -0.051495f, 0.570203f, -0.350444f, -0.195512f, 0.335075f, 0.533103f, -0.173681f, 0.110927f, 0.549661f, -0.303447f, -0.209369f, -0.479343f, 0.113517f, -0.222508f, -0.981697f, -1.000072f, 0.163343f, -0.019158f, 0.217390f, -0.442252f, -1.020732f, -0.645033f, -0.481248f, -0.359233f, -0.271288f, -0.165768f, -0.092544f, -0.219889f, 0.671201f, -0.041137f, -0.289275f, -0.022793f, -0.130253f, -0.072692f, -0.451858f, 0.402947f, 0.168711f, 0.110811f, 0.202315f, -0.200036f, -0.331588f, 0.583341f, -0.522838f, 1.010100f, -0.018650f, 1.269564f, -0.168394f, -0.209390f, 0.740205f, -0.675828f, -0.325915f, -0.404694f, 0.067064f, -0.744102f, -0.639736f, -0.416580f, -0.317643f, 0.004590f, -0.665815f, -0.163600f, -0.661128f, -0.862588f, -0.132515f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.517362f), TypeParam(1.168304f), TypeParam(-0.283719f), TypeParam(-0.056944f), TypeParam(-0.345007f), TypeParam(-1.383013f), TypeParam(-0.517978f), TypeParam(-0.099340f), TypeParam(0.531814f), TypeParam(-0.051495f), TypeParam(0.570203f), TypeParam(-0.350444f), TypeParam(-0.195512f), TypeParam(0.335075f), TypeParam(0.533103f), TypeParam(-0.173681f), TypeParam(0.110927f), TypeParam(0.549661f), TypeParam(-0.303447f), TypeParam(-0.209369f), TypeParam(-0.479343f), TypeParam(0.113517f), TypeParam(-0.222508f), TypeParam(-0.981697f), TypeParam(-1.000072f), TypeParam(0.163343f), TypeParam(-0.019158f), TypeParam(0.217390f), TypeParam(-0.442252f), TypeParam(-1.020732f), TypeParam(-0.645033f), TypeParam(-0.481248f), TypeParam(-0.359233f), TypeParam(-0.271288f), TypeParam(-0.165768f), TypeParam(-0.092544f), TypeParam(-0.219889f), TypeParam(0.671201f), TypeParam(-0.041137f), TypeParam(-0.289275f), TypeParam(-0.022793f), TypeParam(-0.130253f), TypeParam(-0.072692f), TypeParam(-0.451858f), TypeParam(0.402947f), TypeParam(0.168711f), TypeParam(0.110811f), TypeParam(0.202315f), TypeParam(-0.200036f), TypeParam(-0.331588f), TypeParam(0.583341f), TypeParam(-0.522838f), TypeParam(1.010100f), TypeParam(-0.018650f), TypeParam(1.269564f), TypeParam(-0.168394f), TypeParam(-0.209390f), TypeParam(0.740205f), TypeParam(-0.675828f), TypeParam(-0.325915f), TypeParam(-0.404694f), TypeParam(0.067064f), TypeParam(-0.744102f), TypeParam(-0.639736f), TypeParam(-0.416580f), TypeParam(-0.317643f), TypeParam(0.004590f), TypeParam(-0.665815f), TypeParam(-0.163600f), TypeParam(-0.661128f), TypeParam(-0.862588f), TypeParam(-0.132515f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.050553f, -0.825690f, -0.616085f, 0.337113f, 0.370334f, -0.105073f, -0.565382f, 0.396842f, -0.373193f, -0.780451f, -1.932970f, 1.104960f, -2.569945f, 0.661190f, -0.192302f, 0.734279f, 0.351872f, -1.068136f, 0.173665f, -0.778153f, -0.981877f, 1.485344f, 0.431733f, 0.428167f}; + std::initializer_list X_data{TypeParam(-0.050553f), TypeParam(-0.825690f), TypeParam(-0.616085f), TypeParam(0.337113f), TypeParam(0.370334f), TypeParam(-0.105073f), TypeParam(-0.565382f), TypeParam(0.396842f), TypeParam(-0.373193f), TypeParam(-0.780451f), TypeParam(-1.932970f), TypeParam(1.104960f), TypeParam(-2.569945f), TypeParam(0.661190f), TypeParam(-0.192302f), TypeParam(0.734279f), TypeParam(0.351872f), TypeParam(-1.068136f), TypeParam(0.173665f), TypeParam(-0.778153f), TypeParam(-0.981877f), TypeParam(1.485344f), TypeParam(0.431733f), TypeParam(0.428167f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.330875f, 0.589988f, 0.011588f, -1.144325f, -1.038357f, 0.435055f, -1.053243f, -0.957144f, -0.715458f, 1.143742f, -0.341215f, -0.494762f, -0.810255f, 0.767649f, -0.193763f, 0.231402f, 0.286668f, 0.338432f, 0.768106f, 0.062272f, 0.124125f, -0.077928f, -0.932481f, -0.274618f}; + std::initializer_list Grid_data{TypeParam(-0.330875f), TypeParam(0.589988f), TypeParam(0.011588f), TypeParam(-1.144325f), TypeParam(-1.038357f), TypeParam(0.435055f), TypeParam(-1.053243f), TypeParam(-0.957144f), TypeParam(-0.715458f), TypeParam(1.143742f), TypeParam(-0.341215f), TypeParam(-0.494762f), TypeParam(-0.810255f), TypeParam(0.767649f), TypeParam(-0.193763f), TypeParam(0.231402f), TypeParam(0.286668f), TypeParam(0.338432f), TypeParam(0.768106f), TypeParam(0.062272f), TypeParam(0.124125f), TypeParam(-0.077928f), TypeParam(-0.932481f), TypeParam(-0.274618f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.204265f, -0.447104f, 0.027635f, -0.050553f, 0.370334f, -0.248695f, -1.306797f, -0.073120f, -1.391077f, -0.565382f, -1.932970f, -0.419110f, 0.351872f, 0.030903f, -0.124253f, 0.565919f, 0.276202f, -1.171718f, 0.431733f, 0.001712f, 0.689913f, 1.386595f, 0.443614f, -0.505878f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.204265f), TypeParam(-0.447104f), TypeParam(0.027635f), TypeParam(-0.050553f), TypeParam(0.370334f), TypeParam(-0.248695f), TypeParam(-1.306797f), TypeParam(-0.073120f), TypeParam(-1.391077f), TypeParam(-0.565382f), TypeParam(-1.932970f), TypeParam(-0.419110f), TypeParam(0.351872f), TypeParam(0.030903f), TypeParam(-0.124253f), TypeParam(0.565919f), TypeParam(0.276202f), TypeParam(-1.171718f), TypeParam(0.431733f), TypeParam(0.001712f), TypeParam(0.689913f), TypeParam(1.386595f), TypeParam(0.443614f), TypeParam(-0.505878f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.727099f, 0.057663f, -0.548384f, 0.078163f, -0.133679f, 0.211872f, 0.271687f, -1.221973f, -2.630687f, -0.558102f, -0.327183f, 0.039894f, 1.222102f, 0.144418f, 0.696676f, -2.231791f, 0.910544f, 2.749837f, -0.354036f, -0.106102f, 2.453576f, 0.332319f, -1.743712f, 1.416859f, 0.260041f, -1.179930f, 0.407328f, 0.375476f, 2.028488f, 0.174825f, -1.467126f, 0.079045f, 0.870076f, -0.895165f, 0.631429f, 0.358222f, 1.484120f, -0.622331f, 0.727481f, 0.644213f, 1.299103f, -0.378573f, 1.360908f, 0.905514f, 0.180065f, 0.972162f, 1.246238f, -0.537204f, -1.241497f, -0.772822f, -0.149044f, -1.642060f, 0.120091f, 0.937023f, 0.422106f, 0.652040f, 0.045585f, -1.089530f, 0.356099f, 0.536075f, -1.840257f, -1.035736f, 0.348653f, 0.187942f, 0.150011f, 0.521798f, 1.271739f, 0.977495f, 0.811927f, 0.641729f, 0.964401f, -0.693074f}; + std::initializer_list X_data{TypeParam(-0.727099f), TypeParam(0.057663f), TypeParam(-0.548384f), TypeParam(0.078163f), TypeParam(-0.133679f), TypeParam(0.211872f), TypeParam(0.271687f), TypeParam(-1.221973f), TypeParam(-2.630687f), TypeParam(-0.558102f), TypeParam(-0.327183f), TypeParam(0.039894f), TypeParam(1.222102f), TypeParam(0.144418f), TypeParam(0.696676f), TypeParam(-2.231791f), TypeParam(0.910544f), TypeParam(2.749837f), TypeParam(-0.354036f), TypeParam(-0.106102f), TypeParam(2.453576f), TypeParam(0.332319f), TypeParam(-1.743712f), TypeParam(1.416859f), TypeParam(0.260041f), TypeParam(-1.179930f), TypeParam(0.407328f), TypeParam(0.375476f), TypeParam(2.028488f), TypeParam(0.174825f), TypeParam(-1.467126f), TypeParam(0.079045f), TypeParam(0.870076f), TypeParam(-0.895165f), TypeParam(0.631429f), TypeParam(0.358222f), TypeParam(1.484120f), TypeParam(-0.622331f), TypeParam(0.727481f), TypeParam(0.644213f), TypeParam(1.299103f), TypeParam(-0.378573f), TypeParam(1.360908f), TypeParam(0.905514f), TypeParam(0.180065f), TypeParam(0.972162f), TypeParam(1.246238f), TypeParam(-0.537204f), TypeParam(-1.241497f), TypeParam(-0.772822f), TypeParam(-0.149044f), TypeParam(-1.642060f), TypeParam(0.120091f), TypeParam(0.937023f), TypeParam(0.422106f), TypeParam(0.652040f), TypeParam(0.045585f), TypeParam(-1.089530f), TypeParam(0.356099f), TypeParam(0.536075f), TypeParam(-1.840257f), TypeParam(-1.035736f), TypeParam(0.348653f), TypeParam(0.187942f), TypeParam(0.150011f), TypeParam(0.521798f), TypeParam(1.271739f), TypeParam(0.977495f), TypeParam(0.811927f), TypeParam(0.641729f), TypeParam(0.964401f), TypeParam(-0.693074f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{1.017692f, -0.818194f, 0.525611f, -0.556812f, -0.124601f, 1.120205f, 0.153552f, -1.144168f, 1.103147f, -0.050771f, -0.600881f, -0.633732f, 1.029039f, 0.020253f, 0.662802f, 0.788674f, -0.465758f, 0.101853f, -0.776226f, 1.002064f, -0.634553f, 0.797064f, 0.304043f, 0.740241f, -0.845484f, -0.037319f, 0.621792f, -0.047898f, -0.017218f, 0.584766f, -0.896882f, -0.240587f, 0.546590f, 0.588539f, 1.114539f, -0.237379f, 0.284327f, -0.590432f, -0.201402f, -0.602420f, 0.889284f, 0.007310f, 0.488176f, 0.660055f, 0.223618f, 0.127703f, -0.087830f, -1.016490f, 0.193341f, -0.265853f, -1.008634f, 1.118021f, -0.127930f, -0.598904f, -1.168221f, -1.105256f, 0.456964f, -0.547805f, -0.518368f, -0.694346f, 0.968648f, -0.288466f, 0.777819f, 0.952657f, -0.930362f, 0.895254f, -0.229149f, 1.149323f, 0.612939f, -1.162419f, 0.222934f, 0.421831f, -0.435327f, 0.909973f, -0.993750f, -0.380767f, 1.143396f, 1.171977f, 0.599451f, -0.716336f, -1.032482f, -0.975683f, -0.299985f, 0.679795f, 0.379920f, -0.145729f, 1.079221f, 0.942322f, -0.560859f, -0.519668f, -0.014079f, 0.249021f, -0.008590f, 0.463277f, 0.827937f, -0.216375f, 0.589310f, 0.163207f, 0.460623f, 0.494016f, -0.320739f, -0.535032f, 0.512922f, -0.768302f, 0.630003f, -0.769945f, 0.823242f, 0.481487f}; + std::initializer_list Grid_data{TypeParam(1.017692f), TypeParam(-0.818194f), TypeParam(0.525611f), TypeParam(-0.556812f), TypeParam(-0.124601f), TypeParam(1.120205f), TypeParam(0.153552f), TypeParam(-1.144168f), TypeParam(1.103147f), TypeParam(-0.050771f), TypeParam(-0.600881f), TypeParam(-0.633732f), TypeParam(1.029039f), TypeParam(0.020253f), TypeParam(0.662802f), TypeParam(0.788674f), TypeParam(-0.465758f), TypeParam(0.101853f), TypeParam(-0.776226f), TypeParam(1.002064f), TypeParam(-0.634553f), TypeParam(0.797064f), TypeParam(0.304043f), TypeParam(0.740241f), TypeParam(-0.845484f), TypeParam(-0.037319f), TypeParam(0.621792f), TypeParam(-0.047898f), TypeParam(-0.017218f), TypeParam(0.584766f), TypeParam(-0.896882f), TypeParam(-0.240587f), TypeParam(0.546590f), TypeParam(0.588539f), TypeParam(1.114539f), TypeParam(-0.237379f), TypeParam(0.284327f), TypeParam(-0.590432f), TypeParam(-0.201402f), TypeParam(-0.602420f), TypeParam(0.889284f), TypeParam(0.007310f), TypeParam(0.488176f), TypeParam(0.660055f), TypeParam(0.223618f), TypeParam(0.127703f), TypeParam(-0.087830f), TypeParam(-1.016490f), TypeParam(0.193341f), TypeParam(-0.265853f), TypeParam(-1.008634f), TypeParam(1.118021f), TypeParam(-0.127930f), TypeParam(-0.598904f), TypeParam(-1.168221f), TypeParam(-1.105256f), TypeParam(0.456964f), TypeParam(-0.547805f), TypeParam(-0.518368f), TypeParam(-0.694346f), TypeParam(0.968648f), TypeParam(-0.288466f), TypeParam(0.777819f), TypeParam(0.952657f), TypeParam(-0.930362f), TypeParam(0.895254f), TypeParam(-0.229149f), TypeParam(1.149323f), TypeParam(0.612939f), TypeParam(-1.162419f), TypeParam(0.222934f), TypeParam(0.421831f), TypeParam(-0.435327f), TypeParam(0.909973f), TypeParam(-0.993750f), TypeParam(-0.380767f), TypeParam(1.143396f), TypeParam(1.171977f), TypeParam(0.599451f), TypeParam(-0.716336f), TypeParam(-1.032482f), TypeParam(-0.975683f), TypeParam(-0.299985f), TypeParam(0.679795f), TypeParam(0.379920f), TypeParam(-0.145729f), TypeParam(1.079221f), TypeParam(0.942322f), TypeParam(-0.560859f), TypeParam(-0.519668f), TypeParam(-0.014079f), TypeParam(0.249021f), TypeParam(-0.008590f), TypeParam(0.463277f), TypeParam(0.827937f), TypeParam(-0.216375f), TypeParam(0.589310f), TypeParam(0.163207f), TypeParam(0.460623f), TypeParam(0.494016f), TypeParam(-0.320739f), TypeParam(-0.535032f), TypeParam(0.512922f), TypeParam(-0.768302f), TypeParam(0.630003f), TypeParam(-0.769945f), TypeParam(0.823242f), TypeParam(0.481487f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.144687f, 0.794879f, 0.517780f, -0.372025f, -2.071523f, -0.953122f, -0.143000f, 0.040151f, 0.511071f, -0.723342f, 0.441486f, 0.101130f, -0.668215f, -0.313612f, 0.918245f, -0.165560f, -0.141496f, -0.002992f, -0.187333f, 0.433250f, -0.456623f, -0.082449f, -0.849978f, -0.635311f, -1.562003f, -0.323540f, 0.716348f, 0.089914f, 0.085623f, 0.617075f, -0.522245f, 2.013170f, 0.249061f, 0.948093f, 0.518262f, 0.230788f, -0.422900f, 1.315807f, -1.265941f, -0.772822f, 0.375354f, 0.159706f, 1.190603f, 0.217497f, -0.622331f, -0.640623f, -1.324261f, -0.126419f, 0.497220f, -0.421485f, -0.512049f, 0.218454f, -0.680520f, 0.432900f, 0.292848f, 0.338349f, 0.787015f, 0.977495f, 0.494135f, 0.649655f, 0.367739f, 0.766775f, 0.652040f, 1.018832f, 0.738819f, 0.107251f, 0.287288f, 0.515065f, 0.300961f, -0.279154f, 0.866776f, 0.738188f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.144687f), TypeParam(0.794879f), TypeParam(0.517780f), TypeParam(-0.372025f), TypeParam(-2.071523f), TypeParam(-0.953122f), TypeParam(-0.143000f), TypeParam(0.040151f), TypeParam(0.511071f), TypeParam(-0.723342f), TypeParam(0.441486f), TypeParam(0.101130f), TypeParam(-0.668215f), TypeParam(-0.313612f), TypeParam(0.918245f), TypeParam(-0.165560f), TypeParam(-0.141496f), TypeParam(-0.002992f), TypeParam(-0.187333f), TypeParam(0.433250f), TypeParam(-0.456623f), TypeParam(-0.082449f), TypeParam(-0.849978f), TypeParam(-0.635311f), TypeParam(-1.562003f), TypeParam(-0.323540f), TypeParam(0.716348f), TypeParam(0.089914f), TypeParam(0.085623f), TypeParam(0.617075f), TypeParam(-0.522245f), TypeParam(2.013170f), TypeParam(0.249061f), TypeParam(0.948093f), TypeParam(0.518262f), TypeParam(0.230788f), TypeParam(-0.422900f), TypeParam(1.315807f), TypeParam(-1.265941f), TypeParam(-0.772822f), TypeParam(0.375354f), TypeParam(0.159706f), TypeParam(1.190603f), TypeParam(0.217497f), TypeParam(-0.622331f), TypeParam(-0.640623f), TypeParam(-1.324261f), TypeParam(-0.126419f), TypeParam(0.497220f), TypeParam(-0.421485f), TypeParam(-0.512049f), TypeParam(0.218454f), TypeParam(-0.680520f), TypeParam(0.432900f), TypeParam(0.292848f), TypeParam(0.338349f), TypeParam(0.787015f), TypeParam(0.977495f), TypeParam(0.494135f), TypeParam(0.649655f), TypeParam(0.367739f), TypeParam(0.766775f), TypeParam(0.652040f), TypeParam(1.018832f), TypeParam(0.738819f), TypeParam(0.107251f), TypeParam(0.287288f), TypeParam(0.515065f), TypeParam(0.300961f), TypeParam(-0.279154f), TypeParam(0.866776f), TypeParam(0.738188f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.599439f, 0.317612f, -0.294302f, -0.530613f, 0.754687f, 0.092241f, -1.009405f, -1.155944f, 0.336327f, 0.159353f, -1.134330f, 0.510271f, 0.271972f, 1.301884f, 1.027400f, 1.193876f, 0.304363f, 1.027256f, 0.186801f, 0.719412f, -0.310900f, -1.123812f, -0.312771f, 2.729156f}; + std::initializer_list X_data{TypeParam(-0.599439f), TypeParam(0.317612f), TypeParam(-0.294302f), TypeParam(-0.530613f), TypeParam(0.754687f), TypeParam(0.092241f), TypeParam(-1.009405f), TypeParam(-1.155944f), TypeParam(0.336327f), TypeParam(0.159353f), TypeParam(-1.134330f), TypeParam(0.510271f), TypeParam(0.271972f), TypeParam(1.301884f), TypeParam(1.027400f), TypeParam(1.193876f), TypeParam(0.304363f), TypeParam(1.027256f), TypeParam(0.186801f), TypeParam(0.719412f), TypeParam(-0.310900f), TypeParam(-1.123812f), TypeParam(-0.312771f), TypeParam(2.729156f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.853801f, 0.833200f, -0.477474f, 0.131677f, 0.571825f, 0.858708f, -1.120796f, 1.194690f, -0.301706f, 0.488934f, -0.745307f, -0.923452f, -0.812682f, 0.707226f, -0.591920f, 0.697573f, 0.362777f, 0.477332f, -0.266909f, -0.379588f, -0.561456f, -0.670762f, 1.106438f, -0.065215f}; + std::initializer_list Grid_data{TypeParam(0.853801f), TypeParam(0.833200f), TypeParam(-0.477474f), TypeParam(0.131677f), TypeParam(0.571825f), TypeParam(0.858708f), TypeParam(-1.120796f), TypeParam(1.194690f), TypeParam(-0.301706f), TypeParam(0.488934f), TypeParam(-0.745307f), TypeParam(-0.923452f), TypeParam(-0.812682f), TypeParam(0.707226f), TypeParam(-0.591920f), TypeParam(0.697573f), TypeParam(0.362777f), TypeParam(0.477332f), TypeParam(-0.266909f), TypeParam(-0.379588f), TypeParam(-0.561456f), TypeParam(-0.670762f), TypeParam(1.106438f), TypeParam(-0.065215f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.031577f, -0.232574f, 0.133168f, 0.515460f, 0.063332f, -0.470541f, 0.353729f, 0.159106f, 0.163701f, -0.770097f, -0.133556f, -0.925350f, 0.568498f, 0.636194f, 0.976680f, 0.921805f, 0.684184f, 1.189063f, -0.133022f, 0.070598f, 0.388079f, -0.232737f, 0.042589f, -0.965013f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.031577f), TypeParam(-0.232574f), TypeParam(0.133168f), TypeParam(0.515460f), TypeParam(0.063332f), TypeParam(-0.470541f), TypeParam(0.353729f), TypeParam(0.159106f), TypeParam(0.163701f), TypeParam(-0.770097f), TypeParam(-0.133556f), TypeParam(-0.925350f), TypeParam(0.568498f), TypeParam(0.636194f), TypeParam(0.976680f), TypeParam(0.921805f), TypeParam(0.684184f), TypeParam(1.189063f), TypeParam(-0.133022f), TypeParam(0.070598f), TypeParam(0.388079f), TypeParam(-0.232737f), TypeParam(0.042589f), TypeParam(-0.965013f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.441629f, 0.199148f, 1.214051f, -0.000869f, 0.863692f, -0.067719f, -0.621662f, 0.235179f, 0.691041f, 0.176564f, 0.036477f, -0.085879f, 0.785440f, -1.837889f, -0.300151f, -1.710413f, 0.484432f, 2.160478f, -0.049246f, 0.372475f, -1.060470f, -1.000841f, -0.473439f, 0.963055f, 0.174518f, 0.932434f, 0.039338f, -0.343549f, -1.446623f, -0.673622f, 0.520395f, -0.279228f, -0.367065f, -0.871085f, 0.649273f, -0.835047f, 1.063542f, -1.829784f, 1.476173f, -1.048210f, -1.127299f, 1.204756f, -0.998390f, -1.014054f, -1.032717f, 0.977184f, 0.959897f, -0.749289f, 0.784492f, 1.343993f, 1.291144f, 0.099496f, 2.086763f, 0.529948f, -2.296640f, 0.570701f, 0.491216f, -0.003836f, -0.591929f, -0.076994f, 1.239698f, -0.888840f, 0.623497f, 0.769879f, 2.240972f, -2.081689f, 0.798466f, 1.207944f, -0.486804f, -0.488222f, -0.746382f, -0.220282f}; + std::initializer_list X_data{TypeParam(-0.441629f), TypeParam(0.199148f), TypeParam(1.214051f), TypeParam(-0.000869f), TypeParam(0.863692f), TypeParam(-0.067719f), TypeParam(-0.621662f), TypeParam(0.235179f), TypeParam(0.691041f), TypeParam(0.176564f), TypeParam(0.036477f), TypeParam(-0.085879f), TypeParam(0.785440f), TypeParam(-1.837889f), TypeParam(-0.300151f), TypeParam(-1.710413f), TypeParam(0.484432f), TypeParam(2.160478f), TypeParam(-0.049246f), TypeParam(0.372475f), TypeParam(-1.060470f), TypeParam(-1.000841f), TypeParam(-0.473439f), TypeParam(0.963055f), TypeParam(0.174518f), TypeParam(0.932434f), TypeParam(0.039338f), TypeParam(-0.343549f), TypeParam(-1.446623f), TypeParam(-0.673622f), TypeParam(0.520395f), TypeParam(-0.279228f), TypeParam(-0.367065f), TypeParam(-0.871085f), TypeParam(0.649273f), TypeParam(-0.835047f), TypeParam(1.063542f), TypeParam(-1.829784f), TypeParam(1.476173f), TypeParam(-1.048210f), TypeParam(-1.127299f), TypeParam(1.204756f), TypeParam(-0.998390f), TypeParam(-1.014054f), TypeParam(-1.032717f), TypeParam(0.977184f), TypeParam(0.959897f), TypeParam(-0.749289f), TypeParam(0.784492f), TypeParam(1.343993f), TypeParam(1.291144f), TypeParam(0.099496f), TypeParam(2.086763f), TypeParam(0.529948f), TypeParam(-2.296640f), TypeParam(0.570701f), TypeParam(0.491216f), TypeParam(-0.003836f), TypeParam(-0.591929f), TypeParam(-0.076994f), TypeParam(1.239698f), TypeParam(-0.888840f), TypeParam(0.623497f), TypeParam(0.769879f), TypeParam(2.240972f), TypeParam(-2.081689f), TypeParam(0.798466f), TypeParam(1.207944f), TypeParam(-0.486804f), TypeParam(-0.488222f), TypeParam(-0.746382f), TypeParam(-0.220282f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.169044f, 0.178997f, 1.112567f, -0.825642f, -0.359793f, 0.170758f, -0.081412f, 0.319486f, 0.630993f, -0.493702f, 0.093438f, 1.085657f, -0.679024f, -0.813753f, -0.920282f, 0.717311f, -1.100678f, -0.583561f, 0.810473f, -0.719377f, 0.975857f, -0.560957f, 0.189840f, 0.157082f, -0.029434f, 0.747413f, 1.019186f, -0.749235f, 0.673000f, 0.320624f, -0.022362f, -0.839050f, 0.355966f, 0.871005f, -1.030007f, -1.108265f, -1.179701f, 0.277273f, -0.344802f, -0.372753f, 1.117390f, -0.306079f, -0.762057f, 0.107942f, -0.658634f, -0.351593f, 0.633875f, 0.276953f, -0.823465f, 1.142446f, 0.811875f, -0.818022f, 0.522699f, 0.493103f, -0.861061f, -0.843352f, -0.993629f, 0.534540f, 0.209070f, 0.507143f, -0.527071f, 0.902309f, 0.153227f, -0.957513f, -0.302041f, 0.612404f, 0.263859f, -0.183579f, -0.838388f, -0.746482f, 1.035039f, -0.687403f, 0.850371f, -0.401659f, 0.011995f, -1.168548f, -0.390077f, 1.011575f, -1.077360f, 0.603794f, -1.009901f, 0.175023f, -1.087964f, -0.949961f, -0.968757f, -0.416100f, 0.163389f, -0.879807f, 0.304124f, 0.722748f, 0.978239f, 1.062535f, 0.790067f, -0.353356f, -0.110591f, 1.061730f, 0.596951f, -0.318231f, 0.905999f, -1.048710f, 1.027042f, 0.671407f, -0.880154f, -0.978736f, 0.938431f, 1.183815f, 0.104716f, -0.468883f}; + std::initializer_list Grid_data{TypeParam(-0.169044f), TypeParam(0.178997f), TypeParam(1.112567f), TypeParam(-0.825642f), TypeParam(-0.359793f), TypeParam(0.170758f), TypeParam(-0.081412f), TypeParam(0.319486f), TypeParam(0.630993f), TypeParam(-0.493702f), TypeParam(0.093438f), TypeParam(1.085657f), TypeParam(-0.679024f), TypeParam(-0.813753f), TypeParam(-0.920282f), TypeParam(0.717311f), TypeParam(-1.100678f), TypeParam(-0.583561f), TypeParam(0.810473f), TypeParam(-0.719377f), TypeParam(0.975857f), TypeParam(-0.560957f), TypeParam(0.189840f), TypeParam(0.157082f), TypeParam(-0.029434f), TypeParam(0.747413f), TypeParam(1.019186f), TypeParam(-0.749235f), TypeParam(0.673000f), TypeParam(0.320624f), TypeParam(-0.022362f), TypeParam(-0.839050f), TypeParam(0.355966f), TypeParam(0.871005f), TypeParam(-1.030007f), TypeParam(-1.108265f), TypeParam(-1.179701f), TypeParam(0.277273f), TypeParam(-0.344802f), TypeParam(-0.372753f), TypeParam(1.117390f), TypeParam(-0.306079f), TypeParam(-0.762057f), TypeParam(0.107942f), TypeParam(-0.658634f), TypeParam(-0.351593f), TypeParam(0.633875f), TypeParam(0.276953f), TypeParam(-0.823465f), TypeParam(1.142446f), TypeParam(0.811875f), TypeParam(-0.818022f), TypeParam(0.522699f), TypeParam(0.493103f), TypeParam(-0.861061f), TypeParam(-0.843352f), TypeParam(-0.993629f), TypeParam(0.534540f), TypeParam(0.209070f), TypeParam(0.507143f), TypeParam(-0.527071f), TypeParam(0.902309f), TypeParam(0.153227f), TypeParam(-0.957513f), TypeParam(-0.302041f), TypeParam(0.612404f), TypeParam(0.263859f), TypeParam(-0.183579f), TypeParam(-0.838388f), TypeParam(-0.746482f), TypeParam(1.035039f), TypeParam(-0.687403f), TypeParam(0.850371f), TypeParam(-0.401659f), TypeParam(0.011995f), TypeParam(-1.168548f), TypeParam(-0.390077f), TypeParam(1.011575f), TypeParam(-1.077360f), TypeParam(0.603794f), TypeParam(-1.009901f), TypeParam(0.175023f), TypeParam(-1.087964f), TypeParam(-0.949961f), TypeParam(-0.968757f), TypeParam(-0.416100f), TypeParam(0.163389f), TypeParam(-0.879807f), TypeParam(0.304124f), TypeParam(0.722748f), TypeParam(0.978239f), TypeParam(1.062535f), TypeParam(0.790067f), TypeParam(-0.353356f), TypeParam(-0.110591f), TypeParam(1.061730f), TypeParam(0.596951f), TypeParam(-0.318231f), TypeParam(0.905999f), TypeParam(-1.048710f), TypeParam(1.027042f), TypeParam(0.671407f), TypeParam(-0.880154f), TypeParam(-0.978736f), TypeParam(0.938431f), TypeParam(1.183815f), TypeParam(0.104716f), TypeParam(-0.468883f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.414201f, 0.167816f, -0.042305f, -0.423495f, -0.101419f, 0.120192f, -1.543294f, 0.344146f, 0.709278f, 0.248721f, -0.269138f, 0.158159f, 0.659876f, 0.226329f, 0.874509f, 0.240959f, 0.412611f, 0.225904f, -0.448580f, 0.057703f, -0.426538f, -0.401142f, -0.147435f, 0.401852f, -0.355426f, -0.286018f, -0.219687f, -0.564205f, 0.282723f, 0.363522f, -0.543706f, -0.787722f, -0.692217f, -0.594894f, 0.091005f, -0.328214f, 0.919003f, 0.408116f, 0.631220f, 0.303619f, -0.197801f, -0.308153f, 0.094457f, 1.027881f, -0.077622f, -0.597219f, -0.661449f, 0.947805f, 0.279352f, 0.828246f, 0.571205f, 1.646163f, 0.714257f, 0.049881f, -1.680014f, -0.056047f, 0.892393f, 0.250564f, 0.138843f, 0.178706f, 0.161286f, 0.036891f, -0.141908f, -0.510903f, 0.733949f, -0.112944f, -0.581858f, -0.269439f, 0.056781f, 0.200325f, 0.814038f, 0.277386f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.414201f), TypeParam(0.167816f), TypeParam(-0.042305f), TypeParam(-0.423495f), TypeParam(-0.101419f), TypeParam(0.120192f), TypeParam(-1.543294f), TypeParam(0.344146f), TypeParam(0.709278f), TypeParam(0.248721f), TypeParam(-0.269138f), TypeParam(0.158159f), TypeParam(0.659876f), TypeParam(0.226329f), TypeParam(0.874509f), TypeParam(0.240959f), TypeParam(0.412611f), TypeParam(0.225904f), TypeParam(-0.448580f), TypeParam(0.057703f), TypeParam(-0.426538f), TypeParam(-0.401142f), TypeParam(-0.147435f), TypeParam(0.401852f), TypeParam(-0.355426f), TypeParam(-0.286018f), TypeParam(-0.219687f), TypeParam(-0.564205f), TypeParam(0.282723f), TypeParam(0.363522f), TypeParam(-0.543706f), TypeParam(-0.787722f), TypeParam(-0.692217f), TypeParam(-0.594894f), TypeParam(0.091005f), TypeParam(-0.328214f), TypeParam(0.919003f), TypeParam(0.408116f), TypeParam(0.631220f), TypeParam(0.303619f), TypeParam(-0.197801f), TypeParam(-0.308153f), TypeParam(0.094457f), TypeParam(1.027881f), TypeParam(-0.077622f), TypeParam(-0.597219f), TypeParam(-0.661449f), TypeParam(0.947805f), TypeParam(0.279352f), TypeParam(0.828246f), TypeParam(0.571205f), TypeParam(1.646163f), TypeParam(0.714257f), TypeParam(0.049881f), TypeParam(-1.680014f), TypeParam(-0.056047f), TypeParam(0.892393f), TypeParam(0.250564f), TypeParam(0.138843f), TypeParam(0.178706f), TypeParam(0.161286f), TypeParam(0.036891f), TypeParam(-0.141908f), TypeParam(-0.510903f), TypeParam(0.733949f), TypeParam(-0.112944f), TypeParam(-0.581858f), TypeParam(-0.269439f), TypeParam(0.056781f), TypeParam(0.200325f), TypeParam(0.814038f), TypeParam(0.277386f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.173652f, -1.513725f, -0.704586f, -1.952375f, -0.699404f, -0.806298f, 1.640852f, -0.138969f, -0.695411f, -1.352111f, 0.568797f, -0.564294f, -0.056468f, 0.641604f, -0.438370f, 0.450167f, -1.091401f, 1.669729f, -0.908544f, 0.244467f, 0.172109f, 1.156741f, -0.617128f, 1.155460f}; + std::initializer_list X_data{TypeParam(-0.173652f), TypeParam(-1.513725f), TypeParam(-0.704586f), TypeParam(-1.952375f), TypeParam(-0.699404f), TypeParam(-0.806298f), TypeParam(1.640852f), TypeParam(-0.138969f), TypeParam(-0.695411f), TypeParam(-1.352111f), TypeParam(0.568797f), TypeParam(-0.564294f), TypeParam(-0.056468f), TypeParam(0.641604f), TypeParam(-0.438370f), TypeParam(0.450167f), TypeParam(-1.091401f), TypeParam(1.669729f), TypeParam(-0.908544f), TypeParam(0.244467f), TypeParam(0.172109f), TypeParam(1.156741f), TypeParam(-0.617128f), TypeParam(1.155460f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.252250f, -0.151452f, 0.824706f, -0.588292f, -0.591147f, -0.155082f, -0.732938f, 0.457493f, -0.439559f, 0.492330f, 0.696447f, 0.700722f, -0.220298f, 0.654884f, -0.635434f, -1.195619f, -0.114204f, -0.870080f, -0.929674f, 0.305035f, 1.025429f, -0.472240f, -0.067881f, -0.869393f}; + std::initializer_list Grid_data{TypeParam(0.252250f), TypeParam(-0.151452f), TypeParam(0.824706f), TypeParam(-0.588292f), TypeParam(-0.591147f), TypeParam(-0.155082f), TypeParam(-0.732938f), TypeParam(0.457493f), TypeParam(-0.439559f), TypeParam(0.492330f), TypeParam(0.696447f), TypeParam(0.700722f), TypeParam(-0.220298f), TypeParam(0.654884f), TypeParam(-0.635434f), TypeParam(-1.195619f), TypeParam(-0.114204f), TypeParam(-0.870080f), TypeParam(-0.929674f), TypeParam(0.305035f), TypeParam(1.025429f), TypeParam(-0.472240f), TypeParam(-0.067881f), TypeParam(-0.869393f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.538390f, -1.565293f, -0.581079f, -0.701030f, -0.725252f, -0.806298f, -0.850602f, -0.281588f, -0.151944f, 0.172138f, 0.177246f, -0.564294f, -0.316822f, -0.056468f, 0.212846f, -0.737167f, 0.585773f, 0.245182f, -0.111277f, -0.908544f, -0.463717f, -0.189009f, 0.510522f, -0.410307f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.538390f), TypeParam(-1.565293f), TypeParam(-0.581079f), TypeParam(-0.701030f), TypeParam(-0.725252f), TypeParam(-0.806298f), TypeParam(-0.850602f), TypeParam(-0.281588f), TypeParam(-0.151944f), TypeParam(0.172138f), TypeParam(0.177246f), TypeParam(-0.564294f), TypeParam(-0.316822f), TypeParam(-0.056468f), TypeParam(0.212846f), TypeParam(-0.737167f), TypeParam(0.585773f), TypeParam(0.245182f), TypeParam(-0.111277f), TypeParam(-0.908544f), TypeParam(-0.463717f), TypeParam(-0.189009f), TypeParam(0.510522f), TypeParam(-0.410307f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{1.179856f, 1.432512f, 1.016210f, -0.661096f, 0.335863f, 0.565957f, -0.517555f, 2.232456f, -0.615173f, -0.073628f, -0.260768f, -1.952025f, 0.304237f, 0.902323f, -0.485170f, 0.781595f, -1.777093f, -0.274107f, -1.030698f, 0.181435f, 1.947646f, 1.007702f, -0.100718f, 0.154090f, -0.483193f, 1.565921f, -0.932274f, 0.313820f, -0.439116f, -0.411861f, -0.821795f, -1.685022f, -0.013518f, 0.519914f, -0.175407f, -0.507962f, 0.050913f, 0.981904f, 1.087165f, 1.758657f, 0.075954f, -0.481552f, 0.085590f, 0.537831f, -0.419622f, -1.756791f, 1.324879f, -0.267061f, -0.683518f, 0.605393f, 0.041004f, -0.756742f, 0.744950f, -0.508619f, -0.594679f, -1.165646f, -0.699604f, -0.271502f, 0.437731f, -2.206233f, 1.088781f, -0.629873f, -0.904741f, -1.233533f, 2.466710f, -0.117309f, -0.684130f, 0.598811f, 0.288846f, -1.195569f, 0.935300f, 0.962852f}; + std::initializer_list X_data{TypeParam(1.179856f), TypeParam(1.432512f), TypeParam(1.016210f), TypeParam(-0.661096f), TypeParam(0.335863f), TypeParam(0.565957f), TypeParam(-0.517555f), TypeParam(2.232456f), TypeParam(-0.615173f), TypeParam(-0.073628f), TypeParam(-0.260768f), TypeParam(-1.952025f), TypeParam(0.304237f), TypeParam(0.902323f), TypeParam(-0.485170f), TypeParam(0.781595f), TypeParam(-1.777093f), TypeParam(-0.274107f), TypeParam(-1.030698f), TypeParam(0.181435f), TypeParam(1.947646f), TypeParam(1.007702f), TypeParam(-0.100718f), TypeParam(0.154090f), TypeParam(-0.483193f), TypeParam(1.565921f), TypeParam(-0.932274f), TypeParam(0.313820f), TypeParam(-0.439116f), TypeParam(-0.411861f), TypeParam(-0.821795f), TypeParam(-1.685022f), TypeParam(-0.013518f), TypeParam(0.519914f), TypeParam(-0.175407f), TypeParam(-0.507962f), TypeParam(0.050913f), TypeParam(0.981904f), TypeParam(1.087165f), TypeParam(1.758657f), TypeParam(0.075954f), TypeParam(-0.481552f), TypeParam(0.085590f), TypeParam(0.537831f), TypeParam(-0.419622f), TypeParam(-1.756791f), TypeParam(1.324879f), TypeParam(-0.267061f), TypeParam(-0.683518f), TypeParam(0.605393f), TypeParam(0.041004f), TypeParam(-0.756742f), TypeParam(0.744950f), TypeParam(-0.508619f), TypeParam(-0.594679f), TypeParam(-1.165646f), TypeParam(-0.699604f), TypeParam(-0.271502f), TypeParam(0.437731f), TypeParam(-2.206233f), TypeParam(1.088781f), TypeParam(-0.629873f), TypeParam(-0.904741f), TypeParam(-1.233533f), TypeParam(2.466710f), TypeParam(-0.117309f), TypeParam(-0.684130f), TypeParam(0.598811f), TypeParam(0.288846f), TypeParam(-1.195569f), TypeParam(0.935300f), TypeParam(0.962852f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.625842f, 0.210304f, -0.725943f, -0.553764f, -0.182412f, -0.296478f, -0.254040f, -0.820211f, 0.869312f, 0.622346f, 0.236815f, 0.271706f, 0.140482f, 0.897281f, 0.271537f, 0.182799f, -0.659653f, 0.400310f, -1.122656f, 0.378466f, -1.040147f, -0.496646f, 0.633526f, -0.714734f, 0.955528f, -0.663024f, 1.136629f, 0.369854f, -0.520025f, 0.731855f, -1.062711f, -0.760189f, -0.751812f, 0.157968f, 0.117892f, -1.032129f, 1.157953f, -0.001147f, -0.640796f, 0.028663f, -0.515104f, 0.331070f, 0.434411f, -0.340393f, 0.069958f, 0.714010f, -0.780518f, -0.267586f, -0.177029f, -0.793935f, 0.097737f, 0.044103f, -0.969274f, 0.246164f, 1.145360f, 0.638273f, -0.650926f, 1.098440f, -0.824873f, -0.610135f, 0.529312f, 0.954650f, 1.145143f, 1.033109f, -0.660775f, 0.274592f, -0.753497f, 0.026500f, 0.994206f, 0.590870f, -1.108049f, -0.516447f, -1.012489f, 0.565286f, -0.152334f, -0.877228f, -0.383453f, 0.393797f, 0.111096f, 1.125969f, -0.015932f, 0.377468f, -0.363512f, 0.143194f, 0.042988f, 1.030777f, 0.502813f, -0.683870f, -1.066269f, -1.141727f, -0.435790f, 0.155118f, 1.128919f, -0.117905f, 0.469189f, 0.609870f, -0.919201f, -0.992659f, 0.454699f, 0.559331f, -0.558762f, 0.188050f, -1.174933f, 0.015126f, 0.294147f, 0.011359f, -0.190476f, 0.499476f}; + std::initializer_list Grid_data{TypeParam(0.625842f), TypeParam(0.210304f), TypeParam(-0.725943f), TypeParam(-0.553764f), TypeParam(-0.182412f), TypeParam(-0.296478f), TypeParam(-0.254040f), TypeParam(-0.820211f), TypeParam(0.869312f), TypeParam(0.622346f), TypeParam(0.236815f), TypeParam(0.271706f), TypeParam(0.140482f), TypeParam(0.897281f), TypeParam(0.271537f), TypeParam(0.182799f), TypeParam(-0.659653f), TypeParam(0.400310f), TypeParam(-1.122656f), TypeParam(0.378466f), TypeParam(-1.040147f), TypeParam(-0.496646f), TypeParam(0.633526f), TypeParam(-0.714734f), TypeParam(0.955528f), TypeParam(-0.663024f), TypeParam(1.136629f), TypeParam(0.369854f), TypeParam(-0.520025f), TypeParam(0.731855f), TypeParam(-1.062711f), TypeParam(-0.760189f), TypeParam(-0.751812f), TypeParam(0.157968f), TypeParam(0.117892f), TypeParam(-1.032129f), TypeParam(1.157953f), TypeParam(-0.001147f), TypeParam(-0.640796f), TypeParam(0.028663f), TypeParam(-0.515104f), TypeParam(0.331070f), TypeParam(0.434411f), TypeParam(-0.340393f), TypeParam(0.069958f), TypeParam(0.714010f), TypeParam(-0.780518f), TypeParam(-0.267586f), TypeParam(-0.177029f), TypeParam(-0.793935f), TypeParam(0.097737f), TypeParam(0.044103f), TypeParam(-0.969274f), TypeParam(0.246164f), TypeParam(1.145360f), TypeParam(0.638273f), TypeParam(-0.650926f), TypeParam(1.098440f), TypeParam(-0.824873f), TypeParam(-0.610135f), TypeParam(0.529312f), TypeParam(0.954650f), TypeParam(1.145143f), TypeParam(1.033109f), TypeParam(-0.660775f), TypeParam(0.274592f), TypeParam(-0.753497f), TypeParam(0.026500f), TypeParam(0.994206f), TypeParam(0.590870f), TypeParam(-1.108049f), TypeParam(-0.516447f), TypeParam(-1.012489f), TypeParam(0.565286f), TypeParam(-0.152334f), TypeParam(-0.877228f), TypeParam(-0.383453f), TypeParam(0.393797f), TypeParam(0.111096f), TypeParam(1.125969f), TypeParam(-0.015932f), TypeParam(0.377468f), TypeParam(-0.363512f), TypeParam(0.143194f), TypeParam(0.042988f), TypeParam(1.030777f), TypeParam(0.502813f), TypeParam(-0.683870f), TypeParam(-1.066269f), TypeParam(-1.141727f), TypeParam(-0.435790f), TypeParam(0.155118f), TypeParam(1.128919f), TypeParam(-0.117905f), TypeParam(0.469189f), TypeParam(0.609870f), TypeParam(-0.919201f), TypeParam(-0.992659f), TypeParam(0.454699f), TypeParam(0.559331f), TypeParam(-0.558762f), TypeParam(0.188050f), TypeParam(-1.174933f), TypeParam(0.015126f), TypeParam(0.294147f), TypeParam(0.011359f), TypeParam(-0.190476f), TypeParam(0.499476f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.274014f, 0.145076f, 0.451342f, -0.273219f, -1.128307f, 0.962473f, 0.629978f, 0.370138f, 0.901663f, 0.778787f, 1.179856f, 0.014218f, -0.634683f, 0.585419f, 0.972130f, 1.911376f, 0.389205f, 0.849839f, 0.738424f, 0.054296f, -1.034114f, 0.096287f, -0.408114f, -0.474491f, 0.784791f, 0.001762f, -1.672976f, -1.127656f, -1.030698f, 1.105979f, 0.979492f, -0.258014f, 0.693543f, 1.010218f, -0.008927f, -0.078404f, -0.384825f, 0.944247f, -0.508619f, 0.548774f, 0.068986f, 0.881841f, 0.869967f, -0.274754f, 0.337312f, -0.374188f, 0.161655f, 0.050913f, 0.146763f, 0.119233f, -0.438980f, 0.228062f, -0.187221f, -0.376543f, -2.077576f, -1.120214f, 0.962852f, -0.133462f, 0.314542f, -1.044921f, 1.568017f, -0.060947f, 0.838264f, -0.652863f, 0.978122f, -0.594679f, 0.366536f, 0.596221f, -0.120431f, -0.435362f, -0.328892f, -0.434798f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.274014f), TypeParam(0.145076f), TypeParam(0.451342f), TypeParam(-0.273219f), TypeParam(-1.128307f), TypeParam(0.962473f), TypeParam(0.629978f), TypeParam(0.370138f), TypeParam(0.901663f), TypeParam(0.778787f), TypeParam(1.179856f), TypeParam(0.014218f), TypeParam(-0.634683f), TypeParam(0.585419f), TypeParam(0.972130f), TypeParam(1.911376f), TypeParam(0.389205f), TypeParam(0.849839f), TypeParam(0.738424f), TypeParam(0.054296f), TypeParam(-1.034114f), TypeParam(0.096287f), TypeParam(-0.408114f), TypeParam(-0.474491f), TypeParam(0.784791f), TypeParam(0.001762f), TypeParam(-1.672976f), TypeParam(-1.127656f), TypeParam(-1.030698f), TypeParam(1.105979f), TypeParam(0.979492f), TypeParam(-0.258014f), TypeParam(0.693543f), TypeParam(1.010218f), TypeParam(-0.008927f), TypeParam(-0.078404f), TypeParam(-0.384825f), TypeParam(0.944247f), TypeParam(-0.508619f), TypeParam(0.548774f), TypeParam(0.068986f), TypeParam(0.881841f), TypeParam(0.869967f), TypeParam(-0.274754f), TypeParam(0.337312f), TypeParam(-0.374188f), TypeParam(0.161655f), TypeParam(0.050913f), TypeParam(0.146763f), TypeParam(0.119233f), TypeParam(-0.438980f), TypeParam(0.228062f), TypeParam(-0.187221f), TypeParam(-0.376543f), TypeParam(-2.077576f), TypeParam(-1.120214f), TypeParam(0.962852f), TypeParam(-0.133462f), TypeParam(0.314542f), TypeParam(-1.044921f), TypeParam(1.568017f), TypeParam(-0.060947f), TypeParam(0.838264f), TypeParam(-0.652863f), TypeParam(0.978122f), TypeParam(-0.594679f), TypeParam(0.366536f), TypeParam(0.596221f), TypeParam(-0.120431f), TypeParam(-0.435362f), TypeParam(-0.328892f), TypeParam(-0.434798f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.741614f, -1.612838f, 0.274100f, -0.685296f, -0.032079f, -0.246424f, 0.089412f, -0.776545f, -0.152179f, 0.312533f, -1.503701f, -0.720829f, 0.877575f, 0.407229f, -0.889951f, 0.603605f, -0.140859f, 2.032775f, -0.520668f, 1.063163f, -1.008883f, 0.194195f, -0.303240f, -0.967884f}; + std::initializer_list X_data{TypeParam(0.741614f), TypeParam(-1.612838f), TypeParam(0.274100f), TypeParam(-0.685296f), TypeParam(-0.032079f), TypeParam(-0.246424f), TypeParam(0.089412f), TypeParam(-0.776545f), TypeParam(-0.152179f), TypeParam(0.312533f), TypeParam(-1.503701f), TypeParam(-0.720829f), TypeParam(0.877575f), TypeParam(0.407229f), TypeParam(-0.889951f), TypeParam(0.603605f), TypeParam(-0.140859f), TypeParam(2.032775f), TypeParam(-0.520668f), TypeParam(1.063163f), TypeParam(-1.008883f), TypeParam(0.194195f), TypeParam(-0.303240f), TypeParam(-0.967884f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.932019f, -0.034394f, 0.554511f, 0.484230f, 0.141120f, 0.485083f, -0.836516f, 0.999462f, 0.026764f, 0.775689f, 0.265464f, -0.133497f, 0.514005f, 1.139161f, 1.183700f, -1.010095f, 0.072779f, -0.862052f, 0.699178f, 0.861473f, -0.842637f, -0.069355f, 0.830374f, 0.793568f}; + std::initializer_list Grid_data{TypeParam(-0.932019f), TypeParam(-0.034394f), TypeParam(0.554511f), TypeParam(0.484230f), TypeParam(0.141120f), TypeParam(0.485083f), TypeParam(-0.836516f), TypeParam(0.999462f), TypeParam(0.026764f), TypeParam(0.775689f), TypeParam(0.265464f), TypeParam(-0.133497f), TypeParam(0.514005f), TypeParam(1.139161f), TypeParam(1.183700f), TypeParam(-1.010095f), TypeParam(0.072779f), TypeParam(-0.862052f), TypeParam(0.699178f), TypeParam(0.861473f), TypeParam(-0.842637f), TypeParam(-0.069355f), TypeParam(0.830374f), TypeParam(0.793568f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.274192f, -0.348792f, -0.238780f, -0.048938f, -0.195915f, -0.488976f, -0.104505f, -0.351103f, -0.583059f, -1.533095f, -1.141282f, 0.187052f, 1.668728f, 0.345182f, 0.682750f, 1.893112f, -0.775917f, 1.920082f, -0.889375f, 1.071508f, 0.336517f, -0.933740f, -0.981629f, -0.893789f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.274192f), TypeParam(-0.348792f), TypeParam(-0.238780f), TypeParam(-0.048938f), TypeParam(-0.195915f), TypeParam(-0.488976f), TypeParam(-0.104505f), TypeParam(-0.351103f), TypeParam(-0.583059f), TypeParam(-1.533095f), TypeParam(-1.141282f), TypeParam(0.187052f), TypeParam(1.668728f), TypeParam(0.345182f), TypeParam(0.682750f), TypeParam(1.893112f), TypeParam(-0.775917f), TypeParam(1.920082f), TypeParam(-0.889375f), TypeParam(1.071508f), TypeParam(0.336517f), TypeParam(-0.933740f), TypeParam(-0.981629f), TypeParam(-0.893789f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.333395f, 0.977190f, 0.214232f, 0.363731f, -1.352515f, -0.980304f, -0.354887f, -0.481711f, -0.607915f, -0.309748f, 2.262781f, 0.963363f, 1.997079f, 0.987449f, -0.537662f, 1.011585f, 0.822184f, 0.567108f, 0.135401f, -0.943315f, -0.614181f, 0.030652f, 0.914757f, 0.971777f}; + std::initializer_list X_data{TypeParam(0.333395f), TypeParam(0.977190f), TypeParam(0.214232f), TypeParam(0.363731f), TypeParam(-1.352515f), TypeParam(-0.980304f), TypeParam(-0.354887f), TypeParam(-0.481711f), TypeParam(-0.607915f), TypeParam(-0.309748f), TypeParam(2.262781f), TypeParam(0.963363f), TypeParam(1.997079f), TypeParam(0.987449f), TypeParam(-0.537662f), TypeParam(1.011585f), TypeParam(0.822184f), TypeParam(0.567108f), TypeParam(0.135401f), TypeParam(-0.943315f), TypeParam(-0.614181f), TypeParam(0.030652f), TypeParam(0.914757f), TypeParam(0.971777f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.487111f, 0.913573f, 0.641905f, -0.093110f, 0.512522f, 0.358369f, 0.655341f, -0.964320f, 0.370929f, -1.136512f, -0.789199f, -0.447185f, -0.116915f, -1.132446f, 0.029865f, 0.191588f, -0.476239f, 0.389224f, 1.048588f, -0.204978f, -0.639094f, -1.062994f, -0.876243f, -0.663705f}; + std::initializer_list Grid_data{TypeParam(-0.487111f), TypeParam(0.913573f), TypeParam(0.641905f), TypeParam(-0.093110f), TypeParam(0.512522f), TypeParam(0.358369f), TypeParam(0.655341f), TypeParam(-0.964320f), TypeParam(0.370929f), TypeParam(-1.136512f), TypeParam(-0.789199f), TypeParam(-0.447185f), TypeParam(-0.116915f), TypeParam(-1.132446f), TypeParam(0.029865f), TypeParam(0.191588f), TypeParam(-0.476239f), TypeParam(0.389224f), TypeParam(1.048588f), TypeParam(-0.204978f), TypeParam(-0.639094f), TypeParam(-1.062994f), TypeParam(-0.876243f), TypeParam(-0.663705f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.051920f, 0.501832f, -0.508839f, 0.563480f, 0.297178f, 0.246571f, 1.781955f, -0.353574f, 0.481200f, -0.258839f, -0.145200f, -0.469558f, 0.624262f, 0.351267f, 0.180256f, 0.571859f, 0.903895f, 1.383745f, -0.081406f, 0.133665f, 0.348401f, -0.164219f, 0.138237f, 0.203282f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.051920f), TypeParam(0.501832f), TypeParam(-0.508839f), TypeParam(0.563480f), TypeParam(0.297178f), TypeParam(0.246571f), TypeParam(1.781955f), TypeParam(-0.353574f), TypeParam(0.481200f), TypeParam(-0.258839f), TypeParam(-0.145200f), TypeParam(-0.469558f), TypeParam(0.624262f), TypeParam(0.351267f), TypeParam(0.180256f), TypeParam(0.571859f), TypeParam(0.903895f), TypeParam(1.383745f), TypeParam(-0.081406f), TypeParam(0.133665f), TypeParam(0.348401f), TypeParam(-0.164219f), TypeParam(0.138237f), TypeParam(0.203282f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.480448f, 0.682093f, 0.237716f, -1.234307f, 2.139750f, 2.410321f, 0.491472f, -0.553422f, 0.032129f, -0.162503f, 0.144036f, -1.889875f, -0.293944f, -1.390146f, -1.552136f, 1.604720f, -1.707202f, 0.182427f, -0.631000f, 0.196649f, 0.427711f, -0.014224f, -1.319834f, -2.703346f}; + std::initializer_list X_data{TypeParam(-0.480448f), TypeParam(0.682093f), TypeParam(0.237716f), TypeParam(-1.234307f), TypeParam(2.139750f), TypeParam(2.410321f), TypeParam(0.491472f), TypeParam(-0.553422f), TypeParam(0.032129f), TypeParam(-0.162503f), TypeParam(0.144036f), TypeParam(-1.889875f), TypeParam(-0.293944f), TypeParam(-1.390146f), TypeParam(-1.552136f), TypeParam(1.604720f), TypeParam(-1.707202f), TypeParam(0.182427f), TypeParam(-0.631000f), TypeParam(0.196649f), TypeParam(0.427711f), TypeParam(-0.014224f), TypeParam(-1.319834f), TypeParam(-2.703346f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.503717f, 0.572989f, 0.179517f, -0.060398f, 0.503876f, 0.288627f, -1.148268f, 0.194010f, -0.532910f, -0.636357f, 0.464076f, 0.245386f, 0.203212f, -0.569260f, 0.554489f, 1.126118f, 0.146805f, 0.493232f, -1.052794f, 0.713394f, 0.416866f, 0.540634f, 0.500415f, -0.315629f}; + std::initializer_list Grid_data{TypeParam(0.503717f), TypeParam(0.572989f), TypeParam(0.179517f), TypeParam(-0.060398f), TypeParam(0.503876f), TypeParam(0.288627f), TypeParam(-1.148268f), TypeParam(0.194010f), TypeParam(-0.532910f), TypeParam(-0.636357f), TypeParam(0.464076f), TypeParam(0.245386f), TypeParam(0.203212f), TypeParam(-0.569260f), TypeParam(0.554489f), TypeParam(1.126118f), TypeParam(0.146805f), TypeParam(0.493232f), TypeParam(-1.052794f), TypeParam(0.713394f), TypeParam(0.416866f), TypeParam(0.540634f), TypeParam(0.500415f), TypeParam(-0.315629f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.885659f, -0.722912f, -0.180469f, 0.697015f, -0.322127f, -0.292851f, -0.867861f, -0.047527f, -0.447720f, 0.028100f, 0.191874f, -0.378776f, -0.321888f, -0.277691f, -0.037604f, -1.766707f, 0.320836f, 0.415106f, 0.179209f, -2.609096f, -0.929794f, -0.788240f, -1.212243f, 0.337704f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.885659f), TypeParam(-0.722912f), TypeParam(-0.180469f), TypeParam(0.697015f), TypeParam(-0.322127f), TypeParam(-0.292851f), TypeParam(-0.867861f), TypeParam(-0.047527f), TypeParam(-0.447720f), TypeParam(0.028100f), TypeParam(0.191874f), TypeParam(-0.378776f), TypeParam(-0.321888f), TypeParam(-0.277691f), TypeParam(-0.037604f), TypeParam(-1.766707f), TypeParam(0.320836f), TypeParam(0.415106f), TypeParam(0.179209f), TypeParam(-2.609096f), TypeParam(-0.929794f), TypeParam(-0.788240f), TypeParam(-1.212243f), TypeParam(0.337704f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.924256f, -2.309784f, 1.272769f, 0.548427f, -1.478527f, -3.472946f, -1.252325f, 0.268589f, 0.326270f, 0.105016f, 0.515184f, -0.951158f, -0.658693f, -2.018776f, 0.981625f, -0.401504f, 1.560519f, -0.129836f, -1.876357f, 0.511516f, -1.825582f, 0.358958f, -0.805392f, -1.409127f}; + std::initializer_list X_data{TypeParam(-0.924256f), TypeParam(-2.309784f), TypeParam(1.272769f), TypeParam(0.548427f), TypeParam(-1.478527f), TypeParam(-3.472946f), TypeParam(-1.252325f), TypeParam(0.268589f), TypeParam(0.326270f), TypeParam(0.105016f), TypeParam(0.515184f), TypeParam(-0.951158f), TypeParam(-0.658693f), TypeParam(-2.018776f), TypeParam(0.981625f), TypeParam(-0.401504f), TypeParam(1.560519f), TypeParam(-0.129836f), TypeParam(-1.876357f), TypeParam(0.511516f), TypeParam(-1.825582f), TypeParam(0.358958f), TypeParam(-0.805392f), TypeParam(-1.409127f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.874856f, -1.090775f, 1.169192f, 0.447098f, 0.583418f, 0.267395f, 0.788144f, 1.129706f, -0.102229f, -0.984624f, 1.101916f, -0.253070f, -0.578731f, 0.738703f, 0.669694f, 0.160659f, -0.075327f, -0.229561f, 1.100291f, 0.731142f, 0.714643f, 0.765214f, -0.628031f, 0.250554f}; + std::initializer_list Grid_data{TypeParam(0.874856f), TypeParam(-1.090775f), TypeParam(1.169192f), TypeParam(0.447098f), TypeParam(0.583418f), TypeParam(0.267395f), TypeParam(0.788144f), TypeParam(1.129706f), TypeParam(-0.102229f), TypeParam(-0.984624f), TypeParam(1.101916f), TypeParam(-0.253070f), TypeParam(-0.578731f), TypeParam(0.738703f), TypeParam(0.669694f), TypeParam(0.160659f), TypeParam(-0.075327f), TypeParam(-0.229561f), TypeParam(1.100291f), TypeParam(0.731142f), TypeParam(0.714643f), TypeParam(0.765214f), TypeParam(-0.628031f), TypeParam(0.250554f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-2.647128f, -2.154235f, -0.768645f, -3.893546f, -1.698376f, -0.114530f, 0.458115f, -0.696657f, -0.370692f, -1.169692f, -0.754730f, 0.320002f, 1.683550f, -0.301499f, -0.176003f, -0.236653f, -0.278257f, 1.480160f, -0.700350f, 0.095525f, -0.891605f, -1.569065f, -1.633715f, -1.535763f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-2.647128f), TypeParam(-2.154235f), TypeParam(-0.768645f), TypeParam(-3.893546f), TypeParam(-1.698376f), TypeParam(-0.114530f), TypeParam(0.458115f), TypeParam(-0.696657f), TypeParam(-0.370692f), TypeParam(-1.169692f), TypeParam(-0.754730f), TypeParam(0.320002f), TypeParam(1.683550f), TypeParam(-0.301499f), TypeParam(-0.176003f), TypeParam(-0.236653f), TypeParam(-0.278257f), TypeParam(1.480160f), TypeParam(-0.700350f), TypeParam(0.095525f), TypeParam(-0.891605f), TypeParam(-1.569065f), TypeParam(-1.633715f), TypeParam(-1.535763f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.328038f, -0.658850f, -0.054298f, 0.012663f, -0.077366f, 0.644305f, -1.262985f, 0.922028f, 0.189962f, 0.518836f, 1.168413f, -0.286220f, 0.431207f, -0.295352f, -0.357675f, -0.311715f, 0.839514f, -0.651820f, -0.283934f, 0.430508f, 0.206334f, 0.765966f, -1.144732f, -0.507045f}; + std::initializer_list X_data{TypeParam(-0.328038f), TypeParam(-0.658850f), TypeParam(-0.054298f), TypeParam(0.012663f), TypeParam(-0.077366f), TypeParam(0.644305f), TypeParam(-1.262985f), TypeParam(0.922028f), TypeParam(0.189962f), TypeParam(0.518836f), TypeParam(1.168413f), TypeParam(-0.286220f), TypeParam(0.431207f), TypeParam(-0.295352f), TypeParam(-0.357675f), TypeParam(-0.311715f), TypeParam(0.839514f), TypeParam(-0.651820f), TypeParam(-0.283934f), TypeParam(0.430508f), TypeParam(0.206334f), TypeParam(0.765966f), TypeParam(-1.144732f), TypeParam(-0.507045f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.372000f, -1.056863f, -0.360826f, -0.268314f, 0.691035f, -0.595044f, 0.720198f, 0.166462f, -0.201118f, -1.069416f, 1.184721f, -0.213980f, 0.755038f, -0.620722f, -1.168597f, -0.956522f, -0.614982f, -0.382162f, -0.169456f, 1.000817f, -1.106710f, 0.598940f, 1.009714f, 0.007723f}; + std::initializer_list Grid_data{TypeParam(-0.372000f), TypeParam(-1.056863f), TypeParam(-0.360826f), TypeParam(-0.268314f), TypeParam(0.691035f), TypeParam(-0.595044f), TypeParam(0.720198f), TypeParam(0.166462f), TypeParam(-0.201118f), TypeParam(-1.069416f), TypeParam(1.184721f), TypeParam(-0.213980f), TypeParam(0.755038f), TypeParam(-0.620722f), TypeParam(-1.168597f), TypeParam(-0.956522f), TypeParam(-0.614982f), TypeParam(-0.382162f), TypeParam(-0.169456f), TypeParam(1.000817f), TypeParam(-1.106710f), TypeParam(0.598940f), TypeParam(1.009714f), TypeParam(0.007723f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.403118f, -0.158055f, -0.496030f, 0.161379f, -0.440603f, -0.193607f, -0.746082f, -0.076433f, 0.751030f, 0.360851f, -0.488453f, 0.664305f, -0.259139f, 0.411796f, -0.156648f, 0.281569f, 0.437515f, -0.313812f, 0.573781f, -0.265706f, 0.200380f, -0.906155f, -0.724311f, 0.760352f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.403118f), TypeParam(-0.158055f), TypeParam(-0.496030f), TypeParam(0.161379f), TypeParam(-0.440603f), TypeParam(-0.193607f), TypeParam(-0.746082f), TypeParam(-0.076433f), TypeParam(0.751030f), TypeParam(0.360851f), TypeParam(-0.488453f), TypeParam(0.664305f), TypeParam(-0.259139f), TypeParam(0.411796f), TypeParam(-0.156648f), TypeParam(0.281569f), TypeParam(0.437515f), TypeParam(-0.313812f), TypeParam(0.573781f), TypeParam(-0.265706f), TypeParam(0.200380f), TypeParam(-0.906155f), TypeParam(-0.724311f), TypeParam(0.760352f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.290962f, 0.867797f, -0.085436f, -1.597520f, 0.695524f, 0.838739f, 0.513032f, 0.166242f, -0.546135f, -0.780313f, -0.512993f, -0.449479f, 1.594718f, 0.953375f, 0.692587f, -0.798364f, -0.128799f, -0.456210f, 2.098909f, -1.561220f, 1.713821f, -0.701970f, -0.287280f, -1.708048f}; + std::initializer_list X_data{TypeParam(-0.290962f), TypeParam(0.867797f), TypeParam(-0.085436f), TypeParam(-1.597520f), TypeParam(0.695524f), TypeParam(0.838739f), TypeParam(0.513032f), TypeParam(0.166242f), TypeParam(-0.546135f), TypeParam(-0.780313f), TypeParam(-0.512993f), TypeParam(-0.449479f), TypeParam(1.594718f), TypeParam(0.953375f), TypeParam(0.692587f), TypeParam(-0.798364f), TypeParam(-0.128799f), TypeParam(-0.456210f), TypeParam(2.098909f), TypeParam(-1.561220f), TypeParam(1.713821f), TypeParam(-0.701970f), TypeParam(-0.287280f), TypeParam(-1.708048f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.934471f, 0.728362f, -0.458301f, -1.040800f, 0.157908f, 0.753451f, -0.122762f, 0.100970f, 0.889432f, 0.495471f, 0.897108f, 0.176205f, 0.134514f, -0.287037f, -0.202498f, -0.637759f, 0.802292f, 1.094459f, 0.445338f, 0.034096f, -0.396126f, -1.184798f, -0.222199f, -0.851887f}; + std::initializer_list Grid_data{TypeParam(0.934471f), TypeParam(0.728362f), TypeParam(-0.458301f), TypeParam(-1.040800f), TypeParam(0.157908f), TypeParam(0.753451f), TypeParam(-0.122762f), TypeParam(0.100970f), TypeParam(0.889432f), TypeParam(0.495471f), TypeParam(0.897108f), TypeParam(0.176205f), TypeParam(0.134514f), TypeParam(-0.287037f), TypeParam(-0.202498f), TypeParam(-0.637759f), TypeParam(0.802292f), TypeParam(1.094459f), TypeParam(0.445338f), TypeParam(0.034096f), TypeParam(-0.396126f), TypeParam(-1.184798f), TypeParam(-0.222199f), TypeParam(-0.851887f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.037788f, -0.275160f, 0.953595f, -0.518196f, 0.118127f, -1.525148f, -0.413483f, 0.696689f, -0.450182f, -0.696169f, -0.561886f, -0.828986f, 0.343953f, 1.379632f, -0.417260f, -0.781500f, 1.666511f, 1.599268f, 0.106200f, 1.088396f, -2.079140f, -0.612122f, 1.822402f, 1.173807f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.037788f), TypeParam(-0.275160f), TypeParam(0.953595f), TypeParam(-0.518196f), TypeParam(0.118127f), TypeParam(-1.525148f), TypeParam(-0.413483f), TypeParam(0.696689f), TypeParam(-0.450182f), TypeParam(-0.696169f), TypeParam(-0.561886f), TypeParam(-0.828986f), TypeParam(0.343953f), TypeParam(1.379632f), TypeParam(-0.417260f), TypeParam(-0.781500f), TypeParam(1.666511f), TypeParam(1.599268f), TypeParam(0.106200f), TypeParam(1.088396f), TypeParam(-2.079140f), TypeParam(-0.612122f), TypeParam(1.822402f), TypeParam(1.173807f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py index c7e263ca3f654..bf58a5d3fc1d5 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py @@ -14,6 +14,17 @@ padding_modes = ["zeros", "border", "reflection"] align_corners_options = [True, False] +print( + """ +template +class GridSampleTest : public ::testing::Test { +}; + +using GridSampleTestTypes = ::testing::Types; +TYPED_TEST_SUITE(GridSampleTest, GridSampleTestTypes); + +""" +) # Loop over the combinations of parameters torch.manual_seed(0) for opset_version in [16, 20]: @@ -42,11 +53,15 @@ input_tensor, grid_tensor, mode=mode, padding_mode=padding_mode, align_corners=align_corners ) - X_data_str = "{" + ", ".join([f"{x:.6f}f" for x in input_tensor.numpy().flatten()]) + "}" - Grid_data_str = "{" + ", ".join([f"{x:.6f}f" for x in grid_tensor.numpy().flatten()]) + "}" + X_data_str = "{" + ", ".join([f"TypeParam({x:.6f}f)" for x in input_tensor.numpy().flatten()]) + "}" + Grid_data_str = ( + "{" + ", ".join([f"TypeParam({x:.6f}f)" for x in grid_tensor.numpy().flatten()]) + "}" + ) Y_shape = output_tensor.shape - Y_data_str = "{" + ", ".join([f"{x:.6f}f" for x in output_tensor.numpy().flatten()]) + "}" + Y_data_str = ( + "{" + ", ".join([f"TypeParam({x:.6f}f)" for x in output_tensor.numpy().flatten()]) + "}" + ) onnx_mode = mode if opset_version >= 20: @@ -58,24 +73,25 @@ onnx_align_corners = 1 if align_corners else 0 test_name = f"test_grid_sample_{opset_version}_{ndim}D_{mode}_{padding_mode}_{'align_corners' if align_corners else 'no_align_corners'}" - print(f"TEST(GridSampleTest, {test_name}) {{") - print(f'OpTester test("GridSample", {opset_version});') - print(f'std::string mode = "{onnx_mode}";') - print(f'std::string padding_mode = "{padding_mode}";') - print(f"int64_t align_corners = {onnx_align_corners};") - print(f"std::initializer_list X_shape {{ {', '.join(map(str, input_shape))} }};") - print(f"std::initializer_list X_data { X_data_str };") - print(f"std::initializer_list Grid_shape {{ {', '.join(map(str, grid_shape))} }};") - print(f"std::initializer_list Grid_data { Grid_data_str };") - print(f"std::initializer_list Y_shape {{ {', '.join(map(str, Y_shape))} }};") - print(f"std::initializer_list Y_data { Y_data_str };") - - print('test.AddInput("X", X_shape, X_data);') - print('test.AddInput("Grid", Grid_shape, Grid_data);') - print('test.AddAttribute("mode", mode);') - print('test.AddAttribute("padding_mode", padding_mode);') - print('test.AddAttribute("align_corners", align_corners);') - print('test.AddOutput("Y", Y_shape, Y_data);') - print(f"RunTests(test, GetExecutionProviders({opset_version}));") + spaces = " " + print(f"TYPED_TEST(GridSampleTest, {test_name}) {{") + print(f'{spaces}OpTester test("GridSample", {opset_version});') + print(f'{spaces}std::string mode = "{onnx_mode}";') + print(f'{spaces}std::string padding_mode = "{padding_mode}";') + print(f"{spaces}int64_t align_corners = {onnx_align_corners};") + print(f"{spaces}std::initializer_list X_shape {{ {', '.join(map(str, input_shape))} }};") + print(f"{spaces}std::initializer_list X_data { X_data_str };") + print(f"{spaces}std::initializer_list Grid_shape {{ {', '.join(map(str, grid_shape))} }};") + print(f"{spaces}std::initializer_list Grid_data { Grid_data_str };") + print(f"{spaces}std::initializer_list Y_shape {{ {', '.join(map(str, Y_shape))} }};") + print(f"{spaces}std::initializer_list Y_data { Y_data_str };") + + print(f'{spaces}test.AddInput("X", X_shape, X_data);') + print(f'{spaces}test.AddInput("Grid", Grid_shape, Grid_data);') + print(f'{spaces}test.AddAttribute("mode", mode);') + print(f'{spaces}test.AddAttribute("padding_mode", padding_mode);') + print(f'{spaces}test.AddAttribute("align_corners", align_corners);') + print(f'{spaces}test.AddOutput("Y", Y_shape, Y_data);') + print(f"{spaces}RunTests(test, GetExecutionProviders({opset_version}));") print("}") print("\n") diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 111520ef03e26..84fb6157b8884 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -10,6 +10,13 @@ namespace onnxruntime { namespace test { +template +class ResizeOpTest : public ::testing::Test { +}; + +using ResizeOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ResizeOpTest, ResizeOpTestTypes); + TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { @@ -226,26 +233,26 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); } -TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { +TYPED_TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); - std::vector roi{}; + std::vector roi{}; std::vector scales{1.0f, 1.0f, 0.6f, 0.6f}; test.AddAttribute("mode", "linear"); constexpr int64_t N = 1, C = 1, H = 2, W = 4; - std::vector X = { - 1.0f, 2.0f, 3.0f, 4.0f, - 5.0f, 6.0f, 7.0f, 8.0f}; + std::vector X = { + TypeParam(1.0f), TypeParam(2.0f), TypeParam(3.0f), TypeParam(4.0f), + TypeParam(5.0f), TypeParam(6.0f), TypeParam(7.0f), TypeParam(8.0f)}; - test.AddInput("X", {N, C, H, W}, X); - test.AddInput("roi", {0}, roi); + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); test.AddInput("scales", {4}, scales, scales_in_initializer); - std::vector Y = {2.66666651f, 4.3333331f}; + std::vector Y = {TypeParam(2.66666651f), TypeParam(4.3333331f)}; - test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); + test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); // QNN: result diff // TRT: Segmentation fault in A100 std::unordered_set excluded_providers({kQnnExecutionProvider}); diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index 83b308b57f26b..a32d43f296250 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -263,17 +263,33 @@ TEST(SliceTest, Slice3D) { 332.0f, 333.0f}); } -template +template +static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + std::vector inputs_T(inputs.size()); + if constexpr (std::is_same::value) { + return inputs; + } else if constexpr (std::is_integral_v) { + for (size_t i = 0; i < inputs.size(); i++) { + inputs_T[i] = static_cast(inputs[i]); + } + return inputs_T; + } else { + ConvertFloatToMLFloat16(inputs.data(), inputs_T.data(), inputs.size()); + return inputs_T; + } +} + +template static void TestSlice1DIntData() { - static_assert(std::is_integral_v); - RunSliceTest({6}, - {0, 1, 2, 3, 4, 5}, - {2}, - {4}, - {0}, - {}, - {2}, - {2, 3}); + // static_assert(std::is_integral_v); + RunSliceTest({6}, + GetTypedArray({0.f, 1.f, 2.f, 3.f, 4.f, 5.f}), + {2}, + {4}, + {0}, + {}, + {2}, + GetTypedArray({2.f, 3.f})); } TEST(SliceTest, Slice1D_Int32) { @@ -284,6 +300,21 @@ TEST(SliceTest, Slice1D_Int64) { TestSlice1DIntData(); } +TEST(SliceTest, Slice1D_Float) { + TestSlice1DIntData(); +} + +TEST(SliceTest, Slice1D_Float16) { + TestSlice1DIntData(); +} + +template +class SliceTest : public ::testing::Test { +}; + +using SliceTestTypes = ::testing::Types; +TYPED_TEST_SUITE(SliceTest, SliceTestTypes); + TEST(SliceTest, Slice1D_String) { RunSliceTest({6}, {"0", "1", "2", "3", "4", "5"}, @@ -296,16 +327,16 @@ TEST(SliceTest, Slice1D_String) { } // Only Slice V10 can run the following tests -TEST(SliceTest, Slice1D_WithPositiveSteps) { - RunSliceTest({6}, - {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, - {0}, - {6}, - {0}, - {2}, - {3}, - {0.0f, 2.0f, 4.0f}, - true); +TYPED_TEST(SliceTest, Slice1D_WithPositiveSteps) { + RunSliceTest({6}, + GetTypedArray({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}), + {0}, + {6}, + {0}, + {2}, + {3}, + GetTypedArray({0.0f, 2.0f, 4.0f}), + true); } // In numpy: diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index a0c1d675f506f..4954b82690e0f 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -4,10 +4,18 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "core/providers/cpu/tensor/space_depth_ops.h" +#include "core/mlas/inc/mlas.h" namespace onnxruntime { namespace test { +template +class TensorOpTest : public ::testing::Test { +}; + +using TensorOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(TensorOpTest, TensorOpTestTypes); + TEST(TensorOpTest, SpaceToDepthTest_1) { OpTester test("SpaceToDepth"); constexpr int64_t blocksize = 2; @@ -36,8 +44,8 @@ TEST(TensorOpTest, SpaceToDepthTest_1) { 3.1f, 3.3f}; test.AddOutput("output", {N, C * blocksize * blocksize, H / blocksize, W / blocksize}, result); - // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test - // is fixed. + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test is fixed. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } @@ -103,8 +111,8 @@ TEST(TensorOpTest, SpaceToDepthTest_2) { 88., 103., 106., 68., 71., 86., 89., 104., 107.}; test.AddOutput("output", {2, 27, 1, 2}, result); - // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 - // test is fixed. + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 test is fixed. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } @@ -259,7 +267,7 @@ TEST(TensorOpTest, DepthToSpaceTest_2) { test.Run(); } -TEST(TensorOpTest, DepthToSpaceTest_3) { +TYPED_TEST(TensorOpTest, DepthToSpaceTest_3) { OpTester test("DepthToSpace", 11); // create an opset 11 model with missing default attribute constexpr int64_t blocksize = 2; test.AddAttribute("blocksize", blocksize); @@ -281,8 +289,6 @@ TEST(TensorOpTest, DepthToSpaceTest_3) { 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143.}; - test.AddInput("input", {N, C, H, W}, X); - const std::vector result = { 0., 18., 1., 19., 36., 54., 37., 55., 2., 20., 3., 21., 38., 56., 39., 57., 4., 22., 5., 23., 40., 58., @@ -298,11 +304,24 @@ TEST(TensorOpTest, DepthToSpaceTest_3) { 102., 85., 103., 120., 138., 121., 139., 86., 104., 87., 105., 122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125., 143.}; - test.AddOutput("output", {2, 3, 6, 4}, result); - test.Run(); + + if constexpr (std::is_same::value) { + test.AddInput("input", {N, C, H, W}, X); + test.AddOutput("output", {2, 3, 6, 4}, result); + } else { + std::vector X_fp16(X.size()); + std::vector result_fp16(result.size()); + ConvertFloatToMLFloat16(result.data(), result_fp16.data(), result.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + test.AddOutput("output", {2, 3, 6, 4}, result_fp16); + test.AddInput("input", {N, C, H, W}, X_fp16); + } + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test is fixed. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } -TEST(TensorOpTest, DepthToSpaceTest_4) { +TYPED_TEST(TensorOpTest, DepthToSpaceTest_4) { OpTester test("DepthToSpace", 11); // create an opset 11 model with attribute present = "DCR" mode constexpr int64_t blocksize = 2; test.AddAttribute("blocksize", blocksize); @@ -325,8 +344,6 @@ TEST(TensorOpTest, DepthToSpaceTest_4) { 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143.}; - test.AddInput("input", {N, C, H, W}, X); - const std::vector result = { 0., 18., 1., 19., 36., 54., 37., 55., 2., 20., 3., 21., 38., 56., 39., 57., 4., 22., 5., 23., 40., 58., @@ -342,11 +359,25 @@ TEST(TensorOpTest, DepthToSpaceTest_4) { 102., 85., 103., 120., 138., 121., 139., 86., 104., 87., 105., 122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125., 143.}; - test.AddOutput("output", {2, 3, 6, 4}, result); - test.Run(); + + if constexpr (std::is_same::value) { + test.AddInput("input", {N, C, H, W}, X); + test.AddOutput("output", {2, 3, 6, 4}, result); + } else { + std::vector X_fp16(X.size()); + std::vector result_fp16(result.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + ConvertFloatToMLFloat16(result.data(), result_fp16.data(), result.size()); + test.AddInput("input", {N, C, H, W}, X_fp16); + test.AddOutput("output", {2, 3, 6, 4}, result_fp16); + } + + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 test is fixed. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } -TEST(TensorOpTest, DepthToSpaceTest_5) { +TYPED_TEST(TensorOpTest, DepthToSpaceTest_5) { OpTester test("DepthToSpace", 11); // create an opset 11 model with attribute present = "CRD" mode constexpr int64_t blocksize = 2; test.AddAttribute("blocksize", blocksize); @@ -362,15 +393,25 @@ TEST(TensorOpTest, DepthToSpaceTest_5) { 27., 28., 29., 30., 31., 32.}; - test.AddInput("input", {N, C, H, W}, X); - const std::vector result = {0., 9., 1., 10., 2., 11., 18., 27., 19., 28., 20., 29., 3., 12., 4., 13., 5., 14., 21., 30., 22., 31., 23., 32.}; - test.AddOutput("output", {1, 1, 4, 6}, result); - test.Run(); + if constexpr (std::is_same::value) { + test.AddInput("input", {N, C, H, W}, X); + test.AddOutput("output", {1, 1, 4, 6}, result); + } else { + std::vector X_fp16(X.size()); + std::vector result_fp16(result.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + ConvertFloatToMLFloat16(result.data(), result_fp16.data(), result.size()); + test.AddInput("input", {N, C, H, W}, X_fp16); + test.AddOutput("output", {1, 1, 4, 6}, result_fp16); + } + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 test is fixed. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } TEST(TensorOpTest, DepthToSpaceTest_CRD_Batched) { diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 066302a4a37d1..48872404f08bd 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -178,6 +178,17 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) { RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); } +template +std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + TEST(SplitOperatorTest, Axis0UnequalSplitString) { constexpr int64_t axis = 0; std::vector outputs; @@ -222,6 +233,26 @@ TEST(SplitOperatorTest, Axis1EqualSplitFloat) { RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } +TEST(SplitOperatorTest, Axis1EqualSplitFloat16) { + constexpr int64_t axis = 1; + std::vector> outputs; + + // input shape and data + ShapeAndData input = {{2, 4}, + GetTypedArray({1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f})}; + + outputs.push_back({{2, 2}, + GetTypedArray({1.f, 2.f, + 5.f, 6.f})}); + + outputs.push_back({{2, 2}, + GetTypedArray({3.f, 4.f, + 7.f, 8.f})}); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); +} + TEST(SplitOperatorTest, Axis1EqualSplitString) { constexpr int64_t axis = 1; std::vector outputs; diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 01dba55ceb8ed..3b46dc3f5d6a2 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -12,6 +12,13 @@ namespace onnxruntime { namespace test { +template +class TransposeOpTest : public ::testing::Test { +}; + +using TransposeOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(TransposeOpTest, TransposeOpTestTypes); + TEST(TransposeOpTest, IsTransposeReshapeTest) { std::vector input_dims{1, 2, 3, 4, 1}; std::vector perm{0, 1, 2, 3, 4}; @@ -62,18 +69,27 @@ void TransposeTest(const std::vector& input_shape, } } +template +std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + // Test 2 dimensional transpose, with no permutation attribute specified -TEST(TransposeOpTest, TwoDimNoAttr) { +TYPED_TEST(TransposeOpTest, TwoDimNoAttr) { std::vector input_shape({2, 3}); - std::vector input_vals = { - 1.0f, 2.0f, 3.0f, - 4.0f, 5.0f, 6.0f}; + std::vector input_vals = GetTypedArray({1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f}); std::vector expected_shape({3, 2}); - std::vector expected_vals = { - 1.0f, 4.0f, - 2.0f, 5.0f, - 3.0f, 6.0f}; + std::vector expected_vals = GetTypedArray({1.0f, 4.0f, + 2.0f, 5.0f, + 3.0f, 6.0f}); TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: SegFault error diff --git a/onnxruntime/test/providers/qnn/cast_test.cc b/onnxruntime/test/providers/qnn/cast_test.cc index f03782c33c30a..9b83dd281a56d 100644 --- a/onnxruntime/test/providers/qnn/cast_test.cc +++ b/onnxruntime/test/providers/qnn/cast_test.cc @@ -49,7 +49,8 @@ static GetTestModelFn BuildCastTestCase(const std::vector& shape, template static void RunCastOpTest(const std::vector& shape, ONNX_NAMESPACE::TensorProto_DataType dst_type, ExpectedEPNodeAssignment expected_ep_assignment, - bool use_htp) { + bool use_htp, + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = use_htp ? "QnnHtp.dll" : "QnnCpu.dll"; @@ -57,6 +58,12 @@ static void RunCastOpTest(const std::vector& shape, ONNX_NAMESPACE::Ten provider_options["backend_path"] = use_htp ? "libQnnHtp.so" : "libQnnCpu.so"; #endif + if (use_htp && enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + RunQnnModelTest(BuildCastTestCase(shape, dst_type), provider_options, 13, // opset @@ -93,19 +100,19 @@ TEST_F(QnnCPUBackendTests, TestCastFloatToInt32) { // Cast int32_t to float on HTP TEST_F(QnnHTPBackendTests, TestCastInt32ToFloatHTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, - true); + true, false); } // Cast uint8_t to float on HTP TEST_F(QnnHTPBackendTests, TestCastUInt8ToFloatHTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, - true); + true, false); } // Cast float to int32_t on HTP TEST_F(QnnHTPBackendTests, TestCastFloatToInt32HTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, ExpectedEPNodeAssignment::All, - true); + true, false); } // Cast int64_t to int32_t on HTP diff --git a/onnxruntime/test/providers/qnn/clip_op_test.cc b/onnxruntime/test/providers/qnn/clip_op_test.cc index c3a75fd7446e2..cfa77a46210b3 100644 --- a/onnxruntime/test/providers/qnn/clip_op_test.cc +++ b/onnxruntime/test/providers/qnn/clip_op_test.cc @@ -21,7 +21,8 @@ static void RunClipTest(const TestInputDef& input_def, const std::vector>& min_max_defs, ExpectedEPNodeAssignment expected_ep_assignment, bool on_cpu_backend = true, - int opset = 13) { + int opset = 13, + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) @@ -30,6 +31,12 @@ static void RunClipTest(const TestInputDef& input_def, provider_options["backend_path"] = on_cpu_backend ? "libQnnCpu.so" : "libQnnHtp.so"; #endif + if (!on_cpu_backend && enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + RunQnnModelTest(BuildOpTestCase("Clip", {input_def}, min_max_defs, {}), provider_options, opset, @@ -80,7 +87,9 @@ TEST_F(QnnHTPBackendTests, Clip_f32) { {TestInputDef({}, true, {-5.0f}), TestInputDef({}, true, {5.0f})}, ExpectedEPNodeAssignment::All, - on_cpu_backend); + on_cpu_backend, + 13, + false); } // Test Clip with int32 on HTP diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index d8c34d6a6c6ed..708aac03ceb2e 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -117,7 +117,8 @@ static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, ExpectedEPNodeAssignment expected_ep_assignment, int opset = 21, bool use_contrib_qdq = false, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -125,6 +126,12 @@ static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif + if (enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + TestQDQModelAccuracy(BuildMatMulOpTestCase(input_def, weights_def), BuildQDQPerChannelMatMulTestCase(input_def, weights_def, @@ -275,7 +282,8 @@ TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_AS8_WeightInt4) { ExpectedEPNodeAssignment::All, 21, false, - QDQTolerance(0.007f)); + QDQTolerance(0.007f), + false); } // Test QDQ per-channel MatMul with 16-bit act, int8 weights (static) diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index c4367aeb52edc..236b66a2d8a78 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -912,10 +912,28 @@ static GetTestModelFn BuildCastAddTestCase() { }; } -// A repro of QC case 06838696, accuracy issue for Cast + Op (quantized) -// the value pair(1, 0.00392156886) at index #1 don't match, -// which is -0.996078 from 1 -TEST_F(QnnHTPBackendTests, DISABLED_CastAddHTPAccuracyTest) { +TEST_F(QnnHTPBackendTests, ProfilingTest) { + onnxruntime::ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["enable_htp_fp16_precision"] = "1"; + provider_options["profiling_level"] = "detailed"; + provider_options["profiling_file_path"] = "detailed_profile.csv"; + + auto input_defs = {TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f), + TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f)}; + RunQnnModelTest(BuildOpTestCase("Add", input_defs, {}, {}, kOnnxDomain), + provider_options, + 13, + ExpectedEPNodeAssignment::All, + 0.008f); +} + +TEST_F(QnnHTPBackendTests, CastAddHTPAccuracyTest) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index bb77c92668853..7f55a44c748b6 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -33,16 +33,37 @@ struct QuantParams { float scale; QType zero_point; + inline std::pair CalcRminRmax() const { + constexpr float qmin = static_cast(std::numeric_limits::min()); + constexpr float qmax = static_cast(std::numeric_limits::max()); + const float qrange = (qmax - qmin); + const float rrange = this->scale * qrange; + const float rmin = -(static_cast(this->zero_point) - qmin) * this->scale; + const float rmax = rrange + rmin; + + return {rmin, rmax}; + } + + inline bool IsSymmetric() const { + constexpr float qmin = static_cast(std::numeric_limits::min()); + constexpr float qmax = static_cast(std::numeric_limits::max()); + float init_zero_point = (qmin + qmax) / 2.0; + const QType symm_zero_point = static_cast(RoundHalfToEven( + std::max(qmin, std::min(qmax, init_zero_point)))); + + return this->zero_point == symm_zero_point; + } + static QuantParams Compute(float rmin, float rmax, bool symmetric = false) { return Compute( rmin, rmax, - static_cast(std::numeric_limits::min()), - static_cast(std::numeric_limits::max()), + std::numeric_limits::min(), + std::numeric_limits::max(), symmetric); } - static QuantParams Compute(float rmin, float rmax, float qmin, float qmax, bool symmetric = false) { + static QuantParams Compute(float rmin, float rmax, QType qmin, QType qmax, bool symmetric = false) { // Ensure a minimum range of 0.0001 (required by QNN) rmax = std::max(rmax, rmin + 0.0001f); @@ -56,8 +77,8 @@ struct QuantParams { rmin = -abs_max; } - float qmin_flt = qmin; - float qmax_flt = qmax; + const float qmin_flt = qmin; + const float qmax_flt = qmax; const float scale = (rmax - rmin) / (qmax_flt - qmin_flt); float initial_zero_point = 0.0f; @@ -76,6 +97,13 @@ struct QuantParams { } }; +// Utitity that converts quantization parameters from one type to another (e.g., uint8 to uint16). +template +inline QuantParams ConvertQuantParams(QuantParams src_qparams) { + std::pair src_rmin_rmax = src_qparams.CalcRminRmax(); + return QuantParams::Compute(src_rmin_rmax.first, src_rmin_rmax.second, src_qparams.IsSymmetric()); +} + // Signature for function that builds a QDQ model. // The parameter `output_qparams` contains quantization parameters that *can* be used for the QDQ model output. // These output quantization parameters are computed by first running the float32 model and determining the diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 2ebc2c6251b44..018720fd8b71f 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1206,6 +1206,80 @@ TEST_F(QnnHTPBackendTests, Add_U8_U16_Convert) { ExpectedEPNodeAssignment::All); } +// Builds a graph where a (DQ -> Q) sequence at the graph's output is fuse into a QNN Convert operator. +// ONNX Graph: DQ -> Add -> Q -> DQ -> Q -> graph_output +// QNN Graph: DQ -> Add -> Q -> Convert -> graph_output +template +static GetTestModelFn BuildDQQConvertAtOutputTestCase(const TestInputDef& input0_def, + const TestInputDef& input1_def, + const QuantParams& output_qparams) { + return [input0_def, input1_def, output_qparams](ModelTestBuilder& builder) { + // Input0 -> Quantize(InQuantType) -> Dequantize(InQuantType to float) -> input0_after_qdq + NodeArg* input0 = MakeTestInput(builder, input0_def); + QuantParams input0_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, + input0_qparams.zero_point); + + // Input1 -> Quantize(InQuantType) -> Dequantize(InQuantType to float) -> input1_after_qdq + NodeArg* input1 = MakeTestInput(builder, input1_def); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point); + + // Add op -> op_output + auto* op_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input0_after_qdq, input1_after_qdq}, {op_output}); + + // op_output -> Quantize(InQuantType) -> add_out_q + QuantParams add_out_qparams = ConvertQuantParams(output_qparams); + add_out_qparams.scale *= 1.01f; // Make qparams slightly different so DQ->Q are not optimized out. + NodeArg* add_out_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(op_output, add_out_qparams.scale, + add_out_qparams.zero_point, add_out_q); + + // Add DQ + NodeArg* add_out_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(add_out_q, add_out_qparams.scale, + add_out_qparams.zero_point, add_out_dq); + + // Add a Q to quantize to OutQuantType + // The previous DQ and this Q will be fused into a QNN Convert. + NodeArg* q_conv_out = builder.MakeOutput(); + builder.AddQuantizeLinearNode(add_out_dq, output_qparams.scale, output_qparams.zero_point, + q_conv_out); + }; +} + +// Test fusion of (DQ -> Q) into QNN's Convert op using the same quant type. +TEST_F(QnnHTPBackendTests, DQ_Q_ConvertFusion_SameType) { + std::vector input0_data = {-8.0f, -6.0, -2.0f, 0.0f, 2.0f, 4.0f, 6.0f, 8.0f}; + std::vector input1_data = {-8.0f, -6.0, -2.0f, 0.0f, 2.0f, 4.0f, 6.0f, 8.0f}; + TestInputDef input0_def({1, 2, 2, 2}, false, input0_data); + TestInputDef input1_def({1, 2, 2, 2}, false, input1_data); + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + QuantParams out_qparams_u8 = {1.0f, 128}; + QuantParams out_qparams_u16 = {1.0f, 32768}; + + // QNN Convert op converts uint8 to uint8 at the graph output. Slightly different scale values. + RunQnnModelTest(BuildDQQConvertAtOutputTestCase(input0_def, input1_def, out_qparams_u8), + provider_options, + 21, + ExpectedEPNodeAssignment::All); + + // QNN Convert op converts uint16 to uint16 at the graph output. Slightly different scale values. + RunQnnModelTest(BuildDQQConvertAtOutputTestCase(input0_def, input1_def, out_qparams_u16), + provider_options, + 21, + ExpectedEPNodeAssignment::All); +} + TEST_F(QnnHTPBackendTests, UnaryOp_HardSigmoid_QU8) { RunQDQOpTest("HardSigmoid", {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, @@ -1259,14 +1333,6 @@ TEST_F(QnnHTPBackendTests, UnaryOp_HardSigmoid_F32_as_FP16) { } // Check that QNN EP can support float16 HardSigmoid on HTP -// It is using decompose way for FP16 since ElementWiseNeuron failed to finalize the graph with the error below: -// \HTP\src\hexagon\prepare\tcm_migration.cc:1829:ERROR:no properties registered for q::QNN_HardSigmoid -// \HTP\HTP\src\hexagon\prepare\graph_prepare.cc:203:ERROR:could not create op: q::QNN_HardSigmoid -// \HTP\HTP\src\hexagon\prepare\graph_prepare.cc:1238:ERROR:Op 0x101000000010 preparation failed with err:-1 -// Completed stage: Graph Transformations and Optimizations (16361 us) -// QnnDsp "node" generated: could not create op -// QnnDsp RouterWindows graph prepare failed 12 -// QnnDsp Failed to finalize graph (id: 1) with err 1002 TEST_F(QnnHTPBackendTests, UnaryOp_HardSigmoid_FP16) { std::vector input_data = GetFloatDataInRange(-5.0f, 5.0f, 16); diff --git a/onnxruntime/test/providers/qnn/transpose_htp_test.cc b/onnxruntime/test/providers/qnn/transpose_htp_test.cc index 119b8301f36ed..63746e22d214d 100644 --- a/onnxruntime/test/providers/qnn/transpose_htp_test.cc +++ b/onnxruntime/test/providers/qnn/transpose_htp_test.cc @@ -90,7 +90,8 @@ static void RunTransposeQDQTest(const TestInputDef& input_def, template static void RunTransposeNonQDQOnHTP(const TestInputDef& input_def, const std::vector& attrs, - ExpectedEPNodeAssignment expected_ep_assignment) { + ExpectedEPNodeAssignment expected_ep_assignment, + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -98,6 +99,12 @@ static void RunTransposeNonQDQOnHTP(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif + if (enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + RunQnnModelTest(BuildTransposeTestCase(input_def, attrs), provider_options, 13, @@ -123,7 +130,7 @@ TEST_F(QnnHTPBackendTests, TransposeInt32OnHTP) { TEST_F(QnnHTPBackendTests, TransposeFloatOnHTP) { RunTransposeNonQDQOnHTP(TestInputDef({1, 3, 224, 128}, false, 0, 10.0f), {utils::MakeAttribute("perm", std::vector{0, 2, 3, 1})}, - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, false); } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index dc21d4e4a5890..08ec5de328b9d 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -29,6 +29,12 @@ GREEN = "\033[32m" RESET = "\033[0m" +ORT_TYPE = TensorProto.FLOAT +TORCH_TYPE = torch.float16 if ORT_TYPE == TensorProto.FLOAT16 else torch.float32 +NUMPY_TYPE = numpy.float16 if ORT_TYPE == TensorProto.FLOAT16 else numpy.float32 +RTOL = 3e-2 if ORT_TYPE == TensorProto.FLOAT16 else 1e-3 +ATOL = RTOL + class Formats: BSNH = 0 @@ -186,7 +192,7 @@ def create_group_query_attention_graph_prompt( graph_input = [ helper.make_tensor_value_info( "query", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.q_sequence_length, @@ -212,7 +218,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length, @@ -221,7 +227,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length, @@ -233,7 +239,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "past_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -243,7 +249,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "past_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -256,7 +262,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "cos_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.buffer_sequence_length if share_buffer else config.kv_sequence_length, (math.floor(config.head_size / 16) * 16) // 2, @@ -264,7 +270,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "sin_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.buffer_sequence_length if share_buffer else config.kv_sequence_length, (math.floor(config.head_size / 16) * 16) // 2, @@ -275,12 +281,12 @@ def create_group_query_attention_graph_prompt( graph_output = [ helper.make_tensor_value_info( "output", - TensorProto.FLOAT, + ORT_TYPE, [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -290,7 +296,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -300,7 +306,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -310,7 +316,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -378,7 +384,7 @@ def create_group_query_attention_graph_past( graph_input = [ helper.make_tensor_value_info( "query", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -391,7 +397,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "past_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -401,7 +407,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "past_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -424,7 +430,7 @@ def create_group_query_attention_graph_past( graph_input += [ helper.make_tensor_value_info( "key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -433,7 +439,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -445,7 +451,7 @@ def create_group_query_attention_graph_past( graph_input += [ helper.make_tensor_value_info( "cos_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.kv_sequence_length + (0 if share_buffer else config.sequence_length), (math.floor(config.head_size / 16) * 16) // 2, @@ -453,7 +459,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "sin_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.kv_sequence_length + (0 if share_buffer else config.sequence_length), (math.floor(config.head_size / 16) * 16) // 2, @@ -464,12 +470,12 @@ def create_group_query_attention_graph_past( graph_output = [ helper.make_tensor_value_info( "output", - TensorProto.FLOAT, + ORT_TYPE, [config.batch_size, config.sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -479,7 +485,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -641,7 +647,7 @@ def create_inputs(config: Config, kv_packed=False, qkv_packed=True): config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) key_padding_mask = generate_random_padding_mask( @@ -722,13 +728,13 @@ def gqa_prompt_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( - "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) io_binding.bind_input( "past_value", "cpu", 0, - numpy.float32, + NUMPY_TYPE, ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) @@ -835,13 +841,13 @@ def gqa_past_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( - "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) io_binding.bind_input( "past_value", "cpu", 0, - numpy.float32, + NUMPY_TYPE, ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) @@ -1017,9 +1023,11 @@ def attention_ref( attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -1058,8 +1066,8 @@ def parity_check_gqa_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1067,7 +1075,7 @@ def parity_check_gqa_prompt( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1076,7 +1084,7 @@ def parity_check_gqa_prompt( config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1085,7 +1093,7 @@ def parity_check_gqa_prompt( config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1094,7 +1102,7 @@ def parity_check_gqa_prompt( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1103,7 +1111,7 @@ def parity_check_gqa_prompt( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1129,8 +1137,8 @@ def parity_check_gqa_prompt( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved @@ -1152,8 +1160,8 @@ def parity_check_gqa_prompt( kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE) + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE) k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded @@ -1218,11 +1226,11 @@ def parity_check_gqa_prompt( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "KV-buffer", @@ -1271,8 +1279,8 @@ def parity_check_gqa_prompt_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1280,7 +1288,7 @@ def parity_check_gqa_prompt_no_buff( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1289,7 +1297,7 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1298,7 +1306,7 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1321,8 +1329,8 @@ def parity_check_gqa_prompt_no_buff( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved @@ -1405,11 +1413,11 @@ def parity_check_gqa_prompt_no_buff( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "No buff", @@ -1458,8 +1466,8 @@ def parity_check_gqa_past( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1467,7 +1475,7 @@ def parity_check_gqa_past( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1476,7 +1484,7 @@ def parity_check_gqa_past( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1485,7 +1493,7 @@ def parity_check_gqa_past( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1494,7 +1502,7 @@ def parity_check_gqa_past( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1503,7 +1511,7 @@ def parity_check_gqa_past( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1534,8 +1542,8 @@ def parity_check_gqa_past( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved @@ -1624,11 +1632,11 @@ def parity_check_gqa_past( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "KV-buffer", @@ -1677,8 +1685,8 @@ def parity_check_gqa_past_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): torch.manual_seed(69) q = torch.randn( @@ -1687,7 +1695,7 @@ def parity_check_gqa_past_no_buff( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1696,7 +1704,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1705,7 +1713,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1714,7 +1722,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1723,7 +1731,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1759,8 +1767,8 @@ def parity_check_gqa_past_no_buff( angle = ( torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi ) - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved @@ -1849,7 +1857,7 @@ def parity_check_gqa_past_no_buff( out = out.detach().cpu().numpy() # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "NO buff", @@ -1983,8 +1991,8 @@ def test_gqa_past(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -1996,8 +2004,8 @@ def test_gqa_past(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2042,8 +2050,8 @@ def test_gqa_interactive_one_batch(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2053,8 +2061,8 @@ def test_gqa_interactive_one_batch(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc index 509f56664e572..102846e08ac5f 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc @@ -46,7 +46,7 @@ namespace qnnctxgen { "\t [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" "\t [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" - "\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n" + "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" diff --git a/onnxruntime/test/testdata/make_transpose_optimizer_empty_dq_q_at_output_model.py b/onnxruntime/test/testdata/make_transpose_optimizer_empty_dq_q_at_output_model.py new file mode 100644 index 0000000000000..3666f2299de46 --- /dev/null +++ b/onnxruntime/test/testdata/make_transpose_optimizer_empty_dq_q_at_output_model.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import onnx + + +def make_model(model_path: str): + """ + Creates a QDQ model with a (DQ -> Transpose -> Q -> GRAPH OUTPUT) sequence. The Transpose is optimized out + and the TransposeOptimizer should also remove the empty (DQ -> Q) sequence. + """ + input0_shape = (1, 3, 4, 4) + + inputs = [onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, input0_shape)] + outputs = [onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.UINT8, None)] + + mul_weight_scale_data = np.array(1.0, dtype=np.float32) + mul_weight_zp_data = np.array(0, dtype=np.int8) + + initializers = [ + onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "scale_1"), + onnx.numpy_helper.from_array(np.array(128, dtype=np.uint8), "zp_128"), + onnx.numpy_helper.from_array(np.array(1.0 / 255.0, dtype=np.float32), "scale_inv_255"), + onnx.numpy_helper.from_array(np.array(0, dtype=np.uint8), "zp_0"), + onnx.numpy_helper.from_array(mul_weight_scale_data, "mul_weight_scale"), + onnx.numpy_helper.from_array(mul_weight_zp_data, "mul_weight_zp"), + ] + nodes = [] + + # Transpose to channel-last + tp0_node = onnx.helper.make_node("Transpose", ["input0"], ["tp0_out"], name="tp0_node", perm=(0, 2, 3, 1)) + nodes.append(tp0_node) + + # Q_0 + q0_node = onnx.helper.make_node("QuantizeLinear", ["tp0_out", "scale_1", "zp_128"], ["q0_out"], name="q0_node") + nodes.append(q0_node) + + # DQ_0 + dq0_node = onnx.helper.make_node("DequantizeLinear", ["q0_out", "scale_1", "zp_128"], ["dq0_out"], name="dq0_node") + nodes.append(dq0_node) + + # Sigmoid + sigmoid_node = onnx.helper.make_node("Sigmoid", ["dq0_out"], ["sigmoid_out"], name="sigmoid_node") + nodes.append(sigmoid_node) + + # Q_1 + q1_node = onnx.helper.make_node( + "QuantizeLinear", ["sigmoid_out", "scale_inv_255", "zp_0"], ["q1_out"], name="q1_node" + ) + nodes.append(q1_node) + + # DQ_1 + dq1_node = onnx.helper.make_node( + "DequantizeLinear", ["q1_out", "scale_inv_255", "zp_0"], ["dq1_out"], name="dq1_node" + ) + nodes.append(dq1_node) + + # DQ for mul input[1] + mul_weight_i8_data = np.array([1, 2, 3], dtype=np.int8) + mul_weight = onnx.numpy_helper.from_array(mul_weight_i8_data, "mul_weight") + initializers.append(mul_weight) + + nodes.append( + onnx.helper.make_node( + "DequantizeLinear", + ["mul_weight", "mul_weight_scale", "mul_weight_zp"], + ["mul_input_1"], + name="dq_mul_input_1", + ) + ) + + # Mul + mul_node = onnx.helper.make_node("Mul", ["dq1_out", "mul_input_1"], ["mul_out"], name="mul_node") + nodes.append(mul_node) + + # Q_2 + q2_node = onnx.helper.make_node("QuantizeLinear", ["mul_out", "scale_inv_255", "zp_0"], ["q2_out"], name="q2_node") + nodes.append(q2_node) + + # DQ_2 + dq2_node = onnx.helper.make_node( + "DequantizeLinear", ["q2_out", "scale_inv_255", "zp_0"], ["dq2_out"], name="dq2_node" + ) + nodes.append(dq2_node) + + # Transpose to channel-first + tp1_node = onnx.helper.make_node("Transpose", ["dq2_out"], ["tp1_out"], name="tp1_node", perm=(0, 3, 1, 2)) + nodes.append(tp1_node) + + # Q_3 to graph output + nodes.append( + onnx.helper.make_node("QuantizeLinear", ["tp1_out", "scale_inv_255", "zp_0"], ["output0"], name="q3_node") + ) + + graph = onnx.helper.make_graph( + nodes, + "transpose_opt_empty_dqq_graph_output", + inputs, + outputs, + initializer=initializers, + ) + opset_imports = [ + onnx.helper.make_opsetid("", 19), + ] + qdq_model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + print("[INFO]: Running onnx.checker on qdq model") + qdq_model = onnx.shape_inference.infer_shapes(qdq_model) + onnx.checker.check_model(qdq_model, True) + + print(f"[INFO]: Saving {model_path}") + onnx.save_model(qdq_model, model_path) + + +if __name__ == "__main__": + make_model("transpose_optimizer_empty_dq_q_at_graph_output.onnx") diff --git a/onnxruntime/test/testdata/transpose_optimizer_empty_dq_q_at_graph_output.onnx b/onnxruntime/test/testdata/transpose_optimizer_empty_dq_q_at_graph_output.onnx new file mode 100644 index 0000000000000..e5a0675c61cae Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_empty_dq_q_at_graph_output.onnx differ diff --git a/onnxruntime/test/util/test_utils.cc b/onnxruntime/test/util/test_utils.cc index 6bc0f8d105495..f6d5d133262c4 100644 --- a/onnxruntime/test/util/test_utils.cc +++ b/onnxruntime/test/util/test_utils.cc @@ -38,6 +38,10 @@ void VerifyOutput(const std::string& output_name, EXPECT_TRUE(SpanEq(expected_tensor.DataAsSpan(), tensor.DataAsSpan())) << " mismatch for " << output_name; break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: + EXPECT_TRUE(SpanEq(expected_tensor.DataAsSpan(), tensor.DataAsSpan())) + << " mismatch for " << output_name; + break; case ONNX_NAMESPACE::TensorProto_DataType_UINT8: EXPECT_TRUE(SpanEq(expected_tensor.DataAsSpan(), tensor.DataAsSpan())) << " mismatch for " << output_name; @@ -55,6 +59,11 @@ void VerifyOutput(const std::string& output_name, ::testing::Pointwise(::testing::FloatNear(fp32_abs_err), tensor.DataAsSpan())); break; } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + EXPECT_THAT(expected_tensor.DataAsSpan(), + ::testing::Pointwise(::testing::FloatNear(fp32_abs_err), tensor.DataAsSpan())); + break; + } default: ORT_THROW("Unhandled data type. Please add 'case' statement for ", element_type); } diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 0e58bb4f93f7f..5173125cb8634 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -23,7 +23,8 @@ enum DataLocation { DATA_LOCATION_CPU = 1, DATA_LOCATION_CPU_PINNED = 2, DATA_LOCATION_TEXTURE = 3, - DATA_LOCATION_GPU_BUFFER = 4 + DATA_LOCATION_GPU_BUFFER = 4, + DATA_LOCATION_ML_TENSOR = 5 }; static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); @@ -235,7 +236,8 @@ void OrtFree(void* ptr) { OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { if (data_location != DATA_LOCATION_CPU && data_location != DATA_LOCATION_CPU_PINNED && - data_location != DATA_LOCATION_GPU_BUFFER) { + data_location != DATA_LOCATION_GPU_BUFFER && + data_location != DATA_LOCATION_ML_TENSOR) { std::ostringstream ostr; ostr << "Invalid data location: " << data_location; CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); @@ -264,10 +266,15 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* return UNREGISTER_AUTO_RELEASE(value); } else { OrtMemoryInfo* memory_info = nullptr; - if (data_location != DATA_LOCATION_GPU_BUFFER) { - RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); - } else { - RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + switch (data_location) { + case DATA_LOCATION_GPU_BUFFER: + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + break; + case DATA_LOCATION_ML_TENSOR: + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebNN_Tensor", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + break; + default: + RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); } REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); @@ -418,15 +425,18 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, if (output_location != DATA_LOCATION_NONE && output_location != DATA_LOCATION_CPU && output_location != DATA_LOCATION_CPU_PINNED && - output_location != DATA_LOCATION_GPU_BUFFER) { + output_location != DATA_LOCATION_GPU_BUFFER && + output_location != DATA_LOCATION_ML_TENSOR) { std::ostringstream ostr; ostr << "Invalid data location (" << output_location << ") for output: \"" << name << "\"."; return CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); } OrtMemoryInfo* memory_info = nullptr; - if (output_location != DATA_LOCATION_GPU_BUFFER) { + if (output_location != DATA_LOCATION_GPU_BUFFER && output_location != DATA_LOCATION_ML_TENSOR) { RETURN_ERROR_CODE_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); + } else if (output_location == DATA_LOCATION_ML_TENSOR) { + RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebNN_Tensor", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); } else { RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 70ed295887994..68332d07a9782 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -202,5 +202,38 @@ Module['jsepInit'] = (name, params) => { Module.jsepUploadExternalBuffer = (dataId, buffer) => { backend['upload'](dataId, buffer); }; + } else if (name === 'webnn') { + // Functions called from EM_ASM need to be assigned in a way that can be minified. + // Functions called via emscripten::val::module_property need to be assigned by name so that the minifier doesn't + // change the name. + + [Module.jsepBackend, + Module.jsepReserveTensorId, + Module.jsepReleaseTensorId, + Module['jsepEnsureTensor'], + Module.jsepUploadTensor, + Module['jsepDownloadTensor'], + ] = params; + + // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. + Module['jsepReleaseTensorId'] = Module.jsepReleaseTensorId; + + // Functions called from JS also need to have explicit names. + const backend = Module.jsepBackend; + Module['jsepOnRunStart'] = sessionId => { + return backend['onRunStart'](sessionId); + }; + Module['jsepRegisterMLContext'] = (sessionId, mlContext) => { + backend['registerMLContext'](sessionId, mlContext); + }; + Module['jsepOnReleaseSession'] = sessionId => { + backend['onReleaseSession'](sessionId); + }; + Module['jsepCreateMLTensorDownloader'] = (tensorId, type) => { + return backend['createMLTensorDownloader'](tensorId, type); + } + Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { + return backend['registerMLTensor'](tensor, dataType, shape); + } } }; diff --git a/packages.config b/packages.config index 24289f36689a7..597ca77a321c5 100644 --- a/packages.config +++ b/packages.config @@ -1,6 +1,6 @@  - + diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 7d384f7b1df67..72d9ce72ea7cb 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -1,5 +1,7 @@ # This file is auto updated by dependabot -lintrunner-adapters>=0.12.4 +# When any package below is changed, you shall run "lintrunner init" again. +lintrunner==0.12.5 +lintrunner-adapters==0.12.4 # RUFF ruff==0.5.4 # BLACK-ISORT diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 8c2451778420c..0806b56a95c9d 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2097,10 +2097,10 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if not args.disable_ml_ops and not args.use_tensorrt: run_subprocess([sys.executable, "onnxruntime_test_python_mlops.py"], cwd=cwd, dll_path=dll_path) - if args.use_tensorrt: - run_subprocess( - [sys.executable, "onnxruntime_test_python_nested_control_flow_op.py"], cwd=cwd, dll_path=dll_path - ) + # if args.use_tensorrt: + # run_subprocess( + # [sys.executable, "onnxruntime_test_python_nested_control_flow_op.py"], cwd=cwd, dll_path=dll_path + # ) try: import onnx # noqa: F401 diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index bb4cfb2e09dcc..ae0769e7fb93c 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -20,6 +20,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:Mul|| |ai.onnx:Pow|Only supports cases when both inputs are fp32.| +|ai.onnx:Reciprocal|this ask for a `epislon` (default 1e-4) where onnx don't provide| |ai.onnx:Relu|| |ai.onnx:Reshape|| |ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.| @@ -27,5 +28,6 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Split|If provided, `splits` must be constant.| |ai.onnx:Sub|| |ai.onnx:Sigmoid|| +|ai.onnx:Sqrt|| |ai.onnx:Tanh|| |ai.onnx:Transpose|| diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index bcce208aea2c1..71b14e676f8b1 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -283,7 +283,7 @@ stages: - stage: Llama2_7B_ONNX dependsOn: - Build_Onnxruntime_Cuda - condition: and (succeeded(), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))) + condition: or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'), eq(variables['UseA100'], '1')) jobs: - job: Llama2_7B_ONNX timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 4bcbc12574b4d..e2d977bd60986 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -83,7 +83,7 @@ variables: value: 11.8 - name: win_trt_home - value: $(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8 + value: $(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 - name: win_cuda_home value: $(Agent.TempDirectory)\v11.8 diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index 785dc901d6e43..7118e85e9ea4b 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -63,9 +63,9 @@ variables: value: '' - name: win_trt_home ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: $(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8 + value: $(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: $(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-12.5 + value: $(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 - name: win_cuda_home ${{ if eq(parameters.CudaVersion, '11.8') }}: value: $(Agent.TempDirectory)\v11.8 diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 008292d855fc0..6b6630b4bb0f2 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -44,9 +44,9 @@ variables: value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20240719.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.3.0.26-1.cuda11.8 + value: 10.4.0.26-1.cuda11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.3.0.26-1.cuda12.5 + value: 10.4.0.26-1.cuda12.6 jobs: - job: Linux_Build diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index e172611d898bf..fb2c86dbf68e3 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -8,12 +8,12 @@ parameters: - name: TrtVersion displayName: TensorRT Version type: string - default: 10.3.cuda_12_5_cudnn_9 + default: 10.4.cuda_12_5_cudnn_9 values: - 8.6.cuda_11_8_cudnn_8 - 8.6.cuda_12_3_cudnn_9 - - 10.3.cuda_11_8_cudnn_8 - - 10.3.cuda_12_5_cudnn_9 + - 10.4.cuda_11_8_cudnn_8 + - 10.4.cuda_12_5_cudnn_9 - BIN - name: UseTensorrtOssParser diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index 4276e6cfba38a..b1e5816fb748e 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -61,7 +61,7 @@ stages: ${{ if eq(parameters.CudaVersion, '12.2') }}: DockerBuildArgs: " --build-arg BASEIMAGE=nvidia/cuda:12.2.2-devel-ubuntu20.04 - --build-arg TRT_VERSION=10.3.0.26-1+cuda12.5 + --build-arg TRT_VERSION=10.4.0.26-1+cuda12.6 --build-arg BUILD_UID=$( id -u ) " ${{ else }}: diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 3853bdbd1eb88..79f0732b245e2 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -226,7 +226,7 @@ stages: BuildConfig: 'RelWithDebInfo' EnvSetupScript: setup_env_trt.bat buildArch: x64 - additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 isX86: false job_name_suffix: x64_RelWithDebInfo diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index de2677ebc6594..5ba1e78cbbf0c 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -55,7 +55,7 @@ stages: python_wheel_suffix: '_gpu' timeout: 480 docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 - trt_version: '10.3.0.26-1.cuda11.8' + trt_version: '10.4.0.26-1.cuda11.8' cuda_version: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml index 430dc89b5b097..61e181a6004e9 100644 --- a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml @@ -58,6 +58,10 @@ stages: showWarnings: true workingDirectory: '$(Build.BinariesDirectory)\java-artifact' + - template: ../templates/jar-maven-signing-win.yml + parameters: + JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + - task: CopyFiles@2 displayName: 'Copy Java Files to Artifact Staging Directory' inputs: diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index 846fae29e45ab..805094864956d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -49,9 +49,9 @@ jobs: value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20240719.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.3.0.26-1.cuda11.8 + value: 10.4.0.26-1.cuda11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.3.0.26-1.cuda12.5 + value: 10.4.0.26-1.cuda12.6 pool: ${{ parameters.machine_pool }} steps: - checkout: self diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index dcde93e261c0d..034f5221aba49 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -78,9 +78,9 @@ stages: - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.3.0.26-1.cuda11.8 + value: 10.4.0.26-1.cuda11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.3.0.26-1.cuda12.5 + value: 10.4.0.26-1.cuda12.5 steps: - checkout: self clean: true @@ -147,9 +147,9 @@ stages: value: '12' - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.3.0.26-1.cuda11.8 + value: 10.4.0.26-1.cuda11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.3.0.26-1.cuda12.5 + value: 10.4.0.26-1.cuda12.6 steps: - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime submodules: false diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml index ed09b490c3f4d..119024f8bd3e2 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml @@ -65,9 +65,9 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} ${{ if eq(parameters.cuda_version, '11.8') }}: - EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ${{ if eq(parameters.cuda_version, '12.2') }}: - EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-12.5 --cuda_home=$(Agent.TempDirectory)\v12.2 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 --cuda_home=$(Agent.TempDirectory)\v12.2 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ${{ if eq(parameters.enable_linux_gpu, true) }}: - template: ../templates/py-linux-gpu.yml @@ -79,7 +79,7 @@ stages: cuda_version: ${{ parameters.cuda_version }} ${{ if eq(parameters.cuda_version, '11.8') }}: docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 - trt_version: 10.3.0.26-1.cuda11.8 + trt_version: 10.4.0.26-1.cuda11.8 ${{ if eq(parameters.cuda_version, '12.2') }}: docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20240719.1 - trt_version: 10.3.0.26-1.cuda12.5 + trt_version: 10.4.0.26-1.cuda12.6 diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 8ce0e09dce605..ecc0a53f028a4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -102,6 +102,10 @@ jobs: /bin/bash /onnxruntime_src/tools/ci_build/github/android/build_aar_and_copy_artifacts.sh workingDirectory: $(Build.SourcesDirectory) + - template: jar-maven-signing-linux.yml + parameters: + JarFileDirectory: '$(artifacts_directory)' + - task: PublishBuildArtifacts@1 inputs: pathtoPublish: '$(artifacts_directory)' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 3e90a401d4deb..a483db2f9688e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -236,6 +236,10 @@ stages: showWarnings: true workingDirectory: '$(Build.BinariesDirectory)\java-artifact' + - template: jar-maven-signing-win.yml + parameters: + JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + - task: CopyFiles@2 displayName: 'Copy Java Files to Artifact Staging Directory' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index cbba1cb8ba8bd..39479e1b8d208 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.184 + version: 1.0.188 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.184 + version: 1.0.188 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml new file mode 100644 index 0000000000000..96be3b7b0746e --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml @@ -0,0 +1,55 @@ +parameters: + - name: JarFileDirectory + type: string + +steps: + - task: AzureKeyVault@2 + displayName: 'Get GnuPG signing keys' + inputs: + azureSubscription: 'OnnxrunTimeCodeSign_20240611' + KeyVaultName: 'ort-release' + SecretsFilter: 'java-pgp-pwd,java-pgp-key' + RunAsPreJob: false + + - task: CmdLine@2 + displayName: 'Sign jar files: GnuPG and sha256' + inputs: + workingDirectory: '$(Build.SourcesDirectory)' + script: | + #!/bin/bash + set -ex + + jar_file_directory='${{ parameters.JarFileDirectory }}' + working_directory='$(Build.SourcesDirectory)' + original_private_key='$(java-pgp-key)' + original_passphrase='$(java-pgp-pwd)' + + private_key_file=$working_directory/private_key.txt + passphrase_file=$working_directory/passphrase.txt + + echo "Generating GnuPG key files." + printf "%s" "$original_private_key" >$private_key_file + printf "%s" "$original_passphrase" >$passphrase_file + echo "Generated GnuPG key files." + + echo "Importing GnuPG private key file." + gpg --batch --import $private_key_file + echo "Imported GnuPG private key file." + + for file in $(find $jar_file_directory -type f); do + echo "GnuPG signing to file: $file" + gpg --pinentry-mode loopback --passphrase-file $passphrase_file -ab $file + echo "GnuPG signed to file: $file" + done + + for file in $(find $jar_file_directory -type f); do + echo "Adding checksum of sha256 to file: $file" + sha256sum $file | awk '{print $1}' >$file.sha256 + echo "Added checksum of sha256 to file: $file" + done + + echo "GnuPG and sha256 signing to files completed." + echo "Deleting GnuPG key files." + rm -f $private_key_file + rm -f $passphrase_file + echo "Deleted GnuPG key files." diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml new file mode 100644 index 0000000000000..182a2ebe3b4c9 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml @@ -0,0 +1,70 @@ +parameters: + - name: JarFileDirectory + type: string + +steps: + - task: AzureKeyVault@2 + displayName: 'Get GnuPG signing keys' + inputs: + azureSubscription: 'OnnxrunTimeCodeSign_20240611' + KeyVaultName: 'ort-release' + SecretsFilter: 'java-pgp-pwd,java-pgp-key' + RunAsPreJob: false + + - task: PowerShell@2 + displayName: 'Sign jar files: GnuPG and sha256' + inputs: + targetType: 'inline' + workingDirectory: '$(Build.SourcesDirectory)' + script: | + $jar_file_directory = '${{ parameters.JarFileDirectory }}' + $working_directory = '$(Build.SourcesDirectory)' + + $original_passphrase='$(java-pgp-pwd)' + $original_private_key='$(java-pgp-key)' + + $gpg_exe_path = "C:\Program Files (x86)\gnupg\bin\gpg.exe" + + $passphrase_file = Join-Path -Path $working_directory -ChildPath "passphrase.txt" + $private_key_file = Join-Path -Path $working_directory -ChildPath "private_key.txt" + + Write-Host "Generating GnuPG key files." + Out-File -FilePath $passphrase_file -InputObject $original_passphrase -NoNewline -Encoding ascii + Out-File -FilePath $private_key_file -InputObject $original_private_key -NoNewline -Encoding ascii + Write-Host "Generated GnuPG key files." + + Write-Host "Importing GnuPG private key file." + & $gpg_exe_path --batch --import $private_key_file + if ($lastExitCode -ne 0) { + Write-Host -Object "GnuPG importing private key command failed. Exitcode: $exitCode" + exit $lastExitCode + } + Write-Host "Imported GnuPG private key file." + + $targeting_original_files = Get-ChildItem $jar_file_directory -Recurse -Force -File -Name + foreach ($file in $targeting_original_files) { + $file_path = Join-Path $jar_file_directory -ChildPath $file + Write-Host "GnuPG signing to file: "$file_path + & $gpg_exe_path --pinentry-mode loopback --passphrase-file $passphrase_file -ab $file_path + if ($lastExitCode -ne 0) { + Write-Host -Object "GnuPG signing file command failed. Exitcode: $exitCode" + exit $lastExitCode + } + Write-Host "GnuPG signed to file: "$file_path + } + + $targeting_asc_files = Get-ChildItem $jar_file_directory -Recurse -Force -File -Name + foreach ($file in $targeting_asc_files) { + $file_path = Join-Path $jar_file_directory -ChildPath $file + Write-Host "Adding checksum of sha256 to file: "$file_path + $file_path_sha256 = $file_path + ".sha256" + CertUtil -hashfile $file_path SHA256 + CertUtil -hashfile $file_path SHA256 | find /v `"hash`" | Out-File -FilePath $file_path_sha256 + Write-Host "Added checksum of sha256 to file: "$file_path + } + + Write-Host "GnuPG and sha256 signing to files completed." + Write-Host "Deleting GnuPG key files." + Remove-Item -Path $passphrase_file + Remove-Item -Path $private_key_file + Write-Host "Deleted GnuPG key files." diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index 9339eb3f4b9ad..6ce4ad78e6f9e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -13,10 +13,10 @@ parameters: - 12.2 - name: TrtVersion type: string - default: '10.3.0.26' + default: '10.4.0.26' values: - 8.6.1.6 - - 10.3.0.26 + - 10.4.0.26 steps: - ${{ if eq(parameters.DownloadCUDA, true) }}: @@ -42,7 +42,7 @@ steps: - powershell: | Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.0" displayName: Set trtCudaVersion - - ${{ if and(eq(parameters.CudaVersion, '12.2'), eq(parameters.TrtVersion, '10.3.0.26')) }}: + - ${{ if and(eq(parameters.CudaVersion, '12.2'), eq(parameters.TrtVersion, '10.4.0.26')) }}: - powershell: | Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.5" displayName: Set trtCudaVersion diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml index ac6bf48c2ab68..6a2b7f4566b61 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml @@ -24,11 +24,11 @@ steps: displayName: 'Download Secondary CUDA SDK v${{ parameters.SecondaryCUDAVersion }}' - ${{ if eq(parameters.DownloadTRT, 'true') }}: - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8" $(Agent.TempDirectory) - displayName: 'Download TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8' + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" $(Agent.TempDirectory) + displayName: 'Download TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8' - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.3.0.26.Windows10.x86_64.cuda-12.5" $(Agent.TempDirectory) - displayName: 'Download TensorRT-10.3.0.26.Windows10.x86_64.cuda-12.5' + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6" $(Agent.TempDirectory) + displayName: 'Download TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6' - task: BatchScript@1 displayName: 'setup env' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml index 3edae95243943..d19472bcbab5a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml @@ -22,10 +22,10 @@ parameters: - name: trt_version type: string - default: '10.3.0.26-1.cuda11.8' + default: '10.4.0.26-1.cuda11.8' values: - - 10.3.0.26-1.cuda11.8 - - 10.3.0.26-1.cuda12.5 + - 10.4.0.26-1.cuda11.8 + - 10.4.0.26-1.cuda12.6 - name: cuda_version type: string default: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index 35a81c754b38a..0c3cd60a712fb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -18,10 +18,10 @@ parameters: - name: trt_version type: string - default: '10.3.0.26-1.cuda11.8' + default: '10.4.0.26-1.cuda11.8' values: - - 10.3.0.26-1.cuda11.8 - - 10.3.0.26-1.cuda12.5 + - 10.4.0.26-1.cuda11.8 + - 10.4.0.26-1.cuda12.6 - name: cuda_version type: string default: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml index e95de10de2709..8a6434e757a3c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml @@ -381,7 +381,7 @@ stages: variables: CUDA_VERSION: '11.8' buildArch: x64 - EpBuildFlags: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8" --cuda_version=$(CUDA_VERSION) --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$(CUDA_VERSION)" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80" + EpBuildFlags: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_version=$(CUDA_VERSION) --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$(CUDA_VERSION)" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80" EnvSetupScript: setup_env_gpu.bat EP_NAME: gpu VSGenerator: 'Visual Studio 17 2022' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 5c5ccdef980fe..5c78a5dbac6ee 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -298,7 +298,7 @@ stages: parameters: MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' PYTHON_VERSION: '3.8' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ENV_SETUP_SCRIPT: setup_env_gpu.bat EP_NAME: gpu publish_symbols: ${{ parameters.publish_symbols }} @@ -308,7 +308,7 @@ stages: parameters: MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' PYTHON_VERSION: '3.9' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ENV_SETUP_SCRIPT: setup_env_gpu.bat EP_NAME: gpu publish_symbols: ${{ parameters.publish_symbols }} @@ -318,7 +318,7 @@ stages: parameters: MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' PYTHON_VERSION: '3.10' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ENV_SETUP_SCRIPT: setup_env_gpu.bat EP_NAME: gpu publish_symbols: ${{ parameters.publish_symbols }} @@ -328,7 +328,7 @@ stages: parameters: MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' PYTHON_VERSION: '3.11' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ENV_SETUP_SCRIPT: setup_env_gpu.bat EP_NAME: gpu publish_symbols: ${{ parameters.publish_symbols }} @@ -338,7 +338,7 @@ stages: parameters: MACHINE_POOL: 'onnxruntime-Win2022-GPU-A10' PYTHON_VERSION: '3.12' - EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ENV_SETUP_SCRIPT: setup_env_gpu.bat EP_NAME: gpu publish_symbols: ${{ parameters.publish_symbols }} @@ -506,7 +506,7 @@ stages: docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} - trt_version: '10.3.0.26-1.cuda11.8' + trt_version: '10.4.0.26-1.cuda11.8' cuda_version: '11.8' - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml index ef120be5d0391..7c04d6aa2e739 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml @@ -39,9 +39,9 @@ parameters: variables: - name: win_trt_folder ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8 + value: TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: TensorRT-10.3.0.26.Windows10.x86_64.cuda-12.5 + value: TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 jobs: - job: 'build' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 index 76a9d5f0b09b6..c1a445e29fc89 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 @@ -6,7 +6,7 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:12.5.1-cudnn-devel-ubi8 -ARG TRT_VERSION=10.3.0.26-1.cuda12.4 +ARG TRT_VERSION=10.4.0.26-1.cuda12.6 FROM $BASEIMAGE AS base ARG TRT_VERSION ENV PATH /opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch index d6c89703db2e4..a228ebed165eb 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch @@ -6,7 +6,7 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 -ARG TRT_VERSION=10.3.0.26-1.cuda11.8 +ARG TRT_VERSION=10.4.0.26-1.cuda11.8 FROM $BASEIMAGE AS base ARG TRT_VERSION ENV PATH /opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu index d9875a81d2226..6a4244b7aad0d 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu @@ -6,7 +6,7 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 -ARG TRT_VERSION=10.3.0.26-1+cuda11.8 +ARG TRT_VERSION=10.4.0.26-1+cuda11.8 ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 FROM $BASEIMAGE AS base ARG TRT_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg index c2d65b813310d..418c551ab38b4 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg @@ -6,7 +6,7 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 -ARG TRT_VERSION=10.3.0.26-1+cuda11.8 +ARG TRT_VERSION=10.4.0.26-1+cuda11.8 ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 FROM $BASEIMAGE AS base ARG TRT_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 index 7c99d933d72ec..a7d8f220ea9b3 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 @@ -31,7 +31,7 @@ RUN pip install --upgrade pip RUN pip install psutil setuptools>=68.2.2 # Install TensorRT -RUN version="10.3.0.26-1+cuda11.8" &&\ +RUN version="10.4.0.26-1+cuda11.8" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ apt-get install -y \ @@ -61,7 +61,7 @@ RUN if [ ! -d /usr/src/tensorrt/bin ] || [ ! -f /usr/src/tensorrt/bin/trtexec ]; RUN apt-get install -y valgrind # Build final image from base. Builds ORT. -FROM base as final +FROM base AS final ARG BUILD_USER=onnxruntimedev ARG BUILD_UID=1000 RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 index 449d73066481b..523318f09aba6 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 @@ -31,7 +31,7 @@ RUN pip install --upgrade pip RUN pip install setuptools>=68.2.2 psutil # Install TensorRT -RUN version="10.3.0.26-1+cuda12.5" &&\ +RUN version="10.4.0.26-1+cuda12.6" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ apt-get install -y \ @@ -61,7 +61,7 @@ RUN if [ ! -d /usr/src/tensorrt/bin ] || [ ! -f /usr/src/tensorrt/bin/trtexec ]; RUN apt-get install -y valgrind # Build final image from base. Builds ORT. -FROM base as final +FROM base AS final ARG BUILD_USER=onnxruntimedev ARG BUILD_UID=1000 RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile index 710c73ccdaf98..85b1469a038fd 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile @@ -5,7 +5,7 @@ ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 FROM $BASEIMAGE -ARG TRT_VERSION=10.3.0.26-1.cuda11.8 +ARG TRT_VERSION=10.4.0.26-1.cuda11.8 #Install TensorRT only if TRT_VERSION is not empty RUN if [ -n "${TRT_VERSION}" ]; then \ diff --git a/tools/ci_build/github/windows/setup_env_gpu.bat b/tools/ci_build/github/windows/setup_env_gpu.bat index 87affc1348edf..6a660ecaa40d2 100644 --- a/tools/ci_build/github/windows/setup_env_gpu.bat +++ b/tools/ci_build/github/windows/setup_env_gpu.bat @@ -6,10 +6,10 @@ if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( ) else ( set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64;%PATH% ) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.3.0.26.Windows10.x86_64.cuda-12.5\lib;%PATH% +set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6\lib;%PATH% @REM The default version is still cuda v12.2, because set cuda v11.8 after it -set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.3.0.26.Windows10.x86_64.cuda-11.8\lib +set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8\lib if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64 ) else ( diff --git a/tools/ci_build/github/windows/setup_env_trt.bat b/tools/ci_build/github/windows/setup_env_trt.bat index 9bd26cc0dc824..4f2272e306570 100644 --- a/tools/ci_build/github/windows/setup_env_trt.bat +++ b/tools/ci_build/github/windows/setup_env_trt.bat @@ -6,6 +6,6 @@ if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( ) else ( set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64 ) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.3.0.26.Windows10.x86_64.cuda-12.5\lib;%PATH% +set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6\lib;%PATH% set GRADLE_OPTS=-Dorg.gradle.daemon=false set CUDA_MODULE_LOADING=LAZY diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 56e739f5ff3b5..683d7b6be2aa8 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -225,7 +225,7 @@ def add_common_dependencies(xml_text, package_name, version): def generate_dependencies(xml_text, package_name, version): - dml_dependency = '' + dml_dependency = '' if package_name == "Microsoft.AI.MachineLearning": xml_text.append("")