Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise NotImplementedError on missing backend functionality #2055

Merged
merged 3 commits into from
Oct 5, 2019

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Oct 4, 2019

Addresses #2053

Note I've been having mkl errors when using .summary(). It would be nice to find a non-mkl workaround if not torch.backends.mkl.is_available().

Tested

  • added a unit test

@neerajprad
Copy link
Member

if not torch.backends.mkl.is_available()

I think the diagnostics should just have nan if mkl isn't available (that's probably required by the fft function), and IIRC we made a change to do that sometime back. This looks like a regression to me. cc. @fehiepsi

tests/test_generic.py Outdated Show resolved Hide resolved
@neerajprad
Copy link
Member

LGTM! I'll fix the mkl issue in another PR.

@neerajprad
Copy link
Member

@fritzo - could you paste the error trace? I think the only diagnostic function that is likely to throw this is effective_sample_size but that is wrapped in a _safe block. I think the issue is that perhaps pytest is treating that warning as an error? If that's the case, we could just add @pytest.mark.filterwarnings("ignore:..") to the test.

def _safe(fn):
    """
    Safe version of utilities in the :mod:`pyro.ops.stats` module. Wrapped
    functions return `NaN` tensors instead of throwing exceptions.

    :param fn: stats function from :mod:`pyro.ops.stats` module.
    """
    @functools.wraps(fn)
    def wrapped(sample, *args, **kwargs):
        try:
            val = fn(sample, *args, **kwargs)
        except Exception:
            warnings.warn(tb.format_exc())
            val = torch.full(sample.shape[2:], float("nan"),
                             dtype=sample.dtype, device=sample.device)
        return val

    return wrapped

@fritzo
Copy link
Member Author

fritzo commented Oct 4, 2019

Thanks, @neerajprad your suggestion of @pytest.mark.filterwarnings fixed my issue.

neerajprad
neerajprad previously approved these changes Oct 4, 2019
@neerajprad neerajprad merged commit 69abc10 into dev Oct 5, 2019
@fritzo fritzo deleted the pyro-generic-not-implemented branch October 12, 2019 01:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants