Skip to content

Commit

Permalink
[tune] Fix Analysis.dataframe() documentation and enable passing of…
Browse files Browse the repository at this point in the history
… `mode=None` (#18850)
  • Loading branch information
krfricke authored Sep 23, 2021
1 parent cc84f18 commit 2d46e0e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
25 changes: 19 additions & 6 deletions python/ray/tune/analysis/experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,31 @@ def dataframe(self,
mode: Optional[str] = None) -> DataFrame:
"""Returns a pandas.DataFrame object constructed from the trials.
This function will look through all observed results of each trial
and return the one corresponding to the passed ``metric`` and
``mode``: If ``mode=min``, it returns the result with the lowest
*ever* observed ``metric`` for this trial (this is not necessarily
the last)! For ``mode=max``, it's the highest, respectively. If
``metric=None`` or ``mode=None``, the last result will be returned.
Args:
metric (str): Key for trial info to order on.
If None, uses last result.
mode (str): One of [min, max].
mode (None|str): One of [None, "min", "max"].
Returns:
pd.DataFrame: Constructed from a result dict of each trial.
"""
# Allow None values here.
if metric or self.default_metric:
metric = self._validate_metric(metric)
if mode or self.default_mode:
mode = self._validate_mode(mode)
# Do not validate metric/mode here or set from default metric/mode!
# Otherwise we will get confusing results as the lowest ever observed
# result may not be the last result.
if mode and mode not in ["min", "max"]:
raise ValueError("If set, `mode` has to be one of [min, max]")

if mode and not metric:
raise ValueError(
"If a `mode` is passed to `Analysis.dataframe(), you'll "
"also have to pass a `metric`!")

rows = self._retrieve_rows(metric=metric, mode=mode)
all_configs = self.get_all_configs(prefix=True)
Expand Down Expand Up @@ -343,6 +355,7 @@ def _retrieve_rows(self,
metric: Optional[str] = None,
mode: Optional[str] = None) -> Dict[str, Any]:
assert mode is None or mode in ["max", "min"]
assert not mode or metric
rows = {}
for path, df in self.trial_dataframes.items():
if mode == "max":
Expand Down
37 changes: 37 additions & 0 deletions python/ray/tune/tests/test_experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,43 @@ def train(config):
self.assertEquals(ea.best_result_df.loc[trials[2].trial_id, "res"],
309)

def testDataframeBestResult(self):
def train(config):
if config["var"] == 1:
tune.report(loss=9)
tune.report(loss=7)
tune.report(loss=5)
else:
tune.report(loss=10)
tune.report(loss=4)
tune.report(loss=10)

analysis = tune.run(
train,
config={"var": tune.grid_search([1, 2])},
metric="loss",
mode="min")

self.assertEqual(analysis.best_config["var"], 1)

with self.assertRaises(ValueError):
# Should raise because we didn't pass a metric
df = analysis.dataframe(mode="max")

# If we specify `min`, we expect the lowest ever observed result
df = analysis.dataframe(metric="loss", mode="min")
var = df[df.loss == df.loss.min()]["config/var"].values[0]
self.assertEqual(var, 2)

# If we don't pass a mode, we just fetch the last result
df = analysis.dataframe(metric="loss")
var = df[df.loss == df.loss.min()]["config/var"].values[0]
self.assertEqual(var, 1)

df = analysis.dataframe()
var = df[df.loss == df.loss.min()]["config/var"].values[0]
self.assertEqual(var, 1)


if __name__ == "__main__":
import pytest
Expand Down

0 comments on commit 2d46e0e

Please sign in to comment.