Skip to content

Commit

Permalink
[red-knot] infer_symbol_public_type infers union of all definitions (#…
Browse files Browse the repository at this point in the history
…11669)

## Summary

Rename `infer_symbol_type` to `infer_symbol_public_type`, and allow it
to work on symbols with more than one definition. For now, use the most
cautious/sound inference, which is the union of all definitions. We can
prune this union more in future by eliminating definitions if we can
show that they can't be visible (this requires both that the symbol is
definitely later reassigned, and that there is no intervening
call/import that might be able to see the over-written definition).

## Test Plan

Added a test showing inference of union from multiple definitions.
  • Loading branch information
carljm authored Jun 3, 2024
1 parent 2b28889 commit b02d3f3
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 30 deletions.
10 changes: 5 additions & 5 deletions crates/red_knot/src/lint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::source::{source_text, Source};
use crate::symbols::{
resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, SymbolId, SymbolTable,
};
use crate::types::{infer_definition_type, infer_symbol_type, Type};
use crate::types::{infer_definition_type, infer_symbol_public_type, Type};

#[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn lint_syntax(db: &dyn LintDb, file_id: FileId) -> QueryResult<Diagnostics> {
Expand Down Expand Up @@ -104,14 +104,14 @@ fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> {
for (symbol, definition) in context.symbols().all_definitions() {
match definition {
Definition::Import(import) => {
let ty = context.infer_symbol_type(symbol)?;
let ty = context.infer_symbol_public_type(symbol)?;

if ty.is_unknown() {
context.push_diagnostic(format!("Unresolved module {}", import.module));
}
}
Definition::ImportFrom(import) => {
let ty = context.infer_symbol_type(symbol)?;
let ty = context.infer_symbol_public_type(symbol)?;

if ty.is_unknown() {
let module_name = import.module().map(Deref::deref).unwrap_or_default();
Expand Down Expand Up @@ -217,8 +217,8 @@ impl<'a> SemanticLintContext<'a> {
&self.symbols
}

pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_type(
pub fn infer_symbol_public_type(&self, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_public_type(
self.db.upcast(),
GlobalSymbolId {
file_id: self.file_id,
Expand Down
20 changes: 10 additions & 10 deletions crates/red_knot/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rustc_hash::FxHashMap;

pub(crate) mod infer;

pub(crate) use infer::{infer_definition_type, infer_symbol_type};
pub(crate) use infer::{infer_definition_type, infer_symbol_public_type};

/// unique ID for a type
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -119,7 +119,7 @@ impl TypeStore {
self.modules.remove(&file_id);
}

pub fn cache_symbol_type(&self, symbol: GlobalSymbolId, ty: Type) {
pub fn cache_symbol_public_type(&self, symbol: GlobalSymbolId, ty: Type) {
self.add_or_get_module(symbol.file_id)
.symbol_types
.insert(symbol.symbol_id, ty);
Expand All @@ -131,7 +131,7 @@ impl TypeStore {
.insert(node_key, ty);
}

pub fn get_cached_symbol_type(&self, symbol: GlobalSymbolId) -> Option<Type> {
pub fn get_cached_symbol_public_type(&self, symbol: GlobalSymbolId) -> Option<Type> {
self.try_get_module(symbol.file_id)?
.symbol_types
.get(&symbol.symbol_id)
Expand Down Expand Up @@ -182,12 +182,12 @@ impl TypeStore {
.add_class(name, scope_id, bases)
}

fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
fn add_union(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
self.add_or_get_module(file_id).add_union(elems)
}

fn add_intersection(
&mut self,
&self,
file_id: FileId,
positive: &[Type],
negative: &[Type],
Expand Down Expand Up @@ -393,7 +393,7 @@ impl ModuleTypeId {

fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult<Option<Type>> {
if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? {
Ok(Some(infer_symbol_type(db, symbol_id)?))
Ok(Some(infer_symbol_public_type(db, symbol_id)?))
} else {
Ok(None)
}
Expand Down Expand Up @@ -441,7 +441,7 @@ impl ClassTypeId {
let ClassType { scope_id, .. } = *self.class(db)?;
let table = symbol_table(db, self.file_id)?;
if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) {
Ok(Some(infer_symbol_type(
Ok(Some(infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: self.file_id,
Expand Down Expand Up @@ -497,7 +497,7 @@ struct ModuleTypeStore {
unions: IndexVec<ModuleUnionTypeId, UnionType>,
/// arena of all intersection types created in this module
intersections: IndexVec<ModuleIntersectionTypeId, IntersectionType>,
/// cached types of symbols in this module
/// cached public types of symbols in this module
symbol_types: FxHashMap<SymbolId, Type>,
/// cached types of AST nodes in this module
node_types: FxHashMap<NodeKey, Type>,
Expand Down Expand Up @@ -777,7 +777,7 @@ mod tests {

#[test]
fn add_union() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
Expand All @@ -794,7 +794,7 @@ mod tests {

#[test]
fn add_intersection() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
Expand Down
83 changes: 68 additions & 15 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,41 @@ use crate::types::{ModuleTypeId, Type};
use crate::{FileId, Name};

// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
/// Resolve the public-facing type for a symbol (the type seen by other scopes: other modules, or
/// nested functions). Because calls to nested functions and imports can occur anywhere in control
/// flow, this type must be conservative and consider all definitions of the symbol that could
/// possibly be seen by another scope. Currently we take the most conservative approach, which is
/// the union of all definitions. We may be able to narrow this in future to eliminate definitions
/// which can't possibly (or at least likely) be seen by any other scope, so that e.g. we could
/// infer `Literal["1"]` instead of `Literal[1] | Literal["1"]` for `x` in `x = x; x = str(x);`.
#[tracing::instrument(level = "trace", skip(db))]
pub fn infer_symbol_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult<Type> {
pub fn infer_symbol_public_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult<Type> {
let symbols = symbol_table(db, symbol.file_id)?;
let defs = symbols.definitions(symbol.symbol_id);
let jar: &SemanticJar = db.jar()?;

if let Some(ty) = jar.type_store.get_cached_symbol_type(symbol) {
if let Some(ty) = jar.type_store.get_cached_symbol_public_type(symbol) {
return Ok(ty);
}

// TODO handle multiple defs, conditional defs...
assert_eq!(defs.len(), 1);

let ty = infer_definition_type(db, symbol, defs[0].clone())?;
let mut tys = defs
.iter()
.map(|def| infer_definition_type(db, symbol, def.clone()))
.peekable();
let ty = if let Some(first) = tys.next() {
if tys.peek().is_some() {
Type::Union(jar.type_store.add_union(
symbol.file_id,
&Iterator::chain([first].into_iter(), tys).collect::<QueryResult<Vec<_>>>()?,
))
} else {
first?
}
} else {
Type::Unknown
};

jar.type_store.cache_symbol_type(symbol, ty);
jar.type_store.cache_symbol_public_type(symbol, ty);

// TODO record dependencies
Ok(ty)
Expand Down Expand Up @@ -65,7 +84,7 @@ pub fn infer_definition_type(
assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? {
infer_symbol_type(db, remote_symbol)
infer_symbol_public_type(db, remote_symbol)
} else {
Ok(Type::Unknown)
}
Expand Down Expand Up @@ -158,7 +177,8 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu
ast::Expr::Name(name) => {
// TODO look up in the correct scope, don't assume global
if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) {
infer_symbol_type(db, GlobalSymbolId { file_id, symbol_id })
// TODO should use only reachable definitions, not public type
infer_symbol_public_type(db, GlobalSymbolId { file_id, symbol_id })
} else {
Ok(Type::Unknown)
}
Expand All @@ -182,7 +202,7 @@ mod tests {
resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind,
};
use crate::symbols::{symbol_table, GlobalSymbolId};
use crate::types::{infer_symbol_type, Type};
use crate::types::{infer_symbol_public_type, Type};
use crate::Name;

// TODO with virtual filesystem we shouldn't have to write files to disk for these
Expand Down Expand Up @@ -228,7 +248,7 @@ mod tests {
.root_symbol_id_by_name("E")
.expect("E symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: a_file,
Expand Down Expand Up @@ -259,7 +279,7 @@ mod tests {
.root_symbol_id_by_name("Sub")
.expect("Sub symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
Expand Down Expand Up @@ -300,7 +320,7 @@ mod tests {
.root_symbol_id_by_name("C")
.expect("C symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
Expand Down Expand Up @@ -345,7 +365,7 @@ mod tests {
.root_symbol_id_by_name("D")
.expect("D symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: a_file,
Expand Down Expand Up @@ -375,7 +395,7 @@ mod tests {
.root_symbol_id_by_name("x")
.expect("x symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
Expand All @@ -388,4 +408,37 @@ mod tests {
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[1]");
Ok(())
}

#[test]
fn resolve_union() -> anyhow::Result<()> {
let case = create_test()?;
let db = &case.db;

let path = case.src.path().join("a.py");
std::fs::write(path, "if flag:\n x = 1\nelse:\n x = 2")?;
let file = resolve_module(db, ModuleName::new("a"))?
.expect("module should be found")
.path(db)?
.file();
let syms = symbol_table(db, file)?;
let x_sym = syms
.root_symbol_id_by_name("x")
.expect("x symbol should be found");

let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
symbol_id: x_sym,
},
)?;

let jar = HasJar::<SemanticJar>::jar(db)?;
assert!(matches!(ty, Type::Union(_)));
assert_eq!(
format!("{}", ty.display(&jar.type_store)),
"(Literal[1] | Literal[2])"
);
Ok(())
}
}

0 comments on commit b02d3f3

Please sign in to comment.