Skip to content

Commit

Permalink
First step on decode API
Browse files Browse the repository at this point in the history
1. add QMC decoder from rocket-chip.
2. add decoder API.
  • Loading branch information
sequencer committed Jan 21, 2021
1 parent 616256c commit b76dff8
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package chisel3.util.experimental.decoder

import chisel3._
import chisel3.util.BitPat

abstract class Decoder {
def decode(addr: UInt, default: BitPat, mapping: Iterable[(BitPat, BitPat)]): UInt
}
196 changes: 196 additions & 0 deletions src/main/scala/chisel3/util/experimental/decoder/QMCDecoder.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package chisel3.util.experimental.decoder

import chisel3._
import chisel3.util.{BitPat, Cat}

import scala.annotation.tailrec
import scala.collection.mutable

object QMCDecoder {
def apply(): QMCDecoder = new QMCDecoder()

/** decoder cache during a chisel elaboration. */
private val caches: mutable.Map[UInt, mutable.Map[Term, Bool]] = mutable.Map[UInt, mutable.Map[Term, Bool]]()
}

class QMCDecoder extends Decoder {
def decode(addr: UInt, default: BitPat, mapping: Iterable[(BitPat, BitPat)]): UInt = {
def logic(addr: UInt, addrWidth: Int, cache: mutable.Map[Term, Bool], terms: Seq[Term]): Bool = {
terms.map { t =>
cache
.getOrElseUpdate(
t, (
if (t.mask == 0)
addr
else
addr & (BigInt(2).pow(addrWidth) - (t.mask + 1)).U(addrWidth.W)
) === t.value.U(addrWidth.W)
)
}.foldLeft(false.B)(_ || _)
}

def term(lit: BitPat) = new Term(lit.value, BigInt(2).pow(lit.getWidth) - (lit.mask + 1))

def getEssentialPrimeImplicants(prime: Seq[Term], minterms: Seq[Term]): (Seq[Term], Seq[Term], Seq[Term]) = {
val primeCovers = prime.map(p => minterms.filter(p.covers))
for (((icover, pi), i) <- (primeCovers zip prime).zipWithIndex) {
for (((jcover, pj), _) <- (primeCovers zip prime).zipWithIndex.drop(i + 1)) {
if (icover.size > jcover.size && jcover.forall(pi.covers))
return getEssentialPrimeImplicants(prime.filter(_ != pj), minterms)
}
}
val essentiallyCovered = minterms.filter(t => prime.count(_.covers(t)) == 1)
val essential = prime.filter(p => essentiallyCovered.exists(p.covers))
val nonessential = prime.filterNot(essential contains _)
val uncovered = minterms.filterNot(t => essential.exists(_.covers(t)))
if (essential.isEmpty || uncovered.isEmpty)
(essential, nonessential, uncovered)
else {
val (a, b, c) = getEssentialPrimeImplicants(nonessential, uncovered)
(essential ++ a, b, c)
}
}

def getCover(implicants: Seq[Term], minterms: Seq[Term], bits: Int): Seq[Term] = {
def getCost(cover: Seq[Term], bits: Int) = cover.map(bits - _.mask.bitCount).sum

def cheaper(a: List[Term], b: List[Term], bits: Int) = {
val ca = getCost(a, bits)
val cb = getCost(b, bits)

@tailrec
def listLess(a: List[Term], b: List[Term]): Boolean = b.nonEmpty && (a.isEmpty || a.head < b.head || a.head == b.head && listLess(a.tail, b.tail))

ca < cb || ca == cb && listLess(a.sortWith(_ < _), b.sortWith(_ < _))
}

if (minterms.nonEmpty) {
val cover = minterms.map(m => implicants.filter(_.covers(m)))
val all = cover.tail.foldLeft(cover.head.map(Set(_)))((c0, c1) => c0.flatMap(a => c1.map(a + _)))
all.map(_.toList).reduceLeft((a, b) => if (cheaper(a, b, bits)) a else b)
} else
Seq[Term]()
}

def verify(cover: Seq[Term], minterms: Seq[Term], maxterms: Seq[Term]): Unit = {
assert(minterms.forall(t => cover.exists(_.covers(t))))
assert(maxterms.forall(t => !cover.exists(_ intersects t)))
}

def simplifyDC(minterms: Seq[Term], maxterms: Seq[Term], bits: Int): Seq[Term] = {
def getPrimeImplicants(minterms: Seq[Term], maxterms: Seq[Term], bits: Int): Seq[Term] = {
def getImplicitDC(maxterms: Seq[Term], term: Term, bits: Int, above: Boolean): Term = {
for (i <- 0 until bits) {
var t: Term = null
if (above && ((term.value | term.mask) & (BigInt(1) << i)) == 0)
t = new Term(term.value | (BigInt(1) << i), term.mask)
else if (!above && (term.value & (BigInt(1) << i)) != 0)
t = new Term(term.value & ~(BigInt(1) << i), term.mask)
if (t != null && !maxterms.exists(_.intersects(t)))
return t
}
null
}

var prime = List[Term]()
minterms.foreach(_.prime = true)
val mint = minterms.map(t => new Term(t.value, t.mask))
val cols = (0 to bits).map(b => mint.filter(b == _.mask.bitCount))
val table = cols.map(c => (0 to bits).map(b => mutable.Set(c.filter(b == _.value.bitCount): _*)))

for (i <- 0 to bits) {
for (j <- 0 until bits - i) {
table(i)(j).foreach(a => table(i + 1)(j) ++= table(i)(j + 1).filter(_ similar a).map(_ merge a))
}
for (j <- 0 until bits - i) {
for (a <- table(i)(j).filter(_.prime)) {
val dc = getImplicitDC(maxterms, a, bits, above = true)
if (dc != null)
table(i + 1)(j) += dc merge a
}
for (a <- table(i)(j + 1).filter(_.prime)) {
val dc = getImplicitDC(maxterms, a, bits, above = false)
if (dc != null)
table(i + 1)(j) += a merge dc
}
}
for (r <- table(i))
for (p <- r; if p.prime)
prime = p :: prime
}
prime.sortWith(_ < _)
}

val prime = getPrimeImplicants(minterms, maxterms, bits)
val (eprime, prime2, uncovered) = getEssentialPrimeImplicants(prime, minterms)
val cover = eprime ++ getCover(prime2, uncovered, bits)
verify(cover, minterms, maxterms)
cover
}

def simplify(minterms: Seq[Term], dontcares: Seq[Term], bits: Int): Seq[Term] = {
def getPrimeImplicants(implicants: Seq[Term], bits: Int): Seq[Term] = {
var prime = List[Term]()
implicants.foreach(_.prime = true)
val cols = (0 to bits).map(b => implicants.filter(b == _.mask.bitCount))
val table = cols.map(c => (0 to bits).map(b => mutable.Set(c.filter(b == _.value.bitCount): _*)))
for (i <- 0 to bits) {
for (j <- 0 until bits - i)
table(i)(j).foreach(a => table(i + 1)(j) ++= table(i)(j + 1).filter(_.similar(a)).map(_.merge(a)))
for (r <- table(i))
for (p <- r; if p.prime)
prime = p :: prime
}
prime.sortWith(_ < _)
}


if (dontcares.isEmpty) {
// As an elaboration performance optimization, don't be too clever if
// there are no don't-cares; synthesis can figure it out.
minterms
} else {
val prime = getPrimeImplicants(minterms ++ dontcares, bits)
minterms.foreach(t => assert(prime.exists(_.covers(t))))
val (eprime, prime2, uncovered) = getEssentialPrimeImplicants(prime, minterms)
val cover = eprime ++ getCover(prime2, uncovered, bits)
minterms.foreach(t => assert(cover.exists(_.covers(t)))) // sanity check
cover
}
}

val cache = QMCDecoder.caches.getOrElseUpdate(addr, mutable.Map[Term, Bool]())
val defaultTerm = term(default)
val (keys, values) = mapping.unzip
val addrWidth = keys.map(_.getWidth).max
val terms = keys.toList.map(k => term(k))
val termvalues = terms.zip(values.toList.map(term))

for (t <- keys.zip(terms).tails; if t.nonEmpty)
for (u <- t.tail)
assert(
!t.head._2.intersects(u._2),
"DecodeLogic: keys " + t.head + " and " + u + " overlap"
)

Cat(
(0 until default.getWidth.max(values.map(_.getWidth).max))
.map({ i: Int =>
val mint: Seq[Term] =
termvalues.filter { case (_, t) => ((t.mask >> i) & 1) == 0 && ((t.value >> i) & 1) == 1 }.map(_._1)
val maxt: Seq[Term] =
termvalues.filter { case (_, t) => ((t.mask >> i) & 1) == 0 && ((t.value >> i) & 1) == 0 }.map(_._1)
val dc: Seq[Term] = termvalues.filter { case (_, t) => ((t.mask >> i) & 1) == 1 }.map(_._1)
if (((defaultTerm.mask >> i) & 1) != 0) {
logic(addr, addrWidth, cache, simplifyDC(mint, maxt, addrWidth))
} else {
val defbit = (defaultTerm.value.toInt >> i) & 1
val t = if (defbit == 0) mint else maxt
val bit = logic(addr, addrWidth, cache, simplify(t, dc, addrWidth))
if (defbit == 0) bit else ~bit
}
})
.reverse
)
}
}
32 changes: 32 additions & 0 deletions src/main/scala/chisel3/util/experimental/decoder/Term.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package chisel3.util.experimental.decoder

