From cb88a8b189ecc62ef55188a84283bd384c7b2803 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:38:33 +0200 Subject: [PATCH] Enable 3.13 CI --- .github/workflows/tests.yml | 2 +- pyproject.toml | 7 ++++--- tests/geometry/costs_test.py | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7a91ec027..1c6d25843 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -75,7 +75,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.10', '3.11', '3.12'] + python-version: ['3.10', '3.11', '3.12', '3.13'] os: [ubuntu-latest] include: - python-version: '3.9' diff --git a/pyproject.toml b/pyproject.toml index 6c15004e8..1ddcf2e26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] [project.urls] @@ -74,7 +75,7 @@ test = [ "networkx>=2.5", "scikit-learn>=1.0", "tqdm", - "tslearn>=0.5", + "tslearn>=0.5; python_version < '3.13'", "matplotlib", ] docs = [ @@ -184,14 +185,14 @@ ignore_path = ["docs/**/_autosummary", "docs/contributing.rst"] legacy_tox_ini = """ [tox] min_version = 4.0 -env_list = lint-code,py{3.9,3.10,3.11,3.12},py3.9-jax-default +env_list = lint-code,py{3.9,3.10,3.11,3.12,3.13},py3.10-jax-default skip_missing_interpreters = true [testenv] extras = test # https://github.com/google/flax/issues/3329 - py{3.9,3.10,3.11,3.12},py3.9-jax-default: neural + py{3.9,3.10,3.11,3.12,3.13},py3.10-jax-default: neural pass_env = CUDA_*,PYTEST_*,CI commands_pre = gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 7a5f4eba3..0f675458b 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -21,7 +21,6 @@ import jaxopt import numpy as np import scipy as sp -from tslearn import metrics as ts_metrics from ott.geometry import costs, pointcloud, regularizers from ott.math import utils as mu @@ -372,6 +371,7 @@ class TestSoftDTW: @pytest.mark.parametrize("m", [9, 10]) @pytest.mark.parametrize("gamma", [1e-3, 5]) def test_soft_dtw(self, rng: jax.Array, n: int, m: int, gamma: float): + ts_metrics = pytest.importorskip("tslearn.metrics") rng1, rng2 = jax.random.split(rng, 2) t1 = jax.random.normal(rng1, (n,)) t2 = jax.random.normal(rng2, (m,)) @@ -388,6 +388,7 @@ def test_soft_dtw_debiased( debiased: bool, jit: bool, ): + ts_metrics = pytest.importorskip("tslearn.metrics") gamma = 1e-1 rng1, rng2 = jax.random.split(rng, 2) t1 = jax.random.normal(rng1, (16,))