Skip to content

Commit

Permalink
CairoMakie default lib
Browse files Browse the repository at this point in the history
  • Loading branch information
LorenzoFioroni committed Nov 10, 2024
1 parent 62937fd commit 2cc344b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 39 deletions.
61 changes: 25 additions & 36 deletions ext/QuantumToolboxCairoMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,31 @@ using CairoMakie

function QuantumToolbox.plot_wigner(
library::Val{:CairoMakie},
state::QuantumObject{<:AbstractArray{T},OpType},
state::QuantumObject{<:AbstractArray{T},OpType};
xvec::Union{Nothing,AbstractVector} = nothing,
yvec::Union{Nothing,AbstractVector} = nothing;
yvec::Union{Nothing,AbstractVector} = nothing,
g::Real = 2,
method::WignerSolver = WignerClenshaw(),
projection::String = "2d",
fig::Union{Figure,Nothing} = nothing,
ax::Union{Axis,Nothing} = nothing,
projection::Union{Val,Symbol} = Val(:two_dim),
location::Union{GridPosition,Nothing} = nothing,
colorbar::Bool = false,
kwargs...,
) where {T,OpType<:Union{BraQuantumObject,KetQuantumObject,OperatorQuantumObject}}
projection == "2d" || projection == "3d" || throw(ArgumentError("Unsupported projection: $projection"))
QuantumToolbox.getVal(projection) == :two_dim ||
QuantumToolbox.getVal(projection) == :three_dim ||
throw(ArgumentError("Unsupported projection: $projection"))

return _plot_wigner(
library,
state,
xvec,
yvec,
Val(Symbol(projection)),
QuantumToolbox.makeVal(projection),
g,
method,
fig,
ax,
location,
colorbar;
kwargs...
kwargs...,
)
end

Expand All @@ -38,20 +38,17 @@ function _plot_wigner(
state::QuantumObject{<:AbstractArray{T},OpType},
xvec::AbstractVector,
yvec::AbstractVector,
projection::Val{Symbol("2d")},
projection::Val{:two_dim},
g::Real,
method::WignerSolver,
fig::Union{Figure,Nothing},
ax::Union{Axis,Nothing},
location::Union{GridPosition,Nothing},
colorbar::Bool;
kwargs...,
) where {T,OpType<:Union{BraQuantumObject,KetQuantumObject,OperatorQuantumObject}}
fig, ax = _getFigAx(fig, ax)
fig, location = _getFigAndLocation(location)

gridPos = _gridPosFromAx(ax)
CairoMakie.delete!(ax)
lyt = GridLayout(location)

lyt = GridLayout(gridPos)
ax = Axis(lyt[1, 1])

wig = wigner(state, xvec, yvec; g = g, method = method)
Expand All @@ -74,20 +71,17 @@ function _plot_wigner(
state::QuantumObject{<:AbstractArray{T},OpType},
xvec::AbstractVector,
yvec::AbstractVector,
projection::Val{Symbol("3d")},
projection::Val{:three_dim},
g::Real,
method::WignerSolver,
fig::Union{Figure,Nothing},
ax::Union{Axis,Nothing},
location::Union{GridPosition,Nothing},
colorbar::Bool;
kwargs...,
) where {T,OpType<:Union{BraQuantumObject,KetQuantumObject,OperatorQuantumObject}}
fig, ax = _getFigAx(fig, ax)
fig, location = _getFigAndLocation(location)

gridPos = _gridPosFromAx(ax)
CairoMakie.delete!(ax)
lyt = GridLayout(location)

lyt = GridLayout(gridPos)
ax = Axis3(lyt[1, 1], azimuth = 1.775pi, elevation = pi / 16, protrusions = (30, 90, 30, 30), viewmode = :stretch)

wig = wigner(state, xvec, yvec; g = g, method = method)
Expand All @@ -106,22 +100,17 @@ function _plot_wigner(
return fig, ax, surf
end

_getFigAx(fig::Figure, ax::Axis) = fig, ax
_getFigAx(fig::Figure, ::Nothing) = fig, Axis(fig[1, 1])
_getFigAx(::Nothing, ax::Axis) = _figFromChildren(ax), ax
function _getFigAx(::Nothing, ::Nothing)
function _getFigAndLocation(location::Nothing)
fig = Figure()
ax = Axis(fig[1, 1])
return fig, ax
return fig, fig[1, 1]
end
function _getFigAndLocation(location::GridPosition)
fig = _figFromChildren(location.layout)
return fig, location
end

_figFromChildren(children) = _figFromChildren(children.parent)
_figFromChildren(fig::Figure) = fig

function _gridPosFromAx(ax::Axis)
content = CairoMakie.Makie.GridLayoutBase.gridcontent(ax)
gl, sp, si = content.parent, content.span, content.side
return GridPosition(gl, sp, si)
end
_figFromChildren(::Nothing) = throw(ArgumentError("No Figure has been found at the top of the layout hierarchy."))

end
9 changes: 6 additions & 3 deletions src/visualization.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
export plot_wigner

plot_wigner(library::Val{T}, args...; kwargs...) where {T} =
throw(ArgumentError("Unsupported visualization library: $(getVal(library))"))
plot_wigner(library::Symbol, args...; kwargs...) = plot_wigner(Val(library), args...; kwargs...)
plot_wigner(
state::QuantumObject{<:AbstractArray{T},OpType};
library::Union{Val,Symbol} = Val(:CairoMakie),
kwargs...,
) where {T,OpType<:Union{BraQuantumObject,KetQuantumObject,OperatorQuantumObject}} =
plot_wigner(makeVal(library), state; kwargs...)

0 comments on commit 2cc344b

Please sign in to comment.