Skip to content

Commit

Permalink
Fix gradient clipping (microsoft#5150)
Browse files Browse the repository at this point in the history
The gradient clipping API doesn't apply the coefficient correctly. This
PR resolves the issue and adds a test case.

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
2 people authored and SNahir committed Mar 11, 2024
1 parent 28c4ec8 commit 60f0fdc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
2 changes: 1 addition & 1 deletion deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
max_norm = torch.tensor([float(max_norm)], device=parameters[0].device)
clip_coef = max_norm / (total_norm + 1e-6)
tmp_tensor = torch.tensor([1.0], device=parameters[0].device)
clip_coef = torch.max(tmp_tensor, clip_coef)
clip_coef = torch.min(tmp_tensor, clip_coef)
for p in parameters:
p.grad.data.mul_(clip_coef)
return total_norm
Expand Down
25 changes: 23 additions & 2 deletions tests/unit/runtime/test_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def test_call_to_str():
assert c2s('hello', 1138, val=3) == 'hello(1138, val=3)'


class TestClibGradNorm(DistributedTest):
class TestClipGradNorm(DistributedTest):
world_size = 2

def test(self):
def test_gather(self):
param1 = torch.nn.Parameter(torch.Tensor([0]))
param1.grad = torch.Tensor([1])
param2 = torch.nn.Parameter(torch.Tensor([0]))
Expand All @@ -50,6 +50,27 @@ def test(self):

assert gathered_norm[0] == gathered_norm[1], "norm at rank 0 does not match the norm at rank 1"

def test_clipped_val(self):
max_norm = 0.1

def test_params():
param1 = torch.nn.Parameter(torch.Tensor([0]))
param1.grad = torch.Tensor([1])
param2 = torch.nn.Parameter(torch.Tensor([0]))
param2.grad = torch.Tensor([1])
return [param1, param2]

# This assumes gradients are same on all the ranks and doesn't consider multiple ranks
params_expected = test_params()
torch.nn.utils.clip_grad_norm_(params_expected, max_norm)

params_actual = test_params()
ds_utils.clip_grad_norm_(params_actual, max_norm=max_norm)

# This can be allclose
assert torch.equal(params_expected[0].grad, params_actual[0].grad)
assert torch.equal(params_expected[1].grad, params_actual[1].grad)


@pytest.mark.parametrize("check_using_norm", [(False), (True)])
class TestCheckOverflow(DistributedTest):
Expand Down

0 comments on commit 60f0fdc

Please sign in to comment.