Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gomezzz committed Nov 25, 2024
1 parent e77a536 commit ad99710
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions torchquad/tests/integrator_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ def fn_const(x):
# Determine expected dtype
if backend == "tensorflow":
import tensorflow as tf
expected_dtype_name = (
dtype_arg if dtype_arg else tf.keras.backend.floatx()
)

expected_dtype_name = dtype_arg if dtype_arg else tf.keras.backend.floatx()
else:
expected_dtype_name = dtype_arg if dtype_arg else dtype_global

Expand Down Expand Up @@ -109,16 +108,15 @@ def fn_const(x):
)

assert infer_backend(result) == backend
assert get_dtype_name(result) == expected_dtype_name, (
f"Expected dtype {expected_dtype_name}, got {get_dtype_name(result)}"
)
assert (
get_dtype_name(result) == expected_dtype_name
), f"Expected dtype {expected_dtype_name}, got {get_dtype_name(result)}"

# VEGAS seems to be bad at integrating constant functions currently
max_error = 0.03 if integrator_name == "VEGAS" else 1e-5
assert anp.abs(result - (-4.0)) < max_error



test_integrate_numpy = setup_test_for_backend(_run_simple_integrations, "numpy", None)
test_integrate_torch = setup_test_for_backend(_run_simple_integrations, "torch", None)
test_integrate_jax = setup_test_for_backend(_run_simple_integrations, "jax", None)
Expand Down

0 comments on commit ad99710

Please sign in to comment.