Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Add norm(A, p; dims) #43459

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Conversation

mcabbott
Copy link
Contributor

@mcabbott mcabbott commented Dec 18, 2021

Closes JuliaLang/LinearAlgebra.jl#697. RFC, I guess?

One motivation is that this can be faster than making slices:

julia> @btime map(norm, eachrow($(rand(100, 100))));
  12.625 μs (1 allocation: 896 bytes)

julia> @btime norm($(rand(100, 100)), 2, dims=2);  # with PR
  1.771 μs (5 allocations: 976 bytes)
  1.650 μs (1 allocation: 896 bytes)

Here, "5 allocations" is sometimes 1 as you'd expect, I don't know why, seems to depend on load order? I think there was an issue about this which I can't find again.

The 0,1,Inf norm implementations are trivial.

For the 2-norm, I initially made it check whether values in A are in the goldilocks zone, as norm(A) does. But this check alone takes longer than the happy path of sum(abs2, A; dims). Instead, the present PR does that and then checks the answer, and re-does any slices which are dangerous. I hope this is correct. I presume the typical case would not have many such zero/Inf answers, in which case this will be fast. It's a bit awkward that eachslice doesn't do what is needed here, so instead I taped something together, trying to make the common cases fast. It is a bit ugly but is there a better way?

The p-norm is much slower, and for now it just slices. It could probably be done the same way, though. [Edit: now works like 2-norm]

@ararslan ararslan added the linear algebra Linear algebra label Dec 18, 2021
@oscardssmith
Copy link
Member

I don't like this (because I don't like dims arguments in general), but this seems reasonable to have.

@mcabbott
Copy link
Contributor Author

Do you think that checking for 0 & Inf afterwards is sufficient to catch all the floating point problems that the existing implementation catches?

The existing one checks beforehand, but doing that on the N^2 matrix seems to take longer than the entire operation. Whereas checking the N results afterwards is usually quick. At least when not too many are 0/Inf.

@oscardssmith
Copy link
Member

That is sufficient (and a much better idea).

@oscardssmith oscardssmith added linalg triage triage This should be discussed on a triage call needs tests Unit tests are required for this change labels Dec 28, 2021
@oscardssmith
Copy link
Member

oscardssmith commented Dec 28, 2021

Can you separate out

the present PR does that and then checks the answer, and re-does any slices which are dangerous.

into it's own PR? I think that that is an easy win, while the new method probably needs a triage. Alternatively, would you mind if I rewrote #43256 to use this?

@mcabbott
Copy link
Contributor Author

I wondered about using the same check-afterwards idea for the complete norm. I can have a go but if you beat me to it that's even better. (Did not see #43256.)

But I don't think such a change need alter this PR. They would share the idea but not share code for it, I think. Unless you are proposing that norm(::Matrix) should apply this idea chunk-wise?

norm0_dims!(B, A) = count!(!iszero, B, A)
norm1_dims!(B, A) = Base.mapreducedim!(norm, +, B, A)
normInf_dims!(B, A) = Base.mapreducedim!(norm, max, B, A)
normMinusInf_dims!(B, A) = Base.mapreducedim!(norm, min, B, A)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW the reason these are all mutating B is to work around #43461. With things like mapreduce(norm, max, A; dims) instead, some were not type-stable. There's a PR to fix that, though.

@oscardssmith
Copy link
Member

triage would rather not add new dims arguments and instead use #32310 (or similar) to make eachslice faster.

@oscardssmith oscardssmith removed triage This should be discussed on a triage call linalg triage labels Jan 6, 2022
@mcabbott
Copy link
Contributor Author

mcabbott commented Jan 6, 2022

What's slow isn't eachslice, though. It's that performing the calculation slice-by-slice is slow, especially when the direction is cache-unfriendly. (The eachrow example does above does not have eachslice's type-stability issues.)

@oscardssmith
Copy link
Member

with the SlicedArray type, couldn't this PR be implemented for norm(::Slices, p) though?

@mcabbott
Copy link
Contributor Author

mcabbott commented Jan 6, 2022

No, norm regards nested arrays as a bag of numbers, norm([[[1,2],3],4]) ≈ norm([1,2,3,4]), so I don't think the meaning of norm(::Slices) could differ from norm(collect(::Slices)).

It would be possible to overload norm.(::Slices). The two tricky things there are that sum.(eachslice(rand(2,3); dims=1)) has the opposite convention for what dims means to sum's, and that you have to decide whether this should allocate the result immediately or fuse with further operations. That seems like a bigger design decision.

EDIT:

map(::typeof(norm), ::Slices) avoids the fusion question. But A ./ map(Fix2(norm, 1), eachslice(A; dims=2, drop=false)) is quite a mouthful.

All of these have a bit of the problem that reduce(vcat, xs) has -- by magically upgrading a function which already works (but slowly) we are left guessing as to whether a given use is actually going to hit the magic fast path, or not. Whereas right now, the existence of a dims method is evidence of the existence of a special path.

There's also a problem of return types. map(norm, eachcol(A::CuArray)) isa Vector right now. If a magic fast path existed, it would want to make a CuArray, and this is what you want for uses like A ./ norm(A; dims=2). There are very few uses where you could equally accept an Array or (as a magic optimisation) a CuArray. That's an issue for all proposals like sum.(eachcol(A)) too.

@mcabbott
Copy link
Contributor Author

Today I discovered that 1.9 has sortperm(randn(3,5), dims=1), and it was useful.

Rebasing this PR & timing it on 1.10-, which includes the new eachslice of #32310, the benefit is still pretty similar:

julia> @btime mapslices(norm, $(rand(100, 100)), dims=2);  # re-written for 1.9,# 40996
  30.375 μs (15 allocations: 2.12 KiB)

julia> @btime map(norm, eachrow($(rand(100, 100))));  # with JuliaLang/julia#32310
  22.000 μs (1 allocation: 896 bytes)

julia> @btime norm($(rand(100, 100)), 2, dims=2);  # with PR
  2.949 μs (5 allocations: 976 bytes)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
linear algebra Linear algebra needs tests Unit tests are required for this change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Norm along side axis
3 participants