Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: db queries not using agentId in all memory queries #539

Merged
merged 6 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
871 changes: 357 additions & 514 deletions packages/adapter-postgres/src/index.ts

Large diffs are not rendered by default.

40 changes: 16 additions & 24 deletions packages/adapter-sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,17 @@ export class SqliteDatabaseAdapter
}

async getMemoriesByRoomIds(params: {
agentId: UUID;
roomIds: UUID[];
tableName: string;
agentId?: UUID;
}): Promise<Memory[]> {
if (!params.tableName) {
// default to messages
params.tableName = "messages";
}
const placeholders = params.roomIds.map(() => "?").join(", ");
let sql = `SELECT * FROM memories WHERE type = ? AND roomId IN (${placeholders})`;
const queryParams = [params.tableName, ...params.roomIds];

if (params.agentId) {
sql += ` AND agentId = ?`;
queryParams.push(params.agentId);
}
let sql = `SELECT * FROM memories WHERE type = ? AND agentId = ? AND roomId IN (${placeholders})`;
let queryParams = [params.tableName, params.agentId, ...params.roomIds];

const stmt = this.db.prepare(sql);
const rows = stmt.all(...queryParams) as (Memory & {
Expand Down Expand Up @@ -189,8 +184,8 @@ export class SqliteDatabaseAdapter

async createMemory(memory: Memory, tableName: string): Promise<void> {
// Delete any existing memory with the same ID first
const deleteSql = `DELETE FROM memories WHERE id = ? AND type = ?`;
this.db.prepare(deleteSql).run(memory.id, tableName);
// const deleteSql = `DELETE FROM memories WHERE id = ? AND type = ?`;
// this.db.prepare(deleteSql).run(memory.id, tableName);

let isUnique = true;

Expand All @@ -200,6 +195,7 @@ export class SqliteDatabaseAdapter
memory.embedding,
{
tableName,
agentId: memory.agentId,
roomId: memory.roomId,
match_threshold: 0.95, // 5% similarity threshold
count: 1,
Expand Down Expand Up @@ -281,7 +277,7 @@ export class SqliteDatabaseAdapter
match_threshold?: number;
count?: number;
roomId?: UUID;
agentId?: UUID;
agentId: UUID;
unique?: boolean;
tableName: string;
}
Expand All @@ -290,20 +286,17 @@ export class SqliteDatabaseAdapter
// JSON.stringify(embedding),
new Float32Array(embedding),
params.tableName,
params.agentId,
];

let sql = `
SELECT *, vec_distance_L2(embedding, ?) AS similarity
FROM memories
WHERE type = ?`;
WHERE embedding IS NOT NULL type = ? AND agentId = ?`;

if (params.unique) {
sql += " AND `unique` = 1";
}
if (params.agentId) {
sql += " AND agentId = ?";
queryParams.push(params.agentId);
}

if (params.roomId) {
sql += " AND roomId = ?";
Expand Down Expand Up @@ -418,7 +411,7 @@ export class SqliteDatabaseAdapter
count?: number;
unique?: boolean;
tableName: string;
agentId?: UUID;
agentId: UUID;
start?: number;
end?: number;
}): Promise<Memory[]> {
Expand All @@ -428,19 +421,18 @@ export class SqliteDatabaseAdapter
if (!params.roomId) {
throw new Error("roomId is required");
}
let sql = `SELECT * FROM memories WHERE type = ? AND roomId = ?`;
let sql = `SELECT * FROM memories WHERE type = ? AND agentId = ? AND roomId = ?`;

const queryParams = [params.tableName, params.roomId] as any[];
const queryParams = [
params.tableName,
params.agentId,
params.roomId,
] as any[];

if (params.unique) {
sql += " AND `unique` = 1";
}

if (params.agentId) {
sql += " AND agentId = ?";
queryParams.push(params.agentId);
}

if (params.start) {
sql += ` AND createdAt >= ?`;
queryParams.push(params.start);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ const summarizeAction = {
// 2. get these memories from the database
const memories = await runtime.messageManager.getMemories({
roomId,
agentId: runtime.agentId,
// subtract start from current time
start: parseInt(start as string),
end: parseInt(end as string),
Expand Down
3 changes: 0 additions & 3 deletions packages/client-twitter/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ export class ClientBase extends EventEmitter {
// Get the existing memories from the database
const existingMemories =
await this.runtime.messageManager.getMemoriesByRoomIds({
agentId: this.runtime.agentId,
roomIds: cachedTimeline.map((tweet) =>
stringToUuid(
tweet.conversationId + "-" + this.runtime.agentId
Expand Down Expand Up @@ -462,7 +461,6 @@ export class ClientBase extends EventEmitter {
// Check the existing memories in the database
const existingMemories =
await this.runtime.messageManager.getMemoriesByRoomIds({
agentId: this.runtime.agentId,
roomIds: Array.from(roomIds),
});

Expand Down Expand Up @@ -564,7 +562,6 @@ export class ClientBase extends EventEmitter {
const recentMessage = await this.runtime.messageManager.getMemories(
{
roomId: message.roomId,
agentId: this.runtime.agentId,
count: 1,
unique: false,
}
Expand Down
2 changes: 1 addition & 1 deletion packages/client-twitter/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ export async function buildConversationThread(
"twitter"
);

client.runtime.messageManager.createMemory({
await client.runtime.messageManager.createMemory({
id: stringToUuid(
currentTweet.id + "-" + client.runtime.agentId
),
Expand Down
5 changes: 4 additions & 1 deletion packages/core/src/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ export abstract class DatabaseAdapter<DB = any> implements IDatabaseAdapter {
* @returns A Promise that resolves to an array of Memory objects.
*/
abstract getMemories(params: {
agentId: UUID;
roomId: UUID;
count?: number;
unique?: boolean;
tableName: string;
}): Promise<Memory[]>;

abstract getMemoriesByRoomIds(params: {
agentId?: UUID;
agentId: UUID;
roomIds: UUID[];
tableName: string;
}): Promise<Memory[]>;
Expand Down Expand Up @@ -105,6 +106,7 @@ export abstract class DatabaseAdapter<DB = any> implements IDatabaseAdapter {
*/
abstract searchMemories(params: {
tableName: string;
agentId: UUID;
roomId: UUID;
embedding: number[];
match_threshold: number;
Expand Down Expand Up @@ -188,6 +190,7 @@ export abstract class DatabaseAdapter<DB = any> implements IDatabaseAdapter {
* @returns A Promise that resolves to an array of Goal objects.
*/
abstract getGoals(params: {
agentId: UUID;
roomId: UUID;
userId?: UUID | null;
onlyInProgress?: boolean;
Expand Down
3 changes: 3 additions & 0 deletions packages/core/src/goals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@ import {
} from "./types.ts";

export const getGoals = async ({
agentId,
runtime,
roomId,
userId,
onlyInProgress = true,
count = 5,
}: {
runtime: IAgentRuntime;
agentId: UUID;
roomId: UUID;
userId?: UUID;
onlyInProgress?: boolean;
count?: number;
}) => {
return runtime.databaseAdapter.getGoals({
agentId,
roomId,
userId,
onlyInProgress,
Expand Down
3 changes: 1 addition & 2 deletions packages/core/src/knowledge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ async function get(runtime: AgentRuntime, message: Memory): Promise<string[]> {
embedding,
{
roomId: message.agentId,
agentId: message.agentId,
count: 3,
match_threshold: 0.1,
}
Expand Down Expand Up @@ -50,13 +49,13 @@ async function set(
bleed: number = 20
) {
await runtime.documentsManager.createMemory({
embedding: embeddingZeroVector,
id: item.id,
agentId: runtime.agentId,
roomId: runtime.agentId,
userId: runtime.agentId,
createdAt: Date.now(),
content: item.content,
embedding: embeddingZeroVector,
});

const preprocessed = preprocess(item.content.text);
Expand Down
34 changes: 16 additions & 18 deletions packages/core/src/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,12 @@ export class MemoryManager implements IMemoryManager {
roomId,
count = 10,
unique = true,
agentId,
start,
end,
}: {
roomId: UUID;
count?: number;
unique?: boolean;
agentId?: UUID;
start?: number;
end?: number;
}): Promise<Memory[]> {
Expand All @@ -107,7 +105,7 @@ export class MemoryManager implements IMemoryManager {
count,
unique,
tableName: this.tableName,
agentId,
agentId: this.runtime.agentId,
start,
end,
});
Expand Down Expand Up @@ -143,7 +141,6 @@ export class MemoryManager implements IMemoryManager {
embedding: number[],
opts: {
match_threshold?: number;
agentId?: UUID;
count?: number;
roomId: UUID;
unique?: boolean;
Expand All @@ -154,20 +151,19 @@ export class MemoryManager implements IMemoryManager {
count = defaultMatchCount,
roomId,
unique,
agentId,
} = opts;

const searchOpts = {
const result = await this.runtime.databaseAdapter.searchMemories({
tableName: this.tableName,
roomId,
agentId,
embedding,
match_threshold,
agentId: this.runtime.agentId,
embedding: embedding,
match_threshold: match_threshold,
match_count: count,
unique: !!unique,
};
});

return await this.runtime.databaseAdapter.searchMemories(searchOpts);
return result;
}

/**
Expand All @@ -177,6 +173,8 @@ export class MemoryManager implements IMemoryManager {
* @returns A Promise that resolves when the operation completes.
*/
async createMemory(memory: Memory, unique = false): Promise<void> {
// TODO: check memory.agentId == this.runtime.agentId

const existingMessage =
await this.runtime.databaseAdapter.getMemoryById(memory.id);

Expand All @@ -185,26 +183,26 @@ export class MemoryManager implements IMemoryManager {
return;
}

elizaLogger.debug("Creating Memory", memory.id, memory.content.text);
elizaLogger.log("Creating Memory", memory.id, memory.content.text);

await this.runtime.databaseAdapter.createMemory(
memory,
this.tableName,
unique
);
}

async getMemoriesByRoomIds(params: {
agentId?: UUID;
roomIds: UUID[];
}): Promise<Memory[]> {
async getMemoriesByRoomIds(params: { roomIds: UUID[] }): Promise<Memory[]> {
return await this.runtime.databaseAdapter.getMemoriesByRoomIds({
agentId: params.agentId,
agentId: this.runtime.agentId,
roomIds: params.roomIds,
});
}

async getMemoryById(id: UUID): Promise<Memory | null> {
return await this.runtime.databaseAdapter.getMemoryById(id);
const result = await this.runtime.databaseAdapter.getMemoryById(id);
if (result && result.agentId !== this.runtime.agentId) return null;
return result;
}

/**
Expand Down
4 changes: 1 addition & 3 deletions packages/core/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -737,11 +737,11 @@ export class AgentRuntime implements IAgentRuntime {
getActorDetails({ runtime: this, roomId }),
this.messageManager.getMemories({
roomId,
agentId: this.agentId,
count: conversationLength,
unique: false,
}),
getGoals({
agentId: this.agentId,
runtime: this,
count: 10,
onlyInProgress: false,
Expand Down Expand Up @@ -877,7 +877,6 @@ Text: ${attachment.text}
// Check the existing memories in the database
const existingMemories =
await this.messageManager.getMemoriesByRoomIds({
agentId: this.agentId,
// filter out the current room id from rooms
roomIds: rooms.filter((room) => room !== roomId),
});
Expand Down Expand Up @@ -1172,7 +1171,6 @@ Text: ${attachment.text}
const conversationLength = this.getConversationLength();
const recentMessagesData = await this.messageManager.getMemories({
roomId: state.roomId,
agentId: this.agentId,
count: conversationLength,
unique: false,
});
Expand Down
Loading
Loading