Skip to content

Commit

Permalink
fix calcs.
Browse files Browse the repository at this point in the history
  • Loading branch information
awan-10 committed Aug 25, 2023
1 parent 8ed4067 commit 672a97a
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions applications/DeepSpeed-Chat/training/utils/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def print_throughput(hf_model, args, e2e_time, rank=0):
batch_size = args.per_device_train_batch_size
samples_per_second = batch_size / e2e_time
checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3
if args.lora_dim > 0:
k = args.lora_dim * 2 / hidden_size
checkpoint_activations_factor -= (1 - k)

hf_model._num_params = sum([
p.ds_numel if hasattr(p, "ds_tensor") else p.numel()
for p in hf_model.parameters()
Expand All @@ -25,8 +29,7 @@ def print_throughput(hf_model, args, e2e_time, rank=0):

# Megatron paper's formula to calculate training flops
train_flops_per_iteration = calculate_flops(
checkpoint_activations_factor, batch_size, seq_length, num_layers,
hidden_size, vocab_size)
checkpoint_activations_factor, batch_size, seq_length, hf_config)

train_tflops = train_flops_per_iteration / (e2e_time * gpus_per_model *
(10**12))
Expand All @@ -46,8 +49,15 @@ def print_throughput_step3(actor_model,
train_time,
rank=0):
if rank <= 0:
hf_config = actor_model.config
num_layers, hidden_size, vocab_size = get_hf_configs(hf_config)
# Actor model passed here is a HF model.
actor_hf_config = actor_model.config
# Critic model passed here is a DeepSpeed Engine. The module inside is the Reward model (that wraps a HF model).
critic_hf_config = critic_model.module.config

actor_num_layers, actor_hidden_size, actor_vocab_size = get_hf_configs(
actor_hf_config)
critic_num_layers, critic_hidden_size, critic_vocab_size = get_hf_configs(
critic_hf_config)

gpus_per_model = torch.distributed.get_world_size()
seq_length = args.max_answer_seq_len + args.max_prompt_seq_len
Expand All @@ -57,10 +67,10 @@ def print_throughput_step3(actor_model,
actor_checkpoint_activations_factor = 4 if args.actor_gradient_checkpointing else 3
critic_checkpoint_activations_factor = 4 if args.critic_gradient_checkpointing else 3
if args.actor_lora_dim > 0:
k = args.actor_lora_dim * 2 / hidden_size
k = args.actor_lora_dim * 2 / actor_hidden_size
actor_checkpoint_activations_factor -= (1 - k)
if args.critic_lora_dim > 0:
k = args.critic_lora_dim * 2 / hidden_size
k = args.critic_lora_dim * 2 / critic_hidden_size
critic_checkpoint_activations_factor -= (1 - k)

actor_model._num_params = sum([
Expand All @@ -76,12 +86,13 @@ def print_throughput_step3(actor_model,
critic_params_in_billions = critic_model._num_params / (1e9)

# Megatron paper's formula to calculate training flops

actor_train_flops_per_iteration = calculate_flops(
actor_checkpoint_activations_factor, batch_size, seq_length,
num_layers, hidden_size, vocab_size)
actor_hf_config)
critic_train_flops_per_iteration = calculate_flops(
critic_checkpoint_activations_factor, batch_size, seq_length,
num_layers, hidden_size, vocab_size)
critic_hf_config)

total_train_flops = actor_train_flops_per_iteration + critic_train_flops_per_iteration
train_tflops = total_train_flops / (train_time * gpus_per_model *
Expand All @@ -91,24 +102,22 @@ def print_throughput_step3(actor_model,

# Modified formula for calculating flops in the forward pass only
gen_flops_per_iteration = (
24 * gen_bs * seq_length * num_layers *
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
(vocab_size /
(16.0 * num_layers * hidden_size)))
24 * gen_bs * seq_length * actor_num_layers *
(actor_hidden_size**2)) * (
1.0 + (seq_length / (6.0 * actor_hidden_size)) +
(actor_vocab_size /
(16.0 * actor_num_layers * actor_hidden_size)))

gen_tflops = gen_flops_per_iteration / (gen_exp_time * gpus_per_model *
(10**12))

if hf_config.torch_dtype == torch.float16:
if actor_hf_config.torch_dtype == torch.float16:
num_bytes = 2
elif hf_config.torch_dtype == torch.float32:
elif actor_hf_config.torch_dtype == torch.float32:
num_bytes = 4
else:
num_bytes = -1

print(
f"{num_bytes=}, {hf_config.torch_dtype=}, {actor_model._num_params=}"
)
pertok_lat = gen_exp_time / args.max_answer_seq_len
gen_bw = 1 / pertok_lat * actor_model._num_params * num_bytes / 1e9

Expand All @@ -134,7 +143,8 @@ def print_throughput_step3(actor_model,

# Helper function to calculate FLOPs using the Megatron-LM paper's formula
def calculate_flops(checkpoint_activations_factor, batch_size, seq_length,
num_layers, hidden_size, vocab_size):
hf_config):
num_layers, hidden_size, vocab_size = get_hf_configs(hf_config)
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size *
seq_length * num_layers * (hidden_size**2)) * (
1.0 + (seq_length / (6.0 * hidden_size)) +
Expand Down

0 comments on commit 672a97a

Please sign in to comment.