Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed type discriminator value for custom serializer that uses encodeJsonElement #2628

Merged
merged 4 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.serialization.features

import kotlinx.serialization.*
import kotlinx.serialization.builtins.*
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.encoding.*
import kotlinx.serialization.json.*
import kotlinx.serialization.modules.*
import kotlin.test.*

class PolymorphismForCustomTest : JsonTestBase() {

private val customSerializer = object : KSerializer<VImpl> {
override val descriptor: SerialDescriptor =
buildClassSerialDescriptor("VImpl") {
element("a", String.serializer().descriptor)
element("b", Int.serializer().descriptor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no b in VImpl though

}

override fun deserialize(decoder: Decoder): VImpl {
decoder as JsonDecoder
val jsonObject = decoder.decodeJsonElement() as JsonObject
return VImpl(
(jsonObject["a"] as JsonPrimitive).content
)
}

override fun serialize(encoder: Encoder, value: VImpl) {
encoder as JsonEncoder
encoder.encodeJsonElement(
JsonObject(mapOf("a" to JsonPrimitive(value.a)))
)
}
}

@Serializable
data class ValueHolder<V : Any>(
@Polymorphic val value: V,
)

data class VImpl(val a: String)

val json = Json {
serializersModule = SerializersModule {
polymorphic(Any::class, VImpl::class, customSerializer)
}
}

@Test
fun test() = parametrizedTest { mode ->
val valueHolder = ValueHolder(VImpl("aaa"))
val encoded = json.encodeToString(ValueHolder.serializer(customSerializer), valueHolder, mode)
assertEquals("""{"value":{"type":"VImpl","a":"aaa"}}""", encoded)

val decoded = json.decodeFromString<ValueHolder<*>>(ValueHolder.serializer(customSerializer), encoded, mode)

assertEquals(valueHolder, decoded)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import kotlin.jvm.*
internal inline fun <T> JsonEncoder.encodePolymorphically(
serializer: SerializationStrategy<T>,
value: T,
ifPolymorphic: (String) -> Unit
ifPolymorphic: (discriminatorName: String, serialName: String) -> Unit
) {
if (json.configuration.useArrayPolymorphism) {
serializer.serialize(this, value)
Expand All @@ -42,7 +42,7 @@ internal inline fun <T> JsonEncoder.encodePolymorphically(
actual as SerializationStrategy<T>
} else serializer

if (baseClassDiscriminator != null) ifPolymorphic(baseClassDiscriminator)
if (baseClassDiscriminator != null) ifPolymorphic(baseClassDiscriminator, actualSerializer.descriptor.serialName)
actualSerializer.serialize(this, value)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ internal class StreamingJsonEncoder(
}

override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) {
encodePolymorphically(serializer, value) {
polymorphicDiscriminator = it
encodePolymorphically(serializer, value) { discriminatorName, serialName ->
polymorphicDiscriminator = discriminatorName
polymorphicSerialName = serialName
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ private sealed class AbstractJsonTreeEncoder(
override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) {
// Writing non-structured data (i.e. primitives) on top-level (e.g. without any tag) requires special output
if (currentTagOrNull != null || !serializer.descriptor.carrierDescriptor(serializersModule).requiresTopLevelTag) {
encodePolymorphically(serializer, value) { polymorphicDiscriminator = it }
encodePolymorphically(serializer, value) { discriminatorName, serialName ->
polymorphicDiscriminator = discriminatorName
polymorphicSerialName = serialName
}
} else JsonPrimitiveEncoder(json, nodeConsumer).apply {
encodeSerializableValue(serializer, value)
}
Expand Down Expand Up @@ -155,7 +158,14 @@ private sealed class AbstractJsonTreeEncoder(

val discriminator = polymorphicDiscriminator
if (discriminator != null) {
encoder.putElement(discriminator, JsonPrimitive(polymorphicSerialName ?: descriptor.serialName))
if (encoder is JsonTreeMapEncoder) {
// first parameter is ignored in JsonTreeMapEncoder
encoder.putElement("key", JsonPrimitive(discriminator))
encoder.putElement("value", JsonPrimitive(polymorphicSerialName ?: descriptor.serialName))

} else {
encoder.putElement(discriminator, JsonPrimitive(polymorphicSerialName ?: descriptor.serialName))
}
polymorphicDiscriminator = null
polymorphicSerialName = null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ private class DynamicObjectEncoder(
* Flag of usage polymorphism with discriminator attribute
*/
private var polymorphicDiscriminator: String? = null
private var polymorphicSerialName: String? = null


private object NoOutputMark

Expand Down Expand Up @@ -183,8 +185,9 @@ private class DynamicObjectEncoder(
private fun isNotStructured() = result === NoOutputMark

override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) {
encodePolymorphically(serializer, value) {
polymorphicDiscriminator = it
encodePolymorphically(serializer, value) { discriminatorName, serialName ->
polymorphicDiscriminator = discriminatorName
polymorphicSerialName = serialName
}
}

Expand All @@ -209,8 +212,9 @@ private class DynamicObjectEncoder(
}

if (polymorphicDiscriminator != null) {
current.jsObject[polymorphicDiscriminator!!] = descriptor.serialName
current.jsObject[polymorphicDiscriminator!!] = polymorphicSerialName ?: descriptor.serialName
polymorphicDiscriminator = null
polymorphicSerialName = null
}

current.index = 0
Expand Down