From 632f787b7241249d3cea9f8ea9f522f878866e86 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 28 Jun 2019 15:54:26 -0500 Subject: [PATCH] Allow keywords in etuples Closes #36. --- symbolic_pymc/etuple.py | 61 +++++++++++++++++++++++++++++----- tests/tensorflow/test_unify.py | 2 +- tests/test_etuple.py | 45 ++++++++++++++++++++++++- 3 files changed, 98 insertions(+), 10 deletions(-) diff --git a/symbolic_pymc/etuple.py b/symbolic_pymc/etuple.py index 5e8efb8..847fdad 100644 --- a/symbolic_pymc/etuple.py +++ b/symbolic_pymc/etuple.py @@ -1,5 +1,8 @@ +import inspect import reprlib +import toolz + from multipledispatch import dispatch from kanren.term import operator, arguments @@ -10,6 +13,26 @@ etuple_repr.maxother = 100 +class KwdPair(tuple): + """A class used to indicate a keyword + value mapping. + + TODO: Could subclass `ast.keyword`. + + """ + + def __new__(cls, arg, value): + assert isinstance(arg, str) + obj = super().__new__(cls, (arg, value)) + return obj + + @property + def eval_obj(self): + return KwdPair(self[0], getattr(self[1], "eval_obj", self[1])) + + def __repr__(self): + return f"{str(self[0])}={repr(self[1])}" + + class ExpressionTuple(tuple): """A tuple object that represents an expression. @@ -18,6 +41,14 @@ class ExpressionTuple(tuple): """ + null = object() + + def __new__(cls, *args, **kwargs): + obj = super().__new__(cls, *args, **kwargs) + # TODO: Consider making this a weakref. + obj._eval_obj = cls.null + return obj + @property def eval_obj(self): """Return the evaluation of this expression tuple. @@ -25,13 +56,26 @@ def eval_obj(self): XXX: If the object isn't cached, it will be evaluated recursively. """ - if hasattr(self, "_eval_obj"): + if self._eval_obj is not ExpressionTuple.null: return self._eval_obj else: evaled_args = [getattr(i, "eval_obj", i) for i in self[1:]] - _eval_obj = self[0](*evaled_args) + arg_grps = toolz.groupby(lambda x: isinstance(x, KwdPair), evaled_args) + evaled_args = arg_grps.get(False, []) + evaled_kwargs = arg_grps.get(True, []) + + op = self[0] + try: + op_sig = inspect.signature(op) + except ValueError: + _eval_obj = op(*(evaled_args + [kw[1] for kw in evaled_kwargs])) + else: + op_args = op_sig.bind(*evaled_args, **dict(evaled_kwargs)) + op_args.apply_defaults() + + _eval_obj = op(*op_args.args, **op_args.kwargs) - assert not isinstance(_eval_obj, ExpressionTuple) + # assert not isinstance(_eval_obj, ExpressionTuple) self._eval_obj = _eval_obj return self._eval_obj @@ -86,16 +130,17 @@ def etuple(*args, **kwargs): If the keyword 'eval_obj' is given, the `ExpressionTuple`'s evaluated object is set to the corresponding value. + XXX: There is no verification/check that the arguments evaluate to the + user-specified 'eval_obj', so be careful. """ - res = ExpressionTuple(args) + _eval_obj = kwargs.pop("eval_obj", ExpressionTuple.null) - if "eval_obj" in kwargs: - _eval_obj = kwargs.pop("eval_obj") + etuple_kwargs = tuple(KwdPair(k, v) for k, v in kwargs.items()) - assert not isinstance(_eval_obj, ExpressionTuple) + res = ExpressionTuple(args + etuple_kwargs) - res._eval_obj = _eval_obj + res._eval_obj = _eval_obj return res diff --git a/tests/tensorflow/test_unify.py b/tests/tensorflow/test_unify.py index 76d4f5c..d0dfdbf 100644 --- a/tests/tensorflow/test_unify.py +++ b/tests/tensorflow/test_unify.py @@ -34,7 +34,7 @@ def test_etuple_term(): assert isinstance(test_e[2], TFlowMetaTensorShape) assert test_e[2] is a_mt.op.node_def['shape'] - del test_e._eval_obj + test_e._eval_obj = ExpressionTuple.null a_evaled = test_e.eval_obj assert all([a == b for a, b in zip(a_evaled.rands(), a_mt.rands())]) diff --git a/tests/test_etuple.py b/tests/test_etuple.py index 00a572a..2051c00 100644 --- a/tests/test_etuple.py +++ b/tests/test_etuple.py @@ -14,7 +14,7 @@ def test_op(*args): e1 = etuple(test_op, 1, 2) - assert not hasattr(e1, '_eval_obj') + assert e1._eval_obj is ExpressionTuple.null with pytest.raises(ValueError): e1.eval_obj = 1 @@ -74,3 +74,46 @@ def test_etuple_term(): e1_dup_2 = term(operator(e1), arguments(e1)) assert e1_dup_2 == e1_obj + + +def test_etuple_kwargs(): + """Test keyword arguments and default argument values.""" + def test_func(a, b, c=None, d='d-arg', **kwargs): + assert isinstance(c, (type(None), int)) + return [a, b, c, d] + + e1 = etuple(test_func, 1, 2) + assert e1.eval_obj == [1, 2, None, 'd-arg'] + + # Make sure we handle variadic args properly + def test_func2(*args, c=None, d='d-arg', **kwargs): + assert isinstance(c, (type(None), int)) + return list(args) + [c, d] + + e11 = etuple(test_func2, 1, 2) + assert e11.eval_obj == [1, 2, None, 'd-arg'] + + e2 = etuple(test_func, 1, 2, 3) + assert e2.eval_obj == [1, 2, 3, 'd-arg'] + + e3 = etuple(test_func, 1, 2, 3, 4) + assert e3.eval_obj == [1, 2, 3, 4] + + e4 = etuple(test_func, 1, 2, c=3) + assert e4.eval_obj == [1, 2, 3, 'd-arg'] + + e5 = etuple(test_func, 1, 2, d=3) + assert e5.eval_obj == [1, 2, None, 3] + + e6 = etuple(test_func, 1, 2, 3, d=4) + assert e6.eval_obj == [1, 2, 3, 4] + + # Try evaluating nested etuples + e7 = etuple(test_func, etuple(add, 1, 0), 2, + c=etuple(add, 1, etuple(add, 1, 1))) + assert e7.eval_obj == [1, 2, 3, 'd-arg'] + + # Try a function without an obtainable signature object + e8 = etuple(enumerate, etuple(list, ['a', 'b', 'c', 'd']), + start=etuple(add, 1, etuple(add, 1, 1))) + assert list(e8.eval_obj) == [(3, 'a'), (4, 'b'), (5, 'c'), (6, 'd')]