From fa6194c0a9c5907e2c8fe046009d3aea2b9d53a4 Mon Sep 17 00:00:00 2001 From: Dave Halter Date: Mon, 18 May 2020 01:41:07 +0200 Subject: [PATCH] Refactor test_definition_start_end_position to use parametrize --- test/test_api/test_classes.py | 59 ++++++++++++++++------------------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/test/test_api/test_classes.py b/test/test_api/test_classes.py index 24cf7f36f..c59750323 100644 --- a/test/test_api/test_classes.py +++ b/test/test_api/test_classes.py @@ -605,37 +605,32 @@ def test_pseudotreenameclass_type(Script): assert Script('from typing import Any\n').get_names()[0].type == 'class' -def test_definition_start_end_position(Script): - '''Tests for definition_start_position and definition_end_position''' - code = '\n'.join([ - 'def a_func():', - ' return "bar"', - '', - 'var1 = 12', - '', - 'class AClass:', - ' """my class"""', - ' @staticmethod', - ' def hello():', - ' func_var = 1', - ' return func_var', - ]) - script = Script(code=code) - names = script.get_names(all_scopes=True) - assert len(names) == 5 - a_func, var1, AClass, hello, func_var = names - - assert a_func.get_definition_start_position() == (1, 0) - assert a_func.get_definition_end_position() == (2, 16) - - assert var1.get_definition_start_position() == (4, 0) - assert var1.get_definition_end_position() == (4, 9) - - assert AClass.get_definition_start_position() == (6, 0) - assert AClass.get_definition_end_position() == (11, 23) +cls_code = '''\ +class AClass: + """my class""" + @staticmethod + def hello(): + func_var = 1 + return func_var +''' - assert hello.get_definition_start_position() == (9, 4) - assert hello.get_definition_end_position() == (11, 23) - assert func_var.get_definition_start_position() == (10, 8) - assert func_var.get_definition_end_position() == (10, 20) +@pytest.mark.parametrize( + 'code, pos, start, end', [ + ('def a_func():\n return "bar"\n', (1, 4), (1, 0), (2, 16)), + ('var1 = 12', (1, 0), (1, 0), (1, 9)), + ('class AClass: pass', (1, 6), (1, 0), (1, 18)), + ('class AClass: pass\n', (1, 6), (1, 0), (1, 18)), + (cls_code, (1, 6), (1, 0), (6, 23)), + (cls_code, (4, 8), (4, 4), (6, 23)), + (cls_code, (5, 8), (5, 8), (5, 20)), + ] +) +def test_definition_start_end_position(Script, code, pos, start, end): + '''Tests for definition_start_position and definition_end_position''' + name = next( + n for n in Script(code=code).get_names(all_scopes=True, references=True) + if n._name.tree_name.start_pos <= pos <= n._name.tree_name.end_pos + ) + assert name.get_definition_start_position() == start + assert name.get_definition_end_position() == end