Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow modify model config at decode time for ASR #1124

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,27 @@ struct SherpaOnnxOfflineStream {
: impl(std::move(p)) {}
};

sherpa_onnx::OfflineRecognizerConfig convertConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sherpa_onnx::OfflineRecognizerConfig convertConfig(
static sherpa_onnx::OfflineRecognizerConfig convertConfig(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const SherpaOnnxOfflineRecognizerConfig *config);
SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
const SherpaOnnxOfflineRecognizerConfig *config) {
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
convertConfig(config);

if (!recognizer_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
}

SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer;

recognizer->impl =
std::make_unique<sherpa_onnx::OfflineRecognizer>(recognizer_config);

return recognizer;
}
sherpa_onnx::OfflineRecognizerConfig convertConfig(
const SherpaOnnxOfflineRecognizerConfig *config) {
sherpa_onnx::OfflineRecognizerConfig recognizer_config;

recognizer_config.feat_config.sampling_rate =
Expand Down Expand Up @@ -398,17 +417,15 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
SHERPA_ONNX_LOGE("%s", recognizer_config.ToString().c_str());
}

if (!recognizer_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
}

SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer;

recognizer->impl =
std::make_unique<sherpa_onnx::OfflineRecognizer>(recognizer_config);
return recognizer_config;
}

return recognizer;
void SetSherpaOnnxOfflineRecognizerConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void SetSherpaOnnxOfflineRecognizerConfig(
void SherpaOnnxOfflineRecognizerSetConfig(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const SherpaOnnxOfflineRecognizer *recognizer,
const SherpaOnnxOfflineRecognizerConfig *config){
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
convertConfig(config);
recognizer->impl->SetConfig(recognizer_config);
}

void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) {
Expand Down Expand Up @@ -461,6 +478,13 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
pText[text.size()] = 0;
r->text = pText;

//lang
const auto &lang = result.lang;
char *c_lang = new char[lang.size() + 1];
std::copy(lang.begin(), lang.end(), c_lang);
c_lang[lang.size()] = '\0';
r->lang = c_lang;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember to use

delete[] r->lang;

to avoid memory leak.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// copy json
std::string json = result.AsJsonString();
char *pJson = new char[json.size() + 1];
Expand Down
11 changes: 10 additions & 1 deletion sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineStream SherpaOnnxOfflineStream;
SHERPA_ONNX_API SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
const SherpaOnnxOfflineRecognizerConfig *config);

/// @param config Config for the recognizer.
SHERPA_ONNX_API void SetSherpaOnnxOfflineRecognizerConfig(
const SherpaOnnxOfflineRecognizer *recognizer,
const SherpaOnnxOfflineRecognizerConfig *config);

/// Free a pointer returned by CreateOfflineRecognizer()
///
/// @param p A pointer returned by CreateOfflineRecognizer()
Expand Down Expand Up @@ -491,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
const char *text;

// Pointer to continuous memory which holds timestamps
// Pointer to continuous memory which holds timestamps
//
// It is NULL if the model does not support timestamps
float *timestamps;
Expand Down Expand Up @@ -519,6 +524,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
* }
*/
const char *json;

//return recognized language
const char *lang;

} SherpaOnnxOfflineRecognizerResult;

/// Get the result of the offline stream.
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
}

OfflineRecognizerConfig GetConfig() const override {
return config_;
}


private:
// Decode a single stream.
// Some models do not support batch size > 1, e.g., WeNet CTC models.
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,4 +431,8 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
return text;
}

void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) {
config_ = config;
}

} // namespace sherpa_onnx
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class OfflineRecognizerImpl {

virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;

virtual void SetConfig(const OfflineRecognizerConfig &config);

virtual OfflineRecognizerConfig GetConfig() const = 0;

std::string ApplyInverseTextNormalization(std::string text) const;

private:
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
}
}

OfflineRecognizerConfig GetConfig() const override {
return config_;
}

private:
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
int32_t lfr_window_size = model_->LfrWindowSize();
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}

OfflineRecognizerConfig GetConfig() const override {
return config_;
}


void InitHotwords() {
// each line in hotwords_file contains space-separated words

Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
}
}

OfflineRecognizerConfig GetConfig() const override {
return config_;
}

private:
void PostInit() {
config_.feat_config.nemo_normalize_type =
Expand Down
11 changes: 11 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
}

r.text = text;
r.lang = src.lang;

return r;
}
Expand Down Expand Up @@ -98,8 +99,18 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
}

void SetConfig(const OfflineRecognizerConfig &config) override {
config_.model_config.whisper = config.model_config.whisper;
}

OfflineRecognizerConfig GetConfig() const override {
return config_;
}

private:
void DecodeStream(OfflineStream *s) const {
decoder_->SetConfig(config_.model_config.whisper);

int32_t max_num_frames = 3000;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,12 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}

void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) {
impl_->SetConfig(config);
}

OfflineRecognizerConfig OfflineRecognizer::GetConfig() const {
return impl_->GetConfig();
}

} // namespace sherpa_onnx
9 changes: 9 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ class OfflineRecognizer {
*/
void DecodeStreams(OfflineStream **ss, int32_t n) const;

/** Onnxruntime Session objects are not affected by this method.
* The exact behavior can be defined by a specific recognizer impl.
* For instance, for the whisper recognizer, you can retrieve the language and task from
* the config and ignore any remaining fields in `config`.
*/
void SetConfig(const OfflineRecognizerConfig &config);

OfflineRecognizerConfig GetConfig() const;

private:
std::unique_ptr<OfflineRecognizerImpl> impl_;
};
Expand Down
4 changes: 3 additions & 1 deletion sherpa-onnx/csrc/offline-stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ struct OfflineRecognitionResult {
// For instance, for BPE-based models it consists of a list of BPE tokens.
std::vector<std::string> tokens;

/// timestamps.size() == tokens.size()
std::string lang;

/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;

Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/offline-whisper-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_

#include <vector>
#include <string>

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"

namespace sherpa_onnx {

struct OfflineWhisperDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
std::string lang;
};

class OfflineWhisperDecoder {
Expand All @@ -31,6 +34,9 @@ class OfflineWhisperDecoder {
*/
virtual std::vector<OfflineWhisperDecoderResult> Decode(
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;

virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;

};

} // namespace sherpa_onnx
Expand Down
13 changes: 12 additions & 1 deletion sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

namespace sherpa_onnx {

void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) {
config_ = config;
}

std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
Expand All @@ -24,7 +28,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
//
// For non-multilingual models, initial_tokens contains [sot]
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();

if (model_->IsMultiLingual()) {
if (!config_.language.empty()) {
const auto &lang2id = model_->GetLang2ID();
Expand Down Expand Up @@ -129,6 +133,13 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,

std::vector<OfflineWhisperDecoderResult> ans(1);

const auto &id2lang = model_->GetID2Lang();
if (id2lang.count(initial_tokens[1])) {
ans[0].lang = id2lang.at(initial_tokens[1]);
} else {
ans[0].lang = "";
}

ans[0].tokens = std::move(predicted_tokens);

return ans;
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <vector>

#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"

namespace sherpa_onnx {
Expand All @@ -22,6 +21,8 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
Ort::Value cross_v) override;

void SetConfig(const OfflineWhisperModelConfig &config) override;

private:
OfflineWhisperModelConfig config_;
OfflineWhisperModel *model_; // not owned
Expand Down