Skip to content

Commit

Permalink
Merge pull request #94 from aphedges/fix-parse-to-object
Browse files Browse the repository at this point in the history
Ensure equal constructed and parsed keyword objects
  • Loading branch information
mjspeck authored May 31, 2023
2 parents 483cc14 + 3d53dd8 commit 1cfc75a
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 71 deletions.
19 changes: 19 additions & 0 deletions src/daidepp/keywords/base_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class MTO(_DAIDEObject):
unit: Unit
location: Location

def __post_init__(self):
if isinstance(self.location, str):
object.__setattr__(self, "location", Location(province=self.location))
super().__post_init__()

def __str__(self):
return f"( {self.unit} ) MTO {self.location}"

Expand All @@ -82,6 +87,13 @@ class SUP(_DAIDEObject):
supported_unit: Unit
province_no_coast: Optional[ProvinceNoCoast] = None

def __post_init__(self):
if isinstance(self.province_no_coast, Location):
object.__setattr__(
self, "province_no_coast", self.province_no_coast.province
)
super().__post_init__()

@property
def unit(self) -> Unit:
"""Unit attribute to keep API consistent
Expand Down Expand Up @@ -150,6 +162,8 @@ def __init__(self, unit, province, *province_seas):
self.__post_init__()

def __post_init__(self):
if isinstance(self.province, str):
object.__setattr__(self, "province", Location(province=self.province))
super().__post_init__()
if not self.province_seas:
raise ValueError(
Expand All @@ -173,6 +187,11 @@ class RTO(_DAIDEObject):
unit: Unit
location: Location

def __post_init__(self):
if isinstance(self.location, str):
object.__setattr__(self, "location", Location(province=self.location))
super().__post_init__()

def __str__(self):
return f"( {self.unit} ) RTO {self.location}"

Expand Down
14 changes: 7 additions & 7 deletions src/daidepp/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def visit_wht(self, node, visited_children) -> WHT:
return WHT(unit)

def visit_how(self, node, visited_children) -> HOW:
_, _, province_power, _ = visited_children
_, _, province_power, _ = visited_children[0]
return HOW(province_power)

def visit_exp(self, node, visited_children) -> EXP:
Expand Down Expand Up @@ -139,7 +139,7 @@ def why_param(self, node, visited_children) -> Union[FCT, THK, PRP, INS]:

def visit_why(self, node, visited_children) -> WHY:
_, _, why_param, _ = visited_children
return WHY(why_param)
return WHY(why_param[0])

def visit_pob(self, node, visited_children) -> POB:
_, _, why, _ = visited_children
Expand Down Expand Up @@ -287,7 +287,7 @@ def visit_for(self, node, visited_children) -> FOR:
_, start_turn, _, _, end_turn, _ = turn
return FOR(start_turn, end_turn, arrangement)
else:
return FOR(start_turn, None, arrangement)
return FOR(turn, None, arrangement)

def visit_xoy(self, node, visited_children) -> XOY:
_, _, power_x, _, _, power_y, _ = visited_children
Expand Down Expand Up @@ -317,11 +317,11 @@ def visit_snd(self, node, visited_children) -> SND:
_,
) = visited_children

recv_power = [recv_power]
recv_powers = [recv_power]
for ws_recv_power in ws_recv_powers:
_, recv_power = ws_recv_power
recv_power.append(recv_power)
return SND(power, recv_power, message)
recv_powers.append(recv_power)
return SND(power, recv_powers, message)

def visit_fwd(self, node, visited_children) -> FWD:
_, _, power, ws_powers, _, _, power_1, _, _, power_2, _ = visited_children
Expand Down Expand Up @@ -368,7 +368,7 @@ def visit_sup(self, node, visited_children) -> SUP:
return SUP(supporting_unit, supported_unit)
else:
_, _, province_no_coast = ws_mto_prov = ws_province_no_coast[0]
return SUP(supporting_unit, supported_unit, province_no_coast)
return SUP(supporting_unit, supported_unit, province_no_coast.province)

def visit_cvy(self, node, visited_children) -> CVY:
_, convoying_unit, _, _, _, convoyed_unit, _, _, _, province = visited_children
Expand Down
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,31 @@

from daidepp.grammar import create_daide_grammar
from daidepp.grammar.grammar import DAIDELevel
from daidepp.keywords.press_keywords import AnyDAIDEToken
from daidepp.visitor import daide_visitor

# Declared outside of fixture for performance
MAX_DAIDE_LEVEL = get_args(DAIDELevel)[-1]
ALL_GRAMMAR = create_daide_grammar(level=MAX_DAIDE_LEVEL, string_type="all")


@pytest.fixture
def daide_parser():
"""Helper function to ensure grammar and visitor construct correct DAIDE objects"""

def parse_daide(string: str) -> AnyDAIDEToken:
"""Parses a DAIDE string into `daidepp` objects.
:param string: String to parse into DAIDE.
:return: Parsed DAIDE object.
:raises ValueError: If string is invalid DAIDE.
"""
try:
parse_tree = ALL_GRAMMAR.parse(string)
return daide_visitor.visit(parse_tree)
except Exception as ex:
raise ValueError(f"Failed to parse DAIDE string: {string!r}") from ex

return parse_daide


@pytest.fixture(scope="session")
Expand Down
70 changes: 56 additions & 14 deletions tests/keywords/test_base_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
(("TUR", "FLT", "BRE"), "TUR FLT BRE"),
],
)
def test_Unit(input, expected_output):
print(input)
def test_Unit(input, expected_output, daide_parser):
unit = Unit(*input)
assert str(unit) == expected_output
assert unit == daide_parser(expected_output)
hash(unit)


@pytest.mark.parametrize(
Expand All @@ -28,21 +29,27 @@ def test_Unit(input, expected_output):
((Unit("ENG", "AMY", "ANK"),), "( ENG AMY ANK ) HLD"),
],
)
def test_HLD(input, expected_output):
def test_HLD(input, expected_output, daide_parser):
hld = HLD(*input)
assert str(hld) == expected_output
assert hld == daide_parser(expected_output)
hash(hld)


@pytest.mark.parametrize(
["input", "expected_output"],
[
((Unit("AUS", "FLT", "ALB"), "BUL"), "( AUS FLT ALB ) MTO BUL"),
((Unit("AUS", "FLT", "ALB"), Location("BUL")), "( AUS FLT ALB ) MTO BUL"),
((Unit("ENG", "AMY", "ANK"), "CLY"), "( ENG AMY ANK ) MTO CLY"),
((Unit("ENG", "AMY", "ANK"), Location("CLY")), "( ENG AMY ANK ) MTO CLY"),
],
)
def test_MTO(input, expected_output):
def test_MTO(input, expected_output, daide_parser):
mto = MTO(*input)
assert str(mto) == expected_output
assert mto == daide_parser(expected_output)
hash(mto)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -74,9 +81,11 @@ def test_MTO(input, expected_output):
),
],
)
def test_SUP(input, expected_output):
def test_SUP(input, expected_output, daide_parser):
sup = SUP(*input)
assert str(sup) == expected_output
assert sup == daide_parser(expected_output)
hash(sup)


@pytest.mark.parametrize(
Expand All @@ -94,13 +103,16 @@ def test_SUP(input, expected_output):
),
],
)
def test_SUP_location(supporting_unit, supported_unit, province_no_coast):
def test_SUP_location(supporting_unit, supported_unit, province_no_coast, daide_parser):
sup = SUP(
supported_unit=supported_unit,
supporting_unit=supporting_unit,
province_no_coast=province_no_coast,
)
assert isinstance(sup.province_no_coast, str)
assert isinstance(sup.province_no_coast_location, Location)
assert sup == daide_parser(str(sup))
hash(sup)


@pytest.mark.parametrize(
Expand All @@ -116,9 +128,11 @@ def test_SUP_location(supporting_unit, supported_unit, province_no_coast):
),
],
)
def test_CVY(input, expected_output):
def test_CVY(input, expected_output, daide_parser):
cvy = CVY(*input)
assert str(cvy) == expected_output
assert cvy == daide_parser(expected_output)
hash(cvy)


@pytest.mark.parametrize(
Expand All @@ -136,23 +150,41 @@ def test_CVY(input, expected_output):
(Unit("FRA", "FLT", "APU"), "CON", "ADR", "AEG", "BAL"),
"( FRA FLT APU ) CTO CON VIA ( ADR AEG BAL )",
),
(
(Unit("AUS", "FLT", "ALB"), Location("BUL"), "ADR"),
"( AUS FLT ALB ) CTO BUL VIA ( ADR )",
),
(
(Unit("ENG", "AMY", "ANK"), Location("CLY"), "ADR", "AEG"),
"( ENG AMY ANK ) CTO CLY VIA ( ADR AEG )",
),
(
(Unit("FRA", "FLT", "APU"), Location("CON"), "ADR", "AEG", "BAL"),
"( FRA FLT APU ) CTO CON VIA ( ADR AEG BAL )",
),
],
)
def test_MoveByCVY(input, expected_output):
def test_MoveByCVY(input, expected_output, daide_parser):
mvc = MoveByCVY(*input)
assert str(mvc) == expected_output
assert mvc == daide_parser(expected_output)
hash(mvc)


@pytest.mark.parametrize(
["input", "expected_output"],
[
((Unit("AUS", "FLT", "ALB"), "BUL"), "( AUS FLT ALB ) RTO BUL"),
((Unit("AUS", "FLT", "ALB"), Location("BUL")), "( AUS FLT ALB ) RTO BUL"),
((Unit("ENG", "AMY", "ANK"), "CLY"), "( ENG AMY ANK ) RTO CLY"),
((Unit("ENG", "AMY", "ANK"), Location("CLY")), "( ENG AMY ANK ) RTO CLY"),
],
)
def test_RTO(input, expected_output):
def test_RTO(input, expected_output, daide_parser):
rto = RTO(*input)
assert str(rto) == expected_output
assert rto == daide_parser(expected_output)
hash(rto)


@pytest.mark.parametrize(
Expand All @@ -162,9 +194,11 @@ def test_RTO(input, expected_output):
((Unit("ENG", "AMY", "ANK"),), "( ENG AMY ANK ) DSB"),
],
)
def test_DSB(input, expected_output):
def test_DSB(input, expected_output, daide_parser):
dsb = DSB(*input)
assert str(dsb) == expected_output
assert dsb == daide_parser(expected_output)
hash(dsb)


@pytest.mark.parametrize(
Expand All @@ -174,9 +208,11 @@ def test_DSB(input, expected_output):
((Unit("ENG", "AMY", "ANK"),), "( ENG AMY ANK ) BLD"),
],
)
def test_BLD(input, expected_output):
def test_BLD(input, expected_output, daide_parser):
bld = BLD(*input)
assert str(bld) == expected_output
assert bld == daide_parser(expected_output)
hash(bld)


@pytest.mark.parametrize(
Expand All @@ -186,9 +222,11 @@ def test_BLD(input, expected_output):
((Unit("ENG", "AMY", "ANK"),), "( ENG AMY ANK ) REM"),
],
)
def test_REM(input, expected_output):
def test_REM(input, expected_output, daide_parser):
rem = REM(*input)
assert str(rem) == expected_output
assert rem == daide_parser(expected_output)
hash(rem)


@pytest.mark.parametrize(
Expand All @@ -198,9 +236,11 @@ def test_REM(input, expected_output):
(("ENG",), "ENG WVE"),
],
)
def test_WVE(input, expected_output):
def test_WVE(input, expected_output, daide_parser):
wve = WVE(*input)
assert str(wve) == expected_output
assert wve == daide_parser(expected_output)
hash(wve)


@pytest.mark.parametrize(
Expand All @@ -209,6 +249,8 @@ def test_WVE(input, expected_output):
(("SPR", 1901), "SPR 1901"),
],
)
def test_turn(input, expected_output):
def test_turn(input, expected_output, daide_parser):
turn_1 = Turn(*input)
assert str(turn_1) == expected_output
assert turn_1 == daide_parser(expected_output)
hash(turn_1)
Loading

0 comments on commit 1cfc75a

Please sign in to comment.