Skip to content

Commit

Permalink
[API] init
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanung committed Oct 6, 2024
0 parents commit 43b8085
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.env
/LLMenv
__pycache__
venv
LLMenv
9 changes: 9 additions & 0 deletions app/controllers/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from fastapi import HTTPException
from app.utils.model import text_generation_pipeline

async def process_prompt(prompt: str, return_full_text: bool = False, max_new_tokens: int = 200, temperature: float = 0.7, repetition_penalty: float = 1.1, top_k: int = 50, top_p: float = 0.95, **kwargs):
try:
result = await text_generation_pipeline(prompt, return_full_text, max_new_tokens, temperature, repetition_penalty, top_k, top_p, **kwargs)
return [{"generated_text": result[0]["generated_text"]}]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
36 changes: 36 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from fastapi import FastAPI, Depends, UploadFile, File, Body
from fastapi.middleware.cors import CORSMiddleware as CORS
from pydantic import BaseModel
import os
from dotenv import load_dotenv

from app.controllers import prompt as prompt_controller

load_dotenv()

app = FastAPI()

# CORS setup (unchanged)
# ...

# Define a Pydantic model for the request body
class PromptRequest(BaseModel):
prompt: str
return_full_text: bool = False
max_new_tokens: int = 200
temperature: float = 0.7
repetition_penalty: float = 1.1
top_k: int = 50
top_p: float = 0.95

@app.post("/prompt")
async def prompt_endpoint(request: PromptRequest = Body(...)):
return await prompt_controller.process_prompt(
prompt=request.prompt,
return_full_text=request.return_full_text,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
top_k=request.top_k,
top_p=request.top_p
)
33 changes: 33 additions & 0 deletions app/utils/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv
import os
import torch

load_dotenv()

device = os.getenv("DEVICE")
text_generation_model_name = os.getenv("TEXT_GENERATION_MODEL_NAME")

tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
model = AutoModelForCausalLM.from_pretrained(text_generation_model_name).to(device)

async def text_generation_pipeline(prompt, return_full_text, max_new_tokens, temperature, repetition_penalty, top_k, top_p, **kwargs):
inputs = tokenizer(prompt, return_tensors="pt").to(device)

with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
**kwargs
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

if not return_full_text:
generated_text = generated_text[len(prompt):]

return [{"generated_text": generated_text}]
16 changes: 16 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import uvicorn
import sys
from pathlib import Path
from dotenv import load_dotenv
import os

load_dotenv()

current_dir = Path(__file__).resolve().parent
sys.path.append(str(current_dir))

host = os.getenv("HOST")
port = int(os.getenv("PORT"))

if __name__ == "__main__":
uvicorn.run("app.main:app", host=host, port=port, reload=True)

0 comments on commit 43b8085

Please sign in to comment.