diff --git a/pyt/cfg/stmt_visitor.py b/pyt/cfg/stmt_visitor.py index de8e396e..00a0fb2f 100644 --- a/pyt/cfg/stmt_visitor.py +++ b/pyt/cfg/stmt_visitor.py @@ -11,6 +11,7 @@ ) from ..core.ast_helper import ( generate_ast, + get_call_names, get_call_names_as_string ) from ..core.module_definitions import ( @@ -472,14 +473,6 @@ def assignment_call_node(self, left_hand_label, ast_node): call = self.visit(ast_node.value) call_label = call.left_hand_side - if isinstance(call, BBorBInode): - # Necessary to know e.g. - # `image_name = image_name.replace('..', '')` - # is a reassignment. - vars_visitor = VarsVisitor() - vars_visitor.visit(ast_node.value) - call.right_hand_side_variables.extend(vars_visitor.result) - call_assignment = AssignmentCallNode( left_hand_label + ' = ' + call_label, left_hand_label, @@ -572,7 +565,7 @@ def visit_While(self, node): return self.loop_node_skeleton(test, node) - def add_blackbox_or_builtin_call(self, node, blackbox): + def add_blackbox_or_builtin_call(self, node, blackbox): # noqa: C901 """Processes a blackbox or builtin function when it is called. Nothing gets assigned to ret_func_foo in the builtin/blackbox case. @@ -597,14 +590,14 @@ def add_blackbox_or_builtin_call(self, node, blackbox): saved_function_call_index = self.function_call_index self.undecided = False - call_label = LabelVisitor() - call_label.visit(node) + call_label_visitor = LabelVisitor() + call_label_visitor.visit(node) - index = call_label.result.find('(') + call_function_label = call_label_visitor.result[:call_label_visitor.result.find('(')] # Create e.g. ~call_1 = ret_func_foo LHS = CALL_IDENTIFIER + 'call_' + str(saved_function_call_index) - RHS = 'ret_' + call_label.result[:index] + '(' + RHS = 'ret_' + call_function_label + '(' call_node = BBorBInode( label='', @@ -613,7 +606,7 @@ def add_blackbox_or_builtin_call(self, node, blackbox): right_hand_side_variables=[], line_number=node.lineno, path=self.filenames[-1], - func_name=call_label.result[:index] + func_name=call_function_label ) visual_args = list() rhs_vars = list() @@ -657,6 +650,11 @@ def add_blackbox_or_builtin_call(self, node, blackbox): # `scrypt.outer(scrypt.inner(image_name), scrypt.other_inner(image_name))` last_return_value_of_nested_call.connect(call_node) + call_names = list(get_call_names(node.func)) + if len(call_names) > 1: + # taint is a RHS variable (self) of taint.lower() + rhs_vars.append(call_names[0]) + if len(visual_args) > 0: for arg in visual_args: RHS = RHS + arg + ", " @@ -667,7 +665,6 @@ def add_blackbox_or_builtin_call(self, node, blackbox): call_node.label = LHS + " = " + RHS call_node.right_hand_side_variables = rhs_vars - # Used in get_sink_args, not using right_hand_side_variables because it is extended in assignment_call_node rhs_visitor = RHSVisitor() rhs_visitor.visit(node) call_node.args = rhs_visitor.result diff --git a/tests/vulnerabilities/vulnerabilities_test.py b/tests/vulnerabilities/vulnerabilities_test.py index ff48c38c..599ea4aa 100644 --- a/tests/vulnerabilities/vulnerabilities_test.py +++ b/tests/vulnerabilities/vulnerabilities_test.py @@ -111,8 +111,9 @@ def test_build_sanitiser_node_dict(self): self.assertEqual(sanitiser_dict['escape'][0], cfg.nodes[3]) - def run_analysis(self, path): - self.cfg_create_from_file(path) + def run_analysis(self, path=None): + if path: + self.cfg_create_from_file(path) cfg_list = [self.cfg] FrameworkAdaptor(cfg_list, [], [], is_flask_route_function) @@ -468,6 +469,20 @@ def test_yield(self): self.assertAlphaEqual(str(vuln), EXPECTED_VULNERABILITY_DESCRIPTION) + def test_method_of_taint(self): + def assert_vulnerable(fixture): + tree = ast.parse('TAINT = request.args.get("")\n' + fixture + '\nexecute(result)') + self.cfg_create_from_ast(tree) + vulnerabilities = self.run_analysis() + self.assert_length(vulnerabilities, expected_length=1, msg=fixture) + + assert_vulnerable('result = TAINT') + assert_vulnerable('result = TAINT.lower()') + assert_vulnerable('result = str(TAINT)') + assert_vulnerable('result = str(TAINT.lower())') + assert_vulnerable('result = repr(str("%s" % TAINT.lower().upper()))') + assert_vulnerable('result = repr(str("{}".format(TAINT.lower())))') + class EngineDjangoTest(VulnerabilitiesBaseTestCase): def run_analysis(self, path):