Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[codegen] softmax nans #17670

Closed
dan-garvey opened this issue Jun 13, 2024 · 5 comments
Closed

[codegen] softmax nans #17670

dan-garvey opened this issue Jun 13, 2024 · 5 comments
Assignees
Labels
bug 🐞 Something isn't working codegen/llvm LLVM code generation compiler backend

Comments

@dan-garvey
Copy link
Contributor

dan-garvey commented Jun 13, 2024

What happened?

func.func @softmax(%arg0: tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178x1178xf32> {
  %c0 = arith.constant 0 : index
  %0 = tensor.empty() : tensor<2x24x1178x1178xf32>
  %1 = linalg.softmax dimension(3) ins(%arg0 : tensor<2x24x1178x1178xf32>) outs(%0 : tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178x1178xf32>
  return %1 : tensor<2x24x1178x1178xf32>
}

compile command:

iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-target-cpu-features=host 42.mlir -o 42.vmfb

input npy
https://sharkblobs.blob.core.windows.net/dan/42_inputs.npy

output npy (for comparison)
https://sharkblobs.blob.core.windows.net/dan/42_out.npy

iree-run-module --module=42.vmfb --function=softmax --input=@42_inputs.npy --output=@42_out_repro.npy

@dan-garvey dan-garvey added the bug 🐞 Something isn't working label Jun 13, 2024
@dan-garvey
Copy link
Contributor Author

dan-garvey commented Jun 13, 2024

@hanhanW identified [0, 18, 63, *] and [0, 18, 1389, *] are NANs

@hanhanW hanhanW added the codegen/llvm LLVM code generation compiler backend label Jun 13, 2024
@hanhanW
Copy link
Contributor

hanhanW commented Jun 13, 2024

@pashu123 please help the further triaging. We dumped the inputs and outputs and verified that there are NANs.

@pashu123
Copy link
Contributor

pashu123 commented Jun 14, 2024

On further debugging, the problem is with max calculation. Smaller repro:

func.func @softmax(%arg0: tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178xf32> {
  %4 = tensor.empty() : tensor<2x24x1178xf32>
  %cst = arith.constant -3.40282347E+38 : f32
  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<2x24x1178xf32>) -> tensor<2x24x1178xf32>

  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<2x24x1178x1178xf32>) outs(%5 : tensor<2x24x1178xf32>) {
  ^bb0(%in: f32, %out: f32):
    %10 = arith.maximumf %in, %out : f32
    linalg.yield %10 : f32
  } -> tensor<2x24x1178xf32>

  return %6 : tensor<2x24x1178xf32>
}

https://gist.github.com/pashu123/83ca1f519aa39f1ce7a035122bbb7e54 (Compile and run commands are same as above)

I have created a Python script to debug: https://gist.github.com/pashu123/898636a138e41e1db2443acd1248d6d4

The output of Python script:

Mismatch at index (np.int64(0), np.int64(2), np.int64(1)): golden=-1.6139899492263794, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(9)): golden=-1.1718499660491943, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(10)): golden=-1.594499945640564, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(11)): golden=-1.9860199689865112, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(18)): golden=-1.1132500171661377, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(19)): golden=-2.1459200382232666, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(20)): golden=-1.3908900022506714, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(21)): golden=-1.2039200067520142, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(23)): golden=-3.720489978790283, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(24)): golden=-3.0760700702667236, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(25)): golden=-3.9601500034332275, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(26)): golden=-2.8110198974609375, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(27)): golden=-1.5647300481796265, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(29)): golden=-1.171970009803772, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(30)): golden=-2.9511098861694336, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(31)): golden=-1.1302900314331055, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(32)): golden=-3.8724400997161865, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(33)): golden=-1.8330700397491455, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(34)): golden=-1.1605299711227417, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(35)): golden=-5.191100120544434, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(36)): golden=-3.998159885406494, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(48)): golden=-2.524359941482544, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(54)): golden=-1.4726300239562988, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(55)): golden=-6.302299976348877, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(56)): golden=-1.0678900480270386, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(62)): golden=-3.644969940185547, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(63)): golden=-4.302030086517334, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(65)): golden=-1.2450499534606934, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(66)): golden=-2.546420097351074, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(70)): golden=-1.760390043258667, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(80)): golden=-1.1018799543380737, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(84)): golden=-2.6196000576019287, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(85)): golden=-1.4363700151443481, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(92)): golden=-1.8270699977874756, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(93)): golden=-5.119679927825928, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(94)): golden=-3.4443399906158447, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(95)): golden=-1.8535699844360352, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(98)): golden=-1.696810007095337, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(99)): golden=-2.281130075454712, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(100)): golden=-2.694159984588623, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(101)): golden=-3.200939893722534, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(102)): golden=-4.250319957733154, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(103)): golden=-2.6362600326538086, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(108)): golden=-1.3708399534225464, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(115)): golden=-1.9866199493408203, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(118)): golden=-2.3564600944519043, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(122)): golden=-4.689330101013184, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(123)): golden=-3.47625994682312, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(124)): golden=-2.152790069580078, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(125)): golden=-1.2989599704742432, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(127)): golden=-5.363550186157227, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(128)): golden=-4.256410121917725, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(130)): golden=-2.7768800258636475, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(134)): golden=-1.7649099826812744, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(135)): golden=-3.982069969177246, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(136)): golden=-6.1743998527526855, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(137)): golden=-6.286499977111816, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(138)): golden=-2.8284900188446045, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(139)): golden=-5.993460178375244, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(140)): golden=-3.436150074005127, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(144)): golden=-2.254849910736084, iree=-0.0
Mismatch at index (np.int64(0), np.int64(4), np.int64(58)): golden=-1.4180999994277954, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(1)): golden=-19.961200714111328, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(2)): golden=-25.51609992980957, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(3)): golden=-6.272930145263672, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(4)): golden=-9.712470054626465, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(6)): golden=-10.295499801635742, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(7)): golden=-16.54210090637207, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(8)): golden=-39.7671012878418, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(9)): golden=-22.47920036315918, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(10)): golden=-23.77739906311035, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(11)): golden=-40.10390090942383, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(12)): golden=-10.307700157165527, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(13)): golden=-8.724579811096191, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(14)): golden=-3.1235899925231934, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(15)): golden=-15.26159954071045, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(16)): golden=-8.746410369873047, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(17)): golden=-9.033740043640137, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(18)): golden=-36.70589828491211, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(19)): golden=-41.16350173950195, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(20)): golden=-38.764198303222656, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(21)): golden=-20.7450008392334, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(22)): golden=-14.468999862670898, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(23)): golden=-19.56329917907715, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(24)): golden=-17.083499908447266, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(25)): golden=-20.79840087890625, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(26)): golden=-11.901700019836426, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(27)): golden=-21.383699417114258, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(28)): golden=-17.52400016784668, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(29)): golden=-16.292200088500977, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(30)): golden=-15.337599754333496, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(31)): golden=-14.481499671936035, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(32)): golden=-21.077600479125977, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(33)): golden=-26.247299194335938, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(34)): golden=-31.76959991455078, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(35)): golden=-27.92840003967285, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(36)): golden=-10.960000038146973, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(37)): golden=-10.226499557495117, iree=-0.0

It looks like iree's output gets stuck at -0.0.

using https://mlir.llvm.org/docs/Dialects/ArithOps/#arithmaxnumf-arithmaxnumfop
i.e.,

func.func @softmax(%arg0: tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178xf32> {
  %4 = tensor.empty() : tensor<2x24x1178xf32>
  %cst = arith.constant -3.40282347E+38 : f32
  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<2x24x1178xf32>) -> tensor<2x24x1178xf32>

  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<2x24x1178x1178xf32>) outs(%5 : tensor<2x24x1178xf32>) {
  ^bb0(%in: f32, %out: f32):
    %10 = arith.maxnumf %in, %out : f32
    linalg.yield %10 : f32
  } -> tensor<2x24x1178xf32>

  return %6 : tensor<2x24x1178xf32>
}

Solves the problem. Meanwhile, I am reading the documentation. It's not clear to me why it happens 😆.

Cherry-pick: https://github.com/pashu123/llvm-project/tree/fyi_soft (verified)

@hanhanW
Copy link
Contributor

hanhanW commented Jun 14, 2024

Some related read: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671

I recall the time that we split the min(max) to minimum/minnum(maximum/maxnum). We could miss it in softmax because it was not on my radar.

pashu123 added a commit to pashu123/llvm-project that referenced this issue Jun 19, 2024
pashu123 added a commit to llvm/llvm-project that referenced this issue Jun 20, 2024
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this issue Jul 9, 2024
@hanhanW
Copy link
Contributor

hanhanW commented Jul 31, 2024

The upstream change is landed to IREE: #18033

closing the issue.

@hanhanW hanhanW closed this as completed Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working codegen/llvm LLVM code generation compiler backend
Projects
None yet
Development

No branches or pull requests

3 participants