Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearAlgebra: improve type-inference in Symmetric/Hermitian matmul #54303

Merged
merged 15 commits into from
May 7, 2024

Conversation

jishnub
Copy link
Contributor

@jishnub jishnub commented Apr 29, 2024

Matrix multiplication for wrapper types such as Hermitian currently uses the unwrapping mechanism that assigns a character based on the type of the wrapper. However, this isn't always unique, as for Hermitian/Symmetric types, this also looks as the uplo field, which isn't usually known at compile time.

wrapper_char(A::Hermitian) = A.uplo == 'U' ? 'H' : 'h'
wrapper_char(A::Hermitian{<:Real}) = A.uplo == 'U' ? 'S' : 's'
wrapper_char(A::Symmetric) = A.uplo == 'U' ? 'S' : 's'

An example of a badly inferred function call because of this:

julia> @descend_code_warntype (A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(Symmetric(rand(2,2)))
(::var"#1#2")(A) @ Main REPL[2]:1
┌ Warning: couldn't retrieve source of (::var"#1#2")(A) @ Main REPL[2]:1
└ @ TypedSyntax ~/.julia/packages/TypedSyntax/cH1Nu/src/node.jl:36
Variables
  #self#::Core.Const(var"#1#2"())
  A::Symmetric{Float64, Matrix{Float64}}

Body::Union{Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Symmetric{Float64, Matrix{Float64}}, Transpose{Float64, Matrix{Float64}}, Matrix{Float64}}
    @ REPL[2]:1 within `#1`
1%1 = LinearAlgebra.wrap::Core.Const(LinearAlgebra.wrap)
│   %2 = Main.parent::Core.Const(parent)
│   %3 = (%2)(A)::Matrix{Float64}%4 = LinearAlgebra.wrapper_char::Core.Const(LinearAlgebra.wrapper_char)
│   %5 = (%4)(A)::Char%6 = (%1)(%3, %5)::Union{Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Symmetric{Float64, Matrix{Float64}}, Transpose{Float64, Matrix{Float64}}, Matrix{Float64}}
└──      return %6
Select a call to descend into or  to ascend. [q]uit. [b]ookmark.
Toggles: [w]arn, [h]ide type-stable statements, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native, [j]ump to source always.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
 • %3 = parent(::Symmetric{Float64, Matrix{Float64}})::Matrix{Float64}
   %5 = wrapper_char(::Symmetric{Float64, Matrix{Float64}})::Char
   %6 = wrap(::Matrix{Float64},::Char)::

The output type is inferred as a large union, which complicates further type-inference downstream. Often, the impact of the runtime dispatch is minimal due to function barriers. However, we may avoid the runtime dispatch altogether.

This PR separates the uplo character from that for the type, storing them both in a newly defined struct. Using this approach, the type information may be constant-propagated even if the uplo isn't, and the return type may be concretely inferred. After this,

julia> @descend_code_warntype (A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(Symmetric(rand(2,2)))
(::var"#3#4")(A) @ Main REPL[3]:1
┌ Warning: couldn't retrieve source of (::var"#3#4")(A) @ Main REPL[3]:1
└ @ TypedSyntax ~/.julia/packages/TypedSyntax/cH1Nu/src/node.jl:36
Variables
  #self#::Core.Const(var"#3#4"())
  A::Symmetric{Float64, Matrix{Float64}}

Body::Symmetric{Float64, Matrix{Float64}}
    @ REPL[3]:1 within `unknown scope`
1%1 = LinearAlgebra.wrap::Core.Const(LinearAlgebra.wrap)
│   %2 = Main.parent::Core.Const(parent)
│   %3 = (%2)(A)::Matrix{Float64}%4 = LinearAlgebra.wrapper_char::Core.Const(LinearAlgebra.wrapper_char)
│   %5 = (%4)(A)::Core.PartialStruct(LinearAlgebra.WrapperChar, Any[Core.Const('S'), Bool])
│   %6 = (%1)(%3, %5)::Symmetric{Float64, Matrix{Float64}}
└──      return %6
Select a call to descend into or  to ascend. [q]uit. [b]ookmark.
Toggles: [w]arn, [h]ide type-stable statements, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native, [j]ump to source always.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
 • %3 = parent(::Symmetric{Float64, Matrix{Float64}})::Matrix{Float64}
   %5 = wrapper_char(::Symmetric{Float64, Matrix{Float64}})::Core.PartialStruct(LinearAlgebra.WrapperChar, Any[Core.Const('S'), Bool])
   %6 = < constprop > wrap(::Matrix{Float64},::Core.PartialStruct(LinearAlgebra.WrapperChar, Any[Core.Const('S'), Bool]))::
   

This change should be compatible with existing codes, as the new struct subtypes an AbstractChar, and it may be converted and compared to a Char like before.

Fixes #53951.

@jishnub jishnub added linear algebra Linear algebra backport 1.11 Change should be backported to release-1.11 labels Apr 29, 2024
@dkarrasch
Copy link
Member

Should we even backport this to v1.10?

@dkarrasch
Copy link
Member

So, the main idea is not to confuse the compiler unnecessarily with uppercase/lowercase H or S when, for the result type, that distinction is irrelevant anyway, right? Because that distinction is just the value of a field, and not encoded into the type?

@jishnub
Copy link
Contributor Author

jishnub commented Apr 29, 2024

Yes, that's the idea.

Backporting to v1.10 might require some manual intervention, but should be a good idea.

@jishnub jishnub added the backport 1.10 Change should be backported to the 1.10 release label Apr 29, 2024
@jishnub
Copy link
Contributor Author

jishnub commented Apr 30, 2024

The last few commits improve type-stability and ensure constant propagation in various checks in the matmul functions. Introduces a new function _in that parallels in, but uses 2-value logic and is defined recursively. This allows checks like tA in ('T', 'N', 'C') to be evaluated at compile time, which should remove branches in the code. Ideally, in should already be doing this, but I don't know enough about the compile-time implications in the general case. For our specific case, this shouldn't matter much. Also, defines a function all_in that acts like all(in(..)), but ensures constant propagation by unrolling the loop over the arguments. This isn't used anymore, as all(map(in(..), ...)) achieves constant propagation without the need for special helper functions.

Fixes #53951 after the recent set of commits. After this,

julia> using LinearAlgebra

julia> using BenchmarkTools

julia> A = Hermitian([1.0 2.0; 2.0 3.0])
2×2 Hermitian{Float64, Matrix{Float64}}:
 1.0  2.0
 2.0  3.0

julia> B = [4.0 5.0; 6.0 7.0]
2×2 Matrix{Float64}:
 4.0  5.0
 6.0  7.0

julia> Y = similar(B)
2×2 Matrix{Float64}:
   6.0e-323  6.4e-323
 NaN         0.0

julia> @btime mul!($Y, $A, $B)
  127.843 ns (0 allocations: 0 bytes)
2×2 Matrix{Float64}:
 16.0  19.0
 26.0  31.0

julia> Badj = B'
2×2 adjoint(::Matrix{Float64}) with eltype Float64:
 4.0  6.0
 5.0  7.0

julia> @btime mul!($Y, $A, $Badj)
  44.311 ns (0 allocations: 0 bytes)
2×2 Matrix{Float64}:
 14.0  20.0
 23.0  33.0

@jishnub jishnub force-pushed the jishnub/linalgwrapperchar branch from ae7fc6c to 541a76d Compare May 2, 2024 19:13
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is true, this should be a giant contribution to the reduction of compile times, right? If we land in this branch, then we don't need to compile symm and hemm, or in the other case syrk/herk/gemm_wrapper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some compile-time improvement indeed, although it's not dramatic.
Each execution is in a separate session in the following:

julia> A = rand(2,2); B = rand(2,2); C = zeros(2,2);

julia> @time mul!(C, A, B);
  0.847057 seconds (3.39 M allocations: 171.963 MiB, 24.97% gc time, 100.00% compilation time) # nightly
  0.757433 seconds (3.94 M allocations: 202.922 MiB, 4.65% gc time, 100.00% compilation time) # This PR

julia> A = rand(2,2); B = Symmetric(rand(2,2)); C = zeros(2,2);

julia> @time mul!(C, A, B);
  1.098831 seconds (3.68 M allocations: 189.159 MiB, 24.52% gc time, 99.99% compilation time) # nightly
  0.687847 seconds (4.72 M allocations: 238.864 MiB, 7.04% gc time, 99.99% compilation time) # This PR

Descending into generic_matmatmul! using Cthulhu does seem to indicate that unused branches are eliminated, and e.g. in the first case, only gemm_wrapper! is being compiled, and in the second, only BLAS.symm! is compiled.

Copy link
Contributor Author

@jishnub jishnub May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code_typed for the first case (gemm) is identical between this PR and nightly:

julia> A = rand(2,2); B = rand(2,2); C = zeros(2,2);

julia> @code_typed mul!(C, A, B)
CodeInfo(
1%1 = invoke LinearAlgebra.gemm_wrapper!(C::Matrix{Float64}, 'N'::Char, 'N'::Char, A::Matrix{Float64}, B::Matrix{Float64}, $(QuoteNode(LinearAlgebra.MulAddMul{true, true, Bool, Bool}(true, false)))::LinearAlgebra.MulAddMul{true, true, Bool, Bool})::Matrix{Float64}
└──      return %1
) => Matrix{Float64}

I'm not certain why there's a compile-time improvement here. (perhaps noise?) In this case, the all is already being folded (despite the loop over the characters). I suspect the loop is being unrolled entirely, as the characters are all Chars that are fully known at compile time.

The second case (symm) is where the major improvement comes in:

julia> A = rand(2,2); B = Symmetric(rand(2,2)); C = zeros(2,2);

julia> @code_typed mul!(C, A, B)
CodeInfo(
1 ── %1  = Base.getfield(B, :uplo)::Char%2  = Base.bitcast(Base.UInt32, %1)::UInt32%3  = Base.bitcast(Base.UInt32, 'U')::UInt32%4  = (%2 === %3)::Bool%5  = Base.getfield(B, :data)::Matrix{Float64}
└───       goto #3 if not %4
2 ──       goto #4
3 ──       goto #4
4 ┄─ %9  = φ (#2 => 'S', #3 => 's')::Char%10 = Base.bitcast(Base.UInt32, %9)::UInt32%11 = Base.bitcast(Base.UInt32, 'S')::UInt32%12 = (%10 === %11)::Bool
└───       goto #5
5 ──       goto #7 if not %12
6 ──       goto #8
7 ──       nothing::Nothing
8 ┄─ %17 = φ (#6 => 'U', #7 => 'L')::Char%18 = invoke LinearAlgebra.BLAS.symm!('R'::Char, %17::Char, 1.0::Float64, %5::Matrix{Float64}, A::Matrix{Float64}, 0.0::Float64, C::Matrix{Float64})::Matrix{Float64}
└───       goto #9
9 ──       goto #10
10 ─       goto #11
11return %18
) => Matrix{Float64}

The BLAS.symm! branch that is being followed is "inlined" now. This is the case where the loop is not unrolled ordinarily, but using the all(map(..)) combination permits constant propagation.

@dkarrasch
Copy link
Member

I love it. I always dreamt of the day when that character stuff be inferred, or constant-propagated far enough. Is this ready to go now? I think we should first merge this, and then "stabilize MulAddMul strategically" PR, to give this one a chance for backport to v1.10, though I'm not sure if this is a bit too ambitious.

KristofferC pushed a commit that referenced this pull request May 7, 2024
With this, `isuppercase`/`islowercase` are evaluated at compile-time for
`Char` arguments:
```julia
julia> @code_typed (() -> isuppercase('A'))()
CodeInfo(
1 ─     return true
) => Bool

julia> @code_typed (() -> islowercase('A'))()
CodeInfo(
1 ─     return false
) => Bool
```
This would be useful in #54303,
where the case of the character indicates which triangular half of a
matrix is filled, and may be constant-propagated downstream.

---------

Co-authored-by: Shuhei Kadowaki <[email protected]>
@jishnub
Copy link
Contributor Author

jishnub commented May 7, 2024

Yes, this is ready from my side.

@dkarrasch dkarrasch merged commit c77671a into master May 7, 2024
7 checks passed
@dkarrasch dkarrasch deleted the jishnub/linalgwrapperchar branch May 7, 2024 16:19
@KristofferC KristofferC mentioned this pull request May 8, 2024
23 tasks
KristofferC pushed a commit that referenced this pull request May 20, 2024
KristofferC pushed a commit that referenced this pull request May 23, 2024
@KristofferC KristofferC removed the backport 1.11 Change should be backported to release-1.11 label May 28, 2024
KristofferC added a commit that referenced this pull request May 30, 2024
Backported PRs:
- [x] #54010 <!-- Overload `Base.literal_pow` for `AbstractQ` -->
- [x] #54143 <!-- Fix `make install` from tarballs -->
- [x] #54151 <!-- LinearAlgebra: Correct zero element in
`_generic_matvecmul!` for block adj/trans -->
- [x] #54233 <!-- set MAX_OS_WRITE on unix -->
- [x] #54251 <!-- fix typo in gc_mark_memory8 when chunking a large
array -->
- [x] #54363 <!-- typeintersect: fix another stack overflow caused by
circular constraints -->
- [x] #54497 <!-- Make TestLogger thread-safe (introduce a lock) -->
- [x] #53796 <!-- Add a missing doc -->
- [x] #54465 <!-- typeintersect: conservative typevar subtitution during
`finish_unionall` -->
- [x] #54514 <!-- typeintersect: followup cleanup for the nothrow path
of type instantiation -->

Need manual backport:
- [ ] #52505 <!-- fix alignment of emit_unbox_store copy -->
- [ ] #53373 <!-- fix sysimage-native-code=no option with pkgimages -->
- [ ] #53815 <!-- create phantom task for GC threads -->
- [ ] #53984 <!-- Profile: fix heap snapshot is valid char check -->
- [ ] #54276 <!-- Fix solve for complex `Hermitian` with non-vanishing
imaginary part on diagonal -->

Contains multiple commits, manual intervention needed:
- [ ] #52854 <!-- Change to streaming out the heap snapshot data -->
- [ ] #53218 <!-- Fix interpreter_exec.jl test -->
- [ ] #53833 <!-- Profile: make heap snapshots viewable in vscode viewer
-->
- [ ] #54303 <!-- LinearAlgebra: improve type-inference in
Symmetric/Hermitian matmul -->
- [ ] #52694 <!-- Reinstate similar for AbstractQ for backward
compatibility -->

Non-merged PRs with backport label:
- [ ] #54471 <!-- Actually setup jit targets when compiling
packageimages instead of targeting only one -->
- [ ] #53452 <!-- RFC: allow Tuple{Union{}}, returning Union{} -->
- [ ] #51479 <!-- prevent code loading from lookin in the versioned
environment when building Julia -->
@KristofferC KristofferC mentioned this pull request Jun 19, 2024
46 tasks
lazarusA pushed a commit to lazarusA/julia that referenced this pull request Jul 12, 2024
…ng#54346)

With this, `isuppercase`/`islowercase` are evaluated at compile-time for
`Char` arguments:
```julia
julia> @code_typed (() -> isuppercase('A'))()
CodeInfo(
1 ─     return true
) => Bool

julia> @code_typed (() -> islowercase('A'))()
CodeInfo(
1 ─     return false
) => Bool
```
This would be useful in JuliaLang#54303,
where the case of the character indicates which triangular half of a
matrix is filled, and may be constant-propagated downstream.

---------

Co-authored-by: Shuhei Kadowaki <[email protected]>
KristofferC added a commit that referenced this pull request Aug 13, 2024
Backported PRs:
- [x] #51351 <!-- Remove boxing in pinv -->
- [x] #52678 <!-- Profile: Improve module docstring -->
- [x] #54201 <!-- Fix generic triangular solves with empty matrices -->
- [x] #54605 <!-- Allow libquadmath to also fail as it is not available
on all systems -->
- [x] #54634 <!-- Fix trampoline assembly for build on clang 18 on apple
silicon -->
- [x] #54635 <!-- Aggressive constprop in trevc! to stabilize triangular
eigvec -->
- [x] #54645 <!-- ensure we set the right value to gc_first_tid -->
- [x] #54671 <!-- Add boundscheck in bindingkey_eq to avoid OOB access
due to data race -->
- [x] #54672 <!-- make: Fix `sed` command for LLVM libraries with no
symbol versioning -->
- [x] #54704 <!-- LazyString in reinterpretarray error messages -->
- [x] #54713 <!-- make: use `readelf` for LLVM symbol version detection
-->
- [x] #54781 <!-- [LinearAlgebra] Improve resilience to unknown
libblastrampoline flags -->
- [x] #54837 <!-- Do not add type tag size to the `alloc_typed` lowering
for GC allocations -->
- [x] #54815 <!-- add sticky task warning to `@task` and `schedule` -->
- [x] #55141 <!-- Update the aarch64 devdocs to reflect the current
state of its support -->
- [x] #55178 <!-- Compat for `Base.@nospecializeinfer` -->
- [x] #55197 <!-- compat notice for a[begin] indexing -->
- [x] #55209 <!-- correction to compat notice for a[begin] -->
- [x] #55203 <!-- document mutable struct const fields -->
- [x] #54769 <!-- add missing compat entry to edit -->
- [x] #54791 <!-- Bump libblastrampoline to v5.10.1 -->
- [x] #55070 <!-- LinearAlgebra: LazyString in error messages for
Diagonal/Bidiagonal -->
- [x] #54624 <!-- more precise aliasing checks for SubArray -->
- [x] #54690 <!-- Fix assertion/crash when optimizing function with dead
basic block -->
- [x] #55084 <!-- Use triple quotes in TOML.print when string contains
newline -->


Need manual backport:
- [ ] #52505 <!-- fix alignment of emit_unbox_store copy -->
- [ ] #53373 <!-- fix sysimage-native-code=no option with pkgimages -->
- [ ] #53984 <!-- Profile: fix heap snapshot is valid char check -->
- [ ] #54276 <!-- Fix solve for complex `Hermitian` with non-vanishing
imaginary part on diagonal -->
- [ ] #54669 <!-- Improve error message in inplace transpose -->
- [ ] #54871 <!-- Make warn missed transformations pass optional -->

Contains multiple commits, manual intervention needed:
- [ ] #52854 <!-- Change to streaming out the heap snapshot data -->
- [ ] #53218 <!-- Fix interpreter_exec.jl test -->
- [ ] #53833 <!-- Profile: make heap snapshots viewable in vscode viewer
-->
- [ ] #54303 <!-- LinearAlgebra: improve type-inference in
Symmetric/Hermitian matmul -->
- [ ] #52694 <!-- Reinstate similar for AbstractQ for backward
compatibility -->
- [ ] #54737 <!-- LazyString in interpolated error messages involving
types -->
- [ ] #54738 <!-- serialization: fix relocatability bug -->
- [ ] #55052 <!-- Fix `(l/r)mul!` with `Diagonal`/`Bidiagonal` -->

Non-merged PRs with backport label:
- [ ] #55220 <!-- `isfile_casesensitive` fixes on Windows -->
- [ ] #55169 <!-- `propertynames` for SVD respects private argument -->
- [ ] #55013 <!-- [docs] change docstring to match code -->
- [ ] #51479 <!-- prevent code loading from lookin in the versioned
environment when building Julia -->
- [ ] #50813 <!-- More doctests for Sockets and capitalization fix -->
- [ ] #50157 <!-- improve docs for `@inbounds` and
`Base.@propagate_inbounds` -->
- [ ] #41244 <!-- Fix shell `cd` error when working dir has been deleted
-->
@KristofferC KristofferC mentioned this pull request Sep 12, 2024
63 tasks
@KristofferC KristofferC mentioned this pull request Oct 29, 2024
47 tasks
@KristofferC
Copy link
Member

Could someone help with backporting this to 1.10? Just push a commit in that case.

@jishnub
Copy link
Contributor Author

jishnub commented Nov 4, 2024

Perhaps it'll be better to not backport this to v1.10. This changes the type of the character that is passed, and although it is internal, there are packages that call generic_matmatmul! directly (such as CUDA), and occasionally expect the characters to be Chars in the function signature. It seems unnecessary to change this on v1.10. This PR fixes some type-instabilities, which is more of an optimization than a bugfix.

@jishnub jishnub removed the backport 1.10 Change should be backported to the 1.10 release label Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
linear algebra Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Type instability and allocations in mul! with Hermitians and Adjoint arguments
3 participants