-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #210 from kmyk/sum-sum-abs
feat(core): Reduce `\sum \sum |a_i - a_j|`
- Loading branch information
Showing
9 changed files
with
252 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# https://judge.kimiyuki.net/problem/sum-sum-abs-one | ||
from typing import * | ||
|
||
def solve(a: List[int]) -> int: | ||
ans = 0 | ||
for a_i in a: | ||
for a_j in a: | ||
ans += abs(a_i - a_j) | ||
return ans | ||
|
||
def main() -> None: | ||
n = int(input()) | ||
a = list(map(int, input().split())) | ||
assert len(a) == n | ||
ans = solve(a) | ||
print(ans) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,13 @@ | |
-- Maintainer : [email protected] | ||
-- Stability : experimental | ||
-- Portability : portable | ||
module Jikka.Core.Convert.Alpha where | ||
module Jikka.Core.Convert.Alpha | ||
( run, | ||
runProgram, | ||
runToplevelExpr, | ||
runExpr, | ||
) | ||
where | ||
|
||
import Jikka.Common.Alpha | ||
import Jikka.Common.Error | ||
|
@@ -22,31 +28,35 @@ rename x = do | |
i <- nextCounter | ||
return $ VarName (base ++ "$" ++ show i) | ||
|
||
runExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> Expr -> m Expr | ||
runExpr env = \case | ||
runExpr' :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> Expr -> m Expr | ||
runExpr' env = \case | ||
Var x -> case lookup x env of | ||
Nothing -> throwInternalError $ "undefined variable: " ++ unVarName x | ||
Just y -> return $ Var y | ||
Lit lit -> return $ Lit lit | ||
App f e -> App <$> runExpr env f <*> runExpr env e | ||
App f e -> App <$> runExpr' env f <*> runExpr' env e | ||
Lam x t body -> do | ||
y <- rename x | ||
body <- runExpr ((x, y) : env) body | ||
body <- runExpr' ((x, y) : env) body | ||
return $ Lam y t body | ||
Let x t e1 e2 -> do | ||
e1 <- runExpr env e1 | ||
e1 <- runExpr' env e1 | ||
y <- rename x | ||
e2 <- runExpr ((x, y) : env) e2 | ||
e2 <- runExpr' ((x, y) : env) e2 | ||
return $ Let y t e1 e2 | ||
Assert e1 e2 -> Assert <$> runExpr env e1 <*> runExpr env e2 | ||
Assert e1 e2 -> Assert <$> runExpr' env e1 <*> runExpr' env e2 | ||
|
||
runToplevelExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> ToplevelExpr -> m ToplevelExpr | ||
runToplevelExpr env = \case | ||
ResultExpr e -> ResultExpr <$> runExpr env e | ||
runExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> m Expr | ||
runExpr env e = wrapError' "Jikka.Core.Convert.Alpha" $ do | ||
runExpr' (map (\(x, _) -> (x, x)) env) e | ||
|
||
runToplevelExpr' :: (MonadAlpha m, MonadError Error m) => [(VarName, VarName)] -> ToplevelExpr -> m ToplevelExpr | ||
runToplevelExpr' env = \case | ||
ResultExpr e -> ResultExpr <$> runExpr' env e | ||
ToplevelLet x t e cont -> do | ||
y <- rename x | ||
e <- runExpr env e | ||
cont <- runToplevelExpr ((x, y) : env) cont | ||
e <- runExpr' env e | ||
cont <- runToplevelExpr' ((x, y) : env) cont | ||
return $ ToplevelLet y t e cont | ||
ToplevelLetRec f args ret body cont -> do | ||
g <- rename f | ||
|
@@ -55,13 +65,18 @@ runToplevelExpr env = \case | |
return (x, y, t) | ||
let args1 = map (\(x, y, _) -> (x, y)) args | ||
let args2 = map (\(_, y, t) -> (y, t)) args | ||
body <- runExpr (args1 ++ (f, g) : env) body | ||
cont <- runToplevelExpr ((f, g) : env) cont | ||
body <- runExpr' (args1 ++ (f, g) : env) body | ||
cont <- runToplevelExpr' ((f, g) : env) cont | ||
return $ ToplevelLetRec g args2 ret body cont | ||
ToplevelAssert e1 e2 -> ToplevelAssert <$> runExpr env e1 <*> runToplevelExpr env e2 | ||
ToplevelAssert e1 e2 -> ToplevelAssert <$> runExpr' env e1 <*> runToplevelExpr' env e2 | ||
|
||
runToplevelExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr | ||
runToplevelExpr env e = wrapError' "Jikka.Core.Convert.Alpha" $ do | ||
runToplevelExpr' (map (\(x, _) -> (x, x)) env) e | ||
|
||
runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program | ||
runProgram = runToplevelExpr [] | ||
runProgram prog = wrapError' "Jikka.Core.Convert.Alpha" $ do | ||
runToplevelExpr' [] prog | ||
|
||
-- | `run` renames variables in exprs to avoid name conflictions, even if the scopes of two variables are distinct. | ||
-- | ||
|
@@ -81,5 +96,4 @@ runProgram = runToplevelExpr [] | |
-- > in x2 = x0 + y1 | ||
-- > x2 + y1 | ||
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program | ||
run prog = wrapError' "Jikka.Core.Convert.Alpha" $ do | ||
runToplevelExpr [] prog | ||
run = runProgram |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
{-# LANGUAGE FlexibleContexts #-} | ||
{-# LANGUAGE LambdaCase #-} | ||
|
||
-- | | ||
-- Module : Jikka.Core.Convert.SortAbs | ||
-- Description : remove abs with sorting. / sort によって abs を除去します。 | ||
-- Copyright : (c) Kimiyuki Onaka, 2021 | ||
-- License : Apache License 2.0 | ||
-- Maintainer : [email protected] | ||
-- Stability : experimental | ||
-- Portability : portable | ||
-- | ||
-- \[ | ||
-- \newcommand\int{\mathbf{int}} | ||
-- \newcommand\bool{\mathbf{bool}} | ||
-- \newcommand\list{\mathbf{list}} | ||
-- \] | ||
module Jikka.Core.Convert.SortAbs | ||
( run, | ||
|
||
-- * internal rules | ||
rule, | ||
) | ||
where | ||
|
||
import Control.Monad.Trans.Maybe | ||
import Jikka.Common.Alpha | ||
import Jikka.Common.Error | ||
import qualified Jikka.Core.Convert.Alpha as Alpha | ||
import Jikka.Core.Language.ArithmeticExpr | ||
import Jikka.Core.Language.Beta | ||
import Jikka.Core.Language.BuiltinPatterns | ||
import Jikka.Core.Language.Expr | ||
import Jikka.Core.Language.Lint | ||
import Jikka.Core.Language.QuasiRules | ||
import Jikka.Core.Language.RewriteRules | ||
import Jikka.Core.Language.Util | ||
|
||
-- | @replaceAbsDelta x y z e@ replaces \(\levert x - y \rvert\) in \(e\) with \(z\). | ||
replaceAbsDelta :: VarName -> VarName -> VarName -> Expr -> Expr | ||
replaceAbsDelta x y z e = mapSubExpr go [] e | ||
where | ||
go _ = \case | ||
Abs' e | isZeroArithmeticExpr (parseArithmeticExpr (Minus' e (Minus' (Var x) (Var y)))) -> Var z | ||
Abs' e | isZeroArithmeticExpr (parseArithmeticExpr (Minus' e (Minus' (Var y) (Var x)))) -> Var z | ||
e -> e | ||
|
||
swapTwoVars :: MonadAlpha m => VarName -> VarName -> Expr -> m Expr | ||
swapTwoVars x y e = do | ||
x' <- genVarName x | ||
y' <- genVarName y | ||
e <- substitute x (Var x') e | ||
e <- substitute y (Var y') e | ||
e <- substitute x' (Var y) e | ||
substitute y' (Var x) e | ||
|
||
-- | TODO: accept more functions | ||
isSymmetric :: MonadAlpha m => VarName -> VarName -> Expr -> m Bool | ||
isSymmetric x y f = do | ||
g <- swapTwoVars x y f | ||
return $ parseArithmeticExpr g == parseArithmeticExpr f | ||
|
||
rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m | ||
rule = makeRewriteRule "sum/sum/abs/symmetric" $ \env -> \case | ||
Sum' (Map' IntTy _ (Lam x _ (Sum' (Map' _ _ (Lam y _ f) xs'))) xs) | xs' == xs -> runMaybeT $ do | ||
delta <- lift genVarName' | ||
let f' = replaceAbsDelta x y delta f | ||
guard $ f' /= f -- f has |x - y| | ||
guard =<< lift (isSymmetric x y f') -- symmetric | ||
ys <- lift $ genVarName'' xs | ||
i <- lift genVarName' | ||
j <- lift genVarName' | ||
lt <- lift $ substitute delta (Minus' (Var x) (Var y)) f' | ||
eq <- lift $ substitute delta (LitInt' 0) f' | ||
gt <- lift $ substitute delta (Minus' (Var y) (Var x)) f' | ||
let ctx = Let y IntTy (At' IntTy (Var ys) (Var j)) | ||
let lt' = Sum' (Map' IntTy IntTy (Lam j IntTy (ctx lt)) (Range1' (Var i))) | ||
let eq' = Let j IntTy (Var i) (ctx eq) | ||
let gt' = Sum' (Map' IntTy IntTy (Lam j IntTy (ctx gt)) (Range2' (Plus' (Var i) (LitInt' 1)) (Len' IntTy (Var ys)))) | ||
let e = | ||
Let ys (ListTy IntTy) (Sorted' IntTy xs) $ | ||
Sum' | ||
( Map' | ||
IntTy | ||
IntTy | ||
( Lam | ||
i | ||
IntTy | ||
( Let | ||
x | ||
IntTy | ||
(At' IntTy (Var ys) (Var i)) | ||
(Plus' (Plus' lt' eq') gt') | ||
) | ||
) | ||
(Range1' (Len' IntTy (Var ys))) | ||
) | ||
lift $ Alpha.runExpr (typeEnv env) e | ||
_ -> return Nothing | ||
|
||
runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program | ||
runProgram = applyRewriteRuleProgram' rule | ||
|
||
-- | `run` reduces \(\lvert \sum _ {a_i \in a} \sum _ {a_j \in a} f(a, a_i, a_j) \rvert\) to \(\mathbf{let}~ b = \mathrm{sort}(a) ~\mathbf{in}~ \sum \sum f'(a, a_i, a_j)\) when \(f\) contains \(\lvert a_i - a_j \rvert\) and \(f(a, a_i, a_j) = f(a, a_j, a_i)\) holds. | ||
-- | ||
-- == Example | ||
-- | ||
-- Before: | ||
-- | ||
-- > sum (map (fun (a_i: int) -> | ||
-- > sum (map (fun (a_j: int) -> | ||
-- > abs (a_i - a_j) | ||
-- > ) a) | ||
-- > ) a) | ||
-- | ||
-- After: | ||
-- | ||
-- > let b = sort a | ||
-- > in sum (map (fun (i: int) -> | ||
-- > (sum (map (fun (b_j: int) -> | ||
-- > b_i - b_j | ||
-- > ) b[:i]) | ||
-- > + 0 | ||
-- > + sum (map (fun (b_j: int) -> | ||
-- > b_j - b_i | ||
-- > ) b[i + 1:])) | ||
-- > ) (range (length b))) | ||
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program | ||
run prog = wrapError' "Jikka.Core.Convert.SortAbs" $ do | ||
precondition $ do | ||
ensureWellTyped prog | ||
prog <- runProgram prog | ||
postcondition $ do | ||
ensureWellTyped prog | ||
return prog |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
{-# LANGUAGE OverloadedStrings #-} | ||
|
||
module Jikka.Core.Convert.SortAbsSpec (spec) where | ||
|
||
import Jikka.Common.Alpha | ||
import Jikka.Common.Error | ||
import Jikka.Core.Convert.SortAbs (run) | ||
import qualified Jikka.Core.Convert.TypeInfer as TypeInfer | ||
import Jikka.Core.Format (formatProgram) | ||
import Jikka.Core.Language.Expr | ||
import Jikka.Core.Parse (parseProgram) | ||
import Test.Hspec | ||
|
||
run' :: Program -> Either Error Program | ||
run' = flip evalAlphaT 0 . run | ||
|
||
parseProgram' :: [String] -> Program | ||
parseProgram' = fromSuccess . flip evalAlphaT 100 . (TypeInfer.run <=< parseProgram . unlines) | ||
|
||
spec :: Spec | ||
spec = describe "run" $ do | ||
it "works about sum" $ do | ||
let prog = | ||
parseProgram' | ||
[ "fun (a: int list) ->", | ||
" sum (map (fun (a_i: int) ->", | ||
" sum (map (fun (a_j: int) ->", | ||
" abs (a_i - a_j)", | ||
" ) a)", | ||
" ) a)" | ||
] | ||
let expected = | ||
parseProgram' | ||
[ "fun (a: int list) ->", | ||
" let a$6 = sorted a", | ||
" in sum (map (fun ($7: int) ->", | ||
" let a_i$8 = a$6[$7] in", | ||
" sum (map (fun ($9: int) ->", | ||
" let a_j$10 = a$6[$9]", | ||
" in a_i$8 - a_j$10", | ||
" ) (range $7))", | ||
" + (let $11 = $7", | ||
" in let a_j$12 = a$6[$11]", | ||
" in 0)", | ||
" + sum (map (fun ($13: int) ->", | ||
" let a_j$14 = a$6[$13]", | ||
" in a_j$14 - a_i$8", | ||
" ) (range2 ($7 + 1) (len a$6)))", | ||
" ) (range (len a$6)))" | ||
] | ||
(formatProgram <$> run' prog) `shouldBe` Right (formatProgram expected) |