Skip to content

Commit

Permalink
refactor: cleanup some old pre-1.0 hacks
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 23, 2024
1 parent d755929 commit 5e3dca2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
15 changes: 7 additions & 8 deletions src/extended_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,18 @@ CRC.@non_differentiable istraining(::Any)

end

using .LuxOps: LuxOps, multigate
using .LuxOps: LuxOps, multigate, xlogx, xlogy, foldl_init

const safe_getproperty = LuxOps.getproperty
const safe_eachslice = LuxOps.eachslice

# TODO: directly import them from LuxOps from 1.0
const private_xlogx = LuxOps.xlogx
const private_xlogy = LuxOps.xlogy
const private_foldl_init = LuxOps.foldl_init

# These are defined here to avoid a circular dependency among modules
for (op, field) in (:bias => :use_bias, :affine => :affine,
:track_stats => :track_stats, :train_state => :train_state)
for (op, field) in (
:bias => :use_bias,
:affine => :affine,
:track_stats => :track_stats,
:train_state => :train_state
)
@eval function $(Symbol(:has_, op))(l::AbstractLuxLayer)
res = known(safe_getproperty(l, Val($(Meta.quot(field)))))
return ifelse(res === nothing, false, res)
Expand Down
8 changes: 4 additions & 4 deletions src/helpers/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ end

for logits in (true, false)
return_expr = logits ? :(return loss.agg((1 .- ỹ) .*.- logsigmoid.(ŷ))) :
:(return loss.agg(-private_xlogy.(ỹ, ŷ .+ ϵ) .-
private_xlogy.(1 .- ỹ, 1 .-.+ ϵ)))
:(return loss.agg(-xlogy.(ỹ, ŷ .+ ϵ) .-
xlogy.(1 .- ỹ, 1 .-.+ ϵ)))

@eval function unsafe_apply_loss(loss::BinaryCrossEntropyLoss{$(logits)}, ŷ, y)
T = promote_type(eltype(ŷ), eltype(y))
Expand Down Expand Up @@ -387,7 +387,7 @@ for logits in (true, false)
:(return LossFunctionImpl.fused_agg(
loss.agg, -, sum(ỹ .* logsoftmax(ŷ; loss.dims); loss.dims))) :
:(return LossFunctionImpl.fused_agg(
loss.agg, -, sum(private_xlogy.(ỹ, ŷ .+ ϵ); loss.dims)))
loss.agg, -, sum(xlogy.(ỹ, ŷ .+ ϵ); loss.dims)))

@eval function unsafe_apply_loss(loss::CrossEntropyLoss{$(logits)}, ŷ, y)
T = promote_type(eltype(ŷ), eltype(y))
Expand Down Expand Up @@ -603,7 +603,7 @@ end
function unsafe_apply_loss(loss::KLDivergenceLoss, ŷ, y)
cross_entropy = unsafe_apply_loss(loss.celoss, ŷ, y)
# Intentional broadcasting for Zygote type stability
entropy = loss.agg(sum(private_xlogx.(y); loss.dims))
entropy = loss.agg(sum(xlogx.(y); loss.dims))
return entropy + cross_entropy
end

Expand Down
2 changes: 1 addition & 1 deletion src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function (r::Recurrence{True})(x::Union{AbstractVector, NTuple}, ps, st::NamedTu
(out, carry), state = apply(r.cell, (input, carry), ps, state)
return vcat(outputs, [out]), carry, state
end
results = private_foldl_init(recur_op, x)
results = foldl_init(recur_op, x)
return first(results), last(results)
end

Expand Down

0 comments on commit 5e3dca2

Please sign in to comment.