Skip to content

Commit

Permalink
xReLU
Browse files Browse the repository at this point in the history
  • Loading branch information
cossio committed Dec 13, 2021
1 parent 5476676 commit e774670
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 21 deletions.
1 change: 1 addition & 0 deletions src/RestrictedBoltzmannMachines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ module RestrictedBoltzmannMachines
include("layers/relu.jl")
include("layers/drelu.jl")
include("layers/prelu.jl")
include("layers/xrelu.jl")
include("layers/common.jl")

include("rbm.jl")
Expand Down
33 changes: 32 additions & 1 deletion src/layers/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function energy(layer::Union{Binary, Spin, Potts}, x::AbstractArray)
return -reshape(x, length(layer), size(x)[end])' * vec(layer.θ)
end

const _ThetaLayers = Union{Binary, Spin, Potts, Gaussian, ReLU, pReLU}
const _ThetaLayers = Union{Binary, Spin, Potts, Gaussian, ReLU, pReLU, xReLU}
Base.ndims(layer::_ThetaLayers) = ndims(layer.θ)
Base.size(layer::_ThetaLayers) = size(layer.θ)
Base.size(layer::_ThetaLayers, d::Int) = size(layer.θ, d)
Expand All @@ -33,3 +33,34 @@ function dReLU(l::pReLU)
θn = @. l.θ - l.Δ / (1 - l.η)
return dReLU(θp, θn, γp, γn)
end

function xReLU(l::dReLU)
#= The xReLU <-> dReLU conversion is bijective only if the γ's are positive =#
γp = abs.(l.γp)
γn = abs.(l.γn)
γ = @. 2γp * γn / (γp + γn)
ξ = @. (γn - γp) / (γp + γn - abs(γn - γp))
θ = @. (l.θp * γn + l.θn * γp) / (γp + γn)
Δ = @. γ * (l.θp - l.θn) / (γp + γn)
return xReLU(θ, Δ, γ, ξ)
end

function dReLU(l::xReLU)
ξp = @. (1 + abs(l.ξ)) / (1 + max( 2l.ξ, 0))
ξn = @. (1 + abs(l.ξ)) / (1 + max(-2l.ξ, 0))
γp = @. l.γ * ξp
γn = @. l.γ * ξn
θp = @. l.θ + l.Δ * ξp
θn = @. l.θ - l.Δ * ξn
return dReLU(θp, θn, γp, γn)
end

function xReLU(l::pReLU)
ξ = @. l.η / (1 - abs(l.η))
return xReLU(l.θ, l.Δ, l.γ, ξ)
end

function pReLU(l::xReLU)
η = @. l.ξ / (1 + abs(l.ξ))
return pReLU(l.θ, l.Δ, l.γ, η)
end
29 changes: 29 additions & 0 deletions src/layers/xrelu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
struct xReLU{A<:AbstractArray}
θ::A
Δ::A
γ::A
ξ::A
function xReLU::A, Δ::A, γ::A, ξ::A) where {A<:AbstractArray}
@assert size(θ) == size(γ) == size(Δ) == size(ξ)
return new{A}(θ, Δ, γ, ξ)
end
end

function xReLU(::Type{T}, n::Int...) where {T}
θ = zeros(T, n...)
Δ = zeros(T, n...)
γ = ones(T, n...)
ξ = zeros(T, n...)
return pReLU(θ, Δ, γ, ξ)
end

xReLU(n::Int...) = xReLU(Float64, n...)

Flux.@functor xReLU

energy(layer::xReLU, x::AbstractArray) = energy(dReLU(layer), x)
cgf(layer::xReLU, inputs::AbstractArray, β::Real = 1) = cgf(dReLU(layer), inputs, β)

function sample_from_inputs(layer::xReLU, inputs::AbstractArray, β::Real = 1)
return sample_from_inputs(dReLU(layer), inputs, β)
end
57 changes: 39 additions & 18 deletions test/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ _layers = (
RBMs.Gaussian,
RBMs.ReLU,
RBMs.dReLU,
RBMs.pReLU
RBMs.pReLU,
RBMs.xReLU
)

function random_layer(
Expand All @@ -22,8 +23,8 @@ function random_layer(
return T(randn(N...), randn(N...))
end

function random_layer(::Type{RBMs.dReLU}, N::Int...)
return RBMs.dReLU(randn(N...), randn(N...), randn(N...), randn(N...))
function random_layer(::Type{T}, N::Int...) where {T <: Union{RBMs.dReLU, RBMs.xReLU}}
return T(randn(N...), randn(N...), randn(N...), randn(N...))
end

function random_layer(::Type{RBMs.pReLU}, N::Int...)
Expand Down Expand Up @@ -197,28 +198,48 @@ end
@test RBMs.cgf(layer, inputs, β) vec(sum(Γ; dims=(1,2,3)))
end

@testset "pReLU / dReLU convert" begin
@testset "pReLU / xReLU / dReLU convert" begin
N = (10, 7)
B = 13
x = randn(N..., B)

prelu = random_layer(RBMs.pReLU, N...)
drelu = RBMs.dReLU(prelu)
@test prelu.θ RBMs.pReLU(drelu).θ
@test prelu.Δ RBMs.pReLU(drelu).Δ
@test prelu.γ RBMs.pReLU(drelu).γ
@test prelu.η RBMs.pReLU(drelu).η
@test RBMs.energy(prelu, x) @inferred RBMs.energy(drelu, x)
@test RBMs.cgf(prelu, x) @inferred RBMs.cgf(drelu, x)

drelu = random_layer(RBMs.dReLU, N...)
prelu = RBMs.pReLU(drelu)
@test drelu.θp RBMs.dReLU(prelu).θp
@test drelu.θn RBMs.dReLU(prelu).θn
prelu = @inferred RBMs.pReLU(drelu)
xrelu = @inferred RBMs.xReLU(drelu)
@test drelu.θp RBMs.dReLU(prelu).θp RBMs.dReLU(xrelu).θp
@test drelu.θn RBMs.dReLU(prelu).θn RBMs.dReLU(xrelu).θn
@test drelu.γp RBMs.dReLU(prelu).γp
@test drelu.γn RBMs.dReLU(prelu).γn
@test RBMs.energy(drelu, x) @inferred RBMs.energy(prelu, x)
@test RBMs.cgf(drelu, x) @inferred RBMs.cgf(prelu, x)
@test abs.(drelu.γp) RBMs.dReLU(xrelu).γp
@test abs.(drelu.γn) RBMs.dReLU(xrelu).γn
@test RBMs.energy(drelu, x) RBMs.energy(prelu, x) RBMs.energy(xrelu, x)
@test RBMs.cgf(drelu, x) RBMs.cgf(prelu, x) RBMs.cgf(xrelu, x)

prelu = random_layer(RBMs.pReLU, N...)
drelu = @inferred RBMs.dReLU(prelu)
xrelu = @inferred RBMs.xReLU(prelu)
@test prelu.θ RBMs.pReLU(drelu).θ RBMs.pReLU(xrelu).θ
@test prelu.Δ RBMs.pReLU(drelu).Δ RBMs.pReLU(xrelu).Δ
@test prelu.γ RBMs.pReLU(drelu).γ RBMs.pReLU(xrelu).γ
@test prelu.η RBMs.pReLU(drelu).η RBMs.pReLU(xrelu).η
@test RBMs.energy(drelu, x) RBMs.energy(prelu, x) RBMs.energy(xrelu, x)
@test RBMs.cgf(drelu, x) RBMs.cgf(prelu, x) RBMs.cgf(xrelu, x)

xrelu = random_layer(RBMs.xReLU, N...)
drelu = @inferred RBMs.dReLU(xrelu)
prelu = @inferred RBMs.pReLU(xrelu)
@test xrelu.θ RBMs.xReLU(drelu).θ RBMs.xReLU(prelu).θ
@test xrelu.Δ RBMs.xReLU(drelu).Δ RBMs.xReLU(prelu).Δ
@test xrelu.ξ RBMs.xReLU(drelu).ξ RBMs.xReLU(prelu).ξ
@test xrelu.γ RBMs.xReLU(prelu).γ
@test abs.(xrelu.γ) RBMs.xReLU(drelu).γ
@test RBMs.energy(drelu, x) RBMs.energy(prelu, x) RBMs.energy(xrelu, x)
@test RBMs.cgf(drelu, x) RBMs.cgf(prelu, x) RBMs.cgf(xrelu, x)

for layer in (drelu, prelu, xrelu)
@inferred RBMs.energy(layer, x)
@inferred RBMs.cgf(layer, x)
end
end

@testset "dReLU" begin
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ using SafeTestsets, Random, Test
@time @safetestset "util" begin include("util.jl") end
@time @safetestset "linalg" begin include("linalg.jl") end
@time @safetestset "onehot" begin include("onehot.jl") end
@time @safetestset "layers" begin include("layers.jl") end
@time @safetestset "rbm" begin include("rbm.jl") end
@time @safetestset "minibatches" begin include("minibatches.jl") end
@time @safetestset "zerosum" begin include("zerosum.jl") end
@time @safetestset "initialization" begin include("initialization.jl") end
@time @safetestset "layers" begin include("layers.jl") end
@time @safetestset "rbm" begin include("rbm.jl") end
@time @safetestset "pseudolikelihood" begin include("pseudolikelihood.jl") end
@time @safetestset "pgm" begin include("compare_to_pgm/pgm.jl") end
@time @safetestset "partition" begin include("partition.jl") end
Expand Down

0 comments on commit e774670

Please sign in to comment.