From 1bae4085caa61012af123554858f14f8649091ea Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 10 Dec 2024 16:03:03 +0800 Subject: [PATCH] Add speaker diarization API for HarmonyOS. (#1609) --- .../SherpaOnnxHar/sherpa_onnx/Index.ets | 29 ++++---- .../cpp/non-streaming-speaker-diarization.cc | 30 ++++++++ .../main/cpp/types/libsherpa_onnx/Index.d.ts | 5 ++ .../sherpa_onnx/src/main/cpp/wave-writer.cc | 7 +- .../NonStreamingSpeakerDiarization.ets | 73 +++++++++++++++++++ .../ets/components/SpeakerIdentification.ets | 10 +-- .../src/main/ets/components/StreamingAsr.ets | 3 +- .../src/main/ets/components/Vad.ets | 6 +- .../lib/non-streaming-speaker-diarization.js | 2 +- sherpa-onnx/c-api/c-api.cc | 51 +++++++++++-- sherpa-onnx/c-api/c-api.h | 5 ++ .../csrc/offline-speaker-diarization-impl.cc | 24 +++++- .../csrc/offline-speaker-diarization-impl.h | 10 +-- ...ffline-speaker-diarization-pyannote-impl.h | 22 ++++-- .../csrc/offline-speaker-diarization.cc | 24 +++++- .../csrc/offline-speaker-diarization.h | 10 +-- ...ine-speaker-segmentation-pyannote-model.cc | 37 ++++++++-- ...line-speaker-segmentation-pyannote-model.h | 10 +-- 18 files changed, 279 insertions(+), 79 deletions(-) create mode 100644 harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/NonStreamingSpeakerDiarization.ets diff --git a/harmony-os/SherpaOnnxHar/sherpa_onnx/Index.ets b/harmony-os/SherpaOnnxHar/sherpa_onnx/Index.ets index 5132df5f1..16c6279e1 100644 --- a/harmony-os/SherpaOnnxHar/sherpa_onnx/Index.ets +++ b/harmony-os/SherpaOnnxHar/sherpa_onnx/Index.ets @@ -1,11 +1,6 @@ -export { - listRawfileDir, - readWave, - readWaveFromBinary, -} from "libsherpa_onnx.so"; +export { listRawfileDir, readWave, readWaveFromBinary, } from "libsherpa_onnx.so"; -export { - CircularBuffer, +export { CircularBuffer, SileroVadConfig, SpeechSegment, Vad, @@ -13,8 +8,7 @@ export { } from './src/main/ets/components/Vad'; -export { - Samples, +export { Samples, OfflineStream, FeatureConfig, OfflineTransducerModelConfig, @@ -31,8 +25,7 @@ export { OfflineRecognizer, } from './src/main/ets/components/NonStreamingAsr'; -export { - OnlineStream, +export { OnlineStream, OnlineTransducerModelConfig, OnlineParaformerModelConfig, OnlineZipformer2CtcModelConfig, @@ -43,8 +36,7 @@ export { OnlineRecognizer, } from './src/main/ets/components/StreamingAsr'; -export { - OfflineTtsVitsModelConfig, +export { OfflineTtsVitsModelConfig, OfflineTtsModelConfig, OfflineTtsConfig, OfflineTts, @@ -52,8 +44,15 @@ export { TtsInput, } from './src/main/ets/components/NonStreamingTts'; -export { - SpeakerEmbeddingExtractorConfig, +export { SpeakerEmbeddingExtractorConfig, SpeakerEmbeddingExtractor, SpeakerEmbeddingManager, } from './src/main/ets/components/SpeakerIdentification'; + +export { OfflineSpeakerSegmentationPyannoteModelConfig, + OfflineSpeakerSegmentationModelConfig, + OfflineSpeakerDiarizationConfig, + OfflineSpeakerDiarizationSegment, + OfflineSpeakerDiarization, + FastClusteringConfig, +} from './src/main/ets/components/NonStreamingSpeakerDiarization'; diff --git a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/non-streaming-speaker-diarization.cc b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/non-streaming-speaker-diarization.cc index a35f7924a..2cda40e76 100644 --- a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/non-streaming-speaker-diarization.cc +++ b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/non-streaming-speaker-diarization.cc @@ -101,6 +101,17 @@ static SherpaOnnxFastClusteringConfig GetFastClusteringConfig( static Napi::External CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); + +#if __OHOS__ + if (info.Length() != 2) { + std::ostringstream os; + os << "Expect only 2 arguments. Given: " << info.Length(); + + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException(); + + return {}; + } +#else if (info.Length() != 1) { std::ostringstream os; os << "Expect only 1 argument. Given: " << info.Length(); @@ -109,6 +120,7 @@ CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) { return {}; } +#endif if (!info[0].IsObject()) { Napi::TypeError::New(env, "Expect an object as the argument") @@ -129,8 +141,18 @@ CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) { SHERPA_ONNX_ASSIGN_ATTR_FLOAT(min_duration_on, minDurationOn); SHERPA_ONNX_ASSIGN_ATTR_FLOAT(min_duration_off, minDurationOff); +#if __OHOS__ + std::unique_ptr + mgr(OH_ResourceManager_InitNativeResourceManager(env, info[1]), + &OH_ResourceManager_ReleaseNativeResourceManager); + + const SherpaOnnxOfflineSpeakerDiarization *sd = + SherpaOnnxCreateOfflineSpeakerDiarizationOHOS(&c, mgr.get()); +#else const SherpaOnnxOfflineSpeakerDiarization *sd = SherpaOnnxCreateOfflineSpeakerDiarization(&c); +#endif if (c.segmentation.pyannote.model) { delete[] c.segmentation.pyannote.model; @@ -224,9 +246,17 @@ static Napi::Array OfflineSpeakerDiarizationProcessWrapper( Napi::Float32Array samples = info[1].As(); +#if __OHOS__ + // Note(fangjun): For unknown reasons on HarmonyOS, we need to divide it by + // sizeof(float) here + const SherpaOnnxOfflineSpeakerDiarizationResult *r = + SherpaOnnxOfflineSpeakerDiarizationProcess( + sd, samples.Data(), samples.ElementLength() / sizeof(float)); +#else const SherpaOnnxOfflineSpeakerDiarizationResult *r = SherpaOnnxOfflineSpeakerDiarizationProcess(sd, samples.Data(), samples.ElementLength()); +#endif int32_t num_segments = SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r); diff --git a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/types/libsherpa_onnx/Index.d.ts b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/types/libsherpa_onnx/Index.d.ts index d2b6d6ea4..f71e2f6ee 100644 --- a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/types/libsherpa_onnx/Index.d.ts +++ b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/types/libsherpa_onnx/Index.d.ts @@ -62,3 +62,8 @@ export const speakerEmbeddingManagerVerify: (handle: object, obj: {name: string, export const speakerEmbeddingManagerContains: (handle: object, name: string) => boolean; export const speakerEmbeddingManagerNumSpeakers: (handle: object) => number; export const speakerEmbeddingManagerGetAllSpeakers: (handle: object) => Array; + +export const createOfflineSpeakerDiarization: (config: object, mgr?: object) => object; +export const getOfflineSpeakerDiarizationSampleRate: (handle: object) => number; +export const offlineSpeakerDiarizationProcess: (handle: object, samples: Float32Array) => object; +export const offlineSpeakerDiarizationSetConfig: (handle: object, config: object) => void; diff --git a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/wave-writer.cc b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/wave-writer.cc index 3ade695a0..8f6d7bcab 100644 --- a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/wave-writer.cc +++ b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/wave-writer.cc @@ -67,10 +67,15 @@ static Napi::Boolean WriteWaveWrapper(const Napi::CallbackInfo &info) { Napi::Float32Array samples = obj.Get("samples").As(); int32_t sample_rate = obj.Get("sampleRate").As().Int32Value(); - +#if __OHOS__ + int32_t ok = SherpaOnnxWriteWave( + samples.Data(), samples.ElementLength() / sizeof(float), sample_rate, + info[0].As().Utf8Value().c_str()); +#else int32_t ok = SherpaOnnxWriteWave(samples.Data(), samples.ElementLength(), sample_rate, info[0].As().Utf8Value().c_str()); +#endif return Napi::Boolean::New(env, ok); } diff --git a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/NonStreamingSpeakerDiarization.ets b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/NonStreamingSpeakerDiarization.ets new file mode 100644 index 000000000..c0dcd3af1 --- /dev/null +++ b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/NonStreamingSpeakerDiarization.ets @@ -0,0 +1,73 @@ +import { + createOfflineSpeakerDiarization, + getOfflineSpeakerDiarizationSampleRate, + offlineSpeakerDiarizationProcess, + offlineSpeakerDiarizationSetConfig, +} from 'libsherpa_onnx.so'; + +import { SpeakerEmbeddingExtractorConfig } from './SpeakerIdentification'; + +export class OfflineSpeakerSegmentationPyannoteModelConfig { + public model: string = ''; +} + +export class OfflineSpeakerSegmentationModelConfig { + public pyannote: OfflineSpeakerSegmentationPyannoteModelConfig = new OfflineSpeakerSegmentationPyannoteModelConfig(); + public numThreads: number = 1; + public debug: boolean = false; + public provider: string = 'cpu'; +} + +export class FastClusteringConfig { + public numClusters: number = -1; + public threshold: number = 0.5; +} + +export class OfflineSpeakerDiarizationConfig { + public segmentation: OfflineSpeakerSegmentationModelConfig = new OfflineSpeakerSegmentationModelConfig(); + public embedding: SpeakerEmbeddingExtractorConfig = new SpeakerEmbeddingExtractorConfig(); + public clustering: FastClusteringConfig = new FastClusteringConfig(); + public minDurationOn: number = 0.2; + public minDurationOff: number = 0.5; +} + +export class OfflineSpeakerDiarizationSegment { + public start: number = 0; // in secondspublic end: number = 0; // in secondspublic speaker: number = + 0; // ID of the speaker; count from 0 +} + +export class OfflineSpeakerDiarization { + public config: OfflineSpeakerDiarizationConfig; + public sampleRate: number; + private handle: object; + + constructor(config: OfflineSpeakerDiarizationConfig, mgr?: object) { + this.handle = createOfflineSpeakerDiarization(config, mgr); + this.config = config; + + this.sampleRate = getOfflineSpeakerDiarizationSampleRate(this.handle); + } + + /** + * samples is a 1-d float32 array. Each element of the array should be + * in the range [-1, 1]. + * + * We assume its sample rate equals to this.sampleRate. + * + * Returns an array of object, where an object is + * + * { + * "start": start_time_in_seconds, + * "end": end_time_in_seconds, + * "speaker": an_integer, + * } + */ + process(samples: Float32Array): OfflineSpeakerDiarizationSegment { + return offlineSpeakerDiarizationProcess(this.handle, samples) as OfflineSpeakerDiarizationSegment; + } + + setConfig(config: OfflineSpeakerDiarizationConfig) { + offlineSpeakerDiarizationSetConfig(this.handle, config); + this.config.clustering = config.clustering; + } +} diff --git a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/SpeakerIdentification.ets b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/SpeakerIdentification.ets index e490ab15f..50868dfb0 100644 --- a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/SpeakerIdentification.ets +++ b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/SpeakerIdentification.ets @@ -35,8 +35,7 @@ export class SpeakerEmbeddingExtractor { } createStream(): OnlineStream { - return new OnlineStream( - speakerEmbeddingExtractorCreateStream(this.handle)); + return new OnlineStream(speakerEmbeddingExtractorCreateStream(this.handle)); } isReady(stream: OnlineStream): boolean { @@ -44,8 +43,7 @@ export class SpeakerEmbeddingExtractor { } compute(stream: OnlineStream, enableExternalBuffer: boolean = true): Float32Array { - return speakerEmbeddingExtractorComputeEmbedding( - this.handle, stream.handle, enableExternalBuffer); + return speakerEmbeddingExtractorComputeEmbedding(this.handle, stream.handle, enableExternalBuffer); } } @@ -106,9 +104,7 @@ export class SpeakerEmbeddingManager { addMulti(speaker: SpeakerNameWithEmbeddingList): boolean { const c: SpeakerNameWithEmbeddingN = { - name: speaker.name, - vv: flatten(speaker.v), - n: speaker.v.length, + name: speaker.name, vv: flatten(speaker.v), n: speaker.v.length, }; return speakerEmbeddingManagerAddListFlattened(this.handle, c); } diff --git a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/StreamingAsr.ets b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/StreamingAsr.ets index 3b2985771..f8b3c61e4 100644 --- a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/StreamingAsr.ets +++ b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/StreamingAsr.ets @@ -125,8 +125,7 @@ export class OnlineRecognizer { } getResult(stream: OnlineStream): OnlineRecognizerResult { - const jsonStr: string = - getOnlineStreamResultAsJson(this.handle, stream.handle); + const jsonStr: string = getOnlineStreamResultAsJson(this.handle, stream.handle); let o = JSON.parse(jsonStr) as OnlineRecognizerResultJson; diff --git a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/Vad.ets b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/Vad.ets index 8f1bf18d6..cae2cbf13 100644 --- a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/Vad.ets +++ b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/ets/components/Vad.ets @@ -62,8 +62,7 @@ export class CircularBuffer { // return a float32 array get(startIndex: number, n: number, enableExternalBuffer: boolean = true): Float32Array { - return circularBufferGet( - this.handle, startIndex, n, enableExternalBuffer); + return circularBufferGet(this.handle, startIndex, n, enableExternalBuffer); } pop(n: number) { @@ -93,8 +92,7 @@ export class Vad { private handle: object; constructor(config: VadConfig, bufferSizeInSeconds?: number, mgr?: object) { - this.handle = - createVoiceActivityDetector(config, bufferSizeInSeconds, mgr); + this.handle = createVoiceActivityDetector(config, bufferSizeInSeconds, mgr); this.config = config; } diff --git a/scripts/node-addon-api/lib/non-streaming-speaker-diarization.js b/scripts/node-addon-api/lib/non-streaming-speaker-diarization.js index 8ec31ee10..37c4a7493 100644 --- a/scripts/node-addon-api/lib/non-streaming-speaker-diarization.js +++ b/scripts/node-addon-api/lib/non-streaming-speaker-diarization.js @@ -27,7 +27,7 @@ class OfflineSpeakerDiarization { } setConfig(config) { - addon.offlineSpeakerDiarizationSetConfig(config); + addon.offlineSpeakerDiarizationSetConfig(this.handle, config); this.config.clustering = config.clustering; } } diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 7748b9fee..84b33eb0d 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -1784,8 +1784,8 @@ struct SherpaOnnxOfflineSpeakerDiarizationResult { sherpa_onnx::OfflineSpeakerDiarizationResult impl; }; -const SherpaOnnxOfflineSpeakerDiarization * -SherpaOnnxCreateOfflineSpeakerDiarization( +static sherpa_onnx::OfflineSpeakerDiarizationConfig +GetOfflineSpeakerDiarizationConfig( const SherpaOnnxOfflineSpeakerDiarizationConfig *config) { sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config; @@ -1820,6 +1820,22 @@ SherpaOnnxCreateOfflineSpeakerDiarization( sd_config.min_duration_off = SHERPA_ONNX_OR(config->min_duration_off, 0.5); + if (sd_config.segmentation.debug || sd_config.embedding.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", sd_config.ToString().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", sd_config.ToString().c_str()); +#endif + } + + return sd_config; +} + +const SherpaOnnxOfflineSpeakerDiarization * +SherpaOnnxCreateOfflineSpeakerDiarization( + const SherpaOnnxOfflineSpeakerDiarizationConfig *config) { + auto sd_config = GetOfflineSpeakerDiarizationConfig(config); + if (!sd_config.Validate()) { SHERPA_ONNX_LOGE("Errors in config"); return nullptr; @@ -1831,10 +1847,6 @@ SherpaOnnxCreateOfflineSpeakerDiarization( sd->impl = std::make_unique(sd_config); - if (sd_config.segmentation.debug || sd_config.embedding.debug) { - SHERPA_ONNX_LOGE("%s\n", sd_config.ToString().c_str()); - } - return sd; } @@ -2029,5 +2041,32 @@ SherpaOnnxOfflineTts *SherpaOnnxCreateOfflineTtsOHOS( } #endif // #if SHERPA_ONNX_ENABLE_TTS == 1 + // +#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 +const SherpaOnnxOfflineSpeakerDiarization * +SherpaOnnxCreateOfflineSpeakerDiarizationOHOS( + const SherpaOnnxOfflineSpeakerDiarizationConfig *config, + NativeResourceManager *mgr) { + if (!mgr) { + return SherpaOnnxCreateOfflineSpeakerDiarization(config); + } + + auto sd_config = GetOfflineSpeakerDiarizationConfig(config); + + if (!sd_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaOnnxOfflineSpeakerDiarization *sd = + new SherpaOnnxOfflineSpeakerDiarization; + + sd->impl = + std::make_unique(mgr, sd_config); + + return sd; +} + +#endif // #if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 #endif // #ifdef __OHOS__ diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 111aae779..a781520ff 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1577,6 +1577,11 @@ SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingExtractor * SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS( const SherpaOnnxSpeakerEmbeddingExtractorConfig *config, NativeResourceManager *mgr); + +SHERPA_ONNX_API const SherpaOnnxOfflineSpeakerDiarization * +SherpaOnnxCreateOfflineSpeakerDiarizationOHOS( + const SherpaOnnxOfflineSpeakerDiarizationConfig *config, + NativeResourceManager *mgr); #endif #if defined(__GNUC__) diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc index 15c3a2eb4..c0af5c7c4 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc @@ -6,6 +6,15 @@ #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h" @@ -23,10 +32,10 @@ OfflineSpeakerDiarizationImpl::Create( return nullptr; } -#if __ANDROID_API__ >= 9 +template std::unique_ptr OfflineSpeakerDiarizationImpl::Create( - AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) { + Manager *mgr, const OfflineSpeakerDiarizationConfig &config) { if (!config.segmentation.pyannote.model.empty()) { return std::make_unique(mgr, config); } @@ -35,6 +44,17 @@ OfflineSpeakerDiarizationImpl::Create( return nullptr; } + +#if __ANDROID_API__ >= 9 +template std::unique_ptr +OfflineSpeakerDiarizationImpl::Create( + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr +OfflineSpeakerDiarizationImpl::Create( + NativeResourceManager *mgr, const OfflineSpeakerDiarizationConfig &config); #endif } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h index 41f0e1e2f..d2cbdebd2 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h @@ -8,11 +8,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "sherpa-onnx/csrc/offline-speaker-diarization.h" namespace sherpa_onnx { @@ -21,10 +16,9 @@ class OfflineSpeakerDiarizationImpl { static std::unique_ptr Create( const OfflineSpeakerDiarizationConfig &config); -#if __ANDROID_API__ >= 9 + template static std::unique_ptr Create( - AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config); -#endif + Manager *mgr, const OfflineSpeakerDiarizationConfig &config); virtual ~OfflineSpeakerDiarizationImpl() = default; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 51d712eb8..e8228d47b 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -11,11 +11,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "Eigen/Dense" #include "sherpa-onnx/csrc/fast-clustering.h" #include "sherpa-onnx/csrc/math.h" @@ -71,16 +66,15 @@ class OfflineSpeakerDiarizationPyannoteImpl Init(); } -#if __ANDROID_API__ >= 9 + template OfflineSpeakerDiarizationPyannoteImpl( - AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) + Manager *mgr, const OfflineSpeakerDiarizationConfig &config) : config_(config), segmentation_model_(mgr, config_.segmentation), embedding_extractor_(mgr, config_.embedding), clustering_(std::make_unique(config_.clustering)) { Init(); } -#endif int32_t SampleRate() const override { const auto &meta_data = segmentation_model_.GetModelMetaData(); @@ -213,8 +207,13 @@ class OfflineSpeakerDiarizationPyannoteImpl } } } else { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "powerset_max_classes = %{public}d is currently not supported!", i); +#else SHERPA_ONNX_LOGE( "powerset_max_classes = %d is currently not supported!", i); +#endif SHERPA_ONNX_EXIT(-1); } } @@ -229,10 +228,17 @@ class OfflineSpeakerDiarizationPyannoteImpl int32_t window_shift = meta_data.window_shift; if (n <= 0) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "number of audio samples is %{public}d (<= 0). Please provide a " + "positive number", + n); +#else SHERPA_ONNX_LOGE( "number of audio samples is %d (<= 0). Please provide a positive " "number", n); +#endif return {}; } diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index a4b021b73..1e861ab88 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -7,6 +7,15 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" namespace sherpa_onnx { @@ -74,11 +83,10 @@ OfflineSpeakerDiarization::OfflineSpeakerDiarization( const OfflineSpeakerDiarizationConfig &config) : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {} -#if __ANDROID_API__ >= 9 +template OfflineSpeakerDiarization::OfflineSpeakerDiarization( - AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) + Manager *mgr, const OfflineSpeakerDiarizationConfig &config) : impl_(OfflineSpeakerDiarizationImpl::Create(mgr, config)) {} -#endif OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; @@ -98,4 +106,14 @@ OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process( return impl_->Process(audio, n, std::move(callback), callback_arg); } +#if __ANDROID_API__ >= 9 +template OfflineSpeakerDiarization::OfflineSpeakerDiarization( + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config); +#endif + +#if __OHOS__ +template OfflineSpeakerDiarization::OfflineSpeakerDiarization( + NativeResourceManager *mgr, const OfflineSpeakerDiarizationConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index 4a517fbb2..acbb6f524 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -9,11 +9,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "sherpa-onnx/csrc/fast-clustering-config.h" #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" @@ -62,10 +57,9 @@ class OfflineSpeakerDiarization { explicit OfflineSpeakerDiarization( const OfflineSpeakerDiarizationConfig &config); -#if __ANDROID_API__ >= 9 - OfflineSpeakerDiarization(AAssetManager *mgr, + template + OfflineSpeakerDiarization(Manager *mgr, const OfflineSpeakerDiarizationConfig &config); -#endif ~OfflineSpeakerDiarization(); diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc index e3768dcf4..093e871b4 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc @@ -8,6 +8,15 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" @@ -24,8 +33,8 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl { Init(buf.data(), buf.size()); } -#if __ANDROID_API__ >= 9 - Impl(AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config) + template + Impl(Manager *mgr, const OfflineSpeakerSegmentationModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), sess_opts_(GetSessionOptions(config)), @@ -33,7 +42,6 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl { auto buf = ReadFile(mgr, config_.pyannote.model); Init(buf.data(), buf.size()); } -#endif const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() const { @@ -61,7 +69,11 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl { if (config_.debug) { std::ostringstream os; PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -103,12 +115,11 @@ OfflineSpeakerSegmentationPyannoteModel:: const OfflineSpeakerSegmentationModelConfig &config) : impl_(std::make_unique(config)) {} -#if __ANDROID_API__ >= 9 +template OfflineSpeakerSegmentationPyannoteModel:: OfflineSpeakerSegmentationPyannoteModel( - AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config) + Manager *mgr, const OfflineSpeakerSegmentationModelConfig &config) : impl_(std::make_unique(mgr, config)) {} -#endif OfflineSpeakerSegmentationPyannoteModel:: ~OfflineSpeakerSegmentationPyannoteModel() = default; @@ -123,4 +134,18 @@ Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward( return impl_->Forward(std::move(x)); } +#if __ANDROID_API__ >= 9 +template OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + AAssetManager *mgr, + const OfflineSpeakerSegmentationModelConfig &config); +#endif + +#if __OHOS__ +template OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + NativeResourceManager *mgr, + const OfflineSpeakerSegmentationModelConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h index 6b835763b..a3cc7ed3f 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h @@ -6,11 +6,6 @@ #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" @@ -22,10 +17,9 @@ class OfflineSpeakerSegmentationPyannoteModel { explicit OfflineSpeakerSegmentationPyannoteModel( const OfflineSpeakerSegmentationModelConfig &config); -#if __ANDROID_API__ >= 9 + template OfflineSpeakerSegmentationPyannoteModel( - AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config); -#endif + Manager *mgr, const OfflineSpeakerSegmentationModelConfig &config); ~OfflineSpeakerSegmentationPyannoteModel();