From 7017473d187993a3dec85d05aaee1ea03ab4a003 Mon Sep 17 00:00:00 2001 From: Ian Lavery Date: Tue, 21 Nov 2023 16:41:05 -0800 Subject: [PATCH] v2.0 rust (#298) --- .github/workflows/rust-codestyle.yml | 12 ++ .github/workflows/rust-demos.yml | 8 + binding/rust/Cargo.toml | 11 +- binding/rust/src/leopard.rs | 196 ++++++++++++++--- binding/rust/tests/leopard_tests.rs | 307 ++++++++++++++++----------- demo/rust/filedemo/Cargo.lock | 78 ++++++- demo/rust/filedemo/Cargo.toml | 4 +- demo/rust/filedemo/src/main.rs | 22 +- demo/rust/micdemo/Cargo.lock | 20 +- demo/rust/micdemo/Cargo.toml | 4 +- demo/rust/micdemo/src/main.rs | 22 +- resources/spell-check/dict.txt | 1 + 12 files changed, 483 insertions(+), 202 deletions(-) diff --git a/.github/workflows/rust-codestyle.yml b/.github/workflows/rust-codestyle.yml index bfda2850..3ef059c8 100644 --- a/.github/workflows/rust-codestyle.yml +++ b/.github/workflows/rust-codestyle.yml @@ -44,6 +44,10 @@ jobs: toolchain: stable override: true + - name: Rust build binding + run: bash copy.sh && cargo build --verbose + working-directory: binding/rust + - name: Run clippy run: cargo clippy -- -D warnings working-directory: binding/rust @@ -65,6 +69,10 @@ jobs: toolchain: stable override: true + - name: Rust build binding + run: bash copy.sh && cargo build --verbose + working-directory: binding/rust + - name: Run clippy run: cargo clippy -- -D warnings working-directory: demo/rust/filedemo @@ -86,6 +94,10 @@ jobs: toolchain: stable override: true + - name: Rust build binding + run: bash copy.sh && cargo build --verbose + working-directory: binding/rust + - name: Run clippy run: cargo clippy -- -D warnings working-directory: demo/rust/micdemo diff --git a/.github/workflows/rust-demos.yml b/.github/workflows/rust-demos.yml index 0c8f8c3e..0f17b470 100644 --- a/.github/workflows/rust-demos.yml +++ b/.github/workflows/rust-demos.yml @@ -41,6 +41,10 @@ jobs: toolchain: stable override: true + - name: Rust build binding + run: bash copy.sh && cargo build --verbose + working-directory: binding/rust + - name: Rust build micdemo run: cargo build --verbose working-directory: demo/rust/micdemo @@ -73,6 +77,10 @@ jobs: toolchain: nightly override: true + - name: Rust build binding + run: bash copy.sh && cargo build --verbose + working-directory: binding/rust + - name: Rust build filedemo run: cargo build --verbose working-directory: demo/rust/filedemo diff --git a/binding/rust/Cargo.toml b/binding/rust/Cargo.toml index 4fde2d34..1e805b77 100644 --- a/binding/rust/Cargo.toml +++ b/binding/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pv_leopard" -version = "1.2.0" +version = "2.0.0" edition = "2018" description = "The Rust bindings for Picovoice's Leopard library" license = "Apache-2.0" @@ -27,10 +27,11 @@ crate_type = ["lib"] [dependencies] libc = "0.2" -libloading = "0.7" +libloading = "0.8" [dev-dependencies] distance = "0.4.0" -itertools = "0.10" -rodio = "0.15" -serde_json = "1.0.91" +itertools = "0.11" +rodio = "0.17" +serde_json = "1.0" +serde = { version = "1.0", features = ["derive"] } \ No newline at end of file diff --git a/binding/rust/src/leopard.rs b/binding/rust/src/leopard.rs index 71067166..600716e8 100644 --- a/binding/rust/src/leopard.rs +++ b/binding/rust/src/leopard.rs @@ -26,7 +26,11 @@ use libloading::{Library, Symbol}; use crate::util::{pathbuf_to_cstring, pv_library_path, pv_model_path}; #[repr(C)] -struct CLeopard {} +struct CLeopard { + // Fields suggested by the Rustonomicon: https://doc.rust-lang.org/nomicon/ffi.html#representing-opaque-structs + _data: [u8; 0], + _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} #[repr(C)] struct CLeopardWord { @@ -34,6 +38,7 @@ struct CLeopardWord { start_sec: f32, end_sec: f32, confidence: f32, + speaker_tag: i32, } #[repr(C)] @@ -58,6 +63,7 @@ type PvLeopardInitFn = unsafe extern "C" fn( access_key: *const c_char, model_path: *const c_char, enable_automatic_punctuation: bool, + enable_diarization: bool, object: *mut *mut CLeopard, ) -> PvStatus; type PvSampleRateFn = unsafe extern "C" fn() -> i32; @@ -80,6 +86,12 @@ type PvLeopardProcessFileFn = unsafe extern "C" fn( type PvLeopardDeleteFn = unsafe extern "C" fn(object: *mut CLeopard); type PvLeopardTranscriptDeleteFn = unsafe extern "C" fn(transcript: *mut c_char); type PvLeopardWordsDeleteFn = unsafe extern "C" fn(words: *mut CLeopardWord); +type PvGetErrorStackFn = unsafe extern "C" fn( + message_stack: *mut *mut *mut c_char, + message_stack_depth: *mut i32 +) -> PvStatus; +type PvFreeErrorStackFn = unsafe extern "C" fn(message_stack: *mut *mut c_char); +type PvSetSdkFn = unsafe extern "C" fn(sdk: *const c_char); #[derive(Clone, Debug)] pub enum LeopardErrorStatus { @@ -91,8 +103,9 @@ pub enum LeopardErrorStatus { #[derive(Clone, Debug)] pub struct LeopardError { - status: LeopardErrorStatus, - message: String, + pub status: LeopardErrorStatus, + pub message: String, + pub message_stack: Vec, } impl LeopardError { @@ -100,13 +113,35 @@ impl LeopardError { Self { status, message: message.into(), + message_stack: Vec::new() + } + } + + pub fn new_with_stack( + status: LeopardErrorStatus, + message: impl Into, + message_stack: impl Into> + ) -> Self { + Self { + status, + message: message.into(), + message_stack: message_stack.into(), } } } impl std::fmt::Display for LeopardError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}: {:?}", self.message, self.status) + let mut message_string = String::new(); + message_string.push_str(&format!("{} with status '{:?}'", self.message, self.status)); + + if !self.message_stack.is_empty() { + message_string.push(':'); + for x in 0..self.message_stack.len() { + message_string.push_str(&format!(" [{}] {}\n", x, self.message_stack[x])) + }; + } + write!(f, "{}", message_string) } } @@ -117,6 +152,7 @@ pub struct LeopardBuilder { model_path: PathBuf, library_path: PathBuf, enable_automatic_punctuation: bool, + enable_diarization: bool, } impl Default for LeopardBuilder { @@ -127,6 +163,7 @@ impl Default for LeopardBuilder { impl LeopardBuilder { const DEFAULT_ENABLE_AUTOMATIC_PUNCTUATION: bool = false; + const DEFAULT_ENABLE_DIARIZATION: bool = false; pub fn new() -> Self { Self { @@ -134,6 +171,7 @@ impl LeopardBuilder { model_path: pv_model_path(), library_path: pv_library_path(), enable_automatic_punctuation: Self::DEFAULT_ENABLE_AUTOMATIC_PUNCTUATION, + enable_diarization: Self::DEFAULT_ENABLE_DIARIZATION, } } @@ -160,12 +198,22 @@ impl LeopardBuilder { self } + pub fn enable_diarization( + &mut self, + enable_diarization: bool, + ) -> &mut Self { + self.enable_diarization = enable_diarization; + self + } + + pub fn init(&self) -> Result { let inner = LeopardInner::init( &self.access_key, &self.model_path, &self.library_path, self.enable_automatic_punctuation, + self.enable_diarization, ); match inner { Ok(inner) => Ok(Leopard { @@ -182,6 +230,7 @@ pub struct LeopardWord { pub start_sec: f32, pub end_sec: f32, pub confidence: f32, + pub speaker_tag: i32, } impl From<&CLeopardWord> for Result { @@ -200,6 +249,7 @@ impl From<&CLeopardWord> for Result { start_sec: c_leopard_word.start_sec, end_sec: c_leopard_word.end_sec, confidence: c_leopard_word.confidence, + speaker_tag: c_leopard_word.speaker_tag, }) } } @@ -254,23 +304,60 @@ unsafe fn load_library_fn( }) } -fn check_fn_call_status(status: PvStatus, function_name: &str) -> Result<(), LeopardError> { +fn check_fn_call_status( + vtable: &LeopardInnerVTable, + status: PvStatus, + function_name: &str +) -> Result<(), LeopardError> { match status { PvStatus::SUCCESS => Ok(()), - _ => Err(LeopardError::new( - LeopardErrorStatus::LibraryError(status), - format!("Function '{}' in the leopard library failed", function_name), - )), + _ => unsafe { + let mut message_stack_ptr: *mut c_char = std::ptr::null_mut(); + let mut message_stack_ptr_ptr = addr_of_mut!(message_stack_ptr); + + let mut message_stack_depth: i32 = 0; + let err_status = (vtable.pv_get_error_stack)( + addr_of_mut!(message_stack_ptr_ptr), + addr_of_mut!(message_stack_depth), + ); + + if err_status != PvStatus::SUCCESS { + return Err(LeopardError::new( + LeopardErrorStatus::LibraryError(err_status), + "Unable to get Leopard error state.", + )) + } + + let mut message_stack = Vec::new(); + for i in 0..message_stack_depth as usize { + let message = CStr::from_ptr(*message_stack_ptr_ptr.add(i)); + let message = message.to_string_lossy().into_owned(); + message_stack.push(message); + } + + (vtable.pv_free_error_stack)(message_stack_ptr_ptr); + + Err(LeopardError::new_with_stack( + LeopardErrorStatus::LibraryError(status), + format!("'{function_name}' failed"), + message_stack, + )) + }, } } struct LeopardInnerVTable { + pv_leopard_init: RawSymbol, pv_leopard_process: RawSymbol, pv_leopard_process_file: RawSymbol, pv_leopard_delete: RawSymbol, pv_leopard_transcript_delete: RawSymbol, pv_leopard_words_delete: RawSymbol, - + pv_leopard_version: RawSymbol, + pv_sample_rate: RawSymbol, + pv_get_error_stack: RawSymbol, + pv_free_error_stack: RawSymbol, + pv_set_sdk: RawSymbol, _lib_guard: Library, } @@ -279,6 +366,7 @@ impl LeopardInnerVTable { // SAFETY: the library will be hold by this struct and therefore the symbols can't outlive the library unsafe { Ok(Self { + pv_leopard_init: load_library_fn(&lib, b"pv_leopard_init")?, pv_leopard_process: load_library_fn(&lib, b"pv_leopard_process")?, pv_leopard_process_file: load_library_fn(&lib, b"pv_leopard_process_file")?, pv_leopard_delete: load_library_fn(&lib, b"pv_leopard_delete")?, @@ -287,6 +375,12 @@ impl LeopardInnerVTable { b"pv_leopard_transcript_delete", )?, pv_leopard_words_delete: load_library_fn(&lib, b"pv_leopard_words_delete")?, + pv_leopard_version: load_library_fn(&lib, b"pv_leopard_version")?, + pv_sample_rate: load_library_fn(&lib, b"pv_sample_rate")?, + + pv_get_error_stack: load_library_fn(&lib, b"pv_get_error_stack")?, + pv_free_error_stack: load_library_fn(&lib, b"pv_free_error_stack")?, + pv_set_sdk: load_library_fn(&lib, b"pv_set_sdk")?, _lib_guard: lib, }) @@ -311,6 +405,7 @@ impl LeopardInner { model_path: P, library_path: P, enable_automatic_punctuation: bool, + enable_diarization: bool, ) -> Result { if access_key.is_empty() { return Err(LeopardError::new( @@ -346,6 +441,18 @@ impl LeopardInner { ) })?; + let vtable = LeopardInnerVTable::new(lib)?; + + let sdk_string = match CString::new("rust") { + Ok(sdk_string) => sdk_string, + Err(err) => { + return Err(LeopardError::new( + LeopardErrorStatus::ArgumentError, + format!("sdk_string is not a valid C string {err}"), + )) + } + }; + let access_key = match CString::new(access_key) { Ok(access_key) => access_key, Err(err) => { @@ -362,37 +469,32 @@ impl LeopardInner { // safe, because we don't use the raw symbols after this function // anymore. let (sample_rate, version) = unsafe { - let pv_leopard_init = load_library_fn::(&lib, b"pv_leopard_init")?; - let pv_sample_rate = load_library_fn::(&lib, b"pv_sample_rate")?; - let pv_leopard_version = - load_library_fn::(&lib, b"pv_leopard_version")?; + (vtable.pv_set_sdk)(sdk_string.as_ptr()); - let status = pv_leopard_init( + let status = (vtable.pv_leopard_init)( access_key.as_ptr(), pv_model_path.as_ptr(), enable_automatic_punctuation, + enable_diarization, addr_of_mut!(cleopard), ); - check_fn_call_status(status, "pv_leopard_init")?; + check_fn_call_status(&vtable, status, "pv_leopard_init")?; - let version = match CStr::from_ptr(pv_leopard_version()).to_str() { - Ok(string) => string.to_string(), - Err(err) => { - return Err(LeopardError::new( - LeopardErrorStatus::LibraryLoadError, - format!("Failed to get version info from Leopard Library: {}", err), - )) - } - }; + let version = CStr::from_ptr((vtable.pv_leopard_version)()) + .to_string_lossy() + .into_owned(); - (pv_sample_rate(), version) + ( + (vtable.pv_sample_rate)(), + version + ) }; Ok(Self { cleopard, sample_rate, version, - vtable: LeopardInnerVTable::new(lib)?, + vtable, }) } @@ -418,7 +520,7 @@ impl LeopardInner { addr_of_mut!(words_ptr), ); - check_fn_call_status(status, "pv_leopard_process")?; + check_fn_call_status(&self.vtable, status, "pv_leopard_process")?; let transcript = String::from(CStr::from_ptr(transcript_ptr).to_str().map_err(|_| { @@ -486,7 +588,7 @@ impl LeopardInner { )); } - check_fn_call_status(status, "pv_leopard_process_file")?; + check_fn_call_status(&self.vtable, status, "pv_leopard_process_file")?; } let transcript = @@ -521,3 +623,39 @@ impl Drop for LeopardInner { } } } + +#[cfg(test)] +mod tests { + use std::env; + + use crate::util::{pv_library_path, pv_model_path}; + use crate::leopard::{LeopardInner}; + + #[test] + fn test_process_error_stack() { + let access_key = env::var("PV_ACCESS_KEY") + .expect("Pass the AccessKey in using the PV_ACCESS_KEY env variable"); + + let mut inner = LeopardInner::init( + &access_key.as_str(), + pv_model_path(), + pv_library_path(), + false, + false + ).expect("Unable to create Leopard"); + + let test_pcm = vec![0; 1024]; + let address = inner.cleopard; + inner.cleopard = std::ptr::null_mut(); + + let res = inner.process(&test_pcm); + + inner.cleopard = address; + if let Err(err) = res { + assert!(err.message_stack.len() > 0); + assert!(err.message_stack.len() < 8); + } else { + assert_eq!(res.unwrap().transcript, ""); + } + } +} \ No newline at end of file diff --git a/binding/rust/tests/leopard_tests.rs b/binding/rust/tests/leopard_tests.rs index d946ec3d..015543b9 100644 --- a/binding/rust/tests/leopard_tests.rs +++ b/binding/rust/tests/leopard_tests.rs @@ -14,12 +14,49 @@ mod tests { use distance::*; use itertools::Itertools; use rodio::{source::Source, Decoder}; - use serde_json::Value; + use serde::Deserialize; use std::env; use std::fs::{read_to_string, File}; use std::io::BufReader; - use leopard::{LeopardBuilder, LeopardTranscript}; + use leopard::{LeopardBuilder, LeopardWord}; + + #[derive(Debug, Deserialize)] + struct WordJson { + word: String, + start_sec: Option, + end_sec: Option, + confidence: Option, + speaker_tag: i32, + } + + #[derive(Debug, Deserialize)] + struct LanguageTestJson { + language: String, + audio_file: String, + transcript: String, + transcript_with_punctuation: String, + error_rate: f32, + words: Vec, + } + + #[derive(Debug, Deserialize)] + struct DiarizationTestJson { + language: String, + audio_file: String, + words: Vec, + } + + #[derive(Debug, Deserialize)] + struct TestsJson { + language_tests: Vec, + diarization_tests: Vec, + } + + #[derive(Debug, Deserialize)] + struct RootJson { + tests: TestsJson, + } fn append_lang(path: &str, language: &str) -> String { if language == "en" { @@ -29,7 +66,7 @@ mod tests { } } - fn load_test_data() -> Value { + fn load_test_data() -> TestsJson { let test_json_path = format!( "{}{}", env!("CARGO_MANIFEST_DIR"), @@ -37,9 +74,8 @@ mod tests { ); let contents: String = read_to_string(test_json_path).expect("Unable to read test_data.json"); - let test_json: Value = - serde_json::from_str(&contents).expect("Unable to parse test_data.json"); - test_json + let root: RootJson = serde_json::from_str(&contents).expect("Failed to parse JSON"); + root.tests } fn model_path_by_language(language: &str) -> String { @@ -56,30 +92,30 @@ mod tests { return distance as f32 / expected_transcript.len() as f32; } - fn validate_metadata(leopard_transcript: LeopardTranscript, audio_length: f32) { - let norm_transcript = leopard_transcript.transcript.to_uppercase(); - for i in 0..leopard_transcript.words.len() { - let leopard_word = leopard_transcript.words.get(i).unwrap().clone(); - - assert!(norm_transcript.contains(&leopard_word.word.to_uppercase())); - assert!(leopard_word.start_sec > 0.0); - assert!(leopard_word.start_sec <= leopard_word.end_sec); - if i < (leopard_transcript.words.len() - 1) { - let next_leopard_word = leopard_transcript.words.get(i + 1).unwrap().clone(); - assert!(leopard_word.end_sec <= next_leopard_word.start_sec); + fn validate_metadata(words: Vec, reference_words: Vec, enable_diarization: bool) { + for i in 0..words.len() { + let leopard_word = words.get(i).unwrap().clone(); + let reference_word = reference_words.get(i).unwrap().clone(); + assert!(&leopard_word.word.to_uppercase() == &reference_word.word.to_uppercase()); + assert!((leopard_word.start_sec-reference_word.start_sec.unwrap()).abs() <= 0.1); + assert!((leopard_word.end_sec-reference_word.end_sec.unwrap()).abs() <= 0.1); + assert!((leopard_word.confidence-reference_word.confidence.unwrap()).abs() <= 0.1); + if enable_diarization { + assert!(leopard_word.speaker_tag == reference_word.speaker_tag); + } else { + assert!(leopard_word.speaker_tag == -1); } - assert!(leopard_word.end_sec <= audio_length); - assert!(leopard_word.confidence >= 0.0 && leopard_word.confidence <= 1.0); } } fn run_test_process( language: &str, transcript: &str, - punctuations: Vec<&str>, - test_punctuation: bool, + enable_automatic_punctuation: bool, + enable_diarization: bool, error_rate: f32, test_audio: &str, + words: Vec ) { let access_key = env::var("PV_ACCESS_KEY") .expect("Pass the AccessKey in using the PV_ACCESS_KEY env variable"); @@ -93,38 +129,32 @@ mod tests { test_audio ); - let mut norm_transcript = transcript.to_string(); - if !test_punctuation { - punctuations.iter().for_each(|p| { - norm_transcript = norm_transcript.replace(p, ""); - }); - } - let audio_file = BufReader::new(File::open(&audio_path).expect(&audio_path)); let source = Decoder::new(audio_file).unwrap(); let leopard = LeopardBuilder::new() .access_key(access_key) .model_path(model_path) - .enable_automatic_punctuation(test_punctuation) + .enable_automatic_punctuation(enable_automatic_punctuation) + .enable_diarization(enable_diarization) .init() .expect("Unable to create Leopard"); assert_eq!(leopard.sample_rate(), source.sample_rate()); - let audio_file_duration = source.total_duration().unwrap().as_secs_f32(); let result = leopard.process(&source.collect_vec()).unwrap(); - assert!(character_error_rate(&result.transcript, &norm_transcript) < error_rate); - validate_metadata(result, audio_file_duration); + assert!(character_error_rate(&result.transcript, &transcript) < error_rate); + validate_metadata(result.words, words, enable_diarization); } fn run_test_process_file( language: &str, transcript: &str, - punctuations: Vec<&str>, - test_punctuation: bool, + enable_automatic_punctuation: bool, + enable_diarization: bool, error_rate: f32, test_audio: &str, + words: Vec ) { let access_key = env::var("PV_ACCESS_KEY") .expect("Pass the AccessKey in using the PV_ACCESS_KEY env variable"); @@ -138,147 +168,166 @@ mod tests { test_audio ); - let mut norm_transcript = transcript.to_string(); - if !test_punctuation { - punctuations.iter().for_each(|p| { - norm_transcript = norm_transcript.replace(p, ""); - }); - } + let leopard = LeopardBuilder::new() + .access_key(access_key) + .model_path(model_path) + .enable_automatic_punctuation(enable_automatic_punctuation) + .enable_diarization(enable_diarization) + .init() + .expect("Unable to create Leopard"); - let audio_file = BufReader::new(File::open(&audio_path).expect(&audio_path)); - let source = Decoder::new(audio_file).unwrap(); + let result = leopard.process_file(audio_path).unwrap(); + + assert!(character_error_rate(&result.transcript, &transcript) < error_rate); + validate_metadata(result.words, words, enable_diarization); + } + + fn run_test_diarization( + language: &str, + test_audio: &str, + reference_words: Vec + ) { + let access_key = env::var("PV_ACCESS_KEY") + .expect("Pass the AccessKey in using the PV_ACCESS_KEY env variable"); + + let model_path = model_path_by_language(language); + + let audio_path = format!( + "{}{}{}", + env!("CARGO_MANIFEST_DIR"), + "/../../resources/audio_samples/", + test_audio + ); let leopard = LeopardBuilder::new() .access_key(access_key) .model_path(model_path) - .enable_automatic_punctuation(test_punctuation) + .enable_diarization(true) .init() .expect("Unable to create Leopard"); - assert_eq!(leopard.sample_rate(), source.sample_rate()); - let audio_file_duration = source.total_duration().unwrap().as_secs_f32(); let result = leopard.process_file(audio_path).unwrap(); - assert!(character_error_rate(&result.transcript, &norm_transcript) < error_rate); - validate_metadata(result, audio_file_duration); + for i in 0..result.words.len() { + let leopard_word = result.words.get(i).unwrap().clone(); + let reference_word = reference_words.get(i).unwrap().clone(); + assert!(&leopard_word.word.to_uppercase() == &reference_word.word.to_uppercase()); + assert!(leopard_word.speaker_tag == reference_word.speaker_tag); + } } #[test] fn test_process() -> Result<(), String> { - let test_json: Value = load_test_data(); - - for t in test_json["tests"]["parameters"].as_array().unwrap() { - let language = t["language"].as_str().unwrap(); - let transcript = t["transcript"].as_str().unwrap(); - let punctuations = t["punctuations"] - .as_array() - .unwrap() - .iter() - .map(|v| v.as_str().unwrap()) - .collect_vec(); - let error_rate = t["error_rate"].as_f64().unwrap() as f32; - - let test_audio = t["audio_file"].as_str().unwrap(); + let test_json: TestsJson = load_test_data(); + for t in test_json.language_tests { run_test_process( - language, - transcript, - punctuations, + &t.language, + &t.transcript, false, - error_rate, - &test_audio, + false, + t.error_rate, + &t.audio_file, + t.words ); } Ok(()) } #[test] - fn test_process_punctuation() -> Result<(), String> { - let test_json: Value = load_test_data(); - - for t in test_json["tests"]["parameters"].as_array().unwrap() { - let language = t["language"].as_str().unwrap(); - let transcript = t["transcript"].as_str().unwrap(); - let punctuations = t["punctuations"] - .as_array() - .unwrap() - .iter() - .map(|v| v.as_str().unwrap()) - .collect_vec(); - let error_rate = t["error_rate"].as_f64().unwrap() as f32; - - let test_audio = t["audio_file"].as_str().unwrap(); + fn test_process_file() -> Result<(), String> { + let test_json: TestsJson = load_test_data(); - run_test_process( - language, - transcript, - punctuations, - true, - error_rate, - &test_audio, + for t in test_json.language_tests { + run_test_process_file( + &t.language, + &t.transcript, + false, + false, + t.error_rate, + &t.audio_file, + t.words ); } Ok(()) } - #[test] - fn test_process_file() -> Result<(), String> { - let test_json: Value = load_test_data(); - - for t in test_json["tests"]["parameters"].as_array().unwrap() { - let language = t["language"].as_str().unwrap(); - let transcript = t["transcript"].as_str().unwrap(); - let punctuations = t["punctuations"] - .as_array() - .unwrap() - .iter() - .map(|v| v.as_str().unwrap()) - .collect_vec(); - let error_rate = t["error_rate"].as_f64().unwrap() as f32; - let test_audio = t["audio_file"].as_str().unwrap(); + #[test] + fn test_process_file_punctuation() -> Result<(), String> { + let test_json: TestsJson = load_test_data(); + for t in test_json.language_tests { run_test_process_file( - language, - transcript, - punctuations, + &t.language, + &t.transcript_with_punctuation, + true, false, - error_rate, - &test_audio, + t.error_rate, + &t.audio_file, + t.words ); } Ok(()) } #[test] - fn test_process_file_punctuation() -> Result<(), String> { - let test_json: Value = load_test_data(); - - for t in test_json["tests"]["parameters"].as_array().unwrap() { - let language = t["language"].as_str().unwrap(); - let transcript = t["transcript"].as_str().unwrap(); - let punctuations = t["punctuations"] - .as_array() - .unwrap() - .iter() - .map(|v| v.as_str().unwrap()) - .collect_vec(); - let error_rate = t["error_rate"].as_f64().unwrap() as f32; - - let test_audio = t["audio_file"].as_str().unwrap(); + fn test_process_file_diarization() -> Result<(), String> { + let test_json: TestsJson = load_test_data(); + for t in test_json.language_tests { run_test_process_file( - language, - transcript, - punctuations, + &t.language, + &t.transcript, + false, true, - error_rate, - &test_audio, + t.error_rate, + &t.audio_file, + t.words + ); + } + Ok(()) + } + + #[test] + fn test_diarization() -> Result<(), String> { + let test_json: TestsJson = load_test_data(); + + for t in test_json.diarization_tests { + run_test_diarization( + &t.language, + &t.audio_file, + t.words ); } Ok(()) } + #[test] + fn test_error_stack() { + let mut error_stack = Vec::new(); + + let res = LeopardBuilder::new() + .access_key("invalid") + .init(); + + if let Err(err) = res { + error_stack = err.message_stack + } + + assert!(0 < error_stack.len() && error_stack.len() <= 8); + + let res = LeopardBuilder::new() + .access_key("invalid") + .init(); + if let Err(err) = res { + assert_eq!(error_stack.len(), err.message_stack.len()); + for i in 0..error_stack.len() { + assert_eq!(error_stack[i], err.message_stack[i]) + } + } + } + #[test] fn test_version() { let access_key = env::var("PV_ACCESS_KEY") @@ -291,4 +340,4 @@ mod tests { assert_ne!(leopard.version(), "") } -} +} \ No newline at end of file diff --git a/demo/rust/filedemo/Cargo.lock b/demo/rust/filedemo/Cargo.lock index a2ca5a5f..c33422f6 100644 --- a/demo/rust/filedemo/Cargo.lock +++ b/demo/rust/filedemo/Cargo.lock @@ -138,12 +138,12 @@ checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" [[package]] name = "libloading" -version = "0.7.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f84d96438c15fcd6c3f244c8fce01d1e2b9c6b5623e9c711dc9286d8fc92d6a" +checksum = "c571b676ddfc9a8c12f1f3d3085a7b163966a8fd8098a90640953ce5f6170161" dependencies = [ "cfg-if", - "winapi", + "windows-sys", ] [[package]] @@ -197,9 +197,7 @@ dependencies = [ [[package]] name = "pv_leopard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d920fcd9bda7cf20f70083cbb85000321a78ef9638ec8418fc6f0a1f0d30b3fc" +version = "2.0.0" dependencies = [ "libc", "libloading", @@ -207,7 +205,7 @@ dependencies = [ [[package]] name = "pv_leopard_filedemo" -version = "1.1.1" +version = "2.0.0" dependencies = [ "chrono", "clap", @@ -380,3 +378,69 @@ name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/demo/rust/filedemo/Cargo.toml b/demo/rust/filedemo/Cargo.toml index a9e6cc0a..ae6c6894 100644 --- a/demo/rust/filedemo/Cargo.toml +++ b/demo/rust/filedemo/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pv_leopard_filedemo" -version = "1.1.1" +version = "2.0.0" edition = "2018" [dependencies] @@ -8,5 +8,5 @@ chrono = "0.4.20" clap = "3.2.16" hound = "3.4.0" itertools = "0.10.3" -pv_leopard = "=1.2.0" +pv_leopard = { path = "../../../binding/rust" } tabwriter = "1.2.1" diff --git a/demo/rust/filedemo/src/main.rs b/demo/rust/filedemo/src/main.rs index ff8bf42d..005da8d2 100644 --- a/demo/rust/filedemo/src/main.rs +++ b/demo/rust/filedemo/src/main.rs @@ -1,5 +1,5 @@ /* - Copyright 2022 Picovoice Inc. + Copyright 2022-2023 Picovoice Inc. You may not use this file except in compliance with the license. A copy of the license is located in the "LICENSE" file accompanying this source. @@ -21,6 +21,7 @@ fn leopard_demo( access_key: &str, model_path: Option<&str>, enable_automatic_punctuation: bool, + enable_diarization: bool, verbose: bool, ) { let mut leopard_builder = LeopardBuilder::new(); @@ -31,6 +32,7 @@ fn leopard_demo( let leopard = leopard_builder .enable_automatic_punctuation(enable_automatic_punctuation) + .enable_diarization(enable_diarization) .access_key(access_key) .init() .expect("Failed to create Leopard"); @@ -40,13 +42,13 @@ fn leopard_demo( if verbose { println!(); let mut tw = TabWriter::new(vec![]); - writeln!(&mut tw, "Word\tStart Sec\tEnd Sec\tConfidence").unwrap(); - writeln!(&mut tw, "----\t---------\t-------\t----------").unwrap(); + writeln!(&mut tw, "Word\tStart Sec\tEnd Sec\tConfidence\tSpeaker Tag").unwrap(); + writeln!(&mut tw, "----\t---------\t-------\t----------\t-----------").unwrap(); leopard_transcript.words.iter().for_each(|word| { writeln!( &mut tw, - "{}\t{:.2}\t{:.2}\t{:.2}", - word.word, word.start_sec, word.end_sec, word.confidence + "{}\t{:.2}\t{:.2}\t{:.2}\t{}", + word.word, word.start_sec, word.end_sec, word.confidence, word.speaker_tag ) .unwrap(); }); @@ -86,9 +88,15 @@ fn main() { .arg( Arg::with_name("disable_automatic_punctuation") .long("disable_automatic_punctuation") - .short('d') + .short('p') .help("Set to disable automatic punctuation insertion."), ) + .arg( + Arg::with_name("disable_speaker_diarization") + .long("disable_speaker_diarization") + .short('d') + .help("Set to disable speaker diarization."), + ) .arg( Arg::with_name("verbose") .long("verbose") @@ -106,6 +114,7 @@ fn main() { let model_path = matches.value_of("model_path"); let enable_automatic_punctuation = !matches.contains_id("disable_automatic_punctuation"); + let enable_diarization = !matches.contains_id("disable_speaker_diarization"); let verbose = matches.contains_id("verbose"); @@ -114,6 +123,7 @@ fn main() { access_key, model_path, enable_automatic_punctuation, + enable_diarization, verbose, ); } diff --git a/demo/rust/micdemo/Cargo.lock b/demo/rust/micdemo/Cargo.lock index 9670b109..9d84f66e 100644 --- a/demo/rust/micdemo/Cargo.lock +++ b/demo/rust/micdemo/Cargo.lock @@ -197,16 +197,6 @@ version = "0.2.147" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" -[[package]] -name = "libloading" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" -dependencies = [ - "cfg-if", - "winapi", -] - [[package]] name = "libloading" version = "0.8.0" @@ -267,17 +257,15 @@ dependencies = [ [[package]] name = "pv_leopard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d920fcd9bda7cf20f70083cbb85000321a78ef9638ec8418fc6f0a1f0d30b3fc" +version = "2.0.0" dependencies = [ "libc", - "libloading 0.7.4", + "libloading", ] [[package]] name = "pv_leopard_micdemo" -version = "1.1.1" +version = "2.0.0" dependencies = [ "chrono", "clap", @@ -296,7 +284,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40be1d15bef993d0b923720589479d4d2e93e9fb6286328e2551f0fdbf45de31" dependencies = [ "libc", - "libloading 0.8.0", + "libloading", ] [[package]] diff --git a/demo/rust/micdemo/Cargo.toml b/demo/rust/micdemo/Cargo.toml index d6e02264..841b9a5b 100644 --- a/demo/rust/micdemo/Cargo.toml +++ b/demo/rust/micdemo/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pv_leopard_micdemo" -version = "1.1.1" +version = "2.0.0" edition = "2018" [dependencies] @@ -9,6 +9,6 @@ clap = "3.2.16" ctrlc = "3.2.2" hound = "3.4.0" itertools = "0.10.3" -pv_leopard = "=1.2.0" +pv_leopard = { path = "../../../binding/rust" } pv_recorder = "=1.2.1" tabwriter = "1.2.1" diff --git a/demo/rust/micdemo/src/main.rs b/demo/rust/micdemo/src/main.rs index 4bdc2963..0beb0e3f 100644 --- a/demo/rust/micdemo/src/main.rs +++ b/demo/rust/micdemo/src/main.rs @@ -1,5 +1,5 @@ /* - Copyright 2022 Picovoice Inc. + Copyright 2022-2023 Picovoice Inc. You may not use this file except in compliance with the license. A copy of the license is located in the "LICENSE" file accompanying this source. @@ -26,6 +26,7 @@ fn leopard_demo( access_key: &str, model_path: Option<&str>, enable_automatic_punctuation: bool, + enable_diarization: bool, verbose: bool, ) { let mut leopard_builder = LeopardBuilder::new(); @@ -36,6 +37,7 @@ fn leopard_demo( let leopard = leopard_builder .enable_automatic_punctuation(enable_automatic_punctuation) + .enable_diarization(enable_diarization) .access_key(access_key) .init() .expect("Failed to create Leopard"); @@ -89,13 +91,13 @@ fn leopard_demo( if verbose { println!(); let mut tw = TabWriter::new(vec![]); - writeln!(&mut tw, "Word\tStart Sec\tEnd Sec\tConfidence").unwrap(); - writeln!(&mut tw, "----\t---------\t-------\t----------").unwrap(); + writeln!(&mut tw, "Word\tStart Sec\tEnd Sec\tConfidence\tSpeaker Tag").unwrap(); + writeln!(&mut tw, "----\t---------\t-------\t----------\t-----------").unwrap(); leopard_transcript.words.iter().for_each(|word| { writeln!( &mut tw, - "{}\t{:.2}\t{:.2}\t{:.2}", - word.word, word.start_sec, word.end_sec, word.confidence + "{}\t{:.2}\t{:.2}\t{:.2}\t{}", + word.word, word.start_sec, word.end_sec, word.confidence, word.speaker_tag ) .unwrap(); }); @@ -145,9 +147,15 @@ fn main() { .arg( Arg::with_name("disable_automatic_punctuation") .long("disable_automatic_punctuation") - .short('d') + .short('p') .help("Set to disable automatic punctuation insertion."), ) + .arg( + Arg::with_name("disable_speaker_diarization") + .long("disable_speaker_diarization") + .short('d') + .help("Set to disable speaker diarization."), + ) .arg( Arg::with_name("verbose") .long("verbose") @@ -187,6 +195,7 @@ fn main() { let model_path = matches.value_of("model_path"); let enable_automatic_punctuation = !matches.contains_id("disable_automatic_punctuation"); + let enable_diarization = !matches.contains_id("disable_speaker_diarization"); let verbose = matches.contains_id("verbose"); @@ -195,6 +204,7 @@ fn main() { access_key, model_path, enable_automatic_punctuation, + enable_diarization, verbose, ); } diff --git a/resources/spell-check/dict.txt b/resources/spell-check/dict.txt index 0d29eca2..f5bc2b33 100644 --- a/resources/spell-check/dict.txt +++ b/resources/spell-check/dict.txt @@ -127,6 +127,7 @@ ritmi RNFS rodio RTLD +Rustonomicon Sameline Signup sizecache