diff --git a/retrofit/kotlin-test/src/test/java/retrofit2/KotlinSuspendTest.kt b/retrofit/kotlin-test/src/test/java/retrofit2/KotlinSuspendTest.kt index 260fd7ab98..958f96761a 100644 --- a/retrofit/kotlin-test/src/test/java/retrofit2/KotlinSuspendTest.kt +++ b/retrofit/kotlin-test/src/test/java/retrofit2/KotlinSuspendTest.kt @@ -26,15 +26,14 @@ import okhttp3.mockwebserver.MockWebServer import okhttp3.mockwebserver.SocketPolicy.DISCONNECT_AFTER_REQUEST import okhttp3.mockwebserver.SocketPolicy.NO_RESPONSE import org.assertj.core.api.Assertions.assertThat -import org.junit.Assert.assertTrue -import org.junit.Assert.fail -import org.junit.Ignore +import org.junit.Assert.* import org.junit.Rule import org.junit.Test import retrofit2.helpers.ToStringConverterFactory import retrofit2.http.GET import retrofit2.http.HEAD import retrofit2.http.Path +import retrofit2.http.Query import java.io.IOException import java.lang.reflect.ParameterizedType import java.lang.reflect.Type @@ -43,9 +42,23 @@ import kotlin.coroutines.CoroutineContext class KotlinSuspendTest { @get:Rule val server = MockWebServer() - interface Service { + interface SuperService { + @GET("/") suspend fun noBody(@Query("x") arg: Long) + } + + interface Service : SuperService { @GET("/") suspend fun body(): String @GET("/") suspend fun bodyNullable(): String? + @GET("/") suspend fun noBody() + @GET("/") suspend fun noBody(@Query("x") arg: String) + @GET("/") suspend fun noBody(@Query("x") arg: Int) + @GET("/") suspend fun noBody(@Query("x") arg: Array) + @GET("/") suspend fun noBody(@Query("x") arg: Array) + @GET("/") suspend fun noBody(@Query("x") arg: IntArray) + + @UseExperimental(ExperimentalUnsignedTypes::class) + @GET("/") suspend fun noBody(@Query("x") arg: UInt) + @GET("/") suspend fun response(): Response @GET("/") suspend fun unit() @HEAD("/") suspend fun headUnit() @@ -124,7 +137,6 @@ class KotlinSuspendTest { } } - @Ignore("Not working yet") @Test fun bodyNullable() { val retrofit = Retrofit.Builder() .baseUrl(server.url("/")) @@ -138,6 +150,42 @@ class KotlinSuspendTest { assertThat(body).isNull() } + @Test fun noBody() { + val retrofit = Retrofit.Builder() + .baseUrl(server.url("/")) + .addConverterFactory(ToStringConverterFactory()) + .build() + val example = retrofit.create(Service::class.java) + + server.enqueue(MockResponse().setResponseCode(204)) + + val body = runBlocking { example.noBody(intArrayOf(1)) } + assertThat(body).isEqualTo(Unit) + } + + @Test fun signatureMatch() { + val retrofit = Retrofit.Builder() + .baseUrl(server.url("/")) + .addConverterFactory(ToStringConverterFactory()) + .build() + val example = retrofit.create(Service::class.java) + + repeat(8) { + server.enqueue(MockResponse()) + } + + runBlocking { + example.noBody() + example.noBody("") + example.noBody(1) + example.noBody(arrayOf(1)) + example.noBody(intArrayOf(1)) + example.noBody(arrayOf("")) + example.noBody(1u) + example.noBody(1L) + } + } + @Test fun response() { val retrofit = Retrofit.Builder() .baseUrl(server.url("/")) diff --git a/retrofit/src/main/java/retrofit2/HttpServiceMethod.java b/retrofit/src/main/java/retrofit2/HttpServiceMethod.java index 2ca5c7f0ae..a35c15e588 100644 --- a/retrofit/src/main/java/retrofit2/HttpServiceMethod.java +++ b/retrofit/src/main/java/retrofit2/HttpServiceMethod.java @@ -54,10 +54,7 @@ static HttpServiceMethod parseAnnotatio continuationWantsResponse = true; } else { continuationIsUnit = Utils.isUnit(responseType); - // TODO figure out if type is nullable or not - // Metadata metadata = method.getDeclaringClass().getAnnotation(Metadata.class) - // Find the entry for method - // Determine if return type is nullable or not + continuationBodyNullable = KotlinMetadata.isReturnTypeNullable(method); } adapterType = new Utils.ParameterizedTypeImpl(null, Call.class, responseType); diff --git a/retrofit/src/main/java/retrofit2/KotlinMetadata.kt b/retrofit/src/main/java/retrofit2/KotlinMetadata.kt new file mode 100644 index 0000000000..84132f656b --- /dev/null +++ b/retrofit/src/main/java/retrofit2/KotlinMetadata.kt @@ -0,0 +1,144 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2 + +import retrofit2.kotlin.metadata.deserialization.BitEncoding +import retrofit2.kotlin.metadata.deserialization.ByteArrayInput +import retrofit2.kotlin.metadata.deserialization.JvmMetadataVersion +import retrofit2.kotlin.metadata.deserialization.MetadataParser +import retrofit2.kotlin.metadata.deserialization.ProtobufReader +import java.lang.reflect.Method +import java.util.concurrent.ConcurrentHashMap +import kotlin.coroutines.Continuation + +object KotlinMetadata { + + data class Function(val signature: String, val returnType: ReturnType) + data class ReturnType(val isNullable: Boolean, val isUnit: Boolean) + + private val kotlinFunctionsMap = ConcurrentHashMap, List>() + + /** + * This helps to parse kotlin metadata of a compiled class to find out the nullability of a suspending method return + * type. + * + * For example a suspending method with following declaration: + * + * ``` + * @GET("/") suspend fun foo(@Query("x") arg: IntArray): String? + * ``` + * + * Will be compiled as a method returning [Object] and with injected [Continuation] argument with following java + * method: + * + * ``` + * public Object foo(int[], Continuation) + * ``` + * + * The information about the return type and its nullability is stored in a [Metadata] annotation of the containing + * class. We process the metadata of a class the first time [isReturnTypeNullable] is called on one of its methods. + * We extract necessary information about all of its methods and store this info in cache, so each class is + * processed only once. Then we try to match the currently inspected [Method] to one extracted from metadata by + * comparing their signatures. + * + * We use the method signature because: + * - it uniquely identifies a method + * - it requires just comparing 2 strings + * - it is already stored in kotlin metadata + * - it is trivial to create one from java reflection's [Method] instance + * + * For example the previous method's signature would be: + * + * ``` + * foo([ILkotlin/coroutines/Continuation;)Ljava/lang/Object; + * ``` + */ + @JvmStatic fun isReturnTypeNullable(method: Method): Boolean { + if (method.declaringClass.getAnnotation(Metadata::class.java) == null) return false + + val javaMethodSignature = method.createSignature() + val kotlinFunctions = loadKotlinFunctions(method.declaringClass) + val candidates = kotlinFunctions.filter { it.signature == javaMethodSignature } + + require(candidates.isNotEmpty()) { "No match found in metadata for '${method}'" } + require(candidates.size == 1) { "Multiple function matches found in metadata for '${method}'" } + val match = candidates.first() + + return match.returnType.isNullable || match.returnType.isUnit + } + + private fun Method.createSignature() = buildString { + append(name) + append('(') + + parameterTypes.forEach { + append(it.typeToSignature()) + } + + append(')') + + append(returnType.typeToSignature()) + } + + private fun loadKotlinFunctions(clazz: Class<*>): List { + var result = kotlinFunctionsMap[clazz] + if (result != null) return result + + synchronized(kotlinFunctionsMap) { + result = kotlinFunctionsMap[clazz] + if (result == null) { + result = readFunctionsFromMetadata(clazz) + } + } + + return result!! + } + + private fun readFunctionsFromMetadata(clazz: Class<*>): List { + val metadataAnnotation = clazz.getAnnotation(Metadata::class.java) + + val isStrictSemantics = (metadataAnnotation.extraInt and (1 shl 3)) != 0 + val isCompatible = JvmMetadataVersion(metadataAnnotation.metadataVersion, isStrictSemantics).isCompatible() + + require(isCompatible) { "Metadata version not compatible" } + require(metadataAnnotation.kind == 1) { "Metadata of wrong kind: ${metadataAnnotation.kind}" } + require(metadataAnnotation.data1.isNotEmpty()) { "data1 must not be empty" } + + val bytes: ByteArray = BitEncoding.decodeBytes(metadataAnnotation.data1) + val reader = ProtobufReader(ByteArrayInput(bytes)) + val parser = MetadataParser(reader, metadataAnnotation.data2) + + return parser.parse() + } + + private fun Class<*>.typeToSignature() = when { + isPrimitive -> javaTypesMap[name] + isArray -> name.replace('.', '/') + else -> "L${name.replace('.', '/')};" + } + + private val javaTypesMap = mapOf( + "int" to "I", + "long" to "J", + "boolean" to "Z", + "byte" to "B", + "char" to "C", + "float" to "F", + "double" to "D", + "short" to "S", + "void" to "V" + ) +} diff --git a/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/BinaryVersion.kt b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/BinaryVersion.kt new file mode 100644 index 0000000000..005e105bb2 --- /dev/null +++ b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/BinaryVersion.kt @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.kotlin.metadata.deserialization + +/** + * This file was adapted from https://github.com/JetBrains/kotlin/blob/af18b10da9d1e20b1b35831a3fb5e508048a2576/core/metadata/src/org/jetbrains/kotlin/metadata/deserialization/BinaryVersion.kt + * by removing unused parts. + */ + +/** + * Subclasses of this class are used to identify different versions of the binary output of the compiler and their compatibility guarantees. + * - Major version should be increased only when the new binary format is neither forward- nor backward compatible. + * This shouldn't really ever happen at all. + * - Minor version should be increased when the new format is backward compatible, + * i.e. the new compiler can process old data, but the old compiler will not be able to process new data. + * - Patch version can be increased freely and is only supposed to be used for debugging. Increase the patch version when you + * make a change to binaries which is both forward- and backward compatible. + */ +abstract class BinaryVersion(private vararg val numbers: Int) { + val major: Int = numbers.getOrNull(0) ?: UNKNOWN + val minor: Int = numbers.getOrNull(1) ?: UNKNOWN + val patch: Int = numbers.getOrNull(2) ?: UNKNOWN + val rest: List = if (numbers.size > 3) { + if (numbers.size > MAX_LENGTH) + throw IllegalArgumentException("BinaryVersion with length more than $MAX_LENGTH are not supported. Provided length ${numbers.size}.") + else + numbers.asList().subList(3, numbers.size).toList() + } else emptyList() + + abstract fun isCompatible(): Boolean + + fun toArray(): IntArray = numbers + + /** + * Returns true if this version of some format loaded from some binaries is compatible + * to the expected version of that format in the current compiler. + * + * @param ourVersion the version of this format in the current compiler + */ + protected fun isCompatibleTo(ourVersion: BinaryVersion): Boolean { + return if (major == 0) ourVersion.major == 0 && minor == ourVersion.minor + else major == ourVersion.major && minor <= ourVersion.minor + } + + override fun toString(): String { + val versions = toArray().takeWhile { it != UNKNOWN } + return if (versions.isEmpty()) "unknown" else versions.joinToString(".") + } + + override fun equals(other: Any?) = + other != null && + this::class.java == other::class.java && + major == (other as BinaryVersion).major && minor == other.minor && patch == other.patch && rest == other.rest + + override fun hashCode(): Int { + var result = major + result += 31 * result + minor + result += 31 * result + patch + result += 31 * result + rest.hashCode() + return result + } + + companion object { + const val MAX_LENGTH = 1024 + private const val UNKNOWN = -1 + } +} diff --git a/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/BitEncoding.java b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/BitEncoding.java new file mode 100644 index 0000000000..2382306e64 --- /dev/null +++ b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/BitEncoding.java @@ -0,0 +1,143 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.kotlin.metadata.deserialization; + +import static retrofit2.kotlin.metadata.deserialization.UtfEncodingKt.MAX_UTF8_INFO_LENGTH; + +import org.jetbrains.annotations.NotNull; + +/** + * This file was adapted from + * https://github.com/JetBrains/kotlin/blob/af18b10da9d1e20b1b35831a3fb5e508048a2576/core/metadata.jvm/src/org/jetbrains/kotlin/metadata/jvm/deserialization/BitEncoding.java + * by removing the unused parts. + */ +public class BitEncoding { + + private static final char _8TO7_MODE_MARKER = (char) -1; + + private BitEncoding() {} + + private static void addModuloByte(@NotNull byte[] data, int increment) { + for (int i = 0, n = data.length; i < n; i++) { + data[i] = (byte) ((data[i] + increment) & 0x7f); + } + } + + /** Converts encoded array of {@code String} back to a byte array. */ + @NotNull + public static byte[] decodeBytes(@NotNull String[] data) { + if (data.length > 0 && !data[0].isEmpty()) { + char possibleMarker = data[0].charAt(0); + if (possibleMarker == UtfEncodingKt.UTF8_MODE_MARKER) { + return UtfEncodingKt.stringsToBytes(dropMarker(data)); + } + if (possibleMarker == _8TO7_MODE_MARKER) { + data = dropMarker(data); + } + } + + byte[] bytes = combineStringArrayIntoBytes(data); + // Adding 0x7f modulo max byte value is equivalent to subtracting 1 the same modulo, which is + // inverse to what happens in encodeBytes + addModuloByte(bytes, 0x7f); + return decode7to8(bytes); + } + + @NotNull + private static String[] dropMarker(@NotNull String[] data) { + // Clone because the clients should be able to use the passed array for their own purposes. + // This is cheap because the size of the array is 1 or 2 almost always. + String[] result = data.clone(); + result[0] = result[0].substring(1); + return result; + } + + /** Combines the array of strings resulted from encodeBytes() into one long byte array */ + @NotNull + private static byte[] combineStringArrayIntoBytes(@NotNull String[] data) { + int resultLength = 0; + for (String s : data) { + assert s.length() <= MAX_UTF8_INFO_LENGTH : "String is too long: " + s.length(); + resultLength += s.length(); + } + + byte[] result = new byte[resultLength]; + int p = 0; + for (String s : data) { + for (int i = 0, n = s.length(); i < n; i++) { + result[p++] = (byte) s.charAt(i); + } + } + + return result; + } + + /** + * Decodes the byte array resulted from encode8to7(). + * + *

Each byte of the input array has at most 7 valuable bits of information. So the decoding is + * equivalent to the following: least significant 7 bits of all input bytes are combined into one + * long bit string. This bit string is then split into groups of 8 bits, each of which forms a + * byte in the output. If there are any leftovers, they are ignored, since they were added just as + * a padding and do not comprise a full byte. + * + *

Suppose the following encoded byte array is given (bits are numbered the same way as in + * encode8to7() doc): + * + *

01234567 01234567 01234567 01234567 + * + *

The output of the following form would be produced: + * + *

01234560 12345601 23456012 + * + *

Note how all most significant bits and leftovers are dropped, since they don't contain any + * useful information + */ + @NotNull + private static byte[] decode7to8(@NotNull byte[] data) { + // floor(7 * data.length / 8) + int resultLength = 7 * data.length / 8; + + byte[] result = new byte[resultLength]; + + // We maintain a pointer to an input bit in the same fashion as in encode8to7(): it's + // represented as two numbers: index of the + // current byte in the input and index of the bit in the byte + int byteIndex = 0; + int bit = 0; + + // A resulting byte is comprised of 8 bits, starting from the current bit. Since each input byte + // only "contains 7 bytes", a + // resulting byte always consists of two parts: several most significant bits of the current + // byte and several least significant bits + // of the next byte + for (int i = 0; i < resultLength; i++) { + int firstPart = (data[byteIndex] & 0xff) >>> bit; + byteIndex++; + int secondPart = (data[byteIndex] & ((1 << (bit + 1)) - 1)) << 7 - bit; + result[i] = (byte) (firstPart + secondPart); + + if (bit == 6) { + byteIndex++; + bit = 0; + } else { + bit++; + } + } + + return result; + } +} diff --git a/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/JvmMetadataVersion.kt b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/JvmMetadataVersion.kt new file mode 100644 index 0000000000..45cbd488a6 --- /dev/null +++ b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/JvmMetadataVersion.kt @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.kotlin.metadata.deserialization + +/** + * This file was adapted from https://github.com/JetBrains/kotlin/blob/af18b10da9d1e20b1b35831a3fb5e508048a2576/core/metadata.jvm/src/org/jetbrains/kotlin/metadata/jvm/deserialization/JvmMetadataVersion.kt + * by removing the unused parts. + */ + +/** + * The version of the metadata serialized by the compiler and deserialized by the compiler and reflection. + * This version includes the version of the core protobuf messages (metadata.proto) as well as JVM extensions (jvm_metadata.proto). + */ +class JvmMetadataVersion(versionArray: IntArray, val isStrictSemantics: Boolean) : BinaryVersion(*versionArray) { + constructor(vararg numbers: Int) : this(numbers, isStrictSemantics = false) + + override fun isCompatible(): Boolean = + // NOTE: 1.0 is a pre-Kotlin-1.0 metadata version, with which the current compiler is incompatible + (major != 1 || minor != 0) && + if (isStrictSemantics) { + isCompatibleTo(INSTANCE) + } else { + // Kotlin 1.N is able to read metadata of versions up to Kotlin 1.{N+1} (unless the version has strict semantics). + major == INSTANCE.major && minor <= INSTANCE.minor + 1 + } + + companion object { + @JvmField + val INSTANCE = JvmMetadataVersion(1, 6, 0) + } +} diff --git a/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/JvmNameResolver.kt b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/JvmNameResolver.kt new file mode 100644 index 0000000000..d3c10f19ea --- /dev/null +++ b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/JvmNameResolver.kt @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.kotlin.metadata.deserialization + +/** + * This file was adapted from https://github.com/JetBrains/kotlin/blob/26673d2b08f01dec1a9007b9b75436a50fa497e9/core/metadata.jvm/src/org/jetbrains/kotlin/metadata/jvm/deserialization/JvmNameResolver.kt + * by removing the unused parts. + */ + +internal class JvmNameResolver(private val types: StringTableTypes, private val strings: Array) { + + fun getString(index: Int): String { + val record = types.records[index] + + var string = when { + record.hasString() -> record.string + record.hasPredefinedIndex() && record.predefinedIndex in PREDEFINED_STRINGS.indices -> + PREDEFINED_STRINGS[record.predefinedIndex] + else -> strings[index] + } + requireNotNull(string) + + if (record.substringIndexList.size >= 2) { + val (begin, end) = record.substringIndexList + if (begin in 0..end && end <= string.length) { + string = string.substring(begin, end) + } + } + + if (record.replaceCharList.size >= 2) { + val (from, to) = record.replaceCharList + string = string.replace(from.toChar(), to.toChar()) + } + + when (record.operation) { + Record.OPERATION_NONE -> { + // Do nothing + } + Record.OPERATION_INTERNAL_TO_CLASS_ID -> { + string = string.replace('$', '.') + } + Record.OPERATION_DESC_TO_CLASS_ID -> { + if (string.length >= 2) { + string = string.substring(1, string.length - 1) + } + string = string.replace('$', '.') + } + } + + return string + } + + companion object { + private val PREDEFINED_STRINGS = listOf( + "kotlin/Any", + "kotlin/Nothing", + "kotlin/Unit", + "kotlin/Throwable", + "kotlin/Number", + + "kotlin/Byte", "kotlin/Double", "kotlin/Float", "kotlin/Int", + "kotlin/Long", "kotlin/Short", "kotlin/Boolean", "kotlin/Char", + + "kotlin/CharSequence", + "kotlin/String", + "kotlin/Comparable", + "kotlin/Enum", + + "kotlin/Array", + "kotlin/ByteArray", "kotlin/DoubleArray", "kotlin/FloatArray", "kotlin/IntArray", + "kotlin/LongArray", "kotlin/ShortArray", "kotlin/BooleanArray", "kotlin/CharArray", + + "kotlin/Cloneable", + "kotlin/Annotation", + + "kotlin/collections/Iterable", "kotlin/collections/MutableIterable", + "kotlin/collections/Collection", "kotlin/collections/MutableCollection", + "kotlin/collections/List", "kotlin/collections/MutableList", + "kotlin/collections/Set", "kotlin/collections/MutableSet", + "kotlin/collections/Map", "kotlin/collections/MutableMap", + "kotlin/collections/Map.Entry", "kotlin/collections/MutableMap.MutableEntry", + + "kotlin/collections/Iterator", "kotlin/collections/MutableIterator", + "kotlin/collections/ListIterator", "kotlin/collections/MutableListIterator" + ) + } +} \ No newline at end of file diff --git a/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/MetadataParser.kt b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/MetadataParser.kt new file mode 100644 index 0000000000..1238674752 --- /dev/null +++ b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/MetadataParser.kt @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.kotlin.metadata.deserialization + +import retrofit2.KotlinMetadata + +/** + * The class metadata is in protobuf format with [StringTableTypes] object followed by [Klass] object. + */ +internal class MetadataParser(private val reader: ProtobufReader, private val strings: Array) { + + fun parse(): List { + val table = StringTableTypes.parse(makeDelimited(reader, tagless = true)) + val klass = Klass.parse(reader) + + val nameResolver = JvmNameResolver(table, strings) + + val functions = mutableListOf() + + klass.functions.forEach { f -> + val functionName = f.getName(nameResolver) + val signatureDesc = f.signature.getDesc(nameResolver) + + val returnType = KotlinMetadata.ReturnType(f.returnType.isNullable, f.returnType.isUnit(nameResolver)) + + functions += KotlinMetadata.Function(functionName + signatureDesc, returnType) + } + + return functions + } +} diff --git a/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/Models.kt b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/Models.kt new file mode 100644 index 0000000000..9913230311 --- /dev/null +++ b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/Models.kt @@ -0,0 +1,251 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.kotlin.metadata.deserialization + +/** + * These are classes representing https://github.com/JetBrains/kotlin/blob/c6697499153329c088b60404bc584bf4b85a105b/core/metadata.jvm/src/jvm_metadata.proto + * and https://github.com/JetBrains/kotlin/blob/92d200e093c693b3c06e53a39e0b0973b84c7ec5/core/metadata/src/metadata.proto + */ + +internal class StringTableTypes(val records: List) { + + companion object { + private const val ID_RECORD = 1 + + fun parse(reader: ProtobufReader): StringTableTypes { + val records = mutableListOf() + + while (reader.readTag() != -1) { + when (reader.currentId) { + ID_RECORD -> { + val record = Record.parse(makeDelimited(reader)) + repeat(record.range) { records += record } + } + else -> { reader.skipElement() } + } + } + + return StringTableTypes(records) + } + } +} + +internal class Record( + val range: Int, + val predefinedIndex: Int, + val operation: Int, + val string: String?, + val substringIndexList: List, + val replaceCharList: List +) { + fun hasString() = string != null + fun hasPredefinedIndex() = predefinedIndex != -1 + + companion object { + private const val ID_RANGE = 1 + private const val ID_PREDEFINED_INDEX = 2 + private const val ID_STRING = 6 + private const val ID_OPERATION = 3 + private const val ID_SUBSTRING_INDEX = 4 + private const val ID_REPLACE_CHAR = 5 + + internal const val OPERATION_NONE = 0 + internal const val OPERATION_INTERNAL_TO_CLASS_ID = 1 + internal const val OPERATION_DESC_TO_CLASS_ID = 2 + + fun parse(reader: ProtobufReader): Record { + var range = 1 + var operation = 0 + var predefinedIndex = -1 + var string: String? = null + val substringIndexList = mutableListOf() + val replaceCharList = mutableListOf() + + while (reader.readTag() != -1) { + when (reader.currentId) { + ID_RANGE -> { + range = reader.readInt(ProtoIntegerType.DEFAULT) + } + ID_PREDEFINED_INDEX -> { + predefinedIndex = reader.readInt(ProtoIntegerType.DEFAULT) + } + ID_OPERATION -> { + operation = reader.readInt(ProtoIntegerType.DEFAULT) + } + ID_SUBSTRING_INDEX -> { + readIntoList(reader, substringIndexList) + } + ID_REPLACE_CHAR -> { + readIntoList(reader, replaceCharList) + } + ID_STRING -> { + string = reader.readString() + } + else -> { reader.skipElement() } + } + } + + return Record(range, predefinedIndex, operation, string, substringIndexList, replaceCharList) + } + } +} + +/** + * Renamed from `Class` in .proto definition to [Klass] to avoid conflicting with [java.lang.Class]. We only parse the + * info needed to determine nullability of the method's return type and skip other stuff. Thanks to + * the protobuf format we will also be able to skip any new fields added in the future. + */ +internal class Klass(val functions: List) { + + companion object { + private const val ID_FUNCTION = 9 + + fun parse(reader: ProtobufReader): Klass { + val functions = mutableListOf() + + while (reader.readTag() != -1) { + if (reader.currentId == ID_FUNCTION) { + functions += Function.parse(makeDelimited(reader)) + } else { + reader.skipElement() + } + } + + return Klass(functions) + } + } +} + +internal class Function(val nameIndex: Int, val returnType: Type, val signature: JvmMethodSignature) { + + fun getName(nameResolver: JvmNameResolver): String { + return signature.getName(nameResolver) ?: nameResolver.getString(nameIndex) + } + + companion object { + private const val ID_NAME = 2 + private const val ID_RETURN_TYPE = 3 + private const val ID_SIGNATURE = 100 + + fun parse(reader: ProtobufReader): Function { + lateinit var returnType: Type + var nameIndex = -1 + lateinit var signature: JvmMethodSignature + + while (reader.readTag() != -1) { + when (reader.currentId) { + ID_NAME -> { + nameIndex = reader.readInt(ProtoIntegerType.DEFAULT) + } + ID_RETURN_TYPE -> { + returnType = Type.parse(makeDelimited(reader)) + } + ID_SIGNATURE -> { + signature = JvmMethodSignature.parse(makeDelimited(reader)) + } + else -> { + reader.skipElement() + } + } + } + + return Function(nameIndex, returnType, signature) + } + } +} + +internal class Type(val isNullable: Boolean, private val nameIndex: Int) { + + fun isUnit(nameResolver: JvmNameResolver): Boolean = nameResolver.getString(nameIndex) == "kotlin/Unit" + + companion object { + private const val ID_NULLABLE = 3 + private const val ID_CLASS_NAME = 6 + + fun parse(reader: ProtobufReader): Type { + var nullable = false + var nameIndex = -1 + + while (reader.readTag() != -1) { + when (reader.currentId) { + ID_NULLABLE -> { + nullable = reader.readInt(ProtoIntegerType.DEFAULT) != 0 + } + ID_CLASS_NAME -> { + nameIndex = reader.readInt(ProtoIntegerType.DEFAULT) + } + else -> { + reader.skipElement() + } + } + } + + return Type(nullable, nameIndex) + } + } +} + +internal class JvmMethodSignature(val nameIndex: Int, val descIndex: Int) { + + fun getName(nameResolver: JvmNameResolver): String? { + return if (nameIndex != -1) nameResolver.getString(nameIndex) else null + } + + fun getDesc(nameResolver: JvmNameResolver): String = nameResolver.getString(descIndex) + + + companion object { + private const val ID_NAME = 1 + private const val ID_DESC = 2 + + fun parse(reader: ProtobufReader): JvmMethodSignature { + var nameIndex = -1 + var descIndex = -1 + while (reader.readTag() != -1) { + when (reader.currentId) { + ID_NAME -> { + nameIndex = reader.readInt(ProtoIntegerType.DEFAULT) + } + ID_DESC -> { + descIndex = reader.readInt(ProtoIntegerType.DEFAULT) + } + else -> reader.skipElement() + } + } + + return JvmMethodSignature(nameIndex, descIndex) + } + } +} + +internal fun makeDelimited(decoder: ProtobufReader, tagless: Boolean = false): ProtobufReader { + val input = if (tagless) decoder.objectTaglessInput() else decoder.objectInput() + return ProtobufReader(input) +} + +private fun readIntoList( + reader: ProtobufReader, + mutableList: MutableList +) { + if (reader.currentType == VARINT) { + mutableList += reader.readInt(ProtoIntegerType.DEFAULT) + } else { + val arrayReader = ProtobufReader(reader.objectInput()) + while (0 < arrayReader.availableBytes) { + mutableList += arrayReader.readInt(ProtoIntegerType.DEFAULT) + } + } +} diff --git a/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/ProtobufUtils.kt b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/ProtobufUtils.kt new file mode 100644 index 0000000000..482f269259 --- /dev/null +++ b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/ProtobufUtils.kt @@ -0,0 +1,294 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.kotlin.metadata.deserialization + +/** + * This file was adapted from https://github.com/Kotlin/kotlinx.serialization/blob/1814a92b871dac128db67c765c9df2b6be8405c7/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Streams.kt + * and https://github.com/Kotlin/kotlinx.serialization/blob/1814a92b871dac128db67c765c9df2b6be8405c7/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufReader.kt + * by removing the unused parts. + */ + +internal open class SerializationException(message: String?) : IllegalArgumentException(message) + +internal class ProtobufDecodingException(message: String) : SerializationException(message) + +class ByteArrayInput(private var array: ByteArray, private val endIndex: Int = array.size) { + private var position: Int = 0 + val availableBytes: Int get() = endIndex - position + + fun slice(size: Int): ByteArrayInput { + ensureEnoughBytes(size) + val result = ByteArrayInput(array, position + size) + result.position = position + position += size + return result + } + + fun read(): Int { + return if (position < endIndex) array[position++].toInt() and 0xFF else -1 + } + + fun readExactNBytes(bytesCount: Int): ByteArray { + ensureEnoughBytes(bytesCount) + val b = ByteArray(bytesCount) + val length = b.size + // Are there any bytes available? + val copied = if (endIndex - position < length) endIndex - position else length + array.copyInto(destination = b, destinationOffset = 0, startIndex = position, endIndex = position + copied) + position += copied + return b + } + + private fun ensureEnoughBytes(bytesCount: Int) { + if (bytesCount > availableBytes) { + throw SerializationException("Unexpected EOF, available $availableBytes bytes, requested: $bytesCount") + } + } + + fun readString(length: Int): String { + val result = array.decodeToString(position, position + length) + position += length + return result + } + + fun readVarint32(): Int { + if (position == endIndex) { + eof() + } + + // Fast-path: unrolled loop for single and two byte values + var currentPosition = position + var result = array[currentPosition++].toInt() + if (result >= 0) { + position = currentPosition + return result + } else if (endIndex - position > 1) { + result = result xor (array[currentPosition++].toInt() shl 7) + if (result < 0) { + position = currentPosition + return result xor (0.inv() shl 7) + } + } + + return readVarint32SlowPath() + } + + fun readVarint64(eofAllowed: Boolean): Long { + if (position == endIndex) { + if (eofAllowed) return -1 + else eof() + } + + // Fast-path: single and two byte values + var currentPosition = position + var result = array[currentPosition++].toLong() + if (result >= 0) { + position = currentPosition + return result + } else if (endIndex - position > 1) { + result = result xor (array[currentPosition++].toLong() shl 7) + if (result < 0) { + position = currentPosition + return result xor (0L.inv() shl 7) + } + } + + return readVarint64SlowPath() + } + + private fun eof() { + throw SerializationException("Unexpected EOF") + } + + private fun readVarint64SlowPath(): Long { + var result = 0L + var shift = 0 + while (shift < 64) { + val byte = read() + result = result or ((byte and 0x7F).toLong() shl shift) + if (byte and 0x80 == 0) { + return result + } + shift += 7 + } + throw SerializationException("Input stream is malformed: Varint too long (exceeded 64 bits)") + } + + private fun readVarint32SlowPath(): Int { + var result = 0 + var shift = 0 + while (shift < 32) { + val byte = read() + result = result or ((byte and 0x7F) shl shift) + if (byte and 0x80 == 0) { + return result + } + shift += 7 + } + throw SerializationException("Input stream is malformed: Varint too long (exceeded 32 bits)") + } +} + +internal enum class ProtoIntegerType { + DEFAULT, + SIGNED, + FIXED; +} + +internal const val VARINT = 0 +internal const val i64 = 1 +internal const val SIZE_DELIMITED = 2 +internal const val i32 = 5 + +internal class ProtobufReader(private val input: ByteArrayInput) { + @JvmField + var currentId = -1 + @JvmField + var currentType = -1 + + val availableBytes: Int + get() = input.availableBytes + + fun readTag(): Int { + val header = input.readVarint64(true).toInt() + return if (header == -1) { + currentId = -1 + currentType = -1 + -1 + } else { + currentId = header ushr 3 + currentType = header and 0b111 + currentId + } + } + + fun skipElement() { + when (currentType) { + VARINT -> readInt(ProtoIntegerType.DEFAULT) + i64 -> readLong(ProtoIntegerType.FIXED) + SIZE_DELIMITED -> readByteArray() + i32 -> readInt(ProtoIntegerType.FIXED) + else -> throw ProtobufDecodingException("Unsupported start group or end group wire type: $currentType") + } + } + + @Suppress("NOTHING_TO_INLINE") + private inline fun assertWireType(expected: Int) { + if (currentType != expected) throw ProtobufDecodingException("Expected wire type $expected, but found $currentType") + } + + private fun readByteArray(): ByteArray { + assertWireType(SIZE_DELIMITED) + return readByteArrayNoTag() + } + + private fun readByteArrayNoTag(): ByteArray { + val length = decode32() + checkLength(length) + return input.readExactNBytes(length) + } + + fun objectInput(): ByteArrayInput { + assertWireType(SIZE_DELIMITED) + return objectTaglessInput() + } + + fun objectTaglessInput(): ByteArrayInput { + val length = decode32() + checkLength(length) + return input.slice(length) + } + + fun readInt(format: ProtoIntegerType): Int { + val wireType = if (format == ProtoIntegerType.FIXED) i32 else VARINT + assertWireType(wireType) + return decode32(format) + } + + private fun readLong(format: ProtoIntegerType): Long { + val wireType = if (format == ProtoIntegerType.FIXED) i64 else VARINT + assertWireType(wireType) + return decode64(format) + } + + private fun readIntLittleEndian(): Int { + // TODO this could be optimized by extracting method to the IS + var result = 0 + for (i in 0..3) { + val byte = input.read() and 0x000000FF + result = result or (byte shl (i * 8)) + } + return result + } + + private fun readLongLittleEndian(): Long { + // TODO this could be optimized by extracting method to the IS + var result = 0L + for (i in 0..7) { + val byte = (input.read() and 0x000000FF).toLong() + result = result or (byte shl (i * 8)) + } + return result + } + + fun readString(): String { + assertWireType(SIZE_DELIMITED) + val length = decode32() + checkLength(length) + return input.readString(length) + } + + private fun checkLength(length: Int) { + if (length < 0) { + throw ProtobufDecodingException("Unexpected negative length: $length") + } + } + + private fun decode32(format: ProtoIntegerType = ProtoIntegerType.DEFAULT): Int = when (format) { + ProtoIntegerType.DEFAULT -> input.readVarint64(false).toInt() + ProtoIntegerType.SIGNED -> decodeSignedVarintInt(input) + ProtoIntegerType.FIXED -> readIntLittleEndian() + } + + private fun decode64(format: ProtoIntegerType = ProtoIntegerType.DEFAULT): Long = when (format) { + ProtoIntegerType.DEFAULT -> input.readVarint64(false) + ProtoIntegerType.SIGNED -> decodeSignedVarintLong(input) + ProtoIntegerType.FIXED -> readLongLittleEndian() + } + + /** + * Source for all varint operations: + * https://github.com/addthis/stream-lib/blob/master/src/main/java/com/clearspring/analytics/util/Varint.java + */ + private fun decodeSignedVarintInt(input: ByteArrayInput): Int { + val raw = input.readVarint32() + val temp = raw shl 31 shr 31 xor raw shr 1 + // This extra step lets us deal with the largest signed values by treating + // negative results from read unsigned methods as like unsigned values. + // Must re-flip the top bit if the original read value had it set. + return temp xor (raw and (1 shl 31)) + } + + private fun decodeSignedVarintLong(input: ByteArrayInput): Long { + val raw = input.readVarint64(false) + val temp = raw shl 63 shr 63 xor raw shr 1 + // This extra step lets us deal with the largest signed values by treating + // negative results from read unsigned methods as like unsigned values + // Must re-flip the top bit if the original read value had it set. + return temp xor (raw and (1L shl 63)) + + } +} diff --git a/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/utfEncoding.kt b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/utfEncoding.kt new file mode 100644 index 0000000000..01880c71fc --- /dev/null +++ b/retrofit/src/main/java/retrofit2/kotlin/metadata/deserialization/utfEncoding.kt @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.kotlin.metadata.deserialization + +/** + * This file was adapted from https://github.com/JetBrains/kotlin/blob/af18b10da9d1e20b1b35831a3fb5e508048a2576/core/metadata.jvm/src/org/jetbrains/kotlin/metadata/jvm/deserialization/utfEncoding.kt + * by removing the unused parts. + */ + +// The maximum possible length of the byte array in the CONSTANT_Utf8_info structure in the bytecode, as per JVMS7 4.4.7 +const val MAX_UTF8_INFO_LENGTH = 65535 + +const val UTF8_MODE_MARKER = 0.toChar() + +fun stringsToBytes(strings: Array): ByteArray { + val resultLength = strings.sumBy { it.length } + val result = ByteArray(resultLength) + + var i = 0 + for (s in strings) { + for (si in 0..s.length - 1) { + result[i++] = s[si].toInt().toByte() + } + } + + assert(i == result.size) { "Should have reached the end" } + + return result +} diff --git a/retrofit/test-helpers/src/main/java/retrofit2/helpers/ToNullStringResponseConverterFactory.java b/retrofit/test-helpers/src/main/java/retrofit2/helpers/ToNullStringResponseConverterFactory.java new file mode 100644 index 0000000000..b7623a6dc0 --- /dev/null +++ b/retrofit/test-helpers/src/main/java/retrofit2/helpers/ToNullStringResponseConverterFactory.java @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package retrofit2.helpers; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Type; +import javax.annotation.Nullable; +import okhttp3.ResponseBody; +import retrofit2.Converter; +import retrofit2.Retrofit; + +public class ToNullStringResponseConverterFactory extends Converter.Factory { + + @Override + public @Nullable Converter responseBodyConverter( + Type type, Annotation[] annotations, Retrofit retrofit) { + if (String.class.equals(type)) { + return value -> null; + } + return null; + } +}