Skip to content

Commit

Permalink
Provide tests for a selection of method
Browse files Browse the repository at this point in the history
The core algorithm is still not tested because I do not see how I could
get the required objects in the necessary state.  The unfortunate thing
about the implemented selection algorithms is that they rely on an open
editor for a normal execution, which seems to be hard to set up for a
unit test.  Any feedback and help on this is highly appreciated!
  • Loading branch information
stephanlukasczyk committed Nov 24, 2024
1 parent 833956d commit 52fef0d
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class LLMSampleSelectorBuilder(private val project: Project, private val languag
* A selector for samples for the LLM.
*/
class LLMSampleSelector {
private val defaultTestName = "<html>provide manually</html"
private val defaultTestName = "<html>provide manually</html>"
private val defaultTestCode = "// provide test method code here"
private val testNames = mutableSetOf(defaultTestName)
private val initialTestCodes = mutableListOf(createTestSampleClass("", defaultTestCode))
Expand Down Expand Up @@ -340,7 +340,7 @@ class LLMSampleSelector {
* @param psiJavaFile The PSI Java file object
* @return The PSI class object.
*/
private fun retrievePsiClass(psiJavaFile: PsiJavaFile): PsiClass {
fun retrievePsiClass(psiJavaFile: PsiJavaFile): PsiClass {
return psiJavaFile.classes[
psiJavaFile.classes.stream().map { it.name }.toArray()
.indexOf(psiJavaFile.name.removeSuffix(".java")),
Expand All @@ -354,7 +354,7 @@ class LLMSampleSelector {
* @param psiClass The PSI class object.
* @return A string of import statements.
*/
private fun retrieveImportStatements(psiJavaFile: PsiJavaFile, psiClass: PsiClass): String {
fun retrieveImportStatements(psiJavaFile: PsiJavaFile, psiClass: PsiClass): String {
var imports = psiJavaFile.importList?.allImportStatements?.map { it.text }?.toList()
?.joinToString("\n") ?: ""
if (psiClass.qualifiedName != null && psiClass.qualifiedName!!.contains(".")) {
Expand Down Expand Up @@ -393,12 +393,12 @@ class LLMSampleSelector {
* @param methodCode The code of the method.
* @return A class wrapping the given method.
*/
private fun createTestSampleClass(imports: String, methodCode: String): String {
fun createTestSampleClass(imports: String, methodCode: String): String {
var normalizedImports = imports
if (normalizedImports.isNotBlank()) normalizedImports += "\n\n"
return normalizedImports +
"public class TestSample {\n" +
" $methodCode\n" +
" $methodCode\n" +
"}"
}

Expand All @@ -409,6 +409,6 @@ class LLMSampleSelector {
* @param method The method object.
* @return A fully-qualified method name.
*/
private fun createMethodName(psiClass: PsiClass, method: PsiMethod): String =
fun createMethodName(psiClass: PsiClass, method: PsiMethod): String =
"<html>${psiClass.qualifiedName}#${method.name}</html>"
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package org.jetbrains.research.testspark.actions.llm

import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import com.intellij.openapi.application.runReadAction
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiJavaFile
import com.intellij.psi.PsiMethod
import com.intellij.testFramework.fixtures.CodeInsightTestFixture
import com.intellij.testFramework.fixtures.IdeaProjectTestFixture
import com.intellij.testFramework.fixtures.IdeaTestFixtureFactory
import com.intellij.testFramework.fixtures.JavaTestFixtureFactory
import com.intellij.testFramework.fixtures.TestFixtureBuilder
import com.intellij.util.containers.stream
import net.jqwik.api.ForAll
import net.jqwik.api.Property
import net.jqwik.api.lifecycle.BeforeTry
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import kotlin.test.assertContentEquals

class LLMSampleSelectorTest {

private lateinit var selector: LLMSampleSelector
private lateinit var fixture: CodeInsightTestFixture
private lateinit var openFile: PsiJavaFile

@BeforeEach
@BeforeTry
fun setUpSelector() {
selector = LLMSampleSelector()
}

@BeforeEach
fun setUpProject() {
val projectBuilder: TestFixtureBuilder<IdeaProjectTestFixture> =
IdeaTestFixtureFactory.getFixtureFactory().createFixtureBuilder("project")

fixture = JavaTestFixtureFactory.getFixtureFactory()
.createCodeInsightFixture(projectBuilder.fixture)
fixture.setUp()

addFilesFromJSONToFixture()
}

private fun classFromFile(psiJavaFile: PsiJavaFile): PsiClass {
return runReadAction {
psiJavaFile.classes[
psiJavaFile.classes.stream().map { it.name }.toArray()
.indexOf(psiJavaFile.name.removeSuffix(".java")),
]
}
}

private fun methodsFromClass(psiClass: PsiClass): Array<out PsiMethod> {
return psiClass.methods
}

private fun addFilesFromJSONToFixture() {
val jsonFileContent = getResourceAsText("/llm/sampleSelectorTestFiles.json")
val type = object : TypeToken<Map<String, String>>() {}.type
val fileMap = Gson().fromJson(jsonFileContent, type) as Map<String, String>
fileMap.forEach { (fileName, fileContent) ->
val psiFile = fixture.addFileToProject(fileName, fileContent)
if (fileName == "test/dummy/CarTest.java") {
openFile = psiFile as PsiJavaFile
}
}
}

private fun getResourceAsText(path: String): String? =
object {}.javaClass.getResource(path)?.readText()

@Test
fun `test the initial test names`() {
val expected = mutableListOf("<html>provide manually</html>")
val actual = selector.getTestNames()
assertContentEquals(expected, actual)
}

@Test
fun `test the initial test code`() {
val initialCode = """
public class TestSample {
// provide test method code here
}
""".trimIndent()
val expected = mutableListOf(initialCode)
val actual = selector.getInitialTestCodes()
assertContentEquals(expected, actual)
}

@Property
fun `test the append of test sample code`(@ForAll index: Int, @ForAll code: String) {
selector.appendTestSampleCode(index, code)
val expected = "Test sample number ${index + 1}\n```\n$code\n```\n"
val actual = selector.getTestSamplesCode()
assertEquals(expected, actual)
}

@Test
fun `test the class retrieval from a Java file`() {
val expectedName = "CarTest"
val actual = runReadAction { selector.retrievePsiClass(openFile as PsiJavaFile) }
assertEquals(expectedName, actual.name)
}

@Test
fun `test the import-statement retrieval`() {
val expected = """
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import dummy.*;
""".trimIndent()
val actual = runReadAction {
val file = openFile as PsiJavaFile
selector.retrieveImportStatements(file, classFromFile(file))
}
assertEquals(expected, actual)
}

@Test
fun `test the create test-sample class`() {
val expected = """
import org.junit.jupiter.api.Test;
public class TestSample {
// provide test method code here
}
""".trimIndent()
val actual = selector.createTestSampleClass(
"import org.junit.jupiter.api.Test;",
"// provide test method code here",
)
assertEquals(expected, actual)
}

@Test
fun `test the expected method name`() {
val expected = "<html>dummy.CarTest#testCar</html>"
val cls = classFromFile(openFile)
val actual = selector.createMethodName(cls, methodsFromClass(cls)[0])
assertEquals(expected, actual)
}
}

0 comments on commit 52fef0d

Please sign in to comment.