Skip to content

Commit

Permalink
Merge pull request #175 from kmyk/pull-163
Browse files Browse the repository at this point in the history
chore: Merge #163
  • Loading branch information
kmyk authored Aug 7, 2021
2 parents 45c04f3 + 12e974e commit 62907b8
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 0 deletions.
6 changes: 6 additions & 0 deletions runtime/include/jikka/segment_tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ inline int64_t min_int64_t(int64_t a, int64_t b) { return std::min(a, b); }

inline int64_t max_int64_t(int64_t a, int64_t b) { return std::max(a, b); }

inline int64_t gcd_int64_t(int64_t a, int64_t b) { return std::gcd(a, b); }

inline int64_t lcm_int64_t(int64_t a, int64_t b) { return std::lcm(a, b); }

inline int64_t const_zero() { return 0; }

inline int64_t const_one() { return 1; }

inline int64_t const_int64_min() { return INT64_MIN; }

inline int64_t const_int64_max() { return INT64_MAX; }
Expand Down
16 changes: 16 additions & 0 deletions src/Jikka/CPlusPlus/Convert/FromCore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ runSemigroup = \case
X.SemigroupIntPlus -> Y.MonoidIntPlus
X.SemigroupIntMin -> Y.MonoidIntMin
X.SemigroupIntMax -> Y.MonoidIntMax
X.SemigroupIntGcd -> Y.MonoidIntGcd
X.SemigroupIntLcm -> Y.MonoidIntLcm

runLiteral :: (MonadAlpha m, MonadError Error m) => Env -> X.Literal -> m Y.Expr
runLiteral env = \case
Expand Down Expand Up @@ -478,6 +480,20 @@ runAppBuiltin env f ts args = wrapError' ("converting builtin " ++ X.formatBuilt
],
Y.Var y
)
X.Gcd1 -> go11' $ \t xs -> do
y <- Y.newFreshName Y.LocalNameKind
return
( [ Y.Declare t y (Y.DeclareCopy (Y.UnOp Y.Deref (Y.callFunction "std::accumulate" [] [Y.begin xs, Y.end xs, Y.litInt64 0, Y.Lam [(Y.TyAuto, Y.VarName "a"), (Y.TyAuto, Y.VarName "b")] Y.TyAuto [Y.Return $ Y.callFunction "std::gcd" [] [Y.Var $ Y.VarName "a", Y.Var $ Y.VarName "b"]]])))
],
Y.Var y
)
X.Lcm1 -> go11' $ \t xs -> do
y <- Y.newFreshName Y.LocalNameKind
return
( [ Y.Declare t y (Y.DeclareCopy (Y.UnOp Y.Deref (Y.callFunction "std::accumulate" [] [Y.begin xs, Y.end xs, Y.litInt64 1, Y.Lam [(Y.TyAuto, Y.VarName "a"), (Y.TyAuto, Y.VarName "b")] Y.TyAuto [Y.Return $ Y.callFunction "std::lcm" [] [Y.Var $ Y.VarName "a", Y.Var $ Y.VarName "b"]]])))
],
Y.Var y
)
X.All -> go01' $ \xs -> do
y <- Y.newFreshName Y.LocalNameKind
return
Expand Down
2 changes: 2 additions & 0 deletions src/Jikka/CPlusPlus/Format.hs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ formatType = \case
MonoidIntPlus -> "atcoder::segtree<int64_t, jikka::plus_int64_t, jikka::const_zero>"
MonoidIntMin -> "atcoder::segtree<int64_t, jikka::min_int64_t, jikka::const_int64_max>"
MonoidIntMax -> "atcoder::segtree<int64_t, jikka::max_int64_t, jikka::const_int64_min>"
MonoidIntGcd -> "atcoder::segtree<int64_t, jikka::gcd_int64_t, jikka::const_zero>"
MonoidIntLcm -> "atcoder::segtree<int64_t, jikka::lcm_int64_t, jikka::const_one>"
TyIntValue n -> show n

formatLiteral :: Literal -> Code
Expand Down
4 changes: 4 additions & 0 deletions src/Jikka/CPlusPlus/Language/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ data Monoid'
MonoidIntMin
| -- | \((\mathrm{int64\_t}, \max, \mathrm{INT64\_MIN})\)
MonoidIntMax
| -- | \((\mathbb{Z}, \gcd, 0)\)
MonoidIntGcd
| -- | \((\mathbb{Z}, \mathrm{lcm}, 1)\)
MonoidIntLcm
deriving (Eq, Ord, Show, Read)

data Literal
Expand Down
8 changes: 8 additions & 0 deletions src/Jikka/Core/Convert/CumulativeSum.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ rule = makeRewriteRule "Jikka.Core.Convert.CumulativeSum" $ \_ -> \case
Just <$> cumulativeMax (Min2' t) t (Just a0) a n
Min1' t (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n)) | x' == x && x `isUnusedVar` a -> do
Just <$> cumulativeMax (Min2' t) t Nothing a n
Lcm1' t (Cons' _ a0 (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n))) | x' == x && x `isUnusedVar` a -> do
Just <$> cumulativeMax Lcm' t (Just a0) a n
Lcm1' t (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n)) | x' == x && x `isUnusedVar` a -> do
Just <$> cumulativeMax Lcm' t Nothing a n
Gcd1' t (Cons' _ a0 (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n))) | x' == x && x `isUnusedVar` a -> do
Just <$> cumulativeMax Gcd' t (Just a0) a n
Gcd1' t (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n)) | x' == x && x `isUnusedVar` a -> do
Just <$> cumulativeMax Gcd' t Nothing a n
_ -> return Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
Expand Down
4 changes: 4 additions & 0 deletions src/Jikka/Core/Convert/SegmentTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,17 @@ builtinToSemigroup builtin ts = case (builtin, ts) of
(Plus, []) -> Just SemigroupIntPlus
(Min2, [IntTy]) -> Just SemigroupIntMin
(Max2, [IntTy]) -> Just SemigroupIntMax
(Gcd, []) -> Just SemigroupIntGcd
(Lcm, []) -> Just SemigroupIntLcm
_ -> Nothing

semigroupToBuiltin :: Semigroup' -> (Builtin, [Type])
semigroupToBuiltin = \case
SemigroupIntPlus -> (Plus, [])
SemigroupIntMin -> (Min2, [IntTy])
SemigroupIntMax -> (Max2, [IntTy])
SemigroupIntGcd -> (Gcd, [])
SemigroupIntLcm -> (Lcm, [])

unCumulativeSum :: Expr -> Expr -> Maybe (Semigroup', Expr)
unCumulativeSum a = \case
Expand Down
6 changes: 6 additions & 0 deletions src/Jikka/Core/Convert/SpecializeFoldl.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ rule = simpleRewriteRule "Jikka.Core.Convert.SpecializeFoldl" $ \case
-- Max1
Min2' _ (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Min1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Min2' _ e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Min1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
-- Lcm1
Lcm' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Lcm1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Lcm' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Lcm1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
-- Gcd1
Gcd' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Gcd1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Gcd' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Gcd1' t2 (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
-- others
_ -> Nothing
-- The outer floor-mod is required because foldl for empty lists returns values without modulo.
Expand Down
4 changes: 4 additions & 0 deletions src/Jikka/Core/Evaluate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ segmentTreeGetRange semigrp segtree l r
SemigroupIntPlus -> sum slice
SemigroupIntMin -> minimum slice
SemigroupIntMax -> maximum slice
SemigroupIntGcd -> foldl gcd 0 slice
SemigroupIntLcm -> foldl lcm 1 slice

build :: MonadError Error m => (V.Vector Value -> m Value) -> V.Vector Value -> Integer -> m (V.Vector Value)
build _ _ n | n < 0 = throwRuntimeError $ "negative length: " ++ show n
Expand Down Expand Up @@ -223,6 +225,8 @@ callBuiltin builtin ts args = wrapError' ("while calling builtin " ++ formatBuil
Max1 -> go1 valueToList id (V.maximumBy compareValues')
ArgMin -> go1 valueToList ValInt $ \xs -> snd (minimumBy (\(x, i) (y, j) -> compareValues' x y <> compare i j) (zip (V.toList xs) [0 ..]))
ArgMax -> go1 valueToList ValInt $ \xs -> snd (maximumBy (\(x, i) (y, j) -> compareValues' x y <> compare i j) (zip (V.toList xs) [0 ..]))
Gcd1 -> go1 valueToIntList ValInt (foldl gcd 0)
Lcm1 -> go1 valueToIntList ValInt (foldl lcm 1)
All -> go1 valueToBoolList ValBool and
Any -> go1 valueToBoolList ValBool or
Sorted -> go1 valueToList ValList sortVector
Expand Down
4 changes: 4 additions & 0 deletions src/Jikka/Core/Format.hs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ formatSemigroup = \case
SemigroupIntPlus -> "int.plus"
SemigroupIntMin -> "int.min"
SemigroupIntMax -> "int.max"
SemigroupIntGcd -> "int.gcd"
SemigroupIntLcm -> "int.lcm"

data Builtin'
= Fun String
Expand Down Expand Up @@ -217,6 +219,8 @@ analyzeBuiltin = \case
Max1 -> Fun "max"
ArgMin -> Fun "argmin"
ArgMax -> Fun "argmax"
Gcd1 -> Fun "gcd"
Lcm1 -> Fun "lcm"
All -> Fun "all"
Any -> Fun "any"
Sorted -> Fun "sort"
Expand Down
4 changes: 4 additions & 0 deletions src/Jikka/Core/Language/BuiltinPatterns.hs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ pattern ArgMin' t e = AppBuiltin11 ArgMin t e

pattern ArgMax' t e = AppBuiltin11 ArgMax t e

pattern Gcd1' t e = AppBuiltin11 Gcd1 t e

pattern Lcm1' t e = AppBuiltin11 Lcm1 t e

pattern All' e = AppBuiltin1 All e

pattern Any' e = AppBuiltin1 Any e
Expand Down
6 changes: 6 additions & 0 deletions src/Jikka/Core/Language/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ data Semigroup'
= SemigroupIntPlus
| SemigroupIntMin
| SemigroupIntMax
| SemigroupIntGcd
| SemigroupIntLcm
deriving (Eq, Ord, Show, Read, Data, Typeable)

-- | TODO: What is the difference between `Literal` and `Builtin`?
Expand Down Expand Up @@ -210,6 +212,10 @@ data Builtin
ArgMin
| -- | \(: \forall \alpha. \list(\alpha) \to \int\)
ArgMax
| -- | \(: \list(\int) \to \int\)
Gcd1
| -- | \(: \list(\int) \to \int\)
Lcm1
| -- | \(: \list(\bool) \to \bool\)
All
| -- | \(: \list(\bool) \to \bool\)
Expand Down
4 changes: 4 additions & 0 deletions src/Jikka/Core/Language/TypeCheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ builtinToType builtin ts =
Max1 -> go1 $ \t -> FunLTy t
ArgMin -> go1 $ \t -> FunTy (ListTy t) IntTy
ArgMax -> go1 $ \t -> FunTy (ListTy t) IntTy
Gcd1 -> go0 $ FunLTy IntTy
Lcm1 -> go0 $ FunLTy IntTy
All -> go0 $ FunLTy BoolTy
Any -> go0 $ FunLTy BoolTy
Sorted -> go1 $ \t -> Fun1STy (ListTy t)
Expand Down Expand Up @@ -131,6 +133,8 @@ semigroupToType = \case
SemigroupIntPlus -> IntTy
SemigroupIntMin -> IntTy
SemigroupIntMax -> IntTy
SemigroupIntGcd -> IntTy
SemigroupIntLcm -> IntTy

literalToType :: MonadError Error m => Literal -> m Type
literalToType = \case
Expand Down
2 changes: 2 additions & 0 deletions src/Jikka/Core/Language/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ isConstantTimeBuiltin = \case
Max1 -> False
ArgMin -> False
ArgMax -> False
Gcd1 -> False
Lcm1 -> False
All -> False
Any -> False
Sorted -> False
Expand Down

0 comments on commit 62907b8

Please sign in to comment.