From 6fb3b4abec51a4ccb29f1ee34be983245c2637aa Mon Sep 17 00:00:00 2001 From: vasilije Date: Tue, 17 Dec 2024 19:57:47 +0100 Subject: [PATCH 01/16] Add versioning to the data point model --- .../infrastructure/engine/models/DataPoint.py | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index e08041146..d293e25d1 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -1,45 +1,64 @@ + + from datetime import datetime, timezone -from typing import Optional +from typing import Optional, Any, Dict from uuid import UUID, uuid4 from pydantic import BaseModel, Field from typing_extensions import TypedDict +# Define metadata type class MetaData(TypedDict): index_fields: list[str] + +# Updated DataPoint model with versioning and new fields class DataPoint(BaseModel): __tablename__ = "data_point" - id: UUID = Field(default_factory = uuid4) - updated_at: Optional[datetime] = datetime.now(timezone.utc) + id: UUID = Field(default_factory=uuid4) + created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) + updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) + version: str = "0.1" # Default version + source: Optional[str] = None # Path to file, URL, etc. + type: Optional[str] = "text" # "text", "file", "image", "video" topological_rank: Optional[int] = 0 - _metadata: Optional[MetaData] = { - "index_fields": [], - "type": "DataPoint" - } + extra: Optional[Dict[str, Any]] = None # For additional properties + _metadata: Optional[MetaData] = Field( + default={"index_fields": [], "type": "DataPoint"} + ) - # class Config: - # underscore_attrs_are_private = True + # Override the Pydantic configuration + class Config: + underscore_attrs_are_private = True @classmethod - def get_embeddable_data(self, data_point): - if data_point._metadata and len(data_point._metadata["index_fields"]) > 0 \ - and hasattr(data_point, data_point._metadata["index_fields"][0]): + def get_embeddable_data(cls, data_point): + """Retrieve embeddable data based on metadata's index_fields.""" + if ( + data_point._metadata + and len(data_point._metadata["index_fields"]) > 0 + and hasattr(data_point, data_point._metadata["index_fields"][0]) + ): attribute = getattr(data_point, data_point._metadata["index_fields"][0]) if isinstance(attribute, str): return attribute.strip() - else: - return attribute + return attribute @classmethod - def get_embeddable_properties(self, data_point): + def get_embeddable_properties(cls, data_point): + """Retrieve all embeddable properties.""" if data_point._metadata and len(data_point._metadata["index_fields"]) > 0: return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]] - return [] @classmethod - def get_embeddable_property_names(self, data_point): - return data_point._metadata["index_fields"] or [] \ No newline at end of file + def get_embeddable_property_names(cls, data_point): + """Retrieve names of embeddable properties.""" + return data_point._metadata["index_fields"] or [] + + def update_version(self, new_version: str): + """Update the version and updated_at timestamp.""" + self.version = new_version + self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000) From 87bc5d8266c1bcee0e7421dd5e46ab0f336587f0 Mon Sep 17 00:00:00 2001 From: vasilije Date: Tue, 17 Dec 2024 20:06:29 +0100 Subject: [PATCH 02/16] Add versioning to the data point model --- .../infrastructure/engine/models/DataPoint.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index d293e25d1..4cd7664e1 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -24,22 +24,20 @@ class DataPoint(BaseModel): type: Optional[str] = "text" # "text", "file", "image", "video" topological_rank: Optional[int] = 0 extra: Optional[Dict[str, Any]] = None # For additional properties - _metadata: Optional[MetaData] = Field( - default={"index_fields": [], "type": "DataPoint"} - ) + _metadata: Optional[MetaData] = { + "index_fields": [], + "type": "DataPoint" + } # Override the Pydantic configuration class Config: underscore_attrs_are_private = True @classmethod - def get_embeddable_data(cls, data_point): - """Retrieve embeddable data based on metadata's index_fields.""" - if ( - data_point._metadata - and len(data_point._metadata["index_fields"]) > 0 - and hasattr(data_point, data_point._metadata["index_fields"][0]) - ): + @classmethod + def get_embeddable_data(self, data_point): + if data_point._metadata and len(data_point._metadata["index_fields"]) > 0 \ + and hasattr(data_point, data_point._metadata["index_fields"][0]): attribute = getattr(data_point, data_point._metadata["index_fields"][0]) if isinstance(attribute, str): @@ -47,14 +45,14 @@ def get_embeddable_data(cls, data_point): return attribute @classmethod - def get_embeddable_properties(cls, data_point): + def get_embeddable_properties(self, data_point): """Retrieve all embeddable properties.""" if data_point._metadata and len(data_point._metadata["index_fields"]) > 0: return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]] return [] @classmethod - def get_embeddable_property_names(cls, data_point): + def get_embeddable_property_names(self, data_point): """Retrieve names of embeddable properties.""" return data_point._metadata["index_fields"] or [] From 15d8effa3bc40d477d98c527c996cdbe8e4e108c Mon Sep 17 00:00:00 2001 From: vasilije Date: Tue, 17 Dec 2024 20:13:03 +0100 Subject: [PATCH 03/16] Add versioning to the data point model --- .../infrastructure/engine/models/DataPoint.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 4cd7664e1..e0a86bac4 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field from typing_extensions import TypedDict - +import pickle # Define metadata type class MetaData(TypedDict): @@ -23,7 +23,7 @@ class DataPoint(BaseModel): source: Optional[str] = None # Path to file, URL, etc. type: Optional[str] = "text" # "text", "file", "image", "video" topological_rank: Optional[int] = 0 - extra: Optional[Dict[str, Any]] = None # For additional properties + extra: Optional[Dict[str, Dict]] = None # For additional properties _metadata: Optional[MetaData] = { "index_fields": [], "type": "DataPoint" @@ -60,3 +60,24 @@ def update_version(self, new_version: str): """Update the version and updated_at timestamp.""" self.version = new_version self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000) + + # JSON Serialization + def to_json(self) -> str: + """Serialize the instance to a JSON string.""" + return self.json() + + @classmethod + def from_json(self, json_str: str): + """Deserialize the instance from a JSON string.""" + return self.model_validate_json(json_str) + + # Pickle Serialization + def to_pickle(self) -> bytes: + """Serialize the instance to pickle-compatible bytes.""" + return pickle.dumps(self.dict()) + + @classmethod + def from_pickle(self, pickled_data: bytes): + """Deserialize the instance from pickled bytes.""" + data = pickle.loads(pickled_data) + return self(**data) From 52b91b4b328aec8f55d47f1f3a55aa532da298c1 Mon Sep 17 00:00:00 2001 From: vasilije Date: Tue, 17 Dec 2024 20:15:32 +0100 Subject: [PATCH 04/16] Add versioning to the data point model --- cognee/infrastructure/engine/models/DataPoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index e0a86bac4..3d019d586 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -23,7 +23,7 @@ class DataPoint(BaseModel): source: Optional[str] = None # Path to file, URL, etc. type: Optional[str] = "text" # "text", "file", "image", "video" topological_rank: Optional[int] = 0 - extra: Optional[Dict[str, Dict]] = None # For additional properties + extra: Optional[str] = None # For additional properties _metadata: Optional[MetaData] = { "index_fields": [], "type": "DataPoint" From f455ba9843ffed15480e5592a157c25a787269e5 Mon Sep 17 00:00:00 2001 From: vasilije Date: Tue, 17 Dec 2024 20:15:50 +0100 Subject: [PATCH 05/16] Add versioning to the data point model --- cognee/infrastructure/engine/models/DataPoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 3d019d586..cf49cb4e1 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -23,7 +23,7 @@ class DataPoint(BaseModel): source: Optional[str] = None # Path to file, URL, etc. type: Optional[str] = "text" # "text", "file", "image", "video" topological_rank: Optional[int] = 0 - extra: Optional[str] = None # For additional properties + extra: Optional[str] = "extra" # For additional properties _metadata: Optional[MetaData] = { "index_fields": [], "type": "DataPoint" From f71485ea2b918edab0c60ca25e71a3af331c864c Mon Sep 17 00:00:00 2001 From: vasilije Date: Tue, 17 Dec 2024 20:20:10 +0100 Subject: [PATCH 06/16] Add versioning to the data point model --- .github/workflows/profiling.yaml | 1 + cognee/infrastructure/engine/models/DataPoint.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/.github/workflows/profiling.yaml b/.github/workflows/profiling.yaml index 93ed82f82..6633e376e 100644 --- a/.github/workflows/profiling.yaml +++ b/.github/workflows/profiling.yaml @@ -57,6 +57,7 @@ jobs: run: | poetry install --no-interaction --all-extras poetry run pip install pyinstrument + poetry run pip install parso # Set environment variables for SHAs diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index cf49cb4e1..9cc818228 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -81,3 +81,12 @@ def from_pickle(self, pickled_data: bytes): """Deserialize the instance from pickled bytes.""" data = pickle.loads(pickled_data) return self(**data) + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Serialize model to a dictionary.""" + return self.model_dump(**kwargs) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DataPoint": + """Deserialize model from a dictionary.""" + return cls.model_validate(data) From fe31bcdd177b0e1be70c1ad49e9f40ff0e3f1d16 Mon Sep 17 00:00:00 2001 From: vasilije Date: Tue, 17 Dec 2024 20:27:18 +0100 Subject: [PATCH 07/16] Add versioning to the data point model --- .github/workflows/profiling.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/profiling.yaml b/.github/workflows/profiling.yaml index 6633e376e..5f81f8f7e 100644 --- a/.github/workflows/profiling.yaml +++ b/.github/workflows/profiling.yaml @@ -58,6 +58,7 @@ jobs: poetry install --no-interaction --all-extras poetry run pip install pyinstrument poetry run pip install parso + poetry run pip install jedi # Set environment variables for SHAs From 7657b8e6b91df78bb95bcf4d9870bb8c586f98a2 Mon Sep 17 00:00:00 2001 From: vasilije Date: Tue, 17 Dec 2024 20:42:41 +0100 Subject: [PATCH 08/16] Add versioning to the data point model --- cognee/base_config.py | 3 +++ cognee/infrastructure/llm/openai/adapter.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/cognee/base_config.py b/cognee/base_config.py index 0e70b7652..34b0ac744 100644 --- a/cognee/base_config.py +++ b/cognee/base_config.py @@ -10,6 +10,9 @@ class BaseConfig(BaseSettings): monitoring_tool: object = MonitoringTool.LANGFUSE graphistry_username: Optional[str] = os.getenv("GRAPHISTRY_USERNAME") graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD") + langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY") + langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY") + langfuse_host: Optional[str] = os.environ["LANGFUSE_HOST"] model_config = SettingsConfigDict(env_file = ".env", extra = "allow") diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index b2929c6c0..a6bccdf7e 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -6,10 +6,11 @@ import litellm import instructor from pydantic import BaseModel - +from cognee.shared.data_models import MonitoringTool from cognee.exceptions import InvalidValueError from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.prompts import read_query_prompt +from cognee.base_config import get_base_config class OpenAIAdapter(LLMInterface): name = "OpenAI" @@ -35,6 +36,15 @@ def __init__( self.endpoint = endpoint self.api_version = api_version self.streaming = streaming + base_config = get_base_config() + if base_config.monitoring_tool == MonitoringTool.LANGFUSE: + # set callbacks + # litellm.success_callback = ["langfuse"] + # litellm.failure_callback = ["langfuse"] + self.aclient.success_callback = ["langfuse"] + self.aclient.failure_callback = ["langfuse"] + self.client.success_callback = ["langfuse"] + self.client.failure_callback = ["langfuse"] async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel: """Generate a response from a user query.""" From b976f5b7a62852b07cd0b3f914275c702d556a39 Mon Sep 17 00:00:00 2001 From: vasilije Date: Tue, 17 Dec 2024 20:59:45 +0100 Subject: [PATCH 09/16] First draft of relationship embeddings --- cognee/base_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/base_config.py b/cognee/base_config.py index 34b0ac744..4ef9bf8f6 100644 --- a/cognee/base_config.py +++ b/cognee/base_config.py @@ -12,7 +12,7 @@ class BaseConfig(BaseSettings): graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD") langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY") langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY") - langfuse_host: Optional[str] = os.environ["LANGFUSE_HOST"] + langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST") model_config = SettingsConfigDict(env_file = ".env", extra = "allow") From 2bfc657e90c3185a082b3c92d76f83c85849fbe4 Mon Sep 17 00:00:00 2001 From: vasilije Date: Sun, 5 Jan 2025 20:30:34 +0100 Subject: [PATCH 10/16] Implement PR review --- cognee/infrastructure/engine/models/DataPoint.py | 9 +++------ cognee/infrastructure/llm/openai/adapter.py | 3 --- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 9cc818228..0f2b0d34f 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -19,11 +19,9 @@ class DataPoint(BaseModel): id: UUID = Field(default_factory=uuid4) created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) - version: str = "0.1" # Default version - source: Optional[str] = None # Path to file, URL, etc. + version: str = "1" # Default version type: Optional[str] = "text" # "text", "file", "image", "video" topological_rank: Optional[int] = 0 - extra: Optional[str] = "extra" # For additional properties _metadata: Optional[MetaData] = { "index_fields": [], "type": "DataPoint" @@ -33,7 +31,6 @@ class DataPoint(BaseModel): class Config: underscore_attrs_are_private = True - @classmethod @classmethod def get_embeddable_data(self, data_point): if data_point._metadata and len(data_point._metadata["index_fields"]) > 0 \ @@ -56,9 +53,9 @@ def get_embeddable_property_names(self, data_point): """Retrieve names of embeddable properties.""" return data_point._metadata["index_fields"] or [] - def update_version(self, new_version: str): + def update_version(self): """Update the version and updated_at timestamp.""" - self.version = new_version + self.version += 1 self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000) # JSON Serialization diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index a6bccdf7e..835096007 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -38,9 +38,6 @@ def __init__( self.streaming = streaming base_config = get_base_config() if base_config.monitoring_tool == MonitoringTool.LANGFUSE: - # set callbacks - # litellm.success_callback = ["langfuse"] - # litellm.failure_callback = ["langfuse"] self.aclient.success_callback = ["langfuse"] self.aclient.failure_callback = ["langfuse"] self.client.success_callback = ["langfuse"] From d5243b47e07fc66498a7dc772c84152187a36ea3 Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:49:35 +0100 Subject: [PATCH 11/16] Update cognee/infrastructure/engine/models/DataPoint.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- cognee/infrastructure/engine/models/DataPoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 0f2b0d34f..6a677a1ea 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -19,7 +19,7 @@ class DataPoint(BaseModel): id: UUID = Field(default_factory=uuid4) created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) - version: str = "1" # Default version + version: int = 1 # Default version type: Optional[str] = "text" # "text", "file", "image", "video" topological_rank: Optional[int] = 0 _metadata: Optional[MetaData] = { From cc29cd003fb4d630ac19499f2f598a31da2cf3ce Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Thu, 16 Jan 2025 13:19:43 +0100 Subject: [PATCH 12/16] Update base_config.py --- cognee/base_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cognee/base_config.py b/cognee/base_config.py index 085ede2cd..6b1b8811e 100644 --- a/cognee/base_config.py +++ b/cognee/base_config.py @@ -16,7 +16,6 @@ class BaseConfig(BaseSettings): langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST") model_config = SettingsConfigDict(env_file=".env", extra="allow") - def to_dict(self) -> dict: return { "data_root_directory": self.data_root_directory, From 0a9c9438edd486ecac6652eeb63bcd99cdca2239 Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Thu, 16 Jan 2025 13:23:48 +0100 Subject: [PATCH 13/16] Update DataPoint.py --- cognee/infrastructure/engine/models/DataPoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 60ba85151..4f419eee9 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -21,7 +21,6 @@ class DataPoint(BaseModel): created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) version: int = 1 # Default version - type: Optional[str] = "text" # "text", "file", "image", "video" topological_rank: Optional[int] = 0 _metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"} From b6e82dfb4fc99af1df0fb73421d8d45244485e01 Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Thu, 16 Jan 2025 13:24:42 +0100 Subject: [PATCH 14/16] Update adapter.py --- cognee/infrastructure/llm/openai/adapter.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index 6ed5f3c48..7f12947fd 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -35,22 +35,7 @@ def __init__( transcription_model: str, streaming: bool = False, ): - self.aclient = instructor.from_litellm(litellm.acompletion) - self.client = instructor.from_litellm(litellm.completion) - self.transcription_model = transcription_model - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.api_version = api_version - self.streaming = streaming - - base_config = get_base_config() - - if base_config.monitoring_tool == MonitoringTool.LANGFUSE: - self.aclient.success_callback = ["langfuse"] - self.aclient.failure_callback = ["langfuse"] - self.client.success_callback = ["langfuse"] - self.client.failure_callback = ["langfuse"] + From 6c1c8abc263f5d15da3d7c8eb1929f6d58bc1db5 Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Thu, 16 Jan 2025 13:26:37 +0100 Subject: [PATCH 15/16] Update adapter.py --- cognee/infrastructure/llm/openai/adapter.py | 22 ++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index 7f12947fd..d45662380 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -12,7 +12,6 @@ from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.base_config import get_base_config - monitoring = get_base_config().monitoring_tool if monitoring == MonitoringTool.LANGFUSE: from langfuse.decorators import observe @@ -35,14 +34,19 @@ def __init__( transcription_model: str, streaming: bool = False, ): - - - - - @observe(as_type='generation') - async def acreate_structured_output(self, text_input: str, system_prompt: str, - response_model: Type[BaseModel]) -> BaseModel: - + self.aclient = instructor.from_litellm(litellm.acompletion) + self.client = instructor.from_litellm(litellm.completion) + self.transcription_model = transcription_model + self.model = model + self.api_key = api_key + self.endpoint = endpoint + self.api_version = api_version + self.streaming = streaming + + @observe(as_type="generation") + async def acreate_structured_output( + self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + ) -> BaseModel: """Generate a response from a user query.""" return await self.aclient.chat.completions.create( From 0a02886d769178f0335f88479c4a0ddb5b941e49 Mon Sep 17 00:00:00 2001 From: vasilije Date: Thu, 16 Jan 2025 13:28:35 +0100 Subject: [PATCH 16/16] Update format --- .../databases/graph/neo4j_driver/adapter.py | 12 ++++--- .../hybrid/falkordb/FalkorDBAdapter.py | 24 ++++++++----- .../sqlalchemy/SqlAlchemyAdapter.py | 6 ++-- .../infrastructure/engine/models/DataPoint.py | 17 +++++---- .../index_graphiti_objects.py | 6 ++-- .../documents/AudioDocument_test.py | 18 +++++----- .../documents/ImageDocument_test.py | 18 +++++----- .../integration/documents/PdfDocument_test.py | 18 +++++----- .../documents/TextDocument_test.py | 18 +++++----- .../documents/UnstructuredDocument_test.py | 30 ++++++++-------- cognee/tests/test_deduplication.py | 12 +++---- cognee/tests/test_falkordb.py | 6 ++-- cognee/tests/test_library.py | 6 ++-- cognee/tests/test_pgvector.py | 36 +++++++++---------- .../chunks/chunk_by_paragraph_2_test.py | 18 +++++----- .../chunks/chunk_by_paragraph_test.py | 6 ++-- .../chunks/chunk_by_sentence_test.py | 12 +++---- .../processing/chunks/chunk_by_word_test.py | 6 ++-- 18 files changed, 142 insertions(+), 127 deletions(-) diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index d632fbd81..41bfb891d 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -62,10 +62,12 @@ async def has_node(self, node_id: str) -> bool: async def add_node(self, node: DataPoint): serialized_properties = self.serialize_properties(node.model_dump()) - query = dedent("""MERGE (node {id: $node_id}) + query = dedent( + """MERGE (node {id: $node_id}) ON CREATE SET node += $properties, node.updated_at = timestamp() ON MATCH SET node += $properties, node.updated_at = timestamp() - RETURN ID(node) AS internal_id, node.id AS nodeId""") + RETURN ID(node) AS internal_id, node.id AS nodeId""" + ) params = { "node_id": str(node.id), @@ -182,13 +184,15 @@ async def add_edge( ): serialized_properties = self.serialize_properties(edge_properties) - query = dedent("""MATCH (from_node {id: $from_node}), + query = dedent( + """MATCH (from_node {id: $from_node}), (to_node {id: $to_node}) MERGE (from_node)-[r]->(to_node) ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name ON MATCH SET r += $properties, r.updated_at = timestamp() RETURN r - """) + """ + ) params = { "from_node": str(from_node), diff --git a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py index 77bfd74e6..dd8934c34 100644 --- a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py +++ b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py @@ -88,23 +88,27 @@ async def create_data_point_query(self, data_point: DataPoint, vectorized_values } ) - return dedent(f""" + return dedent( + f""" MERGE (node:{node_label} {{id: '{str(data_point.id)}'}}) ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp() ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp() - """).strip() + """ + ).strip() async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str: properties = await self.stringify_properties(edge[3]) properties = f"{{{properties}}}" - return dedent(f""" + return dedent( + f""" MERGE (source {{id:'{edge[0]}'}}) MERGE (target {{id: '{edge[1]}'}}) MERGE (source)-[edge:{edge[2]} {properties}]->(target) ON MATCH SET edge.updated_at = timestamp() ON CREATE SET edge.updated_at = timestamp() - """).strip() + """ + ).strip() async def create_collection(self, collection_name: str): pass @@ -195,12 +199,14 @@ async def add_edges(self, edges: list[tuple[str, str, str, dict]]): self.query(query) async def has_edges(self, edges): - query = dedent(""" + query = dedent( + """ UNWIND $edges AS edge MATCH (a)-[r]->(b) WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists - """).strip() + """ + ).strip() params = { "edges": [ @@ -279,14 +285,16 @@ async def search( [label, attribute_name] = collection_name.split(".") - query = dedent(f""" + query = dedent( + f""" CALL db.idx.vector.queryNodes( '{label}', '{attribute_name}', {limit}, vecf32({query_vector}) ) YIELD node, score - """).strip() + """ + ).strip() result = self.query(query) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 6c3c5029d..68561979d 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -93,10 +93,12 @@ async def get_schema_list(self) -> List[str]: if self.engine.dialect.name == "postgresql": async with self.engine.begin() as connection: result = await connection.execute( - text(""" + text( + """ SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema'); - """) + """ + ) ) return [schema[0] for schema in result.fetchall()] return [] diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 4f419eee9..8317f9401 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -1,5 +1,3 @@ - - from datetime import datetime, timezone from typing import Optional, Any, Dict from uuid import UUID, uuid4 @@ -8,18 +6,22 @@ from typing_extensions import TypedDict import pickle + # Define metadata type class MetaData(TypedDict): index_fields: list[str] - # Updated DataPoint model with versioning and new fields class DataPoint(BaseModel): __tablename__ = "data_point" id: UUID = Field(default_factory=uuid4) - created_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) - updated_at: int = Field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)) + created_at: int = Field( + default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) + ) + updated_at: int = Field( + default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) + ) version: int = 1 # Default version topological_rank: Optional[int] = 0 _metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"} @@ -45,7 +47,6 @@ def get_embeddable_data(self, data_point): def get_embeddable_properties(self, data_point): """Retrieve all embeddable properties.""" if data_point._metadata and len(data_point._metadata["index_fields"]) > 0: - return [ getattr(data_point, field, None) for field in data_point._metadata["index_fields"] ] @@ -54,7 +55,6 @@ def get_embeddable_properties(self, data_point): @classmethod def get_embeddable_property_names(self, data_point): - """Retrieve names of embeddable properties.""" return data_point._metadata["index_fields"] or [] @@ -63,7 +63,7 @@ def update_version(self): self.version += 1 self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000) - # JSON Serialization + # JSON Serialization def to_json(self) -> str: """Serialize the instance to a JSON string.""" return self.json() @@ -92,4 +92,3 @@ def to_dict(self, **kwargs) -> Dict[str, Any]: def from_dict(cls, data: Dict[str, Any]) -> "DataPoint": """Deserialize model from a dictionary.""" return cls.model_validate(data) - diff --git a/cognee/tasks/temporal_awareness/index_graphiti_objects.py b/cognee/tasks/temporal_awareness/index_graphiti_objects.py index cb616ed82..1fbc1f41a 100644 --- a/cognee/tasks/temporal_awareness/index_graphiti_objects.py +++ b/cognee/tasks/temporal_awareness/index_graphiti_objects.py @@ -19,9 +19,11 @@ async def index_and_transform_graphiti_nodes_and_edges(): raise RuntimeError("Initialization error") from e await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""") - await graph_engine.query("""MATCH (source)-[r]->(target) SET r.source_node_id = source.id, + await graph_engine.query( + """MATCH (source)-[r]->(target) SET r.source_node_id = source.id, r.target_node_id = target.id, - r.relationship_name = type(r) RETURN r""") + r.relationship_name = type(r) RETURN r""" + ) await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""") nodes_data, edges_data = await graph_engine.get_model_independent_graph_data() diff --git a/cognee/tests/integration/documents/AudioDocument_test.py b/cognee/tests/integration/documents/AudioDocument_test.py index e07a2431b..dbd43ddda 100644 --- a/cognee/tests/integration/documents/AudioDocument_test.py +++ b/cognee/tests/integration/documents/AudioDocument_test.py @@ -36,12 +36,12 @@ def test_AudioDocument(): for ground_truth, paragraph_data in zip( GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") ): - assert ground_truth["word_count"] == paragraph_data.word_count, ( - f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' - ) - assert ground_truth["len_text"] == len(paragraph_data.text), ( - f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' - ) - assert ground_truth["cut_type"] == paragraph_data.cut_type, ( - f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' - ) + assert ( + ground_truth["word_count"] == paragraph_data.word_count + ), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' + assert ground_truth["len_text"] == len( + paragraph_data.text + ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' + assert ( + ground_truth["cut_type"] == paragraph_data.cut_type + ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' diff --git a/cognee/tests/integration/documents/ImageDocument_test.py b/cognee/tests/integration/documents/ImageDocument_test.py index b8d585419..c0877ae99 100644 --- a/cognee/tests/integration/documents/ImageDocument_test.py +++ b/cognee/tests/integration/documents/ImageDocument_test.py @@ -25,12 +25,12 @@ def test_ImageDocument(): for ground_truth, paragraph_data in zip( GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") ): - assert ground_truth["word_count"] == paragraph_data.word_count, ( - f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' - ) - assert ground_truth["len_text"] == len(paragraph_data.text), ( - f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' - ) - assert ground_truth["cut_type"] == paragraph_data.cut_type, ( - f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' - ) + assert ( + ground_truth["word_count"] == paragraph_data.word_count + ), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' + assert ground_truth["len_text"] == len( + paragraph_data.text + ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' + assert ( + ground_truth["cut_type"] == paragraph_data.cut_type + ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' diff --git a/cognee/tests/integration/documents/PdfDocument_test.py b/cognee/tests/integration/documents/PdfDocument_test.py index fc4307846..8f28815d3 100644 --- a/cognee/tests/integration/documents/PdfDocument_test.py +++ b/cognee/tests/integration/documents/PdfDocument_test.py @@ -27,12 +27,12 @@ def test_PdfDocument(): for ground_truth, paragraph_data in zip( GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker") ): - assert ground_truth["word_count"] == paragraph_data.word_count, ( - f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' - ) - assert ground_truth["len_text"] == len(paragraph_data.text), ( - f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' - ) - assert ground_truth["cut_type"] == paragraph_data.cut_type, ( - f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' - ) + assert ( + ground_truth["word_count"] == paragraph_data.word_count + ), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' + assert ground_truth["len_text"] == len( + paragraph_data.text + ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' + assert ( + ground_truth["cut_type"] == paragraph_data.cut_type + ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' diff --git a/cognee/tests/integration/documents/TextDocument_test.py b/cognee/tests/integration/documents/TextDocument_test.py index 6daec62b7..1e143d563 100644 --- a/cognee/tests/integration/documents/TextDocument_test.py +++ b/cognee/tests/integration/documents/TextDocument_test.py @@ -39,12 +39,12 @@ def test_TextDocument(input_file, chunk_size): for ground_truth, paragraph_data in zip( GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker") ): - assert ground_truth["word_count"] == paragraph_data.word_count, ( - f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' - ) - assert ground_truth["len_text"] == len(paragraph_data.text), ( - f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' - ) - assert ground_truth["cut_type"] == paragraph_data.cut_type, ( - f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' - ) + assert ( + ground_truth["word_count"] == paragraph_data.word_count + ), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' + assert ground_truth["len_text"] == len( + paragraph_data.text + ), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }' + assert ( + ground_truth["cut_type"] == paragraph_data.cut_type + ), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' diff --git a/cognee/tests/integration/documents/UnstructuredDocument_test.py b/cognee/tests/integration/documents/UnstructuredDocument_test.py index 773dc2293..e0278de81 100644 --- a/cognee/tests/integration/documents/UnstructuredDocument_test.py +++ b/cognee/tests/integration/documents/UnstructuredDocument_test.py @@ -71,32 +71,32 @@ def test_UnstructuredDocument(): for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"): assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" - assert "sentence_cut" == paragraph_data.cut_type, ( - f" sentence_cut != {paragraph_data.cut_type = }" - ) + assert ( + "sentence_cut" == paragraph_data.cut_type + ), f" sentence_cut != {paragraph_data.cut_type = }" # Test DOCX for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"): assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" - assert "sentence_end" == paragraph_data.cut_type, ( - f" sentence_end != {paragraph_data.cut_type = }" - ) + assert ( + "sentence_end" == paragraph_data.cut_type + ), f" sentence_end != {paragraph_data.cut_type = }" # TEST CSV for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"): assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }" - assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, ( - f"Read text doesn't match expected text: {paragraph_data.text}" - ) - assert "sentence_cut" == paragraph_data.cut_type, ( - f" sentence_cut != {paragraph_data.cut_type = }" - ) + assert ( + "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text + ), f"Read text doesn't match expected text: {paragraph_data.text}" + assert ( + "sentence_cut" == paragraph_data.cut_type + ), f" sentence_cut != {paragraph_data.cut_type = }" # Test XLSX for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"): assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" - assert "sentence_cut" == paragraph_data.cut_type, ( - f" sentence_cut != {paragraph_data.cut_type = }" - ) + assert ( + "sentence_cut" == paragraph_data.cut_type + ), f" sentence_cut != {paragraph_data.cut_type = }" diff --git a/cognee/tests/test_deduplication.py b/cognee/tests/test_deduplication.py index 89c866f12..9c2df032d 100644 --- a/cognee/tests/test_deduplication.py +++ b/cognee/tests/test_deduplication.py @@ -30,9 +30,9 @@ async def test_deduplication(): result = await relational_engine.get_all_data_from_table("data") assert len(result) == 1, "More than one data entity was found." - assert result[0]["name"] == "Natural_language_processing_copy", ( - "Result name does not match expected value." - ) + assert ( + result[0]["name"] == "Natural_language_processing_copy" + ), "Result name does not match expected value." result = await relational_engine.get_all_data_from_table("datasets") assert len(result) == 2, "Unexpected number of datasets found." @@ -61,9 +61,9 @@ async def test_deduplication(): result = await relational_engine.get_all_data_from_table("data") assert len(result) == 1, "More than one data entity was found." - assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], ( - "Content hash is not a part of file name." - ) + assert ( + hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"] + ), "Content hash is not a part of file name." await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) diff --git a/cognee/tests/test_falkordb.py b/cognee/tests/test_falkordb.py index af0e87916..07ece9eb2 100755 --- a/cognee/tests/test_falkordb.py +++ b/cognee/tests/test_falkordb.py @@ -85,9 +85,9 @@ async def main(): from cognee.infrastructure.databases.relational import get_relational_engine - assert not os.path.exists(get_relational_engine().db_path), ( - "SQLite relational database is not empty" - ) + assert not os.path.exists( + get_relational_engine().db_path + ), "SQLite relational database is not empty" from cognee.infrastructure.databases.graph import get_graph_config diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py index 192b67506..8352b4161 100755 --- a/cognee/tests/test_library.py +++ b/cognee/tests/test_library.py @@ -82,9 +82,9 @@ async def main(): from cognee.infrastructure.databases.relational import get_relational_engine - assert not os.path.exists(get_relational_engine().db_path), ( - "SQLite relational database is not empty" - ) + assert not os.path.exists( + get_relational_engine().db_path + ), "SQLite relational database is not empty" from cognee.infrastructure.databases.graph import get_graph_config diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 73b6be974..c241177f0 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location): data_hash = hashlib.md5(encoded_text).hexdigest() # Get data entry from database based on hash contents data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one() - assert os.path.isfile(data.raw_data_location), ( - f"Data location doesn't exist: {data.raw_data_location}" - ) + assert os.path.isfile( + data.raw_data_location + ), f"Data location doesn't exist: {data.raw_data_location}" # Test deletion of data along with local files created by cognee await engine.delete_data_entity(data.id) - assert not os.path.exists(data.raw_data_location), ( - f"Data location still exists after deletion: {data.raw_data_location}" - ) + assert not os.path.exists( + data.raw_data_location + ), f"Data location still exists after deletion: {data.raw_data_location}" async with engine.get_async_session() as session: # Get data entry from database based on file path data = ( await session.scalars(select(Data).where(Data.raw_data_location == file_location)) ).one() - assert os.path.isfile(data.raw_data_location), ( - f"Data location doesn't exist: {data.raw_data_location}" - ) + assert os.path.isfile( + data.raw_data_location + ), f"Data location doesn't exist: {data.raw_data_location}" # Test local files not created by cognee won't get deleted await engine.delete_data_entity(data.id) - assert os.path.exists(data.raw_data_location), ( - f"Data location doesn't exists: {data.raw_data_location}" - ) + assert os.path.exists( + data.raw_data_location + ), f"Data location doesn't exists: {data.raw_data_location}" async def test_getting_of_documents(dataset_name_1): @@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1): user = await get_default_user() document_ids = await get_document_ids_for_user(user.id, [dataset_name_1]) - assert len(document_ids) == 1, ( - f"Number of expected documents doesn't match {len(document_ids)} != 1" - ) + assert ( + len(document_ids) == 1 + ), f"Number of expected documents doesn't match {len(document_ids)} != 1" # Test getting of documents for search when no dataset is provided user = await get_default_user() document_ids = await get_document_ids_for_user(user.id) - assert len(document_ids) == 2, ( - f"Number of expected documents doesn't match {len(document_ids)} != 2" - ) + assert ( + len(document_ids) == 2 + ), f"Number of expected documents doesn't match {len(document_ids)} != 2" async def main(): diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py index d8680a604..53098fc67 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py @@ -17,9 +17,9 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs): chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs) reconstructed_text = "".join([chunk["text"] for chunk in chunks]) - assert reconstructed_text == input_text, ( - f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" - ) + assert ( + reconstructed_text == input_text + ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" @pytest.mark.parametrize( @@ -36,9 +36,9 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks]) larger_chunks = chunk_lengths[chunk_lengths > paragraph_length] - assert np.all(chunk_lengths <= paragraph_length), ( - f"{paragraph_length = }: {larger_chunks} are too large" - ) + assert np.all( + chunk_lengths <= paragraph_length + ), f"{paragraph_length = }: {larger_chunks} are too large" @pytest.mark.parametrize( @@ -50,6 +50,6 @@ def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_ data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs ) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) - assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( - f"{chunk_indices = } are not monotonically increasing" - ) + assert np.all( + chunk_indices == np.arange(len(chunk_indices)) + ), f"{chunk_indices = } are not monotonically increasing" diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py index e420b2e9f..e7d9a54ba 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py @@ -58,9 +58,9 @@ def run_chunking_test(test_text, expected_chunks): for expected_chunks_item, chunk in zip(expected_chunks, chunks): for key in ["text", "word_count", "cut_type"]: - assert chunk[key] == expected_chunks_item[key], ( - f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }" - ) + assert ( + chunk[key] == expected_chunks_item[key] + ), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }" def test_chunking_whole_text(): diff --git a/cognee/tests/unit/processing/chunks/chunk_by_sentence_test.py b/cognee/tests/unit/processing/chunks/chunk_by_sentence_test.py index efa053077..d1c75d7ed 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_sentence_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_sentence_test.py @@ -16,9 +16,9 @@ def test_chunk_by_sentence_isomorphism(input_text, maximum_length): chunks = chunk_by_sentence(input_text, maximum_length) reconstructed_text = "".join([chunk[1] for chunk in chunks]) - assert reconstructed_text == input_text, ( - f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" - ) + assert ( + reconstructed_text == input_text + ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" @pytest.mark.parametrize( @@ -36,6 +36,6 @@ def test_paragraph_chunk_length(input_text, maximum_length): chunk_lengths = np.array([len(list(chunk_by_word(chunk[1]))) for chunk in chunks]) larger_chunks = chunk_lengths[chunk_lengths > maximum_length] - assert np.all(chunk_lengths <= maximum_length), ( - f"{maximum_length = }: {larger_chunks} are too large" - ) + assert np.all( + chunk_lengths <= maximum_length + ), f"{maximum_length = }: {larger_chunks} are too large" diff --git a/cognee/tests/unit/processing/chunks/chunk_by_word_test.py b/cognee/tests/unit/processing/chunks/chunk_by_word_test.py index d79fcdbc8..fb26638cb 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_word_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_word_test.py @@ -17,9 +17,9 @@ def test_chunk_by_word_isomorphism(input_text): chunks = chunk_by_word(input_text) reconstructed_text = "".join([chunk[0] for chunk in chunks]) - assert reconstructed_text == input_text, ( - f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" - ) + assert ( + reconstructed_text == input_text + ), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" @pytest.mark.parametrize(