Skip to content

Commit

Permalink
[red-knot] implement basic call expression inference (astral-sh#13164)
Browse files Browse the repository at this point in the history
## Summary

Adds basic support for inferring the type resulting from a call
expression. This only works for the *result* of call expressions; it
performs no inference on parameters. It also intentionally does nothing
with class instantiation, `__call__` implementors, or lambdas.

## Test Plan

Adds a test that it infers the right thing!

---------

Co-authored-by: Carl Meyer <[email protected]>
  • Loading branch information
chriskrycho and carljm authored Aug 30, 2024
1 parent a73bebc commit 28ab5f4
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 18 deletions.
48 changes: 48 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,33 @@ impl<'db> Type<'db> {
}
}

/// Return the type resulting from calling an object of this type.
///
/// Returns `None` if `self` is not a callable type.
#[must_use]
pub fn call(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Type::Function(function_type) => function_type.returns(db).or(Some(Type::Unknown)),

// TODO: handle class constructors
Type::Class(_class_ty) => Some(Type::Unknown),

// TODO: handle classes which implement the Callable protocol
Type::Instance(_instance_ty) => Some(Type::Unknown),

// `Any` is callable, and its return type is also `Any`.
Type::Any => Some(Type::Any),

Type::Unknown => Some(Type::Unknown),

// TODO: union and intersection types, if they reduce to `Callable`
Type::Union(_) => Some(Type::Unknown),
Type::Intersection(_) => Some(Type::Unknown),

_ => None,
}
}

#[must_use]
pub fn instance(&self) -> Type<'db> {
match self {
Expand Down Expand Up @@ -550,4 +577,25 @@ mod tests {
let b_file_diagnostics = super::check_types(&db, b_file);
assert_eq!(&*b_file_diagnostics, &[]);
}

#[test]
fn invalid_callable() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
nonsense = 123
x = nonsense()
",
)
.unwrap();

let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'Literal[123]' is not callable"],
);
}
}
62 changes: 44 additions & 18 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -868,19 +868,16 @@ impl<'db> TypeInferenceBuilder<'db> {
value,
} = assignment;

// TODO remove once we infer definitions in unpacking assignment, since that infers the RHS
// too, and uses the `infer_expression_types` query to do it
self.infer_expression(value);

for target in targets {
match target {
ast::Expr::Name(name) => {
self.infer_definition(name);
}
_ => {
// TODO infer definitions in unpacking assignment
self.infer_expression(target);
}
if let ast::Expr::Name(name) = target {
self.infer_definition(name);
} else {
// TODO infer definitions in unpacking assignment. When we do, this duplication of
// the "get `Expression`, call `infer_expression_types` on it, `self.extend`" dance
// will be removed; it'll all happen in `infer_assignment_definition` instead.
let expression = self.index.expression(value.as_ref());
self.extend(infer_expression_types(self.db, expression));
self.infer_expression(target);
}
}
}
Expand Down Expand Up @@ -1363,7 +1360,8 @@ impl<'db> TypeInferenceBuilder<'db> {
};

let expr_id = expression.scoped_ast_id(self.db, self.scope);
self.types.expressions.insert(expr_id, ty);
let previous = self.types.expressions.insert(expr_id, ty);
assert!(previous.is_none());

ty
}
Expand Down Expand Up @@ -1746,10 +1744,18 @@ impl<'db> TypeInferenceBuilder<'db> {
} = call_expression;

self.infer_arguments(arguments);
self.infer_expression(func);

// TODO resolve to return type of `func`, if its a callable type
Type::Unknown
let function_type = self.infer_expression(func);
function_type.call(self.db).unwrap_or_else(|| {
self.add_diagnostic(
func.as_ref().into(),
"call-non-callable",
format_args!(
"Object of type '{}' is not callable",
function_type.display(self.db)
),
);
Type::Unknown
})
}

fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> {
Expand Down Expand Up @@ -2247,7 +2253,8 @@ impl<'db> TypeInferenceBuilder<'db> {
};

let expr_id = expression.scoped_ast_id(self.db, self.scope);
self.types.expressions.insert(expr_id, ty);
let previous = self.types.expressions.insert(expr_id, ty);
assert!(previous.is_none());

ty
}
Expand Down Expand Up @@ -2808,6 +2815,25 @@ mod tests {
Ok(())
}

#[test]
fn basic_call_expression() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
def get_int() -> int:
return 42
x = get_int()
",
)?;

assert_public_ty(&db, "src/a.py", "x", "int");

Ok(())
}

#[test]
fn resolve_union() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down
1 change: 1 addition & 0 deletions crates/ruff_benchmark/benches/red_knot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const TOMLLIB_312_URL: &str = "https://raw.githubusercontent.com/python/cpython/
// The "unresolved import" is because we don't understand `*` imports yet.
static EXPECTED_DIAGNOSTICS: &[&str] = &[
"/src/tomllib/_parser.py:7:29: Module 'collections.abc' has no member 'Iterable'",
"/src/tomllib/_parser.py:686:23: Object of type 'Unbound' is not callable",
"Line 69 is too long (89 characters)",
"Use double quotes for strings",
"Use double quotes for strings",
Expand Down

0 comments on commit 28ab5f4

Please sign in to comment.