From 78b7af089e79c0bc6d1ccdb6f5be560018ab95d5 Mon Sep 17 00:00:00 2001 From: hotman78 Date: Fri, 6 Aug 2021 09:04:49 +0000 Subject: [PATCH 1/4] feat(core): add gcd/lcm Semigroup to segment tree --- runtime/include/jikka/segment_tree.hpp | 6 ++++++ src/Jikka/CPlusPlus/Convert/FromCore.hs | 16 ++++++++++++++++ src/Jikka/CPlusPlus/Format.hs | 2 ++ src/Jikka/CPlusPlus/Language/Expr.hs | 2 ++ src/Jikka/Core/Convert/CumulativeSum.hs | 8 ++++++++ src/Jikka/Core/Convert/SegmentTree.hs | 4 ++++ src/Jikka/Core/Convert/SpecializeFoldl.hs | 7 +++++++ src/Jikka/Core/Format.hs | 4 ++++ src/Jikka/Core/Language/BuiltinPatterns.hs | 4 ++++ src/Jikka/Core/Language/Expr.hs | 6 ++++++ src/Jikka/Core/Language/TypeCheck.hs | 4 ++++ 11 files changed, 63 insertions(+) diff --git a/runtime/include/jikka/segment_tree.hpp b/runtime/include/jikka/segment_tree.hpp index f9fc6c38..430215dc 100644 --- a/runtime/include/jikka/segment_tree.hpp +++ b/runtime/include/jikka/segment_tree.hpp @@ -17,8 +17,14 @@ inline int64_t min_int64_t(int64_t a, int64_t b) { return std::min(a, b); } inline int64_t max_int64_t(int64_t a, int64_t b) { return std::max(a, b); } +inline int64_t gcd_int64_t(int64_t a, int64_t b) { return std::gcd(a, b); } + +inline int64_t lcm_int64_t(int64_t a, int64_t b) { return std::lcm(a, b); } + inline int64_t const_zero() { return 0; } +inline int64_t const_one() { return 1; } + inline int64_t const_int64_min() { return INT64_MIN; } inline int64_t const_int64_max() { return INT64_MAX; } diff --git a/src/Jikka/CPlusPlus/Convert/FromCore.hs b/src/Jikka/CPlusPlus/Convert/FromCore.hs index ff92c19d..e37a0f2d 100644 --- a/src/Jikka/CPlusPlus/Convert/FromCore.hs +++ b/src/Jikka/CPlusPlus/Convert/FromCore.hs @@ -68,6 +68,8 @@ runSemigroup = \case X.SemigroupIntPlus -> Y.MonoidIntPlus X.SemigroupIntMin -> Y.MonoidIntMin X.SemigroupIntMax -> Y.MonoidIntMax + X.SemigroupIntGcd -> Y.MonoidIntGcd + X.SemigroupIntLcm -> Y.MonoidIntLcm runLiteral :: (MonadAlpha m, MonadError Error m) => Env -> X.Literal -> m Y.Expr runLiteral env = \case @@ -478,6 +480,20 @@ runAppBuiltin env f ts args = wrapError' ("converting builtin " ++ X.formatBuilt ], Y.Var y ) + X.Gcd1 -> go11' $ \t xs -> do + y <- Y.newFreshName Y.LocalNameKind + return + ( [ Y.Declare t y (Y.DeclareCopy (Y.UnOp Y.Deref (Y.callFunction "std::accumulate" [] [Y.begin xs, Y.end xs, Y.litInt64 0, Y.Lam [(Y.TyAuto,Y.VarName "a" ),(Y.TyAuto,Y.VarName "b")] Y.TyAuto [Y.Return $ Y.callFunction "std::gcd" [] [Y.Var $ Y.VarName "a",Y.Var $ Y.VarName "b"]]]))) + ], + Y.Var y + ) + X.Lcm1 -> go11' $ \t xs -> do + y <- Y.newFreshName Y.LocalNameKind + return + ( [ Y.Declare t y (Y.DeclareCopy (Y.UnOp Y.Deref (Y.callFunction "std::accumulate" [] [Y.begin xs, Y.end xs, Y.litInt64 1, Y.Lam [(Y.TyAuto,Y.VarName "a"),(Y.TyAuto,Y.VarName "b")] Y.TyAuto [Y.Return $ Y.callFunction "std::lcm" [] [Y.Var $ Y.VarName "a",Y.Var $ Y.VarName "b"]]]))) + ], + Y.Var y + ) X.All -> go01' $ \xs -> do y <- Y.newFreshName Y.LocalNameKind return diff --git a/src/Jikka/CPlusPlus/Format.hs b/src/Jikka/CPlusPlus/Format.hs index 6ff4485b..6f417135 100644 --- a/src/Jikka/CPlusPlus/Format.hs +++ b/src/Jikka/CPlusPlus/Format.hs @@ -189,6 +189,8 @@ formatType = \case MonoidIntPlus -> "atcoder::segtree" MonoidIntMin -> "atcoder::segtree" MonoidIntMax -> "atcoder::segtree" + MonoidIntGcd -> "atcoder::segtree" + MonoidIntLcm -> "atcoder::segtree" TyIntValue n -> show n formatLiteral :: Literal -> Code diff --git a/src/Jikka/CPlusPlus/Language/Expr.hs b/src/Jikka/CPlusPlus/Language/Expr.hs index 41d225d3..56233466 100644 --- a/src/Jikka/CPlusPlus/Language/Expr.hs +++ b/src/Jikka/CPlusPlus/Language/Expr.hs @@ -41,6 +41,8 @@ data Monoid' = MonoidIntPlus | MonoidIntMin | MonoidIntMax + | MonoidIntGcd + | MonoidIntLcm deriving (Eq, Ord, Show, Read) data Literal diff --git a/src/Jikka/Core/Convert/CumulativeSum.hs b/src/Jikka/Core/Convert/CumulativeSum.hs index acc6381e..afd5af07 100644 --- a/src/Jikka/Core/Convert/CumulativeSum.hs +++ b/src/Jikka/Core/Convert/CumulativeSum.hs @@ -59,6 +59,14 @@ rule = RewriteRule $ \_ -> \case Just <$> cumulativeMax (Min2' t) t (Just a0) a n Min1' t (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n)) | x' == x && x `isUnusedVar` a -> do Just <$> cumulativeMax (Min2' t) t Nothing a n + Lcm1' t (Cons' _ a0 (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n))) | x' == x && x `isUnusedVar` a -> do + Just <$> cumulativeMax Lcm' t (Just a0) a n + Lcm1' t (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n)) | x' == x && x `isUnusedVar` a -> do + Just <$> cumulativeMax Lcm' t Nothing a n + Gcd1' t (Cons' _ a0 (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n))) | x' == x && x `isUnusedVar` a -> do + Just <$> cumulativeMax Gcd' t (Just a0) a n + Gcd1' t (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n)) | x' == x && x `isUnusedVar` a -> do + Just <$> cumulativeMax Gcd' t Nothing a n _ -> return Nothing runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program diff --git a/src/Jikka/Core/Convert/SegmentTree.hs b/src/Jikka/Core/Convert/SegmentTree.hs index 37d2bdc0..0257b0c1 100644 --- a/src/Jikka/Core/Convert/SegmentTree.hs +++ b/src/Jikka/Core/Convert/SegmentTree.hs @@ -74,6 +74,8 @@ builtinToSemigroup builtin ts = case (builtin, ts) of (Plus, []) -> Just SemigroupIntPlus (Min2, [IntTy]) -> Just SemigroupIntMin (Max2, [IntTy]) -> Just SemigroupIntMax + (Gcd, []) -> Just SemigroupIntGcd + (Lcm, []) -> Just SemigroupIntLcm _ -> Nothing semigroupToBuiltin :: Semigroup' -> (Builtin, [Type]) @@ -81,6 +83,8 @@ semigroupToBuiltin = \case SemigroupIntPlus -> (Plus, []) SemigroupIntMin -> (Min2, [IntTy]) SemigroupIntMax -> (Max2, [IntTy]) + SemigroupIntGcd -> (Gcd,[]) + SemigroupIntLcm -> (Lcm,[]) unCumulativeSum :: Expr -> Expr -> Maybe (Semigroup', Expr) unCumulativeSum a = \case diff --git a/src/Jikka/Core/Convert/SpecializeFoldl.hs b/src/Jikka/Core/Convert/SpecializeFoldl.hs index ce6649d3..57a1d593 100644 --- a/src/Jikka/Core/Convert/SpecializeFoldl.hs +++ b/src/Jikka/Core/Convert/SpecializeFoldl.hs @@ -50,6 +50,13 @@ rule = simpleRewriteRule $ \case -- Max1 Min2' _ (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Min1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) Min2' _ e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Min1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) + + -- Lcm1 + Lcm' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Lcm1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) + Lcm' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Lcm1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) + -- Gcd1 + Gcd' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Gcd1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) + Gcd' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Gcd1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) -- others _ -> Nothing -- The outer floor-mod is required because foldl for empty lists returns values without modulo. diff --git a/src/Jikka/Core/Format.hs b/src/Jikka/Core/Format.hs index fb18c7b5..a6c50597 100644 --- a/src/Jikka/Core/Format.hs +++ b/src/Jikka/Core/Format.hs @@ -133,6 +133,8 @@ formatSemigroup = \case SemigroupIntPlus -> "int.plus" SemigroupIntMin -> "int.min" SemigroupIntMax -> "int.max" + SemigroupIntGcd -> "int.gcd" + SemigroupIntLcm -> "int.lcm" data Builtin' = Fun String @@ -217,6 +219,8 @@ analyzeBuiltin = \case Max1 -> Fun "max" ArgMin -> Fun "argmin" ArgMax -> Fun "argmax" + Gcd1 -> Fun "gcd" + Lcm1 -> Fun "lcm" All -> Fun "all" Any -> Fun "any" Sorted -> Fun "sort" diff --git a/src/Jikka/Core/Language/BuiltinPatterns.hs b/src/Jikka/Core/Language/BuiltinPatterns.hs index 4ab931dc..dbb3b8b6 100644 --- a/src/Jikka/Core/Language/BuiltinPatterns.hs +++ b/src/Jikka/Core/Language/BuiltinPatterns.hs @@ -147,6 +147,10 @@ pattern ArgMin' t e = AppBuiltin11 ArgMin t e pattern ArgMax' t e = AppBuiltin11 ArgMax t e +pattern Gcd1' t e = AppBuiltin11 Gcd1 t e + +pattern Lcm1' t e = AppBuiltin11 Lcm1 t e + pattern All' e = AppBuiltin1 All e pattern Any' e = AppBuiltin1 Any e diff --git a/src/Jikka/Core/Language/Expr.hs b/src/Jikka/Core/Language/Expr.hs index d5cc44f1..dec75a16 100644 --- a/src/Jikka/Core/Language/Expr.hs +++ b/src/Jikka/Core/Language/Expr.hs @@ -66,6 +66,8 @@ data Semigroup' = SemigroupIntPlus | SemigroupIntMin | SemigroupIntMax + | SemigroupIntGcd + | SemigroupIntLcm deriving (Eq, Ord, Show, Read, Data, Typeable) -- | TODO: What is the difference between `Literal` and `Builtin`? @@ -210,6 +212,10 @@ data Builtin ArgMin | -- | \(: \forall \alpha. \list(\alpha) \to \int\) ArgMax + | -- | \(: \forall \alpha. \list(\alpha) \to \alpha\) + Gcd1 + | -- | \(: \forall \alpha. \list(\alpha) \to \alpha\) + Lcm1 | -- | \(: \list(\bool) \to \bool\) All | -- | \(: \list(\bool) \to \bool\) diff --git a/src/Jikka/Core/Language/TypeCheck.hs b/src/Jikka/Core/Language/TypeCheck.hs index 50ef36af..c21301bd 100644 --- a/src/Jikka/Core/Language/TypeCheck.hs +++ b/src/Jikka/Core/Language/TypeCheck.hs @@ -96,6 +96,8 @@ builtinToType builtin ts = Max1 -> go1 $ \t -> FunLTy t ArgMin -> go1 $ \t -> FunTy (ListTy t) IntTy ArgMax -> go1 $ \t -> FunTy (ListTy t) IntTy + Gcd1 -> go1 $ \t -> FunLTy t + Lcm1 -> go1 $ \t -> FunLTy t All -> go0 $ FunLTy BoolTy Any -> go0 $ FunLTy BoolTy Sorted -> go1 $ \t -> Fun1STy (ListTy t) @@ -131,6 +133,8 @@ semigroupToType = \case SemigroupIntPlus -> IntTy SemigroupIntMin -> IntTy SemigroupIntMax -> IntTy + SemigroupIntGcd -> IntTy + SemigroupIntLcm -> IntTy literalToType :: MonadError Error m => Literal -> m Type literalToType = \case From 403574a0891e1f0267322c8f034dc5e75c62d634 Mon Sep 17 00:00:00 2001 From: hotman78 Date: Fri, 6 Aug 2021 09:53:54 +0000 Subject: [PATCH 2/4] fix: format --- src/Jikka/CPlusPlus/Convert/FromCore.hs | 4 ++-- src/Jikka/Core/Convert/SegmentTree.hs | 4 ++-- src/Jikka/Core/Convert/SpecializeFoldl.hs | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Jikka/CPlusPlus/Convert/FromCore.hs b/src/Jikka/CPlusPlus/Convert/FromCore.hs index e37a0f2d..892efe45 100644 --- a/src/Jikka/CPlusPlus/Convert/FromCore.hs +++ b/src/Jikka/CPlusPlus/Convert/FromCore.hs @@ -483,14 +483,14 @@ runAppBuiltin env f ts args = wrapError' ("converting builtin " ++ X.formatBuilt X.Gcd1 -> go11' $ \t xs -> do y <- Y.newFreshName Y.LocalNameKind return - ( [ Y.Declare t y (Y.DeclareCopy (Y.UnOp Y.Deref (Y.callFunction "std::accumulate" [] [Y.begin xs, Y.end xs, Y.litInt64 0, Y.Lam [(Y.TyAuto,Y.VarName "a" ),(Y.TyAuto,Y.VarName "b")] Y.TyAuto [Y.Return $ Y.callFunction "std::gcd" [] [Y.Var $ Y.VarName "a",Y.Var $ Y.VarName "b"]]]))) + ( [ Y.Declare t y (Y.DeclareCopy (Y.UnOp Y.Deref (Y.callFunction "std::accumulate" [] [Y.begin xs, Y.end xs, Y.litInt64 0, Y.Lam [(Y.TyAuto, Y.VarName "a"), (Y.TyAuto, Y.VarName "b")] Y.TyAuto [Y.Return $ Y.callFunction "std::gcd" [] [Y.Var $ Y.VarName "a", Y.Var $ Y.VarName "b"]]]))) ], Y.Var y ) X.Lcm1 -> go11' $ \t xs -> do y <- Y.newFreshName Y.LocalNameKind return - ( [ Y.Declare t y (Y.DeclareCopy (Y.UnOp Y.Deref (Y.callFunction "std::accumulate" [] [Y.begin xs, Y.end xs, Y.litInt64 1, Y.Lam [(Y.TyAuto,Y.VarName "a"),(Y.TyAuto,Y.VarName "b")] Y.TyAuto [Y.Return $ Y.callFunction "std::lcm" [] [Y.Var $ Y.VarName "a",Y.Var $ Y.VarName "b"]]]))) + ( [ Y.Declare t y (Y.DeclareCopy (Y.UnOp Y.Deref (Y.callFunction "std::accumulate" [] [Y.begin xs, Y.end xs, Y.litInt64 1, Y.Lam [(Y.TyAuto, Y.VarName "a"), (Y.TyAuto, Y.VarName "b")] Y.TyAuto [Y.Return $ Y.callFunction "std::lcm" [] [Y.Var $ Y.VarName "a", Y.Var $ Y.VarName "b"]]]))) ], Y.Var y ) diff --git a/src/Jikka/Core/Convert/SegmentTree.hs b/src/Jikka/Core/Convert/SegmentTree.hs index 0257b0c1..df4dd89c 100644 --- a/src/Jikka/Core/Convert/SegmentTree.hs +++ b/src/Jikka/Core/Convert/SegmentTree.hs @@ -83,8 +83,8 @@ semigroupToBuiltin = \case SemigroupIntPlus -> (Plus, []) SemigroupIntMin -> (Min2, [IntTy]) SemigroupIntMax -> (Max2, [IntTy]) - SemigroupIntGcd -> (Gcd,[]) - SemigroupIntLcm -> (Lcm,[]) + SemigroupIntGcd -> (Gcd, []) + SemigroupIntLcm -> (Lcm, []) unCumulativeSum :: Expr -> Expr -> Maybe (Semigroup', Expr) unCumulativeSum a = \case diff --git a/src/Jikka/Core/Convert/SpecializeFoldl.hs b/src/Jikka/Core/Convert/SpecializeFoldl.hs index 57a1d593..9585222b 100644 --- a/src/Jikka/Core/Convert/SpecializeFoldl.hs +++ b/src/Jikka/Core/Convert/SpecializeFoldl.hs @@ -50,7 +50,6 @@ rule = simpleRewriteRule $ \case -- Max1 Min2' _ (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Min1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) Min2' _ e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Min1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) - -- Lcm1 Lcm' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Lcm1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) Lcm' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Lcm1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) From 695460780770e57ab57467604b4eaab21841f92c Mon Sep 17 00:00:00 2001 From: hotman Date: Sat, 7 Aug 2021 12:40:32 +0900 Subject: [PATCH 3/4] fix: explanation of Gcd1/Lcm1 Co-authored-by: Kimiyuki Onaka --- src/Jikka/Core/Language/Expr.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Jikka/Core/Language/Expr.hs b/src/Jikka/Core/Language/Expr.hs index dec75a16..b6bab6ea 100644 --- a/src/Jikka/Core/Language/Expr.hs +++ b/src/Jikka/Core/Language/Expr.hs @@ -212,9 +212,9 @@ data Builtin ArgMin | -- | \(: \forall \alpha. \list(\alpha) \to \int\) ArgMax - | -- | \(: \forall \alpha. \list(\alpha) \to \alpha\) + | -- | \(: \list(\int) \to \int\) Gcd1 - | -- | \(: \forall \alpha. \list(\alpha) \to \alpha\) + | -- | \(: \list(\int) \to \int\) Lcm1 | -- | \(: \list(\bool) \to \bool\) All From baa573291cfc47e719609640405ff541a867374c Mon Sep 17 00:00:00 2001 From: hotman Date: Sat, 7 Aug 2021 12:41:32 +0900 Subject: [PATCH 4/4] fix: type of Gcd1/Lcm1 Co-authored-by: Kimiyuki Onaka --- src/Jikka/Core/Language/TypeCheck.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Jikka/Core/Language/TypeCheck.hs b/src/Jikka/Core/Language/TypeCheck.hs index c21301bd..61c6512b 100644 --- a/src/Jikka/Core/Language/TypeCheck.hs +++ b/src/Jikka/Core/Language/TypeCheck.hs @@ -96,8 +96,8 @@ builtinToType builtin ts = Max1 -> go1 $ \t -> FunLTy t ArgMin -> go1 $ \t -> FunTy (ListTy t) IntTy ArgMax -> go1 $ \t -> FunTy (ListTy t) IntTy - Gcd1 -> go1 $ \t -> FunLTy t - Lcm1 -> go1 $ \t -> FunLTy t + Gcd1 -> go0 $ FunLTy IntTy + Lcm1 -> go0 $ FunLTy IntTy All -> go0 $ FunLTy BoolTy Any -> go0 $ FunLTy BoolTy Sorted -> go1 $ \t -> Fun1STy (ListTy t)