From 377696329d9fbb0ac797ecf7bc7c7d4e79a42783 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 26 Oct 2023 13:57:06 +0800 Subject: [PATCH] Support vits models from piper --- sherpa-onnx/csrc/lexicon.cc | 17 ++- sherpa-onnx/csrc/lexicon.h | 7 +- sherpa-onnx/csrc/offline-tts-vits-impl.h | 15 ++- sherpa-onnx/csrc/offline-tts-vits-model.cc | 139 ++++++++++++++++----- sherpa-onnx/csrc/offline-tts-vits-model.h | 1 + 5 files changed, 130 insertions(+), 49 deletions(-) diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index 717d426d9..c036a7a36 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -83,8 +83,8 @@ static std::vector ConvertTokensToIds( Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, const std::string &punctuations, const std::string &language, - bool debug /*= false*/) - : debug_(debug) { + bool debug /*= false*/, bool is_piper /*= false*/) + : debug_(debug), is_piper_(is_piper) { InitLanguage(language); { @@ -103,8 +103,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, #if __ANDROID_API__ >= 9 Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, const std::string &tokens, const std::string &punctuations, - const std::string &language, bool debug /*= false*/) - : debug_(debug) { + const std::string &language, bool debug /*= false*/, + bool is_piper /*= false*/) + : debug_(debug), is_piper_(is_piper) { InitLanguage(language); { @@ -206,6 +207,10 @@ std::vector Lexicon::ConvertTextToTokenIdsEnglish( int32_t blank = token2id_.at(" "); std::vector ans; + if (is_piper_) { + ans.push_back(token2id_.at("^")); // sos + } + for (const auto &w : words) { if (punctuations_.count(w)) { ans.push_back(token2id_.at(w)); @@ -227,6 +232,10 @@ std::vector Lexicon::ConvertTextToTokenIdsEnglish( ans.resize(ans.size() - 1); } + if (is_piper_) { + ans.push_back(token2id_.at("$")); // eos + } + return ans; } diff --git a/sherpa-onnx/csrc/lexicon.h b/sherpa-onnx/csrc/lexicon.h index a01004f24..dfc6cf9fd 100644 --- a/sherpa-onnx/csrc/lexicon.h +++ b/sherpa-onnx/csrc/lexicon.h @@ -24,12 +24,13 @@ class Lexicon { public: Lexicon(const std::string &lexicon, const std::string &tokens, const std::string &punctuations, const std::string &language, - bool debug = false); + bool debug = false, bool is_piper = false); #if __ANDROID_API__ >= 9 Lexicon(AAssetManager *mgr, const std::string &lexicon, const std::string &tokens, const std::string &punctuations, - const std::string &language, bool debug = false); + const std::string &language, bool debug = false, + bool is_piper = false); #endif std::vector ConvertTextToTokenIds(const std::string &text) const; @@ -59,7 +60,7 @@ class Lexicon { std::unordered_map token2id_; Language language_; bool debug_; - // + bool is_piper_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index 847fb305b..1845cf2a7 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -26,15 +26,15 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) : model_(std::make_unique(config.model)), lexicon_(config.model.vits.lexicon, config.model.vits.tokens, - model_->Punctuations(), model_->Language(), - config.model.debug) {} + model_->Punctuations(), model_->Language(), config.model.debug, + model_->IsPiper()) {} #if __ANDROID_API__ >= 9 OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config) : model_(std::make_unique(mgr, config.model)), lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens, - model_->Punctuations(), model_->Language(), - config.model.debug) {} + model_->Punctuations(), model_->Language(), config.model.debug, + model_->IsPiper()) {} #endif GeneratedAudio Generate(const std::string &text, int64_t sid = 0, @@ -43,17 +43,16 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { if (num_speakers == 0 && sid != 0) { SHERPA_ONNX_LOGE( "This is a single-speaker model and supports only sid 0. Given sid: " - "%d", + "%d. sid is ignored", static_cast(sid)); - return {}; } if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { SHERPA_ONNX_LOGE( "This model contains only %d speakers. sid should be in the range " - "[%d, %d]. Given: %d", + "[%d, %d]. Given: %d. Use sid=0", num_speakers, 0, num_speakers - 1, static_cast(sid)); - return {}; + sid = 0; } std::vector x = lexicon_.ConvertTextToTokenIds(text); diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index 53d8449e0..ab14b55de 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -38,6 +38,107 @@ class OfflineTtsVitsModel::Impl { #endif Ort::Value Run(Ort::Value x, int64_t sid, float speed) { + if (is_piper_) { + return RunVitsPiper(std::move(x), sid, speed); + } + + return RunVits(std::move(x), sid, speed); + } + + int32_t SampleRate() const { return sample_rate_; } + + bool AddBlank() const { return add_blank_; } + + std::string Punctuations() const { return punctuations_; } + std::string Language() const { return language_; } + bool IsPiper() const { return is_piper_; } + int32_t NumSpeakers() const { return num_speakers_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---vits model---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); + SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers"); + SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); + SHERPA_ONNX_READ_META_DATA_STR(language_, "language"); + + std::string comment; + SHERPA_ONNX_READ_META_DATA_STR(comment, "comment"); + if (comment.find("piper") != std::string::npos) { + is_piper_ = true; + } + } + + Ort::Value RunVitsPiper(Ort::Value x, int64_t sid, float speed) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::vector x_shape = x.GetTensorTypeAndShapeInfo().GetShape(); + if (x_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", + static_cast(x_shape[0])); + exit(-1); + } + + int64_t len = x_shape[1]; + int64_t len_shape = 1; + + Ort::Value x_length = + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); + + float noise_scale = config_.vits.noise_scale; + float length_scale = config_.vits.length_scale; + float noise_scale_w = config_.vits.noise_scale_w; + + if (speed != 1 && speed > 0) { + length_scale = 1. / speed; + } + std::array scales = {noise_scale, length_scale, noise_scale_w}; + + int64_t scale_shape = 3; + + Ort::Value scales_tensor = Ort::Value::CreateTensor( + memory_info, scales.data(), scales.size(), &scale_shape, 1); + + int64_t sid_shape = 1; + Ort::Value sid_tensor = + Ort::Value::CreateTensor(memory_info, &sid, 1, &sid_shape, 1); + + std::vector inputs; + inputs.reserve(4); + inputs.push_back(std::move(x)); + inputs.push_back(std::move(x_length)); + inputs.push_back(std::move(scales_tensor)); + + if (input_names_.size() == 4 && input_names_.back() == "sid") { + inputs.push_back(std::move(sid_tensor)); + } + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + return std::move(out[0]); + } + + Ort::Value RunVits(Ort::Value x, int64_t sid, float speed) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); @@ -94,40 +195,6 @@ class OfflineTtsVitsModel::Impl { return std::move(out[0]); } - int32_t SampleRate() const { return sample_rate_; } - - bool AddBlank() const { return add_blank_; } - - std::string Punctuations() const { return punctuations_; } - std::string Language() const { return language_; } - int32_t NumSpeakers() const { return num_speakers_; } - - private: - void Init(void *model_data, size_t model_data_length) { - sess_ = std::make_unique(env_, model_data, model_data_length, - sess_opts_); - - GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); - - GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); - - // get meta data - Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); - if (config_.debug) { - std::ostringstream os; - os << "---vits model---\n"; - PrintModelMetadata(os, meta_data); - SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); - } - - Ort::AllocatorWithDefaultOptions allocator; // used in the macro below - SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); - SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); - SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers"); - SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); - SHERPA_ONNX_READ_META_DATA_STR(language_, "language"); - } - private: OfflineTtsModelConfig config_; Ort::Env env_; @@ -147,6 +214,8 @@ class OfflineTtsVitsModel::Impl { int32_t num_speakers_; std::string punctuations_; std::string language_; + + bool is_piper_ = false; }; OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) @@ -175,6 +244,8 @@ std::string OfflineTtsVitsModel::Punctuations() const { std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); } +bool OfflineTtsVitsModel::IsPiper() const { return impl_->IsPiper(); } + int32_t OfflineTtsVitsModel::NumSpeakers() const { return impl_->NumSpeakers(); } diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.h b/sherpa-onnx/csrc/offline-tts-vits-model.h index dfe743cab..1cf9ad2ea 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model.h @@ -47,6 +47,7 @@ class OfflineTtsVitsModel { std::string Punctuations() const; std::string Language() const; + bool IsPiper() const; int32_t NumSpeakers() const; private: