Skip to content

Commit

Permalink
Depacketizer: remove redundant buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
t-wallet committed Aug 20, 2024
1 parent d75310a commit 30d979c
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 69 deletions.
12 changes: 8 additions & 4 deletions clash-protocols/src/Data/Constraint/Nat/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ leModulusDivisor :: forall a b. (1 <= b) => Dict (Mod a b + 1 <= b)
leModulusDivisor = unsafeCoerce (Dict :: Dict (0 <= 0))

-- | if (a <= 0) then (a ~ 0)
leZeroIsZero :: forall a. (a <= 0) => Dict (a ~ 0)
leZeroIsZero :: forall (a :: Nat). (a <= 0) => Dict (a ~ 0)
leZeroIsZero = unsafeCoerce (Dict :: Dict (0 ~ 0))

-- | if (1 <= a) and (1 <= b) then (1 <= DivRU a b)
strictlyPositiveDivRu :: forall a b. (1 <= a, 1 <= b) => Dict (1 <= DivRU a b)
strictlyPositiveDivRu = unsafeCoerce (Dict :: Dict (0 <= 0))

-- | if (1 <= a) then (b <= ceiling(b/a) * a)
timesDivRu :: forall a b. (1 <= a) => Dict (b <= Div (b + (a - 1)) a * a)
timesDivRu = unsafeCoerce (Dict :: Dict (0 <= 0))
-- | if (1 <= a) then (b <= ceil(b/a) * a)
leTimesDivRu :: forall a b. (1 <= a) => Dict (b <= a * DivRU b a)
leTimesDivRu = unsafeCoerce (Dict :: Dict (0 <= 0))

-- | if (1 <= a) then (a * ceil(b/a) ~ b + Mod (a - Mod b a) a)
eqTimesDivRu :: forall a b. (1 <= a) => Dict (a * DivRU b a ~ b + Mod (a - Mod b a) a)
eqTimesDivRu = unsafeCoerce (Dict :: Dict (0 ~ 0))
129 changes: 69 additions & 60 deletions clash-protocols/src/Protocols/PacketStream/Depacketizers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,51 +20,59 @@ import qualified Protocols.Df as Df
import Protocols.PacketStream.Base

import Data.Constraint (Dict (Dict))
import Data.Constraint.Nat.Extra (leModulusDivisor, timesDivRu)
import Data.Constraint.Nat.Extra
import Data.Data ((:~:) (Refl))
import Data.Maybe

defaultByte :: BitVector 8
defaultByte = 0x00

-- Since the header might be unaligned compared to the datawidth
-- we need to store a partial fragment when forwarding.
-- The number of bytes we need to store depends on our "unalignedness".
--
-- Ex. We parse a header of 17 bytes and our @dataWidth@ is 4 bytes.
-- That means at the end of the header we can have upto 3 bytes left
-- in the fragment which we may need to forward.
type ForwardBufSize (headerBytes :: Nat) (dataWidth :: Nat) =
{- | Vectors of this size are able to hold @headerBytes `DivRU` dataWidth@
transfers of size @dataWidth@, which is bigger than or equal to @headerBytes@.
-}
type BufSize (headerBytes :: Nat) (dataWidth :: Nat) =
dataWidth * headerBytes `DivRU` dataWidth

{- | Since the header might be unaligned compared to the datawidth
we need to store a partial fragment when forwarding.
The number of bytes we need to store depends on our "unalignedness".
Ex. We parse a header of 17 bytes and our @dataWidth@ is 4 bytes.
That means at the end of the header we can have upto 3 bytes left
in the fragment which we may need to forward.
-}
type ForwardBytes (headerBytes :: Nat) (dataWidth :: Nat) =
(dataWidth - (headerBytes `Mod` dataWidth)) `Mod` dataWidth

-- | Depacketizer constraints.
type DepacketizerCt (headerBytes :: Nat) (dataWidth :: Nat) =
( KnownNat headerBytes
, KnownNat dataWidth
, 1 <= headerBytes
, 1 <= dataWidth
, headerBytes <= headerBytes `DivRU` dataWidth * dataWidth
, BufSize headerBytes dataWidth ~ headerBytes + ForwardBytes headerBytes dataWidth
, headerBytes `Mod` dataWidth <= dataWidth
, ForwardBytes headerBytes dataWidth <= dataWidth
)

-- TODO remove _fwdBuf and just use the last ForwardBufSize bytes of _parseBuf instead
-- See https://github.com/clash-lang/clash-protocols/issues/105
{- | Depacketizer state. Either we are parsing a header, or we are forwarding
the rest of the packet along with the parsed header in its metadata.
-}
data DepacketizerState (headerBytes :: Nat) (dataWidth :: Nat)
= Parse
{ _aborted :: Bool
-- ^ Whether the packet is aborted. We need this, because _abort might
-- have been set in the bytes to be parsed.
, _parseBuf :: Vec (dataWidth * headerBytes `DivRU` dataWidth) (BitVector 8)
-- ^ Parse buffer.
, _fwdBuf :: Vec (ForwardBufSize headerBytes dataWidth) (BitVector 8)
-- ^ Buffer containing data bytes that could not be sent immediately
, _buf :: Vec (BufSize headerBytes dataWidth) (BitVector 8)
-- ^ The first @headerBytes@ of this buffer are for the parsed header.
-- The bytes after that are data bytes that could not be sent immediately
-- due to misalignment of @dataWidth@ and @headerBytes@.
, _counter :: Index (headerBytes `DivRU` dataWidth)
-- ^ @maxBound + 1@ is the number of fragments we need to parse.
}
| Forward
{ _aborted :: Bool
, _parseBuf :: Vec (dataWidth * headerBytes `DivRU` dataWidth) (BitVector 8)
, _fwdBuf :: Vec (ForwardBufSize headerBytes dataWidth) (BitVector 8)
, _buf :: Vec (BufSize headerBytes dataWidth) (BitVector 8)
, _counter :: Index (headerBytes `DivRU` dataWidth)
, _lastFwd :: Bool
-- ^ True iff we have seen @_last@ set but the number of data bytes was too
Expand All @@ -76,26 +84,26 @@ deriving instance
(DepacketizerCt headerBytes dataWidth) =>
NFDataX (DepacketizerState headerBytes dataWidth)

-- | Initial state of @depacketizerT@
-- | Initial state of @depacketizerT@.
instance
(DepacketizerCt headerBytes dataWidth) =>
Default (DepacketizerState headerBytes dataWidth)
where
def :: DepacketizerState headerBytes dataWidth
def = Parse False (repeat undefined) (repeat undefined) maxBound
def = Parse False (repeat undefined) maxBound

-- | Depacketizer state transition function.
depacketizerT ::
forall
(headerBytes :: Nat)
(dataWidth :: Nat)
(header :: Type)
(metaIn :: Type)
(metaOut :: Type).
(BitSize header ~ headerBytes * 8) =>
(metaOut :: Type)
(dataWidth :: Nat).
(BitPack header) =>
(DepacketizerCt headerBytes dataWidth) =>
(BitSize header ~ headerBytes * 8) =>
(NFDataX metaIn) =>
(ForwardBufSize headerBytes dataWidth <= dataWidth) =>
(DepacketizerCt headerBytes dataWidth) =>
(header -> metaIn -> metaOut) ->
DepacketizerState headerBytes dataWidth ->
(Maybe (PacketStreamM2S dataWidth metaIn), PacketStreamS2M) ->
Expand All @@ -106,66 +114,65 @@ depacketizerT _ Parse{..} (Just PacketStreamM2S{..}, _) = (nextStOut, (PacketStr
where
nextAborted = _aborted || _abort
nextCounter = pred _counter
nextParseBuf = fst (shiftInAtN _parseBuf _data)
nextFwdBuf = dropLe (SNat @(dataWidth - ForwardBufSize headerBytes dataWidth)) _data
nextParseBuf = fst (shiftInAtN _buf _data)

prematureEnd idx = case sameNat d0 (SNat @(headerBytes `Mod` dataWidth)) of
Just Refl -> True
_ -> idx < natToNum @(headerBytes `Mod` dataWidth)

-- Upon seeing _last being set, move back to the initial state if the
-- right amount of bytes were not parsed yet, or if they were, but there
-- were no data bytes after that. In any case, we pass the same buffers
-- for efficiency reasons, because their initial contents are undefined
-- anyway.
-- were no data bytes after that.
nextStOut = case (_counter == 0, _last) of
(False, Nothing) ->
Parse nextAborted nextParseBuf nextFwdBuf nextCounter
Parse nextAborted nextParseBuf nextCounter
(False, Just _) ->
def
(True, Just idx)
| prematureEnd idx ->
def
(True, _) ->
Forward nextAborted nextParseBuf nextFwdBuf nextCounter (isJust _last)
Forward nextAborted nextParseBuf nextCounter (isJust _last)

outReady
| Forward{_lastFwd = True} <- nextStOut = False
| otherwise = True
depacketizerT toMetaOut st@Forward{..} (Just pkt@PacketStreamM2S{..}, bwdIn) = (nextStOut, (PacketStreamS2M outReady, Just outPkt))
where
nextAborted = _aborted || _abort
nextBuf = header ++ nextFwdBytes
newLast = adjustLast <$> _last
(dataOut, nextFwdBuf) = splitAt (SNat @dataWidth) (_fwdBuf ++ _data)
(header, fwdBytes) = splitAt (SNat @headerBytes) _buf
(dataOut, nextFwdBytes) = splitAt (SNat @dataWidth) (fwdBytes ++ _data)

-- Only use if headerBytes `Mod` dataWidth > 0.
adjustLast :: Index dataWidth -> Either (Index dataWidth) (Index dataWidth)
adjustLast idx = if idx < x then Left (idx + y) else Right (idx - x)
where
x = natToNum @(headerBytes `Mod` dataWidth)
y = natToNum @(ForwardBufSize headerBytes dataWidth)
y = natToNum @(ForwardBytes headerBytes dataWidth)

outPkt = case sameNat d0 (SNat @(headerBytes `Mod` dataWidth)) of
Just Refl ->
pkt
{ _meta = toMetaOut (bitCoerce $ takeLe (SNat @headerBytes) _parseBuf) _meta
{ _meta = toMetaOut (bitCoerce header) _meta
, _abort = nextAborted
}
Nothing ->
pkt
{ _data =
if _lastFwd
then _fwdBuf ++ repeat @(dataWidth - ForwardBufSize headerBytes dataWidth) defaultByte
then fwdBytes ++ repeat @(dataWidth - ForwardBytes headerBytes dataWidth) defaultByte
else dataOut
, _last =
if _lastFwd
then either Just Just =<< newLast
else either Just (const Nothing) =<< newLast
, _meta = toMetaOut (bitCoerce $ takeLe (SNat @headerBytes) _parseBuf) _meta
, _meta = toMetaOut (bitCoerce header) _meta
, _abort = nextAborted
}

nextForwardSt = Forward nextAborted _parseBuf nextFwdBuf maxBound
nextForwardSt = Forward nextAborted nextBuf maxBound
nextSt = case sameNat d0 (SNat @(headerBytes `Mod` dataWidth)) of
Just Refl
| isJust _last -> def
Expand All @@ -183,38 +190,40 @@ depacketizerT toMetaOut st@Forward{..} (Just pkt@PacketStreamM2S{..}, bwdIn) = (
| otherwise = _ready bwdIn
depacketizerT _ st (Nothing, bwdIn) = (st, (bwdIn, Nothing))

-- | Reads bytes at the start of each packet into metadata.
{- |
Reads bytes at the start of each packet into @_meta@. If a packet contains
less valid bytes than @headerBytes + 1@, it does not send out anything.
If @dataWidth@ divides @headerBytes@, this component runs at full throughput.
Otherwise, it gives backpressure for one clock cycle per packet larger than
@headerBytes + 1@ valid bytes.
-}
depacketizerC ::
forall
(dom :: Domain)
(dataWidth :: Nat)
(headerBytes :: Nat)
(header :: Type)
(metaIn :: Type)
(metaOut :: Type)
(header :: Type)
(headerBytes :: Nat).
( HiddenClockResetEnable dom
, NFDataX metaOut
, NFDataX metaIn
, BitPack header
, BitSize header ~ headerBytes * 8
, KnownNat headerBytes
, 1 <= dataWidth
, 1 <= headerBytes
, KnownNat dataWidth
) =>
(dataWidth :: Nat)
(dom :: Domain).
(HiddenClockResetEnable dom) =>
(BitPack header) =>
(BitSize header ~ headerBytes * 8) =>
(NFDataX metaIn) =>
(KnownNat headerBytes) =>
(KnownNat dataWidth) =>
(1 <= headerBytes) =>
(1 <= dataWidth) =>
-- | Used to compute final metadata of outgoing packets from header and incoming metadata
(header -> metaIn -> metaOut) ->
Circuit (PacketStream dom dataWidth metaIn) (PacketStream dom dataWidth metaOut)
depacketizerC toMetaOut = forceResetSanity |> fromSignals ckt
where
ckt = case ( timesDivRu @dataWidth @headerBytes
ckt = case ( eqTimesDivRu @dataWidth @headerBytes
, leModulusDivisor @headerBytes @dataWidth
, leModulusDivisor @(dataWidth - (headerBytes `Mod` dataWidth)) @dataWidth
) of
(Dict, Dict) -> case compareSNat (SNat @(ForwardBufSize headerBytes dataWidth)) (SNat @dataWidth) of
SNatLE -> mealyB (depacketizerT toMetaOut) def
_ ->
clashCompileError
"depacketizer1: Absurd, Report this to the Clash compiler team: https://github.com/clash-lang/clash-compiler/issues"
(Dict, Dict, Dict) -> mealyB (depacketizerT toMetaOut) def

type DepacketizeToDfCt (headerBytes :: Nat) (dataWidth :: Nat) =
( KnownNat headerBytes
Expand Down Expand Up @@ -340,5 +349,5 @@ depacketizeToDfC ::
Circuit (PacketStream dom dataWidth meta) (Df dom a)
depacketizeToDfC toOut = forceResetSanity |> fromSignals ckt
where
ckt = case timesDivRu @dataWidth @headerBytes of
ckt = case leTimesDivRu @dataWidth @headerBytes of
Dict -> mealyB (depacketizeToDfT toOut) def
26 changes: 21 additions & 5 deletions clash-protocols/tests/Tests/Haxioms.hs
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,35 @@ prop_strictlyPositiveDivRu = property $ do

{- | Test whether the following equation holds:
b <= Div (b + (a - 1)) a * a
b <= a * DivRU b a
Given:
1 <= a
Tests: 'Data.Constraint.Nat.Extra.timesDivRU'.
Tests: 'Data.Constraint.Nat.Extra.leTimesDivRu'.
-}
prop_timesDivRU :: Property
prop_timesDivRU = property $ do
prop_leTimesDivRu :: Property
prop_leTimesDivRu = property $ do
a <- forAll (genNatural 1)
b <- forAll (genNatural 0)
assert (b <= (b + (a - 1) `div` a) * a)
assert (b <= a * divRU b a)

{- | Test whether the following equation holds:
a * DivRU b a ~ b + Mod (a - Mod b a) a
Given:
1 <= a
Tests: 'Data.Constraint.Nat.Extra.eqTimesDivRu'.
-}
prop_eqTimesDivRu :: Property
prop_eqTimesDivRu = property $ do
a <- forAll (genNatural 1)
b <- forAll (genNatural 0)
a * (b `divRU` a) === b + (a - b `mod` a) `mod` a

tests :: TestTree
tests =
Expand Down

0 comments on commit 30d979c

Please sign in to comment.