Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array overloads #131

Merged
merged 34 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
91bacda
Add array overloads
adrhill Jun 19, 2024
a6354e4
Formatting
adrhill Jun 19, 2024
7b35af5
Fix test
adrhill Jun 19, 2024
7f54674
Merge branch 'main' into ah/array-overloads-2
adrhill Jun 20, 2024
695d142
Add `logabsdet`
adrhill Jun 20, 2024
bd3e4b6
Add `norm` and `opnorm`
adrhill Jun 20, 2024
f5c8e15
Add `eigen`, `eigmax`, `eigmin`
adrhill Jun 20, 2024
d94f67b
Add `exp`, `^`
adrhill Jun 20, 2024
3955e06
Remove overloads on `float`
adrhill Jun 20, 2024
d1d7b75
Introduce first- and second-order OR
adrhill Jun 20, 2024
7455bdb
Fix `opnorm`
adrhill Jun 20, 2024
9985a4c
Support casting matrix of tracers to sparse matrix
adrhill Jun 20, 2024
6aadcf5
Add comment linking to SparseArrays source
adrhill Jun 20, 2024
ce45598
More tests
adrhill Jun 20, 2024
69e6636
Keep `float` around in this PR
adrhill Jun 20, 2024
006c045
Fix `logabsdet`
adrhill Jun 20, 2024
a5a5126
Add determinant overloads for Duals
adrhill Jun 20, 2024
0cb5812
Minor cleanup
adrhill Jun 20, 2024
8b7b883
Update `logabsdet` tests
adrhill Jun 21, 2024
3738631
Fix compat entries
adrhill Jun 24, 2024
8590ad9
Less conservative `logabsdet`
adrhill Jun 24, 2024
89b617e
Rename according to review
adrhill Jun 24, 2024
e722a15
Shuffle things around
adrhill Jun 24, 2024
7b2df1d
Refactor tests
adrhill Jun 24, 2024
87180b9
Check zeros and ones in `^`
adrhill Jun 24, 2024
88ce464
More context on Dual overloads
adrhill Jun 24, 2024
dfbffb2
Refactor tests
adrhill Jun 24, 2024
75404c1
Shorter kwarg names
adrhill Jun 24, 2024
4a85f83
Try to get better error on 1.6
adrhill Jun 24, 2024
b0734d7
Exclude test on Julia 1.6
adrhill Jun 24, 2024
62c5f27
Fix `opnorm`
adrhill Jun 24, 2024
415fb9f
Remove `^1` case for type stability
adrhill Jun 25, 2024
8721ec9
Add suggestions from code review
adrhill Jun 25, 2024
981907c
Add tests for new cases
adrhill Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ version = "0.6.0-DEV"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -25,7 +27,10 @@ SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"
ADTypes = "1"
Compat = "3,4"
DocStringExtensions = "0.9"
FillArrays = "1"
LinearAlgebra = "<0.0.1, 1"
NNlib = "0.8, 0.9"
Random = "<0.0.1, 1"
Requires = "1.3"
SparseArrays = "<0.0.1, 1"
SpecialFunctions = "2.4"
Expand Down
5 changes: 5 additions & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ module SparseConnectivityTracer

using ADTypes: ADTypes
using Compat: Returns
using SparseArrays: SparseArrays
using SparseArrays: sparse
using Random: AbstractRNG, SamplerType

using LinearAlgebra: LinearAlgebra
using FillArrays: Fill

using DocStringExtensions

if !isdefined(Base, :get_extension)
Expand All @@ -26,6 +30,7 @@ include("overloads/hessian_tracer.jl")
include("overloads/ifelse_global.jl")
include("overloads/dual.jl")
include("overloads/overload_all.jl")
include("overloads/arrays.jl")

include("interface.jl")
include("adtypes.jl")
Expand Down
231 changes: 231 additions & 0 deletions src/overloads/arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""
second_order_or(tracers)

Compute the most conservative elementwise OR of tracer sparsity patterns,
including second-order interactions to update the `hessian` field of `HessianTracer`.

