Skip to content

Commit

Permalink
fix: use generic broadcasting for complex numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 25, 2024
1 parent 6f9f8d6 commit f412731
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.9"
version = "1.3.10"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
5 changes: 4 additions & 1 deletion lib/LuxLib/src/impl/bias_activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,17 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation),
return y, ∇bias_activation_rrule
end

y, ∇broadcast = CRC.rrule_via_ad(cfg, broadcast, σ +, x, reshape_bias(x, bias))
y, ∇broadcast = CRC.rrule_via_ad(
cfg, broadcast_bias_activation_generic, σ, x, reshape_bias(x, bias))
∇bias_activation_rrule = @closure Δ -> begin
_, _, ∂x, ∂bias = ∇broadcast(Δ)
return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(vec(∂bias))
end
return y, ∇bias_activation_rrule
end

@inline broadcast_bias_activation_generic::F, x, b) where {F} = σ.(x .+ b)

bias_activation!!(::typeof(identity), x::AbstractVector, ::Nothing) = x
for bType in (Nothing, AbstractVector)
@eval function bias_activation!!::F, x::AbstractVector, bias::$(bType)) where {F}
Expand Down
7 changes: 6 additions & 1 deletion lib/LuxLib/src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ is_mutable_array(::Nothing) = True()

ChainRulesCore.@non_differentiable is_mutable_array(::Any...)

for op in (:has_dual, :has_float16, :is_tracked)
for op in (:has_dual, :has_float16, :is_tracked, :has_complex)
@eval $op(::Nothing) = False()
@eval $op(x::Numeric) = $op(eltype(x))
end
Expand All @@ -38,6 +38,9 @@ has_dual(::Type{<:ForwardDiff.Dual}) = True()
has_float16(_) = False()
has_float16(::Type{<:Float16}) = True()

has_complex(_) = False()
has_complex(::Type{<:Complex}) = True()

is_tracked(_) = False()

has_autodiff_value(x) = is_tracked(x) | has_dual(x)
Expand All @@ -51,6 +54,7 @@ function use_generic_broadcasting(xs::Tuple)
xs_unwrapped = unrolled_map(unwrap_array, xs)
return unrolled_any(has_autodiff_value, xs_unwrapped) |
unrolled_any(has_float16, xs_unwrapped) |
unrolled_any(has_complex, xs_unwrapped) |
unrolled_any(static_isa(StaticArray), xs_unwrapped)
end

Expand Down Expand Up @@ -198,6 +202,7 @@ Currently supported modes are:
+ ReverseDiff Arrays
+ Tracker Arrays
+ ForwardDiff.Dual Arrays
+ Complex Arrays
- `GPUBroadcastOp{dev}`: GPU Arrays where `dev` is obtained from `get_device_type(xs)`.
This option dispatches should preferably use `KernelAbstractions` or specialized vendor
Expand Down
27 changes: 27 additions & 0 deletions test/issue_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
@testitem "complex differentiation: issue #977" tags=[:misc] begin
using Lux, Zygote, Random

rng = Random.default_rng()
Random.seed!(rng, 666)

rbf(x) = exp.(-(x .^ 2))

U = Lux.Chain(
Lux.Dense(1, 10, rbf),
Lux.Dense(10, 3, rbf)
)

θ, st = Lux.setup(rng, U)

function complex_step_differentiation(f::Function, x::Float64, ϵ::Float64)
return imag(f(x + ϵ * im)) / ϵ
end

loss(t) = sum(complex_step_differentiation-> U([τ], θ, st)[begin], t, 1e-5))

if pkgversion(LuxLib) v"1.3.10"
@test only(Zygote.gradient(loss, 1.0)) isa Float64
else
@test_broken only(Zygote.gradient(loss, 1.0)) isa Float64
end
end

0 comments on commit f412731

Please sign in to comment.