diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 4a11cae29..099401d2a 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -24,6 +24,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-punctuation.h" #include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/online-punctuation.h" #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/resample.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" @@ -1717,6 +1718,53 @@ const char *SherpaOfflinePunctuationAddPunct( void SherpaOfflinePunctuationFreeText(const char *text) { delete[] text; } +struct SherpaOnnxOnlinePunctuation { + std::unique_ptr impl; +}; + +const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation( + const SherpaOnnxOnlinePunctuationConfig *config) { + auto p = new SherpaOnnxOnlinePunctuation; + try { + sherpa_onnx::OnlinePunctuationConfig punctuation_config; + punctuation_config.model.cnn_bilstm = config->model.cnn_bilstm; + punctuation_config.model.bpe_vocab = config->model.bpe_vocab; + punctuation_config.model.num_threads = config->model.num_threads; + punctuation_config.model.debug = config->model.debug; + punctuation_config.model.provider = config->model.provider; + + p->impl = + std::make_unique(punctuation_config); + } catch (const std::exception &e) { + SHERPA_ONNX_LOGE("Failed to create online punctuation: %s", e.what()); + delete p; + return nullptr; + } + return p; +} + +void SherpaOnnxDestroyOnlinePunctuation(const SherpaOnnxOnlinePunctuation *p) { + delete p; +} + +const char *SherpaOnnxOnlinePunctuationAddPunct( + const SherpaOnnxOnlinePunctuation *punctuation, const char *text) { + if (!punctuation || !text) return nullptr; + + try { + std::string s = punctuation->impl->AddPunctuationWithCase(text); + char *p = new char[s.size() + 1]; + std::copy(s.begin(), s.end(), p); + p[s.size()] = '\0'; + return p; + } catch (const std::exception &e) { + SHERPA_ONNX_LOGE("Failed to add punctuation: %s", e.what()); + return nullptr; + } +} + +void SherpaOnnxOnlinePunctuationFreeText(const char *text) { delete[] text; } + struct SherpaOnnxLinearResampler { std::unique_ptr impl; }; diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 4d4a2c4fc..990c94cb3 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1369,6 +1369,39 @@ SHERPA_ONNX_API const char *SherpaOfflinePunctuationAddPunct( SHERPA_ONNX_API void SherpaOfflinePunctuationFreeText(const char *text); +SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationModelConfig { + const char *cnn_bilstm; + const char *bpe_vocab; + int32_t num_threads; + int32_t debug; + const char *provider; +} SherpaOnnxOnlinePunctuationModelConfig; + +SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig { + SherpaOnnxOnlinePunctuationModelConfig model; +} SherpaOnnxOnlinePunctuationConfig; + +SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation; + +// Create an online punctuation processor. The user has to invoke +// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer +// to avoid memory leak +SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation( + const SherpaOnnxOnlinePunctuationConfig *config); + +// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation() +SHERPA_ONNX_API void SherpaOnnxDestroyOnlinePunctuation( + const SherpaOnnxOnlinePunctuation *punctuation); + +// Add punctuations to the input text. The user has to invoke +// SherpaOnnxOnlinePunctuationFreeText() to free the returned pointer +// to avoid memory leak +SHERPA_ONNX_API const char *SherpaOnnxOnlinePunctuationAddPunct( + const SherpaOnnxOnlinePunctuation *punctuation, const char *text); + +// Free a pointer returned by SherpaOnnxOnlinePunctuationAddPunct() +SHERPA_ONNX_API void SherpaOnnxOnlinePunctuationFreeText(const char *text); + // for resampling SHERPA_ONNX_API typedef struct SherpaOnnxLinearResampler SherpaOnnxLinearResampler; diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index 661ebba28..b100ef408 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -1095,6 +1095,52 @@ class SherpaOnnxOfflinePunctuationWrapper { } } +func sherpaOnnxOnlinePunctuationModelConfig( + cnnBiLstm: String, + bpeVocab: String, + numThreads: Int = 1, + debug: Int = 0, + provider: String = "cpu" +) -> SherpaOnnxOnlinePunctuationModelConfig { + return SherpaOnnxOnlinePunctuationModelConfig( + cnn_bilstm: toCPointer(cnnBiLstm), + bpe_vocab: toCPointer(bpeVocab), + num_threads: Int32(numThreads), + debug: Int32(debug), + provider: toCPointer(provider)) +} + +func sherpaOnnxOnlinePunctuationConfig( + model: SherpaOnnxOnlinePunctuationModelConfig +) -> SherpaOnnxOnlinePunctuationConfig { + return SherpaOnnxOnlinePunctuationConfig(model: model) +} + +class SherpaOnnxOnlinePunctuationWrapper { + /// A pointer to the underlying counterpart in C + let ptr: OpaquePointer! + + /// Constructor taking a model config + init( + config: UnsafePointer! + ) { + ptr = SherpaOnnxCreateOnlinePunctuation(config) + } + + deinit { + if let ptr { + SherpaOnnxDestroyOnlinePunctuation(ptr) + } + } + + func addPunct(text: String) -> String { + let cText = SherpaOnnxOnlinePunctuationAddPunct(ptr, toCPointer(text)) + let ans = String(cString: cText!) + SherpaOnnxOnlinePunctuationFreeText(cText) + return ans + } +} + func sherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(model: String) -> SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig { diff --git a/swift-api-examples/add-punctuation-online.swift b/swift-api-examples/add-punctuation-online.swift new file mode 100644 index 000000000..79af921eb --- /dev/null +++ b/swift-api-examples/add-punctuation-online.swift @@ -0,0 +1,35 @@ +func run() { + let model = "./sherpa-onnx-online-punct-en-2024-08-06/model.onnx" + let bpe = "./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab" + + // Create model config + let modelConfig = sherpaOnnxOnlinePunctuationModelConfig( + cnnBiLstm: model, + bpeVocab: bpe + ) + + // Create punctuation config + var config = sherpaOnnxOnlinePunctuationConfig(model: modelConfig) + + // Create punctuation instance + let punct = SherpaOnnxOnlinePunctuationWrapper(config: &config) + + // Test texts + let textList = [ + "how are you i am fine thank you", + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry" + ] + + // Process each text + for i in 0..