Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO 0, 1, 2, 3 produce different results #966

Closed
szhengac opened this issue Apr 16, 2021 · 30 comments · Fixed by #1021
Closed

ZeRO 0, 1, 2, 3 produce different results #966

szhengac opened this issue Apr 16, 2021 · 30 comments · Fixed by #1021

Comments

@szhengac
Copy link
Contributor

szhengac commented Apr 16, 2021

Previously I observed that ZeRO 1 produces significant worse performance than ZeRO 2 in the finetuning. Similar observation is also mentioned in #757. I created a simple test to see how loss changes with different ZeRO stages. The test code snippet (test_zero.py) is pasted as follows (the code is run by using deepspeed test_zero.py --zero ${ZERO}):

import os
import json
import argparse
import torch
import deepspeed
from torch import nn
from torch.utils.data.distributed import DistributedSampler


class SimpleModel(torch.nn.Module):
    def __init__(self, hidden_dim, empty_grad=False, zero=0):
        super(SimpleModel, self).__init__()
        self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
        mlp = [self.linear]
        mlp.append(torch.nn.Linear(hidden_dim, hidden_dim//2))
        for _ in range(6):
            l = torch.nn.Linear(hidden_dim//2, hidden_dim//2)
            mlp.append(l)
        mlp.append(torch.nn.Linear(hidden_dim//2, hidden_dim))
        l = torch.nn.Linear(hidden_dim, hidden_dim)
        l.weight = self.linear.weight
        l.bias = self.linear.bias
        mlp.append(l)
        if zero == 3:
            deepspeed.zero.register_external_parameter(self, self.linear.weight)
            deepspeed.zero.register_external_parameter(self, self.linear.bias)
        self.mlp = nn.Sequential(*mlp)
        if empty_grad:
            self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

    def forward(self, x, y):
        hidden_dim = x
        hidden_dim = self.mlp(hidden_dim)
        return self.cross_entropy_loss(hidden_dim, y)


def create_config_from_dict(tmpdir, config_dict):
    config_path = os.path.join(tmpdir, 'temp_config.json')
    with open(config_path, 'w') as fd:
        json.dump(config_dict, fd)
    return config_path


def get_data_loader(model, total_samples, hidden_dim, device):
    batch_size = model.train_micro_batch_size_per_gpu()
    train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half)
    train_label = torch.empty(total_samples,
                              dtype=torch.long,
                              device=device).random_(hidden_dim)
    train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
    sampler = DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               sampler=sampler)
    return train_loader


def get_args(tmpdir, config_dict):
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument('--zero', type=int, default=0)
    args = parser.parse_args()  #args=''

    config_dict["zero_optimization"]["stage"] = args.zero
    print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
    config_path = create_config_from_dict(tmpdir, config_dict)

    args.deepspeed_config = config_path
    return args


def print0(msg):
    if torch.distributed.get_rank() == 0:
        print(msg, flush=True)


rank = int(os.environ['RANK'])
print('seed:', 2222 + rank)
torch.random.manual_seed(2222 + rank)

config_dict = {
    "train_batch_size": 32,
    "train_micro_batch_size_per_gpu": 4,
    "steps_per_print": 1,
    "zero_allow_untested_optimizer": True,
    "optimizer": {
        "type": "LAMB",
        "params": {
            "lr": 0.02,
            "weight_decay": 0.01,
            "bias_correction": True,
            "eps": 1e-6
        }
    },
    "gradient_clipping": 1.0,
    "fp16": {
        "enabled": True,
        "initial_scale_power": 10
    },
    "zero_optimization": {
        "stage": 1,
        "overlap_comm": True,
        "contiguous_gradients": False,
        "reduce_bucket_size": 20
    }
}
#        "initial_scale_power": 15
args = get_args('/tmp/', config_dict)
hidden_dim = 4

model = SimpleModel(hidden_dim, empty_grad=False, zero=args.zero)

model, _, _,_ = deepspeed.initialize(args=args,
                                     model=model,
                                     model_parameters=model.parameters(),
                                     dist_init_required=True)


def print_params(tag, model):
    if torch.distributed.get_rank() == 0:
        for n, p in model.named_parameters():
            print0("{} {}:{}".format(tag, n, p))


data_loader = get_data_loader(model=model,
                              total_samples=1000,
                              hidden_dim=hidden_dim,
                              device=model.device)
#print_params('pre-train', model)
for n, batch in enumerate(data_loader):
    loss = model(batch[0], batch[1])
    #if torch.distributed.get_rank() == 0 and model.is_gradient_accumulation_boundary():
    model.backward(loss)
    model.step()
    if torch.distributed.get_rank() == 0 and model.is_gradient_accumulation_boundary():
        print("{}, LOSS: {}".format(n, loss.item()))
    #print_params('step={}'.format(n), model)
    if n == 4: break
                                               

The following shows the results:

ZeRO 0:

0, LOSS: 1.6396484375
1, LOSS: 1.4296875
2, LOSS: 1.4267578125
3, LOSS: 1.529296875
4, LOSS: 1.623046875

ZeRO 1:

0, LOSS: 1.6396484375
1, LOSS: 1.4296875
2, LOSS: 1.427734375
3, LOSS: 1.5322265625
4, LOSS: 1.626953125

ZeRO 2:

0, LOSS: 1.6396484375
1, LOSS: 1.4306640625
2, LOSS: 1.427734375
3, LOSS: 1.529296875
4, LOSS: 1.623046875

ZeRO 3 gives me error

UnboundLocalError
    : param.grad.data = dest_tensor_full_buffer.datalocal variable 'dest_tensor_full_buffer' referenced before assignment
UnboundLocalError
: local variable 'dest_tensor_full_buffer' referenced before assignment
UnboundLocalError: local variable 'dest_tensor_full_buffer' referenced before assignment
    accumulate=True if self.micro_step_id > 0 else False)
  File "/usr/local/lib64/python3.7/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 453, in partition_gradients
    accumulate=accumulate)
  File "/usr/local/lib64/python3.7/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 787, in _partition_gradients
    accumulate=accumulate)
  File "/usr/local/lib64/python3.7/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 855, in _partition_gradient
    param.grad.data = dest_tensor_full_buffer.data
UnboundLocalError: local variable 'dest_tensor_full_buffer' referenced before assignment

As can be seen, ZeRO 0 and 2 produce exactly the same result, while ZeRO 1 gives different losses. And, ZeRO 3 test fails.

@tjruwase FYI

@aced125
Copy link

aced125 commented Apr 16, 2021

@szhengac Set hidden dim > 35 and zero level 3 should work for 8 GPUs.

See also #940

@tjruwase
Copy link
Contributor

@szhengac, thanks for this excellent repro recipe. I am able to repro on my side.

@aced125, thanks for sharing a workaround and for your hard work on #940. However, with hidden_dim=36, I am seeing the same UnboundLocalError with zero-3. On the other hand, with hidden_dim=36, the zero-1 regression seems to disappear.

These are very helpful clues that you have both provided to help with this investigation.

@tjruwase
Copy link
Contributor

@szhengac, so I have fixed zero-3. Below is the full results, with hidden_dim=4. I want to clarify that zero-0 and zero-2 don't match exactly, they are different on steps 1 & 2. I think my observation matches your results, right?

step  zero-0 zero-1 zero-2 zero-3
0 1.639648438 1.639648438 1.639648438 1.639648438
1 1.4296875 1.4296875 1.430664063 1.430664063
2 1.426757813 1.427734375 1.427734375 1.427734375
3 1.529296875 1.532226563 1.529296875 1.529296875
4 1.623046875 1.626953125 1.623046875 1.623046875

@szhengac . I will continue investigation on the zero-1 regression, but please do share if you have other evidences.

@szhengac
Copy link
Contributor Author

