From b43ef933b763b012575e20b24ecc58ba20aa3a6b Mon Sep 17 00:00:00 2001 From: Yorick Peterse Date: Fri, 13 Dec 2024 03:46:22 +0100 Subject: [PATCH] WIP: try fixing default methods --- compiler/src/llvm/passes.rs | 6 +- compiler/src/mir/specialize.rs | 186 +++++++----- compiler/src/symbol_names.rs | 7 + compiler/src/type_check/methods.rs | 15 +- types/src/lib.rs | 56 +++- types/src/specialize.rs | 457 ++++++++++++++++++++++------- 6 files changed, 530 insertions(+), 197 deletions(-) diff --git a/compiler/src/llvm/passes.rs b/compiler/src/llvm/passes.rs index 92f48928..716e7c94 100644 --- a/compiler/src/llvm/passes.rs +++ b/compiler/src/llvm/passes.rs @@ -46,7 +46,8 @@ use std::thread::scope; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use types::module_name::ModuleName; use types::{ - ClassId, Database, Intrinsic, Shape, TypeRef, BYTE_ARRAY_ID, STRING_ID, + ClassId, Database, Intrinsic, Shape, TypeRef, TypeSpecializationKey, + BYTE_ARRAY_ID, STRING_ID, }; const NIL_VALUE: bool = false; @@ -842,8 +843,9 @@ impl<'shared, 'module, 'ctx> LowerModule<'shared, 'module, 'ctx> { ), }; + let key = TypeSpecializationKey::new(vec![shape]); let class_id = ClassId::array() - .specializations(&self.shared.state.db)[&vec![shape]]; + .specializations(&self.shared.state.db)[&key]; let layout = self.layouts.instances[class_id.0 as usize]; let array = builder.allocate_instance( self.module, diff --git a/compiler/src/mir/specialize.rs b/compiler/src/mir/specialize.rs index 93f3aa8d..266b82f3 100644 --- a/compiler/src/mir/specialize.rs +++ b/compiler/src/mir/specialize.rs @@ -1,7 +1,7 @@ use crate::mir::{ Block, BlockId, Borrow, CallDynamic, CallInstance, CastType, Class as MirClass, Drop, Instruction, InstructionLocation, Method, Mir, - RegisterId, SELF_ID, + RegisterId, }; use crate::state::State; use indexmap::{IndexMap, IndexSet}; @@ -33,10 +33,15 @@ fn specialize_constants( let mut classes = Vec::new(); let shapes = HashMap::new(); + // Constants never need access to the self type, so we just use a dummy + // value here. + let stype = ClassInstance::new(ClassId::nil()); + for &id in mir.constants.keys() { let old_typ = id.value_type(db); - let new_typ = TypeSpecializer::new(db, interned, &shapes, &mut classes) - .specialize(old_typ); + let new_typ = + TypeSpecializer::new(db, interned, &shapes, &mut classes, stype) + .specialize(old_typ); id.set_value_type(db, new_typ); } @@ -92,6 +97,9 @@ fn shapes_compatible_with_bounds( } struct Job { + /// The type of `self` within the method. + self_type: ClassInstance, + /// The ID of the method that's being specialized. method: MethodId, @@ -117,11 +125,12 @@ impl Work { fn push( &mut self, + self_type: ClassInstance, method: MethodId, shapes: HashMap, ) -> bool { if self.done.insert(method) { - self.jobs.push_back(Job { method, shapes }); + self.jobs.push_back(Job { self_type, method, shapes }); true } else { false @@ -220,6 +229,7 @@ impl DynamicCalls { /// A compiler pass that specializes generic types. pub(crate) struct Specialize<'a, 'b> { + self_type: ClassInstance, method: MethodId, state: &'a mut State, work: &'b mut Work, @@ -267,7 +277,7 @@ impl<'a, 'b> Specialize<'a, 'b> { let main_method = state.db.main_method().unwrap(); let main_mod = main_class.module(&state.db); - work.push(main_method, HashMap::new()); + work.push(ClassInstance::new(main_class), main_method, HashMap::new()); // The main() method isn't called explicitly, so we have to manually // record it in the main class. @@ -278,6 +288,7 @@ impl<'a, 'b> Specialize<'a, 'b> { Specialize { state, interned: &mut intern, + self_type: job.self_type, method: job.method, shapes: job.shapes, work: &mut work, @@ -328,54 +339,23 @@ impl<'a, 'b> Specialize<'a, 'b> { } fn run(&mut self, mir: &mut Mir, dynamic_calls: &mut DynamicCalls) { - self.update_self_type(mir); + // TODO: remove + //let cls = self.self_type.instance_of().name(&self.state.db); + let rec = self.method.receiver(&self.state.db); + + println!( + "{:<20} {:<30} ID: {}", + types::format::format_type(&self.state.db, rec), + self.method.name(&self.state.db), + self.method.0, + ); + self.process_instructions(mir, dynamic_calls); self.process_specialized_types(mir, dynamic_calls); self.expand_instructions(mir); self.add_methods(mir); } - fn update_self_type(&mut self, mir: &mut Mir) { - let method = mir.methods.get_mut(&self.method).unwrap(); - - if method.id.is_static(&self.state.db) - || method.id.is_extern(&self.state.db) - { - return; - } - - let self_type = method.id.receiver(&self.state.db); - let mut self_regs = vec![false; method.registers.len()]; - - // This ensures that if `self` is assigned to other registers, we also - // update those registers' types. - for block in &method.body.blocks { - for instruction in &block.instructions { - match instruction { - Instruction::Borrow(ins) if ins.value.0 == SELF_ID => { - self_regs[ins.register.0] = true; - } - Instruction::MoveRegister(ins) - if ins.source.0 == SELF_ID - || self_regs[ins.source.0] => - { - self_regs[ins.target.0] = true; - } - _ => {} - } - } - } - - method.registers.get_mut(RegisterId(SELF_ID)).value_type = self_type; - - for (idx, val) in self_regs.into_iter().enumerate() { - if val { - method.registers.get_mut(RegisterId(idx)).value_type = - self_type; - } - } - } - fn process_instructions( &mut self, mir: &mut Mir, @@ -394,6 +374,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.interned, &self.shapes, &mut self.classes, + self.self_type, ) .specialize(reg.value_type); } @@ -486,15 +467,16 @@ impl<'a, 'b> Specialize<'a, 'b> { } }, Instruction::Allocate(ins) => { - let cls = method + let old = ins.class; + let new = method .registers .value_type(ins.register) .class_id(&self.state.db) .unwrap(); - ins.class = cls; - self.schedule_regular_dropper(cls); - self.schedule_regular_inline_type_methods(cls); + ins.class = new; + self.schedule_regular_dropper(old, new); + self.schedule_regular_inline_type_methods(new); } Instruction::Free(ins) => { let cls = method @@ -506,34 +488,56 @@ impl<'a, 'b> Specialize<'a, 'b> { ins.class = cls; } Instruction::Spawn(ins) => { - let cls = method + let old = ins.class; + let new = method .registers .value_type(ins.register) .class_id(&self.state.db) .unwrap(); - ins.class = cls; - self.schedule_regular_dropper(cls); + ins.class = new; + self.schedule_regular_dropper(old, new); } Instruction::SetField(ins) => { + let db = &mut self.state.db; + ins.class = method .registers .value_type(ins.receiver) - .class_id(&self.state.db) + .class_id(db) + .unwrap(); + + ins.field = ins + .class + .field_by_index(db, ins.field.index(db)) .unwrap(); } Instruction::GetField(ins) => { + let db = &mut self.state.db; + ins.class = method .registers .value_type(ins.receiver) - .class_id(&self.state.db) + .class_id(db) + .unwrap(); + + ins.field = ins + .class + .field_by_index(db, ins.field.index(db)) .unwrap(); } Instruction::FieldPointer(ins) => { + let db = &mut self.state.db; + ins.class = method .registers .value_type(ins.receiver) - .class_id(&self.state.db) + .class_id(db) + .unwrap(); + + ins.field = ins + .class + .field_by_index(db, ins.field.index(db)) .unwrap(); } Instruction::MethodPointer(ins) => { @@ -558,6 +562,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.interned, &self.shapes, &mut self.classes, + self.self_type, ) .specialize(ins.argument); } @@ -660,7 +665,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.add_implementation_shapes(call.method, &mut shapes); self.add_method_bound_shapes(call.method, &mut shapes); - self.specialize_method(class, call.method, &shapes); + self.specialize_method(class, call.method, &shapes, None); } } } @@ -723,7 +728,7 @@ impl<'a, 'b> Specialize<'a, 'b> { } } - self.specialize_method(class, method, &shapes) + self.specialize_method(class, method, &shapes, None) } fn call_dynamic( @@ -748,6 +753,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.interned, &self.shapes, &mut self.classes, + self.self_type, shape, ); } @@ -811,7 +817,7 @@ impl<'a, 'b> Specialize<'a, 'b> { // base class, so we don't accidentally end up using the // wrong shape on a future iteration of the surrounding // loop. - for (¶m, shape) in params.iter().zip(key) { + for (¶m, shape) in params.iter().zip(key.shapes) { shapes.insert(param, shape); } @@ -828,12 +834,12 @@ impl<'a, 'b> Specialize<'a, 'b> { // type arguments may differ per specialization. self.add_implementation_shapes(method_impl, &mut shapes); self.add_method_bound_shapes(method_impl, &mut shapes); - self.specialize_method(class, method_impl, &shapes); + self.specialize_method(class, method_impl, &shapes, None); } } else { self.add_implementation_shapes(method_impl, &mut shapes); self.add_method_bound_shapes(method_impl, &mut shapes); - self.specialize_method(class, method_impl, &shapes); + self.specialize_method(class, method_impl, &shapes, None); } } @@ -899,7 +905,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.add_implementation_shapes(method_impl, &mut shapes); self.add_method_bound_shapes(method_impl, &mut shapes); - let new = self.specialize_method(class, method_impl, &shapes); + let new = self.specialize_method(class, method_impl, &shapes, None); Instruction::CallInstance(Box::new(CallInstance { register: call.register, @@ -916,13 +922,17 @@ impl<'a, 'b> Specialize<'a, 'b> { class: ClassId, method: MethodId, shapes: &HashMap, + custom_self_type: Option, ) -> MethodId { + let ins = ClassInstance::new(class); + let stype = custom_self_type.unwrap_or(ins); + // Regular methods on regular types don't need to be specialized. if !class.is_generic(&self.state.db) && !class.is_closure(&self.state.db) && !method.is_generic(&self.state.db) { - if self.work.push(method, shapes.clone()) { + if self.work.push(stype, method, shapes.clone()) { self.update_method_type(method, shapes); self.regular_methods.push(method); } @@ -943,25 +953,21 @@ impl<'a, 'b> Specialize<'a, 'b> { return new; } - let ins = ClassInstance::new(class); let new_rec = method.receiver_for_class_instance(&self.state.db, ins); let new = self.specialize_method_type(new_rec, method, key, shapes); - self.work.push(new, shapes.clone()); + self.work.push(stype, new, shapes.clone()); self.specialized_methods.push((method, new)); new } - fn schedule_regular_dropper(&mut self, class: ClassId) { - if class.is_generic(&self.state.db) { + fn schedule_regular_dropper(&mut self, original: ClassId, class: ClassId) { + if class.is_generic(&self.state.db) || class.is_closure(&self.state.db) + { return; } - if let Some(dropper) = class.method(&self.state.db, DROPPER_METHOD) { - if self.work.push(dropper, HashMap::new()) { - self.regular_methods.push(dropper); - } - } + self.generate_dropper(original, class); } fn schedule_regular_inline_type_methods(&mut self, class: ClassId) { @@ -975,8 +981,9 @@ impl<'a, 'b> Specialize<'a, 'b> { for name in methods { let method = class.method(&self.state.db, name).unwrap(); + let stype = ClassInstance::new(class); - if self.work.push(method, HashMap::new()) { + if self.work.push(stype, method, HashMap::new()) { self.regular_methods.push(method); } } @@ -990,8 +997,16 @@ impl<'a, 'b> Specialize<'a, 'b> { return; }; + // References to `self` in closures should point to the type of the + // scope the closure is defined in, not the closure itself. + let stype = if class.is_closure(&self.state.db) { + self.self_type + } else { + ClassInstance::new(class) + }; + if original == class { - if self.work.push(method, HashMap::new()) { + if self.work.push(stype, method, HashMap::new()) { self.regular_methods.push(method); } @@ -1008,7 +1023,7 @@ impl<'a, 'b> Specialize<'a, 'b> { .collect() }; - let new = self.specialize_method(class, method, &shapes); + let new = self.specialize_method(class, method, &shapes, Some(stype)); class.add_method(&mut self.state.db, name.to_string(), new); } @@ -1028,7 +1043,9 @@ impl<'a, 'b> Specialize<'a, 'b> { let method = original.method(&self.state.db, name).unwrap(); if original == class { - if self.work.push(method, HashMap::new()) { + let stype = ClassInstance::new(class); + + if self.work.push(stype, method, HashMap::new()) { self.regular_methods.push(method); } @@ -1041,7 +1058,7 @@ impl<'a, 'b> Specialize<'a, 'b> { .zip(class.shapes(&self.state.db).clone()) .collect(); - let new = self.specialize_method(class, method, &shapes); + let new = self.specialize_method(class, method, &shapes, None); let name = method.name(&self.state.db).clone(); class.add_method(&mut self.state.db, name, new); @@ -1054,7 +1071,11 @@ impl<'a, 'b> Specialize<'a, 'b> { let shapes = self.shapes.clone(); let method = original.method(&self.state.db, CALL_METHOD).unwrap(); - self.specialize_method(class, method, &shapes); + // Within a closure's `call` method, explicit references to or captures + // of `self` should refer to the type of `self` as used by the method in + // which the closure is defined, instead of pointing to the closure's + // type. + self.specialize_method(class, method, &shapes, Some(self.self_type)); } /// Creates a new specialized method, using an existing method as its @@ -1097,6 +1118,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.interned, shapes, &mut self.classes, + self.self_type, ) .specialize(arg.value_type); @@ -1107,6 +1129,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.interned, shapes, &mut self.classes, + self.self_type, ) .specialize(raw_var_type); @@ -1124,6 +1147,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.interned, shapes, &mut self.classes, + self.self_type, ) .specialize(old_ret); @@ -1159,6 +1183,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.interned, shapes, &mut self.classes, + self.self_type, ) .specialize(arg.value_type); @@ -1168,6 +1193,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.interned, shapes, &mut self.classes, + self.self_type, ) .specialize(raw_var_type); @@ -1185,6 +1211,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.interned, shapes, &mut self.classes, + self.self_type, ) .specialize(old_ret); @@ -1321,6 +1348,7 @@ impl<'a, 'b> Specialize<'a, 'b> { self.interned, &self.shapes, &mut self.classes, + self.self_type, key, ); } diff --git a/compiler/src/symbol_names.rs b/compiler/src/symbol_names.rs index f5b3e1c0..3389fb9b 100644 --- a/compiler/src/symbol_names.rs +++ b/compiler/src/symbol_names.rs @@ -147,6 +147,8 @@ impl SymbolNames { let mut setup_classes = HashMap::new(); let mut setup_constants = HashMap::new(); + let mut names = std::collections::HashSet::new(); + for module in mir.modules.values() { for &class in &module.classes { let class_name = format!( @@ -155,6 +157,11 @@ impl SymbolNames { qualified_class_name(db, module.id, class) ); + // TODO: remove + if !names.insert(class_name.clone()) { + panic!("duplicate symbol name: {}", class_name); + } + classes.insert(class, class_name); } } diff --git a/compiler/src/type_check/methods.rs b/compiler/src/type_check/methods.rs index f7d0f9dc..efb7bfbc 100644 --- a/compiler/src/type_check/methods.rs +++ b/compiler/src/type_check/methods.rs @@ -991,11 +991,16 @@ impl<'a> DefineMethods<'a> { ..Default::default() }; let bounds = TypeBounds::new(); - let self_type = TypeId::TraitInstance(TraitInstance::rigid( - self.db_mut(), - trait_id, - &bounds, - )); + let mut ins = TraitInstance::rigid(self.db_mut(), trait_id, &bounds); + + // We set this flag so that when we specialize the method, we know that + // references to this type should be replaced with the type of the + // implementing type. Without this flag, given some value typed as trait + // `Foo`, we have no way of knowing if that `Foo` is the type of `self` + // or an unrelated value that happens to have the same type. + ins.self_type = true; + + let self_type = TypeId::TraitInstance(ins); let receiver = receiver_type(self.db(), self_type, node.kind); method.set_receiver(self.db_mut(), receiver); diff --git a/types/src/lib.rs b/types/src/lib.rs index d0cb879f..de378bbf 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -902,11 +902,18 @@ pub struct TraitInstance { /// After type specialization takes place, this value shouldn't be used any /// more as specialized types won't have their type arguments set. type_arguments: u32, + + /// A boolean indicating if this type is the type of `self` inside a method. + /// + /// This field is used during type specialization such that we know if some + /// trait `Foo` is the type of `self` or some unrelated type that happens to + /// be typed as `Foo` as well. + pub self_type: bool, } impl TraitInstance { pub fn new(instance_of: TraitId) -> Self { - Self { instance_of, type_arguments: 0 } + Self { instance_of, type_arguments: 0, self_type: false } } pub fn rigid( @@ -940,7 +947,11 @@ impl TraitInstance { let type_args_id = db.type_arguments.len() as u32; db.type_arguments.push(arguments); - TraitInstance { instance_of, type_arguments: type_args_id } + TraitInstance { + instance_of, + type_arguments: type_args_id, + self_type: false, + } } pub fn instance_of(self) -> TraitId { @@ -1306,6 +1317,29 @@ impl ClassKind { } } +/// A type used as the key for a type specialization lookup. +#[derive(Eq, PartialEq, Hash, Clone, Debug)] +pub struct TypeSpecializationKey { + pub shapes: Vec, + + /// Closures may be defined in a default method, in which case we should + /// specialize them for every type that implements the corresponding trait. + self_type: Option, +} + +impl TypeSpecializationKey { + pub fn new(shapes: Vec) -> TypeSpecializationKey { + TypeSpecializationKey { self_type: None, shapes } + } + + pub fn for_closure( + self_type: ClassInstance, + shapes: Vec, + ) -> TypeSpecializationKey { + TypeSpecializationKey { self_type: Some(self_type), shapes } + } +} + /// An Inko class as declared using the `class` keyword. pub struct Class { kind: ClassKind, @@ -1329,7 +1363,7 @@ pub struct Class { methods: HashMap, implemented_traits: HashMap, constructors: IndexMap, - specializations: HashMap, ClassId>, + specializations: HashMap, /// The ID of the class this class is a specialization of. specialization_source: Option, @@ -1532,6 +1566,11 @@ impl ClassId { &self.get(db).name } + // TODO: necessary? + pub fn set_name(self, db: &mut Database, name: String) { + self.get_mut(db).name = name; + } + pub fn kind(self, db: &Database) -> ClassKind { self.get(db).kind } @@ -1750,10 +1789,19 @@ impl ClassId { self.get_mut(db).specialization_source = Some(class); } + pub fn add_specialization( + self, + db: &mut Database, + key: TypeSpecializationKey, + class: ClassId, + ) { + self.get_mut(db).specializations.insert(key, class); + } + pub fn specializations( self, db: &Database, - ) -> &HashMap, ClassId> { + ) -> &HashMap { &self.get(db).specializations } diff --git a/types/src/specialize.rs b/types/src/specialize.rs index 37ce532c..71351c54 100644 --- a/types/src/specialize.rs +++ b/types/src/specialize.rs @@ -1,6 +1,6 @@ use crate::{ ClassId, ClassInstance, Database, InternedTypeArguments, Shape, TypeId, - TypeParameterId, TypeRef, + TypeParameterId, TypeRef, TypeSpecializationKey, }; use std::collections::HashMap; @@ -38,6 +38,9 @@ pub struct TypeSpecializer<'a, 'b, 'c> { /// parameter as it was determined when creating the newly specialized /// class. shapes: &'b HashMap, + + /// The type `self` is an instance of. + self_type: ClassInstance, } impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { @@ -46,11 +49,12 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { interned: &'b mut InternedTypeArguments, shapes: &'b HashMap, classes: &'c mut Vec, + self_type: ClassInstance, key: &mut Vec, ) { for shape in key { TypeSpecializer::specialize_shape( - db, interned, shapes, classes, shape, + db, interned, shapes, classes, self_type, shape, ); } } @@ -60,6 +64,7 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { interned: &'b mut InternedTypeArguments, shapes: &'b HashMap, classes: &'c mut Vec, + self_type: ClassInstance, shape: &mut Shape, ) { match shape { @@ -67,8 +72,10 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { | Shape::Inline(i) | Shape::InlineRef(i) | Shape::InlineMut(i) => { - *i = TypeSpecializer::new(db, interned, shapes, classes) - .specialize_class_instance(*i); + *i = TypeSpecializer::new( + db, interned, shapes, classes, self_type, + ) + .specialize_class_instance(*i); } _ => {} } @@ -79,12 +86,55 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { interned: &'b mut InternedTypeArguments, shapes: &'b HashMap, classes: &'c mut Vec, + self_type: ClassInstance, ) -> TypeSpecializer<'a, 'b, 'c> { - TypeSpecializer { db, interned, shapes, classes } + TypeSpecializer { db, interned, shapes, classes, self_type } } pub fn specialize(&mut self, value: TypeRef) -> TypeRef { match value { + // When specializing default methods inherited from traits, we need + // to replace the trait types used for `self` with the type of + // whatever implements the trait. This is needed such that if e.g. a + // closure captures `self` and `self` is a stack allocated type, the + // closure is specialized correctly. + TypeRef::Owned(TypeId::TraitInstance(i)) if i.self_type => { + // TODO: remove + println!( + " swapping '{}' with '{}'", + crate::format::format_type(self.db, value), + crate::format::format_type(self.db, self.self_type) + ); + TypeRef::Owned(TypeId::ClassInstance(self.self_type)) + } + TypeRef::Uni(TypeId::TraitInstance(i)) if i.self_type => { + // TODO: remove + println!( + " swapping '{}' with '{}'", + crate::format::format_type(self.db, value), + crate::format::format_type(self.db, self.self_type) + ); + TypeRef::Uni(TypeId::ClassInstance(self.self_type)) + } + TypeRef::Ref(TypeId::TraitInstance(i)) if i.self_type => { + // TODO: remove + println!( + " swapping '{}' with '{}'", + crate::format::format_type(self.db, value), + crate::format::format_type(self.db, self.self_type) + ); + TypeRef::Ref(TypeId::ClassInstance(self.self_type)) + } + TypeRef::Mut(TypeId::TraitInstance(i)) if i.self_type => { + // TODO: remove + println!( + " swapping '{}' with '{}'", + crate::format::format_type(self.db, value), + crate::format::format_type(self.db, self.self_type) + ); + + TypeRef::Mut(TypeId::ClassInstance(self.self_type)) + } // When specializing type parameters, we have to reuse existing // shapes if there are any. This leads to a bit of duplication, but // there's not really a way around that without making things more @@ -244,12 +294,11 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { ) -> ClassInstance { let class = ins.instance_of(); + // For regular instances we only need to specialize the first reference. if class.specialization_source(self.db).is_some() { return ins; } - // Rather than introducing another flag to skip processing already - // processed types, we just reuse the specialization source. class.set_specialization_source(self.db, class); self.classes.push(class); @@ -265,6 +314,7 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { self.interned, self.shapes, self.classes, + self.self_type, ) .specialize(v) }) @@ -281,6 +331,7 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { self.interned, self.shapes, self.classes, + self.self_type, ) .specialize(old); @@ -301,7 +352,7 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { } let mut args = ins.type_arguments(self.db).unwrap().clone(); - let mut key: Vec = class + let mut shapes: Vec = class .type_parameters(self.db) .into_iter() .map(|p| { @@ -318,12 +369,13 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { self.interned, self.shapes, self.classes, - &mut key, + self.self_type, + &mut shapes, ); + let key = TypeSpecializationKey::new(shapes); let new = class - .get(self.db) - .specializations + .specializations(self.db) .get(&key) .cloned() .unwrap_or_else(|| self.specialize_class(class, key)); @@ -338,39 +390,53 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { &mut self, ins: ClassInstance, ) -> ClassInstance { - let class = ins.instance_of; - - if class.specialization_source(self.db).is_some() { - return ins; - } - + // We don't check the specialization source for closures, as each + // closure _always_ needs to be specialized, as its behaviour/layout may + // change based on how the surrounding method is specialized. + // // Closures may capture types that contain generic type parameters. If // the shapes of those parameters changes, we must specialize the // closure accordingly. For this reason, the specialization key is all // the shapes the closure can possibly access, rather than this being // limited to the types captured. - let mut key = ordered_shapes_from_map(self.shapes); + let mut shapes = ordered_shapes_from_map(self.shapes); TypeSpecializer::specialize_shapes( self.db, self.interned, self.shapes, self.classes, - &mut key, + self.self_type, + &mut shapes, ); + let key = TypeSpecializationKey::for_closure(self.self_type, shapes); + let class = ins.instance_of; let new = class - .get(self.db) - .specializations + .specializations(self.db) .get(&key) .cloned() .unwrap_or_else(|| self.specialize_class(class, key)); + // TODO: something better to prevent name conflicts + new.set_name( + self.db, + format!( + "{} in {}", + class.name(self.db), + self.self_type.instance_of().name(self.db) + ), + ); + ClassInstance::new(new) } #[allow(clippy::unnecessary_to_owned)] - fn specialize_class(&mut self, class: ClassId, key: Vec) -> ClassId { + fn specialize_class( + &mut self, + class: ClassId, + key: TypeSpecializationKey, + ) -> ClassId { let new = class.clone_for_specialization(self.db); self.classes.push(new); @@ -385,26 +451,30 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { new.get_mut(self.db).type_parameters.insert(name, param); } - new.set_shapes(self.db, key.clone()); - class.get_mut(self.db).specializations.insert(key.clone(), new); + let shapes = key.shapes.clone(); + + new.set_shapes(self.db, shapes.clone()); + class.add_specialization(self.db, key, new); // When specializing fields and constructors, we want them to reuse the // shapes we just created. - let class_mapping = class - .type_parameters(self.db) - .into_iter() - .zip(key) - .fold(HashMap::new(), |mut map, (param, shape)| { - map.insert(param, shape); - map - }); + let mut class_mapping = HashMap::new(); // Closures may capture generic parameters from the outside, and the // classes themselves aren't generic, so we reuse the outer shapes // instead. let kind = class.kind(self.db); - let mapping = - if kind.is_closure() { self.shapes } else { &class_mapping }; + let mapping = if kind.is_closure() { + self.shapes + } else { + for (param, shape) in + class.type_parameters(self.db).into_iter().zip(shapes) + { + class_mapping.insert(param, shape); + } + + &class_mapping + }; if kind.is_enum() { for old_cons in class.constructors(self.db) { @@ -420,6 +490,7 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { self.interned, mapping, self.classes, + self.self_type, ) .specialize(v) }) @@ -447,6 +518,7 @@ impl<'a, 'b, 'c> TypeSpecializer<'a, 'b, 'c> { self.interned, mapping, self.classes, + self.self_type, ) .specialize(orig_typ); @@ -463,9 +535,9 @@ mod tests { use crate::format::format_type; use crate::test::{ any, generic_instance_id, immutable, instance, mutable, new_class, - new_enum_class, new_parameter, owned, parameter, rigid, uni, + new_enum_class, new_parameter, new_trait, owned, parameter, rigid, uni, }; - use crate::{ClassId, Location, ModuleId, Visibility}; + use crate::{ClassId, Location, ModuleId, TraitInstance, Visibility}; #[test] fn test_specialize_type() { @@ -480,19 +552,30 @@ mod tests { let raw1 = owned(generic_instance_id(&mut db, class, vec![int])); let raw2 = owned(generic_instance_id(&mut db, class, vec![int])); let mut classes = Vec::new(); - let spec1 = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(raw1); - let spec2 = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(raw2); + let stype = ClassInstance::new(ClassId::int()); + let spec1 = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(raw1); + let spec2 = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(raw2); assert_eq!(format_type(&db, spec1), "Array[Int]"); assert_eq!(format_type(&db, spec2), "Array[Int]"); - assert_eq!(class.get(&db).specializations.len(), 1); + assert_eq!(class.specializations(&db).len(), 1); - let new_class = - *class.get(&db).specializations.get(&vec![Shape::int()]).unwrap(); + let key = TypeSpecializationKey::new(vec![Shape::int()]); + let new_class = *class.specializations(&db).get(&key).unwrap(); assert_eq!(classes, &[ClassId::int(), new_class]); assert_eq!(new_class.specialization_source(&db), Some(class)); @@ -525,9 +608,15 @@ mod tests { let raw = TypeRef::Pointer(generic_instance_id(&mut db, class, vec![int])); let mut classes = Vec::new(); - let spec = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(raw); + let stype = ClassInstance::new(ClassId::int()); + let spec = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(raw); assert_eq!(format_type(&db, spec), "Pointer[Array[Int]]"); } @@ -548,16 +637,22 @@ mod tests { vec![immutable(instance(foo))], )); let mut classes = Vec::new(); - let spec = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(raw); + let stype = ClassInstance::new(ClassId::int()); + let spec = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(raw); assert_eq!(format_type(&db, spec), "Array[ref Foo]"); assert_eq!( spec, TypeRef::Owned(TypeId::ClassInstance(ClassInstance { instance_of: ClassId(db.number_of_classes() as u32 - 1), - type_arguments: 1 + type_arguments: 1, })) ); assert_eq!(classes.len(), 2); @@ -625,9 +720,15 @@ mod tests { let mut interned = InternedTypeArguments::new(); let mut classes = Vec::new(); - let spec = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(raw); + let stype = ClassInstance::new(ClassId::int()); + let spec = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(raw); assert_eq!(format_type(&db, spec), "(Int, ref X, mut Y: mut)"); assert_eq!(classes.len(), 2); @@ -682,9 +783,15 @@ mod tests { let raw = owned(generic_instance_id(&mut db, opt, vec![TypeRef::int()])); - let res = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(raw); + let stype = ClassInstance::new(ClassId::int()); + let res = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(raw); assert_eq!(classes.len(), 2); assert!(classes[1].kind(&db).is_enum()); @@ -715,13 +822,24 @@ mod tests { let raw = owned(generic_instance_id(&mut db, class, vec![int])); let mut classes = Vec::new(); let mut interned = InternedTypeArguments::new(); - let res1 = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(raw); - - let res2 = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(res1); + let stype = ClassInstance::new(ClassId::int()); + let res1 = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(raw); + + let res2 = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(res1); assert_eq!(res1, res2); assert_eq!(classes, &[ClassId::int(), res1.class_id(&db).unwrap()]); @@ -737,17 +855,33 @@ mod tests { shapes.insert(param, Shape::Atomic); - let owned = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(owned(parameter(param))); - - let immutable = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(immutable(parameter(param))); - - let mutable = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(mutable(parameter(param))); + let stype = ClassInstance::new(ClassId::int()); + let owned = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(owned(parameter(param))); + + let immutable = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(immutable(parameter(param))); + + let mutable = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(mutable(parameter(param))); assert_eq!(owned, TypeRef::Owned(TypeId::AtomicTypeParameter(param))); assert_eq!(immutable, TypeRef::Ref(TypeId::AtomicTypeParameter(param))); @@ -764,21 +898,42 @@ mod tests { shapes.insert(param, Shape::Mut); - let owned = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(owned(parameter(param))); - - let uni = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(uni(parameter(param))); - - let immutable = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(immutable(parameter(param))); - - let mutable = - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(mutable(parameter(param))); + let stype = ClassInstance::new(ClassId::int()); + let owned = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(owned(parameter(param))); + + let uni = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(uni(parameter(param))); + + let immutable = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(immutable(parameter(param))); + + let mutable = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype, + ) + .specialize(mutable(parameter(param))); assert_eq!(owned, TypeRef::Mut(TypeId::TypeParameter(param))); assert_eq!(uni, TypeRef::UniMut(TypeId::TypeParameter(param))); @@ -796,50 +951,138 @@ mod tests { let ins = ClassInstance::new(cls); let p1 = new_parameter(&mut db, "X"); let p2 = new_parameter(&mut db, "Y"); + let stype = ClassInstance::new(ClassId::int()); shapes.insert(p1, Shape::Inline(ins)); shapes.insert(p2, Shape::Copy(ins)); assert_eq!( - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(owned(parameter(p1))), + TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype + ) + .specialize(owned(parameter(p1))), owned(instance(cls)) ); assert_eq!( - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(uni(parameter(p1))), + TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype + ) + .specialize(uni(parameter(p1))), owned(instance(cls)) ); assert_eq!( - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(mutable(parameter(p1))), + TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype + ) + .specialize(mutable(parameter(p1))), mutable(instance(cls)) ); assert_eq!( - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(immutable(parameter(p1))), + TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype + ) + .specialize(immutable(parameter(p1))), immutable(instance(cls)) ); assert_eq!( - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(owned(parameter(p2))), + TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype + ) + .specialize(owned(parameter(p2))), owned(instance(cls)) ); assert_eq!( - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(owned(parameter(p2))), + TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype + ) + .specialize(owned(parameter(p2))), owned(instance(cls)) ); assert_eq!( - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(mutable(parameter(p2))), + TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype + ) + .specialize(mutable(parameter(p2))), owned(instance(cls)) ); assert_eq!( - TypeSpecializer::new(&mut db, &mut interned, &shapes, &mut classes) - .specialize(immutable(parameter(p2))), + TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + stype + ) + .specialize(immutable(parameter(p2))), owned(instance(cls)) ); } + + #[test] + fn test_specialize_trait_self_type() { + let mut db = Database::new(); + let shapes = HashMap::new(); + let mut classes = Vec::new(); + let mut interned = InternedTypeArguments::new(); + let trt = new_trait(&mut db, "ToThing"); + let cls = new_class(&mut db, "Thing"); + let mut old_self = TraitInstance::new(trt); + + old_self.self_type = true; + + let new_self = ClassInstance::new(cls); + let mut spec = TypeSpecializer::new( + &mut db, + &mut interned, + &shapes, + &mut classes, + new_self, + ); + + assert_eq!( + spec.specialize(owned(TypeId::TraitInstance(old_self))), + owned(TypeId::ClassInstance(new_self)) + ); + assert_eq!( + spec.specialize(immutable(TypeId::TraitInstance(old_self))), + immutable(TypeId::ClassInstance(new_self)) + ); + assert_eq!( + spec.specialize(mutable(TypeId::TraitInstance(old_self))), + mutable(TypeId::ClassInstance(new_self)) + ); + assert_eq!( + spec.specialize(uni(TypeId::TraitInstance(old_self))), + uni(TypeId::ClassInstance(new_self)) + ); + } }