Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

[WIP] GMRES Implementation #8

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
2 changes: 2 additions & 0 deletions ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ using BatchedRoutines: AbstractBatchedMatrixFactorization, BatchedRoutines,
using CUDA: CUBLAS, CUDA, CUSOLVER, CuArray, CuMatrix, CuPtr, CuVector, DenseCuArray,
DenseCuMatrix
using ConcreteStructs: @concrete
using FastClosures: @closure
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!

const CuBlasFloat = Union{Float16, Float32, Float64, ComplexF32, ComplexF64}

const CuUniformBlockDiagonalMatrix{T} = UniformBlockDiagonalMatrix{T, <:CuArray{T, 3}}

include("batched_mul.jl")
include("kernels.jl")
include("factorization.jl")

end
27 changes: 27 additions & 0 deletions ext/BatchedRoutinesCUDAExt/kernels.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Some of these are based off
# https://github.com/JaneliaSciComp/BatchedBLAS.jl/blob/master/src/BatchedBLAS.jl
## https://github.com/JuliaLang/julia/issues/40469
@inline _maybe_cast(::Type, x) = x
@inline function _maybe_cast(::Type{T}, x::AbstractFloat) where {T <: Integer}
return round(T, clamp(x, typemin(T), typemax(T)))
end

@inline function batched_dot_kernel!(::Type{T}, o, x, y) where {T}
k = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x
@inbounds if k ≤ size(x, 2)
tmp = T(0)
for i in 1:size(x, 1)
tmp += x[i, k] * T(y[i, k])
end
o[k] = _maybe_cast(eltype(o), tmp)
end
return nothing
end

function BatchedRoutines.batched_dot!(A::CuVector, B::CuMatrix, C::CuMatrix)
T = promote_type(eltype(A), eltype(B), eltype(C))
CUDA.@cuda name="batched_dot!" launch=true batched_dot_kernel!(T, A, B, C)
return A
end

# TODO: batched_axpy!
13 changes: 8 additions & 5 deletions src/BatchedRoutines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import PrecompileTools: @recompile_invalidations
using ConcreteStructs: @concrete
using FastClosures: @closure
using FillArrays: Fill, OneElement
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero,
mul!, pinv
using LinearAlgebra: BLAS, ColumnNorm, I, LinearAlgebra, NoPivot, RowMaximum,
RowNonZero, axpby!, axpy!, mul!, norm, pinv
using LuxDeviceUtils: LuxDeviceUtils, get_device
using Printf: @printf
end

const CRC = ChainRulesCore
Expand All @@ -30,18 +31,20 @@ include("api.jl")
include("helpers.jl")
include("matrix.jl")

include("internal.jl")

include("impl/batched_mul.jl")
include("impl/batched_gmres.jl")

include("chainrules.jl")

export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote
export batched_adjoint, batched_gradient, batched_jacobian, batched_pickchunksize,
batched_mul, batched_pinv, batched_transpose
batched_mul, batched_norm, batched_pinv, batched_transpose
export batchview, nbatches
export UniformBlockDiagonalMatrix

# TODO: Ship a custom GMRES routine & if needed some of the other complex nonlinear solve
# routines
# Special Solvers
export BatchedGmresSolver, batched_gmres, batched_gmres!

end
7 changes: 5 additions & 2 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ TODO: Needs Documentation (take from NNlib.jl)
"""
batched_mul(A, B) = _batched_mul(A, B)

batched_mul!(C, A, B) = _batched_mul!(C, A, B)

"""
batched_transpose(X::AbstractArray{T, 3}) where {T}

Expand Down Expand Up @@ -128,3 +126,8 @@ Reshape `x` into a matrix with the batch dimension as the last dimension.
Reshape `x` into an array with the batch dimension as the last dimension.
"""
@inline batched_reshape(x::AbstractArray, dims...) = reshape(x, dims..., nbatches(x))

@inline function batched_norm(x::AbstractMatrix; drop::Val{D}=Val(true)) where {D}
y = sqrt.(sum(abs2, x; dims=1))
return D ? dropdims(y; dims=1) : y
end
13 changes: 4 additions & 9 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ function CRC.rrule(::typeof(batched_gradient), ad, f::F, x) where {F}
dx = batched_gradient(ad, f, x)
∇batched_gradient = @closure Δ -> begin
∂x = _jacobian_vector_product(
AutoForwardDiff(), @closure(x->batched_gradient(ad, f, x)),
x, reshape(Δ, size(x)))
AutoForwardDiff(), @closure(x->batched_gradient(ad, f, x)), x, Δ)
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
return dx, ∇batched_gradient
Expand All @@ -70,14 +69,10 @@ function CRC.rrule(::typeof(batched_gradient), ad, f::F, x, p) where {F}
dx = batched_gradient(ad, f, x, p)
∇batched_gradient = @closure Δ -> begin
∂x = _jacobian_vector_product(
AutoForwardDiff(), @closure(x->batched_gradient(ad, Base.Fix2(f, p), x)),
x, reshape(Δ, size(x)))
AutoForwardDiff(), @closure(x->batched_gradient(ad, Base.Fix2(f, p), x)), x, Δ)
ad_ = _maybe_remove_chunksize(ad, p)
∂p = _jacobian_vector_product(AutoForwardDiff(),
@closure((x, p)->batched_gradient(
_maybe_remove_chunksize(ad, p), Base.Fix1(f, x), p)),
x,
reshape(Δ, size(x)),
p)
@closure((x, p)->batched_gradient(ad_, Base.Fix1(f, x), p)), x, Δ, p)
return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p
end
return dx, ∇batched_gradient
Expand Down
Loading
Loading