diff --git a/src/MPOConstruction.jl b/src/MPOConstruction.jl index 7495b99..7954f24 100644 --- a/src/MPOConstruction.jl +++ b/src/MPOConstruction.jl @@ -53,16 +53,17 @@ function for_non_zeros_batch(f::Function, A::SparseMatrixCSC, cols::Vector{Int}) end @timeit function at_site!( + ValType::Type{<:Number}, graphs::Dict{QN,MPOGraph{C}}, n::Int, sites::Vector{<:Index}, tol::Real, opCacheVec::OpCacheVec, -)::Tuple{Dict{QN,MPOGraph{C}},BlockSparseMatrix{C},Index} where {C} +)::Tuple{Dict{QN,MPOGraph{C}},BlockSparseMatrix{ValType},Index} where {C} hasQNs = hasqns(sites) nextGraphs = Dict{QN,MPOGraph{C}}() - matrix = BlockSparseMatrix{C}() + matrix = BlockSparseMatrix{ValType}() qi = Vector{Pair{QN,Int}}() outgoingLinkOffset = 0 @@ -194,7 +195,9 @@ end llinks = Vector{Index}(undef, N + 1) for n in 0:N - graphs, symbolicMatrix, llinks[n + 1] = at_site!(graphs, n, sites, tol, opCacheVec) + graphs, symbolicMatrix, llinks[n + 1] = at_site!( + ValType, graphs, n, sites, tol, opCacheVec + ) # For the 0th iteration we only care about constructing the graphs for the next site. n == 0 && continue @@ -239,27 +242,46 @@ end return H end -function svdMPO_new(os::OpIDSum{C}, opCacheVec::OpCacheVec, sites; kwargs...)::MPO where {C} - # Function barrier to improve type stability - ValType = determine_val_type(os) - return svdMPO_new(ValType, os, opCacheVec, sites; kwargs...) -end - function MPO_new( + ValType::Type{<:Number}, os::OpIDSum, sites::Vector{<:Index}, opCacheVec::OpCacheVec; - basisOpCacheVec::Union{Nothing,OpCacheVec}=nothing, - kwargs..., + basisOpCacheVec=nothing, + tol::Real=-1, )::MPO + opCacheVec = to_OpCacheVec(sites, opCacheVec) + basisOpCacheVec = to_OpCacheVec(sites, basisOpCacheVec) os, opCacheVec = prepare_opID_sum!(os, sites, opCacheVec, basisOpCacheVec) + return svdMPO_new(ValType, os, opCacheVec, sites; tol=tol) +end + +function MPO_new( + os::OpIDSum, sites::Vector{<:Index}, opCacheVec; basisOpCacheVec=nothing, tol::Real=-1 +)::MPO + opCacheVec = to_OpCacheVec(sites, opCacheVec) + ValType = determine_val_type(os, opCacheVec) + return MPO_new(ValType, os, sites, opCacheVec; basisOpCacheVec=basisOpCacheVec, tol=tol) +end - return svdMPO_new(os, opCacheVec, sites; kwargs...) +function MPO_new( + ValType::Type{<:Number}, + os::OpSum, + sites::Vector{<:Index}; + basisOpCacheVec=nothing, + tol::Real=-1, +)::MPO + opIDSum, opCacheVec = op_sum_to_opID_sum(os, sites) + return MPO_new( + ValType, opIDSum, sites, opCacheVec; basisOpCacheVec=basisOpCacheVec, tol=tol + ) end -function MPO_new(os::OpSum, sites::Vector{<:Index}; kwargs...)::MPO +function MPO_new( + os::OpSum, sites::Vector{<:Index}; basisOpCacheVec=nothing, tol::Real=-1 +)::MPO opIDSum, opCacheVec = op_sum_to_opID_sum(os, sites) - return MPO_new(opIDSum, sites, opCacheVec; kwargs...) + return MPO_new(opIDSum, sites, opCacheVec; basisOpCacheVec=basisOpCacheVec, tol=tol) end function sparsity(mpo::MPO)::Float64 diff --git a/src/OpIDSum.jl b/src/OpIDSum.jl index 8a3202c..ba527dc 100644 --- a/src/OpIDSum.jl +++ b/src/OpIDSum.jl @@ -14,6 +14,26 @@ end OpCacheVec = Vector{Vector{OpInfo}} +function to_OpCacheVec(sites::Vector{<:Index}, ops::OpCacheVec)::OpCacheVec + length(sites) != length(ops) && + error("Mismatch in the number of sites in `sites` and `ops`.") + any(ops[i][1].matrix != I for i in 1:length(sites)) && + error("The first operator on each site must be the identity.") + return ops +end + +function to_OpCacheVec(sites::Vector{<:Index}, ops::Vector{Vector{String}})::OpCacheVec + length(sites) != length(ops) && + error("Mismatch in the number of sites in `sites` and `ops`.") + any(ops[i][1] != "I" for i in 1:length(sites)) && + error("The first operator on each site must be the identity.") + return [[OpInfo(ITensors.Op(op, n), sites[n]) for op in ops[n]] for n in 1:length(sites)] +end + +function to_OpCacheVec(sites::Vector{<:Index}, ::Nothing)::Nothing + return nothing +end + struct OpID id::Int16 n::Int16 @@ -71,13 +91,10 @@ function add_to_scalar!(os::OpIDSum{C}, i::Integer, scalar::C)::Nothing where {C return nothing end -# TODO: Define as `C`. Rename `coefficient_type`. -function determine_val_type(os::OpIDSum{C}) where {C} - for i in eachindex(os) - scalar, ops = os[i] - (!isreal(scalar)) && return ComplexF64 - end - +function determine_val_type(os::OpIDSum{C}, opCacheVec::OpCacheVec) where {C} + !all(isreal(scalar) for scalar in os.scalars) && return ComplexF64 + !all(isreal(op.matrix) for opsOfSite in opCacheVec for op in opsOfSite) && + return ComplexF64 return Float64 end