Skip to content

Commit

Permalink
allow NNlib v0.9 (#38)
Browse files Browse the repository at this point in the history
* allow NNlib v0.9

* fix test
  • Loading branch information
CarloLucibello authored Jun 16, 2023
1 parent 469b192 commit a531d55
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[compat]
Adapt = "3.0"
CUDA = "3.8"
CUDA = "4"
ChainRulesCore = "1.13"
Compat = "4.2"
GPUArraysCore = "0.1.0"
NNlib = "0.8"
NNlib = "0.8, 0.9"
Zygote = "0.6.35"
julia = "1.6"

Expand Down
6 changes: 5 additions & 1 deletion test/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ end
@test (repr("text/plain", y); true)

gA = rand(3, 2) |> cu;
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
if VERSION >= v"1.9" && CUDA.functional()
@test gradient(A -> sum(A * y), gA)[1] isa CuArray
else
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
end
end

@testset "onehotbatch(::CuArray, ::UnitRange)" begin
Expand Down

0 comments on commit a531d55

Please sign in to comment.