From d764eca3da3c6f9f52d757fcf971b561b2b28921 Mon Sep 17 00:00:00 2001 From: Masahiro Sakai Date: Mon, 2 Dec 2024 23:32:51 +0900 Subject: [PATCH] use SAT.VarMap to represent definitions in SAT-related encoders/converters --- app/toysat/toysat.hs | 6 +++--- src/ToySolver/Converter/PB.hs | 21 ++++++++------------- src/ToySolver/Converter/SAT2KSAT.hs | 7 ++++--- src/ToySolver/Converter/SAT2MaxSAT.hs | 4 ++-- src/ToySolver/Converter/Tseitin.hs | 17 ++++++----------- src/ToySolver/SAT/Encoder/Tseitin.hs | 23 ++++++++++++++--------- test/Test/SAT/Encoder.hs | 8 ++++---- 7 files changed, 41 insertions(+), 45 deletions(-) diff --git a/app/toysat/toysat.hs b/app/toysat/toysat.hs index 49104806..9b090b14 100644 --- a/app/toysat/toysat.hs +++ b/app/toysat/toysat.hs @@ -853,7 +853,7 @@ solvePB opt solver formula = do where -- Use BOXED array to tie the knot a :: Array SAT.Var Bool - a = array (1,nv') $ assocs m ++ [(v, Tseitin.evalFormula a phi) | (v,phi) <- defs] + a = array (1,nv') $ assocs m ++ [(v, Tseitin.evalFormula a phi) | (v,phi) <- IntMap.toList defs] pbo <- PBO.newOptimizer2 solver obj'' (\m -> SAT.evalPBSum m obj') setupOptimizer pbo opt @@ -960,8 +960,8 @@ solveWBO' opt solver isMaxSat formula (wcnf, wbo2maxsat_info) wcnfFileName = do a :: Array SAT.Var Bool a = array (1,nv') $ assocs m ++ - [(v, Tseitin.evalFormula a phi) | (v, phi) <- defsTseitin] ++ - [(v, SAT.evalPBConstraint a constr) | (v, constr) <- defsPB] + [(v, Tseitin.evalFormula a phi) | (v, phi) <- IntMap.toList defsTseitin] ++ + [(v, SAT.evalPBConstraint a constr) | (v, constr) <- IntMap.toList defsPB] let softConstrs = [(c, constr) | (Just c, constr) <- PBFile.wboConstraints formula] diff --git a/src/ToySolver/Converter/PB.hs b/src/ToySolver/Converter/PB.hs index 3d2e3077..228d4718 100644 --- a/src/ToySolver/Converter/PB.hs +++ b/src/ToySolver/Converter/PB.hs @@ -216,7 +216,7 @@ quadratizePB' (formula, maxObj) = } , maxObj ) - , PBQuadratizeInfo $ TseitinInfo nv1 nv2 [(v, And [atom l1, atom l2]) | (v, (l1,l2)) <- prodDefs] + , PBQuadratizeInfo $ TseitinInfo nv1 nv2 (IntMap.fromList [(v, And [atom l1, atom l2]) | (v, (l1,l2)) <- prodDefs]) ) where nv1 = PBFile.pbNumVars formula @@ -556,13 +556,8 @@ wbo2pb wbo = runST $ do , WBO2PBInfo nv (PBFile.pbNumVars formula) defs ) -data WBO2PBInfo = WBO2PBInfo !Int !Int [(SAT.Var, PBFile.Constraint)] - deriving (Show) - --- TODO: change defs representation to SAT.VarMap -instance Eq WBO2PBInfo where - WBO2PBInfo nv1 nv2 defs == WBO2PBInfo nv1' nv2' defs' = - nv1 == nv1' && nv2 == nv2' && sortOn fst defs == sortOn fst defs' +data WBO2PBInfo = WBO2PBInfo !Int !Int (SAT.VarMap PBFile.Constraint) + deriving (Show, Eq) instance Transformer WBO2PBInfo where type Source WBO2PBInfo = SAT.Model @@ -570,7 +565,7 @@ instance Transformer WBO2PBInfo where instance ForwardTransformer WBO2PBInfo where transformForward (WBO2PBInfo _nv1 nv2 defs) m = - array (1, nv2) $ assocs m ++ [(v, SAT.evalPBConstraint m constr) | (v, constr) <- defs] + array (1, nv2) $ assocs m ++ [(v, SAT.evalPBConstraint m constr) | (v, constr) <- IntMap.toList defs] instance BackwardTransformer WBO2PBInfo where transformBackward (WBO2PBInfo nv1 _nv2 _defs) = SAT.restrictModel nv1 @@ -583,7 +578,7 @@ instance J.ToJSON WBO2PBInfo where , "num_transformed_variables" .= nv2 , "definitions" .= J.object [ fromString ("x" ++ show v) .= jPBConstraint constr - | (v, constr) <- defs + | (v, constr) <- IntMap.toList defs ] ] @@ -593,14 +588,14 @@ instance J.FromJSON WBO2PBInfo where WBO2PBInfo <$> obj .: "num_original_variables" <*> obj .: "num_transformed_variables" - <*> mapM f (Map.toList defs) + <*> (IntMap.fromList <$> mapM f (Map.toList defs)) where f (name, constr) = do v <- parseVarNameText name constr' <- parsePBConstraint constr return (v, constr') -addWBO :: (PrimMonad m, SAT.AddPBNL m enc) => enc -> PBFile.SoftFormula -> m (SAT.PBSum, [(SAT.Var, PBFile.Constraint)]) +addWBO :: (PrimMonad m, SAT.AddPBNL m enc) => enc -> PBFile.SoftFormula -> m (SAT.PBSum, (SAT.VarMap PBFile.Constraint)) addWBO db wbo = do SAT.newVars_ db $ PBFile.wboNumVars wbo @@ -671,7 +666,7 @@ addWBO db wbo = do modifyMutVar objRef ((offset,[trueLit]) :) obj <- liftM reverse $ readMutVar objRef - defs <- liftM reverse $ readMutVar defsRef + defs <- liftM IntMap.fromList $ readMutVar defsRef case PBFile.wboTopCost wbo of Nothing -> return () diff --git a/src/ToySolver/Converter/SAT2KSAT.hs b/src/ToySolver/Converter/SAT2KSAT.hs index 69b23aa0..12c51af5 100644 --- a/src/ToySolver/Converter/SAT2KSAT.hs +++ b/src/ToySolver/Converter/SAT2KSAT.hs @@ -22,6 +22,7 @@ module ToySolver.Converter.SAT2KSAT import Control.Monad import Control.Monad.ST import Data.Foldable (toList) +import qualified Data.IntMap.Lazy as IntMap import Data.Sequence ((<|), (|>)) import qualified Data.Sequence as Seq import Data.STRef @@ -38,7 +39,7 @@ sat2ksat k _ | k < 3 = error "ToySolver.Converter.SAT2KSAT.sat2ksat: k must be > sat2ksat k cnf = runST $ do let nv1 = CNF.cnfNumVars cnf db <- newCNFStore - defsRef <- newSTRef Seq.empty + defsRef <- newSTRef IntMap.empty SAT.newVars_ db nv1 forM_ (CNF.cnfClauses cnf) $ \clause -> do let loop lits = do @@ -49,12 +50,12 @@ sat2ksat k cnf = runST $ do case Seq.splitAt (k-1) lits of (lits1, lits2) -> do SAT.addClause db (toList (lits1 |> (-v))) - modifySTRef' defsRef (|> (v, toList lits1)) + modifySTRef' defsRef (IntMap.insert v (toList lits1)) loop (v <| lits2) loop $ Seq.fromList $ SAT.unpackClause clause cnf2 <- getCNFFormula db defs <- readSTRef defsRef - return (cnf2, TseitinInfo nv1 (CNF.cnfNumVars cnf2) [(v, Or [atom lit | lit <- clause]) | (v, clause) <- toList defs]) + return (cnf2, TseitinInfo nv1 (CNF.cnfNumVars cnf2) (fmap (\clause -> Or [atom lit | lit <- clause]) defs)) where atom l | l < 0 = Not (Atom (- l)) diff --git a/src/ToySolver/Converter/SAT2MaxSAT.hs b/src/ToySolver/Converter/SAT2MaxSAT.hs index 300fda2d..0f077eaf 100644 --- a/src/ToySolver/Converter/SAT2MaxSAT.hs +++ b/src/ToySolver/Converter/SAT2MaxSAT.hs @@ -94,7 +94,7 @@ sat3ToMaxSAT2 cnf = } , t ) - , TseitinInfo (CNF.cnfNumVars cnf) nv + , TseitinInfo (CNF.cnfNumVars cnf) nv $ IntMap.fromList [ (d, SAT.And [atom a, atom b, atom c]) -- we define d as "a && b && c", but "a + b + c >= 2" is also fine. | (d, (a,b,c)) <- ds @@ -142,7 +142,7 @@ simplifyMaxSAT2 (wcnf, threshold) = case foldl' f (nv1, Set.empty, IntMap.empty, threshold) (CNF.wcnfClauses wcnf) of (nv2, cs, defs, threshold2) -> ( (nv2, cs, threshold2) - , TseitinInfo nv1 nv2 [(v, atom (- a)) | (v, (a, _b)) <- IntMap.toList defs] + , TseitinInfo nv1 nv2 (fmap (\(a, _b) -> atom (- a)) defs) -- we deine v as "~a" but "~b" is also fine. ) where diff --git a/src/ToySolver/Converter/Tseitin.hs b/src/ToySolver/Converter/Tseitin.hs index 58af9fef..f0812d90 100644 --- a/src/ToySolver/Converter/Tseitin.hs +++ b/src/ToySolver/Converter/Tseitin.hs @@ -21,7 +21,7 @@ import qualified Data.Aeson as J import qualified Data.Aeson.Types as J import Data.Aeson ((.=), (.:)) import Data.Array.IArray -import Data.List (sortOn) +import qualified Data.IntMap.Strict as IntMap import qualified Data.Map.Lazy as Map import Data.String import qualified Data.Text as T @@ -32,13 +32,8 @@ import ToySolver.SAT.Internal.JSON import qualified ToySolver.SAT.Types as SAT import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin -data TseitinInfo = TseitinInfo !Int !Int [(SAT.Var, Tseitin.Formula)] - deriving (Show, Read) - --- TODO: change defs representation to SAT.VarMap -instance Eq TseitinInfo where - TseitinInfo nv1 nv2 defs == TseitinInfo nv1' nv2' defs' = - nv1 == nv1' && nv2 == nv2' && sortOn fst defs == sortOn fst defs' +data TseitinInfo = TseitinInfo !Int !Int (SAT.VarMap Tseitin.Formula) + deriving (Show, Read, Eq) instance Transformer TseitinInfo where type Source TseitinInfo = SAT.Model @@ -50,7 +45,7 @@ instance ForwardTransformer TseitinInfo where -- Use BOXED array to tie the knot a :: Array SAT.Var Bool a = array (1, nv2) $ - assocs m ++ [(v, Tseitin.evalFormula a phi) | (v, phi) <- defs] + assocs m ++ [(v, Tseitin.evalFormula a phi) | (v, phi) <- IntMap.toList defs] instance BackwardTransformer TseitinInfo where transformBackward (TseitinInfo nv1 _nv2 _defs) = SAT.restrictModel nv1 @@ -63,7 +58,7 @@ instance J.ToJSON TseitinInfo where , "num_transformed_variables" .= nv2 , "definitions" .= J.object [ fromString ("x" ++ show v) .= formula - | (v, formula) <- defs + | (v, formula) <- IntMap.toList defs ] ] @@ -73,7 +68,7 @@ instance J.FromJSON TseitinInfo where TseitinInfo <$> obj .: "num_original_variables" <*> obj .: "num_transformed_variables" - <*> mapM f (Map.toList defs) + <*> (IntMap.fromList <$> mapM f (Map.toList defs)) where f :: (T.Text, SAT.Formula) -> J.Parser (SAT.Var, SAT.Formula) f (name, formula) = do diff --git a/src/ToySolver/SAT/Encoder/Tseitin.hs b/src/ToySolver/SAT/Encoder/Tseitin.hs index ea25ced2..58435548 100644 --- a/src/ToySolver/SAT/Encoder/Tseitin.hs +++ b/src/ToySolver/SAT/Encoder/Tseitin.hs @@ -89,6 +89,7 @@ module ToySolver.SAT.Encoder.Tseitin import Control.Monad import Control.Monad.Primitive import Data.Primitive.MutVar +import qualified Data.IntMap.Lazy as IntMap import Data.Map (Map) import qualified Data.Map as Map import qualified Data.IntSet as IntSet @@ -429,7 +430,7 @@ encodeFACarryWithPolarity encoder polarity a b c = do encodeWithPolarityHelper encoder (encFACarryTable encoder) definePos defineNeg polarity (a,b,c) -getDefinitions :: PrimMonad m => Encoder m -> m [(SAT.Var, Formula)] +getDefinitions :: PrimMonad m => Encoder m -> m (SAT.VarMap Formula) getDefinitions encoder = do tableConj <- readMutVar (encConjTable encoder) tableITE <- readMutVar (encITETable encoder) @@ -439,14 +440,18 @@ getDefinitions encoder = do let atom l | l < 0 = Not (Atom (- l)) | otherwise = Atom l - m1 = [(v, andB [atom l1 | l1 <- IntSet.toList ls]) | (ls, (v, _, _)) <- Map.toList tableConj] - m2 = [(v, ite (atom c) (atom t) (atom e)) | ((c,t,e), (v, _, _)) <- Map.toList tableITE] - m3 = [(v, (atom a .||. atom b) .&&. (atom (-a) .||. atom (-b))) | ((a,b), (v, _, _)) <- Map.toList tableXOR] - m4 = [(v, orB [andB [atom l | l <- ls] | ls <- [[a,b,c],[a,-b,-c],[-a,b,-c],[-a,-b,c]]]) - | ((a,b,c), (v, _, _)) <- Map.toList tableFASum] - m5 = [(v, orB [andB [atom l | l <- ls] | ls <- [[a,b],[a,c],[b,c]]]) - | ((a,b,c), (v, _, _)) <- Map.toList tableFACarry] - return $ concat [m1, m2, m3, m4, m5] + m1 = IntMap.fromList [(v, andB [atom l1 | l1 <- IntSet.toList ls]) | (ls, (v, _, _)) <- Map.toList tableConj] + m2 = IntMap.fromList [(v, ite (atom c) (atom t) (atom e)) | ((c,t,e), (v, _, _)) <- Map.toList tableITE] + m3 = IntMap.fromList [(v, (atom a .||. atom b) .&&. (atom (-a) .||. atom (-b))) | ((a,b), (v, _, _)) <- Map.toList tableXOR] + m4 = IntMap.fromList + [ (v, orB [andB [atom l | l <- ls] | ls <- [[a,b,c],[a,-b,-c],[-a,b,-c],[-a,-b,c]]]) + | ((a,b,c), (v, _, _)) <- Map.toList tableFASum + ] + m5 = IntMap.fromList + [ (v, orB [andB [atom l | l <- ls] | ls <- [[a,b],[a,c],[b,c]]]) + | ((a,b,c), (v, _, _)) <- Map.toList tableFACarry + ] + return $ IntMap.unions [m1, m2, m3, m4, m5] data Polarity diff --git a/test/Test/SAT/Encoder.hs b/test/Test/SAT/Encoder.hs index 7bcdf286..9c6e6b28 100644 --- a/test/Test/SAT/Encoder.hs +++ b/test/Test/SAT/Encoder.hs @@ -161,7 +161,7 @@ prop_PBEncoder_addPBAtLeast = QM.monadicIO $ do return (cnf, defs) forM_ (allAssignments 4) $ \m -> do let m2 :: Array SAT.Var Bool - m2 = array (1, CNF.cnfNumVars cnf) $ assocs m ++ [(v, Tseitin.evalFormula m2 phi) | (v,phi) <- defs] + m2 = array (1, CNF.cnfNumVars cnf) $ assocs m ++ [(v, Tseitin.evalFormula m2 phi) | (v,phi) <- IntMap.toList defs] b1 = SAT.evalPBLinAtLeast m (lhs,rhs) b2 = evalCNF (array (bounds m2) (assocs m2)) cnf QM.assert $ b1 == b2 @@ -189,7 +189,7 @@ prop_PBEncoder_encodePBLinAtLeastWithPolarity = QM.monadicIO $ do let m2 :: Array SAT.Var Bool m2 = array (1, CNF.cnfNumVars cnf) $ assocs m ++ - [(v, Tseitin.evalFormula m2 phi) | (v,phi) <- defs] ++ + [(v, Tseitin.evalFormula m2 phi) | (v,phi) <- IntMap.toList defs] ++ Cardinality.evalTotalizerDefinitions m2 defs2 b1 = evalCNF (array (bounds m2) (assocs m2)) cnf cmp a b = isJust $ do @@ -269,7 +269,7 @@ prop_CardinalityEncoder_addAtLeast = QM.monadicIO $ do let m2 :: Array SAT.Var Bool m2 = array (1, CNF.cnfNumVars cnf) $ assocs m ++ - [(v, Tseitin.evalFormula m2 phi) | (v,phi) <- defs] ++ + [(v, Tseitin.evalFormula m2 phi) | (v,phi) <- IntMap.toList defs] ++ Cardinality.evalTotalizerDefinitions m2 defs2 b1 = SAT.evalAtLeast m (lhs,rhs) b2 = evalCNF (array (bounds m2) (assocs m2)) cnf @@ -375,7 +375,7 @@ prop_encodeAtLeastWithPolarity = QM.monadicIO $ do let m2 :: Array SAT.Var Bool m2 = array (1, CNF.cnfNumVars cnf) $ assocs m ++ - [(v, Tseitin.evalFormula m2 phi) | (v,phi) <- defs] ++ + [(v, Tseitin.evalFormula m2 phi) | (v,phi) <- IntMap.toList defs] ++ Cardinality.evalTotalizerDefinitions m2 defs2 b1 = evalCNF (array (bounds m2) (assocs m2)) cnf cmp a b = isJust $ do