Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of np.linalg.lstsq() via SVD #2744

Merged
merged 7 commits into from
May 11, 2020

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Apr 16, 2020

This is an initial implementation of np.linalg.lstsq based on the SVD. A full solution would involve adding wrappers for *gelsd to lax_linalg.py & cusolver.

I estimate this is about 2x slower than the full solution, based on performance of the relevant lapack code paths in numpy.

Addresses part of #1999

jax/numpy/linalg.py Outdated Show resolved Hide resolved
jax/numpy/linalg.py Outdated Show resolved Hide resolved
jax/numpy/linalg.py Show resolved Hide resolved
jax/numpy/linalg.py Outdated Show resolved Hide resolved
jax/numpy/linalg.py Show resolved Hide resolved
tests/linalg_test.py Show resolved Hide resolved
tests/linalg_test.py Show resolved Hide resolved
@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 17, 2020

CI failure looks like a conda failure. Can someone restart it?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 17, 2020

nvm, once I logged-in to Travis I was able to restart.

@joaogui1
Copy link
Contributor

joaogui1 commented Apr 18, 2020

Hey, what about #2200 ?@shoyer

@mattjj
Copy link
Collaborator

mattjj commented Apr 18, 2020

Uh oh, I forgot about #2200 :/ This is completely my bad.

I'm really sorry for the repeated work, and for accidentally ignoring the work you put in @joaogui1 . As JAX activity has picked up (especially inside Alphabet) we've gotten a lot worse at following up on PRs from amazing OSS contributors. (OSS contributions are especially amazing because contributors don't have the benefit of our extremely-active internal chat channels.)

We're trying to address the general problem. As of this week we're experimenting with a GitHub rotation. It's tough, though, because the JAX team is pretty small. I'm optimistic that we'll get better over time.

As for this specific case: @jakevdp @joaogui1 is there a way to combine efforts here and maybe draw on both PRs? Or is there too much redundancy and we need to chalk this up to a mistake to learn from?

@joaogui1 let me know if there is some course of action I can take to make this more right. (Also, maybe this is a good chance to highlight any other PRs of yours that we've let languish...)

@joaogui1
Copy link
Contributor

So, after reading his code I believe @jakevdp implementation is better than mine, so I will close my PR.
I have two other open PRs:

  • [Awaiting reviews] Initial implementation of Depthwise Conv2D #1496 , frankly there are implementations of it in both haiku and flax, so I think I will also close this one. If I may say I think there should be a section in Contributing.md explaining that stax is not a priority for PRs (if it's accepting them at all) as it is more a proof of concept. I mean there aresome dead PRs related to it, and Jake Bradbury said:

... stax.py is meant more as an example (forking encouraged!) than as a growing or comprehensive library ...

  • The other one is [Draft] Performance Tests #1862 which I still think is important, but I would need more input from the dev team, so I understand if it will take a while longer for someone to pop up there and help me with deciding what tests are needed

Also, if you guys want some help I can reopen #1874 and search for solved/obsolete issues and PRs so someone can close them and help organize things a little more

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 18, 2020

I'm not able to get check_grads to pass consistently. It seems to be flaking on complex and/or low-rank inputs – 7 failures with --num_generated_cases=25. The gradient computation and its test is still a bit of a black box to me, and I haven't made much progress in debugging the issue.

@hawkinsp
Copy link
Collaborator

Is it worth merging this even without gradients?

(It seems that the problem is most likely not with the PR itself, given it doesn't implement any new gradients.)

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 28, 2020

Is it worth merging this even without gradients?

Perhaps – I'm planning on digging-in to implement gradients via a custom jvp, but I won't have time to look closely at that until next week.

@jakevdp
Copy link
Collaborator Author

jakevdp commented May 4, 2020

In offline discussion with @mattjj, we decided it would be worth submitting this even without robust gradient support. I removed the gradient test for now, and marked it TODO.

@shoyer
Copy link
Collaborator

shoyer commented May 4, 2020 via email

@jakevdp
Copy link
Collaborator Author

jakevdp commented May 6, 2020

I just rebased on master to pick up changes there.

PTAL - I think this is ready for a final review.

@jakevdp jakevdp merged commit db71f3c into jax-ml:master May 11, 2020
@jakevdp jakevdp deleted the lstsq branch May 11, 2020 22:00
j-towns pushed a commit to j-towns/jax that referenced this pull request May 14, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants