Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(albert): chunkage des données + recherche des informations #16

Merged
merged 9 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ python = "~3.12"
asyncpg = "^0.30.0"
black = "^24.10.0"
python-dotenv = "^1.0.1"
ollama = "^0.3.3"
flagembedding = "^1.3.2"
numpy = "^2.1.3"
httpx = "^0.27.2"
pandas = "^2.2.3"

[build-system]
requires = ["poetry-core"]
Expand Down
7 changes: 5 additions & 2 deletions srdt_analysis/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from dotenv import load_dotenv
from .exploit_data import exploit_data
from srdt_analysis.llm import get_llm
from srdt_analysis.exploit_data import exploit_data


load_dotenv()


def main():
exploit_data()
# exploit_data()
get_llm()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion srdt_analysis/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncpg
import asyncio
from typing import List, Tuple
from .models import Document, DocumentsList
from srdt_analysis.models import Document, DocumentsList


async def fetch_articles_code_du_travail(
Expand Down
10 changes: 5 additions & 5 deletions srdt_analysis/exploit_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .data import get_data
from .models import DocumentsList
from .llm import get_extended_data
from .save import save_to_csv, process_document
from .vector import generate_vector
from srdt_analysis.data import get_data
from srdt_analysis.models import DocumentsList
from srdt_analysis.llm import get_extended_data
from srdt_analysis.save import save_to_csv, process_document
from srdt_analysis.vector import generate_vector
from datetime import datetime


Expand Down
52 changes: 36 additions & 16 deletions srdt_analysis/llm.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,40 @@
import ollama
import httpx
import os


def get_extended_data(message: str, is_summary: bool) -> str:
summary_prompt = "Tu es un chatbot expert en droit du travail français. Lis le texte donné et rédige un résumé clair, précis et concis, limité à 4096 tokens maximum. Dans ce résumé, fais ressortir les points clés suivants : Sujets principaux : identifie les droits et obligations des employeurs et des salariés, les conditions de travail, les procédures, etc. Langage clair : simplifie le langage juridique tout en restant précis pour éviter toute confusion. Organisation logique : commence par les informations principales, puis détaille les exceptions ou points secondaires s'ils existent. Neutralité : garde un ton factuel, sans jugement ou interprétation subjective. Longueur : si le texte est long, privilégie les informations essentielles pour respecter la limite de 4096 tokens."
keyword_prompt = "Tu es un chatbot expert en droit du travail français. Ta seule mission est d’extraire une liste de mots-clés à partir du texte fourni. Objectif : Les mots-clés doivent refléter les idées et thèmes principaux pour faciliter la compréhension et la recherche du contenu du texte. Sélection : Extrait uniquement les termes essentiels, comme les droits et devoirs des employeurs et des salariés, les conditions de travail, les procédures, les sanctions, etc. Non-redondance : Évite les répétitions ; chaque mot-clé doit apparaître une seule fois. Clarté et simplicité : Assure-toi que chaque mot-clé est compréhensible et pertinent. Format attendu pour la liste de mots-clés : une liste simple et directe, sans organisation par thèmes, comme dans cet exemple : code du travail, article 12, congés payés, heures supplémentaires, licenciement économique"
response = ollama.chat(
model="mistral-nemo",
messages=[
{
"role": "system",
"content": summary_prompt if is_summary else keyword_prompt,
},
{
"role": "user",
"content": message,
summary_prompt = "Tu es un chatbot expert en droit du travail français. Lis le texte donné et rédige un résumé clair, précis et concis, limité à 4096 tokens maximum. Dans ce résumé, fais ressortir les points clés suivants : Sujets principaux : identifie les droits et obligations des employeurs et des salariés, les conditions de travail, les procédures, etc. Langage clair : simplifie le langage juridique tout en restant précis pour éviter toute confusion. Organisation logique : commence par les informations principales, puis détaille les exceptions ou points secondaires s'ils existent. Neutralité : garde un ton factuel, sans jugement ou interprétation subjective. Longueur : si le texte est long, privilégie les informations essentielles pour respecter la limite de 4096 tokens. Ta réponse doit obligatoirement être en français."
keyword_prompt = "Tu es un chatbot expert en droit du travail français. Ta seule mission est d’extraire une liste de mots-clés à partir du texte fourni. Objectif : Les mots-clés doivent refléter les idées et thèmes principaux pour faciliter la compréhension et la recherche du contenu du texte. Sélection : Extrait uniquement les termes essentiels, comme les droits et devoirs des employeurs et des salariés, les conditions de travail, les procédures, les sanctions, etc. Non-redondance : Évite les répétitions ; chaque mot-clé doit apparaître une seule fois. Clarté et simplicité : Assure-toi que chaque mot-clé est compréhensible et pertinent. Format attendu pour la liste de mots-clés : une liste simple et directe, sans organisation par thèmes, comme dans cet exemple : code du travail, article 12, congés payés, heures supplémentaires, licenciement économique. Ta réponse doit obligatoirement être en français."

api_key = os.getenv("ALBERT_API_KEY")
if not api_key:
raise ValueError("API key for Albert is not set")

try:
response = httpx.post(
"https://albert.api.etalab.gouv.fr/v1/embeddings",
headers={"Authorization": f"Bearer {api_key}"},
json={
"messages": [
{
"role": "system",
"content": summary_prompt if is_summary else keyword_prompt,
},
{
"role": "user",
"content": message,
},
],
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
},
],
)
return response["message"]["content"]
)
response.raise_for_status()
chat_response = response.json()["choices"][0]["message"]["content"]
except httpx.HTTPStatusError as e:
raise RuntimeError(
f"Request failed: {e.response.status_code} - {e.response.text}"
)
except Exception as e:
raise RuntimeError(f"An error occurred: {str(e)}")

return chat_response
34 changes: 8 additions & 26 deletions srdt_analysis/save.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,10 @@
import csv
import pandas as pd
from typing import List, Dict


def save_to_csv(data: List[Dict], filename: str) -> None:
headers = [
"cdtn_id",
"initial_id",
"title",
"content",
"idcc",
"keywords",
"summary",
"vector_summary",
"vector_keywords",
]

with open(f"data/{filename}", "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=headers)
writer.writeheader()
writer.writerows(data)


def remove_newlines(content: str) -> str:
return content.replace("\n", "-")
df = pd.DataFrame(data)
df.to_csv(f"data/{filename}", index=False)


def process_document(
Expand All @@ -40,10 +22,10 @@ def process_document(
"cdtn_id": cdtn_id,
"initial_id": initial_id,
"title": title,
"content": remove_newlines(content),
"keywords": remove_newlines(keywords),
"summary": remove_newlines(summary),
"vector_summary": remove_newlines(str(vector_summary)),
"vector_keywords": remove_newlines(str(vector_keywords)),
"content": content,
"keywords": keywords,
"summary": summary,
"vector_summary": vector_summary,
"vector_keywords": vector_keywords,
"idcc": idcc,
}
14 changes: 11 additions & 3 deletions srdt_analysis/vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from FlagEmbedding import BGEM3FlagModel
import os
import httpx


def generate_vector(text: str) -> dict:
model = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True)
vector = model.encode(text)
response = httpx.post(
"https://albert.api.etalab.gouv.fr/v1/embeddings",
headers={"Authorization": f"Bearer {os.getenv('ALBERT_API_KEY')}"},
data={
"input": text,
"model": "BAAI/bge-m3",
},
)
vector = response.json()["data"]
return vector