diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index a169ce4da..ccdb17108 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -165,15 +165,14 @@ aInstanceDef (CInstanceDef (WithSrc clNameId clName) args givens methods instNam aDef :: CDef -> SyntaxM (SourceNameW, ULamExpr VoidS) aDef (CDef name params optRhs optGivens body) = do explicitParams <- explicitBindersOptAnn params - let rhsDefault = (ExplicitApp, Nothing, Nothing) - (expl, effs, resultTy) <- fromMaybeM optRhs rhsDefault \(expl, optEffs, resultTy) -> do - effs <- fromMaybeM optEffs UPure aEffects + let rhsDefault = (ExplicitApp, Nothing) + (expl, resultTy) <- fromMaybeM optRhs rhsDefault \(expl, resultTy) -> do resultTy' <- expr resultTy - return (expl, Just effs, Just resultTy') + return (expl, Just resultTy') implicitParams <- aOptGivens optGivens let allParams = implicitParams >>> explicitParams body' <- block body - return (name, ULamExpr allParams expl effs resultTy body') + return (name, ULamExpr allParams expl resultTy body') stripParens :: GroupW -> GroupW stripParens (WithSrcs _ _ (CParens [g])) = stripParens g @@ -356,26 +355,6 @@ identifier ctx (WithSrcs sid _ g) = case g of CLeaf (CIdentifier name) -> return $ WithSrc sid name _ -> throw sid $ ExpectedIdentifier ctx -aEffects :: WithSrcs ([GroupW], Maybe GroupW) -> SyntaxM (UEffectRow VoidS) -aEffects (WithSrcs _ _ (effs, optEffTail)) = do - lhs <- mapM effect effs - rhs <- forM optEffTail \effTail -> - fromSourceNameW <$> identifier "effect row remainder variable" effTail - return $ UEffectRow (S.fromList lhs) rhs - -effect :: GroupW -> SyntaxM (UEffect VoidS) -effect (WithSrcs grpSid _ grp) = case grp of - CParens [g] -> effect g - CJuxtapose True (Identifier "Read" ) (WithSrcs sid _ (CLeaf (CIdentifier h))) -> - return $ URWSEffect Reader $ fromSourceNameW (WithSrc sid h) - CJuxtapose True (Identifier "Accum") (WithSrcs sid _ (CLeaf (CIdentifier h))) -> - return $ URWSEffect Writer $ fromSourceNameW (WithSrc sid h) - CJuxtapose True (Identifier "State") (WithSrcs sid _ (CLeaf (CIdentifier h))) -> - return $ URWSEffect State $ fromSourceNameW (WithSrc sid h) - CLeaf (CIdentifier "Except") -> return UExceptionEffect - CLeaf (CIdentifier "IO" ) -> return UIOEffect - _ -> throw grpSid UnexpectedEffectForm - aMethod :: CSDeclW -> SyntaxM (Maybe (UMethodDef VoidS)) aMethod (WithSrcs _ _ CPass) = return Nothing aMethod (WithSrcs sid _ d) = Just . WithSrcE sid <$> case d of @@ -383,7 +362,7 @@ aMethod (WithSrcs sid _ d) = Just . WithSrcE sid <$> case d of (WithSrc nameSid name, lam) <- aDef def return $ UMethodDef (SourceName nameSid name) lam CLet (WithSrcs lhsSid _ (CLeaf (CIdentifier name))) rhs -> do - rhs' <- ULamExpr Empty ImplicitApp Nothing Nothing <$> block rhs + rhs' <- ULamExpr Empty ImplicitApp Nothing <$> block rhs return $ UMethodDef (fromSourceNameW (WithSrc lhsSid name)) rhs' _ -> throw sid UnexpectedMethodDef @@ -407,7 +386,7 @@ blockDecls (WithSrcs sid _ (CBind b rhs):ds) = do b' <- binderOptTy Explicit b rhs' <- asExpr <$> block rhs body <- block $ IndentedBlock sid ds -- Not really the right SrcId - let lam = ULam $ ULamExpr (UnaryNest b') ExplicitApp Nothing Nothing body + let lam = ULam $ ULamExpr (UnaryNest b') ExplicitApp Nothing body return (Empty, WithSrcE sid $ extendAppRight rhs' (WithSrcE sid lam)) blockDecls (d:ds) = do d' <- decl PlainLet d @@ -428,13 +407,12 @@ expr (WithSrcs sid sids grp) = WithSrcE sid <$> case grp of -- should be detected upstream, before calling expr. CBrackets gs -> UTabCon <$> mapM expr gs CGivens _ -> throw sid UnexpectedGivenClause - CArrow lhs effs rhs -> do + CArrow lhs rhs -> do case lhs of WithSrcs _ _ (CParens gs) -> do bs <- aPiBinders gs - effs' <- fromMaybeM effs UPure aEffects resultTy <- expr rhs - return $ UPi $ UPiExpr bs ExplicitApp effs' resultTy + return $ UPi $ UPiExpr bs ExplicitApp resultTy WithSrcs lhsSid _ _ -> throw lhsSid ArgsShouldHaveParens CDo b -> UDo <$> block b CJuxtapose hasSpace lhs rhs -> case hasSpace of @@ -476,7 +454,7 @@ expr (WithSrcs sid sids grp) = WithSrcE sid <$> case grp of WithSrcs _ _ (CParens gs) -> do bs <- aPiBinders gs resultTy <- expr rhs - return $ UPi $ UPiExpr bs ImplicitApp UPure resultTy + return $ UPi $ UPiExpr bs ImplicitApp resultTy WithSrcs lhsSid _ _ -> throw lhsSid ArgsShouldHaveParens FatArrow -> do lhs' <- tyOptPat lhs @@ -501,7 +479,7 @@ expr (WithSrcs sid sids grp) = WithSrcE sid <$> case grp of CLambda params body -> do params' <- explicitBindersOptAnn $ WithSrcs sid [] $ map stripParens params body' <- block body - return $ ULam $ ULamExpr params' ExplicitApp Nothing Nothing body' + return $ ULam $ ULamExpr params' ExplicitApp Nothing body' CFor kind indices body -> do let (dir, trailingUnit) = case kind of KFor -> (Fwd, False) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index c867b6bf4..76d1a3202 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -15,8 +15,8 @@ import Control.Monad.Reader import Control.Monad.Writer.Strict hiding (Alt) import Control.Monad.State.Strict (MonadState (..), StateT (..), runStateT) import qualified Data.Map.Strict as M -import Data.Foldable (fold) import Data.Graph (graphFromEdges, topSort) +import Data.Foldable (fold) import Foreign.Ptr import qualified Unsafe.Coerce as TrulyUnsafe @@ -709,7 +709,7 @@ buildCase' scrut resultTy indexedAltBody = case scrut of blk <- buildBlock $ indexedAltBody i $ toAtom $ sink x return $ blk `PairE` getEffects blk return (Abs b' body, ignoreHoistFailure $ hoist b' eff') - return $ Case scrut alts $ EffTy (mconcat effs) resultTy + return $ Case scrut alts $ EffTy (fold effs) resultTy buildCase :: (Emits n, ScopableBuilder r m) => Atom r n -> Type r n @@ -717,20 +717,6 @@ buildCase :: (Emits n, ScopableBuilder r m) -> m n (Atom r n) buildCase s r b = emit =<< buildCase' s r b -buildEffLam - :: ScopableBuilder r m - => NameHint -> Type r n - -> (forall l. (Emits l, DExt n l) => AtomVar r l -> AtomVar r l -> m l (Atom r l)) - -> m n (LamExpr r n) -buildEffLam hint ty body = do - withFreshBinder noHint (TyCon HeapType) \h -> do - let ty' = RefTy (toAtom $ binderVar h) (sink ty) - withFreshBinder hint ty' \b -> do - let ref = binderVar b - hVar <- sinkM $ binderVar h - body' <- buildBlock $ body (sink hVar) $ sink ref - return $ LamExpr (BinaryNest h b) body' - emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) emitHof hof = mkTypedHof hof >>= emit @@ -766,35 +752,6 @@ buildMap xs f = do buildFor noHint Fwd (tabIxType t) \i -> tabApp (sink xs) (toAtom i) >>= f -emitRunWriter - :: (Emits n, ScopableBuilder r m) - => NameHint -> Type r n -> BaseMonoid r n - -> (forall l. (Emits l, DExt n l) => AtomVar r l -> AtomVar r l -> m l (Atom r l)) - -> m n (Atom r n) -emitRunWriter hint accTy bm body = do - lam <- buildEffLam hint accTy \h ref -> body h ref - emitHof $ RunWriter Nothing bm lam - -emitRunState - :: (Emits n, ScopableBuilder r m) - => NameHint -> Atom r n - -> (forall l. (Emits l, DExt n l) => AtomVar r l -> AtomVar r l -> m l (Atom r l)) - -> m n (Atom r n) -emitRunState hint initVal body = do - stateTy <- return $ getType initVal - lam <- buildEffLam hint stateTy \h ref -> body h ref - emitHof $ RunState Nothing initVal lam - -emitRunReader - :: (Emits n, ScopableBuilder r m) - => NameHint -> Atom r n - -> (forall l. (Emits l, DExt n l) => AtomVar r l -> AtomVar r l -> m l (Atom r l)) - -> m n (Atom r n) -emitRunReader hint r body = do - rTy <- return $ getType r - lam <- buildEffLam hint rTy \h ref -> body h ref - emitHof $ RunReader r lam - emitSeq :: (Emits n, ScopableBuilder SimpIR m) => Direction -> IxType SimpIR n -> Atom SimpIR n -> LamExpr SimpIR n -> m n (Atom SimpIR n) @@ -806,8 +763,7 @@ mkSeq :: EnvReader m => Direction -> IxType SimpIR n -> Atom SimpIR n -> LamExpr SimpIR n -> m n (DAMOp SimpIR n) mkSeq d t x f = do - effTy <- functionEffs f - return $ Seq effTy d t x f + return $ Seq undefined d t x f buildRememberDest :: (Emits n, ScopableBuilder SimpIR m) => NameHint -> SAtom n @@ -816,8 +772,7 @@ buildRememberDest :: (Emits n, ScopableBuilder SimpIR m) buildRememberDest hint dest cont = do ty <- return $ getType dest doit <- buildUnaryLamExpr hint ty cont - effs <- functionEffs doit - emit $ PrimOp $ DAMOp $ RememberDest effs dest doit + emit $ PrimOp $ DAMOp $ RememberDest undefined dest doit -- === vector space (ish) type class === @@ -1040,15 +995,16 @@ mkBlock (Abs decls body) = do return $ Block effTy block blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n) -blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do - effs <- declsEffects decls mempty - return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result - where - declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) - declsEffects Empty !acc = return acc - declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do - expr' <- sinkM expr - declsEffects rest $ acc <> getEffects expr' +blockEffTy _ = undefined +-- blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do +-- effs <- declsEffects decls mempty +-- return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result +-- where +-- declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) +-- declsEffects Empty !acc = return acc +-- declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do +-- expr' <- sinkM expr +-- declsEffects rest $ acc <> getEffects expr' mkApp :: EnvReader m => CAtom n -> [CAtom n] -> m n (CExpr n) mkApp f xs = do @@ -1084,15 +1040,9 @@ mkInstanceDict instanceName args = do mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n) mkCase scrut resultTy alts = liftEnvReaderM do - eff' <- fold <$> forM alts \alt -> refreshAbs alt \b body -> do - return $ ignoreHoistFailure $ hoist b $ getEffects body + eff' <- undefined return $ Case scrut alts (EffTy eff' resultTy) -mkCatchException :: EnvReader m => CExpr n -> m n (Hof CoreIR n) -mkCatchException body = do - resultTy <- makePreludeMaybeTy (getType body) - return $ CatchException resultTy body - app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n) app x i = mkApp x [i] >>= emit @@ -1134,9 +1084,7 @@ ptrOffset x i = emit $ MemOp $ PtrOffset x i {-# INLINE ptrOffset #-} unsafePtrLoad :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -unsafePtrLoad x = do - body <- liftEmitBuilder $ buildBlock $ emit . MemOp . PtrLoad =<< sinkM x - emitHof $ RunIO body +unsafePtrLoad x = emit . MemOp . PtrLoad =<< sinkM x mkIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Atom r n -> Atom r n -> m n (PrimOp r n) mkIndexRef ref i = do @@ -1198,102 +1146,6 @@ emitIf predicate resultTy trueCase falseCase = do 1 -> trueCase _ -> error "should only have two cases" -emitMaybeCase :: (Emits n, ScopableBuilder r m) - => Atom r n -> Type r n - -> (forall l. (Emits l, DExt n l) => m l (Atom r l)) - -> (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l)) - -> m n (Atom r n) -emitMaybeCase scrut resultTy nothingCase justCase = do - buildCase scrut resultTy \i v -> - case i of - 0 -> nothingCase - 1 -> justCase v - _ -> error "should be a binary scrutinee" - --- Maybe a -> a -fromJustE :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n) -fromJustE x = liftEmitBuilder do - MaybeTy a <- return $ getType x - emitMaybeCase x a - (emit $ MiscOp $ ThrowError $ sink a) - (return) - --- Maybe a -> Bool -isJustE :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n) -isJustE x = liftEmitBuilder $ - emitMaybeCase x BoolTy (return FalseAtom) (\_ -> return TrueAtom) - --- Monoid a -> (n=>a) -> a -reduceE :: (Emits n, SBuilder m) => BaseMonoid SimpIR n -> SAtom n -> m n (SAtom n) -reduceE monoid xs = liftEmitBuilder do - TabPi tabPi <- return $ getTyCon xs - let a = assumeConst tabPi - getSnd =<< emitRunWriter noHint a monoid \_ ref -> - buildFor noHint Fwd (sink $ tabIxType tabPi) \i -> do - x <- tabApp (sink xs) (toAtom i) - emit $ PrimOp $ RefOp (sink $ toAtom ref) $ MExtend (sink monoid) x - -andMonoid :: (EnvReader m, IRRep r) => m n (BaseMonoid r n) -andMonoid = liftM (BaseMonoid TrueAtom) $ liftBuilder $ - buildBinaryLamExpr (noHint, BoolTy) (noHint, BoolTy) \x y -> - emit $ BinOp BAnd (sink $ toAtom x) (toAtom y) - --- (a-> {|eff} b) -> n=>a -> {|eff} (n=>b) -mapE :: (Emits n, ScopableBuilder SimpIR m) - => (forall l. (Emits l, DExt n l) => SAtom l -> m l (SAtom l)) - -> SAtom n -> m n (SAtom n) -mapE cont xs = do - TabPi tabPi <- return $ getTyCon xs - buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> do - tabApp (sink xs) (toAtom i) >>= cont - --- (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) = -catMaybesE :: (Emits n, SBuilder m) => SAtom n -> m n (SAtom n) -catMaybesE maybes = do - TabTy d n (MaybeTy a) <- return $ getType maybes - justs <- liftEmitBuilder $ mapE isJustE maybes - monoid <- andMonoid - allJust <- reduceE monoid justs - liftEmitBuilder $ emitIf allJust (MaybeTy $ TabTy d n a) - (JustAtom (sink $ TabTy d n a) <$> mapE fromJustE (sink maybes)) - (return (NothingAtom $ sink $ TabTy d n a)) - -emitWhile :: (Emits n, ScopableBuilder r m) - => (forall l. (Emits l, DExt n l) => m l (Atom r l)) - -> m n () -emitWhile cont = do - body <- buildBlock cont - void $ emitHof $ While body - --- Dex implementation, for reference --- def whileMaybe (eff:Effects) -> (body: Unit -> {|eff} (Maybe Word8)) : {|eff} Maybe Unit = --- hadError = yieldState False \ref. --- while do --- ans = liftState ref body () --- case ans of --- Nothing -> --- ref := True --- False --- Just cond -> W8ToB cond --- if hadError --- then Nothing --- else Just () - -runMaybeWhile :: (Emits n, ScopableBuilder r m) - => (forall l. (Emits l, DExt n l) => m l (Atom r l)) - -> m n (Atom r n) -runMaybeWhile body = do - hadError <- getSnd =<< emitRunState noHint FalseAtom \_ ref -> do - emitWhile do - ans <- body - emitMaybeCase ans Word8Ty - (emit (RefOp (sink $ toAtom ref) $ MPut TrueAtom) >> return FalseAtom) - (return) - return UnitVal - emitIf hadError (MaybeTy UnitTy) - (return $ NothingAtom UnitTy) - (return $ JustAtom UnitTy UnitVal) - -- === capturing closures with telescopes === type ReconAbs r e = Abs (ReconBinders r) e diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index d1a4daf71..81da044d9 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -232,7 +232,7 @@ typeOfTabApp _ _ = error "expected a TabPi type" typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) typeOfApp (TyCon (Pi piTy)) xs = withSubstReaderT $ - withInstantiated piTy xs \(EffTy _ ty) -> substM ty + withInstantiated piTy xs \ty -> substM ty typeOfApp _ _ = error "expected a pi type" repValAtom :: EnvReader m => RepVal n -> m n (SAtom n) @@ -277,7 +277,6 @@ unwrapNewtypeType = \case def <- lookupTyCon defName ty' <- dataDefRep <$> instantiateTyConDef def params return (UserADTData sn defName params, ty') - ty -> error $ "Shouldn't be projecting: " ++ pprint ty {-# INLINE unwrapNewtypeType #-} instantiateTyConDef :: EnvReader m => TyConDef n -> TyConParams n -> m n (DataConDefs n) @@ -440,13 +439,7 @@ instance IRRep r => VisitGeneric (TypedHof r) r where instance IRRep r => VisitGeneric (Hof r) r where visitGeneric = \case For ann d lam -> For ann <$> visitGeneric d <*> visitGeneric lam - RunReader x body -> RunReader <$> visitGeneric x <*> visitGeneric body - RunWriter dest bm body -> RunWriter <$> mapM visitGeneric dest <*> visitGeneric bm <*> visitGeneric body - RunState dest s body -> RunState <$> mapM visitGeneric dest <*> visitGeneric s <*> visitGeneric body While b -> While <$> visitBlock b - RunIO b -> RunIO <$> visitBlock b - RunInit b -> RunInit <$> visitBlock b - CatchException t b -> CatchException <$> visitType t <*> visitBlock b Linearize lam x -> Linearize <$> visitGeneric lam <*> visitGeneric x Transpose lam x -> Transpose <$> visitGeneric lam <*> visitGeneric x @@ -461,23 +454,10 @@ instance IRRep r => VisitGeneric (DAMOp r) r where Place x y -> Place <$> visitGeneric x <*> visitGeneric y Freeze x -> Freeze <$> visitGeneric x -instance IRRep r => VisitGeneric (Effect r) r where +instance IRRep r => VisitGeneric (Effects r) r where visitGeneric = \case - RWSEffect rws h -> RWSEffect rws <$> visitGeneric h - ExceptionEffect -> pure ExceptionEffect - IOEffect -> pure IOEffect - InitEffect -> pure InitEffect - -instance IRRep r => VisitGeneric (EffectRow r) r where - visitGeneric (EffectRow effs tailVar) = do - effs' <- eSetFromList <$> mapM visitGeneric (eSetToList effs) - tailEffRow <- case tailVar of - NoTail -> return $ EffectRow mempty NoTail - EffectRowTail v -> visitGeneric (toAtom v) <&> \case - Stuck _ (Var v') -> EffectRow mempty (EffectRowTail v') - Con (Eff r) -> r - _ -> error "Not a valid effect substitution" - return $ extendEffRow effs' tailEffRow + Pure -> return Pure + Effectful -> return Effectful instance IRRep r => VisitGeneric (DictCon r) r where visitGeneric = \case @@ -491,14 +471,12 @@ instance IRRep r => VisitGeneric (Con r) r where Lit l -> return $ Lit l ProdCon xs -> ProdCon <$> mapM visitGeneric xs SumCon ty con arg -> SumCon <$> mapM visitGeneric ty <*> return con <*> visitGeneric arg - HeapVal -> return HeapVal DepPair x y t -> do x' <- visitGeneric x y' <- visitGeneric y ~(DepPairTy t') <- visitGeneric $ DepPairTy t return $ DepPair x' y' t' Lam lam -> Lam <$> visitGeneric lam - Eff eff -> Eff <$> visitGeneric eff DictConAtom d -> DictConAtom <$> visitGeneric d TyConAtom t -> TyConAtom <$> visitGeneric t NewtypeCon con x -> NewtypeCon <$> visitGeneric con <*> visitGeneric x @@ -513,7 +491,6 @@ instance VisitGeneric NewtypeTyCon CoreIR where visitGeneric = \case Nat -> return Nat Fin x -> Fin <$> visitGeneric x - EffectRowKind -> return EffectRowKind UserADTType n v params -> UserADTType n <$> renameN v <*> visitGeneric params instance VisitGeneric TyConParams CoreIR where @@ -539,14 +516,14 @@ instance VisitGeneric CorePiType CoreIR where instance IRRep r => VisitGeneric (TabPiType r) r where visitGeneric (TabPiType d b eltTy) = do d' <- visitGeneric d - visitGeneric (PiType (UnaryNest b) (EffTy Pure eltTy)) <&> \case - PiType (UnaryNest b') (EffTy Pure eltTy') -> TabPiType d' b' eltTy' + visitGeneric (PiType (UnaryNest b) eltTy) <&> \case + PiType (UnaryNest b') eltTy' -> TabPiType d' b' eltTy' _ -> error "not a table pi type" instance IRRep r => VisitGeneric (DepPairType r) r where visitGeneric (DepPairType expl b ty) = do - visitGeneric (PiType (UnaryNest b) (EffTy Pure ty)) <&> \case - PiType (UnaryNest b') (EffTy Pure ty') -> DepPairType expl b' ty' + visitGeneric (PiType (UnaryNest b) ty) <&> \case + PiType (UnaryNest b') ty' -> DepPairType expl b' ty' _ -> error "not a dependent pair type" instance VisitGeneric RepVal SimpIR where @@ -573,7 +550,7 @@ instance VisitGeneric DataConDefs CoreIR where instance VisitGeneric DataConDef CoreIR where visitGeneric (DataConDef sn (Abs bs UnitE) repTy ps) = do - PiType bs' _ <- visitGeneric $ PiType bs $ EffTy Pure UnitTy + PiType bs' _ <- visitGeneric $ PiType bs UnitTy repTy' <- visitGeneric repTy return $ DataConDef sn (Abs bs' UnitE) repTy' ps @@ -582,8 +559,7 @@ instance IRRep r => VisitGeneric (TyCon r) r where BaseType bt -> return $ BaseType bt ProdType tys -> ProdType <$> mapM visitGeneric tys SumType tys -> SumType <$> mapM visitGeneric tys - RefType h t -> RefType <$> visitGeneric h <*> visitGeneric t - HeapType -> return HeapType + RefType h t -> RefType h <$> visitGeneric t TabPi t -> TabPi <$> visitGeneric t DepPairTy t -> DepPairTy <$> visitGeneric t TypeKind -> return TypeKind @@ -702,22 +678,6 @@ liftSimpAtom ty@(TyCon tyCon) simpAtom = case simpAtom of rec = liftSimpAtom {-# INLINE liftSimpAtom #-} -instance IRRep r => SubstE AtomSubstVal (EffectRow r) where - substE env (EffectRow effs tailVar) = do - let effs' = eSetFromList $ map (substE env) (eSetToList effs) - let tailEffRow = case tailVar of - NoTail -> EffectRow mempty NoTail - EffectRowTail (AtomVar v _) -> case snd env ! v of - Rename v' -> do - let v'' = runEnvReaderM (fst env) $ toAtomVar v' - EffectRow mempty (EffectRowTail v'') - SubstVal (Stuck _ (Var v')) -> EffectRow mempty (EffectRowTail v') - SubstVal (Con (Eff r)) -> r - _ -> error "Not a valid effect substitution" - extendEffRow effs' tailEffRow - -instance IRRep r => SubstE AtomSubstVal (Effect r) - instance SubstE AtomSubstVal SpecializationSpec where substE env (AppSpecialization (AtomVar f _) ab) = do let f' = case snd env ! f of @@ -726,8 +686,11 @@ instance SubstE AtomSubstVal SpecializationSpec where _ -> error "bad substitution" AppSpecialization f' (substE env ab) -instance SubstE AtomSubstVal EffectDef -instance SubstE AtomSubstVal EffectOpType +instance SubstE AtomSubstVal (Effects r) where + substE _ = \case + Pure -> Pure + Effectful -> Effectful + instance SubstE AtomSubstVal IExpr instance SubstE AtomSubstVal RepVal instance SubstE AtomSubstVal TyConParams diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index b15b95331..d8d7cc047 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -101,9 +101,6 @@ class HasNamesB b => CheckableB (r::IR) (b::B) | b -> r where -> (forall o'. DExt o o' => b o o' -> TyperM r i' o' a) -> TyperM r i o a -class SinkableE e => CheckableWithEffects (r::IR) (e::E) | e -> r where - checkWithEffects :: EffectRow r o -> e i -> TyperM r i o (e o) - checkBEvidenced :: CheckableB r b => b i i' -> (forall o'. Distinct o' => ExtEvidence o o' -> b o o' -> TyperM r i' o' a) @@ -124,11 +121,8 @@ checkAndGetType x = do x' <- checkE x return (x', getType x') -checkWithEffTy :: (CheckableWithEffects r e, HasType r e, IRRep r) => EffTy r o -> e i -> TyperM r i o (e o) -checkWithEffTy (EffTy effs ty) e = do - e' <- checkWithEffects effs e - checkTypesEq ty (getType e') - return e' +checkWithEffTy :: (HasType r e, CheckableE r e, IRRep r) => e i -> EffTy r o -> TyperM r i o (e o) +checkWithEffTy e (EffTy _ ty) = e |: ty instance CheckableE CoreIR SourceMap where checkE sm = renameM sm -- TODO? @@ -197,15 +191,15 @@ checkBinderType ty b cont = do checkTypesEq (sink $ binderType b') (sink ty) cont b' -instance IRRep r => CheckableWithEffects r (Expr r) where - checkWithEffects allowedEffs expr = case expr of +instance IRRep r => CheckableE r (Expr r) where + checkE = \case App effTy f xs -> do - effTy' <- checkEffTy allowedEffs effTy + effTy'@(EffTy _ ty) <- checkE effTy f' <- checkE f TyCon (Pi piTy) <- return $ getType f' xs' <- mapM checkE xs - effTy'' <- checkInstantiation piTy xs' - checkAlphaEq effTy' effTy'' + ty' <- checkInstantiation piTy xs' + checkAlphaEq ty ty' return $ App effTy' f' xs' TabApp reqTy f x -> do reqTy' <- checkE reqTy @@ -216,38 +210,39 @@ instance IRRep r => CheckableWithEffects r (Expr r) where return $ TabApp reqTy' f' x' TopApp effTy f xs -> do f' <- renameM f - effTy' <- checkEffTy allowedEffs effTy + effTy'@(EffTy _ ty') <- checkE effTy piTy <- getTypeTopFun f' xs' <- mapM checkE xs - effTy'' <- checkInstantiation piTy xs' - checkAlphaEq effTy' effTy'' + ty'' <- checkInstantiation piTy xs' + checkAlphaEq ty' ty'' return $ TopApp effTy' f' xs' Atom x -> Atom <$> checkE x - PrimOp op -> PrimOp <$> checkWithEffects allowedEffs op + PrimOp op -> PrimOp <$> checkE op Block effTy (Abs decls body) -> do - effTy'@(EffTy effs ty) <- checkEffTy allowedEffs effTy - checkDecls effs decls \decls' -> do - body' <- checkWithEffects (sink effs) body + effTy'@(EffTy _ ty) <- checkE effTy + checkDecls decls \decls' -> do + body' <- checkE body checkTypesEq (sink ty) (getType body') return $ Block effTy' $ Abs decls' body' Case scrut alts effTy -> do - effTy' <- checkEffTy allowedEffs effTy + effTy' <- checkE effTy scrut' <- checkE scrut TyCon (SumType altsBinderTys) <- return $ getType scrut' assertEq (length altsBinderTys) (length alts) "" alts' <- parallelAffines $ (zip alts altsBinderTys) <&> \(Abs b body, reqBinderTy) -> do checkB b \b' -> do checkTypesEq (sink reqBinderTy) (sink $ binderType b') - Abs b' <$> checkWithEffTy (sink effTy') body + Abs b' <$> checkWithEffTy body (sink effTy') return $ Case scrut' alts' effTy' - ApplyMethod effTy dict i args -> do - effTy' <- checkEffTy allowedEffs effTy + ApplyMethod (EffTy eff resultTy) dict i args -> do + eff' <- checkE eff + resultTy' <- checkE resultTy Just dict' <- toMaybeDict <$> checkE dict args' <- mapM checkE args methodTy <- getMethodType dict' i - effTy'' <- checkInstantiation methodTy args' - checkAlphaEq effTy' effTy'' - return $ ApplyMethod effTy' (toAtom dict') i args' + resultTy'' <- checkInstantiation methodTy args' + checkAlphaEq resultTy' resultTy'' + return $ ApplyMethod (EffTy eff' resultTy') (toAtom dict') i args' TabCon ty xs -> do ty'@(TyCon (TabPi (TabPiType _ b restTy))) <- checkE ty xs' <- case fromConstAbs (Abs b restTy) of @@ -298,9 +293,8 @@ instance IRRep r => CheckableE r (Stuck r) where return $ StuckTabApp f' x' InstantiatedGiven given args -> do given' <- checkE given - TyCon (Pi piTy) <- queryStuckType given' + TyCon (Pi _) <- queryStuckType given' args' <- mapM checkE args - EffTy Pure _ <- checkInstantiation piTy args' return $ InstantiatedGiven given' args' SuperclassProj i d -> SuperclassProj <$> pure i <*> checkE d -- TODO: check index in range PtrVar t v -> PtrVar t <$> renameM v @@ -378,9 +372,8 @@ instance IRRep r => CheckableE r (TyCon r) where BaseType b -> return $ BaseType b ProdType tys -> ProdType <$> mapM checkE tys SumType cs -> SumType <$> mapM checkE cs - RefType r a -> RefType <$> r|:TyCon HeapType <*> checkE a + RefType r a -> RefType r <$> checkE a TypeKind -> return TypeKind - HeapType -> return HeapType Pi t -> Pi <$> checkE t TabPi t -> TabPi <$> checkE t NewtypeTyCon t -> NewtypeTyCon <$> checkE t @@ -407,7 +400,6 @@ instance IRRep r => CheckableE r (Con r) where unless (0 <= tag && tag < length tys') $ throwInternal "Invalid SumType tag" payload' <- payload |: (tys' !! tag) return $ SumCon tys' tag payload' - HeapVal -> return HeapVal Lam lam -> Lam <$> checkE lam DepPair l r ty -> do l' <- checkE l @@ -415,7 +407,6 @@ instance IRRep r => CheckableE r (Con r) where rTy <- checkInstantiation ty' [l'] r' <- r |: rTy return $ DepPair l' r' ty' - Eff eff -> Eff <$> checkE eff -- TODO: check against cached type DictConAtom con -> DictConAtom <$> checkE con NewtypeCon con x -> do @@ -443,7 +434,6 @@ instance CheckableE CoreIR NewtypeTyCon where checkE = \case Nat -> return Nat Fin n -> Fin <$> n|:NatTy - EffectRowKind -> return EffectRowKind UserADTType sn d params -> do d' <- renameM d TyConParams expls params' <- checkE params @@ -451,16 +441,11 @@ instance CheckableE CoreIR NewtypeTyCon where void $ checkInstantiation def params' return $ UserADTType sn d' (TyConParams expls params') -instance IRRep r => CheckableWithEffects r (PrimOp r) where - checkWithEffects effs = \case +instance IRRep r => CheckableE r (PrimOp r) where + checkE = \case Hof (TypedHof effTy hof) -> do - effTy'@(EffTy effs' resultTy) <- checkE effTy - checkExtends effs effs' - -- TODO: we should be able to use the `effTy` from the `TypedHof`, which - -- might have fewer effects than `effs`. But that exposes an error in - -- which we under-report the `Init` effect in the `TypedHof` effect - -- annotation. We should fix that. - hof' <- checkHof (EffTy effs resultTy) hof + effTy' <- checkE effTy + hof' <- checkHof effTy' hof return $ Hof (TypedHof effTy' hof') VectorOp vOp -> VectorOp <$> checkE vOp BinOp binop x y -> do @@ -475,22 +460,20 @@ instance IRRep r => CheckableWithEffects r (PrimOp r) where TyCon (BaseType xTy) <- return $ getType x' checkUnOp unop xTy return $ UnOp unop x' - MiscOp op -> MiscOp <$> checkWithEffects effs op - MemOp op -> MemOp <$> checkWithEffects effs op - DAMOp op -> DAMOp <$> checkWithEffects effs op + MiscOp op -> MiscOp <$> checkE op + MemOp op -> MemOp <$> checkE op + DAMOp op -> DAMOp <$> checkE op RefOp ref m -> do (ref', TyCon (RefType h s)) <- checkAndGetType ref m' <- case m of - MGet -> declareEff effs (RWSEffect State h) $> MGet + MGet -> return MGet MPut x -> do x' <- x|:s - declareEff effs (RWSEffect State h) return $ MPut x' - MAsk -> declareEff effs (RWSEffect Reader h) $> MAsk + MAsk -> return MAsk MExtend b x -> do b' <- checkE b x' <- x|:s - declareEff effs (RWSEffect Writer h) return $ MExtend b' x' IndexRef givenTy i -> do givenTy' <- checkE givenTy @@ -518,23 +501,19 @@ instance IRRep r => CheckableE r (EffTy r) where instance IRRep r => CheckableE r (BaseMonoid r) where checkE = renameM -- TODO: check -instance IRRep r => CheckableWithEffects r (MemOp r) where - checkWithEffects effs = \case +instance IRRep r => CheckableE r (MemOp r) where + checkE = \case IOAlloc n -> do - declareEff effs IOEffect IOAlloc <$> (n |: IdxRepTy) IOFree ptr -> do - declareEff effs IOEffect IOFree <$> checkIsPtr ptr PtrOffset ptr off -> do ptr' <- checkIsPtr ptr off' <- off |: IdxRepTy return $ PtrOffset ptr' off' PtrLoad ptr -> do - declareEff effs IOEffect PtrLoad <$> checkIsPtr ptr PtrStore ptr val -> do - declareEff effs IOEffect ptr' <- checkE ptr PtrTy (_, t) <- return $ getType ptr' val' <- val |: BaseTy t @@ -546,8 +525,8 @@ checkIsPtr ptr = do PtrTy _ <- return $ getType ptr' return ptr' -instance IRRep r => CheckableWithEffects r (MiscOp r) where - checkWithEffects effs = \case +instance IRRep r => CheckableE r (MiscOp r) where + checkE = \case Select p x y -> do p' <- p |: (BaseTy $ Scalar Word8Type) x' <- checkE x @@ -587,9 +566,6 @@ instance IRRep r => CheckableWithEffects r (MiscOp r) where BaseTy (Scalar _) <- return $ getType x' return $ ShowScalar x' ThrowError ty -> ThrowError <$> checkE ty - ThrowException ty -> ThrowException <$> do - declareEff effs ExceptionEffect - checkE ty checkSomeSumType :: IRRep r => Type r o -> TyperM r i o [Type r o] checkSomeSumType = \case @@ -624,83 +600,40 @@ instance IRRep r => CheckableE r (VectorOp r) where return $ VectorSubref ref' i' ty' checkHof :: IRRep r => EffTy r o -> Hof r i -> TyperM r i o (Hof r o) -checkHof (EffTy effs reqTy) = \case +checkHof (EffTy _ reqTy) = \case For dir ixTy f -> do IxType t d <- checkE ixTy LamExpr (UnaryNest b) body <- return f TyCon (TabPi tabTy) <- return reqTy checkBinderType t b \b' -> do resultTy <- checkInstantiation (sink tabTy) [toAtom $ binderVar b'] - body' <- checkWithEffTy (EffTy (sink effs) resultTy) body + body' <- body |: resultTy return $ For dir (IxType t d) (LamExpr (UnaryNest b') body') While body -> do - let effTy = EffTy effs (BaseTy $ Scalar Word8Type) checkTypesEq reqTy UnitTy - While <$> checkWithEffTy effTy body + body' <- body |: BaseTy (Scalar Word8Type) + return $ While body' Linearize f x -> do (x', xTy) <- checkAndGetType x LamExpr (UnaryNest b) body <- return f checkBinderType xTy b \b' -> do PairTy resultTy fLinTy <- sinkM reqTy - body' <- checkWithEffTy (EffTy Pure resultTy) body - checkTypesEq fLinTy (toType $ nonDepPiType [sink xTy] Pure resultTy) + body' <- body |: resultTy + checkTypesEq fLinTy (toType $ nonDepPiType [sink xTy] resultTy) return $ Linearize (LamExpr (UnaryNest b') body') x' Transpose f x -> do (x', xTy) <- checkAndGetType x LamExpr (UnaryNest b) body <- return f checkB b \b' -> do - body' <- checkWithEffTy (EffTy Pure (sink xTy)) body + body' <- body |: sink xTy checkTypesEq (sink $ binderType b') (sink reqTy) return $ Transpose (LamExpr (UnaryNest b') body') x' - RunReader r f -> do - (r', rTy) <- checkAndGetType r - f' <- checkRWSAction reqTy rTy effs Reader f - return $ RunReader r' f' - RunWriter d bm f -> do - -- XXX: We can't verify compatibility between the base monoid and f, because - -- the only way in which they are related in the runAccum definition is via - -- the AccumMonoid typeclass. The frontend constraints should be sufficient - -- to ensure that only well typed programs are accepted, but it is a bit - -- disappointing that we cannot verify that internally. We might want to consider - -- e.g. only disabling this check for prelude. - bm' <- checkE bm - PairTy resultTy accTy <- return reqTy - f' <- checkRWSAction resultTy accTy effs Writer f - d' <- case d of - Nothing -> return Nothing - Just dest -> do - dest' <- dest |: RawRefTy accTy - declareEff effs InitEffect - return $ Just dest' - return $ RunWriter d' bm' f' - RunState d s f -> do - (s', sTy) <- checkAndGetType s - PairTy resultTy sTy' <- return reqTy - checkTypesEq sTy sTy' - f' <- checkRWSAction resultTy sTy effs State f - d' <- case d of - Nothing -> return Nothing - Just dest -> do - declareEff effs InitEffect - Just <$> dest |: RawRefTy sTy - return $ RunState d' s' f' - RunIO body -> RunIO <$> checkWithEffTy (EffTy (extendEffect IOEffect effs) reqTy) body - RunInit body -> RunInit <$> checkWithEffTy (EffTy (extendEffect InitEffect effs) reqTy) body - CatchException reqTy' body -> do - reqTy'' <- checkE reqTy' - checkTypesEq reqTy reqTy'' - -- TODO: take more care in unpacking Maybe - TyCon (NewtypeTyCon (UserADTType _ _ (TyConParams _ [ty]))) <- return reqTy'' - Just ty' <- return $ toMaybeType ty - body' <- checkWithEffTy (EffTy (extendEffect ExceptionEffect effs) ty') body - return $ CatchException reqTy'' body' - -instance IRRep r => CheckableWithEffects r (DAMOp r) where - checkWithEffects effs = \case + +instance IRRep r => CheckableE r (DAMOp r) where + checkE = \case Seq effAnn dir ixTy carry lam -> do LamExpr (UnaryNest b) body <- return lam effAnn' <- checkE effAnn - checkExtends effs effAnn' ixTy' <- checkE ixTy (carry', carryTy') <- checkAndGetType carry let badCarry = throwInternal $ "Seq carry should be a product of raw references, got: " ++ pprint carryTy' @@ -709,21 +642,19 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where _ -> badCarry let binderReqTy = PairTy (ixTypeType ixTy') carryTy' checkBinderType binderReqTy b \b' -> do - body' <- checkWithEffTy (EffTy (sink effAnn') UnitTy) body + body' <- checkE body return $ Seq effAnn' dir ixTy' carry' $ LamExpr (UnaryNest b') body' RememberDest effAnn d lam -> do LamExpr (UnaryNest b) body <- return lam effAnn' <- checkE effAnn - checkExtends effs effAnn' (d', dTy@(RawRefTy _)) <- checkAndGetType d checkBinderType dTy b \b' -> do - body' <- checkWithEffTy (EffTy (sink effAnn') UnitTy) body + body' <- checkE body return $ RememberDest effAnn' d' $ LamExpr (UnaryNest b') body' AllocDest ty -> AllocDest <$> checkE ty Place ref val -> do val' <- checkE val ref' <- ref |: RawRefTy (getType val') - declareEff effs InitEffect return $ Place ref' val' Freeze ref -> do ref' <- checkE ref @@ -733,35 +664,22 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where checkLamExpr :: IRRep r => PiType r o -> LamExpr r i -> TyperM r i o (LamExpr r o) checkLamExpr piTy (LamExpr bs body) = checkB bs \bs' -> do - effTy <- checkInstantiation (sink piTy) (toAtom <$> bindersVars bs') - body' <- checkWithEffTy effTy body + void $ checkInstantiation (sink piTy) (toAtom <$> bindersVars bs') + body' <- checkE body return $ LamExpr bs' body' checkDecls :: IRRep r - => EffectRow r o -> Decls r i i' + => Decls r i i' -> (forall o'. DExt o o' => Decls r o o' -> TyperM r i' o' a) -> TyperM r i o a -checkDecls _ Empty cont = getDistinct >>= \Distinct -> cont Empty -checkDecls effs (Nest (Let b (DeclBinding ann expr)) decls) cont = do - rhs <- DeclBinding ann <$> checkWithEffects effs expr +checkDecls Empty cont = getDistinct >>= \Distinct -> cont Empty +checkDecls (Nest (Let b (DeclBinding ann expr)) decls) cont = do + rhs <- DeclBinding ann <$> checkE expr withFreshBinder (getNameHint b) rhs \(b':>_) -> do extendRenamer (b@>binderName b') do let decl' = Let b' rhs - checkDecls (sink effs) decls \decls' -> cont $ Nest decl' decls' - -checkRWSAction - :: IRRep r => Type r o -> Type r o -> EffectRow r o - -> RWS -> LamExpr r i -> TyperM r i o (LamExpr r o) -checkRWSAction resultTy referentTy effs rws f = do - BinaryLamExpr bH bR body <- return f - checkBinderType (TyCon HeapType) bH \bH' -> do - let h = toAtom $ binderVar bH' - let refTy = RefTy h (sink referentTy) - checkBinderType refTy bR \bR' -> do - let effs' = extendEffect (RWSEffect rws $ sink h) (sink effs) - body' <- checkWithEffTy (EffTy effs' (sink resultTy)) body - return $ BinaryLamExpr bH' bR' body' + checkDecls decls \decls' -> cont $ Nest decl' decls' checkProject :: (IRRep r) => Int -> Atom r o -> TyperM r i o (Type r o) checkProject i x = case getType x of @@ -898,31 +816,7 @@ checkUnOp op x = checkOpArgType argTy x where u = SomeUIntArg; f = SomeFloatArg; --- === effects === - -instance IRRep r => CheckableE r (EffectRow r) where - checkE (EffectRow effs effTail) = do - effs' <- eSetFromList <$> forM (eSetToList effs) \eff -> case eff of - RWSEffect rws v -> do - v' <- v |: TyCon HeapType - return $ RWSEffect rws v' - ExceptionEffect -> return ExceptionEffect - IOEffect -> return IOEffect - InitEffect -> return InitEffect - effTail' <- case effTail of - NoTail -> return NoTail - EffectRowTail v -> do - v' <- renameM v - ty <- getType <$> lookupAtomName (atomVarName v') - checkTypesEq EffKind ty - return $ EffectRowTail v' - return $ EffectRow effs' effTail' - -declareEff :: IRRep r => EffectRow r o -> Effect r o -> TyperM r i o () -declareEff allowedEffs eff = checkExtends allowedEffs $ OneEffect eff - -checkEffTy :: IRRep r => EffectRow r o -> EffTy r i -> TyperM r i o (EffTy r o) -checkEffTy allowedEffs effTy = do - EffTy declaredEffs resultTy <- checkE effTy - checkExtends allowedEffs declaredEffs - return $ EffTy declaredEffs resultTy +instance IRRep r => CheckableE r (Effects r) where + checkE = \case + Pure -> return Pure + Effectful -> return Effectful diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 5209466d7..5d60cc1fe 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -74,7 +74,6 @@ interpOperator sid = \case "." -> atomic Dot ",>" -> atomic DepComma ":" -> atomic Colon - "|" -> atomic Pipe "::" -> atomic DoubleColon "$" -> atomic Dollar "->>" -> atomic ImplicitArrow @@ -328,9 +327,8 @@ funDefLet = label "function definition" do params <- explicitParams rhs <- optional do expl <- explicitness - effs <- optional cEffs resultTy <- cGroupNoEqual - return (expl, effs, resultTy) + return (expl, resultTy) givens <- optional givenClause mayNotBreak do sym "=" @@ -387,12 +385,6 @@ givenClause = do withClause :: Parser WithClause withClause = keyWord WithKW >> parenList cGroup -cEffs :: Parser CEffs -cEffs = withSrcs $ braces do - effs <- commaSep cGroupNoPipe - effTail <- optional $ sym "|" >> cGroup - return (effs, effTail) - commaSep :: Parser a -> Parser [a] commaSep p = sepBy p (sym ",") @@ -410,10 +402,6 @@ cGroupNoEqual :: Parser GroupW cGroupNoEqual = makeExprParser leafGroup $ withoutOp "=" ops -cGroupNoPipe :: Parser GroupW -cGroupNoPipe = makeExprParser leafGroup $ - withoutOp "|" ops - cGroupNoArrow :: Parser GroupW cGroupNoArrow = makeExprParser leafGroup $ withoutOp "->" ops @@ -672,8 +660,7 @@ symOp s = binApp do arrowOp :: Parser (GroupW -> GroupW -> GroupW) arrowOp = addSrcIdToBinOp do sid <- symWithId "->" - optEffs <- optional cEffs - return \lhs rhs -> ([sid], CArrow lhs optEffs rhs) + return \lhs rhs -> ([sid], CArrow lhs rhs) unOpPre :: String -> (SourceName, Expr.Operator Parser GroupW) unOpPre s = (fromString s, Expr.Prefix $ prefixOp s) diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 466afbb5d..92ef67a0c 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -102,18 +102,12 @@ liftExportSigM cont = do corePiToExportSig :: CallingConvention -> CorePiType i -> ExportSigM CoreIR i o (ExportedSignature o) -corePiToExportSig cc (CorePiType _ expls tbs (EffTy effs resultTy)) = do - case effs of - Pure -> return () - _ -> throwErr $ MiscErr $ MiscMiscErr "Only pure functions can be exported" +corePiToExportSig cc (CorePiType _ expls tbs resultTy) = do goArgs cc Empty [] (zipAttrs expls tbs) resultTy simpPiToExportSig :: CallingConvention -> PiType SimpIR i -> ExportSigM SimpIR i o (ExportedSignature o) -simpPiToExportSig cc (PiType bs (EffTy effs resultTy)) = do - case effs of - Pure -> return () - _ -> throwErr $ MiscErr $ MiscMiscErr "Only pure functions can be exported" +simpPiToExportSig cc (PiType bs resultTy) = do bs' <- return $ fmapNest (\b -> WithAttrB Explicit b) bs goArgs cc Empty [] bs' resultTy diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index ac7efabee..ba8cfbca2 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -143,11 +143,9 @@ traverseTyParams (TyCon ty) f = liftM TyCon $ getDistinct >>= \Distinct -> case RefType _ _ -> error "not implemented" -- how should we handle the ParamRole for the heap parameter? SumType tys -> SumType <$> forM tys \t -> f' TypeParam TyKind t TypeKind -> return TypeKind - HeapType -> return HeapType NewtypeTyCon con -> NewtypeTyCon <$> case con of Nat -> return Nat Fin n -> Fin <$> f DataParam NatTy n - EffectRowKind -> return EffectRowKind UserADTType sn def (TyConParams infs params) -> do Abs roleBinders UnitE <- getDataDefRoleBinders def params' <- traverseRoleBinders f roleBinders params diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 7c1ab7138..ff74e4ecd 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -118,18 +118,18 @@ getNaryLamImpArgTypes :: EnvReader m => PiType SimpIR n -> m n ([[BaseType]], [BaseType]) getNaryLamImpArgTypes t = liftEnvReaderM $ go t where go :: PiType SimpIR n -> EnvReaderM n ([[BaseType]], [BaseType]) - go (PiType bs effTy) = case bs of + go (PiType bs resultTy) = case bs of Nest piB rest -> do ts <- getRepBaseTypes $ binderType piB - refreshAbs (Abs piB (PiType rest effTy)) \_ restPi -> do + refreshAbs (Abs piB (PiType rest resultTy)) \_ restPi -> do (argTys, resultTys) <- go restPi return (ts:argTys, resultTys) - Empty -> ([],) <$> getDestBaseTypes (etTy effTy) + Empty -> ([],) <$> getDestBaseTypes resultTy interpretImpArgsWithDest :: EnvReader m => PiType SimpIR n -> [IExpr n] -> m n ([SAtom n], Dest n) interpretImpArgsWithDest t xs = do - (PiType bs (EffTy _ resultTy)) <- return t + (PiType bs resultTy) <- return t (args, xsLeft) <- _interpretImpArgs (EmptyAbs bs) xs resultTy' <- applySubst (bs @@> (SubstVal <$> args)) resultTy (destTree, xsRest) <- listToTree resultTy' xsLeft @@ -423,7 +423,7 @@ toImpVectorOp = \case refi <- destToAtom <$> indexDest refDest i refi' <- fromScalarAtom refi resultVal <- castPtrToVectorType refi' (toIVectorType vty) - repValAtom $ RepVal (RefTy (Con HeapVal) vty) (Leaf resultVal) + repValAtom $ RepVal (RefTy State vty) (Leaf resultVal) where returnIExprVal x = return $ toScalarAtom x @@ -470,7 +470,6 @@ toImpMiscOp op = case op of return $ toAtom $ RepVal ty $ Branch $ Leaf i' : map (const (Branch [])) cases _ -> error $ "Not an enum: " ++ pprint ty OutputStream -> returnIExprVal =<< emitInstr IOutputStream - ThrowException _ -> error "shouldn't have ThrowException left" -- also, should be replaced with user-defined errors ShowAny _ -> error "Shouldn't have ShowAny in simplified IR" ShowScalar x -> do resultTy <- return $ getType $ PrimOp $ MiscOp op @@ -514,71 +513,14 @@ toImpMemOp op = case op of returnIExprVal x = return $ toScalarAtom x toImpTypedHof :: Emits o => TypedHof SimpIR i -> SubstImpM i o (SAtom o) -toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do - resultTy <- substM resultTy' - case hof of - For _ _ _ -> error $ "Unexpected `for` in Imp pass " ++ pprint hof - While body -> do - body' <- buildBlockImp do - ans <- fromScalarAtom =<< translateExpr body - return [ans] - emitStatement $ IWhile body' - return UnitVal - RunReader r f -> do - BinaryLamExpr h ref body <- return f - r' <- substM r - rDest <- allocDest $ getType r' - storeAtom rDest r' - extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom rDest)) $ - translateExpr body - RunWriter d (BaseMonoid e _) f -> do - BinaryLamExpr h ref body <- return f - let PairTy ansTy accTy = resultTy - (aDest, wDest) <- case d of - Nothing -> destPairUnpack <$> allocDest resultTy - Just d' -> do - aDest <- allocDest ansTy - wDest <- atomToDest =<< substM d' - return (aDest, wDest) - e' <- substM e - PairE accTy' e'' <- sinkM $ PairE accTy e' - liftMonoidEmpty wDest accTy' e'' - extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom wDest)) $ - translateExpr body >>= storeAtom aDest - PairVal <$> loadAtom aDest <*> loadAtom wDest - RunState d s f -> do - BinaryLamExpr h ref body <- return f - let PairTy ansTy _ = resultTy - (aDest, sDest) <- case d of - Nothing -> destPairUnpack <$> allocDest resultTy - Just d' -> do - aDest <- allocDest ansTy - sDest <- atomToDest =<< substM d' - return (aDest, sDest) - storeAtom sDest =<< substM s - extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom sDest)) $ - translateExpr body >>= storeAtom aDest - PairVal <$> loadAtom aDest <*> loadAtom sDest - RunIO body -> translateExpr body - RunInit body -> translateExpr body - where - liftMonoidEmpty :: Emits n => Dest n -> SType n -> SAtom n -> SubstImpM i n () - liftMonoidEmpty accDest accTy x = do - xTy <- return $ getType x - alphaEq xTy accTy >>= \case - True -> storeAtom accDest x - False -> case accTy of - TyCon (TabPi t) -> do - let ixTy = tabIxType t - n <- indexSetSizeImp ixTy - emitLoop noHint Fwd n \i -> do - idx <- unsafeFromOrdinalImp (sink ixTy) i - x' <- sinkM x - eltTy <- instantiate (sink t) [idx] - ithDest <- indexDest (sink accDest) idx - liftMonoidEmpty ithDest eltTy x' - _ -> error $ "Base monoid type mismatch: can't lift " ++ - pprint xTy ++ " to " ++ pprint accTy +toImpTypedHof (TypedHof _ hof) = case hof of + For _ _ _ -> error $ "Unexpected `for` in Imp pass " ++ pprint hof + While body -> do + body' <- buildBlockImp do + ans <- fromScalarAtom =<< translateExpr body + return [ans] + emitStatement $ IWhile body' + return UnitVal -- === Runtime representation of values and refs === @@ -697,7 +639,6 @@ typeToTree tyTop = return $ go REmpty tyTop let tag = rec TagRepTy let xs = map rec ts Branch $ tag:xs - HeapType -> Branch [] where rec = go ctx traverseScalarRepTys :: EnvReader m => SType n -> (LeafType n -> m n a) -> m n (Tree a) @@ -746,7 +687,6 @@ valueToTree (RepVal tyTop valTop) = do results <- zipWithM rec ts vals return $ Branch $ tag : results _ -> error "expected a branch" - _ -> error $ "not implemented " ++ pprint ty where rec = go ctx {-# INLINE valueToTree #-} @@ -857,7 +797,6 @@ atomToRepVal x = RepVal (getType x) <$> go x where then go payload else buildGarbageVal t <&> \(Stuck _ (RepValAtom (RepVal _ tree))) -> tree return $ Branch $ tag':xs - HeapVal -> return $ Branch [] go (Stuck _ stuck) = case stuck of Var v -> lookupAtomName (atomVarName v) >>= \case TopDataBound (RepVal _ tree) -> return tree @@ -880,7 +819,7 @@ atomToRepVal x = RepVal (getType x) <$> go x where -- from the dest. This version is not that. It just lifts a dest into an atom of -- type `Ref _`. destToAtom :: Dest n -> SAtom n -destToAtom (Dest valTy tree) = toAtom $ RepVal (RefTy (Con HeapVal) valTy) tree +destToAtom (Dest valTy tree) = toAtom $ RepVal (RefTy State valTy) tree atomToDest :: EnvReader m => SAtom n -> m n (Dest n) atomToDest (Stuck _ (RepValAtom val)) = do @@ -1039,12 +978,6 @@ projectDest i (Dest (TyCon (ProdType tys)) (Branch ds)) = Dest (tys!!i) (ds!!i) projectDest _ (Dest ty _) = error $ "Can't project dest: " ++ pprint ty -destPairUnpack :: Dest n -> (Dest n, Dest n) -destPairUnpack (Dest (PairTy t1 t2) (Branch [d1, d2])) = - ( Dest t1 d1, Dest t2 d2 ) -destPairUnpack (Dest ty tree) = - error $ "Can't unpack dest: " ++ pprint ty ++ "\n" ++ show tree - -- === Determining buffer sizes and offsets using polynomials === type SBuilderM = BuilderM SimpIR @@ -1185,7 +1118,7 @@ withFreshIBinder hint ty cont = do emitCall :: Emits n => PiType SimpIR n -> ImpFunName n -> [SAtom n] -> SubstImpM i n (SAtom n) -emitCall (PiType bs (EffTy _ resultTy)) f xs = do +emitCall (PiType bs resultTy) f xs = do resultTy' <- applySubst (bs @@> map SubstVal xs) resultTy dest <- allocDest resultTy' argsImp <- forM xs \x -> repValToList <$> atomToRepVal x diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index c827523d7..9457bed85 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -18,7 +18,7 @@ import Control.Monad import Control.Monad.State.Strict import Control.Monad.Reader import Data.Either (partitionEithers) -import Data.Foldable (toList, asum) +import Data.Foldable (asum) import Data.Functor ((<&>)) import Data.List (sortOn) import Data.Maybe (fromJust, fromMaybe, catMaybes) @@ -132,16 +132,15 @@ inferTopUDecl (ULocalDecl (WithSrcB _ decl)) result = case decl of asTopBlock :: EnvReader m => CExpr n -> m n (TopBlock CoreIR n, CType n) asTopBlock block = do - let effs = getEffects block let ty = getType block - return (TopLam False (PiType Empty (EffTy effs ty)) (LamExpr Empty block), ty) + return (TopLam False (PiType Empty ty) (LamExpr Empty block), ty) getInstanceType :: EnvReader m => InstanceDef n -> m n (CorePiType n) getInstanceType (InstanceDef className roleExpls bs params _) = liftEnvReaderM do refreshAbs (Abs bs (ListE params)) \bs' (ListE params') -> do className' <- sinkM className dTy <- toType <$> dictType className' params' - return $ CorePiType ImplicitApp (snd <$> roleExpls) bs' $ EffTy Pure dTy + return $ CorePiType ImplicitApp (snd <$> roleExpls) bs' dTy -- === Inferer monad === @@ -151,8 +150,7 @@ emptySolverSubst :: SolverSubst n emptySolverSubst = SolverSubst mempty data InfState (n::S) = InfState - { givens :: Givens n - , infEffects :: EffectRow CoreIR n } + { givens :: Givens n } newtype InfererM (i::S) (o::S) (a:: *) = InfererM { runInfererM' :: SubstReaderT Name (ReaderT1 InfState (BuilderT CoreIR (ExceptT (State TypingInfo)))) i o a } @@ -176,7 +174,7 @@ liftInfererMPure cont = do return $ runState (runExceptT ansM) mempty where emptyInfState :: InfState n - emptyInfState = InfState (Givens HM.empty) Pure + emptyInfState = InfState (Givens HM.empty) {-# INLINE liftInfererMPure #-} -- === Solver monad === @@ -343,10 +341,6 @@ withInfState :: (InfState o -> InfState o) -> InfererM i o a -> InfererM i o a withInfState f cont = InfererM $ local f (runInfererM' cont) {-# INLINE withInfState #-} -withAllowedEffects :: EffectRow CoreIR o -> InfererM i o a -> InfererM i o a -withAllowedEffects effs cont = withInfState (\(InfState g _) -> InfState g effs) cont -{-# INLINE withAllowedEffects #-} - getTypeAndEmit :: HasType CoreIR e => SrcId -> e o -> InfererM i o (e o) getTypeAndEmit sid e = do emitExprType sid (getType e) @@ -383,7 +377,6 @@ data PartialPiType (n::S) where PartialPiType :: AppExplicitness -> [Explicitness] -> Nest CBinder n l - -> EffectRow CoreIR l -> RequiredTy l -> PartialPiType n @@ -414,16 +407,15 @@ etaExpandPartialPi :: PartialPiType o -> (forall o'. (Emits o', DExt o o') => RequiredTy o' -> [CAtom o'] -> InfererM i o' (CAtom o')) -> InfererM i o (CoreLamExpr o) -etaExpandPartialPi (PartialPiType appExpl expls bs effs reqTy) cont = do - withFreshBindersInf expls (Abs bs (PairE effs reqTy)) \bs' (PairE effs' reqTy') -> do +etaExpandPartialPi (PartialPiType appExpl expls bs reqTy) cont = do + withFreshBindersInf expls (Abs bs reqTy) \bs' reqTy' -> do let args = zip expls (toAtom <$> bindersVars bs') explicits <- return $ catMaybes $ args <&> \case (Explicit, arg) -> Just arg _ -> Nothing - withAllowedEffects effs' do - body <- buildBlock $ cont (sink reqTy') (sink <$> explicits) - let piTy = CorePiType appExpl expls bs' (EffTy effs' $ getType body) - return $ CoreLamExpr piTy $ LamExpr bs' body + body <- buildBlock $ cont (sink reqTy') (sink <$> explicits) + let piTy = CorePiType appExpl expls bs' (getType body) + return $ CoreLamExpr piTy $ LamExpr bs' body -- Doesn't introduce implicit pi binders or dependent pairs topDownExplicit :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CAtom o) @@ -516,12 +508,12 @@ bottomUpExplicit (WithSrcE sid expr) = getTypeAndEmit sid =<< case expr of UTabApp tab args -> do tab' <- bottomUp tab SigmaAtom Nothing <$> inferTabApp (getSrcId tab) tab' args - UPi (UPiExpr bs appExpl effs ty) -> do + UPi (UPiExpr bs appExpl ty) -> do -- TODO: check explicitness constraints withUBinders bs \(ZipB expls bs') -> do - effTy' <- EffTy <$> checkUEffRow effs <*> checkUType ty + ty' <- checkUType ty return $ SigmaAtom Nothing $ toAtom $ - Pi $ CorePiType appExpl expls bs' effTy' + Pi $ CorePiType appExpl expls bs' ty' UTabPi (UTabPiExpr b ty) -> do Abs b' ty' <- withUBinder b \(WithAttrB _ b') -> liftM (Abs b') $ checkUType ty @@ -558,7 +550,7 @@ bottomUpExplicit (WithSrcE sid expr) = getTypeAndEmit sid =<< case expr of UPrim UExplicitApply (f:xs) -> do f' <- bottomUpExplicit f xs' <- mapM bottomUp xs - SigmaAtom Nothing <$> applySigmaAtom sid f' xs' + SigmaAtom Nothing <$> applySigmaAtom f' xs' UPrim UProjNewtype [x] -> do x' <- bottomUp x >>= unwrapNewtype return $ SigmaAtom Nothing x' @@ -594,17 +586,17 @@ fromNatLit sid x ty = do instantiateSigma :: Emits o => SrcId -> RequiredTy o -> SigmaAtom o -> InfererM i o (CAtom o) instantiateSigma sid reqTy sigmaAtom = case sigmaAtom of SigmaUVar _ _ _ -> case getType sigmaAtom of - TyCon (Pi (CorePiType ImplicitApp expls bs (EffTy _ resultTy))) -> do + TyCon (Pi (CorePiType ImplicitApp expls bs resultTy)) -> do bsConstrained <- buildConstraints (Abs bs resultTy) \_ resultTy' -> do case reqTy of Infer -> return [] Check reqTy' -> return [TypeConstraint sid (sink reqTy') resultTy'] args <- inferMixedArgs @UExpr sid fDesc expls bsConstrained ([], []) - applySigmaAtom sid sigmaAtom args + applySigmaAtom sigmaAtom args _ -> fallback _ -> fallback where - fallback = forceSigmaAtom sid sigmaAtom >>= matchReq sid reqTy + fallback = forceSigmaAtom sigmaAtom >>= matchReq sid reqTy fDesc = getSourceName sigmaAtom matchReq :: Ext o o' => SrcId -> RequiredTy o -> CAtom o' -> InfererM i o' (CAtom o') @@ -614,12 +606,12 @@ matchReq sid (Check reqTy) x = do matchReq _ Infer x = return x {-# INLINE matchReq #-} -forceSigmaAtom :: Emits o => SrcId -> SigmaAtom o -> InfererM i o (CAtom o) -forceSigmaAtom sid sigmaAtom = case sigmaAtom of +forceSigmaAtom :: Emits o => SigmaAtom o -> InfererM i o (CAtom o) +forceSigmaAtom sigmaAtom = case sigmaAtom of SigmaAtom _ x -> return x SigmaUVar _ _ v -> case v of UAtomVar v' -> inlineTypeAliases v' - _ -> applySigmaAtom sid sigmaAtom [] + _ -> applySigmaAtom sigmaAtom [] SigmaPartialApp _ _ _ -> error "not implemented" -- better error message? withBlockDecls @@ -653,7 +645,7 @@ withUDecl (WithSrcB _ d) cont = case d of considerInlineAnn :: LetAnn -> CType n -> LetAnn considerInlineAnn PlainLet TyKind = InlineLet -considerInlineAnn PlainLet (TyCon (Pi (CorePiType _ _ _ (EffTy Pure TyKind)))) = InlineLet +considerInlineAnn PlainLet (TyCon (Pi (CorePiType _ _ _ TyKind))) = InlineLet considerInlineAnn ann _ = ann applyFromLiteralMethod @@ -778,7 +770,7 @@ checkOrInferApp appSrcId funSrcId f' posArgs namedArgs reqTy = do checkExplicitArity appSrcId expls posArgs bsConstrained <- buildAppConstraints appSrcId reqTy piTy args <- inferMixedArgs appSrcId fDesc expls bsConstrained (posArgs, namedArgs) - applySigmaAtom appSrcId f args + applySigmaAtom f args ImplicitApp -> error "should already have handled this case" ty -> throw funSrcId $ EliminationErr "function type" (pprint ty) where @@ -786,17 +778,11 @@ checkOrInferApp appSrcId funSrcId f' posArgs namedArgs reqTy = do fDesc = getSourceName f' buildAppConstraints :: SrcId -> RequiredTy n -> CorePiType n -> InfererM i n (ConstrainedBinders n) -buildAppConstraints appSrcId reqTy (CorePiType _ _ bs effTy) = do - effsAllowed <- infEffects <$> getInfState - buildConstraints (Abs bs effTy) \_ (EffTy effs resultTy) -> do - resultTyConstraints <- return case reqTy of +buildAppConstraints appSrcId reqTy (CorePiType _ _ bs ty) = do + buildConstraints (Abs bs ty) \_ resultTy -> do + return case reqTy of Infer -> [] Check reqTy' -> [TypeConstraint appSrcId (sink reqTy') resultTy] - EffectRow _ t <- return effs - effConstraints <- case t of - NoTail -> return [] - EffectRowTail _ -> return [EffectConstraint appSrcId (sink effsAllowed) effs] - return $ resultTyConstraints ++ effConstraints maybeInterpretPunsAsTyCons :: RequiredTy n -> SigmaAtom n -> InfererM i n (SigmaAtom n) maybeInterpretPunsAsTyCons (Check TyKind) (SigmaUVar sn _ (UPunVar v)) = do @@ -813,12 +799,12 @@ inlineTypeAliases v = do LetBound (DeclBinding InlineLet (Atom e)) -> return e _ -> toAtom <$> toAtomVar v -applySigmaAtom :: Emits o => SrcId -> SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) -applySigmaAtom appSrcId (SigmaAtom _ f) args = emitWithEffects appSrcId =<< mkApp f args -applySigmaAtom appSrcId (SigmaUVar _ _ f) args = case f of +applySigmaAtom :: Emits o => SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) +applySigmaAtom (SigmaAtom _ f) args = emit =<< mkApp f args +applySigmaAtom (SigmaUVar _ _ f) args = case f of UAtomVar f' -> do f'' <- inlineTypeAliases f' - emitWithEffects appSrcId =<< mkApp f'' args + emit =<< mkApp f'' args UTyConVar f' -> do TyConDef sn roleExpls _ _ <- lookupTyCon f' let expls = snd <$> roleExpls @@ -846,9 +832,9 @@ applySigmaAtom appSrcId (SigmaUVar _ _ f) args = case f of let numParams = nestLength paramBs -- params aren't needed because they're already implied by the dict argument let (dictArg:args') = drop numParams args - emitWithEffects appSrcId =<< mkApplyMethod (fromJust $ toMaybeDict dictArg) methodIdx args' -applySigmaAtom appSrcId (SigmaPartialApp _ f prevArgs) args = - emitWithEffects appSrcId =<< mkApp f (prevArgs ++ args) + emit =<< mkApplyMethod (fromJust $ toMaybeDict dictArg) methodIdx args' +applySigmaAtom (SigmaPartialApp _ f prevArgs) args = + emit =<< mkApp f (prevArgs ++ args) splitParamPrefix :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n, [CAtom n]) splitParamPrefix tc args = do @@ -888,11 +874,6 @@ applyDataCon tc conIx topArgs = do where h:t = args _ -> error $ "Unexpected data con representation type: " ++ pprint rty -emitWithEffects :: Emits o => SrcId -> CExpr o -> InfererM i o (CAtom o) -emitWithEffects sid expr = do - addEffects sid $ getEffects expr - emit expr - checkExplicitArity :: SrcId -> [Explicitness] -> [a] -> InfererM i o () checkExplicitArity sid expls args = do let arity = length [() | Explicit <- expls] @@ -900,10 +881,7 @@ checkExplicitArity sid expls args = do when (numArgs /= arity) $ throw sid $ ArityErr arity numArgs type MixedArgs arg = ([arg], [(SourceName, arg)]) -- positional args, named args -data Constraint (n::S) = - TypeConstraint SrcId (CType n) (CType n) - -- permitted effects (no inference vars), proposed effects - | EffectConstraint SrcId (EffectRow CoreIR n) (EffectRow CoreIR n) +data Constraint (n::S) = TypeConstraint SrcId (CType n) (CType n) type Constraints = ListE Constraint type ConstrainedBinders n = ([IsDependent], Abs (Nest CBinder) Constraints n) @@ -955,13 +933,6 @@ inferMixedArgs appSrcId fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, case hoist bs c of HoistSuccess c' -> case c' of TypeConstraint _ _ _ -> applyConstraint c' >> return Nothing - EffectConstraint _ _ (EffectRow specificEffs _) -> - hasInferenceVars specificEffs >>= \case - False -> applyConstraint c' >> return Nothing - -- we delay applying the constraint in this case because we might - -- learn more about the specific effects after we've seen more - -- arguments (like a `Ref h a` that tells us about the `h`) - True -> return $ Just c HoistFailure _ -> return $ Just c inferMixedArg @@ -1044,18 +1015,15 @@ matchPrimApp :: Emits o => PrimName -> [CAtom o] -> InfererM i o (CAtom o) matchPrimApp = \case UNat -> \case ~[] -> return $ toAtom $ NewtypeTyCon Nat UFin -> \case ~[n] -> return $ toAtom $ NewtypeTyCon (Fin n) - UEffectRowKind -> \case ~[] -> return $ toAtom $ NewtypeTyCon EffectRowKind UBaseType b -> \case ~[] -> return $ toAtomR $ BaseType b UNatCon -> \case ~[x] -> return $ toAtom $ NewtypeCon NatCon x UPrimTC tc -> case tc of P.ProdType -> \ts -> return $ toAtom $ ProdType $ map (fromJust . toMaybeType) ts P.SumType -> \ts -> return $ toAtom $ SumType $ map (fromJust . toMaybeType) ts - P.RefType -> \case ~[h, a] -> return $ toAtom $ RefType h (fromJust $ toMaybeType a) + P.RefType -> \case ~[h, a] -> undefined -- return $ toAtom $ RefType h (fromJust $ toMaybeType a) P.TypeKind -> \case ~[] -> return $ Con $ TyConAtom $ TypeKind - P.HeapType -> \case ~[] -> return $ Con $ TyConAtom $ HeapType UCon con -> case con of P.ProdCon -> \xs -> return $ toAtom $ ProdCon xs - P.HeapVal -> \case ~[] -> return $ toAtom HeapVal P.SumCon _ -> error "not supported" UMiscOp op -> \x -> emit =<< MiscOp <$> matchGenericOp op x UMemOp op -> \x -> emit =<< MemOp <$> matchGenericOp op x @@ -1068,17 +1036,7 @@ matchPrimApp = \case UApplyMethod i -> \case ~(d:args) -> emit =<< mkApplyMethod (fromJust $ toMaybeDict d) i args ULinearize -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Linearize f' x UTranspose -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Transpose f' x - URunReader -> \case ~[x, f] -> do f' <- lam2 f; emitHof $ RunReader x f' - URunState -> \case ~[x, f] -> do f' <- lam2 f; emitHof $ RunState Nothing x f' - UWhile -> \case ~[f] -> do f' <- lam0 f; emitHof $ While f' - URunIO -> \case ~[f] -> do f' <- lam0 f; emitHof $ RunIO f' - UCatchException-> \case ~[f] -> do f' <- lam0 f; emitHof =<< mkCatchException f' UMExtend -> \case ~[r, z, f, x] -> do f' <- lam2 f; emit $ RefOp r $ MExtend (BaseMonoid z f') x - URunWriter -> \args -> do - [idVal, combiner, f] <- return args - combiner' <- lam2 combiner - f' <- lam2 f - emitHof $ RunWriter Nothing (BaseMonoid idVal combiner') f' p -> \case xs -> throwInternal $ "Bad primitive application: " ++ show (p, xs) where lam2 :: Fallible m => CAtom n -> m (LamExpr CoreIR n) @@ -1091,11 +1049,6 @@ matchPrimApp = \case ExplicitCoreLam (UnaryNest b) body <- return x return $ UnaryLamExpr b body - lam0 :: Fallible m => CAtom n -> m (CExpr n) - lam0 x = do - ExplicitCoreLam Empty body <- return x - return body - matchGenericOp :: GenericOp op => OpConst op CoreIR -> [CAtom n] -> InfererM i n (op CoreIR n) matchGenericOp op xs = do (tyArgs, dataArgs) <- partitionEithers <$> forM xs \x -> do @@ -1190,8 +1143,7 @@ instanceFun instanceName appExpl = do liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do args <- mapM toAtomVar $ nestToNames bs' result <- toAtom <$> mkInstanceDict (sink instanceName) (toAtom <$> args) - let effTy = EffTy Pure (getType result) - let piTy = CorePiType appExpl (snd<$>expls) bs' effTy + let piTy = CorePiType appExpl (snd<$>expls) bs' (getType result) return $ toAtom $ CoreLamExpr piTy (LamExpr bs' $ Atom result) checkMaybeAnnExpr :: Emits o => Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o) @@ -1280,7 +1232,7 @@ inferClassDef className methodNames paramBs methodTys = do methodTys' <- forM methodTys \m -> do checkUType m >>= \case TyCon (Pi t) -> return t - t -> return $ CorePiType ImplicitApp [] Empty (EffTy Pure t) + t -> return $ CorePiType ImplicitApp [] Empty t PairB paramBs'' superclassBs <- partitionBinders rootSrcId (zipAttrs roleExpls paramBs') $ \b@(WithAttrB (_, expl) b') -> case expl of Explicit -> return $ LeftB b @@ -1360,26 +1312,22 @@ inferAnn binderSrcId ann cs = case ann of checkULamPartial :: PartialPiType o -> SrcId -> ULamExpr i -> InfererM i o (CoreLamExpr o) checkULamPartial partialPiTy sid lamExpr = do - PartialPiType piAppExpl expls piBs piEffs piReqTy <- return partialPiTy - ULamExpr lamBs lamAppExpl lamEffs lamResultTy body <- return lamExpr + PartialPiType piAppExpl expls piBs piReqTy <- return partialPiTy + ULamExpr lamBs lamAppExpl lamResultTy body <- return lamExpr checkExplicitArity sid expls (nestToList (const ()) lamBs) when (piAppExpl /= lamAppExpl) $ throw sid $ WrongArrowErr (pprint piAppExpl) (pprint lamAppExpl) checkLamBinders expls piBs lamBs \lamBs' -> do - PairE piEffs' piReqTy' <- applyRename (piBs @@> (atomVarName <$> bindersVars lamBs')) (PairE piEffs piReqTy) + piReqTy' <- applyRename (piBs @@> (atomVarName <$> bindersVars lamBs')) piReqTy resultTy <- case (lamResultTy, piReqTy') of (Nothing, Infer ) -> return Infer (Just t , Infer ) -> Check <$> checkUType t (Nothing, Check t) -> Check <$> return t (Just t , Check t') -> checkUType t >>= expectEq (getSrcId t) t' >> return (Check t') - forM_ lamEffs \lamEffs' -> do - lamEffs'' <- checkUEffRow lamEffs' - expectEq sid (Eff piEffs') (Eff lamEffs'') -- TODO: add source annotations to lambda effects too - body' <- withAllowedEffects piEffs' do - buildBlock $ withBlockDecls body \result -> checkOrInfer (sink resultTy) result + body' <- buildBlock $ withBlockDecls body \result -> checkOrInfer (sink resultTy) result resultTy' <- case resultTy of Infer -> return $ getType body' Check t -> return t - let piTy = CorePiType piAppExpl expls lamBs' (EffTy piEffs' resultTy') + let piTy = CorePiType piAppExpl expls lamBs' resultTy' return $ CoreLamExpr piTy (LamExpr lamBs' body') where checkLamBinders @@ -1412,41 +1360,38 @@ inferUForExpr (UForExpr b body) = do checkUForExpr :: Emits o => SrcId -> UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) checkUForExpr sid (UForExpr bFor body) (TabPiType _ bPi resultTy) = do - let uLamExpr = ULamExpr (UnaryNest bFor) ExplicitApp Nothing Nothing body - effsAllowed <- infEffects <$> getInfState + let uLamExpr = ULamExpr (UnaryNest bFor) ExplicitApp Nothing body partialPi <- liftEnvReaderM $ refreshAbs (Abs bPi resultTy) \bPi' resultTy' -> do - return $ PartialPiType ExplicitApp [Explicit] (UnaryNest bPi') (sink effsAllowed) (Check resultTy') + return $ PartialPiType ExplicitApp [Explicit] (UnaryNest bPi') (Check resultTy') CoreLamExpr _ lamExpr <- checkULamPartial partialPi sid uLamExpr return lamExpr inferULam :: ULamExpr i -> InfererM i o (CoreLamExpr o) -inferULam (ULamExpr bs appExpl effs resultTy body) = do - Abs (ZipB expls bs') (PairE effTy body') <- inferUBinders bs \_ -> do - effs' <- fromMaybe Pure <$> mapM checkUEffRow effs +inferULam (ULamExpr bs appExpl resultTy body) = do + Abs (ZipB expls bs') (PairE ty body') <- inferUBinders bs \_ -> do resultTy' <- mapM checkUType resultTy - body' <- buildBlock $ withAllowedEffects (sink effs') do + body' <- buildBlock do withBlockDecls body \result -> case resultTy' of Nothing -> bottomUp result Just resultTy'' -> topDown (sink resultTy'') result - let effTy = EffTy effs' (getType body') - return $ PairE effTy body' - return $ CoreLamExpr (CorePiType appExpl expls bs' effTy) (LamExpr bs' body') + return $ PairE (getType body') body' + return $ CoreLamExpr (CorePiType appExpl expls bs' ty) (LamExpr bs' body') checkULam :: SrcId -> ULamExpr i -> CorePiType o -> InfererM i o (CoreLamExpr o) checkULam sid ulam piTy = checkULamPartial (piAsPartialPi piTy) sid ulam piAsPartialPi :: CorePiType n -> PartialPiType n -piAsPartialPi (CorePiType appExpl expls bs (EffTy effs ty)) = - PartialPiType appExpl expls bs effs (Check ty) +piAsPartialPi (CorePiType appExpl expls bs ty) = + PartialPiType appExpl expls bs (Check ty) typeAsPartialType :: CType n -> PartialType n typeAsPartialType (TyCon (Pi piTy)) = PartialType $ piAsPartialPi piTy typeAsPartialType ty = FullType ty piAsPartialPiDropResultTy :: CorePiType n -> PartialPiType n -piAsPartialPiDropResultTy (CorePiType appExpl expls bs (EffTy effs _)) = - PartialPiType appExpl expls bs effs Infer +piAsPartialPiDropResultTy (CorePiType appExpl expls bs _) = + PartialPiType appExpl expls bs Infer checkInstanceParams :: Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] checkInstanceParams bsTop paramsTop = go bsTop paramsTop @@ -1493,26 +1438,6 @@ checkMethodDef className methodTys (WithSrcE sid m) = do throw sid $ NotAMethod (pprint sourceName) (pprint $ getSourceName classDef) (i,) <$> toAtom <$> Lam <$> checkULam sid rhs (methodTys !! i) -checkUEffRow :: UEffectRow i -> InfererM i o (EffectRow CoreIR o) -checkUEffRow (UEffectRow effs t) = do - effs' <- liftM eSetFromList $ mapM checkUEff $ toList effs - t' <- case t of - Nothing -> return NoTail - Just (SourceOrInternalName ~(InternalName sid _ v)) -> do - v' <- toAtomVar =<< renameM v - expectEq sid EffKind (getType v') - return $ EffectRowTail v' - return $ EffectRow effs' t' - -checkUEff :: UEffect i -> InfererM i o (Effect CoreIR o) -checkUEff eff = case eff of - URWSEffect rws (SourceOrInternalName ~(InternalName sid _ region)) -> do - region' <- renameM region >>= toAtomVar - expectEq sid (TyCon HeapType) (getType region') - return $ RWSEffect rws (toAtom region') - UExceptionEffect -> return ExceptionEffect - UIOEffect -> return IOEffect - type CaseAltIndex = Int checkCaseAlt :: Emits o => RequiredTy o -> CType o -> UAlt i -> InfererM i o (IndexedAlt o) @@ -1626,7 +1551,7 @@ checkUType t = do checkUParam :: Kind CoreIR o -> UType i -> InfererM i o (CAtom o) checkUParam k uty = - withReducibleEmissions (getSrcId uty) msg $ withAllowedEffects Pure $ topDownExplicit (sink k) uty + withReducibleEmissions (getSrcId uty) msg $ topDownExplicit (sink k) uty where msg = CantReduceType $ pprint uty inferTabCon :: forall i o. Emits o => SrcId -> [UExpr i] -> InfererM i o (CAtom o) @@ -1652,14 +1577,6 @@ checkTabCon tabTy@(TabPiType _ b elemTy) sid xs = do topDown elemTy' x emit $ TabCon (TyCon (TabPi tabTy)) xs' -addEffects :: SrcId -> EffectRow CoreIR o -> InfererM i o () -addEffects _ Pure = return () -addEffects sid eff = do - effsAllowed <- infEffects <$> getInfState - case checkExtends effsAllowed eff of - Success () -> return () - Failure _ -> expectEq sid (Eff effsAllowed) (Eff eff) - getIxDict :: SrcId -> CType o -> InfererM i o (IxDict CoreIR o) getIxDict sid t = fromJust <$> toMaybeDict <$> trySynthTerm sid (toType $ IxDictType t) Full @@ -1681,21 +1598,6 @@ lookupSolverSubst (SolverSubst m) name = applyConstraint :: Constraint o -> SolverM i o () applyConstraint = \case TypeConstraint sid t1 t2 -> constrainEq sid t1 t2 - EffectConstraint sid r1 r2' -> do - -- r1 shouldn't have inference variables. And we can't infer anything about - -- any inference variables in r2's explicit effects because we don't know - -- how they line up with r1's. So this is just about figuring out r2's tail. - r2 <- zonk r2' - let msg = DisallowedEffects (pprint r1) (pprint r2) - case checkExtends r1 r2 of - Success () -> return () - Failure _ -> searchFailureAsTypeErr sid msg do - EffectRow effs1 t1 <- return r1 - EffectRow effs2 (EffectRowTail v2) <- return r2 - guard =<< isUnificationName (atomVarName v2) - guard $ null (eSetToList $ effs2 `eSetDifference` effs1) - let extras1 = effs1 `eSetDifference` effs2 - extendSolution v2 (toAtom $ EffectRow extras1 t1) constrainEq :: ToAtom e CoreIR => SrcId -> e o -> e o -> SolverM i o () constrainEq sid t1 t2 = do @@ -1774,10 +1676,6 @@ instance Unifiable (Con CoreIR) where { SumCon ts' i' x' <- matchit; unifyLists ts ts'; guard (i==i'); unify x x'} ( DepPair t x y ) -> do { DepPair t' x' y' <- matchit; unify t t'; unify x x'; unify y y'} - ( HeapVal ) -> do - { HeapVal <- matchit; return ()} - ( Eff eff ) -> do - { Eff eff' <- matchit; unify eff eff'} ( Lam lam ) -> do { Lam lam' <- matchit; unifyEq lam lam'} ( NewtypeCon con x ) -> do @@ -1792,8 +1690,6 @@ instance Unifiable (TyCon CoreIR) where unify t1 t2 = case t1 of ( BaseType b ) -> do { BaseType b' <- matchit; guard $ b == b'} - ( HeapType ) -> do - { HeapType <- matchit; return () } ( TypeKind ) -> do { TypeKind <- matchit; return () } ( Pi piTy ) -> do @@ -1809,7 +1705,7 @@ instance Unifiable (TyCon CoreIR) where ( ProdType ts ) -> do { ProdType ts' <- matchit; unifyLists ts ts'} ( RefType h t ) -> do - { RefType h' t' <- matchit; unify h h'; unify t t'} + { RefType h' t' <- matchit; guard (h == h'); unify t t'} ( DepPairTy t ) -> do { DepPairTy t' <- matchit; unify t t'} where matchit = return t2 @@ -1834,8 +1730,6 @@ instance Unifiable NewtypeTyCon where { Nat <- matchit; return ()} ( Fin n ) -> do { Fin n' <- matchit; unify n n'} - ( EffectRowKind ) -> do - { EffectRowKind <- matchit; return ()} ( UserADTType _ c params ) -> do { UserADTType _ c' params' <- matchit; guard (c == c') >> unify params params' } where matchit = return e2 @@ -1848,56 +1742,21 @@ instance Unifiable TyConParams where -- We ignore the dictionaries because we assume coherence unify ps ps' = zipWithM_ unify (ignoreSynthParams ps) (ignoreSynthParams ps') -instance Unifiable (EffectRow CoreIR) where - unify x1 x2 = - unifyDirect x1 x2 - <|> unifyDirect x2 x1 - <|> unifyZip x1 x2 - - where - unifyDirect :: EffectRow CoreIR n -> EffectRow CoreIR n -> SolverM i n () - unifyDirect r@(EffectRow effs' mv') (EffectRow effs (EffectRowTail v)) | null (eSetToList effs) = - case mv' of - EffectRowTail v' | v == v' -> guard $ null $ eSetToList effs' - _ -> extendSolution v (Con $ Eff r) - unifyDirect _ _ = empty - {-# INLINE unifyDirect #-} - - unifyZip :: EffectRow CoreIR n -> EffectRow CoreIR n -> SolverM i n () - unifyZip r1 r2 = case (r1, r2) of - (EffectRow effs1 t1, EffectRow effs2 t2) | not (eSetNull effs1 || eSetNull effs2) -> do - let extras1 = effs1 `eSetDifference` effs2 - let extras2 = effs2 `eSetDifference` effs1 - void $ withFreshEff \newRow -> do - unify (EffectRow mempty (sink t1)) (extendEffRow (sink extras2) newRow) - unify (extendEffRow (sink extras1) newRow) (EffectRow mempty (sink t2)) - return UnitE - _ -> unifyEq r1 r2 - -withFreshEff - :: Zonkable e - => (forall o'. DExt o o' => EffectRow CoreIR o' -> SolverM i o' (e o')) - -> SolverM i o (e o) -withFreshEff cont = - withFreshUnificationVarNoEmits rootSrcId MiscInfVar EffKind \v -> do - cont $ EffectRow mempty $ EffectRowTail v -{-# INLINE withFreshEff #-} - unifyEq :: AlphaEqE e => e n -> e n -> SolverM i n () unifyEq e1 e2 = guard =<< alphaEq e1 e2 {-# INLINE unifyEq #-} instance Unifiable CorePiType where unify (CorePiType appExpl1 expls1 bsTop1 effTy1) - (CorePiType appExpl2 expls2 bsTop2 effTy2) = do + (CorePiType appExpl2 expls2 bsTop2 effTy2) = do unless (appExpl1 == appExpl2) empty unless (expls1 == expls2) empty go (Abs bsTop1 effTy1) (Abs bsTop2 effTy2) where - go :: Abs (Nest CBinder) (EffTy CoreIR) n - -> Abs (Nest CBinder) (EffTy CoreIR) n + go :: Abs (Nest CBinder) (Type CoreIR) n + -> Abs (Nest CBinder) (Type CoreIR) n -> SolverM i n () - go (Abs Empty (EffTy e1 t1)) (Abs Empty (EffTy e2 t2)) = unify t1 t2 >> unify e1 e2 + go (Abs Empty t1) (Abs Empty t2) = unify t1 t2 go (Abs (Nest (b1:>t1) bs1) rest1) (Abs (Nest (b2:>t2) bs2) rest2) = do unify t1 t2 @@ -2121,7 +1980,7 @@ getSynthType x = ignoreExcept $ typeAsSynthType rootSrcId (getType x) typeAsSynthType :: SrcId -> CType n -> Except (SynthType n) typeAsSynthType sid = \case TyCon (DictTy dictTy) -> return $ SynthDictType dictTy - TyCon (Pi (CorePiType ImplicitApp expls bs (EffTy Pure (TyCon (DictTy d))))) -> + TyCon (Pi (CorePiType ImplicitApp expls bs (TyCon (DictTy d)))) -> return $ SynthPiType (expls, Abs bs d) ty -> Failure $ toErr sid $ NotASynthType $ pprint ty {-# SCC typeAsSynthType #-} @@ -2171,7 +2030,7 @@ synthTerm sid targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of ab' <- withFreshBindersInf expls ab \bs' targetTy' -> do Abs bs' <$> synthTerm sid (SynthDictType targetTy') reqMethodAccess Abs bs' synthExpr <- return ab' - let piTy = CorePiType ImplicitApp expls bs' (EffTy Pure (getType synthExpr)) + let piTy = CorePiType ImplicitApp expls bs' (getType synthExpr) let lamExpr = LamExpr bs' (Atom synthExpr) return $ toAtom $ Lam $ CoreLamExpr piTy lamExpr SynthDictType dictTy -> case dictTy of @@ -2213,7 +2072,7 @@ synthDictFromInstance :: SrcId -> DictType n -> InfererM i n (SynthAtom n) synthDictFromInstance sid targetTy = do instances <- getInstanceDicts targetTy asum $ instances <&> \candidate -> typeErrAsSearchFailure do - CorePiType _ expls bs (EffTy _ (TyCon (DictTy candidateTy))) <- lookupInstanceTy candidate + CorePiType _ expls bs (TyCon (DictTy candidateTy)) <- lookupInstanceTy candidate args <- instantiateSynthArgs sid targetTy (expls, Abs bs candidateTy) return $ toAtom $ InstanceDict (toType targetTy) candidate args @@ -2282,18 +2141,18 @@ asFFIFunType ty = return do return (impTy, piTy) checkFFIFunTypeM :: Fallible m => CorePiType n -> m IFunType -checkFFIFunTypeM (CorePiType appExpl (_:expls) (Nest b bs) effTy) = do +checkFFIFunTypeM (CorePiType appExpl (_:expls) (Nest b bs) ty) = do argTy <- checkScalar $ binderType b case bs of Empty -> do - resultTys <- checkScalarOrPairType (etTy effTy) + resultTys <- checkScalarOrPairType ty let cc = case length resultTys of 0 -> error "Not implemented" 1 -> FFICC _ -> FFIMultiResultCC return $ IFunType cc [argTy] resultTys Nest b' rest -> do - let naryPiRest = CorePiType appExpl expls (Nest b' rest) effTy + let naryPiRest = CorePiType appExpl expls (Nest b' rest) ty IFunType cc argTys resultTys <- checkFFIFunTypeM naryPiRest return $ IFunType cc (argTy:argTys) resultTys checkFFIFunTypeM _ = error "expected at least one argument" @@ -2388,16 +2247,10 @@ instance RenameE SynthType instance SubstE AtomSubstVal SynthType instance GenericE Constraint where - type RepE Constraint = PairE - (LiftE SrcId) - (EitherE - (PairE CType CType) - (PairE (EffectRow CoreIR) (EffectRow CoreIR))) - fromE (TypeConstraint sid t1 t2) = LiftE sid `PairE` LeftE (PairE t1 t2) - fromE (EffectConstraint sid e1 e2) = LiftE sid `PairE` RightE (PairE e1 e2) + type RepE Constraint = PairE (LiftE SrcId) (PairE CType CType) + fromE (TypeConstraint sid t1 t2) = LiftE sid `PairE` PairE t1 t2 {-# INLINE fromE #-} - toE (LiftE sid `PairE` LeftE (PairE t1 t2)) = TypeConstraint sid t1 t2 - toE (LiftE sid `PairE` RightE (PairE e1 e2)) = EffectConstraint sid e1 e2 + toE (LiftE sid `PairE` PairE t1 t2) = TypeConstraint sid t1 t2 {-# INLINE toE #-} instance SinkableE Constraint @@ -2445,11 +2298,10 @@ instance RenameE SolverSubst where instance HoistableE SolverSubst instance GenericE PartialPiType where - type RepE PartialPiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) - (EffectRow CoreIR `PairE` RequiredTy) - fromE (PartialPiType ex exs b eff ty) = LiftE (ex, exs) `PairE` Abs b (PairE eff ty) + type RepE PartialPiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) RequiredTy + fromE (PartialPiType ex exs b ty) = LiftE (ex, exs) `PairE` Abs b ty {-# INLINE fromE #-} - toE (LiftE (ex, exs) `PairE` Abs b (PairE eff ty)) = PartialPiType ex exs b eff ty + toE (LiftE (ex, exs) `PairE` Abs b ty) = PartialPiType ex exs b ty {-# INLINE toE #-} instance SinkableE PartialPiType diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index f72f24bef..7a22882c5 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -199,7 +199,7 @@ data Context (from::E) (to::E) (o::S) where Stop :: Context e e o TabAppCtx :: SAtom i -> Subst InlineSubstVal i o -> Context SExpr e o -> Context SExpr e o - CaseCtx :: [SAlt i] -> SType i -> EffectRow SimpIR i + CaseCtx :: [SAlt i] -> SType i -> Effects SimpIR i -> Subst InlineSubstVal i o -> Context SExpr e o -> Context SExpr e o EmitToAtomCtx :: Context SAtom e o -> Context SExpr e o @@ -327,7 +327,7 @@ reconstructTabApp ctx expr i = case expr of reconstruct ctx =<< mkTabApp array' i' reconstructCase :: Emits o - => Context SExpr e o -> SExpr o -> [SAlt i] -> SType i -> EffectRow SimpIR i + => Context SExpr e o -> SExpr o -> [SAlt i] -> SType i -> Effects SimpIR i -> InlineM i o (e o) reconstructCase ctx scrutExpr alts resultTy effs = case scrutExpr of @@ -360,8 +360,8 @@ reconstructCase ctx scrutExpr alts resultTy effs = effs' <- inline Stop effs reconstruct ctx $ Case scrut alts' (EffTy effs' resultTy') -instance Inlinable (EffectRow SimpIR) -instance Inlinable (EffTy SimpIR) +instance Inlinable (Effects SimpIR) +instance Inlinable (EffTy SimpIR) -- === NoteReconstructTabAppDecisions === diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 1c496cf5e..1cdbcd72a 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -33,11 +33,10 @@ import Util (enumerate) -- === linearization monad === data ActivePrimals (n::S) = ActivePrimals - { activeVars :: [AtomVar SimpIR n] -- includes refs and regions - , activeEffs :: EffectRow SimpIR n } + { activeVars :: [AtomVar SimpIR n] } -- includes refs emptyActivePrimals :: ActivePrimals n -emptyActivePrimals = ActivePrimals [] Pure +emptyActivePrimals = ActivePrimals [] data TangentArgs (n::S) = TangentArgs [SAtomVar n] @@ -72,10 +71,6 @@ extendActiveSubst => b i i' -> SAtomVar o -> PrimalM i' o a -> PrimalM i o a extendActiveSubst b v cont = extendSubst (b@>atomVarName v) $ extendActivePrimals v cont -extendActiveEffs :: Effect SimpIR o -> PrimalM i o a -> PrimalM i o a -extendActiveEffs eff = local \primals -> - primals { activeEffs = extendEffRow (eSetSingleton eff) (activeEffs primals)} - extendActivePrimals :: SAtomVar o -> PrimalM i o a -> PrimalM i o a extendActivePrimals v = extendActivePrimalss [v] @@ -89,9 +84,6 @@ getTangentArg idx = asks \(TangentArgs vs) -> toAtom $ vs !! idx extendTangentArgs :: SAtomVar n -> TangentM n a -> TangentM n a extendTangentArgs v m = local (\(TangentArgs vs) -> TangentArgs $ vs ++ [v]) m -extendTangentArgss :: [SAtomVar n] -> TangentM n a -> TangentM n a -extendTangentArgss vs' m = local (\(TangentArgs vs) -> TangentArgs $ vs ++ vs') m - getTangentArgs :: TangentM o (TangentArgs o) getTangentArgs = ask @@ -167,39 +159,26 @@ tangentFunAsLambda => (forall o'. (DExt o o', Emits o') => TangentM o' (Atom SimpIR o')) -> PrimalM i o (SLam o) tangentFunAsLambda cont = do - ActivePrimals primalVars _ <- getActivePrimals + ActivePrimals primalVars <- getActivePrimals tangentTys <- getTangentArgTys primalVars buildLamExpr tangentTys \tangentVars -> do liftTangentM (TangentArgs $ map sink tangentVars) cont getTangentArgTys :: (Fallible1 m, EnvExtender m) => [SAtomVar n] -> m n (EmptyAbs (Nest SBinder) n) -getTangentArgTys topVs = go mempty topVs where - go :: (Fallible1 m, EnvExtender m) - => EMap SAtomName SAtomVar n -> [SAtomVar n] -> m n (EmptyAbs (Nest SBinder) n) - go _ [] = return $ EmptyAbs Empty - go heapMap (v:vs) = case getType v of - -- This is a hack to handle heaps/references. They normally come in pairs - -- like this, but there's nothing to prevent users writing programs that - -- sling around heap variables by themselves. We should try to do something - -- better... - TyCon HeapType -> do - withFreshBinder (getNameHint v) (TyCon HeapType) \hb -> do - let newHeapMap = sink heapMap <> eMapSingleton (sink (atomVarName v)) (binderVar hb) - Abs bs UnitE <- go newHeapMap $ sinkList vs - return $ EmptyAbs $ Nest hb bs - RefTy (Stuck _ (Var h)) referentTy -> do - case lookupEMap heapMap (atomVarName h) of - Nothing -> error "shouldn't happen?" - Just h' -> do - tt <- tangentType referentTy - let refTy = RefTy (toAtom h') tt - withFreshBinder (getNameHint v) refTy \refb -> do - Abs bs UnitE <- go (sink heapMap) $ sinkList vs - return $ EmptyAbs $ Nest refb bs +getTangentArgTys topVs = go topVs where + go :: (Fallible1 m, EnvExtender m) => [SAtomVar n] -> m n (EmptyAbs (Nest SBinder) n) + go [] = return $ EmptyAbs Empty + go (v:vs) = case getType v of + RefTy rws referentTy -> do + tt <- tangentType referentTy + let refTy = RefTy rws tt + withFreshBinder (getNameHint v) refTy \refb -> do + Abs bs UnitE <- go $ sinkList vs + return $ EmptyAbs $ Nest refb bs ty -> do tt <- tangentType ty withFreshBinder (getNameHint v) tt \b -> do - Abs bs UnitE <- go (sink heapMap) $ sinkList vs + Abs bs UnitE <- go $ sinkList vs return $ EmptyAbs $ Nest b bs class ReconFunctor (f :: E -> E) where @@ -487,7 +466,6 @@ linearizeMiscOp op = case op of BitcastOp _ _ -> notImplemented UnsafeCoerce _ _ -> notImplemented GarbageVal _ -> notImplemented - ThrowException _ -> notImplemented ThrowError _ -> zero OutputStream -> zero ShowAny _ -> error "Shouldn't have ShowAny in simplified IR" @@ -597,7 +575,6 @@ linearizePrimCon con = case con of Lit _ -> zero ProdCon xs -> fmapLin (Con . ProdCon . fromComposeE) $ seqLin (fmap linearizeAtom xs) SumCon _ _ _ -> notImplemented - HeapVal -> zero DepPair _ _ _ -> notImplemented where zero = emitZeroT con @@ -625,82 +602,18 @@ linearizeHof hof = case hof of residuals' <- tabApp (sink primalsAux) (toAtom i') >>= getSnd >>= unpackTelescope bs extendSubst (bs @@> (SubstVal <$> residuals')) $ applyLinLam linLam' - RunReader r lam -> do - WithTangent r' rLin <- linearizeAtom r - (lam', recon) <- linearizeEffectFun Reader lam - primalAux <- emitHof $ RunReader r' lam' - referentTy <- renameM $ getType r - (primal, linLam) <- reconstruct primalAux recon - return $ WithTangent primal do - rLin' <- rLin - tt <- tangentType $ sink referentTy - tanEffLam <- buildEffLam noHint tt \h ref -> - extendTangentArgss [h, ref] do - withSubstReaderT $ applyLinLam $ sink linLam - emitHofLin $ RunReader rLin' tanEffLam - RunState Nothing sInit lam -> do - WithTangent sInit' sLin <- linearizeAtom sInit - (lam', recon) <- linearizeEffectFun State lam - (primalAux, sFinal) <- fromPair =<< emitHof (RunState Nothing sInit' lam') - referentTy <- snd <$> getTypeRWSAction lam' - (primal, linLam) <- reconstruct primalAux recon - return $ WithTangent (PairVal primal sFinal) do - sLin' <- sLin - tt <- tangentType $ sink referentTy - tanEffLam <- buildEffLam noHint tt \h ref -> - extendTangentArgss [h, ref] do - withSubstReaderT $ applyLinLam $ sink linLam - emitHofLin $ RunState Nothing sLin' tanEffLam - RunWriter Nothing bm lam -> do - -- TODO: check it's actually the 0/+ monoid (or should we just build that in?) - bm' <- renameM bm - (lam', recon) <- linearizeEffectFun Writer lam - (primalAux, wFinal) <- fromPair =<< emitHof (RunWriter Nothing bm' lam') - (primal, linLam) <- reconstruct primalAux recon - referentTy <- snd <$> getTypeRWSAction lam' - return $ WithTangent (PairVal primal wFinal) do - bm'' <- sinkM bm' - tt <- tangentType $ sink referentTy - tanEffLam <- buildEffLam noHint tt \h ref -> - extendTangentArgss [h, ref] do - withSubstReaderT $ applyLinLam $ sink linLam - emitHofLin $ RunWriter Nothing bm'' tanEffLam - RunIO body -> do - (body', recon) <- linearizeExprDefunc body - primalAux <- emitHof $ RunIO body' - (primal, linLam) <- reconstruct primalAux recon - return $ WithTangent primal do - withSubstReaderT $ applyLinLam $ sink linLam _ -> error $ "not implemented: " ++ pprint hof -linearizeEffectFun :: RWS -> SLam i -> PrimalM i o (SLam o, LinLamAbs o) -linearizeEffectFun rws (BinaryLamExpr hB refB body) = do - withFreshBinder noHint (TyCon HeapType) \h -> do - bTy <- extendSubst (hB@>binderName h) $ renameM $ binderType refB - withFreshBinder noHint bTy \b -> do - let ref = binderVar b - hVar <- sinkM $ binderVar h - (body', linLam) <- extendActiveSubst hB hVar $ extendActiveSubst refB ref $ - -- TODO: maybe we should check whether we need to extend the active effects - extendActiveEffs (RWSEffect rws (toAtom hVar)) do - linearizeExprDefunc body - -- TODO: this assumes that references aren't returned. Our type system - -- ensures that such references can never be *used* once the effect runner - -- returns, but technically it's legal to return them. - let linLam' = ignoreHoistFailure $ hoist (PairB h b) linLam - return (BinaryLamExpr h b body', linLam') -linearizeEffectFun _ _ = error "expect effect function to be a binary lambda" - notImplemented :: HasCallStack => a notImplemented = error "Not implemented" -- === boring instances === instance GenericE ActivePrimals where - type RepE ActivePrimals = PairE (ListE SAtomVar) (EffectRow SimpIR) - fromE (ActivePrimals vs effs) = ListE vs `PairE` effs + type RepE ActivePrimals = ListE SAtomVar + fromE (ActivePrimals vs) = ListE vs {-# INLINE fromE #-} - toE (ListE vs `PairE` effs) = ActivePrimals vs effs + toE (ListE vs) = ActivePrimals vs {-# INLINE toE #-} instance SinkableE ActivePrimals diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 9f160776a..cdf136dfd 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -67,7 +67,7 @@ lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftE liftAtomSubstBuilder case wantDestStyle of True -> do xs <- bindersToAtoms bs' - EffTy _ resultTy <- instantiate (sink piTy) xs + resultTy <- instantiate (sink piTy) xs let resultDestTy = RawRefTy resultTy withFreshBinder "ans" resultDestTy \destBinder -> do let dest = toAtom $ binderVar destBinder @@ -252,16 +252,6 @@ lowerExpr dest expr = case expr of PrimOp (Hof (TypedHof (EffTy _ ansTy) (For dir ixDict body))) -> do ansTy' <- substM ansTy lowerFor ansTy' dest dir ixDict body - PrimOp (Hof (TypedHof (EffTy _ ty) (RunWriter Nothing m body))) -> do - PairTy _ ansTy <- visitType ty - traverseRWS ansTy body \ref' body' -> do - m' <- visitGeneric m - emitHof $ RunWriter ref' m' body' - PrimOp (Hof (TypedHof (EffTy _ ty) (RunState Nothing s body))) -> do - PairTy _ ansTy <- visitType ty - traverseRWS ansTy body \ref' body' -> do - s' <- visitAtom s - emitHof $ RunState ref' s' body' -- this case is important because this pass changes effects PrimOp (Hof (TypedHof _ hof)) -> do hof' <- emit =<< (visitGeneric hof >>= mkTypedHof) @@ -280,27 +270,6 @@ lowerExpr dest expr = case expr of place d e return e - traverseRWS - :: SType o -> LamExpr SimpIR i - -> (OptDest o -> LamExpr SimpIR o -> LowerM i o (SAtom o)) - -> LowerM i o (SAtom o) - traverseRWS referentTy (LamExpr (BinaryNest hb rb) body) cont = do - unpackRWSDest dest >>= \case - Nothing -> generic - Just (bodyDest, refDest) -> do - cont refDest =<< - buildEffLam (getNameHint rb) referentTy \hb' rb' -> - extendRenamer (hb@>atomVarName hb' <.> rb@>atomVarName rb') do - lowerExpr (sink <$> bodyDest) body - traverseRWS _ _ _ = error "Expected a binary lambda expression" - - unpackRWSDest = \case - Nothing -> return Nothing - Just d -> do - bd <- getProjRef (ProjectProduct 0) d - rd <- getProjRef (ProjectProduct 1) d - return $ Just (Just bd, Just rd) - place :: Emits o => Dest o -> SAtom o -> LowerM i o () place d x = void $ emit $ Place d x diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index 0e75165be..f09a07db2 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -203,10 +203,7 @@ summary atom = case atom of Lit _ -> return $ Deterministic [] ProdCon elts -> Product <$> mapM summary elts SumCon _ tag payload -> Inject tag <$> summary payload - HeapVal -> invalid "HeapVal" DepPair _ _ _ -> error "not implemented" - where - invalid tag = error $ "Unexpected indexing by " ++ tag unknown :: HoistableE e => e n -> OCCM n (IxExpr n) unknown _ = return IxAll @@ -417,7 +414,7 @@ instance HasOCC (TypedHof SimpIR) where occ a (TypedHof effTy hof) = TypedHof <$> occ a effTy <*> occ a hof instance HasOCC (Hof SimpIR) where - occ a hof = case hof of + occ _ hof = case hof of For ann ixDict (UnaryLamExpr b body) -> do ixDict' <- inlinedLater ixDict occWithBinder (Abs b body) \b' body' -> do @@ -426,47 +423,6 @@ instance HasOCC (Hof SimpIR) where return $ For ann ixDict' (UnaryLamExpr b' body'') For _ _ _ -> error "For body should be a unary lambda expression" While body -> While <$> censored useManyTimes (occ accessOnce body) - RunReader ini bd -> do - iniIx <- summary ini - bd' <- oneShot a [Deterministic [], iniIx] bd - ini' <- occ accessOnce ini - return $ RunReader ini' bd' - RunWriter Nothing (BaseMonoid empty combine) bd -> do - -- There is no way to read from the reference in a Writer, so the only way - -- an indexing expression can depend on it is by referring to the - -- reference itself. One way to so refer that is opaque to occurrence - -- analysis would be to pass the reference to a standalone function which - -- returns an index (presumably without actually reading any information - -- from said reference). - -- - -- To cover this case, we write `Deterministic []` here. This is correct, - -- because RunWriter creates the reference without reading any external - -- names. In particular, in the event of `RunWriter` in a loop, the - -- different references across loop iterations are not distinguishable. - -- The same argument holds for the heap parameter. - bd' <- oneShot a [Deterministic [], Deterministic []] bd - -- We will process the combining function when we meet it in MExtend ops - -- (but we won't attempt to eliminate dead code in it). - empty' <- occ accessOnce empty - return $ RunWriter Nothing (BaseMonoid empty' combine) bd' - RunWriter (Just _) _ _ -> - error "Expecting to do occurrence analysis before destination passing." - RunState Nothing ini bd -> do - -- If we wanted to be more precise, the summary for the reference should - -- be something about the stuff that might flow into the `put` operations - -- affecting that reference. Using `IxAll` is a conservative - -- approximation (in downstream analysis it means "assume I touch every - -- value"). - bd' <- oneShot a [Deterministic [], IxAll] bd - ini' <- occ accessOnce ini - return $ RunState Nothing ini' bd' - RunState (Just _) _ _ -> - error "Expecting to do occurrence analysis before destination passing." - RunIO bd -> RunIO <$> occ a bd - RunInit _ -> - -- Though this is probably not too hard to implement. Presumably - -- the lambda is one-shot. - error "Expecting to do occurrence analysis before lowering." oneShot :: Access n -> [IxExpr n] -> LamExpr SimpIR n -> OCCM n (LamExpr SimpIR n) oneShot acc [] (LamExpr Empty body) = diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 05546104a..884a9a6af 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -80,7 +80,7 @@ ulExpr expr = case expr of extendSubst (b' @> SubstVal (IdxRepVal i)) $ ulExpr block' inc $ fromIntegral n -- To account for the TabCon we emit below getLamExprType body' >>= \case - PiType (UnaryNest (tb:>_)) (EffTy _ valTy) -> do + PiType (UnaryNest (tb:>_)) valTy -> do let tabTy = toType $ TabPiType (DictCon $ IxRawFin (IdxRepVal n)) (tb:>IdxRepTy) valTy emit $ TabCon tabTy vals _ -> error "Expected `for` body to have a Pi type" @@ -336,5 +336,5 @@ dceBlock (Abs decls ans) = case decls of modify (<>FV (freeVarsB b')) return $ Abs (Nest (Let b' (DeclBinding ann expr')) bs'') ans'' -instance HasDCE (EffectRow SimpIR) -instance HasDCE (EffTy SimpIR) +instance HasDCE (Effects SimpIR) +instance HasDCE (EffTy SimpIR) diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 37348773a..c99d7bbd9 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -20,7 +20,6 @@ import Types.Top import Types.Imp import IRVariants import Core -import Err import Name hiding (withFreshM) import Subst import Util @@ -43,14 +42,10 @@ caseAltsBinderTys ty = case ty of _ -> error msg where msg = "Case analysis only supported on ADTs, not on " ++ pprint ty -extendEffect :: IRRep r => Effect r n -> EffectRow r n -> EffectRow r n -extendEffect eff (EffectRow effs t) = EffectRow (effs <> eSetSingleton eff) t - piTypeWithoutDest :: PiType SimpIR n -> PiType SimpIR n piTypeWithoutDest (PiType bsRefB _) = case popNest bsRefB of - Just (PairB bs (_:>RawRefTy ansTy)) -> do - PiType bs $ EffTy Pure ansTy -- XXX: we ignore the effects here + Just (PairB bs (_:>RawRefTy ansTy)) -> PiType bs ansTy _ -> error "expected trailing dest binder" typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> Atom r n -> m n (Type r n) @@ -65,7 +60,8 @@ typeOfApplyMethod d i args = do typeOfTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (EffTy SimpIR n) typeOfTopApp f xs = do piTy <- getTypeTopFun f - instantiate piTy xs + ty <- instantiate piTy xs + return $ EffTy undefined ty -- TODO typeOfIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Type r n -> Atom r n -> m n (Type r n) typeOfIndexRef (TyCon (RefType h s)) i = do @@ -87,7 +83,9 @@ typeOfProjRef (TyCon (RefType h s)) p = do typeOfProjRef _ _ = error "expected a reference" appEffTy :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (EffTy r n) -appEffTy (TyCon (Pi piTy)) xs = instantiate piTy xs +appEffTy (TyCon (Pi piTy)) xs = do + ty <- instantiate piTy xs + return $ EffTy undefined ty -- TODO appEffTy t _ = error $ "expected a pi type, got: " ++ pprint t partialAppType :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) @@ -103,46 +101,25 @@ effTyOfHof hof = EffTy <$> hofEffects hof <*> typeOfHof hof typeOfHof :: (EnvReader m, IRRep r) => Hof r n -> m n (Type r n) typeOfHof = \case For _ ixTy f -> getLamExprType f >>= \case - PiType (UnaryNest b) (EffTy _ eltTy) -> return $ TabTy (ixTypeDict ixTy) b eltTy + PiType (UnaryNest b) eltTy -> return $ TabTy (ixTypeDict ixTy) b eltTy _ -> error "expected a unary pi type" While _ -> return UnitTy Linearize f _ -> getLamExprType f >>= \case - PiType (UnaryNest (binder:>a)) (EffTy Pure b) -> do + PiType (UnaryNest (binder:>a)) b -> do let b' = ignoreHoistFailure $ hoist binder b - let fLinTy = toType $ nonDepPiType [a] Pure b' + let fLinTy = toType $ nonDepPiType [a] b' return $ PairTy b' fLinTy _ -> error "expected a unary pi type" Transpose f _ -> getLamExprType f >>= \case PiType (UnaryNest (_:>a)) _ -> return a _ -> error "expected a unary pi type" - RunReader _ f -> do - (resultTy, _) <- getTypeRWSAction f - return resultTy - RunWriter _ _ f -> uncurry PairTy <$> getTypeRWSAction f - RunState _ _ f -> do - (resultTy, stateTy) <- getTypeRWSAction f - return $ PairTy resultTy stateTy - RunIO f -> return $ getType f - RunInit f -> return $ getType f - CatchException ty _ -> return ty - -hofEffects :: (EnvReader m, IRRep r) => Hof r n -> m n (EffectRow r n) + +hofEffects :: (EnvReader m, IRRep r) => Hof r n -> m n (Effects r n) hofEffects = \case - For _ _ f -> functionEffs f + For _ _ _ -> undefined -- TODO While body -> return $ getEffects body Linearize _ _ -> return Pure -- Body has to be a pure function Transpose _ _ -> return Pure -- Body has to be a pure function - RunReader _ f -> rwsFunEffects Reader f - RunWriter d _ f -> maybeInit d <$> rwsFunEffects Writer f - RunState d _ f -> maybeInit d <$> rwsFunEffects State f - RunIO f -> return $ deleteEff IOEffect $ getEffects f - RunInit f -> return $ deleteEff InitEffect $ getEffects f - CatchException _ f -> return $ deleteEff ExceptionEffect $ getEffects f - where maybeInit :: IRRep r => Maybe (Atom r i) -> (EffectRow r o -> EffectRow r o) - maybeInit d = case d of Just _ -> (<>OneEffect InitEffect); Nothing -> id - -deleteEff :: IRRep r => Effect r n -> EffectRow r n -> EffectRow r n -deleteEff eff (EffectRow effs t) = EffectRow (effs `eSetDifference` eSetSingleton eff) t getMethodIndex :: EnvReader m => ClassName n -> SourceName -> m n Int getMethodIndex className methodSourceName = do @@ -160,7 +137,7 @@ getUVarType = \case UPunVar v -> getStructDataConType v UClassVar v -> do ClassDef _ _ _ _ roleExpls bs _ _ <- lookupClassDef v - return $ toType $ CorePiType ExplicitApp (map snd roleExpls) bs $ EffTy Pure TyKind + return $ toType $ CorePiType ExplicitApp (map snd roleExpls) bs TyKind UMethodVar v -> getMethodNameType v getMethodNameType :: EnvReader m => MethodName n -> m n (CType n) @@ -196,14 +173,14 @@ getMethodType dict i = do mkCorePiType :: EnvReader m => [CType n] -> CType n -> m n (CorePiType n) mkCorePiType argTys resultTy = liftEnvReaderM $ withFreshBinders argTys \bs _ -> do expls <- return $ nestToList (const Explicit) bs - return $ CorePiType ExplicitApp expls bs (EffTy Pure (sink resultTy)) + return $ CorePiType ExplicitApp expls bs (sink resultTy) getTyConNameType :: EnvReader m => TyConName n -> m n (Type CoreIR n) getTyConNameType v = do TyConDef _ expls bs _ <- lookupTyCon v case bs of Empty -> return TyKind - _ -> return $ toType $ CorePiType ExplicitApp (snd <$> expls) bs $ EffTy Pure TyKind + _ -> return $ toType $ CorePiType ExplicitApp (snd <$> expls) bs TyKind getDataConNameType :: EnvReader m => DataConName n -> m n (Type CoreIR n) getDataConNameType dataCon = liftEnvReaderM $ withSubstReaderT do @@ -217,7 +194,7 @@ getDataConNameType dataCon = liftEnvReaderM $ withSubstReaderT do _ -> ExplicitApp let resultTy = toType $ UserADTType (getSourceName tyConDef) (sink tyCon) (sink params) let dataExpls = nestToList (const $ Explicit) dataBs - return $ toType $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy) + return $ toType $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) resultTy getStructDataConType :: EnvReader m => TyConName n -> m n (CType n) getStructDataConType tyCon = liftEnvReaderM $ withSubstReaderT do @@ -228,7 +205,7 @@ getStructDataConType tyCon = liftEnvReaderM $ withSubstReaderT do let resultTy = toType $ UserADTType (getSourceName tyConDef) (sink tyCon) params Abs dataBs resultTy' <- return $ typesAsBinderNest fieldTys resultTy let dataExpls = nestToList (const Explicit) dataBs - return $ toType $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy') + return $ toType $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) resultTy' buildDataConType :: (EnvReader m, EnvExtender m) @@ -266,35 +243,8 @@ makePreludeMaybeTy ty = do let params = TyConParams [Explicit] [toAtom ty] return $ toType $ UserADTType "Maybe" tyConName params --- === computing effects === - -functionEffs :: (IRRep r, EnvReader m) => LamExpr r n -> m n (EffectRow r n) -functionEffs f = getLamExprType f >>= \case - PiType b (EffTy effs _) -> return $ ignoreHoistFailure $ hoist b effs - -rwsFunEffects :: (IRRep r, EnvReader m) => RWS -> LamExpr r n -> m n (EffectRow r n) -rwsFunEffects rws f = getLamExprType f >>= \case - PiType (BinaryNest h ref) et -> do - let effs' = ignoreHoistFailure $ hoist ref (etEff et) - let hVal = toAtom $ AtomVar (binderName h) (TyCon HeapType) - let effs'' = deleteEff (RWSEffect rws hVal) effs' - return $ ignoreHoistFailure $ hoist h effs'' - _ -> error "Expected a binary function type" - getLamExprType :: (IRRep r, EnvReader m) => LamExpr r n -> m n (PiType r n) -getLamExprType (LamExpr bs body) = - return $ PiType bs $ EffTy (getEffects body) (getType body) - -getTypeRWSAction :: (IRRep r, EnvReader m) => LamExpr r n -> m n (Type r n, Type r n) -getTypeRWSAction f = getLamExprType f >>= \case - PiType (BinaryNest regionBinder refBinder) (EffTy _ resultTy) -> do - case binderType refBinder of - RefTy _ referentTy -> do - let referentTy' = ignoreHoistFailure $ hoist regionBinder referentTy - let resultTy' = ignoreHoistFailure $ hoist (PairB regionBinder refBinder) resultTy - return (resultTy', referentTy') - _ -> error "expected a ref" - _ -> error "expected a pi type" +getLamExprType (LamExpr bs body) = return $ PiType bs (getType body) getSuperclassDicts :: EnvReader m => CDict n -> m n ([CAtom n]) getSuperclassDicts dict = do @@ -326,7 +276,7 @@ liftIFunType :: (IRRep r, EnvReader m) => IFunType -> m n (PiType r n) liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where go :: IRRep r => [BaseType] -> EnvReaderM n (PiType r n) go = \case - [] -> return $ PiType Empty (EffTy (OneEffect IOEffect) resultTy) + [] -> return $ PiType Empty resultTy where resultTy = case resultTys of [] -> UnitTy [t] -> toType $ BaseType t @@ -358,19 +308,7 @@ isData ty = do ProdType as -> mapM_ go as SumType cs -> mapM_ go cs RefType _ _ -> return () - HeapType -> return () TypeKind -> notData DictTy _ -> notData Pi _ -> notData where notData = empty - -checkExtends :: (Fallible m, IRRep r) => EffectRow r n -> EffectRow r n -> m () -checkExtends allowed (EffectRow effs effTail) = do - let (EffectRow allowedEffs allowedEffTail) = allowed - case effTail of - EffectRowTail _ -> assertEq allowedEffTail effTail "" - NoTail -> return () - forM_ (eSetToList effs) \eff -> unless (eff `eSetMember` allowedEffs) $ - throwInternal $ "Unexpected effect: " ++ pprint eff ++ - "\nAllowed: " ++ pprint allowed -{-# INLINE checkExtends #-} diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 26b17feb3..0deff5111 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -16,15 +16,15 @@ class HasType (r::IR) (e::E) | e -> r where getType :: e n -> Type r n class HasEffects (e::E) (r::IR) | e -> r where - getEffects :: e n -> EffectRow r n + getEffects :: e n -> Effects r n getTyCon :: HasType SimpIR e => e n -> TyCon SimpIR n getTyCon e = con where TyCon con = getType e isPure :: (IRRep r, HasEffects e r) => e n -> Bool isPure e = case getEffects e of - Pure -> True - _ -> False + Pure -> True + Effectful -> False -- === querying types implementation === @@ -107,10 +107,8 @@ instance IRRep r => HasType r (Con r) where Lit l -> toType $ BaseType $ litType l ProdCon xs -> toType $ ProdType $ map getType xs SumCon tys _ _ -> toType $ SumType tys - HeapVal -> toType HeapType Lam (CoreLamExpr piTy _) -> toType $ Pi piTy DepPair _ _ ty -> toType $ DepPairTy ty - Eff _ -> EffKind DictConAtom d -> getType d NewtypeCon con _ -> getNewtypeType con TyConAtom _ -> TyKind @@ -195,7 +193,6 @@ instance IRRep r => HasType r (MiscOp r) where getType = \case Select _ x _ -> getType x ThrowError t -> t - ThrowException t -> t CastOp t _ -> t BitcastOp t _ -> t UnsafeCoerce t _ -> t @@ -225,11 +222,11 @@ typesAsBinderNest => [Type r n] -> e n -> Abs (Nest (Binder r)) e n typesAsBinderNest types body = toConstBinderNest types body -nonDepPiType :: [CType n] -> EffectRow CoreIR n -> CType n -> CorePiType n -nonDepPiType argTys eff resultTy = case typesAsBinderNest argTys (PairE eff resultTy) of - Abs bs (PairE eff' resultTy') -> do +nonDepPiType :: [CType n] -> CType n -> CorePiType n +nonDepPiType argTys resultTy = case typesAsBinderNest argTys resultTy of + Abs bs resultTy' -> do let expls = nestToList (const Explicit) bs - CorePiType ExplicitApp expls bs $ EffTy eff' resultTy' + CorePiType ExplicitApp expls bs resultTy' nonDepTabPiType :: IRRep r => IxType r n -> Type r n -> TabPiType r n nonDepTabPiType (IxType t d) resultTy = @@ -277,13 +274,12 @@ instance IRRep r => HasEffects (PrimOp r) r where BinOp _ _ _ -> Pure VectorOp _ -> Pure MemOp op -> case op of - IOAlloc _ -> OneEffect IOEffect - IOFree _ -> OneEffect IOEffect - PtrLoad _ -> OneEffect IOEffect - PtrStore _ _ -> OneEffect IOEffect + IOAlloc _ -> Effectful + IOFree _ -> Effectful + PtrLoad _ -> Effectful + PtrStore _ _ -> Effectful PtrOffset _ _ -> Pure MiscOp op -> case op of - ThrowException _ -> OneEffect ExceptionEffect Select _ _ _ -> Pure ThrowError _ -> Pure CastOp _ _ -> Pure @@ -296,17 +292,17 @@ instance IRRep r => HasEffects (PrimOp r) r where ShowAny _ -> Pure ShowScalar _ -> Pure RefOp ref m -> case getType ref of - TyCon (RefType h _) -> case m of - MGet -> OneEffect (RWSEffect State h) - MPut _ -> OneEffect (RWSEffect State h) - MAsk -> OneEffect (RWSEffect Reader h) + TyCon (RefType _ _) -> case m of + MGet -> Effectful + MPut _ -> Effectful + MAsk -> Effectful -- XXX: We don't verify the base monoid. See note about RunWriter. - MExtend _ _ -> OneEffect (RWSEffect Writer h) + MExtend _ _ -> Effectful IndexRef _ _ -> Pure ProjRef _ _ -> Pure _ -> error "not a ref" DAMOp op -> case op of - Place _ _ -> OneEffect InitEffect + Place _ _ -> Effectful Seq eff _ _ _ _ -> eff RememberDest eff _ _ -> eff AllocDest _ -> Pure -- is this correct? diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 285ac87f8..1cba35575 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -76,7 +76,6 @@ showAnyTyCon tyCon atom = case tyCon of -- aren't user-facing. SumType _ -> printAsConstant RefType _ _ -> printTypeOnly "reference" - HeapType -> printAsConstant ProdType _ -> do xs <- getUnpacked atom parens $ sepBy ", " $ map rec xs @@ -94,7 +93,6 @@ showAnyTyCon tyCon atom = case tyCon of -- Cast to Int so that it prints in decimal instead of hex let intTy = toType $ BaseType (Scalar Int64Type) emit (CastOp intTy n) >>= rec - EffectRowKind -> printAsConstant -- hack to print strings nicely. TODO: make `Char` a newtype UserADTType "List" _ (TyConParams [Explicit] [Con (TyConAtom (BaseType (Scalar (Word8Type))))]) -> do charTab <- applyProjections [ProjectProduct 1, UnwrapNewtype] atom @@ -164,38 +162,33 @@ withBuffer => (forall l . (Emits l, DExt n l) => CAtom l -> BuilderM CoreIR l ()) -> BuilderM CoreIR n (CAtom n) withBuffer cont = do - lam <- withFreshBinder "h" (TyCon HeapType) \h -> do - bufTy <- bufferTy (toAtom $ binderVar h) - withFreshBinder "buf" bufTy \b -> do - let eff = OneEffect (RWSEffect State (toAtom $ sink $ binderVar h)) - body <- buildBlock do - cont $ sink $ toAtom $ binderVar b - return UnitVal - let binders = BinaryNest h b - let expls = [Inferred Nothing Unify, Explicit] - let piTy = CorePiType ExplicitApp expls binders $ EffTy eff UnitTy - let lam = LamExpr (BinaryNest h b) body - return $ toAtom $ CoreLamExpr piTy lam + bufTy <- bufferTy + lam <- withFreshBinder "buf" bufTy \b -> do + body <- buildBlock do + cont $ sink $ toAtom $ binderVar b + return UnitVal + let binders = UnaryNest b + let expls = [Inferred Nothing Unify, Explicit] + let piTy = CorePiType ExplicitApp expls binders UnitTy + let lam = LamExpr (UnaryNest b) body + return $ toAtom $ CoreLamExpr piTy lam applyPreludeFunction "with_stack_internal" [lam] -bufferTy :: EnvReader m => CAtom n -> m n (CType n) -bufferTy h = do +bufferTy :: EnvReader m => m n (CType n) +bufferTy = do t <- strType - return $ RefTy h (PairTy NatTy t) + return $ RefTy State (PairTy NatTy t) -- argument has type `Fin n => Word8` extendBuffer :: (Emits n, CBuilder m) => CAtom n -> CAtom n -> m n () extendBuffer buf tab = do - RefTy h _ <- return $ getType buf TyCon (TabPi t) <- return $ getType tab n <- applyIxMethodCore Size (tabIxType t) [] - void $ applyPreludeFunction "stack_extend_internal" [n, h, buf, tab] + void $ applyPreludeFunction "stack_extend_internal" [n, buf, tab] -- argument has type `Word8` pushBuffer :: (Emits n, CBuilder m) => CAtom n -> CAtom n -> m n () -pushBuffer buf x = do - RefTy h _ <- return $ getType buf - void $ applyPreludeFunction "stack_push_internal" [h, buf, x] +pushBuffer buf x = void $ applyPreludeFunction "stack_push_internal" [buf, x] stringLitAsCharTab :: (Emits n, CBuilder m) => String -> m n (CAtom n) stringLitAsCharTab s = do diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 34308c573..f8cfe7738 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -30,7 +30,6 @@ import QueryType import RuntimePrint import Transpose import Types.Core -import Types.Source import Types.Top import Types.Primitives import Util (enumerate) @@ -196,8 +195,7 @@ getRepType (TyCon con) = case con of BaseType b -> return $ toType $ BaseType b ProdType ts -> toType . ProdType <$> mapM getRepType ts SumType ts -> toType . SumType <$> mapM getRepType ts - RefType h a -> toType <$> (RefType <$> toDataAtomAssumeNoDecls h <*> getRepType a) - HeapType -> return $ toType HeapType + RefType h a -> toType <$> (RefType h <$> getRepType a) DepPairTy (DepPairType expl b r) -> do withSimplifiedBinder b \b' -> do r' <- getRepType r @@ -220,14 +218,12 @@ toDataAtom (Con con) = case con of Lit v -> return $ toAtom $ Lit v ProdCon xs -> toAtom . ProdCon <$> mapM rec xs SumCon tys tag x -> toAtom <$> (SumCon <$> mapM getRepType tys <*> pure tag <*> rec x) - HeapVal -> return $ toAtom HeapVal DepPair x y ty -> do TyCon (DepPairTy ty') <- getRepType $ TyCon $ DepPairTy ty toAtom <$> (DepPair <$> rec x <*> rec y <*> pure ty') NewtypeCon _ x -> rec x Lam _ -> notData DictConAtom _ -> notData - Eff _ -> notData TyConAtom _ -> notData where rec = toDataAtom @@ -632,9 +628,9 @@ simplifyDictMethod absDict@(Abs bs dict) method = do ixMethodType :: IxMethod -> AbsDict n -> EnvReaderM n (PiType CoreIR n) ixMethodType method absDict = do refreshAbs absDict \extraArgBs dict -> do - CorePiType _ _ methodArgs (EffTy _ resultTy) <- getMethodType dict (fromEnum method) + CorePiType _ _ methodArgs resultTy <- getMethodType dict (fromEnum method) let allBs = extraArgBs >>> methodArgs - return $ PiType allBs (EffTy Pure resultTy) + return $ PiType allBs resultTy simplifyAtom :: CAtom i -> SimplifyM i o (CAtom o) simplifyAtom = substM @@ -790,43 +786,6 @@ simplifyHof resultTy = \case SimplifiedBlock body' (CoerceRecon _) <- buildSimplifiedBlock $ simplifyExpr body result <- emitHof $ While body' liftSimpAtom resultTy result - RunReader r lam -> do - r' <- toDataAtom r - (lam', Abs b recon) <- simplifyLam lam - ans <- emitHof $ RunReader r' lam' - let recon' = ignoreHoistFailure $ hoist b recon - applyRecon recon' ans - RunWriter Nothing (BaseMonoid e combine) lam -> do - LamExpr (BinaryNest h (_:>RefTy _ wTy)) _ <- return lam - wTy' <- substM $ ignoreHoistFailure $ hoist h wTy - e' <- toDataAtom e - (combine', CoerceReconAbs) <- simplifyLam combine - (lam', Abs b recon) <- simplifyLam lam - (ans, w) <- fromPair =<< emitHof (RunWriter Nothing (BaseMonoid e' combine') lam') - let recon' = ignoreHoistFailure $ hoist b recon - ans' <- applyRecon recon' ans - w' <- liftSimpAtom wTy' w - return $ PairVal ans' w' - RunWriter _ _ _ -> error "Shouldn't see a RunWriter with a dest in Simplify" - RunState Nothing s lam -> do - s' <- toDataAtom s - sTy <- substM $ getType s - (lam', Abs b recon) <- simplifyLam lam - resultPair <- emitHof $ RunState Nothing s' lam' - (ans, sOut) <- fromPair resultPair - let recon' = ignoreHoistFailure $ hoist b recon - ans' <- applyRecon recon' ans - sOut' <- liftSimpAtom sTy sOut - return $ PairVal ans' sOut' - RunState _ _ _ -> error "Shouldn't see a RunState with a dest in Simplify" - RunIO body -> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body - ans <- emitHof $ RunIO body' - applyRecon recon ans - RunInit body -> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body - ans <- emitHof $ RunInit body' - applyRecon recon ans Linearize lam x -> do x' <- toDataAtom x -- XXX: we're ignoring the result type here, which only makes sense if we're @@ -843,52 +802,6 @@ simplifyHof resultTy = \case x' <- toDataAtom x result <- transpose lam' x' liftSimpAtom resultTy result - CatchException _ body-> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body - block <- liftBuilder $ runSubstReaderT idSubst $ buildBlock $ - exceptToMaybeExpr body' - result <- emit block - case recon of - CoerceRecon ty -> do - maybeTy <- makePreludeMaybeTy ty - liftSimpAtom maybeTy result - LamRecon reconAbs -> fmapMaybe result (applyReconAbs $ sink reconAbs) - --- takes an internal SimpIR Maybe to a CoreIR "prelude Maybe" -fmapMaybe - :: SAtom n -> (forall l. DExt n l => SAtom l -> SimplifyM i l (CAtom l)) - -> SimplifyM i n (CAtom n) -fmapMaybe scrut f = do - ~(MaybeTy justTy) <- return $ getType scrut - (justAlt, resultJustTy) <- withFreshBinder noHint justTy \b -> do - result <- f (toAtom $ binderVar b) - resultTy <- return $ ignoreHoistFailure $ hoist b (getType result) - result' <- preludeJustVal result - return (Abs b result', resultTy) - nothingAlt <- buildAbs noHint UnitTy \_ -> preludeNothingVal $ sink resultJustTy - resultMaybeTy <- makePreludeMaybeTy resultJustTy - reduceACase scrut [nothingAlt, justAlt] resultMaybeTy - --- This is wrong! The correct implementation is below. And yet there's some --- compensatory bug somewhere that means that the wrong answer works and the --- right answer doesn't. Need to investigate. -preludeJustVal :: EnvReader m => CAtom n -> m n (CAtom n) -preludeJustVal x = return x - -- xTy <- getType x - -- con <- preludeMaybeNewtypeCon xTy - -- return $ NewtypeCon con (JustAtom xTy x) - -preludeNothingVal :: EnvReader m => CType n -> m n (CAtom n) -preludeNothingVal ty = do - con <- preludeMaybeNewtypeCon ty - return $ Con $ NewtypeCon con (NothingAtom ty) - -preludeMaybeNewtypeCon :: EnvReader m => CType n -> m n (NewtypeCon n) -preludeMaybeNewtypeCon ty = do - ~(Just (UTyConVar tyConName)) <- lookupSourceMap "Maybe" - TyConDef sn _ _ _ <- lookupTyCon tyConName - let params = TyConParams [Explicit] [toAtom ty] - return $ UserADTData sn tyConName params liftSimpFun :: EnvReader m => Type CoreIR n -> LamExpr SimpIR n -> m n (CAtom n) liftSimpFun (TyCon (Pi piTy)) f = mkStuck $ LiftSimpFun piTy f @@ -1019,85 +932,6 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do let tangentFun = LamExpr (bs >>> residualAndTangentBs) tangentBody return $ PairE primalFun tangentFun --- === exception-handling pass === - -type HandlerM = SubstReaderT AtomSubstVal (BuilderM SimpIR) - -exceptToMaybeBlock :: Emits o => SType o -> SBlock i -> HandlerM i o (SAtom o) -exceptToMaybeBlock ty (Abs Empty result) = do - result' <- exceptToMaybeExpr result - return $ JustAtom ty result' -exceptToMaybeBlock resultTy (Abs (Nest (Let b (DeclBinding _ rhs)) decls) finalResult) = do - maybeResult <- exceptToMaybeExpr rhs - case maybeResult of - -- This case is just an optimization (but an important one!) - JustAtom _ x -> - extendSubst (b@> SubstVal x) $ exceptToMaybeBlock resultTy (Abs decls finalResult) - _ -> emitMaybeCase maybeResult (MaybeTy resultTy) - (return $ NothingAtom $ sink resultTy) - (\v -> extendSubst (b@> SubstVal v) $ - exceptToMaybeBlock (sink resultTy) (Abs decls finalResult)) - -exceptToMaybeExpr :: Emits o => SExpr i -> HandlerM i o (SAtom o) -exceptToMaybeExpr expr = case expr of - Block (EffTy _ ty) body -> do - ty' <- substM ty - exceptToMaybeBlock ty' body - Case e alts (EffTy _ resultTy) -> do - e' <- substM e - resultTy' <- substM $ MaybeTy resultTy - buildCase e' resultTy' \i v -> do - Abs b body <- return $ alts !! i - extendSubst (b @> SubstVal v) do - exceptToMaybeExpr body - Atom x -> do - x' <- substM x - let ty = getType x' - return $ JustAtom ty x' - PrimOp (Hof (TypedHof _ (For ann ixTy' (UnaryLamExpr b body)))) -> do - ixTy <- substM ixTy' - maybes <- buildFor (getNameHint b) ann ixTy \i -> do - extendSubst (b@>Rename (atomVarName i)) $ exceptToMaybeExpr body - catMaybesE maybes - PrimOp (MiscOp (ThrowException _)) -> do - ty <- substM $ getType expr - return $ NothingAtom ty - PrimOp (Hof (TypedHof _ (RunState Nothing s lam))) -> do - s' <- substM s - BinaryLamExpr h ref body <- return lam - result <- emitRunState noHint s' \h' ref' -> - extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do - exceptToMaybeExpr body - (maybeAns, newState) <- fromPair result - a <- substM $ getType expr - emitMaybeCase maybeAns (MaybeTy a) - (return $ NothingAtom $ sink a) - (\ans -> return $ JustAtom (sink a) $ PairVal ans (sink newState)) - PrimOp (Hof (TypedHof (EffTy _ resultTy) (RunWriter Nothing monoid (BinaryLamExpr h ref body)))) -> do - monoid' <- substM monoid - PairTy _ accumTy <- substM resultTy - result <- emitRunWriter noHint accumTy monoid' \h' ref' -> - extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do - exceptToMaybeExpr body - (maybeAns, accumResult) <- fromPair result - a <- substM $ getType expr - emitMaybeCase maybeAns (MaybeTy a) - (return $ NothingAtom $ sink a) - (\ans -> return $ JustAtom (sink a) $ PairVal ans (sink accumResult)) - PrimOp (Hof (TypedHof _ (While body))) -> runMaybeWhile $ exceptToMaybeExpr body - _ -> do - expr' <- substM expr - case hasExceptions expr' of - True -> error $ "Unexpected exception-throwing expression: " ++ pprint expr - False -> do - v <- emit expr' - let ty = getType v - return $ JustAtom ty v - -hasExceptions :: SExpr n -> Bool -hasExceptions expr = case getEffects expr of - EffectRow effs NoTail -> ExceptionEffect `eSetMember` effs - -- === instances === instance GenericE ReconstructAtom where diff --git a/src/lib/SourceIdTraversal.hs b/src/lib/SourceIdTraversal.hs index 9a1a9e09b..2c5e74914 100644 --- a/src/lib/SourceIdTraversal.hs +++ b/src/lib/SourceIdTraversal.hs @@ -73,7 +73,7 @@ instance IsTree Group where CCase scrut alts -> visit scrut >> visit alts CIf scrut ifTrue ifFalse -> visit scrut >> visit ifTrue >> visit ifFalse CDo body -> visit body - CArrow l effs r -> visit l >> visit effs >> visit r + CArrow l r -> visit l >> visit r CWith b body -> visit b >> visit body instance IsTree Bin where diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index 2250b3853..16e484de5 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -170,9 +170,9 @@ instance SourceRenamableE UExpr where UVar v -> UVar <$> sourceRenameE v ULit l -> return $ ULit l ULam lam -> ULam <$> sourceRenameE lam - UPi (UPiExpr pats appExpl eff body) -> + UPi (UPiExpr pats appExpl body) -> sourceRenameB pats \pats' -> - UPi <$> (UPiExpr pats' <$> pure appExpl <*> sourceRenameE eff <*> sourceRenameE body) + UPi <$> (UPiExpr pats' <$> pure appExpl <*> sourceRenameE body) UApp f xs ys -> UApp <$> sourceRenameE f <*> forM xs sourceRenameE <*> forM ys (\(name, y) -> (name,) <$> sourceRenameE y) @@ -205,16 +205,6 @@ instance SourceRenamableE UAlt where sourceRenameB pat \pat' -> UAlt pat' <$> sourceRenameE body -instance SourceRenamableE UEffectRow where - sourceRenameE (UEffectRow row tailVar) = - UEffectRow <$> row' <*> mapM sourceRenameE tailVar - where row' = S.fromList <$> traverse sourceRenameE (S.toList row) - -instance SourceRenamableE UEffect where - sourceRenameE (URWSEffect rws name) = URWSEffect rws <$> sourceRenameE name - sourceRenameE UExceptionEffect = return UExceptionEffect - sourceRenameE UIOEffect = return UIOEffect - instance SourceRenamableB UTopDecl where sourceRenameB decl cont = case decl of ULocalDecl d -> sourceRenameB d \d' -> cont $ ULocalDecl d' @@ -255,10 +245,9 @@ instance SourceRenamableB UDecl where UPass -> cont $ WithSrcB sid UPass instance SourceRenamableE ULamExpr where - sourceRenameE (ULamExpr args expl effs resultTy body) = + sourceRenameE (ULamExpr args expl resultTy body) = sourceRenameB args \args' -> ULamExpr args' <$> pure expl - <*> mapM sourceRenameE effs <*> mapM sourceRenameE resultTy <*> sourceRenameE body diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 9ab73bf7d..e58c549a2 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -486,7 +486,7 @@ evalBlock typed = do llvmOpt <- packageLLVMCallable impOpt resultVals <- liftIO $ callEntryFun llvmOpt [] TopLam _ destTy _ <- return lOpt - EffTy _ resultTy <- return $ assumeConst $ piTypeWithoutDest destTy + resultTy <- return $ assumeConst $ piTypeWithoutDest destTy repValAtom =<< repValFromFlatList resultTy resultVals applyReconTop recon simpResult {-# SCC evalBlock #-} @@ -837,7 +837,7 @@ instance Generic TopStateEx where getLinearizationType :: SymbolicZeros -> CType n -> EnvReaderT Except n (Int, Int, CType n) getLinearizationType zeros = \case - TyCon (Pi (CorePiType ExplicitApp expls bs (EffTy Pure resultTy))) -> do + TyCon (Pi (CorePiType ExplicitApp expls bs resultTy)) -> do (numIs, numEs) <- getNumImplicits expls refreshAbs (Abs bs resultTy) \bs' resultTy' -> do PairB _ bsE <- return $ splitNestAt numIs bs' @@ -850,8 +850,8 @@ getLinearizationType zeros = \case resultTanTy <- maybeTangentType resultTy' >>= \case Just rtt -> return rtt Nothing -> throwErr $ MiscErr $ MiscMiscErr $ "No tangent type for: " ++ pprint resultTy' - let tanFunTy = toType $ Pi $ nonDepPiType argTanTys Pure resultTanTy - let fullTy = CorePiType ExplicitApp expls bs' $ EffTy Pure (PairTy resultTy' tanFunTy) + let tanFunTy = toType $ Pi $ nonDepPiType argTanTys resultTanTy + let fullTy = CorePiType ExplicitApp expls bs' $ PairTy resultTy' tanFunTy return (numIs, numEs, toType $ Pi fullTy) _ -> throwErr $ MiscErr $ MiscMiscErr $ "Can't define a custom linearization for implicit or impure functions" where diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index b9c002468..44e6d4db3 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -13,7 +13,6 @@ import GHC.Stack import Builder import Core -import Imp import IRVariants import Name import PPrint @@ -49,7 +48,7 @@ transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do withAccumulator inTy \refSubstVal -> extendSubst (bLin @> refSubstVal) $ transposeExpr body (sink ct) - let piTy = PiType (bsNonlin' >>> UnaryNest bCT) (EffTy Pure (getType body')) + let piTy = PiType (bsNonlin' >>> UnaryNest bCT) (getType body') let lamT = LamExpr (bsNonlin' >>> UnaryNest bCT) body' return $ TopLam False piTy lamT transposeTopFun (TopLam True _ _) = error "shouldn't be transposing in destination passing style" @@ -61,7 +60,7 @@ unpackLinearLamExpr unpackLinearLamExpr lam@(LamExpr bs body) = do let numNonlin = nestLength bs - 1 PairB bsNonlin (UnaryNest bLin) <- return $ splitNestAt numNonlin bs - PiType bsTy (EffTy _ resultTy) <- getLamExprType lam + PiType bsTy resultTy <- getLamExprType lam PairB bsNonlinTy (UnaryNest bLinTy) <- return $ splitNestAt numNonlin bsTy let resultTy' = ignoreHoistFailure $ hoist bLinTy resultTy return ( Abs bsNonlin $ Abs bLin body @@ -91,19 +90,19 @@ withAccumulator => SType o -> (forall o'. (Emits o', DExt o o') => TransposeSubstVal (AtomNameC SimpIR) o' -> TransposeM i o' ()) -> TransposeM i o (SAtom o) -withAccumulator ty cont = do - singletonTypeVal ty >>= \case - Nothing -> do - baseMonoid <- tangentBaseMonoidFor ty - getSnd =<< emitRunWriter noHint ty baseMonoid \_ ref -> - cont (LinRef $ toAtom ref) >> return UnitVal - Just val -> do - -- If the accumulator's type is inhabited by just one value, we - -- don't need any actual accumulation, and can just return that - -- value. (We still run `cont` because it may emit decls that - -- have effects.) - Distinct <- getDistinct - cont LinTrivial >> return val +withAccumulator ty cont = undefined + -- singletonTypeVal ty >>= \case + -- Nothing -> do + -- baseMonoid <- tangentBaseMonoidFor ty + -- getSnd =<< emitRunWriter noHint ty baseMonoid \_ ref -> + -- cont (LinRef $ toAtom ref) >> return UnitVal + -- Just val -> do + -- -- If the accumulator's type is inhabited by just one value, we + -- -- don't need any actual accumulation, and can just return that + -- -- value. (We still run `cont` because it may emit decls that + -- -- have effects.) + -- Distinct <- getDistinct + -- cont LinTrivial >> return val emitCTToRef :: (Emits n, Builder SimpIR m) => SAtom n -> SAtom n -> m n () emitCTToRef ref ct = do @@ -228,7 +227,6 @@ transposeMiscOp op _ = case op of ThrowError _ -> notLinear SumTag _ -> notLinear ToEnum _ _ -> notLinear - ThrowException _ -> notLinear OutputStream -> notLinear Select _ _ _ -> notImplemented CastOp _ _ -> notImplemented @@ -263,28 +261,6 @@ transposeHof hof ct = case hof of ctElt <- tabApp (sink ct) (toAtom i) extendSubst (b@>RenameNonlin (atomVarName i)) $ transposeExpr body ctElt return UnitVal - RunState Nothing s (BinaryLamExpr hB refB body) -> do - (ctBody, ctState) <- fromPair ct - (_, cts) <- (fromPair =<<) $ emitRunState noHint ctState \h ref -> do - extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - transposeExpr body (sink ctBody) - return UnitVal - transposeAtom s cts - RunReader r (BinaryLamExpr hB refB body) -> do - accumTy <- substNonlin $ getType r - baseMonoid <- tangentBaseMonoidFor accumTy - (_, ct') <- (fromPair =<<) $ emitRunWriter noHint accumTy baseMonoid \h ref -> do - extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - transposeExpr body (sink ct) - return UnitVal - transposeAtom r ct' - RunWriter Nothing _ (BinaryLamExpr hB refB body)-> do - -- TODO: check we have the 0/+ monoid - (ctBody, ctEff) <- fromPair ct - void $ emitRunReader noHint ctEff \h ref -> do - extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - transposeExpr body (sink ctBody) - return UnitVal _ -> notImplemented transposeCon :: Emits o => Con SimpIR i -> SAtom o -> TransposeM i o () @@ -293,9 +269,7 @@ transposeCon con ct = case con of ProdCon [] -> return () ProdCon xs -> forM_ (enumerate xs) \(i, x) -> proj i ct >>= transposeAtom x SumCon _ _ _ -> notImplemented - HeapVal -> notTangent DepPair _ _ _ -> notImplemented - where notTangent = error $ "Not a tangent atom: " ++ pprint (Con con) notImplemented :: HasCallStack => a notImplemented = error "Not implemented" diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 63e005c2f..b82f835fa 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -52,10 +52,8 @@ data Con (r::IR) (n::S) where Lit :: LitVal -> Con r n ProdCon :: [Atom r n] -> Con r n SumCon :: [Type r n] -> Int -> Atom r n -> Con r n -- type, tag, payload - HeapVal :: Con r n DepPair :: Atom r n -> Atom r n -> DepPairType r n -> Con r n Lam :: CoreLamExpr n -> Con CoreIR n - Eff :: EffectRow CoreIR n -> Con CoreIR n NewtypeCon :: NewtypeCon n -> Atom CoreIR n -> Con CoreIR n DictConAtom :: DictCon CoreIR n -> Con CoreIR n TyConAtom :: TyCon CoreIR n -> Con CoreIR n @@ -80,8 +78,7 @@ data TyCon (r::IR) (n::S) where BaseType :: BaseType -> TyCon r n ProdType :: [Type r n] -> TyCon r n SumType :: [Type r n] -> TyCon r n - RefType :: Atom r n -> Type r n -> TyCon r n - HeapType :: TyCon r n + RefType :: RWS -> Type r n -> TyCon r n TabPi :: TabPiType r n -> TyCon r n DepPairTy :: DepPairType r n -> TyCon r n TypeKind :: TyCon CoreIR n @@ -108,7 +105,6 @@ deriving via WrapE (TyCon r) n instance IRRep r => Generic (TyCon r n) deriving via WrapE (Type r) n instance IRRep r => Generic (Type r n) deriving via WrapE (Stuck r) n instance IRRep r => Generic (Stuck r n) --- TODO: factor out the EffTy and maybe merge with PrimOp data Expr r n where Block :: EffTy r n -> Block r n -> Expr r n TopApp :: EffTy SimpIR n -> TopFunName n -> [SAtom n] -> Expr SimpIR n @@ -149,7 +145,6 @@ type AtomNameBinder (r::IR) = NameBinder (AtomNameC r) type ClassName = Name ClassNameC type TyConName = Name TyConNameC type DataConName = Name DataConNameC -type EffectName = Name EffectNameC type InstanceName = Name InstanceNameC type MethodName = Name MethodNameC type ModuleName = Name ModuleNameC @@ -216,10 +211,10 @@ data TabPiType (r::IR) (n::S) where TabPiType :: IxDict r n -> Binder r n l -> Type r l -> TabPiType r n data PiType (r::IR) (n::S) where - PiType :: Nest (Binder r) n l -> EffTy r l -> PiType r n + PiType :: Nest (Binder r) n l -> Type r l -> PiType r n data CorePiType (n::S) where - CorePiType :: AppExplicitness -> [Explicitness] -> Nest CBinder n l -> EffTy CoreIR l -> CorePiType n + CorePiType :: AppExplicitness -> [Explicitness] -> Nest CBinder n l -> CType l -> CorePiType n data DepPairType (r::IR) (n::S) where DepPairType :: DepPairExplicitness -> Binder r n l -> Type r l -> DepPairType r n @@ -239,8 +234,8 @@ data NonDepNest r ann n l = NonDepNest (Nest (AtomNameBinder r) n l) [ann n] class ToBindersAbs (e::E) (body::E) (r::IR) | e -> body, e -> r where toAbs :: e n -> Abs (Nest (Binder r)) body n -instance ToBindersAbs CorePiType (EffTy CoreIR) CoreIR where - toAbs (CorePiType _ _ bs effTy) = Abs bs effTy +instance ToBindersAbs CorePiType (Type CoreIR) CoreIR where + toAbs (CorePiType _ _ bs ty) = Abs bs ty instance ToBindersAbs CoreLamExpr (Expr CoreIR) CoreIR where toAbs (CoreLamExpr _ lam) = toAbs lam @@ -248,8 +243,8 @@ instance ToBindersAbs CoreLamExpr (Expr CoreIR) CoreIR where instance ToBindersAbs (Abs (Nest (Binder r)) body) body r where toAbs = id -instance ToBindersAbs (PiType r) (EffTy r) r where - toAbs (PiType bs effTy) = Abs bs effTy +instance ToBindersAbs (PiType r) (Type r) r where + toAbs (PiType bs ty) = Abs bs ty instance ToBindersAbs (LamExpr r) (Expr r) r where toAbs (LamExpr bs body) = Abs bs body @@ -344,9 +339,7 @@ data MiscOp (r::IR) (n::S) = | BitcastOp (Type r n) (Atom r n) -- (2) Type, then value. See CheckType.hs for valid coercions. | UnsafeCoerce (Type r n) (Atom r n) -- type, then value. Assumes runtime representation is the same. | GarbageVal (Type r n) -- type of value (assume `Data` constraint) - -- Effects | ThrowError (Type r n) -- (1) Hard error (parameterized by result type) - | ThrowException (Type r n) -- (1) Catchable exceptions (unlike `ThrowError`) -- Tag of a sum type | SumTag (Atom r n) -- Create an enum (payload-free ADT) from a Word8 @@ -375,22 +368,16 @@ data TypedHof r n = TypedHof (EffTy r n) (Hof r n) data Hof r n where For :: ForAnn -> IxType r n -> LamExpr r n -> Hof r n While :: Expr r n -> Hof r n - RunReader :: Atom r n -> LamExpr r n -> Hof r n - RunWriter :: Maybe (Atom r n) -> BaseMonoid r n -> LamExpr r n -> Hof r n - RunState :: Maybe (Atom r n) -> Atom r n -> LamExpr r n -> Hof r n -- dest, initial value, body lambda - RunIO :: Expr r n -> Hof r n - RunInit :: Expr r n -> Hof r n - CatchException :: CType n -> Expr CoreIR n -> Hof CoreIR n - Linearize :: LamExpr CoreIR n -> Atom CoreIR n -> Hof CoreIR n - Transpose :: LamExpr CoreIR n -> Atom CoreIR n -> Hof CoreIR n + Linearize :: LamExpr CoreIR n -> Atom CoreIR n -> Hof CoreIR n + Transpose :: LamExpr CoreIR n -> Atom CoreIR n -> Hof CoreIR n deriving instance IRRep r => Show (Hof r n) deriving via WrapE (Hof r) n instance IRRep r => Generic (Hof r n) -- Ops for "Dex Abstract Machine" data DAMOp r n = - Seq (EffectRow r n) Direction (IxType r n) (Atom r n) (LamExpr r n) -- ix dict, carry dests, body lambda - | RememberDest (EffectRow r n) (Atom r n) (LamExpr r n) + Seq (Effects r n) Direction (IxType r n) (Atom r n) (LamExpr r n) -- ix dict, carry dests, body lambda + | RememberDest (Effects r n) (Atom r n) (LamExpr r n) | AllocDest (Type r n) -- type | Place (Atom r n) (Atom r n) -- reference, value | Freeze (Atom r n) -- reference @@ -445,7 +432,6 @@ data NewtypeCon (n::S) = data NewtypeTyCon (n::S) = Nat | Fin (Atom CoreIR n) - | EffectRowKind | UserADTType SourceName (TyConName n) (TyConParams n) deriving (Show, Generic) @@ -501,120 +487,21 @@ data DictCon (r::IR) (n::S) where IxRawFin :: Atom r n -> DictCon r n IxSpecialized :: SpecDictName n -> [SAtom n] -> DictCon SimpIR n +data Effects (r::IR) (n::S) = Pure | Effectful + deriving (Generic, Show) -data EffectOpDef (n::S) where - EffectOpDef :: EffectName n -- name of associated effect - -> EffectOpIdx -- index in effect definition - -> EffectOpDef n - deriving (Show, Generic) +instance Semigroup (Effects r n) where + Pure <> Pure = Pure + _ <> _ = Effectful -instance GenericE EffectOpDef where - type RepE EffectOpDef = - EffectName `PairE` LiftE EffectOpIdx - fromE (EffectOpDef name idx) = name `PairE` LiftE idx - toE (name `PairE` LiftE idx) = EffectOpDef name idx - -instance SinkableE EffectOpDef -instance HoistableE EffectOpDef -instance RenameE EffectOpDef - -data EffectOpIdx = ReturnOp | OpIdx Int - deriving (Show, Eq, Generic) - -data EffectDef (n::S) where - EffectDef :: SourceName - -> [(SourceName, EffectOpType n)] - -> EffectDef n - -instance GenericE EffectDef where - type RepE EffectDef = - LiftE SourceName `PairE` ListE (LiftE SourceName `PairE` EffectOpType) - fromE (EffectDef name ops) = - LiftE name `PairE` ListE (map (\(x, y) -> LiftE x `PairE` y) ops) - toE (LiftE name `PairE` ListE ops) = - EffectDef name (map (\(LiftE x `PairE` y)->(x,y)) ops) - -instance SinkableE EffectDef -instance HoistableE EffectDef -instance AlphaEqE EffectDef -instance AlphaHashableE EffectDef -instance RenameE EffectDef -deriving instance Show (EffectDef n) -deriving via WrapE EffectDef n instance Generic (EffectDef n) - -data EffectOpType (n::S) where - EffectOpType :: UResumePolicy -> CType n -> EffectOpType n - -instance GenericE EffectOpType where - type RepE EffectOpType = - LiftE UResumePolicy `PairE` CType - fromE (EffectOpType pol ty) = LiftE pol `PairE` ty - toE (LiftE pol `PairE` ty) = EffectOpType pol ty - -instance SinkableE EffectOpType -instance HoistableE EffectOpType -instance AlphaEqE EffectOpType -instance AlphaHashableE EffectOpType -instance RenameE EffectOpType -deriving instance Show (EffectOpType n) -deriving via WrapE EffectOpType n instance Generic (EffectOpType n) - --- === effects === - -data Effect (r::IR) (n::S) = - RWSEffect RWS (Atom r n) - | ExceptionEffect - | IOEffect - | InitEffect -- Internal effect modeling writing to a destination. - deriving (Generic, Show) - -data EffectRow (r::IR) (n::S) = - EffectRow (ESet (Effect r) n) (EffectRowTail r n) - deriving (Generic) - -data EffectRowTail (r::IR) (n::S) where - EffectRowTail :: AtomVar CoreIR n -> EffectRowTail CoreIR n - NoTail :: EffectRowTail r n -deriving instance IRRep r => Show (EffectRowTail r n) -deriving instance IRRep r => Eq (EffectRowTail r n) -deriving via WrapE (EffectRowTail r) n instance IRRep r => Generic (EffectRowTail r n) +instance Monoid (Effects r n) where + mempty = Pure data EffTy (r::IR) (n::S) = - EffTy { etEff :: EffectRow r n + EffTy { etEff :: Effects r n , etTy :: Type r n } deriving (Generic, Show) -deriving instance IRRep r => Show (EffectRow r n) - -pattern Pure :: IRRep r => EffectRow r n -pattern Pure <- ((\(EffectRow effs t) -> (eSetToList effs, t)) -> ([], NoTail)) - where Pure = EffectRow mempty NoTail - -pattern OneEffect :: IRRep r => Effect r n -> EffectRow r n -pattern OneEffect eff <- ((\(EffectRow effs t) -> (eSetToList effs, t)) -> ([eff], NoTail)) - where OneEffect eff = EffectRow (eSetSingleton eff) NoTail - -instance IRRep r => Semigroup (EffectRow r n) where - EffectRow effs t <> EffectRow effs' t' = - EffectRow (effs <> effs') newTail - where - newTail = case (t, t') of - (NoTail, effTail) -> effTail - (effTail, NoTail) -> effTail - _ | t == t' -> t - | otherwise -> error "Can't combine effect rows with mismatched tails" - -instance IRRep r => Monoid (EffectRow r n) where - mempty = EffectRow mempty NoTail - -extendEffRow :: IRRep r => ESet (Effect r) n -> EffectRow r n -> EffectRow r n -extendEffRow effs (EffectRow effs' t) = EffectRow (effs <> effs') t -{-# INLINE extendEffRow #-} - -instance IRRep r => Store (EffectRowTail r n) -instance IRRep r => Store (EffectRow r n) -instance IRRep r => Store (Effect r n) - -- === Binder utils === binderType :: Binder r n l -> Type r n @@ -639,7 +526,6 @@ instance ToAtom (Atom r) r where toAtom = id instance ToAtom (Con r) r where toAtom = Con instance ToAtom (TyCon CoreIR) CoreIR where toAtom = Con . TyConAtom instance ToAtom (DictCon CoreIR) CoreIR where toAtom = Con . DictConAtom -instance ToAtom (EffectRow CoreIR) CoreIR where toAtom = Con . Eff instance ToAtom CoreLamExpr CoreIR where toAtom = Con . Lam instance ToAtom DictType CoreIR where toAtom = Con . TyConAtom . DictTy instance ToAtom NewtypeTyCon CoreIR where toAtom = Con . TyConAtom . NewtypeTyCon @@ -763,11 +649,11 @@ pattern BaseTy b = TyCon (BaseType b) pattern PtrTy :: PtrType -> Type r n pattern PtrTy ty = TyCon (BaseType (PtrType ty)) -pattern RefTy :: Atom r n -> Type r n -> Type r n +pattern RefTy :: RWS -> Type r n -> Type r n pattern RefTy r a = TyCon (RefType r a) pattern RawRefTy :: Type r n -> Type r n -pattern RawRefTy a = TyCon (RefType (Con HeapVal) a) +pattern RawRefTy a = TyCon (RefType State a) pattern TabTy :: IxDict r n -> Binder r n l -> Type r l -> Type r n pattern TabTy d b body = TyCon (TabPi (TabPiType d b body)) @@ -784,9 +670,6 @@ pattern NatVal n = Con (NewtypeCon NatCon (IdxRepVal n)) pattern TyKind :: Kind CoreIR n pattern TyKind = TyCon TypeKind -pattern EffKind :: Kind CoreIR n -pattern EffKind = TyCon (NewtypeTyCon EffectRowKind) - pattern FinConst :: Word32 -> Type CoreIR n pattern FinConst n = TyCon (NewtypeTyCon (Fin (NatVal n))) @@ -934,23 +817,20 @@ instance AlphaHashableE NewtypeCon instance RenameE NewtypeCon instance GenericE NewtypeTyCon where - type RepE NewtypeTyCon = EitherE4 + type RepE NewtypeTyCon = EitherE3 {- Nat -} UnitE {- Fin -} CAtom - {- EffectRowKind -} UnitE {- UserADTType -} (LiftE SourceName `PairE` TyConName `PairE` TyConParams) fromE = \case Nat -> Case0 UnitE Fin n -> Case1 n - EffectRowKind -> Case2 UnitE - UserADTType s d p -> Case3 (LiftE s `PairE` d `PairE` p) + UserADTType s d p -> Case2 (LiftE s `PairE` d `PairE` p) {-# INLINE fromE #-} toE = \case Case0 UnitE -> Nat Case1 n -> Fin n - Case2 UnitE -> EffectRowKind - Case3 (LiftE s `PairE` d `PairE` p) -> UserADTType s d p + Case2 (LiftE s `PairE` d `PairE` p) -> UserADTType s d p _ -> error "impossible" {-# INLINE toE #-} @@ -975,8 +855,8 @@ instance IRRep r => AlphaHashableE (BaseMonoid r) instance IRRep r => GenericE (DAMOp r) where type RepE (DAMOp r) = EitherE5 - {- Seq -} (EffectRow r `PairE` LiftE Direction `PairE` IxType r `PairE` Atom r `PairE` LamExpr r) - {- RememberDest -} (EffectRow r `PairE` Atom r `PairE` LamExpr r) + {- Seq -} (Effects r `PairE` LiftE Direction `PairE` IxType r `PairE` Atom r `PairE` LamExpr r) + {- RememberDest -} (Effects r `PairE` Atom r `PairE` LamExpr r) {- AllocDest -} (Type r) {- Place -} (Atom r `PairE` Atom r) {- Freeze -} (Atom r) @@ -1016,47 +896,23 @@ instance IRRep r => AlphaEqE (TypedHof r) instance IRRep r => AlphaHashableE (TypedHof r) instance IRRep r => GenericE (Hof r) where - type RepE (Hof r) = EitherE2 - (EitherE6 + type RepE (Hof r) = EitherE4 {- For -} (LiftE ForAnn `PairE` IxType r `PairE` LamExpr r) {- While -} (Expr r) - {- RunReader -} (Atom r `PairE` LamExpr r) - {- RunWriter -} (MaybeE (Atom r) `PairE` BaseMonoid r `PairE` LamExpr r) - {- RunState -} (MaybeE (Atom r) `PairE` Atom r `PairE` LamExpr r) - {- RunIO -} (Expr r) - ) (EitherE4 - {- RunInit -} (Expr r) - {- CatchException -} (WhenCore r (Type r `PairE` Expr r)) - {- Linearize -} (WhenCore r (LamExpr r `PairE` Atom r)) - {- Transpose -} (WhenCore r (LamExpr r `PairE` Atom r))) + {- Linearize -} (WhenCore r (LamExpr r `PairE` Atom r)) + {- Transpose -} (WhenCore r (LamExpr r `PairE` Atom r)) fromE = \case - For ann d body -> Case0 (Case0 (LiftE ann `PairE` d `PairE` body)) - While body -> Case0 (Case1 body) - RunReader x body -> Case0 (Case2 (x `PairE` body)) - RunWriter d bm body -> Case0 (Case3 (toMaybeE d `PairE` bm `PairE` body)) - RunState d x body -> Case0 (Case4 (toMaybeE d `PairE` x `PairE` body)) - RunIO body -> Case0 (Case5 body) - RunInit body -> Case1 (Case0 body) - CatchException ty body -> Case1 (Case1 (WhenIRE (ty `PairE` body))) - Linearize body x -> Case1 (Case2 (WhenIRE (PairE body x))) - Transpose body x -> Case1 (Case3 (WhenIRE (PairE body x))) + For ann d body -> Case0 (LiftE ann `PairE` d `PairE` body) + While body -> Case1 body + Linearize body x -> Case2 (WhenIRE (PairE body x)) + Transpose body x -> Case3 (WhenIRE (PairE body x)) {-# INLINE fromE #-} toE = \case - Case0 hof -> case hof of - Case0 (LiftE ann `PairE` d `PairE` body) -> For ann d body - Case1 body -> While body - Case2 (x `PairE` body) -> RunReader x body - Case3 (d `PairE` bm `PairE` body) -> RunWriter (fromMaybeE d) bm body - Case4 (d `PairE` x `PairE` body) -> RunState (fromMaybeE d) x body - Case5 body -> RunIO body - _ -> error "impossible" - Case1 hof -> case hof of - Case0 body -> RunInit body - Case1 (WhenIRE (ty `PairE` body)) -> CatchException ty body - Case2 (WhenIRE (PairE body x)) -> Linearize body x - Case3 (WhenIRE (PairE body x)) -> Transpose body x - _ -> error "impossible" + Case0 (LiftE ann `PairE` d `PairE` body) -> For ann d body + Case1 body -> While body + Case2 (WhenIRE (PairE body x)) -> Linearize body x + Case3 (WhenIRE (PairE body x)) -> Transpose body x _ -> error "impossible" {-# INLINE toE #-} @@ -1380,7 +1236,6 @@ instance GenericOp MiscOp where UnsafeCoerce t x -> GenericOpRep P.UnsafeCoerce [t] [x] [] GarbageVal t -> GenericOpRep P.GarbageVal [t] [] [] ThrowError t -> GenericOpRep P.ThrowError [t] [] [] - ThrowException t -> GenericOpRep P.ThrowException [t] [] [] SumTag x -> GenericOpRep P.SumTag [] [x] [] ToEnum t x -> GenericOpRep P.ToEnum [t] [x] [] OutputStream -> GenericOpRep P.OutputStream [] [] [] @@ -1394,7 +1249,6 @@ instance GenericOp MiscOp where GenericOpRep P.UnsafeCoerce [t] [x] [] -> Just $ UnsafeCoerce t x GenericOpRep P.GarbageVal [t] [] [] -> Just $ GarbageVal t GenericOpRep P.ThrowError [t] [] [] -> Just $ ThrowError t - GenericOpRep P.ThrowException [t] [] [] -> Just $ ThrowException t GenericOpRep P.SumTag [] [x] [] -> Just $ SumTag x GenericOpRep P.ToEnum [t] [x] [] -> Just $ ToEnum t x GenericOpRep P.OutputStream [] [] [] -> Just $ OutputStream @@ -1415,15 +1269,13 @@ instance IRRep r => RenameE (MiscOp r) instance IRRep r => GenericE (Con r) where type RepE (Con r) = EitherE2 - (EitherE5 + (EitherE4 {- Lit -} (LiftE LitVal) {- ProdCon -} (ListE (Atom r)) {- SumCon -} (ListE (Type r) `PairE` LiftE Int `PairE` Atom r) - {- HeapVal -} UnitE {- DepPair -} (Atom r `PairE` Atom r `PairE` DepPairType r)) - (WhenCore r (EitherE5 + (WhenCore r (EitherE4 {- Lam -} CoreLamExpr - {- Eff -} (EffectRow CoreIR) {- NewtypeCon -} (NewtypeCon `PairE` CAtom) {- DictConAtom -} (DictCon CoreIR) {- TyConAtom -} (TyCon CoreIR))) @@ -1431,28 +1283,24 @@ instance IRRep r => GenericE (Con r) where Lit l -> Case0 $ Case0 $ LiftE l ProdCon xs -> Case0 $ Case1 $ ListE xs SumCon ts i x -> Case0 $ Case2 $ ListE ts `PairE` LiftE i `PairE` x - HeapVal -> Case0 $ Case3 $ UnitE - DepPair x y t -> Case0 $ Case4 $ x `PairE` y `PairE` t + DepPair x y t -> Case0 $ Case3 $ x `PairE` y `PairE` t Lam lam -> Case1 $ WhenIRE $ Case0 lam - Eff effs -> Case1 $ WhenIRE $ Case1 effs - NewtypeCon con x -> Case1 $ WhenIRE $ Case2 $ con `PairE` x - DictConAtom con -> Case1 $ WhenIRE $ Case3 con - TyConAtom tc -> Case1 $ WhenIRE $ Case4 tc + NewtypeCon con x -> Case1 $ WhenIRE $ Case1 $ con `PairE` x + DictConAtom con -> Case1 $ WhenIRE $ Case2 con + TyConAtom tc -> Case1 $ WhenIRE $ Case3 tc {-# INLINE fromE #-} toE = \case Case0 con -> case con of Case0 (LiftE l) -> Lit l Case1 (ListE xs) -> ProdCon xs Case2 (ListE ts `PairE` LiftE i `PairE` x) -> SumCon ts i x - Case3 UnitE -> HeapVal - Case4 (x `PairE` y `PairE` t) -> DepPair x y t + Case3 (x `PairE` y `PairE` t) -> DepPair x y t _ -> error "impossible" Case1 (WhenIRE con) -> case con of Case0 lam -> Lam lam - Case1 effs -> Eff effs - Case2 (con' `PairE` x) -> NewtypeCon con' x - Case3 con' -> DictConAtom con' - Case4 tc -> TyConAtom tc + Case1 (con' `PairE` x) -> NewtypeCon con' x + Case2 con' -> DictConAtom con' + Case3 tc -> TyConAtom tc _ -> error "impossible" _ -> error "impossible" {-# INLINE toE #-} @@ -1469,9 +1317,8 @@ instance IRRep r => GenericE (TyCon r) where {- BaseType -} (LiftE BaseType) {- ProdType -} (ListE (Type r)) {- SumType -} (ListE (Type r)) - {- RefType -} (Atom r `PairE` Type r)) - (EitherE4 - {- HeapType -} UnitE + {- RefType -} (LiftE RWS `PairE` Type r)) + (EitherE3 {- TabPi -} (TabPiType r) {- DepPairTy -} (DepPairType r) {- TypeKind -} (WhenCore r UnitE)) @@ -1483,11 +1330,10 @@ instance IRRep r => GenericE (TyCon r) where BaseType b -> Case0 (Case0 (LiftE b)) ProdType ts -> Case0 (Case1 (ListE ts)) SumType ts -> Case0 (Case2 (ListE ts)) - RefType h t -> Case0 (Case3 (h `PairE` t)) - HeapType -> Case1 (Case0 UnitE) - TabPi t -> Case1 (Case1 t) - DepPairTy t -> Case1 (Case2 t) - TypeKind -> Case1 (Case3 (WhenIRE UnitE)) + RefType h t -> Case0 (Case3 (LiftE h `PairE` t)) + TabPi t -> Case1 (Case0 t) + DepPairTy t -> Case1 (Case1 t) + TypeKind -> Case1 (Case2 (WhenIRE UnitE)) DictTy t -> Case2 (Case0 (WhenIRE t)) Pi t -> Case2 (Case1 (WhenIRE t)) NewtypeTyCon t -> Case2 (Case2 (WhenIRE t)) @@ -1497,13 +1343,12 @@ instance IRRep r => GenericE (TyCon r) where Case0 (LiftE b ) -> BaseType b Case1 (ListE ts) -> ProdType ts Case2 (ListE ts) -> SumType ts - Case3 (h `PairE` t) -> RefType h t + Case3 (LiftE h `PairE` t) -> RefType h t _ -> error "impossible" Case1 c -> case c of - Case0 UnitE -> HeapType - Case1 t -> TabPi t - Case2 t -> DepPairTy t - Case3 (WhenIRE UnitE) -> TypeKind + Case0 t -> TabPi t + Case1 t -> DepPairTy t + Case2 (WhenIRE UnitE) -> TypeKind _ -> error "impossible" Case2 c -> case c of Case0 (WhenIRE t) -> DictTy t @@ -1670,7 +1515,7 @@ instance AlphaHashableE CoreLamExpr instance RenameE CoreLamExpr instance GenericE CorePiType where - type RepE CorePiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) (EffTy CoreIR) + type RepE CorePiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) (Type CoreIR) fromE (CorePiType ex exs b effTy) = LiftE (ex, exs) `PairE` Abs b effTy {-# INLINE fromE #-} toE (LiftE (ex, exs) `PairE` Abs b effTy) = CorePiType ex exs b effTy @@ -1725,7 +1570,7 @@ deriving instance IRRep r => Show (TabPiType r n) deriving via WrapE (TabPiType r) n instance IRRep r => Generic (TabPiType r n) instance GenericE (PiType r) where - type RepE (PiType r) = Abs (Nest (Binder r)) (EffTy r) + type RepE (PiType r) = Abs (Nest (Binder r)) (Type r) fromE (PiType bs effTy) = Abs bs effTy {-# INLINE fromE #-} toE (Abs bs effTy) = PiType bs effTy @@ -1768,6 +1613,36 @@ instance RenameE DotMethods instance AlphaEqE DotMethods instance AlphaHashableE DotMethods +instance IRRep r => GenericE (Effects r) where + type RepE (Effects r) = EitherE UnitE UnitE + fromE = \case + Pure -> LeftE UnitE + Effectful -> RightE UnitE + {-# INLINE fromE #-} + toE = \case + LeftE UnitE -> Pure + RightE UnitE -> Effectful + {-# INLINE toE #-} + +instance IRRep r => SinkableE (Effects r) +instance IRRep r => HoistableE (Effects r) +instance IRRep r => RenameE (Effects r) +instance IRRep r => AlphaEqE (Effects r) +instance IRRep r => AlphaHashableE (Effects r) + +instance IRRep r => GenericE (EffTy r) where + type RepE (EffTy r) = PairE (Effects r) (Type r) + fromE (EffTy eff ty) = eff `PairE` ty + {-# INLINE fromE #-} + toE (eff `PairE` ty) = EffTy eff ty + {-# INLINE toE #-} + +instance IRRep r => SinkableE (EffTy r) +instance IRRep r => HoistableE (EffTy r) +instance IRRep r => RenameE (EffTy r) +instance IRRep r => AlphaEqE (EffTy r) +instance IRRep r => AlphaHashableE (EffTy r) + instance IRRep r => GenericE (DeclBinding r) where type RepE (DeclBinding r) = LiftE LetAnn `PairE` Expr r fromE (DeclBinding ann expr) = LiftE ann `PairE` expr @@ -1796,74 +1671,6 @@ instance IRRep r => AlphaHashableB (Decl r) instance IRRep r => ProvesExt (Decl r) instance IRRep r => BindsNames (Decl r) -instance IRRep r => GenericE (Effect r) where - type RepE (Effect r) = - EitherE3 (PairE (LiftE RWS) (Atom r)) - (LiftE (Either () ())) - UnitE - fromE = \case - RWSEffect rws h -> Case0 (PairE (LiftE rws) h) - ExceptionEffect -> Case1 (LiftE (Left ())) - IOEffect -> Case1 (LiftE (Right ())) - InitEffect -> Case2 UnitE - {-# INLINE fromE #-} - toE = \case - Case0 (PairE (LiftE rws) h) -> RWSEffect rws h - Case1 (LiftE (Left ())) -> ExceptionEffect - Case1 (LiftE (Right ())) -> IOEffect - Case2 UnitE -> InitEffect - _ -> error "unreachable" - {-# INLINE toE #-} - -instance IRRep r => SinkableE (Effect r) -instance IRRep r => HoistableE (Effect r) -instance IRRep r => AlphaEqE (Effect r) -instance IRRep r => AlphaHashableE (Effect r) -instance IRRep r => RenameE (Effect r) - -instance IRRep r => GenericE (EffectRow r) where - type RepE (EffectRow r) = PairE (ListE (Effect r)) (EffectRowTail r) - fromE (EffectRow effs ext) = ListE (eSetToList effs) `PairE` ext - {-# INLINE fromE #-} - toE (ListE effs `PairE` ext) = EffectRow (eSetFromList effs) ext - {-# INLINE toE #-} - -instance IRRep r => SinkableE (EffectRow r) -instance IRRep r => HoistableE (EffectRow r) -instance IRRep r => RenameE (EffectRow r) -instance IRRep r => AlphaEqE (EffectRow r) -instance IRRep r => AlphaHashableE (EffectRow r) - -instance IRRep r => GenericE (EffectRowTail r) where - type RepE (EffectRowTail r) = EitherE (WhenCore r (AtomVar CoreIR)) UnitE - fromE = \case - EffectRowTail v -> LeftE (WhenIRE v) - NoTail -> RightE UnitE - {-# INLINE fromE #-} - toE = \case - LeftE (WhenIRE v) -> EffectRowTail v - RightE UnitE -> NoTail - {-# INLINE toE #-} - -instance IRRep r => SinkableE (EffectRowTail r) -instance IRRep r => HoistableE (EffectRowTail r) -instance IRRep r => RenameE (EffectRowTail r) -instance IRRep r => AlphaEqE (EffectRowTail r) -instance IRRep r => AlphaHashableE (EffectRowTail r) - -instance IRRep r => GenericE (EffTy r) where - type RepE (EffTy r) = PairE (EffectRow r) (Type r) - fromE (EffTy eff ty) = eff `PairE` ty - {-# INLINE fromE #-} - toE (eff `PairE` ty) = EffTy eff ty - {-# INLINE toE #-} - -instance IRRep r => SinkableE (EffTy r) -instance IRRep r => HoistableE (EffTy r) -instance IRRep r => RenameE (EffTy r) -instance IRRep r => AlphaEqE (EffTy r) -instance IRRep r => AlphaHashableE (EffTy r) - instance IRRep r => BindsAtMostOneName (Decl r) (AtomNameC r) where Let b _ @> x = b @> x {-# INLINE (@>) #-} @@ -1884,6 +1691,7 @@ instance IRRep r => Store (Con r n) instance IRRep r => Store (PrimOp r n) instance Store (RepVal n) instance IRRep r => Store (Type r n) +instance IRRep r => Store (Effects r n) instance IRRep r => Store (EffTy r n) instance IRRep r => Store (Stuck r n) instance IRRep r => Store (Atom r n) @@ -1907,10 +1715,6 @@ instance Store (InstanceDef n) instance Store (InstanceBody n) instance Store (DictType n) instance IRRep r => Store (DictCon r n) -instance Store (EffectDef n) -instance Store (EffectOpDef n) -instance Store (EffectOpType n) -instance Store (EffectOpIdx) instance Store (ann n) => Store (NonDepNest r ann n l) instance Store IxMethod instance Store ParamRole @@ -1931,17 +1735,8 @@ instance IRRep r => PrettyPrec (Hof r n) where prettyPrec hof = atPrec LowestPrec case hof of For _ _ lam -> "for" <+> pLowest lam While body -> "while" <+> pArg body - RunReader x body -> "runReader" <+> pArg x <> nest 2 (line <> p body) - RunWriter _ bm body -> "runWriter" <+> pArg bm <> nest 2 (line <> p body) - RunState _ x body -> "runState" <+> pArg x <> nest 2 (line <> p body) - RunIO body -> "runIO" <+> pArg body - RunInit body -> "runInit" <+> pArg body - CatchException _ body -> "catchException" <+> pArg body Linearize body x -> "linearize" <+> pArg body <+> pArg x Transpose body x -> "transpose" <+> pArg body <+> pArg x - where - p :: Pretty a => a -> Doc ann - p = pretty instance IRRep r => Pretty (DAMOp r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (DAMOp r n) where @@ -1964,9 +1759,8 @@ instance IRRep r => PrettyPrec (TyCon r n) where encloseSep "(" ")" ", " $ fmap pApp as SumType cs -> atPrec ArgPrec $ align $ group $ encloseSep "(|" "|)" " | " $ fmap pApp cs - RefType h a -> atPrec AppPrec $ pAppArg "Ref" [h] <+> p a + RefType _ a -> atPrec AppPrec $ "Ref" <+> p a TypeKind -> atPrec ArgPrec "Type" - HeapType -> atPrec ArgPrec "Heap" Pi piType -> atPrec LowestPrec $ align $ p piType TabPi piType -> atPrec LowestPrec $ align $ p piType DepPairTy ty -> prettyPrec ty @@ -1986,7 +1780,6 @@ instance PrettyPrec (NewtypeTyCon n) where prettyPrec = \case Nat -> atPrec ArgPrec $ "Nat" Fin n -> atPrec AppPrec $ "Fin" <+> pArg n - EffectRowKind -> atPrec ArgPrec "EffKind" UserADTType name _ (TyConParams infs params) -> case (infs, params) of ([], []) -> atPrec ArgPrec $ pretty name ([Explicit, Explicit], [l, r]) @@ -2010,11 +1803,9 @@ instance IRRep r => PrettyPrec (Con r n) where encloseSep "(" ")" ", " $ fmap pLowest xs SumCon _ tag payload -> atPrec ArgPrec $ "(" <> p tag <> "|" <+> pApp payload <+> "|)" - HeapVal -> atPrec ArgPrec "HeapValue" Lam lam -> atPrec LowestPrec $ p lam DepPair x y _ -> atPrec ArgPrec $ align $ group $ parens $ p x <+> ",>" <+> p y - Eff e -> atPrec ArgPrec $ p e DictConAtom d -> atPrec LowestPrec $ p d NewtypeCon con x -> prettyPrecNewtype con x TyConAtom ty -> prettyPrec ty @@ -2101,7 +1892,7 @@ instance IRRep r => PrettyPrec (Expr r n) where App _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) TopApp _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) TabApp _ f x -> atPrec AppPrec $ pApp f <> brackets (p x) - Case e alts (EffTy effs _) -> prettyPrecCase "case" e alts effs + Case e alts _ -> prettyPrecCase "case" e alts TabCon _ es -> atPrec ArgPrec $ list $ pApp <$> es PrimOp op -> prettyPrec op ApplyMethod _ d i xs -> atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs @@ -2111,15 +1902,10 @@ instance IRRep r => PrettyPrec (Expr r n) where p :: Pretty a => a -> Doc ann p = pretty -prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann -prettyPrecCase name e alts effs = atPrec LowestPrec $ +prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> DocPrec ann +prettyPrecCase name e alts = atPrec LowestPrec $ name <+> pApp e <+> "of" <> - nest 2 (foldMap (\alt -> hardline <> prettyAlt alt) alts - <> effectLine effs) - where - effectLine :: IRRep r => EffectRow r n -> Doc ann - effectLine Pure = "" - effectLine row = hardline <> "case annotated with effects" <+> pretty row + nest 2 (foldMap (\alt -> hardline <> prettyAlt alt) alts) prettyAlt :: IRRep r => Alt r n -> Doc ann prettyAlt (Abs b body) = prettyBinderNoAnn b <+> "->" <> nest 2 (pretty body) @@ -2136,8 +1922,8 @@ instance IRRep r => Pretty (Decl r n l) where where annDoc = case ann of NoInlineLet -> pretty ann <> " "; _ -> pretty ann instance IRRep r => Pretty (PiType r n) where - pretty (PiType bs (EffTy effs resultTy)) = - (spaced $ unsafeFromNest $ bs) <+> "->" <+> "{" <> pretty effs <> "}" <+> pretty resultTy + pretty (PiType bs resultTy) = + (spaced $ unsafeFromNest $ bs) <+> "->" <+> pretty resultTy instance IRRep r => Pretty (LamExpr r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (LamExpr r n) where @@ -2206,21 +1992,6 @@ instance PrettyPrec (AtomVar r n) where prettyPrec (AtomVar v _) = prettyPrec v instance Pretty (AtomVar r n) where pretty = prettyFromPrettyPrec -instance IRRep r => Pretty (EffectRow r n) where - pretty (EffectRow effs t) = braces $ hsep (punctuate "," (map pretty (eSetToList effs))) <> pretty t - -instance IRRep r => Pretty (EffectRowTail r n) where - pretty = \case - NoTail -> mempty - EffectRowTail v -> "|" <> pretty v - -instance IRRep r => Pretty (Effect r n) where - pretty eff = case eff of - RWSEffect rws h -> pretty rws <+> pretty h - ExceptionEffect -> "Except" - IOEffect -> "IO" - InitEffect -> "Init" - prettyLam :: Pretty a => Doc ann -> a -> Doc ann prettyLam binders body = group $ group (nest 4 $ binders) <> group (nest 2 $ pretty body) @@ -2245,12 +2016,8 @@ prettyBinderHelper (b:>ty) body = else pretty ty instance Pretty (CorePiType n) where - pretty (CorePiType appExpl expls bs (EffTy eff resultTy)) = - prettyBindersWithExpl expls bs <+> pretty appExpl <> prettyEff <> pretty resultTy - where - prettyEff = case eff of - Pure -> space - _ -> space <> pretty eff <> space + pretty (CorePiType appExpl expls bs resultTy) = + prettyBindersWithExpl expls bs <+> pretty appExpl <> pretty resultTy prettyBindersWithExpl :: forall b n l ann. PrettyB b => [Explicitness] -> Nest b n l -> Doc ann diff --git a/src/lib/Types/OpNames.hs b/src/lib/Types/OpNames.hs index 344329ac6..8ecacd0f0 100644 --- a/src/lib/Types/OpNames.hs +++ b/src/lib/Types/OpNames.hs @@ -16,8 +16,8 @@ import Data.Store (Store (..)) import PPrint -data TC = ProdType | SumType | RefType | TypeKind | HeapType -data Con = ProdCon | SumCon Int | HeapVal +data TC = ProdType | SumType | RefType | TypeKind +data Con = ProdCon | SumCon Int data BinOp = IAdd | ISub | IMul | IDiv | ICmp CmpOp | FAdd | FSub | FMul diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 66e7ffd14..139c3be4d 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -24,11 +24,10 @@ import Data.Aeson (ToJSON (..)) import Data.Hashable import Data.Foldable import qualified Data.Map.Strict as M -import qualified Data.Set as S import qualified Data.Text as T import Data.Text (Text) import Data.Word -import Data.Text.Prettyprint.Doc (line, group, parens, nest, align, punctuate, hsep) +import Data.Text.Prettyprint.Doc (line, group, parens, nest, align) import Data.Text (snoc, unsnoc) import Data.Tuple (swap) @@ -199,7 +198,6 @@ data CSDecl | CPass deriving (Show, Generic) -type CEffs = WithSrcs ([GroupW], Maybe GroupW) data CDef = CDef SourceNameW ExplicitParams @@ -208,7 +206,7 @@ data CDef = CDef CSBlock deriving (Show, Generic) -type CDefRhs = (AppExplicitness, Maybe CEffs, GroupW) +type CDefRhs = (AppExplicitness, GroupW) data CInstanceDef = CInstanceDef SourceNameW -- interface name @@ -232,7 +230,7 @@ data Group | CCase GroupW [CaseAlt] -- scrutinee, alternatives | CIf GroupW CSBlock (Maybe CSBlock) | CDo CSBlock - | CArrow GroupW (Maybe CEffs) GroupW + | CArrow GroupW GroupW | CWith GroupW WithClause deriving (Show, Generic) @@ -281,20 +279,6 @@ data CSBlock = -- === Untyped IR === -- The AST of Dex surface language. -data UEffect (n::S) = - URWSEffect RWS (SourceOrInternalName (AtomNameC CoreIR) n) - | UExceptionEffect - | UIOEffect - deriving (Generic) - -data UEffectRow (n::S) = - UEffectRow (S.Set (UEffect n)) (Maybe (SourceOrInternalName (AtomNameC CoreIR) n)) - deriving (Generic) - -pattern UPure :: UEffectRow n -pattern UPure <- ((\(UEffectRow effs t) -> (S.null effs, t)) -> (True, Nothing)) - where UPure = UEffectRow mempty Nothing - data UVar (n::S) = UAtomVar (Name (AtomNameC CoreIR) n) | UTyConVar (Name TyConNameC n) @@ -361,13 +345,12 @@ data ULamExpr (n::S) where ULamExpr :: Nest UAnnBinder n l -- args -> AppExplicitness - -> Maybe (UEffectRow l) -- optional effect -> Maybe (UType l) -- optional result type -> UBlock l -- body -> ULamExpr n data UPiExpr (n::S) where - UPiExpr :: Nest UAnnBinder n l -> AppExplicitness -> UEffectRow l -> UType l -> UPiExpr n + UPiExpr :: Nest UAnnBinder n l -> AppExplicitness -> UType l -> UPiExpr n data UTabPiExpr (n::S) where UTabPiExpr :: UAnnBinder n l -> UType l -> UTabPiExpr n @@ -669,7 +652,6 @@ data PrimName = | UBinOp BinOp | UMAsk | UMExtend | UMGet | UMPut | UWhile | ULinearize | UTranspose - | URunReader | URunWriter | URunState | URunIO | UCatchException | UProjNewtype | UExplicitApply | UMonoLiteral | UIndexRef | UApplyMethod Int | UNat | UNatCon | UFin | UEffectRowKind @@ -696,8 +678,6 @@ primNames = M.fromList , ("get" , UMGet), ("put" , UMPut) , ("while" , UWhile) , ("linearize", ULinearize), ("linearTranspose", UTranspose) - , ("runReader", URunReader), ("runWriter" , URunWriter), ("runState", URunState) - , ("runIO" , URunIO ), ("catchException" , UCatchException) , ("iadd" , binary IAdd), ("isub" , binary ISub) , ("imul" , binary IMul), ("fdiv" , binary FDiv) , ("fadd" , binary FAdd), ("fsub" , binary FSub) @@ -736,8 +716,7 @@ primNames = M.fromList , ("Fin" , UFin) , ("EffKind" , UEffectRowKind) , ("NatCon" , UNatCon) - , ("Ref" , UPrimTC $ P.RefType) - , ("HeapType" , UPrimTC $ P.HeapType) + , ("Ref" , UPrimTC $ P.RefType) , ("indexRef" , UIndexRef) , ("alloc" , memOp $ P.IOAlloc) , ("free" , memOp $ P.IOFree) @@ -745,7 +724,6 @@ primNames = M.fromList , ("ptrLoad" , memOp $ P.PtrLoad) , ("ptrStore" , memOp $ P.PtrStore) , ("throwError" , miscOp $ P.ThrowError) - , ("throwException", miscOp $ P.ThrowException) , ("dataConTag" , miscOp $ P.SumTag) , ("toEnum" , miscOp $ P.ToEnum) , ("outputStream" , miscOp $ P.OutputStream) @@ -980,14 +958,6 @@ deriving instance Show (UBlock' n) deriving instance Show (UForExpr n) deriving instance Show (UAlt n) -deriving instance Show (UEffect n) -deriving instance Eq (UEffect n) -deriving instance Ord (UEffect n) - -deriving instance Show (UEffectRow n) -deriving instance Eq (UEffectRow n) -deriving instance Ord (UEffectRow n) - instance ToJSON LexemeType instance ToJSON PassName @@ -1116,10 +1086,6 @@ instance Pretty (UDecl' n l) where UExprDecl expr -> pretty expr UPass -> "pass" -instance Pretty (UEffectRow n) where - pretty (UEffectRow x Nothing) = encloseSep "<" ">" "," $ (pretty <$> toList x) - pretty (UEffectRow x (Just y)) = "{" <> (hsep $ punctuate "," (pretty <$> toList x)) <+> "|" <+> pretty y <> "}" - instance Pretty e => Pretty (WithSrcs e) where pretty (WithSrcs _ _ x) = pretty x instance PrettyPrec e => PrettyPrec (WithSrcs e) where prettyPrec (WithSrcs _ _ x) = prettyPrec x @@ -1141,15 +1107,13 @@ instance Pretty (SourceOrInternalName c n) where instance Pretty (ULamExpr n) where pretty = prettyFromPrettyPrec instance PrettyPrec (ULamExpr n) where - prettyPrec (ULamExpr bs _ _ _ body) = atPrec LowestPrec $ + prettyPrec (ULamExpr bs _ _ body) = atPrec LowestPrec $ "\\" <> pretty bs <+> "." <+> indented (pretty body) instance Pretty (UPiExpr n) where pretty = prettyFromPrettyPrec instance PrettyPrec (UPiExpr n) where - prettyPrec (UPiExpr pats appExpl UPure ty) = atPrec LowestPrec $ align $ + prettyPrec (UPiExpr pats appExpl ty) = atPrec LowestPrec $ align $ pretty pats <+> pretty appExpl <+> pLowest ty - prettyPrec (UPiExpr pats appExpl eff ty) = atPrec LowestPrec $ align $ - pretty pats <+> pretty appExpl <+> pretty eff <+> pLowest ty instance Pretty (UTabPiExpr n) where pretty = prettyFromPrettyPrec instance PrettyPrec (UTabPiExpr n) where @@ -1233,12 +1197,6 @@ instance Pretty FieldName' where FieldName s -> pretty s FieldNum n -> pretty n -instance Pretty (UEffect n) where - pretty eff = case eff of - URWSEffect rws h -> pretty rws <+> pretty h - UExceptionEffect -> "Except" - UIOEffect -> "IO" - prettyOpDefault :: PrettyPrec a => PrimName -> [a] -> DocPrec ann prettyOpDefault name args = case length args of diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index daa606fc6..4122276aa 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -160,53 +160,7 @@ vectorizeLoopsExpr expr = do let loopWidth = vectorByteWidth `div` narrowestTypeByteWidth case expr of Block _ (Abs decls body) -> vectorizeLoopsDecls decls $ vectorizeLoopsExpr body - PrimOp (DAMOp (Seq effs dir ixty dest body)) -> do - sz <- simplifyIxSize =<< renameM ixty - case sz of - Just n | n `mod` loopWidth == 0 -> (do - safe <- vectorSafeEffect effs - if safe - then (do - Distinct <- getDistinct - let vn = n `div` loopWidth - body' <- vectorizeSeq loopWidth ixty body - dest' <- renameM dest - emit =<< mkSeq dir (IxType IdxRepTy (DictCon (IxRawFin (IdxRepVal vn)))) dest' body') - else renameM expr >>= emit) - `catchErr` \err -> do - modify (\(LiftE errs) -> LiftE (err:errs)) - recurSeq expr - _ -> recurSeq expr - PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do - item' <- renameM item - itemTy <- return $ getType item' - lam <- buildEffLam noHint itemTy \hb refb -> - extendRenamer (hb' @> atomVarName hb) do - extendRenamer (refb' @> atomVarName refb) do - vectorizeLoopsExpr body - emit =<< mkTypedHof (RunReader item' lam) - PrimOp (Hof (TypedHof (EffTy _ ty) - (RunWriter (Just dest) monoid (BinaryLamExpr hb' refb' body)))) -> do - dest' <- renameM dest - monoid' <- renameM monoid - commutativity <- monoidCommutativity monoid' - PairTy _ accTy <- renameM ty - lam <- buildEffLam noHint accTy \hb refb -> - extendRenamer (hb' @> atomVarName hb) do - extendRenamer (refb' @> atomVarName refb) do - extendCommuteMap (atomVarName hb) commutativity do - vectorizeLoopsExpr body - emit =<< mkTypedHof (RunWriter (Just dest') monoid' lam) _ -> renameM expr >>= emit - where - recurSeq :: (Emits o) => SExpr i -> TopVectorizeM i o (SAtom o) - recurSeq (PrimOp (DAMOp (Seq effs dir ixty dest body))) = do - effs' <- renameM effs - ixty' <- renameM ixty - dest' <- renameM dest - body' <- vectorizeLoopsLamExpr body - emit $ Seq effs' dir ixty' dest' body' - recurSeq _ = error "Impossible" simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m) => IxType SimpIR n -> m n (Maybe Word32) @@ -269,58 +223,6 @@ _isZeroLit = \case Float64Lit 0.0 -> True _ -> False --- Vectorizing a loop with an effect is safe when the operation reordering --- produced by vectorization doesn't change the semantics. This is guaranteed --- to happen when: --- - It's the Init effect (because the writes are non-aliasing), or --- - It's the Reader effect, or --- - Every reference in the effect is accessed in non-aliasing --- fashion across iterations (e.g., for i. ... ref!i ...), or --- - It's a Writer effect with a commutative monoid, or --- - It's a Writer effect and the body writes to each set of --- potentially overlapping references in scope at most once --- (and the vector operations have in-order reductions --- available) --- - The Exception effect should have been transformed away by now --- - The IO effect is in general not safe --- This check doesn't have enough information to test the above, --- but we crudely approximate for now. -vectorSafeEffect :: EffectRow SimpIR i -> TopVectorizeM i o Bool -vectorSafeEffect (EffectRow effs NoTail) = allM safe $ eSetToList effs where - safe :: Effect SimpIR i -> TopVectorizeM i o Bool - safe InitEffect = return True - safe (RWSEffect Reader _) = return True - safe (RWSEffect Writer (Stuck _ (Var h))) = do - h' <- renameM $ atomVarName h - commuteMap <- ask - case lookupNameMapE h' commuteMap of - Just (LiftE Commutes) -> return True - Just (LiftE DoesNotCommute) -> return False - Nothing -> error $ "Handle " ++ pprint h ++ " not present in commute map?" - safe _ = return False - -vectorizeSeq :: Word32 -> IxType SimpIR i -> LamExpr SimpIR i - -> TopVectorizeM i o (LamExpr SimpIR o) -vectorizeSeq loopWidth ixty (UnaryLamExpr (b:>ty) body) = do - newLoopTy <- case ty of - TyCon (ProdType [_ixType, ref]) -> do - ref' <- renameM ref - return $ TyCon $ ProdType [IdxRepTy, ref'] - _ -> error "Unexpected seq binder type" - ixty' <- renameM ixty - liftVectorizeM loopWidth $ - buildUnaryLamExpr (getNameHint b) newLoopTy \ci -> do - -- The per-tile loop iterates on `Fin` - (viOrd, dest) <- fromPair $ toAtom ci - iOrd <- imul viOrd $ IdxRepVal loopWidth - -- TODO: It would be nice to cancel this UnsafeFromOrdinal with the - -- Ordinal that will be taken later when indexing, but that should - -- probably be a separate pass. - i <- applyIxMethod (sink $ ixTypeDict ixty') UnsafeFromOrdinal [iOrd] - extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal i dest)) $ - vectorizeExpr body $> UnitVal -vectorizeSeq _ _ _ = error "expected a unary lambda expression" - newtype VectorizeM i o a = VectorizeM { runVectorizeM :: SubstReaderT VSubstValC (BuilderT SimpIR (ReaderT Word32 Except)) i o a } @@ -494,16 +396,6 @@ vectorizePrimOp op = case op of BaseTy av <- getVectorType $ BaseTy a ptr' <- emit $ CastOp (BaseTy $ PtrType (addrSpace, av)) ptr VVal Varying <$> emit (PtrLoad ptr') - -- Vectorizing IO might not always be safe! Here, we depend on vectorizeOp - -- being picky about the IO-inducing ops it supports, and expect it to - -- complain about FFI calls and the like. - Hof (TypedHof _ (RunIO body)) -> do - -- TODO: buildBlockAux? - Abs decls (LiftE vy `PairE` y) <- buildScoped do - VVal vy y <- vectorizeExpr body - return $ PairE (LiftE vy) y - block <- mkBlock (Abs decls y) - VVal vy <$> emitHof (RunIO block) _ -> throwVectErr $ "Can't vectorize op: " ++ pprint op vectorizeType :: SType i -> VectorizeM i o (SType o)