Skip to content

Commit

Permalink
Red Knot - Infer the return value of bool() (astral-sh#13538)
Browse files Browse the repository at this point in the history
## Summary
Following astral-sh#13449, this PR adds custom handling for the bool constructor,
so when the input type has statically known truthiness value, it will be
used as the return value of the bool function.
For example, in the following snippet x will now be resolved to
`Literal[True]` instead of `bool`.
```python
x = bool(1)
```

## Test Plan
Some cargo tests were added.
  • Loading branch information
TomerBin authored Sep 27, 2024
1 parent 1639488 commit ec72e67
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 7 deletions.
15 changes: 14 additions & 1 deletion crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,20 @@ impl<'db> Type<'db> {
),

// TODO annotated return type on `__new__` or metaclass `__call__`
Type::Class(class) => CallOutcome::callable(Type::Instance(class)),
Type::Class(class) => {
// If the class is the builtin-bool class (for example `bool(1)`), we try to return
// the specific truthiness value of the input arg, `Literal[True]` for the example above.
let is_bool = class.is_stdlib_symbol(db, "builtins", "bool");
CallOutcome::callable(if is_bool {
arg_types
.first()
.unwrap_or(&Type::Unknown)
.bool(db)
.into_type(db)
} else {
Type::Instance(class)
})
}

// TODO: handle classes which implement the `__call__` protocol
Type::Instance(_instance_ty) => CallOutcome::callable(Type::Unknown),
Expand Down
98 changes: 92 additions & 6 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2711,12 +2711,6 @@ mod tests {

use anyhow::Context;

use ruff_db::files::{system_path_to_file, File};
use ruff_db::parsed::parsed_module;
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
use ruff_db::testing::assert_function_query_was_not_run;
use ruff_python_ast::name::Name;

use crate::db::tests::TestDb;
use crate::program::{Program, SearchPathSettings};
use crate::python_version::PythonVersion;
Expand All @@ -2728,6 +2722,11 @@ mod tests {
check_types, global_symbol_ty, infer_definition_types, symbol_ty, TypeCheckDiagnostics,
};
use crate::{HasTy, ProgramSettings, SemanticModel};
use ruff_db::files::{system_path_to_file, File};
use ruff_db::parsed::parsed_module;
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
use ruff_db::testing::assert_function_query_was_not_run;
use ruff_python_ast::name::Name;

use super::TypeInferenceBuilder;

Expand Down Expand Up @@ -6483,4 +6482,91 @@ mod tests {
assert_public_ty(&db, "/src/a.py", "f", r#"Literal["x"]"#);
Ok(())
}

#[test]
fn bool_function_falsy_values() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
r#"
a = bool(0)
b = bool(())
c = bool(None)
d = bool("")
e = bool(False)
"#,
)?;
assert_public_ty(&db, "/src/a.py", "a", "Literal[False]");
assert_public_ty(&db, "/src/a.py", "b", "Literal[False]");
assert_public_ty(&db, "/src/a.py", "c", "Literal[False]");
assert_public_ty(&db, "/src/a.py", "d", "Literal[False]");
assert_public_ty(&db, "/src/a.py", "e", "Literal[False]");
Ok(())
}

#[test]
fn builtin_bool_function_detected() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
redefined_builtin_bool = bool
def my_bool(x)-> bool: pass
",
)?;
db.write_dedented(
"/src/b.py",
"
from a import redefined_builtin_bool, my_bool
a = redefined_builtin_bool(0)
b = my_bool(0)
",
)?;
assert_public_ty(&db, "/src/b.py", "a", "Literal[False]");
assert_public_ty(&db, "/src/b.py", "b", "bool");
Ok(())
}

#[test]
fn bool_function_truthy_values() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
r#"
a = bool(1)
b = bool((0,))
c = bool("NON EMPTY")
d = bool(True)
def foo(): pass
e = bool(foo)
"#,
)?;

assert_public_ty(&db, "/src/a.py", "a", "Literal[True]");
assert_public_ty(&db, "/src/a.py", "b", "Literal[True]");
assert_public_ty(&db, "/src/a.py", "c", "Literal[True]");
assert_public_ty(&db, "/src/a.py", "d", "Literal[True]");
assert_public_ty(&db, "/src/a.py", "e", "Literal[True]");
Ok(())
}

#[test]
fn bool_function_ambiguous_values() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
a = bool([])
b = bool({})
c = bool(set())
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "bool");
assert_public_ty(&db, "/src/a.py", "b", "bool");
assert_public_ty(&db, "/src/a.py", "c", "bool");
Ok(())
}
}

0 comments on commit ec72e67

Please sign in to comment.