Skip to content

Commit

Permalink
Move to no parametric type on abstract quanutum objects
Browse files Browse the repository at this point in the history
  • Loading branch information
akirakyle committed Dec 9, 2024
1 parent 12765d5 commit c86c1e6
Show file tree
Hide file tree
Showing 20 changed files with 71 additions and 100 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "QuantumOpticsBase"
uuid = "4f57444f-1401-5e15-980d-4471b28d5678"
version = "0.5.4"
version = "0.6.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -25,7 +25,7 @@ FastGaussQuadrature = "0.5, 1"
FillArrays = "0.13, 1"
LRUCache = "1"
LinearAlgebra = "1"
QuantumInterface = "0.3.3"
QuantumInterface = "0.4.0"
Random = "1"
RecursiveArrayTools = "3"
SparseArrays = "1"
Expand Down
2 changes: 1 addition & 1 deletion src/QuantumOpticsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import LinearAlgebra: mul!, rmul!
import RecursiveArrayTools

import QuantumInterface: dagger, directsum, , dm, embed, nsubsystems, expect, identityoperator, identitysuperoperator,
permutesystems, projector, ptrace, reduced, tensor, , variance, apply!, basis, AbstractSuperOperator
permutesystems, projector, ptrace, reduced, tensor, , variance, apply!, basis, basis_l, basis_r

# index helpers
import QuantumInterface: complement, remove, shiftremove, reducedindices!, check_indices, check_sortedindices, check_embed_indices
Expand Down
2 changes: 1 addition & 1 deletion src/bases.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import QuantumInterface: Basis, basis, GenericBasis, CompositeBasis,
equal_shape, equal_bases, IncompatibleBases, @samebases, samebases, check_samebases,
equal_shape, IncompatibleBases, @samebases, samebases, check_samebases,
multiplicable, check_multiplicable, reduced, ptrace, permutesystems
15 changes: 8 additions & 7 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Abstract type for operators with a data field.
This is an abstract type for operators that have a direct matrix representation
stored in their `.data` field.
"""
abstract type DataOperator{BL,BR} <: AbstractOperator{BL,BR} end
abstract type BLROperator{BL,BR} <: AbstractOperator end
abstract type DataOperator{BL,BR} <: BLROperator{BL,BR} end


# Common error messages
Expand Down Expand Up @@ -109,18 +110,18 @@ Expectation value of the given operator `op` for the specified `state`.
`state` can either be a (density) operator or a ket.
"""
expect(op::AbstractOperator{B,B}, state::Ket{B}) where B = dot(state.data, (op * state).data)
expect(op::BLROperator{B,B}, state::Ket{B}) where B = dot(state.data, (op * state).data)

# TODO upstream this one
# expect(op::AbstractOperator{B,B}, state::AbstractKet{B}) where B = norm(op * state) ^ 2

function expect(indices, op::AbstractOperator{B,B}, state::Ket{B2}) where {B,B2<:CompositeBasis}
function expect(indices, op::BLROperator{B,B}, state::Ket{B2}) where {B,B2<:CompositeBasis}
N = length(state.basis.shape)
indices_ = complement(N, indices)
expect(op, ptrace(state, indices_))
end

expect(index::Integer, op::AbstractOperator{B,B}, state::Ket{B2}) where {B,B2<:CompositeBasis} = expect([index], op, state)
expect(index::Integer, op::BLROperator{B,B}, state::Ket{B2}) where {B,B2<:CompositeBasis} = expect([index], op, state)

"""
variance(op, state)
Expand All @@ -129,18 +130,18 @@ Variance of the given operator `op` for the specified `state`.
`state` can either be a (density) operator or a ket.
"""
function variance(op::AbstractOperator{B,B}, state::Ket{B}) where B
function variance(op::BLROperator{B,B}, state::Ket{B}) where B
x = op*state
state.data'*(op*x).data - (state.data'*x.data)^2
end