private class Term(val value: BigInt, val mask: BigInt = 0) {
var prime = true

def covers(x: Term): Boolean = ((value ^ x.value) &~ mask | x.mask &~ mask).signum == 0

def intersects(x: Term): Boolean = ((value ^ x.value) &~ mask &~ x.mask).signum == 0

override def equals(that: Any): Boolean = that match {
case x: Term => x.value == value && x.mask == mask
case _ => false
}

override def hashCode: Int = value.toInt

def <(that: Term): Boolean = value < that.value || value == that.value && mask < that.mask

def similar(x: Term): Boolean = {
val diff = value - x.value
mask == x.mask && value > x.value && (diff & diff - 1) == 0
}

def merge(x: Term): Term = {
prime = false
x.prime = false
val bit = value - x.value
new Term(value &~ bit, mask | bit)
}

override def toString: String = value.toString(16) + "-" + mask.toString(16) + (if (prime) "p" else "")
}
36 changes: 36 additions & 0 deletions src/main/scala/chisel3/util/experimental/decoder/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package chisel3.util.experimental

import chisel3.{Bool, UInt}
import chisel3.util.BitPat

import scala.collection.mutable.ArrayBuffer

package object decoder {
def decode(addr: UInt, default: BitPat, mapping: Iterable[(BitPat, BitPat)], decoder: Decoder): UInt =
decoder.decode(addr, default, mapping)

def decode(addr: UInt, default: Seq[BitPat], mappingIn: Iterable[(BitPat, Seq[BitPat])], decoder: Decoder): Seq[UInt] = {
val mapping = ArrayBuffer.fill(default.size)(ArrayBuffer[(BitPat, BitPat)]())
for ((key, values) <- mappingIn)
for ((value, i) <- values.zipWithIndex)
mapping(i) += key -> value
for ((thisDefault, thisMapping) <- default.zip(mapping))
yield decode(addr, thisDefault, thisMapping, decoder)
}

def decode(addr: UInt, default: Seq[BitPat], mappingIn: List[(UInt, Seq[BitPat])], decoder: Decoder): Seq[UInt] =
decode(
addr,
default,
mappingIn.map(m => (BitPat(m._1), m._2)).asInstanceOf[Iterable[(BitPat, Seq[BitPat])]],
decoder
)

def decode(addr: UInt, trues: Iterable[UInt], falses: Iterable[UInt], decoder: Decoder): Bool =
decode(
addr,
BitPat.dontCare(1),
trues.map(BitPat(_) -> BitPat("b1")) ++ falses.map(BitPat(_) -> BitPat("b0")),
decoder
).asBool
}

0 comments on commit b76dff8

Please sign in to comment.