Skip to content

Commit

Permalink
define dual_type and nondual_type
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Dec 1, 2024
1 parent e19b521 commit a270ec7
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ dual(x) = x
nondual(r::AbstractUnitRange) = r
isdual(::AbstractUnitRange) = false

dual_type(x) = dual_type(typeof(x))
dual_type(T::Type) = T
nondual_type(x) = nondual_type(typeof(x))
nondual_type(T::Type) = T

using LabelledNumbers: LabelledStyle, IsLabelled, NotLabelled, label, labelled, unlabel

dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))
Expand Down
17 changes: 17 additions & 0 deletions src/gradedunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@ nondual(a::GradedUnitRangeDual) = a.nondual_unitrange
dual(a::GradedUnitRangeDual) = nondual(a)
flip(a::GradedUnitRangeDual) = dual(flip(nondual(a)))
isdual(::GradedUnitRangeDual) = true

function nondual_type(
::Type{<:GradedUnitRangeDual{<:Any,<:Any,NondualUnitRange}}
) where {NondualUnitRange}
return NondualUnitRange
end
dual_type(T::Type{<:GradedUnitRangeDual}) = nondual_type(T)
function dual_type(type::Type{<:AbstractGradedUnitRange{T,BlockLasts}}) where {T,BlockLasts}
return GradedUnitRangeDual{T,BlockLasts,type}
end
function LabelledNumbers.label_type(
type::Type{<:GradedUnitRangeDual}
)
# `dual_type` right now doesn't do anything but anticipates defining `SectorDual`.
return dual_type(label_type(nondual_type(type)))
end

## TODO: Define this to instantiate a dual unit range.
## materialize_dual(a::GradedUnitRangeDual) = materialize_dual(nondual(a))

Expand Down
15 changes: 13 additions & 2 deletions src/labelledunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@ label_dual(::IsLabelled, a::LabelledUnitRangeDual) = dual(label_dual(nondual(a))
isdual(::LabelledUnitRangeDual) = true
blocklabels(la::LabelledUnitRangeDual) = [label(la)]

function nondual_type(
::Type{<:LabelledUnitRangeDual{<:Any,NondualUnitRange}}
) where {NondualUnitRange}
return NondualUnitRange
end
dual_type(T::Type{<:LabelledUnitRangeDual}) = nondual_type(T)
function dual_type(T::Type{<:LabelledUnitRange})
return LabelledUnitRangeDual{eltype(T),T}
end

LabelledNumbers.label(a::LabelledUnitRangeDual) = dual(label(nondual(a)))
LabelledNumbers.unlabel(a::LabelledUnitRangeDual) = unlabel(nondual(a))
LabelledNumbers.LabelledStyle(::LabelledUnitRangeDual) = IsLabelled()
function LabelledNumbers.label_type(
::Type{<:LabelledUnitRangeDual{<:Any,NondualUnitRange}}
type::Type{<:LabelledUnitRangeDual{<:Any,NondualUnitRange}}
) where {NondualUnitRange}
return label_type(NondualUnitRange)
# `dual_type` right now doesn't do anything but anticipates defining `SectorDual`.
return dual_type(label_type(nondual_type(type)))
end

for f in [:first, :getindex, :last, :length, :step]
Expand Down
12 changes: 12 additions & 0 deletions test/test_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ using GradedUnitRanges:
blockmergesortperm,
blocksortperm,
dual,
dual_type,
flip,
gradedrange,
isdual,
nondual,
nondual_type,
space_isequal,
sector_type
using LabelledNumbers:
Expand Down Expand Up @@ -95,6 +97,11 @@ end
@test label_type(lad) === U1
@test sector_type(lad) === U1

@test dual_type(la) === typeof(lad)
@test dual_type(lad) === typeof(la)
@test nondual_type(lad) === typeof(la)
@test nondual_type(la) === typeof(la)

@test iterate(lad) == (1, 1)
@test iterate(lad) == (1, 1)
@test iterate(lad, 1) == (2, 2)
Expand Down Expand Up @@ -167,6 +174,11 @@ end
@test !space_isequal(a, ad)
@test !space_isequal(ad, a)

@test dual_type(a) === typeof(ad)
@test dual_type(ad) === typeof(a)
@test nondual_type(ad) === typeof(a)
@test nondual_type(a) === typeof(a)

@test isdual(ad)
@test !isdual(a)
@test axes(Base.Slice(a)) isa Tuple{typeof(a)}
Expand Down

0 comments on commit a270ec7

Please sign in to comment.