function variance(indices, op::AbstractOperator{B,B}, state::Ket{BC}) where {B,BC<:CompositeBasis}
function variance(indices, op::BLROperator{B,B}, state::Ket{BC}) where {B,BC<:CompositeBasis}
N = length(state.basis.shape)
indices_ = complement(N, indices)
variance(op, ptrace(state, indices_))
end

variance(index::Integer, op::AbstractOperator{B,B}, state::Ket{BC}) where {B,BC<:CompositeBasis} = variance([index], op, state)
variance(index::Integer, op::BLROperator{B,B}, state::Ket{BC}) where {B,BC<:CompositeBasis} = variance([index], op, state)

# Helper functions to check validity of arguments
function check_ptrace_arguments(a::AbstractOperator, indices)
Expand Down
17 changes: 10 additions & 7 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ Operator(qets::AbstractVector{<:Ket}) = Operator(first(qets).basis, GenericBasis
Operator(basis_r::Basis,qets::AbstractVector{<:Ket}) = Operator(first(qets).basis, basis_r, qets)
Operator(basis_l::BL,basis_r::BR,qets::AbstractVector{<:Ket}) where {BL,BR} = Operator{BL,BR}(basis_l, basis_r, reduce(hcat, getfield.(qets, :data)))

basis_l(op::Operator) = op.basis_l
basis_r(op::Operator) = op.basis_r

QuantumInterface.traceout!(s::QuantumOpticsBase.Operator, i) = QuantumInterface.ptrace(s,i)

Base.zero(op::Operator) = Operator(op.basis_l,op.basis_r,zero(op.data))
Expand Down Expand Up @@ -98,22 +101,22 @@ Base.isapprox(x::DataOperator, y::DataOperator; kwargs...) = false
*(a::Operator{B1, B2, T}, b::DataOperator{B2, B3}) where {B1, B2, B3, T} = error("no `*` method defined for DataOperator subtype $(typeof(b))") # defined to avoid method ambiguity
*(a::Operator, b::Number) = Operator(a.basis_l, a.basis_r, b*a.data)
*(a::Number, b::Operator) = Operator(b.basis_l, b.basis_r, a*b.data)
function *(op1::AbstractOperator{B1,B2}, op2::Operator{B2,B3,T}) where {B1,B2,B3,T}
function *(op1::BLROperator{B1,B2}, op2::Operator{B2,B3,T}) where {B1,B2,B3,T}
result = Operator{B1,B3}(op1.basis_l, op2.basis_r, similar(_parent(op2.data),promote_type(eltype(op1),eltype(op2)),length(op1.basis_l),length(op2.basis_r)))
mul!(result,op1,op2)
return result
end
function *(op1::Operator{B1,B2,T}, op2::AbstractOperator{B2,B3}) where {B1,B2,B3,T}
function *(op1::Operator{B1,B2,T}, op2::BLROperator{B2,B3}) where {B1,B2,B3,T}
result = Operator{B1,B3}(op1.basis_l, op2.basis_r, similar(_parent(op1.data),promote_type(eltype(op1),eltype(op2)),length(op1.basis_l),length(op2.basis_r)))
mul!(result,op1,op2)
return result
end
function *(op::AbstractOperator{BL,BR}, psi::Ket{BR,T}) where {BL,BR,T}
function *(op::BLROperator{BL,BR}, psi::Ket{BR,T}) where {BL,BR,T}
result = Ket{BL,T}(op.basis_l,similar(psi.data,length(op.basis_l)))
mul!(result,op,psi)
return result
end
function *(psi::Bra{BL,T}, op::AbstractOperator{BL,BR}) where {BL,BR,T}
function *(psi::Bra{BL,T}, op::BLROperator{BL,BR}) where {BL,BR,T}
result = Bra{BR,T}(op.basis_r, similar(psi.data,length(op.basis_r)))
mul!(result,psi,op)
return result
Expand Down Expand Up @@ -388,7 +391,7 @@ mul!(result::Bra{B2},a::Bra{B1},b::Operator{B1,B2},alpha,beta) where {B1,B2} = (
rmul!(op::Operator, x) = (rmul!(op.data, x); op)

# Multiplication for Operators in terms of their gemv! implementation
function mul!(result::Operator{B1,B3},M::AbstractOperator{B1,B2},b::Operator{B2,B3},alpha,beta) where {B1,B2,B3}
function mul!(result::Operator{B1,B3},M::BLROperator{B1,B2},b::Operator{B2,B3},alpha,beta) where {B1,B2,B3}
for i=1:size(b.data, 2)
bket = Ket(b.basis_l, b.data[:,i])
resultket = Ket(M.basis_l, result.data[:,i])
Expand All @@ -398,7 +401,7 @@ function mul!(result::Operator{B1,B3},M::AbstractOperator{B1,B2},b::Operator{B2,
return result
end

function mul!(result::Operator{B1,B3},b::Operator{B1,B2},M::AbstractOperator{B2,B3},alpha,beta) where {B1,B2,B3}
function mul!(result::Operator{B1,B3},b::Operator{B1,B2},M::BLROperator{B2,B3},alpha,beta) where {B1,B2,B3}
for i=1:size(b.data, 1)
bbra = Bra(b.basis_r, vec(b.data[i,:]))
resultbra = Bra(M.basis_r, vec(result.data[i,:]))
Expand Down Expand Up @@ -469,4 +472,4 @@ Base.similar(x::Operator, t) = typeof(x)(x.basis_l, x.basis_r, copy(x.data))
RecursiveArrayTools.recursivecopy!(dest::Operator{Bl,Br,A},src::Operator{Bl,Br,A}) where {Bl,Br,A} = copyto!(dest,src) # ODE in-place equations
RecursiveArrayTools.recursivecopy(x::Operator) = copy(x)
RecursiveArrayTools.recursivecopy(x::AbstractArray{T}) where {T<:Operator} = copy(x)
RecursiveArrayTools.recursivefill!(x::Operator, a) = fill!(x, a)
RecursiveArrayTools.recursivefill!(x::Operator, a) = fill!(x, a)
4 changes: 2 additions & 2 deletions src/operators_lazyproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ function LazyProduct(operators::T, factor::F=1) where {T,F}
LazyProduct{BL,BR,F,T,KTL,BTR}(operators, ket_l, bra_r, factor)
end



basis_l(op::LazyProduct) = op.basis_l
basis_r(op::LazyProduct) = op.basis_r

LazyProduct(operators::Vector{T}, factor=1) where T<:AbstractOperator = LazyProduct((operators...,), factor)
LazyProduct(operators::AbstractOperator...) = LazyProduct((operators...,))
Expand Down
5 changes: 4 additions & 1 deletion src/operators_lazysum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end
"""
Abstract class for all Lazy type operators ([`LazySum`](@ref), [`LazyProduct`](@ref), and [`LazyTensor`](@ref))
"""
abstract type LazyOperator{BL,BR} <: AbstractOperator{BL,BR} end
abstract type LazyOperator{BL,BR} <: BLROperator{BL,BR} end

"""
LazySum([Tf,] [factors,] operators)
Expand Down Expand Up @@ -49,6 +49,9 @@ mutable struct LazySum{BL,BR,F,T} <: LazyOperator{BL,BR}
end
end

basis_l(op::LazySum) = op.basis_l
basis_r(op::LazySum) = op.basis_r

LazySum(::Type{Tf}, basis_l::Basis, basis_r::Basis) where Tf = LazySum(basis_l,basis_r,Tf[],())
LazySum(basis_l::Basis, basis_r::Basis) = LazySum(ComplexF64, basis_l, basis_r)

Expand Down
7 changes: 5 additions & 2 deletions src/operators_lazytensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ LazyTensor(op::T, factor) where {T<:LazyTensor} = LazyTensor(op.basis_l, op.basi
LazyTensor(basis_l::CompositeBasis, basis_r::CompositeBasis, index::Integer, operator::T, factor=one(eltype(operator))) where T<:AbstractOperator = LazyTensor(basis_l, basis_r, [index], (operator,), factor)
LazyTensor(basis::Basis, index, operators, factor=_default_factor(operators)) = LazyTensor(basis, basis, index, operators, factor)

basis_l(op::LazyTensor) = op.basis_l
basis_r(op::LazyTensor) = op.basis_r

Base.copy(x::LazyTensor) = LazyTensor(x.basis_l, x.basis_r, copy(x.indices), Tuple(copy(op) for op in x.operators), x.factor)
function Base.eltype(x::LazyTensor)
F = eltype(x.factor)
Expand Down Expand Up @@ -112,14 +115,14 @@ function -(a::T1,b::T2) where {T1 <: single_dataoperator{B1,B2},T2 <: single_dat
LazySum(a) - LazySum(b)
end

function tensor(a::LazyTensor{B1,B2},b::AbstractOperator{B3,B4}) where {B1,B2,B3,B4}
function tensor(a::LazyTensor{B1,B2},b::BLROperator{B3,B4}) where {B1,B2,B3,B4}
if B3 <: CompositeBasis || B4 <: CompositeBasis
throw(ArgumentError("tensor(a::LazyTensor{B1,B2},b::AbstractOperator{B3,B4}) is not implemented for B3 or B4 being CompositeBasis unless b is identityoperator "))
else
a LazyTensor(b.basis_l,b.basis_r,[1],(b,),1)
end
end
function tensor(a::AbstractOperator{B1,B2},b::LazyTensor{B3,B4}) where {B1,B2,B3,B4}
function tensor(a::BLROperator{B1,B2},b::LazyTensor{B3,B4}) where {B1,B2,B3,B4}
if B1 <: CompositeBasis || B2 <: CompositeBasis
throw(ArgumentError("tensor(a::AbstractOperator{B1,B2},b::LazyTensor{B3,B4}) is not implemented for B1 or B2 being CompositeBasis unless b is identityoperator "))
else
Expand Down
5 changes: 4 additions & 1 deletion src/particle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ end
Abstract type for all implementations of FFT operators.
"""
abstract type FFTOperator{BL, BR, T} <: AbstractOperator{BL,BR} end
abstract type FFTOperator{BL, BR, T} <: BLROperator{BL,BR} end

Base.eltype(x::FFTOperator) = promote_type(eltype(x.mul_before), eltype(x.mul_after))

Expand Down Expand Up @@ -310,6 +310,9 @@ struct FFTOperators{BL,BR,T,P1,P2,P3,P4} <: FFTOperator{BL, BR, T}
end
end

basis_l(op::FFTOperators) = op.basis_l
basis_r(op::FFTOperators) = op.basis_r

"""
FFTKets
Expand Down
5 changes: 4 additions & 1 deletion src/spinors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,15 @@ end
Lazy implementation of `directsum`
"""
mutable struct LazyDirectSum{BL,BR,T} <: AbstractOperator{BL,BR}
mutable struct LazyDirectSum{BL,BR,T} <: BLROperator{BL,BR}
basis_l::BL
basis_r::BR
operators::T
end

basis_l(op::LazyDirectSum) = op.basis_l
basis_r(op::LazyDirectSum) = op.basis_r

# Methods
LazyDirectSum(op1::AbstractOperator, op2::AbstractOperator) = LazyDirectSum(directsum(op1.basis_l,op2.basis_l),directsum(op1.basis_r,op2.basis_r),(op1,op2))
LazyDirectSum(op1::LazyDirectSum, op2::AbstractOperator) = LazyDirectSum(directsum(op1.basis_l,op2.basis_l),directsum(op1.basis_r,op2.basis_r),(op1.operators...,op2))
Expand Down
4 changes: 2 additions & 2 deletions src/state_definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ end
Coherent thermal state ``D(α)exp(-H/T)/Tr[exp(-H/T)]D^†(α)``.
"""
function coherentthermalstate(::Type{C},basis::B,H::AbstractOperator{B,B},T,alpha) where {C,B<:FockBasis}
function coherentthermalstate(::Type{C},basis::B,H::BLROperator{B,B},T,alpha) where {C,B<:FockBasis}
D = displace(C,basis,alpha)
return D*thermalstate(H,T)*dagger(D)
end
coherentthermalstate(basis::B,H::AbstractOperator{B,B},T,alpha) where B<:FockBasis = coherentthermalstate(ComplexF64,basis,H,T,alpha)
coherentthermalstate(basis::B,H::BLROperator{B,B},T,alpha) where B<:FockBasis = coherentthermalstate(ComplexF64,basis,H,T,alpha)

"""
phase_average(rho)
Expand Down
9 changes: 6 additions & 3 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import QuantumInterface: StateVector, AbstractKet, AbstractBra
Bra state defined by coefficients in respect to the basis.
"""
mutable struct Bra{B,T} <: AbstractBra{B,T}
mutable struct Bra{B,T} <: AbstractBra
basis::B
data::T
function Bra{B,T}(b::B, data::T) where {B,T}
Expand All @@ -21,7 +21,7 @@ end
Ket state defined by coefficients in respect to the given basis.
"""
mutable struct Ket{B,T} <: AbstractKet{B,T}
mutable struct Ket{B,T} <: AbstractKet
basis::B
data::T
function Ket{B,T}(b::B, data::T) where {B,T}
Expand All @@ -30,6 +30,9 @@ mutable struct Ket{B,T} <: AbstractKet{B,T}
end
end

basis(x::Bra) = x.basis
basis(x::Ket) = x.basis

Base.zero(x::Bra) = Bra(x.basis, zero(x.data))
Base.zero(x::Ket) = Ket(x.basis, zero(x.data))
eltype(::Type{K}) where {K <: Ket{B,V}} where {B,V} = eltype(V)
Expand Down Expand Up @@ -256,4 +259,4 @@ RecursiveArrayTools.recursivecopy!(dest::Ket{B,A},src::Ket{B,A}) where {B,A} = c
RecursiveArrayTools.recursivecopy!(dest::Bra{B,A},src::Bra{B,A}) where {B,A} = copyto!(dest, src)
RecursiveArrayTools.recursivecopy(x::T) where {T<:Union{Ket, Bra}} = copy(x)
RecursiveArrayTools.recursivecopy(x::AbstractArray{T}) where {T<:Union{Ket, Bra}} = copy(x)
RecursiveArrayTools.recursivefill!(x::T, a) where {T<:Union{Ket, Bra}} = fill!(x, a)
RecursiveArrayTools.recursivefill!(x::T, a) where {T<:Union{Ket, Bra}} = fill!(x, a)
6 changes: 4 additions & 2 deletions src/states_lazyket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The subkets are stored in the `kets` field.
The main purpose of such a ket are simple computations for large product states, such as expectation values.
It's used to compute numeric initial states in QuantumCumulants.jl (see QuantumCumulants.initial_values).
"""
mutable struct LazyKet{B,T} <: AbstractKet{B,T}
mutable struct LazyKet{B,T} <: AbstractKet
basis::B
kets::T
function LazyKet(b::B, kets::T) where {B<:CompositeBasis,T<:Tuple}
Expand All @@ -23,6 +23,8 @@ function LazyKet(b::CompositeBasis, kets::Vector)
return LazyKet(b,Tuple(kets))
end

basis(ket::LazyKet) = b.basis

Base.eltype(ket::LazyKet) = Base.promote_type(eltype.(ket.kets)...)

Base.isequal(x::LazyKet, y::LazyKet) = isequal(x.basis, y.basis) && isequal(x.kets, y.kets)
Expand Down Expand Up @@ -145,4 +147,4 @@ function mul!(y::LazyKet{BL}, op::LazyTensor{BL, BR}, x::LazyKet{BR}, alpha, bet

rmul!(y.kets[1].data, op.factor * alpha)
return y
end
end
12 changes: 7 additions & 5 deletions src/superoperators.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import QuantumInterface: AbstractSuperOperator
import FastExpm: fastExpm

abstract type BLRSuperOperator{BL,BR} <: AbstractSuperOperator end

"""
SuperOperator <: AbstractSuperOperator
SuperOperator stored as representation, e.g. as a Matrix.
"""
mutable struct SuperOperator{B1,B2,T} <: AbstractSuperOperator{B1,B2}
mutable struct SuperOperator{B1,B2,T} <: BLRSuperOperator{B1,B2}
basis_l::B1
basis_r::B2
data::T
Expand Down Expand Up @@ -167,12 +169,12 @@ holds. `A` ond `B` can be dense or a sparse operators.
"""
sprepost(A::AbstractOperator, B::AbstractOperator) = SuperOperator((A.basis_l, B.basis_r), (A.basis_r, B.basis_l), kron(permutedims(B.data), A.data))

function _check_input(H::AbstractOperator{B1,B2}, J::Vector, Jdagger::Vector, rates) where {B1,B2}
function _check_input(H::BLROperator{B1,B2}, J::Vector, Jdagger::Vector, rates) where {B1,B2}
for j=J
@assert isa(j, AbstractOperator{B1,B2})
@assert isa(j, BLROperator{B1,B2})
end
for j=Jdagger
@assert isa(j, AbstractOperator{B1,B2})
@assert isa(j, BLROperator{B1,B2})
end
@assert length(J)==length(Jdagger)
if isa(rates, Matrix{<:Number})
Expand Down Expand Up @@ -323,7 +325,7 @@ end
Superoperator represented as a choi state.
"""
mutable struct ChoiState{B1,B2,T} <: AbstractSuperOperator{B1,B2}
mutable struct ChoiState{B1,B2,T} <: BLRSuperOperator{B1,B2}
basis_l::B1
basis_r::B2
data::T
Expand Down
8 changes: 4 additions & 4 deletions src/time_dependent_operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ and [`current_time`](@ref). A shorthand `op(t)`, equivalent to
A time-dependent operator is always concrete-valued according to the current
time of its internal clock.
"""
abstract type AbstractTimeDependentOperator{BL,BR} <: AbstractOperator{BL,BR} end
abstract type AbstractTimeDependentOperator{BL,BR} <: BLROperator{BL,BR} end

"""
current_time(op::AbstractOperator)
Expand Down Expand Up @@ -71,9 +71,9 @@ end

for func in (:expect, :variance)
@eval $func(op::AbstractTimeDependentOperator{B,B}, x::Ket{B}) where B = $func(static_operator(op), x)
@eval $func(op::AbstractTimeDependentOperator{B,B}, x::AbstractOperator{B,B}) where B = $func(static_operator(op), x)
@eval $func(index::Integer, op::AbstractTimeDependentOperator{B1,B2}, x::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} = $func(index, static_operator(op), x)
@eval $func(indices, op::AbstractTimeDependentOperator{B1,B2}, x::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} = $func(indices, static_operator(op), x)
@eval $func(op::AbstractTimeDependentOperator{B,B}, x::BLROperator{B,B}) where B = $func(static_operator(op), x)
@eval $func(index::Integer, op::AbstractTimeDependentOperator{B1,B2}, x::BLROperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} = $func(index, static_operator(op), x)
@eval $func(indices, op::AbstractTimeDependentOperator{B1,B2}, x::BLROperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} = $func(indices, static_operator(op), x)
end

# TODO: Consider using promotion to define arithmetic between operator types
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
names = [
"test_sortedindices.jl",

"test_bases.jl",
"test_states.jl",

"test_operators.jl",
Expand Down
Loading

0 comments on commit c86c1e6

Please sign in to comment.