diff --git a/.gitignore b/.gitignore index e0743e07f..48114405e 100644 --- a/.gitignore +++ b/.gitignore @@ -93,3 +93,5 @@ sr-data vits-icefall-* sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 spoken-language-identification-test-wavs +my-release-key* +vits-zh-hf-fanchen-C diff --git a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 369aaa8c5..c342b9f61 100644 --- a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -158,6 +158,7 @@ class MainActivity : AppCompatActivity() { var ruleFars: String? var lexicon: String? var dataDir: String? + var dictDir: String? var assets: AssetManager? = application.assets // The purpose of such a design is to make the CI test easier @@ -169,6 +170,7 @@ class MainActivity : AppCompatActivity() { ruleFars = null lexicon = null dataDir = null + dictDir = null // Example 1: // modelDir = "vits-vctk" @@ -191,21 +193,36 @@ class MainActivity : AppCompatActivity() { // lexicon = "lexicon.txt" // Example 4: + // https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/vits.html#csukuangfj-vits-zh-hf-fanchen-c-chinese-187-speakers + // modelDir = "vits-zh-hf-fanchen-C" + // modelName = "vits-zh-hf-fanchen-C.onnx" + // lexicon = "lexicon.txt" + // dictDir = "vits-zh-hf-fanchen-C/dict" + + // Example 5: // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-coqui-de-css10.tar.bz2 // modelDir = "vits-coqui-de-css10" // modelName = "model.onnx" - // lang = "deu" if (dataDir != null) { - val newDir = copyDataDir(modelDir) + val newDir = copyDataDir(modelDir!!) modelDir = newDir + "/" + modelDir dataDir = newDir + "/" + dataDir assets = null } + if (dictDir != null) { + val newDir = copyDataDir( modelDir!!) + modelDir = newDir + "/" + modelDir + dictDir = modelDir + "/" + "dict" + ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst" + assets = null + } + val config = getOfflineTtsConfig( modelDir = modelDir!!, modelName = modelName!!, lexicon = lexicon ?: "", dataDir = dataDir ?: "", + dictDir = dictDir ?: "", ruleFsts = ruleFsts ?: "", ruleFars = ruleFars ?: "", )!! diff --git a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt index 2514fcac5..e0f95166c 100644 --- a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt +++ b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt @@ -8,6 +8,7 @@ data class OfflineTtsVitsModelConfig( var lexicon: String = "", var tokens: String, var dataDir: String = "", + var dictDir: String = "", var noiseScale: Float = 0.667f, var noiseScaleW: Float = 0.8f, var lengthScale: Float = 1.0f, @@ -49,7 +50,7 @@ class OfflineTts( init { if (assetManager != null) { - ptr = new(assetManager, config) + ptr = newFromAsset(assetManager, config) } else { ptr = newFromFile(config) } @@ -87,7 +88,7 @@ class OfflineTts( fun allocate(assetManager: AssetManager? = null) { if (ptr == 0L) { if (assetManager != null) { - ptr = new(assetManager, config) + ptr = newFromAsset(assetManager, config) } else { ptr = newFromFile(config) } @@ -105,7 +106,7 @@ class OfflineTts( delete(ptr) } - private external fun new( + private external fun newFromAsset( assetManager: AssetManager, config: OfflineTtsConfig, ): Long @@ -152,6 +153,7 @@ fun getOfflineTtsConfig( modelName: String, lexicon: String, dataDir: String, + dictDir: String, ruleFsts: String, ruleFars: String ): OfflineTtsConfig? { @@ -161,7 +163,8 @@ fun getOfflineTtsConfig( model = "$modelDir/$modelName", lexicon = "$modelDir/$lexicon", tokens = "$modelDir/tokens.txt", - dataDir = "$dataDir" + dataDir = dataDir, + dictDir = dictDir, ), numThreads = 2, debug = true, diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt index 5699ccf20..e02cc069c 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt @@ -42,6 +42,7 @@ object TtsEngine { private var ruleFars: String? = null private var lexicon: String? = null private var dataDir: String? = null + private var dictDir: String? = null private var assets: AssetManager? = null init { @@ -54,6 +55,7 @@ object TtsEngine { ruleFars = null lexicon = null dataDir = null + dictDir = null lang = null // Please enable one and only one of the examples below @@ -83,6 +85,14 @@ object TtsEngine { // lang = "zho" // Example 4: + // https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/vits.html#csukuangfj-vits-zh-hf-fanchen-c-chinese-187-speakers + // modelDir = "vits-zh-hf-fanchen-C" + // modelName = "vits-zh-hf-fanchen-C.onnx" + // lexicon = "lexicon.txt" + // dictDir = "vits-zh-hf-fanchen-C/dict" + // lang = "zho" + + // Example 5: // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-coqui-de-css10.tar.bz2 // This model does not need lexicon or dataDir // modelDir = "vits-coqui-de-css10" @@ -108,9 +118,18 @@ object TtsEngine { assets = null } + if (dictDir != null) { + val newDir = copyDataDir(context, modelDir!!) + modelDir = newDir + "/" + modelDir + dictDir = modelDir + "/" + "dict" + ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst" + assets = null + } + val config = getOfflineTtsConfig( modelDir = modelDir!!, modelName = modelName!!, lexicon = lexicon ?: "", dataDir = dataDir ?: "", + dictDir = dictDir ?: "", ruleFsts = ruleFsts ?: "", ruleFars = ruleFars ?: "" )!! diff --git a/build-android-arm64-v8a.sh b/build-android-arm64-v8a.sh index f2726659d..181e70d4c 100755 --- a/build-android-arm64-v8a.sh +++ b/build-android-arm64-v8a.sh @@ -47,7 +47,7 @@ onnxruntime_version=1.17.1 if [ ! -f $onnxruntime_version/jni/arm64-v8a/libonnxruntime.so ]; then mkdir -p $onnxruntime_version pushd $onnxruntime_version - wget -q https://github.com/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime-android-${onnxruntime_version}.zip + wget -c -q https://github.com/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime-android-${onnxruntime_version}.zip unzip onnxruntime-android-${onnxruntime_version}.zip rm onnxruntime-android-${onnxruntime_version}.zip popd diff --git a/build-android-armv7-eabi.sh b/build-android-armv7-eabi.sh index 32572b5ec..5c8bcd132 100755 --- a/build-android-armv7-eabi.sh +++ b/build-android-armv7-eabi.sh @@ -48,7 +48,7 @@ onnxruntime_version=1.17.1 if [ ! -f $onnxruntime_version/jni/armeabi-v7a/libonnxruntime.so ]; then mkdir -p $onnxruntime_version pushd $onnxruntime_version - wget -q https://github.com/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime-android-${onnxruntime_version}.zip + wget -c -q https://github.com/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime-android-${onnxruntime_version}.zip unzip onnxruntime-android-${onnxruntime_version}.zip rm onnxruntime-android-${onnxruntime_version}.zip popd diff --git a/build-android-x86-64.sh b/build-android-x86-64.sh index 84071ced2..15241f050 100755 --- a/build-android-x86-64.sh +++ b/build-android-x86-64.sh @@ -48,7 +48,7 @@ onnxruntime_version=1.17.1 if [ ! -f $onnxruntime_version/jni/x86_64/libonnxruntime.so ]; then mkdir -p $onnxruntime_version pushd $onnxruntime_version - wget -q https://github.com/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime-android-${onnxruntime_version}.zip + wget -c -q https://github.com/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime-android-${onnxruntime_version}.zip unzip onnxruntime-android-${onnxruntime_version}.zip rm onnxruntime-android-${onnxruntime_version}.zip popd diff --git a/build-android-x86.sh b/build-android-x86.sh index 968e81239..c02d9fc5e 100755 --- a/build-android-x86.sh +++ b/build-android-x86.sh @@ -48,7 +48,7 @@ onnxruntime_version=1.17.1 if [ ! -f $onnxruntime_version/jni/x86/libonnxruntime.so ]; then mkdir -p $onnxruntime_version pushd $onnxruntime_version - wget -q https://github.com/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime-android-${onnxruntime_version}.zip + wget -c -q https://github.com/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime-android-${onnxruntime_version}.zip unzip onnxruntime-android-${onnxruntime_version}.zip rm onnxruntime-android-${onnxruntime_version}.zip popd diff --git a/scripts/apk/build-apk-tts-engine.sh.in b/scripts/apk/build-apk-tts-engine.sh.in index 08d570384..80e34df3a 100644 --- a/scripts/apk/build-apk-tts-engine.sh.in +++ b/scripts/apk/build-apk-tts-engine.sh.in @@ -61,6 +61,11 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt sed -i.bak s%"ruleFsts = null"%"ruleFars = \"$rule_fars\""% ./TtsEngine.kt {% endif %} +{% if tts_model.dict_dir %} + dict_dir={{ tts_model.dict_dir }} + sed -i.bak s%"dictDir = null"%"dictDir = \"$dict_dir\""% ./TtsEngine.kt +{% endif %} + {% if tts_model.data_dir %} data_dir={{ tts_model.data_dir }} sed -i.bak s%"dataDir = null"%"dataDir = \"$data_dir\""% ./TtsEngine.kt diff --git a/scripts/apk/build-apk-tts.sh.in b/scripts/apk/build-apk-tts.sh.in index 5eb377e31..2caf3788a 100644 --- a/scripts/apk/build-apk-tts.sh.in +++ b/scripts/apk/build-apk-tts.sh.in @@ -59,6 +59,11 @@ sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt sed -i.bak s%"ruleFsts = null"%"ruleFars = \"$rule_fars\""% ./MainActivity.kt {% endif %} +{% if tts_model.dict_dir %} + dict_dir={{ tts_model.dict_dir }} + sed -i.bak s%"dictDir = null"%"dictDir = \"$dict_dir\""% ./MainActivity.kt +{% endif %} + {% if tts_model.data_dir %} data_dir={{ tts_model.data_dir }} sed -i.bak s%"dataDir = null"%"dataDir = \"$data_dir\""% ./MainActivity.kt diff --git a/scripts/apk/generate-tts-apk-script.py b/scripts/apk/generate-tts-apk-script.py index 1221c4d33..8c19a5e4e 100755 --- a/scripts/apk/generate-tts-apk-script.py +++ b/scripts/apk/generate-tts-apk-script.py @@ -35,6 +35,7 @@ class TtsModel: rule_fsts: Optional[List[str]] = None rule_fars: Optional[List[str]] = None data_dir: Optional[str] = None + dict_dir: Optional[str] = None is_char: bool = False lang_iso_639_3: str = "" @@ -326,8 +327,14 @@ def get_vits_models() -> List[TtsModel]: rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"] for m in chinese_models: s = [f"{m.model_dir}/{r}" for r in rule_fsts] + if "vits-zh-hf" in m.model_dir: + s = s[:-1] + m.dict_dir = m.model_dir + "/dict" + m.rule_fsts = ",".join(s) - m.rule_fars = f"{m.model_dir}/rule.far" + + if "vits-zh-hf" not in m.model_dir: + m.rule_fars = f"{m.model_dir}/rule.far" all_models = chinese_models + [ TtsModel( diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.cc b/sherpa-onnx/csrc/audio-tagging-model-config.cc index 5c5dcf2e1..ba68c50ee 100644 --- a/sherpa-onnx/csrc/audio-tagging-model-config.cc +++ b/sherpa-onnx/csrc/audio-tagging-model-config.cc @@ -32,7 +32,7 @@ bool AudioTaggingModelConfig::Validate() const { } if (!ced.empty() && !FileExists(ced)) { - SHERPA_ONNX_LOGE("CED model file %s does not exist", ced.c_str()); + SHERPA_ONNX_LOGE("CED model file '%s' does not exist", ced.c_str()); return false; } diff --git a/sherpa-onnx/csrc/audio-tagging.cc b/sherpa-onnx/csrc/audio-tagging.cc index c6d1e24f2..966a19200 100644 --- a/sherpa-onnx/csrc/audio-tagging.cc +++ b/sherpa-onnx/csrc/audio-tagging.cc @@ -48,7 +48,7 @@ bool AudioTaggingConfig::Validate() const { } if (!FileExists(labels)) { - SHERPA_ONNX_LOGE("--labels %s does not exist", labels.c_str()); + SHERPA_ONNX_LOGE("--labels '%s' does not exist", labels.c_str()); return false; } diff --git a/sherpa-onnx/csrc/file-utils.cc b/sherpa-onnx/csrc/file-utils.cc index f5cf48f97..8d87fd19f 100644 --- a/sherpa-onnx/csrc/file-utils.cc +++ b/sherpa-onnx/csrc/file-utils.cc @@ -7,7 +7,7 @@ #include #include -#include "sherpa-onnx/csrc/log.h" +#include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { @@ -17,7 +17,7 @@ bool FileExists(const std::string &filename) { void AssertFileExists(const std::string &filename) { if (!FileExists(filename)) { - SHERPA_ONNX_LOG(FATAL) << filename << " does not exist!"; + SHERPA_ONNX_LOGE("filename '%s' does not exist", filename.c_str()); exit(-1); } } diff --git a/sherpa-onnx/csrc/jieba-lexicon.cc b/sherpa-onnx/csrc/jieba-lexicon.cc index c63fff1c0..2f99a34f1 100644 --- a/sherpa-onnx/csrc/jieba-lexicon.cc +++ b/sherpa-onnx/csrc/jieba-lexicon.cc @@ -146,6 +146,14 @@ class JiebaLexicon::Impl { if (token2id_.count(p.first) && !token2id_.count(p.second)) { token2id_[p.second] = token2id_[p.first]; } + + if (!token2id_.count(p.first) && token2id_.count(p.second)) { + token2id_[p.first] = token2id_[p.second]; + } + } + + if (!token2id_.count("、") && token2id_.count(",")) { + token2id_["、"] = token2id_[","]; } } diff --git a/sherpa-onnx/csrc/keyword-spotter.cc b/sherpa-onnx/csrc/keyword-spotter.cc index 274a7fddf..7e93d7a04 100644 --- a/sherpa-onnx/csrc/keyword-spotter.cc +++ b/sherpa-onnx/csrc/keyword-spotter.cc @@ -101,7 +101,8 @@ bool KeywordSpotterConfig::Validate() const { // Solution: take keyword_file variable is directly // parsed as a string of keywords if (!std::ifstream(keywords_file.c_str()).good()) { - SHERPA_ONNX_LOGE("Keywords file %s does not exist.", keywords_file.c_str()); + SHERPA_ONNX_LOGE("Keywords file '%s' does not exist.", + keywords_file.c_str()); return false; } #endif diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc index 481ecaef5..7b47b6b45 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc @@ -34,7 +34,7 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) { bool OfflineCtcFstDecoderConfig::Validate() const { if (!graph.empty() && !FileExists(graph)) { - SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str()); + SHERPA_ONNX_LOGE("graph: '%s' does not exist", graph.c_str()); return false; } return true; diff --git a/sherpa-onnx/csrc/offline-lm-config.cc b/sherpa-onnx/csrc/offline-lm-config.cc index 078e56fab..791fa11af 100644 --- a/sherpa-onnx/csrc/offline-lm-config.cc +++ b/sherpa-onnx/csrc/offline-lm-config.cc @@ -22,7 +22,7 @@ void OfflineLMConfig::Register(ParseOptions *po) { bool OfflineLMConfig::Validate() const { if (!FileExists(model)) { - SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("'%s' does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc index 5589402ee..9ea26f814 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc @@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) { bool OfflineNemoEncDecCtcModelConfig::Validate() const { if (!FileExists(model)) { - SHERPA_ONNX_LOGE("NeMo model: %s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("NeMo model: '%s' does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/offline-paraformer-model-config.cc b/sherpa-onnx/csrc/offline-paraformer-model-config.cc index 82886fe87..b95edfea6 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model-config.cc +++ b/sherpa-onnx/csrc/offline-paraformer-model-config.cc @@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) { bool OfflineParaformerModelConfig::Validate() const { if (!FileExists(model)) { - SHERPA_ONNX_LOGE("Paraformer model %s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("Paraformer model '%s' does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/offline-transducer-model-config.cc b/sherpa-onnx/csrc/offline-transducer-model-config.cc index 05fcc9092..72fcfefbf 100644 --- a/sherpa-onnx/csrc/offline-transducer-model-config.cc +++ b/sherpa-onnx/csrc/offline-transducer-model-config.cc @@ -18,19 +18,19 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { bool OfflineTransducerModelConfig::Validate() const { if (!FileExists(encoder_filename)) { - SHERPA_ONNX_LOGE("transducer encoder: %s does not exist", + SHERPA_ONNX_LOGE("transducer encoder: '%s' does not exist", encoder_filename.c_str()); return false; } if (!FileExists(decoder_filename)) { - SHERPA_ONNX_LOGE("transducer decoder: %s does not exist", + SHERPA_ONNX_LOGE("transducer decoder: '%s' does not exist", decoder_filename.c_str()); return false; } if (!FileExists(joiner_filename)) { - SHERPA_ONNX_LOGE("transducer joiner: %s does not exist", + SHERPA_ONNX_LOGE("transducer joiner: '%s' does not exist", joiner_filename.c_str()); return false; } diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc index e6195b4f9..d380ec18b 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc @@ -35,7 +35,7 @@ bool OfflineTtsVitsModelConfig::Validate() const { } if (!FileExists(model)) { - SHERPA_ONNX_LOGE("--vits-model: %s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("--vits-model: '%s' does not exist", model.c_str()); return false; } @@ -45,31 +45,31 @@ bool OfflineTtsVitsModelConfig::Validate() const { } if (!FileExists(tokens)) { - SHERPA_ONNX_LOGE("--vits-tokens: %s does not exist", tokens.c_str()); + SHERPA_ONNX_LOGE("--vits-tokens: '%s' does not exist", tokens.c_str()); return false; } if (!data_dir.empty()) { if (!FileExists(data_dir + "/phontab")) { - SHERPA_ONNX_LOGE("%s/phontab does not exist. Skipping test", + SHERPA_ONNX_LOGE("'%s/phontab' does not exist. Skipping test", data_dir.c_str()); return false; } if (!FileExists(data_dir + "/phonindex")) { - SHERPA_ONNX_LOGE("%s/phonindex does not exist. Skipping test", + SHERPA_ONNX_LOGE("'%s/phonindex' does not exist. Skipping test", data_dir.c_str()); return false; } if (!FileExists(data_dir + "/phondata")) { - SHERPA_ONNX_LOGE("%s/phondata does not exist. Skipping test", + SHERPA_ONNX_LOGE("'%s/phondata' does not exist. Skipping test", data_dir.c_str()); return false; } if (!FileExists(data_dir + "/intonations")) { - SHERPA_ONNX_LOGE("%s/intonations does not exist.", data_dir.c_str()); + SHERPA_ONNX_LOGE("'%s/intonations' does not exist.", data_dir.c_str()); return false; } } @@ -82,7 +82,8 @@ bool OfflineTtsVitsModelConfig::Validate() const { for (const auto &f : required_files) { if (!FileExists(dict_dir + "/" + f)) { - SHERPA_ONNX_LOGE("%s/%s does not exist.", data_dir.c_str(), f.c_str()); + SHERPA_ONNX_LOGE("'%s/%s' does not exist.", data_dir.c_str(), + f.c_str()); return false; } } diff --git a/sherpa-onnx/csrc/offline-tts.cc b/sherpa-onnx/csrc/offline-tts.cc index 34d4a39ca..4349f98e9 100644 --- a/sherpa-onnx/csrc/offline-tts.cc +++ b/sherpa-onnx/csrc/offline-tts.cc @@ -42,7 +42,7 @@ bool OfflineTtsConfig::Validate() const { SplitStringToVector(rule_fsts, ",", false, &files); for (const auto &f : files) { if (!FileExists(f)) { - SHERPA_ONNX_LOGE("Rule fst %s does not exist. ", f.c_str()); + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str()); return false; } } @@ -53,7 +53,7 @@ bool OfflineTtsConfig::Validate() const { SplitStringToVector(rule_fars, ",", false, &files); for (const auto &f : files) { if (!FileExists(f)) { - SHERPA_ONNX_LOGE("Rule far %s does not exist. ", f.c_str()); + SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str()); return false; } } diff --git a/sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc b/sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc index f3543948e..2493971a6 100644 --- a/sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc +++ b/sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc @@ -18,7 +18,7 @@ void OfflineWenetCtcModelConfig::Register(ParseOptions *po) { bool OfflineWenetCtcModelConfig::Validate() const { if (!FileExists(model)) { - SHERPA_ONNX_LOGE("WeNet model: %s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("WeNet model: '%s' does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/offline-whisper-model-config.cc b/sherpa-onnx/csrc/offline-whisper-model-config.cc index 821ccfbbe..5d36b82c4 100644 --- a/sherpa-onnx/csrc/offline-whisper-model-config.cc +++ b/sherpa-onnx/csrc/offline-whisper-model-config.cc @@ -48,7 +48,8 @@ bool OfflineWhisperModelConfig::Validate() const { } if (!FileExists(encoder)) { - SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str()); + SHERPA_ONNX_LOGE("whisper encoder file '%s' does not exist", + encoder.c_str()); return false; } @@ -58,7 +59,8 @@ bool OfflineWhisperModelConfig::Validate() const { } if (!FileExists(decoder)) { - SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str()); + SHERPA_ONNX_LOGE("whisper decoder file '%s' does not exist", + decoder.c_str()); return false; } diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc index 3034ff77f..633bfac86 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc @@ -21,7 +21,7 @@ bool OfflineZipformerAudioTaggingModelConfig::Validate() const { } if (!FileExists(model)) { - SHERPA_ONNX_LOGE("--zipformer-model: %s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("--zipformer-model: '%s' does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc b/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc index 1c661fcc1..e03e841fb 100644 --- a/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc @@ -15,7 +15,7 @@ void OfflineZipformerCtcModelConfig::Register(ParseOptions *po) { bool OfflineZipformerCtcModelConfig::Validate() const { if (!FileExists(model)) { - SHERPA_ONNX_LOGE("zipformer CTC model file %s does not exist", + SHERPA_ONNX_LOGE("zipformer CTC model file '%s' does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc b/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc index 9eccebea7..8c4f3787f 100644 --- a/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc @@ -31,7 +31,7 @@ void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) { bool OnlineCtcFstDecoderConfig::Validate() const { if (!graph.empty() && !FileExists(graph)) { - SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str()); + SHERPA_ONNX_LOGE("graph: '%s' does not exist", graph.c_str()); return false; } return true; diff --git a/sherpa-onnx/csrc/online-lm-config.cc b/sherpa-onnx/csrc/online-lm-config.cc index af75d1667..42990f720 100644 --- a/sherpa-onnx/csrc/online-lm-config.cc +++ b/sherpa-onnx/csrc/online-lm-config.cc @@ -22,7 +22,7 @@ void OnlineLMConfig::Register(ParseOptions *po) { bool OnlineLMConfig::Validate() const { if (!FileExists(model)) { - SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("'%s' does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index 16431c9b4..d2da161e9 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -45,7 +45,7 @@ bool OnlineModelConfig::Validate() const { } if (!FileExists(tokens)) { - SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str()); + SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str()); return false; } diff --git a/sherpa-onnx/csrc/online-paraformer-model-config.cc b/sherpa-onnx/csrc/online-paraformer-model-config.cc index a93fe2992..25a99262d 100644 --- a/sherpa-onnx/csrc/online-paraformer-model-config.cc +++ b/sherpa-onnx/csrc/online-paraformer-model-config.cc @@ -18,12 +18,12 @@ void OnlineParaformerModelConfig::Register(ParseOptions *po) { bool OnlineParaformerModelConfig::Validate() const { if (!FileExists(encoder)) { - SHERPA_ONNX_LOGE("Paraformer encoder %s does not exist", encoder.c_str()); + SHERPA_ONNX_LOGE("Paraformer encoder '%s' does not exist", encoder.c_str()); return false; } if (!FileExists(decoder)) { - SHERPA_ONNX_LOGE("Paraformer decoder %s does not exist", decoder.c_str()); + SHERPA_ONNX_LOGE("Paraformer decoder '%s' does not exist", decoder.c_str()); return false; } diff --git a/sherpa-onnx/csrc/online-transducer-model-config.cc b/sherpa-onnx/csrc/online-transducer-model-config.cc index f7015f98d..dd7572717 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.cc +++ b/sherpa-onnx/csrc/online-transducer-model-config.cc @@ -18,17 +18,19 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) { bool OnlineTransducerModelConfig::Validate() const { if (!FileExists(encoder)) { - SHERPA_ONNX_LOGE("transducer encoder: %s does not exist", encoder.c_str()); + SHERPA_ONNX_LOGE("transducer encoder: '%s' does not exist", + encoder.c_str()); return false; } if (!FileExists(decoder)) { - SHERPA_ONNX_LOGE("transducer decoder: %s does not exist", decoder.c_str()); + SHERPA_ONNX_LOGE("transducer decoder: '%s' does not exist", + decoder.c_str()); return false; } if (!FileExists(joiner)) { - SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner.c_str()); + SHERPA_ONNX_LOGE("joiner: '%s' does not exist", joiner.c_str()); return false; } diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model-config.cc b/sherpa-onnx/csrc/online-wenet-ctc-model-config.cc index 6098be626..a47b3e162 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model-config.cc +++ b/sherpa-onnx/csrc/online-wenet-ctc-model-config.cc @@ -21,7 +21,7 @@ void OnlineWenetCtcModelConfig::Register(ParseOptions *po) { bool OnlineWenetCtcModelConfig::Validate() const { if (!FileExists(model)) { - SHERPA_ONNX_LOGE("WeNet CTC model %s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("WeNet CTC model '%s' does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc b/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc index 836808d6f..ed9e7b8a9 100644 --- a/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc @@ -22,7 +22,8 @@ bool OnlineZipformer2CtcModelConfig::Validate() const { } if (!FileExists(model)) { - SHERPA_ONNX_LOGE("--zipformer2-ctc-model %s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("--zipformer2-ctc-model '%s' does not exist", + model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/silero-vad-model-config.cc b/sherpa-onnx/csrc/silero-vad-model-config.cc index 8419265fe..6589361ea 100644 --- a/sherpa-onnx/csrc/silero-vad-model-config.cc +++ b/sherpa-onnx/csrc/silero-vad-model-config.cc @@ -44,7 +44,8 @@ bool SileroVadModelConfig::Validate() const { } if (!FileExists(model)) { - SHERPA_ONNX_LOGE("Silero vad model file %s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("Silero vad model file '%s' does not exist", + model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor.cc b/sherpa-onnx/csrc/speaker-embedding-extractor.cc index 35bfc297b..1c99de1a0 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor.cc @@ -31,7 +31,7 @@ bool SpeakerEmbeddingExtractorConfig::Validate() const { } if (!FileExists(model)) { - SHERPA_ONNX_LOGE("--speaker-embedding-model: %s does not exist", + SHERPA_ONNX_LOGE("--speaker-embedding-model: '%s' does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/spoken-language-identification.cc b/sherpa-onnx/csrc/spoken-language-identification.cc index eff49662d..3797586a2 100644 --- a/sherpa-onnx/csrc/spoken-language-identification.cc +++ b/sherpa-onnx/csrc/spoken-language-identification.cc @@ -43,7 +43,8 @@ bool SpokenLanguageIdentificationWhisperConfig::Validate() const { } if (!FileExists(encoder)) { - SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str()); + SHERPA_ONNX_LOGE("whisper encoder file '%s' does not exist", + encoder.c_str()); return false; } @@ -53,7 +54,8 @@ bool SpokenLanguageIdentificationWhisperConfig::Validate() const { } if (!FileExists(decoder)) { - SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str()); + SHERPA_ONNX_LOGE("whisper decoder file '%s' does not exist", + decoder.c_str()); return false; } diff --git a/sherpa-onnx/jni/CMakeLists.txt b/sherpa-onnx/jni/CMakeLists.txt index 6f14a35fa..bb08bbf35 100644 --- a/sherpa-onnx/jni/CMakeLists.txt +++ b/sherpa-onnx/jni/CMakeLists.txt @@ -9,11 +9,20 @@ if(NOT DEFINED ANDROID_ABI) include_directories($ENV{JAVA_HOME}/include/darwin) endif() -add_library(sherpa-onnx-jni +set(sources audio-tagging.cc jni.cc offline-stream.cc spoken-language-identification.cc ) + +if(SHERPA_ONNX_ENABLE_TTS) + list(APPEND sources + offline-tts.cc + ) +endif() + +add_library(sherpa-onnx-jni ${sources}) + target_link_libraries(sherpa-onnx-jni sherpa-onnx-core) install(TARGETS sherpa-onnx-jni DESTINATION lib) diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 6bb25a362..a8b2e4b6d 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -24,10 +24,6 @@ #include "sherpa-onnx/csrc/wave-writer.h" #include "sherpa-onnx/jni/common.h" -#if SHERPA_ONNX_ENABLE_TTS == 1 -#include "sherpa-onnx/csrc/offline-tts.h" -#endif - namespace sherpa_onnx { class SherpaOnnx { @@ -775,113 +771,6 @@ static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) { return ans; } -#if SHERPA_ONNX_ENABLE_TTS == 1 -class SherpaOnnxOfflineTts { - public: -#if __ANDROID_API__ >= 9 - SherpaOnnxOfflineTts(AAssetManager *mgr, const OfflineTtsConfig &config) - : tts_(mgr, config) {} -#endif - explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config) - : tts_(config) {} - - GeneratedAudio Generate(const std::string &text, int64_t sid = 0, - float speed = 1.0, - std::function - callback = nullptr) const { - return tts_.Generate(text, sid, speed, callback); - } - - int32_t SampleRate() const { return tts_.SampleRate(); } - - int32_t NumSpeakers() const { return tts_.NumSpeakers(); } - - private: - OfflineTts tts_; -}; - -static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { - OfflineTtsConfig ans; - - jclass cls = env->GetObjectClass(config); - jfieldID fid; - - fid = env->GetFieldID(cls, "model", - "Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;"); - jobject model = env->GetObjectField(config, fid); - jclass model_config_cls = env->GetObjectClass(model); - - fid = env->GetFieldID(model_config_cls, "vits", - "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;"); - jobject vits = env->GetObjectField(model, fid); - jclass vits_cls = env->GetObjectClass(vits); - - fid = env->GetFieldID(vits_cls, "model", "Ljava/lang/String;"); - jstring s = (jstring)env->GetObjectField(vits, fid); - const char *p = env->GetStringUTFChars(s, nullptr); - ans.model.vits.model = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(vits_cls, "lexicon", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(vits, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model.vits.lexicon = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(vits_cls, "tokens", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(vits, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model.vits.tokens = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(vits_cls, "dataDir", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(vits, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model.vits.data_dir = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(vits_cls, "noiseScale", "F"); - ans.model.vits.noise_scale = env->GetFloatField(vits, fid); - - fid = env->GetFieldID(vits_cls, "noiseScaleW", "F"); - ans.model.vits.noise_scale_w = env->GetFloatField(vits, fid); - - fid = env->GetFieldID(vits_cls, "lengthScale", "F"); - ans.model.vits.length_scale = env->GetFloatField(vits, fid); - - fid = env->GetFieldID(model_config_cls, "numThreads", "I"); - ans.model.num_threads = env->GetIntField(model, fid); - - fid = env->GetFieldID(model_config_cls, "debug", "Z"); - ans.model.debug = env->GetBooleanField(model, fid); - - fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model.provider = p; - env->ReleaseStringUTFChars(s, p); - - // for ruleFsts - fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.rule_fsts = p; - env->ReleaseStringUTFChars(s, p); - - // for ruleFars - fid = env->GetFieldID(cls, "ruleFars", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.rule_fars = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(cls, "maxNumSentences", "I"); - ans.max_num_sentences = env->GetIntField(config, fid); - - return ans; -} -#endif - } // namespace sherpa_onnx SHERPA_ONNX_EXTERN_C @@ -1226,128 +1115,6 @@ jobject NewFloat(JNIEnv *env, float value) { return env->NewObject(cls, constructor, value); } -#if SHERPA_ONNX_ENABLE_TTS == 1 -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new( - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { -#if __ANDROID_API__ >= 9 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); - if (!mgr) { - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); - } -#endif - auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - - auto tts = new sherpa_onnx::SherpaOnnxOfflineTts( -#if __ANDROID_API__ >= 9 - mgr, -#endif - config); - - return (jlong)tts; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile( - JNIEnv *env, jobject /*obj*/, jobject _config) { - auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - - if (!config.Validate()) { - SHERPA_ONNX_LOGE("Errors found in config!"); - } - - auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(config); - - return (jlong)tts; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - delete reinterpret_cast(ptr); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - return reinterpret_cast(ptr) - ->SampleRate(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getNumSpeakers( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - return reinterpret_cast(ptr) - ->NumSpeakers(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jobjectArray JNICALL -Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/, - jlong ptr, jstring text, - jint sid, jfloat speed) { - const char *p_text = env->GetStringUTFChars(text, nullptr); - SHERPA_ONNX_LOGE("string is: %s", p_text); - - auto audio = - reinterpret_cast(ptr)->Generate( - p_text, sid, speed); - - jfloatArray samples_arr = env->NewFloatArray(audio.samples.size()); - env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(), - audio.samples.data()); - - jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( - 2, env->FindClass("java/lang/Object"), nullptr); - - env->SetObjectArrayElement(obj_arr, 0, samples_arr); - env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate)); - - env->ReleaseStringUTFChars(text, p_text); - - return obj_arr; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jobjectArray JNICALL -Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( - JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid, - jfloat speed, jobject callback) { - const char *p_text = env->GetStringUTFChars(text, nullptr); - SHERPA_ONNX_LOGE("string is: %s", p_text); - - std::function callback_wrapper = - [env, callback](const float *samples, int32_t n, float /*p*/) { - jclass cls = env->GetObjectClass(callback); - jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V"); - - jfloatArray samples_arr = env->NewFloatArray(n); - env->SetFloatArrayRegion(samples_arr, 0, n, samples); - env->CallVoidMethod(callback, mid, samples_arr); - }; - - auto audio = - reinterpret_cast(ptr)->Generate( - p_text, sid, speed, callback_wrapper); - - jfloatArray samples_arr = env->NewFloatArray(audio.samples.size()); - env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(), - audio.samples.data()); - - jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( - 2, env->FindClass("java/lang/Object"), nullptr); - - env->SetObjectArrayElement(obj_arr, 0, samples_arr); - env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate)); - - env->ReleaseStringUTFChars(text, p_text); - - return obj_arr; -} -#endif - SHERPA_ONNX_EXTERN_C JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, diff --git a/sherpa-onnx/jni/offline-tts.cc b/sherpa-onnx/jni/offline-tts.cc new file mode 100644 index 000000000..d36c08420 --- /dev/null +++ b/sherpa-onnx/jni/offline-tts.cc @@ -0,0 +1,215 @@ +// sherpa-onnx/jni/offline-tts.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tts.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/jni/common.h" + +namespace sherpa_onnx { + +static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { + OfflineTtsConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + fid = env->GetFieldID(cls, "model", + "Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;"); + jobject model = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model); + + fid = env->GetFieldID(model_config_cls, "vits", + "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;"); + jobject vits = env->GetObjectField(model, fid); + jclass vits_cls = env->GetObjectClass(vits); + + fid = env->GetFieldID(vits_cls, "model", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(vits, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.model.vits.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(vits_cls, "lexicon", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(vits, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.vits.lexicon = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(vits_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(vits, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.vits.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(vits_cls, "dataDir", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(vits, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.vits.data_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(vits_cls, "dictDir", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(vits, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.vits.dict_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(vits_cls, "noiseScale", "F"); + ans.model.vits.noise_scale = env->GetFloatField(vits, fid); + + fid = env->GetFieldID(vits_cls, "noiseScaleW", "F"); + ans.model.vits.noise_scale_w = env->GetFloatField(vits, fid); + + fid = env->GetFieldID(vits_cls, "lengthScale", "F"); + ans.model.vits.length_scale = env->GetFloatField(vits, fid); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model.num_threads = env->GetIntField(model, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model.debug = env->GetBooleanField(model, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.provider = p; + env->ReleaseStringUTFChars(s, p); + + // for ruleFsts + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.rule_fsts = p; + env->ReleaseStringUTFChars(s, p); + + // for ruleFars + fid = env->GetFieldID(cls, "ruleFars", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.rule_fars = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "maxNumSentences", "I"); + ans.max_num_sentences = env->GetIntField(config, fid); + + return ans; +} + +} // namespace sherpa_onnx + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newForAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + } +#endif + auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto tts = new sherpa_onnx::OfflineTts( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)tts; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + } + + auto tts = new sherpa_onnx::OfflineTts(config); + + return (jlong)tts; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + return reinterpret_cast(ptr)->SampleRate(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getNumSpeakers( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + return reinterpret_cast(ptr)->NumSpeakers(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/, + jlong ptr, jstring text, + jint sid, jfloat speed) { + const char *p_text = env->GetStringUTFChars(text, nullptr); + SHERPA_ONNX_LOGE("string is: %s", p_text); + + auto audio = reinterpret_cast(ptr)->Generate( + p_text, sid, speed); + + jfloatArray samples_arr = env->NewFloatArray(audio.samples.size()); + env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(), + audio.samples.data()); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 2, env->FindClass("java/lang/Object"), nullptr); + + env->SetObjectArrayElement(obj_arr, 0, samples_arr); + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate)); + + env->ReleaseStringUTFChars(text, p_text); + + return obj_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid, + jfloat speed, jobject callback) { + const char *p_text = env->GetStringUTFChars(text, nullptr); + SHERPA_ONNX_LOGE("string is: %s", p_text); + + std::function callback_wrapper = + [env, callback](const float *samples, int32_t n, float /*progress*/) { + jclass cls = env->GetObjectClass(callback); + jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V"); + + jfloatArray samples_arr = env->NewFloatArray(n); + env->SetFloatArrayRegion(samples_arr, 0, n, samples); + env->CallVoidMethod(callback, mid, samples_arr); + }; + + auto audio = reinterpret_cast(ptr)->Generate( + p_text, sid, speed, callback_wrapper); + + jfloatArray samples_arr = env->NewFloatArray(audio.samples.size()); + env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(), + audio.samples.data()); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 2, env->FindClass("java/lang/Object"), nullptr); + + env->SetObjectArrayElement(obj_arr, 0, samples_arr); + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate)); + + env->ReleaseStringUTFChars(text, p_text); + + return obj_arr; +}