Skip to content

Commit

Permalink
Backtrack
Browse files Browse the repository at this point in the history
Fixes #23
  • Loading branch information
christiaanb committed Feb 3, 2020
1 parent 99086c2 commit 65ac021
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 95 deletions.
5 changes: 3 additions & 2 deletions ghc-typelits-natnormalise.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ library
build-depends: base >=4.9 && <5,
containers >=0.5.7.1 && <0.7,
ghc >=8.0.1 && <8.11,
ghc-tcplugins-extra >=0.3.1,
ghc-tcplugins-extra >=0.3.2,
integer-gmp >=1.0 && <1.1,
syb >=0.7.1 && <0.8,
transformers >=0.5.2.0 && < 0.6
Expand All @@ -86,7 +86,8 @@ library
test-suite test-ghc-typelits-natnormalise
type: exitcode-stdio-1.0
main-is: Tests.hs
Other-Modules: ErrorTests
Other-Modules: AmbTests
ErrorTests
build-depends: base >=4.8 && <5,
ghc-typelits-natnormalise,
tasty >= 0.10,
Expand Down
290 changes: 197 additions & 93 deletions src/GHC/TypeLits/Normalise.hs
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ import Control.Monad ((<=<), forM)
import Control.Monad (replicateM)
#endif
import Control.Monad.Trans.Writer.Strict
import Data.Either (rights)
import Data.List (intersect, stripPrefix, find)
import Data.Either (lefts, rights)
import Data.List (intersect, stripPrefix, find, nub)
import Data.Maybe (mapMaybe, catMaybes)
import Data.Set (Set, empty, toList, notMember, fromList, union)
import Data.Set (Set, empty, toList, notMember, fromList, union, insert)
import GHC.TcPluginM.Extra (tracePlugin, newGiven, newWanted)
import qualified GHC.TcPluginM.Extra as TcPluginM
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra (flattenGivens)
import GHC.TcPluginM.Extra (flattenGivens, substCt)
#endif
import Text.Read (readMaybe)

Expand Down Expand Up @@ -207,19 +207,20 @@ import TcTypeNats (typeNatAddTyCon, typeNatExpTyCon, typeNatMulTyCon,

import TcTypeNats (typeNatLeqTyCon)
import TysWiredIn (promotedFalseDataCon, promotedTrueDataCon)
import UniqSet (elementOfUniqSet)
import Data.IORef

#if MIN_VERSION_ghc(8,10,0)
import Constraint
(Ct, CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred,
(Ct (..), CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred,
ctLoc, ctLocSpan, isGiven, isWanted, mkNonCanonical, setCtLoc, setCtLocSpan)
import Predicate
(EqRel (NomEq), Pred (EqPred), classifyPredType, getEqPredTys, mkClassPred,
mkPrimEqPred)
import Type (typeKind)
#else
import TcRnTypes
(Ct, CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred,
(Ct (..), CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred,
ctLoc, ctLocSpan, isGiven, isWanted, mkNonCanonical, setCtLoc, setCtLocSpan)
import TcType (typeKind)
import Type
Expand Down Expand Up @@ -248,6 +249,7 @@ import TcType (isEqPrimPred)
#endif

-- internal
import GHC.TypeLits.Normalise.SOP
import GHC.TypeLits.Normalise.Unify

#if !MIN_VERSION_ghc(8,10,0)
Expand Down Expand Up @@ -331,10 +333,7 @@ decideEqualSOP opts gen'd givens _deriveds wanteds = do
#if MIN_VERSION_ghc(8,4,0)
let simplGivens = givens ++ flattenGivens givens
subst = fst $ unzip $ TcPluginM.mkSubst' givens
wanteds0 = map (\ct -> (OrigCt ct,
TcPluginM.substCt subst ct
)
) wanteds
wanteds0 = map (\ct -> (OrigCt ct,substCt subst ct)) wanteds
#else
let wanteds0 = map (\ct -> (OrigCt ct, ct)) wanteds
simplGivens <- mapM zonkCt givens
Expand Down Expand Up @@ -367,7 +366,7 @@ decideEqualSOP opts gen'd givens _deriveds wanteds = do
tcPluginIO $
modifyIORef' gen'd $ union (fromList newlyDone)
let unit_givens = mapMaybe toNatEquality simplGivens
sr <- simplifyNats opts unit_givens unit_wanteds
sr <- solveWanted opts unit_givens unit_wanteds
tcPluginTrace "normalised" (ppr sr)
reds <- forM reducible_wanteds $ \(origCt,(term, ws)) -> do
wants <- fmap fst $ evSubtPreds origCt $ subToPred opts ws
Expand Down Expand Up @@ -451,88 +450,193 @@ instance Outputable SimplifyResult where
ppr (Simplified evs) = text "Simplified" $$ ppr evs
ppr (Impossible eq) = text "Impossible" <+> ppr eq

simplifyNats
:: Opts
-- ^ Allow negated numbers (potentially unsound!)
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-- ^ Given constraints
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-- ^ Wanted constraints
-> TcPluginM SimplifyResult
simplifyNats opts@Opts {..} eqsG eqsW =
let eqs = map (second (const [])) eqsG ++ eqsW
in tcPluginTrace "simplifyNats" (ppr eqs) >> simples [] [] [] [] eqs
where
simples :: [CoreUnify]
-> [((EvTerm, Ct), [Ct])]
-> [(CoreSOP,CoreSOP,Bool)]
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> TcPluginM SimplifyResult
simples _subst evs _leqsG _xs [] = return (Simplified evs)
simples subst evs leqsG xs (eq@(Left (ct,u,v),k):eqs') = do
let u' = substsSOP subst u
v' = substsSOP subst v
ur <- unifyNats ct u' v'
tcPluginTrace "unifyNats result" (ppr ur)
case ur of
Win -> do
evs' <- maybe evs (:evs) <$> evMagic ct empty (subToPred opts k)
simples subst evs' leqsG [] (xs ++ eqs')
Lose -> if null evs && null eqs'
then return (Impossible (fst eq))
else simples subst evs leqsG xs eqs'
Draw [] -> simples subst evs [] (eq:xs) eqs'
Draw subst' -> do
evM <- evMagic ct empty (map unifyItemToPredType subst' ++
subToPred opts k)
let leqsG' | isGiven (ctEvidence ct) = eqToLeq u' v' ++ leqsG
| otherwise = leqsG
case evM of
Nothing -> simples subst evs leqsG' xs eqs'
Just ev ->
simples (substsSubst subst' subst ++ subst')
(ev:evs) leqsG' [] (xs ++ eqs')
simples subst evs leqsG xs (eq@(Right (ct,u@(x,y,b)),k):eqs') = do
let u' = substsSOP subst (subtractIneq u)
x' = substsSOP subst x
y' = substsSOP subst y
uS = (x',y',b)
leqsG' | isGiven (ctEvidence ct) = (x',y',b):leqsG
| otherwise = leqsG
ineqs = concat [ leqsG
, map (substLeq subst) leqsG
, map snd (rights (map fst eqsG))
]
tcPluginTrace "unifyNats(ineq) results" (ppr (ct,u,u',ineqs))
case runWriterT (isNatural u') of
Just (True,knW) -> do
evs' <- maybe evs (:evs) <$> evMagic ct knW (subToPred opts k)
simples subst evs' leqsG' xs eqs'

Just (False,_) | null k -> return (Impossible (fst eq))
_ -> do
let solvedIneq = mapMaybe runWriterT
-- it is an inequality that can be instantly solved, such as
-- `1 <= x^y`
-- OR
(instantSolveIneq depth u:
-- This inequality is either a given constraint, or it is a wanted
-- constraint, which in normal form is equal to another given
-- constraint, hence it can be solved.
-- OR
map (solveIneq depth u) ineqs ++
-- The above, but with valid substitutions applied to the wanted.
map (solveIneq depth uS) ineqs)
smallest = solvedInEqSmallestConstraint solvedIneq
case smallest of
(True,kW) -> do
evs' <- maybe evs (:evs) <$> evMagic ct kW (subToPred opts k)
simples subst evs' leqsG' xs eqs'
_ -> simples subst evs leqsG (eq:xs) eqs'

eqToLeq x y = [(x,y,True),(y,x,True)]
substLeq s (x,y,b) = (substsSOP s x, substsSOP s y, b)
findGivenSubst ::
[(Either NatEquality NatInEquality, [(Type, Type)])] ->
TcPluginM [CoreUnify]
findGivenSubst = go [] []
where
go ::
[CoreUnify] ->
[(Either NatEquality NatInEquality, [(Type, Type)])] ->
[(Either NatEquality NatInEquality, [(Type, Type)])] ->
TcPluginM [CoreUnify]
go subst _ [] = pure subst
go subst tryAgain (eq@(Left (ct,u,v),_k):eqs) = do
let uvs = rewriteEQ False subst (u,v)
urs <- mapM (uncurry (unifyNats ct)) uvs
case foldr smashDraw Lose urs of
Draw qs@(_:_) -> go (qs ++ subst) [] (tryAgain ++ eqs)
_ -> go subst (eq:tryAgain) eqs
go subst tryAgain (_:eqs) = go subst tryAgain eqs

smashDraw (Draw s1) (Draw s2) = Draw (nub (s1 ++ s2))
smashDraw (Draw s) _ = Draw s
smashDraw _ (Draw s) = Draw s
smashDraw _ r = r

-- | Witness that equality is a symmetric relation
newtype SymSOP = SymSOP (CoreSOP, CoreSOP)

instance Eq SymSOP where
(SymSOP (p,q)) == (SymSOP (x,y)) =
case p == x of
False -> p == y && q == x
True -> q == y

instance Ord SymSOP where
compare (SymSOP (p,q)) (SymSOP (x,y)) =
case compare p x of
LT -> case compare p y of
EQ -> case compare q x of
EQ -> EQ
_ -> LT
_ -> LT
EQ -> compare q y
GT -> GT

rewriteEQ ::
Bool ->
[CoreUnify] ->
(CoreSOP, CoreSOP) ->
[(CoreSOP, CoreSOP)]
rewriteEQ isW subst = go empty
where
go hist (u,v) =
let
hist1 = insert (SymSOP (u,v)) hist
rewrites = mapMaybe (rewrite (u,v)) subst
new = nub (filter ((`notMember` hist1) . SymSOP) rewrites)
in (u,v):concatMap (go hist1) new

rewrite _ UnifyItem {} = Nothing
rewrite (p,q) (SubstItem x e) =
let p1 = substSOP x e p
q1 = substSOP x e q
in if isW || p1 /= q1 then
Just (substSOP x e p, substSOP x e q)
else
Nothing

rewriteINEQ ::
[CoreUnify] ->
(CoreSOP,CoreSOP,Bool) ->
[(CoreSOP,CoreSOP,Bool)]
rewriteINEQ subst = go empty
where
go hist (u,v,b) =
let hist1 = insert (u,v,b) hist
rewrites = mapMaybe (rewrite (u,v,b)) subst
new = nub (filter (`notMember` hist1) rewrites)
in (u,v,b):concatMap (go hist1) new

rewrite _ UnifyItem {} = Nothing
rewrite (p,q,z) (SubstItem x e) =
let p1 = substSOP x e p
q1 = substSOP x e q
in if p1 /= q1 then
Just (substSOP x e p, substSOP x e q, z)
else
Nothing

solveWanted ::
-- | Solver options (depth, whether to allow integers, etc.)
Opts ->
-- | Given constraints
[(Either NatEquality NatInEquality,[(Type,Type)])] ->
-- | Wanted constraints
[(Either NatEquality NatInEquality,[(Type,Type)])] ->
TcPluginM SimplifyResult
solveWanted opts@Opts {..} eqsG eqsW = do
subst <- nub . concatMap mkSubst <$> findGivenSubst eqsG
tcPluginTrace "solveWanted eqsG:" (ppr eqsG)
tcPluginTrace "solveWanted subst:" (ppr (subst))
go subst [] eqsW
where
go _ evs [] = pure (Simplified evs)
go subst evs ((Left (ct,u,v),kW):eqs) = do
(subst1,s) <- goSingleWantedEq subst (ct,u,v,kW)
case s of
Simplified ev -> go subst1 (ev ++ evs) eqs
i -> if null evs && null eqs then
pure i
else
go subst1 evs eqs
go subst evs ((Right (ct,(u,v,b)),kW):eqs) = do
s <- goSingleWantedInEq subst (ct,u,v,b,kW)
case s of
Simplified ev -> go subst (ev ++ evs) eqs
i -> if null evs && null eqs then
pure i
else
go subst evs eqs

goSingleWantedEq
:: [CoreUnify]
-> (Ct,CoreSOP,CoreSOP,[(Type,Type)])
-> TcPluginM ([CoreUnify],SimplifyResult)
goSingleWantedEq subst (ct,u,v,kW) = do
tcPluginTrace "goSingleWanted" (ppr (ct,u,v,subst))
let uvs = rewriteEQ True subst (u,v)
urs <- mapM (uncurry (unifyNats ct)) uvs
case foldr smashResult (Draw []) urs of
Win -> maybe (subst,Simplified []) ((subst,) . Simplified . (:[])) <$>
evMagic ct empty (subToPred opts kW)
Lose -> pure (subst,Impossible (Left (ct,u,v)))
Draw [] -> pure (subst,Simplified [])
Draw s ->
let subst1 = nub (concatMap mkSubst s ++ subst)
in maybe (subst,Simplified []) ((subst1,) . Simplified . (:[])) <$>
evMagic ct empty (map unifyItemToPredType s ++ subToPred opts kW)

mkSubst UnifyItem {} = []
mkSubst (SubstItem x e)
| S [P [V y]] <- e
= if x /= y then [SubstItem x e, SubstItem y (S [P [(V x)]])] else []
| otherwise
= [SubstItem x e]

smashResult Lose _ = Lose
smashResult _ Lose = Lose
smashResult Win _ = Win
smashResult _ Win = Win
smashResult (Draw s1) (Draw s2) = Draw (filter simpleSubst (s1 ++ s2))

simpleSubst UnifyItem {} = True
simpleSubst (SubstItem v s) = not (v `elementOfUniqSet` fvSOP s)

goSingleWantedInEq
:: [CoreUnify]
-> (Ct,CoreSOP,CoreSOP,Bool,[(Type,Type)])
-> TcPluginM SimplifyResult
goSingleWantedInEq subst (ct,u,v,b,kW) = do
let ineqs = map snd (rights (map fst eqsG))
eqs = map (\(_,x,y) -> (x,y)) (lefts (map fst eqsG))
ineqs1 = nub (concatMap (rewriteINEQ subst) ineqs)
ineqs2 = nub (concatMap (concatMap eqToLeq . rewriteEQ False subst) eqs)
ineqs3 = ineqs1 ++ ineqs2
z = subtractIneq (u,v,b)
uvbs = rewriteINEQ subst (u,v,b)
solved = mapMaybe runWriterT
-- This inequality is either a given constraint, or it is a
-- wanted constraint, which in normal form is equal to
-- another given constraint, hence it can be solved.
(concatMap (\uvb -> map (solveIneq depth uvb) ineqs3) uvbs ++
-- Or it is an inequality that can be instantly solved, such
-- as `1 <= x^y`
map (instantSolveIneq depth) uvbs)
smallest = solvedInEqSmallestConstraint solved
tcPluginTrace "goSingleWantedInEq" (ppr (ct,uvbs,ineqs3))
case runWriterT (isNatural z) of
Just (True,knW) ->
maybe (Simplified []) (Simplified . (:[])) <$>
evMagic ct knW (subToPred opts kW)
Just (False,_) | null kW -> pure (Impossible (Right (ct,(u,v,b))))
_ -> case smallest of
(True,knW) -> maybe (Simplified []) (Simplified . (:[])) <$>
evMagic ct knW (subToPred opts kW)
_ -> pure (Simplified [])

eqToLeq (x,y) = [(x,y,True),(y,x,True)]

-- If we allow negated numbers we simply do not emit the inequalities
-- derived from the subtractions that are converted to additions with a
Expand Down
1 change: 1 addition & 0 deletions src/GHC/TypeLits/Normalise/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ module GHC.TypeLits.Normalise.Unify
-- * Substitution on 'SOP' terms
, UnifyItem (..)
, CoreUnify
, substSOP
, substsSOP
, substsSubst
-- * Find unifiers
Expand Down
Loading

0 comments on commit 65ac021

Please sign in to comment.