Skip to content

Commit

Permalink
More offline test improvements (#153)
Browse files Browse the repository at this point in the history
* Protect logging with a mutex

Main thread and worker thread output could get interleaved weirdly
without this

* Move segments.json saving to different thread

This was taking a considerable amount of time, especially for longer
input files, reducing overall utilization

* Check whether offline test can push more data before waiting

* Fix offline test with large files

In
```
circlebuf_push_back(
  &gf->input_buffers[c],
  audio[c].data() +
    frames_count * frame_size_bytes,
  frames_size_bytes);
```
`frames_count * frame_size_bytes` would overflow with `int` on
a 4 hour file; using `size_t` (on a 64 bit platform) fixes that
  • Loading branch information
palana authored Aug 14, 2024
1 parent 6cc88b1 commit bdab41c
Showing 1 changed file with 71 additions and 26 deletions.
97 changes: 71 additions & 26 deletions src/tests/localvocal-offline-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ void obs_log(int log_level, const char *format, ...)

auto diff = now - start;

static std::mutex log_mutex;
auto lock = std::lock_guard(log_mutex);
// print timestamp
printf("[%02d:%02d:%02d.%03d] [%02d:%02lld.%03lld] ", now_tm.tm_hour, now_tm.tm_min,
now_tm.tm_sec, (int)(epoch.count() % 1000),
Expand Down Expand Up @@ -194,6 +196,11 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p
return gf;
}

std::mutex json_segments_input_mutex;
std::condition_variable json_segments_input_cv;
std::vector<nlohmann::json> json_segments_input;
bool json_segments_input_finished = false;

void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm32f_data,
size_t frames, int vad_state, const DetectionResultWithText &result)
{
Expand All @@ -214,33 +221,56 @@ void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm
// obs_log(gf->log_level, "Saving %lu frames to %s", frames, filename.c_str());
// write_audio_wav_file(filename.c_str(), pcm32f_data, frames);

// append a row to the array in the segments.json file
std::string segments_filename = "segments.json";
nlohmann::json segments_json;

// Read existing segments from file
std::ifstream segments_file(segments_filename);
if (segments_file.is_open()) {
segments_file >> segments_json;
segments_file.close();
}

// Create a new segment object
nlohmann::json segment;
segment["start_time"] = result.start_timestamp_ms / 1000.0;
segment["end_time"] = result.end_timestamp_ms / 1000.0;
segment["segment_label"] = result.text;

// Add the new segment to the segments array
segments_json.push_back(segment);
{
auto lock = std::lock_guard(json_segments_input_mutex);

// Add the new segment to the segments array
json_segments_input.push_back(segment);
}
json_segments_input_cv.notify_one();
}

void json_segments_saver_thread_function()
{
std::string segments_filename = "segments.json";
nlohmann::json segments_json;

decltype(json_segments_input) json_segments_input_local;

for (;;) {
{
auto lock = std::unique_lock(json_segments_input_mutex);
while (json_segments_input.empty()) {
if (json_segments_input_finished)
return;
json_segments_input_cv.wait(lock, [&] {
return json_segments_input_finished ||
!json_segments_input.empty();
});
}

std::swap(json_segments_input, json_segments_input_local);
json_segments_input.clear();
}

for (auto &elem : json_segments_input_local) {
segments_json.push_back(std::move(elem));
}

// Write the updated segments back to the file
std::ofstream segments_file_out(segments_filename);
if (segments_file_out.is_open()) {
segments_file_out << std::setw(4) << segments_json << std::endl;
segments_file_out.close();
} else {
obs_log(gf->log_level, "Failed to open %s", segments_filename.c_str());
// Write the updated segments back to the file
std::ofstream segments_file_out(segments_filename);
if (segments_file_out.is_open()) {
segments_file_out << std::setw(4) << segments_json << std::endl;
segments_file_out.close();
} else {
obs_log(LOG_INFO, "Failed to open %s", segments_filename.c_str());
}
}
}

Expand Down Expand Up @@ -361,6 +391,7 @@ int wmain(int argc, wchar_t *argv[])

std::cout << "LocalVocal Offline Test" << std::endl;
transcription_filter_data *gf = nullptr;
std::optional<std::thread> audio_chunk_saver_thread;

std::vector<std::vector<uint8_t>> audio =
read_audio_file(filenameStr.c_str(), [&](int sample_rate, int channels) {
Expand Down Expand Up @@ -419,6 +450,10 @@ int wmain(int argc, wchar_t *argv[])
return 1;
}

if (gf->enable_audio_chunks_callback) {
audio_chunk_saver_thread.emplace(json_segments_saver_thread_function);
}

// truncate the output file
obs_log(LOG_INFO, "Truncating output file");
std::ofstream output_file(gf->output_file_path, std::ios::trunc);
Expand All @@ -437,10 +472,10 @@ int wmain(int argc, wchar_t *argv[])

obs_log(LOG_INFO, "Sending samples to whisper buffer");
// 25 ms worth of frames
int frames = gf->sample_rate * window_size_in_ms.count() / 1000;
size_t frames = gf->sample_rate * window_size_in_ms.count() / 1000;
const int frame_size_bytes = sizeof(float);
int frames_size_bytes = frames * frame_size_bytes;
int frames_count = 0;
size_t frames_size_bytes = frames * frame_size_bytes;
size_t frames_count = 0;
int64_t start_time = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
Expand All @@ -464,12 +499,13 @@ int wmain(int argc, wchar_t *argv[])
if (false && now > max_wait)
break;

if (gf->input_buffers->size == 0)
break;

gf->input_cv->wait_for(
lock, std::chrono::milliseconds(10), [&] {
lock, std::chrono::milliseconds(1), [&] {
return gf->input_buffers->size == 0;
});
if (gf->input_buffers->size == 0)
break;
}
// push back current audio data to input circlebuf
for (size_t c = 0; c < gf->channels; c++) {
Expand Down Expand Up @@ -533,6 +569,15 @@ int wmain(int argc, wchar_t *argv[])
}
}

if (audio_chunk_saver_thread.has_value()) {
{
auto lock = std::lock_guard(json_segments_input_mutex);
json_segments_input_finished = true;
}
json_segments_input_cv.notify_one();
audio_chunk_saver_thread->join();
}

release_context(gf);

obs_log(LOG_INFO, "LocalVocal Offline Test Done");
Expand Down

0 comments on commit bdab41c

Please sign in to comment.