Skip to content

Commit

Permalink
Add CORS for prod & checking if CUDA is available before loading model
Browse files Browse the repository at this point in the history
  • Loading branch information
xKhronoz committed Jan 26, 2024
1 parent 8500091 commit afc8144
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
7 changes: 5 additions & 2 deletions backend/backend/app/utils/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
current_directory / "data"
) # directory containing the documents to index


# set to at least 1 to use GPU, adjust according to your GPU memory, but must be able to fit the model
model_kwargs = {"n_gpu_layers": 100} if DEVICE_TYPE == "cuda" else {}

llm = LlamaCPP(
model_url="https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf",
temperature=0.1,
Expand All @@ -43,8 +47,7 @@
# kwargs to pass to __call__()
# generate_kwargs={},
# kwargs to pass to __init__()
# set to at least 1 to use GPU, adjust according to your GPU memory, but must be able to fit the model
model_kwargs={"n_gpu_layers": 100},
model_kwargs=model_kwargs,
# transform inputs into Llama2 format
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
Expand Down
24 changes: 23 additions & 1 deletion backend/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from torch.cuda import is_available as is_cuda_available

load_dotenv()

app = FastAPI()

environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set

# TODO: Add reading allowed origins from environment variables

if environment == "dev":
logger = logging.getLogger("uvicorn")
Expand All @@ -28,10 +30,30 @@
allow_headers=["*"],
)

if environment == "prod":
# In production, specify the allowed origins
allowed_origins = [
"https://your-production-domain.com",
"https://another-production-domain.com",
# Add more allowed origins as needed
]

logger = logging.getLogger("uvicorn")
logger.info(f"Running in production mode - allowing CORS for {allowed_origins}")
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)

logger.info(f"CUDA available: {is_cuda_available()}")

app.include_router(chat_router, prefix="/api/chat")
app.include_router(query_router, prefix="/api/query")
app.include_router(search_router, prefix="/api/search")
app.include_router(healthcheck_router, prefix="/api/healthcheck")

# try to create the index first on startup
# Try to create the index first on startup
create_index()

0 comments on commit afc8144

Please sign in to comment.