Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Add Bundle Literals. #929

Open
wants to merge 18 commits into
base: master-deprecated
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/main/antlr4/FIRRTL.g4
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ type
| type '[' intLit ']' // Vector
;

expField
: fieldId ':' exp
;

field
: 'flip'? fieldId ':' type
;
Expand Down Expand Up @@ -168,6 +172,8 @@ exp
| 'mux(' exp exp exp ')'
| 'validif(' exp exp ')'
| primop exp* intLit* ')'
| '{' expField* '}' // Bundle Expression
| '[' exp* ']' // Vector Expression
;

id
Expand Down
16 changes: 16 additions & 0 deletions src/main/proto/firrtl.proto
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,20 @@ message Firrtl {
Width point = 3;
}

message BundleExpression {
message Field {
// Required
string name = 1;
// Required
Expression value = 2;
}
repeated Field field = 1;
}

message VectorExpression {
repeated Expression exp = 1;
}

message ValidIf {
// Required.
Expression condition = 1;
Expand Down Expand Up @@ -475,6 +489,8 @@ message Firrtl {
UIntLiteral uint_literal = 2;
SIntLiteral sint_literal = 3;
FixedLiteral fixed_literal = 11;
BundleExpression bundle_expression = 12;
VectorExpression vector_expression = 13;
ValidIf valid_if = 4;
//ExtractBits extract_bits = 5;
Mux mux = 6;
Expand Down
4 changes: 4 additions & 0 deletions src/main/scala/firrtl/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ object Utils extends LazyLogging {
case ex: DoPrim => MALE
case ex: UIntLiteral => MALE
case ex: SIntLiteral => MALE
case ex: BundleExpression => MALE
case ex: VectorExpression => MALE
case ex: Mux => MALE
case ex: ValidIf => MALE
case WInvalid => MALE
Expand Down Expand Up @@ -623,6 +625,8 @@ object Utils extends LazyLogging {
case ex: DoPrim => SourceFlow
case ex: UIntLiteral => SourceFlow
case ex: SIntLiteral => SourceFlow
case ex: BundleExpression => SourceFlow
case ex: VectorExpression => SourceFlow
case ex: Mux => SourceFlow
case ex: ValidIf => SourceFlow
case WInvalid => SourceFlow
Expand Down
8 changes: 8 additions & 0 deletions src/main/scala/firrtl/Visitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
}
}

private def visitExpField[FirrtlNode](ctx: ExpFieldContext): (String, Expression) = {
(ctx.fieldId.getText, visitExp(ctx.exp))
}

// Special case "type" of CHIRRTL mems because their size can be BigInt
private def visitCMemType(ctx: TypeContext): (Type, BigInt) = {
def loc: String = s"${ctx.getStart.getLine}:${ctx.getStart.getCharPositionInLine}"
Expand Down Expand Up @@ -387,6 +391,10 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
}
case "validif(" => ValidIf(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType)
case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType)
case "{" =>
BundleExpression(ctx.expField.asScala.map(visitExpField))
case "[" =>
VectorExpression(ctx_exp.map(visitExp), UnknownType)
}
}
}
Expand Down
13 changes: 13 additions & 0 deletions src/main/scala/firrtl/WIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ object WSubField {
def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UnknownFlow)
def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UnknownFlow)
}
object WSubLiteral {
def unapply(w: Expression): Option[Expression] = w match {
case WSubField(BundleExpression(lits), name, _, _) =>
lits.collectFirst({ case (n, value) if n == name => value })
case WSubField(WSubLiteral(BundleExpression(lits)), name, _, _) =>
lits.collectFirst({ case (n, value) if n == name => value })
case WSubIndex(VectorExpression(exprs, _), index, _, _) =>
exprs.lift(index)
case WSubIndex(WSubLiteral(VectorExpression(exprs, _)), index, _, _) =>
exprs.lift(index)
case _ => None
}
}
case class WSubIndex(expr: Expression, value: Int, tpe: Type, flow: Flow) extends Expression with GenderFromFlow {
def serialize: String = s"${expr.serialize}[$value]"
def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/firrtl/annotations/Target.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ object Target {
case d: DoPrim => m.ref("@" + d.serialize)
case d: Mux => m.ref("@" + d.serialize)
case d: ValidIf => m.ref("@" + d.serialize)
case b: BundleExpression => m.ref("@" + b.serialize)
case v: VectorExpression => m.ref("@" + v.serialize)
case d: Literal => m.ref("@" + d.serialize)
case other => sys.error(s"Unsupported: $other")
}
Expand Down
34 changes: 34 additions & 0 deletions src/main/scala/firrtl/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,40 @@ case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Liter
def foreachType(f: Type => Unit): Unit = Unit
def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) }
}
case class BundleExpression(lits: Seq[(String, Expression)]) extends Expression {
def tpe = BundleType(lits.map { case (name, lit) =>
Field(name = name, flip = Default, tpe = lit.tpe)
})
def serialize =
"{ " + (lits.map({ case (n, v) =>
s"$n : ${v.serialize}"
}) mkString ", ") + " }"
def foreachExpr(f: Expression => Unit): Unit = lits.foreach { case (_, l) => l foreachExpr f }
def foreachType(f: Type => Unit): Unit = lits.foreach { case (_, l) => l foreachType f }
def foreachWidth(f: Width => Unit): Unit = lits.foreach { case (_, l) => l foreachWidth f }
def mapExpr(f: Expression => Expression): Expression = BundleExpression(lits.map { case (s, l) =>
(s, l mapExpr f)
})
def mapType(f: Type => Type): Expression = BundleExpression(lits.map { case (s, l) =>
(s, l mapType f)
})
def mapWidth(f: Width => Width): Expression = BundleExpression(lits.map { case (s, l) =>
(s, l mapWidth f)
})
}
case class VectorExpression(exprs: Seq[Expression], tpe: Type) extends Expression {
def serialize =
"[" + exprs.map(_.serialize).mkString(", ") + "]" // TODO type annotation
def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(_ foreachExpr f)
def foreachType(f: Type => Unit): Unit = exprs.foreach(_ foreachType f)
def foreachWidth(f: Width => Unit): Unit = exprs.foreach(_ foreachWidth f)
def mapExpr(f: Expression => Expression): Expression =
VectorExpression(exprs.map(_ mapExpr f), tpe)
def mapType(f: Type => Type): Expression =
VectorExpression(exprs.map(_ mapType f), tpe)
def mapWidth(f: Width => Width): Expression =
VectorExpression(exprs.map(_ mapWidth f), tpe)
}
case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression {
def serialize: String = op.serialize + "(" +
(args.map(_.serialize) ++ consts.map(_.toString)).mkString(", ") + ")"
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/firrtl/passes/Checks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ trait CheckHighFormLike { this: Pass =>
def validSubexp(info: Info, mname: String)(e: Expression): Unit = {
e match {
case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf |
_: BundleExpression | _: VectorExpression => // No error
case _ => errors.append(new InvalidAccessException(info, mname))
}
}
Expand Down Expand Up @@ -470,6 +471,8 @@ object CheckTypes extends Pass with PreservesAll[Transform] {
case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval)
case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval)
case i:IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true)
case _: BundleType |
_: VectorType => (isUInt, isSInt, isClock, isFix, isAsync, isInterval)
case UnknownType =>
errors.append(new IllegalUnknownType(info, mname, e.serialize))
(isUInt, isSInt, isClock, isFix, isAsync, isInterval)
Expand Down
5 changes: 5 additions & 0 deletions src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ object ConvertFixedToSInt extends Pass with PreservesAll[Transform] {
newExp map updateExpType
case e: UIntLiteral => e
case e: SIntLiteral => e
case e: BundleExpression => e
case e: VectorExpression =>
val point = calcPoint(e.exprs)
val newExprs = e.exprs.map(alignArg(_, point))
VectorExpression(newExprs, UnknownType)
case _ => e map updateExpType match {
case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe)
case WRef(name, tpe, k, g) => WRef(name, types(name), k, g)
Expand Down
12 changes: 11 additions & 1 deletion src/main/scala/firrtl/passes/InferTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ object InferTypes extends Pass with PreservesAll[Transform] {
case e: DoPrim => PrimOps.set_primop_type(e)
case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval))
case e: ValidIf => e copy (tpe = e.value.tpe)
case e @ (_: UIntLiteral | _: SIntLiteral) => e
case e: VectorExpression => e copy (tpe = VectorType(
e.exprs.map(_.tpe).reduce[Type](mux_type_and_widths),
e.exprs.length
))
case e: BundleExpression => e.copy(lits = e.lits.map(x => (x._1, infer_types_e(types)(x._2))))
case e @ (_: UIntLiteral | _: SIntLiteral) => e
}

