From d8b98b5ed3ff5f8c6074a1b59e6f0e23f62f0dea Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 5 Dec 2023 18:19:18 +0000 Subject: [PATCH] Fix: Hocon polymorphic serialization (#2151) Fixes #1581 --- .../kotlinx/serialization/hocon/Hocon.kt | 31 ++++++++++-- .../hocon/HoconPolymorphismTest.kt | 48 +++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/formats/hocon/src/main/kotlin/kotlinx/serialization/hocon/Hocon.kt b/formats/hocon/src/main/kotlin/kotlinx/serialization/hocon/Hocon.kt index f2f277948..5ca445ec5 100644 --- a/formats/hocon/src/main/kotlin/kotlinx/serialization/hocon/Hocon.kt +++ b/formats/hocon/src/main/kotlin/kotlinx/serialization/hocon/Hocon.kt @@ -145,7 +145,7 @@ public sealed class Hocon( } - private inner class ConfigReader(val conf: Config) : ConfigConverter() { + private inner class ConfigReader(val conf: Config, private val isPolymorphic: Boolean = false) : ConfigConverter() { private var ind = -1 override fun decodeElementIndex(descriptor: SerialDescriptor): Int { @@ -161,8 +161,10 @@ public sealed class Hocon( private fun composeName(parentName: String, childName: String) = if (parentName.isEmpty()) childName else "$parentName.$childName" - override fun SerialDescriptor.getTag(index: Int): String = - composeName(currentTagOrNull.orEmpty(), getConventionElementName(index, useConfigNamingConvention)) + override fun SerialDescriptor.getTag(index: Int): String { + val conventionName = getConventionElementName(index, useConfigNamingConvention) + return if (!isPolymorphic) composeName(currentTagOrNull.orEmpty(), conventionName) else conventionName + } override fun decodeNotNullMark(): Boolean { // Tag might be null for top-level deserialization @@ -206,6 +208,27 @@ public sealed class Hocon( } } + private inner class PolymorphConfigReader(private val conf: Config) : ConfigConverter() { + private var ind = -1 + + override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder = + when { + descriptor.kind.objLike -> ConfigReader(conf, isPolymorphic = true) + else -> this + } + + override fun SerialDescriptor.getTag(index: Int): String = getElementName(index) + + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + ind++ + return if (ind >= descriptor.elementsCount) DECODE_DONE else ind + } + + override fun getValueFromTaggedConfig(tag: String, valueResolver: (Config, String) -> E): E { + return valueResolver(conf, tag) + } + } + private inner class ListConfigReader(private val list: ConfigList) : ConfigConverter() { private var ind = -1 @@ -216,6 +239,7 @@ public sealed class Hocon( override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder = when { + descriptor.kind is PolymorphicKind -> PolymorphConfigReader((list[currentTag] as ConfigObject).toConfig()) descriptor.kind.listLike -> ListConfigReader(list[currentTag] as ConfigList) descriptor.kind.objLike -> ConfigReader((list[currentTag] as ConfigObject).toConfig()) descriptor.kind == StructureKind.MAP -> MapConfigReader(list[currentTag] as ConfigObject) @@ -256,6 +280,7 @@ public sealed class Hocon( override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder = when { + descriptor.kind is PolymorphicKind -> PolymorphConfigReader((values[currentTag / 2] as ConfigObject).toConfig()) descriptor.kind.listLike -> ListConfigReader(values[currentTag / 2] as ConfigList) descriptor.kind.objLike -> ConfigReader((values[currentTag / 2] as ConfigObject).toConfig()) descriptor.kind == StructureKind.MAP -> MapConfigReader(values[currentTag / 2] as ConfigObject) diff --git a/formats/hocon/src/test/kotlin/kotlinx/serialization/hocon/HoconPolymorphismTest.kt b/formats/hocon/src/test/kotlin/kotlinx/serialization/hocon/HoconPolymorphismTest.kt index db038e70b..1dbc1f90a 100644 --- a/formats/hocon/src/test/kotlin/kotlinx/serialization/hocon/HoconPolymorphismTest.kt +++ b/formats/hocon/src/test/kotlin/kotlinx/serialization/hocon/HoconPolymorphismTest.kt @@ -23,6 +23,12 @@ class HoconPolymorphismTest { data class AnnotatedTypeChild(@SerialName("my_type") val type: String) : Sealed(3) } + @Serializable + data class SealedCollectionContainer(val sealed: Collection) + + @Serializable + data class SealedMapContainer(val sealed: Map) + @Serializable data class CompositeClass(var sealed: Sealed) @@ -102,4 +108,46 @@ class HoconPolymorphismTest { serializer = Sealed.serializer(), ) } + + @Test + fun testCollectionContainer() { + objectHocon.assertStringFormAndRestored( + expected = """ + sealed = [ + { type = annotated_type_child, my_type = override, intField = 3 } + { type = object } + { type = data_class, name = testDataClass, intField = 1 } + ] + """.trimIndent(), + original = SealedCollectionContainer( + listOf( + Sealed.AnnotatedTypeChild(type = "override"), + Sealed.ObjectChild, + Sealed.DataClassChild(name = "testDataClass"), + ) + ), + serializer = SealedCollectionContainer.serializer(), + ) + } + + @Test + fun testMapContainer() { + objectHocon.assertStringFormAndRestored( + expected = """ + sealed = { + "annotated_type_child" = { type = annotated_type_child, my_type = override, intField = 3 } + "object" = { type = object } + "data_class" = { type = data_class, name = testDataClass, intField = 1 } + } + """.trimIndent(), + original = SealedMapContainer( + mapOf( + "annotated_type_child" to Sealed.AnnotatedTypeChild(type = "override"), + "object" to Sealed.ObjectChild, + "data_class" to Sealed.DataClassChild(name = "testDataClass"), + ) + ), + serializer = SealedMapContainer.serializer(), + ) + } }