From 94c5d0fda5935239a32690f06b88b12f37249466 Mon Sep 17 00:00:00 2001 From: Wyatt Hepler Date: Fri, 6 Jan 2023 15:47:40 +0000 Subject: [PATCH] pw_rpc: Support opening and closing RPC channels in Java client Fixes: b/250065568 Change-Id: I8a467755d1a9f67dcc9c60d0fda2c9e326b783fb Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/119570 Commit-Queue: Wyatt Hepler Reviewed-by: Alexei Frolov --- .../main/dev/pigweed/pw_rpc/AbstractCall.java | 4 ++ .../java/main/dev/pigweed/pw_rpc/Client.java | 19 +++++++ .../main/dev/pigweed/pw_rpc/Endpoint.java | 17 ++++++ .../pw_rpc/InvalidRpcChannelException.java | 4 ++ .../test/dev/pigweed/pw_rpc/ClientTest.java | 55 +++++++++++++++++++ .../test/dev/pigweed/pw_rpc/EndpointTest.java | 1 - 6 files changed, 99 insertions(+), 1 deletion(-) diff --git a/pw_rpc/java/main/dev/pigweed/pw_rpc/AbstractCall.java b/pw_rpc/java/main/dev/pigweed/pw_rpc/AbstractCall.java index 9672e147ac..901ebe2938 100644 --- a/pw_rpc/java/main/dev/pigweed/pw_rpc/AbstractCall.java +++ b/pw_rpc/java/main/dev/pigweed/pw_rpc/AbstractCall.java @@ -75,6 +75,10 @@ public final boolean finish() throws ChannelOutputException { return rpcs.clientStreamEnd(this); } + final int getChannelId() { + return rpc.channel().id(); + } + final void sendPacket(byte[] packet) throws ChannelOutputException { rpc.channel().send(packet); } diff --git a/pw_rpc/java/main/dev/pigweed/pw_rpc/Client.java b/pw_rpc/java/main/dev/pigweed/pw_rpc/Client.java index 38d6c5c842..361ba893da 100644 --- a/pw_rpc/java/main/dev/pigweed/pw_rpc/Client.java +++ b/pw_rpc/java/main/dev/pigweed/pw_rpc/Client.java @@ -93,6 +93,25 @@ public void onError(Status status) { }); } + /** + * Adds a new channel to this RPC client. + * + * @throws InvalidRpcChannelException if the channel's ID is already in use + */ + public void openChannel(Channel channel) { + rpcs.openChannel(channel); + } + + /** + * Closes a channel and aborts and RPCs using it. + * + * @param id the channel ID to close + * @return true if the channel was closed; false if the channel was not found + */ + public boolean closeChannel(int id) { + return rpcs.closeChannel(id); + } + /** * Returns a MethodClient with the given name for the provided channelID * diff --git a/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java b/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java index f731b70ae4..5d3a7a8cad 100644 --- a/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java +++ b/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java @@ -142,6 +142,23 @@ private boolean sendPacket(AbstractCall call, byte[] packet) throws Channe return true; } + public synchronized void openChannel(Channel channel) { + if (channels.putIfAbsent(channel.id(), channel) != null) { + throw InvalidRpcChannelException.duplicate(channel.id()); + } + } + + public synchronized boolean closeChannel(int id) { + if (channels.remove(id) == null) { + return false; + } + pending.values() + .stream() + .filter(call -> call.getChannelId() == id) + .forEach(call -> call.handleError(Status.ABORTED)); + return true; + } + public synchronized boolean handleNext(PendingRpc rpc, ByteString payload) { AbstractCall call = pending.get(rpc); if (call == null) { diff --git a/pw_rpc/java/main/dev/pigweed/pw_rpc/InvalidRpcChannelException.java b/pw_rpc/java/main/dev/pigweed/pw_rpc/InvalidRpcChannelException.java index 47d8af59f3..5916907cc5 100644 --- a/pw_rpc/java/main/dev/pigweed/pw_rpc/InvalidRpcChannelException.java +++ b/pw_rpc/java/main/dev/pigweed/pw_rpc/InvalidRpcChannelException.java @@ -19,6 +19,10 @@ static InvalidRpcChannelException unknown(int channelId) { return new InvalidRpcChannelException("Invalid or closed RPC channel " + channelId); } + static InvalidRpcChannelException duplicate(int channelId) { + return new InvalidRpcChannelException("A channel with ID " + channelId + " already exists!"); + } + private InvalidRpcChannelException(String message) { super(message); } diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java index ac97d858a8..b3e31984b4 100644 --- a/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java @@ -362,4 +362,59 @@ public void streamObserverClient_create_invokeMethod() throws Exception { verify(mockChannelOutput) .send(requestPacket("pw.rpc.test1.TheTestService", "SomeUnary", payload).toByteArray()); } + + @Test + public void closeChannel_abortsExisting() throws Exception { + MethodClient serverStreamMethod = + client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService", "SomeServerStreaming"); + + Call call1 = serverStreamMethod.invokeServerStreaming(REQUEST_PAYLOAD, observer); + Call call2 = client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService", "SomeClientStreaming") + .invokeClientStreaming(observer); + assertThat(call1.active()).isTrue(); + assertThat(call2.active()).isTrue(); + + assertThat(client.closeChannel(CHANNEL_ID)).isTrue(); + + assertThat(call1.active()).isFalse(); + assertThat(call2.active()).isFalse(); + + verify(observer, times(2)).onError(Status.ABORTED); + + assertThrows(InvalidRpcChannelException.class, + () -> serverStreamMethod.invokeServerStreaming(REQUEST_PAYLOAD, observer)); + } + + @Test + public void closeChannel_noCalls() { + assertThat(client.closeChannel(CHANNEL_ID)).isTrue(); + } + + @Test + public void closeChannel_knownChannel() { + assertThat(client.closeChannel(CHANNEL_ID + 100)).isFalse(); + } + + @Test + public void openChannel_uniqueChannel() throws Exception { + int newChannelId = CHANNEL_ID + 100; + Channel.Output channelOutput = Mockito.mock(Channel.Output.class); + client.openChannel(new Channel(newChannelId, channelOutput)); + + client.method(newChannelId, "pw.rpc.test1.TheTestService", "SomeUnary") + .invokeUnary(REQUEST_PAYLOAD, observer); + + verify(channelOutput) + .send(requestPacket("pw.rpc.test1.TheTestService", "SomeUnary", REQUEST_PAYLOAD) + .toBuilder() + .setChannelId(newChannelId) + .build() + .toByteArray()); + } + + @Test + public void openChannel_alreadyExists_throwsException() { + assertThrows(InvalidRpcChannelException.class, + () -> client.openChannel(new Channel(CHANNEL_ID, packet -> {}))); + } } diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java index 19f46ec17e..adad853b41 100644 --- a/pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java @@ -30,7 +30,6 @@ import org.junit.Rule; import org.junit.Test; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule;