Skip to content

Commit

Permalink
Merge #1516
Browse files Browse the repository at this point in the history
1516: add Embedding layer r=CarloLucibello a=CarloLucibello

Basic implementation. 

Maybe could be improved when FluxML/NNlib.jl#255 lands

### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
bors[bot] and CarloLucibello authored Jul 13, 2021
2 parents 1a14301 + dfb390d commit 13e607e
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 37 deletions.
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Flux Release Notes

## v0.12.4
* Implemented an [`Embedding layer`](https://github.com/FluxML/Flux.jl/pull/1516)
based on recently added `NNlib.gather` and `NNlib.scatter`.

## v0.12.1 - v0.12.3

* CUDA.jl 3.0 support
* Bug fixes and optimizations.

## v0.12.0

* Add [identity_init](https://github.com/FluxML/Flux.jl/pull/1524).
Expand Down
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.12.4"
version = "0.12.5"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand Down Expand Up @@ -37,7 +39,7 @@ Colors = "0.12"
Functors = "0.2.1"
Juno = "0.8"
MacroTools = "0.5"
NNlib = "0.7.14"
NNlib = "0.7.24"
NNlibCUDA = "0.1"
Reexport = "0.2, 1.0"
StatsBase = "0.33"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ If you define a structured model, like a `Dense` layer or `Chain`, you just need
```julia
d = Dense(10, 5, σ)
d = fmap(cu, d)
d.W # CuArray
d.weight # CuArray
d(cu(rand(10))) # CuArray output

m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ by simply deleting it from `ps`:

```julia
ps = params(m)
delete!(ps, m[2].b)
delete!(ps, m[2].bias)
```

## Custom multiple input or output layer
Expand Down
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ SkipConnection
Parallel
Flux.Bilinear
Flux.Diagonal
Flux.Embedding
```

## Normalisation & Regularisation
Expand Down
7 changes: 7 additions & 0 deletions docs/src/models/nnlib.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,10 @@ NNlib.batched_mul!
NNlib.batched_adjoint
NNlib.batched_transpose
```

## Gather and Scatter

```@docs
NNlib.gather
NNlib.scatter
```
38 changes: 19 additions & 19 deletions docs/src/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Here's how you'd use Flux to build and train the most basic of models, step by s

This example will predict the output of the function `4x + 2`. First, import `Flux` and define the function we want to simulate:

```
```julia
julia> using Flux

julia> actual(x) = 4x + 2
Expand All @@ -28,7 +28,7 @@ This example will build a model to approximate the `actual` function.

Use the `actual` function to build sets of data for training and verification:

```
```julia
julia> x_train, x_test = hcat(0:5...), hcat(6:10...)
([0 1 4 5], [6 7 9 10])

Expand All @@ -42,38 +42,38 @@ Normally, your training and test data come from real world observations, but thi

Now, build a model to make predictions with `1` input and `1` output:

```
```julia
julia> model = Dense(1, 1)
Dense(1, 1)

julia> model.W
1-element Array{Float64,1}:
-0.99009055
julia> model.weight
1×1 Matrix{Float32}:
-1.4925033

julia> model.b
1-element Array{Float64,1}:
julia> model.bias
1-element Vector{Float32}:
0.0
```

Under the hood, a dense layer is a struct with fields `W` and `b`. `W` represents a weight and `b` represents a bias. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:
Under the hood, a dense layer is a struct with fields `weight` and `bias`. `weight` represents a weights' matrix and `bias` represents a bias vector. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:

```
```julia
julia> predict = Dense(1, 1)
```

`Dense(1, 1)` also implements the function `σ(Wx+b)` where `W` and `b` are the weights and biases. `σ` is an activation function (more on activations later). Our model has one weight and one bias, but typical models will have many more. Think of weights and biases as knobs and levers Flux can use to tune predictions. Activation functions are transformations that tailor models to your needs.

This model will already make predictions, though not accurate ones yet:

```
```julia
julia> predict(x_train)
1×6 Array{Float32,2}:
-1.98018 -5.94054 -9.90091 -13.8613 -17.8216 -21.782
1×6 Matrix{Float32}:
0.0 -1.4925 -2.98501 -4.47751 -5.97001 -7.46252
```

In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.

```
```julia
julia> loss(x, y) = Flux.Losses.mse(predict(x), y)
loss (generic function with 1 method)

Expand All @@ -87,7 +87,7 @@ More accurate predictions will yield a lower loss. You can write your own loss f

Under the hood, the Flux [`train!`](@ref) function uses *a loss function* and *training data* to improve the *parameters* of your model based on a pluggable [`optimiser`](../training/optimisers.md):

```
```julia
julia> using Flux: train!

julia> opt = Descent()
Expand All @@ -100,12 +100,12 @@ julia> data = [(x_train, y_train)]

Now, we have the optimiser and data we'll pass to `train!`. All that remains are the parameters of the model. Remember, each model is a Julia struct with a function and configurable parameters. Remember, the dense layer has weights and biases that depend on the dimensions of the inputs and outputs:

```
julia> predict.W
```julia
julia> predict.weight
1-element Array{Float64,1}:
-0.99009055

julia> predict.b
julia> predict.bias
1-element Array{Float64,1}:
0.0
```
Expand All @@ -120,7 +120,7 @@ Params([[-0.99009055], [0.0]])
These are the parameters Flux will change, one step at a time, to improve predictions. Each of the parameters comes from the `predict` model:

```
julia> predict.W in parameters, predict.b in parameters
julia> predict.weight in parameters, predict.bias in parameters
(true, true)
```
Expand Down
4 changes: 2 additions & 2 deletions docs/src/models/regularisation.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ m = Dense(10, 5)
loss(x, y) = logitcrossentropy(m(x), y)
```

We can apply L2 regularisation by taking the squared norm of the parameters , `m.W` and `m.b`.
We can apply L2 regularisation by taking the squared norm of the parameters , `m.weight` and `m.bias`.

```julia
penalty() = sum(abs2, m.W) + sum(abs2, m.b)
penalty() = sum(abs2, m.weight) + sum(abs2, m.bias)
loss(x, y) = logitcrossentropy(m(x), y) + penalty()
```

Expand Down
53 changes: 53 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on a given input.
`m[1:3](x)` will calculate the output of the first three layers.
# Examples
```jldoctest
julia> m = Chain(x -> x^2, x -> x+1);
Expand Down Expand Up @@ -428,3 +429,55 @@ function Base.show(io::IO, m::Parallel)
join(io, m.layers, ", ")
print(io, ")")
end

"""
Embedding(in, out; init=randn)
A lookup table that stores embeddings of dimension `out`
for a vocabulary of size `in`.
This layers is often used to store word embeddings and retrieve them using indices.
The input to the layer can be either a vector of indexes
or the corresponding [onehot encoding](@ref Flux.OneHotArray).
# Examples
```julia-repl
julia> using Flux: Embedding
julia> vocab_size, embed_size = 1000, 4;
julia> model = Embedding(vocab_size, embed_size)
Embedding(1000, 4)
julia> vocab_idxs = [1, 722, 53, 220, 3]
julia> x = OneHotMatrix(vocab_idxs, vocab_size);
julia> model(x)
4×5 Matrix{Float32}:
0.91139 0.670462 0.463217 0.670462 0.110932
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
-0.497621 0.87595 -0.870251 0.87595 -0.772696
```
julia> model(vocab_idxs) == model(x)
true
"""
struct Embedding{W}
weight::W
end

@functor Embedding

Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))

(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
(m::Embedding)(x::Integer) = m([x])
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)

function Base.show(io::IO, m::Embedding)
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
end
4 changes: 3 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@r
```jldoctest
julia> layer = Dense(10, 20);
julia> Flux.nfan(size(layer.W))
julia> Flux.nfan(size(layer.weight))
(10, 20)
julia> layer = Conv((3, 3), 2=>10);
Expand Down Expand Up @@ -368,6 +368,8 @@ identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identit

ones32(dims...) = Base.ones(Float32, dims...)
zeros32(dims...) = Base.zeros(Float32, dims...)
rand32(dims...) = Base.rand(Float32, dims...)
randn32(dims...) = Base.randn(Float32, dims...)

"""
create_bias(weights, bias, length)
Expand Down
38 changes: 37 additions & 1 deletion test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ end

@test sum(l(ip)) 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.b gs.params
@test l.bias gs.params
end

@testset "Extended BatchNorm" begin
Expand Down Expand Up @@ -259,3 +259,39 @@ end
end
end
end

@testset "Embedding" begin
vocab_size, embed_size = 5, 2
m = Flux.Embedding(vocab_size, embed_size)

x = [1, 3, 5]
y = m(x)
m_g = m |> gpu
x_g = x |> gpu
y_g = m_g(x_g)
@test collect(y_g) == y

gs = gradient(() -> sum(m(x)), params(m))
gs_g = gradient(() -> sum(m_g(x_g)), params(m_g))
@test collect(gs_g[m_g.weight]) gs[m.weight]

gs = gradient(() -> sum(tanh.(m(x))), params(m))
gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g))
@test collect(gs_g[m_g.weight]) gs[m.weight]

@testset "repeated indexes" begin
vocab_size, embed_size = 5, 2
m = Flux.Embedding(vocab_size, embed_size)

x = [1, 3, 5, 3] # repeated indexes
y = m(x)
m_g = m |> gpu
x_g = x |> gpu
y_g = m_g(x_g)
@test collect(y_g) == y
gs = gradient(() -> sum(m(x)), params(m))
gs_g = gradient(() -> sum(m_g(x_g)), params(m_g))
@test collect(gs_g[m_g.weight]) gs[m.weight]
end
end

25 changes: 25 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,29 @@ import Flux: activations
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
end
end

@testset "Embedding" begin
vocab_size, embed_size = 10, 4
m = Flux.Embedding(vocab_size, embed_size)
@test size(m.weight) == (embed_size, vocab_size)

x = rand(1:vocab_size, 3)
y = m(x)
@test y isa Matrix{Float32}
@test y m.weight[:,x]
x2 = OneHotMatrix(x, vocab_size)
y2 = m(x2)
@test y2 isa Matrix{Float32}
@test y2 y
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))

x = rand(1:vocab_size, 3, 4)
y = m(x)
@test y isa Array{Float32, 3}
@test size(y) == (embed_size, 3, 4)

@test m(2) m.weight[:,2]
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
end
end
Loading

0 comments on commit 13e607e

Please sign in to comment.