Skip to content

Commit

Permalink
Allow keywords in etuples
Browse files Browse the repository at this point in the history
Closes #36.
  • Loading branch information
brandonwillard committed Jun 28, 2019
1 parent 06e80a9 commit 632f787
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 10 deletions.
61 changes: 53 additions & 8 deletions symbolic_pymc/etuple.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import inspect
import reprlib

import toolz

from multipledispatch import dispatch

from kanren.term import operator, arguments
Expand All @@ -10,6 +13,26 @@
etuple_repr.maxother = 100


class KwdPair(tuple):
"""A class used to indicate a keyword + value mapping.
TODO: Could subclass `ast.keyword`.
"""

def __new__(cls, arg, value):
assert isinstance(arg, str)
obj = super().__new__(cls, (arg, value))
return obj

@property
def eval_obj(self):
return KwdPair(self[0], getattr(self[1], "eval_obj", self[1]))

def __repr__(self):
return f"{str(self[0])}={repr(self[1])}"


class ExpressionTuple(tuple):
"""A tuple object that represents an expression.
Expand All @@ -18,20 +41,41 @@ class ExpressionTuple(tuple):
"""

null = object()

def __new__(cls, *args, **kwargs):
obj = super().__new__(cls, *args, **kwargs)
# TODO: Consider making this a weakref.
obj._eval_obj = cls.null
return obj

@property
def eval_obj(self):
"""Return the evaluation of this expression tuple.
XXX: If the object isn't cached, it will be evaluated recursively.
"""
if hasattr(self, "_eval_obj"):
if self._eval_obj is not ExpressionTuple.null:
return self._eval_obj
else:
evaled_args = [getattr(i, "eval_obj", i) for i in self[1:]]
_eval_obj = self[0](*evaled_args)
arg_grps = toolz.groupby(lambda x: isinstance(x, KwdPair), evaled_args)
evaled_args = arg_grps.get(False, [])
evaled_kwargs = arg_grps.get(True, [])

op = self[0]
try:
op_sig = inspect.signature(op)
except ValueError:
_eval_obj = op(*(evaled_args + [kw[1] for kw in evaled_kwargs]))
else:
op_args = op_sig.bind(*evaled_args, **dict(evaled_kwargs))
op_args.apply_defaults()

_eval_obj = op(*op_args.args, **op_args.kwargs)

assert not isinstance(_eval_obj, ExpressionTuple)
# assert not isinstance(_eval_obj, ExpressionTuple)

self._eval_obj = _eval_obj
return self._eval_obj
Expand Down Expand Up @@ -86,16 +130,17 @@ def etuple(*args, **kwargs):
If the keyword 'eval_obj' is given, the `ExpressionTuple`'s
evaluated object is set to the corresponding value.
XXX: There is no verification/check that the arguments evaluate to the
user-specified 'eval_obj', so be careful.
"""
res = ExpressionTuple(args)
_eval_obj = kwargs.pop("eval_obj", ExpressionTuple.null)

if "eval_obj" in kwargs:
_eval_obj = kwargs.pop("eval_obj")
etuple_kwargs = tuple(KwdPair(k, v) for k, v in kwargs.items())

assert not isinstance(_eval_obj, ExpressionTuple)
res = ExpressionTuple(args + etuple_kwargs)

res._eval_obj = _eval_obj
res._eval_obj = _eval_obj

return res

Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_etuple_term():
assert isinstance(test_e[2], TFlowMetaTensorShape)
assert test_e[2] is a_mt.op.node_def['shape']

del test_e._eval_obj
test_e._eval_obj = ExpressionTuple.null
a_evaled = test_e.eval_obj
assert all([a == b for a, b in zip(a_evaled.rands(), a_mt.rands())])

Expand Down
45 changes: 44 additions & 1 deletion tests/test_etuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_op(*args):

e1 = etuple(test_op, 1, 2)

assert not hasattr(e1, '_eval_obj')
assert e1._eval_obj is ExpressionTuple.null

with pytest.raises(ValueError):
e1.eval_obj = 1
Expand Down Expand Up @@ -74,3 +74,46 @@ def test_etuple_term():

e1_dup_2 = term(operator(e1), arguments(e1))
assert e1_dup_2 == e1_obj


def test_etuple_kwargs():
"""Test keyword arguments and default argument values."""
def test_func(a, b, c=None, d='d-arg', **kwargs):
assert isinstance(c, (type(None), int))
return [a, b, c, d]

e1 = etuple(test_func, 1, 2)
assert e1.eval_obj == [1, 2, None, 'd-arg']

# Make sure we handle variadic args properly
def test_func2(*args, c=None, d='d-arg', **kwargs):
assert isinstance(c, (type(None), int))
return list(args) + [c, d]

e11 = etuple(test_func2, 1, 2)
assert e11.eval_obj == [1, 2, None, 'd-arg']

e2 = etuple(test_func, 1, 2, 3)
assert e2.eval_obj == [1, 2, 3, 'd-arg']

e3 = etuple(test_func, 1, 2, 3, 4)
assert e3.eval_obj == [1, 2, 3, 4]

e4 = etuple(test_func, 1, 2, c=3)
assert e4.eval_obj == [1, 2, 3, 'd-arg']

e5 = etuple(test_func, 1, 2, d=3)
assert e5.eval_obj == [1, 2, None, 3]

e6 = etuple(test_func, 1, 2, 3, d=4)
assert e6.eval_obj == [1, 2, 3, 4]

# Try evaluating nested etuples
e7 = etuple(test_func, etuple(add, 1, 0), 2,
c=etuple(add, 1, etuple(add, 1, 1)))
assert e7.eval_obj == [1, 2, 3, 'd-arg']

# Try a function without an obtainable signature object
e8 = etuple(enumerate, etuple(list, ['a', 'b', 'c', 'd']),
start=etuple(add, 1, etuple(add, 1, 1)))
assert list(e8.eval_obj) == [(3, 'a'), (4, 'b'), (5, 'c'), (6, 'd')]

0 comments on commit 632f787

Please sign in to comment.