diff --git a/Control/Monad/Logic.hs b/Control/Monad/Logic.hs index 5788564..be0ea9d 100644 --- a/Control/Monad/Logic.hs +++ b/Control/Monad/Logic.hs @@ -25,13 +25,8 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE UndecidableInstances #-} - -#if MIN_VERSION_base(4,17,0) -{-# LANGUAGE Safe #-} -#else {-# LANGUAGE Trustworthy #-} -#endif +{-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# HLINT ignore "Avoid restricted function" #-} @@ -60,6 +55,7 @@ module Control.Monad.Logic ( import Prelude (error, (-)) import Control.Applicative (Alternative(..), Applicative, liftA2, pure, (<*>), (*>)) +import Control.Exception (Exception, evaluate, throw) import Control.Monad (join, MonadPlus(..), Monad(..), fail) import Control.Monad.Catch (MonadThrow, MonadCatch, throwM, catch) import Control.Monad.Error.Class (MonadError(..)) @@ -71,18 +67,19 @@ import Control.Monad.State.Class (MonadState(..)) import Control.Monad.Trans (MonadTrans(..)) import Control.Monad.Zip (MonadZip (..)) -import Data.Bool (otherwise) +import Data.Bool (Bool (..), otherwise) import Data.Eq ((==)) import qualified Data.Foldable as F import Data.Function (($), (.), const) import Data.Functor (Functor(..), (<$>)) import Data.Int import qualified Data.List as L -import Data.Maybe (Maybe(..)) +import Data.Maybe (Maybe(..), maybe) import Data.Monoid (Monoid (..)) import Data.Ord ((<=), (>)) import Data.Semigroup (Semigroup (..)) import qualified Data.Traversable as T +import System.IO.Unsafe (unsafePerformIO) import Text.Show (Show, showsPrec, showParen, showString, shows) import Text.Read (Read, readPrec, Lexeme (Ident), parens, lexP, prec, readListPrec, readListPrecDefault) @@ -459,7 +456,7 @@ instance MonadTrans LogicT where instance (MonadIO m) => MonadIO (LogicT m) where liftIO = lift . liftIO -instance (Monad m) => MonadLogic (LogicT m) where +instance {-# OVERLAPPABLE #-} (Monad m) => MonadLogic (LogicT m) where -- 'msplit' is quite costly even if the base 'Monad' is 'Identity'. -- Try to avoid it. msplit m = lift $ unLogicT m ssk (return Nothing) @@ -468,6 +465,29 @@ instance (Monad m) => MonadLogic (LogicT m) where once m = LogicT $ \sk fk -> unLogicT m (\a _ -> sk a fk) fk lnot m = LogicT $ \sk fk -> unLogicT m (\_ _ -> fk) (sk () fk) +instance {-# INCOHERENT #-} MonadLogic Logic where + -- Same as in the generic instance above + msplit m = lift $ unLogicT m ssk (return Nothing) + where + ssk a fk = return $ Just (a, lift fk >>= reflect) + once m = LogicT $ \sk fk -> unLogicT m (\a _ -> sk a fk) fk + lnot m = LogicT $ \sk fk -> unLogicT m (\_ _ -> fk) (sk () fk) + + m >>- f + | isConstantFailure f = empty + -- Otherwise apply default definition from Control.Monad.Logic.Class + | otherwise = msplit m >>= maybe empty (\(a, m') -> interleave (f a) (m' >>- f)) + +data MyException = MyException + deriving (Show) + +instance Exception MyException + +isConstantFailure :: (a -> Logic b) -> Bool +isConstantFailure f = unsafePerformIO $ do + let eval foo = runIdentity (unLogicT foo (const $ const $ Identity False) (Identity True)) + evaluate (eval (f (throw MyException))) `catch` (\MyException -> pure False) + -- | @since 0.5.0 instance {-# OVERLAPPABLE #-} (Applicative m, F.Foldable m) => F.Foldable (LogicT m) where foldMap f m = F.fold $ unLogicT m (fmap . mappend . f) (pure mempty)