diff --git a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/integration/ClientSpec.scala b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/integration/ClientSpec.scala index 05a920f4cb2..2b01c87b27e 100644 --- a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/integration/ClientSpec.scala +++ b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/integration/ClientSpec.scala @@ -6,49 +6,48 @@ import com.twitter.finagle.kestrel.protocol.Kestrel import com.twitter.finagle.memcached.util.ChannelBufferUtils._ import com.twitter.io.Charsets import com.twitter.util.Await -import org.specs.SpecificationWithJUnit import com.twitter.finagle.thrift.{ClientId, ThriftClientFramedCodec} -class ClientSpec extends SpecificationWithJUnit { - "ConnectedClient" should { - skip("This test requires a Kestrel server to run. Please run manually") - - "simple client" in { - val serviceFactory = ClientBuilder() - .hosts("localhost:22133") - .codec(Kestrel()) - .hostConnectionLimit(1) - .buildFactory() - val client = Client(serviceFactory) - - Await.result(client.delete("foo")) - - "set & get" in { - Await.result(client.get("foo")) mustEqual None - Await.result(client.set("foo", "bar")) - Await.result(client.get("foo")) map { _.toString(Charsets.Utf8) } mustEqual Some("bar") - } - } +import org.scalatest.junit.JUnitRunner +import org.junit.runner.RunWith +import org.scalatest.{BeforeAndAfter, FunSuite, Suites} + +@RunWith(classOf[JUnitRunner]) +class ClientTest extends Suites ( + new ConnectedClientTest, + new ThriftConnectedClientTest +) + +class ConnectedClientTest extends FunSuite { + val serviceFactory = ClientBuilder() + .hosts("localhost:22133") + .codec(Kestrel()) + .hostConnectionLimit(1) + .buildFactory() + val client = Client(serviceFactory) + + Await.result(client.delete("foo")) + + ignore("simple clientset & get") { + assert(Await.result(client.get("foo")) === None) + Await.result(client.set("foo", "bar")) + assert(Await.result(client.get("foo")).map(f => f.toString(Charsets.Utf8)) === Some("bar")) } +} + +class ThriftConnectedClientTest extends FunSuite { + val serviceFactory = ClientBuilder() + .hosts("localhost:2229") + .codec(ThriftClientFramedCodec(Some(ClientId("testcase")))) + .hostConnectionLimit(1) + .buildFactory() + val client = Client.makeThrift(serviceFactory) + + Await.result(client.delete("foo")) - "ThriftConnectedClient" should { - skip("This test requires a Kestrel server to run. Please run manually") - - "simple client" in { - val serviceFactory = ClientBuilder() - .hosts("localhost:2229") - .codec(ThriftClientFramedCodec(Some(ClientId("testcase")))) - .hostConnectionLimit(1) - .buildFactory() - val client = Client.makeThrift(serviceFactory) - - Await.result(client.delete("foo")) - - "set & get" in { - Await.result(client.get("foo")) mustEqual None - Await.result(client.set("foo", "bar")) - Await.result(client.get("foo")) map { _.toString(Charsets.Utf8) } mustEqual Some("bar") - } - } + ignore("set & get") { + assert(Await.result(client.get("foo")) === None) + Await.result(client.set("foo", "bar")) + assert(Await.result(client.get("foo")).map(f => f.toString(Charsets.Utf8)) == Some("bar")) } } diff --git a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/integration/InterpreterServiceSpec.scala b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/integration/InterpreterServiceSpec.scala index f0f037f7ef0..e482c5a4159 100644 --- a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/integration/InterpreterServiceSpec.scala +++ b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/integration/InterpreterServiceSpec.scala @@ -8,50 +8,50 @@ import com.twitter.finagle.kestrel.protocol._ import com.twitter.finagle.memcached.util.ChannelBufferUtils._ import com.twitter.util.{Await, Time} import java.net.InetSocketAddress -import org.specs.SpecificationWithJUnit - -class InterpreterServiceSpec extends SpecificationWithJUnit { - "InterpreterService" should { - var server: Server = null - var client: Service[Command, Response] = null - var address: InetSocketAddress = null - - doBefore { - server = new Server(new InetSocketAddress(0)) - address = server.start().localAddress.asInstanceOf[InetSocketAddress] - client = ClientBuilder() - .hosts("localhost:" + address.getPort) - .codec(Kestrel()) - .hostConnectionLimit(1) - .build() - } - - doAfter { - server.stop() - } - - val queueName = "name" - val value = "value" - - "set & get" in { - val result = for { - _ <- client(Flush(queueName)) - _ <- client(Set(queueName, Time.now, value)) - r <- client(Get(queueName)) - } yield r - Await.result(result, 1.second) mustEqual Values(Seq(Value(queueName, value))) - } - - "transactions" in { - "set & get/open & get/abort" in { - val result = for { - _ <- client(Set(queueName, Time.now, value)) - _ <- client(Open(queueName)) - _ <- client(Abort(queueName)) - r <- client(Open(queueName)) - } yield r - Await.result(result, 1.second) mustEqual Values(Seq(Value(queueName, value))) - } - } + +import org.scalatest.junit.JUnitRunner +import org.junit.runner.RunWith +import org.scalatest.{BeforeAndAfter, FunSuite} + +@RunWith(classOf[JUnitRunner]) +class InterpreterServiceTest extends FunSuite with BeforeAndAfter { + var server: Server = null + var client: Service[Command, Response] = null + var address: InetSocketAddress = null + + before { + server = new Server(new InetSocketAddress(0)) + address = server.start().localAddress.asInstanceOf[InetSocketAddress] + client = ClientBuilder() + .hosts("localhost:" + address.getPort) + .codec(Kestrel()) + .hostConnectionLimit(1) + .build() + } + + after { + server.stop() + } + + val queueName = "name" + val value = "value" + + test("set & get") { + val result = for { + _ <- client(Flush(queueName)) + _ <- client(Set(queueName, Time.now, value)) + r <- client(Get(queueName)) + } yield r + assert(Await.result(result, 1.second) === Values(Seq(Value(queueName, value)))) + } + + test("transactions - set & get/open & get/abort") { + val result = for { + _ <- client(Set(queueName, Time.now, value)) + _ <- client(Open(queueName)) + _ <- client(Abort(queueName)) + r <- client(Open(queueName)) + } yield r + assert(Await.result(result, 1.second) === Values(Seq(Value(queueName, value)))) } } diff --git a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/ClientSpec.scala b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/ClientSpec.scala index 9cf07c781af..e53fdfe1cdf 100644 --- a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/ClientSpec.scala +++ b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/ClientSpec.scala @@ -1,8 +1,11 @@ package com.twitter.finagle.kestrel package unit -import org.specs.SpecificationWithJUnit -import org.specs.mock.Mockito +import org.junit.runner.RunWith +import org.scalatest.{FunSuite, Suites} +import org.scalatest.junit.JUnitRunner +import org.scalatest.mock.MockitoSugar +import org.mockito.Mockito.{verify, times, when} import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers} import com.twitter.util.{Await, Future, Duration, Time, MockTimer, Promise} @@ -16,6 +19,13 @@ import com.twitter.finagle.memcached.util.ChannelBufferUtils._ import com.twitter.finagle.thrift.ThriftClientRequest import com.twitter.finagle.kestrel.net.lag.kestrel.thriftscala.Item +@RunWith(classOf[JUnitRunner]) +class ClientTest extends Suites( + new ClientReadReliablyTest, + new ConnectedClientReadTest, + new ThriftConnectedClientReadTest +) + // all this so we can spy() on a client. class MockClient extends Client { def set(queueName: String, value: ChannelBuffer, expiry: Time = Time.epoch) = null @@ -27,149 +37,149 @@ class MockClient extends Client { def close() {} } -class ClientSpec extends SpecificationWithJUnit with Mockito { +class ClientReadReliablyTest extends FunSuite with MockitoSugar { def buf(i: Int) = ChannelBuffers.wrappedBuffer("%d".format(i).getBytes) def msg(i: Int) = { val m = mock[ReadMessage] - m.bytes returns buf(i) + when(m.bytes).thenReturn(buf(i)) m } - "Client.readReliably" should { - val messages = new Broker[ReadMessage] - val error = new Broker[Throwable] - val client = spy(new MockClient) - val rh = mock[ReadHandle] - rh.messages returns messages.recv - rh.error returns error.recv - client.read("foo") returns rh - - "proxy messages" in { - val h = client.readReliably("foo") - there was one(client).read("foo") + val messages = new Broker[ReadMessage] + val error = new Broker[Throwable] + val client = mock[Client] + val rh = mock[ReadHandle] + when(rh.messages).thenReturn(messages.recv) + when(rh.error).thenReturn(error.recv) + when(client.read("foo")).thenReturn(rh) - val f = (h.messages?) - f.isDefined must beFalse + test("proxy messages") { + val h = client.readReliably("foo") + verify(client).read("foo") - val m = msg(0) + val f = (h.messages?) + assert(f.isDefined === false) - messages ! m - f.isDefined must beTrue - Await.result(f) must be(m) + val m = msg(0) - (h.messages?).isDefined must beFalse - } - - "reconnect on failure" in { - val h = client.readReliably("foo") - there was one(client).read("foo") - val m = msg(0) - messages ! m - (h.messages??) must be(m) - - val messages2 = new Broker[ReadMessage] - val error2 = new Broker[Throwable] - val rh2 = mock[ReadHandle] - rh2.messages returns messages2.recv - rh2.error returns error2.recv - client.read("foo") returns rh2 - - error ! new Exception("wtf") - there were two(client).read("foo") - - messages ! m // an errant message on broken channel - - // new messages must make it - val f = (h.messages?) - f.isDefined must beFalse - - val m2 = msg(2) - messages2 ! m2 - f.isDefined must beTrue - Await.result(f) must be(m2) - } + messages ! m + assert(f.isDefined === true) + assert(Await.result(f) === m) - "reconnect on failure (with delay)" in Time.withCurrentTimeFrozen { tc => - val timer = new MockTimer - val delays = Stream(1.seconds, 2.seconds, 3.second) - val h = client.readReliably("foo", timer, delays) - there was one(client).read("foo") - - val errf = (h.error?) + assert((h.messages?).isDefined === false) + } - delays.zipWithIndex foreach { case (delay, i) => - there were (i + 1).times(client).read("foo") - error ! new Exception("sad panda") - tc.advance(delay) - timer.tick() - there were (i + 2).times(client).read("foo") - errf.isDefined must beFalse - } + test("reconnect on failure") { + val h = client.readReliably("foo") + verify(client).read("foo") + val m = msg(0) + messages ! m + assert((h.messages??) === m) + + val messages2 = new Broker[ReadMessage] + val error2 = new Broker[Throwable] + val rh2 = mock[ReadHandle] + when(rh2.messages).thenReturn(messages2.recv) + when(rh2.error).thenReturn(error2.recv) + when(client.read("foo")).thenReturn(rh2) + + error ! new Exception("wtf") + verify(client, times(2)).read("foo") + + messages ! m // an errant message on broken channel + + // new messages must make it + val f = (h.messages?) + assert(f.isDefined === false) + + val m2 = msg(2) + messages2 ! m2 + assert(f.isDefined === true) + assert(Await.result(f) === m2) + } - error ! new Exception("final sad panda") + // test("reconnect on failure (with delay)") in Time.withCurrentTimeFrozen { tc => + // val timer = new MockTimer + // val delays = Stream(1.seconds, 2.seconds, 3.second) + // val h = client.readReliably("foo", timer, delays) + // verify(client).read("foo") + + // val errf = (h.error?) + + // delays.zipWithIndex foreach { case (delay, i) => + // verify(client, times(i + 1)).read("foo") + // error ! new Exception("sad panda") + // tc.advance(delay) + // timer.tick() + // verify(client, times(i + 2)).read("foo") + // assert(errf.isDefined === false) + // } + + // error ! new Exception("final sad panda") + + // assert(errf.isDefined === true) + // intercept[OutOfRetriesException] { + // Await.result(errf) + // } + // } + + test("close on close requested") { + val h = client.readReliably("foo") + verify(rh, times(0)).close() + h.close() + verify(rh).close() + } +} - errf.isDefined must beTrue - Await.result(errf) must be_==(OutOfRetriesException) +class ConnectedClientReadTest extends FunSuite with MockitoSugar { + val queueName = "foo" + val factory = mock[ServiceFactory[Command, Response]] + val service = mock[Service[Command, Response]] + val client = new ConnectedClient(factory) + val open = Open(queueName, Some(Duration.Top)) + val closeAndOpen = CloseAndOpen(queueName, Some(Duration.Top)) + val abort = Abort(queueName) + + test("interrupt current request on close") { + when(factory.apply()).thenReturn(Future(service)) + val promise = new Promise[Response]() + @volatile var wasInterrupted = false + promise.setInterruptHandler { case _cause => + wasInterrupted = true } + when(service(open)).thenReturn(promise) + when(service(closeAndOpen)).thenReturn(promise) + when(service(abort)).thenReturn(Future(Values(Seq()))) - "close on close requested" in { - val h = client.readReliably("foo") - there was no(rh).close() - h.close() - there was one(rh).close() - } - } + val rh = client.read(queueName) - "ConnectedClient.read" should { - val queueName = "foo" - val factory = mock[ServiceFactory[Command, Response]] - val service = mock[Service[Command, Response]] - val client = new ConnectedClient(factory) - val open = Open(queueName, Some(Duration.Top)) - val closeAndOpen = CloseAndOpen(queueName, Some(Duration.Top)) - val abort = Abort(queueName) - - "interrupt current request on close" in { - factory.apply() returns Future(service) - val promise = new Promise[Response]() - @volatile var wasInterrupted = false - promise.setInterruptHandler { case _cause => - wasInterrupted = true - } - service(open) returns promise - service(closeAndOpen) returns promise - service(abort) returns Future(Values(Seq())) - - val rh = client.read(queueName) - - wasInterrupted must beFalse - rh.close() - wasInterrupted must beTrue - } + assert(wasInterrupted === false) + rh.close() + assert(wasInterrupted === true) } +} - "ThriftConnectedClient.read" should { - val queueName = "foo" - val clientFactory = mock[FinagledClientFactory] - val finagledClient = mock[FinagledClosableClient] - val client = new ThriftConnectedClient(clientFactory) +class ThriftConnectedClientReadTest extends FunSuite with MockitoSugar { + val queueName = "foo" + val clientFactory = mock[FinagledClientFactory] + val finagledClient = mock[FinagledClosableClient] + val client = new ThriftConnectedClient(clientFactory) - "interrupt current thrift request on close" in { - clientFactory.apply() returns Future(finagledClient) - val promise = new Promise[Seq[Item]]() + test("interrupt current thrift request on close") { + when(clientFactory.apply()).thenReturn(Future(finagledClient)) + val promise = new Promise[Seq[Item]]() - @volatile var wasInterrupted = false - promise.setInterruptHandler { case _cause => - wasInterrupted = true - } + @volatile var wasInterrupted = false + promise.setInterruptHandler { case _cause => + wasInterrupted = true + } - finagledClient.get(queueName, 1, Int.MaxValue, Int.MaxValue) returns promise + when(finagledClient.get(queueName, 1, Int.MaxValue, Int.MaxValue)).thenReturn(promise) - val rh = client.read(queueName) + val rh = client.read(queueName) - wasInterrupted must beFalse - rh.close() - wasInterrupted must beTrue - } + assert(wasInterrupted === false) + rh.close() + assert(wasInterrupted === true) } -} +} \ No newline at end of file diff --git a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/InterpreterSpec.scala b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/InterpreterSpec.scala index 90d9036b1bf..bdcd87118a0 100644 --- a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/InterpreterSpec.scala +++ b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/InterpreterSpec.scala @@ -10,80 +10,88 @@ import com.twitter.util.Time import java.util.concurrent.{BlockingDeque, LinkedBlockingDeque} import org.jboss.netty.buffer.ChannelBuffer import org.jboss.netty.buffer.ChannelBuffers.copiedBuffer -import org.specs.SpecificationWithJUnit -class InterpreterSpec extends SpecificationWithJUnit { - "Interpreter" should { - val queues = CacheBuilder.newBuilder() - .build(new CacheLoader[ChannelBuffer, BlockingDeque[ChannelBuffer]] { - def load(k: ChannelBuffer) = new LinkedBlockingDeque[ChannelBuffer] - }) - val interpreter = new Interpreter(queues) +import org.junit.runner.RunWith +import org.scalatest.{FunSuite, Suites} +import org.scalatest.junit.JUnitRunner + +@RunWith(classOf[JUnitRunner]) +class InterpreterTest extends Suites( + new InterpreterTests, + new DecodingCommandTests +) + +class InterpreterTests extends FunSuite { + val queues = CacheBuilder.newBuilder() + .build(new CacheLoader[ChannelBuffer, BlockingDeque[ChannelBuffer]] { + def load(k: ChannelBuffer) = new LinkedBlockingDeque[ChannelBuffer] + }) + val interpreter = new Interpreter(queues) + + test("set & get") { + interpreter(Set("name", Time.now, "rawr")) + assert(interpreter(Get("name")) === Values(Seq(Value("name", "rawr")))) + } - "set & get" in { + test("transactions") { + test("set & get/open & get/open") { interpreter(Set("name", Time.now, "rawr")) - interpreter(Get("name")) mustEqual - Values(Seq(Value("name", "rawr"))) - } - - "transactions" in { - "set & get/open & get/open" in { - interpreter(Set("name", Time.now, "rawr")) + interpreter(Open("name")) + intercept[InvalidStateTransition] { interpreter(Open("name")) - interpreter(Open("name")) must throwA[InvalidStateTransition] - } - - "set & get/abort" in { - interpreter(Set("name", Time.now, "rawr")) - interpreter(Abort("name")) must throwA[InvalidStateTransition] - } - - "set & get/open & get/close" in { - interpreter(Set("name", Time.now, "rawr")) - interpreter(Open("name")) mustEqual - Values(Seq(Value("name", "rawr"))) - interpreter(Close("name")) mustEqual Values(Seq()) - interpreter(Open("name")) mustEqual Values(Seq()) - } - - "set & get/open & get/abort" in { - interpreter(Set("name", Time.now, "rawr")) - interpreter(Open("name")) mustEqual - Values(Seq(Value("name", "rawr"))) - interpreter(Abort("name")) mustEqual Values(Seq()) - interpreter(Open("name")) mustEqual - Values(Seq(Value("name", "rawr"))) } } - "timeouts" in { - "set & get/t=1" in { - interpreter(Get("name", Some(1.millisecond))) mustEqual Values(Seq()) - interpreter(Set("name", Time.now, "rawr")) - interpreter(Get("name", Some(1.second))) mustEqual Values(Seq(Value("name", "rawr"))) + test("set & get/abort") { + interpreter(Set("name", Time.now, "rawr")) + intercept[InvalidStateTransition] { + interpreter(Abort("name")) } } - "delete" in { + test("set & get/open & get/close") { interpreter(Set("name", Time.now, "rawr")) - interpreter(Delete("name")) - interpreter(Get("name")) mustEqual Values(Seq.empty) + assert(interpreter(Open("name")) === Values(Seq(Value("name", "rawr")))) + assert(interpreter(Close("name")) === Values(Seq())) + assert(interpreter(Open("name")) === Values(Seq())) } - "flush" in { + test("set & get/open & get/abort") { interpreter(Set("name", Time.now, "rawr")) - interpreter(Flush("name")) - interpreter(Get("name")) mustEqual Values(Seq.empty) + assert(interpreter(Open("name")) === Values(Seq(Value("name", "rawr")))) + assert(interpreter(Abort("name")) === Values(Seq())) + assert(interpreter(Open("name")) === Values(Seq(Value("name", "rawr")))) } + } - "flushAll" in { + test("timeouts") { + test("set & get/t=1") { + assert(interpreter(Get("name", Some(1.millisecond))) === Values(Seq())) interpreter(Set("name", Time.now, "rawr")) - interpreter(FlushAll()) - interpreter(Get("name")) mustEqual Values(Seq.empty) + assert(interpreter(Get("name", Some(1.second))) === Values(Seq(Value("name", "rawr")))) } } - "Decoding to command" should { + test("delete") { + interpreter(Set("name", Time.now, "rawr")) + interpreter(Delete("name")) + assert(interpreter(Get("name")) === Values(Seq.empty)) + } + + test("flush") { + interpreter(Set("name", Time.now, "rawr")) + interpreter(Flush("name")) + assert(interpreter(Get("name")) === Values(Seq.empty)) + } + + test("flushAll") { + interpreter(Set("name", Time.now, "rawr")) + interpreter(FlushAll()) + assert(interpreter(Get("name")) === Values(Seq.empty)) + } +} + +class DecodingCommandTests extends FunSuite { val dtc = new DecodingToCommand def getCmdSeq(subCmd: String) = { dtc.parseNonStorageCommand( @@ -91,20 +99,19 @@ class InterpreterSpec extends SpecificationWithJUnit { copiedBuffer(subCmd.getBytes))) } - "parse get with timeout" in { - getCmdSeq("foo/t=123") mustEqual Get(copiedBuffer("foo".getBytes), Some(123.milliseconds)) + test("parse get with timeout") { + assert(getCmdSeq("foo/t=123") === Get(copiedBuffer("foo".getBytes), Some(123.milliseconds))) } - "parse close/open with timeout" in { - getCmdSeq("foo/close/open/t=123") mustEqual CloseAndOpen(copiedBuffer("foo".getBytes), Some(123.milliseconds)) + test("parse close/open with timeout") { + assert(getCmdSeq("foo/close/open/t=123") === CloseAndOpen(copiedBuffer("foo".getBytes), Some(123.milliseconds))) } - "parse close/open with timeout in between" in { - getCmdSeq("foo/t=123/close/open") mustEqual CloseAndOpen(copiedBuffer("foo".getBytes), Some(123.milliseconds)) + test("parse close/open with timeout in between") { + assert(getCmdSeq("foo/t=123/close/open") === CloseAndOpen(copiedBuffer("foo".getBytes), Some(123.milliseconds))) } - "parse without timeout" in { - getCmdSeq("foo") mustEqual Get(copiedBuffer("foo".getBytes), None) + test("parse without timeout") { + assert(getCmdSeq("foo") === Get(copiedBuffer("foo".getBytes), None)) } - } } diff --git a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/MultiReaderSpec.scala b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/MultiReaderSpec.scala index 9ca137be8f1..c9f5fc1899e 100644 --- a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/MultiReaderSpec.scala +++ b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/MultiReaderSpec.scala @@ -13,46 +13,85 @@ import com.twitter.finagle.kestrel.protocol._ import com.twitter.finagle.memcached.util.ChannelBufferUtils._ import com.twitter.util.{Await, Future, Promise, Return, Time, Updatable, Var} import org.jboss.netty.buffer.ChannelBuffer -import org.specs.SpecificationWithJUnit -import org.specs.mock.Mockito import scala.collection.mutable.{ArrayBuffer, Set => MSet} import scala.collection.immutable.{Set => ISet} -class MultiReaderSpec extends SpecificationWithJUnit with Mockito { - noDetailedDiffs() +import org.junit.runner.RunWith +import org.scalatest.{FunSuite, Suites} +import org.scalatest.matchers._ +import org.scalatest.junit.JUnitRunner +import org.scalatest.mock.MockitoSugar +import org.mockito.Mockito.{verify, times, when} + +trait CollectionComparison { + def sameAs[A](c: Traversable[A], d: Traversable[A]): Boolean = + if (c.isEmpty) d.isEmpty + else { + val (e, f) = d span (c.head !=) + if (f.isEmpty) false else sameAs(c.tail, e ++ f.tail) + } +} - class MockHandle extends ReadHandle { - val _messages = new Broker[ReadMessage] - val _error = new Broker[Throwable] +@RunWith(classOf[JUnitRunner]) +class MultiReaderTest extends Suites( + new StaticReadHandleClusterTest, + new VarAddrBasedClusterTest, + new DynamicSocketAddressClusterTest +) - val messages = _messages.recv - val error = _error.recv - def close() {} // to spy on! - } +class MockHandle extends ReadHandle { + val _messages = new Broker[ReadMessage] + val _error = new Broker[Throwable] - "MultiReader" should { - val queueName = "the_queue" - "static ReadHandle cluster" in { - val N = 3 - val handles = (0 until N) map { _ => spy(new MockHandle) } - val va: Var[Return[ISet[ReadHandle]]] = Var.value(Return(handles.toSet)) + val messages = _messages.recv + val error = _error.recv + def close() {} // to spy on! +} - "always grab the first available message" in { - val handle = MultiReaderHelper.merge(va) +class StaticReadHandleClusterTest extends FunSuite with MockitoSugar { + val N = 3 + val queueName = "the_queue" + val handles = (0 until N) map { _ => mock[MockHandle] } + val va: Var[Return[ISet[ReadHandle]]] = Var.value(Return(handles.toSet)) - val messages = new ArrayBuffer[ReadMessage] - handle.messages foreach { messages += _ } + test("always grab the first available message") { + val handle = MultiReaderHelper.merge(va) - // stripe some messages across - val sentMessages = 0 until N*100 map { _ => mock[ReadMessage] } + val messages = new ArrayBuffer[ReadMessage] + handle.messages foreach { messages += _ } - messages must beEmpty - sentMessages.zipWithIndex foreach { case (m, i) => - handles(i % handles.size)._messages ! m - } + // stripe some messages across + val sentMessages = 0 until N*100 map { _ => mock[ReadMessage] } - messages must be_==(sentMessages) - } + assert(messages === Seq.empty) + sentMessages.zipWithIndex foreach { case (m, i) => + handles(i % handles.size)._messages ! m + } + + assert(messages === sentMessages) + } + + test("propagate closes") { + handles foreach { h => verify(h, times(0)).close() } + val handle = MultiReaderHelper.merge(va) + handle.close() + handles foreach { h => verify(h).close() } + } + + test("propagate errors when everything's errored out") { + val handle = MultiReaderHelper.merge(va) + val e = handle.error.sync() + handles foreach { h => + assert(!e.isDefined) + h._error ! new Exception("sad panda") + } + + assert(e.isDefined) + intercept[AllHandlesDiedException.type] { + Await.result(e) + } + } +} // TODO - this test stopped working // // We use frozen time for deterministic randomness. @@ -74,423 +113,408 @@ class MultiReaderSpec extends SpecificationWithJUnit with Mockito { // } // } - "propagate closes" in { - handles foreach { h => there was no(h).close() } - val handle = MultiReaderHelper.merge(va) - handle.close() - handles foreach { h => there was one(h).close() } - } +class VarAddrBasedClusterTest extends FunSuite with MockitoSugar with ShouldMatchers { + val N = 3 + val queueName = "the_queue" + val hosts = 0 until N map { i => + InetSocketAddress.createUnresolved("10.0.0.%d".format(i), 22133) + } - "propagate errors when everything's errored out" in { - val handle = MultiReaderHelper.merge(va) - val e = handle.error.sync() - handles foreach { h => - e.isDefined must beFalse - h._error ! new Exception("sad panda") + val executor = Executors.newCachedThreadPool() + + def newKestrelService(executor: Option[ExecutorService], + queues: LoadingCache[ChannelBuffer, BlockingDeque[ChannelBuffer]]): Service[Command, Response] = { + val interpreter = new Interpreter(queues) + new Service[Command, Response] { + def apply(request: Command) = { + val promise = new Promise[Response]() + executor match { + case Some(executor) => + executor.submit(new Runnable { + def run() { + promise.setValue(interpreter(request)) + } + }) + case None => promise.setValue(interpreter(request)) } - - e.isDefined must beTrue - Await.result(e) must be(AllHandlesDiedException) + promise } } + } - "Var[Addr]-based cluster" in { - val N = 3 - val hosts = 0 until N map { i => - InetSocketAddress.createUnresolved("10.0.0.%d".format(i), 22133) - } - - val executor = Executors.newCachedThreadPool() - - def newKestrelService(executor: Option[ExecutorService], - queues: LoadingCache[ChannelBuffer, BlockingDeque[ChannelBuffer]]): Service[Command, Response] = { - val interpreter = new Interpreter(queues) - new Service[Command, Response] { - def apply(request: Command) = { - val promise = new Promise[Response]() - executor match { - case Some(executor) => - executor.submit(new Runnable { - def run() { - promise.setValue(interpreter(request)) - } - }) - case None => promise.setValue(interpreter(request)) - } - promise - } - } - } - - val hostQueuesMap = hosts.map { host => - val queues = CacheBuilder.newBuilder() - .build(new CacheLoader[ChannelBuffer, BlockingDeque[ChannelBuffer]] { - def load(k: ChannelBuffer) = new LinkedBlockingDeque[ChannelBuffer] - }) - (host, queues) - }.toMap - - lazy val mockClientBuilder = { - val result = mock[ClientBuilder[Command, Response, Nothing, ClientConfig.Yes, ClientConfig.Yes]] - - hosts.foreach { host => - val mockHostClientBuilder = - mock[ClientBuilder[Command, Response, ClientConfig.Yes, ClientConfig.Yes, ClientConfig.Yes]] - result.hosts(host) returns mockHostClientBuilder - - val queues = hostQueuesMap(host) - val factory = new ServiceFactory[Command, Response] { - // use an executor so readReliably doesn't block waiting on an empty queue - def apply(conn: ClientConnection) = - Future.value(newKestrelService(Some(executor), queues)) - def close(deadline: Time) = Future.Done - override def toString = "ServiceFactory for %s".format(host) - } - mockHostClientBuilder.buildFactory() returns factory - } - result - } - - val services = hosts.map { host => - val queues = hostQueuesMap(host) - // no executor here: this one is used for writing to the queues - newKestrelService(None, queues) + val hostQueuesMap = hosts.map { host => + val queues = CacheBuilder.newBuilder() + .build(new CacheLoader[ChannelBuffer, BlockingDeque[ChannelBuffer]] { + def load(k: ChannelBuffer) = new LinkedBlockingDeque[ChannelBuffer] + }) + (host, queues) + }.toMap + + lazy val mockClientBuilder = { + val result = mock[ClientBuilder[Command, Response, Nothing, ClientConfig.Yes, ClientConfig.Yes]] + + hosts.foreach { host => + val mockHostClientBuilder = + mock[ClientBuilder[Command, Response, ClientConfig.Yes, ClientConfig.Yes, ClientConfig.Yes]] + when(result.hosts(host)).thenReturn(mockHostClientBuilder) + + val queues = hostQueuesMap(host) + val factory = new ServiceFactory[Command, Response] { + // use an executor so readReliably doesn't block waiting on an empty queue + def apply(conn: ClientConnection) = + Future.value(newKestrelService(Some(executor), queues)) + def close(deadline: Time) = Future.Done + override def toString = "ServiceFactory for %s".format(host) } + when(mockHostClientBuilder.buildFactory()).thenReturn(factory) + } + result + } - def configureMessageReader(handle: ReadHandle): MSet[String] = { - val messages = MSet[String]() - val UTF8 = Charset.forName("UTF-8") + val services = hosts.map { host => + val queues = hostQueuesMap(host) + // no executor here: this one is used for writing to the queues + newKestrelService(None, queues) + } - handle.messages foreach { msg => - val str = msg.bytes.toString(UTF8) - messages += str - msg.ack.sync() - } - messages - } + def configureMessageReader(handle: ReadHandle): MSet[String] = { + val messages = MSet[String]() + val UTF8 = Charset.forName("UTF-8") - "read messages from a ready cluster" in { - val va = Var(Addr.Bound(hosts: _*)) - val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() - val messages = configureMessageReader(handle) - val sentMessages = 0 until N*10 map { i => "message %d".format(i) } - messages must beEmpty + handle.messages foreach { msg => + val str = msg.bytes.toString(UTF8) + messages += str + msg.ack.sync() + } + messages + } - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set(queueName, Time.now, m))) - } + test("read messages from a ready cluster") { + val va = Var(Addr.Bound(hosts: _*)) + val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() + val messages = configureMessageReader(handle) + val sentMessages = 0 until N*10 map { i => "message %d".format(i) } + assert(messages.toSeq === Seq.empty) - messages must eventually(be_==(sentMessages.toSet)) - } + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set(queueName, Time.now, m))) + } + assert(messages.toList === sentMessages.toList) + } - "read messages as cluster hosts are added" in { - val va = Var(Addr.Bound(hosts.head)) - val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() - val messages = configureMessageReader(handle) - val sentMessages = 0 until N*10 map { i => "message %d".format(i) } - messages must beEmpty + test("read messages as cluster hosts are added") { + val va = Var(Addr.Bound(hosts.head)) + val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() + val messages = configureMessageReader(handle) + val sentMessages = 0 until N*10 map { i => "message %d".format(i) } + assert(messages.toSeq === Seq.empty) - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set(queueName, Time.now, m))) - } + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set(queueName, Time.now, m))) + } - // 0, 3, 6 ... - messages must eventually(be_==(sentMessages.grouped(N).map { _.head }.toSet)) - messages.clear() + // 0, 3, 6 ... + assert(messages === sentMessages.grouped(N).map { _.head }.toSet) + messages.clear() - va.update(Addr.Bound(hosts: _*)) + va.update(Addr.Bound(hosts: _*)) - // 1, 2, 4, 5, ... - messages must eventually(be_==(sentMessages.grouped(N).map { _.tail }.flatten.toSet)) - } + // 1, 2, 4, 5, ... + assert(messages === sentMessages.grouped(N).map { _.tail }.flatten.toSet) + } - "read messages as cluster hosts are removed" in { - var mutableHosts: Seq[SocketAddress] = hosts - val va = Var(Addr.Bound(mutableHosts: _*)) - val rest = hosts.tail.reverse - val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() + test("read messages as cluster hosts are removed") { + var mutableHosts: Seq[SocketAddress] = hosts + val va = Var(Addr.Bound(mutableHosts: _*)) + val rest = hosts.tail.reverse + val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() - val messages = configureMessageReader(handle) - val sentMessages = 0 until N*10 map { i => "message %d".format(i) } - messages must beEmpty + val messages = configureMessageReader(handle) + val sentMessages = 0 until N*10 map { i => "message %d".format(i) } + assert(messages === Seq.empty) - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set(queueName, Time.now, m))) - } - - messages must eventually(be_==(sentMessages.toSet)) - rest.zipWithIndex.foreach { case (host, hostIndex) => - messages.clear() - mutableHosts = (mutableHosts.toSet - host).toSeq - va.update(Addr.Bound(mutableHosts: _*)) + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set(queueName, Time.now, m))) + } - // write to all 3 - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set(queueName, Time.now, m))) - } + assert(messages === sentMessages.toSet) + rest.zipWithIndex.foreach { case (host, hostIndex) => + messages.clear() + mutableHosts = (mutableHosts.toSet - host).toSeq + va.update(Addr.Bound(mutableHosts: _*)) - // expect fewer to be read on each pass - val expectFirstN = N - hostIndex - 1 - messages must eventually(be_==(sentMessages.grouped(N).map { _.take(expectFirstN) }.flatten.toSet)) - } + // write to all 3 + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set(queueName, Time.now, m))) } - "wait for cluster to become ready before snapping initial hosts" in { - val va = Var(Addr.Bound()) - val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() - val messages = configureMessageReader(handle) - val error = handle.error.sync() - val sentMessages = 0 until N*10 map { i => "message %d".format(i) } - messages must beEmpty + // expect fewer to be read on each pass + val expectFirstN = N - hostIndex - 1 + assert(messages === sentMessages.grouped(N).map { _.take(expectFirstN) }.flatten.toSet) + } + } - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set(queueName, Time.now, m))) - } + test("wait for cluster to become ready before snapping initial hosts") { + val va = Var(Addr.Bound()) + val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() + val messages = configureMessageReader(handle) + val error = handle.error.sync() + val sentMessages = 0 until N*10 map { i => "message %d".format(i) } + assert(messages === Seq.empty) - messages must beEmpty // cluster not ready - error.isDefined must beFalse + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set(queueName, Time.now, m))) + } - va.update(Addr.Bound(hosts: _*)) + assert(messages === Seq.empty) // cluster not ready + assert(!error.isDefined) - messages must eventually(be_==(sentMessages.toSet)) - } + va.update(Addr.Bound(hosts: _*)) - "report an error if all hosts are removed" in { - val va = Var(Addr.Bound(hosts: _*)) - val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() - val error = handle.error.sync() - va.update(Addr.Bound()) + assert(messages === sentMessages.toSet) + } - error.isDefined must beTrue - Await.result(error) must be(AllHandlesDiedException) - } + test("report an error if all hosts are removed") { + val va = Var(Addr.Bound(hosts: _*)) + val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() + val error = handle.error.sync() + va.update(Addr.Bound()) - "propagate exception if cluster fails" in { - val ex = new Exception("uh oh") - val va: Var[Addr] with Updatable[Addr] = Var(Addr.Bound(hosts: _*)) - val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() - val error = handle.error.sync() - va.update(Addr.Failed(ex)) + assert(error.isDefined) + intercept[AllHandlesDiedException.type] { + Await.result(error) + } + } - error.isDefined must beTrue - Await.result(error) must be(ex) - } - } + test("propagate exception if cluster fails") { + var ex = new Exception("uh oh") + val va: Var[Addr] with Updatable[Addr] = Var(Addr.Bound(hosts: _*)) + val handle = MultiReader(va, queueName).clientBuilder(mockClientBuilder).build() + val error = handle.error.sync() + va.update(Addr.Failed(ex)) - "[deprecated] dynamic SocketAddress cluster" in { - class DynamicCluster[U](initial: Seq[U]) extends Cluster[U] { - def this() = this(Seq[U]()) + assert(error.isDefined) + intercept[Exception] { + Await.result(error) + } + } +} - var set = initial.toSet - var s = new Promise[Spool[Cluster.Change[U]]] +class DynamicSocketAddressClusterTest extends FunSuite with MockitoSugar { + class DynamicCluster[U](initial: Seq[U]) extends Cluster[U] { + def this() = this(Seq[U]()) - def add(f: U) = { - set += f - performChange(Cluster.Add(f)) - } + var set = initial.toSet + var s = new Promise[Spool[Cluster.Change[U]]] - def del(f: U) = { - set -= f - performChange(Cluster.Rem(f)) - } + def add(f: U) = { + set += f + performChange(Cluster.Add(f)) + } - private[this] def performChange (change: Cluster.Change[U]) = synchronized { - val newTail = new Promise[Spool[Cluster.Change[U]]] - s() = Return(change *:: newTail) - s = newTail - } + def del(f: U) = { + set -= f + performChange(Cluster.Rem(f)) + } - def snap = (set.toSeq, s) - } + private[this] def performChange (change: Cluster.Change[U]) = synchronized { + val newTail = new Promise[Spool[Cluster.Change[U]]] + s() = Return(change *:: newTail) + s = newTail + } - val N = 3 - val hosts = 0 until N map { i => InetSocketAddress.createUnresolved("10.0.0.%d".format(i), 22133) } - - val executor = Executors.newCachedThreadPool() - - def newKestrelService(executor: Option[ExecutorService], - queues: LoadingCache[ChannelBuffer, BlockingDeque[ChannelBuffer]]): Service[Command, Response] = { - val interpreter = new Interpreter(queues) - new Service[Command, Response] { - def apply(request: Command) = { - val promise = new Promise[Response]() - executor match { - case Some(executor) => - executor.submit(new Runnable { - def run() { - promise.setValue(interpreter(request)) - } - }) - case None => promise.setValue(interpreter(request)) - } - promise - } - } - } + def snap = (set.toSeq, s) + } - val hostQueuesMap = hosts.map { host => - val queues = CacheBuilder.newBuilder() - .build(new CacheLoader[ChannelBuffer, BlockingDeque[ChannelBuffer]] { - def load(k: ChannelBuffer) = new LinkedBlockingDeque[ChannelBuffer] - }) - (host, queues) - }.toMap - - lazy val mockClientBuilder = { - val result = mock[ClientBuilder[Command, Response, Nothing, ClientConfig.Yes, ClientConfig.Yes]] - - hosts.foreach { host => - val mockHostClientBuilder = - mock[ClientBuilder[Command, Response, ClientConfig.Yes, ClientConfig.Yes, ClientConfig.Yes]] - result.hosts(host) returns mockHostClientBuilder - - val queues = hostQueuesMap(host) - val factory = new ServiceFactory[Command, Response] { - // use an executor so readReliably doesn't block waiting on an empty queue - def apply(conn: ClientConnection) = Future(newKestrelService(Some(executor), queues)) - def close(deadline: Time) = Future.Done - override def toString = "ServiceFactory for %s".format(host) - } - mockHostClientBuilder.buildFactory() returns factory + val N = 3 + val hosts = 0 until N map { i => InetSocketAddress.createUnresolved("10.0.0.%d".format(i), 22133) } + + val executor = Executors.newCachedThreadPool() + + def newKestrelService(executor: Option[ExecutorService], + queues: LoadingCache[ChannelBuffer, BlockingDeque[ChannelBuffer]]): Service[Command, Response] = { + val interpreter = new Interpreter(queues) + new Service[Command, Response] { + def apply(request: Command) = { + val promise = new Promise[Response]() + executor match { + case Some(executor) => + executor.submit(new Runnable { + def run() { + promise.setValue(interpreter(request)) + } + }) + case None => promise.setValue(interpreter(request)) } - result + promise } + } + } - val services = hosts.map { host => - val queues = hostQueuesMap(host) - // no executor here: this one is used for writing to the queues - newKestrelService(None, queues) + val hostQueuesMap = hosts.map { host => + val queues = CacheBuilder.newBuilder() + .build(new CacheLoader[ChannelBuffer, BlockingDeque[ChannelBuffer]] { + def load(k: ChannelBuffer) = new LinkedBlockingDeque[ChannelBuffer] + }) + (host, queues) + }.toMap + + lazy val mockClientBuilder = { + val result = mock[ClientBuilder[Command, Response, Nothing, ClientConfig.Yes, ClientConfig.Yes]] + + hosts.foreach { host => + val mockHostClientBuilder = + mock[ClientBuilder[Command, Response, ClientConfig.Yes, ClientConfig.Yes, ClientConfig.Yes]] + when(result.hosts(host)).thenReturn(mockHostClientBuilder) + + val queues = hostQueuesMap(host) + val factory = new ServiceFactory[Command, Response] { + // use an executor so readReliably doesn't block waiting on an empty queue + def apply(conn: ClientConnection) = Future(newKestrelService(Some(executor), queues)) + def close(deadline: Time) = Future.Done + override def toString = "ServiceFactory for %s".format(host) } + when(mockHostClientBuilder.buildFactory()).thenReturn(factory) + } + result + } - def configureMessageReader(handle: ReadHandle): MSet[String] = { - val messages = MSet[String]() - val UTF8 = Charset.forName("UTF-8") + val services = hosts.map { host => + val queues = hostQueuesMap(host) + // no executor here: this one is used for writing to the queues + newKestrelService(None, queues) + } - handle.messages foreach { msg => - val str = msg.bytes.toString(UTF8) - messages += str - msg.ack.sync() - } - messages - } + def configureMessageReader(handle: ReadHandle): MSet[String] = { + val messages = MSet[String]() + val UTF8 = Charset.forName("UTF-8") - "read messages from a ready cluster" in { - val cluster = new DynamicCluster[SocketAddress](hosts) - val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() - val messages = configureMessageReader(handle) - val sentMessages = 0 until N*10 map { i => "message %d".format(i) } - messages must beEmpty + handle.messages foreach { msg => + val str = msg.bytes.toString(UTF8) + messages += str + msg.ack.sync() + } + messages + } - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) - } + test("read messages from a ready cluster") { + val cluster = new DynamicCluster[SocketAddress](hosts) + val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() + val messages = configureMessageReader(handle) + val sentMessages = 0 until N*10 map { i => "message %d".format(i) } + assert(messages === Seq.empty) - messages must eventually(be_==(sentMessages.toSet)) - } + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) + } - "read messages as cluster hosts are added" in { - val (host, rest) = (hosts.head, hosts.tail) - val cluster = new DynamicCluster[SocketAddress](List(host)) - val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() - val messages = configureMessageReader(handle) - val sentMessages = 0 until N*10 map { i => "message %d".format(i) } - messages must beEmpty + assert(messages === sentMessages.toSet) + } - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) - } + test("read messages as cluster hosts are added") { + val (host, rest) = (hosts.head, hosts.tail) + val cluster = new DynamicCluster[SocketAddress](List(host)) + val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() + val messages = configureMessageReader(handle) + val sentMessages = 0 until N*10 map { i => "message %d".format(i) } + assert(messages === Seq.empty) - // 0, 3, 6 ... - messages must eventually(be_==(sentMessages.grouped(N).map { _.head }.toSet)) - messages.clear() + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) + } - rest.foreach { host => cluster.add(host) } + // 0, 3, 6 ... + assert(messages === sentMessages.grouped(N).map { _.head }.toSet) + messages.clear() - // 1, 2, 4, 5, ... - messages must eventually(be_==(sentMessages.grouped(N).map { _.tail }.flatten.toSet)) - } + rest.foreach { host => cluster.add(host) } - "read messages as cluster hosts are removed" in { - val cluster = new DynamicCluster[SocketAddress](hosts) - val rest = hosts.tail - val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() + // 1, 2, 4, 5, ... + assert(messages === sentMessages.grouped(N).map { _.tail }.flatten.toSet) + } - val messages = configureMessageReader(handle) - val sentMessages = 0 until N*10 map { i => "message %d".format(i) } - messages must beEmpty + test("read messages as cluster hosts are removed") { + val cluster = new DynamicCluster[SocketAddress](hosts) + val rest = hosts.tail + val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) - } + val messages = configureMessageReader(handle) + val sentMessages = 0 until N*10 map { i => "message %d".format(i) } + assert(messages === Seq.empty) - messages must eventually(be_==(sentMessages.toSet)) + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) + } - rest.reverse.zipWithIndex.foreach { case (host, hostIndex) => - messages.clear() - cluster.del(host) + assert(messages === sentMessages.toSet) - // write to all 3 - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) - } + rest.reverse.zipWithIndex.foreach { case (host, hostIndex) => + messages.clear() + cluster.del(host) - // expect fewer to be read on each pass - val expectFirstN = N - hostIndex - 1 - messages must eventually(be_==(sentMessages.grouped(N).map { _.take(expectFirstN) }.flatten.toSet)) - } + // write to all 3 + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) } - "wait for cluster to become ready before snapping initial hosts" in { - val cluster = new DynamicCluster[SocketAddress](Seq()) - val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() - val messages = configureMessageReader(handle) - val errors = handle.error? - val sentMessages = 0 until N*10 map { i => "message %d".format(i) } - messages must beEmpty + // expect fewer to be read on each pass + val expectFirstN = N - hostIndex - 1 + assert(messages === sentMessages.grouped(N).map { _.take(expectFirstN) }.flatten.toSet) + } + } - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) - } + test("wait for cluster to become ready before snapping initial hosts") { + val cluster = new DynamicCluster[SocketAddress](Seq()) + val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() + val messages = configureMessageReader(handle) + val errors = handle.error? + val sentMessages = 0 until N*10 map { i => "message %d".format(i) } + assert(messages === Seq.empty) - messages must beEmpty // cluster not ready - errors.isDefined must beFalse + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) + } - hosts.foreach { host => cluster.add(host) } + assert(messages === Seq.empty) // cluster not ready + assert(!errors.isDefined) - messages must eventually(be_==(sentMessages.toSet)) - } + hosts.foreach { host => cluster.add(host) } - "report an error if all hosts are removed" in { - val cluster = new DynamicCluster[SocketAddress](hosts) - val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() - val e = (handle.error?) - hosts.foreach { host => cluster.del(host) } + assert(messages === sentMessages.toSet) + } - e.isDefined must beTrue - Await.result(e) must be(AllHandlesDiedException) - } + test("report an error if all hosts are removed") { + val cluster = new DynamicCluster[SocketAddress](hosts) + val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() + val e = (handle.error?) + hosts.foreach { host => cluster.del(host) } - "silently handle the removal of a host that was never added" in { - val cluster = new DynamicCluster[SocketAddress](hosts) - val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() + assert(e.isDefined) + intercept[AllHandlesDiedException.type] { + Await.result(e) + } + } - val messages = configureMessageReader(handle) - val sentMessages = 0 until N*10 map { i => "message %d".format(i) } + test("silently handle the removal of a host that was never added") { + val cluster = new DynamicCluster[SocketAddress](hosts) + val handle = MultiReader(cluster, "the_queue").clientBuilder(mockClientBuilder).build() - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) - } - messages must eventually(be_==(sentMessages.toSet)) - messages.clear() + val messages = configureMessageReader(handle) + val sentMessages = 0 until N*10 map { i => "message %d".format(i) } - cluster.del(InetSocketAddress.createUnresolved("10.0.0.100", 22133)) + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) + } + assert(messages === sentMessages.toSet) + messages.clear() - sentMessages.zipWithIndex foreach { case (m, i) => - Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) - } - messages must eventually(be_==(sentMessages.toSet)) - } + cluster.del(InetSocketAddress.createUnresolved("10.0.0.100", 22133)) + + sentMessages.zipWithIndex foreach { case (m, i) => + Await.result(services(i % services.size).apply(Set("the_queue", Time.now, m))) } + + assert(messages === sentMessages.toSet) } -} +} \ No newline at end of file diff --git a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/ReadHandleSpec.scala b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/ReadHandleSpec.scala index 2cc5aecb3a6..5a3d2ec3318 100644 --- a/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/ReadHandleSpec.scala +++ b/finagle-kestrel/src/test/scala/com/twitter/finagle/kestrel/unit/ReadHandleSpec.scala @@ -5,10 +5,12 @@ import com.twitter.finagle.kestrel._ import com.twitter.io.Charsets import com.twitter.util.Await import org.jboss.netty.buffer.ChannelBuffers -import org.specs.SpecificationWithJUnit -import org.specs.mock.Mockito -class ReadHandleSpec extends SpecificationWithJUnit with Mockito { +import org.junit.runner.RunWith +import org.scalatest.{FunSuite, Suites} +import org.scalatest.junit.JUnitRunner + +trait Messaging { def msg_(i: Int) = { val ack = new Broker[Unit] val abort = new Broker[Unit] @@ -16,138 +18,144 @@ class ReadHandleSpec extends SpecificationWithJUnit with Mockito { } def msg(i: Int) = { val (_, m) = msg_(i); m } +} - "ReadHandle.buffered" should { - val N = 10 - val messages = new Broker[ReadMessage] - val error = new Broker[Throwable] - val close = new Broker[Unit] - val handle = ReadHandle(messages.recv, error.recv, close.send(())) - val buffered = handle.buffered(N) - - "acknowledge howmany messages" in { - 0 until N foreach { i => - val (ack, m) = msg_(i) - messages ! m - (ack?).isDefined must beTrue - } - val (ack, m) = msg_(0) - (ack?).isDefined must beFalse +@RunWith(classOf[JUnitRunner]) +class ReadHandleTest extends Suites( + new ReadHandleBufferedTest, + new ReadHandleMergedTest +) + +class ReadHandleBufferedTest extends FunSuite with Messaging { + val N = 10 + val messages = new Broker[ReadMessage] + val error = new Broker[Throwable] + val close = new Broker[Unit] + val handle = ReadHandle(messages.recv, error.recv, close.send(())) + val buffered = handle.buffered(N) + + test("acknowledge howmany messages") { + 0 until N foreach { i => + val (ack, m) = msg_(i) + messages ! m + assert((ack?).isDefined) } + val (ack, m) = msg_(0) + assert(!(ack?).isDefined) + } - "not synchronize on send when buffer is full" in { - 0 until N foreach { _ => - (messages ! msg(0)).isDefined must beTrue - } - (messages ! msg(0)).isDefined must beFalse + test("not synchronize on send when buffer is full") { + 0 until N foreach { _ => + assert((messages ! msg(0)).isDefined) } + assert(!(messages ! msg(0)).isDefined) + } - "keep the buffer full" in { - 0 until N foreach { _ => - messages ! msg(0) - } - val sent = messages ! msg(0) - sent.isDefined must beFalse - val recvd = (buffered.messages?) - recvd.isDefined must beTrue - Await.result(recvd).ack.sync() - sent.isDefined must beTrue + test("keep the buffer full") { + 0 until N foreach { _ => + messages ! msg(0) } + val sent = messages ! msg(0) + assert(!sent.isDefined) + val recvd = (buffered.messages?) + assert(recvd.isDefined) + Await.result(recvd).ack.sync() + assert(sent.isDefined) + } - "preserve FIFO order" in { - 0 until N foreach { i => - messages ! msg(i) - } - - 0 until N foreach { i => - val recvd = (buffered.messages?) - recvd.isDefined must beTrue - Await.result(recvd).bytes.toString(Charsets.Utf8) must be_==(i.toString) - } + test("preserve FIFO order") { + 0 until N foreach { i => + messages ! msg(i) } - "propagate errors" in { - val errd = (buffered.error?) - errd.isDefined must beFalse - val e = new Exception("sad panda") - error ! e - errd.isDefined must beTrue - Await.result(errd) must be(e) + 0 until N foreach { i => + val recvd = (buffered.messages?) + assert(recvd.isDefined) + assert(Await.result(recvd).bytes.toString(Charsets.Utf8) === i.toString) } + } - "when closed" in { - "propagate immediately if empty" in { - val closed = (close?) - closed.isDefined must beFalse - buffered.close() - closed.isDefined must beTrue - } - - "wait for outstanding acks before closing underlying" in { - val closed = (close?) - closed.isDefined must beFalse - messages ! msg(0) - messages ! msg(1) - buffered.close() - closed.isDefined must beFalse - val m0 = (buffered.messages?) - m0.isDefined must beTrue - Await.result(m0).ack.sync() - closed.isDefined must beFalse - val m1 = (buffered.messages?) - m1.isDefined must beTrue - Await.result(m1).ack.sync() - closed.isDefined must beTrue - } + test("propagate errors") { + val errd = (buffered.error?) + assert(!errd.isDefined) + val e = new Exception("Sad panda") + error ! e + assert(errd.isDefined) + intercept[Exception] { + Await.result(errd) } } - "ReadHandle.merged" should { - val messages0 = new Broker[ReadMessage] - val error0 = new Broker[Throwable] - val close0 = new Broker[Unit] - val handle0 = ReadHandle(messages0.recv, error0.recv, close0.send(())) - val messages1 = new Broker[ReadMessage] - val error1 = new Broker[Throwable] - val close1 = new Broker[Unit] - val handle1 = ReadHandle(messages1.recv, error1.recv, close1.send(())) + test("when closed, propagate immediately if empty") { + val closed = (close?) + assert(!closed.isDefined) + buffered.close() + assert(closed.isDefined) + } + + test("when closed, wait for outstanding acks before closing underlying") { + val closed = (close?) + assert(!closed.isDefined) + messages ! msg(0) + messages ! msg(1) + buffered.close() + assert(!closed.isDefined) + val m0 = (buffered.messages?) + assert(m0.isDefined) + Await.result(m0).ack.sync() + assert(!closed.isDefined) + val m1 = (buffered.messages?) + assert(m1.isDefined) + Await.result(m1).ack.sync() + assert(closed.isDefined) + } +} - val merged = ReadHandle.merged(Seq(handle0, handle1)) +class ReadHandleMergedTest extends FunSuite with Messaging { + val messages0 = new Broker[ReadMessage] + val error0 = new Broker[Throwable] + val close0 = new Broker[Unit] + val handle0 = ReadHandle(messages0.recv, error0.recv, close0.send(())) + val messages1 = new Broker[ReadMessage] + val error1 = new Broker[Throwable] + val close1 = new Broker[Unit] + val handle1 = ReadHandle(messages1.recv, error1.recv, close1.send(())) - "provide a merged stream of messages" in { - var count = 0 - merged.messages.foreach { _ => count += 1 } - count must be_==(0) + val merged = ReadHandle.merged(Seq(handle0, handle1)) - messages0 ! msg(0) - messages1 ! msg(1) - messages0 ! msg(2) + test("provide a merged stream of messages") { + var count = 0 + merged.messages.foreach { _ => count += 1 } + assert(count === 0) - count must be_==(3) - } + messages0 ! msg(0) + messages1 ! msg(1) + messages0 ! msg(2) + + assert(count === 3) + } - "provide a merged stream of errors" in { - var count = 0 - merged.error.foreach { _ => count += 1 } - count must be_==(0) + test("provide a merged stream of errors") { + var count = 0 + merged.error.foreach { _ => count += 1 } + assert(count === 0) - error0 ! new Exception("sad panda") - error1 ! new Exception("sad panda #2") + error0 ! new Exception("sad panda") + error1 ! new Exception("sad panda #2") - count must be_==(2) - } + assert(count === 2) + } - "propagate closes to all underlying handles" in { - val closed0 = (close0?) - val closed1 = (close1?) + test("propagate closes to all underlying handles") { + val closed0 = (close0?) + val closed1 = (close1?) - closed0.isDefined must beFalse - closed1.isDefined must beFalse + assert(!closed0.isDefined) + assert(!closed1.isDefined) - merged.close() + merged.close() - closed0.isDefined must beTrue - closed1.isDefined must beTrue - } + assert(closed0.isDefined) + assert(closed1.isDefined) } -} +} \ No newline at end of file