diff --git a/build.gradle.kts b/build.gradle.kts index 0f44089cb..3474bcfcf 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -3,6 +3,7 @@ import org.jetbrains.intellij.platform.gradle.IntelliJPlatformType import org.jetbrains.intellij.platform.gradle.TestFrameworkType import org.jetbrains.intellij.platform.gradle.models.ProductRelease import org.jetbrains.intellij.platform.gradle.tasks.RunIdeTask +import org.jetbrains.intellij.platform.gradle.tasks.aware.SplitModeAware.SplitModeTarget import org.jetbrains.kotlin.gradle.tasks.KotlinCompile import java.io.FileOutputStream import java.net.URL @@ -80,7 +81,10 @@ if (spaceCredentialsProvided()) { dependencies { add(hasGrazieAccess.implementationConfigurationName, kotlin("stdlib")) add(hasGrazieAccess.implementationConfigurationName, "org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.3") - add(hasGrazieAccess.implementationConfigurationName, "org.jetbrains.research:grazie-test-generation:$grazieTestGenerationVersion") + add( + hasGrazieAccess.implementationConfigurationName, + "org.jetbrains.research:grazie-test-generation:$grazieTestGenerationVersion" + ) } tasks.register("checkCredentials") { @@ -428,6 +432,7 @@ fun String?.orDefault(default: String): String = this ?: default * @param prompt a txt file containing the LLM's prompt template * @param out The output directory for the project. * @param enableCoverage flag to enable/disable coverage computation + * @param methodName indicates the name of the method under test or empty for class level generation */ tasks.create("headless") { val root: String? by project @@ -440,8 +445,22 @@ tasks.create("headless") { val prompt: String? by project val out: String? by project val enableCoverage: String? by project - - args = listOfNotNull("testspark", root, file, cut, cp, junitv, llm, token, prompt, out, enableCoverage.orDefault("false")) + val methodName: String? by project + + args = listOfNotNull( + "testspark", + root, + file, + cut, + cp, + junitv, + llm, + token, + prompt, + out, + enableCoverage.orDefault("false"), + methodName.orDefault("") + ) jvmArgs( "-Xmx16G", @@ -450,6 +469,9 @@ tasks.create("headless") { "java.base/jdk.internal.vm=ALL-UNNAMED", "-Didea.system.path", ) + + splitMode = false + splitModeTarget = SplitModeTarget.BACKEND } fun spaceCredentialsProvided() = spaceUsername.isNotEmpty() && spacePassword.isNotEmpty() diff --git a/runTestSparkHeadless.sh b/runTestSparkHeadless.sh index a5b64f474..bbb5c0f8c 100644 --- a/runTestSparkHeadless.sh +++ b/runTestSparkHeadless.sh @@ -23,9 +23,10 @@ if [ $# -ne "12" ]; then 9) Output directory 10) Enable/disable coverage computation ('true' or 'false') 11) Space username - 12) Space password" + 12) Space password + 13) Method under test name(or empty for class-level generation)" exit 1 fi -echo -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}" -"$DIR/gradlew" -p "$DIR" headless -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}" +echo -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}" -PmethodName="${13}" +"$DIR/gradlew" -p "$DIR" headless -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}" -PmethodName="${13}" diff --git a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt index af2bdef41..4f4cb9d0b 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/appstarter/TestSparkStarter.kt @@ -13,16 +13,21 @@ import com.intellij.openapi.vfs.LocalFileSystem import com.intellij.psi.PsiClass import com.intellij.psi.PsiJavaFile import com.intellij.psi.PsiManager +import com.intellij.psi.PsiMethod import kotlinx.serialization.ExperimentalSerializationApi +import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor +import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestCompiler import org.jetbrains.research.testspark.core.test.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData import org.jetbrains.research.testspark.data.ProjectContext import org.jetbrains.research.testspark.data.llm.JsonEncoding +import org.jetbrains.research.testspark.java.JavaPsiMethodWrapper +import org.jetbrains.research.testspark.kotlin.KotlinPsiHelperProvider import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider import org.jetbrains.research.testspark.progress.HeadlessProgressIndicator import org.jetbrains.research.testspark.services.LLMSettingsService @@ -71,8 +76,13 @@ class TestSparkStarter : ApplicationStarter { val output = args[9] // Run coverage val runCoverage = args[10].toBoolean() + // Method under test name(or empty string for class level generation) + val methodName = args[11] val testsExecutionResultManager = TestsExecutionResultManager() + // TODO check for suitable refactoring + val language = + if (cutSourceFilePath.toString().endsWith(".kt")) SupportedLanguage.Kotlin else SupportedLanguage.Java println("Test generation requested for $projectPath") @@ -108,10 +118,15 @@ class TestSparkStarter : ApplicationStarter { println("Couldn't open file $cutSourceFilePath") exitProcess(1) } - // get target PsiClass - val psiFile = PsiManager.getInstance(project).findFile(cutSourceVirtualFile) as PsiJavaFile - val targetPsiClass = detectPsiClass(psiFile.classes, classUnderTestName) ?: run { + val psiFile = PsiManager.getInstance(project).findFile(cutSourceVirtualFile) + val targetPsiClass = detectPsiClass( + when (language) { + SupportedLanguage.Java -> psiFile as PsiJavaFile + SupportedLanguage.Kotlin -> psiFile as KtFile + }.classes, + classUnderTestName + ) ?: run { println("Couldn't find $classUnderTestName in $cutSourceFilePath") exitProcess(1) } @@ -159,7 +174,10 @@ class TestSparkStarter : ApplicationStarter { val packageName = packageList.joinToString(".") // Get PsiHelper - val psiHelper = PsiHelperProvider.getPsiHelper(psiFile) + val psiHelper = when (language) { + SupportedLanguage.Kotlin -> KotlinPsiHelperProvider().getPsiHelper(psiFile as KtFile) + SupportedLanguage.Java -> PsiHelperProvider.getPsiHelper(psiFile as PsiJavaFile) + } if (psiHelper == null) { // TODO exception: the support for the current language does not exist } @@ -183,9 +201,23 @@ class TestSparkStarter : ApplicationStarter { psiHelper.language, projectSDKPath.toString(), ) + val codeType = when (methodName) { + "" -> FragmentToTestData(CodeType.CLASS) + else -> { + val psiMethod = targetPsiClass.methods.find { it.name == methodName } ?: run { + println("Couldn't find method $methodName") + exitProcess(1) + } + FragmentToTestData( + CodeType.METHOD, + psiHelper.generateMethodDescriptor(JavaPsiMethodWrapper(psiMethod as PsiMethod)) + ) + } + } + val uiContext = llmProcessManager.runTestGenerator( indicator, - FragmentToTestData(CodeType.CLASS), + codeType, packageName, projectContext, testGenerationData, @@ -257,6 +289,7 @@ class TestSparkStarter : ApplicationStarter { val targetDirectory = "$out${File.separator}${packageList.joinToString(File.separator)}" println("Run tests in $targetDirectory") File(targetDirectory).walk().forEach { + // TODO Doesn't work for compiled kotlin files if (it.name.endsWith(".class")) { println("Running test ${it.name}") var testcaseName = it.nameWithoutExtension.removePrefix("Generated")