Skip to content

Commit

Permalink
Add arango and local disk as repositories for conversation messages
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinbayes committed Nov 26, 2024
1 parent a042504 commit fac87e5
Show file tree
Hide file tree
Showing 22 changed files with 1,189 additions and 161 deletions.
9 changes: 9 additions & 0 deletions ERROR.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ID Registry

------

| ID | Owner |
|----|---------|
| 0 | common |
| 1 | agent |
| 2 | runtime |
4 changes: 4 additions & 0 deletions agentsmith-agent/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"
authors = ["Kevin Bayes"]

[dependencies]
arangors = { version = "0.6", default-features = false, features = ["reqwest", "reqwest_async", ] }
pyo3 = { version = "0.22.3", features = ["extension-module"] }
qdrant-client = "1"
agentsmith-common = { path = "../agentsmith-common"}
Expand All @@ -28,6 +29,9 @@ tracing-subscriber = "0.3"
log = "0.4.22"
short-uuid = "0.1"

sqlx = { version = "0.8", features = ["runtime-async-std-native-tls", "mysql", "chrono", "uuid"] }
uuid = { version = "1.10.0", features = ["v4"] }

[build-dependencies]
tonic-build = "0.12"
prost-build = "0.13"
Expand Down
10 changes: 10 additions & 0 deletions agentsmith-agent/ERROR.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Code Registry

------

| ID | Owner | Current |
|------|--------|---------|
| 1xxx | llm | |
| 2xxx | memory | |
| 3xxx | tools | |
| 4xxx | agent | |
1 change: 1 addition & 0 deletions agentsmith-agent/src/error.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1-2000=
18 changes: 18 additions & 0 deletions agentsmith-agent/src/llm/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ impl PromptMessage {
role: "tool".to_string(),
}
}


pub fn role(&self) -> &String {
match self {
PromptMessage::System { role, .. } => {
role
}
PromptMessage::User { role, .. } => {
role
}
PromptMessage::Assistant { role, .. } => {
role
}
PromptMessage::Tool { role, .. } => {
role
}
}
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
Expand Down
180 changes: 73 additions & 107 deletions agentsmith-agent/src/memory/general.rs
Original file line number Diff line number Diff line change
@@ -1,150 +1,116 @@
use std::sync::{Arc, Mutex, RwLock};
use qdrant_client::config::QdrantConfig;
use qdrant_client::Qdrant;
use qdrant_client::qdrant::{CreateCollection, CreateCollectionBuilder, Distance, PointStruct, ScalarQuantizationBuilder, VectorParams, VectorParamsBuilder, VectorsConfig};
use qdrant_client::qdrant::qdrant_client::QdrantClient;
use serde::{Deserialize, Serialize};
use agentsmith_common::config::config::Config;
use agentsmith_common::error::error::{SystemError, SystemResult};
use crate::llm::prompt::PromptMessage;
use crate::memory::memory::{InitialiseMemory, RecordMemory, RetrieveMemory};
use crate::memory::memory::{InitialiseMemory, MemoryBlock, MemoryContext, RecordMemory, RetrieveMemory};
use crate::memory::repository::semantic_repository::{SemanticCommand, SemanticMemoryConfiguration, SemanticQuery, SemanticRepository, SemanticRepositoryFactory};
use crate::memory::repository::working_repository::{WorkingMemoryCommand, WorkingMemoryConfiguration, WorkingMemoryQuery, WorkingMemoryRepository, WorkingMemoryRepositoryFactory};
use agentsmith_common::config::config::Config;
use agentsmith_common::error::error::SystemResult;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

#[derive(Clone)]
pub struct General {
pub message_log: Arc<RwLock<Vec<PromptMessage>>>,
pub memory: Arc<Qdrant>,
pub working: Arc<WorkingMemoryRepository>,
pub semantic: Arc<SemanticRepository>,
}

impl General {
pub fn new(config: &Config) -> Self {
let client = Qdrant::from_url(config.config.qdrant.host.clone().as_str()).build().unwrap();

pub async fn new(config: &Config,
working_memory_configuration: &WorkingMemoryConfiguration,
semantic_repository_configuration: &SemanticMemoryConfiguration) -> Self {

let working_factory = WorkingMemoryRepositoryFactory::new(config);
let semantic_factory = SemanticRepositoryFactory::new(config);

let working = working_factory.instance(
working_memory_configuration
)
.await
.unwrap();

let semantic = semantic_factory.instance(
semantic_repository_configuration
)
.await
.unwrap();

Self {
message_log: Arc::new(RwLock::new(vec![])),
memory: Arc::new(client),
working: Arc::new(working),
semantic: Arc::new(semantic),
}
}
}

impl InitialiseMemory for General {

async fn initialise_collection(&self, collection: &str) -> SystemResult<()> {

let creation_result = self.memory.create_collection(
CreateCollectionBuilder::new(collection.clone())
.vectors_config(VectorParamsBuilder::new(1024, Distance::Cosine))
.quantization_config(ScalarQuantizationBuilder::default()),
).await;

match creation_result {
Ok(_) => println!("Collection {} created", collection.clone()),
Err(e) => println!("Error creating collection {}: {}", collection.clone(), e)
}

Ok(())
async fn initialise(&self) -> SystemResult<()> {
self.working.initialise().await?;
self.semantic.initialise().await
}
}

impl RecordMemory for General {

async fn record_memory_chunk(&self, collection: &str, chunk: &str) -> SystemResult<bool> {

// let sample = embeddings.clone();
//
// let mut payload: Payload = Payload::new();
//
// for item in index_payload.iter() {
// payload.insert(item.0.clone(), Value::from(item.1.clone()));
// }
//
// let points = vec![PointStruct::new(id, sample, payload)];
//
// let collection_name = collection.clone();
//
// let response = self.memory.upsert_points(collection_name, None, points, None)
// .await
// .map_err(|e| {
// println!("Error indexing: {}", e);
// Error::MemoryError
// })
// ?;

Ok(true)
async fn record_memory_chunk(&self, chunk: &str) -> SystemResult<bool> {
self.semantic.record_memory_chunk(chunk).await
}

async fn record_prompt_messages<'a>(&'a self, messages: &'a Vec<PromptMessage>) -> SystemResult<bool> {

if messages.is_empty() {
Ok(false)
} else {
let mut result = self.message_log.write().unwrap();
result.extend(messages.iter().cloned());
Ok(true)
}
async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec<PromptMessage>) -> SystemResult<bool> {
self.working.record_prompt_messages(context, messages).await
}
}

impl RetrieveMemory for General {

async fn retrieve_past_messages(&self) -> SystemResult<Vec<PromptMessage>> {
Ok(self.message_log.read().unwrap().iter().cloned().collect())
async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult<Vec<PromptMessage>> {
self.working.retrieve_past_n_messages(context, last_n).await
}

async fn retrieve_memory_chunks(&self, collection: &str, query: &str) -> SystemResult<Vec<String>> {
todo!()
async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult<Vec<MemoryBlock>> {
self.semantic.retrieve_memory_chunks(query).await
}
}



#[cfg(test)]
mod tests {
use std::fs;
use serde_json::{json, Value};
use testcontainers::core::{IntoContainerPort, Mount, WaitFor};
use testcontainers::{GenericImage, ImageExt};
use testcontainers::runners::AsyncRunner;
use agentsmith_common::config::config::read_config;
use crate::llm::llm_factory::LLMFactory;
use crate::llm::prompt::UserContent;
use super::*;


// #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[tokio::test]
async fn test_record_prompt_message() {

tracing_subscriber::fmt::init();

let client = Qdrant::from_url("http://localhost:6333").build().unwrap();

let memory = General {
message_log: Arc::new(RwLock::new(vec![])),
memory: Arc::new(client),
};

memory.initialise_collection(&String::from("test")).await.unwrap();

let current_storage = memory.retrieve_past_messages().await;

assert_eq!(current_storage.unwrap().len(), 0);

let result_true = memory.record_prompt_messages(&vec![PromptMessage::User {
role: "user".to_string(),
content: vec![UserContent::Text {
type_: "text".to_string(),
text: "Hello world!".to_string(),
}],
name: None,
}]).await.unwrap();

assert!(result_true);

let result_false = memory.record_prompt_messages(&vec![]).await.unwrap();

assert!(!result_false);

let current_storage = memory.retrieve_past_messages().await.unwrap();

assert_eq!(current_storage.len(), 1);
// let client = Qdrant::from_url("http://localhost:6333").build().unwrap();
//
// let memory = General {
// message_log: Arc::new(RwLock::new(vec![])),
// memory: Arc::new(client),
// };
//
// memory.initialise(&String::from("test")).await.unwrap();
//
// let current_storage = memory.retrieve_past_messages().await;
//
// assert_eq!(current_storage.unwrap().len(), 0);
//
// let result_true = memory.record_prompt_messages(&vec![PromptMessage::User {
// role: "user".to_string(),
// content: vec![UserContent::Text {
// type_: "text".to_string(),
// text: "Hello world!".to_string(),
// }],
// name: None,
// }]).await.unwrap();
//
// assert!(result_true);
//
// let result_false = memory.record_prompt_messages(&vec![]).await.unwrap();
//
// assert!(!result_false);
//
// let current_storage = memory.retrieve_past_messages().await.unwrap();
//
// assert_eq!(current_storage.len(), 1);
}
}
52 changes: 33 additions & 19 deletions agentsmith-agent/src/memory/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ pub enum Memory {
GENERAL(General),
}

#[derive(Clone)]
pub struct MemoryBlock {
pub address: String,
pub string: String,
}

#[derive(Clone, Debug)]
pub struct MemoryConfiguration {
pub r#type: String,
Expand All @@ -27,6 +33,7 @@ impl MemoryFactory {
}

pub async fn instance(&self, memory_config: MemoryConfiguration) -> SystemResult<Memory> {

match memory_config.r#type.as_ref() {
"messages" => Ok(Memory::MESSAGES(Messages::new())),
_ => panic!(),
Expand All @@ -35,52 +42,59 @@ impl MemoryFactory {
}

impl RetrieveMemory for Memory {
async fn retrieve_memory_chunks(&self, collection: &str, query: &str) -> SystemResult<Vec<String>> {
async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult<Vec<MemoryBlock>> {
match self {
Memory::MESSAGES(memory) => memory.retrieve_memory_chunks(collection, query).await,
Memory::GENERAL(memory) => memory.retrieve_memory_chunks(collection, query).await,
Memory::MESSAGES(memory) => memory.retrieve_memory_chunks(query).await,
Memory::GENERAL(memory) => memory.retrieve_memory_chunks(query).await,
}
}

async fn retrieve_past_messages(&self) -> SystemResult<Vec<PromptMessage>> {
async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult<Vec<PromptMessage>> {
match self {
Memory::MESSAGES(memory) => memory.retrieve_past_messages().await,
Memory::GENERAL(memory) => memory.retrieve_past_messages().await,
Memory::MESSAGES(memory) => memory.retrieve_past_n_messages(context, last_n).await,
Memory::GENERAL(memory) => memory.retrieve_past_n_messages(context, last_n).await,
}
}
}

impl RecordMemory for Memory {
async fn record_memory_chunk(&self, collection: &str, chunk: &str) -> SystemResult<bool> {
async fn record_memory_chunk(&self, chunk: &str) -> SystemResult<bool> {
match self {
Memory::MESSAGES(memory) => memory.record_memory_chunk(collection, chunk).await,
Memory::GENERAL(memory) => memory.record_memory_chunk(collection, chunk).await,
Memory::MESSAGES(memory) => memory.record_memory_chunk(chunk).await,
Memory::GENERAL(memory) => memory.record_memory_chunk(chunk).await,
}
}

async fn record_prompt_messages<'a>(&'a self, messages: &'a Vec<PromptMessage>) -> SystemResult<bool> {
async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec<PromptMessage>) -> SystemResult<bool> {
match self {
Memory::MESSAGES(memory) => memory.record_prompt_messages(messages).await,
Memory::GENERAL(memory) => memory.record_prompt_messages(messages).await,
Memory::MESSAGES(memory) => memory.record_prompt_messages(context, messages).await,
Memory::GENERAL(memory) => memory.record_prompt_messages(context, messages).await,
}
}
}

pub struct MemoryContext {
pub interaction_id: String,
}

pub trait InitialiseMemory {
async fn initialise_collection(&self, collection: &str) -> SystemResult<()>;
async fn initialise(&self) -> SystemResult<()>;
}

pub trait RecordMemory {
async fn record_memory_chunk(&self, collection: &str, chunk: &str) -> SystemResult<bool>;
async fn record_prompt_message(&self, message: &PromptMessage) -> SystemResult<bool> {
self.record_prompt_messages(&vec![message.clone()]).await
async fn record_memory_chunk(&self, chunk: &str) -> SystemResult<bool>;
async fn record_prompt_message(&self, context: &MemoryContext, message: &PromptMessage) -> SystemResult<bool> {
self.record_prompt_messages(context, &vec![message.clone()]).await
}
async fn record_prompt_messages<'a>(&'a self, messages: &'a Vec<PromptMessage>) -> SystemResult<bool>;
async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec<PromptMessage>) -> SystemResult<bool>;
}

pub trait RetrieveMemory {
async fn retrieve_memory_chunks(&self, collection: &str, query: &str) -> SystemResult<Vec<String>>;
async fn retrieve_past_messages(&self) -> SystemResult<Vec<PromptMessage>>;
async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult<Vec<MemoryBlock>>;
async fn retrieve_past_messages(&self, context: &MemoryContext) -> SystemResult<Vec<PromptMessage>> {
self.retrieve_past_n_messages(context, -1).await
}
async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult<Vec<PromptMessage>>;
}


Expand Down
Loading

0 comments on commit fac87e5

Please sign in to comment.