Skip to content

Commit

Permalink
add selection of a & fix srt8
Browse files Browse the repository at this point in the history
  • Loading branch information
wissygh committed Jun 14, 2022
1 parent ffdb455 commit 2220115
Show file tree
Hide file tree
Showing 13 changed files with 483 additions and 157 deletions.
36 changes: 35 additions & 1 deletion arithmetic/src/division/srt/SRT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,35 @@ class SRT(
dTruncateWidth: Int = 4,
rTruncateWidth: Int = 4)
extends Module {
// val x = (radixLog2, a, dTruncateWidth)
// val tips = x match {
// case (2,2,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
// case (2,2,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
// case (2,2,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
//
// case (3,4,6) => require(rTruncateWidth >= 7, "rTruncateWidth need >= 7")
// case (3,4,7) => require(rTruncateWidth >= 6, "rTruncateWidth need >= 6")
//
// case (3,5,5) => require(rTruncateWidth >= 5, "rTruncateWidth need >= 5")
// case (3,5,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
//
// case (3,6,4) => require(rTruncateWidth >= 6, "rTruncateWidth need >= 6")
// case (3,6,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
//
// case (3,7,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
// case (3,7,5) => require(rTruncateWidth >= 3, "rTruncateWidth need >= 3")
//
// case (4,2,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
// case (4,2,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
// case (4,2,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4")
//
// case _ => println("this srt is not supported")
// }

val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))))
val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth)))
// select radix

// select radix
if (radixLog2 == 2) { // SRT4
val srt = Module(new SRT4(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth))
srt.input <> input
Expand All @@ -31,4 +57,12 @@ class SRT(
srt.input <> input
output <> srt.output
}

// val srt = radixLog2 match {
// case 2 => Module(new SRT4(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth))
// case 3 => Module(new SRT8(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth))
// case 4 => Module(new SRT16(dividendWidth, dividerWidth, n, radixLog2 >> 1, a, dTruncateWidth, rTruncateWidth))
// }
// srt.input <> input
// output <> srt.output
}
2 changes: 1 addition & 1 deletion arithmetic/src/division/srt/SRTTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ case class SRTTable(
}

// TODO: select a Constant from each m, then offer the table to QDS.
// select rule: symmetry and draw a line parallel to the X-axis, how define the rule
// todo: ? select rule: symmetry and draw a line parallel to the X-axis, how define the rule
lazy val tablesToQDS: Seq[Seq[Int]] = {
(aMin.toInt to aMax.toInt).drop(1).map { k =>
k -> dSet.dropRight(1).map { d =>
Expand Down
4 changes: 2 additions & 2 deletions arithmetic/src/division/srt/srt16/OTF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ object OTF {
)(quotient: UInt,
quotientMinusOne: UInt,
selectedQuotientOH: UInt
): Seq[UInt] = {
): Vec[UInt] = {
val m = Module(new OTF(radixLog2, qWidth, ohWidth))
m.input.quotient := quotient
m.input.quotientMinusOne := quotientMinusOne
m.input.selectedQuotientOH := selectedQuotientOH
Seq(m.output.quotient, m.output.quotientMinusOne)
VecInit(m.output.quotient, m.output.quotientMinusOne)
}
}
1 change: 0 additions & 1 deletion arithmetic/src/division/srt/srt16/SRT16.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class SRT16(
val csa5 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(4).head(csa0InWidth))) // 2

// qds

val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS
val partialDivider: UInt = dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0)
val qdsOH0: UInt =
Expand Down
62 changes: 43 additions & 19 deletions arithmetic/src/division/srt/srt4/OTF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,52 @@ import division.srt._
import chisel3._
import chisel3.util.Mux1H

class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module {
class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int, a: Int) extends Module {
val input = IO(Input(new OTFInput(qWidth, ohWidth)))
val output = IO(Output(new OTFOutput(qWidth)))

val radix: Int = 1 << radixLog2
// datapath
// q_j+1 in this circle, only for srt4
val qNext: UInt = Mux1H(
Seq(
input.selectedQuotientOH(0) -> "b110".U,
input.selectedQuotientOH(1) -> "b111".U,
input.selectedQuotientOH(2) -> "b000".U,
input.selectedQuotientOH(3) -> "b001".U,
input.selectedQuotientOH(4) -> "b010".U
)
)

// val cShiftQ: Bool = qNext >= 0.U
// val cShiftQM: Bool = qNext <= 0.U
val cShiftQ: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR
val cShiftQM: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR
val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0)
val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0)
val qNext: UInt = Wire(UInt(3.W))
val cShiftQ, cShiftQM = Wire(Bool())

