Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cuDNN softmax and log-softmax, add tests. #34

Merged
merged 2 commits into from
Oct 22, 2018

Conversation

dan-zheng
Copy link
Collaborator

  • Implement cuDNN softmax and log-softmax.
    • The functions are currently limited to 2-D tensors.
  • Add tests, verifying both forward/backward results.
    • Add CPU softmax and log-softmax tests to ensure parity.

Low priority todos:

  • Support softmax grad on CPU.
  • Implement softmax for arbitrary rank tensors. The API should look like
    softmax(axis: Int), matching other reduction ops.

@dan-zheng dan-zheng requested a review from feiwang3311 October 22, 2018 02:48
- Implement cuDNN softmax and log-softmax.
  - The functions are currently limited to 2-D tensors.
- Add tests, verifying both forward/backward results.
  - Add CPU softmax and log-softmax tests to ensure parity.

Low priority todos:
- Support softmax grad on CPU.
- Implement softmax for arbitrary rank tensors. The API should look like
  `softmax(axis: Int)`, matching other reduction ops.
@feiwang3311 feiwang3311 merged commit 3626b99 into feiwang3311:master Oct 22, 2018
@dan-zheng dan-zheng deleted the cudnn-softmax branch October 22, 2018 15:22
@dan-zheng
Copy link
Collaborator Author

Btw, softmax calculations verified via PyTorch:

import torch
import torch.nn.functional as F

a = torch.Tensor(range(6)).reshape(2, 3)
a.requires_grad = True
b = F.log_softmax(a, dim=1)
b.backward(torch.ones(2, 3))

print(b)
# tensor([[-2.4076, -1.4076, -0.4076],
#         [-2.4076, -1.4076, -0.4076]], grad_fn=<LogSoftmaxBackward>)

print(a.grad)
# tensor([[ 0.7299,  0.2658, -0.9957],
#         [ 0.7299,  0.2658, -0.9957]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants