Skip to content

Commit

Permalink
Add streaming support for Wire request bodies
Browse files Browse the repository at this point in the history
  • Loading branch information
JakeWharton committed Dec 5, 2024
1 parent 9d8286f commit ce4b98f
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 10 deletions.
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,4 @@ robovm = { module = "com.mobidevelop.robovm:robovm-rt", version.ref = "robovm" }
googleJavaFormat = "com.google.googlejavaformat:google-java-format:1.25.0"
ktlint = "com.pinterest.ktlint:ktlint-cli:1.5.0"
compileTesting = "com.google.testing.compile:compile-testing:0.21.0"
testParameterInjector = "com.google.testparameterinjector:test-parameter-injector:1.18"
1 change: 1 addition & 0 deletions retrofit-converters/wire/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies {
testImplementation libs.junit
testImplementation libs.truth
testImplementation libs.okhttp.mockwebserver
testImplementation libs.testParameterInjector
}

jar {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import javax.annotation.Nullable;
import okhttp3.RequestBody;
import okhttp3.ResponseBody;
import retrofit2.Call;
import retrofit2.Converter;
import retrofit2.Retrofit;

Expand All @@ -31,11 +32,30 @@
* <p>This converter only applies for types which extend from {@link Message}.
*/
public final class WireConverterFactory extends Converter.Factory {
/**
* Create an instance which serializes request messages to bytes eagerly on the caller thread
* when either {@link Call#execute()} or {@link Call#enqueue} is called. Response bytes are
* always converted to message instances on one of OKHttp's background threads.
*/
public static WireConverterFactory create() {
return new WireConverterFactory();
return new WireConverterFactory(false);
}

private WireConverterFactory() {}
/**
* Create an instance which streams serialization of request messages to bytes on the HTTP thread
* This is either the calling thread for {@link Call#execute()}, or one of OKHttp's background
* threads for {@link Call#enqueue}. Response bytes are always converted to message instances on
* one of OKHttp's background threads.
*/
public static WireConverterFactory createStreaming() {
return new WireConverterFactory(true);
}

private final boolean streaming;

private WireConverterFactory(boolean streaming) {
this.streaming = streaming;
}

@Override
public @Nullable Converter<ResponseBody, ?> responseBodyConverter(
Expand Down Expand Up @@ -67,6 +87,6 @@ private WireConverterFactory() {}
}
//noinspection unchecked
ProtoAdapter<? extends Message> adapter = ProtoAdapter.get((Class<? extends Message>) c);
return new WireRequestBodyConverter<>(adapter);
return new WireRequestBodyConverter<>(adapter, streaming);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,22 @@
import retrofit2.Converter;

final class WireRequestBodyConverter<T extends Message<T, ?>> implements Converter<T, RequestBody> {
private static final MediaType MEDIA_TYPE = MediaType.get("application/x-protobuf");
static final MediaType MEDIA_TYPE = MediaType.get("application/x-protobuf");

private final ProtoAdapter<T> adapter;
private final boolean streaming;

WireRequestBodyConverter(ProtoAdapter<T> adapter) {
WireRequestBodyConverter(ProtoAdapter<T> adapter, boolean streaming) {
this.adapter = adapter;
this.streaming = streaming;
}

@Override
public RequestBody convert(T value) throws IOException {
if (streaming) {
return new WireStreamingRequestBody<>(adapter, value);
}

Buffer buffer = new Buffer();
adapter.encode(buffer, value);
return RequestBody.create(MEDIA_TYPE, buffer.snapshot());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (C) 2015 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package retrofit2.converter.wire;

import static retrofit2.converter.wire.WireRequestBodyConverter.MEDIA_TYPE;

import com.squareup.wire.Message;
import com.squareup.wire.ProtoAdapter;
import java.io.IOException;
import okhttp3.MediaType;
import okhttp3.RequestBody;
import okio.BufferedSink;

final class WireStreamingRequestBody<T extends Message<T, ?>> extends RequestBody {
private final ProtoAdapter<T> adapter;
private final T value;

WireStreamingRequestBody(ProtoAdapter<T> adapter, T value) {
this.adapter = adapter;
this.value = value;
}

@Override
public MediaType contentType() {
return MEDIA_TYPE;
}

@Override
public void writeTo(BufferedSink sink) throws IOException {
adapter.encode(sink, value);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Code generated by Wire protocol buffer compiler, do not edit.
// Source file: phone.proto at 6:1
package retrofit2.converter.wire;

import com.squareup.wire.FieldEncoding;
import com.squareup.wire.Message;
import com.squareup.wire.ProtoAdapter;
import com.squareup.wire.ProtoReader;
import com.squareup.wire.ProtoWriter;
import com.squareup.wire.WireField;
import com.squareup.wire.internal.Internal;
import java.io.EOFException;
import java.io.IOException;
import okio.ByteString;

public final class CrashingPhone extends Message<CrashingPhone, CrashingPhone.Builder> {
public static final ProtoAdapter<CrashingPhone> ADAPTER = new ProtoAdapter_CrashingPhone();

private static final long serialVersionUID = 0L;

public static final String DEFAULT_NUMBER = "";

@WireField(tag = 1, adapter = "com.squareup.wire.ProtoAdapter#STRING")
public final String number;

public CrashingPhone(String number) {
this(number, ByteString.EMPTY);
}

public CrashingPhone(String number, ByteString unknownFields) {
super(ADAPTER, unknownFields);
this.number = number;
}

@Override
public Builder newBuilder() {
Builder builder = new Builder();
builder.number = number;
builder.addUnknownFields(unknownFields());
return builder;
}

@Override
public boolean equals(Object other) {
if (other == this) return true;
if (!(other instanceof CrashingPhone)) return false;
CrashingPhone o = (CrashingPhone) other;
return Internal.equals(unknownFields(), o.unknownFields()) && Internal.equals(number, o.number);
}

@Override
public int hashCode() {
int result = super.hashCode;
if (result == 0) {
result = unknownFields().hashCode();
result = result * 37 + (number != null ? number.hashCode() : 0);
super.hashCode = result;
}
return result;
}

@Override
public String toString() {
StringBuilder builder = new StringBuilder();
if (number != null) builder.append(", number=").append(number);
return builder.replace(0, 2, "Phone{").append('}').toString();
}

public static final class Builder extends Message.Builder<CrashingPhone, Builder> {
public String number;

public Builder() {}

public Builder number(String number) {
this.number = number;
return this;
}

@Override
public CrashingPhone build() {
return new CrashingPhone(number, buildUnknownFields());
}
}

private static final class ProtoAdapter_CrashingPhone extends ProtoAdapter<CrashingPhone> {
ProtoAdapter_CrashingPhone() {
super(FieldEncoding.LENGTH_DELIMITED, CrashingPhone.class);
}

@Override
public int encodedSize(CrashingPhone value) {
return (value.number != null ? ProtoAdapter.STRING.encodedSizeWithTag(1, value.number) : 0)
+ value.unknownFields().size();
}

@Override
public void encode(ProtoWriter writer, CrashingPhone value) throws IOException {
throw new EOFException("oops!");
}

@Override
public CrashingPhone decode(ProtoReader reader) throws IOException {
Builder builder = new Builder();
long token = reader.beginMessage();
for (int tag; (tag = reader.nextTag()) != -1; ) {
switch (tag) {
case 1:
builder.number(ProtoAdapter.STRING.decode(reader));
break;
default:
{
FieldEncoding fieldEncoding = reader.peekFieldEncoding();
Object value = fieldEncoding.rawProtoAdapter().decode(reader);
builder.addUnknownField(tag, fieldEncoding, value);
}
}
}
reader.endMessage(token);
return builder.build();
}

@Override
public CrashingPhone redact(CrashingPhone value) {
Builder builder = value.newBuilder();
builder.clearUnknownFields();
return builder.build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,32 @@

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;

import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import java.io.EOFException;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import okio.Buffer;
import okio.ByteString;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Response;
import retrofit2.Retrofit;
import retrofit2.http.Body;
import retrofit2.http.GET;
import retrofit2.http.POST;

@RunWith(TestParameterInjector.class)
public final class WireConverterFactoryTest {
interface Service {
@GET("/")
Expand All @@ -44,6 +51,9 @@ interface Service {
@POST("/")
Call<Phone> post(@Body Phone impl);

@POST("/")
Call<Void> postCrashing(@Body CrashingPhone impl);

@GET("/")
Call<String> wrongClass();

Expand All @@ -53,14 +63,17 @@ interface Service {

@Rule public final MockWebServer server = new MockWebServer();

private Service service;
private final Service service;
private final boolean streaming;

public WireConverterFactoryTest(@TestParameter boolean streaming) {
this.streaming = streaming;

@Before
public void setUp() {
Retrofit retrofit =
new Retrofit.Builder()
.baseUrl(server.url("/"))
.addConverterFactory(WireConverterFactory.create())
.addConverterFactory(
streaming ? WireConverterFactory.createStreaming() : WireConverterFactory.create())
.build();
service = retrofit.create(Service.class);
}
Expand All @@ -80,6 +93,36 @@ public void serializeAndDeserialize() throws IOException, InterruptedException {
assertThat(request.getHeader("Content-Type")).isEqualTo("application/x-protobuf");
}

@Test
public void serializeIsStreamed() throws IOException, InterruptedException {
assumeTrue(streaming);

Call<Void> call = service.postCrashing(new CrashingPhone("(519) 867-5309"));

final AtomicReference<Throwable> throwableRef = new AtomicReference<>();
final CountDownLatch latch = new CountDownLatch(1);

// If streaming were broken, the call to enqueue would throw the exception synchronously.
call.enqueue(
new Callback<Void>() {
@Override
public void onResponse(Call<Void> call, Response<Void> response) {
latch.countDown();
}

@Override
public void onFailure(Call<Void> call, Throwable t) {
throwableRef.set(t);
latch.countDown();
}
});
latch.await();

Throwable throwable = throwableRef.get();
assertThat(throwable).isInstanceOf(EOFException.class);
assertThat(throwable).hasMessageThat().isEqualTo("oops!");
}

@Test
public void deserializeEmpty() throws IOException {
server.enqueue(new MockResponse());
Expand Down

0 comments on commit ce4b98f

Please sign in to comment.