Skip to content

Commit

Permalink
Adjust tolerance in a test
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Feb 7, 2023
1 parent 1236aa6 commit c60cf50
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Sinkhorn."""
from typing import Optional

import pytest

Expand Down Expand Up @@ -471,7 +472,7 @@ def test_sinkhorn_online_memory(self, batch_size: int):
@pytest.mark.fast.with_args(
cost_fn=[None, costs.SqPNorm(1.6)],
)
def test_primal_cost_grid(self, cost_fn):
def test_primal_cost_grid(self, cost_fn: Optional[costs.CostFn]):
"""Test computation of primal / costs for Grids."""
ns = [6, 7, 11]
xs = [
Expand All @@ -493,7 +494,7 @@ def test_primal_cost_grid(self, cost_fn):
cost = jnp.sum(transport_matrix * cost_matrix)
assert cost > 0.0
assert out.primal_cost > 0.0
np.testing.assert_allclose(cost, out.primal_cost, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(cost, out.primal_cost, rtol=1e-5, atol=1e-5)
assert jnp.isfinite(out.dual_cost)
assert out.primal_cost - out.dual_cost > 0.0

Expand Down

0 comments on commit c60cf50

Please sign in to comment.