Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(albert): chunkage des données + recherche des informations #16

Merged
merged 9 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ dmypy.json
local_dump/*

# Data
data/*.csv
data/*
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
poetry shell
poetry install
poetry run start # or poetry run python -m srdt_analysis
ruff check --fix
ruff format
pyright # for type checking
```

## Statistiques sur les documents
Expand Down
41 changes: 34 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "~3.12"
asyncpg = "^0.30.0"
black = "^24.10.0"
python-dotenv = "^1.0.1"
ollama = "^0.3.3"
flagembedding = "^1.3.2"
numpy = "^2.1.3"
httpx = "^0.27.2"
pandas = "^2.2.3"
langchain-text-splitters = "^0.3.2"

[tool.poetry.group.dev.dependencies]
pyright = "^1.1.389"
ruff = "^0.8.0"

[build-system]
requires = ["poetry-core"]
Expand All @@ -21,6 +24,30 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry.scripts]
start = "srdt_analysis.__main__:main"

[tool.black]
line-length = 90
include = '\.pyi?$'
[tool.ruff]
exclude = [
".ruff_cache",
"__pycache__",
]
line-length = 88
indent-width = 4

[tool.ruff.lint]
select = ["E4", "E7", "E9", "F"]
extend-select = ["I"]
ignore = []
fixable = ["ALL"]
unfixable = []
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
docstring-code-line-length = "dynamic"

[tool.pyright]
include = ["srdt_analysis"]
exclude = ["**/__pycache__"]
18 changes: 16 additions & 2 deletions srdt_analysis/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
from dotenv import load_dotenv
from .exploit_data import exploit_data

from srdt_analysis.collections import Collections
from srdt_analysis.data_exploiter import PageInfosExploiter
from srdt_analysis.database_manager import get_data

load_dotenv()


def main():
exploit_data()
data = get_data()
exploiter = PageInfosExploiter()
result = exploiter.process_documents(
[data[3][0]], "page_infos.csv", "cdtn_page_infos"
)
collections = Collections()
res = collections.search(
"combien de jour de congé payé par mois de travail effectif",
[result["id"]],
)

print(res)


if __name__ == "__main__":
Expand Down
20 changes: 20 additions & 0 deletions srdt_analysis/albert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
from typing import Any, Dict

import httpx

from srdt_analysis.constants import ALBERT_ENDPOINT


class AlbertBase:
def __init__(self):
self.api_key = os.getenv("ALBERT_API_KEY")
if not self.api_key:
raise ValueError(
"API key must be provided either in constructor or as environment variable"
)
self.headers = {"Authorization": f"Bearer {self.api_key}"}

def get_models(self) -> Dict[str, Any]:
response = httpx.get(f"{ALBERT_ENDPOINT}/v1/models", headers=self.headers)
return response.json()
41 changes: 41 additions & 0 deletions srdt_analysis/chunker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import List

from langchain_text_splitters import (
MarkdownHeaderTextSplitter,
RecursiveCharacterTextSplitter,
)

from srdt_analysis.constants import CHUNK_OVERLAP, CHUNK_SIZE
from srdt_analysis.models import SplitDocument


class Chunker:
def __init__(self):
self._markdown_splitter = MarkdownHeaderTextSplitter(
[
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
("#####", "Header 5"),
("######", "Header 6"),
],
strip_headers=False,
)
self._character_recursive_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
)

def split_markdown(self, markdown: str) -> List[SplitDocument]:
md_header_splits = self._markdown_splitter.split_text(markdown)
documents = self._character_recursive_splitter.split_documents(md_header_splits)
return [SplitDocument(doc.page_content, doc.metadata) for doc in documents]

def split_character_recursive(self, content: str) -> List[SplitDocument]:
text_splits = self._character_recursive_splitter.split_text(content)
return [SplitDocument(text, {}) for text in text_splits]

def split(self, content: str, content_type: str = "markdown"):
if content_type.lower() == "markdown":
return self.split_markdown(content)
raise ValueError(f"Unsupported content type: {content_type}")
104 changes: 104 additions & 0 deletions srdt_analysis/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json
from io import BytesIO
from typing import Any, Dict, List

import httpx

from srdt_analysis.albert import AlbertBase
from srdt_analysis.constants import ALBERT_ENDPOINT
from srdt_analysis.models import ChunkDataList, DocumentData


class Collections(AlbertBase):
def _create(self, collection_name: str, model: str) -> str:
payload = {"name": collection_name, "model": model}
response = httpx.post(
f"{ALBERT_ENDPOINT}/v1/collections", headers=self.headers, json=payload
)
return response.json()["id"]

def create(self, collection_name: str, model: str) -> str:
collections: List[Dict[str, Any]] = self.list()
for collection in collections:
if collection["name"] == collection_name:
self.delete(collection["id"])
return self._create(collection_name, model)

def list(self) -> List[Dict[str, Any]]:
response = httpx.get(f"{ALBERT_ENDPOINT}/v1/collections", headers=self.headers)
return response.json()["data"]

def delete(self, id_collection: str):
response = httpx.delete(
f"{ALBERT_ENDPOINT}/v1/collections/{id_collection}", headers=self.headers
)
response.raise_for_status()

def delete_all(self, collection_name) -> None:
collections = self.list()
for collection in collections:
if collection["name"] == collection_name:
self.delete(collection["id"])
return None

def search(
self,
prompt: str,
id_collections: List[str],
k: int = 5,
score_threshold: float = 0,
) -> ChunkDataList:
response = httpx.post(
f"{ALBERT_ENDPOINT}/v1/search",
headers=self.headers,
json={
"prompt": prompt,
"collections": id_collections,
"k": k,
"score_threshold": score_threshold,
},
)
return response.json()

def upload(
self,
data: List[DocumentData],
id_collection: str,
) -> None:
result = []
for dt in data:
dt: DocumentData
chunks = dt["content_chunked"]
for chunk in chunks:
result.append(
{
"text": chunk.page_content,
"title": dt["title"],
"metadata": {
"cdtn_id": dt["cdtn_id"],
"structure_du_chunk": chunk.metadata,
"url": dt["url"],
},
}
)

file_content = json.dumps(result).encode("utf-8")

files = {
"file": (
"content.json",
BytesIO(file_content),
"multipart/form-data",
)
}

request_data = {"request": '{"collection": "%s"}' % id_collection}
response = httpx.post(
f"{ALBERT_ENDPOINT}/v1/files",
headers=self.headers,
files=files,
data=request_data,
)

response.raise_for_status()
return
6 changes: 6 additions & 0 deletions srdt_analysis/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
ALBERT_ENDPOINT = "https://albert.api.etalab.gouv.fr"
MODEL_VECTORISATION = "BAAI/bge-m3"
LLM_MODEL = "meta-llama/Meta-Llama-3.1-70B-Instruct"
CHUNK_SIZE = 5000
CHUNK_OVERLAP = 500
BASE_URL_CDTN = "https://code.travail.gouv.fr"
91 changes: 0 additions & 91 deletions srdt_analysis/data.py

This file was deleted.

Loading