From cd44aa9eb0af36829bc807973bd986771af387a9 Mon Sep 17 00:00:00 2001 From: ncaq Date: Sat, 24 Dec 2022 03:07:14 +0900 Subject: [PATCH] feat!: instance `MonadLogic` for `WriterT` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I have implemented `MonadLogic` in `WriterT` as well. Until now, `WriterT` should not be used because it easily causes space leaks. However, [Control.Monad.Trans.Writer.CPS](https://www.stackage.org/haddock/nightly-2022-12-08/transformers-0.5.6.2/Control-Monad-Trans-Writer-CPS.html) no longer causes space leaks and can be used. Therefore, we would like to use `WriterT`, but it is inconvenient if it cannot be combined with `MonadLogic`. Therefore, I implemented it and added tests where `ReaderT` and `StateT` are tested. To use the CPS version of `WriterT`, I raised the mtl version requirement to [feat(Control.Monad.Writer.CPS): re export runWriterT by ncaq · Pull Request #136 · haskell/mtl](https://github.com/haskell/mtl/pull/136) has not yet been imported, so we have no choice but to rely directly on transformers. In the end, mtl depends on transformers, so we have determined that this is not a critical problem. We have taken care to make it easy to remove them when they are no longer needed by doing a limited `import`. I was surprised myself that many of the CPP macros in the test code were removed. It was not my intention. I believe it was probably done by the HLS plugin. I've been trying to find out why it deleted them. I found out that the new mtl raises the lower limit for the base library in the cabal file. So when raising the lower version limit of mtl, we thought that it is certainly not a very effective practice to control the base version here. Therefore, we decided to leave the deleted part as it is. --- Control/Monad/Logic/Class.hs | 9 +++++ logict.cabal | 6 ++-- test/Test.hs | 66 ++++++++++++++++++++++++++++-------- 3 files changed, 65 insertions(+), 16 deletions(-) diff --git a/Control/Monad/Logic/Class.hs b/Control/Monad/Logic/Class.hs index 061f1ff..0cc3a9b 100644 --- a/Control/Monad/Logic/Class.hs +++ b/Control/Monad/Logic/Class.hs @@ -30,6 +30,8 @@ import Control.Applicative import Control.Monad import Control.Monad.Reader (ReaderT(..)) import Control.Monad.Trans (MonadTrans(..)) +import qualified Control.Monad.Writer.CPS as CpsW +import qualified Control.Monad.Trans.Writer.CPS as CpsW(writerT, runWriterT) import qualified Control.Monad.State.Lazy as LazyST import qualified Control.Monad.State.Strict as StrictST @@ -347,6 +349,13 @@ instance MonadLogic m => MonadLogic (ReaderT e m) where Nothing -> pure Nothing Just (a, m) -> pure (Just (a, lift m)) +instance (Monoid w, MonadLogic m, MonadPlus m) => MonadLogic (CpsW.WriterT w m) where + msplit wm = CpsW.writerT $ do + r <- msplit $ CpsW.runWriterT wm + case r of + Nothing -> pure (Nothing, mempty) + Just ((a, w), m) -> pure (Just (a, CpsW.writerT m), w) + -- | See note on splitting above. instance (MonadLogic m, MonadPlus m) => MonadLogic (StrictST.StateT s m) where msplit sm = StrictST.StateT $ \s -> diff --git a/logict.cabal b/logict.cabal index 47b75c8..5feaa52 100644 --- a/logict.cabal +++ b/logict.cabal @@ -40,11 +40,12 @@ library build-depends: base >=4.3 && <5, - mtl >=2.0 && <2.4 + mtl >=2.3 && <2.4, + transformers if impl(ghc <8.0) build-depends: - fail, transformers + fail executable grandparents buildable: False @@ -69,6 +70,7 @@ test-suite logict-tests async >=2.0, logict, mtl, + transformers, tasty, tasty-hunit diff --git a/test/Test.hs b/test/Test.hs index 9635afb..9b16012 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -14,18 +14,12 @@ import Control.Monad import Control.Monad.Identity import Control.Monad.Logic import Control.Monad.Reader +import qualified Control.Monad.Writer.CPS as CpsW +import qualified Control.Monad.Trans.Writer.CPS as CpsW(runWriterT) import qualified Control.Monad.State.Lazy as SL import qualified Control.Monad.State.Strict as SS import Data.Maybe - -#if MIN_VERSION_base(4,9,0) -#if MIN_VERSION_base(4,11,0) -#else -import Data.Semigroup (Semigroup (..)) -#endif -#else import Data.Monoid -#endif monadReader1 :: Assertion @@ -53,11 +47,11 @@ nats, odds, oddsOrTwo, oddsOrTwoUnfair, oddsOrTwoFair, odds5down :: Monad m => LogicT m Integer -#if MIN_VERSION_base(4,8,0) +-- | A `WriterT` version of `evalStateT`. +evalWriterT :: (Monad m, Monoid w) => CpsW.WriterT w m a -> m a +evalWriterT = fmap fst . CpsW.runWriterT + nats = pure 0 `mplus` ((1 +) <$> nats) -#else -nats = return 0 `mplus` liftM (1 +) nats -#endif odds = return 1 `mplus` liftM (2+) odds @@ -77,9 +71,9 @@ yieldWords = go main :: IO () main = defaultMain $ -#if __GLASGOW_HASKELL__ >= 702 + localOption (mkTimeout 3000000) $ -- 3 second deadman timeout -#endif + testGroup "All" [ testGroup "Monad Reader + env" [ testCase "Monad Reader 1" monadReader1 @@ -152,6 +146,11 @@ main = defaultMain $ z = mzero in assertBool "ReaderT" $ null $ catMaybes $ runReaderT (msplit z) 0 + , testCase "msplit mzero :: CPS WriterT" $ + let z :: CpsW.WriterT (Sum Int) [] String + z = mzero + in assertBool "CPS WriterT" $ null $ catMaybes (evalWriterT (msplit z)) + , testCase "msplit mzero :: LogicT" $ let z :: LogicT [] String z = mzero @@ -181,6 +180,13 @@ main = defaultMain $ extract (msplit op) @?= [sample] extract (msplit op >>= (\(Just (_,nxt)) -> msplit nxt)) @?= [] + , testCase "msplit CPS WriterT" $ do + let op :: CpsW.WriterT (Sum Integer) [] () + op = CpsW.tell 1 `mplus` op + extract = CpsW.execWriterT + extract (msplit op) @?= [1] + extract (msplit op >>= \(Just (_,nxt)) -> msplit nxt) @?= [2] + , testCase "msplit LogicT" $ do let op :: LogicT [] Integer op = foldr (mplus . return) mzero sample @@ -240,6 +246,11 @@ main = defaultMain $ (take 4 $ runReaderT (let oddsR = return 1 `mplus` liftM (2+) oddsR in oddsR `interleave` return (2 :: Integer)) "go") + , testCase "fair disjunction :: CPS WriterT" $ [1,2,3,5] @=? + (take 4 $ evalWriterT (let oddsW :: CpsW.WriterT [Char] [] Integer + oddsW = return 1 `mplus` liftM (2+) oddsW + in oddsW `interleave` return (2 :: Integer))) + , testCase "fair disjunction :: strict StateT" $ [1,2,3,5] @=? (take 4 $ SS.evalStateT (let oddsS = return 1 `mplus` liftM (2+) oddsS in oddsS `interleave` return (2 :: Integer)) "go") @@ -343,6 +354,15 @@ main = defaultMain $ if even x then return x else mzero ) "env") + , testCase "fair conjunction :: CPS WriterT" $ [2,4,6,8] @=? + (take 4 $ evalWriterT $ + (let oddsW :: CpsW.WriterT [Char] [] Integer + oddsW = return (1 :: Integer) `mplus` liftM (2+) oddsW + oddsPlus n = oddsW >>= \a -> return (a + n) + in do x <- (return 0 `mplus` return 1) >>- oddsPlus + if even x then return x else mzero + )) + , testCase "fair conjunction :: strict StateT" $ [2,4,6,8] @=? (take 4 $ SS.evalStateT (let oddsS = return (1 :: Integer) `mplus` liftM (2+) oddsS oddsPlus n = oddsS >>= \a -> return (a + n) @@ -426,6 +446,18 @@ main = defaultMain $ in testCase "indivisible odds :: ReaderT" $ [3,5,7,11,13,17,19,23,29,31] @=? (take 10 $ runReaderT oc "env") + , let iota n = msum (map return [1..n]) + oddsW = return (1 :: Integer) `mplus` liftM (2+) oddsW + oc :: CpsW.WriterT [Char] [] Integer + oc = do n <- oddsW + guard (n > 1) + ifte (do d <- iota (n - 1) + guard (d > 1 && n `mod` d == 0)) + (const mzero) + (return n) + in testCase "indivisible odds :: CPS WriterT" $ [3,5,7,11,13,17,19,23,29,31] @=? + (take 10 $ (fmap fst . CpsW.runWriterT) oc) + , let iota n = msum (map return [1..n]) oddsS = return (1 :: Integer) `mplus` liftM (2+) oddsS oc = do n <- oddsS @@ -499,6 +531,12 @@ main = defaultMain $ lnot (isEven v) return v) "env") + , testCase "inversion :: CPS WriterT" $ [1,3,5,7,9] @=? + (take 5 $ (evalWriterT :: CpsW.WriterT [Char] [] Integer -> [Integer]) + (do v <- foldr (mplus . return) mzero [(1::Integer)..] + lnot (isEven v) + return v)) + , testCase "inversion :: strict StateT" $ [1,3,5,7,9] @=? (take 5 $ SS.evalStateT (do v <- foldr (mplus . return) mzero [(1::Integer)..] lnot (isEven v)