Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: find best embedding matches #1102

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,44 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) {
delete[] name;
}

const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *
SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold,
int32_t n) {
auto matches = p->impl->GetBestMatches(v, threshold, n);

if (matches.empty()) {
return nullptr;
}

auto resultMatches =
new SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch[matches.size()];
for (int i = 0; i < matches.size(); ++i) {
resultMatches[i].score = matches[i].score;

char *name = new char[matches[i].name.size() + 1];
std::copy(matches[i].name.begin(), matches[i].name.end(), name);
name[matches[i].name.size()] = '\0';

resultMatches[i].name = name;
}

auto *result = new SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult();
result->count = matches.size();
result->matches = resultMatches;

return result;
}

void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r) {
for (int32_t i = 0; i < r->count; ++i) {
delete[] r->matches[i].name;
}
delete[] r->matches;
delete r;
};

int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float *v, float threshold) {
Expand Down
33 changes: 33 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,39 @@ SHERPA_ONNX_API const char *SherpaOnnxSpeakerEmbeddingManagerSearch(
SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(
const char *name);

SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch {
float score;
const char *name;
} SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch;

SHERPA_ONNX_API typedef struct
SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult {
const SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch *matches;
int32_t count;
} SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult;

// Get the best matching speakers whose embeddings match the given
// embedding.
//
// @param p Pointer to the SherpaOnnxSpeakerEmbeddingManager instance.
// @param v Pointer to an array containing the embedding vector.
// @param threshold Minimum similarity score required for a match (between 0 and
// 1).
// @param n Number of best matches to retrieve.
// @return Returns a pointer to
// SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult
// containing the best matches found. Returns NULL if no matches are
// found. The caller is responsible for freeing the returned pointer
// using SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches() to
// avoid memory leaks.
SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *
SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold,
int32_t n);

SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r);

// Check whether the input embedding matches the embedding of the input
// speaker.
//
Expand Down
39 changes: 39 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,40 @@ class SpeakerEmbeddingManager::Impl {
return row2name_.at(max_index);
}

std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold,
int32_t n) {
std::vector<SpeakerMatch> matches;

if (embedding_matrix_.rows() == 0) {
return matches;
}

Eigen::VectorXf v =
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
v.normalize();

Eigen::VectorXf scores = embedding_matrix_ * v;

std::vector<std::pair<float, int>> score_indices;
for (int i = 0; i < scores.size(); ++i) {
if (scores[i] >= threshold) {
score_indices.emplace_back(scores[i], i);
}
}

std::sort(score_indices.rbegin(), score_indices.rend(),
[](const auto &a, const auto &b) { return a.first < b.first; });

matches.reserve(score_indices.size());
for (int i = 0; i < std::min(n, static_cast<int32_t>(score_indices.size()));
thewh1teagle marked this conversation as resolved.
Show resolved Hide resolved
++i) {
const auto &pair = score_indices[i];
matches.push_back({row2name_.at(pair.second), pair.first});
}

return matches;
}

bool Verify(const std::string &name, const float *p, float threshold) {
if (!name2row_.count(name)) {
return false;
Expand Down Expand Up @@ -219,6 +253,11 @@ std::string SpeakerEmbeddingManager::Search(const float *p,
return impl_->Search(p, threshold);
}

std::vector<SpeakerMatch> SpeakerEmbeddingManager::GetBestMatches(
const float *p, float threshold, int32_t n) const {
return impl_->GetBestMatches(p, threshold, n);
}

bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
float threshold) const {
return impl_->Verify(name, p, threshold);
Expand Down
24 changes: 24 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
#include <string>
#include <vector>

struct SpeakerMatch {
const std::string name;
float score;
};

namespace sherpa_onnx {

class SpeakerEmbeddingManager {
Expand Down Expand Up @@ -62,6 +67,25 @@ class SpeakerEmbeddingManager {
*/
std::string Search(const float *p, float threshold) const;

/**
* It is for speaker identification.
*
* It computes the cosine similarity between a given embedding and all
* other embeddings and finds the embeddings that have the largest scores
* and the scores are above or equal to the threshold. Returns a vector of
* SpeakerMatch structures containing the speaker names and scores for the
* embeddings if found; otherwise, returns an empty vector.
*
* @param p A pointer to the input embedding.
* @param threshold A value between 0 and 1.
* @param n The number of top matches to return.
* @return A vector of SpeakerMatch structures. If matches are found, the
* vector contains the names and scores of the speakers. Otherwise,
* it returns an empty vector.
*/
std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold,
int32_t n) const;

/* Check whether the input embedding matches the embedding of the input
* speaker.
*
Expand Down