Skip to content

Commit

Permalink
Remove Val in typeinfo of WeightNorm (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Aug 26, 2022
1 parent 4e394f7 commit 4936d51
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Random.seed!(rng, 0)

# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
Chain(Dense(256, 1, tanh),Dense(1, 10)))
Chain(Dense(256, 1, tanh), Dense(1, 10)))

# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> gpu
Expand Down
8 changes: 4 additions & 4 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ end

function WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N, Symbol},
dims::Union{Tuple, Nothing}=nothing) where {N}
return WeightNorm{Val{which_params}, typeof(layer), typeof(dims)}(layer, dims)
return WeightNorm{which_params, typeof(layer), typeof(dims)}(layer, dims)
end

@inline _norm(x; dims=Colon()) = sqrt.(sum(abs2, x; dims=dims))
Expand All @@ -384,7 +384,7 @@ end
@inline _get_norm_except_dims(N, dims::Tuple) = filter(i -> !(i in dims), 1:N)

function initialparameters(rng::AbstractRNG,
wn::WeightNorm{Val{which_params}}) where {which_params}
wn::WeightNorm{which_params}) where {which_params}
ps_layer = initialparameters(rng, wn.layer)
ps_normalized = []
ps_unnormalized = []
Expand Down Expand Up @@ -419,7 +419,7 @@ function (wn::WeightNorm)(x, ps, s::NamedTuple)
return wn.layer(x, merge(_ps, ps.unnormalized), s)
end

@inbounds @generated function _get_normalized_parameters(::WeightNorm{Val{which_params}},
@inbounds @generated function _get_normalized_parameters(::WeightNorm{which_params},
dims::T,
ps) where {T, which_params}
parameter_names = string.(which_params)
Expand Down Expand Up @@ -449,7 +449,7 @@ end
return Expr(:block, calls...)
end

function Base.show(io::IO, w::WeightNorm{Val{which_params}}) where {which_params}
function Base.show(io::IO, w::WeightNorm{which_params}) where {which_params}
return print(io, "WeightNorm{", which_params, "}(", w.layer, ")")
end

Expand Down

0 comments on commit 4936d51

Please sign in to comment.