Skip to content

Commit

Permalink
Preserve the target type in letrec lifting (#1945)
Browse files Browse the repository at this point in the history
- Closes #1887
  • Loading branch information
janmasrovira authored Mar 30, 2023
1 parent e1e4216 commit 9e9a884
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 62 deletions.
61 changes: 39 additions & 22 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,14 @@ _NLam f = \case
cosmos :: SimpleFold Node Node
cosmos f = ufoldA reassemble f

-- | The list should not contain repeated indices.
-- if fv = x1, x2, .., xn
-- the result is of the form λx1 λx2 .. λ xn b
captureFreeVars :: [(Index, Binder)] -> Node -> Node
captureFreeVars freevars = goBinders freevars . mapFreeVars
-- | The free vars are given in the context of the node.
captureFreeVarsType :: [(Index, Binder)] -> (Node, Type) -> (Node, Type)
captureFreeVarsType freevars (n, ty) =
let bodyTy = mapFreeVars ty
body' = mapFreeVars n
in ( mkLambdasB captureBinders' body',
mkPis captureBinders' bodyTy
)
where
mapFreeVars :: Node -> Node
mapFreeVars = dmapN go
Expand All @@ -112,25 +115,33 @@ captureFreeVars freevars = goBinders freevars . mapFreeVars
NVar (Var i u)
| Just v <- s ^. at (u - k) -> NVar (Var i (v + k))
m -> m

goBinders :: [(Index, Binder)] -> Node -> Node
goBinders fv = case unsnoc fv of
Nothing -> id
Just (fvs, (idx, bin)) -> goBinders fvs . mkLambdaB (mapBinder idx bin)
captureBinders' :: [Binder]
captureBinders' = goBinders freevars []
where
indices = map fst fv
mapBinder :: Index -> Binder -> Binder
mapBinder binderIndex = over binderType (dmapN go)
goBinders :: [(Index, Binder)] -> [Binder] -> [Binder]
goBinders fv acc = case unsnoc fv of
Nothing -> acc
Just (fvs, (idx, bin)) -> goBinders fvs (mapBinder idx bin : acc)
where
go :: Index -> Node -> Node
go k = \case
NVar u
| u ^. varIndex >= k ->
let uCtx = u ^. varIndex - k + binderIndex + 1
err = error ("impossible: could not find " <> show uCtx <> " in " <> show indices)
u' = length indices - 2 - fromMaybe err (elemIndex uCtx indices) + k
in NVar (set varIndex u' u)
m -> m
indices = map fst fv
mapBinder :: Index -> Binder -> Binder
mapBinder binderIndex = over binderType (dmapN go)
where
go :: Index -> Node -> Node
go k = \case
NVar u
| u ^. varIndex >= k ->
let uCtx = u ^. varIndex - k + binderIndex + 1
err = error ("impossible: could not find " <> show uCtx <> " in " <> show indices)
u' = length indices - 2 - fromMaybe err (elemIndex uCtx indices) + k
in NVar (set varIndex u' u)
m -> m

-- | The list should not contain repeated indices.
-- if fv = x1, x2, .., xn
-- the result is of the form λx1 λx2 .. λ xn b
captureFreeVars :: [(Index, Binder)] -> Node -> Node
captureFreeVars freevars n = fst (captureFreeVarsType freevars (n, mkDynamic'))

-- | Captures all free variables of a node. It also returns the list of captured
-- variables in left-to-right order: if snd is of the form λxλy... then fst is
Expand All @@ -140,6 +151,12 @@ captureFreeVarsCtx bl n =
let assocs = freeVarsCtx bl n
in (assocs, captureFreeVars (map (first (^. varIndex)) assocs) n)

captureFreeVarsCtxType :: BinderList Binder -> (Node, Type) -> ([(Var, Binder)], (Node, Type))
captureFreeVarsCtxType bl (n, ty) =
let assocs = freeVarsCtx bl n
assocsi = map (first (^. varIndex)) assocs
in (assocs, captureFreeVarsType assocsi (n, ty))

freeVarsCtxMany' :: BinderList Binder -> [Node] -> [Var]
freeVarsCtxMany' bl = map fst . freeVarsCtxMany bl

Expand Down
30 changes: 21 additions & 9 deletions src/Juvix/Compiler/Core/Transformation/LambdaLetRecLifting.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.NameInfo
import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo (computeNodeType)

lambdaLiftBinder :: Members '[Reader OnlyLetRec, InfoTableBuilder] r => BinderList Binder -> Binder -> Sem r Binder
lambdaLiftBinder bl = traverseOf binderType (lambdaLiftNode bl)
Expand All @@ -22,10 +23,8 @@ lambdaLiftNode aboveBl top =
(topArgs, body) = unfoldLambdas top
in goTop aboveBl body topArgs
where
typeFromArgs :: [ArgumentInfo] -> Type
typeFromArgs = \case
[] -> mkDynamic' -- change this when we have type info about the body
(a : as) -> mkPi mempty (binderFromArgumentInfo a) (typeFromArgs as)
nodeType :: Node -> Sem r Type
nodeType n = flip computeNodeType n <$> getInfoTable

goTop :: BinderList Binder -> Node -> [LambdaLhs] -> Sem r Node
goTop bl body = \case
Expand Down Expand Up @@ -58,13 +57,14 @@ lambdaLiftNode aboveBl top =
argsInfo = map (argumentInfoFromBinder . (^. lambdaLhsBinder)) (fst (unfoldLambdas fBody'))
f <- freshSymbol
let name = uniqueName "lambda" f
ty <- nodeType fBody'
registerIdent
name
IdentifierInfo
{ _identifierSymbol = f,
_identifierName = name,
_identifierLocation = Nothing,
_identifierType = typeFromArgs argsInfo,
_identifierType = ty,
_identifierArgsNum = length argsInfo,
_identifierArgsInfo = argsInfo,
_identifierIsExported = False,
Expand All @@ -78,10 +78,13 @@ lambdaLiftNode aboveBl top =
goLetRec letr = do
let defs :: [Node]
defs = letr ^.. letRecValues . each . letItemValue
defsTypes :: [Type]
defsTypes = letr ^.. letRecValues . each . letItemBinder . binderType
ndefs :: Int
ndefs = length defs
binders :: [Binder]
binders = letr ^.. letRecValues . each . letItemBinder

letRecBinders' :: [Binder] <- mapM (lambdaLiftBinder bl) binders
topSyms :: [Symbol] <- forM defs (const freshSymbol)
let bl' :: BinderList Binder
Expand All @@ -98,7 +101,7 @@ lambdaLiftNode aboveBl top =
helper :: Var -> Maybe (Var, Binder)
helper v
| v ^. varIndex < ndefs = Nothing
| otherwise = Just (set varIndex idx' v, BL.lookup idx' bl)
| otherwise = Just (shiftVar (-ndefs) v, BL.lookup idx' bl)
where
idx' = v ^. varIndex - ndefs

Expand All @@ -120,7 +123,10 @@ lambdaLiftNode aboveBl top =
declareTopSyms =
sequence_
[ do
let topBody = captureFreeVars (map (first (^. varIndex)) recItemsFreeVars) b
let (topBody, topTy) =
captureFreeVarsType
(map (first (^. varIndex)) recItemsFreeVars)
(b, bty)
argsInfo :: [ArgumentInfo]
argsInfo =
map (argumentInfoFromBinder . (^. lambdaLhsBinder)) (fst (unfoldLambdas topBody))
Expand All @@ -131,13 +137,19 @@ lambdaLiftNode aboveBl top =
{ _identifierSymbol = sym,
_identifierName = name,
_identifierLocation = itemBinder ^. binderLocation,
_identifierType = typeFromArgs argsInfo,
_identifierType = topTy,
_identifierArgsNum = length argsInfo,
_identifierArgsInfo = argsInfo,
_identifierIsExported = False,
_identifierBuiltin = Nothing
}
| ((sym, name), (itemBinder, b)) <- zipExact topSymsWithName (zipExact letRecBinders' liftedDefs)
| ((sym, name), (itemBinder, (b, bty))) <-
zipExact
topSymsWithName
( zipExact
letRecBinders'
(zipExact liftedDefs defsTypes)
)
]
declareTopSyms

Expand Down
15 changes: 11 additions & 4 deletions src/Juvix/Compiler/Core/Translation/Stripped/FromCore.hs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
module Juvix.Compiler.Core.Translation.Stripped.FromCore (fromCore) where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core
import Juvix.Compiler.Core.Data.Stripped.InfoTable qualified as Stripped
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Extra.Stripped.Base qualified as Stripped
import Juvix.Compiler.Core.Info.LocationInfo
import Juvix.Compiler.Core.Info.NameInfo
import Juvix.Compiler.Core.Language
import Juvix.Compiler.Core.Language.Stripped qualified as Stripped
import Juvix.Compiler.Core.Pretty

fromCore :: InfoTable -> Stripped.InfoTable
fromCore tab =
Expand Down Expand Up @@ -87,7 +86,15 @@ translateFunction :: Int -> Node -> Stripped.Node
translateFunction argsNum node =
let (k, body) = unfoldLambdas' node
in if
| k /= argsNum -> error "wrong number of arguments"
| k /= argsNum ->
error
( "wrong number of arguments. argsNum = "
<> show argsNum
<> ", unfoldLambdas = "
<> show k
<> "\nNode = "
<> ppTrace node
)
| otherwise -> translateNode body

translateNode :: Node -> Stripped.Node
Expand Down
2 changes: 1 addition & 1 deletion test/Core/Compile/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ignoredTests =
"Fast exponentiation",
"Nested 'case', 'let' and 'if' with variable capture",
"Mutual recursion",
"LetRec",
"LetRec - fib, fact",
"Big numbers"
]

Expand Down
2 changes: 1 addition & 1 deletion test/Core/Eval/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ tests =
$(mkRelFile "test039.jvc")
$(mkRelFile "out/test039.out"),
PosTest
"LetRec"
"LetRec - fib, fact"
$(mkRelDir ".")
$(mkRelFile "test040.jvc")
$(mkRelFile "out/test040.out"),
Expand Down
36 changes: 18 additions & 18 deletions tests/Core/positive/test040.jvc
Original file line number Diff line number Diff line change
@@ -1,42 +1,42 @@
-- letrec
-- letrec - fib, fact

def sum := letrec sum := \x if x = 0 then 0 else x + sum (x - 1) in sum;
def sum : Int -> Int := letrec sum : Int -> Int := \x if x = 0 then 0 else x + sum (x - 1) in sum;

def fact := \x
letrec fact' := \x \acc if x = 0 then acc else fact' (x - 1) (acc * x)
def fact : Int -> Int := \x
letrec fact' : Int -> Int -> Int := \x \acc if x = 0 then acc else fact' (x - 1) (acc * x)
in fact' x 1;

def fib :=
letrec fib' := \n \x \y if n = 0 then x else fib' (n - 1) y (x + y)
def fib : Int -> Int :=
letrec fib' : Int -> Int -> Int -> Int := \n \x \y if n = 0 then x else fib' (n - 1) y (x + y)
in \n fib' n 0 1;

def writeLn := \x write x >> write "\n";

def mutrec :=
let two := 2 in
let one := 1 in
def mutrec : IO :=
let two : Int := 2 in
let one : Int := 1 in
letrec[f g h]
f := \x {
f : Int -> Int := \x {
if x < one then
one
else
g (x - one) + two * x
};
g := \x {
g : Int -> Int := \x {
if x < one then
one
else
x + h (x - one)
};
h := \x letrec z := {
h : Int -> Int := \x letrec z : Int := {
if x < one then
one
else
x * f (x - one)
} in z;
in writeLn (f 5) >> writeLn (f 10) >> writeLn (f 100) >> writeLn (g 5) >> writeLn (h 5);

letrec x := 3
letrec x : Int := 3
in
writeLn x >>
writeLn (sum 10000) >>
Expand All @@ -47,9 +47,9 @@ writeLn (fib 10) >>
writeLn (fib 100) >>
writeLn (fib 1000) >>
mutrec >>
letrec x := 1 in
letrec x' := x + letrec x := 2 in x in
letrec x := x' * x' in
letrec y := x + 2 in
letrec z := x + y in
letrec x : Int := 1 in
letrec x' : Int := x + letrec x : Int := 2 in x in
letrec x : Int := x' * x' in
letrec y : Int := x + 2 in
letrec z : Int := x + y in
writeLn (x + y + z)
36 changes: 29 additions & 7 deletions tests/Core/positive/test043.jvc
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
-- dependent lambda-abstractions

def fun := λ(A : Type) λ(x : A) let f := λ(h : A → A) h (h x) in f (λ(y : A) x);
def fun :
Π A : Type,
A → A :=
λ(A : Type)
λ(x : A)
let f : (A → A) → A := λ(h : A → A) h (h x) in
f (λ(y : A) x);

def fun' : Π T : Type → Type, Π X : Type, Π A : Type, Any :=
λ(T : Type → Type) λ(_ : Type) λ(A : Type) λ(B : T A) λ(x : B)
let f := λ(g : B → B) g (g x) in
let h := λ(b : B) λ(a : A) a * b - b in
f (λ(y : B) h y x);
def helper : Int → Int → Int :=
λ(a : Int)
λ(b : Int)
a * b - b;

fun Int 2 + fun' (λ(A : Type) A) Bool Int Int 3
def fun' : Π T : Type → Type,
Π unused : Type,
Π C : Type,
Π A : Type,
(T A → A → C)
→ A
→ C :=
λ(T : Type → Type)
λ(unused : Type)
λ(C : Type)
λ(A : Type)
λ(mhelper : T A → A → C)
λ(a' : A)
let f : (A → A) → A := λ(g : A → A) g (g a') in
let h : A → A → C := λ(a1 : A) λ(a2 : A) mhelper a2 a1 in
f (λ(y : A) h y a');

fun Int 2 + fun' (λ(A : Type) A) Bool Int Int helper 3

0 comments on commit 9e9a884

Please sign in to comment.