Skip to content

Commit

Permalink
[Commonizer] Extract CIR classifiers cache from the root node
Browse files Browse the repository at this point in the history
  • Loading branch information
ddolovov committed Nov 26, 2020
1 parent 8d9abed commit eca231a
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.jetbrains.kotlin.descriptors.commonizer.utils.internedClassId
import org.jetbrains.kotlin.descriptors.commonizer.utils.isUnderStandardKotlinPackages

internal class CommonizationVisitor(
private val cache: CirClassifiersCache,
private val root: CirRootNode
) : CirNodeVisitor<Unit, Unit> {
override fun visitRootNode(node: CirRootNode, data: Unit) {
Expand Down Expand Up @@ -87,7 +88,7 @@ internal class CommonizationVisitor(
val companionObjectName = node.targetDeclarations.mapTo(HashSet()) { it!!.companion }.singleOrNull()
if (companionObjectName != null) {
val companionObjectClassId = internedClassId(node.classId, companionObjectName)
val companionObjectNode = root.cache.classes[companionObjectClassId]
val companionObjectNode = cache.classNode(companionObjectClassId)
?: error("Can't find companion object with class ID $companionObjectClassId")

if (companionObjectNode.commonDeclaration() != null) {
Expand Down Expand Up @@ -131,7 +132,7 @@ internal class CommonizationVisitor(
if (expandedClassId.packageFqName.isUnderStandardKotlinPackages)
return null // this case is not supported

val expandedClassNode = root.cache.classes[expandedClassId] ?: return null
val expandedClassNode = cache.classNode(expandedClassId) ?: return null
val expandedClass = expandedClassNode.targetDeclarations[index]
?: error("Can't find expanded class with class ID $expandedClassId and index $index for type alias $classId")

Expand All @@ -147,7 +148,7 @@ internal class CommonizationVisitor(
if (supertypesMap.isNullOrEmpty())
emptyList()
else
supertypesMap.values.compactMapNotNull { supertypesGroup -> commonize(supertypesGroup, TypeCommonizer(root.cache)) }
supertypesMap.values.compactMapNotNull { supertypesGroup -> commonize(supertypesGroup, TypeCommonizer(cache)) }
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.jetbrains.kotlin.descriptors.DescriptorVisibility
import org.jetbrains.kotlin.descriptors.commonizer.cir.*
import org.jetbrains.kotlin.descriptors.commonizer.cir.factory.CirTypeFactory
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirClassifiersCache
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirNode
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirNodeWithClassId
import org.jetbrains.kotlin.descriptors.commonizer.utils.isUnderStandardKotlinPackages
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.types.Variance
Expand Down Expand Up @@ -66,7 +66,7 @@ private class ClassTypeCommonizer(private val cache: CirClassifiersCache) : Abst
isMarkedNullable == next.isMarkedNullable
&& classId == next.classifierId
&& outerType.commonizeWith(next.outerType)
&& commonizeClassifier(classId, cache.classes).first
&& commonizeClassifier(classId) { cache.classNode(classId) }.first
&& arguments.commonizeWith(next.arguments)
}

Expand Down Expand Up @@ -102,7 +102,7 @@ private class TypeAliasTypeCommonizer(private val cache: CirClassifiersCache) :
return false

if (commonizedTypeBuilder == null) {
val (commonized, commonClassifier) = commonizeClassifier(typeAliasId, cache.typeAliases)
val (commonized, commonClassifier) = commonizeClassifier(typeAliasId) { cache.typeAliasNode(typeAliasId) }
if (!commonized)
return false

Expand Down Expand Up @@ -218,15 +218,15 @@ private class TypeArgumentListCommonizer(cache: CirClassifiersCache) : AbstractL

private inline fun <reified T : CirClassifier> commonizeClassifier(
classifierId: ClassId,
classifierNodes: Map<ClassId, CirNode<*, T>>,
classifierNode: (classifierId: ClassId) -> CirNodeWithClassId<*, T>?
): Pair<Boolean, T?> {
if (classifierId.packageFqName.isUnderStandardKotlinPackages) {
/* either class or type alias from Kotlin stdlib */
return true to null
}

/* or descriptors themselves can be commonized */
return when (val node = classifierNodes[classifierId]) {
return when (val node = classifierNode(classifierId)) {
null -> {
// No node means that the class or type alias was not subject for commonization at all, probably it lays
// not in commonized module descriptors but somewhere in their dependencies.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.jetbrains.kotlin.descriptors.commonizer.builder.DeclarationsBuilderVi
import org.jetbrains.kotlin.descriptors.commonizer.builder.createGlobalBuilderComponents
import org.jetbrains.kotlin.descriptors.commonizer.core.CommonizationVisitor
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirTreeMerger
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.DefaultCirClassifiersCache
import org.jetbrains.kotlin.storage.LockBasedStorageManager

fun runCommonization(parameters: Parameters): Result {
Expand All @@ -19,11 +20,12 @@ fun runCommonization(parameters: Parameters): Result {
val storageManager = LockBasedStorageManager("Declaration descriptors commonization")

// build merged tree:
val mergeResult = CirTreeMerger(storageManager, parameters).merge()
val cache = DefaultCirClassifiersCache()
val mergeResult = CirTreeMerger(storageManager, cache, parameters).merge()

// commonize:
val mergedTree = mergeResult.root
mergedTree.accept(CommonizationVisitor(mergedTree), Unit)
mergedTree.accept(CommonizationVisitor(cache, mergedTree), Unit)
parameters.progressLogger?.invoke("Commonized declarations")

// build resulting descriptors:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,31 @@

package org.jetbrains.kotlin.descriptors.commonizer.mergedtree

import gnu.trove.THashMap
import org.jetbrains.kotlin.name.ClassId

interface CirClassifiersCache {
val classes: Map<ClassId, CirClassNode>
val typeAliases: Map<ClassId, CirTypeAliasNode>
fun classNode(classId: ClassId): CirClassNode?
fun typeAliasNode(typeAliasId: ClassId): CirTypeAliasNode?

fun addClassNode(classId: ClassId, node: CirClassNode)
fun addTypeAliasNode(typeAliasId: ClassId, node: CirTypeAliasNode)
}

class DefaultCirClassifiersCache : CirClassifiersCache {
private val classNodes = THashMap<ClassId, CirClassNode>()
private val typeAliases = THashMap<ClassId, CirTypeAliasNode>()

override fun classNode(classId: ClassId): CirClassNode? = classNodes[classId]
override fun typeAliasNode(typeAliasId: ClassId): CirTypeAliasNode? = typeAliases[typeAliasId]

override fun addClassNode(classId: ClassId, node: CirClassNode) {
val oldNode = classNodes.put(classId, node)
check(oldNode == null) { "Rewriting class node $classId" }
}

override fun addTypeAliasNode(typeAliasId: ClassId, node: CirTypeAliasNode) {
val oldNode = typeAliases.put(typeAliasId, node)
check(oldNode == null) { "Rewriting type alias node $typeAliasId" }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@ package org.jetbrains.kotlin.descriptors.commonizer.mergedtree
import gnu.trove.THashMap
import org.jetbrains.kotlin.descriptors.commonizer.cir.CirRoot
import org.jetbrains.kotlin.descriptors.commonizer.utils.CommonizedGroup
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.storage.NullableLazyValue

class CirRootNode(
override val targetDeclarations: CommonizedGroup<CirRoot>,
override val commonDeclaration: NullableLazyValue<CirRoot>
) : CirNode<CirRoot, CirRoot> {
class CirClassifiersCacheImpl : CirClassifiersCache {
override val classes = THashMap<ClassId, CirClassNode>()
override val typeAliases = THashMap<ClassId, CirTypeAliasNode>()
}

val modules: MutableMap<Name, CirModuleNode> = THashMap()
val cache = CirClassifiersCacheImpl()

override fun <T, R> accept(visitor: CirNodeVisitor<T, R>, data: T): R =
visitor.visitRootNode(this, data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import org.jetbrains.kotlin.descriptors.commonizer.Parameters
import org.jetbrains.kotlin.descriptors.commonizer.TargetProvider
import org.jetbrains.kotlin.descriptors.commonizer.cir.CirClass
import org.jetbrains.kotlin.descriptors.commonizer.cir.factory.*
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirRootNode.CirClassifiersCacheImpl
import org.jetbrains.kotlin.descriptors.commonizer.utils.intern
import org.jetbrains.kotlin.descriptors.commonizer.utils.internedClassId
import org.jetbrains.kotlin.name.ClassId
Expand All @@ -24,6 +23,7 @@ import org.jetbrains.kotlin.storage.StorageManager

class CirTreeMerger(
private val storageManager: StorageManager,
private val cache: CirClassifiersCache,
private val parameters: Parameters
) {
class CirTreeMergeResult(
Expand All @@ -32,11 +32,9 @@ class CirTreeMerger(
)

private val size = parameters.targetProviders.size
private lateinit var cacheRW: CirClassifiersCacheImpl

fun merge(): CirTreeMergeResult {
val rootNode: CirRootNode = buildRootNode(storageManager, size)
cacheRW = rootNode.cache

val allModuleInfos: List<Map<String, ModuleInfo>> = parameters.targetProviders.map { it.modulesProvider.loadModuleInfos() }
val commonModuleNames = allModuleInfos.map { it.keys }.reduce { a, b -> a intersect b }
Expand Down Expand Up @@ -140,7 +138,7 @@ class CirTreeMerger(
parentCommonDeclaration: NullableLazyValue<*>?
) {
val propertyNode: CirPropertyNode = properties.getOrPut(PropertyApproximationKey(propertyDescriptor)) {
buildPropertyNode(storageManager, size, cacheRW, parentCommonDeclaration)
buildPropertyNode(storageManager, size, cache, parentCommonDeclaration)
}
propertyNode.targetDeclarations[targetIndex] = CirPropertyFactory.create(propertyDescriptor)
}
Expand All @@ -152,7 +150,7 @@ class CirTreeMerger(
parentCommonDeclaration: NullableLazyValue<*>?
) {
val functionNode: CirFunctionNode = functions.getOrPut(FunctionApproximationKey(functionDescriptor)) {
buildFunctionNode(storageManager, size, cacheRW, parentCommonDeclaration)
buildFunctionNode(storageManager, size, cache, parentCommonDeclaration)
}
functionNode.targetDeclarations[targetIndex] = CirFunctionFactory.create(functionDescriptor)
}
Expand All @@ -168,7 +166,7 @@ class CirTreeMerger(
val classId = classIdFunction(className)

val classNode: CirClassNode = classes.getOrPut(className) {
buildClassNode(storageManager, size, cacheRW, parentCommonDeclaration, classId)
buildClassNode(storageManager, size, cache, parentCommonDeclaration, classId)
}
classNode.targetDeclarations[targetIndex] = CirClassFactory.create(classDescriptor)

Expand Down Expand Up @@ -203,7 +201,7 @@ class CirTreeMerger(
parentCommonDeclaration: NullableLazyValue<*>?
) {
val constructorNode: CirClassConstructorNode = constructors.getOrPut(ConstructorApproximationKey(constructorDescriptor)) {
buildClassConstructorNode(storageManager, size, cacheRW, parentCommonDeclaration)
buildClassConstructorNode(storageManager, size, cache, parentCommonDeclaration)
}
constructorNode.targetDeclarations[targetIndex] = CirClassConstructorFactory.create(constructorDescriptor)
}
Expand All @@ -218,7 +216,7 @@ class CirTreeMerger(
val typeAliasClassId = internedClassId(packageFqName, typeAliasName)

val typeAliasNode: CirTypeAliasNode = typeAliases.getOrPut(typeAliasName) {
buildTypeAliasNode(storageManager, size, cacheRW, typeAliasClassId)
buildTypeAliasNode(storageManager, size, cache, typeAliasClassId)
}
typeAliasNode.targetDeclarations[targetIndex] = CirTypeAliasFactory.create(typeAliasDescriptor)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import org.jetbrains.kotlin.descriptors.commonizer.cir.*
import org.jetbrains.kotlin.descriptors.commonizer.cir.impl.CirClassRecursionMarker
import org.jetbrains.kotlin.descriptors.commonizer.cir.impl.CirClassifierRecursionMarker
import org.jetbrains.kotlin.descriptors.commonizer.core.*
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirRootNode.CirClassifiersCacheImpl
import org.jetbrains.kotlin.descriptors.commonizer.utils.CommonizedGroup
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
Expand Down Expand Up @@ -80,18 +79,18 @@ internal fun buildFunctionNode(
internal fun buildClassNode(
storageManager: StorageManager,
size: Int,
cacheRW: CirClassifiersCacheImpl,
cache: CirClassifiersCache,
parentCommonDeclaration: NullableLazyValue<*>?,
classId: ClassId
): CirClassNode = buildNode(
storageManager = storageManager,
size = size,
parentCommonDeclaration = parentCommonDeclaration,
commonizerProducer = { ClassCommonizer(cacheRW) },
commonizerProducer = { ClassCommonizer(cache) },
recursionMarker = CirClassRecursionMarker,
nodeProducer = { targetDeclarations, commonDeclaration ->
CirClassNode(targetDeclarations, commonDeclaration, classId).also {
cacheRW.classes[classId] = it
cache.addClassNode(classId, it)
}
}
)
Expand All @@ -112,16 +111,16 @@ internal fun buildClassConstructorNode(
internal fun buildTypeAliasNode(
storageManager: StorageManager,
size: Int,
cacheRW: CirClassifiersCacheImpl,
classId: ClassId
cache: CirClassifiersCache,
typeAliasId: ClassId
): CirTypeAliasNode = buildNode(
storageManager = storageManager,
size = size,
commonizerProducer = { TypeAliasCommonizer(cacheRW) },
commonizerProducer = { TypeAliasCommonizer(cache) },
recursionMarker = CirClassifierRecursionMarker,
nodeProducer = { targetDeclarations, commonDeclaration ->
CirTypeAliasNode(targetDeclarations, commonDeclaration, classId).also {
cacheRW.typeAliases[classId] = it
CirTypeAliasNode(targetDeclarations, commonDeclaration, typeAliasId).also {
cache.addTypeAliasNode(typeAliasId, it)
}
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ import org.jetbrains.kotlin.descriptors.commonizer.cir.CirType
import org.jetbrains.kotlin.descriptors.commonizer.cir.factory.CirClassFactory
import org.jetbrains.kotlin.descriptors.commonizer.cir.factory.CirTypeAliasFactory
import org.jetbrains.kotlin.descriptors.commonizer.cir.factory.CirTypeFactory
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirClassifiersCache
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.CirRootNode.CirClassifiersCacheImpl
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.buildClassNode
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.buildTypeAliasNode
import org.jetbrains.kotlin.descriptors.commonizer.mergedtree.*
import org.jetbrains.kotlin.descriptors.commonizer.utils.mockClassType
import org.jetbrains.kotlin.descriptors.commonizer.utils.mockTAType
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.resolve.descriptorUtil.classId
import org.jetbrains.kotlin.storage.LockBasedStorageManager
import org.jetbrains.kotlin.types.KotlinType
Expand All @@ -26,11 +24,11 @@ import org.junit.Test

class TypeCommonizerTest : AbstractCommonizerTest<CirType, CirType>() {

private lateinit var cache: CirClassifiersCacheImpl
private lateinit var cache: CirClassifiersCache

@Before
fun initialize() {
cache = CirClassifiersCacheImpl() // reset cache
cache = DefaultCirClassifiersCache() // reset cache
}

@Test
Expand Down Expand Up @@ -467,25 +465,25 @@ class TypeCommonizerTest : AbstractCommonizerTest<CirType, CirType>() {
when (descriptor) {
is ClassDescriptor -> {
val classId = descriptor.classId ?: error("No class ID for ${descriptor::class.java}, $descriptor")
val node = cache.classes.getOrPut(classId) {
val node = cache.classNode(classId) {
buildClassNode(
storageManager = LockBasedStorageManager.NO_LOCKS,
size = variants.size,
cacheRW = cache,
cache = cache,
parentCommonDeclaration = null,
classId = classId
)
}
node.targetDeclarations[index] = CirClassFactory.create(descriptor)
}
is TypeAliasDescriptor -> {
val classId = descriptor.classId ?: error("No class ID for ${descriptor::class.java}, $descriptor")
val node = cache.typeAliases.getOrPut(classId) {
val typeAliasId = descriptor.classId ?: error("No class ID for ${descriptor::class.java}, $descriptor")
val node = cache.typeAliasNode(typeAliasId) {
buildTypeAliasNode(
storageManager = LockBasedStorageManager.NO_LOCKS,
size = variants.size,
cacheRW = cache,
classId = classId
cache = cache,
typeAliasId = typeAliasId
)
}
node.targetDeclarations[index] = CirTypeAliasFactory.create(descriptor)
Expand Down Expand Up @@ -526,5 +524,11 @@ class TypeCommonizerTest : AbstractCommonizerTest<CirType, CirType>() {
companion object {
fun areEqual(cache: CirClassifiersCache, a: CirType, b: CirType): Boolean =
TypeCommonizer(cache).run { commonizeWith(a) && commonizeWith(b) }

private fun CirClassifiersCache.classNode(classId: ClassId, computation: () -> CirClassNode) =
classNode(classId) ?: computation()

private fun CirClassifiersCache.typeAliasNode(typeAliasId: ClassId, computation: () -> CirTypeAliasNode) =
typeAliasNode(typeAliasId) ?: computation()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ private fun createPackageFragmentForClassifier(classifierFqName: FqName): Packag
}

internal val EMPTY_CLASSIFIERS_CACHE = object : CirClassifiersCache {
override val classes: Map<ClassId, CirClassNode> get() = emptyMap()
override val typeAliases: Map<ClassId, CirTypeAliasNode> get() = emptyMap()
override fun classNode(classId: ClassId): CirClassNode? = null
override fun typeAliasNode(typeAliasId: ClassId): CirTypeAliasNode? = null
override fun addClassNode(classId: ClassId, node: CirClassNode) = error("This method should not be called")
override fun addTypeAliasNode(typeAliasId: ClassId, node: CirTypeAliasNode) = error("This method should not be called")
}

internal class MockBuiltInsProvider(private val builtIns: KotlinBuiltIns) : BuiltInsProvider {
Expand Down

0 comments on commit eca231a

Please sign in to comment.