Skip to content

Commit

Permalink
Backtrack
Browse files Browse the repository at this point in the history
Fixes #23
  • Loading branch information
christiaanb committed Jan 30, 2020
1 parent e81b9ca commit d23c3c9
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 80 deletions.
2 changes: 1 addition & 1 deletion cabal.project
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
packages: ./ghc-typelits-natnormalise.cabal
flags: +deverror
flags: -deverror
allow-newer: *:base, *:template-haskell, *:ghc, *:stm
repository head.hackage
url: http://head.hackage.haskell.org/
Expand Down
3 changes: 2 additions & 1 deletion ghc-typelits-natnormalise.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ library
test-suite test-ghc-typelits-natnormalise
type: exitcode-stdio-1.0
main-is: Tests.hs
Other-Modules: ErrorTests
Other-Modules: AmbTests
ErrorTests
build-depends: base >=4.8 && <5,
ghc-typelits-natnormalise,
tasty >= 0.10,
Expand Down
263 changes: 185 additions & 78 deletions src/GHC/TypeLits/Normalise.hs
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ import Control.Monad ((<=<), forM)
import Control.Monad (replicateM)
#endif
import Control.Monad.Trans.Writer.Strict
import Data.Either (rights)
import Data.List (intersect, stripPrefix, find)
import Data.Either (lefts, rights)
import Data.List (intersect, stripPrefix, find, nub)
import Data.Maybe (mapMaybe, catMaybes)
import Data.Set (Set, empty, toList, notMember, fromList, union)
import Data.Set (Set, empty, toList, notMember, fromList, union, insert)
import GHC.TcPluginM.Extra (tracePlugin, newGiven, newWanted)
import qualified GHC.TcPluginM.Extra as TcPluginM
#if MIN_VERSION_ghc(8,4,0)
Expand Down Expand Up @@ -248,6 +248,7 @@ import TcType (isEqPrimPred)
#endif

-- internal
import GHC.TypeLits.Normalise.SOP
import GHC.TypeLits.Normalise.Unify

#if !MIN_VERSION_ghc(8,10,0)
Expand Down Expand Up @@ -367,7 +368,7 @@ decideEqualSOP opts gen'd givens _deriveds wanteds = do
tcPluginIO $
modifyIORef' gen'd $ union (fromList newlyDone)
let unit_givens = mapMaybe toNatEquality simplGivens
sr <- simplifyNats opts unit_givens unit_wanteds
sr <- solveWanted opts unit_givens unit_wanteds
tcPluginTrace "normalised" (ppr sr)
reds <- forM reducible_wanteds $ \(origCt,(term, ws, w)) -> do
wants <- fmap fst $ evSubtPreds origCt $ subToPred opts ws
Expand Down Expand Up @@ -455,84 +456,190 @@ instance Outputable SimplifyResult where
ppr (Simplified evs) = text "Simplified" $$ ppr evs
ppr (Impossible eq) = text "Impossible" <+> ppr eq

simplifyNats
:: Opts
-- ^ Allow negated numbers (potentially unsound!)
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-- ^ Given constraints
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-- ^ Wanted constraints
-> TcPluginM SimplifyResult
simplifyNats opts@Opts {..} eqsG eqsW =
let eqs = map (second (const [])) eqsG ++ eqsW
in tcPluginTrace "simplifyNats" (ppr eqs) >> simples [] [] [] [] eqs
where
simples :: [CoreUnify]
-> [((EvTerm, Ct), [Ct])]
-> [(CoreSOP,CoreSOP,Bool)]
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> [(Either NatEquality NatInEquality,[(Type,Type)])]
-> TcPluginM SimplifyResult
simples _subst evs _leqsG _xs [] = return (Simplified evs)
simples subst evs leqsG xs (eq@(Left (ct,u,v),k):eqs') = do
let u' = substsSOP subst u
v' = substsSOP subst v
ur <- unifyNats ct u' v'
tcPluginTrace "unifyNats result" (ppr ur)
case ur of
Win -> do
evs' <- maybe evs (:evs) <$> evMagic ct empty (subToPred opts k)
simples subst evs' leqsG [] (xs ++ eqs')
Lose -> if null evs && null eqs'
then return (Impossible (fst eq))
else simples subst evs leqsG xs eqs'
Draw [] -> simples subst evs [] (eq:xs) eqs'
Draw subst' -> do
evM <- evMagic ct empty (map unifyItemToPredType subst' ++
subToPred opts k)
let leqsG' | isGiven (ctEvidence ct) = eqToLeq u' v' ++ leqsG
| otherwise = leqsG
case evM of
Nothing -> simples subst evs leqsG' xs eqs'
Just ev ->
simples (substsSubst subst' subst ++ subst')
(ev:evs) leqsG' [] (xs ++ eqs')
simples subst evs leqsG xs (eq@(Right (ct,u@(x,y,b)),k):eqs') = do
let u' = substsSOP subst (subtractIneq u)
x' = substsSOP subst x
y' = substsSOP subst y
uS = (x',y',b)
leqsG' | isGiven (ctEvidence ct) = (x',y',b):leqsG
| otherwise = leqsG
ineqs = concat [ leqsG
, map (substLeq subst) leqsG
, map snd (rights (map fst eqsG))
]
tcPluginTrace "unifyNats(ineq) results" (ppr (ct,u,u',ineqs))
case runWriterT (isNatural u') of
Just (True,knW) -> do
evs' <- maybe evs (:evs) <$> evMagic ct knW (subToPred opts k)
simples subst evs' leqsG' xs eqs'

Just (False,_) | null k -> return (Impossible (fst eq))
_
findGivenSubst ::
[CoreUnify] ->
[(Either NatEquality NatInEquality, [(Type, Type)])] ->
TcPluginM [CoreUnify]
findGivenSubst subst [] = pure subst
findGivenSubst subst ((Left (ct,u,v),_k):eqs) = do
let uvs = rewriteEQ False subst (u,v)
urs <- mapM (uncurry (unifyNats ct)) uvs
case foldr smashDraw Lose urs of
Draw qs -> findGivenSubst (qs ++ subst) eqs
_ -> findGivenSubst subst eqs
where
smashDraw (Draw s1) (Draw s2) = Draw (nub (s1 ++ s2))
smashDraw (Draw s) _ = Draw s
smashDraw _ (Draw s) = Draw s
smashDraw _ r = r

findGivenSubst subst (_:eqs) = findGivenSubst subst eqs

newtype SymSOP = SymSOP (CoreSOP, CoreSOP)

instance Eq SymSOP where
(SymSOP (p,q)) == (SymSOP (x,y)) =
case p == x of
False -> p == y && q == x
True -> q == y

instance Ord SymSOP where
compare (SymSOP (p,q)) (SymSOP (x,y)) =
case compare p x of
LT -> case compare p y of
EQ -> case compare q x of
EQ -> EQ
_ -> LT
_ -> LT
EQ -> compare q y
GT -> GT

rewriteEQ ::
Bool ->
[CoreUnify] ->
(CoreSOP, CoreSOP) ->
[(CoreSOP, CoreSOP)]
rewriteEQ isW subst = go empty
where
go hist (u,v) =
let
hist1 = insert (SymSOP (u,v)) hist
rewrites = mapMaybe (rewrite (u,v)) subst
new = nub (filter ((`notMember` hist1).SymSOP) rewrites)
in (u,v):concatMap (go hist1) new

rewrite _ UnifyItem {} = Nothing
rewrite (p,q) (SubstItem x e) =
let p1 = substSOP x e p
q1 = substSOP x e q
in if isW || p1 /= q1 then
Just (substSOP x e p, substSOP x e q)
else
Nothing

rewriteINEQ ::
[CoreUnify] ->
(CoreSOP,CoreSOP,Bool) ->
[(CoreSOP,CoreSOP,Bool)]
rewriteINEQ subst = go empty
where
go hist (u,v,b) =
let hist1 = insert (u,v,b) hist
rewrites = mapMaybe (rewrite (u,v,b)) subst
new = nub (filter (`notMember` hist1) rewrites)
in (u,v,b):concatMap (go hist1) new

rewrite _ UnifyItem {} = Nothing
rewrite (p,q,z) (SubstItem x e) =
let p1 = substSOP x e p
q1 = substSOP x e q
in if p1 /= q1 then
Just (substSOP x e p, substSOP x e q, z)
else
Nothing

solveWanted ::
-- | Solver options (depth, whether to allow integers, etc.)
Opts ->
-- | Given constraints
[(Either NatEquality NatInEquality,[(Type,Type)])] ->
-- | Wanted constraints
[(Either NatEquality NatInEquality,[(Type,Type)])] ->
TcPluginM SimplifyResult
solveWanted opts@Opts {..} eqsG eqsW = do
subst <- nub . concatMap mkSubst <$> findGivenSubst [] eqsG
tcPluginTrace "solveWanted eqsG:" (ppr eqsG)
tcPluginTrace "solveWanted subst:" (ppr (subst))
go subst [] eqsW
where
go _ evs [] = pure (Simplified evs)
go subst evs ((Left (ct,u,v),kW):eqs) = do
s <- goSingleWantedEq subst (ct,u,v,kW)
case s of
Simplified ev -> go subst (ev ++ evs) eqs
i -> if null evs && null eqs then
pure i
else
go subst evs eqs
go subst evs ((Right (ct,(u,v,b)),kW):eqs) = do
s <- goSingleWantedInEq subst (ct,u,v,b,kW)
case s of
Simplified ev -> go subst (ev ++ evs) eqs
i -> if null evs && null eqs then
pure i
else
go subst evs eqs

goSingleWantedEq
:: [CoreUnify]
-> (Ct,CoreSOP,CoreSOP,[(Type,Type)])
-> TcPluginM SimplifyResult
goSingleWantedEq subst (ct,u,v,kW) = do
tcPluginTrace "goSingleWanted" (ppr (ct,u,v))
let uvs = rewriteEQ True subst (u,v)
urs <- mapM (uncurry (unifyNats ct)) uvs
case foldr smashResult (Draw []) urs of
Win -> maybe (Simplified []) (Simplified . (:[])) <$>
evMagic ct empty (subToPred opts kW)
Lose -> pure (Impossible (Left (ct,u,v)))
Draw [] -> pure (Simplified [])
Draw s -> maybe (Simplified []) (Simplified . (:[])) <$>
evMagic ct empty (map unifyItemToPredType s ++
subToPred opts kW)

mkSubst UnifyItem {} = []
mkSubst (SubstItem x e)
| S [P [V y]] <- e
= [SubstItem x e, SubstItem y (S [P [(V x)]])]
| otherwise
= [SubstItem x e]

smashResult Lose _ = Lose
smashResult _ Lose = Lose
smashResult Win _ = Win
smashResult _ Win = Win
smashResult (Draw s1) (Draw s2) = Draw (filter simpleSubst (s1 ++ s2))

simpleSubst UnifyItem {} = True
simpleSubst (SubstItem _ (S [P [l]]))
| V _ <- l
= True
| I _ <- l
= True
simpleSubst _ = False

goSingleWantedInEq
:: [CoreUnify]
-> (Ct,CoreSOP,CoreSOP,Bool,[(Type,Type)])
-> TcPluginM SimplifyResult
goSingleWantedInEq subst (ct,u,v,b,kW) = do
let ineqs = map snd (rights (map fst eqsG))
eqs = map (\(_,x,y) -> (x,y)) (lefts (map fst eqsG))
ineqs1 = nub (concatMap (rewriteINEQ subst) ineqs)
ineqs2 = nub (concatMap (concatMap eqToLeq . rewriteEQ False subst) eqs)
ineqs3 = ineqs1 ++ ineqs2
z = subtractIneq (u,v,b)
uvbs = rewriteINEQ subst (u,v,b)
solved =
-- This inequality is either a given constraint, or it is a wanted
-- constraint, which in normal form is equal to another given
-- constraint, hence it can be solved.
| or (mapMaybe (solveIneq depth u) ineqs) ||
-- Or the above, but with valid substitutions applied to the wanted.
or (mapMaybe (solveIneq depth uS) ineqs) ||
or (concatMap (\uvb -> mapMaybe (solveIneq depth uvb) ineqs3) uvbs) ||
-- Or it is an inequality that can be instantly solved, such as
-- `1 <= x^y`
instantSolveIneq depth u
-> do
evs' <- maybe evs (:evs) <$> evMagic ct empty (subToPred opts k)
simples subst evs' leqsG' xs eqs'
| otherwise
-> simples subst evs leqsG (eq:xs) eqs'

eqToLeq x y = [(x,y,True),(y,x,True)]
substLeq s (x,y,b) = (substsSOP s x, substsSOP s y, b)
any (instantSolveIneq depth) uvbs
tcPluginTrace "goSingleWantedInEq" (ppr (ct,uvbs,ineqs3))
case runWriterT (isNatural z) of
Just (True,knW) ->
maybe (Simplified []) (Simplified . (:[])) <$>
evMagic ct knW (subToPred opts kW)
Just (False,_) | null kW -> pure (Impossible (Right (ct,(u,v,b))))
_ | solved
-> maybe (Simplified []) (Simplified . (:[])) <$>
evMagic ct empty (subToPred opts kW)
| otherwise
-> pure (Simplified [])

eqToLeq (x,y) = [(x,y,True),(y,x,True)]

-- If we allow negated numbers we simply do not emit the inequalities
-- derived from the subtractions that are converted to additions with a
Expand Down
1 change: 1 addition & 0 deletions src/GHC/TypeLits/Normalise/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ module GHC.TypeLits.Normalise.Unify
-- * Substitution on 'SOP' terms
, UnifyItem (..)
, CoreUnify
, substSOP
, substsSOP
, substsSubst
-- * Find unifiers
Expand Down
52 changes: 52 additions & 0 deletions tests/AmbTests.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
#if __GLASGOW_HASKELL__ >= 805
{-# LANGUAGE NoStarIsType #-}
#endif

{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -dcore-lint #-}

module AmbTests where

import GHC.TypeLits

leToPlus2
:: forall (k :: Nat) (n :: Nat) r
. ( k <= n
)
=> (forall m . (n ~ (k + m)) => r)
-- ^ Context with the (k + m ~ n) constraint
-> r
leToPlus2 r = r @(n - k)
{-# INLINE leToPlus2 #-}

newtype BitVector (n :: Nat) = BV Integer

class BitPack a where
type BitSize a :: Nat

split :: forall a m n .
(BitPack a, BitSize a ~ (m + n)) =>
a ->
(BitVector m, BitVector n)
split = split

lastBits
:: forall n a
. ( BitPack a
, n <= BitSize a
, KnownNat (BitSize a)
, KnownNat n
)
=> a
-> BitVector n
lastBits = leToPlus2 @n @(BitSize a) $ snd . split @_ @_ @n
1 change: 1 addition & 0 deletions tests/Tests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import Control.Exception
import Test.Tasty
import Test.Tasty.HUnit

import AmbTests
import ErrorTests

data Vec :: Nat -> Type -> Type where
Expand Down

0 comments on commit d23c3c9

Please sign in to comment.