Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prints benchmark results in a neat table and attempts to run every benchmark #1464

Merged
merged 7 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 65 additions & 25 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
use super::{
auth::{save_token, CLIENT_ID},
App,
};
use crate::burnbenchapp::auth::{get_token_from_cache, verify_token};
use crate::persistence::{BenchmarkCollection, BenchmarkRecord};
use arboard::Clipboard;
use clap::{Parser, Subcommand, ValueEnum};
use github_device_flow::{self, DeviceFlow};
use serde_json;
use std::fs;
use std::io::{BufRead, BufReader, Result as ioResult};
use std::{
process::{Command, Stdio},
process::{Command, ExitStatus, Stdio},
thread, time,
};

use strum::IntoEnumIterator;
use strum_macros::{Display, EnumIter};

use crate::burnbenchapp::auth::{get_token_from_cache, verify_token};

use super::{
auth::{save_token, CLIENT_ID},
App,
};

const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
const BENCHMARKS_TARGET_DIR: &str = "target/benchmarks";
const USER_BENCHMARK_SERVER_URL: &str = if cfg!(debug_assertions) {
Expand Down Expand Up @@ -184,17 +186,12 @@ fn command_run(run_args: RunArgs) {
}
let total_combinations = run_args.backends.len() * run_args.benches.len();
println!(
"Executing the following benchmark and backend combinations (Total: {}):",
"Executing benchmark and backend combinations in total: {}",
total_combinations
);
for backend in &run_args.backends {
for bench in &run_args.benches {
println!("- Benchmark: {}, Backend: {}", bench, backend);
}
}
let mut app = App::new();
syl20bnr marked this conversation as resolved.
Show resolved Hide resolved
app.init();
println!("Running benchmarks...");
println!("Running benchmarks...\n");
app.run(
&run_args.benches,
&run_args.backends,
Expand All @@ -204,7 +201,7 @@ fn command_run(run_args: RunArgs) {
}

#[allow(unused)] // for tui as this is WIP
pub(crate) fn run_cargo(command: &str, params: &[&str]) {
pub(crate) fn run_cargo(command: &str, params: &[&str]) -> ioResult<ExitStatus> {
let mut cargo = Command::new("cargo")
.arg(command)
.arg("--color=always")
Expand All @@ -213,22 +210,36 @@ pub(crate) fn run_cargo(command: &str, params: &[&str]) {
.stderr(Stdio::inherit())
.spawn()
.expect("cargo process should run");
let status = cargo.wait().expect("");
if !status.success() {
std::process::exit(status.code().unwrap_or(1));
}
cargo.wait()
}

pub(crate) fn run_backend_comparison_benchmarks(
benches: &[BenchmarkValues],
backends: &[BackendValues],
token: Option<&str>,
) {
// Iterate over each combination of backend and bench
for backend in backends.iter() {
for bench in benches.iter() {
// Prefix and postfix for titles
let filler = ["="; 10].join("");

// Delete the file containing file paths to benchmark results, if existing
let benchmark_results_file = dirs::home_dir()
.expect("Home directory should exist")
.join(".cache")
.join("burn")
.join("backend-comparison")
.join("benchmark_results.txt");

fs::remove_file(benchmark_results_file.clone()).ok();

// Iterate through every combination of benchmark and backend
for bench in benches.iter() {
for backend in backends.iter() {
let bench_str = bench.to_string();
let backend_str = backend.to_string();
println!(
"{}Benchmarking {} on {}{}",
filler, bench_str, backend_str, filler
);
let mut args = vec![
"-p",
"backend-comparison",
Expand All @@ -246,7 +257,36 @@ pub(crate) fn run_backend_comparison_benchmarks(
args.push("--sharing-token");
args.push(t);
}
run_cargo("bench", &args);
let status = run_cargo("bench", &args).unwrap();
if !status.success() {
println!(
"Benchmark {} didn't ran successfully on the backend {}",
bench_str, backend_str
);
continue;
}
}
}

// Iterate though each benchmark result file present in backend-comparison/benchmark_results.txt
// and print them in a single table.
let mut benchmark_results = BenchmarkCollection::default();
if let Ok(file) = fs::File::open(benchmark_results_file.clone()) {
let file_reader = BufReader::new(file);
for file in file_reader.lines() {
let file_path = file.unwrap();
if let Ok(br_file) = fs::File::open(file_path.clone()) {
let benchmarkrecord =
serde_json::from_reader::<_, BenchmarkRecord>(br_file).unwrap();
benchmark_results.records.push(benchmarkrecord)
} else {
println!("Cannot find the benchmark-record file: {}", file_path);
};
}
println!(
"{}Benchmark Results{}\n\n{}",
filler, filler, benchmark_results
);
fs::remove_file(benchmark_results_file).ok();
}
}
207 changes: 201 additions & 6 deletions backend-comparison/src/persistence/base.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
use std::fs;

use burn::{
serde::{ser::SerializeStruct, Serialize, Serializer},
serde::{de::Visitor, ser::SerializeStruct, Deserialize, Serialize, Serializer},
tensor::backend::Backend,
};
use burn_common::benchmark::BenchmarkResult;
use dirs;
use reqwest::header::{HeaderMap, ACCEPT, AUTHORIZATION, USER_AGENT};
use serde_json;

use std::fmt::Display;
use std::time::Duration;
use std::{fs, io::Write};
#[derive(Default, Clone)]
pub struct BenchmarkRecord {
backend: String,
device: String,
results: BenchmarkResult,
pub results: BenchmarkResult,
}

/// Save the benchmarks results on disk.
Expand Down Expand Up @@ -77,10 +77,22 @@ pub fn save<B: Backend>(
record.results.name, record.results.timestamp
);
let file_path = cache_dir.join(file_name);
let file = fs::File::create(file_path).expect("Benchmark file should exist or be created");
let file =
fs::File::create(file_path.clone()).expect("Benchmark file should exist or be created");
serde_json::to_writer_pretty(file, &record)
.expect("Benchmark file should be updated with benchmark results");

// Append the benchmark result filepath in the benchmark_results.tx file of cache folder to be later picked by benchrun
let benchmark_results_path = cache_dir.join("benchmark_results.txt");
let mut benchmark_results_file = fs::OpenOptions::new()
.append(true)
.create(true)
.open(benchmark_results_path)
.unwrap();
benchmark_results_file
.write_all(format!("{}\n", file_path.to_string_lossy()).as_bytes())
.unwrap();

if url.is_some() {
println!("Sharing results...");
let client = reqwest::blocking::Client::new();
Expand Down Expand Up @@ -154,3 +166,186 @@ impl Serialize for BenchmarkRecord {
)
}
}

struct BenchmarkRecordVisitor;

impl<'de> Visitor<'de> for BenchmarkRecordVisitor {
type Value = BenchmarkRecord;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "Serialized Json object of BenchmarkRecord")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: burn::serde::de::MapAccess<'de>,
{
let mut br = BenchmarkRecord::default();
while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"backend" => br.backend = map.next_value::<String>()?,
"device" => br.device = map.next_value::<String>()?,
"gitHash" => br.results.git_hash = map.next_value::<String>()?,
"name" => br.results.name = map.next_value::<String>()?,
"max" => {
let value = map.next_value::<u64>()?;
br.results.computed.max = Duration::from_micros(value);
}
"mean" => {
let value = map.next_value::<u64>()?;
br.results.computed.mean = Duration::from_micros(value);
}
"median" => {
let value = map.next_value::<u64>()?;
br.results.computed.median = Duration::from_micros(value);
}
"min" => {
let value = map.next_value::<u64>()?;
br.results.computed.min = Duration::from_micros(value);
}
"options" => br.results.options = map.next_value::<Option<String>>()?,
"rawDurations" => br.results.raw.durations = map.next_value::<Vec<Duration>>()?,
"shapes" => br.results.shapes = map.next_value::<Vec<Vec<usize>>>()?,
"timestamp" => br.results.timestamp = map.next_value::<u128>()?,
"variance" => {
let value = map.next_value::<u64>()?;
br.results.computed.variance = Duration::from_micros(value)
}

"numSamples" => _ = map.next_value::<usize>()?,
_ => panic!("Unexpected Key: {}", key),
}
}

Ok(br)
}
}

impl<'de> Deserialize<'de> for BenchmarkRecord {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: burn::serde::Deserializer<'de>,
{
deserializer.deserialize_map(BenchmarkRecordVisitor)
}
}

#[derive(Default)]
pub(crate) struct BenchmarkCollection {
pub records: Vec<BenchmarkRecord>,
}

impl Display for BenchmarkCollection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(
f,
"| {0:<15}| {1:<35}| {2:<15}|\n|{3:-<16}|{4:-<36}|{5:-<16}|",
"Benchmark", "Backend", "Median", "", "", ""
)?;
for record in self.records.iter() {
let backend = [record.backend.clone(), record.device.clone()].join("-");
writeln!(
f,
"| {0:<15}| {1:<35}| {2:<15.3?}|",
record.results.name, backend, record.results.computed.median
)?;
}

Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn get_benchmark_result() {
let sample_result = r#"{
"backend": "candle",
"device": "Cuda(0)",
"gitHash": "02d37011ab4dc773286e5983c09cde61f95ba4b5",
"name": "unary",
"max": 8858,
"mean": 8629,
"median": 8592,
"min": 8506,
"numSamples": 10,
"options": null,
"rawDurations": [
{
"secs": 0,
"nanos": 8858583
},
{
"secs": 0,
"nanos": 8719822
},
{
"secs": 0,
"nanos": 8705335
},
{
"secs": 0,
"nanos": 8835636
},
{
"secs": 0,
"nanos": 8592507
},
{
"secs": 0,
"nanos": 8506423
},
{
"secs": 0,
"nanos": 8534337
},
{
"secs": 0,
"nanos": 8506627
},
{
"secs": 0,
"nanos": 8521615
},
{
"secs": 0,
"nanos": 8511474
}
],
"shapes": [
[
32,
512,
1024
]
],
"timestamp": 1710208069697,
"variance": 0
}"#;
let record = serde_json::from_str::<BenchmarkRecord>(sample_result).unwrap();
assert!(record.backend == "candle");
assert!(record.device == "Cuda(0)");
assert!(record.results.git_hash == "02d37011ab4dc773286e5983c09cde61f95ba4b5");
assert!(record.results.name == "unary");
assert!(record.results.computed.max.as_micros() == 8858);
assert!(record.results.computed.mean.as_micros() == 8629);
assert!(record.results.computed.median.as_micros() == 8592);
assert!(record.results.computed.min.as_micros() == 8506);
assert!(record.results.options.is_none());
assert!(record.results.shapes == vec![vec![32, 512, 1024]]);
assert!(record.results.timestamp == 1710208069697);
assert!(record.results.computed.variance.as_micros() == 0);

//Check raw durations
assert!(record.results.raw.durations[0] == Duration::from_nanos(8858583));
assert!(record.results.raw.durations[1] == Duration::from_nanos(8719822));
assert!(record.results.raw.durations[2] == Duration::from_nanos(8705335));
assert!(record.results.raw.durations[3] == Duration::from_nanos(8835636));
assert!(record.results.raw.durations[4] == Duration::from_nanos(8592507));
assert!(record.results.raw.durations[5] == Duration::from_nanos(8506423));
assert!(record.results.raw.durations[6] == Duration::from_nanos(8534337));
assert!(record.results.raw.durations[7] == Duration::from_nanos(8506627));
assert!(record.results.raw.durations[8] == Duration::from_nanos(8521615));
assert!(record.results.raw.durations[9] == Duration::from_nanos(8511474));
}
}
Loading