From 7cb574744f2f5d1c6997d9d73335f7fe63fe0007 Mon Sep 17 00:00:00 2001 From: Brandon Flores Date: Thu, 16 May 2024 16:30:17 -0500 Subject: [PATCH] Eliminated special case for dot product in multiply kernel --- docs/src/api/internal.md | 1 + src/multiply.jl | 35 +++++++++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/docs/src/api/internal.md b/docs/src/api/internal.md index 212141a..23e91f4 100644 --- a/docs/src/api/internal.md +++ b/docs/src/api/internal.md @@ -35,6 +35,7 @@ CliffordNumbers.mul CliffordNumbers.GradeFilter CliffordNumbers.nondegenerate_mask CliffordNumbers.mul_mask +CliffordNumbers.mul_signs CliffordNumbers.bitindex_shuffle CliffordNumbers.widen_grade_for_mul ``` diff --git a/src/multiply.jl b/src/multiply.jl index e3ebcf1..3e298ec 100644 --- a/src/multiply.jl +++ b/src/multiply.jl @@ -92,9 +92,6 @@ const ContractionGradeFilters = Union{GradeFilter{:⨼},GradeFilter{:⨽},GradeF Generates a `NTuple{L,Bool}` which is `true` whenever the multiplication of the blade indexed by `a` and blades indexed by `B` is nonzero. `false` is returned if the grades multiply to zero due to the squaring of a degenerate component, or if they are filtered by `F`. - -In the special case of dot products (`F = CliffordNumbers.GradeFilter{:dot}()`), the return type is -not an` NTuple{L,Bool}`, but `NTuple{L,Int8}`, as some multiplications must flip sign """ function mul_mask(F::GradeFilter, a::BitIndex{Q}, B::NTuple{L,BitIndex{Q}}) where {L,Q} return map(b -> F(a,b) & nondegenerate_mult(a,b), B) @@ -107,6 +104,32 @@ end mul_mask(F::GradeFilter, a::BitIndex{Q}, B::BitIndices{Q}) where Q = mul_mask(F, a, Tuple(B)) mul_mask(F::GradeFilter, B::BitIndices{Q}, a::BitIndex{Q}) where Q = mul_mask(F, Tuple(B), a) +""" + CliffordNumbers.mul_signs(F::GradeFilter, a::BitIndex{Q}, B::NTuple{L,BitIndices{Q}}) + CliffordNumbers.mul_signs(F::GradeFilter, B::NTuple{L,BitIndices{Q}}, a::BitIndex{Q}) + + CliffordNumbers.mul_signs(F::GradeFilter, a::BitIndex{Q}, B::BitIndices{Q}) + CliffordNumbers.mul_signs(F::GradeFilter, B::BitIndices{Q}, a::BitIndex{Q}) + +Generates an `NTuple{L,Int8}` which represents the sign associated with the multiplication needed to +calculate components of a multiplication result. + +This is equivalent to `sign.(B)` unless `F === CliffordNumbers.GradeFilter{:dot}()`. +""" +mul_signs(::GradeFilter, ::BitIndex{Q}, B::NTuple{L,BitIndex{Q}}) where {L,Q} = sign.(B) +mul_signs(::GradeFilter, B::NTuple{L,BitIndex{Q}}, ::BitIndex{Q}) where {L,Q} = sign.(B) + +function mul_signs(::GradeFilter{:dot}, a::BitIndex{Q}, B::NTuple{L,BitIndex{Q}}) where {L,Q} + return sign.(B) .* Int8(-1).^(grade.(B) .* (grade(a) .- grade.(B))) +end + +function mul_signs(::GradeFilter{:dot}, B::NTuple{L,BitIndex{Q}}, a::BitIndex{Q}) where {L,Q} + return sign.(B) .* Int8(-1).^(grade(a) .* (grade.(B) .- grade(a))) +end + +mul_signs(F::GradeFilter, a::BitIndex{Q}, B::BitIndices{Q}) where Q = mul_signs(F, a, Tuple(B)) +mul_signs(F::GradeFilter, B::BitIndices{Q}, a::BitIndex{Q}) where Q = mul_signs(F, Tuple(B), a) + #---Product return types---------------------------------------------------------------------------# """ CliffordNumbers.product_return_type(::Type{X}, ::Type{Y}, [::GradeFilter{S}]) @@ -204,11 +227,7 @@ kernel just returns the geometric product. # But all values are known at compile time, so interpolate them into expressions ia = to_index(x, a) tuple_inds = to_index.(y, inds) - # Special case for dot products - signs = sign.(inds) - if F <: GradeFilter{:dot} - signs = signs .* Int8(-1).^(grade.(inds) .* (grade(a) .- grade.(inds))) - end + signs = mul_signs(F(), a, inds) # Construct the tuples that contribute to the product x_tuple_ex = :(Tuple(x)[$ia] .* $x_mask) y_tuple_ex = :(getindex.(tuple(Tuple(y)), $tuple_inds) .* $signs .* $y_mask)