Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically detect and split mutually recursive blocks in let expressions #1894

61 changes: 44 additions & 17 deletions src/Juvix/Compiler/Abstract/Extra/DependencyBuilder.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module Juvix.Compiler.Abstract.Extra.DependencyBuilder (buildDependencyInfo, ExportsTable) where
module Juvix.Compiler.Abstract.Extra.DependencyBuilder (buildDependencyInfo, buildDependencyInfoExpr, ExportsTable) where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
Expand All @@ -18,7 +18,23 @@ type ExportsTable = HashSet NameId

buildDependencyInfo :: NonEmpty TopModule -> ExportsTable -> NameDependencyInfo
buildDependencyInfo ms tab =
createDependencyInfo graph startNodes
buildDependencyInfoHelper tab (mapM_ goModule ms)

buildDependencyInfoExpr :: Expression -> NameDependencyInfo
buildDependencyInfoExpr = buildDependencyInfoHelper mempty . goExpression Nothing

buildDependencyInfoHelper ::
ExportsTable ->
( Sem
'[ Reader ExportsTable,
State DependencyGraph,
State StartNodes,
State VisitedModules
]
()
) ->
NameDependencyInfo
buildDependencyInfoHelper tbl m = createDependencyInfo graph startNodes
where
startNodes :: StartNodes
graph :: DependencyGraph
Expand All @@ -27,12 +43,14 @@ buildDependencyInfo ms tab =
evalState (HashSet.empty :: VisitedModules) $
runState HashSet.empty $
execState HashMap.empty $
runReader tab $
mapM_ goModule ms
runReader tbl m

addStartNode :: (Member (State StartNodes) r) => Name -> Sem r ()
addStartNode n = modify (HashSet.insert n)

addEdgeMay :: (Member (State DependencyGraph) r) => Maybe Name -> Name -> Sem r ()
addEdgeMay mn1 n2 = whenJust mn1 $ \n1 -> addEdge n1 n2

addEdge :: (Member (State DependencyGraph) r) => Name -> Name -> Sem r ()
addEdge n1 n2 =
modify
Expand Down Expand Up @@ -87,16 +105,16 @@ goStatement modName = \case
StatementAxiom ax -> do
checkStartNode (ax ^. axiomName)
addEdge (ax ^. axiomName) modName
goExpression (ax ^. axiomName) (ax ^. axiomType)
goExpression (Just (ax ^. axiomName)) (ax ^. axiomType)
StatementFunction f -> goTopFunctionDef modName f
StatementImport m -> guardNotVisited (m ^. moduleName) (goModule m)
StatementLocalModule m -> goLocalModule modName m
StatementInductive i -> do
checkStartNode (i ^. inductiveName)
checkBuiltinInductiveStartNode i
addEdge (i ^. inductiveName) modName
mapM_ (goFunctionParameter (i ^. inductiveName)) (i ^. inductiveParameters)
goExpression (i ^. inductiveName) (i ^. inductiveType)
mapM_ (goFunctionParameter (Just (i ^. inductiveName))) (i ^. inductiveParameters)
goExpression (Just (i ^. inductiveName)) (i ^. inductiveType)
mapM_ (goConstructorDef (i ^. inductiveName)) (i ^. inductiveConstructors)

goTopFunctionDef :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> FunctionDef -> Sem r ()
Expand All @@ -110,22 +128,22 @@ goFunctionDefHelper ::
Sem r ()
goFunctionDefHelper f = do
checkStartNode (f ^. funDefName)
goExpression (f ^. funDefName) (f ^. funDefTypeSig)
goExpression (Just (f ^. funDefName)) (f ^. funDefTypeSig)
mapM_ (goFunctionClause (f ^. funDefName)) (f ^. funDefClauses)

-- constructors of an inductive type depend on the inductive type, not the other
-- way round; an inductive type depends on the types of its constructors
goConstructorDef :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> InductiveConstructorDef -> Sem r ()
goConstructorDef indName c = do
addEdge (c ^. constructorName) indName
goExpression indName (c ^. constructorType)
goExpression (Just indName) (c ^. constructorType)

goFunctionClause :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> FunctionClause -> Sem r ()
goFunctionClause p c = do
mapM_ (goPattern p) (c ^. clausePatterns)
goExpression p (c ^. clauseBody)
mapM_ (goPattern (Just p)) (c ^. clausePatterns)
goExpression (Just p) (c ^. clauseBody)

goPattern :: forall r. (Member (State DependencyGraph) r) => Name -> PatternArg -> Sem r ()
goPattern :: forall r. (Member (State DependencyGraph) r) => Maybe Name -> PatternArg -> Sem r ()
goPattern n p = case p ^. patternArgPattern of
PatternVariable {} -> return ()
PatternWildcard {} -> return ()
Expand All @@ -134,12 +152,17 @@ goPattern n p = case p ^. patternArgPattern of
where
goApp :: ConstructorApp -> Sem r ()
goApp (ConstructorApp ctr ps) = do
addEdge n (ctr ^. constructorRefName)
addEdgeMay n (ctr ^. constructorRefName)
mapM_ (goPattern n) ps

goExpression :: forall r. (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> Expression -> Sem r ()
goExpression ::
forall r.
(Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) =>
Maybe Name ->
Expression ->
Sem r ()
goExpression p e = case e of
ExpressionIden i -> addEdge p (idenName i)
ExpressionIden i -> addEdgeMay p (idenName i)
ExpressionUniverse {} -> return ()
ExpressionFunction f -> do
goFunctionParameter p (f ^. funParameter)
Expand Down Expand Up @@ -177,8 +200,12 @@ goExpression p e = case e of
goLetClause :: LetClause -> Sem r ()
goLetClause = \case
LetFunDef f -> do
addEdge p (f ^. funDefName)
addEdgeMay p (f ^. funDefName)
goFunctionDefHelper f

goFunctionParameter :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> FunctionParameter -> Sem r ()
goFunctionParameter ::
(Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) =>
Maybe Name ->
FunctionParameter ->
Sem r ()
goFunctionParameter p param = goExpression p (param ^. paramType)
61 changes: 30 additions & 31 deletions src/Juvix/Compiler/Core/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,9 @@ goFunctionDef ::
Sem r ()
goFunctionDef ((f, sym), ty) = do
mbody <- case f ^. Internal.funDefBuiltin of
Just b | isIgnoredBuiltin b -> return Nothing
Just _ -> Just <$> runReader initIndexTable (mkFunBody ty f)
Just b
| isIgnoredBuiltin b -> return Nothing
| otherwise -> Just <$> runReader initIndexTable (mkFunBody ty f)
Nothing -> Just <$> runReader initIndexTable (mkFunBody ty f)
forM_ mbody (registerIdentNode sym)
forM_ mbody setIdentArgsInfo'
Expand Down Expand Up @@ -461,35 +462,33 @@ goLet ::
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable] r) =>
Internal.Let ->
Sem r Node
goLet l = do
vars <- asks (^. indexTableVars)
varsNum <- asks (^. indexTableVarsNum)
let bs :: [Name]
bs = map (\(Internal.LetFunDef Internal.FunctionDef {..}) -> _funDefName) (toList $ l ^. Internal.letClauses)
(vars', varsNum') =
foldl'
( \(vs, k) name ->
(HashMap.insert (name ^. nameId) k vs, k + 1)
)
(vars, varsNum)
bs
(defs, value) <- do
values <-
mapM
( \(Internal.LetFunDef f) -> do
funTy <- goType (f ^. Internal.funDefType)

funBody <- local (set indexTableVars vars' . set indexTableVarsNum varsNum') (mkFunBody funTy f)
return (funTy, funBody)
)
(l ^. Internal.letClauses)

lbody <-
local
(set indexTableVars vars' . set indexTableVarsNum varsNum')
(goExpression (l ^. Internal.letExpression))
return (values, lbody)
return $ mkLetRec' defs value
goLet l = goClauses (toList (l ^. Internal.letClauses))
where
goClauses :: [Internal.LetClause] -> Sem r Node
goClauses = \case
[] -> goExpression (l ^. Internal.letExpression)
c : cs -> case c of
Internal.LetFunDef f -> goNonRecFun f
Internal.LetMutualBlock m -> goMutual m
where
goNonRecFun :: Internal.FunctionDef -> Sem r Node
goNonRecFun f =
do
funTy <- goType (f ^. Internal.funDefType)
funBody <- mkFunBody funTy f
rest <- localAddName (f ^. Internal.funDefName) (goClauses cs)
return $ mkLet' funTy funBody rest
goMutual :: Internal.MutualBlock -> Sem r Node
goMutual (Internal.MutualBlock funs) = do
let lfuns = toList funs
names = map (^. Internal.funDefName) lfuns
tys = map (^. Internal.funDefType) lfuns
tys' <- mapM goType tys
localAddNames names $ do
vals' <- sequence [mkFunBody ty f | (ty, f) <- zipExact tys' lfuns]
let items = nonEmpty' (zip tys' vals')
rest <- goClauses cs
return (mkLetRec' items rest)

goAxiomInductive ::
forall r.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,24 @@ makeLenses ''IndexTable
initIndexTable :: IndexTable
initIndexTable = IndexTable 0 mempty

localAddName :: forall r a. (Member (Reader IndexTable) r) => Name -> Sem r a -> Sem r a
localAddName n s = do
localAddName :: Member (Reader IndexTable) r => Name -> Sem r a -> Sem r a
localAddName n = localAddNames [n]

localAddNames :: forall r a. (Member (Reader IndexTable) r) => [Name] -> Sem r a -> Sem r a
localAddNames names s = do
updateFn <- update
local updateFn s
where
len :: Int = length names
insertMany :: [(NameId, Index)] -> HashMap NameId Index -> HashMap NameId Index
insertMany l t = foldl' (\m (k, v) -> HashMap.insert k v m) t l
update :: Sem r (IndexTable -> IndexTable)
update = do
idx <- asks (^. indexTableVarsNum)
let newElems = zip (map (^. nameId) names) [idx ..]
return
( over indexTableVars (HashMap.insert (n ^. nameId) idx)
. over indexTableVarsNum (+ 1)
( over indexTableVars (insertMany newElems)
. over indexTableVarsNum (+ len)
)

underBinders :: Members '[Reader IndexTable] r => Int -> Sem r a -> Sem r a
Expand Down
20 changes: 18 additions & 2 deletions src/Juvix/Compiler/Internal/Data/InfoTable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,25 @@ extendWithReplExpression e =
over
infoFunctions
( HashMap.union
(HashMap.fromList [(f ^. funDefName, FunctionInfo f) | LetFunDef f <- universeBi e])
( HashMap.fromList
[ (f ^. funDefName, FunctionInfo f)
| f <- letFunctionDefs e
]
)
)

letFunctionDefs :: Data from => from -> [FunctionDef]
letFunctionDefs e =
concat
[ concatMap (toList . flattenClause) _letClauses
| Let {..} <- universeBi e
]
where
flattenClause :: LetClause -> NonEmpty FunctionDef
flattenClause = \case
LetFunDef f -> pure f
LetMutualBlock (MutualBlock fs) -> fs

-- | moduleName ↦ infoTable
type Cache = HashMap Name InfoTable

Expand Down Expand Up @@ -117,7 +133,7 @@ buildTable1' m = do
]
<> [ (f ^. funDefName, FunctionInfo f)
| s <- filter (not . isInclude) ss,
LetFunDef f <- universeBi s
f <- letFunctionDefs s
]
where
isInclude :: Statement -> Bool
Expand Down
5 changes: 5 additions & 0 deletions src/Juvix/Compiler/Internal/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,14 @@ instance HasExpressions Case where
where
_caseParens = l ^. caseParens

instance HasExpressions MutualBlock where
leafExpressions f (MutualBlock defs) =
MutualBlock <$> traverse (leafExpressions f) defs

instance HasExpressions LetClause where
leafExpressions f = \case
LetFunDef d -> LetFunDef <$> leafExpressions f d
LetMutualBlock b -> LetMutualBlock <$> leafExpressions f b

instance HasExpressions Let where
leafExpressions f l = do
Expand Down
14 changes: 11 additions & 3 deletions src/Juvix/Compiler/Internal/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ data Statement
newtype MutualBlock = MutualBlock
{ _mutualFunctions :: NonEmpty FunctionDef
}
deriving stock (Data)
deriving stock (Eq, Generic, Data)

instance Hashable MutualBlock

data AxiomDef = AxiomDef
{ _axiomName :: AxiomName,
Expand Down Expand Up @@ -98,8 +100,10 @@ data TypedExpression = TypedExpression
_typedExpression :: Expression
}

newtype LetClause
= LetFunDef FunctionDef
data LetClause
= -- | Non-recursive let definition
LetFunDef FunctionDef
| LetMutualBlock MutualBlock
deriving stock (Eq, Generic, Data)

instance Hashable LetClause
Expand Down Expand Up @@ -367,9 +371,13 @@ instance HasLoc FunctionClause where
instance HasLoc FunctionDef where
getLoc f = getLoc (f ^. funDefName) <> getLocSpan (f ^. funDefClauses)

instance HasLoc MutualBlock where
getLoc (MutualBlock defs) = getLocSpan defs

instance HasLoc LetClause where
getLoc = \case
LetFunDef f -> getLoc f
LetMutualBlock f -> getLoc f

instance HasLoc Let where
getLoc l = getLocSpan (l ^. letClauses) <> getLoc (l ^. letExpression)
Expand Down
9 changes: 9 additions & 0 deletions src/Juvix/Compiler/Internal/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,17 @@ instance PrettyCode Let where
return $ kwLet <+> letClauses' <+> kwIn <+> letExpression'

instance PrettyCode LetClause where
ppCode :: forall r. Member (Reader Options) r => LetClause -> Sem r (Doc Ann)
ppCode = \case
LetFunDef f -> ppCode f
LetMutualBlock b -> ppMutual b
where
ppMutual :: MutualBlock -> Sem r (Doc Ann)
ppMutual m@(MutualBlock b)
| [_] <- toList b = ppCode b
| otherwise = do
b' <- ppCode m
return (kwMutual <+> braces (line <> indent' b' <> line))
paulcadman marked this conversation as resolved.
Show resolved Hide resolved

ppPipeBlock :: (PrettyCode a, Members '[Reader Options] r, Traversable t) => t a -> Sem r (Doc Ann)
ppPipeBlock items = vsep <$> mapM (fmap (kwPipe <+>) . ppCode) items
Expand Down
Loading