Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement arithmetic inference for binary expressions with float and int instances #13590

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 242 additions & 3 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,104 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}

(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might futz with the organization of these and add some comments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @carljm regarding the handling of IntLiteral without it being a special case

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we probably want a general-purpose function like this for binary operations between instance types:

fn perform_binary_operation<'db>(
    db: &'db dyn Db,
    left: ClassType<'db>,
    right: ClassType<'db>,
    dunder_name: &str,
) -> Option<Type<'db>> {
    // TODO the reflected dunder actually has priority if the r.h.s. is a strict subclass of the l.h.s.
    // TODO Some other complications too!

    let dunder = left.class_member(db, dunder_name);
    if !dunder.is_unbound() {
        return dunder
            .call(db, &[Type::Instance(left), Type::Instance(right)])
            .return_ty(db);
    }

    let reflected_dunder = right.class_member(db, &format!("r{dunder_name}"));
    if !reflected_dunder.is_unbound() {
        return dunder
            .call(db, &[Type::Instance(right), Type::Instance(left)])
            .return_ty(db);
    }

    None
}

And then for e.g. inferring x * y, where x is an instance type and y is an IntLiteral type we'd just do something like

let Type::Instance(x_class) = x;
let int_class = builtins_symbol_ty(&db, "int");
perform_binary_operation(db, x_class, int_class, "__mul__")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What perform_binary_operation when passed "__mul__" does is:

  • Lookup the __mul__ function on the type of x
  • If it exists, call __mul__(x, y)
  • If type(x).__mul__ did not exist, lookup __rmul__ on the type of y
  • If type(y).__rmul__ exists, call __rmul__(y, x)
  • Else, return None

This is a generalised routine that will work for inferring binary operations between any two instance types. And unless we have a literal on both sides of the binary operation, we may as well treat an IntLiteral variant as "just an instance of int", and fallback to the generalised routine

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The algorithm for inferring binary operations is unfortunately really really complicated in its totality, though. See https://snarky.ca/unravelling-binary-arithmetic-operations-in-python/ for all the gory details!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! Thanks for the context.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, @AlexWaygood perfectly summarized what I was talking about in standup, and with lots more useful details, too!!

Copy link
Contributor

@carljm carljm Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so either way we are going to have large match statements, but once we are in the realm of "things typeshed can tell us" (which is all of the cases added in this PR), we should avoid adding new special case match arms and just have a generic version that looks up and "calls" the appropriate dunder methods like Alex suggests. Pretty much we should only have special-case implementations for literal types if there is a possibility we can infer a literal type out of the operation (something typeshed clearly can't do), otherwise we should be falling back to the general case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I was thinking something was wrong here but wasn't aware typeshed does all of this.

Copy link
Member

@AlexWaygood AlexWaygood Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, we go to great lengths to accurately reflect all the dunders in typeshed: https://github.com/python/typeshed/blob/44aa63330b03bdacab731af2333ff9bf70855de3/stdlib/builtins.pyi#L227-L330

And we have quite extensive testing to check that none are omitted from the stub or have inconsistent signatures with what actually exists at runtime.

Type::Instance(cls),
Type::IntLiteral(_),
ast::Operator::Mult
| ast::Operator::Div
| ast::Operator::Add
| ast::Operator::Sub
| ast::Operator::Mod,
) if cls.is_stdlib_symbol(self.db, "builtins", "float") => {
builtins_symbol_ty(self.db, "float").to_instance(self.db)
}

(
Type::Instance(cls),
Type::IntLiteral(_),
ast::Operator::Mult
| ast::Operator::Add
| ast::Operator::Sub
| ast::Operator::FloorDiv
| ast::Operator::Mod,
) if cls.is_stdlib_symbol(self.db, "builtins", "int") => {
builtins_symbol_ty(self.db, "int").to_instance(self.db)
}

(Type::Instance(cls), Type::IntLiteral(_), ast::Operator::Div)
if cls.is_stdlib_symbol(self.db, "builtins", "int") =>
{
builtins_symbol_ty(self.db, "float").to_instance(self.db)
}

(Type::Instance(left_cls), Type::Instance(right_cls), ast::Operator::Div)
if left_cls.is_stdlib_symbol(self.db, "builtins", "int")
&& (right_cls.is_stdlib_symbol(self.db, "builtins", "int")
|| right_cls.is_stdlib_symbol(self.db, "builtins", "float")) =>
{
builtins_symbol_ty(self.db, "float").to_instance(self.db)
}

(Type::IntLiteral(_), Type::Instance(cls), ast::Operator::Div)
if cls.is_stdlib_symbol(self.db, "builtins", "int") =>
{
builtins_symbol_ty(self.db, "float").to_instance(self.db)
}

(
Type::Instance(left_cls),
Type::Instance(right_cls),
ast::Operator::Mult
| ast::Operator::Add
| ast::Operator::Sub
| ast::Operator::FloorDiv
| ast::Operator::Mod,
) if left_cls.is_stdlib_symbol(self.db, "builtins", "int")
&& right_cls.is_stdlib_symbol(self.db, "builtins", "int") =>
{
builtins_symbol_ty(self.db, "int").to_instance(self.db)
}

(
Type::Instance(left_cls),
Type::Instance(right_cls),
ast::Operator::Mult
| ast::Operator::Add
| ast::Operator::Sub
| ast::Operator::Div
| ast::Operator::Mod
| ast::Operator::FloorDiv,
) if left_cls.is_stdlib_symbol(self.db, "builtins", "float")
&& right_cls.is_stdlib_symbol(self.db, "builtins", "float") =>
{
builtins_symbol_ty(self.db, "float").to_instance(self.db)
}

(
Type::IntLiteral(_),
Type::Instance(cls),
ast::Operator::Mult
| ast::Operator::Div
| ast::Operator::Add
| ast::Operator::Sub
| ast::Operator::Mod
| ast::Operator::FloorDiv,
) if cls.is_stdlib_symbol(self.db, "builtins", "float") => {
builtins_symbol_ty(self.db, "float").to_instance(self.db)
}

(
Type::IntLiteral(_),
Type::Instance(cls),
ast::Operator::Mult
| ast::Operator::Add
| ast::Operator::Sub
| ast::Operator::Mod
| ast::Operator::FloorDiv,
) if cls.is_stdlib_symbol(self.db, "builtins", "int") => {
builtins_symbol_ty(self.db, "int").to_instance(self.db)
}

_ => Type::Todo, // TODO
}
}
Expand Down Expand Up @@ -4196,6 +4294,148 @@ mod tests {
Ok(())
}

#[test]
fn int_arithmetic() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
x = int()

a = x + 1
b = x - 4
c = x * 5
d = x // 3
e = x / 3
f = x % 3
",
)?;

assert_public_ty(&db, "src/a.py", "a", "int");
assert_public_ty(&db, "src/a.py", "b", "int");
assert_public_ty(&db, "src/a.py", "c", "int");
assert_public_ty(&db, "src/a.py", "d", "int");
assert_public_ty(&db, "src/a.py", "e", "float");
assert_public_ty(&db, "src/a.py", "f", "int");

db.write_dedented(
"src/b.py",
"
x = int()

a = 1 + x
b = 4 - x
c = 5 * x
d = 3 // x
e = 3 / x
f = 3 % x
",
)?;

assert_public_ty(&db, "src/b.py", "a", "int");
assert_public_ty(&db, "src/b.py", "b", "int");
assert_public_ty(&db, "src/b.py", "c", "int");
assert_public_ty(&db, "src/b.py", "d", "int");
assert_public_ty(&db, "src/b.py", "e", "float");
assert_public_ty(&db, "src/b.py", "f", "int");

db.write_dedented(
"src/c.py",
"
x = int()
y = int()

a = x + y
b = x - y
c = x * y
d = x // y
e = x / y
f = x % y
",
)?;

assert_public_ty(&db, "src/c.py", "a", "int");
assert_public_ty(&db, "src/c.py", "b", "int");
assert_public_ty(&db, "src/c.py", "c", "int");
assert_public_ty(&db, "src/c.py", "d", "int");
assert_public_ty(&db, "src/c.py", "e", "float");
assert_public_ty(&db, "src/c.py", "f", "int");

Ok(())
}

#[test]
fn float_arithmetic() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
x = float()

a = x + 1
b = x - 4
c = x * 5
d = x // 3
e = x / 3
f = x % 3
",
)?;

assert_public_ty(&db, "src/a.py", "a", "float");
assert_public_ty(&db, "src/a.py", "b", "float");
assert_public_ty(&db, "src/a.py", "c", "float");
assert_public_ty(&db, "src/a.py", "d", "float");
assert_public_ty(&db, "src/a.py", "e", "float");
assert_public_ty(&db, "src/a.py", "f", "float");

db.write_dedented(
"src/b.py",
"
x = float()

a = 1 + x
b = 4 - x
c = 5 * x
d = 3 // x
e = 3 / x
f = 3 % x
",
)?;

assert_public_ty(&db, "src/b.py", "a", "float");
assert_public_ty(&db, "src/b.py", "b", "float");
assert_public_ty(&db, "src/b.py", "c", "float");
assert_public_ty(&db, "src/b.py", "d", "float");
assert_public_ty(&db, "src/b.py", "e", "float");
assert_public_ty(&db, "src/b.py", "f", "float");

db.write_dedented(
"src/c.py",
"
x = float()
y = float()

a = x + y
b = x - y
c = x * y
d = x // y
e = x / y
f = x % y
",
)?;

assert_public_ty(&db, "src/c.py", "a", "float");
assert_public_ty(&db, "src/c.py", "b", "float");
assert_public_ty(&db, "src/c.py", "c", "float");
assert_public_ty(&db, "src/c.py", "d", "float");
assert_public_ty(&db, "src/c.py", "e", "float");
assert_public_ty(&db, "src/c.py", "f", "float");

Ok(())
}

#[test]
fn division_by_zero() -> anyhow::Result<()> {
let mut db = setup_db();
Expand All @@ -4214,9 +4454,8 @@ mod tests {
assert_public_ty(&db, "/src/a.py", "a", "float");
assert_public_ty(&db, "/src/a.py", "b", "int");
assert_public_ty(&db, "/src/a.py", "c", "int");
// TODO: These should be `int` and `float` respectively once we support inference
assert_public_ty(&db, "/src/a.py", "d", "@Todo");
assert_public_ty(&db, "/src/a.py", "e", "@Todo");
assert_public_ty(&db, "/src/a.py", "d", "int");
assert_public_ty(&db, "/src/a.py", "e", "float");

assert_file_diagnostics(
&db,
Expand Down
Loading