def infer_types_s(types: TypeMap)(s: Statement): Statement = s match {
Expand Down Expand Up @@ -106,6 +111,11 @@ object CInferTypes extends Pass with PreservesAll[Transform] {
case (e: DoPrim) => PrimOps.set_primop_type(e)
case (e: Mux) => e copy (tpe = mux_type(e.tval, e.fval))
case (e: ValidIf) => e copy (tpe = e.value.tpe)
case e: VectorExpression => e copy (tpe = VectorType(
e.exprs.map(_.tpe).reduce[Type](mux_type),
e.exprs.length
))
case e: BundleExpression => e.copy(lits = e.lits.map(x => (x._1, infer_types_e(types)(x._2))))
case e @ (_: UIntLiteral | _: SIntLiteral) => e
}

Expand Down
13 changes: 7 additions & 6 deletions src/main/scala/firrtl/passes/LowerTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@ object LowerTypes extends Transform {

def lowerTypesExp(memDataTypeMap: MemDataTypeMap,
info: Info, mname: String)(e: Expression): Expression = e match {
case e: WRef => e
case (_: WSubField | _: WSubIndex) => kind(e) match {
case e @ (_: WRef | _: UIntLiteral | _: SIntLiteral | _: BundleExpression | _: VectorExpression) => e
case WSubLiteral(exp) =>
lowerTypesExp(memDataTypeMap, info, mname)(exp)
case WSubIndex(ve: VectorExpression, index, _, _) => ve.exprs(index)
case e @ (_: WSubField | _: WSubIndex) => kind(e) match {
case InstanceKind =>
val (root, tail) = splitRef(e)
val name = loweredName(tail)
Expand All @@ -162,10 +165,8 @@ object LowerTypes extends Transform {
}
case _ => WRef(loweredName(e), e.tpe, kind(e), flow(e))
}
case e: Mux => e map lowerTypesExp(memDataTypeMap, info, mname)
case e: ValidIf => e map lowerTypesExp(memDataTypeMap, info, mname)
case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname)
case e @ (_: UIntLiteral | _: SIntLiteral) => e
case e @ (_: Mux | _: ValidIf | _: DoPrim | _: VectorExpression) =>
e map lowerTypesExp(memDataTypeMap, info, mname)
}
def lowerTypesStmt(memDataTypeMap: MemDataTypeMap,
minfo: Info, mname: String, renames: RenameMap)(s: Statement): Statement = {
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/firrtl/passes/Uniquify.scala
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ object Uniquify extends Transform {
val (subExp, subMap) = rec(e.expr, m)
val index = uniquifyNamesExp(e.index, map)
(WSubAccess(subExp, index, e.tpe, e.flow), subMap)
case (_: UIntLiteral | _: SIntLiteral) => (exp, m)
case _: VectorExpression => (exp, m)
case (_: UIntLiteral | _: SIntLiteral | _: BundleExpression) => (exp, m)
case (_: Mux | _: ValidIf | _: DoPrim) =>
(exp map ((e: Expression) => uniquifyNamesExp(e, map)), m)
}
Expand Down Expand Up @@ -267,6 +268,8 @@ object Uniquify extends Transform {
uniquifyNamesExp(e, nameMap.toMap)
case e: Mux => e map uniquifyExp
case e: ValidIf => e map uniquifyExp
case e: BundleExpression => e map uniquifyExp
case e: VectorExpression => e map uniquifyExp
case (_: UIntLiteral | _: SIntLiteral) => e
case e: DoPrim => e map uniquifyExp
}
Expand Down
17 changes: 17 additions & 0 deletions src/main/scala/firrtl/proto/FromProto.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ object FromProto {
ir.FixedLiteral(convert(fixed.getValue), width, point)
}

def convert(fixed: Firrtl.Expression.BundleExpression.Field): (String, ir.Expression) = {
val value = convert(fixed.getValue)
value match {
case l: ir.Expression => (fixed.getName, l)
}
}

def convert(fixed: Firrtl.Expression.BundleExpression): ir.BundleExpression = {
ir.BundleExpression(fixed.getFieldList.asScala.map(convert(_)))
}

def convert(fixed: Firrtl.Expression.VectorExpression): ir.VectorExpression = {
ir.VectorExpression(fixed.getExpList.asScala.map(convert(_)), ir.UnknownType)
}

def convert(subfield: Firrtl.Expression.SubField): ir.SubField =
ir.SubField(convert(subfield.getExpression), subfield.getField, ir.UnknownType)

Expand Down Expand Up @@ -104,6 +119,8 @@ object FromProto {
case UINT_LITERAL_FIELD_NUMBER => convert(expr.getUintLiteral)
case SINT_LITERAL_FIELD_NUMBER => convert(expr.getSintLiteral)
case FIXED_LITERAL_FIELD_NUMBER => convert(expr.getFixedLiteral)
case BUNDLE_EXPRESSION_FIELD_NUMBER => convert(expr.getBundleExpression)
case VECTOR_EXPRESSION_FIELD_NUMBER => convert(expr.getVectorExpression)
case PRIM_OP_FIELD_NUMBER => convert(expr.getPrimOp)
case MUX_FIELD_NUMBER => convert(expr.getMux)
}
Expand Down
16 changes: 16 additions & 0 deletions src/main/scala/firrtl/proto/ToProto.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,22 @@ object ToProto {
convert(width).foreach(fb.setWidth)
convert(point).foreach(fb.setPoint)
eb.setFixedLiteral(fb)
case ir.BundleExpression(fields) =>
val bb = Firrtl.Expression.BundleExpression.newBuilder()
fields.foreach({ case (n, v) =>
val fb = Firrtl.Expression.BundleExpression.Field.newBuilder()
fb.setName(n)
fb.setValue(convert(v))
bb.addField(fb)
})
eb.setBundleExpression(bb)
case ir.VectorExpression(exps, _) =>
val bb = Firrtl.Expression.VectorExpression.newBuilder()
exps.foreach({ case e =>
bb.addExp(convert(e))
})

eb.setVectorExpression(bb)
case ir.DoPrim(op, args, consts, _) =>
val db = Firrtl.Expression.PrimOp.newBuilder()
.setOp(convert(op))
Expand Down
19 changes: 19 additions & 0 deletions src/test/scala/firrtlTests/LowerTypesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,25 @@ class LowerTypesSpec extends FirrtlFlatSpec {
executeTest(input, expected)
}

it should "lower registers with aggregate expression initialization" in {
val input =
"""circuit Test :
| module Test :
| input clock : Clock
| input reset : UInt<1>
| reg x : { a : UInt<1>, b : UInt<1>}[2], clock with :
| reset => (reset, [{ a : UInt<1>("h0"), b : UInt<1>("h0") }, { a : UInt<1>("h1"), b : UInt<1>("h1") }])
""".stripMargin
val expected = Seq(
"reg x_0_a : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h0\"))",
"reg x_0_b : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h0\"))",
"reg x_1_a : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h1\"))",
"reg x_1_b : UInt<1>, clock with :", "reset => (reset, UInt<1>(\"h1\"))"
) map normalized

executeTest(input, expected)
}

it should "lower DefRegister expressions: clock, reset, and init" in {
val input =
"""circuit Test :
Expand Down
Loading