Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kotlin2cpg] Implemented Lambda Parameter Destruction #5095

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,37 +1,34 @@
package io.joern.kotlin2cpg

object Constants {
val alloc = "alloc"
val caseNodeParserTypeName = "CaseNode"
val caseNodePrefix = "case"
val codeForLoweredForBlock = "FOR-BLOCK" // TODO: improve this
val collectionsIteratorName = "kotlin.collections.Iterator"
val companionObjectMemberName = "object"
val componentNPrefix = "component"
val defaultCaseNode = "default"
val empty = "<empty>"
val getIteratorMethodName = "iterator"
val hasNextIteratorMethodName = "hasNext"
val importKeyword = "import"
val init = io.joern.x2cpg.Defines.ConstructorMethodName
val iteratorPrefix = "iterator_"
val javaUtilIterator = "java.util.Iterator"
val unknownLambdaBindingName = "<unknownBindingName>"
val unknownLambdaBaseClass = "<unknownLambdaBaseClass>"
val lambdaTypeDeclName = "LAMBDA_TYPE_DECL"
val nextIteratorMethodName = "next"
val codePropUndefinedValue = ""
val operatorSuffix = "<operator>"
val paramNameLambdaDestructureDecl = "DESTRUCTURE_PARAM"
val parserTypeName = "KOTLIN_PSI_PARSER"
val retCode = "RET"
val ret = "RET"
val root = "<root>"
val this_ = "this"
val tmpLocalPrefix = "tmp_"
val tryCode = "try"
val unusedDestructuringEntryText = "_"
val unknownOperator = "<operator>.unknown"
val when = "when"
val wildcardImportName = "*"
val alloc = "alloc"
val caseNodeParserTypeName = "CaseNode"
val caseNodePrefix = "case"
val codeForLoweredForBlock = "FOR-BLOCK" // TODO: improve this
val collectionsIteratorName = "kotlin.collections.Iterator"
val companionObjectMemberName = "object"
val componentNPrefix = "component"
val defaultCaseNode = "default"
val empty = "<empty>"
val getIteratorMethodName = "iterator"
val hasNextIteratorMethodName = "hasNext"
val importKeyword = "import"
val init = io.joern.x2cpg.Defines.ConstructorMethodName
val iteratorPrefix = "iterator_"
val javaUtilIterator = "java.util.Iterator"
val unknownLambdaBindingName = "<unknownBindingName>"
val unknownLambdaBaseClass = "<unknownLambdaBaseClass>"
val lambdaTypeDeclName = "LAMBDA_TYPE_DECL"
val nextIteratorMethodName = "next"
val codePropUndefinedValue = ""
val operatorSuffix = "<operator>"
val destructedParamNamePrefix = "<destructed_param>"
val retCode = "RET"
val root = "<root>"
val this_ = "this"
val tmpLocalPrefix = "tmp_"
val unusedDestructuringEntryText = "_"
val unknownOperator = "<operator>.unknown"
val when = "when"
val wildcardImportName = "*"
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ class AstCreator(fileWithMeta: KtFileWithMeta, bindingContext: BindingContext, g
protected val lambdaBindingInfoQueue: mutable.ArrayBuffer[BindingInfo] = mutable.ArrayBuffer.empty
protected val methodAstParentStack: Stack[NewNode] = new Stack()

protected val tmpKeyPool = new IntervalKeyPool(first = 1, last = Long.MaxValue)
protected val iteratorKeyPool = new IntervalKeyPool(first = 1, last = Long.MaxValue)
protected val tmpKeyPool = new IntervalKeyPool(first = 1, last = Long.MaxValue)
protected val destructedParamKeyPool = new IntervalKeyPool(first = 1, last = Long.MaxValue)
protected val iteratorKeyPool = new IntervalKeyPool(first = 1, last = Long.MaxValue)

protected val relativizedPath: String = fileWithMeta.relativizedPath

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) {
val fieldAccessCallAst = callAst(fieldAccessCall, List(thisAst, Ast(fieldIdentifier)))
val methodBlockAst = blockAst(
blockNode(valueParam, fieldAccessCall.code, typeFullName),
List(returnAst(returnNode(valueParam, Constants.ret), List(fieldAccessCallAst)))
List(returnAst(returnNode(valueParam, Constants.retCode), List(fieldAccessCallAst)))
)

val componentIdx = idx + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import io.joern.x2cpg.datastructures.Stack.StackWrapper
import io.joern.x2cpg.utils.NodeBuilders
import io.joern.x2cpg.utils.NodeBuilders.newBindingNode
import io.joern.x2cpg.utils.NodeBuilders.newClosureBindingNode
import io.joern.x2cpg.utils.NodeBuilders.newIdentifierNode
import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode
import io.joern.x2cpg.utils.NodeBuilders.newModifierNode
import io.shiftleft.codepropertygraph.generated.EvaluationStrategies
import io.shiftleft.codepropertygraph.generated.ModifierTypes
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal
import org.jetbrains.kotlin.com.intellij.psi.PsiElement
import org.jetbrains.kotlin.descriptors.ClassDescriptor
Expand Down Expand Up @@ -186,9 +188,92 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) {
)
}

private def astsForDestructuring(param: KtParameter): Seq[Ast] = {
val decl = param.getDestructuringDeclaration
val tmpName = s"${Constants.tmpLocalPrefix}${tmpKeyPool.next}"
var localForTmp = Option.empty[NewLocal]
val additionalLocals = mutable.ArrayBuffer.empty[Ast]

val initCallAst = if (decl.hasInitializer) {
val init = decl.getInitializer
val asts = astsForExpression(init, Some(2))
val initAst =
if (asts.size == 1) { asts.head }
else {
val block = blockNode(init, "", "")
blockAst(block, asts.toList)
}
val local = localNode(decl, tmpName, tmpName, TypeConstants.any)
localForTmp = Some(local)
scope.addToScope(tmpName, local)
val tmpIdentifier = newIdentifierNode(tmpName, TypeConstants.any)
val tmpIdentifierAst = Ast(tmpIdentifier).withRefEdge(tmpIdentifier, local)
val assignmentCallNode = NodeBuilders.newOperatorCallNode(
Operators.assignment,
s"$tmpName = ${init.getText}",
None,
line(init),
column(init)
)
callAst(assignmentCallNode, List(tmpIdentifierAst, initAst))
} else {
val explicitTypeName = Option(param.getTypeReference)
.map(typeRef => fullNameByImportPath(typeRef, param.getContainingKtFile).getOrElse(typeRef.getText))
.getOrElse(TypeConstants.any)
val typeFullName = registerType(
nameRenderer.typeFullName(bindingUtils.getVariableDesc(param).get.getType).getOrElse(explicitTypeName)
)
val localForIt = localNode(decl, "it", "it", typeFullName)
additionalLocals.addOne(Ast(localForIt))
val identifierForIt = newIdentifierNode("it", typeFullName)
val initAst = Ast(identifierForIt).withRefEdge(identifierForIt, localForIt)
val tmpIdentifier = newIdentifierNode(tmpName, typeFullName)
val local = localNode(decl, tmpName, tmpName, typeFullName)
localForTmp = Some(local)
scope.addToScope(tmpName, local)
val tmpIdentifierAst = Ast(tmpIdentifier).withRefEdge(tmpIdentifier, local)
val assignmentCallNode =
NodeBuilders.newOperatorCallNode(Operators.assignment, s"$tmpName = it", None, line(decl), column(decl))
callAst(assignmentCallNode, List(tmpIdentifierAst, initAst))
}

val localsForDestructuringVars = localsForDestructuringEntries(decl)
val assignmentsForEntries =
decl.getEntries.asScala.filterNot(_.getText == Constants.unusedDestructuringEntryText).zipWithIndex.map {
case (entry, idx) =>
val rhsBaseAst = astWithRefEdgeMaybe(
tmpName,
identifierNode(entry, tmpName, tmpName, localForTmp.map(_.typeFullName).getOrElse(TypeConstants.any))
)
assignmentAstForDestructuringEntry(entry, rhsBaseAst, idx + 1)
}

localForTmp
.map(l => Ast(l))
.toSeq ++ additionalLocals ++ localsForDestructuringVars ++ (initCallAst +: assignmentsForEntries)
}

private def astForDestructedParameter(param: KtParameter, order: Int): Ast = {
val name = s"${Constants.destructedParamNamePrefix}${destructedParamKeyPool.next}"
val explicitTypeName = Option(param.getTypeReference)
.map(typeRef =>
fullNameByImportPath(typeRef, param.getContainingKtFile)
.getOrElse(typeRef.getText)
)
.getOrElse(TypeConstants.any)
val typeFullName = registerType(
nameRenderer.typeFullName(bindingUtils.getVariableDesc(param).get.getType).getOrElse(explicitTypeName)
)
val node = parameterInNode(param, name, name, order, false, EvaluationStrategies.BY_VALUE, typeFullName)
scope.addToScope(name, node)

val annotations = param.getAnnotationEntries.asScala.map(astForAnnotationEntry).toSeq
Ast(node).withChildren(annotations)
}

def astForParameter(param: KtParameter, order: Int): Ast = {
val name = if (param.getDestructuringDeclaration != null) {
Constants.paramNameLambdaDestructureDecl
s"${Constants.destructedParamNamePrefix}${destructedParamKeyPool.next}"
} else {
param.getName
}
Expand Down Expand Up @@ -316,8 +401,6 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) {
.withChildren(annotations.map(astForAnnotationEntry))
}

// TODO Handling for destructuring of lambda parameters is missing.
// More specifically the creation and initialisation of the thereby introduced variables.
def astForLambda(
expr: KtLambdaExpression,
argIdxMaybe: Option[Int],
Expand Down Expand Up @@ -363,7 +446,8 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) {
node
}

val paramAsts = mutable.ArrayBuffer.empty[Ast]
val paramAsts = mutable.ArrayBuffer.empty[Ast]
val destructedParamAsts = mutable.ArrayBuffer.empty[Ast]
val valueParamStartIndex =
if (funcDesc.getExtensionReceiverParameter != null) {
// Lambdas which are arguments to function parameters defined
Expand All @@ -382,13 +466,18 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) {
case parameters =>
parameters.zipWithIndex.foreach { (paramDesc, idx) =>
val param = paramDesc.getSource.asInstanceOf[KotlinSourceElement].getPsi.asInstanceOf[KtParameter]
paramAsts.append(astForParameter(param, valueParamStartIndex + idx))
if (param.getDestructuringDeclaration != null) {
paramAsts.append(astForDestructedParameter(param, valueParamStartIndex + idx))
val destructAsts = astsForDestructuring(param)
destructedParamAsts.appendAll(destructAsts)
} else {
paramAsts.append(astForParameter(param, valueParamStartIndex + idx))
}
}
}

val lastChildNotReturnExpression = !expr.getBodyExpression.getLastChild.isInstanceOf[KtReturnExpression]
val needsReturnExpression =
lastChildNotReturnExpression
val needsReturnExpression = lastChildNotReturnExpression
val bodyAsts = Option(expr.getBodyExpression)
.map(
astsForBlock(
Expand All @@ -397,7 +486,8 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) {
None,
pushToScope = false,
localsForCaptured,
implicitReturnAroundLastStatement = needsReturnExpression
implicitReturnAroundLastStatement = needsReturnExpression,
Some(destructedParamAsts.toSeq)
)
)
.getOrElse(Seq(Ast(NewBlock())))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = true) {
implicit val resolver: ICallResolver = NoResolve

"CPG for code containing a lambda with parameter destructuring" should {
val cpg = code("""|package mypkg
val cpg = code("""
|package mypkg
|
|fun f1(p: String) {
| val m = mapOf(p to 1, "two" to 2, "three" to 3)
Expand All @@ -20,20 +21,15 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = true) {
val source = cpg.method.name("f1").parameter
val sink = cpg.call.methodFullName(".*println.*").argument
val flows = sink.reachableByFlows(source)
// fixme: The nature of destructed parameters causes a loss in granularity here, what we see
// is the over-approximation of container `m` tainting called method-ref parameters
flows.map(flowToResultPairs).toSet shouldBe
Set(
List(("f1(p)", Some(3)), ("println(k)", Some(5)))
// List(
// ("f1(p)", Some(3)),
// ("p to 1", Some(4)),
// ("mapOf(p to 1, \"two\" to 2, \"three\" to 3)", Some(4)),
// ("val m = mapOf(p to 1, \"two\" to 2, \"three\" to 3)", Some(4)),
// ("m.forEach { (k, v) -> println(k) }", Some(5)),
// ("<lambda>0(k, v)", Some(5)),
// ("println(k)", Some(5))
// )
List(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting this back up!

("f1(p)", Some(4)),
("tmp_1 = it", None),
("tmp_1.component1()", Some(6)),
("k = tmp_1.component1()", Some(6)),
("println(k)", Some(6))
)
)
}
}
Expand Down
Loading
Loading