diff --git a/tests/test_h3.py b/tests/test_h3.py index f9f8dcd4..8ac307d2 100644 --- a/tests/test_h3.py +++ b/tests/test_h3.py @@ -496,14 +496,19 @@ def cells_at_res(res): for parent in cells: yield from h3.cell_to_children(parent, res) - def roundtrip(children, res_parent): - for c in children: - parent = h3.cell_to_parent(c, res_parent) - pos = h3.cell_to_child_pos(res_parent, c) - yield h3.child_pos_to_cell(parent, res_child, pos) - - for res_parent in [0, 1]: - for res_child in [0, 1, 2, 3]: - if res_parent <= res_child: - children = set(cells_at_res(res_child)) - assert set(roundtrip(children, res_parent)) == children + def roundtrip(child, res_parent): + res_child = h3.get_resolution(child) + parent = h3.cell_to_parent(child, res_parent) + pos = h3.cell_to_child_pos(res_parent, child) + return h3.child_pos_to_cell(parent, res_child, pos) + + res_pairs = [ + (res_parent, res_child) + for res_parent in [0, 1] + for res_child in [0, 1, 2, 3] + if res_parent <= res_child + ] + + for res_parent, res_child in res_pairs: + for child in cells_at_res(res_child): + assert child == roundtrip(child, res_parent)