Skip to content

Commit

Permalink
Cover Data.List interface more comprehensively
Browse files Browse the repository at this point in the history
  • Loading branch information
utdemir committed Sep 22, 2020
1 parent 4fa4757 commit 9e8f2b7
Showing 1 changed file with 130 additions and 2 deletions.
132 changes: 130 additions & 2 deletions src/Data/List/Linear.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,31 @@ original "Data.List" module for more detailed information.
-}
module Data.List.Linear
( -- * Basic functions
map
(++)
, map
, filter
, head
, uncons
, tail
, last
, init
, reverse
, length
, splitAt
, span
, partition
, takeWhile
, dropWhile
, find
, intersperse
, intercalate
, transpose
-- * Folds
, foldr
, foldl
, foldl'
, foldl1'
, foldr
, foldr1
, foldMap
, foldMap'
-- * Special folds
Expand All @@ -35,9 +46,15 @@ module Data.List.Linear
, or
, any
, all
, sum
, product
-- * Building lists
, scanl
, scanl1
, scanr
, scanr1
, replicate
, cycle
, unfoldr
-- * Zipping lists
, zip
Expand All @@ -57,12 +74,17 @@ import Data.Bool.Linear
import Data.Unrestricted.Linear
import Data.Functor.Linear
import Data.Monoid.Linear
import Data.Num.Linear
import Data.List.NonEmpty (NonEmpty ((:|)))
import GHC.Stack
import qualified Data.List as NonLinear

-- # Basic functions
--------------------------------------------------

(++) :: [a] #-> [a] #-> [a]
(++) = Unsafe.toLinear2 (NonLinear.++)

map :: (a #-> b) -> [a] #-> [b]
map = fmap

Expand All @@ -78,10 +100,30 @@ filter p (x:xs) =
then x'' : filter p xs
else x'' `lseq` filter p xs

-- | __NOTE__: This does not short-circuit and always traverses the
-- entire array to consume the rest of the elements.
head :: (HasCallStack, Consumable a) => [a] #-> a
head [] = Prelude.error "head: empty list"
head (x:xs) = xs `lseq` x

uncons :: [a] #-> Maybe (a, [a])
uncons [] = Nothing
uncons (x:xs) = Just (x, xs)

tail :: (HasCallStack, Consumable a) => [a] #-> [a]
tail [] = Prelude.error "tail: empty list"
tail (x:xs) = x `lseq` xs

last :: (HasCallStack, Consumable a) => [a] #-> a
last [] = Prelude.error "last: empty list"
last [x] = x
last (x:xs) = x `lseq` last xs

init :: (HasCallStack, Consumable a) => [a] #-> [a]
init [] = Prelude.error "init: empty list"
init [x] = x `lseq` []
init (x:xs) = x : init xs

reverse :: [a] #-> [a]
reverse = Unsafe.toLinear NonLinear.reverse

Expand All @@ -92,6 +134,16 @@ length = Unsafe.toLinear $ \xs ->
-- We can only do this because of the fact that 'NonLinear.length'
-- does not inspect the elements.

-- | __NOTE__: This does not short-circuit and always traverses the
-- entire array to consume the rest of the elements.
find :: Dupable a => (a #-> Bool) -> [a] #-> Maybe a
find _ [] = Nothing
find p (x:xs) =
dup x & \(x', x'') ->
if p x'
then xs `lseq` Just x''
else x'' `lseq` find p xs

-- 'splitAt' @n xs@ returns a tuple where first element is @xs@ prefix of
-- length @n@ and second element is the remainder of the list.
splitAt :: Int -> [a] #-> ([a], [a])
Expand All @@ -108,6 +160,37 @@ span f (x:xs) = dup x & \case
then span f xs & \case (ts, fs) -> (x'':ts, fs)
else ([x''], xs)

-- The partition function takes a predicate a list and returns the
-- pair of lists of elements which do and do not satisfy the predicate,
-- respectively.
partition :: Dupable a => (a #-> Bool) -> [a] #-> ([a], [a])
partition p (xs :: [a]) = foldr select ([], []) xs
where
select :: a #-> ([a], [a]) #-> ([a], [a])
select x (ts, fs) =
dup2 x & \(x', x'') ->
if p x'
then (x'':ts, fs)
else (ts, x'':fs)

-- | __NOTE__: This does not short-circuit and always traverses the
-- entire array to consume the rest of the elements.
takeWhile :: Dupable a => (a #-> Bool) -> [a] #-> [a]
takeWhile _ [] = []
takeWhile p (x:xs) =
dup2 x & \(x', x'') ->
if p x'
then x'' : takeWhile p xs
else (x'', xs) `lseq` []

dropWhile :: Dupable a => (a #-> Bool) -> [a] #-> [a]
dropWhile _ [] = []
dropWhile p (x:xs) =
dup2 x & \(x', x'') ->
if p x'
then x'' `lseq` dropWhile p xs
else x'' : xs

-- | The intersperse function takes an element and a list and
-- `intersperses' that element between the elements of the list.
intersperse :: a -> [a] #-> [a]
Expand All @@ -129,12 +212,21 @@ transpose = Unsafe.toLinear NonLinear.transpose
foldr :: (a #-> b #-> b) -> b #-> [a] #-> b
foldr f = Unsafe.toLinear2 (NonLinear.foldr (\a b -> f a b))

foldr1 :: HasCallStack => (a #-> a #-> a) -> [a] #-> a
foldr1 f = Unsafe.toLinear (NonLinear.foldr1 (\a b -> f a b))

foldl :: (b #-> a #-> b) -> b #-> [a] #-> b
foldl f = Unsafe.toLinear2 (NonLinear.foldl (\b a -> f b a))

foldl' :: (b #-> a #-> b) -> b #-> [a] #-> b
foldl' f = Unsafe.toLinear2 (NonLinear.foldl' (\b a -> f b a))

foldl1 :: HasCallStack => (a #-> a #-> a) -> [a] #-> a
foldl1 f = Unsafe.toLinear (NonLinear.foldl1 (\a b -> f a b))

foldl1' :: HasCallStack => (a #-> a #-> a) -> [a] #-> a
foldl1' f = Unsafe.toLinear (NonLinear.foldl1' (\a b -> f a b))

-- | Map each element of the structure to a monoid,
-- and combine the results.
foldMap :: Monoid m => (a #-> m) -> [a] #-> m
Expand All @@ -150,6 +242,12 @@ concat = Unsafe.toLinear NonLinear.concat
concatMap :: (a #-> [b]) -> [a] #-> [b]
concatMap f = Unsafe.toLinear (NonLinear.concatMap (forget f))

sum :: AddIdentity a => [a] #-> a
sum = foldl' (+) zero

product :: MultIdentity a => [a] #-> a
product = foldl' (*) one

-- | __NOTE:__ This does not short-circuit, and always consumes the
-- entire container.
any :: (a #-> Bool) -> [a] #-> Bool
Expand All @@ -173,10 +271,40 @@ or = foldl' (||) False
-- # Building Lists
--------------------------------------------------

iterate :: Dupable a => (a #-> a) -> a #-> [a]
iterate f a = dup2 a & \(a', a'') ->
a' : iterate f (f a'')

repeat :: Dupable a => a #-> [a]
repeat = iterate id

cycle :: (HasCallStack, Dupable a) => [a] #-> [a]
cycle [] = Prelude.error "cycle: empty list"
cycle xs = dup2 xs & \(xs', xs'') -> xs' ++ cycle xs''

scanl :: Dupable b => (b #-> a #-> b) -> b #-> [a] #-> [b]
scanl _ b [] = [b]
scanl f b (x:xs) = dup2 b & \(b', b'') -> b' : scanl f (f b'' x) xs

scanl1 :: Dupable a => (a #-> a #-> a) -> [a] #-> [a]
scanl1 _ [] = []
scanl1 f (x:xs) = scanl f x xs

scanr :: Dupable b => (a #-> b #-> b) -> b #-> [a] #-> [b]
scanr _ b [] = [b]
scanr f b (a:as) =
scanr f b as & \(b':bs') ->
dup2 b' & \(b'', b''') ->
f a b'' : b''' : bs'

scanr1 :: Dupable a => (a #-> a #-> a) -> [a] #-> [a]
scanr1 _ [] = []
scanr1 _ [a] = [a]
scanr1 f (a:as) =
scanr1 f as & \(a':as') ->
dup2 a' & \(a'', a''') ->
f a a'' : a''' : as'

replicate :: Dupable a => Int -> a #-> [a]
replicate i a
| i Prelude.< 1 = a `lseq` []
Expand Down

0 comments on commit 9e8f2b7

Please sign in to comment.