Skip to content

Commit

Permalink
Replace approximate formula with exact one for throughput (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
deepakn94 authored Feb 22, 2022
1 parent 541b967 commit ae6277f
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,23 +668,23 @@ def add_to_logging(name):
elapsed_time_per_iteration = elapsed_time / total_iterations

seq_len = args.curriculum_seqlen if args.curriculum_learning else args.seq_length
hidden_size = args.hidden_size
num_layers = args.num_layers
vocab_size = args.padded_vocab_size

# throughput
# Compute throughput.
samples_per_sec = batch_size / elapsed_time_per_iteration
samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size
tokens_per_sec = samples_per_sec * seq_len
tokens_per_sec_per_replica = tokens_per_sec / args.data_parallel_size

# general TFLOPs formula
# model_size_in_B * 4 * 2 * seqlen * global_batch_size / (time_in_sec_per_interation * total_gpus * 1e3)
#
# General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of
# https://arxiv.org/pdf/2104.04473.pdf).
# The factor of 4 is when used with activation check-pointing,
# otherwise it will be 3, but for 200B model, activation check-pointing will always be on.
#
# here:
# model_size_in_B * 4 * 2 * seqlen * batch_size / (time_in_msec_per_interation * total_gpus * 1e3)
checkpoint_activations_factor = 4 if args.checkpoint_activations else 3
tflops = args.parameters_in_billions_no_embedding * checkpoint_activations_factor * 2 * seq_len * batch_size / (elapsed_time_per_iteration * args.world_size * 1e3)
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))
tflops = flops_per_iteration / (elapsed_time_per_iteration * args.world_size * (10**12))

# only the last rank process has a non-None _GLOBAL_TENSORBOARD_WRITER
if writer and is_last_rank():
Expand Down

0 comments on commit ae6277f

Please sign in to comment.