Skip to content

Commit

Permalink
Separate default_factory from default parameter in BaseField init
Browse files Browse the repository at this point in the history
Separating default and default_factory makes it consistent with parameters for field().
  • Loading branch information
daveraja committed May 8, 2024
1 parent ea7918f commit f498bd3
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 59 deletions.
132 changes: 88 additions & 44 deletions clorm/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,21 +1352,31 @@ class DateField(StringField):
Args:
default: A default value (or function) to be used when instantiating a
``Predicate`` object. If a Python ``callable`` object is
specified (i.e., a function or functor) then it will be called (with no
arguments) when the predicate/complex-term object is instantiated.
default: A default value to be used when instantiating a ``Predicate`` object and no
value has been specified.
default_factory: A unary function (ie. a function with no arguments) for generating a
value when none has been specified.
index (bool): Determine if this field should be indexed by default in a
``FactBase```. Defaults to ``False``.
"""

def __init__(self, default: Any = MISSING, index: Any = MISSING) -> None:
def __init__(
self,
*,
default: Any = MISSING,
default_factory: Callable[[], Any] = MISSING,
index: Any = MISSING,
) -> None:
self._index = index if index is not MISSING else False
if default is not MISSING and default_factory is not MISSING:
raise ValueError("can not specify both default and default_factory")

if default is MISSING:
self._default = (False, None)
self._default = MISSING
self._default_factory = MISSING
if default is MISSING and default_factory is MISSING:
return

cmplx = self.complex
Expand All @@ -1383,22 +1393,21 @@ def _process_cmplx_value(v):

_process_value = _process_basic_value if cmplx is None else _process_cmplx_value

# If the default is not a factory function than make sure the value can be converted to
# clingo without error.
if not callable(default):
# If we're using a default value then set the value (after some preprocessing).
if default is not MISSING:
try:
self._default = (True, _process_value(default))
self.pytocl(self._default[1])
self._default = _process_value(default)
self.pytocl(self._default)
except (TypeError, ValueError):
raise TypeError(
'Invalid default value "{}" for {}'.format(default, type(self).__name__)
)
else:

def _process_default():
return _process_value(default())
def _process_default_factory():
return _process_value(default_factory())

self._default = (True, _process_default)
self._default_factory = _process_default_factory

@staticmethod
@abc.abstractmethod
Expand All @@ -1420,13 +1429,18 @@ def complex(cls) -> Optional["Predicate"]:

@property
def has_default(self):
"""Returns whether a default value has been set"""
return self._default[0]
"""Returns whether there is a default value or default factory."""
return self._default is not MISSING or self._default_factory is not MISSING

@property
def has_default_factory(self):
"""Returns whether a default value has been set"""
return self._default[0] and callable(self._default[1])
"""Returns whether there is a default factory"""
return self._default_factory is not MISSING

@property
def default_factory(self):
"""Return the default factory function."""
return self._default_factory

@property
def default(self):
Expand All @@ -1439,11 +1453,11 @@ def default(self):
value and a ``None`` default value.
"""
if not self._default[0]:
return None
if callable(self._default[1]):
return self._default[1]()
return self._default[1]
if self._default is not MISSING:
return self._default
if self._default_factory is not MISSING:
return self._default_factory()
return None

@property
def index(self):
Expand Down Expand Up @@ -1511,19 +1525,23 @@ def field(
except (AttributeError, ValueError):
module = None
if default is not MISSING:
return _create_complex_term(basefield, default, module)
elif default_factory is not MISSING:
return _create_complex_term(basefield, default_factory, module)
return _create_complex_term(basefield, default=default, module=module)
if default_factory is not MISSING:
return _create_complex_term(basefield, default_factory=default_factory, module=module)
return basefield

if issubclass(basefield, BaseField):
if default is not MISSING:
return basefield(default)
elif default_factory is not MISSING:
return basefield(default_factory)
return basefield
return basefield(default=default)
if default_factory is not MISSING:
return basefield(default_factory=default_factory)
complex = basefield.complex
if complex is not None:
fields = (dn.defn for dn in complex.meta)
default, default_factory = _make_implicit_defaults_for_complex_field(fields)
return basefield(default=default, default_factory=default_factory)

raise TypeError(f"{basefield} can just be of Type '{BaseField}' or '{Sequence}'")
raise TypeError(f"{basefield} can only be of type '{BaseField}' or '{Sequence}'")


# ------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -2464,24 +2482,50 @@ def get_field_definition(defn: Any, module: str = "") -> BaseField:
return _create_complex_term(defn, module=module)


def _create_complex_term(defn: Any, default_value: Any = MISSING, module: str = "") -> BaseField:
def _make_implicit_defaults_for_complex_field(
fields: Tuple[BaseField],
) -> Tuple[Any, Union[Callable[[], Any], None]]:
"""Tries to create a default value or a default_factory.
Returns a pair (default, default_factory). If both values are MISSING then there is no way to
generate a default.
"""

def _default_factory():
return tuple([field.default for field in fields])

make_default_factory = False
for field in fields:
if not field.has_default and not field.has_default_factory:
return (MISSING, MISSING)
if field.has_default_factory:
make_default_factory = True
if make_default_factory:
return (MISSING, _default_factory)
return (_default_factory(), MISSING)


def _create_complex_term(
defn: Any,
*,
default: Any = MISSING,
default_factory: Callable[[], Any] = MISSING,
module: str = "",
) -> BaseField:
# NOTE: relies on a dict preserving insertion order - this is true from Python 3.7+. Python
# 3.7 is already end-of-life so there is no longer a reason to use OrderedDict.
proto = {f"arg{idx+1}": get_field_definition(dn) for idx, dn in enumerate(defn)}
class_name = (
f'ClormAnonTuple({",".join(f"{arg[0]}={repr(arg[1])}" for arg in proto.items())})'
)

if default_value is not MISSING:
default = default_value
set_default = True
else:
default = tuple([field.default for field in proto.values() if field.has_default])
set_default = len(default) == len(
proto
) # calling type modifies proto so compare it beforehand
if not set_default and default:
raise ValueError((f"Default {default_value} must have the same length as {defn}"))
# if no default or default_factory is specified then see if we try to construct one from
# the sub-fields.
if default is MISSING and default_factory is MISSING:
default, default_factory = _make_implicit_defaults_for_complex_field(
tuple(proto.values())
)
proto["Meta"] = type("Meta", (object,), {"is_tuple": True, "_anon": True})

# For pickling to work, the __module__ variable needs to be set to the frame
Expand All @@ -2505,7 +2549,7 @@ def _create_complex_term(defn: Any, default_value: Any = MISSING, module: str =
depth += 1
except ValueError:
pass
return ct.Field(default) if set_default else ct.Field()
return ct.Field(default=default, default_factory=default_factory)


# ------------------------------------------------------------------------------
Expand Down
65 changes: 50 additions & 15 deletions tests/test_orm_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class BadField(IntegerField, StringField):
def test_api_field_function(self):
with self.subTest("with single BaseField"):
f = field(IntegerField)
self.assertEqual(f, IntegerField)
self.assertTrue(isinstance(f, IntegerField))

f = field(IntegerField, default=4)
self.assertEqual(type(f), IntegerField)
Expand Down Expand Up @@ -469,8 +469,8 @@ def inc():
self.assertEqual(fld.default, 5)
self.assertTrue(fld.has_default)

fld = IntegerField(default=inc)
self.assertTrue(fld.has_default)
fld = IntegerField(default_factory=inc)
self.assertTrue(fld.has_default_factory)
self.assertEqual(fld.default, 1)
self.assertEqual(fld.default, 2)

Expand All @@ -480,11 +480,6 @@ def inc():
self.assertEqual(fld.default, 0)
self.assertTrue(fld.has_default)

# A default can also be specified as a position argument
fld = IntegerField(0)
self.assertEqual(fld.default, 0)
self.assertTrue(fld.has_default)

# --------------------------------------------------------------------------
# Test catching invalid instantiation of a field (such as giving a bad
# default values for a field).
Expand Down Expand Up @@ -525,10 +520,6 @@ def test_api_field_index(self):
self.assertTrue(fstr2.index)
self.assertTrue(fconst2.index)

# Specify with positional arguments
f = IntegerField(1, True)
self.assertTrue(f.index)

# --------------------------------------------------------------------------
# Test the SimpleField class that handles all primitive types
# (Integer, String, Constant).
Expand Down Expand Up @@ -1157,7 +1148,7 @@ class Q(Predicate):
def test_predicate_anonymous_field_with_default(self):
class P(Predicate):
first = IntegerField
tuple_ = (IntegerField(2), StringField("42"))
tuple_ = (IntegerField(default=2), StringField(default="42"))

p = P(first=15, tuple_=(1, "2"))
raw_p = Function("p", [Number(15), Function("", [Number(1), String("2")])])
Expand Down Expand Up @@ -2182,7 +2173,7 @@ class P(Predicate):
# --------------------------------------------------------------------------
# Test a simple predicate with a field that has a function default
# --------------------------------------------------------------------------
def test_predicate_with_function_default(self):
def test_predicate_with_default_factory(self):
val = 0

def inc():
Expand All @@ -2191,7 +2182,7 @@ def inc():
return val

class Fact(Predicate):
anum = IntegerField(default=inc)
anum = IntegerField(default_factory=inc)
astr = StringField()

func = Function("fact", [Number(1), String("test")])
Expand All @@ -2206,6 +2197,50 @@ class Fact(Predicate):
self.assertEqual(f1, f2)
self.assertEqual(f1.raw, func)

# ---------------------------------------------------------------------------------------
# Test a predicate a complex field that has an implicit default based on its subfields
# ---------------------------------------------------------------------------------------
def test_predicate_with_anon_tuple_field_with_implicit_default_factory(self):
val = 0

def inc():
nonlocal val
val += 1
return val

class Outer(Predicate):
x = (IntegerField(default_factory=inc), StringField(default="blah"))

x1 = Function("", [Number(1), String("blah")])
x2 = Function("", [Number(2), String("blah")])

self.assertEqual(Outer().x, x1)
self.assertEqual(Outer().x, x2)

# ---------------------------------------------------------------------------------------
# Test a predicate a complex field that has an implicit default based on its subfields
# ---------------------------------------------------------------------------------------
def test_predicate_with_complex_field_with_implicit_default_factory(self):
val = 0

def inc():
nonlocal val
val += 1
return val

class X(Predicate):
a: int = field(IntegerField, default_factory=inc)
b: str = field(StringField, default="blah")

class Outer(Predicate):
x = field(X.Field)

x1 = Function("x", [Number(1), String("blah")])
x2 = Function("x", [Number(2), String("blah")])

self.assertEqual(Outer().x, x1)
self.assertEqual(Outer().x, x2)

# --------------------------------------------------------------------------
# Test that we can initialise using positional arguments
# --------------------------------------------------------------------------
Expand Down

0 comments on commit f498bd3

Please sign in to comment.