diff --git a/test/activation.jl b/test/activation.jl index d34e17c42..8a42f2532 100644 --- a/test/activation.jl +++ b/test/activation.jl @@ -100,7 +100,7 @@ end xs = Float32[1 2 3; 1000 2000 3000] @test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.] - @test NNlib.∇logsoftmax(ones(size(xs)), xs) ≈ zeros(Float32, size(xs)) + @test NNlib.∇logsoftmax(ones(size(xs)), xs) ≈ Float32[1 1 1; -1 -1 -1] @test NNlib.∇softmax(ones(size(xs)), xs) ≈ zeros(Float32, size(xs)) # These values precalculated using PyTorch's nn.LogSoftmax