From d6527f2899ee2c4cba7eb0481911f2bc6eca78e3 Mon Sep 17 00:00:00 2001 From: Lennart Van Hirtum Date: Sun, 2 Jun 2024 16:54:33 +0200 Subject: [PATCH] Have explicit nodes in Flattening for Func Calls This is for two reasons - Func calls have become relevant for generative code execution. It was really difficult to do with the current Write-based system - Also future idea for "action" functions, that enable an implicit port when activating. --- src/arena_alloc.rs | 46 +++++++++--- src/dev_aid/lsp/tree_walk.rs | 5 ++ src/flattening/mod.rs | 35 ++++++++- src/flattening/parse.rs | 115 +++++++++++++++-------------- src/flattening/typechecking.rs | 54 ++++++++++---- src/instantiation/execute.rs | 74 +++++++++++++------ src/instantiation/latency_count.rs | 6 +- src/instantiation/mod.rs | 2 +- test.sus | 9 +++ 9 files changed, 232 insertions(+), 114 deletions(-) diff --git a/src/arena_alloc.rs b/src/arena_alloc.rs index dd14b16..36cea15 100644 --- a/src/arena_alloc.rs +++ b/src/arena_alloc.rs @@ -53,12 +53,6 @@ impl UUID { pub struct UUIDRange(pub UUID, pub UUID); -impl UUIDRange { - pub fn contains(&self, id : UUID) -> bool { - self.0.0 >= id.0 && self.1.0 < id.0 - } -} - impl Debug for UUIDRange { fn fmt(&self, f: &mut Formatter<'_>) -> Result { f.write_str(IndexMarker::DISPLAY_NAME)?; @@ -87,16 +81,38 @@ impl Hash for UUIDRange { } } +impl IntoIterator for UUIDRange { + type Item = UUID; + + type IntoIter = UUIDRangeIter; + + fn into_iter(self) -> Self::IntoIter { + UUIDRangeIter(self.0, self.1) + } +} + #[derive(Clone, Copy, PartialEq, Eq, Hash)] -pub struct UUIDRangeIter(UUID, UUID); +pub struct UUIDRangeIter(UUID, UUID); -impl UUIDRange { +impl UUIDRange { pub fn empty() -> Self { UUIDRange(UUID(0, PhantomData), UUID(0, PhantomData)) } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn contains(&self, id : UUID) -> bool { + self.0.0 >= id.0 && self.1.0 < id.0 + } + pub fn iter(&self) -> UUIDRangeIter { + self.into_iter() + } + pub fn len(&self) -> usize { + self.1.0 - self.0.0 + } } -impl Iterator for UUIDRange { +impl Iterator for UUIDRangeIter { type Item = UUID; fn next(&mut self) -> Option { @@ -108,9 +124,19 @@ impl Iterator for UUIDRange { Some(result) } } + fn size_hint(&self) -> (usize, Option) { + let sz = self.len(); + (sz, Some(sz)) + } +} + +impl ExactSizeIterator for UUIDRangeIter { + fn len(&self) -> usize { + self.1.0 - self.0.0 + } } -impl UUIDRange { +impl UUIDRangeIter { pub fn skip_to(&mut self, to : UUID) { assert!(to.0 >= self.0.0); assert!(to.0 <= self.1.0); diff --git a/src/dev_aid/lsp/tree_walk.rs b/src/dev_aid/lsp/tree_walk.rs index 9168b16..7e455ca 100644 --- a/src/dev_aid/lsp/tree_walk.rs +++ b/src/dev_aid/lsp/tree_walk.rs @@ -225,6 +225,11 @@ impl<'linker, Visitor : FnMut(Span, LocationInfo<'linker>), Pruner : Fn(Span) -> Instruction::Write(write) => { self.walk_wire_ref(md_id, md, &write.to); } + Instruction::FuncCall(fc) => { + if let Some(submod_name_span) = fc.name_span { + self.visit(submod_name_span, LocationInfo::InModule(md_id, md, fc.submodule_instruction, InModule::NamedSubmodule(md.instructions[fc.submodule_instruction].unwrap_submodule()))); + } + } Instruction::IfStatement(_) | Instruction::ForStatement(_) => {} }; } diff --git a/src/flattening/mod.rs b/src/flattening/mod.rs index 557eca3..dc1bf11 100644 --- a/src/flattening/mod.rs +++ b/src/flattening/mod.rs @@ -136,6 +136,7 @@ impl Module { pub fn get_instruction_span(&self, instr_id : FlatID) -> Span { match &self.instructions[instr_id] { Instruction::SubModule(sm) => sm.module_name_span, + Instruction::FuncCall(fc) => fc.whole_func_span, Instruction::Declaration(decl) => decl.get_span(), Instruction::Wire(w) => w.span, Instruction::Write(conn) => conn.to_span, @@ -393,6 +394,9 @@ impl WireSource { WireSource::Constant(_) => {} } } + pub const fn new_error() -> WireSource { + WireSource::Constant(Value::Error) + } } #[derive(Debug, Clone)] @@ -469,9 +473,9 @@ impl Declaration { #[derive(Debug)] pub struct SubModuleInstance { pub module_uuid : ModuleUUID, + pub module_name_span : Span, /// Name is not always present in source code. Such as in inline function call syntax: my_mod(a, b, c) pub name : Option<(String, Span)>, - pub module_name_span : Span, pub local_interface_domains : FlatAlloc, pub documentation : Documentation } @@ -486,6 +490,28 @@ impl SubModuleInstance { } } +#[derive(Debug)] +pub struct FuncCallInstruction { + pub submodule_instruction : FlatID, + pub module_uuid : ModuleUUID, + /// arguments.len() == func_call_inputs.len() ALWAYS + pub arguments : Vec, + /// arguments.len() == func_call_inputs.len() ALWAYS + pub func_call_inputs : PortIDRange, + pub func_call_outputs : PortIDRange, + /// If this is None, that means the submodule was declared implicitly. Hence it could also be used at compiletime + pub name_span : Option, + pub arguments_span : BracketSpan, + pub whole_func_span : Span, +} + +impl FuncCallInstruction { + pub fn could_be_at_compile_time(&self) -> bool { + todo!("self.name_span.is_none() but also other requirements, like if the module is a function") + } +} + + #[derive(Debug)] pub struct IfStatement { pub condition : FlatID, @@ -506,6 +532,7 @@ pub struct ForStatement { #[derive(Debug)] pub enum Instruction { SubModule(SubModuleInstance), + FuncCall(FuncCallInstruction), Declaration(Declaration), Wire(WireInstance), Write(Write), @@ -530,9 +557,9 @@ impl Instruction { sm } #[track_caller] - pub fn unwrap_write(&self) -> &Write { - let Self::Write(sm) = self else {panic!("unwrap_write on not a Write! Found {self:?}")}; - sm + pub fn unwrap_func_call(&self) -> &FuncCallInstruction { + let Self::FuncCall(fc) = self else {panic!("unwrap_func_call on not a FuncCallInstruction! Found {self:?}")}; + fc } } diff --git a/src/flattening/parse.rs b/src/flattening/parse.rs index 307c610..ab45435 100644 --- a/src/flattening/parse.rs +++ b/src/flattening/parse.rs @@ -1,11 +1,11 @@ -use std::{iter::zip, ops::{Deref, DerefMut}, str::FromStr}; +use std::{ops::{Deref, DerefMut}, str::FromStr}; use num::BigInt; use sus_proc_macro::{field, kind, kw}; use crate::{ - arena_alloc::{UUIDRange, UUID}, debug::SpanDebugger, errors::ErrorCollector, file_position::{BracketSpan, Span}, linker::{with_module_editing_context, ConstantUUIDMarker, Linker, ModuleUUID, ModuleUUIDMarker, NameElem, NameResolver, NamedConstant, NamedType, ResolvedName, Resolver, TypeUUIDMarker, WorkingOnResolver}, parser::Cursor, value::Value + arena_alloc::{UUIDRange, UUIDRangeIter, UUID}, debug::SpanDebugger, errors::ErrorCollector, file_position::{BracketSpan, Span}, linker::{with_module_editing_context, ConstantUUIDMarker, Linker, ModuleUUID, ModuleUUIDMarker, NameElem, NameResolver, NamedConstant, NamedType, ResolvedName, Resolver, TypeUUIDMarker, WorkingOnResolver}, parser::Cursor, value::Value }; use super::name_context::LocalVariableContext; @@ -133,7 +133,7 @@ struct FlatteningContext<'l, 'errs> { name_resolver : NameResolver<'l, 'errs>, errors : &'errs ErrorCollector<'l>, - ports_to_visit : UUIDRange, + ports_to_visit : UUIDRangeIter, local_variable_context : LocalVariableContext<'l, FlatID> } @@ -298,72 +298,67 @@ impl<'l, 'errs> FlatteningContext<'l, 'errs> { ) } - fn desugar_func_call(&mut self, cursor : &mut Cursor) -> Option<(ModuleUUID, Option, FlatID, PortIDRange)> { + fn alloc_error(&mut self, span : Span) -> FlatID { + self.working_on.instructions.alloc(Instruction::Wire(WireInstance{typ : FullType::new_unset(), span, source : WireSource::new_error()})) + } + + fn flatten_func_call(&mut self, cursor : &mut Cursor) -> Option { + let whole_func_span = cursor.span(); cursor.go_down(kind!("func_call"), |cursor| { cursor.field(field!("name")); - let instantiation_flat_id = self.get_or_alloc_module_by_global_identifier(cursor); + let function_root = self.get_or_alloc_module_by_global_identifier(cursor); cursor.field(field!("arguments")); let arguments_span = BracketSpan::from_outer(cursor.span()); - let arguments = cursor.collect_list(kind!("parenthesis_expression_list"), |cursor| { + let mut arguments = cursor.collect_list(kind!("parenthesis_expression_list"), |cursor| { self.flatten_expr(cursor) }); - let (instantiation_flat_id, submodule_name_span) = instantiation_flat_id?; - let func_instantiation = self.working_on.instructions[instantiation_flat_id].unwrap_submodule(); + let (submodule_instruction, name_span) = function_root?; + let func_module = self.working_on.instructions[submodule_instruction].unwrap_submodule(); - - let module_uuid = func_instantiation.module_uuid; + let module_uuid = func_module.module_uuid; let md = &self.modules[module_uuid]; - let inputs = md.interfaces[Module::MAIN_INTERFACE_ID].func_call_inputs; - let outputs = md.interfaces[Module::MAIN_INTERFACE_ID].func_call_outputs; - + let interface = &md.interfaces[Module::MAIN_INTERFACE_ID]; + let func_call_inputs = interface.func_call_inputs; + let func_call_outputs = interface.func_call_outputs; + let arg_count = arguments.len(); - let expected_arg_count = inputs.len(); + let expected_arg_count = func_call_inputs.len(); - let mut args = arguments.as_slice(); - if arg_count != expected_arg_count { if arg_count > expected_arg_count { // Too many args, complain about excess args at the end - let excess_args_span = Span::new_overarching(self.working_on.instructions[args[expected_arg_count]].unwrap_wire().span, self.working_on.instructions[*args.last().unwrap()].unwrap_wire().span); + let excess_args_span = Span::new_overarching(self.working_on.instructions[arguments[expected_arg_count]].unwrap_wire().span, self.working_on.instructions[*arguments.last().unwrap()].unwrap_wire().span); self.errors .error(excess_args_span, format!("Excess argument. Function takes {expected_arg_count} args, but {arg_count} were passed.")) .info_obj(&md.link_info); // Shorten args to still get proper type checking for smaller arg array - args = &args[..expected_arg_count]; + arguments.truncate(expected_arg_count); } else { // Too few args, mention missing argument names self.errors .error(arguments_span.close_bracket(), format!("Too few arguments. Function takes {expected_arg_count} args, but {arg_count} were passed.")) .info_obj(&md.link_info); - } - } - for (port_id, arg_read_side) in zip(inputs, args) { - let arg_wire = self.working_on.instructions[*arg_read_side].unwrap_wire(); - let arg_wire_span = arg_wire.span; - let root = WireReferenceRoot::SubModulePort(PortInfo{ - submodule_name_span, - submodule_flat : instantiation_flat_id, - port : port_id, - port_name_span : None, // Not present in function call notation - port_identifier_typ : IdentifierType::Input - }); - self.working_on.instructions.alloc(Instruction::Write(Write{ - from: *arg_read_side, - to: WireReference{ - root, - path : Vec::new(), - }, - to_span : arg_wire_span, - write_modifiers : WriteModifiers::Connection{num_regs : 0, regs_span : arg_wire_span.empty_span_at_front()} - })); + while arguments.len() < expected_arg_count { + arguments.push(self.alloc_error(arguments_span.close_bracket())); + } + } } - Some((module_uuid, submodule_name_span, instantiation_flat_id, outputs)) + Some(self.working_on.instructions.alloc(Instruction::FuncCall(FuncCallInstruction{ + submodule_instruction, + module_uuid, + arguments, + func_call_inputs, + func_call_outputs, + name_span, + arguments_span, + whole_func_span + }))) }) } @@ -445,29 +440,30 @@ impl<'l, 'errs> FlatteningContext<'l, 'errs> { WireSource::BinaryOp{op, left, right} }) } else if kind == kind!("func_call") { - if let Some((md_id, submodule_name_span, submodule, outputs)) = self.desugar_func_call(cursor) { - if outputs.len() != 1 { - let md = &self.modules[md_id]; + if let Some(fc_id) = self.flatten_func_call(cursor) { + let fc = self.working_on.instructions[fc_id].unwrap_func_call(); + if fc.func_call_outputs.len() != 1 { + let md = &self.modules[fc.module_uuid]; self.errors .error(expr_span, "A function called in this context may only return one result. Split this function call into a separate line instead.") .info_obj(&md.link_info); } - if outputs.len() >= 1 { + if fc.func_call_outputs.len() >= 1 { WireSource::WireRef(WireReference::simple_port(PortInfo{ - submodule_name_span, - submodule_flat: submodule, - port: outputs.0, + submodule_name_span : fc.name_span, + submodule_flat: fc.submodule_instruction, + port: fc.func_call_outputs.0, port_name_span: None, port_identifier_typ: IdentifierType::Output, })) } else { // Function desugaring or using threw an error - WireSource::Constant(Value::Error) + WireSource::new_error() } } else { // Function desugaring or using threw an error - WireSource::Constant(Value::Error) + WireSource::new_error() } } else if kind == kind!("parenthesis_expression") { return cursor.go_down_content(kind!("parenthesis_expression"), |cursor| self.flatten_expr(cursor)); @@ -475,7 +471,7 @@ impl<'l, 'errs> FlatteningContext<'l, 'errs> { if let Some(wr) = self.flatten_wire_reference(cursor).expect_ready(self) { WireSource::WireRef(wr) } else { - WireSource::Constant(Value::Error) + WireSource::new_error() } }; @@ -498,7 +494,7 @@ impl<'l, 'errs> FlatteningContext<'l, 'errs> { let root = WireReferenceRoot::LocalDecl(decl_id, expr_span); PartialWireReference::Ready(WireReference{root, path : Vec::new()}) } - Instruction::Wire(_) | Instruction::Write(_) | Instruction::IfStatement(_) | Instruction::ForStatement(_) => unreachable!() + Instruction::Wire(_) | Instruction::Write(_) | Instruction::IfStatement(_) | Instruction::ForStatement(_) | Instruction::FuncCall(_) => unreachable!() } } LocalOrGlobal::Global(global) => { @@ -599,12 +595,17 @@ impl<'l, 'errs> FlatteningContext<'l, 'errs> { fn flatten_assign_function_call(&mut self, to : Vec<(Option<(WireReference, WriteModifiers)>, Span)>, cursor : &mut Cursor) { let func_call_span = cursor.span(); - let to_iter = if let Some((md_id, submodule_name_span, submodule, outputs)) = self.desugar_func_call(cursor) { + let to_iter = if let Some(fc_id) = self.flatten_func_call(cursor) { + let fc = self.working_on.instructions[fc_id].unwrap_func_call(); + + let outputs = fc.func_call_outputs; + let submodule_name_span = fc.name_span; + let submodule_flat = fc.submodule_instruction; let num_func_outputs = outputs.len(); let num_targets = to.len(); if num_targets != num_func_outputs { - let md = &self.modules[md_id]; + let md = &self.modules[fc.module_uuid]; if num_targets > num_func_outputs { let excess_results_span = Span::new_overarching(to[num_func_outputs].1, to.last().unwrap().1); self.errors @@ -624,11 +625,11 @@ impl<'l, 'errs> FlatteningContext<'l, 'errs> { typ: FullType::new_unset(), span: func_call_span, source: WireSource::WireRef(WireReference::simple_port(PortInfo{ - submodule_name_span, - submodule_flat: submodule, port, port_name_span: None, port_identifier_typ: IdentifierType::Output, + submodule_name_span, + submodule_flat, })) })); self.working_on.instructions.alloc(Instruction::Write(Write{from, to, to_span, write_modifiers})); @@ -640,7 +641,7 @@ impl<'l, 'errs> FlatteningContext<'l, 'errs> { }; for leftover_to in to_iter { if let (Some((to, write_modifiers)), to_span) = leftover_to { - let err_id = self.working_on.instructions.alloc(Instruction::Wire(WireInstance{typ : FullType::new_unset(), span : func_call_span, source : WireSource::Constant(Value::Error)})); + let err_id = self.working_on.instructions.alloc(Instruction::Wire(WireInstance{typ : FullType::new_unset(), span : func_call_span, source : WireSource::new_error()})); self.working_on.instructions.alloc(Instruction::Write(Write{from: err_id, to, to_span, write_modifiers})); } } @@ -869,7 +870,7 @@ pub fn flatten<'cursor_linker, 'errs>(linker : *mut Linker, module_uuid : Module println!("Flattening {}", modules.working_on.link_info.name); let mut context = FlatteningContext { - ports_to_visit : modules.working_on.ports.id_range(), + ports_to_visit : modules.working_on.ports.id_range().into_iter(), errors : name_resolver.errors, modules, types, diff --git a/src/flattening/typechecking.rs b/src/flattening/typechecking.rs index 188964b..6b3ed43 100644 --- a/src/flattening/typechecking.rs +++ b/src/flattening/typechecking.rs @@ -75,13 +75,25 @@ impl<'l, 'errs> DerefMut for TypeCheckingContext<'l, 'errs> { } impl<'l, 'errs> TypeCheckingContext<'l, 'errs> { - fn get_decl_of_module_port<'s>(&'s self, port : PortInfo) -> (&'s Declaration, FileUUID) { - let submodule_id = self.working_on.instructions[port.submodule_flat].unwrap_submodule().module_uuid; + fn get_decl_of_module_port<'s>(&'s self, port : PortID, submodule_instr : FlatID) -> (&'s Declaration, FileUUID) { + let submodule_id = self.working_on.instructions[submodule_instr].unwrap_submodule().module_uuid; let module = &self.modules[submodule_id]; - let decl = module.get_port_decl(port.port); + let decl = module.get_port_decl(port); (decl, module.link_info.file) } + fn get_type_of_port(&self, port : PortID, submodule_instr : FlatID) -> FullType { + let (decl, _file) = self.get_decl_of_module_port(port, submodule_instr); + let submodule_inst = self.working_on.instructions[submodule_instr].unwrap_submodule(); + let submodule_module = &self.modules[submodule_inst.module_uuid]; + let port_interface = submodule_module.ports[port].interface; + let port_local_domain = submodule_inst.local_interface_domains[port_interface]; + FullType { + typ : decl.typ_expr.to_type(), + domain : DomainType::Physical(port_local_domain) + } + } + fn get_wire_ref_declaration_point(&self, wire_ref_root : &WireReferenceRoot) -> Option { match wire_ref_root { WireReferenceRoot::LocalDecl(id, _) => { @@ -93,8 +105,8 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> { linker_cst.get_span_file() } WireReferenceRoot::SubModulePort(port) => { - let (decl, file) = self.get_decl_of_module_port(*port); - Some((decl.typ_expr.get_span(), file)) + let (decl, file) = self.get_decl_of_module_port(port.port, port.submodule_flat); + Some((decl.get_span(), file)) } } } @@ -110,15 +122,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> { linker_cst.get_full_type() } WireReferenceRoot::SubModulePort(port) => { - let (decl, _file) = self.get_decl_of_module_port(*port); - let submodule_inst = self.working_on.instructions[port.submodule_flat].unwrap_submodule(); - let submodule_module = &self.modules[submodule_inst.module_uuid]; - let port_interface = submodule_module.ports[port.port].interface; - let port_local_domain = submodule_inst.local_interface_domains[port_interface]; - FullType { - typ : decl.typ_expr.to_type(), - domain : DomainType::Physical(port_local_domain) - } + self.get_type_of_port(port.port, port.submodule_flat) } }; @@ -152,6 +156,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> { } match &self.working_on.instructions[inst_id] { Instruction::SubModule(_) => {} + Instruction::FuncCall(_) => {} Instruction::Declaration(decl) => { if decl.identifier_type.is_generative() { assert!(matches!(self.declaration_depths[inst_id], ExtraInstructionData::Unset)); @@ -174,7 +179,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> { return; } WireReferenceRoot::SubModulePort(port) => { - let r = self.get_decl_of_module_port(port); + let r = self.get_decl_of_module_port(port.port, port.submodule_flat); if !r.0.identifier_type.unwrap_is_input() { self.errors.error(conn.to_span, "Cannot assign to a submodule output port") @@ -281,6 +286,20 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> { let Instruction::Wire(w) = &mut self.working_on.instructions[instr_id] else {unreachable!()}; w.typ = result_typ; } + Instruction::FuncCall(fc) => { + for (port, arg) in std::iter::zip(fc.func_call_inputs.into_iter(), &fc.arguments) { + let write_to_type = self.get_type_of_port(port, fc.submodule_instruction); + + let (decl, file) = self.get_decl_of_module_port(port, fc.submodule_instruction); + let declared_here = (decl.get_span(), file); + + // Typecheck the value with target type + let from_wire = self.working_on.instructions[*arg].unwrap_wire(); + + from_wire.span.debug(); + self.type_checker.typecheck_write_to(&from_wire.typ, from_wire.span, &write_to_type, "function argument", Some(declared_here)); + } + } Instruction::Write(conn) => { // Typecheck digging down into write side let mut write_to_type = self.get_type_of_wire_reference(&conn.to, conn.to_span); @@ -405,6 +424,11 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> { } } Instruction::SubModule(_) => {} // TODO Dependencies should be added here if for example generative templates get added + Instruction::FuncCall(fc) => { + for a in &fc.arguments { + instruction_fanins[fc.submodule_instruction].push(*a); + } + } Instruction::Declaration(decl) => { decl.typ_expr.for_each_generative_input(|id| instruction_fanins[inst_id].push(id)); } diff --git a/src/instantiation/execute.rs b/src/instantiation/execute.rs index 444fd2d..1002f87 100644 --- a/src/instantiation/execute.rs +++ b/src/instantiation/execute.rs @@ -102,11 +102,15 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { }) } - fn instantiate_wire_ref(&self, wire_ref : &WireReference) -> ExecutionResult { - // Later on, potentially allow module arrays - let mut path = Vec::new(); + fn instantiate_port_wire_ref_root(&self, port : PortID, submodule_instr : FlatID) -> InstantiatedWireRef { + let sm = &self.submodules[self.generation_state[submodule_instr].unwrap_submodule_instance()]; + let root = RealWireRefRoot::Wire(sm.port_map[port]); + + InstantiatedWireRef{root, path : Vec::new()} + } - let root = match &wire_ref.root { + fn realize_wire_ref_root(&self, wire_ref_root : &WireReferenceRoot) -> InstantiatedWireRef { + let root = match wire_ref_root { &WireReferenceRoot::LocalDecl(decl_id, _) => { match &self.generation_state[decl_id] { SubModuleOrWire::Wire(w) => RealWireRefRoot::Wire(*w), @@ -120,11 +124,17 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { RealWireRefRoot::Constant(val.clone()) } WireReferenceRoot::SubModulePort(port) => { - let sm = &self.submodules[self.generation_state[port.submodule_flat].unwrap_submodule_instance()]; - RealWireRefRoot::Wire(sm.port_map[port.port]) + return self.instantiate_port_wire_ref_root(port.port, port.submodule_flat); } }; + InstantiatedWireRef{root, path : Vec::new()} + } + + fn instantiate_wire_ref(&self, wire_ref : &WireReference) -> InstantiatedWireRef { + // Later on, potentially allow module arrays + let mut result = self.realize_wire_ref_root(&wire_ref.root); + for v in &wire_ref.path { match v { &WireReferencePathElement::ArrayIdx{idx, bracket_span} => { @@ -134,10 +144,10 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { &SubModuleOrWire::Wire(idx_wire) => { assert!(self.wires[idx_wire].typ == INT_CONCRETE_TYPE); - path.push(RealWirePathElem::MuxArrayWrite{ span:bracket_span, idx_wire}); + result.path.push(RealWirePathElem::MuxArrayWrite{ span:bracket_span, idx_wire}); } SubModuleOrWire::CompileTimeValue(cv) => { - path.push(RealWirePathElem::ConstArrayWrite{ + result.path.push(RealWirePathElem::ConstArrayWrite{ idx : cv.value.unwrap_integer().clone(), span : bracket_span }); @@ -147,24 +157,29 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { } } - Ok(InstantiatedWireRef{root, path}) + result + } + + fn instantiate_write_to_wire(&mut self, write_to_wire : WireID, to_path : Vec, from : WireID, num_regs : i64, original_instruction : FlatID, condition : Option) { + let from = ConnectFrom { + num_regs, + from, + condition, + original_connection : original_instruction + }; + + let RealWireDataSource::Multiplexer{is_state : _, sources} = &mut self.wires[write_to_wire].source else {caught_by_typecheck!("Should only be a writeable wire here")}; + + sources.push(MultiplexerSource{from, to_path}); } - fn process_connection(&mut self, wire_ref_inst : InstantiatedWireRef, write_modifiers : &WriteModifiers, conn_from : FlatID, original_connection : FlatID, condition : Option) -> ExecutionResult<()> { + fn instantiate_connection(&mut self, wire_ref_inst : InstantiatedWireRef, write_modifiers : &WriteModifiers, conn_from : FlatID, original_connection : FlatID, condition : Option) -> ExecutionResult<()> { match write_modifiers { WriteModifiers::Connection{num_regs, regs_span : _} => { match &wire_ref_inst.root { RealWireRefRoot::Wire(write_to_wire) => { - let from = ConnectFrom { - num_regs : *num_regs, - from : self.get_wire_or_constant_as_wire(conn_from), - condition, - original_connection - }; - - let RealWireDataSource::Multiplexer{is_state : _, sources} = &mut self.wires[*write_to_wire].source else {caught_by_typecheck!("Should only be a writeable wire here")}; - - sources.push(MultiplexerSource{from, to_path : wire_ref_inst.path}); + let from = self.get_wire_or_constant_as_wire(conn_from); + self.instantiate_write_to_wire(*write_to_wire, wire_ref_inst.path, from, *num_regs, original_connection, condition); } RealWireRefRoot::Generative(decl_id) => { let found_v = self.generation_state[conn_from].unwrap_generation_value().clone(); @@ -216,7 +231,7 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { fn compute_compile_time(&self, wire_inst : &WireInstance) -> ExecutionResult { Ok(match &wire_inst.source { WireSource::WireRef(wire_ref) => { - let wire_ref_instance = self.instantiate_wire_ref(wire_ref)?; + let wire_ref_instance = self.instantiate_wire_ref(wire_ref); self.compute_compile_time_wireref(wire_ref_instance)? } &WireSource::UnaryOp{op, right} => { @@ -263,7 +278,7 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { fn wire_to_real_wire(&mut self, w: &WireInstance, original_instruction : FlatID) -> ExecutionResult { let source = match &w.source { WireSource::WireRef(wire_ref) => { - let inst = self.instantiate_wire_ref(wire_ref)?; + let inst = self.instantiate_wire_ref(wire_ref); let root_wire = self.get_wire_ref_root_as_wire(inst.root, original_instruction); if inst.path.is_empty() { // Little optimization reduces instructions @@ -405,8 +420,19 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { } } Instruction::Write(conn) => { - let to_inst = self.instantiate_wire_ref(&conn.to)?; - self.process_connection(to_inst, &conn.write_modifiers, conn.from, original_instruction, condition)?; + let to_inst = self.instantiate_wire_ref(&conn.to); + self.instantiate_connection(to_inst, &conn.write_modifiers, conn.from, original_instruction, condition)?; + continue; + } + Instruction::FuncCall(fc) => { + let submod_id = self.generation_state[fc.submodule_instruction].unwrap_submodule_instance(); + for (port, arg) in std::iter::zip(fc.func_call_inputs.iter(), fc.arguments.iter()) { + let from = self.get_wire_or_constant_as_wire(*arg); + let submod = &self.submodules[submod_id]; + let port_wire = submod.port_map[port]; + self.instantiate_write_to_wire(port_wire, Vec::new(), from, 0, original_instruction, condition); + } + continue; } Instruction::IfStatement(stm) => { diff --git a/src/instantiation/latency_count.rs b/src/instantiation/latency_count.rs index c5c3425..d8ccc6e 100644 --- a/src/instantiation/latency_count.rs +++ b/src/instantiation/latency_count.rs @@ -77,9 +77,9 @@ fn make_path_info_string(writes : &[PathMuxSource<'_>], from_latency : i64, from fn filter_unique_write_flats<'w>(writes : &'w [PathMuxSource<'w>], instructions : &'w FlatAlloc) -> Vec<&'w crate::flattening::Write> { let mut result : Vec<&'w crate::flattening::Write> = Vec::new(); for w in writes { - let original_write = instructions[w.mux_input.from.original_connection].unwrap_write(); - - if !result.iter().any(|found_write| std::ptr::eq(*found_write, original_write)) {result.push(original_write)} + if let Instruction::Write(original_write) = &instructions[w.mux_input.from.original_connection] { + if !result.iter().any(|found_write| std::ptr::eq(*found_write, original_write)) {result.push(original_write)} + } } result } diff --git a/src/instantiation/mod.rs b/src/instantiation/mod.rs index 791f010..acfa1d3 100644 --- a/src/instantiation/mod.rs +++ b/src/instantiation/mod.rs @@ -10,7 +10,7 @@ use std::{cell::RefCell, ops::Deref, rc::Rc}; use num::BigInt; use crate::{ - arena_alloc::{FlatAlloc, UUIDMarker, UUID}, config, errors::{CompileError, ErrorCollector, ErrorStore}, file_position::BracketSpan, flattening::{BinaryOperator, FlatID, FlatIDMarker, Module, PortID, PortIDMarker, UnaryOperator}, linker::{Linker, ModuleUUID}, concrete_type::ConcreteType, value::{TypedValue, Value} + arena_alloc::{FlatAlloc, UUIDMarker, UUID}, concrete_type::ConcreteType, config, errors::{CompileError, ErrorCollector, ErrorStore}, file_position::BracketSpan, flattening::{BinaryOperator, FlatID, FlatIDMarker, Module, PortID, PortIDMarker, UnaryOperator}, linker::{Linker, ModuleUUID}, value::{TypedValue, Value} }; use self::latency_algorithm::SpecifiedLatency; diff --git a/test.sus b/test.sus index 0f07aee..ef75fee 100644 --- a/test.sus +++ b/test.sus @@ -808,3 +808,12 @@ module test_separated_domain : int main { } + +module no_port_module {} + +module use_no_input_module { + no_port_module() + + no_port_module no_port + no_port() +}