From 5e3dca2c1a87747595f52927767d826adc609082 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 23 Nov 2024 18:35:21 -0500 Subject: [PATCH] refactor: cleanup some old pre-1.0 hacks --- src/extended_ops.jl | 15 +++++++-------- src/helpers/losses.jl | 8 ++++---- src/layers/recurrent.jl | 2 +- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/extended_ops.jl b/src/extended_ops.jl index fafe000f2..0223d775c 100644 --- a/src/extended_ops.jl +++ b/src/extended_ops.jl @@ -224,19 +224,18 @@ CRC.@non_differentiable istraining(::Any) end -using .LuxOps: LuxOps, multigate +using .LuxOps: LuxOps, multigate, xlogx, xlogy, foldl_init const safe_getproperty = LuxOps.getproperty const safe_eachslice = LuxOps.eachslice -# TODO: directly import them from LuxOps from 1.0 -const private_xlogx = LuxOps.xlogx -const private_xlogy = LuxOps.xlogy -const private_foldl_init = LuxOps.foldl_init - # These are defined here to avoid a circular dependency among modules -for (op, field) in (:bias => :use_bias, :affine => :affine, - :track_stats => :track_stats, :train_state => :train_state) +for (op, field) in ( + :bias => :use_bias, + :affine => :affine, + :track_stats => :track_stats, + :train_state => :train_state +) @eval function $(Symbol(:has_, op))(l::AbstractLuxLayer) res = known(safe_getproperty(l, Val($(Meta.quot(field))))) return ifelse(res === nothing, false, res) diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index 493e1b357..0974fbc9a 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -261,8 +261,8 @@ end for logits in (true, false) return_expr = logits ? :(return loss.agg((1 .- ỹ) .* ŷ .- logsigmoid.(ŷ))) : - :(return loss.agg(-private_xlogy.(ỹ, ŷ .+ ϵ) .- - private_xlogy.(1 .- ỹ, 1 .- ŷ .+ ϵ))) + :(return loss.agg(-xlogy.(ỹ, ŷ .+ ϵ) .- + xlogy.(1 .- ỹ, 1 .- ŷ .+ ϵ))) @eval function unsafe_apply_loss(loss::BinaryCrossEntropyLoss{$(logits)}, ŷ, y) T = promote_type(eltype(ŷ), eltype(y)) @@ -387,7 +387,7 @@ for logits in (true, false) :(return LossFunctionImpl.fused_agg( loss.agg, -, sum(ỹ .* logsoftmax(ŷ; loss.dims); loss.dims))) : :(return LossFunctionImpl.fused_agg( - loss.agg, -, sum(private_xlogy.(ỹ, ŷ .+ ϵ); loss.dims))) + loss.agg, -, sum(xlogy.(ỹ, ŷ .+ ϵ); loss.dims))) @eval function unsafe_apply_loss(loss::CrossEntropyLoss{$(logits)}, ŷ, y) T = promote_type(eltype(ŷ), eltype(y)) @@ -603,7 +603,7 @@ end function unsafe_apply_loss(loss::KLDivergenceLoss, ŷ, y) cross_entropy = unsafe_apply_loss(loss.celoss, ŷ, y) # Intentional broadcasting for Zygote type stability - entropy = loss.agg(sum(private_xlogx.(y); loss.dims)) + entropy = loss.agg(sum(xlogx.(y); loss.dims)) return entropy + cross_entropy end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index de95b5b59..57e571650 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -131,7 +131,7 @@ function (r::Recurrence{True})(x::Union{AbstractVector, NTuple}, ps, st::NamedTu (out, carry), state = apply(r.cell, (input, carry), ps, state) return vcat(outputs, [out]), carry, state end - results = private_foldl_init(recur_op, x) + results = foldl_init(recur_op, x) return first(results), last(results) end