Skip to content

Commit

Permalink
[FEAT] Refactor logging 3/3 [v1] (#4552)
Browse files Browse the repository at this point in the history
* wip

* wip check how many tests break

* wip

* resolve some bugs

* resolve more bugs

* resolve 2 bugs

* resolve

* temp fix

* update

* remove useless code

* remove result

* try to resolve bug

* update changelog

* formatting

* remove pl

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
  • Loading branch information
3 people authored Nov 11, 2020
1 parent 514cb22 commit 3d202f9
Show file tree
Hide file tree
Showing 9 changed files with 501 additions and 214 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))


- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))
- Added logging using `self.log` in train and evaluation for most callbacks and model hooks (
[#4552](https://github.com/PyTorchLightning/pytorch-lightning/pull/4552),
[#4495](https://github.com/PyTorchLightning/pytorch-lightning/pull/4495),
[#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)
)

- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))

- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from collections import defaultdict, ChainMap
from enum import Enum
from typing import Union, Tuple, Any, Dict, Optional, List
Expand Down Expand Up @@ -419,13 +420,14 @@ def update_logger_connector(self, fx_name: str = None) -> None:
logger_connector = self.trainer.logger_connector

callback_metrics = {}
is_train = self._stage in LoggerStages.TRAIN.value

if not self._has_batch_loop_finished:
# get pbar
batch_pbar_metrics = self.get_latest_batch_pbar_metrics()
logger_connector.add_progress_bar_metrics(batch_pbar_metrics)

if self._stage in LoggerStages.TRAIN.value:
if is_train:
# Only log and add to callback epoch step during evaluation, test.
batch_log_metrics = self.get_latest_batch_log_metrics()
logger_connector.logged_metrics.update(batch_log_metrics)
Expand All @@ -443,6 +445,9 @@ def update_logger_connector(self, fx_name: str = None) -> None:
epoch_log_metrics = self.get_epoch_log_metrics()
logger_connector.logged_metrics.update(epoch_log_metrics)
logger_connector.logged_metrics.update(epoch_dict)
if not self.trainer.running_sanity_check and not is_train:
if len(epoch_log_metrics) > 0:
self.trainer.dev_debugger.track_logged_metrics_history(deepcopy(epoch_log_metrics))

# get forked_metrics
forked_metrics = self.get_forked_metrics()
Expand All @@ -451,6 +456,9 @@ def update_logger_connector(self, fx_name: str = None) -> None:
callback_metrics.update(epoch_log_metrics)
callback_metrics.update(forked_metrics)

if not is_train:
logger_connector.evaluation_callback_metrics.update(callback_metrics)

# update callback_metrics
logger_connector.callback_metrics.update(callback_metrics)
logger_connector.callback_metrics.pop("epoch", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LoggerConnector:
def __init__(self, trainer):
self.trainer = trainer
self.callback_metrics = {}
self.evaluation_callback_metrics = {}
self.logged_metrics = {}
self.progress_bar_metrics = {}
self.eval_loop_results = []
Expand All @@ -59,10 +60,9 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc
on_epoch=on_epoch)

def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders):
# reset the result of the PL module
model = self.trainer.get_model()
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

# track batch_size
self.cached_results._batch_size = Result.extract_batch_size(batch)

Expand Down Expand Up @@ -226,19 +226,41 @@ def add_progress_bar_metrics(self, metrics):

self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode):
def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result, test_mode):
self._track_callback_metrics(deprecated_eval_results, using_eval_result)

# TODO: deprecate parts of this for 1.0 (when removing results)
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode)

self._log_on_evaluation_epoch_end_metrics(epoch_logs)
def evaluation_epoch_end(self, testing):
# reset dataloader idx
model_ref = self.trainer.get_model()
model_ref._current_dataloader_idx = None

# setting `has_batch_loop_finished` to True
# will perform Results reduction accross entire epoch.
self.cached_results.has_batch_loop_finished = True

def add_to_eval_loop_results(self, dl_idx, has_been_initialized):
callback_metrics = deepcopy(self.evaluation_callback_metrics)
for key in list(callback_metrics.keys()):
if "dataloader_idx" in key:
if f"dataloader_idx_{dl_idx}" not in key:
# remove dl_idx from self.callback_metrics not belonging to this dataset.
del callback_metrics[key]
if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
self.eval_loop_results.append(callback_metrics)

# get the final loop results
eval_loop_results = self._get_evaluate_epoch_results(test_mode)
return eval_loop_results
def prepare_eval_loop_results(self):
num_dataloaders = self.trainer.evaluation_loop.num_dataloaders
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
self.add_to_eval_loop_results(dl_idx, has_been_initialized)

def get_evaluate_epoch_results(self, test_mode):

