From 27d2858ad914f8224785bc8a430bad1921f8a2c8 Mon Sep 17 00:00:00 2001 From: Tim Nunamaker Date: Sat, 9 Mar 2024 13:06:11 -0600 Subject: [PATCH 1/3] docs: write API documentation --- selfie-ui/src/app/page.tsx | 5 +- selfie/api/__init__.py | 43 ++++++++- selfie/api/completions.py | 17 +++- selfie/api/connectors.py | 47 ++++++++- selfie/api/data_sources.py | 2 +- selfie/api/document_connections.py | 107 +++++++++++++++++++-- selfie/api/documents.py | 25 ++--- selfie/api/index_documents.py | 2 +- selfie/api/models.py | 69 ++++++++++---- selfie/connectors/whatsapp/connector.py | 2 +- selfie/types/completion_requests.py | 121 ++++++++++++++++-------- 11 files changed, 348 insertions(+), 92 deletions(-) diff --git a/selfie-ui/src/app/page.tsx b/selfie-ui/src/app/page.tsx index 68d077b..64dae80 100644 --- a/selfie-ui/src/app/page.tsx +++ b/selfie-ui/src/app/page.tsx @@ -88,9 +88,12 @@ const App = () => { ))}
  • - + API Docs + + API Sandbox +
  • LlamaCppChatCompletionResponse | LitellmCompletionResponse: @@ -19,7 +25,12 @@ async def create_chat_completion( # TODO can StreamingResponse's schema be defined? -@router.post("/completions") +@router.post("/completions", + # tags=["OpenAI"], + # summary="Create completion", + description=""" + Creates a response for the given prompt in [the style of OpenAI](https://platform.openai.com/docs/api-reference/completions/create). + """) async def create_completion( request: CompletionRequest, ) -> LlamaCppCompletionResponse | LitellmCompletionResponse: diff --git a/selfie/api/connectors.py b/selfie/api/connectors.py index 5c57e1c..bdfa1b9 100644 --- a/selfie/api/connectors.py +++ b/selfie/api/connectors.py @@ -3,9 +3,38 @@ from fastapi import APIRouter from pydantic import BaseModel +from selfie.connectors.whatsapp.connector import WhatsAppConnector from selfie.connectors.factory import ConnectorFactory -router = APIRouter() +router = APIRouter(tags=["Configuration"]) + +example = { + "id": WhatsAppConnector().id, + "name": WhatsAppConnector().name, + "documentation": "WhatsApp is a popular messaging app...", # TODO: Read this from connectors/whatsapp, maybe truncate + "form_schema": { # TODO: Read this from connectors/whatsapp + "title": "Upload WhatsApp Conversations", + "type": "object", + "properties": { + "files": { + "type": "array", + "title": "Files", + "description": "Upload .txt files exported from WhatsApp", + "items": { + "type": "object" + } + } + } + }, + "ui_schema": { # TODO: Read this from connectors/whatsapp + "files": { + "ui:widget": "nativeFile", + "ui:options": { + "accept": ".txt" + } + } + } +} class Connector(BaseModel): @@ -15,18 +44,30 @@ class Connector(BaseModel): form_schema: Optional[dict] = None ui_schema: Optional[dict] = None + model_config = { + "json_schema_extra": { + "example": example + } + } + class ConnectorsResponse(BaseModel): connectors: List[Connector] = [] -@router.get("/connectors") +@router.get("/connectors", + description="""List all available connectors. This endpoint fetches and returns a comprehensive list of all connectors configured in the system, along with their respective details including ID, name, optional documentation, form schema, and UI schema if available. + +### Response Format +Returns a `ConnectorsResponse` object containing a list of `Connector` objects each detailing a connector available in the system. +""") async def get_connectors() -> ConnectorsResponse: connectors = ConnectorFactory.get_all_connectors() return ConnectorsResponse(connectors=connectors) -@router.get("/connectors/{connector_id}") +@router.get("/connectors/{connector_id}", + description="Retrieve detailed information about a specific connector by its ID. This includes its name, documentation, and any schemas related to form and UI configurations.") async def get_connector(connector_id: str) -> Connector: connector_instance = ConnectorFactory.get_connector(connector_name=connector_id) return Connector( diff --git a/selfie/api/data_sources.py b/selfie/api/data_sources.py index 407305a..e407cf2 100644 --- a/selfie/api/data_sources.py +++ b/selfie/api/data_sources.py @@ -6,7 +6,7 @@ from selfie.database import DataManager -router = APIRouter() +router = APIRouter(tags=["Deprecated"]) class DataLoaderRequest(BaseModel): diff --git a/selfie/api/document_connections.py b/selfie/api/document_connections.py index b1dc350..8a2d62e 100644 --- a/selfie/api/document_connections.py +++ b/selfie/api/document_connections.py @@ -1,21 +1,44 @@ import base64 import json -from typing import Any +from typing import Any, Dict from urllib.parse import quote -from fastapi import APIRouter, Request, UploadFile, Form, HTTPException, Depends -from pydantic import BaseModel +from fastapi import APIRouter, Request, UploadFile, Form, HTTPException, Depends, Body +from pydantic import BaseModel, Field from selfie.connectors.factory import ConnectorFactory from selfie.database import DataManager, DocumentModel from selfie.embeddings import DataIndex -router = APIRouter() +router = APIRouter(tags=["Data Management"]) class DocumentConnectionRequest(BaseModel): - connector_id: str - configuration: Any + connector_id: str = Field(..., description="The ID of the connector to use for creating the document connection.") + configuration: Dict[str, Any] = Field(..., description="The configuration object for the document connection.") + + model_config = { + "json_schema_extra": { + "example": { + "connector_id": "whatsapp", + "configuration": { + "files": ["data:text/plain;name=example.txt;base64,SGVsbG8gV29ybGQ="] + } + } + } + } + + +class DocumentConnectionResponse(BaseModel): + message: str = Field(..., description="A message indicating the status of the document connection creation.") + + model_config = { + "json_schema_extra": { + "example": { + "message": "Document connection created successfully" + } + } + } async def file_to_data_uri(file: UploadFile): @@ -67,8 +90,74 @@ async def parse_create_document_connection_request(request: Request): return connector_id, configuration -@router.post("/document-connections") -async def create_document_connection(request: Request, parsed_data: tuple = Depends(parse_create_document_connection_request)): +@router.post("/document-connections", + description="""Create a new document connection using the specified connector and configuration. + +The request can be sent as multipart/form-data or application/json. Because browsers tend to limit the size of data URIs, it is recommended to use multipart/form-data for large files. + +### Using multipart/form-data + +For multipart/form-data, the connector_id and configuration should be provided as form fields. The configuration object should contain placeholders for file references, which will be replaced with data URIs from corresponding file fields in the request body. + +For example, the following form fields can be used to create a document connection: + + connector_id: whatsapp + configuration: {"files": ["file-0"]} + file-0: + +### Using application/json + +For application/json, the connector_id and configuration should be provided in the request body. The configuration object should contain files as data URIs. + +For example, the following JSON can be used to create a document connection: + + { + "connector_id": "whatsapp", + "configuration": { + "files": ["data:text/plain;name=example.txt;base64,SGVsbG8gV29ybGQ="] + } + } +""", + openapi_extra={ + "requestBody": { + "content": { + "application/json": { + "schema": DocumentConnectionRequest.schema() + }, + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "connector_id": {"type": "string"}, + "configuration": {"type": "string"} + }, + "required": ["connector_id", "configuration"], + "patternProperties": { + "^file-\\d+$": { + "type": "string", + "format": "binary" + } + } + }, + "examples": { + "example1": { + "summary": "Example 1", + "value": """connector_id: whatsapp +configuration: {"files": ["file-0"]} +file-0: +""" + } + } + } + } + } + }) +# request_body=DocumentConnectionRequest) +async def create_document_connection( + request: Request, + # document_connection_request: DocumentConnectionRequest = Body(..., description="The document connection request."), + parsed_data: tuple = Depends(parse_create_document_connection_request) +) -> DocumentConnectionResponse: connector_id, configuration = parsed_data connector_instance = ConnectorFactory.get_connector(connector_name=connector_id) connector_instance.validate_configuration(configuration=configuration) @@ -96,4 +185,4 @@ async def create_document_connection(request: Request, parsed_data: tuple = Depe # Save embedding_documents to Vector DB await DataIndex("n/a").index(embedding_documents, extract_importance=False) - return {"message": "Document connection created successfully"} + return DocumentConnectionResponse(message="Document connection created successfully") diff --git a/selfie/api/documents.py b/selfie/api/documents.py index 474de6a..c963b00 100644 --- a/selfie/api/documents.py +++ b/selfie/api/documents.py @@ -7,6 +7,7 @@ from selfie.embeddings import DataIndex from selfie.parsers.chat import ChatFileParser +# router = APIRouter(tags=["Data Management"]) router = APIRouter() @@ -23,30 +24,35 @@ class DeleteDocumentsRequest(BaseModel): document_ids: List[str] = [] -@router.get("/documents") +@router.get("/documents", + tags=["Data Management"]) async def get_documents(): return DataManager().get_documents() -@router.delete("/documents") -async def index_documents(request: DeleteDocumentsRequest): +@router.delete("/documents", + tags=["Data Management"]) +async def delete_documents(request: DeleteDocumentsRequest): await DataManager().remove_documents([int(document_id) for document_id in request.document_ids]) return {"message": "Documents removed successfully"} -@router.delete("/documents/{document_id}") -async def delete_data_source(document_id: int, delete_indexed_data: bool = True): +@router.delete("/documents/{document_id}", + tags=["Data Management"]) +async def delete_document(document_id: int, delete_indexed_data: bool = True): await DataManager().remove_document(document_id, delete_indexed_data) return {"message": "Document removed successfully"} -@router.post("/documents/unindex") +@router.post("/documents/unindex", + tags=["Deprecated"]) async def unindex_documents(request: UnindexDocumentsRequest): await DataIndex("n/a").delete_documents_with_source_documents(request.document_ids) return {"message": "Document unindexed successfully"} -@router.post("/documents/index") +@router.post("/documents/index", + tags=["Deprecated"]) async def index_documents(request: IndexDocumentsRequest): is_chat = request.is_chat document_ids = request.document_ids @@ -72,8 +78,3 @@ async def index_documents(request: IndexDocumentsRequest): ) if is_chat else None) for document_id in document_ids ] - -# @app.delete("/documents/{document-id}") -# async def delete_data_source(document_id: int): -# DataSourceManager().remove_document(document_id) -# return {"message": "Document removed successfully"} diff --git a/selfie/api/index_documents.py b/selfie/api/index_documents.py index c21f3bf..bef0c96 100644 --- a/selfie/api/index_documents.py +++ b/selfie/api/index_documents.py @@ -12,7 +12,7 @@ from datetime import datetime import importlib -router = APIRouter() +router = APIRouter(tags=["Deprecated"]) @router.get("/index_documents") diff --git a/selfie/api/models.py b/selfie/api/models.py index bb98713..ee85725 100644 --- a/selfie/api/models.py +++ b/selfie/api/models.py @@ -1,11 +1,43 @@ +from datetime import datetime from fastapi import APIRouter from huggingface_hub import scan_cache_dir +from typing import List, Optional +from pydantic import BaseModel, Field -router = APIRouter() +router = APIRouter(tags=["Configuration"]) -@router.get("/models") -async def get_models(): +class Model(BaseModel): + id: str = Field(..., description="The unique identifier of the model, formatted as 'repo_id/filename'") + object: str = Field(..., description="The type of the object, typically 'model'") + created: datetime = Field(..., description="The timestamp when the model was last modified in the cache") + owned_by: str = Field(..., description="Indicates the ownership of the model, typically 'user' for user-uploaded models") + + +class ModelsResponse(BaseModel): + object: str = Field(..., description="The type of the response, always 'list'") + data: List[Model] = Field(..., description="A list of model objects detailing available models in the cache") + + model_config = { + "json_schema_extra": { + "example": { + "object": "list", + "data": [ + { + "id": "user/repo/modelname.gguf", + "object": "model", + "created": "2021-01-01T00:00:00Z", + "owned_by": "user" + } + ] + } + } + } + + +@router.get("/models", + description="Retrieve a list of **already-downloaded llama.cpp models** (in the Hugging Face Hub cache). This endpoint scans the cache directory for model files (specifically looking for files with a '.gguf' extension within each repository revision) and returns a list of models including their ID, object type, creation timestamp, and ownership information.") +async def get_models() -> ModelsResponse: hf_cache_info = scan_cache_dir() models = [] @@ -14,21 +46,18 @@ async def get_models(): gguf_files = [file for file in revision.files if file.file_name.endswith('.gguf')] if gguf_files: for gguf_file in gguf_files: - models.append({ - "id": f"{repo.repo_id}/{gguf_file.file_name}", - "object": "model", - "created": gguf_file.blob_last_modified, - "owned_by": "user" - }) + models.append(Model( + id=f"{repo.repo_id}/{gguf_file.file_name}", + object="model", + created=gguf_file.blob_last_modified, + owned_by="user" + )) else: - models.append({ - "id": repo.repo_id, - "object": "model", - "created": min(file.last_modified for file in repo.revisions), - "owned_by": "user" - }) - - return { - "object": "list", - "data": models - } + models.append(Model( + id=repo.repo_id, + object="model", + created=min(file.last_modified for file in repo.revisions), + owned_by="user" + )) + + return ModelsResponse(object="list", data=models) diff --git a/selfie/connectors/whatsapp/connector.py b/selfie/connectors/whatsapp/connector.py index 9946214..d80a387 100644 --- a/selfie/connectors/whatsapp/connector.py +++ b/selfie/connectors/whatsapp/connector.py @@ -16,7 +16,7 @@ class WhatsAppConfiguration(BaseModel): class WhatsAppConnector(BaseConnector, ABC): def __init__(self): super().__init__() - self.id = "whatsapp" + self.id = "whatsapp" # TODO: this should be static self.name = "WhatsApp" def load_document(self, configuration: dict[str, Any]) -> List[DocumentDTO]: diff --git a/selfie/types/completion_requests.py b/selfie/types/completion_requests.py index f9107a6..9974849 100644 --- a/selfie/types/completion_requests.py +++ b/selfie/types/completion_requests.py @@ -1,4 +1,6 @@ from typing import Optional, Literal, List, ClassVar, Union, Dict, Any +from pydantic import BaseModel, TypeAdapter, Extra, Field + from litellm import ModelResponse as LitellmCompletionResponse from llama_cpp import ( CreateCompletionResponse as LlamaCppCompletionResponse, @@ -7,7 +9,6 @@ from openai.types.chat import ChatCompletionMessage from openai.types.completion_create_params import CompletionCreateParams, CompletionCreateParamsNonStreaming, CompletionCreateParamsStreaming from openai.types.chat import CompletionCreateParams as ChatCompletionCreateParams -from pydantic import BaseModel, TypeAdapter, Extra from sse_starlette import EventSourceResponse # from typing_extensions import TypedDict @@ -16,48 +17,64 @@ class Message(BaseModel): - role: str - content: str + role: str = Field(..., description="The role of the message sender, e.g., 'system', 'user', or 'assistant'") + content: str = Field(..., description="The content of the message") class FunctionCall(BaseModel): - name: str - parameters: Optional[Dict[str, Any]] = None + name: str = Field(..., description="The name of the function to call") + parameters: Optional[Dict[str, Any]] = Field(None, description="The parameters to pass to the function") class Tool(BaseModel): - type: str - function: Optional[FunctionCall] = None + type: str = Field(..., description="The type of the tool") + function: Optional[FunctionCall] = Field(None, description="The function to call when using the tool") class BaseCompletionRequest(BaseModel): # OpenAI parameters - model: Optional[str] = None - frequency_penalty: Optional[float] = 0.0 - logit_bias: Optional[Dict[int, float]] = None + model: Optional[str] = Field(None, description="ID of the model to use for completion") + frequency_penalty: Optional[float] = Field(0.0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.") + logit_bias: Optional[Dict[int, float]] = Field(None, description="Modify the likelihood of specified tokens appearing in the completion.") # logprobs: Optional[bool] = False - max_tokens: Optional[int] = None + max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate in the completion.") # n: Optional[int] = 1 - presence_penalty: Optional[float] = 0.0 - response_format: Optional[Dict[str, str]] = None - seed: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None - stream: Optional[bool] = False - temperature: Optional[float] = 1.0 - top_p: Optional[float] = 1.0 - tools: Optional[List[Tool]] = None - tool_choice: Optional[Union[str, Dict[str, Any]]] = None + presence_penalty: Optional[float] = Field(0.0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.") + response_format: Optional[Dict[str, str]] = Field(None, description="An object specifying the format that the model must output.") + seed: Optional[int] = Field(None, description="If specified, the returned completion will be deterministic. Generated tokens will be the same for each request with the same seed.") + stop: Optional[Union[str, List[str]]] = Field(None, description="Up to 4 sequences where the API will stop generating further tokens.") + stream: Optional[bool] = Field(False, description="If set, partial completion results will be sent as they become available.") + temperature: Optional[float] = Field(1.0, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.") + top_p: Optional[float] = Field(1.0, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.") + tools: Optional[List[Tool]] = Field(None, description="A list of tools the model may call.") + tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Controls which (if any) function is called by the model. Options are 'none', 'auto', or a specific function.") # user: Optional[str] = None # Selfie parameters - method: Optional[Literal["litellm", "llama.cpp", "transformers"]] = None - api_base: Optional[str] = None - api_key: Optional[str] = None - disable_augmentation: Optional[bool] = False + method: Optional[Literal["litellm", "llama.cpp", "transformers"]] = Field(None, description="The method to use for completion, e.g., 'litellm', 'llama.cpp', or 'transformers'.") + api_base: Optional[str] = Field(None, description="The base URL for the API") + api_key: Optional[str] = Field(None, description="The API key to use for authentication") + disable_augmentation: Optional[bool] = Field(False, description="Whether to disable data augmentation during completion") - # Custom parameters, e.g. for a custom API class Config: + # Allow custom parameters, e.g. for a custom API extra = Extra.allow + model_config = { + "json_schema_extra": { + "example": { + "model": "gpt-3.5-turbo", + "prompt": "Hello, how are you?", + "max_tokens": 50, + "temperature": 0.8, + "method": "litellm", + "api_key": "your_api_key" + } + } + } + + +class ChatCompletionRequest(BaseCompletionRequest): + messages: List[Message] = Field(..., description="A list of messages comprising the conversation so far.") custom_params: ClassVar[List[str]] = ["method", "api_base", "api_key", "disable_augmentation"] @@ -78,27 +95,53 @@ def extra_params(self): """ return {k: v for k, v in self.model_dump().items() if k not in self.model_fields.keys()} - -class ChatCompletionRequest(BaseCompletionRequest): - messages: List[Message] + model_config = { + "json_schema_extra": { + "example": { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"} + ], + "max_tokens": 50, + "temperature": 0.8, + "method": "litellm", + "api_key": "your_api_key" + } + } + } class CompletionRequest(BaseCompletionRequest): - prompt: Union[str, List[str]] + prompt: Union[str, List[str]] = Field(..., description="The prompt(s) to generate completions for. Can be a string or a list of strings.") # best_of: Optional[int] = None - echo: Optional[bool] = None - logprobs: Optional[int] = None + echo: Optional[bool] = Field(None, description="Whether to echo the prompt in the response.") + logprobs: Optional[int] = Field(None, description="Include the log probabilities on the logprobs most likely tokens, as well the chosen tokens. So for example, if logprobs is 10, the API will return a list of the 10 most likely tokens. If logprobs is supplied, the API will always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the response.") # n: Optional[int] = None - suffix: Optional[str] = None + suffix: Optional[str] = Field(None, description="The suffix that comes after a completion of inserted text.") + + model_config = { + "json_schema_extra": { + "example": { + "model": "gpt-3.5-turbo", + "prompt": "Once upon a time", + "max_tokens": 50, + "temperature": 0.8, + "method": "litellm", + "api_key": "your_api_key" + } + } + } -class ChatCompletionResponse(BaseModel): - id: str - object: str = "chat.completion" - created: int - model: Optional[str] - choices: List[Dict[str, Any]] - usage: Dict[str, int] +# class ChatCompletionResponse(BaseModel): +# id: str = Field(..., description="The ID of the completion") +# object: str = Field('chat.completion', description="The object type, e.g., 'chat.completion'") +# created: int = Field(..., description="The timestamp of the completion creation") +# model: Optional[str] = Field(None, description="The model used for the completion") +# choices: List[Dict[str, Any]] = Field(..., description="The choices of completions") +# usage: Dict[str, int] = Field(..., description="The usage of the completion") +# ########################## From 52cc93f2366ddd10fedc145277cca240709d1470 Mon Sep 17 00:00:00 2001 From: Tim Nunamaker Date: Mon, 11 Mar 2024 15:56:19 -0500 Subject: [PATCH 2/3] Document query endpoint --- .../components/Playground/PlaygroundQuery.tsx | 160 ++++++++++++++---- selfie/api/__init__.py | 22 ++- selfie/api/completions.py | 7 +- selfie/api/documents.py | 110 +++++++++++- selfie/embeddings/__init__.py | 1 + selfie/embeddings/document_types.py | 24 +++ selfie/types/completion_requests.py | 12 +- 7 files changed, 270 insertions(+), 66 deletions(-) diff --git a/selfie-ui/src/app/components/Playground/PlaygroundQuery.tsx b/selfie-ui/src/app/components/Playground/PlaygroundQuery.tsx index b8bf0d7..54af79b 100644 --- a/selfie-ui/src/app/components/Playground/PlaygroundQuery.tsx +++ b/selfie-ui/src/app/components/Playground/PlaygroundQuery.tsx @@ -3,10 +3,25 @@ import { apiBaseUrl } from "@/app/config"; import useAsyncTask from "@/app/hooks/useAsyncTask"; import TaskToast from "@/app/components/TaskToast"; - -const fetchDocuments = async (topic: string, limit?: number, minScore?: number, includeSummary?: boolean) => { - const params = new URLSearchParams({ topic, ...(limit && { limit: limit.toString() }), ...(minScore && { min_score: minScore.toString() }), ...(includeSummary !== undefined && { include_summary: includeSummary.toString() }) }); - const url = `${apiBaseUrl}/v1/index_documents/summary?${params.toString()}`; +const fetchDocuments = async ( + query: string, + limit?: number, + minScore?: number, + includeSummary?: boolean, + relevanceWeight?: number, + recencyWeight?: number, + importanceWeight?: number +) => { + const params = new URLSearchParams({ + query, + ...(limit && { limit: limit.toString() }), + ...(minScore && { min_score: minScore.toString() }), + ...(includeSummary !== undefined && { include_summary: includeSummary.toString() }), + ...(relevanceWeight && { relevance_weight: relevanceWeight.toString() }), + ...(recencyWeight && { recency_weight: recencyWeight.toString() }), + ...(importanceWeight && { importance_weight: importanceWeight.toString() }), + }); + const url = `${apiBaseUrl}/v1/documents/search?${params.toString()}`; try { const response = await fetch(url); @@ -23,10 +38,14 @@ const PlaygroundQuery = () => { const [documents, setDocuments] = useState([]); const [summary, setSummary] = useState(""); const [isSummaryLoading, setSummaryLoading] = useState(false); - const [score, setScore] = useState(0); + const [averageScore, setAverageScore] = useState(0); + const [totalResults, setTotalResults] = useState(0); const [limit, setLimit] = useState(); const [minScore, setMinScore] = useState(); const [includeSummary, setIncludeSummary] = useState(true); + const [relevanceWeight, setRelevanceWeight] = useState(); + const [recencyWeight, setRecencyWeight] = useState(); + const [importanceWeight, setImportanceWeight] = useState(); const handleInputChange = (setter: React.Dispatch>) => (e: React.ChangeEvent) => { const value = e.target.type === "number" ? Number(e.target.value) || undefined : e.target.value; @@ -41,16 +60,26 @@ const PlaygroundQuery = () => { e.preventDefault(); executeTask(async () => { - setScore(0); + setAverageScore(0); + setTotalResults(0); setDocuments([]); setSummary(""); setSummaryLoading(true); - const results = await fetchDocuments(query, limit, minScore, includeSummary); - setScore(results.score); + const results = await fetchDocuments( + query, + limit, + minScore, + includeSummary, + relevanceWeight, + recencyWeight, + importanceWeight + ); + setAverageScore(results.average_score); + setTotalResults(results.total_results); setDocuments(results.documents); setSummary(results.summary); setSummaryLoading(false); - console.log("Searching with:", query, limit, minScore, includeSummary); + console.log("Searching with:", query, limit, minScore, includeSummary, relevanceWeight, recencyWeight, importanceWeight); }, { start: "Searching...", success: "Search complete", @@ -62,12 +91,9 @@ const PlaygroundQuery = () => { return (
    - {/*

    Document {doc.id}

    */} -

    Embedding document {i}

    +

    Embedding Document {doc.id}

    {doc.text}
      - {/*
    • Score: {doc.score}
    • */} - {/* only 2 decimal */}
    • Overall score: {doc.score.toFixed(2)}
    • Relevance score: {doc.relevance.toFixed(2)}
    • Recency score: {doc.recency.toFixed(2)}
    • @@ -114,7 +140,39 @@ const PlaygroundQuery = () => { onChange={(e) => setMinScore(Number(e.target.value) || undefined)} min="0" max="1" - step="0.1" + step="0.01" + /> +
    + +
    + {/**/} + setRelevanceWeight(e.target.value ? Number(e.target.value) : undefined)} + min="0" + max="1" + step="0.01" + /> +
    + +
    + {/**/} + setRecencyWeight(e.target.value ? Number(e.target.value) : undefined)} + min="0" + max="1" + step="0.01" />
    @@ -145,34 +203,62 @@ const PlaygroundQuery = () => {
    - {!!summary &&
    - {/**/} -

    {summary}

    - {documents.length ?

    Result Score: {score.toFixed(2)}

    : null } -
    } + {!!summary && ( +
    + {/**/} +

    {summary}

    + {documents.length ? ( +
    +

    Total Results: {totalResults}

    +

    Average Score: {averageScore.toFixed(2)}

    +
    + ) : null} +
    + )} - {!!score &&
    - {documents.map(renderDocument)} -
    } + {documents.length > 0 &&
    {documents.map(renderDocument)}
    } ); }; -const SearchIcon = () => - -; - -const LoadingIcon = () => - - - -; +const SearchIcon = () => ( + + + +); + +const LoadingIcon = () => ( + + + + + +); PlaygroundQuery.displayName = "PlaygroundQuery"; diff --git a/selfie/api/__init__.py b/selfie/api/__init__.py index 911c409..489745d 100644 --- a/selfie/api/__init__.py +++ b/selfie/api/__init__.py @@ -30,22 +30,28 @@ tags_metadata = [ { "name": "Completions", - "description": "Endpoints for generating chat and text completions.", + "description": """Endpoints for generating chat and text completions. Selfie completion endpoints can be used as drop-in replacements for endpoints in the [OpenAI API](https://platform.openai.com/docs/api-reference). + +These endpoints generally include additional functionality not present in the OpenAI API, e.g. you can use a flag to control whether or not to Selfie data is used during text generation. + +Please see the [API Usage Guide](https://github.com/vana-com/selfie/?tab=readme-ov-file#api-usage-guide) for more information on how to use these endpoints. +""", }, { - "name": "Configuration", - "description": "Endpoints for configuring and managing Selfie.", + "name": "Search", + "description": "Endpoints for searching and analyzing documents.", }, { "name": "Data Management", - "description": "Endpoints for managing data sources, documents, and data indexing.", + "description": """Endpoints for managing data sources, documents, and data indexing. + +These endpoints are primarily intended to be used by the Selfie UI.""" }, { - "name": "OpenAI", - "description": """Endpoints that can be used as drop-in replacements for endpoints in the [OpenAI API](https://platform.openai.com/docs/api-reference). + "name": "Configuration", + "description": """Endpoints for configuring and managing Selfie. -These endpoints generally include additional functionality not present in the OpenAI API. For example, completion endpoints support a flag for controlling whether or not to use Selfie data during text generation. -""", +These endpoints are primarily intended to be used by the Selfie UI.""" }, { "name": "Deprecated", diff --git a/selfie/api/completions.py b/selfie/api/completions.py index 62f6e0a..96f5cc2 100644 --- a/selfie/api/completions.py +++ b/selfie/api/completions.py @@ -8,13 +8,10 @@ from selfie.text_generation import completion -# router = APIRouter(tags=["OpenAI"]) -router = APIRouter(tags=["Completions", "OpenAI"]) +router = APIRouter(tags=["Completions"]) @router.post("/chat/completions", - # tags=["OpenAI"], - # summary="Create chat completion", description=""" Creates a response for the given conversation in [the style of OpenAI](https://platform.openai.com/docs/api-reference/chat/create). """) @@ -26,8 +23,6 @@ async def create_chat_completion( # TODO can StreamingResponse's schema be defined? @router.post("/completions", - # tags=["OpenAI"], - # summary="Create completion", description=""" Creates a response for the given prompt in [the style of OpenAI](https://platform.openai.com/docs/api-reference/completions/create). """) diff --git a/selfie/api/documents.py b/selfie/api/documents.py index c963b00..c019f8c 100644 --- a/selfie/api/documents.py +++ b/selfie/api/documents.py @@ -1,8 +1,9 @@ -from typing import List +from typing import List, Optional -from fastapi import APIRouter -from pydantic import BaseModel +from fastapi import APIRouter, Query +from pydantic import BaseModel, Field +from selfie.embeddings import ScoredEmbeddingDocumentModel from selfie.database import DataManager from selfie.embeddings import DataIndex from selfie.parsers.chat import ChatFileParser @@ -24,24 +25,115 @@ class DeleteDocumentsRequest(BaseModel): document_ids: List[str] = [] +class FetchedDocument(BaseModel): + id: int = Field(..., description="The unique identifier of the document") + name: str = Field(..., description="The name of the document") + size: int = Field(..., description="The size of the document") + created_at: str = Field(..., description="The timestamp of the document creation") + updated_at: str = Field(..., description="The timestamp of the document update") + content_type: str = Field(..., description="The content type of the document") + connector_name: str = Field(..., description="The name of the connector") + + model_config = { + "json_schema_extra": { + "example": { + "id": 1, + "name": "example.txt", + "size": 1024, + "created_at": "2024-03-11T18:33:04.733583", + "updated_at": "2024-03-11T18:33:04.733590", + "content_type": "text/plain", + "connector_name": "whatsapp", + } + } + } + + @router.get("/documents", tags=["Data Management"]) -async def get_documents(): +async def get_documents() -> List[FetchedDocument]: return DataManager().get_documents() @router.delete("/documents", - tags=["Data Management"]) + tags=["Data Management"], + description="Remove multiple documents by their IDs.", + status_code=204) async def delete_documents(request: DeleteDocumentsRequest): await DataManager().remove_documents([int(document_id) for document_id in request.document_ids]) - return {"message": "Documents removed successfully"} @router.delete("/documents/{document_id}", - tags=["Data Management"]) -async def delete_document(document_id: int, delete_indexed_data: bool = True): + tags=["Data Management"], + description="Remove a document by its ID.", + status_code=204) +async def delete_document(document_id: int, delete_indexed_data: Optional[bool] = True): await DataManager().remove_document(document_id, delete_indexed_data) - return {"message": "Document removed successfully"} + + +class SearchDocumentsResponse(BaseModel): + query: str = Field(..., description="The search query") + total_results: int = Field(..., description="The total number of documents found") + average_score: float = Field(..., description="The mean relevance score of the documents") + documents: List[ScoredEmbeddingDocumentModel] = Field(..., description="The documents found") + summary: Optional[str] = Field(None, description="A summary of the search results") + + model_config = { + "json_schema_extra": { + "example": { + "query": "What is the meaning of life?", + "total_results": 1, + "average_score": 0.4206031938249788, + "documents": [ + { + "id": 1, + "text": "The meaning of life is 42.", + "source": "whatsapp", + "timestamp": "2023-03-12T11:35:00Z", + "created_timestamp": "2024-03-11T18:33:04.733583", + "updated_timestamp": "2024-03-11T18:33:04.733590", + "source_document_id": 3, + "score": 0.4206031938249788, + "relevance": 0.08712080866098404, + "recency": 0.4720855789889736, + "importance": None, + }, + ], + "summary": "The meaning of life is 42." + } + } + } + + +@router.get("/documents/search", + tags=["Search"], + description="Search for embedding documents that most closely match a query.") +async def search_documents( + query: str, + limit: Optional[int] = Query(3, ge=1, le=100, description="Maximum number of documents to fetch"), + min_score: Optional[float] = Query(0.4, ge=0.0, le=1.0, description="Minimum score for embedding documents"), + include_summary: Optional[bool] = Query(False, description="Include a summary of the search results"), + relevance_weight: Optional[float] = Query(1.0, le=1.0, ge=0.0, description="Weight for relevance in the scoring algorithm"), + recency_weight: Optional[float] = Query(1.0, le=1.0, ge=0.0, description="Weight for recency in the scoring algorithm"), + importance_weight: Optional[float] = Query(0, le=1.0, ge=0.0, description="**Importance scores are currently not calculated, so this weight has no effect!** Weight for document importance in the scoring algorithm.") +) -> SearchDocumentsResponse: + result = await DataIndex("n/a").recall( + topic=query, + limit=limit, + min_score=min_score, + include_summary=include_summary, + relevance_weight=relevance_weight, + recency_weight=recency_weight, + importance_weight=importance_weight, + ) + + return SearchDocumentsResponse( + query=query, + total_results=len(result["documents"]), + average_score=result["mean_score"], + documents=result["documents"], + summary=result["summary"] if include_summary else None + ) @router.post("/documents/unindex", diff --git a/selfie/embeddings/__init__.py b/selfie/embeddings/__init__.py index 5c9870a..c597319 100644 --- a/selfie/embeddings/__init__.py +++ b/selfie/embeddings/__init__.py @@ -323,6 +323,7 @@ async def recall( documents_list: List[ScoredEmbeddingDocumentModel] = [] for result in results: document = EmbeddingDocumentModel( + id=result["id"], text=result["text"], timestamp=result["timestamp"], importance=result["importance"], diff --git a/selfie/embeddings/document_types.py b/selfie/embeddings/document_types.py index 0092d5f..1dd0d30 100644 --- a/selfie/embeddings/document_types.py +++ b/selfie/embeddings/document_types.py @@ -23,6 +23,18 @@ def to_dict(self, *args, **kwargs): class Config: validate_assignment = True + json_schema_extra = { + "example": { + "id": 42, + "text": "What is the meaning of life?", + "source": "whatsapp", + "importance": None, + "timestamp": "2022-01-01T00:00:00Z", + "created_timestamp": "2022-01-01T00:00:00Z", + "updated_timestamp": "2022-01-01T00:00:00Z", + "source_document_id": 42 + } + } @model_validator(mode='before') def autofill_timestamps(cls, values): @@ -37,3 +49,15 @@ class ScoredEmbeddingDocumentModel(EmbeddingDocumentModel): importance: Optional[float] = Field(..., description="Importance score of the document, [0, 1]") relevance: float = Field(..., description="Relevance score of the document, for a query, [0, 1]") recency: float = Field(..., description="Recency score of the document, for a query (time), [0, 1]") + + model_config = { + "json_schema_extra": { + "example": { + **EmbeddingDocumentModel.model_config['json_schema_extra']['example'], + "score": 0.42, + "relevance": 0.42, + "recency": 0.42, + "importance": None, # For now + } + } + } diff --git a/selfie/types/completion_requests.py b/selfie/types/completion_requests.py index 9974849..4e1c9ad 100644 --- a/selfie/types/completion_requests.py +++ b/selfie/types/completion_requests.py @@ -62,12 +62,12 @@ class Config: model_config = { "json_schema_extra": { "example": { + "method": "litellm", "model": "gpt-3.5-turbo", + "api_key": "your-api-key", "prompt": "Hello, how are you?", "max_tokens": 50, "temperature": 0.8, - "method": "litellm", - "api_key": "your_api_key" } } } @@ -98,15 +98,15 @@ def extra_params(self): model_config = { "json_schema_extra": { "example": { + "method": "litellm", "model": "gpt-3.5-turbo", + "api_key": "your-api-key", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"} ], "max_tokens": 50, "temperature": 0.8, - "method": "litellm", - "api_key": "your_api_key" } } } @@ -123,12 +123,12 @@ class CompletionRequest(BaseCompletionRequest): model_config = { "json_schema_extra": { "example": { + "method": "litellm", "model": "gpt-3.5-turbo", + "api_key": "your-api-key", "prompt": "Once upon a time", "max_tokens": 50, "temperature": 0.8, - "method": "litellm", - "api_key": "your_api_key" } } } From 49a0d7302e216929470552188a712257838ffc02 Mon Sep 17 00:00:00 2001 From: Tim Nunamaker Date: Mon, 11 Mar 2024 16:00:17 -0500 Subject: [PATCH 3/3] Delete deprecated endpoints --- selfie/api/__init__.py | 13 +-- selfie/api/data_sources.py | 59 ------------- selfie/api/documents.py | 38 --------- selfie/api/index_documents.py | 153 ---------------------------------- selfie/api/models.py | 2 +- 5 files changed, 5 insertions(+), 260 deletions(-) delete mode 100644 selfie/api/data_sources.py delete mode 100644 selfie/api/index_documents.py diff --git a/selfie/api/__init__.py b/selfie/api/__init__.py index 489745d..518ff1f 100644 --- a/selfie/api/__init__.py +++ b/selfie/api/__init__.py @@ -10,10 +10,8 @@ from selfie.api.completions import router as completions_router from selfie.api.connectors import router as connectors_router -from selfie.api.data_sources import router as data_sources_router from selfie.api.document_connections import router as document_connections_router from selfie.api.documents import router as documents_router -from selfie.api.index_documents import router as index_documents_router from selfie.api.models import router as models_router from selfie.api.connectors import router as connectors_router from selfie.config import get_app_config @@ -53,11 +51,10 @@ These endpoints are primarily intended to be used by the Selfie UI.""" }, - { - "name": "Deprecated", - "description": "Endpoints that are deprecated and should not be used.", - } - + # { + # "name": "Deprecated", + # "description": "Endpoints that are deprecated and should not be used.", + # } ] @@ -93,9 +90,7 @@ app.include_router(completions_router) app.include_router(connectors_router) app.include_router(document_connections_router) -app.include_router(data_sources_router) app.include_router(documents_router) -app.include_router(index_documents_router) app.include_router(models_router) app.include_router(connectors_router) diff --git a/selfie/api/data_sources.py b/selfie/api/data_sources.py deleted file mode 100644 index e407cf2..0000000 --- a/selfie/api/data_sources.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Optional, Any, List, Dict - -from fastapi import APIRouter, HTTPException -from playhouse.shortcuts import model_to_dict -from pydantic import BaseModel - -from selfie.database import DataManager - -router = APIRouter(tags=["Deprecated"]) - - -class DataLoaderRequest(BaseModel): - name: Optional[str] - loader_module: str - constructor_args: Optional[List[Any]] = [] - constructor_kwargs: Optional[Dict[str, Any]] = {} - load_data_args: Optional[List[Any]] = [] - load_data_kwargs: Optional[Dict[str, Any]] = {} - - -@router.get("/data-sources") -async def get_data_sources(): - return DataManager().get_document_connections() - - -@router.post("/data-sources") -async def add_data_source(request: DataLoaderRequest): - configuration = { - "loader_name": request.loader_module, - "constructor_args": request.constructor_args, - "constructor_kwargs": request.constructor_kwargs, - "load_data_args": request.load_data_args, - "load_data_kwargs": request.load_data_kwargs, - } - - return model_to_dict(DataManager().add_document_connection( - request.name, - configuration - )) - - -@router.delete("/data-sources/{source_id}") -async def delete_data_source(source_id: int, delete_documents: bool = True, delete_indexed_data: bool = True): - await DataManager().remove_document_connection(source_id, delete_documents, delete_indexed_data) - return {"message": "Data source removed successfully"} - - -@router.post("/data-sources/{source_id}/scan") -async def scan_data_sources(source_id: int): - return DataManager().scan_document_connections([source_id]) - - -@router.post("/data-sources/{source_id}/index") -async def index_data_source(source_id: int): - manager = DataManager() - data_source = manager.get_document_connection(source_id) - if isinstance(data_source, dict) and "error" in data_source: - raise HTTPException(status_code=404, detail=data_source["error"]) - return await manager.index_documents(data_source) diff --git a/selfie/api/documents.py b/selfie/api/documents.py index c019f8c..c70d27e 100644 --- a/selfie/api/documents.py +++ b/selfie/api/documents.py @@ -6,9 +6,7 @@ from selfie.embeddings import ScoredEmbeddingDocumentModel from selfie.database import DataManager from selfie.embeddings import DataIndex -from selfie.parsers.chat import ChatFileParser -# router = APIRouter(tags=["Data Management"]) router = APIRouter() @@ -134,39 +132,3 @@ async def search_documents( documents=result["documents"], summary=result["summary"] if include_summary else None ) - - -@router.post("/documents/unindex", - tags=["Deprecated"]) -async def unindex_documents(request: UnindexDocumentsRequest): - await DataIndex("n/a").delete_documents_with_source_documents(request.document_ids) - return {"message": "Document unindexed successfully"} - - -@router.post("/documents/index", - tags=["Deprecated"]) -async def index_documents(request: IndexDocumentsRequest): - is_chat = request.is_chat - document_ids = request.document_ids - - manager = DataManager() - parser = ChatFileParser() - - # TODO: figure out what to do about this - speaker_aliases = {} - - return [ - await manager.index_document(manager.get_document(document_id), - lambda document: DataIndex.map_share_gpt_data( - parser.parse_document( - document.content, - None, - speaker_aliases, - False, - document.id - ).conversations, - # source=document.source.name, - source_document_id=document.id - ) if is_chat else None) - for document_id in document_ids - ] diff --git a/selfie/api/index_documents.py b/selfie/api/index_documents.py deleted file mode 100644 index bef0c96..0000000 --- a/selfie/api/index_documents.py +++ /dev/null @@ -1,153 +0,0 @@ -from typing import Optional, List - -from fastapi import APIRouter, UploadFile, File, Form - -from selfie.api.data_sources import DataLoaderRequest -from selfie.parsers.chat import ChatFileParser -from selfie.parsers.chat.chat_file_parsing_helper import get_files_with_configs, delete_uploaded_files -from selfie.embeddings import DataIndex -from selfie.embeddings.document_types import EmbeddingDocumentModel - -from llama_index.core.node_parser.text import SentenceSplitter -from datetime import datetime -import importlib - -router = APIRouter(tags=["Deprecated"]) - - -@router.get("/index_documents") -async def get_documents(offset: int = 0, limit: int = 10): - return await DataIndex("n/a").get_documents(offset=offset, limit=limit) - - -@router.get("/index_documents/summary") -async def get_index_documents_summary(topic: str, limit: Optional[int] = 5, min_score: Optional[float] = None, include_summary: Optional[bool] = True): - result = await DataIndex("n/a").recall(topic, limit=limit, min_score=min_score, include_summary=include_summary) - return { - "summary": result["summary"], - "score": result["mean_score"], - "documents": result["documents"], - } - - -@router.post("/index_documents") -async def create_index_document(document: EmbeddingDocumentModel): - return (await DataIndex("n/a").index([document]))[0] - - -@router.get("/index_documents/{document_id}") -async def get_index_document(document_id: int): - return await DataIndex("n/a").get_document(document_id) - - -@router.put("/index_documents/{document_id}") -async def update_index_document(document_id: int, document: EmbeddingDocumentModel): - await DataIndex("n/a").update_document(document_id, document) - return {"message": "Document updated successfully"} - - -@router.delete("/index_documents/{document_id}") -async def delete_index_document(document_id: int): - # Sometimes self.embeddings.save() errors on "database is locked", bricks it - # raise HTTPException(status_code=501, detail="Not implemented") - DataIndex("n/a").delete_document(document_id) - return {"message": "Document deleted successfully"} - - -@router.delete("/index_documents") -async def delete_index_documents(): - DataIndex("n/a").delete_all() - return {"message": "All documents deleted successfully"} - - -# TODO: Deprecate this endpoint, it should be not be allowed to embed documents that are not tracked -@router.post("/index_documents/llama-hub-loader") -async def load_data(request: DataLoaderRequest): - # TODO: extract document metadata from request? - - # Adapted from: - # https://docs.llamaindex.ai/en/stable/examples/low_level/oss_ingestion_retrieval.html#build-an-ingestion-pipeline-from-scratch - module_name, class_name = request.loader_module.rsplit(".", 1) - module = importlib.import_module(module_name) - loader_class = getattr(module, class_name) - - loader = loader_class(*request.constructor_args, **request.constructor_kwargs) - - documents = loader.load_data(*request.load_data_args, **request.load_data_kwargs) - - print(documents) - - text_parser = SentenceSplitter( - chunk_size=1024, - # separator=" ", - ) - - text_chunks = [] - # maintain relationship with source doc index, to help inject doc metadata in (3) - doc_idxs = [] - for doc_idx, doc in enumerate(documents): - cur_text_chunks = text_parser.split_text(doc.text) - text_chunks.extend(cur_text_chunks) - doc_idxs.extend([doc_idx] * len(cur_text_chunks)) - - embedding_documents = [] - for idx, text_chunk in enumerate(text_chunks): - src_doc = documents[doc_idxs[idx]] - document = EmbeddingDocumentModel( - text=text_chunk, - # source=request.loader_module, - # importance=0.0, - # timestamp=datetime.strptime(src_doc.metadata['last_modified'], "%Y-%m-%d"), - # use last_modified if available, otherwise use current time - timestamp=datetime.strptime(src_doc["last_modified"], "%Y-%m-%d") - # source_document_id=src_doc["id"] - if "last_modified" in src_doc - else datetime.now(), - ) - - embedding_documents.append(document) - - return {"documents": await DataIndex("n/a").index(embedding_documents, extract_importance=False)} - - -@router.post("/index_documents/chat-processor") -async def process_chat_files( - character_name: str, - files: List[UploadFile] = File(..., description="Upload chat files here."), - parser_configs: str = Form( - "[]", - description="JSON string of parser configurations. Format can be whatsapp, google, discord. Example: " - + '[{"main_speaker": "Alice", "format": "whatsapp", ' - + '"speaker_aliases": {"alicex": "Alice", "bobby": "Bob"}}]', - ), - extract_importance: bool = False, -): - parser = ChatFileParser() - data_index = DataIndex(character_name) - - files_with_settings = get_files_with_configs(files, parser_configs) - - documents = [] - new_document_count = 0 - for file_with_settings in files_with_settings: - file_data = parser.parse_file( - file_with_settings["file"], - file_with_settings["config"].format, - file_with_settings["config"].speaker_aliases, - ) - - mapped_documents = DataIndex.map_share_gpt_data( - file_data.conversations, file_with_settings["file"].split("/")[-1] - ) - - file_documents = await data_index.index(mapped_documents, extract_importance) - new_document_count += len(file_documents) - documents.extend(file_documents) - - delete_uploaded_files(files_with_settings) - - return { - "success": True, - "num_documents": len(documents), - "num_new_documents": new_document_count, - } diff --git a/selfie/api/models.py b/selfie/api/models.py index ee85725..3b2cabe 100644 --- a/selfie/api/models.py +++ b/selfie/api/models.py @@ -1,7 +1,7 @@ from datetime import datetime from fastapi import APIRouter from huggingface_hub import scan_cache_dir -from typing import List, Optional +from typing import List from pydantic import BaseModel, Field router = APIRouter(tags=["Configuration"])