Skip to content

Commit

Permalink
add support for printing input/output to NNs (#46)
Browse files Browse the repository at this point in the history
Found this sitting on my local disk; guess it was useful at some point.
  • Loading branch information
tgolsson authored Aug 31, 2023
1 parent 38e8fda commit d8e0f2f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 12 deletions.
7 changes: 1 addition & 6 deletions crates/cervo-core/src/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,7 @@ impl Batcher {
}
}

Ok(self
.scratch
.ids
.drain(..)
.zip(outputs.into_iter())
.collect::<_>())
Ok(self.scratch.ids.drain(..).zip(outputs).collect::<_>())
}

/// Check if there is any data to run on here.
Expand Down
4 changes: 3 additions & 1 deletion crates/cervo-runtime/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ impl Runtime {
let mut unselected_jobs = Vec::new();

while let Some(ticket) = self.queue.pop() {
let Some(model) = self.models.iter().find(|m| m.id == ticket.1) else {continue};
let Some(model) = self.models.iter().find(|m| m.id == ticket.1) else {
continue;
};

if model.needs_to_execute()
&& (selected_jobs.is_empty() || model.can_run_in_time(available_cpu_time))
Expand Down
55 changes: 50 additions & 5 deletions crates/cervo/src/commands/run.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::{bail, Result};
use cervo_asset::AssetData;
use cervo_core::prelude::{Inferer, InfererExt, State};
use cervo_core::prelude::{Inferer, InfererExt, Response, State};
use clap::Parser;

use std::{collections::HashMap, fs::File, path::PathBuf, time::Instant};
Expand All @@ -19,6 +19,12 @@ pub(crate) struct Args {
/// An epsilon key to randomize noise.
#[clap(short, long)]
with_epsilon: Option<String>,

#[clap(long)]
print_output: bool,

#[clap(long)]
print_input: bool,
}

fn build_inputs_from_desc(count: u64, inputs: &[(String, Vec<usize>)]) -> HashMap<u64, State<'_>> {
Expand All @@ -39,6 +45,28 @@ fn build_inputs_from_desc(count: u64, inputs: &[(String, Vec<usize>)]) -> HashMa
.collect()
}

fn indent_by(target: String, prefix_len: usize) -> String {
let prefix = " ".repeat(prefix_len);

target
.lines()
.map(|line| format!("{}{}", prefix, line))
.collect::<Vec<_>>()
.join("\n")
}

fn print_input(obs: &HashMap<u64, State<'_>>) {
let formatted = format!("{:#?}", obs);
let indented = indent_by(formatted, 4);
println!("Inputs:\n{}", indented);
}

fn print_output(obs: &HashMap<u64, Response<'_>>) {
let formatted = format!("{:#?}", obs);
let indented = indent_by(formatted, 4);
println!("Outputs:\n{}", indented);
}

pub(super) fn run(config: Args) -> Result<()> {
let mut reader = File::open(&config.file)?;
let inferer = if cervo_nnef::is_nnef_tar(&config.file) {
Expand All @@ -65,19 +93,36 @@ pub(super) fn run(config: Args) -> Result<()> {
.collect::<Vec<_>>();

let observations = build_inputs_from_desc(config.batch_size as u64, &shapes);

if config.print_input {
print_input(&observations);
}

inferer.infer_batch(observations.clone())?;

let start = Instant::now();
inferer.infer_batch(observations)?;
start.elapsed()
let res = inferer.infer_batch(observations)?;

let dur = start.elapsed();
if config.print_output {
print_output(&res);
}

dur
} else {
let shapes = inferer.input_shapes().to_vec();
let observations = build_inputs_from_desc(config.batch_size as u64, &shapes);
inferer.infer_batch(observations.clone())?;

let start = Instant::now();
inferer.infer_batch(observations)?;
start.elapsed()
let res = inferer.infer_batch(observations)?;

let dur = start.elapsed();
if config.print_output {
print_output(&res);
}

dur
};

println!(
Expand Down

0 comments on commit d8e0f2f

Please sign in to comment.