diff --git a/ERROR.md b/ERROR.md new file mode 100644 index 0000000..c7d8179 --- /dev/null +++ b/ERROR.md @@ -0,0 +1,9 @@ +# ID Registry + +------ + +| ID | Owner | +|----|---------| +| 0 | common | +| 1 | agent | +| 2 | runtime | diff --git a/agentsmith-agent/Cargo.toml b/agentsmith-agent/Cargo.toml index b9826f2..1249bca 100644 --- a/agentsmith-agent/Cargo.toml +++ b/agentsmith-agent/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" authors = ["Kevin Bayes"] [dependencies] +arangors = { version = "0.6", default-features = false, features = ["reqwest", "reqwest_async", ] } pyo3 = { version = "0.22.3", features = ["extension-module"] } qdrant-client = "1" agentsmith-common = { path = "../agentsmith-common"} @@ -28,6 +29,9 @@ tracing-subscriber = "0.3" log = "0.4.22" short-uuid = "0.1" +sqlx = { version = "0.8", features = ["runtime-async-std-native-tls", "mysql", "chrono", "uuid"] } +uuid = { version = "1.10.0", features = ["v4"] } + [build-dependencies] tonic-build = "0.12" prost-build = "0.13" diff --git a/agentsmith-agent/ERROR.md b/agentsmith-agent/ERROR.md new file mode 100644 index 0000000..79ba95c --- /dev/null +++ b/agentsmith-agent/ERROR.md @@ -0,0 +1,10 @@ +# Code Registry + +------ + +| ID | Owner | Current | +|------|--------|---------| +| 1xxx | llm | | +| 2xxx | memory | | +| 3xxx | tools | | +| 4xxx | agent | | diff --git a/agentsmith-agent/src/error.properties b/agentsmith-agent/src/error.properties new file mode 100644 index 0000000..2d369fc --- /dev/null +++ b/agentsmith-agent/src/error.properties @@ -0,0 +1 @@ +1-2000= \ No newline at end of file diff --git a/agentsmith-agent/src/llm/prompt.rs b/agentsmith-agent/src/llm/prompt.rs index c9e5498..d9af29d 100644 --- a/agentsmith-agent/src/llm/prompt.rs +++ b/agentsmith-agent/src/llm/prompt.rs @@ -106,6 +106,24 @@ impl PromptMessage { role: "tool".to_string(), } } + + + pub fn role(&self) -> &String { + match self { + PromptMessage::System { role, .. } => { + role + } + PromptMessage::User { role, .. } => { + role + } + PromptMessage::Assistant { role, .. } => { + role + } + PromptMessage::Tool { role, .. } => { + role + } + } + } } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/agentsmith-agent/src/memory/general.rs b/agentsmith-agent/src/memory/general.rs index f87acac..b98d00f 100644 --- a/agentsmith-agent/src/memory/general.rs +++ b/agentsmith-agent/src/memory/general.rs @@ -1,96 +1,73 @@ -use std::sync::{Arc, Mutex, RwLock}; -use qdrant_client::config::QdrantConfig; -use qdrant_client::Qdrant; -use qdrant_client::qdrant::{CreateCollection, CreateCollectionBuilder, Distance, PointStruct, ScalarQuantizationBuilder, VectorParams, VectorParamsBuilder, VectorsConfig}; -use qdrant_client::qdrant::qdrant_client::QdrantClient; -use serde::{Deserialize, Serialize}; -use agentsmith_common::config::config::Config; -use agentsmith_common::error::error::{SystemError, SystemResult}; use crate::llm::prompt::PromptMessage; -use crate::memory::memory::{InitialiseMemory, RecordMemory, RetrieveMemory}; +use crate::memory::memory::{InitialiseMemory, MemoryBlock, MemoryContext, RecordMemory, RetrieveMemory}; +use crate::memory::repository::semantic_repository::{SemanticCommand, SemanticMemoryConfiguration, SemanticQuery, SemanticRepository, SemanticRepositoryFactory}; +use crate::memory::repository::working_repository::{WorkingMemoryCommand, WorkingMemoryConfiguration, WorkingMemoryQuery, WorkingMemoryRepository, WorkingMemoryRepositoryFactory}; +use agentsmith_common::config::config::Config; +use agentsmith_common::error::error::SystemResult; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; #[derive(Clone)] pub struct General { - pub message_log: Arc>>, - pub memory: Arc, + pub working: Arc, + pub semantic: Arc, } impl General { - pub fn new(config: &Config) -> Self { - let client = Qdrant::from_url(config.config.qdrant.host.clone().as_str()).build().unwrap(); + + pub async fn new(config: &Config, + working_memory_configuration: &WorkingMemoryConfiguration, + semantic_repository_configuration: &SemanticMemoryConfiguration) -> Self { + + let working_factory = WorkingMemoryRepositoryFactory::new(config); + let semantic_factory = SemanticRepositoryFactory::new(config); + + let working = working_factory.instance( + working_memory_configuration + ) + .await + .unwrap(); + + let semantic = semantic_factory.instance( + semantic_repository_configuration + ) + .await + .unwrap(); + Self { - message_log: Arc::new(RwLock::new(vec![])), - memory: Arc::new(client), + working: Arc::new(working), + semantic: Arc::new(semantic), } } } impl InitialiseMemory for General { - async fn initialise_collection(&self, collection: &str) -> SystemResult<()> { - - let creation_result = self.memory.create_collection( - CreateCollectionBuilder::new(collection.clone()) - .vectors_config(VectorParamsBuilder::new(1024, Distance::Cosine)) - .quantization_config(ScalarQuantizationBuilder::default()), - ).await; - - match creation_result { - Ok(_) => println!("Collection {} created", collection.clone()), - Err(e) => println!("Error creating collection {}: {}", collection.clone(), e) - } - - Ok(()) + async fn initialise(&self) -> SystemResult<()> { + self.working.initialise().await?; + self.semantic.initialise().await } } impl RecordMemory for General { - async fn record_memory_chunk(&self, collection: &str, chunk: &str) -> SystemResult { - - // let sample = embeddings.clone(); - // - // let mut payload: Payload = Payload::new(); - // - // for item in index_payload.iter() { - // payload.insert(item.0.clone(), Value::from(item.1.clone())); - // } - // - // let points = vec![PointStruct::new(id, sample, payload)]; - // - // let collection_name = collection.clone(); - // - // let response = self.memory.upsert_points(collection_name, None, points, None) - // .await - // .map_err(|e| { - // println!("Error indexing: {}", e); - // Error::MemoryError - // }) - // ?; - - Ok(true) + async fn record_memory_chunk(&self, chunk: &str) -> SystemResult { + self.semantic.record_memory_chunk(chunk).await } - async fn record_prompt_messages<'a>(&'a self, messages: &'a Vec) -> SystemResult { - - if messages.is_empty() { - Ok(false) - } else { - let mut result = self.message_log.write().unwrap(); - result.extend(messages.iter().cloned()); - Ok(true) - } + async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec) -> SystemResult { + self.working.record_prompt_messages(context, messages).await } } impl RetrieveMemory for General { - async fn retrieve_past_messages(&self) -> SystemResult> { - Ok(self.message_log.read().unwrap().iter().cloned().collect()) + async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult> { + self.working.retrieve_past_n_messages(context, last_n).await } - async fn retrieve_memory_chunks(&self, collection: &str, query: &str) -> SystemResult> { - todo!() + async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult> { + self.semantic.retrieve_memory_chunks(query).await } } @@ -98,53 +75,42 @@ impl RetrieveMemory for General { #[cfg(test)] mod tests { - use std::fs; - use serde_json::{json, Value}; - use testcontainers::core::{IntoContainerPort, Mount, WaitFor}; - use testcontainers::{GenericImage, ImageExt}; - use testcontainers::runners::AsyncRunner; - use agentsmith_common::config::config::read_config; - use crate::llm::llm_factory::LLMFactory; - use crate::llm::prompt::UserContent; - use super::*; - - // #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test] async fn test_record_prompt_message() { tracing_subscriber::fmt::init(); - let client = Qdrant::from_url("http://localhost:6333").build().unwrap(); - - let memory = General { - message_log: Arc::new(RwLock::new(vec![])), - memory: Arc::new(client), - }; - - memory.initialise_collection(&String::from("test")).await.unwrap(); - - let current_storage = memory.retrieve_past_messages().await; - - assert_eq!(current_storage.unwrap().len(), 0); - - let result_true = memory.record_prompt_messages(&vec![PromptMessage::User { - role: "user".to_string(), - content: vec![UserContent::Text { - type_: "text".to_string(), - text: "Hello world!".to_string(), - }], - name: None, - }]).await.unwrap(); - - assert!(result_true); - - let result_false = memory.record_prompt_messages(&vec![]).await.unwrap(); - - assert!(!result_false); - - let current_storage = memory.retrieve_past_messages().await.unwrap(); - - assert_eq!(current_storage.len(), 1); + // let client = Qdrant::from_url("http://localhost:6333").build().unwrap(); + // + // let memory = General { + // message_log: Arc::new(RwLock::new(vec![])), + // memory: Arc::new(client), + // }; + // + // memory.initialise(&String::from("test")).await.unwrap(); + // + // let current_storage = memory.retrieve_past_messages().await; + // + // assert_eq!(current_storage.unwrap().len(), 0); + // + // let result_true = memory.record_prompt_messages(&vec![PromptMessage::User { + // role: "user".to_string(), + // content: vec![UserContent::Text { + // type_: "text".to_string(), + // text: "Hello world!".to_string(), + // }], + // name: None, + // }]).await.unwrap(); + // + // assert!(result_true); + // + // let result_false = memory.record_prompt_messages(&vec![]).await.unwrap(); + // + // assert!(!result_false); + // + // let current_storage = memory.retrieve_past_messages().await.unwrap(); + // + // assert_eq!(current_storage.len(), 1); } } \ No newline at end of file diff --git a/agentsmith-agent/src/memory/memory.rs b/agentsmith-agent/src/memory/memory.rs index 521c57c..56d0b17 100644 --- a/agentsmith-agent/src/memory/memory.rs +++ b/agentsmith-agent/src/memory/memory.rs @@ -10,6 +10,12 @@ pub enum Memory { GENERAL(General), } +#[derive(Clone)] +pub struct MemoryBlock { + pub address: String, + pub string: String, +} + #[derive(Clone, Debug)] pub struct MemoryConfiguration { pub r#type: String, @@ -27,6 +33,7 @@ impl MemoryFactory { } pub async fn instance(&self, memory_config: MemoryConfiguration) -> SystemResult { + match memory_config.r#type.as_ref() { "messages" => Ok(Memory::MESSAGES(Messages::new())), _ => panic!(), @@ -35,52 +42,59 @@ impl MemoryFactory { } impl RetrieveMemory for Memory { - async fn retrieve_memory_chunks(&self, collection: &str, query: &str) -> SystemResult> { + async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult> { match self { - Memory::MESSAGES(memory) => memory.retrieve_memory_chunks(collection, query).await, - Memory::GENERAL(memory) => memory.retrieve_memory_chunks(collection, query).await, + Memory::MESSAGES(memory) => memory.retrieve_memory_chunks(query).await, + Memory::GENERAL(memory) => memory.retrieve_memory_chunks(query).await, } } - async fn retrieve_past_messages(&self) -> SystemResult> { + async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult> { match self { - Memory::MESSAGES(memory) => memory.retrieve_past_messages().await, - Memory::GENERAL(memory) => memory.retrieve_past_messages().await, + Memory::MESSAGES(memory) => memory.retrieve_past_n_messages(context, last_n).await, + Memory::GENERAL(memory) => memory.retrieve_past_n_messages(context, last_n).await, } } } impl RecordMemory for Memory { - async fn record_memory_chunk(&self, collection: &str, chunk: &str) -> SystemResult { + async fn record_memory_chunk(&self, chunk: &str) -> SystemResult { match self { - Memory::MESSAGES(memory) => memory.record_memory_chunk(collection, chunk).await, - Memory::GENERAL(memory) => memory.record_memory_chunk(collection, chunk).await, + Memory::MESSAGES(memory) => memory.record_memory_chunk(chunk).await, + Memory::GENERAL(memory) => memory.record_memory_chunk(chunk).await, } } - async fn record_prompt_messages<'a>(&'a self, messages: &'a Vec) -> SystemResult { + async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec) -> SystemResult { match self { - Memory::MESSAGES(memory) => memory.record_prompt_messages(messages).await, - Memory::GENERAL(memory) => memory.record_prompt_messages(messages).await, + Memory::MESSAGES(memory) => memory.record_prompt_messages(context, messages).await, + Memory::GENERAL(memory) => memory.record_prompt_messages(context, messages).await, } } } +pub struct MemoryContext { + pub interaction_id: String, +} + pub trait InitialiseMemory { - async fn initialise_collection(&self, collection: &str) -> SystemResult<()>; + async fn initialise(&self) -> SystemResult<()>; } pub trait RecordMemory { - async fn record_memory_chunk(&self, collection: &str, chunk: &str) -> SystemResult; - async fn record_prompt_message(&self, message: &PromptMessage) -> SystemResult { - self.record_prompt_messages(&vec![message.clone()]).await + async fn record_memory_chunk(&self, chunk: &str) -> SystemResult; + async fn record_prompt_message(&self, context: &MemoryContext, message: &PromptMessage) -> SystemResult { + self.record_prompt_messages(context, &vec![message.clone()]).await } - async fn record_prompt_messages<'a>(&'a self, messages: &'a Vec) -> SystemResult; + async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec) -> SystemResult; } pub trait RetrieveMemory { - async fn retrieve_memory_chunks(&self, collection: &str, query: &str) -> SystemResult>; - async fn retrieve_past_messages(&self) -> SystemResult>; + async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult>; + async fn retrieve_past_messages(&self, context: &MemoryContext) -> SystemResult> { + self.retrieve_past_n_messages(context, -1).await + } + async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult>; } diff --git a/agentsmith-agent/src/memory/messages.rs b/agentsmith-agent/src/memory/messages.rs index e57ac74..67a046f 100644 --- a/agentsmith-agent/src/memory/messages.rs +++ b/agentsmith-agent/src/memory/messages.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use agentsmith_common::config::config::Config; use agentsmith_common::error::error::{SystemError, SystemResult}; use crate::llm::prompt::PromptMessage; -use crate::memory::memory::{InitialiseMemory, RecordMemory, RetrieveMemory}; +use crate::memory::memory::{InitialiseMemory, MemoryBlock, MemoryContext, RecordMemory, RetrieveMemory}; #[derive(Clone)] pub struct Messages { @@ -24,11 +24,11 @@ impl Messages { impl RecordMemory for Messages { - async fn record_memory_chunk(&self, collection: &str, chunk: &str) -> SystemResult { + async fn record_memory_chunk(&self, chunk: &str) -> SystemResult { Ok(true) } - async fn record_prompt_messages<'a>(&'a self, messages: &'a Vec) -> SystemResult { + async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec) -> SystemResult { if messages.is_empty() { Ok(false) } else { @@ -41,11 +41,11 @@ impl RecordMemory for Messages { impl RetrieveMemory for Messages { - async fn retrieve_past_messages(&self) -> SystemResult> { + async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult> { Ok(self.message_log.read().unwrap().iter().cloned().collect()) } - async fn retrieve_memory_chunks(&self, collection: &str, query: &str) -> SystemResult> { + async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult> { Ok(vec![]) } } @@ -71,15 +71,17 @@ mod tests { tracing_subscriber::fmt::init(); - let client = Qdrant::from_url("http://localhost:6333").build().unwrap(); - let memory = Messages::new(); - let current_storage = memory.retrieve_past_messages().await; + let context = MemoryContext { + interaction_id: "ut".to_string(), + }; + + let current_storage = memory.retrieve_past_messages(&context).await; assert_eq!(current_storage.unwrap().len(), 0); - let result_true = memory.record_prompt_messages(&vec![PromptMessage::User { + let result_true = memory.record_prompt_messages(&context, &vec![PromptMessage::User { role: "user".to_string(), content: vec![UserContent::Text { type_: "text".to_string(), @@ -90,11 +92,11 @@ mod tests { assert!(result_true); - let result_false = memory.record_prompt_messages(&vec![]).await.unwrap(); + let result_false = memory.record_prompt_messages(&context, &vec![]).await.unwrap(); assert!(!result_false); - let current_storage = memory.retrieve_past_messages().await.unwrap(); + let current_storage = memory.retrieve_past_messages(&context).await.unwrap(); assert_eq!(current_storage.len(), 1); } diff --git a/agentsmith-agent/src/memory/mod.rs b/agentsmith-agent/src/memory/mod.rs index 8b19b73..cfd5205 100644 --- a/agentsmith-agent/src/memory/mod.rs +++ b/agentsmith-agent/src/memory/mod.rs @@ -1,3 +1,4 @@ pub mod memory; pub mod messages; -pub mod general; \ No newline at end of file +pub mod general; +pub mod repository; \ No newline at end of file diff --git a/agentsmith-agent/src/memory/repository/mod.rs b/agentsmith-agent/src/memory/repository/mod.rs new file mode 100644 index 0000000..a638d6f --- /dev/null +++ b/agentsmith-agent/src/memory/repository/mod.rs @@ -0,0 +1,6 @@ +pub mod semantic_repository; +pub mod semantic_arango_repository; +pub mod semantic_disk_repository; +pub mod working_repository; +pub mod working_arango_repository; +pub mod working_disk_repository; \ No newline at end of file diff --git a/agentsmith-agent/src/memory/repository/semantic_arango_repository.rs b/agentsmith-agent/src/memory/repository/semantic_arango_repository.rs new file mode 100644 index 0000000..38841af --- /dev/null +++ b/agentsmith-agent/src/memory/repository/semantic_arango_repository.rs @@ -0,0 +1,82 @@ +use arangors::client::reqwest::ReqwestClient; +use std::sync::Arc; + +use crate::memory::memory::{InitialiseMemory, MemoryBlock}; +use crate::memory::repository::semantic_repository::{SemanticCommand, SemanticMemoryConfiguration, SemanticQuery}; +use crate::memory::repository::working_repository::WorkingMemoryConfiguration; +use agentsmith_common::error::error::{SystemError, SystemResult}; +use arangors::{Collection, Connection, Database}; +use qdrant_client::Qdrant; + +#[derive(Clone)] +pub struct SemanticArangoRepository { + pub agent_id: String, + pub repository: Arc, + pub index: Arc, +} + + +impl SemanticArangoRepository { + + const DATABASE_NAME: &'static str = "agentsmith"; + const COLLECTION_NAME: &'static str = "semantic-memory"; + const INDEX_NAME: &'static str = "semantic-memory"; + + pub async fn new(config: WorkingMemoryConfiguration) -> SystemResult { + + let arango_config = config.arango.unwrap(); + + let qdrant_config = arango_config.index.clone(); + let index = Qdrant::from_url(qdrant_config.host.as_str()) + .build() + .unwrap(); + + let repository_config = arango_config.connection.clone(); + let arango_url = format!("{}://{}:{}", repository_config.protocol, repository_config.host, repository_config.port); + let repository = Connection::establish_jwt(arango_url.as_str(), + repository_config.user.as_str(), + repository_config.pass.as_str()) + .await + .unwrap(); + + Ok(Self { + agent_id: config.id.clone(), + index: Arc::new(index), + repository: Arc::new(repository), + }) + } + + async fn database(&self) -> SystemResult> { + Ok(self.repository.db(Self::DATABASE_NAME) + .await + .map_err(|e| SystemError::MemoryError { id: 0, code: 1}) + ?) + } + + async fn collection(&self) -> SystemResult> { + let db = self.database().await?; + + match db.collection(Self::COLLECTION_NAME).await { + Ok(collection) => Ok(collection), + Err(_) => Ok(db.create_collection(Self::COLLECTION_NAME).await.map_err(|e| SystemError::MemoryError { id: 0, code: 2})?), + } + } +} + +impl InitialiseMemory for SemanticArangoRepository { + async fn initialise(&self) -> SystemResult<()> { + todo!() + } +} + +impl SemanticCommand for SemanticArangoRepository { + async fn record_memory_chunk(&self, chunk: &str) -> SystemResult { + todo!() + } +} + +impl SemanticQuery for SemanticArangoRepository { + async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult> { + todo!() + } +} \ No newline at end of file diff --git a/agentsmith-agent/src/memory/repository/semantic_disk_repository.rs b/agentsmith-agent/src/memory/repository/semantic_disk_repository.rs new file mode 100644 index 0000000..62f8dad --- /dev/null +++ b/agentsmith-agent/src/memory/repository/semantic_disk_repository.rs @@ -0,0 +1,27 @@ +use crate::memory::memory::{InitialiseMemory, MemoryBlock}; +use agentsmith_common::error::error::SystemResult; +use crate::memory::repository::semantic_repository::{SemanticCommand, SemanticQuery, SemanticRepository}; + +#[derive(Clone)] +pub struct SemanticDiskRepository { + pub agent_id: String, +} + + +impl InitialiseMemory for SemanticDiskRepository { + async fn initialise(&self) -> SystemResult<()> { + todo!() + } +} + +impl SemanticCommand for SemanticDiskRepository { + async fn record_memory_chunk(&self, chunk: &str) -> SystemResult { + todo!() + } +} + +impl SemanticQuery for SemanticDiskRepository { + async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult> { + todo!() + } +} \ No newline at end of file diff --git a/agentsmith-agent/src/memory/repository/semantic_repository.rs b/agentsmith-agent/src/memory/repository/semantic_repository.rs new file mode 100644 index 0000000..7e09157 --- /dev/null +++ b/agentsmith-agent/src/memory/repository/semantic_repository.rs @@ -0,0 +1,91 @@ +use crate::memory::memory::{InitialiseMemory, MemoryBlock}; +use crate::memory::repository::semantic_arango_repository::SemanticArangoRepository; +use crate::memory::repository::semantic_disk_repository::SemanticDiskRepository; +use agentsmith_common::config::arango::ArangoConfig; +use agentsmith_common::config::config::{Config, QdrantConfig}; +use agentsmith_common::error::error::SystemResult; +use serde::{Deserialize, Serialize}; + +#[derive(Clone)] +pub enum SemanticRepository { + Disk(SemanticDiskRepository), + Arango(SemanticArangoRepository), +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct SemanticMemoryConfiguration { + pub id: String, + r#type: String, + pub disk: Option, + pub arango: Option, +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct SemanticDiskRepositoryConfiguration { + pub path: String, +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct SemanticArangoRepositoryConfiguration { + pub connection: ArangoConfig, + pub index: QdrantConfig, +} + + +pub struct SemanticRepositoryFactory { + pub config: Config, +} + +impl SemanticRepositoryFactory { + + pub fn new(config: &Config) -> Self { + Self { + config: config.clone() + } + } + + pub async fn instance(&self, semantic_memory_configuration: &SemanticMemoryConfiguration) -> SystemResult { + let memory_type: &str = semantic_memory_configuration.r#type.as_str(); + match memory_type { + // TODO: Create a disk or arango repository. + _ => panic!(), + } + } +} + +pub struct Query { + q: String, +} + +impl InitialiseMemory for SemanticRepository { + async fn initialise(&self) -> SystemResult<()> { + match self { + SemanticRepository::Disk(disk) => { + disk.initialise().await + } + SemanticRepository::Arango(arango) => { + arango.initialise().await + } + } + } +} + +pub trait SemanticCommand { + async fn record_memory_chunk(&self, chunk: &str) -> SystemResult; +} + +impl SemanticCommand for SemanticRepository { + async fn record_memory_chunk(&self, chunk: &str) -> SystemResult { + todo!() + } +} + +pub trait SemanticQuery { + async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult>; +} + +impl SemanticQuery for SemanticRepository { + async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult> { + todo!() + } +} \ No newline at end of file diff --git a/agentsmith-agent/src/memory/repository/working_arango_repository.rs b/agentsmith-agent/src/memory/repository/working_arango_repository.rs new file mode 100644 index 0000000..aed8528 --- /dev/null +++ b/agentsmith-agent/src/memory/repository/working_arango_repository.rs @@ -0,0 +1,380 @@ +use crate::llm::prompt::PromptMessage; +use crate::memory::memory::{InitialiseMemory, MemoryContext}; +use crate::memory::repository::working_repository::{ + WorkingMemoryCommand, WorkingMemoryConfiguration, WorkingMemoryQuery, +}; +use agentsmith_common::error::error::{SystemError, SystemResult}; +use arangors::client::reqwest::ReqwestClient; +use arangors::document::options::InsertOptions; +use arangors::{AqlQuery, Collection, Connection, Database}; +use chrono::{DateTime, Utc}; +use log::debug; +use qdrant_client::Qdrant; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use arangors::transaction::{Status, TransactionCollections, TransactionSettings}; +use uuid::Uuid; + +#[derive(Debug, Serialize, Deserialize)] +struct WorkingMemoryItem { + #[serde(rename = "_key")] + id: String, + agent_id: String, + message: PromptMessage, + created_on: DateTime, + created_by: String, + modified_on: DateTime, + modified_by: String, +} + +#[derive(Clone)] +pub struct WorkingMemoryArangoRepository { + pub agent_id: String, + pub repository: Arc, + pub index: Arc, +} + +impl WorkingMemoryArangoRepository { + const DATABASE_NAME: &'static str = "agentsmith"; + const COLLECTION_NAME: &'static str = "working-memory"; + const INDEX_NAME: &'static str = "working-memory"; + + pub async fn new(config: &WorkingMemoryConfiguration) -> SystemResult { + let arango_config = config.arango.clone().unwrap(); + + let qdrant_config = arango_config.index.clone(); + let index = Qdrant::from_url(qdrant_config.host.as_str()) + .build() + .unwrap(); + + let repository_config = arango_config.connection.clone(); + let arango_url = format!( + "{}://{}:{}", + repository_config.protocol, repository_config.host, repository_config.port + ); + let repository = Connection::establish_jwt( + arango_url.as_str(), + repository_config.user.as_str(), + repository_config.pass.as_str(), + ) + .await + .unwrap(); + + Ok(Self { + agent_id: config.id.clone(), + index: Arc::new(index), + repository: Arc::new(repository), + }) + } + + pub async fn database(&self) -> SystemResult> { + Ok(self + .repository + .db(Self::DATABASE_NAME) + .await + .map_err(|e| SystemError::MemoryError { id: 0, code: 1 })?) + } + + pub async fn collection(&self) -> SystemResult> { + let db = self.database().await?; + + match db.collection(Self::COLLECTION_NAME).await { + Ok(collection) => Ok(collection), + Err(_) => Ok(db + .create_collection(Self::COLLECTION_NAME) + .await + .map_err(|e| SystemError::MemoryError { id: 0, code: 2 })?), + } + } +} + +impl InitialiseMemory for WorkingMemoryArangoRepository { + async fn initialise(&self) -> SystemResult<()> { + let collection = self.collection().await?; + println!("Collection {:?}", collection); + Ok(()) + } +} + +impl WorkingMemoryCommand for WorkingMemoryArangoRepository { + async fn record_prompt_messages<'a>( + &'a self, + context: &MemoryContext, + messages: &'a Vec, + ) -> SystemResult { + // Ensure the collection exists + let db = self.database().await?; + + let now = Utc::now(); + + let mut transaction_actions = Vec::new(); + + // Convert messages to working memory items and prepare transaction actions + for msg in messages { + let memory_item = WorkingMemoryItem { + id: Uuid::new_v4().to_string(), + agent_id: self.agent_id.clone(), + message: msg.clone(), + created_on: now, + created_by: self.agent_id.clone(), + modified_on: now, + modified_by: self.agent_id.clone(), + }; + + // Prepare parameters for the action + let doc = serde_json::json!({ + "id": memory_item.id, + "agent_id": memory_item.agent_id, + "message": memory_item.message, + "created_on": memory_item.created_on, + "created_by": memory_item.created_by, + "modified_on": memory_item.modified_on, + "modified_by": memory_item.modified_by + }); + + transaction_actions.push(doc); + } + + // Configure transaction options + let transaction_settings = TransactionSettings::builder() + .lock_timeout(60000) + .wait_for_sync(true) + .collections( + TransactionCollections::builder() + .write(vec![Self::COLLECTION_NAME.clone().to_owned()]) + .build(), + ) + .build(); + + // Begin transaction + let mut transaction = db.begin_transaction(transaction_settings) + .await + .map_err(|e| { + println!("Failed to insert documents - {}", e); + SystemError::MemoryError { id: 3, code: 1000 } + })?; + + + let collection = transaction.collection(Self::COLLECTION_NAME) + .await + .map_err(|e| { + println!("Failed to insert documents - {}", e); + SystemError::MemoryError { id: 3, code: 1001 } + })?; + + // Execute each document insert within the transaction + for doc in transaction_actions { + collection.create_document(doc, InsertOptions::builder().build()) + .await + .map_err(|e| { + println!("Failed to insert documents - {}", e); + SystemError::MemoryError { id: 3, code: 1002 } + })?; + } + + // Commit transaction + let commit_result = transaction.commit() + .await + .map_err(|e| { + println!("Failed to insert documents - {}", e); + SystemError::MemoryError { id: 3, code: 1003 } + })?; + + // Check transaction status + match commit_result { + Status::Committed => Ok(true), + _ => { + println!("Failed to insert documents."); + Err(SystemError::MemoryError { id: 3, code: 1004 }) + } + } + } +} + +impl WorkingMemoryQuery for WorkingMemoryArangoRepository { + async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult> { + let db = self.database().await?; + + let limit = if last_n < 1 { 2000 } else { last_n }; + println!("Limit is {}", limit); + + // AQL query to get the last n messages ordered by creation time + let aql = AqlQuery::builder() + .query("FOR u IN @@collection FILTER u.agent_id==@agent_id LIMIT @limit RETURN u") + .bind_var("@collection", Self::COLLECTION_NAME) + .bind_var("agent_id", self.agent_id.clone()) + .bind_var("limit", limit) + .build(); + + let messages: Vec = db + .aql_query(aql) + .await + .map_err(|e| SystemError::MemoryError { id: 0, code: 0 })?; + + Ok(messages.iter().map(|i| i.message.clone()).collect()) + } +} + +#[cfg(test)] +mod tests { + use std::{fs, thread}; + use std::path::Path; + use std::time::Duration; + use arangors::Connection; + use log::Level; + use testcontainers::*; + use testcontainers::{ + core::{IntoContainerPort, WaitFor}, + runners::AsyncRunner, + GenericImage, + }; + use testcontainers::core::logs::consumer::LogConsumer; + use testcontainers::core::logs::consumer::logging_consumer::LoggingConsumer; + use testcontainers::core::logs::LogFrame; + use futures::{future::BoxFuture, FutureExt}; + use agentsmith_common::config::arango::ArangoConfig; + use agentsmith_common::config::config::QdrantConfig; + use crate::llm::prompt::{PromptMessage, UserContent}; + use crate::memory::memory::{InitialiseMemory, MemoryContext}; + use crate::memory::repository::working_arango_repository::WorkingMemoryArangoRepository; + use crate::memory::repository::working_repository::{WorkingMemoryArangoRepositoryConfiguration, WorkingMemoryCommand, WorkingMemoryConfiguration, WorkingMemoryQuery}; + + struct TestFixture { + id: String, + container_arango: ContainerAsync, + container_index: ContainerAsync, + } + + #[derive(Clone)] + struct LogPrinter{} + + impl LogConsumer for LogPrinter { + fn accept<'a>(&'a self, record: &'a LogFrame) -> BoxFuture<'a, ()> { + + async move { + match record { + LogFrame::StdOut(bytes) => { + println!("container: {:?}", String::from_utf8_lossy(bytes)); + } + LogFrame::StdErr(bytes) => { + println!("container: {:?}", String::from_utf8_lossy(bytes)); + } + } + }.boxed() + } + } + + impl TestFixture { + async fn new() -> Self { + + let log_consumer = LogPrinter {}; + + let container_arango = GenericImage::new("arangodb", "3.12") + .with_wait_for(WaitFor::message_on_stdout("Have fun!")) + .with_mapped_port(18529, 8529.tcp()) + .with_log_consumer(log_consumer.clone()) + .with_env_var("ARANGO_ROOT_PASSWORD", "password") + .start() + .await + .expect("Arango started"); + + let container_index = GenericImage::new("qdrant/qdrant", "latest") + .with_wait_for(WaitFor::message_on_stdout("Qdrant gRPC listening on 6334")) + .with_mapped_port(16333, 6333.tcp()) + .with_mapped_port(16334, 6334.tcp()) + .with_log_consumer(log_consumer.clone()) + .start() + .await + .expect("Qdrant started"); + + let repository = Connection::establish_jwt( + "http://localhost:18529", + "root", + "password", + ) + .await + .unwrap(); + + repository.create_database(WorkingMemoryArangoRepository::DATABASE_NAME).await.unwrap(); + + let id = "unittest-1".to_string(); + + Self { + id, + container_arango, + container_index, + } + } + } + + impl Drop for TestFixture { + fn drop(&mut self) { + } + } + + #[tokio::test] + async fn test_working_memory_initialisation() { + let fixture = TestFixture::new().await; + // thread::sleep(Duration::from_secs(30)); + println!("Container arango started: {:?}", fixture.container_arango); + println!("Container index started: {:?}", fixture.container_index); + + let working_memory_config = WorkingMemoryConfiguration { + id: "agent-ut".to_string(), + r#type: "arango".to_string(), + disk: None, + arango: Some(WorkingMemoryArangoRepositoryConfiguration { + index: QdrantConfig { + host: "http://192.168.1.151:16334".to_string(), + }, + connection: ArangoConfig { + protocol: "http".to_string(), + host: "localhost".to_string(), + port: "18529".to_string(), + user: "root".to_string(), + pass: "password".to_string() + } + }) + }; + + let repository = WorkingMemoryArangoRepository::new(&working_memory_config).await.unwrap(); + + let _ = repository.initialise().await; + + let context = MemoryContext { + interaction_id: "ut".to_string(), + }; + + let prompt = PromptMessage::System { + name: Some("system".to_string()), + content: String::from("Test system"), + role: "system".to_string(), + }; + + let result = repository.record_prompt_message(&context, &prompt).await.unwrap(); + assert!(result); + + let saved_messages = &repository.retrieve_past_messages(&context).await.unwrap(); + println!("Saved messages {:?}", saved_messages); + + + let prompt = PromptMessage::User { + name: Some("kevin".to_string()), + content: vec![UserContent::Text { type_: "text".to_string(), text: "Hello world!".to_string() }], + role: "user".to_string(), + }; + + let result = repository.record_prompt_message(&context, &prompt).await.unwrap(); + assert!(result); + + let saved_messages = &repository.retrieve_past_messages(&context).await.unwrap(); + + let saved_system_message = saved_messages.get(0).unwrap(); + let saved_user_message = saved_messages.get(1).unwrap(); + + assert_eq!(2, saved_messages.len()); + + assert_eq!("system".to_string(), saved_system_message.role().clone()); + assert_eq!("user".to_string(), saved_user_message.role().clone()); + } +} diff --git a/agentsmith-agent/src/memory/repository/working_disk_repository.rs b/agentsmith-agent/src/memory/repository/working_disk_repository.rs new file mode 100644 index 0000000..3d331f5 --- /dev/null +++ b/agentsmith-agent/src/memory/repository/working_disk_repository.rs @@ -0,0 +1,296 @@ +use std::fs; +use std::fs::{File, OpenOptions}; +use std::io::{BufReader, BufWriter, Read, Write}; +use std::path::Path; +use std::sync::{Arc, RwLock}; +use arangors::AqlQuery; +use chrono::{DateTime, Utc}; +use log::debug; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; +use agentsmith_common::error::error::{SystemError, SystemResult}; +use crate::llm::prompt::PromptMessage; +use crate::memory::general::General; +use crate::memory::memory::{InitialiseMemory, MemoryContext}; +use crate::memory::messages::Messages; +use crate::memory::repository::working_arango_repository::WorkingMemoryArangoRepository; +use crate::memory::repository::working_repository::{WorkingMemoryCommand, WorkingMemoryConfiguration, WorkingMemoryQuery}; + +#[derive(Debug, Serialize, Deserialize)] +struct WorkingMemoryItem { + #[serde(rename = "_key")] + id: String, + agent_id: String, + message: PromptMessage, + created_on: DateTime, + created_by: String, + modified_on: DateTime, + modified_by: String, +} + +#[derive(Clone)] +pub struct WorkingMemoryDiskRepository { + pub agent_id: String, + pub config: WorkingMemoryConfiguration, +} + +impl WorkingMemoryDiskRepository { + + pub async fn new(config: &WorkingMemoryConfiguration) -> SystemResult { + + Ok(Self { + agent_id: config.id.clone(), + config: config.clone(), + }) + } + + pub fn file_path(&self,context: &MemoryContext) -> SystemResult { + + let disk_config = self.config.disk.as_ref().ok_or(SystemError::MemoryError { + code: 2009, + id: 1 + })?; + + Ok(format!("{}/{}/working-{}.db", disk_config.path, self.agent_id, context.interaction_id)) + } + + pub fn open_file(&self, context: &MemoryContext) -> SystemResult { + + let file_path_str = self.file_path(context)?; + + let file_path = Path::new(file_path_str.as_str()); + + if !file_path.exists() { + File::create(file_path) + .map_err(|e| { + println!("Error creating file {}.", e); + SystemError::MemoryError { code: 2008, id: 1 } + })?; + } + + let file = File::open(file_path).map_err(|e| { + debug!("Error opening file for reading: {}", e); + SystemError::MemoryError { code: 2007, id: 1 } + })?; + + Ok(file) + } +} + + +impl WorkingMemoryCommand for WorkingMemoryDiskRepository { + + async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec) -> SystemResult { + + let file_path = self.file_path(context)?; + + // Read existing content + let file = self.open_file(context)?; + + let mut reader = BufReader::new(file); + let mut content = String::new(); + reader.read_to_string(&mut content).map_err(|e| { + println!("Error reading file: {}", e); + SystemError::MemoryError { code: 2000, id: 1 } + })?; + + // Parse existing content or create new array + let mut memory_items: Vec = if content.is_empty() { + Vec::new() + } else { + serde_json::from_str(&content).map_err(|e| { + debug!("Error parsing JSON: {}", e); + SystemError::MemoryError { code: 2001, id: 1 } + })? + }; + + // Add new messages + let now = Utc::now(); + for message in messages { + let item = WorkingMemoryItem { + id: Uuid::new_v4().to_string(), + agent_id: self.agent_id.clone(), + message: message.clone(), + created_on: now, + created_by: "system".to_string(), + modified_on: now, + modified_by: "system".to_string(), + }; + memory_items.push(item); + } + + // Write back to file + let file = OpenOptions::new() + .write(true) + .truncate(true) + .open(&file_path) + .map_err(|e| { + debug!("Error opening file for writing: {}", e); + SystemError::MemoryError { code: 2002, id: 1 } + })?; + + let writer = BufWriter::new(file); + serde_json::to_writer_pretty(writer, &memory_items).map_err(|e| { + debug!("Error writing JSON: {}", e); + SystemError::MemoryError { code: 2003, id: 1 } + })?; + + Ok(true) + } +} + +impl WorkingMemoryQuery for WorkingMemoryDiskRepository { + + async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult> { + + // Read file content + let file = self.open_file(context)?; + + let mut reader = BufReader::new(file); + let mut content = String::new(); + reader.read_to_string(&mut content).map_err(|e| { + debug!("Error reading file: {}", e); + SystemError::MemoryError { code: 2004, id: 1 } + })?; + + // Parse content + if content.is_empty() { + return Ok(Vec::new()); + } + + let memory_items: Vec = serde_json::from_str(&content).map_err(|e| { + debug!("Error parsing JSON: {}", e); + SystemError::MemoryError { code: 2005, id: 1 } + })?; + + // Get last n messages + let mut messages: Vec = memory_items + .iter() + .rev() + .take(last_n as usize) + .map(|item| item.message.clone()) + .collect(); + + messages.reverse(); + + Ok(messages) + } +} + +impl InitialiseMemory for WorkingMemoryDiskRepository { + async fn initialise(&self) -> SystemResult<()> { + let id = &self.agent_id; + let config = &self.config; + let disk_config = &config.disk.clone().unwrap(); + let path_str = format!("{}/{}", disk_config.path, id); + + let path = Path::new(path_str.as_str()); + if path.exists() { + debug!("Directory exists!"); + } else { + fs::create_dir_all(path_str) + .map_err(|e| { + println!("Error path for file {}.", e); + SystemError::MemoryError { code: 2006, id: 1 } + })?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::fs; + use std::path::Path; + use crate::llm::prompt::{PromptMessage, UserContent}; + use crate::memory::memory::{InitialiseMemory, MemoryContext}; + use crate::memory::repository::working_disk_repository::WorkingMemoryDiskRepository; + use crate::memory::repository::working_repository::{WorkingMemoryCommand, WorkingMemoryConfiguration, WorkingMemoryDiskRepositoryConfiguration, WorkingMemoryQuery}; + + + struct TestFixture { + id: String, + path: String, + path_str: String, + } + + impl TestFixture { + + fn new() -> Self { + let id = "unittest-1".to_string(); + let path = "./tmp".to_string(); + let path_str = format!("{}/{}", path, id); + Self { + id, path, path_str, + } + } + } + + impl Drop for TestFixture { + + fn drop(&mut self) { + + let path = Path::new(self.path_str.as_str()); + let _ = fs::remove_dir_all(path); + } + } + + #[tokio::test] + async fn test_working_memory_save_and_read_message() { + + let fixture = TestFixture::new(); + + let id = fixture.id.as_str(); + let path = fixture.path.as_str(); + + let config = WorkingMemoryConfiguration { + id: id.to_string(), + r#type: "disk".to_string(), + disk: Some(WorkingMemoryDiskRepositoryConfiguration { + path: path.to_string(), + }), + arango: None, + }; + let repository = WorkingMemoryDiskRepository::new(&config).await.unwrap(); + + let _ = repository.initialise().await; + + let prompt = PromptMessage::System { + name: Some("system".to_string()), + content: String::from("Test system"), + role: "system".to_string(), + }; + + let context = MemoryContext { + interaction_id: "ut".to_string(), + }; + + let result = repository.record_prompt_message(&context, &prompt).await.unwrap(); + assert!(result); + + let saved_messages = &repository.retrieve_past_messages(&context).await.unwrap(); + println!("Saved messages {:?}", saved_messages); + + + let prompt = PromptMessage::User { + name: Some("kevin".to_string()), + content: vec![UserContent::Text { type_: "text".to_string(), text: "Hello world!".to_string() }], + role: "user".to_string(), + }; + + let result = repository.record_prompt_message(&context, &prompt).await.unwrap(); + assert!(result); + + let saved_messages = &repository.retrieve_past_messages(&context).await.unwrap(); + + let saved_system_message = saved_messages.get(0).unwrap(); + let saved_user_message = saved_messages.get(1).unwrap(); + + assert_eq!(2, saved_messages.len()); + + assert_eq!("system".to_string(), saved_system_message.role().clone()); + assert_eq!("user".to_string(), saved_user_message.role().clone()); + } +} + diff --git a/agentsmith-agent/src/memory/repository/working_repository.rs b/agentsmith-agent/src/memory/repository/working_repository.rs new file mode 100644 index 0000000..a9d9069 --- /dev/null +++ b/agentsmith-agent/src/memory/repository/working_repository.rs @@ -0,0 +1,101 @@ +use serde::{Deserialize, Serialize}; +use agentsmith_common::config::arango::ArangoConfig; +use agentsmith_common::config::config::{Config, QdrantConfig}; +use agentsmith_common::error::error::SystemResult; +use crate::llm::prompt::PromptMessage; +use crate::memory::memory::{InitialiseMemory, MemoryContext}; +use crate::memory::repository::working_arango_repository::WorkingMemoryArangoRepository; +use crate::memory::repository::working_disk_repository::WorkingMemoryDiskRepository; + +#[derive(Clone)] +pub enum WorkingMemoryRepository { + Disk(WorkingMemoryDiskRepository), + Arango(WorkingMemoryArangoRepository), +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct WorkingMemoryConfiguration { + pub id: String, + pub(crate) r#type: String, + pub disk: Option, + pub arango: Option, +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct WorkingMemoryDiskRepositoryConfiguration { + pub path: String, +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct WorkingMemoryArangoRepositoryConfiguration { + pub connection: ArangoConfig, + pub index: QdrantConfig, +} + +pub struct WorkingMemoryRepositoryFactory { + pub config: Config, +} + +impl WorkingMemoryRepositoryFactory { + + pub fn new(config: &Config) -> Self { + Self { + config: config.clone(), + } + } + + pub async fn instance(&self, working_memory_config: &WorkingMemoryConfiguration) -> SystemResult { + match working_memory_config.r#type.as_ref() { + "disk" => { + let disk = WorkingMemoryDiskRepository::new(working_memory_config) + .await?; + Ok(WorkingMemoryRepository::Disk(disk)) + }, + "arango" => { + let arango = WorkingMemoryArangoRepository::new(working_memory_config) + .await?; + Ok(WorkingMemoryRepository::Arango(arango)) + }, + _ => panic!(), + } + } +} + +impl InitialiseMemory for WorkingMemoryRepository { + async fn initialise(&self) -> SystemResult<()> { + match self { + WorkingMemoryRepository::Disk(disk) => { + disk.initialise().await + } + WorkingMemoryRepository::Arango(arango) => { + arango.initialise().await + } + } + } +} + +pub trait WorkingMemoryCommand { + async fn record_prompt_message(&self, context: &MemoryContext, message: &PromptMessage) -> SystemResult { + self.record_prompt_messages(context, &vec![message.clone()]).await + } + async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec) -> SystemResult; +} + +impl WorkingMemoryCommand for WorkingMemoryRepository { + async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec) -> SystemResult { + todo!() + } +} + +pub trait WorkingMemoryQuery { + async fn retrieve_past_messages(&self, context: &MemoryContext) -> SystemResult> { + self.retrieve_past_n_messages(context, -1).await + } + async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult>; +} + +impl WorkingMemoryQuery for WorkingMemoryRepository { + async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult> { + todo!() + } +} \ No newline at end of file diff --git a/agentsmith-common/src/config/arango.rs b/agentsmith-common/src/config/arango.rs new file mode 100644 index 0000000..29499d8 --- /dev/null +++ b/agentsmith-common/src/config/arango.rs @@ -0,0 +1,11 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ArangoConfig { + // Define your configuration structure + pub protocol: String, + pub host: String, + pub port: String, + pub user: String, + pub pass: String, +} \ No newline at end of file diff --git a/agentsmith-common/src/config/config.rs b/agentsmith-common/src/config/config.rs index 09afc5d..4416d02 100644 --- a/agentsmith-common/src/config/config.rs +++ b/agentsmith-common/src/config/config.rs @@ -1,7 +1,8 @@ use std::{fs, io}; use std::collections::HashMap; use std::io::Read; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; +use crate::config::arango::ArangoConfig; pub fn read_config(file_path: &str) -> Result { @@ -27,6 +28,7 @@ pub struct ServerConfig { pub redis: RedisConfig, pub qdrant: QdrantConfig, pub database: DatabaseConfig, + pub arango: Option, pub host: HostConfig, pub security: SecurityConfig, pub gateways: GatewaysConfig, @@ -49,7 +51,7 @@ pub struct GatewayConfig { pub model: String, } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct SecurityConfig { // Define your configuration structure #[serde(rename = "oauth")] @@ -59,7 +61,7 @@ pub struct SecurityConfig { pub jwt: SecurityJwtConfig, } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct OAuthConfig { // Define your configuration structure #[serde(rename = "jwks_domain")] @@ -72,7 +74,7 @@ pub struct OAuthConfig { pub audience: String, } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct SecurityJwtConfig { // Define your configuration structure #[serde(rename = "secret")] @@ -81,14 +83,14 @@ pub struct SecurityJwtConfig { pub issuer: String, } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct HostConfig { // Define your configuration structure pub host: String, pub port: i32, } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct RedisConfig { // Define your configuration structure pub host: String, @@ -96,13 +98,13 @@ pub struct RedisConfig { } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct QdrantConfig { // Define your configuration structure pub host: String, } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct DatabaseConfig { // Define your configuration structure pub connection: String, diff --git a/agentsmith-common/src/config/mod.rs b/agentsmith-common/src/config/mod.rs index a105933..2b75a18 100644 --- a/agentsmith-common/src/config/mod.rs +++ b/agentsmith-common/src/config/mod.rs @@ -1 +1,2 @@ +pub mod arango; pub mod config; \ No newline at end of file diff --git a/agentsmith-runtime/Cargo.toml b/agentsmith-runtime/Cargo.toml index 30ca50a..cc83bdd 100644 --- a/agentsmith-runtime/Cargo.toml +++ b/agentsmith-runtime/Cargo.toml @@ -13,6 +13,7 @@ tokio = { version = "1.40", features = ["full"] } tokio-tungstenite = "0.24" serde = "1.0" serde_json = "1.0" +uuid = { version = "1.10.0", features = ["v4"] } [build-dependencies] tonic-build = "0.12" diff --git a/agentsmith-runtime/src/simple/swarm.rs b/agentsmith-runtime/src/simple/swarm.rs index 717dcb3..7881cc6 100644 --- a/agentsmith-runtime/src/simple/swarm.rs +++ b/agentsmith-runtime/src/simple/swarm.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use agentsmith_agent::agent::agent::Agent; -use agentsmith_agent::memory::memory::{Memory, RecordMemory, RetrieveMemory}; +use agentsmith_agent::memory::memory::{Memory, MemoryContext, RecordMemory, RetrieveMemory}; use agentsmith_agent::memory::messages::Messages; use std::sync::Arc; use serde_json::json; @@ -9,8 +9,10 @@ use agentsmith_agent::llm::prompt::{Prompt, PromptMessage, Tool}; use agentsmith_agent::tools::registry::{SafeToolRegistry,}; use agentsmith_agent::tools::tool::{SimpleToolExecution, ToolResult}; use agentsmith_common::error::error::SystemResult; +use uuid::{uuid, Uuid}; pub struct Swarm { + pub memory_context: Arc, pub memory: Arc, pub agents: Arc>, pub tool_registry: Arc, @@ -38,9 +40,12 @@ impl Swarm { panic!("No agents configured, must have at least one!"); } + let interaction_id = Uuid::new_v4().to_string(); + let initial_agent = agents.first().unwrap().clone().id().to_string(); Self { + memory_context: Arc::new(MemoryContext { interaction_id }), memory: Arc::new(Memory::MESSAGES(Messages::new())), agents: agents.clone(), tool_registry: Arc::new(tool_registry.clone()), @@ -61,7 +66,7 @@ impl Swarm { pub async fn run(&mut self, initial: PromptMessage) -> SystemResult { - let _ = self.memory.record_prompt_messages(&vec![initial]).await; + let _ = self.memory.record_prompt_messages(&self.memory_context, &vec![initial]).await; while self.turn < self.max_turns { @@ -69,7 +74,7 @@ impl Swarm { .find(|item| item.id() == self.active_agent.clone()) .unwrap(); - let messages = self.memory.retrieve_past_messages() + let messages = self.memory.retrieve_past_messages(&self.memory_context) .await? .clone(); @@ -80,7 +85,7 @@ impl Swarm { let make_tool_calls = !result.tool_calls.is_empty(); let assistant_message = PromptMessage::from_assistant_message(&result); - let _ = self.memory.record_prompt_message(&assistant_message).await; + let _ = self.memory.record_prompt_message(&self.memory_context, &assistant_message).await; if make_tool_calls { @@ -174,7 +179,7 @@ impl Swarm { let prompt_message = PromptMessage::from_tool_result( &tool_result, tool.as_ref() ); - self.memory.record_prompt_messages(&vec![prompt_message]).await.unwrap(); + self.memory.record_prompt_messages(&self.memory_context, &vec![prompt_message]).await.unwrap(); } } } @@ -317,6 +322,6 @@ mod tests { let result = swarm.run(message).await.unwrap(); - println!("all messages: {:?}", result.memory.retrieve_past_messages().await); + println!("all messages: {:?}", result.memory.retrieve_past_messages(&swarm.memory_context).await); } } \ No newline at end of file diff --git a/local/docker-compose.yml b/local/docker-compose.yml index e9e49c6..a21c7fc 100644 --- a/local/docker-compose.yml +++ b/local/docker-compose.yml @@ -32,15 +32,15 @@ services: - "8529:8529" environment: - ARANGO_ROOT_PASSWORD=password - volumes: - - type: bind - read_only: false - source: ./data/arango/data/ - target: /var/lib/arangodb3 - - source: ./data/arango/backup/ - target: /opt/backup/ - read_only: false - type: bind +# volumes: +# - type: bind +# read_only: false +# source: ./data/arango/data/ +# target: /var/lib/arangodb3 +# - source: ./data/arango/backup/ +# target: /opt/backup/ +# read_only: false +# type: bind database: