diff --git a/src/Flux.jl b/src/Flux.jl index a423311204..119e40a9ac 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -23,7 +23,9 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion RNN, LSTM, GRU, GRUv3, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, - Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + Dropout, AlphaDropout, + LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + MultiHeadAttention, Upsample, PixelShuffle, fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32, testmode!, trainmode! @@ -59,6 +61,7 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") include("layers/upsample.jl") +include("layers/attention.jl") include("layers/show.jl") include("loading.jl") diff --git a/src/layers/attention.jl b/src/layers/attention.jl new file mode 100644 index 0000000000..5b4dacfaf1 --- /dev/null +++ b/src/layers/attention.jl @@ -0,0 +1,133 @@ + +const A3{T} = AbstractArray{T, 3} +const IntOrDims{N} = Union{Int, Dims{N}} + +""" + MultiHeadAttention(dims; [nheads, bias, init, dropout_prob]) + +The multi-head dot-product attention layer used in Transformer architectures [1]. + +Returns the transformed input sequnce and the attention scores. + +[1] Vaswani et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017. + +# Arguments + +- `dims`: The embedding dimensions of inputs, intermediate tensors and outputs. + In the most general case, it is given as + a) `(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`. + Can take also simpler forms as + b) `dims::Int`; + c) `in_dim::Int => (qk_dim, v_dim) => out_dim`; + d) `in_dim::Int => qkv_dim => out_dim`. +- `nheads`: number of heads. Default `8`. +- `init`: weight initializer for the Dense layers. Default `glorot_uniform`. +- `bias` : whether pointwise QKVO dense transforms use bias. Default `false`. +- `dropout_prob`: dropout probability for the attention scores. Default `0.0`. + +# Forward + + (mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask]) + +The arguments of the forward pass are: + +- `q_in`: Input query array of size `(q_in_dim, q_len, batch_size)`. +- `k_in`: Input key array of size `(k_in_dim, kv_len, batch_size)`. +- `v_in`: Input value array of size `(v_in_dim, kv_len, batch_size)`. +- `bias`: Bias array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. + It will be added to the attention scores before the softmax. + Default `nothing`. +- `mask`: Input array broadcastable to size + `(kv_len, q_len, nheads, batch_size)`. + The mask is applied to the attention scores just before the softmax. + See [`NNlib.make_causal_mask`](@ref) for creating causal masks. + Default `nothing`. + +Alternative calling signatures are `mha(q_in)`, equivalent to `mha(q_in, q_in, q_in)` (self-attention), +and `mha(q_in, k_in)`, equivalent to `mha(q_in, k_in, k_in)` (key and value are the same). + +See also [`NNlib.dot_product_attention`](@ref). + +# Examples + +```julia +mha = MultiHeadAttention(64, nheads = 8) +q = rand(Float32, (64, 10, 32)) +k = rand(Float32, (64, 20, 32)) +v = rand(Float32, (64, 20, 32)) +y, α = mha(q, k, v) +# [y] = [64, 10, 32] +# [α] = [20, 10, 8, 32] + +mha = MultiHeadAttention(64 => 1024 => 1024, nheads = 8) +y, α = mha(q) # self-attention +# [y] = [1024, 10, 32] +# [α] = [10, 10, 8, 32] +``` +""" +struct MultiHeadAttention{P1, D, P2} + nheads::Int + q_proj::P1 + k_proj::P1 + v_proj::P1 + attn_drop::D + out_proj::P2 +end + +@functor MultiHeadAttention + +function MultiHeadAttention(dims; + nheads::Int = 8, + bias::Bool = false, + init = glorot_uniform, + dropout_prob = 0.0) + + dims = normalize_mha_dims(dims) + @assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads" + @assert dims.v % nheads == 0 "v_dim should be divisible by nheads" + q_proj = Dense(dims.q_in => dims.qk; bias, init) + k_proj = Dense(dims.k_in => dims.qk; bias, init) + v_proj = Dense(dims.v_in => dims.v; bias, init) + attn_drop = Dropout(dropout_prob) + out_proj = Dense(dims.v => dims.out; bias, init) + return MultiHeadAttention(nheads, q_proj, k_proj, v_proj, attn_drop, out_proj) +end + +# turns the dims argument into a named tuple +normalize_mha_dims(dims::Int) = + (; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims) + +function normalize_mha_dims((in, (qkv, out))::Pair{<:IntOrDims{3}, <:Pair{<:IntOrDims{2}, Int}}) + if in isa Int + q_in = k_in = v_in = in + else + q_in, k_in, v_in = in + end + if qkv isa Int + qk = v = qkv + else + qk, v = qkv + end + return (; q_in, k_in, v_in, qk, v, out) +end + +# self-attention +(mha::MultiHeadAttention)(qkv; kws...) = mha(qkv, qkv, qkv; kws...) + +# key and value are the same +(mha::MultiHeadAttention)(q, kv; kws...) = mha(q, kv, kv; kws...) + +function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, + bias=nothing; mask=nothing) + ## [q_in] = [q_in_dim, q_len, batch_size] + ## [k_in] = [k_in_dim, kv_len, batch_size] + ## [v_in] = [v_in_dim, kv_len, batch_size] + q = mha.q_proj(q_in) # [q] = [qk_dim, q_len, batch_size] + k = mha.k_proj(k_in) # [k] = [qk_dim, kv_len, batch_size] + v = mha.v_proj(v_in) # [v] = [v_dim, kv_len, batch_size] + x, α = NNlib.dot_product_attention(q, k, v, bias; mha.nheads, mask, fdrop=mha.attn_drop) + x = mha.out_proj(x) + # [x] = [out_dim, q_len, batch_size] + # [α] = [kv_len, q_len, nheads, batch_size] + return x, α +end diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index a406c4129e..90c7ab0b40 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -338,3 +338,29 @@ end @test eltype(pool(reshape(gx,3,4,1))) == Float16 end end + +@testset "MultiHeadAttention" begin + dim = 4; nheads = 2; len = 3; batch_size = 5 + mha_cpu = MultiHeadAttention(dim; nheads) + x_cpu = rand(Float32, (dim, len, batch_size)) + y_cpu, α_cpu = mha_cpu(x_cpu) + + mha_gpu = mha_cpu |> gpu + x_gpu = x_cpu |> gpu + y_gpu, α_gpu = mha_gpu(x_gpu) + @test y_gpu isa CuArray{Float32} + @test α_gpu isa CuArray{Float32} + @test Array(y_gpu) ≈ y_cpu atol=1e-4 + @test Array(α_gpu) ≈ α_cpu atol=1e-4 + + gm_cpu, gx_cpu = gradient(mha_cpu, x_cpu) do mha, x + y, α = mha(x) + return sum(y.^2) + sum(α.^2) + end + gm_gpu, gx_gpu = gradient(mha_gpu, x_gpu) do mha, x + y, α = mha(x) + return sum(y.^2) + sum(α.^2) + end + check_grad(gm_gpu, gm_cpu) + check_grad(gx_gpu, gx_cpu) +end diff --git a/test/layers/attention.jl b/test/layers/attention.jl new file mode 100644 index 0000000000..a4c90b36ed --- /dev/null +++ b/test/layers/attention.jl @@ -0,0 +1,65 @@ + + +@testset "attention" begin + dim = 4; nheads = 2; len = 3; batch_size = 5 + mha = MultiHeadAttention(dim; nheads) + q = rand(Float32, (dim, len, batch_size)) + k = rand(Float32, (dim, len, batch_size)) + v = rand(Float32, (dim, len, batch_size)) + + y, α = mha(q, k, v) + @test y isa Array{Float32, 3} + @test size(y) == (dim, len, batch_size) + @test α isa Array{Float32, 4} + @test size(α) == (len, len, nheads, batch_size) + + @testset "self-attention" begin + y1, α1 = mha(q) + y2, α2 = mha(q, q, q) + @test y1 ≈ y2 + @test α1 ≈ α2 + end + + @testset "key and value are the same" begin + y1, α1 = mha(q, k) + y2, α2 = mha(q, k, k) + @test y1 ≈ y2 + @test α1 ≈ α2 + end + + @testset "change dims" begin + dims = 4 => 10 => 5 + nhead = 5 + mha2 = MultiHeadAttention(dims; nheads) + y2, _ = mha2(q, k, v) + @test size(y2) == (dims.second.second, len, batch_size) + end + + @testset "mask" begin + mask = NNlib.make_causal_mask(q) + y, α = mha(q; mask) + @test all(α[2, 1, :, :] .== 0) + @test α[:, :, 1, 1] ≈ triu(α[:, :, 1, 1]) + end + + @testset "bias" begin + # use bias to produce a causal mask + b = zeros(Float32, (len, len)) + for i in 1:len, j in i:len + b[i, j] = typemax(Float32) + end + y, α = mha(q, k, v, b) + @test all(α[2, 1, :, :] .== 0) + @test α[:, :, 1, 1] ≈ triu(α[:, :, 1, 1]) + end + + @testset "gradient" begin + gm, gq = gradient(mha, q) do mha, q + y, α = mha(q) + return sum(y.^2) + sum(α.^2) + end + check_grad_type(gm, mha) + check_grad_type(gq, q) + end +end + diff --git a/test/runtests.jl b/test/runtests.jl index a2a8f66323..a14372317c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,7 @@ Random.seed!(0) end @testset "Layers" begin + include("layers/attention.jl") include("layers/basic.jl") include("layers/normalisation.jl") include("layers/stateless.jl") diff --git a/test/test_utils.jl b/test/test_utils.jl index 2b07e59d08..f07fb1c721 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,27 +1,33 @@ -function check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing::Bool) +function check_grad(g_gpu, g_cpu; + rtol=1e-4, atol=1e-4, + allow_nothing::Bool=false) allow_nothing && return @show g_gpu g_cpu @test false end -check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, rtol; allow_nothing::Bool) = - check_grad(g_gpu[], g_cpu[], atol, rtol; allow_nothing) -check_grad(g_gpu::Nothing, g_cpu::Nothing, atol, rtol; allow_nothing::Bool) = + +check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = + check_grad(g_gpu[], g_cpu[]; rtol, atol, allow_nothing) + +check_grad(g_gpu::Nothing, g_cpu::Nothing; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = @test true -check_grad(g_gpu::Float32, g_cpu::Float32, atol, rtol; allow_nothing::Bool) = + +check_grad(g_gpu::Float32, g_cpu::Float32; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = @test g_cpu ≈ g_gpu rtol=rtol atol=atol -check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}, atol, rtol; allow_nothing::Bool) = + +check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = @test g_cpu ≈ collect(g_gpu) rtol=rtol atol=atol -function check_grad(g_gpu::Tuple, g_cpu::Tuple, atol, rtol; allow_nothing::Bool) +function check_grad(g_gpu::Tuple, g_cpu::Tuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) for (v1, v2) in zip(g_gpu, g_cpu) - check_grad(v1, v2, atol, rtol; allow_nothing) + check_grad(v1, v2; rtol, atol, allow_nothing) end end -function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple, atol, rtol; allow_nothing::Bool) +function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) for ((k1,v1), (k2,v2)) in zip(pairs(g_gpu), pairs(g_cpu)) @test k1 == k2 - check_grad(v1, v2, atol, rtol; allow_nothing) + check_grad(v1, v2; rtol, atol, allow_nothing) end end @@ -31,10 +37,14 @@ check_type(x::CuArray{Float32}) = true check_type(x::Array{Float32}) = true function gpu_autodiff_test( - f_cpu, xs_cpu::Array{Float32}...; - test_equal=true, rtol=1e-4, atol=1e-4, - checkgrad::Bool = true, allow_nothing::Bool = false, -) + f_cpu, + xs_cpu::Array{Float32}...; + test_equal=true, + rtol=1e-4, atol=1e-4, + checkgrad::Bool = true, + allow_nothing::Bool = false, + ) + # Compare CPU & GPU function outputs. f_gpu = f_cpu |> gpu xs_gpu = gpu.(xs_cpu) @@ -60,7 +70,7 @@ function gpu_autodiff_test( if test_equal @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol for (g_gpu, g_cpu) in zip(gs_gpu, gs_cpu) - check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing) + check_grad(g_gpu, g_cpu; atol, rtol, allow_nothing) end end @@ -78,7 +88,22 @@ function gpu_autodiff_test( @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol @assert length(ps_gpu) == length(ps_cpu) for (p_gpu, p_cpu) in zip(ps_gpu, ps_cpu) - check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu], atol, rtol; allow_nothing) + check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu]; atol, rtol, allow_nothing) end end end + +# check_grad_type checks that the gradient type matches the primal type. + +check_grad_type(g::Nothing, x) = nothing + +function check_grad_type(g::AbstractArray{T1}, x::AbstractArray{T2}) where {T1, T2} + @test T1 == T2 + @test size(g) == size(x) +end + +function check_grad_type(g::NamedTuple, x::T) where T + for f in fieldnames(T) + check_grad_type(g[f], getfield(x, f)) + end +end