Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch_size, rank_zero_only arguments for log_dict to match log #8628

Merged
merged 5 commits into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))

- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))

-

Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def log_dict(
sync_dist_op: Optional[Any] = None, # noqa: Remove in 1.6
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
rank_zero_only: Optional[bool] = None,
eladsegal marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Log a dictionary of values at once.
Expand All @@ -502,6 +504,10 @@ def log_dict(
add_dataloader_idx: if True, appends the index of the current dataloader to
the name (when using multiple). If False, user needs to give unique names for
each dataloader to not mix values
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but some data structures might need to explicitly provide it.
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
for k, v in dictionary.items():
self.log(
Expand All @@ -519,6 +525,8 @@ def log_dict(
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
rank_zero_only=rank_zero_only,
)

@staticmethod
Expand Down