Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to generate text with the Megatron-LM model trained with DeepSpeed #507

Closed
msmolyak opened this issue Nov 7, 2020 · 7 comments
Closed

Comments

@msmolyak
Copy link

msmolyak commented Nov 7, 2020

DeepSpeed tutorial https://www.deepspeed.ai/tutorials/megatron/ does not have any guidance regarding text generation after training the model. All the information in the tutorial deals with training, which implies that logic for generation does not change.

I tried running the Megatron-LM script for text generation, (which works fine with the model trained using Megatron-LM scripts), but it was unable to load the checkpoint generated by the DeepSpeed wrapper.

Are checkpoints generated by Megatron-LM and DeepSpeed wrapper of Megatron-LM binary compatible? Do I need to wrap the text generation code with deepspeed module they way it was done with the pre-training logic? Are there examples of text generation based on models trained with DeepSpeed?

Thank you,

Michael

@msmolyak
Copy link
Author

I was able to modify the text generation utility generate_samples.py to wrap Megatron-LM model in DeepSpeed engine and use it to generate text with the checkpoint created by DeepSpeed Megatron-LM wrapper.

I can create a pull request for the change, but the code for Megatron-LM in this repository is over a year old and I think that fixing it will not serve much purpose. The preferred solution is for the Microsoft team to complete their effort of incorporating the more recent version of Megatron-LM in DeepSpeed Examples repo. I left a note to the developers working on that branch appraising them of this bug.

Please note that another bug in DeepSpeed makes it hard to wrap Megatron-LM model in DeepSpeed engine for text generation. Text generation does not require an optimizer. The deepspeed.initialize() call states that the optimizer is optional. In fact, passing None for an optimizer leads to an exception. Trying to create a "fake" optimizer for text generation just to make DeepSpeed code to work complicates the code and makes it hard to test the code on a machine with small amount of GPU memory.

My team fixed the issue with optimizer and sent a pull request to the Microsoft team. Once it is merged, it will be much easier to use text generation with the DeepSpeed code.

@ShadenSmith
Copy link
Contributor

Hi @msmolyak, many many thanks to you and your team for the multiple contributions along this line. I am back after some sick leave and catching up.

Let me work on your submitted bugfix and ramp back up on the DSE side of things. We'd love to get the text generation PR from you once we're setup to properly track Megatron-LM's upstream branch.

@hujian233
Copy link

Hi, @msmolyak , I'd like to ask you a question,I modify the generate_samples.py,to use the setup_model_and_optimizer() function in pretrain_gpt2.py. It can load model parallel checkpoint success, but it can't generate any text,it will stuck in logits = model(tokens, position_ids, attention_mask) without any error log, Do you know what that might be?

def main():
    """Main training program."""

    print('Generate Samples')

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # get the tokenizer
    tokenizer = prepare_tokenizer(args)

    # Model, optimizer, and learning rate.
    if args.deepspeed:
        model = setup_model_and_optimizer(args)
    else:
        model = setup_model(args)

    # setting default batch size to 1
    args.batch_size = 1

    # generate samples
    generate_samples(model, tokenizer, args, torch.cuda.current_device())

@msmolyak
Copy link
Author

Hi @hujian233,

Here are the notes I took when trying to make text generation work in the current version of DeepSpeedExamples (I did not commit the code anywhere, it was just a proof of concept effort):

https://docs.google.com/document/d/1My6UA-2n_MHMZO8w-xwnXoKBU1KWEr4B4_iYveFoXPs/edit?usp=sharing

Step 6 dealing with the changes to generate_samples has three changes. The document contains the diff with my changes.

  1. Update the model using deepspeed.initialize()
  2. Obtain an optimizer since the call above does not work without an optimizer (it is not used in text generation)
  3. Fixed the logic for computing the vocabulary size

The second change was a hack to create an optimizer, which is totally superfluous. The preferred approach here would be to modify the DeepSpeed code to allow initializing a model without an optimizer. (My colleague submitted a pull request to that effect).

This code was able to generate text from the DeepSpeed-trained model. If this document does not offer any clues, let me know and I will try to run text generation with your code in my test environment.

@hujian233
Copy link

Hi @msmolyak , thank you. I have already generate text normally.

@hujian233
Copy link

Hi, @msmolyak , Have you solved the problem you raised. If now can use deepspeed initialize without optmizer?
this is my code:

def get_model(args):
    """Build the model."""

    print_rank_0('building GPT2 model ...')
    model = GPT2Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=False)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    if args.deepspeed and args.fp16:
        model.half()

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    model = DDP(model)

    return model
def get_parameters(model):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (DDP, FP16_Module)):
        model = model.module
    param_groups = gpt2_get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False
    return param_groups
def setup_model_deepspeed(args):
    """Setup model without optimizer."""

    model = get_model(args)
    param_groups = get_parameters(model)
    if args.deepspeed:
        print_rank_0("DeepSpeed is enabled.")

    model, _, _, _ = deepspeed.initialize(
        model=model,
        model_parameters=param_groups,
        args=args,
        mpu=mpu,
        dist_init_required=False
    )

    if args.load is not None:
        _ = load_checkpoint(
            model, None, None, args)

    return model

It doesn't word well with the latest deepspeed version. do you know what I'm missing? how the deepspeed config file set?

@loadams
Copy link
Collaborator

loadams commented Aug 21, 2023

Closing this issue as it appears to be stale. If you are hitting new issues/have more questions, please open a new issue with the latest DeepSpeed/Megatron-DeepSpeed repo and we would be happy to take a look. Thanks!

@loadams loadams closed this as completed Aug 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants