From 68b8834de78c38fb65178f38f3b4c582cc5b0d52 Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Thu, 1 Nov 2018 08:03:16 +0000 Subject: [PATCH 01/13] Add Bundle Literals. --- src/main/antlr4/FIRRTL.g4 | 5 +++ src/main/proto/firrtl.proto | 11 ++++++ src/main/scala/firrtl/Utils.scala | 1 + src/main/scala/firrtl/Visitor.scala | 12 +++++++ src/main/scala/firrtl/WIR.scala | 9 +++++ src/main/scala/firrtl/ir/IR.scala | 16 +++++++++ .../firrtl/passes/ConvertFixedToSInt.scala | 1 + src/main/scala/firrtl/passes/InferTypes.scala | 4 +-- src/main/scala/firrtl/passes/LowerTypes.scala | 4 ++- src/main/scala/firrtl/passes/Uniquify.scala | 4 +-- src/main/scala/firrtl/proto/FromProto.scala | 12 +++++++ src/main/scala/firrtl/proto/ToProto.scala | 9 +++++ src/test/scala/firrtlTests/ParserSpec.scala | 14 ++++++++ src/test/scala/firrtlTests/ProtoBufSpec.scala | 16 +++++++++ src/test/scala/firrtlTests/UnitTests.scala | 35 +++++++++++++++++++ 15 files changed, 148 insertions(+), 5 deletions(-) diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index cc5d0a1614..1dc5195fb8 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -61,6 +61,10 @@ type | type '[' intLit ']' // Vector ; +litField + : fieldId ':' exp + ; + field : 'flip'? fieldId ':' type ; @@ -170,6 +174,7 @@ exp | 'mux(' exp exp exp ')' | 'validif(' exp exp ')' | primop exp* intLit* ')' + | '{' litField* '}' ; id diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto index 7be042ab79..3c38340abc 100644 --- a/src/main/proto/firrtl.proto +++ b/src/main/proto/firrtl.proto @@ -323,6 +323,16 @@ message Firrtl { string id = 1; } + message BundleLiteral { + message LiteralField { + // Required + string name = 1; + // Required + Expression value = 2; + } + repeated LiteralField field = 1; + } + message IntegerLiteral { // Base 10 value. May begin with a sign (+|-). Only zero can begin with a // '0'. @@ -441,6 +451,7 @@ message Firrtl { UIntLiteral uint_literal = 2; SIntLiteral sint_literal = 3; FixedLiteral fixed_literal = 11; + BundleLiteral bundle_literal = 12; ValidIf valid_if = 4; //ExtractBits extract_bits = 5; Mux mux = 6; diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 2d1b0b7422..89b8d63133 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -562,6 +562,7 @@ object Utils extends LazyLogging { case ex: DoPrim => MALE case ex: UIntLiteral => MALE case ex: SIntLiteral => MALE + case ex: BundleLiteral => MALE case ex: Mux => MALE case ex: ValidIf => MALE case WInvalid => MALE diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index 7b965bfff0..a6e8927fb0 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -138,6 +138,16 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] { } } + private def visitLitField[FirrtlNode](ctx: FIRRTLParser.LitFieldContext): (String, Literal) = { + val expr = visitExp(ctx.exp) match { + case u: UIntLiteral => u + case s: SIntLiteral => s + case b: BundleLiteral => b + case _ => throw new ParserException(s"Illegal expression in bundle literal at ${ctx.exp}") + } + (ctx.fieldId.getText, expr) + } + private def visitField[FirrtlNode](ctx: FIRRTLParser.FieldContext): Field = { val flip = if (ctx.getChild(0).getText == "flip") Flip else Default Field(ctx.fieldId.getText, flip, visitType(ctx.`type`)) @@ -311,6 +321,8 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] { } case "validif(" => ValidIf(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType) case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType) + case "{" => + BundleLiteral(ctx.litField.asScala.map(visitLitField)) case _ => ctx.getChild(1).getText match { case "." => diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index f61fa41e4d..13654cf060 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -49,6 +49,15 @@ object WSubField { def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UNKNOWNGENDER) def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UNKNOWNGENDER) } +object WSubFieldLiteral { + def unapply(ws: WSubField): Option[Literal] = ws match { + case WSubField(BundleLiteral(lits), name, _, _) => + lits.collectFirst({ case (n, value) if n == name => value }) + case WSubField(WSubFieldLiteral(BundleLiteral(lits)), name, _, _) => + lits.collectFirst({ case (n, value) if n == name => value }) + case _ => None + } +} case class WSubIndex(expr: Expression, value: Int, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${expr.serialize}[$value]" def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index faebc7b891..d0ebdf0d35 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -154,6 +154,22 @@ abstract class Literal extends Expression { val value: BigInt val width: Width } +case class BundleLiteral(lits: Seq[(String, Literal)]) extends Literal { + val value = lits.map({ case (_, lit) => (lit.value, lit.width) }).foldLeft(BigInt(0)) { case (prev, (v, IntWidth(w))) => + (prev << w.toInt) + v + } + val width = lits.map(_._2.width).reduce(_ + _) + def tpe = BundleType(lits.map { case (name, lit) => + Field(name = name, flip = Default, tpe = lit.tpe) + }) + def serialize = + "{ " + (lits.map({ case (n, v) => + s"$n : ${v.serialize}" + }) mkString ", ") + " }" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this +} case class UIntLiteral(value: BigInt, width: Width) extends Literal { def tpe = UIntType(width) def serialize = s"""UInt${width.serialize}("h""" + value.toString(16)+ """")""" diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 4004b8d640..49151fce1f 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -55,6 +55,7 @@ object ConvertFixedToSInt extends Pass { newExp map updateExpType case e: UIntLiteral => e case e: SIntLiteral => e + case e: BundleLiteral => e case _ => e map updateExpType match { case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe) case WRef(name, tpe, k, g) => WRef(name, types(name), k, g) diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 24482076ad..79921b013c 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -31,7 +31,7 @@ object InferTypes extends Pass { case e: DoPrim => PrimOps.set_primop_type(e) case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval)) case e: ValidIf => e copy (tpe = e.value.tpe) - case e @ (_: UIntLiteral | _: SIntLiteral) => e + case e @ (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e } def infer_types_s(types: TypeMap)(s: Statement): Statement = s match { @@ -90,7 +90,7 @@ object CInferTypes extends Pass { case (e: DoPrim) => PrimOps.set_primop_type(e) case (e: Mux) => e copy (tpe = mux_type(e.tval, e.fval)) case (e: ValidIf) => e copy (tpe = e.value.tpe) - case e @ (_: UIntLiteral | _: SIntLiteral) => e + case e @ (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e } def infer_types_s(types: TypeMap)(s: Statement): Statement = s match { diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 663241d30e..772308522a 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -139,6 +139,8 @@ object LowerTypes extends Transform { def lowerTypesExp(memDataTypeMap: MemDataTypeMap, info: Info, mname: String)(e: Expression): Expression = e match { case e: WRef => e + case WSubFieldLiteral(lit) => + lowerTypesExp(memDataTypeMap, info, mname)(lit) case (_: WSubField | _: WSubIndex) => kind(e) match { case InstanceKind => val (root, tail) = splitRef(e) @@ -156,7 +158,7 @@ object LowerTypes extends Transform { case e: Mux => e map lowerTypesExp(memDataTypeMap, info, mname) case e: ValidIf => e map lowerTypesExp(memDataTypeMap, info, mname) case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname) - case e @ (_: UIntLiteral | _: SIntLiteral) => e + case e @ (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e } def lowerTypesStmt(memDataTypeMap: MemDataTypeMap, minfo: Info, mname: String, renames: RenameMap)(s: Statement): Statement = { diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 73f967f499..10e9ed27e0 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -168,7 +168,7 @@ object Uniquify extends Transform { val (subExp, subMap) = rec(e.expr, m) val index = uniquifyNamesExp(e.index, map) (WSubAccess(subExp, index, e.tpe, e.gender), subMap) - case (_: UIntLiteral | _: SIntLiteral) => (exp, m) + case (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => (exp, m) case (_: Mux | _: ValidIf | _: DoPrim) => (exp map ((e: Expression) => uniquifyNamesExp(e, map)), m) } @@ -247,7 +247,7 @@ object Uniquify extends Transform { uniquifyNamesExp(e, nameMap.toMap) case e: Mux => e map uniquifyExp case e: ValidIf => e map uniquifyExp - case (_: UIntLiteral | _: SIntLiteral) => e + case (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e case e: DoPrim => e map uniquifyExp } diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala index dda2099c7d..3f8a3daa15 100644 --- a/src/main/scala/firrtl/proto/FromProto.scala +++ b/src/main/scala/firrtl/proto/FromProto.scala @@ -75,6 +75,17 @@ object FromProto { ir.FixedLiteral(convert(fixed.getValue), width, point) } + def convert(fixed: Firrtl.Expression.BundleLiteral.LiteralField): (String, ir.Literal) = { + val value = convert(fixed.getValue) + value match { + case l: ir.Literal => (fixed.getName, l) + } + } + + def convert(fixed: Firrtl.Expression.BundleLiteral): ir.BundleLiteral = { + ir.BundleLiteral(fixed.getFieldList.asScala.map(convert(_))) + } + def convert(subfield: Firrtl.Expression.SubField): ir.SubField = ir.SubField(convert(subfield.getExpression), subfield.getField, ir.UnknownType) @@ -103,6 +114,7 @@ object FromProto { case UINT_LITERAL_FIELD_NUMBER => convert(expr.getUintLiteral) case SINT_LITERAL_FIELD_NUMBER => convert(expr.getSintLiteral) case FIXED_LITERAL_FIELD_NUMBER => convert(expr.getFixedLiteral) + case BUNDLE_LITERAL_FIELD_NUMBER => convert(expr.getBundleLiteral) case PRIM_OP_FIELD_NUMBER => convert(expr.getPrimOp) case MUX_FIELD_NUMBER => convert(expr.getMux) } diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala index b3fb9a0c83..79b161b895 100644 --- a/src/main/scala/firrtl/proto/ToProto.scala +++ b/src/main/scala/firrtl/proto/ToProto.scala @@ -164,6 +164,15 @@ object ToProto { convert(width).foreach(fb.setWidth) convert(point).foreach(fb.setPoint) eb.setFixedLiteral(fb) + case ir.BundleLiteral(fields) => + val bb = Firrtl.Expression.BundleLiteral.newBuilder() + fields.foreach({ case (n, v) => + val fb = Firrtl.Expression.BundleLiteral.LiteralField.newBuilder() + fb.setName(n) + fb.setValue(convert(v)) + bb.addField(fb) + }) + eb.setBundleLiteral(bb) case ir.DoPrim(op, args, consts, _) => val db = Firrtl.Expression.PrimOp.newBuilder() .setOp(convert(op)) diff --git a/src/test/scala/firrtlTests/ParserSpec.scala b/src/test/scala/firrtlTests/ParserSpec.scala index 384b75d22c..ad75a8f1df 100644 --- a/src/test/scala/firrtlTests/ParserSpec.scala +++ b/src/test/scala/firrtlTests/ParserSpec.scala @@ -180,6 +180,7 @@ class ParserPropSpec extends FirrtlPropSpec { def legalStartChar = Gen.frequency((1, '_'), (20, Gen.alphaChar)) def legalChar = Gen.frequency((1, Gen.numChar), (1, '$'), (10, legalStartChar)) + def uintValues = Gen.choose(0, 1000000) def identifier = for { x <- legalStartChar @@ -215,4 +216,17 @@ class ParserPropSpec extends FirrtlPropSpec { } } } + property("Bundle literals should be OK") { + forAll (identifier, bundleField, uintValues) { case (id, field, uval) => + whenever(id.nonEmpty && field.nonEmpty) { + val input = s""" + |circuit Test : + | module Test : + | output $id : { $field : UInt<32> } + | $id <= { $field : UInt<32>("h${uval.toHexString}") } + |""".stripMargin + firrtl.Parser.parse(input split "\n") + } + } + } } diff --git a/src/test/scala/firrtlTests/ProtoBufSpec.scala b/src/test/scala/firrtlTests/ProtoBufSpec.scala index 2a60eab536..811cf39096 100644 --- a/src/test/scala/firrtlTests/ProtoBufSpec.scala +++ b/src/test/scala/firrtlTests/ProtoBufSpec.scala @@ -117,6 +117,22 @@ class ProtoBufSpec extends FirrtlFlatSpec { FromProto.convert(ToProto.convert(flit).build) should equal (flit) } + it should "support Bundle Literals" in { + val ulit = ir.UIntLiteral(123, ir.IntWidth(32)) + FromProto.convert(ToProto.convert(ulit).build) should equal (ulit) + + val slit = ir.SIntLiteral(-123, ir.IntWidth(32)) + FromProto.convert(ToProto.convert(slit).build) should equal (slit) + + val flit = ir.FixedLiteral(-123, ir.IntWidth(32), ir.IntWidth(30)) + FromProto.convert(ToProto.convert(flit).build) should equal (flit) + val blit = ir.BundleLiteral(Seq( + ("a", ulit), + ("b", ir.BundleLiteral(Seq( ("c", slit), ("d", flit)))), + )) + FromProto.convert(ToProto.convert(blit).build) should equal (blit) + } + it should "support Analog and Attach" in { val analog = ir.AnalogType(IntWidth(8)) FromProto.convert(ToProto.convert(analog).build) should equal (analog) diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 62ed561e1a..a308668103 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -88,6 +88,41 @@ class UnitTests extends FirrtlFlatSpec { } } + "Connecting bundle literals" should "work" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ExpandConnects, + LowerTypes, + ) + val input = + """circuit Unit : + | module Unit : + | output x : { a: { b: UInt<32> } } + | x <= { a: { b: UInt<32>("h5") } } + | """.stripMargin + val check = + """circuit Unit : + | module Unit : + | output x : { a: { b: UInt<32> } } + | x.a.b <= UInt<32>("h5") + | """.stripMargin + val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + val cResult = passes.foldLeft(CircuitState(parse(check), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + val iWriter = new StringWriter() + (new HighFirrtlEmitter).emit(iResult, iWriter) + val cWriter = new StringWriter() + (new HighFirrtlEmitter).emit(cResult, cWriter) + (parse(iWriter.toString())) should be (parse(cWriter.toString())) + } + "Partial connection two bundle types whose relative flips don't match but leaf node directions do" should "connect correctly" in { val passes = Seq( ToWorkingIR, From 4abd9a4ea5b0ecc9d0c365d77ee8568707b4b0d1 Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Mon, 5 Nov 2018 03:43:46 +0000 Subject: [PATCH 02/13] Initial work on a vector expression --- src/main/antlr4/FIRRTL.g4 | 3 +- src/main/scala/firrtl/Utils.scala | 1 + src/main/scala/firrtl/Visitor.scala | 2 + src/main/scala/firrtl/ir/IR.scala | 46 ++++++++++++------- src/main/scala/firrtl/passes/Checks.scala | 2 +- .../firrtl/passes/ConvertFixedToSInt.scala | 4 ++ src/main/scala/firrtl/passes/InferTypes.scala | 8 ++++ src/main/scala/firrtl/passes/LowerTypes.scala | 11 ++--- src/main/scala/firrtl/passes/Uniquify.scala | 2 + 9 files changed, 55 insertions(+), 24 deletions(-) diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index 1dc5195fb8..7a793e75b4 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -174,7 +174,8 @@ exp | 'mux(' exp exp exp ')' | 'validif(' exp exp ')' | primop exp* intLit* ')' - | '{' litField* '}' + | '{' litField* '}' // Bundle Literal + | '[' exp* ']' // Vector Literal ; id diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 89b8d63133..bbf6e99d8d 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -563,6 +563,7 @@ object Utils extends LazyLogging { case ex: UIntLiteral => MALE case ex: SIntLiteral => MALE case ex: BundleLiteral => MALE + case ex: VectorExpression => MALE case ex: Mux => MALE case ex: ValidIf => MALE case WInvalid => MALE diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index a6e8927fb0..6883285077 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -323,6 +323,8 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] { case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType) case "{" => BundleLiteral(ctx.litField.asScala.map(visitLitField)) + case "[" => + VectorExpression(ctx.exp.asScala.map(visitExp), UnknownType) case _ => ctx.getChild(1).getText match { case "." => diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index d0ebdf0d35..e38e1e106b 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -154,22 +154,6 @@ abstract class Literal extends Expression { val value: BigInt val width: Width } -case class BundleLiteral(lits: Seq[(String, Literal)]) extends Literal { - val value = lits.map({ case (_, lit) => (lit.value, lit.width) }).foldLeft(BigInt(0)) { case (prev, (v, IntWidth(w))) => - (prev << w.toInt) + v - } - val width = lits.map(_._2.width).reduce(_ + _) - def tpe = BundleType(lits.map { case (name, lit) => - Field(name = name, flip = Default, tpe = lit.tpe) - }) - def serialize = - "{ " + (lits.map({ case (n, v) => - s"$n : ${v.serialize}" - }) mkString ", ") + " }" - def mapExpr(f: Expression => Expression): Expression = this - def mapType(f: Type => Type): Expression = this - def mapWidth(f: Width => Width): Expression = this -} case class UIntLiteral(value: BigInt, width: Width) extends Literal { def tpe = UIntType(width) def serialize = s"""UInt${width.serialize}("h""" + value.toString(16)+ """")""" @@ -202,6 +186,36 @@ case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Liter def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = FixedLiteral(value, f(width), f(point)) } +case class BundleLiteral(lits: Seq[(String, Literal)]) extends Literal { + val value = lits.map({ case (_, lit) => (lit.value, lit.width) }).foldLeft(BigInt(0)) { case (prev, (v, IntWidth(w))) => + (prev << w.toInt) + v + } + val width = lits.map(_._2.width).reduce(_ + _) + def tpe = BundleType(lits.map { case (name, lit) => + Field(name = name, flip = Default, tpe = lit.tpe) + }) + def serialize = + "{ " + (lits.map({ case (n, v) => + s"$n : ${v.serialize}" + }) mkString ", ") + " }" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this +} +case class VectorExpression(exprs: Seq[Expression], tpe: Type) extends Expression { + // val value = lits.map(x => (x.value, x.width)).foldLeft(BigInt(0)) { case (prev, (v, IntWidth(w))) => + // (prev << w.toInt) + v + // } + // val width = lits.map(_.width).reduce(_ + _) + def serialize = + "[" + exprs.map(_.serialize).mkString(", ") + "]" // TODO type annotation + def mapExpr(f: Expression => Expression): Expression = // this + VectorExpression(exprs.map(_ mapExpr f), tpe) + def mapType(f: Type => Type): Expression = // this + VectorExpression(exprs.map(_ mapType f), tpe) + def mapWidth(f: Width => Width): Expression = // this + VectorExpression(exprs.map(_ mapWidth f), tpe) +} case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression { def serialize: String = op.serialize + "(" + (args.map(_.serialize) ++ consts.map(_.toString)).mkString(", ") + ")" diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 4c7458bf13..3c26307530 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -408,7 +408,7 @@ object CheckTypes extends Pass { val info = get_info(s) match { case NoInfo => minfo case x => x } s match { case sx: Connect if wt(sx.loc.tpe) != wt(sx.expr.tpe) => - errors.append(new InvalidConnect(info, mname, sx.loc.serialize, sx.expr.serialize)) + errors.append(new InvalidConnect(info, mname, s"${sx.loc.serialize} (${sx.loc.tpe.serialize})", s"${sx.expr.serialize} (${sx.expr.tpe.serialize})")) case sx: PartialConnect if !bulk_equals(sx.loc.tpe, sx.expr.tpe, Default, Default) => errors.append(new InvalidConnect(info, mname, sx.loc.serialize, sx.expr.serialize)) case sx: DefRegister => diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 49151fce1f..f3d1fe8341 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -56,6 +56,10 @@ object ConvertFixedToSInt extends Pass { case e: UIntLiteral => e case e: SIntLiteral => e case e: BundleLiteral => e + case e: VectorExpression => + val point = calcPoint(e.exprs) + val newExprs = e.exprs.map(alignArg(_, point)) + VectorExpression(newExprs, UnknownType) case _ => e map updateExpType match { case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe) case WRef(name, tpe, k, g) => WRef(name, types(name), k, g) diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 79921b013c..c7d6765ab8 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -31,6 +31,10 @@ object InferTypes extends Pass { case e: DoPrim => PrimOps.set_primop_type(e) case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval)) case e: ValidIf => e copy (tpe = e.value.tpe) + case e: VectorExpression => e copy (tpe = VectorType( + e.exprs.map(_.tpe).reduce[Type](mux_type_and_widths), + e.exprs.length + )) case e @ (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e } @@ -90,6 +94,10 @@ object CInferTypes extends Pass { case (e: DoPrim) => PrimOps.set_primop_type(e) case (e: Mux) => e copy (tpe = mux_type(e.tval, e.fval)) case (e: ValidIf) => e copy (tpe = e.value.tpe) + case e: VectorExpression => e copy (tpe = VectorType( + e.exprs.map(_.tpe).reduce[Type](mux_type), + e.exprs.length + )) case e @ (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e } diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 772308522a..be315bf6b1 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -138,10 +138,11 @@ object LowerTypes extends Transform { def lowerTypesExp(memDataTypeMap: MemDataTypeMap, info: Info, mname: String)(e: Expression): Expression = e match { - case e: WRef => e + case e @ (_: WRef | _: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e case WSubFieldLiteral(lit) => lowerTypesExp(memDataTypeMap, info, mname)(lit) - case (_: WSubField | _: WSubIndex) => kind(e) match { + case WSubIndex(ve: VectorExpression, index, _, _) => ve.exprs(index) + case e @ (_: WSubField | _: WSubIndex) => kind(e) match { case InstanceKind => val (root, tail) = splitRef(e) val name = loweredName(tail) @@ -155,10 +156,8 @@ object LowerTypes extends Transform { } case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e)) } - case e: Mux => e map lowerTypesExp(memDataTypeMap, info, mname) - case e: ValidIf => e map lowerTypesExp(memDataTypeMap, info, mname) - case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname) - case e @ (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e + case e @ (_: Mux | _: ValidIf | _: DoPrim | _: VectorExpression) => + e map lowerTypesExp(memDataTypeMap, info, mname) } def lowerTypesStmt(memDataTypeMap: MemDataTypeMap, minfo: Info, mname: String, renames: RenameMap)(s: Statement): Statement = { diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 10e9ed27e0..f20ebba4e0 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -168,6 +168,7 @@ object Uniquify extends Transform { val (subExp, subMap) = rec(e.expr, m) val index = uniquifyNamesExp(e.index, map) (WSubAccess(subExp, index, e.tpe, e.gender), subMap) + case _: VectorExpression => (exp, m) case (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => (exp, m) case (_: Mux | _: ValidIf | _: DoPrim) => (exp map ((e: Expression) => uniquifyNamesExp(e, map)), m) @@ -247,6 +248,7 @@ object Uniquify extends Transform { uniquifyNamesExp(e, nameMap.toMap) case e: Mux => e map uniquifyExp case e: ValidIf => e map uniquifyExp + case e: VectorExpression => e map uniquifyExp case (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e case e: DoPrim => e map uniquifyExp } From 58210136169f15a5296ca19690c959d33876f036 Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Mon, 5 Nov 2018 04:50:07 +0000 Subject: [PATCH 03/13] Add proto and some tests --- src/main/proto/firrtl.proto | 25 +++++++++++-------- src/main/scala/firrtl/proto/FromProto.scala | 5 ++++ src/main/scala/firrtl/proto/ToProto.scala | 7 ++++++ src/test/scala/firrtlTests/ParserSpec.scala | 14 +++++++++++ src/test/scala/firrtlTests/ProtoBufSpec.scala | 6 +++++ 5 files changed, 47 insertions(+), 10 deletions(-) diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto index 3c38340abc..077ad9368e 100644 --- a/src/main/proto/firrtl.proto +++ b/src/main/proto/firrtl.proto @@ -323,16 +323,6 @@ message Firrtl { string id = 1; } - message BundleLiteral { - message LiteralField { - // Required - string name = 1; - // Required - Expression value = 2; - } - repeated LiteralField field = 1; - } - message IntegerLiteral { // Base 10 value. May begin with a sign (+|-). Only zero can begin with a // '0'. @@ -358,6 +348,20 @@ message Firrtl { Width point = 3; } + message BundleLiteral { + message LiteralField { + // Required + string name = 1; + // Required + Expression value = 2; + } + repeated LiteralField field = 1; + } + + message VectorExpression { + repeated Expression exp = 1; + } + message ValidIf { // Required. Expression condition = 1; @@ -452,6 +456,7 @@ message Firrtl { SIntLiteral sint_literal = 3; FixedLiteral fixed_literal = 11; BundleLiteral bundle_literal = 12; + VectorExpression vector_expression = 13; ValidIf valid_if = 4; //ExtractBits extract_bits = 5; Mux mux = 6; diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala index 3f8a3daa15..615f01f937 100644 --- a/src/main/scala/firrtl/proto/FromProto.scala +++ b/src/main/scala/firrtl/proto/FromProto.scala @@ -86,6 +86,10 @@ object FromProto { ir.BundleLiteral(fixed.getFieldList.asScala.map(convert(_))) } + def convert(fixed: Firrtl.Expression.VectorExpression): ir.VectorExpression = { + ir.VectorExpression(fixed.getExpList.asScala.map(convert(_)), ir.UnknownType) + } + def convert(subfield: Firrtl.Expression.SubField): ir.SubField = ir.SubField(convert(subfield.getExpression), subfield.getField, ir.UnknownType) @@ -115,6 +119,7 @@ object FromProto { case SINT_LITERAL_FIELD_NUMBER => convert(expr.getSintLiteral) case FIXED_LITERAL_FIELD_NUMBER => convert(expr.getFixedLiteral) case BUNDLE_LITERAL_FIELD_NUMBER => convert(expr.getBundleLiteral) + case VECTOR_EXPRESSION_FIELD_NUMBER => convert(expr.getVectorExpression) case PRIM_OP_FIELD_NUMBER => convert(expr.getPrimOp) case MUX_FIELD_NUMBER => convert(expr.getMux) } diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala index 79b161b895..bad878bd67 100644 --- a/src/main/scala/firrtl/proto/ToProto.scala +++ b/src/main/scala/firrtl/proto/ToProto.scala @@ -173,6 +173,13 @@ object ToProto { bb.addField(fb) }) eb.setBundleLiteral(bb) + case ir.VectorExpression(exps, _) => + val bb = Firrtl.Expression.VectorExpression.newBuilder() + exps.foreach({ case e => + bb.addExp(convert(e)) + }) + + eb.setVectorExpression(bb) case ir.DoPrim(op, args, consts, _) => val db = Firrtl.Expression.PrimOp.newBuilder() .setOp(convert(op)) diff --git a/src/test/scala/firrtlTests/ParserSpec.scala b/src/test/scala/firrtlTests/ParserSpec.scala index ad75a8f1df..682da0805f 100644 --- a/src/test/scala/firrtlTests/ParserSpec.scala +++ b/src/test/scala/firrtlTests/ParserSpec.scala @@ -229,4 +229,18 @@ class ParserPropSpec extends FirrtlPropSpec { } } } + property("Vector expressions should be OK") { + forAll (identifier, uintValues) { case (id, uval) => + whenever(id.nonEmpty) { + val entries = (0 until 6).map(x => "UInt(\"h" + (x + uval).toHexString + "\")") + val input = s""" + |circuit Test : + | module Test : + | output $id : UInt<32>[6] + | $id <= [${entries.mkString(", ")}] + |""".stripMargin + firrtl.Parser.parse(input split "\n") + } + } + } } diff --git a/src/test/scala/firrtlTests/ProtoBufSpec.scala b/src/test/scala/firrtlTests/ProtoBufSpec.scala index 811cf39096..90549e2946 100644 --- a/src/test/scala/firrtlTests/ProtoBufSpec.scala +++ b/src/test/scala/firrtlTests/ProtoBufSpec.scala @@ -133,6 +133,12 @@ class ProtoBufSpec extends FirrtlFlatSpec { FromProto.convert(ToProto.convert(blit).build) should equal (blit) } + it should "support Vector Expressions" in { + val ulits = for (i <- 0 until 10) yield ir.UIntLiteral(i + 3, ir.UnknownWidth) + val ve = VectorExpression(ulits, ir.UnknownType) + FromProto.convert(ToProto.convert(ve).build) should equal (ve) + } + it should "support Analog and Attach" in { val analog = ir.AnalogType(IntWidth(8)) FromProto.convert(ToProto.convert(analog).build) should equal (analog) From 947b0131cf303d01a912aba198f5ba47682e0f87 Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Mon, 5 Nov 2018 20:03:55 +0000 Subject: [PATCH 04/13] Add vector expression unit test --- src/test/scala/firrtlTests/UnitTests.scala | 52 ++++++++++++++++++++-- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index a308668103..b11b37d714 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -101,14 +101,60 @@ class UnitTests extends FirrtlFlatSpec { val input = """circuit Unit : | module Unit : - | output x : { a: { b: UInt<32> } } - | x <= { a: { b: UInt<32>("h5") } } + | output x : { a: { b: UInt<32> }, c: UInt<32> } + | x <= { a: { b: UInt<32>("h5") }, c: UInt<32>("h6") } | """.stripMargin val check = """circuit Unit : | module Unit : - | output x : { a: { b: UInt<32> } } + | output x : { a: { b: UInt<32> }, c: UInt<32> } | x.a.b <= UInt<32>("h5") + | x.c <= UInt<32>("h6") + | """.stripMargin + val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + val cResult = passes.foldLeft(CircuitState(parse(check), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + val iWriter = new StringWriter() + (new HighFirrtlEmitter).emit(iResult, iWriter) + val cWriter = new StringWriter() + (new HighFirrtlEmitter).emit(cResult, cWriter) + (parse(iWriter.toString())) should be (parse(cWriter.toString())) + } + + "Connecting vector expressions" should "work" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ExpandConnects, + LowerTypes, + ) + val input = + """circuit Unit : + | module Unit : + | output x : UInt<32>[6] + | x <= [ UInt("h1"), UInt("h2"), UInt("h4"), UInt("h8"), UInt("h10"), UInt("h20") ] + | """.stripMargin + val check = + """circuit Unit : + | module Unit : + | output x_0 : UInt<32> + | output x_1 : UInt<32> + | output x_2 : UInt<32> + | output x_3 : UInt<32> + | output x_4 : UInt<32> + | output x_5 : UInt<32> + | x_0 <= UInt("h1") + | x_1 <= UInt("h2") + | x_2 <= UInt("h4") + | x_3 <= UInt("h8") + | x_4 <= UInt("h10") + | x_5 <= UInt("h20") | """.stripMargin val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => p.runTransform(c) From 080a3331aeae8e8fc66b5cf0b8adeb1c56d4b3ea Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Thu, 6 Dec 2018 20:42:59 +0000 Subject: [PATCH 05/13] More tests --- src/main/scala/firrtl/passes/CheckChirrtl.scala | 2 +- src/main/scala/firrtl/passes/Checks.scala | 6 ++++-- src/test/scala/firrtlTests/UnitTests.scala | 4 ++++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index 8f37c8bf1f..90369b3670 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -59,7 +59,7 @@ object CheckChirrtl extends Pass { def validSubexp(info: Info, mname: String)(e: Expression): Expression = { e match { case _: Reference | _: SubField | _: SubIndex | _: SubAccess | - _: Mux | _: ValidIf => // No error + _: Mux | _: ValidIf | _: BundleLiteral | _: VectorExpression => // No error case _ => errors append new InvalidAccessException(info, mname) } e diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 3c26307530..449b52d7d7 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -135,7 +135,8 @@ object CheckHighForm extends Pass { def validSubexp(info: Info, mname: String)(e: Expression): Expression = { e match { - case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error + case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf | + _: BundleLiteral | _: VectorExpression => // No error case _ => errors.append(new InvalidAccessException(info, mname)) } e @@ -302,10 +303,11 @@ object CheckTypes extends Pass { case s: SIntType => (isUInt, true, isClock, isFix) case ClockType => (isUInt, isSInt, true, isFix) case f: FixedType => (isUInt, isSInt, isClock, true) + case _: BundleType | _: VectorType => (isUInt, isSInt, isClock, isFix) // ignore until lowered case UnknownType => errors.append(new IllegalUnknownType(info, mname, e.serialize)) (isUInt, isSInt, isClock, isFix) - case other => throwInternalError(s"Illegal Type: ${other.serialize}") + case other => throwInternalError(s"Illegal Type: ${other.serialize} (for DoPrim ${e.serialize})") } } match { // (UInt, SInt, Clock, Fixed) diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index b11b37d714..8441bb3a27 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -102,14 +102,18 @@ class UnitTests extends FirrtlFlatSpec { """circuit Unit : | module Unit : | output x : { a: { b: UInt<32> }, c: UInt<32> } + | output y : UInt<32> | x <= { a: { b: UInt<32>("h5") }, c: UInt<32>("h6") } + | y <= { a: { b: UInt<32>("h5") }, c: UInt<32>("h6") }.a.b | """.stripMargin val check = """circuit Unit : | module Unit : | output x : { a: { b: UInt<32> }, c: UInt<32> } + | output y : UInt<32> | x.a.b <= UInt<32>("h5") | x.c <= UInt<32>("h6") + | y <= { a: { b: UInt<32>("h5") }, c: UInt<32>("h6") }.a.b | """.stripMargin val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => p.runTransform(c) From 643507a90f116c0ef82d2e7f3b1366c7fd88bdbd Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Thu, 20 Dec 2018 01:24:58 +0000 Subject: [PATCH 06/13] Fix unit tests --- src/main/scala/firrtl/ir/IR.scala | 33 ++++++++++++++++++---- src/test/scala/firrtlTests/UnitTests.scala | 2 ++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index d424df7600..4ead263f30 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -228,9 +228,27 @@ case class BundleLiteral(lits: Seq[(String, Literal)]) extends Literal { "{ " + (lits.map({ case (n, v) => s"$n : ${v.serialize}" }) mkString ", ") + " }" - def mapExpr(f: Expression => Expression): Expression = this - def mapType(f: Type => Type): Expression = this - def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = lits.foreach { case (_, l) => l foreachExpr f } + def foreachType(f: Type => Unit): Unit = lits.foreach { case (_, l) => l foreachType f } + def foreachWidth(f: Width => Unit): Unit = lits.foreach { case (_, l) => l foreachWidth f } + def mapExpr(f: Expression => Expression): Expression = BundleLiteral(lits.map { case (s, l) => + l mapExpr f match { + case lit: Literal => (s, lit) + case _ => throw new Exception("Oh no!") + } + }) + def mapType(f: Type => Type): Expression = BundleLiteral(lits.map { case (s, l) => + l mapType f match { + case lit: Literal => (s, lit) + case _ => throw new Exception("Oh no!") + } + }) + def mapWidth(f: Width => Width): Expression = BundleLiteral(lits.map { case (s, l) => + l mapWidth f match { + case lit: Literal => (s, lit) + case _ => throw new Exception("Oh no!") + } + }) } case class VectorExpression(exprs: Seq[Expression], tpe: Type) extends Expression { // val value = lits.map(x => (x.value, x.width)).foldLeft(BigInt(0)) { case (prev, (v, IntWidth(w))) => @@ -239,11 +257,14 @@ case class VectorExpression(exprs: Seq[Expression], tpe: Type) extends Expressio // val width = lits.map(_.width).reduce(_ + _) def serialize = "[" + exprs.map(_.serialize).mkString(", ") + "]" // TODO type annotation - def mapExpr(f: Expression => Expression): Expression = // this + def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(_ foreachExpr f) + def foreachType(f: Type => Unit): Unit = exprs.foreach(_ foreachType f) + def foreachWidth(f: Width => Unit): Unit = exprs.foreach(_ foreachWidth f) + def mapExpr(f: Expression => Expression): Expression = VectorExpression(exprs.map(_ mapExpr f), tpe) - def mapType(f: Type => Type): Expression = // this + def mapType(f: Type => Type): Expression = VectorExpression(exprs.map(_ mapType f), tpe) - def mapWidth(f: Width => Width): Expression = // this + def mapWidth(f: Width => Width): Expression = VectorExpression(exprs.map(_ mapWidth f), tpe) } case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression { diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index f5a9833f4d..830dfa888e 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -93,6 +93,7 @@ class UnitTests extends FirrtlFlatSpec { ToWorkingIR, CheckHighForm, ResolveKinds, + ResolveGenders, InferTypes, CheckTypes, ExpandConnects, @@ -133,6 +134,7 @@ class UnitTests extends FirrtlFlatSpec { ToWorkingIR, CheckHighForm, ResolveKinds, + ResolveGenders, InferTypes, CheckTypes, ExpandConnects, From eee8d9256dcbc1e30ad57cd6a622b2614eb3f95f Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Sun, 23 Dec 2018 19:55:03 -0500 Subject: [PATCH 07/13] Add more tests. --- src/main/scala/firrtl/WIR.scala | 10 +- src/main/scala/firrtl/ir/IR.scala | 1 + src/main/scala/firrtl/passes/LowerTypes.scala | 6 +- src/test/scala/firrtlTests/MemSpec.scala | 252 ++++++++++++++++++ 4 files changed, 263 insertions(+), 6 deletions(-) diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index bcdf8d2d01..ed284d960a 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -55,12 +55,16 @@ object WSubField { def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UNKNOWNGENDER) def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UNKNOWNGENDER) } -object WSubFieldLiteral { - def unapply(ws: WSubField): Option[Literal] = ws match { +object WSubLiteral { + def unapply(w: Expression): Option[Expression] = w match { case WSubField(BundleLiteral(lits), name, _, _) => lits.collectFirst({ case (n, value) if n == name => value }) - case WSubField(WSubFieldLiteral(BundleLiteral(lits)), name, _, _) => + case WSubField(WSubLiteral(BundleLiteral(lits)), name, _, _) => lits.collectFirst({ case (n, value) if n == name => value }) + case WSubIndex(VectorExpression(exprs, _), index, _, _) => + exprs.lift(index) + case WSubIndex(WSubLiteral(VectorExpression(exprs, _)), index, _, _) => + exprs.lift(index) case _ => None } } diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 4ead263f30..918739be13 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -8,6 +8,7 @@ import Utils.indent /** Intermediate Representation */ abstract class FirrtlNode { def serialize: String + override def toString = serialize } abstract class Info extends FirrtlNode { diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 61afa0c62c..bdd466eb62 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -138,9 +138,9 @@ object LowerTypes extends Transform { def lowerTypesExp(memDataTypeMap: MemDataTypeMap, info: Info, mname: String)(e: Expression): Expression = e match { - case e @ (_: WRef | _: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e - case WSubFieldLiteral(lit) => - lowerTypesExp(memDataTypeMap, info, mname)(lit) + case e @ (_: WRef | _: UIntLiteral | _: SIntLiteral | _: BundleLiteral | _: VectorExpression) => e + case WSubLiteral(exp) => + lowerTypesExp(memDataTypeMap, info, mname)(exp) case WSubIndex(ve: VectorExpression, index, _, _) => ve.exprs(index) case e @ (_: WSubField | _: WSubIndex) => kind(e) match { case InstanceKind => diff --git a/src/test/scala/firrtlTests/MemSpec.scala b/src/test/scala/firrtlTests/MemSpec.scala index 67b7e74df1..f709eb48cd 100644 --- a/src/test/scala/firrtlTests/MemSpec.scala +++ b/src/test/scala/firrtlTests/MemSpec.scala @@ -2,6 +2,15 @@ package firrtlTests +import java.io._ +import org.scalatest._ +import org.scalatest.prop._ +import firrtl._ +import firrtl.ir._ +import firrtl.passes._ +import firrtl.transforms._ +import FirrtlCheckers._ + class MemSpec extends FirrtlPropSpec { property("Zero-ported mems should be supported!") { @@ -11,5 +20,248 @@ class MemSpec extends FirrtlPropSpec { property("Mems with zero-width elements should be supported!") { runFirrtlTest("ZeroWidthMem", "/features") } + + property("Writing to mems with bundle literals should work") { + val passes = Seq( + CheckChirrtl, + CInferTypes, + CInferMDir, + RemoveCHIRRTL, + ToWorkingIR, + CheckHighForm, + ResolveKinds, + ResolveGenders, + InferTypes, + CheckTypes, + ExpandConnects, + LowerTypes, + ) + val input = + """circuit Unit : + | module Unit : + | input clock : Clock + | output i : { a: { b: UInt<32> }, c: UInt<32> } + | output o : { a: { b: UInt<32> }, c: UInt<32> } + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32> }, c: UInt<32> }[16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w <= { a: { b: UInt<32>("h0") }, c: UInt<32>("h0") } + | """.stripMargin + val check = + """circuit Unit : + | module Unit : + | input clock : Clock + | output i : { a: { b: UInt<32> }, c: UInt<32> } + | output o : { a: { b: UInt<32> }, c: UInt<32> } + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32> }, c: UInt<32> }[16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w.a.b <= UInt<32>("h0") + | w.c <= UInt<32>("h0") + | """.stripMargin + val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + val cResult = passes.foldLeft(CircuitState(parse(check), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + def removeMask(s: String) = + s.split("\n") filter { l => !(l contains "mask") } mkString "\n" + val iWriter = new StringWriter() + val cWriter = new StringWriter() + (new HighFirrtlEmitter).emit(iResult, iWriter) + (new HighFirrtlEmitter).emit(cResult, cWriter) + val iCircuit = parse(removeMask(iWriter.toString())) + val cCircuit = parse(removeMask(cWriter.toString())) + iCircuit should be (cCircuit) + } + + property("Writing to mems with a vector of bundle literals should work") { + val passes = Seq( + CheckChirrtl, + CInferTypes, + CInferMDir, + RemoveCHIRRTL, + ToWorkingIR, + CheckHighForm, + ResolveKinds, + ResolveGenders, + InferTypes, + CheckTypes, + ReplaceAccesses, + ExpandConnects, + LowerTypes, + ) + val input = + """circuit Unit : + | module Unit : + | input clock : Clock + | input i : { a: { b: UInt<32> }, c: UInt<32> }[3] + | output o : { a: { b: UInt<32> }, c: UInt<32> }[3] + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32> }, c: UInt<32> }[3][16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w <= [ { a: { b: UInt<32>("h0") }, c: UInt<32>("h0") }, { a: { b: UInt<32>("h1") }, c: UInt<32>("h1") }, { a: { b: UInt<32>("h4") }, c: UInt<32>("h4") }] + | """.stripMargin + val check = + """circuit Unit : + | module Unit : + | input clock : Clock + | input i : { a: { b: UInt<32> }, c: UInt<32> }[3] + | output o : { a: { b: UInt<32> }, c: UInt<32> }[3] + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32> }, c: UInt<32> }[3][16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w[0].a.b <= UInt<32>("h0") + | w[0].c <= UInt<32>("h0") + | w[1].a.b <= UInt<32>("h1") + | w[1].c <= UInt<32>("h1") + | w[2].a.b <= UInt<32>("h4") + | w[2].c <= UInt<32>("h4") + | """.stripMargin + val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + println("Input done") + val cResult = passes.foldLeft(CircuitState(parse(check), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + println("Output done") + // All the assignments to addr, clk, mask, en, etc. will be in different orders + // Get rid of them, leaving just data + def removeJunk(s: String) = + s.split("\n") filter { l => !( + (l contains "mask") || + (l contains "invalid") || + (l contains ".addr") || + (l contains ".clk") || + (l contains ".en") || + (l contains ".mask") + ) } mkString "\n" + val iWriter = new StringWriter() + val cWriter = new StringWriter() + (new HighFirrtlEmitter).emit(iResult, iWriter) + (new HighFirrtlEmitter).emit(cResult, cWriter) + val iCircuit = parse(removeJunk(iWriter.toString())) + val cCircuit = parse(removeJunk(cWriter.toString())) + iCircuit should be (cCircuit) + } + + ignore("Writing to mems with a bundle of vector literals should work") { + val passes = Seq( + CheckChirrtl, + CInferTypes, + CInferMDir, + RemoveCHIRRTL, + ToWorkingIR, + CheckHighForm, + ResolveKinds, + ResolveGenders, + InferTypes, + CheckTypes, + ReplaceAccesses, + ExpandConnects, + LowerTypes, + ) + val input = + """circuit Unit : + | module Unit : + | input clock : Clock + | input i : { a: { b: UInt<32>[3] }, c: UInt<32> } + | output o : { a: { b: UInt<32>[3] }, c: UInt<32> } + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32>[3] }, c: UInt<32> }[16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w <= { a: { b: [ UInt<32>("h0"), UInt<32>("h3"), UInt<32>("h6")] }, c: UInt<32>("h1") } + | """.stripMargin + val check = + """circuit Unit : + | module Unit : + | input clock : Clock + | input i : { a: { b: UInt<32>[3] }, c: UInt<32> } + | output o : { a: { b: UInt<32>[3] }, c: UInt<32> } + | input wZero: UInt<1> + | input waddr: UInt<4> + | input raddr: UInt<4> + | + | cmem ram : { a: { b: UInt<32>[3] }, c: UInt<32> }[16] + | + | infer mport r = ram[raddr], clock + | o <= r + | infer mport w = ram[waddr], clock + | w <= i + | when wZero : + | w.a.b[0] <= UInt<32>("h0") + | w.a.b[1] <= UInt<32>("h3") + | w.a.b[2] <= UInt<32>("h6") + | w.c <= UInt<32>("h1") + | """.stripMargin + val iResult = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + println("Input done") + val cResult = passes.foldLeft(CircuitState(parse(check), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + println("Output done") + // All the assignments to addr, clk, mask, en, etc. will be in different orders + // Get rid of them, leaving just data + def removeJunk(s: String) = + s.split("\n") filter { l => !( + (l contains "mask") || + (l contains "invalid") || + (l contains ".addr") || + (l contains ".clk") || + (l contains ".en") || + (l contains ".mask") + ) } mkString "\n" + val iWriter = new StringWriter() + val cWriter = new StringWriter() + (new HighFirrtlEmitter).emit(iResult, iWriter) + (new HighFirrtlEmitter).emit(cResult, cWriter) + val iCircuit = parse(removeJunk(iWriter.toString())) + val cCircuit = parse(removeJunk(cWriter.toString())) + iCircuit should be (cCircuit) + } } From e63b628a4ed790b30e59e1d1d238fc6f3e018ee8 Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Sun, 23 Dec 2018 22:27:10 -0500 Subject: [PATCH 08/13] I think Scala 2.11 doesn't like commas --- src/test/scala/firrtlTests/MemSpec.scala | 6 +++--- src/test/scala/firrtlTests/ProtoBufSpec.scala | 2 +- src/test/scala/firrtlTests/UnitTests.scala | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/test/scala/firrtlTests/MemSpec.scala b/src/test/scala/firrtlTests/MemSpec.scala index f709eb48cd..b985a7561d 100644 --- a/src/test/scala/firrtlTests/MemSpec.scala +++ b/src/test/scala/firrtlTests/MemSpec.scala @@ -34,7 +34,7 @@ class MemSpec extends FirrtlPropSpec { InferTypes, CheckTypes, ExpandConnects, - LowerTypes, + LowerTypes ) val input = """circuit Unit : @@ -106,7 +106,7 @@ class MemSpec extends FirrtlPropSpec { CheckTypes, ReplaceAccesses, ExpandConnects, - LowerTypes, + LowerTypes ) val input = """circuit Unit : @@ -193,7 +193,7 @@ class MemSpec extends FirrtlPropSpec { CheckTypes, ReplaceAccesses, ExpandConnects, - LowerTypes, + LowerTypes ) val input = """circuit Unit : diff --git a/src/test/scala/firrtlTests/ProtoBufSpec.scala b/src/test/scala/firrtlTests/ProtoBufSpec.scala index 0cd74f42c9..4a55dab9ca 100644 --- a/src/test/scala/firrtlTests/ProtoBufSpec.scala +++ b/src/test/scala/firrtlTests/ProtoBufSpec.scala @@ -127,7 +127,7 @@ class ProtoBufSpec extends FirrtlFlatSpec { FromProto.convert(ToProto.convert(flit).build) should equal (flit) val blit = ir.BundleLiteral(Seq( ("a", ulit), - ("b", ir.BundleLiteral(Seq( ("c", slit), ("d", flit)))), + ("b", ir.BundleLiteral(Seq( ("c", slit), ("d", flit)))) )) FromProto.convert(ToProto.convert(blit).build) should equal (blit) } diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 830dfa888e..6b3974a627 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -97,7 +97,7 @@ class UnitTests extends FirrtlFlatSpec { InferTypes, CheckTypes, ExpandConnects, - LowerTypes, + LowerTypes ) val input = """circuit Unit : @@ -138,7 +138,7 @@ class UnitTests extends FirrtlFlatSpec { InferTypes, CheckTypes, ExpandConnects, - LowerTypes, + LowerTypes ) val input = """circuit Unit : From 8a0bd90faa32598361621ec4f25f680e601e20cb Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Mon, 2 Mar 2020 18:05:44 -0800 Subject: [PATCH 09/13] Update flow --- src/main/scala/firrtl/Utils.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 8301bcde92..ad75ea73d3 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -625,6 +625,8 @@ object Utils extends LazyLogging { case ex: DoPrim => SourceFlow case ex: UIntLiteral => SourceFlow case ex: SIntLiteral => SourceFlow + case ex: BundleLiteral => SourceFlow + case ex: VectorExpression => SourceFlow case ex: Mux => SourceFlow case ex: ValidIf => SourceFlow case WInvalid => SourceFlow From 2dbca05301e60d9403a218c7b3e2af4849280440 Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Mon, 2 Mar 2020 19:14:26 -0800 Subject: [PATCH 10/13] Make tests pass --- src/main/antlr4/FIRRTL.g4 | 4 +-- src/main/proto/firrtl.proto | 4 +-- src/main/scala/firrtl/Visitor.scala | 18 +++++-------- src/main/scala/firrtl/ir/IR.scala | 26 +++---------------- src/main/scala/firrtl/passes/InferTypes.scala | 6 +++-- src/main/scala/firrtl/proto/FromProto.scala | 4 +-- src/main/scala/firrtl/proto/ToProto.scala | 2 +- src/test/scala/firrtlTests/MemSpec.scala | 6 ++--- 8 files changed, 24 insertions(+), 46 deletions(-) diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index bbd91acfa6..307e9915cc 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -59,7 +59,7 @@ type | type '[' intLit ']' // Vector ; -litField +expField : fieldId ':' exp ; @@ -172,7 +172,7 @@ exp | 'mux(' exp exp exp ')' | 'validif(' exp exp ')' | primop exp* intLit* ')' - | '{' litField* '}' // Bundle Literal + | '{' expField* '}' // Bundle Literal | '[' exp* ']' // Vector Literal ; diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto index 1fa76c4eae..f1476eabd1 100644 --- a/src/main/proto/firrtl.proto +++ b/src/main/proto/firrtl.proto @@ -378,13 +378,13 @@ message Firrtl { } message BundleLiteral { - message LiteralField { + message Field { // Required string name = 1; // Required Expression value = 2; } - repeated LiteralField field = 1; + repeated Field field = 1; } message VectorExpression { diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index a1f2aab8cf..0ae688e6e2 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -175,14 +175,8 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } } - private def visitLitField[FirrtlNode](ctx: LitFieldContext): (String, Literal) = { - val expr = visitExp(ctx.exp) match { - case u: UIntLiteral => u - case s: SIntLiteral => s - case b: BundleLiteral => b - case _ => throw new ParserException(s"Illegal expression in bundle literal at ${ctx.exp}") - } - (ctx.fieldId.getText, expr) + private def visitExpField[FirrtlNode](ctx: ExpFieldContext): (String, Expression) = { + (ctx.fieldId.getText, visitExp(ctx.exp)) } // Special case "type" of CHIRRTL mems because their size can be BigInt @@ -397,10 +391,10 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } case "validif(" => ValidIf(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType) case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType) - case "{" => - BundleLiteral(ctx.litField.asScala.map(visitLitField)) - case "[" => - VectorExpression(ctx.exp.asScala.map(visitExp), UnknownType) + case "{" => + BundleLiteral(ctx.expField.asScala.map(visitExpField)) + case "[" => + VectorExpression(ctx_exp.map(visitExp), UnknownType) } } } diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index c4a54d49c4..888e5d3137 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -11,7 +11,6 @@ import scala.math.BigDecimal.RoundingMode._ /** Intermediate Representation */ abstract class FirrtlNode { def serialize: String - override def toString = serialize } abstract class Info extends FirrtlNode { @@ -241,11 +240,7 @@ case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Liter def foreachType(f: Type => Unit): Unit = Unit def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) } } -case class BundleLiteral(lits: Seq[(String, Literal)]) extends Literal { - val value = lits.map({ case (_, lit) => (lit.value, lit.width) }).foldLeft(BigInt(0)) { case (prev, (v, IntWidth(w))) => - (prev << w.toInt) + v - } - val width = lits.map(_._2.width).reduce(_ + _) +case class BundleLiteral(lits: Seq[(String, Expression)]) extends Expression { def tpe = BundleType(lits.map { case (name, lit) => Field(name = name, flip = Default, tpe = lit.tpe) }) @@ -257,29 +252,16 @@ case class BundleLiteral(lits: Seq[(String, Literal)]) extends Literal { def foreachType(f: Type => Unit): Unit = lits.foreach { case (_, l) => l foreachType f } def foreachWidth(f: Width => Unit): Unit = lits.foreach { case (_, l) => l foreachWidth f } def mapExpr(f: Expression => Expression): Expression = BundleLiteral(lits.map { case (s, l) => - l mapExpr f match { - case lit: Literal => (s, lit) - case _ => throw new Exception("Oh no!") - } + (s, l mapExpr f) }) def mapType(f: Type => Type): Expression = BundleLiteral(lits.map { case (s, l) => - l mapType f match { - case lit: Literal => (s, lit) - case _ => throw new Exception("Oh no!") - } + (s, l mapType f) }) def mapWidth(f: Width => Width): Expression = BundleLiteral(lits.map { case (s, l) => - l mapWidth f match { - case lit: Literal => (s, lit) - case _ => throw new Exception("Oh no!") - } + (s, l mapWidth f) }) } case class VectorExpression(exprs: Seq[Expression], tpe: Type) extends Expression { - // val value = lits.map(x => (x.value, x.width)).foldLeft(BigInt(0)) { case (prev, (v, IntWidth(w))) => - // (prev << w.toInt) + v - // } - // val width = lits.map(_.width).reduce(_ + _) def serialize = "[" + exprs.map(_.serialize).mkString(", ") + "]" // TODO type annotation def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(_ foreachExpr f) diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 99f7259089..fe32772852 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -45,7 +45,8 @@ object InferTypes extends Pass { e.exprs.map(_.tpe).reduce[Type](mux_type_and_widths), e.exprs.length )) - case e @ (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e + case e: BundleLiteral => e.copy(lits = e.lits.map(x => (x._1, infer_types_e(types)(x._2)))) + case e @ (_: UIntLiteral | _: SIntLiteral) => e } def infer_types_s(types: TypeMap)(s: Statement): Statement = s match { @@ -107,7 +108,8 @@ object CInferTypes extends Pass { e.exprs.map(_.tpe).reduce[Type](mux_type), e.exprs.length )) - case e @ (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e + case e: BundleLiteral => e.copy(lits = e.lits.map(x => (x._1, infer_types_e(types)(x._2)))) + case e @ (_: UIntLiteral | _: SIntLiteral) => e } def infer_types_s(types: TypeMap)(s: Statement): Statement = s match { diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala index 225b2d539a..cac6fa8c02 100644 --- a/src/main/scala/firrtl/proto/FromProto.scala +++ b/src/main/scala/firrtl/proto/FromProto.scala @@ -76,10 +76,10 @@ object FromProto { ir.FixedLiteral(convert(fixed.getValue), width, point) } - def convert(fixed: Firrtl.Expression.BundleLiteral.LiteralField): (String, ir.Literal) = { + def convert(fixed: Firrtl.Expression.BundleLiteral.Field): (String, ir.Expression) = { val value = convert(fixed.getValue) value match { - case l: ir.Literal => (fixed.getName, l) + case l: ir.Expression => (fixed.getName, l) } } diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala index 585adf035b..e479be85cc 100644 --- a/src/main/scala/firrtl/proto/ToProto.scala +++ b/src/main/scala/firrtl/proto/ToProto.scala @@ -179,7 +179,7 @@ object ToProto { case ir.BundleLiteral(fields) => val bb = Firrtl.Expression.BundleLiteral.newBuilder() fields.foreach({ case (n, v) => - val fb = Firrtl.Expression.BundleLiteral.LiteralField.newBuilder() + val fb = Firrtl.Expression.BundleLiteral.Field.newBuilder() fb.setName(n) fb.setValue(convert(v)) bb.addField(fb) diff --git a/src/test/scala/firrtlTests/MemSpec.scala b/src/test/scala/firrtlTests/MemSpec.scala index b23f07466b..61726a6f75 100644 --- a/src/test/scala/firrtlTests/MemSpec.scala +++ b/src/test/scala/firrtlTests/MemSpec.scala @@ -37,7 +37,7 @@ class MemSpec extends FirrtlPropSpec with FirrtlMatchers { """circuit Unit : | module Unit : | input clock : Clock - | output i : { a: { b: UInt<32> }, c: UInt<32> } + | input i : { a: { b: UInt<32> }, c: UInt<32> } | output o : { a: { b: UInt<32> }, c: UInt<32> } | input wZero: UInt<1> | input waddr: UInt<4> @@ -56,7 +56,7 @@ class MemSpec extends FirrtlPropSpec with FirrtlMatchers { """circuit Unit : | module Unit : | input clock : Clock - | output i : { a: { b: UInt<32> }, c: UInt<32> } + | input i : { a: { b: UInt<32> }, c: UInt<32> } | output o : { a: { b: UInt<32> }, c: UInt<32> } | input wZero: UInt<1> | input waddr: UInt<4> @@ -176,7 +176,7 @@ class MemSpec extends FirrtlPropSpec with FirrtlMatchers { iCircuit should be (cCircuit) } - ignore("Writing to mems with a bundle of vector literals should work") { + property("Writing to mems with a bundle of vector literals should work") { val passes = Seq( CheckChirrtl, CInferTypes, From 1156eb3359f60b40b5c444922f6ebaaf0e0f8805 Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Mon, 2 Mar 2020 19:27:04 -0800 Subject: [PATCH 11/13] Rename BundleLiteral to BundleExpression. This reflects the fact that it is no longer a literal. --- src/main/antlr4/FIRRTL.g4 | 4 ++-- src/main/proto/firrtl.proto | 4 ++-- src/main/scala/firrtl/Utils.scala | 4 ++-- src/main/scala/firrtl/Visitor.scala | 2 +- src/main/scala/firrtl/WIR.scala | 4 ++-- src/main/scala/firrtl/ir/IR.scala | 8 ++++---- src/main/scala/firrtl/passes/Checks.scala | 2 +- src/main/scala/firrtl/passes/ConvertFixedToSInt.scala | 2 +- src/main/scala/firrtl/passes/InferTypes.scala | 4 ++-- src/main/scala/firrtl/passes/LowerTypes.scala | 2 +- src/main/scala/firrtl/passes/Uniquify.scala | 5 +++-- src/main/scala/firrtl/proto/FromProto.scala | 8 ++++---- src/main/scala/firrtl/proto/ToProto.scala | 8 ++++---- src/test/scala/firrtlTests/ProtoBufSpec.scala | 4 ++-- 14 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index 307e9915cc..f64e8c0323 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -172,8 +172,8 @@ exp | 'mux(' exp exp exp ')' | 'validif(' exp exp ')' | primop exp* intLit* ')' - | '{' expField* '}' // Bundle Literal - | '[' exp* ']' // Vector Literal + | '{' expField* '}' // Bundle Expression + | '[' exp* ']' // Vector Expression ; id diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto index f1476eabd1..c4b15fed1b 100644 --- a/src/main/proto/firrtl.proto +++ b/src/main/proto/firrtl.proto @@ -377,7 +377,7 @@ message Firrtl { Width point = 3; } - message BundleLiteral { + message BundleExpression { message Field { // Required string name = 1; @@ -489,7 +489,7 @@ message Firrtl { UIntLiteral uint_literal = 2; SIntLiteral sint_literal = 3; FixedLiteral fixed_literal = 11; - BundleLiteral bundle_literal = 12; + BundleExpression bundle_expression = 12; VectorExpression vector_expression = 13; ValidIf valid_if = 4; //ExtractBits extract_bits = 5; diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index ad75ea73d3..292388c965 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -592,7 +592,7 @@ object Utils extends LazyLogging { case ex: DoPrim => MALE case ex: UIntLiteral => MALE case ex: SIntLiteral => MALE - case ex: BundleLiteral => MALE + case ex: BundleExpression => MALE case ex: VectorExpression => MALE case ex: Mux => MALE case ex: ValidIf => MALE @@ -625,7 +625,7 @@ object Utils extends LazyLogging { case ex: DoPrim => SourceFlow case ex: UIntLiteral => SourceFlow case ex: SIntLiteral => SourceFlow - case ex: BundleLiteral => SourceFlow + case ex: BundleExpression => SourceFlow case ex: VectorExpression => SourceFlow case ex: Mux => SourceFlow case ex: ValidIf => SourceFlow diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index 0ae688e6e2..b9f4da2938 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -392,7 +392,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w case "validif(" => ValidIf(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType) case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType) case "{" => - BundleLiteral(ctx.expField.asScala.map(visitExpField)) + BundleExpression(ctx.expField.asScala.map(visitExpField)) case "[" => VectorExpression(ctx_exp.map(visitExp), UnknownType) } diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 3e6db23913..2019ba51b4 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -77,9 +77,9 @@ object WSubField { } object WSubLiteral { def unapply(w: Expression): Option[Expression] = w match { - case WSubField(BundleLiteral(lits), name, _, _) => + case WSubField(BundleExpression(lits), name, _, _) => lits.collectFirst({ case (n, value) if n == name => value }) - case WSubField(WSubLiteral(BundleLiteral(lits)), name, _, _) => + case WSubField(WSubLiteral(BundleExpression(lits)), name, _, _) => lits.collectFirst({ case (n, value) if n == name => value }) case WSubIndex(VectorExpression(exprs, _), index, _, _) => exprs.lift(index) diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 888e5d3137..ead0b57171 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -240,7 +240,7 @@ case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Liter def foreachType(f: Type => Unit): Unit = Unit def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) } } -case class BundleLiteral(lits: Seq[(String, Expression)]) extends Expression { +case class BundleExpression(lits: Seq[(String, Expression)]) extends Expression { def tpe = BundleType(lits.map { case (name, lit) => Field(name = name, flip = Default, tpe = lit.tpe) }) @@ -251,13 +251,13 @@ case class BundleLiteral(lits: Seq[(String, Expression)]) extends Expression { def foreachExpr(f: Expression => Unit): Unit = lits.foreach { case (_, l) => l foreachExpr f } def foreachType(f: Type => Unit): Unit = lits.foreach { case (_, l) => l foreachType f } def foreachWidth(f: Width => Unit): Unit = lits.foreach { case (_, l) => l foreachWidth f } - def mapExpr(f: Expression => Expression): Expression = BundleLiteral(lits.map { case (s, l) => + def mapExpr(f: Expression => Expression): Expression = BundleExpression(lits.map { case (s, l) => (s, l mapExpr f) }) - def mapType(f: Type => Type): Expression = BundleLiteral(lits.map { case (s, l) => + def mapType(f: Type => Type): Expression = BundleExpression(lits.map { case (s, l) => (s, l mapType f) }) - def mapWidth(f: Width => Width): Expression = BundleLiteral(lits.map { case (s, l) => + def mapWidth(f: Width => Width): Expression = BundleExpression(lits.map { case (s, l) => (s, l mapWidth f) }) } diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index d11b055fa7..cb2e5cd076 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -151,7 +151,7 @@ trait CheckHighFormLike { e match { case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf | - _: BundleLiteral | _: VectorExpression => // No error + _: BundleExpression | _: VectorExpression => // No error case _ => errors.append(new InvalidAccessException(info, mname)) } } diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 991b34d088..faefa1ea7d 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -55,7 +55,7 @@ object ConvertFixedToSInt extends Pass { newExp map updateExpType case e: UIntLiteral => e case e: SIntLiteral => e - case e: BundleLiteral => e + case e: BundleExpression => e case e: VectorExpression => val point = calcPoint(e.exprs) val newExprs = e.exprs.map(alignArg(_, point)) diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index fe32772852..2bb32e4a2f 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -45,7 +45,7 @@ object InferTypes extends Pass { e.exprs.map(_.tpe).reduce[Type](mux_type_and_widths), e.exprs.length )) - case e: BundleLiteral => e.copy(lits = e.lits.map(x => (x._1, infer_types_e(types)(x._2)))) + case e: BundleExpression => e.copy(lits = e.lits.map(x => (x._1, infer_types_e(types)(x._2)))) case e @ (_: UIntLiteral | _: SIntLiteral) => e } @@ -108,7 +108,7 @@ object CInferTypes extends Pass { e.exprs.map(_.tpe).reduce[Type](mux_type), e.exprs.length )) - case e: BundleLiteral => e.copy(lits = e.lits.map(x => (x._1, infer_types_e(types)(x._2)))) + case e: BundleExpression => e.copy(lits = e.lits.map(x => (x._1, infer_types_e(types)(x._2)))) case e @ (_: UIntLiteral | _: SIntLiteral) => e } diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 925b3bb35a..3559ad4ad9 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -138,7 +138,7 @@ object LowerTypes extends Transform { def lowerTypesExp(memDataTypeMap: MemDataTypeMap, info: Info, mname: String)(e: Expression): Expression = e match { - case e @ (_: WRef | _: UIntLiteral | _: SIntLiteral | _: BundleLiteral | _: VectorExpression) => e + case e @ (_: WRef | _: UIntLiteral | _: SIntLiteral | _: BundleExpression | _: VectorExpression) => e case WSubLiteral(exp) => lowerTypesExp(memDataTypeMap, info, mname)(exp) case WSubIndex(ve: VectorExpression, index, _, _) => ve.exprs(index) diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 705dbaaa01..f5acafba32 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -176,7 +176,7 @@ object Uniquify extends Transform { val index = uniquifyNamesExp(e.index, map) (WSubAccess(subExp, index, e.tpe, e.flow), subMap) case _: VectorExpression => (exp, m) - case (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => (exp, m) + case (_: UIntLiteral | _: SIntLiteral | _: BundleExpression) => (exp, m) case (_: Mux | _: ValidIf | _: DoPrim) => (exp map ((e: Expression) => uniquifyNamesExp(e, map)), m) } @@ -255,8 +255,9 @@ object Uniquify extends Transform { uniquifyNamesExp(e, nameMap.toMap) case e: Mux => e map uniquifyExp case e: ValidIf => e map uniquifyExp + case e: BundleExpression => e map uniquifyExp case e: VectorExpression => e map uniquifyExp - case (_: UIntLiteral | _: SIntLiteral | _: BundleLiteral) => e + case (_: UIntLiteral | _: SIntLiteral) => e case e: DoPrim => e map uniquifyExp } diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala index cac6fa8c02..6c4ae3cd71 100644 --- a/src/main/scala/firrtl/proto/FromProto.scala +++ b/src/main/scala/firrtl/proto/FromProto.scala @@ -76,15 +76,15 @@ object FromProto { ir.FixedLiteral(convert(fixed.getValue), width, point) } - def convert(fixed: Firrtl.Expression.BundleLiteral.Field): (String, ir.Expression) = { + def convert(fixed: Firrtl.Expression.BundleExpression.Field): (String, ir.Expression) = { val value = convert(fixed.getValue) value match { case l: ir.Expression => (fixed.getName, l) } } - def convert(fixed: Firrtl.Expression.BundleLiteral): ir.BundleLiteral = { - ir.BundleLiteral(fixed.getFieldList.asScala.map(convert(_))) + def convert(fixed: Firrtl.Expression.BundleExpression): ir.BundleExpression = { + ir.BundleExpression(fixed.getFieldList.asScala.map(convert(_))) } def convert(fixed: Firrtl.Expression.VectorExpression): ir.VectorExpression = { @@ -119,7 +119,7 @@ object FromProto { case UINT_LITERAL_FIELD_NUMBER => convert(expr.getUintLiteral) case SINT_LITERAL_FIELD_NUMBER => convert(expr.getSintLiteral) case FIXED_LITERAL_FIELD_NUMBER => convert(expr.getFixedLiteral) - case BUNDLE_LITERAL_FIELD_NUMBER => convert(expr.getBundleLiteral) + case BUNDLE_EXPRESSION_FIELD_NUMBER => convert(expr.getBundleExpression) case VECTOR_EXPRESSION_FIELD_NUMBER => convert(expr.getVectorExpression) case PRIM_OP_FIELD_NUMBER => convert(expr.getPrimOp) case MUX_FIELD_NUMBER => convert(expr.getMux) diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala index e479be85cc..e32dbdaf1c 100644 --- a/src/main/scala/firrtl/proto/ToProto.scala +++ b/src/main/scala/firrtl/proto/ToProto.scala @@ -176,15 +176,15 @@ object ToProto { convert(width).foreach(fb.setWidth) convert(point).foreach(fb.setPoint) eb.setFixedLiteral(fb) - case ir.BundleLiteral(fields) => - val bb = Firrtl.Expression.BundleLiteral.newBuilder() + case ir.BundleExpression(fields) => + val bb = Firrtl.Expression.BundleExpression.newBuilder() fields.foreach({ case (n, v) => - val fb = Firrtl.Expression.BundleLiteral.Field.newBuilder() + val fb = Firrtl.Expression.BundleExpression.Field.newBuilder() fb.setName(n) fb.setValue(convert(v)) bb.addField(fb) }) - eb.setBundleLiteral(bb) + eb.setBundleExpression(bb) case ir.VectorExpression(exps, _) => val bb = Firrtl.Expression.VectorExpression.newBuilder() exps.foreach({ case e => diff --git a/src/test/scala/firrtlTests/ProtoBufSpec.scala b/src/test/scala/firrtlTests/ProtoBufSpec.scala index eb419849a1..27075da298 100644 --- a/src/test/scala/firrtlTests/ProtoBufSpec.scala +++ b/src/test/scala/firrtlTests/ProtoBufSpec.scala @@ -123,9 +123,9 @@ class ProtoBufSpec extends FirrtlFlatSpec { val flit = ir.FixedLiteral(-123, ir.IntWidth(32), ir.IntWidth(30)) FromProto.convert(ToProto.convert(flit).build) should equal (flit) - val blit = ir.BundleLiteral(Seq( + val blit = ir.BundleExpression(Seq( ("a", ulit), - ("b", ir.BundleLiteral(Seq( ("c", slit), ("d", flit)))) + ("b", ir.BundleExpression(Seq( ("c", slit), ("d", flit)))) )) FromProto.convert(ToProto.convert(blit).build) should equal (blit) } From ec34ce32d4fdd64b87203314d3484b26f35f7dab Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Mon, 2 Mar 2020 19:39:22 -0800 Subject: [PATCH 12/13] Small cleanup --- src/test/scala/firrtlTests/MemSpec.scala | 1 + src/test/scala/firrtlTests/ParserSpec.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/test/scala/firrtlTests/MemSpec.scala b/src/test/scala/firrtlTests/MemSpec.scala index 61726a6f75..f2dce1cae4 100644 --- a/src/test/scala/firrtlTests/MemSpec.scala +++ b/src/test/scala/firrtlTests/MemSpec.scala @@ -260,6 +260,7 @@ class MemSpec extends FirrtlPropSpec with FirrtlMatchers { val cCircuit = parse(removeJunk(cWriter.toString())) iCircuit should be (cCircuit) } + property("Very large memories should be supported") { val addrWidth = 65 val memSize = BigInt(1) << addrWidth diff --git a/src/test/scala/firrtlTests/ParserSpec.scala b/src/test/scala/firrtlTests/ParserSpec.scala index 1911a80d17..f2b32dd168 100644 --- a/src/test/scala/firrtlTests/ParserSpec.scala +++ b/src/test/scala/firrtlTests/ParserSpec.scala @@ -237,7 +237,7 @@ class ParserPropSpec extends FirrtlPropSpec { } } } - property("Bundle literals should be OK") { + property("Bundle expressions should be OK") { forAll (identifier, bundleField, uintValues) { case (id, field, uval) => whenever(id.nonEmpty && field.nonEmpty) { val input = s""" From c43243f6cfb47d04d85581384f5d621783d21229 Mon Sep 17 00:00:00 2001 From: Paul Rigge Date: Mon, 2 Mar 2020 19:48:27 -0800 Subject: [PATCH 13/13] Add test for bundle expression as register init --- .../scala/firrtl/annotations/Target.scala | 2 ++ .../scala/firrtlTests/LowerTypesSpec.scala | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala index c7cb98d234..537aa45f21 100644 --- a/src/main/scala/firrtl/annotations/Target.scala +++ b/src/main/scala/firrtl/annotations/Target.scala @@ -115,6 +115,8 @@ object Target { case d: DoPrim => m.ref("@" + d.serialize) case d: Mux => m.ref("@" + d.serialize) case d: ValidIf => m.ref("@" + d.serialize) + case b: BundleExpression => m.ref("@" + b.serialize) + case v: VectorExpression => m.ref("@" + v.serialize) case d: Literal => m.ref("@" + d.serialize) case other => sys.error(s"Unsupported: $other") } diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index b0e5727b1d..ad5d87a85d 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -115,6 +115,25 @@ class LowerTypesSpec extends FirrtlFlatSpec { executeTest(input, expected) } + it should "lower registers with aggregate expression initialization" in { + val input = + """circuit Test : + | module Test : + | input clock : Clock + | input reset : UInt<1> + | reg x : { a : UInt<1>, b : UInt<1>}[2], clock with : + | reset => (reset, [{ a : UInt<1>("h0"), b : UInt<1>("h0") }, { a : UInt<1>("h1"), b : UInt<1>("h1") }]) + """.stripMargin + val expected = Seq( + "reg x_0_a : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h0\"))", + "reg x_0_b : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h0\"))", + "reg x_1_a : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h1\"))", + "reg x_1_b : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h1\"))" + ) map normalized + + executeTest(input, expected) + } + it should "lower DefRegister expressions: clock, reset, and init" in { val input = """circuit Test :