Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[deepspeed checkpointing] AttributeError: 'NoneType' object has no attribute 'numel' #598

Open
g-karthik opened this issue Dec 11, 2020 · 23 comments · Fixed by #660
Open

Comments

@g-karthik
Copy link

So I took a public GPT-2 class implementation (not Megatron-LM) and I added deepspeed checkpointing to it for all 48 layers.
In my train script for this class, I added the following line:

deepspeed.checkpointing.configure(mpu_=None, deepspeed_config=args.deepspeed_config)

My deepspeed config JSON is as follows:

{
  "train_batch_size": 128,
  "gradient_accumulation_steps": 8,
  "gradient_clipping": 1.0,
  "optimizer": {
    "type": "adam",
    "params": {
      "lr": 6.25e-5
    }
  },
  "fp16": {
    "enabled": true
  },
  "zero_optimization": {
    "stage": 2,
    "cpu_offload": true,
    "contiguous_gradients": true,
    "overlap_comm": false,
    "allgather_bucket_size": 500000000
  },

  "activation_checkpointing": {
    "partition_activations": true,
    "contiguous_memory_optimization": true,
    "number_checkpoints": 48,
    "cpu_checkpointing": true
  }

}

When I try running my script, I get the following error:

  File "/path/to/my/modeling_gpt2.py", line 221, in forward
    encoder_attention_mask)
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 582, in checkpoint
    return CheckpointFunction.apply(function, *args)
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 376, in forward
    partition_size = get_partition_size(item)
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 275, in get_partition_size
    size = item.numel()
AttributeError: 'NoneType' object has no attribute 'numel'

Any ideas what's going on?

@tjruwase @ShadenSmith

@ShadenSmith
Copy link
Contributor

Hi @g-karthik , that is an unfriendly way of reporting that the partition_activations, contiguous_memory_optimization, and cpu_checkpointing features require model parallelism. The partition_activations feature assumes that activations are replicated among a group of model-parallel workers (as is the case in Megatron) and then partitions them to save additional memory with activation checkpointing. The contiguous_memory_optimization saves additional memory by avoiding fragementation and allocating all activations up front. The cpu_checkpointing feature then offloads the saved activations to CPU memory.

I will file an issue to catch and report the configuration error when model parallelism is not enabled.

@g-karthik
Copy link
Author

Ah I see, thanks @ShadenSmith! I made those false and get a different NoneType error now :(

  File "/path/to/my/modeling_gpt2.py", line 221, in forward
    encoder_attention_mask)
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 582, in checkpoint
    return CheckpointFunction.apply(function, *args)
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 416, in forward
    inputs_cuda = [item.to(cuda_device) for item in args]
  File "/usr/local/lib/python3.6/dist-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 416, in <listcomp>
    inputs_cuda = [item.to(cuda_device) for item in args]
AttributeError: 'NoneType' object has no attribute 'to'

@g-karthik
Copy link
Author

g-karthik commented Dec 11, 2020

Btw as a more foundational question, why couldn't we have upfront contiguous memory allocation for activations in the case of no-model-parallelism? Same question for offloading activations to CPU with cpu_checkpointing=True.

I understand the behavior of partition_activations is strongly tied to model-parallelism, but the others shouldn't necessarily need to, right?

@ShadenSmith
Copy link
Contributor

Hm, can you share a bit about the model code that uses checkpointing? My first response still applies, but now I'm wondering if there's a bigger issue that we're not catching. This looks like checkpoint() is not getting the args, which should be the input to the checkpointed method.

@g-karthik
Copy link
Author

g-karthik commented Dec 11, 2020

@ShadenSmith yep, this is kinda what it looks like inside the model's forward():

            checkpoint = deepspeed.checkpointing.checkpoint

            def create_custom_forward(module):
                def custom_forward(*inputs):
                    # checkpointing only works with tuple returns, not with lists
                    return tuple(output for output in module(*inputs, use_cache, output_attentions))

                return custom_forward

            outputs = checkpoint(create_custom_forward(block),
                                 hidden_states,
                                 layer_past,
                                 attention_mask,
                                 head_mask[i],
                                 encoder_hidden_states,
                                 encoder_attention_mask)


I assumed deepspeed.checkpointing.checkpoint is like a drop-in replacement for torch.utils.checkpoint.checkpoint since the signatures are the same. Is that a fair assumption to make?

@g-karthik
Copy link
Author

g-karthik commented Dec 12, 2020

Ah @ShadenSmith I think I know what the issue is, but correct me if I'm wrong. The default torch.utils.checkpoint.checkpoint implementation does not assume anything about availability of all *args, i.e., the CheckpointFunction's forward() method implementation allows for some arguments to be None.

However, in the corresponding deepspeed CheckpointFunction's forward() method, the implementation seems to assume at multiple places that all passed args are not None.

Perhaps these enumerations need to be modified and made more generic to account for cases where a layer being checkpointed could have some args as None, depending on the consumer's choice?

