Skip to content

Commit

Permalink
FBL: update onnx model
Browse files Browse the repository at this point in the history
  • Loading branch information
wolfgitpr committed Jun 21, 2024
1 parent e666aa5 commit e24a9d6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 52 deletions.
38 changes: 3 additions & 35 deletions src/apps/FoxBreatheLabeler/util/Fbl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ namespace FBL {

// 获取参数
m_audio_sample_rate = config["audio_sample_rate"].as<int>();
m_spec_win = config["spec_win"].as<int>();
m_hop_size = config["hop_size"].as<int>();

m_time_scale = 1.0 / (m_audio_sample_rate / m_hop_size);
m_time_scale = 1 / (static_cast<float>(m_audio_sample_rate) / static_cast<float>(m_hop_size));

if (!m_fblModel) {
qDebug() << "Cannot load ASR Model, there must be files model.onnx and vocab.txt";
Expand All @@ -51,36 +50,6 @@ namespace FBL {
}
}

// Function to unfold the signal into frames
std::vector<std::vector<float>> unfold_signal(const std::vector<float> &y, int frame_length, int hop_length) {
const int num_frames = (static_cast<int>(y.size()) - frame_length) / hop_length + 1;
std::vector<std::vector<float>> unfolded;
for (int i = 0; i < num_frames; ++i) {
const int start = i * hop_length;
const int end = start + frame_length;
unfolded.emplace_back(y.begin() + start, y.begin() + end);
}
return unfolded;
}

// Function to get music chunks
std::vector<std::vector<float>> get_music_chunk(const std::vector<float> &y, int frame_length = 2048,
int hop_length = 512, const std::string &pad_mode = "constant") {
const std::pair<int, int> padding = {(frame_length - hop_length) / 2, (frame_length - hop_length + 1) / 2};
const std::vector<float> y_padded = pad_signal(y, padding, pad_mode);
std::vector<std::vector<float>> y_f = unfold_signal(y_padded, frame_length, hop_length);

return y_f;
}

std::vector<float> sigmoid(const std::vector<float> &arr) {
std::vector<float> result;
result.reserve(arr.size());
for (const auto &item : arr)
result.push_back(1 / (1 + std::exp(-item)));
return result;
}

static std::vector<std::pair<float, float>> findSegmentsDynamic(const std::vector<float> &arr, double time_scale,
double threshold = 0.5, int max_gap = 5,
int ap_threshold = 10) {
Expand Down Expand Up @@ -140,9 +109,8 @@ namespace FBL {

std::string modelMsg;
std::vector<float> modelRes;
if (m_fblModel->forward(get_music_chunk(tmp, m_spec_win, m_hop_size), modelRes, modelMsg)) {
res = findSegmentsDynamic(sigmoid(modelRes), m_time_scale, ap_threshold,
static_cast<int>(ap_dur / m_time_scale));
if (m_fblModel->forward(std::vector<std::vector<float>>{tmp}, modelRes, modelMsg)) {
res = findSegmentsDynamic(modelRes, m_time_scale, ap_threshold, static_cast<int>(ap_dur / m_time_scale));
return true;
} else {
msg = QString::fromStdString(modelMsg);
Expand Down
8 changes: 4 additions & 4 deletions src/apps/FoxBreatheLabeler/util/Fbl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ namespace FBL {
explicit FBL(const QString &modelDir);
~FBL();

[[nodiscard]] bool recognize(const QString &filename, std::vector<std::pair<float, float>> &res, QString &msg, float ap_threshold = 0.4,
float ap_dur = 0.1) const;
[[nodiscard]] bool recognize(SF_VIO sf_vio, std::vector<std::pair<float, float>> &res, QString &msg, float ap_threshold = 0.4, float ap_dur = 0.08) const;
[[nodiscard]] bool recognize(const QString &filename, std::vector<std::pair<float, float>> &res, QString &msg,
float ap_threshold = 0.4, float ap_dur = 0.1) const;
[[nodiscard]] bool recognize(SF_VIO sf_vio, std::vector<std::pair<float, float>> &res, QString &msg,
float ap_threshold = 0.4, float ap_dur = 0.08) const;

private:
[[nodiscard]] SF_VIO resample(const QString &filename) const;

std::unique_ptr<FblModel> m_fblModel;

int m_audio_sample_rate;
int m_spec_win;
int m_hop_size;

float m_time_scale;
Expand Down
29 changes: 17 additions & 12 deletions src/apps/FoxBreatheLabeler/util/FblModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace FBL {
: m_env(Ort::Env(ORT_LOGGING_LEVEL_WARNING, "FblModel")), m_session_options(Ort::SessionOptions()),
m_session(nullptr) {

m_input_name = "mel";
m_input_name = "waveform";
m_output_name = "ap_probability";

#ifdef _WIN32
Expand All @@ -26,22 +26,27 @@ namespace FBL {

bool FblModel::forward(const std::vector<std::vector<float>> &input_data, std::vector<float> &result,
std::string &msg) const {
// 假设输入数据是二维的,形状为 (num_channels, height)
const size_t num_channels = input_data.size();
if (num_channels == 0) {
throw std::invalid_argument("Input data cannot be empty.");
const size_t batch_size = input_data.size();
if (batch_size == 0) {
throw std::invalid_argument("输入数据不能为空。");
}

const size_t height = input_data[0].size();

// 将输入数据展平成一维数组
std::vector<float> flattened_input;
// 确定输入数据中最大的长度
size_t max_height = 0;
for (const auto &channel_data : input_data) {
flattened_input.insert(flattened_input.end(), channel_data.begin(), channel_data.end());
max_height = std::max(max_height, channel_data.size());
}

// 创建一个用于存放扁平化后的输入数据的向量
std::vector<float> flattened_input(batch_size * max_height, 0.0f);

// 将输入数据扁平化并填充到flattened_input中
for (size_t i = 0; i < batch_size; ++i) {
std::copy(input_data[i].begin(), input_data[i].end(), flattened_input.begin() + i * max_height);
}

// 定义输入张量的形状,batch size 固定为 1
const std::array<int64_t, 3> input_shape_{1, static_cast<int64_t>(num_channels), static_cast<int64_t>(height)};
// 定义输入张量的形状
const std::array<int64_t, 2> input_shape_{static_cast<int64_t>(batch_size), static_cast<int64_t>(max_height)};

// 创建输入张量
const Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
Expand Down
2 changes: 1 addition & 1 deletion src/apps/FoxBreatheLabeler/util/FblThread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ namespace FBL {
}

if (i == 0) {
if (sp.start + sp_dur <= ap.first && ap.first < sp.end) {
if (cursor < ap.first && sp.start + sp_dur <= ap.first && ap.first < sp.end) {
out.append(Word{cursor, ap.first, "SP"});
cursor = ap.first;
}
Expand Down

0 comments on commit e24a9d6

Please sign in to comment.