Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[profiling]add show_straggler argument to log_summary() #3579

Merged
merged 17 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,13 @@ def monitored_barrier(group=None,
return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)


def log_summary():
def log_summary(show_straggler=False):
global cdb
barrier(log_name='log_summary_barrier')
if cdb.get_rank() == 0:
comms_logger.log_all()
comms_logger.log_all(print_log=True, show_straggler=show_straggler)
else:
comms_logger.log_all(print_log=False, show_straggler=show_straggler)
barrier(log_name='log_summary_barrier')


Expand Down
45 changes: 39 additions & 6 deletions deepspeed/utils/comms_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,18 @@ def append(self, raw_name, record_name, latency, msg_size):
log_dist(log_str, [0])

# Print summary at end of iteration, epoch, or training
def log_all(self):
def log_all(self, print_log=True, show_straggler=False):
import torch
from deepspeed.utils.timer import trim_mean
print(
f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}"
)
import deepspeed.comm as dist
from deepspeed.comm.reduce_op import ReduceOp
if print_log:
print(
f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}"
)
for record_name in self.comms_dict.keys():
print(record_name)
if print_log:
print(record_name)
for msg_size, vals in sorted(self.comms_dict[record_name].items()):
# vals[0] is the count for each msg size
count = vals[0]
Expand All @@ -139,6 +144,34 @@ def log_all(self):
avg_lat = trim_mean(vals[1], 0.1)
avg_algbw = trim_mean(vals[2], 0.1)
avg_busbw = trim_mean(vals[3], 0.1)
if print_log:
print(
f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}"
)

if show_straggler:
if print_log:
print("_______________________________")
print("Breakdown with straggler effect")
print("-------------------------------")
print(
f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}"
f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total comm lat(ms)': <20}{'Total straggler(ms)': <20}{'Avg comm lat(ms)': <20}{'Avg straggler(ms)': <20}"
)
for record_name in self.comms_dict.keys():
if print_log:
print(record_name)
for msg_size, vals in sorted(self.comms_dict[record_name].items()):
# vals[0] is the count for each msg size
count = vals[0]
# vals[1] is a list of latency records for each msg size
lats = torch.tensor(vals[1])
min_lats = torch.tensor(vals[1])
dist.all_reduce(min_lats, op=ReduceOp.MIN)
total_lat = min_lats.sum().item()
total_straggler = (lats - min_lats).sum().item()
avg_lat = trim_mean(min_lats.tolist(), 0.1)
avg_straggler = trim_mean((lats - min_lats).tolist(), 0.1)
if print_log:
print(
f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{total_straggler: <20.2f}{avg_lat: <20.2f}{avg_straggler: <20.2f}"
)
13 changes: 12 additions & 1 deletion docs/_tutorials/comms-logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ In this tutorial, we introduce DeepSpeed communication logging and provide examp

NOTE: All logging communication calls are synchronized in order to provide accurate timing information. This may hamper performance if your model heavily uses asynchronous communication operations.

Logging communication calls is vital to ensure networking resources are fully utilized. The DeepSpeed communication logger enables the detection and logging of all communication operations launched under `deepspeed.comm`. Each communication operation can all be directly printed to the console immediately after completion (via the `verbose` config option), or a summary may be printed with a call to `deepspeed.comm.log_summary()` in the client code at the completion of training, an epoch, after N training iterations, etc.
Logging communication calls is vital to ensure networking resources are fully utilized. The DeepSpeed communication logger enables the detection and logging of all communication operations launched under `deepspeed.comm`. Each communication operation can all be directly printed to the console immediately after completion (via the `verbose` config option), or a summary may be printed with a call to `deepspeed.comm.log_summary()` or `deepspeed.com.log_summary(show_straggler=True)` in the client code at the completion of training, an epoch, after N training iterations, etc.

## Usage

Expand Down Expand Up @@ -114,3 +114,14 @@ broadcast | [Caller Func: _broadcast_model]
reduce_scatter_tensor | [Caller Func: reduce_scatter_fn]
678.86 MB 80 1527.17 13.94 1211.75 1136.01
```

Straggler effect can be shown by supplying optional argument `show_straggler=True` to `deepspeed.comm.log_summary()` call. Straggler effect is defined as the time a rank waits for the slowest rank to start communication. For each collective, `log_summary` would get the minimum collective time among all ranks, compute straggler effect as follows:

```
straggler = sum(t_collectives - allreduce(t_collectives, MIN))
```

Print straggler effect with the following `log_summary` call in the example above:
```
dist.log_summary(show_straggler=True)
```