Skip to content

Commit

Permalink
Specify type of MPO and pass OpCacheVec as strings
Browse files Browse the repository at this point in the history
  • Loading branch information
corbett5 committed Feb 26, 2024
1 parent e0cb5bf commit 5133de1
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 21 deletions.
50 changes: 36 additions & 14 deletions src/MPOConstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,17 @@ function for_non_zeros_batch(f::Function, A::SparseMatrixCSC, cols::Vector{Int})
end

@timeit function at_site!(
ValType::Type{<:Number},
graphs::Dict{QN,MPOGraph{C}},
n::Int,
sites::Vector{<:Index},
tol::Real,
opCacheVec::OpCacheVec,
)::Tuple{Dict{QN,MPOGraph{C}},BlockSparseMatrix{C},Index} where {C}
)::Tuple{Dict{QN,MPOGraph{C}},BlockSparseMatrix{ValType},Index} where {C}
hasQNs = hasqns(sites)
nextGraphs = Dict{QN,MPOGraph{C}}()

matrix = BlockSparseMatrix{C}()
matrix = BlockSparseMatrix{ValType}()

qi = Vector{Pair{QN,Int}}()
outgoingLinkOffset = 0
Expand Down Expand Up @@ -194,7 +195,9 @@ end

llinks = Vector{Index}(undef, N + 1)
for n in 0:N
graphs, symbolicMatrix, llinks[n + 1] = at_site!(graphs, n, sites, tol, opCacheVec)
graphs, symbolicMatrix, llinks[n + 1] = at_site!(
ValType, graphs, n, sites, tol, opCacheVec
)

# For the 0th iteration we only care about constructing the graphs for the next site.
n == 0 && continue
Expand Down Expand Up @@ -239,27 +242,46 @@ end
return H
end

function svdMPO_new(os::OpIDSum{C}, opCacheVec::OpCacheVec, sites; kwargs...)::MPO where {C}
# Function barrier to improve type stability
ValType = determine_val_type(os)
return svdMPO_new(ValType, os, opCacheVec, sites; kwargs...)
end

function MPO_new(
ValType::Type{<:Number},
os::OpIDSum,
sites::Vector{<:Index},
opCacheVec::OpCacheVec;
basisOpCacheVec::Union{Nothing,OpCacheVec}=nothing,
kwargs...,
basisOpCacheVec=nothing,
tol::Real=-1,
)::MPO
opCacheVec = to_OpCacheVec(sites, opCacheVec)
basisOpCacheVec = to_OpCacheVec(sites, basisOpCacheVec)
os, opCacheVec = prepare_opID_sum!(os, sites, opCacheVec, basisOpCacheVec)
return svdMPO_new(ValType, os, opCacheVec, sites; tol=tol)
end

function MPO_new(
os::OpIDSum, sites::Vector{<:Index}, opCacheVec; basisOpCacheVec=nothing, tol::Real=-1
)::MPO
opCacheVec = to_OpCacheVec(sites, opCacheVec)
ValType = determine_val_type(os, opCacheVec)
return MPO_new(ValType, os, sites, opCacheVec; basisOpCacheVec=basisOpCacheVec, tol=tol)
end

return svdMPO_new(os, opCacheVec, sites; kwargs...)
function MPO_new(
ValType::Type{<:Number},
os::OpSum,
sites::Vector{<:Index};
basisOpCacheVec=nothing,
tol::Real=-1,
)::MPO
opIDSum, opCacheVec = op_sum_to_opID_sum(os, sites)
return MPO_new(
ValType, opIDSum, sites, opCacheVec; basisOpCacheVec=basisOpCacheVec, tol=tol
)
end

function MPO_new(os::OpSum, sites::Vector{<:Index}; kwargs...)::MPO
function MPO_new(
os::OpSum, sites::Vector{<:Index}; basisOpCacheVec=nothing, tol::Real=-1
)::MPO
opIDSum, opCacheVec = op_sum_to_opID_sum(os, sites)
return MPO_new(opIDSum, sites, opCacheVec; kwargs...)
return MPO_new(opIDSum, sites, opCacheVec; basisOpCacheVec=basisOpCacheVec, tol=tol)
end

function sparsity(mpo::MPO)::Float64
Expand Down
31 changes: 24 additions & 7 deletions src/OpIDSum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ end

OpCacheVec = Vector{Vector{OpInfo}}

function to_OpCacheVec(sites::Vector{<:Index}, ops::OpCacheVec)::OpCacheVec
length(sites) != length(ops) &&
error("Mismatch in the number of sites in `sites` and `ops`.")
any(ops[i][1].matrix != I for i in 1:length(sites)) &&
error("The first operator on each site must be the identity.")
return ops
end

function to_OpCacheVec(sites::Vector{<:Index}, ops::Vector{Vector{String}})::OpCacheVec
length(sites) != length(ops) &&
error("Mismatch in the number of sites in `sites` and `ops`.")
any(ops[i][1] != "I" for i in 1:length(sites)) &&
error("The first operator on each site must be the identity.")
return [[OpInfo(ITensors.Op(op, n), sites[n]) for op in ops[n]] for n in 1:length(sites)]
end

function to_OpCacheVec(sites::Vector{<:Index}, ::Nothing)::Nothing
return nothing
end

struct OpID
id::Int16
n::Int16
Expand Down Expand Up @@ -71,13 +91,10 @@ function add_to_scalar!(os::OpIDSum{C}, i::Integer, scalar::C)::Nothing where {C
return nothing
end

# TODO: Define as `C`. Rename `coefficient_type`.
function determine_val_type(os::OpIDSum{C}) where {C}
for i in eachindex(os)
scalar, ops = os[i]
(!isreal(scalar)) && return ComplexF64
end

function determine_val_type(os::OpIDSum{C}, opCacheVec::OpCacheVec) where {C}
!all(isreal(scalar) for scalar in os.scalars) && return ComplexF64
!all(isreal(op.matrix) for opsOfSite in opCacheVec for op in opsOfSite) &&
return ComplexF64
return Float64
end

Expand Down

0 comments on commit 5133de1

Please sign in to comment.