diff --git a/.github/workflows/aarch64-linux-gnu-shared.yaml b/.github/workflows/aarch64-linux-gnu-shared.yaml index dbba7c132..1f548e237 100644 --- a/.github/workflows/aarch64-linux-gnu-shared.yaml +++ b/.github/workflows/aarch64-linux-gnu-shared.yaml @@ -34,11 +34,12 @@ concurrency: jobs: aarch64_linux_gnu_shared: runs-on: ${{ matrix.os }} - name: aarch64 shared lib test + name: aarch64 shared GPU ${{ matrix.gpu }} strategy: fail-fast: false matrix: os: [ubuntu-latest] + gpu: [ON, OFF] steps: - uses: actions/checkout@v4 @@ -79,15 +80,24 @@ jobs: make -j2 make install - - name: cache-toolchain - id: cache-toolchain + - name: cache-toolchain (CPU) + if: matrix.gpu == 'OFF' + id: cache-toolchain-cpu uses: actions/cache@v4 with: path: toolchain key: gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz - - name: Download toolchain - if: steps.cache-toolchain.outputs.cache-hit != 'true' + - name: cache-toolchain (GPU) + if: matrix.gpu == 'ON' + id: cache-toolchain-gpu + uses: actions/cache@v4 + with: + path: toolchain + key: gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz + + - name: Download toolchain (CPU, gcc 7.5) + if: steps.cache-toolchain-cpu.outputs.cache-hit != 'true' && matrix.gpu == 'OFF' shell: bash run: | wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz @@ -95,6 +105,15 @@ jobs: mkdir $GITHUB_WORKSPACE/toolchain tar xf ./gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain + - name: Download toolchain (GPU, gcc 10.3) + if: steps.cache-toolchain-gpu.outputs.cache-hit != 'true' && matrix.gpu == 'ON' + shell: bash + run: | + wget -qq https://huggingface.co/csukuangfj/sherpa-ncnn-toolchains/resolve/main/gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz + + mkdir $GITHUB_WORKSPACE/toolchain + tar xf ./gcc-arm-10.3-2021.07-x86_64-aarch64-none-linux-gnu.tar.xz --strip-components 1 -C $GITHUB_WORKSPACE/toolchain + - name: Set environment variable if: steps.cache-build-result.outputs.cache-hit != 'true' shell: bash @@ -103,19 +122,31 @@ jobs: echo "$GITHUB_WORKSPACE/bin" >> "$GITHUB_PATH" ls -lh "$GITHUB_WORKSPACE/toolchain/bin" - echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV" - echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV" + if [[ ${{ matrix.gpu }} == OFF ]]; then + echo "CC=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV" + echo "CXX=aarch64-linux-gnu-g++" >> "$GITHUB_ENV" + else + echo "CC=aarch64-none-linux-gnu-gcc" >> "$GITHUB_ENV" + echo "CXX=aarch64-none-linux-gnu-g++" >> "$GITHUB_ENV" + fi - name: Display toolchain info shell: bash run: | - aarch64-linux-gnu-gcc --version + if [[ ${{ matrix.gpu }} == OFF ]]; then + which aarch64-linux-gnu-gcc + aarch64-linux-gnu-gcc --version + else + which aarch64-none-linux-gnu-gcc + aarch64-none-linux-gnu-gcc --version + fi - name: Display qemu-aarch64 -h shell: bash run: | export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc + export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc qemu-aarch64 -h - name: build aarch64-linux-gnu @@ -127,6 +158,7 @@ jobs: cmake --version export BUILD_SHARED_LIBS=ON + export SHERPA_ONNX_ENABLE_GPU=${{ matrix.gpu }} ./build-aarch64-linux-gnu.sh @@ -140,7 +172,11 @@ jobs: run: | export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH - export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc + if [[ ${{ matrix.gpu }} == OFF ]]; then + export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-linux-gnu/libc + else + export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/aarch64-none-linux-gnu/libc + fi ls -lh ./build-aarch64-linux-gnu/bin @@ -151,11 +187,20 @@ jobs: - name: Copy files shell: bash run: | - aarch64-linux-gnu-strip --version + if [[ ${{ matrix.gpu }} == OFF ]]; then + aarch64-linux-gnu-strip --version + else + aarch64-none-linux-gnu-strip --version + fi SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-linux-aarch64-shared + if [[ ${{ matrix.gpu }} == OFF ]]; then + dst=${dst}-cpu + else + dst=${dst}-gpu + fi mkdir $dst cp -a build-aarch64-linux-gnu/install/bin $dst/ @@ -166,7 +211,11 @@ jobs: ls -lh $dst/bin/ echo "strip" - aarch64-linux-gnu-strip $dst/bin/* + if [[ ${{ matrix.gpu }} == OFF ]]; then + aarch64-linux-gnu-strip $dst/bin/* + else + aarch64-none-linux-gnu-strip $dst/bin/* + fi tree $dst @@ -174,8 +223,8 @@ jobs: - uses: actions/upload-artifact@v4 with: - name: sherpa-onnx-linux-aarch64-shared - path: sherpa-onnx-*linux-aarch64-shared.tar.bz2 + name: sherpa-onnx-linux-aarch64-shared-gpu-${{ matrix.gpu }} + path: sherpa-onnx-*linux-aarch64-shared*.tar.bz2 # https://huggingface.co/docs/hub/spaces-github-actions - name: Publish to huggingface @@ -198,7 +247,7 @@ jobs: cd huggingface mkdir -p aarch64 - cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64 + cp -v ../sherpa-onnx-*-shared*.tar.bz2 ./aarch64 git status git lfs track "*.bz2" diff --git a/build-aarch64-linux-gnu.sh b/build-aarch64-linux-gnu.sh index d9851fbe1..62b359c17 100755 --- a/build-aarch64-linux-gnu.sh +++ b/build-aarch64-linux-gnu.sh @@ -44,6 +44,21 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then BUILD_SHARED_LIBS=OFF fi +if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"" ]]; then + # By default, use CPU + SHERPA_ONNX_ENABLE_GPU=OFF + + # If you use GPU, then please make sure you have NVIDIA GPUs on your board. + # It uses onnxruntime 1.11.0. + # + # Tested on Jetson Nano B01 +fi + +if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"ON" ]]; then + # Build shared libs if building GPU is enabled. + BUILD_SHARED_LIBS=ON +fi + cmake \ -DBUILD_PIPER_PHONMIZE_EXE=OFF \ -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ @@ -51,6 +66,7 @@ cmake \ -DBUILD_ESPEAK_NG_TESTS=OFF \ -DCMAKE_INSTALL_PREFIX=./install \ -DCMAKE_BUILD_TYPE=Release \ + -DSHERPA_ONNX_ENABLE_GPU=$SHERPA_ONNX_ENABLE_GPU \ -DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \ -DSHERPA_ONNX_ENABLE_TESTS=OFF \ -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ diff --git a/cmake/onnxruntime-linux-aarch64-gpu.cmake b/cmake/onnxruntime-linux-aarch64-gpu.cmake new file mode 100644 index 000000000..64db9c22b --- /dev/null +++ b/cmake/onnxruntime-linux-aarch64-gpu.cmake @@ -0,0 +1,101 @@ +# Copyright (c) 2022-2024 Xiaomi Corporation +message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") + +if(NOT CMAKE_SYSTEM_NAME STREQUAL Linux) + message(FATAL_ERROR "This file is for Linux only. Given: ${CMAKE_SYSTEM_NAME}") +endif() + +if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) + message(FATAL_ERROR "This file is for aarch64 only. Given: ${CMAKE_SYSTEM_PROCESSOR}") +endif() + +if(NOT BUILD_SHARED_LIBS) + message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}") +endif() + +if(NOT SHERPA_ONNX_ENABLE_GPU) + message(FATAL_ERROR "This file is for NVIDIA GPU only. Given SHERPA_ONNX_ENABLE_GPU: ${SHERPA_ONNX_ENABLE_GPU}") +endif() + +set(onnxruntime_URL "https://github.com/csukuangfj/onnxruntime-libs/releases/download/v1.11.0/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2") +set(onnxruntime_URL2 "https://hf-mirror.com/csukuangfj/onnxruntime-libs/resolve/main/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2") +set(onnxruntime_HASH "SHA256=36eded935551e23aead09d4173bdf0bd1e7b01fdec15d77f97d6e34029aa60d7") + +# If you don't have access to the Internet, +# please download onnxruntime to one of the following locations. +# You can add more if you want. +set(possible_file_locations + $ENV{HOME}/Downloads/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2 + ${CMAKE_SOURCE_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2 + ${CMAKE_BINARY_DIR}/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2 + /tmp/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2 + /star-fj/fangjun/download/github/onnxruntime-linux-aarch64-gpu-1.11.0.tar.bz2 +) + +foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(onnxruntime_URL "${f}") + file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL) + message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}") + set(onnxruntime_URL2) + break() + endif() +endforeach() + +FetchContent_Declare(onnxruntime + URL + ${onnxruntime_URL} + ${onnxruntime_URL2} + URL_HASH ${onnxruntime_HASH} +) + +FetchContent_GetProperties(onnxruntime) +if(NOT onnxruntime_POPULATED) + message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}") + FetchContent_Populate(onnxruntime) +endif() +message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") + +find_library(location_onnxruntime onnxruntime + PATHS + "${onnxruntime_SOURCE_DIR}/lib" + NO_CMAKE_SYSTEM_PATH +) + +message(STATUS "location_onnxruntime: ${location_onnxruntime}") + +add_library(onnxruntime SHARED IMPORTED) + +set_target_properties(onnxruntime PROPERTIES + IMPORTED_LOCATION ${location_onnxruntime} + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" +) + +find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda + PATHS + "${onnxruntime_SOURCE_DIR}/lib" + NO_CMAKE_SYSTEM_PATH +) + +add_library(onnxruntime_providers_cuda SHARED IMPORTED) +set_target_properties(onnxruntime_providers_cuda PROPERTIES + IMPORTED_LOCATION ${location_onnxruntime_cuda_lib} +) +message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}") + +# for libonnxruntime_providers_shared.so +find_library(location_onnxruntime_providers_shared_lib onnxruntime_providers_shared + PATHS + "${onnxruntime_SOURCE_DIR}/lib" + NO_CMAKE_SYSTEM_PATH +) +add_library(onnxruntime_providers_shared SHARED IMPORTED) +set_target_properties(onnxruntime_providers_shared PROPERTIES + IMPORTED_LOCATION ${location_onnxruntime_providers_shared_lib} +) +message(STATUS "location_onnxruntime_providers_shared_lib: ${location_onnxruntime_providers_shared_lib}") + +file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime*") +message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}") +install(FILES ${onnxruntime_lib_files} DESTINATION lib) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 6655b45cd..8453b96bd 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -13,7 +13,9 @@ function(download_onnxruntime) include(onnxruntime-linux-riscv64-static) endif() elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) - if(BUILD_SHARED_LIBS) + if(SHERPA_ONNX_ENABLE_GPU) + include(onnxruntime-linux-aarch64-gpu) + elseif(BUILD_SHARED_LIBS) include(onnxruntime-linux-aarch64) else() include(onnxruntime-linux-aarch64-static) diff --git a/cmake/piper-phonemize.cmake b/cmake/piper-phonemize.cmake index bcea4e8ac..9c9c71f5a 100644 --- a/cmake/piper-phonemize.cmake +++ b/cmake/piper-phonemize.cmake @@ -1,18 +1,18 @@ function(download_piper_phonemize) include(FetchContent) - set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/dc6b5f4441bffe521047086930b0fc12686acd56.zip") - set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip") - set(piper_phonemize_HASH "SHA256=b9faa04204b1756fa455a962abb1f037041c040133d55be58d11f11ab9b3ce14") + set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip") + set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip") + set(piper_phonemize_HASH "SHA256=ab4d06ca76047e1585c63c482f39ffead5315785345055360703cc9382c5e74b") # If you don't have access to the Internet, # please pre-download kaldi-decoder set(possible_file_locations - $ENV{HOME}/Downloads/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip - ${CMAKE_SOURCE_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip - ${CMAKE_BINARY_DIR}/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip - /tmp/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip - /star-fj/fangjun/download/github/piper-phonemize-dc6b5f4441bffe521047086930b0fc12686acd56.zip + $ENV{HOME}/Downloads/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip + ${CMAKE_SOURCE_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip + ${CMAKE_BINARY_DIR}/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip + /tmp/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip + /star-fj/fangjun/download/github/piper-phonemize-38ee199dcc49c7b6de89f7ebfb32ed682763fa1b.zip ) foreach(f IN LISTS possible_file_locations) diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index 521739a89..506e63e13 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -7,6 +7,8 @@ #include #include +#include + #if __ANDROID_API__ >= 8 #include "android/log.h" #define SHERPA_ONNX_LOGE(...) \ @@ -36,30 +38,28 @@ #endif // Read an integer -#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - \ - dst = atoi(value.get()); \ - if (dst < 0) { \ - SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ - exit(-1); \ - } \ +#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + dst = atoi(value.c_str()); \ + if (dst < 0) { \ + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ + exit(-1); \ + } \ } while (0) #define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \ do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ dst = default_value; \ } else { \ - dst = atoi(value.get()); \ + dst = atoi(value.c_str()); \ if (dst < 0) { \ SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ exit(-1); \ @@ -68,118 +68,111 @@ } while (0) // read a vector of integers -#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - \ - bool ret = SplitStringToIntegers(value.get(), ",", true, &dst); \ - if (!ret) { \ - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \ - exit(-1); \ - } \ +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + bool ret = SplitStringToIntegers(value.c_str(), ",", true, &dst); \ + if (!ret) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \ + exit(-1); \ + } \ } while (0) // read a vector of floats -#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - \ - bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \ - if (!ret) { \ - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.get(), src_key); \ - exit(-1); \ - } \ +#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + bool ret = SplitStringToFloats(value.c_str(), ",", true, &dst); \ + if (!ret) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \ + exit(-1); \ + } \ } while (0) // read a vector of strings -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - SplitStringToVector(value.get(), ",", false, &dst); \ - \ - if (dst.empty()) { \ - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ - value.get(), src_key); \ - exit(-1); \ - } \ +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + SplitStringToVector(value.c_str(), ",", false, &dst); \ + \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ + value.c_str(), src_key); \ + exit(-1); \ + } \ } while (0) // read a vector of strings separated by sep -#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - SplitStringToVector(value.get(), sep, false, &dst); \ - \ - if (dst.empty()) { \ - SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ - value.get(), src_key); \ - exit(-1); \ - } \ +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + SplitStringToVector(value.c_str(), sep, false, &dst); \ + \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ + value.c_str(), src_key); \ + exit(-1); \ + } \ } while (0) // Read a string -#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - \ - dst = value.get(); \ - if (dst.empty()) { \ - SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ - exit(-1); \ - } \ +#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + dst = std::move(value); \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ + exit(-1); \ + } \ } while (0) -#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ - exit(-1); \ - } \ - \ - dst = value.get(); \ +#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + dst = std::move(value); \ } while (0) -#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ - default_value) \ - do { \ - auto value = \ - meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ - if (!value) { \ - dst = default_value; \ - } else { \ - dst = value.get(); \ - if (dst.empty()) { \ - SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ - exit(-1); \ - } \ - } \ +#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ + default_value) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + dst = default_value; \ + } else { \ + dst = std::move(value); \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ + exit(-1); \ + } \ + } \ } while (0) #define SHERPA_ONNX_EXIT(code) exit(code) diff --git a/sherpa-onnx/csrc/offline-ced-model.cc b/sherpa-onnx/csrc/offline-ced-model.cc index 538fe5bdb..d6dd35290 100644 --- a/sherpa-onnx/csrc/offline-ced-model.cc +++ b/sherpa-onnx/csrc/offline-ced-model.cc @@ -46,7 +46,7 @@ class OfflineCEDModel::Impl { int32_t NumEventClasses() const { return num_event_classes_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } private: void Init(void *model_data, size_t model_data_length) { diff --git a/sherpa-onnx/csrc/offline-ct-transformer-model.cc b/sherpa-onnx/csrc/offline-ct-transformer-model.cc index 2ce593b3e..d616484b4 100644 --- a/sherpa-onnx/csrc/offline-ct-transformer-model.cc +++ b/sherpa-onnx/csrc/offline-ct-transformer-model.cc @@ -44,7 +44,7 @@ class OfflineCtTransformerModel::Impl { return std::move(ans[0]); } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } const OfflineCtTransformerModelMetaData &GetModelMetadata() const { return meta_data_; diff --git a/sherpa-onnx/csrc/offline-ctc-model.cc b/sherpa-onnx/csrc/offline-ctc-model.cc index 9d1e05d9b..2cbd936ea 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-ctc-model.cc @@ -53,8 +53,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, Ort::AllocatorWithDefaultOptions allocator; auto model_type = - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); - if (!model_type) { + LookupCustomModelMetaData(meta_data, "model_type", allocator); + if (model_type.empty()) { SHERPA_ONNX_LOGE( "No model_type in the metadata!\n" "If you are using models from NeMo, please refer to\n" @@ -74,22 +74,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kUnknown; } - if (model_type.get() == std::string("EncDecCTCModelBPE")) { + if (model_type == "EncDecCTCModelBPE") { return ModelType::kEncDecCTCModelBPE; - } else if (model_type.get() == std::string("EncDecCTCModel")) { + } else if (model_type == "EncDecCTCModel") { return ModelType::kEncDecCTCModel; - } else if (model_type.get() == std::string("EncDecHybridRNNTCTCBPEModel")) { + } else if (model_type == "EncDecHybridRNNTCTCBPEModel") { return ModelType::kEncDecHybridRNNTCTCBPEModel; - } else if (model_type.get() == std::string("tdnn")) { + } else if (model_type == "tdnn") { return ModelType::kTdnn; - } else if (model_type.get() == std::string("zipformer2_ctc")) { + } else if (model_type == "zipformer2_ctc") { return ModelType::kZipformerCtc; - } else if (model_type.get() == std::string("wenet_ctc")) { + } else if (model_type == "wenet_ctc") { return ModelType::kWenetCtc; - } else if (model_type.get() == std::string("telespeech_ctc")) { + } else if (model_type == "telespeech_ctc") { return ModelType::kTeleSpeechCtc; } else { - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); return ModelType::kUnknown; } } diff --git a/sherpa-onnx/csrc/offline-moonshine-model.cc b/sherpa-onnx/csrc/offline-moonshine-model.cc index ab71d000f..bf9624d4d 100644 --- a/sherpa-onnx/csrc/offline-moonshine-model.cc +++ b/sherpa-onnx/csrc/offline-moonshine-model.cc @@ -155,7 +155,7 @@ class OfflineMoonshineModel::Impl { return {std::move(cached_decoder_out[0]), std::move(next_states)}; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } private: void InitPreprocessor(void *model_data, size_t model_data_length) { diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc index 708cb4b4f..14dc7dbe4 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc @@ -68,7 +68,7 @@ class OfflineNemoEncDecCtcModel::Impl { int32_t SubsamplingFactor() const { return subsampling_factor_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } std::string FeatureNormalizationMethod() const { return normalize_type_; } diff --git a/sherpa-onnx/csrc/offline-paraformer-model.cc b/sherpa-onnx/csrc/offline-paraformer-model.cc index ce1851062..9c61cb350 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model.cc +++ b/sherpa-onnx/csrc/offline-paraformer-model.cc @@ -56,7 +56,7 @@ class OfflineParaformerModel::Impl { const std::vector &InverseStdDev() const { return inv_stddev_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } private: void Init(void *model_data, size_t model_data_length) { diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 07887df60..f89bb9bec 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -121,9 +121,9 @@ std::unique_ptr OfflineRecognizerImpl::Create( Ort::AllocatorWithDefaultOptions allocator; // used in the macro below - auto model_type_ptr = - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); - if (!model_type_ptr) { + auto model_type = + LookupCustomModelMetaData(meta_data, "model_type", allocator); + if (!model_type.empty()) { SHERPA_ONNX_LOGE( "No model_type in the metadata!\n\n" "Please refer to the following URLs to add metadata" @@ -164,7 +164,6 @@ std::unique_ptr OfflineRecognizerImpl::Create( "\n"); exit(-1); } - std::string model_type(model_type_ptr.get()); if (model_type == "conformer" || model_type == "zipformer" || model_type == "zipformer2") { @@ -301,9 +300,9 @@ std::unique_ptr OfflineRecognizerImpl::Create( Ort::AllocatorWithDefaultOptions allocator; // used in the macro below - auto model_type_ptr = - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); - if (!model_type_ptr) { + auto model_type = + LookupCustomModelMetaData(meta_data, "model_type", allocator); + if (model_type.empty()) { SHERPA_ONNX_LOGE( "No model_type in the metadata!\n\n" "Please refer to the following URLs to add metadata" @@ -344,7 +343,6 @@ std::unique_ptr OfflineRecognizerImpl::Create( "\n"); exit(-1); } - std::string model_type(model_type_ptr.get()); if (model_type == "conformer" || model_type == "zipformer" || model_type == "zipformer2") { diff --git a/sherpa-onnx/csrc/offline-sense-voice-model.cc b/sherpa-onnx/csrc/offline-sense-voice-model.cc index 24903a41a..a914ccf4a 100644 --- a/sherpa-onnx/csrc/offline-sense-voice-model.cc +++ b/sherpa-onnx/csrc/offline-sense-voice-model.cc @@ -56,7 +56,7 @@ class OfflineSenseVoiceModel::Impl { return meta_data_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } private: void Init(void *model_data, size_t model_data_length) { diff --git a/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc index ea91d1c55..d7db0040c 100644 --- a/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc @@ -63,7 +63,7 @@ class OfflineTdnnCtcModel::Impl { int32_t VocabSize() const { return vocab_size_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } private: void Init(void *model_data, size_t model_data_length) { diff --git a/sherpa-onnx/csrc/offline-telespeech-ctc-model.cc b/sherpa-onnx/csrc/offline-telespeech-ctc-model.cc index 68c0afbe8..aeb918cd3 100644 --- a/sherpa-onnx/csrc/offline-telespeech-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-telespeech-ctc-model.cc @@ -69,7 +69,7 @@ class OfflineTeleSpeechCtcModel::Impl { int32_t SubsamplingFactor() const { return subsampling_factor_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } private: void Init(void *model_data, size_t model_data_length) { diff --git a/sherpa-onnx/csrc/offline-transducer-model.cc b/sherpa-onnx/csrc/offline-transducer-model.cc index 6a297347d..910ae3475 100644 --- a/sherpa-onnx/csrc/offline-transducer-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-model.cc @@ -95,11 +95,11 @@ class OfflineTransducerModel::Impl { int32_t VocabSize() const { return vocab_size_; } int32_t ContextSize() const { return context_size_; } int32_t SubsamplingFactor() const { return 4; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } Ort::Value BuildDecoderInput( const std::vector &results, - int32_t end_index) const { + int32_t end_index) { assert(end_index <= results.size()); int32_t batch_size = end_index; @@ -122,7 +122,7 @@ class OfflineTransducerModel::Impl { } Ort::Value BuildDecoderInput(const std::vector &results, - int32_t end_index) const { + int32_t end_index) { assert(end_index <= results.size()); int32_t batch_size = end_index; diff --git a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc index 5332a835e..7dd5d31b8 100644 --- a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc @@ -123,7 +123,7 @@ class OfflineTransducerNeMoModel::Impl { return std::move(logit[0]); } - std::vector GetDecoderInitStates(int32_t batch_size) const { + std::vector GetDecoderInitStates(int32_t batch_size) { std::array s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; Ort::Value s0 = Ort::Value::CreateTensor(allocator_, s0_shape.data(), s0_shape.size()); @@ -149,7 +149,7 @@ class OfflineTransducerNeMoModel::Impl { int32_t SubsamplingFactor() const { return subsampling_factor_; } int32_t VocabSize() const { return vocab_size_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } std::string FeatureNormalizationMethod() const { return normalize_type_; } diff --git a/sherpa-onnx/csrc/offline-wenet-ctc-model.cc b/sherpa-onnx/csrc/offline-wenet-ctc-model.cc index 93fdffab8..d696aa1c7 100644 --- a/sherpa-onnx/csrc/offline-wenet-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-wenet-ctc-model.cc @@ -47,7 +47,7 @@ class OfflineWenetCtcModel::Impl { int32_t SubsamplingFactor() const { return subsampling_factor_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } private: void Init(void *model_data, size_t model_data_length) { diff --git a/sherpa-onnx/csrc/offline-whisper-model.cc b/sherpa-onnx/csrc/offline-whisper-model.cc index 485eaf93c..0747a329b 100644 --- a/sherpa-onnx/csrc/offline-whisper-model.cc +++ b/sherpa-onnx/csrc/offline-whisper-model.cc @@ -188,7 +188,7 @@ class OfflineWhisperModel::Impl { return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)}; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } const std::vector &GetInitialTokens() const { return sot_sequence_; } diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc index 8a2e80dc2..7ddf6d9b3 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc @@ -47,7 +47,7 @@ class OfflineZipformerAudioTaggingModel::Impl { int32_t NumEventClasses() const { return num_event_classes_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } private: void Init(void *model_data, size_t model_data_length) { diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc index 8db9439e4..a783ce506 100644 --- a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc @@ -48,7 +48,7 @@ class OfflineZipformerCtcModel::Impl { int32_t VocabSize() const { return vocab_size_; } int32_t SubsamplingFactor() const { return 4; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } private: void Init(void *model_data, size_t model_data_length) { diff --git a/sherpa-onnx/csrc/online-cnn-bilstm-model.cc b/sherpa-onnx/csrc/online-cnn-bilstm-model.cc index ce8da377e..f4fb3c8f9 100644 --- a/sherpa-onnx/csrc/online-cnn-bilstm-model.cc +++ b/sherpa-onnx/csrc/online-cnn-bilstm-model.cc @@ -47,7 +47,7 @@ class OnlineCNNBiLSTMModel::Impl { return {std::move(ans[0]), std::move(ans[1])}; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const { return meta_data_; diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.cc b/sherpa-onnx/csrc/online-conformer-transducer-model.cc index 7c252f5a4..2bceffc7d 100644 --- a/sherpa-onnx/csrc/online-conformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.cc @@ -163,8 +163,11 @@ std::vector OnlineConformerTransducerModel::StackStates( conv_vec[i] = &states[i][1]; } - Ort::Value attn = Cat(allocator_, attn_vec, 2); - Ort::Value conv = Cat(allocator_, conv_vec, 2); + auto allocator = + const_cast(this)->allocator_; + + Ort::Value attn = Cat(allocator, attn_vec, 2); + Ort::Value conv = Cat(allocator, conv_vec, 2); std::vector ans; ans.reserve(2); @@ -183,8 +186,11 @@ OnlineConformerTransducerModel::UnStackStates( std::vector> ans(batch_size); - std::vector attn_vec = Unbind(allocator_, &states[0], 2); - std::vector conv_vec = Unbind(allocator_, &states[1], 2); + auto allocator = + const_cast(this)->allocator_; + + std::vector attn_vec = Unbind(allocator, &states[0], 2); + std::vector conv_vec = Unbind(allocator, &states[1], 2); assert(attn_vec.size() == batch_size); assert(conv_vec.size() == batch_size); diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 094cc933c..b9ef9ca5b 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -158,9 +158,10 @@ std::vector OnlineLstmTransducerModel::StackStates( h_buf[i] = &states[i][0]; c_buf[i] = &states[i][1]; } + auto allocator = const_cast(this)->allocator_; - Ort::Value h = Cat(allocator_, h_buf, 1); - Ort::Value c = Cat(allocator_, c_buf, 1); + Ort::Value h = Cat(allocator, h_buf, 1); + Ort::Value c = Cat(allocator, c_buf, 1); std::vector ans; ans.reserve(2); @@ -177,8 +178,10 @@ std::vector> OnlineLstmTransducerModel::UnStackStates( std::vector> ans(batch_size); - std::vector h_vec = Unbind(allocator_, &states[0], 1); - std::vector c_vec = Unbind(allocator_, &states[1], 1); + auto allocator = const_cast(this)->allocator_; + + std::vector h_vec = Unbind(allocator, &states[0], 1); + std::vector c_vec = Unbind(allocator, &states[1], 1); assert(h_vec.size() == batch_size); assert(c_vec.size() == batch_size); diff --git a/sherpa-onnx/csrc/online-nemo-ctc-model.cc b/sherpa-onnx/csrc/online-nemo-ctc-model.cc index d93ff73b1..172ee69f4 100644 --- a/sherpa-onnx/csrc/online-nemo-ctc-model.cc +++ b/sherpa-onnx/csrc/online-nemo-ctc-model.cc @@ -102,7 +102,7 @@ class OnlineNeMoCtcModel::Impl { int32_t ChunkShift() const { return chunk_shift_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } // Return a vector containing 3 tensors // - cache_last_channel @@ -119,7 +119,7 @@ class OnlineNeMoCtcModel::Impl { } std::vector StackStates( - std::vector> states) const { + std::vector> states) { int32_t batch_size = static_cast(states.size()); if (batch_size == 1) { return std::move(states[0]); @@ -157,6 +157,8 @@ class OnlineNeMoCtcModel::Impl { std::vector states) const { assert(states.size() == 3); + auto allocator = const_cast(this)->allocator_; + std::vector> ans; auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); @@ -171,9 +173,9 @@ class OnlineNeMoCtcModel::Impl { for (int32_t i = 0; i != 3; ++i) { std::vector v; if (i == 2) { - v = Unbind(allocator_, &states[i], 0); + v = Unbind(allocator, &states[i], 0); } else { - v = Unbind(allocator_, &states[i], 0); + v = Unbind(allocator, &states[i], 0); } assert(v.size() == batch_size); diff --git a/sherpa-onnx/csrc/online-paraformer-model.cc b/sherpa-onnx/csrc/online-paraformer-model.cc index 9397ff75b..d7d2e436d 100644 --- a/sherpa-onnx/csrc/online-paraformer-model.cc +++ b/sherpa-onnx/csrc/online-paraformer-model.cc @@ -105,7 +105,7 @@ class OnlineParaformerModel::Impl { const std::vector &InverseStdDev() const { return inv_stddev_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } private: void InitEncoder(void *model_data, size_t model_data_length) { diff --git a/sherpa-onnx/csrc/online-rnn-lm.cc b/sherpa-onnx/csrc/online-rnn-lm.cc index 1b13d3a2d..2a44ddbe0 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.cc +++ b/sherpa-onnx/csrc/online-rnn-lm.cc @@ -5,10 +5,10 @@ #include "sherpa-onnx/csrc/online-rnn-lm.h" +#include #include #include #include -#include #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/macros.h" @@ -53,49 +53,49 @@ class OnlineRnnLM::Impl { // classic rescore function void ComputeLMScore(float scale, int32_t context_size, - std::vector *hyps) { - Ort::AllocatorWithDefaultOptions allocator; - - for (auto &hyp : *hyps) { - for (auto &h_m : hyp) { - auto &h = h_m.second; - auto &ys = h.ys; - const int32_t token_num_in_chunk = - ys.size() - context_size - h.cur_scored_pos - 1; - - if (token_num_in_chunk < 1) { - continue; - } - - if (h.nn_lm_states.empty()) { - h.nn_lm_states = Convert(GetInitStates()); - } - - if (token_num_in_chunk >= h.lm_rescore_min_chunk) { - std::array x_shape{1, token_num_in_chunk}; - - Ort::Value x = Ort::Value::CreateTensor( - allocator, x_shape.data(), x_shape.size()); - int64_t *p_x = x.GetTensorMutableData(); - std::copy(ys.begin() + context_size + h.cur_scored_pos, - ys.end() - 1, p_x); - - // streaming forward by NN LM - auto out = ScoreToken(std::move(x), - Convert(std::move(h.nn_lm_states))); - - // update NN LM score in hyp - const float *p_nll = out.first.GetTensorData(); - h.lm_log_prob = -scale * (*p_nll); - - // update NN LM states in hyp - h.nn_lm_states = Convert(std::move(out.second)); - - h.cur_scored_pos += token_num_in_chunk; - } + std::vector *hyps) { + Ort::AllocatorWithDefaultOptions allocator; + + for (auto &hyp : *hyps) { + for (auto &h_m : hyp) { + auto &h = h_m.second; + auto &ys = h.ys; + const int32_t token_num_in_chunk = + ys.size() - context_size - h.cur_scored_pos - 1; + + if (token_num_in_chunk < 1) { + continue; + } + + if (h.nn_lm_states.empty()) { + h.nn_lm_states = Convert(GetInitStates()); + } + + if (token_num_in_chunk >= h.lm_rescore_min_chunk) { + std::array x_shape{1, token_num_in_chunk}; + + Ort::Value x = Ort::Value::CreateTensor( + allocator, x_shape.data(), x_shape.size()); + int64_t *p_x = x.GetTensorMutableData(); + std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1, + p_x); + + // streaming forward by NN LM + auto out = + ScoreToken(std::move(x), Convert(std::move(h.nn_lm_states))); + + // update NN LM score in hyp + const float *p_nll = out.first.GetTensorData(); + h.lm_log_prob = -scale * (*p_nll); + + // update NN LM states in hyp + h.nn_lm_states = Convert(std::move(out.second)); + + h.cur_scored_pos += token_num_in_chunk; } } } + } std::pair> ScoreToken( Ort::Value x, std::vector states) { @@ -125,7 +125,7 @@ class OnlineRnnLM::Impl { } // get init states for classic rescore - std::vector GetInitStates() const { + std::vector GetInitStates() { std::vector ans; ans.reserve(init_states_.size()); @@ -226,7 +226,7 @@ std::pair> OnlineRnnLM::ScoreToken( // classic rescore scores void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size, - std::vector *hyps) { + std::vector *hyps) { return impl_->ComputeLMScore(scale, context_size, hyps); } @@ -235,5 +235,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) { return impl_->ComputeLMScoreSF(scale, hyp); } - } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 16577dd49..51a9aef3c 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -54,8 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, Ort::AllocatorWithDefaultOptions allocator; auto model_type = - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); - if (!model_type) { + LookupCustomModelMetaData(meta_data, "model_type", allocator); + if (model_type.empty()) { SHERPA_ONNX_LOGE( "No model_type in the metadata!\n" "Please make sure you are using the latest export-onnx.py from icefall " @@ -63,16 +63,16 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kUnknown; } - if (model_type.get() == std::string("conformer")) { + if (model_type == "conformer") { return ModelType::kConformer; - } else if (model_type.get() == std::string("lstm")) { + } else if (model_type == "lstm") { return ModelType::kLstm; - } else if (model_type.get() == std::string("zipformer")) { + } else if (model_type == "zipformer") { return ModelType::kZipformer; - } else if (model_type.get() == std::string("zipformer2")) { + } else if (model_type == "zipformer2") { return ModelType::kZipformer2; } else { - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); return ModelType::kUnknown; } } diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index 4e12da44c..264593a1c 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -197,7 +197,7 @@ class OnlineTransducerNeMoModel::Impl { int32_t VocabSize() const { return vocab_size_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } std::string FeatureNormalizationMethod() const { return normalize_type_; } @@ -224,6 +224,8 @@ class OnlineTransducerNeMoModel::Impl { std::vector ans; + auto allocator = const_cast(this)->allocator_; + // stack cache_last_channel std::vector buf(batch_size); @@ -239,9 +241,9 @@ class OnlineTransducerNeMoModel::Impl { Ort::Value c{nullptr}; if (i == 2) { - c = Cat(allocator_, buf, 0); + c = Cat(allocator, buf, 0); } else { - c = Cat(allocator_, buf, 0); + c = Cat(allocator, buf, 0); } ans.push_back(std::move(c)); @@ -251,7 +253,7 @@ class OnlineTransducerNeMoModel::Impl { } std::vector> UnStackStates( - std::vector states) const { + std::vector states) { assert(states.size() == 3); std::vector> ans; diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.cc b/sherpa-onnx/csrc/online-wenet-ctc-model.cc index 1b1605183..cce322aa4 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model.cc +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.cc @@ -101,7 +101,7 @@ class OnlineWenetCtcModel::Impl { return config_.wenet_ctc.chunk_size * subsampling_factor_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } // Return a vector containing 3 tensors // - attn_cache diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index 324b2b088..36e2d9dbd 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -179,12 +179,15 @@ std::vector OnlineZipformerTransducerModel::StackStates( std::vector ans; ans.reserve(states[0].size()); + auto allocator = + const_cast(this)->allocator_; + // cached_len for (int32_t i = 0; i != num_encoders; ++i) { for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][i]; } - auto v = Cat(allocator_, buf, 1); // (num_layers, 1) + auto v = Cat(allocator, buf, 1); // (num_layers, 1) ans.push_back(std::move(v)); } @@ -193,7 +196,7 @@ std::vector OnlineZipformerTransducerModel::StackStates( for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][num_encoders + i]; } - auto v = Cat(allocator_, buf, 1); // (num_layers, 1, encoder_dims) + auto v = Cat(allocator, buf, 1); // (num_layers, 1, encoder_dims) ans.push_back(std::move(v)); } @@ -203,7 +206,7 @@ std::vector OnlineZipformerTransducerModel::StackStates( buf[n] = &states[n][num_encoders * 2 + i]; } // (num_layers, left_context_len, 1, attention_dims) - auto v = Cat(allocator_, buf, 2); + auto v = Cat(allocator, buf, 2); ans.push_back(std::move(v)); } @@ -213,7 +216,7 @@ std::vector OnlineZipformerTransducerModel::StackStates( buf[n] = &states[n][num_encoders * 3 + i]; } // (num_layers, left_context_len, 1, attention_dims/2) - auto v = Cat(allocator_, buf, 2); + auto v = Cat(allocator, buf, 2); ans.push_back(std::move(v)); } @@ -223,7 +226,7 @@ std::vector OnlineZipformerTransducerModel::StackStates( buf[n] = &states[n][num_encoders * 4 + i]; } // (num_layers, left_context_len, 1, attention_dims/2) - auto v = Cat(allocator_, buf, 2); + auto v = Cat(allocator, buf, 2); ans.push_back(std::move(v)); } @@ -233,7 +236,7 @@ std::vector OnlineZipformerTransducerModel::StackStates( buf[n] = &states[n][num_encoders * 5 + i]; } // (num_layers, 1, encoder_dims, cnn_module_kernels-1) - auto v = Cat(allocator_, buf, 1); + auto v = Cat(allocator, buf, 1); ans.push_back(std::move(v)); } @@ -243,7 +246,7 @@ std::vector OnlineZipformerTransducerModel::StackStates( buf[n] = &states[n][num_encoders * 6 + i]; } // (num_layers, 1, encoder_dims, cnn_module_kernels-1) - auto v = Cat(allocator_, buf, 1); + auto v = Cat(allocator, buf, 1); ans.push_back(std::move(v)); } @@ -258,12 +261,15 @@ OnlineZipformerTransducerModel::UnStackStates( int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; int32_t num_encoders = num_encoder_layers_.size(); + auto allocator = + const_cast(this)->allocator_; + std::vector> ans; ans.resize(batch_size); // cached_len for (int32_t i = 0; i != num_encoders; ++i) { - auto v = Unbind(allocator_, &states[i], 1); + auto v = Unbind(allocator, &states[i], 1); assert(v.size() == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -273,7 +279,7 @@ OnlineZipformerTransducerModel::UnStackStates( // cached_avg for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) { - auto v = Unbind(allocator_, &states[i], 1); + auto v = Unbind(allocator, &states[i], 1); assert(v.size() == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -283,7 +289,7 @@ OnlineZipformerTransducerModel::UnStackStates( // cached_key for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) { - auto v = Unbind(allocator_, &states[i], 2); + auto v = Unbind(allocator, &states[i], 2); assert(v.size() == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -293,7 +299,7 @@ OnlineZipformerTransducerModel::UnStackStates( // cached_val for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) { - auto v = Unbind(allocator_, &states[i], 2); + auto v = Unbind(allocator, &states[i], 2); assert(v.size() == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -303,7 +309,7 @@ OnlineZipformerTransducerModel::UnStackStates( // cached_val2 for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) { - auto v = Unbind(allocator_, &states[i], 2); + auto v = Unbind(allocator, &states[i], 2); assert(v.size() == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -313,7 +319,7 @@ OnlineZipformerTransducerModel::UnStackStates( // cached_conv1 for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) { - auto v = Unbind(allocator_, &states[i], 1); + auto v = Unbind(allocator, &states[i], 1); assert(v.size() == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -323,7 +329,7 @@ OnlineZipformerTransducerModel::UnStackStates( // cached_conv2 for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) { - auto v = Unbind(allocator_, &states[i], 1); + auto v = Unbind(allocator, &states[i], 1); assert(v.size() == batch_size); for (int32_t n = 0; n != batch_size; ++n) { diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc index 04699a56b..8f0708ad1 100644 --- a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc @@ -70,7 +70,7 @@ class OnlineZipformer2CtcModel::Impl { int32_t ChunkShift() const { return decode_chunk_len_; } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } // Return a vector containing 3 tensors // - attn_cache @@ -86,7 +86,7 @@ class OnlineZipformer2CtcModel::Impl { } std::vector StackStates( - std::vector> states) const { + std::vector> states) { int32_t batch_size = static_cast(states.size()); std::vector buf(batch_size); @@ -159,7 +159,7 @@ class OnlineZipformer2CtcModel::Impl { } std::vector> UnStackStates( - std::vector states) const { + std::vector states) { int32_t m = std::accumulate(num_encoder_layers_.begin(), num_encoder_layers_.end(), 0); assert(states.size() == m * 6 + 2); diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc index 0782f06fc..03c68474c 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc @@ -185,6 +185,9 @@ std::vector OnlineZipformer2TransducerModel::StackStates( std::vector buf(batch_size); + auto allocator = + const_cast(this)->allocator_; + std::vector ans; int32_t num_states = static_cast(states[0].size()); ans.reserve(num_states); @@ -194,42 +197,42 @@ std::vector OnlineZipformer2TransducerModel::StackStates( for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][6 * i]; } - auto v = Cat(allocator_, buf, 1); + auto v = Cat(allocator, buf, 1); ans.push_back(std::move(v)); } { for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][6 * i + 1]; } - auto v = Cat(allocator_, buf, 1); + auto v = Cat(allocator, buf, 1); ans.push_back(std::move(v)); } { for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][6 * i + 2]; } - auto v = Cat(allocator_, buf, 1); + auto v = Cat(allocator, buf, 1); ans.push_back(std::move(v)); } { for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][6 * i + 3]; } - auto v = Cat(allocator_, buf, 1); + auto v = Cat(allocator, buf, 1); ans.push_back(std::move(v)); } { for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][6 * i + 4]; } - auto v = Cat(allocator_, buf, 0); + auto v = Cat(allocator, buf, 0); ans.push_back(std::move(v)); } { for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][6 * i + 5]; } - auto v = Cat(allocator_, buf, 0); + auto v = Cat(allocator, buf, 0); ans.push_back(std::move(v)); } } @@ -238,7 +241,7 @@ std::vector OnlineZipformer2TransducerModel::StackStates( for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][num_states - 2]; } - auto v = Cat(allocator_, buf, 0); + auto v = Cat(allocator, buf, 0); ans.push_back(std::move(v)); } @@ -246,7 +249,7 @@ std::vector OnlineZipformer2TransducerModel::StackStates( for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][num_states - 1]; } - auto v = Cat(allocator_, buf, 0); + auto v = Cat(allocator, buf, 0); ans.push_back(std::move(v)); } return ans; @@ -261,12 +264,15 @@ OnlineZipformer2TransducerModel::UnStackStates( int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; + auto allocator = + const_cast(this)->allocator_; + std::vector> ans; ans.resize(batch_size); for (int32_t i = 0; i != m; ++i) { { - auto v = Unbind(allocator_, &states[i * 6], 1); + auto v = Unbind(allocator, &states[i * 6], 1); assert(static_cast(v.size()) == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -274,7 +280,7 @@ OnlineZipformer2TransducerModel::UnStackStates( } } { - auto v = Unbind(allocator_, &states[i * 6 + 1], 1); + auto v = Unbind(allocator, &states[i * 6 + 1], 1); assert(static_cast(v.size()) == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -282,7 +288,7 @@ OnlineZipformer2TransducerModel::UnStackStates( } } { - auto v = Unbind(allocator_, &states[i * 6 + 2], 1); + auto v = Unbind(allocator, &states[i * 6 + 2], 1); assert(static_cast(v.size()) == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -290,7 +296,7 @@ OnlineZipformer2TransducerModel::UnStackStates( } } { - auto v = Unbind(allocator_, &states[i * 6 + 3], 1); + auto v = Unbind(allocator, &states[i * 6 + 3], 1); assert(static_cast(v.size()) == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -298,7 +304,7 @@ OnlineZipformer2TransducerModel::UnStackStates( } } { - auto v = Unbind(allocator_, &states[i * 6 + 4], 0); + auto v = Unbind(allocator, &states[i * 6 + 4], 0); assert(static_cast(v.size()) == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -306,7 +312,7 @@ OnlineZipformer2TransducerModel::UnStackStates( } } { - auto v = Unbind(allocator_, &states[i * 6 + 5], 0); + auto v = Unbind(allocator, &states[i * 6 + 5], 0); assert(static_cast(v.size()) == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -316,7 +322,7 @@ OnlineZipformer2TransducerModel::UnStackStates( } { - auto v = Unbind(allocator_, &states[m * 6], 0); + auto v = Unbind(allocator, &states[m * 6], 0); assert(static_cast(v.size()) == batch_size); for (int32_t n = 0; n != batch_size; ++n) { @@ -324,7 +330,7 @@ OnlineZipformer2TransducerModel::UnStackStates( } } { - auto v = Unbind(allocator_, &states[m * 6 + 1], 0); + auto v = Unbind(allocator, &states[m * 6 + 1], 0); assert(static_cast(v.size()) == batch_size); for (int32_t n = 0; n != batch_size; ++n) { diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 0f637020a..0dc69fc8f 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -21,6 +21,36 @@ namespace sherpa_onnx { +static std::string GetInputName(Ort::Session *sess, size_t index, + OrtAllocator *allocator) { +// Note(fangjun): We only tested 1.17.1 and 1.11.0 +// For other versions, we may need to change it +#if ORT_API_VERSION >= 17 + auto v = sess->GetInputNameAllocated(index, allocator); + return v.get(); +#else + auto v = sess->GetInputName(index, allocator); + std::string ans = v; + allocator->Free(allocator, v); + return ans; +#endif +} + +static std::string GetOutputName(Ort::Session *sess, size_t index, + OrtAllocator *allocator) { +// Note(fangjun): We only tested 1.17.1 and 1.11.0 +// For other versions, we may need to change it +#if ORT_API_VERSION >= 17 + auto v = sess->GetOutputNameAllocated(index, allocator); + return v.get(); +#else + auto v = sess->GetOutputName(index, allocator); + std::string ans = v; + allocator->Free(allocator, v); + return ans; +#endif +} + void GetInputNames(Ort::Session *sess, std::vector *input_names, std::vector *input_names_ptr) { Ort::AllocatorWithDefaultOptions allocator; @@ -28,8 +58,7 @@ void GetInputNames(Ort::Session *sess, std::vector *input_names, input_names->resize(node_count); input_names_ptr->resize(node_count); for (size_t i = 0; i != node_count; ++i) { - auto tmp = sess->GetInputNameAllocated(i, allocator); - (*input_names)[i] = tmp.get(); + (*input_names)[i] = GetInputName(sess, i, allocator); (*input_names_ptr)[i] = (*input_names)[i].c_str(); } } @@ -41,8 +70,7 @@ void GetOutputNames(Ort::Session *sess, std::vector *output_names, output_names->resize(node_count); output_names_ptr->resize(node_count); for (size_t i = 0; i != node_count; ++i) { - auto tmp = sess->GetOutputNameAllocated(i, allocator); - (*output_names)[i] = tmp.get(); + (*output_names)[i] = GetOutputName(sess, i, allocator); (*output_names_ptr)[i] = (*output_names)[i].c_str(); } } @@ -78,12 +106,24 @@ Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { Ort::AllocatorWithDefaultOptions allocator; +#if ORT_API_VERSION >= 17 std::vector v = meta_data.GetCustomMetadataMapKeysAllocated(allocator); for (const auto &key : v) { auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator); os << key.get() << "=" << p.get() << "\n"; } +#else + int64_t num_keys = 0; + char **keys = meta_data.GetCustomMetadataMapKeys(allocator, num_keys); + for (int32_t i = 0; i < num_keys; ++i) { + auto v = LookupCustomModelMetaData(meta_data, keys[i], allocator); + os << keys[i] << "=" << v << "\n"; + allocator.Free(keys[i]); + } + + allocator.Free(keys); +#endif } Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { @@ -361,4 +401,20 @@ std::vector Convert(std::vector values) { return ans; } +std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data, + const char *key, + OrtAllocator *allocator) { +// Note(fangjun): We only tested 1.17.1 and 1.11.0 +// For other versions, we may need to change it +#if ORT_API_VERSION >= 17 + auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator); + return v.get(); +#else + auto v = meta_data.LookupCustomMetadataMap(key, allocator); + std::string ans = v; + allocator->Free(allocator, v); + return ans; +#endif +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index 98eb25137..8a19a2baf 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -59,6 +59,9 @@ void GetOutputNames(Ort::Session *sess, std::vector *output_names, Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out, int32_t t); +std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data, + const char *key, OrtAllocator *allocator); + void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data); // NOLINT diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 9c5eb2b1a..160797c6e 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -60,6 +60,7 @@ Ort::SessionOptions GetSessionOptionsImpl( case Provider::kCPU: break; // nothing to do for the CPU provider case Provider::kXnnpack: { +#if ORT_API_VERSION >= 17 if (std::find(available_providers.begin(), available_providers.end(), "XnnpackExecutionProvider") != available_providers.end()) { sess_opts.AppendExecutionProvider("XNNPACK"); @@ -67,6 +68,11 @@ Ort::SessionOptions GetSessionOptionsImpl( SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!", os.str().c_str()); } +#else + SHERPA_ONNX_LOGE( + "Does not support xnnpack for onnxruntime: %d. Fallback to cpu!", + static_cast(ORT_API_VERSION)); +#endif break; } case Provider::kTRT: { diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc index 4dafce91d..b9591d624 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc @@ -40,8 +40,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, Ort::AllocatorWithDefaultOptions allocator; auto model_type = - meta_data.LookupCustomMetadataMapAllocated("framework", allocator); - if (!model_type) { + LookupCustomModelMetaData(meta_data, "framework", allocator); + if (model_type.empty()) { SHERPA_ONNX_LOGE( "No model_type in the metadata!\n" "Please make sure you have added metadata to the model.\n\n" @@ -52,14 +52,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kUnknown; } - if (model_type.get() == std::string("wespeaker")) { + if (model_type == "wespeaker") { return ModelType::kWeSpeaker; - } else if (model_type.get() == std::string("3d-speaker")) { + } else if (model_type == "3d-speaker") { return ModelType::k3dSpeaker; - } else if (model_type.get() == std::string("nemo")) { + } else if (model_type == "nemo") { return ModelType::kNeMo; } else { - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); return ModelType::kUnknown; } } diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc index 2e481b20f..1b60e2469 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc @@ -53,7 +53,7 @@ class SpeakerEmbeddingExtractorNeMoModel::Impl { return std::move(outputs[0]); } - OrtAllocator *Allocator() const { return allocator_; } + OrtAllocator *Allocator() { return allocator_; } const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const { return meta_data_; diff --git a/sherpa-onnx/csrc/spoken-language-identification-impl.cc b/sherpa-onnx/csrc/spoken-language-identification-impl.cc index 109b48789..5b29df484 100644 --- a/sherpa-onnx/csrc/spoken-language-identification-impl.cc +++ b/sherpa-onnx/csrc/spoken-language-identification-impl.cc @@ -42,8 +42,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, Ort::AllocatorWithDefaultOptions allocator; auto model_type = - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); - if (!model_type) { + LookupCustomModelMetaData(meta_data, "model_type", allocator); + if (model_type.empty()) { SHERPA_ONNX_LOGE( "No model_type in the metadata!\n" "Please make sure you have added metadata to the model.\n\n" @@ -54,11 +54,10 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kUnknown; } - auto model_type_str = std::string(model_type.get()); - if (model_type_str.find("whisper") == 0) { + if (model_type.find("whisper") == 0) { return ModelType::kWhisper; } else { - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); return ModelType::kUnknown; } } diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index 173b060b4..77b976431 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -29,20 +29,19 @@ namespace { const char *ws = " \t\n\r\f\v"; // trim from end of string (right) -inline std::string &TrimRight(std::string &s, const char *t = ws) { - s.erase(s.find_last_not_of(t) + 1); - return s; +inline void TrimRight(std::string *s, const char *t = ws) { + s->erase(s->find_last_not_of(t) + 1); } // trim from beginning of string (left) -inline std::string &TrimLeft(std::string &s, const char *t = ws) { - s.erase(0, s.find_first_not_of(t)); - return s; +inline void TrimLeft(std::string *s, const char *t = ws) { + s->erase(0, s->find_first_not_of(t)); } // trim from both ends of string (right then left) -inline std::string &Trim(std::string &s, const char *t = ws) { - return TrimLeft(TrimRight(s, t), t); +inline void Trim(std::string *s, const char *t = ws) { + TrimRight(s, t); + TrimLeft(s, t); } } // namespace @@ -56,7 +55,7 @@ std::unordered_map ReadTokens( std::string sym; int32_t id = -1; while (std::getline(is, line)) { - Trim(line); + Trim(&line); std::istringstream iss(line); iss >> sym; if (iss.eof()) {