diff --git a/Jikka.cabal b/Jikka.cabal index f4d13798..ee1068b8 100644 --- a/Jikka.cabal +++ b/Jikka.cabal @@ -84,6 +84,7 @@ library Jikka.Core.Convert.RemoveUnusedVars Jikka.Core.Convert.SegmentTree Jikka.Core.Convert.ShortCutFusion + Jikka.Core.Convert.SortAbs Jikka.Core.Convert.SpecializeFoldl Jikka.Core.Convert.TrivialLetElimination Jikka.Core.Convert.TypeInfer @@ -250,6 +251,7 @@ test-suite jikka-test Jikka.Core.Convert.RemoveUnusedVarsSpec Jikka.Core.Convert.SegmentTreeSpec Jikka.Core.Convert.ShortCutFusionSpec + Jikka.Core.Convert.SortAbsSpec Jikka.Core.Convert.SpecializeFoldlSpec Jikka.Core.Convert.TrivialLetEliminationSpec Jikka.Core.Convert.TypeInferSpec diff --git a/examples/data/METADATA.json b/examples/data/METADATA.json index 75ff297a..c0d1f9c4 100644 --- a/examples/data/METADATA.json +++ b/examples/data/METADATA.json @@ -11,5 +11,6 @@ "sum_sum_plus_one": "jikka-judge://sum-sum-plus-one", "sum_sum_plus_one_lt": "jikka-judge://sum-sum-plus-one-lt", "sum_sum_plus_two": "jikka-judge://sum-sum-plus-two", - "sum_sum_square": "jikka-judge://sum-sum-square" + "sum_sum_square": "jikka-judge://sum-sum-square", + "sum_sum_abs_one": "jikka-judge://sum-sum-abs-one" } diff --git a/examples/sum_sum_abs_one.py b/examples/sum_sum_abs_one.py new file mode 100644 index 00000000..bba1c900 --- /dev/null +++ b/examples/sum_sum_abs_one.py @@ -0,0 +1,19 @@ +# https://judge.kimiyuki.net/problem/sum-sum-abs-one +from typing import * + +def solve(a: List[int]) -> int: + ans = 0 + for a_i in a: + for a_j in a: + ans += abs(a_i - a_j) + return ans + +def main() -> None: + n = int(input()) + a = list(map(int, input().split())) + assert len(a) == n + ans = solve(a) + print(ans) + +if __name__ == "__main__": + main() diff --git a/scripts/integration_tests.py b/scripts/integration_tests.py index bad67b4b..d6c1c777 100644 --- a/scripts/integration_tests.py +++ b/scripts/integration_tests.py @@ -58,7 +58,7 @@ def generate_test_cases_of_library_checker(*, problem_id: str, is_jikka_judge: b # collect testcases testcases: List[Tuple[pathlib.Path, pathlib.Path]] = [] - for inputcase in (info_toml.parent / 'in').iterdir(): + for inputcase in sorted((info_toml.parent / 'in').iterdir()): # sorting makes examples first outputcase = info_toml.parent / 'out' / (inputcase.stem + '.out') testcases.append((inputcase, outputcase)) return testcases @@ -69,7 +69,7 @@ def collect_test_cases(script: pathlib.Path, *, tempdir: pathlib.Path, library_c testcases: List[Tuple[pathlib.Path, pathlib.Path]] = [] # text files - for path in pathlib.Path('examples', 'data').iterdir(): + for path in sorted(pathlib.Path('examples', 'data').iterdir()): if path.name[:-len(''.join(path.suffixes))] != script.stem: continue if path.suffix != '.in': @@ -77,7 +77,7 @@ def collect_test_cases(script: pathlib.Path, *, tempdir: pathlib.Path, library_c testcases.append((path, path.with_suffix('.out'))) # using generators - for generator_path in pathlib.Path('examples', 'data').glob(glob.escape(script.stem) + '*.generator.py'): + for generator_path in sorted(pathlib.Path('examples', 'data').glob(glob.escape(script.stem) + '*.generator.py')): _, testset_name, _, _ = generator_path.name.split('.') for solver_ext in ('.py', '.cpp'): diff --git a/src/Jikka/Core/Convert.hs b/src/Jikka/Core/Convert.hs index f576e335..f50c31b9 100644 --- a/src/Jikka/Core/Convert.hs +++ b/src/Jikka/Core/Convert.hs @@ -37,6 +37,7 @@ import qualified Jikka.Core.Convert.PropagateMod as PropagateMod import qualified Jikka.Core.Convert.RemoveUnusedVars as RemoveUnusedVars import qualified Jikka.Core.Convert.SegmentTree as SegmentTree import qualified Jikka.Core.Convert.ShortCutFusion as ShortCutFusion +import qualified Jikka.Core.Convert.SortAbs as SortAbs import qualified Jikka.Core.Convert.SpecializeFoldl as SpecializeFoldl import qualified Jikka.Core.Convert.TrivialLetElimination as TrivialLetElimination import qualified Jikka.Core.Convert.TypeInfer as TypeInfer @@ -49,6 +50,7 @@ run'' prog = do prog <- UnpackTuple.run prog prog <- MatrixExponentiation.run prog prog <- SpecializeFoldl.run prog + prog <- SortAbs.run prog prog <- MakeScanl.run prog prog <- PropagateMod.run prog prog <- ConstantPropagation.run prog diff --git a/src/Jikka/Core/Convert/Alpha.hs b/src/Jikka/Core/Convert/Alpha.hs index 9424c22a..9db7457d 100644 --- a/src/Jikka/Core/Convert/Alpha.hs +++ b/src/Jikka/Core/Convert/Alpha.hs @@ -10,7 +10,13 @@ -- Maintainer : kimiyuki95@gmail.com -- Stability : experimental -- Portability : portable -module Jikka.Core.Convert.Alpha where +module Jikka.Core.Convert.Alpha + ( run, + runProgram, + runToplevelExpr, + runExpr, + ) +where import Jikka.Common.Alpha import Jikka.Common.Error @@ -22,31 +28,35 @@ rename x = do i <- nextCounter return $ VarName (base ++ "$" ++ show i) -runExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> Expr -> m Expr -runExpr env = \case +runExpr' :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> Expr -> m Expr +runExpr' env = \case Var x -> case lookup x env of Nothing -> throwInternalError $ "undefined variable: " ++ unVarName x Just y -> return $ Var y Lit lit -> return $ Lit lit - App f e -> App <$> runExpr env f <*> runExpr env e + App f e -> App <$> runExpr' env f <*> runExpr' env e Lam x t body -> do y <- rename x - body <- runExpr ((x, y) : env) body + body <- runExpr' ((x, y) : env) body return $ Lam y t body Let x t e1 e2 -> do - e1 <- runExpr env e1 + e1 <- runExpr' env e1 y <- rename x - e2 <- runExpr ((x, y) : env) e2 + e2 <- runExpr' ((x, y) : env) e2 return $ Let y t e1 e2 - Assert e1 e2 -> Assert <$> runExpr env e1 <*> runExpr env e2 + Assert e1 e2 -> Assert <$> runExpr' env e1 <*> runExpr' env e2 -runToplevelExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> ToplevelExpr -> m ToplevelExpr -runToplevelExpr env = \case - ResultExpr e -> ResultExpr <$> runExpr env e +runExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> m Expr +runExpr env e = wrapError' "Jikka.Core.Convert.Alpha" $ do + runExpr' (map (\(x, _) -> (x, x)) env) e + +runToplevelExpr' :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> ToplevelExpr -> m ToplevelExpr +runToplevelExpr' env = \case + ResultExpr e -> ResultExpr <$> runExpr' env e ToplevelLet x t e cont -> do y <- rename x - e <- runExpr env e - cont <- runToplevelExpr ((x, y) : env) cont + e <- runExpr' env e + cont <- runToplevelExpr' ((x, y) : env) cont return $ ToplevelLet y t e cont ToplevelLetRec f args ret body cont -> do g <- rename f @@ -55,13 +65,18 @@ runToplevelExpr env = \case return (x, y, t) let args1 = map (\(x, y, _) -> (x, y)) args let args2 = map (\(_, y, t) -> (y, t)) args - body <- runExpr (args1 ++ (f, g) : env) body - cont <- runToplevelExpr ((f, g) : env) cont + body <- runExpr' (args1 ++ (f, g) : env) body + cont <- runToplevelExpr' ((f, g) : env) cont return $ ToplevelLetRec g args2 ret body cont - ToplevelAssert e1 e2 -> ToplevelAssert <$> runExpr env e1 <*> runToplevelExpr env e2 + ToplevelAssert e1 e2 -> ToplevelAssert <$> runExpr' env e1 <*> runToplevelExpr' env e2 + +runToplevelExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr +runToplevelExpr env e = wrapError' "Jikka.Core.Convert.Alpha" $ do + runToplevelExpr' (map (\(x, _) -> (x, x)) env) e runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program -runProgram = runToplevelExpr [] +runProgram prog = wrapError' "Jikka.Core.Convert.Alpha" $ do + runToplevelExpr' [] prog -- | `run` renames variables in exprs to avoid name conflictions, even if the scopes of two variables are distinct. -- @@ -81,5 +96,4 @@ runProgram = runToplevelExpr [] -- > in x2 = x0 + y1 -- > x2 + y1 run :: (MonadAlpha m, MonadError Error m) => Program -> m Program -run prog = wrapError' "Jikka.Core.Convert.Alpha" $ do - runToplevelExpr [] prog +run = runProgram diff --git a/src/Jikka/Core/Convert/SortAbs.hs b/src/Jikka/Core/Convert/SortAbs.hs new file mode 100644 index 00000000..f8d7251a --- /dev/null +++ b/src/Jikka/Core/Convert/SortAbs.hs @@ -0,0 +1,135 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +-- | +-- Module : Jikka.Core.Convert.SortAbs +-- Description : remove abs with sorting. / sort によって abs を除去します。 +-- Copyright : (c) Kimiyuki Onaka, 2021 +-- License : Apache License 2.0 +-- Maintainer : kimiyuki95@gmail.com +-- Stability : experimental +-- Portability : portable +-- +-- \[ +-- \newcommand\int{\mathbf{int}} +-- \newcommand\bool{\mathbf{bool}} +-- \newcommand\list{\mathbf{list}} +-- \] +module Jikka.Core.Convert.SortAbs + ( run, + + -- * internal rules + rule, + ) +where + +import Control.Monad.Trans.Maybe +import Jikka.Common.Alpha +import Jikka.Common.Error +import qualified Jikka.Core.Convert.Alpha as Alpha +import Jikka.Core.Language.ArithmeticExpr +import Jikka.Core.Language.Beta +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Lint +import Jikka.Core.Language.QuasiRules +import Jikka.Core.Language.RewriteRules +import Jikka.Core.Language.Util + +-- | @replaceAbsDelta x y z e@ replaces \(\levert x - y \rvert\) in \(e\) with \(z\). +replaceAbsDelta :: VarName -> VarName -> VarName -> Expr -> Expr +replaceAbsDelta x y z e = mapSubExpr go [] e + where + go _ = \case + Abs' e | isZeroArithmeticExpr (parseArithmeticExpr (Minus' e (Minus' (Var x) (Var y)))) -> Var z + Abs' e | isZeroArithmeticExpr (parseArithmeticExpr (Minus' e (Minus' (Var y) (Var x)))) -> Var z + e -> e + +swapTwoVars :: MonadAlpha m => VarName -> VarName -> Expr -> m Expr +swapTwoVars x y e = do + x' <- genVarName x + y' <- genVarName y + e <- substitute x (Var x') e + e <- substitute y (Var y') e + e <- substitute x' (Var y) e + substitute y' (Var x) e + +-- | TODO: accept more functions +isSymmetric :: MonadAlpha m => VarName -> VarName -> Expr -> m Bool +isSymmetric x y f = do + g <- swapTwoVars x y f + return $ parseArithmeticExpr g == parseArithmeticExpr f + +rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m +rule = makeRewriteRule "sum/sum/abs/symmetric" $ \env -> \case + Sum' (Map' IntTy _ (Lam x _ (Sum' (Map' _ _ (Lam y _ f) xs'))) xs) | xs' == xs -> runMaybeT $ do + delta <- lift genVarName' + let f' = replaceAbsDelta x y delta f + guard $ f' /= f -- f has |x - y| + guard =<< lift (isSymmetric x y f') -- symmetric + ys <- lift $ genVarName'' xs + i <- lift genVarName' + j <- lift genVarName' + lt <- lift $ substitute delta (Minus' (Var x) (Var y)) f' + eq <- lift $ substitute delta (LitInt' 0) f' + gt <- lift $ substitute delta (Minus' (Var y) (Var x)) f' + let ctx = Let y IntTy (At' IntTy (Var ys) (Var j)) + let lt' = Sum' (Map' IntTy IntTy (Lam j IntTy (ctx lt)) (Range1' (Var i))) + let eq' = Let j IntTy (Var i) (ctx eq) + let gt' = Sum' (Map' IntTy IntTy (Lam j IntTy (ctx gt)) (Range2' (Plus' (Var i) (LitInt' 1)) (Len' IntTy (Var ys)))) + let e = + Let ys (ListTy IntTy) (Sorted' IntTy xs) $ + Sum' + ( Map' + IntTy + IntTy + ( Lam + i + IntTy + ( Let + x + IntTy + (At' IntTy (Var ys) (Var i)) + (Plus' (Plus' lt' eq') gt') + ) + ) + (Range1' (Len' IntTy (Var ys))) + ) + lift $ Alpha.runExpr (typeEnv env) e + _ -> return Nothing + +runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program +runProgram = applyRewriteRuleProgram' rule + +-- | `run` reduces \(\lvert \sum _ {a_i \in a} \sum _ {a_j \in a} f(a, a_i, a_j) \rvert\) to \(\mathbf{let}~ b = \mathrm{sort}(a) ~\mathbf{in}~ \sum \sum f'(a, a_i, a_j)\) when \(f\) contains \(\lvert a_i - a_j \rvert\) and \(f(a, a_i, a_j) = f(a, a_j, a_i)\) holds. +-- +-- == Example +-- +-- Before: +-- +-- > sum (map (fun (a_i: int) -> +-- > sum (map (fun (a_j: int) -> +-- > abs (a_i - a_j) +-- > ) a) +-- > ) a) +-- +-- After: +-- +-- > let b = sort a +-- > in sum (map (fun (i: int) -> +-- > (sum (map (fun (b_j: int) -> +-- > b_i - b_j +-- > ) b[:i]) +-- > + 0 +-- > + sum (map (fun (b_j: int) -> +-- > b_j - b_i +-- > ) b[i + 1:])) +-- > ) (range (length b))) +run :: (MonadAlpha m, MonadError Error m) => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.SortAbs" $ do + precondition $ do + ensureWellTyped prog + prog <- runProgram prog + postcondition $ do + ensureWellTyped prog + return prog diff --git a/src/Jikka/Core/Language/Util.hs b/src/Jikka/Core/Language/Util.hs index 3b5f700e..d343b503 100644 --- a/src/Jikka/Core/Language/Util.hs +++ b/src/Jikka/Core/Language/Util.hs @@ -30,6 +30,11 @@ genVarName x = do genVarName' :: MonadAlpha m => m VarName genVarName' = genVarName (VarName "_") +genVarName'' :: MonadAlpha m => Expr -> m VarName +genVarName'' = \case + Var x -> genVarName x + _ -> genVarName' + mapSubTypesM :: Monad m => (Type -> m Type) -> Type -> m Type mapSubTypesM f = go where diff --git a/test/Jikka/Core/Convert/SortAbsSpec.hs b/test/Jikka/Core/Convert/SortAbsSpec.hs new file mode 100644 index 00000000..ccfb73c7 --- /dev/null +++ b/test/Jikka/Core/Convert/SortAbsSpec.hs @@ -0,0 +1,51 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Jikka.Core.Convert.SortAbsSpec (spec) where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Convert.SortAbs (run) +import qualified Jikka.Core.Convert.TypeInfer as TypeInfer +import Jikka.Core.Format (formatProgram) +import Jikka.Core.Language.Expr +import Jikka.Core.Parse (parseProgram) +import Test.Hspec + +run' :: Program -> Either Error Program +run' = flip evalAlphaT 0 . run + +parseProgram' :: [String] -> Program +parseProgram' = fromSuccess . flip evalAlphaT 100 . (TypeInfer.run <=< parseProgram . unlines) + +spec :: Spec +spec = describe "run" $ do + it "works about sum" $ do + let prog = + parseProgram' + [ "fun (a: int list) ->", + " sum (map (fun (a_i: int) ->", + " sum (map (fun (a_j: int) ->", + " abs (a_i - a_j)", + " ) a)", + " ) a)" + ] + let expected = + parseProgram' + [ "fun (a: int list) ->", + " let a$6 = sorted a", + " in sum (map (fun ($7: int) ->", + " let a_i$8 = a$6[$7] in", + " sum (map (fun ($9: int) ->", + " let a_j$10 = a$6[$9]", + " in a_i$8 - a_j$10", + " ) (range $7))", + " + (let $11 = $7", + " in let a_j$12 = a$6[$11]", + " in 0)", + " + sum (map (fun ($13: int) ->", + " let a_j$14 = a$6[$13]", + " in a_j$14 - a_i$8", + " ) (range2 ($7 + 1) (len a$6)))", + " ) (range (len a$6)))" + ] + (formatProgram <$> run' prog) `shouldBe` Right (formatProgram expected)