Skip to content

Commit

Permalink
Base.Fix2 to resolve worldtime issue
Browse files Browse the repository at this point in the history
  • Loading branch information
terasakisatoshi committed Jun 4, 2022
1 parent e422978 commit 11edee6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ function create_model(hp::HyperParams)
k = 1 / (c * prod(kernel_size))
W = rand(rng, Uniform(-√k, k), kernel_size..., c, c_next)
b = rand(rng, Uniform(-√k, k), c_next)
push!(net, x -> mycircular(x, (kernel_size 2)...))
#push!(net, x -> mycircular(x, (kernel_size .÷ 2)...)) # <--- DO NOT USE. Use Base.Fix2 instead
# See https://github.com/JuliaIO/BSON.jl/issues/69
push!(net, Base.Fix2(mycircular, (kernel_size 2)))
if use_bn
push!(
net,
Expand Down
12 changes: 12 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ function mycircular(Y::AbstractArray{<:Real,2 + 2}, d1=1, d2=1)
return cat(Z_top, Z_main, Z_bottom, dims=1)
end

function mycircular(Y::AbstractArray{<:Real,2 + 2}, ds::NTuple{2,Int})
mycircular(Y, ds[1], ds[2])
end

"""
Differentiable padarray for 3D
"""
Expand All @@ -67,6 +71,10 @@ function mycircular(Y::AbstractArray{<:Real,3 + 2}, d1=1, d2=1, d3=1)
return cat(Z_end, Z_, Z_begin, dims=3)
end

function mycircular(Y::AbstractArray{<:Real,3 + 2}, ds::NTuple{3,Int})
mycircular(Y, ds[1], ds[2], ds[3])
end

"""
Differentiable padarray for 4D
"""
Expand Down Expand Up @@ -96,3 +104,7 @@ function mycircular(Y::AbstractArray{<:Real,4 + 2}, d1=1, d2=1, d3=1, d4=1)
Z_end = Z_[:, :, :, end-(d4-1):end, :, :]
return cat(Z_end, Z_, Z_begin, dims=4)
end

function mycircular(Y::AbstractArray{<:Real,4 + 2}, ds::NTuple{4,Int})
mycircular(Y, ds[1], ds[2], ds[3], ds[4])
end
26 changes: 26 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,32 @@ end
@test tar ref
end

@testset "circular padding with Base.Fix2" begin
# used for 2D Lattice
d1 = 2
d2 = 3
d3 = 4
d4 = 5
x = rand(4, 4, 4, 4)
c = Chain(Base.Fix2(GomalizingFlow.mycircular, (d1, d2)))
tar = c(x)
ref = ImageFiltering.padarray(x, Pad(:circular, d1, d2, 0, 0)).parent
@test tar ref
# used for 3D Lattice
x = rand(4, 4, 4, 4, 4)
c = Chain(Base.Fix2(GomalizingFlow.mycircular, (d1, d2, d3)))
tar = c(x)
ref = ImageFiltering.padarray(x, Pad(:circular, d1, d2, d3, 0, 0)).parent
@test tar ref

# used for 4D Lattice
x = rand(7, 7, 7, 7, 7, 7)
c = Chain(Base.Fix2(GomalizingFlow.mycircular, (d1, d2, d3, d4)))
tar = c(x)
ref = ImageFiltering.padarray(x, Pad(:circular, d1, d2, d3, d4, 0, 0)).parent
@test tar ref
end

@testset "make_checker_mask" begin
@test GomalizingFlow.make_checker_mask((8, 8), 0) == [
0 1 0 1 0 1 0 1
Expand Down

0 comments on commit 11edee6

Please sign in to comment.