Skip to content

Commit

Permalink
Merge pull request #397 from kofemann/custom-model
Browse files Browse the repository at this point in the history
[poc] add a way to specify custom model name
  • Loading branch information
stephanj authored Dec 17, 2024
2 parents 120e72b + f331fb9 commit d028e2e
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
builder.baseUrl(DevoxxGenieStateService.getInstance().getCustomOpenAIUrl());
}

if (Strings.isNotBlank(DevoxxGenieStateService.getInstance().getCustomOpenAIModel())) {
builder.modelName(DevoxxGenieStateService.getInstance().getCustomOpenAIModel());
}

return builder.build();
}

Expand All @@ -51,6 +55,11 @@ public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel ch
if (Strings.isNotBlank(DevoxxGenieStateService.getInstance().getCustomOpenAIUrl())) {
builder.baseUrl(DevoxxGenieStateService.getInstance().getCustomOpenAIUrl());
}

if (Strings.isNotBlank(DevoxxGenieStateService.getInstance().getCustomOpenAIModel())) {
builder.modelName(DevoxxGenieStateService.getInstance().getCustomOpenAIModel());
}

return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,18 @@ private void addOpenAiModels() {
.contextWindow(128_000)
.apiKeyUsed(true)
.build());

String custom = "custom";
models.put(ModelProvider.OpenAI.getName() + ":" + gpt4oMini,
LanguageModel.builder()
.provider(ModelProvider.OpenAI)
.modelName(custom)
.displayName("Custom Model")
.inputCost(0.15)
.outputCost(0.6)
.contextWindow(128_000)
.apiKeyUsed(true)
.build());
}

private void addDeepInfraModels() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public static DevoxxGenieStateService getInstance() {
private String llamaCPPUrl = LLAMA_CPP_MODEL_URL;
private String jlamaUrl = JLAMA_MODEL_URL;
private String customOpenAIUrl = "";
private String customOpenAIModel = "";

// Local LLM Providers
private boolean isOllamaEnabled = true;
Expand All @@ -87,6 +88,7 @@ public static DevoxxGenieStateService getInstance() {
private boolean isLlamaCPPEnabled = true;
private boolean isJlamaEnabled = true;
private boolean isCustomOpenAIEnabled = false;
private boolean isCustomOpenAIModelEnabled = false;

// Remote LLM Providers
private boolean isOpenAIEnabled = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public class LLMProvidersComponent extends AbstractSettingsComponent {
@Getter
private final JTextField customOpenAIUrlField = new JTextField(stateService.getCustomOpenAIUrl());
@Getter
private final JTextField customOpenAIModelField = new JTextField(stateService.getCustomOpenAIModel());
@Getter
private final JPasswordField openAIKeyField = new JPasswordField(stateService.getOpenAIKey());
@Getter
private final JTextField azureOpenAIEndpointField = new JTextField(stateService.getAzureOpenAIEndpoint());
Expand Down Expand Up @@ -72,6 +74,8 @@ public class LLMProvidersComponent extends AbstractSettingsComponent {
private final JCheckBox jlamaEnabledCheckBox = new JCheckBox("", stateService.isJlamaEnabled());
@Getter
private final JCheckBox customOpenAIEnabledCheckBox = new JCheckBox("", stateService.isCustomOpenAIEnabled());
@Getter
private final JCheckBox customOpenAIModelEnabledCheckBox = new JCheckBox("", stateService.isCustomOpenAIModelEnabled());

@Getter
private final JCheckBox openAIEnabledCheckBox = new JCheckBox("", stateService.isOpenAIEnabled());
Expand Down Expand Up @@ -129,6 +133,7 @@ public JPanel createPanel() {
addProviderSettingRow(panel, gbc, "JLama URL", jlamaEnabledCheckBox,
createTextWithLinkButton(jlamaModelUrlField, "https://github.com/tjake/Jlama"));
addProviderSettingRow(panel, gbc, "Custom OpenAI URL", customOpenAIEnabledCheckBox, customOpenAIUrlField);
addProviderSettingRow(panel, gbc, "Custom OpenAI Model", customOpenAIModelEnabledCheckBox, customOpenAIModelField);

// Cloud LLM Providers section
addSection(panel, gbc, "Cloud LLM Providers");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public boolean isModified() {
isModified |= isFieldModified(llmSettingsComponent.getJanModelUrlField(), stateService.getJanModelUrl());
isModified |= isFieldModified(llmSettingsComponent.getExoModelUrlField(), stateService.getExoModelUrl());
isModified |= isFieldModified(llmSettingsComponent.getCustomOpenAIUrlField(), stateService.getCustomOpenAIUrl());
isModified |= isFieldModified(llmSettingsComponent.getCustomOpenAIModelField(), stateService.getCustomOpenAIModel());

isModified |= !stateService.getShowAzureOpenAIFields().equals(llmSettingsComponent.getEnableAzureOpenAICheckBox().isSelected());
isModified |= isFieldModified(llmSettingsComponent.getAzureOpenAIEndpointField(), stateService.getAzureOpenAIEndpoint());
Expand Down Expand Up @@ -118,6 +119,7 @@ public void apply() {
settings.setLlamaCPPUrl(llmSettingsComponent.getLlamaCPPModelUrlField().getText());
settings.setJlamaUrl(llmSettingsComponent.getJlamaModelUrlField().getText());
settings.setCustomOpenAIUrl(llmSettingsComponent.getCustomOpenAIUrlField().getText());
settings.setCustomOpenAIModel(llmSettingsComponent.getCustomOpenAIModelField().getText());

settings.setOpenAIKey(new String(llmSettingsComponent.getOpenAIKeyField().getPassword()));
settings.setMistralKey(new String(llmSettingsComponent.getMistralApiKeyField().getPassword()));
Expand Down Expand Up @@ -187,6 +189,7 @@ public void reset() {
llmSettingsComponent.getLlamaCPPModelUrlField().setText(settings.getLlamaCPPUrl());
llmSettingsComponent.getJlamaModelUrlField().setText(settings.getJlamaUrl());
llmSettingsComponent.getCustomOpenAIUrlField().setText(settings.getCustomOpenAIUrl());
llmSettingsComponent.getCustomOpenAIModelField().setText(settings.getCustomOpenAIModel());

llmSettingsComponent.getOpenAIKeyField().setText(settings.getOpenAIKey());
llmSettingsComponent.getMistralApiKeyField().setText(settings.getMistralKey());
Expand Down

0 comments on commit d028e2e

Please sign in to comment.