diff --git a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala index f8cfd820e6c..03906bcb22d 100644 --- a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala +++ b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala @@ -35,11 +35,16 @@ object Analysis { case class ProvidedSymbol(name: String, recursive: Boolean) } -case class ProvidedSymbol(sawClass: Boolean, sawTrait: Boolean, sawObject: Boolean, recursive: Boolean) +case class ProvidedSymbol( + sawClass: Boolean, + sawTrait: Boolean, + sawObject: Boolean, + recursive: Boolean +) class SourceAnalysisTraverser extends Traverser { val nameParts = ArrayBuffer[String]() - var skipProvidedNames = false + var skipProvidedNames = false val providedSymbolsByScope = HashMap[String, HashMap[String, ProvidedSymbol]]() val importsByScope = HashMap[String, ArrayBuffer[AnImport]]() @@ -56,7 +61,7 @@ class SourceAnalysisTraverser extends Traverser { case (None, None) => None } - def maybeExtractName(tree: Tree): Option[String] = + def maybeExtractName(tree: Tree): Option[String] = tree match { case Term.Select(qual, name) => extractNameSelect(qual, name) case Type.Select(qual, name) => extractNameSelect(qual, name) @@ -192,9 +197,9 @@ class SourceAnalysisTraverser extends Traverser { def visitMods(mods: List[Mod]): Unit = { mods.foreach({ - case Mod.Annot(init) => - apply(init) // rely on `Init` extraction in main parsing match code - case _ => () + case Mod.Annot(init) => + apply(init) // rely on `Init` extraction in main parsing match code + case _ => () }) } @@ -209,8 +214,25 @@ class SourceAnalysisTraverser extends Traverser { visitMods(mods) val name = extractName(nameNode) recordScope(name) - recordProvidedName(name, sawObject = true) - visitTemplate(templ, name) + + // TODO: should object already be recursive? + // an object is recursive if extends another type because we cannot figure out the provided types + // in the parents, we just mark the object as recursive (which is indicated by non-empty inits) + val recursive = !templ.inits.isEmpty + recordProvidedName(name, sawObject = true, recursive = recursive) + + // visitTemplate visits the inits part of the template in the outer scope, + // however for a package object the inits part can actually be found both in the inner scope as well (package inner). + // therefore we are not calling visitTemplate, calling all the apply methods in the inner scope. + // issue https://github.com/pantsbuild/pants/issues/16259 + withNamePart( + name, + () => { + templ.inits.foreach(init => apply(init)) + apply(templ.early) + apply(templ.stats) + } + ) } case Defn.Class(mods, nameNode, _tparams, ctor, templ) => { @@ -238,7 +260,7 @@ class SourceAnalysisTraverser extends Traverser { // in the parents, we just mark the object as recursive (which is indicated by non-empty inits) val recursive = !templ.inits.isEmpty recordProvidedName(name, sawObject = true, recursive = recursive) - + // If the object is recursive, no need to provide the symbols inside if (recursive) withSuppressProvidedNames(() => visitTemplate(templ, name)) @@ -333,7 +355,7 @@ class SourceAnalysisTraverser extends Traverser { } case Init(tpe, _name, argss) => { - extractNamesFromTypeTree(tpe).foreach(recordConsumedSymbol(_)) + extractNamesFromTypeTree(tpe).foreach(recordConsumedSymbol(_)) argss.foreach(_.foreach(apply)) } @@ -376,7 +398,9 @@ class SourceAnalysisTraverser extends Traverser { def gatherProvidedSymbols(): Vector[Analysis.ProvidedSymbol] = { providedSymbolsByScope .flatMap({ case (scopeName, symbolsForScope) => - symbolsForScope.map { case(symbolName, symbol) => Analysis.ProvidedSymbol(s"${scopeName}.${symbolName}", symbol.recursive)}.toVector + symbolsForScope.map { case (symbolName, symbol) => + Analysis.ProvidedSymbol(s"${scopeName}.${symbolName}", symbol.recursive) + }.toVector }) .toVector } @@ -387,10 +411,14 @@ class SourceAnalysisTraverser extends Traverser { val encodedSymbolsForScope = symbolsForScope.flatMap({ case (symbolName, symbol) => { val encodedSymbolName = NameTransformer.encode(symbolName) - val result = ArrayBuffer[Analysis.ProvidedSymbol](Analysis.ProvidedSymbol(encodedSymbolName, symbol.recursive)) + val result = ArrayBuffer[Analysis.ProvidedSymbol]( + Analysis.ProvidedSymbol(encodedSymbolName, symbol.recursive) + ) if (symbol.sawObject) { result.append(Analysis.ProvidedSymbol(encodedSymbolName + "$", symbol.recursive)) - result.append(Analysis.ProvidedSymbol(encodedSymbolName + "$.MODULE$", symbol.recursive)) + result.append( + Analysis.ProvidedSymbol(encodedSymbolName + "$.MODULE$", symbol.recursive) + ) } result.toVector } diff --git a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py index f58ceba4e5c..a67f1b6d91e 100644 --- a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py +++ b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py @@ -577,3 +577,19 @@ def test_object_extends_ctor(rule_runner: RuleRunner) -> None: "foo.Bar", "foo.hello", ] + + +def test_package_object_extends_trait(rule_runner: RuleRunner) -> None: + analysis = _analyze( + rule_runner, + textwrap.dedent( + """ + package foo + + package object bar extends Trait { + } + """ + ), + ) + + assert sorted(analysis.fully_qualified_consumed_symbols()) == ["foo.Trait", "foo.bar.Trait"]