From 2598d56e5db1d9296054945e48fb754e7bbf4646 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 11 Jun 2024 12:53:21 +0200 Subject: [PATCH] feat: Add progress bar on terminal --- Cargo.lock | 138 +++++++++++++++++++++++----------- Cargo.toml | 5 +- python/nutpie/compile_pymc.py | 7 +- python/nutpie/compile_stan.py | 7 +- python/nutpie/sample.py | 28 ++++--- src/progress.rs | 63 +++++++++++++++- src/wrapper.rs | 91 ++++++++++++++++------ tests/test_pymc.py | 21 ++++++ 8 files changed, 272 insertions(+), 88 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ff12f73..8aa21f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -60,9 +60,9 @@ checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" [[package]] name = "arrow" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219d05930b81663fd3b32e3bde8ce5bff3c4d23052a99f11a8fa50a3b47b2658" +checksum = "7ae9728f104939be6d8d9b368a354b4929b0569160ea1641f0721b55a861ce38" dependencies = [ "arrow-arith", "arrow-array", @@ -78,9 +78,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0272150200c07a86a390be651abdd320a2d12e84535f0837566ca87ecd8f95e0" +checksum = "a7029a5b3efbeafbf4a12d12dc16b8f9e9bff20a410b8c25c5d28acc089e1043" dependencies = [ "arrow-array", "arrow-buffer", @@ -93,9 +93,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8010572cf8c745e242d1b632bd97bd6d4f40fefed5ed1290a8f433abaa686fea" +checksum = "d33238427c60271710695f17742f45b1a5dc5bcfc5c15331c25ddfe7abf70d97" dependencies = [ "ahash", "arrow-buffer", @@ -109,9 +109,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d0a2432f0cba5692bf4cb757469c66791394bac9ec7ce63c1afe74744c37b27" +checksum = "fe9b95e825ae838efaf77e366c00d3fc8cca78134c9db497d6bda425f2e7b7c1" dependencies = [ "bytes", "half", @@ -120,9 +120,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9abc10cd7995e83505cc290df9384d6e5412b207b79ce6bdff89a10505ed2cba" +checksum = "87cf8385a9d5b5fcde771661dd07652b79b9139fea66193eda6a88664400ccab" dependencies = [ "arrow-array", "arrow-buffer", @@ -140,9 +140,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2742ac1f6650696ab08c88f6dd3f0eb68ce10f8c253958a18c943a68cd04aec5" +checksum = "cb29be98f987bcf217b070512bb7afba2f65180858bca462edf4a39d84a23e10" dependencies = [ "arrow-buffer", "arrow-schema", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3e6b61e3dc468f503181dccc2fc705bdcc5f2f146755fa5b56d0a6c5943f412" +checksum = "fcb56ed1547004e12203652f12fe12e824161ff9d1e5cf2a7dc4ff02ba94f413" dependencies = [ "arrow-array", "arrow-buffer", @@ -167,9 +167,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "848ee52bb92eb459b811fb471175ea3afcf620157674c8794f539838920f9228" +checksum = "575b42f1fc588f2da6977b94a5ca565459f5ab07b60545e17243fb9a7ed6d43e" dependencies = [ "ahash", "arrow-array", @@ -182,18 +182,18 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02d9483aaabe910c4781153ae1b6ae0393f72d9ef757d38d09d450070cf2e528" +checksum = "32aae6a60458a2389c0da89c9de0b7932427776127da1a738e2efc21d32f3393" dependencies = [ "bitflags 2.5.0", ] [[package]] name = "arrow-select" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "849524fa70e0e3c5ab58394c770cb8f514d0122d20de08475f7b472ed8075830" +checksum = "de36abaef8767b4220d7b4a8c2fe5ffc78b47db81b03d77e2136091c3ba39102" dependencies = [ "ahash", "arrow-array", @@ -205,9 +205,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9373cb5a021aee58863498c37eb484998ef13377f69989c6c5ccfbd258236cdb" +checksum = "e435ada8409bcafc910bc3e0077f532a4daa20e99060a496685c0e3e53cc2597" dependencies = [ "arrow-array", "arrow-buffer", @@ -342,9 +342,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.98" +version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" +checksum = "96c51067fd44124faa7f870b4b1c969379ad32b2ba805aa959430ceaa384f695" [[package]] name = "cexpr" @@ -402,9 +402,9 @@ dependencies = [ [[package]] name = "clang-sys" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a483f3cbf7cec2e153d424d0e92329d816becc6421389bd494375c6065921b9b" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" dependencies = [ "glob", "libc", @@ -413,18 +413,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.4" +version = "4.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +checksum = "5db83dced34638ad474f39f250d7fea9598bdd239eaced1bdf45d597da0f433f" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.2" +version = "4.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +checksum = "f7e204572485eb3fbf28f871612191521df159bc3e15a9f5064c66dba3a8c05f" dependencies = [ "anstyle", "clap_lex", @@ -432,9 +432,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" [[package]] name = "coe-rs" @@ -442,6 +442,19 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e8f1e641542c07631228b1e0dc04b69ae3c1d58ef65d5691a439711d805c698" +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys", +] + [[package]] name = "const-random" version = "0.1.18" @@ -586,6 +599,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "enum-as-inner" version = "0.6.0" @@ -878,12 +897,34 @@ dependencies = [ "cc", ] +[[package]] +name = "indicatif" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + [[package]] name = "indoc" version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if", +] + [[package]] name = "is-terminal" version = "0.4.12" @@ -1306,6 +1347,12 @@ dependencies = [ "libm", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "numpy" version = "0.21.0" @@ -1329,6 +1376,7 @@ dependencies = [ "arrow", "bridgestan", "criterion", + "indicatif", "itertools 0.13.0", "ndarray", "numpy", @@ -1346,9 +1394,9 @@ dependencies = [ [[package]] name = "nuts-rs" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f206f155b5652ce881cff9d8cbf683f3f73c9e476458cbc2d8a160b4aa443a9" +checksum = "93c95d63c0d52a79a61ff7343323fc73f7d41415cd3fc4f0439ea4efb9183dfb" dependencies = [ "anyhow", "arrow", @@ -1501,9 +1549,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.84" +version = "1.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" +checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" dependencies = [ "unicode-ident", ] @@ -1698,9 +1746,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.4" +version = "1.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" dependencies = [ "aho-corasick", "memchr", @@ -1710,9 +1758,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" dependencies = [ "aho-corasick", "memchr", @@ -1721,9 +1769,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "rustc-hash" @@ -1942,6 +1990,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-width" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" + [[package]] name = "unindent" version = "0.2.3" diff --git a/Cargo.toml b/Cargo.toml index 0e3682c..a2c187c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ name = "_lib" crate-type = ["cdylib"] [dependencies] -nuts-rs = "0.10.0" +nuts-rs = "0.11.0" numpy = "0.21.0" ndarray = "0.15.6" rand = "0.8.5" @@ -29,7 +29,7 @@ thiserror = "1.0.44" rand_chacha = "0.3.1" rayon = "1.9.0" # Keep arrow in sync with nuts-rs requirements -arrow = { version = "51.0.0", default-features = false, features = ["ffi"] } +arrow = { version = "52.0.0", default-features = false, features = ["ffi"] } anyhow = "1.0.72" itertools = "0.13.0" bridgestan = "2.4.1" @@ -37,6 +37,7 @@ rand_distr = "0.4.3" smallvec = "1.11.0" upon = { version = "0.8.1", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } +indicatif = "0.17.8" [dependencies.pyo3] version = "0.21.0" diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index b9c28b2..4df2994 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -87,10 +87,13 @@ def with_data(self, **updates): user_data=user_data, ) - def _make_sampler(self, settings, init_mean, cores, template, rate, callback=None): + def _make_sampler(self, settings, init_mean, cores, progress_type): model = self._make_model(init_mean) return _lib.PySampler.from_pymc( - settings, cores, model, template, rate, callback + settings, + cores, + model, + progress_type, ) def _make_model(self, init_mean): diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 03f5444..7a28052 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -80,10 +80,13 @@ def _make_model(self, init_mean): return self.with_data().model return self.model - def _make_sampler(self, settings, init_mean, cores, template, rate, callback=None): + def _make_sampler(self, settings, init_mean, cores, progress_type): model = self._make_model(init_mean) return _lib.PySampler.from_stan( - settings, cores, model, template, rate, callback + settings, + cores, + model, + progress_type, ) @property diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 9c818c8..20beecf 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -254,7 +254,7 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs): value="{{ chain.finished_draws }}"> - {{ chain.total_draws }} + {{ chain.finished_draws }} {{ chain.divergences }} {{ chain.step_size }} {{ chain.latest_num_steps }} @@ -335,16 +335,16 @@ def __init__( self._html = None - if progress_template is None: - progress_template = _progress_template + if not progress_bar: + progress_type = _lib.ProgressType.none() - if progress_style is None: - progress_style = _progress_style + elif in_notebook(): + if progress_template is None: + progress_template = _progress_template + + if progress_style is None: + progress_style = _progress_style - if not progress_bar or not in_notebook(): - progress_template = "" - callback = None - else: import IPython self._html = "" @@ -358,13 +358,17 @@ def callback(formatted): self._html = formatted self.display_id.update(self) + progress_type = _lib.ProgressType.template_callback( + progress_rate, progress_template, cores, callback + ) + else: + progress_type = _lib.ProgressType.indicatif(progress_rate) + self._sampler = compiled_model._make_sampler( settings, init_mean, cores, - progress_template, - progress_rate, - callback=callback, + progress_type, ) def wait(self, *, timeout=None): diff --git a/src/progress.rs b/src/progress.rs index 8c130e5..906c50d 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -1,6 +1,7 @@ use std::{collections::BTreeMap, time::Duration}; use anyhow::{Context, Result}; +use indicatif::ProgressBar; use nuts_rs::{ChainProgress, ProgressCallback}; use pyo3::{Py, PyAny, Python}; use time_humanize::{Accuracy, Tense}; @@ -193,14 +194,16 @@ fn estimate_remaining_time( time_sampling: Duration, progress: &[ChainProgress], ) -> Option { - let finished_draws: f64 = progress + let finished_draws: u64 = progress .iter() - .map(|chain| chain.finished_draws as f64) + .map(|chain| chain.finished_draws as u64) .sum(); - if !(finished_draws > 0.) { + if finished_draws == 0 { return None; } + let finished_draws = finished_draws as f64; + // TODO this assumes that so far all cores were used all the time let time_per_draw = time_sampling.mul_f64((n_cores as f64) / finished_draws); @@ -221,3 +224,57 @@ fn estimate_remaining_time( Some(core_times.into_iter().max().unwrap_or(Duration::ZERO)) } + +pub struct IndicatifHandler { + rate: Duration, +} + +impl IndicatifHandler { + pub fn new(rate: Duration) -> Self { + Self { rate } + } + + pub fn into_callback(self) -> Result { + let mut finished = false; + let mut last_draws = 0; + let mut bar = None; + + let callback = move |_time_sampling, progress: Box<[ChainProgress]>| { + let total: u64 = progress.iter().map(|chain| chain.total_draws as u64).sum(); + + if bar.is_none() { + bar = Some(ProgressBar::new(total)); + } + + let Some(ref bar) = bar else { unreachable!() }; + + if finished { + return; + } + if progress + .iter() + .all(|chain| chain.finished_draws == chain.total_draws) + { + finished = true; + bar.set_position(total); + bar.finish(); + } + + let finished_draws: u64 = progress + .iter() + .map(|chain| chain.finished_draws as u64) + .sum(); + + let delta = finished_draws.saturating_sub(last_draws); + if delta > 0 { + bar.set_position(finished_draws); + last_draws = finished_draws; + } + }; + + Ok(ProgressCallback { + callback: Box::new(callback), + rate: self.rate, + }) + } +} diff --git a/src/wrapper.rs b/src/wrapper.rs index f6d6faa..ee1c908 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -4,7 +4,7 @@ use std::{ }; use crate::{ - progress::ProgressHandler, + progress::{IndicatifHandler, ProgressHandler}, pymc::{ExpandFunc, LogpFunc, PyMcModel}, stan::{StanLibrary, StanModel}, }; @@ -240,26 +240,72 @@ pub(crate) enum SamplerState { Empty, } +#[derive(Clone)] #[pyclass] -struct PySampler(SamplerState); +pub enum ProgressType { + Callback { + rate: Duration, + n_cores: usize, + template: String, + callback: Py, + }, + Indicatif { + rate: Duration, + }, + None {}, +} + +impl ProgressType { + fn into_callback(self) -> Result> { + match self { + ProgressType::Callback { + callback, + rate, + n_cores, + template, + } => { + let handler = ProgressHandler::new(callback, rate, template, n_cores); + let callback = handler.into_callback()?; + + Ok(Some(callback)) + } + ProgressType::Indicatif { rate } => { + let handler = IndicatifHandler::new(rate); + Ok(Some(handler.into_callback()?)) + } + ProgressType::None {} => Ok(None), + } + } +} + +#[pymethods] +impl ProgressType { + #[staticmethod] + fn indicatif(rate: u64) -> Self { + let rate = Duration::from_millis(rate); + ProgressType::Indicatif { rate } + } + + #[staticmethod] + fn none() -> Self { + ProgressType::None {} + } -fn make_callback( - template: String, - n_cores: usize, - rate: Duration, - callback: Option>, -) -> Result> { - match callback { - Some(callback) => { - let handler = ProgressHandler::new(callback, rate, template, n_cores); - let callback = handler.into_callback()?; - - Ok(Some(callback)) + #[staticmethod] + fn template_callback(rate: u64, template: String, n_cores: usize, callback: Py) -> Self { + let rate = Duration::from_millis(rate); + ProgressType::Callback { + callback, + template, + n_cores, + rate, } - None => Ok(None), } } +#[pyclass] +struct PySampler(SamplerState); + #[pymethods] impl PySampler { #[staticmethod] @@ -267,12 +313,9 @@ impl PySampler { settings: PyDiagGradNutsSettings, cores: usize, model: PyMcModel, - template: String, - rate: u64, - callback: Option>, + progress_type: ProgressType, ) -> PyResult { - let rate = Duration::from_millis(rate); - let callback = make_callback(template, cores, rate, callback)?; + let callback = progress_type.into_callback()?; let sampler = Sampler::new(model, settings.0, cores, callback)?; Ok(PySampler(SamplerState::Running(sampler))) } @@ -282,12 +325,9 @@ impl PySampler { settings: PyDiagGradNutsSettings, cores: usize, model: StanModel, - template: String, - rate: u64, - callback: Option>, + progress_type: ProgressType, ) -> PyResult { - let rate = Duration::from_millis(rate); - let callback = make_callback(template, cores, rate, callback)?; + let callback = progress_type.into_callback()?; let sampler = Sampler::new(model, settings.0, cores, callback)?; Ok(PySampler(SamplerState::Running(sampler))) } @@ -514,6 +554,7 @@ pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; Ok(()) } diff --git a/tests/test_pymc.py b/tests/test_pymc.py index d39b06e..ef249cf 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -57,6 +57,27 @@ def test_pymc_model_with_coordinate(): trace.posterior.a # noqa: B018 +def test_pymc_model_store_extra(): + with pm.Model() as model: + model.add_coord("foo", length=5) + pm.Normal("a", dims="foo") + + compiled = nutpie.compile_pymc_model(model) + trace = nutpie.sample( + compiled, + chains=1, + store_mass_matrix=True, + store_divergences=True, + store_unconstrained=True, + store_gradient=True, + ) + trace.posterior.a # noqa: B018 + _ = trace.sample_stats.unconstrained_draw + _ = trace.sample_stats.gradient + _ = trace.sample_stats.divergence_start + _ = trace.sample_stats.mass_matrix_inv + + def test_trafo(): with pm.Model() as model: pm.Uniform("a")