Skip to content

Commit

Permalink
Added test for functor Cholesky.
Browse files Browse the repository at this point in the history
  • Loading branch information
aterenin committed Apr 25, 2020
1 parent 621b4cb commit d5cb27e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Adapt: adapt, adapt_storage
using LinearAlgebra: Cholesky
using Zygote: IdSet

functor(x) = (), _ -> x
Expand Down Expand Up @@ -115,3 +116,6 @@ paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)

f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)

# Functors for certain Julia data structures
@functor Cholesky
4 changes: 0 additions & 4 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using LinearAlgebra

"""
Chain(layers...)
Expand Down Expand Up @@ -260,5 +258,3 @@ end
function Base.show(io::IO, b::SkipConnection)
print(io, "SkipConnection(", b.layers, ", ", b.connection, ")")
end

@functor Cholesky
14 changes: 13 additions & 1 deletion test/cuda/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Flux, Test
using Flux.CuArrays
using Flux: gpu
using Flux: cpu, gpu
using LinearAlgebra: I, cholesky, Cholesky

@info "Testing GPU Support"

Expand Down Expand Up @@ -65,6 +66,17 @@ end
@test gradient(foo, cu(rand(1)))[1] isa CuArray
end

@testset "GPU functors" begin
@testset "Cholesky" begin
M = 2.0*I(10) |> collect
Q = cholesky(M)
Q_gpu = Q |> gpu
@test Q_gpu isa Cholesky{<:Any,<:CuArray}
Q_cpu = Q_gpu |> cpu
@test Q_cpu == cholesky(eltype(Q_gpu).(M))
end
end

if CuArrays.has_cudnn()
@info "Testing Flux/CUDNN"
include("cudnn.jl")
Expand Down

0 comments on commit d5cb27e

Please sign in to comment.