Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jax==0.4.4 jitting issues in CI #308

Closed
michalk8 opened this issue Feb 20, 2023 · 3 comments · Fixed by #310
Closed

Fix jax==0.4.4 jitting issues in CI #308

michalk8 opened this issue Feb 20, 2023 · 3 comments · Fixed by #310
Assignees
Labels
bug Something isn't working

Comments

@michalk8
Copy link
Collaborator

No description provided.

@michalk8 michalk8 added the bug Something isn't working label Feb 20, 2023
@michalk8 michalk8 self-assigned this Feb 20, 2023
@michalk8
Copy link
Collaborator Author

michalk8 commented Feb 20, 2023

Issue seems to be with cache, running

pytest -k 'test_online_sinkhorn_jit[True]'  # works
pytest -k 'test_online_sinkhorn_jit[False]'  # works
pytest test_online_sinkhorn_jit  # fails; tb below

Traceback:

============================= test session starts ==============================
platform darwin -- Python 3.10.6, pytest-7.2.0, pluggy-1.0.0
rootdir: /Users/michal/Projects/ott, configfile: pyproject.toml, testpaths: tests
plugins: memray-1.4.0, mock-3.10.0, xdist-3.0.2, cov-4.0.0, anyio-3.6.2
collected 786 items / 784 deselected / 2 selected

tests/solvers/linear/sinkhorn_misc_test.py .F                            [100%]

=================================== FAILURES ===================================
______________ TestSinkhornOnline.test_online_sinkhorn_jit[True] _______________

    #!/Users/michal/.mambaforge/envs/ott/bin/python3.10
    # -*- coding: utf-8 -*-
    import re
    import sys
    from pytest import console_main
    if __name__ == '__main__':
        sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
>       sys.exit(console_main())

../../.mambaforge/envs/ott/bin/pytest:8: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    def console_main() -> int:
        """The CLI entry point of pytest.
    
        This function is not meant for programmable use; use `main()` instead.
        """
        # https://docs.python.org/3/library/signal.html#note-on-sigpipe
        try:
>           code = main()

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/config/__init__.py:190: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = None, plugins = None

    def main(
        args: Optional[Union[List[str], "os.PathLike[str]"]] = None,
        plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
    ) -> Union[int, ExitCode]:
        """Perform an in-process test run.
    
        :param args: List of command line arguments.
        :param plugins: List of plugin objects to be auto-registered during initialization.
    
        :returns: An exit code.
        """
        try:
            try:
                config = _prepareconfig(args, plugins)
            except ConftestImportFailure as e:
                exc_info = ExceptionInfo.from_exc_info(e.excinfo)
                tw = TerminalWriter(sys.stderr)
                tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
                exc_info.traceback = exc_info.traceback.filter(
                    filter_traceback_for_conftest_import_failure
                )
                exc_repr = (
                    exc_info.getrepr(style="short", chain=False)
                    if exc_info.traceback
                    else exc_info.exconly()
                )
                formatted_tb = str(exc_repr)
                for line in formatted_tb.splitlines():
                    tw.line(line.rstrip(), red=True)
                return ExitCode.USAGE_ERROR
            else:
                try:
>                   ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
                        config=config
                    )

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/config/__init__.py:167: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x10385b670>}
argname = 'config', firstresult = True

    def __call__(self, *args, **kwargs):
        if args:
            raise TypeError("hook calling supports only keyword arguments")
        assert not self.is_historic()
    
        # This is written to avoid expensive operations when not needed.
        if self.spec:
            for argname in self.spec.argnames:
                if argname not in kwargs:
                    notincall = tuple(set(self.spec.argnames) - kwargs.keys())
                    warnings.warn(
                        "Argument(s) {} which are declared in the hookspec "
                        "can not be found in this hook call".format(notincall),
                        stacklevel=2,
                    )
                    break
    
            firstresult = self.spec.opts.get("firstresult")
        else:
            firstresult = False
    
>       return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_hooks.py:265: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_pytest.config.PytestPluginManager object at 0x102dc8a60>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/Users/michal/.mambaforge/envs/ott/lib/python3.10/s... '_pytest.mark' from '/Users/michal/.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/mark/__init__.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x10385b670>}
firstresult = True

    def _hookexec(self, hook_name, methods, kwargs, firstresult):
        # called from all hookcaller instances.
        # enable_tracing will set its own wrapping function at self._inner_hookexec
>       return self._inner_hookexec(hook_name, methods, kwargs, firstresult)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_manager.py:80: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/Users/michal/.mambaforge/envs/ott/lib/python3.10/s... '_pytest.mark' from '/Users/michal/.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/mark/__init__.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x10385b670>}
firstresult = True

    def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
        """Execute a call into multiple python functions/methods and return the
        result(s).
    
        ``caller_kwargs`` comes from _HookCaller.__call__().
        """
        __tracebackhide__ = True
        results = []
        excinfo = None
        try:  # run impl and wrapper setup functions in a loop
            teardowns = []
            try:
                for hook_impl in reversed(hook_impls):
                    try:
                        args = [caller_kwargs[argname] for argname in hook_impl.argnames]
                    except KeyError:
                        for argname in hook_impl.argnames:
                            if argname not in caller_kwargs:
                                raise HookCallError(
                                    f"hook call must provide argument {argname!r}"
                                )
    
                    if hook_impl.hookwrapper:
                        try:
                            gen = hook_impl.function(*args)
                            next(gen)  # first yield
                            teardowns.append(gen)
                        except StopIteration:
                            _raise_wrapfail(gen, "did not yield")
                    else:
>                       res = hook_impl.function(*args)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_callers.py:39: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

config = <_pytest.config.Config object at 0x10385b670>

    def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
>       return wrap_session(config, _main)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/main.py:317: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

config = <_pytest.config.Config object at 0x10385b670>
doit = <function _main at 0x1035fe440>

    def wrap_session(
        config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
    ) -> Union[int, ExitCode]:
        """Skeleton command line program."""
        session = Session.from_config(config)
        session.exitstatus = ExitCode.OK
        initstate = 0
        try:
            try:
                config._do_configure()
                initstate = 1
                config.hook.pytest_sessionstart(session=session)
                initstate = 2
>               session.exitstatus = doit(config, session) or 0

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/main.py:270: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

config = <_pytest.config.Config object at 0x10385b670>
session = <Session ott exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=2>

    def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
        """Default command line protocol for initialization, session,
        running tests and reporting."""
        config.hook.pytest_collection(session=session)
>       config.hook.pytest_runtestloop(session=session)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/main.py:324: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session ott exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=2>}
argname = 'session', firstresult = True

    def __call__(self, *args, **kwargs):
        if args:
            raise TypeError("hook calling supports only keyword arguments")
        assert not self.is_historic()
    
        # This is written to avoid expensive operations when not needed.
        if self.spec:
            for argname in self.spec.argnames:
                if argname not in kwargs:
                    notincall = tuple(set(self.spec.argnames) - kwargs.keys())
                    warnings.warn(
                        "Argument(s) {} which are declared in the hookspec "
                        "can not be found in this hook call".format(notincall),
                        stacklevel=2,
                    )
                    break
    
            firstresult = self.spec.opts.get("firstresult")
        else:
            firstresult = False
    
>       return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_hooks.py:265: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_pytest.config.PytestPluginManager object at 0x102dc8a60>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/Users/michal/.mambaforge/envs/ott/lib/python3.10/s...test/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x1268b5180>>]
kwargs = {'session': <Session ott exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=2>}
firstresult = True

    def _hookexec(self, hook_name, methods, kwargs, firstresult):
        # called from all hookcaller instances.
        # enable_tracing will set its own wrapping function at self._inner_hookexec
>       return self._inner_hookexec(hook_name, methods, kwargs, firstresult)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_manager.py:80: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/Users/michal/.mambaforge/envs/ott/lib/python3.10/s...test/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x1268b5180>>]
caller_kwargs = {'session': <Session ott exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=2>}
firstresult = True

    def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
        """Execute a call into multiple python functions/methods and return the
        result(s).
    
        ``caller_kwargs`` comes from _HookCaller.__call__().
        """
        __tracebackhide__ = True
        results = []
        excinfo = None
        try:  # run impl and wrapper setup functions in a loop
            teardowns = []
            try:
                for hook_impl in reversed(hook_impls):
                    try:
                        args = [caller_kwargs[argname] for argname in hook_impl.argnames]
                    except KeyError:
                        for argname in hook_impl.argnames:
                            if argname not in caller_kwargs:
                                raise HookCallError(
                                    f"hook call must provide argument {argname!r}"
                                )
    
                    if hook_impl.hookwrapper:
                        try:
                            gen = hook_impl.function(*args)
                            next(gen)  # first yield
                            teardowns.append(gen)
                        except StopIteration:
                            _raise_wrapfail(gen, "did not yield")
                    else:
>                       res = hook_impl.function(*args)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_callers.py:39: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

session = <Session ott exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=2>

    def pytest_runtestloop(session: "Session") -> bool:
        if session.testsfailed and not session.config.option.continue_on_collection_errors:
            raise session.Interrupted(
                "%d error%s during collection"
                % (session.testsfailed, "s" if session.testsfailed != 1 else "")
            )
    
        if session.config.option.collectonly:
            return True
    
        for i, item in enumerate(session.items):
            nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
>           item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/main.py:349: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_online_sinkhorn_jit[True]>, 'nextitem': None}
argname = 'nextitem', firstresult = True

    def __call__(self, *args, **kwargs):
        if args:
            raise TypeError("hook calling supports only keyword arguments")
        assert not self.is_historic()
    
        # This is written to avoid expensive operations when not needed.
        if self.spec:
            for argname in self.spec.argnames:
                if argname not in kwargs:
                    notincall = tuple(set(self.spec.argnames) - kwargs.keys())
                    warnings.warn(
                        "Argument(s) {} which are declared in the hookspec "
                        "can not be found in this hook call".format(notincall),
                        stacklevel=2,
                    )
                    break
    
            firstresult = self.spec.opts.get("firstresult")
        else:
            firstresult = False
    
>       return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_hooks.py:265: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_pytest.config.PytestPluginManager object at 0x102dc8a60>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/Users/michal/.mambaforge/envs/ott/lib/python3....module '_pytest.warnings' from '/Users/michal/.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_online_sinkhorn_jit[True]>, 'nextitem': None}
firstresult = True

    def _hookexec(self, hook_name, methods, kwargs, firstresult):
        # called from all hookcaller instances.
        # enable_tracing will set its own wrapping function at self._inner_hookexec
>       return self._inner_hookexec(hook_name, methods, kwargs, firstresult)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_manager.py:80: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/Users/michal/.mambaforge/envs/ott/lib/python3....module '_pytest.warnings' from '/Users/michal/.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_online_sinkhorn_jit[True]>, 'nextitem': None}
firstresult = True

    def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
        """Execute a call into multiple python functions/methods and return the
        result(s).
    
        ``caller_kwargs`` comes from _HookCaller.__call__().
        """
        __tracebackhide__ = True
        results = []
        excinfo = None
        try:  # run impl and wrapper setup functions in a loop
            teardowns = []
            try:
                for hook_impl in reversed(hook_impls):
                    try:
                        args = [caller_kwargs[argname] for argname in hook_impl.argnames]
                    except KeyError:
                        for argname in hook_impl.argnames:
                            if argname not in caller_kwargs:
                                raise HookCallError(
                                    f"hook call must provide argument {argname!r}"
                                )
    
                    if hook_impl.hookwrapper:
                        try:
                            gen = hook_impl.function(*args)
                            next(gen)  # first yield
                            teardowns.append(gen)
                        except StopIteration:
                            _raise_wrapfail(gen, "did not yield")
                    else:
>                       res = hook_impl.function(*args)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_callers.py:39: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

item = <Function test_online_sinkhorn_jit[True]>, nextitem = None

    def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
        ihook = item.ihook
        ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
>       runtestprotocol(item, nextitem=nextitem)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/runner.py:112: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

item = <Function test_online_sinkhorn_jit[True]>, log = True, nextitem = None

    def runtestprotocol(
        item: Item, log: bool = True, nextitem: Optional[Item] = None
    ) -> List[TestReport]:
        hasrequest = hasattr(item, "_request")
        if hasrequest and not item._request:  # type: ignore[attr-defined]
            # This only happens if the item is re-run, as is done by
            # pytest-rerunfailures.
            item._initrequest()  # type: ignore[attr-defined]
        rep = call_and_report(item, "setup", log)
        reports = [rep]
        if rep.passed:
            if item.config.getoption("setupshow", False):
                show_test_item(item)
            if not item.config.getoption("setuponly", False):
>               reports.append(call_and_report(item, "call", log))

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/runner.py:131: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

item = <Function test_online_sinkhorn_jit[True]>, when = 'call', log = True
kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo ConcretizationTypeError('Abstract tracer value encountered where concrete...ened position 0.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError') tblen=3>>
hook = <_pytest.config.compat.PathAwareHookProxy object at 0x10385b7f0>

    def call_and_report(
        item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
    ) -> TestReport:
>       call = call_runtest_hook(item, when, **kwds)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/runner.py:220: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

item = <Function test_online_sinkhorn_jit[True]>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)

    def call_runtest_hook(
        item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
    ) -> "CallInfo[None]":
        if when == "setup":
            ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
        elif when == "call":
            ihook = item.ihook.pytest_runtest_call
        elif when == "teardown":
            ihook = item.ihook.pytest_runtest_teardown
        else:
            assert False, f"Unhandled runtest hook case: {when}"
        reraise: Tuple[Type[BaseException], ...] = (Exit,)
        if not item.config.getoption("usepdb", False):
            reraise += (KeyboardInterrupt,)
>       return CallInfo.from_call(
            lambda: ihook(item=item, **kwds), when=when, reraise=reraise
        )

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/runner.py:259: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x14fe576d0>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)

    @classmethod
    def from_call(
        cls,
        func: "Callable[[], TResult]",
        when: "Literal['collect', 'setup', 'call', 'teardown']",
        reraise: Optional[
            Union[Type[BaseException], Tuple[Type[BaseException], ...]]
        ] = None,
    ) -> "CallInfo[TResult]":
        """Call func, wrapping the result in a CallInfo.
    
        :param func:
            The function to call. Called without arguments.
        :param when:
            The phase in which the function is called.
        :param reraise:
            Exception or exceptions that shall propagate if raised by the
            function, instead of being wrapped in the CallInfo.
        """
        excinfo = None
        start = timing.time()
        precise_start = timing.perf_counter()
        try:
>           result: Optional[TResult] = func()

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/runner.py:339: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

>       lambda: ihook(item=item, **kwds), when=when, reraise=reraise
    )

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/runner.py:260: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_online_sinkhorn_jit[True]>}, argname = 'item'
firstresult = False

    def __call__(self, *args, **kwargs):
        if args:
            raise TypeError("hook calling supports only keyword arguments")
        assert not self.is_historic()
    
        # This is written to avoid expensive operations when not needed.
        if self.spec:
            for argname in self.spec.argnames:
                if argname not in kwargs:
                    notincall = tuple(set(self.spec.argnames) - kwargs.keys())
                    warnings.warn(
                        "Argument(s) {} which are declared in the hookspec "
                        "can not be found in this hook call".format(notincall),
                        stacklevel=2,
                    )
                    break
    
            firstresult = self.spec.opts.get("firstresult")
        else:
            firstresult = False
    
>       return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_hooks.py:265: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_pytest.config.PytestPluginManager object at 0x102dc8a60>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/Users/michal/.mambaforge/envs/ott/lib/python3....t.threadexception' from '/Users/michal/.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/threadexception.py'>>]
kwargs = {'item': <Function test_online_sinkhorn_jit[True]>}
firstresult = False

    def _hookexec(self, hook_name, methods, kwargs, firstresult):
        # called from all hookcaller instances.
        # enable_tracing will set its own wrapping function at self._inner_hookexec
>       return self._inner_hookexec(hook_name, methods, kwargs, firstresult)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_manager.py:80: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/Users/michal/.mambaforge/envs/ott/lib/python3....t.threadexception' from '/Users/michal/.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/threadexception.py'>>]
caller_kwargs = {'item': <Function test_online_sinkhorn_jit[True]>}
firstresult = False

    def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
        """Execute a call into multiple python functions/methods and return the
        result(s).
    
        ``caller_kwargs`` comes from _HookCaller.__call__().
        """
        __tracebackhide__ = True
        results = []
        excinfo = None
        try:  # run impl and wrapper setup functions in a loop
            teardowns = []
            try:
                for hook_impl in reversed(hook_impls):
                    try:
                        args = [caller_kwargs[argname] for argname in hook_impl.argnames]
                    except KeyError:
                        for argname in hook_impl.argnames:
                            if argname not in caller_kwargs:
                                raise HookCallError(
                                    f"hook call must provide argument {argname!r}"
                                )
    
                    if hook_impl.hookwrapper:
                        try:
                            gen = hook_impl.function(*args)
                            next(gen)  # first yield
                            teardowns.append(gen)
                        except StopIteration:
                            _raise_wrapfail(gen, "did not yield")
                    else:
>                       res = hook_impl.function(*args)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_callers.py:39: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

item = <Function test_online_sinkhorn_jit[True]>

    def pytest_runtest_call(item: Item) -> None:
        _update_current_test_var(item, "call")
        try:
            del sys.last_type
            del sys.last_value
            del sys.last_traceback
        except AttributeError:
            pass
        try:
>           item.runtest()

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/runner.py:167: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <Function test_online_sinkhorn_jit[True]>

    def runtest(self) -> None:
        """Execute the underlying test function."""
>       self.ihook.pytest_pyfunc_call(pyfuncitem=self)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/python.py:1789: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_online_sinkhorn_jit[True]>}
argname = 'pyfuncitem', firstresult = True

    def __call__(self, *args, **kwargs):
        if args:
            raise TypeError("hook calling supports only keyword arguments")
        assert not self.is_historic()
    
        # This is written to avoid expensive operations when not needed.
        if self.spec:
            for argname in self.spec.argnames:
                if argname not in kwargs:
                    notincall = tuple(set(self.spec.argnames) - kwargs.keys())
                    warnings.warn(
                        "Argument(s) {} which are declared in the hookspec "
                        "can not be found in this hook call".format(notincall),
                        stacklevel=2,
                    )
                    break
    
            firstresult = self.spec.opts.get("firstresult")
        else:
            firstresult = False
    
>       return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_hooks.py:265: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_pytest.config.PytestPluginManager object at 0x102dc8a60>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/Users/michal/.mambaforge/envs/ott/lib/python3....est_plugin.py'>>, <HookImpl plugin_name='memray_manager', plugin=<pytest_memray.plugin.Manager object at 0x1268b4be0>>]
kwargs = {'pyfuncitem': <Function test_online_sinkhorn_jit[True]>}
firstresult = True

    def _hookexec(self, hook_name, methods, kwargs, firstresult):
        # called from all hookcaller instances.
        # enable_tracing will set its own wrapping function at self._inner_hookexec
>       return self._inner_hookexec(hook_name, methods, kwargs, firstresult)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_manager.py:80: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/Users/michal/.mambaforge/envs/ott/lib/python3....est_plugin.py'>>, <HookImpl plugin_name='memray_manager', plugin=<pytest_memray.plugin.Manager object at 0x1268b4be0>>]
caller_kwargs = {'pyfuncitem': <Function test_online_sinkhorn_jit[True]>}
firstresult = True

    def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
        """Execute a call into multiple python functions/methods and return the
        result(s).
    
        ``caller_kwargs`` comes from _HookCaller.__call__().
        """
        __tracebackhide__ = True
        results = []
        excinfo = None
        try:  # run impl and wrapper setup functions in a loop
            teardowns = []
            try:
                for hook_impl in reversed(hook_impls):
                    try:
                        args = [caller_kwargs[argname] for argname in hook_impl.argnames]
                    except KeyError:
                        for argname in hook_impl.argnames:
                            if argname not in caller_kwargs:
                                raise HookCallError(
                                    f"hook call must provide argument {argname!r}"
                                )
    
                    if hook_impl.hookwrapper:
                        try:
                            gen = hook_impl.function(*args)
                            next(gen)  # first yield
                            teardowns.append(gen)
                        except StopIteration:
                            _raise_wrapfail(gen, "did not yield")
                    else:
>                       res = hook_impl.function(*args)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/pluggy/_callers.py:39: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

pyfuncitem = <Function test_online_sinkhorn_jit[True]>

    @hookimpl(trylast=True)
    def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
        testfunction = pyfuncitem.obj
        if is_async_function(testfunction):
            async_warn_and_skip(pyfuncitem.nodeid)
        funcargs = pyfuncitem.funcargs
        testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
>       result = testfunction(**testargs)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/_pytest/python.py:195: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <sinkhorn_misc_test.TestSinkhornOnline object at 0x14c6c3c10>, jit = True

    @pytest.mark.parametrize("jit", [False, True])
    def test_online_sinkhorn_jit(self, jit: bool):
    
      def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput:
        geom = pointcloud.PointCloud(
            self.x, self.y, epsilon=epsilon, batch_size=batch_size
        )
        prob = linear_problem.LinearProblem(geom, self.a, self.b)
        solver = sinkhorn.Sinkhorn(threshold=threshold)
        return solver(prob)
    
      threshold = 1e-1
      fun = jax.jit(callback, static_argnums=(1,)) if jit else callback
    
>     errors = fun(epsilon=1.0, batch_size=42).errors

tests/solvers/linear/sinkhorn_misc_test.py:232: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (), kwargs = {'batch_size': 42, 'epsilon': 1.0}, __tracebackhide__ = True
msg = 'jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<S...ludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------'

    @util.wraps(fun)
    def reraise_with_filtered_traceback(*args, **kwargs):
      __tracebackhide__ = True
      try:
>       return fun(*args, **kwargs)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/traceback_util.py:163: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (), kwargs = {'batch_size': 42, 'epsilon': 1.0}

    @api_boundary
    def cache_miss(*args, **kwargs):
>     outs, out_flat, out_tree, args_flat = _python_pjit_helper(
          fun, infer_params_fn, *args, **kwargs)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/pjit.py:237: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

fun = <function TestSinkhornOnline.test_online_sinkhorn_jit.<locals>.callback at 0x126b22290>
infer_params_fn = <function jit.<locals>.infer_params at 0x16a499090>, args = ()
kwargs = {'batch_size': 42, 'epsilon': 1.0}

    def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
>     args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
          *args, **kwargs)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/pjit.py:180: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (), kwargs = {'batch_size': 42, 'epsilon': 1.0}
pjit_info_args = PjitInfo(fun=<function TestSinkhornOnline.test_online_sinkhorn_jit.<locals>.callback at 0x126b22290>, in_shardings=<ja...ames=('batch_size',), donate_argnums=(), device=None, backend=None, keep_unused=False, inline=False, resource_env=None)

    def infer_params(*args, **kwargs):
      pjit_info_args = pjit.PjitInfo(
          fun=fun, in_shardings=in_shardings,
          out_shardings=out_shardings, static_argnums=static_argnums,
          static_argnames=static_argnames, donate_argnums=donate_argnums,
          device=device, backend=backend, keep_unused=keep_unused,
          inline=inline, resource_env=None)
>     return pjit.common_infer_params(pjit_info_args, *args, **kwargs)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/api.py:443: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

pjit_info_args = PjitInfo(fun=<function TestSinkhornOnline.test_online_sinkhorn_jit.<locals>.callback at 0x126b22290>, in_shardings=<ja...ames=('batch_size',), donate_argnums=(), device=None, backend=None, keep_unused=False, inline=False, resource_env=None)
fun = <function TestSinkhornOnline.test_online_sinkhorn_jit.<locals>.callback at 0x126b22290>
static_argnums = (1,), static_argnames = ('batch_size',), donate_argnums = ()
device = None, backend = None, keep_unused = False, inline = False

    def common_infer_params(pjit_info_args, *args, **kwargs):
      (fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames,
       donate_argnums, device, backend, keep_unused, inline,
       resource_env) = pjit_info_args
    
      if kwargs and not _is_unspecified(user_in_shardings):
        raise ValueError(
            "pjit does not support kwargs when in_shardings is specified.")
    
      if resource_env is not None:
        pjit_mesh = resource_env.physical_mesh
        if pjit_mesh.empty:
          if jax.config.jax_array:
            # Don't enforce requiring a mesh when `jax_array` flag is enabled. But
            # if mesh is not empty then pjit will respect it.
            pass
          else:
            raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
                                "it's defined at the call site?")
      else:
        pjit_mesh = None
    
      if (backend or device) and pjit_mesh is not None and not pjit_mesh.empty:
        raise ValueError(
            "Mesh context manager should not be used with jit when backend or "
            "device is also specified as an argument to jit.")
    
      f = lu.wrap_init(fun)
      f, dyn_args = argnums_partial_except(f, static_argnums, args,
                                            allow_invalid=True)
      del args
    
      # TODO(yashkatariya): Merge the nokwargs and kwargs path. One blocker is
      # flatten_axes which if kwargs are present in the treedef (even empty {}),
      # leads to wrong expansion.
      if kwargs:
        f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs)
        args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs))
        flat_fun, out_tree = flatten_fun(f, in_tree)
      else:
        args_flat, in_tree = tree_flatten(dyn_args)
        flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
        dyn_kwargs = ()
      del kwargs
    
      if donate_argnums and not jax.config.jax_debug_nans:
        donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
      else:
        donated_invars = (False,) * len(args_flat)
    
      if jax.config.jax_array:
        # If backend or device is set as an arg on jit, then resolve them to
        # in_shardings and out_shardings as if user passed in in_shardings
        # and out_shardings.
        if backend or device:
          in_shardings = out_shardings = _create_sharding_with_device_backend(
              device, backend)
        else:
          in_shardings = tree_map(
              lambda x: _create_sharding_for_array(pjit_mesh, x), user_in_shardings)
          out_shardings = tree_map(
              lambda x: _create_sharding_for_array(pjit_mesh, x), user_out_shardings)
      else:
        in_shardings = tree_map(
            lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
            user_in_shardings)
        out_shardings = tree_map(
            lambda x: x if _is_unspecified(x) else
            _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), user_out_shardings)
        # This check fails extremely rarely and has a huge cost in the dispatch
        # path. So hide it behind the jax_enable_checks flag.
        if jax.config.jax_enable_checks:
          _maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
    
      del user_in_shardings, user_out_shardings
    
      local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
      # TODO(yashkatariya): This is a hack. This should go away when avals have
      # is_global attribute.
      if jax.config.jax_array:
        in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat)
      else:
        in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat))
      out_positional_semantics = (
          pxla._PositionalSemantics.GLOBAL
          if jax.config.jax_parallel_functions_output_gda or jax.config.jax_array else
          pxla.positional_semantics.val)
    
      global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources(
          hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,
          tuple(isinstance(a, GDA) for a in args_flat), resource_env)
    
>     jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
          flat_fun, hashable_pytree(out_shardings), global_in_avals,
          HashableFunction(out_tree, closure=()),
          ('jit' if resource_env is None else 'pjit'))

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/pjit.py:520: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

fun = Wrapped function:
0   : flatten_fun   (PyTreeDef(((), {'epsilon': *})),)
1   : _argnames_partial   (<jax._src.util.WrapKwArgs object at 0x16a614850>,)
2   : _argnums_partial   ((), ())
Core: callback

args = (<hashable <lambda> with closure=(PyTreeDef(*), (<jax._src.interpreters.pxla.UnspecifiedValue object at 0x125ff8d60>,))>, (ShapedArray(float32[], weak_type=True),), <hashable <lambda> with closure=()>, 'jit')
cache = {}
key = (((<function flatten_fun at 0x10d0d5e10>, (PyTreeDef(((), {'epsilon': *})),)), (<function _argnames_partial at 0x10d0d...5ff8d60>,))>, (ShapedArray(float32[], weak_type=True),), <hashable <lambda> with closure=()>, 'jit'), False, None, ...)
result = None

    def memoized_fun(fun: WrappedFun, *args):
      cache = fun_caches.setdefault(fun.f, {})
      if config.jax_check_tracer_leaks:
        key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
               config.x64_enabled, config.jax_default_device,
               config._trace_context())
      else:
        key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled,
               config.jax_default_device, config._trace_context())
      result = cache.get(key, None)
      if result is not None:
        ans, stores = result
        fun.populate_stores(stores)
      else:
>       ans = call(fun, *args)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/linear_util.py:301: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

fun = Wrapped function:
0   : flatten_fun   (PyTreeDef(((), {'epsilon': *})),)
1   : _argnames_partial   (<jax._src.util.WrapKwArgs object at 0x16a614850>,)
2   : _argnums_partial   ((), ())
Core: callback

out_shardings_thunk = <hashable <lambda> with closure=(PyTreeDef(*), (<jax._src.interpreters.pxla.UnspecifiedValue object at 0x125ff8d60>,))>
global_in_avals = (ShapedArray(float32[], weak_type=True),)
out_tree = <hashable <lambda> with closure=()>, api_name = 'jit'

    @lu.cache
    def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree, api_name):
      prev_positional_val = pxla.positional_semantics.val
      try:
        pxla.positional_semantics.val = pxla._PositionalSemantics.GLOBAL
        with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
                                       "for pjit in {elapsed_time} sec",
                                        event=dispatch.JAXPR_TRACE_EVENT):
>         jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
              fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/pjit.py:932: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (Wrapped function:
0   : flatten_fun   (PyTreeDef(((), {'epsilon': *})),)
1   : _argnames_partial   (<jax._src.util.Wr... object at 0x16a614850>,)
2   : _argnums_partial   ((), ())
Core: callback
, (ShapedArray(float32[], weak_type=True),))
kwargs = {'debug_info': DebugInfo(func_src_info='callback at /Users/michal/Projects/ott/tests/solvers/linear/sinkhorn_misc_test...TestSinkhornOnline.test_online_sinkhorn_jit.<locals>.callback at 0x126b22290>, PyTreeDef(((), {'epsilon': *})), True))}

    @wraps(func)
    def wrapper(*args, **kwargs):
      with TraceAnnotation(name, **decorator_kwargs):
>       return func(*args, **kwargs)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/profiler.py:314: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

fun = Wrapped function:
0   : flatten_fun   (PyTreeDef(((), {'epsilon': *})),)
1   : _argnames_partial   (<jax._src.util.WrapKwArgs object at 0x16a614850>,)
2   : _argnums_partial   ((), ())
Core: callback

in_avals = (ShapedArray(float32[], weak_type=True),)
debug_info = DebugInfo(func_src_info='callback at /Users/michal/Projects/ott/tests/solvers/linear/sinkhorn_misc_test.py:221', trace... TestSinkhornOnline.test_online_sinkhorn_jit.<locals>.callback at 0x126b22290>, PyTreeDef(((), {'epsilon': *})), True))

    @profiler.annotate_function
    def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
                               in_avals: Sequence[AbstractValue],
                               debug_info: Optional[DebugInfo] = None,
                               *,
                               keep_inputs: Optional[List[bool]] = None):
      with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore
        main.jaxpr_stack = ()  # type: ignore
>       jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
          fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/interpreters/partial_eval.py:1985: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

fun = Wrapped function:
0   : flatten_fun   (PyTreeDef(((), {'epsilon': *})),)
1   : _argnames_partial   (<jax._src.util.WrapKwArgs object at 0x16a614850>,)
2   : _argnums_partial   ((), ())
Core: callback

main = MainTrace(1,DynamicJaxprTrace)
in_avals = (ShapedArray(float32[], weak_type=True),)

    def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
                                  in_avals: Sequence[AbstractValue], *,
                                  keep_inputs: Optional[Sequence[bool]] = None,
                                  debug_info: Optional[DebugInfo] = None):
      keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
    
      frame = JaxprStackFrame()
      frame.debug_info = debug_info
      with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
        trace = DynamicJaxprTrace(main, core.cur_sublevel())
        in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
        in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
>       ans = fun.call_wrapped(*in_tracers_)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/interpreters/partial_eval.py:2002: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Wrapped function:
0   : flatten_fun   (PyTreeDef(((), {'epsilon': *})),)
1   : _argnames_partial   (<jax._src.util.WrapKwArgs object at 0x16a614850>,)
2   : _argnums_partial   ((), ())
Core: callback

args = []
kwargs = {'batch_size': 42, 'epsilon': Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>}
stack = [], gen = None, gen_static_args = None, out_store = None

    def call_wrapped(self, *args, **kwargs):
      """Calls the underlying function, applying the transforms.
    
      The positional `args` and keyword `kwargs` are passed to the first
      transformation generator.
      """
      stack = []
      for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
        gen = gen(*(gen_static_args + tuple(args)), **kwargs)
        args, kwargs = next(gen)
        stack.append((gen, out_store))
      gen = gen_static_args = out_store = None
    
      try:
>       ans = self.f(*args, **dict(self.params, **kwargs))

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/linear_util.py:165: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

epsilon = Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
batch_size = 42

    def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput:
      geom = pointcloud.PointCloud(
          self.x, self.y, epsilon=epsilon, batch_size=batch_size
      )
      prob = linear_problem.LinearProblem(geom, self.a, self.b)
      solver = sinkhorn.Sinkhorn(threshold=threshold)
>     return solver(prob)

tests/solvers/linear/sinkhorn_misc_test.py:227: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <ott.solvers.linear.sinkhorn.Sinkhorn object at 0x16a615240>
ot_prob = <ott.problems.linear.linear_problem.LinearProblem object at 0x16a6152d0>
init = (None, None)

    def __call__(
        self,
        ot_prob: linear_problem.LinearProblem,
        init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None),
    ) -> SinkhornOutput:
      """Run Sinkhorn algorithm.
    
      Args:
        ot_prob: Linear OT problem.
        init: Initial dual potentials/scalings f_u and g_v, respectively.
          Any `None` values will be initialized using the initializer.
    
      Returns:
        The Sinkhorn output.
      """
      initializer = self.create_initializer()
      init_dual_a, init_dual_b = initializer(
          ot_prob, *init, lse_mode=self.lse_mode
      )
      run_fn = jax.jit(run) if self.jit else run
>     return run_fn(ot_prob, self, (init_dual_a, init_dual_b))

src/ott/solvers/linear/sinkhorn.py:774: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<ott.problems.linear.linear_problem.LinearProblem object at 0x16a6152d0>, <ott.solvers.linear.sinkhorn.Sinkhorn objec...loat32[1000])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[402])>with<DynamicJaxprTrace(level=1/0)>))
kwargs = {}, __tracebackhide__ = True, mode = 'remove_frames'

    @util.wraps(fun)
    def reraise_with_filtered_traceback(*args, **kwargs):
      __tracebackhide__ = True
      try:
>       return fun(*args, **kwargs)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/traceback_util.py:163: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<ott.problems.linear.linear_problem.LinearProblem object at 0x16a6152d0>, <ott.solvers.linear.sinkhorn.Sinkhorn objec...loat32[1000])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[402])>with<DynamicJaxprTrace(level=1/0)>))
kwargs = {}

    @api_boundary
    def cache_miss(*args, **kwargs):
>     outs, out_flat, out_tree, args_flat = _python_pjit_helper(
          fun, infer_params_fn, *args, **kwargs)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/pjit.py:237: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

fun = <function run at 0x13fe49120>
infer_params_fn = <function jit.<locals>.infer_params at 0x16a4995a0>
args = (<ott.problems.linear.linear_problem.LinearProblem object at 0x16a6152d0>, <ott.solvers.linear.sinkhorn.Sinkhorn objec...loat32[1000])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[402])>with<DynamicJaxprTrace(level=1/0)>))
kwargs = {}

    def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
>     args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
          *args, **kwargs)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/pjit.py:180: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<ott.problems.linear.linear_problem.LinearProblem object at 0x16a6152d0>, <ott.solvers.linear.sinkhorn.Sinkhorn objec...loat32[1000])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[402])>with<DynamicJaxprTrace(level=1/0)>))
kwargs = {}
pjit_info_args = PjitInfo(fun=<function run at 0x13fe49120>, in_shardings=<jax._src.interpreters.pxla.UnspecifiedValue object at 0x125f..., static_argnames=(), donate_argnums=(), device=None, backend=None, keep_unused=False, inline=False, resource_env=None)

    def infer_params(*args, **kwargs):
      pjit_info_args = pjit.PjitInfo(
          fun=fun, in_shardings=in_shardings,
          out_shardings=out_shardings, static_argnums=static_argnums,
          static_argnames=static_argnames, donate_argnums=donate_argnums,
          device=device, backend=backend, keep_unused=keep_unused,
          inline=inline, resource_env=None)
>     return pjit.common_infer_params(pjit_info_args, *args, **kwargs)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/api.py:443: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

pjit_info_args = PjitInfo(fun=<function run at 0x13fe49120>, in_shardings=<jax._src.interpreters.pxla.UnspecifiedValue object at 0x125f..., static_argnames=(), donate_argnums=(), device=None, backend=None, keep_unused=False, inline=False, resource_env=None)
fun = <function run at 0x13fe49120>, static_argnums = (), static_argnames = ()
donate_argnums = (), device = None, backend = None, keep_unused = False
inline = False

    def common_infer_params(pjit_info_args, *args, **kwargs):
      (fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames,
       donate_argnums, device, backend, keep_unused, inline,
       resource_env) = pjit_info_args
    
      if kwargs and not _is_unspecified(user_in_shardings):
        raise ValueError(
            "pjit does not support kwargs when in_shardings is specified.")
    
      if resource_env is not None:
        pjit_mesh = resource_env.physical_mesh
        if pjit_mesh.empty:
          if jax.config.jax_array:
            # Don't enforce requiring a mesh when `jax_array` flag is enabled. But
            # if mesh is not empty then pjit will respect it.
            pass
          else:
            raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
                                "it's defined at the call site?")
      else:
        pjit_mesh = None
    
      if (backend or device) and pjit_mesh is not None and not pjit_mesh.empty:
        raise ValueError(
            "Mesh context manager should not be used with jit when backend or "
            "device is also specified as an argument to jit.")
    
      f = lu.wrap_init(fun)
      f, dyn_args = argnums_partial_except(f, static_argnums, args,
                                            allow_invalid=True)
      del args
    
      # TODO(yashkatariya): Merge the nokwargs and kwargs path. One blocker is
      # flatten_axes which if kwargs are present in the treedef (even empty {}),
      # leads to wrong expansion.
      if kwargs:
        f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs)
        args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs))
        flat_fun, out_tree = flatten_fun(f, in_tree)
      else:
        args_flat, in_tree = tree_flatten(dyn_args)
        flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
        dyn_kwargs = ()
      del kwargs
    
      if donate_argnums and not jax.config.jax_debug_nans:
        donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
      else:
        donated_invars = (False,) * len(args_flat)
    
      if jax.config.jax_array:
        # If backend or device is set as an arg on jit, then resolve them to
        # in_shardings and out_shardings as if user passed in in_shardings
        # and out_shardings.
        if backend or device:
          in_shardings = out_shardings = _create_sharding_with_device_backend(
              device, backend)
        else:
          in_shardings = tree_map(
              lambda x: _create_sharding_for_array(pjit_mesh, x), user_in_shardings)
          out_shardings = tree_map(
              lambda x: _create_sharding_for_array(pjit_mesh, x), user_out_shardings)
      else:
        in_shardings = tree_map(
            lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
            user_in_shardings)
        out_shardings = tree_map(
            lambda x: x if _is_unspecified(x) else
            _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), user_out_shardings)
        # This check fails extremely rarely and has a huge cost in the dispatch
        # path. So hide it behind the jax_enable_checks flag.
        if jax.config.jax_enable_checks:
          _maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
    
      del user_in_shardings, user_out_shardings
    
      local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
      # TODO(yashkatariya): This is a hack. This should go away when avals have
      # is_global attribute.
      if jax.config.jax_array:
        in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat)
      else:
        in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat))
      out_positional_semantics = (
          pxla._PositionalSemantics.GLOBAL
          if jax.config.jax_parallel_functions_output_gda or jax.config.jax_array else
          pxla.positional_semantics.val)
    
>     global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources(
          hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,
          tuple(isinstance(a, GDA) for a in args_flat), resource_env)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/pjit.py:516: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

>   def __bool__(self): return self.aval._bool(self)

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/core.py:650: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = ShapedArray(bool[], weak_type=True)
arg = Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

    def error(self, arg):
>     raise ConcretizationTypeError(arg, fname_context)
E     jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
E     The problem arose with the `bool` function. 
E     The error occurred while tracing the function callback at /Users/michal/Projects/ott/tests/solvers/linear/sinkhorn_misc_test.py:221 for jit. This concrete value was not available in Python because it depends on the value of the argument passed at flattened position 0.
E     
E     See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
E     
E     The stack trace below excludes JAX-internal frames.
E     The preceding is the original exception that occurred, unmodified.
E     
E     --------------------

../../.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/core.py:1341: UnfilteredStackTrace

The above exception was the direct cause of the following exception:

self = <sinkhorn_misc_test.TestSinkhornOnline object at 0x14c6c3c10>, jit = True

    @pytest.mark.parametrize("jit", [False, True])
    def test_online_sinkhorn_jit(self, jit: bool):
    
      def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput:
        geom = pointcloud.PointCloud(
            self.x, self.y, epsilon=epsilon, batch_size=batch_size
        )
        prob = linear_problem.LinearProblem(geom, self.a, self.b)
        solver = sinkhorn.Sinkhorn(threshold=threshold)
        return solver(prob)
    
      threshold = 1e-1
      fun = jax.jit(callback, static_argnums=(1,)) if jit else callback
    
>     errors = fun(epsilon=1.0, batch_size=42).errors

tests/solvers/linear/sinkhorn_misc_test.py:232: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/solvers/linear/sinkhorn_misc_test.py:227: in callback
    return solver(prob)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <ott.solvers.linear.sinkhorn.Sinkhorn object at 0x16a615240>
ot_prob = <ott.problems.linear.linear_problem.LinearProblem object at 0x16a6152d0>
init = (None, None)

    def __call__(
        self,
        ot_prob: linear_problem.LinearProblem,
        init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None),
    ) -> SinkhornOutput:
      """Run Sinkhorn algorithm.
    
      Args:
        ot_prob: Linear OT problem.
        init: Initial dual potentials/scalings f_u and g_v, respectively.
          Any `None` values will be initialized using the initializer.
    
      Returns:
        The Sinkhorn output.
      """
      initializer = self.create_initializer()
      init_dual_a, init_dual_b = initializer(
          ot_prob, *init, lse_mode=self.lse_mode
      )
      run_fn = jax.jit(run) if self.jit else run
>     return run_fn(ot_prob, self, (init_dual_a, init_dual_b))
E     jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
E     The problem arose with the `bool` function. 
E     The error occurred while tracing the function callback at /Users/michal/Projects/ott/tests/solvers/linear/sinkhorn_misc_test.py:221 for jit. This concrete value was not available in Python because it depends on the value of the argument passed at flattened position 0.
E     
E     See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

src/ott/solvers/linear/sinkhorn.py:774: ConcretizationTypeError
=========================== short test summary info ============================
FAILED tests/solvers/linear/sinkhorn_misc_test.py::TestSinkhornOnline::test_online_sinkhorn_jit[True]
================= 1 failed, 1 passed, 784 deselected in 2.46s ==================

Edit 1: JAX_JIT_PJIT_API_MERGE=0 seems to resolve this issue.
Edit 2: Seems the issue is with caching.

@marcocuturi
Copy link
Contributor

marcocuturi commented Feb 20, 2023

Thanks Michal!

Just a heads-up for everyone: this is basically making all tests fail, so once Michal has fixed this we can look back at all PRs

@JTT94
Copy link
Collaborator

JTT94 commented Feb 20, 2023

Maybe related, I noticed things fail with jit inside a jit.

Consider a function consisting of potential initialisation, geometry, and sinkhorn computations. The sinkhorn solver has jit=True as default argument.

  1. jit the outer function and pass jit=True to the sinkhorn solver. Seems to fail
  2. jit the outer function and pass jit=False argument in solver. Seems to work.
  3. do not jit the outer function but set jit=True to the solver. Seems to work.

For example this fails

@pytest.mark.parametrize("jit", [False, True])
  def test_online_sinkhorn_jit(self, jit: bool):

    def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput:
      geom = pointcloud.PointCloud(
          self.x, self.y, epsilon=epsilon, batch_size=batch_size
      )
      prob = linear_problem.LinearProblem(geom, self.a, self.b)
      solver = sinkhorn.Sinkhorn(threshold=threshold, jit=True)
      return solver(prob)

    threshold = 1e-1
    fun = jax.jit(callback, static_argnums=(1,)) if jit else callback

    errors = fun(epsilon=1.0, batch_size=42).errors
    err = errors[errors > -1][-1]
    assert threshold > err

and this is fine

@pytest.mark.parametrize("jit", [False, True])
  def test_online_sinkhorn_jit(self, jit: bool):

    def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput:
      geom = pointcloud.PointCloud(
          self.x, self.y, epsilon=epsilon, batch_size=batch_size
      )
      prob = linear_problem.LinearProblem(geom, self.a, self.b)
      solver = sinkhorn.Sinkhorn(threshold=threshold, jit=not jit)
      return solver(prob)

    threshold = 1e-1
    fun = jax.jit(callback, static_argnums=(1,)) if jit else callback

    errors = fun(epsilon=1.0, batch_size=42).errors
    err = errors[errors > -1][-1]
    assert threshold > err

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants