Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement betainc and derivatives #464

Merged
merged 1 commit into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions aesara/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import os
import warnings

import numpy as np
import scipy.special
Expand All @@ -14,12 +15,15 @@
from aesara.gradient import grad_not_implemented
from aesara.scalar.basic import (
BinaryScalarOp,
ScalarOp,
UnaryScalarOp,
complex_types,
discrete_types,
exp,
float64,
float_types,
log,
log1p,
true_div,
upcast,
upgrade_to_float,
Expand Down Expand Up @@ -1044,3 +1048,221 @@ def c_code(self, node, name, inp, out, sub):


log1mexp = Log1mexp(upgrade_to_float, name="scalar_log1mexp")


class BetaInc(ScalarOp):
"""
Regularized incomplete beta function
"""

nin = 3
nfunc_spec = ("scipy.special.betainc", 3, 1)

def impl(self, a, b, x):
return scipy.special.betainc(a, b, x)

def grad(self, inp, grads):
a, b, x = inp
(gz,) = grads

return [
gz * betainc_der(a, b, x, True),
gz * betainc_der(a, b, x, False),
gz
* exp(
log1p(-x) * (b - 1)
+ log(x) * (a - 1)
- (gammaln(a) + gammaln(b) - gammaln(a + b))
),
]


betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")


class BetaIncDer(ScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the first
argument (alpha) or the second argument (bbeta), depending on whether the
fourth argument to betainc_der is `True` or `False`, respectively.

Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
Journal of Statistical Software, 3(1), 1-20.
"""

nin = 4

def impl(self, p, q, x, wrtp):
def _betainc_a_n(f, p, q, n):
"""
Numerator (a_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""

if n == 1:
return p * f * (q - 1) / (q * (p + 1))

p2n = p + 2 * n
F1 = p ** 2 * f ** 2 * (n - 1) / (q ** 2)
F2 = (
(p + q + n - 2)
* (p + n - 1)
* (q - n)
/ ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1))
)

return F1 * F2

def _betainc_b_n(f, p, q, n):
"""
Offset (b_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""
pf = p * f
p2n = p + 2 * n

N1 = 2 * (pf + 2 * q) * n * (n + p - 1) + p * q * (p - 2 - pf)
D1 = q * (p2n - 2) * p2n

return N1 / D1

def _betainc_da_n_dp(f, p, q, n):
"""
Derivative of a_n wrt p
"""

if n == 1:
return -p * f * (q - 1) / (q * (p + 1) ** 2)

pp = p ** 2
ppp = pp * p
p2n = p + 2 * n

N1 = -(n - 1) * f ** 2 * pp * (q - n)
N2a = (-8 + 8 * p + 8 * q) * n ** 3
N2b = (16 * pp + (-44 + 20 * q) * p + 26 - 24 * q) * n ** 2
N2c = (10 * ppp + (14 * q - 46) * pp + (-40 * q + 66) * p - 28 + 24 * q) * n
N2d = 2 * pp ** 2 + (-13 + 3 * q) * ppp + (-14 * q + 30) * pp
N2e = (-29 + 19 * q) * p + 10 - 8 * q

D1 = q ** 2 * (p2n - 3) ** 2
D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2

return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2

def _betainc_da_n_dq(f, p, q, n):
"""
Derivative of a_n wrt q
"""
if n == 1:
return p * f / (q * (p + 1))

p2n = p + 2 * n
F1 = (p ** 2 * f ** 2 / (q ** 2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2)
D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)

return F1 / D1

def _betainc_db_n_dp(f, p, q, n):
"""
Derivative of b_n wrt p
"""
p2n = p + 2 * n
pp = p ** 2
q4 = 4 * q
p4 = 4 * p

F1 = (p * f / q) * (
(-p4 - q4 + 4) * n ** 2 + (p4 - 4 + q4 - 2 * pp) * n + pp * q
)
D1 = (p2n - 2) ** 2 * p2n ** 2

return F1 / D1

def _betainc_db_n_dq(f, p, q, n):
"""
Derivative of b_n wrt to q
"""
p2n = p + 2 * n
return -(p ** 2 * f) / (q * (p2n - 2) * p2n)

# Input validation
if not (0 <= x <= 1) or p < 0 or q < 0:
return np.nan

if x > (p / (p + q)):
return -self.impl(q, p, 1 - x, not wrtp)

min_iters = 3
max_iters = 200
err_threshold = 1e-12

derivative_old = 0

Am2, Am1 = 1, 1
Bm2, Bm1 = 0, 1
dAm2, dAm1 = 0, 0
dBm2, dBm1 = 0, 0

f = (q * x) / (p * (1 - x))
K = np.exp(
p * np.log(x)
+ (q - 1) * np.log1p(-x)
- np.log(p)
- scipy.special.betaln(p, q)
)
if wrtp:
dK = (
np.log(x)
- 1 / p
+ scipy.special.digamma(p + q)
- scipy.special.digamma(p)
)
else:
dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q)

for n in range(1, max_iters + 1):
a_n_ = _betainc_a_n(f, p, q, n)
b_n_ = _betainc_b_n(f, p, q, n)
if wrtp:
da_n = _betainc_da_n_dp(f, p, q, n)
db_n = _betainc_db_n_dp(f, p, q, n)
else:
da_n = _betainc_da_n_dq(f, p, q, n)
db_n = _betainc_db_n_dq(f, p, q, n)

A = a_n_ * Am2 + b_n_ * Am1
B = a_n_ * Bm2 + b_n_ * Bm1
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1

Am2, Am1 = Am1, A
Bm2, Bm1 = Bm1, B
dAm2, dAm1 = dAm1, dA
dBm2, dBm1 = dBm1, dB

if n < min_iters - 1:
continue

F1 = A / B
F2 = (dA - F1 * dB) / B
derivative = K * (F1 * dK + F2)

errapx = abs(derivative_old - derivative)
d_errapx = errapx / max(err_threshold, abs(derivative))
derivative_old = derivative

if d_errapx <= err_threshold:
break

if n >= max_iters:
warnings.warn(
f"_betainc_derivative did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan

return derivative


betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
5 changes: 5 additions & 0 deletions aesara/tensor/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ def log1mexp_inplace(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""


@scalar_elemwise
def betainc_inplace(a, b, x):
"""Regularized incomplete beta function"""


@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
Expand Down
6 changes: 6 additions & 0 deletions aesara/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,11 @@ def log1mexp(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""


@scalar_elemwise
def betainc(a, b, x):
"""Regularized incomplete beta function"""


@scalar_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
Expand Down Expand Up @@ -2909,6 +2914,7 @@ def logsumexp(x, axis=None, keepdims=False):
"softplus",
"log1pexp",
"log1mexp",
"betainc",
"real",
"imag",
"angle",
Expand Down
23 changes: 22 additions & 1 deletion tests/scalar/test_math.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import numpy as np
import scipy.special as sp

import aesara.tensor as aet
from aesara import function
from aesara.compile.mode import Mode
from aesara.graph.fg import FunctionGraph
from aesara.link.c.basic import CLinker
from aesara.scalar.math import gammainc, gammaincc, gammal, gammau
from aesara.scalar.math import betainc, betainc_der, gammainc, gammaincc, gammal, gammau


def test_gammainc_nan():
Expand Down Expand Up @@ -44,3 +47,21 @@ def test_gammau_nan():
assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1))


def test_betainc():
a, b, x = aet.scalars("a", "b", "x")
res = betainc(a, b, x)
test_func = function([a, b, x], res, mode=Mode("py"))
assert np.isclose(test_func(15, 10, 0.7), sp.betainc(15, 10, 0.7))


def test_betainc_derivative_nan():
a, b, x = aet.scalars("a", "b", "x")
res = betainc_der(a, b, x, True)
test_func = function([a, b, x], res, mode=Mode("py"))
assert not np.isnan(test_func(1, 1, 1))
assert np.isnan(test_func(1, 1, -1))
assert np.isnan(test_func(1, 1, 2))
assert np.isnan(test_func(1, -1, 1))
assert np.isnan(test_func(1, 1, -1))
Loading