Skip to content

Commit

Permalink
Improve typevar solver for constrained TypeVars (#532)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Apr 27, 2022
1 parent 94ca131 commit 34f23c1
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- Improve `TypeVar` solution heuristic for constrained
typevars with multiple solutions (#532)
- Fix resolution of stringified annotations in `__init__` methods (#530)
- Type check `yield`, `yield from`, and `return` nodes in generators (#529)
- Type check calls to comparison operators (#527)
Expand Down
5 changes: 3 additions & 2 deletions pyanalyze/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,9 +904,10 @@ def capybara(s: str, b: bytes, sb: Union[str, bytes], unannotated):
assert_is_value(f(b"x"), TypedValue(bytes))
assert_is_value(f(s), TypedValue(str))
assert_is_value(f(b), TypedValue(bytes))
f(sb) # E: incompatible_argument
result = f(sb) # E: incompatible_argument
assert_is_value(result, AnyValue(AnySource.error))
f(3) # E: incompatible_argument
assert_is_value(f(unannotated), AnyValue(AnySource.inference))
assert_is_value(f(unannotated), AnyValue(AnySource.unannotated))

@assert_passes()
def test_constraint_in_typeshed(self):
Expand Down
13 changes: 11 additions & 2 deletions pyanalyze/test_typeshed.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,9 @@ def capybara():
wrapped("x") # E: incompatible_argument


class TestOpen(TestNameCheckVisitorBase):
class TestIntegration(TestNameCheckVisitorBase):
@assert_passes()
def test_basic(self):
def test_open(self):
import io
from typing import Any, BinaryIO, IO

Expand All @@ -682,3 +682,12 @@ def capybara(buffering: int, mode: str):
assert_type(open("x", "rb", buffering=0), io.FileIO)
assert_type(open("x", "rb", buffering=buffering), BinaryIO)
assert_type(open("x", mode, buffering=buffering), IO[Any])

@assert_passes()
def test_itertools_count(self):
import itertools

def capybara():
assert_is_value(
itertools.count(1), GenericValue(itertools.count, [TypedValue(int)])
)
29 changes: 29 additions & 0 deletions pyanalyze/test_typevar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
AnnotatedValue,
AnySource,
AnyValue,
GenericValue,
KnownValue,
MultiValuedValue,
TypedValue,
Expand Down Expand Up @@ -260,6 +261,34 @@ def capybara():
m = min(E)
assert_is_value(m, TypedValue(E))

@assert_passes()
def test_constraints(self):
from typing import List, TypeVar

LT = TypeVar("LT", List[int], List[str])

def g(x: LT) -> LT:
return x

def pacarana() -> None:
assert_is_value(g([]), AnyValue(AnySource.inference))
assert_is_value(g([1]), GenericValue(list, [TypedValue(int)]))

@assert_passes()
def test_redundant_constraints(self):
from typing import TypeVar
from typing_extensions import SupportsIndex

T = TypeVar("T", int, float, SupportsIndex)

def f(x: T) -> T:
return x

def capybara(si: SupportsIndex):
assert_is_value(f(1), TypedValue(int))
assert_is_value(f(si), TypedValue(SupportsIndex))
assert_is_value(f(1.0), TypedValue(float))


class TestAnnotated(TestNameCheckVisitorBase):
@assert_passes()
Expand Down
30 changes: 29 additions & 1 deletion pyanalyze/typevar.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,36 @@ def solve(
for option, can_assign in zip(options, can_assigns)
if not isinstance(can_assign, CanAssignError)
]
# If there's only one solution, pick it.
if len(available) == 1:
return available[0]
# TODO consider returning unite_values(*available) here instead.
# If we inferred Any, keep it; all the solutions will be valid, and
# picking one will lead to weird errors down the line.
if isinstance(solution, AnyValue):
return solution
available = remove_redundant_solutions(available, ctx)
if len(available) == 1:
return available[0]
# If there are still multiple options, we fall back to Any.
return AnyValue(AnySource.inference)
return solution


def remove_redundant_solutions(
solutions: Sequence[Value], ctx: CanAssignContext
) -> Sequence[Value]:
# This is going to be quadratic, so don't do it when there's too many
# opttions.
initial_count = len(solutions)
if initial_count > 10:
return solutions

temp_solutions = list(solutions)
for i in range(initial_count):
sol = temp_solutions[i]
for j, other in enumerate(temp_solutions):
if i == j or other is None:
continue
if sol.is_assignable(other, ctx) and not other.is_assignable(sol, ctx):
temp_solutions[i] = None
return [sol for sol in temp_solutions if sol is not None]

0 comments on commit 34f23c1

Please sign in to comment.