Skip to content

Commit

Permalink
feat: abstract out reading ByteArrays
Browse files Browse the repository at this point in the history
This closes #13
  • Loading branch information
esensar committed Feb 7, 2021
1 parent 18eefb0 commit 08b721a
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 75 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.ensarsarajcic.kotlinx.serialization.msgpack

import com.ensarsarajcic.kotlinx.serialization.msgpack.stream.MsgPackDataBuffer
import com.ensarsarajcic.kotlinx.serialization.msgpack.stream.toMsgPackBuffer
import kotlinx.serialization.BinaryFormat
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.ExperimentalSerializationApi
Expand Down Expand Up @@ -38,7 +40,7 @@ class MsgPack @JvmOverloads constructor(
}

override fun <T> decodeFromByteArray(deserializer: DeserializationStrategy<T>, bytes: ByteArray): T {
val decoder = MsgPackDecoder(configuration, serializersModule, bytes)
val decoder = MsgPackDecoder(configuration, serializersModule, bytes.toMsgPackBuffer())
return decoder.decodeSerializableValue(deserializer)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.ensarsarajcic.kotlinx.serialization.msgpack

import com.ensarsarajcic.kotlinx.serialization.msgpack.stream.MsgPackDataBuffer
import com.ensarsarajcic.kotlinx.serialization.msgpack.types.MsgPackType
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.builtins.ByteArraySerializer
Expand All @@ -12,20 +13,8 @@ import kotlinx.serialization.modules.SerializersModule
internal class MsgPackDecoder(
private val configuration: MsgPackConfiguration,
override val serializersModule: SerializersModule,
private val byteArray: ByteArray
private val dataBuffer: MsgPackDataBuffer
) : AbstractDecoder() {
// TODO extract into some form of ByteStream
private var index = 0
private fun nextByteOrNull(): Byte? = byteArray.getOrNull(index++)
private fun requireNextByte(): Byte = nextByteOrNull() ?: throw Exception("End of stream")
private fun takeNext(next: Int): ByteArray {
require(next > 0) { "Number of bytes to take must be greater than 0!" }
val result = ByteArray(next)
(0 until next).forEach {
result[it] = requireNextByte()
}
return result
}

// TODO Don't use flags, separate composite decoders for classes, lists and maps
private var decodingClass = false
Expand All @@ -43,17 +32,17 @@ internal class MsgPackDecoder(
override fun decodeSequentially(): Boolean = !decodingClass

override fun decodeNotNullMark(): Boolean {
val next = byteArray.getOrNull(index) ?: throw Exception("End of stream")
val next = dataBuffer.peek()
return next != MsgPackType.NULL
}

override fun decodeNull(): Nothing? {
val next = requireNextByte()
val next = dataBuffer.requireNextByte()
return if (next == MsgPackType.NULL) null else throw Exception("Invalid null $next")
}

override fun decodeBoolean(): Boolean {
return when (val next = requireNextByte()) {
return when (val next = dataBuffer.requireNextByte()) {
MsgPackType.Boolean.TRUE -> true
MsgPackType.Boolean.FALSE -> false
else -> throw Exception("Invalid boolean $next")
Expand All @@ -62,149 +51,138 @@ internal class MsgPackDecoder(

override fun decodeByte(): Byte {
// Check is it a single byte value
val next = requireNextByte()
val next = dataBuffer.requireNextByte()
return when {
MsgPackType.Int.POSITIVE_FIXNUM_MASK.test(next) or MsgPackType.Int.NEGATIVE_FIXNUM_MASK.test(next) -> next
// TODO reader is not handling overflows (when using unsigned types)
MsgPackType.Int.isByte(next) -> nextByteOrNull() ?: throw Exception("End of stream")
MsgPackType.Int.isByte(next) -> dataBuffer.requireNextByte()
else -> throw TODO("Add a more descriptive error when wrong type is found!")
}
}

override fun decodeShort(): Short {
val next = byteArray.getOrNull(index) ?: throw Exception("End of stream")
val next = dataBuffer.peek()
return when {
MsgPackType.Int.isShort(next) -> {
index++
takeNext(2).joinToNumber()
dataBuffer.skip(1)
dataBuffer.takeNext(2).joinToNumber()
}
next == MsgPackType.Int.UINT8 -> {
index++
(requireNextByte().toInt() and 0xff).toShort()
dataBuffer.skip(1)
(dataBuffer.requireNextByte().toInt() and 0xff).toShort()
}
else -> decodeByte().toShort()
}
}

override fun decodeInt(): Int {
val next = byteArray.getOrNull(index) ?: throw Exception("End of stream")
val next = dataBuffer.peek()
return when {
MsgPackType.Int.isInt(next) -> {
index++
takeNext(4).joinToNumber()
dataBuffer.skip(1)
dataBuffer.takeNext(4).joinToNumber()
}
next == MsgPackType.Int.UINT16 -> {
index++
takeNext(2).joinToNumber()
dataBuffer.skip(1)
dataBuffer.takeNext(2).joinToNumber()
}
else -> decodeShort().toInt()
}
}

override fun decodeLong(): Long {
val next = byteArray.getOrNull(index) ?: throw Exception("End of stream")
val next = dataBuffer.peek()
return when {
MsgPackType.Int.isLong(next) -> {
index++
takeNext(8).joinToNumber()
dataBuffer.skip(1)
dataBuffer.takeNext(8).joinToNumber()
}
next == MsgPackType.Int.UINT32 -> {
index++
takeNext(4).joinToNumber()
dataBuffer.skip(1)
dataBuffer.takeNext(4).joinToNumber()
}
else -> decodeInt().toLong()
}
}

override fun decodeFloat(): Float {
val next = byteArray.getOrNull(index) ?: throw Exception("End of stream")
return when (next) {
return when (dataBuffer.peek()) {
MsgPackType.Float.FLOAT -> {
index++
Float.fromBits(takeNext(4).joinToNumber())
dataBuffer.skip(1)
Float.fromBits(dataBuffer.takeNext(4).joinToNumber())
}
else -> TODO("Add a more descriptive error when wrong type is found!")
}
}

override fun decodeDouble(): Double {
val next = byteArray.getOrNull(index) ?: throw Exception("End of stream")
return when (next) {
return when (dataBuffer.peek()) {
MsgPackType.Float.DOUBLE -> {
index++
Double.fromBits(takeNext(8).joinToNumber())
dataBuffer.skip(1)
Double.fromBits(dataBuffer.takeNext(8).joinToNumber())
}
MsgPackType.Float.FLOAT -> decodeFloat().toDouble()
else -> TODO("Add a more descriptive error when wrong type is found!")
}
}

override fun decodeString(): String {
val next = byteArray.getOrNull(index) ?: throw Exception("End of stream")
index++
val next = dataBuffer.requireNextByte()
val length = when {
MsgPackType.String.FIXSTR_SIZE_MASK.test(next) -> MsgPackType.String.FIXSTR_SIZE_MASK.unMaskValue(next).toInt()
next == MsgPackType.String.STR8 -> requireNextByte().toInt() and 0xff
next == MsgPackType.String.STR16 -> takeNext(2).joinToNumber()
next == MsgPackType.String.STR8 -> dataBuffer.requireNextByte().toInt() and 0xff
next == MsgPackType.String.STR16 -> dataBuffer.takeNext(2).joinToNumber()
// TODO: this may have issues with long strings, since size will overflow
next == MsgPackType.String.STR32 -> takeNext(4).joinToNumber()
next == MsgPackType.String.STR32 -> dataBuffer.takeNext(4).joinToNumber()
else -> {
index--
throw TODO("Add a more descriptive error when wrong type is found!")
}
}
if (length == 0) return ""
return takeNext(length).decodeToString()
return dataBuffer.takeNext(length).decodeToString()
}

fun decodeByteArray(): ByteArray {
val next = byteArray.getOrNull(index) ?: throw Exception("End of stream")
index++
val next = dataBuffer.requireNextByte()
val length = when (next) {
MsgPackType.Bin.BIN8 -> requireNextByte().toInt() and 0xff
MsgPackType.Bin.BIN16 -> takeNext(2).joinToNumber()
MsgPackType.Bin.BIN8 -> dataBuffer.requireNextByte().toInt() and 0xff
MsgPackType.Bin.BIN16 -> dataBuffer.takeNext(2).joinToNumber()
// TODO: this may have issues with long byte arrays, since size will overflow
MsgPackType.Bin.BIN32 -> takeNext(4).joinToNumber()
MsgPackType.Bin.BIN32 -> dataBuffer.takeNext(4).joinToNumber()
else -> {
index--
throw TODO("Add a more descriptive error when wrong type is found!")
}
}
if (length == 0) return byteArrayOf()
return takeNext(length)
return dataBuffer.takeNext(length)
}

override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
val next = byteArray.getOrNull(index) ?: throw Exception("End of stream")
index++

val next = dataBuffer.requireNextByte()
return when (descriptor.kind) {
StructureKind.LIST ->
when {
MsgPackType.Array.FIXARRAY_SIZE_MASK.test(next) -> MsgPackType.Array.FIXARRAY_SIZE_MASK.unMaskValue(next).toInt()
next == MsgPackType.Array.ARRAY16 -> takeNext(2).joinToNumber()
next == MsgPackType.Array.ARRAY16 -> dataBuffer.takeNext(2).joinToNumber()
// TODO: this may have issues with long arrays, since size will overflow
next == MsgPackType.Array.ARRAY32 -> takeNext(4).joinToNumber()
next == MsgPackType.Array.ARRAY32 -> dataBuffer.takeNext(4).joinToNumber()
else -> {
index--
throw TODO("Add a more descriptive error when wrong type is found!")
}
}

StructureKind.CLASS, StructureKind.OBJECT, StructureKind.MAP ->
when {
MsgPackType.Map.FIXMAP_SIZE_MASK.test(next) -> MsgPackType.Map.FIXMAP_SIZE_MASK.unMaskValue(next).toInt()
next == MsgPackType.Map.MAP16 -> takeNext(2).joinToNumber()
next == MsgPackType.Map.MAP16 -> dataBuffer.takeNext(2).joinToNumber()
// TODO: this may have issues with long objects, since size will overflow
next == MsgPackType.Map.MAP16 -> takeNext(4).joinToNumber()
next == MsgPackType.Map.MAP16 -> dataBuffer.takeNext(4).joinToNumber()
else -> {
index--
throw TODO("Add a more descriptive error when wrong type is found!")
}
}

else -> {
index--
TODO("Unsupported collection")
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.ensarsarajcic.kotlinx.serialization.msgpack.extensions

import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.modules.SerializersModule

class MsgPackExtensionDecoder(
override val serializersModule: SerializersModule
) : AbstractDecoder() {

override fun decodeElementIndex(descriptor: SerialDescriptor): Int = 0

override fun decodeValue(): Any {
return super.decodeValue()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.ensarsarajcic.kotlinx.serialization.msgpack.extensions

import kotlinx.serialization.KSerializer
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder

class MsgPackExtensionSerializer<T> : KSerializer<T> {

override fun deserialize(decoder: Decoder): T {
TODO("Not yet implemented")
}

override val descriptor: SerialDescriptor
get() = TODO("Not yet implemented")

override fun serialize(encoder: Encoder, value: T) {
TODO("Not yet implemented")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.ensarsarajcic.kotlinx.serialization.msgpack.stream

internal class MsgPackDataBuffer(
private val byteArray: ByteArray
) {
private var index = 0

fun skip(bytes: Int) {
index += bytes
}

fun peek(): Byte = byteArray.getOrNull(index) ?: throw Exception("End of stream")

// Increases index only if next byte is not null
fun nextByteOrNull(): Byte? = byteArray.getOrNull(index)?.also { index++ }

fun requireNextByte(): Byte = nextByteOrNull() ?: throw Exception("End of stream")

fun takeNext(next: Int): ByteArray {
require(next > 0) { "Number of bytes to take must be greater than 0!" }
val result = ByteArray(next)
(0 until next).forEach {
result[it] = requireNextByte()
}
return result
}
}

internal fun ByteArray.toMsgPackBuffer() = MsgPackDataBuffer(this)
Loading

0 comments on commit 08b721a

Please sign in to comment.