-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Copied across all proc-blocks from hotg-ai/rune v0.8.0
The actual commit is 5baadf6af6479648258a871dec151e8643d82b71.
- Loading branch information
Michael-F-Bryan
committed
Oct 10, 2021
1 parent
8366e06
commit 9cfa102
Showing
22 changed files
with
1,741 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,11 @@ | ||
[workspace] | ||
members = [ | ||
"fft", | ||
"audio_float_conversion", | ||
"image-normalization", | ||
"label", | ||
"modulo", | ||
"most_confident_indices", | ||
"noise-filtering", | ||
"normalize", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
[package] | ||
name = "audio_float_conversion" | ||
version = "0.8.0" | ||
edition = "2018" | ||
publish = false | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
hotg-rune-proc-blocks = "^0.8.0" | ||
|
||
[dev-dependencies] | ||
pretty_assertions = "0.7.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#![no_std] | ||
|
||
extern crate alloc; | ||
|
||
#[cfg(test)] | ||
#[macro_use] | ||
extern crate std; | ||
|
||
use hotg_rune_proc_blocks::{ProcBlock, Transform, Tensor}; | ||
|
||
// TODO: Add Generics | ||
|
||
#[derive(Debug, Clone, PartialEq, ProcBlock)] | ||
#[transform(inputs = [i16; _], outputs = [f32; _])] | ||
pub struct AudioFloatConversion { | ||
i16_max_as_float: f32, | ||
} | ||
|
||
const I16_MAX_AS_FLOAT: f32 = i16::MAX as f32; | ||
|
||
impl AudioFloatConversion { | ||
pub const fn new() -> Self { | ||
AudioFloatConversion { | ||
i16_max_as_float: I16_MAX_AS_FLOAT, | ||
} | ||
} | ||
|
||
fn check_input_dimensions(&self, dimensions: &[usize]) { | ||
assert_eq!( | ||
dimensions.len(), | ||
1, | ||
"This proc block only supports 1D outputs (requested output: {:?})", | ||
dimensions | ||
); | ||
} | ||
} | ||
|
||
impl Default for AudioFloatConversion { | ||
fn default() -> Self { AudioFloatConversion::new() } | ||
} | ||
|
||
impl Transform<Tensor<i16>> for AudioFloatConversion { | ||
type Output = Tensor<f32>; | ||
|
||
fn transform(&mut self, input: Tensor<i16>) -> Self::Output { | ||
self.check_input_dimensions(input.dimensions()); | ||
input.map(|_dims, &value| { | ||
(value as f32 / I16_MAX_AS_FLOAT).clamp(-1.0, 1.0) | ||
}) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn handle_empty() { | ||
let mut pb = AudioFloatConversion::new(); | ||
let input = Tensor::new_vector(vec![0; 15]); | ||
|
||
let got = pb.transform(input); | ||
|
||
assert_eq!(got.dimensions(), &[15]); | ||
} | ||
|
||
#[test] | ||
fn does_it_match() { | ||
let max = i16::MAX; | ||
let min = i16::MIN; | ||
|
||
let mut pb = AudioFloatConversion::new(); | ||
let input = Tensor::new_vector(vec![0, max / 2, min / 2]); | ||
|
||
let got = pb.transform(input); | ||
|
||
assert_eq!(got.elements()[0..3], [0.0, 0.49998474, -0.50001526]); | ||
} | ||
#[test] | ||
fn does_clutch_work() { | ||
let max = i16::MAX; | ||
let min = i16::MIN; | ||
|
||
let mut pb = AudioFloatConversion::new(); | ||
let input = Tensor::new_vector(vec![max, min, min + 1]); | ||
|
||
let got = pb.transform(input); | ||
|
||
assert_eq!(got.elements()[0..3], [1.0, -1.0, -1.0]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[package] | ||
name = "fft" | ||
version = "0.8.0" | ||
authors = ["The Rune Developers <[email protected]>"] | ||
edition = "2018" | ||
publish = false | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
hotg-rune-proc-blocks = "^0.8.0" | ||
hound = "3.4" | ||
libm = "0.2.1" | ||
# See https://github.com/hotg-ai/rune/pull/107#issuecomment-825806000 | ||
mel = { git = "https://github.com/hotg-ai/mel", rev = "017694ee3143c11ea9b75ba6cd27fe7c8a69a867", default-features = false } | ||
nalgebra = { version = "0.29", default-features = false, features = ["alloc"] } | ||
normalize = { path = "../normalize", version = "^0.8.0" } | ||
sonogram = {git = "https://github.com/hotg-ai/sonogram", rev = "009bc0cba44267d8a0807e43c9bb0712f0f334ea" } | ||
|
||
[dev-dependencies] | ||
pretty_assertions = "0.7.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
#![no_std] | ||
|
||
extern crate alloc; | ||
|
||
#[cfg(test)] | ||
#[macro_use] | ||
extern crate std; | ||
#[cfg(test)] | ||
#[macro_use] | ||
extern crate pretty_assertions; | ||
|
||
/// A type alias for [`ShortTimeFourierTransform`] which uses the camel case | ||
/// version of this crate. | ||
pub type Fft = ShortTimeFourierTransform; | ||
|
||
use alloc::{sync::Arc, vec::Vec}; | ||
use hotg_rune_proc_blocks::{ProcBlock, Transform, Tensor}; | ||
use sonogram::SpecOptionsBuilder; | ||
use nalgebra::DMatrix; | ||
|
||
#[derive(Debug, Clone, PartialEq, ProcBlock)] | ||
pub struct ShortTimeFourierTransform { | ||
sample_rate: u32, | ||
bins: usize, | ||
window_overlap: f32, | ||
} | ||
|
||
const DEFAULT_SAMPLE_RATE: u32 = 16000; | ||
const DEFAULT_BINS: usize = 480; | ||
const DEFAULT_WINDOW_OVERLAP: f32 = 0.6666667; | ||
|
||
impl ShortTimeFourierTransform { | ||
pub const fn new() -> Self { | ||
ShortTimeFourierTransform { | ||
sample_rate: DEFAULT_SAMPLE_RATE, | ||
bins: DEFAULT_BINS, | ||
window_overlap: DEFAULT_WINDOW_OVERLAP, | ||
} | ||
} | ||
|
||
fn transform_inner(&mut self, input: Vec<i16>) -> [u32; 1960] { | ||
// Build the spectrogram computation engine | ||
let mut spectrograph = SpecOptionsBuilder::new(49, 241) | ||
.set_window_fn(sonogram::hann_function) | ||
.load_data_from_memory(input, self.sample_rate as u32) | ||
.build(); | ||
|
||
// Compute the spectrogram giving the number of bins in a window and the | ||
// overlap between neighbour windows. | ||
spectrograph.compute(self.bins, self.window_overlap); | ||
|
||
let spectrogram = spectrograph.create_in_memory(false); | ||
|
||
let filter_count: usize = 40; | ||
let power_spectrum_size = 241; | ||
let window_size = 480; | ||
let sample_rate_usize: usize = 16000; | ||
|
||
// build up the mel filter matrix | ||
let mut mel_filter_matrix = | ||
DMatrix::<f64>::zeros(filter_count, power_spectrum_size); | ||
for (row, col, coefficient) in mel::enumerate_mel_scaling_matrix( | ||
sample_rate_usize, | ||
window_size, | ||
power_spectrum_size, | ||
filter_count, | ||
) { | ||
mel_filter_matrix[(row, col)] = coefficient; | ||
} | ||
|
||
let spectrogram = spectrogram.into_iter().map(f64::from); | ||
let power_spectrum_matrix_unflipped: DMatrix<f64> = | ||
DMatrix::from_iterator(49, power_spectrum_size, spectrogram); | ||
let power_spectrum_matrix_transposed = | ||
power_spectrum_matrix_unflipped.transpose(); | ||
let mut power_spectrum_vec: Vec<_> = | ||
power_spectrum_matrix_transposed.row_iter().collect(); | ||
power_spectrum_vec.reverse(); | ||
let power_spectrum_matrix: DMatrix<f64> = | ||
DMatrix::from_rows(&power_spectrum_vec); | ||
let mel_spectrum_matrix = &mel_filter_matrix * &power_spectrum_matrix; | ||
let mel_spectrum_matrix = mel_spectrum_matrix.map(libm::sqrt); | ||
|
||
let min_value = mel_spectrum_matrix | ||
.data | ||
.as_vec() | ||
.iter() | ||
.fold(f64::INFINITY, |a, &b| a.min(b)); | ||
let max_value = mel_spectrum_matrix | ||
.data | ||
.as_vec() | ||
.iter() | ||
.fold(f64::NEG_INFINITY, |a, &b| a.max(b)); | ||
|
||
let res: Vec<u32> = mel_spectrum_matrix | ||
.data | ||
.as_vec() | ||
.iter() | ||
.map(|freq| 65536.0 * (freq - min_value) / (max_value - min_value)) | ||
.map(|freq| freq as u32) | ||
.collect(); | ||
|
||
let mut out = [0; 1960]; | ||
out.copy_from_slice(&res[..1960]); | ||
out | ||
} | ||
} | ||
|
||
impl Default for ShortTimeFourierTransform { | ||
fn default() -> Self { ShortTimeFourierTransform::new() } | ||
} | ||
|
||
impl Transform<Tensor<i16>> for ShortTimeFourierTransform { | ||
type Output = Tensor<u32>; | ||
|
||
fn transform(&mut self, input: Tensor<i16>) -> Self::Output { | ||
let input = input.elements().to_vec(); | ||
let stft = self.transform_inner(input); | ||
Tensor::new_row_major(Arc::new(stft), alloc::vec![1, stft.len()]) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn it_works() { | ||
let mut fft_pb = ShortTimeFourierTransform::new(); | ||
fft_pb.set_sample_rate(16000); | ||
let input = Tensor::new_vector(vec![0; 16000]); | ||
|
||
let got = fft_pb.transform(input); | ||
|
||
assert_eq!(got.dimensions(), &[1, 1960]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
[package] | ||
name = "image-normalization" | ||
version = "0.8.0" | ||
edition = "2018" | ||
publish = false | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
num-traits = { version = "0.2.14", default-features = false } | ||
hotg-rune-proc-blocks = "^0.8.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#![no_std] | ||
|
||
#[cfg(test)] | ||
#[macro_use] | ||
extern crate alloc; | ||
|
||
use num_traits::{Bounded, ToPrimitive}; | ||
use hotg_rune_proc_blocks::{ProcBlock, Transform, Tensor}; | ||
|
||
/// A normalization routine which takes some tensor of integers and fits their | ||
/// values to the range `[0, 1]` as `f32`'s. | ||
#[derive(Debug, Default, Clone, PartialEq, ProcBlock)] | ||
#[non_exhaustive] | ||
#[transform(inputs = [u8; _], outputs = [f32; _])] | ||
#[transform(inputs = [i8; _], outputs = [f32; _])] | ||
#[transform(inputs = [u16; _], outputs = [f32; _])] | ||
#[transform(inputs = [i16; _], outputs = [f32; _])] | ||
#[transform(inputs = [u32; _], outputs = [f32; _])] | ||
#[transform(inputs = [i32; _], outputs = [f32; _])] | ||
pub struct ImageNormalization {} | ||
|
||
impl ImageNormalization { | ||
fn check_input_dimensions(&self, dimensions: &[usize]) { | ||
match *dimensions { | ||
[_, _, _, 3] => {}, | ||
[_, _, _, channels] => panic!( | ||
"The number of channels should be either 1 or 3, found {}", | ||
channels | ||
), | ||
_ => panic!("The image normalization proc block only supports outputs of the form [frames, rows, columns, channels], found {:?}", dimensions), | ||
} | ||
} | ||
} | ||
|
||
impl<T> Transform<Tensor<T>> for ImageNormalization | ||
where | ||
T: Bounded + ToPrimitive + Copy, | ||
{ | ||
type Output = Tensor<f32>; | ||
|
||
fn transform(&mut self, input: Tensor<T>) -> Self::Output { | ||
self.check_input_dimensions(input.dimensions()); | ||
input.map(|_, &value| normalize(value).expect("Cast should never fail")) | ||
} | ||
} | ||
|
||
fn normalize<T>(value: T) -> Option<f32> | ||
where | ||
T: Bounded + ToPrimitive, | ||
{ | ||
let min = T::min_value().to_f32()?; | ||
let max = T::max_value().to_f32()?; | ||
let value = value.to_f32()?; | ||
debug_assert!(min <= value && value <= max); | ||
|
||
Some((value - min) / (max - min)) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn normalizing_with_default_distribution_is_noop() { | ||
let dims = vec![1, 1, 1, 3]; | ||
let input: Tensor<u8> = | ||
Tensor::new_row_major(vec![0, 127, 255].into(), dims.clone()); | ||
let mut norm = ImageNormalization::default(); | ||
let should_be: Tensor<f32> = | ||
Tensor::new_row_major(vec![0.0, 127.0 / 255.0, 1.0].into(), dims); | ||
|
||
let got = norm.transform(input); | ||
|
||
assert_eq!(got, should_be); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[package] | ||
name = "label" | ||
version = "0.8.0" | ||
edition = "2018" | ||
publish = false | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
hotg-rune-proc-blocks = "^0.8.0" |
Oops, something went wrong.