Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure equal constructed and parsed keyword objects #94

Merged
merged 8 commits into from
May 31, 2023
19 changes: 19 additions & 0 deletions src/daidepp/keywords/base_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class MTO(_DAIDEObject):
unit: Unit
location: Location

def __post_init__(self):
if isinstance(self.location, str):
object.__setattr__(self, "location", Location(province=self.location))
super().__post_init__()

mjspeck marked this conversation as resolved.
Show resolved Hide resolved
def __str__(self):
return f"( {self.unit} ) MTO {self.location}"

Expand All @@ -82,6 +87,13 @@ class SUP(_DAIDEObject):
supported_unit: Unit
province_no_coast: Optional[ProvinceNoCoast] = None

def __post_init__(self):
if isinstance(self.province_no_coast, Location):
object.__setattr__(
self, "province_no_coast", self.province_no_coast.province
)
super().__post_init__()

@property
def unit(self) -> Unit:
"""Unit attribute to keep API consistent
Expand Down Expand Up @@ -150,6 +162,8 @@ def __init__(self, unit, province, *province_seas):
self.__post_init__()

def __post_init__(self):
if isinstance(self.province, str):
object.__setattr__(self, "province", Location(province=self.province))
super().__post_init__()
if not self.province_seas:
raise ValueError(
Expand All @@ -173,6 +187,11 @@ class RTO(_DAIDEObject):
unit: Unit
location: Location

def __post_init__(self):
if isinstance(self.location, str):
object.__setattr__(self, "location", Location(province=self.location))
super().__post_init__()

def __str__(self):
return f"( {self.unit} ) RTO {self.location}"

Expand Down
14 changes: 7 additions & 7 deletions src/daidepp/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def visit_wht(self, node, visited_children) -> WHT:
return WHT(unit)

def visit_how(self, node, visited_children) -> HOW:
_, _, province_power, _ = visited_children
_, _, province_power, _ = visited_children[0]
return HOW(province_power)

def visit_exp(self, node, visited_children) -> EXP:
Expand Down Expand Up @@ -139,7 +139,7 @@ def why_param(self, node, visited_children) -> Union[FCT, THK, PRP, INS]:

def visit_why(self, node, visited_children) -> WHY:
_, _, why_param, _ = visited_children
return WHY(why_param)
return WHY(why_param[0])

def visit_pob(self, node, visited_children) -> POB:
_, _, why, _ = visited_children
Expand Down Expand Up @@ -287,7 +287,7 @@ def visit_for(self, node, visited_children) -> FOR:
_, start_turn, _, _, end_turn, _ = turn
return FOR(start_turn, end_turn, arrangement)
else:
return FOR(start_turn, None, arrangement)
return FOR(turn, None, arrangement)

def visit_xoy(self, node, visited_children) -> XOY:
_, _, power_x, _, _, power_y, _ = visited_children
Expand Down Expand Up @@ -317,11 +317,11 @@ def visit_snd(self, node, visited_children) -> SND:
_,
) = visited_children

recv_power = [recv_power]
recv_powers = [recv_power]
for ws_recv_power in ws_recv_powers:
_, recv_power = ws_recv_power
recv_power.append(recv_power)
return SND(power, recv_power, message)
recv_powers.append(recv_power)
return SND(power, recv_powers, message)

def visit_fwd(self, node, visited_children) -> FWD:
_, _, power, ws_powers, _, _, power_1, _, _, power_2, _ = visited_children
Expand Down Expand Up @@ -368,7 +368,7 @@ def visit_sup(self, node, visited_children) -> SUP:
return SUP(supporting_unit, supported_unit)
else:
_, _, province_no_coast = ws_mto_prov = ws_province_no_coast[0]
return SUP(supporting_unit, supported_unit, province_no_coast)
return SUP(supporting_unit, supported_unit, province_no_coast.province)

def visit_cvy(self, node, visited_children) -> CVY:
_, convoying_unit, _, _, _, convoyed_unit, _, _, _, province = visited_children
Expand Down
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,31 @@

from daidepp.grammar import create_daide_grammar
from daidepp.grammar.grammar import DAIDELevel
from daidepp.keywords.press_keywords import AnyDAIDEToken
from daidepp.visitor import daide_visitor

# Declared outside of fixture for performance
MAX_DAIDE_LEVEL = get_args(DAIDELevel)[-1]
ALL_GRAMMAR = create_daide_grammar(level=MAX_DAIDE_LEVEL, string_type="all")


@pytest.fixture
def daide_parser():
"""Helper function to ensure grammar and visitor construct correct DAIDE objects"""

def parse_daide(string: str) -> AnyDAIDEToken:
"""Parses a DAIDE string into `daidepp` objects.
:param string: String to parse into DAIDE.
:return: Parsed DAIDE object.
:raises ValueError: If string is invalid DAIDE.
"""
try:
parse_tree = ALL_GRAMMAR.parse(string)
return daide_visitor.visit(parse_tree)
except Exception as ex:
raise ValueError(f"Failed to parse DAIDE string: {string!r}") from ex

return parse_daide


@pytest.fixture(scope="session")
Expand Down
70 changes: 56 additions & 14 deletions tests/keywords/test_base_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
(("TUR", "FLT", "BRE"), "TUR FLT BRE"),
],
)
def test_Unit(input, expected_output):
print(input)
def test_Unit(input, expected_output, daide_parser):
unit = Unit(*input)
assert str(unit) == expected_output
assert unit == daide_parser(expected_output)
hash(unit)


@pytest.mark.parametrize(
Expand All @@ -28,21 +29,27 @@ def test_Unit(input, expected_output):
((Unit("ENG", "AMY", "ANK"),), "( ENG AMY ANK ) HLD"),
],
)
def test_HLD(input, expected_output):
def test_HLD(input, expected_output, daide_parser):
hld = HLD(*input)
assert str(hld) == expected_output
assert hld == daide_parser(expected_output)
hash(hld)


@pytest.mark.parametrize(
["input", "expected_output"],
[
((Unit("AUS", "FLT", "ALB"), "BUL"), "( AUS FLT ALB ) MTO BUL"),
((Unit("AUS", "FLT", "ALB"), Location("BUL")), "( AUS FLT ALB ) MTO BUL"),
((Unit("ENG", "AMY", "ANK"), "CLY"), "( ENG AMY ANK ) MTO CLY"),
((Unit("ENG", "AMY", "ANK"), Location("CLY")), "( ENG AMY ANK ) MTO CLY"),
],
)
def test_MTO(input, expected_output):
def test_MTO(input, expected_output, daide_parser):
mto = MTO(*input)
assert str(mto) == expected_output
assert mto == daide_parser(expected_output)
hash(mto)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -74,9 +81,11 @@ def test_MTO(input, expected_output):
),
],
)
def test_SUP(input, expected_output):
def test_SUP(input, expected_output, daide_parser):
sup = SUP(*input)
assert str(sup) == expected_output
assert sup == daide_parser(expected_output)
hash(sup)


@pytest.mark.parametrize(
Expand All @@ -94,13 +103,16 @@ def test_SUP(input, expected_output):
),
],
)
def test_SUP_location(supporting_unit, supported_unit, province_no_coast):
def test_SUP_location(supporting_unit, supported_unit, province_no_coast, daide_parser):
sup = SUP(
supported_unit=supported_unit,
supporting_unit=supporting_unit,
province_no_coast=province_no_coast,
)
assert isinstance(sup.province_no_coast, str)
assert isinstance(sup.province_no_coast_location, Location)
assert sup == daide_parser(str(sup))
hash(sup)


@pytest.mark.parametrize(
Expand All @@ -116,9 +128,11 @@ def test_SUP_location(supporting_unit, supported_unit, province_no_coast):
),
],
)
def test_CVY(input, expected_output):
def test_CVY(input, expected_output, daide_parser):
cvy = CVY(*input)
assert str(cvy) == expected_output
assert cvy == daide_parser(expected_output)
hash(cvy)


@pytest.mark.parametrize(
Expand All @@ -136,23 +150,41 @@ def test_CVY(input, expected_output):
(Unit("FRA", "FLT", "APU"), "CON", "ADR", "AEG", "BAL"),
"( FRA FLT APU ) CTO CON VIA ( ADR AEG BAL )",
),
(
(Unit("AUS", "FLT", "ALB"), Location("BUL"), "ADR"),
"( AUS FLT ALB ) CTO BUL VIA ( ADR )",
),
(
(Unit("ENG", "AMY", "ANK"), Location("CLY"), "ADR", "AEG"),
"( ENG AMY ANK ) CTO CLY VIA ( ADR AEG )",
),
(
(Unit("FRA", "FLT", "APU"), Location("CON"), "ADR", "AEG", "BAL"),
"( FRA FLT APU ) CTO CON VIA ( ADR AEG BAL )",
),
],
)
def test_MoveByCVY(input, expected_output):
def test_MoveByCVY(input, expected_output, daide_parser):
mvc = MoveByCVY(*input)
assert str(mvc) == expected_output
assert mvc == daide_parser(expected_output)
hash(mvc)


@pytest.mark.parametrize(
["input", "expected_output"],
[
((Unit("AUS", "FLT", "ALB"), "BUL"), "( AUS FLT ALB ) RTO BUL"),
((Unit("AUS", "FLT", "ALB"), Location("BUL")), "( AUS FLT ALB ) RTO BUL"),
((Unit("ENG", "AMY", "ANK"), "CLY"), "( ENG AMY ANK ) RTO CLY"),
((Unit("ENG", "AMY", "ANK"), Location("CLY")), "( ENG AMY ANK ) RTO CLY"),
],
)
def test_RTO(input, expected_output):
def test_RTO(input, expected_output, daide_parser):
rto = RTO(*input)
assert str(rto) == expected_output
assert rto == daide_parser(expected_output)
hash(rto)


@pytest.mark.parametrize(
Expand All @@ -162,9 +194,11 @@ def test_RTO(input, expected_output):
((Unit("ENG", "AMY", "ANK"),), "( ENG AMY ANK ) DSB"),
],
)
def test_DSB(input, expected_output):
def test_DSB(input, expected_output, daide_parser):
dsb = DSB(*input)
assert str(dsb) == expected_output
assert dsb == daide_parser(expected_output)
hash(dsb)


@pytest.mark.parametrize(
Expand All @@ -174,9 +208,11 @@ def test_DSB(input, expected_output):
((Unit("ENG", "AMY", "ANK"),), "( ENG AMY ANK ) BLD"),
],
)
def test_BLD(input, expected_output):
def test_BLD(input, expected_output, daide_parser):
bld = BLD(*input)
assert str(bld) == expected_output
assert bld == daide_parser(expected_output)
hash(bld)


@pytest.mark.parametrize(
Expand All @@ -186,9 +222,11 @@ def test_BLD(input, expected_output):
((Unit("ENG", "AMY", "ANK"),), "( ENG AMY ANK ) REM"),
],
)
def test_REM(input, expected_output):
def test_REM(input, expected_output, daide_parser):
rem = REM(*input)
assert str(rem) == expected_output
assert rem == daide_parser(expected_output)
hash(rem)


@pytest.mark.parametrize(
Expand All @@ -198,9 +236,11 @@ def test_REM(input, expected_output):
(("ENG",), "ENG WVE"),
],
)
def test_WVE(input, expected_output):
def test_WVE(input, expected_output, daide_parser):
wve = WVE(*input)
assert str(wve) == expected_output
assert wve == daide_parser(expected_output)
hash(wve)


@pytest.mark.parametrize(
Expand All @@ -209,6 +249,8 @@ def test_WVE(input, expected_output):
(("SPR", 1901), "SPR 1901"),
],
)
def test_turn(input, expected_output):
def test_turn(input, expected_output, daide_parser):
turn_1 = Turn(*input)
assert str(turn_1) == expected_output
assert turn_1 == daide_parser(expected_output)
hash(turn_1)
Loading