Skip to content

Commit

Permalink
Lazy boolean operators
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszcz committed Jan 19, 2023
1 parent ba03576 commit b89608a
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 66 deletions.
14 changes: 7 additions & 7 deletions c-runtime/builtins/bool.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ typedef bool prim_bool;
#define prim_true true
#define prim_false false

bool is_prim_true(prim_bool b) {
return b == prim_true;
}
bool is_prim_true(prim_bool b) { return b == prim_true; }

bool is_prim_false(prim_bool b) {
return b == prim_false;
}
bool is_prim_false(prim_bool b) { return b == prim_false; }

#define prim_if(b, ifThen, ifElse) (b ? ifThen : ifElse)

#endif // BOOL_H_
#define prim_or(a, b) ((a) || (b))

#define prim_and(a, b) ((a) && (b))

#endif // BOOL_H_
2 changes: 1 addition & 1 deletion juvix-stdlib
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Abstract/Translation/FromConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ registerBuiltinFunction d = \case
BuiltinNatLt -> registerNatLt d
BuiltinNatEq -> registerNatEq d
BuiltinBoolIf -> registerIf d
BuiltinBoolOr -> registerOr d
BuiltinBoolAnd -> registerAnd d

registerBuiltinAxiom ::
Members '[InfoTableBuilder, Error ScoperError, Builtins] r =>
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Backend/C/Data/BuiltinTable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ builtinFunctionName = \case
BuiltinNatLt -> Just natlt
BuiltinNatEq -> Just nateq
BuiltinBoolIf -> Just boolif
BuiltinBoolOr -> Just boolor
BuiltinBoolAnd -> Just booland

builtinName :: BuiltinPrim -> Maybe Text
builtinName = \case
Expand Down
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Backend/C/Data/CNames.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ nateq = primPrefix "nateq"
boolif :: Text
boolif = primPrefix "if"

boolor :: Text
boolor = primPrefix "or"

booland :: Text
booland = primPrefix "and"

funField :: Text
funField = "fun"

Expand Down
51 changes: 33 additions & 18 deletions src/Juvix/Compiler/Builtins/Bool.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module Juvix.Compiler.Builtins.Bool where

import Data.HashSet qualified as HashSet
import Juvix.Compiler.Abstract.Extra
import Juvix.Compiler.Abstract.Pretty
import Juvix.Compiler.Builtins.Effect
Expand Down Expand Up @@ -36,34 +35,50 @@ registerIf f = do
bool_ <- getBuiltinName (getLoc f) BuiltinBool
true_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolTrue
false_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolFalse
vart <- freshVar "t"
let if_ = f ^. funDefName
ty = f ^. funDefTypeSig
freeTVars = HashSet.fromList [vart]
u = ExpressionUniverse (Universe {_universeLevel = Nothing, _universeLoc = error "Universe with no location"})
unless (((u <>--> bool_ --> vart --> vart --> vart) ==% ty) freeTVars) (error "Bool if has the wrong type signature")
registerBuiltin BuiltinBoolIf if_
vart <- freshVar "t"
vare <- freshVar "e"
hole <- freshHole
let e = toExpression vare
freeVars = HashSet.fromList [vare]
(=%) :: (IsExpression a, IsExpression b) => a -> b -> Bool
a =% b = (a ==% b) freeVars
exClauses :: [(Expression, Expression)]
exClauses =
[ (if_ @@ true_ @@ e @@ hole, e),
(if_ @@ false_ @@ hole @@ e, e)
]
clauses :: [(Expression, Expression)]
clauses =
[ (clauseLhsAsExpression c, c ^. clauseBody)
| c <- toList (f ^. funDefClauses)
registerFun f BuiltinBoolIf (u <>--> bool_ --> vart --> vart --> vart) exClauses [vare] [vart]

registerOr :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerOr f = do
bool_ <- getBuiltinName (getLoc f) BuiltinBool
true_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolTrue
false_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolFalse
let or_ = f ^. funDefName
vare <- freshVar "e"
hole <- freshHole
let e = toExpression vare
exClauses :: [(Expression, Expression)]
exClauses =
[ (or_ @@ true_ @@ hole, true_),
(or_ @@ false_ @@ e, e)
]
registerFun f BuiltinBoolOr (bool_ --> bool_ --> bool_) exClauses [vare] []

registerAnd :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerAnd f = do
bool_ <- getBuiltinName (getLoc f) BuiltinBool
true_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolTrue
false_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolFalse
let and_ = f ^. funDefName
vare <- freshVar "e"
hole <- freshHole
let e = toExpression vare
exClauses :: [(Expression, Expression)]
exClauses =
[ (and_ @@ true_ @@ e, e),
(and_ @@ false_ @@ hole, false_)
]
case zipExactMay exClauses clauses of
Nothing -> error "Bool if has the wrong number of clauses"
Just z -> forM_ z $ \((exLhs, exBody), (lhs, body)) -> do
unless (exLhs =% lhs) (error "clause lhs does not match")
unless (exBody =% body) (error $ "clause body does not match " <> ppTrace exBody <> " | " <> ppTrace body)
registerFun f BuiltinBoolAnd (bool_ --> bool_ --> bool_) exClauses [vare] []

registerBoolPrint :: Members '[Builtins] r => AxiomDef -> Sem r ()
registerBoolPrint f = do
Expand Down
29 changes: 29 additions & 0 deletions src/Juvix/Compiler/Builtins/Effect.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ module Juvix.Compiler.Builtins.Effect
)
where

