-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ZeRO 0, 1, 2, 3 produce different results #966
Comments
@szhengac, thanks for this excellent repro recipe. I am able to repro on my side. @aced125, thanks for sharing a workaround and for your hard work on #940. However, with These are very helpful clues that you have both provided to help with this investigation. |
@szhengac, so I have fixed zero-3. Below is the full results, with
@szhengac . I will continue investigation on the zero-1 regression, but please do share if you have other evidences. |
@tjruwase Thanks for the quick fix for ZeRO 3. I didn't note that steps 1 and 2 have different losses for ZeRO 0 and 2. I think this is due to that the behavior of LAMB is different when optimizer state partition is adopted. After I change the optimizer to Adam, I can obtain the following:
You can see that now ZeRO 0 and 2 have exactly the same numbers, while the loss at step 4 is different for ZeRO 1. |
@szhengac, I get different observations for Adam on 8GPUs, see below. Of course some the differences in our observations can be due to environment, i.e., initialization and paddings.
|
@szhengac, the fix for the UnboundLocal has been merged, perhaps you can close this issue when you verify. Thanks. |
So in your testing, ZeRO 1 and 3 have different losses from ZeRO 1 and 2. Do you have idea? As ZeRO 1 produces very bad performance in my finetuning task, it is good to figure it out and have consistent result. |
@szhengac, I agree that ZeRO 1 requires further investigation, which I plan to do. However, it seems this particular issue was opened for a ZeRO 3 error. Please feel free to open a separate issue to track ZeRO 1 regression, where we can continue the investigation. Does that make sense? |
@szhengac, on second thoughts. Let's just use this issue to at least understand the source of the different results. |
@aced125, I saw your question about the fix. Let me know if you want more details after reviewing the PR. |
@tjruwase here is the bug with, for example, Please replace the class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False, zero=0):
super(SimpleModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
mlp = [self.linear]
mlp.append(torch.nn.Linear(hidden_dim, hidden_dim//2))
for _ in range(6):
l = torch.nn.Linear(hidden_dim//2, hidden_dim//2)
mlp.append(l)
mlp.append(torch.nn.Linear(hidden_dim//2, hidden_dim))
l = torch.nn.Linear(hidden_dim, hidden_dim)
l.weight = self.linear.weight
l.bias = self.linear.bias
mlp.append(l)
if zero == 3:
deepspeed.zero.register_external_parameter(self, self.linear.weight)
deepspeed.zero.register_external_parameter(self, self.linear.bias)
self.mlp = nn.Sequential(*mlp)
if empty_grad:
self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
self.conv1d = nn.Conv1d(1, 1, kernel_size=3, stride=1, padding=1, bias=True)
def forward(self, x, y):
hidden_dim = self.conv1d(x.unsqueeze(1)).squeeze(1)
hidden_dim = self.mlp(hidden_dim)
return self.cross_entropy_loss(hidden_dim, y) and run
Also - add CPU offloading |
Actually - I seem to be getting a different error (on A100), when running the above:
|
@aced125, this expected as CPU offloading means the optimizer will execute on the CPU, and so you need CPU-implementation of optimizer. However, |
@tjruwase The current option is to disable reduce-scatter? |
@szhengac, yes. Reduce-scatter appears to be broken in zero-1, and it might take a while to fix. Does it work for you now? |
@tjruwase Yes. This workaround works for the test cases. |
Thanks for the confirmation. We are trying to fix reduce-scatter asap, bandwidth permitting. But we hope this workaround can unblock until then. |
@szhengac, did you also check if zero-3 does not match others? |
ZeRO 3 does not match if I increase number of steps to 8. |
@szhengac, thanks for the sharing that update. I will look further. Can you also share results? |
|
@szhengac, I don't see any difference between zero 3 and others with Adam optimizer. Are your results from LAMB or Adam? |
@tjruwase I use Adam. |
Did you increase the number of iterations to 8? |
And, my deepspeed version is '0.3.15+e414435' |
@szhengac, sorry this was my error. I was running on 2gpus. However, with 8 gpus I see differences between zero 3 and others. |
It looks like they just merged this fix not too long ago #968 and incorporated into the official release 9 days ago https://github.com/microsoft/DeepSpeed/releases I did pip install deepspeed -U and it started working |
@szhengac, I think the ZeRO3 regression is now fixed. PR should be merged soon. Thanks for helping to find this corner case. I extended your test case to 17 iterations.
|
@tjruwase Thanks for the quick fix. I just took a look at your PR. So the squared gradient norms were reduced twice? |
Previously I observed that ZeRO 1 produces significant worse performance than ZeRO 2 in the finetuning. Similar observation is also mentioned in #757. I created a simple test to see how loss changes with different ZeRO stages. The test code snippet (
test_zero.py
) is pasted as follows (the code is run by usingdeepspeed test_zero.py --zero ${ZERO}
):The following shows the results:
ZeRO 0:
ZeRO 1:
ZeRO 2:
ZeRO 3 gives me error
As can be seen, ZeRO 0 and 2 produce exactly the same result, while ZeRO 1 gives different losses. And, ZeRO 3 test fails.
@tjruwase FYI
The text was updated successfully, but these errors were encountered: