Skip to content

Commit

Permalink
Make lambda lifting correct when free variables occur in the types of…
Browse files Browse the repository at this point in the history
… binders (#1609)
  • Loading branch information
janmasrovira authored Nov 9, 2022
1 parent 9d4f843 commit aa00d34
Show file tree
Hide file tree
Showing 12 changed files with 229 additions and 47 deletions.
1 change: 0 additions & 1 deletion src/Juvix/Compiler/Asm/Data/Stack.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module Juvix.Compiler.Asm.Data.Stack where

import Data.IntMap (IntMap)
import Data.IntMap qualified as IntMap
import Juvix.Prelude hiding (empty)

Expand Down
8 changes: 4 additions & 4 deletions src/Juvix/Compiler/Asm/Translation/FromCore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ genCode infoTable fi =
Core.NCase c -> goCase isTail tempSize refs c

goVar :: Bool -> BinderList Value -> Core.Var -> Code'
goVar isTail refs (Core.Var {..}) =
goVar isTail refs Core.Var {..} =
snocReturn isTail $
DL.singleton $
mkInstr $
Push (BL.lookup _varIndex refs)

goIdent :: Bool -> Core.Ident -> Code'
goIdent isTail (Core.Ident {..}) =
goIdent isTail Core.Ident {..} =
if
| getArgsNum _identSymbol == 0 ->
DL.singleton $
Expand Down Expand Up @@ -108,7 +108,7 @@ genCode infoTable fi =
unimplemented
where
argsNum = getArgsNum _identSymbol
Core.FunVar (Core.Var {..}) ->
Core.FunVar Core.Var {..} ->
if
| argsNum > suppliedArgsNum ->
snocReturn isTail $
Expand Down Expand Up @@ -163,7 +163,7 @@ genCode infoTable fi =
goLet isTail tempSize refs (Core.Let {..}) =
DL.append
(DL.snoc (go False tempSize refs (_letItem ^. Core.letItemValue)) (mkInstr PushTemp))
(snocPopTemp isTail $ go isTail (tempSize + 1) (BL.extend (Ref (DRef (TempRef tempSize))) refs) _letBody)
(snocPopTemp isTail $ go isTail (tempSize + 1) (BL.cons (Ref (DRef (TempRef tempSize))) refs) _letBody)

goCase :: Bool -> Int -> BinderList Value -> Core.Case -> Code'
goCase isTail tempSize refs (Core.Case {..}) =
Expand Down
106 changes: 88 additions & 18 deletions src/Juvix/Compiler/Core/Data/BinderList.hs
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
module Juvix.Compiler.Core.Data.BinderList where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Core.Language.Base
import Juvix.Compiler.Core.Language hiding (cons, lookup, uncons)
import Juvix.Prelude qualified as Prelude

-- | if we have \x\y. b, the binderlist in b is [y, x]
data BinderList a = BinderList
{ _blLength :: Int,
_blMap :: HashMap Index a
_blMap :: [a]
}

makeLenses ''BinderList

fromList :: [a] -> BinderList a
fromList l = BinderList (length l) (HashMap.fromList (zip [0 ..] l))
fromList l = BinderList (length l) l

drop' :: Int -> BinderList a -> BinderList a
drop' k (BinderList n l) = BinderList (n - k) (dropExact k l)

tail' :: BinderList a -> BinderList a
tail' = snd . fromJust . uncons

uncons :: BinderList a -> Maybe (a, BinderList a)
uncons l = second helper <$> Prelude.uncons (l ^. blMap)
where
helper m =
BinderList
{ _blLength = l ^. blLength - 1,
_blMap = m
}

toIndexedList :: BinderList a -> [(Index, a)]
toIndexedList = zip [0 ..] . toList

instance Foldable BinderList where
foldr :: (a -> b -> b) -> b -> BinderList a -> b
Expand All @@ -24,14 +43,53 @@ instance Foldable BinderList where
length = (^. blLength)

toList :: BinderList a -> [a]
toList bl =
map snd $
sortBy (compare `on` fst) $
HashMap.toList (bl ^. blMap)
toList = (^. blMap)

-- | same as `lookupsSortedRev` but the result is in the same order as the input list.
lookupsSorted :: BinderList a -> [Var' i] -> [(Var' i, a)]
lookupsSorted bl = reverse . lookupsSortedRev bl

-- | efficient multiple lookups. The input list needs to be in non-decreasing order.
-- | The result is in reversed order (non-increasing order)
lookupsSortedRev :: BinderList a -> [Var' i] -> [(Var' i, a)]
lookupsSortedRev bl = go [] 0 bl
where
go :: [(Var' i, a)] -> Index -> BinderList a -> [Var' i] -> [(Var' i, a)]
go acc off ctx = \case
[] -> acc
(v : vs) ->
let skipped = v ^. varIndex - off
off' = off + skipped
ctx' = drop' skipped ctx
in go ((v, head' ctx') : acc) off' ctx' vs
head' :: BinderList a -> a
head' = lookup 0

-- | lookup de Bruijn Index
lookup :: Index -> BinderList a -> a
lookup idx bl =
fromMaybe err (HashMap.lookup target (bl ^. blMap))
lookup idx bl
| target < bl ^. blLength = (bl ^. blMap) !! target
| otherwise = err
where
target = idx
err :: a
err =
error
( "invalid binder lookup. Got index "
<> show idx
<> " that targets "
<> show target
<> " and the length is "
<> show (bl ^. blLength)
<> ". Actual length is "
<> show (length (bl ^. blMap))
)

-- | lookup de Bruijn Level
lookupLevel :: Level -> BinderList a -> a
lookupLevel idx bl
| target < bl ^. blLength = (bl ^. blMap) !! target
| otherwise = err
where
target = bl ^. blLength - 1 - idx
err :: a
Expand All @@ -46,7 +104,11 @@ lookup idx bl =
)

instance Semigroup (BinderList a) where
a <> b = prepend (toList a) b
(BinderList la ta) <> (BinderList lb tb) =
BinderList
{ _blLength = la + lb,
_blMap = ta <> tb
}

instance Monoid (BinderList a) where
mempty =
Expand All @@ -55,15 +117,23 @@ instance Monoid (BinderList a) where
_blMap = mempty
}

extend :: a -> BinderList a -> BinderList a
extend a bl =
BinderList
(bl ^. blLength + 1)
(HashMap.insert (bl ^. blLength) a (bl ^. blMap))

instance Functor BinderList where
fmap :: (a -> b) -> BinderList a -> BinderList b
fmap f = over blMap (fmap f)

cons :: a -> BinderList a -> BinderList a
cons a (BinderList l m) = BinderList (l + 1) (a : m)

-- | prepend [a,b] [c,d] = [a,b,c,d]
prepend :: [a] -> BinderList a -> BinderList a
prepend l bl = foldr extend bl l
prepend l bl = fromList l <> bl

-- | prependRev [a,b] [c,d] = [b,a,c,d]
-- more efficient than 'prepend' since it is tail recursive.
-- One example use case is prepending a list of binders in a letrec.
prependRev :: [a] -> BinderList a -> BinderList a
prependRev l (BinderList s m) =
BinderList
{ _blLength = length l + s,
_blMap = foldl' (flip (:)) m l
}
39 changes: 37 additions & 2 deletions src/Juvix/Compiler/Core/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Data.Set qualified as Set
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core.Extra.Base
import Juvix.Compiler.Core.Extra.Equality
Expand All @@ -25,6 +27,9 @@ import Juvix.Compiler.Core.Language
isClosed :: Node -> Bool
isClosed = not . has freeVars

freeVarsSorted :: Node -> Set Var
freeVarsSorted n = Set.fromList (n ^.. freeVars)

freeVarsSet :: Node -> HashSet Var
freeVarsSet n = HashSet.fromList (n ^.. freeVars)

Expand Down Expand Up @@ -76,8 +81,7 @@ _NLam f = \case
cosmos :: SimpleFold Node Node
cosmos f = ufoldA reassemble f

-- | The list should not contain repeated indices. The 'Info' corresponds to the
-- binder of the variable.
-- | The list should not contain repeated indices.
-- if fv = x1, x2, .., xn
-- the result is of the form λx1 λx2 .. λ xn b
captureFreeVars :: [(Index, Binder)] -> Node -> Node
Expand All @@ -98,6 +102,37 @@ captureFreeVars fv
| Just v <- s ^. at (u - k) -> NVar (Var i (v + k))
m -> m

-- captures all free variables of a node. It also returns the list of captured
-- variables in left-to-right order: if snd is of the form λxλy... then fst is
-- [x, y]
captureFreeVarsCtx :: BinderList Binder -> Node -> ([(Var, Binder)], Node)
captureFreeVarsCtx bl n =
let assocs = freeVarsCtx bl n
in (assocs, captureFreeVars (map (first (^. varIndex)) assocs) n)

freeVarsCtx' :: BinderList Binder -> Node -> [Var]
freeVarsCtx' bl = map fst . freeVarsCtx bl

-- | the output list does not contain repeated elements and is sorted by *decreasing* variable index.
-- The indices are relative to the given binder list
freeVarsCtx :: BinderList Binder -> Node -> [(Var, Binder)]
freeVarsCtx ctx n =
BL.lookupsSortedRev ctx . run . fmap fst . runOutputList $ go (freeVarsSorted n)
where
go ::
-- set of free variables relative to the original ctx
Set Var ->
Sem '[Output Var] ()
go fv = case Set.minView fv of
Nothing -> return ()
Just (v, vs) -> do
output v
let idx = v ^. varIndex
bi = BL.lookup idx ctx
freevarsbi' :: Set Var
freevarsbi' = Set.mapMonotonic (over varIndex (+ (idx + 1))) (freeVarsSorted (bi ^. binderType))
go (freevarsbi' <> vs)

-- | subst for multiple bindings
substs :: [Node] -> Node -> Node
substs t = umapN go
Expand Down
12 changes: 11 additions & 1 deletion src/Juvix/Compiler/Core/Extra/Base.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
module Juvix.Compiler.Core.Extra.Base where
module Juvix.Compiler.Core.Extra.Base
( module Juvix.Compiler.Core.Extra.Base,
module Juvix.Compiler.Core.Data.BinderList,
)
where

import Data.Functor.Identity
import Juvix.Compiler.Core.Data.BinderList (BinderList)
import Juvix.Compiler.Core.Info qualified as Info
import Juvix.Compiler.Core.Language
import Polysemy.Input
Expand Down Expand Up @@ -217,6 +222,10 @@ reLambda lhs = mkLambda (lhs ^. lambdaLhsInfo) (lhs ^. lambdaLhsBinder)
reLambdas :: [LambdaLhs] -> Node -> Node
reLambdas is n = foldl' (flip reLambda) n (reverse is)

-- | useful with unfoldLambdasRev
reLambdasRev :: [LambdaLhs] -> Node -> Node
reLambdasRev is n = foldl' (flip reLambda) n is

mkLambdaB :: Binder -> Node -> Node
mkLambdaB = mkLambda mempty

Expand All @@ -228,6 +237,7 @@ mkLambdas' k
| k < 0 = impossible
| otherwise = mkLambdasB (replicate k emptyBinder)

-- | \x\y b gives ([y, x], b)
unfoldLambdasRev :: Node -> ([LambdaLhs], Node)
unfoldLambdasRev = go []
where
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Language/Nodes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ makeLenses ''LetItem'
instance Eq (Var' i) where
(Var _ idx1) == (Var _ idx2) = idx1 == idx2

instance Ord (Var' i) where
compare = compare `on` (^. varIndex)

instance Eq (Ident' i) where
(Ident _ sym1) == (Ident _ sym2) = sym1 == sym2

Expand Down
34 changes: 34 additions & 0 deletions src/Juvix/Compiler/Core/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ module Juvix.Compiler.Core.Pretty.Base
where

import Data.HashMap.Strict qualified as HashMap
import Data.Map.Strict qualified as Map
import Juvix.Compiler.Core.Data.BinderList as BL
import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core.Data.Stripped.InfoTable qualified as Stripped
import Juvix.Compiler.Core.Extra
Expand Down Expand Up @@ -124,6 +126,38 @@ ppCodeConstr' name c = do
Nothing -> ppCode (c ^. constrTag)
return $ foldl' (<+>) n' args'

instance (Pretty k, PrettyCode a) => PrettyCode (Map k a) where
ppCode m = do
m' <-
sep . punctuate ","
<$> sequence
[ do
a' <- ppCode a
let k' = pretty k
return $ k' <+> kwMapsto <+> a'
| (k, a) <- Map.toList m
]
return $ braces m'

instance PrettyCode a => PrettyCode (BinderList a) where
ppCode bl = do
m <-
sequence
[ do
v' <- ppCode v
return (pretty k <+> kwMapsto <+> v')
| (k, v) <- BL.toIndexedList bl
]
return $ brackets (hsep $ punctuate "," m)

instance PrettyCode a => PrettyCode (Binder' a) where
ppCode (Binder mname ty) = do
name' <- case mname of
Nothing -> return "_"
Just n -> ppCode n
ty' <- ppCode ty
return (parens (name' <+> kwColon <+> ty'))

ppCodeLet' :: (PrettyCode a, Member (Reader Options) r) => Maybe Name -> Maybe (Doc Ann) -> Let' i a ty -> Sem r (Doc Ann)
ppCodeLet' name mty lt = do
n' <- case name of
Expand Down
Loading

0 comments on commit aa00d34

Please sign in to comment.