Skip to content

Commit

Permalink
Add haxiom module
Browse files Browse the repository at this point in the history
  • Loading branch information
t-wallet committed Aug 20, 2024
1 parent d1d4cb4 commit c91991d
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 55 deletions.
3 changes: 3 additions & 0 deletions clash-protocols/clash-protocols.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ library
, clash-protocols-base
, circuit-notation
, clash-prelude-hedgehog
, constraints
, data-default
, deepseq
, extra
Expand Down Expand Up @@ -170,6 +171,7 @@ library
autogen-modules: Paths_clash_protocols

other-modules:
Data.Constraint.Nat.Extra
Data.Maybe.Extra
Clash.Sized.Vector.Extra
Paths_clash_protocols
Expand All @@ -183,6 +185,7 @@ test-suite unittests
ghc-options: -threaded -with-rtsopts=-N
main-is: unittests.hs
other-modules:
Tests.Haxioms
Tests.Protocols
Tests.Protocols.Df
Tests.Protocols.DfConv
Expand Down
35 changes: 35 additions & 0 deletions clash-protocols/src/Data/Constraint/Nat/Extra.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{-# LANGUAGE AllowAmbiguousTypes #-}

{-
NOTE [constraint solver addition]
The functions in this module enable us introduce trivial constraints that are not
solved by the constraint solver.
-}
module Data.Constraint.Nat.Extra where

import Clash.Prelude
import Data.Constraint
import Unsafe.Coerce (unsafeCoerce)

{- | Postulates that multiplying some number /a/ by some constant /b/, and
subsequently dividing that result by /b/ equals /a/.
-}
cancelMulDiv :: forall a b. (1 <= b) => Dict (DivRU (a * b) b ~ a)
cancelMulDiv = unsafeCoerce (Dict :: Dict (0 ~ 0))

-- | if (1 <= b) then (Mod a b + 1 <= b)
leModulusDivisor :: forall a b. 1 <= b => Dict (Mod a b + 1 <= b)
leModulusDivisor = unsafeCoerce (Dict :: Dict (0 <= 0))

-- | if (a <= 0) then (a ~ 0)
leZeroIsZero :: forall a. (a <= 0) => Dict (a ~ 0)
leZeroIsZero = unsafeCoerce (Dict :: Dict (0 ~ 0))

-- | if (1 <= a) and (1 <= b) then (1 <= DivRU a b)
strictlyPositiveDivRu :: forall a b. (1 <= a, 1 <= b) => Dict (1 <= DivRU a b)
strictlyPositiveDivRu = unsafeCoerce (Dict :: Dict (0 <= 0))

-- | if (1 <= a) then (b <= ceiling(b/a) * a)
timesDivRu :: forall a b. (1 <= a) => Dict (b <= Div (b + (a - 1)) a * a)
timesDivRu = unsafeCoerce (Dict :: Dict (0 <= 0))
56 changes: 22 additions & 34 deletions clash-protocols/src/Protocols/PacketStream/Depacketizers.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=10 #-}
{-# OPTIONS_HADDOCK hide #-}

{- |
Expand All @@ -21,6 +21,8 @@ import Protocols.PacketStream.Base

import Data.Data ((:~:) (Refl))
import Data.Maybe
import Data.Constraint (Dict(Dict))
import Data.Constraint.Nat.Extra (timesDivRu, leModulusDivisor)

defaultByte :: BitVector 8
defaultByte = 0x00
Expand All @@ -36,10 +38,12 @@ type ForwardBufSize (headerBytes :: Nat) (dataWidth :: Nat) =
(dataWidth - (headerBytes `Mod` dataWidth)) `Mod` dataWidth

type DepacketizerCt (headerBytes :: Nat) (dataWidth :: Nat) =
( headerBytes `Mod` dataWidth <= dataWidth
( KnownNat headerBytes
, KnownNat dataWidth
, 1 <= headerBytes
, 1 <= dataWidth
, KnownNat headerBytes
, headerBytes <= headerBytes `DivRU` dataWidth * dataWidth
, headerBytes `Mod` dataWidth <= dataWidth
)

-- TODO remove _fwdBuf and just use the last ForwardBufSize bytes of _parseBuf instead
Expand Down Expand Up @@ -92,7 +96,6 @@ depacketizerT ::
(DepacketizerCt headerBytes dataWidth) =>
(NFDataX metaIn) =>
(ForwardBufSize headerBytes dataWidth <= dataWidth) =>
(headerBytes <= dataWidth * headerBytes `DivRU` dataWidth) =>
(header -> metaIn -> metaOut) ->
DepacketizerState headerBytes dataWidth ->
(Maybe (PacketStreamM2S dataWidth metaIn), PacketStreamS2M) ->
Expand Down Expand Up @@ -196,34 +199,29 @@ depacketizerC ::
, BitSize header ~ headerBytes * 8
, KnownNat headerBytes
, 1 <= dataWidth
, 1 <= headerBytes
, KnownNat dataWidth
) =>
-- | Used to compute final metadata of outgoing packets from header and incoming metadata
(header -> metaIn -> metaOut) ->
Circuit (PacketStream dom dataWidth metaIn) (PacketStream dom dataWidth metaOut)
depacketizerC toMetaOut = forceResetSanity |> fromSignals outCircuit
depacketizerC toMetaOut = forceResetSanity |> fromSignals ckt
where
modProof = compareSNat (SNat @(headerBytes `Mod` dataWidth)) (SNat @dataWidth)
divProof = compareSNat (SNat @headerBytes) (SNat @(dataWidth * headerBytes `DivRU` dataWidth))

outCircuit =
case (modProof, divProof) of
(SNatLE, SNatLE) -> case compareSNat (SNat @(ForwardBufSize headerBytes dataWidth)) (SNat @dataWidth) of
SNatLE -> mealyB (depacketizerT @headerBytes toMetaOut) def
_ ->
clashCompileError
"depacketizer1: Absurd, Report this to the Clash compiler team: https://github.com/clash-lang/clash-compiler/issues"
ckt = case ( timesDivRu @dataWidth @headerBytes
, leModulusDivisor @headerBytes @dataWidth
) of
(Dict, Dict) -> case compareSNat (SNat @(ForwardBufSize headerBytes dataWidth)) (SNat @dataWidth) of
SNatLE -> mealyB (depacketizerT toMetaOut) def
_ ->
clashCompileError
"depacketizer0: Absurd, Report this to the Clash compiler team: https://github.com/clash-lang/clash-compiler/issues"
"depacketizer1: Absurd, Report this to the Clash compiler team: https://github.com/clash-lang/clash-compiler/issues"

type DepacketizeToDfCt (headerBytes :: Nat) (dataWidth :: Nat) =
( 1 <= headerBytes `DivRU` dataWidth
, headerBytes `Mod` dataWidth <= dataWidth
, headerBytes <= dataWidth * headerBytes `DivRU` dataWidth
( KnownNat headerBytes
, KnownNat dataWidth
, 1 <= headerBytes
, 1 <= dataWidth
, KnownNat headerBytes
, headerBytes <= headerBytes `DivRU` dataWidth * dataWidth
)

data DfDepacketizerState (headerBytes :: Nat) (dataWidth :: Nat)
Expand Down Expand Up @@ -334,23 +332,13 @@ depacketizeToDfC ::
(BitPack header) =>
(KnownNat headerBytes) =>
(KnownNat dataWidth) =>
(1 <= headerBytes) =>
(1 <= dataWidth) =>
(BitSize header ~ headerBytes * 8) =>
-- | function that transforms the given meta + parsed header to the output Df
(header -> meta -> a) ->
Circuit (PacketStream dom dataWidth meta) (Df dom a)
depacketizeToDfC toOut = forceResetSanity |> fromSignals outCircuit
depacketizeToDfC toOut = forceResetSanity |> fromSignals ckt
where
divProof = compareSNat (SNat @headerBytes) (SNat @(dataWidth * headerBytes `DivRU` dataWidth))
modProof = compareSNat (SNat @(headerBytes `Mod` dataWidth)) (SNat @dataWidth)

outCircuit =
case (divProof, modProof) of
(SNatLE, SNatLE) -> case compareSNat d1 (SNat @(headerBytes `DivRU` dataWidth)) of
SNatLE -> mealyB (depacketizeToDfT toOut) def
_ ->
clashCompileError
"depacketizeToDfC0: Absurd, Report this to the Clash compiler team: https://github.com/clash-lang/clash-compiler/issues"
_ ->
clashCompileError
"depacketizeToDfC1: Absurd, Report this to the Clash compiler team: https://github.com/clash-lang/clash-compiler/issues"
ckt = case timesDivRu @dataWidth @headerBytes of
Dict -> mealyB (depacketizeToDfT toOut) def
30 changes: 12 additions & 18 deletions clash-protocols/src/Protocols/PacketStream/Packetizers.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=10 #-}
{-# OPTIONS_HADDOCK hide #-}

{- |
Expand All @@ -18,9 +19,10 @@ import qualified Protocols.Df as Df
import Protocols.PacketStream.Base

import Clash.Sized.Vector.Extra (takeLe)
import Data.Data ((:~:) (Refl))
import Data.Maybe
import Data.Maybe.Extra
import Data.Constraint.Nat.Extra (leModulusDivisor, strictlyPositiveDivRu, leZeroIsZero)
import Data.Constraint (Dict(Dict))

defaultByte :: BitVector 8
defaultByte = 0x00
Expand Down Expand Up @@ -313,19 +315,13 @@ packetizerC ::
Circuit (PacketStream dom dataWidth metaIn) (PacketStream dom dataWidth metaOut)
packetizerC toMetaOut toHeader = fromSignals outCircuit
where
outCircuit = case compareSNat (SNat @(headerBytes `Mod` dataWidth)) (SNat @dataWidth) of
SNatLE -> case compareSNat (SNat @(headerBytes + 1)) (SNat @dataWidth) of
outCircuit = case leModulusDivisor @headerBytes @dataWidth of
Dict -> case compareSNat (SNat @(headerBytes + 1)) (SNat @dataWidth) of
SNatLE -> mealyB (packetizerT1 @headerBytes toMetaOut toHeader) (Insert1 False)
SNatGT -> case sameNat (SNat @(headerBytes `Mod` dataWidth)) d0 of
Just Refl -> mealyB (packetizerT2 @headerBytes toMetaOut toHeader) LoadHeader2
_ -> case compareSNat d1 (SNat @(headerBytes `Mod` dataWidth)) of
SNatLE -> mealyB (packetizerT3 @headerBytes toMetaOut toHeader) LoadHeader3
SNatGT ->
clashCompileError
"packetizerC0: Absurd, Report this to the Clash compiler team: https://github.com/clash-lang/clash-compiler/issues"
_ ->
clashCompileError
"packetizerC1: Absurd, Report this to the Clash compiler team: https://github.com/clash-lang/clash-compiler/issues"
SNatGT -> case compareSNat (SNat @(headerBytes `Mod` dataWidth)) d0 of
SNatLE -> case leZeroIsZero @(headerBytes `Mod` dataWidth) of
Dict -> mealyB (packetizerT2 @headerBytes toMetaOut toHeader) LoadHeader2
SNatGT -> mealyB (packetizerT3 @headerBytes toMetaOut toHeader) LoadHeader3

data DfPacketizerState (metaOut :: Type) (headerBytes :: Nat) (dataWidth :: Nat)
= DfIdle
Expand Down Expand Up @@ -404,14 +400,15 @@ packetizeFromDfC ::
(BitSize header ~ headerBytes * 8) =>
(KnownNat headerBytes) =>
(KnownNat dataWidth) =>
(1 <= headerBytes) =>
(1 <= dataWidth) =>
-- | Function that transforms the Df input to the output metadata.
(a -> metaOut) ->
-- | Function that transforms the Df input to the header that will be packetized.
(a -> header) ->
Circuit (Df dom a) (PacketStream dom dataWidth metaOut)
packetizeFromDfC toMetaOut toHeader = case compareSNat d1 (SNat @(headerBytes `DivRU` dataWidth)) of
SNatLE -> case compareSNat (SNat @headerBytes) (SNat @dataWidth) of
packetizeFromDfC toMetaOut toHeader = case strictlyPositiveDivRu @headerBytes @dataWidth of
Dict -> case compareSNat (SNat @headerBytes) (SNat @dataWidth) of
-- We don't need a state machine in this case, as we are able to packetize
-- the entire payload in one clock cycle.
SNatLE -> Circuit (unbundle . fmap go . bundle)
Expand All @@ -425,6 +422,3 @@ packetizeFromDfC toMetaOut toHeader = case compareSNat d1 (SNat @(headerBytes `D
SNatGT -> natToNum @(headerBytes `Mod` dataWidth - 1)
_ -> natToNum @(dataWidth - 1)
SNatGT -> fromSignals (mealyB (packetizeFromDfT toMetaOut toHeader) DfIdle)
SNatGT ->
clashCompileError
"packetizeFromDfC: Absurd, Report this to the Clash compiler team: https://github.com/clash-lang/clash-compiler/issues"
99 changes: 99 additions & 0 deletions clash-protocols/tests/Tests/Haxioms.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
{-# LANGUAGE NumericUnderscores #-}

module Tests.Haxioms where

import Prelude
import Numeric.Natural

import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range

import Test.Tasty
import Test.Tasty.Hedgehog (HedgehogTestLimit (HedgehogTestLimit))
import Test.Tasty.Hedgehog.Extra (testProperty)
import Test.Tasty.TH (testGroupGenerator)

-- | Generate a 'Natural' greater than or equal to /n/. Can generate 'Natural's
-- up to /n+1000/. This should be enough, given that naturals in this module are
-- used in proofs.
genNatural :: Natural -> Gen Natural
genNatural min_ = Gen.integral (Range.linear min_ (1000 + min_))

-- | Like 'DivRU', but at term-level.
divRU :: Natural -> Natural -> Natural
divRU dividend divider =
case dividend `divMod` divider of
(n, 0) -> n
(n, _) -> n + 1

-- | Test whether the following equation holds:
--
-- DivRU (a * b) b ~ a
--
-- Given:
--
-- 1 <= b
--
-- Tests: 'Data.Constraint.Nat.Extra.cancelMulDiv'.
--
prop_cancelMulDiv :: Property
prop_cancelMulDiv = property $ do
a <- forAll (genNatural 0)
b <- forAll (genNatural 1)
divRU (a * b) b === a

-- | Test whether the following equation holds:
--
-- Mod a b + 1 <= b
--
-- Given:
--
-- 1 <= b
--
-- Tests: 'Data.Constraint.Nat.Extra.leModulusDivisor'.
--
prop_leModulusDivisor :: Property
prop_leModulusDivisor = property $ do
a <- forAll (genNatural 0)
b <- forAll (genNatural 1)
assert (a `mod` b + 1 <= b)

-- | Test whether the following equation holds:
--
-- 1 <= DivRU a b
--
-- Given:
--
-- 1 <= a, 1 <= b
--
-- Tests: 'Data.Constraint.Nat.Extra.strictlyPositiveDivRu'.
--
prop_strictlyPositiveDivRu :: Property
prop_strictlyPositiveDivRu = property $ do
a <- forAll (genNatural 1)
b <- forAll (genNatural 1)
assert (1 <= divRU a b)

-- | Test whether the following equation holds:
--
-- b <= Div (b + (a - 1)) a * a
--
-- Given:
--
-- 1 <= a
--
-- Tests: 'Data.Constraint.Nat.Extra.timesDivRU'.
--
prop_timesDivRU :: Property
prop_timesDivRU = property $ do
a <- forAll (genNatural 1)
b <- forAll (genNatural 0)
assert (b <= (b + (a - 1) `div` a) * a)

tests :: TestTree
tests =
localOption (mkTimeout 10_000_000 {- 10 seconds -}) $
localOption
(HedgehogTestLimit (Just 100_000))
$(testGroupGenerator)
7 changes: 4 additions & 3 deletions clash-protocols/tests/unittests.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
module Main where

import Control.Concurrent (setNumCapabilities)
import Control.Monad (join)
import System.Environment (lookupEnv, setEnv)
import Test.Tasty
import Text.Read (readMaybe)
import Prelude

import qualified Tests.Haxioms
import qualified Tests.Protocols

main :: IO ()
Expand All @@ -15,7 +15,7 @@ main = do
setEnv "TASTY_NUM_THREADS" "2"

-- Detect "THREADS" environment variable on CI
nThreads <- join . fmap readMaybe <$> lookupEnv "THREADS"
nThreads <- (readMaybe =<<) <$> lookupEnv "THREADS"
case nThreads of
Nothing -> pure ()
Just n -> do
Expand All @@ -27,5 +27,6 @@ tests :: TestTree
tests =
testGroup
"Tests"
[ Tests.Protocols.tests
[ Tests.Haxioms.tests
, Tests.Protocols.tests
]

0 comments on commit c91991d

Please sign in to comment.