Skip to content

Commit

Permalink
Merge pull request #150 from PumasAI/separatepullbackarg
Browse files Browse the repository at this point in the history
`alloc_return` needs 5 args
  • Loading branch information
chriselrod authored Aug 29, 2023
2 parents f028d69 + c95e8bc commit 87a5b40
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleChains"
uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.4.5"
version = "0.4.6"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
3 changes: 3 additions & 0 deletions src/SimpleChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ using LoopVectorization: matmul_params, @turbo
# macro turbo(ex0, ex1)
# esc(ex1)
# end
# macro turbo(ex0, ex1, ex2)
# esc(ex2)
# end

export SimpleChain,
TurboDense,
Expand Down
65 changes: 60 additions & 5 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,26 @@ end
C[m, n] = Cmn
end
end
@inline function denserev!(
::typeof(identity),
::Nothing,
_C::AbstractArray{T1,N},
_A::AbstractMatrix,
_B::AbstractArray{T2,N},
inds::AbstractVector{<:Integer},
::False
) where {T1<:Base.HWReal,T2<:Base.HWReal,N}
C = zero_offsets(_C)
A = zero_offsets(_A)
B = zero_offsets(_B)
@turbo for n indices((B, C), 2), m indices((A, C), 1)
Cmn = zero(eltype(C))
for k indices((A, inds), (2, 1))
Cmn += A[m, k] * B[inds[k], n]
end
C[m, n] = Cmn
end
end

function valgrad_layer!(
pg::Ptr{T},
Expand All @@ -843,16 +863,22 @@ function valgrad_layer!(
pu::Ptr{UInt8}
) where {T,O}
input_dim = static_size(B, StaticInt(1))
batch_size = static_size(B, StaticInt(2))
batch_size = static_size(inds, StaticInt(1))
pu2 = Base.unsafe_convert(
Ptr{T},
pu + align(batch_size * td.outputdim * sizeof(T))
)
C, _pu3 =
alloc_return(td, batch_size, pu2, contiguous_axis(B), stride_rank(B))
A, p2 = getparams(td, p, input_dim)
C, _pu3 = alloc_return(
td,
batch_size,
pu2,
contiguous_axis(B),
stride_rank(A),
Val(ndims(B))
)
pu3 = Base.unsafe_convert(Ptr{UInt8}, _pu3)
∂C, _ = get∂C(td, C, pu)
A, p2 = getparams(td, p, input_dim)
f = td.f
dense!(f, ∂C, C, A, B, inds, static(O))
# doesn'tneed a pullback
Expand All @@ -871,7 +897,7 @@ function chain_valgrad_entry!(
pg2, larg, p2, pu2 = valgrad_layer!(pgp, l, arg, inds, p, pu)
val, grad, pu3 = chain_valgrad!(pg2, larg, Base.tail(layers), p2, pu2)
pu = pullback_common!(pgp, l, grad, arg, p, pu)::Ptr{UInt8}
pullback_param!(pgp, l, grad, arg, p, pu)
pullback_param!(pgp, l, grad, arg, p, pu, inds)
pga === nothing || pullback_arg!(pga, l, grad, arg, p, pu, pu3)
return val
end
Expand Down Expand Up @@ -961,6 +987,20 @@ function pullback_param!(
dense_param_update!(td, Ā, C̄, B)
return nothing
end
function pullback_param!(
pg::Ptr{T},
td::TurboDense{O},
C̄,
B,
::Ptr{T},
pu::Ptr{UInt8},
inds
) where {T,O}
# Ā = C̄ * B'
= first(getparams(td, pg, static_size(B, StaticInt(1))))
dense_param_update!(td, Ā, C̄, B, inds)
return nothing
end
function dense_param_update!(::TurboDense{true}, Ā, C̄, B)
Kp1 = static_size(Ā, StaticInt(2))
K = Kp1 - StaticInt(1)
Expand All @@ -976,6 +1016,21 @@ end
function dense_param_update!(::TurboDense{false}, Ā, C̄, B)
dense!(identity, nothing, Ā, C̄, B', False())
end
function dense_param_update!(::TurboDense{true}, Ā, C̄, B, inds)
Kp1 = static_size(Ā, StaticInt(2))
K = Kp1 - StaticInt(1)
denserev!(identity, nothing, view(Ā, :, static(1):K), C̄, B', inds, False())
@turbo for m axes(Ā, 1)
s = zero(eltype(Ā))
for n axes(C̄, 2)
s += C̄[m, n]
end
Ā[m, Kp1] = s
end
end
function dense_param_update!(::TurboDense{false}, Ā, C̄, B, inds)
dense!(identity, nothing, Ā, C̄, B', inds, False())
end

@inline function dense!(
f,
Expand Down
20 changes: 20 additions & 0 deletions test/batch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using SimpleChains, Test
@testset "batch" begin
x = randn(3, 100)
chain = SimpleChain(
static(3),
TurboDense{true}(tanh, 8),
TurboDense{true}(identity, 4)
)
opt = SimpleChains.ADAM()
penalty = L2Penalty(0.01)
loss = SquaredLoss

p0 = SimpleChains.init_params(chain)
y = Matrix(chain(x, p0))
model = penalty(SimpleChains.add_loss(chain, loss(y)))
p1 = SimpleChains.init_params(chain)
origloss = model(x, p1)
@time SimpleChains.train_batched!(copy(p1), p1, model, x, opt, 10_000)
@test model(x, p1) < 0.25origloss
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ InteractiveUtils.versioninfo(; verbose = true)
@testset "Glorot" begin
include("random.jl")
end
@testset "Batch" begin
include("batch.jl")
end
end
# TODO: test ambiguities once ForwardDiff fixes them, or once ForwardDiff is dropped
# For now, there are the tests at the start.
Expand Down

4 comments on commit 87a5b40

@chriselrod
Copy link
Contributor Author

@chriselrod chriselrod commented on 87a5b40 Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Action not recognized: registe

@chriselrod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/90434

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.6 -m "<description of version>" 87a5b400d798bd38dfde8ab93cce959a2b7d3ce3
git push origin v0.4.6

Please sign in to comment.