diff --git a/clash-protocols/src/Data/Constraint/Nat/Extra.hs b/clash-protocols/src/Data/Constraint/Nat/Extra.hs index 9af80935..2344569e 100644 --- a/clash-protocols/src/Data/Constraint/Nat/Extra.hs +++ b/clash-protocols/src/Data/Constraint/Nat/Extra.hs @@ -23,13 +23,17 @@ leModulusDivisor :: forall a b. (1 <= b) => Dict (Mod a b + 1 <= b) leModulusDivisor = unsafeCoerce (Dict :: Dict (0 <= 0)) -- | if (a <= 0) then (a ~ 0) -leZeroIsZero :: forall a. (a <= 0) => Dict (a ~ 0) +leZeroIsZero :: forall (a :: Nat). (a <= 0) => Dict (a ~ 0) leZeroIsZero = unsafeCoerce (Dict :: Dict (0 ~ 0)) -- | if (1 <= a) and (1 <= b) then (1 <= DivRU a b) strictlyPositiveDivRu :: forall a b. (1 <= a, 1 <= b) => Dict (1 <= DivRU a b) strictlyPositiveDivRu = unsafeCoerce (Dict :: Dict (0 <= 0)) --- | if (1 <= a) then (b <= ceiling(b/a) * a) -timesDivRu :: forall a b. (1 <= a) => Dict (b <= Div (b + (a - 1)) a * a) -timesDivRu = unsafeCoerce (Dict :: Dict (0 <= 0)) +-- | if (1 <= a) then (b <= ceil(b/a) * a) +leTimesDivRu :: forall a b. (1 <= a) => Dict (b <= a * DivRU b a) +leTimesDivRu = unsafeCoerce (Dict :: Dict (0 <= 0)) + +-- | if (1 <= a) then (a * ceil(b/a) ~ b + Mod (a - Mod b a) a) +eqTimesDivRu :: forall a b. (1 <= a) => Dict (a * DivRU b a ~ b + Mod (a - Mod b a) a) +eqTimesDivRu = unsafeCoerce (Dict :: Dict (0 ~ 0)) diff --git a/clash-protocols/src/Protocols/PacketStream/Depacketizers.hs b/clash-protocols/src/Protocols/PacketStream/Depacketizers.hs index b45530e2..5953ea4a 100644 --- a/clash-protocols/src/Protocols/PacketStream/Depacketizers.hs +++ b/clash-protocols/src/Protocols/PacketStream/Depacketizers.hs @@ -20,51 +20,59 @@ import qualified Protocols.Df as Df import Protocols.PacketStream.Base import Data.Constraint (Dict (Dict)) -import Data.Constraint.Nat.Extra (leModulusDivisor, timesDivRu) +import Data.Constraint.Nat.Extra import Data.Data ((:~:) (Refl)) import Data.Maybe defaultByte :: BitVector 8 defaultByte = 0x00 --- Since the header might be unaligned compared to the datawidth --- we need to store a partial fragment when forwarding. --- The number of bytes we need to store depends on our "unalignedness". --- --- Ex. We parse a header of 17 bytes and our @dataWidth@ is 4 bytes. --- That means at the end of the header we can have upto 3 bytes left --- in the fragment which we may need to forward. -type ForwardBufSize (headerBytes :: Nat) (dataWidth :: Nat) = +{- | Vectors of this size are able to hold @headerBytes `DivRU` dataWidth@ + transfers of size @dataWidth@, which is bigger than or equal to @headerBytes@. +-} +type BufSize (headerBytes :: Nat) (dataWidth :: Nat) = + dataWidth * headerBytes `DivRU` dataWidth + +{- | Since the header might be unaligned compared to the datawidth + we need to store a partial fragment when forwarding. + The number of bytes we need to store depends on our "unalignedness". + + Ex. We parse a header of 17 bytes and our @dataWidth@ is 4 bytes. + That means at the end of the header we can have upto 3 bytes left + in the fragment which we may need to forward. +-} +type ForwardBytes (headerBytes :: Nat) (dataWidth :: Nat) = (dataWidth - (headerBytes `Mod` dataWidth)) `Mod` dataWidth +-- | Depacketizer constraints. type DepacketizerCt (headerBytes :: Nat) (dataWidth :: Nat) = ( KnownNat headerBytes , KnownNat dataWidth , 1 <= headerBytes , 1 <= dataWidth - , headerBytes <= headerBytes `DivRU` dataWidth * dataWidth + , BufSize headerBytes dataWidth ~ headerBytes + ForwardBytes headerBytes dataWidth , headerBytes `Mod` dataWidth <= dataWidth + , ForwardBytes headerBytes dataWidth <= dataWidth ) --- TODO remove _fwdBuf and just use the last ForwardBufSize bytes of _parseBuf instead --- See https://github.com/clash-lang/clash-protocols/issues/105 +{- | Depacketizer state. Either we are parsing a header, or we are forwarding + the rest of the packet along with the parsed header in its metadata. +-} data DepacketizerState (headerBytes :: Nat) (dataWidth :: Nat) = Parse { _aborted :: Bool -- ^ Whether the packet is aborted. We need this, because _abort might -- have been set in the bytes to be parsed. - , _parseBuf :: Vec (dataWidth * headerBytes `DivRU` dataWidth) (BitVector 8) - -- ^ Parse buffer. - , _fwdBuf :: Vec (ForwardBufSize headerBytes dataWidth) (BitVector 8) - -- ^ Buffer containing data bytes that could not be sent immediately + , _buf :: Vec (BufSize headerBytes dataWidth) (BitVector 8) + -- ^ The first @headerBytes@ of this buffer are for the parsed header. + -- The bytes after that are data bytes that could not be sent immediately -- due to misalignment of @dataWidth@ and @headerBytes@. , _counter :: Index (headerBytes `DivRU` dataWidth) -- ^ @maxBound + 1@ is the number of fragments we need to parse. } | Forward { _aborted :: Bool - , _parseBuf :: Vec (dataWidth * headerBytes `DivRU` dataWidth) (BitVector 8) - , _fwdBuf :: Vec (ForwardBufSize headerBytes dataWidth) (BitVector 8) + , _buf :: Vec (BufSize headerBytes dataWidth) (BitVector 8) , _counter :: Index (headerBytes `DivRU` dataWidth) , _lastFwd :: Bool -- ^ True iff we have seen @_last@ set but the number of data bytes was too @@ -76,26 +84,26 @@ deriving instance (DepacketizerCt headerBytes dataWidth) => NFDataX (DepacketizerState headerBytes dataWidth) --- | Initial state of @depacketizerT@ +-- | Initial state of @depacketizerT@. instance (DepacketizerCt headerBytes dataWidth) => Default (DepacketizerState headerBytes dataWidth) where def :: DepacketizerState headerBytes dataWidth - def = Parse False (repeat undefined) (repeat undefined) maxBound + def = Parse False (repeat undefined) maxBound +-- | Depacketizer state transition function. depacketizerT :: forall (headerBytes :: Nat) - (dataWidth :: Nat) (header :: Type) (metaIn :: Type) - (metaOut :: Type). - (BitSize header ~ headerBytes * 8) => + (metaOut :: Type) + (dataWidth :: Nat). (BitPack header) => - (DepacketizerCt headerBytes dataWidth) => + (BitSize header ~ headerBytes * 8) => (NFDataX metaIn) => - (ForwardBufSize headerBytes dataWidth <= dataWidth) => + (DepacketizerCt headerBytes dataWidth) => (header -> metaIn -> metaOut) -> DepacketizerState headerBytes dataWidth -> (Maybe (PacketStreamM2S dataWidth metaIn), PacketStreamS2M) -> @@ -106,8 +114,7 @@ depacketizerT _ Parse{..} (Just PacketStreamM2S{..}, _) = (nextStOut, (PacketStr where nextAborted = _aborted || _abort nextCounter = pred _counter - nextParseBuf = fst (shiftInAtN _parseBuf _data) - nextFwdBuf = dropLe (SNat @(dataWidth - ForwardBufSize headerBytes dataWidth)) _data + nextParseBuf = fst (shiftInAtN _buf _data) prematureEnd idx = case sameNat d0 (SNat @(headerBytes `Mod` dataWidth)) of Just Refl -> True @@ -115,19 +122,17 @@ depacketizerT _ Parse{..} (Just PacketStreamM2S{..}, _) = (nextStOut, (PacketStr -- Upon seeing _last being set, move back to the initial state if the -- right amount of bytes were not parsed yet, or if they were, but there - -- were no data bytes after that. In any case, we pass the same buffers - -- for efficiency reasons, because their initial contents are undefined - -- anyway. + -- were no data bytes after that. nextStOut = case (_counter == 0, _last) of (False, Nothing) -> - Parse nextAborted nextParseBuf nextFwdBuf nextCounter + Parse nextAborted nextParseBuf nextCounter (False, Just _) -> def (True, Just idx) | prematureEnd idx -> def (True, _) -> - Forward nextAborted nextParseBuf nextFwdBuf nextCounter (isJust _last) + Forward nextAborted nextParseBuf nextCounter (isJust _last) outReady | Forward{_lastFwd = True} <- nextStOut = False @@ -135,37 +140,39 @@ depacketizerT _ Parse{..} (Just PacketStreamM2S{..}, _) = (nextStOut, (PacketStr depacketizerT toMetaOut st@Forward{..} (Just pkt@PacketStreamM2S{..}, bwdIn) = (nextStOut, (PacketStreamS2M outReady, Just outPkt)) where nextAborted = _aborted || _abort + nextBuf = header ++ nextFwdBytes newLast = adjustLast <$> _last - (dataOut, nextFwdBuf) = splitAt (SNat @dataWidth) (_fwdBuf ++ _data) + (header, fwdBytes) = splitAt (SNat @headerBytes) _buf + (dataOut, nextFwdBytes) = splitAt (SNat @dataWidth) (fwdBytes ++ _data) -- Only use if headerBytes `Mod` dataWidth > 0. adjustLast :: Index dataWidth -> Either (Index dataWidth) (Index dataWidth) adjustLast idx = if idx < x then Left (idx + y) else Right (idx - x) where x = natToNum @(headerBytes `Mod` dataWidth) - y = natToNum @(ForwardBufSize headerBytes dataWidth) + y = natToNum @(ForwardBytes headerBytes dataWidth) outPkt = case sameNat d0 (SNat @(headerBytes `Mod` dataWidth)) of Just Refl -> pkt - { _meta = toMetaOut (bitCoerce $ takeLe (SNat @headerBytes) _parseBuf) _meta + { _meta = toMetaOut (bitCoerce header) _meta , _abort = nextAborted } Nothing -> pkt { _data = if _lastFwd - then _fwdBuf ++ repeat @(dataWidth - ForwardBufSize headerBytes dataWidth) defaultByte + then fwdBytes ++ repeat @(dataWidth - ForwardBytes headerBytes dataWidth) defaultByte else dataOut , _last = if _lastFwd then either Just Just =<< newLast else either Just (const Nothing) =<< newLast - , _meta = toMetaOut (bitCoerce $ takeLe (SNat @headerBytes) _parseBuf) _meta + , _meta = toMetaOut (bitCoerce header) _meta , _abort = nextAborted } - nextForwardSt = Forward nextAborted _parseBuf nextFwdBuf maxBound + nextForwardSt = Forward nextAborted nextBuf maxBound nextSt = case sameNat d0 (SNat @(headerBytes `Mod` dataWidth)) of Just Refl | isJust _last -> def @@ -183,38 +190,40 @@ depacketizerT toMetaOut st@Forward{..} (Just pkt@PacketStreamM2S{..}, bwdIn) = ( | otherwise = _ready bwdIn depacketizerT _ st (Nothing, bwdIn) = (st, (bwdIn, Nothing)) --- | Reads bytes at the start of each packet into metadata. +{- | +Reads bytes at the start of each packet into @_meta@. If a packet contains +less valid bytes than @headerBytes + 1@, it does not send out anything. + +If @dataWidth@ divides @headerBytes@, this component runs at full throughput. +Otherwise, it gives backpressure for one clock cycle per packet larger than +@headerBytes + 1@ valid bytes. +-} depacketizerC :: forall - (dom :: Domain) - (dataWidth :: Nat) + (headerBytes :: Nat) + (header :: Type) (metaIn :: Type) (metaOut :: Type) - (header :: Type) - (headerBytes :: Nat). - ( HiddenClockResetEnable dom - , NFDataX metaOut - , NFDataX metaIn - , BitPack header - , BitSize header ~ headerBytes * 8 - , KnownNat headerBytes - , 1 <= dataWidth - , 1 <= headerBytes - , KnownNat dataWidth - ) => + (dataWidth :: Nat) + (dom :: Domain). + (HiddenClockResetEnable dom) => + (BitPack header) => + (BitSize header ~ headerBytes * 8) => + (NFDataX metaIn) => + (KnownNat headerBytes) => + (KnownNat dataWidth) => + (1 <= headerBytes) => + (1 <= dataWidth) => -- | Used to compute final metadata of outgoing packets from header and incoming metadata (header -> metaIn -> metaOut) -> Circuit (PacketStream dom dataWidth metaIn) (PacketStream dom dataWidth metaOut) depacketizerC toMetaOut = forceResetSanity |> fromSignals ckt where - ckt = case ( timesDivRu @dataWidth @headerBytes + ckt = case ( eqTimesDivRu @dataWidth @headerBytes , leModulusDivisor @headerBytes @dataWidth + , leModulusDivisor @(dataWidth - (headerBytes `Mod` dataWidth)) @dataWidth ) of - (Dict, Dict) -> case compareSNat (SNat @(ForwardBufSize headerBytes dataWidth)) (SNat @dataWidth) of - SNatLE -> mealyB (depacketizerT toMetaOut) def - _ -> - clashCompileError - "depacketizer1: Absurd, Report this to the Clash compiler team: https://github.com/clash-lang/clash-compiler/issues" + (Dict, Dict, Dict) -> mealyB (depacketizerT toMetaOut) def type DepacketizeToDfCt (headerBytes :: Nat) (dataWidth :: Nat) = ( KnownNat headerBytes @@ -340,5 +349,5 @@ depacketizeToDfC :: Circuit (PacketStream dom dataWidth meta) (Df dom a) depacketizeToDfC toOut = forceResetSanity |> fromSignals ckt where - ckt = case timesDivRu @dataWidth @headerBytes of + ckt = case leTimesDivRu @dataWidth @headerBytes of Dict -> mealyB (depacketizeToDfT toOut) def diff --git a/clash-protocols/tests/Tests/Haxioms.hs b/clash-protocols/tests/Tests/Haxioms.hs index 68188c61..93a18a18 100644 --- a/clash-protocols/tests/Tests/Haxioms.hs +++ b/clash-protocols/tests/Tests/Haxioms.hs @@ -78,19 +78,35 @@ prop_strictlyPositiveDivRu = property $ do {- | Test whether the following equation holds: - b <= Div (b + (a - 1)) a * a + b <= a * DivRU b a Given: 1 <= a -Tests: 'Data.Constraint.Nat.Extra.timesDivRU'. +Tests: 'Data.Constraint.Nat.Extra.leTimesDivRu'. -} -prop_timesDivRU :: Property -prop_timesDivRU = property $ do +prop_leTimesDivRu :: Property +prop_leTimesDivRu = property $ do a <- forAll (genNatural 1) b <- forAll (genNatural 0) - assert (b <= (b + (a - 1) `div` a) * a) + assert (b <= a * divRU b a) + +{- | Test whether the following equation holds: + + a * DivRU b a ~ b + Mod (a - Mod b a) a + +Given: + + 1 <= a + +Tests: 'Data.Constraint.Nat.Extra.eqTimesDivRu'. +-} +prop_eqTimesDivRu :: Property +prop_eqTimesDivRu = property $ do + a <- forAll (genNatural 1) + b <- forAll (genNatural 0) + a * (b `divRU` a) === b + (a - b `mod` a) `mod` a tests :: TestTree tests =