Skip to content

Commit

Permalink
Improve knowledge embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
tarrencev committed Nov 21, 2024
1 parent 3ab32a9 commit f8d6b57
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ const summarizeAction = {
const model = models[runtime.character.settings.model];
const chunkSize = model.settings.maxContextLength - 1000;

const chunks = await splitChunks(formattedMemories, chunkSize, 0);
const chunks = await splitChunks(formattedMemories, chunkSize, "gpt-4o-mini", 0);

const datestr = new Date().toUTCString().replace(/:/g, "-");

Expand Down
45 changes: 3 additions & 42 deletions packages/client-github/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@ import {
AgentRuntime,
Client,
IAgentRuntime,
Content,
Memory,
knowledge,
stringToUuid,
embeddingZeroVector,
splitChunks,
embed,
} from "@ai16z/eliza";

export interface GitHubConfig {
Expand Down Expand Up @@ -111,11 +107,8 @@ export class GitHubClient {
relativePath
);

const memory: Memory = {
await knowledge.set({
id: knowledgeId,
agentId: this.runtime.agentId,
userId: this.runtime.agentId,
roomId: this.runtime.agentId,
content: {
text: content,
hash: contentHash,
Expand All @@ -127,39 +120,7 @@ export class GitHubClient {
owner: this.config.owner,
},
},
embedding: embeddingZeroVector,
};

await this.runtime.documentsManager.createMemory(memory);

// Only split if content exceeds 4000 characters
const fragments =
content.length > 4000
? await splitChunks(content, 2000, 200)
: [content];

for (const fragment of fragments) {
// Skip empty fragments
if (!fragment.trim()) continue;

// Add file path context to the fragment before embedding
const fragmentWithPath = `File: ${relativePath}\n\n${fragment}`;
const embedding = await embed(this.runtime, fragmentWithPath);

await this.runtime.knowledgeManager.createMemory({
// We namespace the knowledge base uuid to avoid id
// collision with the document above.
id: stringToUuid(knowledgeId + fragment),
roomId: this.runtime.agentId,
agentId: this.runtime.agentId,
userId: this.runtime.agentId,
content: {
source: knowledgeId,
text: fragment,
},
embedding,
});
}
});
}
}

Expand Down
23 changes: 14 additions & 9 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -468,27 +468,32 @@ export async function generateShouldRespond({
export async function splitChunks(
content: string,
chunkSize: number,
model: string,
bleed: number = 100
): Promise<string[]> {
const encoding = encoding_for_model("gpt-4o-mini");
const encoding = encoding_for_model(model as TiktokenModel);

const tokens = encoding.encode(content);
const chunks: string[] = [];
const textDecoder = new TextDecoder();

for (let i = 0; i < tokens.length; i += chunkSize) {
const chunk = tokens.slice(i, i + chunkSize);
const decodedChunk = textDecoder.decode(encoding.decode(chunk));
let chunk = tokens.slice(i, i + chunkSize);

// Append bleed characters from the previous chunk
const startBleed = i > 0 ? content.slice(i - bleed, i) : "";
if (i > 0) {
chunk = new Uint32Array([...tokens.slice(i - bleed, i), ...chunk]);
}

// Append bleed characters from the next chunk
const endBleed =
i + chunkSize < tokens.length
? content.slice(i + chunkSize, i + chunkSize + bleed)
: "";
if (i + chunkSize < tokens.length) {
chunk = new Uint32Array([
...chunk,
...tokens.slice(i + chunkSize, i + chunkSize + bleed),
]);
}

chunks.push(startBleed + decodedChunk + endBleed);
chunks.push(textDecoder.decode(encoding.decode(chunk)));
}

return chunks;
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ export * from "./types.ts";
export * from "./logger.ts";
export * from "./parsing.ts";
export * from "./uuid.ts";
export { default as knowledge } from "./knowledge.ts";
120 changes: 120 additions & 0 deletions packages/core/src/knowledge.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import { UUID } from "crypto";

import { AgentRuntime } from "./runtime.ts";
import { embed } from "./embedding.ts";
import { Content, ModelClass, type Memory } from "./types.ts";
import { stringToUuid } from "./uuid.ts";
import { embeddingZeroVector } from "./memory.ts";
import { splitChunks } from "./generation.ts";
import { models } from "./models.ts";

async function get(runtime: AgentRuntime, message: Memory): Promise<string[]> {
const embedding = await embed(runtime, message.content.text);
const fragments = await runtime.knowledgeManager.searchMemoriesByEmbedding(
embedding,
{
roomId: message.agentId,
agentId: message.agentId,
count: 3,
match_threshold: 0.1,
}
);

const uniqueSources = [
...new Set(
fragments.map((memory) => {
console.log((memory as any).similarity);
return memory.content.source;
})
),
];

const knowledgeDocuments = await Promise.all(
uniqueSources.map((source) =>
runtime.documentsManager.getMemoryById(source as UUID)
)
);

const knowledge = knowledgeDocuments
.filter((memory) => memory !== null)
.map((memory) => memory.content.text);
return knowledge;
}

export type KnowledgeItem = {
id: UUID;
content: Content;
};

async function set(runtime: AgentRuntime, item: KnowledgeItem) {
await runtime.documentsManager.createMemory({
embedding: embeddingZeroVector,
id: item.id,
agentId: runtime.agentId,
roomId: runtime.agentId,
userId: runtime.agentId,
createdAt: Date.now(),
content: item.content,
});

const preprocessed = preprocess(item.content.text);
const fragments = await splitChunks(
preprocessed,
10,
models[runtime.character.modelProvider].model?.[ModelClass.EMBEDDING],
5
);

for (const fragment of fragments) {
const embedding = await embed(this, fragment);
await runtime.knowledgeManager.createMemory({
// We namespace the knowledge base uuid to avoid id
// collision with the document above.
id: stringToUuid(item.id + fragment),
roomId: runtime.agentId,
agentId: runtime.agentId,
userId: runtime.agentId,
createdAt: Date.now(),
content: {
source: item.id,
text: fragment,
},
embedding,
});
}
}

export function preprocess(content: string): string {
return (
content
// Remove code blocks and their content
.replace(/```[\s\S]*?```/g, "")
// Remove inline code
.replace(/`.*?`/g, "")
// Convert headers to plain text with emphasis
.replace(/#{1,6}\s*(.*)/g, "$1")
// Remove image links but keep alt text
.replace(/!\[(.*?)\]\(.*?\)/g, "$1")
// Remove links but keep text
.replace(/\[(.*?)\]\(.*?\)/g, "$1")
// Remove HTML tags
.replace(/<[^>]*>/g, "")
// Remove horizontal rules
.replace(/^\s*[-*_]{3,}\s*$/gm, "")
// Remove comments
.replace(/\/\*[\s\S]*?\*\//g, "")
.replace(/\/\/.*/g, "")
// Normalize whitespace
.replace(/\s+/g, " ")
// Remove multiple newlines
.replace(/\n{3,}/g, "\n\n")
.trim()
.toLowerCase()
);
}

export default {
get,
set,
process,
};
Loading

0 comments on commit f8d6b57

Please sign in to comment.