Skip to content

Commit

Permalink
Merge pull request #868 from GaloisInc/foldl
Browse files Browse the repository at this point in the history
Foldl
  • Loading branch information
robdockins authored Sep 15, 2020
2 parents 751ed48 + e79786e commit ac0aca1
Show file tree
Hide file tree
Showing 12 changed files with 267 additions and 126 deletions.
69 changes: 59 additions & 10 deletions lib/Cryptol.cry
Original file line number Diff line number Diff line change
Expand Up @@ -928,14 +928,38 @@ pmod x y = if y == 0 then 0/0 else last zs
/**
* Parallel map. The given function is applied to each element in the
* given finite seqeuence, and the results are computed in parallel.
* The values in the resulting sequence are reduced to normal form,
* as is done with the deepseq operation.
*
* The Eq constraint restricts this operation to types
* where reduction to normal form makes sense.
*
* This function is experimental.
*/
primitive parmap : {a, b, n} (fin n) => (a -> b) -> [n]a -> [n]b
primitive parmap : {a, b, n} (Eq b, fin n) => (a -> b) -> [n]a -> [n]b


// Utility operations -----------------------------------------------------------------

/**
* A strictness-increasing operation. The first operand
* is reduced to normal form before evaluating the second
* argument.
*
* The Eq constraint restricts this operation to types
* where reduction to normal form makes sense.
*/
primitive deepseq : {a, b} Eq a => a -> b -> b

/**
* Reduce to normal form.
*
* The Eq constraint restricts this operation to types
* where reduction to normal form makes sense.
*/
rnf : {a} Eq a => a -> a
rnf x = deepseq x x

/**
* Raise a run-time error with the given message.
* This function can be called at any type.
Expand Down Expand Up @@ -1009,13 +1033,13 @@ or xs = zero != xs
* Conjunction after applying a predicate to all elements.
*/
all : {n, a} (fin n) => (a -> Bit) -> [n]a -> Bit
all f xs = and (map f xs)
all f xs = foldl' (/\) True (map f xs)

/**
* Disjunction after applying a predicate to all elements.
*/
any : {n, a} (fin n) => (a -> Bit) -> [n]a -> Bit
any f xs = or (map f xs)
any f xs = foldl' (\/) False (map f xs)

/**
* Map a function over a sequence.
Expand All @@ -1028,24 +1052,49 @@ map f xs = [f x | x <- xs]
*
* foldl (+) 0 [1,2,3] = ((0 + 1) + 2) + 3
*/
foldl : {n, a, b} (fin n) => (a -> b -> a) -> a -> [n]b -> a
foldl f acc xs = ys ! 0
where ys = [acc] # [f a x | a <- ys | x <- xs]
primitive foldl : {n, a, b} (fin n) => (a -> b -> a) -> a -> [n]b -> a

/**
* Functional left fold, with strict evaluation of the accumulator value.
* The accumulator is reduced to normal form at each step. The Eq constraint
* restricts the accumulator to types where reduction to normal form makes sense.
*
* foldl' (+) 0 [1,2,3] = ((0 + 1) + 2) + 3
*/
primitive foldl' : {n, a, b} (fin n, Eq a) => (a -> b -> a) -> a -> [n]b -> a

/**
* Functional right fold.
*
* foldr (-) 0 [1,2,3] = 0 - (1 - (2 - 3))
*/
foldr : {n, a, b} (fin n) => (a -> b -> b) -> b -> [n]a -> b
foldr f acc xs = ys ! 0
where ys = [acc] # [f x a | a <- ys | x <- reverse xs]
foldr f acc xs = foldl g acc (reverse xs)
where g b a = f a b

/**
* Functional right fold, with strict evaluation of the accumulator value.
* The accumulator is reduced to weak head normal form at each step.
*
* foldr' (-) 0 [1,2,3] = 0 - (1 - (2 - 3))
*/
foldr' : {n, a, b} (fin n, Eq b) => (a -> b -> b) -> b -> [n]a -> b
foldr' f acc xs = foldl' g acc (reverse xs)
where g b a = f a b

/**
* Compute the sum of the values in the sequence.
*/
sum : {n, a} (fin n, Ring a) => [n]a -> a
sum xs = foldl (+) (fromInteger 0) xs
sum : {n, a} (fin n, Eq a, Ring a) => [n]a -> a
sum xs = foldl' (+) (fromInteger 0) xs


/**
* Compute the product of the values in the sequence.
*/
product : {n, a} (fin n, Eq a, Ring a) => [n]a -> a
product xs = foldl' (*) (fromInteger 1) xs


