Skip to content

Commit

Permalink
feat: Refactor GET /documents to return minimal fields (vana-com#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kahtaf authored Feb 29, 2024
1 parent 02cc326 commit db81e61
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 47 deletions.
11 changes: 5 additions & 6 deletions selfie/api/documents.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Optional, List
from typing import List

from fastapi import APIRouter
from pydantic import BaseModel

from selfie.database import DataManager
from selfie.parsers.chat import ChatFileParser
from selfie.embeddings import DataIndex
from selfie.parsers.chat import ChatFileParser

router = APIRouter()

Expand All @@ -20,8 +20,8 @@ class IndexDocumentsRequest(BaseModel):


@router.get("/documents")
async def get_documents(source_id: Optional[int] = None):
return DataManager().get_documents(source_id)
async def get_documents():
return DataManager().get_documents()


@router.delete("/documents/{document_id}")
Expand Down Expand Up @@ -57,13 +57,12 @@ async def index_documents(request: IndexDocumentsRequest):
False,
document.id
).conversations,
#source=document.source.name,
# source=document.source.name,
source_document_id=document.id
) if is_chat else None)
for document_id in document_ids
]


# @app.delete("/documents/{document-id}")
# async def delete_data_source(document_id: int):
# DataSourceManager().remove_document(document_id)
Expand Down
9 changes: 5 additions & 4 deletions selfie/connectors/chatgpt/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def load_document(self, configuration: dict[str, Any]) -> List[DocumentDTO]:

return [
DocumentDTO(
content=data_uri_to_string(data_uri),
content_type="text/plain",
content=(content := data_uri_to_string(data_uri)),
content_type="application/json",
name="todo",
size=len(data_uri_to_string(data_uri).encode('utf-8'))
size=len(content.encode('utf-8'))
)
for data_uri in config.files
]
Expand All @@ -36,7 +36,8 @@ def validate_configuration(self, configuration: dict[str, Any]):
# TODO: check if file can be read from path
pass

