From 844f293b57f81a2dcd13d55c96ae927aabecc987 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 5 Jul 2022 15:57:58 -0500 Subject: [PATCH 1/2] Update docstrings and type hints for Op's gradient methods --- aesara/graph/op.py | 60 ++++++++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/aesara/graph/op.py b/aesara/graph/op.py index ffe748166b..631b6bc339 100644 --- a/aesara/graph/op.py +++ b/aesara/graph/op.py @@ -320,53 +320,73 @@ def __ne__(self, other: Any) -> bool: add_tag_trace = staticmethod(add_tag_trace) def grad( - self, inputs: List[Variable], output_grads: List[Variable] + self, inputs: Sequence[Variable], output_grads: Sequence[Variable] ) -> List[Variable]: - """Construct a graph for the gradient with respect to each input variable. + r"""Construct a graph for the gradient with respect to each input variable. Each returned `Variable` represents the gradient with respect to that input computed based on the symbolic gradients with respect to each output. If the output is not differentiable with respect to an input, - then this method should return an instance of type ``NullType`` for that + then this method should return an instance of type `NullType` for that input. + Using the reverse-mode AD characterization given in [1]_, for a + :math:`C = f(A, B)` representing the function implemented by the `Op` + and its two arguments :math:`A` and :math:`B`, given by the + `Variable`\s in `inputs`, the values returned by `Op.grad` represent + the quantities :math:`\bar{A} \equiv \frac{\partial S_O}{A}` and + :math:`\bar{B}`, for some scalar output term :math:`S_O` of :math:`C` + in + + .. math:: + + \operatorname{Tr}\left(\bar{C}^\top dC\right) = + \operatorname{Tr}\left(\bar{A}^\top dA\right) + + \operatorname{Tr}\left(\bar{B}^\top dB\right) + + Parameters ---------- - inputs : list of Variable + inputs The input variables. - output_grads : list of Variable + output_grads The gradients of the output variables. Returns ------- - grads : list of Variable + grads The gradients with respect to each `Variable` in `inputs`. + .. [1] Giles, Mike. 2008. “An Extended Collection of Matrix Derivative Results for Forward and Reverse Mode Automatic Differentiation.” + """ raise NotImplementedError() def L_op( self, - inputs: List[Variable], - outputs: List[Variable], - output_grads: List[Variable], + inputs: Sequence[Variable], + outputs: Sequence[Variable], + output_grads: Sequence[Variable], ) -> List[Variable]: r"""Construct a graph for the L-operator. - This method is primarily used by `Lop` and dispatches to - :meth:`Op.grad` by default. + The L-operator computes a row vector times the Jacobian. + + This method dispatches to :meth:`Op.grad` by default. In one sense, + this method provides the original outputs when they're needed to + compute the return value, whereas `Op.grad` doesn't. - The L-operator computes a *row* vector times the Jacobian. The - mathematical relationship is - :math:`v \frac{\partial f(x)}{\partial x}`. - The L-operator is also supported for generic tensors (not only for - vectors). + See `Op.grad` for a mathematical explanation of the inputs and outputs + of this method. Parameters ---------- - inputs : list of Variable - outputs : list of Variable - output_grads : list of Variable + inputs + The inputs of the `Apply` node using this `Op`. + outputs + The outputs of the `Apply` node using this `Op` + output_grads + The gradients with respect to each `Variable` in `inputs`. """ return self.grad(inputs, output_grads) @@ -378,8 +398,6 @@ def R_op( This method is primarily used by `Rop`. - Suppose the `Op` outputs ``[ f_1(inputs), ..., f_n(inputs) ]``. - Parameters ---------- inputs From 62f50bbdbbcd3e7604ffc0137ed280b3e16ec3d6 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 5 Jul 2022 18:27:39 -0500 Subject: [PATCH 2/2] Refactor aesara.gradient and add type hints --- aesara/gradient.py | 429 ++++++++++++++++++++--------------------- tests/test_gradient.py | 2 - 2 files changed, 211 insertions(+), 220 deletions(-) diff --git a/aesara/gradient.py b/aesara/gradient.py index 080ae4a012..a57d061400 100644 --- a/aesara/gradient.py +++ b/aesara/gradient.py @@ -1,19 +1,30 @@ """Driver for gradient calculations.""" -import logging import time import warnings -from collections import OrderedDict from functools import partial, reduce -from typing import TYPE_CHECKING, Callable, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + List, + Mapping, + MutableSequence, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) import numpy as np +from typing_extensions import Literal import aesara from aesara.compile.ops import ViewOp from aesara.configdefaults import config from aesara.graph import utils -from aesara.graph.basic import NominalVariable, Variable +from aesara.graph.basic import Apply, NominalVariable, Variable from aesara.graph.null_type import NullType, null_type from aesara.graph.op import get_test_values from aesara.graph.type import Type @@ -23,26 +34,18 @@ from aesara.compile.mode import Mode -__docformat__ = "restructuredtext en" -_logger = logging.getLogger("aesara.gradient") +V = TypeVar("V", bound=Optional[Variable]) -# we can't do "import aesara.tensor" -# tensor depends on aesara.compile -# aesara.compile depends on aesara.gradient (this file) -# the reason aesara.compile depends on aesara.gradient -# is that aesara.compile.builders contains the op from graph -# functionality and it uses aesara.gradient to implement -# the new op's grad method -tensor = None -_msg_retType = "op.grad(...) returned a non-list" +# TODO: Refactor this so that it's not a global variable +grad_time: float = 0.0 -grad_time = 0 - -def format_as(use_list, use_tuple, outputs): - """ - Formats the outputs according to the flags `use_list` and `use_tuple`. +# TODO: Add `overload` variants +def as_list_or_tuple( + use_list: bool, use_tuple: bool, outputs: Union[V, Sequence[V]] +) -> Union[V, List[V], Tuple[V, ...]]: + """Return either a single object or a list/tuple of objects. If `use_list` is True, `outputs` is returned as a list (if `outputs` is not a list or a tuple then it is converted in a one element list). @@ -52,22 +55,25 @@ def format_as(use_list, use_tuple, outputs): """ if use_list and use_tuple: raise ValueError("Both flags cannot be simultaneously True") - if (use_list or use_tuple) and not isinstance(outputs, (list, tuple)): - if use_list: - return [outputs] - else: - return (outputs,) - elif not (use_list or use_tuple) and isinstance(outputs, (list, tuple)): - if len(outputs) != 1: - raise ValueError("Wrong arguments; expected a one element list") - return outputs[0] - elif use_list or use_tuple: - if use_list: - return list(outputs) + + if use_list or use_tuple: + if isinstance(outputs, Sequence): + if use_list: + return list(outputs) + else: + return tuple(outputs) else: - return tuple(outputs) + if use_list: + return [outputs] + else: + return (outputs,) else: - return outputs + if isinstance(outputs, Sequence): + if len(outputs) != 1: + raise ValueError("Wrong arguments; expected a one element list") + return outputs[0] + else: + return outputs def grad_not_implemented(op, x_pos, x, comment=""): @@ -155,97 +161,87 @@ def __str__(self): disconnected_type = DisconnectedType() -######################## -# R Operator -######################## - +def Rop( + f: Union[Variable, Sequence[Variable]], + wrt: Union[Variable, Sequence[Variable]], + eval_points: Union[Variable, Sequence[Variable]], + disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise", + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", +) -> Union[Optional[Variable], Sequence[Optional[Variable]]]: + """Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`. -def Rop(f, wrt, eval_points, disconnected_outputs="raise", return_disconnected="zero"): - """ - Computes the R operation on `f` wrt to `wrt` at `eval_points`. - - Mathematically this stands for the jacobian of `f` wrt - to `wrt` right muliplied by the eval points. + Mathematically this stands for the Jacobian of `f` right multiplied by the + `eval_points`. Parameters ---------- - f : :class:`~aesara.graph.basic.Variable` or list of Variables - `f` stands for the output of the computational graph to which you - want to apply the R operator - wrt : :class:`~aesara.graph.basic.Variable` or list of Variables - variables for which you compute the R operator of the expression - described by `f` - eval_points : :class:`~aesara.graph.basic.Variable` or list of Variables - evaluation points for each of the variables in `wrt` - disconnected_outputs : str + f + The outputs of the computational graph to which the R-operator is + applied. + wrt + Variables for which the R-operator of `f` is computed. + eval_points + Points at which to evaluate each of the variables in `wrt`. + disconnected_outputs Defines the behaviour if some of the variables in `f` have no dependency on any of the variable in `wrt` (or if all links are non-differentiable). The possible values are: - - 'ignore': considers that the gradient on these parameters is zero. - - 'warn': consider the gradient zero, and print a warning. - - 'raise': raise DisconnectedInputError. + - ``'ignore'``: considers that the gradient on these parameters is zero. + - ``'warn'``: consider the gradient zero, and print a warning. + - ``'raise'``: raise `DisconnectedInputError`. - return_disconnected : {'zero', 'None', 'Disconnected'} - - 'zero' : If wrt[i] is disconnected, return value i will be - wrt[i].zeros_like() - - 'None' : If wrt[i] is disconnected, return value i will be - None - - 'Disconnected' : returns variables of type DisconnectedType + return_disconnected + - ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be + ``wrt[i].zeros_like()``. + - ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be + ``None`` + - ``'disconnected'`` : returns variables of type `DisconnectedType` Returns ------- - :class:`~aesara.graph.basic.Variable` or list/tuple of Variables depending on type of f - Symbolic expression such that - R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j] + A symbolic expression such obeying + ``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``, where the indices in that expression are magic multidimensional indices that specify both the position within a list and all - coordinates of the tensor element in the last. + coordinates of the tensor elements. If `wrt` is a list/tuple, then return a list/tuple with the results. """ - using_list = isinstance(f, list) - using_tuple = isinstance(f, tuple) if not isinstance(wrt, (list, tuple)): - wrt = [wrt] + _wrt: List[Variable] = [aesara.tensor.as_tensor_variable(wrt)] + else: + _wrt = [aesara.tensor.as_tensor_variable(x) for x in wrt] if not isinstance(eval_points, (list, tuple)): - eval_points = [eval_points] + _eval_points: List[Variable] = [aesara.tensor.as_tensor_variable(eval_points)] + else: + _eval_points = [aesara.tensor.as_tensor_variable(x) for x in eval_points] if not isinstance(f, (list, tuple)): - f = [f] + _f: List[Variable] = [aesara.tensor.as_tensor_variable(f)] + else: + _f = [aesara.tensor.as_tensor_variable(x) for x in f] - if len(wrt) != len(eval_points): + if len(_wrt) != len(_eval_points): raise ValueError("`wrt` must be the same length as `eval_points`.") # Check that each element of wrt corresponds to an element # of eval_points with the same dimensionality. - for pack in enumerate(zip(wrt, eval_points)): - i = pack[0] - wrt_elem, eval_point = pack[1] - if not isinstance(wrt_elem, Variable): - wrt_elem = aesara.tensor.as_tensor_variable(wrt_elem) - if not isinstance(eval_point, Variable): - eval_point = aesara.tensor.as_tensor_variable(eval_point) + for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points)): try: - if wrt_elem.type.ndim != eval_point.type.ndim: raise ValueError( - "Element " - + str(i) - + " of wrt/eval_point have mismatched " - + "dimensionality: " - + str(wrt_elem.type.ndim) - + " versus " - + str(eval_point.type.ndim) + f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: " + f"{wrt_elem.type.ndim} and {eval_point.type.ndim}" ) except AttributeError: # wrt_elem and eval_point don't always have ndim like random type # Tensor, Sparse have the ndim attribute pass - seen_nodes = OrderedDict() + seen_nodes: Dict[Apply, Sequence[Variable]] = {} def _traverse(node): """TODO: writeme""" @@ -260,8 +256,8 @@ def _traverse(node): # inputs of the node local_eval_points = [] for inp in inputs: - if inp in wrt: - local_eval_points.append(eval_points[wrt.index(inp)]) + if inp in _wrt: + local_eval_points.append(_eval_points[_wrt.index(inp)]) elif inp.owner is None: try: local_eval_points.append(inp.zeros_like()) @@ -316,13 +312,13 @@ def _traverse(node): # end _traverse # Populate the dictionary - for out in f: + for out in _f: _traverse(out.owner) - rval = [] - for out in f: - if out in wrt: - rval.append(eval_points[wrt.index(out)]) + rval: List[Optional[Variable]] = [] + for out in _f: + if out in _wrt: + rval.append(_eval_points[_wrt.index(out)]) elif ( seen_nodes.get(out.owner, None) is None or seen_nodes[out.owner][out.owner.outputs.index(out)] is None @@ -361,81 +357,89 @@ def _traverse(node): else: rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)]) - return format_as(using_list, using_tuple, rval) + using_list = isinstance(f, list) + using_tuple = isinstance(f, tuple) + return as_list_or_tuple(using_list, using_tuple, rval) -def Lop(f, wrt, eval_points, consider_constant=None, disconnected_inputs="raise"): - """Computes the L operation on `f` with respect to `wrt` at `eval_points`. +def Lop( + f: Union[Variable, Sequence[Variable]], + wrt: Union[Variable, Sequence[Variable]], + eval_points: Union[Variable, Sequence[Variable]], + consider_constant: Optional[Sequence[Variable]] = None, + disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise", +) -> Union[Optional[Variable], Sequence[Optional[Variable]]]: + """Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`. Mathematically this stands for the Jacobian of `f` with respect to `wrt` left muliplied by the `eval_points`. Parameters ---------- - f : :class:`~aesara.graph.basic.Variable` or list of Variables - `f` stands for the output of the computational graph to which you - want to apply the L operator - wrt : :class:`~aesara.graph.basic.Variable` or list of Variables - variables for which you compute the L operator of the expression - described by `f` - eval_points : :class:`~aesara.graph.basic.Variable` or list of Variables - evaluation points for each of the variables in `f` + f + The outputs of the computational graph to which the R-operator is + applied. + wrt + Variables for which the R-operator of `f` is computed. + eval_points + Points at which to evaluate each of the variables in `wrt`. + consider_constant + See `grad`. + disconnected_inputs + See `grad`. Returns ------- - :class:`~aesara.graph.basic.Variable` or list/tuple of Variables depending on type of `f` - Symbolic expression such that + A symbolic expression satisfying ``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]`` where the indices in that expression are magic multidimensional indices that specify both the position within a list and all - coordinates of the tensor element in the last + coordinates of the tensor elements. If `f` is a list/tuple, then return a list/tuple with the results. """ if not isinstance(eval_points, (list, tuple)): - eval_points = [eval_points] - - using_list = isinstance(wrt, list) - using_tuple = isinstance(wrt, tuple) + _eval_points: List[Variable] = [aesara.tensor.as_tensor_variable(eval_points)] + else: + _eval_points = [aesara.tensor.as_tensor_variable(x) for x in eval_points] if not isinstance(f, (list, tuple)): - f = [f] + _f: List[Variable] = [aesara.tensor.as_tensor_variable(f)] + else: + _f = [aesara.tensor.as_tensor_variable(x) for x in f] - # make copies of f and grads so we don't modify the client's copy - f = list(f) - grads = list(eval_points) + grads = list(_eval_points) if not isinstance(wrt, (list, tuple)): - wrt = [wrt] + _wrt: List[Variable] = [aesara.tensor.as_tensor_variable(wrt)] + else: + _wrt = [aesara.tensor.as_tensor_variable(x) for x in wrt] - assert len(f) == len(grads) - known = OrderedDict(zip(f, grads)) + assert len(_f) == len(grads) + known = dict(zip(_f, grads)) ret = grad( cost=None, known_grads=known, consider_constant=consider_constant, - wrt=wrt, + wrt=_wrt, disconnected_inputs=disconnected_inputs, ) - return format_as(using_list, using_tuple, ret) - - -######################### -# Gradient -######################### + using_list = isinstance(wrt, list) + using_tuple = isinstance(wrt, tuple) + return as_list_or_tuple(using_list, using_tuple, ret) def grad( - cost, - wrt, - consider_constant=None, - disconnected_inputs="raise", - add_names=True, - known_grads=None, - return_disconnected="zero", - null_gradients="raise", -): + cost: Optional[Variable], + wrt: Union[Variable, Sequence[Variable]], + consider_constant: Optional[Sequence[Variable]] = None, + disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise", + add_names: bool = True, + known_grads: Optional[Mapping[Variable, Variable]] = None, + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", + null_gradients: Literal["raise", "return"] = "raise", +) -> Union[Optional[Variable], Sequence[Optional[Variable]]]: """ Return symbolic gradients of one cost with respect to one or more variables. @@ -445,49 +449,47 @@ def grad( Parameters ---------- - cost : :class:`~aesara.graph.basic.Variable` scalar (0-dimensional) tensor variable or ``None`` - Value that we are differentiating (that we want the gradient of). - May be `None` if `known_grads` is provided. - wrt : :class:`~aesara.graph.basic.Variable` or list of Variables - Term[s] with respect to which we want gradients - consider_constant : list of variables - Expressions not to backpropagate through + cost + Value that we are differentiating (i.e. for which we want the + gradient). May be `None` if `known_grads` is provided. + wrt + The term(s) with respect to which we want gradients. + consider_constant + Expressions not to backpropagate through. disconnected_inputs : {'ignore', 'warn', 'raise'} Defines the behaviour if some of the variables in `wrt` are not part of the computational graph computing `cost` (or if all links are non-differentiable). The possible values are: - - 'ignore': considers that the gradient on these parameters is zero. - - 'warn': consider the gradient zero, and print a warning. - - 'raise': raise DisconnectedInputError. - add_names : bool - If True, variables generated by grad will be named - (d/d) provided that both cost and wrt - have names - known_grads : OrderedDict, optional - A ordered dictionary mapping variables to their gradients. This is - useful in the case where you know the gradient on some + - ``'ignore'``: considers that the gradient on these parameters is zero + - ``'warn'``: consider the gradient zero, and print a warning + - ``'raise'``: raise `DisconnectedInputError` + add_names + If ``True``, variables generated by `grad` will be named + ``(d/d)`` provided that both `cost` and `wrt` + have names. + known_grads + An ordered dictionary mapping variables to their gradients. This is + useful in the case where you know the gradients of some variables but do not know the original cost. - return_disconnected : {'zero', 'None', 'Disconnected'} - - 'zero' : If wrt[i] is disconnected, return value i will be - wrt[i].zeros_like() - - 'None' : If wrt[i] is disconnected, return value i will be - None - - 'Disconnected' : returns variables of type DisconnectedType - null_gradients : {'raise', 'return'} - Defines the behaviour if some of the variables in `wrt` have a + return_disconnected + - ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be + ``wrt[i].zeros_like()`` + - ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be + ``None`` + - ``'disconnected'`` : returns variables of type `DisconnectedType` + null_gradients + Defines the behaviour when some of the variables in `wrt` have a null gradient. The possibles values are: - - 'raise' : raise a NullTypeGradError exception - - 'return' : return the null gradients + - ``'raise'`` : raise a `NullTypeGradError` exception + - ``'return'`` : return the null gradients Returns ------- - variable or list/tuple of variables (matches `wrt`) - Symbolic expression of gradient of `cost` with respect to each - of the `wrt` terms. If an element of `wrt` is not - differentiable with respect to the output, then a zero - variable is returned. + A symbolic expression for the gradient of `cost` with respect to each + of the `wrt` terms. If an element of `wrt` is not differentiable with + respect to the output, then a zero variable is returned. """ t0 = time.time() @@ -498,30 +500,17 @@ def grad( if cost is not None and isinstance(cost.type, NullType): raise ValueError( - "Can't differentiate a NaN cost." - "cost is NaN because " + cost.type.why_null - ) - - if cost is not None and cost.ndim != 0: - raise TypeError("cost must be a scalar.") - - if isinstance(wrt, set): - raise TypeError( - "wrt must not be a set. sets have no defined " - "iteration order, so we can't return gradients in a" - " matching order." + "Can't differentiate a NaN cost. " + f"Cost is NaN because {cost.type.why_null}" ) - using_list = isinstance(wrt, list) - using_tuple = isinstance(wrt, tuple) - if not using_list and not using_tuple: - wrt = [wrt] + if cost is not None and cost.type.ndim != 0: + raise TypeError("Cost must be a scalar.") - for elem in wrt: - if not isinstance(elem, Variable): - raise TypeError( - "Expected Variable, got " + str(elem) + " of type " + str(type(elem)) - ) + if not isinstance(wrt, Sequence): + _wrt: List[Variable] = [wrt] + else: + _wrt = [x for x in wrt] outputs = [] if cost is not None: @@ -529,16 +518,15 @@ def grad( if known_grads is not None: outputs.extend(list(known_grads.keys())) - var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, wrt, consider_constant) + var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, _wrt, consider_constant) # build a dict mapping var to the gradient of cost with respect to var - grad_dict = OrderedDict() + grad_dict = {} if known_grads is None: - known_grads = OrderedDict() - else: - m = "known_grads must be an OrderedDict. " - assert isinstance(known_grads, OrderedDict) or len(known_grads) <= 1, m + known_grads = {} + + assert isinstance(known_grads, dict) # The gradient of the cost is 1 unless specified otherwise by known_grads. if cost is not None: @@ -615,7 +603,7 @@ def handle_disconnected(var): # if wrt is such a variable, populate the grad_dict with this info # so that wrt not being in var_to_app_to_idx won't cause an error below # according to the flag, possibly raise an error if wrt is disconnected - for elem in wrt: + for elem in _wrt: if elem not in var_to_app_to_idx and elem is not cost and elem not in grad_dict: handle_disconnected(elem) grad_dict[elem] = disconnected_type() @@ -632,32 +620,38 @@ def handle_disconnected(var): if hasattr(g.type, "dtype"): assert g.type.dtype in aesara.tensor.type.float_dtypes - rval = _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name) + _rval: Sequence[Variable] = _populate_grad_dict( + var_to_app_to_idx, grad_dict, _wrt, cost_name + ) + + rval: MutableSequence[Optional[Variable]] = list(_rval) - for i in range(len(rval)): - if isinstance(rval[i].type, NullType): + for i in range(len(_rval)): + if isinstance(_rval[i].type, NullType): if null_gradients == "raise": raise NullTypeGradError( - f"grad encountered a NaN. {rval[i].type.why_null}" + f"`grad` encountered a NaN. {_rval[i].type.why_null}" ) else: assert null_gradients == "return" - if isinstance(rval[i].type, DisconnectedType): - handle_disconnected(rval[i]) + if isinstance(_rval[i].type, DisconnectedType): + handle_disconnected(_rval[i]) if return_disconnected == "zero": - rval[i] = _float_zeros_like(wrt[i]) - elif return_disconnected == "None": + rval[i] = _float_zeros_like(_wrt[i]) + elif return_disconnected.lower() == "none": rval[i] = None else: - assert return_disconnected == "Disconnected" + assert return_disconnected.lower() == "disconnected" - if using_tuple: - rval = tuple(rval) - elif not using_list: - (rval,) = rval t1 = time.time() global grad_time grad_time += t1 - t0 + + if isinstance(wrt, tuple): + return tuple(rval) + elif not isinstance(wrt, list): + return rval[0] + return rval @@ -801,7 +795,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False): for i in range(len(grads)): grads[i] += cost_grads[i] - pgrads = OrderedDict(zip(params, grads)) + pgrads = dict(zip(params, grads)) # separate wrt from end grads: wrt_grads = list(pgrads[k] for k in wrt) end_grads = list(pgrads[k] for k in end) @@ -916,7 +910,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant): # var_to_app_to_idx[var][node] = [i,j] means node has # var as input at positions i and j - var_to_app_to_idx = OrderedDict() + var_to_app_to_idx = dict() # Set of variables that have been added to their true parents # ('true' here means that the elements of the variable are a function @@ -954,13 +948,13 @@ def account_for(var): continue if ipt not in var_to_app_to_idx: - # This object here *must* be an OrderedDict, because + # This object here *must* be ordered, because # we iterate over its keys when adding up the terms of the # gradient on ipt. If it is a regular dict, the grad method # will return something that is analytically correct, but # whose order of doing additions depends on the memory # location of the apply nodes. - var_to_app_to_idx[ipt] = OrderedDict() + var_to_app_to_idx[ipt] = {} app_to_idx = var_to_app_to_idx[ipt] if app not in app_to_idx: app_to_idx[app] = [] @@ -1052,7 +1046,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): """ # build a dict mapping node to the terms node contributes to each of # its inputs' gradients - term_dict = OrderedDict() + term_dict = {} def access_term_cache(node): """Populates term_dict[node] and returns it""" @@ -1978,7 +1972,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise if expression.ndim == 0: # expression is just a scalar, use grad - return format_as( + return as_list_or_tuple( using_list, using_tuple, grad( @@ -2013,7 +2007,7 @@ def inner_function(*args): non_sequences=[expression] + wrt, ) assert not updates, "Scan has returned a list of updates; this should not happen." - return format_as(using_list, using_tuple, jacobs) + return as_list_or_tuple(using_list, using_tuple, jacobs) def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): @@ -2093,7 +2087,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): not updates ), "Scan has returned a list of updates; this should not happen." hessians.append(hess) - return format_as(using_list, using_tuple, hessians) + return as_list_or_tuple(using_list, using_tuple, hessians) def _is_zero(x): @@ -2134,7 +2128,6 @@ def grad(self, args, g_outs): consider_constant_ = ConsiderConstant() -# I create a function only to have the doc show well. def consider_constant(x): """ DEPRECATED: use zero_grad() or disconnected_grad() instead. diff --git a/tests/test_gradient.py b/tests/test_gradient.py index b968aef585..7d7c92a391 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -278,8 +278,6 @@ def test_1None_rval(self): g = grad(a1.outputs[0], a1.outputs[1], disconnected_inputs="ignore") assert g.owner.op == at.fill assert g.owner.inputs[1].data == 0 - with pytest.raises(TypeError): - grad(a1.outputs[0], "wtf") def test_NNone_rval(self): # grad: Test returning some zero value from grad