From 7d9c775f211aaea074f6cb9155f6632313c90446 Mon Sep 17 00:00:00 2001 From: thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Date: Wed, 10 Jul 2024 10:04:25 +0300 Subject: [PATCH 1/5] feat: find best embedding matches --- sherpa-onnx/c-api/c-api.cc | 37 ++++++++++++++++++ sherpa-onnx/c-api/c-api.h | 33 ++++++++++++++++ sherpa-onnx/csrc/speaker-embedding-manager.cc | 38 +++++++++++++++++++ sherpa-onnx/csrc/speaker-embedding-manager.h | 24 ++++++++++++ 4 files changed, 132 insertions(+) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index e23305fb7..017f3d22b 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -1256,6 +1256,43 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) { delete[] name; } +const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult * +SherpaOnnxSpeakerEmbeddingManagerGetBestMatches( + const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold, + int32_t n) { + auto matches = p->impl->GetBestMatches(v, threshold, n); + + if (matches.empty()) { + return nullptr; + } + + auto *result = new SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult(); + result->count = matches.size(); + result->matches = + new SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch[matches.size()]; + + for (int i = 0; i < matches.size(); ++i) { + result->matches[i].score = matches[i].score; + + char *name = new char[matches[i].name.size() + 1]; + std::copy(matches[i].name.begin(), matches[i].name.end(), name); + name[matches[i].name.size()] = '\0'; + + result->matches[i].name = name; + } + + return result; +} + +void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches( + const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r) { + for (int32_t i = 0; i < r->count; ++i) { + delete[] r->matches[i].name; + } + delete[] r->matches; + delete r; +}; + int32_t SherpaOnnxSpeakerEmbeddingManagerVerify( const SherpaOnnxSpeakerEmbeddingManager *p, const char *name, const float *v, float threshold) { diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 2bfba98c7..8aa65e419 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1109,6 +1109,39 @@ SHERPA_ONNX_API const char *SherpaOnnxSpeakerEmbeddingManagerSearch( SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch( const char *name); +SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch { + float score; + char *name; +} SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch; + +SHERPA_ONNX_API typedef struct + SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult { + SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch *matches; + int32_t count; +} SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult; + +// Get the best matching speakers whose embeddings match the given +// embedding. +// +// @param p Pointer to the SherpaOnnxSpeakerEmbeddingManager instance. +// @param v Pointer to an array containing the embedding vector. +// @param threshold Minimum similarity score required for a match (between 0 and +// 1). +// @param n Number of best matches to retrieve. +// @return Returns a pointer to +// SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult +// containing the best matches found. Returns NULL if no matches are +// found. The caller is responsible for freeing the returned pointer +// using SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches() to +// avoid memory leaks. +SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult * +SherpaOnnxSpeakerEmbeddingManagerGetBestMatches( + const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold, + int32_t n); + +SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches( + const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r); + // Check whether the input embedding matches the embedding of the input // speaker. // diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.cc b/sherpa-onnx/csrc/speaker-embedding-manager.cc index 6c90c1953..54fee2323 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager.cc +++ b/sherpa-onnx/csrc/speaker-embedding-manager.cc @@ -131,6 +131,39 @@ class SpeakerEmbeddingManager::Impl { return row2name_.at(max_index); } + std::vector GetBestMatches(const float *p, float threshold, + int32_t n) { + std::vector matches; + + if (embedding_matrix_.rows() == 0) { + return matches; + } + + Eigen::VectorXf v = + Eigen::Map(const_cast(p), dim_); + v.normalize(); + + Eigen::VectorXf scores = embedding_matrix_ * v; + + std::vector> score_indices; + for (int i = 0; i < scores.size(); ++i) { + if (scores[i] >= threshold) { + score_indices.emplace_back(scores[i], i); + } + } + + std::sort(score_indices.rbegin(), score_indices.rend(), + [](const auto &a, const auto &b) { return a.first < b.first; }); + + for (int i = 0; i < std::min(n, static_cast(score_indices.size())); + ++i) { + const auto &pair = score_indices[i]; + matches.push_back({row2name_.at(pair.second), pair.first}); + } + + return matches; + } + bool Verify(const std::string &name, const float *p, float threshold) { if (!name2row_.count(name)) { return false; @@ -219,6 +252,11 @@ std::string SpeakerEmbeddingManager::Search(const float *p, return impl_->Search(p, threshold); } +std::vector SpeakerEmbeddingManager::GetBestMatches( + const float *p, float threshold, int32_t n) const { + return impl_->GetBestMatches(p, threshold, n); +} + bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p, float threshold) const { return impl_->Verify(name, p, threshold); diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.h b/sherpa-onnx/csrc/speaker-embedding-manager.h index ae8728b13..9490765ca 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager.h +++ b/sherpa-onnx/csrc/speaker-embedding-manager.h @@ -9,6 +9,11 @@ #include #include +struct SpeakerMatch { + const std::string name; + float score; +}; + namespace sherpa_onnx { class SpeakerEmbeddingManager { @@ -62,6 +67,25 @@ class SpeakerEmbeddingManager { */ std::string Search(const float *p, float threshold) const; + /** + * It is for speaker identification. + * + * It computes the cosine similarity between a given embedding and all + * other embeddings and finds the embeddings that have the largest scores + * and the scores are above or equal to the threshold. Returns a vector of + * SpeakerMatch structures containing the speaker names and scores for the + * embeddings if found; otherwise, returns an empty vector. + * + * @param p A pointer to the input embedding. + * @param threshold A value between 0 and 1. + * @param n The number of top matches to return. + * @return A vector of SpeakerMatch structures. If matches are found, the + * vector contains the names and scores of the speakers. Otherwise, + * it returns an empty vector. + */ + std::vector GetBestMatches(const float *p, float threshold, + int32_t n) const; + /* Check whether the input embedding matches the embedding of the input * speaker. * From 4c550730ebe68f111bf25c2225500bcf3c957190 Mon Sep 17 00:00:00 2001 From: thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Date: Wed, 10 Jul 2024 10:13:59 +0300 Subject: [PATCH 2/5] chore: add const to speaker match name Co-authored-by: Fangjun Kuang --- sherpa-onnx/c-api/c-api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 8aa65e419..fa8cd8587 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1111,7 +1111,7 @@ SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch( SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch { float score; - char *name; + const char *name; } SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch; SHERPA_ONNX_API typedef struct From 87573242a8231e9cacc2850bfd847c494650beb6 Mon Sep 17 00:00:00 2001 From: thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Date: Wed, 10 Jul 2024 19:41:00 +0300 Subject: [PATCH 3/5] Update sherpa-onnx/csrc/speaker-embedding-manager.cc Co-authored-by: Fangjun Kuang --- sherpa-onnx/csrc/speaker-embedding-manager.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.cc b/sherpa-onnx/csrc/speaker-embedding-manager.cc index 54fee2323..701fa6e18 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager.cc +++ b/sherpa-onnx/csrc/speaker-embedding-manager.cc @@ -155,6 +155,7 @@ class SpeakerEmbeddingManager::Impl { std::sort(score_indices.rbegin(), score_indices.rend(), [](const auto &a, const auto &b) { return a.first < b.first; }); + matches.reserve(score_indices.size()); for (int i = 0; i < std::min(n, static_cast(score_indices.size())); ++i) { const auto &pair = score_indices[i]; From d1a50a19e81c10c9b6b55a032027317adeb85c78 Mon Sep 17 00:00:00 2001 From: thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Date: Wed, 10 Jul 2024 19:41:07 +0300 Subject: [PATCH 4/5] Update sherpa-onnx/c-api/c-api.h Co-authored-by: Fangjun Kuang --- sherpa-onnx/c-api/c-api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index fa8cd8587..4beba2a73 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1116,7 +1116,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch { SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult { - SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch *matches; + const SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch *matches; int32_t count; } SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult; From 7d2341d1fe1ca45a2890a554fe2dc5360d6f15da Mon Sep 17 00:00:00 2001 From: thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Date: Wed, 10 Jul 2024 20:27:21 +0300 Subject: [PATCH 5/5] fix: return constant matches --- sherpa-onnx/c-api/c-api.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 017f3d22b..eb9ec8752 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -1266,21 +1266,22 @@ SherpaOnnxSpeakerEmbeddingManagerGetBestMatches( return nullptr; } - auto *result = new SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult(); - result->count = matches.size(); - result->matches = + auto resultMatches = new SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch[matches.size()]; - for (int i = 0; i < matches.size(); ++i) { - result->matches[i].score = matches[i].score; + resultMatches[i].score = matches[i].score; char *name = new char[matches[i].name.size() + 1]; std::copy(matches[i].name.begin(), matches[i].name.end(), name); name[matches[i].name.size()] = '\0'; - result->matches[i].name = name; + resultMatches[i].name = name; } + auto *result = new SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult(); + result->count = matches.size(); + result->matches = resultMatches; + return result; }