diff --git a/plugins/tactics/src/Ide/Plugin/Tactic.hs b/plugins/tactics/src/Ide/Plugin/Tactic.hs index 25874bf242..1922d8a8ea 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic.hs @@ -214,7 +214,7 @@ filterBindingType filterBindingType p tp dflags plId uri range jdg = let hy = jHypothesis jdg g = jGoal jdg - in fmap join $ for (M.toList hy) $ \(occ, CType ty) -> + in fmap join $ for (M.toList hy) $ \(occ, hi_type -> CType ty) -> case p (unCType g) ty of True -> tp occ ty dflags plId uri range jdg False -> pure [] diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs index 5ce8605e70..e07aa1dfb2 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs @@ -23,6 +23,6 @@ auto = do commit knownStrategies . tracing "auto" . localTactic (auto' 4) - . disallowing + . disallowing RecursiveCall $ fmap fst current diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs b/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs index db20420ede..7d6a58561f 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs @@ -43,8 +43,8 @@ useOccName jdg name = ------------------------------------------------------------------------------ -- | Doing recursion incurs a small penalty in the score. -penalizeRecursion :: MonadState TacticState m => m () -penalizeRecursion = modify $ field @"ts_recursion_penality" +~ 1 +countRecursiveCall :: TacticState -> TacticState +countRecursiveCall = field @"ts_recursion_count" +~ 1 ------------------------------------------------------------------------------ @@ -57,13 +57,13 @@ addUnusedTopVals vals = modify $ field @"ts_unused_top_vals" <>~ vals destructMatches :: (DataCon -> Judgement -> Rule) -- ^ How to construct each match - -> ([(OccName, CType)] -> Judgement -> Judgement) - -- ^ How to derive each match judgement + -> Maybe OccName + -- ^ Scrutinee -> CType -- ^ Type being destructed -> Judgement -> RuleM (Trace, [RawMatch]) -destructMatches f f2 t jdg = do +destructMatches f scrut t jdg = do let hy = jHypothesis jdg g = jGoal jdg case splitTyConApp_maybe $ unCType t of @@ -78,9 +78,8 @@ destructMatches f f2 t jdg = do let hy' = zip names $ coerce args dcon_name = nameOccName $ dataConName dc - let j = f2 hy' - $ withPositionMapping dcon_name names - $ introducingPat hy' + let j = withPositionMapping dcon_name names + $ introducingPat scrut dc hy' $ withNewGoal g jdg (tr, sg) <- f dc j modify $ withIntroducedVals $ mappend $ S.fromList names @@ -142,12 +141,12 @@ destruct' f term jdg = do let hy = jHypothesis jdg case find ((== term) . fst) $ toList hy of Nothing -> throwError $ UndefinedHypothesis term - Just (_, t) -> do + Just (_, hi_type -> t) -> do useOccName jdg term (tr, ms) <- destructMatches f - (\cs -> setParents term (fmap fst cs) . destructing term) + (Just term) t jdg pure ( rose ("destruct " <> show term) $ pure tr @@ -165,7 +164,7 @@ destructLambdaCase' f jdg = do case splitFunTy_maybe (unCType g) of Just (arg, _) | isAlgType arg -> fmap (fmap noLoc $ lambdaCase) <$> - destructMatches f (const id) (CType arg) jdg + destructMatches f Nothing (CType arg) jdg _ -> throwError $ GoalMismatch "destructLambdaCase'" g @@ -178,12 +177,11 @@ buildDataCon -> RuleM (Trace, LHsExpr GhcPs) buildDataCon jdg dc apps = do let args = dataConInstOrigArgTys' dc apps - dcon_name = nameOccName $ dataConName dc (tr, sgs) <- fmap unzipTrace $ traverse ( \(arg, n) -> newSubgoal - . filterSameTypeFromOtherPositions dcon_name n + . filterSameTypeFromOtherPositions''' dc n . blacklistingDestruct . flip withNewGoal jdg $ CType arg diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs index 3beb40daa4..6b1f30b5ac 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs @@ -1,6 +1,7 @@ -{-# LANGUAGE TupleSections #-} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} @@ -14,7 +15,9 @@ import Data.Generics.Product (field) import Data.Map (Map) import qualified Data.Map as M import Data.Maybe +import Data.Set (Set) import qualified Data.Set as S +import DataCon (DataCon) import Development.IDE.Spans.LocalBindings import Ide.Plugin.Tactic.Types import OccName @@ -27,6 +30,7 @@ import Type hypothesisFromBindings :: RealSrcSpan -> Bindings -> Map OccName CType hypothesisFromBindings span bs = buildHypothesis $ getLocalScope bs span + ------------------------------------------------------------------------------ -- | Convert a @Set Id@ into a hypothesis. buildHypothesis :: [(Name, Maybe Type)] -> Map OccName CType @@ -40,12 +44,8 @@ buildHypothesis | otherwise = Nothing -hasDestructed :: Judgement -> OccName -> Bool -hasDestructed j n = S.member n $ _jDestructed j - - -destructing :: OccName -> Judgement -> Judgement -destructing n = field @"_jDestructed" <>~ S.singleton n +hasDestructed :: Judgement' a -> OccName -> Bool +hasDestructed jdg n = S.member n $ _jDestructed jdg blacklistingDestruct :: Judgement -> Judgement @@ -70,30 +70,133 @@ withNewGoal :: a -> Judgement' a -> Judgement' a withNewGoal t = field @"_jGoal" .~ t -introducing :: [(OccName, a)] -> Judgement' a -> Judgement' a -introducing ns = - field @"_jHypothesis" <>~ M.fromList ns +introducingLambda + :: Maybe OccName -- ^ top level function, or Nothing for any other function + -> [(OccName, a)] + -> Judgement' a + -> Judgement' a +introducingLambda func ns = + field @"_jHypothesis" <>~ M.fromList (zip [0..] ns <&> \(pos, (name, ty)) -> + -- TODO(sandy): cleanup + (name, HyInfo (maybe UserPrv (\x -> TopLevelArgPrv x pos) func) ty)) ------------------------------------------------------------------------------ -- | Add some terms to the ambient hypothesis -introducingAmbient :: [(OccName, a)] -> Judgement' a -> Judgement' a -introducingAmbient ns = - field @"_jAmbientHypothesis" <>~ M.fromList ns +introducingRecursively :: [(OccName, a)] -> Judgement' a -> Judgement' a +introducingRecursively ns = + field @"_jHypothesis" <>~ M.fromList (ns <&> \(name, ty) -> + -- TODO(sandy): cleanup + (name, HyInfo RecursivePrv ty + )) filterPosition :: OccName -> Int -> Judgement -> Judgement filterPosition defn pos jdg = - withHypothesis (M.filterWithKey go) jdg + disallowing (WrongBranch pos) (M.keys $ M.filterWithKey go $ jHypothesis jdg) jdg + where + go name _ = not . isJust $ hasPositionalAncestry jdg defn pos name + +hasPositionalAncestry' + :: Foldable t + => t OccName + -> Judgement + -> OccName -- ^ thing to check ancestry + -> Maybe Bool -- ^ Just True if the result is the oldest positional ancestor + -- just false if it's a descendent + -- otherwise nothing +hasPositionalAncestry' ancestors jdg name + | not $ null ancestors + = case any (== name) ancestors of + True -> Just True + False -> + case M.lookup name $ traceIdX "ancestry" $ jAncestryMap jdg of + Just ancestry -> + bool Nothing (Just False) $ any (flip S.member ancestry) ancestors + Nothing -> Nothing + | otherwise = Nothing + +filterPosition' :: OccName -> Int -> Judgement -> Judgement +filterPosition' defn pos jdg = + disallowing (WrongBranch pos) (M.keys $ M.filterWithKey go $ jHypothesis jdg) jdg + where + go name _ + = not + . isJust + $ hasPositionalAncestry' (findPositionVal' jdg defn pos) jdg name + +filterPosition''' :: OccName -> Int -> Judgement -> Judgement +filterPosition''' defn pos jdg = + let broken_ancestors = findPositionVal' jdg defn pos + ancestors = toListOf (_Just . traversed . ix pos) + $ M.lookup defn + $ _jPositionMaps jdg + working = filterPosition defn pos jdg + in case maybeToList broken_ancestors == ancestors of + True -> working + -- TODO(sandy): THE BUG IS THAT WE FILTER OUT FROM THE HYPOTHESIS + -- WHICH REMOVES THE EQUIVALENT OF A POSITION MAPPING; BUT THAT _USED_ + -- TO BE THERE. + False -> error $ show (broken_ancestors, ancestors, defn, pos, jHypothesis jdg) + -- broken = filterPosition' defn pos jdg + -- in case working == broken of + -- True -> working + -- False -> error $ show (working, broken) + +filterDconPosition' :: DataCon -> Int -> Judgement -> Judgement +filterDconPosition' dcon pos jdg = + disallowing (WrongBranch pos) (M.keys $ M.filterWithKey go $ jHypothesis jdg) jdg where - go name _ = isJust $ hasPositionalAncestry jdg defn pos name + go name _ + = not + . isJust + $ hasPositionalAncestry' (findDconPositionVals' jdg dcon pos) jdg name + +findPositionVal' :: Judgement' a -> OccName -> Int -> Maybe OccName +findPositionVal' jdg defn pos = listToMaybe $ do + (name, hi) <- M.toList $ M.map (overProvenance expandDisallowed) $ jEntireHypothesis jdg + case hi_provenance hi of + TopLevelArgPrv defn' pos' + | defn == defn' + , pos == pos' -> pure name + PatternMatchPrv pv + | pv_scrutinee pv == Just defn + , pv_position pv == pos -> pure name + _ -> [] + +findDconPositionVals' :: Judgement' a -> DataCon -> Int -> [OccName] +findDconPositionVals' jdg dcon pos = do + (name, hi) <- M.toList $ jHypothesis jdg + case hi_provenance hi of + PatternMatchPrv pv + | pv_datacon pv == Uniquely dcon + , pv_position pv == pos -> pure name + _ -> [] + +filterSameTypeFromOtherPositions''' :: DataCon -> Int -> Judgement -> Judgement +filterSameTypeFromOtherPositions''' dcon pos jdg = filterSameTypeFromOtherPositions' dcon pos jdg + +filterSameTypeFromOtherPositions' :: DataCon -> Int -> Judgement -> Judgement +filterSameTypeFromOtherPositions' dcon pos jdg = + let hy = jHypothesis $ filterDconPosition' dcon pos jdg + tys = S.fromList $ fmap (hi_type . snd) $ M.toList hy + to_remove = M.filter (flip S.member tys . hi_type) (jHypothesis jdg) M.\\ hy + in disallowing (WrongBranch pos) (M.keys to_remove) jdg filterSameTypeFromOtherPositions :: OccName -> Int -> Judgement -> Judgement filterSameTypeFromOtherPositions defn pos jdg = let hy = jHypothesis $ filterPosition defn pos jdg - tys = S.fromList $ fmap snd $ M.toList hy - in withHypothesis (\hy2 -> M.filter (not . flip S.member tys) hy2 <> hy) jdg + tys = S.fromList $ fmap (hi_type . snd) $ M.toList hy + to_remove = M.filter (flip S.member tys . hi_type) (jHypothesis jdg) M.\\ hy + in disallowing (WrongBranch pos) (M.keys to_remove) jdg + + +getAncestry :: Judgement' a -> OccName -> Set OccName +getAncestry jdg name = + case M.lookup name $ jPatHypothesis jdg of + Just pv -> pv_ancestry pv + Nothing -> mempty hasPositionalAncestry @@ -109,7 +212,7 @@ hasPositionalAncestry jdg defn n name = case any (== name) ancestors of True -> Just True False -> - case M.lookup name $ _jAncestry jdg of + case M.lookup name $ jAncestryMap jdg of Just ancestry -> bool Nothing (Just False) $ any (flip S.member ancestry) ancestors Nothing -> Nothing @@ -120,17 +223,9 @@ hasPositionalAncestry jdg defn n name $ _jPositionMaps jdg -setParents - :: OccName -- ^ parent - -> [OccName] -- ^ children - -> Judgement - -> Judgement -setParents p cs jdg = - let ancestry = mappend (S.singleton p) - $ fromMaybe mempty - $ M.lookup p - $ _jAncestry jdg - in jdg & field @"_jAncestry" <>~ M.fromList (fmap (, ancestry) cs) +jAncestryMap :: Judgement' a -> Map OccName (Set OccName) +jAncestryMap jdg = + flip M.map (jPatHypothesis jdg) pv_ancestry withPositionMapping :: OccName -> [OccName] -> Judgement -> Judgement @@ -150,43 +245,69 @@ extremelyStupid__definingFunction = withHypothesis - :: (Map OccName a -> Map OccName a) + :: (Map OccName (HyInfo a) -> Map OccName (HyInfo a)) -> Judgement' a -> Judgement' a withHypothesis f = field @"_jHypothesis" %~ f ------------------------------------------------------------------------------- --- | Pattern vals are currently tracked in jHypothesis, with an extra piece of data sitting around in jPatternVals. -introducingPat :: [(OccName, a)] -> Judgement' a -> Judgement' a -introducingPat ns jdg = jdg - & field @"_jHypothesis" <>~ M.fromList ns - & field @"_jPatternVals" <>~ S.fromList (fmap fst ns) - -disallowing :: [OccName] -> Judgement' a -> Judgement' a -disallowing ns = - field @"_jHypothesis" %~ flip M.withoutKeys (S.fromList ns) +------------------------------------------------------------------------------ +-- | Pattern vals are currently tracked in jHypothesis, with an extra piece of +-- data sitting around in jPatternVals. +introducingPat + :: Maybe OccName + -> DataCon + -> [(OccName, a)] + -> Judgement' a + -> Judgement' a +introducingPat scrutinee dc ns jdg = jdg + & field @"_jHypothesis" <>~ (M.fromList $ zip [0..] ns <&> \(pos, (name, ty)) -> + ( name + , HyInfo + (PatternMatchPrv $ PatVal + scrutinee + (maybe + mempty + (\scrut -> S.singleton scrut <> getAncestry jdg scrut) + scrutinee) + (Uniquely dc) + pos) + ty)) + & maybe id (\scrut -> field @"_jDestructed" <>~ S.singleton scrut) scrutinee + + +disallowing :: DisallowReason -> [OccName] -> Judgement' a -> Judgement' a +disallowing reason (S.fromList -> ns) = + field @"_jHypothesis" %~ (M.mapWithKey $ \name hi -> + case S.member name ns of + True -> overProvenance (DisallowedPrv reason) hi + False -> hi + ) ------------------------------------------------------------------------------ -- | The hypothesis, consisting of local terms and the ambient environment -- (includes and class methods.) -jHypothesis :: Judgement' a -> Map OccName a -jHypothesis = _jHypothesis <> _jAmbientHypothesis +jHypothesis :: Judgement' a -> Map OccName (HyInfo a) +jHypothesis = M.filter (not . isDisallowed . hi_provenance) . jEntireHypothesis ------------------------------------------------------------------------------ --- | Just the local hypothesis. -jLocalHypothesis :: Judgement' a -> Map OccName a -jLocalHypothesis = _jHypothesis +-- | The whole hypothesis, including things disallowed. +jEntireHypothesis :: Judgement' a -> Map OccName (HyInfo a) +jEntireHypothesis = _jHypothesis + +------------------------------------------------------------------------------ +-- | Just the local hypothesis. +jLocalHypothesis :: Judgement' a -> Map OccName (HyInfo a) +jLocalHypothesis = M.filter (isLocalHypothesis . hi_provenance) . jHypothesis -isPatVal :: Judgement' a -> OccName -> Bool -isPatVal j n = S.member n $ _jPatternVals j -isTopHole :: Judgement' a -> Bool -isTopHole = _jIsTopHole +isTopHole :: Context -> Judgement' a -> Maybe OccName +isTopHole ctx = + bool Nothing (Just $ extremelyStupid__definingFunction ctx) . _jIsTopHole unsetIsTopHole :: Judgement' a -> Judgement' a unsetIsTopHole = field @"_jIsTopHole" .~ False @@ -194,9 +315,15 @@ unsetIsTopHole = field @"_jIsTopHole" .~ False ------------------------------------------------------------------------------ -- | Only the hypothesis members which are pattern vals -jPatHypothesis :: Judgement' a -> Map OccName a -jPatHypothesis jdg - = M.restrictKeys (jHypothesis jdg) $ _jPatternVals jdg +jPatHypothesis :: Judgement' a -> Map OccName PatVal +jPatHypothesis = M.mapMaybe (getPatVal . hi_provenance) . jHypothesis + + +getPatVal :: Provenance-> Maybe PatVal +getPatVal prov = + case prov of + PatternMatchPrv pv -> Just pv + _ -> Nothing jGoal :: Judgement' a -> a @@ -206,6 +333,7 @@ jGoal = _jGoal substJdg :: TCvSubst -> Judgement -> Judgement substJdg subst = fmap $ coerce . substTy subst . coerce + mkFirstJudgement :: M.Map OccName CType -- ^ local hypothesis -> M.Map OccName CType -- ^ ambient hypothesis @@ -214,15 +342,41 @@ mkFirstJudgement -> Type -> Judgement' CType mkFirstJudgement hy ambient top posvals goal = Judgement - { _jHypothesis = hy - , _jAmbientHypothesis = ambient - , _jDestructed = mempty - , _jPatternVals = mempty + { _jHypothesis = M.map mkLocalHypothesisInfo hy + <> M.map mkAmbientHypothesisInfo ambient , _jBlacklistDestruct = False , _jWhitelistSplit = True + , _jDestructed = mempty , _jPositionMaps = posvals - , _jAncestry = mempty , _jIsTopHole = top , _jGoal = CType goal } + +mkLocalHypothesisInfo :: a -> HyInfo a +mkLocalHypothesisInfo = HyInfo UserPrv + + +mkAmbientHypothesisInfo :: a -> HyInfo a +mkAmbientHypothesisInfo = HyInfo ImportPrv + + +isLocalHypothesis :: Provenance -> Bool +isLocalHypothesis UserPrv{} = True +isLocalHypothesis PatternMatchPrv{} = True +isLocalHypothesis TopLevelArgPrv{} = True +isLocalHypothesis _ = False + + +isPatternMatch :: Provenance -> Bool +isPatternMatch PatternMatchPrv{} = True +isPatternMatch _ = False + +isDisallowed :: Provenance -> Bool +isDisallowed DisallowedPrv{} = True +isDisallowed _ = False + +expandDisallowed :: Provenance -> Provenance +expandDisallowed (DisallowedPrv _ prv) = expandDisallowed prv +expandDisallowed prv = prv + diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs index f3e41c0061..26108c5cbc 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs @@ -74,7 +74,7 @@ runTactic ctx jdg t = let skolems = nub $ foldMap (tyCoVarsOfTypeWellScoped . unCType) $ jGoal jdg - : (toList $ jHypothesis jdg) + : (fmap hi_type $ toList $ jHypothesis jdg) unused_topvals = nub $ join $ join $ toList $ _jPositionMaps jdg tacticState = defaultTacticState @@ -118,20 +118,29 @@ tracing s (TacticT m) mapExtract' (first $ rose s . pure) $ runStateT m jdg -recursiveCleanup +------------------------------------------------------------------------------ +-- | Recursion is allowed only when we can prove it is on a structurally +-- smaller argument. The top of the 'ts_recursion_stack' is set to 'True' iff +-- one of the recursive arguments is a pattern val (ie. came from a pattern +-- match.) +guardStructurallySmallerRecursion :: TacticState -> Maybe TacticError -recursiveCleanup s = - let r = head $ ts_recursion_stack s - in case r of - True -> Nothing - False -> Just NoProgress +guardStructurallySmallerRecursion s = + case head $ ts_recursion_stack s of + True -> Nothing + False -> Just NoProgress -setRecursionFrameData :: MonadState TacticState m => Bool -> m () -setRecursionFrameData b = do +------------------------------------------------------------------------------ +-- | Mark that the current recursive call is structurally smaller, due to +-- having been matched on a pattern value. +-- +-- Implemented by setting the top of the 'ts_recursion_stack'. +markStructuralySmallerRecursion :: MonadState TacticState m => m () +markStructuralySmallerRecursion = do modify $ withRecursionStack $ \case - (_ : bs) -> b : bs + (_ : bs) -> True : bs [] -> [] @@ -159,7 +168,7 @@ scoreSolution ext TacticState{..} holes , Penalize $ S.size ts_unused_top_vals , Penalize $ S.size ts_intro_vals , Reward $ S.size ts_used_vals - , Penalize $ ts_recursion_penality + , Penalize $ ts_recursion_count , Penalize $ solutionSize ext ) diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs index f1c2a6d220..3b1ce087e5 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs @@ -55,10 +55,9 @@ assume :: OccName -> TacticsM () assume name = rule $ \jdg -> do let g = jGoal jdg case M.lookup name $ jHypothesis jdg of - Just ty -> do + Just (hi_type -> ty) -> do unify ty $ jGoal jdg - when (M.member name $ jPatHypothesis jdg) $ - setRecursionFrameData True + when (M.member name $ jPatHypothesis jdg) markStructuralySmallerRecursion useOccName jdg name pure $ (tracePrim $ "assume " <> occNameString name, ) $ noLoc $ var' name Nothing -> throwError $ UndefinedHypothesis name @@ -68,11 +67,11 @@ recursion :: TacticsM () recursion = requireConcreteHole $ tracing "recursion" $ do defs <- getCurrentDefinitions attemptOn (const $ fmap fst defs) $ \name -> do - modify $ withRecursionStack (False :) - penalizeRecursion - ensure recursiveCleanup (withRecursionStack tail) $ do - (localTactic (apply name) $ introducingAmbient defs) - <@> fmap (localTactic assumption . filterPosition name) [0..] + modify $ pushRecursionStack . countRecursiveCall + jdg <- goal + ensure guardStructurallySmallerRecursion popRecursionStack $ do + (localTactic (apply name) $ introducingRecursively defs) + <@> fmap (localTactic assumption . filterPosition''' name) [0..] ------------------------------------------------------------------------------ @@ -86,17 +85,19 @@ intros = rule $ \jdg -> do ([], _) -> throwError $ GoalMismatch "intros" g (as, b) -> do vs <- mkManyGoodNames hy as - let jdg' = introducing (zip vs $ coerce as) + let top_hole = isTopHole ctx jdg + let jdg' = traceIdX "introduced lambda" + $ introducingLambda top_hole (zip vs $ coerce as) $ withNewGoal (CType b) jdg modify $ withIntroducedVals $ mappend $ S.fromList vs - when (isTopHole jdg) $ addUnusedTopVals $ S.fromList vs + when (isJust top_hole) $ addUnusedTopVals $ S.fromList vs (tr, sg) <- newSubgoal $ bool id (withPositionMapping (extremelyStupid__definingFunction ctx) vs) - (isTopHole jdg) + (isJust top_hole) $ jdg' pure . (rose ("intros {" <> intercalate ", " (fmap show vs) <> "}") $ pure tr, ) @@ -110,20 +111,24 @@ intros = rule $ \jdg -> do destructAuto :: OccName -> TacticsM () destructAuto name = requireConcreteHole $ tracing "destruct(auto)" $ do jdg <- goal - case hasDestructed jdg name of - True -> throwError $ AlreadyDestructed name - False -> - let subtactic = rule $ destruct' (const subgoal) name - in case isPatVal jdg name of - True -> - pruning subtactic $ \jdgs -> - let getHyTypes = S.fromList . fmap snd . M.toList . jHypothesis - new_hy = foldMap getHyTypes jdgs - old_hy = getHyTypes jdg - in case S.null $ new_hy S.\\ old_hy of - True -> Just $ UnhelpfulDestruct name - False -> Nothing - False -> subtactic + case M.lookup name $ jHypothesis jdg of + Nothing -> throwError $ NotInScope name + Just hi -> + case hasDestructed jdg name of + True -> throwError $ AlreadyDestructed name + False -> + let subtactic = rule $ destruct' (const subgoal) name + in case isPatternMatch $ hi_provenance hi of + True -> + pruning subtactic $ \jdgs -> + let getHyTypes = S.fromList . fmap (hi_type . snd) . M.toList . jHypothesis + new_hy = foldMap getHyTypes jdgs + old_hy = getHyTypes jdg + in case S.null $ new_hy S.\\ old_hy of + True -> Just $ UnhelpfulDestruct name + False -> Nothing + False -> subtactic + ------------------------------------------------------------------------------ -- | Case split, and leave holes in the matches. @@ -167,7 +172,7 @@ apply func = requireConcreteHole $ tracing ("apply' " <> show func) $ do let hy = jHypothesis jdg g = jGoal jdg case M.lookup func hy of - Just (CType ty) -> do + Just (hi_type -> CType ty) -> do ty' <- freshTyvars ty let (_, _, args, ret) = tacticsSplitFunTy ty' requireNewHoles $ rule $ \jdg -> do @@ -283,12 +288,12 @@ auto' n = do overFunctions :: (OccName -> TacticsM ()) -> TacticsM () overFunctions = - attemptOn $ M.keys . M.filter (isFunction . unCType) . jHypothesis + attemptOn $ M.keys . M.filter (isFunction . unCType . hi_type) . jHypothesis overAlgebraicTerms :: (OccName -> TacticsM ()) -> TacticsM () overAlgebraicTerms = attemptOn $ - M.keys . M.filter (isJust . algebraicTyCon . unCType) . jHypothesis + M.keys . M.filter (isJust . algebraicTyCon . unCType . hi_type) . jHypothesis allNames :: Judgement -> [OccName] diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs index 6b4201b49a..0f76560b43 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeApplications #-} @@ -39,7 +40,7 @@ import Refinery.Tactic import System.IO.Unsafe (unsafePerformIO) import Type import UniqSupply (takeUniqFromSupply, mkSplitUniqSupply, UniqSupply) -import Unique (Unique) +import Unique (nonDetCmpUnique, Uniquable, getUnique, Unique) ------------------------------------------------------------------------------ @@ -70,6 +71,9 @@ instance Show (LHsExpr GhcPs) where instance Show DataCon where show = unsafeRender +instance Show Class where + show = unsafeRender + ------------------------------------------------------------------------------ data TacticState = TacticState @@ -87,7 +91,7 @@ data TacticState = TacticState -- ^ Stack for tracking whether or not the current recursive call has -- used at least one smaller pat val. Recursive calls for which this -- value is 'False' are guaranteed to loop, and must be pruned. - , ts_recursion_penality :: !Int + , ts_recursion_count :: !Int -- ^ Number of calls to recursion. We penalize each. , ts_unique_gen :: !UniqSupply } deriving stock (Show, Generic) @@ -113,7 +117,7 @@ defaultTacticState = , ts_intro_vals = mempty , ts_unused_top_vals = mempty , ts_recursion_stack = mempty - , ts_recursion_penality = 0 + , ts_recursion_count = 0 , ts_unique_gen = unsafeDefaultUniqueSupply } @@ -132,6 +136,12 @@ withRecursionStack withRecursionStack f = field @"ts_recursion_stack" %~ f +pushRecursionStack :: TacticState -> TacticState +pushRecursionStack = withRecursionStack (False :) + +popRecursionStack :: TacticState -> TacticState +popRecursionStack = withRecursionStack tail + withUsedVals :: (Set OccName -> Set OccName) -> TacticState -> TacticState withUsedVals f = @@ -143,26 +153,72 @@ withIntroducedVals f = field @"ts_intro_vals" %~ f +data Provenance + = TopLevelArgPrv + OccName -- ^ Function name + Int -- ^ Position + | PatternMatchPrv PatVal + | ClassMethodPrv + (Uniquely Class) -- ^ Class + | UserPrv + | ImportPrv + | RecursivePrv + | DisallowedPrv DisallowReason Provenance + deriving stock (Eq, Show, Generic, Ord) + + +data DisallowReason + = WrongBranch Int + | RecursiveCall + deriving stock (Eq, Show, Generic, Ord) + + +data PatVal = PatVal + { pv_scrutinee :: Maybe OccName + -- ^ Original scrutinee which created this PatVal. Nothing, for lambda + -- case. + , pv_ancestry :: Set OccName + , pv_datacon :: Uniquely DataCon + , pv_position :: Int + } deriving stock (Eq, Show, Generic, Ord) + + +newtype Uniquely a = Uniquely { getViaUnique :: a } + deriving Show via a + +instance Uniquable a => Eq (Uniquely a) where + (==) = (==) `on` getUnique . getViaUnique + +instance Uniquable a => Ord (Uniquely a) where + compare = nonDetCmpUnique `on` getUnique . getViaUnique + + +data HyInfo a = HyInfo + { hi_provenance :: Provenance + , hi_type :: a + } + deriving stock (Functor, Eq, Show, Generic, Ord) + +overProvenance :: (Provenance -> Provenance) -> HyInfo a -> HyInfo a +overProvenance f (HyInfo prv ty) = HyInfo (f prv) ty + ------------------------------------------------------------------------------ -- | The current bindings and goal for a hole to be filled by refinery. data Judgement' a = Judgement - { _jHypothesis :: !(Map OccName a) - , _jAmbientHypothesis :: !(Map OccName a) - -- ^ Things in the hypothesis that were imported. Solutions don't get - -- points for using the ambient hypothesis. + { _jHypothesis :: !(Map OccName (HyInfo a)) , _jDestructed :: !(Set OccName) - -- ^ These should align with keys of _jHypothesis - , _jPatternVals :: !(Set OccName) - -- ^ These should align with keys of _jHypothesis + -- ^ Set of names we've already destructed. These should align with keys of + -- _jHypothesis. You might think we could just inspect the hypothesis and + -- find any PatVals whose scrutinee is the name in question, but this fails + -- for nullary data constructors. , _jBlacklistDestruct :: !(Bool) , _jWhitelistSplit :: !(Bool) , _jPositionMaps :: !(Map OccName [[OccName]]) - , _jAncestry :: !(Map OccName (Set OccName)) , _jIsTopHole :: !Bool , _jGoal :: !(a) } - deriving stock (Eq, Ord, Generic, Functor, Show) + deriving stock (Eq, Generic, Functor, Show) type Judgement = Judgement' CType @@ -191,6 +247,7 @@ data TacticError | UnhelpfulDestruct OccName | UnhelpfulSplit OccName | TooPolymorphic + | NotInScope OccName deriving stock (Eq) instance Show TacticError where @@ -229,6 +286,8 @@ instance Show TacticError where "Splitting constructor " <> show n <> " leads to no new goals" show TooPolymorphic = "The tactic isn't applicable because the goal is too polymorphic" + show (NotInScope name) = + "Tried to do something with the out of scope name " <> show name ------------------------------------------------------------------------------