self.prepare_eval_loop_results()

def _get_evaluate_epoch_results(self, test_mode):
# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
print('-' * 80)
Expand All @@ -253,106 +275,6 @@ def _get_evaluate_epoch_results(self, test_mode):
self.eval_loop_results = []
return results

def _log_on_evaluation_epoch_end_metrics(self, epoch_logs):
step_metrics = self.trainer.evaluation_loop.step_metrics

num_loaders = len(step_metrics)

# clear mem
self.trainer.evaluation_loop.step_metrics = []

if self.trainer.running_sanity_check:
return

# track all metrics we want to log
metrics_to_log = []

# ---------------------------
# UPDATE EPOCH LOGGED METRICS
# ---------------------------
# (ie: in methods at the val_epoch_end level)
# union the epoch logs with whatever was returned from loaders and reduced
epoch_logger_metrics = epoch_logs.get_epoch_log_metrics()
epoch_pbar_metrics = epoch_logs.get_epoch_pbar_metrics()

self.logged_metrics.update(epoch_logger_metrics)
self.add_progress_bar_metrics(epoch_pbar_metrics)

# enable the metrics to be monitored
self.callback_metrics.update(epoch_logger_metrics)
self.callback_metrics.update(epoch_pbar_metrics)

if len(epoch_logger_metrics) > 0:
metrics_to_log.append(epoch_logger_metrics)

# --------------------------------
# UPDATE METRICS PER DATALOADER
# --------------------------------
# each dataloader aggregated metrics
# now we log all of them
for dl_idx, dl_metrics in enumerate(step_metrics):
if len(dl_metrics) == 0:
# Ensure custom logged metrics are included if not included with step metrics
if len(epoch_logger_metrics) > 0:
self.eval_loop_results.append(epoch_logger_metrics)
continue

reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics)
# track the metrics
logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics()
pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics()
forked_metrics = reduced_epoch_metrics.get_forked_metrics()

# make the keys 'k/dl'
logger_metrics = self.__rename_keys_by_dataloader_idx(logger_metrics, dl_idx, num_loaders)
pbar_metrics = self.__rename_keys_by_dataloader_idx(pbar_metrics, dl_idx, num_loaders)
forked_metrics = self.__rename_keys_by_dataloader_idx(forked_metrics, dl_idx, num_loaders)

self.logged_metrics.update(logger_metrics)
self.add_progress_bar_metrics(pbar_metrics)

# enable the metrics to be monitored
self.callback_metrics.update(logger_metrics)
self.callback_metrics.update(pbar_metrics)

# forked metrics were dropped, enable them for callbacks
self.callback_metrics.update(forked_metrics)

# track the final results for the dataloader
self.add_to_eval_loop_results(dl_idx, num_loaders)

# actually log
if len(logger_metrics) > 0:
metrics_to_log.append(logger_metrics)

# log all the metrics as a s single dict
metrics_to_log = dict(ChainMap(*metrics_to_log))
if len(metrics_to_log) > 0:
self.log_metrics(metrics_to_log, {})

def add_to_eval_loop_results(self, dl_idx, num_loaders):
callback_metrics = deepcopy(self.callback_metrics)
if num_loaders == 1:
if len(self.eval_loop_results) > 0:
self.eval_loop_results[0].update(callback_metrics)
else:
self.eval_loop_results.append(callback_metrics)
return

for key in list(callback_metrics.keys()):
if "dataloader_idx" in key:
if f"dataloader_idx_{dl_idx}" not in key:
# remove dl_idx from self.callback_metrics not belonging to this dataset.
del callback_metrics[key]
self.eval_loop_results.append(callback_metrics)

def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
if num_loaders == 1:
return metrics

result = {f'{k}/dataloader_idx_{dataloader_idx}': v for k, v in metrics.items()}
return result

def _track_callback_metrics(self, eval_results, using_eval_result):
if (
len(eval_results) > 0 and
Expand All @@ -364,8 +286,10 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
if isinstance(eval_results, list):
for eval_result in eval_results:
self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics)
else:
self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
else:
flat = {}
if isinstance(eval_results, list):
Expand All @@ -381,6 +305,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
else:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_results, torch.Tensor):
Expand All @@ -393,6 +318,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)

def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
# eval loop returns all metrics
Expand All @@ -406,9 +332,10 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric
self.trainer.logger_connector.log_metrics(log_metrics, {})

# track metrics for callbacks (all prog bar, logged and callback metrics)
callback_metrics.update(log_metrics)
callback_metrics.update(prog_bar_metrics)
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
self.trainer.logger_connector.callback_metrics.update(log_metrics)
self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)

if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)
Expand Down
Loading

0 comments on commit 3d202f9

Please sign in to comment.