From fc2c8a83602e9aa830014444b4986083b65a7c10 Mon Sep 17 00:00:00 2001 From: Levent Erkok Date: Wed, 13 Nov 2024 14:20:42 -0800 Subject: [PATCH] A first shot at more higher-order list functions --- Data/SBV/Core/Symbolic.hs | 50 ++++++++++++++---------- Data/SBV/Lambda.hs | 8 +++- Data/SBV/List.hs | 38 ++++++++++++++++--- Data/SBV/SMT/SMTLib2.hs | 80 ++++++++++++++++++++++++++++++++------- 4 files changed, 135 insertions(+), 41 deletions(-) diff --git a/Data/SBV/Core/Symbolic.hs b/Data/SBV/Core/Symbolic.hs index e0ecc12a8..5197e7a85 100644 --- a/Data/SBV/Core/Symbolic.hs +++ b/Data/SBV/Core/Symbolic.hs @@ -520,21 +520,24 @@ instance Show RegExOp where show (RegExNEq r1 r2) = "(distinct " ++ regExpToSMTString r1 ++ " " ++ regExpToSMTString r2 ++ ")" -- | Sequence operations. -data SeqOp = SeqConcat -- ^ See StrConcat - | SeqLen -- ^ See StrLen - | SeqUnit -- ^ See StrUnit - | SeqNth -- ^ See StrNth - | SeqSubseq -- ^ See StrSubseq - | SeqIndexOf -- ^ See StrIndexOf - | SeqContains -- ^ See StrContains - | SeqPrefixOf -- ^ See StrPrefixOf - | SeqSuffixOf -- ^ See StrSuffixOf - | SeqReplace -- ^ See StrReplace - | SeqMap String -- ^ Mapping over sequences - | SeqMapI String -- ^ Mapping over sequences with offset - | SeqFoldLeft String -- ^ Folding of sequences - | SeqFoldLeftI String -- ^ Folding of sequences with offset - | SBVReverse Kind -- ^ Reversal of sequences. NB. Also works for strings; hence the name. +data SeqOp = SeqConcat -- ^ See StrConcat + | SeqLen -- ^ See StrLen + | SeqUnit -- ^ See StrUnit + | SeqNth -- ^ See StrNth + | SeqSubseq -- ^ See StrSubseq + | SeqIndexOf -- ^ See StrIndexOf + | SeqContains -- ^ See StrContains + | SeqPrefixOf -- ^ See StrPrefixOf + | SeqSuffixOf -- ^ See StrSuffixOf + | SeqReplace -- ^ See StrReplace + | SeqMap String -- ^ Mapping over sequences + | SeqMapI String -- ^ Mapping over sequences with offset + | SeqFoldLeft String -- ^ Folding of sequences + | SeqFoldLeftI String -- ^ Folding of sequences with offset + | SBVReverse Kind -- ^ Reversal of sequences. NB. Also works for strings; hence the name. + | SBVSeqFilter Kind String -- ^ filter the list. Kind is the element type + | SBVSeqAll Kind String -- ^ map the function and reduce via and, with base true. Kind is the element type. + | SBVSeqAny Kind String -- ^ map the function and reduce via or, with base false. Kind is the element type. deriving (Eq, Ord, G.Data, NFData, Generic) -- | Show instance for SeqOp. Again, mapping is important. @@ -554,11 +557,18 @@ instance Show SeqOp where show (SeqFoldLeft s) = "seq.foldl " ++ s show (SeqFoldLeftI s) = "seq.foldli " ++ s - -- Note: This isn't part of SMTLib, we explicitly handle it - show (SBVReverse k) = let sk = show k - ssk | any isSpace sk = '(' : sk ++ ")" - | True = sk - in "sbv.reverse @" ++ ssk + -- Note: The followings aren't part of SMTLib, we explicitly handle it + show (SBVReverse k) = funcWithKind "sbv.reverse" k Nothing + show (SBVSeqFilter k s) = funcWithKind "sbv.seqFilter" k (Just s) + show (SBVSeqAll k s) = funcWithKind "sbv.seqAll" k (Just s) + show (SBVSeqAny k s) = funcWithKind "sbv.seqAny" k (Just s) + +-- helper for above +funcWithKind :: String -> Kind -> Maybe String -> String +funcWithKind f k mbExtra = f ++ " @" ++ ssk ++ maybe "" (' ':) mbExtra + where sk = show k + ssk | any isSpace sk = '(' : sk ++ ")" + | True = sk -- | Set operations. data SetOp = SetEqual diff --git a/Data/SBV/Lambda.hs b/Data/SBV/Lambda.hs index 50791e9da..855192b75 100644 --- a/Data/SBV/Lambda.hs +++ b/Data/SBV/Lambda.hs @@ -35,6 +35,8 @@ import Data.SBV.Core.Kind import Data.SBV.SMT.SMTLib2 import Data.SBV.Utils.PrettyNum +import qualified Data.Map.Strict as M + import Data.SBV.Core.Symbolic hiding (mkNewState) import qualified Data.SBV.Core.Symbolic as S (mkNewState) @@ -336,7 +338,7 @@ toLambda level curProgInfo cfg expectedKind result@Result{resAsgns = SBVPgm asgn where mkAsgn (sv, e@(SBVApp (Label l) _)) = ((sv, converter e), Just l) mkAsgn (sv, e) = ((sv, converter e), Nothing) - converter = cvtExp curProgInfo (capabilities (solver cfg)) rm tableMap + converter = cvtExp curProgInfo (capabilities (solver cfg)) rm tableMap funcMap out :: SV @@ -348,6 +350,10 @@ toLambda level curProgInfo cfg expectedKind result@Result{resAsgns = SBVPgm asgn rm = roundingMode cfg + -- NB. The following isn't really kosher since a lambda might refer to an SBV function + -- like reverse/any/all etc. But let's cross that bridge when we get to it + funcMap = M.empty + -- NB. The following is dead-code, since we ensure tbls is empty -- We used to support this, but there are issues, so dropping support -- See, for instance, https://github.com/LeventErkok/sbv/issues/664 diff --git a/Data/SBV/List.hs b/Data/SBV/List.hs index 93ceb6cac..8078ed36e 100644 --- a/Data/SBV/List.hs +++ b/Data/SBV/List.hs @@ -561,8 +561,15 @@ concat = foldl (++) [] -- True -- >>> all isEven [2, 4, 6, 1, 8, 10 :: Integer] -- False -all :: SymVal a => (SBV a -> SBool) -> SList a -> SBool -all f = foldl (\sofar e -> sofar .&& f e) sTrue +all :: forall a. SymVal a => (SBV a -> SBool) -> SList a -> SBool +all f l + | Just l' <- unliteral l + = sAll f (P.map literal l') + | True + = SBV $ SVal KBool $ Right $ cache r + where r st = do sva <- sbvToSV st l + lam <- lambdaStr st KBool f + newExpr st KBool (SBVApp (SeqOp (SBVSeqAll (kindOf (Proxy @a)) lam)) [sva]) -- | Check some element satisfies the predicate. -- -- @@ -571,8 +578,15 @@ all f = foldl (\sofar e -> sofar .&& f e) sTrue -- False -- >>> any isEven [2, 4, 6, 1, 8, 10 :: Integer] -- True -any :: SymVal a => (SBV a -> SBool) -> SList a -> SBool -any f = foldl (\sofar e -> sofar .|| f e) sFalse +any :: forall a. SymVal a => (SBV a -> SBool) -> SList a -> SBool +any f l + | Just l' <- unliteral l + = sAny f (P.map literal l') + | True + = SBV $ SVal KBool $ Right $ cache r + where r st = do sva <- sbvToSV st l + lam <- lambdaStr st KBool f + newExpr st KBool (SBVApp (SeqOp (SBVSeqAny (kindOf (Proxy @a)) lam)) [sva]) -- | @filter f xs@ filters the list with the given predicate. -- @@ -580,8 +594,20 @@ any f = foldl (\sofar e -> sofar .|| f e) sFalse -- [2,4,6,8,10] :: [SInteger] -- >>> filter (\x -> x `sMod` 2 ./= 0) [1 .. 10 :: Integer] -- [1,3,5,7,9] :: [SInteger] -filter :: SymVal a => (SBV a -> SBool) -> SList a -> SList a -filter f = foldl (\sofar e -> sofar ++ ite (f e) (singleton e) []) [] +filter :: forall a. SymVal a => (SBV a -> SBool) -> SList a -> SList a +filter f l + | Just l' <- unliteral l, Just concResult <- concreteFilter l' + = literal concResult + | True + = SBV $ SVal k $ Right $ cache r + where concreteFilter l' = case P.map (unliteral . f . literal) l' of + xs | P.any isNothing xs -> Nothing + | True -> Just [e | (True, e) <- P.zip (catMaybes xs) l'] + + k = kindOf (Proxy @(SList a)) + r st = do sva <- sbvToSV st l + lam <- lambdaStr st KBool f + newExpr st k (SBVApp (SeqOp (SBVSeqFilter (kindOf (Proxy @a)) lam)) [sva]) -- | Lift a unary operator over lists. lift1 :: forall a b. (SymVal a, SymVal b) => Bool -> SeqOp -> Maybe (a -> b) -> SBV a -> SBV b diff --git a/Data/SBV/SMT/SMTLib2.hs b/Data/SBV/SMT/SMTLib2.hs index 5bcf49660..79fe07a42 100644 --- a/Data/SBV/SMT/SMTLib2.hs +++ b/Data/SBV/SMT/SMTLib2.hs @@ -232,11 +232,11 @@ cvt ctx curProgInfo kindInfo isSat comments allInputs (_, consts) tbls uis defs ++ [ "; --- uninterpreted constants ---" ] ++ concatMap (declUI curProgInfo) uis ++ [ "; --- SBV Function definitions" | not (null funcMap) ] - ++ concat [ declSBVFunc op nm | (op, nm) <- M.toAscList funcMap ] + ++ concat [declSBVFunc op nm | (op, nm) <- M.toAscList funcMap] ++ [ "; --- user defined functions ---"] ++ userDefs ++ [ "; --- assignments ---" ] - ++ concatMap (declDef curProgInfo cfg tableMap) asgns + ++ concatMap (declDef curProgInfo cfg tableMap funcMap) asgns ++ [ "; --- delayedEqualities ---" ] ++ map (\s -> "(assert " ++ s ++ ")") delayedEqualities ++ [ "; --- formula ---" ] @@ -290,8 +290,22 @@ cvt ctx curProgInfo kindInfo isSat comments allInputs (_, consts) tbls uis defs | True = Just $ Left s -- SBV only functions. - funcMap = M.fromList [(op, "|sbv.reverse_" ++ show k ++ "|") | (op, k) <- revs] - where revs = nub [(op, k) | op@(SeqOp (SBVReverse k)) <- G.universeBi asgnsSeq] + funcMap = M.fromList $ [(op, "|sbv.reverse_" ++ show k ++ "|") | (op@(SeqOp (SBVReverse k )), _) <- specials] + ++ [(op, "|sbv.seqFilter_" ++ show k ++ idx i ++ "|") | (op@(SeqOp (SBVSeqFilter k _)), i) <- specials] + ++ [(op, "|sbv.seqAll_" ++ show k ++ idx i ++ "|") | (op@(SeqOp (SBVSeqAll k _)), i) <- specials] + ++ [(op, "|sbv.seqAny_" ++ show k ++ idx i ++ "|") | (op@(SeqOp (SBVSeqAny k _)), i) <- specials] + where specials = zip (nub [op | op@(SeqOp so) <- G.universeBi asgnsSeq, isSpecial so]) [0..] + + isSpecial SBVReverse{} = True + isSpecial SBVSeqFilter{} = True + isSpecial SBVSeqAll{} = True + isSpecial SBVSeqAny{} = True + isSpecial _ = False + + -- if index 0, then ignore it; other wise add it. This distinguishes different functions passed to all/any + idx :: Int -> String + idx 0 = "" + idx i = show i asgns = F.toList asgnsSeq @@ -305,6 +319,9 @@ declSBVFunc :: Op -> String -> [String] declSBVFunc op nm = case op of SeqOp (SBVReverse KString) -> mkStringRev SeqOp (SBVReverse (KList k)) -> mkSeqRev (KList k) + SeqOp (SBVSeqFilter ek f) -> mkFilter ek f + SeqOp (SBVSeqAll ek f) -> mkAnyAll True ek f + SeqOp (SBVSeqAny ek f) -> mkAnyAll False ek f _ -> error $ "Data.SBV.declSBVFunc: Unexpected internal function: " ++ show (op, nm) where mkStringRev = [ "(define-fun-rec " ++ nm ++ " ((str String)) String" , " (ite (= str \"\")" @@ -321,6 +338,24 @@ declSBVFunc op nm = case op of ] where t = smtType k + mkAnyAll isAll ek f = [ "(define-fun-rec " ++ nm ++ " ((lst " ++ t ++ ")) Bool" + , " (ite (= lst (as seq.empty " ++ t ++ "))" + , " " ++ base + , " (" ++ conn ++ " (select " ++ f ++ " (seq.nth lst 0)) (" ++ nm ++ " (seq.extract lst 1 (- (seq.len lst) 1))))))" + ] + where t = smtType (KList ek) + (base, conn) | isAll = ("true", "and") + | True = ("false", "or") + + mkFilter k f = [ "(define-fun-rec " ++ nm ++ " ((lst " ++ t ++ ")) " ++ t + , " (ite (= lst (as seq.empty " ++ t ++ "))" + , " (as seq.empty " ++ t ++ ")" + , " (ite (select " ++ f ++ " (seq.nth lst 0))" + , " (seq.++ (seq.unit (seq.nth lst 0)) (" ++ nm ++ " (seq.extract lst 1 (- (seq.len lst) 1))))" + , " (" ++ nm ++ " (seq.extract lst 1 (- (seq.len lst) 1))))))" + ] + where t = smtType (KList k) + -- | Declare new sorts declSort :: (String, Maybe [String]) -> [String] declSort (s, _) @@ -472,7 +507,7 @@ cvtInc curProgInfo inps newKs (_, consts) tbls uis (SBVPgm asgnsSeq) cstrs cfg = -- table declarations ++ tableDecls -- expressions - ++ concatMap (declDef curProgInfo cfg tableMap) (F.toList asgnsSeq) + ++ concatMap (declDef curProgInfo cfg tableMap funcMap) (F.toList asgnsSeq) -- table setups ++ concat tableAssigns -- extra constraints @@ -488,6 +523,10 @@ cvtInc curProgInfo inps newKs (_, consts) tbls uis (SBVPgm asgnsSeq) cstrs cfg = (tableDecls, tableAssigns) = unzip $ map mkTable allTables + -- This isn't super kosher, since we might refer to an internal function in + -- the incremental context. But let's cross that bridge when we come to it. + funcMap = M.empty + -- If we need flattening in models, do emit the required lines if preset settings | any needsFlattening newKinds @@ -496,11 +535,11 @@ cvtInc curProgInfo inps newKs (_, consts) tbls uis (SBVPgm asgnsSeq) cstrs cfg = = [] where solverCaps = capabilities (solver cfg) -declDef :: ProgInfo -> SMTConfig -> TableMap -> (SV, SBVExpr) -> [String] -declDef curProgInfo cfg tableMap (s, expr) = +declDef :: ProgInfo -> SMTConfig -> TableMap -> FuncMap -> (SV, SBVExpr) -> [String] +declDef curProgInfo cfg tableMap funcMap (s, expr) = case expr of - SBVApp (Label m) [e] -> defineFun cfg (s, cvtSV e) (Just m) - e -> defineFun cfg (s, cvtExp curProgInfo caps rm tableMap e) Nothing + SBVApp (Label m) [e] -> defineFun cfg (s, cvtSV e) (Just m) + e -> defineFun cfg (s, cvtExp curProgInfo caps rm tableMap funcMap e) Nothing where caps = capabilities (solver cfg) rm = roundingMode cfg @@ -684,7 +723,8 @@ cvtType (SBVType []) = error "SBV.SMT.SMTLib2.cvtType: internal: received an emp cvtType (SBVType xs) = "(" ++ unwords (map smtType body) ++ ") " ++ smtType ret where (body, ret) = (init xs, last xs) -type TableMap = IM.IntMap String +type TableMap = IM.IntMap String +type FuncMap = M.Map Op String -- Present an SV, simply show cvtSV :: SV -> String @@ -698,8 +738,8 @@ getTable m i | Just tn <- i `IM.lookup` m = tn | True = "table" ++ show i -- constant tables are always named this way -cvtExp :: ProgInfo -> SolverCapabilities -> RoundingMode -> TableMap -> SBVExpr -> String -cvtExp curProgInfo caps rm tableMap expr@(SBVApp _ arguments) = sh expr +cvtExp :: ProgInfo -> SolverCapabilities -> RoundingMode -> TableMap -> FuncMap -> SBVExpr -> String +cvtExp curProgInfo caps rm tableMap funcMap expr@(SBVApp _ arguments) = sh expr where hasPB = supportsPseudoBooleans caps hasInt2bv = supportsInt2bv caps hasDistinct = supportsDistinct caps @@ -788,6 +828,15 @@ cvtExp curProgInfo caps rm tableMap expr@(SBVApp _ arguments) = sh expr = let idx v = "(" ++ s ++ "_constrIndex " ++ v ++ ")" in "(" ++ o ++ " " ++ idx a ++ " " ++ idx b ++ ")" unintComp o sbvs = error $ "SBV.SMT.SMTLib2.sh.unintComp: Unexpected arguments: " ++ show (o, sbvs, map kindOf arguments) + getFuncName op = case op `M.lookup` funcMap of + Just n -> n + Nothing -> error $ unlines [ "" + , "*** Cannot translate operator: " ++ show op + , "***" + , "*** Note that this operator isn't currently supported in incremental query mode." + , "*** If you are not in query mode, or would like support for this feature, please report!" + ] + stringOrChar KString = True stringOrChar KChar = True stringOrChar _ = False @@ -974,8 +1023,11 @@ cvtExp curProgInfo caps rm tableMap expr@(SBVApp _ arguments) = sh expr sh (SBVApp (RegExOp o@RegExEq{}) []) = show o sh (SBVApp (RegExOp o@RegExNEq{}) []) = show o - -- Reverse is special, since we need to generate call to the internally generated function - sh (SBVApp (SeqOp (SBVReverse k)) args) = "(|sbv.reverse_" ++ show k ++ "| " ++ unwords (map cvtSV args) ++ ")" + -- Reverse and higher order functions are special + sh (SBVApp o@(SeqOp SBVReverse{}) args) = "(" ++ getFuncName o ++ " " ++ unwords (map cvtSV args) ++ ")" + sh (SBVApp o@(SeqOp SBVSeqFilter{}) args) = "(" ++ getFuncName o ++ " " ++ unwords (map cvtSV args) ++ ")" + sh (SBVApp o@(SeqOp SBVSeqAll{} ) args) = "(" ++ getFuncName o ++ " " ++ unwords (map cvtSV args) ++ ")" + sh (SBVApp o@(SeqOp SBVSeqAny{} ) args) = "(" ++ getFuncName o ++ " " ++ unwords (map cvtSV args) ++ ")" sh (SBVApp (SeqOp op) args) = "(" ++ show op ++ " " ++ unwords (map cvtSV args) ++ ")"