You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm attempting to implement a coupling layer (as in #416 ), now with custom layers. Again, they work fine with NamedTuple but not with ComponentArray. This doesn't seem to be related to #416 . This is possibly due to me not understanding the role of AbstractExplicitContainerLayer and therefore not using it.
The issue appears when the custom layer contains a Chain of other Lux layers, but not when they contain a single layer.
Here's some code:
using Lux, ComponentArrays, Random
rng = Random.default_rng()
struct LeapFrog{T} <:Lux.AbstractExplicitLayer
sub_net::Tend
(frog::LeapFrog)(x,ps,st) = (frog.sub_net(x[1],ps,st)[1]+x[2],x[1]),st
Lux.initialparameters(rng::AbstractRNG,frog::LeapFrog) = Lux.initialparameters(rng::AbstractRNG,frog.sub_net)
Lux.initialstates(rng::AbstractRNG,frog::LeapFrog) = Lux.initialstates(rng::AbstractRNG,frog.sub_net)
#LeapFrog containing a single Dense layer, TYPE STABLE
D =Dense(1=>1)
F =LeapFrog(D)
C =Chain(F,F,F,F)
ps,st = Lux.setup(rng,C)
v = ([10.],[-10.])
C(v,ps,st)
@code_warntypeC(v,ps,st) #Type stable
psc = ps |> ComponentArray
@code_warntypeC(v,psc,st)#Type stable#LeapFrog containing a Chain, NOT TYPE STABLE
D =Chain(Dense(1=>10),Dense(10=>1))
F =LeapFrog(D)
C =Chain(F,F,F,F)
ps,st = Lux.setup(rng,C)
v = ([10.],[-10.])
C(v,ps,st)
@code_warntypeC(v,ps,st) #Type stable
psc = ps |> ComponentArray
@code_warntypeC(v,psc,st)#NOT Type stable
The text was updated successfully, but these errors were encountered:
Can you try changing frog.sub_net(x[1],ps,st)[1] to first(frog.sub_net(x[1],ps,st))? The first one might cause type instability, since you are indexing into a heterogeneous container.
Hey.
I'm attempting to implement a coupling layer (as in #416 ), now with custom layers. Again, they work fine with
NamedTuple
but not withComponentArray
. This doesn't seem to be related to #416 . This is possibly due to me not understanding the role ofAbstractExplicitContainerLayer
and therefore not using it.The issue appears when the custom layer contains a
Chain
of other Lux layers, but not when they contain a single layer.Here's some code:
The text was updated successfully, but these errors were encountered: