-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add arango and local disk as repositories for conversation messages
- Loading branch information
1 parent
a042504
commit fac87e5
Showing
22 changed files
with
1,189 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# ID Registry | ||
|
||
------ | ||
|
||
| ID | Owner | | ||
|----|---------| | ||
| 0 | common | | ||
| 1 | agent | | ||
| 2 | runtime | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Code Registry | ||
|
||
------ | ||
|
||
| ID | Owner | Current | | ||
|------|--------|---------| | ||
| 1xxx | llm | | | ||
| 2xxx | memory | | | ||
| 3xxx | tools | | | ||
| 4xxx | agent | | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
1-2000= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,150 +1,116 @@ | ||
use std::sync::{Arc, Mutex, RwLock}; | ||
use qdrant_client::config::QdrantConfig; | ||
use qdrant_client::Qdrant; | ||
use qdrant_client::qdrant::{CreateCollection, CreateCollectionBuilder, Distance, PointStruct, ScalarQuantizationBuilder, VectorParams, VectorParamsBuilder, VectorsConfig}; | ||
use qdrant_client::qdrant::qdrant_client::QdrantClient; | ||
use serde::{Deserialize, Serialize}; | ||
use agentsmith_common::config::config::Config; | ||
use agentsmith_common::error::error::{SystemError, SystemResult}; | ||
use crate::llm::prompt::PromptMessage; | ||
use crate::memory::memory::{InitialiseMemory, RecordMemory, RetrieveMemory}; | ||
use crate::memory::memory::{InitialiseMemory, MemoryBlock, MemoryContext, RecordMemory, RetrieveMemory}; | ||
use crate::memory::repository::semantic_repository::{SemanticCommand, SemanticMemoryConfiguration, SemanticQuery, SemanticRepository, SemanticRepositoryFactory}; | ||
use crate::memory::repository::working_repository::{WorkingMemoryCommand, WorkingMemoryConfiguration, WorkingMemoryQuery, WorkingMemoryRepository, WorkingMemoryRepositoryFactory}; | ||
use agentsmith_common::config::config::Config; | ||
use agentsmith_common::error::error::SystemResult; | ||
use serde::{Deserialize, Serialize}; | ||
use std::sync::Arc; | ||
|
||
#[derive(Clone)] | ||
pub struct General { | ||
pub message_log: Arc<RwLock<Vec<PromptMessage>>>, | ||
pub memory: Arc<Qdrant>, | ||
pub working: Arc<WorkingMemoryRepository>, | ||
pub semantic: Arc<SemanticRepository>, | ||
} | ||
|
||
impl General { | ||
pub fn new(config: &Config) -> Self { | ||
let client = Qdrant::from_url(config.config.qdrant.host.clone().as_str()).build().unwrap(); | ||
|
||
pub async fn new(config: &Config, | ||
working_memory_configuration: &WorkingMemoryConfiguration, | ||
semantic_repository_configuration: &SemanticMemoryConfiguration) -> Self { | ||
|
||
let working_factory = WorkingMemoryRepositoryFactory::new(config); | ||
let semantic_factory = SemanticRepositoryFactory::new(config); | ||
|
||
let working = working_factory.instance( | ||
working_memory_configuration | ||
) | ||
.await | ||
.unwrap(); | ||
|
||
let semantic = semantic_factory.instance( | ||
semantic_repository_configuration | ||
) | ||
.await | ||
.unwrap(); | ||
|
||
Self { | ||
message_log: Arc::new(RwLock::new(vec![])), | ||
memory: Arc::new(client), | ||
working: Arc::new(working), | ||
semantic: Arc::new(semantic), | ||
} | ||
} | ||
} | ||
|
||
impl InitialiseMemory for General { | ||
|
||
async fn initialise_collection(&self, collection: &str) -> SystemResult<()> { | ||
|
||
let creation_result = self.memory.create_collection( | ||
CreateCollectionBuilder::new(collection.clone()) | ||
.vectors_config(VectorParamsBuilder::new(1024, Distance::Cosine)) | ||
.quantization_config(ScalarQuantizationBuilder::default()), | ||
).await; | ||
|
||
match creation_result { | ||
Ok(_) => println!("Collection {} created", collection.clone()), | ||
Err(e) => println!("Error creating collection {}: {}", collection.clone(), e) | ||
} | ||
|
||
Ok(()) | ||
async fn initialise(&self) -> SystemResult<()> { | ||
self.working.initialise().await?; | ||
self.semantic.initialise().await | ||
} | ||
} | ||
|
||
impl RecordMemory for General { | ||
|
||
async fn record_memory_chunk(&self, collection: &str, chunk: &str) -> SystemResult<bool> { | ||
|
||
// let sample = embeddings.clone(); | ||
// | ||
// let mut payload: Payload = Payload::new(); | ||
// | ||
// for item in index_payload.iter() { | ||
// payload.insert(item.0.clone(), Value::from(item.1.clone())); | ||
// } | ||
// | ||
// let points = vec![PointStruct::new(id, sample, payload)]; | ||
// | ||
// let collection_name = collection.clone(); | ||
// | ||
// let response = self.memory.upsert_points(collection_name, None, points, None) | ||
// .await | ||
// .map_err(|e| { | ||
// println!("Error indexing: {}", e); | ||
// Error::MemoryError | ||
// }) | ||
// ?; | ||
|
||
Ok(true) | ||
async fn record_memory_chunk(&self, chunk: &str) -> SystemResult<bool> { | ||
self.semantic.record_memory_chunk(chunk).await | ||
} | ||
|
||
async fn record_prompt_messages<'a>(&'a self, messages: &'a Vec<PromptMessage>) -> SystemResult<bool> { | ||
|
||
if messages.is_empty() { | ||
Ok(false) | ||
} else { | ||
let mut result = self.message_log.write().unwrap(); | ||
result.extend(messages.iter().cloned()); | ||
Ok(true) | ||
} | ||
async fn record_prompt_messages<'a>(&'a self, context: &MemoryContext, messages: &'a Vec<PromptMessage>) -> SystemResult<bool> { | ||
self.working.record_prompt_messages(context, messages).await | ||
} | ||
} | ||
|
||
impl RetrieveMemory for General { | ||
|
||
async fn retrieve_past_messages(&self) -> SystemResult<Vec<PromptMessage>> { | ||
Ok(self.message_log.read().unwrap().iter().cloned().collect()) | ||
async fn retrieve_past_n_messages(&self, context: &MemoryContext, last_n: i32) -> SystemResult<Vec<PromptMessage>> { | ||
self.working.retrieve_past_n_messages(context, last_n).await | ||
} | ||
|
||
async fn retrieve_memory_chunks(&self, collection: &str, query: &str) -> SystemResult<Vec<String>> { | ||
todo!() | ||
async fn retrieve_memory_chunks(&self, query: &str) -> SystemResult<Vec<MemoryBlock>> { | ||
self.semantic.retrieve_memory_chunks(query).await | ||
} | ||
} | ||
|
||
|
||
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::fs; | ||
use serde_json::{json, Value}; | ||
use testcontainers::core::{IntoContainerPort, Mount, WaitFor}; | ||
use testcontainers::{GenericImage, ImageExt}; | ||
use testcontainers::runners::AsyncRunner; | ||
use agentsmith_common::config::config::read_config; | ||
use crate::llm::llm_factory::LLMFactory; | ||
use crate::llm::prompt::UserContent; | ||
use super::*; | ||
|
||
|
||
// #[tokio::test(flavor = "multi_thread", worker_threads = 2)] | ||
#[tokio::test] | ||
async fn test_record_prompt_message() { | ||
|
||
tracing_subscriber::fmt::init(); | ||
|
||
let client = Qdrant::from_url("http://localhost:6333").build().unwrap(); | ||
|
||
let memory = General { | ||
message_log: Arc::new(RwLock::new(vec![])), | ||
memory: Arc::new(client), | ||
}; | ||
|
||
memory.initialise_collection(&String::from("test")).await.unwrap(); | ||
|
||
let current_storage = memory.retrieve_past_messages().await; | ||
|
||
assert_eq!(current_storage.unwrap().len(), 0); | ||
|
||
let result_true = memory.record_prompt_messages(&vec![PromptMessage::User { | ||
role: "user".to_string(), | ||
content: vec![UserContent::Text { | ||
type_: "text".to_string(), | ||
text: "Hello world!".to_string(), | ||
}], | ||
name: None, | ||
}]).await.unwrap(); | ||
|
||
assert!(result_true); | ||
|
||
let result_false = memory.record_prompt_messages(&vec![]).await.unwrap(); | ||
|
||
assert!(!result_false); | ||
|
||
let current_storage = memory.retrieve_past_messages().await.unwrap(); | ||
|
||
assert_eq!(current_storage.len(), 1); | ||
// let client = Qdrant::from_url("http://localhost:6333").build().unwrap(); | ||
// | ||
// let memory = General { | ||
// message_log: Arc::new(RwLock::new(vec![])), | ||
// memory: Arc::new(client), | ||
// }; | ||
// | ||
// memory.initialise(&String::from("test")).await.unwrap(); | ||
// | ||
// let current_storage = memory.retrieve_past_messages().await; | ||
// | ||
// assert_eq!(current_storage.unwrap().len(), 0); | ||
// | ||
// let result_true = memory.record_prompt_messages(&vec![PromptMessage::User { | ||
// role: "user".to_string(), | ||
// content: vec![UserContent::Text { | ||
// type_: "text".to_string(), | ||
// text: "Hello world!".to_string(), | ||
// }], | ||
// name: None, | ||
// }]).await.unwrap(); | ||
// | ||
// assert!(result_true); | ||
// | ||
// let result_false = memory.record_prompt_messages(&vec![]).await.unwrap(); | ||
// | ||
// assert!(!result_false); | ||
// | ||
// let current_storage = memory.retrieve_past_messages().await.unwrap(); | ||
// | ||
// assert_eq!(current_storage.len(), 1); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.