diff --git a/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/MsgPack.kt b/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/MsgPack.kt index 2cbf07f..0d4fe9b 100644 --- a/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/MsgPack.kt +++ b/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/MsgPack.kt @@ -1,5 +1,7 @@ package com.ensarsarajcic.kotlinx.serialization.msgpack +import com.ensarsarajcic.kotlinx.serialization.msgpack.stream.MsgPackDataBuffer +import com.ensarsarajcic.kotlinx.serialization.msgpack.stream.toMsgPackBuffer import kotlinx.serialization.BinaryFormat import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.ExperimentalSerializationApi @@ -38,7 +40,7 @@ class MsgPack @JvmOverloads constructor( } override fun decodeFromByteArray(deserializer: DeserializationStrategy, bytes: ByteArray): T { - val decoder = MsgPackDecoder(configuration, serializersModule, bytes) + val decoder = MsgPackDecoder(configuration, serializersModule, bytes.toMsgPackBuffer()) return decoder.decodeSerializableValue(deserializer) } diff --git a/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/MsgPackDecoder.kt b/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/MsgPackDecoder.kt index e9a556b..6426240 100644 --- a/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/MsgPackDecoder.kt +++ b/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/MsgPackDecoder.kt @@ -1,5 +1,6 @@ package com.ensarsarajcic.kotlinx.serialization.msgpack +import com.ensarsarajcic.kotlinx.serialization.msgpack.stream.MsgPackDataBuffer import com.ensarsarajcic.kotlinx.serialization.msgpack.types.MsgPackType import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.builtins.ByteArraySerializer @@ -12,20 +13,8 @@ import kotlinx.serialization.modules.SerializersModule internal class MsgPackDecoder( private val configuration: MsgPackConfiguration, override val serializersModule: SerializersModule, - private val byteArray: ByteArray + private val dataBuffer: MsgPackDataBuffer ) : AbstractDecoder() { - // TODO extract into some form of ByteStream - private var index = 0 - private fun nextByteOrNull(): Byte? = byteArray.getOrNull(index++) - private fun requireNextByte(): Byte = nextByteOrNull() ?: throw Exception("End of stream") - private fun takeNext(next: Int): ByteArray { - require(next > 0) { "Number of bytes to take must be greater than 0!" } - val result = ByteArray(next) - (0 until next).forEach { - result[it] = requireNextByte() - } - return result - } // TODO Don't use flags, separate composite decoders for classes, lists and maps private var decodingClass = false @@ -43,17 +32,17 @@ internal class MsgPackDecoder( override fun decodeSequentially(): Boolean = !decodingClass override fun decodeNotNullMark(): Boolean { - val next = byteArray.getOrNull(index) ?: throw Exception("End of stream") + val next = dataBuffer.peek() return next != MsgPackType.NULL } override fun decodeNull(): Nothing? { - val next = requireNextByte() + val next = dataBuffer.requireNextByte() return if (next == MsgPackType.NULL) null else throw Exception("Invalid null $next") } override fun decodeBoolean(): Boolean { - return when (val next = requireNextByte()) { + return when (val next = dataBuffer.requireNextByte()) { MsgPackType.Boolean.TRUE -> true MsgPackType.Boolean.FALSE -> false else -> throw Exception("Invalid boolean $next") @@ -62,77 +51,75 @@ internal class MsgPackDecoder( override fun decodeByte(): Byte { // Check is it a single byte value - val next = requireNextByte() + val next = dataBuffer.requireNextByte() return when { MsgPackType.Int.POSITIVE_FIXNUM_MASK.test(next) or MsgPackType.Int.NEGATIVE_FIXNUM_MASK.test(next) -> next // TODO reader is not handling overflows (when using unsigned types) - MsgPackType.Int.isByte(next) -> nextByteOrNull() ?: throw Exception("End of stream") + MsgPackType.Int.isByte(next) -> dataBuffer.requireNextByte() else -> throw TODO("Add a more descriptive error when wrong type is found!") } } override fun decodeShort(): Short { - val next = byteArray.getOrNull(index) ?: throw Exception("End of stream") + val next = dataBuffer.peek() return when { MsgPackType.Int.isShort(next) -> { - index++ - takeNext(2).joinToNumber() + dataBuffer.skip(1) + dataBuffer.takeNext(2).joinToNumber() } next == MsgPackType.Int.UINT8 -> { - index++ - (requireNextByte().toInt() and 0xff).toShort() + dataBuffer.skip(1) + (dataBuffer.requireNextByte().toInt() and 0xff).toShort() } else -> decodeByte().toShort() } } override fun decodeInt(): Int { - val next = byteArray.getOrNull(index) ?: throw Exception("End of stream") + val next = dataBuffer.peek() return when { MsgPackType.Int.isInt(next) -> { - index++ - takeNext(4).joinToNumber() + dataBuffer.skip(1) + dataBuffer.takeNext(4).joinToNumber() } next == MsgPackType.Int.UINT16 -> { - index++ - takeNext(2).joinToNumber() + dataBuffer.skip(1) + dataBuffer.takeNext(2).joinToNumber() } else -> decodeShort().toInt() } } override fun decodeLong(): Long { - val next = byteArray.getOrNull(index) ?: throw Exception("End of stream") + val next = dataBuffer.peek() return when { MsgPackType.Int.isLong(next) -> { - index++ - takeNext(8).joinToNumber() + dataBuffer.skip(1) + dataBuffer.takeNext(8).joinToNumber() } next == MsgPackType.Int.UINT32 -> { - index++ - takeNext(4).joinToNumber() + dataBuffer.skip(1) + dataBuffer.takeNext(4).joinToNumber() } else -> decodeInt().toLong() } } override fun decodeFloat(): Float { - val next = byteArray.getOrNull(index) ?: throw Exception("End of stream") - return when (next) { + return when (dataBuffer.peek()) { MsgPackType.Float.FLOAT -> { - index++ - Float.fromBits(takeNext(4).joinToNumber()) + dataBuffer.skip(1) + Float.fromBits(dataBuffer.takeNext(4).joinToNumber()) } else -> TODO("Add a more descriptive error when wrong type is found!") } } override fun decodeDouble(): Double { - val next = byteArray.getOrNull(index) ?: throw Exception("End of stream") - return when (next) { + return when (dataBuffer.peek()) { MsgPackType.Float.DOUBLE -> { - index++ - Double.fromBits(takeNext(8).joinToNumber()) + dataBuffer.skip(1) + Double.fromBits(dataBuffer.takeNext(8).joinToNumber()) } MsgPackType.Float.FLOAT -> decodeFloat().toDouble() else -> TODO("Add a more descriptive error when wrong type is found!") @@ -140,53 +127,46 @@ internal class MsgPackDecoder( } override fun decodeString(): String { - val next = byteArray.getOrNull(index) ?: throw Exception("End of stream") - index++ + val next = dataBuffer.requireNextByte() val length = when { MsgPackType.String.FIXSTR_SIZE_MASK.test(next) -> MsgPackType.String.FIXSTR_SIZE_MASK.unMaskValue(next).toInt() - next == MsgPackType.String.STR8 -> requireNextByte().toInt() and 0xff - next == MsgPackType.String.STR16 -> takeNext(2).joinToNumber() + next == MsgPackType.String.STR8 -> dataBuffer.requireNextByte().toInt() and 0xff + next == MsgPackType.String.STR16 -> dataBuffer.takeNext(2).joinToNumber() // TODO: this may have issues with long strings, since size will overflow - next == MsgPackType.String.STR32 -> takeNext(4).joinToNumber() + next == MsgPackType.String.STR32 -> dataBuffer.takeNext(4).joinToNumber() else -> { - index-- throw TODO("Add a more descriptive error when wrong type is found!") } } if (length == 0) return "" - return takeNext(length).decodeToString() + return dataBuffer.takeNext(length).decodeToString() } fun decodeByteArray(): ByteArray { - val next = byteArray.getOrNull(index) ?: throw Exception("End of stream") - index++ + val next = dataBuffer.requireNextByte() val length = when (next) { - MsgPackType.Bin.BIN8 -> requireNextByte().toInt() and 0xff - MsgPackType.Bin.BIN16 -> takeNext(2).joinToNumber() + MsgPackType.Bin.BIN8 -> dataBuffer.requireNextByte().toInt() and 0xff + MsgPackType.Bin.BIN16 -> dataBuffer.takeNext(2).joinToNumber() // TODO: this may have issues with long byte arrays, since size will overflow - MsgPackType.Bin.BIN32 -> takeNext(4).joinToNumber() + MsgPackType.Bin.BIN32 -> dataBuffer.takeNext(4).joinToNumber() else -> { - index-- throw TODO("Add a more descriptive error when wrong type is found!") } } if (length == 0) return byteArrayOf() - return takeNext(length) + return dataBuffer.takeNext(length) } override fun decodeCollectionSize(descriptor: SerialDescriptor): Int { - val next = byteArray.getOrNull(index) ?: throw Exception("End of stream") - index++ - + val next = dataBuffer.requireNextByte() return when (descriptor.kind) { StructureKind.LIST -> when { MsgPackType.Array.FIXARRAY_SIZE_MASK.test(next) -> MsgPackType.Array.FIXARRAY_SIZE_MASK.unMaskValue(next).toInt() - next == MsgPackType.Array.ARRAY16 -> takeNext(2).joinToNumber() + next == MsgPackType.Array.ARRAY16 -> dataBuffer.takeNext(2).joinToNumber() // TODO: this may have issues with long arrays, since size will overflow - next == MsgPackType.Array.ARRAY32 -> takeNext(4).joinToNumber() + next == MsgPackType.Array.ARRAY32 -> dataBuffer.takeNext(4).joinToNumber() else -> { - index-- throw TODO("Add a more descriptive error when wrong type is found!") } } @@ -194,17 +174,15 @@ internal class MsgPackDecoder( StructureKind.CLASS, StructureKind.OBJECT, StructureKind.MAP -> when { MsgPackType.Map.FIXMAP_SIZE_MASK.test(next) -> MsgPackType.Map.FIXMAP_SIZE_MASK.unMaskValue(next).toInt() - next == MsgPackType.Map.MAP16 -> takeNext(2).joinToNumber() + next == MsgPackType.Map.MAP16 -> dataBuffer.takeNext(2).joinToNumber() // TODO: this may have issues with long objects, since size will overflow - next == MsgPackType.Map.MAP16 -> takeNext(4).joinToNumber() + next == MsgPackType.Map.MAP16 -> dataBuffer.takeNext(4).joinToNumber() else -> { - index-- throw TODO("Add a more descriptive error when wrong type is found!") } } else -> { - index-- TODO("Unsupported collection") } } diff --git a/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/extensions/MsgPackExtensionDecoder.kt b/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/extensions/MsgPackExtensionDecoder.kt new file mode 100644 index 0000000..09574f1 --- /dev/null +++ b/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/extensions/MsgPackExtensionDecoder.kt @@ -0,0 +1,16 @@ +package com.ensarsarajcic.kotlinx.serialization.msgpack.extensions + +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.AbstractDecoder +import kotlinx.serialization.modules.SerializersModule + +class MsgPackExtensionDecoder( + override val serializersModule: SerializersModule +) : AbstractDecoder() { + + override fun decodeElementIndex(descriptor: SerialDescriptor): Int = 0 + + override fun decodeValue(): Any { + return super.decodeValue() + } +} \ No newline at end of file diff --git a/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/extensions/MsgPackExtensionSerializer.kt b/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/extensions/MsgPackExtensionSerializer.kt new file mode 100644 index 0000000..9be0caf --- /dev/null +++ b/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/extensions/MsgPackExtensionSerializer.kt @@ -0,0 +1,20 @@ +package com.ensarsarajcic.kotlinx.serialization.msgpack.extensions + +import kotlinx.serialization.KSerializer +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder + +class MsgPackExtensionSerializer : KSerializer { + + override fun deserialize(decoder: Decoder): T { + TODO("Not yet implemented") + } + + override val descriptor: SerialDescriptor + get() = TODO("Not yet implemented") + + override fun serialize(encoder: Encoder, value: T) { + TODO("Not yet implemented") + } +} \ No newline at end of file diff --git a/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/stream/MsgPackDataBuffer.kt b/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/stream/MsgPackDataBuffer.kt new file mode 100644 index 0000000..39d280e --- /dev/null +++ b/serialization-msgpack/src/commonMain/kotlin/com.ensarsarajcic.kotlinx.serialization.msgpack/stream/MsgPackDataBuffer.kt @@ -0,0 +1,29 @@ +package com.ensarsarajcic.kotlinx.serialization.msgpack.stream + +internal class MsgPackDataBuffer( + private val byteArray: ByteArray +) { + private var index = 0 + + fun skip(bytes: Int) { + index += bytes + } + + fun peek(): Byte = byteArray.getOrNull(index) ?: throw Exception("End of stream") + + // Increases index only if next byte is not null + fun nextByteOrNull(): Byte? = byteArray.getOrNull(index)?.also { index++ } + + fun requireNextByte(): Byte = nextByteOrNull() ?: throw Exception("End of stream") + + fun takeNext(next: Int): ByteArray { + require(next > 0) { "Number of bytes to take must be greater than 0!" } + val result = ByteArray(next) + (0 until next).forEach { + result[it] = requireNextByte() + } + return result + } +} + +internal fun ByteArray.toMsgPackBuffer() = MsgPackDataBuffer(this) \ No newline at end of file diff --git a/serialization-msgpack/src/commonTest/kotlin/com/ensarsarajcic/kotlinx/serialization/msgpack/MsgPackDecoderTest.kt b/serialization-msgpack/src/commonTest/kotlin/com/ensarsarajcic/kotlinx/serialization/msgpack/MsgPackDecoderTest.kt index 9d885e9..a068e73 100644 --- a/serialization-msgpack/src/commonTest/kotlin/com/ensarsarajcic/kotlinx/serialization/msgpack/MsgPackDecoderTest.kt +++ b/serialization-msgpack/src/commonTest/kotlin/com/ensarsarajcic/kotlinx/serialization/msgpack/MsgPackDecoderTest.kt @@ -1,5 +1,6 @@ package com.ensarsarajcic.kotlinx.serialization.msgpack +import com.ensarsarajcic.kotlinx.serialization.msgpack.stream.toMsgPackBuffer import kotlinx.serialization.builtins.ArraySerializer import kotlinx.serialization.builtins.MapSerializer import kotlinx.serialization.builtins.serializer @@ -20,7 +21,7 @@ internal class MsgPackDecoderTest { @Test fun testNullDecode() { - val decoder = MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, byteArrayOf(0xc0.toByte())) + val decoder = MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, byteArrayOf(0xc0.toByte()).toMsgPackBuffer()) assertEquals(null, decoder.decodeNull()) } @@ -71,7 +72,7 @@ internal class MsgPackDecoderTest { @Test fun testFloatDecode() { TestData.floatTestPairs.forEach { (input, result) -> - MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray()).also { + MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray().toMsgPackBuffer()).also { // Tests in JS were failing when == comparison was used, so threshold is now used val threshold = 0.00001f val right = it.decodeFloat() @@ -83,7 +84,7 @@ internal class MsgPackDecoderTest { @Test fun testDoubleDecode() { TestData.doubleTestPairs.forEach { (input, result) -> - MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray()).also { + MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray().toMsgPackBuffer()).also { // Tests in JS were failing when == comparison was used, so threshold is now used val threshold = 0.000000000000000000000000000000000000000000001 val right = it.decodeDouble() @@ -105,7 +106,7 @@ internal class MsgPackDecoderTest { @Test fun testByteArrayDecode() { TestData.bin8TestPairs.forEach { (input, result) -> - MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray()).also { + MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray().toMsgPackBuffer()).also { assertTrue { result.contentEquals(it.decodeSerializableValue(serializer())) } } } @@ -114,7 +115,7 @@ internal class MsgPackDecoderTest { @Test fun testArrayDecodeStringArrays() { TestData.stringArrayTestPairs.forEach { (input, result) -> - val decoder = MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray()) + val decoder = MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray().toMsgPackBuffer()) val serializer = ArraySerializer(String.serializer()) assertEquals(result.toList(), serializer.deserialize(decoder).toList()) } @@ -123,7 +124,7 @@ internal class MsgPackDecoderTest { @Test fun testArrayDecodeIntArrays() { TestData.intArrayTestPairs.forEach { (input, result) -> - val decoder = MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray()) + val decoder = MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray().toMsgPackBuffer()) val serializer = ArraySerializer(Int.serializer()) assertEquals(result.toList(), serializer.deserialize(decoder).toList()) } @@ -132,7 +133,7 @@ internal class MsgPackDecoderTest { @Test fun testMapDecode() { TestData.mapTestPairs.forEach { (input, result) -> - val decoder = MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray()) + val decoder = MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray().toMsgPackBuffer()) val serializer = MapSerializer(String.serializer(), String.serializer()) assertEquals(result, serializer.deserialize(decoder)) } @@ -141,7 +142,7 @@ internal class MsgPackDecoderTest { @Test fun testSampleClassDecode() { TestData.sampleClassTestPairs.forEach { (input, result) -> - val decoder = MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray()) + val decoder = MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray().toMsgPackBuffer()) val serializer = TestData.SampleClass.serializer() assertEquals(result, serializer.deserialize(decoder)) } @@ -149,7 +150,7 @@ internal class MsgPackDecoderTest { private fun testPairs(decodeFunction: MsgPackDecoder.() -> RESULT, vararg pairs: Pair) { pairs.forEach { (input, result) -> - MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray()).also { + MsgPackDecoder(MsgPackConfiguration.default, SerializersModule {}, input.hexStringToByteArray().toMsgPackBuffer()).also { assertEquals(result, it.decodeFunction()) } } diff --git a/serialization-msgpack/src/commonTest/kotlin/com/ensarsarajcic/kotlinx/serialization/msgpack/stream/MsgPackDataBufferTest.kt b/serialization-msgpack/src/commonTest/kotlin/com/ensarsarajcic/kotlinx/serialization/msgpack/stream/MsgPackDataBufferTest.kt new file mode 100644 index 0000000..95252bc --- /dev/null +++ b/serialization-msgpack/src/commonTest/kotlin/com/ensarsarajcic/kotlinx/serialization/msgpack/stream/MsgPackDataBufferTest.kt @@ -0,0 +1,56 @@ +package com.ensarsarajcic.kotlinx.serialization.msgpack.stream + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.test.fail + +internal class MsgPackDataBufferTest { + + @Test + fun testEmptyBuffer() { + val buffer = MsgPackDataBuffer(byteArrayOf()) + + try { + buffer.peek() + fail("Peeking in empty buffer should fail!") + } catch (e: Exception) { + } + + try { + buffer.requireNextByte() + fail("Requiring next byte in empty buffer should fail!") + } catch (e: Exception) { + } + + assertNull(buffer.nextByteOrNull()) + } + + @Test + fun testBuffer() { + val buffer = MsgPackDataBuffer(byteArrayOf(0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06)) + + assertEquals(0x00, buffer.peek()) + assertEquals(0x00, buffer.nextByteOrNull()) + assertEquals(0x01, buffer.requireNextByte()) + assertEquals(byteArrayOf(0x02, 0x03).toList(), buffer.takeNext(2).toList()) + assertEquals(byteArrayOf(0x04, 0x05).toList(), buffer.takeNext(2).toList()) + assertEquals(0x06, buffer.requireNextByte()) + try { + buffer.peek() + fail("Peeking in at the end of buffer should fail!") + } catch (e: Exception) { + } + + try { + buffer.requireNextByte() + fail("Requiring next byte should fail!") + } catch (e: Exception) { + } + + assertNull(buffer.nextByteOrNull()) + + buffer.skip(-7) + assertEquals(0x00, buffer.peek()) + } +} \ No newline at end of file