def transform_for_embedding(self, configuration: dict[str, Any], documents: List[DocumentDTO]) -> List[EmbeddingDocumentModel]:
def transform_for_embedding(self, configuration: dict[str, Any], documents: List[DocumentDTO]) -> List[
EmbeddingDocumentModel]:
return [
embeddingDocumentModel
for document in documents
Expand Down
4 changes: 2 additions & 2 deletions selfie/connectors/whatsapp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def load_document(self, configuration: dict[str, Any]) -> List[DocumentDTO]:

return [
DocumentDTO(
content=data_uri_to_string(data_uri),
content=(content := data_uri_to_string(data_uri)),
content_type="text/plain",
name="todo",
size=len(data_uri_to_string(data_uri).encode('utf-8'))
size=len(content.encode('utf-8'))
)
for data_uri in config.files
]
Expand Down
67 changes: 32 additions & 35 deletions selfie/database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import importlib
import json
import logging
import os
from datetime import datetime
from typing import List, Dict, Any, Callable

from llama_index.core.node_parser import SentenceSplitter
from peewee import (
Expand All @@ -9,23 +13,15 @@
TextField,
ForeignKeyField,
AutoField,
DoesNotExist,
Proxy,
IntegerField,
DateTimeField,
)
import json
import importlib
from typing import List, Dict, Any, Optional, Callable

from playhouse.shortcuts import model_to_dict

from selfie.config import get_app_config
from selfie.embeddings import DataIndex
from selfie.embeddings.document_types import EmbeddingDocumentModel

import logging

# TODO: This module should not be aware of DocumentDTO. Refactor its usage out of this module.
from selfie.types.documents import DocumentDTO

Expand Down Expand Up @@ -62,6 +58,7 @@ class DocumentConnectionModel(BaseModel):
# name = CharField()
connector_name = CharField()
configuration = TextField()

# last_loaded_timestamp = CharField(null=True)

class Meta:
Expand Down Expand Up @@ -107,9 +104,9 @@ def __init__(self, storage_path: str = config.database_storage_root):
self.db.create_tables([DocumentConnectionModel, DocumentModel])

def add_document_connection(
self,
connector_name: str,
configuration: Dict[str, Any],
self,
connector_name: str,
configuration: Dict[str, Any],
) -> int:
return DocumentConnectionModel.create(
connector_name=connector_name,
Expand All @@ -131,12 +128,14 @@ async def remove_document(self, document_id: int, delete_indexed_data: bool = Tr

document.delete_instance()

async def remove_document_connection(self, document_connection_id: int, delete_documents: bool = True, delete_indexed_data: bool = True):
async def remove_document_connection(self, document_connection_id: int, delete_documents: bool = True,
delete_indexed_data: bool = True):
if self.get_document_connection(document_connection_id) is None:
raise ValueError(f"No document connection found with ID {document_connection_id}")

if delete_indexed_data:
source_document_ids = [doc.id for doc in DocumentModel.select().where(DocumentModel.document_connection == document_connection_id)]
source_document_ids = [doc.id for doc in DocumentModel.select().where(
DocumentModel.document_connection == document_connection_id)]
await DataIndex("n/a").delete_documents_with_source_documents(source_document_ids)

with self.db.atomic():
Expand Down Expand Up @@ -205,7 +204,8 @@ async def index_documents(self, document_connection: DocumentConnectionModel):

documents = self._fetch_documents(json.loads(document_connection.configuration))
documents = [
document for doc in documents for document in self._map_selfie_documents_to_index_documents(selfie_document=doc)
document for doc in documents for document in
self._map_selfie_documents_to_index_documents(selfie_document=doc)
]

await DataIndex("n/a").index(documents, extract_importance=False)
Expand All @@ -214,7 +214,8 @@ async def index_documents(self, document_connection: DocumentConnectionModel):

return {"message": f"{len(documents)} documents indexed successfully"}

async def index_document(self, document: DocumentDTO, selfie_documents_to_index_documents: Callable[[DocumentDTO], List[EmbeddingDocumentModel]] = None):
async def index_document(self, document: DocumentDTO, selfie_documents_to_index_documents: Callable[
[DocumentDTO], List[EmbeddingDocumentModel]] = None):
print("Indexing document")

if selfie_documents_to_index_documents is None:
Expand Down Expand Up @@ -260,26 +261,22 @@ def get_document_connections(self):
for source in DocumentConnectionModel.select()
]

def get_documents(self, document_connection_id: Optional[int] = None):
if document_connection_id:
documents = DocumentModel.select().where(DocumentModel.document_connection == document_connection_id)
doc_ids = [str(document.id) for document in documents]
else:
documents = DocumentModel.select()
doc_ids = None

one_embedding_document_per_document = DataIndex("n/a").get_one_document_per_source_document(doc_ids)
indexed_documents = list(set([doc['source_document_id'] for doc in one_embedding_document_per_document]))

return [
{
**model_to_dict(doc),
"is_indexed": doc.id in indexed_documents,
# TODO: for some reason, initializing Embeddings in DataIndex with the SQLAlchemy driver returns indexed_documents as strings, not ints (requires str(doc.id)).
"num_index_documents": DataIndex("n/a").get_document_count([str(doc.id)])
}
for doc in documents
]
def get_documents(self):
documents = DocumentModel.select(DocumentModel.id, DocumentModel.name, DocumentModel.size,
DocumentModel.created_at, DocumentModel.updated_at,
DocumentModel.content_type, DocumentConnectionModel.connector_name).join(
DocumentConnectionModel)

result = []
for doc in documents:
doc_dict = model_to_dict(doc, backrefs=True, only=[
DocumentModel.id, DocumentModel.name, DocumentModel.size,
DocumentModel.created_at, DocumentModel.updated_at,
DocumentModel.content_type, DocumentConnectionModel.connector_name
])
doc_dict['connector_name'] = doc.document_connection.connector_name
result.append(doc_dict)
return result

def get_document(self, document_id: str):
return DocumentModel.get_by_id(document_id)
Expand Down

0 comments on commit db81e61

Please sign in to comment.