From c52d1e2baf386d5dc6566b820583a3d71e6827d6 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Fri, 19 May 2023 23:36:11 +0800 Subject: [PATCH 1/5] add show_straggler argument to log_summary() --- deepspeed/comm/comm.py | 6 ++++-- deepspeed/utils/comms_logging.py | 31 +++++++++++++++++++++++-------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index 4a9b51b9e601..d13d2e664eee 100644 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -405,11 +405,13 @@ def monitored_barrier(group=None, return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks) -def log_summary(): +def log_summary(show_straggler=False): global cdb barrier(log_name='log_summary_barrier') if cdb.get_rank() == 0: - comms_logger.log_all() + comms_logger.log_all(print_log=True, show_straggler=show_straggler) + else: + comms_logger.log_all(print_log=False, show_straggler=show_straggler) barrier(log_name='log_summary_barrier') diff --git a/deepspeed/utils/comms_logging.py b/deepspeed/utils/comms_logging.py index 2400fa55b20e..7f3423b06363 100644 --- a/deepspeed/utils/comms_logging.py +++ b/deepspeed/utils/comms_logging.py @@ -122,23 +122,38 @@ def append(self, raw_name, record_name, latency, msg_size): log_dist(log_str, [0]) # Print summary at end of iteration, epoch, or training - def log_all(self): + def log_all(self, print_log=True, show_straggler=False): + import torch from deepspeed.utils.timer import trim_mean - print( - f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}" - ) + import deepspeed.comm as dist + from deepspeed.comm.reduce_op import ReduceOp + if print_log: + msg = f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}" + if show_straggler: + msg += f"{'Total straggler(ms)': <20}" + msg += f"{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}" + print(msg) for record_name in self.comms_dict.keys(): - print(record_name) + if print_log: + print(record_name) for msg_size, vals in sorted(self.comms_dict[record_name].items()): # vals[0] is the count for each msg size count = vals[0] # vals[1] is a list of latency records for each msg size total_lat = sum(vals[1]) + lats = torch.tensor(vals[1]) + min_lats = torch.tensor(vals[1]) + dist.all_reduce(min_lats, op=ReduceOp.MIN) + delta_lats = lats - min_lats + total_straggler = (lats - min_lats).sum().item() # vals[2] and vals[3] are the lists of algbw and busbw, respectively # Get rid of outliers when we print avg_lat = trim_mean(vals[1], 0.1) avg_algbw = trim_mean(vals[2], 0.1) avg_busbw = trim_mean(vals[3], 0.1) - print( - f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}" - ) + if print_log: + msg = f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}" + if show_straggler: + msg += f"{total_straggler: <20.2f}" + msg += f"{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}" + print(msg) From de368dbfe79fc83b396f9039cebfcac8ea751ab0 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Sat, 20 May 2023 16:50:03 +0800 Subject: [PATCH 2/5] Show straggler effect logging in seperate table --- deepspeed/utils/comms_logging.py | 40 ++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/deepspeed/utils/comms_logging.py b/deepspeed/utils/comms_logging.py index 7f3423b06363..8d9defebe0fe 100644 --- a/deepspeed/utils/comms_logging.py +++ b/deepspeed/utils/comms_logging.py @@ -128,11 +128,7 @@ def log_all(self, print_log=True, show_straggler=False): import deepspeed.comm as dist from deepspeed.comm.reduce_op import ReduceOp if print_log: - msg = f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}" - if show_straggler: - msg += f"{'Total straggler(ms)': <20}" - msg += f"{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}" - print(msg) + print(f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}") for record_name in self.comms_dict.keys(): if print_log: print(record_name) @@ -141,19 +137,33 @@ def log_all(self, print_log=True, show_straggler=False): count = vals[0] # vals[1] is a list of latency records for each msg size total_lat = sum(vals[1]) - lats = torch.tensor(vals[1]) - min_lats = torch.tensor(vals[1]) - dist.all_reduce(min_lats, op=ReduceOp.MIN) - delta_lats = lats - min_lats - total_straggler = (lats - min_lats).sum().item() # vals[2] and vals[3] are the lists of algbw and busbw, respectively # Get rid of outliers when we print avg_lat = trim_mean(vals[1], 0.1) avg_algbw = trim_mean(vals[2], 0.1) avg_busbw = trim_mean(vals[3], 0.1) if print_log: - msg = f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}" - if show_straggler: - msg += f"{total_straggler: <20.2f}" - msg += f"{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}" - print(msg) + print(f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}") + + if show_straggler: + if print_log: + print("_______________________________") + print("Breakdown with straggler effect") + print("-------------------------------") + print(f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total comm lat(ms)': <20}{'Total straggler(ms)': <20}{'Avg comm lat(ms)': <20}{'Avg straggler(ms)': <20}") + for record_name in self.comms_dict.keys(): + if print_log: + print(record_name) + for msg_size, vals in sorted(self.comms_dict[record_name].items()): + # vals[0] is the count for each msg size + count = vals[0] + # vals[1] is a list of latency records for each msg size + lats = torch.tensor(vals[1]) + min_lats = torch.tensor(vals[1]) + dist.all_reduce(min_lats, op=ReduceOp.MIN) + total_lat = min_lats.sum().item() + total_straggler = (lats - min_lats).sum().item() + avg_lat = trim_mean(min_lats.tolist(), 0.1) + avg_straggler = trim_mean((lats - min_lats).tolist(), 0.1) + if print_log: + print(f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{total_straggler: <20.2f}{avg_lat: <20.2f}{avg_straggler: <20.2f}") From 6884e33a91156c34a2bb3e0795efc3d7697b4c23 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Sat, 20 May 2023 17:01:33 +0800 Subject: [PATCH 3/5] fix formatting --- deepspeed/utils/comms_logging.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/deepspeed/utils/comms_logging.py b/deepspeed/utils/comms_logging.py index 8d9defebe0fe..8e6558cfb9dd 100644 --- a/deepspeed/utils/comms_logging.py +++ b/deepspeed/utils/comms_logging.py @@ -128,7 +128,9 @@ def log_all(self, print_log=True, show_straggler=False): import deepspeed.comm as dist from deepspeed.comm.reduce_op import ReduceOp if print_log: - print(f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}") + print( + f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}" + ) for record_name in self.comms_dict.keys(): if print_log: print(record_name) @@ -143,14 +145,18 @@ def log_all(self, print_log=True, show_straggler=False): avg_algbw = trim_mean(vals[2], 0.1) avg_busbw = trim_mean(vals[3], 0.1) if print_log: - print(f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}") + print( + f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}" + ) if show_straggler: if print_log: print("_______________________________") print("Breakdown with straggler effect") print("-------------------------------") - print(f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total comm lat(ms)': <20}{'Total straggler(ms)': <20}{'Avg comm lat(ms)': <20}{'Avg straggler(ms)': <20}") + print( + f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total comm lat(ms)': <20}{'Total straggler(ms)': <20}{'Avg comm lat(ms)': <20}{'Avg straggler(ms)': <20}" + ) for record_name in self.comms_dict.keys(): if print_log: print(record_name) @@ -166,4 +172,6 @@ def log_all(self, print_log=True, show_straggler=False): avg_lat = trim_mean(min_lats.tolist(), 0.1) avg_straggler = trim_mean((lats - min_lats).tolist(), 0.1) if print_log: - print(f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{total_straggler: <20.2f}{avg_lat: <20.2f}{avg_straggler: <20.2f}") + print( + f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{total_straggler: <20.2f}{avg_lat: <20.2f}{avg_straggler: <20.2f}" + ) From 206d4550f433468d1daf6a5675de113ec8830cd1 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Sun, 21 May 2023 17:01:51 +0800 Subject: [PATCH 4/5] add docs for log_summary with straggler effect --- docs/_tutorials/comms-logging.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/_tutorials/comms-logging.md b/docs/_tutorials/comms-logging.md index b6a352b60f68..7313688a03e6 100644 --- a/docs/_tutorials/comms-logging.md +++ b/docs/_tutorials/comms-logging.md @@ -13,7 +13,7 @@ In this tutorial, we introduce DeepSpeed communication logging and provide examp NOTE: All logging communication calls are synchronized in order to provide accurate timing information. This may hamper performance if your model heavily uses asynchronous communication operations. -Logging communication calls is vital to ensure networking resources are fully utilized. The DeepSpeed communication logger enables the detection and logging of all communication operations launched under `deepspeed.comm`. Each communication operation can all be directly printed to the console immediately after completion (via the `verbose` config option), or a summary may be printed with a call to `deepspeed.comm.log_summary()` in the client code at the completion of training, an epoch, after N training iterations, etc. +Logging communication calls is vital to ensure networking resources are fully utilized. The DeepSpeed communication logger enables the detection and logging of all communication operations launched under `deepspeed.comm`. Each communication operation can all be directly printed to the console immediately after completion (via the `verbose` config option), or a summary may be printed with a call to `deepspeed.comm.log_summary()` or `deepspeed.com.log_summary(show_straggler=True)` in the client code at the completion of training, an epoch, after N training iterations, etc. ## Usage @@ -114,3 +114,14 @@ broadcast | [Caller Func: _broadcast_model] reduce_scatter_tensor | [Caller Func: reduce_scatter_fn] 678.86 MB 80 1527.17 13.94 1211.75 1136.01 ``` + +Straggler effect can be shown by supply optional argument `show_straggler=True` to `deepspeed.comm.log_summary()` call. Straggler effect is defined as the time a rank waits for the slowest rank to start communication. For each collective, `log_summary` would get the minimum collective time among all ranks, compute straggler effect as follows: + +``` +straggler = sum(t_collectives - allreduce(t_collectives, MIN)) +``` + +Print straggler effect with the following `log_summary` call in the example above: +``` + dist.log_summary(show_straggler=True) +``` From b4c4472dffb54b54b9279e38185a7c840773faba Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Fri, 30 Jun 2023 09:48:22 +0800 Subject: [PATCH 5/5] fix typo --- docs/_tutorials/comms-logging.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_tutorials/comms-logging.md b/docs/_tutorials/comms-logging.md index 7313688a03e6..2719f08ad200 100644 --- a/docs/_tutorials/comms-logging.md +++ b/docs/_tutorials/comms-logging.md @@ -115,7 +115,7 @@ reduce_scatter_tensor | [Caller Func: reduce_scatter_fn] 678.86 MB 80 1527.17 13.94 1211.75 1136.01 ``` -Straggler effect can be shown by supply optional argument `show_straggler=True` to `deepspeed.comm.log_summary()` call. Straggler effect is defined as the time a rank waits for the slowest rank to start communication. For each collective, `log_summary` would get the minimum collective time among all ranks, compute straggler effect as follows: +Straggler effect can be shown by supplying optional argument `show_straggler=True` to `deepspeed.comm.log_summary()` call. Straggler effect is defined as the time a rank waits for the slowest rank to start communication. For each collective, `log_summary` would get the minimum collective time among all ranks, compute straggler effect as follows: ``` straggler = sum(t_collectives - allreduce(t_collectives, MIN))