EDIT: Yep, I ran a test where I simply replaced deepspeed.checkpointing.checkpoint with torch.utils.checkpoint.checkpoint (all else equal), I can confirm the job runs successfully with the replacement. So there definitely needs to be some updates to the deepspeed checkpointing implementation.

Also, do you happen to have performed any side-by-side benchmarking of torch.utils.checkpoint.checkpoint and deepspeed.checkpointing.checkpoint? I'd love to know more about it if y'all have done it! :)

@g-karthik
Copy link
Author

@ShadenSmith happy new year! Just bumping this back up your notifications!

@ShadenSmith
Copy link
Contributor

Hey @g-karthik , happy new year and thanks for the ping! I hope you had a nice holiday. I spent some time away to focus on python scalability :-).

Thanks for the deep dive on debugging this issue. I think you're right. DeepSpeed's activation checkpointing should work as a drop-in replacement in the simple case without partitioning, etc.

I don't know if we've done any benchmarking beyond at large scale with model parallelism where it's critical for huge activations. Maybe @samyam knows more.

@g-karthik
Copy link
Author

g-karthik commented Jan 6, 2021

@ShadenSmith The args fix should be straightforward, unless I'm missing something? i.e., just need to ignore those args that are None?

I've been using torch.utils.checkpoint.checkpoint as a workaround so far and I'm unable to fit the GPT-2 10B configuration (with 1024 sequence length) in my Tesla V100 GPU (32 GB memory). The best I've fitted with this workaround is the GPT-2 8B configuration (with 512 sequence length). I'm not quite sure yet if using deepspeed.checkpointing.checkpoint will help me fit 10B with 1024 sequence length, but I'll at least be able to try after the above args fix.

@ShadenSmith
Copy link
Contributor

Hi @g-karthik I submitted a PR with a fix and corresponding unit test. Are you able to give that a try? It would be great to ensure that my new unit test covers your case.

@ShadenSmith
Copy link
Contributor

I went ahead and merged the PR; please re-open if the issue is still present for you!

@g-karthik
Copy link
Author

@ShadenSmith thanks for the fix! Will test this out later this week and revert back to you!

@g-karthik
Copy link
Author

g-karthik commented Feb 26, 2021

@ShadenSmith sorry for the delay, I did test this fix and it works, thanks a lot! Note that this is with Hugging Face's transformer models without usage of mpu_.

Keeping all else equal, I do not see any meaningful difference between DeepSpeed's activation checkpointing and torch.utils.checkpoint - GPU memory consumption is about the same, the wall-clock breakdown for forward/backward are also about the same. Since the best I could do with the latter was 8B (with 512 sequence length), the best I can do with the former will also be 8B. What else would I need to do to fit a 10B GPT-2 on a 32 GB V100, as is claimed for ZeRO-Offload?

I also have a couple other questions:

  • What are some ways by which activation memory can be reduced even further in the mpu_=None setting? As mentioned earlier in this thread, I wouldn't be able to fit a 10B sized GPT-2 on a V100 otherwise.
  • What is the value of using synchronize_checkpoint_boundary? As in, in what scenarios would one want to perform a torch.cuda.synchronize() for each checkpointed layer and what would the benefit of doing that be?
  • Why can't activations be offloaded to CPU in the mpu_=None case? We're already offloading optimizer states to CPU via cpu_offload.
  • Why can't all activations be allocated up front in the mpu_=None case?

I'd asked these last two questions earlier on in the above thread but I didn't quite understand why they're tied to the partition_activations feature, which assumes activation replication across model-parallel workers.

@g-karthik
Copy link
Author

@ShadenSmith also take a look at this DS config JSON from the DeepSpeedExamples repo for 10B (with model-parallel degree MP=1 as seen here) in the context of the last two questions above.

Clearly, even when model-parallel degree = 1, the features "partition_activations", "cpu_checkpointing" and "contiguous_memory_optimization" are being enabled and used.

So shouldn't these be supported when mpu_=None? Because that's equivalent to a model-parallel degree of 1.

@samyam
Copy link
Contributor

samyam commented Apr 12, 2021

Hi @g-karthik, these features are supported if mpu=None. You should be able to turn on all of the above features you referred to without providing any mpu object, and it will default to model-parallelism = 1.

You can offload activation checkpoints to CPU with mpu=None. To do this you need to set both partition_activations and cpu_checkpointing to True, which can be done even when mpu is None. This is a superficial restriction in the code that we need to fix.

To enable cpu_checkpointing you need to first enable partition_actications.

@ShadenSmith sorry for the delay, I did test this fix and it works, thanks a lot! Note that this is with Hugging Face's transformer models without usage of mpu_.

Keeping all else equal, I do not see any meaningful difference between DeepSpeed's activation checkpointing and torch.utils.checkpoint - GPU memory consumption is about the same, the wall-clock breakdown for forward/backward are also about the same. Since the best I could do with the latter was 8B (with 512 sequence length), the best I can do with the former will also be 8B. What else would I need to do to fit a 10B GPT-2 on a 32 GB V100, as is claimed for ZeRO-Offload?

I also have a couple other questions:

  • What are some ways by which activation memory can be reduced even further in the mpu_=None setting? As mentioned earlier in this thread, I wouldn't be able to fit a 10B sized GPT-2 on a V100 otherwise.

Offload to CPU

  • What is the value of using synchronize_checkpoint_boundary? As in, in what scenarios would one want to perform a torch.cuda.synchronize() for each checkpointed layer and what would the benefit of doing that be?

Sometimes it will help the memory allocator clean up the memory and avoid OOM or speed up memory allocations

  • Why can't activations be offloaded to CPU in the mpu_=None case? We're already offloading optimizer states to CPU via cpu_offload.

It can be. Yout just need to set partition_activations to true as well. This is a superficial restriction as you have noticed and we need to fix it.

  • Why can't all activations be allocated up front in the mpu_=None case?

Can you provide some more context on this question?

I'd asked these last two questions earlier on in the above thread but I didn't quite understand why they're tied to the partition_activations feature, which assumes activation replication across model-parallel workers.

This is a superficial restriction. You can get around it by just setting partition_activations to True even when running without mpu. It will just assume MP=1 and do the right thing and allow you to offload your activations

@samyam
Copy link
Contributor

samyam commented Apr 12, 2021

@ShadenSmith also take a look at this DS config JSON from the DeepSpeedExamples repo for 10B (with model-parallel degree MP=1 as seen here) in the context of the last two questions above.

Clearly, even when model-parallel degree = 1, the features "partition_activations", "cpu_checkpointing" and "contiguous_memory_optimization" are being enabled and used.

So shouldn't these be supported when mpu_=None? Because that's equivalent to a model-parallel degree of 1.

It should be supported with mpu=None. Are you running into errors?

@g-karthik
Copy link
Author

@samyam thanks for your response! Yes, I am indeed running into errors when setting those flags to true, see: #284 (comment) This error is with the latest DeepSpeed. And as a reminder, I use Hugging Face's GPT-2.

@tjruwase tjruwase reopened this Apr 12, 2021
@tjruwase
Copy link
Contributor

@g-karthik, can you please share repro steps with HF GPT-2? I am trying to do that now.

@g-karthik
Copy link
Author

@tjruwase I replied here #284 (comment), we can continue this discussion on this thread since this issue has to do with deepspeed's activation checkpointing.

The other thread is closely intertwined with this one though, since reproducing those claims requires deepspeed's activation checkpointing to work fine.

@tjruwase
Copy link
Contributor

tjruwase commented Apr 13, 2021

@g-karthik, thanks for reminding/showing me to enable deepspeed.checkpointing in HF GPT2. However, after doing that I am still unable to repro the reported error. I am seeing some issues that suggest that you and I are running different code stacks. Below are my observations:

  1. deepspeed.checkpointing.configure() should fail with a deepspeed_config that works for HF. This is because HF forbids batch_size and gas params in the json, while deepspeed.checkpointing.configure() requires them.

  2. gradient_checkpointing is disabled in the GPT2 config, so I hacked this line to checkpoint regardless of the config setting. However, this triggered this warning because of use_cache setting.

So how are you avoiding these two issues? Can you share more details of your repro setup? Below is my environment:

DeepSpeed general environment info:
torch install path ............... ['/opt/conda/lib/python3.8/site-packages/torch']
torch version .................... 1.7.1+cu110
torch cuda version ............... 11.0
nvcc version ..................... 11.0
deepspeed install path ........... ['/opt/conda/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.3.14+112ebff, 112ebff, master
deepspeed wheel compiled w. ...... torch 1.7, cuda 11.0
Name: transformers
Version: 4.6.0.dev0

@g-karthik
Copy link
Author

g-karthik commented Apr 13, 2021

@tjruwase Indeed, my stack is different, as are my dependencies - I'm using torch 1.4, CUDA 10.1, transformers 3.1. I don't think it matters for this issue though.

About 1 - I have a separate deepspeed_config.json file which has the train_batch_size aptly configured for my cluster. So my train script's DataLoader would be configured using the micro-batch size, and the DeepSpeedEngine would be configured using the deepspeed config. And that DataLoader would be used to serve examples to the DeepSpeed model-wrapper model_engine.

As for 2 - I created a local copy of the HF transformers modeling_gpt2.py and edited it to ensure that deepspeed's checkpointing is always used. And my training script uses this local copy.

I don't currently use the HF Trainer class, but I would imagine you should be able to repro this with that class as well, as long as that class pulls your deepspeed-edited version of the GPT-2 model class and you have configured checkpointing with your deepspeed_config.json correctly. Maybe you can temporarily hack in a second deepspeed_config.json into the HF Trainer for the purpose of configuring deepspeed.checkpointing? It's a one-time op, after all.

@g-karthik
Copy link
Author

@tjruwase did you get a chance to work through this?

@tjruwase
Copy link
Contributor

@g-karthik, sorry due to our latest release, I did not have bandwidth to explore the proposed workarounds. I will resume this work this week. Stay tuned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants