Skip to content

Commit

Permalink
Merge #1516
Browse files Browse the repository at this point in the history
1516: add Embedding layer r=DhairyaLGandhi a=CarloLucibello

Basic implementation. 

Maybe could be improved when FluxML/NNlib.jl#255 lands

### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
3 people authored Jul 13, 2021
2 parents 1a0b519 + 397cabd commit 9931730
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 42 deletions.
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Flux Release Notes

## v0.12.4
* Implemented an [`Embedding layer`](https://github.com/FluxML/Flux.jl/pull/1516)
based on `NNlib.gather` and `NNlib.scatter`.

## v0.12.1 - v0.12.3

* CUDA.jl 3.0 support
* Bug fixes and optimizations.

## v0.12.0

* Add [identity_init](https://github.com/FluxML/Flux.jl/pull/1524).
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.12.4"
version = "0.12.5"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -37,7 +37,7 @@ Colors = "0.12"
Functors = "0.2.1"
Juno = "0.8"
MacroTools = "0.5"
NNlib = "0.7.14"
NNlib = "0.7.24"
NNlibCUDA = "0.1"
Reexport = "0.2, 1.0"
StatsBase = "0.33"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ If you define a structured model, like a `Dense` layer or `Chain`, you just need
```julia
d = Dense(10, 5, σ)
d = fmap(cu, d)
d.W # CuArray
d.weight # CuArray
d(cu(rand(10))) # CuArray output

m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ by simply deleting it from `ps`:

```julia
ps = params(m)
delete!(ps, m[2].b)
delete!(ps, m[2].bias)
```

## Custom multiple input or output layer
Expand Down
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ SkipConnection
Parallel
Flux.Bilinear
Flux.Diagonal
Flux.Embedding
```

## Normalisation & Regularisation
Expand Down
7 changes: 7 additions & 0 deletions docs/src/models/nnlib.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,10 @@ NNlib.batched_mul!
NNlib.batched_adjoint
NNlib.batched_transpose
```

## Gather and Scatter

```@docs
NNlib.gather
NNlib.scatter
```
38 changes: 19 additions & 19 deletions docs/src/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Here's how you'd use Flux to build and train the most basic of models, step by s

This example will predict the output of the function `4x + 2`. First, import `Flux` and define the function we want to simulate:

```
```julia
julia> using Flux

julia> actual(x) = 4x + 2
Expand All @@ -28,7 +28,7 @@ This example will build a model to approximate the `actual` function.

Use the `actual` function to build sets of data for training and verification:

```
```julia
julia> x_train, x_test = hcat(0:5...), hcat(6:10...)
([0 1 4 5], [6 7 9 10])

Expand All @@ -42,38 +42,38 @@ Normally, your training and test data come from real world observations, but thi

Now, build a model to make predictions with `1` input and `1` output:

```
```julia
julia> model = Dense(1, 1)
Dense(1, 1)

julia> model.W
1-element Array{Float64,1}:
-0.99009055
julia> model.weight
1×1 Matrix{Float32}:
-1.4925033

julia> model.b
1-element Array{Float64,1}:
julia> model.bias
1-element Vector{Float32}:
0.0
```

Under the hood, a dense layer is a struct with fields `W` and `b`. `W` represents a weight and `b` represents a bias. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:
Under the hood, a dense layer is a struct with fields `weight` and `bias`. `weight` represents a weights' matrix and `bias` represents a bias vector. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:

```
```julia
julia> predict = Dense(1, 1)
```

`Dense(1, 1)` also implements the function `σ(Wx+b)` where `W` and `b` are the weights and biases. `σ` is an activation function (more on activations later). Our model has one weight and one bias, but typical models will have many more. Think of weights and biases as knobs and levers Flux can use to tune predictions. Activation functions are transformations that tailor models to your needs.

This model will already make predictions, though not accurate ones yet:

```
```julia
julia> predict(x_train)
1×6 Array{Float32,2}:
-1.98018 -5.94054 -9.90091 -13.8613 -17.8216 -21.782
1×6 Matrix{Float32}:
0.0 -1.4925 -2.98501 -4.47751 -5.97001 -7.46252
```

In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.

```
```julia
julia> loss(x, y) = Flux.Losses.mse(predict(x), y)
loss (generic function with 1 method)

Expand All @@ -87,7 +87,7 @@ More accurate predictions will yield a lower loss. You can write your own loss f

Under the hood, the Flux [`train!`](@ref) function uses *a loss function* and *training data* to improve the *parameters* of your model based on a pluggable [`optimiser`](../training/optimisers.md):

```
```julia
julia> using Flux: train!

julia> opt = Descent()
Expand All @@ -100,12 +100,12 @@ julia> data = [(x_train, y_train)]

Now, we have the optimiser and data we'll pass to `train!`. All that remains are the parameters of the model. Remember, each model is a Julia struct with a function and configurable parameters. Remember, the dense layer has weights and biases that depend on the dimensions of the inputs and outputs:

```
julia> predict.W
```julia
julia> predict.weight
1-element Array{Float64,1}:
-0.99009055

julia> predict.b
julia> predict.bias
1-element Array{Float64,1}:
0.0
```
Expand All @@ -120,7 +120,7 @@ Params([[-0.99009055], [0.0]])
These are the parameters Flux will change, one step at a time, to improve predictions. Each of the parameters comes from the `predict` model:

```
julia> predict.W in parameters, predict.b in parameters
julia> predict.weight in parameters, predict.bias in parameters
(true, true)
```
Expand Down
4 changes: 2 additions & 2 deletions docs/src/models/regularisation.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ m = Dense(10, 5)
loss(x, y) = logitcrossentropy(m(x), y)
```

We can apply L2 regularisation by taking the squared norm of the parameters , `m.W` and `m.b`.
We can apply L2 regularisation by taking the squared norm of the parameters , `m.weight` and `m.bias`.

```julia
penalty() = sum(abs2, m.W) + sum(abs2, m.b)
penalty() = sum(abs2, m.weight) + sum(abs2, m.bias)
loss(x, y) = logitcrossentropy(m(x), y) + penalty()
```

Expand Down
58 changes: 58 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on a given input.
`m[1:3](x)` will calculate the output of the first three layers.
# Examples
```jldoctest
julia> m = Chain(x -> x^2, x -> x+1);
Expand Down Expand Up @@ -428,3 +429,60 @@ function Base.show(io::IO, m::Parallel)
join(io, m.layers, ", ")
print(io, ")")
end

"""
Embedding(in, out; init=randn)
A lookup table that stores embeddings of dimension `out`
for a vocabulary of size `in`.
This layers is often used to store word embeddings and retrieve them using indices.
The input to the layer can be either a vector of indexes
or the corresponding [onehot encoding](@ref Flux.OneHotArray).
# Examples
```julia-repl
julia> using Flux: Embedding
julia> vocab_size, embed_size = 1000, 4;
julia> model = Embedding(vocab_size, embed_size)
Embedding(1000, 4)
julia> vocab_idxs = [1, 722, 53, 220, 3]
julia> x = OneHotMatrix(vocab_idxs, vocab_size);
julia> model(x)
4×5 Matrix{Float32}:
0.91139 0.670462 0.463217 0.670462 0.110932
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
-0.497621 0.87595 -0.870251 0.87595 -0.772696
```
julia> model(vocab_idxs) == model(x)
true
"""
struct Embedding{W}
weight::W
end

@functor Embedding

Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))


(m::Embedding)(x::Integer) = m.weight[:, x]
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)

function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
end

function Base.show(io::IO, m::Embedding)
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
end
4 changes: 3 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@r
```jldoctest
julia> layer = Dense(10, 20);
julia> Flux.nfan(size(layer.W))
julia> Flux.nfan(size(layer.weight))
(10, 20)
julia> layer = Conv((3, 3), 2=>10);
Expand Down Expand Up @@ -368,6 +368,8 @@ identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identit

ones32(dims...) = Base.ones(Float32, dims...)
zeros32(dims...) = Base.zeros(Float32, dims...)
rand32(dims...) = Base.rand(Float32, dims...)
randn32(dims...) = Base.randn(Float32, dims...)

"""
create_bias(weights, bias, length)
Expand Down
29 changes: 23 additions & 6 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,21 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te
# test
if test_cpu
@test y_gpu y_cpu rtol=1f-3 atol=1f-3
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
if isnothing(xg_cpu)
@test isnothing(xg_gpu)
else
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
end
end
@test gs_gpu isa Flux.Zygote.Grads
for (p_cpu, p_gpu) in zip(ps_cpu, ps_gpu)
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
if test_cpu
@test Array(gs_gpu[p_gpu]) gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
if isnothing(gs_cpu[p_cpu])
@test isnothing(gs_gpu[p_gpu])
else
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
if test_cpu
@test Array(gs_gpu[p_gpu]) gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
end
end
end
end
Expand Down Expand Up @@ -114,6 +122,15 @@ pixelshuffle = [PixelShuffle]
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)

embedding = [Flux.Embedding]
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)

@testset "function layers" begin
x = rand(Float32, 3,3)
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
Expand All @@ -135,12 +152,12 @@ end
end

@testset "Dense with Zeros bias" begin
l = Dense(ones(Float32, 4,3), Flux.Zeros()) |> gpu
l = Dense(ones(Float32, 4, 3), Flux.Zeros()) |> gpu
ip = zeros(Float32, 3, 7) |> gpu

@test sum(l(ip)) 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.b gs.params
@test l.bias gs.params
end

@testset "Extended BatchNorm" begin
Expand Down
25 changes: 25 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,29 @@ import Flux: activations
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
end
end

@testset "Embedding" begin
vocab_size, embed_size = 10, 4
m = Flux.Embedding(vocab_size, embed_size)
@test size(m.weight) == (embed_size, vocab_size)

x = rand(1:vocab_size, 3)
y = m(x)
@test y isa Matrix{Float32}
@test y m.weight[:,x]
x2 = OneHotMatrix(x, vocab_size)
y2 = m(x2)
@test y2 isa Matrix{Float32}
@test y2 y
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))

x = rand(1:vocab_size, 3, 4)
y = m(x)
@test y isa Array{Float32, 3}
@test size(y) == (embed_size, 3, 4)

@test m(2) m.weight[:,2]
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
end
end
Loading

3 comments on commit 9931730

@CarloLucibello
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/41132

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.5 -m "<description of version>" 9931730129f514425d939683f72dd566a791472f
git push origin v0.12.5

@DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented on 9931730 Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flux is currently broken with the latest releases of our ecosystem specifically zygote. I am going to stop the registration and fix before doing a release.

  Got exception outside of a @test
  MethodError: no method matching push!(::Zygote.IdSet{Union{}})
  Closest candidates are:
    push!(::Zygote.IdSet{T}, ::T) where T at /home/dhairyalgandhi/.julia/packages/Zygote/JBaiY/src/tools/idset.jl:10
    push!(::Any, ::Any, ::Any) at abstractarray.jl:2387
    push!(::Any, ::Any, ::Any, ::Any...) at abstractarray.jl:2388
    ...
  Stacktrace:
    [1] Zygote.IdSet{Union{}}(xs::Vector{Union{}})
      @ Zygote ~/.julia/packages/Zygote/JBaiY/src/tools/idset.jl:14
    [2] Zygote.IdSet(xs::Vector{Union{}})
      @ Zygote ~/.julia/packages/Zygote/JBaiY/src/tools/idset.jl:16
    [3] Params(xs::Vector{Union{}})

Please sign in to comment.