Skip to content

Commit

Permalink
[internal] scala: generate the JVM names seen by Java code for Scala …
Browse files Browse the repository at this point in the history
…code (#13696)

As described in #13662, Java code that tries to import Scala symbols is not having that dependency inferred by Pants. The cause is that the Scala backend does not expose the transformed names seen by Java code for Scala symbols.

For example, the instance for `object Foo` in package `org.pantsbuild.example` is actually `org.pantsbuild.example.Foo$.MODULE$`.

This PR is the first step in solving this by generating the transformed names as part of source analysis. Only the `object` case is handled for now. A subsequent PR will use this information for Java dependency inference.

[ci skip-rust]
  • Loading branch information
Tom Dyas authored Nov 23, 2021
1 parent a1fff7c commit 9e4d9c4
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,24 @@ import scala.meta._
import scala.meta.transversers.Traverser

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.reflect.NameTransformer

case class AnImport(name: String, isWildcard: Boolean)

case class Analysis(
providedNames: Vector[String],
importsByScope: HashMap[String, ArrayBuffer[AnImport]],
consumedSymbolsByScope: HashMap[String, HashSet[String]],
providedSymbols: Vector[String],
providedSymbolsEncoded: Vector[String],
importsByScope: HashMap[String, ArrayBuffer[AnImport]],
consumedSymbolsByScope: HashMap[String, HashSet[String]],
)

case class ProvidedSymbol(sawClass: Boolean, sawTrait: Boolean, sawObject: Boolean)

class SourceAnalysisTraverser extends Traverser {
val nameParts = ArrayBuffer[String]()
var skipProvidedNames = false

val providedNames = ArrayBuffer[String]()
val providedSymbolsByScope = HashMap[String, HashMap[String, ProvidedSymbol]]()
val importsByScope = HashMap[String, ArrayBuffer[AnImport]]()
val consumedSymbolsByScope = HashMap[String, HashSet[String]]()

Expand Down Expand Up @@ -74,10 +78,29 @@ class SourceAnalysisTraverser extends Traverser {
}
}

def recordProvidedName(name: String): Unit = {
def recordProvidedName(symbolName: String, sawClass: Boolean = false, sawTrait: Boolean = false, sawObject: Boolean = false): Unit = {
if (!skipProvidedNames) {
val fullPackageName = nameParts.mkString(".")
providedNames.append(s"${fullPackageName}.${name}")
if (!providedSymbolsByScope.contains(fullPackageName)) {
providedSymbolsByScope(fullPackageName) = HashMap[String, ProvidedSymbol]()
}
val providedSymbols = providedSymbolsByScope(fullPackageName)

if (providedSymbols.contains(symbolName)) {
val existingSymbol = providedSymbols(symbolName)
val newSymbol = ProvidedSymbol(
sawClass = existingSymbol.sawClass || sawClass,
sawTrait = existingSymbol.sawTrait || sawTrait,
sawObject = existingSymbol.sawObject || sawObject,
)
providedSymbols(symbolName) = newSymbol
} else {
providedSymbols(symbolName) = ProvidedSymbol(
sawClass = sawClass,
sawTrait = sawTrait,
sawObject = sawObject
)
}
}
}

Expand Down Expand Up @@ -126,19 +149,19 @@ class SourceAnalysisTraverser extends Traverser {

case Defn.Class(_mods, nameNode, _tparams, _ctor, templ) => {
val name = extractName(nameNode)
recordProvidedName(name)
recordProvidedName(name, sawClass = true)
visitTemplate(templ, name)
}

case Defn.Trait(_mods, nameNode, _tparams, _ctor, templ) => {
val name = extractName(nameNode)
recordProvidedName(name)
recordProvidedName(name, sawTrait = true)
visitTemplate(templ, name)
}

case Defn.Object(_mods, nameNode, templ) => {
val name = extractName(nameNode)
recordProvidedName(name)
recordProvidedName(name, sawObject = true)
visitTemplate(templ, name)
}

Expand Down Expand Up @@ -231,6 +254,37 @@ class SourceAnalysisTraverser extends Traverser {

case node => super.apply(node)
}

def gatherProvidedSymbols(): Vector[String] = {
providedSymbolsByScope.flatMap({ case (scopeName, symbolsForScope) =>
symbolsForScope.keys.map(symbolName => s"${scopeName}.${symbolName}").toVector
}).toVector
}

def gatherEncodedProvidedSymbols(): Vector[String] = {
providedSymbolsByScope.flatMap({ case (scopeName, symbolsForScope) =>
val encodedSymbolsForScope = symbolsForScope.flatMap({ case (symbolName, symbol) => {
val encodedSymbolName = NameTransformer.encode(symbolName)
val result = ArrayBuffer[String](encodedSymbolName)
if (symbol.sawObject) {
result.append(encodedSymbolName + "$")
result.append(encodedSymbolName + "$.MODULE$")
}
result.toVector
}})

encodedSymbolsForScope.map(symbolName => s"${scopeName}.${symbolName}")
}).toVector
}

def toAnalysis: Analysis = {
Analysis(
providedSymbols = gatherProvidedSymbols(),
providedSymbolsEncoded = gatherEncodedProvidedSymbols(),
importsByScope = importsByScope,
consumedSymbolsByScope = consumedSymbolsByScope,
)
}
}

object ScalaParser {
Expand All @@ -244,12 +298,7 @@ object ScalaParser {

val analysisTraverser = new SourceAnalysisTraverser()
analysisTraverser.apply(tree)

Analysis(
providedNames = analysisTraverser.providedNames.toVector,
importsByScope = analysisTraverser.importsByScope,
consumedSymbolsByScope = analysisTraverser.consumedSymbolsByScope,
)
analysisTraverser.toAnalysis
}

def main(args: Array[String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def to_debug_json_dict(self) -> dict[str, Any]:

@dataclass(frozen=True)
class ScalaSourceDependencyAnalysis:
provided_names: FrozenOrderedSet[str]
provided_symbols: FrozenOrderedSet[str]
provided_symbols_encoded: FrozenOrderedSet[str]
imports_by_scope: FrozenDict[str, tuple[ScalaImport, ...]]
consumed_symbols_by_scope: FrozenDict[str, FrozenOrderedSet[str]]

Expand Down Expand Up @@ -165,7 +166,8 @@ def fully_qualified_consumed_symbols(self) -> Iterator[str]:
@classmethod
def from_json_dict(cls, d: dict) -> ScalaSourceDependencyAnalysis:
return cls(
provided_names=FrozenOrderedSet(d["providedNames"]),
provided_symbols=FrozenOrderedSet(d["providedSymbols"]),
provided_symbols_encoded=FrozenOrderedSet(d["providedSymbolsEncoded"]),
imports_by_scope=FrozenDict(
{
key: tuple(ScalaImport.from_json_dict(v) for v in values)
Expand All @@ -182,7 +184,8 @@ def from_json_dict(cls, d: dict) -> ScalaSourceDependencyAnalysis:

def to_debug_json_dict(self) -> dict[str, Any]:
return {
"provided_names": list(self.provided_names),
"provided_symbols": list(self.provided_symbols),
"provided_symbols_encoded": list(self.provided_symbols_encoded),
"imports_by_scope": {
key: [v.to_debug_json_dict() for v in values]
for key, values in self.imports_by_scope.items()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,39 +147,79 @@ def this(bar: SomeTypeInSecondaryConstructor) {
[source_files],
)

assert analysis.provided_names == FrozenOrderedSet(
[
"org.pantsbuild.example.OuterClass",
"org.pantsbuild.example.OuterClass.NestedVal",
"org.pantsbuild.example.OuterClass.NestedVar",
"org.pantsbuild.example.OuterClass.NestedTrait",
"org.pantsbuild.example.OuterClass.NestedClass",
"org.pantsbuild.example.OuterClass.NestedType",
"org.pantsbuild.example.OuterClass.NestedObject",
"org.pantsbuild.example.OuterClass.NestedObject.valWithType",
"org.pantsbuild.example.OuterTrait",
"org.pantsbuild.example.OuterTrait.NestedVal",
"org.pantsbuild.example.OuterTrait.NestedVar",
"org.pantsbuild.example.OuterTrait.NestedTrait",
"org.pantsbuild.example.OuterTrait.NestedClass",
"org.pantsbuild.example.OuterTrait.NestedType",
"org.pantsbuild.example.OuterTrait.NestedObject",
"org.pantsbuild.example.OuterObject",
"org.pantsbuild.example.OuterObject.NestedVal",
"org.pantsbuild.example.OuterObject.NestedVar",
"org.pantsbuild.example.OuterObject.NestedTrait",
"org.pantsbuild.example.OuterObject.NestedClass",
"org.pantsbuild.example.OuterObject.NestedType",
"org.pantsbuild.example.OuterObject.NestedObject",
"org.pantsbuild.example.Functions",
"org.pantsbuild.example.Functions.func1",
"org.pantsbuild.example.Functions.func2",
"org.pantsbuild.example.Functions.func3",
"org.pantsbuild.example.ASubClass",
"org.pantsbuild.example.ASubTrait",
"org.pantsbuild.example.HasPrimaryConstructor",
]
)
assert sorted(list(analysis.provided_symbols)) == [
"org.pantsbuild.example.ASubClass",
"org.pantsbuild.example.ASubTrait",
"org.pantsbuild.example.Functions",
"org.pantsbuild.example.Functions.func1",
"org.pantsbuild.example.Functions.func2",
"org.pantsbuild.example.Functions.func3",
"org.pantsbuild.example.HasPrimaryConstructor",
"org.pantsbuild.example.OuterClass",
"org.pantsbuild.example.OuterClass.NestedClass",
"org.pantsbuild.example.OuterClass.NestedObject",
"org.pantsbuild.example.OuterClass.NestedObject.valWithType",
"org.pantsbuild.example.OuterClass.NestedTrait",
"org.pantsbuild.example.OuterClass.NestedType",
"org.pantsbuild.example.OuterClass.NestedVal",
"org.pantsbuild.example.OuterClass.NestedVar",
"org.pantsbuild.example.OuterObject",
"org.pantsbuild.example.OuterObject.NestedClass",
"org.pantsbuild.example.OuterObject.NestedObject",
"org.pantsbuild.example.OuterObject.NestedTrait",
"org.pantsbuild.example.OuterObject.NestedType",
"org.pantsbuild.example.OuterObject.NestedVal",
"org.pantsbuild.example.OuterObject.NestedVar",
"org.pantsbuild.example.OuterTrait",
"org.pantsbuild.example.OuterTrait.NestedClass",
"org.pantsbuild.example.OuterTrait.NestedObject",
"org.pantsbuild.example.OuterTrait.NestedTrait",
"org.pantsbuild.example.OuterTrait.NestedType",
"org.pantsbuild.example.OuterTrait.NestedVal",
"org.pantsbuild.example.OuterTrait.NestedVar",
]

assert sorted(list(analysis.provided_symbols_encoded)) == [
"org.pantsbuild.example.ASubClass",
"org.pantsbuild.example.ASubTrait",
"org.pantsbuild.example.Functions",
"org.pantsbuild.example.Functions$",
"org.pantsbuild.example.Functions$.MODULE$",
"org.pantsbuild.example.Functions.func1",
"org.pantsbuild.example.Functions.func2",
"org.pantsbuild.example.Functions.func3",
"org.pantsbuild.example.HasPrimaryConstructor",
"org.pantsbuild.example.OuterClass",
"org.pantsbuild.example.OuterClass.NestedClass",
"org.pantsbuild.example.OuterClass.NestedObject",
"org.pantsbuild.example.OuterClass.NestedObject$",
"org.pantsbuild.example.OuterClass.NestedObject$.MODULE$",
"org.pantsbuild.example.OuterClass.NestedObject.valWithType",
"org.pantsbuild.example.OuterClass.NestedTrait",
"org.pantsbuild.example.OuterClass.NestedType",
"org.pantsbuild.example.OuterClass.NestedVal",
"org.pantsbuild.example.OuterClass.NestedVar",
"org.pantsbuild.example.OuterObject",
"org.pantsbuild.example.OuterObject$",
"org.pantsbuild.example.OuterObject$.MODULE$",
"org.pantsbuild.example.OuterObject.NestedClass",
"org.pantsbuild.example.OuterObject.NestedObject",
"org.pantsbuild.example.OuterObject.NestedObject$",
"org.pantsbuild.example.OuterObject.NestedObject$.MODULE$",
"org.pantsbuild.example.OuterObject.NestedTrait",
"org.pantsbuild.example.OuterObject.NestedType",
"org.pantsbuild.example.OuterObject.NestedVal",
"org.pantsbuild.example.OuterObject.NestedVar",
"org.pantsbuild.example.OuterTrait",
"org.pantsbuild.example.OuterTrait.NestedClass",
"org.pantsbuild.example.OuterTrait.NestedObject",
"org.pantsbuild.example.OuterTrait.NestedObject$",
"org.pantsbuild.example.OuterTrait.NestedObject$.MODULE$",
"org.pantsbuild.example.OuterTrait.NestedTrait",
"org.pantsbuild.example.OuterTrait.NestedType",
"org.pantsbuild.example.OuterTrait.NestedVal",
"org.pantsbuild.example.OuterTrait.NestedVar",
]

assert analysis.imports_by_scope == FrozenDict(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def map_first_party_scala_targets_to_symbols(

symbol_map = SymbolMap()
for address, analysis in address_and_analysis:
for symbol in analysis.provided_names:
for symbol in analysis.provided_symbols:
symbol_map.add_symbol(symbol, address)

return symbol_map
Expand Down

0 comments on commit 9e4d9c4

Please sign in to comment.