Skip to content

Commit

Permalink
Refactor to use ExecutorService and BaseEncoding
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-philippe-martin committed May 5, 2016
1 parent 7d7212a commit 966679a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
package com.google.cloud.examples.nio;

import com.google.common.base.Stopwatch;
import com.google.common.io.BaseEncoding;

import javax.xml.bind.annotation.adapters.HexBinaryAdapter;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
Expand All @@ -35,7 +35,7 @@
* <p>This example shows how to read a file size using NIO.
* File.size returns the size of the file as saved in Storage metadata.
* This class also shows how to read all of the file's contents using NIO,
* and reports how long it took.
* computes a MD5 hash, and reports how long it took.
*
* <p>See the README for compilation instructions. Run this code with
* {@code target/appassembler/bin/CountBytes <file>}
Expand Down Expand Up @@ -85,7 +85,7 @@ private static void countFile(String fname) {
long elapsed = sw.elapsed(TimeUnit.SECONDS);
System.out.println("Read all " + total + " bytes in " + elapsed + "s. " +
"(" + readCalls +" calls to chan.read)");
String hex = (new HexBinaryAdapter()).marshal(md.digest());
String hex = String.valueOf(BaseEncoding.base16().encode(md.digest()));
System.out.println("The MD5 is: 0x" + hex);
if (total != size) {
System.out.println("Wait, this doesn't match! We saw " + total + " bytes, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
package com.google.cloud.examples.nio;

import com.google.common.base.Stopwatch;
import com.google.common.io.BaseEncoding;

import javax.xml.bind.annotation.adapters.HexBinaryAdapter;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
Expand All @@ -27,28 +27,58 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.MessageDigest;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

/**
* ParallelCountBytes will read through the whole file given as input.
*
* <p>This example shows how to go through all the contents of a file,
* in order, using multithreaded NIO reads.It also reports how long it took.
* in order, using multithreaded NIO reads.
* It prints a MD5 hash and reports how long it took.
*
* <p>See the README for compilation instructions. Run this code with
* {@code target/appassembler/bin/ParallelCountBytes <file>}
*/
public class ParallelCountBytes {

private class BufWithLock {
public Object lock;
public ByteBuffer buf;
public boolean full;
public Thread t;
/**
* WorkUnit holds a buffer and the instructions for what to put in it.
*/
private class WorkUnit implements Callable<WorkUnit> {
public final ByteBuffer buf;
final SeekableByteChannel chan;
final int blockSize;
int blockIndex;

public BufWithLock(int size) {
this.buf = ByteBuffer.allocate(size);
this.lock = new Object();
public WorkUnit(SeekableByteChannel chan, int blockSize, int blockIndex) {
this.chan = chan;
this.buf = ByteBuffer.allocate(blockSize);
this.blockSize = blockSize;
this.blockIndex = blockIndex;
}

@Override
public WorkUnit call() throws IOException {
int pos = blockSize * blockIndex;
if (pos > chan.size()) {
return this;
}
chan.position(pos);
// read until buffer is full, or EOF
while (chan.read(buf) > 0) {};
return this;
}

public WorkUnit resetForIndex(int blockIndex) {
this.blockIndex = blockIndex;
buf.flip();
return this;
}
}

Expand All @@ -69,37 +99,6 @@ public void start(String[] args) throws IOException {
}
}

private void stridedRead(SeekableByteChannel chan, int blockSize, int firstBlock, int stride, BufWithLock output) {
try {
// stagger the threads a little bit.
Thread.sleep(250 * firstBlock);
long pos = firstBlock * blockSize;
synchronized(output.lock) {
while (true) {
if (pos > chan.size()) {
break;
}
chan.position(pos);
// read until buffer is full, or EOF
while (chan.read(output.buf) > 0) {};
output.full = true;
output.lock.notifyAll();
if (output.buf.hasRemaining()) {
break;
}
// wait for main thread to process it
while (output.full) {
output.lock.wait();
}
output.buf.flip();
pos += stride * blockSize;
}
}
} catch (InterruptedException | IOException o) {
// this simple example doesn't handle errors, sorry.
}
}

/**
* Print the length of the indicated file.
*
Expand All @@ -109,49 +108,36 @@ private void stridedRead(SeekableByteChannel chan, int blockSize, int firstBlock
private void countFile(String fname) throws IOException{
// large buffers pay off
final int bufSize = 50 * 1024 * 1024;
Queue<Future<WorkUnit>> work = new ArrayDeque<>();
try {
Path path = Paths.get(new URI(fname));
long size = Files.size(path);
System.out.println(fname + ": " + size + " bytes.");
ByteBuffer buf = ByteBuffer.allocate(bufSize);
int nBlocks = (int)Math.ceil( size / (double)bufSize);
int nThreads = nBlocks;
int nThreads = (int) Math.ceil(size / (double) bufSize);
if (nThreads > 4) nThreads = 4;
System.out.println("Reading the whole file using " + nThreads + " threads...");
Stopwatch sw = Stopwatch.createStarted();
final BufWithLock[] bufs = new BufWithLock[nThreads];
for (int i = 0; i < nThreads; i++) {
bufs[i] = new BufWithLock(bufSize);
final SeekableByteChannel chan = Files.newByteChannel(path);
final int finalNThreads = nThreads;
final int finalI = i;
bufs[i].t = new Thread(new Runnable() {
@Override
public void run() {
stridedRead(chan, bufSize, finalI, finalNThreads, bufs[finalI]);
}
});
bufs[i].t.start();
}

long total = 0;
MessageDigest md = MessageDigest.getInstance("MD5");
for (int block = 0; block < nBlocks; block++) {
BufWithLock bwl = bufs[block % bufs.length];
synchronized (bwl.lock) {
while (!bwl.full) {
bwl.lock.wait();
}
md.update(bwl.buf.array(), 0, bwl.buf.position());
total += bwl.buf.position();
bwl.full = false;
bwl.lock.notifyAll();

ExecutorService exec = Executors.newFixedThreadPool(nThreads);
int blockIndex;
for (blockIndex = 0; blockIndex < nThreads; blockIndex++) {
work.add(exec.submit(new WorkUnit(Files.newByteChannel(path), bufSize, blockIndex)));
}
while (true) {
WorkUnit full = work.remove().get();
md.update(full.buf.array(), 0, full.buf.position());
total += full.buf.position();
if (full.buf.hasRemaining()) {
break;
}
work.add(exec.submit(full.resetForIndex(blockIndex++)));
}

long elapsed = sw.elapsed(TimeUnit.SECONDS);
System.out.println("Read all " + total + " bytes in " + elapsed + "s. ");
String hex = (new HexBinaryAdapter()).marshal(md.digest());
String hex = String.valueOf(BaseEncoding.base16().encode(md.digest()));
System.out.println("The MD5 is: 0x" + hex);
if (total != size) {
System.out.println("Wait, this doesn't match! We saw " + total + " bytes, " +
Expand Down

0 comments on commit 966679a

Please sign in to comment.