Skip to content

Commit

Permalink
Implement a new multi-SAT algorithm for What4 solvers.
Browse files Browse the repository at this point in the history
Fixes #1125
  • Loading branch information
robdockins committed Jul 20, 2021
1 parent dd8fca4 commit 62d413e
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 36 deletions.
20 changes: 20 additions & 0 deletions src/Cryptol/Symbolic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ module Cryptol.Symbolic
, modelPred
, varModelPred
, varToExpr
, flattenShape
, flattenShapes
) where


Expand Down Expand Up @@ -222,6 +224,24 @@ ppVarShape sym (VarRecord fs) =
ppField (f,v) = pp f <+> char '=' <+> ppVarShape sym v


-- | Flatten structured shapes (like tuples and sequences), leaving only
-- a sequence of variable shapes of base type.
flattenShapes :: [VarShape sym] -> [VarShape sym] -> [VarShape sym]
flattenShapes [] tl = tl
flattenShapes (x:xs) tl = flattenShape x (flattenShapes xs tl)

flattenShape :: VarShape sym -> [VarShape sym] -> [VarShape sym]
flattenShape x tl =
case x of
VarBit{} -> x : tl
VarInteger{} -> x : tl
VarRational{} -> x : tl
VarWord{} -> x : tl
VarFloat{} -> x : tl
VarFinSeq _ vs -> flattenShapes vs tl
VarTuple vs -> flattenShapes vs tl
VarRecord fs -> flattenShapes (recordElements fs) tl

varShapeToValue :: Backend sym => sym -> VarShape sym -> GenValue sym
varShapeToValue sym var =
case var of
Expand Down
159 changes: 123 additions & 36 deletions src/Cryptol/Symbolic/What4.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
Expand Down Expand Up @@ -39,7 +40,7 @@ import qualified Control.Exception as X
import System.IO (Handle)
import Data.Time
import Data.IORef
import Data.List (intercalate)
import Data.List (intercalate, tails, inits)
import Data.List.NonEmpty (NonEmpty(..))
import Data.Proxy
import qualified Data.Map as Map
Expand All @@ -62,7 +63,9 @@ import Cryptol.Backend.What4
import qualified Cryptol.Eval as Eval
import qualified Cryptol.Eval.Concrete as Concrete
import qualified Cryptol.Eval.Value as Eval
import Cryptol.Eval.Type (TValue)
import Cryptol.Eval.What4

import Cryptol.Parser.Position (emptyRange)
import Cryptol.Symbolic
import Cryptol.TypeCheck.AST
Expand Down Expand Up @@ -462,12 +465,14 @@ satProveOffline hashConsing warnUninterp ProverCommand{ .. } outputContinuation
onError msg minp = pure (Right (Just msg, M.minpModuleEnv minp), [])


{-
decSatNum :: SatNum -> SatNum
decSatNum (SomeSat n) | n > 0 = SomeSat (n-1)
decSatNum n = n
-}


multiSATQuery ::
multiSATQuery :: forall sym t fm.
sym ~ W4.ExprBuilder t CryptolState fm =>
What4 sym ->
W4ProverConfig ->
Expand Down Expand Up @@ -500,39 +505,122 @@ multiSATQuery sym (W4ProverConfig (AnOnlineAdapter nm fs _opts (_ :: Proxy s)))
(\ (proc :: W4.SolverProcess t s) ->
do W4.assume (W4.solverConn proc) query
res <- W4.checkAndGetModel proc "query"
pres <- case res of
W4.Unknown -> return (Left (ProverError "Solver returned UNKNOWN"))
W4.Unsat _ -> return (Left (ThmResult (map unFinType ts)))
W4.Sat evalFn ->
do xs <- mapM (varShapeToConcrete evalFn) args
let model = computeModel primMap ts xs
blockingPred <- computeBlockingPred sym args xs
return (Right (model, blockingPred))
case pres of
Left x -> pure (Just nm, x)
Right (mdl,block) ->
do W4.assume (W4.solverConn proc) block
mdls <- (mdl:) <$> computeMoreModels proc (decSatNum satNum0)
case res of
W4.Unknown -> return (Just nm, ProverError "Solver returned UNKNOWN")
W4.Unsat _ -> return (Just nm, ThmResult (map unFinType ts))
W4.Sat evalFn ->
do xs <- mapM (varShapeToConcrete evalFn) args
let mdl = computeModel primMap ts xs
-- NB, we flatten these shapes to make sure that we can split
-- our search across all of the atomic variables
let vs = flattenShapes args []
let cs = flattenShapes xs []
mdls <- runMultiSat satNum0 $
do yield mdl
computeMoreModels proc vs cs
return (Just nm, AllSatResult mdls))

where
computeMoreModels _proc (SomeSat n) | n <= 0 = return [] -- should never happen...

computeMoreModels proc satNum =
do res <- W4.checkAndGetModel proc "more models"
-- This search procedure uses incremental solving and push/pop to split on the concrete
-- values of variables, while also helping to prevent the accumulation of unhelpful
-- lemmas in the solver state. This algorithm is basically taken from:
-- http://theory.stanford.edu/%7Enikolaj/programmingz3.html#sec-blocking-evaluations
computeMoreModels ::
W4.SolverProcess t s ->
[VarShape (What4 sym)] ->
[VarShape Concrete.Concrete] ->
MultiSat ()
computeMoreModels proc vs cs =
-- Enumerate all the ways to split up the current model
forM_ (computeSplits vs cs) $ \ (prefix, vi, ci, suffix) ->
do -- open a new solver frame
liftIO $ W4.push proc
-- force the selected pair to be different
liftIO $ W4.assume (W4.solverConn proc) =<< W4.notPred (w4 sym) =<< computeModelPred sym vi ci
-- force the prefix values to be the same
liftIO $ forM_ prefix $ \(v,c) ->
W4.assume (W4.solverConn proc) =<< computeModelPred sym v c
-- under these assumptions, find all the remaining models
findMoreModels proc (vi:suffix)
-- pop the current assumption frame
liftIO $ W4.pop proc

findMoreModels ::
W4.SolverProcess t s ->
[VarShape (What4 sym)] ->
MultiSat ()
findMoreModels proc vs =
-- see if our current assumptions are consistent
do res <- liftIO (W4.checkAndGetModel proc "find model")
case res of
W4.Unknown -> return []
W4.Unsat _ -> return []
-- if the solver gets stuck, drop all the way out and stop search
W4.Unknown -> done
-- if our assumptions are already unsatisfiable, stop search and return
W4.Unsat _ -> return ()
W4.Sat evalFn ->
do xs <- mapM (varShapeToConcrete evalFn) args
let model = computeModel primMap ts xs
case satNum of
-- final model
SomeSat n | n <= 1 -> return [model]
-- keep going
_ -> do blockingPred <- computeBlockingPred sym args xs
W4.assume (W4.solverConn proc) blockingPred
(model:) <$> computeMoreModels proc (decSatNum satNum)
-- We found a model. Record it and then use it to split the remaining
-- search variables some more.
do xs <- liftIO (mapM (varShapeToConcrete evalFn) args)
yield (computeModel primMap ts xs)
cs <- liftIO (mapM (varShapeToConcrete evalFn) vs)
computeMoreModels proc vs cs

-- == Support operations for multi-SAT ==
type Models = [[(TValue, Expr, Concrete.Value)]]

newtype MultiSat a =
MultiSat { unMultiSat :: Models -> SatNum -> (a -> Models -> SatNum -> IO Models) -> IO Models }

instance Functor MultiSat where
fmap f m = MultiSat (\ms satNum k -> unMultiSat m ms satNum (k . f))

instance Applicative MultiSat where
pure x = MultiSat (\ms satNum k -> k x ms satNum)
mf <*> mx = mf >>= \f -> fmap f mx

instance Monad MultiSat where
m >>= f = MultiSat (\ms satNum k -> unMultiSat m ms satNum (\x ms' satNum' -> unMultiSat (f x) ms' satNum' k))

instance MonadIO MultiSat where
liftIO m = MultiSat (\ms satNum k -> do x <- m; k x ms satNum)

runMultiSat :: SatNum -> MultiSat a -> IO Models
runMultiSat satNum m = reverse <$> unMultiSat m [] satNum (\_ ms _ -> return ms)

done :: MultiSat a
done = MultiSat (\ms _satNum _k -> return ms)

yield :: [(TValue, Expr, Concrete.Value)] -> MultiSat ()
yield mdl = MultiSat (\ms satNum k ->
case satNum of
SomeSat n
| n > 1 -> k () (mdl:ms) (SomeSat (n-1))
| otherwise -> return (mdl:ms)
_ -> k () (mdl:ms) satNum)

-- Compute all the ways to split a sequences of variables
-- and concrete values for those variables. Each element
-- of the list consists of a prefix of (varaible,value)
-- pairs whose values we will fix, a single (varaible,value)
-- pair that we will force to be different, and a list of
-- additional unconstrained variables.
computeSplits ::
[VarShape (What4 sym)] ->
[VarShape Concrete.Concrete] ->
[ ( [(VarShape (What4 sym), VarShape Concrete.Concrete)]
, VarShape (What4 sym)
, VarShape Concrete.Concrete
, [VarShape (What4 sym)]
)
]
computeSplits vs cs = reverse
[ (prefix, v, c, tl)
| prefix <- inits (zip vs cs)
| v <- vs
| c <- cs
| tl <- tail (tails vs)
]
-- == END Support operations for multi-SAT ==

singleQuery ::
sym ~ W4.ExprBuilder t CryptolState fm =>
Expand Down Expand Up @@ -610,15 +698,14 @@ singleQuery sym (W4ProverConfig (AnOnlineAdapter nm fs _opts (_ :: Proxy s)))
)


computeBlockingPred ::
computeModelPred ::
sym ~ W4.ExprBuilder t CryptolState fm =>
What4 sym ->
[VarShape (What4 sym)] ->
[VarShape Concrete.Concrete] ->
VarShape (What4 sym) ->
VarShape Concrete.Concrete ->
IO (W4.Pred sym)
computeBlockingPred sym vs xs =
do res <- doW4Eval (w4 sym) (modelPred sym vs xs)
W4.notPred (w4 sym) (snd res)
computeModelPred sym v c =
snd <$> doW4Eval (w4 sym) (varModelPred sym (v, c))

varShapeToConcrete ::
W4.GroundEvalFn t ->
Expand Down

0 comments on commit 62d413e

Please sign in to comment.