Skip to content

Commit

Permalink
Hparams: Generate metric values for data provider-based session group…
Browse files Browse the repository at this point in the history
…s. (#6543)

Generate metric values for hparams plugin `/session_groups` requests
when the session groups are generated from
DataProvider.read_hyperparameters().

We need to reuse the logic introduced in
#6539 to generate
metric_infos for each session group and also query for scalar values. We
reuse the existing logic to join the two collections of data into
metric_values for the `/session_groups` request.

We also continue the work begun in
#6541 to improve how we
generate sessions - in this case also handling cases where experiment_id
is not specified for the session. This became urgently necessary to
address in order to get new tests in list_session_groups_test.py to work
with existing test data.
  • Loading branch information
bmd3k authored Aug 14, 2023
1 parent d4df603 commit 11e188a
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 18 deletions.
18 changes: 17 additions & 1 deletion tensorboard/plugins/hparams/backend_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def compute_metric_infos_from_data_provider_session_groups(
self, ctx, experiment_id, session_groups
):
session_runs = set(
f"{s.experiment_id}/{s.run}" if s.run else s.experiment_id
generate_data_provider_session_name(experiment_id, s)
for sg in session_groups
for s in sg.sessions
)
Expand Down Expand Up @@ -460,6 +460,22 @@ def _compute_metric_names(self, ctx, experiment_id, session_runs):
return metric_names_list


def generate_data_provider_session_name(experiment_id, session):
"""Generates a name from a HyperparameterSesssionRun.
If the HyperparameterSessionRun contains no experiment or run information
then the name is set to the original experiment_id.
"""
if not session.experiment_id and not session.run:
return experiment_id
elif not session.experiment_id:
return session.run
elif not session.run:
return session.experiment_id
else:
return f"{session.experiment_id}/{session.run}"


def _find_longest_parent_path(path_set, path):
"""Finds the longest "parent-path" of 'path' in 'path_set'.
Expand Down
37 changes: 37 additions & 0 deletions tensorboard/plugins/hparams/backend_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,43 @@ def test_experiment_from_data_provider_session_group_without_run_name(self):
"""
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_from_data_provider_session_group_without_experiment_name(
self,
):
self._mock_tb_context.data_provider.list_tensors.side_effect = None
self._hyperparameters = provider.ListHyperparametersResult(
hyperparameters=[],
session_groups=[
provider.HyperparameterSessionGroup(
root=provider.HyperparameterSessionRun(
experiment_id="", run="exp/session_1"
),
sessions=[
provider.HyperparameterSessionRun(
experiment_id="", run="exp/session_1"
),
],
hyperparameter_values=[],
),
],
)
actual_exp = self._experiment_from_metadata()
expected_exp = """
metric_infos: {
name: {group: '', tag: 'accuracy'}
}
metric_infos: {
name: {group: '', tag: 'loss'}
}
metric_infos: {
name: {group: 'eval', tag: 'loss'}
}
metric_infos: {
name: {group: 'train', tag: 'loss'}
}
"""
self.assertProtoEquals(expected_exp, actual_exp)

def test_experiment_from_data_provider_old_response_type(self):
self._hyperparameters = [
provider.Hyperparameter(
Expand Down
70 changes: 53 additions & 17 deletions tensorboard/plugins/hparams/list_session_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@

from tensorboard.data import provider
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import backend_context as backend_context_lib
from tensorboard.plugins.hparams import error
from tensorboard.plugins.hparams import json_format_compat
from tensorboard.plugins.hparams import metadata
from tensorboard.plugins.hparams import metrics
from tensorboard.plugins.hparams import plugin_data_pb2


class Handler:
Expand Down Expand Up @@ -93,13 +95,15 @@ def _session_groups_from_tags(self):
hparams_run_to_tag_to_content,
# Don't pass any information from the DataProvider since we are only
# examining session groups based on tag metadata
[],
provider.ListHyperparametersResult(
hyperparameters=[], session_groups=[]
),
)
extractors = _create_extractors(self._request.col_params)
filters = _create_filters(self._request.col_params, extractors)

session_groups = self._build_session_groups(
hparams_run_to_tag_to_content, experiment
hparams_run_to_tag_to_content, experiment.metric_infos
)
session_groups = self._filter(session_groups, filters)
self._sort(session_groups, extractors)
Expand All @@ -116,16 +120,37 @@ def _session_groups_from_data_provider(self):
sort,
)

metric_infos = self._backend_context.compute_metric_infos_from_data_provider_session_groups(
self._request_context, self._experiment_id, response
)

all_metric_evals = self._backend_context.read_last_scalars(
self._request_context,
self._experiment_id,
run_tag_filter=None,
)

session_groups = []
for provider_group in response:
sessions = [
api_pb2.Session(name=f"{s.experiment_id}/{s.run}")
for s in provider_group.sessions
]
name = (
f"{provider_group.root.experiment_id}/{provider_group.root.run}"
if provider_group.root.run
else provider_group.root.experiment_id
sessions = []
for session in provider_group.sessions:
session_name = (
backend_context_lib.generate_data_provider_session_name(
self._experiment_id, session
)
)
sessions.append(
self._build_session(
metric_infos,
session_name,
plugin_data_pb2.SessionStartInfo(),
plugin_data_pb2.SessionEndInfo(),
all_metric_evals,
)
)

name = backend_context_lib.generate_data_provider_session_name(
self._experiment_id, provider_group.root
)
session_group = api_pb2.SessionGroup(
name=name,
Expand Down Expand Up @@ -154,9 +179,16 @@ def _session_groups_from_data_provider(self):

session_groups.append(session_group)

# Compute the session group's aggregated metrics for each group.
for group in session_groups:
if group.sessions:
self._aggregate_metrics(group)

return session_groups

def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
def _build_session_groups(
self, hparams_run_to_tag_to_content, metric_infos
):
"""Returns a list of SessionGroups protobuffers from the summary
data."""

Expand All @@ -178,7 +210,7 @@ def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
metric_runs = set()
metric_tags = set()
for session_name in session_names:
for metric in experiment.metric_infos:
for metric in metric_infos:
metric_name = metric.name
(run, tag) = metrics.run_tag_from_session_and_metric(
session_name, metric_name
Expand Down Expand Up @@ -207,7 +239,11 @@ def _build_session_groups(self, hparams_run_to_tag_to_content, experiment):
tag_to_content[metadata.SESSION_END_INFO_TAG]
)
session = self._build_session(
experiment, session_name, start_info, end_info, all_metric_evals
metric_infos,
session_name,
start_info,
end_info,
all_metric_evals,
)
if session.status in self._request.allowed_statuses:
self._add_session(session, start_info, groups_by_name)
Expand Down Expand Up @@ -263,7 +299,7 @@ def _add_session(self, session, start_info, groups_by_name):
groups_by_name[group_name] = group

def _build_session(
self, experiment, name, start_info, end_info, all_metric_evals
self, metric_infos, name, start_info, end_info, all_metric_evals
):
"""Builds a session object."""

Expand All @@ -273,7 +309,7 @@ def _build_session(
start_time_secs=start_info.start_time_secs,
model_uri=start_info.model_uri,
metric_values=self._build_session_metric_values(
experiment, name, all_metric_evals
metric_infos, name, all_metric_evals
),
monitor_url=start_info.monitor_url,
)
Expand All @@ -283,13 +319,13 @@ def _build_session(
return result

def _build_session_metric_values(
self, experiment, session_name, all_metric_evals
self, metric_infos, session_name, all_metric_evals
):
"""Builds the session metric values."""

# result is a list of api_pb2.MetricValue instances.
result = []
for metric_info in experiment.metric_infos:
for metric_info in metric_infos:
metric_name = metric_info.name
(run, tag) = metrics.run_tag_from_session_and_metric(
session_name, metric_name
Expand Down
Loading

0 comments on commit 11e188a

Please sign in to comment.