Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API endpoint to export accuracy metrics from user feedback + created_at timestamp #803

Merged
merged 9 commits into from
Feb 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion haystack/document_store/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,15 @@ def _create_label_index(self, index_name: str):
"answer": {"type": "text"},
"is_correct_answer": {"type": "boolean"},
"is_correct_document": {"type": "boolean"},
"origin": {"type": "keyword"},
"origin": {"type": "keyword"}, # e.g. user-feedback or gold-label
"document_id": {"type": "keyword"},
"offset_start_in_doc": {"type": "long"},
"no_answer": {"type": "boolean"},
"model_id": {"type": "keyword"},
"type": {"type": "keyword"},
"created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||epoch_millis"},
"updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||epoch_millis"}
#TODO add pipeline_hash and pipeline_name once we migrated the REST API to pipelines
}
}
}
Expand Down Expand Up @@ -364,6 +367,12 @@ def write_labels(
else:
label = l

# create timestamps if not available yet
if not label.created_at:
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
if not label.updated_at:
label.updated_at = label.created_at

_label = {
"_op_type": "index" if self.update_existing_documents else "create",
"_index": index,
Expand Down
6 changes: 6 additions & 0 deletions haystack/document_store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import deepcopy
from typing import Dict, List, Optional, Union, Generator
from uuid import uuid4
import time

import numpy as np
from scipy.spatial.distance import cosine
Expand Down Expand Up @@ -87,6 +88,11 @@ def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[s

for label in label_objects:
label_id = str(uuid4())
# create timestamps if not available yet
if not label.created_at:
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
if not label.updated_at:
label.updated_at = label.created_at
self.indexes[index][label_id] = label

def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
Expand Down
6 changes: 4 additions & 2 deletions haystack/document_store/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class ORMBase(Base):
__abstract__ = True

id = Column(String(100), default=lambda: str(uuid4()), primary_key=True)
created = Column(DateTime, server_default=func.now())
updated = Column(DateTime, server_default=func.now(), server_onupdate=func.now())
created_at = Column(DateTime, server_default=func.now())
updated_at = Column(DateTime, server_default=func.now(), server_onupdate=func.now())


class DocumentORM(ORMBase):
Expand Down Expand Up @@ -424,6 +424,8 @@ def _convert_sql_row_to_label(self, row) -> Label:
answer=row.answer,
offset_start_in_doc=row.offset_start_in_doc,
model_id=row.model_id,
created_at=row.created_at,
updated_at=row.updated_at
)
return label

Expand Down
20 changes: 15 additions & 5 deletions haystack/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from abc import abstractmethod
from typing import Any, Optional, Dict, List
from uuid import uuid4

import numpy as np


Expand Down Expand Up @@ -78,6 +76,7 @@ def __repr__(self):
def __str__(self):
return str(self.to_dict())


class Label:
def __init__(self, question: str,
answer: str,
Expand All @@ -88,7 +87,9 @@ def __init__(self, question: str,
document_id: Optional[str] = None,
offset_start_in_doc: Optional[int] = None,
no_answer: Optional[bool] = None,
model_id: Optional[int] = None):
model_id: Optional[int] = None,
created_at: Optional[str] = None,
updated_at: Optional[str] = None):
"""
Object used to represent label/feedback in a standardized way within Haystack.
This includes labels from dataset like SQuAD, annotations from labeling tools,
Expand All @@ -106,6 +107,10 @@ def __init__(self, question: str,
:param offset_start_in_doc: the answer start offset in the document.
:param no_answer: whether the question in unanswerable.
:param model_id: model_id used for prediction (in-case of user feedback).
:param created_at: Timestamp of creation with format yyyy-MM-dd HH:mm:ss.
Generate in Python via time.strftime("%Y-%m-%d %H:%M:%S").
:param created_at: Timestamp of update with format yyyy-MM-dd HH:mm:ss.
Generate in Python via time.strftime("%Y-%m-%d %H:%M:%S")
"""

# Create a unique ID (either new one, or one from user input)
Expand All @@ -114,6 +119,8 @@ def __init__(self, question: str,
else:
self.id = str(uuid4())

self.created_at = created_at
self.updated_at = updated_at
self.question = question
self.answer = answer
self.is_correct_answer = is_correct_answer
Expand Down Expand Up @@ -142,7 +149,9 @@ def __eq__(self, other):
getattr(other, 'document_id', None) == self.document_id and
getattr(other, 'offset_start_in_doc', None) == self.offset_start_in_doc and
getattr(other, 'no_answer', None) == self.no_answer and
getattr(other, 'model_id', None) == self.model_id)
getattr(other, 'model_id', None) == self.model_id and
getattr(other, 'created_at', None) == self.created_at and
getattr(other, 'updated_at', None) == self.updated_at)

def __hash__(self):
return hash(self.question +
Expand All @@ -153,7 +162,8 @@ def __hash__(self):
str(self.document_id) +
str(self.offset_start_in_doc) +
str(self.no_answer) +
str(self.model_id))
str(self.model_id)
)

def __repr__(self):
return str(self.to_dict())
Expand Down
46 changes: 46 additions & 0 deletions rest_api/controller/feedback.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Optional
import time

from fastapi import APIRouter
from pydantic import BaseModel, Field
from typing import Dict, Union, List

from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from rest_api.config import (
Expand Down Expand Up @@ -65,6 +67,8 @@ class DocQAFeedback(FAQQAFeedback):
..., description="The answer start offset in the original doc. Only required for doc-qa feedback."
)

class FilterRequest(BaseModel):
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None

@router.post("/doc-qa-feedback")
def doc_qa_feedback(feedback: DocQAFeedback):
Expand All @@ -77,6 +81,48 @@ def faq_qa_feedback(feedback: FAQQAFeedback):
document_store.write_labels([{"origin": "user-feedback-faq", **feedback_payload}])


@router.post("/eval-doc-qa-feedback")
def eval_doc_qa_feedback(filters: FilterRequest = None):
"""
Return basic accuracy metrics based on the user feedback.
Which ratio of answers was correct? Which ratio of documents was correct?
You can supply filters in the request to only use a certain subset of labels.

**Example:**

```
| curl --location --request POST 'http://127.0.0.1:8000/eval-doc-qa-feedback' \
| --header 'Content-Type: application/json' \
| --data-raw '{ "filters": {"document_id": ["XRR3xnEBCYVTkbTystOB"]} }'
"""

if filters:
filters = filters.filters
filters["origin"] = ["user-feedback"]
else:
filters = {"origin": ["user-feedback"]}

labels = document_store.get_all_labels(
index=DB_INDEX_FEEDBACK,
filters=filters
)

if len(labels) > 0:
answer_feedback = [1 if l.is_correct_answer else 0 for l in labels]
doc_feedback = [1 if l.is_correct_document else 0 for l in labels]

answer_accuracy = sum(answer_feedback)/len(answer_feedback)
doc_accuracy = sum(doc_feedback)/len(doc_feedback)

res = {"answer_accuracy": answer_accuracy,
"document_accuracy": doc_accuracy,
"n_feedback": len(labels)}
else:
res = {"answer_accuracy": None,
"document_accuracy": None,
"n_feedback": 0}
return res

@router.get("/export-doc-qa-feedback")
def export_doc_qa_feedback(context_size: int = 2_000):
"""
Expand Down