From 09e6e65547e2c5d434bf9d17408faaf04c58c5e5 Mon Sep 17 00:00:00 2001 From: "Felix L." <50841330+Felix-El@users.noreply.github.com> Date: Thu, 15 Aug 2024 20:50:09 +0200 Subject: [PATCH] Avoid memory leaks This commit replaces the `threadpool` crate with a handcrafted solution based on scoped threads. This leaves `valgrind` much happier than before. We also lose some dependency baggage. --- CHANGELOG.md | 4 ++++ Cargo.toml | 3 +-- src/lib.rs | 52 ++++++++++++++++++++++++++++++++-------------------- src/pool.rs | 27 +++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 22 deletions(-) create mode 100644 src/pool.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index e4631ed..05a9f21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Replace dependency on threadpool crate with a custom solution built on the + standard library only, and only using scoped threads + -> fixes memory leaks observed when running under valgrind +- up MSRV to 1.63 for scoped threads ## [0.7.3] - 2024-05-10 - Default to single-threaded tests for WebAssembly (thanks @alexcrichton) in [#41](https://github.com/LukasKalbertodt/libtest-mimic/pull/41) diff --git a/Cargo.toml b/Cargo.toml index 26d4bf1..6bae64b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "libtest-mimic" version = "0.7.3" authors = ["Lukas Kalbertodt "] edition = "2021" -rust-version = "1.60" +rust-version = "1.63" description = """ Write your own test harness that looks and behaves like the built-in test \ @@ -20,7 +20,6 @@ exclude = [".github"] [dependencies] clap = { version = "4.0.8", features = ["derive"] } -threadpool = "1.8.1" termcolor = "1.0.5" escape8259 = "0.5.2" diff --git a/src/lib.rs b/src/lib.rs index be1e65c..e4c2537 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,18 +71,22 @@ #![forbid(unsafe_code)] -use std::{borrow::Cow, fmt, process::{self, ExitCode}, sync::mpsc, time::Instant}; +use std::{ + borrow::Cow, + fmt, + process::{self, ExitCode}, + sync::mpsc, + time::Instant, +}; mod args; +mod pool; mod printer; use printer::Printer; -use threadpool::ThreadPool; pub use crate::args::{Arguments, ColorSetting, FormatSetting}; - - /// A single test or benchmark. /// /// The original `libtest` often calls benchmarks "tests", which is a bit @@ -143,8 +147,9 @@ impl Trial { Err(failed) => Outcome::Failed(failed), Ok(_) if test_mode => Outcome::Passed, Ok(Some(measurement)) => Outcome::Measured(measurement), - Ok(None) - => Outcome::Failed("bench runner returned `Ok(None)` in bench mode".into()), + Ok(None) => { + Outcome::Failed("bench runner returned `Ok(None)` in bench mode".into()) + } }), info: TestInfo { name: name.into(), @@ -284,13 +289,11 @@ impl Failed { impl From for Failed { fn from(msg: M) -> Self { Self { - msg: Some(msg.to_string()) + msg: Some(msg.to_string()), } } } - - /// The outcome of performing a test/benchmark. #[derive(Debug, Clone)] enum Outcome { @@ -473,7 +476,7 @@ pub fn run(args: &Arguments, mut tests: Vec) -> Conclusion { Outcome::Failed(failed) => { failed_tests.push((test, failed.msg)); conclusion.num_failed += 1; - }, + } Outcome::Ignored => conclusion.num_ignored += 1, Outcome::Measured(_) => conclusion.num_measured += 1, } @@ -481,7 +484,14 @@ pub fn run(args: &Arguments, mut tests: Vec) -> Conclusion { // Execute all tests. let test_mode = !args.bench; - if platform_defaults_to_one_thread() || args.test_threads == Some(1) { + + let num_threads = platform_defaults_to_one_thread() + .then_some(1) + .or(args.test_threads) + .or_else(|| std::thread::available_parallelism().ok().map(Into::into)) + .unwrap_or(1); + + if num_threads == 1 { // Run test sequentially in main thread for test in tests { // Print `test foo ...`, run the test, then print the outcome in @@ -496,28 +506,29 @@ pub fn run(args: &Arguments, mut tests: Vec) -> Conclusion { } } else { // Run test in thread pool. - let pool = match args.test_threads { - Some(num_threads) => ThreadPool::new(num_threads), - None => ThreadPool::default() - }; + let num_tests = tests.len(); let (sender, receiver) = mpsc::channel(); - let num_tests = tests.len(); - for test in tests { + let mut tasks: Vec = Default::default(); + + for test in tests.into_iter() { if args.is_ignored(&test) { sender.send((Outcome::Ignored, test.info)).unwrap(); } else { let sender = sender.clone(); - pool.execute(move || { + + tasks.push(Box::new(move || { // It's fine to ignore the result of sending. If the // receiver has hung up, everything will wind down soon // anyway. let outcome = run_single(test.runner, test_mode); let _ = sender.send((outcome, test.info)); - }); + })); } } + pool::scoped_run_tasks(tasks, num_threads); + for (outcome, test_info) in receiver.iter().take(num_tests) { // In multithreaded mode, we do only print the start of the line // after the test ran, as otherwise it would lead to terribly @@ -552,7 +563,8 @@ fn run_single(runner: Box Outcome + Send>, test_mode: bool) // The `panic` information is just an `Any` object representing the // value the panic was invoked with. For most panics (which use // `panic!` like `println!`), this is either `&str` or `String`. - let payload = e.downcast_ref::() + let payload = e + .downcast_ref::() .map(|s| s.as_str()) .or(e.downcast_ref::<&str>().map(|s| *s)); diff --git a/src/pool.rs b/src/pool.rs new file mode 100644 index 0000000..409ec9a --- /dev/null +++ b/src/pool.rs @@ -0,0 +1,27 @@ +use std::{sync, thread}; + +pub(crate) type Task = dyn FnOnce() + Send; +pub(crate) type BoxedTask = Box; + +pub(crate) fn scoped_run_tasks( + tasks: Vec, + num_threads: usize, +) { + if num_threads < 2 { + // There is another code path for num_threads == 1 running entirely in the main thread. + panic!("`run_on_scoped_pool` may not be called with `num_threads` less than 2"); + } + + let sync_iter = sync::Mutex::new(tasks.into_iter()); + let next_task = || sync_iter.lock().unwrap().next(); + + thread::scope(|scope| { + for _ in 0..num_threads { + scope.spawn(|| { + while let Some(task) = next_task() { + task(); + } + }); + } + }); +}