Skip to content

Commit

Permalink
Merge branch 'main' into compathelper/new_version/2023-10-29-00-58-21…
Browse files Browse the repository at this point in the history
…-718-03943160536
  • Loading branch information
avik-pal authored Oct 29, 2023
2 parents 30123c0 + af90779 commit b62d359
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 6 deletions.
12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.5.8"
version = "0.5.9"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -26,6 +26,7 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[weakdeps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand All @@ -36,6 +37,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LuxChainRulesExt = "ChainRules"
LuxComponentArraysExt = "ComponentArrays"
LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"]
LuxFluxTransformExt = "Flux"
Expand All @@ -47,31 +49,37 @@ LuxZygoteExt = "Zygote"
[compat]
ADTypes = "0.1, 0.2"
Adapt = "3"
ChainRules = "1"
ChainRulesCore = "1"
ComponentArrays = "0.15.2"
ConcreteStructs = "0.2"
FillArrays = "0.13, 1"
Flux = "0.13, 0.14"
Functors = "0.2, 0.3, 0.4"
LinearAlgebra = "1.6"
LuxAMDGPU = "0.1, 0.2"
LuxCUDA = "0.2, 0.3"
LuxCore = "0.1.6"
LuxDeviceUtils = "0.1"
LuxLib = "0.3"
MacroTools = "0.5"
Markdown = "1.6"
Optimisers = "0.2, 0.3"
PackageExtensionCompat = "1"
Random = "1.6"
Reexport = "1"
ReverseDiff = "1"
Setfield = "0.8, 1"
Statistics = "1"
SparseArrays = "1.6"
Statistics = "1.6"
Tracker = "0.2"
TruncatedStacktraces = "1.1"
WeightInitializers = "0.1"
Zygote = "0.6"
julia = "1.6"

[extras]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
1 change: 1 addition & 0 deletions examples/DDIM/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ Optimisers = "0.2, 0.3"
Plots = "1"
ProgressBars = "1"
Setfield = "1"
Statistics = "1"
Zygote = "0.6"
julia = "1.6"
1 change: 1 addition & 0 deletions examples/HyperNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ MLUtils = "0.2, 0.3, 0.4"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2, 0.3"
Setfield = "0.8, 1"
Statistics = "1"
Zygote = "0.6"
1 change: 1 addition & 0 deletions examples/ImageNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ OneHotArrays = "0.1, 0.2"
Optimisers = "0.2, 0.3"
Setfield = "0.8.2, 1"
SimpleConfig = "0.1"
Statistics = "1"
Zygote = "0.6"
1 change: 1 addition & 0 deletions examples/NeuralODE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ OneHotArrays = "0.1, 0.2"
Optimisers = "0.2, 0.3"
OrdinaryDiffEq = "6"
SciMLSensitivity = "7.45"
Statistics = "1"
Zygote = "0.6"
1 change: 1 addition & 0 deletions examples/PolynomialFitting/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ LuxAMDGPU = "0.1, 0.2"
LuxCUDA = "0.2, 0.3"
MakiePublication = "0.3"
Optimisers = "0.2, 0.3"
Statistics = "1"
Zygote = "0.6"
1 change: 1 addition & 0 deletions examples/SimpleRNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ LuxAMDGPU = "0.1, 0.2"
LuxCUDA = "0.2, 0.3"
MLUtils = "0.2, 0.3, 0.4"
Optimisers = "0.2, 0.3"
Statistics = "1"
Zygote = "0.6"
7 changes: 3 additions & 4 deletions examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ end

# Now we need to define the behavior of the Classifier when it is invoked.

function (s::SpiralClassifier)(x::AbstractArray{T, 3},
ps::NamedTuple,
function (s::SpiralClassifier)(x::AbstractArray{T, 3}, ps::NamedTuple,
st::NamedTuple) where {T}
## First we will have to run the sequence through the LSTM Cell
## The first call to LSTM Cell will create the initial hidden state
Expand Down Expand Up @@ -153,8 +152,8 @@ function main()
for (x, y) in train_loader
x = x |> dev
y = y |> dev
(loss, y_pred, st), back = pullback(p -> compute_loss(x, y, model, p, st), ps)
gs = back((one(loss), nothing, nothing))[1]
(loss, y_pred, st), back = pullback(compute_loss, x, y, model, ps, st)
gs = back((one(loss), nothing, nothing))[4]
opt_state, ps = Optimisers.update(opt_state, ps, gs)

println("Epoch [$epoch]: Loss $loss")
Expand Down
18 changes: 18 additions & 0 deletions ext/LuxChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module LuxChainRulesExt

using ChainRules, ChainRulesCore, Lux

# https://github.com/FluxML/Zygote.jl/pull/1328 broke the RNNs completely. Putting an
# emergency patch here
function ChainRules._setindex_zero(x::Vector{<:AbstractArray{T}}, dy,
inds::Integer...) where {T <: Number}
return [fill!(similar(xᵢ), 0) for xᵢ in x]
end

function ChainRules.∇getindex!(dx::Vector{<:AbstractArray{T}}, dy,
inds::Integer...) where {T <: Number}
dx[inds...] .+= dy
return dx
end

end

0 comments on commit b62d359

Please sign in to comment.