Skip to content

Commit

Permalink
finish up
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Mar 5, 2023
1 parent 2b9b219 commit 5745555
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 424 deletions.
5 changes: 0 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@ version = "0.13.14"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
NeuralAttentionlib = "12afc1b8-fad6-47e1-9132-84abc478905f"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Expand All @@ -26,7 +22,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down
5 changes: 4 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
Dropout, AlphaDropout,
LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
MultiHeadAttention,
Upsample, PixelShuffle,
fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32,
testmode!, trainmode!
Expand Down Expand Up @@ -59,6 +61,7 @@ include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")
include("layers/upsample.jl")
include("layers/attention.jl")
include("layers/show.jl")

include("loading.jl")
Expand Down
281 changes: 55 additions & 226 deletions src/layers/attention.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,32 @@
using Flux, Functors, Test, LinearAlgebra, Random, Statistics
using CUDA
using NeuralAttentionlib
using NeuralAttentionlib: score_returning
using BenchmarkTools
using Flux: glorot_uniform
CUDA.allowscalar(false)

const A3{T} = AbstractArray{T, 3}
const A4{T} = AbstractArray{T, 4}
const TuplInt2 = Union{Int, Tuple{Int, Int}}
const TuplInt3 = Union{Int, Tuple{Int, Int, Int}}

include("attention_nnlib.jl")
include("attention_tullio.jl")


"""
MultiHeadAttention(dims, nheads; [bias, init, dropout_prob])
MultiHeadAttention(dims; [nheads, bias, init, dropout_prob])
The multi-head dot-product attention layer used in Transformer architectures [1].
Multi-head dot-product attention layer.
[1] Vaswani et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017.
# Arguments
- `dims`: ...
- `nheads`: number of heads.
- `init`: weight initializer for the Dense layers.
- `bias` : whether pointwise QKVO dense transforms use bias.
- `dropout_prob`: dropout probability for the attention scores.
- `dims`: The embedding dimensions of inputs, intermediate tensors and outputs.
In the most general case, it is given as
`(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`.
Can take also simpler forms as
`dims::Int`, `in_dim::Int => (qk_dim, v_dim) => out_dim`,
`in_dim::Int => qkv_dim => out_dim`.
- `nheads`: number of heads. Default `8`.
- `init`: weight initializer for the Dense layers. Default `glorot_uniform`.
- `bias` : whether pointwise QKVO dense transforms use bias. Default `false`.
- `dropout_prob`: dropout probability for the attention scores. Default `0.0`.
# Forward
(::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores])
(mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores])
- `q_in`: input query array of size `(q_in_dim, q_len, batch_size...)`.
- `k_in`: input key array of size `(k_in_dim, kv_len, batch_size...)`.
Expand All @@ -39,38 +35,58 @@ Multi-head dot-product attention layer.
`(kv_len, q_len, nheads, batch_size)`. Default `nothing`.
- `withscores`: Whether to return the attention scores. Default `false`.
In alternative, `mha(q_in)` is equivalent to `mha(q_in, q_in, q_in)` (self-attention)
and `mha(q_in, k_in)` is equivalent to `mha(q_in, k_in, k_in)` (key and value are the same).
See also [`NNlib.dot_product_attention`](@ref).
# Examples
```julia
mha = MultiHeadAttention(64, 8)
mha = MultiHeadAttention(64, nheads = 8)
q = rand(Float32, (64, 10, 32))
k = rand(Float32, (64, 20, 32))
v = rand(Float32, (64, 20, 32))
y = mha(q, k, v) # [y] = [64, 10, 32]
mha = MultiHeadAttention(64 => 1024 => 1024, nheads = 8)
y = mha(q) # self-attention; [y] = [1024, 10, 32]
```
"""
struct MultiHeadAttention{P1, D, P2}
nheads::Int
qkv_proj::P1
q_proj::P1
k_proj::P1
v_proj::P1
attn_drop::D
out_proj::P2
end

@functor MultiHeadAttention

function MultiHeadAttention(dims, nheads::Int;
function MultiHeadAttention(dims;
nheads::Int = 8,
bias::Bool = false,
init = glorot_uniform,
dropout_prob = 0.0)

dims = mha_process_dims(dims)
dims = normalize_mha_dims(dims)
@assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads"
qkv_proj = QKVProj(dims; bias, init)
@assert dims.v % nheads == 0 "v_dim should be divisible by nheads"
q_proj = Dense(dims.q_in => dims.qk; bias, init)
k_proj = Dense(dims.k_in => dims.qk; bias, init)
v_proj = Dense(dims.v_in => dims.v; bias, init)
attn_drop = Dropout(dropout_prob)
out_proj = Dense(dims.v => dims.out; bias, init)
return MultiHeadAttention(nheads, qkv_proj, attn_drop, out_proj)
return MultiHeadAttention(nheads, q_proj, k_proj, v_proj, attn_drop, out_proj)
end

mha_process_dims(dims::Int) =
# turns the dims argument into a named tuple
normalize_mha_dims(dims::Int) =
(; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims)

function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}})
function normalize_mha_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}})
if in isa Int
q_in = k_in = v_in = in
else
Expand All @@ -85,209 +101,22 @@ function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2,
end

# self-attention
(m::MultiHeadAttention)(qkv; kws...) = m(qkv, qkv, qkv; kws...)
(mha::MultiHeadAttention)(qkv; kws...) = mha(qkv, qkv, qkv; kws...)

# key and value are the same
(m::MultiHeadAttention)(q, kv; kws...) = m(q, kv, kv; kws...)
(mha::MultiHeadAttention)(q, kv; kws...) = mha(q, kv, kv; kws...)

function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, bias=nothing;
withscores=false, mask=nothing, impl=:nnlib)
function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, bias=nothing;
withscores=false, mask=nothing)
## [q_in] = [q_in_dim, q_len, batch_size]
## [k_in] = [k_in_dim, kv_len, batch_size]
## [v_in] = [v_in_dim, kv_len, batch_size]

q, k, v = m.qkv_proj(q_in, k_in, v_in)
# [q] = [qk_dim, q_len, batch_size]
# [k] = [qk_dim, kv_len, batch_size]
# [v] = [v_dim, kv_len, batch_size]

if impl == :tullio
x, α = dot_product_attention_tullio(m.nheads, q, k, v; mask, dropout=m.attn_drop)
elseif impl == :nalib
x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.nheads, q, k, v, mask)
elseif impl == :nnlib
x, α = dot_product_attention(q, k, v, bias; m.nheads, mask, fdrop=m.attn_drop)
else
error("Unknown attention implementation")
end

x = m.out_proj(x)

q = mha.q_proj(q_in) # [q] = [qk_dim, q_len, batch_size]
k = mha.k_proj(k_in) # [k] = [qk_dim, kv_len, batch_size]
v = mha.v_proj(v_in) # [v] = [v_dim, kv_len, batch_size]
x, α = NNlib.dot_product_attention(q, k, v, bias; mha.nheads, mask, fdrop=mha.attn_drop)
x = mha.out_proj(x)
# [x] = [out_dim, q_len, batch_size]
# [α] = [kv_len, q_len, nheads, batch_size]
return withscores ? (x, α) : x
end

struct QKVProj
q_proj::Dense
k_proj::Dense
v_proj::Dense
end

@functor QKVProj

function QKVProj(dims; bias = false, init=glorot_uniform)
return QKVProj(
Dense(dims.q_in => dims.qk; bias, init),
Dense(dims.k_in => dims.qk; bias, init),
Dense(dims.v_in => dims.v; bias, init))
end

function (proj::QKVProj)(q_in, k_in, v_in)
return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in))
end

function perf(dim, len, batch_size, nheads)
mha = MultiHeadAttention(dim, nheads)
x = rand(Float32, (dim, len, batch_size))

println("tullio")
@btime $mha($x, impl=:tullio);
@btime gradient(m -> sum(m($x, impl=:tullio)), $mha);

println("nalib")
@btime $mha($x, $x, $x, impl=:nalib);
@btime gradient(m -> sum(m($x, impl=:nalib)), $mha);

println("nnlib")
@btime $mha($x, $x, $x, impl=:nnlib);
@btime gradient(m -> sum(m($x, impl=:nnlib)), $mha);

if CUDA.functional()
mha_gpu = mha |> gpu
x_gpu = x |> gpu

println("tullio - gpu")
@btime $mha_gpu($x_gpu, impl=:tullio);
@btime gradient(m -> sum(m($x_gpu, impl=:tullio)), $mha_gpu);

println("nalib - gpu")
@btime CUDA.@sync $mha_gpu($x_gpu, impl=:nalib);
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nalib)), $mha_gpu);

println("nnlib - gpu")
@btime CUDA.@sync $mha_gpu($x_gpu, impl=:nnlib);
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nnlib)), $mha_gpu);
end
return nothing
end

function test(dim, nheads, len, batch_size)
mha = MultiHeadAttention(dim, nheads)
q = rand(Float32, (dim, len, batch_size))
k = rand(Float32, (dim, len, batch_size))
v = rand(Float32, (dim, len, batch_size))

y, α = mha(q, k, v, impl=:tullio, withscores=true)
@test y isa Array{Float32, 3}
@test size(y) == (dim, len, batch_size)
@test α isa Array{Float32, 4}
@test size(α) == (len, len, nheads, batch_size)

y2, α2 = mha(q, k, v, impl=:nalib, withscores=true)
@test size(y) == size(y2)
@test y2 y
@test size(α) == size(α2)
@test α2 α

y2b, α2b = mha(q, k, v, impl=:nnlib, withscores=true)
@test size(y) == size(y2b)
@test y2b y
@test size(α) == size(α2b)
@test α2b α

mask = make_causal_mask(q)
y3, α3 = mha(q, k, v; impl=:tullio, withscores=true, mask)
y4, α4 = mha(q, k, v, impl=:nalib, withscores=true, mask=NeuralAttentionlib.CausalMask())
@test y3 y4
@test α3 α4

if CUDA.functional()
mha_gpu = mha |> gpu
q_gpu, k_gpu, v_gpu = q |> gpu, k |> gpu, v |> gpu

y_gpu = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:tullio)
y_gpu2 = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:nalib)
@test Array(y_gpu) Array(y_gpu2)
@test Array(y_gpu) y
end
return nothing
end

test(4, 2, 3, 1)

perf(128, 8, 128, 32)

## M1 Pro, NNlib v0.8.12
# tullio
# 2.948 ms (77 allocations: 7.25 MiB)
# 15.041 ms (1124 allocations: 16.71 MiB)
# nalib
# 3.503 ms (89 allocations: 7.75 MiB)
# 15.828 ms (604 allocations: 14.70 MiB)
# nnlib
# 3.611 ms (87 allocations: 9.25 MiB)
# 16.497 ms (1055 allocations: 20.71 MiB)

## M1 Pro, NNlib v0.8.13 (fast_maximum)
# tullio
# 2.427 ms (71 allocations: 7.13 MiB)
# 14.510 ms (1118 allocations: 16.59 MiB)
# nalib
# 3.052 ms (84 allocations: 7.63 MiB)
# 15.327 ms (599 allocations: 14.57 MiB)
# nnlib
# 3.166 ms (81 allocations: 9.13 MiB)
# 16.082 ms (1049 allocations: 20.58 MiB)

## Threadripper, NNlib v0.8.12
# tullio
# 5.658 ms (77 allocations: 7.25 MiB)
# 22.373 ms (1124 allocations: 16.71 MiB)
# nalib
# 6.187 ms (89 allocations: 7.75 MiB)
# 23.723 ms (604 allocations: 14.70 MiB)
# nnlib
# 6.473 ms (87 allocations: 9.25 MiB)
# 24.966 ms (1055 allocations: 20.71 MiB)
# tullio - gpu
# 145.332 μs (520 allocations: 24.52 KiB)
# 902.020 μs (2221 allocations: 117.19 KiB)
# nalib - gpu
# 162.354 μs (410 allocations: 18.03 KiB)
# 604.111 μs (1263 allocations: 71.78 KiB)
# nnlib - gpu
# 156.383 μs (440 allocations: 20.00 KiB)
# 835.374 μs (1969 allocations: 100.58 KiB)

## Threadripper, NNlib v0.8.13 (fast_maximum)
# tullio
# 4.599 ms (71 allocations: 7.13 MiB)
# 20.699 ms (1118 allocations: 16.59 MiB)
# nalib
# 5.049 ms (84 allocations: 7.63 MiB)
# 22.252 ms (599 allocations: 14.57 MiB)
# nnlib
# 5.378 ms (81 allocations: 9.13 MiB)
# 23.453 ms (1049 allocations: 20.58 MiB)
# tullio - gpu
# 145.824 μs (520 allocations: 24.52 KiB)
# 915.305 μs (2221 allocations: 117.19 KiB)
# nalib - gpu
# 164.789 μs (410 allocations: 18.03 KiB)
# 610.835 μs (1263 allocations: 71.78 KiB)
# nnlib - gpu
# 157.785 μs (440 allocations: 20.00 KiB)
# 852.087 μs (1969 allocations: 100.58 KiB)


# function prof()
# dim, len, batch_size, nheads = 128, 8, 128, 32;
# # dim = 384; len = 128; batch_size = 32; nheads = 12
# mha = MultiHeadAttention(dim, nheads)
# x = rand(Float32, (dim, len, batch_size))
# @btime mha(x, impl=:tullio);
# @btime mha(x, impl=:nnlib);
# @profview mha(x, impl=:tullio);
# @profview prof(mha, x);
# y, α = mha(x; impl=:nnlib, withscores=true, mask)
# y2, α2 = mha(x; impl=:nalib, withscores=true, mask=NeuralAttentionlib.CausalMask())
# end
Loading

0 comments on commit 5745555

Please sign in to comment.