Skip to content

Commit

Permalink
Optimize check for missing fields in deserialization (#3862)
Browse files Browse the repository at this point in the history
  • Loading branch information
shanshin authored Nov 25, 2020
1 parent f9503ef commit b5143ba
Show file tree
Hide file tree
Showing 10 changed files with 425 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,25 @@

package org.jetbrains.kotlinx.serialization.compiler.backend.common

import org.jetbrains.kotlin.config.ApiVersion
import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor
import org.jetbrains.kotlin.descriptors.ClassDescriptor
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.descriptorUtil.module
import org.jetbrains.kotlin.resolve.descriptorUtil.secondaryConstructors
import org.jetbrains.kotlinx.serialization.compiler.diagnostic.VersionReader
import org.jetbrains.kotlinx.serialization.compiler.resolve.*

abstract class SerializableCodegen(
protected val serializableDescriptor: ClassDescriptor,
bindingContext: BindingContext
) : AbstractSerialGenerator(bindingContext, serializableDescriptor) {
protected val properties = bindingContext.serializablePropertiesFor(serializableDescriptor)
protected val staticDescriptor = serializableDescriptor.declaredTypeParameters.isEmpty()

private val fieldMissingOptimizationVersion = ApiVersion.parse("1.1")!!
protected val useFieldMissingOptimization = canUseFieldMissingOptimization()

fun generate() {
generateSyntheticInternalConstructor()
Expand All @@ -50,6 +57,40 @@ abstract class SerializableCodegen(
}
}

protected fun getGoldenMask(): Int {
var goldenMask = 0
var requiredBit = 1
for (property in properties.serializableProperties) {
if (!property.optional) {
goldenMask = goldenMask or requiredBit
}
requiredBit = requiredBit shl 1
}
return goldenMask
}

protected fun getGoldenMaskList(): List<Int> {
val maskSlotCount = properties.serializableProperties.bitMaskSlotCount()
val goldenMaskList = MutableList(maskSlotCount) { 0 }

for (i in properties.serializableProperties.indices) {
if (!properties.serializableProperties[i].optional) {
val slotNumber = i / 32
val bitInSlot = i % 32
goldenMaskList[slotNumber] = goldenMaskList[slotNumber] or (1 shl bitInSlot)
}
}
return goldenMaskList
}

private fun canUseFieldMissingOptimization(): Boolean {
val implementationVersion = VersionReader.getVersionsForCurrentModuleFromContext(
serializableDescriptor.module,
bindingContext
)?.implementationVersion
return if (implementationVersion != null) implementationVersion >= fieldMissingOptimizationVersion else false
}

protected abstract fun generateInternalConstructor(constructorDescriptor: ClassConstructorDescriptor)

protected open fun generateWriteSelfMethod(methodDescriptor: FunctionDescriptor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.jetbrains.kotlin.builtins.KotlinBuiltIns
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.descriptors.annotations.Annotations
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.buildFun
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.*
Expand Down Expand Up @@ -67,6 +68,30 @@ interface IrBuilderExtension {
) { bodyGen(c) }
}

// function will not be created in the real class
fun IrClass.createInlinedFunction(
name: Name,
visibility: DescriptorVisibility,
origin: IrDeclarationOrigin,
returnType: IrType,
bodyGen: IrBlockBodyBuilder.(IrFunction) -> Unit
): IrSimpleFunction {
val function = factory.buildFun {
this.name = name
this.visibility = visibility
this.origin = origin
this.isInline = true
this.returnType = returnType
}
val functionSymbol = function.symbol
function.parent = this
function.body = DeclarationIrBuilder(compilerContext, functionSymbol, startOffset, endOffset).irBlockBody(
startOffset,
endOffset
) { bodyGen(function) }
return function
}

fun IrBuilderWithScope.irInvoke(
dispatchReceiver: IrExpression? = null,
callee: IrFunctionSymbol,
Expand Down Expand Up @@ -109,6 +134,19 @@ interface IrBuilderExtension {
}
}

fun IrBuilderWithScope.createPrimitiveArrayOfExpression(
elementPrimitiveType: IrType,
arrayElements: List<IrExpression>
): IrExpression {
val arrayType = compilerContext.irBuiltIns.primitiveArrayForType.getValue(elementPrimitiveType).defaultType
val arg0 = IrVarargImpl(startOffset, endOffset, arrayType, elementPrimitiveType, arrayElements)
val typeArguments = listOf(elementPrimitiveType)

return irCall(compilerContext.symbols.arrayOf, arrayType, typeArguments = typeArguments).apply {
putValueArgument(0, arg0)
}
}

fun IrBuilderWithScope.irBinOp(name: Name, lhs: IrExpression, rhs: IrExpression): IrExpression {
val classFqName = (lhs.type as IrSimpleType).classOrNull!!.owner.fqNameWhenAvailable!!
val symbol = compilerContext.referenceFunctions(classFqName.child(name)).single()
Expand Down Expand Up @@ -766,4 +804,4 @@ interface IrBuilderExtension {
return superClasses.singleOrNull { it.kind == ClassKind.CLASS }
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ class SerialInfoImplJvmIrGenerator(
generateSimplePropertyWithBackingField(property.descriptor, irClass, Name.identifier("_" + property.name.asString()))

val getter = property.getter!!
getter.origin = SERIALIZABLE_SYNTHETIC_ORIGIN
getter.origin = SERIALIZABLE_PLUGIN_ORIGIN
// Add JvmName annotation to property getters to force the resulting JVM method name for 'x' be 'x', instead of 'getX',
// and to avoid having useless bridges for it generated in BridgeLowering.
// Unfortunately, this results in an extra `@JvmName` annotation in the bytecode, but it shouldn't matter very much.
getter.annotations += jvmName(property.name.asString())

val field = property.backingField!!
field.visibility = DescriptorVisibilities.PRIVATE
field.origin = SERIALIZABLE_SYNTHETIC_ORIGIN
field.origin = SERIALIZABLE_PLUGIN_ORIGIN

val parameter = ctor.addValueParameter(property.name.asString(), getter.returnType)
ctorBody.statements += IrSetFieldImpl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,52 @@ package org.jetbrains.kotlinx.serialization.compiler.backend.ir
import org.jetbrains.kotlin.backend.common.deepCopyWithVariables
import org.jetbrains.kotlin.backend.common.lower.irThrow
import org.jetbrains.kotlin.codegen.CompilationException
import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.addField
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrDelegatingConstructorCallImpl
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.getAnnotation
import org.jetbrains.kotlin.ir.util.patchDeclarationParents
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.js.resolve.diagnostics.findPsi
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe
import org.jetbrains.kotlin.resolve.descriptorUtil.module
import org.jetbrains.kotlin.util.OperatorNameConventions
import org.jetbrains.kotlinx.serialization.compiler.backend.common.SerializableCodegen
import org.jetbrains.kotlinx.serialization.compiler.backend.common.serialName
import org.jetbrains.kotlinx.serialization.compiler.diagnostic.serializableAnnotationIsUseless
import org.jetbrains.kotlinx.serialization.compiler.extensions.SerializationPluginContext
import org.jetbrains.kotlinx.serialization.compiler.resolve.*
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.ARRAY_MASK_FIELD_MISSING_FUNC_FQ
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.MISSING_FIELD_EXC
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SERIAL_DESC_FIELD
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.SINGLE_MASK_FIELD_MISSING_FUNC_FQ
import org.jetbrains.kotlinx.serialization.compiler.resolve.SerialEntityNames.initializedDescriptorFieldName

class SerializableIrGenerator(
val irClass: IrClass,
override val compilerContext: SerializationPluginContext,
bindingContext: BindingContext
) : SerializableCodegen(irClass.descriptor, bindingContext), IrBuilderExtension {

private val descriptorGenerationFunctionName = "createInitializedDescriptor"

private val serialDescClass: ClassDescriptor = serializableDescriptor.module
.getClassFromSerializationDescriptorsPackage(SerialEntityNames.SERIAL_DESCRIPTOR_CLASS)

private val serialDescImplClass: ClassDescriptor = serializableDescriptor
.getClassFromInternalSerializationPackage(SerialEntityNames.SERIAL_DESCRIPTOR_CLASS_IMPL)

private val addElementFun = serialDescImplClass.findFunctionSymbol(CallingConventions.addElement)

val throwMissedFieldExceptionFunc =
if (useFieldMissingOptimization) compilerContext.referenceFunctions(SINGLE_MASK_FIELD_MISSING_FUNC_FQ).single() else null
val throwMissedFieldExceptionArrayFunc =
if (useFieldMissingOptimization) compilerContext.referenceFunctions(ARRAY_MASK_FIELD_MISSING_FUNC_FQ).single() else null

private fun IrClass.hasSerializableAnnotationWithoutArgs(): Boolean {
val annot = getAnnotation(SerializationAnnotations.serializableAnnotationFqName) ?: return false
Expand Down Expand Up @@ -64,6 +84,10 @@ class SerializableIrGenerator(
val thiz = irClass.thisReceiver!!
val superClass = irClass.getSuperClassOrAny()
var startPropOffset: Int = 0

if (useFieldMissingOptimization) {
generateOptimizedGoldenMaskCheck(seenVars)
}
when {
superClass.symbol == compilerContext.irBuiltIns.anyClass -> generateAnySuperConstructorCall(toBuilder = this@contributeConstructor)
superClass.isInternalSerializable -> {
Expand All @@ -83,7 +107,14 @@ class SerializableIrGenerator(
requireNotNull(transformFieldInitializer(prop.irField)) { "Optional value without an initializer" } // todo: filter abstract here
setProperty(irGet(thiz), prop.irProp, initializerBody)
} else {
irThrow(irInvoke(null, exceptionCtorRef, irString(prop.name), typeHint = exceptionType))
// property required
if (useFieldMissingOptimization) {
// field definitely not empty as it's checked before - no need another IF, only assign property from param
+assignParamExpr
continue
} else {
irThrow(irInvoke(null, exceptionCtorRef, irString(prop.name), typeHint = exceptionType))
}
}

val propNotSeenTest =
Expand Down Expand Up @@ -116,6 +147,153 @@ class SerializableIrGenerator(
}
}

private fun IrBlockBodyBuilder.generateOptimizedGoldenMaskCheck(seenVars: List<IrValueParameter>) {
if (serializableDescriptor.isAbstractSerializableClass() || serializableDescriptor.isSealedSerializableClass()) {
// for abstract classes fields MUST BE checked in child classes
return
}

val fieldsMissedTest: IrExpression
val throwErrorExpr: IrExpression

val maskSlotCount = seenVars.size
if (maskSlotCount == 1) {
val goldenMask = getGoldenMask()

throwErrorExpr = irInvoke(
null,
throwMissedFieldExceptionFunc!!,
irGet(seenVars[0]),
irInt(goldenMask),
getSerialDescriptorExpr(),
typeHint = compilerContext.irBuiltIns.unitType
)

fieldsMissedTest = irNotEquals(
irInt(goldenMask),
irBinOp(
OperatorNameConventions.AND,
irInt(goldenMask),
irGet(seenVars[0])
)
)
} else {
val goldenMaskList = getGoldenMaskList()

var compositeExpression: IrExpression? = null
for (i in goldenMaskList.indices) {
val singleCheckExpr = irNotEquals(
irInt(goldenMaskList[i]),
irBinOp(
OperatorNameConventions.AND,
irInt(goldenMaskList[i]),
irGet(seenVars[i])
)
)

compositeExpression = if (compositeExpression == null) {
singleCheckExpr
} else {
irBinOp(
OperatorNameConventions.OR,
compositeExpression,
singleCheckExpr
)
}
}

fieldsMissedTest = compositeExpression!!

throwErrorExpr = irBlock {
+irInvoke(
null,
throwMissedFieldExceptionArrayFunc!!,
createPrimitiveArrayOfExpression(compilerContext.irBuiltIns.intType, goldenMaskList.indices.map { irGet(seenVars[it]) }),
createPrimitiveArrayOfExpression(compilerContext.irBuiltIns.intType, goldenMaskList.map { irInt(it) }),
getSerialDescriptorExpr(),
typeHint = compilerContext.irBuiltIns.unitType
)
}
}

+irIfThen(compilerContext.irBuiltIns.unitType, fieldsMissedTest, throwErrorExpr)
}

private fun IrBlockBodyBuilder.getSerialDescriptorExpr(): IrExpression {
return if (serializableDescriptor.shouldHaveGeneratedSerializer && staticDescriptor) {
val serializer = serializableDescriptor.classSerializer!!
val serialDescriptorGetter = compilerContext.referenceClass(serializer.fqNameSafe)!!.getPropertyGetter(SERIAL_DESC_FIELD)!!
irGet(
serialDescriptorGetter.owner.returnType,
irGetObject(serializer),
serialDescriptorGetter.owner.symbol
)
} else {
irGetField(null, generateStaticDescriptorField())
}
}

private fun IrBlockBodyBuilder.generateStaticDescriptorField(): IrField {
val serialDescItType = serialDescClass.defaultType.toIrType()

val function = irClass.createInlinedFunction(
Name.identifier(descriptorGenerationFunctionName),
DescriptorVisibilities.PRIVATE,
SERIALIZABLE_PLUGIN_ORIGIN,
serialDescItType
) {
val serialDescVar = irTemporary(
getInstantiateDescriptorExpr(),
nameHint = "serialDesc"
)
for (property in properties.serializableProperties) {
+getAddElementToDescriptorExpr(property, serialDescVar)
}
+irReturn(irGet(serialDescVar))
}

return irClass.addField {
name = Name.identifier(initializedDescriptorFieldName)
visibility = DescriptorVisibilities.PRIVATE
origin = SERIALIZABLE_PLUGIN_ORIGIN
isFinal = true
isStatic = true
type = serialDescItType
}.apply { initializer = irClass.factory.createExpressionBody(irCall(function)) }
}

private fun IrBlockBodyBuilder.getInstantiateDescriptorExpr(): IrExpression {
val classConstructors = compilerContext.referenceConstructors(serialDescImplClass.fqNameSafe)
val serialClassDescImplCtor = classConstructors.single { it.owner.isPrimary }
return irInvoke(
null, serialClassDescImplCtor,
irString(serializableDescriptor.serialName()), irNull(), irInt(properties.serializableProperties.size)
)
}

private fun IrBlockBodyBuilder.getAddElementToDescriptorExpr(
property: SerializableProperty,
serialDescVar: IrVariable
): IrExpression {
return irInvoke(
irGet(serialDescVar),
addElementFun,
irString(property.name),
irBoolean(property.optional),
typeHint = compilerContext.irBuiltIns.unitType
)
}

private inline fun ClassDescriptor.findFunctionSymbol(
functionName: String,
predicate: (IrSimpleFunction) -> Boolean = { true }
): IrFunctionSymbol {
val irClass = compilerContext.referenceClass(fqNameSafe)?.owner ?: error("Couldn't load class $this")
val simpleFunctions = irClass.declarations.filterIsInstance<IrSimpleFunction>()

return simpleFunctions.filter { it.name.asString() == functionName }.single { predicate(it) }.symbol
}

private fun IrBlockBodyBuilder.generateSuperNonSerializableCall(superClass: IrClass) {
val ctorRef = superClass.declarations.filterIsInstance<IrConstructor>().singleOrNull { it.valueParameters.isEmpty() }
?: error("Non-serializable parent of serializable $serializableDescriptor must have no arg constructor")
Expand Down Expand Up @@ -189,4 +367,4 @@ class SerializableIrGenerator(
}
}
}
}
}
Loading

0 comments on commit b5143ba

Please sign in to comment.