diff --git a/src/org/jgroups/protocols/ReliableUnicast.java b/src/org/jgroups/protocols/ReliableUnicast.java index 1c518a8fb6..a757272766 100644 --- a/src/org/jgroups/protocols/ReliableUnicast.java +++ b/src/org/jgroups/protocols/ReliableUnicast.java @@ -233,6 +233,8 @@ public T setLevel(String level) { public ReliableUnicast sendAtomically(boolean f) {send_atomically=f; return this;} public boolean loopback() {return loopback;} public ReliableUnicast loopback(boolean b) {this.loopback=b; return this;} + public ReliableUnicast timeService(TimeService ts) {this.time_service=ts; return this;} // testing only! + public ReliableUnicast lastSync(ExpiryCache
c) {this.last_sync_sent=c; return this;} // testing only! @ManagedOperation @@ -438,30 +440,6 @@ public void stop() { msg_cache.clear(); } - public Object up(Message msg) { - Address dest=msg.dest(), sender=msg.src(); - if(dest == null || dest.isMulticast() || msg.isFlagSet(NO_RELIABILITY)) // only handle unicast messages - return up_prot.up(msg); // pass up - - UnicastHeader hdr=msg.getHeader(this.id); - if(hdr == null) - return up_prot.up(msg); - switch(hdr.type) { - case DATA: // received regular message - if(is_trace) - log.trace("%s <-- %s: DATA(#%d, conn_id=%d%s)", local_addr, sender, hdr.seqno, hdr.conn_id, hdr.first? ", first" : ""); - if(Objects.equals(local_addr, sender)) - handleDataReceivedFromSelf(sender, hdr.seqno, msg); - else - handleDataReceived(sender, hdr.seqno, hdr.conn_id, hdr.first, msg); - break; // we pass the deliverable message up in handleDataReceived() - default: - handleUpEvent(sender, msg, hdr); - break; - } - return null; - } - protected void handleUpEvent(Address sender, Message msg, UnicastHeader hdr) { try { switch(hdr.type) { @@ -494,6 +472,30 @@ protected void handleUpEvent(Address sender, Message msg, UnicastHeader hdr) { } } + public Object up(Message msg) { + Address dest=msg.dest(), sender=msg.src(); + if(dest == null || dest.isMulticast() || msg.isFlagSet(NO_RELIABILITY)) // only handle unicast messages + return up_prot.up(msg); // pass up + + UnicastHeader hdr=msg.getHeader(this.id); + if(hdr == null) + return up_prot.up(msg); + switch(hdr.type) { + case DATA: // received regular message + if(is_trace) + log.trace("%s <-- %s: DATA(#%d, conn_id=%d%s)", local_addr, sender, hdr.seqno, hdr.conn_id, hdr.first? ", first" : ""); + if(Objects.equals(local_addr, sender)) + handleDataReceivedFromSelf(sender, hdr.seqno, msg); + else + handleDataReceived(sender, hdr.seqno, hdr.conn_id, hdr.first, msg); + break; // we pass the deliverable message up in handleDataReceived() + default: + handleUpEvent(sender, msg, hdr); + break; + } + return null; + } + public void up(MessageBatch batch) { if(batch.dest() == null || batch.dest().isMulticast()) { // not a unicast batch up_prot.up(batch); @@ -553,6 +555,10 @@ else if(entry == null) { if(queued_msgs != null) addQueuedMessages(sender, entry, queued_msgs); } + + // the code below removes messages that have a HIGHER conn_id; instead we should replace the + // ReceiverEntry with one with the higher conn_id! + if(msgs.keySet().retainAll(Collections.singletonList(entry.connId()))) // remove all conn-ids that don't match sendRequestForFirstSeqno(sender, batch.dest()); List> list=msgs.get(entry.connId()); @@ -611,8 +617,6 @@ protected void handleBatchFromSelf(MessageBatch batch, Entry entry) { up_prot.up(batch); } - - public Object down(Event evt) { switch (evt.getType()) { @@ -751,7 +755,6 @@ public void removeReceiveConnection(Address mbr) { entry.state(State.CLOSED); } - /** * This method is public only so it can be invoked by unit testing, but should not otherwise be used ! */ @@ -761,7 +764,6 @@ public void removeAllConnections() { recv_table.clear(); } - /** Sends a retransmit request to the given sender */ protected void retransmit(SeqnoList missing, Address sender, Address real_dest) { Message xmit_msg=new ObjectMessage(sender, missing).setFlag(OOB, NO_FC) @@ -774,7 +776,6 @@ protected void retransmit(SeqnoList missing, Address sender, Address real_dest) xmit_reqs_sent.add(missing.size()); } - /** Called by the sender to resend messages for which no ACK has been received yet */ protected void retransmit(Message msg) { if(is_trace) { @@ -926,7 +927,6 @@ protected void removeAndDeliver(Entry entry, Address sender, AsciiString cluster while(mb != null || adders.decrementAndGet() != 0); } - protected String printMessageList(List> list) { StringBuilder sb=new StringBuilder(); int size=list.size(); @@ -945,34 +945,30 @@ protected String printMessageList(List> list) { return sb.toString(); } - protected ReceiverEntry getReceiverEntry(Address sender, long seqno, boolean first, short conn_id, Address dest) { + protected ReceiverEntry getReceiverEntry(Address sender, long seqno, boolean first, short conn_id, Address real_dest) { ReceiverEntry entry=recv_table.get(sender); if(entry != null && entry.connId() == conn_id) return entry; + return _getReceiverEntry(sender, seqno, first, conn_id, real_dest); + } + // public for unit testing - don't use in app code! + public ReceiverEntry _getReceiverEntry(Address sender, long seqno, boolean first, short conn_id, Address real_dest) { + ReceiverEntry entry; recv_table_lock.lock(); try { entry=recv_table.get(sender); - if(first) { - if(entry == null) { - entry=createReceiverEntry(sender, seqno, conn_id, dest); - } - else { // entry != null && win != null - if(conn_id != entry.connId()) { - log.trace("%s: conn_id=%d != %d; resetting receiver window", local_addr, conn_id, entry.connId()); - recv_table.remove(sender); - entry=createReceiverEntry(sender, seqno, conn_id, dest); - } - } - } - else { // entry == null && win == null OR entry != null && win == null OR entry != null && win != null - if(entry == null || entry.connId() != conn_id) { + if(entry == null) { + if(first) + return createReceiverEntry(sender, seqno, conn_id, real_dest); + else { recv_table_lock.unlock(); - sendRequestForFirstSeqno(sender, dest); // drops the message and returns (see below) + sendRequestForFirstSeqno(sender, real_dest); // drops the message and returns (see below) return null; } } - return entry; + // entry != null + return compareConnIds(conn_id, entry.connId(), first, entry, sender, seqno, real_dest); } finally { if(recv_table_lock.isHeldByCurrentThread()) @@ -980,6 +976,26 @@ protected ReceiverEntry getReceiverEntry(Address sender, long seqno, boolean fir } } + protected ReceiverEntry compareConnIds(short other, short mine, boolean first, ReceiverEntry e, + Address sender, long seqno, Address real_dest) { + if(other == mine) + return e; + if(other < mine) + return null; + // other_conn_id > my_conn_id + if(first) { + log.trace("%s: other conn_id (%d) > mine (%d); creating new receiver window", local_addr, other, mine); + recv_table.remove(sender); + return createReceiverEntry(sender, seqno, other, real_dest); + } + else { + log.trace("%s: other conn_id (%d) > mine (%d) (!first); asking for first message", local_addr, other, mine); + recv_table_lock.unlock(); + sendRequestForFirstSeqno(sender, real_dest); // drops the message and returns (see below) + return null; + } + } + protected SenderEntry getSenderEntry(Address dst) { SenderEntry entry=send_table.get(dst); if(entry == null || entry.state() == State.CLOSED) { @@ -1459,8 +1475,8 @@ protected Entry(short conn_id, Buffer buf) { update(); } - protected Buffer buf() {return buf;} - protected short connId() {return conn_id;} + public Buffer buf() {return buf;} + public short connId() {return conn_id;} protected void update() {timestamp.set(getTimestamp());} protected State state() {return state;} protected Entry state(State s) {if(this.state != s) {this.state=s; update();} return this;} @@ -1513,7 +1529,8 @@ public String toString() { } } - protected final class ReceiverEntry extends Entry { + // public for unit testing + public final class ReceiverEntry extends Entry { private final Address real_dest ; // if real_dest != local_addr (https://issues.redhat.com/browse/JGRP-2729) public ReceiverEntry(Buffer received_msgs, short recv_conn_id, Address real_dest) { diff --git a/tests/junit-functional/org/jgroups/protocols/UNICAST_OOB_Test.java b/tests/junit-functional/org/jgroups/protocols/UNICAST_OOB_Test.java index 9ce8a4c4cb..8403045973 100644 --- a/tests/junit-functional/org/jgroups/protocols/UNICAST_OOB_Test.java +++ b/tests/junit-functional/org/jgroups/protocols/UNICAST_OOB_Test.java @@ -39,6 +39,7 @@ protected void setup(Class unicast_class) throws Exception { a.connect("UNICAST_OOB_Test"); b.connect("UNICAST_OOB_Test"); Util.waitUntilAllChannelsHaveSameView(3000, 100, a,b); + System.out.printf("-- cluster formed: %s\n", b.view()); } @AfterMethod @@ -75,6 +76,7 @@ public void testSecondMessageReceivedFirstOOB(Class unicast_ _testSecondMessageReceivedFirst(true, false); } + // @Test(invocationCount=100) public void testSecondMessageReceivedFirstOOBBatched(Class unicast_class) throws Exception { setup(unicast_class); _testSecondMessageReceivedFirst(true, true); @@ -84,14 +86,22 @@ protected void _testSecondMessageReceivedFirst(boolean oob, boolean use_batches) Address dest=a.getAddress(), src=b.getAddress(); Protocol u_a=a.getProtocolStack().findProtocol(Util.getUnicastProtocols()), u_b=b.getProtocolStack().findProtocol(Util.getUnicastProtocols()); - Util.invoke(u_a, "removeReceiveConnection", src); - Util.invoke(u_a, "removeSendConnection", src); - Util.invoke(u_b, "removeReceiveConnection", dest); - Util.invoke(u_b, "removeSendConnection", dest); - System.out.println("=============== removed connection between A and B ==========="); - - REVERSE reverse=new REVERSE().numMessagesToReverse(5) - .filter(msg -> msg.getDest() != null && src.equals(msg.getSrc()) && (msg.getFlags(false) == 0 || msg.isFlagSet(Message.Flag.OOB))); + for(int i=0; i < 10; i++) { + Util.invoke(u_a, "removeReceiveConnection", src); + Util.invoke(u_a, "removeSendConnection", src); + Util.invoke(u_b, "removeReceiveConnection", dest); + Util.invoke(u_b, "removeSendConnection", dest); + int num_connections=(int)Util.invoke(u_a, "getNumConnections") + + (int)Util.invoke(u_b, "getNumConnections"); + if(num_connections == 0) + break; + Util.sleep(100); + } + System.out.println("=============== removed connections between A and B ==========="); + + Protocol reverse=new REVERSE().numMessagesToReverse(5) + .filter(msg -> msg.getDest() != null && src.equals(msg.src()) && (msg.getFlags(false) == 0 || msg.isFlagSet(Message.Flag.OOB))); + // REVERSE2 reverse=new REVERSE2().filter(m -> m.dest() != null && m.isFlagSet(Message.Flag.OOB) && src.equals(m.src())); a.getProtocolStack().insertProtocol(reverse, ProtocolStack.Position.BELOW, UNICAST3.class,UNICAST4.class); if(use_batches) { @@ -100,17 +110,28 @@ protected void _testSecondMessageReceivedFirst(boolean oob, boolean use_batches) mb.start(); } - MyReceiver r=new MyReceiver<>(); + MyReceiver r=new MyReceiver().name(a.getName()).verbose(true); a.setReceiver(r); - System.out.println("========== B sends messages 1-5 to A =========="); + System.out.printf("========== B sending %s messages 1-5 to A ==========\n", oob? "OOB" : "regular"); + //u_a.setLevel("trace"); u_b.setLevel("trace"); + long start=System.currentTimeMillis(); for(int i=1; i <= 5; i++) { Message msg=new ObjectMessage(dest, (long)i); if(oob) msg.setFlag(Message.Flag.OOB); b.send(msg); + System.out.printf("-- %s: sent %s, hdrs: %s\n", b.address(), msg, msg.printHeaders()); } - Util.waitUntil(10000, 100, () -> r.size() == 5); + if(reverse instanceof REVERSE2) { + REVERSE2 rr=((REVERSE2)reverse); + Util.waitUntilTrue(2000, 100, () -> rr.size() == 5); + rr.filter(null); // from now on, all msgs are passed up + rr.deliver(); + } + + Util.waitUntil(5000, 100, () -> r.size() == 5, + () -> String.format("expected 5 messages but got %s", r.list())); long time=System.currentTimeMillis() - start; System.out.printf("===== list: %s (in %d ms)\n", r.list(), time); long expected_time=XMIT_INTERVAL * 10; // increased because times might increase with the increase in parallel tests @@ -167,11 +188,11 @@ protected static JChannel createChannel(String name, Class u Protocol p=unicast_class.getConstructor().newInstance(); Util.invoke(p, "setXmitInterval", XMIT_INTERVAL); return new JChannel( - new SHARED_LOOPBACK(), + new SHARED_LOOPBACK(), // .bundler("nb"), new SHARED_LOOPBACK_PING(), new NAKACK2(), p, - new GMS()) + new GMS().printLocalAddress(false)) .name(name); } diff --git a/tests/junit-functional/org/jgroups/tests/ReliableUnicastTest.java b/tests/junit-functional/org/jgroups/tests/ReliableUnicastTest.java new file mode 100644 index 0000000000..9f5bc5e766 --- /dev/null +++ b/tests/junit-functional/org/jgroups/tests/ReliableUnicastTest.java @@ -0,0 +1,113 @@ +package org.jgroups.tests; + +import org.jgroups.Address; +import org.jgroups.Global; +import org.jgroups.Message; +import org.jgroups.protocols.ReliableUnicast; +import org.jgroups.protocols.UNICAST4; +import org.jgroups.protocols.UnicastHeader; +import org.jgroups.stack.Protocol; +import org.jgroups.util.*; +import org.testng.annotations.*; + +import java.util.concurrent.atomic.LongAdder; + +/** + * @author Bela Ban + * @since 5.4 + */ +@Test(groups=Global.FUNCTIONAL,singleThreaded=true,dataProvider="createUnicast") +public class ReliableUnicastTest { + protected ReliableUnicast unicast; + protected DownProtocol down_prot; + protected TimeScheduler timer; + protected static final Address DEST=Util.createRandomAddress("A"); + + @DataProvider + static Object[][] createUnicast() { + return new Object[][]{ + {UNICAST4.class} + }; + } + + @BeforeClass + protected void setupTimer() { + timer=new TimeScheduler3(); + } + + @AfterClass + protected void stopTimer() { + timer.stop(); + } + + protected void setup(Class unicast_cl) throws Exception { + unicast=unicast_cl.getConstructor().newInstance(); + down_prot=new DownProtocol(); + unicast.setDownProtocol(down_prot); + down_prot.setUpProtocol(unicast); + TimeService time_service=new TimeService(timer); + unicast.timeService(time_service); + unicast.lastSync(new ExpiryCache<>(5000)); + } + + @AfterMethod + protected void destroy() { + unicast.stop(); + } + + public void testGetReceiverEntryFirst(Class unicast_class) throws Exception { + setup(unicast_class); + ReliableUnicast.ReceiverEntry entry=unicast._getReceiverEntry(DEST, 1L, true, (short)0, null); + assert entry != null && entry.connId() == 0; + entry=unicast._getReceiverEntry(DEST, 1L, true, (short)0, null); + assert entry != null && entry.connId() == 0; + assert unicast.getNumReceiveConnections() == 1; + } + + public void testGetReceiverEntryNotFirst(Class unicast_class) throws Exception { + setup(unicast_class); + ReliableUnicast.ReceiverEntry entry=unicast._getReceiverEntry(DEST, 2L, false, (short)0, null); + assert entry == null; + assert down_prot.numSendFirstReqs() == 1; + } + + public void testGetReceiverEntryExists(Class unicast_class) throws Exception { + setup(unicast_class); + ReliableUnicast.ReceiverEntry entry=unicast._getReceiverEntry(DEST, 1L, true, (short)1, null); + ReliableUnicast.ReceiverEntry old=entry; + assert entry != null && entry.connId() == 1; + + // entry exists, but this conn-ID is smaller + entry=unicast._getReceiverEntry(DEST, 1L, true, (short)0, null); + assert entry == null; + + // entry exists and conn-IDs match + ReliableUnicast.ReceiverEntry e=unicast._getReceiverEntry(DEST, 2L, true, (short)1, null); + assert e != null && e == old; + + // entry exists, but is replaced by higher conn_id + entry=unicast._getReceiverEntry(DEST, 5L, true, (short)2, null); + assert entry.connId() == 2; + assert entry.buf().high() == 4; + + entry=unicast._getReceiverEntry(DEST, 10L, false, (short)3, null); + assert entry == null; + assert down_prot.numSendFirstReqs() == 1; + } + + + protected static class DownProtocol extends Protocol { + protected final LongAdder num_send_first_reqs=new LongAdder(); + + protected long numSendFirstReqs() {return num_send_first_reqs.sum();} + protected DownProtocol clear() {num_send_first_reqs.reset(); return this;} + + @Override + public Object down(Message msg) { + UnicastHeader hdr=msg.getHeader(up_prot.getId()); + if(hdr != null && hdr.type() == UnicastHeader.SEND_FIRST_SEQNO) + num_send_first_reqs.increment(); + return null; + } + } +}