@tjruwase Thanks for the quick fix for ZeRO 3. I didn't note that steps 1 and 2 have different losses for ZeRO 0 and 2. I think this is due to that the behavior of LAMB is different when optimizer state partition is adopted. After I change the optimizer to Adam, I can obtain the following:

step zero-0 zero-1 zero-2
0 1.23046875 1.23046875 1.23046875
1 1.533203125 1.533203125 1.533203125
2 1.3779296875 1.3779296875 1.3779296875
3 1.3134765625 1.3134765625 1.3134765625
4 1.220703125 1.2216796875 1.220703125

You can see that now ZeRO 0 and 2 have exactly the same numbers, while the loss at step 4 is different for ZeRO 1.

@tjruwase
Copy link
Contributor

@szhengac, I get different observations for Adam on 8GPUs, see below. Of course some the differences in our observations can be due to environment, i.e., initialization and paddings.

  zero-0 zero-1 zero-2 zero-3
0 1.639648438 1.639648438 1.639648438 1.639648438
1 1.399414063 1.399414063 1.399414063 1.399414063
2 1.442382813 1.44140625 1.442382813 1.438476563
3 1.525390625 1.524414063 1.525390625 1.514648438
4 1.537109375 1.537109375 1.537109375 1.52734375

@tjruwase
Copy link
Contributor

@szhengac, the fix for the UnboundLocal has been merged, perhaps you can close this issue when you verify. Thanks.

@szhengac
Copy link
Contributor Author

szhengac commented Apr 16, 2021

So in your testing, ZeRO 1 and 3 have different losses from ZeRO 1 and 2. Do you have idea? As ZeRO 1 produces very bad performance in my finetuning task, it is good to figure it out and have consistent result.

@tjruwase
Copy link
Contributor

@szhengac, I agree that ZeRO 1 requires further investigation, which I plan to do. However, it seems this particular issue was opened for a ZeRO 3 error. Please feel free to open a separate issue to track ZeRO 1 regression, where we can continue the investigation. Does that make sense?

@tjruwase
Copy link
Contributor

@szhengac, on second thoughts. Let's just use this issue to at least understand the source of the different results.

@tjruwase
Copy link
Contributor

@aced125, I saw your question about the fix. Let me know if you want more details after reviewing the PR.

@aced125
Copy link

aced125 commented Apr 17, 2021

@tjruwase here is the bug with, for example, nn.Conv

Please replace the SimpleModel class with:

