Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

makemore_part4_backprop dhpreact exact part is False. #45

Open
Stealeristaken opened this issue Mar 3, 2024 · 2 comments
Open

makemore_part4_backprop dhpreact exact part is False. #45

Stealeristaken opened this issue Mar 3, 2024 · 2 comments

Comments

@Stealeristaken
Copy link

Stealeristaken commented Mar 3, 2024

Hello Andrej. First of all, I would like to express my gratitude to you for sharing such a valuable videos with us for free.

While watching the 'makemore part4' video, I was also trying to apply it to my own created dataset. When I tried to take the chained derivative in the 'dhpreact' part, it started to give an error output, and since it is a chain derivative operation, it also included subsequent outputs. Below, I share the code line and the output.
Please share any other solution if you have one. Using different Torch versions and changing the dtype to 'double' as suggested in the comments didn't work out for me.

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n
dprobs = (1.0 / probs) * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
dcounts += torch.ones_like(counts) * dcounts_sum
dnorm_logits = counts * dcounts
dlogits = dnorm_logits.clone()
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
dhpreact = (1.0 - h**2) * dh
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnraw = bngain * dhpreact


Output : 

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: False | approximate: True  | maxdiff: 4.656612873077393e-10
bngain          | exact: False | approximate: True  | maxdiff: 1.862645149230957e-09
bnbias          | exact: False | approximate: True  | maxdiff: 7.450580596923828e-09
bnraw           | exact: False | approximate: True  | maxdiff: 6.984919309616089e-10
bnvar_inv       | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09
bnvar           | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10
@zzkzzkjsw
Copy link

It seems like the problem of different versions of pytorch. It works well for torch == 1.12.0, but not the lastest.

@Stealeristaken
Copy link
Author

It seems like the problem of different versions of pytorch. It works well for torch == 1.12.0, but not the lastest.

Yes but tbh i want to know what is changed in versions so it gives error

conscell added a commit to conscell/nn-zero-to-hero that referenced this issue Dec 27, 2024
Fixes karpathy#13 karpathy#45, where the `dhpreact` was not exactly matching the `hpreact.grad`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants