Skip to content

Commit

Permalink
Add TH tuple instances for Simulate
Browse files Browse the repository at this point in the history
  • Loading branch information
lmbollen committed Nov 19, 2024
1 parent 4b3fdec commit 28173f2
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
5 changes: 5 additions & 0 deletions clash-protocols/src/Protocols/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=20 #-}
#if !MIN_VERSION_clash_prelude(1, 8, 2)
{-# OPTIONS_GHC -fno-warn-orphans #-}
#endif
Expand Down Expand Up @@ -33,8 +34,10 @@ import qualified Clash.Explicit.Prelude as CE
import Clash.Prelude (type (*), type (+))
import qualified Clash.Prelude as C

import Protocols.Internal.TH (simulateTupleInstances)
import Protocols.Internal.Types
import Protocols.Plugin
import Protocols.Plugin.Cpp (maxTupleSize)
import Protocols.Plugin.TaggedBundle
import Protocols.Plugin.Units

Expand Down Expand Up @@ -234,6 +237,8 @@ instance (Simulate a, Simulate b) => Simulate (a, b) where
in
((fwdL1, fwdR1), (bwdL1, bwdR1))

simulateTupleInstances 3 maxTupleSize

instance (Drivable a, Drivable b) => Drivable (a, b) where
type ExpectType (a, b) = (ExpectType a, ExpectType b)

Expand Down
77 changes: 77 additions & 0 deletions clash-protocols/src/Protocols/Internal/TH.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

module Protocols.Internal.TH where

import qualified Clash.Prelude as C
import Control.Monad.Extra (concatMapM)
import GHC.TypeNats
import Language.Haskell.TH
import Protocols.Internal.Types
import Protocols.Plugin

{- | Template haskell function to generate IdleCircuit instances for the tuples
n through m inclusive. To see a 2-tuple version of the pattern we generate,
Expand All @@ -31,3 +34,77 @@ idleCircuitTupleInstance n =
mkFwdExpr ty = [e|idleFwd $ Proxy @($ty)|]
bwdExpr = tupE $ map mkBwdExpr circTys
mkBwdExpr ty = [e|idleBwd $ Proxy @($ty)|]

simulateTupleInstances :: Int -> Int -> DecsQ
simulateTupleInstances n m = concatMapM simulateTupleInstance [n .. m]

simulateTupleInstance :: Int -> DecsQ
simulateTupleInstance n =
[d|
instance ($instCtx) => Simulate $instTy where
type SimulateFwdType $instTy = $fwdType
type SimulateBwdType $instTy = $bwdType
type SimulateChannels $instTy = $channelSum

simToSigFwd _ $fwdPat0 = $(tupE $ zipWith (\ty expr -> [e|simToSigFwd (Proxy @($ty)) $expr|]) circTys fwdExpr)
simToSigBwd _ $bwdPat0 = $(tupE $ zipWith (\ty expr -> [e|simToSigBwd (Proxy @($ty)) $expr|]) circTys bwdExpr)
sigToSimFwd _ $fwdPat0 = $(tupE $ zipWith (\ty expr -> [e|sigToSimFwd (Proxy @($ty)) $expr|]) circTys fwdExpr)
sigToSimBwd _ $bwdPat0 = $(tupE $ zipWith (\ty expr -> [e|sigToSimBwd (Proxy @($ty)) $expr|]) circTys bwdExpr)

stallC $(varP $ mkName "conf") $(varP $ mkName "rem0") = $(letE (stallVecs ++ stallCircuits) stallCExpr)
|]
where
-- Generate the types for the instance
circTys = map (\i -> varT $ mkName $ "c" <> show i) [1 .. n]
instTy = foldl appT (tupleT n) circTys
instCtx = foldl appT (tupleT n) $ map (\ty -> [t|Simulate $ty|]) circTys
fwdType = foldl appT (tupleT n) $ map (\ty -> [t|SimulateFwdType $ty|]) circTys
bwdType = foldl appT (tupleT n) $ map (\ty -> [t|SimulateBwdType $ty|]) circTys
channelSum = foldl1 (\a b -> [t|$a + $b|]) $ map (\ty -> [t|SimulateChannels $ty|]) circTys

-- Relevant expressions and patterns
fwdPat0 = tupP $ map (\i -> varP $ mkName $ "fwd" <> show i) [1 .. n]
bwdPat0 = tupP $ map (\i -> varP $ mkName $ "bwd" <> show i) [1 .. n]
fwdExpr = map (\i -> varE $ mkName $ "fwd" <> show i) [1 .. n]
bwdExpr = map (\i -> varE $ mkName $ "bwd" <> show i) [1 .. n]
fwdExpr1 = map (\i -> varE $ mkName $ "fwdStalled" <> show i) [1 .. n]
bwdExpr1 = map (\i -> varE $ mkName $ "bwdStalled" <> show i) [1 .. n]

-- stallC Declaration: Split off the stall vectors from the large input vector
stallVecs = zipWith mkStallVec [1 .. n] circTys
mkStallVec i ty =
valD
mkStallPat
( normalB [e|(C.splitAtI @(SimulateChannels $ty) $(varE (mkName $ "rem" <> show (i - 1))))|]
)
[]
where
mkStallPat =
tupP
[ varP (mkName $ "stalls" <> show i)
, varP (mkName $ if i == n then "_" else "rem" <> show i)
]

-- stallC Declaration: Generate stalling circuits
stallCircuits = zipWith mkStallCircuit [1 .. n] circTys
mkStallCircuit i ty =
valD
[p|Circuit $(varP $ mkName $ "stalled" <> show i)|]
(normalB [e|stallC @($ty) conf $(varE $ mkName $ "stalls" <> show i)|])
[]

-- Generate the stallC expression
stallCExpr =
[e|
Circuit $ \($fwdPat0, $bwdPat0) -> $(letE stallCResultDecs [e|($(tupE fwdExpr1), $(tupE bwdExpr1))|])
|]

stallCResultDecs = map mkStallCResultDec [1 .. n]
mkStallCResultDec i =
valD
(tupP [varP $ mkName $ "fwdStalled" <> show i, varP $ mkName $ "bwdStalled" <> show i])
( normalB $
appE (varE $ mkName $ "stalled" <> show i) $
tupE [varE $ mkName $ "fwd" <> show i, varE $ mkName $ "bwd" <> show i]
)
[]

0 comments on commit 28173f2

Please sign in to comment.