diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 000000000..323237bab --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "blue" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5e31ddcdb..dc73a1d7a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - x86 - x64 steps: - - uses: actions/checkout@v4.0.0 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/.github/workflows/Cancel.yml b/.github/workflows/Cancel.yml index 85b1ef3d2..652e014a9 100644 --- a/.github/workflows/Cancel.yml +++ b/.github/workflows/Cancel.yml @@ -13,7 +13,7 @@ jobs: cancel: runs-on: ubuntu-latest steps: - - uses: styfle/cancel-workflow-action@0.9.0 + - uses: styfle/cancel-workflow-action@0.12.0 with: all_but_latest: true workflow_id: ${{ github.event.workflow.id }} diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index f49c7cddf..c63b657c1 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -25,14 +25,14 @@ jobs: # package: {user: JuliaDiff, repo: Diffractor.jl} steps: - - uses: actions/checkout@v4.0.0 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - uses: julia-actions/julia-buildpkg@latest - name: Clone Downstream - uses: actions/checkout@v4.0.0 + uses: actions/checkout@v4 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream diff --git a/.github/workflows/JuliaNightly.yml b/.github/workflows/JuliaNightly.yml index 0d3526c0b..a1281c5f6 100644 --- a/.github/workflows/JuliaNightly.yml +++ b/.github/workflows/JuliaNightly.yml @@ -23,7 +23,7 @@ jobs: - x86 - x64 steps: - - uses: actions/checkout@v4.0.0 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/.github/workflows/VersionVigilante_pull_request.yml b/.github/workflows/VersionVigilante_pull_request.yml index 76fffeac4..57ee668e3 100644 --- a/.github/workflows/VersionVigilante_pull_request.yml +++ b/.github/workflows/VersionVigilante_pull_request.yml @@ -6,7 +6,7 @@ jobs: VersionVigilante: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4.0.0 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@latest - name: VersionVigilante.main id: versionvigilante_main diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 000000000..f80377a24 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,27 @@ +name: Format suggestions + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@latest + with: + version: 1 + - run: | + julia -e 'using Pkg; Pkg.add("JuliaFormatter")' + julia -e 'using JuliaFormatter; format("."; verbose=true)' + - uses: reviewdog/action-suggester@v1 + with: + tool_name: JuliaFormatter + fail_on_error: true + filter_mode: added diff --git a/Project.toml b/Project.toml index 8dfefc100..77f514046 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.54.0" +version = "1.58.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -23,15 +23,21 @@ Adapt = "3.4.0" ChainRulesCore = "1.15.3" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" +Distributed = "1" FiniteDifferences = "0.12.20" GPUArraysCore = "0.1.0" IrrationalConstants = "0.1.1, 0.2" JLArrays = "0.1" JuliaInterpreter = "0.8,0.9" +LinearAlgebra = "1" +Random = "1" RealDot = "0.1" SparseInverseSubset = "0.1" +SparseArrays = "1" StaticArrays = "1.2" +Statistics = "1" StructArrays = "0.6.11" +SuiteSparse = "1" julia = "1.6" [extras] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 28e73c166..6d33a22e7 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -43,6 +43,7 @@ include("rulesets/Base/indexing.jl") include("rulesets/Base/sort.jl") include("rulesets/Base/mapreduce.jl") include("rulesets/Base/broadcast.jl") +include("rulesets/Base/CoreLogging.jl") include("rulesets/Distributed/nondiff.jl") diff --git a/src/rulesets/Base/CoreLogging.jl b/src/rulesets/Base/CoreLogging.jl new file mode 100644 index 000000000..ae97f4e40 --- /dev/null +++ b/src/rulesets/Base/CoreLogging.jl @@ -0,0 +1,20 @@ +# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib) + +function rrule( + rc::RuleConfig{>:ChainRulesCore.HasReverseMode}, + ::typeof(Base.CoreLogging.with_logger), + f::Function, + logger::Base.CoreLogging.AbstractLogger, +) + y, f_pb = Base.CoreLogging.with_logger(logger) do + rrule_via_ad(rc, f) + end + with_logger_pullback(ȳ) = (NoTangent(), only(f_pb(ȳ)), NoTangent()) + return y, with_logger_pullback +end + +@non_differentiable Base.CoreLogging.current_logger(args...) +@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...) +@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...) +@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any) +@non_differentiable Base.CoreLogging.handle_message(::Any...) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 7fbf46062..078bb602a 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -351,7 +351,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) - + Ȳf = Ȳ @static if VERSION >= v"1.9" # Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358 @@ -360,7 +360,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R end end Yf = Y - @static if VERSION >= v"1.9" + @static if VERSION >= v"1.9" # Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358 if !isa(Y, AbstractArray) Yf = [Y] @@ -371,7 +371,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R B̄ = A' \ Ȳf Ā = -B̄ * Y' t = (B - A * Y) * B̄' - @static if VERSION >= v"1.9" + @static if VERSION >= v"1.9" # Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358 if !isa(t, AbstractArray) t = [t] diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 9576abd98..28cc11d19 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -94,6 +94,7 @@ end @scalar_rule fma(x, y, z) (y, x, true) @scalar_rule muladd(x, y, z) (y, x, true) +@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true) @scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent()) @scalar_rule( mod(x, y), diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 1334cc925..830571ecd 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -1,6 +1,10 @@ # Int rather than Int64/Integer is intentional -function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int) - return x.i, ẋ.i +function ChainRulesCore.frule((_, Δ, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}) + return (getfield(strct, sym), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym)) +end + +function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}, inbounds) + return (getfield(strct, sym, inbounds), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym)) end "for a given tuple type, returns a Val{N} where N is the length of the tuple" @@ -140,7 +144,7 @@ end ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...) function ∇getindex!(dx::AbstractArray, dy, inds::Integer...) - view(dx, inds...) .+= Ref(dy) + @views dx[inds...] += dy return dx end function ∇getindex!(dx::AbstractArray, dy, inds...) diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 58298f068..d35024163 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -483,10 +483,6 @@ end @non_differentiable Broadcast.result_style(::Any) @non_differentiable Broadcast.result_style(::Any, ::Any) -@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...) -@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...) -@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any) -@non_differentiable Base.CoreLogging.handle_message(::Any...) @non_differentiable Libc.free(::Any) @non_differentiable Libc.getpid() diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index da153d14e..245335774 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -96,7 +96,7 @@ end function _diagm_back(p, ȳ) k, v = p - d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix return Tangent{typeof(p)}(second = d) end diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 8ee8f8cd0..06de7a135 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -137,3 +137,26 @@ function rrule(::typeof(det), x::SparseMatrixCSC) end return Ω, det_pullback end + + +function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) + + function spdiagm_pullback(ȳ) + return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(m, n, kv...), spdiagm_pullback +end + +function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...) + function spdiagm_pullback(ȳ) + return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(kv...), spdiagm_pullback +end + +function rrule(::typeof(spdiagm), v::AbstractVector) + function spdiagm_pullback(ȳ) + return (NoTangent(), diag(unthunk(ȳ))) + end + return spdiagm(v), spdiagm_pullback +end diff --git a/test/rulesets/Base/CoreLogging.jl b/test/rulesets/Base/CoreLogging.jl new file mode 100644 index 000000000..28c0b74a8 --- /dev/null +++ b/test/rulesets/Base/CoreLogging.jl @@ -0,0 +1,11 @@ +# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib) +@testset "CoreLogging.jl" begin + @testset "with_logger" begin + test_rrule( + Base.CoreLogging.with_logger, + () -> 2.0 * 3.0, + Base.CoreLogging.NullLogger(); + check_inferred=false, + ) + end +end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 36452da1e..9a5278747 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -153,6 +153,16 @@ test_rrule(muladd, 10randn(), randn(), randn()) end + @testset "muladd ZeroTangent" begin + test_frule(muladd, 2.0, 3.0, ZeroTangent()) + test_frule(muladd, 2.0, ZeroTangent(), 4.0) + test_frule(muladd, ZeroTangent(), 3.0, 4.0) + + test_rrule(muladd, 2.0, 3.0, ZeroTangent()) + test_rrule(muladd, 2.0, ZeroTangent(), 4.0) + test_rrule(muladd, ZeroTangent(), 3.0, 4.0) + end + @testset "fma" begin test_frule(fma, 10randn(), randn(), randn()) test_rrule(fma, 10randn(), randn(), randn()) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index d3c7ecfb4..e878dd061 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -1,3 +1,19 @@ +struct FooTwoField + x::Float64 + y::Float64 +end + + +@testset "getfield" begin + test_frule(getfield, FooTwoField(1.5, 2.5), :x, check_inferred=false) + + test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false) + test_frule(getfield, (; a=1.5, b=2.5), 2) + + test_frule(getfield, (1.5, 2.5), 2) + test_frule(getfield, (1.5, 2.5), 2, true) +end + @testset "getindex" begin @testset "getindex(::Tuple, ...)" begin x = (1.2, 3.4, 5.6) @@ -161,6 +177,14 @@ @test Array(y3) == Array(x_23_gpu)[1, [1,1,2]] @test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0]) end + + @testset "getindex(::Array{<:AbstractGPUArray})" begin + x_gpu = jl(rand(1)) + y, back = rrule(getindex, [x_gpu], 1) + @test y === x_gpu + dxs_gpu = unthunk(back(jl([1.0]))[2]) + @test dxs_gpu == [jl([1.0])] + end end # first & tail handled by getfield rules diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index 03f1052c2..283452a8a 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -18,6 +18,51 @@ end test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4) end +# copied over from test/rulesets/LinearAlgebra/structured +@testset "spdiagm" begin + @testset "without size" begin + M, N = 7, 9 + s = (8, 8) + a = randn(M) + b = randn(M) + c = randn(M - 1) + ȳ = randn(s) + ps = (0 => a, 1 => b, 0 => c) + y, back = rrule(spdiagm, ps...) + @test y == spdiagm(ps...) + ∂self, ∂pa, ∂pb, ∂pc = back(ȳ) + @test ∂self === NoTangent() + ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c) + for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) + ∂px = unthunk(∂px) + @test ∂px isa Tangent{typeof(p)} + @test ∂px.first isa AbstractZero + @test ∂px.second ≈ ∂x_fd + end + end + @testset "with size" begin + M, N = 7, 9 + a = randn(M) + b = randn(M) + c = randn(M - 1) + ȳ = randn(M, N) + ps = (0 => a, 1 => b, 0 => c) + y, back = rrule(spdiagm, M, N, ps...) + @test y == spdiagm(M, N, ps...) + ∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ) + @test ∂self === NoTangent() + @test ∂M === NoTangent() + @test ∂N === NoTangent() + ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c) + for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) + ∂px = unthunk(∂px) + @test ∂px isa Tangent{typeof(p)} + @test ∂px.first isa AbstractZero + @test ∂px.second ≈ ∂x_fd + end + end +end + @testset "findnz" begin A = sprand(5, 5, 0.5) dA = similar(A) @@ -42,4 +87,4 @@ end test_rrule(logabsdet, A) test_rrule(logdet, A) test_rrule(det, A) -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index a9f25c55c..768f7c208 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ end test_method_tables() # Check the global method tables are consistent # Each file puts all tests inside one or more @testset blocks + include_test("rulesets/Base/CoreLogging.jl") include_test("rulesets/Base/base.jl") include_test("rulesets/Base/fastmath_able.jl") include_test("rulesets/Base/evalpoly.jl") diff --git a/test/unzipped.jl b/test/unzipped.jl index 97aaa23f5..4215f3a6e 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -87,11 +87,14 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map # TODO invent some tests of this rrule's pullback function @test unzip(jl([(1,2), (3,4), (5,6)])) == (jl([1, 3, 5]), jl([2, 4, 6])) - @test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] == jl([2, 4, 6]) - @test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] isa Base.ReinterpretArray - @test unzip(jl([(1,), (3,), (5,)]))[1] == jl([1, 3, 5]) - @test unzip(jl([(1,), (3,), (5,)]))[1] isa Base.ReinterpretArray + + # depending on Julia/package versions, may get ReinterpretArray or JLArray + # Either is acceptable + @test isa( + unzip(jl([(missing, 2), (missing, 4), (missing, 6)]))[2], + Union{Base.ReinterpretArray,JLArray}, + ) end -end \ No newline at end of file +end