Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Deepspeed] Allow HF optimizer and scheduler to be passed to deepspeed #10464

Merged
merged 32 commits into from
Mar 16, 2021

Conversation

cli99
Copy link
Contributor

@cli99 cli99 commented Mar 1, 2021

Use HF optimizer and/or scheduler unless specified in deepspeed config

If HF is already creating an optimizer and LR scheduler, we should not try to match that config/implementation in a ds_config.json instead we pass it to deepspeed.initialize(..., lr_scheduler=hf_lr_scheduler)

  • This PR checks if ds_config has an optimizer or scheduler, if it does not, calls create_optimizer or create_cheduler (after splitting it) to create an optimizer or scheduler. Then HF optimizer and scheduler are passed it to deepspeed.initialize().
    DeepSpeed can handle any optimizer and scheduler if these are passed directly to deepspeed.initialize() as an object.

Due to the chicken-n-egg init problem, the valid combinations are:

Combos HF Scheduler DS Scheduler
HF Optimizer Yes Yes
DS Optimizer No Yes

but if cpu_offload is used all bets are off - we can only use DS optim/sched.


added by @stas00 below:

Added:

  • make init_deepspeed support config dict, besides the config file - this makes the testing much easier
  • add tests for this PR using this new feature of passing the dict
  • various small clean ups
  • update the docs
  • check for cpu_offload - add test
  • recode the config overrides to have one true source of values
  • tweak one not working test

blocking event: waiting for a new release 0.3.13 from DeepSpeed.

@sgugger

@stas00 stas00 self-assigned this Mar 1, 2021
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very good proposal, @cli99.

So are you saying DeepSpeed will handle any optimizer and scheduler if these are passed directly as an object?

One more thing we need is a test, it probably would be easier for you if I added it since I need to rework a few things to support this test. But you will need to switch it to normal PR from draft for me to be able to push into it. We usually don't use Draft but edit the PR title to start with [WIP] work in progress.

src/transformers/integrations.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Show resolved Hide resolved
ZERO_DP_2 = "zero2"
ZERO_DP_3 = "zero3"
OFFLOAD = "offload"
ZERO_DP_2 = "zero_dp_2"
Copy link
Contributor

@stas00 stas00 Mar 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for catching those discrepancies.

Ideally let's try not to cross unrelated areas in the same PR, but I suppose it's OK for now.

@cli99 cli99 marked this pull request as ready for review March 2, 2021 01:47
@cli99 cli99 changed the title Pass HF optimizer and scheduler to deepspeed if not specified in deepspeed config [WIP] Pass HF optimizer and scheduler to deepspeed if not specified in deepspeed config Mar 2, 2021
@stas00
Copy link
Contributor

stas00 commented Mar 2, 2021

OK, 2 tests added and no, this doesn't work w/o neither the default optimizer nor the default scheduler. e.g. if you comment out the del lines in the tests then we are using DS optim/sched and things are back to normal.

I didn't have time to investigate as it's late, so just sharing the outputs at the moment - will look closer tomorrow. I think both are issues on the DeepSpeed side, but I could be wrong.

Also note that the normal CI doesn't run these tests, so green doesn't say anything about those.

pytest -sv examples/tests/deepspeed/test_deepspeed.py -k test_hf_native_scheduler

examples/tests/deepspeed/test_deepspeed.py:103:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/trainer.py:917: in train
    model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
src/transformers/integrations.py:351: in init_deepspeed
    trainer.create_scheduler(num_training_steps=num_training_steps)
src/transformers/trainer.py:685: in create_scheduler
    self.lr_scheduler = get_scheduler(
src/transformers/optimization.py:266: in get_scheduler
    return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
src/transformers/optimization.py:98: in get_linear_schedule_with_warmup
    return LambdaLR(optimizer, lr_lambda, last_epoch)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <torch.optim.lr_scheduler.LambdaLR object at 0x7fd86fb0a160>, optimizer = None
lr_lambda = <function get_linear_schedule_with_warmup.<locals>.lr_lambda at 0x7fd86fafc160>, last_epoch = -1, verbose = False

    def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
        self.optimizer = optimizer

        if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
>           self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
E           AttributeError: 'NoneType' object has no attribute 'param_groups'

/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:197: AttributeError
pytest -sv examples/tests/deepspeed/test_deepspeed.py -k test_hf_native_optimizer

examples/tests/deepspeed/test_deepspeed.py:91: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/trainer.py:917: in train
    model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
src/transformers/integrations.py:384: in init_deepspeed
    model, optimizer, _, lr_scheduler = deepspeed.initialize(
../../github/00optimize/DeepSpeed/deepspeed/__init__.py:110: in initialize
    engine = DeepSpeedEngine(args=args,
../../github/00optimize/DeepSpeed/deepspeed/runtime/engine.py:174: in __init__
    self._configure_optimizer(optimizer, model_parameters)
../../github/00optimize/DeepSpeed/deepspeed/runtime/engine.py:570: in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
../../github/00optimize/DeepSpeed/deepspeed/runtime/engine.py:691: in _configure_zero_optimizer
    optimizer = FP16_DeepSpeedZeroOptimizer(
../../github/00optimize/DeepSpeed/deepspeed/runtime/zero/stage2.py:239: in __init__
    flatten_dense_tensors_aligned(
../../github/00optimize/DeepSpeed/deepspeed/runtime/zero/stage2.py:74: in flatten_dense_tensors_aligned
    return _flatten_dense_tensors(padded_tensor_list)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

tensors = []

    def _flatten_dense_tensors(tensors):
        """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
        same dense type.
    
        Since inputs are dense, the resulting tensor will be a concatenated 1D
        buffer. Element-wise operation on this buffer will be equivalent to
        operating individually.
    
        Args:
            tensors (Iterable[Tensor]): dense tensors to flatten.

        Returns:
            A contiguous 1D buffer containing input tensors.
        """
        if len(tensors) == 1:
            return tensors[0].contiguous().view(-1)
>       flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
E       RuntimeError: There were no tensor arguments to this function (e.g., you passed an empty list of Tensors), but no fallback function is registered for schema aten::_cat.  This usually means that this function requires a non-empty list of Tensors.  Available functions are [CPU, CUDA, QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].
E
E       CPU: registered at /pytorch/build/aten/src/ATen/RegisterCPU.cpp:5925 [kernel]
E       CUDA: registered at /pytorch/build/aten/src/ATen/RegisterCUDA.cpp:7100 [kernel]
E       QuantizedCPU: registered at /pytorch/build/aten/src/ATen/RegisterQuantizedCPU.cpp:641 [kernel]
E       BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E       Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
E       AutogradOther: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:9161 [autograd kernel]
E       AutogradCPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:9161 [autograd kernel]
E       AutogradCUDA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:9161 [autograd kernel]
E       AutogradXLA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:9161 [autograd kernel]
E       AutogradNestedTensor: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:9161 [autograd kernel]
E       UNKNOWN_TENSOR_TYPE_ID: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:9161 [autograd kernel]
E       AutogradPrivateUse1: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:9161 [autograd kernel]
E       AutogradPrivateUse2: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:9161 [autograd kernel]
E       AutogradPrivateUse3: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:9161 [autograd kernel]
E       Tracer: registered at /pytorch/torch/csrc/autograd/generated/TraceType_2.cpp:10551 [kernel]
E       Autocast: registered at /pytorch/aten/src/ATen/autocast_mode.cpp:254 [kernel]
E       Batched: registered at /pytorch/aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback]
E       VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/torch/_utils.py:259: RuntimeError

@stas00 stas00 self-requested a review March 2, 2021 06:42
@stas00
Copy link
Contributor

stas00 commented Mar 2, 2021

OK, so I had a look at the first failing test.

These 2 can't be separated the way it was done, since the optimizer is needed to init the scheduler. But we don't have it yet if it's Deepspeed that creates the optimizer. So we have a chicken-n-egg problem here. Unless deepspeed provides a new API to handle that.

So probably at the moment we can only support one of: 1, 2, 3 and not 4.

  1. DS scheduler + DS optimizer
  2. HF scheduler + HF optimizer
  3. DS scheduler + HF optimizer
  4. HF scheduler + DS optimizer

Note I added a new test for the combo 2 and renamed all tests to match, so now we have:

pytest -sv examples/tests/deepspeed/test_deepspeed.py -k test_hf_scheduler_hf_optimizer
pytest -sv examples/tests/deepspeed/test_deepspeed.py -k test_ds_scheduler_hf_optimizer
pytest -sv examples/tests/deepspeed/test_deepspeed.py -k test_hf_scheduler_ds_optimizer

@cli99
Copy link
Contributor Author

cli99 commented Mar 6, 2021

This deepspeed PR deepspeedai/DeepSpeed#827 fixes the issues. The following tests would pass.
DS scheduler + DS optimizer
HF scheduler + HF optimizer
DS scheduler + HF optimizer

Shall we put a check in HF to disallow the case HF scheduler + DS optimizer?

@stas00
Copy link
Contributor

stas00 commented Mar 6, 2021

I tested with your deepspeedai/DeepSpeed#827 PR tree and indeed

pytest -sv examples/tests/deepspeed/test_deepspeed.py -k test_hf_scheduler_hf_optimizer
pytest -sv examples/tests/deepspeed/test_deepspeed.py -k test_ds_scheduler_hf_optimizer

now pass. awesome!

Shall we put a check in HF to disallow the case HF scheduler + DS optimizer?

Correct! Please let me know if you will be taking care of it or you'd rather me finish this up. Either way works.

Also will need to update the docs to reflect this new more flexible reality. I can take care of that.

We will need to wait for a new release from your side to commit this PR and add a requirement for that new version. I already have this check setup in another PR waiting for this new release: #9624

There is another PR by @jeffra today that also needs a new release first before we can update our tree.

@cli99
Copy link
Contributor Author

cli99 commented Mar 7, 2021

Great. Can you please add the check for the case HF scheduler + DS optimizer? Since you are updating the docs, I think it makes more sense for you to do it. I will work with @jeffra to push the deepspeed PRs into the new release. Thanks.

@stas00 stas00 requested review from sgugger and removed request for stas00 March 8, 2021 22:42
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR!

It makes sense to split the optimizer and scheduler creation as you did, and since it's in a BC compatible way I have no worries. Would you mind adding those two new methods in the documentation of the Trainer (top of the file main_classes/trainer.rst)?

src/transformers/training_args.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
@stas00 stas00 changed the title [WIP] Pass HF optimizer and scheduler to deepspeed if not specified in deepspeed config [Deepspeed] Allow HF optimizer and scheduler to be passed to deepspeed Mar 9, 2021
@stas00
Copy link
Contributor

stas00 commented Mar 13, 2021

@cli99, I made further changes to your original code

  1. as @jeffra suggested we can't use HF optimizer with offload enabled - so coded to defend against that
  2. I realized my original design was flawed and that the user could end up with a mismatch between cl args and the ds config, so I recoded the optimizer/scheduler config sections to override ds config with cl args where needed.

Please let me know if I broke anything in your original plan. I have also updated the docs extensively. They look a bit scary at the moment and will need a rework down the road.

My main goal here is to prevent from user getting subtle errors, so setting command line arguments to override DS config. Hope it makes sense.

@stas00
Copy link
Contributor

stas00 commented Mar 13, 2021

@sgugger, I made more doc updates - if you get a chance please kindly skim over them? Thank you!

I think we will merge this on Monday when deepspeed==0.3.13 is planned to be released.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one nit in the documentation. Thanks for all the work!

docs/source/main_classes/trainer.rst Show resolved Hide resolved
@sgugger
Copy link
Collaborator

sgugger commented Mar 16, 2021

@stas00 I'll let you merge when you are ready (since you followed this more closely than me). It looks good to merge to me :-)
Thanks for your contribution @cli99!

@stas00
Copy link
Contributor

stas00 commented Mar 16, 2021

I'm on top of this - we are waiting for a new DeepSpeed release required by this PR. Thank you, @sgugger

@stas00 stas00 merged commit c83fbc5 into huggingface:master Mar 16, 2021
Iwontbecreative pushed a commit to Iwontbecreative/transformers that referenced this pull request Jul 15, 2021
huggingface#10464)

* pass hf optimizer and scheduler to deepspeed if not specified in ds config

* pass hf optimizer and scheduler to deepspeed if not specified in ds config

* update

* make init_deepspeed support config dict

* fix docstring formatting

* clean up trainer's comments

* add new tests

* fix type

* composit argparse doesn't work

* style

* add a new test, rename others

* document new functionality

* complete tests, add docs

* style

* correct level

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <[email protected]>

* add new methods to the doc

* must tell DS we are using a non-native optimizer

* add protection against cpu_offload + HF optimizer combo

* fix the cli overrides

* sync docs + tests

* restore AdamW

* better docs

* need new version

* no longer needed

* remove outdate information

* refactor duplicated code

Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Sylvain Gugger <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants