Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ICTL-826] add model selection for Grazie platform #104

Merged
merged 9 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ dependencies {
// Dependencies for hasGrazieAccess variant
"hasGrazieAccessImplementation"(kotlin("stdlib"))
"hasGrazieAccessImplementation"("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.3")
"hasGrazieAccessImplementation"("org.jetbrains.research:grazie-test-generation:1.0.1")
"hasGrazieAccessImplementation"("org.jetbrains.research:grazie-test-generation:1.0.4")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.jetbrains.research.grazie

import org.jetbrains.research.testSpark.grazie.TestGeneration
import org.jetbrains.research.testspark.tools.llm.generation.Info

class Info : Info {
override fun availableProfiles(): Set<String> = TestGeneration.availableProfiles
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ class Request : Request {
override fun request(
token: String,
messages: List<Pair<String, String>>,
profile: String,
testsAssembler: TestsAssembler,
): Pair<String, TestsAssembler> {
val generation = TestGeneration(token)
var errorMessage = ""

runBlocking {
generation.generate(messages).catch {
generation.generate(messages, profile).catch {
errorMessage = it.message.toString()
}.collect {
testsAssembler.receiveResponse(it)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ class LLMPanelFactory : ToolPanelFactory {
private val defaultModulesArray = arrayOf("")
private var modelSelector = ComboBox(defaultModulesArray)
private var llmUserTokenField = JTextField(30)
private var platformSelector = ComboBox(arrayOf("OpenAI"))
private var lastChosenModule = ""
private var platformSelector = ComboBox(arrayOf(TestSparkLabelsBundle.defaultValue("openAI")))
private val backLlmButton = JButton("Back")
private val okLlmButton = JButton("OK")

Expand All @@ -34,7 +33,6 @@ class LLMPanelFactory : ToolPanelFactory {
modelSelector,
llmUserTokenField,
defaultModulesArray,
lastChosenModule,
)
}

Expand Down Expand Up @@ -65,14 +63,18 @@ class LLMPanelFactory : ToolPanelFactory {
titlePanel.add(textTitle)

if (isGrazieClassLoaded()) {
platformSelector.model = DefaultComboBoxModel(arrayOf("Grazie", "OpenAI"))
platformSelector.model = DefaultComboBoxModel(arrayOf(TestSparkLabelsBundle.defaultValue("grazie"), TestSparkLabelsBundle.defaultValue("openAI")))
platformSelector.selectedItem = settingsState.llmPlatform
} else {
platformSelector.isEnabled = false
}

llmUserTokenField.toolTipText = TestSparkToolTipsBundle.defaultValue("llmToken")
llmUserTokenField.text = settingsState.llmUserToken
if (platformSelector.selectedItem!!.toString() == TestSparkLabelsBundle.defaultValue("grazie")) {
llmUserTokenField.text = settingsState.grazieToken
} else {
llmUserTokenField.text = settingsState.openAIToken
}

modelSelector.toolTipText = TestSparkToolTipsBundle.defaultValue("model")
modelSelector.isEnabled = false
Expand All @@ -82,7 +84,6 @@ class LLMPanelFactory : ToolPanelFactory {
modelSelector,
llmUserTokenField,
defaultModulesArray,
lastChosenModule,
)

val bottomButtons = JPanel()
Expand Down Expand Up @@ -131,7 +132,12 @@ class LLMPanelFactory : ToolPanelFactory {
*/
override fun settingsStateUpdate() {
settingsState.llmPlatform = platformSelector.selectedItem!!.toString()
settingsState.llmUserToken = llmUserTokenField.text
settingsState.model = modelSelector.selectedItem!!.toString()
if (platformSelector.selectedItem!!.toString() == TestSparkLabelsBundle.defaultValue("grazie")) {
settingsState.grazieToken = llmUserTokenField.text
settingsState.grazieModel = modelSelector.selectedItem!!.toString()
} else {
settingsState.openAIToken = llmUserTokenField.text
settingsState.openAIModel = modelSelector.selectedItem!!.toString()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import com.google.gson.JsonParser
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.ui.ComboBox
import com.intellij.util.io.HttpRequests
import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
import org.jetbrains.research.testspark.services.SettingsApplicationService
import org.jetbrains.research.testspark.tools.llm.generation.Info
import java.net.HttpURLConnection
import javax.swing.DefaultComboBoxModel
import javax.swing.JTextField
Expand All @@ -25,21 +27,28 @@ fun isGrazieClassLoaded(): Boolean {
}
}

fun loadGrazieInfo(): Info? {
val className = "org.jetbrains.research.grazie.Info"
return try {
Class.forName(className).getDeclaredConstructor().newInstance() as Info
} catch (e: ClassNotFoundException) {
null
}
}

/**
* Adds listeners to the given components to handle events and perform necessary actions.
*
* @param platformSelector The combo box used for selecting platforms.
* @param modelSelector The combo box used for selecting models.
* @param llmUserTokenField The text field used for entering the user token.
* @param defaultModulesArray An array of default module names.
* @param lastChosenModule The name of the last chosen module.
*/
fun addLLMPanelListeners(
platformSelector: ComboBox<String>,
modelSelector: ComboBox<String>,
llmUserTokenField: JTextField,
defaultModulesArray: Array<String>,
lastChosenModule: String,
) {
llmUserTokenField.document.addDocumentListener(object : DocumentListener {
override fun insertUpdate(e: DocumentEvent?) {
Expand All @@ -48,7 +57,6 @@ fun addLLMPanelListeners(
modelSelector,
llmUserTokenField,
defaultModulesArray,
lastChosenModule,
)
}

Expand All @@ -58,7 +66,6 @@ fun addLLMPanelListeners(
modelSelector,
llmUserTokenField,
defaultModulesArray,
lastChosenModule,
)
}

Expand All @@ -68,18 +75,23 @@ fun addLLMPanelListeners(
modelSelector,
llmUserTokenField,
defaultModulesArray,
lastChosenModule,
)
}
})

platformSelector.addItemListener {
val settingsState = SettingsApplicationService.getInstance().state!!
if (platformSelector.selectedItem!!.toString() == TestSparkLabelsBundle.defaultValue("grazie")) {
llmUserTokenField.text = settingsState.grazieToken
} else {
llmUserTokenField.text = settingsState.openAIToken
}

updateModelSelector(
platformSelector,
modelSelector,
llmUserTokenField,
defaultModulesArray,
lastChosenModule,
)
}
}
Expand All @@ -97,25 +109,26 @@ fun updateModelSelector(
modelSelector: ComboBox<String>,
llmUserTokenField: JTextField,
defaultModulesArray: Array<String>,
lastChosenModule: String,
) {
val settingsState = SettingsApplicationService.getInstance().state!!

if (platformSelector.selectedItem!!.toString() == "Grazie") {
modelSelector.model = DefaultComboBoxModel(arrayOf("GPT-4"))
modelSelector.isEnabled = false
return
}
ApplicationManager.getApplication().executeOnPooledThread {
val modules = getOpenAIModules(llmUserTokenField.text, lastChosenModule)
modelSelector.removeAllItems()
if (modules != null) {
modelSelector.model = DefaultComboBoxModel(modules)
if (modules.contains(settingsState.model)) modelSelector.selectedItem = settingsState.model
modelSelector.isEnabled = true
} else {
modelSelector.model = DefaultComboBoxModel(defaultModulesArray)
modelSelector.isEnabled = false
if (platformSelector.selectedItem!!.toString() == TestSparkLabelsBundle.defaultValue("grazie")) {
val info = loadGrazieInfo()
val modules = info?.availableProfiles() ?: emptySet()
modelSelector.model = DefaultComboBoxModel(modules.toTypedArray())
if (modules.contains(settingsState.grazieModel)) modelSelector.selectedItem = settingsState.grazieModel
modelSelector.isEnabled = true
} else {
ApplicationManager.getApplication().executeOnPooledThread {
val modules = getOpenAIModules(llmUserTokenField.text)
if (modules != null) {
modelSelector.model = DefaultComboBoxModel(modules)
if (modules.contains(settingsState.openAIModel)) modelSelector.selectedItem = settingsState.openAIModel
modelSelector.isEnabled = true
} else {
modelSelector.model = DefaultComboBoxModel(defaultModulesArray)
modelSelector.isEnabled = false
}
}
}
}
Expand All @@ -126,7 +139,7 @@ fun updateModelSelector(
* @param token Authorization token for the OpenAI API.
* @return An array of model names if request is successful, otherwise null.
*/
private fun getOpenAIModules(token: String, lastChosenModule: String): Array<String>? {
private fun getOpenAIModules(token: String): Array<String>? {
val url = "https://api.openai.com/v1/models"

val httpRequest = HttpRequests.request(url).tuner {
Expand All @@ -152,8 +165,6 @@ private fun getOpenAIModules(token: String, lastChosenModule: String): Array<Str

val gptComparator = Comparator<String> { s1, s2 ->
when {
s1 == lastChosenModule -> -1
s2 == lastChosenModule -> 1
s1.contains("gpt") && s2.contains("gpt") -> s2.compareTo(s1)
s1.contains("gpt") -> -1
s2.contains("gpt") -> 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class LLMChatService {
* @return True if the token is set, false otherwise.
*/
private fun updateToken(project: Project): Boolean {
requestManager.token = SettingsArguments.llmUserToken()
requestManager.token = SettingsArguments.getToken()
return isCorrectToken(project)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ data class SettingsApplicationState(
var criterionMethodNoException: Boolean = DefaultSettingsApplicationState.criterionMethodNoException,
var criterionCBranch: Boolean = DefaultSettingsApplicationState.criterionCBranch,
var minimize: Boolean = DefaultSettingsApplicationState.minimize,
var llmUserToken: String = DefaultSettingsApplicationState.llmUserToken,
var model: String = DefaultSettingsApplicationState.model,
var openAIToken: String = DefaultSettingsApplicationState.openAIToken,
var grazieToken: String = DefaultSettingsApplicationState.grazieToken,
var openAIModel: String = DefaultSettingsApplicationState.openAIModel,
var grazieModel: String = DefaultSettingsApplicationState.grazieModel,
var llmPlatform: String = DefaultSettingsApplicationState.llmPlatform,
var maxLLMRequest: Int = DefaultSettingsApplicationState.maxLLMRequest,
var maxInputParamsDepth: Int = DefaultSettingsApplicationState.maxInputParamsDepth,
Expand Down Expand Up @@ -56,8 +58,10 @@ data class SettingsApplicationState(
val criterionMethod: Boolean = TestSparkDefaultsBundle.defaultValue("criterionMethod").toBoolean()
val criterionMethodNoException: Boolean = TestSparkDefaultsBundle.defaultValue("criterionMethodNoException").toBoolean()
val criterionCBranch: Boolean = TestSparkDefaultsBundle.defaultValue("criterionCBranch").toBoolean()
val llmUserToken: String = TestSparkDefaultsBundle.defaultValue("llmToken")
var model: String = TestSparkDefaultsBundle.defaultValue("model")
val openAIToken: String = TestSparkDefaultsBundle.defaultValue("openAIToken")
val grazieToken: String = TestSparkDefaultsBundle.defaultValue("grazieToken")
var openAIModel: String = TestSparkDefaultsBundle.defaultValue("openAIModel")
var grazieModel: String = TestSparkDefaultsBundle.defaultValue("grazieModel")
var llmPlatform: String = TestSparkDefaultsBundle.defaultValue("llmPlatform")
val maxLLMRequest: Int = TestSparkDefaultsBundle.defaultValue("maxLLMRequest").toInt()
val maxInputParamsDepth: Int = TestSparkDefaultsBundle.defaultValue("maxInputParamsDepth").toInt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ class SettingsLLMComponent {
// Models
private val defaultModulesArray = arrayOf("")
private var modelSelector = ComboBox(defaultModulesArray)
private var platformSelector = ComboBox(arrayOf("OpenAI"))
private var platformSelector = ComboBox(arrayOf(TestSparkLabelsBundle.defaultValue("openAI")))

// Prompt Editor
private var promptSeparator = JXTitledSeparator(TestSparkLabelsBundle.defaultValue("PromptSeparator"))
private var promptEditorTabbedPane = creatTabbedPane()

private var lastChosenModule = ""

// Maximum number of LLM requests
private var maxLLMRequestsField =
JBIntSpinner(UINumericRange(SettingsApplicationState.DefaultSettingsApplicationState.maxLLMRequest, 1, 20))
Expand Down Expand Up @@ -146,7 +144,6 @@ class SettingsLLMComponent {
modelSelector,
llmUserTokenField,
defaultModulesArray,
lastChosenModule,
)

addHighlighterListeners()
Expand Down Expand Up @@ -180,7 +177,7 @@ class SettingsLLMComponent {
private fun createSettingsPanel() {
// Check if the Grazie platform access is available in the current build
if (isGrazieClassLoaded()) {
platformSelector.model = DefaultComboBoxModel(arrayOf("Grazie", "OpenAI"))
platformSelector.model = DefaultComboBoxModel(arrayOf(TestSparkLabelsBundle.defaultValue("grazie"), TestSparkLabelsBundle.defaultValue("openAI")))
} else {
platformSelector.isEnabled = false
}
Expand Down Expand Up @@ -233,16 +230,27 @@ class SettingsLLMComponent {
return (promptEditorTabbedPane.getComponentAt(editorType.index) as JPanel).getComponent(0) as EditorTextField
}

var llmUserToken: String
var openAIToken: String
get() = llmUserTokenField.text
set(newText) {
llmUserTokenField.text = newText
}

var grazieToken: String
get() = llmUserTokenField.text
set(newText) {
llmUserTokenField.text = newText
}

var model: String
var openAIModel: String
get() = modelSelector.item
set(newAlg) {
modelSelector.item = newAlg
}

var grazieModel: String
get() = modelSelector.item
set(newAlg) {
lastChosenModule = newAlg
modelSelector.item = newAlg
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.jetbrains.research.testspark.settings.llm

import com.intellij.openapi.components.service
import com.intellij.openapi.options.Configurable
import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
import org.jetbrains.research.testspark.services.PromptParserService
import org.jetbrains.research.testspark.services.SettingsApplicationService
import org.jetbrains.research.testspark.settings.SettingsApplicationState
Expand Down Expand Up @@ -31,8 +32,10 @@ class SettingsLLMConfigurable : Configurable {
*/
override fun reset() {
val settingsState: SettingsApplicationState = SettingsApplicationService.getInstance().state!!
settingsComponent!!.llmUserToken = settingsState.llmUserToken
settingsComponent!!.model = settingsState.model
settingsComponent!!.openAIToken = settingsState.openAIToken
settingsComponent!!.grazieToken = settingsState.grazieToken
settingsComponent!!.openAIModel = settingsState.openAIModel
settingsComponent!!.grazieModel = settingsState.grazieModel
settingsComponent!!.llmPlatform = settingsState.llmPlatform
settingsComponent!!.maxLLMRequest = settingsState.maxLLMRequest
settingsComponent!!.maxPolyDepth = settingsState.maxPolyDepth
Expand All @@ -49,8 +52,10 @@ class SettingsLLMConfigurable : Configurable {
*/
override fun isModified(): Boolean {
val settingsState: SettingsApplicationState = SettingsApplicationService.getInstance().state!!
var modified: Boolean = settingsComponent!!.llmUserToken != settingsState.llmUserToken
modified = modified or (settingsComponent!!.model != settingsState.model)
var modified: Boolean = settingsComponent!!.openAIToken != settingsState.openAIToken
modified = modified or (settingsComponent!!.grazieToken != settingsState.grazieToken)
modified = modified or (settingsComponent!!.openAIModel != settingsState.openAIModel)
modified = modified or (settingsComponent!!.grazieModel != settingsState.grazieModel)
modified = modified or (settingsComponent!!.llmPlatform != settingsState.llmPlatform)
modified = modified or (settingsComponent!!.maxLLMRequest != settingsState.maxLLMRequest)
modified = modified or (settingsComponent!!.maxPolyDepth != settingsState.maxPolyDepth)
Expand All @@ -73,8 +78,13 @@ class SettingsLLMConfigurable : Configurable {
*/
override fun apply() {
val settingsState: SettingsApplicationState = SettingsApplicationService.getInstance().state!!
settingsState.llmUserToken = settingsComponent!!.llmUserToken
settingsState.model = settingsComponent!!.model
if (settingsComponent!!.llmPlatform == TestSparkLabelsBundle.defaultValue("grazie")) {
settingsState.grazieToken = settingsComponent!!.grazieToken
settingsState.grazieModel = settingsComponent!!.grazieModel
} else {
settingsState.openAIToken = settingsComponent!!.openAIToken
settingsState.openAIModel = settingsComponent!!.openAIModel
}
settingsState.llmPlatform = settingsComponent!!.llmPlatform
settingsState.maxLLMRequest = settingsComponent!!.maxLLMRequest
settingsState.maxPolyDepth = settingsComponent!!.maxPolyDepth
Expand Down
Loading
Loading