Skip to content

Commit

Permalink
feat(albert): chunkage des données + recherche des informations (#16)
Browse files Browse the repository at this point in the history
* fix: remote

* feat(chunk): ajout de la partie chunkage (#18)

* fix: chunk

* fix: finish

* fix: finish

* fix: finish

* fix: finish

* fix: done

* fix: format

* config: Disable some pylint and mypy rules that are not necessarily useful

* fix: retours

* fix: retours

* fix: retours

* fix: retours

* fix: config

---------

Co-authored-by: Victor DEGLIAME <[email protected]>
  • Loading branch information
maxgfr and RealVidy authored Nov 27, 2024
1 parent 19f7cf5 commit d508a9f
Show file tree
Hide file tree
Showing 18 changed files with 626 additions and 361 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ dmypy.json
local_dump/*

# Data
data/*.csv
data/*
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
poetry shell
poetry install
poetry run start # or poetry run python -m srdt_analysis
ruff check --fix
ruff format
pyright # for type checking
```

## Statistiques sur les documents
Expand Down
41 changes: 34 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "~3.12"
asyncpg = "^0.30.0"
black = "^24.10.0"
python-dotenv = "^1.0.1"
ollama = "^0.3.3"
flagembedding = "^1.3.2"
numpy = "^2.1.3"
httpx = "^0.27.2"
pandas = "^2.2.3"
langchain-text-splitters = "^0.3.2"

[tool.poetry.group.dev.dependencies]
pyright = "^1.1.389"
ruff = "^0.8.0"

[build-system]
requires = ["poetry-core"]
Expand All @@ -21,6 +24,30 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry.scripts]
start = "srdt_analysis.__main__:main"

[tool.black]
line-length = 90
include = '\.pyi?$'
[tool.ruff]
exclude = [
".ruff_cache",
"__pycache__",
]
line-length = 88
indent-width = 4

[tool.ruff.lint]
select = ["E4", "E7", "E9", "F"]
extend-select = ["I"]
ignore = []
fixable = ["ALL"]
unfixable = []
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
docstring-code-line-length = "dynamic"

[tool.pyright]
include = ["srdt_analysis"]
exclude = ["**/__pycache__"]
18 changes: 16 additions & 2 deletions srdt_analysis/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
from dotenv import load_dotenv
from .exploit_data import exploit_data

from srdt_analysis.collections import Collections
from srdt_analysis.data_exploiter import PageInfosExploiter
from srdt_analysis.database_manager import get_data

load_dotenv()


def main():
exploit_data()
data = get_data()
exploiter = PageInfosExploiter()
result = exploiter.process_documents(
[data[3][0]], "page_infos.csv", "cdtn_page_infos"
)
collections = Collections()
res = collections.search(
"combien de jour de congé payé par mois de travail effectif",
[result["id"]],
)

print(res)


if __name__ == "__main__":
Expand Down
20 changes: 20 additions & 0 deletions srdt_analysis/albert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
from typing import Any, Dict

import httpx

from srdt_analysis.constants import ALBERT_ENDPOINT


class AlbertBase:
def __init__(self):
self.api_key = os.getenv("ALBERT_API_KEY")
if not self.api_key:
raise ValueError(
"API key must be provided either in constructor or as environment variable"
)
self.headers = {"Authorization": f"Bearer {self.api_key}"}

def get_models(self) -> Dict[str, Any]:
response = httpx.get(f"{ALBERT_ENDPOINT}/v1/models", headers=self.headers)
return response.json()
41 changes: 41 additions & 0 deletions srdt_analysis/chunker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import List

from langchain_text_splitters import (
MarkdownHeaderTextSplitter,
RecursiveCharacterTextSplitter,
)

from srdt_analysis.constants import CHUNK_OVERLAP, CHUNK_SIZE
from srdt_analysis.models import SplitDocument


class Chunker:
def __init__(self):
self._markdown_splitter = MarkdownHeaderTextSplitter(
[
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
("#####", "Header 5"),
("######", "Header 6"),
],
strip_headers=False,
)
self._character_recursive_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
)

def split_markdown(self, markdown: str) -> List[SplitDocument]:
md_header_splits = self._markdown_splitter.split_text(markdown)
documents = self._character_recursive_splitter.split_documents(md_header_splits)
return [SplitDocument(doc.page_content, doc.metadata) for doc in documents]

def split_character_recursive(self, content: str) -> List[SplitDocument]:
text_splits = self._character_recursive_splitter.split_text(content)
return [SplitDocument(text, {}) for text in text_splits]

def split(self, content: str, content_type: str = "markdown"):
if content_type.lower() == "markdown":
return self.split_markdown(content)
raise ValueError(f"Unsupported content type: {content_type}")
104 changes: 104 additions & 0 deletions srdt_analysis/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json
from io import BytesIO
from typing import Any, Dict, List

import httpx

from srdt_analysis.albert import AlbertBase
from srdt_analysis.constants import ALBERT_ENDPOINT
from srdt_analysis.models import ChunkDataList, DocumentData


class Collections(AlbertBase):
def _create(self, collection_name: str, model: str) -> str:
payload = {"name": collection_name, "model": model}
response = httpx.post(
f"{ALBERT_ENDPOINT}/v1/collections", headers=self.headers, json=payload
)
return response.json()["id"]

def create(self, collection_name: str, model: str) -> str:
collections: List[Dict[str, Any]] = self.list()
for collection in collections:
if collection["name"] == collection_name:
self.delete(collection["id"])
return self._create(collection_name, model)

def list(self) -> List[Dict[str, Any]]:
response = httpx.get(f"{ALBERT_ENDPOINT}/v1/collections", headers=self.headers)
return response.json()["data"]

def delete(self, id_collection: str):
response = httpx.delete(
f"{ALBERT_ENDPOINT}/v1/collections/{id_collection}", headers=self.headers
)
response.raise_for_status()

def delete_all(self, collection_name) -> None:
collections = self.list()
for collection in collections:
if collection["name"] == collection_name:
self.delete(collection["id"])
return None

def search(
self,
prompt: str,
id_collections: List[str],
k: int = 5,
score_threshold: float = 0,
) -> ChunkDataList:
response = httpx.post(
f"{ALBERT_ENDPOINT}/v1/search",
headers=self.headers,
json={
"prompt": prompt,
"collections": id_collections,
"k": k,
"score_threshold": score_threshold,
},
)
return response.json()

def upload(
self,
data: List[DocumentData],
id_collection: str,
) -> None:
result = []
for dt in data:
dt: DocumentData
chunks = dt["content_chunked"]
for chunk in chunks:
result.append(
{
"text": chunk.page_content,
"title": dt["title"],
"metadata": {
"cdtn_id": dt["cdtn_id"],
"structure_du_chunk": chunk.metadata,
"url": dt["url"],
},
}
)

file_content = json.dumps(result).encode("utf-8")

files = {
"file": (
"content.json",
BytesIO(file_content),
"multipart/form-data",
)
}

request_data = {"request": '{"collection": "%s"}' % id_collection}
response = httpx.post(
f"{ALBERT_ENDPOINT}/v1/files",
headers=self.headers,
files=files,
data=request_data,
)

response.raise_for_status()
return
6 changes: 6 additions & 0 deletions srdt_analysis/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
ALBERT_ENDPOINT = "https://albert.api.etalab.gouv.fr"
MODEL_VECTORISATION = "BAAI/bge-m3"
LLM_MODEL = "meta-llama/Meta-Llama-3.1-70B-Instruct"
CHUNK_SIZE = 5000
CHUNK_OVERLAP = 500
BASE_URL_CDTN = "https://code.travail.gouv.fr"
91 changes: 0 additions & 91 deletions srdt_analysis/data.py

This file was deleted.

Loading

0 comments on commit d508a9f

Please sign in to comment.