diff --git a/examples/dp_z-kubaru.py b/examples/dp_z-kubaru.py index 6ca2256e..257dd901 100644 --- a/examples/dp_z-kubaru.py +++ b/examples/dp_z-kubaru.py @@ -4,7 +4,7 @@ INF = 10 ** 18 def solve(n: int, c: int, h: List[int]) -> int: - assert 2 <= n <= 10 ** 5 + assert 2 <= n <= 2 * 10 ** 5 assert 1 <= c <= 10 ** 12 assert len(h) == n assert all(1 <= h_i <= 10 ** 6 for h_i in h) diff --git a/examples/dp_z-morau.py b/examples/dp_z-morau.py index aeb14533..6bf1d5f3 100644 --- a/examples/dp_z-morau.py +++ b/examples/dp_z-morau.py @@ -4,7 +4,7 @@ INF = 10 ** 18 def solve(n: int, c: int, h: List[int]) -> int: - assert 2 <= n <= 10 ** 5 + assert 2 <= n <= 2 * 10 ** 5 assert 1 <= c <= 10 ** 12 assert len(h) == n assert all(1 <= h_i <= 10 ** 6 for h_i in h) diff --git a/src/Jikka/CPlusPlus/Convert/FromCore.hs b/src/Jikka/CPlusPlus/Convert/FromCore.hs index ff92c19d..d2eea3c1 100644 --- a/src/Jikka/CPlusPlus/Convert/FromCore.hs +++ b/src/Jikka/CPlusPlus/Convert/FromCore.hs @@ -481,14 +481,14 @@ runAppBuiltin env f ts args = wrapError' ("converting builtin " ++ X.formatBuilt X.All -> go01' $ \xs -> do y <- Y.newFreshName Y.LocalNameKind return - ( [ Y.Declare Y.TyBool y (Y.DeclareCopy (Y.BinOp Y.Equal (Y.callFunction "std::find" [] [Y.begin xs, Y.end xs, Y.Lit (Y.LitBool True)]) (Y.end xs))) + ( [ Y.Declare Y.TyBool y (Y.DeclareCopy (Y.BinOp Y.Equal (Y.callFunction "std::find" [] [Y.begin xs, Y.end xs, Y.Lit (Y.LitBool False)]) (Y.end xs))) ], Y.Var y ) X.Any -> go01' $ \xs -> do y <- Y.newFreshName Y.LocalNameKind return - ( [ Y.Declare Y.TyBool y (Y.DeclareCopy (Y.BinOp Y.NotEqual (Y.callFunction "std::find" [] [Y.begin xs, Y.end xs, Y.Lit (Y.LitBool False)]) (Y.end xs))) + ( [ Y.Declare Y.TyBool y (Y.DeclareCopy (Y.BinOp Y.NotEqual (Y.callFunction "std::find" [] [Y.begin xs, Y.end xs, Y.Lit (Y.LitBool True)]) (Y.end xs))) ], Y.Var y ) @@ -637,6 +637,10 @@ runExpr env = \case (stmts1, e1) <- runExpr env e1 (stmts2, e2) <- runExpr ((x, t, y) : env) e2 return (stmts1 ++ Y.Declare t' y (Y.DeclareCopy e1) : stmts2, e2) + X.Assert e1 e2 -> do + (stmts1, e1) <- runExpr env e1 + (stmts2, e2) <- runExpr env e2 + return (stmts1 ++ Y.Assert e1 : stmts2, e2) runToplevelFunDef :: (MonadAlpha m, MonadError Error m) => Env -> Y.VarName -> [(X.VarName, X.Type)] -> X.Type -> X.Expr -> m [Y.ToplevelStatement] runToplevelFunDef env f args ret body = do @@ -713,6 +717,11 @@ runToplevelExpr env = \case stmt <- runToplevelFunDef ((f, t, g) : env) g args ret body cont <- runToplevelExpr ((f, t, g) : env) cont return $ stmt ++ cont + X.ToplevelAssert e cont -> do + (stmts, e) <- runExpr env e + let stmt = Y.StaticAssert (Y.CallExpr (Y.Lam [] Y.TyBool (stmts ++ [Y.Return e])) []) "" + cont <- runToplevelExpr env cont + return $ stmt : cont runProgram :: (MonadAlpha m, MonadError Error m) => X.Program -> m Y.Program runProgram prog = Y.Program <$> runToplevelExpr [] prog diff --git a/src/Jikka/CPlusPlus/Convert/MoveSemantics.hs b/src/Jikka/CPlusPlus/Convert/MoveSemantics.hs index 8aa2e44b..a788f7ec 100644 --- a/src/Jikka/CPlusPlus/Convert/MoveSemantics.hs +++ b/src/Jikka/CPlusPlus/Convert/MoveSemantics.hs @@ -152,6 +152,7 @@ runToplevelStatement :: MonadState (M.Map VarName VarName) m => ToplevelStatemen runToplevelStatement = \case VarDef t x e -> VarDef t x <$> runExpr e FunDef ret f args body -> FunDef ret f args <$> runStatements body [] + StaticAssert e msg -> StaticAssert <$> runExpr e <*> pure msg runProgram :: Monad m => Program -> m Program runProgram (Program decls) = (`evalStateT` M.empty) $ do diff --git a/src/Jikka/CPlusPlus/Convert/UnpackTuples.hs b/src/Jikka/CPlusPlus/Convert/UnpackTuples.hs index 0ff91689..1346a50e 100644 --- a/src/Jikka/CPlusPlus/Convert/UnpackTuples.hs +++ b/src/Jikka/CPlusPlus/Convert/UnpackTuples.hs @@ -203,6 +203,7 @@ runToplevelStatement :: (MonadAlpha m, MonadError Error m, MonadState (M.Map Var runToplevelStatement = \case VarDef t x e -> VarDef t x <$> runExpr e FunDef ret f args body -> FunDef ret f args <$> runStatements body [] + StaticAssert e msg -> StaticAssert <$> runExpr e <*> pure msg runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program runProgram (Program decls) = (`evalStateT` M.empty) $ do diff --git a/src/Jikka/CPlusPlus/Format.hs b/src/Jikka/CPlusPlus/Format.hs index 753398fe..00cb8fe3 100644 --- a/src/Jikka/CPlusPlus/Format.hs +++ b/src/Jikka/CPlusPlus/Format.hs @@ -320,6 +320,8 @@ formatToplevelStatement = \case args' = intercalate ", " $ map (\(t, x) -> formatType t ++ " " ++ unVarName x) args body' = concatMap formatStatement body in [ret' ++ " " ++ unVarName f ++ "(" ++ args' ++ ") {"] ++ body' ++ ["}"] + StaticAssert e msg -> + ["static_assert (" ++ resolvePrec CommaPrec (formatExpr e) ++ ", " ++ formatLiteral (LitString msg) ++ ");"] formatProgram :: Program -> [Code] formatProgram prog = @@ -327,6 +329,7 @@ formatProgram prog = standardHeaders = [ "#include ", "#include ", + "#include ", "#include ", "#include ", "#include ", diff --git a/src/Jikka/CPlusPlus/Language/Expr.hs b/src/Jikka/CPlusPlus/Language/Expr.hs index 242e239b..d2bd6ea7 100644 --- a/src/Jikka/CPlusPlus/Language/Expr.hs +++ b/src/Jikka/CPlusPlus/Language/Expr.hs @@ -249,6 +249,8 @@ data ToplevelStatement VarDef Type VarName Expr | -- | @T f(T1 x1, T2 x2, ...) { stmt1; stmt2; ... }@ FunDef Type VarName [(Type, VarName)] [Statement] + | -- | @static_assert(e, msg);@ + StaticAssert Expr String deriving (Eq, Ord, Show, Read) newtype Program = Program diff --git a/src/Jikka/CPlusPlus/Language/Util.hs b/src/Jikka/CPlusPlus/Language/Util.hs index 1c53f692..9e4dace3 100644 --- a/src/Jikka/CPlusPlus/Language/Util.hs +++ b/src/Jikka/CPlusPlus/Language/Util.hs @@ -198,6 +198,7 @@ mapExprStatementToplevelStatementM :: Monad m => (Expr -> m Expr) -> (Statement mapExprStatementToplevelStatementM f g = \case VarDef t x e -> VarDef t x <$> mapExprStatementExprM f g e FunDef ret h args body -> FunDef ret h args <$> mapM (mapExprStatementStatementM f g) body + StaticAssert e msg -> StaticAssert <$> mapExprStatementExprM f g e <*> pure msg mapExprStatementProgramM :: Monad m => (Expr -> m Expr) -> (Statement -> m Statement) -> Program -> m Program mapExprStatementProgramM f g (Program decls) = Program <$> mapM (mapExprStatementToplevelStatementM f g) decls diff --git a/src/Jikka/Core/Convert/ANormal.hs b/src/Jikka/Core/Convert/ANormal.hs index 84d6a269..db9a5e28 100644 --- a/src/Jikka/Core/Convert/ANormal.hs +++ b/src/Jikka/Core/Convert/ANormal.hs @@ -22,7 +22,7 @@ import Jikka.Core.Language.Lint import Jikka.Core.Language.TypeCheck import Jikka.Core.Language.Util -destruct :: (MonadAlpha m, MonadError Error m) => TypeEnv -> Expr -> m (TypeEnv, Expr -> Expr, Expr) +destruct :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> m ([(VarName, Type)], Expr -> Expr, Expr) destruct env = \case e@Var {} -> return (env, id, e) e@Lit {} -> return (env, id, e) @@ -38,8 +38,12 @@ destruct env = \case (env, ctx, e1) <- destruct env e1 (env, ctx', e2) <- destruct ((x, t) : env) e2 return (env, ctx . Let x t e1 . ctx', e2) + Assert e1 e2 -> do + (env, ctx, e1) <- destruct env e1 + (env, ctx', e2) <- destruct env e2 + return (env, ctx . Assert e1 . ctx', e2) -runApp :: (MonadAlpha m, MonadError Error m) => TypeEnv -> Expr -> [Expr] -> m Expr +runApp :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> [Expr] -> m Expr runApp env f args = go env id args where go :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> ([Expr] -> [Expr]) -> [Expr] -> m Expr @@ -51,7 +55,7 @@ runApp env f args = go env id args e <- go env (acc . (arg :)) args return $ ctx e -runExpr :: (MonadAlpha m, MonadError Error m) => TypeEnv -> Expr -> m Expr +runExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> m Expr runExpr env = \case Var x -> return $ Var x Lit lit -> return $ Lit lit @@ -70,19 +74,24 @@ runExpr env = \case (env, ctx, e1) <- destruct env e1 e2 <- runExpr ((x, t) : env) e2 return $ ctx (Let x t e1 e2) + Assert e1 e2 -> do + e1 <- runExpr env e1 + (env, ctx, e1) <- destruct env e1 + e2 <- runExpr env e2 + return $ ctx (Assert e1 e2) -runToplevelExpr :: (MonadAlpha m, MonadError Error m) => TypeEnv -> ToplevelExpr -> m ToplevelExpr -runToplevelExpr env = \case - ResultExpr e -> ResultExpr <$> runExpr env e - ToplevelLet x t e cont -> do - e <- runExpr env e - cont <- runToplevelExpr ((x, t) : env) cont - return $ ToplevelLet x t e cont - ToplevelLetRec f args ret body cont -> do - let t = curryFunTy (map snd args) ret - body <- runExpr (reverse args ++ (f, t) : env) body - cont <- runToplevelExpr ((f, t) : env) cont - return $ ToplevelLetRec f args ret body cont +-- | TODO: convert `ToplevelExpr` too +runProgram :: (MonadAlpha m, MonadError Error m) => ToplevelExpr -> m ToplevelExpr +runProgram = mapToplevelExprProgramM go + where + go env = \case + ResultExpr e -> ResultExpr <$> runExpr env e + ToplevelLet x t e cont -> ToplevelLet x t <$> runExpr env e <*> pure cont + ToplevelLetRec f args ret body cont -> do + let t = curryFunTy (map snd args) ret + let env' = reverse args ++ (f, t) : env + ToplevelLetRec f args ret <$> runExpr env' body <*> pure cont + ToplevelAssert e cont -> ToplevelAssert <$> runExpr env e <*> pure cont -- | `run` makes a given program A-normal form. -- A program is an A-normal form iff assigned exprs of all let-statements are values or function applications. @@ -99,6 +108,6 @@ runToplevelExpr env = \case run :: (MonadAlpha m, MonadError Error m) => Program -> m Program run prog = wrapError' "Jikka.Core.Convert.ANormal" $ do prog <- Alpha.runProgram prog - prog <- runToplevelExpr [] prog + prog <- runProgram prog ensureWellTyped prog return prog diff --git a/src/Jikka/Core/Convert/Alpha.hs b/src/Jikka/Core/Convert/Alpha.hs index cf4fbdbf..9424c22a 100644 --- a/src/Jikka/Core/Convert/Alpha.hs +++ b/src/Jikka/Core/Convert/Alpha.hs @@ -38,6 +38,7 @@ runExpr env = \case y <- rename x e2 <- runExpr ((x, y) : env) e2 return $ Let y t e1 e2 + Assert e1 e2 -> Assert <$> runExpr env e1 <*> runExpr env e2 runToplevelExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> ToplevelExpr -> m ToplevelExpr runToplevelExpr env = \case @@ -57,6 +58,7 @@ runToplevelExpr env = \case body <- runExpr (args1 ++ (f, g) : env) body cont <- runToplevelExpr ((f, g) : env) cont return $ ToplevelLetRec g args2 ret body cont + ToplevelAssert e1 e2 -> ToplevelAssert <$> runExpr env e1 <*> runToplevelExpr env e2 runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program runProgram = runToplevelExpr [] diff --git a/src/Jikka/Core/Convert/ConstantPropagation.hs b/src/Jikka/Core/Convert/ConstantPropagation.hs index 4eccc240..25be5f87 100644 --- a/src/Jikka/Core/Convert/ConstantPropagation.hs +++ b/src/Jikka/Core/Convert/ConstantPropagation.hs @@ -35,6 +35,7 @@ runExpr env = \case in if isConstantTimeExpr e1' then runExpr (M.insert x e1' env) e2 else Let x t e1' (runExpr env e2) + Assert e1 e2 -> Assert (runExpr env e1) (runExpr env e2) runToplevelExpr :: Env -> ToplevelExpr -> ToplevelExpr runToplevelExpr env = \case @@ -46,6 +47,7 @@ runToplevelExpr env = \case else ToplevelLet x t e' (runToplevelExpr env cont) ToplevelLetRec f args ret body cont -> ToplevelLetRec f args ret (runExpr env body) (runToplevelExpr env cont) + ToplevelAssert e1 e2 -> ToplevelAssert (runExpr env e1) (runToplevelExpr env e2) run' :: Program -> Program run' = runToplevelExpr M.empty diff --git a/src/Jikka/Core/Convert/KubaruToMorau.hs b/src/Jikka/Core/Convert/KubaruToMorau.hs index 90d4fa03..fab47aec 100644 --- a/src/Jikka/Core/Convert/KubaruToMorau.hs +++ b/src/Jikka/Core/Convert/KubaruToMorau.hs @@ -50,6 +50,7 @@ runFunctionBody c i j step y x k = do Lam x t e | x == c || x == i || x == j -> throwRuntimeError "name confliction found" | otherwise -> Lam x t <$> go e + Assert e1 e2 -> Assert <$> go e1 <*> go e2 go step -- | TODO: remove the assumption that the length of @a@ is equals to @n@ diff --git a/src/Jikka/Core/Convert/MakeScanl.hs b/src/Jikka/Core/Convert/MakeScanl.hs index 84b49f60..5cf2f544 100644 --- a/src/Jikka/Core/Convert/MakeScanl.hs +++ b/src/Jikka/Core/Convert/MakeScanl.hs @@ -74,6 +74,7 @@ getRecurrenceFormulaStep1 shift t a i body = do App f e -> App <$> go f <*> go e Lam x t e -> Lam x t <$> if x == a then Just e else go e Let x t e1 e2 -> Let x t <$> go e1 <*> if x == a then Just e2 else go e2 + Assert f e -> Assert <$> go f <*> go e return $ case go body of Just body -> Just $ Lam2 x t i IntTy body Nothing -> Nothing @@ -97,6 +98,7 @@ getRecurrenceFormulaStep shift size t a i body = do App f e -> App <$> go f <*> go e Lam x t e -> Lam x t <$> if x == a then Just e else go e Let x t e1 e2 -> Let x t <$> go e1 <*> if x == a then Just e2 else go e2 + Assert f e -> Assert <$> go f <*> go e return $ case go body of Just body -> Just $ Lam2 x (TupleTy ts) i IntTy (uncurryApp (Tuple' ts) (map (\i -> Proj' ts i (Var x)) [1 .. size - 1] ++ [body])) Nothing -> Nothing @@ -148,6 +150,7 @@ checkAccumulationFormulaStep a i = go App f e -> go f && go e Lam x _ e -> x == a || go e Let x _ e1 e2 -> go e1 && (x == a || go e2) + Assert e1 e2 -> go e1 && go e2 -- | -- * This assumes that `Range2` and `Range3` are already converted to `Range1` (`Jikka.Core.Convert.ShortCutFusion`). @@ -196,6 +199,7 @@ checkGenericRecurrenceFormulaStep a = \i k -> go (M.fromList [(i, k - 1)]) App f e -> go env f && go env e Lam x _ e -> x == a || go env e Let x _ e1 e2 -> go env e1 && (x == a || go env e2) + Assert e1 e2 -> go env e1 && go env e2 reduceFoldlSetAtGeneric :: MonadAlpha m => RewriteRule m reduceFoldlSetAtGeneric = RewriteRule $ \_ -> \case diff --git a/src/Jikka/Core/Convert/RemoveUnusedVars.hs b/src/Jikka/Core/Convert/RemoveUnusedVars.hs index 0b371494..d7c125ef 100644 --- a/src/Jikka/Core/Convert/RemoveUnusedVars.hs +++ b/src/Jikka/Core/Convert/RemoveUnusedVars.hs @@ -23,32 +23,24 @@ import Jikka.Core.Language.FreeVars (isUnusedVar) import Jikka.Core.Language.Lint import Jikka.Core.Language.Util -runLet :: VarName -> Type -> Expr -> Expr -> Expr -runLet x t e1 e2 - | isUnusedVar x e2 = e2 - | otherwise = Let x t e1 e2 +runExpr :: [(VarName, Type)] -> Expr -> Expr +runExpr _ = mapExpr go [] + where + go _ = \case + Let x _ _ e2 | x `isUnusedVar` e2 -> e2 + e -> e -runExpr :: Expr -> Expr -runExpr = \case - Var x -> Var x - Lit lit -> Lit lit - App f e -> App (runExpr f) (runExpr e) - Lam x t e -> Lam x t (runExpr e) - Let x t e1 e2 -> runLet x t (runExpr e1) (runExpr e2) - -runToplevelExpr :: ToplevelExpr -> ToplevelExpr -runToplevelExpr = \case - ResultExpr e -> ResultExpr $ runExpr e - ToplevelLet x t e cont -> ToplevelLet x t (runExpr e) (runToplevelExpr cont) +-- | TODO: Remove `ToplevelLet` if its variable is not used. +runToplevelExpr :: [(VarName, Type)] -> ToplevelExpr -> ToplevelExpr +runToplevelExpr _ = \case ToplevelLetRec f args ret body cont -> - let body' = runExpr body - cont' = runToplevelExpr cont - in if isUnusedVar f body' - then ToplevelLet f (curryFunTy (map snd args) ret) (curryLam args body') cont' - else ToplevelLetRec f args ret body' cont' + if isUnusedVar f body + then ToplevelLet f (curryFunTy (map snd args) ret) (curryLam args body) cont + else ToplevelLetRec f args ret body cont + e -> e run' :: Program -> Program -run' = runToplevelExpr +run' = mapToplevelExprProgram runToplevelExpr . mapExprProgram runExpr -- | `run` removes unused variables in given programs. -- diff --git a/src/Jikka/Core/Convert/SegmentTree.hs b/src/Jikka/Core/Convert/SegmentTree.hs index 37d2bdc0..1859ead4 100644 --- a/src/Jikka/Core/Convert/SegmentTree.hs +++ b/src/Jikka/Core/Convert/SegmentTree.hs @@ -113,6 +113,7 @@ replaceWithSegtrees a segtrees = go M.empty in case check env e1' of Just (e1', b, semigrp) -> go (M.insert x (e1', b, semigrp) env) e2 Nothing -> Let x t (go env e1) (go env e2) + Assert e1 e2 -> Assert (go env e1) (go env e2) check :: M.Map VarName (Expr, Expr, Semigroup') -> Expr -> Maybe (Expr, Expr, Semigroup') check env = \case Var x -> M.lookup x env diff --git a/src/Jikka/Core/Convert/TrivialLetElimination.hs b/src/Jikka/Core/Convert/TrivialLetElimination.hs index 74933ab0..41848a33 100644 --- a/src/Jikka/Core/Convert/TrivialLetElimination.hs +++ b/src/Jikka/Core/Convert/TrivialLetElimination.hs @@ -35,12 +35,14 @@ isEliminatable x = \case App f e -> isEliminatable x f `plus` isEliminatable x e Lam y _ e -> if x == y then Nothing else isEliminatable x e $> False -- moving an expr into a lambda may increase the time complexity Let y _ e1 e2 -> isEliminatable x e1 `plus` (if x == y then Nothing else isEliminatable x e2) + Assert e1 e2 -> isEliminatable x e1 `plus` isEliminatable x e2 isEliminatableToplevelExpr :: VarName -> ToplevelExpr -> Maybe Bool isEliminatableToplevelExpr x = \case ResultExpr e -> isEliminatable x e ToplevelLet y _ e cont -> isEliminatable x e `plus` (if x == y then Nothing else isEliminatableToplevelExpr x cont) ToplevelLetRec f args _ body cont -> if x == f then Nothing else isEliminatableToplevelExpr x cont `plus` (if x `elem` map fst args then Nothing else isEliminatable x body) + ToplevelAssert e cont -> isEliminatable x e `plus` isEliminatableToplevelExpr x cont runExpr :: M.Map VarName Expr -> Expr -> Expr runExpr env = \case @@ -53,6 +55,7 @@ runExpr env = \case in if isEliminatable x e2 /= Just False then runExpr (M.insert x e1' env) e2 else Let x t e1' (runExpr env e2) + Assert e1 e2 -> Assert (runExpr env e1) (runExpr env e2) runToplevelExpr :: M.Map VarName Expr -> ToplevelExpr -> ToplevelExpr runToplevelExpr env = \case @@ -64,6 +67,8 @@ runToplevelExpr env = \case else ToplevelLet x t e' (runToplevelExpr env cont) ToplevelLetRec f args ret body cont -> ToplevelLetRec f args ret (runExpr env body) (runToplevelExpr env cont) + ToplevelAssert e cont -> + ToplevelAssert (runExpr env e) (runToplevelExpr env cont) run' :: Program -> Program run' = runToplevelExpr M.empty diff --git a/src/Jikka/Core/Convert/TypeInfer.hs b/src/Jikka/Core/Convert/TypeInfer.hs index e268f550..db290a50 100644 --- a/src/Jikka/Core/Convert/TypeInfer.hs +++ b/src/Jikka/Core/Convert/TypeInfer.hs @@ -73,6 +73,9 @@ formularizeExpr = \case formularizeVarName x t formularizeExpr' e1 t formularizeExpr e2 + Assert e1 e2 -> do + formularizeExpr' e1 BoolTy + formularizeExpr e2 formularizeExpr' :: (MonadWriter Eqns m, MonadAlpha m, MonadError Error m) => Expr -> Type -> m () formularizeExpr' e t = do @@ -91,6 +94,9 @@ formularizeToplevelExpr = \case mapM_ (uncurry formularizeVarName) args formularizeExpr' body ret formularizeToplevelExpr cont + ToplevelAssert e cont -> do + formularizeExpr' e BoolTy + formularizeToplevelExpr cont formularizeProgram :: (MonadAlpha m, MonadError Error m) => Program -> m [Equation] formularizeProgram prog = getDual <$> execWriterT (formularizeToplevelExpr prog) diff --git a/src/Jikka/Core/Evaluate.hs b/src/Jikka/Core/Evaluate.hs index 2b080ffc..8037a18c 100644 --- a/src/Jikka/Core/Evaluate.hs +++ b/src/Jikka/Core/Evaluate.hs @@ -29,7 +29,7 @@ import qualified Data.Vector as V import Jikka.Common.Alpha import Jikka.Common.Error import Jikka.Common.Matrix -import Jikka.Core.Format (formatBuiltinIsolated) +import Jikka.Core.Format (formatBuiltinIsolated, formatExpr) import Jikka.Core.Language.BuiltinPatterns import Jikka.Core.Language.Expr import Jikka.Core.Language.Lint @@ -292,6 +292,11 @@ evaluateExpr env = \case Let x _ e1 e2 -> do v1 <- evaluateExpr env e1 evaluateExpr ((x, v1) : env) e2 + Assert e1 e2 -> do + p <- valueToBool =<< evaluateExpr env e1 + if p + then evaluateExpr env e2 + else throwRuntimeError $ "assertion failed: " ++ formatExpr e1 callToplevelExpr :: (MonadFix m, MonadError Error m) => Env -> ToplevelExpr -> [Value] -> m Value callToplevelExpr env e args = case e of @@ -301,6 +306,11 @@ callToplevelExpr env e args = case e of ToplevelLetRec f args' _ body cont -> do val <- mfix $ \val -> evaluateExpr ((f, val) : env) (curryLam args' body) callToplevelExpr ((f, val) : env) cont args + ToplevelAssert e cont -> do + p <- valueToBool =<< evaluateExpr env e + if p + then callToplevelExpr env cont args + else throwRuntimeError $ "toplevel assertion failed: " ++ formatExpr e ResultExpr e -> do val <- evaluateExpr env e callValue val args diff --git a/src/Jikka/Core/Format.hs b/src/Jikka/Core/Format.hs index fb18c7b5..f5c8d5c0 100644 --- a/src/Jikka/Core/Format.hs +++ b/src/Jikka/Core/Format.hs @@ -315,6 +315,7 @@ formatExpr' = \case let (args, body) = uncurryLam e in ("fun " ++ formatFormalArgs args ++ " ->\n" ++ indent ++ "\n" ++ resolvePrec parenPrec (formatExpr' body) ++ "\n" ++ dedent ++ "\n", lambdaPrec) Let x t e1 e2 -> ("let " ++ unVarName x ++ ": " ++ formatType t ++ " =\n" ++ indent ++ "\n" ++ resolvePrec parenPrec (formatExpr' e1) ++ "\n" ++ dedent ++ "\nin " ++ resolvePrec lambdaPrec (formatExpr' e2), lambdaPrec) + Assert e1 e2 -> ("assert " ++ resolvePrec parenPrec (formatExpr' e1) ++ " in " ++ resolvePrec lambdaPrec (formatExpr' e2), lambdaPrec) formatExpr :: Expr -> String formatExpr = unlines . makeIndentFromMarkers 4 . lines . resolvePrec parenPrec . formatExpr' @@ -324,6 +325,7 @@ formatToplevelExpr = \case ResultExpr e -> lines (resolvePrec lambdaPrec (formatExpr' e)) ToplevelLet x t e cont -> let' (unVarName x) t e cont ToplevelLetRec f args ret e cont -> let' ("rec " ++ unVarName f ++ " " ++ formatFormalArgs args) ret e cont + ToplevelAssert e cont -> ["assert " ++ resolvePrec parenPrec (formatExpr' e), "in"] ++ formatToplevelExpr cont where let' s t e cont = ["let " ++ s ++ ": " ++ formatType t ++ " =", indent] diff --git a/src/Jikka/Core/Language/Beta.hs b/src/Jikka/Core/Language/Beta.hs index fadd0fb9..a331acbf 100644 --- a/src/Jikka/Core/Language/Beta.hs +++ b/src/Jikka/Core/Language/Beta.hs @@ -46,6 +46,7 @@ substitute x e = \case else do (y, e2) <- resolveConflict e (y, e2) Let y t e1 <$> substitute x e e2 + Assert e1 e2 -> Assert <$> substitute x e e1 <*> substitute x e e2 substituteToplevelExpr :: (MonadAlpha m, MonadError Error m) => VarName -> Expr -> ToplevelExpr -> m ToplevelExpr substituteToplevelExpr x e = \case @@ -73,6 +74,7 @@ substituteToplevelExpr x e = \case return (args ++ [(y, t)], body) foldM go ([], body) args ToplevelLetRec f args ret body <$> substituteToplevelExpr x e cont + ToplevelAssert e1 e2 -> ToplevelAssert <$> substitute x e e1 <*> substituteToplevelExpr x e e2 resolveConflict :: MonadAlpha m => Expr -> (VarName, Expr) -> m (VarName, Expr) resolveConflict e (x, e') = diff --git a/src/Jikka/Core/Language/Expr.hs b/src/Jikka/Core/Language/Expr.hs index 81834c47..1f301455 100644 --- a/src/Jikka/Core/Language/Expr.hs +++ b/src/Jikka/Core/Language/Expr.hs @@ -291,7 +291,8 @@ data Literal -- \vert & \mathrm{literal}\ldots \\ -- \vert & e_0(e_1, e_2, \dots, e_n) \\ -- \vert & \lambda ~ x_0\colon \tau_0, x_1\colon \tau_1, \dots, x_{n-1}\colon \tau_{n-1}. ~ e \\ --- \vert & \mathbf{let} ~ x\colon \tau = e_1 ~ \mathbf{in} ~ e_2 +-- \vert & \mathbf{let} ~ x\colon \tau = e_1 ~ \mathbf{in} ~ e_2 \\ +-- \vert & \mathbf{assert} ~ e_1 ~ \mathbf{in} ~ e_2 -- \end{array} -- \] data Expr @@ -301,6 +302,7 @@ data Expr | Lam VarName Type Expr | -- | This "let" is not recursive. Let VarName Type Expr Expr + | Assert Expr Expr deriving (Eq, Ord, Show, Read, Data, Typeable) pattern Fun2Ty t1 t2 ret = FunTy t1 (FunTy t2 ret) @@ -399,13 +401,15 @@ pattern Lam3 x1 t1 x2 t2 x3 t3 e = Lam x1 t1 (Lam x2 t2 (Lam x3 t3 e)) -- \begin{array}{rl} -- \mathrm{tle} ::= & e \\ -- \vert & \mathbf{let}~ x: \tau = e ~\mathbf{in}~ \mathrm{tle} \\ --- \vert & \mathbf{let~rec}~ x(x: \tau, x: \tau, \dots, x: \tau): \tau = e ~\mathbf{in}~ \mathrm{tle} +-- \vert & \mathbf{let~rec}~ x(x: \tau, x: \tau, \dots, x: \tau): \tau = e ~\mathbf{in}~ \mathrm{tle} \\ +-- \vert & \mathbf{assert}~ e ~\mathbf{in}~ \mathrm{tle} -- \end{array} -- \] data ToplevelExpr = ResultExpr Expr | ToplevelLet VarName Type Expr ToplevelExpr | ToplevelLetRec VarName [(VarName, Type)] Type Expr ToplevelExpr + | ToplevelAssert Expr ToplevelExpr deriving (Eq, Ord, Show, Read, Data, Typeable) type Program = ToplevelExpr diff --git a/src/Jikka/Core/Language/FreeVars.hs b/src/Jikka/Core/Language/FreeVars.hs index 321f7e75..1adde133 100644 --- a/src/Jikka/Core/Language/FreeVars.hs +++ b/src/Jikka/Core/Language/FreeVars.hs @@ -26,6 +26,7 @@ isFreeVar x = \case App f e -> isFreeVar x f || isFreeVar x e Lam y _ e -> x /= y && isFreeVar x e Let y _ e1 e2 -> (y /= x && isFreeVar x e1) || isFreeVar x e2 + Assert e1 e2 -> isFreeVar x e1 || isFreeVar x e2 -- | `isUnusedVar` is the negation of `isFreeVar`. -- @@ -44,6 +45,7 @@ isFreeVarOrScopedVar x = \case App f e -> isFreeVarOrScopedVar x f || isFreeVarOrScopedVar x e Lam y _ e -> x == y || isFreeVarOrScopedVar x e Let y _ e1 e2 -> y == x || isFreeVarOrScopedVar x e1 || isFreeVarOrScopedVar x e2 + Assert e1 e2 -> isFreeVarOrScopedVar x e1 || isFreeVarOrScopedVar x e2 freeTyVars :: Type -> [TypeName] freeTyVars = \case diff --git a/src/Jikka/Core/Language/QuasiRules.hs b/src/Jikka/Core/Language/QuasiRules.hs index ec0c0bb8..852a633a 100644 --- a/src/Jikka/Core/Language/QuasiRules.hs +++ b/src/Jikka/Core/Language/QuasiRules.hs @@ -128,6 +128,10 @@ toPatE = \case modify' (\env -> env {vars = (x, Just (VarE y)) : vars env}) e2 <- toPatE e2 lift [p|Let $(pure (VarP y)) $(pure t) $(pure e1) $(pure e2)|] + Assert e1 e2 -> do + e1 <- toPatE e1 + e2 <- toPatE e2 + lift [p|Assert $(pure e1) $(pure e2)|] toExpT :: Type -> StateT Env Q Exp toExpT = \case @@ -201,6 +205,11 @@ toExpE e = do (stmts', e2) <- toExpE e2 e <- lift [e|Let $(pure (VarE y)) $(pure t) $(pure e1) $(pure e2)|] return (stmts ++ BindS (VarP y) (VarE genVarName) : stmts', e) + Assert e1 e2 -> do + (stmts1, e1) <- toExpE e1 + (stmts2, e2) <- toExpE e2 + e <- lift [e|Assert $(pure e1) $(pure e2)|] + return (stmts1 ++ stmts2, e) ruleExp :: String -> Q Exp ruleExp s = do diff --git a/src/Jikka/Core/Language/RewriteRules.hs b/src/Jikka/Core/Language/RewriteRules.hs index cd7c1af8..b9db1362 100644 --- a/src/Jikka/Core/Language/RewriteRules.hs +++ b/src/Jikka/Core/Language/RewriteRules.hs @@ -85,6 +85,10 @@ applyRewriteRuleToImmediateSubExprs f env = \case e1' <- lift $ unRewriteRule f env e1 e2' <- lift $ unRewriteRule f ((x, t) : env) e2 return $ fmap (uncurry (Let x t)) (coalesceMaybes e1 e1' e2 e2') + Assert e1 e2 -> do + e1' <- lift $ unRewriteRule f env e1 + e2' <- lift $ unRewriteRule f env e2 + return $ fmap (uncurry Assert) (coalesceMaybes e1 e1' e2 e2') joinStateT :: Monad m => StateT s (StateT s m) a -> StateT s m a joinStateT f = do @@ -130,6 +134,10 @@ applyRewriteRuleToplevelExpr f env = \case body' <- applyRewriteRule f (reverse args ++ env') body cont' <- applyRewriteRuleToplevelExpr f env' cont return $ fmap (uncurry (ToplevelLetRec g args ret)) (coalesceMaybes body body' cont cont') + ToplevelAssert e1 e2 -> do + e1' <- applyRewriteRule f env e1 + e2' <- applyRewriteRuleToplevelExpr f env e2 + return $ fmap (uncurry ToplevelAssert) (coalesceMaybes e1 e1' e2 e2') applyRewriteRuleProgram :: MonadError Error m => RewriteRule m -> Program -> m (Maybe Program) applyRewriteRuleProgram f prog = evalStateT (applyRewriteRuleToplevelExpr f [] prog) 0 diff --git a/src/Jikka/Core/Language/TypeCheck.hs b/src/Jikka/Core/Language/TypeCheck.hs index 50ef36af..c38a3e11 100644 --- a/src/Jikka/Core/Language/TypeCheck.hs +++ b/src/Jikka/Core/Language/TypeCheck.hs @@ -176,6 +176,11 @@ typecheckExpr env = \case throwInternalError $ "wrong type binding: " ++ formatExpr (Let x t e1 e2) let env' = if x == VarName "_" then env else (x, t) : env typecheckExpr env' e2 + Assert e1 e2 -> do + t <- typecheckExpr env e1 + when (t /= BoolTy) $ do + throwInternalError $ "wrong type assertion: expr = " ++ formatExpr e1 ++ " has type = " ++ formatType t + typecheckExpr env e2 typecheckToplevelExpr :: MonadError Error m => TypeEnv -> ToplevelExpr -> m Type typecheckToplevelExpr env = \case @@ -193,6 +198,11 @@ typecheckToplevelExpr env = \case when (ret' /= ret) $ do throwInternalError $ "returned type is not correct: context = (let rec " ++ unVarName f ++ " " ++ unwords (map (\(x, t) -> unVarName x ++ ": " ++ formatType t) args) ++ ": " ++ formatType ret ++ " = " ++ formatExpr body ++ " in ...), expected type = " ++ formatType ret ++ ", actual type = " ++ formatType ret' typecheckToplevelExpr ((f, t) : env) cont + ToplevelAssert e1 e2 -> do + t <- typecheckExpr env e1 + when (t /= BoolTy) $ do + throwInternalError $ "wrong type toplevel assertion: expr = " ++ formatExpr e1 ++ " has type = " ++ formatType t + typecheckToplevelExpr env e2 typecheckProgram :: MonadError Error m => Program -> m Type typecheckProgram prog = wrapError' "Jikka.Core.Language.TypeCheck.typecheckProgram" $ do diff --git a/src/Jikka/Core/Language/Util.hs b/src/Jikka/Core/Language/Util.hs index 6f0382ba..e2c7ee05 100644 --- a/src/Jikka/Core/Language/Util.hs +++ b/src/Jikka/Core/Language/Util.hs @@ -59,6 +59,7 @@ mapTypeExprM f = go App f e -> App <$> go f <*> go e Lam x t body -> Lam x <$> f t <*> go body Let x t e1 e2 -> Let x <$> f t <*> go e1 <*> go e2 + Assert e1 e2 -> Assert <$> go e1 <*> go e2 mapTypeExpr :: (Type -> Type) -> Expr -> Expr mapTypeExpr f e = runIdentity (mapTypeExprM (return . f) e) @@ -68,6 +69,7 @@ mapTypeToplevelExprM f = \case ResultExpr e -> ResultExpr <$> mapTypeExprM f e ToplevelLet x t e cont -> ToplevelLet x <$> f t <*> mapTypeExprM f e <*> mapTypeToplevelExprM f cont ToplevelLetRec g args ret body cont -> ToplevelLetRec g <$> mapM (\(x, t) -> (x,) <$> f t) args <*> f ret <*> mapTypeExprM f body <*> mapTypeToplevelExprM f cont + ToplevelAssert e cont -> ToplevelAssert <$> mapTypeExprM f e <*> mapTypeToplevelExprM f cont mapTypeProgramM :: Monad m => (Type -> m Type) -> Program -> m Program mapTypeProgramM = mapTypeToplevelExprM @@ -86,16 +88,34 @@ mapExprM' pre post env e = do App g e -> App <$> go env g <*> go env e Lam x t body -> Lam x t <$> go ((x, t) : env) body Let y t e1 e2 -> Let y t <$> go env e1 <*> go ((y, t) : env) e2 + Assert e1 e2 -> Assert <$> go env e1 <*> go env e2 + post env e + +mapToplevelExprM' :: Monad m => ([(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr) -> ([(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr) -> [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr +mapToplevelExprM' pre post env e = do + e <- pre env e + e <- case e of + ResultExpr e -> return $ ResultExpr e + ToplevelLet y t e cont -> + ToplevelLet y t e <$> mapToplevelExprM' pre post ((y, t) : env) cont + ToplevelLetRec g args ret body cont -> + let env' = (g, foldr (FunTy . snd) ret args) : env + in ToplevelLetRec g args ret body <$> mapToplevelExprM' pre post env' cont + ToplevelAssert e cont -> + ToplevelAssert e <$> mapToplevelExprM' pre post env cont post env e mapExprToplevelExprM' :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> ([(VarName, Type)] -> Expr -> m Expr) -> [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr -mapExprToplevelExprM' pre post env = \case - ResultExpr e -> ResultExpr <$> mapExprM' pre post env e - ToplevelLet y t e cont -> - ToplevelLet y t <$> mapExprM' pre post env e <*> mapExprToplevelExprM' pre post ((y, t) : env) cont - ToplevelLetRec g args ret body cont -> - let env' = (g, foldr (FunTy . snd) ret args) : env - in ToplevelLetRec g args ret <$> mapExprM' pre post (reverse args ++ env') body <*> mapExprToplevelExprM' pre post env' cont +mapExprToplevelExprM' pre post env = mapToplevelExprM' pre' (\_ e -> return e) env + where + go = mapExprM' pre post + pre' env = \case + ResultExpr e -> ResultExpr <$> go env e + ToplevelLet y t e cont -> ToplevelLet y t <$> go env e <*> pure cont + ToplevelLetRec g args ret body cont -> + let env' = (g, foldr (FunTy . snd) ret args) : env + in ToplevelLetRec g args ret <$> go (reverse args ++ env') body <*> pure cont + ToplevelAssert e cont -> ToplevelAssert <$> go env e <*> pure cont mapExprProgramM' :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> ([(VarName, Type)] -> Expr -> m Expr) -> Program -> m Program mapExprProgramM' pre post = mapExprToplevelExprM' pre post [] @@ -119,6 +139,16 @@ mapExprToplevelExpr f env e = runIdentity $ mapExprToplevelExprM (\env e -> retu mapExprProgram :: ([(VarName, Type)] -> Expr -> Expr) -> Program -> Program mapExprProgram f prog = runIdentity $ mapExprProgramM (\env e -> return $ f env e) prog +-- | `mapToplevelExprM` is a wrapper of `mapToplevelExprM'`. This function works in post-order. +mapToplevelExprM :: Monad m => ([(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr) -> [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr +mapToplevelExprM f env e = mapToplevelExprM' (\_ e -> return e) f env e + +mapToplevelExprProgramM :: Monad m => ([(VarName, Type)] -> Program -> m Program) -> Program -> m Program +mapToplevelExprProgramM f prog = mapToplevelExprM f [] prog + +mapToplevelExprProgram :: ([(VarName, Type)] -> Program -> Program) -> Program -> Program +mapToplevelExprProgram f prog = runIdentity $ mapToplevelExprProgramM (\env e -> return $ f env e) prog + listSubExprs :: Expr -> [Expr] listSubExprs e = getDual . execWriter $ mapExprM go [] e where @@ -284,6 +314,7 @@ isConstantTimeExpr = \case _ -> False Lam _ _ _ -> True Let _ _ e1 e2 -> isConstantTimeExpr e1 && isConstantTimeExpr e2 + Assert e1 e2 -> isConstantTimeExpr e1 && isConstantTimeExpr e2 -- | `replaceLenF` replaces @len(f)@ in an expr with @i + k@. -- * This assumes that there are no name conflicts. @@ -299,6 +330,7 @@ replaceLenF f i k = go Lam x t body -> Lam x t <$> (if x == f then return body else go body) Let y _ _ _ | y == i -> throwInternalError "Jikka.Core.Language.Util.replaceLenF: name conflict" Let y t e1 e2 -> Let y t <$> go e1 <*> (if y == f then return e2 else go e2) + Assert e1 e2 -> Assert <$> go e1 <*> go e2 -- | `getRecurrenceFormulaBase` makes a pair @((a_0, ..., a_{k - 1}), a)@ from @setat (... (setat a 0 a_0) ...) (k - 1) a_{k - 1})@. getRecurrenceFormulaBase :: Expr -> ([Expr], Expr) diff --git a/src/Jikka/Core/Parse/Alex.x b/src/Jikka/Core/Parse/Alex.x index ae807188..d594f038 100644 --- a/src/Jikka/Core/Parse/Alex.x +++ b/src/Jikka/Core/Parse/Alex.x @@ -63,6 +63,7 @@ tokens :- "if" { tok If } "then" { tok Then } "else" { tok Else } + "assert" { tok Assert } "forall" { tok Forall } -- punctuations diff --git a/src/Jikka/Core/Parse/Happy.y b/src/Jikka/Core/Parse/Happy.y index f2f7149d..4aa79736 100644 --- a/src/Jikka/Core/Parse/Happy.y +++ b/src/Jikka/Core/Parse/Happy.y @@ -55,6 +55,7 @@ import qualified Jikka.Core.Parse.Token as L "if" { WithLoc _ L.If } "then" { WithLoc _ L.Then } "else" { WithLoc _ L.Else } + "assert" { WithLoc _ L.Assert } "forall" { WithLoc _ L.Forall } -- punctuations @@ -471,10 +472,15 @@ lambda_expr :: { Expr } let_expr :: { Expr } : "let" identifier ":" type "=" expression "in" expression { Let $2 $4 $6 $8 } +-- Assertion +assert_expr :: { Expr } + : "assert" expression "->" expression { Assert $2 $4 } + expression_nolet :: { Expr } : implies_test { $1 } | conditional_expression { $1 } | lambda_expr { $1 } + | assert_expr { $1 } expression :: { Expr } : expression_nolet { $1 } | let_expr { $1 } diff --git a/src/Jikka/Core/Parse/Token.hs b/src/Jikka/Core/Parse/Token.hs index c061f015..9fe8387e 100644 --- a/src/Jikka/Core/Parse/Token.hs +++ b/src/Jikka/Core/Parse/Token.hs @@ -61,6 +61,7 @@ data Token | Else | Fun | Dot + | Assert | Forall | -- punctuations Arrow diff --git a/src/Jikka/RestrictedPython/Convert/ToCore.hs b/src/Jikka/RestrictedPython/Convert/ToCore.hs index 8cd305c4..237e25fd 100644 --- a/src/Jikka/RestrictedPython/Convert/ToCore.hs +++ b/src/Jikka/RestrictedPython/Convert/ToCore.hs @@ -387,7 +387,10 @@ runStatements (stmt : stmts) cont = case stmt of runStatements stmts cont X.For x iter body -> runForStatement x iter body stmts cont X.If e body1 body2 -> runIfStatement e body1 body2 stmts cont - X.Assert _ -> runStatements stmts cont + X.Assert e -> do + e <- runExpr e + cont <- runStatements stmts cont + return $ Y.Assert e cont X.Append loc t x e -> do case X.exprToTarget x of Nothing -> throwSemanticErrorAt' loc "invalid `append` method"