diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 94841d99da..16998a3ab0 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -83,3 +83,4 @@ robovm = { module = "com.mobidevelop.robovm:robovm-rt", version.ref = "robovm" } googleJavaFormat = "com.google.googlejavaformat:google-java-format:1.25.0" ktlint = "com.pinterest.ktlint:ktlint-cli:1.5.0" compileTesting = "com.google.testing.compile:compile-testing:0.21.0" +testParameterInjector = "com.google.testparameterinjector:test-parameter-injector:1.18" diff --git a/retrofit-converters/wire/build.gradle b/retrofit-converters/wire/build.gradle index 6f7c414a9c..06aef0cc81 100644 --- a/retrofit-converters/wire/build.gradle +++ b/retrofit-converters/wire/build.gradle @@ -12,6 +12,7 @@ dependencies { testImplementation libs.junit testImplementation libs.truth testImplementation libs.okhttp.mockwebserver + testImplementation libs.testParameterInjector } jar { diff --git a/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireConverterFactory.java b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireConverterFactory.java index f3472b5744..aaa4c26cf6 100644 --- a/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireConverterFactory.java +++ b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireConverterFactory.java @@ -22,6 +22,7 @@ import javax.annotation.Nullable; import okhttp3.RequestBody; import okhttp3.ResponseBody; +import retrofit2.Call; import retrofit2.Converter; import retrofit2.Retrofit; @@ -31,11 +32,30 @@ *

This converter only applies for types which extend from {@link Message}. */ public final class WireConverterFactory extends Converter.Factory { + /** + * Create an instance which serializes request messages to bytes eagerly on the caller thread + * when either {@link Call#execute()} or {@link Call#enqueue} is called. Response bytes are + * always converted to message instances on one of OKHttp's background threads. + */ public static WireConverterFactory create() { - return new WireConverterFactory(); + return new WireConverterFactory(false); } - private WireConverterFactory() {} + /** + * Create an instance which streams serialization of request messages to bytes on the HTTP thread + * This is either the calling thread for {@link Call#execute()}, or one of OKHttp's background + * threads for {@link Call#enqueue}. Response bytes are always converted to message instances on + * one of OKHttp's background threads. + */ + public static WireConverterFactory createStreaming() { + return new WireConverterFactory(true); + } + + private final boolean streaming; + + private WireConverterFactory(boolean streaming) { + this.streaming = streaming; + } @Override public @Nullable Converter responseBodyConverter( @@ -67,6 +87,6 @@ private WireConverterFactory() {} } //noinspection unchecked ProtoAdapter adapter = ProtoAdapter.get((Class) c); - return new WireRequestBodyConverter<>(adapter); + return new WireRequestBodyConverter<>(adapter, streaming); } } diff --git a/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireRequestBodyConverter.java b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireRequestBodyConverter.java index 1d6be146f6..c1d96d3a02 100644 --- a/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireRequestBodyConverter.java +++ b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireRequestBodyConverter.java @@ -24,16 +24,22 @@ import retrofit2.Converter; final class WireRequestBodyConverter> implements Converter { - private static final MediaType MEDIA_TYPE = MediaType.get("application/x-protobuf"); + static final MediaType MEDIA_TYPE = MediaType.get("application/x-protobuf"); private final ProtoAdapter adapter; + private final boolean streaming; - WireRequestBodyConverter(ProtoAdapter adapter) { + WireRequestBodyConverter(ProtoAdapter adapter, boolean streaming) { this.adapter = adapter; + this.streaming = streaming; } @Override public RequestBody convert(T value) throws IOException { + if (streaming) { + return new WireStreamingRequestBody<>(adapter, value); + } + Buffer buffer = new Buffer(); adapter.encode(buffer, value); return RequestBody.create(MEDIA_TYPE, buffer.snapshot()); diff --git a/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireStreamingRequestBody.java b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireStreamingRequestBody.java new file mode 100644 index 0000000000..f441ff4a6d --- /dev/null +++ b/retrofit-converters/wire/src/main/java/retrofit2/converter/wire/WireStreamingRequestBody.java @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2015 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.converter.wire; + +import static retrofit2.converter.wire.WireRequestBodyConverter.MEDIA_TYPE; + +import com.squareup.wire.Message; +import com.squareup.wire.ProtoAdapter; +import java.io.IOException; +import okhttp3.MediaType; +import okhttp3.RequestBody; +import okio.BufferedSink; + +final class WireStreamingRequestBody> extends RequestBody { + private final ProtoAdapter adapter; + private final T value; + + WireStreamingRequestBody(ProtoAdapter adapter, T value) { + this.adapter = adapter; + this.value = value; + } + + @Override + public MediaType contentType() { + return MEDIA_TYPE; + } + + @Override + public void writeTo(BufferedSink sink) throws IOException { + adapter.encode(sink, value); + } +} diff --git a/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/CrashingPhone.java b/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/CrashingPhone.java new file mode 100644 index 0000000000..f020475e37 --- /dev/null +++ b/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/CrashingPhone.java @@ -0,0 +1,129 @@ +// Code generated by Wire protocol buffer compiler, do not edit. +// Source file: phone.proto at 6:1 +package retrofit2.converter.wire; + +import com.squareup.wire.FieldEncoding; +import com.squareup.wire.Message; +import com.squareup.wire.ProtoAdapter; +import com.squareup.wire.ProtoReader; +import com.squareup.wire.ProtoWriter; +import com.squareup.wire.WireField; +import com.squareup.wire.internal.Internal; +import java.io.EOFException; +import java.io.IOException; +import okio.ByteString; + +public final class CrashingPhone extends Message { + public static final ProtoAdapter ADAPTER = new ProtoAdapter_CrashingPhone(); + + private static final long serialVersionUID = 0L; + + public static final String DEFAULT_NUMBER = ""; + + @WireField(tag = 1, adapter = "com.squareup.wire.ProtoAdapter#STRING") + public final String number; + + public CrashingPhone(String number) { + this(number, ByteString.EMPTY); + } + + public CrashingPhone(String number, ByteString unknownFields) { + super(ADAPTER, unknownFields); + this.number = number; + } + + @Override + public Builder newBuilder() { + Builder builder = new Builder(); + builder.number = number; + builder.addUnknownFields(unknownFields()); + return builder; + } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + if (!(other instanceof CrashingPhone)) return false; + CrashingPhone o = (CrashingPhone) other; + return Internal.equals(unknownFields(), o.unknownFields()) && Internal.equals(number, o.number); + } + + @Override + public int hashCode() { + int result = super.hashCode; + if (result == 0) { + result = unknownFields().hashCode(); + result = result * 37 + (number != null ? number.hashCode() : 0); + super.hashCode = result; + } + return result; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + if (number != null) builder.append(", number=").append(number); + return builder.replace(0, 2, "Phone{").append('}').toString(); + } + + public static final class Builder extends Message.Builder { + public String number; + + public Builder() {} + + public Builder number(String number) { + this.number = number; + return this; + } + + @Override + public CrashingPhone build() { + return new CrashingPhone(number, buildUnknownFields()); + } + } + + private static final class ProtoAdapter_CrashingPhone extends ProtoAdapter { + ProtoAdapter_CrashingPhone() { + super(FieldEncoding.LENGTH_DELIMITED, CrashingPhone.class); + } + + @Override + public int encodedSize(CrashingPhone value) { + return (value.number != null ? ProtoAdapter.STRING.encodedSizeWithTag(1, value.number) : 0) + + value.unknownFields().size(); + } + + @Override + public void encode(ProtoWriter writer, CrashingPhone value) throws IOException { + throw new EOFException("oops!"); + } + + @Override + public CrashingPhone decode(ProtoReader reader) throws IOException { + Builder builder = new Builder(); + long token = reader.beginMessage(); + for (int tag; (tag = reader.nextTag()) != -1; ) { + switch (tag) { + case 1: + builder.number(ProtoAdapter.STRING.decode(reader)); + break; + default: + { + FieldEncoding fieldEncoding = reader.peekFieldEncoding(); + Object value = fieldEncoding.rawProtoAdapter().decode(reader); + builder.addUnknownField(tag, fieldEncoding, value); + } + } + } + reader.endMessage(token); + return builder.build(); + } + + @Override + public CrashingPhone redact(CrashingPhone value) { + Builder builder = value.newBuilder(); + builder.clearUnknownFields(); + return builder.build(); + } + } +} diff --git a/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/WireConverterFactoryTest.java b/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/WireConverterFactoryTest.java index ff6c73748f..443a9dee91 100644 --- a/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/WireConverterFactoryTest.java +++ b/retrofit-converters/wire/src/test/java/retrofit2/converter/wire/WireConverterFactoryTest.java @@ -17,25 +17,32 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import static org.junit.Assume.assumeTrue; +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; import java.io.EOFException; import java.io.IOException; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; import okio.Buffer; import okio.ByteString; -import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.junit.runner.RunWith; import retrofit2.Call; +import retrofit2.Callback; import retrofit2.Response; import retrofit2.Retrofit; import retrofit2.http.Body; import retrofit2.http.GET; import retrofit2.http.POST; +@RunWith(TestParameterInjector.class) public final class WireConverterFactoryTest { interface Service { @GET("/") @@ -44,6 +51,9 @@ interface Service { @POST("/") Call post(@Body Phone impl); + @POST("/") + Call postCrashing(@Body CrashingPhone impl); + @GET("/") Call wrongClass(); @@ -53,14 +63,17 @@ interface Service { @Rule public final MockWebServer server = new MockWebServer(); - private Service service; + private final Service service; + private final boolean streaming; + + public WireConverterFactoryTest(@TestParameter boolean streaming) { + this.streaming = streaming; - @Before - public void setUp() { Retrofit retrofit = new Retrofit.Builder() .baseUrl(server.url("/")) - .addConverterFactory(WireConverterFactory.create()) + .addConverterFactory( + streaming ? WireConverterFactory.createStreaming() : WireConverterFactory.create()) .build(); service = retrofit.create(Service.class); } @@ -80,6 +93,36 @@ public void serializeAndDeserialize() throws IOException, InterruptedException { assertThat(request.getHeader("Content-Type")).isEqualTo("application/x-protobuf"); } + @Test + public void serializeIsStreamed() throws IOException, InterruptedException { + assumeTrue(streaming); + + Call call = service.postCrashing(new CrashingPhone("(519) 867-5309")); + + final AtomicReference throwableRef = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(1); + + // If streaming were broken, the call to enqueue would throw the exception synchronously. + call.enqueue( + new Callback() { + @Override + public void onResponse(Call call, Response response) { + latch.countDown(); + } + + @Override + public void onFailure(Call call, Throwable t) { + throwableRef.set(t); + latch.countDown(); + } + }); + latch.await(); + + Throwable throwable = throwableRef.get(); + assertThat(throwable).isInstanceOf(EOFException.class); + assertThat(throwable).hasMessageThat().isEqualTo("oops!"); + } + @Test public void deserializeEmpty() throws IOException { server.enqueue(new MockResponse());