Skip to content

Commit

Permalink
Merge pull request #210 from kmyk/sum-sum-abs
Browse files Browse the repository at this point in the history
feat(core): Reduce `\sum \sum |a_i - a_j|`
  • Loading branch information
kmyk authored Sep 3, 2021
2 parents 366d746 + fc51b8e commit da7eb6e
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 23 deletions.
2 changes: 2 additions & 0 deletions Jikka.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ library
Jikka.Core.Convert.RemoveUnusedVars
Jikka.Core.Convert.SegmentTree
Jikka.Core.Convert.ShortCutFusion
Jikka.Core.Convert.SortAbs
Jikka.Core.Convert.SpecializeFoldl
Jikka.Core.Convert.TrivialLetElimination
Jikka.Core.Convert.TypeInfer
Expand Down Expand Up @@ -250,6 +251,7 @@ test-suite jikka-test
Jikka.Core.Convert.RemoveUnusedVarsSpec
Jikka.Core.Convert.SegmentTreeSpec
Jikka.Core.Convert.ShortCutFusionSpec
Jikka.Core.Convert.SortAbsSpec
Jikka.Core.Convert.SpecializeFoldlSpec
Jikka.Core.Convert.TrivialLetEliminationSpec
Jikka.Core.Convert.TypeInferSpec
Expand Down
3 changes: 2 additions & 1 deletion examples/data/METADATA.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
"sum_sum_plus_one": "jikka-judge://sum-sum-plus-one",
"sum_sum_plus_one_lt": "jikka-judge://sum-sum-plus-one-lt",
"sum_sum_plus_two": "jikka-judge://sum-sum-plus-two",
"sum_sum_square": "jikka-judge://sum-sum-square"
"sum_sum_square": "jikka-judge://sum-sum-square",
"sum_sum_abs_one": "jikka-judge://sum-sum-abs-one"
}
19 changes: 19 additions & 0 deletions examples/sum_sum_abs_one.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# https://judge.kimiyuki.net/problem/sum-sum-abs-one
from typing import *

def solve(a: List[int]) -> int:
ans = 0
for a_i in a:
for a_j in a:
ans += abs(a_i - a_j)
return ans

def main() -> None:
n = int(input())
a = list(map(int, input().split()))
assert len(a) == n
ans = solve(a)
print(ans)

if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions scripts/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def generate_test_cases_of_library_checker(*, problem_id: str, is_jikka_judge: b

# collect testcases
testcases: List[Tuple[pathlib.Path, pathlib.Path]] = []
for inputcase in (info_toml.parent / 'in').iterdir():
for inputcase in sorted((info_toml.parent / 'in').iterdir()): # sorting makes examples first
outputcase = info_toml.parent / 'out' / (inputcase.stem + '.out')
testcases.append((inputcase, outputcase))
return testcases
Expand All @@ -69,15 +69,15 @@ def collect_test_cases(script: pathlib.Path, *, tempdir: pathlib.Path, library_c
testcases: List[Tuple[pathlib.Path, pathlib.Path]] = []

# text files
for path in pathlib.Path('examples', 'data').iterdir():
for path in sorted(pathlib.Path('examples', 'data').iterdir()):
if path.name[:-len(''.join(path.suffixes))] != script.stem:
continue
if path.suffix != '.in':
continue
testcases.append((path, path.with_suffix('.out')))

# using generators
for generator_path in pathlib.Path('examples', 'data').glob(glob.escape(script.stem) + '*.generator.py'):
for generator_path in sorted(pathlib.Path('examples', 'data').glob(glob.escape(script.stem) + '*.generator.py')):
_, testset_name, _, _ = generator_path.name.split('.')

for solver_ext in ('.py', '.cpp'):
Expand Down
2 changes: 2 additions & 0 deletions src/Jikka/Core/Convert.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import qualified Jikka.Core.Convert.PropagateMod as PropagateMod
import qualified Jikka.Core.Convert.RemoveUnusedVars as RemoveUnusedVars
import qualified Jikka.Core.Convert.SegmentTree as SegmentTree
import qualified Jikka.Core.Convert.ShortCutFusion as ShortCutFusion
import qualified Jikka.Core.Convert.SortAbs as SortAbs
import qualified Jikka.Core.Convert.SpecializeFoldl as SpecializeFoldl
import qualified Jikka.Core.Convert.TrivialLetElimination as TrivialLetElimination
import qualified Jikka.Core.Convert.TypeInfer as TypeInfer
Expand All @@ -49,6 +50,7 @@ run'' prog = do
prog <- UnpackTuple.run prog
prog <- MatrixExponentiation.run prog
prog <- SpecializeFoldl.run prog
prog <- SortAbs.run prog
prog <- MakeScanl.run prog
prog <- PropagateMod.run prog
prog <- ConstantPropagation.run prog
Expand Down
52 changes: 33 additions & 19 deletions src/Jikka/Core/Convert/Alpha.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
-- Maintainer : [email protected]
-- Stability : experimental
-- Portability : portable
module Jikka.Core.Convert.Alpha where
module Jikka.Core.Convert.Alpha
( run,
runProgram,
runToplevelExpr,
runExpr,
)
where

import Jikka.Common.Alpha
import Jikka.Common.Error
Expand All @@ -22,31 +28,35 @@ rename x = do
i <- nextCounter
return $ VarName (base ++ "$" ++ show i)

runExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> Expr -> m Expr
runExpr env = \case
runExpr' :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> Expr -> m Expr
runExpr' env = \case
Var x -> case lookup x env of
Nothing -> throwInternalError $ "undefined variable: " ++ unVarName x
Just y -> return $ Var y
Lit lit -> return $ Lit lit
App f e -> App <$> runExpr env f <*> runExpr env e
App f e -> App <$> runExpr' env f <*> runExpr' env e
Lam x t body -> do
y <- rename x
body <- runExpr ((x, y) : env) body
body <- runExpr' ((x, y) : env) body
return $ Lam y t body
Let x t e1 e2 -> do
e1 <- runExpr env e1
e1 <- runExpr' env e1
y <- rename x
e2 <- runExpr ((x, y) : env) e2
e2 <- runExpr' ((x, y) : env) e2
return $ Let y t e1 e2
Assert e1 e2 -> Assert <$> runExpr env e1 <*> runExpr env e2
Assert e1 e2 -> Assert <$> runExpr' env e1 <*> runExpr' env e2

runToplevelExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> ToplevelExpr -> m ToplevelExpr
runToplevelExpr env = \case
ResultExpr e -> ResultExpr <$> runExpr env e
runExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> m Expr
runExpr env e = wrapError' "Jikka.Core.Convert.Alpha" $ do
runExpr' (map (\(x, _) -> (x, x)) env) e

runToplevelExpr' :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> ToplevelExpr -> m ToplevelExpr
runToplevelExpr' env = \case
ResultExpr e -> ResultExpr <$> runExpr' env e
ToplevelLet x t e cont -> do
y <- rename x
e <- runExpr env e
cont <- runToplevelExpr ((x, y) : env) cont
e <- runExpr' env e
cont <- runToplevelExpr' ((x, y) : env) cont
return $ ToplevelLet y t e cont
ToplevelLetRec f args ret body cont -> do
g <- rename f
Expand All @@ -55,13 +65,18 @@ runToplevelExpr env = \case
return (x, y, t)
let args1 = map (\(x, y, _) -> (x, y)) args
let args2 = map (\(_, y, t) -> (y, t)) args
body <- runExpr (args1 ++ (f, g) : env) body
cont <- runToplevelExpr ((f, g) : env) cont
body <- runExpr' (args1 ++ (f, g) : env) body
cont <- runToplevelExpr' ((f, g) : env) cont
return $ ToplevelLetRec g args2 ret body cont
ToplevelAssert e1 e2 -> ToplevelAssert <$> runExpr env e1 <*> runToplevelExpr env e2
ToplevelAssert e1 e2 -> ToplevelAssert <$> runExpr' env e1 <*> runToplevelExpr' env e2

runToplevelExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr
runToplevelExpr env e = wrapError' "Jikka.Core.Convert.Alpha" $ do
runToplevelExpr' (map (\(x, _) -> (x, x)) env) e

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram = runToplevelExpr []
runProgram prog = wrapError' "Jikka.Core.Convert.Alpha" $ do
runToplevelExpr' [] prog

-- | `run` renames variables in exprs to avoid name conflictions, even if the scopes of two variables are distinct.
--
Expand All @@ -81,5 +96,4 @@ runProgram = runToplevelExpr []
-- > in x2 = x0 + y1
-- > x2 + y1
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run prog = wrapError' "Jikka.Core.Convert.Alpha" $ do
runToplevelExpr [] prog
run = runProgram
135 changes: 135 additions & 0 deletions src/Jikka/Core/Convert/SortAbs.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module : Jikka.Core.Convert.SortAbs
-- Description : remove abs with sorting. / sort によって abs を除去します。
-- Copyright : (c) Kimiyuki Onaka, 2021
-- License : Apache License 2.0
-- Maintainer : [email protected]
-- Stability : experimental
-- Portability : portable
--
-- \[
-- \newcommand\int{\mathbf{int}}
-- \newcommand\bool{\mathbf{bool}}
-- \newcommand\list{\mathbf{list}}
-- \]
module Jikka.Core.Convert.SortAbs
( run,

-- * internal rules
rule,
)
where

import Control.Monad.Trans.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import qualified Jikka.Core.Convert.Alpha as Alpha
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Lint
import Jikka.Core.Language.QuasiRules
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

-- | @replaceAbsDelta x y z e@ replaces \(\levert x - y \rvert\) in \(e\) with \(z\).
replaceAbsDelta :: VarName -> VarName -> VarName -> Expr -> Expr
replaceAbsDelta x y z e = mapSubExpr go [] e
where
go _ = \case
Abs' e | isZeroArithmeticExpr (parseArithmeticExpr (Minus' e (Minus' (Var x) (Var y)))) -> Var z
Abs' e | isZeroArithmeticExpr (parseArithmeticExpr (Minus' e (Minus' (Var y) (Var x)))) -> Var z
e -> e

swapTwoVars :: MonadAlpha m => VarName -> VarName -> Expr -> m Expr
swapTwoVars x y e = do
x' <- genVarName x
y' <- genVarName y
e <- substitute x (Var x') e
e <- substitute y (Var y') e
e <- substitute x' (Var y) e
substitute y' (Var x) e

-- | TODO: accept more functions
isSymmetric :: MonadAlpha m => VarName -> VarName -> Expr -> m Bool
isSymmetric x y f = do
g <- swapTwoVars x y f
return $ parseArithmeticExpr g == parseArithmeticExpr f

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule = makeRewriteRule "sum/sum/abs/symmetric" $ \env -> \case
Sum' (Map' IntTy _ (Lam x _ (Sum' (Map' _ _ (Lam y _ f) xs'))) xs) | xs' == xs -> runMaybeT $ do
delta <- lift genVarName'
let f' = replaceAbsDelta x y delta f
guard $ f' /= f -- f has |x - y|
guard =<< lift (isSymmetric x y f') -- symmetric
ys <- lift $ genVarName'' xs
i <- lift genVarName'
j <- lift genVarName'
lt <- lift $ substitute delta (Minus' (Var x) (Var y)) f'
eq <- lift $ substitute delta (LitInt' 0) f'
gt <- lift $ substitute delta (Minus' (Var y) (Var x)) f'
let ctx = Let y IntTy (At' IntTy (Var ys) (Var j))
let lt' = Sum' (Map' IntTy IntTy (Lam j IntTy (ctx lt)) (Range1' (Var i)))
let eq' = Let j IntTy (Var i) (ctx eq)
let gt' = Sum' (Map' IntTy IntTy (Lam j IntTy (ctx gt)) (Range2' (Plus' (Var i) (LitInt' 1)) (Len' IntTy (Var ys))))
let e =
Let ys (ListTy IntTy) (Sorted' IntTy xs) $
Sum'
( Map'
IntTy
IntTy
( Lam
i
IntTy
( Let
x
IntTy
(At' IntTy (Var ys) (Var i))
(Plus' (Plus' lt' eq') gt')
)
)
(Range1' (Len' IntTy (Var ys)))
)
lift $ Alpha.runExpr (typeEnv env) e
_ -> return Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram = applyRewriteRuleProgram' rule

-- | `run` reduces \(\lvert \sum _ {a_i \in a} \sum _ {a_j \in a} f(a, a_i, a_j) \rvert\) to \(\mathbf{let}~ b = \mathrm{sort}(a) ~\mathbf{in}~ \sum \sum f'(a, a_i, a_j)\) when \(f\) contains \(\lvert a_i - a_j \rvert\) and \(f(a, a_i, a_j) = f(a, a_j, a_i)\) holds.
--
-- == Example
--
-- Before:
--
-- > sum (map (fun (a_i: int) ->
-- > sum (map (fun (a_j: int) ->
-- > abs (a_i - a_j)
-- > ) a)
-- > ) a)
--
-- After:
--
-- > let b = sort a
-- > in sum (map (fun (i: int) ->
-- > (sum (map (fun (b_j: int) ->
-- > b_i - b_j
-- > ) b[:i])
-- > + 0
-- > + sum (map (fun (b_j: int) ->
-- > b_j - b_i
-- > ) b[i + 1:]))
-- > ) (range (length b)))
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run prog = wrapError' "Jikka.Core.Convert.SortAbs" $ do
precondition $ do
ensureWellTyped prog
prog <- runProgram prog
postcondition $ do
ensureWellTyped prog
return prog
5 changes: 5 additions & 0 deletions src/Jikka/Core/Language/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ genVarName x = do
genVarName' :: MonadAlpha m => m VarName
genVarName' = genVarName (VarName "_")

genVarName'' :: MonadAlpha m => Expr -> m VarName
genVarName'' = \case
Var x -> genVarName x
_ -> genVarName'

mapSubTypesM :: Monad m => (Type -> m Type) -> Type -> m Type
mapSubTypesM f = go
where
Expand Down
51 changes: 51 additions & 0 deletions test/Jikka/Core/Convert/SortAbsSpec.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{-# LANGUAGE OverloadedStrings #-}

module Jikka.Core.Convert.SortAbsSpec (spec) where

import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Convert.SortAbs (run)
import qualified Jikka.Core.Convert.TypeInfer as TypeInfer
import Jikka.Core.Format (formatProgram)
import Jikka.Core.Language.Expr
import Jikka.Core.Parse (parseProgram)
import Test.Hspec

run' :: Program -> Either Error Program
run' = flip evalAlphaT 0 . run

parseProgram' :: [String] -> Program
parseProgram' = fromSuccess . flip evalAlphaT 100 . (TypeInfer.run <=< parseProgram . unlines)

spec :: Spec
spec = describe "run" $ do
it "works about sum" $ do
let prog =
parseProgram'
[ "fun (a: int list) ->",
" sum (map (fun (a_i: int) ->",
" sum (map (fun (a_j: int) ->",
" abs (a_i - a_j)",
" ) a)",
" ) a)"
]
let expected =
parseProgram'
[ "fun (a: int list) ->",
" let a$6 = sorted a",
" in sum (map (fun ($7: int) ->",
" let a_i$8 = a$6[$7] in",
" sum (map (fun ($9: int) ->",
" let a_j$10 = a$6[$9]",
" in a_i$8 - a_j$10",
" ) (range $7))",
" + (let $11 = $7",
" in let a_j$12 = a$6[$11]",
" in 0)",
" + sum (map (fun ($13: int) ->",
" let a_j$14 = a$6[$13]",
" in a_j$14 - a_i$8",
" ) (range2 ($7 + 1) (len a$6)))",
" ) (range (len a$6)))"
]
(formatProgram <$> run' prog) `shouldBe` Right (formatProgram expected)

0 comments on commit da7eb6e

Please sign in to comment.