Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TransformGlobals config option #786

Merged
merged 1 commit into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Unreleased

- Add a mechanism to allow overriding the global variables in an
analyzed module. Use this mechanism to set the type of
`qcore.testing.Anything` to `Any`. (#786)
- Rename the `is_compatible` and `get_compatibility_error` functions
to `is_assignable` and `get_assignability_error` to align with the
terminology in the typing spec (#785)
Expand Down
27 changes: 25 additions & 2 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import asynq
import qcore
import typeshed_client
from qcore.testing import Anything
from typing_extensions import Annotated, Protocol, get_args, get_origin

from . import attributes, format_strings, importer, node_visitor, type_evaluation
Expand Down Expand Up @@ -627,6 +628,19 @@ def should_check_for_duplicate_values(cls: object, options: Options) -> bool:
return True


def _anything_to_any(obj: object) -> Optional[Value]:
if obj is Anything:
return AnyValue(AnySource.explicit)
return None


class TransformGlobals(PyObjectSequenceOption[Callable[[object], Optional[Value]]]):
"""Transform global variables."""

name = "transform_globals"
default_value = [_anything_to_any]


class IgnoredTypesForAttributeChecking(PyObjectSequenceOption[type]):
"""Used in the check for object attributes that are accessed but not set. In general, the check
will only alert about attributes that don't exist when it has visited all the base classes of
Expand Down Expand Up @@ -1181,6 +1195,7 @@ def __init__(
self.scopes = build_stacked_scopes(
self.module,
simplification_limit=self.options.get_value_for(UnionSimplificationLimit),
options=self.options,
)
self.node_context = StackedContexts()
self.asynq_checker = AsynqChecker(
Expand Down Expand Up @@ -5916,7 +5931,10 @@ def visit_expression(self, node: ast.AST) -> Value:


def build_stacked_scopes(
module: Optional[types.ModuleType], simplification_limit: Optional[int] = None
module: Optional[types.ModuleType],
simplification_limit: Optional[int] = None,
*,
options: Options,
) -> StackedScopes:
# Build a StackedScopes object.
# Not part of stacked_scopes.py to avoid a circular dependency.
Expand All @@ -5928,7 +5946,12 @@ def build_stacked_scopes(
for key, value in module.__dict__.items():
val = type_from_annotations(annotations, key, globals=module.__dict__)
if val is None:
val = KnownValue(value)
for transformer in options.get_value_for(TransformGlobals):
maybe_val = transformer(value)
if maybe_val is not None:
val = maybe_val
if val is None:
val = KnownValue(value)
module_vars[key] = val
return StackedScopes(module_vars, module, simplification_limit=simplification_limit)

Expand Down
7 changes: 7 additions & 0 deletions pyanalyze/test_name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,13 @@ def capybara(foo):

assert_is_value(assert_eq, KnownValue(_assert_eq))

@assert_passes()
def test_transform_globals(self):
from qcore.testing import Anything

def f():
assert_is_value(Anything, AnyValue(AnySource.explicit))


class TestComprehensions(TestNameCheckVisitorBase):
@assert_passes()
Expand Down
3 changes: 2 additions & 1 deletion pyanalyze/test_stacked_scopes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# static analysis: ignore
from .error_code import ErrorCode
from .name_check_visitor import build_stacked_scopes
from .options import Options
from .stacked_scopes import ScopeType, uniq_chain
from .test_name_check_visitor import TestNameCheckVisitorBase
from .test_node_visitor import assert_passes, skip_before
Expand Down Expand Up @@ -29,7 +30,7 @@ class Module:

class TestStackedScopes:
def setup_method(self):
self.scopes = build_stacked_scopes(Module)
self.scopes = build_stacked_scopes(Module, options=Options({}))

def test_scope_type(self):
assert ScopeType.module_scope == self.scopes.scope_type()
Expand Down
Loading