import Data.HashSet qualified as HashSet
import Juvix.Compiler.Abstract.Extra
import Juvix.Compiler.Abstract.Pretty
import Juvix.Compiler.Builtins.Error
import Juvix.Prelude

Expand Down Expand Up @@ -61,3 +63,30 @@ re = reinterpret $ \case

runBuiltins :: Member (Error JuvixError) r => BuiltinsState -> Sem (Builtins ': r) a -> Sem r (BuiltinsState, a)
runBuiltins s = runState s . re

registerFun ::
Members '[Builtins, NameIdGen] r =>
FunctionDef ->
BuiltinFunction ->
Expression ->
[(Expression, Expression)] ->
[VarName] ->
[VarName] ->
Sem r ()
registerFun f blt sig exClauses fvs ftvs = do
let op = f ^. funDefName
ty = f ^. funDefTypeSig
unless ((sig ==% ty) (HashSet.fromList ftvs)) (error "builtin has the wrong type signature")
registerBuiltin blt op
let freeVars = HashSet.fromList fvs
a =% b = (a ==% b) freeVars
clauses :: [(Expression, Expression)]
clauses =
[ (clauseLhsAsExpression c, c ^. clauseBody)
| c <- toList (f ^. funDefClauses)
]
case zipExactMay exClauses clauses of
Nothing -> error "builtin has the wrong number of clauses"
Just z -> forM_ z $ \((exLhs, exBody), (lhs, body)) -> do
unless (exLhs =% lhs) (error "clause lhs does not match")
unless (exBody =% body) (error $ "clause body does not match " <> ppTrace exBody <> " | " <> ppTrace body)
45 changes: 9 additions & 36 deletions src/Juvix/Compiler/Builtins/Nat.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module Juvix.Compiler.Builtins.Nat where

import Data.HashSet qualified as HashSet
import Juvix.Compiler.Abstract.Extra
import Juvix.Compiler.Abstract.Pretty
import Juvix.Compiler.Builtins.Effect
Expand Down Expand Up @@ -39,32 +38,6 @@ registerNatPrint f = do
unless (f ^. axiomType === (nat --> io)) (error "Nat print has the wrong type signature")
registerBuiltin BuiltinNatPrint (f ^. axiomName)

registerNatFun ::
Members '[Builtins, NameIdGen] r =>
FunctionDef ->
BuiltinFunction ->
Expression ->
[(Expression, Expression)] ->
[VarName] ->
Sem r ()
registerNatFun f blt sig exClauses fvs = do
let op = f ^. funDefName
ty = f ^. funDefTypeSig
unless (ty === sig) (error "builtin has the wrong type signature")
registerBuiltin blt op
let freeVars = HashSet.fromList fvs
a =% b = (a ==% b) freeVars
clauses :: [(Expression, Expression)]
clauses =
[ (clauseLhsAsExpression c, c ^. clauseBody)
| c <- toList (f ^. funDefClauses)
]
case zipExactMay exClauses clauses of
Nothing -> error "builtin has the wrong number of clauses"
Just z -> forM_ z $ \((exLhs, exBody), (lhs, body)) -> do
unless (exLhs =% lhs) (error "clause lhs does not match")
unless (exBody =% body) (error $ "clause body does not match " <> ppTrace exBody <> " | " <> ppTrace body)

registerNatPlus :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerNatPlus f = do
nat <- getBuiltinName (getLoc f) BuiltinNat
Expand All @@ -82,7 +55,7 @@ registerNatPlus f = do
[ (zero .+. m, m),
((suc @@ n) .+. m, suc @@ (n .+. m))
]
registerNatFun f BuiltinNatPlus (nat --> nat --> nat) exClauses [varn, varm]
registerFun f BuiltinNatPlus (nat --> nat --> nat) exClauses [varn, varm] []

registerNatMul :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerNatMul f = do
Expand All @@ -103,7 +76,7 @@ registerNatMul f = do
[ (zero .*. h, zero),
((suc @@ n) .*. m, plus @@ m @@ (n .*. m))
]
registerNatFun f BuiltinNatMul (nat --> nat --> nat) exClauses [varn, varm]
registerFun f BuiltinNatMul (nat --> nat --> nat) exClauses [varn, varm] []

registerNatSub :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerNatSub f = do
Expand All @@ -124,7 +97,7 @@ registerNatSub f = do
(n .-. zero, n),
((suc @@ n) .-. (suc @@ m), n .-. m)
]
registerNatFun f BuiltinNatSub (nat --> nat --> nat) exClauses [varn, varm]
registerFun f BuiltinNatSub (nat --> nat --> nat) exClauses [varn, varm] []

registerNatUDiv :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerNatUDiv f = do
Expand All @@ -145,7 +118,7 @@ registerNatUDiv f = do
[ (zero ./. h, zero),
(n ./. m, suc @@ ((sub @@ n @@ m) ./. m))
]
registerNatFun f BuiltinNatUDiv (nat --> nat --> nat) exClauses [varn, varm]
registerFun f BuiltinNatUDiv (nat --> nat --> nat) exClauses [varn, varm] []

registerNatDiv :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerNatDiv f = do
Expand All @@ -164,7 +137,7 @@ registerNatDiv f = do
exClauses =
[ (n ./. m, udiv @@ (sub @@ (suc @@ n) @@ m) @@ m)
]
registerNatFun f BuiltinNatDiv (nat --> nat --> nat) exClauses [varn, varm]
registerFun f BuiltinNatDiv (nat --> nat --> nat) exClauses [varn, varm] []

registerNatMod :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerNatMod f = do
Expand All @@ -180,7 +153,7 @@ registerNatMod f = do
exClauses =
[ (modop @@ n @@ m, sub @@ n @@ (mul @@ (divop @@ n @@ m) @@ m))
]
registerNatFun f BuiltinNatMod (nat --> nat --> nat) exClauses [varn, varm]
registerFun f BuiltinNatMod (nat --> nat --> nat) exClauses [varn, varm] []

registerNatLe :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerNatLe f = do
Expand All @@ -204,7 +177,7 @@ registerNatLe f = do
(h .<=. zero, false),
((suc @@ n) .<=. (suc @@ m), n .<=. m)
]
registerNatFun f BuiltinNatLe (nat --> nat --> tybool) exClauses [varn, varm]
registerFun f BuiltinNatLe (nat --> nat --> tybool) exClauses [varn, varm] []

