Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Bookmark embeddings #834

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions apps/web/components/admin/AdminActions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ export default function AdminActions() {
},
});

const { mutateAsync: reEmbedBookmarks, isPending: isReEmbedPending } =
api.admin.reEmbedAllBookmarks.useMutation({
onSuccess: () => {
toast({
description: "ReEmbed request has been enqueued!",
});
},
onError: (e) => {
toast({
variant: "destructive",
description: e.message,
});
},
});

return (
<div>
<div className="mb-2 text-xl font-medium">{t("common.actions")}</div>
Expand Down Expand Up @@ -124,6 +139,13 @@ export default function AdminActions() {
>
{t("admin.actions.reindex_all_bookmarks")}
</ActionButton>
<ActionButton
variant="destructive"
loading={isReEmbedPending}
onClick={() => reEmbedBookmarks()}
>
{t("admin.actions.reembed_all_bookmarks")}
</ActionButton>
<ActionButton
variant="destructive"
loading={isTidyAssetsPending}
Expand Down
1 change: 1 addition & 0 deletions apps/web/lib/i18n/locales/en/translation.json
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
"regenerate_ai_tags_for_failed_bookmarks_only": "Regenerate AI Tags for Failed Bookmarks Only",
"regenerate_ai_tags_for_all_bookmarks": "Regenerate AI Tags for All Bookmarks",
"reindex_all_bookmarks": "Reindex All Bookmarks",
"reembed_all_bookmarks": "Re-embed All Bookmarks",
"compact_assets": "Compact Assets"
},
"users_list": {
Expand Down
182 changes: 182 additions & 0 deletions apps/workers/embeddingsWorker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import { RecursiveCharacterTextSplitter } from "@langchain/textsplitters";
import { eq } from "drizzle-orm";
import { DequeuedJob, Runner } from "liteque";

import type { EmbeddingsRequest, ZOpenAIRequest } from "@hoarder/shared/queues";
import { db } from "@hoarder/db";
import { bookmarks } from "@hoarder/db/schema";
import serverConfig from "@hoarder/shared/config";
import { InferenceClientFactory } from "@hoarder/shared/inference";
import logger from "@hoarder/shared/logger";
import {
EmbeddingsQueue,
zEmbeddingsRequestSchema,
} from "@hoarder/shared/queues";
import { getBookmarkVectorDb } from "@hoarder/shared/vectorDb";

interface EmbeddingChunk {
embeddingType: "description" | "content_full" | "content_chunk";
fromOffset: number;
toOffset: number;
text: string;
}

export class EmbeddingsWorker {
static build() {
logger.info("Starting embeddings worker ...");
const worker = new Runner<ZOpenAIRequest>(
EmbeddingsQueue,
{
run: runEmbeddings,
onComplete: async (job) => {
const jobId = job.id;
logger.info(`[embeddings][${jobId}] Completed successfully`);
return Promise.resolve();
},
onError: async (job) => {
const jobId = job.id;
logger.error(
`[embeddings][${jobId}] embeddings job failed: ${job.error}\n${job.error.stack}`,
);
return Promise.resolve();
},
},
{
concurrency: 1,
pollIntervalMs: 1000,
timeoutSecs: serverConfig.inference.jobTimeoutSec,
validator: zEmbeddingsRequestSchema,
},
);

return worker;
}
}

async function fetchBookmark(linkId: string) {
return await db.query.bookmarks.findFirst({
where: eq(bookmarks.id, linkId),
with: {
link: true,
text: true,
asset: true,
},
});
}

async function chunkText(text: string): Promise<EmbeddingChunk[]> {
const textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: 100,
chunkOverlap: 0,
});
const texts = await textSplitter.splitText(text);
return texts.map((t) => ({
embeddingType: "content_chunk",
text: t,
fromOffset: 0,
toOffset: t.length,
}));
}

async function prepareEmbeddings(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
) {
const reqs: EmbeddingChunk[] = [];

if (bookmark.link) {
if (bookmark.link.description) {
reqs.push({
embeddingType: "description",
fromOffset: 0,
toOffset: bookmark.link.description?.length ?? 0,
text: bookmark.link.description ?? "",
});
}
if (bookmark.link.content) {
reqs.push({
embeddingType: "content_full",
fromOffset: 0,
toOffset: bookmark.link.content?.length ?? 0,
text: bookmark.link.content ?? "",
});
reqs.push(...(await chunkText(bookmark.link.content ?? "")));
}
}

if (bookmark.text) {
if (bookmark.text.text) {
reqs.push({
embeddingType: "description",
fromOffset: 0,
toOffset: bookmark.text.text?.length ?? 0,
text: bookmark.text.text ?? "",
});
reqs.push(...(await chunkText(bookmark.text.text)));
}
}

if (bookmark.asset) {
if (bookmark.asset.content) {
reqs.push({
embeddingType: "content_full",
fromOffset: 0,
toOffset: bookmark.asset.content?.length ?? 0,
text: bookmark.asset.content ?? "",
});
reqs.push(...(await chunkText(bookmark.asset.content)));
}
}
return reqs;
}

async function runEmbeddings(job: DequeuedJob<EmbeddingsRequest>) {
const jobId = job.id;

const inferenceClient = InferenceClientFactory.build();
if (!inferenceClient) {
logger.debug(
`[embeddings][${jobId}] No inference client configured, nothing to do now`,
);
return;
}

const { bookmarkId } = job.data;
const bookmark = await fetchBookmark(bookmarkId);
if (!bookmark) {
throw new Error(
`[embeddings][${jobId}] bookmark with id ${bookmarkId} was not found`,
);
}

logger.info(
`[embeddings][${jobId}] Starting an embeddings job for bookmark with id "${bookmark.id}"`,
);

const reqs = await prepareEmbeddings(bookmark);

logger.info(`[embeddings][${jobId}] Got ${reqs.length} embeddings requests`);
if (reqs.length == 0) {
logger.info(`[embeddings][${jobId}] No embeddings requests to process`);
return;
}

const embeddings = await inferenceClient.generateEmbeddingFromText(
reqs.map((r) => r.text),
);

const resps = reqs.map((req, i) => ({
...req,
embedding: embeddings.embeddings[i],
}));

const db = await getBookmarkVectorDb();
// Delete the old vectors
await db.delete(`bookmarkid = "${bookmark.id}"`);
// Add the new vectors
await db.add(
resps.map((r) => ({
vector: r.embedding,
bookmarkid: bookmarkId,
})),
);
}
34 changes: 23 additions & 11 deletions apps/workers/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import "dotenv/config";

import { AssetPreprocessingWorker } from "assetPreprocessingWorker";
import { EmbeddingsWorker } from "embeddingsWorker";
import { FeedRefreshingWorker, FeedWorker } from "feedWorker";
import { TidyAssetsWorker } from "tidyAssetsWorker";

Expand All @@ -18,16 +19,25 @@ async function main() {
logger.info(`Workers version: ${serverConfig.serverVersion ?? "not set"}`);
runQueueDBMigrations();

const [crawler, openai, search, tidyAssets, video, feed, assetPreprocessing] =
[
await CrawlerWorker.build(),
OpenAiWorker.build(),
SearchIndexingWorker.build(),
TidyAssetsWorker.build(),
VideoWorker.build(),
FeedWorker.build(),
AssetPreprocessingWorker.build(),
];
const [
crawler,
openai,
search,
tidyAssets,
video,
feed,
assetPreprocessing,
embeddingsWorker,
] = [
await CrawlerWorker.build(),
OpenAiWorker.build(),
SearchIndexingWorker.build(),
TidyAssetsWorker.build(),
VideoWorker.build(),
FeedWorker.build(),
AssetPreprocessingWorker.build(),
EmbeddingsWorker.build(),
];
FeedRefreshingWorker.start();

await Promise.any([
Expand All @@ -39,11 +49,12 @@ async function main() {
video.run(),
feed.run(),
assetPreprocessing.run(),
embeddingsWorker.run(),
]),
shutdownPromise,
]);
logger.info(
"Shutting down crawler, openai, tidyAssets, video, feed, assetPreprocessing and search workers ...",
"Shutting down crawler, openai, tidyAssets, video, feed, assetPreprocessing, embeddingsWorker and search workers ...",
);

FeedRefreshingWorker.stop();
Expand All @@ -54,6 +65,7 @@ async function main() {
video.stop();
feed.stop();
assetPreprocessing.stop();
embeddingsWorker.stop();
}

main();
2 changes: 2 additions & 0 deletions apps/workers/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"@hoarder/shared": "workspace:^0.1.0",
"@hoarder/trpc": "workspace:^0.1.0",
"@hoarder/tsconfig": "workspace:^0.1.0",
"@langchain/core": "^0.3.26",
"@langchain/textsplitters": "^0.1.0",
"@mozilla/readability": "^0.5.0",
"@tsconfig/node21": "^21.0.1",
"async-mutex": "^0.4.1",
Expand Down
2 changes: 2 additions & 0 deletions packages/shared/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"private": true,
"type": "module",
"dependencies": {
"@lancedb/lancedb": "^0.14.1",
"apache-arrow": "^18.1.0",
"glob": "^11.0.0",
"liteque": "^0.3.0",
"meilisearch": "^0.37.0",
Expand Down
16 changes: 16 additions & 0 deletions packages/shared/queues.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,19 @@ export const AssetPreprocessingQueue =
keepFailedJobs: false,
},
);

// Embeddings Queue
export const zEmbeddingsRequestSchema = z.object({
bookmarkId: z.string(),
});
export type EmbeddingsRequest = z.infer<typeof zEmbeddingsRequestSchema>;
export const EmbeddingsQueue = new SqliteQueue<EmbeddingsRequest>(
"embeddings_queue",
queueDB,
{
defaultJobArgs: {
numRetries: 3,
},
keepFailedJobs: false,
},
);
25 changes: 25 additions & 0 deletions packages/shared/vectorDb.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import path from "path";
import * as lancedb from "@lancedb/lancedb";
import { Field, FixedSizeList, Float32, Schema, Utf8 } from "apache-arrow";

import serverConfig from "./config";

export async function getBookmarkVectorDb() {
const dbPath = path.join(serverConfig.dataDir, "vectordb");
const db = await lancedb.connect(dbPath);
const table = db.createEmptyTable(
"bookmarks",
new Schema([
new Field(
"vector",
new FixedSizeList(1536, new Field("item", new Float32(), true)),
),
new Field("bookmarkid", new Utf8()),
]),
{
mode: "create",
existOk: true,
},
);
return table;
}
12 changes: 12 additions & 0 deletions packages/trpc/routers/admin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { z } from "zod";
import { assets, bookmarkLinks, bookmarks, users } from "@hoarder/db/schema";
import serverConfig from "@hoarder/shared/config";
import {
EmbeddingsQueue,
LinkCrawlerQueue,
OpenAIQueue,
SearchIndexingQueue,
Expand Down Expand Up @@ -154,6 +155,17 @@ export const adminAppRouter = router({

await Promise.all(bookmarkIds.map((b) => triggerSearchReindex(b.id)));
}),
reEmbedAllBookmarks: adminProcedure.mutation(async ({ ctx }) => {
const bookmarkIds = await ctx.db.query.bookmarks.findMany({
columns: {
id: true,
},
});

await Promise.all(
bookmarkIds.map((b) => EmbeddingsQueue.enqueue({ bookmarkId: b.id })),
);
}),
reRunInferenceOnAllBookmarks: adminProcedure
.input(
z.object({
Expand Down
Loading
Loading