Skip to content

Commit

Permalink
Merge pull request #210 from esa/CI_fixes
Browse files Browse the repository at this point in the history
CI fixes
  • Loading branch information
gomezzz authored Nov 25, 2024
2 parents 364e735 + 50ca860 commit 37fa291
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 34 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Note also that installing PyTorch with *pip* may **not** set it up with CUDA sup
Here are installation instructions for other numerical backends:
```sh
conda install "tensorflow>=2.6.0=cuda*" -c conda-forge
pip install "jax[cuda]>=0.2.22" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
pip install "jax[cuda]>=0.4.17" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
conda install "numpy>=1.19.5" -c conda-forge
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Here are installation instructions for other numerical backends:
.. code-block:: bash
conda install "tensorflow>=2.6.0=cuda*" -c conda-forge
pip install "jax[cuda]>=0.2.22" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
pip install "jax[cuda]>=0.4.17" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
conda install "numpy>=1.19.5" -c conda-forge
More installation instructions for numerical backends can be found in
Expand Down
6 changes: 3 additions & 3 deletions environment_all_backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
- loguru>=0.5.3
- matplotlib>=3.3.3
- pytest>=6.2.1
- python>=3.8
- python==3.12
- scipy>=1.6.0
- sphinx>=3.4.3
- sphinx_rtd_theme>=0.5.1
Expand All @@ -16,9 +16,9 @@ dependencies:
- numpy>=1.19.5
- cudatoolkit>=11.1
- pytorch>=1.9 # CPU version
- tensorflow>=2.10.0 # CPU version
# jaxlib with CUDA support is not available for conda
- pip:
- --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
- jax[cpu]>=0.2.22 # this will only work on linux. for win see e.g. https://github.com/cloudhan/jax-windows-builder
- tensorflow>=2.18.0 # CPU version
- jax[cpu]>=0.4.17 # this will only work on linux. for win see e.g. https://github.com/cloudhan/jax-windows-builder
# CPU version
7 changes: 6 additions & 1 deletion torchquad/integration/base_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ def evaluate_integrand(fn, points, weights=None, args=None):
len(result.shape) > 1
): # if the the integrand is multi-dimensional, we need to reshape/repeat weights so they can be broadcast in the *=
integrand_shape = anp.array(
result.shape[1:], like=infer_backend(points)
[
dim if isinstance(dim, int) else dim.as_list()
for dim in result.shape[1:]
],
like=infer_backend(points),
)

weights = anp.repeat(
anp.expand_dims(weights, axis=1), anp.prod(integrand_shape)
).reshape((weights.shape[0], *(integrand_shape)))
Expand Down
17 changes: 9 additions & 8 deletions torchquad/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,21 @@ def _setup_integration_domain(dim, integration_domain, backend):
# Get a globally default backend
backend = _get_default_backend()
dtype_arg = _get_precision(backend)
if dtype_arg is not None:
# For NumPy and Tensorflow there is no global dtype, so set the
# configured default dtype here
integration_domain = anp.array(
integration_domain, like=backend, dtype=dtype_arg
)
else:
integration_domain = anp.array(integration_domain, like=backend)
if backend == "tensorflow":
import tensorflow as tf

dtype_arg = dtype_arg or tf.keras.backend.floatx()

integration_domain = anp.array(
integration_domain, like=backend, dtype=dtype_arg
)

if integration_domain.shape != (dim, 2):
raise ValueError(
"The integration domain has an unexpected shape. "
f"Expected {(dim, 2)}, got {integration_domain.shape}"
)

return integration_domain


Expand Down
2 changes: 1 addition & 1 deletion torchquad/tests/integration_test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _poly(self, x):
# Tensorflow does not automatically cast float32 to complex128,
# so we do it here explicitly.
assert self.is_complex
exponentials = anp.cast(exponentials, self.coeffs.dtype)
exponentials = exponentials.astype(self.coeffs.dtype)

# multiply by coefficients
exponentials = anp.multiply(exponentials, self.coeffs)
Expand Down
27 changes: 21 additions & 6 deletions torchquad/tests/integrator_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,25 @@ def fn_const(x):
("jax", "float64", "float32"),
]:
continue

integrator_name = type(integrator).__name__

# VEGAS supports only numpy and torch
if integrator_name == "VEGAS" and backend in ["jax", "tensorflow"]:
continue

# Set the global precision
set_precision(dtype_global, backend=backend)

# Determine expected dtype
if backend == "tensorflow":
import tensorflow as tf

expected_dtype_name = dtype_arg if dtype_arg else tf.keras.backend.floatx()
else:
expected_dtype_name = dtype_arg if dtype_arg else dtype_global

# Set integration domain
integration_domain = [[0.0, 1.0], [-2.0, 0.0]]
if dtype_arg is not None:
# Set the integration_domain dtype which should have higher priority
Expand All @@ -75,18 +86,18 @@ def fn_const(x):
)
assert infer_backend(integration_domain) == backend
assert get_dtype_name(integration_domain) == dtype_arg
expected_dtype_name = dtype_arg
else:
expected_dtype_name = dtype_global

print(
f"[2mTesting {integrator_name} with {backend}, argument dtype"
f" {dtype_arg}, global/default dtype {dtype_global}[m"
f"Testing {integrator_name} with {backend}, argument dtype"
f" {dtype_arg}, global/default dtype {dtype_global}"
)

# Integration
if integrator_name in ["MonteCarlo", "VEGAS"]:
extra_kwargs = {"seed": 0}
else:
extra_kwargs = {}

result = integrator.integrate(
fn=fn_const,
dim=2,
Expand All @@ -95,8 +106,12 @@ def fn_const(x):
backend=backend,
**extra_kwargs,
)

assert infer_backend(result) == backend
assert get_dtype_name(result) == expected_dtype_name
assert (
get_dtype_name(result) == expected_dtype_name
), f"Expected dtype {expected_dtype_name}, got {get_dtype_name(result)}"

# VEGAS seems to be bad at integrating constant functions currently
max_error = 0.03 if integrator_name == "VEGAS" else 1e-5
assert anp.abs(result - (-4.0)) < max_error
Expand Down
2 changes: 1 addition & 1 deletion torchquad/tests/monte_carlo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _run_monte_carlo_tests(backend, _precision):
assert errors[4] < 32.0

for error in errors[6:10]:
assert error < 1e-2
assert error < 1.1e-2

for error in errors[10:]:
assert error < 28.03
Expand Down
30 changes: 18 additions & 12 deletions torchquad/utils/set_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ def _get_precision(backend):


def set_precision(data_type="float32", backend="torch"):
"""This function allows the user to set the default precision for floating point numbers for the given numerical backend.
"""Set the default precision for floating-point numbers for the given numerical backend.
Call before declaring your variables.
NumPy and Tensorflow don't have global dtypes:
NumPy and doesn't have global dtypes:
https://github.com/numpy/numpy/issues/6860
https://github.com/tensorflow/tensorflow/issues/26033
Therefore, torchquad sets the dtype argument for these two when initialising the integration domain.
Therefore, torchquad sets the dtype argument for these it when initialising the integration domain.
Args:
data_type (string, optional): Data type to use, either "float32" or "float64". Defaults to "float32".
backend (string, optional): Numerical backend for which the data type is changed. Defaults to "torch".
data_type (str, optional): Data type to use, either "float32" or "float64". Defaults to "float32".
backend (str, optional): Numerical backend for which the data type is changed. Defaults to "torch".
"""
# Backwards-compatibility: allow "float" and "double", optionally with
# upper-case letters
Expand Down Expand Up @@ -55,14 +56,19 @@ def set_precision(data_type="float32", backend="torch"):
)
torch.set_default_tensor_type(tensor_dtype)
elif backend == "jax":
from jax.config import config
from jax import config

config.update("jax_enable_x64", data_type == "float64")
logger.info(f"JAX data type set to {data_type}")
elif backend in ["numpy", "tensorflow"]:
os.environ[f"TORCHQUAD_DTYPE_{backend.upper()}"] = data_type
logger.info(
f"Default dtype config for backend {backend} set to {_get_precision(backend)}"
)
elif backend == "tensorflow":
import tensorflow as tf

# Set TensorFlow global precision
tf.keras.backend.set_floatx(data_type)
logger.info(f"TensorFlow default floatx set to {tf.keras.backend.floatx()}")
elif backend == "numpy":
# NumPy still lacks global dtype support
os.environ["TORCHQUAD_DTYPE_NUMPY"] = data_type
logger.info(f"NumPy default dtype set to {_get_precision('numpy')}")
else:
logger.error(f"Changing the data type is not supported for backend {backend}")

0 comments on commit 37fa291

Please sign in to comment.