Skip to content

Commit

Permalink
Implement incoherent instance MonadLogic Logic and elaborate (>>-) in…
Browse files Browse the repository at this point in the history
… instance MonadLogic []
  • Loading branch information
Bodigrim committed Sep 9, 2024
1 parent 85270d6 commit cd0457b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 10 deletions.
38 changes: 29 additions & 9 deletions Control/Monad/Logic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,8 @@
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

#if MIN_VERSION_base(4,17,0)
{-# LANGUAGE Safe #-}
#else
{-# LANGUAGE Trustworthy #-}
#endif
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Avoid restricted function" #-}
Expand Down Expand Up @@ -60,6 +55,7 @@ module Control.Monad.Logic (
import Prelude (error, (-))

import Control.Applicative (Alternative(..), Applicative, liftA2, pure, (<*>), (*>))
import Control.Exception (Exception, evaluate, throw)
import Control.Monad (join, MonadPlus(..), Monad(..), fail)
import Control.Monad.Catch (MonadThrow, MonadCatch, throwM, catch)
import Control.Monad.Error.Class (MonadError(..))
Expand All @@ -71,18 +67,19 @@ import Control.Monad.State.Class (MonadState(..))
import Control.Monad.Trans (MonadTrans(..))
import Control.Monad.Zip (MonadZip (..))

import Data.Bool (otherwise)
import Data.Bool (Bool (..), otherwise)
import Data.Eq (Eq, (==))
import qualified Data.Foldable as F
import Data.Function (($), (.), const, on)
import Data.Functor (Functor(..), (<$>))
import Data.Int
import qualified Data.List as L
import Data.Maybe (Maybe(..))
import Data.Maybe (Maybe(..), maybe)
import Data.Monoid (Monoid (..))
import Data.Ord (Ord, (<=), (>), compare)
import Data.Semigroup (Semigroup (..))
import qualified Data.Traversable as T
import System.IO.Unsafe (unsafePerformIO)
import Text.Show (Show, showsPrec, showParen, showString, shows)
import Text.Read (Read, readPrec, Lexeme (Ident), parens, lexP, prec, readListPrec, readListPrecDefault)

Expand Down Expand Up @@ -459,7 +456,7 @@ instance MonadTrans LogicT where
instance (MonadIO m) => MonadIO (LogicT m) where
liftIO = lift . liftIO

instance (Monad m) => MonadLogic (LogicT m) where
instance {-# OVERLAPPABLE #-} (Monad m) => MonadLogic (LogicT m) where
-- 'msplit' is quite costly even if the base 'Monad' is 'Identity'.
-- Try to avoid it.
msplit m = lift $ unLogicT m ssk (return Nothing)
Expand All @@ -468,6 +465,29 @@ instance (Monad m) => MonadLogic (LogicT m) where
once m = LogicT $ \sk fk -> unLogicT m (\a _ -> sk a fk) fk
lnot m = LogicT $ \sk fk -> unLogicT m (\_ _ -> fk) (sk () fk)

instance {-# INCOHERENT #-} MonadLogic Logic where
-- Same as in the generic instance above
msplit m = lift $ unLogicT m ssk (return Nothing)
where
ssk a fk = return $ Just (a, lift fk >>= reflect)
once m = LogicT $ \sk fk -> unLogicT m (\a _ -> sk a fk) fk
lnot m = LogicT $ \sk fk -> unLogicT m (\_ _ -> fk) (sk () fk)

m >>- f
| isConstantFailure f = empty
-- Otherwise apply the default definition from Control.Monad.Logic.Class
| otherwise = msplit m >>= maybe empty (\(a, m') -> interleave (f a) (m' >>- f))

data MyException = MyException
deriving (Show)

instance Exception MyException

isConstantFailure :: (a -> Logic b) -> Bool
isConstantFailure f = unsafePerformIO $ do
let eval foo = runIdentity (unLogicT foo (const $ const $ Identity False) (Identity True))
evaluate (eval (f (throw MyException))) `catch` (\MyException -> pure False)

-- | @since 0.5.0
instance {-# OVERLAPPABLE #-} (Applicative m, F.Foldable m) => F.Foldable (LogicT m) where
foldMap f m = F.fold $ unLogicT m (fmap . mappend . f) (pure mempty)
Expand Down
24 changes: 23 additions & 1 deletion Control/Monad/Logic/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,28 @@
-------------------------------------------------------------------------

{-# LANGUAGE CPP #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE Trustworthy #-}

{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Avoid restricted function" #-}

module Control.Monad.Logic.Class (MonadLogic(..), reflect) where

import Prelude ()

import Control.Applicative (Alternative(..), Applicative(..))
import Control.Exception (Exception, evaluate, catch, throw)
import Control.Monad (MonadPlus, Monad(..))
import Control.Monad.Reader (ReaderT(..))
import Control.Monad.Trans (MonadTrans(..))
import qualified Control.Monad.State.Lazy as LazyST
import qualified Control.Monad.State.Strict as StrictST
import Data.Bool (Bool(..), otherwise)
import Data.Function (const, ($))
import Data.List (null)
import Data.Maybe (Maybe(..), maybe)
import System.IO.Unsafe (unsafePerformIO)
import Text.Show (Show)

#if MIN_VERSION_mtl(2,3,0)
import qualified Control.Monad.Writer.CPS as CpsW
Expand Down Expand Up @@ -364,6 +372,20 @@ instance MonadLogic [] where
msplit [] = pure Nothing
msplit (x:xs) = pure $ Just (x, xs)

m >>- f
| isConstantFailure f = []
-- Otherwise apply the default definition
| otherwise = msplit m >>= maybe empty (\(a, m') -> interleave (f a) (m' >>- f))

data MyException = MyException
deriving (Show)

instance Exception MyException

isConstantFailure :: (a -> [b]) -> Bool
isConstantFailure f = unsafePerformIO $
evaluate (null (f (throw MyException))) `catch` (\MyException -> pure False)

-- | Note that splitting a transformer does
-- not allow you to provide different input
-- to the monadic object returned.
Expand Down
15 changes: 15 additions & 0 deletions Control/Monad/Logic/Do.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module Control.Monad.Logic.Do
( (>>=)
, (>>)
) where

import qualified Prelude as P
import Control.Applicative (empty)
import Data.Maybe (maybe)
import Control.Monad.Logic.Class

(>>=) :: MonadLogic m => m a -> (a -> m b) -> m b
(>>=) = (>>-)

(>>) :: MonadLogic m => m () -> m a -> m a
m >> k = msplit m P.>>= maybe empty (\(_, m') -> interleave k (m' >> k))
17 changes: 17 additions & 0 deletions test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ import qualified Control.Monad.State.Lazy as SL
import qualified Control.Monad.State.Strict as SS
import Data.Maybe

#if MIN_VERSION_base(4,17,0)
import GHC.IsList (IsList(..))
#else
import GHC.Exts (IsList(..))
#endif

#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup (Semigroup (..))
#endif
Expand Down Expand Up @@ -134,6 +140,12 @@ main = defaultMain $

, testCase "observeMany multi" $ [5,3] @=? observeMany 2 odds5down
, testCase "observeMany none" $ ([] :: [Integer]) @=? observeMany 2 mzero

, testCase "(>>-) Logic" $ do
let sample = fromList [1, 2, 3] :: Logic Integer
(sample >>- const (mempty :: Logic Integer)) @?= mempty
(sample >>- (\x -> fmap (+ x) (fromList [100, 200, 300]))) @?= fromList [101,102,201,103,301,202,203,302,303]
(sample >>- (\x -> if odd x then fmap (+ x) (fromList [100, 200, 300]) else mempty)) @?= fromList [101,103,201,203,301,303]
]

--------------------------------------------------
Expand Down Expand Up @@ -182,6 +194,11 @@ main = defaultMain $
extract (msplit op) @?= [Just 1]
extract (msplit op >>= (\(Just (_,nxt)) -> msplit nxt)) @?= [Just 2]

, testCase "(>>-) []" $ do
(sample >>- const ([] :: [Integer])) @?= []
(sample >>- (\x -> fmap (+ x) [100, 200, 300])) @?= [101,102,201,103,301,202,203,302,303]
(sample >>- (\x -> if odd x then fmap (+ x) [100, 200, 300] else [])) @?= [101,103,201,203,301,303]

, testCase "msplit ReaderT" $ do
let op = ask
extract = fmap fst . catMaybes . flip runReaderT sample
Expand Down

0 comments on commit cd0457b

Please sign in to comment.