Skip to content

Commit

Permalink
Remove constructors for lifting between IRs
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed May 2, 2024
1 parent 6507556 commit 72e36b9
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 153 deletions.
11 changes: 1 addition & 10 deletions src/lib/Algebra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ exprAsPoly expr = liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ exprAsPol
atomAsPoly :: SAtom i -> BlockTraverserM i o (Polynomial o)
atomAsPoly = \case
Stuck _ (Var v) -> atomVarAsPoly v
Stuck _ (RepValAtom (RepVal _ (Leaf (IVar v' _)))) -> impNameAsPoly v'
IdxRepVal i -> return $ poly [((fromIntegral i) % 1, mono [])]
_ -> empty

Expand Down Expand Up @@ -206,15 +205,7 @@ emitPolynomial (Polynomial p) = do
asAtom = IdxRepVal . fromInteger

emitMonomial :: Emits n => Monomial n -> BuilderM SimpIR n (SAtom n)
emitMonomial (Monomial m) = do
varAtoms <- forM (toList m) \(v, e) -> case v of
LeftE v' -> do
v'' <- toAtom <$> toAtomVar v'
ipow v'' e
RightE v' -> do
atom <- mkStuck $ RepValAtom $ RepVal IdxRepTy (Leaf (IVar v' IIdxRepTy))
ipow atom e
foldM imul (IdxRepVal 1) varAtoms
emitMonomial (Monomial m) = undefined

ipow :: Emits n => SAtom n -> Int -> BuilderM SimpIR n (SAtom n)
ipow x i = foldM imul (IdxRepVal 1) (replicate i x)
Expand Down
55 changes: 2 additions & 53 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ module CheapReduction
, visitAtomDefault, visitTypeDefault, Visitor2, mkStuck, mkStuckTy
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate, withInstantiated
, bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst
, repValAtom, reduceUnwrap, reduceProj, reduceSuperclassProj, typeOfApp
, reduceInstantiateGiven, queryStuckType, substMStuck, reduceTabApp, substStuck
, liftSimpAtom)
, reduceUnwrap, reduceProj, reduceSuperclassProj, typeOfApp
, reduceInstantiateGiven, queryStuckType, substMStuck, reduceTabApp, substStuck)
where

import Control.Applicative
Expand All @@ -35,7 +34,6 @@ import Types.Core
import Types.Top
import Types.Imp
import Types.Primitives
import Util
import GHC.Stack

-- Carry out the reductions we are willing to carry out during type
Expand Down Expand Up @@ -177,16 +175,13 @@ queryStuckType = \case
fTy <- queryStuckType f
typeOfTabApp fTy x
PtrVar t _ -> return $ PtrTy t
RepValAtom repVal -> return $ getType repVal
StuckUnwrap s -> queryStuckType s >>= \case
TyCon (NewtypeTyCon con) -> snd <$> unwrapNewtypeType con
_ -> error "not a newtype"
InstantiatedGiven f xs -> do
fTy <- queryStuckType f
typeOfApp fTy xs
SuperclassProj i s -> superclassProjType i =<< queryStuckType s
LiftSimp t _ -> return t
LiftSimpFun t _ -> return $ toType t

projType :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Type r n)
projType i x = case getType x of
Expand Down Expand Up @@ -219,24 +214,6 @@ typeOfApp (TyCon (Pi piTy)) xs = withSubstReaderT $
withInstantiated piTy xs \ty -> substM ty
typeOfApp _ _ = error "expected a pi type"

repValAtom :: EnvReader m => RepVal n -> m n (SAtom n)
repValAtom (RepVal ty tree) = case ty of
TyCon (ProdType ts) -> case tree of
Branch trees -> toAtom <$> ProdCon <$> mapM repValAtom (zipWith RepVal ts trees)
_ -> malformed
TyCon (BaseType _) -> case tree of
Leaf x -> case x of
ILit l -> return $ toAtom $ Lit l
_ -> fallback
_ -> malformed
-- TODO: make sure this covers all the cases. Maybe only TabPi should hit the
-- fallback? This could be a place where we accidentally violate the `Stuck`
-- assumption
_ -> fallback
where fallback = return $ Stuck ty $ RepValAtom $ RepVal ty tree
malformed = error "malformed repval"
{-# INLINE repValAtom #-}

depPairLeftTy :: DepPairType r n -> Type r n
depPairLeftTy (DepPairType _ (_:>ty) _) = ty
{-# INLINE depPairLeftTy #-}
Expand Down Expand Up @@ -615,34 +592,6 @@ reduceStuck = \case
Just child' <- toMaybeDict <$> reduceStuck child
reduceSuperclassProjM superclassIx child'
PtrVar ptrTy ptr -> mkStuck =<< PtrVar ptrTy <$> substM ptr
RepValAtom repVal -> mkStuck =<< RepValAtom <$> substM repVal
LiftSimp t s -> do
t' <- substM t
s' <- reduceStuck s
liftSimpAtom t' s'
LiftSimpFun t f -> mkStuck =<< (LiftSimpFun <$> substM t <*> substM f)

liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n)
liftSimpAtom (StuckTy _ _) _ = error "Can't lift stuck type"
liftSimpAtom ty@(TyCon tyCon) simpAtom = case simpAtom of
Stuck _ stuck -> return $ Stuck ty $ LiftSimp ty stuck
Con con -> Con <$> case (tyCon, con) of
(NewtypeTyCon newtypeCon, _) -> do
(dataCon, repTy) <- unwrapNewtypeType newtypeCon
cAtom <- rec repTy (Con con)
return $ NewtypeCon dataCon cAtom
(BaseType _ , Lit v) -> return $ Lit v
(ProdType tys, ProdCon xs) -> ProdCon <$> zipWithM rec tys xs
(SumType tys, SumCon _ i x) -> SumCon tys i <$> rec (tys!!i) x
(DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do
x1' <- rec t1 x1
t2' <- applySubst (b@>SubstVal x1') t2
x2' <- rec t2' x2
return $ DepPair x1' x2' dpt
_ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty
where
rec = liftSimpAtom
{-# INLINE liftSimpAtom #-}

instance SubstE AtomSubstVal SpecializationSpec where
substE env (AppSpecialization (AtomVar f _) ab) = do
Expand Down
3 changes: 0 additions & 3 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,6 @@ instance IRRep r => CheckableE r (Stuck r) where
return $ InstantiatedGiven given' args'
SuperclassProj i d -> SuperclassProj <$> pure i <*> checkE d -- TODO: check index in range
PtrVar t v -> PtrVar t <$> renameM v
RepValAtom repVal -> RepValAtom <$> renameM repVal -- TODO: check
LiftSimp t x -> LiftSimp <$> checkE t <*> renameM x -- TODO: check
LiftSimpFun t x -> LiftSimpFun <$> checkE t <*> renameM x -- TODO: check

depPairLeftTy :: DepPairType r n -> Type r n
depPairLeftTy (DepPairType _ (_:>ty) _) = ty
Expand Down
43 changes: 20 additions & 23 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ import Types.Primitives
import Types.Top
import Util (forMFilter, Tree (..), zipTrees, enumerate)

repValAtom :: EnvReader m => RepVal n -> m n (SAtom n)
repValAtom = undefined

toImpFunction :: EnvReader m => CallingConvention -> STopLam n -> m n (ImpFunction n)
toImpFunction cc (TopLam True destTy lam) = do
LamExpr bsAndRefB body <- return lam
Expand Down Expand Up @@ -398,15 +401,12 @@ toImpMiscOp op = case op of
returnIExprVal =<< emitInstr =<< (ISelect <$> fsa p <*> fsa x <*> fsa y)
SumTag con -> case con of
Con (SumCon _ tag _) -> return $ TagRepVal $ fromIntegral tag
Stuck _ (RepValAtom dRepVal) -> do
RepVal _ (Branch (tag:_)) <- return dRepVal
return $ toAtom $ RepVal (TagRepTy :: SType o) tag
_ -> error $ "Not a data constructor: " ++ pprint con
ToEnum ty i -> case ty of
TyCon (SumType cases) -> do
i' <- fromScalarAtom i
return $ toAtom $ RepVal ty $ Branch $ Leaf i' : map (const (Branch [])) cases
_ -> error $ "Not an enum: " ++ pprint ty
ToEnum ty i -> undefined
-- TyCon (SumType cases) -> do
-- i' <- fromScalarAtom i
-- return $ toAtom $ RepVal ty $ Branch $ Leaf i' : map (const (Branch [])) cases
-- _ -> error $ "Not an enum: " ++ pprint ty
OutputStream -> returnIExprVal =<< emitInstr IOutputStream
ShowAny _ -> error "Shouldn't have ShowAny in simplified IR"
ShowScalar x -> do
Expand Down Expand Up @@ -729,16 +729,13 @@ atomToRepVal x = RepVal (getType x) <$> go x where
tag' <- go $ TagRepVal $ fromIntegral tag
xs <- forM (enumerate cases) \(i, t) -> if i == tag
then go payload
else buildGarbageVal t <&> \(Stuck _ (RepValAtom (RepVal _ tree))) -> tree
else undefined -- buildGarbageVal t <&> \(Stuck _ (RepValAtom (RepVal _ tree))) -> tree
return $ Branch $ tag':xs
go (Stuck _ stuck) = case stuck of
Var v -> lookupAtomName (atomVarName v) >>= \case
TopDataBound (RepVal _ tree) -> return tree
_ -> error "should only have pointer and data atom names left"
PtrVar ty p -> return $ Leaf $ IPtrVar p ty
RepValAtom dRepVal -> do
(RepVal _ tree) <- return dRepVal
return tree
-- TODO: I think we want to be able to rule this one out by insisting that
-- RepValAtom is itself part of Stuck and it can't represent a product.
StuckProject i val -> do
Expand All @@ -750,12 +747,12 @@ atomToRepVal x = RepVal (getType x) <$> go x where
-- from the dest. This version is not that. It just lifts a dest into an atom of
-- type `Ref _`.
destToAtom :: Dest n -> SAtom n
destToAtom (Dest valTy tree) = toAtom $ RepVal (RefTy valTy) tree
destToAtom (Dest valTy tree) = undefined -- toAtom $ RepVal (RefTy valTy) tree

atomToDest :: EnvReader m => SAtom n -> m n (Dest n)
atomToDest (Stuck _ (RepValAtom val)) = do
(RepVal ~(RefTy valTy) valTree) <- return val
return $ Dest valTy valTree
-- atomToDest (Stuck _ (RepValAtom val)) = do
-- (RepVal ~(RefTy valTy) valTree) <- return val
-- return $ Dest valTy valTree
atomToDest atom = error $ "Expected a non-var atom of type `RawRef _`, got: " ++ pprint atom
{-# INLINE atomToDest #-}

Expand Down Expand Up @@ -847,11 +844,11 @@ litValToIExpr litval = case litval of
_ -> return $ ILit litval

buildGarbageVal :: Emits n => SType n -> SubstImpM i n (SAtom n)
buildGarbageVal ty =
toAtom <$> RepVal ty <$> traverseScalarRepTys ty \leafTy -> do
case getIExprInterpretation leafTy of
BufferPtr bufferTy -> allocBuffer Managed bufferTy
RawValue b -> return $ ILit $ emptyLit b
buildGarbageVal ty = undefined
-- toAtom <$> RepVal ty <$> traverseScalarRepTys ty \leafTy -> do
-- case getIExprInterpretation leafTy of
-- BufferPtr bufferTy -> allocBuffer Managed bufferTy
-- RawValue b -> return $ ILit $ emptyLit b

-- === Operations on dests ===

Expand Down Expand Up @@ -1176,7 +1173,7 @@ fromScalarAtom atom = atomToRepVal atom >>= \case
_ -> error $ "Not a scalar atom:" ++ pprint ty

toScalarAtom :: forall n. IExpr n -> SAtom n
toScalarAtom x = toAtom $ RepVal (BaseTy (getIType x) :: SType n) (Leaf x)
toScalarAtom x = undefined -- toAtom $ RepVal (BaseTy (getIType x) :: SType n) (Leaf x)

liftBuilderImp :: (Emits n, SubstE AtomSubstVal e, SinkableE e)
=> (forall l. (Emits l, DExt n l) => BuilderM SimpIR l (e l))
Expand Down Expand Up @@ -1249,7 +1246,7 @@ singletonTypeVal ty = do
if length tree == 0 then do
-- The tree has 0 of these if the type is empty
let tree' = fmap (const $ ILit $ Int32Lit 0) tree
Just <$> mkStuck (RepValAtom $ RepVal ty tree')
undefined -- Just <$> mkStuck (RepValAtom $ RepVal ty tree')
else
return Nothing
{-# INLINE singletonTypeVal #-}
Expand Down
3 changes: 0 additions & 3 deletions src/lib/Inline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,6 @@ inlineStuck ctx = \case
PtrVar t p -> do
s <- mkStuck =<< (PtrVar t <$> substM p)
reconstruct ctx (toExpr s)
RepValAtom repVal -> do
s <- mkStuck =<< (RepValAtom <$> visitGeneric repVal)
reconstruct ctx (toExpr s)

inlineName :: Emits o => Context SExpr e o -> SAtomVar i -> InlineM i o (e o)
inlineName ctx name =
Expand Down
1 change: 0 additions & 1 deletion src/lib/Linearize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ linearizeStuck stuck = case stuck of
Nothing -> zero
Just idx -> return $ WithTangent (toAtom v') $ getTangentArg idx
PtrVar _ _ -> zero
RepValAtom _ -> zero
-- TODO: de-dup with the Expr versions of these
StuckProject i x -> do
x' <- linearizeStuck x
Expand Down
1 change: 0 additions & 1 deletion src/lib/OccAnalysis.hs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ instance HasOCC SStuck where
array' <- occ a' array
return $ StuckTabApp array' ixs'
PtrVar t p -> return $ PtrVar t p
RepValAtom x -> return $ RepValAtom x

instance HasOCC SType where
occ a (TyCon con) = liftM TyCon $ runOCCMVisitor a $ visitGeneric con
Expand Down
17 changes: 5 additions & 12 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

{-# LANGUAGE UndecidableInstances #-}

module Simplify (simplifyTopBlock, simplifyTopFunction, linearizeTopFun) where
module Simplify (simplifyTopBlock, simplifyTopFunction, linearizeTopFun, liftSimpAtom) where

import Control.Category ((>>>))
import Control.Monad
Expand Down Expand Up @@ -52,6 +52,9 @@ import Util (enumerate)

-- === Conversions between CoreIR, CoreToSimpIR, SimpIR ===

liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n)
liftSimpAtom = undefined

tryAsDataAtom :: Emits n => CAtom n -> SimplifyM i n (Maybe (SAtom n, Type CoreIR n))
tryAsDataAtom atom = do
let ty = getType atom
Expand Down Expand Up @@ -92,14 +95,6 @@ forceStuck stuck = withDistinct case stuck of
return $ CCFun $ CCNoInlineFun v'' t f
FFIFunBound t f -> return $ CCFun $ CCFFIFun t f
_ -> error "shouldn't have other CVars left"
LiftSimp _ x -> do
-- the subst should be rename-only for `x`. We should make subst IR-specific
s <- getSubst
let s' = newSubst \v -> case s ! v of
SubstVal _ -> error "subst should be rename-only for SimpIR vars" -- TODO: make subst IR-specific
Rename v' -> v'
x' <- runSubstReaderT s' $ renameM x
returnLifted x'
StuckProject i x -> forceStuck x >>= \case
CCLiftSimp _ x' -> returnLifted $ StuckProject i x'
CCCon (WithSubst s con) -> withSubst s case con of
Expand All @@ -124,7 +119,6 @@ forceStuck stuck = withDistinct case stuck of
PtrVar ty p -> do
p' <- substM p
returnLifted $ PtrVar ty p'
LiftSimpFun t f -> CCFun <$> (CCLiftSimpFun <$> substM t <*> substM f)
where
returnLifted :: SStuck o -> SimplifyM i o (ConcreteCAtom o)
returnLifted s = do
Expand Down Expand Up @@ -594,8 +588,7 @@ simplifyHof resultTy = \case
liftSimpAtom resultTy result

liftSimpFun :: EnvReader m => Type CoreIR n -> LamExpr SimpIR n -> m n (CAtom n)
liftSimpFun (TyCon (Pi piTy)) f = mkStuck $ LiftSimpFun piTy f
liftSimpFun _ _ = error "not a pi type"
liftSimpFun = undefined

-- === simplifying custom linearizations ===

Expand Down
13 changes: 7 additions & 6 deletions src/lib/TopLevel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,13 @@ loadModuleSource config moduleName = do
{-# SCC loadModuleSource #-}

getDexString :: (MonadIO1 m, EnvReader m, Fallible1 m) => Val CoreIR n -> m n String
getDexString val = do
-- TODO: use a `ByteString` instead of `String`
Stuck _ (LiftSimp _ (RepValAtom (RepVal _ tree))) <- return val
Branch [Leaf (IIdxRepVal n), Leaf (IPtrVar ptrName _)] <- return tree
PtrBinding (CPU, Scalar Word8Type) (PtrLitVal ptr) <- lookupEnv ptrName
liftIO $ peekCStringLen (castPtr ptr, fromIntegral n)
getDexString val = undefined
-- getDexString val = do
-- -- TODO: use a `ByteString` instead of `String`
-- Stuck _ (LiftSimp _ (RepValAtom (RepVal _ tree))) <- return val
-- Branch [Leaf (IIdxRepVal n), Leaf (IPtrVar ptrName _)] <- return tree
-- PtrBinding (CPU, Scalar Word8Type) (PtrLitVal ptr) <- lookupEnv ptrName
-- liftIO $ peekCStringLen (castPtr ptr, fromIntegral n)

-- === saving cache to disk ===

Expand Down
1 change: 0 additions & 1 deletion src/lib/Transpose.hs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ transposeAtom atom ct = case atom of
LinTrivial -> return ()
StuckProject _ _ -> error "not linear"
StuckTabApp _ _ -> error "not linear"
RepValAtom _ -> error "not linear"
where notTangent = error $ "Not a tangent atom: " ++ pprint atom

transposeHof :: Emits o => Hof SimpIR i -> SAtom o -> TransposeM i o ()
Expand Down
Loading

0 comments on commit 72e36b9

Please sign in to comment.