Skip to content

Commit

Permalink
Do not combine condition domain with every wire reference, only writes
Browse files Browse the repository at this point in the history
  • Loading branch information
VonTum committed Jul 5, 2024
1 parent bfdec36 commit 7073206
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 173 deletions.
48 changes: 31 additions & 17 deletions src/flattening/typechecking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}
}

fn get_type_of_wire_reference(&self, wire_ref : &WireReference, wire_ref_span : Span) -> FullType {
fn get_type_of_wire_reference(&self, wire_ref : &WireReference) -> FullType {
let mut write_to_type = match &wire_ref.root {
WireReferenceRoot::LocalDecl(id, _) => {
let decl_root = self.working_on.instructions[*id].unwrap_wire_declaration();
Expand All @@ -122,14 +122,6 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}
};

if let Some(condition_domain) = self.get_current_condition_domain() {
write_to_type.domain = self.type_checker.combine_domains::<false, _>(&write_to_type.domain, &DomainType::Physical(condition_domain.0), |wire_ref_domain_name, condition_domain_name| {
let wire_ref_domain_name = wire_ref_domain_name.unwrap();
self.errors.error(wire_ref_span, format!("Attempting to access a wire from domain '{wire_ref_domain_name}' within a condition in domain '{condition_domain_name}'"))
.info_same_file(condition_domain.1, format!("This condition has domain '{condition_domain_name}'"));
})
}

for p in &wire_ref.path {
match p {
&WireReferencePathElement::ArrayAccess{idx, bracket_span} => {
Expand Down Expand Up @@ -264,6 +256,25 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}
}

/// TODO: writes to declarations that are in same scope need not be checked as such.
///
/// This allows to work with temporaries of a different domain within an if statement
///
/// Which could allow for a little more encapsulation in certain circumstances
///
/// Also, this meshes with the thing where we only add condition wires to writes that go
/// outside of a condition block
fn join_with_condition(&self, ref_domain : &DomainType, span : Span) {
if let Some(condition_domain) = self.get_current_condition_domain() {
// Just check that
self.type_checker.combine_domains::<false, _>(ref_domain, &DomainType::Physical(condition_domain.0), |wire_ref_domain_name, condition_domain_name| {
let wire_ref_domain_name = wire_ref_domain_name.unwrap();
self.errors.error(span, format!("Attempting to write to a wire from domain '{wire_ref_domain_name}' within a condition in domain '{condition_domain_name}'"))
.info_same_file(condition_domain.1, format!("This condition has domain '{condition_domain_name}'"));
});
}
}

fn typecheck_visit_instruction(&mut self, instr_id : FlatID) {
match &self.working_on.instructions[instr_id] {
Instruction::SubModule(sm) => {
Expand Down Expand Up @@ -298,14 +309,13 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
let start = &self.working_on.instructions[stm.start].unwrap_wire();
let end = &self.working_on.instructions[stm.end].unwrap_wire();

let declared_here = Some((loop_var.decl_span, self.errors.file));
self.type_checker.typecheck_and_generative::<true>(&start.typ, start.span, &loop_var.typ.typ, "for loop start", declared_here);
self.type_checker.typecheck_and_generative::<true>(&end.typ, end.span, &loop_var.typ.typ, "for loop end", declared_here);
self.type_checker.typecheck_and_generative::<true>(&start.typ, start.span, &loop_var.typ.typ, "for loop start", None);
self.type_checker.typecheck_and_generative::<true>(&end.typ, end.span, &loop_var.typ.typ, "for loop end", None);
}
Instruction::Wire(w) => {
let result_typ = match &w.source {
WireSource::WireRef(from_wire) => {
self.get_type_of_wire_reference(from_wire, w.span)
self.get_type_of_wire_reference(from_wire)
}
&WireSource::UnaryOp{op, right} => {
let right_wire = self.working_on.instructions[right].unwrap_wire();
Expand Down Expand Up @@ -333,17 +343,22 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
// Typecheck the value with target type
let from_wire = self.working_on.instructions[*arg].unwrap_wire();

from_wire.span.debug();
self.join_with_condition(&write_to_type.domain, from_wire.span.debug());
self.type_checker.typecheck_write_to(&from_wire.typ, from_wire.span, &write_to_type, "function argument", Some(declared_here));
}
}
Instruction::Write(conn) => {
// Typecheck digging down into write side
let mut write_to_type = self.get_type_of_wire_reference(&conn.to, conn.to_span);
let mut write_to_type = self.get_type_of_wire_reference(&conn.to);

self.join_with_condition(&write_to_type.domain, conn.to_span.debug());

let declared_here = self.get_wire_ref_declaration_point(&conn.to.root);

let write_context = match conn.write_modifiers {
WriteModifiers::Connection { num_regs:_, regs_span:_ } => "connection",
WriteModifiers::Connection { num_regs:_, regs_span:_ } => {
"connection"
},
WriteModifiers::Initial { initial_kw_span:_ } => {
write_to_type.domain = DomainType::Generative;
"initial value"
Expand Down Expand Up @@ -371,7 +386,6 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}

for elem_id in self.working_on.instructions.id_range() {
println!("{elem_id:?}: {:?}", &self.type_checker);
self.control_flow_visit_instruction(elem_id);
self.typecheck_visit_instruction(elem_id);
}
Expand Down
2 changes: 1 addition & 1 deletion src/instantiation/latency_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> {
match target {
LatencyMeaning::Wire(wire_id) => {
self.iter_sources_with_min_latency(&self.wires[*wire_id].source, |from, delta_latency| {
assert!(self.wires[from].domain == domain_id);
assert_eq!(self.wires[from].domain, domain_id);
fanins.push_to_last_group(FanInOut{other : latency_node_mapper.map_wire_to_latency_node[from], delta_latency});
});
}
Expand Down
155 changes: 155 additions & 0 deletions test.sus
Original file line number Diff line number Diff line change
Expand Up @@ -930,3 +930,158 @@ module templateModule<abc> {

}
*/

module test<T> {
interface test : ::int::<beep = 20 > 3; BEEP = int::<;>> ab
input gen int MY_INPUT

MY_INPUT = 3

input int beep

beep = 3

FIFO::<BITWIDTH = 4;> badoop
}

module use_test {
test::<3;> test_mod


}


module tinyTestMod {
input gen int beep

output int o = beep
}


module testTinyTestMod {
tinyTestMod::<beep = 3;> a
tinyTestMod::<beep = 4;> b
tinyTestMod::<beep = 3;> c
}



module tree_add {
input gen int WIDTH

input int[WIDTH] values
output int sum

if WIDTH == 1 {
sum = values[0]
} else {
gen int HALF_WIDTH = WIDTH / 2
tree_add::<HALF_WIDTH;> left
tree_add::<HALF_WIDTH;> right

for int i in 0..HALF_WIDTH {
left.values[i] = values[i]
right.values[i] = values[i+HALF_WIDTH]
}

if WIDTH % 2 == 0 {
reg sum = left.sum + right.sum
} else {
reg sum = left.sum + right.sum + values[WIDTH - 1]
}
}
}

module make_tree_add {
gen int SIZE = 255

int[SIZE] vs

for int i in 0..SIZE {
vs[i] = i
}

tree_add::<SIZE;> tr

tr.values = vs

output int beep = tr.sum
}


module replicate<T> {
input gen int NUM_REPLS

input T data

output T[NUM_REPLS] result

for int i in 0..NUM_REPLS {
result[i] = data
}
}

module use_replicate {
replicate::<NUM_REPLS = 50; T = bool> a
replicate::<NUM_REPLS = 20; T = int[30]> b
}

module permute_t<T> {
input gen int SIZE

input gen int[SIZE] SOURCES

interface permute : T[SIZE] d_in -> T[SIZE] d_out

for int i in 0..SIZE {
d_out[i] = d_in[SOURCES[i]]
}
}

module use_permute {
gen int[8] indices

indices[0] = 3
indices[1] = 2
indices[2] = 4
indices[3] = 5
indices[4] = 1
indices[5] = 2
indices[6] = 7
indices[7] = 6


int[2] inArr

inArr[0] = 2387
inArr[1] = 786823

permute_t::<SIZE = 8, SOURCES = indices; T = int> permut

int[8] beep = permut.permute(indices)
}

module instruction_decoder {
interface from : bool[32] instr
interface is_jump
interface is_load
interface is_arith

}

module run_instruction {
interface run_instruction : bool[32] instr
instruction_decoder decoder
decoder.from(instr)

if decoder.is_jump() : int target_addr {
// ...
}
if decoder.is_load() : int reg_to, int addr {
// ...
}
if decoder.is_arith() : int reg_a, int reg_b, Operator op {
// ...
}
}

Loading

0 comments on commit 7073206

Please sign in to comment.