Skip to content

Commit

Permalink
Eliminated special case for dot product in multiply kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
brainandforce committed May 16, 2024
1 parent 054d34e commit 7cb5747
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/src/api/internal.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ CliffordNumbers.mul
CliffordNumbers.GradeFilter
CliffordNumbers.nondegenerate_mask
CliffordNumbers.mul_mask
CliffordNumbers.mul_signs
CliffordNumbers.bitindex_shuffle
CliffordNumbers.widen_grade_for_mul
```
Expand Down
35 changes: 27 additions & 8 deletions src/multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ const ContractionGradeFilters = Union{GradeFilter{:⨼},GradeFilter{:⨽},GradeF
Generates a `NTuple{L,Bool}` which is `true` whenever the multiplication of the blade indexed by `a`
and blades indexed by `B` is nonzero. `false` is returned if the grades multiply to zero due to the
squaring of a degenerate component, or if they are filtered by `F`.
In the special case of dot products (`F = CliffordNumbers.GradeFilter{:dot}()`), the return type is
not an` NTuple{L,Bool}`, but `NTuple{L,Int8}`, as some multiplications must flip sign
"""
function mul_mask(F::GradeFilter, a::BitIndex{Q}, B::NTuple{L,BitIndex{Q}}) where {L,Q}
return map(b -> F(a,b) & nondegenerate_mult(a,b), B)
Expand All @@ -107,6 +104,32 @@ end
mul_mask(F::GradeFilter, a::BitIndex{Q}, B::BitIndices{Q}) where Q = mul_mask(F, a, Tuple(B))
mul_mask(F::GradeFilter, B::BitIndices{Q}, a::BitIndex{Q}) where Q = mul_mask(F, Tuple(B), a)

"""
CliffordNumbers.mul_signs(F::GradeFilter, a::BitIndex{Q}, B::NTuple{L,BitIndices{Q}})
CliffordNumbers.mul_signs(F::GradeFilter, B::NTuple{L,BitIndices{Q}}, a::BitIndex{Q})
CliffordNumbers.mul_signs(F::GradeFilter, a::BitIndex{Q}, B::BitIndices{Q})
CliffordNumbers.mul_signs(F::GradeFilter, B::BitIndices{Q}, a::BitIndex{Q})
Generates an `NTuple{L,Int8}` which represents the sign associated with the multiplication needed to
calculate components of a multiplication result.
This is equivalent to `sign.(B)` unless `F === CliffordNumbers.GradeFilter{:dot}()`.
"""
mul_signs(::GradeFilter, ::BitIndex{Q}, B::NTuple{L,BitIndex{Q}}) where {L,Q} = sign.(B)
mul_signs(::GradeFilter, B::NTuple{L,BitIndex{Q}}, ::BitIndex{Q}) where {L,Q} = sign.(B)

function mul_signs(::GradeFilter{:dot}, a::BitIndex{Q}, B::NTuple{L,BitIndex{Q}}) where {L,Q}
return sign.(B) .* Int8(-1).^(grade.(B) .* (grade(a) .- grade.(B)))
end

function mul_signs(::GradeFilter{:dot}, B::NTuple{L,BitIndex{Q}}, a::BitIndex{Q}) where {L,Q}
return sign.(B) .* Int8(-1).^(grade(a) .* (grade.(B) .- grade(a)))
end

mul_signs(F::GradeFilter, a::BitIndex{Q}, B::BitIndices{Q}) where Q = mul_signs(F, a, Tuple(B))
mul_signs(F::GradeFilter, B::BitIndices{Q}, a::BitIndex{Q}) where Q = mul_signs(F, Tuple(B), a)

#---Product return types---------------------------------------------------------------------------#
"""
CliffordNumbers.product_return_type(::Type{X}, ::Type{Y}, [::GradeFilter{S}])
Expand Down Expand Up @@ -204,11 +227,7 @@ kernel just returns the geometric product.
# But all values are known at compile time, so interpolate them into expressions
ia = to_index(x, a)
tuple_inds = to_index.(y, inds)
# Special case for dot products
signs = sign.(inds)
if F <: GradeFilter{:dot}
signs = signs .* Int8(-1).^(grade.(inds) .* (grade(a) .- grade.(inds)))
end
signs = mul_signs(F(), a, inds)
# Construct the tuples that contribute to the product
x_tuple_ex = :(Tuple(x)[$ia] .* $x_mask)
y_tuple_ex = :(getindex.(tuple(Tuple(y)), $tuple_inds) .* $signs .* $y_mask)
Expand Down

0 comments on commit 7cb5747

Please sign in to comment.