diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index 0035423b29..f64e8c0323 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -59,6 +59,10 @@ type | type '[' intLit ']' // Vector ; +expField + : fieldId ':' exp + ; + field : 'flip'? fieldId ':' type ; @@ -168,6 +172,8 @@ exp | 'mux(' exp exp exp ')' | 'validif(' exp exp ')' | primop exp* intLit* ')' + | '{' expField* '}' // Bundle Expression + | '[' exp* ']' // Vector Expression ; id diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto index 3d2c89f11d..c4b15fed1b 100644 --- a/src/main/proto/firrtl.proto +++ b/src/main/proto/firrtl.proto @@ -377,6 +377,20 @@ message Firrtl { Width point = 3; } + message BundleExpression { + message Field { + // Required + string name = 1; + // Required + Expression value = 2; + } + repeated Field field = 1; + } + + message VectorExpression { + repeated Expression exp = 1; + } + message ValidIf { // Required. Expression condition = 1; @@ -475,6 +489,8 @@ message Firrtl { UIntLiteral uint_literal = 2; SIntLiteral sint_literal = 3; FixedLiteral fixed_literal = 11; + BundleExpression bundle_expression = 12; + VectorExpression vector_expression = 13; ValidIf valid_if = 4; //ExtractBits extract_bits = 5; Mux mux = 6; diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index d75c504a1c..292388c965 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -592,6 +592,8 @@ object Utils extends LazyLogging { case ex: DoPrim => MALE case ex: UIntLiteral => MALE case ex: SIntLiteral => MALE + case ex: BundleExpression => MALE + case ex: VectorExpression => MALE case ex: Mux => MALE case ex: ValidIf => MALE case WInvalid => MALE @@ -623,6 +625,8 @@ object Utils extends LazyLogging { case ex: DoPrim => SourceFlow case ex: UIntLiteral => SourceFlow case ex: SIntLiteral => SourceFlow + case ex: BundleExpression => SourceFlow + case ex: VectorExpression => SourceFlow case ex: Mux => SourceFlow case ex: ValidIf => SourceFlow case WInvalid => SourceFlow diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index 112343d1e3..b9f4da2938 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -175,6 +175,10 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } } + private def visitExpField[FirrtlNode](ctx: ExpFieldContext): (String, Expression) = { + (ctx.fieldId.getText, visitExp(ctx.exp)) + } + // Special case "type" of CHIRRTL mems because their size can be BigInt private def visitCMemType(ctx: TypeContext): (Type, BigInt) = { def loc: String = s"${ctx.getStart.getLine}:${ctx.getStart.getCharPositionInLine}" @@ -387,6 +391,10 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } case "validif(" => ValidIf(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType) case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType) + case "{" => + BundleExpression(ctx.expField.asScala.map(visitExpField)) + case "[" => + VectorExpression(ctx_exp.map(visitExp), UnknownType) } } } diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index eb4a665f22..2019ba51b4 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -75,6 +75,19 @@ object WSubField { def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UnknownFlow) def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UnknownFlow) } +object WSubLiteral { + def unapply(w: Expression): Option[Expression] = w match { + case WSubField(BundleExpression(lits), name, _, _) => + lits.collectFirst({ case (n, value) if n == name => value }) + case WSubField(WSubLiteral(BundleExpression(lits)), name, _, _) => + lits.collectFirst({ case (n, value) if n == name => value }) + case WSubIndex(VectorExpression(exprs, _), index, _, _) => + exprs.lift(index) + case WSubIndex(WSubLiteral(VectorExpression(exprs, _)), index, _, _) => + exprs.lift(index) + case _ => None + } +} case class WSubIndex(expr: Expression, value: Int, tpe: Type, flow: Flow) extends Expression with GenderFromFlow { def serialize: String = s"${expr.serialize}[$value]" def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala index c7cb98d234..537aa45f21 100644 --- a/src/main/scala/firrtl/annotations/Target.scala +++ b/src/main/scala/firrtl/annotations/Target.scala @@ -115,6 +115,8 @@ object Target { case d: DoPrim => m.ref("@" + d.serialize) case d: Mux => m.ref("@" + d.serialize) case d: ValidIf => m.ref("@" + d.serialize) + case b: BundleExpression => m.ref("@" + b.serialize) + case v: VectorExpression => m.ref("@" + v.serialize) case d: Literal => m.ref("@" + d.serialize) case other => sys.error(s"Unsupported: $other") } diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 07cbb7e222..ead0b57171 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -240,6 +240,40 @@ case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Liter def foreachType(f: Type => Unit): Unit = Unit def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) } } +case class BundleExpression(lits: Seq[(String, Expression)]) extends Expression { + def tpe = BundleType(lits.map { case (name, lit) => + Field(name = name, flip = Default, tpe = lit.tpe) + }) + def serialize = + "{ " + (lits.map({ case (n, v) => + s"$n : ${v.serialize}" + }) mkString ", ") + " }" + def foreachExpr(f: Expression => Unit): Unit = lits.foreach { case (_, l) => l foreachExpr f } + def foreachType(f: Type => Unit): Unit = lits.foreach { case (_, l) => l foreachType f } + def foreachWidth(f: Width => Unit): Unit = lits.foreach { case (_, l) => l foreachWidth f } + def mapExpr(f: Expression => Expression): Expression = BundleExpression(lits.map { case (s, l) => + (s, l mapExpr f) + }) + def mapType(f: Type => Type): Expression = BundleExpression(lits.map { case (s, l) => + (s, l mapType f) + }) + def mapWidth(f: Width => Width): Expression = BundleExpression(lits.map { case (s, l) => + (s, l mapWidth f) + }) +} +case class VectorExpression(exprs: Seq[Expression], tpe: Type) extends Expression { + def serialize = + "[" + exprs.map(_.serialize).mkString(", ") + "]" // TODO type annotation + def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(_ foreachExpr f) + def foreachType(f: Type => Unit): Unit = exprs.foreach(_ foreachType f) + def foreachWidth(f: Width => Unit): Unit = exprs.foreach(_ foreachWidth f) + def mapExpr(f: Expression => Expression): Expression = + VectorExpression(exprs.map(_ mapExpr f), tpe) + def mapType(f: Type => Type): Expression = + VectorExpression(exprs.map(_ mapType f), tpe) + def mapWidth(f: Width => Width): Expression = + VectorExpression(exprs.map(_ mapWidth f), tpe) +} case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression { def serialize: String = op.serialize + "(" + (args.map(_.serialize) ++ consts.map(_.toString)).mkString(", ") + ")" diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 3c9a7eda76..1e3f1d9f20 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -173,7 +173,8 @@ trait CheckHighFormLike { this: Pass => def validSubexp(info: Info, mname: String)(e: Expression): Unit = { e match { case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error - case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error + case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf | + _: BundleExpression | _: VectorExpression => // No error case _ => errors.append(new InvalidAccessException(info, mname)) } } @@ -470,6 +471,8 @@ object CheckTypes extends Pass with PreservesAll[Transform] { case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval) case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval) case i:IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true) + case _: BundleType | + _: VectorType => (isUInt, isSInt, isClock, isFix, isAsync, isInterval) case UnknownType => errors.append(new IllegalUnknownType(info, mname, e.serialize)) (isUInt, isSInt, isClock, isFix, isAsync, isInterval) diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 7e65bdd12a..2cbcf3e882 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -65,6 +65,11 @@ object ConvertFixedToSInt extends Pass with PreservesAll[Transform] { newExp map updateExpType case e: UIntLiteral => e case e: SIntLiteral => e + case e: BundleExpression => e + case e: VectorExpression => + val point = calcPoint(e.exprs) + val newExprs = e.exprs.map(alignArg(_, point)) + VectorExpression(newExprs, UnknownType) case _ => e map updateExpType match { case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe) case WRef(name, tpe, k, g) => WRef(name, types(name), k, g) diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index d625b6269c..ee4e596661 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -45,7 +45,12 @@ object InferTypes extends Pass with PreservesAll[Transform] { case e: DoPrim => PrimOps.set_primop_type(e) case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval)) case e: ValidIf => e copy (tpe = e.value.tpe) - case e @ (_: UIntLiteral | _: SIntLiteral) => e + case e: VectorExpression => e copy (tpe = VectorType( + e.exprs.map(_.tpe).reduce[Type](mux_type_and_widths), + e.exprs.length + )) + case e: BundleExpression => e.copy(lits = e.lits.map(x => (x._1, infer_types_e(types)(x._2)))) + case e @ (_: UIntLiteral | _: SIntLiteral) => e } def infer_types_s(types: TypeMap)(s: Statement): Statement = s match { @@ -106,6 +111,11 @@ object CInferTypes extends Pass with PreservesAll[Transform] { case (e: DoPrim) => PrimOps.set_primop_type(e) case (e: Mux) => e copy (tpe = mux_type(e.tval, e.fval)) case (e: ValidIf) => e copy (tpe = e.value.tpe) + case e: VectorExpression => e copy (tpe = VectorType( + e.exprs.map(_.tpe).reduce[Type](mux_type), + e.exprs.length + )) + case e: BundleExpression => e.copy(lits = e.lits.map(x => (x._1, infer_types_e(types)(x._2)))) case e @ (_: UIntLiteral | _: SIntLiteral) => e } diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 73ef8a22df..365096760c 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -147,8 +147,11 @@ object LowerTypes extends Transform { def lowerTypesExp(memDataTypeMap: MemDataTypeMap, info: Info, mname: String)(e: Expression): Expression = e match { - case e: WRef => e - case (_: WSubField | _: WSubIndex) => kind(e) match { + case e @ (_: WRef | _: UIntLiteral | _: SIntLiteral | _: BundleExpression | _: VectorExpression) => e + case WSubLiteral(exp) => + lowerTypesExp(memDataTypeMap, info, mname)(exp) + case WSubIndex(ve: VectorExpression, index, _, _) => ve.exprs(index) + case e @ (_: WSubField | _: WSubIndex) => kind(e) match { case InstanceKind => val (root, tail) = splitRef(e) val name = loweredName(tail) @@ -162,10 +165,8 @@ object LowerTypes extends Transform { } case _ => WRef(loweredName(e), e.tpe, kind(e), flow(e)) } - case e: Mux => e map lowerTypesExp(memDataTypeMap, info, mname) - case e: ValidIf => e map lowerTypesExp(memDataTypeMap, info, mname) - case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname) - case e @ (_: UIntLiteral | _: SIntLiteral) => e + case e @ (_: Mux | _: ValidIf | _: DoPrim | _: VectorExpression) => + e map lowerTypesExp(memDataTypeMap, info, mname) } def lowerTypesStmt(memDataTypeMap: MemDataTypeMap, minfo: Info, mname: String, renames: RenameMap)(s: Statement): Statement = { diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 1268cac278..9eeaa165f0 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -218,7 +218,8 @@ object Uniquify extends Transform { val (subExp, subMap) = rec(e.expr, m) val index = uniquifyNamesExp(e.index, map) (WSubAccess(subExp, index, e.tpe, e.flow), subMap) - case (_: UIntLiteral | _: SIntLiteral) => (exp, m) + case _: VectorExpression => (exp, m) + case (_: UIntLiteral | _: SIntLiteral | _: BundleExpression) => (exp, m) case (_: Mux | _: ValidIf | _: DoPrim) => (exp map ((e: Expression) => uniquifyNamesExp(e, map)), m) } @@ -267,6 +268,8 @@ object Uniquify extends Transform { uniquifyNamesExp(e, nameMap.toMap) case e: Mux => e map uniquifyExp case e: ValidIf => e map uniquifyExp + case e: BundleExpression => e map uniquifyExp + case e: VectorExpression => e map uniquifyExp case (_: UIntLiteral | _: SIntLiteral) => e case e: DoPrim => e map uniquifyExp } diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala index ef2ee5bdff..6c4ae3cd71 100644 --- a/src/main/scala/firrtl/proto/FromProto.scala +++ b/src/main/scala/firrtl/proto/FromProto.scala @@ -76,6 +76,21 @@ object FromProto { ir.FixedLiteral(convert(fixed.getValue), width, point) } + def convert(fixed: Firrtl.Expression.BundleExpression.Field): (String, ir.Expression) = { + val value = convert(fixed.getValue) + value match { + case l: ir.Expression => (fixed.getName, l) + } + } + + def convert(fixed: Firrtl.Expression.BundleExpression): ir.BundleExpression = { + ir.BundleExpression(fixed.getFieldList.asScala.map(convert(_))) + } + + def convert(fixed: Firrtl.Expression.VectorExpression): ir.VectorExpression = { + ir.VectorExpression(fixed.getExpList.asScala.map(convert(_)), ir.UnknownType) + } + def convert(subfield: Firrtl.Expression.SubField): ir.SubField = ir.SubField(convert(subfield.getExpression), subfield.getField, ir.UnknownType) @@ -104,6 +119,8 @@ object FromProto { case UINT_LITERAL_FIELD_NUMBER => convert(expr.getUintLiteral) case SINT_LITERAL_FIELD_NUMBER => convert(expr.getSintLiteral) case FIXED_LITERAL_FIELD_NUMBER => convert(expr.getFixedLiteral) + case BUNDLE_EXPRESSION_FIELD_NUMBER => convert(expr.getBundleExpression) + case VECTOR_EXPRESSION_FIELD_NUMBER => convert(expr.getVectorExpression) case PRIM_OP_FIELD_NUMBER => convert(expr.getPrimOp) case MUX_FIELD_NUMBER => convert(expr.getMux) } diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala index 9fe01a0700..e32dbdaf1c 100644 --- a/src/main/scala/firrtl/proto/ToProto.scala +++ b/src/main/scala/firrtl/proto/ToProto.scala @@ -176,6 +176,22 @@ object ToProto { convert(width).foreach(fb.setWidth) convert(point).foreach(fb.setPoint) eb.setFixedLiteral(fb) + case ir.BundleExpression(fields) => + val bb = Firrtl.Expression.BundleExpression.newBuilder() + fields.foreach({ case (n, v) => + val fb = Firrtl.Expression.BundleExpression.Field.newBuilder() + fb.setName(n) + fb.setValue(convert(v)) + bb.addField(fb) + }) + eb.setBundleExpression(bb) + case ir.VectorExpression(exps, _) => + val bb = Firrtl.Expression.VectorExpression.newBuilder() + exps.foreach({ case e => + bb.addExp(convert(e)) + }) + + eb.setVectorExpression(bb) case ir.DoPrim(op, args, consts, _) => val db = Firrtl.Expression.PrimOp.newBuilder() .setOp(convert(op)) diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index b0e5727b1d..ad5d87a85d 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -115,6 +115,25 @@ class LowerTypesSpec extends FirrtlFlatSpec { executeTest(input, expected) } + it should "lower registers with aggregate expression initialization" in { + val input = + """circuit Test : + | module Test : + | input clock : Clock + | input reset : UInt<1> + | reg x : { a : UInt<1>, b : UInt<1>}[2], clock with : + | reset => (reset, [{ a : UInt<1>("h0"), b : UInt<1>("h0") }, { a : UInt<1>("h1"), b : UInt<1>("h1") }]) + """.stripMargin + val expected = Seq( + "reg x_0_a : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h0\"))", + "reg x_0_b : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h0\"))", + "reg x_1_a : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h1\"))", + "reg x_1_b : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h1\"))" + ) map normalized + + executeTest(input, expected) + } + it should "lower DefRegister expressions: clock, reset, and init" in { val input = """circuit Test : diff --git a/src/test/scala/firrtlTests/MemSpec.scala b/src/test/scala/firrtlTests/MemSpec.scala index 612a952d9c..f2dce1cae4 100644 --- a/src/test/scala/firrtlTests/MemSpec.scala +++ b/src/test/scala/firrtlTests/MemSpec.scala @@ -3,8 +3,11 @@ package firrtlTests import firrtl._ +import firrtl.passes._ import FirrtlCheckers._ +import java.io.StringWriter + class MemSpec extends FirrtlPropSpec with FirrtlMatchers { property("Zero-ported mems should be supported!") { @@ -15,6 +18,249 @@ class MemSpec extends FirrtlPropSpec with FirrtlMatchers { runFirrtlTest("ZeroWidthMem", "/features") } + property("Writing to mems with bundle literals should work") { + val passes = Seq( + CheckChirrtl, + CInferTypes, + CInferMDir, + RemoveCHIRRTL, + ToWorkingIR, + CheckHighForm, + ResolveKinds, + ResolveGenders, + InferTypes, + CheckTypes, + ExpandConnects, + LowerTypes + ) + val input = + """circuit Unit : + | module Unit : + | input clock : Clock + | input i : { a: { b: UInt<32> }, c: UInt<32> } + | output o : { a: { b: UInt<32> }, c: UInt<32> } + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32> }, c: UInt<32> }[16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w <= { a: { b: UInt<32>("h0") }, c: UInt<32>("h0") } + | """.stripMargin + val check = + """circuit Unit : + | module Unit : + | input clock : Clock + | input i : { a: { b: UInt<32> }, c: UInt<32> } + | output o : { a: { b: UInt<32> }, c: UInt<32> } + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32> }, c: UInt<32> }[16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w.a.b <= UInt<32>("h0") + | w.c <= UInt<32>("h0") + | """.stripMargin + val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + val cResult = passes.foldLeft(CircuitState(parse(check), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + def removeMask(s: String) = + s.split("\n") filter { l => !(l contains "mask") } mkString "\n" + val iWriter = new StringWriter() + val cWriter = new StringWriter() + (new HighFirrtlEmitter).emit(iResult, iWriter) + (new HighFirrtlEmitter).emit(cResult, cWriter) + val iCircuit = parse(removeMask(iWriter.toString())) + val cCircuit = parse(removeMask(cWriter.toString())) + iCircuit should be (cCircuit) + } + + property("Writing to mems with a vector of bundle literals should work") { + val passes = Seq( + CheckChirrtl, + CInferTypes, + CInferMDir, + RemoveCHIRRTL, + ToWorkingIR, + CheckHighForm, + ResolveKinds, + ResolveGenders, + InferTypes, + CheckTypes, + ReplaceAccesses, + ExpandConnects, + LowerTypes + ) + val input = + """circuit Unit : + | module Unit : + | input clock : Clock + | input i : { a: { b: UInt<32> }, c: UInt<32> }[3] + | output o : { a: { b: UInt<32> }, c: UInt<32> }[3] + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32> }, c: UInt<32> }[3][16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w <= [ { a: { b: UInt<32>("h0") }, c: UInt<32>("h0") }, { a: { b: UInt<32>("h1") }, c: UInt<32>("h1") }, { a: { b: UInt<32>("h4") }, c: UInt<32>("h4") }] + | """.stripMargin + val check = + """circuit Unit : + | module Unit : + | input clock : Clock + | input i : { a: { b: UInt<32> }, c: UInt<32> }[3] + | output o : { a: { b: UInt<32> }, c: UInt<32> }[3] + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32> }, c: UInt<32> }[3][16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w[0].a.b <= UInt<32>("h0") + | w[0].c <= UInt<32>("h0") + | w[1].a.b <= UInt<32>("h1") + | w[1].c <= UInt<32>("h1") + | w[2].a.b <= UInt<32>("h4") + | w[2].c <= UInt<32>("h4") + | """.stripMargin + val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + println("Input done") + val cResult = passes.foldLeft(CircuitState(parse(check), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + println("Output done") + // All the assignments to addr, clk, mask, en, etc. will be in different orders + // Get rid of them, leaving just data + def removeJunk(s: String) = + s.split("\n") filter { l => !( + (l contains "mask") || + (l contains "invalid") || + (l contains ".addr") || + (l contains ".clk") || + (l contains ".en") || + (l contains ".mask") + ) } mkString "\n" + val iWriter = new StringWriter() + val cWriter = new StringWriter() + (new HighFirrtlEmitter).emit(iResult, iWriter) + (new HighFirrtlEmitter).emit(cResult, cWriter) + val iCircuit = parse(removeJunk(iWriter.toString())) + val cCircuit = parse(removeJunk(cWriter.toString())) + iCircuit should be (cCircuit) + } + + property("Writing to mems with a bundle of vector literals should work") { + val passes = Seq( + CheckChirrtl, + CInferTypes, + CInferMDir, + RemoveCHIRRTL, + ToWorkingIR, + CheckHighForm, + ResolveKinds, + ResolveGenders, + InferTypes, + CheckTypes, + ReplaceAccesses, + ExpandConnects, + LowerTypes + ) + val input = + """circuit Unit : + | module Unit : + | input clock : Clock + | input i : { a: { b: UInt<32>[3] }, c: UInt<32> } + | output o : { a: { b: UInt<32>[3] }, c: UInt<32> } + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32>[3] }, c: UInt<32> }[16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w <= { a: { b: [ UInt<32>("h0"), UInt<32>("h3"), UInt<32>("h6")] }, c: UInt<32>("h1") } + | """.stripMargin + val check = + """circuit Unit : + | module Unit : + | input clock : Clock + | input i : { a: { b: UInt<32>[3] }, c: UInt<32> } + | output o : { a: { b: UInt<32>[3] }, c: UInt<32> } + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32>[3] }, c: UInt<32> }[16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w.a.b[0] <= UInt<32>("h0") + | w.a.b[1] <= UInt<32>("h3") + | w.a.b[2] <= UInt<32>("h6") + | w.c <= UInt<32>("h1") + | """.stripMargin + val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + println("Input done") + val cResult = passes.foldLeft(CircuitState(parse(check), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + println("Output done") + // All the assignments to addr, clk, mask, en, etc. will be in different orders + // Get rid of them, leaving just data + def removeJunk(s: String) = + s.split("\n") filter { l => !( + (l contains "mask") || + (l contains "invalid") || + (l contains ".addr") || + (l contains ".clk") || + (l contains ".en") || + (l contains ".mask") + ) } mkString "\n" + val iWriter = new StringWriter() + val cWriter = new StringWriter() + (new HighFirrtlEmitter).emit(iResult, iWriter) + (new HighFirrtlEmitter).emit(cResult, cWriter) + val iCircuit = parse(removeJunk(iWriter.toString())) + val cCircuit = parse(removeJunk(cWriter.toString())) + iCircuit should be (cCircuit) + } + property("Very large memories should be supported") { val addrWidth = 65 val memSize = BigInt(1) << addrWidth diff --git a/src/test/scala/firrtlTests/ParserSpec.scala b/src/test/scala/firrtlTests/ParserSpec.scala index 4f28e10058..f2b32dd168 100644 --- a/src/test/scala/firrtlTests/ParserSpec.scala +++ b/src/test/scala/firrtlTests/ParserSpec.scala @@ -201,6 +201,7 @@ class ParserPropSpec extends FirrtlPropSpec { def legalStartChar = Gen.frequency((1, '_'), (20, Gen.alphaChar)) def legalChar = Gen.frequency((1, Gen.numChar), (1, '$'), (10, legalStartChar)) + def uintValues = Gen.choose(0, 1000000) def identifier = for { x <- legalStartChar @@ -236,4 +237,31 @@ class ParserPropSpec extends FirrtlPropSpec { } } } + property("Bundle expressions should be OK") { + forAll (identifier, bundleField, uintValues) { case (id, field, uval) => + whenever(id.nonEmpty && field.nonEmpty) { + val input = s""" + |circuit Test : + | module Test : + | output $id : { $field : UInt<32> } + | $id <= { $field : UInt<32>("h${uval.toHexString}") } + |""".stripMargin + firrtl.Parser.parse(input split "\n") + } + } + } + property("Vector expressions should be OK") { + forAll (identifier, uintValues) { case (id, uval) => + whenever(id.nonEmpty) { + val entries = (0 until 6).map(x => "UInt(\"h" + (x + uval).toHexString + "\")") + val input = s""" + |circuit Test : + | module Test : + | output $id : UInt<32>[6] + | $id <= [${entries.mkString(", ")}] + |""".stripMargin + firrtl.Parser.parse(input split "\n") + } + } + } } diff --git a/src/test/scala/firrtlTests/ProtoBufSpec.scala b/src/test/scala/firrtlTests/ProtoBufSpec.scala index 7f41fb2610..27075da298 100644 --- a/src/test/scala/firrtlTests/ProtoBufSpec.scala +++ b/src/test/scala/firrtlTests/ProtoBufSpec.scala @@ -114,6 +114,28 @@ class ProtoBufSpec extends FirrtlFlatSpec { FromProto.convert(ToProto.convert(flit).build) should equal (flit) } + it should "support Bundle Literals" in { + val ulit = ir.UIntLiteral(123, ir.IntWidth(32)) + FromProto.convert(ToProto.convert(ulit).build) should equal (ulit) + + val slit = ir.SIntLiteral(-123, ir.IntWidth(32)) + FromProto.convert(ToProto.convert(slit).build) should equal (slit) + + val flit = ir.FixedLiteral(-123, ir.IntWidth(32), ir.IntWidth(30)) + FromProto.convert(ToProto.convert(flit).build) should equal (flit) + val blit = ir.BundleExpression(Seq( + ("a", ulit), + ("b", ir.BundleExpression(Seq( ("c", slit), ("d", flit)))) + )) + FromProto.convert(ToProto.convert(blit).build) should equal (blit) + } + + it should "support Vector Expressions" in { + val ulits = for (i <- 0 until 10) yield ir.UIntLiteral(i + 3, ir.UnknownWidth) + val ve = VectorExpression(ulits, ir.UnknownType) + FromProto.convert(ToProto.convert(ve).build) should equal (ve) + } + it should "support Analog and Attach" in { val analog = ir.AnalogType(IntWidth(8)) FromProto.convert(ToProto.convert(analog).build) should equal (analog) diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 8788bac7a3..e8d514ec50 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -90,6 +90,93 @@ class UnitTests extends FirrtlFlatSpec { } } + "Connecting bundle literals" should "work" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + ResolveGenders, + InferTypes, + CheckTypes, + ExpandConnects, + LowerTypes + ) + val input = + """circuit Unit : + | module Unit : + | output x : { a: { b: UInt<32> }, c: UInt<32> } + | output y : UInt<32> + | x <= { a: { b: UInt<32>("h5") }, c: UInt<32>("h6") } + | y <= { a: { b: UInt<32>("h5") }, c: UInt<32>("h6") }.a.b + | """.stripMargin + val check = + """circuit Unit : + | module Unit : + | output x : { a: { b: UInt<32> }, c: UInt<32> } + | output y : UInt<32> + | x.a.b <= UInt<32>("h5") + | x.c <= UInt<32>("h6") + | y <= { a: { b: UInt<32>("h5") }, c: UInt<32>("h6") }.a.b + | """.stripMargin + val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + val cResult = passes.foldLeft(CircuitState(parse(check), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + val iWriter = new StringWriter() + (new HighFirrtlEmitter).emit(iResult, iWriter) + val cWriter = new StringWriter() + (new HighFirrtlEmitter).emit(cResult, cWriter) + (parse(iWriter.toString())) should be (parse(cWriter.toString())) + } + + "Connecting vector expressions" should "work" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + ResolveGenders, + InferTypes, + CheckTypes, + ExpandConnects, + LowerTypes + ) + val input = + """circuit Unit : + | module Unit : + | output x : UInt<32>[6] + | x <= [ UInt("h1"), UInt("h2"), UInt("h4"), UInt("h8"), UInt("h10"), UInt("h20") ] + | """.stripMargin + val check = + """circuit Unit : + | module Unit : + | output x_0 : UInt<32> + | output x_1 : UInt<32> + | output x_2 : UInt<32> + | output x_3 : UInt<32> + | output x_4 : UInt<32> + | output x_5 : UInt<32> + | x_0 <= UInt("h1") + | x_1 <= UInt("h2") + | x_2 <= UInt("h4") + | x_3 <= UInt("h8") + | x_4 <= UInt("h10") + | x_5 <= UInt("h20") + | """.stripMargin + val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + val cResult = passes.foldLeft(CircuitState(parse(check), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + val iWriter = new StringWriter() + (new HighFirrtlEmitter).emit(iResult, iWriter) + val cWriter = new StringWriter() + (new HighFirrtlEmitter).emit(cResult, cWriter) + (parse(iWriter.toString())) should be (parse(cWriter.toString())) + } + "Partial connection two bundle types whose relative flips don't match but leaf node directions do" should "connect correctly" in { val passes = Seq( ToWorkingIR,