title | permalink |
---|---|
Functional Programming in Java |
/13ln-fp2/ |
Let's start with a recap of the three main concepts of last weeks lecture. Working on lists (and later streams), we defined three main methods:
filter
utilizing aPredicate<T>
to retain only certain elements,map
utilizing aFunction<T, R>
to transform a list of elements of typeT
to a list of typeR
, and finallyforEach
utilizing aConsumer<T>
that accepts (in order of the list).
Working on lists, we defined those recursively:
static <T> List<T> filter(List<T> xs, Predicate<T> p) {
if (xs.isEmpty()) return xs;
else if (p.test(xs.head)) return list(xs.head, filter(xs.tail, p));
else return filter(xs.tail, p);
}
static <A, B> List<B> map(List<A> xs, Function<A, B> f) {
if (xs.isEmpty()) return empty();
else return list(f.apply(xs.head), map(xs.tail, f));
}
static <A> void forEach(List<A> xs, Consumer<A> c) {
if (xs.isEmpty()) return;
else {
c.accept(xs.head);
forEach(xs.tail, c);
// return (added for clarity)
}
}
Here are three key observations:
- All three methods "iterate" over the list, i.e. all elements are visited.
- The
forEach
method is tail recursive, as in the recursive call is the very last one prior toreturn
. - The
filter
andmap
methods return another list, whileforEach
returns nothing (void
).
In this (final) chapter, we'll talk about list (or stream) reduction, that is reducing a sequence of values to a single value.
Let's start with a simple example: sum all numbers of a list.
static int sum(List<Integer> xs) {
if (xs.isEmpty()) return 0; // sum of an empty list is zero
else return xs.head + sum(xs.tail);
}
For list(1, 3, 3, 7)
, this function evaluates to
sum(list(1, 3, 3, 7))
-> 1 + sum(list(3, 3, 7))
-> 1 + (3 + sum(list(3, 7)))
-> 1 + (3 + (3 + sum(list(7))))
-> 1 + (3 + (3 + (7 + sum(empty()))))
-> 1 + (3 + (3 + (7 + 0)))
-> 1 + (3 + (3 + 7))
-> 1 + (3 + 10)
-> 1 + 13
-> 14
As you can see, the recursion expands until the terminal case is reached, and the first return
happens.
Then the addition is done all the way back up the call stack.
From last week, we know that this is unfavorable: the recursion depth is as many as there are list elements. Here is a better, tail recursive variant:
static int sum(List<Integer> xs, int z) {
if (xs.isEmpty()) return z;
else return sum(xs.tail, z + xs.head);
}
which evaluates to
sum(list(1, 3, 3, 7), 0)
-> sum(list(3, 3, 7), 0 + 1)
-> sum(list(3, 7), 1 + 3)
-> sum(list(7), 4 + 3)
-> sum(empty(), 7 + 7)
-> 14
As mentioned last week, tail recursive calls are much more efficient; depending on the language, they can be realized as a for-loop reusing the stack variables.
Let's consider another example: joining Strings together by concatenating them.
static String join(List<String> xs, String z) {
if (xs.isEmpty()) return z;
else return join(xs.tail, z + xs.head);
}
Clearly, the sum
and join
functions are almost identical -- the only difference being the Integer
and String
types.
So why not generalize?
static <T> T reduce(List<T> xs, T z) {
if (xs.isEmpty()) return z;
else return reduce(xs.tail, z + xs.head); // oops :-(
}
Unfortunately, the +
operator is only defined for basic types (including java.lang.String
), and Java does not support operator overloading.
But look closer what the +
actually is: it is a binary operation to combine two values to a single value of the same type.
Both String
and Integer
actually offer such methods:
static int reduce(List<Integer> xs, int z) {
if (xs.isEmpty()) return z;
else return sum(xs.tail, Integer.sum(z, xs.head));
}
import org.apache.commons.lang3.StringUtils;
static String reduce(List<String> xs, String z) {
if (xs.isEmpty()) return z;
else return join(xs.tail, StringUtils.join(z, xs.head));
}
Let's isolate the operation, using the interface java.util.function.BinaryOperator<T>
:
interface BinaryOperator<T> {
T apply(T t1, T t2);
}
static <T> T reduce(List<T>, T z, BinaryOperator<T> op) {
if (xs.isEmpty()) return z;
else return reduce(xs.tail, op.apply(z, xs.head), op);
}
reduce(list(1, 3, 3, 7), 0, (i, j) -> Integer.sum(i, j)); // 14
reduce(list(1, 3, 3, 7), 0, Integer::sum);
reduce(list("a", "b", "c", "d"), "", (a, b) -> a.concat(b)); // abcd
reduce(list("a", "b", "c", "d"), "", String::concat);
It may sound odd, but forEach
is actually a special case of reduce
:
reduce(list(1, 3, 3, 7), 0, (i, j) -> { System.out.println(j); return j; });
The reduce
function derived above is a bit restricted: it only works to reduce elements of type T
to another T
.
This might be a problem: consider the case where you sum up a very long list of potentially large Integer
s -- you may run into an overflow.
The solution to this would be to add the Integer
s from the list to a BigInteger which is of arbitrary precision.
In terms of a for
-loop, this would be
BigInteger sum = BigInteger.ZERO;
for (Integer i : xs) {
sum = sum.add(BigInteger.valueOf(i));
}
So our hypothetical reduce
function for this would be
static BigInteger reduce(List<Integer> xs, BigInteger z) {
if (xs.isEmpty()) return z;
else return reduce(xs.tail, z.add(BigInteger.valueOf(xs.head)));
}
reduce(list(1, 3, 3, 7), BigInteger.ZERO);
By now, you probably already guessed it: we'll isolate the actual operation!
We need a function that takes a BigInteger
(the accumulator), adds an Integer
, and returns a BigInteger
.
We'll do so with the interface java.util.function.BiFunction<T, U, R>
(but tying T
and R
), and naming it foldl
(read: fold left).
static <T, R> R foldl(List<T> xs, R z, BiFunction<R, T, R> op) {
if (xs.isEmpty()) return z;
else return foldl(xs.tail, op.apply(z, xs.head), op);
}
foldl(xs, BigInteger.ZERO, (b, i) -> b.add(BigInteger.valueOf(i)));
The function is called left fold, since the list is folded to the left, if you were to look at the evaluation:
foldl(list(1, 3, 3, 7), 0)
-> foldl(list(3, 3, 7), 0+1)
-> foldl(list(3, 7), 1+3)
-> foldl(list(7), 4+3)
-> foldl(empty(), 7+7)
-> 14
and visualized that a list, the operations are performed in this order:
// start at the bottom left!
op
/ \
op 7
/ \
op 3
/ \
op 3
/ \
z 1
Look at that list again, doesn't it look oddly familiar?
If we define z
is the empty list, and op
is the list constructor, you end up with the reverse of the original list:
foldl(list(1, 3, 3, 7), List.<Integer>empty(),
(xs, x) -> list(x, xs)); // 7, 3, 3, 1
Let's go back to the original, non-tail-recursive definition of sum
:
static int sum(List<Integer> xs) {
if (xs.isEmpty()) return 0; // sum of an empty list is zero
else return xs.head + sum(xs.tail);
}
static BigInteger sum(List<Integer> xs, BigInteger z) {
if (xs.isEmpty()) return z;
else return BigInteger.valueOf(xs.head).add(sum(xs.tail, z));
}
If you isolate the operation (+
or .add()
, respectively), you end up with a right fold:
static <T, R> R foldr(List<T> xs, R z, BiFunction<T, R, R> op) {
if (xs.isEmpty()) return z;
else return op.apply(xs.head, foldr(xs.tail, z, op));
}
foldr(list(1, 3, 3, 7), BigInteger.ZERO,
(i, b) -> BigInteger.valueOf(i).add(b)); // 14
Again, look at the order of operations:
op
/ \
1 op
/ \
3 op
/ \
3 op
/ \
7 z
To complete the top most operation, you need descend all the way down the fold.
Again, does that look familar?
If we define z
as a list and op
as the list construction, we end up with append
:
foldr(xs, List.<Integer>list(49), (z, zs) -> list(z, zs));
// 1, 3, 3, 7, 49
If we add in some logic, we get map
:
foldr(xs, List.<Integer>empty(), (z, zs) -> list(z*z, zs));
// squares: 1, 9, 9, 49
And even filter
:
// drop all values less than 5
foldr(xs, List.<Integer>empty(), (z, zs) -> {
if (z < 5) return zs;
else return list(z, zs);
});
// 7
Unfortunately, right fold is not tail-recursive, making it an undesirable operation.
The trick is to apply a left fold twice: in a first step, we'll use foldl
to reverse the list, then we'll use it again to reverse it to its original order and applying the mapping function:
static <T, R> List<R> maptr(List<T> xs, Function<T, R> op) {
List<T> reverse = foldl(xs, empty(), (ys, y) -> list(y, ys));
List<R> mapped = foldl(reverse, empty(), (ys, y) -> list(op.apply(y), ys));
return mapped;
}
The Java docs has a nice summary of stream operations. The main distinction is between intermediate operations which return a new stream, and terminal operations, which return a (single) value.
You already know most of the intermediate operations:
filter(Predicate<T> p)
removes/skips unwanted elements in the streammap(Function<T, R> f)
transforms aStream<T>
into aStream<R>
using the providedFunction
sorted(Comparator<T> comp)
returns a sorted streamconcat(Stream<T> s)
appends another streamdistinct()
removes duplicatesskip(int n)
andlimit(int n)
skip elements and truncate the stream
Another notable intermediate operation is flatMap
which transforms a stream of sequences (lists, streams, etc.) into a single flat sequence.
// list-of-lists
Stream<List<Integer>> lol = Stream.of(
Arrays.asList(1, 2),
Arrays.asList(3, 4),
Arrays.asList(5)
);
Stream<Integer> integerStream = lol.flatMap(al -> al.stream());
integerStream.forEach(System.out::print); // 12345
Last week, we already talked about forEach(Consumer<T> c)
which can be used to iterate over the whole stream, and pass each element to the Consumer
.
This week, we learned about the reduce
functions, which are implemented in Java as reduce(T identity, BinaryOperator<T> op)
and the more more generic reduce(U identity, BiFunction<U, ? super T, U> op, BinaryCombiner<U> com)
.
Stream.of(1, 3, 3, 7).reduce(0, Integer::sum));
// 14
Stream.of(1, 3, 3, 7).reduce(BigInteger.ZERO,
(bi, i) -> bi.add(BigInteger.valueOf(i)),
(bi1, bi2) -> bi1.add(bi2))); // combine identity with first result
// 14
Nota Bene: The second operation can often be defined simpler as a
map
followed by areduce
Another powerful tool provided by the Java Streams API is collect
which is a special form of stream reduction.
The idea is to iterate over the stream and pass each element to a combiner that builds up a data structure.
A classic example is to turn a Stream
into a List
:
List<Integer> list1 = new LinkedList<>();
Stream.of(1, 3, 3, 7).forEach(i -> list.add(i));
// or shorter, using collect
List<Integer> list2 = Stream.of(1, 3, 3, 7).collect(Collectors.toList()));
Java provides a lengthy list of collectors for your convenience.
Here are a few examples from the docs, most notably groupingBy
und partitioningBy
.
// Accumulate names into a List
List<String> list = people.stream()
.map(Person::getName)
.collect(Collectors.toList());
// Accumulate names into a TreeSet
Set<String> set = people.stream()
.map(Person::getName)
.collect(Collectors.toCollection(TreeSet::new));
// Convert elements to strings and concatenate them, separated by commas
String joined = things.stream()
.map(Object::toString)
.collect(Collectors.joining(", "));
// Compute sum of salaries of employee
int total = employees.stream()
.collect(Collectors.summingInt(Employee::getSalary));
// Group employees by department
Map<Department, List<Employee>> byDept = employees.stream()
.collect(Collectors.groupingBy(Employee::getDepartment));
// Compute sum of salaries by department
Map<Department, Integer> totalByDept = employees.stream()
.collect(Collectors.groupingBy(Employee::getDepartment,
Collectors.summingInt(Employee::getSalary)));
// Partition students into passing and failing
Map<Boolean, List<Student>> passingFailing = students.stream()
.collect(Collectors.partitioningBy(s -> s.getGrade() <= 400));
Often you need to find certain values in a stream, such as findFirst()
, min()
or max()
.
Since these are methods that are often used on streams that are potentially empty, they return an Optional.
Optionals are similar to futures, as in you can get()
the content if it isPresent()
.
They can also be mapped to another Optional
, or used as a .stream()
.
Another frequent use case it to verify if all, any or none of the elements in a stream match a certain criteria.
Use the allMatch
, anyMatch
and noneMatch
functions, which take a Predicate<T>
as argument.
There is a separate document on parallel streams, but in short, just use parallelStream()
to enable parallel processing.
For example, to group People by their gender, you can use
Map<Person.Gender, List<Person>> byGender = allPeople
.stream()
.collect(Collectors.groupingBy(Person::getGender));
// or parallel
ConcurrentMap<Person.Sex, List<Person>> byGender = allPeople
.parallelStream()
.collect(Collectors.groupingByConcurrent(Person::getGender));
∎