Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

doc entry for MultiHeadAttention #2218

Merged
merged 5 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ deps
.vscode
Manifest.toml
LocalPreferences.toml
.DS_Store
8 changes: 5 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# Flux Release Notes

## v0.13.15
* Added [MultiHeadAttention](https://github.com/FluxML/Flux.jl/pull/2146) layer.

## v0.13.14
* Fixed various deprecation warnings, from `Zygone.@nograd` and `Vararg`.
* Initial support for `AMDGPU` via extension mechanism.
* Add `gpu_backend` preference to select GPU backend using `LocalPreference.toml`.
* Add `Flux.gpu_backend!` method to switch between GPU backends.

## v0.13.13
* Added `f16` which changes precision to `Float16`, recursively.
* Initial support for AMDGPU via extension mechanism.
* Add `gpu_backend` preference to select GPU backend using `LocalPreference.toml`.
* Add `Flux.gpu_backend!` method to switch between GPU backends.
Comment on lines -9 to -11
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these were misplaced


## v0.13.12
* CUDA.jl 4.0 compatibility.
Expand Down
11 changes: 10 additions & 1 deletion docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The `Dense` exemplifies several features:

* It take an `init` keyword, which accepts a function acting like `rand`. That is, `init(2,3,4)` should create an array of this size. Flux has [many such functions](@ref man-init-funcs) built-in. All make a CPU array, moved later with [`gpu`](@ref Flux.gpu) if desired.

* The bias vector is always intialised [`Flux.zeros32`](@ref). The keyword `bias=false` will turn this off, i.e. keeping the bias permanently zero.
* The bias vector is always initialised [`Flux.zeros32`](@ref). The keyword `bias=false` will turn this off, i.e. keeping the bias permanently zero.

* It is annotated with [`@functor`](@ref Functors.@functor), which means that [`params`](@ref Flux.params) will see the contents, and [`gpu`](@ref Flux.gpu) will move their arrays to the GPU.

Expand Down Expand Up @@ -54,6 +54,15 @@ SamePad
Flux.flatten
```

## MultiHeadAttention

The basic blocks needed to implement [Transformer](https://arxiv.org/abs/1706.03762) architectures. See also the functional counterparts
documented in NNlib's [Attention](@ref) section.

```@docs
MultiHeadAttention
```

### Pooling

These layers are commonly used after a convolution layer, and reduce the size of its output. They have no trainable parameters.
Expand Down
24 changes: 18 additions & 6 deletions docs/src/models/nnlib.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@

Flux re-exports all of the functions exported by the [NNlib](https://github.com/FluxML/NNlib.jl) package. This includes activation functions, described on [their own page](@ref man-activations). Many of the functions on this page exist primarily as the internal implementation of Flux layer, but can also be used independently.


## Attention

Primitives for the [`MultiHeadAttention`](ref) layer.

```@docs
NNlib.dot_product_attention
NNlib.dot_product_attention_scores
NNlib.make_causal_mask
```

## Softmax

`Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally.
`Flux`'s [`Flux.logitcrossentropy`](@ref) uses [`NNlib.logsoftmax`](@ref) internally.

```@docs
softmax
Expand All @@ -13,7 +24,8 @@ logsoftmax

## Pooling

`Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, and `MeanPool` use `NNlib.PoolDims`, `NNlib.maxpool`, and `NNlib.meanpool` as their backend.
`Flux`'s [`AdaptiveMaxPool`](@ref), [`AdaptiveMeanPool`](@ref), [`GlobalMaxPool`](@ref), [`GlobalMeanPool`](@ref),
[`MaxPool`](@ref), and [`MeanPool`](@ref) use [`NNlib.PoolDims`](@ref), [`NNlib.maxpool`](@ref), and [`NNlib.meanpool`](@ref) as their backend.

```@docs
PoolDims
Expand All @@ -32,7 +44,7 @@ pad_zeros

## Convolution

`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.
`Flux`'s [`Conv`](@ref) and [`CrossCor`](@ref) layers use [`NNlib.DenseConvDims`](@ref) and [`NNlib.conv`](@ref) internally.

```@docs
conv
Expand All @@ -44,7 +56,7 @@ DenseConvDims

## Upsampling

`Flux`'s `Upsample` layer uses `NNlib.upsample_nearest`, `NNlib.upsample_bilinear`, and `NNlib.upsample_trilinear` as its backend. Additionally, `Flux`'s `PixelShuffle` layer uses `NNlib.pixel_shuffle` as its backend.
`Flux`'s [`Upsample`](@ref) layer uses [`NNlib.upsample_nearest`](@ref), [`NNlib.upsample_bilinear`](@ref), and [`NNlib.upsample_trilinear`](@ref) as its backend. Additionally, `Flux`'s [`PixelShuffle`](@ref) layer uses [`NNlib.pixel_shuffle`](@ref) as its backend.

```@docs
upsample_nearest
Expand All @@ -60,7 +72,7 @@ pixel_shuffle

## Batched Operations

`Flux`'s `Bilinear` layer uses `NNlib.batched_mul` internally.
`Flux`'s [`Flux.Bilinear`](@ref) layer uses [`NNlib.batched_mul`](@ref) internally.

```@docs
batched_mul
Expand All @@ -72,7 +84,7 @@ batched_vec

## Gather and Scatter

`Flux`'s `Embedding` layer uses `NNlib.gather` as its backend.
`Flux`'s [`Embedding`](@ref) layer uses [`NNlib.gather`](@ref) as its backend.

```@docs
NNlib.gather
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/2021-10-08-dcgan-mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ We will be using the [relu](https://fluxml.ai/Flux.jl/stable/models/nnlib/#NNlib
We will also apply the weight initialization method mentioned in the original DCGAN paper.

```julia
# Function for intializing the model weights with values
# Function for initializing the model weights with values
# sampled from a Gaussian distribution with μ=0 and σ=0.02
dcgan_init(shape...) = randn(Float32, shape) * 0.02f0
```
Expand Down