diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 6efa8a5592337..aecc05c91d736 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -1,192 +1,197 @@ -name: Mac_CI - -on: - push: - branches: - - main - - rel-* - pull_request: - branches: - - main - - rel-* - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -env: - python_version: 3.11 - xcode_version: 15.2 - -jobs: - ARM64: - runs-on: macos-14 - - timeout-minutes: 60 - - steps: - - uses: actions/setup-python@v5 - with: - python-version: ${{ env.python_version }} - - - name: Verify ARM64 machine - shell: python - run: | - import platform - assert platform.machine() == "arm64", "This job expects to be run on an ARM64 machine." - - - name: Use Xcode ${{ env.xcode_version }} - shell: bash - run: | - XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" - sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" - - - uses: actions/checkout@v4 - - - name: Build and test - shell: bash - run: | - python ./tools/ci_build/build.py \ - --build_dir ./build \ - --update \ - --build --parallel \ - --test \ - --build_shared_lib \ - --build_objc \ - --use_coreml \ - --use_xnnpack \ - --use_binskim_compliant_compile_flags - - Vcpkg: - runs-on: macos-13 - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: ${{ env.python_version }} - - - name: "Run vcpkg(x64-osx)" - uses: lukka/run-vcpkg@v11 - with: - vcpkgDirectory: "${{ runner.temp }}/vcpkg" - vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 - runVcpkgInstall: true - vcpkgJsonGlob: "cmake/vcpkg.json" - vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" - env: - VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" - VCPKG_DEFAULT_TRIPLET: "x64-osx" - # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching - - - name: "Run compile_schema.py" - run: | - # Runner's host triplet should be x64-osx or arm64-osx - export FLATC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/flatbuffers" - export PATH="$FLATC_DIR:$PATH" - flatc --version - python onnxruntime/core/flatbuffers/schema/compile_schema.py --flatc "$(which flatc)" - - - name: "Detect protoc" - id: protoc-detect - run: | - export PROTOC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/protobuf" - export PATH="$PROTOC_DIR:$PATH" - protoc --version - echo "protoc_path=$(which protoc)" >> "$GITHUB_OUTPUT" - - - name: "Run build.py(x64-osx)" - run: | - python ./tools/ci_build/build.py \ - --build_dir "build/x64-osx" \ - --skip_submodule_sync \ - --skip_tests \ - --compile_no_warning_as_error \ - --parallel \ - --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ - --osx_arch x86_64 \ - --use_vcpkg \ - --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ - --cmake_extra_defines "VCPKG_TARGET_TRIPLET=x64-osx" \ - --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ - --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" - shell: bash - - - name: "Run vcpkg(arm64-osx)" - uses: lukka/run-vcpkg@v11 - with: - vcpkgDirectory: "${{ runner.temp }}/vcpkg" - vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 - runVcpkgInstall: true - vcpkgJsonGlob: "cmake/vcpkg.json" - vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" - env: - VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" - VCPKG_DEFAULT_TRIPLET: "arm64-osx" - # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching - - - name: "Run build.py(arm64-osx)" - run: | - python ./tools/ci_build/build.py \ - --build_dir "build/arm64-osx" \ - --skip_submodule_sync \ - --skip_tests \ - --compile_no_warning_as_error \ - --parallel \ - --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ - --osx_arch arm64 \ - --use_vcpkg \ - --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ - --cmake_extra_defines "VCPKG_TARGET_TRIPLET=arm64-osx" \ - --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ - --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" - shell: bash - - Objective-C-StaticAnalysis: - runs-on: macos-14 - - timeout-minutes: 30 - - steps: - - uses: actions/setup-python@v5 - with: - python-version: ${{ env.python_version }} - - - name: Use Xcode ${{ env.xcode_version }} - shell: bash - run: | - XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" - sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" - - - uses: actions/checkout@v4 - - - name: Generate compile_commands.json and ONNX protobuf files - shell: bash - run: | - python ./tools/ci_build/build.py \ - --build_dir ./build \ - --cmake_generator "Unix Makefiles" \ - --config Debug \ - --build_shared_lib \ - --use_coreml \ - --build_objc \ - --enable_training_apis \ - --cmake_extra_defines CMAKE_EXPORT_COMPILE_COMMANDS=ON \ - --use_binskim_compliant_compile_flags \ - --update \ - --build --parallel \ - --target onnx_proto - - - name: Analyze Objective-C/C++ source code - shell: bash - run: | - CLANG_TIDY_CHECKS="-*,clang-analyzer-*" - - "$(brew --prefix llvm@15)/bin/clang-tidy" \ - -p=./build/Debug \ - --checks="${CLANG_TIDY_CHECKS}" \ - --warnings-as-errors="${CLANG_TIDY_CHECKS}" \ - --header-filter="objectivec/include|objectivec|onnxruntime/core" \ - ./objectivec/*.mm \ - ./onnxruntime/core/platform/apple/logging/apple_log_sink.mm \ - ./onnxruntime/core/providers/coreml/model/*.mm +name: Mac_CI + +on: + push: + branches: + - main + - rel-* + pull_request: + branches: + - main + - rel-* + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +env: + python_version: 3.11 + +jobs: + ARM64: + runs-on: macos-14 + + env: + xcode_version: 16 + + timeout-minutes: 60 + + steps: + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: Verify ARM64 machine + shell: python + run: | + import platform + assert platform.machine() == "arm64", "This job expects to be run on an ARM64 machine." + + - name: Use Xcode ${{ env.xcode_version }} + shell: bash + run: | + XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" + sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" + + - uses: actions/checkout@v4 + + - name: Build and test + shell: bash + run: | + python ./tools/ci_build/build.py \ + --build_dir ./build \ + --update \ + --build --parallel \ + --test \ + --build_shared_lib \ + --build_objc \ + --use_coreml \ + --use_xnnpack \ + --use_binskim_compliant_compile_flags + + Vcpkg: + runs-on: macos-13 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: "Run vcpkg(x64-osx)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "${{ runner.temp }}/vcpkg" + vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "x64-osx" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run compile_schema.py" + run: | + # Runner's host triplet should be x64-osx or arm64-osx + export FLATC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/flatbuffers" + export PATH="$FLATC_DIR:$PATH" + flatc --version + python onnxruntime/core/flatbuffers/schema/compile_schema.py --flatc "$(which flatc)" + + - name: "Detect protoc" + id: protoc-detect + run: | + export PROTOC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/protobuf" + export PATH="$PROTOC_DIR:$PATH" + protoc --version + echo "protoc_path=$(which protoc)" >> "$GITHUB_OUTPUT" + + - name: "Run build.py(x64-osx)" + run: | + python ./tools/ci_build/build.py \ + --build_dir "build/x64-osx" \ + --skip_submodule_sync \ + --skip_tests \ + --compile_no_warning_as_error \ + --parallel \ + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ + --osx_arch x86_64 \ + --use_vcpkg \ + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=x64-osx" \ + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: bash + + - name: "Run vcpkg(arm64-osx)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "${{ runner.temp }}/vcpkg" + vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "arm64-osx" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run build.py(arm64-osx)" + run: | + python ./tools/ci_build/build.py \ + --build_dir "build/arm64-osx" \ + --skip_submodule_sync \ + --skip_tests \ + --compile_no_warning_as_error \ + --parallel \ + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ + --osx_arch arm64 \ + --use_vcpkg \ + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=arm64-osx" \ + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: bash + + Objective-C-StaticAnalysis: + runs-on: macos-14 + + env: + xcode_version: 15.2 + + timeout-minutes: 30 + + steps: + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: Use Xcode ${{ env.xcode_version }} + shell: bash + run: | + XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" + sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" + + - uses: actions/checkout@v4 + + - name: Generate compile_commands.json and ONNX protobuf files + shell: bash + run: | + python ./tools/ci_build/build.py \ + --build_dir ./build \ + --cmake_generator "Unix Makefiles" \ + --config Debug \ + --build_shared_lib \ + --use_coreml \ + --build_objc \ + --enable_training_apis \ + --cmake_extra_defines CMAKE_EXPORT_COMPILE_COMMANDS=ON \ + --use_binskim_compliant_compile_flags \ + --update \ + --build --parallel \ + --target onnx_proto + + - name: Analyze Objective-C/C++ source code + shell: bash + run: | + CLANG_TIDY_CHECKS="-*,clang-analyzer-*" + + "$(brew --prefix llvm@15)/bin/clang-tidy" \ + -p=./build/Debug \ + --checks="${CLANG_TIDY_CHECKS}" \ + --warnings-as-errors="${CLANG_TIDY_CHECKS}" \ + --header-filter="objectivec/include|objectivec|onnxruntime/core" \ + ./objectivec/*.mm \ + ./onnxruntime/core/platform/apple/logging/apple_log_sink.mm \ + ./onnxruntime/core/providers/coreml/model/*.mm diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 246675b72f4e6..7168a99fe1f93 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -642,10 +642,12 @@ else() check_cxx_compiler_flag(-Wcast-function-type HAS_CAST_FUNCTION_TYPE) check_cxx_compiler_flag(-Wcatch-value HAS_CATCH_VALUE) check_cxx_compiler_flag(-Wclass-memaccess HAS_CLASS_MEMACCESS) + check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) check_cxx_compiler_flag(-Wdeprecated-anon-enum-enum-conversion HAS_DEPRECATED_ANON_ENUM_ENUM_CONVERSION) check_cxx_compiler_flag(-Wdeprecated-builtins HAS_DEPRECATED_BUILTINS) check_cxx_compiler_flag(-Wdeprecated-copy HAS_DEPRECATED_COPY) check_cxx_compiler_flag(-Wdeprecated-declarations HAS_DEPRECATED_DECLARATIONS) + check_cxx_compiler_flag(-Wdeprecated-this-capture HAS_DEPRECATED_THIS_CAPTURE) check_cxx_compiler_flag(-Wenum-constexpr-conversion HAS_ENUM_CONSTEXPR_CONVERSION) check_cxx_compiler_flag(-Wformat-truncation HAS_FORMAT_TRUNCATION) check_cxx_compiler_flag(-Wignored-attributes HAS_IGNORED_ATTRIBUTES) @@ -656,15 +658,15 @@ else() check_cxx_compiler_flag(-Wshorten-64-to-32 HAS_SHORTEN_64_TO_32) check_cxx_compiler_flag(-Wstrict-aliasing HAS_STRICT_ALIASING) check_nvcc_compiler_flag(-Wstrict-aliasing NVCC_HAS_STRICT_ALIASING) + check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) check_cxx_compiler_flag(-Wtautological-pointer-compare HAS_TAUTOLOGICAL_POINTER_COMPARE) check_cxx_compiler_flag(-Wundefined-var-template HAS_UNDEFINED_VAR_TEMPLATE) check_cxx_compiler_flag(-Wunused-but-set-parameter HAS_UNUSED_BUT_SET_PARAMETER) check_cxx_compiler_flag(-Wunused-but-set-variable HAS_UNUSED_BUT_SET_VARIABLE) check_cxx_compiler_flag(-Wunused-variable HAS_UNUSED_VARIABLE) check_cxx_compiler_flag(-Wuseless-cast HAS_USELESS_CAST) - check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) + if(onnxruntime_ENABLE_TRAINING_APIS) - check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) if(HAS_DANGLING_REFERENCE) list(APPEND ORT_WARNING_FLAGS -Wno-dangling-reference) endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 43f18abbe9522..cb737ee53639f 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -91,6 +91,7 @@ if (NOT WIN32) google_nsync URL ${DEP_URL_google_nsync} URL_HASH SHA1=${DEP_SHA1_google_nsync} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/nsync/nsync_1.26.0.patch FIND_PACKAGE_ARGS NAMES nsync unofficial-nsync ) #nsync tests failed on Mac Build diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 7e992fb33077c..f2be742458313 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -352,12 +352,12 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) # make both maccatalyst and other builds do the same thing. set(CUR_TARGET_CMAKE_SOURCE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/${_LIB}.dir) add_custom_command(TARGET onnxruntime POST_BUILD - COMMAND ar -t $ | grep "\.o$" > ${_LIB}.object_file_list.txt + COMMAND /usr/bin/ar -t $ | grep "\.o$" > ${_LIB}.object_file_list.txt COMMAND ${CMAKE_COMMAND} -E env python3 ${CMAKE_CURRENT_SOURCE_DIR}/maccatalyst_prepare_objects_for_prelink.py ${CUR_TARGET_CMAKE_SOURCE_LIB_DIR} ${CUR_STATIC_LIB_OBJ_DIR} ${CUR_STATIC_LIB_OBJ_DIR}/${_LIB}.object_file_list.txt WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR}) else() add_custom_command(TARGET onnxruntime POST_BUILD - COMMAND ar ARGS -x $ + COMMAND /usr/bin/ar ARGS -x $ WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR}) endif() endif() @@ -378,12 +378,12 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) # do the pre-link with `ld -r` to create a single relocatable object with correct symbol visibility add_custom_command(TARGET onnxruntime POST_BUILD - COMMAND ld ARGS -r -o ${STATIC_LIB_DIR}/prelinked_objects.o */*.o ../*.a + COMMAND /usr/bin/ld ARGS -r -o ${STATIC_LIB_DIR}/prelinked_objects.o */*.o ../*.a WORKING_DIRECTORY ${STATIC_LIB_TEMP_DIR}) # create the static library add_custom_command(TARGET onnxruntime POST_BUILD - COMMAND libtool -static -o ${STATIC_FRAMEWORK_DIR}/onnxruntime prelinked_objects.o + COMMAND /usr/bin/libtool -static -o ${STATIC_FRAMEWORK_DIR}/onnxruntime prelinked_objects.o WORKING_DIRECTORY ${STATIC_LIB_DIR}) # Assemble the other pieces of the static framework diff --git a/cmake/onnxruntime_config.h.in b/cmake/onnxruntime_config.h.in index e3ea767401ddc..bbddefe531cb8 100644 --- a/cmake/onnxruntime_config.h.in +++ b/cmake/onnxruntime_config.h.in @@ -9,6 +9,7 @@ #cmakedefine HAS_CLASS_MEMACCESS #cmakedefine HAS_DEPRECATED_COPY #cmakedefine HAS_DEPRECATED_DECLARATIONS +#cmakedefine HAS_DEPRECATED_THIS_CAPTURE #cmakedefine HAS_FORMAT_TRUNCATION #cmakedefine HAS_IGNORED_ATTRIBUTES #cmakedefine HAS_MAYBE_UNINITIALIZED diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 4b880c4437dfd..a4ba85e868896 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -893,8 +893,6 @@ if (MSVC) set_property(SOURCE "${TEST_SRC_DIR}/optimizer/graph_transform_test.cc" "${TEST_SRC_DIR}/optimizer/qdq_transformer_test.cc" APPEND PROPERTY COMPILE_OPTIONS "/bigobj") - set_property(SOURCE "${TEST_SRC_DIR}/optimizer/qdq_transformer_test.cc" - APPEND PROPERTY COMPILE_OPTIONS "/bigobj") else() target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses") endif() diff --git a/cmake/patches/composable_kernel/Fix_Clang_Build.patch b/cmake/patches/composable_kernel/Fix_Clang_Build.patch index 73ece647d82c7..d63da63445fde 100644 --- a/cmake/patches/composable_kernel/Fix_Clang_Build.patch +++ b/cmake/patches/composable_kernel/Fix_Clang_Build.patch @@ -3,22 +3,22 @@ index c23746e7f..bc326c8b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,10 +23,10 @@ endif() - + set(version 1.1.0) # Check support for CUDA/HIP in Cmake -project(composable_kernel VERSION ${version} LANGUAGES CXX) +project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) include(CTest) - + -find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) +find_package(Python3 COMPONENTS Interpreter REQUIRED) - + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") - + @@ -227,27 +227,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") - + -## OpenMP -if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - # workaround issue hipcc in rocm3.5 cannot find openmp @@ -53,11 +53,11 @@ index c23746e7f..bc326c8b5 100644 -else() - add_compile_definitions(__HIP_PLATFORM_HCC__=1) -endif() - + ## tidy include(EnableCompilerWarnings) @@ -541,11 +514,3 @@ rocm_install(FILES - + set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") - @@ -88,7 +88,7 @@ index c0894f1d7..559481fee 100644 @@ -6,19 +6,7 @@ #include #include - + -// To be removed, which really does not tell the location of failed HIP functional call -inline void hip_check_error(hipError_t x) -{ @@ -121,9 +121,9 @@ index a164c3f94..293ead89a 100644 --- a/include/ck_tile/core/utility/transpose_vectors.hpp +++ b/include/ck_tile/core/utility/transpose_vectors.hpp @@ -11,6 +11,9 @@ - + namespace ck_tile { - + +template +constexpr bool always_false = false; + @@ -139,7 +139,7 @@ index a164c3f94..293ead89a 100644 } } }; - + + } // namespace ck_tile + @@ -150,7 +150,7 @@ index 3acdb4d87..cc26e184f 100644 @@ -8,20 +8,7 @@ #include #include - + -namespace ck_tile { -// To be removed, which really does not tell the location of failed HIP functional call -CK_TILE_HOST void hip_check_error(hipError_t x) @@ -198,3 +198,41 @@ index c035e7e56..8c5f36d2e 100644 set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) clang_tidy_check(${INSTANCE_NAME}) set(result 0) +--- ./include/ck/utility/amd_buffer_addressing.hpp 2024-09-05 10:12:33.343091000 +0800 ++++ ./include/ck/utility/amd_buffer_addressing_new.hpp 2024-09-05 10:12:20.276686000 +0800 +@@ -991,7 +991,8 @@ + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), +- "s"(src_resource)); ++ "s"(src_resource) ++ : "memory"); + #else + // LDS pointer must be attributed with the LDS address space. + __attribute__((address_space(3))) uint32_t* lds_ptr = +--- ./include/ck_tile/core/arch/amd_buffer_addressing.hpp 2024-09-05 10:18:28.884031000 +0800 ++++ ./include/ck_tile/core/arch/amd_buffer_addressing_new.hpp 2024-09-05 10:17:29.434931000 +0800 +@@ -26,7 +26,12 @@ + CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) + { + buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; +- return __builtin_bit_cast(int32x4_t, res); ++ int32x4_t r = __builtin_bit_cast(int32x4_t, res); ++ r.x = __builtin_amdgcn_readfirstlane(r.x); ++ r.y = __builtin_amdgcn_readfirstlane(r.y); ++ r.z = __builtin_amdgcn_readfirstlane(r.z); ++ r.w = __builtin_amdgcn_readfirstlane(r.w); ++ return r; + } + + // TODO: glc/slc/... +@@ -2016,7 +2021,8 @@ + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), +- "s"(src_resource)); ++ "s"(src_resource) ++ : "memory"); + #else + // LDS pointer must be attributed with the LDS address space. + __attribute__((address_space(3))) uint32_t* lds_ptr = diff --git a/cmake/patches/nsync/nsync_1.26.0.patch b/cmake/patches/nsync/nsync_1.26.0.patch new file mode 100644 index 0000000000000..78ef2b3cb20d4 --- /dev/null +++ b/cmake/patches/nsync/nsync_1.26.0.patch @@ -0,0 +1,14 @@ +diff --git a/public/nsync_atomic.h b/public/nsync_atomic.h +index aebe4f7..466a262 100644 +--- a/public/nsync_atomic.h ++++ b/public/nsync_atomic.h +@@ -45,7 +45,8 @@ NSYNC_CPP_END_ + NSYNC_CPP_START_ + typedef std::atomic nsync_atomic_uint32_; + NSYNC_CPP_END_ +-#define NSYNC_ATOMIC_UINT32_INIT_ ATOMIC_VAR_INIT (0) ++// Replace deprecated ATOMIC_VAR_INIT with std::atomic brace initialization ++#define NSYNC_ATOMIC_UINT32_INIT_ { 0 } + #define NSYNC_ATOMIC_UINT32_LOAD_(p) (std::atomic_load (p)) + #define NSYNC_ATOMIC_UINT32_STORE_(p,v) (std::atomic_store ((p), (uint32_t) (v))) + #define NSYNC_ATOMIC_UINT32_PTR_(p) (p) diff --git a/docs/Coding_Conventions_and_Standards.md b/docs/Coding_Conventions_and_Standards.md index e8e1e7dc9ccd8..f18f1036efee8 100644 --- a/docs/Coding_Conventions_and_Standards.md +++ b/docs/Coding_Conventions_and_Standards.md @@ -155,7 +155,7 @@ Using `Show Code Coverage Coloring` will allow you to visually inspect which lin This project uses [lintrunner](https://github.com/suo/lintrunner) for linting. It provides a consistent linting experience locally and in CI. You can install the dependencies and initialize with ```sh -pip install lintrunner lintrunner-adapters +pip install -r requirements-lintrunner.txt lintrunner init ``` diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 121240e6e18f9..734506681ab60 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -27,12 +27,12 @@ Do not modify directly.* |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |AffineGrid|*in* theta:**T1**
*in* size:**T2**
*out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| -|||[1, 10]|**T** = tensor(float), tensor(int32)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |Asin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float)| |Asinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)| |Atan|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float)| @@ -482,7 +482,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -508,7 +508,7 @@ Do not modify directly.* |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| -|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| diff --git a/include/onnxruntime/core/common/eigen_common_wrapper.h b/include/onnxruntime/core/common/eigen_common_wrapper.h index 57599e04037dc..19efa7bcff107 100644 --- a/include/onnxruntime/core/common/eigen_common_wrapper.h +++ b/include/onnxruntime/core/common/eigen_common_wrapper.h @@ -49,6 +49,12 @@ #pragma GCC diagnostic ignored "-Wshorten-64-to-32" #endif +// eigen-src/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h:215:9: +// error: implicit capture of 'this' with a capture default of '=' is deprecated [-Werror,-Wdeprecated-this-capture] +#ifdef HAS_DEPRECATED_THIS_CAPTURE +#pragma GCC diagnostic ignored "-Wdeprecated-this-capture" +#endif + #elif defined(_MSC_VER) // build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76): // warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a4ec66761c4ba..3aa98bb020452 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3650,10 +3650,10 @@ struct OrtApi { * - "73" * - "75" * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). - "enable_htp_fp16_precision": Only used for float32 model. + "enable_htp_fp16_precision": Used for float32 model for HTP backend. Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. - - "0": Default. With fp32 precision. - - "1": With fp16 precision. + - "0": With fp32 precision. + - "1": Default. With fp16 precision. "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context. - "0": Default. Disabled. - "1": Enabled. diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index e1ee2c14fd9d1..3f276a3670156 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -54,7 +54,7 @@ public class OnnxTensor extends OnnxTensorLike { * the state of this buffer without first getting the reference via {@link #getBufferRef()}. * * @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is - * a copy of a user buffer.) + * a copy of a user buffer or array.) */ public boolean ownsBuffer() { return this.ownsBuffer; @@ -62,8 +62,8 @@ public boolean ownsBuffer() { /** * Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not - * backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by - * ORT) this method returns an empty {@link Optional}. + * backed by a buffer (i.e., it is backed by memory allocated by ORT) this method returns an empty + * {@link Optional}. * *

Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be * used to repeatedly update a single tensor for multiple different inferences without allocating @@ -77,7 +77,116 @@ public boolean ownsBuffer() { * @return A reference to the buffer. */ public Optional getBufferRef() { - return Optional.ofNullable(buffer); + return Optional.ofNullable(duplicate(buffer)); + } + + /** + * Duplicates the buffer to ensure concurrent reads don't disrupt the buffer position. Concurrent + * writes will modify the underlying memory in a racy way, don't do that. + * + *

Can be replaced to a call to buf.duplicate() in Java 9+. + * + * @param buf The buffer to duplicate. + * @return A copy of the buffer which refers to the same underlying memory, but has an independent + * position, limit and mark. + */ + private static Buffer duplicate(Buffer buf) { + if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).duplicate().order(ByteOrder.nativeOrder()); + } else if (buf instanceof ShortBuffer) { + return ((ShortBuffer) buf).duplicate(); + } else if (buf instanceof IntBuffer) { + return ((IntBuffer) buf).duplicate(); + } else if (buf instanceof LongBuffer) { + return ((LongBuffer) buf).duplicate(); + } else if (buf instanceof FloatBuffer) { + return ((FloatBuffer) buf).duplicate(); + } else if (buf instanceof DoubleBuffer) { + return ((DoubleBuffer) buf).duplicate(); + } else { + throw new IllegalStateException("Unknown buffer type " + buf.getClass()); + } + } + + /** + * Checks that the buffer is the right type for the {@code info.type}, and if it's a {@link + * ByteBuffer} then convert it to the right type. If it's not convertible it throws {@link + * IllegalStateException}. + * + *

Note this method converts FP16 and BFLOAT16 ShortBuffers into FP32 FloatBuffers, to preserve + * compatibility with existing {@link #getValue} calls. + * + * @param buf The buffer to convert. + * @return The buffer with the expected type. + */ + private Buffer castBuffer(Buffer buf) { + switch (info.type) { + case FLOAT: + if (buf instanceof FloatBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asFloatBuffer(); + } + break; + case DOUBLE: + if (buf instanceof DoubleBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asDoubleBuffer(); + } + break; + case BOOL: + case INT8: + case UINT8: + if (buf instanceof ByteBuffer) { + return buf; + } + break; + case BFLOAT16: + if (buf instanceof ShortBuffer) { + ShortBuffer bf16Buf = (ShortBuffer) buf; + return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf); + } else if (buf instanceof ByteBuffer) { + ShortBuffer bf16Buf = ((ByteBuffer) buf).asShortBuffer(); + return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf); + } + break; + case FLOAT16: + if (buf instanceof ShortBuffer) { + ShortBuffer fp16Buf = (ShortBuffer) buf; + return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf); + } else if (buf instanceof ByteBuffer) { + ShortBuffer fp16Buf = ((ByteBuffer) buf).asShortBuffer(); + return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf); + } + break; + case INT16: + if (buf instanceof ShortBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asShortBuffer(); + } + break; + case INT32: + if (buf instanceof IntBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asIntBuffer(); + } + break; + case INT64: + if (buf instanceof LongBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asLongBuffer(); + } + break; + } + throw new IllegalStateException( + "Invalid buffer type for cast operation, found " + + buf.getClass() + + " expected something convertible to " + + info.type); } @Override @@ -133,15 +242,26 @@ public Object getValue() throws OrtException { Object carrier = info.makeCarrier(); if (info.getNumElements() > 0) { // If the tensor has values copy them out - getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); - } - if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { - // We read the strings out from native code in a flat array and then reshape - // to the desired output shape. - return OrtUtil.reshape((String[]) carrier, info.shape); - } else { - return carrier; + if (info.type == OnnxJavaType.STRING) { + // We read the strings out from native code in a flat array and then reshape + // to the desired output shape if necessary. + getStringArray(OnnxRuntime.ortApiHandle, nativeHandle, (String[]) carrier); + if (info.shape.length != 1) { + carrier = OrtUtil.reshape((String[]) carrier, info.shape); + } + } else { + // Wrap ORT owned memory in buffer, otherwise use our reference + Buffer buf; + if (buffer == null) { + buf = castBuffer(getBuffer()); + } else { + buf = castBuffer(duplicate(buffer)); + } + // Copy out buffer into arrays + OrtUtil.fillArrayFromBuffer(info, buf, 0, carrier); + } } + return carrier; } } @@ -175,8 +295,8 @@ public synchronized void close() { public ByteBuffer getByteBuffer() { checkClosed(); if (info.type != OnnxJavaType.STRING) { - ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle); - ByteBuffer output = ByteBuffer.allocate(buffer.capacity()); + ByteBuffer buffer = getBuffer(); + ByteBuffer output = ByteBuffer.allocate(buffer.capacity()).order(ByteOrder.nativeOrder()); output.put(buffer); output.rewind(); return output; @@ -201,12 +321,12 @@ public FloatBuffer getFloatBuffer() { output.rewind(); return output; } else if (info.type == OnnxJavaType.FLOAT16) { - // if it's fp16 we need to copy it out by hand. + // if it's fp16 we need to convert it. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); return Fp16Conversions.convertFp16BufferToFloatBuffer(buffer); } else if (info.type == OnnxJavaType.BFLOAT16) { - // if it's bf16 we need to copy it out by hand. + // if it's bf16 we need to convert it. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); return Fp16Conversions.convertBf16BufferToFloatBuffer(buffer); @@ -331,7 +451,7 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType) private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException; - private native void getArray(long apiHandle, long nativeHandle, Object carrier) + private native void getStringArray(long apiHandle, long nativeHandle, String[] carrier) throws OrtException; private native void close(long apiHandle, long nativeHandle); @@ -387,21 +507,32 @@ static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Objec info); } } else { + Buffer buf; if (info.shape.length == 0) { - data = OrtUtil.convertBoxedPrimitiveToArray(info.type, data); - if (data == null) { + buf = OrtUtil.convertBoxedPrimitiveToBuffer(info.type, data); + if (buf == null) { throw new OrtException( "Failed to convert a boxed primitive to an array, this is an error with the ORT Java API, please report this message & stack trace. JavaType = " + info.type + ", object = " + data); } + } else { + buf = OrtUtil.convertArrayToBuffer(info, data); } return new OnnxTensor( - createTensor( - OnnxRuntime.ortApiHandle, allocator.handle, data, info.shape, info.onnxType.value), + createTensorFromBuffer( + OnnxRuntime.ortApiHandle, + allocator.handle, + buf, + 0, + info.type.size * info.numElements, + info.shape, + info.onnxType.value), allocator.handle, - info); + info, + buf, + true); } } else { throw new IllegalStateException("Trying to create an OnnxTensor with a closed OrtAllocator."); @@ -627,7 +758,26 @@ static OnnxTensor createTensor( */ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long[] shape) throws OrtException { - return createTensor(env, env.defaultAllocator, data, shape); + return createTensor(env, env.defaultAllocator, data, shape, OnnxJavaType.INT16); + } + + /** + * Create an OnnxTensor backed by a direct ShortBuffer. The buffer should be in nativeOrder. + * + *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime + * of the tensor. Uses the default allocator. + * + * @param env The current OrtEnvironment. + * @param data The tensor data. + * @param shape The shape of tensor. + * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16}, + * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}. + * @return An OnnxTensor of the required shape. + * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. + */ + public static OnnxTensor createTensor( + OrtEnvironment env, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException { + return createTensor(env, env.defaultAllocator, data, shape, type); } /** @@ -640,15 +790,23 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. + * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16}, + * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( - OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape) + OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException { if (!allocator.isClosed()) { - OnnxJavaType type = OnnxJavaType.INT16; - return createTensor(type, allocator, data, shape); + if ((type == OnnxJavaType.BFLOAT16) + || (type == OnnxJavaType.FLOAT16) + || (type == OnnxJavaType.INT16)) { + return createTensor(type, allocator, data, shape); + } else { + throw new IllegalArgumentException( + "Only int16, float16 or bfloat16 tensors can be created from ShortBuffer."); + } } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -768,10 +926,6 @@ private static OnnxTensor createTensor( tuple.isCopy); } - private static native long createTensor( - long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType) - throws OrtException; - private static native long createTensorFromBuffer( long apiHandle, long allocatorHandle, diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 4f3dee3c00b91..2f44236e4ef67 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -26,10 +26,10 @@ public final class OrtUtil { private OrtUtil() {} /** - * Converts an long shape into a int shape. + * Converts a long shape into an int shape. * - *

Validates that the shape has more than 1 elements, less than 9 elements, each element is - * less than {@link Integer#MAX_VALUE} and that each entry is non-negative. + *

Validates that the shape has more than 1 element, less than 9 elements, each element is less + * than {@link Integer#MAX_VALUE} and that each entry is non-negative. * * @param shape The long shape. * @return The int shape. @@ -460,6 +460,308 @@ static Object convertBoxedPrimitiveToArray(OnnxJavaType javaType, Object data) { } } + /** + * Stores a boxed primitive in a single element buffer of the unboxed type. + * + *

If it's not a boxed primitive then it returns null. + * + * @param javaType The type of the boxed primitive. + * @param data The boxed primitive. + * @return The primitive in a direct buffer. + */ + static Buffer convertBoxedPrimitiveToBuffer(OnnxJavaType javaType, Object data) { + switch (javaType) { + case FLOAT: + { + FloatBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + buf.put(0, (Float) data); + return buf; + } + case DOUBLE: + { + DoubleBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer(); + buf.put(0, (Double) data); + return buf; + } + case BOOL: + { + ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()); + buf.put(0, ((boolean) data) ? (byte) 1 : (byte) 0); + return buf; + } + case UINT8: + case INT8: + { + ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()); + buf.put(0, (Byte) data); + return buf; + } + case FLOAT16: + case BFLOAT16: + case INT16: + { + ShortBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asShortBuffer(); + buf.put(0, (Short) data); + return buf; + } + case INT32: + { + IntBuffer buf = + ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asIntBuffer(); + buf.put(0, (Integer) data); + return buf; + } + case INT64: + { + LongBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asLongBuffer(); + buf.put(0, (Long) data); + return buf; + } + case STRING: + case UNKNOWN: + default: + return null; + } + } + + /** + * Copies a Java (possibly multidimensional) array into a direct {@link Buffer}. + * + *

Throws {@link IllegalArgumentException} if the array is not an array of Java primitives or + * if the array is ragged. + * + * @param info The tensor info object containing the types and shape of the array. + * @param array The array object. + * @return A direct buffer containing all the elements. + */ + static Buffer convertArrayToBuffer(TensorInfo info, Object array) { + ByteBuffer byteBuffer = + ByteBuffer.allocateDirect((int) info.numElements * info.type.size) + .order(ByteOrder.nativeOrder()); + + Buffer buffer; + switch (info.type) { + case FLOAT: + buffer = byteBuffer.asFloatBuffer(); + break; + case DOUBLE: + buffer = byteBuffer.asDoubleBuffer(); + break; + case BOOL: + case INT8: + case UINT8: + // no-op, it's already a bytebuffer + buffer = byteBuffer; + break; + case BFLOAT16: + case FLOAT16: + case INT16: + buffer = byteBuffer.asShortBuffer(); + break; + case INT32: + buffer = byteBuffer.asIntBuffer(); + break; + case INT64: + buffer = byteBuffer.asLongBuffer(); + break; + case STRING: + case UNKNOWN: + default: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + + fillBufferFromArray(info, array, 0, buffer); + + if (buffer.remaining() != 0) { + throw new IllegalArgumentException( + "Failed to copy all elements into the buffer, expected to copy " + + info.numElements + + " into a buffer of capacity " + + buffer.capacity() + + " but had " + + buffer.remaining() + + " values left over."); + } + buffer.rewind(); + + return buffer; + } + + /** + * Fills the provided buffer with the values from the array, recursing through the array + * structure. + * + * @param info The tensor info containing the type and shape of the array. + * @param array The array object to read from. + * @param curDim The current dimension we're processing. + * @param buffer The buffer to write to. + */ + private static void fillBufferFromArray( + TensorInfo info, Object array, int curDim, Buffer buffer) { + if (curDim == info.shape.length - 1) { + // Reached primitive values, copy into buffer + switch (info.type) { + case FLOAT: + float[] fArr = (float[]) array; + FloatBuffer fBuf = (FloatBuffer) buffer; + fBuf.put(fArr); + break; + case DOUBLE: + double[] dArr = (double[]) array; + DoubleBuffer dBuf = (DoubleBuffer) buffer; + dBuf.put(dArr); + break; + case INT8: + case UINT8: + byte[] bArr = (byte[]) array; + ByteBuffer bBuf = (ByteBuffer) buffer; + bBuf.put(bArr); + break; + case FLOAT16: + case BFLOAT16: + case INT16: + short[] sArr = (short[]) array; + ShortBuffer sBuf = (ShortBuffer) buffer; + sBuf.put(sArr); + break; + case INT32: + int[] iArr = (int[]) array; + IntBuffer iBuf = (IntBuffer) buffer; + iBuf.put(iArr); + break; + case INT64: + long[] lArr = (long[]) array; + LongBuffer lBuf = (LongBuffer) buffer; + lBuf.put(lArr); + break; + case BOOL: + boolean[] boolArr = (boolean[]) array; + ByteBuffer boolBuf = (ByteBuffer) buffer; + for (int i = 0; i < boolArr.length; i++) { + boolBuf.put(boolArr[i] ? (byte) 1 : (byte) 0); + } + break; + case STRING: + case UNKNOWN: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + } else { + // Recurse through array + long expectedSize = info.shape[curDim]; + long actualSize = Array.getLength(array); + if (expectedSize != actualSize) { + throw new IllegalArgumentException( + "Mismatch in array sizes, expected " + + expectedSize + + " at dim " + + curDim + + " from shape " + + Arrays.toString(info.shape) + + ", found " + + actualSize); + } else { + for (int i = 0; i < actualSize; i++) { + fillBufferFromArray(info, Array.get(array, i), curDim + 1, buffer); + } + } + } + } + + /** + * Fills the provided array with the values from the buffer, recursing through the array + * structure. + * + * @param info The tensor info containing the type and shape of the array. + * @param buffer The buffer to read from. + * @param curDim The current dimension we're processing. + * @param array The array object to write to. + */ + static void fillArrayFromBuffer(TensorInfo info, Buffer buffer, int curDim, Object array) { + if (curDim == info.shape.length - 1) { + // Reached primitive values, copy into buffer + switch (info.type) { + case FLOAT16: + case BFLOAT16: + case FLOAT: + float[] fArr = (float[]) array; + FloatBuffer fBuf = (FloatBuffer) buffer; + fBuf.get(fArr); + break; + case DOUBLE: + double[] dArr = (double[]) array; + DoubleBuffer dBuf = (DoubleBuffer) buffer; + dBuf.get(dArr); + break; + case INT8: + case UINT8: + byte[] bArr = (byte[]) array; + ByteBuffer bBuf = (ByteBuffer) buffer; + bBuf.get(bArr); + break; + case INT16: + short[] sArr = (short[]) array; + ShortBuffer sBuf = (ShortBuffer) buffer; + sBuf.get(sArr); + break; + case INT32: + int[] iArr = (int[]) array; + IntBuffer iBuf = (IntBuffer) buffer; + iBuf.get(iArr); + break; + case INT64: + long[] lArr = (long[]) array; + LongBuffer lBuf = (LongBuffer) buffer; + lBuf.get(lArr); + break; + case BOOL: + boolean[] boolArr = (boolean[]) array; + ByteBuffer boolBuf = (ByteBuffer) buffer; + for (int i = 0; i < boolArr.length; i++) { + // Test to see if the byte is non-zero, non-zero bytes are true, zero bytes are false. + boolArr[i] = boolBuf.get() != 0; + } + break; + case STRING: + case UNKNOWN: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + } else { + // Recurse through array + long expectedSize = info.shape[curDim]; + long actualSize = Array.getLength(array); + if (expectedSize != actualSize) { + throw new IllegalArgumentException( + "Mismatch in array sizes, expected " + + expectedSize + + " at dim " + + curDim + + " from shape " + + Arrays.toString(info.shape) + + ", found " + + actualSize); + } else { + for (int i = 0; i < actualSize; i++) { + fillArrayFromBuffer(info, buffer, curDim + 1, Array.get(array, i)); + } + } + } + } + /** * Returns expected JDK map capacity for a given size, this factors in the default JDK load factor * diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 1c21387b50455..f3e9f21ef408d 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -323,6 +323,9 @@ public long getNumElements() { * all elements as that's the expected format of the native code. It can be reshaped to the * correct shape using {@link OrtUtil#reshape(String[],long[])}. * + *

For fp16 and bf16 tensors the output carrier type is float, and so this method produces + * multidimensional float arrays. + * * @return A multidimensional array of the appropriate primitive type (or String). * @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is * greater than an int). @@ -335,6 +338,8 @@ public Object makeCarrier() throws OrtException { + Arrays.toString(shape)); } switch (type) { + case BFLOAT16: + case FLOAT16: case FLOAT: return OrtUtil.newFloatArray(shape); case DOUBLE: diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 7b26291581395..6a3c279073860 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -502,104 +502,6 @@ jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSeque return sequenceInfo; } -int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor) { - int32_t inputLength = (*jniEnv)->GetArrayLength(jniEnv, inputArray); - int64_t consumedSize = inputLength * onnxTypeSize(onnxType); - switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // maps to c type uint8_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { // maps to c type int8_t - jbyteArray typedArr = (jbyteArray)inputArray; - (*jniEnv)->GetByteArrayRegion(jniEnv, typedArr, 0, inputLength, (jbyte * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: // maps to c type uint16_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { // maps to c type int16_t - jshortArray typedArr = (jshortArray)inputArray; - (*jniEnv)->GetShortArrayRegion(jniEnv, typedArr, 0, inputLength, (jshort * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: // maps to c type uint32_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { // maps to c type int32_t - jintArray typedArr = (jintArray)inputArray; - (*jniEnv)->GetIntArrayRegion(jniEnv, typedArr, 0, inputLength, (jint * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: // maps to c type uint64_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { // maps to c type int64_t - jlongArray typedArr = (jlongArray)inputArray; - (*jniEnv)->GetLongArrayRegion(jniEnv, typedArr, 0, inputLength, (jlong * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "16-bit float not supported."); - return -1; - /* - float *floatArr = malloc(sizeof(float) * inputLength); - uint16_t *halfArr = (uint16_t *) outputTensor; - for (uint32_t i = 0; i < inputLength; i++) { - floatArr[i] = convertHalfToFloat(halfArr[i]); - } - jfloatArray typedArr = (jfloatArray) inputArray; - (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, floatArr); - free(floatArr); - return consumedSize; - */ - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { // maps to c type float - jfloatArray typedArr = (jfloatArray)inputArray; - (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, (jfloat * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { // maps to c type double - jdoubleArray typedArr = (jdoubleArray)inputArray; - (*jniEnv)->GetDoubleArrayRegion(jniEnv, typedArr, 0, inputLength, (jdouble * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { // maps to c++ type std::string - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "String is not supported."); - return -1; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - jbooleanArray typedArr = (jbooleanArray)inputArray; - (*jniEnv)->GetBooleanArrayRegion(jniEnv, typedArr, 0, inputLength, (jboolean *)outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: // complex with float64 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: - default: { - throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid outputTensor element type."); - return -1; - } - } -} - -int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor) { - if (dimensionsRemaining == 1) { - // write out 1d array of the respective primitive type - return copyJavaToPrimitiveArray(jniEnv, onnxType, inputArray, outputTensor); - } else { - // recurse through the dimensions - // Java arrays are objects until the final dimension - jobjectArray inputObjArr = (jobjectArray)inputArray; - int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, inputObjArr); - int64_t sizeConsumed = 0; - for (int32_t i = 0; i < dimLength; i++) { - jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, inputObjArr, i); - int64_t consumed = copyJavaToTensor(jniEnv, onnxType, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr, outputTensor + sizeConsumed); - sizeConsumed += consumed; - // Cleanup reference to childArr so it doesn't prevent GC. - (*jniEnv)->DeleteLocalRef(jniEnv, childArr); - // If we failed to copy an array then break and return. - if (consumed == -1) { - return -1; - } - } - return sizeConsumed; - } -} - int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray) { int32_t outputLength = (*jniEnv)->GetArrayLength(jniEnv, outputArray); if (outputLength == 0) return 0; @@ -697,65 +599,6 @@ int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxT } } -int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, - size_t dimensionsRemaining, jarray outputArray) { - if (dimensionsRemaining == 1) { - // write out 1d array of the respective primitive type - return copyPrimitiveArrayToJava(jniEnv, onnxType, inputTensor, outputArray); - } else { - // recurse through the dimensions - // Java arrays are objects until the final dimension - jobjectArray outputObjArr = (jobjectArray)outputArray; - int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, outputObjArr); - int64_t sizeConsumed = 0; - for (int32_t i = 0; i < dimLength; i++) { - jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, outputObjArr, i); - int64_t consumed = copyTensorToJava(jniEnv, onnxType, inputTensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr); - sizeConsumed += consumed; - // Cleanup reference to childArr so it doesn't prevent GC. - (*jniEnv)->DeleteLocalRef(jniEnv, childArr); - // If we failed to copy an array then break and return. - if (consumed == -1) { - return -1; - } - } - return sizeConsumed; - } -} - -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { - jobject tempString = NULL; - // Get the buffer size needed - size_t totalStringLength = 0; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, &totalStringLength)); - if (code != ORT_OK) { - return NULL; - } - - // Create the character and offset buffers, character is one larger to allow zero termination. - char * characterBuffer = malloc(sizeof(char)*(totalStringLength+1)); - if (characterBuffer == NULL) { - throwOrtException(jniEnv, 1, "OOM error"); - } else { - size_t * offsets = malloc(sizeof(size_t)); - if (offsets != NULL) { - // Get a view on the String data - code = checkOrtStatus(jniEnv, api, api->GetStringTensorContent(tensor, characterBuffer, totalStringLength, offsets, 1)); - - if (code == ORT_OK) { - size_t curSize = (offsets[0]) + 1; - characterBuffer[curSize-1] = '\0'; - tempString = (*jniEnv)->NewStringUTF(jniEnv, characterBuffer); - } - - free((void*)characterBuffer); - free((void*)offsets); - } - } - - return tempString; -} - OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray) { size_t bufferSize = 16; char * tempBuffer = malloc(bufferSize); diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 023bc0c739583..7f41e06371f2a 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -54,16 +54,8 @@ jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInf jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info); -int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor); - -int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor); - int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray); -int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray); - -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); - OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray); jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c index b694f57357bb5..d757bd6281499 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c +++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c @@ -8,72 +8,6 @@ #include "OrtJniUtil.h" #include "ai_onnxruntime_OnnxTensor.h" -/* - * Class: ai_onnxruntime_OnnxTensor - * Method: createTensor - * Signature: (JJLjava/lang/Object;[JI)J - */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor - (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj, - jlongArray shape, jint onnxTypeJava) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; - // Convert type to ONNX C enum - ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava); - - // Extract the shape information - jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL); - jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape); - - // Create the OrtValue - OrtValue* ortValue = NULL; - OrtErrorCode code = checkOrtStatus(jniEnv, api, - api->CreateTensorAsOrtValue( - allocator, (int64_t*)shapeArr, shapeLen, onnxType, &ortValue - ) - ); - (*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT); - - int failed = 0; - if (code == ORT_OK) { - // Get a reference to the OrtValue's data - uint8_t* tensorData = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&tensorData)); - if (code == ORT_OK) { - // Check if we're copying a scalar or not - if (shapeLen == 0) { - // Scalars are passed in as a single element array - int64_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, dataObj, tensorData); - failed = copied == -1 ? 1 : failed; - } else { - // Extract the tensor shape information - JavaTensorTypeShape typeShape; - code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue); - - if (code == ORT_OK) { - // Copy the java array into the tensor - int64_t copied = copyJavaToTensor(jniEnv, onnxType, typeShape.elementCount, - typeShape.dimensions, dataObj, tensorData); - failed = copied == -1 ? 1 : failed; - } else { - failed = 1; - } - } - } else { - failed = 1; - } - } - - if (failed) { - api->ReleaseValue(ortValue); - ortValue = NULL; - } - - // Return the pointer to the OrtValue - return (jlong) ortValue; -} - /* * Class: ai_onnxruntime_OnnxTensor * Method: createTensorFromBuffer @@ -227,7 +161,7 @@ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer size_t sizeBytes = typeShape.elementCount * typeSize; uint8_t* arr = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&arr)); if (code == ORT_OK) { return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes); @@ -401,11 +335,11 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool /* * Class: ai_onnxruntime_OnnxTensor - * Method: getArray - * Signature: (JJLjava/lang/Object;)V + * Method: getStringArray + * Signature: (JJ[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) { +JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getStringArray + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobjectArray carrier) { (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtValue* value = (OrtValue*) handle; @@ -415,12 +349,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier); } else { - uint8_t* arr = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr)); - if (code == ORT_OK) { - copyTensorToJava(jniEnv, typeShape.onnxTypeEnum, arr, typeShape.elementCount, - typeShape.dimensions, (jarray)carrier); - } + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Non-string types are not supported by this codepath, please raise a Github issue as it should not reach here."); } } } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 11141a3a65a3e..7cb6305923279 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -495,12 +495,12 @@ public void throwWrongInputName() throws OrtException { container.put("wrong_name", OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for incorrect name."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unknown input name")); + } finally { + OnnxValue.close(container.values()); } } } @@ -522,12 +522,57 @@ public void throwWrongInputType() throws OrtException { container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for incorrect type."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unexpected input data type")); + } finally { + OnnxValue.close(container.values()); + } + } + } + + @Test + public void throwWrongSizeInput() throws OrtException { + SqueezeNetTuple tuple = openSessionSqueezeNet(); + try (OrtSession session = tuple.session) { + + float[] inputData = tuple.inputData; + NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); + Map container = new HashMap<>(); + float[] wrongSizeData = Arrays.copyOf(inputData, 2 * 224 * 224); + Object tensor = OrtUtil.reshape(wrongSizeData, new long[] {1, 2, 224, 224}); + container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); + try { + session.run(container); + fail("Should throw exception for incorrect size."); + } catch (OrtException e) { + String msg = e.getMessage(); + assertTrue(msg.contains("Got invalid dimensions for input")); + } finally { + OnnxValue.close(container.values()); + } + } + } + + @Test + public void throwWrongRankInput() throws OrtException { + SqueezeNetTuple tuple = openSessionSqueezeNet(); + try (OrtSession session = tuple.session) { + + float[] inputData = tuple.inputData; + NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); + Map container = new HashMap<>(); + Object tensor = OrtUtil.reshape(inputData, new long[] {1, 1, 3, 224, 224}); + container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); + try { + session.run(container); + fail("Should throw exception for incorrect size."); + } catch (OrtException e) { + String msg = e.getMessage(); + assertTrue(msg.contains("Invalid rank for input")); + } finally { + OnnxValue.close(container.values()); } } } @@ -550,12 +595,12 @@ public void throwExtraInputs() throws OrtException { container.put("extra", OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for too many inputs."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unexpected number of inputs")); + } finally { + OnnxValue.close(container.values()); } } } @@ -565,12 +610,11 @@ public void testMultiThreads() throws OrtException, InterruptedException { int numThreads = 10; int loop = 10; SqueezeNetTuple tuple = openSessionSqueezeNet(); + Map container = new HashMap<>(); try (OrtSession session = tuple.session) { - float[] inputData = tuple.inputData; float[] expectedOutput = tuple.outputData; NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); - Map container = new HashMap<>(); long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape; Object tensor = OrtUtil.reshape(inputData, inputShape); container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); @@ -592,8 +636,9 @@ public void testMultiThreads() throws OrtException, InterruptedException { } executor.shutdown(); executor.awaitTermination(1, TimeUnit.MINUTES); - OnnxValue.close(container.values()); assertTrue(executor.isTerminated()); + } finally { + OnnxValue.close(container.values()); } } diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index ea210d96c1507..064f14f3b51ff 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -12,8 +12,11 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; +import java.nio.IntBuffer; import java.nio.ShortBuffer; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.SplittableRandom; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -93,30 +96,108 @@ public void testScalarCreation() throws OrtException { } @Test - public void testBufferCreation() throws OrtException { + public void testArrayCreation() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); - // Test creating a value from an array - // Arrays result in tensors allocated by ORT, so they do not have a backing java.nio.Buffer + // Test creating a value from a single dimensional array float[] arrValues = new float[] {0, 1, 2, 3, 4}; try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { - // array creation isn't backed by buffers - assertFalse(t.ownsBuffer()); - assertFalse(t.getBufferRef().isPresent()); - FloatBuffer buf = t.getFloatBuffer(); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + FloatBuffer buf = (FloatBuffer) t.getBufferRef().get(); float[] output = new float[arrValues.length]; buf.get(output); Assertions.assertArrayEquals(arrValues, output); - // Can't modify the tensor through this buffer. + // Can modify the tensor through this buffer. buf.put(0, 25); - Assertions.assertArrayEquals(arrValues, output); + Assertions.assertArrayEquals(new float[] {25, 1, 2, 3, 4}, (float[]) t.getValue()); } + // Test creating a value from a multidimensional float array + float[][][] arr3dValues = + new float[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}, {9, 10, 11}}, + {{12, 13, 14}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, arr3dValues)) { + Assertions.assertArrayEquals(new long[] {4, 2, 3}, t.getInfo().getShape()); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + float[][][] output = (float[][][]) t.getValue(); + Assertions.assertArrayEquals(arr3dValues, output); + + // Can modify the tensor through the buffer. + FloatBuffer buf = (FloatBuffer) t.getBufferRef().get(); + buf.put(0, 25); + buf.put(12, 32); + buf.put(13, 33); + buf.put(23, 35); + arr3dValues[0][0][0] = 25; + arr3dValues[2][0][0] = 32; + arr3dValues[2][0][1] = 33; + arr3dValues[3][1][2] = 35; + output = (float[][][]) t.getValue(); + Assertions.assertArrayEquals(arr3dValues, output); + } + + // Test creating a value from a multidimensional int array + int[][][] iArr3dValues = + new int[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}, {9, 10, 11}}, + {{12, 13, 14}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, iArr3dValues)) { + Assertions.assertArrayEquals(new long[] {4, 2, 3}, t.getInfo().getShape()); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + int[][][] output = (int[][][]) t.getValue(); + Assertions.assertArrayEquals(iArr3dValues, output); + + // Can modify the tensor through the buffer. + IntBuffer buf = (IntBuffer) t.getBufferRef().get(); + buf.put(0, 25); + iArr3dValues[0][0][0] = 25; + output = (int[][][]) t.getValue(); + Assertions.assertArrayEquals(iArr3dValues, output); + } + + // Test creating a value from a ragged array throws + int[][][] ragged = + new int[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}}, + {{12, 13}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, ragged)) { + Assertions.fail("Can't create tensors from ragged arrays"); + } catch (OrtException e) { + Assertions.assertTrue(e.getMessage().contains("ragged")); + } + + // Test creating a value from a non-array, non-primitive type throws. + List list = new ArrayList<>(5); + list.add(5); + try (OnnxTensor t = OnnxTensor.createTensor(env, list)) { + Assertions.fail("Can't create tensors from lists"); + } catch (OrtException e) { + Assertions.assertTrue(e.getMessage().contains("Cannot convert")); + } + } + + @Test + public void testBufferCreation() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + // Test creating a value from a non-direct byte buffer // Non-direct byte buffers are allocated on the Java heap and must be copied into off-heap - // direct byte buffers - // which can be directly passed to ORT + // direct byte buffers which can be directly passed to ORT + float[] arrValues = new float[] {0, 1, 2, 3, 4}; FloatBuffer nonDirectBuffer = FloatBuffer.allocate(5); nonDirectBuffer.put(arrValues); nonDirectBuffer.rewind(); @@ -335,10 +416,12 @@ public void testFp32ToFp16() throws OrtException { String modelPath = TestHelpers.getResourcePath("/java-fp32-to-fp16.onnx").toString(); SplittableRandom rng = new SplittableRandom(1); - float[][] input = new float[10][5]; + int dim1 = 10, dim2 = 5; + float[][] input = new float[dim1][dim2]; + float[][] expectedOutput = new float[dim1][dim2]; FloatBuffer floatBuf = - ByteBuffer.allocateDirect(4 * 10 * 5).order(ByteOrder.nativeOrder()).asFloatBuffer(); - ShortBuffer shortBuf = ShortBuffer.allocate(10 * 5); + ByteBuffer.allocateDirect(4 * dim1 * dim2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ShortBuffer shortBuf = ShortBuffer.allocate(dim1 * dim2); // Generate data for (int i = 0; i < input.length; i++) { @@ -347,6 +430,8 @@ public void testFp32ToFp16() throws OrtException { input[i][j] = Float.intBitsToFloat(bits); floatBuf.put(input[i][j]); shortBuf.put(Fp16Conversions.floatToFp16(input[i][j])); + expectedOutput[i][j] = + Fp16Conversions.fp16ToFloat(Fp16Conversions.floatToFp16(input[i][j])); } } floatBuf.rewind(); @@ -354,25 +439,31 @@ public void testFp32ToFp16() throws OrtException { try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); - OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {10, 5}); + OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {dim1, dim2}); OrtSession.Result result = session.run(Collections.singletonMap("input", tensor))) { OnnxTensor output = (OnnxTensor) result.get(0); // Check outbound Java side cast to fp32 works FloatBuffer castOutput = output.getFloatBuffer(); - float[] expectedFloatArr = new float[10 * 5]; + float[] expectedFloatArr = new float[dim1 * dim2]; Fp16Conversions.convertFp16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); - float[] actualFloatArr = new float[10 * 5]; + float[] actualFloatArr = new float[dim1 * dim2]; castOutput.get(actualFloatArr); Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr); // Check bits are correct ShortBuffer outputBuf = output.getShortBuffer(); - short[] expectedShortArr = new short[10 * 5]; + short[] expectedShortArr = new short[dim1 * dim2]; shortBuf.get(expectedShortArr); - short[] actualShortArr = new short[10 * 5]; + short[] actualShortArr = new short[dim1 * dim2]; outputBuf.get(actualShortArr); Assertions.assertArrayEquals(expectedShortArr, actualShortArr); + + // Check outbound fp16 -> float[] conversion + float[][] floats = (float[][]) output.getValue(); + for (int i = 0; i < dim1; i++) { + Assertions.assertArrayEquals(expectedOutput[i], floats[i]); + } } } @@ -382,10 +473,12 @@ public void testFp32ToBf16() throws OrtException { String modelPath = TestHelpers.getResourcePath("/java-fp32-to-bf16.onnx").toString(); SplittableRandom rng = new SplittableRandom(1); - float[][] input = new float[10][5]; + int dim1 = 10, dim2 = 5; + float[][] input = new float[dim1][dim2]; + float[][] expectedOutput = new float[dim1][dim2]; FloatBuffer floatBuf = - ByteBuffer.allocateDirect(4 * 10 * 5).order(ByteOrder.nativeOrder()).asFloatBuffer(); - ShortBuffer shortBuf = ShortBuffer.allocate(10 * 5); + ByteBuffer.allocateDirect(4 * dim1 * dim2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ShortBuffer shortBuf = ShortBuffer.allocate(dim1 * dim2); // Generate data for (int i = 0; i < input.length; i++) { @@ -394,6 +487,8 @@ public void testFp32ToBf16() throws OrtException { input[i][j] = Float.intBitsToFloat(bits); floatBuf.put(input[i][j]); shortBuf.put(Fp16Conversions.floatToBf16(input[i][j])); + expectedOutput[i][j] = + Fp16Conversions.bf16ToFloat(Fp16Conversions.floatToBf16(input[i][j])); } } floatBuf.rewind(); @@ -401,25 +496,31 @@ public void testFp32ToBf16() throws OrtException { try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); - OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {10, 5}); + OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {dim1, dim2}); OrtSession.Result result = session.run(Collections.singletonMap("input", tensor))) { OnnxTensor output = (OnnxTensor) result.get(0); // Check outbound Java side cast to fp32 works FloatBuffer castOutput = output.getFloatBuffer(); - float[] expectedFloatArr = new float[10 * 5]; + float[] expectedFloatArr = new float[dim1 * dim2]; Fp16Conversions.convertBf16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); - float[] actualFloatArr = new float[10 * 5]; + float[] actualFloatArr = new float[dim1 * dim2]; castOutput.get(actualFloatArr); Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr); // Check bits are correct ShortBuffer outputBuf = output.getShortBuffer(); - short[] expectedShortArr = new short[10 * 5]; + short[] expectedShortArr = new short[dim1 * dim2]; shortBuf.get(expectedShortArr); - short[] actualShortArr = new short[10 * 5]; + short[] actualShortArr = new short[dim1 * dim2]; outputBuf.get(actualShortArr); Assertions.assertArrayEquals(expectedShortArr, actualShortArr); + + // Check outbound bf16 -> float[] conversion + float[][] floats = (float[][]) output.getValue(); + for (int i = 0; i < dim1; i++) { + Assertions.assertArrayEquals(expectedOutput[i], floats[i]); + } } } diff --git a/js/package-lock.json b/js/package-lock.json index d3684dfdf9117..58a13a9112116 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -3282,12 +3282,12 @@ } }, "node_modules/micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "dependencies": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" }, "engines": { @@ -7059,12 +7059,12 @@ "dev": true }, "micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "requires": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" } }, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 7b6140f3b1185..859bd850862aa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -4,18 +4,17 @@ import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; -import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; +import { ComputeContext, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; +import { createTransposeProgramInfo } from './transpose'; import { createTensorShapeVariables, - fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, - UniformsArrayType, } from './common'; export interface InstanceNormAttributes { @@ -23,117 +22,7 @@ export interface InstanceNormAttributes { format: 'NHWC' | 'NCHW'; } -const createInstanceNormProgramInfo = ( - inputs: readonly TensorView[], - attributes: InstanceNormAttributes, -): ProgramInfo => { - const xShape = inputs[0].dims; - const outputShape = xShape; - const axis = 2; - const normCount = ShapeUtil.sizeToDimension(xShape, axis); - const normSize = ShapeUtil.sizeFromDimension(xShape, axis); - const components = getMaxComponents(normSize); - const normPackedSize = normSize / components; - const inputShape = [xShape[0], xShape[1], normPackedSize]; - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; - const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: normSize }, - { type: DataType.uint32, data: normPackedSize }, - ]; - programUniforms.push(...createTensorShapeVariables(inputShape, inputShape)); - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); - const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); - const output = outputVariable('output', inputs[0].dataType, inputShape.length, components); - const variables = [x, scale, bias, output]; - const dataType = x.type.value; - const f32Type = components === 1 ? 'f32' : `vec${components}`; - const workgroupSize = 64; - - const uniforms: UniformsArrayType = [ - { name: 'normSize', type: 'u32' }, - { name: 'normPackedSize', type: 'u32' }, - ]; - return ` - var meanShared : f32; - var squaredNormShared : f32; - var workgroupShared : array<${f32Type}, ${workgroupSize}>; - const workgroupSize = ${workgroupSize}u; - ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} - ${shaderHelper.mainStart(workgroupSize)} - let norm = global_idx / workgroupSize; - let batch = norm / uniforms.x_shape[1]; - let channel = norm % uniforms.x_shape[1]; - let localIndex = local_id.x; - - // initialize workgroup memory - var initial = ${f32Type}(0); - for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { - initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')}); - } - workgroupShared[localIndex] = initial; - workgroupBarrier(); - - // Calculate the mean of current channel data. - for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) { - if (localIndex < currSize) { - workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize]; - } - workgroupBarrier(); - } - if (localIndex == 0) { - meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize); - } - workgroupBarrier(); - - // reinitialize workgroup memory. - initial = ${f32Type}(0); - for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { - let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared); - initial = initial + deviation * deviation; - } - workgroupShared[localIndex] = initial; - workgroupBarrier(); - - // Calculate the sum of square of deviation of current channel data. - for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) { - if (localIndex < currSize) { - workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize]; - } - workgroupBarrier(); - } - if (localIndex == 0) { - squaredNormShared = ${sumVector('workgroupShared[0]', components)}; - } - workgroupBarrier(); - - let invStdDev = inverseSqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon})); - let channelScale = invStdDev * f32(${scale.getByOffset('channel')}); - let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; - for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { - let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${ - f32Type - }(channelShift)); - ${output.set('batch', 'channel', 'h', 'value')}; - } - }`; - }; - return { - ...{ name: 'InstanceNormalization' }, - // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. - shaderCache: { hint: `${attributes.epsilon};${components}`, inputDependencies }, - getRunData: () => ({ - outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], - dispatchGroup: { x: normCount }, - programUniforms, - }), - getShaderSource, - }; -}; - -const computeMean = ( +const computeChannelScaleShift = ( context: ComputeContext, input: TensorView, scale: TensorView, @@ -143,121 +32,140 @@ const computeMean = ( c: number, epsilon: number, ) => { - const components = getMaxComponents(c); - const WG = 64; - // we will store channel scale and channel shift in [2, components] matrix - // or in vec2 when components == 1 - const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`; - const sumCastType = components === 1 ? 'f32' : `vec${components}f`; - const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`; - const unitsOfWork = (n * c) / components; - const wgSize = Math.ceil(h / WG); + const components = getMaxComponents(h); + const f32Type = components === 1 ? 'f32' : `vec${components}f`; + const wgType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const unitsOfWork = n * c; - const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; - const meanProgramUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: wgSize }, - { type: DataType.uint32, data: h }, - { type: DataType.uint32, data: Math.floor(c / components) }, - { type: DataType.uint32, data: Math.floor((h * c) / components) }, - ]; + const inputShape = [n, c, h / components]; + const outputShape = [n, c, 2]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; + const programUniforms: ProgramUniform[] = []; + programUniforms.push(...createTensorShapeVariables(inputShape, outputShape)); - const getMeanShaderSource = (shaderHelper: ShaderHelper) => { - const inputHelper = inputVariable('input', input.dataType, input.dims, components); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', input.dataType, 3, components); + const s = inputVariable('scale', scale.dataType, scale.dims); + const b = inputVariable('bias', bias.dataType, bias.dims); + const output = outputVariable('output', DataType.float, 3, 2); + const variables = [x, s, b, output]; + const workgroupSize = 64; return ` - ${shaderHelper.declareVariables(inputHelper)} - @group(0) @binding(1) var output : array<${outputType}>; - struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32}; - @group(0) @binding(2) var uniforms: Uniforms; + var workgroup_shared : array<${wgType}, ${workgroupSize}>; + const workgroup_size = ${workgroupSize}u; + ${shaderHelper.declareVariables(...variables)} + ${shaderHelper.mainStart(workgroupSize)} + let batch = workgroup_index / uniforms.x_shape[1]; + let channel = workgroup_index % uniforms.x_shape[1]; + let hight = uniforms.x_shape[2]; + // initialize workgroup memory + var sum = ${f32Type}(0); + var squared_sum = ${f32Type}(0); + for (var h = local_idx; h < hight; h += workgroup_size) { + let value = ${f32Type}(${x.get('batch', 'channel', 'h')}); + sum += value; + squared_sum += value * value; + } + workgroup_shared[local_idx] = ${wgType}(sum, squared_sum); + workgroupBarrier(); - ${shaderHelper.mainStart(WG)} - let currentImageNumber = global_idx / ${WG} / uniforms.C; - let currentChannelNumber = (global_idx / ${WG}) % uniforms.C; - let wgOffset = local_id.x * uniforms.wg_size; - if (wgOffset >= uniforms.H) { - return; + for (var currSize = workgroup_size >> 1; currSize > 0; currSize = currSize >> 1) { + if (local_idx < currSize) { + workgroup_shared[local_idx] = workgroup_shared[local_idx] + workgroup_shared[local_idx + currSize]; + } + workgroupBarrier(); } - let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H); + if (local_idx == 0) { + let sum_final = ${sumVector('workgroup_shared[0][0]', components)} / f32(hight * ${components}); + let squared_sum_final = ${sumVector('workgroup_shared[0][1]', components)} / f32(hight * ${components}); - let offset = currentImageNumber * uniforms.image_size + currentChannelNumber; - var sum = ${fillVector('f32', components)}; - var squaredSum = ${fillVector('f32', components)}; - for (var i: u32 = wgOffset; i < wgMax; i++) { - let value = ${sumCastType}(input[offset + i * uniforms.C]); - sum += value; - squaredSum += value * value; + let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32(${epsilon})); + let channel_scale = inv_std_dev * f32(scale[channel]); + let channel_shift = f32(bias[channel]) - sum_final * channel_scale; + output[workgroup_index] = vec2f(channel_scale, channel_shift); } - output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; }`; }; - const meanValues = context.compute( + return context.compute( { - name: 'InstanceNormComputeMean', - shaderCache: { hint: `${components}`, inputDependencies: meanInputDependencies }, + name: 'InstanceNormComputeChannelScaleShift', + // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. + shaderCache: { hint: `${components};${epsilon}`, inputDependencies }, getRunData: () => ({ - outputs: [{ dims: [n, c, WG, 2], dataType: DataType.float }], - dispatchGroup: { x: (n * c) / components }, - programUniforms: meanProgramUniforms, + outputs: [{ dims: outputShape, dataType: DataType.float }], + dispatchGroup: { x: unitsOfWork }, + programUniforms, }), - getShaderSource: getMeanShaderSource, + getShaderSource, }, - { inputs: [input], outputs: [-1] }, + { inputs: [input, scale, bias], outputs: [-1] }, )[0]; +}; + +const createInstanceNormProgramInfo = ( + context: ComputeContext, + inputs: readonly TensorView[], + attributes: InstanceNormAttributes, +) => { + const xShape = inputs[0].dims; + const outputShape = xShape; + const axis = 2; + const N = xShape[0]; + const C = xShape[1]; + const H = ShapeUtil.sizeFromDimension(xShape, axis); + const components = getMaxComponents(H); + const outputSize = ShapeUtil.size(outputShape) / components; + // compute channel scale and channel shift. + const channelScaleShift = computeChannelScaleShift( + context, + inputs[0], + inputs[1], + inputs[2], + N, + H, + C, + attributes.epsilon, + ); + + const inputShape = [N, C, H / components]; + const scaleShape = [N, C]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'none']; - const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: unitsOfWork }, - { type: DataType.uint32, data: h }, - { type: DataType.uint32, data: Math.floor(c / components) }, - { type: DataType.uint32, data: Math.floor((WG * c) / components) }, - ]; - const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; const getShaderSource = (shaderHelper: ShaderHelper) => { - const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); - const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); + const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); + const scale = inputVariable('scale_shift', DataType.float, scaleShape.length, 2); + const output = outputVariable('output', inputs[0].dataType, inputShape.length, components); + const variables = [x, scale, output]; return ` - @group(0) @binding(0) var input : array<${outputType}>; - @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; - @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; - @group(0) @binding(3) var output : array<${outputType}>; - struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32}; - @group(0) @binding(4) var uniforms: Uniforms; - + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')} - let currentImageNumber = global_idx / uniforms.C; - let currentChannelNumber = global_idx % uniforms.C; - - let offset = currentImageNumber * uniforms.image_size; - var sum = ${fillVector('f32', components)}; - var squaredSum = ${fillVector('f32', components)}; - for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) { - let value = input[offset + i + currentChannelNumber * ${WG}]; - sum += value[0]; - squaredSum += value[1]; - } - sum = sum / f32(uniforms.H); - squaredSum = squaredSum / f32(uniforms.H); - let invStdDev = inverseSqrt(squaredSum - sum * sum + f32(${epsilon})); - let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]); - let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale; - - output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let outputIndices = ${output.offsetToIndices('global_idx')}; + let batch = outputIndices[0]; + let channel = outputIndices[1]; + let scale_shift = ${scale.getByIndices('vec2(batch, channel)')}; + let value = ${x.getByOffset('global_idx')} * ${output.type.value}(scale_shift.x) + ${output.type.value}(scale_shift.y); + ${output.setByOffset('global_idx', 'value')}; }`; }; - return context.compute( + + context.compute( { - name: 'InstanceNormComputeChannelScaleShift', - // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. - shaderCache: { hint: `${components};${epsilon}`, inputDependencies }, + name: 'InstanceNormalization', + shaderCache: { hint: `${components}`, inputDependencies }, getRunData: () => ({ - outputs: [{ dims: [n, c, 2], dataType: DataType.float }], - dispatchGroup: { x: Math.ceil(unitsOfWork / 64 /* workgroup size */) }, - programUniforms, + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(inputShape, scaleShape, inputShape), + ], }), getShaderSource, }, - { inputs: [meanValues, scale, bias], outputs: [-1] }, - )[0]; + { inputs: [inputs[0], channelScaleShift] }, + ); }; const createInstanceNormNHWCProgramInfo = ( @@ -277,30 +185,61 @@ const createInstanceNormNHWCProgramInfo = ( { type: DataType.uint32, data: Math.floor(C / components) }, ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; - // first compute mean - const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); + + // 1. transpose x from NHWC to NCHW + const transposedXPerm = [0, xShape.length - 1]; + for (let i = 0; i < xShape.length - 2; i++) { + transposedXPerm.push(i + 1); + } + const transposedX = context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), { + inputs: [context.inputs[0]], + outputs: [-1], + })[0]; + // 2. compute channel scale and channel shift. + const channelScaleShift = computeChannelScaleShift( + context, + transposedX, + inputs[1], + inputs[2], + N, + H, + C, + attributes.epsilon, + ); const getShaderSource = (shaderHelper: ShaderHelper) => { const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; - const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; - + const scaleType = components === 1 ? 'vec2f' : `mat${components}x2f`; + const scaleData = (num: number) => { + const index = num === 0 ? 'x' : 'y'; + const f32Type = components === 1 ? 'f32' : `vec${components}f`; + switch (components) { + case 1: + return `${dataType}(${f32Type}(scale.${index}))`; + case 2: + return `vec2<${dataType}>(${f32Type}(scale[0].${index}, scale[1].${index}))`; + case 4: + return `vec4<${dataType}>(${f32Type}(scale[0].${index}, scale[1].${index}, scale[2].${index}, scale[3].${index}))`; + default: + throw new Error(`Not supported compoents ${components}`); + } + }; const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); return ` @group(0) @binding(0) var input : array<${inputHelper.type.storage}>; - @group(0) @binding(1) var scaleInput : array<${scaleType}>; + @group(0) @binding(1) var scale_input : array<${scaleType}>; @group(0) @binding(2) var output : array<${outputHelper.type.storage}>; struct Uniforms {H: u32, C : u32}; @group(0) @binding(3) var uniforms: Uniforms; ${shaderHelper.mainStart()} - let currentImageNumber = global_idx / (uniforms.C * uniforms.H); - let currentChannelNumber = global_idx % uniforms.C; + let current_image_number = global_idx / (uniforms.C * uniforms.H); + let current_channel_number = global_idx % uniforms.C; - let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber; - let scale = scaleInput[scaleOffset]; - output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1])); + let scale_offset = current_image_number * uniforms.C + current_channel_number; + let scale = scale_input[scale_offset]; + output[global_idx] = fma(input[global_idx], ${scaleData(0)}, ${scaleData(1)}); }`; }; context.compute( @@ -322,6 +261,6 @@ export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAt if (attributes.format === 'NHWC') { createInstanceNormNHWCProgramInfo(context, context.inputs, attributes); } else { - context.compute(createInstanceNormProgramInfo(context.inputs, attributes)); + createInstanceNormProgramInfo(context, context.inputs, attributes); } }; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 9db48f74a94a4..6e723a76e8fd8 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -2390,12 +2390,12 @@ } }, "node_modules/micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "dependencies": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" }, "engines": { @@ -5514,12 +5514,12 @@ "dev": true }, "micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "requires": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" } }, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc index 7b84971585f9f..c8fe9c77d8ff8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc @@ -48,13 +48,13 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat constexpr size_t element_size = sizeof(T); ProcessBroadcastSpanFuncs add_funcs{ [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); }}; // For element-wise add // Allocate space for output of Q(BS, D) + bias(D) @@ -132,13 +132,13 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is constexpr size_t element_size = sizeof(T); ProcessBroadcastSpanFuncs add_funcs{ [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); }}; // For element-wise add // Get Q's bias from combined bias @@ -219,6 +219,10 @@ template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, int batch_size, int num_heads, int sequence_length, int head_size, const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out); +template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out); + template Status MaybeTransposeToBNSH(AllocatorPtr allocator, int batch_size, int num_heads, int sequence_length, int head_size, @@ -242,5 +246,9 @@ template Status MaybeTransposeToBNSH(AllocatorPtr allocator, int batch_size, int num_heads, int sequence_length, int head_size, const Tensor* in, OrtValue& out); +template Status MaybeTransposeToBNSH(AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, OrtValue& out); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc index 570f4108c3f62..72adfa025da57 100644 --- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc @@ -86,6 +86,11 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { std::atomic_bool failed{false}; int n = batch_size * sequence_length; + + // Put epsilon into local variable here to avoid the need to capture 'this' in the TryBatchParallelFor() lambda. + // Using the copy capture default (=) to implicitly capture 'this' is deprecated. + const float epsilon_value = epsilon(); + concurrency::ThreadPool::TryBatchParallelFor( context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) { int word_col_index = input_ids_data[index]; @@ -136,7 +141,7 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { y[i] = a; sum += a * a; } - T e = sqrt(sum / hidden_size + static_cast(epsilon())); + T e = sqrt(sum / hidden_size + static_cast(epsilon_value)); for (int i = 0; i < hidden_size; i++) { y[i] = y[i] / e * gamma_data[i] + beta_data[i]; } diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index bfec9aef56727..ccaeb6654e286 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -75,7 +75,7 @@ class GQAAttentionBase { int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]); // Compute the attention score. - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -87,16 +87,17 @@ class GQAAttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp); + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), + ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, + seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, tp); + is_prompt, tp, allocator); return Status::OK(); } @@ -106,7 +107,7 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + void ComputeAttentionProbs(float* attention_probs, // output buffer with size BxNxSxT const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor @@ -120,7 +121,8 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt - ThreadPool* tp) const { // thread pool + ThreadPool* tp, // thread pool + AllocatorPtr allocator) const { // allocator for temporary buffer const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); @@ -131,7 +133,9 @@ class GQAAttentionBase { const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { - memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); + memset((void*)present_key, + 0, + batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } const size_t loop_len = batch_size * num_heads_; @@ -164,7 +168,7 @@ class GQAAttentionBase { const size_t past_chunk_length = past_seqlen * head_size; const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; - T* output = attention_probs + output_offset; + float* output = attention_probs + output_offset; const T* k; if (packed_qkv) { @@ -190,12 +194,28 @@ class GQAAttentionBase { q = Q + q_input_chunk_length * i; } - math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, - static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, - static_cast(present_buffer_sequence_length), nullptr); + if constexpr (std::is_same::value) { + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, + static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, + output, static_cast(present_buffer_sequence_length), nullptr); + } else { + size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float); + auto q_k_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator)); + + float* q_fp32 = static_cast(q_k_fp32); + MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length); + + float* k_fp32 = q_fp32 + head_size * sequence_length; + MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seqlen); + + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q_fp32, + static_cast(head_size), k_fp32, static_cast(head_size), 0.0f /*bata*/, + output, static_cast(present_buffer_sequence_length), nullptr); + } // compute Softmax - T* output_softmax = output; + float* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { size_t seq_causal_length = past_seqlen + seq + 1; if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) { @@ -237,7 +257,7 @@ class GQAAttentionBase { template void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH - const T* attention_probs, // Attention probs with size BxNxSxT + const float* attention_probs, // Attention probs with size BxNxSxT const T* V, // V value with size BxN_kvxSxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor const size_t batch_size, // batch size @@ -251,7 +271,8 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt - ThreadPool* tp) const { + ThreadPool* tp, + AllocatorPtr allocator) const { const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); @@ -261,7 +282,9 @@ class GQAAttentionBase { const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { - memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); + memset((void*)present_value, + 0, + batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } const size_t loop_len = batch_size * num_heads_; @@ -285,6 +308,13 @@ class GQAAttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; + size_t output_fp32_bytes = 0; + if constexpr (std::is_same::value) { + output_fp32_bytes = SafeInt(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float); + } + auto output_fp32 = allocator->Alloc(output_fp32_bytes); + BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator)); + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const size_t batch_index = i / num_heads_; @@ -305,15 +335,39 @@ class GQAAttentionBase { i / kv_num_heads_factor); } - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/ - attention_probs + attention_probs_offset, - static_cast(present_buffer_sequence_length), v, static_cast(head_size), - 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr); + if constexpr (std::is_same::value) { + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v, + static_cast(head_size), 0.0f /*beta*/, output_current, + static_cast(hidden_size), nullptr); + } else { + size_t bytes = head_size * total_seqlen * sizeof(float); + auto v_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator)); + + float* v_fp32_ptr = static_cast(v_fp32); + MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seqlen); + + float* output_fp32_current = static_cast(output_fp32) + + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v_fp32_ptr, + static_cast(head_size), 0.0f /*beta*/, output_fp32_current, + static_cast(hidden_size), nullptr); + } } }); + + if constexpr (std::is_same::value) { + MlasConvertFloatToHalfBuffer(static_cast(output_fp32), + output, + SafeInt(sequence_length) * batch_size * num_heads_ * head_size); + } } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 2a38e4a1ac636..a1ed35e54b008 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -22,16 +22,20 @@ namespace onnxruntime { namespace contrib { // These ops are internal-only, so register outside of onnx -ONNX_OPERATOR_TYPED_KERNEL_EX( - GroupQueryAttention, - kMSDomain, - 1, - float, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("M", DataTypeImpl::GetTensorType()), - GroupQueryAttention); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + GroupQueryAttention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) template GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 6732f8b96cce2..cbfd2f0949363 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -13,16 +13,20 @@ namespace onnxruntime { namespace contrib { // These ops are internal-only, so register outside of onnx -ONNX_OPERATOR_TYPED_KERNEL_EX( - RotaryEmbedding, - kMSDomain, - 1, - float, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("M", DataTypeImpl::GetTensorType()), - RotaryEmbedding); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { @@ -75,19 +79,27 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete const T* sin_data = sin_cache + cache_offset; int cache_idx = 0; - T sign = 0; + bool sign = false; int j = 0; for (int i = 0; i < rotary_emb_dim; i++) { if (interleaved) { cache_idx = (i / 2) % half_rotary_emb_dim; - sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); - j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + sign = i & 1; + j = sign ? i - 1 : i + 1; // i - sign } else { cache_idx = i % half_rotary_emb_dim; - sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); + sign = (i >= half_rotary_emb_dim); j = (i + half_rotary_emb_dim) % rotary_emb_dim; } - output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); + float input_data_j = static_cast(input_data[j]); + float sin_data_cache_idx = static_cast(sin_data[cache_idx]); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output_data[i] = static_cast(output_data_i); } for (int i = rotary_emb_dim; i < head_size; i++) { output_data[i] = input_data[i]; @@ -102,6 +114,10 @@ template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryPar const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output, bool interleaved); +template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input, + const int64_t* position_ids, const MLFloat16* cos_cache, const MLFloat16* sin_cache, + MLFloat16* output, bool interleaved); + template Status RotaryEmbedding::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index dcd1f5ec22b52..e75d485830ca5 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -22,8 +22,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); @@ -288,8 +290,10 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 7ed776f1358a5..7b1b136eb091e 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -227,10 +227,16 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, uint8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int8_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, uint8_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, GRU); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, LSTM); @@ -408,9 +414,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); @@ -636,9 +646,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat); @@ -1443,16 +1457,28 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { int64_t, ReduceSumSquare)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1725,12 +1751,20 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { uint8_t, ArgMax)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2065,11 +2099,19 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { ArgMax)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 91717486b77cb..a78ff69e5c894 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -757,9 +757,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_1_vec_map.min(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template min( + static_cast(per_iter_bh.ScalarInput0())); } else { - output_vec_map = input_1_vec_map.max(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template max( + static_cast(per_iter_bh.ScalarInput0())); } }, [](BroadcastHelper& per_iter_bh) { @@ -772,9 +774,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.min(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template min( + static_cast(per_iter_bh.ScalarInput1())); } else { - output_vec_map = input_0_vec_map.max(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template max( + static_cast(per_iter_bh.ScalarInput1())); } }, [](BroadcastHelper& per_iter_bh) { @@ -790,9 +794,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.min(input_1_vec_map); + output_vec_map = input_0_vec_map.template min(input_1_vec_map); } else { - output_vec_map = input_0_vec_map.max(input_1_vec_map); + output_vec_map = input_0_vec_map.template max(input_1_vec_map); } }}; diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index 5aac1d9387f57..24fbfbe8d525b 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -288,22 +288,36 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSumSquare, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceSumSquare, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMax, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMax, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMax, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMax, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMax, 11, 12) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMax, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMax, 11, 12) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMax, 11, 12) REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMax, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMin, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMin, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMin, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMin, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMin, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMin, 11, 12) REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMin, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMin, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMin, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMin, 13); FastReduceKind operator|(FastReduceKind a, FastReduceKind b) { return static_cast(static_cast(a) | static_cast(b)); diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index db36754319309..55935a9eae86d 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -10,13 +10,10 @@ #include #include #include +#include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/cuda_call.h" -#if CUDA_VERSION >= 11000 -#include -#endif - namespace onnxruntime { namespace cuda { @@ -347,6 +344,21 @@ __device__ __inline__ double _Pow(double a, double b) { return pow(a, b); } template <> __device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); } +#define ISNAN_HALF(v__) static_cast(*reinterpret_cast(&v__) & ~MLFloat16::kSignMask) \ + > MLFloat16::kPositiveInfinityBits + +#define ISNAN_BFLOAT16(v__) static_cast(*reinterpret_cast(&v__) & ~BFloat16::kSignMask) \ + > BFloat16::kPositiveInfinityBits + +// CUDART_NAN_BF16 and CUDART_NAN_FP16 constants were only added in CUDA 12.2, +// so define our own equivalent constants to support older versions. +// Note that there is no consistent canonical NaN for FP16 and BF16; +// CUDA uses 0x7FFF for both, but ONNX Runtime uses 0x7E00 and 0x7FC1 +// for FP16 and BF16 respectively +// (see Float16Impl::kPositiveQNaNBits and BFloat16Impl::kPositiveQNaNBits). +#define NAN_HALF __ushort_as_half((unsigned short)0x7FFFU) +#define NAN_BFLOAT16 BFloat16::FromBits((uint16_t)0x7FFFU) + template __device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; } @@ -360,6 +372,24 @@ __device__ __inline__ double _Min(double a, double b) { return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); } +template <> +__device__ __inline__ half _Min(half a, half b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_HALF(a) || ISNAN_HALF(b)) ? NAN_HALF : (a < b ? a : b); +#else + return __hmin_nan(a, b); +#endif +} + +template <> +__device__ __inline__ BFloat16 _Min(BFloat16 a, BFloat16 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a < b ? a : b); +#else + return BFloat16(__hmin_nan((__nv_bfloat16)a, (__nv_bfloat16)b)); +#endif +} + template __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } @@ -373,6 +403,29 @@ __device__ __inline__ double _Max(double a, double b) { return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); } +template <> +__device__ __inline__ half _Max(half a, half b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_HALF(a) || ISNAN_HALF(b)) ? NAN_HALF : (a > b ? a : b); +#else + return __hmax_nan(a, b); +#endif +} + +template <> +__device__ __inline__ BFloat16 _Max(BFloat16 a, BFloat16 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a > b ? a : b); +#else + return BFloat16(__hmax_nan((__nv_bfloat16)a, (__nv_bfloat16)b)); +#endif +} + +#undef ISNAN_HALF +#undef ISNAN_BFLOAT16 +#undef NAN_HALF +#undef NAN_BFLOAT16 + template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 9cd73edbff0e0..ac9098f907975 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -142,7 +142,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t device_id_ = 0; qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; uint32_t default_rpc_control_latency_ = 0; - bool enable_HTP_FP16_precision_ = false; + bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; #ifdef _WIN32 onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 11073ab3584eb..a1f5eba9a24c8 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -731,10 +731,9 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenR std::vector axes; size_t num_inputs = ctx->InputCount(); - if (num_inputs == 2) { + const Tensor* axes_tensor = num_inputs == 2 ? ctx->Input(1) : nullptr; // optional input. may be nullptr. + if (axes_tensor != nullptr) { // override the attribute value with the input value for reduction_axes - const Tensor* axes_tensor = ctx->Input(1); - ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); auto nDims = static_cast(axes_tensor->Shape()[0]); const auto* data = axes_tensor->Data(); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b9e017df5baa3..83e7596d2f6b8 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2770,7 +2770,8 @@ common::Status InferenceSession::RunAsync(const RunOptions* run_options, if (!tp || concurrency::ThreadPool::DegreeOfParallelism(tp) < 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync"); } - std::function run_fn = [=]() { + std::function run_fn = [run_options, feed_names, feeds, fetch_names, fetches, num_fetches, + callback, user_data, this]() { Status status = Status::OK(); ORT_TRY { if (run_options) { diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 924616f49ab25..e8c948ade1068 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -73,7 +73,7 @@ void usage() { "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" - "\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n" + "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" diff --git a/onnxruntime/test/optimizer/optimizer_test.cc b/onnxruntime/test/optimizer/optimizer_test.cc index 79704f2cc79e3..81c1a4ace1e33 100644 --- a/onnxruntime/test/optimizer/optimizer_test.cc +++ b/onnxruntime/test/optimizer/optimizer_test.cc @@ -24,8 +24,6 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { -static const std::string MODEL_FOLDER = "testdata/transform/"; - TEST(OptimizerTest, Basic) { Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index c1c48d4945a4d..6e811f4596eab 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -98,7 +98,7 @@ namespace perftest { "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" - "\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n" + "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index eb914646942fe..507ed8e91a728 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1787,54 +1787,90 @@ TEST(MathOpTest, Min_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFloat16_MatrixVector) { - OpTester test("Min", 12); - test.AddInput("data_0", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); - test.AddInput("data_1", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddOutput("min", {3, 3}, - MakeMLFloat16({0.0f, 0.0f, 0.0f, - -1.0f, -1.0f, -2.0f, - 0.5f, 0.0f, 1.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { +void TestFloat16MinMax( + const char* op_name, + const std::vector& lhs_dim, + const std::initializer_list& lhs_values, + const std::vector& rhs_dim, + const std::initializer_list& rhs_values, + const std::vector& out_dim, + const std::initializer_list& out_values) { + { std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); + if (nullptr != DefaultCpuExecutionProvider()) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (nullptr != DefaultCudaExecutionProvider()) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + OpTester test(op_name, 13); + test.AddInput("data_0", lhs_dim, MakeMLFloat16(lhs_values)); + test.AddInput("data_1", rhs_dim, MakeMLFloat16(rhs_values)); + test.AddOutput("output", out_dim, MakeMLFloat16(out_values)); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} -TEST(MathOpTest, Min_12_MLFloat16_VectorMatrix) { - OpTester test("Min", 12); - test.AddInput("data_0", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddInput("data_1", {3, 4}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, -1.0f, - -0.5f, 0.0f, -2.0f, -1.25f, - 0.5f, 0.0f, 2.0f, 1.5f})); - test.AddOutput("min", {3, 4}, - MakeMLFloat16({0.0f, 0.0f, 0.0f, -1.0f, - -1.0f, -1.0f, -2.0f, -1.25f, - 0.5f, 0.0f, 1.0f, 1.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + OpTester test(op_name, 13); + test.AddInput("data_0", lhs_dim, MakeBFloat16(lhs_values)); + test.AddInput("data_1", rhs_dim, MakeBFloat16(rhs_values)); + test.AddOutput("output", out_dim, MakeBFloat16(out_values)); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } +TEST(MathOpTest, Min_13_Float16_MatrixVector) { + TestFloat16MinMax("Min", + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f}, + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 3}, + {0.0f, 0.0f, 0.0f, + -1.0f, -1.0f, -2.0f, + 0.5f, 0.0f, 1.0f}); +} + +TEST(MathOpTest, Min_13_Float16_VectorMatrix) { + TestFloat16MinMax("Min", + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 4}, + {1.0f, 1.0f, 1.0f, -1.0f, + -0.5f, 0.0f, -2.0f, -1.25f, + 0.5f, 0.0f, 2.0f, 1.5f}, + {3, 4}, + {0.0f, 0.0f, 0.0f, -1.0f, + -1.0f, -1.0f, -2.0f, -1.25f, + 0.5f, 0.0f, 1.0f, 1.0f}); +} + +TEST(MathOpTest, Min_13_Float16_Nan) { + TestFloat16MinMax("Min", + {4, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f, 0.5f}, + {4, 1}, {0.5f, 1.0f, 0.25f, std::numeric_limits::quiet_NaN()}, + {4, 1}, + {-1.0f, std::numeric_limits::quiet_NaN(), 0.25f, std::numeric_limits::quiet_NaN()}); +} + +TEST(MathOpTest, Min_13_Float16_Nan_with_scalar) { + TestFloat16MinMax("Min", + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f}, + {1}, {0.25f}, + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 0.25f}); +} + +TEST(MathOpTest, Min_13_Float16_with_scalar_Nan) { + TestFloat16MinMax("Min", + {3, 1}, {-0.5f, 1.0f, 1.5f}, + {1}, {std::numeric_limits::quiet_NaN()}, + {3, 1}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); +} TEST(MathOpTest, Max_6) { OpTester test("Max", 6); std::vector dims{3, 3}; @@ -2185,54 +2221,57 @@ TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFloat16_MatrixVector) { - OpTester test("Max", 12); - test.AddInput("data_0", {4, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.0f, 0.5f, 0.75f, - 0.5f, 0.0f, 2.0f})); - test.AddInput("data_1", {4, 1}, - MakeMLFloat16({0.0f, -1.0f, 0.5f, 1.0f})); - test.AddOutput("max", {4, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -1.0f, - 0.5f, 0.5f, 0.75f, - 1.0f, 1.0f, 2.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} - -TEST(MathOpTest, Max_12_MLFloat16_VectorMatrix) { - OpTester test("Max", 12); - test.AddInput("data_0", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddInput("data_1", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); - test.AddOutput("max", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -1.0f, - 1.0f, 1.0f, 2.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } +TEST(MathOpTest, Max_13_Float16_MatrixVector) { + TestFloat16MinMax("Max", + {4, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.0f, 0.5f, 0.75f, + 0.5f, 0.0f, 2.0f}, + {4, 1}, {0.0f, -1.0f, 0.5f, 1.0f}, + {4, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -1.0f, + 0.5f, 0.5f, 0.75f, + 1.0f, 1.0f, 2.0f}); +} + +TEST(MathOpTest, Max_13_Float16_VectorMatrix) { + TestFloat16MinMax("Max", + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f}, + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -1.0f, + 1.0f, 1.0f, 2.0f}); +} + +TEST(MathOpTest, Max_13_Float16_Nan) { + TestFloat16MinMax("Max", + {4, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f, 0.5f}, + {4, 1}, {0.5f, 1.0f, 0.25f, std::numeric_limits::quiet_NaN()}, + {4, 1}, + {0.5f, std::numeric_limits::quiet_NaN(), 1.0f, std::numeric_limits::quiet_NaN()}); +} + +TEST(MathOpTest, Max_13_Float16_Nan_with_scalar) { + TestFloat16MinMax("Max", + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f}, + {1}, {0.25f}, + {3, 1}, {0.25f, std::numeric_limits::quiet_NaN(), 1.0f}); +} + +TEST(MathOpTest, Max_13_Float16_with_scalar_Nan) { + TestFloat16MinMax("Max", + {3, 1}, {-0.5f, 1.0f, 1.5f}, + {1}, {std::numeric_limits::quiet_NaN()}, + {3, 1}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); } TEST(MathOpTest, Not) { diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 0697187a777d6..0968bc32e0de4 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -3246,6 +3246,26 @@ TEST(ReductionOpTest, ArgMax_do_not_keepdims_2) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: node1: at least 2 dimensions are required for input } +TEST(ReductionOpTest, ArgMax_int64) { + OpTester test("ArgMax", 13); + test.AddAttribute("axis", (int64_t)1); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {3, 1, 2}, + {1, 1, + 1, 1, + 1, 1}); + test.Run(); +} + TEST(ReductionOpTest, ArgMax_int32) { OpTester test("ArgMax"); test.AddAttribute("axis", (int64_t)1); @@ -3511,6 +3531,63 @@ TEST(ReductionOpTest, ArgMin_do_not_keepdims_2_select_last) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(ReductionOpTest, ArgMin_uint8) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(); +} + +TEST(ReductionOpTest, ArgMin_int8) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST(ReductionOpTest, ArgMin_int64) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(); +} + TEST(ReductionOpTest, ArgMin_int32) { OpTester test("ArgMin"); test.AddAttribute("axis", (int64_t)0); diff --git a/onnxruntime/test/providers/qnn/cast_test.cc b/onnxruntime/test/providers/qnn/cast_test.cc index f03782c33c30a..9b83dd281a56d 100644 --- a/onnxruntime/test/providers/qnn/cast_test.cc +++ b/onnxruntime/test/providers/qnn/cast_test.cc @@ -49,7 +49,8 @@ static GetTestModelFn BuildCastTestCase(const std::vector& shape, template static void RunCastOpTest(const std::vector& shape, ONNX_NAMESPACE::TensorProto_DataType dst_type, ExpectedEPNodeAssignment expected_ep_assignment, - bool use_htp) { + bool use_htp, + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = use_htp ? "QnnHtp.dll" : "QnnCpu.dll"; @@ -57,6 +58,12 @@ static void RunCastOpTest(const std::vector& shape, ONNX_NAMESPACE::Ten provider_options["backend_path"] = use_htp ? "libQnnHtp.so" : "libQnnCpu.so"; #endif + if (use_htp && enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + RunQnnModelTest(BuildCastTestCase(shape, dst_type), provider_options, 13, // opset @@ -93,19 +100,19 @@ TEST_F(QnnCPUBackendTests, TestCastFloatToInt32) { // Cast int32_t to float on HTP TEST_F(QnnHTPBackendTests, TestCastInt32ToFloatHTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, - true); + true, false); } // Cast uint8_t to float on HTP TEST_F(QnnHTPBackendTests, TestCastUInt8ToFloatHTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, - true); + true, false); } // Cast float to int32_t on HTP TEST_F(QnnHTPBackendTests, TestCastFloatToInt32HTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, ExpectedEPNodeAssignment::All, - true); + true, false); } // Cast int64_t to int32_t on HTP diff --git a/onnxruntime/test/providers/qnn/clip_op_test.cc b/onnxruntime/test/providers/qnn/clip_op_test.cc index c3a75fd7446e2..cfa77a46210b3 100644 --- a/onnxruntime/test/providers/qnn/clip_op_test.cc +++ b/onnxruntime/test/providers/qnn/clip_op_test.cc @@ -21,7 +21,8 @@ static void RunClipTest(const TestInputDef& input_def, const std::vector>& min_max_defs, ExpectedEPNodeAssignment expected_ep_assignment, bool on_cpu_backend = true, - int opset = 13) { + int opset = 13, + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) @@ -30,6 +31,12 @@ static void RunClipTest(const TestInputDef& input_def, provider_options["backend_path"] = on_cpu_backend ? "libQnnCpu.so" : "libQnnHtp.so"; #endif + if (!on_cpu_backend && enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + RunQnnModelTest(BuildOpTestCase("Clip", {input_def}, min_max_defs, {}), provider_options, opset, @@ -80,7 +87,9 @@ TEST_F(QnnHTPBackendTests, Clip_f32) { {TestInputDef({}, true, {-5.0f}), TestInputDef({}, true, {5.0f})}, ExpectedEPNodeAssignment::All, - on_cpu_backend); + on_cpu_backend, + 13, + false); } // Test Clip with int32 on HTP diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index d8c34d6a6c6ed..708aac03ceb2e 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -117,7 +117,8 @@ static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, ExpectedEPNodeAssignment expected_ep_assignment, int opset = 21, bool use_contrib_qdq = false, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -125,6 +126,12 @@ static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif + if (enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + TestQDQModelAccuracy(BuildMatMulOpTestCase(input_def, weights_def), BuildQDQPerChannelMatMulTestCase(input_def, weights_def, @@ -275,7 +282,8 @@ TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_AS8_WeightInt4) { ExpectedEPNodeAssignment::All, 21, false, - QDQTolerance(0.007f)); + QDQTolerance(0.007f), + false); } // Test QDQ per-channel MatMul with 16-bit act, int8 weights (static) diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 2ebc2c6251b44..83899ec6ef17b 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -157,6 +157,8 @@ static void RunOpTest(const std::string& op_type, if (enable_htp_fp16_precision) { provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; // enabled in QNN EP by default } // Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs. diff --git a/onnxruntime/test/providers/qnn/transpose_htp_test.cc b/onnxruntime/test/providers/qnn/transpose_htp_test.cc index 119b8301f36ed..63746e22d214d 100644 --- a/onnxruntime/test/providers/qnn/transpose_htp_test.cc +++ b/onnxruntime/test/providers/qnn/transpose_htp_test.cc @@ -90,7 +90,8 @@ static void RunTransposeQDQTest(const TestInputDef& input_def, template static void RunTransposeNonQDQOnHTP(const TestInputDef& input_def, const std::vector& attrs, - ExpectedEPNodeAssignment expected_ep_assignment) { + ExpectedEPNodeAssignment expected_ep_assignment, + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -98,6 +99,12 @@ static void RunTransposeNonQDQOnHTP(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif + if (enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + RunQnnModelTest(BuildTransposeTestCase(input_def, attrs), provider_options, 13, @@ -123,7 +130,7 @@ TEST_F(QnnHTPBackendTests, TransposeInt32OnHTP) { TEST_F(QnnHTPBackendTests, TransposeFloatOnHTP) { RunTransposeNonQDQOnHTP(TestInputDef({1, 3, 224, 128}, false, 0, 10.0f), {utils::MakeAttribute("perm", std::vector{0, 2, 3, 1})}, - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, false); } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index dc21d4e4a5890..08ec5de328b9d 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -29,6 +29,12 @@ GREEN = "\033[32m" RESET = "\033[0m" +ORT_TYPE = TensorProto.FLOAT +TORCH_TYPE = torch.float16 if ORT_TYPE == TensorProto.FLOAT16 else torch.float32 +NUMPY_TYPE = numpy.float16 if ORT_TYPE == TensorProto.FLOAT16 else numpy.float32 +RTOL = 3e-2 if ORT_TYPE == TensorProto.FLOAT16 else 1e-3 +ATOL = RTOL + class Formats: BSNH = 0 @@ -186,7 +192,7 @@ def create_group_query_attention_graph_prompt( graph_input = [ helper.make_tensor_value_info( "query", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.q_sequence_length, @@ -212,7 +218,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length, @@ -221,7 +227,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length, @@ -233,7 +239,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "past_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -243,7 +249,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "past_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -256,7 +262,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "cos_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.buffer_sequence_length if share_buffer else config.kv_sequence_length, (math.floor(config.head_size / 16) * 16) // 2, @@ -264,7 +270,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "sin_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.buffer_sequence_length if share_buffer else config.kv_sequence_length, (math.floor(config.head_size / 16) * 16) // 2, @@ -275,12 +281,12 @@ def create_group_query_attention_graph_prompt( graph_output = [ helper.make_tensor_value_info( "output", - TensorProto.FLOAT, + ORT_TYPE, [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -290,7 +296,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -300,7 +306,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -310,7 +316,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -378,7 +384,7 @@ def create_group_query_attention_graph_past( graph_input = [ helper.make_tensor_value_info( "query", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -391,7 +397,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "past_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -401,7 +407,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "past_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -424,7 +430,7 @@ def create_group_query_attention_graph_past( graph_input += [ helper.make_tensor_value_info( "key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -433,7 +439,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -445,7 +451,7 @@ def create_group_query_attention_graph_past( graph_input += [ helper.make_tensor_value_info( "cos_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.kv_sequence_length + (0 if share_buffer else config.sequence_length), (math.floor(config.head_size / 16) * 16) // 2, @@ -453,7 +459,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "sin_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.kv_sequence_length + (0 if share_buffer else config.sequence_length), (math.floor(config.head_size / 16) * 16) // 2, @@ -464,12 +470,12 @@ def create_group_query_attention_graph_past( graph_output = [ helper.make_tensor_value_info( "output", - TensorProto.FLOAT, + ORT_TYPE, [config.batch_size, config.sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -479,7 +485,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -641,7 +647,7 @@ def create_inputs(config: Config, kv_packed=False, qkv_packed=True): config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) key_padding_mask = generate_random_padding_mask( @@ -722,13 +728,13 @@ def gqa_prompt_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( - "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) io_binding.bind_input( "past_value", "cpu", 0, - numpy.float32, + NUMPY_TYPE, ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) @@ -835,13 +841,13 @@ def gqa_past_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( - "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) io_binding.bind_input( "past_value", "cpu", 0, - numpy.float32, + NUMPY_TYPE, ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) @@ -1017,9 +1023,11 @@ def attention_ref( attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -1058,8 +1066,8 @@ def parity_check_gqa_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1067,7 +1075,7 @@ def parity_check_gqa_prompt( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1076,7 +1084,7 @@ def parity_check_gqa_prompt( config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1085,7 +1093,7 @@ def parity_check_gqa_prompt( config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1094,7 +1102,7 @@ def parity_check_gqa_prompt( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1103,7 +1111,7 @@ def parity_check_gqa_prompt( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1129,8 +1137,8 @@ def parity_check_gqa_prompt( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved @@ -1152,8 +1160,8 @@ def parity_check_gqa_prompt( kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE) + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE) k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded @@ -1218,11 +1226,11 @@ def parity_check_gqa_prompt( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "KV-buffer", @@ -1271,8 +1279,8 @@ def parity_check_gqa_prompt_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1280,7 +1288,7 @@ def parity_check_gqa_prompt_no_buff( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1289,7 +1297,7 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1298,7 +1306,7 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1321,8 +1329,8 @@ def parity_check_gqa_prompt_no_buff( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved @@ -1405,11 +1413,11 @@ def parity_check_gqa_prompt_no_buff( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "No buff", @@ -1458,8 +1466,8 @@ def parity_check_gqa_past( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1467,7 +1475,7 @@ def parity_check_gqa_past( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1476,7 +1484,7 @@ def parity_check_gqa_past( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1485,7 +1493,7 @@ def parity_check_gqa_past( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1494,7 +1502,7 @@ def parity_check_gqa_past( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1503,7 +1511,7 @@ def parity_check_gqa_past( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1534,8 +1542,8 @@ def parity_check_gqa_past( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved @@ -1624,11 +1632,11 @@ def parity_check_gqa_past( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "KV-buffer", @@ -1677,8 +1685,8 @@ def parity_check_gqa_past_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): torch.manual_seed(69) q = torch.randn( @@ -1687,7 +1695,7 @@ def parity_check_gqa_past_no_buff( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1696,7 +1704,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1705,7 +1713,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1714,7 +1722,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1723,7 +1731,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1759,8 +1767,8 @@ def parity_check_gqa_past_no_buff( angle = ( torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi ) - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved @@ -1849,7 +1857,7 @@ def parity_check_gqa_past_no_buff( out = out.detach().cpu().numpy() # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "NO buff", @@ -1983,8 +1991,8 @@ def test_gqa_past(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -1996,8 +2004,8 @@ def test_gqa_past(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2042,8 +2050,8 @@ def test_gqa_interactive_one_batch(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2053,8 +2061,8 @@ def test_gqa_interactive_one_batch(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc index 509f56664e572..102846e08ac5f 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc @@ -46,7 +46,7 @@ namespace qnnctxgen { "\t [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" "\t [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" - "\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n" + "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 7d384f7b1df67..72d9ce72ea7cb 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -1,5 +1,7 @@ # This file is auto updated by dependabot -lintrunner-adapters>=0.12.4 +# When any package below is changed, you shall run "lintrunner init" again. +lintrunner==0.12.5 +lintrunner-adapters==0.12.4 # RUFF ruff==0.5.4 # BLACK-ISORT diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index 3e2ade7eacd25..94c2d35a563b6 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -196,42 +196,6 @@ stages: WITH_CACHE: false MachinePool: 'onnxruntime-Win-CPU-2022' -- stage: training_x64_debug - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'Debug' - buildArch: x64 - additionalBuildFlags: --enable_training --build_wheel --disable_memleak_checker - msbuildPlatform: x64 - isX86: false - job_name_suffix: training_x64_debug - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - isTraining: true - ORT_EP_NAME: CPU - GenerateDocumentation: false - WITH_CACHE: false - MachinePool: 'onnxruntime-Win2022-CPU-training-AMD' - -- stage: training_x64_release - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - buildArch: x64 - additionalBuildFlags: --enable_training --build_wheel - msbuildPlatform: x64 - isX86: false - job_name_suffix: training_x64_release - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - isTraining: true - ORT_EP_NAME: CPU - GenerateDocumentation: false - WITH_CACHE: false - MachinePool: 'onnxruntime-Win2022-CPU-training-AMD' - - stage: ort_training_apis_x64_release dependsOn: [] jobs: