Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable different chunking methods #128

Merged
merged 5 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cognee/api/v1/cognify/cognify_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ async def run_cognify_pipeline(dataset: Dataset):
data: list[Data] = await get_dataset_data(dataset_id = dataset.id)

documents = [
PdfDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location) if data_item.extension == "pdf" else
AudioDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location) if data_item.extension == "audio" else
ImageDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location) if data_item.extension == "image" else
TextDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location)
PdfDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location, chunking_strategy="paragraph") if data_item.extension == "pdf" else
AudioDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location, chunking_strategy="paragraph") if data_item.extension == "audio" else
ImageDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location, chunking_strategy="paragraph") if data_item.extension == "image" else
TextDocument(id = data_item.id, title=f"{data_item.name}.{data_item.extension}", file_path=data_item.raw_data_location, chunking_strategy="paragraph")
for data_item in data
]

Expand Down
2 changes: 2 additions & 0 deletions cognee/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ class BaseConfig(BaseSettings):
monitoring_tool: object = MonitoringTool.LANGFUSE
graphistry_username: Optional[str] = None
graphistry_password: Optional[str] = None
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None

model_config = SettingsConfigDict(env_file = ".env", extra = "allow")

Expand Down
2 changes: 2 additions & 0 deletions cognee/infrastructure/data/chunking/DefaultChunkEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from cognee.shared.data_models import ChunkStrategy


# /Users/vasa/Projects/cognee/cognee/infrastructure/data/chunking/DefaultChunkEngine.py

class DefaultChunkEngine():
def __init__(self, chunk_strategy=None, chunk_size=None, chunk_overlap=None):
self.chunk_strategy = chunk_strategy
Expand Down
17 changes: 11 additions & 6 deletions cognee/modules/data/processing/document_types/AudioDocument.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
from typing import Optional

from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.data.chunking import chunk_by_paragraph

from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from cognee.modules.data.processing.document_types.Document import Document

from cognee.tasks.chunking.chunking_registry import get_chunking_function

class AudioReader:
id: UUID
file_path: str
chunking_strategy:str

def __init__(self, id: UUID, file_path: str):
def __init__(self, id: UUID, file_path: str, chunking_strategy:str = "paragraph"):
self.id = id
self.file_path = file_path
self.llm_client = get_llm_client() # You can choose different models like "tiny", "base", "small", etc.

self.chunking_function = get_chunking_function(chunking_strategy)

def read(self, max_chunk_size: Optional[int] = 1024):
chunk_index = 0
Expand All @@ -37,7 +40,7 @@ def read_text_chunks(text, chunk_size):
chunked_pages.append(page_index)
page_index += 1

for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs=True):
for chunk_data in self.chunking_function(page_text, max_chunk_size, batch_paragraphs=True):
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
paragraph_chunks.append(chunk_data)
chunk_size += chunk_data["word_count"]
Expand Down Expand Up @@ -86,14 +89,16 @@ class AudioDocument(Document):
type: str = "audio"
title: str
file_path: str
chunking_strategy:str

def __init__(self, id: UUID, title: str, file_path: str):
def __init__(self, id: UUID, title: str, file_path: str, chunking_strategy:str="paragraph"):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.file_path = file_path
self.chunking_strategy = chunking_strategy

def get_reader(self) -> AudioReader:
reader = AudioReader(self.id, self.file_path)
reader = AudioReader(self.id, self.file_path, self.chunking_strategy)
return reader

def to_dict(self) -> dict:
Expand Down
1 change: 1 addition & 0 deletions cognee/modules/data/processing/document_types/Document.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ class Document(Protocol):
type: str
title: str
file_path: str
chunking_strategy:str
12 changes: 8 additions & 4 deletions cognee/modules/data/processing/document_types/ImageDocument.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
from typing import Optional

from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.data.chunking import chunk_by_paragraph

from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from cognee.modules.data.processing.document_types.Document import Document
from cognee.tasks.chunking import chunk_by_paragraph
from cognee.tasks.chunking.chunking_registry import get_chunking_function


class ImageReader:
id: UUID
file_path: str
chunking_strategy:str

def __init__(self, id: UUID, file_path: str):
def __init__(self, id: UUID, file_path: str, chunking_strategy:str = "paragraph"):
self.id = id
self.file_path = file_path
self.llm_client = get_llm_client() # You can choose different models like "tiny", "base", "small", etc.

self.llm_client = get_llm_client() # You can choose different models like "tiny", "base", "small", etc.
self.chunking_function = get_chunking_function(chunking_strategy)

def read(self, max_chunk_size: Optional[int] = 1024):
chunk_index = 0
Expand Down
18 changes: 13 additions & 5 deletions cognee/modules/data/processing/document_types/PdfDocument.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from typing import Optional
from pypdf import PdfReader as pypdf_PdfReader
from cognee.modules.data.chunking import chunk_by_paragraph

from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from cognee.tasks.chunking import chunk_by_paragraph
from cognee.tasks.chunking.chunking_registry import get_chunking_function
from .Document import Document

class PdfReader():
id: UUID
file_path: str
chunking_strategy: str

def __init__(self, id: UUID, file_path: str):
def __init__(self, id: UUID, file_path: str, chunking_strategy:str = "paragraph"):
self.id = id
self.file_path = file_path
self.chunking_strategy = chunking_strategy
self.chunking_function = get_chunking_function(chunking_strategy)


def get_number_of_pages(self):
file = pypdf_PdfReader(self.file_path)
Expand All @@ -33,7 +39,7 @@ def read(self, max_chunk_size: Optional[int] = 1024):
page_text = page.extract_text()
chunked_pages.append(page_index)

for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs = True):
for chunk_data in self.chunking_function(page_text, max_chunk_size, batch_paragraphs = True):
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
paragraph_chunks.append(chunk_data)
chunk_size += chunk_data["word_count"]
Expand Down Expand Up @@ -85,18 +91,20 @@ class PdfDocument(Document):
title: str
num_pages: int
file_path: str
chunking_strategy:str

def __init__(self, id: UUID, title: str, file_path: str):
def __init__(self, id: UUID, title: str, file_path: str, chunking_strategy:str="paragraph"):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.file_path = file_path
logging.debug("file_path: %s", self.file_path)
reader = PdfReader(self.id, self.file_path)
self.num_pages = reader.get_number_of_pages()
self.chunking_strategy = chunking_strategy

def get_reader(self) -> PdfReader:
logging.debug("file_path: %s", self.file_path)
reader = PdfReader(self.id, self.file_path)
reader = PdfReader(self.id, self.file_path, self.chunking_strategy)
return reader

def to_dict(self) -> dict:
Expand Down
18 changes: 13 additions & 5 deletions cognee/modules/data/processing/document_types/TextDocument.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from typing import Optional
from cognee.modules.data.chunking import chunk_by_paragraph

from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from cognee.tasks.chunking.chunking_registry import get_chunking_function
from .Document import Document

class TextReader():
id: UUID
file_path: str
chunking_strategy:str

def __init__(self, id: UUID, file_path: str):
def __init__(self, id: UUID, file_path: str, chunking_strategy:str="paragraph"):
self.id = id
self.file_path = file_path
self.chunking_strategy = chunking_strategy
self.chunking_function = get_chunking_function(chunking_strategy)



def get_number_of_pages(self):
num_pages = 1 # Pure text is not formatted
Expand Down Expand Up @@ -39,7 +45,7 @@ def read_text_chunks(file_path):
chunked_pages.append(page_index)
page_index += 1

for chunk_data in chunk_by_paragraph(page_text, max_chunk_size, batch_paragraphs = True):
for chunk_data in self.chunking_function(page_text, max_chunk_size, batch_paragraphs = True):
if chunk_size + chunk_data["word_count"] <= max_chunk_size:
paragraph_chunks.append(chunk_data)
chunk_size += chunk_data["word_count"]
Expand Down Expand Up @@ -89,17 +95,19 @@ class TextDocument(Document):
title: str
num_pages: int
file_path: str
chunking_strategy:str

def __init__(self, id: UUID, title: str, file_path: str):
def __init__(self, id: UUID, title: str, file_path: str, chunking_strategy:str="paragraph"):
self.id = id or uuid5(NAMESPACE_OID, title)
self.title = title
self.file_path = file_path
self.chunking_strategy = chunking_strategy

reader = TextReader(self.id, self.file_path)
self.num_pages = reader.get_number_of_pages()

def get_reader(self) -> TextReader:
reader = TextReader(self.id, self.file_path)
reader = TextReader(self.id, self.file_path, self.chunking_strategy)
return reader

def to_dict(self) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

if __name__ == "__main__":
test_file_path = os.path.join(os.path.dirname(__file__), "artificial-inteligence.pdf")
pdf_doc = PdfDocument("Test document.pdf", test_file_path)
pdf_doc = PdfDocument("Test document.pdf", test_file_path, chunking_strategy="paragraph")
reader = pdf_doc.get_reader()

for paragraph_data in reader.read():
Expand Down
Empty file.
39 changes: 39 additions & 0 deletions cognee/tasks/chunk_translate/translate_chunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

import logging

from cognee.base_config import get_base_config

BaseConfig = get_base_config()

async def translate_text(data, source_language:str='sr', target_language:str='en', region_name='eu-west-1'):
"""
Translate text from source language to target language using AWS Translate.
Parameters:
data (str): The text to be translated.
source_language (str): The source language code (e.g., 'sr' for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
target_language (str): The target language code (e.g., 'en' for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
region_name (str): AWS region name.
Returns:
str: Translated text or an error message.
"""
import boto3
from botocore.exceptions import BotoCoreError, ClientError

if not data:
yield "No text provided for translation."

if not source_language or not target_language:
yield "Both source and target language codes are required."

try:
translate = boto3.client(service_name='translate', region_name=region_name, use_ssl=True)
result = translate.translate_text(Text=data, SourceLanguageCode=source_language, TargetLanguageCode=target_language)
yield result.get('TranslatedText', 'No translation found.')

except BotoCoreError as e:
logging.info(f"BotoCoreError occurred: {e}")
yield "Error with AWS Translate service configuration or request."

except ClientError as e:
logging.info(f"ClientError occurred: {e}")
yield "Error with AWS client or network issue."
Comment on lines +8 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor translate_text to use return instead of yield.

Using yield in an async function is unconventional and may cause unexpected behavior. Consider refactoring to use return or an async-compatible structure.

-async def translate_text(data, source_language:str='sr', target_language:str='en', region_name='eu-west-1'):
+async def translate_text(data, source_language: str = 'sr', target_language: str = 'en', region_name: str = 'eu-west-1') -> str:
    ...
-        yield "No text provided for translation."
+        return "No text provided for translation."
    ...
-        yield "Both source and target language codes are required."
+        return "Both source and target language codes are required."
    ...
-        yield result.get('TranslatedText', 'No translation found.')
+        return result.get('TranslatedText', 'No translation found.')
    ...
-        yield "Error with AWS Translate service configuration or request."
+        return "Error with AWS Translate service configuration or request."
    ...
-        yield "Error with AWS client or network issue."
+        return "Error with AWS client or network issue."
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def translate_text(data, source_language:str='sr', target_language:str='en', region_name='eu-west-1'):
"""
Translate text from source language to target language using AWS Translate.
Parameters:
data (str): The text to be translated.
source_language (str): The source language code (e.g., 'sr' for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
target_language (str): The target language code (e.g., 'en' for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
region_name (str): AWS region name.
Returns:
str: Translated text or an error message.
"""
import boto3
from botocore.exceptions import BotoCoreError, ClientError
if not data:
yield "No text provided for translation."
if not source_language or not target_language:
yield "Both source and target language codes are required."
try:
translate = boto3.client(service_name='translate', region_name=region_name, use_ssl=True)
result = translate.translate_text(Text=data, SourceLanguageCode=source_language, TargetLanguageCode=target_language)
yield result.get('TranslatedText', 'No translation found.')
except BotoCoreError as e:
logging.info(f"BotoCoreError occurred: {e}")
yield "Error with AWS Translate service configuration or request."
except ClientError as e:
logging.info(f"ClientError occurred: {e}")
yield "Error with AWS client or network issue."
async def translate_text(data, source_language: str = 'sr', target_language: str = 'en', region_name: str = 'eu-west-1') -> str:
"""
Translate text from source language to target language using AWS Translate.
Parameters:
data (str): The text to be translated.
source_language (str): The source language code (e.g., 'sr' for Serbian). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
target_language (str): The target language code (e.g., 'en' for English). ISO 639-2 Code https://www.loc.gov/standards/iso639-2/php/code_list.php
region_name (str): AWS region name.
Returns:
str: Translated text or an error message.
"""
import boto3
from botocore.exceptions import BotoCoreError, ClientError
if not data:
return "No text provided for translation."
if not source_language or not target_language:
return "Both source and target language codes are required."
try:
translate = boto3.client(service_name='translate', region_name=region_name, use_ssl=True)
result = translate.translate_text(Text=data, SourceLanguageCode=source_language, TargetLanguageCode=target_language)
return result.get('TranslatedText', 'No translation found.')
except BotoCoreError as e:
logging.info(f"BotoCoreError occurred: {e}")
return "Error with AWS Translate service configuration or request."
except ClientError as e:
logging.info(f"ClientError occurred: {e}")
return "Error with AWS client or network issue."

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cognee.modules.data.chunking import chunk_by_paragraph
from cognee.tasks.chunking import chunk_by_paragraph

if __name__ == "__main__":
def test_chunking_on_whole_text():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from uuid import uuid5, NAMESPACE_OID
from .chunk_by_sentence import chunk_by_sentence
from cognee.tasks.chunking.chunking_registry import register_chunking_function

@register_chunking_function("paragraph")
def chunk_by_paragraph(data: str, paragraph_length: int = 1024, batch_paragraphs = True):
paragraph = ""
last_cut_type = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@



from uuid import uuid4
from .chunk_by_word import chunk_by_word
from cognee.tasks.chunking.chunking_registry import register_chunking_function

@register_chunking_function("sentence")
def chunk_by_sentence(data: str):
sentence = ""
paragraph_id = uuid4()
Expand Down
10 changes: 10 additions & 0 deletions cognee/tasks/chunking/chunking_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
chunking_registry = {}

def register_chunking_function(name):
def decorator(func):
chunking_registry[name] = func
return func
return decorator

def get_chunking_function(name: str):
return chunking_registry.get(name)
Comment on lines +9 to +10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for get_chunking_function.

Consider adding error handling or logging if a requested chunking function is not found in the registry. This will help diagnose issues when an invalid function name is used.

def get_chunking_function(name: str):
    func = chunking_registry.get(name)
    if func is None:
        raise ValueError(f"Chunking function '{name}' not found in registry.")
    return func

Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

import logging



async def detect_language(data:str):
"""
Detect the language of the given text and return its ISO 639-1 language code.
If the detected language is Croatian ('hr'), it maps to Serbian ('sr').
The text is trimmed to the first 100 characters for efficient processing.
Parameters:
text (str): The text for language detection.
Returns:
str: The ISO 639-1 language code of the detected language, or 'None' in case of an error.
"""

# Trim the text to the first 100 characters
from langdetect import detect, LangDetectException
trimmed_text = data[:100]

try:
# Detect the language using langdetect
detected_lang_iso639_1 = detect(trimmed_text)
logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}")

# Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2)
if detected_lang_iso639_1 == 'hr':
yield 'sr'
yield detected_lang_iso639_1

except LangDetectException as e:
logging.error(f"Language detection error: {e}")
except Exception as e:
logging.error(f"Unexpected error: {e}")

yield None
Comment on lines +6 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor detect_language to use return instead of yield.

Using yield in an async function is unconventional and may cause unexpected behavior. Consider refactoring to use return or an async-compatible structure.

-async def detect_language(data:str):
+async def detect_language(data: str) -> str:
    ...
-        if detected_lang_iso639_1 == 'hr':
-            yield 'sr'
-        yield detected_lang_iso639_1
+        return 'sr' if detected_lang_iso639_1 == 'hr' else detected_lang_iso639_1
    ...
-    yield None
+    return None
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def detect_language(data:str):
"""
Detect the language of the given text and return its ISO 639-1 language code.
If the detected language is Croatian ('hr'), it maps to Serbian ('sr').
The text is trimmed to the first 100 characters for efficient processing.
Parameters:
text (str): The text for language detection.
Returns:
str: The ISO 639-1 language code of the detected language, or 'None' in case of an error.
"""
# Trim the text to the first 100 characters
from langdetect import detect, LangDetectException
trimmed_text = data[:100]
try:
# Detect the language using langdetect
detected_lang_iso639_1 = detect(trimmed_text)
logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}")
# Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2)
if detected_lang_iso639_1 == 'hr':
yield 'sr'
yield detected_lang_iso639_1
except LangDetectException as e:
logging.error(f"Language detection error: {e}")
except Exception as e:
logging.error(f"Unexpected error: {e}")
yield None
async def detect_language(data: str) -> str:
"""
Detect the language of the given text and return its ISO 639-1 language code.
If the detected language is Croatian ('hr'), it maps to Serbian ('sr').
The text is trimmed to the first 100 characters for efficient processing.
Parameters:
text (str): The text for language detection.
Returns:
str: The ISO 639-1 language code of the detected language, or 'None' in case of an error.
"""
# Trim the text to the first 100 characters
from langdetect import detect, LangDetectException
trimmed_text = data[:100]
try:
# Detect the language using langdetect
detected_lang_iso639_1 = detect(trimmed_text)
logging.info(f"Detected ISO 639-1 code: {detected_lang_iso639_1}")
# Special case: map 'hr' (Croatian) to 'sr' (Serbian ISO 639-2)
return 'sr' if detected_lang_iso639_1 == 'hr' else detected_lang_iso639_1
except LangDetectException as e:
logging.error(f"Language detection error: {e}")
except Exception as e:
logging.error(f"Unexpected error: {e}")
return None

Loading