Skip to content

Commit

Permalink
Move ctc_loss from Flux to NNlib (#426)
Browse files Browse the repository at this point in the history
* move ctc loss from Flux

* fixup
  • Loading branch information
mcabbott authored Jul 22, 2022
1 parent c9faa64 commit 023cd3d
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.8.8"
version = "0.8.9"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
3 changes: 3 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
include("conv_bias_act.jl")
export conv_bias_act, conv_bias_act!

include("ctc.jl")
export ctc_loss

include("pooling.jl")
export maxpool, maxpool!, meanpool, meanpool!,
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!
Expand Down
132 changes: 132 additions & 0 deletions src/ctc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# CTC loss moved from Flux.jl to NNlib + NNlibCUDA

## CPU implementation

"""
logaddexp(a, b)
Adds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))`
"""
function logaddexp(a, b)
isinf(a) && return b
isinf(b) && return a

# always want the greater number on the left in the exponentiation;
# the magnitude difference may end up making the number very positive
# which will cause exp() to return Inf
# E.g., a = -900, b = -800, will give exp(-800 - -900), which will be
# Inf for Float32 values
if a < b
a, b = b, a
end
return a + log(1+exp(b-a))
end

"""
add_blanks(z)
Adds blanks to the start and end of `z`, and between items in `z`
"""
function add_blanks(z, blank)
z′ = fill(blank, 2*length(z) + 1)
z′[2 .* eachindex(z)] = z
return z′
end

function ctc_alpha(ŷ::AbstractArray, y)
typed_zero = zero(ŷ[1])
= logsoftmax(ŷ)
blank = size(ŷ, 1)
z′ = add_blanks(y, blank)
T = size(ŷ, 2)
U′ = length(z′)

α = fill(log(typed_zero), U′, T)
α[1,1] = ŷ[blank, 1]
α[2,1] = ŷ[z′[2], 1]
for t=2:T
bound = max(1, U′ - 2(T - t) - 1)
for u=bound:U′
if u == 1
α[u,t] = α[u, t-1]
else
α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1])

# array bounds check and f(u) function from Eq. 7.9
if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u])
α[u,t] = logaddexp(α[u,t], α[u-2,t-1])
end
end
α[u,t] += ŷ[z′[u], t]
end
end
return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ)
end

function ∇ctc_loss(ŷ::AbstractArray, y, out)
loss, α, z′, ŷ = out
U′, T = size(α)
blank = size(ŷ, 1)
typed_zero = zero(first(α))

# Calculate beta coefficients, from the bottom-right, to the upper-left
β = fill(log(typed_zero), U′, T)

# Fill bottom-right corner so bounding errors can be avoided
# by starting `u` at `U′-1`
β[U′, T] = typed_zero
β[U′-1, T] = typed_zero

# start at T-1 so that β(T, u) = log(0) for all u < U′ - 1
for t=(T-1):-1:1
bound = min(U′, 2t)
for u=bound:-1:1
if u == U′
β[u,t] = ŷ[z′[u], t+1] + β[u, t+1]
else
β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1])

# array bounds check and g(u) function from Eq. 7.16
if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2]
β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1])
end
end
end
end

# Accumulate alpha-beta products for each category,
# then calculate gradients
accum = fill(log(typed_zero), size(ŷ))
for t=1:T
for u=1:U′
accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t])
end
end
grads = exp.(ŷ) .- exp.(accum .+ loss)
return grads
end

"""
ctc_loss(ŷ, y)
Computes the connectionist temporal classification loss between `ŷ`
and `y`.
`ŷ` must be a classes-by-time matrices, i.e., each row
represents a class and each column represents a time step.
Additionally, the `logsoftmax` function will be applied to `ŷ`, so
`ŷ` must be the raw activation values from the neural network and
not, for example, the activations after being passed through a
`softmax` activation function. `y` must be a 1D array of the labels
associated with `ŷ`. The blank label is assumed to be the last label
category in `ŷ`, so it is equivalent to `size(ŷ, 1)`.
Used for sequence-to-sequence classification problems such as
speech recognition and handwriting recognition where the exact
time-alignment of the output (e.g., letters) is not needed to
solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves/icml_2006.pdf)
or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7)
for mathematical details.
"""
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss

function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)
tmp = ctc_alpha(ŷ, y)
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent())
return tmp.loss, ctc_loss_pullback
end
47 changes: 47 additions & 0 deletions test/ctc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using Test
using NNlib: ctc_loss
using Zygote: gradient
using LinearAlgebra

# Custom function to check numerical gradient of ctc loss,
# based on `ngradient` in `Tracker.jl`
function ctc_ngradient(x, y)
f = ctc_loss
grads = zero(x)
for i in 1:length(x)
δ = sqrt(eps())
tmp = x[i]
x[i] = tmp - δ/2
y1 = f(x, y)
x[i] = tmp + δ/2
y2 = f(x, y)
x[i] = tmp
grads[i] = (y2-y1)/δ
end
return grads
end

@testset "ctc_loss" begin
x = rand(10, 50)
y = rand(1:9, 30)
g1 = gradient(ctc_loss, x, y)[1]
g2 = ctc_ngradient(x, y)
@test g1 g2 rtol=1e-5 atol=1e-5

# tests using hand-calculated values
x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.]
y = [1, 2]
@test ctc_loss(x, y) 3.6990738275138035

g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457]
ghat = gradient(ctc_loss, x, y)[1]
@test g ghat rtol=1e-5 atol=1e-5

x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.]
y = [1, 2]
@test ctc_loss(x, y) 8.02519869363453

g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07]
ghat = gradient(ctc_loss, x, y)[1]
@test g ghat rtol=1e-5 atol=1e-5
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ include("test_utils.jl")
include("conv_bias_act.jl")
end

@testset "CTC Loss" begin
include("ctc.jl")
end

@testset "Inference" begin
include("inference.jl")
end
Expand Down

2 comments on commit 023cd3d

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/64797

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.9 -m "<description of version>" 023cd3da63892b754b4197fe7a848093128f2bf9
git push origin v0.8.9

Please sign in to comment.