Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use diffbot-kg client library #4

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions api/app/enhance.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlencode
from typing import Dict, List, Literal, Optional, Tuple, Union

import requests
from diffbot_kg import DiffbotEnhanceClient
from utils import graph

CATEGORY_THRESHOLD = 0.50
params = []

DIFF_TOKEN = os.environ["DIFFBOT_API_KEY"]

client = DiffbotEnhanceClient(DIFF_TOKEN)

def get_datetime(value: Optional[Union[str, int, float]]) -> datetime:
if not value:
return value
return datetime.fromtimestamp(float(value) / 1000.0)


def process_entities(entity: str, type: str) -> Dict[str, Any]:

async def process_entities(entity: str, type: str) -> Tuple[str, List[Dict]]:
"""
Fetch relevant articles from Diffbot KG endpoint
"""
search_host = "https://kg.diffbot.com/kg/v3/enhance?"
params = {"type": type, "name": entity, "token": DIFF_TOKEN}
encoded_query = urlencode(params)
url = f"{search_host}{encoded_query}"
return entity, requests.get(url).json()
params = {"type": type, "name": entity}
response = await client.enhance(params)

return entity, response.entities


def get_people_params(row: Dict) -> Optional[Dict]:
Expand Down
37 changes: 22 additions & 15 deletions api/app/importing.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,40 @@
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional

import requests
from diffbot_kg import DiffbotSearchClient
from utils import embeddings, text_splitter

CATEGORY_THRESHOLD = 0.50
params = []

DIFF_TOKEN = os.environ["DIFFBOT_API_KEY"]

client = DiffbotSearchClient(token=DIFF_TOKEN)

def get_articles(
query: Optional[str], tag: Optional[str], size: int = 5, offset: int = 0
) -> Dict[str, Any]:

async def get_articles(
query: Optional[str],
tag: Optional[str],
size: int = 5,
offset: int = 0,
) -> List[Dict]:
"""
Fetch relevant articles from Diffbot KG endpoint
"""
search_query = "type:Article language:en sortBy:date"
if query:
search_query += f' strict:text:"{query}"'
if tag:
search_query += f' tags.label:"{tag}"'

params = {"query": search_query, "size": size, "offset": offset}

logging.info(f"Fetching articles with params: {params}")

try:
search_host = "https://kg.diffbot.com/kg/v3/dql?"
search_query = f'query=type%3AArticle+strict%3Alanguage%3A"en"+sortBy%3Adate'
if query:
search_query += f'+text%3A"{query}"'
if tag:
search_query += f'+tags.label%3A"{tag}"'
url = (
f"{search_host}{search_query}&token={DIFF_TOKEN}&from={offset}&size={size}"
)
return requests.get(url).json()
response = await client.search(params)
return response.entities
except Exception as ex:
raise ex

Expand Down
63 changes: 41 additions & 22 deletions api/app/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -20,7 +21,8 @@
)

# Multithreading for Diffbot API
MAX_WORKERS = min(os.cpu_count() * 5, 20)
MAX_WORKERS = min((os.cpu_count() or 1) * 5, 20)
MAX_TASKS = 20

app = FastAPI()

Expand All @@ -35,21 +37,24 @@


@app.post("/import_articles/")
def import_articles_endpoint(article_data: ArticleData) -> int:
async def import_articles_endpoint(article_data: ArticleData) -> int:
logging.info(f"Starting to process article import with params: {article_data}")
if not article_data.query and not article_data.tag:
if not article_data.query and not article_data.category and not article_data.tag:
raise HTTPException(
status_code=500, detail="Either `query` or `tag` must be provided"
status_code=500,
detail="Either `query` or `category` or `tag` must be provided",
)
data = get_articles(article_data.query, article_data.tag, article_data.size)
logging.info(f"Articles fetched: {len(data['data'])} articles.")
articles = await get_articles(
article_data.query, article_data.category, article_data.tag, article_data.size
)
logging.info(f"Articles fetched: {len(articles)} articles.")
try:
params = process_params(data)
params = process_params(articles)
except Exception as e:
# You could log the exception here if needed
raise HTTPException(status_code=500, detail=e)
raise HTTPException(status_code=500, detail=e) from e
graph.query(import_cypher_query, params={"data": params})
logging.info(f"Article import query executed successfully.")
logging.info("Article import query executed successfully.")
return len(params)


Expand Down Expand Up @@ -124,26 +129,40 @@ def fetch_unprocessed_count(count_data: CountData) -> int:


@app.post("/enhance_entities/")
def enhance_entities(entity_data: EntityData) -> str:
async def enhance_entities(entity_data: EntityData) -> str:
entities = graph.query(
"MATCH (a:Person|Organization) WHERE a.processed IS NULL "
"WITH a LIMIT toInteger($limit) "
"RETURN [el in labels(a) WHERE el <> '__Entity__' | el][0] "
"AS label, collect(a.name) AS entities",
params={"limit": entity_data.size},
)
enhanced_data = []
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
# Submitting all tasks and creating a list of future objects
for row in entities:
futures = [
executor.submit(process_entities, el, row["label"])
for el in row["entities"]
]

for future in futures:
response = future.result()
enhanced_data.append(response)
enhanced_data = {}

queue = asyncio.Queue()
for row in entities:
for el in row["entities"]:
await queue.put((el, row["label"]))

async def worker():
while True:
el, label = await queue.get()
logging.info("Processing entity: %s", el)
try:
response = await process_entities(el, label)
enhanced_data[response[0]] = response[1]
finally:
logging.info("Processed: %s", el)
queue.task_done()

num_workers = min(queue.qsize(), MAX_TASKS)
workers = [asyncio.create_task(worker()) for _ in range(num_workers)]
await queue.join()

for w in workers:
w.cancel()
await asyncio.gather(*workers, return_exceptions=True)

store_enhanced_data(enhanced_data)
return "Finished enhancing entities."

Expand Down
Loading