Skip to content

Commit

Permalink
add a progress_bar for loading a BalancingLearner
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed May 27, 2019
1 parent d79512f commit 671d831
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
21 changes: 18 additions & 3 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def save(self, fname, compress=True):
for l in self.learners:
l.save(fname(l), compress=compress)

def load(self, fname, compress=True):
def load(self, fname, compress=True, with_progress_bar=False):
"""Load the data of the child learners from pickle files
in a directory.
Expand All @@ -389,16 +389,31 @@ def load(self, fname, compress=True):
compress : bool, default True
If the data is compressed when saved, one must load it
with compression too.
with_progress_bar : bool, default False
Display a progress bar using `tqdm`.
Example
-------
See the example in the `BalancingLearner.save` doc-string.
"""
def progress(seq):
if not with_progress_bar:
return seq
else:
from adaptive.notebook_integration import in_ipynb
desc = "Loading learners."
if in_ipynb():
from tqdm import tqdm_notebook
return tqdm_notebook(list(seq), desc=desc)
else:
from tqdm import tqdm
return tqdm(list(seq), desc=desc)

if isinstance(fname, Iterable):
for l, _fname in zip(self.learners, fname):
for l, _fname in progress(zip(self.learners, fname)):
l.load(_fname, compress=compress)
else:
for l in self.learners:
for l in progress(self.learners):
l.load(fname(l), compress=compress)

def _get_data(self):
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ dependencies:
- ipywidgets
- scikit-optimize
- plotly
- tqdm
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_version_and_cmdclass(package_name):
"bokeh",
"matplotlib",
"plotly",
"tqdm",
]
}

Expand Down

0 comments on commit 671d831

Please sign in to comment.