Skip to content

Commit

Permalink
Merge pull request #38 from gridap/adding_snes_support
Browse files Browse the repository at this point in the history
Adding snes support
  • Loading branch information
fverdugo authored Oct 29, 2021
2 parents 9f41fdb + dbcca1b commit 2620049
Show file tree
Hide file tree
Showing 27 changed files with 1,245 additions and 88 deletions.
522 changes: 522 additions & 0 deletions Manifest.toml

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.3.0"

[deps]
Gridap = "56d4f2e9-7ea1-5844-9cf6-b9c51ca7ce8e"
GridapDistributed = "f9701e48-63b3-45aa-9a63-9bc6c271f355"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Expand All @@ -15,7 +16,8 @@ SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"

[compat]
Gridap = "0.17"
MPI = "0.14, 0.15, 0.16"
GridapDistributed = "0.2.0"
MPI = "0.14, 0.15, 0.16, 0.17, 0.18, 0.19"
PETSc_jll = "3.13"
PartitionedArrays = "0.2.4"
SparseMatricesCSR = "0.6.1"
Expand Down
7 changes: 6 additions & 1 deletion src/GridapPETSc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,16 @@ end
include("PETSC.jl")

using GridapPETSc.PETSC: @check_error_code
using GridapPETSc.PETSC: PetscBool, PetscInt, PetscScalar, Vec, Mat, KSP, PC
using GridapPETSc.PETSC: PetscBool, PetscInt, PetscScalar, Vec, Mat, KSP, PC, SNES
#export PETSC
export @check_error_code
export PetscBool, PetscInt, PetscScalar, Vec, Mat, KSP, PC

include("Environment.jl")

export PETScVector
export get_local_oh_vector
export get_local_vector, restore_local_vector!
export PETScMatrix
export petsc_sparse
include("PETScArrays.jl")
Expand All @@ -82,6 +84,9 @@ include("PartitionedArrays.jl")
export PETScLinearSolver
include("PETScLinearSolvers.jl")

export PETScNonlinearSolver
include("PETScNonlinearSolvers.jl")

include("PETScAssembly.jl")

end # module
70 changes: 63 additions & 7 deletions src/PETSC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ let types_jl = joinpath(@__DIR__,"..","deps","PetscDataTypes.jl")
if !isfile(types_jl)
msg = """
GridapPETSc needs to be configured before use. Type
pkg> build
and try again.
"""
error(msg)
end

include(types_jl)
end

Expand Down Expand Up @@ -98,7 +98,7 @@ macro PETSC_VIEWER_STDOUT_SELF()
quote
PETSC_VIEWER_STDOUT_(MPI.COMM_SELF)
end
end
end

"""
@PETSC_VIEWER_STDOUT_WORLD
Expand All @@ -109,7 +109,7 @@ macro PETSC_VIEWER_STDOUT_WORLD()
quote
PETSC_VIEWER_STDOUT_(MPI.COMM_WORLD)
end
end
end

"""
@PETSC_VIEWER_DRAW_SELF
Expand All @@ -120,7 +120,7 @@ macro PETSC_VIEWER_DRAW_SELF()
quote
PETSC_VIEWER_DRAW_(MPI.COMM_SELF)
end
end
end

