forked from JabRef/jabref
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
137 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
src/main/java/org/jabref/logic/ai/ingestion/GenerateEmbeddingsForSeveralTask.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
package org.jabref.logic.ai.ingestion; | ||
|
||
import java.util.List; | ||
import java.util.concurrent.Future; | ||
|
||
import javafx.beans.property.StringProperty; | ||
|
||
import org.jabref.gui.util.BackgroundTask; | ||
import org.jabref.gui.util.TaskExecutor; | ||
import org.jabref.logic.ai.processingstatus.ProcessingInfo; | ||
import org.jabref.logic.ai.processingstatus.ProcessingState; | ||
import org.jabref.logic.l10n.Localization; | ||
import org.jabref.logic.util.ProgressCounter; | ||
import org.jabref.model.database.BibDatabaseContext; | ||
import org.jabref.model.entry.LinkedFile; | ||
import org.jabref.preferences.FilePreferences; | ||
|
||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import scala.Int; | ||
|
||
/** | ||
* This task generates embeddings for several {@link LinkedFile} (typically used for groups). | ||
* It will check if embeddings were already generated. | ||
* And it also will store the embeddings. | ||
*/ | ||
public class GenerateEmbeddingsForSeveralTask extends BackgroundTask<Void> { | ||
private static final Logger LOGGER = LoggerFactory.getLogger(GenerateEmbeddingsTask.class); | ||
|
||
private final StringProperty name; | ||
private final List<ProcessingInfo<LinkedFile, Void>> linkedFiles; | ||
private final FileEmbeddingsManager fileEmbeddingsManager; | ||
private final BibDatabaseContext bibDatabaseContext; | ||
private final FilePreferences filePreferences; | ||
private final TaskExecutor taskExecutor; | ||
|
||
private final ProgressCounter progressCounter = new ProgressCounter(); | ||
|
||
public GenerateEmbeddingsForSeveralTask( | ||
StringProperty name, | ||
List<ProcessingInfo<LinkedFile, Void>> linkedFiles, | ||
FileEmbeddingsManager fileEmbeddingsManager, | ||
BibDatabaseContext bibDatabaseContext, | ||
FilePreferences filePreferences, | ||
TaskExecutor taskExecutor | ||
) { | ||
this.name = name; | ||
this.linkedFiles = linkedFiles; | ||
this.fileEmbeddingsManager = fileEmbeddingsManager; | ||
this.bibDatabaseContext = bibDatabaseContext; | ||
this.filePreferences = filePreferences; | ||
this.taskExecutor = taskExecutor; | ||
|
||
configure(name); | ||
} | ||
|
||
private void configure(StringProperty name) { | ||
showToUser(true); | ||
titleProperty().set(Localization.lang("Generating embeddings for %0", name.get())); | ||
name.addListener((o, oldValue, newValue) -> titleProperty().set(Localization.lang("Generating embeddings for %0", newValue))); | ||
|
||
progressCounter.increaseWorkMax(linkedFiles.size()); | ||
progressCounter.listenToAllProperties(this::updateProgress); | ||
updateProgress(); | ||
} | ||
|
||
@Override | ||
protected Void call() throws Exception { | ||
LOGGER.debug("Starting embeddings generation of several files for {}", name.get()); | ||
|
||
List<? extends Future<?>> futures = linkedFiles | ||
.stream() | ||
.map(processingInfo -> { | ||
processingInfo.setState(ProcessingState.PROCESSING); | ||
return new GenerateEmbeddingsTask(processingInfo.getObject(), fileEmbeddingsManager, bibDatabaseContext, filePreferences) | ||
.onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS)) | ||
.onFailure(processingInfo::setException) | ||
.onFinished(() -> progressCounter.increaseWorkDone(1)) | ||
.executeWith(taskExecutor); | ||
}) | ||
.toList(); | ||
|
||
|
||
for (Future<?> future : futures) { | ||
future.get(); | ||
} | ||
|
||
LOGGER.debug("Finished embeddings generation task of several files for {}", name.get()); | ||
progressCounter.stop(); | ||
return null; | ||
} | ||
|
||
private void updateProgress() { | ||
updateProgress(progressCounter.getWorkDone(), progressCounter.getWorkMax()); | ||
updateMessage(progressCounter.getMessage()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters