-
-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement betainc and derivatives #464
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
We need to determine whether or not an Aesara implementation of _betainc_derivative
is a reasonable replacement for the Python-only implementation via BetaIncDd[a|b]
before merging.
If it's not possible to make an Aesara version that's comparable to the Python version in a reasonable amount of time/effort, then we can create a separate issue for that and merge this in the meantime.
aesara/scalar/math.py
Outdated
class BetaIncDdb(TernaryScalarOp): | ||
""" | ||
Gradient of the regularized incomplete beta function wrt to the second argument (b) | ||
""" | ||
|
||
def impl(self, a, b, x): | ||
return _betainc_derivative(a, b, x, wrtp=False) | ||
|
||
|
||
betainc_ddb_scalar = BetaIncDdb(upgrade_to_float_no_complex, name="betainc_ddb") | ||
|
||
|
||
def _betainc_derivative(p, q, x, wrtp=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the/a Aesara implementation of _betainc_derivative
prohibitively slow compared to this Python implementation?
If not, we should definitely use the Aesara implementation. If it is, we need to figure out why and open an independent line of investigation for why that's the case, and fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually did not write it!
The BetaInc code which was just a copy of the c code behind the scipy.special.betainc was what was known to be prohibitively slow (to compile specially), as were the autodiff derivatives obtained from it. I have some crude benchmarks here: https://github.com/ricardoV94/derivatives_betainc/blob/master/comparison_aesara.ipynb
If we want a test case for exploring the slowness of scan that seems like a good start, as we have the scipy vs aesara with the exact same algorithm under the hood.
The derivatives are a complete different algorithm so they might be fine. Do you think it's worth trying to convert them to aesara code? I guess the concern here is that they break the auto-diff chain? Or is it an issue for the backends that would need custom dispatch?
e35a39a
to
91d36ac
Compare
Are we running any of the jobs with float32? I am puzzled as to why the custom tests fail here but pass locally and also pass on the PyMC3 PR. The original precision of 7 decimals should be fine on float64, whereas for float32 it should be 3. The current is 4 so that could explain it. Edit: I see now it was float32, it's specified during the create matrix id part of the job. Wonder if that could also be part of the test title. It got me by surprise |
Yes, it's also something we need to refactor entirely, because rerunning all the tests under a default of float32 is extremely time consuming and does not provide any additional coverage that couldn't be achieved more directly for a fraction of the time. |
9fa58fc
to
ffb575f
Compare
I am getting a ValueError in the test that expects a Details: ValueError: Scalar check failed (npy_float64)test_math.py::TestBetaIncGrad::test_stan_grad_combined FAILED [100%]
tests/scalar/test_math.py:54 (TestBetaIncGrad.test_stan_grad_combined)
self = <aesara.compile.function.types.Function object at 0x7f411e2dc400>
args = (1.0, 1.0, 1.0), kwargs = {}
restore_defaults = <function Function.__call__.<locals>.restore_defaults at 0x7f411e2aa820>
profile = None, t0 = 1623257249.2276874, output_subset = None, i = 3, arg = 1.0
s = <array(1.)>, c = <array(1.)>
def __call__(self, *args, **kwargs):
"""
Evaluates value of a function on given arguments.
Parameters
----------
args : list
List of inputs to the function. All inputs are required, even when
some of them are not necessary to calculate requested subset of
outputs.
kwargs : dict
The function inputs can be passed as keyword argument. For this, use
the name of the input or the input instance as the key.
Keyword argument ``output_subset`` is a list of either indices of the
function's outputs or the keys belonging to the `output_keys` dict
and represent outputs that are requested to be calculated. Regardless
of the presence of ``output_subset``, the updates are always calculated
and processed. To disable the updates, you should use the ``copy``
method with ``delete_updates=True``.
Returns
-------
list
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
"""
def restore_defaults():
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, Container):
value = value.storage[0]
self[i] = value
profile = self.profile
t0 = time.time()
output_subset = kwargs.pop("output_subset", None)
if output_subset is not None and self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]
# Reinitialize each container's 'provided' counter
if self.trust_input:
i = 0
for arg in args:
s = self.input_storage[i]
s.storage[0] = arg
i += 1
else:
for c in self.input_storage:
c.provided = 0
if len(args) + len(kwargs) > len(self.input_storage):
raise TypeError("Too many parameter passed to aesara function")
# Set positional arguments
i = 0
for arg in args:
# TODO: provide a Param option for skipping the filter if we
# really want speed.
s = self.input_storage[i]
# see this emails for a discuation about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if arg is None:
s.storage[0] = arg
else:
try:
s.storage[0] = s.type.filter(
arg, strict=s.strict, allow_downcast=s.allow_downcast
)
except Exception as e:
function_name = "aesara function"
argument_name = "argument"
if self.name:
function_name += ' with name "' + self.name + '"'
if hasattr(arg, "name") and arg.name:
argument_name += ' with name "' + arg.name + '"'
where = get_variable_trace_string(self.maker.inputs[i].variable)
if len(e.args) == 1:
e.args = (
"Bad input "
+ argument_name
+ " to "
+ function_name
+ f" at index {int(i)} (0-based). {where}"
+ e.args[0],
)
else:
e.args = (
"Bad input "
+ argument_name
+ " to "
+ function_name
+ f" at index {int(i)} (0-based). {where}"
) + e.args
restore_defaults()
raise
s.provided += 1
i += 1
# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg
if (
not self.trust_input
and
# The getattr is only needed for old pickle
getattr(self, "_check_for_aliased_inputs", True)
):
# Collect aliased inputs among the storage space
args_share_memory = []
for i in range(len(self.input_storage)):
i_var = self.maker.inputs[i].variable
i_val = self.input_storage[i].storage[0]
if hasattr(i_var.type, "may_share_memory"):
is_aliased = False
for j in range(len(args_share_memory)):
group_j = zip(
[
self.maker.inputs[k].variable
for k in args_share_memory[j]
],
[
self.input_storage[k].storage[0]
for k in args_share_memory[j]
],
)
if any(
[
(
var.type is i_var.type
and var.type.may_share_memory(val, i_val)
)
for (var, val) in group_j
]
):
is_aliased = True
args_share_memory[j].append(i)
break
if not is_aliased:
args_share_memory.append([i])
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for j in group[1:]:
self.input_storage[j].storage[0] = copy.copy(
self.input_storage[j].storage[0]
)
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
if not self.trust_input:
for c in self.input_storage:
if c.required and not c.provided:
restore_defaults()
raise TypeError(
f"Missing required input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
)
if c.provided > 1:
restore_defaults()
raise TypeError(
f"Multiple values for input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
)
if c.implicit and c.provided > 0:
restore_defaults()
raise TypeError(
f"Tried to provide value for implicit input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
)
# Do the actual work
t0_fn = time.time()
try:
outputs = (
> self.fn()
if output_subset is None
else self.fn(output_subset=output_subset)
)
E ValueError: Scalar check failed (npy_float64)
../../aesara/compile/function/types.py:976: ValueError
During handling of the above exception, another exception occurred:
self = <tests.scalar.test_math.TestBetaIncGrad object at 0x7f413a7f8d00>
def test_stan_grad_combined(self):
a, b, z = aet.scalars("a", "b", "z")
betainc_out = betainc(a, b, z)
betainc_grad = aet.grad(betainc_out, [a, b], null_gradients="return")
f_grad = function([a, b, z], betainc_grad)
for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.0, 1.0, 1.0, 0, np.nan),
(1.0, 1.0, 0.4, -0.36651629, 0.30649537),
):
assert_allclose(
> f_grad(test_a, test_b, test_z), [expected_dda, expected_ddb]
)
test_math.py:66:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../aesara/compile/function/types.py:989: in __call__
raise_with_op(
../../aesara/link/utils.py:522: in raise_with_op
raise exc_value.with_traceback(exc_trace)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <aesara.compile.function.types.Function object at 0x7f411e2dc400>
args = (1.0, 1.0, 1.0), kwargs = {}
restore_defaults = <function Function.__call__.<locals>.restore_defaults at 0x7f411e2aa820>
profile = None, t0 = 1623257249.2276874, output_subset = None, i = 3, arg = 1.0
s = <array(1.)>, c = <array(1.)>
def __call__(self, *args, **kwargs):
"""
Evaluates value of a function on given arguments.
Parameters
----------
args : list
List of inputs to the function. All inputs are required, even when
some of them are not necessary to calculate requested subset of
outputs.
kwargs : dict
The function inputs can be passed as keyword argument. For this, use
the name of the input or the input instance as the key.
Keyword argument ``output_subset`` is a list of either indices of the
function's outputs or the keys belonging to the `output_keys` dict
and represent outputs that are requested to be calculated. Regardless
of the presence of ``output_subset``, the updates are always calculated
and processed. To disable the updates, you should use the ``copy``
method with ``delete_updates=True``.
Returns
-------
list
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
"""
def restore_defaults():
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
if isinstance(value, Container):
value = value.storage[0]
self[i] = value
profile = self.profile
t0 = time.time()
output_subset = kwargs.pop("output_subset", None)
if output_subset is not None and self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]
# Reinitialize each container's 'provided' counter
if self.trust_input:
i = 0
for arg in args:
s = self.input_storage[i]
s.storage[0] = arg
i += 1
else:
for c in self.input_storage:
c.provided = 0
if len(args) + len(kwargs) > len(self.input_storage):
raise TypeError("Too many parameter passed to aesara function")
# Set positional arguments
i = 0
for arg in args:
# TODO: provide a Param option for skipping the filter if we
# really want speed.
s = self.input_storage[i]
# see this emails for a discuation about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if arg is None:
s.storage[0] = arg
else:
try:
s.storage[0] = s.type.filter(
arg, strict=s.strict, allow_downcast=s.allow_downcast
)
except Exception as e:
function_name = "aesara function"
argument_name = "argument"
if self.name:
function_name += ' with name "' + self.name + '"'
if hasattr(arg, "name") and arg.name:
argument_name += ' with name "' + arg.name + '"'
where = get_variable_trace_string(self.maker.inputs[i].variable)
if len(e.args) == 1:
e.args = (
"Bad input "
+ argument_name
+ " to "
+ function_name
+ f" at index {int(i)} (0-based). {where}"
+ e.args[0],
)
else:
e.args = (
"Bad input "
+ argument_name
+ " to "
+ function_name
+ f" at index {int(i)} (0-based). {where}"
) + e.args
restore_defaults()
raise
s.provided += 1
i += 1
# Set keyword arguments
if kwargs: # for speed, skip the items for empty kwargs
for k, arg in kwargs.items():
self[k] = arg
if (
not self.trust_input
and
# The getattr is only needed for old pickle
getattr(self, "_check_for_aliased_inputs", True)
):
# Collect aliased inputs among the storage space
args_share_memory = []
for i in range(len(self.input_storage)):
i_var = self.maker.inputs[i].variable
i_val = self.input_storage[i].storage[0]
if hasattr(i_var.type, "may_share_memory"):
is_aliased = False
for j in range(len(args_share_memory)):
group_j = zip(
[
self.maker.inputs[k].variable
for k in args_share_memory[j]
],
[
self.input_storage[k].storage[0]
for k in args_share_memory[j]
],
)
if any(
[
(
var.type is i_var.type
and var.type.may_share_memory(val, i_val)
)
for (var, val) in group_j
]
):
is_aliased = True
args_share_memory[j].append(i)
break
if not is_aliased:
args_share_memory.append([i])
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for j in group[1:]:
self.input_storage[j].storage[0] = copy.copy(
self.input_storage[j].storage[0]
)
# Check if inputs are missing, or if inputs were set more than once, or
# if we tried to provide inputs that are supposed to be implicit.
if not self.trust_input:
for c in self.input_storage:
if c.required and not c.provided:
restore_defaults()
raise TypeError(
f"Missing required input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
)
if c.provided > 1:
restore_defaults()
raise TypeError(
f"Multiple values for input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
)
if c.implicit and c.provided > 0:
restore_defaults()
raise TypeError(
f"Tried to provide value for implicit input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
)
# Do the actual work
t0_fn = time.time()
try:
outputs = (
> self.fn()
if output_subset is None
else self.fn(output_subset=output_subset)
)
E ValueError: Scalar check failed (npy_float64)
E Apply node that caused the error: mul(second.0, betainc_ddb.0)
E Toposort index: 9
E Inputs types: [Scalar(float64), Scalar(float64)]
E Inputs shapes: [(), 'No shapes']
E Inputs strides: [(), 'No strides']
E Inputs values: [1.0, nan]
E Outputs clients: [[TensorFromScalar(mul.0)]]
E
E Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
E File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1441, in <listcomp>
E rval = [access_grad_cache(elem) for elem in wrt]
E File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1394, in access_grad_cache
E term = access_term_cache(node)[idx]
E File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1059, in access_term_cache
E output_grads = [access_grad_cache(var) for var in node.outputs]
E File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1059, in <listcomp>
E output_grads = [access_grad_cache(var) for var in node.outputs]
E File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1394, in access_grad_cache
E term = access_term_cache(node)[idx]
E File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1221, in access_term_cache
E input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
E File "/home/ricardo/Documents/Projects/aesara/aesara/scalar/basic.py", line 1138, in L_op
E return self.grad(inputs, output_gradients)
E File "/home/ricardo/Documents/Projects/aesara/aesara/scalar/math.py", line 1094, in grad
E gz * betainc_ddb_scalar(a, b, x),
E
E HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
../../aesara/compile/function/types.py:976: ValueError |
That |
You are right. Is this something that should be worked around, and if so, how? |
532a8b9
to
cdb12ca
Compare
1ece906
to
41be515
Compare
We can perform the test with the |
Codecov Report
@@ Coverage Diff @@
## main #464 +/- ##
==========================================
+ Coverage 76.66% 76.71% +0.05%
==========================================
Files 148 148
Lines 46400 46510 +110
Branches 10202 10213 +11
==========================================
+ Hits 35573 35682 +109
Misses 8219 8219
- Partials 2608 2609 +1
|
All tests are passing now and coverage looks good. The Calling the derivative scalar ops directly is also fine but it seems more reasonable to test the derivatives via the I would open an issue to test an aesara pure implementation of the derivatives and merge this for the time being (if the code looks good) |
aesara/scalar/math.py
Outdated
class BetaIncDda(TernaryScalarOp): | ||
""" | ||
Gradient of the regularized incomplete beta function wrt to the first argument (a) | ||
""" | ||
|
||
def impl(self, a, b, x): | ||
return _betainc_derivative(a, b, x, wrtp=True) | ||
|
||
|
||
betainc_dda_scalar = BetaIncDda(upgrade_to_float_no_complex, name="betainc_dda") | ||
|
||
|
||
class BetaIncDdb(TernaryScalarOp): | ||
""" | ||
Gradient of the regularized incomplete beta function wrt to the second argument (b) | ||
""" | ||
|
||
def impl(self, a, b, x): | ||
return _betainc_derivative(a, b, x, wrtp=False) | ||
|
||
|
||
betainc_ddb_scalar = BetaIncDdb(upgrade_to_float_no_complex, name="betainc_ddb") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before we merge this, let's combine these into a single Op
with a boolean wrtp
attribute. The _betainc_derivative
function can then become the entire impl
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, is my latest push what you had in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you refactor _betainc_derivative
to be BetaIncDdb.impl
, then yes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it
1e2d674
to
cca785f
Compare
This PR adds the equivalent Scipy
betainc
and Python-only Op's for the approximation of the derivatives wrt to the first two arguments. More context can be found in pymc-devs/pymc#4736One of the scalar gradient tests is failing locally because the expected
nan
return raises aValueError
in a test context, whereas it issues aRuntimeWarning
when running in the REPL. Other scalar gradient tests are failing in the CI due to numerical issues, but pass locally. Probably I need to specify in more detail the compilation mode (and if so, which one)?Here are a few important guidelines and requirements to check before your PR can be merged:
pre-commit
is installed and set up.