Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WarmupDecayLR.params.total_num_steps - total or per gpu? #633

Closed
stas00 opened this issue Jan 4, 2021 · 10 comments · Fixed by #608 or #644
Closed

WarmupDecayLR.params.total_num_steps - total or per gpu? #633

stas00 opened this issue Jan 4, 2021 · 10 comments · Fixed by #608 or #644

Comments

@stas00
Copy link
Collaborator

stas00 commented Jan 4, 2021

We are having a bit of hard time getting total_num_steps to pass to WarmupDecayLR at init time - it's a bit too early for the logic as these points are configured once ddp/ds has started - we found a workaround, but it doesn't take into the account the number of gpus.

I wanted to check that WarmupDecayLR.params.total_num_steps expects the total for the whole world and not per gpu.

Currently the doc says:

class WarmupDecayLR(WarmupLR):
[...]
            total_num_steps (int): total number of training steps

so it'd be awesome to disambiguate the context.

Thank you!

@tjruwase
Copy link
Contributor

tjruwase commented Jan 5, 2021

I find it easier to think of training step in terms of effective batch size. So while this implies world size, it also incorporates gradient accumulation steps. I guess another way is to say training steps == optimizer steps. Does this help? I will update doc shortly.

@stas00
Copy link
Collaborator Author

stas00 commented Jan 5, 2021

I'm not sure it helps. While it can be overriden by users with --max_steps, normally it's derived automatically at run time based on the dataset length, and gradient_accumulation_steps, to be exact:

            num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
            return math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)

So the modified question with this context - do we need to divide this by the number of gpus or not?

My guess is that since you said "Effective batch size" then yes it should be divided by n_gpus, correct?

I don't think though we have that n_gpus/world_size at ds init time, do we? as it may change through args to the deepspeed script.

We are dealing with a chicken and egg problem here, as we need to know the world size to get this right, but it's not set up until ds has been initialized.

@tjruwase
Copy link
Contributor

tjruwase commented Jan 5, 2021

Besides the chicken and egg problem, I am a bit confused about that code snippet for computing training steps. It seems to be independent of effective batch size, or am I missing something?

Also is len(train_dataloader) per gpu or for the world size?

@stas00
Copy link
Collaborator Author

stas00 commented Jan 5, 2021

You're correct, I haven't supplied enough information to allow you to support me. My apologies, @tjruwase.

Here is the full code:

            train_dataloader = DataLoader(
                self.train_dataset,
                batch_size=self.args.train_batch_size,
                collate_fn=self.data_collator,
                drop_last=self.args.dataloader_drop_last,
            )
            num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
            return math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)

and the batch is --per_device_train_batch_size - so per gpu.

but this gives me the full dataset, and not the exact slice derived from the distributed sampler, I had to remove it. The original had:

                self.train_dataset,
                batch_size=self.args.train_batch_size,
                sampler=self._get_train_sampler(),
                collate_fn=self.data_collator,
                drop_last=self.args.dataloader_drop_last,
            )

But I can't invoke the distributed sampler, since it's too early in the game, I get:

Traceback (most recent call last):
  File "./finetune_trainer.py", line 367, in <module>
    main()
  File "./finetune_trainer.py", line 282, in main
    trainer = Seq2SeqTrainer(
  File "/mnt/nvme1/code/huggingface/transformers-deepspeed/src/transformers/trainer.py", line 281, in __init__
    model, optimizer, lr_scheduler = self._init_deepspeed(model)
  File "/mnt/nvme1/code/huggingface/transformers-deepspeed/src/transformers/trainer.py", line 417, in _init_deepspeed
    print(f"STEPS {self.get_num_training_steps()}")
  File "/mnt/nvme1/code/huggingface/transformers-deepspeed/src/transformers/trainer.py", line 403, in get_num_training_steps
    sampler=self._get_train_sampler(),
  File "/mnt/nvme1/code/huggingface/transformers-deepspeed/src/transformers/trainer_seq2seq.py", line 46, in _get_train_sampler
    self.train_dataset.make_sortish_sampler(
  File "/mnt/nvme1/code/huggingface/transformers-deepspeed/examples/seq2seq/utils.py", line 170, in make_sortish_sampler
    return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
  File "/mnt/nvme1/code/huggingface/transformers-deepspeed/examples/seq2seq/utils.py", line 382, in __init__
    num_replicas = dist.get_world_size()
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 691, in get_world_size
    return _get_group_size(group)
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 243, in _get_group_size
    default_pg = _get_default_group()
  File "/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 327, in _get_default_group
    raise RuntimeError("Default process group has not been initialized, "
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

So it seems to be a design issue in DS. It should have run torch.distributed.init_process_group() by now. e.g. in HF trainer we set it up as soon as args are being parsed.

We actually don't want to do that calculation I pasted above, since it's too early and we have to re-do it again once dist is setup, but we can't see another way.

Actually, this won't be a problem at all if we could deepspeed_initialize with some dummy number and then update it once we have correctly calculated the total_num_steps during training. Is this possible?

@stas00
Copy link
Collaborator Author

stas00 commented Jan 5, 2021

When I launch the program with:

deepspeed  ./finetune_trainer.py ...

It does know the world size early on according to the logs:

[2021-01-05 13:30:20,634] [INFO] [launch.py:100:main] dist_world_size=2

but it won't torch.distributed.init_process_group() until after deepspeed.initialize is initialized and it's too late for our needs.

I guess I could have hacked around this if I could get the world size by simply measuring the length of the whole dataset and dividing it by the world size, which would give a roughly close to correct number, but again I can't call dist.get_world_size() at that stage.

@stas00
Copy link
Collaborator Author

stas00 commented Jan 5, 2021

Perhaps the correct solution is that there should be 2 init functions:

  1. deepspeed.init_dist which the user should call as soon as possible - first thing in the script as soon as args have been parsed
  2. deepspeed.initialize - the rest of it.

All that deepspeed.init_dist needs to do is to call torch.distributed.init_process_group()

In order to remain back compatible - the newly proposed deepspeed.init_dist will be called internally by deepspeed.initialize if the former wasn't called. But now users that need it early can invoke it early. Once the whole dist world is correctly setup we are much more flexible and can use the existing logic to integrate DS with ease.

@stas00
Copy link
Collaborator Author

stas00 commented Jan 7, 2021

Would it be possible to look at this issue, please? I can't complete the integration without this issue resolved. Thank you.

@jeffra
Copy link
Collaborator

jeffra commented Jan 7, 2021

If you're needing the world size before the deepspeed.initialize call then yes I think that's perfect motivation for calling deepspeed.init_distributed() sometime earlier. This will also help in other scenarios where torch.distributed needs to be used before deepspeed.initialize is invoked (e.g., model parallelism).

This call was introduced in #608 and is still pretty new. You are correct that if deepspeed.init_distributed is not called before deepspeed.initialize then we will call init_distributed on the users behalf to make sure the distributed backend is started properly.

@stas00
Copy link
Collaborator Author

stas00 commented Jan 7, 2021

oh, you have recently added just that - fantastic - I will give it a try and report back. Thank you, @jeffra

@stas00
Copy link
Collaborator Author

stas00 commented Jan 7, 2021

It worked beautifully and greatly simplified our code - thank you, @jeffra

It'd be awesome to document it at https://www.deepspeed.ai/getting-started/#writing-deepspeed-models perhaps?

I've proposed this doc to reflect this important function #644

@stas00 stas00 closed this as completed Jan 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants