-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 43b8085
Showing
5 changed files
with
99 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
.env | ||
/LLMenv | ||
__pycache__ | ||
venv | ||
LLMenv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from fastapi import HTTPException | ||
from app.utils.model import text_generation_pipeline | ||
|
||
async def process_prompt(prompt: str, return_full_text: bool = False, max_new_tokens: int = 200, temperature: float = 0.7, repetition_penalty: float = 1.1, top_k: int = 50, top_p: float = 0.95, **kwargs): | ||
try: | ||
result = await text_generation_pipeline(prompt, return_full_text, max_new_tokens, temperature, repetition_penalty, top_k, top_p, **kwargs) | ||
return [{"generated_text": result[0]["generated_text"]}] | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from fastapi import FastAPI, Depends, UploadFile, File, Body | ||
from fastapi.middleware.cors import CORSMiddleware as CORS | ||
from pydantic import BaseModel | ||
import os | ||
from dotenv import load_dotenv | ||
|
||
from app.controllers import prompt as prompt_controller | ||
|
||
load_dotenv() | ||
|
||
app = FastAPI() | ||
|
||
# CORS setup (unchanged) | ||
# ... | ||
|
||
# Define a Pydantic model for the request body | ||
class PromptRequest(BaseModel): | ||
prompt: str | ||
return_full_text: bool = False | ||
max_new_tokens: int = 200 | ||
temperature: float = 0.7 | ||
repetition_penalty: float = 1.1 | ||
top_k: int = 50 | ||
top_p: float = 0.95 | ||
|
||
@app.post("/prompt") | ||
async def prompt_endpoint(request: PromptRequest = Body(...)): | ||
return await prompt_controller.process_prompt( | ||
prompt=request.prompt, | ||
return_full_text=request.return_full_text, | ||
max_new_tokens=request.max_new_tokens, | ||
temperature=request.temperature, | ||
repetition_penalty=request.repetition_penalty, | ||
top_k=request.top_k, | ||
top_p=request.top_p | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from dotenv import load_dotenv | ||
import os | ||
import torch | ||
|
||
load_dotenv() | ||
|
||
device = os.getenv("DEVICE") | ||
text_generation_model_name = os.getenv("TEXT_GENERATION_MODEL_NAME") | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name) | ||
model = AutoModelForCausalLM.from_pretrained(text_generation_model_name).to(device) | ||
|
||
async def text_generation_pipeline(prompt, return_full_text, max_new_tokens, temperature, repetition_penalty, top_k, top_p, **kwargs): | ||
inputs = tokenizer(prompt, return_tensors="pt").to(device) | ||
|
||
with torch.no_grad(): | ||
outputs = model.generate( | ||
**inputs, | ||
max_new_tokens=max_new_tokens, | ||
temperature=temperature, | ||
repetition_penalty=repetition_penalty, | ||
top_k=top_k, | ||
top_p=top_p, | ||
**kwargs | ||
) | ||
|
||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | ||
|
||
if not return_full_text: | ||
generated_text = generated_text[len(prompt):] | ||
|
||
return [{"generated_text": generated_text}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import uvicorn | ||
import sys | ||
from pathlib import Path | ||
from dotenv import load_dotenv | ||
import os | ||
|
||
load_dotenv() | ||
|
||
current_dir = Path(__file__).resolve().parent | ||
sys.path.append(str(current_dir)) | ||
|
||
host = os.getenv("HOST") | ||
port = int(os.getenv("PORT")) | ||
|
||
if __name__ == "__main__": | ||
uvicorn.run("app.main:app", host=host, port=port, reload=True) |