if (a == 2) {
qNext := Mux1H(
Seq(
input.selectedQuotientOH(0) -> "b110".U, //-2
input.selectedQuotientOH(1) -> "b111".U, //-1
input.selectedQuotientOH(2) -> "b000".U, // 0
input.selectedQuotientOH(3) -> "b001".U, // 1
input.selectedQuotientOH(4) -> "b010".U // 2
)
)
cShiftQ := input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR
cShiftQM := input.selectedQuotientOH(ohWidth / 2, 0).orR
} else if (a == 3) {
qNext := Mux1H(
Seq(
input.selectedQuotientOH(0) -> "b111".U, //-1
input.selectedQuotientOH(1) -> "b000".U, // 0
input.selectedQuotientOH(2) -> "b001".U // 1
)
) + Mux1H(
Seq(
input.selectedQuotientOH(3) -> "b110".U, // -2
input.selectedQuotientOH(4) -> "b000".U, // 0
input.selectedQuotientOH(5) -> "b010".U // 2
)
)
cShiftQ := input.selectedQuotientOH(5) ||
(input.selectedQuotientOH(4) && input.selectedQuotientOH(2, 1).orR)
cShiftQM := input.selectedQuotientOH(3) ||
(input.selectedQuotientOH(4) && input.selectedQuotientOH(1, 0).orR)
}

val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0)
val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0)

output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qIn
output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qmIn
Expand All @@ -36,15 +59,16 @@ object OTF {
def apply(
radixLog2: Int,
qWidth: Int,
ohWidth: Int
ohWidth: Int,
a: Int
)(quotient: UInt,
quotientMinusOne: UInt,
selectedQuotientOH: UInt
): Seq[UInt] = {
val m = Module(new OTF(radixLog2, qWidth, ohWidth))
): Vec[UInt] = {
val m = Module(new OTF(radixLog2, qWidth, ohWidth, a))
m.input.quotient := quotient
m.input.quotientMinusOne := quotientMinusOne
m.input.selectedQuotientOH := selectedQuotientOH
Seq(m.output.quotient, m.output.quotientMinusOne)
VecInit(m.output.quotient, m.output.quotientMinusOne)
}
}
40 changes: 28 additions & 12 deletions arithmetic/src/division/srt/srt4/QDS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import chisel3.util.BitPat.bitPatToUInt
import chisel3.util.experimental.decode.TruthTable
import utils.{extend, sIntToBitPat}

class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module {
class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]], a: Int) extends Module {
// IO
val input = IO(Input(new QDSInput(rWidth, partialDividerWidth)))
val output = IO(Output(new QDSOutput(ohWidth)))
Expand Down Expand Up @@ -54,15 +54,30 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[I
// decoder or findFirstOne here, prefer decoder, the decoder only for srt4
output.selectedQuotientOH := chisel3.util.experimental.decode.decoder(
selectPoints,
TruthTable(
Seq(
BitPat("b???0") -> BitPat("b10000"), //2
BitPat("b??01") -> BitPat("b01000"), //1
BitPat("b?011") -> BitPat("b00100"), //0
BitPat("b0111") -> BitPat("b00010") //-1
),
BitPat("b00001") //-2
)
a match {
case 2 =>
TruthTable(
Seq(
BitPat("b???0") -> BitPat("b10000"), //2
BitPat("b??01") -> BitPat("b01000"), //1
BitPat("b?011") -> BitPat("b00100"), //0
BitPat("b0111") -> BitPat("b00010") //-1
),
BitPat("b00001") //-2
)
case 3 =>
TruthTable(
Seq( // 2 0 -2 1 0 -1
BitPat("b??_???0") -> BitPat("b100_100"), //3 = 2 + 1
BitPat("b??_??01") -> BitPat("b100_010"), //2 = 2 + 0
BitPat("b??_?011") -> BitPat("b010_100"), //1 = 0 + 1
BitPat("b??_0111") -> BitPat("b010_010"), //0 = 0 + 0
BitPat("b?0_1111") -> BitPat("b010_001"), //-1 = 0 + -1
BitPat("b01_1111") -> BitPat("b001_010") //-2 = -2 + 0
),
BitPat("b001_001") //-3 = -2 + -1
)
}
)
}

Expand All @@ -71,12 +86,13 @@ object QDS {
rWidth: Int,
ohWidth: Int,
partialDividerWidth: Int,
tables: Seq[Seq[Int]]
tables: Seq[Seq[Int]],
a: Int
)(partialReminderSum: UInt,
partialReminderCarry: UInt,
partialDivider: UInt
): UInt = {
val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables))
val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables, a))
m.input.partialReminderSum := partialReminderSum
m.input.partialReminderCarry := partialReminderCarry
m.input.partialDivider := partialDivider
Expand Down
81 changes: 57 additions & 24 deletions arithmetic/src/division/srt/srt4/SRT4.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import addition.csa.CarrySaveAdder
import addition.csa.common.CSACompressor3_2
import chisel3._
import chisel3.util._
import spire.math
import utils.leftShift

/** SRT4
Expand All @@ -26,11 +27,8 @@ class SRT4(
dTruncateWidth: Int = 4,
rTruncateWidth: Int = 4)
extends Module {

val xLen: Int = dividendWidth + radixLog2 + 1
val wLen: Int = xLen + radixLog2
val ohWidth: Int = 2 * a + 1

val xLen: Int = dividendWidth + radixLog2 + 1
val wLen: Int = xLen + radixLog2
// IO
val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))))
val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth)))
Expand All @@ -41,9 +39,8 @@ class SRT4(
val counterNext = Wire(UInt(log2Ceil(n).W))

// Control
// sign of select quotient, true -> negative, false -> positive
// sign of Cycle, true -> (counter === 0.U)
val qdsSign, isLastCycle, enable: Bool = Wire(Bool())
val isLastCycle, enable: Bool = Wire(Bool())

// State
// because we need a CSA to minimize the critical path
Expand All @@ -68,41 +65,77 @@ class SRT4(
output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2)
output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient)

// qds
val rWidth: Int = 1 + radixLog2 + rTruncateWidth
val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS
val ohWidth: Int = a match {
case 2 => 2 * a + 1
case 3 => 6
}
//qds
val selectedQuotientOH: UInt =
QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)(
QDS(rWidth, ohWidth, dTruncateWidth - 1, tables, a)(
leftShift(partialReminderSum, radixLog2).head(rWidth),
leftShift(partialReminderCarry, radixLog2).head(rWidth),
dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> ***
)
qdsSign := selectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR
// On-The-Fly conversion
val otf = OTF(radixLog2, n, ohWidth, a)(quotient, quotientMinusOne, selectedQuotientOH)

// csa for SRT4 -> CSA32
val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen))
csa.in(0) := leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2)
csa.in(1) := leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign
csa.in(2) :=
Mux1H(
selectedQuotientOH,
//this is for SRT4, for SRT8 or SRT16, this should be changed
VecInit((-2 to 2).map {
val csa: Vec[UInt] =
if (a == 2) { // a == 2
//csa
val dividerMap = VecInit((-2 to 2).map {
case -2 => divider << 1
case -1 => divider
case 0 => 0.U
case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider
case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1)
})
)
val qdsSign = selectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR
addition.csa.c32(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign,
Mux1H(selectedQuotientOH, dividerMap)
)
)
} else { // a==3
val qHigh = selectedQuotientOH(5, 3)
val qLow = selectedQuotientOH(2, 0)
val qds0Sign = qHigh.head(1)
val qds1Sign = qLow.head(1)

// On-The-Fly conversion
val otf = OTF(radixLog2, n, ohWidth)(quotient, quotientMinusOne, selectedQuotientOH)
// csa
val dividerHMap = VecInit((-1 to 1).map {
case -1 => divider << 1 // -2
case 0 => 0.U // 0
case 1 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) // 2
})
val dividerLMap = VecInit((-1 to 1).map {
case -1 => divider // -1
case 0 => 0.U // 0
case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider // 1
})
val csa0 = addition.csa.c32(
VecInit(
leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2),
leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qds0Sign,
Mux1H(qHigh, dividerHMap)
)
)
addition.csa.c32(
VecInit(
csa0(1).head(wLen - radixLog2),
leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qds1Sign,
Mux1H(qLow, dividerLMap)
)
)
}

dividerNext := Mux(input.fire, input.bits.divider, divider)
counterNext := Mux(input.fire, input.bits.counter, counter - 1.U)
quotientNext := Mux(input.fire, 0.U, otf(0))
quotientMinusOneNext := Mux(input.fire, 0.U, otf(1))
partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa.out(1) << radixLog2)
partialReminderCarryNext := Mux(input.fire, 0.U, csa.out(0) << 1 + radixLog2)
partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa(1) << radixLog2)
partialReminderCarryNext := Mux(input.fire, 0.U, csa(0) << 1 + radixLog2)
}
Loading

0 comments on commit 2220115

Please sign in to comment.