Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StableHLO complex sqrt to stablehlo-complex-math-expander pass #2679

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

pearu
Copy link
Contributor

@pearu pearu commented Dec 29, 2024

As in the title.

The existing implementation of sqrt on complex inputs uses polar form of complex sqrt which is inaccurate/incorrect on about 28 % of uniformly distributed samples over all complex plane. The JAX complex sqrt accuracy statistics is as follows:

test_unary[sqrt-jax-cuda-complex64-default] maximal ULP difference: 2792619172
ULP difference == 0: 297588
ULP difference == 1: 1106945
ULP difference == 2: 78921
ULP difference == 3: 5924
ULP difference == 4: 2938
ULP difference == 5: 2231
ULP difference == 6: 1855
ULP difference == 7: 1648
ULP difference == 8: 1449
ULP difference == 9: 1263
ULP difference == 10: 1089
ULP difference >= 11: 598949

test_unary[sqrt-jax-cpu-complex64-default] maximal ULP difference: 2760370645
ULP difference == 0: 699582
ULP difference == 1: 759750
ULP difference == 2: 58127
ULP difference == 3: 3237
ULP difference == 4: 1665
ULP difference == 5: 1185
ULP difference == 6: 1021
ULP difference == 7: 871
ULP difference == 8: 804
ULP difference == 9: 653
ULP difference == 10: 622
ULP difference >= 11: 573283

This PR provides an algorithm for complex sqrt that is accurate up to 3/6 ULP difference error on complex samples. The corresponding JAX complex sqrt accuracy statistics is as follows:

test_unary[sqrt-jax-cuda-complex64-default] maximal ULP difference: 5
ULP difference == 0: 1060571
ULP difference == 1: 1008268
ULP difference == 2: 31136
ULP difference == 3: 686
ULP difference == 4: 129
ULP difference == 5: 10

test_unary[sqrt-jax-cpu-complex64-default] maximal ULP difference: 2
ULP difference == 0: 1348868
ULP difference == 1: 751504
ULP difference == 2: 428

It is interesting to note that although the same algorithm is used for both CUDA and CPU platforms, then the expected maximal ULP difference is 4 (obtained from applying the algorithm to numpy arrays). Hence

  • the accuracy of complex sqrt on CPU is better than expected
  • the accuracy of complex sqrt on CUDA is worse than expected because CUDA sqrt produces slightly different results from std sqrt on float inputs.

@pearu
Copy link
Contributor Author

pearu commented Dec 30, 2024

@GleasonK , please review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant