diff --git a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala index 644f362491f..97a881c711f 100644 --- a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala +++ b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala @@ -11,6 +11,7 @@ import io.circe.syntax._ import scala.meta._ import scala.meta.transversers.Traverser +import scala.collection.SortedSet import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.reflect.NameTransformer @@ -25,14 +26,15 @@ case class AnImport( ) case class Analysis( - providedSymbols: Vector[Analysis.ProvidedSymbol], - providedSymbolsEncoded: Vector[Analysis.ProvidedSymbol], + providedSymbols: SortedSet[Analysis.ProvidedSymbol], + providedSymbolsEncoded: SortedSet[Analysis.ProvidedSymbol], importsByScope: HashMap[String, ArrayBuffer[AnImport]], consumedSymbolsByScope: HashMap[String, HashSet[String]], scopes: Vector[String] ) object Analysis { case class ProvidedSymbol(name: String, recursive: Boolean) + implicit val providedSymbolOrdering: Ordering[ProvidedSymbol] = Ordering.by(_.name) } case class ProvidedSymbol( @@ -80,7 +82,7 @@ class SourceAnalysisTraverser extends Traverser { case Type.Name(name) => Vector(name) case Type.Select(qual, Type.Name(name)) => { val qualName = extractName(qual) - Vector(s"${qualName}.${name}") + Vector(qualifyName(qualName, name)) } case Type.Apply(tpe, args) => extractNamesFromTypeTree(tpe) ++ args.toVector.flatMap(extractNamesFromTypeTree(_)) @@ -404,17 +406,17 @@ class SourceAnalysisTraverser extends Traverser { case node => super.apply(node) } - def gatherProvidedSymbols(): Vector[Analysis.ProvidedSymbol] = { + def gatherProvidedSymbols(): SortedSet[Analysis.ProvidedSymbol] = { providedSymbolsByScope .flatMap({ case (scopeName, symbolsForScope) => symbolsForScope.map { case (symbolName, symbol) => - Analysis.ProvidedSymbol(s"${scopeName}.${symbolName}", symbol.recursive) + Analysis.ProvidedSymbol(qualifyName(scopeName, symbolName), symbol.recursive) }.toVector }) - .toVector + .to(SortedSet) } - def gatherEncodedProvidedSymbols(): Vector[Analysis.ProvidedSymbol] = { + def gatherEncodedProvidedSymbols(): SortedSet[Analysis.ProvidedSymbol] = { providedSymbolsByScope .flatMap({ case (scopeName, symbolsForScope) => val encodedSymbolsForScope = symbolsForScope.flatMap({ @@ -433,9 +435,9 @@ class SourceAnalysisTraverser extends Traverser { } }) - encodedSymbolsForScope.map(symbol => symbol.copy(name = s"${scopeName}.${symbol.name}")) + encodedSymbolsForScope.map(symbol => symbol.copy(name = qualifyName(scopeName, symbol.name))) }) - .toVector + .to(SortedSet) } def toAnalysis: Analysis = { @@ -447,6 +449,11 @@ class SourceAnalysisTraverser extends Traverser { scopes = scopes.toVector ) } + + private def qualifyName(qualifier: String, name: String): String = { + if (qualifier.length > 0) s"$qualifier.$name" + else name + } } object ScalaParser { diff --git a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py index c3f4023a555..6e969e89f5a 100644 --- a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py +++ b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py @@ -595,6 +595,40 @@ def test_package_object_extends_trait(rule_runner: RuleRunner) -> None: assert sorted(analysis.fully_qualified_consumed_symbols()) == ["foo.Trait", "foo.bar.Trait"] +def test_types_at_toplevel_package(rule_runner: RuleRunner) -> None: + analysis = _analyze( + rule_runner, + textwrap.dedent( + """\ + trait Foo + + class Bar + + object Quxx + """ + ), + ) + + expected_symbols = [ + ScalaProvidedSymbol("Foo", False), + ScalaProvidedSymbol("Bar", False), + ScalaProvidedSymbol("Quxx", False), + ] + + expected_symbols_encoded = expected_symbols.copy() + expected_symbols_encoded.extend( + [ScalaProvidedSymbol("Quxx$", False), ScalaProvidedSymbol("Quxx$.MODULE$", False)] + ) + + def by_name(symbol: ScalaProvidedSymbol) -> str: + return symbol.name + + assert analysis.provided_symbols == FrozenOrderedSet(sorted(expected_symbols, key=by_name)) + assert analysis.provided_symbols_encoded == FrozenOrderedSet( + sorted(expected_symbols_encoded, key=by_name) + ) + + def test_type_constaint(rule_runner: RuleRunner) -> None: analysis = _analyze( rule_runner,