Skip to content

Commit

Permalink
lancedb based embeddings generation
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamedBassem committed Dec 29, 2024
1 parent c89b0c5 commit 01c74df
Show file tree
Hide file tree
Showing 14 changed files with 2,395 additions and 60 deletions.
22 changes: 22 additions & 0 deletions apps/web/components/dashboard/admin/AdminActions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ export default function AdminActions() {
},
});

const { mutateAsync: reEmbedBookmarks, isPending: isReEmbedPending } =
api.admin.reEmbedAllBookmarks.useMutation({
onSuccess: () => {
toast({
description: "ReEmbed request has been enqueued!",
});
},
onError: (e) => {
toast({
variant: "destructive",
description: e.message,
});
},
});

return (
<div>
<div className="mb-2 mt-8 text-xl font-medium">{t("common.actions")}</div>
Expand Down Expand Up @@ -124,6 +139,13 @@ export default function AdminActions() {
>
{t("admin.actions.reindex_all_bookmarks")}
</ActionButton>
<ActionButton
variant="destructive"
loading={isReEmbedPending}
onClick={() => reEmbedBookmarks()}
>
{t("admin.actions.reembed_all_bookmarks")}
</ActionButton>
<ActionButton
variant="destructive"
loading={isTidyAssetsPending}
Expand Down
1 change: 1 addition & 0 deletions apps/web/lib/i18n/locales/en/translation.json
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
"regenerate_ai_tags_for_failed_bookmarks_only": "Regenerate AI Tags for Failed Bookmarks Only",
"regenerate_ai_tags_for_all_bookmarks": "Regenerate AI Tags for All Bookmarks",
"reindex_all_bookmarks": "Reindex All Bookmarks",
"reembed_all_bookmarks": "Re-embed All Bookmarks",
"compact_assets": "Compact Assets"
},
"users_list": {
Expand Down
180 changes: 180 additions & 0 deletions apps/workers/embeddingsWorker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import { RecursiveCharacterTextSplitter } from "@langchain/textsplitters";
import { eq } from "drizzle-orm";
import { DequeuedJob, Runner } from "liteque";

import type { EmbeddingsRequest, ZOpenAIRequest } from "@hoarder/shared/queues";
import { db } from "@hoarder/db";
import { bookmarkEmbeddings, bookmarks } from "@hoarder/db/schema";
import serverConfig from "@hoarder/shared/config";
import { InferenceClientFactory } from "@hoarder/shared/inference";
import logger from "@hoarder/shared/logger";
import {
EmbeddingsQueue,
zEmbeddingsRequestSchema,
} from "@hoarder/shared/queues";
import { getBookmarkVectorDb } from "@hoarder/shared/vectorDb";

type EmbeddingChunk = Pick<
typeof bookmarkEmbeddings.$inferSelect,
"embeddingType" | "fromOffset" | "toOffset"
> & { text: string };

export class EmbeddingsWorker {
static build() {
logger.info("Starting embeddings worker ...");
const worker = new Runner<ZOpenAIRequest>(
EmbeddingsQueue,
{
run: runEmbeddings,
onComplete: async (job) => {
const jobId = job.id;
logger.info(`[embeddings][${jobId}] Completed successfully`);
return Promise.resolve();
},
onError: async (job) => {
const jobId = job.id;
logger.error(
`[embeddings][${jobId}] embeddings job failed: ${job.error}\n${job.error.stack}`,
);
return Promise.resolve();
},
},
{
concurrency: 1,
pollIntervalMs: 1000,
timeoutSecs: serverConfig.inference.jobTimeoutSec,
validator: zEmbeddingsRequestSchema,
},
);

return worker;
}
}

async function fetchBookmark(linkId: string) {
return await db.query.bookmarks.findFirst({
where: eq(bookmarks.id, linkId),
with: {
link: true,
text: true,
asset: true,
},
});
}

async function chunkText(text: string): Promise<EmbeddingChunk[]> {
const textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: 100,
chunkOverlap: 0,
});
const texts = await textSplitter.splitText(text);
return texts.map((t) => ({
embeddingType: "content_chunk",
text: t,
fromOffset: 0,
toOffset: t.length,
}));
}

async function prepareEmbeddings(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
) {
const reqs: EmbeddingChunk[] = [];

if (bookmark.link) {
if (bookmark.link.description) {
reqs.push({
embeddingType: "description",
fromOffset: 0,
toOffset: bookmark.link.description?.length ?? 0,
text: bookmark.link.description ?? "",
});
}
if (bookmark.link.content) {
reqs.push({
embeddingType: "content_full",
fromOffset: 0,
toOffset: bookmark.link.content?.length ?? 0,
text: bookmark.link.content ?? "",
});
reqs.push(...(await chunkText(bookmark.link.content ?? "")));
}
}

if (bookmark.text) {
if (bookmark.text.text) {
reqs.push({
embeddingType: "description",
fromOffset: 0,
toOffset: bookmark.text.text?.length ?? 0,
text: bookmark.text.text ?? "",
});
reqs.push(...(await chunkText(bookmark.text.text)));
}
}

if (bookmark.asset) {
if (bookmark.asset.content) {
reqs.push({
embeddingType: "content_full",
fromOffset: 0,
toOffset: bookmark.asset.content?.length ?? 0,
text: bookmark.asset.content ?? "",
});
reqs.push(...(await chunkText(bookmark.asset.content)));
}
}
return reqs;
}

async function runEmbeddings(job: DequeuedJob<EmbeddingsRequest>) {
const jobId = job.id;

const inferenceClient = InferenceClientFactory.build();
if (!inferenceClient) {
logger.debug(
`[embeddings][${jobId}] No inference client configured, nothing to do now`,
);
return;
}

const { bookmarkId } = job.data;
const bookmark = await fetchBookmark(bookmarkId);
if (!bookmark) {
throw new Error(
`[embeddings][${jobId}] bookmark with id ${bookmarkId} was not found`,
);
}

logger.info(
`[embeddings][${jobId}] Starting an embeddings job for bookmark with id "${bookmark.id}"`,
);

const reqs = await prepareEmbeddings(bookmark);

logger.info(`[embeddings][${jobId}] Got ${reqs.length} embeddings requests`);
if (reqs.length == 0) {
logger.info(`[embeddings][${jobId}] No embeddings requests to process`);
return;
}

const embeddings = await inferenceClient.generateEmbeddingFromText(
reqs.map((r) => r.text),
);

const resps = reqs.map((req, i) => ({
...req,
embedding: embeddings.embeddings[i],
}));

const db = await getBookmarkVectorDb();
// Delete the old vectors
await db.delete(`bookmarkid = "${bookmark.id}"`);
// Add the new vectors
await db.add(
resps.map((r) => ({
vector: r.embedding,
bookmarkid: bookmarkId,
})),
);
}
34 changes: 23 additions & 11 deletions apps/workers/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import "dotenv/config";

import { AssetPreprocessingWorker } from "assetPreprocessingWorker";
import { EmbeddingsWorker } from "embeddingsWorker";
import { FeedRefreshingWorker, FeedWorker } from "feedWorker";
import { TidyAssetsWorker } from "tidyAssetsWorker";

Expand All @@ -18,16 +19,25 @@ async function main() {
logger.info(`Workers version: ${serverConfig.serverVersion ?? "not set"}`);
runQueueDBMigrations();

const [crawler, openai, search, tidyAssets, video, feed, assetPreprocessing] =
[
await CrawlerWorker.build(),
OpenAiWorker.build(),
SearchIndexingWorker.build(),
TidyAssetsWorker.build(),
VideoWorker.build(),
FeedWorker.build(),
AssetPreprocessingWorker.build(),
];
const [
crawler,
openai,
search,
tidyAssets,
video,
feed,
assetPreprocessing,
embeddingsWorker,
] = [
await CrawlerWorker.build(),
OpenAiWorker.build(),
SearchIndexingWorker.build(),
TidyAssetsWorker.build(),
VideoWorker.build(),
FeedWorker.build(),
AssetPreprocessingWorker.build(),
EmbeddingsWorker.build(),
];
FeedRefreshingWorker.start();

await Promise.any([
Expand All @@ -39,11 +49,12 @@ async function main() {
video.run(),
feed.run(),
assetPreprocessing.run(),
embeddingsWorker.run(),
]),
shutdownPromise,
]);
logger.info(
"Shutting down crawler, openai, tidyAssets, video, feed, assetPreprocessing and search workers ...",
"Shutting down crawler, openai, tidyAssets, video, feed, assetPreprocessing, embeddingsWorker and search workers ...",
);

FeedRefreshingWorker.stop();
Expand All @@ -54,6 +65,7 @@ async function main() {
video.stop();
feed.stop();
assetPreprocessing.stop();
embeddingsWorker.stop();
}

main();
2 changes: 2 additions & 0 deletions apps/workers/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"@hoarder/shared": "workspace:^0.1.0",
"@hoarder/trpc": "workspace:^0.1.0",
"@hoarder/tsconfig": "workspace:^0.1.0",
"@langchain/core": "^0.3.26",
"@langchain/textsplitters": "^0.1.0",
"@mozilla/readability": "^0.5.0",
"@tsconfig/node21": "^21.0.1",
"async-mutex": "^0.4.1",
Expand Down
14 changes: 14 additions & 0 deletions packages/db/drizzle/0037_sturdy_leper_queen.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
CREATE TABLE `bookmarkEmbeddings` (
`id` text PRIMARY KEY NOT NULL,
`bookmarkId` text NOT NULL,
`userId` text NOT NULL,
`embedding` text NOT NULL,
`embeddingType` text NOT NULL,
`fromOffset` integer,
`toOffset` integer,
FOREIGN KEY (`bookmarkId`) REFERENCES `bookmarks`(`id`) ON UPDATE no action ON DELETE cascade,
FOREIGN KEY (`userId`) REFERENCES `user`(`id`) ON UPDATE no action ON DELETE cascade
);
--> statement-breakpoint
CREATE INDEX `bookmarkEmbeddings_bookmarkId_idx` ON `bookmarkEmbeddings` (`bookmarkId`);--> statement-breakpoint
CREATE INDEX `bookmarkEmbeddings_userId_idx` ON `bookmarkEmbeddings` (`userId`);
Loading

0 comments on commit 01c74df

Please sign in to comment.