-
-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move
ctc_loss
from Flux to NNlib (#426)
* move ctc loss from Flux * fixup
- Loading branch information
Showing
5 changed files
with
187 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# CTC loss moved from Flux.jl to NNlib + NNlibCUDA | ||
|
||
## CPU implementation | ||
|
||
""" | ||
logaddexp(a, b) | ||
Adds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))` | ||
""" | ||
function logaddexp(a, b) | ||
isinf(a) && return b | ||
isinf(b) && return a | ||
|
||
# always want the greater number on the left in the exponentiation; | ||
# the magnitude difference may end up making the number very positive | ||
# which will cause exp() to return Inf | ||
# E.g., a = -900, b = -800, will give exp(-800 - -900), which will be | ||
# Inf for Float32 values | ||
if a < b | ||
a, b = b, a | ||
end | ||
return a + log(1+exp(b-a)) | ||
end | ||
|
||
""" | ||
add_blanks(z) | ||
Adds blanks to the start and end of `z`, and between items in `z` | ||
""" | ||
function add_blanks(z, blank) | ||
z′ = fill(blank, 2*length(z) + 1) | ||
z′[2 .* eachindex(z)] = z | ||
return z′ | ||
end | ||
|
||
function ctc_alpha(ŷ::AbstractArray, y) | ||
typed_zero = zero(ŷ[1]) | ||
ŷ = logsoftmax(ŷ) | ||
blank = size(ŷ, 1) | ||
z′ = add_blanks(y, blank) | ||
T = size(ŷ, 2) | ||
U′ = length(z′) | ||
|
||
α = fill(log(typed_zero), U′, T) | ||
α[1,1] = ŷ[blank, 1] | ||
α[2,1] = ŷ[z′[2], 1] | ||
for t=2:T | ||
bound = max(1, U′ - 2(T - t) - 1) | ||
for u=bound:U′ | ||
if u == 1 | ||
α[u,t] = α[u, t-1] | ||
else | ||
α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1]) | ||
|
||
# array bounds check and f(u) function from Eq. 7.9 | ||
if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u]) | ||
α[u,t] = logaddexp(α[u,t], α[u-2,t-1]) | ||
end | ||
end | ||
α[u,t] += ŷ[z′[u], t] | ||
end | ||
end | ||
return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ) | ||
end | ||
|
||
function ∇ctc_loss(ŷ::AbstractArray, y, out) | ||
loss, α, z′, ŷ = out | ||
U′, T = size(α) | ||
blank = size(ŷ, 1) | ||
typed_zero = zero(first(α)) | ||
|
||
# Calculate beta coefficients, from the bottom-right, to the upper-left | ||
β = fill(log(typed_zero), U′, T) | ||
|
||
# Fill bottom-right corner so bounding errors can be avoided | ||
# by starting `u` at `U′-1` | ||
β[U′, T] = typed_zero | ||
β[U′-1, T] = typed_zero | ||
|
||
# start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 | ||
for t=(T-1):-1:1 | ||
bound = min(U′, 2t) | ||
for u=bound:-1:1 | ||
if u == U′ | ||
β[u,t] = ŷ[z′[u], t+1] + β[u, t+1] | ||
else | ||
β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1]) | ||
|
||
# array bounds check and g(u) function from Eq. 7.16 | ||
if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2] | ||
β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1]) | ||
end | ||
end | ||
end | ||
end | ||
|
||
# Accumulate alpha-beta products for each category, | ||
# then calculate gradients | ||
accum = fill(log(typed_zero), size(ŷ)) | ||
for t=1:T | ||
for u=1:U′ | ||
accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t]) | ||
end | ||
end | ||
grads = exp.(ŷ) .- exp.(accum .+ loss) | ||
return grads | ||
end | ||
|
||
""" | ||
ctc_loss(ŷ, y) | ||
Computes the connectionist temporal classification loss between `ŷ` | ||
and `y`. | ||
`ŷ` must be a classes-by-time matrices, i.e., each row | ||
represents a class and each column represents a time step. | ||
Additionally, the `logsoftmax` function will be applied to `ŷ`, so | ||
`ŷ` must be the raw activation values from the neural network and | ||
not, for example, the activations after being passed through a | ||
`softmax` activation function. `y` must be a 1D array of the labels | ||
associated with `ŷ`. The blank label is assumed to be the last label | ||
category in `ŷ`, so it is equivalent to `size(ŷ, 1)`. | ||
Used for sequence-to-sequence classification problems such as | ||
speech recognition and handwriting recognition where the exact | ||
time-alignment of the output (e.g., letters) is not needed to | ||
solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves/icml_2006.pdf) | ||
or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7) | ||
for mathematical details. | ||
""" | ||
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss | ||
|
||
function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y) | ||
tmp = ctc_alpha(ŷ, y) | ||
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent()) | ||
return tmp.loss, ctc_loss_pullback | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
using Test | ||
using NNlib: ctc_loss | ||
using Zygote: gradient | ||
using LinearAlgebra | ||
|
||
# Custom function to check numerical gradient of ctc loss, | ||
# based on `ngradient` in `Tracker.jl` | ||
function ctc_ngradient(x, y) | ||
f = ctc_loss | ||
grads = zero(x) | ||
for i in 1:length(x) | ||
δ = sqrt(eps()) | ||
tmp = x[i] | ||
x[i] = tmp - δ/2 | ||
y1 = f(x, y) | ||
x[i] = tmp + δ/2 | ||
y2 = f(x, y) | ||
x[i] = tmp | ||
grads[i] = (y2-y1)/δ | ||
end | ||
return grads | ||
end | ||
|
||
@testset "ctc_loss" begin | ||
x = rand(10, 50) | ||
y = rand(1:9, 30) | ||
g1 = gradient(ctc_loss, x, y)[1] | ||
g2 = ctc_ngradient(x, y) | ||
@test g1 ≈ g2 rtol=1e-5 atol=1e-5 | ||
|
||
# tests using hand-calculated values | ||
x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] | ||
y = [1, 2] | ||
@test ctc_loss(x, y) ≈ 3.6990738275138035 | ||
|
||
g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] | ||
ghat = gradient(ctc_loss, x, y)[1] | ||
@test g ≈ ghat rtol=1e-5 atol=1e-5 | ||
|
||
x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] | ||
y = [1, 2] | ||
@test ctc_loss(x, y) ≈ 8.02519869363453 | ||
|
||
g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] | ||
ghat = gradient(ctc_loss, x, y)[1] | ||
@test g ≈ ghat rtol=1e-5 atol=1e-5 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
023cd3d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
023cd3d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/64797
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: