Skip to content

Commit

Permalink
Merge pull request #472 from tarrencev/main
Browse files Browse the repository at this point in the history
feat: Improve knowledge embeddings
  • Loading branch information
ponderingdemocritus authored Nov 21, 2024
2 parents 8450877 + 0d75334 commit 9123996
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,12 @@ const summarizeAction = {
const model = models[runtime.character.settings.model];
const chunkSize = model.settings.maxContextLength - 1000;

const chunks = await splitChunks(formattedMemories, chunkSize, 0);
const chunks = await splitChunks(
formattedMemories,
chunkSize,
"gpt-4o-mini",
0
);

const datestr = new Date().toUTCString().replace(/:/g, "-");

Expand Down
5 changes: 3 additions & 2 deletions packages/client-discord/src/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -430,13 +430,13 @@ export class MessageManager {
await this.runtime.messageManager.createMemory(memory);
}

let state = (await this.runtime.composeState(userMessage, {
let state = await this.runtime.composeState(userMessage, {
discordClient: this.client,
discordMessage: message,
agentName:
this.runtime.character.name ||
this.client.user?.displayName,
})) as State;
});

if (!canSendMessage(message.channel).canSend) {
return elizaLogger.warn(
Expand Down Expand Up @@ -649,6 +649,7 @@ export class MessageManager {
message: DiscordMessage
): Promise<{ processedContent: string; attachments: Media[] }> {
let processedContent = message.content;

let attachments: Media[] = [];

// Process code blocks in the message content
Expand Down
45 changes: 3 additions & 42 deletions packages/client-github/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@ import {
AgentRuntime,
Client,
IAgentRuntime,
Content,
Memory,
knowledge,
stringToUuid,
embeddingZeroVector,
splitChunks,
embed,
} from "@ai16z/eliza";
import { validateGithubConfig } from "./enviroment";

Expand Down Expand Up @@ -112,11 +108,8 @@ export class GitHubClient {
relativePath
);

const memory: Memory = {
await knowledge.set(this.runtime, {
id: knowledgeId,
agentId: this.runtime.agentId,
userId: this.runtime.agentId,
roomId: this.runtime.agentId,
content: {
text: content,
hash: contentHash,
Expand All @@ -128,39 +121,7 @@ export class GitHubClient {
owner: this.config.owner,
},
},
embedding: embeddingZeroVector,
};

await this.runtime.documentsManager.createMemory(memory);

// Only split if content exceeds 4000 characters
const fragments =
content.length > 4000
? await splitChunks(content, 2000, 200)
: [content];

for (const fragment of fragments) {
// Skip empty fragments
if (!fragment.trim()) continue;

// Add file path context to the fragment before embedding
const fragmentWithPath = `File: ${relativePath}\n\n${fragment}`;
const embedding = await embed(this.runtime, fragmentWithPath);

await this.runtime.knowledgeManager.createMemory({
// We namespace the knowledge base uuid to avoid id
// collision with the document above.
id: stringToUuid(knowledgeId + fragment),
roomId: this.runtime.agentId,
agentId: this.runtime.agentId,
userId: this.runtime.agentId,
content: {
source: knowledgeId,
text: fragment,
},
embedding,
});
}
});
}
}

Expand Down
26 changes: 15 additions & 11 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -463,34 +463,38 @@ export async function generateShouldRespond({
* Splits content into chunks of specified size with optional overlapping bleed sections
* @param content - The text content to split into chunks
* @param chunkSize - The maximum size of each chunk in tokens
* @param bleed - Number of characters to overlap between chunks (default: 100)
* @param model - The model name to use for tokenization (default: runtime.model)
* @param bleed - Number of characters to overlap between chunks (default: 100)
* @returns Promise resolving to array of text chunks with bleed sections
*/
export async function splitChunks(
content: string,
chunkSize: number,
model: string,
bleed: number = 100
): Promise<string[]> {
const encoding = encoding_for_model("gpt-4o-mini");

const encoding = encoding_for_model(model as TiktokenModel);
const tokens = encoding.encode(content);
const chunks: string[] = [];
const textDecoder = new TextDecoder();

for (let i = 0; i < tokens.length; i += chunkSize) {
const chunk = tokens.slice(i, i + chunkSize);
const decodedChunk = textDecoder.decode(encoding.decode(chunk));
let chunk = tokens.slice(i, i + chunkSize);

// Append bleed characters from the previous chunk
const startBleed = i > 0 ? content.slice(i - bleed, i) : "";
if (i > 0) {
chunk = new Uint32Array([...tokens.slice(i - bleed, i), ...chunk]);
}

// Append bleed characters from the next chunk
const endBleed =
i + chunkSize < tokens.length
? content.slice(i + chunkSize, i + chunkSize + bleed)
: "";
if (i + chunkSize < tokens.length) {
chunk = new Uint32Array([
...chunk,
...tokens.slice(i + chunkSize, i + chunkSize + bleed),
]);
}

chunks.push(startBleed + decodedChunk + endBleed);
chunks.push(textDecoder.decode(encoding.decode(chunk)));
}

return chunks;
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ export * from "./parsing.ts";
export * from "./uuid.ts";
export * from "./enviroment.ts";
export * from "./cache.ts";
export { default as knowledge } from "./knowledge.ts";
129 changes: 129 additions & 0 deletions packages/core/src/knowledge.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import { UUID } from "crypto";

import { AgentRuntime } from "./runtime.ts";
import { embed } from "./embedding.ts";
import { Content, ModelClass, type Memory } from "./types.ts";
import { stringToUuid } from "./uuid.ts";
import { embeddingZeroVector } from "./memory.ts";
import { splitChunks } from "./generation.ts";
import { models } from "./models.ts";
import elizaLogger from "./logger.ts";

async function get(runtime: AgentRuntime, message: Memory): Promise<string[]> {
const processed = preprocess(message.content.text);
elizaLogger.log(`Querying knowledge for: ${processed}`);
const embedding = await embed(runtime, processed);
const fragments = await runtime.knowledgeManager.searchMemoriesByEmbedding(
embedding,
{
roomId: message.agentId,
agentId: message.agentId,
count: 3,
match_threshold: 0.1,
}
);

const uniqueSources = [
...new Set(
fragments.map((memory) => {
elizaLogger.log(
`Matched fragment: ${memory.content.text} with similarity: ${message.similarity}`
);
return memory.content.source;
})
),
];

const knowledgeDocuments = await Promise.all(
uniqueSources.map((source) =>
runtime.documentsManager.getMemoryById(source as UUID)
)
);

const knowledge = knowledgeDocuments
.filter((memory) => memory !== null)
.map((memory) => memory.content.text);
return knowledge;
}

export type KnowledgeItem = {
id: UUID;
content: Content;
};

async function set(runtime: AgentRuntime, item: KnowledgeItem) {
await runtime.documentsManager.createMemory({
embedding: embeddingZeroVector,
id: item.id,
agentId: runtime.agentId,
roomId: runtime.agentId,
userId: runtime.agentId,
createdAt: Date.now(),
content: item.content,
});

const preprocessed = preprocess(item.content.text);
const fragments = await splitChunks(
preprocessed,
10,
models[runtime.character.modelProvider].model?.[ModelClass.EMBEDDING],
5
);

for (const fragment of fragments) {
const embedding = await embed(runtime, fragment);
await runtime.knowledgeManager.createMemory({
// We namespace the knowledge base uuid to avoid id
// collision with the document above.
id: stringToUuid(item.id + fragment),
roomId: runtime.agentId,
agentId: runtime.agentId,
userId: runtime.agentId,
createdAt: Date.now(),
content: {
source: item.id,
text: fragment,
},
embedding,
});
}
}

export function preprocess(content: string): string {
return (
content
// Remove code blocks and their content
.replace(/```[\s\S]*?```/g, "")
// Remove inline code
.replace(/`.*?`/g, "")
// Convert headers to plain text with emphasis
.replace(/#{1,6}\s*(.*)/g, "$1")
// Remove image links but keep alt text
.replace(/!\[(.*?)\]\(.*?\)/g, "$1")
// Remove links but keep text
.replace(/\[(.*?)\]\(.*?\)/g, "$1")
// Remove HTML tags
.replace(/<[^>]*>/g, "")
// Remove horizontal rules
.replace(/^\s*[-*_]{3,}\s*$/gm, "")
// Remove comments
.replace(/\/\*[\s\S]*?\*\//g, "")
.replace(/\/\/.*/g, "")
// Normalize whitespace
.replace(/\s+/g, " ")
// Remove multiple newlines
.replace(/\n{3,}/g, "\n\n")
// strip all special characters
.replace(/[^a-zA-Z0-9\s]/g, "")
// Remove Discord mentions
.replace(/<@!?\d+>/g, "")
.trim()
.toLowerCase()
);
}

export default {
get,
set,
process,
};
Loading

0 comments on commit 9123996

Please sign in to comment.