Skip to content

Commit

Permalink
autodiff: Make sure inputs to cross entropy loss are at least 2d for …
Browse files Browse the repository at this point in the history
…torch<=1.11.x

d6d95f3 ("Nback (PrincetonUniversity#2617)") changed the
format of cross entropy target that requires torch >=1.12
1.12.0+ includes input handling path to consider inputs without batch
dimension and can be used directly. [0,1,2]

Fixes: d6d95f3 ("Nback (PrincetonUniversity#2617)")
Closes: PrincetonUniversity#2665

[0] https://github.com/pytorch/pytorch/releases/tag/v1.12.0
[1] pytorch/pytorch#77653
[2] pytorch/pytorch@8881d7a

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed May 12, 2023
1 parent 32393ae commit 43bc49b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
13 changes: 5 additions & 8 deletions psyneulink/library/compositions/autodiffcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@
import os
import warnings
import numpy as np
from packaging import version
from pathlib import Path, PosixPath

try:
Expand Down Expand Up @@ -456,19 +457,15 @@ def _get_loss(self, loss_spec):
elif loss_spec == Loss.SSE:
return nn.MSELoss(reduction='sum')
elif loss_spec == Loss.CROSS_ENTROPY:
if version.parse(torch.version.__version__) >= version.parse('1.12.0'):
return nn.CrossEntropyLoss()

# Cross entropy loss is used for multiclass categorization and needs inputs in shape
# ((# minibatch_size, C), targets) where C is a 1-d vector of probabilities for each potential category
# and where target is a 1d vector of type long specifying the index to the target category. This
# formatting is different from most other loss functions available to autodiff compositions,
# and therefore requires a wrapper function to properly package inputs.
cross_entropy_loss = nn.CrossEntropyLoss()
return lambda x, y: cross_entropy_loss(
# x.unsqueeze(0),
x,
# y.type(torch.LongTensor)
# torch.argmax(y.type(torch.LongTensor))
y.type(x.type())
)
return lambda x, y: nn.CrossEntropyLoss()(torch.atleast_2d(x), torch.atleast_2d(y.type(x.type())))
elif loss_spec == Loss.L1:
return nn.L1Loss(reduction='sum')
elif loss_spec == Loss.NLL:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ modeci_mdf<0.5, >=0.3.4; (platform_machine == 'AMD64' or platform_machine == 'x8
networkx<3.2
numpy<1.22.5, >=1.19.0
optuna<3.2.0
packaging<24.0
pandas<2.0.2
pillow<9.6.0
pint<0.22.0
Expand Down

0 comments on commit 43bc49b

Please sign in to comment.