From 10dbf8f01bd48b46386ca92d1aa0e77daf2a237b Mon Sep 17 00:00:00 2001 From: ncaq Date: Sat, 24 Dec 2022 03:07:14 +0900 Subject: [PATCH] feat: instance `MonadLogic` for `WriterT` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I have implemented `MonadLogic` in `WriterT` as well. Until now, `WriterT` should not be used because it easily causes space leaks. However, [Control.Monad.Trans.Writer.CPS](https://www.stackage.org/haddock/nightly-2022-12-08/transformers-0.5.6.2/Control-Monad-Trans-Writer-CPS.html) no longer causes space leaks and can be used. Therefore, we would like to use `WriterT`, but it is inconvenient if it cannot be combined with `MonadLogic`. Therefore, I implemented it and added tests where `ReaderT` and `StateT` are tested. To use the CPS version of `WriterT`, I raised the mtl version requirement to [feat(Control.Monad.Writer.CPS): re export runWriterT by ncaq · Pull Request #136 · haskell/mtl](https://github.com/haskell/mtl/pull/136) has not yet been imported, so we have no choice but to rely directly on transformers. In the end, mtl depends on transformers, so we have determined that this is not a critical problem. We have taken care to make it easy to remove them when they are no longer needed by doing a limited `import`. --- Control/Monad/Logic/Class.hs | 14 +++++++ logict.cabal | 6 ++- test/Test.hs | 74 ++++++++++++++++++++++++++++++++++-- 3 files changed, 88 insertions(+), 6 deletions(-) diff --git a/Control/Monad/Logic/Class.hs b/Control/Monad/Logic/Class.hs index 061f1ff..f5d6278 100644 --- a/Control/Monad/Logic/Class.hs +++ b/Control/Monad/Logic/Class.hs @@ -33,6 +33,11 @@ import Control.Monad.Trans (MonadTrans(..)) import qualified Control.Monad.State.Lazy as LazyST import qualified Control.Monad.State.Strict as StrictST +#if MIN_VERSION_mtl(2,3,0) +import qualified Control.Monad.Writer.CPS as CpsW +import qualified Control.Monad.Trans.Writer.CPS as CpsW(writerT, runWriterT) +#endif + -- | A backtracking, logic programming monad. class (Monad m, Alternative m) => MonadLogic m where -- | Attempts to __split__ the computation, giving access to the first @@ -347,6 +352,15 @@ instance MonadLogic m => MonadLogic (ReaderT e m) where Nothing -> pure Nothing Just (a, m) -> pure (Just (a, lift m)) +#if MIN_VERSION_mtl(2,3,0) +instance (Monoid w, MonadLogic m, MonadPlus m) => MonadLogic (CpsW.WriterT w m) where + msplit wm = CpsW.writerT $ do + r <- msplit $ CpsW.runWriterT wm + case r of + Nothing -> pure (Nothing, mempty) + Just ((a, w), m) -> pure (Just (a, CpsW.writerT m), w) +#endif + -- | See note on splitting above. instance (MonadLogic m, MonadPlus m) => MonadLogic (StrictST.StateT s m) where msplit sm = StrictST.StateT $ \s -> diff --git a/logict.cabal b/logict.cabal index 47b75c8..6ef027a 100644 --- a/logict.cabal +++ b/logict.cabal @@ -40,11 +40,12 @@ library build-depends: base >=4.3 && <5, - mtl >=2.0 && <2.4 + mtl >=2.0 && <2.4, + transformers if impl(ghc <8.0) build-depends: - fail, transformers + fail executable grandparents buildable: False @@ -69,6 +70,7 @@ test-suite logict-tests async >=2.0, logict, mtl, + transformers, tasty, tasty-hunit diff --git a/test/Test.hs b/test/Test.hs index 9635afb..56efd79 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -18,15 +18,19 @@ import qualified Control.Monad.State.Lazy as SL import qualified Control.Monad.State.Strict as SS import Data.Maybe -#if MIN_VERSION_base(4,9,0) -#if MIN_VERSION_base(4,11,0) -#else +#if MIN_VERSION_base(4,9,0) && !MIN_VERSION_base(4,11,0) import Data.Semigroup (Semigroup (..)) #endif -#else + +-- required by base < 4.9 OR CPS Writer test +#if !MIN_VERSION_base(4,9,0) || MIN_VERSION_mtl(2,3,0) import Data.Monoid #endif +#if MIN_VERSION_mtl(2,3,0) +import qualified Control.Monad.Writer.CPS as CpsW (WriterT, execWriterT, tell) +import qualified Control.Monad.Trans.Writer.CPS as CpsW (runWriterT) +#endif monadReader1 :: Assertion monadReader1 = assertEqual "should be equal" [5 :: Int] $ @@ -53,6 +57,12 @@ nats, odds, oddsOrTwo, oddsOrTwoUnfair, oddsOrTwoFair, odds5down :: Monad m => LogicT m Integer +-- | A `WriterT` version of `evalStateT`. +#if MIN_VERSION_mtl(2,3,0) +evalWriterT :: (Monad m, Monoid w) => CpsW.WriterT w m a -> m a +evalWriterT = fmap fst . CpsW.runWriterT +#endif + #if MIN_VERSION_base(4,8,0) nats = pure 0 `mplus` ((1 +) <$> nats) #else @@ -152,6 +162,13 @@ main = defaultMain $ z = mzero in assertBool "ReaderT" $ null $ catMaybes $ runReaderT (msplit z) 0 +#if MIN_VERSION_mtl(2,3,0) + , testCase "msplit mzero :: CPS WriterT" $ + let z :: CpsW.WriterT (Sum Int) [] String + z = mzero + in assertBool "CPS WriterT" $ null $ catMaybes (evalWriterT (msplit z)) +#endif + , testCase "msplit mzero :: LogicT" $ let z :: LogicT [] String z = mzero @@ -181,6 +198,15 @@ main = defaultMain $ extract (msplit op) @?= [sample] extract (msplit op >>= (\(Just (_,nxt)) -> msplit nxt)) @?= [] +#if MIN_VERSION_mtl(2,3,0) + , testCase "msplit CPS WriterT" $ do + let op :: CpsW.WriterT (Sum Integer) [] () + op = CpsW.tell 1 `mplus` op + extract = CpsW.execWriterT + extract (msplit op) @?= [1] + extract (msplit op >>= \(Just (_,nxt)) -> msplit nxt) @?= [2] +#endif + , testCase "msplit LogicT" $ do let op :: LogicT [] Integer op = foldr (mplus . return) mzero sample @@ -240,6 +266,13 @@ main = defaultMain $ (take 4 $ runReaderT (let oddsR = return 1 `mplus` liftM (2+) oddsR in oddsR `interleave` return (2 :: Integer)) "go") +#if MIN_VERSION_mtl(2,3,0) + , testCase "fair disjunction :: CPS WriterT" $ [1,2,3,5] @=? + (take 4 $ evalWriterT (let oddsW :: CpsW.WriterT [Char] [] Integer + oddsW = return 1 `mplus` liftM (2+) oddsW + in oddsW `interleave` return (2 :: Integer))) +#endif + , testCase "fair disjunction :: strict StateT" $ [1,2,3,5] @=? (take 4 $ SS.evalStateT (let oddsS = return 1 `mplus` liftM (2+) oddsS in oddsS `interleave` return (2 :: Integer)) "go") @@ -343,6 +376,17 @@ main = defaultMain $ if even x then return x else mzero ) "env") +#if MIN_VERSION_mtl(2,3,0) + , testCase "fair conjunction :: CPS WriterT" $ [2,4,6,8] @=? + (take 4 $ evalWriterT $ + (let oddsW :: CpsW.WriterT [Char] [] Integer + oddsW = return (1 :: Integer) `mplus` liftM (2+) oddsW + oddsPlus n = oddsW >>= \a -> return (a + n) + in do x <- (return 0 `mplus` return 1) >>- oddsPlus + if even x then return x else mzero + )) +#endif + , testCase "fair conjunction :: strict StateT" $ [2,4,6,8] @=? (take 4 $ SS.evalStateT (let oddsS = return (1 :: Integer) `mplus` liftM (2+) oddsS oddsPlus n = oddsS >>= \a -> return (a + n) @@ -426,6 +470,20 @@ main = defaultMain $ in testCase "indivisible odds :: ReaderT" $ [3,5,7,11,13,17,19,23,29,31] @=? (take 10 $ runReaderT oc "env") +#if MIN_VERSION_mtl(2,3,0) + , let iota n = msum (map return [1..n]) + oddsW = return (1 :: Integer) `mplus` liftM (2+) oddsW + oc :: CpsW.WriterT [Char] [] Integer + oc = do n <- oddsW + guard (n > 1) + ifte (do d <- iota (n - 1) + guard (d > 1 && n `mod` d == 0)) + (const mzero) + (return n) + in testCase "indivisible odds :: CPS WriterT" $ [3,5,7,11,13,17,19,23,29,31] @=? + (take 10 $ (fmap fst . CpsW.runWriterT) oc) +#endif + , let iota n = msum (map return [1..n]) oddsS = return (1 :: Integer) `mplus` liftM (2+) oddsS oc = do n <- oddsS @@ -499,6 +557,14 @@ main = defaultMain $ lnot (isEven v) return v) "env") +#if MIN_VERSION_mtl(2,3,0) + , testCase "inversion :: CPS WriterT" $ [1,3,5,7,9] @=? + (take 5 $ (evalWriterT :: CpsW.WriterT [Char] [] Integer -> [Integer]) + (do v <- foldr (mplus . return) mzero [(1::Integer)..] + lnot (isEven v) + return v)) +#endif + , testCase "inversion :: strict StateT" $ [1,3,5,7,9] @=? (take 5 $ SS.evalStateT (do v <- foldr (mplus . return) mzero [(1::Integer)..] lnot (isEven v)