Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a quasi quote for rewrite rules #133

Merged
merged 12 commits into from
Aug 4, 2021
14 changes: 10 additions & 4 deletions Jikka.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,16 @@ library
Jikka.Core.Language.FreeVars
Jikka.Core.Language.LambdaPatterns
Jikka.Core.Language.Lint
Jikka.Core.Language.QuasiRules
Jikka.Core.Language.RewriteRules
Jikka.Core.Language.Runtime
Jikka.Core.Language.TypeCheck
Jikka.Core.Language.Util
Jikka.Core.Language.Value
Jikka.Core.Parse
Jikka.Core.Parse.Alex
Jikka.Core.Parse.Happy
Jikka.Core.Parse.Token
Jikka.CPlusPlus.Convert
Jikka.CPlusPlus.Convert.AddMain
Jikka.CPlusPlus.Convert.BundleRuntime
Expand Down Expand Up @@ -162,7 +167,7 @@ library
, deepseq >=1.4.4 && <1.5
, directory >=1.3.3 && <1.4
, mtl >=2.2.2 && <2.3
, template-haskell >=2.14.0 && <2.17
, template-haskell >=2.16.0 && <2.17
, text >=1.2.3 && <1.3
, transformers >=0.5.6 && <0.6
, vector >=0.12.3 && <0.13
Expand All @@ -186,7 +191,7 @@ executable jikka
, deepseq >=1.4.4 && <1.5
, directory >=1.3.3 && <1.4
, mtl >=2.2.2 && <2.3
, template-haskell >=2.14.0 && <2.17
, template-haskell >=2.16.0 && <2.17
, text >=1.2.3 && <1.3
, transformers >=0.5.6 && <0.6
, vector >=0.12.3 && <0.13
Expand All @@ -210,7 +215,7 @@ test-suite jikka-doctest
, directory >=1.3.3 && <1.4
, doctest
, mtl >=2.2.2 && <2.3
, template-haskell >=2.14.0 && <2.17
, template-haskell >=2.16.0 && <2.17
, text >=1.2.3 && <1.3
, transformers >=0.5.6 && <0.6
, vector >=0.12.3 && <0.13
Expand Down Expand Up @@ -244,6 +249,7 @@ test-suite jikka-test
Jikka.Core.FormatSpec
Jikka.Core.Language.ArithmeticalExprSpec
Jikka.Core.Language.BetaSpec
Jikka.Core.ParseSpec
Jikka.CPlusPlus.Convert.FromCoreSpec
Jikka.CPlusPlus.FormatSpec
Jikka.Python.Convert.ToRestrictedPythonSpec
Expand Down Expand Up @@ -282,7 +288,7 @@ test-suite jikka-test
, hspec
, mtl >=2.2.2 && <2.3
, ormolu
, template-haskell >=2.14.0 && <2.17
, template-haskell >=2.16.0 && <2.17
, text >=1.2.3 && <1.3
, transformers >=0.5.6 && <0.6
, vector >=0.12.3 && <0.13
Expand Down
4 changes: 1 addition & 3 deletions doctests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import Test.DocTest
main :: IO ()
main = doctest
[ "src/Jikka/Common/"
, "src/Jikka/Core/"
, "src/Jikka/CPlusPlus/"
, "src/Jikka/Python/Convert/"
, "src/Jikka/Python/Language/"
, "src/Jikka/RestrictedPython/"
, "src/Jikka/RestrictedPython/Language/"
]
2 changes: 1 addition & 1 deletion package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies:
- deepseq >= 1.4.4 && < 1.5
- directory >= 1.3.3 && < 1.4
- mtl >= 2.2.2 && < 2.3
- template-haskell >= 2.14.0 && < 2.17
- template-haskell >= 2.16.0 && < 2.17
- text >= 1.2.3 && < 1.3
- transformers >= 0.5.6 && < 0.6
- vector >= 0.12.3 && < 0.13
Expand Down
336 changes: 179 additions & 157 deletions src/Jikka/CPlusPlus/Convert/FromCore.hs

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/Jikka/Common/Alpha.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import Control.Monad.Reader
import Control.Monad.Signatures
import Control.Monad.State.Strict
import Control.Monad.Writer.Strict
import Data.Unique
import Language.Haskell.TH (Q)

class Monad m => MonadAlpha m where
nextCounter :: m Int
Expand Down Expand Up @@ -84,3 +86,9 @@ evalAlpha f i = runIdentity (evalAlphaT f i)

resetAlphaT :: Monad m => Int -> AlphaT m ()
resetAlphaT i = AlphaT $ \_ -> return ((), i)

instance MonadAlpha IO where
nextCounter = hashUnique <$> newUnique

instance MonadAlpha Q where
nextCounter = liftIO nextCounter
2 changes: 1 addition & 1 deletion src/Jikka/Core/Convert/ANormal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ runExpr env = \case
f <- runExpr env f
args <- mapM (runExpr env) args
case (f, args) of
(Lit (LitBuiltin (If _)), [e1, e2, e3]) -> do
(Lit (LitBuiltin If _), [e1, e2, e3]) -> do
(_, ctx, e1) <- destruct env e1
return $ ctx (App3 f e1 e2 e3)
_ -> runApp env f args
Expand Down
2 changes: 1 addition & 1 deletion src/Jikka/Core/Convert/CumulativeSum.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ rule = RewriteRule $ \_ -> \case
then At' IntTy (Var b) n
else Minus' (At' IntTy (Var b) (Plus' n (formatArithmeticalExpr shift))) (At' IntTy (Var b) (formatArithmeticalExpr shift))
return . Just $
Let b (ListTy IntTy) (Scanl' IntTy IntTy (Lit (LitBuiltin Plus)) Lit0 a) e
Let b (ListTy IntTy) (Scanl' IntTy IntTy (Builtin Plus) Lit0 a) e
_ -> return Nothing
Max1' t (Cons' _ a0 (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n))) | x' == x && x `isUnusedVar` a -> do
Just <$> cumulativeMax (Max2' t) t (Just a0) a n
Expand Down
13 changes: 7 additions & 6 deletions src/Jikka/Core/Convert/MakeScanl.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ module Jikka.Core.Convert.MakeScanl
where

import Control.Monad.Trans.Maybe
import Data.List
import qualified Data.Map as M
import Jikka.Common.Alpha
import Jikka.Common.Error
Expand All @@ -56,7 +57,7 @@ reduceScanlBuild = simpleRewriteRule $ \case
_ -> Nothing

-- | `getRecurrenceFormulaStep1` removes `At` in @body@.
getRecurrenceFormulaStep1 :: MonadAlpha m => Int -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr)
getRecurrenceFormulaStep1 :: MonadAlpha m => Integer -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr)
getRecurrenceFormulaStep1 shift t a i body = do
x <- genVarName a
let proj k =
Expand All @@ -78,13 +79,13 @@ getRecurrenceFormulaStep1 shift t a i body = do
Nothing -> Nothing

-- | `getRecurrenceFormulaStep` replaces `At` in @body@ with `Proj`.
getRecurrenceFormulaStep :: MonadAlpha m => Int -> Int -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr)
getRecurrenceFormulaStep :: MonadAlpha m => Integer -> Integer -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr)
getRecurrenceFormulaStep shift size t a i body = do
x <- genVarName a
let ts = replicate size t
let ts = replicate (fromInteger size) t
let proj k =
if 0 <= toInteger shift + k && toInteger shift + k < toInteger size
then Just $ Proj' ts (shift + fromInteger k) (Var x)
then Just $ Proj' ts (shift + k) (Var x)
else Nothing
let go :: Expr -> Maybe Expr
go = \case
Expand Down Expand Up @@ -129,9 +130,9 @@ reduceFoldlSetAtRecurrence = RewriteRule $ \_ -> \case
_ -> do
let ts = replicate (length base) t2
let base' = uncurryApp (Tuple' ts) base
step <- MaybeT $ getRecurrenceFormulaStep (- length base + fromInteger k) (length base) t2 a i step
step <- MaybeT $ getRecurrenceFormulaStep (- genericLength base + k) (genericLength base) t2 a i step
x <- lift (genVarName a)
return $ foldr (Cons' t2) (Map' (TupleTy ts) t2 (Lam x (TupleTy ts) (Proj' ts (length base - 1) (Var x))) (Scanl' IntTy (TupleTy ts) step base' (Range1' n))) (init base)
return $ foldr (Cons' t2) (Map' (TupleTy ts) t2 (Lam x (TupleTy ts) (Proj' ts (genericLength base - 1) (Var x))) (Scanl' IntTy (TupleTy ts) step base' (Range1' n))) (init base)
_ -> return Nothing

-- | `checkAccumulationFormulaStep` checks that all `At` in @body@ about @a@ are @At a i@.
Expand Down
17 changes: 9 additions & 8 deletions src/Jikka/Core/Convert/MatrixExponentiation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ where

import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.List
import qualified Data.Vector as V
import Jikka.Common.Alpha
import Jikka.Common.Error
Expand Down Expand Up @@ -52,13 +53,13 @@ fromAffineMatrix a b =
bottom = uncurryApp (Tuple' (replicate (w + 1) IntTy)) (replicate w (LitInt' 0) ++ [LitInt' 1])
in uncurryApp (Tuple' (replicate (h + 1) (TupleTy (replicate (w + 1) IntTy)))) (V.toList (V.zipWith go (unMatrix a) b) ++ [bottom])

toMatrix :: MonadAlpha m => [(VarName, Type)] -> VarName -> Int -> Expr -> m (Maybe (Matrix ArithmeticalExpr, Maybe (V.Vector ArithmeticalExpr)))
toMatrix :: MonadAlpha m => [(VarName, Type)] -> VarName -> Integer -> Expr -> m (Maybe (Matrix ArithmeticalExpr, Maybe (V.Vector ArithmeticalExpr)))
toMatrix env x n step =
case curryApp step of
(Tuple' _, es) -> runMaybeT $ do
xs <- V.fromList <$> replicateM n (lift (genVarName x))
xs <- V.fromList <$> replicateM (fromInteger n) (lift (genVarName x))
let unpackTuple _ e = case e of
Proj' _ i (Var x') | x' == x -> Var (xs V.! i)
Proj' _ i (Var x') | x' == x -> Var (xs V.! fromInteger i)
_ -> e
rows <- MaybeT . return . forM es $ \e -> do
let e' = mapExpr unpackTuple env e
Expand All @@ -69,14 +70,14 @@ toMatrix env x n step =
return (a, b)
_ -> return Nothing

addOneToVector :: Int -> VarName -> Expr
addOneToVector :: Integer -> VarName -> Expr
addOneToVector n x =
let ts = replicate n IntTy
let ts = replicate (fromInteger n) IntTy
in uncurryApp (Tuple' (IntTy : ts)) (map (\i -> Proj' ts i (Var x)) [0 .. n - 1] ++ [LitInt' 1])

removeOneFromVector :: Int -> VarName -> Expr
removeOneFromVector :: Integer -> VarName -> Expr
removeOneFromVector n x =
let ts = replicate n IntTy
let ts = replicate (fromInteger n) IntTy
in uncurryApp (Tuple' ts) (map (\i -> Proj' (IntTy : ts) i (Var x)) [0 .. n - 1])

rule :: MonadAlpha m => RewriteRule m
Expand All @@ -93,7 +94,7 @@ rule = RewriteRule $ \env -> \case
b' = Mult' (FloorDiv' (Minus' (Pow' a k) (LitInt' 1)) (Minus' a (LitInt' 1))) b -- This division has no remainder.
in Just $ Plus' (Mult' a' base) b'
Iterate' (TupleTy ts) k (Lam x _ step) base | isVectorTy' ts -> do
let n = length ts
let n = genericLength ts
let go n step base = MatAp' n n (MatPow' n step k) base
step <- toMatrix env x n step
case step of
Expand Down
9 changes: 5 additions & 4 deletions src/Jikka/Core/Convert/PropagateMod.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module Jikka.Core.Convert.PropagateMod
)
where

import Data.List
import Data.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
Expand Down Expand Up @@ -81,11 +82,11 @@ putFloorMod (Mod m) =
LitInt' n -> case m of
LitInt' m -> return' $ LitInt' (n `mod` m)
_ -> return Nothing
Proj' ts i e | isVectorTy' ts -> return' $ Proj' ts i (VecFloorMod' (length ts) e m)
Proj' ts i e | isVectorTy' ts -> return' $ Proj' ts i (VecFloorMod' (genericLength ts) e m)
Proj' ts i e
| isMatrixTy' ts ->
let (h, w) = fromJust (sizeOfMatrixTy (TupleTy ts))
in return' $ Proj' ts i (MatFloorMod' h w e m)
in return' $ Proj' ts i (MatFloorMod' (toInteger h) (toInteger w) e m)
Map' t1 t2 f xs -> do
f <- putFloorMod (Mod m) f
case f of
Expand Down Expand Up @@ -144,7 +145,7 @@ putVecFloorMod env = putFloorModGeneric fallback
fallback e (Mod m) = do
t <- typecheckExpr env e
case t of
TupleTy ts -> return $ VecFloorMod' (length ts) e m
TupleTy ts -> return $ VecFloorMod' (genericLength ts) e m
_ -> throwInternalError $ "not a vector: " ++ formatType t

putMatFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Mod -> Expr -> m Expr
Expand All @@ -153,7 +154,7 @@ putMatFloorMod env = putFloorModGeneric fallback
fallback e (Mod m) = do
t <- typecheckExpr env e
case t of
TupleTy ts@(TupleTy ts' : _) -> return $ MatFloorMod' (length ts) (length ts') e m
TupleTy ts@(TupleTy ts' : _) -> return $ MatFloorMod' (genericLength ts) (genericLength ts') e m
_ -> throwInternalError $ "not a matrix: " ++ formatType t

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
Expand Down
29 changes: 15 additions & 14 deletions src/Jikka/Core/Convert/SegmentTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,26 @@ pattern CumulativeSumFlip t f e es <-
x2 = findUnusedVarName (VarName "x") f
in Scanl' t t (Lam2 x1 t x2 t (App (App f (Var x2)) (Var x1))) e es

builtinToSemigroup :: Builtin -> Maybe Semigroup'
builtinToSemigroup = \case
Plus -> Just SemigroupIntPlus
Min2 IntTy -> Just SemigroupIntMin
Max2 IntTy -> Just SemigroupIntMax
builtinToSemigroup :: Builtin -> [Type] -> Maybe Semigroup'
builtinToSemigroup builtin ts = case (builtin, ts) of
(Plus, []) -> Just SemigroupIntPlus
(Min2, [IntTy]) -> Just SemigroupIntMin
(Max2, [IntTy]) -> Just SemigroupIntMax
_ -> Nothing

semigroupToBuiltin :: Semigroup' -> Builtin
semigroupToBuiltin :: Semigroup' -> (Builtin, [Type])
semigroupToBuiltin = \case
SemigroupIntPlus -> Plus
SemigroupIntMin -> Min2 IntTy
SemigroupIntMax -> Max2 IntTy
SemigroupIntPlus -> (Plus, [])
SemigroupIntMin -> (Min2, [IntTy])
SemigroupIntMax -> (Max2, [IntTy])

unCumulativeSum :: Expr -> Expr -> Maybe (Semigroup', Expr)
unCumulativeSum a = \case
CumulativeSum _ (Lit (LitBuiltin op)) b a' | a' == a -> case builtinToSemigroup op of
CumulativeSum _ (Lit (LitBuiltin op ts)) b a' | a' == a -> case builtinToSemigroup op ts of
Just semigrp -> Just (semigrp, b)
Nothing -> Nothing
-- Semigroups must be commutative to use CumulativeSumFlip.
CumulativeSumFlip _ (Lit (LitBuiltin op)) b a' | a' == a -> case builtinToSemigroup op of
CumulativeSumFlip _ (Lit (LitBuiltin op ts)) b a' | a' == a -> case builtinToSemigroup op ts of
Just semigrp -> Just (semigrp, b)
Nothing -> Nothing
_ -> Nothing
Expand All @@ -103,7 +103,7 @@ replaceWithSegtrees a segtrees = go M.empty
go env = \case
At' _ (check env -> Just (e, b, semigrp)) i ->
let e' = SegmentTreeGetRange' semigrp e (LitInt' 0) i
in AppBuiltin2 (semigroupToBuiltin semigrp) b e'
in App2 (Lit (uncurry LitBuiltin (semigroupToBuiltin semigrp))) b e'
Var x -> Var x
Lit lit -> Lit lit
App e1 e2 -> App (go env e1) (go env e2)
Expand All @@ -113,10 +113,11 @@ replaceWithSegtrees a segtrees = go M.empty
in case check env e1' of
Just (e1', b, semigrp) -> go (M.insert x (e1', b, semigrp) env) e2
Nothing -> Let x t (go env e1) (go env e2)
check :: M.Map VarName (Expr, Expr, Semigroup') -> Expr -> Maybe (Expr, Expr, Semigroup')
check env = \case
Var x -> M.lookup x env
CumulativeSum _ (Lit (LitBuiltin op)) b (Var a') | a' == a -> case lookup op (map (first semigroupToBuiltin) segtrees) of
Just e -> Just (e, b, fromJust (builtinToSemigroup op))
CumulativeSum _ (Lit (LitBuiltin op ts)) b (Var a') | a' == a -> case lookup (op, ts) (map (first semigroupToBuiltin) segtrees) of
Just e -> Just (e, b, fromJust (builtinToSemigroup op ts))
Nothing -> Nothing
_ -> Nothing

Expand Down
Loading