diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 72fe24064c25d..f78eec0937d3e 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2407,6 +2407,104 @@ impl<'db> TypeInferenceBuilder<'db> { } } + ( + Type::Instance(cls), + Type::IntLiteral(_), + ast::Operator::Mult + | ast::Operator::Div + | ast::Operator::Add + | ast::Operator::Sub + | ast::Operator::Mod, + ) if cls.is_stdlib_symbol(self.db, "builtins", "float") => { + builtins_symbol_ty(self.db, "float").to_instance(self.db) + } + + ( + Type::Instance(cls), + Type::IntLiteral(_), + ast::Operator::Mult + | ast::Operator::Add + | ast::Operator::Sub + | ast::Operator::FloorDiv + | ast::Operator::Mod, + ) if cls.is_stdlib_symbol(self.db, "builtins", "int") => { + builtins_symbol_ty(self.db, "int").to_instance(self.db) + } + + (Type::Instance(cls), Type::IntLiteral(_), ast::Operator::Div) + if cls.is_stdlib_symbol(self.db, "builtins", "int") => + { + builtins_symbol_ty(self.db, "float").to_instance(self.db) + } + + (Type::Instance(left_cls), Type::Instance(right_cls), ast::Operator::Div) + if left_cls.is_stdlib_symbol(self.db, "builtins", "int") + && (right_cls.is_stdlib_symbol(self.db, "builtins", "int") + || right_cls.is_stdlib_symbol(self.db, "builtins", "float")) => + { + builtins_symbol_ty(self.db, "float").to_instance(self.db) + } + + (Type::IntLiteral(_), Type::Instance(cls), ast::Operator::Div) + if cls.is_stdlib_symbol(self.db, "builtins", "int") => + { + builtins_symbol_ty(self.db, "float").to_instance(self.db) + } + + ( + Type::Instance(left_cls), + Type::Instance(right_cls), + ast::Operator::Mult + | ast::Operator::Add + | ast::Operator::Sub + | ast::Operator::FloorDiv + | ast::Operator::Mod, + ) if left_cls.is_stdlib_symbol(self.db, "builtins", "int") + && right_cls.is_stdlib_symbol(self.db, "builtins", "int") => + { + builtins_symbol_ty(self.db, "int").to_instance(self.db) + } + + ( + Type::Instance(left_cls), + Type::Instance(right_cls), + ast::Operator::Mult + | ast::Operator::Add + | ast::Operator::Sub + | ast::Operator::Div + | ast::Operator::Mod + | ast::Operator::FloorDiv, + ) if left_cls.is_stdlib_symbol(self.db, "builtins", "float") + && right_cls.is_stdlib_symbol(self.db, "builtins", "float") => + { + builtins_symbol_ty(self.db, "float").to_instance(self.db) + } + + ( + Type::IntLiteral(_), + Type::Instance(cls), + ast::Operator::Mult + | ast::Operator::Div + | ast::Operator::Add + | ast::Operator::Sub + | ast::Operator::Mod + | ast::Operator::FloorDiv, + ) if cls.is_stdlib_symbol(self.db, "builtins", "float") => { + builtins_symbol_ty(self.db, "float").to_instance(self.db) + } + + ( + Type::IntLiteral(_), + Type::Instance(cls), + ast::Operator::Mult + | ast::Operator::Add + | ast::Operator::Sub + | ast::Operator::Mod + | ast::Operator::FloorDiv, + ) if cls.is_stdlib_symbol(self.db, "builtins", "int") => { + builtins_symbol_ty(self.db, "int").to_instance(self.db) + } + _ => Type::Todo, // TODO } } @@ -4196,6 +4294,148 @@ mod tests { Ok(()) } + #[test] + fn int_arithmetic() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + x = int() + + a = x + 1 + b = x - 4 + c = x * 5 + d = x // 3 + e = x / 3 + f = x % 3 + ", + )?; + + assert_public_ty(&db, "src/a.py", "a", "int"); + assert_public_ty(&db, "src/a.py", "b", "int"); + assert_public_ty(&db, "src/a.py", "c", "int"); + assert_public_ty(&db, "src/a.py", "d", "int"); + assert_public_ty(&db, "src/a.py", "e", "float"); + assert_public_ty(&db, "src/a.py", "f", "int"); + + db.write_dedented( + "src/b.py", + " + x = int() + + a = 1 + x + b = 4 - x + c = 5 * x + d = 3 // x + e = 3 / x + f = 3 % x + ", + )?; + + assert_public_ty(&db, "src/b.py", "a", "int"); + assert_public_ty(&db, "src/b.py", "b", "int"); + assert_public_ty(&db, "src/b.py", "c", "int"); + assert_public_ty(&db, "src/b.py", "d", "int"); + assert_public_ty(&db, "src/b.py", "e", "float"); + assert_public_ty(&db, "src/b.py", "f", "int"); + + db.write_dedented( + "src/c.py", + " + x = int() + y = int() + + a = x + y + b = x - y + c = x * y + d = x // y + e = x / y + f = x % y + ", + )?; + + assert_public_ty(&db, "src/c.py", "a", "int"); + assert_public_ty(&db, "src/c.py", "b", "int"); + assert_public_ty(&db, "src/c.py", "c", "int"); + assert_public_ty(&db, "src/c.py", "d", "int"); + assert_public_ty(&db, "src/c.py", "e", "float"); + assert_public_ty(&db, "src/c.py", "f", "int"); + + Ok(()) + } + + #[test] + fn float_arithmetic() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + x = float() + + a = x + 1 + b = x - 4 + c = x * 5 + d = x // 3 + e = x / 3 + f = x % 3 + ", + )?; + + assert_public_ty(&db, "src/a.py", "a", "float"); + assert_public_ty(&db, "src/a.py", "b", "float"); + assert_public_ty(&db, "src/a.py", "c", "float"); + assert_public_ty(&db, "src/a.py", "d", "float"); + assert_public_ty(&db, "src/a.py", "e", "float"); + assert_public_ty(&db, "src/a.py", "f", "float"); + + db.write_dedented( + "src/b.py", + " + x = float() + + a = 1 + x + b = 4 - x + c = 5 * x + d = 3 // x + e = 3 / x + f = 3 % x + ", + )?; + + assert_public_ty(&db, "src/b.py", "a", "float"); + assert_public_ty(&db, "src/b.py", "b", "float"); + assert_public_ty(&db, "src/b.py", "c", "float"); + assert_public_ty(&db, "src/b.py", "d", "float"); + assert_public_ty(&db, "src/b.py", "e", "float"); + assert_public_ty(&db, "src/b.py", "f", "float"); + + db.write_dedented( + "src/c.py", + " + x = float() + y = float() + + a = x + y + b = x - y + c = x * y + d = x // y + e = x / y + f = x % y + ", + )?; + + assert_public_ty(&db, "src/c.py", "a", "float"); + assert_public_ty(&db, "src/c.py", "b", "float"); + assert_public_ty(&db, "src/c.py", "c", "float"); + assert_public_ty(&db, "src/c.py", "d", "float"); + assert_public_ty(&db, "src/c.py", "e", "float"); + assert_public_ty(&db, "src/c.py", "f", "float"); + + Ok(()) + } + #[test] fn division_by_zero() -> anyhow::Result<()> { let mut db = setup_db(); @@ -4214,9 +4454,8 @@ mod tests { assert_public_ty(&db, "/src/a.py", "a", "float"); assert_public_ty(&db, "/src/a.py", "b", "int"); assert_public_ty(&db, "/src/a.py", "c", "int"); - // TODO: These should be `int` and `float` respectively once we support inference - assert_public_ty(&db, "/src/a.py", "d", "@Todo"); - assert_public_ty(&db, "/src/a.py", "e", "@Todo"); + assert_public_ty(&db, "/src/a.py", "d", "int"); + assert_public_ty(&db, "/src/a.py", "e", "float"); assert_file_diagnostics( &db,