Skip to content

Commit

Permalink
Improve matching on datas with defaults
Browse files Browse the repository at this point in the history
Resolves   #708.
  • Loading branch information
evhub committed Dec 26, 2022
1 parent e7a2316 commit a137c16
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 35 deletions.
14 changes: 8 additions & 6 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ base_pattern ::= (
- View Patterns (`(<expression>) -> <pattern>`): calls `<expression>` on the item being matched and matches the result to `<pattern>`. The match fails if a [`MatchError`](#matcherror) is raised. `<expression>` may be unparenthesized only when it is a single atom.
- Class and Data Type Matching:
- Classes or Data Types (`<name>(<args>)`): will match as a data type if given [a Coconut `data` type](#data) (or a tuple of Coconut data types) and a class otherwise.
- Data Types (`data <name>(<args>)`): will check that whatever is in that position is of data type `<name>` and will match the attributes to `<args>`. Includes support for positional arguments, named arguments, and starred arguments. Also supports strict attribute by prepending a dot to the attribute name that raises `AttributError` if the attribute is not present rather than failing the match (e.g. `data MyData(.my_attr=<some_pattern>)`).
- Data Types (`data <name>(<args>)`): will check that whatever is in that position is of data type `<name>` and will match the attributes to `<args>`. Includes support for positional arguments, named arguments, default arguments, and starred arguments. Also supports strict attributes by prepending a dot to the attribute name that raises `AttributError` if the attribute is not present rather than failing the match (e.g. `data MyData(.my_attr=<some_pattern>)`).
- Classes (`class <name>(<args>)`): does [PEP-634-style class matching](https://www.python.org/dev/peps/pep-0634/#class-patterns). Also supports strict attribute matching as above.
- Mapping Destructuring:
- Dicts (`{<key>: <value>, ...}`): will match any mapping (`collections.abc.Mapping`) with the given keys and values that match the value patterns. Keys must be constants or equality checks.
Expand Down Expand Up @@ -2769,11 +2769,7 @@ Coconut's `Expected` built-in is a Coconut [`data` type](#data) that represents

`Expected` is effectively equivalent to the following:
```coconut
data Expected[T](result: T?, error: BaseException?):
def __new__(cls, result: T?=None, error: BaseException?=None) -> Expected[T]:
if result is not None and error is not None:
raise TypeError("Expected cannot have both a result and an error")
return makedata(cls, result, error)
data Expected[T](result: T? = None, error: BaseException? = None):
def __bool__(self) -> bool:
return self.error is None
def __fmap__[U](self, func: T -> U) -> Expected[U]:
Expand Down Expand Up @@ -2807,6 +2803,12 @@ data Expected[T](result: T?, error: BaseException?):

`Expected` is primarily used as the return type for [`safe_call`](#safe_call). Generally, the best way to use `Expected` is with [`fmap`](#fmap), which will apply a function to the result if it exists, or otherwise retain the error. If you want to sequence multiple `Expected`-returning operations, `.and_then` should be used instead of `fmap`.

To match against an `Expected`, just:
```
Expected(res) = Expected("result")
Expected(error=err) = Expected(error=TypeError())
```

##### Example

**Coconut:**
Expand Down
7 changes: 6 additions & 1 deletion __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,14 @@ _coconut_tail_call = of = call


@_dataclass(frozen=True, slots=True)
class Expected(_t.Generic[_T], _t.Tuple):
class _BaseExpected(_t.Generic[_T], _t.Tuple):
result: _t.Optional[_T]
error: _t.Optional[BaseException]
class Expected(_BaseExpected[_T]):
__slots__ = ()
_coconut_is_data = True
__match_args__ = ("result", "error")
_coconut_data_defaults: _t.Mapping[int, None] = ...
@_t.overload
def __new__(
cls,
Expand Down
33 changes: 24 additions & 9 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
format_var,
none_coalesce_var,
is_data_var,
data_defaults_var,
funcwrapper,
non_syntactic_newline,
indchars,
Expand Down Expand Up @@ -157,6 +158,7 @@
split_leading_whitespace,
ordered_items,
tuple_str_of_str,
dict_to_str,
)
from coconut.compiler.header import (
minify_header,
Expand Down Expand Up @@ -2564,8 +2566,8 @@ def datadef_handle(self, loc, tokens):
base_args = [] # names of all the non-starred args
req_args = 0 # number of required arguments
starred_arg = None # starred arg if there is one else None
saw_defaults = False # whether there have been any default args so far
types = {} # arg position to typedef for arg
arg_defaults = {} # arg position to default for arg
for i, arg in enumerate(original_args):

star, default, typedef = False, None, None
Expand All @@ -2586,13 +2588,14 @@ def datadef_handle(self, loc, tokens):
if argname.startswith("_"):
raise CoconutDeferredSyntaxError("data fields cannot start with an underscore", loc)
if star:
internal_assert(default is None, "invalid default in starred data field", default)
if i != len(original_args) - 1:
raise CoconutDeferredSyntaxError("starred data field must come last", loc)
starred_arg = argname
else:
if default:
saw_defaults = True
elif saw_defaults:
if default is not None:
arg_defaults[i] = "__new__.__defaults__[{i}]".format(i=len(arg_defaults))
elif arg_defaults:
raise CoconutDeferredSyntaxError("data fields with defaults must come after data fields without", loc)
else:
req_args += 1
Expand Down Expand Up @@ -2668,7 +2671,7 @@ def {arg}(self):
arg=starred_arg,
kwd_only=("*, " if self.target.startswith("3") else ""),
)
elif saw_defaults:
elif arg_defaults:
extra_stmts += handle_indentation(
'''
def __new__(_coconut_cls, {all_args}):
Expand All @@ -2680,10 +2683,22 @@ def __new__(_coconut_cls, {all_args}):
base_args_tuple=tuple_str_of(base_args),
)

if arg_defaults:
extra_stmts += handle_indentation(
'''
{data_defaults_var} = {arg_defaults} {type_ignore}
''',
add_newline=True,
).format(
data_defaults_var=data_defaults_var,
arg_defaults=dict_to_str(arg_defaults),
type_ignore=self.type_ignore_comment(),
)

namedtuple_args = base_args + ([] if starred_arg is None else [starred_arg])
namedtuple_call = self.make_namedtuple_call(name, namedtuple_args, types)

return self.assemble_data(decorators, name, namedtuple_call, inherit, extra_stmts, stmts, namedtuple_args, paramdefs)
return self.assemble_data(decorators, name, namedtuple_call, inherit, extra_stmts, stmts, base_args, paramdefs)

def make_namedtuple_call(self, name, namedtuple_args, types=None):
"""Construct a namedtuple call."""
Expand Down Expand Up @@ -2727,8 +2742,9 @@ def assemble_data(self, decorators, name, namedtuple_call, inherit, extra_stmts,
# add universal statements
all_extra_stmts = handle_indentation(
"""
{is_data_var} = True
__slots__ = ()
{is_data_var} = True
__match_args__ = {match_args}
def __add__(self, other): return _coconut.NotImplemented
def __mul__(self, other): return _coconut.NotImplemented
def __rmul__(self, other): return _coconut.NotImplemented
Expand All @@ -2741,9 +2757,8 @@ def __hash__(self):
add_newline=True,
).format(
is_data_var=is_data_var,
match_args=tuple_str_of(match_args, add_quotes=True),
)
if self.target_info < (3, 10):
all_extra_stmts += "__match_args__ = " + tuple_str_of(match_args, add_quotes=True) + "\n"
all_extra_stmts += extra_stmts

# manage docstring
Expand Down
4 changes: 4 additions & 0 deletions coconut/compiler/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
numpy_modules,
jax_numpy_modules,
self_match_types,
is_data_var,
data_defaults_var,
)
from coconut.util import (
univ_open,
Expand Down Expand Up @@ -209,6 +211,8 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap):
empty_dict="{}",
lbrace="{",
rbrace="}",
is_data_var=is_data_var,
data_defaults_var=data_defaults_var,
target_startswith=target_startswith,
default_encoding=default_encoding,
hash_line=hash_prefix + use_hash + "\n" if use_hash is not None else "",
Expand Down
43 changes: 33 additions & 10 deletions coconut/compiler/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
function_match_error_var,
match_set_name_var,
is_data_var,
data_defaults_var,
default_matcher_style,
self_match_types,
)
Expand All @@ -47,6 +48,7 @@
handle_indentation,
add_int_and_strs,
ordered_items,
tuple_str_of,
)

# -----------------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1039,15 +1041,8 @@ def match_data(self, tokens, item):

self.add_check("_coconut.isinstance(" + item + ", " + cls_name + ")")

if star_match is None:
self.add_check(
'_coconut.len({item}) == {total_len}'.format(
item=item,
total_len=len(pos_matches) + len(name_matches),
),
)
# avoid checking >= 0
elif len(pos_matches):
if len(pos_matches):
self.add_check(
"_coconut.len({item}) >= {min_len}".format(
item=item,
Expand All @@ -1063,6 +1058,34 @@ def match_data(self, tokens, item):
# handle keyword args
self.match_class_names(name_matches, item)

# handle data types with defaults for some arguments
if star_match is None:
# use a def so we can type ignore it
temp_var = self.get_temp_var()
self.add_def(
(
'{temp_var} ='
' _coconut.len({item}) <= _coconut.max({min_len}, _coconut.len({item}.__match_args__))'
' and _coconut.all('
'i in _coconut.getattr({item}, "{data_defaults_var}", {{}})'
' and {item}[i] == _coconut.getattr({item}, "{data_defaults_var}", {{}})[i]'
' for i in _coconut.range({min_len}, _coconut.len({item}.__match_args__))'
' if {item}.__match_args__[i] not in {name_matches}'
') if _coconut.hasattr({item}, "__match_args__")'
' else _coconut.len({item}) == {min_len}'
' {type_ignore}'
).format(
item=item,
temp_var=temp_var,
data_defaults_var=data_defaults_var,
min_len=len(pos_matches),
name_matches=tuple_str_of(name_matches, add_quotes=True),
type_ignore=self.comp.type_ignore_comment(),
),
)
with self.down_a_level():
self.add_check(temp_var)

def match_data_or_class(self, tokens, item):
"""Matches an ambiguous data or class match."""
cls_name, matches = tokens
Expand All @@ -1071,13 +1094,13 @@ def match_data_or_class(self, tokens, item):
self.add_def(
handle_indentation(
"""
{is_data_result_var} = _coconut.getattr({cls_name}, "{is_data_var}", False) or _coconut.isinstance({cls_name}, _coconut.tuple) and _coconut.all(_coconut.getattr(_coconut_x, "{is_data_var}", False) for _coconut_x in {cls_name}){type_comment}
{is_data_result_var} = _coconut.getattr({cls_name}, "{is_data_var}", False) or _coconut.isinstance({cls_name}, _coconut.tuple) and _coconut.all(_coconut.getattr(_coconut_x, "{is_data_var}", False) for _coconut_x in {cls_name}){type_ignore}
""",
).format(
is_data_result_var=is_data_result_var,
is_data_var=is_data_var,
cls_name=cls_name,
type_comment=self.comp.type_ignore_comment(),
type_ignore=self.comp.type_ignore_comment(),
),
)

Expand Down
11 changes: 4 additions & 7 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -1521,11 +1521,7 @@ class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){
that may or may not be an error, similar to Haskell's Either.

Effectively equivalent to:
data Expected[T](result: T?, error: BaseException?):
def __new__(cls, result: T?=None, error: BaseException?=None) -> Expected[T]:
if result is not None and error is not None:
raise TypeError("Expected cannot have both a result and an error")
return makedata(cls, result, error)
data Expected[T](result: T? = None, error: BaseException? = None):
def __bool__(self) -> bool:
return self.error is None
def __fmap__[U](self, func: T -> U) -> Expected[U]:
Expand Down Expand Up @@ -1556,8 +1552,10 @@ class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){
raise self.error
return self.result
'''
_coconut_is_data = True
__slots__ = ()
{is_data_var} = True
__match_args__ = ("result", "error")
{data_defaults_var} = {lbrace}0: None, 1: None{rbrace}
def __add__(self, other): return _coconut.NotImplemented
def __mul__(self, other): return _coconut.NotImplemented
def __rmul__(self, other): return _coconut.NotImplemented
Expand All @@ -1566,7 +1564,6 @@ class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){
return self.__class__ is other.__class__ and _coconut.tuple.__eq__(self, other)
def __hash__(self):
return _coconut.tuple.__hash__(self) ^ hash(self.__class__)
__match_args__ = ("result", "error")
def __new__(cls, result=_coconut_sentinel, error=None):
if result is not _coconut_sentinel and error is not None:
raise _coconut.TypeError("Expected cannot have both a result and an error")
Expand Down
8 changes: 8 additions & 0 deletions coconut/compiler/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,14 @@ def tuple_str_of_str(argstr, add_parens=True):
return out


def dict_to_str(inputdict, quote_keys=False, quote_values=False):
"""Convert a dictionary of code snippets to a dict literal."""
return "{" + ", ".join(
(repr(key) if quote_keys else str(key)) + ": " + (repr(value) if quote_values else str(value))
for key, value in ordered_items(inputdict)
) + "}"


def split_comment(line, move_indents=False):
"""Split line into base and comment."""
if move_indents:
Expand Down
2 changes: 2 additions & 0 deletions coconut/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def get_bool_env_var(env_var, default=False):
format_var = reserved_prefix + "_format"
is_data_var = reserved_prefix + "_is_data"
custom_op_var = reserved_prefix + "_op"
is_data_var = reserved_prefix + "_is_data"
data_defaults_var = reserved_prefix + "_data_defaults"

# prefer Matcher.get_temp_var to proliferating more vars here
match_to_args_var = reserved_prefix + "_match_args"
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "2.1.1"
VERSION_NAME = "The Spanish Inquisition"
# False for release, int >= 1 for develop
DEVELOP = 43
DEVELOP = 44
ALPHA = False # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
Expand Down
33 changes: 33 additions & 0 deletions coconut/tests/src/cocotest/agnostic/main.coco
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,10 @@ def main_test() -> bool:
assert_raises(Expected(error=TypeError()).unwrap, TypeError)
assert_raises(Expected(error=KeyboardInterrupt()).unwrap, KeyboardInterrupt)
assert Expected(10).or_else(const <| Expected(20)) == Expected(10) == Expected(error=TypeError()).or_else(const <| Expected(10))
Expected(x) = Expected(10)
assert x == 10
Expected(error=err) = Expected(error=some_err)
assert err is some_err

recit = ([1,2,3] :: recit) |> map$(.+1)
assert tee(recit)
Expand Down Expand Up @@ -1410,6 +1414,35 @@ def main_test() -> bool:
hardref = map((.+1), [1,2,3])
assert weakref.ref(hardref)() |> list == [2, 3, 4]
assert parallel_map(ident, [MatchError]) |> list == [MatchError]
match data tuple(1, 2) in (1, 2, 3):
assert False
data TestDefaultMatching(x="x default", y="y default")
TestDefaultMatching(got_x) = TestDefaultMatching(1)
assert got_x == 1
TestDefaultMatching(y=got_y) = TestDefaultMatching(y=10)
assert got_y == 10
TestDefaultMatching() = TestDefaultMatching()
data HasStar(x, y, *zs)
HasStar(x, *ys) = HasStar(1, 2, 3, 4)
assert x == 1
assert ys == (2, 3, 4)
HasStar(x, y, z) = HasStar(1, 2, 3)
assert (x, y, z) == (1, 2, 3)
HasStar(5, y=10) = HasStar(5, 10)
HasStar(1, 2, 3, zs=(3,)) = HasStar(1, 2, 3)
HasStar(x=1, y=2) = HasStar(1, 2)
match HasStar(x) in HasStar(1, 2):
assert False
match HasStar(x, y) in HasStar(1, 2, 3):
assert False
data HasStarAndDef(x, y="y", *zs)
HasStarAndDef(1, "y") = HasStarAndDef(1)
HasStarAndDef(1) = HasStarAndDef(1)
HasStarAndDef(x=1) = HasStarAndDef(1)
HasStarAndDef(1, 2, 3, zs=(3,)) = HasStarAndDef(1, 2, 3)
HasStarAndDef(1, y=2) = HasStarAndDef(1, 2)
match HasStarAndDef(x, y) in HasStarAndDef(1, 2, 3):
assert False
return True

def test_asyncio() -> bool:
Expand Down
2 changes: 1 addition & 1 deletion coconut/tests/src/cocotest/agnostic/suite.coco
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def suite_test() -> bool:
else:
assert False
assert vector.__match_args__ == ("x", "y") == typed_vector.__match_args__ # type: ignore
assert Pred.__match_args__ == ("name", "args") == Pred_.__match_args__ # type: ignore
assert Pred.__match_args__ == ("name",) == Pred_.__match_args__ # type: ignore
m = Matchable(1, 2, 3)
class Matchable(newx, newy, newz) = m
assert (newx, newy, newz) == (1, 2, 3)
Expand Down

0 comments on commit a137c16

Please sign in to comment.