From 65ac021473ed68ed6cd85a14dd99a25a5914abf6 Mon Sep 17 00:00:00 2001 From: Christiaan Baaij Date: Mon, 27 Jan 2020 18:42:46 +0100 Subject: [PATCH] Backtrack Fixes #23 --- ghc-typelits-natnormalise.cabal | 5 +- src/GHC/TypeLits/Normalise.hs | 290 +++++++++++++++++++--------- src/GHC/TypeLits/Normalise/Unify.hs | 1 + tests/AmbTests.hs | 52 +++++ tests/Tests.hs | 28 +++ 5 files changed, 281 insertions(+), 95 deletions(-) create mode 100644 tests/AmbTests.hs diff --git a/ghc-typelits-natnormalise.cabal b/ghc-typelits-natnormalise.cabal index b1f47ab..e5be4ab 100644 --- a/ghc-typelits-natnormalise.cabal +++ b/ghc-typelits-natnormalise.cabal @@ -68,7 +68,7 @@ library build-depends: base >=4.9 && <5, containers >=0.5.7.1 && <0.7, ghc >=8.0.1 && <8.11, - ghc-tcplugins-extra >=0.3.1, + ghc-tcplugins-extra >=0.3.2, integer-gmp >=1.0 && <1.1, syb >=0.7.1 && <0.8, transformers >=0.5.2.0 && < 0.6 @@ -86,7 +86,8 @@ library test-suite test-ghc-typelits-natnormalise type: exitcode-stdio-1.0 main-is: Tests.hs - Other-Modules: ErrorTests + Other-Modules: AmbTests + ErrorTests build-depends: base >=4.8 && <5, ghc-typelits-natnormalise, tasty >= 0.10, diff --git a/src/GHC/TypeLits/Normalise.hs b/src/GHC/TypeLits/Normalise.hs index 958f57f..9b0c220 100644 --- a/src/GHC/TypeLits/Normalise.hs +++ b/src/GHC/TypeLits/Normalise.hs @@ -163,14 +163,14 @@ import Control.Monad ((<=<), forM) import Control.Monad (replicateM) #endif import Control.Monad.Trans.Writer.Strict -import Data.Either (rights) -import Data.List (intersect, stripPrefix, find) +import Data.Either (lefts, rights) +import Data.List (intersect, stripPrefix, find, nub) import Data.Maybe (mapMaybe, catMaybes) -import Data.Set (Set, empty, toList, notMember, fromList, union) +import Data.Set (Set, empty, toList, notMember, fromList, union, insert) import GHC.TcPluginM.Extra (tracePlugin, newGiven, newWanted) import qualified GHC.TcPluginM.Extra as TcPluginM #if MIN_VERSION_ghc(8,4,0) -import GHC.TcPluginM.Extra (flattenGivens) +import GHC.TcPluginM.Extra (flattenGivens, substCt) #endif import Text.Read (readMaybe) @@ -207,11 +207,12 @@ import TcTypeNats (typeNatAddTyCon, typeNatExpTyCon, typeNatMulTyCon, import TcTypeNats (typeNatLeqTyCon) import TysWiredIn (promotedFalseDataCon, promotedTrueDataCon) +import UniqSet (elementOfUniqSet) import Data.IORef #if MIN_VERSION_ghc(8,10,0) import Constraint - (Ct, CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred, + (Ct (..), CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isGiven, isWanted, mkNonCanonical, setCtLoc, setCtLocSpan) import Predicate (EqRel (NomEq), Pred (EqPred), classifyPredType, getEqPredTys, mkClassPred, @@ -219,7 +220,7 @@ import Predicate import Type (typeKind) #else import TcRnTypes - (Ct, CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred, + (Ct (..), CtEvidence (..), CtLoc, TcEvDest (..), ctEvidence, ctEvLoc, ctEvPred, ctLoc, ctLocSpan, isGiven, isWanted, mkNonCanonical, setCtLoc, setCtLocSpan) import TcType (typeKind) import Type @@ -248,6 +249,7 @@ import TcType (isEqPrimPred) #endif -- internal +import GHC.TypeLits.Normalise.SOP import GHC.TypeLits.Normalise.Unify #if !MIN_VERSION_ghc(8,10,0) @@ -331,10 +333,7 @@ decideEqualSOP opts gen'd givens _deriveds wanteds = do #if MIN_VERSION_ghc(8,4,0) let simplGivens = givens ++ flattenGivens givens subst = fst $ unzip $ TcPluginM.mkSubst' givens - wanteds0 = map (\ct -> (OrigCt ct, - TcPluginM.substCt subst ct - ) - ) wanteds + wanteds0 = map (\ct -> (OrigCt ct,substCt subst ct)) wanteds #else let wanteds0 = map (\ct -> (OrigCt ct, ct)) wanteds simplGivens <- mapM zonkCt givens @@ -367,7 +366,7 @@ decideEqualSOP opts gen'd givens _deriveds wanteds = do tcPluginIO $ modifyIORef' gen'd $ union (fromList newlyDone) let unit_givens = mapMaybe toNatEquality simplGivens - sr <- simplifyNats opts unit_givens unit_wanteds + sr <- solveWanted opts unit_givens unit_wanteds tcPluginTrace "normalised" (ppr sr) reds <- forM reducible_wanteds $ \(origCt,(term, ws)) -> do wants <- fmap fst $ evSubtPreds origCt $ subToPred opts ws @@ -451,88 +450,193 @@ instance Outputable SimplifyResult where ppr (Simplified evs) = text "Simplified" $$ ppr evs ppr (Impossible eq) = text "Impossible" <+> ppr eq -simplifyNats - :: Opts - -- ^ Allow negated numbers (potentially unsound!) - -> [(Either NatEquality NatInEquality,[(Type,Type)])] - -- ^ Given constraints - -> [(Either NatEquality NatInEquality,[(Type,Type)])] - -- ^ Wanted constraints - -> TcPluginM SimplifyResult -simplifyNats opts@Opts {..} eqsG eqsW = - let eqs = map (second (const [])) eqsG ++ eqsW - in tcPluginTrace "simplifyNats" (ppr eqs) >> simples [] [] [] [] eqs - where - simples :: [CoreUnify] - -> [((EvTerm, Ct), [Ct])] - -> [(CoreSOP,CoreSOP,Bool)] - -> [(Either NatEquality NatInEquality,[(Type,Type)])] - -> [(Either NatEquality NatInEquality,[(Type,Type)])] - -> TcPluginM SimplifyResult - simples _subst evs _leqsG _xs [] = return (Simplified evs) - simples subst evs leqsG xs (eq@(Left (ct,u,v),k):eqs') = do - let u' = substsSOP subst u - v' = substsSOP subst v - ur <- unifyNats ct u' v' - tcPluginTrace "unifyNats result" (ppr ur) - case ur of - Win -> do - evs' <- maybe evs (:evs) <$> evMagic ct empty (subToPred opts k) - simples subst evs' leqsG [] (xs ++ eqs') - Lose -> if null evs && null eqs' - then return (Impossible (fst eq)) - else simples subst evs leqsG xs eqs' - Draw [] -> simples subst evs [] (eq:xs) eqs' - Draw subst' -> do - evM <- evMagic ct empty (map unifyItemToPredType subst' ++ - subToPred opts k) - let leqsG' | isGiven (ctEvidence ct) = eqToLeq u' v' ++ leqsG - | otherwise = leqsG - case evM of - Nothing -> simples subst evs leqsG' xs eqs' - Just ev -> - simples (substsSubst subst' subst ++ subst') - (ev:evs) leqsG' [] (xs ++ eqs') - simples subst evs leqsG xs (eq@(Right (ct,u@(x,y,b)),k):eqs') = do - let u' = substsSOP subst (subtractIneq u) - x' = substsSOP subst x - y' = substsSOP subst y - uS = (x',y',b) - leqsG' | isGiven (ctEvidence ct) = (x',y',b):leqsG - | otherwise = leqsG - ineqs = concat [ leqsG - , map (substLeq subst) leqsG - , map snd (rights (map fst eqsG)) - ] - tcPluginTrace "unifyNats(ineq) results" (ppr (ct,u,u',ineqs)) - case runWriterT (isNatural u') of - Just (True,knW) -> do - evs' <- maybe evs (:evs) <$> evMagic ct knW (subToPred opts k) - simples subst evs' leqsG' xs eqs' - - Just (False,_) | null k -> return (Impossible (fst eq)) - _ -> do - let solvedIneq = mapMaybe runWriterT - -- it is an inequality that can be instantly solved, such as - -- `1 <= x^y` - -- OR - (instantSolveIneq depth u: - -- This inequality is either a given constraint, or it is a wanted - -- constraint, which in normal form is equal to another given - -- constraint, hence it can be solved. - -- OR - map (solveIneq depth u) ineqs ++ - -- The above, but with valid substitutions applied to the wanted. - map (solveIneq depth uS) ineqs) - smallest = solvedInEqSmallestConstraint solvedIneq - case smallest of - (True,kW) -> do - evs' <- maybe evs (:evs) <$> evMagic ct kW (subToPred opts k) - simples subst evs' leqsG' xs eqs' - _ -> simples subst evs leqsG (eq:xs) eqs' - - eqToLeq x y = [(x,y,True),(y,x,True)] - substLeq s (x,y,b) = (substsSOP s x, substsSOP s y, b) +findGivenSubst :: + [(Either NatEquality NatInEquality, [(Type, Type)])] -> + TcPluginM [CoreUnify] +findGivenSubst = go [] [] + where + go :: + [CoreUnify] -> + [(Either NatEquality NatInEquality, [(Type, Type)])] -> + [(Either NatEquality NatInEquality, [(Type, Type)])] -> + TcPluginM [CoreUnify] + go subst _ [] = pure subst + go subst tryAgain (eq@(Left (ct,u,v),_k):eqs) = do + let uvs = rewriteEQ False subst (u,v) + urs <- mapM (uncurry (unifyNats ct)) uvs + case foldr smashDraw Lose urs of + Draw qs@(_:_) -> go (qs ++ subst) [] (tryAgain ++ eqs) + _ -> go subst (eq:tryAgain) eqs + go subst tryAgain (_:eqs) = go subst tryAgain eqs + + smashDraw (Draw s1) (Draw s2) = Draw (nub (s1 ++ s2)) + smashDraw (Draw s) _ = Draw s + smashDraw _ (Draw s) = Draw s + smashDraw _ r = r + +-- | Witness that equality is a symmetric relation +newtype SymSOP = SymSOP (CoreSOP, CoreSOP) + +instance Eq SymSOP where + (SymSOP (p,q)) == (SymSOP (x,y)) = + case p == x of + False -> p == y && q == x + True -> q == y + +instance Ord SymSOP where + compare (SymSOP (p,q)) (SymSOP (x,y)) = + case compare p x of + LT -> case compare p y of + EQ -> case compare q x of + EQ -> EQ + _ -> LT + _ -> LT + EQ -> compare q y + GT -> GT + +rewriteEQ :: + Bool -> + [CoreUnify] -> + (CoreSOP, CoreSOP) -> + [(CoreSOP, CoreSOP)] +rewriteEQ isW subst = go empty + where + go hist (u,v) = + let + hist1 = insert (SymSOP (u,v)) hist + rewrites = mapMaybe (rewrite (u,v)) subst + new = nub (filter ((`notMember` hist1) . SymSOP) rewrites) + in (u,v):concatMap (go hist1) new + + rewrite _ UnifyItem {} = Nothing + rewrite (p,q) (SubstItem x e) = + let p1 = substSOP x e p + q1 = substSOP x e q + in if isW || p1 /= q1 then + Just (substSOP x e p, substSOP x e q) + else + Nothing + +rewriteINEQ :: + [CoreUnify] -> + (CoreSOP,CoreSOP,Bool) -> + [(CoreSOP,CoreSOP,Bool)] +rewriteINEQ subst = go empty + where + go hist (u,v,b) = + let hist1 = insert (u,v,b) hist + rewrites = mapMaybe (rewrite (u,v,b)) subst + new = nub (filter (`notMember` hist1) rewrites) + in (u,v,b):concatMap (go hist1) new + + rewrite _ UnifyItem {} = Nothing + rewrite (p,q,z) (SubstItem x e) = + let p1 = substSOP x e p + q1 = substSOP x e q + in if p1 /= q1 then + Just (substSOP x e p, substSOP x e q, z) + else + Nothing + +solveWanted :: + -- | Solver options (depth, whether to allow integers, etc.) + Opts -> + -- | Given constraints + [(Either NatEquality NatInEquality,[(Type,Type)])] -> + -- | Wanted constraints + [(Either NatEquality NatInEquality,[(Type,Type)])] -> + TcPluginM SimplifyResult +solveWanted opts@Opts {..} eqsG eqsW = do + subst <- nub . concatMap mkSubst <$> findGivenSubst eqsG + tcPluginTrace "solveWanted eqsG:" (ppr eqsG) + tcPluginTrace "solveWanted subst:" (ppr (subst)) + go subst [] eqsW + where + go _ evs [] = pure (Simplified evs) + go subst evs ((Left (ct,u,v),kW):eqs) = do + (subst1,s) <- goSingleWantedEq subst (ct,u,v,kW) + case s of + Simplified ev -> go subst1 (ev ++ evs) eqs + i -> if null evs && null eqs then + pure i + else + go subst1 evs eqs + go subst evs ((Right (ct,(u,v,b)),kW):eqs) = do + s <- goSingleWantedInEq subst (ct,u,v,b,kW) + case s of + Simplified ev -> go subst (ev ++ evs) eqs + i -> if null evs && null eqs then + pure i + else + go subst evs eqs + + goSingleWantedEq + :: [CoreUnify] + -> (Ct,CoreSOP,CoreSOP,[(Type,Type)]) + -> TcPluginM ([CoreUnify],SimplifyResult) + goSingleWantedEq subst (ct,u,v,kW) = do + tcPluginTrace "goSingleWanted" (ppr (ct,u,v,subst)) + let uvs = rewriteEQ True subst (u,v) + urs <- mapM (uncurry (unifyNats ct)) uvs + case foldr smashResult (Draw []) urs of + Win -> maybe (subst,Simplified []) ((subst,) . Simplified . (:[])) <$> + evMagic ct empty (subToPred opts kW) + Lose -> pure (subst,Impossible (Left (ct,u,v))) + Draw [] -> pure (subst,Simplified []) + Draw s -> + let subst1 = nub (concatMap mkSubst s ++ subst) + in maybe (subst,Simplified []) ((subst1,) . Simplified . (:[])) <$> + evMagic ct empty (map unifyItemToPredType s ++ subToPred opts kW) + + mkSubst UnifyItem {} = [] + mkSubst (SubstItem x e) + | S [P [V y]] <- e + = if x /= y then [SubstItem x e, SubstItem y (S [P [(V x)]])] else [] + | otherwise + = [SubstItem x e] + + smashResult Lose _ = Lose + smashResult _ Lose = Lose + smashResult Win _ = Win + smashResult _ Win = Win + smashResult (Draw s1) (Draw s2) = Draw (filter simpleSubst (s1 ++ s2)) + + simpleSubst UnifyItem {} = True + simpleSubst (SubstItem v s) = not (v `elementOfUniqSet` fvSOP s) + + goSingleWantedInEq + :: [CoreUnify] + -> (Ct,CoreSOP,CoreSOP,Bool,[(Type,Type)]) + -> TcPluginM SimplifyResult + goSingleWantedInEq subst (ct,u,v,b,kW) = do + let ineqs = map snd (rights (map fst eqsG)) + eqs = map (\(_,x,y) -> (x,y)) (lefts (map fst eqsG)) + ineqs1 = nub (concatMap (rewriteINEQ subst) ineqs) + ineqs2 = nub (concatMap (concatMap eqToLeq . rewriteEQ False subst) eqs) + ineqs3 = ineqs1 ++ ineqs2 + z = subtractIneq (u,v,b) + uvbs = rewriteINEQ subst (u,v,b) + solved = mapMaybe runWriterT + -- This inequality is either a given constraint, or it is a + -- wanted constraint, which in normal form is equal to + -- another given constraint, hence it can be solved. + (concatMap (\uvb -> map (solveIneq depth uvb) ineqs3) uvbs ++ + -- Or it is an inequality that can be instantly solved, such + -- as `1 <= x^y` + map (instantSolveIneq depth) uvbs) + smallest = solvedInEqSmallestConstraint solved + tcPluginTrace "goSingleWantedInEq" (ppr (ct,uvbs,ineqs3)) + case runWriterT (isNatural z) of + Just (True,knW) -> + maybe (Simplified []) (Simplified . (:[])) <$> + evMagic ct knW (subToPred opts kW) + Just (False,_) | null kW -> pure (Impossible (Right (ct,(u,v,b)))) + _ -> case smallest of + (True,knW) -> maybe (Simplified []) (Simplified . (:[])) <$> + evMagic ct knW (subToPred opts kW) + _ -> pure (Simplified []) + + eqToLeq (x,y) = [(x,y,True),(y,x,True)] -- If we allow negated numbers we simply do not emit the inequalities -- derived from the subtractions that are converted to additions with a diff --git a/src/GHC/TypeLits/Normalise/Unify.hs b/src/GHC/TypeLits/Normalise/Unify.hs index 5ccb510..51e88ef 100644 --- a/src/GHC/TypeLits/Normalise/Unify.hs +++ b/src/GHC/TypeLits/Normalise/Unify.hs @@ -26,6 +26,7 @@ module GHC.TypeLits.Normalise.Unify -- * Substitution on 'SOP' terms , UnifyItem (..) , CoreUnify + , substSOP , substsSOP , substsSubst -- * Find unifiers diff --git a/tests/AmbTests.hs b/tests/AmbTests.hs new file mode 100644 index 0000000..26f9b08 --- /dev/null +++ b/tests/AmbTests.hs @@ -0,0 +1,52 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +#if __GLASGOW_HASKELL__ >= 805 +{-# LANGUAGE NoStarIsType #-} +#endif + +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -dcore-lint #-} + +module AmbTests where + +import GHC.TypeLits + +leToPlus2 + :: forall (k :: Nat) (n :: Nat) r + . ( k <= n + ) + => (forall m . (n ~ (k + m)) => r) + -- ^ Context with the (k + m ~ n) constraint + -> r +leToPlus2 r = r @(n - k) +{-# INLINE leToPlus2 #-} + +newtype BitVector (n :: Nat) = BV Integer + +class BitPack a where + type BitSize a :: Nat + +split :: forall a m n . + (BitPack a, BitSize a ~ (m + n)) => + a -> + (BitVector m, BitVector n) +split = split + +lastBits + :: forall n a + . ( BitPack a + , n <= BitSize a + , KnownNat (BitSize a) + , KnownNat n + ) + => a + -> BitVector n +lastBits = leToPlus2 @n @(BitSize a) $ snd . split @_ @_ @n diff --git a/tests/Tests.hs b/tests/Tests.hs index 8f81c5e..c3212cb 100644 --- a/tests/Tests.hs +++ b/tests/Tests.hs @@ -34,6 +34,7 @@ import Control.Exception import Test.Tasty import Test.Tasty.HUnit +import AmbTests import ErrorTests data Vec :: Nat -> Type -> Type where @@ -91,6 +92,11 @@ addUNat UZero y = y addUNat x UZero = x addUNat (USucc x) y = USucc (addUNat x y) +subUNat :: UNat (m+n) -> UNat n -> UNat m +subUNat x UZero = x +subUNat (USucc x) (USucc y) = subUNat x y +subUNat UZero _ = error "subUNat: impossible: 0 + (n + 1) ~ 0" + -- | Multiply two singleton natural numbers -- -- __NB__: Not synthesisable @@ -298,6 +304,28 @@ predBNat (B1 a) = case a of a' -> B0 a' predBNat (B0 x) = B1 (predBNat x) +succBNat :: BNat n -> BNat (n+1) +succBNat BT = B1 BT +succBNat (B0 a) = B1 a +succBNat (B1 a) = B0 (succBNat a) + +stripZeros :: BNat n -> BNat n +stripZeros BT = BT +stripZeros (B1 x) = B1 (stripZeros x) +stripZeros (B0 BT) = BT +stripZeros (B0 x) = case stripZeros x of + BT -> BT + k -> B0 k + +log2BNat :: BNat (2^n) -> BNat n +#if __GLASGOW_HASKELL__ >= 802 +log2BNat BT = error "log2BNat: log2(0) not defined" +#endif +log2BNat (B1 x) = case stripZeros x of + BT -> BT + _ -> error "log2BNat: impossible: 2^n ~ 2x+1" +log2BNat (B0 x) = succBNat (log2BNat x) + proxyInEq1 :: Proxy a -> Proxy (a+1) -> () proxyInEq1 = proxyInEq