Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Migrate to an operator based implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 29, 2024
1 parent d0ce078 commit 7dd7160
Show file tree
Hide file tree
Showing 13 changed files with 415 additions and 486 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down Expand Up @@ -53,6 +54,7 @@ PrecompileTools = "1.2.0"
Random = "<0.0.1, 1"
ReTestItems = "1.23.1"
ReverseDiff = "1.15"
SciMLOperators = "0.3.8"
StableRNGs = "1.0.1"
Statistics = "1.11.1"
Test = "<0.0.1, 1"
Expand Down
5 changes: 3 additions & 2 deletions ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
module BatchedRoutinesCUDAExt

using BatchedRoutines: AbstractBatchedMatrixFactorization, BatchedRoutines,
UniformBlockDiagonalMatrix, batchview, nbatches
UniformBlockDiagonalOperator, batchview, nbatches
using CUDA: CUBLAS, CUDA, CUSOLVER, CuArray, CuMatrix, CuPtr, CuVector, DenseCuArray,
DenseCuMatrix
using ConcreteStructs: @concrete
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!

const CuBlasFloat = Union{Float16, Float32, Float64, ComplexF32, ComplexF64}

const CuUniformBlockDiagonalMatrix{T} = UniformBlockDiagonalMatrix{T, <:CuArray{T, 3}}
const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{
T, <:CUDA.AnyCuArray{T, 3}}

include("batched_mul.jl")
include("factorization.jl")
Expand Down
12 changes: 8 additions & 4 deletions ext/BatchedRoutinesCUDAExt/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ end

for pT in (:RowMaximum, :RowNonZero, :NoPivot)
@eval begin
function LinearAlgebra.lu!(A::CuUniformBlockDiagonalMatrix, pivot::$pT; kwargs...)
function LinearAlgebra.lu!(A::CuUniformBlockDiagonalOperator, pivot::$pT; kwargs...)
return LinearAlgebra.lu!(A, !(pivot isa NoPivot); kwargs...)
end
end
end

function LinearAlgebra.lu!(
A::CuUniformBlockDiagonalMatrix, pivot::Bool=true; check::Bool=true, kwargs...)
A::CuUniformBlockDiagonalOperator, pivot::Bool=true; check::Bool=true, kwargs...)
pivot_array, info_, factors = CUBLAS.getrf_strided_batched!(A.data, pivot)
info = Array(info_)
check && LinearAlgebra.checknonsingular.(info)
Expand Down Expand Up @@ -82,11 +82,15 @@ function Base.show(io::IO, QR::CuBatchedQR)
return print(io, "CuBatchedQR() with Batch Count: $(nbatches(QR))")
end

function LinearAlgebra.qr!(::CuUniformBlockDiagonalMatrix, ::ColumnNorm; kwargs...)
function LinearAlgebra.qr!(A::CuUniformBlockDiagonalOperator; kwargs...)
return LinearAlgebra.qr!(A, NoPivot(); kwargs...)
end

function LinearAlgebra.qr!(::CuUniformBlockDiagonalOperator, ::ColumnNorm; kwargs...)
throw(ArgumentError("ColumnNorm is not supported for batched CUDA QR factorization!"))
end

function LinearAlgebra.qr!(A::CuUniformBlockDiagonalMatrix, ::NoPivot; kwargs...)
function LinearAlgebra.qr!(A::CuUniformBlockDiagonalOperator, ::NoPivot; kwargs...)
τ, factors = CUBLAS.geqrf_batched!(collect(batchview(A)))
return CuBatchedQR{eltype(A)}(factors, τ, size(A))
end
Expand Down
8 changes: 4 additions & 4 deletions ext/BatchedRoutinesFiniteDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BatchedRoutinesFiniteDiffExt

using ADTypes: AutoFiniteDiff
using ArrayInterface: matrix_colors, parameterless_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalMatrix, _assert_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, _assert_type
using FastClosures: @closure
using FiniteDiff: FiniteDiff

Expand All @@ -14,15 +14,15 @@ using FiniteDiff: FiniteDiff
ad::AutoFiniteDiff, f::F, x::AbstractVector{T}) where {F, T}
J = FiniteDiff.finite_difference_jacobian(f, x, ad.fdjtype)
(_assert_type(f) && _assert_type(x) && Base.issingletontype(F)) &&
(return UniformBlockDiagonalMatrix(J::parameterless_type(x){T, 2}))
return UniformBlockDiagonalMatrix(J)
(return UniformBlockDiagonalOperator(J::parameterless_type(x){T, 2}))
return UniformBlockDiagonalOperator(J)
end

@inline function BatchedRoutines._batched_jacobian(
ad::AutoFiniteDiff, f::F, x::AbstractMatrix) where {F}
f! = @closure (y, x_) -> copyto!(y, f(x_))
fx = f(x)
J = UniformBlockDiagonalMatrix(similar(
J = UniformBlockDiagonalOperator(similar(
x, promote_type(eltype(fx), eltype(x)), size(fx, 1), size(x, 1), size(x, 2)))
sparsecache = FiniteDiff.JacobianCache(
x, fx, ad.fdjtype; colorvec=matrix_colors(J), sparsity=J)
Expand Down
8 changes: 4 additions & 4 deletions ext/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BatchedRoutinesForwardDiffExt

using ADTypes: AutoForwardDiff
using ArrayInterface: parameterless_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalMatrix, batched_jacobian,
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, batched_jacobian,
batched_mul, batched_pickchunksize, _assert_type
using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
Expand Down Expand Up @@ -117,7 +117,7 @@ end
else
jac_call = :((y, J) = __batched_value_and_jacobian(ad, f, u, $(Val(CK))))
end
return Expr(:block, jac_call, :(return (y, UniformBlockDiagonalMatrix(J))))
return Expr(:block, jac_call, :(return (y, UniformBlockDiagonalOperator(J))))
end

## Exposed API
Expand All @@ -132,8 +132,8 @@ end
end
J = ForwardDiff.jacobian(f, u, cfg)
(_assert_type(f) && _assert_type(u) && Base.issingletontype(F)) &&
(return UniformBlockDiagonalMatrix(J::parameterless_type(u){T, 2}))
return UniformBlockDiagonalMatrix(J)
(return UniformBlockDiagonalOperator(J::parameterless_type(u){T, 2}))
return UniformBlockDiagonalOperator(J)
end

@inline function BatchedRoutines._batched_jacobian(
Expand Down
25 changes: 20 additions & 5 deletions src/BatchedRoutines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ import PrecompileTools: @recompile_invalidations
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero,
mul!, pinv
using LuxDeviceUtils: LuxDeviceUtils, get_device
using SciMLOperators: AbstractSciMLOperator
end

function __init__()
@static if isdefined(Base.Experimental, :register_error_hint)
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs
if any(Base.Fix2(isa, UniformBlockDiagonalOperator), exc.args)
print(io, "\nHINT: ")
printstyled(
io, "`UniformBlockDiagonalOperator` doesn't support AbstractArray \
operations. If you want this supported open an issue at \
https://github.com/LuxDL/BatchedRoutines.jl to discuss it.";
color=:cyan)
end
end
end
end

const CRC = ChainRulesCore
Expand All @@ -28,7 +44,9 @@ const BatchedMatrix{T} = AbstractArray{T, 3}

include("api.jl")
include("helpers.jl")
include("matrix.jl")

include("operator.jl")
include("factorization.jl")

include("impl/batched_mul.jl")
include("impl/batched_gmres.jl")
Expand All @@ -39,9 +57,6 @@ export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote
export batched_adjoint, batched_gradient, batched_jacobian, batched_pickchunksize,
batched_mul, batched_pinv, batched_transpose
export batchview, nbatches
export UniformBlockDiagonalMatrix

# TODO: Ship a custom GMRES routine & if needed some of the other complex nonlinear solve
# routines
export UniformBlockDiagonalOperator

end
7 changes: 4 additions & 3 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
batched_jacobian(ad, f::F, x, p) where {F}
Use the backend `ad` to compute the Jacobian of `f` at `x` in batched mode. Returns a
[`UniformBlockDiagonalMatrix`](@ref) as the Jacobian.
[`UniformBlockDiagonalOperator`](@ref) as the Jacobian.
!!! warning
Expand Down Expand Up @@ -63,6 +63,7 @@ batched_mul!(C, A, B, α=true, β=false) = _batched_mul!(C, A, B, α, β)
Transpose the first two dimensions of `X`.
"""
batched_transpose(X::BatchedMatrix) = PermutedDimsArray(X, (2, 1, 3))
batched_transpose(X::AbstractMatrix) = reshape(X, 1, size(X, 1), size(X, 2))

"""
batched_adjoint(X::AbstractArray{T, 3}) where {T}
Expand Down Expand Up @@ -101,15 +102,15 @@ batchview(A::AbstractVector{T}) where {T} = isbitstype(T) ? (A,) : A

"""
batched_pinv(A::AbstractArray{T, 3}) where {T}
batched_pinv(A::UniformBlockDiagonalMatrix)
batched_pinv(A::UniformBlockDiagonalOperator)
Compute the pseudo-inverse of `A` in batched mode.
"""
@inline batched_pinv(x::AbstractArray{T, 3}) where {T} = _batched_map(pinv, x)

"""
batched_inv(A::AbstractArray{T, 3}) where {T}
batched_inv(A::UniformBlockDiagonalMatrix)
batched_inv(A::UniformBlockDiagonalOperator)
Compute the inverse of `A` in batched mode.
"""
Expand Down
36 changes: 7 additions & 29 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,40 +101,18 @@ function CRC.rrule(::typeof(_batched_mul), A::AbstractArray{T1, 3},
return batched_mul(A, B), ∇batched_mul
end

function CRC.rrule(::typeof(*), X::UniformBlockDiagonalMatrix{<:Union{Real, Complex}},
Y::AbstractMatrix{<:Union{Real, Complex}})
function ∇times(_Δ)
Δ = CRC.unthunk(_Δ)
∂X = CRC.@thunk*batched_adjoint(batched_reshape(Y, :, 1)))
∂Y = CRC.@thunk begin
res = (X' * Δ)
Y isa UniformBlockDiagonalMatrix ? res : dropdims(res.data; dims=2)
end
return (NoTangent(), ∂X, ∂Y)
end
return X * Y, ∇times
end

function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(*),
X::AbstractMatrix{<:Union{Real, Complex}},
Y::UniformBlockDiagonalMatrix{<:Union{Real, Complex}})
_f = @closure (x, y) -> dropdims(
batched_mul(reshape(x, :, 1, nbatches(x)), y.data); dims=1)
return CRC.rrule_via_ad(cfg, _f, X, Y)
end

# constructor
function CRC.rrule(::Type{<:UniformBlockDiagonalMatrix}, data)
function UniformBlockDiagonalMatrix(Δ)
∂data = Δ isa UniformBlockDiagonalMatrix ? Δ.data :
function CRC.rrule(::Type{<:UniformBlockDiagonalOperator}, data)
function UniformBlockDiagonalOperator(Δ)
∂data = Δ isa UniformBlockDiagonalOperator ? getdata(Δ) :
isa NoTangent ? NoTangent() : Δ)
return (NoTangent(), ∂data)
end
return UniformBlockDiagonalMatrix(data), ∇UniformBlockDiagonalMatrix
return UniformBlockDiagonalOperator(data), ∇UniformBlockDiagonalOperator
end

function CRC.rrule(::typeof(getproperty), A::UniformBlockDiagonalMatrix, x::Symbol)
function CRC.rrule(::typeof(getproperty), op::UniformBlockDiagonalOperator, x::Symbol)
@assert x === :data
∇getproperty(Δ) = (NoTangent(), UniformBlockDiagonalMatrix(Δ))
return A.data, ∇getproperty
∇getproperty(Δ) = (NoTangent(), UniformBlockDiagonalOperator(Δ))
return op.data, ∇getproperty
end
115 changes: 115 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
abstract type AbstractBatchedMatrixFactorization{T} <: LinearAlgebra.Factorization{T} end

const AdjointAbstractBatchedMatrixFactorization{T} = LinearAlgebra.AdjointFactorization{
T, <:AbstractBatchedMatrixFactorization{T}}
const TransposeAbstractBatchedMatrixFactorization{T} = LinearAlgebra.TransposeFactorization{
T, <:AbstractBatchedMatrixFactorization{T}}
const AdjointOrTransposeAbstractBatchedMatrixFactorization{T} = Union{
AdjointAbstractBatchedMatrixFactorization{T},
TransposeAbstractBatchedMatrixFactorization{T}}

const AllAbstractBatchedMatrixFactorization{T} = Union{
AbstractBatchedMatrixFactorization{T},
AdjointOrTransposeAbstractBatchedMatrixFactorization{T}}

nbatches(f::AdjointOrTransposeAbstractBatchedMatrixFactorization) = nbatches(parent(f))
batchview(f::AdjointOrTransposeAbstractBatchedMatrixFactorization) = batchview(parent(f))
function batchview(f::AdjointOrTransposeAbstractBatchedMatrixFactorization, idx::Int)
return batchview(parent(f), idx)
end

# First we take inputs and standardize them
function LinearAlgebra.ldiv!(A::AllAbstractBatchedMatrixFactorization, b::AbstractVector)
LinearAlgebra.ldiv!(A, reshape(b, :, nbatches(A)))
return b
end

function LinearAlgebra.ldiv!(
X::AbstractVector, A::AllAbstractBatchedMatrixFactorization, b::AbstractVector)
LinearAlgebra.ldiv!(reshape(X, :, nbatches(A)), A, reshape(b, :, nbatches(A)))
return X
end

function Base.:\(A::AllAbstractBatchedMatrixFactorization, b::AbstractVector)
X = similar(b, promote_type(eltype(A), eltype(b)), size(A, 1))
LinearAlgebra.ldiv!(X, A, b)
return X
end

function Base.:\(A::AllAbstractBatchedMatrixFactorization, b::AbstractMatrix)
X = similar(b, promote_type(eltype(A), eltype(b)), size(A, 1))
LinearAlgebra.ldiv!(X, A, vec(b))
return reshape(X, :, nbatches(b))
end

# Now we implement the actual factorizations
## This just loops over the batches and calls the factorization on each, mostly used where
## we don't have native batched factorizations
struct GenericBatchedFactorization{T, A, F} <: AbstractBatchedMatrixFactorization{T}
alg::A
fact::Vector{F}

function GenericBatchedFactorization(alg, fact)
return GenericBatchedFactorization{eltype(first(fact))}(alg, fact)
end

function GenericBatchedFactorization{T}(alg::A, fact::Vector{F}) where {T, A, F}
return new{T, A, F}(alg, fact)
end
end

nbatches(F::GenericBatchedFactorization) = length(F.fact)
batchview(F::GenericBatchedFactorization) = F.fact
batchview(F::GenericBatchedFactorization, idx::Int) = F.fact[idx]
Base.size(F::GenericBatchedFactorization) = size(first(F.fact)) .* length(F.fact)
function Base.size(F::GenericBatchedFactorization, i::Integer)
return size(first(F.fact), i) * length(F.fact)
end

function LinearAlgebra.issuccess(fact::GenericBatchedFactorization)
return all(LinearAlgebra.issuccess, fact.fact)
end

function Base.adjoint(fact::GenericBatchedFactorization{T}) where {T}
return GenericBatchedFactorization{T}(fact.alg, adjoint.(fact.fact))
end

function Base.show(io::IO, mime::MIME"text/plain", F::GenericBatchedFactorization)
println(io, "GenericBatchedFactorization() with Batch Count: $(nbatches(F))")
Base.printstyled(io, "Factorization Function: "; color=:green)
show(io, mime, F.alg)
Base.printstyled(io, "\nPrototype Factorization: "; color=:green)
show(io, mime, first(F.fact))
end

for fact in (:qr, :lu, :cholesky)
fact! = Symbol(fact, :!)
@eval begin
function LinearAlgebra.$(fact)(op::UniformBlockDiagonalOperator, args...; kwargs...)
return LinearAlgebra.$(fact!)(copy(op), args...; kwargs...)
end

function LinearAlgebra.$(fact!)(
op::UniformBlockDiagonalOperator, args...; kwargs...)
fact = map(Aᵢ -> LinearAlgebra.$(fact!)(Aᵢ, args...; kwargs...), batchview(op))
return GenericBatchedFactorization(LinearAlgebra.$(fact!), fact)
end
end
end

function LinearAlgebra.ldiv!(A::GenericBatchedFactorization, b::AbstractMatrix)
@assert nbatches(A) == nbatches(b)
for i in 1:nbatches(A)
LinearAlgebra.ldiv!(batchview(A, i), batchview(b, i))
end
return b
end

function LinearAlgebra.ldiv!(
X::AbstractMatrix, A::GenericBatchedFactorization, b::AbstractMatrix)
@assert nbatches(A) == nbatches(b) == nbatches(X)
for i in 1:nbatches(A)
LinearAlgebra.ldiv!(batchview(X, i), batchview(A, i), batchview(b, i))
end
return X
end
3 changes: 0 additions & 3 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,6 @@ end
return promote_type(T, eltype(f.x)), false
end

# eachrow override
@inline _eachrow(X) = eachrow(X)

# MLUtils.jl has too many unwanted dependencies
@inline fill_like(x::AbstractArray, v, ::Type{T}, dims...) where {T} = fill!(
similar(x, T, dims...), v)
Expand Down
Loading

0 comments on commit 7dd7160

Please sign in to comment.