Skip to content

Commit

Permalink
Update actions' versions (#516)
Browse files Browse the repository at this point in the history
* Update actions' versions

* Don't use `jax.tree_map`
  • Loading branch information
michalk8 authored Apr 9, 2024
1 parent 4bba69f commit 276695e
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 20 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ jobs:
lint-kind: [code, docs]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.10'

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ jobs:
environment: publish-pypi

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.10'

Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ jobs:
jax-version: [jax-default, jax-latest]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -51,7 +51,7 @@ jobs:
image: docker://michalk8/cuda:12.2.2-cudnn8-devel-ubuntu22.04
options: --gpus="device=2"
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Install dependencies
run: |
Expand Down Expand Up @@ -84,9 +84,9 @@ jobs:
os: macos-14

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down
3 changes: 2 additions & 1 deletion tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

from ott.tools.gaussian_mixture import gaussian_mixture, gaussian_mixture_pair
Expand Down Expand Up @@ -170,7 +171,7 @@ def test_pytree_mapping(self, epsilon, tau, lock_gmm1):
)
expected_gmm1_loc = 2.0 * self.gmm1.loc if not lock_gmm1 else self.gmm1.loc

pair_x_2 = jax.tree_map(lambda x: 2.0 * x, pair)
pair_x_2 = jtu.tree_map(lambda x: 2.0 * x, pair)
# gmm parameters should be doubled
np.testing.assert_allclose(2.0 * pair.gmm0.loc, pair_x_2.gmm0.loc)
np.testing.assert_allclose(expected_gmm1_loc, pair_x_2.gmm1.loc)
Expand Down
3 changes: 2 additions & 1 deletion tests/tools/gaussian_mixture/gaussian_mixture_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

from ott.tools.gaussian_mixture import gaussian_mixture, linalg
Expand Down Expand Up @@ -163,7 +164,7 @@ def test_pytree_mapping(self, rng: jax.Array):
gmm = gaussian_mixture.GaussianMixture.from_random(
rng=rng, n_components=3, n_dimensions=2
)
gmm_x_2 = jax.tree_map(lambda x: 2.0 * x, gmm)
gmm_x_2 = jtu.tree_map(lambda x: 2.0 * x, gmm)
np.testing.assert_allclose(2.0 * gmm.loc, gmm_x_2.loc, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(
2.0 * gmm.scale_params, gmm_x_2.scale_params, atol=1e-4, rtol=1e-4
Expand Down
7 changes: 4 additions & 3 deletions tests/tools/gaussian_mixture/gaussian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

from ott.tools.gaussian_mixture import gaussian, scale_tril
Expand Down Expand Up @@ -137,14 +138,14 @@ def test_transport(self, rng: jax.Array):

def test_flatten_unflatten(self, rng: jax.Array):
g = gaussian.Gaussian.from_random(rng, n_dimensions=3)
children, aux_data = jax.tree_util.tree_flatten(g)
g_new = jax.tree_util.tree_unflatten(aux_data, children)
children, aux_data = jtu.tree_flatten(g)
g_new = jtu.tree_unflatten(aux_data, children)

assert g == g_new

def test_pytree_mapping(self, rng: jax.Array):
g = gaussian.Gaussian.from_random(rng, n_dimensions=3)
g_x_2 = jax.tree_map(lambda x: 2 * x, g)
g_x_2 = jtu.tree_map(lambda x: 2 * x, g)

np.testing.assert_allclose(2.0 * g.loc, g_x_2.loc)
np.testing.assert_allclose(2.0 * g.scale.params, g_x_2.scale.params)
7 changes: 4 additions & 3 deletions tests/tools/gaussian_mixture/probabilities_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

from ott.tools.gaussian_mixture import probabilities
Expand Down Expand Up @@ -63,15 +64,15 @@ def test_sample(self, rng: jax.Array):
def test_flatten_unflatten(self):
probs = jnp.array([0.1, 0.2, 0.3, 0.4])
pp = probabilities.Probabilities.from_probs(probs)
children, aux_data = jax.tree_util.tree_flatten(pp)
pp_new = jax.tree_util.tree_unflatten(aux_data, children)
children, aux_data = jtu.tree_flatten(pp)
pp_new = jtu.tree_unflatten(aux_data, children)
np.testing.assert_array_equal(pp.params, pp_new.params)
assert pp == pp_new

def test_pytree_mapping(self):
probs = jnp.array([0.1, 0.2, 0.3, 0.4])
pp = probabilities.Probabilities.from_probs(probs)
pp_x_2 = jax.tree_map(lambda x: 2 * x, pp)
pp_x_2 = jtu.tree_map(lambda x: 2 * x, pp)
np.testing.assert_allclose(
2.0 * pp.params, pp_x_2.params, rtol=1e-6, atol=1e-6
)
7 changes: 4 additions & 3 deletions tests/tools/gaussian_mixture/scale_tril_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

from ott.math import matrix_square_root
Expand Down Expand Up @@ -102,12 +103,12 @@ def test_transport(self, rng: jax.Array):

def test_flatten_unflatten(self, rng: jax.Array):
scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3)
children, aux_data = jax.tree_util.tree_flatten(scale)
scale_new = jax.tree_util.tree_unflatten(aux_data, children)
children, aux_data = jtu.tree_flatten(scale)
scale_new = jtu.tree_unflatten(aux_data, children)
np.testing.assert_array_equal(scale.params, scale_new.params)
assert scale == scale_new

def test_pytree_mapping(self, rng: jax.Array):
scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3)
scale_x_2 = jax.tree_map(lambda x: 2 * x, scale)
scale_x_2 = jtu.tree_map(lambda x: 2 * x, scale)
np.testing.assert_allclose(2.0 * scale.params, scale_x_2.params)

0 comments on commit 276695e

Please sign in to comment.