diff --git a/t1/src/mask/MaskCompress.scala b/t1/src/mask/MaskCompress.scala index 4f7f7dd56..cb381958f 100644 --- a/t1/src/mask/MaskCompress.scala +++ b/t1/src/mask/MaskCompress.scala @@ -184,6 +184,12 @@ class MaskCompress(val parameter: CompressParam) changeUIntSize(compressInitPipe, maxCountWidth) } + val tailCountForMask: UInt = { + val minElementSizePerSet = parameter.laneNumber * parameter.datapathWidth / 8 + val maxCountWidth = log2Ceil(minElementSizePerSet) + changeUIntSize(compressInit, maxCountWidth) + } + val compressDataReg = RegInit(0.U((parameter.laneNumber * parameter.datapathWidth).W)) val compressTailValid: Bool = RegInit(false.B) val compressWriteGroupCount: UInt = RegInit(0.U(parameter.groupNumberBits.W)) @@ -232,7 +238,7 @@ class MaskCompress(val parameter: CompressParam) val dataByte = 1 << sewInt val elementSizePerSet = parameter.laneNumber * parameter.datapathWidth / 8 / dataByte VecInit(Seq.tabulate(elementSizePerSet) { elementIndex => - val elementValid = elementIndex.U < tailCount + val elementValid = elementIndex.U < tailCountForMask val elementMask = Fill(dataByte, elementValid) elementMask }).asUInt @@ -310,5 +316,5 @@ class MaskCompress(val parameter: CompressParam) val ffoOutPipe: UInt = initRegEnable(completedLeftOr | Fill(parameter.laneNumber, ffoValid), in.fire) outWire.ffoOutput := ffoOutPipe out := RegNext(outWire, 0.U.asTypeOf(outWire)) - io.stageValid := stage2Valid || in.valid + io.stageValid := stage2Valid || in.valid || compressTailValid } diff --git a/t1/src/mask/MaskReduce.scala b/t1/src/mask/MaskReduce.scala index d0abe2598..ff48579d0 100644 --- a/t1/src/mask/MaskReduce.scala +++ b/t1/src/mask/MaskReduce.scala @@ -194,7 +194,7 @@ class MaskReduce(val parameter: MaskReduceParameter) // count update // todo: stateCross <=> stateOrder ?? - when((stateCross && !floatType) || waiteDeq || in.fire) { + when((stateCross && !floatAdd) || waiteDeq || in.fire) { crossFoldCount := Mux(in.fire, 0.U, crossFoldCount + 1.U) }