Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split up gemm_wrapper and stabilize MulAddMul strategically #47206

Closed
wants to merge 13 commits into from
24 changes: 12 additions & 12 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,18 +402,18 @@ end

const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
const BiTri = Union{Bidiagonal,Tridiagonal}
@inline mul!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::AbstractTriangular, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractVector, A::BiTriSym, B::AbstractVector, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::AbstractVecOrMat, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::AbstractTriangular, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractVector, A::BiTriSym, B::AbstractVector, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::AbstractVecOrMat, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = @stable_muladdmul A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractVector, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = throw(MethodError(mul!, (C, A, B)), MulAddMul(alpha, beta))
@inline mul!(C::AbstractVector, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = throw(MethodError(mul!, (C, A, B)), MulAddMul(alpha, beta))

Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -571,14 +571,14 @@ for Tri in (:UpperTriangular, :LowerTriangular)
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = iszero(β) ? nothing : diag(C)
data = mul!(C.data, D, A.data, α, β)
$Tri(_setdiag!(data, MulAddMul(α, β), D.diag, diag′))
$Tri(@stable_muladdmul _setdiag!(data, MulAddMul(α, β), D.diag, diag′))
end
@eval @inline mul!(C::$Tri, A::$Tri, D::Diagonal, α::Number, β::Number) = $Tri(mul!(C.data, A.data, D, α, β))
@eval @inline function mul!(C::$Tri, A::$UTri, D::Diagonal, α::Number, β::Number)
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = iszero(β) ? nothing : diag(C)
data = mul!(C.data, A.data, D, α, β)
$Tri(_setdiag!(data, MulAddMul(α, β), D.diag, diag′))
$Tri(@stable_muladdmul _setdiag!(data, MulAddMul(α, β), D.diag, diag′))
end
end

Expand Down
70 changes: 70 additions & 0 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,76 @@ end
end
end

"""
@stable_muladdmul

Replaces a function call, that has a `MulAddMul(alpha, beta)` constructor as an
argument, with a branch over possible values of `isone(alpha)` and `iszero(beta)`
and constructs `MulAddMul{isone(alpha), iszero(beta)}` explicitly in each branch.

For example, 'f(x, y, MulAddMul(alpha, beta))` is transformed into
```
if isone(alpha)
if iszero(beta)
f(x, y, MulAddMul{true, true, typeof(alpha), typeof(beta)}(alpha, beta))
else
f(x, y, MulAddMul{true, false, typeof(alpha), typeof(beta)}(alpha, beta))
end
else
if iszero(beta)
f(x, y, MulAddMul{false, true, typeof(alpha), typeof(beta)}(alpha, beta))
else
f(x, y, MulAddMul{false, false, typeof(alpha), typeof(beta)}(alpha, beta))
end
end
```

This avoids the type instability of the `MulAddMul(alpha, beta)` constructor,
which causes runtime dispatch in case alpha and zero are not constants.
"""
macro stable_muladdmul(expr)
expr.head == :call || throw(ArgumentError("Can only handle function calls."))
for (i, e) in enumerate(expr.args)
e isa Expr || continue
if e.head == :call && e.args[1] == :MulAddMul && length(e.args) == 3
e.args[2] isa Symbol || continue
e.args[3] isa Symbol || continue
local asym = e.args[2]
local bsym = e.args[3]

local e_sub11 = copy(expr)
e_sub11.args[i] = :(MulAddMul{true, true, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub10 = copy(expr)
e_sub10.args[i] = :(MulAddMul{true, false, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub01 = copy(expr)
e_sub01.args[i] = :(MulAddMul{false, true, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub00 = copy(expr)
e_sub00.args[i] = :(MulAddMul{false, false, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_out = quote
if isone($asym)
if iszero($bsym)
$e_sub11
else
$e_sub10
end
else
if iszero($bsym)
$e_sub01
else
$e_sub00
end
end
end
return esc(e_out)
end
end
throw(ArgumentError("No valid MulAddMul expression found."))
end

MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false)

@inline (::MulAddMul{true})(x) = x
Expand Down
Loading