Skip to content

Commit

Permalink
Update lsq.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hustzxd authored Nov 19, 2020
1 parent f2d8ce6 commit 86cf511
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def backward(ctx, grad_weight):
q_w = weight / alpha
indicate_small = (q_w < Qn).float()
indicate_big = (q_w > Qp).float()
indicate_middle = torch.ones(indicate_small.shape).to(indicate_small.device) - indicate_small - indicate_big
# indicate_middle = torch.ones(indicate_small.shape).to(indicate_small.device) - indicate_small - indicate_big
indicate_middle = 1.0 - indicate_small - indicate_big # Thanks to @haolibai
grad_alpha = ((indicate_small * Qn + indicate_big * Qp + indicate_middle * (
-q_w + q_w.round())) * grad_weight * g).sum().unsqueeze(dim=0)
grad_weight = indicate_middle * grad_weight
Expand Down

0 comments on commit 86cf511

Please sign in to comment.