Skip to content

Commit

Permalink
feat!: instance MonadLogic for WriterT
Browse files Browse the repository at this point in the history
I have implemented `MonadLogic` in `WriterT` as well.

Until now, `WriterT` should not be used because it easily causes space leaks.
However, [Control.Monad.Trans.Writer.CPS](https://www.stackage.org/haddock/nightly-2022-12-08/transformers-0.5.6.2/Control-Monad-Trans-Writer-CPS.html) no longer causes space leaks and can be used.
Therefore, we would like to use `WriterT`, but it is inconvenient if it cannot be combined with `MonadLogic`.

Therefore, I implemented it and added tests where `ReaderT` and `StateT` are tested.

To use the CPS version of `WriterT`, I raised the mtl version requirement to [feat(Control.Monad.Writer.CPS): re export runWriterT by ncaq · Pull Request #136 · haskell/mtl](haskell/mtl#136) has not yet been imported, so we have no choice but to rely directly on transformers.
In the end, mtl depends on transformers, so we have determined that this is not a critical problem.
We have taken care to make it easy to remove them when they are no longer needed by doing a limited `import`.

I was surprised myself that many of the CPP macros in the test code were removed.
It was not my intention.
I believe it was probably done by the HLS plugin.
I've been trying to find out why it deleted them. I found out that the new mtl raises the lower limit for the base library in the cabal file.
So when raising the lower version limit of mtl, we thought that it is certainly not a very effective practice to control the base version here.
Therefore, we decided to leave the deleted part as it is.
  • Loading branch information
ncaq committed Dec 23, 2022
1 parent 3070bfb commit cd44aa9
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 16 deletions.
9 changes: 9 additions & 0 deletions Control/Monad/Logic/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import Control.Applicative
import Control.Monad
import Control.Monad.Reader (ReaderT(..))
import Control.Monad.Trans (MonadTrans(..))
import qualified Control.Monad.Writer.CPS as CpsW
import qualified Control.Monad.Trans.Writer.CPS as CpsW(writerT, runWriterT)
import qualified Control.Monad.State.Lazy as LazyST
import qualified Control.Monad.State.Strict as StrictST

Expand Down Expand Up @@ -347,6 +349,13 @@ instance MonadLogic m => MonadLogic (ReaderT e m) where
Nothing -> pure Nothing
Just (a, m) -> pure (Just (a, lift m))

instance (Monoid w, MonadLogic m, MonadPlus m) => MonadLogic (CpsW.WriterT w m) where
msplit wm = CpsW.writerT $ do
r <- msplit $ CpsW.runWriterT wm
case r of
Nothing -> pure (Nothing, mempty)
Just ((a, w), m) -> pure (Just (a, CpsW.writerT m), w)

-- | See note on splitting above.
instance (MonadLogic m, MonadPlus m) => MonadLogic (StrictST.StateT s m) where
msplit sm = StrictST.StateT $ \s ->
Expand Down
6 changes: 4 additions & 2 deletions logict.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ library

build-depends:
base >=4.3 && <5,
mtl >=2.0 && <2.4
mtl >=2.3 && <2.4,
transformers

if impl(ghc <8.0)
build-depends:
fail, transformers
fail

executable grandparents
buildable: False
Expand All @@ -69,6 +70,7 @@ test-suite logict-tests
async >=2.0,
logict,
mtl,
transformers,
tasty,
tasty-hunit

Expand Down
66 changes: 52 additions & 14 deletions test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,12 @@ import Control.Monad
import Control.Monad.Identity
import Control.Monad.Logic
import Control.Monad.Reader
import qualified Control.Monad.Writer.CPS as CpsW
import qualified Control.Monad.Trans.Writer.CPS as CpsW(runWriterT)
import qualified Control.Monad.State.Lazy as SL
import qualified Control.Monad.State.Strict as SS
import Data.Maybe

#if MIN_VERSION_base(4,9,0)
#if MIN_VERSION_base(4,11,0)
#else
import Data.Semigroup (Semigroup (..))
#endif
#else
import Data.Monoid
#endif


monadReader1 :: Assertion
Expand Down Expand Up @@ -53,11 +47,11 @@ nats, odds, oddsOrTwo,
oddsOrTwoUnfair, oddsOrTwoFair,
odds5down :: Monad m => LogicT m Integer

#if MIN_VERSION_base(4,8,0)
-- | A `WriterT` version of `evalStateT`.
evalWriterT :: (Monad m, Monoid w) => CpsW.WriterT w m a -> m a
evalWriterT = fmap fst . CpsW.runWriterT

nats = pure 0 `mplus` ((1 +) <$> nats)
#else
nats = return 0 `mplus` liftM (1 +) nats
#endif

odds = return 1 `mplus` liftM (2+) odds

Expand All @@ -77,9 +71,9 @@ yieldWords = go

main :: IO ()
main = defaultMain $
#if __GLASGOW_HASKELL__ >= 702

localOption (mkTimeout 3000000) $ -- 3 second deadman timeout
#endif

testGroup "All"
[ testGroup "Monad Reader + env"
[ testCase "Monad Reader 1" monadReader1
Expand Down Expand Up @@ -152,6 +146,11 @@ main = defaultMain $
z = mzero
in assertBool "ReaderT" $ null $ catMaybes $ runReaderT (msplit z) 0

, testCase "msplit mzero :: CPS WriterT" $
let z :: CpsW.WriterT (Sum Int) [] String
z = mzero
in assertBool "CPS WriterT" $ null $ catMaybes (evalWriterT (msplit z))

, testCase "msplit mzero :: LogicT" $
let z :: LogicT [] String
z = mzero
Expand Down Expand Up @@ -181,6 +180,13 @@ main = defaultMain $
extract (msplit op) @?= [sample]
extract (msplit op >>= (\(Just (_,nxt)) -> msplit nxt)) @?= []

, testCase "msplit CPS WriterT" $ do
let op :: CpsW.WriterT (Sum Integer) [] ()
op = CpsW.tell 1 `mplus` op
extract = CpsW.execWriterT
extract (msplit op) @?= [1]
extract (msplit op >>= \(Just (_,nxt)) -> msplit nxt) @?= [2]

, testCase "msplit LogicT" $ do
let op :: LogicT [] Integer
op = foldr (mplus . return) mzero sample
Expand Down Expand Up @@ -240,6 +246,11 @@ main = defaultMain $
(take 4 $ runReaderT (let oddsR = return 1 `mplus` liftM (2+) oddsR
in oddsR `interleave` return (2 :: Integer)) "go")

, testCase "fair disjunction :: CPS WriterT" $ [1,2,3,5] @=?
(take 4 $ evalWriterT (let oddsW :: CpsW.WriterT [Char] [] Integer
oddsW = return 1 `mplus` liftM (2+) oddsW
in oddsW `interleave` return (2 :: Integer)))

, testCase "fair disjunction :: strict StateT" $ [1,2,3,5] @=?
(take 4 $ SS.evalStateT (let oddsS = return 1 `mplus` liftM (2+) oddsS
in oddsS `interleave` return (2 :: Integer)) "go")
Expand Down Expand Up @@ -343,6 +354,15 @@ main = defaultMain $
if even x then return x else mzero
) "env")

, testCase "fair conjunction :: CPS WriterT" $ [2,4,6,8] @=?
(take 4 $ evalWriterT $
(let oddsW :: CpsW.WriterT [Char] [] Integer
oddsW = return (1 :: Integer) `mplus` liftM (2+) oddsW
oddsPlus n = oddsW >>= \a -> return (a + n)
in do x <- (return 0 `mplus` return 1) >>- oddsPlus
if even x then return x else mzero
))

, testCase "fair conjunction :: strict StateT" $ [2,4,6,8] @=?
(take 4 $ SS.evalStateT (let oddsS = return (1 :: Integer) `mplus` liftM (2+) oddsS
oddsPlus n = oddsS >>= \a -> return (a + n)
Expand Down Expand Up @@ -426,6 +446,18 @@ main = defaultMain $
in testCase "indivisible odds :: ReaderT" $ [3,5,7,11,13,17,19,23,29,31] @=?
(take 10 $ runReaderT oc "env")

, let iota n = msum (map return [1..n])
oddsW = return (1 :: Integer) `mplus` liftM (2+) oddsW
oc :: CpsW.WriterT [Char] [] Integer
oc = do n <- oddsW
guard (n > 1)
ifte (do d <- iota (n - 1)
guard (d > 1 && n `mod` d == 0))
(const mzero)
(return n)
in testCase "indivisible odds :: CPS WriterT" $ [3,5,7,11,13,17,19,23,29,31] @=?
(take 10 $ (fmap fst . CpsW.runWriterT) oc)

, let iota n = msum (map return [1..n])
oddsS = return (1 :: Integer) `mplus` liftM (2+) oddsS
oc = do n <- oddsS
Expand Down Expand Up @@ -499,6 +531,12 @@ main = defaultMain $
lnot (isEven v)
return v) "env")

, testCase "inversion :: CPS WriterT" $ [1,3,5,7,9] @=?
(take 5 $ (evalWriterT :: CpsW.WriterT [Char] [] Integer -> [Integer])
(do v <- foldr (mplus . return) mzero [(1::Integer)..]
lnot (isEven v)
return v))

, testCase "inversion :: strict StateT" $ [1,3,5,7,9] @=?
(take 5 $ SS.evalStateT (do v <- foldr (mplus . return) mzero [(1::Integer)..]
lnot (isEven v)
Expand Down

0 comments on commit cd44aa9

Please sign in to comment.