diff --git a/docs/changelog.md b/docs/changelog.md index 3f2c1d8e..a168ea10 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,9 @@ ## Unreleased +- Add a mechanism to allow overriding the global variables in an + analyzed module. Use this mechanism to set the type of + `qcore.testing.Anything` to `Any`. (#786) - Rename the `is_compatible` and `get_compatibility_error` functions to `is_assignable` and `get_assignability_error` to align with the terminology in the typing spec (#785) diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 855b14ae..0a9e02d2 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -54,6 +54,7 @@ import asynq import qcore import typeshed_client +from qcore.testing import Anything from typing_extensions import Annotated, Protocol, get_args, get_origin from . import attributes, format_strings, importer, node_visitor, type_evaluation @@ -627,6 +628,19 @@ def should_check_for_duplicate_values(cls: object, options: Options) -> bool: return True +def _anything_to_any(obj: object) -> Optional[Value]: + if obj is Anything: + return AnyValue(AnySource.explicit) + return None + + +class TransformGlobals(PyObjectSequenceOption[Callable[[object], Optional[Value]]]): + """Transform global variables.""" + + name = "transform_globals" + default_value = [_anything_to_any] + + class IgnoredTypesForAttributeChecking(PyObjectSequenceOption[type]): """Used in the check for object attributes that are accessed but not set. In general, the check will only alert about attributes that don't exist when it has visited all the base classes of @@ -1181,6 +1195,7 @@ def __init__( self.scopes = build_stacked_scopes( self.module, simplification_limit=self.options.get_value_for(UnionSimplificationLimit), + options=self.options, ) self.node_context = StackedContexts() self.asynq_checker = AsynqChecker( @@ -5916,7 +5931,10 @@ def visit_expression(self, node: ast.AST) -> Value: def build_stacked_scopes( - module: Optional[types.ModuleType], simplification_limit: Optional[int] = None + module: Optional[types.ModuleType], + simplification_limit: Optional[int] = None, + *, + options: Options, ) -> StackedScopes: # Build a StackedScopes object. # Not part of stacked_scopes.py to avoid a circular dependency. @@ -5928,7 +5946,12 @@ def build_stacked_scopes( for key, value in module.__dict__.items(): val = type_from_annotations(annotations, key, globals=module.__dict__) if val is None: - val = KnownValue(value) + for transformer in options.get_value_for(TransformGlobals): + maybe_val = transformer(value) + if maybe_val is not None: + val = maybe_val + if val is None: + val = KnownValue(value) module_vars[key] = val return StackedScopes(module_vars, module, simplification_limit=simplification_limit) diff --git a/pyanalyze/test_name_check_visitor.py b/pyanalyze/test_name_check_visitor.py index 60618e65..c74ed2b7 100644 --- a/pyanalyze/test_name_check_visitor.py +++ b/pyanalyze/test_name_check_visitor.py @@ -921,6 +921,13 @@ def capybara(foo): assert_is_value(assert_eq, KnownValue(_assert_eq)) + @assert_passes() + def test_transform_globals(self): + from qcore.testing import Anything + + def f(): + assert_is_value(Anything, AnyValue(AnySource.explicit)) + class TestComprehensions(TestNameCheckVisitorBase): @assert_passes() diff --git a/pyanalyze/test_stacked_scopes.py b/pyanalyze/test_stacked_scopes.py index 57da5099..f69cf1b4 100644 --- a/pyanalyze/test_stacked_scopes.py +++ b/pyanalyze/test_stacked_scopes.py @@ -1,6 +1,7 @@ # static analysis: ignore from .error_code import ErrorCode from .name_check_visitor import build_stacked_scopes +from .options import Options from .stacked_scopes import ScopeType, uniq_chain from .test_name_check_visitor import TestNameCheckVisitorBase from .test_node_visitor import assert_passes, skip_before @@ -29,7 +30,7 @@ class Module: class TestStackedScopes: def setup_method(self): - self.scopes = build_stacked_scopes(Module) + self.scopes = build_stacked_scopes(Module, options=Options({})) def test_scope_type(self): assert ScopeType.module_scope == self.scopes.scope_type()