diff --git a/selfie-ui/src/app/components/Playground/PlaygroundQuery.tsx b/selfie-ui/src/app/components/Playground/PlaygroundQuery.tsx index b8bf0d7..54af79b 100644 --- a/selfie-ui/src/app/components/Playground/PlaygroundQuery.tsx +++ b/selfie-ui/src/app/components/Playground/PlaygroundQuery.tsx @@ -3,10 +3,25 @@ import { apiBaseUrl } from "@/app/config"; import useAsyncTask from "@/app/hooks/useAsyncTask"; import TaskToast from "@/app/components/TaskToast"; - -const fetchDocuments = async (topic: string, limit?: number, minScore?: number, includeSummary?: boolean) => { - const params = new URLSearchParams({ topic, ...(limit && { limit: limit.toString() }), ...(minScore && { min_score: minScore.toString() }), ...(includeSummary !== undefined && { include_summary: includeSummary.toString() }) }); - const url = `${apiBaseUrl}/v1/index_documents/summary?${params.toString()}`; +const fetchDocuments = async ( + query: string, + limit?: number, + minScore?: number, + includeSummary?: boolean, + relevanceWeight?: number, + recencyWeight?: number, + importanceWeight?: number +) => { + const params = new URLSearchParams({ + query, + ...(limit && { limit: limit.toString() }), + ...(minScore && { min_score: minScore.toString() }), + ...(includeSummary !== undefined && { include_summary: includeSummary.toString() }), + ...(relevanceWeight && { relevance_weight: relevanceWeight.toString() }), + ...(recencyWeight && { recency_weight: recencyWeight.toString() }), + ...(importanceWeight && { importance_weight: importanceWeight.toString() }), + }); + const url = `${apiBaseUrl}/v1/documents/search?${params.toString()}`; try { const response = await fetch(url); @@ -23,10 +38,14 @@ const PlaygroundQuery = () => { const [documents, setDocuments] = useState([]); const [summary, setSummary] = useState(""); const [isSummaryLoading, setSummaryLoading] = useState(false); - const [score, setScore] = useState(0); + const [averageScore, setAverageScore] = useState(0); + const [totalResults, setTotalResults] = useState(0); const [limit, setLimit] = useState(); const [minScore, setMinScore] = useState(); const [includeSummary, setIncludeSummary] = useState(true); + const [relevanceWeight, setRelevanceWeight] = useState(); + const [recencyWeight, setRecencyWeight] = useState(); + const [importanceWeight, setImportanceWeight] = useState(); const handleInputChange = (setter: React.Dispatch>) => (e: React.ChangeEvent) => { const value = e.target.type === "number" ? Number(e.target.value) || undefined : e.target.value; @@ -41,16 +60,26 @@ const PlaygroundQuery = () => { e.preventDefault(); executeTask(async () => { - setScore(0); + setAverageScore(0); + setTotalResults(0); setDocuments([]); setSummary(""); setSummaryLoading(true); - const results = await fetchDocuments(query, limit, minScore, includeSummary); - setScore(results.score); + const results = await fetchDocuments( + query, + limit, + minScore, + includeSummary, + relevanceWeight, + recencyWeight, + importanceWeight + ); + setAverageScore(results.average_score); + setTotalResults(results.total_results); setDocuments(results.documents); setSummary(results.summary); setSummaryLoading(false); - console.log("Searching with:", query, limit, minScore, includeSummary); + console.log("Searching with:", query, limit, minScore, includeSummary, relevanceWeight, recencyWeight, importanceWeight); }, { start: "Searching...", success: "Search complete", @@ -62,12 +91,9 @@ const PlaygroundQuery = () => { return (
- {/*

Document {doc.id}

*/} -

Embedding document {i}

+

Embedding Document {doc.id}

{doc.text}
    - {/*
  • Score: {doc.score}
  • */} - {/* only 2 decimal */}
  • Overall score: {doc.score.toFixed(2)}
  • Relevance score: {doc.relevance.toFixed(2)}
  • Recency score: {doc.recency.toFixed(2)}
  • @@ -114,7 +140,39 @@ const PlaygroundQuery = () => { onChange={(e) => setMinScore(Number(e.target.value) || undefined)} min="0" max="1" - step="0.1" + step="0.01" + /> +
+ +
+ {/**/} + setRelevanceWeight(e.target.value ? Number(e.target.value) : undefined)} + min="0" + max="1" + step="0.01" + /> +
+ +
+ {/**/} + setRecencyWeight(e.target.value ? Number(e.target.value) : undefined)} + min="0" + max="1" + step="0.01" />
@@ -145,34 +203,62 @@ const PlaygroundQuery = () => {
- {!!summary &&
- {/**/} -

{summary}

- {documents.length ?

Result Score: {score.toFixed(2)}

: null } -
} + {!!summary && ( +
+ {/**/} +

{summary}

+ {documents.length ? ( +
+

Total Results: {totalResults}

+

Average Score: {averageScore.toFixed(2)}

+
+ ) : null} +
+ )} - {!!score &&
- {documents.map(renderDocument)} -
} + {documents.length > 0 &&
{documents.map(renderDocument)}
} ); }; -const SearchIcon = () => - -; - -const LoadingIcon = () => - - - -; +const SearchIcon = () => ( + + + +); + +const LoadingIcon = () => ( + + + + + +); PlaygroundQuery.displayName = "PlaygroundQuery"; diff --git a/selfie-ui/src/app/page.tsx b/selfie-ui/src/app/page.tsx index 68d077b..64dae80 100644 --- a/selfie-ui/src/app/page.tsx +++ b/selfie-ui/src/app/page.tsx @@ -88,9 +88,12 @@ const App = () => { ))}
  • - + API Docs + + API Sandbox +
  • LlamaCppChatCompletionResponse | LitellmCompletionResponse: @@ -19,7 +22,10 @@ async def create_chat_completion( # TODO can StreamingResponse's schema be defined? -@router.post("/completions") +@router.post("/completions", + description=""" + Creates a response for the given prompt in [the style of OpenAI](https://platform.openai.com/docs/api-reference/completions/create). + """) async def create_completion( request: CompletionRequest, ) -> LlamaCppCompletionResponse | LitellmCompletionResponse: diff --git a/selfie/api/connectors.py b/selfie/api/connectors.py index 5c57e1c..bdfa1b9 100644 --- a/selfie/api/connectors.py +++ b/selfie/api/connectors.py @@ -3,9 +3,38 @@ from fastapi import APIRouter from pydantic import BaseModel +from selfie.connectors.whatsapp.connector import WhatsAppConnector from selfie.connectors.factory import ConnectorFactory -router = APIRouter() +router = APIRouter(tags=["Configuration"]) + +example = { + "id": WhatsAppConnector().id, + "name": WhatsAppConnector().name, + "documentation": "WhatsApp is a popular messaging app...", # TODO: Read this from connectors/whatsapp, maybe truncate + "form_schema": { # TODO: Read this from connectors/whatsapp + "title": "Upload WhatsApp Conversations", + "type": "object", + "properties": { + "files": { + "type": "array", + "title": "Files", + "description": "Upload .txt files exported from WhatsApp", + "items": { + "type": "object" + } + } + } + }, + "ui_schema": { # TODO: Read this from connectors/whatsapp + "files": { + "ui:widget": "nativeFile", + "ui:options": { + "accept": ".txt" + } + } + } +} class Connector(BaseModel): @@ -15,18 +44,30 @@ class Connector(BaseModel): form_schema: Optional[dict] = None ui_schema: Optional[dict] = None + model_config = { + "json_schema_extra": { + "example": example + } + } + class ConnectorsResponse(BaseModel): connectors: List[Connector] = [] -@router.get("/connectors") +@router.get("/connectors", + description="""List all available connectors. This endpoint fetches and returns a comprehensive list of all connectors configured in the system, along with their respective details including ID, name, optional documentation, form schema, and UI schema if available. + +### Response Format +Returns a `ConnectorsResponse` object containing a list of `Connector` objects each detailing a connector available in the system. +""") async def get_connectors() -> ConnectorsResponse: connectors = ConnectorFactory.get_all_connectors() return ConnectorsResponse(connectors=connectors) -@router.get("/connectors/{connector_id}") +@router.get("/connectors/{connector_id}", + description="Retrieve detailed information about a specific connector by its ID. This includes its name, documentation, and any schemas related to form and UI configurations.") async def get_connector(connector_id: str) -> Connector: connector_instance = ConnectorFactory.get_connector(connector_name=connector_id) return Connector( diff --git a/selfie/api/data_sources.py b/selfie/api/data_sources.py deleted file mode 100644 index 407305a..0000000 --- a/selfie/api/data_sources.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Optional, Any, List, Dict - -from fastapi import APIRouter, HTTPException -from playhouse.shortcuts import model_to_dict -from pydantic import BaseModel - -from selfie.database import DataManager - -router = APIRouter() - - -class DataLoaderRequest(BaseModel): - name: Optional[str] - loader_module: str - constructor_args: Optional[List[Any]] = [] - constructor_kwargs: Optional[Dict[str, Any]] = {} - load_data_args: Optional[List[Any]] = [] - load_data_kwargs: Optional[Dict[str, Any]] = {} - - -@router.get("/data-sources") -async def get_data_sources(): - return DataManager().get_document_connections() - - -@router.post("/data-sources") -async def add_data_source(request: DataLoaderRequest): - configuration = { - "loader_name": request.loader_module, - "constructor_args": request.constructor_args, - "constructor_kwargs": request.constructor_kwargs, - "load_data_args": request.load_data_args, - "load_data_kwargs": request.load_data_kwargs, - } - - return model_to_dict(DataManager().add_document_connection( - request.name, - configuration - )) - - -@router.delete("/data-sources/{source_id}") -async def delete_data_source(source_id: int, delete_documents: bool = True, delete_indexed_data: bool = True): - await DataManager().remove_document_connection(source_id, delete_documents, delete_indexed_data) - return {"message": "Data source removed successfully"} - - -@router.post("/data-sources/{source_id}/scan") -async def scan_data_sources(source_id: int): - return DataManager().scan_document_connections([source_id]) - - -@router.post("/data-sources/{source_id}/index") -async def index_data_source(source_id: int): - manager = DataManager() - data_source = manager.get_document_connection(source_id) - if isinstance(data_source, dict) and "error" in data_source: - raise HTTPException(status_code=404, detail=data_source["error"]) - return await manager.index_documents(data_source) diff --git a/selfie/api/document_connections.py b/selfie/api/document_connections.py index b1dc350..8a2d62e 100644 --- a/selfie/api/document_connections.py +++ b/selfie/api/document_connections.py @@ -1,21 +1,44 @@ import base64 import json -from typing import Any +from typing import Any, Dict from urllib.parse import quote -from fastapi import APIRouter, Request, UploadFile, Form, HTTPException, Depends -from pydantic import BaseModel +from fastapi import APIRouter, Request, UploadFile, Form, HTTPException, Depends, Body +from pydantic import BaseModel, Field from selfie.connectors.factory import ConnectorFactory from selfie.database import DataManager, DocumentModel from selfie.embeddings import DataIndex -router = APIRouter() +router = APIRouter(tags=["Data Management"]) class DocumentConnectionRequest(BaseModel): - connector_id: str - configuration: Any + connector_id: str = Field(..., description="The ID of the connector to use for creating the document connection.") + configuration: Dict[str, Any] = Field(..., description="The configuration object for the document connection.") + + model_config = { + "json_schema_extra": { + "example": { + "connector_id": "whatsapp", + "configuration": { + "files": ["data:text/plain;name=example.txt;base64,SGVsbG8gV29ybGQ="] + } + } + } + } + + +class DocumentConnectionResponse(BaseModel): + message: str = Field(..., description="A message indicating the status of the document connection creation.") + + model_config = { + "json_schema_extra": { + "example": { + "message": "Document connection created successfully" + } + } + } async def file_to_data_uri(file: UploadFile): @@ -67,8 +90,74 @@ async def parse_create_document_connection_request(request: Request): return connector_id, configuration -@router.post("/document-connections") -async def create_document_connection(request: Request, parsed_data: tuple = Depends(parse_create_document_connection_request)): +@router.post("/document-connections", + description="""Create a new document connection using the specified connector and configuration. + +The request can be sent as multipart/form-data or application/json. Because browsers tend to limit the size of data URIs, it is recommended to use multipart/form-data for large files. + +### Using multipart/form-data + +For multipart/form-data, the connector_id and configuration should be provided as form fields. The configuration object should contain placeholders for file references, which will be replaced with data URIs from corresponding file fields in the request body. + +For example, the following form fields can be used to create a document connection: + + connector_id: whatsapp + configuration: {"files": ["file-0"]} + file-0: + +### Using application/json + +For application/json, the connector_id and configuration should be provided in the request body. The configuration object should contain files as data URIs. + +For example, the following JSON can be used to create a document connection: + + { + "connector_id": "whatsapp", + "configuration": { + "files": ["data:text/plain;name=example.txt;base64,SGVsbG8gV29ybGQ="] + } + } +""", + openapi_extra={ + "requestBody": { + "content": { + "application/json": { + "schema": DocumentConnectionRequest.schema() + }, + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "connector_id": {"type": "string"}, + "configuration": {"type": "string"} + }, + "required": ["connector_id", "configuration"], + "patternProperties": { + "^file-\\d+$": { + "type": "string", + "format": "binary" + } + } + }, + "examples": { + "example1": { + "summary": "Example 1", + "value": """connector_id: whatsapp +configuration: {"files": ["file-0"]} +file-0: +""" + } + } + } + } + } + }) +# request_body=DocumentConnectionRequest) +async def create_document_connection( + request: Request, + # document_connection_request: DocumentConnectionRequest = Body(..., description="The document connection request."), + parsed_data: tuple = Depends(parse_create_document_connection_request) +) -> DocumentConnectionResponse: connector_id, configuration = parsed_data connector_instance = ConnectorFactory.get_connector(connector_name=connector_id) connector_instance.validate_configuration(configuration=configuration) @@ -96,4 +185,4 @@ async def create_document_connection(request: Request, parsed_data: tuple = Depe # Save embedding_documents to Vector DB await DataIndex("n/a").index(embedding_documents, extract_importance=False) - return {"message": "Document connection created successfully"} + return DocumentConnectionResponse(message="Document connection created successfully") diff --git a/selfie/api/documents.py b/selfie/api/documents.py index 8ebaf08..68bb382 100644 --- a/selfie/api/documents.py +++ b/selfie/api/documents.py @@ -1,11 +1,11 @@ -from typing import List +from typing import List, Optional -from fastapi import APIRouter -from pydantic import BaseModel +from fastapi import APIRouter, Query +from pydantic import BaseModel, Field +from selfie.embeddings import ScoredEmbeddingDocumentModel from selfie.database import DataManager from selfie.embeddings import DataIndex -from selfie.parsers.chat import ChatFileParser router = APIRouter() @@ -23,57 +23,112 @@ class DeleteDocumentsRequest(BaseModel): document_ids: List[int] = [] -@router.get("/documents") -async def get_documents(): +class FetchedDocument(BaseModel): + id: int = Field(..., description="The unique identifier of the document") + name: str = Field(..., description="The name of the document") + size: int = Field(..., description="The size of the document") + created_at: str = Field(..., description="The timestamp of the document creation") + updated_at: str = Field(..., description="The timestamp of the document update") + content_type: str = Field(..., description="The content type of the document") + connector_name: str = Field(..., description="The name of the connector") + + model_config = { + "json_schema_extra": { + "example": { + "id": 1, + "name": "example.txt", + "size": 1024, + "created_at": "2024-03-11T18:33:04.733583", + "updated_at": "2024-03-11T18:33:04.733590", + "content_type": "text/plain", + "connector_name": "whatsapp", + } + } + } + + +@router.get("/documents", + tags=["Data Management"]) +async def get_documents() -> List[FetchedDocument]: return DataManager().get_documents() -@router.delete("/documents") -async def index_documents(request: DeleteDocumentsRequest): +@router.delete("/documents", + tags=["Data Management"], + description="Remove multiple documents by their IDs.", + status_code=204) +async def delete_documents(request: DeleteDocumentsRequest): await DataManager().remove_documents([int(document_id) for document_id in request.document_ids]) - return {"message": "Documents removed successfully"} -@router.delete("/documents/{document_id}") -async def delete_data_source(document_id: int, delete_indexed_data: bool = True): +@router.delete("/documents/{document_id}", + tags=["Data Management"], + description="Remove a document by its ID.", + status_code=204) +async def delete_document(document_id: int, delete_indexed_data: Optional[bool] = True): await DataManager().remove_document(document_id, delete_indexed_data) - return {"message": "Document removed successfully"} - - -@router.post("/documents/unindex") -async def unindex_documents(request: UnindexDocumentsRequest): - await DataIndex("n/a").delete_documents_with_source_documents(request.document_ids) - return {"message": "Document unindexed successfully"} - - -@router.post("/documents/index") -async def index_documents(request: IndexDocumentsRequest): - is_chat = request.is_chat - document_ids = request.document_ids - - manager = DataManager() - parser = ChatFileParser() - - # TODO: figure out what to do about this - speaker_aliases = {} - - return [ - await manager.index_document(manager.get_document(document_id), - lambda document: DataIndex.map_share_gpt_data( - parser.parse_document( - document.content, - None, - speaker_aliases, - False, - document.id - ).conversations, - # source=document.source.name, - source_document_id=document.id - ) if is_chat else None) - for document_id in document_ids - ] - -# @app.delete("/documents/{document-id}") -# async def delete_data_source(document_id: int): -# DataSourceManager().remove_document(document_id) -# return {"message": "Document removed successfully"} + + +class SearchDocumentsResponse(BaseModel): + query: str = Field(..., description="The search query") + total_results: int = Field(..., description="The total number of documents found") + average_score: float = Field(..., description="The mean relevance score of the documents") + documents: List[ScoredEmbeddingDocumentModel] = Field(..., description="The documents found") + summary: Optional[str] = Field(None, description="A summary of the search results") + + model_config = { + "json_schema_extra": { + "example": { + "query": "What is the meaning of life?", + "total_results": 1, + "average_score": 0.4206031938249788, + "documents": [ + { + "id": 1, + "text": "The meaning of life is 42.", + "source": "whatsapp", + "timestamp": "2023-03-12T11:35:00Z", + "created_timestamp": "2024-03-11T18:33:04.733583", + "updated_timestamp": "2024-03-11T18:33:04.733590", + "source_document_id": 3, + "score": 0.4206031938249788, + "relevance": 0.08712080866098404, + "recency": 0.4720855789889736, + "importance": None, + }, + ], + "summary": "The meaning of life is 42." + } + } + } + + +@router.get("/documents/search", + tags=["Search"], + description="Search for embedding documents that most closely match a query.") +async def search_documents( + query: str, + limit: Optional[int] = Query(3, ge=1, le=100, description="Maximum number of documents to fetch"), + min_score: Optional[float] = Query(0.4, ge=0.0, le=1.0, description="Minimum score for embedding documents"), + include_summary: Optional[bool] = Query(False, description="Include a summary of the search results"), + relevance_weight: Optional[float] = Query(1.0, le=1.0, ge=0.0, description="Weight for relevance in the scoring algorithm"), + recency_weight: Optional[float] = Query(1.0, le=1.0, ge=0.0, description="Weight for recency in the scoring algorithm"), + importance_weight: Optional[float] = Query(0, le=1.0, ge=0.0, description="**Importance scores are currently not calculated, so this weight has no effect!** Weight for document importance in the scoring algorithm.") +) -> SearchDocumentsResponse: + result = await DataIndex("n/a").recall( + topic=query, + limit=limit, + min_score=min_score, + include_summary=include_summary, + relevance_weight=relevance_weight, + recency_weight=recency_weight, + importance_weight=importance_weight, + ) + + return SearchDocumentsResponse( + query=query, + total_results=len(result["documents"]), + average_score=result["mean_score"], + documents=result["documents"], + summary=result["summary"] if include_summary else None + ) diff --git a/selfie/api/models.py b/selfie/api/models.py index bb98713..3b2cabe 100644 --- a/selfie/api/models.py +++ b/selfie/api/models.py @@ -1,11 +1,43 @@ +from datetime import datetime from fastapi import APIRouter from huggingface_hub import scan_cache_dir +from typing import List +from pydantic import BaseModel, Field -router = APIRouter() +router = APIRouter(tags=["Configuration"]) -@router.get("/models") -async def get_models(): +class Model(BaseModel): + id: str = Field(..., description="The unique identifier of the model, formatted as 'repo_id/filename'") + object: str = Field(..., description="The type of the object, typically 'model'") + created: datetime = Field(..., description="The timestamp when the model was last modified in the cache") + owned_by: str = Field(..., description="Indicates the ownership of the model, typically 'user' for user-uploaded models") + + +class ModelsResponse(BaseModel): + object: str = Field(..., description="The type of the response, always 'list'") + data: List[Model] = Field(..., description="A list of model objects detailing available models in the cache") + + model_config = { + "json_schema_extra": { + "example": { + "object": "list", + "data": [ + { + "id": "user/repo/modelname.gguf", + "object": "model", + "created": "2021-01-01T00:00:00Z", + "owned_by": "user" + } + ] + } + } + } + + +@router.get("/models", + description="Retrieve a list of **already-downloaded llama.cpp models** (in the Hugging Face Hub cache). This endpoint scans the cache directory for model files (specifically looking for files with a '.gguf' extension within each repository revision) and returns a list of models including their ID, object type, creation timestamp, and ownership information.") +async def get_models() -> ModelsResponse: hf_cache_info = scan_cache_dir() models = [] @@ -14,21 +46,18 @@ async def get_models(): gguf_files = [file for file in revision.files if file.file_name.endswith('.gguf')] if gguf_files: for gguf_file in gguf_files: - models.append({ - "id": f"{repo.repo_id}/{gguf_file.file_name}", - "object": "model", - "created": gguf_file.blob_last_modified, - "owned_by": "user" - }) + models.append(Model( + id=f"{repo.repo_id}/{gguf_file.file_name}", + object="model", + created=gguf_file.blob_last_modified, + owned_by="user" + )) else: - models.append({ - "id": repo.repo_id, - "object": "model", - "created": min(file.last_modified for file in repo.revisions), - "owned_by": "user" - }) - - return { - "object": "list", - "data": models - } + models.append(Model( + id=repo.repo_id, + object="model", + created=min(file.last_modified for file in repo.revisions), + owned_by="user" + )) + + return ModelsResponse(object="list", data=models) diff --git a/selfie/connectors/whatsapp/connector.py b/selfie/connectors/whatsapp/connector.py index 9946214..d80a387 100644 --- a/selfie/connectors/whatsapp/connector.py +++ b/selfie/connectors/whatsapp/connector.py @@ -16,7 +16,7 @@ class WhatsAppConfiguration(BaseModel): class WhatsAppConnector(BaseConnector, ABC): def __init__(self): super().__init__() - self.id = "whatsapp" + self.id = "whatsapp" # TODO: this should be static self.name = "WhatsApp" def load_document(self, configuration: dict[str, Any]) -> List[DocumentDTO]: diff --git a/selfie/embeddings/__init__.py b/selfie/embeddings/__init__.py index 493a28a..7c6630b 100644 --- a/selfie/embeddings/__init__.py +++ b/selfie/embeddings/__init__.py @@ -358,6 +358,7 @@ async def recall( documents_list: List[ScoredEmbeddingDocumentModel] = [] for result in results: document = EmbeddingDocumentModel( + id=result["id"], text=result["text"], timestamp=result["timestamp"], importance=result["importance"], diff --git a/selfie/embeddings/document_types.py b/selfie/embeddings/document_types.py index 0092d5f..1dd0d30 100644 --- a/selfie/embeddings/document_types.py +++ b/selfie/embeddings/document_types.py @@ -23,6 +23,18 @@ def to_dict(self, *args, **kwargs): class Config: validate_assignment = True + json_schema_extra = { + "example": { + "id": 42, + "text": "What is the meaning of life?", + "source": "whatsapp", + "importance": None, + "timestamp": "2022-01-01T00:00:00Z", + "created_timestamp": "2022-01-01T00:00:00Z", + "updated_timestamp": "2022-01-01T00:00:00Z", + "source_document_id": 42 + } + } @model_validator(mode='before') def autofill_timestamps(cls, values): @@ -37,3 +49,15 @@ class ScoredEmbeddingDocumentModel(EmbeddingDocumentModel): importance: Optional[float] = Field(..., description="Importance score of the document, [0, 1]") relevance: float = Field(..., description="Relevance score of the document, for a query, [0, 1]") recency: float = Field(..., description="Recency score of the document, for a query (time), [0, 1]") + + model_config = { + "json_schema_extra": { + "example": { + **EmbeddingDocumentModel.model_config['json_schema_extra']['example'], + "score": 0.42, + "relevance": 0.42, + "recency": 0.42, + "importance": None, # For now + } + } + } diff --git a/selfie/types/completion_requests.py b/selfie/types/completion_requests.py index f9107a6..4e1c9ad 100644 --- a/selfie/types/completion_requests.py +++ b/selfie/types/completion_requests.py @@ -1,4 +1,6 @@ from typing import Optional, Literal, List, ClassVar, Union, Dict, Any +from pydantic import BaseModel, TypeAdapter, Extra, Field + from litellm import ModelResponse as LitellmCompletionResponse from llama_cpp import ( CreateCompletionResponse as LlamaCppCompletionResponse, @@ -7,7 +9,6 @@ from openai.types.chat import ChatCompletionMessage from openai.types.completion_create_params import CompletionCreateParams, CompletionCreateParamsNonStreaming, CompletionCreateParamsStreaming from openai.types.chat import CompletionCreateParams as ChatCompletionCreateParams -from pydantic import BaseModel, TypeAdapter, Extra from sse_starlette import EventSourceResponse # from typing_extensions import TypedDict @@ -16,48 +17,64 @@ class Message(BaseModel): - role: str - content: str + role: str = Field(..., description="The role of the message sender, e.g., 'system', 'user', or 'assistant'") + content: str = Field(..., description="The content of the message") class FunctionCall(BaseModel): - name: str - parameters: Optional[Dict[str, Any]] = None + name: str = Field(..., description="The name of the function to call") + parameters: Optional[Dict[str, Any]] = Field(None, description="The parameters to pass to the function") class Tool(BaseModel): - type: str - function: Optional[FunctionCall] = None + type: str = Field(..., description="The type of the tool") + function: Optional[FunctionCall] = Field(None, description="The function to call when using the tool") class BaseCompletionRequest(BaseModel): # OpenAI parameters - model: Optional[str] = None - frequency_penalty: Optional[float] = 0.0 - logit_bias: Optional[Dict[int, float]] = None + model: Optional[str] = Field(None, description="ID of the model to use for completion") + frequency_penalty: Optional[float] = Field(0.0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.") + logit_bias: Optional[Dict[int, float]] = Field(None, description="Modify the likelihood of specified tokens appearing in the completion.") # logprobs: Optional[bool] = False - max_tokens: Optional[int] = None + max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate in the completion.") # n: Optional[int] = 1 - presence_penalty: Optional[float] = 0.0 - response_format: Optional[Dict[str, str]] = None - seed: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None - stream: Optional[bool] = False - temperature: Optional[float] = 1.0 - top_p: Optional[float] = 1.0 - tools: Optional[List[Tool]] = None - tool_choice: Optional[Union[str, Dict[str, Any]]] = None + presence_penalty: Optional[float] = Field(0.0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.") + response_format: Optional[Dict[str, str]] = Field(None, description="An object specifying the format that the model must output.") + seed: Optional[int] = Field(None, description="If specified, the returned completion will be deterministic. Generated tokens will be the same for each request with the same seed.") + stop: Optional[Union[str, List[str]]] = Field(None, description="Up to 4 sequences where the API will stop generating further tokens.") + stream: Optional[bool] = Field(False, description="If set, partial completion results will be sent as they become available.") + temperature: Optional[float] = Field(1.0, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.") + top_p: Optional[float] = Field(1.0, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.") + tools: Optional[List[Tool]] = Field(None, description="A list of tools the model may call.") + tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Controls which (if any) function is called by the model. Options are 'none', 'auto', or a specific function.") # user: Optional[str] = None # Selfie parameters - method: Optional[Literal["litellm", "llama.cpp", "transformers"]] = None - api_base: Optional[str] = None - api_key: Optional[str] = None - disable_augmentation: Optional[bool] = False + method: Optional[Literal["litellm", "llama.cpp", "transformers"]] = Field(None, description="The method to use for completion, e.g., 'litellm', 'llama.cpp', or 'transformers'.") + api_base: Optional[str] = Field(None, description="The base URL for the API") + api_key: Optional[str] = Field(None, description="The API key to use for authentication") + disable_augmentation: Optional[bool] = Field(False, description="Whether to disable data augmentation during completion") - # Custom parameters, e.g. for a custom API class Config: + # Allow custom parameters, e.g. for a custom API extra = Extra.allow + model_config = { + "json_schema_extra": { + "example": { + "method": "litellm", + "model": "gpt-3.5-turbo", + "api_key": "your-api-key", + "prompt": "Hello, how are you?", + "max_tokens": 50, + "temperature": 0.8, + } + } + } + + +class ChatCompletionRequest(BaseCompletionRequest): + messages: List[Message] = Field(..., description="A list of messages comprising the conversation so far.") custom_params: ClassVar[List[str]] = ["method", "api_base", "api_key", "disable_augmentation"] @@ -78,27 +95,53 @@ def extra_params(self): """ return {k: v for k, v in self.model_dump().items() if k not in self.model_fields.keys()} - -class ChatCompletionRequest(BaseCompletionRequest): - messages: List[Message] + model_config = { + "json_schema_extra": { + "example": { + "method": "litellm", + "model": "gpt-3.5-turbo", + "api_key": "your-api-key", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"} + ], + "max_tokens": 50, + "temperature": 0.8, + } + } + } class CompletionRequest(BaseCompletionRequest): - prompt: Union[str, List[str]] + prompt: Union[str, List[str]] = Field(..., description="The prompt(s) to generate completions for. Can be a string or a list of strings.") # best_of: Optional[int] = None - echo: Optional[bool] = None - logprobs: Optional[int] = None + echo: Optional[bool] = Field(None, description="Whether to echo the prompt in the response.") + logprobs: Optional[int] = Field(None, description="Include the log probabilities on the logprobs most likely tokens, as well the chosen tokens. So for example, if logprobs is 10, the API will return a list of the 10 most likely tokens. If logprobs is supplied, the API will always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the response.") # n: Optional[int] = None - suffix: Optional[str] = None + suffix: Optional[str] = Field(None, description="The suffix that comes after a completion of inserted text.") + + model_config = { + "json_schema_extra": { + "example": { + "method": "litellm", + "model": "gpt-3.5-turbo", + "api_key": "your-api-key", + "prompt": "Once upon a time", + "max_tokens": 50, + "temperature": 0.8, + } + } + } -class ChatCompletionResponse(BaseModel): - id: str - object: str = "chat.completion" - created: int - model: Optional[str] - choices: List[Dict[str, Any]] - usage: Dict[str, int] +# class ChatCompletionResponse(BaseModel): +# id: str = Field(..., description="The ID of the completion") +# object: str = Field('chat.completion', description="The object type, e.g., 'chat.completion'") +# created: int = Field(..., description="The timestamp of the completion creation") +# model: Optional[str] = Field(None, description="The model used for the completion") +# choices: List[Dict[str, Any]] = Field(..., description="The choices of completions") +# usage: Dict[str, int] = Field(..., description="The usage of the completion") +# ##########################