Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 15, 2024
1 parent 4867802 commit 5a68480
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 21 deletions.
10 changes: 8 additions & 2 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Optional, Union, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from docarray import BaseDoc, DocList
Expand All @@ -19,12 +19,15 @@ class TopologyInfo:
class TextDoc(BaseDoc, TopologyInfo):
text: str


class ImageDoc(BaseDoc):
image_path: str



class TextImageDoc(BaseDoc):
doc: Tuple[Union[TextDoc, ImageDoc]]


class Base64ByteStrDoc(BaseDoc):
byte_str: str

Expand Down Expand Up @@ -71,6 +74,7 @@ class SearchedDoc(BaseDoc):
class Config:
json_encoders = {np.ndarray: lambda x: x.tolist()}


class SearchedMultimodalDoc(BaseDoc):
retrieved_docs: List[TextImageDoc]
initial_query: str
Expand All @@ -80,6 +84,7 @@ class SearchedMultimodalDoc(BaseDoc):
class Config:
json_encoders = {np.ndarray: lambda x: x.tolist()}


class GeneratedDoc(BaseDoc):
text: str
prompt: str
Expand Down Expand Up @@ -184,6 +189,7 @@ class LVMDoc(BaseDoc):
repetition_penalty: float = 1.03
streaming: bool = False


class LVMVideoDoc(BaseDoc):
video_url: str
chunk_start: float
Expand Down
3 changes: 2 additions & 1 deletion comps/reranks/video-rag-qna/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ done
```

Available configuration by environment variable:

- CHUNK_DURATION: target chunk duration, should be aligned with VideoRAGQnA dataprep. Default 10s.

# ✅ 2. Test

``` bash
```bash
export ip_address=$(hostname -I | awk '{print $1}')
curl -X 'POST' \
"http://${ip_address}:8000/v1/reranking" \
Expand Down
35 changes: 19 additions & 16 deletions comps/reranks/video-rag-qna/local_reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from langsmith import traceable

from comps import (
SearchedMultimodalDoc,
LVMVideoDoc,
SearchedMultimodalDoc,
ServiceType,
opea_microservices,
register_microservice,
Expand All @@ -23,10 +23,9 @@
file_server_url = os.getenv("FILE_SERVER_URL") or "http://0.0.0.0:6005"

logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: [%(asctime)s] %(message)s",
datefmt="%d/%m/%Y %I:%M:%S"
)
level=logging.INFO, format="%(levelname)s: [%(asctime)s] %(message)s", datefmt="%d/%m/%Y %I:%M:%S"
)


def get_top_doc(top_n, videos) -> list:
hit_score = {}
Expand All @@ -40,23 +39,21 @@ def get_top_doc(top_n, videos) -> list:
except KeyError as r:
logging.info(f"no video name {r}")

x = dict(sorted(hit_score.items(), key=lambda item: -item[1])) # sorted dict of video name and score
x = dict(sorted(hit_score.items(), key=lambda item: -item[1])) # sorted dict of video name and score
top_n_names = list(x.keys())[:top_n]
logging.info(f"top docs = {x}")
logging.info(f"top n docs names = {top_n_names}")

return top_n_names


def find_timestamp_from_video(metadata_list, video):
return next(
(
metadata['timestamp']
for metadata in metadata_list
if metadata['video'] == video
),
(metadata["timestamp"] for metadata in metadata_list if metadata["video"] == video),
None,
)


@register_microservice(
name="opea_service@reranking_visual_rag",
service_type=ServiceType.RERANK,
Expand All @@ -70,18 +67,24 @@ def find_timestamp_from_video(metadata_list, video):
@register_statistics(names=["opea_service@reranking_visual_rag"])
def reranking(input: SearchedMultimodalDoc) -> LVMVideoDoc:
start = time.time()

# get top video name from metadata
video_names = [meta["video"] for meta in input.metadata]
top_video_names = get_top_doc(input.top_n, video_names)

# only use the first top video
timestamp = find_timestamp_from_video(input.metadata, top_video_names[0])
video_url = f"{file_server_url.rstrip('/')}/{top_video_names[0]}"

result = LVMVideoDoc(video_url=video_url, prompt=input.initial_query, chunk_start=timestamp, chunk_duration=float(chunk_duration), max_new_tokens=512)

result = LVMVideoDoc(
video_url=video_url,
prompt=input.initial_query,
chunk_start=timestamp,
chunk_duration=float(chunk_duration),
max_new_tokens=512,
)
statistics_dict["opea_service@reranking_visual_rag"].append_latency(time.time() - start, None)

return result


Expand Down
4 changes: 2 additions & 2 deletions comps/reranks/video-rag-qna/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
datasets
docarray
fastapi
uvicorn
langsmith
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
prometheus-fastapi-instrumentator
pydub
shortuuid
langsmith
uvicorn

0 comments on commit 5a68480

Please sign in to comment.