Skip to content

Commit

Permalink
Refactor the LLM sample selection for (hopefully) better testability
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanlukasczyk committed Nov 24, 2024
1 parent 2cf904f commit 833956d
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import com.intellij.psi.PsiJavaFile
import com.intellij.psi.PsiManager
import com.intellij.psi.PsiMethod
import com.intellij.psi.search.ProjectAndLibrariesScope
import com.intellij.psi.search.SearchScope
import com.intellij.psi.search.searches.ReferencesSearch
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.util.containers.stream
Expand All @@ -36,12 +37,8 @@ class LLMSampleSelectorBuilder(private val project: Project, private val languag
private val selectionTypeButtonGroup = ButtonGroup()
private val radioButtonsPanel = JPanel()

private val defaultTestName = "<html>provide manually</html>"
private val defaultTestCode = "// provide test method code here"
private val testNames = mutableSetOf(defaultTestName)
private val initialTestCodes = mutableListOf(createTestSampleClass("", defaultTestCode))
private val testSamplePanelFactories: MutableList<TestSamplePanelBuilder> = mutableListOf()
private var testSamplesCode: String = ""
private val sampleSelector = LLMSampleSelector()

private val addButtonPanel = JPanel()
private val addButton = JButton(PluginLabelsBundle.get("addTestSample"))
Expand Down Expand Up @@ -107,7 +104,7 @@ class LLMSampleSelectorBuilder(private val project: Project, private val languag
override fun applyUpdates() {
if (selectionTypeButtons[0].isSelected) {
for (index in testSamplePanelFactories.indices) {
testSamplesCode += "Test sample number ${index + 1}\n```\n${testSamplePanelFactories[index].getCode()}\n```\n"
sampleSelector.appendTestSampleCode(index, testSamplePanelFactories[index].getCode())
}
}
}
Expand All @@ -124,7 +121,7 @@ class LLMSampleSelectorBuilder(private val project: Project, private val languag
*
* @return The test samples code.
*/
fun getTestSamplesCode(): String = testSamplesCode
fun getTestSamplesCode(): String = sampleSelector.getTestSamplesCode()

/**
* Adds action listeners to the selectionTypeButtons array to enable the nextButton if any button is selected.
Expand All @@ -141,9 +138,14 @@ class LLMSampleSelectorBuilder(private val project: Project, private val languag
}

addButton.addActionListener {
collectTestSamples()
val testSamplePanelBuilder =
TestSamplePanelBuilder(project, middlePanel, testNames.toMutableList(), initialTestCodes, language)
sampleSelector.collectTestSamples(project)
val testSamplePanelBuilder = TestSamplePanelBuilder(
project,
middlePanel,
sampleSelector.getTestNames(),
sampleSelector.getInitialTestCodes(),
language,
)
testSamplePanelFactories.add(testSamplePanelBuilder)
val testSamplePanel = testSamplePanelBuilder.getTestSamplePanel()
val codeScrollPanel = testSamplePanelBuilder.getCodeScrollPanel()
Expand Down Expand Up @@ -187,45 +189,171 @@ class LLMSampleSelectorBuilder(private val project: Project, private val languag
testSamplePanelFactory.enabledComponents(isEnabled)
}
}
}

private fun collectTestSamples() {
/**
* A selector for samples for the LLM.
*/
class LLMSampleSelector {
private val defaultTestName = "<html>provide manually</html"
private val defaultTestCode = "// provide test method code here"
private val testNames = mutableSetOf(defaultTestName)
private val initialTestCodes = mutableListOf(createTestSampleClass("", defaultTestCode))
private var testSamplesCode: String = ""

/**
* Retrieves the test samples code.
*
* @return The test samples code.
*/
fun getTestSamplesCode(): String = testSamplesCode

/**
* Provides the list of test names.
*
* @return The list of test names.
*/
fun getTestNames(): MutableList<String> = testNames.toMutableList()

/**
* Provides the initial test codes.
*
* @return The initial test codes
*/
fun getInitialTestCodes(): MutableList<String> = initialTestCodes

fun appendTestSampleCode(index: Int, code: String) {
testSamplesCode += "Test sample number ${index + 1}\n```\n${code}\n```\n"
}

/**
* Collects the test samples for the LLM from the current project.
*/
fun collectTestSamples(project: Project) {
val currentDocument = FileEditorManager.getInstance(project).selectedTextEditor?.document
val currentFile = currentDocument?.let { FileDocumentManager.getInstance().getFile(it) }
collectTestSamplesForCurrentFile(currentFile)

collectTestSamplesForCurrentFile(currentFile!!, project)

if (testNames.size == 1) {
// Only the default test name is there, thus we did not find any tests related to the
// current file; collect all test samples and provide them to the user instead
collectTestSamples(project)
// Only the default test name is there, thus we did not find any tests related to the current file;
// collect all test samples and provide them to the user instead
collectTestSamplesFromProject(project)
}
}

fun collectTestSamplesForCurrentFile(currentFile: VirtualFile?) {
val javaFileType: FileType = FileTypeManager.getInstance().getFileTypeByExtension("java")
/**
* Collects all test methods as samples from a given {@link Project}.
*
* @param project The project to retrieve all test samples from.
*/
fun collectTestSamplesFromProject(project: Project) {
val projectFileIndex: ProjectFileIndex = ProjectRootManager.getInstance(project).fileIndex

projectFileIndex.iterateContent { file ->
if (isJavaFileTypes(file)) {
try {
val psiJavaFile = findJavaFileFromProject(file, project)
val psiClass = retrievePsiClass(psiJavaFile)
val imports = retrieveImportStatements(psiJavaFile, psiClass)
psiClass.allMethods.forEach { method -> processCandidateMethod(method, imports, psiClass) }
} catch (_: Exception) {}
}
true
}
}

/**
* Collect the test samples relevant for the current file.
*
* These test samples are those methods that call a method in the current file of a project.
*
* @param currentFile The current file.
* @param project The project.
*/
fun collectTestSamplesForCurrentFile(currentFile: VirtualFile, project: Project) {
val projectScope = ProjectAndLibrariesScope(project)
if (currentFile!!.fileType === javaFileType) {
val psiJavaFile = (PsiManager.getInstance(project).findFile(currentFile!!) as PsiJavaFile)
val psiClass = psiJavaFile.classes[
psiJavaFile.classes.stream().map { it.name }.toArray()
.indexOf(psiJavaFile.name.removeSuffix(".java")),
]
if (isJavaFileTypes(currentFile)) {
val psiJavaFile = findJavaFileFromProject(currentFile, project)
val psiClass = retrievePsiClass(psiJavaFile)
psiClass.methods.forEach { method ->
ReferencesSearch.search(method, projectScope).forEach { reference ->
val enclosingMethod = PsiTreeUtil.getParentOfType(
reference.element,
PsiMethod::class.java,
)
if (enclosingMethod != null) {
val enclosingClass = enclosingMethod.containingClass
val enclosingFile = (enclosingMethod.containingFile as PsiJavaFile)
val imports = retrieveImportStatements(enclosingFile, enclosingClass!!)
processCandidateMethod(enclosingMethod, imports, psiClass)
}
}
processMethod(psiClass, method, projectScope)
}
}
}

/**
* Returns, whether the file type is a Java file type.
*
* @param file The file object.
* @return True, if the file is a Java file, false otherwise.
*/
private fun isJavaFileTypes(file: VirtualFile): Boolean {
val javaFileType: FileType = FileTypeManager.getInstance().getFileTypeByExtension("java")
return file.fileType === javaFileType
}

/**
* Finds a {@link PsiJavaFile} from a {@link Project} and a {@link VirtualFile}.
*
* @param file The virtual file object.
* @param project The project instance.
* @return The PSI Java file for the given file and project.
*/
private fun findJavaFileFromProject(file: VirtualFile, project: Project): PsiJavaFile {
val psiManager = PsiManager.getInstance(project)
return psiManager.findFile(file) as PsiJavaFile
}

/**
* Processes a method and searches for methods that reference this method.
*
* @param psiClass The class that defines the method.
* @param psiMethod The method from which the search for references starts.
* @param scope The scope of the search
*/
private fun processMethod(psiClass: PsiClass, psiMethod: PsiMethod, scope: SearchScope) {
ReferencesSearch.search(psiMethod, scope).forEach { reference ->
val enclosingMethod = PsiTreeUtil.getParentOfType(reference.element, PsiMethod::class.java)
if (enclosingMethod != null) {
processEnclosingMethod(enclosingMethod, psiClass)
}
}
}

/**
* Processes an enclosing method of a statement.
*
* @param enclosingMethod The enclosing method.
* @param psiClass The class that defines the method.
*/
private fun processEnclosingMethod(enclosingMethod: PsiMethod, psiClass: PsiClass) {
val enclosingClass = enclosingMethod.containingClass
val enclosingFile = (enclosingMethod.containingFile as PsiJavaFile)
val imports = retrieveImportStatements(enclosingFile, enclosingClass!!)
processCandidateMethod(enclosingMethod, imports, psiClass)
}

/**
* Retrieves the {@link PsiClass} instance from a {@link PsiJavaFile}.
*
* @param psiJavaFile The PSI Java file object
* @return The PSI class object.
*/
private fun retrievePsiClass(psiJavaFile: PsiJavaFile): PsiClass {
return psiJavaFile.classes[
psiJavaFile.classes.stream().map { it.name }.toArray()
.indexOf(psiJavaFile.name.removeSuffix(".java")),
]
}

/**
* Retrieves the import statements for a {@link PsiJavaFile} and {@link PsiClass}.
*
* @param psiJavaFile The PSI Java file object.
* @param psiClass The PSI class object.
* @return A string of import statements.
*/
private fun retrieveImportStatements(psiJavaFile: PsiJavaFile, psiClass: PsiClass): String {
var imports = psiJavaFile.importList?.allImportStatements?.map { it.text }?.toList()
?.joinToString("\n") ?: ""
Expand All @@ -235,55 +363,36 @@ class LLMSampleSelectorBuilder(private val project: Project, private val languag
return imports
}

/**
* Processes a {@link PsiMethod} as a candidate.
*
* If it is a candidate, i.e., it is annotated as a JUnit test, add it to {@link #testNames} and
* {@link #initialTestCodes}, respectively.
*
* @param psiMethod The PSI method object.
* @param imports The imports required for the code generation.
* @param psiClass The PSI class object.
*/
private fun processCandidateMethod(psiMethod: PsiMethod, imports: String, psiClass: PsiClass) {
val annotations = psiMethod.annotations
annotations.forEach { annotation ->
if (annotation.qualifiedName == "org.junit.jupiter.api.Test" ||
annotation.qualifiedName == "org.junit.Test"
) {
val code: String = createTestSampleClass(imports, psiMethod.text)
if (testNames.add(createMethodName(psiClass, psiMethod))) {
initialTestCodes.add(code)
}
testNames.add(createMethodName(psiClass, psiMethod))
initialTestCodes.add(code)
}
}
}

/**
* Retrieves a list of test samples from the given project.
* Creates a class from the imports and the method codes as a test sample.
*
* @return A list of strings, representing the names of the test samples.
* @param imports The imports required for the code.
* @param methodCode The code of the method.
* @return A class wrapping the given method.
*/
fun collectTestSamples(project: Project) {
val projectFileIndex: ProjectFileIndex = ProjectRootManager.getInstance(project).fileIndex
val javaFileType: FileType = FileTypeManager.getInstance().getFileTypeByExtension("java")

projectFileIndex.iterateContent { file ->
if (file.fileType === javaFileType) {
try {
val psiJavaFile = (PsiManager.getInstance(project).findFile(file) as PsiJavaFile)
val psiClass = psiJavaFile.classes[
psiJavaFile.classes.stream().map { it.name }.toArray()
.indexOf(psiJavaFile.name.removeSuffix(".java")),
]
var imports = retrieveImportStatements(psiJavaFile, psiClass)
psiClass.allMethods.forEach { method ->
val annotations = method.modifierList.annotations
annotations.forEach { annotation ->
if (annotation.qualifiedName == "org.junit.jupiter.api.Test" || annotation.qualifiedName == "org.junit.Test") {
val code: String = createTestSampleClass(imports, method.text)
testNames.add(createMethodName(psiClass, method))
initialTestCodes.add(code)
}
}
}
} catch (_: Exception) {
}
}
true
}
}

private fun createTestSampleClass(imports: String, methodCode: String): String {
var normalizedImports = imports
if (normalizedImports.isNotBlank()) normalizedImports += "\n\n"
Expand All @@ -293,6 +402,13 @@ class LLMSampleSelectorBuilder(private val project: Project, private val languag
"}"
}

/**
* Creates a fully-qualified method name from a class and method instance.
*
* @param psiClass The class object.
* @param method The method object.
* @return A fully-qualified method name.
*/
private fun createMethodName(psiClass: PsiClass, method: PsiMethod): String =
"<html>${psiClass.qualifiedName}#${method.name}</html>"
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package org.jetbrains.research.testspark.actions.llm

import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiFile
import com.intellij.testFramework.IndexingTestUtil
Expand All @@ -13,7 +12,6 @@ import com.intellij.testFramework.fixtures.JavaTestFixtureFactory
import com.intellij.testFramework.fixtures.TestFixtureBuilder
import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test

class LLMSampleSelectorBuilderTest {

Expand Down Expand Up @@ -53,13 +51,13 @@ class LLMSampleSelectorBuilderTest {
private fun getResourceAsText(path: String): String? =
object {}.javaClass.getResource(path)?.readText()

@Test
fun collectTestSampleForCurrentFile() {
runReadAction { builder.collectTestSamplesForCurrentFile(openFile.virtualFile) }
}
// @Test
// fun collectTestSampleForCurrentFile() {
// runReadAction { builder.collectTestSamplesForCurrentFile(openFile.virtualFile) }
// }

@Test
fun collectTestSamples() {
runReadAction { builder.collectTestSamples(project) }
}
// @Test
// fun collectTestSamples() {
// runReadAction { builder.collectTestSamples(project) }
// }
}

0 comments on commit 833956d

Please sign in to comment.