"""
@PETSC_VIEWER_DRAW_WORLD
Expand All @@ -131,7 +131,7 @@ macro PETSC_VIEWER_DRAW_WORLD()
quote
PETSC_VIEWER_DRAW_(MPI.COMM_WORLD)
end
end
end

# Vector related functions

Expand Down Expand Up @@ -197,7 +197,15 @@ Base.convert(::Type{Vec},p::Ptr{Cvoid}) = Vec(p)
@wrapper(:VecSetValues,PetscErrorCode,(Vec,PetscInt,Ptr{PetscInt},Ptr{PetscScalar},InsertMode),(x,ni,ix,y,iora),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/Vec/VecSetValues.html")
@wrapper(:VecGetValues,PetscErrorCode,(Vec,PetscInt,Ptr{PetscInt},Ptr{PetscScalar}),(x,ni,ix,y),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/Vec/VecGetValues.html")
@wrapper(:VecGetArray,PetscErrorCode,(Vec,Ptr{Ptr{PetscScalar}}),(x,a),"https://petsc.org/release/docs/manualpages/Vec/VecGetArray.html")
@wrapper(:VecGetArrayRead,PetscErrorCode,(Vec,Ptr{Ptr{PetscScalar}}),(x,a),"https://petsc.org/
release/docs/manualpages/Vec/VecGetArrayRead.html")
@wrapper(:VecGetArrayWrite,PetscErrorCode,(Vec,Ptr{Ptr{PetscScalar}}),(x,a),"https://petsc.org/release/docs/manualpages/Vec/VecGetArrayWrite.html")
@wrapper(:VecRestoreArray,PetscErrorCode,(Vec,Ptr{Ptr{PetscScalar}}),(x,a),"https://petsc.org/release/docs/manualpages/Vec/VecRestoreArray.html")
@wrapper(:VecRestoreArrayRead,PetscErrorCode,(Vec,Ptr{Ptr{PetscScalar}}),(x,a),"https://petsc.org/
release/docs/manualpages/Vec/VecRestoreArrayRead.html")
@wrapper(:VecRestoreArrayWrite,PetscErrorCode,(Vec,Ptr{Ptr{PetscScalar}}),(x,a),"https://petsc.org/release/docs/manualpages/Vec/VecRestoreArrayWrite.html")
@wrapper(:VecGetSize,PetscErrorCode,(Vec,Ptr{PetscInt}),(vec,n),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/Vec/VecGetSize.html")
@wrapper(:VecGetLocalSize,PetscErrorCode,(Vec,Ptr{PetscInt}),(vec,n),"https://petsc.org/release/docs/manualpages/Vec/VecGetLocalSize.html")
@wrapper(:VecAssemblyBegin,PetscErrorCode,(Vec,),(vec,),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/Vec/VecAssemblyBegin.html")
@wrapper(:VecAssemblyEnd,PetscErrorCode,(Vec,),(vec,),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/Vec/VecAssemblyEnd.html")
@wrapper(:VecPlaceArray,PetscErrorCode,(Vec,Ptr{PetscScalar}),(vec,array),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/Vec/VecPlaceArray.html")
Expand All @@ -211,6 +219,9 @@ Base.convert(::Type{Vec},p::Ptr{Cvoid}) = Vec(p)
@wrapper(:VecAXPBY,PetscErrorCode,(Vec,PetscScalar,PetscScalar,Vec),(y,alpha,beta,x),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/Vec/VecAXPBY.html")
@wrapper(:VecSetOption,PetscErrorCode,(Vec,VecOption,PetscBool),(x,op,flg),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/Vec/VecSetOption.html")
@wrapper(:VecNorm,PetscErrorCode,(Vec,NormType,Ptr{PetscReal}),(x,typ,val),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/Vec/VecNorm.html")
@wrapper(:VecGhostGetLocalForm,PetscErrorCode,(Vec,Ptr{Vec}),(g,l),"https://petsc.org/release/docs/
manualpages/Vec/VecGhostGetLocalForm.html")
@wrapper(:VecGhostRestoreLocalForm,PetscErrorCode,(Vec,Ptr{Vec}),(g,l),"https://petsc.org/release/docs/manualpages/Vec/VecGhostRestoreLocalForm.html")

# Matrix related functions

Expand Down Expand Up @@ -607,4 +618,49 @@ Base.convert(::Type{PC},p::Ptr{Cvoid}) = PC(p)
@wrapper(:KSPGetPC,PetscErrorCode,(KSP,Ptr{PC}),(ksp,pc),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/KSP/KSPGetPC.html")
@wrapper(:PCSetType,PetscErrorCode,(PC,PCType),(pc,typ),"https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/PC/PCSetType.html")


"""
Julia alias for the `SNES` C type.
See [PETSc manual](https://petsc.org/release/docs/manualpages/SNES/SNES.html).
"""
struct SNES
ptr::Ptr{Cvoid}
end
SNES() = SNES(Ptr{Cvoid}())
Base.convert(::Type{SNES},p::Ptr{Cvoid}) = SNES(p)

const SNESType = Cstring
const SNESNEWTONLS = "newtonls"
const SNESNEWTONTR = "newtontr"
const SNESPYTHON = "python"
const SNESNRICHARDSON = "nrichardson"
const SNESKSPONLY = "ksponly"
const SNESKSPTRANSPOSEONLY = "ksptransposeonly"
const SNESVINEWTONRSLS = "vinewtonrsls"
const SNESVINEWTONSSLS = "vinewtonssls"
const SNESNGMRES = "ngmres"
const SNESQN = "qn"
const SNESSHELL = "shell"
const SNESNGS = "ngs"
const SNESNCG = "ncg"
const SNESFAS = "fas"
const SNESMS = "ms"
const SNESNASM = "nasm"
const SNESANDERSON = "anderson"
const SNESASPIN = "aspin"
const SNESCOMPOSITE = "composite"
const SNESPATCH = "patch"


@wrapper(:SNESCreate,PetscErrorCode,(MPI.Comm,Ptr{SNES}),(comm,snes),"https://petsc.org/release/docs/manualpages/SNES/SNESCreate.html")
@wrapper(:SNESSetFunction,PetscErrorCode,(SNES,Vec,Ptr{Cvoid},Ptr{Cvoid}),(snes,vec,fptr,ctx),"https://petsc.org/release/docs/manualpages/SNES/SNESSetFunction.html")
@wrapper(:SNESSetJacobian,PetscErrorCode,(SNES,Mat,Mat,Ptr{Cvoid},Ptr{Cvoid}),(snes,A,P,jacptr,ctx),"https://petsc.org/release/docs/manualpages/SNES/SNESSetJacobian.html")
@wrapper(:SNESSolve,PetscErrorCode,(SNES,Vec,Vec),(snes,b,x),"https://petsc.org/release/docs/manualpages/SNES/SNESSolve.html")
@wrapper(:SNESDestroy,PetscErrorCode,(Ptr{SNES},),(snes,),"https://petsc.org/release/docs/manualpages/SNES/SNESDestroy.html")
@wrapper(:SNESSetFromOptions,PetscErrorCode,(SNES,),(snes,),"https://petsc.org/release/docs/manualpages/SNES/SNESSetFromOptions.html")
@wrapper(:SNESView,PetscErrorCode,(SNES,PetscViewer),(snes,viewer),"https://petsc.org/release/docs/manualpages/SNES/SNESView.html")
@wrapper(:SNESSetType,PetscErrorCode,(SNES,SNESType),(snes,type),"https://petsc.org/release/docs/
manualpages/SNES/SNESSetType.html")

end # module
121 changes: 120 additions & 1 deletion src/PETScArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,83 @@ function Base.convert(::Type{PETScVector},a::AbstractVector)
PETScVector(array)
end

function Base.copy!(a::AbstractVector,petsc_vec::Vec)
aux=PETScVector()
aux.vec[] = petsc_vec.ptr
Base.copy!(a,aux)
end

function Base.copy!(petsc_vec::Vec,a::AbstractVector)
aux=PETScVector()
aux.vec[] = petsc_vec.ptr
Base.copy!(aux,a)
end

function Base.copy!(vec::AbstractVector,petscvec::PETScVector)
lg=get_local_oh_vector(petscvec)
if isa(lg,PETScVector) # petscvec is a ghosted vector
lx=get_local_vector(lg)
@assert length(lx)==length(vec)
vec .= lx
restore_local_vector!(lx,lg)
GridapPETSc.Finalize(lg)
else # petscvec is NOT a ghosted vector
@assert length(lg)==length(vec)
vec .= lg
restore_local_vector!(lg,petscvec)
end
end

function Base.copy!(petscvec::PETScVector,vec::AbstractVector)
lg=get_local_oh_vector(petscvec)
if isa(lg,PETScVector) # petscvec is a ghosted vector
lx=get_local_vector(lg)
@assert length(lx)==length(vec)
lx .= vec
restore_local_vector!(lx,lg)
GridapPETSc.Finalize(lg)
else # petscvec is NOT a ghosted vector
@assert length(lg)==length(vec)
lg .= vec
restore_local_vector!(lg,petscvec)
end
end



function get_local_oh_vector(a::PETScVector)
v=PETScVector()
@check_error_code PETSC.VecGhostGetLocalForm(a.vec[],v.vec)
if v.vec[] != C_NULL # a is a ghosted vector
v.ownership=a
Init(v)
return v
else # a is NOT a ghosted vector
return get_local_vector(a)
end
end

function _local_size(a::PETScVector)
r_sz = Ref{PetscInt}()
@check_error_code PETSC.VecGetLocalSize(a.vec[], r_sz)
r_sz[]
end

# This function works with either ghosted or non-ghosted MPI vectors.
# In the case of a ghosted vector it solely returns the locally owned
# entries.
function get_local_vector(a::PETScVector)
r_pv = Ref{Ptr{PetscScalar}}()
@check_error_code PETSC.VecGetArray(a.vec[], r_pv)
v = unsafe_wrap(Array, r_pv[], _local_size(a); own = false)
return v
end

function restore_local_vector!(v::Array,a::PETScVector)
@check_error_code PETSC.VecRestoreArray(a.vec[], Ref(pointer(v)))
nothing
end

# Matrix

mutable struct PETScMatrix <: AbstractMatrix{PetscScalar}
Expand Down Expand Up @@ -194,6 +271,7 @@ function PETScMatrix(csr::SparseMatrixCSR{0,PetscScalar,PetscInt})
Init(A)
end


function Base.similar(::PETScMatrix,::Type{PetscScalar},ax::Tuple{Int,Int})
PETScMatrix(ax[1],ax[2])
end
Expand Down Expand Up @@ -225,6 +303,35 @@ function Base.copy(a::PETScMatrix)
Init(v)
end

function Base.copy!(petscmat::Mat,a::AbstractMatrix)
aux=PETScMatrix()
aux.mat[] = petscmat.ptr
Base.copy!(aux,a)
end


function Base.copy!(petscmat::PETScMatrix,mat::AbstractMatrix)
n = size(mat)[2]
cols = [PetscInt(j-1) for j=1:n]
row = Vector{PetscInt}(undef,1)
vals = Vector{eltype(mat)}(undef,n)
for i=1:size(mat)[1]
row[1]=PetscInt(i-1)
vals .= view(mat,i,:)
PETSC.MatSetValues(petscmat.mat[],
PetscInt(1),
row,
n,
cols,
vals,
PETSC.INSERT_VALUES)
end
@check_error_code PETSC.MatAssemblyBegin(petscmat.mat[], PETSC.MAT_FINAL_ASSEMBLY)
@check_error_code PETSC.MatAssemblyEnd(petscmat.mat[] , PETSC.MAT_FINAL_ASSEMBLY)
end



function Base.convert(::Type{PETScMatrix},a::PETScMatrix)
a
end
Expand All @@ -235,6 +342,19 @@ function Base.convert(::Type{PETScMatrix},a::AbstractSparseMatrix)
PETScMatrix(csr)
end

function Base.convert(::Type{PETScMatrix}, a::AbstractMatrix{PetscScalar})
m, n = size(a)
i = [PetscInt(n*(i-1)) for i=1:m+1]
j = [PetscInt(j-1) for i=1:m for j=1:n]
v = [ a[i,j] for i=1:m for j=1:n]
A = PETScMatrix()
A.ownership = a
@check_error_code PETSC.MatCreateSeqAIJWithArrays(MPI.COMM_SELF,m,n,i,j,v,A.mat)
@check_error_code PETSC.MatAssemblyBegin(A.mat[],PETSC.MAT_FINAL_ASSEMBLY)
@check_error_code PETSC.MatAssemblyEnd(A.mat[],PETSC.MAT_FINAL_ASSEMBLY)
Init(A)
end

function petsc_sparse(i,j,v,m,n)
csr = sparsecsr(Val(0),i,j,v,m,n)
convert(PETScMatrix,csr)
Expand Down Expand Up @@ -313,4 +433,3 @@ function LinearAlgebra.norm(a::PETScVector, p::Real=2)
@check_error_code PETSC.VecNorm(a.vec[],nt,val)
Float64(val[])
end

4 changes: 2 additions & 2 deletions src/PETScLinearSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ struct PETScLinearSolver{F} <: LinearSolver
comm::MPI.Comm
end

from_options(ksp) = @check_error_code PETSC.KSPSetFromOptions(ksp[])
ksp_from_options(ksp) = @check_error_code PETSC.KSPSetFromOptions(ksp[])

function PETScLinearSolver(comm::MPI.Comm)
PETScLinearSolver(from_options,comm)
PETScLinearSolver(ksp_from_options,comm)
end

PETScLinearSolver() = PETScLinearSolver(MPI.COMM_WORLD)
Expand Down
Loading

0 comments on commit 2620049

Please sign in to comment.