Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Headless mode Kotlin #367

Open
wants to merge 10 commits into
base: headless-mode
Choose a base branch
from
28 changes: 25 additions & 3 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import org.jetbrains.intellij.platform.gradle.IntelliJPlatformType
import org.jetbrains.intellij.platform.gradle.TestFrameworkType
import org.jetbrains.intellij.platform.gradle.models.ProductRelease
import org.jetbrains.intellij.platform.gradle.tasks.RunIdeTask
import org.jetbrains.intellij.platform.gradle.tasks.aware.SplitModeAware.SplitModeTarget
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
import java.io.FileOutputStream
import java.net.URL
Expand Down Expand Up @@ -80,7 +81,10 @@ if (spaceCredentialsProvided()) {
dependencies {
add(hasGrazieAccess.implementationConfigurationName, kotlin("stdlib"))
add(hasGrazieAccess.implementationConfigurationName, "org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.3")
add(hasGrazieAccess.implementationConfigurationName, "org.jetbrains.research:grazie-test-generation:$grazieTestGenerationVersion")
add(
hasGrazieAccess.implementationConfigurationName,
"org.jetbrains.research:grazie-test-generation:$grazieTestGenerationVersion"
)
}

tasks.register("checkCredentials") {
Expand Down Expand Up @@ -428,6 +432,7 @@ fun String?.orDefault(default: String): String = this ?: default
* @param prompt a txt file containing the LLM's prompt template
* @param out The output directory for the project.
* @param enableCoverage flag to enable/disable coverage computation
* @param methodName indicates the name of the method under test or empty for class level generation
*/
tasks.create<RunIdeTask>("headless") {
val root: String? by project
Expand All @@ -440,8 +445,22 @@ tasks.create<RunIdeTask>("headless") {
val prompt: String? by project
val out: String? by project
val enableCoverage: String? by project

args = listOfNotNull("testspark", root, file, cut, cp, junitv, llm, token, prompt, out, enableCoverage.orDefault("false"))
val methodName: String? by project

args = listOfNotNull(
"testspark",
root,
file,
cut,
cp,
junitv,
llm,
token,
prompt,
out,
enableCoverage.orDefault("false"),
methodName.orDefault("")
)

jvmArgs(
"-Xmx16G",
Expand All @@ -450,6 +469,9 @@ tasks.create<RunIdeTask>("headless") {
"java.base/jdk.internal.vm=ALL-UNNAMED",
"-Didea.system.path",
)

splitMode = false
splitModeTarget = SplitModeTarget.BACKEND
}

fun spaceCredentialsProvided() = spaceUsername.isNotEmpty() && spacePassword.isNotEmpty()
7 changes: 4 additions & 3 deletions runTestSparkHeadless.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ if [ $# -ne "12" ]; then
9) Output directory
10) Enable/disable coverage computation ('true' or 'false')
11) Space username
12) Space password"
12) Space password
13) Method under test name(or empty for class-level generation)"
exit 1
fi

echo -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}"
"$DIR/gradlew" -p "$DIR" headless -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}"
echo -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}" -PmethodName="${13}"
"$DIR/gradlew" -p "$DIR" headless -Proot="$1" -Pfile="$2" -Pcut="$3" -Pcp="$4" -Pjunitv="$5" -Pllm="$6" -Ptoken="$7" -Pprompt="$8" -Pout="$9" -PenableCoverage="${10}" -Dspace.username="${11}" -Dspace.pass="${12}" -PmethodName="${13}"
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,21 @@ import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiJavaFile
import com.intellij.psi.PsiManager
import com.intellij.psi.PsiMethod
import kotlinx.serialization.ExperimentalSerializationApi
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle
import org.jetbrains.research.testspark.core.data.JUnitVersion
import org.jetbrains.research.testspark.core.data.TestGenerationData
import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor
import org.jetbrains.research.testspark.core.test.SupportedLanguage
import org.jetbrains.research.testspark.core.test.TestCompiler
import org.jetbrains.research.testspark.core.test.data.CodeType
import org.jetbrains.research.testspark.data.FragmentToTestData
import org.jetbrains.research.testspark.data.ProjectContext
import org.jetbrains.research.testspark.data.llm.JsonEncoding
import org.jetbrains.research.testspark.java.JavaPsiMethodWrapper
import org.jetbrains.research.testspark.kotlin.KotlinPsiHelperProvider
import org.jetbrains.research.testspark.langwrappers.PsiHelperProvider
import org.jetbrains.research.testspark.progress.HeadlessProgressIndicator
import org.jetbrains.research.testspark.services.LLMSettingsService
Expand Down Expand Up @@ -71,8 +76,13 @@ class TestSparkStarter : ApplicationStarter {
val output = args[9]
// Run coverage
val runCoverage = args[10].toBoolean()
// Method under test name(or empty string for class level generation)
val methodName = args[11]

val testsExecutionResultManager = TestsExecutionResultManager()
// TODO check for suitable refactoring
val language =
if (cutSourceFilePath.toString().endsWith(".kt")) SupportedLanguage.Kotlin else SupportedLanguage.Java

println("Test generation requested for $projectPath")

Expand Down Expand Up @@ -108,10 +118,15 @@ class TestSparkStarter : ApplicationStarter {
println("Couldn't open file $cutSourceFilePath")
exitProcess(1)
}

// get target PsiClass
val psiFile = PsiManager.getInstance(project).findFile(cutSourceVirtualFile) as PsiJavaFile
val targetPsiClass = detectPsiClass(psiFile.classes, classUnderTestName) ?: run {
val psiFile = PsiManager.getInstance(project).findFile(cutSourceVirtualFile)
val targetPsiClass = detectPsiClass(
when (language) {
SupportedLanguage.Java -> psiFile as PsiJavaFile
SupportedLanguage.Kotlin -> psiFile as KtFile
}.classes,
classUnderTestName
) ?: run {
println("Couldn't find $classUnderTestName in $cutSourceFilePath")
exitProcess(1)
}
Expand Down Expand Up @@ -159,7 +174,10 @@ class TestSparkStarter : ApplicationStarter {
val packageName = packageList.joinToString(".")

// Get PsiHelper
val psiHelper = PsiHelperProvider.getPsiHelper(psiFile)
val psiHelper = when (language) {
SupportedLanguage.Kotlin -> KotlinPsiHelperProvider().getPsiHelper(psiFile as KtFile)
SupportedLanguage.Java -> PsiHelperProvider.getPsiHelper(psiFile as PsiJavaFile)
}
if (psiHelper == null) {
// TODO exception: the support for the current language does not exist
}
Expand All @@ -183,9 +201,23 @@ class TestSparkStarter : ApplicationStarter {
psiHelper.language,
projectSDKPath.toString(),
)
val codeType = when (methodName) {
"" -> FragmentToTestData(CodeType.CLASS)
else -> {
val psiMethod = targetPsiClass.methods.find { it.name == methodName } ?: run {
println("Couldn't find method $methodName")
exitProcess(1)
}
FragmentToTestData(
CodeType.METHOD,
psiHelper.generateMethodDescriptor(JavaPsiMethodWrapper(psiMethod as PsiMethod))
)
}
}

val uiContext = llmProcessManager.runTestGenerator(
indicator,
FragmentToTestData(CodeType.CLASS),
codeType,
packageName,
projectContext,
testGenerationData,
Expand Down Expand Up @@ -257,6 +289,7 @@ class TestSparkStarter : ApplicationStarter {
val targetDirectory = "$out${File.separator}${packageList.joinToString(File.separator)}"
println("Run tests in $targetDirectory")
File(targetDirectory).walk().forEach {
// TODO Doesn't work for compiled kotlin files
if (it.name.endsWith(".class")) {
println("Running test ${it.name}")
var testcaseName = it.nameWithoutExtension.removePrefix("Generated")
Expand Down
Loading