registerNatLt :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerNatLt f = do
Expand All @@ -221,7 +194,7 @@ registerNatLt f = do
exClauses =
[ (lt @@ n @@ m, le @@ (suc @@ n) @@ m)
]
registerNatFun f BuiltinNatLt (nat --> nat --> tybool) exClauses [varn, varm]
registerFun f BuiltinNatLt (nat --> nat --> tybool) exClauses [varn, varm] []

registerNatEq :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerNatEq f = do
Expand All @@ -246,4 +219,4 @@ registerNatEq f = do
(h .==. zero, false),
((suc @@ n) .==. (suc @@ m), n .==. m)
]
registerNatFun f BuiltinNatEq (nat --> nat --> tybool) exClauses [varn, varm]
registerFun f BuiltinNatEq (nat --> nat --> tybool) exClauses [varn, varm] []
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Concrete/Data/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ data BuiltinFunction
| BuiltinNatLt
| BuiltinNatEq
| BuiltinBoolIf
| BuiltinBoolOr
| BuiltinBoolAnd
deriving stock (Show, Eq, Ord, Enum, Bounded, Generic, Data)

instance Hashable BuiltinFunction
Expand All @@ -88,6 +90,8 @@ instance Pretty BuiltinFunction where
BuiltinNatLt -> Str.natLt
BuiltinNatEq -> Str.natEq
BuiltinBoolIf -> Str.boolIf
BuiltinBoolOr -> Str.boolOr
BuiltinBoolAnd -> Str.boolAnd

data BuiltinAxiom
= BuiltinNatPrint
Expand Down
12 changes: 12 additions & 0 deletions src/Juvix/Compiler/Core/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,18 @@ goApplication a = do
case as of
(_ : v : b1 : b2 : xs) -> return (mkApps' (mkIf' sym v b1 b2) xs)
_ -> error "if must be called with 3 arguments"
Just Internal.BuiltinBoolOr -> do
sym <- getBoolSymbol
as <- exprArgs
case as of
(x : y : xs) -> return (mkApps' (mkIf' sym x (mkConstr' (BuiltinTag TagTrue) []) y) xs)
_ -> error "|| must be called with 2 arguments"
Just Internal.BuiltinBoolAnd -> do
sym <- getBoolSymbol
as <- exprArgs
case as of
(x : y : xs) -> return (mkApps' (mkIf' sym x y (mkConstr' (BuiltinTag TagFalse) [])) xs)
_ -> error "&& must be called with 2 arguments"
_ -> app
_ -> app

Expand Down
6 changes: 6 additions & 0 deletions src/Juvix/Extra/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ natEq = "nat-eq"
boolIf :: IsString s => s
boolIf = "bool-if"

boolOr :: IsString s => s
boolOr = "bool-or"

boolAnd :: IsString s => s
boolAnd = "bool-and"

builtin :: IsString s => s
builtin = "builtin"

Expand Down
2 changes: 1 addition & 1 deletion test/Compilation/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ tests =
$(mkRelFile "test005.juvix")
$(mkRelFile "out/test005.out"),
posTest
"If-then-else"
"If-then-else and lazy boolean operators"
$(mkRelDir ".")
$(mkRelFile "test006.juvix")
$(mkRelFile "out/test006.out"),
Expand Down
2 changes: 2 additions & 0 deletions tests/Compilation/positive/out/test006.out
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
2
true
false
7 changes: 4 additions & 3 deletions tests/Compilation/positive/test006.juvix
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-- if then else
-- if-then-else and lazy boolean operators
module test006;

open import Stdlib.Prelude;
Expand All @@ -9,7 +9,8 @@ loop : Nat;
loop := loop;

main : IO;
main := printNatLn $ (if (3 > 0) 1 loop) + (if (2 < 1) loop (if (7 >= 8) loop 1));

main := printNatLn ((if (3 > 0) 1 loop) + (if (2 < 1) loop (if (7 >= 8) loop 1))) >>
printBoolLn (2 > 0 || loop == 0) >>
printBoolLn (2 < 0 && loop == 0);

end;

0 comments on commit b89608a

Please sign in to comment.