Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Emit a diagnostic if the value of a starred expression or a yield from expression is not iterable #13240

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 89 additions & 20 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,27 +401,35 @@ impl<'db> Type<'db> {
/// pass
/// ```
///
/// Returns `None` if `self` represents a type that is not iterable.
fn iterate(&self, db: &'db dyn Db) -> Option<Type<'db>> {
// `self` represents the type of the iterable;
// `__iter__` and `__next__` are both looked up on the class of the iterable:
let type_of_class = self.to_meta_type(db);

let dunder_iter_method = type_of_class.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let iterator_ty = dunder_iter_method.call(db)?;
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method.call(db);
}
/// Emits a diagnostic and returns `Unknown` if `self` represents a type that is not iterable.
AlexWaygood marked this conversation as resolved.
Show resolved Hide resolved
fn iterate(&self, db: &'db dyn Db, mut diagnostic_fn: impl FnMut()) -> Type<'db> {
MichaReiser marked this conversation as resolved.
Show resolved Hide resolved
fn loop_var_from_iterable<'db>(
db: &'db dyn Db,
iterable_ty: &Type<'db>,
) -> Option<Type<'db>> {
// `__iter__` and `__next__` are both looked up on the class of the iterable:
let iterable_meta_type = iterable_ty.to_meta_type(db);

let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let iterator_ty = dunder_iter_method.call(db)?;
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method.call(db);
}

// Although it's not considered great practice,
// classes that define `__getitem__` are also iterable,
// even if they do not define `__iter__`.
//
// TODO this is only valid if the `__getitem__` method is annotated as
// accepting `int` or `SupportsIndex`
let dunder_get_item_method = type_of_class.member(db, "__getitem__");
dunder_get_item_method.call(db)
// Although it's not considered great practice,
// classes that define `__getitem__` are also iterable,
// even if they do not define `__iter__`.
//
// TODO this is only valid if the `__getitem__` method is annotated as
// accepting `int` or `SupportsIndex`
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");
dunder_get_item_method.call(db)
}
AlexWaygood marked this conversation as resolved.
Show resolved Hide resolved
loop_var_from_iterable(db, self).unwrap_or_else(|| {
diagnostic_fn();
Type::Unknown
})
}

#[must_use]
Expand Down Expand Up @@ -789,4 +797,65 @@ mod tests {
&["Object of type 'NotIterable' is not iterable"],
);
}

#[test]
fn starred_expressions_must_be_iterable() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
class NotIterable: pass

class Iterator:
def __next__(self) -> int:
return 42

class Iterable:
def __iter__(self) -> Iterator:

x = [*NotIterable()]
y = [*Iterable()]
",
)
.unwrap();

let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}

#[test]
fn yield_from_expression_must_be_iterable() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
class NotIterable: pass

class Iterator:
def __next__(self) -> int:
return 42

class Iterable:
def __iter__(self) -> Iterator:

def generator_function():
yield from Iterable()
yield from NotIterable()
",
)
.unwrap();

let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}
}
36 changes: 23 additions & 13 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,18 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_body(orelse);
}

/// Emit a diagnostic declaring that the object represented by `node` is not iterable
fn not_iterable_diagnostic(&mut self, node: AnyNodeRef, iterable_ty: Type<'db>) {
self.add_diagnostic(
node,
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
iterable_ty.display(self.db)
),
);
}

fn infer_for_statement_definition(
&mut self,
target: &ast::ExprName,
Expand All @@ -1042,16 +1054,8 @@ impl<'db> TypeInferenceBuilder<'db> {
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));

let loop_var_value_ty = iterable_ty.iterate(self.db).unwrap_or_else(|| {
self.add_diagnostic(
iterable.into(),
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
iterable_ty.display(self.db)
),
);
Type::Unknown
let loop_var_value_ty = iterable_ty.iterate(self.db, || {
self.not_iterable_diagnostic(iterable.into(), iterable_ty);
});

self.types
Expand Down Expand Up @@ -1812,7 +1816,10 @@ impl<'db> TypeInferenceBuilder<'db> {
ctx: _,
} = starred;

self.infer_expression(value);
let iterable_ty = self.infer_expression(value);
iterable_ty.iterate(self.db, || {
self.not_iterable_diagnostic(value.as_ref().into(), iterable_ty);
});

// TODO
Type::Unknown
Expand All @@ -1830,9 +1837,12 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_yield_from_expression(&mut self, yield_from: &ast::ExprYieldFrom) -> Type<'db> {
let ast::ExprYieldFrom { range: _, value } = yield_from;

self.infer_expression(value);
let iterable_ty = self.infer_expression(value);
iterable_ty.iterate(self.db, || {
self.not_iterable_diagnostic(value.as_ref().into(), iterable_ty);
});

// TODO get type from awaitable
// TODO get type from `SendType` of generator/awaitable
Type::Unknown
}

Expand Down
Loading