Skip to content

Commit

Permalink
[red-knot] Add type inference for comprehension targets
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Sep 9, 2024
1 parent b04948f commit ef4caf1
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ where
iterable: &node.iter,
target: name_node,
first,
is_async: node.is_async,
},
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::ExprName,
pub(crate) first: bool,
pub(crate) is_async: bool,
}

#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -227,10 +228,12 @@ impl DefinitionNodeRef<'_> {
iterable,
target,
first,
is_async,
}) => DefinitionKind::Comprehension(ComprehensionDefinitionKind {
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
first,
is_async,
}),
DefinitionNodeRef::Parameter(parameter) => match parameter {
ast::AnyParameterRef::Variadic(parameter) => {
Expand Down Expand Up @@ -337,6 +340,7 @@ pub struct ComprehensionDefinitionKind {
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
first: bool,
is_async: bool,
}

impl ComprehensionDefinitionKind {
Expand All @@ -351,6 +355,10 @@ impl ComprehensionDefinitionKind {
pub(crate) fn is_first(&self) -> bool {
self.first
}

pub(crate) fn is_async(&self) -> bool {
self.is_async
}
}

#[derive(Clone, Debug)]
Expand Down
269 changes: 255 additions & 14 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ impl<'db> TypeInferenceBuilder<'db> {
comprehension.iterable(),
comprehension.target(),
comprehension.is_first(),
comprehension.is_async(),
definition,
);
}
Expand Down Expand Up @@ -1444,7 +1445,7 @@ impl<'db> TypeInferenceBuilder<'db> {

let expr_id = expression.scoped_ast_id(self.db, self.scope);
let previous = self.types.expressions.insert(expr_id, ty);
assert!(previous.is_none());
assert_eq!(previous, None);

ty
}
Expand Down Expand Up @@ -1747,22 +1748,40 @@ impl<'db> TypeInferenceBuilder<'db> {
iterable: &ast::Expr,
target: &ast::ExprName,
is_first: bool,
is_async: bool,
definition: Definition<'db>,
) {
if !is_first {
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);

// Two things are different if it's the first comprehension:
// (1) We must lookup the type of the expression in the outer scope
// (symbols defined in this scope are not available for definitions
// created by the first comprehension)
// (2) We must *not* call `self.extend()` on the result of the type inference,
// or we'll doubly infer the type of releveant symbols
// (any symbols referenced from the first comprehension will have already been
// inferred when analysing outer scopes)
let iterable_ty = if is_first {
let lookup_scope = self
.index
.parent_scope_id(self.scope.file_scope_id(self.db))
.expect("A comprehension should never be the top-level scope")
.to_scope_id(self.db, self.file);
result.expression_ty(iterable.scoped_ast_id(self.db, lookup_scope))
} else {
self.extend(result);
let _iterable_ty = self
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
}
// TODO(dhruvmanila): The iter type for the first comprehension is coming from the
// enclosing scope.
result.expression_ty(iterable.scoped_ast_id(self.db, self.scope))
};

// TODO(dhruvmanila): The target type should be inferred based on the iter type instead,
// similar to how it's done in `infer_for_statement_definition`.
let target_ty = Type::Unknown;
let target_ty = if is_async {
// TODO: async iterables/iterators! -- Alex
Type::Unknown
} else {
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(iterable.into(), self)
};

self.types
.expressions
Expand Down Expand Up @@ -4191,7 +4210,6 @@ mod tests {
",
)?;

// TODO(Alex) async iterables/iterators!
assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown");

Ok(())
Expand Down Expand Up @@ -4326,6 +4344,229 @@ mod tests {
Ok(())
}

#[test]
fn basic_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
def foo():
[x for y in IterableOfIterables() for x in y]
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
class IteratorOfIterables:
def __next__(self) -> IntIterable:
return IntIterable()
class IterableOfIterables:
def __iter__(self) -> IteratorOfIterables:
return IteratorOfIterables()
",
)?;

assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "int");
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "y", "IntIterable");

Ok(())
}

#[test]
fn comprehension_inside_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
def foo():
[[x for x in iter1] for y in iter2]
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
iter1 = IntIterable()
iter2 = IntIterable()
",
)?;

assert_scope_ty(
&db,
"src/a.py",
&["foo", "<listcomp>", "<listcomp>"],
"x",
"int",
);
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "y", "int");

Ok(())
}

#[test]
fn comprehension_with_not_iterable_iter() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
[z for z in x]
",
)?;

assert_scope_ty(&db, "src/a.py", &["<listcomp>"], "x", "Unbound");

// Iterating over an `Unbound` yields `Unknown`:
assert_scope_ty(&db, "src/a.py", &["<listcomp>"], "z", "Unknown");

// TODO: not the greatest error message in the world! --Alex
assert_file_diagnostics(
&db,
"src/a.py",
&["Object of type 'Unbound' is not iterable"],
);

Ok(())
}

#[test]
fn comprehension_with_not_iterable_iter_in_second_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
def foo():
[z for x in IntIterable() for z in x]
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
",
)?;

assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "int");
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "z", "Unknown");
assert_file_diagnostics(&db, "src/a.py", &["Object of type 'int' is not iterable"]);

Ok(())
}

#[test]
fn dict_comprehension_variable_key() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
def foo():
{x: 0 for x in IntIterable()}
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
",
)?;

assert_scope_ty(&db, "src/a.py", &["foo", "<dictcomp>"], "x", "int");
assert_file_diagnostics(&db, "src/a.py", &[]);

Ok(())
}

#[test]
fn dict_comprehension_variable_value() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
def foo():
{0: x for x in IntIterable()}
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
",
)?;

assert_scope_ty(&db, "src/a.py", &["foo", "<dictcomp>"], "x", "int");
assert_file_diagnostics(&db, "src/a.py", &[]);

Ok(())
}

/// This tests that we understand that `async` comprehensions
/// do not work according to the synchronous iteration protocol
#[test]
fn invalid_async_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
async def foo():
[x async for x in Iterable()]
class Iterator:
def __next__(self) -> int:
return 42
class Iterable:
def __iter__(self) -> Iterator:
return Iterator()
",
)?;

assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "Unknown");

Ok(())
}

#[test]
fn basic_async_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
async def foo():
[x async for x in AsyncIterable()]
class AsyncIterator:
async def __anext__(self) -> int:
return 42
class AsyncIterable:
def __aiter__(self) -> AsyncIterator:
return AsyncIterator()
",
)?;

// TODO async iterables/iterators! --Alex
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "x", "Unknown");

Ok(())
}

#[test]
fn invalid_iterable() {
let mut db = setup_db();
Expand Down

0 comments on commit ef4caf1

Please sign in to comment.