Skip to content

Commit

Permalink
Added custom parcelers
Browse files Browse the repository at this point in the history
  • Loading branch information
arkivanov committed Aug 17, 2023
1 parent d4db8e9 commit 124a9e7
Show file tree
Hide file tree
Showing 12 changed files with 363 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.arkivanov.parcelize.darwin

import org.jetbrains.kotlin.ir.backend.js.utils.typeArguments
import org.jetbrains.kotlin.ir.builders.IrBuilderWithScope
import org.jetbrains.kotlin.ir.builders.createTmpVariable
import org.jetbrains.kotlin.ir.builders.irBlock
Expand All @@ -16,6 +17,8 @@ import org.jetbrains.kotlin.ir.builders.irNotEquals
import org.jetbrains.kotlin.ir.builders.irNull
import org.jetbrains.kotlin.ir.builders.irString
import org.jetbrains.kotlin.ir.builders.irTrue
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.declarations.IrField
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.symbols.IrConstructorSymbol
import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
Expand All @@ -27,6 +30,10 @@ import org.jetbrains.kotlin.ir.types.makeNullable
import org.jetbrains.kotlin.ir.util.companionObject
import org.jetbrains.kotlin.ir.util.functions
import org.jetbrains.kotlin.ir.util.getPropertyGetter
import org.jetbrains.kotlin.ir.util.isAnnotation
import org.jetbrains.kotlin.ir.util.isObject
import org.jetbrains.kotlin.ir.util.parentAsClass
import org.jetbrains.kotlin.ir.util.render

interface Coder {

Expand All @@ -46,8 +53,21 @@ class CoderFactory(
private val symbols: Symbols,
) {

fun get(type: IrType): Coder =
get(type = type.toSupportedType(symbols))
fun get(field: IrField): Coder =
get(
type = IrTypeToSupportedTypeMapper(
symbols = symbols,
typeParcelers = field.parentAsClass.extractTypeParcelers(),
).map(type = field.type),
)

private fun IrClass.extractTypeParcelers(): Map<IrType, IrType> =
annotations
.filter { it.isAnnotation(typeParcelerName) }
.associateBy(
keySelector = { requireNotNull(it.typeArguments[0]) },
valueTransform = { requireNotNull(it.typeArguments[1]) },
)

fun get(type: SupportedType): Coder =
when (type) {
Expand Down Expand Up @@ -121,6 +141,13 @@ class CoderFactory(
decodeFunction = symbols.decodeDouble,
)

is SupportedType.Custom ->
CustomCoder(
symbols = symbols,
type = type.type,
parcelerType = type.parcelerType,
)

is SupportedType.String -> StringCoder(symbols = symbols)

is SupportedType.Enum ->
Expand Down Expand Up @@ -246,6 +273,85 @@ private class PrimitiveCoder(
}
}

private class CustomCoder(
private val symbols: Symbols,
private val type: IrType,
private val parcelerType: IrType,
) : Coder {
init {
require(parcelerType.requireClass().isObject) { "Parceler must be an object: ${parcelerType.render()}" }
}

override fun IrBuilderWithScope.encode(coder: IrExpression, value: IrExpression, key: IrExpression): IrExpression =
irBlock {
val archiver =
createTmpVariable(
irCallCompat(
callee = symbols.nsKeyedArchiverConstructor,
arguments = listOf(irTrue()),
)
)

+irCallCompat(
callee = parcelerType.requireClass().requireFunction(
name = "write",
valueParameterTypes = listOf(symbols.nsCoderType),
extensionReceiverParameterType = type,
),
extensionReceiver = value,
dispatchReceiver = irGetObject(parcelerType.classOrNull!!),
arguments = listOf(irGet(archiver)),
)

+irCallCompat(
callee = symbols.encodeObject,
extensionReceiver = coder,
arguments = listOf(
irCallCompat(callee = symbols.encodedData, dispatchReceiver = irGet(archiver)),
key,
)
)
}

override fun IrBuilderWithScope.decode(coder: IrExpression, key: IrExpression): IrExpression =
irBlock {
val data =
createTmpVariable(
irCallCompat(
callee = symbols.decodeObject,
extensionReceiver = coder,
arguments = listOf(
irGetObject(symbols.nsDataClass.owner.companionObject()!!.symbol),
key,
),
)
)

val unarchiver =
createTmpVariable(
irCallCompat(
callee = symbols.nsKeyedUnarchiverConstructor,
arguments = listOf(irGet(data)),
)
)

+irCallCompat(
callee = symbols.setRequireSecureCoding,
dispatchReceiver = irGet(unarchiver),
arguments = listOf(irTrue()),
)

+irCallCompat(
callee = parcelerType.requireClass().requireFunction(
name = "create",
valueParameterTypes = listOf(symbols.nsCoderType),
),
dispatchReceiver = irGetObject(parcelerType.classOrNull!!),
arguments = listOf(irGet(unarchiver)),
)
}
}

private class StringCoder(
private val symbols: Symbols,
) : Coder {
Expand Down Expand Up @@ -474,7 +580,8 @@ private class CollectionCoder(
condition = irNotEquals(arg1 = irGet(index), arg2 = irGet(size)),
body = irBlock {
+irCallCompat(
callee = collectionConstructor.owner.returnType.requireClass().requireFunction(name = "add"),
callee = collectionConstructor.owner.returnType.requireClass()
.requireFunction(name = "add"),
dispatchReceiver = irGet(collection),
arguments = listOf(
with(itemCoder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class ParcelizeClassLoweringPass(
}

private fun IrBlockBuilder.addEncodeFieldStatement(field: IrField, data: IrExpression, coder: IrExpression) {
+with(coderFactory.get(field.type)) {
+with(coderFactory.get(field)) {
encode(
coder = coder,
value = irGetField(data, field),
Expand Down Expand Up @@ -396,7 +396,7 @@ class ParcelizeClassLoweringPass(
) {
dataConstructorCall.putValueArgument(
index = index,
valueArgument = with(coderFactory.get(field.type)) {
valueArgument = with(coderFactory.get(field)) {
decode(
coder = coder,
key = irString(field.name.identifier),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package com.arkivanov.parcelize.darwin

import org.jetbrains.kotlin.backend.jvm.ir.erasedUpperBound
import org.jetbrains.kotlin.ir.backend.js.utils.typeArguments
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.typeOrNull
import org.jetbrains.kotlin.ir.util.defaultType
import org.jetbrains.kotlin.ir.util.getAnnotation
import org.jetbrains.kotlin.ir.util.hasAnnotation
import org.jetbrains.kotlin.ir.util.isEnumClass
import org.jetbrains.kotlin.ir.util.render

Expand All @@ -17,6 +20,7 @@ sealed interface SupportedType {
data class PrimitiveFloat(val isNullable: Boolean) : SupportedType
data class PrimitiveDouble(val isNullable: Boolean) : SupportedType
data class PrimitiveBoolean(val isNullable: Boolean) : SupportedType
data class Custom(val type: IrType, val parcelerType: IrType) : SupportedType
object String : SupportedType
data class Enum(val type: IrType) : SupportedType
object Parcelable : SupportedType
Expand All @@ -28,58 +32,76 @@ sealed interface SupportedType {
data class MutableMap(val keyType: SupportedType, val valueType: SupportedType) : SupportedType
}

fun IrType.toSupportedType(symbols: Symbols): SupportedType =
when {
this == symbols.intType -> SupportedType.PrimitiveInt(isNullable = false)
this == symbols.intNType -> SupportedType.PrimitiveInt(isNullable = true)
this == symbols.longType -> SupportedType.PrimitiveLong(isNullable = false)
this == symbols.longNType -> SupportedType.PrimitiveLong(isNullable = true)
this == symbols.shortType -> SupportedType.PrimitiveShort(isNullable = false)
this == symbols.shortNType -> SupportedType.PrimitiveShort(isNullable = true)
this == symbols.byteType -> SupportedType.PrimitiveByte(isNullable = false)
this == symbols.byteNType -> SupportedType.PrimitiveByte(isNullable = true)
this == symbols.charType -> SupportedType.PrimitiveChar(isNullable = false)
this == symbols.charNType -> SupportedType.PrimitiveChar(isNullable = true)
this == symbols.floatType -> SupportedType.PrimitiveFloat(isNullable = false)
this == symbols.floatNType -> SupportedType.PrimitiveFloat(isNullable = true)
this == symbols.doubleType -> SupportedType.PrimitiveDouble(isNullable = false)
this == symbols.doubleNType -> SupportedType.PrimitiveDouble(isNullable = true)
this == symbols.booleanType -> SupportedType.PrimitiveBoolean(isNullable = false)
this == symbols.booleanNType -> SupportedType.PrimitiveBoolean(isNullable = true)
(this == symbols.stringType) || (this == symbols.stringNType) -> SupportedType.String
erasedUpperBound.isEnumClass -> SupportedType.Enum(type = this)
isParcelable() -> SupportedType.Parcelable

erasedUpperBoundType == symbols.listType ->
SupportedType.List(itemType = getTypeArgument(0).toSupportedType(symbols))

erasedUpperBoundType == symbols.mutableListType ->
SupportedType.MutableList(itemType = getTypeArgument(0).toSupportedType(symbols))

erasedUpperBoundType == symbols.setType ->
SupportedType.Set(itemType = getTypeArgument(0).toSupportedType(symbols))

erasedUpperBoundType == symbols.mutableSetType ->
SupportedType.MutableSet(itemType = getTypeArgument(0).toSupportedType(symbols))

erasedUpperBoundType == symbols.mapType ->
SupportedType.Map(
keyType = getTypeArgument(0).toSupportedType(symbols),
valueType = getTypeArgument(1).toSupportedType(symbols),
)

erasedUpperBoundType == symbols.mutableMapType ->
SupportedType.MutableMap(
keyType = getTypeArgument(0).toSupportedType(symbols),
valueType = getTypeArgument(1).toSupportedType(symbols),
)

else -> error("Unsupported type: ${render()}")
}

private fun IrType.getTypeArgument(index: Int): IrType =
asIrSimpleType().arguments[index].typeOrNull!!

private val IrType.erasedUpperBoundType: IrType
get() = erasedUpperBound.defaultType

class IrTypeToSupportedTypeMapper(
private val symbols: Symbols,
private val typeParcelers: Map<IrType, IrType>,
) {
fun map(type: IrType): SupportedType =
when {
type == symbols.intType -> SupportedType.PrimitiveInt(isNullable = false)
type == symbols.intNType -> SupportedType.PrimitiveInt(isNullable = true)
type == symbols.longType -> SupportedType.PrimitiveLong(isNullable = false)
type == symbols.longNType -> SupportedType.PrimitiveLong(isNullable = true)
type == symbols.shortType -> SupportedType.PrimitiveShort(isNullable = false)
type == symbols.shortNType -> SupportedType.PrimitiveShort(isNullable = true)
type == symbols.byteType -> SupportedType.PrimitiveByte(isNullable = false)
type == symbols.byteNType -> SupportedType.PrimitiveByte(isNullable = true)
type == symbols.charType -> SupportedType.PrimitiveChar(isNullable = false)
type == symbols.charNType -> SupportedType.PrimitiveChar(isNullable = true)
type == symbols.floatType -> SupportedType.PrimitiveFloat(isNullable = false)
type == symbols.floatNType -> SupportedType.PrimitiveFloat(isNullable = true)
type == symbols.doubleType -> SupportedType.PrimitiveDouble(isNullable = false)
type == symbols.doubleNType -> SupportedType.PrimitiveDouble(isNullable = true)
type == symbols.booleanType -> SupportedType.PrimitiveBoolean(isNullable = false)
type == symbols.booleanNType -> SupportedType.PrimitiveBoolean(isNullable = true)

type.hasAnnotation(writeWithName) ->
SupportedType.Custom(
type = type,
parcelerType = requireNotNull(type.getAnnotation(writeWithName)?.typeArguments?.first()),
)

type in typeParcelers ->
SupportedType.Custom(
type = type,
parcelerType = typeParcelers.getValue(type),
)

(type == symbols.stringType) || (type == symbols.stringNType) -> SupportedType.String
type.erasedUpperBound.isEnumClass -> SupportedType.Enum(type = type)
type.isParcelable() -> SupportedType.Parcelable

type.erasedUpperBoundType == symbols.listType ->
SupportedType.List(itemType = map(type.getTypeArgument(0)))

type.erasedUpperBoundType == symbols.mutableListType ->
SupportedType.MutableList(itemType = map(type.getTypeArgument(0)))

type.erasedUpperBoundType == symbols.setType ->
SupportedType.Set(itemType = map(type.getTypeArgument(0)))

type.erasedUpperBoundType == symbols.mutableSetType ->
SupportedType.MutableSet(itemType = map(type.getTypeArgument(0)))

type.erasedUpperBoundType == symbols.mapType ->
SupportedType.Map(
keyType = map(type.getTypeArgument(0)),
valueType = map(type.getTypeArgument(1)),
)

type.erasedUpperBoundType == symbols.mutableMapType ->
SupportedType.MutableMap(
keyType = map(type.getTypeArgument(0)),
valueType = map(type.getTypeArgument(1)),
)

else -> error("Unsupported type: ${type.render()}")
}

private fun IrType.getTypeArgument(index: Int): IrType =
asIrSimpleType().arguments[index].typeOrNull!!

private val IrType.erasedUpperBoundType: IrType
get() = erasedUpperBound.defaultType
}
Loading

0 comments on commit 124a9e7

Please sign in to comment.