Skip to content

Commit

Permalink
propagate mp config (#7637)
Browse files Browse the repository at this point in the history
Signed-off-by: eharper <[email protected]>
  • Loading branch information
ericharper authored Oct 4, 2023
1 parent e0ed81d commit de61527
Showing 1 changed file with 3 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.max_position_embeddings = cfg.model.max_position_embeddings
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention
gpt_cfg.tensor_model_parallel_size = cfg.model.get('tensor_model_parallel_size', 1)
gpt_cfg.pipeline_model_parallel_size = cfg.model.get('pipeline_model_parallel_size', 1)
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get('pipeline_model_parallel_split_rank', 0)

# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
Expand Down

0 comments on commit de61527

Please sign in to comment.