Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nx.LinAlg.qr does not support vectorization #1556

Closed
jyc opened this issue Nov 5, 2024 · 6 comments
Closed

Nx.LinAlg.qr does not support vectorization #1556

jyc opened this issue Nov 5, 2024 · 6 comments

Comments

@jyc
Copy link

jyc commented Nov 5, 2024

Hello! Thanks again for making Nx. Sorry I haven't been able to try out your CPU-based implementation of LU yet from #1388 (comment)!

I noticed today that QR does not seem to work with vectorized inputs in EXLA:

input =
  Nx.tensor([
    [
      [1, 0],
      [0, 1]
    ],
    [
      [0, 1],
      [1, 0]
    ]
  ])
  |> Nx.vectorize(:foo)

Nx.LinAlg.qr(input)

The result is incorrect; all results except for the first are zero.

{#Nx.Tensor<
   vectorized[foo: 2]
   f32[2][2]
   EXLA.Backend<host:0, 0.3767501109.3804364811.204054>
   [
     [
       [1.0, 0.0],
       [0.0, 1.0]
     ],
     [
       [0.0, 0.0],
       [0.0, 0.0]
     ]
   ]
 >,
 #Nx.Tensor<
   vectorized[foo: 2]
   f32[2][2]
   EXLA.Backend<host:0, 0.3767501109.3804364811.204055>
   [
     [
       [1.0, 0.0],
       [0.0, 1.0]
     ],
     [
       [0.0, 0.0],
       [0.0, 0.0]
     ]
   ]
 >}

Also, the process my Livebook is attached to exits with "Abort trap: 6" shortly afterwards! This happens if I execute Nx.LinAlg.qr in a Livebook session or through IEx.

Is there anything I can do to help debug this? Thanks again.

@polvalente
Copy link
Contributor

Which version of Nx are you using? Please try the main branch, as I ran into a bug when implementing LU that was also present in QR that could be account for this.

I'm getting different results in main than 0.9.1, and main reconstructs input correctly :)

@polvalente
Copy link
Contributor

The bug was due to "overstriding" the pointer arithmetic. In C++ the size of the datatype is automatically multiplied to the right operand, and I was multiplying explicitly. This might be the cause of the exit you're seeing too

@jyc
Copy link
Author

jyc commented Nov 5, 2024

Interesting, thanks! I was on 0.8 and just verified that it doesn't work on 0.9.1 either, like you said. I was going to try main, but using:

      {:nx,
       git: "https://github.com/elixir-nx/nx.git", tag: "9e2cd048de610151b85a27a183035bc0873fa77f"},
      {:exla,
       git: "https://github.com/elixir-nx/nx.git", tag: "9e2cd048de610151b85a27a183035bc0873fa77f"},

in mix.exs gave me warnings; I got:

    warning: redefining module NxRoot (current version defined in memory)
    │
  2 │ defmodule NxRoot do
    │ ~~~~~~~~~~~~~~~~~~~
    │
    └─ /Users/jyc/projects/foo/server/deps/exla/mix.exs:2: NxRoot (module)

I imagine this is because Nx and EXLA are both defined in this repository. What would be the correct way to set Mix deps?

EDIT: The full error is:

Erlang/OTP 27 [erts-15.0] [source] [64-bit] [smp:10:10] [ds:10:10:10] [async-threads:1] [jit]

    warning: redefining module NxRoot (current version defined in memory)
    │
  2 │ defmodule NxRoot do
    │ ~~~~~~~~~~~~~~~~~~~
    │
    └─ /Users/jyc/projects/foo/server/deps/exla/mix.exs:2: NxRoot (module)

** (Mix) App nx lists itself as a dependency

@jyc
Copy link
Author

jyc commented Nov 5, 2024

I think you're right that this works on main!

I messed around with options and this in mix.exs seems to work, although I have no idea whether it's correct:

      {:nx, git: "https://github.com/elixir-nx/nx.git", sparse: "nx", ref: "main", override: true},                                                                              
      {:exla, git: "https://github.com/elixir-nx/nx.git", sparse: "exla", ref: "main"},   

The output is now:

{#Nx.Tensor<
   vectorized[foo: 2]
   f32[2][2]
   EXLA.Backend<host:0, 0.1453107890.856555539.251335>
   [
     [
       [1.0, 0.0],
       [0.0, 1.0]
     ],
     [
       [0.0, -1.0],
       [-1.0, 0.0]
     ]
   ]
 >,
 #Nx.Tensor<
   vectorized[foo: 2]
   f32[2][2]
   EXLA.Backend<host:0, 0.1453107890.856555539.251336>
   [
     [
       [1.0, 0.0],
       [0.0, 1.0]
     ],
     [
       [-1.0, 0.0],
       [0.0, -1.0]
     ]
   ]
 >}

@jyc
Copy link
Author

jyc commented Nov 5, 2024

Thanks so much for the help! For my own edification, is 7af065e the commit that contains the fix you mentioned?

@polvalente
Copy link
Contributor

Glad it worked! Yes, that's the commit with the fix :)

Also, for the future, you can use a shorter notation for github deps:

{:nx, github: "elixir-nx/nx", branch: "main", sparse: "nx"}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants