Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: postgres adapter settings not being applied #1379

Merged
merged 2 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions packages/adapter-postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
elizaLogger,
getEmbeddingConfig,
DatabaseAdapter,
EmbeddingProvider,
} from "@elizaos/core";
import fs from "fs";
import { fileURLToPath } from "url";
Expand Down Expand Up @@ -189,6 +190,19 @@ export class PostgresDatabaseAdapter
try {
await client.query("BEGIN");

// Set application settings for embedding dimension
const embeddingConfig = getEmbeddingConfig();
if (embeddingConfig.provider === EmbeddingProvider.OpenAI) {
await client.query("SET app.use_openai_embedding = 'true'");
await client.query("SET app.use_ollama_embedding = 'false'");
} else if (embeddingConfig.provider === EmbeddingProvider.Ollama) {
await client.query("SET app.use_openai_embedding = 'false'");
await client.query("SET app.use_ollama_embedding = 'true'");
} else {
await client.query("SET app.use_openai_embedding = 'false'");
await client.query("SET app.use_ollama_embedding = 'false'");
}

// Check if schema already exists (check for a core table)
const { rows } = await client.query(`
SELECT EXISTS (
Expand Down
40 changes: 32 additions & 8 deletions packages/core/src/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,36 @@ interface EmbeddingOptions {
provider?: string;
}

// Add the embedding configuration
export const getEmbeddingConfig = () => ({
export const EmbeddingProvider = {
OpenAI: "OpenAI",
Ollama: "Ollama",
GaiaNet: "GaiaNet",
BGE: "BGE",
} as const;

export type EmbeddingProvider =
(typeof EmbeddingProvider)[keyof typeof EmbeddingProvider];

export namespace EmbeddingProvider {
export type OpenAI = typeof EmbeddingProvider.OpenAI;
export type Ollama = typeof EmbeddingProvider.Ollama;
export type GaiaNet = typeof EmbeddingProvider.GaiaNet;
export type BGE = typeof EmbeddingProvider.BGE;
}

export type EmbeddingConfig = {
readonly dimensions: number;
readonly model: string;
readonly provider: EmbeddingProvider;
};

export const getEmbeddingConfig = (): EmbeddingConfig => ({
dimensions:
settings.USE_OPENAI_EMBEDDING?.toLowerCase() === "true"
? 1536 // OpenAI
: settings.USE_OLLAMA_EMBEDDING?.toLowerCase() === "true"
? 1024 // Ollama mxbai-embed-large
:settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true"
: settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true"
? 768 // GaiaNet
: 384, // BGE
model:
Expand Down Expand Up @@ -171,7 +193,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
const isNode = typeof process !== "undefined" && process.versions?.node;

// Determine which embedding path to use
if (config.provider === "OpenAI") {
if (config.provider === EmbeddingProvider.OpenAI) {
return await getRemoteEmbedding(input, {
model: config.model,
endpoint: "https://api.openai.com/v1",
Expand All @@ -180,7 +202,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
});
}

if (config.provider === "Ollama") {
if (config.provider === EmbeddingProvider.Ollama) {
return await getRemoteEmbedding(input, {
model: config.model,
endpoint:
Expand All @@ -191,7 +213,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
});
}

if (config.provider=="GaiaNet") {
if (config.provider == EmbeddingProvider.GaiaNet) {
return await getRemoteEmbedding(input, {
model: config.model,
endpoint:
Expand Down Expand Up @@ -252,9 +274,11 @@ export async function embed(runtime: IAgentRuntime, input: string) {
return await import("fastembed");
} catch {
elizaLogger.error("Failed to load fastembed.");
throw new Error("fastembed import failed, falling back to remote embedding");
throw new Error(
"fastembed import failed, falling back to remote embedding"
);
}
})()
})(),
]);

const [fs, { fileURLToPath }, fastEmbed] = moduleImports;
Expand Down
Loading