Skip to content

Commit

Permalink
Optimize check for missing fields in deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
shanshin committed Nov 24, 2020
1 parent a710013 commit a79d8d1
Show file tree
Hide file tree
Showing 9 changed files with 423 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,25 @@

package org.jetbrains.kotlinx.serialization.compiler.backend.common

import org.jetbrains.kotlin.config.ApiVersion
import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor
import org.jetbrains.kotlin.descriptors.ClassDescriptor
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.descriptorUtil.module
import org.jetbrains.kotlin.resolve.descriptorUtil.secondaryConstructors
import org.jetbrains.kotlinx.serialization.compiler.diagnostic.VersionReader
import org.jetbrains.kotlinx.serialization.compiler.resolve.*

abstract class SerializableCodegen(
protected val serializableDescriptor: ClassDescriptor,
bindingContext: BindingContext
) : AbstractSerialGenerator(bindingContext, serializableDescriptor) {
protected val properties = bindingContext.serializablePropertiesFor(serializableDescriptor)
protected val staticDescriptor = serializableDescriptor.declaredTypeParameters.isEmpty()

private val fieldMissingOptimizationVersion = ApiVersion.parse("1.1")!!
protected val useFieldMissingOptimization = canUseFieldMissingOptimization()

fun generate() {
generateSyntheticInternalConstructor()
Expand All @@ -50,6 +57,40 @@ abstract class SerializableCodegen(
}
}

protected fun getGoldenMask(): Int {
var goldenMask = 0
var requiredBit = 1
for (property in properties.serializableProperties) {
if (!property.optional) {
goldenMask = goldenMask or requiredBit
}
requiredBit = requiredBit shl 1
}
return goldenMask
}

protected fun getGoldenMaskList(): List<Int> {
val maskSlotCount = properties.serializableProperties.bitMaskSlotCount()
val goldenMaskList = MutableList(maskSlotCount) { 0 }

for (i in properties.serializableProperties.indices) {
if (!properties.serializableProperties[i].optional) {
val slotNumber = i / 32
val bitInSlot = i % 32
goldenMaskList[slotNumber] = goldenMaskList[slotNumber] or (1 shl bitInSlot)
}
}
return goldenMaskList
}

private fun canUseFieldMissingOptimization(): Boolean {
val implementationVersion = VersionReader.getVersionsForCurrentModuleFromContext(
serializableDescriptor.module,
bindingContext
)?.implementationVersion
return if (implementationVersion != null) implementationVersion >= fieldMissingOptimizationVersion else false
}

protected abstract fun generateInternalConstructor(constructorDescriptor: ClassConstructorDescriptor)

protected open fun generateWriteSelfMethod(methodDescriptor: FunctionDescriptor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.jetbrains.kotlin.builtins.KotlinBuiltIns
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.descriptors.annotations.Annotations
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.buildFun
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.*
Expand Down Expand Up @@ -67,6 +68,30 @@ interface IrBuilderExtension {
) { bodyGen(c) }
}

// function will not be created in the real class
fun IrClass.createInlinedFunction(
name: Name,
visibility: DescriptorVisibility,
origin: IrDeclarationOrigin,
returnType: IrType,
bodyGen: IrBlockBodyBuilder.(IrFunction) -> Unit
): IrSimpleFunction {
val function = factory.buildFun {
this.name = name
this.visibility = visibility
this.origin = origin
this.isInline = true
this.returnType = returnType
}
val functionSymbol = function.symbol
function.parent = this
function.body = DeclarationIrBuilder(compilerContext, functionSymbol, startOffset, endOffset).irBlockBody(
startOffset,
endOffset
) { bodyGen(function) }
return function
}

fun IrBuilderWithScope.irInvoke(
dispatchReceiver: IrExpression? = null,
callee: IrFunctionSymbol,
Expand Down Expand Up @@ -109,6 +134,19 @@ interface IrBuilderExtension {
}
}

fun IrBuilderWithScope.createPrimitiveArrayOfExpression(
elementPrimitiveType: IrType,
arrayElements: List<IrExpression>
): IrExpression {
val arrayType = compilerContext.irBuiltIns.primitiveArrayForType.getValue(elementPrimitiveType).defaultType
val arg0 = IrVarargImpl(startOffset, endOffset, arrayType, elementPrimitiveType, arrayElements)
val typeArguments = listOf(elementPrimitiveType)

return irCall(compilerContext.symbols.arrayOf, arrayType, typeArguments = typeArguments).apply {
putValueArgument(0, arg0)
}
}

fun IrBuilderWithScope.irBinOp(name: Name, lhs: IrExpression, rhs: IrExpression): IrExpression {
val classFqName = (lhs.type as IrSimpleType).classOrNull!!.owner.fqNameWhenAvailable!!
val symbol = compilerContext.referenceFunctions(classFqName.child(name)).single()
Expand Down Expand Up @@ -766,4 +804,4 @@ interface IrBuilderExtension {
return superClasses.singleOrNull { it.kind == ClassKind.CLASS }
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,52 @@ package org.jetbrains.kotlinx.serialization.compiler.backend.ir
import org.jetbrains.kotlin.backend.common.deepCopyWithVariables
import org.jetbrains.kotlin.backend.common.lower.irThrow
import org.jetbrains.kotlin.codegen.CompilationException
import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.addField
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrDelegatingConstructorCallImpl
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.getAnnotation
import org.jetbrains.kotlin.ir.util.patchDeclarationParents
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.js.resolve.diagnostics.findPsi
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe
import org.jetbrains.kotlin.resolve.descriptorUtil.module
import org.jetbrains.kotlin.util.OperatorNameConventions
import org.jetbrains.kotlinx.serialization.compiler.backend.common.SerializableCodegen
import org.jetbrains.kotlinx.serialization.compiler.backend.common.serialName
import org.jetbrains.kotlinx.serialization.compiler.diagnostic.serializableAnnotationIsUseless
import org.jetbrains.kotlinx.serialization.compiler.extensions.SerializationPluginContext
import org.jetbrains.kotlinx.serialization.compiler.resolve.*
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.ARRAY_MASK_FIELD_MISSING_FUNC_FQ
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.MISSING_FIELD_EXC
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_DESC_FIELD
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SINGLE_MASK_FIELD_MISSING_FUNC_FQ
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.initializedDescriptorFieldName

class SerializableIrGenerator(
val irClass: IrClass,
override val compilerContext: SerializationPluginContext,
bindingContext: BindingContext
) : SerializableCodegen(irClass.descriptor, bindingContext), IrBuilderExtension {

private val descriptorGenerationFunctionName = "createInitializedDescriptor"

private val serialDescClass: ClassDescriptor = serializableDescriptor.module
.getClassFromSerializationDescriptorsPackage(SerialEntityNames.SERIAL_DESCRIPTOR_CLASS)

private val serialDescImplClass: ClassDescriptor = serializableDescriptor
.getClassFromInternalSerializationPackage(SerialEntityNames.SERIAL_DESCRIPTOR_CLASS_IMPL)

private val addElementFun = serialDescImplClass.findFunctionSymbol(CallingConventions.addElement)

val throwMissedFieldExceptionFunc =
if (useFieldMissingOptimization) compilerContext.referenceFunctions(SINGLE_MASK_FIELD_MISSING_FUNC_FQ).single() else null
val throwMissedFieldExceptionArrayFunc =
if (useFieldMissingOptimization) compilerContext.referenceFunctions(ARRAY_MASK_FIELD_MISSING_FUNC_FQ).single() else null

private fun IrClass.hasSerializableAnnotationWithoutArgs(): Boolean {
val annot = getAnnotation(SerializationAnnotations.serializableAnnotationFqName) ?: return false
Expand Down Expand Up @@ -64,6 +84,10 @@ class SerializableIrGenerator(
val thiz = irClass.thisReceiver!!
val superClass = irClass.getSuperClassOrAny()
var startPropOffset: Int = 0

if (useFieldMissingOptimization) {
generateOptimizedGoldenMaskCheck(seenVars)
}
when {
superClass.symbol == compilerContext.irBuiltIns.anyClass -> generateAnySuperConstructorCall(toBuilder = this@contributeConstructor)
superClass.isInternalSerializable -> {
Expand All @@ -83,7 +107,14 @@ class SerializableIrGenerator(
requireNotNull(transformFieldInitializer(prop.irField)) { "Optional value without an initializer" } // todo: filter abstract here
setProperty(irGet(thiz), prop.irProp, initializerBody)
} else {
irThrow(irInvoke(null, exceptionCtorRef, irString(prop.name), typeHint = exceptionType))
// property required
if (useFieldMissingOptimization) {
// field definitely not empty as it's checked before - no need another IF, only assign property from param
+assignParamExpr
continue
} else {
irThrow(irInvoke(null, exceptionCtorRef, irString(prop.name), typeHint = exceptionType))
}
}

val propNotSeenTest =
Expand Down Expand Up @@ -116,6 +147,153 @@ class SerializableIrGenerator(
}
}

private fun IrBlockBodyBuilder.generateOptimizedGoldenMaskCheck(seenVars: List<IrValueParameter>) {
if (serializableDescriptor.isAbstractSerializableClass() || serializableDescriptor.isSealedSerializableClass()) {
// for abstract classes fields MUST BE checked in child classes
return
}

val fieldsMissedTest: IrExpression
val throwErrorExpr: IrExpression

val maskSlotCount = seenVars.size
if (maskSlotCount == 1) {
val goldenMask = getGoldenMask()

throwErrorExpr = irInvoke(
null,
throwMissedFieldExceptionFunc!!,
irGet(seenVars[0]),
irInt(goldenMask),
getSerialDescriptorExpr(),
typeHint = compilerContext.irBuiltIns.unitType
)

fieldsMissedTest = irNotEquals(
irInt(goldenMask),
irBinOp(
OperatorNameConventions.AND,
irInt(goldenMask),
irGet(seenVars[0])
)
)
} else {
val goldenMaskList = getGoldenMaskList()

var compositeExpression: IrExpression? = null
for (i in goldenMaskList.indices) {
val singleCheckExpr = irNotEquals(
irInt(goldenMaskList[i]),
irBinOp(
OperatorNameConventions.AND,
irInt(goldenMaskList[i]),
irGet(seenVars[i])
)
)

compositeExpression = if (compositeExpression == null) {
singleCheckExpr
} else {
irBinOp(
OperatorNameConventions.OR,
compositeExpression,
singleCheckExpr
)
}
}

fieldsMissedTest = compositeExpression!!

throwErrorExpr = irBlock {
+irInvoke(
null,
throwMissedFieldExceptionArrayFunc!!,
createPrimitiveArrayOfExpression(compilerContext.irBuiltIns.intType, goldenMaskList.indices.map { irGet(seenVars[it]) }),
createPrimitiveArrayOfExpression(compilerContext.irBuiltIns.intType, goldenMaskList.map { irInt(it) }),
getSerialDescriptorExpr(),
typeHint = compilerContext.irBuiltIns.unitType
)
}
}

+irIfThen(compilerContext.irBuiltIns.unitType, fieldsMissedTest, throwErrorExpr)
}

private fun IrBlockBodyBuilder.getSerialDescriptorExpr(): IrExpression {
return if (serializableDescriptor.shouldHaveGeneratedSerializer && staticDescriptor) {
val serializer = serializableDescriptor.classSerializer!!
val serialDescriptorGetter = compilerContext.referenceClass(serializer.fqNameSafe)!!.getPropertyGetter(SERIAL_DESC_FIELD)!!
irGet(
serialDescriptorGetter.owner.returnType,
irGetObject(serializer),
serialDescriptorGetter.owner.symbol
)
} else {
irGetField(null, generateStaticDescriptorField())
}
}

private fun IrBlockBodyBuilder.generateStaticDescriptorField(): IrField {
val serialDescItType = serialDescClass.defaultType.toIrType()

val function = irClass.createInlinedFunction(
Name.identifier(descriptorGenerationFunctionName),
DescriptorVisibilities.PRIVATE,
SERIALIZABLE_PLUGIN_ORIGIN,
serialDescItType
) {
val serialDescVar = irTemporary(
getInstantiateDescriptorExpr(),
nameHint = "serialDesc"
)
for (property in properties.serializableProperties) {
+getAddElementToDescriptorExpr(property, serialDescVar)
}
+irReturn(irGet(serialDescVar))
}

return irClass.addField {
name = Name.identifier(initializedDescriptorFieldName)
visibility = DescriptorVisibilities.PRIVATE
origin = SERIALIZABLE_PLUGIN_ORIGIN
isFinal = true
isStatic = true
type = serialDescItType
}.apply { initializer = irClass.factory.createExpressionBody(irCall(function)) }
}

private fun IrBlockBodyBuilder.getInstantiateDescriptorExpr(): IrExpression {
val classConstructors = compilerContext.referenceConstructors(serialDescImplClass.fqNameSafe)
val serialClassDescImplCtor = classConstructors.single { it.owner.isPrimary }
return irInvoke(
null, serialClassDescImplCtor,
irString(serializableDescriptor.serialName()), irNull(), irInt(properties.serializableProperties.size)
)
}

private fun IrBlockBodyBuilder.getAddElementToDescriptorExpr(
property: SerializableProperty,
serialDescVar: IrVariable
): IrExpression {
return irInvoke(
irGet(serialDescVar),
addElementFun,
irString(property.name),
irBoolean(property.optional),
typeHint = compilerContext.irBuiltIns.unitType
)
}

private inline fun ClassDescriptor.findFunctionSymbol(
functionName: String,
predicate: (IrSimpleFunction) -> Boolean = { true }
): IrFunctionSymbol {
val irClass = compilerContext.referenceClass(fqNameSafe)?.owner ?: error("Couldn't load class $this")
val simpleFunctions = irClass.declarations.filterIsInstance<IrSimpleFunction>()

return simpleFunctions.filter { it.name.asString() == functionName }.single { predicate(it) }.symbol
}

private fun IrBlockBodyBuilder.generateSuperNonSerializableCall(superClass: IrClass) {
val ctorRef = superClass.declarations.filterIsInstance<IrConstructor>().singleOrNull { it.valueParameters.isEmpty() }
?: error("Non-serializable parent of serializable $serializableDescriptor must have no arg constructor")
Expand Down Expand Up @@ -189,4 +367,4 @@ class SerializableIrGenerator(
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.ST
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.STRUCTURE_ENCODER_CLASS
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.UNKNOWN_FIELD_EXC

object SERIALIZABLE_PLUGIN_ORIGIN : IrDeclarationOriginImpl("SERIALIZER")
object SERIALIZABLE_PLUGIN_ORIGIN : IrDeclarationOriginImpl("SERIALIZER", true)

// TODO: use in places where elements need to have ACC_SYNTHETIC on JVM
object SERIALIZABLE_SYNTHETIC_ORIGIN : IrDeclarationOriginImpl("SERIALIZER")
Expand Down
Loading

0 comments on commit a79d8d1

Please sign in to comment.