Skip to content

Commit

Permalink
Enable 3.13 CI (#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 authored Oct 11, 2024
1 parent 0219671 commit 706cef7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.10', '3.11', '3.12']
python-version: ['3.10', '3.11', '3.12', '3.13']
os: [ubuntu-latest]
include:
- python-version: '3.9'
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]

[project.urls]
Expand Down Expand Up @@ -74,7 +75,7 @@ test = [
"networkx>=2.5",
"scikit-learn>=1.0",
"tqdm",
"tslearn>=0.5",
"tslearn>=0.5; python_version < '3.13'",
"matplotlib",
]
docs = [
Expand Down Expand Up @@ -184,14 +185,14 @@ ignore_path = ["docs/**/_autosummary", "docs/contributing.rst"]
legacy_tox_ini = """
[tox]
min_version = 4.0
env_list = lint-code,py{3.9,3.10,3.11,3.12},py3.9-jax-default
env_list = lint-code,py{3.9,3.10,3.11,3.12,3.13},py3.10-jax-default
skip_missing_interpreters = true
[testenv]
extras =
test
# https://github.com/google/flax/issues/3329
py{3.9,3.10,3.11,3.12},py3.9-jax-default: neural
py{3.9,3.10,3.11,3.12,3.13},py3.10-jax-default: neural
pass_env = CUDA_*,PYTEST_*,CI
commands_pre =
gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand Down
3 changes: 2 additions & 1 deletion tests/geometry/costs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import jaxopt
import numpy as np
import scipy as sp
from tslearn import metrics as ts_metrics

from ott.geometry import costs, pointcloud, regularizers
from ott.math import utils as mu
Expand Down Expand Up @@ -372,6 +371,7 @@ class TestSoftDTW:
@pytest.mark.parametrize("m", [9, 10])
@pytest.mark.parametrize("gamma", [1e-3, 5])
def test_soft_dtw(self, rng: jax.Array, n: int, m: int, gamma: float):
ts_metrics = pytest.importorskip("tslearn.metrics")
rng1, rng2 = jax.random.split(rng, 2)
t1 = jax.random.normal(rng1, (n,))
t2 = jax.random.normal(rng2, (m,))
Expand All @@ -388,6 +388,7 @@ def test_soft_dtw_debiased(
debiased: bool,
jit: bool,
):
ts_metrics = pytest.importorskip("tslearn.metrics")
gamma = 1e-1
rng1, rng2 = jax.random.split(rng, 2)
t1 = jax.random.normal(rng1, (16,))
Expand Down

0 comments on commit 706cef7

Please sign in to comment.