Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove explicit optimization from math.logsumexp #4747

Closed
ricardoV94 opened this issue Jun 6, 2021 · 2 comments · Fixed by #4860
Closed

Remove explicit optimization from math.logsumexp #4747

ricardoV94 opened this issue Jun 6, 2021 · 2 comments · Fixed by #4860
Assignees

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 6, 2021

Aesara already performs one pass of this optimization (which will be exactly equivalent to the pymc3 one after pymc-devs/pytensor#465). As things stand, the final graph contains two passes of the same optimization, because the optimized output still contains the same form as the original unoptimized graph log(sum(exp(x))) -> x_max + << log(sum(exp(x - x_max))) >>

https://github.com/pymc-devs/pymc3/blob/c62100c99df9a74194db89863776f155a95c076a/pymc3/math.py#L189-L194

import aesara
import aesara.tensor as at
from pymc3.math import logsumexp

x = at.vector('x')
y = logsumexp(x)
f = aesara.function([x], y)
aesara.dprint(f)
Elemwise{Composite{(i0 + log(i1) + i2)}}[(0, 0)] [id A] ''   9
 |InplaceDimShuffle{x} [id B] ''   5
 | |Reduce{maximum}{0} [id C] 'max'   4
 |   |Elemwise{sub,no_inplace} [id D] ''   3
 |     |x [id E]
 |     |Elemwise{Composite{Switch(IsInf(i0), i1, i0)}}[(0, 0)] [id F] ''   2
 |       |InplaceDimShuffle{x} [id G] ''   1
 |       | |Reduce{maximum}{0} [id H] 'max'   0
 |       |   |x [id E]
 |       |TensorConstant{(1,) of 0} [id I]
 |InplaceDimShuffle{x} [id J] ''   8
 | |Sum{acc_dtype=float64} [id K] ''   7
 |   |Elemwise{Composite{exp((i0 - i1))}}[(0, 0)] [id L] ''   6
 |     |Elemwise{sub,no_inplace} [id D] ''   3
 |     |InplaceDimShuffle{x} [id B] ''   5
 |Elemwise{Composite{Switch(IsInf(i0), i1, i0)}}[(0, 0)] [id F] ''   2
y = at.log(at.sum(at.exp(x), axis=None, keepdims=True))
f = aesara.function([x], y)
aesara.dprint(f)
Elemwise{Composite{(i0 + log(i1))}}[(0, 0)] [id A] ''   5
 |InplaceDimShuffle{x} [id B] ''   1
 | |Reduce{maximum}{0} [id C] 'max'   0
 |   |x [id D]
 |InplaceDimShuffle{x} [id E] ''   4
   |Sum{acc_dtype=float64} [id F] ''   3
     |Elemwise{Composite{exp((i0 - i1))}} [id G] ''   2
       |x [id D]
       |InplaceDimShuffle{x} [id B] ''   1
@ricardoV94 ricardoV94 self-assigned this Jun 6, 2021
@brandonwillard
Copy link
Contributor

Looks like the second graph was truncated.

@ricardoV94
Copy link
Member Author

Looks like the second graph was truncated.

Thanks, fixed it

ricardoV94 added a commit to ricardoV94/pymc that referenced this issue Jul 13, 2021
…deprecate `log1mexp` in favor of Aesara implementations.

Closes pymc-devs#4747
ricardoV94 added a commit to ricardoV94/pymc that referenced this issue Jul 13, 2021
…deprecate `log1mexp` in favor of Aesara implementations.

Closes pymc-devs#4747
ricardoV94 added a commit to ricardoV94/pymc that referenced this issue Jul 13, 2021
…deprecate `log1mexp` in favor of Aesara implementations.

Closes pymc-devs#4747
ricardoV94 added a commit to ricardoV94/pymc that referenced this issue Jul 13, 2021
…deprecate `log1mexp` in favor of Aesara implementations.

Closes pymc-devs#4747
ricardoV94 added a commit to ricardoV94/pymc that referenced this issue Jul 13, 2021
…deprecate `log1mexp` in favor of Aesara implementations.

Closes pymc-devs#4747
ricardoV94 added a commit that referenced this issue Jul 13, 2021
…deprecate `log1mexp` in favor of Aesara implementations.

Closes #4747
ricardoV94 added a commit to ricardoV94/pymc that referenced this issue Jul 15, 2021
…deprecate `log1mexp` in favor of Aesara implementations.

Closes pymc-devs#4747
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants