Skip to content

Commit

Permalink
WIP 1.0 support
Browse files Browse the repository at this point in the history
closes #353
  • Loading branch information
MikeInnes committed Aug 20, 2018
1 parent 0ef6456 commit 5a023a9
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 13 deletions.
1 change: 0 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ export Chain, Dense, RNN, LSTM, GRU, Conv,
params, mapleaves, cpu, gpu

@reexport using NNlib
using NNlib: @fix

include("tracker/Tracker.jl")
using .Tracker
Expand Down
2 changes: 1 addition & 1 deletion src/cuda/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module CUDA

using CuArrays
using ..CuArrays

CuArrays.cudnn_available() && include("cudnn.jl")

Expand Down
12 changes: 6 additions & 6 deletions src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
cudnnDataType, TensorDesc, FilterDesc

using LinearAlgebra
using LinearAlgebra

mutable struct DropoutDesc
ptr::Ptr{Nothing}
Expand Down Expand Up @@ -243,8 +243,8 @@ end

import ..Flux: Flux, relu
import ..Tracker: TrackedArray
using CUDAnative
using CuArrays: @cuindex, cudims
using .CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims

function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
Expand Down Expand Up @@ -326,7 +326,7 @@ end
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
end
end

Expand All @@ -341,6 +341,6 @@ end
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
dWi.', dWh.', db))
transpose(dWi), transpose(dWh), db))
end
end
2 changes: 1 addition & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ end

function (a::Dense)(x)
W, b, σ = a.W, a.b, a.σ
@fix σ.(W*x .+ b)
σ.(W*x .+ b)
end

function Base.show(io::IO, l::Dense)
Expand Down
2 changes: 1 addition & 1 deletion src/layers/stateless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)

function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
@fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
end

@deprecate logloss(x, y) crossentropy(x, y)
Expand Down
5 changes: 3 additions & 2 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ import Adapt.adapt
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))

@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
import CuArrays: CuArray, cudaconvert
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
import .CuArrays: CuArray, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end

Expand Down
29 changes: 29 additions & 0 deletions src/tracker/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,32 @@ function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle})
bc = Broadcast.flatten(bc)
∇broadcast(bc.f, bc.args...)
end

using Requires

# https://github.com/FluxML/Flux.jl/issues/353
@init @eval Base.Broadcast begin
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
args = cat_nested(bc)
let makeargs = make_makeargs(bc), f = bc.f
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted{Style}(newf, args, bc.axes)
end
end
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
bc = t[1]
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
let makeargs = make_makeargs(makeargs, bc.args)
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
return @inline function(args::Vararg{Any,N}) where N
args1 = makeargs(args...)
a, b = headargs(args1...), tailargs(args1...)
(f(a...), b...)
end
end
end
end
end
1 change: 1 addition & 0 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cx = gpu(x)
x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
@test (cx .+ 1) isa CuArray

m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
cm = gpu(m)
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ using Random

Random.seed!(0)

# So we can use the system CuArrays
insert!(LOAD_PATH, 2, "@v#.#")

@testset "Flux" begin

include("utils.jl")
Expand All @@ -12,7 +15,7 @@ include("layers/stateless.jl")
include("optimise.jl")
include("data.jl")

if Base.find_package("CuArrays") nothing
if Base.find_package("CuArrays") != nothing
include("cuda/cuda.jl")
end

Expand Down

0 comments on commit 5a023a9

Please sign in to comment.