Skip to content

Commit

Permalink
PayloadInterceptorRSocket retains all payloads
Browse files Browse the repository at this point in the history
Flux#skip discards its corresponding elements, meaning that they
aren't intended for reuse. When using RSocket's ByteBufPayloads,
this means that the bytes are releaseed back into RSocket's pool.

Since the downstream request may still need the skipped payload,
we should construct the publisher in a different way so as to
avoid the preemptive release.

Deferring Spring JavaFormat to clarify what changed.

Closes gh-9345
  • Loading branch information
jzheaux committed Jun 4, 2021
1 parent 895ae0a commit 63cd52d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019 the original author or authors.
* Copyright 2019-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -92,13 +92,16 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
return Flux.from(payloads).switchOnFirst((signal, innerFlux) -> {
Payload firstPayload = signal.get();
return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload).flatMapMany((context) -> innerFlux
.skip(1).flatMap((p) -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p))
.transform((securedPayloads) -> Flux.concat(Flux.just(firstPayload), securedPayloads))
.index().concatMap((tuple) -> justOrIntercept(tuple.getT1(), tuple.getT2()))
.transform((securedPayloads) -> this.source.requestChannel(securedPayloads))
.subscriberContext(context));
});
}

private Mono<Payload> justOrIntercept(Long index, Payload payload) {
return (index == 0) ? Mono.just(payload) : intercept(PayloadExchangeType.PAYLOAD, payload).thenReturn(payload);
}

@Override
public Mono<Void> metadataPush(Payload payload) {
return intercept(PayloadExchangeType.METADATA_PUSH, payload)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019 the original author or authors.
* Copyright 2019-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,10 +19,14 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.metadata.WellKnownMimeType;
import io.rsocket.util.ByteBufPayload;
import io.rsocket.util.DefaultPayload;
import io.rsocket.util.RSocketProxy;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -32,13 +36,17 @@
import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.test.publisher.PublisherProbe;
import reactor.test.publisher.TestPublisher;
import reactor.util.context.Context;

import org.springframework.http.MediaType;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
Expand All @@ -56,6 +64,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;

Expand Down Expand Up @@ -265,6 +274,57 @@ public void requestChannelWhenInterceptorCompletesThenDelegateSubscribed() {
verify(this.delegate).requestChannel(any());
}

// gh-9345
@Test
public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() {
ExecutorService executors = Executors.newSingleThreadExecutor();
Payload payload = ByteBufPayload.create("data");
Payload payloadTwo = ByteBufPayload.create("moredata");
Payload payloadThree = ByteBufPayload.create("stillmoredata");
Context ctx = Context.empty();
Flux<Payload> payloads = this.payloadResult.flux();
given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty())
.willReturn(Mono.error(() -> new AccessDeniedException("Access Denied")));
given(this.delegate.requestChannel(any())).willAnswer((invocation) -> {
Flux<Payload> input = invocation.getArgument(0);
return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8)
.transform((data) -> Flux.<String>create((emitter) -> {
Runnable run = () -> data.subscribe(new CoreSubscriber<String>() {
@Override
public void onSubscribe(Subscription s) {
s.request(3);
}

@Override
public void onNext(String s) {
emitter.next(s);
}

@Override
public void onError(Throwable t) {
emitter.error(t);
}

@Override
public void onComplete() {
emitter.complete();
}
});
executors.execute(run);
})).map(DefaultPayload::create));
});
PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx);
StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release))
.then(() -> this.payloadResult.assertSubscribers())
.then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree))
.assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8()))
.verifyError(AccessDeniedException.class);
verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any());
assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo);
verify(this.delegate).requestChannel(any());
}

@Test
public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() {
RuntimeException expected = new RuntimeException("Oops");
Expand Down

0 comments on commit 63cd52d

Please sign in to comment.