diff --git a/applications/DeepSpeed-Chat/training/utils/perf.py b/applications/DeepSpeed-Chat/training/utils/perf.py index 94b091cd0..df57a2046 100644 --- a/applications/DeepSpeed-Chat/training/utils/perf.py +++ b/applications/DeepSpeed-Chat/training/utils/perf.py @@ -17,6 +17,10 @@ def print_throughput(hf_model, args, e2e_time, rank=0): batch_size = args.per_device_train_batch_size samples_per_second = batch_size / e2e_time checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3 + if args.lora_dim > 0: + k = args.lora_dim * 2 / hidden_size + checkpoint_activations_factor -= (1 - k) + hf_model._num_params = sum([ p.ds_numel if hasattr(p, "ds_tensor") else p.numel() for p in hf_model.parameters() @@ -25,8 +29,7 @@ def print_throughput(hf_model, args, e2e_time, rank=0): # Megatron paper's formula to calculate training flops train_flops_per_iteration = calculate_flops( - checkpoint_activations_factor, batch_size, seq_length, num_layers, - hidden_size, vocab_size) + checkpoint_activations_factor, batch_size, seq_length, hf_config) train_tflops = train_flops_per_iteration / (e2e_time * gpus_per_model * (10**12)) @@ -46,8 +49,15 @@ def print_throughput_step3(actor_model, train_time, rank=0): if rank <= 0: - hf_config = actor_model.config - num_layers, hidden_size, vocab_size = get_hf_configs(hf_config) + # Actor model passed here is a HF model. + actor_hf_config = actor_model.config + # Critic model passed here is a DeepSpeed Engine. The module inside is the Reward model (that wraps a HF model). + critic_hf_config = critic_model.module.config + + actor_num_layers, actor_hidden_size, actor_vocab_size = get_hf_configs( + actor_hf_config) + critic_num_layers, critic_hidden_size, critic_vocab_size = get_hf_configs( + critic_hf_config) gpus_per_model = torch.distributed.get_world_size() seq_length = args.max_answer_seq_len + args.max_prompt_seq_len @@ -57,10 +67,10 @@ def print_throughput_step3(actor_model, actor_checkpoint_activations_factor = 4 if args.actor_gradient_checkpointing else 3 critic_checkpoint_activations_factor = 4 if args.critic_gradient_checkpointing else 3 if args.actor_lora_dim > 0: - k = args.actor_lora_dim * 2 / hidden_size + k = args.actor_lora_dim * 2 / actor_hidden_size actor_checkpoint_activations_factor -= (1 - k) if args.critic_lora_dim > 0: - k = args.critic_lora_dim * 2 / hidden_size + k = args.critic_lora_dim * 2 / critic_hidden_size critic_checkpoint_activations_factor -= (1 - k) actor_model._num_params = sum([ @@ -76,12 +86,13 @@ def print_throughput_step3(actor_model, critic_params_in_billions = critic_model._num_params / (1e9) # Megatron paper's formula to calculate training flops + actor_train_flops_per_iteration = calculate_flops( actor_checkpoint_activations_factor, batch_size, seq_length, - num_layers, hidden_size, vocab_size) + actor_hf_config) critic_train_flops_per_iteration = calculate_flops( critic_checkpoint_activations_factor, batch_size, seq_length, - num_layers, hidden_size, vocab_size) + critic_hf_config) total_train_flops = actor_train_flops_per_iteration + critic_train_flops_per_iteration train_tflops = total_train_flops / (train_time * gpus_per_model * @@ -91,24 +102,22 @@ def print_throughput_step3(actor_model, # Modified formula for calculating flops in the forward pass only gen_flops_per_iteration = ( - 24 * gen_bs * seq_length * num_layers * - (hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) + - (vocab_size / - (16.0 * num_layers * hidden_size))) + 24 * gen_bs * seq_length * actor_num_layers * + (actor_hidden_size**2)) * ( + 1.0 + (seq_length / (6.0 * actor_hidden_size)) + + (actor_vocab_size / + (16.0 * actor_num_layers * actor_hidden_size))) gen_tflops = gen_flops_per_iteration / (gen_exp_time * gpus_per_model * (10**12)) - if hf_config.torch_dtype == torch.float16: + if actor_hf_config.torch_dtype == torch.float16: num_bytes = 2 - elif hf_config.torch_dtype == torch.float32: + elif actor_hf_config.torch_dtype == torch.float32: num_bytes = 4 else: num_bytes = -1 - print( - f"{num_bytes=}, {hf_config.torch_dtype=}, {actor_model._num_params=}" - ) pertok_lat = gen_exp_time / args.max_answer_seq_len gen_bw = 1 / pertok_lat * actor_model._num_params * num_bytes / 1e9 @@ -134,7 +143,8 @@ def print_throughput_step3(actor_model, # Helper function to calculate FLOPs using the Megatron-LM paper's formula def calculate_flops(checkpoint_activations_factor, batch_size, seq_length, - num_layers, hidden_size, vocab_size): + hf_config): + num_layers, hidden_size, vocab_size = get_hf_configs(hf_config) flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * seq_length * num_layers * (hidden_size**2)) * ( 1.0 + (seq_length / (6.0 * hidden_size)) +