Skip to content

Commit

Permalink
Fix blindspot in per-channelness checking logic of custom grad function
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored May 12, 2023
1 parent 912c8cd commit 57ed1b5
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,14 @@ def asymmetric_gradients(tensor: torch.Tensor,
grad_offset = grad_xq * (1 - mask_tensor)

dim = list(range(len(tensor.shape)))
if len(delta.shape) > 1 and len(tensor.shape) > 1:
if delta.numel() > 1 and len(tensor.shape) > 1:
dim.pop(channel_axis)

num_steps = intermediate_result.num_steps
encoding_min = intermediate_result.encoding_min
encoding_max = intermediate_result.encoding_max

if len(delta.shape) > 1 and len(tensor.shape) == 1:
if delta.numel() > 1 and len(tensor.shape) == 1:
# NOTE: Handle when applying per-channel quant to 1-D Tensor case such as bias tensor in Conv or beta/gamma in BatchNorm
intermediate_term1 = grad_scale / num_steps
intermediate_term2 = num_steps / (encoding_max - encoding_min) ** 2 * grad_offset
Expand Down Expand Up @@ -301,13 +301,13 @@ def symmetric_gradients(tensor: torch.Tensor,
mask_tensor = Variable(mask_tensor.type_as(grad.data))

dim = list(range(len(tensor.shape)))
if len(delta.shape) > 1 and len(tensor.shape) > 1:
if delta.numel() > 1 and len(tensor.shape) > 1:
dim.pop(channel_axis)

num_steps = intermediate_result.num_steps
grad_tensor = mask_tensor * grad

if len(delta.shape) > 1 and len(tensor.shape) == 1:
if delta.numel() > 1 and len(tensor.shape) == 1:
# NOTE: Handle when applying per-channel quant to 1-D Tensor case such as bias tensor in Conv or beta/gamma in BatchNorm
grad_encoding_max = ((x_quant + offset) * grad) - (mask_tensor * (tensor / delta) * grad)
else:
Expand Down

0 comments on commit 57ed1b5

Please sign in to comment.