/**
* Scan left is like a foldl that also emits the intermediate values.
Expand Down
14 changes: 14 additions & 0 deletions src/Cryptol/Eval/Concrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,20 @@ primTable eOpts = let sym = Concrete in
updatePrim sym updateBack_word updateBack)

-- Misc
, ("foldl" , {-# SCC "Prelude::foldl" #-}
foldlV sym)

, ("foldl'" , {-# SCC "Prelude::foldl'" #-}
foldl'V sym)

, ("deepseq" , {-# SCC "Prelude::deepseq" #-}
tlam $ \_a ->
tlam $ \_b ->
lam $ \x -> pure $
lam $ \y ->
do _ <- forceValue =<< x
y)

, ("parmap" , {-# SCC "Prelude::parmap" #-}
parmapV sym)

Expand Down
59 changes: 58 additions & 1 deletion src/Cryptol/Eval/Generic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE MultiParamTypeClasses #-}
Expand Down Expand Up @@ -1909,6 +1910,57 @@ mergeSeqMap sym c x y =
iteValue sym c (lookupSeqMap x i) (lookupSeqMap y i)



foldlV :: Backend sym => sym -> GenValue sym
foldlV sym =
ilam $ \_n ->
tlam $ \_a ->
tlam $ \_b ->
lam $ \f -> pure $
lam $ \z -> pure $
lam $ \v ->
v >>= \case
VSeq n m -> go0 f z (enumerateSeqMap n m)
VWord _n wv -> go0 f z . map (pure . VBit) =<< (enumerateWordValue sym =<< wv)
_ -> panic "Cryptol.Eval.Generic.foldlV" ["Expected finite sequence"]
where
go0 _f a [] = a
go0 f a bs =
do f' <- fromVFun <$> f
go1 f' a bs

go1 _f a [] = a
go1 f a (b:bs) =
do f' <- fromVFun <$> (f a)
go1 f (f' b) bs

foldl'V :: Backend sym => sym -> GenValue sym
foldl'V sym =
ilam $ \_n ->
tlam $ \_a ->
tlam $ \_b ->
lam $ \f -> pure $
lam $ \z -> pure $
lam $ \v ->
v >>= \case
VSeq n m -> go0 f z (enumerateSeqMap n m)
VWord _n wv -> go0 f z . map (pure . VBit) =<< (enumerateWordValue sym =<< wv)
_ -> panic "Cryptol.Eval.Generic.foldlV" ["Expected finite sequence"]
where
go0 _f a [] = a
go0 f a bs =
do f' <- fromVFun <$> f
a' <- sDelay sym Nothing a
forceValue =<< a'
go1 f' a' bs

go1 _f a [] = a
go1 f a (b:bs) =
do f' <- fromVFun <$> (f a)
a' <- sDelay sym Nothing (f' b)
forceValue =<< a'
go1 f a' bs

--------------------------------------------------------------------------------
-- Experimental parallel primitives

Expand Down Expand Up @@ -1940,7 +1992,12 @@ sparkParMap ::
SeqMap sym ->
SEval sym (SeqMap sym)
sparkParMap sym f n m =
finiteSeqMap sym <$> mapM (sSpark sym . f) (enumerateSeqMap n m)
finiteSeqMap sym <$> mapM (sSpark sym . g) (enumerateSeqMap n m)
where
g x =
do z <- sDelay sym Nothing (f x)
forceValue =<< z
z

--------------------------------------------------------------------------------
-- Floating Point Operations
Expand Down
11 changes: 11 additions & 0 deletions src/Cryptol/Eval/SBV.hs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,17 @@ primTable = let sym = SBV in

, ("fromZ" , fromZV sym)

, ("foldl" , foldlV sym)
, ("foldl'" , foldl'V sym)

, ("deepseq" ,
tlam $ \_a ->
tlam $ \_b ->
lam $ \x -> pure $
lam $ \y ->
do _ <- forceValue =<< x
y)

, ("parmap" , parmapV sym)

-- {at,len} (fin len) => [len][8] -> at
Expand Down
11 changes: 11 additions & 0 deletions src/Cryptol/Eval/What4.hs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,17 @@ primTable w4sym = let sym = What4 w4sym in

-- Misc

, ("foldl" , foldlV sym)
, ("foldl'" , foldl'V sym)

, ("deepseq" ,
tlam $ \_a ->
tlam $ \_b ->
lam $ \x -> pure $
lam $ \y ->
do _ <- forceValue =<< x
y)

, ("parmap" , parmapV sym)

, ("fromZ" , fromZV sym)
Expand Down
16 changes: 8 additions & 8 deletions tests/issues/T146.icry.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ Loading module Cryptol
Loading module Main

[error] at T146.cry:1:18--6:10:
The type ?fv`959 is not sufficiently polymorphic.
It cannot depend on quantified variables: fv`943
The type ?fv`980 is not sufficiently polymorphic.
It cannot depend on quantified variables: fv`964
where
?fv`959 is type argument 'fv' of 'Main::ec_v1' at T146.cry:4:19--4:24
fv`943 is signature variable 'fv' at T146.cry:11:10--11:12
?fv`980 is type argument 'fv' of 'Main::ec_v1' at T146.cry:4:19--4:24
fv`964 is signature variable 'fv' at T146.cry:11:10--11:12
[error] at T146.cry:5:19--5:24:
The type ?fv`961 is not sufficiently polymorphic.
It cannot depend on quantified variables: fv`943
The type ?fv`982 is not sufficiently polymorphic.
It cannot depend on quantified variables: fv`964
where
?fv`961 is type argument 'fv' of 'Main::ec_v2' at T146.cry:5:19--5:24
fv`943 is signature variable 'fv' at T146.cry:11:10--11:12
?fv`982 is type argument 'fv' of 'Main::ec_v2' at T146.cry:5:19--5:24
fv`964 is signature variable 'fv' at T146.cry:11:10--11:12
9 changes: 7 additions & 2 deletions tests/issues/issue226.icry.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,17 @@ Symbols
ceiling : {a} (Round a) => a -> Integer
complement : {a} (Logic a) => a -> a
curry : {a, b, c} ((a, b) -> c) -> a -> b -> c
deepseq : {a, b} (Eq a) => a -> b -> b
demote : {val, rep} (Literal val rep) => rep
drop : {front, back, a} (fin front) => [front + back]a -> [back]a
elem : {n, a} (fin n, Eq a) => a -> [n]a -> Bit
error : {a, n} (fin n) => String n -> a
False : Bit
floor : {a} (Round a) => a -> Integer
foldl : {n, a, b} (fin n) => (a -> b -> a) -> a -> [n]b -> a
foldl' : {n, a, b} (fin n, Eq a) => (a -> b -> a) -> a -> [n]b -> a
foldr : {n, a, b} (fin n) => (a -> b -> b) -> b -> [n]a -> b
foldr' : {n, a, b} (fin n, Eq b) => (a -> b -> b) -> b -> [n]a -> b
fraction : {m, n, r, a} (FLiteral m n r a) => a
fromInteger : {a} (Ring a) => Integer -> a
fromThenTo :
Expand Down Expand Up @@ -162,16 +165,18 @@ Symbols
negate : {a} (Ring a) => a -> a
number : {val, rep} (Literal val rep) => rep
or : {n} (fin n) => [n] -> Bit
parmap : {a, b, n} (fin n) => (a -> b) -> [n]a -> [n]b
parmap : {a, b, n} (Eq b, fin n) => (a -> b) -> [n]a -> [n]b
pdiv : {u, v} (fin u, fin v) => [u] -> [v] -> [u]
pmod : {u, v} (fin u, fin v) => [u] -> [1 + v] -> [v]
pmult :
{u, v} (fin u, fin v) => [1 + u] -> [1 + v] -> [1 + (u + v)]
product : {n, a} (fin n, Eq a, Ring a) => [n]a -> a
random : {a} [256] -> a
ratio : Integer -> Integer -> Rational
recip : {a} (Field a) => a -> a
repeat : {n, a} a -> [n]a
reverse : {n, a} (fin n) => [n]a -> [n]a
rnf : {a} (Eq a) => a -> a
roundAway : {a} (Round a) => a -> Integer
roundToEven : {a} (Round a) => a -> Integer
sborrow : {n} (fin n, n >= 1) => [n] -> [n] -> Bit
Expand All @@ -184,7 +189,7 @@ Symbols
splitAt :
{front, back, a} (fin front) =>
[front + back]a -> ([front]a, [back]a)
sum : {n, a} (fin n, Ring a) => [n]a -> a
sum : {n, a} (fin n, Eq a, Ring a) => [n]a -> a
True : Bit
tail : {n, a} [1 + n]a -> [n]a
take : {front, back, a} (fin front) => [front + back]a -> [front]a
Expand Down
4 changes: 2 additions & 2 deletions tests/issues/issue290v2.icry.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ Loading module Main

[error] at issue290v2.cry:2:1--2:19:
Unsolved constraints:
• n`940 == 1
• n`961 == 1
arising from
checking a pattern: type of 1st argument of Main::minMax
at issue290v2.cry:2:8--2:11
where
n`940 is signature variable 'n' at issue290v2.cry:1:11--1:12
n`961 is signature variable 'n' at issue290v2.cry:1:11--1:12
4 changes: 2 additions & 2 deletions tests/issues/issue723.icry.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ Loading module Main
assuming
• fin k
the following constraints hold:
• k == n`940
• k == n`961
arising from
matching types
at issue723.cry:7:17--7:19
where
n`940 is signature variable 'n' at issue723.cry:1:6--1:7
n`961 is signature variable 'n' at issue723.cry:1:6--1:7
Loading

0 comments on commit ac0aca1

Please sign in to comment.