class SimpleModel(torch.nn.Module):
    def __init__(self, hidden_dim, empty_grad=False, zero=0):
        super(SimpleModel, self).__init__()
        self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
        mlp = [self.linear]
        mlp.append(torch.nn.Linear(hidden_dim, hidden_dim//2))
        for _ in range(6):
            l = torch.nn.Linear(hidden_dim//2, hidden_dim//2)
            mlp.append(l)
        mlp.append(torch.nn.Linear(hidden_dim//2, hidden_dim))
        l = torch.nn.Linear(hidden_dim, hidden_dim)
        l.weight = self.linear.weight
        l.bias = self.linear.bias
        mlp.append(l)
        if zero == 3:
            deepspeed.zero.register_external_parameter(self, self.linear.weight)
            deepspeed.zero.register_external_parameter(self, self.linear.bias)
        self.mlp = nn.Sequential(*mlp)
        if empty_grad:
            self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.conv1d = nn.Conv1d(1, 1, kernel_size=3, stride=1, padding=1, bias=True)

    def forward(self, x, y):
        hidden_dim = self.conv1d(x.unsqueeze(1)).squeeze(1)
        hidden_dim = self.mlp(hidden_dim)
        return self.cross_entropy_loss(hidden_dim, y)

and run

deepspeed test_zero.py --zero 3

Also - add CPU offloading

@aced125
Copy link

aced125 commented Apr 17, 2021

Actually - I seem to be getting a different error (on A100), when running the above:

RuntimeError: p.type().is_cuda() INTERNAL ASSERT FAILED at "/home/ubuntu/anaconda3/envs/torch/lib/python3.7/site-packages/deepspeed/ops/csrc/lamb/fused_lamb_cuda.cpp":49, please report a bug to PyTorch. p must be a CUDA tensor

@tjruwase
Copy link
Contributor

@aced125, this expected as CPU offloading means the optimizer will execute on the CPU, and so you need CPU-implementation of optimizer. However, deepspeed/ops/csrc/lamb/fused_lamb_cuda.cp is CUDA-implementation of LAMB. We currently don't have CPU-implementation of LAMB.

@tjruwase
Copy link
Contributor

@szhengac, I have a fix for zero-1 issue in this branch, can you give it a try?

@szhengac
Copy link
Contributor Author

@tjruwase The current option is to disable reduce-scatter?

@tjruwase
Copy link
Contributor

@szhengac, yes. Reduce-scatter appears to be broken in zero-1, and it might take a while to fix. Does it work for you now?

@szhengac
Copy link
Contributor Author

@tjruwase Yes. This workaround works for the test cases.

@tjruwase
Copy link
Contributor

Thanks for the confirmation. We are trying to fix reduce-scatter asap, bandwidth permitting. But we hope this workaround can unblock until then.

@tjruwase
Copy link
Contributor

@szhengac, did you also check if zero-3 does not match others?

@szhengac
Copy link
Contributor Author

ZeRO 3 does not match if I increase number of steps to 8.

@tjruwase
Copy link
Contributor

@szhengac, thanks for the sharing that update. I will look further. Can you also share results?

@szhengac
Copy link
Contributor Author

  zero-0,1,2 zero-3
0 1.23046875 1.23046875
1 1.533203125 1.533203125
2 1.3779296875 1.3779296875
3 1.3134765625 1.3134765625
4 1.220703125 1.220703125
5 1.462890625 1.4619140625
6 1.263671875 1.263671875
7 1.234375 1.23828125
8 1.392578125 1.3955078125

@tjruwase
Copy link
Contributor

@szhengac, I don't see any difference between zero 3 and others with Adam optimizer. Are your results from LAMB or Adam?

@szhengac
Copy link
Contributor Author

@tjruwase I use Adam.

@szhengac
Copy link
Contributor Author

Did you increase the number of iterations to 8?

@szhengac
Copy link
Contributor Author

And, my deepspeed version is '0.3.15+e414435'

@tjruwase
Copy link
Contributor

@szhengac, sorry this was my error. I was running on 2gpus. However, with 8 gpus I see differences between zero 3 and others.

@SantoshGuptaML
Copy link

It looks like they just merged this fix not too long ago #968 and incorporated into the official release 9 days ago https://github.com/microsoft/DeepSpeed/releases

I did pip install deepspeed -U and it started working

@tjruwase
Copy link
Contributor

@szhengac, I think the ZeRO3 regression is now fixed. PR should be merged soon. Thanks for helping to find this corner case. I extended your test case to 17 iterations.

  0 1 2 3
1 1.683594 1.683594 1.683594 1.683594
2 1.853516 1.853516 1.853516 1.853516
3 1.154297 1.154297 1.154297 1.154297
4 1.125977 1.125977 1.125977 1.125977
5 1.227539 1.227539 1.227539 1.227539
6 1.537109 1.537109 1.537109 1.537109
7 1.385742 1.385742 1.385742 1.385742
8 1.322266 1.322266 1.322266 1.322266
9 1.381836 1.381836 1.381836 1.381836
10 1.415039 1.415039 1.415039 1.415039
11 1.389648 1.389648 1.389648 1.389648
12 1.552734 1.552734 1.552734 1.552734
13 1.501953 1.501953 1.501953 1.501953
14 1.351563 1.351563 1.351563 1.351563
15 1.404297 1.404297 1.404297 1.404297
16 1.589844 1.589844 1.589844 1.589844
17 1.422852 1.422852 1.422852 1.422852

@szhengac
Copy link
Contributor Author

szhengac commented Apr 29, 2021

@tjruwase Thanks for the quick fix. I just took a look at your PR. So the squared gradient norms were reduced twice?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants