Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement logprob derivation for some forms of AdvancedIndexing mixtures #6398

Open
ricardoV94 opened this issue Dec 14, 2022 · 1 comment
Open

Comments

@ricardoV94
Copy link
Member

    > Are we guaranteed to always be able to later rewrite this safely into something that is a proper logprob?

There are two types of logps that are relevant here. 1) An RV that is broadcasted directly and 2) an RV where we are doing advanced indexing (i.e. mixture)

For the former, we would need to perform symbolic unbroadcasting of the value variable, and then evaluate the base RV at that value

def logp_broadcast_to(value, base_rv, broadcast_shape):
  return logp(base_rv, unbroadcast(value, broadcast_shape))

I am not sure what the unbroadcast function would look like, but it should do the following (in pseudo-code), when valid, and raise otherwise.

np.testing.assert_array_equal(unbroadcast(broadcast_to(value, shape), shape).eval(), value)

For the second, we would need to somehow only evaluate the logp at the unique indexes. Something like

def logp_mixture_advanced_integer_indexing(value, base_rvs, indexes):
  unique_indexes = [pt.unique(index) for index in indexes]
  # And either ignore or assert the value should be the same for repeated indexes
  return logp_basic_mixture_rv(value[unique_indexes], base_rvs, unique_indexes)

This requires some thought to implement, which is why this PR does not attempt to do so.

We could at least implement the logp for the cases where indexes are unique (either because the user promised us so, or we added an assert)

def logp_mixture_advanced_integer_indexing(value, base_rvs, indexes):
  indexes_with_assert = [assert_op(index, pt.unique(index).size == index.size) for index in indexes]
  return logp_basic_mixture_rv(value, base_rvs, indexes_with_assert)

Originally posted by @ricardoV94 in #6369 (comment)

@ricardoV94 ricardoV94 changed the title Implement logprob derivation for some forms of AdvancedIndexing Implement logprob derivation for some forms of AdvancedIndexing mixtures Dec 14, 2022
@Om-Doiphode
Copy link

Om-Doiphode commented Feb 1, 2023

@ricardoV94 I would like to work on this issue. Can you tell me on how to get started with it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants