diff --git a/.github/scripts/test-spoken-language-identification.sh b/.github/scripts/test-spoken-language-identification.sh new file mode 100755 index 000000000..028e5c23a --- /dev/null +++ b/.github/scripts/test-spoken-language-identification.sh @@ -0,0 +1,98 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +names=( +tiny +base +small +medium +) + +# all_language_codes=bo,ml,tt,fa,sl,bg,sn,sr,tl,km,ln,mr,hr,eu,ro,ba,bs,pl,as,nn,sk,ko,oc,ar,uz,pa,tg,mk,kk,hi,ha,uk,is,de,el,ja,yo,be,so,tk,id,sa,ru,yi,en,am,cs,ne,la,sv,su,pt,mi,ca,sd,hy,haw,fi,et,kn,da,lt,it,nl,he,mg,ur,tr,af,br,bn,ta,no,my,si,mt,th,gl,sw,mn,jw,ms,ps,fo,ka,hu,zh,ht,az,fr,lo,sq,gu,cy,lv,es,lb,te,vi + +log "Download test waves" +waves=( +ar-arabic.wav +bg-bulgarian.wav +cs-czech.wav +da-danish.wav +de-german.wav +el-greek.wav +en-english.wav +es-spanish.wav +fa-persian.wav +fi-finnish.wav +fr-french.wav +hi-hindi.wav +hr-croatian.wav +id-indonesian.wav +it-italian.wav +ja-japanese.wav +ko-korean.wav +nl-dutch.wav +no-norwegian.wav +po-polish.wav +pt-portuguese.wav +ro-romanian.wav +ru-russian.wav +sk-slovak.wav +sv-swedish.wav +ta-tamil.wav +tl-tagalog.wav +tr-turkish.wav +uk-ukrainian.wav +zh-chinese.wav +) + +for wav in ${waves[@]}; do + echo "Downloading $wav" + curl -SL -O https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/$wav + ls -lh *.wav +done + +for name in ${names[@]}; do + log "------------------------------------------------------------" + log "Run $name" + log "------------------------------------------------------------" + + repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-whisper-$name + log "Start testing ${repo_url}" + repo=$(basename $repo_url) + log "Download pretrained model and test-data from $repo_url" + + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + pushd $repo + git lfs pull --include "*.onnx" + # git lfs pull --include "*.ort" + ls -lh *.onnx + popd + + for wav in ${waves[@]}; do + log "test fp32 onnx" + + time $EXE \ + --whisper-encoder=$repo/${name}-encoder.onnx \ + --whisper-decoder=$repo/${name}-decoder.onnx \ + $wav + + log "test int8 onnx" + + time $EXE \ + --whisper-encoder=$repo/${name}-encoder.int8.onnx \ + --whisper-decoder=$repo/${name}-decoder.int8.onnx \ + $wav + done + rm -rf $repo +done diff --git a/.github/workflows/build-wheels-linux.yaml b/.github/workflows/build-wheels-linux.yaml index 48f1a7767..0443074c2 100644 --- a/.github/workflows/build-wheels-linux.yaml +++ b/.github/workflows/build-wheels-linux.yaml @@ -82,7 +82,6 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} uses: nick-fields/retry@v3 - shell: bash with: max_attempts: 20 timeout_seconds: 200 diff --git a/.github/workflows/build-wheels-macos-arm64.yaml b/.github/workflows/build-wheels-macos-arm64.yaml index 1b90ab4d4..b41f5eba7 100644 --- a/.github/workflows/build-wheels-macos-arm64.yaml +++ b/.github/workflows/build-wheels-macos-arm64.yaml @@ -21,27 +21,12 @@ jobs: fail-fast: false matrix: os: [macos-latest] - python-version: ["cp37", "cp38", "cp39", "cp310", "cp311", "cp312"] + python-version: ["cp38", "cp39", "cp310", "cp311", "cp312"] steps: - uses: actions/checkout@v4 - # see https://cibuildwheel.readthedocs.io/en/stable/changelog/ - # for a list of versions - name: Build wheels - if: matrix.python-version == 'cp37' - uses: pypa/cibuildwheel@v2.11.4 - env: - CIBW_BUILD: "${{ matrix.python-version}}-* " - CIBW_ENVIRONMENT: SHERPA_ONNX_CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES='arm64'" - CIBW_ARCHS: "arm64" - CIBW_BUILD_VERBOSITY: 3 - - # Don't repair macOS wheels - CIBW_REPAIR_WHEEL_COMMAND_MACOS: "" - - - name: Build wheels - if: matrix.python-version != 'cp37' uses: pypa/cibuildwheel@v2.15.0 env: CIBW_BUILD: "${{ matrix.python-version}}-* " diff --git a/.github/workflows/linux-gpu.yaml b/.github/workflows/linux-gpu.yaml index d3bfd118b..bccde7390 100644 --- a/.github/workflows/linux-gpu.yaml +++ b/.github/workflows/linux-gpu.yaml @@ -92,6 +92,14 @@ jobs: file build/bin/sherpa-onnx readelf -d build/bin/sherpa-onnx + - name: Test spoken language identification + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-language-identification + + .github/scripts/test-spoken-language-identification.sh + - name: Test online CTC shell: bash run: | @@ -116,6 +124,7 @@ jobs: .github/scripts/test-online-paraformer.sh + - name: Test offline Whisper shell: bash run: | diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index a60b3e430..754daa312 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -123,6 +123,15 @@ jobs: name: release-${{ matrix.build_type }}-${{ matrix.shared_lib }} path: build/bin/* + - name: Test spoken language identification + if: matrix.build_type != 'Debug' + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-language-identification + + .github/scripts/test-spoken-language-identification.sh + - name: Test transducer kws shell: bash run: | @@ -140,6 +149,7 @@ jobs: .github/scripts/test-online-ctc.sh - name: Test offline Whisper + if: matrix.build_type != 'Debug' shell: bash run: | export PATH=$PWD/build/bin:$PATH diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index b48a4a000..04abcd31d 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -102,6 +102,15 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test spoken language identification + if: matrix.build_type != 'Debug' + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-language-identification + + .github/scripts/test-spoken-language-identification.sh + - name: Test transducer kws shell: bash run: | @@ -135,6 +144,7 @@ jobs: .github/scripts/test-online-paraformer.sh - name: Test offline Whisper + if: matrix.build_type != 'Debug' shell: bash run: | export PATH=$PWD/build/bin:$PATH diff --git a/.github/workflows/windows-x64-cuda.yaml b/.github/workflows/windows-x64-cuda.yaml index d4ca33a79..0672065c2 100644 --- a/.github/workflows/windows-x64-cuda.yaml +++ b/.github/workflows/windows-x64-cuda.yaml @@ -68,6 +68,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test spoken language identification + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline-language-identification.exe + + .github/scripts/test-spoken-language-identification.sh + - name: Test online CTC shell: bash run: | diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 46daea36b..cf982f6fb 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -68,6 +68,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test spoken language identification + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline-language-identification.exe + + .github/scripts/test-spoken-language-identification.sh + - name: Test online CTC shell: bash run: | diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 1ed8ea0a0..b701b8c0c 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -69,6 +69,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + # - name: Test spoken language identification + # shell: bash + # run: | + # export PATH=$PWD/build/bin/Release:$PATH + # export EXE=sherpa-onnx-offline-language-identification.exe + # + # .github/scripts/test-spoken-language-identification.sh + - name: Test online CTC shell: bash run: | diff --git a/CMakeLists.txt b/CMakeLists.txt index 45e359e4c..495cab28d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.9.13") +set(SHERPA_ONNX_VERSION "1.9.14") # Disable warning about # diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index 45b534d05..7fb0fc0f2 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -43,6 +43,50 @@ def enable_alsa(): return build_alsa and is_linux() and (is_arm64() or is_x86()) +def get_binaries(): + binaries = [ + "sherpa-onnx", + "sherpa-onnx-keyword-spotter", + "sherpa-onnx-microphone", + "sherpa-onnx-microphone-offline", + "sherpa-onnx-microphone-offline-speaker-identification", + "sherpa-onnx-offline", + "sherpa-onnx-offline-language-identification", + "sherpa-onnx-offline-tts", + "sherpa-onnx-offline-tts-play", + "sherpa-onnx-offline-websocket-server", + "sherpa-onnx-online-websocket-client", + "sherpa-onnx-online-websocket-server", + "sherpa-onnx-vad-microphone", + "sherpa-onnx-vad-microphone-offline-asr", + ] + + if enable_alsa(): + binaries += [ + "sherpa-onnx-alsa", + "sherpa-onnx-alsa-offline", + "sherpa-onnx-alsa-offline-speaker-identification", + "sherpa-onnx-offline-tts-play-alsa", + ] + + if is_windows(): + binaries += [ + "espeak-ng.dll", + "kaldi-decoder-core.dll", + "kaldi-native-fbank-core.dll", + "onnxruntime.dll", + "piper_phonemize.dll", + "sherpa-onnx-c-api.dll", + "sherpa-onnx-core.dll", + "sherpa-onnx-fst.lib", + "sherpa-onnx-kaldifst-core.lib", + "sherpa-onnx-portaudio.dll", + "ucd.dll", + ] + + return binaries + + try: from wheel.bdist_wheel import bdist_wheel as _bdist_wheel @@ -150,38 +194,7 @@ def build_extension(self, ext: setuptools.extension.Extension): suffix = ".exe" if is_windows() else "" # Remember to also change setup.py - binaries = ["sherpa-onnx"] - binaries += ["sherpa-onnx-keyword-spotter"] - binaries += ["sherpa-onnx-offline"] - binaries += ["sherpa-onnx-microphone"] - binaries += ["sherpa-onnx-microphone-offline"] - binaries += ["sherpa-onnx-microphone-offline-speaker-identification"] - binaries += ["sherpa-onnx-online-websocket-server"] - binaries += ["sherpa-onnx-offline-websocket-server"] - binaries += ["sherpa-onnx-online-websocket-client"] - binaries += ["sherpa-onnx-vad-microphone"] - binaries += ["sherpa-onnx-vad-microphone-offline-asr"] - binaries += ["sherpa-onnx-offline-tts"] - binaries += ["sherpa-onnx-offline-tts-play"] - - if enable_alsa(): - binaries += ["sherpa-onnx-alsa"] - binaries += ["sherpa-onnx-alsa-offline"] - binaries += ["sherpa-onnx-offline-tts-play-alsa"] - binaries += ["sherpa-onnx-alsa-offline-speaker-identification"] - - if is_windows(): - binaries += ["kaldi-native-fbank-core.dll"] - binaries += ["sherpa-onnx-c-api.dll"] - binaries += ["sherpa-onnx-core.dll"] - binaries += ["sherpa-onnx-portaudio.dll"] - binaries += ["onnxruntime.dll"] - binaries += ["piper_phonemize.dll"] - binaries += ["espeak-ng.dll"] - binaries += ["ucd.dll"] - binaries += ["kaldi-decoder-core.dll"] - binaries += ["sherpa-onnx-fst.lib"] - binaries += ["sherpa-onnx-kaldifst-core.lib"] + binaries = get_binaries() for f in binaries: suffix = "" if (".dll" in f or ".lib" in f) else suffix diff --git a/python-api-examples/spoken-language-identification.py b/python-api-examples/spoken-language-identification.py new file mode 100755 index 000000000..8e5ad8b61 --- /dev/null +++ b/python-api-examples/spoken-language-identification.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use Python APIs for spoken languge identification. +It detects the language spoken in the given wave file. + +Usage: + +1. Download a whisper multilingual model. We use a tiny model below. +Please refer to https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models +to download more models. + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 +tar xvf sherpa-onnx-whisper-tiny.tar.bz2 +rm sherpa-onnx-whisper-tiny.tar.bz2 + +We only use the int8.onnx models below. + +2. Download a test wave. + +You can find many wave files for different languages at +https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs + +wget https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/de-german.wav + +python3 ./python-api-examples/spoken-language-identification.py + --whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \ + --whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \ + --num-threads=1 \ + ./de-german.wav +""" + +import argparse +import logging +import time +import wave +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_onnx + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--whisper-encoder", + required=True, + type=str, + help="Path to a multilingual whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + required=True, + type=str, + help="Path to a multilingual whisper decoder model", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to identify. It must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +def main(): + args = get_args() + assert_file_exists(args.whisper_encoder) + assert_file_exists(args.whisper_decoder) + assert args.num_threads > 0, args.num_threads + config = sherpa_onnx.SpokenLanguageIdentificationConfig( + whisper=sherpa_onnx.SpokenLanguageIdentificationWhisperConfig( + encoder=args.whisper_encoder, + decoder=args.whisper_decoder, + ), + num_threads=args.num_threads, + debug=args.debug, + provider=args.provider, + ) + slid = sherpa_onnx.SpokenLanguageIdentification(config) + + samples, sample_rate = read_wave(args.sound_file) + + start_time = time.time() + stream = slid.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + lang = slid.compute(stream) + end_time = time.time() + + elapsed_seconds = end_time - start_time + audio_duration = len(samples) / sample_rate + real_time_factor = elapsed_seconds / audio_duration + + logging.info(f"File: {args.sound_file}") + logging.info(f"Detected language: {lang}") + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}") + logging.info(f"Audio duration in seconds: {audio_duration:.3f}") + logging.info( + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/setup.py b/setup.py index 6d5ccf73b..150444a86 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 -import os import re -import sys from pathlib import Path import setuptools @@ -11,7 +9,7 @@ BuildExtension, bdist_wheel, cmake_extension, - enable_alsa, + get_binaries, is_windows, ) @@ -42,39 +40,7 @@ def get_binaries_to_install(): bin_dir.mkdir(parents=True, exist_ok=True) suffix = ".exe" if is_windows() else "" - # Remember to also change cmake/cmake_extension.py - binaries = ["sherpa-onnx"] - binaries += ["sherpa-onnx-keyword-spotter"] - binaries += ["sherpa-onnx-offline"] - binaries += ["sherpa-onnx-microphone"] - binaries += ["sherpa-onnx-microphone-offline"] - binaries += ["sherpa-onnx-microphone-offline-speaker-identification"] - binaries += ["sherpa-onnx-online-websocket-server"] - binaries += ["sherpa-onnx-offline-websocket-server"] - binaries += ["sherpa-onnx-online-websocket-client"] - binaries += ["sherpa-onnx-vad-microphone"] - binaries += ["sherpa-onnx-vad-microphone-offline-asr"] - binaries += ["sherpa-onnx-offline-tts"] - binaries += ["sherpa-onnx-offline-tts-play"] - - if enable_alsa(): - binaries += ["sherpa-onnx-alsa"] - binaries += ["sherpa-onnx-alsa-offline"] - binaries += ["sherpa-onnx-offline-tts-play-alsa"] - binaries += ["sherpa-onnx-alsa-offline-speaker-identification"] - - if is_windows(): - binaries += ["kaldi-native-fbank-core.dll"] - binaries += ["sherpa-onnx-c-api.dll"] - binaries += ["sherpa-onnx-core.dll"] - binaries += ["sherpa-onnx-portaudio.dll"] - binaries += ["onnxruntime.dll"] - binaries += ["piper_phonemize.dll"] - binaries += ["espeak-ng.dll"] - binaries += ["ucd.dll"] - binaries += ["kaldi-decoder-core.dll"] - binaries += ["sherpa-onnx-fst.lib"] - binaries += ["sherpa-onnx-kaldifst-core.lib"] + binaries = get_binaries() exe = [] for f in binaries: diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 198e8fa7c..6a14aa780 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -86,6 +86,8 @@ set(sources silero-vad-model-config.cc silero-vad-model.cc slice.cc + spoken-language-identification-impl.cc + spoken-language-identification.cc stack.cc symbol-table.cc text-utils.cc @@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) + add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) set(main_exes sherpa-onnx @@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) sherpa-onnx-offline sherpa-onnx-offline-parallel sherpa-onnx-offline-tts + sherpa-onnx-offline-language-identification ) foreach(exe IN LISTS main_exes) diff --git a/sherpa-onnx/csrc/offline-ctc-model.cc b/sherpa-onnx/csrc/offline-ctc-model.cc index cf0649acb..5a6fb1858 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-ctc-model.cc @@ -23,7 +23,7 @@ enum class ModelType { kTdnn, kZipformerCtc, kWenetCtc, - kUnkown, + kUnknown, }; } // namespace @@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, "run.sh\n" "\n" "for how to add metadta to model.onnx\n"); - return ModelType::kUnkown; + return ModelType::kUnknown; } if (model_type.get() == std::string("EncDecCTCModelBPE")) { @@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kWenetCtc; } else { SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); - return ModelType::kUnkown; + return ModelType::kUnknown; } } std::unique_ptr OfflineCtcModel::Create( const OfflineModelConfig &config) { - ModelType model_type = ModelType::kUnkown; + ModelType model_type = ModelType::kUnknown; std::string filename; if (!config.nemo_ctc.model.empty()) { @@ -113,7 +113,7 @@ std::unique_ptr OfflineCtcModel::Create( case ModelType::kWenetCtc: return std::make_unique(config); break; - case ModelType::kUnkown: + case ModelType::kUnknown: SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); return nullptr; } @@ -125,7 +125,7 @@ std::unique_ptr OfflineCtcModel::Create( std::unique_ptr OfflineCtcModel::Create( AAssetManager *mgr, const OfflineModelConfig &config) { - ModelType model_type = ModelType::kUnkown; + ModelType model_type = ModelType::kUnknown; std::string filename; if (!config.nemo_ctc.model.empty()) { @@ -160,7 +160,7 @@ std::unique_ptr OfflineCtcModel::Create( case ModelType::kWenetCtc: return std::make_unique(mgr, config); break; - case ModelType::kUnkown: + case ModelType::kUnknown: SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); return nullptr; } diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h index 4f5be4ca1..ea62925b0 100644 --- a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { num_frames = max_num_frames - 50; } - NormalizeFeatures(f.data(), num_frames, feat_dim); + model_->NormalizeFeatures(f.data(), num_frames, feat_dim); // note that 1000 is an experience-value. // You can replace 1000 by other values, say, 100. @@ -162,38 +162,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { } } - private: - static void NormalizeFeatures(float *features, int32_t num_frames, - int32_t feat_dim) { - // log_spec = torch.clamp(features, min=1e-10).log10() - // log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) - // mel = (log_spec + 4.0) / 4.0 - - int32_t n = num_frames * feat_dim; - float max_v = -1e20; - for (int32_t i = 0; i != n; ++i) { - float f = features[i]; - - f = std::max(f, 1e-10); - f = std::log10(f); - - max_v = std::max(f, max_v); - - features[i] = f; - } - - max_v -= 8; - - for (int32_t i = 0; i != n; ++i) { - float f = features[i]; - f = std::max(f, max_v); - - f = (f + 4) / 4; - - features[i] = f; - } - } - private: OfflineRecognizerConfig config_; SymbolTable symbol_table_; diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc index 396e76ec6..15eacb62b 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc @@ -12,56 +12,6 @@ namespace sherpa_onnx { -int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage( - Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT - int64_t token_val = model_->SOT(); - std::array token_shape{1, 1}; - - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - Ort::Value tokens = Ort::Value::CreateTensor( - memory_info, &token_val, 1, token_shape.data(), token_shape.size()); - - auto self_kv_cache = model_->GetInitialSelfKVCache(); - - std::array offset_shape{1}; - Ort::Value offset = Ort::Value::CreateTensor( - model_->Allocator(), offset_shape.data(), offset_shape.size()); - *(offset.GetTensorMutableData()) = 0; - - auto decoder_out = model_->ForwardDecoder( - std::move(tokens), std::move(self_kv_cache.first), - std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), - std::move(offset)); - - cross_k = std::move(std::get<3>(decoder_out)); - cross_v = std::move(std::get<4>(decoder_out)); - - const float *p_logits = std::get<0>(decoder_out).GetTensorData(); - int32_t vocab_size = model_->VocabSize(); - const auto &all_language_ids = model_->GetAllLanguageIDs(); - - int32_t lang_id = all_language_ids[0]; - float this_logit = p_logits[lang_id]; - - for (int32_t i = 1; i != all_language_ids.size(); ++i) { - int32_t id = all_language_ids[i]; - float p = p_logits[id]; - - if (p > this_logit) { - this_logit = p; - lang_id = id; - } - } -#if 1 - SHERPA_ONNX_LOGE("Detected language: %s", - model_->GetID2Lang().at(lang_id).c_str()); -#endif - - return lang_id; -} - std::vector OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, Ort::Value cross_v) { @@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, // 0: sot, 1: lang_id, 2: task, 3: no_timestamps initial_tokens[1] = lang_id; } else { - int32_t lang_id = DetectLanguage(cross_k, cross_v); + int32_t lang_id = model_->DetectLanguage(cross_k, cross_v); // 0: sot, 1: lang_id, 2: task, 3: no_timestamps initial_tokens[1] = lang_id; diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h index b74bd94a4..5f2b41680 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h @@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { std::vector Decode(Ort::Value cross_k, Ort::Value cross_v) override; - int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT - Ort::Value &cross_v) const; // NOLINT - private: OfflineWhisperModelConfig config_; OfflineWhisperModel *model_; // not owned diff --git a/sherpa-onnx/csrc/offline-whisper-model-config.cc b/sherpa-onnx/csrc/offline-whisper-model-config.cc index 946437bc2..821ccfbbe 100644 --- a/sherpa-onnx/csrc/offline-whisper-model-config.cc +++ b/sherpa-onnx/csrc/offline-whisper-model-config.cc @@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { po->Register( "whisper-tail-paddings", &tail_paddings, - "Suggest value: 50 for English models. 300 for multilingual models. " + "Suggested value: 50 for English models. 300 for multilingual models. " "Since we have removed the 30-second constraint, we need to add some " "tail padding frames " - "so that whisper can detect the eot token. Leave it to -1 to use 50 for " - "English models and 300 for multilingual models."); + "so that whisper can detect the eot token. Leave it to -1 to use 1000."); } bool OfflineWhisperModelConfig::Validate() const { + if (encoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --whisper-encoder"); + return false; + } + if (!FileExists(encoder)) { SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str()); return false; } + if (decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --whisper-decoder"); + return false; + } + if (!FileExists(decoder)) { SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str()); return false; diff --git a/sherpa-onnx/csrc/offline-whisper-model.cc b/sherpa-onnx/csrc/offline-whisper-model.cc index 8774233a7..2dcfac907 100644 --- a/sherpa-onnx/csrc/offline-whisper-model.cc +++ b/sherpa-onnx/csrc/offline-whisper-model.cc @@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { + debug_ = config_.debug; + { + auto buf = ReadFile(config.whisper.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.whisper.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + explicit Impl(const SpokenLanguageIdentificationConfig &config) + : lid_config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + debug_ = config_.debug; { auto buf = ReadFile(config.whisper.encoder); InitEncoder(buf.data(), buf.size()); @@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl { env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), allocator_{} { + debug_ = config_.debug; { auto buf = ReadFile(mgr, config.whisper.encoder); InitEncoder(buf.data(), buf.size()); @@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl { std::move(decoder_input[4]), std::move(decoder_input[5])}; } + int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT + Ort::Value &cross_v) { // NOLINT + int64_t token_val = SOT(); + std::array token_shape{1, 1}; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + Ort::Value tokens = Ort::Value::CreateTensor( + memory_info, &token_val, 1, token_shape.data(), token_shape.size()); + + auto self_kv_cache = GetInitialSelfKVCache(); + + std::array offset_shape{1}; + Ort::Value offset = Ort::Value::CreateTensor( + Allocator(), offset_shape.data(), offset_shape.size()); + *(offset.GetTensorMutableData()) = 0; + + auto decoder_out = + ForwardDecoder(std::move(tokens), std::move(self_kv_cache.first), + std::move(self_kv_cache.second), std::move(cross_k), + std::move(cross_v), std::move(offset)); + + cross_k = std::move(std::get<3>(decoder_out)); + cross_v = std::move(std::get<4>(decoder_out)); + + const float *p_logits = std::get<0>(decoder_out).GetTensorData(); + int32_t vocab_size = VocabSize(); + const auto &all_language_ids = GetAllLanguageIDs(); + + int32_t lang_id = all_language_ids[0]; + float this_logit = p_logits[lang_id]; + + for (int32_t i = 1; i != all_language_ids.size(); ++i) { + int32_t id = all_language_ids[i]; + float p = p_logits[id]; + + if (p > this_logit) { + this_logit = p; + lang_id = id; + } + } + + if (debug_) { + SHERPA_ONNX_LOGE("Detected language: %s", + GetID2Lang().at(lang_id).c_str()); + } + + return lang_id; + } + std::pair GetInitialSelfKVCache() { std::array shape{n_text_layer_, 1, n_text_ctx_, n_text_state_}; @@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl { // get meta data Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); - if (config_.debug) { + if (debug_) { std::ostringstream os; os << "---encoder---\n"; PrintModelMetadata(os, meta_data); @@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl { private: OfflineModelConfig config_; + SpokenLanguageIdentificationConfig lid_config_; + bool debug_ = false; Ort::Env env_; Ort::SessionOptions sess_opts_; Ort::AllocatorWithDefaultOptions allocator_; @@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl { OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) : impl_(std::make_unique(config)) {} +OfflineWhisperModel::OfflineWhisperModel( + const SpokenLanguageIdentificationConfig &config) + : impl_(std::make_unique(config)) {} + #if __ANDROID_API__ >= 9 OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config) @@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens, std::move(n_layer_cross_v), std::move(offset)); } +int32_t OfflineWhisperModel::DetectLanguage(Ort::Value &cross_k, // NOLINT + Ort::Value &cross_v) { // NOLINT + return impl_->DetectLanguage(cross_k, cross_v); +} + std::pair OfflineWhisperModel::GetInitialSelfKVCache() const { return impl_->GetInitialSelfKVCache(); @@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const { return impl_->IsMultiLingual(); } +void OfflineWhisperModel::NormalizeFeatures(float *features, int32_t num_frames, + int32_t feat_dim) { + // log_spec = torch.clamp(features, min=1e-10).log10() + // log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + // mel = (log_spec + 4.0) / 4.0 + + int32_t n = num_frames * feat_dim; + float max_v = -1e20; + for (int32_t i = 0; i != n; ++i) { + float f = features[i]; + + f = std::max(f, 1e-10); + f = std::log10(f); + + max_v = std::max(f, max_v); + + features[i] = f; + } + + max_v -= 8; + + for (int32_t i = 0; i != n; ++i) { + float f = features[i]; + f = std::max(f, max_v); + + f = (f + 4) / 4; + + features[i] = f; + } +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-whisper-model.h b/sherpa-onnx/csrc/offline-whisper-model.h index 3d0674099..386a4d87d 100644 --- a/sherpa-onnx/csrc/offline-whisper-model.h +++ b/sherpa-onnx/csrc/offline-whisper-model.h @@ -18,6 +18,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/spoken-language-identification.h" namespace sherpa_onnx { @@ -25,6 +26,9 @@ class OfflineWhisperModel { public: explicit OfflineWhisperModel(const OfflineModelConfig &config); + explicit OfflineWhisperModel( + const SpokenLanguageIdentificationConfig &config); + #if __ANDROID_API__ >= 9 OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config); #endif @@ -72,7 +76,8 @@ class OfflineWhisperModel { Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v, Ort::Value offset) const; - int32_t DetectLanguage() const; + int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT + Ort::Value &cross_v); // NOLINT /** Return the initial self kv cache in a pair * - n_layer_self_k_cache A 4-D tensor of shape @@ -98,6 +103,9 @@ class OfflineWhisperModel { int32_t Translate() const; bool IsMultiLingual() const; + static void NormalizeFeatures(float *features, int32_t num_frames, + int32_t feat_dim); + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 83bdc3906..2bc9acc53 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -28,7 +28,7 @@ enum class ModelType { kLstm, kZipformer, kZipformer2, - kUnkown, + kUnknown, }; } // namespace @@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, "No model_type in the metadata!\n" "Please make sure you are using the latest export-onnx.py from icefall " "to export your transducer models"); - return ModelType::kUnkown; + return ModelType::kUnknown; } if (model_type.get() == std::string("conformer")) { @@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kZipformer2; } else { SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); - return ModelType::kUnkown; + return ModelType::kUnknown; } } @@ -93,7 +93,7 @@ std::unique_ptr OnlineTransducerModel::Create( model_type.c_str()); } } - ModelType model_type = ModelType::kUnkown; + ModelType model_type = ModelType::kUnknown; { auto buffer = ReadFile(config.transducer.encoder); @@ -110,7 +110,7 @@ std::unique_ptr OnlineTransducerModel::Create( return std::make_unique(config); case ModelType::kZipformer2: return std::make_unique(config); - case ModelType::kUnkown: + case ModelType::kUnknown: SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); return nullptr; } @@ -185,7 +185,7 @@ std::unique_ptr OnlineTransducerModel::Create( return std::make_unique(mgr, config); case ModelType::kZipformer2: return std::make_unique(mgr, config); - case ModelType::kUnkown: + case ModelType::kUnknown: SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); return nullptr; } diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 759c66dd4..aacd1e158 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions( return GetSessionOptionsImpl(config.num_threads, config.provider); } +Ort::SessionOptions GetSessionOptions( + const SpokenLanguageIdentificationConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 53cc22b76..9bb3e4371 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -12,6 +12,7 @@ #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" +#include "sherpa-onnx/csrc/spoken-language-identification.h" #include "sherpa-onnx/csrc/vad-model-config.h" namespace sherpa_onnx { @@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); Ort::SessionOptions GetSessionOptions( const SpeakerEmbeddingExtractorConfig &config); + +Ort::SessionOptions GetSessionOptions( + const SpokenLanguageIdentificationConfig &config); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc new file mode 100644 index 000000000..83756621d --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc @@ -0,0 +1,107 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#include + +#include // NOLINT +#include +#include + +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/spoken-language-identification.h" +#include "sherpa-onnx/csrc/wave-reader.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Spoken language identification with sherpa-onnx. + +Usage: + +(1) Use a whisper multilingual model + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 +tar xvf sherpa-onnx-whisper-tiny.tar.bz2 +rm sherpa-onnx-whisper-tiny.tar.bz2 + +We only use the int8.onnx models below. + +./bin/sherpa-onnx-offline-spoken-language-identification \ + --whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \ + --whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \ + --num-threads=1 \ + /path/to/foo.wav + +foo.wav should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. +You can find test waves for different languages at +https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html +Note that only whisper multilingual models are supported. For instance, +"tiny" is supported but "tiny.en" is not. +for a list of pre-trained models to download. +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::SpokenLanguageIdentificationConfig config; + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Error: Please provide 1 wave file.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + fprintf(stderr, "Creating spoken language identifier ...\n"); + sherpa_onnx::SpokenLanguageIdentification slid(config); + + fprintf(stderr, "Started\n"); + const std::string wav_filename = po.GetArg(1); + + int32_t sampling_rate = -1; + bool is_ok = false; + const std::vector samples = + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + return -1; + } + float duration = samples.size() / static_cast(sampling_rate); + + const auto begin = std::chrono::steady_clock::now(); + + auto s = slid.CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + auto language = slid.Compute(s.get()); + + const auto end = std::chrono::steady_clock::now(); + + fprintf(stderr, "Done!\n\n"); + fprintf(stderr, "%s\nDetected language: %s\n", wav_filename.c_str(), + language.c_str()); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "num threads: %d\n", config.num_threads); + + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +} diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc index 1d2798f6c..52466edc5 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc @@ -16,7 +16,7 @@ enum class ModelType { kWeSpeaker, k3dSpeaker, kNeMo, - kUnkown, + kUnknown, }; } // namespace @@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/" "add_meta_data.py" "to add metadata to models from WeSpeaker\n"); - return ModelType::kUnkown; + return ModelType::kUnknown; } if (model_type.get() == std::string("wespeaker")) { @@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kNeMo; } else { SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); - return ModelType::kUnkown; + return ModelType::kUnknown; } } std::unique_ptr SpeakerEmbeddingExtractorImpl::Create( const SpeakerEmbeddingExtractorConfig &config) { - ModelType model_type = ModelType::kUnkown; + ModelType model_type = ModelType::kUnknown; { auto buffer = ReadFile(config.model); @@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create( return std::make_unique(config); case ModelType::kNeMo: return std::make_unique(config); - case ModelType::kUnkown: - SHERPA_ONNX_LOGE( - "Unknown model type in for speaker embedding extractor!"); + case ModelType::kUnknown: + SHERPA_ONNX_LOGE("Unknown model type for speaker embedding extractor!"); return nullptr; } @@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create( std::unique_ptr SpeakerEmbeddingExtractorImpl::Create( AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) { - ModelType model_type = ModelType::kUnkown; + ModelType model_type = ModelType::kUnknown; { auto buffer = ReadFile(mgr, config.model); @@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create( config); case ModelType::kNeMo: return std::make_unique(mgr, config); - case ModelType::kUnkown: + case ModelType::kUnknown: SHERPA_ONNX_LOGE( "Unknown model type in for speaker embedding extractor!"); return nullptr; diff --git a/sherpa-onnx/csrc/spoken-language-identification-impl.cc b/sherpa-onnx/csrc/spoken-language-identification-impl.cc new file mode 100644 index 000000000..599a72a7c --- /dev/null +++ b/sherpa-onnx/csrc/spoken-language-identification-impl.cc @@ -0,0 +1,88 @@ +// sherpa-onnx/csrc/spoken-language-identification-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-onnx/csrc/spoken-language-identification-impl.h" + +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h" + +namespace sherpa_onnx { + +namespace { + +enum class ModelType { + kWhisper, + kUnknown, +}; + +} + +static ModelType GetModelType(char *model_data, size_t model_data_length, + bool debug) { + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + Ort::SessionOptions sess_opts; + + auto sess = std::make_unique(env, model_data, model_data_length, + sess_opts); + + Ort::ModelMetadata meta_data = sess->GetModelMetadata(); + if (debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; + auto model_type = + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); + if (!model_type) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n" + "Please make sure you have added metadata to the model.\n\n" + "For instance, you can use\n" + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/" + "export-onnx.py " + "to add metadata to models from whisper\n"); + return ModelType::kUnknown; + } + + auto model_type_str = std::string(model_type.get()); + if (model_type_str.find("whisper") == 0) { + return ModelType::kWhisper; + } else { + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); + return ModelType::kUnknown; + } +} + +std::unique_ptr +SpokenLanguageIdentificationImpl::Create( + const SpokenLanguageIdentificationConfig &config) { + ModelType model_type = ModelType::kUnknown; + { + if (config.whisper.encoder.empty()) { + SHERPA_ONNX_LOGE("Only whisper models are supported at present"); + exit(-1); + } + auto buffer = ReadFile(config.whisper.encoder); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kWhisper: + return std::make_unique(config); + case ModelType::kUnknown: + SHERPA_ONNX_LOGE( + "Unknown model type for spoken language identification!"); + return nullptr; + } + + // unreachable code + return nullptr; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/spoken-language-identification-impl.h b/sherpa-onnx/csrc/spoken-language-identification-impl.h new file mode 100644 index 000000000..b9112fa4f --- /dev/null +++ b/sherpa-onnx/csrc/spoken-language-identification-impl.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/spoken-language-identification-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_ + +#include +#include + +#include "sherpa-onnx/csrc/spoken-language-identification.h" + +namespace sherpa_onnx { + +class SpokenLanguageIdentificationImpl { + public: + virtual ~SpokenLanguageIdentificationImpl() = default; + + static std::unique_ptr Create( + const SpokenLanguageIdentificationConfig &config); + + virtual std::unique_ptr CreateStream() const = 0; + + virtual std::string Compute(OfflineStream *s) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_ diff --git a/sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h b/sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h new file mode 100644 index 000000000..a44001d58 --- /dev/null +++ b/sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h @@ -0,0 +1,119 @@ +// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-whisper-model.h" +#include "sherpa-onnx/csrc/spoken-language-identification-impl.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +class SpokenLanguageIdentificationWhisperImpl + : public SpokenLanguageIdentificationImpl { + public: + explicit SpokenLanguageIdentificationWhisperImpl( + const SpokenLanguageIdentificationConfig &config) + : config_(config), model_(std::make_unique(config)) { + Check(); + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(WhisperTag{}); + } + + std::string Compute(OfflineStream *s) const override { + int32_t max_num_frames = 3000; + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t feat_dim = s->FeatureDim(); + std::vector f = s->GetFrames(); + int32_t num_frames = f.size() / feat_dim; + + // we use 50 here so that there will be some zero tail paddings + if (num_frames >= max_num_frames - 50) { + SHERPA_ONNX_LOGE( + "Only waves less than 30 seconds are supported. We process only the " + "first 30 seconds and discard the remaining data"); + num_frames = max_num_frames - 50; + } + + model_->NormalizeFeatures(f.data(), num_frames, feat_dim); + + // note that 1000 is an experience-value. + // You can replace 1000 by other values, say, 100. + // + // Since we have removed the 30 seconds constraint, we need + // tail_padding_frames so that whisper is able to detect the eot token. + int32_t tail_padding_frames = 1000; + + if (config_.whisper.tail_paddings > 0) { + tail_padding_frames = config_.whisper.tail_paddings; + } + + int32_t actual_frames = + std::min(num_frames + tail_padding_frames, max_num_frames); + + std::array shape{1, actual_frames, feat_dim}; + + Ort::Value mel = Ort::Value::CreateTensor( + model_->Allocator(), shape.data(), shape.size()); + + float *p_mel = mel.GetTensorMutableData(); + std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel); + + std::fill_n(p_mel + num_frames * feat_dim, + (actual_frames - num_frames) * feat_dim, 0); + + mel = Transpose12(model_->Allocator(), &mel); + + try { + auto cross_kv = model_->ForwardEncoder(std::move(mel)); + int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second); + const auto &id2lang = model_->GetID2Lang(); + if (id2lang.count(lang_id)) { + return id2lang.at(lang_id); + } else { + SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.", + lang_id); + return ""; + } + } catch (const Ort::Exception &ex) { + SHERPA_ONNX_LOGE( + "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of " + "input frames: %d, Current tail " + "paddings: %d. If you see a lot of such exceptions, please consider " + "using a larger --whisper-tail-paddings", + ex.what(), num_frames, tail_padding_frames); + return ""; + } + } + + private: + void Check() const { + if (!model_->IsMultiLingual()) { + SHERPA_ONNX_LOGE( + "Only whisper multilingual models can be used for spoken language " + "identification. Given: %s,%s", + config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str()); + exit(-1); + } + } + + private: + SpokenLanguageIdentificationConfig config_; + std::unique_ptr model_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ diff --git a/sherpa-onnx/csrc/spoken-language-identification.cc b/sherpa-onnx/csrc/spoken-language-identification.cc new file mode 100644 index 000000000..868382835 --- /dev/null +++ b/sherpa-onnx/csrc/spoken-language-identification.cc @@ -0,0 +1,117 @@ +// sherpa-onnx/csrc/spoken-language-identification.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/spoken-language-identification.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/spoken-language-identification-impl.h" + +namespace sherpa_onnx { + +void SpokenLanguageIdentificationWhisperConfig::Register(ParseOptions *po) { + po->Register( + "whisper-encoder", &encoder, + "Path to then encoder of a whisper multilingual model. Support only " + "tiny, base, small, medium, large."); + + po->Register( + "whisper-decoder", &decoder, + "Path to the decoder of a whisper multilingual model. Support only " + "tiny, base, small, medium, large."); + + po->Register( + "whisper-tail-paddings", &tail_paddings, + "Suggested value: 300 for multilingual models. " + "Since we have removed the 30-second constraint, we need to add some " + "tail padding frames " + "so that whisper can detect the eot token. Leave it to -1 to use 1000"); +} + +bool SpokenLanguageIdentificationWhisperConfig::Validate() const { + if (encoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --whisper-encoder"); + return false; + } + + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str()); + return false; + } + + if (decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --whisper-decoder"); + return false; + } + + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str()); + return false; + } + + return true; +} + +std::string SpokenLanguageIdentificationWhisperConfig::ToString() const { + std::ostringstream os; + + os << "SpokenLanguageIdentificationWhisperConfig("; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\", "; + os << "tail_paddings=" << tail_paddings << ")"; + + return os.str(); +} + +void SpokenLanguageIdentificationConfig::Register(ParseOptions *po) { + whisper.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool SpokenLanguageIdentificationConfig::Validate() const { + if (!whisper.Validate()) { + return false; + } + + return true; +} + +std::string SpokenLanguageIdentificationConfig::ToString() const { + std::ostringstream os; + + os << "SpokenLanguageIdentificationConfig("; + os << "whisper=\"" << whisper.ToString() << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +SpokenLanguageIdentification::SpokenLanguageIdentification( + const SpokenLanguageIdentificationConfig &config) + : impl_(SpokenLanguageIdentificationImpl::Create(config)) {} + +SpokenLanguageIdentification::~SpokenLanguageIdentification() = default; + +std::unique_ptr SpokenLanguageIdentification::CreateStream() + const { + return impl_->CreateStream(); +} + +std::string SpokenLanguageIdentification::Compute(OfflineStream *s) const { + return impl_->Compute(s); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/spoken-language-identification.h b/sherpa-onnx/csrc/spoken-language-identification.h new file mode 100644 index 000000000..83e60da59 --- /dev/null +++ b/sherpa-onnx/csrc/spoken-language-identification.h @@ -0,0 +1,89 @@ +// sherpa-onnx/csrc/spoken-language-identification.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ + +#include +#include + +#include "sherpa-onnx/csrc/offline-stream.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct SpokenLanguageIdentificationWhisperConfig { + // Requires a multi-lingual whisper model. + // That is, it supports only tiny, base, small, medium, large. + // Note: It does NOT support tiny.en, base.en, small.en, medium.en + std::string encoder; + std::string decoder; + + // Number of tail padding frames. + // + // Since we remove the 30-second constraint, we need to add some paddings + // at the end. + // + // Recommended values: + // - 50 for English models + // - 300 for multilingual models + int32_t tail_paddings = -1; + + SpokenLanguageIdentificationWhisperConfig() = default; + + SpokenLanguageIdentificationWhisperConfig(const std::string &encoder, + const std::string &decoder, + int32_t tail_paddings) + : encoder(encoder), decoder(decoder), tail_paddings(tail_paddings) {} + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; + +struct SpokenLanguageIdentificationConfig { + SpokenLanguageIdentificationWhisperConfig whisper; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + SpokenLanguageIdentificationConfig() = default; + + SpokenLanguageIdentificationConfig( + const SpokenLanguageIdentificationWhisperConfig &whisper, + int32_t num_threads, bool debug, const std::string &provider) + : whisper(whisper), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; + +class SpokenLanguageIdentificationImpl; + +class SpokenLanguageIdentification { + public: + explicit SpokenLanguageIdentification( + const SpokenLanguageIdentificationConfig &config); + + ~SpokenLanguageIdentification(); + + // Create a stream to accept audio samples and compute features + std::unique_ptr CreateStream() const; + + // Return a string containing the language, e.g., en, zh, de, + // etc. + // Note: en is for English, zh is for Chinese, de is for German, etc. + std::string Compute(OfflineStream *s) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index bba7903a0..ff81d5e4e 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -33,6 +33,7 @@ set(srcs silero-vad-model-config.cc speaker-embedding-extractor.cc speaker-embedding-manager.cc + spoken-language-identification.cc vad-model-config.cc vad-model.cc voice-activity-detector.cc diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 7b0d7c0a0..b30ed16da 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -22,6 +22,7 @@ #include "sherpa-onnx/python/csrc/online-stream.h" #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" #include "sherpa-onnx/python/csrc/speaker-embedding-manager.h" +#include "sherpa-onnx/python/csrc/spoken-language-identification.h" #include "sherpa-onnx/python/csrc/vad-model-config.h" #include "sherpa-onnx/python/csrc/vad-model.h" #include "sherpa-onnx/python/csrc/voice-activity-detector.h" @@ -55,6 +56,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindOfflineTts(&m); PybindSpeakerEmbeddingExtractor(&m); PybindSpeakerEmbeddingManager(&m); + PybindSpokenLanguageIdentification(&m); PybindAlsa(&m); } diff --git a/sherpa-onnx/python/csrc/spoken-language-identification.cc b/sherpa-onnx/python/csrc/spoken-language-identification.cc new file mode 100644 index 000000000..f528e5561 --- /dev/null +++ b/sherpa-onnx/python/csrc/spoken-language-identification.cc @@ -0,0 +1,60 @@ +// sherpa-onnx/python/csrc/spoken-language-identification.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/spoken-language-identification.h" + +#include + +#include "sherpa-onnx/csrc/spoken-language-identification.h" + +namespace sherpa_onnx { + +static void PybindSpokenLanguageIdentificationWhisperConfig(py::module *m) { + using PyClass = SpokenLanguageIdentificationWhisperConfig; + + py::class_(*m, "SpokenLanguageIdentificationWhisperConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("encoder"), py::arg("decoder"), + py::arg("tail_paddings") = -1) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def_readwrite("tail_paddings", &PyClass::tail_paddings) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindSpokenLanguageIdentificationConfig(py::module *m) { + PybindSpokenLanguageIdentificationWhisperConfig(m); + + using PyClass = SpokenLanguageIdentificationConfig; + + py::class_(*m, "SpokenLanguageIdentificationConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("whisper"), py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("whisper", &PyClass::whisper) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindSpokenLanguageIdentification(py::module *m) { + PybindSpokenLanguageIdentificationConfig(m); + + using PyClass = SpokenLanguageIdentification; + py::class_(*m, "SpokenLanguageIdentification") + .def(py::init(), + py::arg("config"), py::call_guard()) + .def("create_stream", &PyClass::CreateStream, + py::call_guard()) + .def("compute", &PyClass::Compute, + py::call_guard()); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/spoken-language-identification.h b/sherpa-onnx/python/csrc/spoken-language-identification.h new file mode 100644 index 000000000..52b3d9bc2 --- /dev/null +++ b/sherpa-onnx/python/csrc/spoken-language-identification.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/spoken-language-identification.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ +#define SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindSpokenLanguageIdentification(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index ee22bd432..1f98bef69 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -13,6 +13,9 @@ SpeakerEmbeddingExtractorConfig, SpeakerEmbeddingManager, SpeechSegment, + SpokenLanguageIdentification, + SpokenLanguageIdentificationConfig, + SpokenLanguageIdentificationWhisperConfig, VadModel, VadModelConfig, VoiceActivityDetector,