Skip to content

Commit

Permalink
Improve default AI preferences
Browse files Browse the repository at this point in the history
  • Loading branch information
InAnYan committed Oct 12, 2024
1 parent b646c7d commit 51cd6c5
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.jabref.gui.DialogService;
import org.jabref.gui.desktop.os.NativeDesktop;
import org.jabref.gui.frame.ExternalApplicationsPreferences;
import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.logic.ai.AiPreferences;
import org.jabref.model.ai.AiProvider;

Expand Down Expand Up @@ -72,12 +71,11 @@ private void initPrivacyHyperlink(TextFlow textFlow, AiProvider aiProvider) {
text.setText(replacedText);
text.wrappingWidthProperty().bind(this.widthProperty());

String link = AiDefaultPreferences.PROVIDERS_PRIVACY_POLICIES.get(aiProvider);
Hyperlink hyperlink = new Hyperlink(link);
Hyperlink hyperlink = new Hyperlink(aiProvider.getApiUrl());
hyperlink.setWrapText(true);
hyperlink.setFont(text.getFont());
hyperlink.setOnAction(event -> {
openBrowser(link);
openBrowser(aiProvider.getApiUrl());
});

textFlow.getChildren().add(hyperlink);
Expand Down
17 changes: 4 additions & 13 deletions src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

import javafx.beans.property.BooleanProperty;
Expand Down Expand Up @@ -119,7 +118,7 @@ public AiTabViewModel(CliPreferences preferences) {
);

this.selectedAiProvider.addListener((observable, oldValue, newValue) -> {
List<String> models = AiDefaultPreferences.AVAILABLE_CHAT_MODELS.get(newValue);
List<String> models = AiDefaultPreferences.getAvailableModels(newValue);

// When we setAll on Hugging Face, models are empty, and currentChatModel become null.
// It becomes null beause currentChatModel is binded to combobox, and this combobox becomes empty.
Expand Down Expand Up @@ -186,14 +185,7 @@ public AiTabViewModel(CliPreferences preferences) {
case HUGGING_FACE -> huggingFaceChatModel.set(newValue);
}

Map<String, Integer> modelContextWindows = AiDefaultPreferences.CONTEXT_WINDOW_SIZES.get(selectedAiProvider.get());

if (modelContextWindows == null) {
contextWindowSize.set(AiDefaultPreferences.CONTEXT_WINDOW_SIZE);
return;
}

contextWindowSize.set(modelContextWindows.getOrDefault(newValue, AiDefaultPreferences.CONTEXT_WINDOW_SIZE));
contextWindowSize.set(AiDefaultPreferences.getContextWindowSize(selectedAiProvider.get(), newValue));
});

this.currentApiKey.addListener((observable, oldValue, newValue) -> {
Expand Down Expand Up @@ -356,13 +348,12 @@ public void storeSettings() {
}

public void resetExpertSettings() {
String resetApiBaseUrl = AiDefaultPreferences.PROVIDERS_API_URLS.get(selectedAiProvider.get());
String resetApiBaseUrl = selectedAiProvider.get().getApiUrl();
currentApiBaseUrl.set(resetApiBaseUrl);

instruction.set(AiDefaultPreferences.SYSTEM_MESSAGE);

int resetContextWindowSize = AiDefaultPreferences.CONTEXT_WINDOW_SIZES.getOrDefault(selectedAiProvider.get(), Map.of()).getOrDefault(currentChatModel.get(), 0);
contextWindowSize.set(resetContextWindowSize);
contextWindowSize.set(AiDefaultPreferences.getContextWindowSize(selectedAiProvider.get(), currentChatModel.get()));

temperature.set(LocalizedNumbers.doubleToString(AiDefaultPreferences.TEMPERATURE));
documentSplitterChunkSize.set(AiDefaultPreferences.DOCUMENT_SPLITTER_CHUNK_SIZE);
Expand Down
108 changes: 60 additions & 48 deletions src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
Original file line number Diff line number Diff line change
@@ -1,67 +1,68 @@
package org.jabref.logic.ai;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

import org.jabref.model.ai.AiProvider;
import org.jabref.model.ai.EmbeddingModel;

public class AiDefaultPreferences {
public static final Map<AiProvider, List<String>> AVAILABLE_CHAT_MODELS = Map.of(
AiProvider.OPEN_AI, List.of("gpt-4o-mini", "gpt-4o", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo"),
// "mistral" and "mixtral" are not language mistakes.
AiProvider.MISTRAL_AI, List.of("open-mistral-nemo", "open-mistral-7b", "open-mixtral-8x7b", "open-mixtral-8x22b", "mistral-large-latest"),
AiProvider.GEMINI, List.of("gemini-1.5-flash", "gemini-1.5-pro", "gemini-1.0-pro"),
AiProvider.HUGGING_FACE, List.of()
);
public enum PredefinedChatModel {
GPT_4O_MINI(AiProvider.OPEN_AI, "gpt-4o-mini", 128000),
GPT_4O(AiProvider.OPEN_AI, "gpt-4o", 128000),
GPT_4(AiProvider.OPEN_AI, "gpt-4", 8192),
GPT_4_TURBO(AiProvider.OPEN_AI, "gpt-4-turbo", 128000),
GPT_3_5_TURBO(AiProvider.OPEN_AI, "gpt-3.5-turbo", 16385),
OPEN_MISTRAL_NEMO(AiProvider.MISTRAL_AI, "open-mistral-nemo", 128000),
OPEN_MISTRAL_7B(AiProvider.MISTRAL_AI, "open-mistral-7b", 32000),
// "mixtral" is not a typo.
OPEN_MIXTRAL_8X7B(AiProvider.MISTRAL_AI, "open-mixtral-8x7b", 32000),
OPEN_MIXTRAL_8X22B(AiProvider.MISTRAL_AI, "open-mixtral-8x22b", 64000),
GEMINI_1_5_FLASH(AiProvider.GEMINI, "gemini-1.5-flash", 1048576),
GEMINI_1_5_PRO(AiProvider.GEMINI, "gemini-1.5-pro", 2097152),
GEMINI_1_0_PRO(AiProvider.GEMINI, "gemini-1.0-pro", 32000),
// Dummy variant for Hugging Face models.
HUGGING_FACE(AiProvider.HUGGING_FACE, "", 0);

public static final Map<AiProvider, String> PROVIDERS_PRIVACY_POLICIES = Map.of(
AiProvider.OPEN_AI, "https://openai.com/policies/privacy-policy/",
AiProvider.MISTRAL_AI, "https://mistral.ai/terms/#privacy-policy",
AiProvider.GEMINI, "https://ai.google.dev/gemini-api/terms",
AiProvider.HUGGING_FACE, "https://huggingface.co/privacy"
);
private final AiProvider aiProvider;
private final String name;
private final int contextWindowSize;

public static final Map<AiProvider, String> PROVIDERS_API_URLS = Map.of(
AiProvider.OPEN_AI, "https://api.openai.com/v1",
AiProvider.MISTRAL_AI, "https://api.mistral.ai/v1",
AiProvider.GEMINI, "https://generativelanguage.googleapis.com/v1beta/",
AiProvider.HUGGING_FACE, "https://huggingface.co/api"
);
PredefinedChatModel(AiProvider aiProvider, String name, int contextWindowSize) {
this.aiProvider = aiProvider;
this.name = name;
this.contextWindowSize = contextWindowSize;
}

public static final Map<AiProvider, Map<String, Integer>> CONTEXT_WINDOW_SIZES = Map.of(
AiProvider.OPEN_AI, Map.of(
"gpt-4o-mini", 128000,
"gpt-4o", 128000,
"gpt-4", 8192,
"gpt-4-turbo", 128000,
"gpt-3.5-turbo", 16385
),
AiProvider.MISTRAL_AI, Map.of(
"mistral-large-latest", 128000,
"open-mistral-nemo", 128000,
"open-mistral-7b", 32000,
"open-mixtral-8x7b", 32000,
"open-mixtral-8x22b", 64000
),
AiProvider.GEMINI, Map.of(
"gemini-1.5-flash", 1048576,
"gemini-1.5-pro", 2097152,
"gemini-1.0-pro", 32000
)
);
public AiProvider getAiProvider() {
return aiProvider;
}

public String getName() {
return name;
}

public int getContextWindowSize() {
return contextWindowSize;
}

public String toString() {
return aiProvider.toString() + " " + name;
}
}

public static final boolean ENABLE_CHAT = false;
public static final boolean AUTO_GENERATE_EMBEDDINGS = false;
public static final boolean AUTO_GENERATE_SUMMARIES = false;

public static final AiProvider PROVIDER = AiProvider.OPEN_AI;

public static final Map<AiProvider, String> CHAT_MODELS = Map.of(
AiProvider.OPEN_AI, "gpt-4o-mini",
AiProvider.MISTRAL_AI, "open-mixtral-8x22b",
AiProvider.GEMINI, "gemini-1.5-flash",
AiProvider.HUGGING_FACE, ""
public static final Map<AiProvider, PredefinedChatModel> CHAT_MODELS = Map.of(
AiProvider.OPEN_AI, PredefinedChatModel.GPT_4O_MINI,
AiProvider.MISTRAL_AI, PredefinedChatModel.OPEN_MIXTRAL_8X22B,
AiProvider.GEMINI, PredefinedChatModel.GEMINI_1_5_FLASH,
AiProvider.HUGGING_FACE, PredefinedChatModel.HUGGING_FACE
);

public static final boolean CUSTOMIZE_SETTINGS = false;
Expand All @@ -74,9 +75,20 @@ public class AiDefaultPreferences {
public static final int RAG_MAX_RESULTS_COUNT = 10;
public static final double RAG_MIN_SCORE = 0.3;

public static final int CONTEXT_WINDOW_SIZE = 8196;
public static final int FALLBACK_CONTEXT_WINDOW_SIZE = 8196;

public static List<String> getAvailableModels(AiProvider aiProvider) {
return Arrays.stream(AiDefaultPreferences.PredefinedChatModel.values())
.filter(model -> model.getAiProvider() == aiProvider)
.map(AiDefaultPreferences.PredefinedChatModel::getName)
.toList();
}

public static int getContextWindowSize(AiProvider aiProvider, String model) {
return CONTEXT_WINDOW_SIZES.getOrDefault(aiProvider, Map.of()).getOrDefault(model, 0);
public static int getContextWindowSize(AiProvider aiProvider, String modelName) {
return Arrays.stream(AiDefaultPreferences.PredefinedChatModel.values())
.filter(model -> model.getAiProvider() == aiProvider && model.getName().equals(modelName))
.map(AiDefaultPreferences.PredefinedChatModel::getContextWindowSize)
.findFirst()
.orElse(AiDefaultPreferences.FALLBACK_CONTEXT_WINDOW_SIZE);
}
}
2 changes: 1 addition & 1 deletion src/main/java/org/jabref/logic/ai/AiPreferences.java
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ public String getSelectedApiBaseUrl() {
geminiApiBaseUrl.get();
};
} else {
return AiDefaultPreferences.PROVIDERS_API_URLS.get(aiProvider.get());
return aiProvider.get().getApiUrl();
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/main/java/org/jabref/logic/ai/chatting/AiChatLogic.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import javafx.collections.ListChangeListener;
import javafx.collections.ObservableList;

import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.logic.ai.AiPreferences;
import org.jabref.logic.ai.ingestion.FileEmbeddingsManager;
import org.jabref.logic.ai.util.ErrorMessage;
Expand Down Expand Up @@ -160,7 +159,7 @@ public AiMessage execute(UserMessage message) {
// Message will be automatically added to ChatMemory through ConversationalRetrievalChain.

LOGGER.info("Sending message to AI provider ({}) for answering in {}: {}",
AiDefaultPreferences.PROVIDERS_API_URLS.get(aiPreferences.getAiProvider()),
aiPreferences.getAiProvider().getApiUrl(),
name.get(),
message.singleText());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -628,19 +628,19 @@ protected JabRefCliPreferences() {
defaults.put(AI_AUTO_GENERATE_EMBEDDINGS, AiDefaultPreferences.AUTO_GENERATE_EMBEDDINGS);
defaults.put(AI_AUTO_GENERATE_SUMMARIES, AiDefaultPreferences.AUTO_GENERATE_SUMMARIES);
defaults.put(AI_PROVIDER, AiDefaultPreferences.PROVIDER.name());
defaults.put(AI_OPEN_AI_CHAT_MODEL, AiDefaultPreferences.CHAT_MODELS.get(AiProvider.OPEN_AI));
defaults.put(AI_MISTRAL_AI_CHAT_MODEL, AiDefaultPreferences.CHAT_MODELS.get(AiProvider.MISTRAL_AI));
defaults.put(AI_GEMINI_CHAT_MODEL, AiDefaultPreferences.CHAT_MODELS.get(AiProvider.GEMINI));
defaults.put(AI_HUGGING_FACE_CHAT_MODEL, AiDefaultPreferences.CHAT_MODELS.get(AiProvider.HUGGING_FACE));
defaults.put(AI_OPEN_AI_CHAT_MODEL, AiDefaultPreferences.CHAT_MODELS.get(AiProvider.OPEN_AI).getName());
defaults.put(AI_MISTRAL_AI_CHAT_MODEL, AiDefaultPreferences.CHAT_MODELS.get(AiProvider.MISTRAL_AI).getName());
defaults.put(AI_GEMINI_CHAT_MODEL, AiDefaultPreferences.CHAT_MODELS.get(AiProvider.GEMINI).getName());
defaults.put(AI_HUGGING_FACE_CHAT_MODEL, AiDefaultPreferences.CHAT_MODELS.get(AiProvider.HUGGING_FACE).getName());
defaults.put(AI_CUSTOMIZE_SETTINGS, AiDefaultPreferences.CUSTOMIZE_SETTINGS);
defaults.put(AI_EMBEDDING_MODEL, AiDefaultPreferences.EMBEDDING_MODEL.name());
defaults.put(AI_OPEN_AI_API_BASE_URL, AiDefaultPreferences.PROVIDERS_API_URLS.get(AiProvider.OPEN_AI));
defaults.put(AI_MISTRAL_AI_API_BASE_URL, AiDefaultPreferences.PROVIDERS_API_URLS.get(AiProvider.MISTRAL_AI));
defaults.put(AI_GEMINI_API_BASE_URL, AiDefaultPreferences.PROVIDERS_API_URLS.get(AiProvider.GEMINI));
defaults.put(AI_HUGGING_FACE_API_BASE_URL, AiDefaultPreferences.PROVIDERS_API_URLS.get(AiProvider.HUGGING_FACE));
defaults.put(AI_OPEN_AI_API_BASE_URL, AiProvider.OPEN_AI.getApiUrl());
defaults.put(AI_MISTRAL_AI_API_BASE_URL, AiProvider.MISTRAL_AI.getApiUrl());
defaults.put(AI_GEMINI_API_BASE_URL, AiProvider.GEMINI.getApiUrl());
defaults.put(AI_HUGGING_FACE_API_BASE_URL, AiProvider.HUGGING_FACE.getApiUrl());
defaults.put(AI_SYSTEM_MESSAGE, AiDefaultPreferences.SYSTEM_MESSAGE);
defaults.put(AI_TEMPERATURE, AiDefaultPreferences.TEMPERATURE);
defaults.put(AI_CONTEXT_WINDOW_SIZE, AiDefaultPreferences.CONTEXT_WINDOW_SIZES.get(AiDefaultPreferences.PROVIDER).get(AiDefaultPreferences.CHAT_MODELS.get(AiDefaultPreferences.PROVIDER)));
defaults.put(AI_CONTEXT_WINDOW_SIZE, AiDefaultPreferences.getContextWindowSize(AiDefaultPreferences.PROVIDER, AiDefaultPreferences.CHAT_MODELS.get(AiDefaultPreferences.PROVIDER).getName()));
defaults.put(AI_DOCUMENT_SPLITTER_CHUNK_SIZE, AiDefaultPreferences.DOCUMENT_SPLITTER_CHUNK_SIZE);
defaults.put(AI_DOCUMENT_SPLITTER_OVERLAP_SIZE, AiDefaultPreferences.DOCUMENT_SPLITTER_OVERLAP);
defaults.put(AI_RAG_MAX_RESULTS_COUNT, AiDefaultPreferences.RAG_MAX_RESULTS_COUNT);
Expand Down
22 changes: 17 additions & 5 deletions src/main/java/org/jabref/model/ai/AiProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,33 @@
import java.io.Serializable;

public enum AiProvider implements Serializable {
OPEN_AI("OpenAI"),
MISTRAL_AI("Mistral AI"),
GEMINI("Gemini"),
HUGGING_FACE("Hugging Face");
OPEN_AI("OpenAI", "https://openai.com/policies/privacy-policy/", "https://openai.com/policies/privacy-policy/"),
MISTRAL_AI("Mistral AI", "https://mistral.ai/terms/#privacy-policy", "https://mistral.ai/terms/#privacy-policy"),
GEMINI("Gemini", "https://huggingface.co/privacy", "https://ai.google.dev/gemini-api/terms"),
HUGGING_FACE("Hugging Face", "https://huggingface.co/api", "https://huggingface.co/privacy");

private final String label;
private final String apiUrl;
private final String privacyPolicyUrl;

AiProvider(String label) {
AiProvider(String label, String apiUrl, String privacyPolicyUrl) {
this.label = label;
this.apiUrl = apiUrl;
this.privacyPolicyUrl = privacyPolicyUrl;
}

public String getLabel() {
return label;
}

public String getApiUrl() {
return apiUrl;
}

public String getPrivacyPolicyUrl() {
return privacyPolicyUrl;
}

public String toString() {
return label;
}
Expand Down

0 comments on commit 51cd6c5

Please sign in to comment.