diff --git a/src/main/java/org/jabref/gui/JabRefGUI.java b/src/main/java/org/jabref/gui/JabRefGUI.java index b20c103d2c1..b0ef741918d 100644 --- a/src/main/java/org/jabref/gui/JabRefGUI.java +++ b/src/main/java/org/jabref/gui/JabRefGUI.java @@ -156,7 +156,7 @@ public void initialize() { JabRefGUI.clipBoardManager = new ClipBoardManager(); Injector.setModelOrService(ClipBoardManager.class, clipBoardManager); - JabRefGUI.aiService = new AiService(preferencesService.getAiPreferences(), dialogService, taskExecutor); + JabRefGUI.aiService = new AiService(preferencesService, dialogService, taskExecutor); Injector.setModelOrService(AiService.class, aiService); } diff --git a/src/main/java/org/jabref/gui/entryeditor/AiChatTab.java b/src/main/java/org/jabref/gui/entryeditor/AiChatTab.java index 5cbe5ad54fa..548a39a5fd3 100644 --- a/src/main/java/org/jabref/gui/entryeditor/AiChatTab.java +++ b/src/main/java/org/jabref/gui/entryeditor/AiChatTab.java @@ -46,6 +46,7 @@ public class AiChatTab extends EntryEditorTab { private final BibDatabaseContext bibDatabaseContext; private final TaskExecutor taskExecutor; private final CitationKeyGenerator citationKeyGenerator; + private final PreferencesService preferencesService; private final AiService aiService; private final List entriesUnderIngestion = new ArrayList<>(); @@ -64,6 +65,7 @@ public AiChatTab(LibraryTabContainer libraryTabContainer, this.bibDatabaseContext = bibDatabaseContext; this.taskExecutor = taskExecutor; this.citationKeyGenerator = new CitationKeyGenerator(bibDatabaseContext, preferencesService.getCitationKeyPatternPreferences()); + this.preferencesService = preferencesService; setText(Localization.lang("AI chat")); setTooltip(new Tooltip(Localization.lang("Chat with AI about content of attached file(s)"))); @@ -88,7 +90,7 @@ protected void handleFocus() { protected void bindToEntry(BibEntry entry) { if (!aiService.getPreferences().getEnableAi()) { showPrivacyNotice(entry); - } else if (aiService.getPreferences().getSelectedApiKey().isEmpty()) { + } else if (aiService.getPreferences().getSelectedApiKey(preferencesService).isEmpty()) { showApiKeyMissing(); } else if (entry.getFiles().isEmpty()) { showErrorNoFiles(); diff --git a/src/main/java/org/jabref/gui/entryeditor/AiSummaryTab.java b/src/main/java/org/jabref/gui/entryeditor/AiSummaryTab.java index 9213a056da7..feb5635ac23 100644 --- a/src/main/java/org/jabref/gui/entryeditor/AiSummaryTab.java +++ b/src/main/java/org/jabref/gui/entryeditor/AiSummaryTab.java @@ -42,6 +42,7 @@ public class AiSummaryTab extends EntryEditorTab { private final TaskExecutor taskExecutor; private final CitationKeyGenerator citationKeyGenerator; private final AiService aiService; + private final PreferencesService preferencesService; private final List entriesUnderSummarization = new ArrayList<>(); @@ -59,6 +60,7 @@ public AiSummaryTab(LibraryTabContainer libraryTabContainer, this.bibDatabaseContext = bibDatabaseContext; this.taskExecutor = taskExecutor; this.citationKeyGenerator = new CitationKeyGenerator(bibDatabaseContext, preferencesService.getCitationKeyPatternPreferences()); + this.preferencesService = preferencesService; setText(Localization.lang("AI summary")); setTooltip(new Tooltip(Localization.lang("AI-generated summary of attached file(s)"))); @@ -82,7 +84,7 @@ protected void handleFocus() { protected void bindToEntry(BibEntry entry) { if (!aiService.getPreferences().getEnableAi()) { showPrivacyNotice(entry); - } else if (aiService.getPreferences().getSelectedApiKey().isEmpty()) { + } else if (aiService.getPreferences().getSelectedApiKey(preferencesService).isEmpty()) { showApiKeyMissing(); } else if (bibDatabaseContext.getDatabasePath().isEmpty()) { showErrorNoDatabasePath(); diff --git a/src/main/java/org/jabref/logic/ai/AiService.java b/src/main/java/org/jabref/logic/ai/AiService.java index 920f8636e37..21839517c6e 100644 --- a/src/main/java/org/jabref/logic/ai/AiService.java +++ b/src/main/java/org/jabref/logic/ai/AiService.java @@ -17,6 +17,7 @@ import org.jabref.logic.ai.models.JabRefEmbeddingModel; import org.jabref.logic.ai.summarization.SummariesStorage; import org.jabref.logic.l10n.Localization; +import org.jabref.preferences.PreferencesService; import org.jabref.preferences.ai.AiPreferences; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -56,8 +57,8 @@ public class AiService implements AutoCloseable { private final SummariesStorage summariesStorage; - public AiService(AiPreferences aiPreferences, DialogService dialogService, TaskExecutor taskExecutor) { - this.aiPreferences = aiPreferences; + public AiService(PreferencesService preferencesService, DialogService dialogService, TaskExecutor taskExecutor) { + this.aiPreferences = preferencesService.getAiPreferences(); MVStore mvStore; try { @@ -74,7 +75,7 @@ public AiService(AiPreferences aiPreferences, DialogService dialogService, TaskE this.mvStore = mvStore; - this.jabRefChatLanguageModel = new JabRefChatLanguageModel(aiPreferences); + this.jabRefChatLanguageModel = new JabRefChatLanguageModel(preferencesService); this.bibDatabaseChatHistoryManager = new BibDatabaseChatHistoryManager(mvStore); this.jabRefEmbeddingModel = new JabRefEmbeddingModel(aiPreferences, dialogService, taskExecutor); this.fileEmbeddingsManager = new FileEmbeddingsManager(aiPreferences, shutdownSignal, jabRefEmbeddingModel, mvStore); diff --git a/src/main/java/org/jabref/logic/ai/models/JabRefChatLanguageModel.java b/src/main/java/org/jabref/logic/ai/models/JabRefChatLanguageModel.java index 1698929a43c..9119f96f24f 100644 --- a/src/main/java/org/jabref/logic/ai/models/JabRefChatLanguageModel.java +++ b/src/main/java/org/jabref/logic/ai/models/JabRefChatLanguageModel.java @@ -9,6 +9,7 @@ import org.jabref.logic.ai.AiChatLogic; import org.jabref.logic.l10n.Localization; +import org.jabref.preferences.PreferencesService; import org.jabref.preferences.ai.AiPreferences; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -29,6 +30,7 @@ public class JabRefChatLanguageModel implements ChatLanguageModel, AutoCloseable { private static final Duration CONNECTION_TIMEOUT = Duration.ofSeconds(5); + private final PreferencesService preferencesService; private final AiPreferences aiPreferences; private final HttpClient httpClient; @@ -38,8 +40,9 @@ public class JabRefChatLanguageModel implements ChatLanguageModel, AutoCloseable private Optional langchainChatModel = Optional.empty(); - public JabRefChatLanguageModel(AiPreferences aiPreferences) { - this.aiPreferences = aiPreferences; + public JabRefChatLanguageModel(PreferencesService preferencesService) { + this.preferencesService = preferencesService; + this.aiPreferences = preferencesService.getAiPreferences(); this.httpClient = HttpClient.newBuilder().connectTimeout(CONNECTION_TIMEOUT).executor(executorService).build(); if (aiPreferences.getEnableAi()) { @@ -56,20 +59,20 @@ public JabRefChatLanguageModel(AiPreferences aiPreferences) { * and using {@link org.jabref.logic.ai.chathistory.BibDatabaseChatHistoryManager}, where messages are stored in {@link MVStore}. */ private void rebuild() { - if (!aiPreferences.getEnableAi() || aiPreferences.getSelectedApiKey().isEmpty()) { + if (!aiPreferences.getEnableAi() || aiPreferences.getSelectedApiKey(preferencesService).isEmpty()) { langchainChatModel = Optional.empty(); return; } switch (aiPreferences.getAiProvider()) { case OPEN_AI -> { - langchainChatModel = Optional.of(new JvmOpenAiChatLanguageModel(aiPreferences, httpClient)); + langchainChatModel = Optional.of(new JvmOpenAiChatLanguageModel(preferencesService, httpClient)); } case MISTRAL_AI -> { langchainChatModel = Optional.of(MistralAiChatModel .builder() - .apiKey(aiPreferences.getSelectedApiKey()) + .apiKey(aiPreferences.getSelectedApiKey(preferencesService)) .modelName(aiPreferences.getSelectedChatModel()) .temperature(aiPreferences.getTemperature()) .baseUrl(aiPreferences.getSelectedApiBaseUrl()) @@ -83,7 +86,7 @@ private void rebuild() { // NOTE: {@link HuggingFaceChatModel} doesn't support API base url :( langchainChatModel = Optional.of(HuggingFaceChatModel .builder() - .accessToken(aiPreferences.getSelectedApiKey()) + .accessToken(aiPreferences.getSelectedApiKey(preferencesService)) .modelId(aiPreferences.getSelectedChatModel()) .temperature(aiPreferences.getTemperature()) .timeout(Duration.ofMinutes(2)) @@ -116,7 +119,7 @@ public Response generate(List list) { if (langchainChatModel.isEmpty()) { if (!aiPreferences.getEnableAi()) { throw new RuntimeException(Localization.lang("In order to use AI chat, you need to enable chatting with attached PDF files in JabRef preferences (AI tab).")); - } else if (aiPreferences.getSelectedApiKey().isEmpty()) { + } else if (aiPreferences.getSelectedApiKey(preferencesService).isEmpty()) { throw new RuntimeException(Localization.lang("In order to use AI chat, set OpenAI API key inside JabRef preferences (AI tab).")); } else { throw new RuntimeException(Localization.lang("Unable to chat with AI.")); diff --git a/src/main/java/org/jabref/logic/ai/models/JabRefEmbeddingModel.java b/src/main/java/org/jabref/logic/ai/models/JabRefEmbeddingModel.java index 7ac288d22a0..c48e76636d3 100644 --- a/src/main/java/org/jabref/logic/ai/models/JabRefEmbeddingModel.java +++ b/src/main/java/org/jabref/logic/ai/models/JabRefEmbeddingModel.java @@ -68,6 +68,10 @@ public void startRebuildingTask() { return; } + if (predictorProperty.get().isPresent()) { + predictorProperty.get().get().close(); + } + predictorProperty.set(Optional.empty()); new UpdateEmbeddingModelTask(aiPreferences, predictorProperty) diff --git a/src/main/java/org/jabref/logic/ai/models/JvmOpenAiChatLanguageModel.java b/src/main/java/org/jabref/logic/ai/models/JvmOpenAiChatLanguageModel.java index 43de70ba304..5d1bb58e92b 100644 --- a/src/main/java/org/jabref/logic/ai/models/JvmOpenAiChatLanguageModel.java +++ b/src/main/java/org/jabref/logic/ai/models/JvmOpenAiChatLanguageModel.java @@ -3,6 +3,7 @@ import java.net.http.HttpClient; import java.util.List; +import org.jabref.preferences.PreferencesService; import org.jabref.preferences.ai.AiPreferences; import dev.langchain4j.data.message.AiMessage; @@ -29,11 +30,11 @@ public class JvmOpenAiChatLanguageModel implements ChatLanguageModel { private final ChatClient chatClient; - public JvmOpenAiChatLanguageModel(AiPreferences aiPreferences, HttpClient httpClient) { - this.aiPreferences = aiPreferences; + public JvmOpenAiChatLanguageModel(PreferencesService preferencesService, HttpClient httpClient) { + this.aiPreferences = preferencesService.getAiPreferences(); OpenAI openAI = OpenAI - .newBuilder(aiPreferences.getSelectedApiKey()) + .newBuilder(aiPreferences.getSelectedApiKey(preferencesService)) .httpClient(httpClient) .baseUrl(aiPreferences.getSelectedApiBaseUrl()) .build(); diff --git a/src/main/java/org/jabref/preferences/JabRefPreferences.java b/src/main/java/org/jabref/preferences/JabRefPreferences.java index ffefb733805..efb81077339 100644 --- a/src/main/java/org/jabref/preferences/JabRefPreferences.java +++ b/src/main/java/org/jabref/preferences/JabRefPreferences.java @@ -2790,7 +2790,6 @@ public AiPreferences getAiPreferences() { boolean aiEnabled = getBoolean(AI_ENABLED); aiPreferences = new AiPreferences( - this, aiEnabled, AiProvider.valueOf(get(AI_PROVIDER)), get(AI_OPEN_AI_CHAT_MODEL), diff --git a/src/main/java/org/jabref/preferences/ai/AiPreferences.java b/src/main/java/org/jabref/preferences/ai/AiPreferences.java index c4496b35fd1..be20c24618c 100644 --- a/src/main/java/org/jabref/preferences/ai/AiPreferences.java +++ b/src/main/java/org/jabref/preferences/ai/AiPreferences.java @@ -17,8 +17,6 @@ import org.jabref.preferences.PreferencesService; public class AiPreferences { - private final PreferencesService preferencesService; - private final BooleanProperty enableAi; private final ObjectProperty aiProvider; @@ -46,8 +44,7 @@ public class AiPreferences { private final IntegerProperty ragMaxResultsCount; private final DoubleProperty ragMinScore; - public AiPreferences(PreferencesService preferencesService, - boolean enableAi, + public AiPreferences(boolean enableAi, AiProvider aiProvider, String openAiChatModel, String mistralAiChatModel, @@ -65,8 +62,6 @@ public AiPreferences(PreferencesService preferencesService, int ragMaxResultsCount, double ragMinScore ) { - this.preferencesService = preferencesService; - this.enableAi = new SimpleBooleanProperty(enableAi); this.aiProvider = new SimpleObjectProperty<>(aiProvider); @@ -467,16 +462,16 @@ public String getSelectedChatModel() { }; } - public String getSelectedApiKey() { + public String getSelectedApiKey(PreferencesService preferencesService) { if (!enableAi.get()) { return ""; } - retrieveKeys(); + retrieveKeys(preferencesService); return getKeys(); } - private void retrieveKeys() { + private void retrieveKeys(PreferencesService preferencesService) { switch (aiProvider.get()) { case OPEN_AI -> { if (openAiApiKey.get().isEmpty()) {