-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX log accuracy and correctness #46
Comments
With https://github.com/pearu/stablehlo/tree/pearu/log, we'll have
|
The result of import jax.numpy as jnp
dtype = jnp.complex64
ctx = fa.Context(paths=[fa.algorithms])
graph = ctx.trace(getattr(fa.algorithms, "log"), dtype)
graph2 = graph.rewrite(fa.targets.numpy, fa.rewrite)
func = fa.targets.numpy.as_function(graph2, debug=2, numpy=jnp)
z = jnp.array(0.9970238-0.08273122j, dtype=jnp.complex64)
v = func(z)
print(f'{func.__name__}({str(z)}) -> {str(v)}') is
that is, using jax functions directly will produce equivalent result to numpy functions. |
With the following scripts: # run_pearu_log_branch_test.py in jax repo
import numpy as np
import jax.numpy as jnp
import functional_algorithms as fa
def main():
dtype = jnp.complex64
ctx = fa.Context(paths=[fa.algorithms])
graph = ctx.trace(getattr(fa.algorithms, "log"), dtype)
graph2 = graph.rewrite(fa.targets.numpy, fa.rewrite)
func = fa.targets.numpy.as_function(graph2, debug=0)
z = jnp.array(0.9970238-0.08273122j, dtype=dtype)
print(jnp.log.lower(z).compile().as_text())
print(f'EXPECTED (0.00090095366 or 0.0009009537): {str(func(z))}')
print(f'JAX: {str(jnp.log(z))}')
if __name__ == "__main__":
main() #!/bin/bash
# update_pearu_log_branch.sh in stablehlo repo
PYTHONPATH=../functional_algorithms-4 python build_tools/math/generate_ChloDecompositionPatternsMath.py
git add stablehlo/transforms/StablehloComplexMathExpanderPatterns.td
git commit --amend --no-edit
git push -uf origin pearu/log
sleep 5
STABLEHLO_SHA256=`curl -sL https://github.com/pearu/stablehlo/archive/refs/heads/pearu/log.zip | sha256sum | cut -d' ' -f1`
cat > ../xla/third_party/stablehlo/workspace.bzl << EndOfMessage
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
def repo():
tf_http_archive(
name = "stablehlo",
strip_prefix = "stablehlo-pearu-log",
sha256 = "${STABLEHLO_SHA256}",
urls = tf_mirror_urls("https://github.com/pearu/stablehlo/archive/refs/heads/pearu/log.zip"),
patch_file = [],
)
EndOfMessage
echo "Created ../xla/third_party/stablehlo/workspace.bzl:"
cat ../xla/third_party/stablehlo/workspace.bzl
it was established that in the stablehlo->xla->jax path, the A more stable algorithm is obtained by using |
Applying the above fix to https://github.com/pearu/stablehlo/tree/pearu/log, we'll have
|
The text was updated successfully, but these errors were encountered: