From a5c5d8c456d7315d842ae1da774dae69f5f39607 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 22 Feb 2021 12:55:58 +0100 Subject: [PATCH 1/6] add embedding layer Co-authored-by: Kyle Daruwalla --- docs/src/models/layers.md | 1 + src/Flux.jl | 2 +- src/layers/basic.jl | 53 +++++++++++++++++++++++++++++++++++++++ test/cuda/layers.jl | 15 +++++++++++ test/layers/basic.jl | 25 ++++++++++++++++++ 5 files changed, 95 insertions(+), 1 deletion(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index cee27941de..10333f9e9d 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -58,6 +58,7 @@ SkipConnection Parallel Flux.Bilinear Flux.Diagonal +Flux.Embedding ``` ## Normalisation & Regularisation diff --git a/src/Flux.jl b/src/Flux.jl index f4e808bd05..400a4be7e1 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient export Chain, Dense, Maxout, SkipConnection, Parallel, flatten, - RNN, LSTM, GRU, + RNN, LSTM, GRU, Embedding, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 997f72bcde..b8defdb0b8 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -8,6 +8,7 @@ on a given input. `m[1:3](x)` will calculate the output of the first three layers. # Examples + ```jldoctest julia> m = Chain(x -> x^2, x -> x+1); @@ -428,3 +429,55 @@ function Base.show(io::IO, m::Parallel) join(io, m.layers, ", ") print(io, ")") end + +""" + Embedding(in, out; init=randn) + +A lookup table that stores embeddings of dimension `out` +for a vocabulary of size `in`. + +This layers is often used to store word embeddings and retrieve them using indices. +The input to the layer can be either a vector of indexes +or the corresponding [onehot encoding](@ref Flux.OneHotArray). + +# Examples + +```julia-repl +julia> vocab_size, embed_size = 1000, 4; + +julia> model = Embedding(vocab_size, embed_size) +Embedding(1000, 4) + +julia> vocab_idxs = [1, 722, 53, 220, 3] + +julia> x = OneHotMatrix(vocab_idxs, vocab_size); + +julia> model(x) +4×5 Matrix{Float32}: + 0.91139 0.670462 0.463217 0.670462 0.110932 + 0.247225 -0.0823874 0.698694 -0.0823874 0.945958 + -0.393626 -0.590136 -0.545422 -0.590136 0.77743 + -0.497621 0.87595 -0.870251 0.87595 -0.772696 +``` + +julia> model(vocab_idxs) == model(x) +true +""" +struct Embedding{W} + weight::W +end + +@functor Embedding + +function Embedding(in::Integer, out::Integer; + init = (i...) -> randn(Float32, i...)) + return Embedding(init(out, in)) +end + +(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)] +(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x] +(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) + +function Base.show(io::IO, m::Embedding) + print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))") +end diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 0a4034303d..8d7ba90e80 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -259,3 +259,18 @@ end end end end + +@testset "Embedding" begin + vocab_size, embed_size = 10, 4 + m = Embedding(vocab_size, embed_size) + x = rand(1:vocab_size, 3) + y = m(x) + m_g = m |> gpu + x_g = x |> gpu + y_g = m_g(x_g) + @test collect(y_g) == y + gs = gradient(() -> sum(tanh.(m(x))), params(m)) + gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g)) + @test collect(gs_g[m_g.weight]) ≈ gs[m.weight] +end + diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 4bf4d430f2..5a22838bd1 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -191,4 +191,29 @@ import Flux: activations @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) end end + + @testset "Embedding" begin + vocab_size, embed_size = 10, 4 + m = Embedding(vocab_size, embed_size) + @test size(m.weight) == (embed_size, vocab_size) + + x = rand(1:vocab_size, 3) + y = m(x) + @test y isa Matrix{Float32} + @test y ≈ m.weight[:,x] + x2 = OneHotMatrix(x, vocab_size) + y2 = m(x2) + @test y2 isa Matrix{Float32} + @test y2 ≈ y + @test_throws DimensionMismatch m(OneHotMatrix(x, 1000)) + + x = rand(1:vocab_size, 3, 4) + y = m(x) + @test y isa Array{Float32, 3} + @test size(y) == (embed_size, 3, 4) + + @test m(2) ≈ m.weight[:,2] + @test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3] + @test_throws DimensionMismatch m(OneHotVector(3, 1000)) + end end From c4f2351faf01c0e6aae3d6b45d3fd0deaa9ee812 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 10 Jul 2021 17:41:58 +0200 Subject: [PATCH 2/6] use gather; fix outdated docs Co-authored-by: Manikya --- docs/src/gpu.md | 2 +- docs/src/models/advanced.md | 2 +- docs/src/models/nnlib.md | 7 ++++++ docs/src/models/overview.md | 38 +++++++++++++++---------------- docs/src/models/regularisation.md | 4 ++-- src/layers/basic.jl | 3 ++- src/utils.jl | 2 +- test/cuda/layers.jl | 2 +- test/utils.jl | 21 +++++++++-------- 9 files changed, 45 insertions(+), 36 deletions(-) diff --git a/docs/src/gpu.md b/docs/src/gpu.md index ceee92d6c3..fc97eed377 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -30,7 +30,7 @@ If you define a structured model, like a `Dense` layer or `Chain`, you just need ```julia d = Dense(10, 5, σ) d = fmap(cu, d) -d.W # CuArray +d.weight # CuArray d(cu(rand(10))) # CuArray output m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index 98659a7e06..77d8940a89 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -68,7 +68,7 @@ by simply deleting it from `ps`: ```julia ps = params(m) -delete!(ps, m[2].b) +delete!(ps, m[2].bias) ``` ## Custom multiple input or output layer diff --git a/docs/src/models/nnlib.md b/docs/src/models/nnlib.md index d60cc9ea52..a60585a4df 100644 --- a/docs/src/models/nnlib.md +++ b/docs/src/models/nnlib.md @@ -67,3 +67,10 @@ NNlib.batched_mul! NNlib.batched_adjoint NNlib.batched_transpose ``` + +## Gather and Scatter + +```@docs +NNlib.gather +NNlib.scatter +``` diff --git a/docs/src/models/overview.md b/docs/src/models/overview.md index 697c3c65e3..ccc62953e9 100644 --- a/docs/src/models/overview.md +++ b/docs/src/models/overview.md @@ -15,7 +15,7 @@ Here's how you'd use Flux to build and train the most basic of models, step by s This example will predict the output of the function `4x + 2`. First, import `Flux` and define the function we want to simulate: -``` +```julia julia> using Flux julia> actual(x) = 4x + 2 @@ -28,7 +28,7 @@ This example will build a model to approximate the `actual` function. Use the `actual` function to build sets of data for training and verification: -``` +```julia julia> x_train, x_test = hcat(0:5...), hcat(6:10...) ([0 1 … 4 5], [6 7 … 9 10]) @@ -42,22 +42,22 @@ Normally, your training and test data come from real world observations, but thi Now, build a model to make predictions with `1` input and `1` output: -``` +```julia julia> model = Dense(1, 1) Dense(1, 1) -julia> model.W -1-element Array{Float64,1}: - -0.99009055 +julia> model.weight +1×1 Matrix{Float32}: + -1.4925033 -julia> model.b -1-element Array{Float64,1}: +julia> model.bias +1-element Vector{Float32}: 0.0 ``` -Under the hood, a dense layer is a struct with fields `W` and `b`. `W` represents a weight and `b` represents a bias. There's another way to think about a model. In Flux, *models are conceptually predictive functions*: +Under the hood, a dense layer is a struct with fields `weight` and `bias`. `weight` represents a weights' matrix and `bias` represents a bias vector. There's another way to think about a model. In Flux, *models are conceptually predictive functions*: -``` +```julia julia> predict = Dense(1, 1) ``` @@ -65,15 +65,15 @@ julia> predict = Dense(1, 1) This model will already make predictions, though not accurate ones yet: -``` +```julia julia> predict(x_train) -1×6 Array{Float32,2}: - -1.98018 -5.94054 -9.90091 -13.8613 -17.8216 -21.782 +1×6 Matrix{Float32}: + 0.0 -1.4925 -2.98501 -4.47751 -5.97001 -7.46252 ``` In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions. -``` +```julia julia> loss(x, y) = Flux.Losses.mse(predict(x), y) loss (generic function with 1 method) @@ -87,7 +87,7 @@ More accurate predictions will yield a lower loss. You can write your own loss f Under the hood, the Flux [`train!`](@ref) function uses *a loss function* and *training data* to improve the *parameters* of your model based on a pluggable [`optimiser`](../training/optimisers.md): -``` +```julia julia> using Flux: train! julia> opt = Descent() @@ -100,12 +100,12 @@ julia> data = [(x_train, y_train)] Now, we have the optimiser and data we'll pass to `train!`. All that remains are the parameters of the model. Remember, each model is a Julia struct with a function and configurable parameters. Remember, the dense layer has weights and biases that depend on the dimensions of the inputs and outputs: -``` -julia> predict.W +```julia +julia> predict.weight 1-element Array{Float64,1}: -0.99009055 -julia> predict.b +julia> predict.bias 1-element Array{Float64,1}: 0.0 ``` @@ -120,7 +120,7 @@ Params([[-0.99009055], [0.0]]) These are the parameters Flux will change, one step at a time, to improve predictions. Each of the parameters comes from the `predict` model: ``` -julia> predict.W in parameters, predict.b in parameters +julia> predict.weight in parameters, predict.bias in parameters (true, true) ``` diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md index 2a533b4182..b1089d5cef 100644 --- a/docs/src/models/regularisation.md +++ b/docs/src/models/regularisation.md @@ -13,10 +13,10 @@ m = Dense(10, 5) loss(x, y) = logitcrossentropy(m(x), y) ``` -We can apply L2 regularisation by taking the squared norm of the parameters , `m.W` and `m.b`. +We can apply L2 regularisation by taking the squared norm of the parameters , `m.weight` and `m.bias`. ```julia -penalty() = sum(abs2, m.W) + sum(abs2, m.b) +penalty() = sum(abs2, m.weight) + sum(abs2, m.bias) loss(x, y) = logitcrossentropy(m(x), y) + penalty() ``` diff --git a/src/layers/basic.jl b/src/layers/basic.jl index b8defdb0b8..754f7fb2f5 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -475,7 +475,8 @@ function Embedding(in::Integer, out::Integer; end (m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)] -(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x] +(m::Embedding)(x::Integer) = m([x]) +(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) (m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) function Base.show(io::IO, m::Embedding) diff --git a/src/utils.jl b/src/utils.jl index 7fbe2f2ec6..39f6079cde 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,7 +15,7 @@ This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@r ```jldoctest julia> layer = Dense(10, 20); -julia> Flux.nfan(size(layer.W)) +julia> Flux.nfan(size(layer.weight)) (10, 20) julia> layer = Conv((3, 3), 2=>10); diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 8d7ba90e80..a8bfdaa0b1 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -140,7 +140,7 @@ end @test sum(l(ip)) ≈ 0.f0 gs = gradient(() -> sum(l(ip)), Flux.params(l)) - @test l.b ∉ gs.params + @test l.bias ∉ gs.params end @testset "Extended BatchNorm" begin diff --git a/test/utils.jl b/test/utils.jl index 8a3ad8741d..fa05b343e6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -226,19 +226,19 @@ end m = Chain(Dense(10, 5, relu), Dense(5, 2)) x64 = rand(Float64, 10) x32 = rand(Float32, 10) - @test eltype(m[1].W) == Float32 + @test eltype(m[1].weight) == Float32 @test eltype(m(x32)) == Float32 @test eltype(m(x64)) == Float64 @test eltype(f64(m)(x32)) == Float64 @test eltype(f64(m)(x64)) == Float64 - @test eltype(f64(m)[1].W) == Float64 - @test eltype(f32(f64(m))[1].W) == Float32 + @test eltype(f64(m)[1].weight) == Float64 + @test eltype(f32(f64(m))[1].weight) == Float32 end @testset "Zeros" begin m = Dense(3,2; bias=false) - @test f64(m).b === m.b === Zeros() - @test f32(m).b === m.b === Zeros() + @test f64(m).bias === m.bias === Zeros() + @test f32(m).bias === m.bias === Zeros() @testset "Gradients for broadcasted $op with sizes $s" for op in (+,-,*), s in ((1,), (2,3)) o = ones(s) @@ -340,19 +340,20 @@ end nobias(n) = Zeros() testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt))) - @test l1.W == l2.W - @test l1.b == l2.b - @test_skip typeof(l1.b) === typeof(l2.b) + @test l1.weight == l2.weight + @test l1.bias == l2.bias + @test_skip typeof(l1.bias) === typeof(l2.bias) end @testset "loadparams!" begin import Flux: loadparams! pars(w, b) = [w, b] import Flux: loadparams!, Zeros + pars(w, b::Zeros) = [w, Flux.zeros32(size(w,1))] - pars(l) = pars(l.W, l.b) + pars(l) = pars(l.weight, l.bias) pararray(m) = mapreduce(pars, vcat, m) - weights(m) = mapreduce(l -> [l.W], vcat, m) + weights(m) = mapreduce(l -> [l.weight], vcat, m) @testset "Bias type $bt" for bt in (Flux.zeros32, nobias) m = dm(bt) loadparams!(m, params(m)) From 64e0f365b9b43833ab5c914cc0ed18c9020195d9 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 11 Jul 2021 13:03:03 +0200 Subject: [PATCH 3/6] more embedding tests; keep Embedding unexported --- Project.toml | 6 ++++-- src/Flux.jl | 2 +- test/cuda/layers.jl | 25 +++++++++++++++++++++++-- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index d4bcbca67a..7760a0c9d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,14 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.12.4" +version = "0.12.5" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" @@ -37,7 +39,7 @@ Colors = "0.12" Functors = "0.2.1" Juno = "0.8" MacroTools = "0.5" -NNlib = "0.7.14" +NNlib = "0.7.24" NNlibCUDA = "0.1" Reexport = "0.2, 1.0" StatsBase = "0.33" diff --git a/src/Flux.jl b/src/Flux.jl index 400a4be7e1..f4e808bd05 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient export Chain, Dense, Maxout, SkipConnection, Parallel, flatten, - RNN, LSTM, GRU, Embedding, + RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index a8bfdaa0b1..e384039799 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -261,16 +261,37 @@ end end @testset "Embedding" begin - vocab_size, embed_size = 10, 4 + vocab_size, embed_size = 5, 2 m = Embedding(vocab_size, embed_size) - x = rand(1:vocab_size, 3) + + x = [1, 3, 5] y = m(x) m_g = m |> gpu x_g = x |> gpu y_g = m_g(x_g) @test collect(y_g) == y + + gs = gradient(() -> sum(m(x)), params(m)) + gs_g = gradient(() -> sum(m_g(x_g)), params(m_g)) + @test collect(gs_g[m_g.weight]) ≈ gs[m.weight] + gs = gradient(() -> sum(tanh.(m(x))), params(m)) gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g)) @test collect(gs_g[m_g.weight]) ≈ gs[m.weight] + + @testset "repeated indexes" begin + vocab_size, embed_size = 5, 2 + m = Embedding(vocab_size, embed_size) + + x = [1, 3, 5, 3] # repeated indexes + y = m(x) + m_g = m |> gpu + x_g = x |> gpu + y_g = m_g(x_g) + @test collect(y_g) == y + gs = gradient(() -> sum(m(x)), params(m)) + gs_g = gradient(() -> sum(m_g(x_g)), params(m_g)) + @test collect(gs_g[m_g.weight]) ≈ gs[m.weight] + end end From 82c3f293df916a62673d874be189e546bf81eb11 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 11 Jul 2021 13:14:36 +0200 Subject: [PATCH 4/6] update news --- NEWS.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/NEWS.md b/NEWS.md index c61fa02923..64ad240178 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,14 @@ # Flux Release Notes +## v0.12.4 +* Implemented an [`Embedding layer`](https://github.com/FluxML/Flux.jl/pull/1516) + based on recently added `NNlib.gather` and `NNlib.scatter`. + +## v0.12.1 - v0.12.3 + +* CUDA.jl 3.0 support +* Bug fixes and optimizations. + ## v0.12.0 * Add [identity_init](https://github.com/FluxML/Flux.jl/pull/1524). From 058a4a07962d5d82caa7f5a3f24f8afe7faec58c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 11 Jul 2021 14:11:17 +0200 Subject: [PATCH 5/6] fix import --- src/layers/basic.jl | 2 ++ test/cuda/layers.jl | 4 ++-- test/layers/basic.jl | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 754f7fb2f5..f2b145f62b 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -443,6 +443,8 @@ or the corresponding [onehot encoding](@ref Flux.OneHotArray). # Examples ```julia-repl +julia> using Flux: Embedding + julia> vocab_size, embed_size = 1000, 4; julia> model = Embedding(vocab_size, embed_size) diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index e384039799..afa9c3dacf 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -262,7 +262,7 @@ end @testset "Embedding" begin vocab_size, embed_size = 5, 2 - m = Embedding(vocab_size, embed_size) + m = Flux.Embedding(vocab_size, embed_size) x = [1, 3, 5] y = m(x) @@ -281,7 +281,7 @@ end @testset "repeated indexes" begin vocab_size, embed_size = 5, 2 - m = Embedding(vocab_size, embed_size) + m = Flux.Embedding(vocab_size, embed_size) x = [1, 3, 5, 3] # repeated indexes y = m(x) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 5a22838bd1..f75af7c3cb 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -194,7 +194,7 @@ import Flux: activations @testset "Embedding" begin vocab_size, embed_size = 10, 4 - m = Embedding(vocab_size, embed_size) + m = Flux.Embedding(vocab_size, embed_size) @test size(m.weight) == (embed_size, vocab_size) x = rand(1:vocab_size, 3) From dfb390da30b74af8f2219ad6457c1626ade8621e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 12 Jul 2021 19:54:14 +0200 Subject: [PATCH 6/6] use randn32 --- src/layers/basic.jl | 7 ++----- src/utils.jl | 2 ++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f2b145f62b..ba3d0e23b2 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -471,11 +471,8 @@ end @functor Embedding -function Embedding(in::Integer, out::Integer; - init = (i...) -> randn(Float32, i...)) - return Embedding(init(out, in)) -end - +Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in)) + (m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)] (m::Embedding)(x::Integer) = m([x]) (m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) diff --git a/src/utils.jl b/src/utils.jl index 39f6079cde..a7c90807cf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -368,6 +368,8 @@ identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identit ones32(dims...) = Base.ones(Float32, dims...) zeros32(dims...) = Base.zeros(Float32, dims...) +rand32(dims...) = Base.rand(Float32, dims...) +randn32(dims...) = Base.randn(Float32, dims...) """ create_bias(weights, bias, length)