Skip to content

Commit

Permalink
[pysrc2cpg] refactor <module>, ANY and __init__ constants
Browse files Browse the repository at this point in the history
  • Loading branch information
xavierpinho committed May 28, 2024
1 parent 038c849 commit 241c1ec
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ object Constants {
val builtinIntType = s"${builtinPrefix}int"
val builtinFloatType = s"${builtinPrefix}float"
val builtinComplexType = s"${builtinPrefix}complex"

val moduleName = "<module>"
val initName = "__init__"
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class ContextStack {
*/
def considerAsGlobalVariable(lhs: NewNode): Unit = {
lhs match {
case n: NewIdentifier if findEnclosingMethodContext(stack).scopeName.contains("<module>") =>
case n: NewIdentifier if findEnclosingMethodContext(stack).scopeName.contains(Constants.moduleName) =>
addGlobalVariable(n.name)
case _ =>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class DynamicTypeHintFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPass[CfgN
}

private def pythonicTypeNameToImport(fullName: String): String =
fullName.replaceFirst("\\.py:<module>", "").replaceAll(Pattern.quote(File.separator), ".")
fullName.replaceFirst(s"\\.py:${Constants.moduleName}", "").replaceAll(Pattern.quote(File.separator), ".")

private def setTypeHints(
diffGraph: BatchedUpdate.DiffGraphBuilder,
Expand All @@ -86,7 +86,7 @@ class DynamicTypeHintFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPass[CfgN
val typeFilePath = typeHintFullName.replaceAll("\\.", Matcher.quoteReplacement(File.separator))
val pythonicTypeFullName = importFullPath.split("\\.").lastOption match {
case Some(typeName) =>
typeFilePath.stripSuffix(s"${File.separator}$typeName").concat(s".py:<module>.$typeName")
typeFilePath.stripSuffix(s"${File.separator}$typeName").concat(s".py:${Constants.moduleName}.$typeName")
case None => typeHintFullName
}
cpg.typeDecl.fullName(s".*${Pattern.quote(pythonicTypeFullName)}").l match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class PythonAstVisitor(
edgeBuilder.astEdge(namespaceBlockNode, fileNode, 1)
contextStack.setFileNamespaceBlock(namespaceBlockNode)

val methodFullName = calculateFullNameFromContext("<module>")
val methodFullName = calculateFullNameFromContext(Constants.moduleName)

val firstLineAndCol = module.stmts.headOption.map(lineAndColOf)
val lastLineAndCol = module.stmts.lastOption.map(lineAndColOf)
Expand All @@ -105,9 +105,9 @@ class PythonAstVisitor(

val moduleMethodNode =
createMethod(
"<module>",
Constants.moduleName,
methodFullName,
Some("<module>"),
Some(Constants.moduleName),
ModifierTypes.VIRTUAL :: ModifierTypes.MODULE :: Nil,
parameterProvider = () => MethodParameters.empty(),
bodyProvider = () => createBuiltinIdentifiers(memOpCalculator.names) ++ module.stmts.map(convert),
Expand Down Expand Up @@ -402,7 +402,7 @@ class PythonAstVisitor(

// For every method that is a module, the local variables can be imported by other modules. This behaviour is
// much like fields so they are to be linked as fields to this method type
if (name == "<module>") contextStack.createMemberLinks(typeDeclNode, edgeBuilder.astEdge)
if (name == Constants.moduleName) contextStack.createMemberLinks(typeDeclNode, edgeBuilder.astEdge)

contextStack.pop()
edgeBuilder.astEdge(typeDeclNode, contextStack.astParent, contextStack.order.getAndInc)
Expand Down Expand Up @@ -487,7 +487,7 @@ class PythonAstVisitor(
val functions = classDef.body.collect { case func: ast.FunctionDef => func }

// __init__ method has to be in functions because "async def __init__" is invalid.
val initFunctionOption = functions.find(_.name == "__init__")
val initFunctionOption = functions.find(_.name == Constants.initName)

val initParameters = initFunctionOption.map(_.args).getOrElse {
// Create arguments of a default __init__ function.
Expand Down Expand Up @@ -773,7 +773,7 @@ class PythonAstVisitor(

val initCall = createXDotYCall(
() => createIdentifierNode("cls", Load, lineAndColumn),
"__init__",
Constants.initName,
xMayHaveSideEffects = false,
lineAndColumn,
argumentWithInstance,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class PythonImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) {
private val moduleCache: mutable.HashMap[String, ImportableEntity] = mutable.HashMap.empty

override def init(): Unit = {
cpg.typeDecl.isExternal(false).nameExact("<module>").foreach { moduleType =>
cpg.typeDecl.isExternal(false).nameExact(Constants.moduleName).foreach { moduleType =>
val modulePath = fileToPythonImportNotation(moduleType.filename)
cpg.method.fullNameExact(moduleType.fullName).headOption.foreach { moduleMethod =>
moduleCache.put(modulePath, Module(moduleType, moduleMethod))
Expand All @@ -48,7 +48,7 @@ class PythonImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) {
.stripPrefix(codeRootDir)
.replaceAll(Matcher.quoteReplacement(JFile.separator), ".")
.stripSuffix(".py")
.stripSuffix(".__init__")
.stripSuffix(s".${Constants.initName}")

override protected def optionalResolveImport(
fileName: String,
Expand Down Expand Up @@ -103,16 +103,20 @@ class PythonImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) {

def toUnresolvedImport(pseudoPath: String): Set[EvaluatedImport] = {
if (isMaybeConstructor) {
Set(UnknownMethod(Seq(pseudoPath, "__init__").mkString(pathSep.toString), alias), UnknownTypeDecl(pseudoPath))
Set(
UnknownMethod(Seq(pseudoPath, Constants.initName).mkString(pathSep.toString), alias),
UnknownTypeDecl(pseudoPath)
)
} else {
Set(UnknownImport(pseudoPath))
}
}

expEntity.split(pathSep).reverse.toList match
case name :: Nil => toUnresolvedImport(s"$name.py:<module>")
case name :: xs => toUnresolvedImport(s"${xs.reverse.mkString(JFile.separator)}.py:<module>$pathSep$name")
case Nil => Set.empty
case name :: Nil => toUnresolvedImport(s"$name.py:${Constants.moduleName}")
case name :: xs =>
toUnresolvedImport(s"${xs.reverse.mkString(JFile.separator)}.py:${Constants.moduleName}$pathSep$name")
case Nil => Set.empty
}

private sealed trait ImportableEntity {
Expand Down Expand Up @@ -140,6 +144,6 @@ class PythonImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) {

private case class ImportableType(typ: TypeDecl) extends ImportableEntity {
override def toResolvedImport(alias: String): List[EvaluatedImport] =
List(ResolvedTypeDecl(typ.fullName), ResolvedMethod(s"${typ.fullName}.__init__", typ.name))
List(ResolvedTypeDecl(typ.fullName), ResolvedMethod(s"${typ.fullName}.${Constants.initName}", typ.name))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import io.shiftleft.codepropertygraph.Cpg
*/
class PythonInheritanceNamePass(cpg: Cpg) extends XInheritanceFullNamePass(cpg) {

override val moduleName: String = "<module>"
override val moduleName: String = Constants.moduleName
override val fileExt: String = ".py"

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PythonTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg) {

override def calleeNames(c: Call): Seq[String] = super.calleeNames(c).map {
// Python call from a type
case typ if typ.split("\\.").lastOption.exists(_.charAt(0).isUpper) => s"$typ.__init__"
case typ if typ.split("\\.").lastOption.exists(_.charAt(0).isUpper) => s"$typ.${Constants.initName}"
// Python call from a function pointer
case typ => typ
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder
.member
.nameExact(memberName)
.flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName)
.filterNot(_ == "ANY")
.filterNot(_ == Constants.ANY)
.toSet
symbolTable.put(LocalVar(entityName), memberTypes)
case UnknownMethod(fullName, alias, receiver, _) =>
Expand Down Expand Up @@ -96,7 +96,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder
/** If the parent method is module then it can be used as a field.
*/
override def isFieldUncached(i: Identifier): Boolean =
i.method.name.matches("(<module>|__init__)") || super.isFieldUncached(i)
i.method.name.matches(s"(${Constants.moduleName}|${Constants.initName})") || super.isFieldUncached(i)

override def visitIdentifierAssignedToOperator(i: Identifier, c: Call, operation: String): Set[String] = {
operation match {
Expand All @@ -110,7 +110,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder
}

override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = {
val constructorPaths = symbolTable.get(c).map(_.stripSuffix(s"${pathSep}__init__"))
val constructorPaths = symbolTable.get(c).map(_.stripSuffix(s"$pathSep${Constants.initName}"))
associateTypes(i, constructorPaths)
}

Expand Down Expand Up @@ -143,7 +143,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder
}

override def getFieldParents(fa: FieldAccess): Set[String] = {
if (fa.method.name == "<module>") {
if (fa.method.name == Constants.moduleName) {
Set(fa.method.fullName)
} else if (fa.method.typeDecl.nonEmpty) {
val parentTypes = fa.method.typeDecl.fullName.toSet
Expand Down Expand Up @@ -203,7 +203,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder
.foreach { cls =>
val clsPath = classMethod.typeDecl.fullName.toSet
symbolTable.put(LocalVar(cls.name), clsPath)
if (cls.typeFullName == "ANY")
if (cls.typeFullName == Constants.ANY)
builder.setNodeProperty(cls, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, clsPath.toSeq)
}
}
Expand All @@ -223,7 +223,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder
funcName: String,
baseName: Option[String]
): Unit = {
if (funcName != "<module>")
if (funcName != Constants.moduleName)
super.handlePotentialFunctionPointer(funcPtr, baseTypes, funcName, baseName)
}

Expand Down

0 comments on commit 241c1ec

Please sign in to comment.