Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: (db) add limit param to memory retrieval across adapters #2264

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions packages/adapter-pglite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ export class PGLiteDatabaseAdapter
roomIds: UUID[];
agentId?: UUID;
tableName: string;
limit?: number;
}): Promise<Memory[]> {
return this.withDatabase(async () => {
if (params.roomIds.length === 0) return [];
Expand All @@ -167,6 +168,13 @@ export class PGLiteDatabaseAdapter
queryParams = [...queryParams, params.agentId];
}

// Add ordering and limit
query += ` ORDER BY "createdAt" DESC`;
if (params.limit) {
query += ` LIMIT $${queryParams.length + 1}`;
queryParams.push(params.limit.toString());
}

const { rows } = await this.query<Memory>(query, queryParams);
return rows.map((row) => ({
...row,
Expand Down
12 changes: 11 additions & 1 deletion packages/adapter-sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,29 @@ export class SqliteDatabaseAdapter
agentId: UUID;
roomIds: UUID[];
tableName: string;
limit?: number;
}): Promise<Memory[]> {
if (!params.tableName) {
// default to messages
params.tableName = "messages";
}

const placeholders = params.roomIds.map(() => "?").join(", ");
const sql = `SELECT * FROM memories WHERE type = ? AND agentId = ? AND roomId IN (${placeholders})`;
let sql = `SELECT * FROM memories WHERE type = ? AND agentId = ? AND roomId IN (${placeholders})`;

const queryParams = [
params.tableName,
params.agentId,
...params.roomIds,
];

// Add ordering and limit
sql += ` ORDER BY createdAt DESC`;
if (params.limit) {
sql += ` LIMIT ?`;
queryParams.push(params.limit.toString());
}

const stmt = this.db.prepare(sql);
const rows = stmt.all(...queryParams) as (Memory & {
content: string;
Expand Down
67 changes: 46 additions & 21 deletions packages/adapter-sqljs/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
type Relationship,
type UUID,
RAGKnowledgeItem,
elizaLogger
elizaLogger,
} from "@elizaos/core";
import { v4 } from "uuid";
import { sqliteTables } from "./sqliteTables.ts";
Expand Down Expand Up @@ -81,15 +81,26 @@ export class SqlJsDatabaseAdapter
agentId: UUID;
roomIds: UUID[];
tableName: string;
limit?: number;
}): Promise<Memory[]> {
const placeholders = params.roomIds.map(() => "?").join(", ");
const sql = `SELECT * FROM memories WHERE 'type' = ? AND agentId = ? AND roomId IN (${placeholders})`;
const stmt = this.db.prepare(sql);
let sql = `SELECT * FROM memories WHERE 'type' = ? AND agentId = ? AND roomId IN (${placeholders})`;

const queryParams = [
params.tableName,
params.agentId,
...params.roomIds,
];

// Add ordering and limit
sql += ` ORDER BY createdAt DESC`;
if (params.limit) {
sql += ` LIMIT ?`;
queryParams.push(params.limit.toString());
}

const stmt = this.db.prepare(sql);

elizaLogger.log({ queryParams });
stmt.bind(queryParams);
elizaLogger.log({ queryParams });
Expand Down Expand Up @@ -834,8 +845,10 @@ export class SqlJsDatabaseAdapter
id: row.id,
agentId: row.agentId,
content: JSON.parse(row.content),
embedding: row.embedding ? new Float32Array(row.embedding) : undefined, // Convert Uint8Array back to Float32Array
createdAt: row.createdAt
embedding: row.embedding
? new Float32Array(row.embedding)
: undefined, // Convert Uint8Array back to Float32Array
createdAt: row.createdAt,
});
}
stmt.free();
Expand All @@ -852,7 +865,7 @@ export class SqlJsDatabaseAdapter
const cacheKey = `embedding_${params.agentId}_${params.searchText}`;
const cachedResult = await this.getCache({
key: cacheKey,
agentId: params.agentId
agentId: params.agentId,
});

if (cachedResult) {
Expand Down Expand Up @@ -901,11 +914,11 @@ export class SqlJsDatabaseAdapter
stmt.bind([
new Uint8Array(params.embedding.buffer),
params.agentId,
`%${params.searchText || ''}%`,
`%${params.searchText || ""}%`,
params.agentId,
params.agentId,
params.match_threshold,
params.match_count
params.match_count,
]);

const results: RAGKnowledgeItem[] = [];
Expand All @@ -915,17 +928,19 @@ export class SqlJsDatabaseAdapter
id: row.id,
agentId: row.agentId,
content: JSON.parse(row.content),
embedding: row.embedding ? new Float32Array(row.embedding) : undefined,
embedding: row.embedding
? new Float32Array(row.embedding)
: undefined,
createdAt: row.createdAt,
similarity: row.keyword_score
similarity: row.keyword_score,
});
}
stmt.free();

await this.setCache({
key: cacheKey,
agentId: params.agentId,
value: JSON.stringify(results)
value: JSON.stringify(results),
});

return results;
Expand All @@ -947,31 +962,41 @@ export class SqlJsDatabaseAdapter
knowledge.id,
metadata.isShared ? null : knowledge.agentId,
JSON.stringify(knowledge.content),
knowledge.embedding ? new Uint8Array(knowledge.embedding.buffer) : null,
knowledge.embedding
? new Uint8Array(knowledge.embedding.buffer)
: null,
knowledge.createdAt || Date.now(),
metadata.isMain ? 1 : 0,
metadata.originalId || null,
metadata.chunkIndex || null,
metadata.isShared ? 1 : 0
metadata.isShared ? 1 : 0,
]);
stmt.free();
} catch (error: any) {
const isShared = knowledge.content.metadata?.isShared;
const isPrimaryKeyError = error?.code === 'SQLITE_CONSTRAINT_PRIMARYKEY';
const isPrimaryKeyError =
error?.code === "SQLITE_CONSTRAINT_PRIMARYKEY";

if (isShared && isPrimaryKeyError) {
elizaLogger.info(`Shared knowledge ${knowledge.id} already exists, skipping`);
elizaLogger.info(
`Shared knowledge ${knowledge.id} already exists, skipping`
);
return;
} else if (!isShared && !error.message?.includes('SQLITE_CONSTRAINT_PRIMARYKEY')) {
} else if (
!isShared &&
!error.message?.includes("SQLITE_CONSTRAINT_PRIMARYKEY")
) {
elizaLogger.error(`Error creating knowledge ${knowledge.id}:`, {
error,
embeddingLength: knowledge.embedding?.length,
content: knowledge.content
content: knowledge.content,
});
throw error;
}

elizaLogger.debug(`Knowledge ${knowledge.id} already exists, skipping`);
elizaLogger.debug(
`Knowledge ${knowledge.id} already exists, skipping`
);
}
}

Expand All @@ -983,9 +1008,9 @@ export class SqlJsDatabaseAdapter
}

async clearKnowledge(agentId: UUID, shared?: boolean): Promise<void> {
const sql = shared ?
`DELETE FROM knowledge WHERE ("agentId" = ? OR "isShared" = 1)` :
`DELETE FROM knowledge WHERE "agentId" = ?`;
const sql = shared
? `DELETE FROM knowledge WHERE ("agentId" = ? OR "isShared" = 1)`
: `DELETE FROM knowledge WHERE "agentId" = ?`;

const stmt = this.db.prepare(sql);
stmt.run([agentId]);
Expand Down
Loading
Loading