Skip to content

Commit

Permalink
Introduce @JsonClassDiscriminator to configure discriminator per poly…
Browse files Browse the repository at this point in the history
…morphic base class

Fixes #546
  • Loading branch information
sandwwraith committed May 25, 2021
1 parent 9d965ba commit 4af5547
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2017-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.serialization.json

import kotlinx.serialization.*
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.encoding.*
import kotlinx.serialization.json.internal.*
import kotlin.native.concurrent.*

/**
* Indicates that the field can be represented in JSON
* with multiple possible alternative names.
* [Json] format recognizes this annotation and is able to decode
* the data using any of the alternative names.
*
* Unlike [SerialName] annotation, does not affect JSON encoding in any way.
*
* Example of usage:
* ```
* @Serializable
* data class Project(@JsonNames("title") val name: String)
*
* val project = Json.decodeFromString<Project>("""{"name":"kotlinx.serialization"}""")
* println(project)
* val oldProject = Json.decodeFromString<Project>("""{"title":"kotlinx.coroutines"}""")
* println(oldProject)
* ```
*
* This annotation has lesser priority than [SerialName].
*
* @see JsonBuilder.useAlternativeNames
*/
@SerialInfo
@Target(AnnotationTarget.PROPERTY)
@ExperimentalSerializationApi
public annotation class JsonNames(vararg val names: String)

/**
* Specifies key for class discriminator value used during polymorphic serialization in [Json].
* Provided key is used only for an annotated class, to configure global class discriminator, use [JsonBuilder.classDiscriminator]
* property.
*
* It is possible to define different class discriminators for different parts of class hierarchy.
* Pay attention to the fact that class discriminator, same as polymorphic serializer's base class, is
* determined statically.
*
* Example:
* ```
* @Serializable
* @JsonTypeDiscriminator("class")
* abstract class Base
*
* @Serializable
* @JsonTypeDiscriminator("error_class")
* abstract class ErrorClass: Base()
*
* @Serializable
* class Message(val object: Base, val error: ErrorClass?)
*
* val message = Json.decodeFromString<Message>("""{"object": {"class":"my.app.BaseMessage", "message": "not found"}, "error": {"error_class":"my.app.GenericError", "error_code": 404}}""")
* ```
*
* @see JsonBuilder.classDiscriminator
*/
@SerialInfo
@Target(AnnotationTarget.CLASS)
@ExperimentalSerializationApi
public annotation class JsonClassDiscriminator(val discriminator: String)

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,22 @@ import kotlinx.serialization.internal.*
import kotlinx.serialization.json.*

@Suppress("UNCHECKED_CAST")
internal inline fun <T> JsonEncoder.encodePolymorphically(serializer: SerializationStrategy<T>, value: T, ifPolymorphic: () -> Unit) {
internal inline fun <T> JsonEncoder.encodePolymorphically(
serializer: SerializationStrategy<T>,
value: T,
ifPolymorphic: (String) -> Unit
) {
if (serializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) {
serializer.serialize(this, value)
return
}
val actualSerializer = findActualSerializer(serializer as SerializationStrategy<Any>, value as Any)
ifPolymorphic()
actualSerializer.serialize(this, value)
}

private fun JsonEncoder.findActualSerializer(
serializer: SerializationStrategy<Any>,
value: Any
): SerializationStrategy<Any> {
val casted = serializer as AbstractPolymorphicSerializer<Any>
val actualSerializer = casted.findPolymorphicSerializer(this, value)
validateIfSealed(casted, actualSerializer, json.configuration.classDiscriminator)
val kind = actualSerializer.descriptor.kind
checkKind(kind)
return actualSerializer
val baseClassDiscriminator = serializer.descriptor.classDiscriminator(json)
val actualSerializer = casted.findPolymorphicSerializer(this, value as Any)
validateIfSealed(casted, actualSerializer, baseClassDiscriminator)
checkKind(actualSerializer.descriptor.kind)
ifPolymorphic(baseClassDiscriminator)
actualSerializer.serialize(this, value)
}

private fun validateIfSealed(
Expand Down Expand Up @@ -64,7 +60,7 @@ internal fun <T> JsonDecoder.decodeSerializableValuePolymorphic(deserializer: De
}

val jsonTree = cast<JsonObject>(decodeJsonElement(), deserializer.descriptor)
val discriminator = json.configuration.classDiscriminator
val discriminator = deserializer.descriptor.classDiscriminator(json)
val type = jsonTree[discriminator]?.jsonPrimitive?.content
val actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type)
?: throwSerializerNotFound(type, jsonTree)
Expand All @@ -79,3 +75,8 @@ private fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothin
else "class discriminator '$type'"
throw JsonDecodingException(-1, "Polymorphic serializer was not found for $suffix", jsonTree.toString())
}

internal fun SerialDescriptor.classDiscriminator(json: Json): String =
annotations.filterIsInstance<JsonClassDiscriminator>().singleOrNull()?.discriminator
?: json.configuration.classDiscriminator

Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ internal class StreamingJsonEncoder(

// Forces serializer to wrap all values into quotes
private var forceQuoting: Boolean = false
private var writePolymorphic = false
private var polymorphicDiscriminator: String? = null

init {
val i = mode.ordinal
Expand All @@ -64,13 +64,13 @@ internal class StreamingJsonEncoder(

override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) {
encodePolymorphically(serializer, value) {
writePolymorphic = true
polymorphicDiscriminator = it
}
}

private fun encodeTypeInfo(descriptor: SerialDescriptor) {
composer.nextItem()
encodeString(configuration.classDiscriminator)
encodeString(polymorphicDiscriminator!!)
composer.print(COLON)
composer.space()
encodeString(descriptor.serialName)
Expand All @@ -83,9 +83,9 @@ internal class StreamingJsonEncoder(
composer.indent()
}

if (writePolymorphic) {
writePolymorphic = false
if (polymorphicDiscriminator != null) {
encodeTypeInfo(descriptor)
polymorphicDiscriminator = null
}

if (mode == newMode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private sealed class AbstractJsonTreeEncoder(
@JvmField
protected val configuration = json.configuration

private var writePolymorphic = false
private var polymorphicDiscriminator: String? = null

override fun encodeJsonElement(element: JsonElement) {
encodeSerializableValue(JsonElementSerializer, element)
Expand Down Expand Up @@ -70,7 +70,7 @@ private sealed class AbstractJsonTreeEncoder(
override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) {
// Writing non-structured data (i.e. primitives) on top-level (e.g. without any tag) requires special output
if (currentTagOrNull != null || serializer.descriptor.kind !is PrimitiveKind && serializer.descriptor.kind !== SerialKind.ENUM) {
encodePolymorphically(serializer, value) { writePolymorphic = true }
encodePolymorphically(serializer, value) { polymorphicDiscriminator = it }
} else JsonPrimitiveEncoder(json, nodeConsumer).apply {
encodeSerializableValue(serializer, value)
endEncode(serializer.descriptor)
Expand Down Expand Up @@ -126,9 +126,9 @@ private sealed class AbstractJsonTreeEncoder(
else -> JsonTreeEncoder(json, consumer)
}

if (writePolymorphic) {
writePolymorphic = false
encoder.putElement(configuration.classDiscriminator, JsonPrimitive(descriptor.serialName))
if (polymorphicDiscriminator != null) {
encoder.putElement(polymorphicDiscriminator!!, JsonPrimitive(descriptor.serialName))
polymorphicDiscriminator = null
}

return encoder
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright 2017-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.serialization.features

import kotlinx.serialization.*
import kotlinx.serialization.builtins.*
import kotlinx.serialization.json.*
import kotlinx.serialization.modules.*
import kotlin.test.*

class JsonClassDiscriminatorTest : JsonTestBase() {
@Serializable
@JsonClassDiscriminator("sealedType")
sealed class SealedMessage {
@Serializable
@SerialName("SealedMessage.StringMessage")
data class StringMessage(val description: String, val message: String) : SealedMessage()

@SerialName("EOF")
@Serializable
object EOF : SealedMessage()
}

@Serializable
@JsonClassDiscriminator("abstractType")
abstract class Message {
@Serializable
@SerialName("Message.StringMessage")
data class StringMessage(val description: String, val message: String) : Message()

@Serializable
@SerialName("Message.IntMessage")
data class IntMessage(val description: String, val message: Int) : Message()
}

@Test
fun testSealedClassesHaveCustomDiscriminator() {
val messages = listOf(
SealedMessage.StringMessage("string message", "foo"),
SealedMessage.EOF
)
val expected =
"""[{"sealedType":"SealedMessage.StringMessage","description":"string message","message":"foo"},{"sealedType":"EOF"}]"""
assertJsonFormAndRestored(
ListSerializer(SealedMessage.serializer()),
messages,
expected,
)
}

@Test
fun testAbstractClassesHaveCustomDiscriminator() {
val messages = listOf(
Message.StringMessage("string message", "foo"),
Message.IntMessage("int message", 42),
)
val module = SerializersModule {
polymorphic(Message::class) {
subclass(Message.StringMessage.serializer())
subclass(Message.IntMessage.serializer())
}
}
val json = Json { serializersModule = module }
val expected =
"""[{"abstractType":"Message.StringMessage","description":"string message","message":"foo"},{"abstractType":"Message.IntMessage","description":"int message","message":42}]"""
assertJsonFormAndRestored(ListSerializer(Message.serializer()), messages, expected, json)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ private class DynamicObjectEncoder(
/**
* Flag of usage polymorphism with discriminator attribute
*/
private var writePolymorphic = false
private var polymorphicDiscriminator: String? = null

private object NoOutputMark

Expand Down Expand Up @@ -173,7 +173,7 @@ private class DynamicObjectEncoder(

override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) {
encodePolymorphically(serializer, value) {
writePolymorphic = true
polymorphicDiscriminator = it
}
}

Expand All @@ -197,9 +197,9 @@ private class DynamicObjectEncoder(
enterNode(child, newMode)
}

if (writePolymorphic) {
writePolymorphic = false
current.jsObject[json.configuration.classDiscriminator] = descriptor.serialName
if (polymorphicDiscriminator != null) {
current.jsObject[polymorphicDiscriminator!!] = descriptor.serialName
polymorphicDiscriminator = null
}

current.index = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ class DynamicPolymorphismTest {
data class DefaultChild(val default: String? = "default"): Sealed(5)
}

@Serializable
@JsonClassDiscriminator("sealed_custom")
sealed class SealedCustom {
@Serializable
@SerialName("data_class")
data class DataClassChild(val name: String) : SealedCustom()
}

@Serializable
data class CompositeClass(val mark: String, val nested: Sealed)

Expand Down Expand Up @@ -75,6 +83,16 @@ class DynamicPolymorphismTest {
}
}

@Test
fun testCustomClassDiscriminator() {
val value = SealedCustom.DataClassChild("custom-discriminator-test")
encodeAndDecode(SealedCustom.serializer(), value, objectJson) {
assertEquals("data_class", this["sealed_custom"])
assertEquals(undefined, this.type)
assertEquals(2, fieldsCount(this))
}
}

@Test
fun testComposite() {
val nestedValue = Sealed.DataClassChild("child")
Expand Down

0 comments on commit 4af5547

Please sign in to comment.