diff --git a/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/OmitNullBenchmark.kt b/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/OmitNullBenchmark.kt new file mode 100644 index 000000000..5c32f7f40 --- /dev/null +++ b/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/OmitNullBenchmark.kt @@ -0,0 +1,116 @@ +package kotlinx.benchmarks.json + +import kotlinx.serialization.Serializable +import kotlinx.serialization.decodeFromString +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import org.openjdk.jmh.annotations.* +import java.util.concurrent.TimeUnit + +@Warmup(iterations = 5, time = 1) +@Measurement(iterations = 5, time = 1) +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +@Fork(2) +open class OmitNullBenchmark { + + @Serializable + data class Values( + val field0: Int?, + val field1: Int?, + val field2: Int?, + val field3: Int?, + val field4: Int?, + val field5: Int?, + val field6: Int?, + val field7: Int?, + val field8: Int?, + val field9: Int?, + + val field10: Int?, + val field11: Int?, + val field12: Int?, + val field13: Int?, + val field14: Int?, + val field15: Int?, + val field16: Int?, + val field17: Int?, + val field18: Int?, + val field19: Int?, + + val field20: Int?, + val field21: Int?, + val field22: Int?, + val field23: Int?, + val field24: Int?, + val field25: Int?, + val field26: Int?, + val field27: Int?, + val field28: Int?, + val field29: Int?, + + val field30: Int?, + val field31: Int? + ) + + + private val jsonOmitNull = Json { omitNull = true } + + private val valueWithNulls = Values( + null, null, 2, null, null, null, null, null, null, null, + null, null, null, null, 14, null, null, null, null, null, + null, null, null, null, null, null, null, null, null, null, + null, null + ) + + + private val jsonWithNulls = """{"field0":null,"field1":null,"field2":2,"field3":null,"field4":null,"field5":null, + |"field6":null,"field7":null,"field8":null,"field9":null,"field10":null,"field11":null,"field12":null, + |"field13":null,"field14":14,"field15":null,"field16":null,"field17":null,"field18":null,"field19":null, + |"field20":null,"field21":null,"field22":null,"field23":null,"field24":null,"field25":null,"field26":null, + |"field27":null,"field28":null,"field29":null,"field30":null,"field31":null}""".trimMargin() + + private val jsonNoNulls = """{"field0":0,"field1":1,"field2":2,"field3":3,"field4":4,"field5":5, + |"field6":6,"field7":7,"field8":8,"field9":9,"field10":10,"field11":11,"field12":12, + |"field13":13,"field14":14,"field15":15,"field16":16,"field17":17,"field18":18,"field19":19, + |"field20":20,"field21":21,"field22":22,"field23":23,"field24":24,"field25":25,"field26":26, + |"field27":27,"field28":28,"field29":29,"field30":30,"field31":31}""".trimMargin() + + private val jsonWithAbsence = """{"field2":2, "field14":14}""" + + @Benchmark + fun decodeNoNulls() { + Json.decodeFromString(jsonNoNulls) + } + + @Benchmark + fun decodeNoNullsWithOmit() { + jsonOmitNull.decodeFromString(jsonNoNulls) + } + + @Benchmark + fun decodeNulls() { + Json.decodeFromString(jsonWithNulls) + } + + @Benchmark + fun decodeNullsWithOmit() { + jsonOmitNull.decodeFromString(jsonWithNulls) + } + + @Benchmark + fun decodeAbsenceWithOmit() { + jsonOmitNull.decodeFromString(jsonWithAbsence) + } + + @Benchmark + fun encodeNulls() { + Json.encodeToString(valueWithNulls) + } + + @Benchmark + fun encodeNullsWithOmit() { + jsonOmitNull.encodeToString(valueWithNulls) + } +} diff --git a/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/TwitterBenchmark.kt b/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/TwitterBenchmark.kt index 15e9ea46b..7097305b9 100644 --- a/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/TwitterBenchmark.kt +++ b/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/TwitterBenchmark.kt @@ -25,6 +25,8 @@ open class TwitterBenchmark { private val input = TwitterBenchmark::class.java.getResource("/twitter.json").readBytes().decodeToString() private val twitter = Json.decodeFromString(Twitter.serializer(), input) + private val jsonOmitNull = Json { omitNull = true } + @Setup fun init() { require(twitter == Json.decodeFromString(Twitter.serializer(), Json.encodeToString(Twitter.serializer(), twitter))) @@ -34,6 +36,9 @@ open class TwitterBenchmark { @Benchmark fun decodeTwitter() = Json.decodeFromString(Twitter.serializer(), input) + @Benchmark + fun decodeTwitterOmitNull() = jsonOmitNull.decodeFromString(Twitter.serializer(), input) + @Benchmark fun encodeTwitter() = Json.encodeToString(Twitter.serializer(), twitter) } diff --git a/core/api/kotlinx-serialization-core.api b/core/api/kotlinx-serialization-core.api index a7004ddf7..75260bf9a 100644 --- a/core/api/kotlinx-serialization-core.api +++ b/core/api/kotlinx-serialization-core.api @@ -379,6 +379,7 @@ public abstract class kotlinx/serialization/encoding/AbstractEncoder : kotlinx/s public fun encodeValue (Ljava/lang/Object;)V public fun endStructure (Lkotlinx/serialization/descriptors/SerialDescriptor;)V public fun shouldEncodeElementDefault (Lkotlinx/serialization/descriptors/SerialDescriptor;I)Z + public fun skipNullElement (Lkotlinx/serialization/descriptors/SerialDescriptor;I)Z } public abstract interface class kotlinx/serialization/encoding/CompositeDecoder { @@ -628,6 +629,13 @@ public final class kotlinx/serialization/internal/DoubleSerializer : kotlinx/ser public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V } +public abstract class kotlinx/serialization/internal/ElementMarker { + public fun (Lkotlinx/serialization/descriptors/SerialDescriptor;)V + protected abstract fun isPoppedElement (Lkotlinx/serialization/descriptors/SerialDescriptor;I)Z + public final fun mark (I)V + public final fun popUnmarkedIndex ()I +} + public final class kotlinx/serialization/internal/EnumDescriptor : kotlinx/serialization/internal/PluginGeneratedSerialDescriptor { public fun (Ljava/lang/String;I)V public fun equals (Ljava/lang/Object;)Z @@ -1091,6 +1099,7 @@ public abstract class kotlinx/serialization/internal/TaggedEncoder : kotlinx/ser protected final fun popTag ()Ljava/lang/Object; protected final fun pushTag (Ljava/lang/Object;)V public fun shouldEncodeElementDefault (Lkotlinx/serialization/descriptors/SerialDescriptor;I)Z + protected fun skipNullElement (Lkotlinx/serialization/descriptors/SerialDescriptor;I)Z } public final class kotlinx/serialization/internal/TripleSerializer : kotlinx/serialization/KSerializer { diff --git a/core/commonMain/src/kotlinx/serialization/encoding/AbstractEncoder.kt b/core/commonMain/src/kotlinx/serialization/encoding/AbstractEncoder.kt index 616a759b2..f3dddaf68 100644 --- a/core/commonMain/src/kotlinx/serialization/encoding/AbstractEncoder.kt +++ b/core/commonMain/src/kotlinx/serialization/encoding/AbstractEncoder.kt @@ -40,6 +40,10 @@ public abstract class AbstractEncoder : Encoder, CompositeEncoder { throw SerializationException("'null' is not supported by default") } + public open fun skipNullElement(descriptor: SerialDescriptor, index: Int): Boolean { + return false + } + override fun encodeBoolean(value: Boolean): Unit = encodeValue(value) override fun encodeByte(value: Byte): Unit = encodeValue(value) override fun encodeShort(value: Short): Unit = encodeValue(value) @@ -86,6 +90,10 @@ public abstract class AbstractEncoder : Encoder, CompositeEncoder { serializer: SerializationStrategy, value: T? ) { + if (value == null && skipNullElement(descriptor, index)) { + return + } + if (encodeElement(descriptor, index)) encodeNullableSerializableValue(serializer, value) } diff --git a/core/commonMain/src/kotlinx/serialization/internal/ElementMarker.kt b/core/commonMain/src/kotlinx/serialization/internal/ElementMarker.kt new file mode 100644 index 000000000..c56c72bf9 --- /dev/null +++ b/core/commonMain/src/kotlinx/serialization/internal/ElementMarker.kt @@ -0,0 +1,93 @@ +package kotlinx.serialization.internal + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.CompositeDecoder + +@InternalSerializationApi +@OptIn(ExperimentalSerializationApi::class) +public abstract class ElementMarker(private val descriptor: SerialDescriptor) { + /* + Element decoding marks from given bytes. + The element number is the same as the bit position. + Marks for the lowest 64 elements are always stored in a single Long value, higher elements stores in long array. + */ + private var lowerMarks: Long + private val highMarksArray: LongArray? + + init { + val elementsCount = descriptor.elementsCount + if (elementsCount <= Long.SIZE_BITS) { + lowerMarks = if (elementsCount == Long.SIZE_BITS) { + // number of bits in the mark is equal to the number of fields + 0L + } else { + // (1 - elementsCount) bits are always 1 since there are no fields for them + -1L shl elementsCount + } + highMarksArray = null + } else { + lowerMarks = 0L + // (elementsCount - 1) because only one Long value is needed to store 64 fields etc + val slotsCount = (elementsCount - 1) / Long.SIZE_BITS + val elementsInLastSlot = elementsCount % Long.SIZE_BITS + val highMarks = LongArray(slotsCount) + // (elementsCount % Long.SIZE_BITS) == 0 this means that the fields occupy all bits in mark + if (elementsInLastSlot != 0) { + // all marks except the higher are always 0 + highMarks[highMarks.lastIndex] = -1L shl elementsCount + } + highMarksArray = highMarks + } + } + + protected abstract fun isPoppedElement(descriptor: SerialDescriptor, index: Int): Boolean + + public fun popUnmarkedIndex(): Int { + val elementsCount = descriptor.elementsCount + while (lowerMarks != -1L) { + val index = lowerMarks.inv().countTrailingZeroBits() + lowerMarks = lowerMarks or (1L shl index) + + if (isPoppedElement(descriptor, index)) { + return index + } + } + + if (elementsCount > Long.SIZE_BITS) { + val higherMarks = highMarksArray!! + + for (slot in higherMarks.indices) { + // (slot + 1) because first element in high marks has index 64 + val slotOffset = (slot + 1) * Long.SIZE_BITS + // store in a variable so as not to frequently use the array + var mark = higherMarks[slot] + + while (mark != -1L) { + val indexInSlot = mark.inv().countTrailingZeroBits() + mark = mark or (1L shl indexInSlot) + + val index = slotOffset + indexInSlot + if (isPoppedElement(descriptor, index)) { + higherMarks[slot] = mark + return index + } + } + higherMarks[slot] = mark + } + return CompositeDecoder.DECODE_DONE + } + return CompositeDecoder.DECODE_DONE + } + + public fun mark(index: Int) { + if (index < Long.SIZE_BITS) { + lowerMarks = lowerMarks or (1L shl index) + } else { + val slot = (index / Long.SIZE_BITS) - 1 + val offsetInSlot = index % Long.SIZE_BITS + highMarksArray!![slot] = highMarksArray[slot] or (1L shl offsetInSlot) + } + } +} diff --git a/core/commonMain/src/kotlinx/serialization/internal/Tagged.kt b/core/commonMain/src/kotlinx/serialization/internal/Tagged.kt index 8f00b4ab0..5750da5ac 100644 --- a/core/commonMain/src/kotlinx/serialization/internal/Tagged.kt +++ b/core/commonMain/src/kotlinx/serialization/internal/Tagged.kt @@ -61,6 +61,10 @@ public abstract class TaggedEncoder : Encoder, CompositeEncoder { return true } + protected open fun skipNullElement(descriptor: SerialDescriptor, index: Int): Boolean { + return false + } + final override fun encodeNotNullMark() {} // Does nothing, open because is not really required open override fun encodeNull(): Unit = encodeTaggedNull(popTag()) final override fun encodeBoolean(value: Boolean): Unit = encodeTaggedBoolean(popTag(), value) @@ -143,6 +147,10 @@ public abstract class TaggedEncoder : Encoder, CompositeEncoder { serializer: SerializationStrategy, value: T? ) { + if (value == null && skipNullElement(descriptor, index)) { + return + } + if (encodeElement(descriptor, index)) encodeNullableSerializableValue(serializer, value) } diff --git a/core/commonTest/src/kotlinx/serialization/ElementMarkerTest.kt b/core/commonTest/src/kotlinx/serialization/ElementMarkerTest.kt new file mode 100644 index 000000000..62b9b4672 --- /dev/null +++ b/core/commonTest/src/kotlinx/serialization/ElementMarkerTest.kt @@ -0,0 +1,105 @@ +package kotlinx.serialization + +import kotlinx.serialization.descriptors.* +import kotlinx.serialization.encoding.CompositeDecoder +import kotlinx.serialization.internal.ElementMarker +import kotlin.test.Test +import kotlin.test.assertEquals + +class ElementMarkerTest { + private class TestMarker(descriptor: SerialDescriptor, val predicate: (SerialDescriptor, Int) -> Boolean = { _, _ -> true }) : + ElementMarker(descriptor) { + + override fun isPoppedElement(descriptor: SerialDescriptor, index: Int): Boolean = + predicate(descriptor, index) + } + + + @Test + fun testNothingWasRead() { + val size = 5 + val descriptor = createClassDescriptor(size) + val reader = TestMarker(descriptor) + + for (i in 0 until size) { + assertEquals(i, reader.popUnmarkedIndex()) + } + assertEquals(CompositeDecoder.DECODE_DONE, reader.popUnmarkedIndex()) + } + + @Test + fun testAllWasRead() { + val size = 5 + val descriptor = createClassDescriptor(size) + val reader = TestMarker(descriptor) + for (i in 0 until size) { + reader.mark(i) + } + + assertEquals(CompositeDecoder.DECODE_DONE, reader.popUnmarkedIndex()) + } + + @Test + fun testFilteredRead() { + val size = 10 + val readIndex = 4 + + val predicate: (Any?, Int) -> Boolean = { _, i -> i % 2 == 0 } + val descriptor = createClassDescriptor(size) + val reader = TestMarker(descriptor, predicate) + reader.mark(readIndex) + + for (i in 0 until size) { + if (predicate(descriptor, i) && i != readIndex) { + //`readIndex` already read and only filtered elements must be read + assertEquals(i, reader.popUnmarkedIndex()) + } + } + assertEquals(CompositeDecoder.DECODE_DONE, reader.popUnmarkedIndex()) + } + + @Test + fun testSmallPartiallyRead() { + testPartiallyRead(Long.SIZE_BITS / 3) + } + + @Test + fun test64PartiallyRead() { + testPartiallyRead(Long.SIZE_BITS) + } + + @Test + fun test128PartiallyRead() { + testPartiallyRead(Long.SIZE_BITS * 2) + } + + @Test + fun testLargePartiallyRead() { + testPartiallyRead(Long.SIZE_BITS * 2 + Long.SIZE_BITS / 3) + } + + private fun testPartiallyRead(size: Int) { + val descriptor = createClassDescriptor(size) + val reader = TestMarker(descriptor) + for (i in 0 until size) { + if (i % 2 == 0) { + reader.mark(i) + } + } + + for (i in 0 until size) { + if (i % 2 != 0) { + assertEquals(i, reader.popUnmarkedIndex()) + } + } + assertEquals(CompositeDecoder.DECODE_DONE, reader.popUnmarkedIndex()) + } + + private fun createClassDescriptor(size: Int): SerialDescriptor { + return buildClassSerialDescriptor("descriptor") { + for (i in 0 until size) { + element("element$i", buildSerialDescriptor("int", PrimitiveKind.INT)) + } + } + } +} diff --git a/formats/json/api/kotlinx-serialization-json.api b/formats/json/api/kotlinx-serialization-json.api index 1797d5974..922fbcfc1 100644 --- a/formats/json/api/kotlinx-serialization-json.api +++ b/formats/json/api/kotlinx-serialization-json.api @@ -81,6 +81,7 @@ public final class kotlinx/serialization/json/JsonBuilder { public final fun getCoerceInputValues ()Z public final fun getEncodeDefaults ()Z public final fun getIgnoreUnknownKeys ()Z + public final fun getOmitNull ()Z public final fun getPrettyPrint ()Z public final fun getPrettyPrintIndent ()Ljava/lang/String; public final fun getSerializersModule ()Lkotlinx/serialization/modules/SerializersModule; @@ -94,6 +95,7 @@ public final class kotlinx/serialization/json/JsonBuilder { public final fun setEncodeDefaults (Z)V public final fun setIgnoreUnknownKeys (Z)V public final fun setLenient (Z)V + public final fun setOmitNull (Z)V public final fun setPrettyPrint (Z)V public final fun setPrettyPrintIndent (Ljava/lang/String;)V public final fun setSerializersModule (Lkotlinx/serialization/modules/SerializersModule;)V @@ -109,6 +111,7 @@ public final class kotlinx/serialization/json/JsonConfiguration { public final fun getCoerceInputValues ()Z public final fun getEncodeDefaults ()Z public final fun getIgnoreUnknownKeys ()Z + public final fun getOmitNull ()Z public final fun getPrettyPrint ()Z public final fun getPrettyPrintIndent ()Ljava/lang/String; public final fun getUseAlternativeNames ()Z diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt b/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt index 7ef3f40ba..71bfd37ec 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt @@ -96,7 +96,7 @@ public sealed class Json( */ public final override fun decodeFromString(deserializer: DeserializationStrategy, string: String): T { val lexer = JsonLexer(string) - val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer) + val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor) val result = input.decodeSerializableValue(deserializer) lexer.expectEof() return result @@ -170,6 +170,11 @@ public class JsonBuilder internal constructor(json: Json) { */ public var encodeDefaults: Boolean = json.configuration.encodeDefaults + /** + * TODO xdfdx + */ + public var omitNull: Boolean = json.configuration.omitNull + /** * Specifies whether encounters of unknown properties in the input JSON * should be ignored instead of throwing [SerializationException]. @@ -275,7 +280,7 @@ public class JsonBuilder internal constructor(json: Json) { return JsonConfiguration( encodeDefaults, ignoreUnknownKeys, isLenient, - allowStructuredMapKeys, prettyPrint, prettyPrintIndent, + allowStructuredMapKeys, prettyPrint, omitNull, prettyPrintIndent, coerceInputValues, useArrayPolymorphism, classDiscriminator, allowSpecialFloatingPointValues, useAlternativeNames ) diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/JsonConfiguration.kt b/formats/json/commonMain/src/kotlinx/serialization/json/JsonConfiguration.kt index c6a87ebea..557ac7f52 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/JsonConfiguration.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/JsonConfiguration.kt @@ -22,6 +22,7 @@ public class JsonConfiguration internal constructor( public val isLenient: Boolean = false, public val allowStructuredMapKeys: Boolean = false, public val prettyPrint: Boolean = false, + public val omitNull: Boolean = false, @ExperimentalSerializationApi public val prettyPrintIndent: String = " ", public val coerceInputValues: Boolean = false, @@ -33,6 +34,6 @@ public class JsonConfiguration internal constructor( /** @suppress Dokka **/ override fun toString(): String { - return "JsonConfiguration(encodeDefaults=$encodeDefaults, ignoreUnknownKeys=$ignoreUnknownKeys, isLenient=$isLenient, allowStructuredMapKeys=$allowStructuredMapKeys, prettyPrint=$prettyPrint, prettyPrintIndent='$prettyPrintIndent', coerceInputValues=$coerceInputValues, useArrayPolymorphism=$useArrayPolymorphism, classDiscriminator='$classDiscriminator', allowSpecialFloatingPointValues=$allowSpecialFloatingPointValues)" + return "JsonConfiguration(encodeDefaults=$encodeDefaults, ignoreUnknownKeys=$ignoreUnknownKeys, isLenient=$isLenient, allowStructuredMapKeys=$allowStructuredMapKeys, prettyPrint=$prettyPrint, omitNull=$omitNull, prettyPrintIndent='$prettyPrintIndent', coerceInputValues=$coerceInputValues, useArrayPolymorphism=$useArrayPolymorphism, classDiscriminator='$classDiscriminator', allowSpecialFloatingPointValues=$allowSpecialFloatingPointValues)" } } diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonAbsenceReader.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonAbsenceReader.kt new file mode 100644 index 000000000..b735d205a --- /dev/null +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonAbsenceReader.kt @@ -0,0 +1,18 @@ +package kotlinx.serialization.json.internal + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.internal.ElementMarker + +@InternalSerializationApi +@OptIn(ExperimentalSerializationApi::class) +internal class JsonAbsenceReader(descriptor: SerialDescriptor) : ElementMarker(descriptor) { + internal var poppedNull: Boolean = false + private set + + override fun isPoppedElement(descriptor: SerialDescriptor, index: Int): Boolean { + poppedNull = !descriptor.isElementOptional(index) && descriptor.getElementDescriptor(index).isNullable + return poppedNull + } +} diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt index a73e863d9..9b9e5ff06 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt @@ -19,12 +19,14 @@ import kotlin.jvm.* internal open class StreamingJsonDecoder( final override val json: Json, private val mode: WriteMode, - @JvmField internal val lexer: JsonLexer + @JvmField internal val lexer: JsonLexer, + descriptor: SerialDescriptor ) : JsonDecoder, AbstractDecoder() { override val serializersModule: SerializersModule = json.serializersModule private var currentIndex = -1 private val configuration = json.configuration + private val absenceReader: JsonAbsenceReader? = if (configuration.omitNull) JsonAbsenceReader(descriptor) else null override fun decodeJsonElement(): JsonElement = JsonTreeReader(json.configuration, lexer).read() @@ -41,12 +43,13 @@ internal open class StreamingJsonDecoder( WriteMode.LIST, WriteMode.MAP, WriteMode.POLY_OBJ -> StreamingJsonDecoder( json, newMode, - lexer + lexer, + descriptor ) - else -> if (mode == newMode) { + else -> if (mode == newMode && !json.configuration.omitNull) { this } else { - StreamingJsonDecoder(json, newMode, lexer) + StreamingJsonDecoder(json, newMode, lexer, descriptor) } } } @@ -56,7 +59,7 @@ internal open class StreamingJsonDecoder( } override fun decodeNotNullMark(): Boolean { - return lexer.tryConsumeNotNull() + return !(absenceReader?.poppedNull?:false) && lexer.tryConsumeNotNull() } override fun decodeNull(): Nothing? { @@ -124,6 +127,7 @@ internal open class StreamingJsonDecoder( hasComma = lexer.tryConsumeComma() false // Known element, but coerced } else { + absenceReader?.mark(index) return index // Known element without coercing, return it } } else { @@ -135,7 +139,8 @@ internal open class StreamingJsonDecoder( } } if (hasComma) lexer.fail("Unexpected trailing comma") - return CompositeDecoder.DECODE_DONE + + return absenceReader?.popUnmarkedIndex() ?: CompositeDecoder.DECODE_DONE } private fun handleUnknown(key: String): Boolean { diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonEncoder.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonEncoder.kt index 5b68278c2..88e3eac0b 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonEncoder.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonEncoder.kt @@ -147,6 +147,11 @@ internal class StreamingJsonEncoder( return true } + override fun skipNullElement(descriptor: SerialDescriptor, index: Int): Boolean { + // function should called only for nullable elements so no need to check `descriptor` + return configuration.omitNull + } + override fun encodeInline(inlineDescriptor: SerialDescriptor): Encoder = if (inlineDescriptor.isUnsignedNumber) StreamingJsonEncoder( ComposerForUnsignedNumbers( diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonDecoder.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonDecoder.kt index e9de5b38a..ea4805a76 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonDecoder.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonDecoder.kt @@ -17,7 +17,7 @@ import kotlin.jvm.* internal fun Json.readJson(element: JsonElement, deserializer: DeserializationStrategy): T { val input = when (element) { - is JsonObject -> JsonTreeDecoder(this, element) + is JsonObject -> JsonTreeDecoder(this, element, deserializer.descriptor) is JsonArray -> JsonTreeListDecoder(this, element) is JsonLiteral, JsonNull -> JsonPrimitiveDecoder(this, element as JsonPrimitive) } @@ -29,7 +29,7 @@ internal fun Json.readPolymorphicJson( element: JsonObject, deserializer: DeserializationStrategy ): T { - return JsonTreeDecoder(this, element, discriminator, deserializer.descriptor).decodeSerializableValue(deserializer) + return JsonTreeDecoder(this, element, deserializer.descriptor, discriminator, deserializer.descriptor).decodeSerializableValue(deserializer) } private sealed class AbstractJsonTreeDecoder( @@ -62,7 +62,7 @@ private sealed class AbstractJsonTreeDecoder( { JsonTreeMapDecoder(json, cast(currentObject, descriptor)) }, { JsonTreeListDecoder(json, cast(currentObject, descriptor)) } ) - else -> JsonTreeDecoder(json, cast(currentObject, descriptor)) + else -> JsonTreeDecoder(json, cast(currentObject, descriptor), descriptor) } } @@ -179,11 +179,14 @@ private class JsonPrimitiveDecoder(json: Json, override val value: JsonPrimitive private open class JsonTreeDecoder( json: Json, override val value: JsonObject, + objectDescriptor: SerialDescriptor? = null, private val polyDiscriminator: String? = null, private val polyDescriptor: SerialDescriptor? = null ) : AbstractJsonTreeDecoder(json, value) { private var position = 0 + private val absenceReader: JsonAbsenceReader? = + if (objectDescriptor != null && json.configuration.omitNull) JsonAbsenceReader(objectDescriptor) else null /* * Checks whether JSON has `null` value for non-null property or unknown enum value for enum property */ @@ -197,11 +200,17 @@ private open class JsonTreeDecoder( override fun decodeElementIndex(descriptor: SerialDescriptor): Int { while (position < descriptor.elementsCount) { val name = descriptor.getTag(position++) - if (name in value && (!configuration.coerceInputValues || !coerceInputValue(descriptor, position - 1, name))) { - return position - 1 + val index = position - 1 + if (name in value && (!configuration.coerceInputValues || !coerceInputValue(descriptor, index, name))) { + absenceReader?.mark(index) + return index } } - return CompositeDecoder.DECODE_DONE + return absenceReader?.popUnmarkedIndex() ?: CompositeDecoder.DECODE_DONE + } + + override fun decodeNotNullMark(): Boolean { + return !(absenceReader?.poppedNull?:false) && super.decodeNotNullMark() } override fun elementName(desc: SerialDescriptor, index: Int): String { diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonEncoder.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonEncoder.kt index 53d7e0174..8844ad7d3 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonEncoder.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonEncoder.kt @@ -171,6 +171,10 @@ private open class JsonTreeEncoder( content[key] = element } + override fun skipNullElement(descriptor: SerialDescriptor, index: Int): Boolean { + return configuration.omitNull + } + override fun getCurrent(): JsonElement = JsonObject(content) } diff --git a/formats/json/commonTest/src/kotlinx/serialization/json/AbstractJsonOmitNullTest.kt b/formats/json/commonTest/src/kotlinx/serialization/json/AbstractJsonOmitNullTest.kt new file mode 100644 index 000000000..d92342e25 --- /dev/null +++ b/formats/json/commonTest/src/kotlinx/serialization/json/AbstractJsonOmitNullTest.kt @@ -0,0 +1,138 @@ +package kotlinx.serialization.json + +import kotlinx.serialization.* +import kotlin.test.* + +@Ignore +abstract class AbstractJsonOmitNullTest { + @Serializable + data class Nullable( + val f0: Int?, + val f1: Int?, + val f2: Int?, + val f3: Int?, + ) + + @Serializable + data class WithRequired( + val f0: Int?, + val f1: Int?, + val f2: Int, + ) + + @Serializable + data class WithOptional( + val f0: Int?, + val f1: Int? = 1, + val f2: Int = 2, + ) + + @Serializable + data class Outer(val i: Inner) + + @Serializable + data class Inner(val s1: String?, val s2: String?) + + @Serializable + data class ListWithNullable(val l: List) + + @Serializable + data class MapWithNullable(val m: Map) + + @Serializable + data class NullableList(val l: List?) + + @Serializable + data class NullableMap(val m: Map?) + + + private val format = Json { omitNull = true } + + protected abstract fun Json.encode(value: T, serializer: KSerializer): String + + protected abstract fun Json.decode(json: String, serializer: KSerializer): T + + @Test + fun testNullable() { + val plain = Nullable(null, 10, null, null) + val json = """{"f1":10}""" + + assertEquals(json, format.encode(plain, Nullable.serializer())) + assertEquals(plain, format.decode(json, Nullable.serializer())) + } + + @Test + fun testMissingRequired() { + val json = """{"f1":10}""" + + assertFailsWith(SerializationException::class) { + format.decode(json, WithRequired.serializer()) + } + } + + @Test + fun testDecodeOptional() { + val json = """{}""" + + val decoded = format.decode(json, WithOptional.serializer()) + assertEquals(WithOptional(null), decoded) + } + + + @Test + fun testNestedJsonObject() { + val json = """{"i": {}}""" + + val decoded = format.decode(json, Outer.serializer()) + assertEquals(Outer(Inner(null, null)), decoded) + } + + @Test + fun testListWithNullable() { + val jsonWithNull = """{"l":[null]}""" + val jsonWithEmptyList = """{"l":[]}""" + + val encoded = format.encode(ListWithNullable(listOf(null)), ListWithNullable.serializer()) + assertEquals(jsonWithNull, encoded) + + val decoded = format.decode(jsonWithEmptyList, ListWithNullable.serializer()) + assertEquals(ListWithNullable(emptyList()), decoded) + } + + @Test + fun testMapWithNullable() { + val jsonWithNull = """{"m":{null:null}}""" + val jsonWithQuotedNull = """{"m":{"null":null}}""" + val jsonWithEmptyList = """{"m":{}}""" + + val encoded = format.encode(MapWithNullable(mapOf(null to null)), MapWithNullable.serializer()) + //Json encode map null key as `null:` but other external utilities may encode it as a String `"null":` + assertTrue { listOf(jsonWithNull, jsonWithQuotedNull).contains(encoded) } + + val decoded = format.decode(jsonWithEmptyList, MapWithNullable.serializer()) + assertEquals(MapWithNullable(emptyMap()), decoded) + } + + @Test + fun testNullableList() { + val json = """{}""" + + val encoded = format.encode(NullableList(null), NullableList.serializer()) + assertEquals(json, encoded) + + val decoded = format.decode(json, NullableList.serializer()) + assertEquals(NullableList(null), decoded) + } + + @Test + fun testNullableMap() { + val json = """{}""" + + val encoded = format.encode(NullableMap(null), NullableMap.serializer()) + assertEquals(json, encoded) + + val decoded = format.decode(json, NullableMap.serializer()) + assertEquals(NullableMap(null), decoded) + } + +} diff --git a/formats/json/commonTest/src/kotlinx/serialization/json/JsonOmitNullTest.kt b/formats/json/commonTest/src/kotlinx/serialization/json/JsonOmitNullTest.kt new file mode 100644 index 000000000..b7830a38b --- /dev/null +++ b/formats/json/commonTest/src/kotlinx/serialization/json/JsonOmitNullTest.kt @@ -0,0 +1,13 @@ +package kotlinx.serialization.json + +import kotlinx.serialization.* + +class JsonOmitNullTest: AbstractJsonOmitNullTest() { + override fun Json.encode(value: T, serializer: KSerializer): String { + return encodeToString(serializer, value) + } + + override fun Json.decode(json: String, serializer: KSerializer): T { + return decodeFromString(serializer, json) + } +} diff --git a/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt b/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt index d0ab7821c..6adb71280 100644 --- a/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt +++ b/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt @@ -38,7 +38,7 @@ abstract class JsonTestBase { decodeFromString(deserializer, source) } else { val lexer = JsonLexer(source) - val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer) + val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor) val tree = input.decodeJsonElement() lexer.expectEof() readJson(tree, deserializer) diff --git a/formats/json/commonTest/src/kotlinx/serialization/json/JsonTreeOmitNullTest.kt b/formats/json/commonTest/src/kotlinx/serialization/json/JsonTreeOmitNullTest.kt new file mode 100644 index 000000000..aa30a6d14 --- /dev/null +++ b/formats/json/commonTest/src/kotlinx/serialization/json/JsonTreeOmitNullTest.kt @@ -0,0 +1,14 @@ +package kotlinx.serialization.json + +import kotlinx.serialization.KSerializer + +class JsonTreeOmitNullTest: AbstractJsonOmitNullTest() { + override fun Json.encode(value: T, serializer: KSerializer): String { + return encodeToJsonElement(serializer, value).toString() + } + + override fun Json.decode(json: String, serializer: KSerializer): T { + val jsonElement = parseToJsonElement(json) + return decodeFromJsonElement(serializer, jsonElement) + } +} diff --git a/formats/json/commonTest/src/kotlinx/serialization/test/TestClass.kt b/formats/json/commonTest/src/kotlinx/serialization/test/TestClass.kt new file mode 100644 index 000000000..a5ea20682 --- /dev/null +++ b/formats/json/commonTest/src/kotlinx/serialization/test/TestClass.kt @@ -0,0 +1,29 @@ +package kotlinx.serialization.test + +import kotlinx.serialization.Serializable +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.encodeToJsonElement +import kotlin.test.Test + +class TestClass { + @Serializable + data class Nullable(val i: Int?) + + @Serializable + data class NullableList(val i: List) + + @Serializable + data class NullableMap(val i: Map) + + @Test + fun foo() { + println(Json {omitNull = true}.encodeToString(Nullable(null))) + println(Json {omitNull = true}.encodeToString(NullableList(listOf(null)))) + println(Json {omitNull = true}.encodeToString(NullableMap(mapOf(null to null)))) + + println(Json {omitNull = true}.encodeToJsonElement(Nullable(null))) + println(Json {omitNull = true}.encodeToJsonElement(NullableList(listOf(null)))) + println(Json {omitNull = true}.encodeToJsonElement(NullableMap(mapOf(null to null)))) + } +} diff --git a/formats/json/jsMain/src/kotlinx/serialization/json/internal/DynamicDecoders.kt b/formats/json/jsMain/src/kotlinx/serialization/json/internal/DynamicDecoders.kt index 0e71fb879..444a2c971 100644 --- a/formats/json/jsMain/src/kotlinx/serialization/json/internal/DynamicDecoders.kt +++ b/formats/json/jsMain/src/kotlinx/serialization/json/internal/DynamicDecoders.kt @@ -41,6 +41,8 @@ private open class DynamicInput( protected val keys: dynamic = js("Object").keys(value ?: js("{}")) protected open val size: Int = keys.length as Int + private var forceNull: Boolean = false + override val serializersModule: SerializersModule get() = json.serializersModule @@ -81,14 +83,23 @@ private open class DynamicInput( override fun decodeElementIndex(descriptor: SerialDescriptor): Int { while (currentPosition < descriptor.elementsCount) { val name = descriptor.getTag(currentPosition++) - if (hasName(name) && (!json.configuration.coerceInputValues || !coerceInputValue(descriptor, currentPosition - 1, name))) - return currentPosition - 1 + val index = currentPosition - 1 + forceNull = false + if ((hasName(name) || absenceIsNull(descriptor, index)) && (!json.configuration.coerceInputValues || !coerceInputValue(descriptor, index, name))) { + return index + } } return CompositeDecoder.DECODE_DONE } private fun hasName(name: String) = value[name] !== undefined + private fun absenceIsNull(descriptor: SerialDescriptor, index: Int): Boolean { + forceNull = json.configuration.omitNull + && !descriptor.isElementOptional(index) && descriptor.getElementDescriptor(index).isNullable + return forceNull + } + override fun elementName(desc: SerialDescriptor, index: Int): String { val mainName = desc.getElementName(index) if (!json.configuration.useAlternativeNames) return mainName @@ -141,6 +152,10 @@ private open class DynamicInput( } override fun decodeTaggedNotNullMark(tag: String): Boolean { + if (forceNull) { + return false + } + val o = getByTag(tag) if (o === undefined) throwMissingTag(tag) @Suppress("SENSELESS_COMPARISON") // null !== undefined ! diff --git a/formats/json/jsMain/src/kotlinx/serialization/json/internal/DynamicEncoders.kt b/formats/json/jsMain/src/kotlinx/serialization/json/internal/DynamicEncoders.kt index 633ab35fc..1e53d30d9 100644 --- a/formats/json/jsMain/src/kotlinx/serialization/json/internal/DynamicEncoders.kt +++ b/formats/json/jsMain/src/kotlinx/serialization/json/internal/DynamicEncoders.kt @@ -158,6 +158,10 @@ private class DynamicObjectEncoder( override fun shouldEncodeElementDefault(descriptor: SerialDescriptor, index: Int) = json.configuration.encodeDefaults + override fun skipNullElement(descriptor: SerialDescriptor, index: Int): Boolean { + return json.configuration.omitNull + } + private fun enterNode(jsObject: dynamic, writeMode: WriteMode) { val child = Node(writeMode, jsObject) child.parent = current diff --git a/formats/json/jsTest/src/kotlinx/serialization/json/JsonDynamicOmitNullTest.kt b/formats/json/jsTest/src/kotlinx/serialization/json/JsonDynamicOmitNullTest.kt new file mode 100644 index 000000000..dabcb0fcc --- /dev/null +++ b/formats/json/jsTest/src/kotlinx/serialization/json/JsonDynamicOmitNullTest.kt @@ -0,0 +1,15 @@ +package kotlinx.serialization.json + +import kotlinx.serialization.KSerializer +import kotlin.test.Test + +class JsonDynamicOmitNullTest : AbstractJsonOmitNullTest() { + override fun Json.encode(value: T, serializer: KSerializer): String { + return JSON.stringify(encodeToDynamic(serializer, value)) + } + + override fun Json.decode(json: String, serializer: KSerializer): T { + val x: dynamic = JSON.parse(json) + return decodeFromDynamic(serializer, x) + } +} diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufAbsenceReader.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufAbsenceReader.kt new file mode 100644 index 000000000..c05d74fcf --- /dev/null +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufAbsenceReader.kt @@ -0,0 +1,34 @@ +package kotlinx.serialization.protobuf.internal + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.internal.ElementMarker + +@InternalSerializationApi +@OptIn(ExperimentalSerializationApi::class) +internal class ProtobufAbsenceReader(descriptor: SerialDescriptor) : ElementMarker(descriptor) { + private var nullValue: Boolean = false + + override fun isPoppedElement(descriptor: SerialDescriptor, index: Int): Boolean { + if (!descriptor.isElementOptional(index)) { + val elementDescriptor = descriptor.getElementDescriptor(index) + val kind = elementDescriptor.kind + if (kind == StructureKind.MAP || kind == StructureKind.LIST) { + nullValue = false + return true + } else if (elementDescriptor.isNullable) { + nullValue = true + return true + } + } + return false + } + + fun popNullValue(): Boolean { + val prev = nullValue + nullValue = false + return prev + } +} diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt index 35b024e8e..60fa4b506 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt @@ -28,46 +28,12 @@ internal open class ProtobufDecoder( private var indexCache: IntArray? = null private var sparseIndexCache: MutableMap? = null - /* - Element decoding marks from given bytes. - The element number is the same as the bit position. - Marks for the lowest 64 elements are always stored in a single Long value, higher elements stores in long array. - */ - private var lowerReadMark: Long = 0 - private val highReadMarks: LongArray? - - private var valueIsNull: Boolean = false + private val absenceReader: ProtobufAbsenceReader = ProtobufAbsenceReader(descriptor) init { - highReadMarks = prepareReadMarks(descriptor) populateCache(descriptor) } - private fun prepareReadMarks(descriptor: SerialDescriptor): LongArray? { - val elementsCount = descriptor.elementsCount - return if (elementsCount <= Long.SIZE_BITS) { - lowerReadMark = if (elementsCount == Long.SIZE_BITS) { - // number og bits in the mark is equal to the number of fields - 0 - } else { - // (1 - elementsCount) bits are always 1 since there are no fields for them - -1L shl elementsCount - } - null - } else { - // (elementsCount - 1) because only one Long value is needed to store 64 fields etc - val slotsCount = (elementsCount - 1) / Long.SIZE_BITS - val elementsInLastSlot = elementsCount % Long.SIZE_BITS - val highReadMarks = LongArray(slotsCount) - // (elementsCount % Long.SIZE_BITS) == 0 this means that the fields occupy all bits in mark - if (elementsInLastSlot != 0) { - // all marks except the higher are always 0 - highReadMarks[highReadMarks.lastIndex] = -1L shl elementsCount - } - highReadMarks - } - } - public fun populateCache(descriptor: SerialDescriptor) { val elements = descriptor.elementsCount if (elements < 32) { @@ -247,97 +213,24 @@ internal open class ProtobufDecoder( override fun SerialDescriptor.getTag(index: Int) = extractParameters(index) - private fun findUnreadElementIndex(): Int { - val elementsCount = descriptor.elementsCount - while (lowerReadMark != -1L) { - val index = lowerReadMark.inv().countTrailingZeroBits() - lowerReadMark = lowerReadMark or (1L shl index) - - if (!descriptor.isElementOptional(index)) { - val elementDescriptor = descriptor.getElementDescriptor(index) - val kind = elementDescriptor.kind - if (kind == StructureKind.MAP || kind == StructureKind.LIST) { - return index - } else if (elementDescriptor.isNullable) { - valueIsNull = true - return index - } - } - } - - if (elementsCount > Long.SIZE_BITS) { - val higherMarks = highReadMarks!! - - for (slot in higherMarks.indices) { - // (slot + 1) because first element in high marks has index 64 - val slotOffset = (slot + 1) * Long.SIZE_BITS - // store in a variable so as not to frequently use the array - var mark = higherMarks[slot] - - while (mark != -1L) { - val indexInSlot = mark.inv().countTrailingZeroBits() - mark = mark or (1L shl indexInSlot) - - val index = slotOffset + indexInSlot - if (!descriptor.isElementOptional(index)) { - val elementDescriptor = descriptor.getElementDescriptor(index) - val kind = elementDescriptor.kind - if (kind == StructureKind.MAP || kind == StructureKind.LIST) { - higherMarks[slot] = mark - return index - } else if (elementDescriptor.isNullable) { - higherMarks[slot] = mark - valueIsNull = true - return index - } - } - } - - higherMarks[slot] = mark - } - return -1 - } - return -1 - } - - private fun markElementAsRead(index: Int) { - if (index < Long.SIZE_BITS) { - lowerReadMark = lowerReadMark or (1L shl index) - } else { - val slot = (index / Long.SIZE_BITS) - 1 - val offsetInSlot = index % Long.SIZE_BITS - highReadMarks!![slot] = highReadMarks[slot] or (1L shl offsetInSlot) - } - } - override fun decodeElementIndex(descriptor: SerialDescriptor): Int { while (true) { val protoId = reader.readTag() if (protoId == -1) { // EOF - val absenceIndex = findUnreadElementIndex() - return if (absenceIndex == -1) { - CompositeDecoder.DECODE_DONE - } else { - absenceIndex - } + return absenceReader.popUnmarkedIndex() } val index = getIndexByTag(protoId) if (index == -1) { // not found reader.skipElement() } else { - markElementAsRead(index) + absenceReader.mark(index) return index } } } override fun decodeNotNullMark(): Boolean { - return if (valueIsNull) { - valueIsNull = false - false - } else { - true - } + return !absenceReader.popNullValue() } }