Skip to content

Commit

Permalink
Fix for #166
Browse files Browse the repository at this point in the history
  • Loading branch information
InAnYan committed Aug 31, 2024
1 parent ccb7369 commit 9d87a7b
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public AiChatComponent(AiService aiService,

this.aiChatLogic = aiService.getAiChatService().makeChat(name, chatHistory, entries);

aiService.getIngestionService().ingest(name, ListUtil.getLinkedFiles(entries).toList(), bibDatabaseContext);

ViewLoader.view(this)
.root(this)
.load();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package org.jabref.logic.ai.ingestion;

import java.util.List;
import java.util.concurrent.Future;

import javafx.beans.property.StringProperty;

import org.jabref.gui.util.BackgroundTask;
import org.jabref.gui.util.TaskExecutor;
import org.jabref.logic.ai.processingstatus.ProcessingInfo;
import org.jabref.logic.ai.processingstatus.ProcessingState;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.util.ProgressCounter;
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.LinkedFile;
import org.jabref.preferences.FilePreferences;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Int;

/**
* This task generates embeddings for several {@link LinkedFile} (typically used for groups).
* It will check if embeddings were already generated.
* And it also will store the embeddings.
*/
public class GenerateEmbeddingsForSeveralTask extends BackgroundTask<Void> {
private static final Logger LOGGER = LoggerFactory.getLogger(GenerateEmbeddingsTask.class);

private final StringProperty name;
private final List<ProcessingInfo<LinkedFile, Void>> linkedFiles;
private final FileEmbeddingsManager fileEmbeddingsManager;
private final BibDatabaseContext bibDatabaseContext;
private final FilePreferences filePreferences;
private final TaskExecutor taskExecutor;

private final ProgressCounter progressCounter = new ProgressCounter();

public GenerateEmbeddingsForSeveralTask(
StringProperty name,
List<ProcessingInfo<LinkedFile, Void>> linkedFiles,
FileEmbeddingsManager fileEmbeddingsManager,
BibDatabaseContext bibDatabaseContext,
FilePreferences filePreferences,
TaskExecutor taskExecutor
) {
this.name = name;
this.linkedFiles = linkedFiles;
this.fileEmbeddingsManager = fileEmbeddingsManager;
this.bibDatabaseContext = bibDatabaseContext;
this.filePreferences = filePreferences;
this.taskExecutor = taskExecutor;

configure(name);
}

private void configure(StringProperty name) {
showToUser(true);
titleProperty().set(Localization.lang("Generating embeddings for %0", name.get()));
name.addListener((o, oldValue, newValue) -> titleProperty().set(Localization.lang("Generating embeddings for %0", newValue)));

progressCounter.increaseWorkMax(linkedFiles.size());
progressCounter.listenToAllProperties(this::updateProgress);
updateProgress();
}

@Override
protected Void call() throws Exception {
LOGGER.debug("Starting embeddings generation of several files for {}", name.get());

List<? extends Future<?>> futures = linkedFiles
.stream()
.map(processingInfo -> {
processingInfo.setState(ProcessingState.PROCESSING);
return new GenerateEmbeddingsTask(processingInfo.getObject(), fileEmbeddingsManager, bibDatabaseContext, filePreferences)
.onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS))
.onFailure(processingInfo::setException)
.onFinished(() -> progressCounter.increaseWorkDone(1))
.executeWith(taskExecutor);
})
.toList();


for (Future<?> future : futures) {
future.get();
}

LOGGER.debug("Finished embeddings generation task of several files for {}", name.get());
progressCounter.stop();
return null;
}

private void updateProgress() {
updateProgress(progressCounter.getWorkDone(), progressCounter.getWorkMax());
updateMessage(progressCounter.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
/**
* This task generates embeddings for a {@link LinkedFile}.
* It will check if embeddings were already generated.
* And it also will store the summary.
* And it also will store the embeddings.
*/
public class GenerateEmbeddingsTask extends BackgroundTask<Void> {
private static final Logger LOGGER = LoggerFactory.getLogger(GenerateEmbeddingsTask.class);
Expand All @@ -47,7 +47,6 @@ public GenerateEmbeddingsTask(LinkedFile linkedFile,

private void configure(LinkedFile linkedFile) {
titleProperty().set(Localization.lang("Generating embeddings for file '%0'", linkedFile.getLink()));
showToUser(true);

progressCounter.listenToAllProperties(this::updateProgress);
}
Expand All @@ -62,7 +61,6 @@ protected Void call() throws Exception {
LOGGER.debug("There is a embeddings generation task for file \"{}\". It will be cancelled, because user quits JabRef.", linkedFile.getLink());
}

showToUser(false);
LOGGER.debug("Finished embeddings generation task for file \"{}\"", linkedFile.getLink());
progressCounter.stop();
return null;
Expand Down
42 changes: 37 additions & 5 deletions src/main/java/org/jabref/logic/ai/ingestion/IngestionService.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package org.jabref.logic.ai.ingestion;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javafx.beans.Observable;
import javafx.beans.property.ReadOnlyBooleanProperty;
import javafx.beans.property.StringProperty;
import javafx.collections.ObservableList;

import org.jabref.gui.util.TaskExecutor;
import org.jabref.logic.ai.processingstatus.ProcessingInfo;
Expand All @@ -25,6 +29,8 @@
public class IngestionService {
private final Map<LinkedFile, ProcessingInfo<LinkedFile, Void>> ingestionStatusMap = new HashMap<>();

private final List<List<LinkedFile>> listsUnderIngestion = new ArrayList<>();

private final FilePreferences filePreferences;
private final TaskExecutor taskExecutor;

Expand Down Expand Up @@ -57,11 +63,7 @@ public IngestionService(AiPreferences aiPreferences,
* on the same {@link LinkedFile}, the method will return the same {@link ProcessingInfo}.
*/
public ProcessingInfo<LinkedFile, Void> ingest(LinkedFile linkedFile, BibDatabaseContext bibDatabaseContext) {
ProcessingInfo<LinkedFile, Void> processingInfo = ingestionStatusMap.computeIfAbsent(linkedFile, file -> {
ProcessingInfo<LinkedFile, Void> newProcessingInfo = new ProcessingInfo<>(linkedFile, ProcessingState.PROCESSING);
startEmbeddingsGenerationTask(linkedFile, bibDatabaseContext, newProcessingInfo);
return newProcessingInfo;
});
ProcessingInfo<LinkedFile, Void> processingInfo = getProcessingInfo(linkedFile);

if (processingInfo.getState() == ProcessingState.STOPPED) {
startEmbeddingsGenerationTask(linkedFile, bibDatabaseContext, processingInfo);
Expand All @@ -70,13 +72,43 @@ public ProcessingInfo<LinkedFile, Void> ingest(LinkedFile linkedFile, BibDatabas
return processingInfo;
}

/**
* Get {@link ProcessingInfo} of a {@link LinkedFile}. Initially, it is in state {@link ProcessingState#STOPPED}.
* This method will not start ingesting. If you need to start it, use {@link IngestionService#ingest(LinkedFile, BibDatabaseContext)}.
*/
public ProcessingInfo<LinkedFile, Void> getProcessingInfo(LinkedFile linkedFile) {
return ingestionStatusMap.computeIfAbsent(linkedFile, file -> new ProcessingInfo<>(linkedFile, ProcessingState.STOPPED));
}

public List<ProcessingInfo<LinkedFile, Void>> getProcessingInfo(List<LinkedFile> linkedFiles) {
return linkedFiles.stream().map(this::getProcessingInfo).toList();
}

public List<ProcessingInfo<LinkedFile, Void>> ingest(StringProperty name, List<LinkedFile> linkedFiles, BibDatabaseContext bibDatabaseContext) {
List<ProcessingInfo<LinkedFile, Void>> result = getProcessingInfo(linkedFiles);

if (listsUnderIngestion.contains(linkedFiles)) {
return result;
}

List<ProcessingInfo<LinkedFile, Void>> needToProcess = result.stream().filter(processingInfo -> processingInfo.getState() == ProcessingState.STOPPED).toList();
startEmbeddingsGenerationTask(name, needToProcess, bibDatabaseContext);

return result;
}

private void startEmbeddingsGenerationTask(LinkedFile linkedFile, BibDatabaseContext bibDatabaseContext, ProcessingInfo<LinkedFile, Void> processingInfo) {
new GenerateEmbeddingsTask(linkedFile, fileEmbeddingsManager, bibDatabaseContext, filePreferences)
.onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS))
.onFailure(processingInfo::setException)
.executeWith(taskExecutor);
}

private void startEmbeddingsGenerationTask(StringProperty name, List<ProcessingInfo<LinkedFile, Void>> linkedFiles, BibDatabaseContext bibDatabaseContext) {
new GenerateEmbeddingsForSeveralTask(name, linkedFiles, fileEmbeddingsManager, bibDatabaseContext, filePreferences, taskExecutor)
.executeWith(taskExecutor);
}

public void clearEmbeddingsFor(List<LinkedFile> linkedFiles) {
fileEmbeddingsManager.clearEmbeddingsFor(linkedFiles);
ingestionStatusMap.values().forEach(processingInfo -> processingInfo.setState(ProcessingState.STOPPED));
Expand Down
4 changes: 0 additions & 4 deletions src/main/java/org/jabref/logic/util/ProgressCounter.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@ private void update() {
Duration eta = oneWorkTime.multipliedBy(workMax.get() - workDone.get() <= 0 ? 1 : workMax.get() - workDone.get());

updateMessage(eta);

if (workDone.get() != 0 && workMax.get() != 0 && workDone.get() == workMax.get()) {
stop();
}
}

@Override
Expand Down

0 comments on commit 9d87a7b

Please sign in to comment.