Skip to content

Commit

Permalink
feat: instance MonadLogic for WriterT
Browse files Browse the repository at this point in the history
I have implemented `MonadLogic` in `WriterT` as well.

Until now, `WriterT` should not be used because it easily causes space leaks.
However, [Control.Monad.Trans.Writer.CPS](https://www.stackage.org/haddock/nightly-2022-12-08/transformers-0.5.6.2/Control-Monad-Trans-Writer-CPS.html) no longer causes space leaks and can be used.
Therefore, we would like to use `WriterT`, but it is inconvenient if it cannot be combined with `MonadLogic`.

Therefore, I implemented it and added tests where `ReaderT` and `StateT` are tested.

To use the CPS version of `WriterT`, I raised the mtl version requirement to [feat(Control.Monad.Writer.CPS): re export runWriterT by ncaq · Pull Request #136 · haskell/mtl](haskell/mtl#136) has not yet been imported, so we have no choice but to rely directly on transformers.
In the end, mtl depends on transformers, so we have determined that this is not a critical problem.
We have taken care to make it easy to remove them when they are no longer needed by doing a limited `import`.
  • Loading branch information
ncaq committed Jan 18, 2023
1 parent 3070bfb commit 10dbf8f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 6 deletions.
14 changes: 14 additions & 0 deletions Control/Monad/Logic/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ import Control.Monad.Trans (MonadTrans(..))
import qualified Control.Monad.State.Lazy as LazyST
import qualified Control.Monad.State.Strict as StrictST

#if MIN_VERSION_mtl(2,3,0)
import qualified Control.Monad.Writer.CPS as CpsW
import qualified Control.Monad.Trans.Writer.CPS as CpsW(writerT, runWriterT)
#endif

-- | A backtracking, logic programming monad.
class (Monad m, Alternative m) => MonadLogic m where
-- | Attempts to __split__ the computation, giving access to the first
Expand Down Expand Up @@ -347,6 +352,15 @@ instance MonadLogic m => MonadLogic (ReaderT e m) where
Nothing -> pure Nothing
Just (a, m) -> pure (Just (a, lift m))

#if MIN_VERSION_mtl(2,3,0)
instance (Monoid w, MonadLogic m, MonadPlus m) => MonadLogic (CpsW.WriterT w m) where
msplit wm = CpsW.writerT $ do
r <- msplit $ CpsW.runWriterT wm
case r of
Nothing -> pure (Nothing, mempty)
Just ((a, w), m) -> pure (Just (a, CpsW.writerT m), w)
#endif

-- | See note on splitting above.
instance (MonadLogic m, MonadPlus m) => MonadLogic (StrictST.StateT s m) where
msplit sm = StrictST.StateT $ \s ->
Expand Down
6 changes: 4 additions & 2 deletions logict.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ library

build-depends:
base >=4.3 && <5,
mtl >=2.0 && <2.4
mtl >=2.0 && <2.4,
transformers

if impl(ghc <8.0)
build-depends:
fail, transformers
fail

executable grandparents
buildable: False
Expand All @@ -69,6 +70,7 @@ test-suite logict-tests
async >=2.0,
logict,
mtl,
transformers,
tasty,
tasty-hunit

Expand Down
74 changes: 70 additions & 4 deletions test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@ import qualified Control.Monad.State.Lazy as SL
import qualified Control.Monad.State.Strict as SS
import Data.Maybe

#if MIN_VERSION_base(4,9,0)
#if MIN_VERSION_base(4,11,0)
#else
#if MIN_VERSION_base(4,9,0) && !MIN_VERSION_base(4,11,0)
import Data.Semigroup (Semigroup (..))
#endif
#else

-- required by base < 4.9 OR CPS Writer test
#if !MIN_VERSION_base(4,9,0) || MIN_VERSION_mtl(2,3,0)
import Data.Monoid
#endif

#if MIN_VERSION_mtl(2,3,0)
import qualified Control.Monad.Writer.CPS as CpsW (WriterT, execWriterT, tell)
import qualified Control.Monad.Trans.Writer.CPS as CpsW (runWriterT)
#endif

monadReader1 :: Assertion
monadReader1 = assertEqual "should be equal" [5 :: Int] $
Expand All @@ -53,6 +57,12 @@ nats, odds, oddsOrTwo,
oddsOrTwoUnfair, oddsOrTwoFair,
odds5down :: Monad m => LogicT m Integer

-- | A `WriterT` version of `evalStateT`.
#if MIN_VERSION_mtl(2,3,0)
evalWriterT :: (Monad m, Monoid w) => CpsW.WriterT w m a -> m a
evalWriterT = fmap fst . CpsW.runWriterT
#endif

#if MIN_VERSION_base(4,8,0)
nats = pure 0 `mplus` ((1 +) <$> nats)
#else
Expand Down Expand Up @@ -152,6 +162,13 @@ main = defaultMain $
z = mzero
in assertBool "ReaderT" $ null $ catMaybes $ runReaderT (msplit z) 0

#if MIN_VERSION_mtl(2,3,0)
, testCase "msplit mzero :: CPS WriterT" $
let z :: CpsW.WriterT (Sum Int) [] String
z = mzero
in assertBool "CPS WriterT" $ null $ catMaybes (evalWriterT (msplit z))
#endif

, testCase "msplit mzero :: LogicT" $
let z :: LogicT [] String
z = mzero
Expand Down Expand Up @@ -181,6 +198,15 @@ main = defaultMain $
extract (msplit op) @?= [sample]
extract (msplit op >>= (\(Just (_,nxt)) -> msplit nxt)) @?= []

#if MIN_VERSION_mtl(2,3,0)
, testCase "msplit CPS WriterT" $ do
let op :: CpsW.WriterT (Sum Integer) [] ()
op = CpsW.tell 1 `mplus` op
extract = CpsW.execWriterT
extract (msplit op) @?= [1]
extract (msplit op >>= \(Just (_,nxt)) -> msplit nxt) @?= [2]
#endif

, testCase "msplit LogicT" $ do
let op :: LogicT [] Integer
op = foldr (mplus . return) mzero sample
Expand Down Expand Up @@ -240,6 +266,13 @@ main = defaultMain $
(take 4 $ runReaderT (let oddsR = return 1 `mplus` liftM (2+) oddsR
in oddsR `interleave` return (2 :: Integer)) "go")

#if MIN_VERSION_mtl(2,3,0)
, testCase "fair disjunction :: CPS WriterT" $ [1,2,3,5] @=?
(take 4 $ evalWriterT (let oddsW :: CpsW.WriterT [Char] [] Integer
oddsW = return 1 `mplus` liftM (2+) oddsW
in oddsW `interleave` return (2 :: Integer)))
#endif

, testCase "fair disjunction :: strict StateT" $ [1,2,3,5] @=?
(take 4 $ SS.evalStateT (let oddsS = return 1 `mplus` liftM (2+) oddsS
in oddsS `interleave` return (2 :: Integer)) "go")
Expand Down Expand Up @@ -343,6 +376,17 @@ main = defaultMain $
if even x then return x else mzero
) "env")

#if MIN_VERSION_mtl(2,3,0)
, testCase "fair conjunction :: CPS WriterT" $ [2,4,6,8] @=?
(take 4 $ evalWriterT $
(let oddsW :: CpsW.WriterT [Char] [] Integer
oddsW = return (1 :: Integer) `mplus` liftM (2+) oddsW
oddsPlus n = oddsW >>= \a -> return (a + n)
in do x <- (return 0 `mplus` return 1) >>- oddsPlus
if even x then return x else mzero
))
#endif

, testCase "fair conjunction :: strict StateT" $ [2,4,6,8] @=?
(take 4 $ SS.evalStateT (let oddsS = return (1 :: Integer) `mplus` liftM (2+) oddsS
oddsPlus n = oddsS >>= \a -> return (a + n)
Expand Down Expand Up @@ -426,6 +470,20 @@ main = defaultMain $
in testCase "indivisible odds :: ReaderT" $ [3,5,7,11,13,17,19,23,29,31] @=?
(take 10 $ runReaderT oc "env")

#if MIN_VERSION_mtl(2,3,0)
, let iota n = msum (map return [1..n])
oddsW = return (1 :: Integer) `mplus` liftM (2+) oddsW
oc :: CpsW.WriterT [Char] [] Integer
oc = do n <- oddsW
guard (n > 1)
ifte (do d <- iota (n - 1)
guard (d > 1 && n `mod` d == 0))
(const mzero)
(return n)
in testCase "indivisible odds :: CPS WriterT" $ [3,5,7,11,13,17,19,23,29,31] @=?
(take 10 $ (fmap fst . CpsW.runWriterT) oc)
#endif

, let iota n = msum (map return [1..n])
oddsS = return (1 :: Integer) `mplus` liftM (2+) oddsS
oc = do n <- oddsS
Expand Down Expand Up @@ -499,6 +557,14 @@ main = defaultMain $
lnot (isEven v)
return v) "env")

#if MIN_VERSION_mtl(2,3,0)
, testCase "inversion :: CPS WriterT" $ [1,3,5,7,9] @=?
(take 5 $ (evalWriterT :: CpsW.WriterT [Char] [] Integer -> [Integer])
(do v <- foldr (mplus . return) mzero [(1::Integer)..]
lnot (isEven v)
return v))
#endif

, testCase "inversion :: strict StateT" $ [1,3,5,7,9] @=?
(take 5 $ SS.evalStateT (do v <- foldr (mplus . return) mzero [(1::Integer)..]
lnot (isEven v)
Expand Down

0 comments on commit 10dbf8f

Please sign in to comment.