This is functionally equivalent to:
```julia
reduce(^, tracers)
```
"""
function second_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer}
# TODO: improve performance
return reduce(second_order_or, ts; init=myempty(T))
end

function second_order_or(a::T, b::T) where {T<:ConnectivityTracer}
return connectivity_tracer_2_to_1(a, b, false, false)
end
function second_order_or(a::T, b::T) where {T<:GradientTracer}
return gradient_tracer_2_to_1(a, b, false, false)
end
function second_order_or(a::T, b::T) where {T<:HessianTracer}
return hessian_tracer_2_to_1(a, b, false, false, false, false, false)
end

"""
first_order_or(tracers)

Compute the most conservative elementwise OR of tracer sparsity patterns,
excluding second-order interactions of `HessianTracer`.

This is functionally equivalent to:
```julia
reduce(+, tracers)
```
"""
function first_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer}
# TODO: improve performance
return reduce(first_order_or, ts; init=myempty(T))
end
function first_order_or(a::T, b::T) where {T<:ConnectivityTracer}
return connectivity_tracer_2_to_1(a, b, false, false)
end
function first_order_or(a::T, b::T) where {T<:GradientTracer}
return gradient_tracer_2_to_1(a, b, false, false)
end
function first_order_or(a::T, b::T) where {T<:HessianTracer}
return hessian_tracer_2_to_1(a, b, false, true, false, true, true)
end

#===========#
# Utilities #
#===========#

function split_dual_array(A::AbstractArray{D}) where {D<:Dual}
primals = getproperty.(A, :primal)
tracers = getproperty.(A, :tracer)
return primals, tracers
end
function split_dual_array(A::SparseArrays.SparseMatrixCSC{D}) where {D<:Dual}
A = Matrix(A)
primals = getproperty.(A, :primal)
tracers = getproperty.(A, :tracer)
return sparse(primals), sparse(tracers)
end

#==================#
# LinearAlgebra.jl #
#==================#

# TODO: replace `second_order_or` by less conservative sparsity patterns when possible

## Determinant
LinearAlgebra.det(A::AbstractMatrix{T}) where {T<:AbstractTracer} = second_order_or(A)
LinearAlgebra.logdet(A::AbstractMatrix{T}) where {T<:AbstractTracer} = second_order_or(A)
function LinearAlgebra.logabsdet(A::AbstractMatrix{T}) where {T<:AbstractTracer}
t1 = second_order_or(A)
t2 = sign(t1) # corresponds to sign of det(A): set first- and second-order derivatives to zero
return (t1, t2)
end

## Norm
function LinearAlgebra.norm(A::AbstractArray{T}, p::Real=2) where {T<:AbstractTracer}
return second_order_or(A)
end
function LinearAlgebra.opnorm(A::AbstractArray{T}, p::Real=2) where {T<:AbstractTracer}
if isone(p)
return first_order_or(A)
else
return second_order_or(A)
end
end
function LinearAlgebra.opnorm(A::AbstractMatrix{T}, p::Real=2) where {T<:AbstractTracer}
if isone(p)
return first_order_or(A)
else
return second_order_or(A)
end
end

## Eigenvalues

function LinearAlgebra.eigmax(
A::Union{T,AbstractMatrix{T}}; permute::Bool=true, scale::Bool=true
) where {T<:AbstractTracer}
return second_order_or(A)
end
function LinearAlgebra.eigmin(
A::Union{T,AbstractMatrix{T}}; permute::Bool=true, scale::Bool=true
) where {T<:AbstractTracer}
return second_order_or(A)
end
function LinearAlgebra.eigen(
A::AbstractMatrix{T};
permute::Bool=true,
scale::Bool=true,
sortby::Union{Function,Nothing}=nothing,
) where {T<:AbstractTracer}
LinearAlgebra.checksquare(A)
n = size(A, 1)
t = second_order_or(A)
values = Fill(t, n)
vectors = Fill(t, n, n)
return LinearAlgebra.Eigen(values, vectors)
end

## Inverse
function LinearAlgebra.inv(A::StridedMatrix{T}) where {T<:AbstractTracer}
LinearAlgebra.checksquare(A)
t = second_order_or(A)
return Fill(t, size(A)...)
end
function LinearAlgebra.pinv(
A::AbstractMatrix{T}; atol::Real=0.0, rtol::Real=0.0
) where {T<:AbstractTracer}
n, m = size(A)
t = second_order_or(A)
return Fill(t, m, n)
end

## Division
function LinearAlgebra.:\(
A::AbstractMatrix{T}, B::AbstractVecOrMat
) where {T<:AbstractTracer}
Ainv = LinearAlgebra.pinv(A)
return Ainv * B
end

## Exponent
function LinearAlgebra.exp(A::AbstractMatrix{T}) where {T<:AbstractTracer}
LinearAlgebra.checksquare(A)
n = size(A, 1)
t = second_order_or(A)
return Fill(t, n, n)
end

## Matrix power
function LinearAlgebra.:^(A::AbstractMatrix{T}, p::Integer) where {T<:AbstractTracer}
LinearAlgebra.checksquare(A)
n = size(A, 1)
if iszero(p)
return Fill(myempty(T), n, n)
elseif isone(p)
return A
else
t = second_order_or(A)
return Fill(t, n, n)
end
end

#==========================#
# LinearAlgebra.jl on Dual #
#==========================#

# `Duals` should use LinearAlgebra's generic fallback implementations
# to compute the "least conservative" sparsity patterns possible on a scalar level.

# The following three methods are a temporary fix for issue #108.
# TODO: instead overload `lu` on AbstractMatrix of Duals.
function LinearAlgebra.det(A::AbstractMatrix{D}) where {D<:Dual}
primals, tracers = split_dual_array(A)
p = LinearAlgebra.logdet(primals)
t = LinearAlgebra.logdet(tracers)
return D(p, t)
end
function LinearAlgebra.logdet(A::AbstractMatrix{D}) where {D<:Dual}
primals, tracers = split_dual_array(A)
p = LinearAlgebra.logdet(primals)
t = LinearAlgebra.logdet(tracers)
return D(p, t)
end
function LinearAlgebra.logabsdet(A::AbstractMatrix{D}) where {D<:Dual}
primals, tracers = split_dual_array(A)
p1, p2 = LinearAlgebra.logabsdet(primals)
t1, t2 = LinearAlgebra.logabsdet(tracers)
return (D(p1, t1), D(p2, t2))
end

#==============#
# SparseArrays #
#==============#

# Conversion of matrices of tracers to SparseMatrixCSC has to be rewritten
# due to use of `count(_isnotzero, M)` in SparseArrays.jl
#
# Code modified from MIT licensed SparseArrays.jl source:
# https://github.com/JuliaSparse/SparseArrays.jl/blob/45dfe459ede2fa1419e7068d4bda92d9d22bd44d/src/sparsematrix.jl#L901-L920
# Copyright (c) 2009-2024: Jeff Bezanson, Stefan Karpinski, Viral B. Shah, and other contributors: https://github.com/JuliaLang/julia/contributors
function SparseArrays.SparseMatrixCSC{Tv,Ti}(
M::StridedMatrix{Tv}
) where {Tv<:AbstractTracer,Ti}
nz = count(!isemptytracer, M)
colptr = zeros(Ti, size(M, 2) + 1)
nzval = Vector{Tv}(undef, nz)
rowval = Vector{Ti}(undef, nz)
colptr[1] = 1
cnt = 1
@inbounds for j in 1:size(M, 2)
for i in 1:size(M, 1)
v = M[i, j]
if !isemptytracer(v)
rowval[cnt] = i
nzval[cnt] = v
cnt += 1
end
end
colptr[j + 1] = cnt
end
return SparseArrays.SparseMatrixCSC(size(M, 1), size(M, 2), colptr, rowval, nzval)
end
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core")
Aqua.test_all(
SparseConnectivityTracer;
ambiguities=false,
deps_compat=(ignore=[:Random, :SparseArrays], check_extras=false),
deps_compat=(check_extras=false,),
stale_deps=(ignore=[:Requires],),
persistent_tasks=false,
)
Expand Down Expand Up @@ -82,6 +82,9 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core")
@testset "HessianTracer" begin
include("test_hessian.jl")
end
@testset "Array overloads" begin
include("test_arrays.jl")
end
end
end

Expand Down
Loading
Loading