-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllm_pytorch.py
154 lines (116 loc) · 6.06 KB
/
llm_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import argparse
import torch # Import torch to manage device allocation
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer
import uvicorn
from dotenv import load_dotenv
import logging
from typing import Union, Dict
load_dotenv()
OUTPUT_TOKENS = int(os.getenv("OUTPUT_TOKENS", 1000))
TEMPERATURE = float(os.getenv("TEMPERATURE", 0.5))
LLM_SERVER_PORT = int(os.getenv("LLM_SERVER_PORT", 6001))
LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
LOG_QUERIES = os.getenv("LOG_QUERIES", "false")
# Configure logging
logging.basicConfig(filename='app.log', level=LOG_LEVEL.upper(),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Define your models configuration
models = [
{"friendly_name": "orca", "short_name": "Orca-2-7b", "vendor_name": "microsoft"},
{"friendly_name": "tinyllama", "short_name": "TinyLlama-1.1B-Chat-v1.0", "vendor_name": "TinyLlama"},
{"friendly_name": "phi2", "short_name": "phi-2", "vendor_name": "microsoft"},
{"friendly_name": "mistral", "short_name": "Mistral-7B-Instruct-v0.2", "vendor_name": "mistralai"},
]
# Setup FastAPI app
app = FastAPI()
# Determine the best available device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Explicitly set the device based on availability
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
# Placeholder for the pipeline object
pipe = None
class PredictRequest(BaseModel):
text: str
def ensure_path_exists(path: str) -> None:
if not os.path.exists(path):
os.makedirs(path)
def model_already_downloaded(model_directory: str) -> bool:
"""Check if the model and tokenizer have been already downloaded."""
required_files = ["config.json", "pytorch_model.bin", "tokenizer_config.json", "vocab.txt"]
return all(os.path.exists(os.path.join(model_directory, file)) for file in required_files)
def download_model(model_config: dict[str, str]) -> None:
model_identifier = f"{model_config['vendor_name']}/{model_config['short_name']}"
model_directory = f"./models/{model_config['short_name']}"
# Only check for the existence of the directory, not its contents.
if not os.path.exists(model_directory):
print(f"Downloading model '{model_identifier}' to '{model_directory}'...")
ensure_path_exists(model_directory)
tokenizer = AutoTokenizer.from_pretrained(model_identifier, cache_dir=model_directory)
tokenizer.padding_side = 'left' # Adjust padding side to left
model = AutoModelForCausalLM.from_pretrained(model_identifier, cache_dir=model_directory)
tokenizer.save_pretrained(model_directory)
model.save_pretrained(model_directory)
print("Model downloaded and saved successfully.")
else:
print(f"Model '{model_identifier}' is already downloaded.")
def load_model_from_disk(model_config: dict[str, str]) -> None:
global pipe
model_directory = f"./models/{model_config['short_name']}"
print(f"Attempting to load model from {model_directory}...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_directory)
tokenizer.padding_side = 'left' # Ensure tokenizer uses left padding
config = AutoConfig.from_pretrained(model_directory)
model = AutoModelForCausalLM.from_pretrained(model_directory, config=config).to(device)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device in ["cuda", "mps"] else -1)
print(f"Model '{model_config['short_name']}' loaded successfully.")
except Exception as e:
print(f"Failed to load model '{model_config['short_name']}' from disk. Error: {e}")
print("Please delete the model directory and try downloading again.")
async def startup_event() -> None:
parser = argparse.ArgumentParser(description="FastAPI model serving application.")
parser.add_argument("-m", "--model", type=str, help="Model to use by friendly name", default="tinyllama")
parser.add_argument("-d", "--download", action="store_true", help="Download the model from Hugging Face")
args, unknown = parser.parse_known_args()
selected_model_config = next((m for m in models if m["friendly_name"] == args.model), None)
if not selected_model_config:
available_models = ", ".join([m["friendly_name"] for m in models])
print(f"Model '{args.model}' not found. Available models: {available_models}")
print("Please select one of the available models by using the -m option.")
exit(1) # Exit the application with a non-zero status to indicate an error.
# If the model is found, proceed with the download or load from disk.
if args.download:
download_model(selected_model_config)
load_model_from_disk(selected_model_config)
async def shutdown_event() -> None:
pipe = None
app.add_event_handler("startup", startup_event)
app.add_event_handler("shutdown", shutdown_event)
@app.post("/predict")
async def predict(request: PredictRequest) -> Union[Dict[str, str], str]:
if pipe is None:
raise HTTPException(status_code=503, detail="Model not loaded correctly")
if (LOG_QUERIES == "true"):
logging.info(f"LLM Prompt: {request.text}")
# Optionally format the input text based on the model requirements
#formatted_text = f"<s>[INST] {request.text} [/INST]</s>"
formatted_text = f"{request.text}"
result = pipe(formatted_text, max_length=OUTPUT_TOKENS, temperature=TEMPERATURE, truncation=True, do_sample=True, return_full_text=False)
response = result[0]["generated_text"]
#return {"result": response}
return response
def run_server() -> None:
config = uvicorn.Config("llm_pytorch:app", host="0.0.0.0", port=LLM_SERVER_PORT, log_level=LOG_LEVEL)
server = uvicorn.Server(config)
try:
server.run()
except KeyboardInterrupt:
print("Shutting down...")
finally:
print("Server stopped.")
if __name__ == "__main__":
run_server()