Skip to content

Commit

Permalink
JVM IR: handle JvmStatic in object as module phase
Browse files Browse the repository at this point in the history
This allows to get rid of the situation where a JvmStatic function in
object can be seen in different states in different lowerings: unlowered
with a dispatch receiver parameter, declaration is lowered but calls are
not, and both declaration and calls are lowered.

Now it works like this:
1) JvmStatic functions in objects coming from dependencies are always
   loaded as lowered, without the extra dispatch receiver parameter. In
   psi2ir this is done via JVM-specific extension; in fir2ir it's done
   in place (but probably should be extracted to extension too).
2) Functions from sources are created as unlowered by both psi2ir and
   fir2ir, and are lowered in a module-wide phase at the beginning of
   JvmLower.
3) Calls to all JvmStatic functions from objects (from sources and
   dependencies) are lowered in the same phase at the beginning of
   JvmLower.

This ensures that all lowerings after the module-wide phase
`jvmStaticInObjectPhase`, which include all per-file phases, see all
JvmStatic functions in objects without the additional dispatch receiver
parameter, and calls do not have dispatch receiver either.

The only issue with this approach is that function/property reference
representation in reflection needs to have that dispatch receiver
parameter, and that is achieved via a hack in those lowerings, which
seems not too out of place anyway, given that they're handled specially
in kotlin-reflect as well.
  • Loading branch information
udalov committed Nov 24, 2020
1 parent 0a00cbe commit 7117932
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.serialization.deserialization.descriptors.DeserializedContainerSource
import org.jetbrains.kotlin.descriptors.DescriptorVisibility
import org.jetbrains.kotlin.ir.util.hasAnnotation
import org.jetbrains.kotlin.ir.util.isNonCompanionObject
import org.jetbrains.kotlin.ir.util.isObject
import org.jetbrains.kotlin.resolve.annotations.JVM_STATIC_ANNOTATION_FQ_NAME

class Fir2IrLazySimpleFunction(
components: Fir2IrComponents,
Expand Down Expand Up @@ -90,7 +94,9 @@ class Fir2IrLazySimpleFunction(

override var dispatchReceiverParameter: IrValueParameter? by lazyVar {
val containingClass = parent as? IrClass
if (!fir.isStatic && containingClass != null) {
if (containingClass != null && !fir.isStatic &&
!(containingClass.isNonCompanionObject && hasAnnotation(JVM_STATIC_ANNOTATION_FQ_NAME))
) {
declarationStorage.enterScope(this)
declareThisReceiverParameter(
symbolTable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import org.jetbrains.kotlin.backend.common.ir.Ir
import org.jetbrains.kotlin.backend.common.lower.irThrow
import org.jetbrains.kotlin.backend.common.phaser.PhaseConfig
import org.jetbrains.kotlin.backend.jvm.codegen.*
import org.jetbrains.kotlin.backend.jvm.codegen.createFakeContinuation
import org.jetbrains.kotlin.backend.jvm.descriptors.JvmSharedVariablesManager
import org.jetbrains.kotlin.backend.jvm.intrinsics.IrIntrinsicMethods
import org.jetbrains.kotlin.backend.jvm.lower.BridgeLowering
Expand Down Expand Up @@ -128,8 +127,6 @@ class JvmBackendContext(
val suspendFunctionOriginalToView = mutableMapOf<IrFunction, IrFunction>()
val fakeContinuation: IrExpression = createFakeContinuation(this)

val jvmStaticObjectFunctionToStaticFunctionMap = mutableMapOf<IrSimpleFunction, IrSimpleFunction>()

val staticDefaultStubs = mutableMapOf<IrSimpleFunctionSymbol, IrSimpleFunction>()

val inlineClassReplacements = MemoizedInlineClassReplacements(state.functionsWithInlineClassReturnTypesMangled, irFactory, this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.jetbrains.kotlin.load.kotlin.JvmPackagePartSource
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.psi2ir.generators.GeneratorExtensions
import org.jetbrains.kotlin.resolve.DescriptorUtils
import org.jetbrains.kotlin.resolve.annotations.hasJvmStaticAnnotation
import org.jetbrains.kotlin.resolve.jvm.JvmClassName
import org.jetbrains.kotlin.resolve.jvm.annotations.hasJvmFieldAnnotation
import org.jetbrains.kotlin.resolve.scopes.MemberScope
Expand Down Expand Up @@ -79,6 +81,11 @@ class JvmGeneratorExtensions(private val generateFacades: Boolean = true) : Gene
override fun isPropertyWithPlatformField(descriptor: PropertyDescriptor): Boolean =
descriptor.hasJvmFieldAnnotation()

override fun isStaticFunction(descriptor: FunctionDescriptor): Boolean =
DescriptorUtils.isNonCompanionObject(descriptor.containingDeclaration) &&
(descriptor.hasJvmStaticAnnotation() ||
descriptor is PropertyAccessorDescriptor && descriptor.correspondingProperty.hasJvmStaticAnnotation())

override val enhancedNullability: EnhancedNullability
get() = JvmEnhancedNullability

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ private val jvmFilePhases = listOf(
initializersPhase,
initializersCleanupPhase,
functionNVarargBridgePhase,
jvmStaticAnnotationPhase,
jvmStaticInCompanionPhase,
staticDefaultFunctionPhase,
bridgePhase,
syntheticAccessorPhase,
Expand Down Expand Up @@ -392,6 +392,7 @@ val jvmPhases = NamedCompilerPhase(
expectDeclarationsRemovingPhase then
scriptsToClassesPhase then
fileClassPhase then
jvmStaticInObjectPhase then
performByIrFile(lower = jvmFilePhases) then
generateMultifileFacadesPhase then
resolveInlineCallsPhase then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.addConstructor
import org.jetbrains.kotlin.ir.builders.declarations.addFunction
import org.jetbrains.kotlin.ir.builders.declarations.addValueParameter
import org.jetbrains.kotlin.ir.builders.declarations.buildClass
import org.jetbrains.kotlin.ir.builders.declarations.*
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.IrClassReferenceImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrFunctionReferenceImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrGetObjectValueImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrInstanceInitializerCallImpl
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.types.*
Expand Down Expand Up @@ -111,7 +109,9 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext)
// However, when we bind a value of an inline class type as a receiver, the receiver will turn into an argument of
// the function in question. Yet we still need to record it as the "receiver" in CallableReference in order for reflection
// to work correctly.
private val boundReceiver: Pair<IrValueParameter, IrExpression>? = irFunctionReference.getArgumentsWithIr().singleOrNull()
private val boundReceiver: Pair<IrValueParameter, IrExpression>? =
if (callee.isJvmStaticInObject()) createFakeBoundReceiverForJvmStaticInObject()
else irFunctionReference.getArgumentsWithIr().singleOrNull()

// The type of the reference is KFunction<in A1, ..., in An, out R>
private val parameterTypes = (irFunctionReference.type as IrSimpleType).arguments.map { (it as IrTypeProjection).type }
Expand Down Expand Up @@ -485,6 +485,19 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext)
)
}

private fun createFakeBoundReceiverForJvmStaticInObject(): Pair<IrValueParameter, IrGetObjectValueImpl> {
// JvmStatic functions in objects are special in that they are generated as static methods in the bytecode, and JVM IR lowers
// both declarations and call sites early on in jvmStaticInObjectPhase because it's easier that way in subsequent lowerings.
// However from the point of view of Kotlin language (and thus reflection), these functions still take the dispatch receiver
// parameter of the object type. So we pretend here that a JvmStatic function in object has an additional dispatch receiver
// parameter, so that the correct function reference object will be created and reflective calls will work at runtime.
val objectClass = callee.parentAsClass
return buildValueParameter(callee) {
name = Name.identifier("\$this")
type = objectClass.typeWith()
} to IrGetObjectValueImpl(UNDEFINED_OFFSET, UNDEFINED_OFFSET, objectClass.typeWith(), objectClass.symbol)
}

private fun createLegacyMethodOverride(
superFunction: IrSimpleFunction,
generator: JvmIrBuilder.() -> IrExpression
Expand All @@ -494,7 +507,6 @@ internal class FunctionReferenceLowering(private val context: JvmBackendContext)
irExprBody(generator())
}
}

}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,21 @@ import org.jetbrains.kotlin.backend.common.ir.*
import org.jetbrains.kotlin.backend.common.lower.createIrBuilder
import org.jetbrains.kotlin.backend.common.lower.irBlock
import org.jetbrains.kotlin.backend.common.phaser.makeIrFilePhase
import org.jetbrains.kotlin.backend.common.phaser.makeIrModulePhase
import org.jetbrains.kotlin.backend.common.runOnFilePostfix
import org.jetbrains.kotlin.backend.jvm.JvmBackendContext
import org.jetbrains.kotlin.backend.jvm.JvmLoweredDeclarationOrigin
import org.jetbrains.kotlin.backend.jvm.ir.copyCorrespondingPropertyFrom
import org.jetbrains.kotlin.backend.jvm.ir.isInCurrentModule
import org.jetbrains.kotlin.backend.jvm.ir.replaceThisByStaticReference
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.ir.builders.declarations.addFunction
import org.jetbrains.kotlin.ir.builders.declarations.buildFun
import org.jetbrains.kotlin.ir.builders.irCall
import org.jetbrains.kotlin.ir.builders.irExprBody
import org.jetbrains.kotlin.ir.builders.irGet
import org.jetbrains.kotlin.ir.builders.irGetField
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.IrCall
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrMemberAccessExpression
import org.jetbrains.kotlin.ir.expressions.IrTypeOperator
import org.jetbrains.kotlin.ir.expressions.impl.IrTypeOperatorCallImpl
import org.jetbrains.kotlin.ir.util.*
Expand All @@ -36,35 +34,37 @@ import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.resolve.annotations.JVM_STATIC_ANNOTATION_FQ_NAME

internal val jvmStaticAnnotationPhase = makeIrFilePhase(
::JvmStaticAnnotationLowering,
name = "JvmStaticAnnotation",
description = "Handle JvmStatic annotations"
internal val jvmStaticInObjectPhase = makeIrModulePhase(
::JvmStaticInObjectLowering,
name = "JvmStaticInObject",
description = "Make JvmStatic functions in non-companion objects static and replace all call sites in the module"
)

/*
* For @JvmStatic functions within companion objects of classes, we synthesize proxy static functions that redirect
* to the actual implementation.
* For @JvmStatic functions within static objects, we make the actual function static and modify all call sites.
*/
private class JvmStaticAnnotationLowering(val context: JvmBackendContext) : IrElementTransformerVoid(), FileLoweringPass {
internal val jvmStaticInCompanionPhase = makeIrFilePhase(
::JvmStaticInCompanionLowering,
name = "JvmStaticInCompanion",
description = "Synthesize static proxy functions for JvmStatic functions in companion objects"
)

private class JvmStaticInObjectLowering(val context: JvmBackendContext) : IrElementTransformerVoid(), FileLoweringPass {
override fun lower(irFile: IrFile) {
CompanionObjectJvmStaticLowering(context).runOnFilePostfix(irFile)
SingletonObjectJvmStaticLowering(context).runOnFilePostfix(irFile)
irFile.transformChildrenVoid(MakeCallsStatic(context))
}
}

private class CompanionObjectJvmStaticLowering(val context: JvmBackendContext) : ClassLoweringPass {
private class JvmStaticInCompanionLowering(val context: JvmBackendContext) : IrElementTransformerVoid(), ClassLoweringPass {
override fun lower(irClass: IrClass) {
val companion = irClass.declarations.find {
it is IrClass && it.isCompanion
} as? IrClass ?: return
val companion = irClass.companionObject() ?: return

companion.declarations
// In case of companion objects, proxy functions for '$default' methods for @JvmStatic functions with default parameters
// are not created in the host class.
.filter { isJvmStaticFunction(it) && it.origin != IrDeclarationOrigin.FUNCTION_FOR_DEFAULT_PARAMETER }
.filter {
it.isJvmStaticDeclaration() &&
it.origin != IrDeclarationOrigin.FUNCTION_FOR_DEFAULT_PARAMETER &&
it.origin != JvmLoweredDeclarationOrigin.SYNTHETIC_METHOD_FOR_PROPERTY_ANNOTATIONS
}
.forEach { declaration ->
val jvmStaticFunction = declaration as IrSimpleFunction
if (jvmStaticFunction.isExternal) {
Expand Down Expand Up @@ -129,107 +129,46 @@ private class CompanionObjectJvmStaticLowering(val context: JvmBackendContext) :

private class SingletonObjectJvmStaticLowering(val context: JvmBackendContext) : ClassLoweringPass {
override fun lower(irClass: IrClass) {
if (!irClass.isObject || irClass.isCompanion) return

val jvmStaticFunctionsToReplace = irClass.declarations.filter {
// dispatch receiver parameter is already null for synthetic property annotation methods
isJvmStaticFunction(it) && it is IrSimpleFunction && it.dispatchReceiverParameter != null
}
jvmStaticFunctionsToReplace.forEach { function ->
val replacement = createReplacement(context, function as IrSimpleFunction)
// Set dispatch receiver parameter for body move operation.
replacement.dispatchReceiverParameter = function.dispatchReceiverParameter
replacement.body = function.moveBodyTo(replacement)
replacement.replaceThisByStaticReference(context.cachedDeclarations, irClass, function.dispatchReceiverParameter!!)
// Clear dispatch receiver parameter again after body move operation.
replacement.dispatchReceiverParameter = null
irClass.declarations.remove(function)
irClass.declarations.add(replacement)
if (!irClass.isNonCompanionObject) return

for (function in irClass.simpleFunctions()) {
if (function.isJvmStaticDeclaration()) {
// dispatch receiver parameter is already null for synthetic property annotation methods
function.dispatchReceiverParameter?.let { oldDispatchReceiverParameter ->
function.dispatchReceiverParameter = null
function.replaceThisByStaticReference(context.cachedDeclarations, irClass, oldDispatchReceiverParameter)
}
}
}
}
}

private fun createReplacement(
context: JvmBackendContext,
jvmStaticFunction: IrSimpleFunction
): IrSimpleFunction =
context.jvmStaticObjectFunctionToStaticFunctionMap.getOrPut(jvmStaticFunction) {
val irClass = jvmStaticFunction.parentAsClass
val newFunction = context.irFactory.buildFun {
updateFrom(jvmStaticFunction)
name = jvmStaticFunction.name
returnType = jvmStaticFunction.returnType
}.apply {
parent = irClass
copyTypeParametersFrom(jvmStaticFunction)
copyAnnotationsFrom(jvmStaticFunction)
extensionReceiverParameter = jvmStaticFunction.extensionReceiverParameter?.copyTo(this)
valueParameters = jvmStaticFunction.valueParameters.map { it.copyTo(this) }
copyAttributes(jvmStaticFunction)
copyCorrespondingPropertyFrom(jvmStaticFunction)
metadata = jvmStaticFunction.metadata
}
context.jvmStaticObjectFunctionToStaticFunctionMap[jvmStaticFunction] = newFunction
newFunction
}


private fun IrFunction.isJvmStaticInSingleton(): Boolean {
val parentClass = parent as? IrClass ?: return false
return isJvmStaticFunction(this) && parentClass.isObject && !parentClass.isCompanion
}
internal fun IrDeclaration.isJvmStaticInObject(): Boolean =
isJvmStaticDeclaration() && (parent as? IrClass)?.isNonCompanionObject == true

private class MakeCallsStatic(val context: JvmBackendContext) : IrElementTransformerVoid() {
override fun visitCall(expression: IrCall): IrExpression {
if (expression.symbol.owner.isJvmStaticInSingleton() && expression.dispatchReceiver != null) {
// Imported functions do not have their receiver parameter nulled by SingletonObjectJvmStaticLowering,
// so we have to do it here.
// TODO: would be better handled by lowering imported declarations.
val callee = expression.symbol.owner
val newCallee = if (!callee.isInCurrentModule()) {
callee.copyRemovingDispatchReceiver() // TODO: cache these
} else {
createReplacement(context, callee)
}

override fun visitMemberAccess(expression: IrMemberAccessExpression<*>): IrExpression {
val callee = expression.symbol.owner
if (callee is IrDeclaration && callee.isJvmStaticInObject() && expression.dispatchReceiver != null) {
return context.createIrBuilder(expression.symbol, expression.startOffset, expression.endOffset).irBlock(expression) {
// OldReceiver has to be evaluated for its side effects.
val oldReceiver = super.visitExpression(expression.dispatchReceiver!!)
// `coerceToUnit()` is private in InsertImplicitCasts, have to reproduce it here
val oldReceiverVoid = IrTypeOperatorCallImpl(
+IrTypeOperatorCallImpl(
oldReceiver.startOffset, oldReceiver.endOffset,
context.irBuiltIns.unitType,
IrTypeOperator.IMPLICIT_COERCION_TO_UNIT,
context.irBuiltIns.unitType,
oldReceiver
)

+super.visitExpression(oldReceiverVoid)
+super.visitCall(
irCall(expression, newFunction = newCallee).apply { dispatchReceiver = null }
)
expression.dispatchReceiver = null
+super.visitMemberAccess(expression)
}
}
return super.visitCall(expression)
return super.visitMemberAccess(expression)
}

private fun IrSimpleFunction.copyRemovingDispatchReceiver(): IrSimpleFunction =
factory.buildFun {
updateFrom(this@copyRemovingDispatchReceiver)
name = this@copyRemovingDispatchReceiver.name
returnType = this@copyRemovingDispatchReceiver.returnType
}.also {
it.parent = parent
it.copyCorrespondingPropertyFrom(this)
it.annotations += annotations
it.copyParameterDeclarationsFrom(this)
it.dispatchReceiverParameter = null
it.copyAttributes(this)
}
}

private fun isJvmStaticFunction(declaration: IrDeclaration): Boolean =
declaration is IrSimpleFunction &&
(declaration.hasAnnotation(JVM_STATIC_ANNOTATION_FQ_NAME) ||
declaration.correspondingPropertySymbol?.owner?.hasAnnotation(JVM_STATIC_ANNOTATION_FQ_NAME) == true) &&
declaration.origin != JvmLoweredDeclarationOrigin.SYNTHETIC_METHOD_FOR_PROPERTY_ANNOTATIONS
private fun IrDeclaration.isJvmStaticDeclaration(): Boolean =
hasAnnotation(JVM_STATIC_ANNOTATION_FQ_NAME) ||
(this as? IrSimpleFunction)?.correspondingPropertySymbol?.owner?.hasAnnotation(JVM_STATIC_ANNOTATION_FQ_NAME) == true
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ import org.jetbrains.kotlin.ir.expressions.impl.IrExpressionBodyImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrGetFieldImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrSetFieldImpl
import org.jetbrains.kotlin.ir.symbols.impl.IrAnonymousInitializerSymbolImpl
import org.jetbrains.kotlin.ir.util.filterOutAnnotations
import org.jetbrains.kotlin.ir.util.isObject
import org.jetbrains.kotlin.ir.util.parentAsClass
import org.jetbrains.kotlin.ir.util.patchDeclarationParents
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
import org.jetbrains.kotlin.resolve.deprecation.DeprecationResolver

Expand All @@ -45,7 +42,7 @@ internal val remapObjectFieldAccesses = makeIrFilePhase(

private class MoveOrCopyCompanionObjectFieldsLowering(val context: JvmBackendContext) : ClassLoweringPass {
override fun lower(irClass: IrClass) {
if (irClass.isObject && !irClass.isCompanion) {
if (irClass.isNonCompanionObject) {
irClass.handle()
} else {
(irClass.declarations.singleOrNull { it is IrClass && it.isCompanion } as IrClass?)?.handle()
Expand Down
Loading

0 comments on commit 7117932

Please sign in to comment.