Skip to content

Commit

Permalink
Add tests for complex valued training
Browse files Browse the repository at this point in the history
  • Loading branch information
staticfloat committed Nov 29, 2021
1 parent 9326702 commit 8c3d852
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
3 changes: 3 additions & 0 deletions test/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ y = [1, 1, 0, 0]

@testset "mse" begin
@test mse(ŷ, y) (.1^2 + .9^2)/2

# Test that mse() loss works on complex values:
@test mse(0 + 0im, 1 + 1im) == 2
end

@testset "mae" begin
Expand Down
37 changes: 37 additions & 0 deletions test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,40 @@ end
Flux.update!(opt, θ, gs)
@test w wold .- 0.1
end

# Flux PR #1776
# We need to test that optimisers like ADAM that maintain an internal momentum
# estimate properly calculate the second-order statistics on the gradients as
# the flow backward through the model. Previously, we would calculate second-
# order statistics via `Δ^2` rather than the complex-aware `Δ * conj(Δ)`, which
# wreaks all sorts of havoc on our training loops. This test ensures that
# a simple optimization is montonically decreasing (up to learning step effects)
@testset "Momentum Optimisers and complex values" begin
# Test every optimizer that has momentum internally
for opt_ctor in [ADAM, RMSProp, RADAM, OADAM, ADAGrad, ADADelta, NADAM, AdaBelief]
# Our "model" is just a complex number
w = zeros(ComplexF32, 1)

# Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x`
function loss()
# Deterministic training data is the best training data
x = ones(1, 1) + 1im*ones(1, 1)

# Manually implement `mse()` to allow demonstration of brokenness
# on older Flux builds that don't have a fixed `mse()`
return sum(abs2.(w * x .- conj(x)))
end

params = Flux.Params([w])
opt = opt_ctor(1e-2)

# Train for 10 iterations, enforcing that loss is monotonically decreasing
last_loss = Inf
for idx in 1:10
grads = Flux.gradient(loss, params)
@test loss() < last_loss
last_loss = loss()
Flux.update!(opt, params, grads)
end
end
end

0 comments on commit 8c3d852

Please sign in to comment.