From 8577581a4a92a6493763a0afa78d0c2226edf375 Mon Sep 17 00:00:00 2001 From: Christiaan Baaij Date: Wed, 5 Feb 2020 12:43:08 +0100 Subject: [PATCH] Backtrack Fixes #23 --- src/GHC/TypeLits/Normalise.hs | 78 +++++++++++++++++++++++++++++++---- 1 file changed, 71 insertions(+), 7 deletions(-) diff --git a/src/GHC/TypeLits/Normalise.hs b/src/GHC/TypeLits/Normalise.hs index 958f57f..7652b9e 100644 --- a/src/GHC/TypeLits/Normalise.hs +++ b/src/GHC/TypeLits/Normalise.hs @@ -163,8 +163,8 @@ import Control.Monad ((<=<), forM) import Control.Monad (replicateM) #endif import Control.Monad.Trans.Writer.Strict -import Data.Either (rights) -import Data.List (intersect, stripPrefix, find) +import Data.Either (partitionEithers, rights) +import Data.List (intersect, partition, stripPrefix, find) import Data.Maybe (mapMaybe, catMaybes) import Data.Set (Set, empty, toList, notMember, fromList, union) import GHC.TcPluginM.Extra (tracePlugin, newGiven, newWanted) @@ -212,7 +212,8 @@ import Data.IORef #if MIN_VERSION_ghc(8,10,0) import Constraint (Ct, CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred, - ctLoc, ctLocSpan, isGiven, isWanted, mkNonCanonical, setCtLoc, setCtLocSpan) + ctLoc, ctLocSpan, isGiven, isWanted, mkNonCanonical, setCtLoc, setCtLocSpan, + isWantedCt) import Predicate (EqRel (NomEq), Pred (EqPred), classifyPredType, getEqPredTys, mkClassPred, mkPrimEqPred) @@ -220,7 +221,8 @@ import Type (typeKind) #else import TcRnTypes (Ct, CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred, - ctLoc, ctLocSpan, isGiven, isWanted, mkNonCanonical, setCtLoc, setCtLocSpan) + ctLoc, ctLocSpan, isGiven, isWanted, mkNonCanonical, setCtLoc, setCtLocSpan, + isWantedCt) import TcType (typeKind) import Type (EqRel (NomEq), PredTree (EqPred), classifyPredType, getEqPredTys, mkClassPred, @@ -248,6 +250,7 @@ import TcType (isEqPrimPred) #endif -- internal +import GHC.TypeLits.Normalise.SOP import GHC.TypeLits.Normalise.Unify #if !MIN_VERSION_ghc(8,10,0) @@ -459,9 +462,21 @@ simplifyNats -> [(Either NatEquality NatInEquality,[(Type,Type)])] -- ^ Wanted constraints -> TcPluginM SimplifyResult -simplifyNats opts@Opts {..} eqsG eqsW = - let eqs = map (second (const [])) eqsG ++ eqsW - in tcPluginTrace "simplifyNats" (ppr eqs) >> simples [] [] [] [] eqs +simplifyNats opts@Opts {..} eqsG eqsW = do + let (varEqs,otherEqs) = partition isVarEqs eqsG + fancyGivens = concatMap (makeGivensSet otherEqs) varEqs + eqs = map (second (const [])) otherEqs ++ eqsW + case varEqs of + [] -> do + tcPluginTrace "simplifyNats" (ppr eqs) + simples [] [] [] [] eqs + _ -> do + tcPluginTrace ("simplifyNats(backtrack: " ++ show (length fancyGivens) ++ ")") + (ppr varEqs) + foldr findFirstSimpliedWanted (Simplified []) <$> + mapM (\v -> tcPluginTrace "simplifyNats" (ppr (v ++ eqs)) >> + simples [] [] [] [] (v ++ eqs)) + fancyGivens where simples :: [CoreUnify] -> [((EvTerm, Ct), [Ct])] @@ -534,6 +549,55 @@ simplifyNats opts@Opts {..} eqsG eqsW = eqToLeq x y = [(x,y,True),(y,x,True)] substLeq s (x,y,b) = (substsSOP s x, substsSOP s y, b) + isVarEqs (Left (_,S [P [V _]], S [P [V _]]), _) = True + isVarEqs _ = False + + makeGivensSet otherEqs varEq + = let (noMentionsV,mentionsV) = partitionEithers + (map (matchesVarEq varEq) otherEqs) + (mentionsLHS,mentionsRHS) = partitionEithers mentionsV + vS = swapVar varEq + givensLHS = case mentionsLHS of + [] -> [] + _ -> [mentionsLHS ++ (vS:noMentionsV) + ,varEq:mentionsLHS ++ noMentionsV + ] + givensRHS = case mentionsRHS of + [] -> [] + _ -> [mentionsRHS ++ (varEq:noMentionsV) + ,vS:mentionsRHS ++ noMentionsV + ] + in case mentionsV of + [] -> [noMentionsV] + _ -> givensLHS ++ givensRHS + + matchesVarEq (Left (_, S [P [V v1]], S [P [V v2]]),_) r = case r of + (Left (_,S [P [V v3]],_),_) + | v1 == v3 -> Right (Left r) + | v2 == v3 -> Right (Right r) + (Left (_,_,S [P [V v3]]),_) + | v1 == v3 -> Right (Left r) + | v2 == v3 -> Right (Right r) + (Right (_,(S [P [V v3]],_,_)),_) + | v1 == v3 -> Right (Left r) + | v2 == v3 -> Right (Right r) + (Right (_,(_,S [P [V v3]],_)),_) + | v1 == v3 -> Right (Left r) + | v2 == v3 -> Right (Right r) + _ -> Left r + matchesVarEq _ _ = error "internal error" + + swapVar (Left (ct,S [P [V v1]], S [P [V v2]]),ps) = + (Left (ct,S [P [V v2]], S [P [V v1]]),ps) + swapVar _ = error "internal error" + + findFirstSimpliedWanted (Impossible e) _ = Impossible e + findFirstSimpliedWanted (Simplified evs) s2 + | any (isWantedCt . snd . fst) evs + = Simplified evs + | otherwise + = s2 + -- If we allow negated numbers we simply do not emit the inequalities -- derived from the subtractions that are converted to additions with a -- negated operand