diff --git a/backend/backend/app/utils/index.py b/backend/backend/app/utils/index.py index 9839db1..45350ee 100644 --- a/backend/backend/app/utils/index.py +++ b/backend/backend/app/utils/index.py @@ -34,6 +34,10 @@ current_directory / "data" ) # directory containing the documents to index + +# set to at least 1 to use GPU, adjust according to your GPU memory, but must be able to fit the model +model_kwargs = {"n_gpu_layers": 100} if DEVICE_TYPE == "cuda" else {} + llm = LlamaCPP( model_url="https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf", temperature=0.1, @@ -43,8 +47,7 @@ # kwargs to pass to __call__() # generate_kwargs={}, # kwargs to pass to __init__() - # set to at least 1 to use GPU, adjust according to your GPU memory, but must be able to fit the model - model_kwargs={"n_gpu_layers": 100}, + model_kwargs=model_kwargs, # transform inputs into Llama2 format messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, diff --git a/backend/backend/main.py b/backend/backend/main.py index 7e4d920..65cb640 100644 --- a/backend/backend/main.py +++ b/backend/backend/main.py @@ -9,6 +9,7 @@ from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from torch.cuda import is_available as is_cuda_available load_dotenv() @@ -16,6 +17,7 @@ environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set +# TODO: Add reading allowed origins from environment variables if environment == "dev": logger = logging.getLogger("uvicorn") @@ -28,10 +30,30 @@ allow_headers=["*"], ) +if environment == "prod": + # In production, specify the allowed origins + allowed_origins = [ + "https://your-production-domain.com", + "https://another-production-domain.com", + # Add more allowed origins as needed + ] + + logger = logging.getLogger("uvicorn") + logger.info(f"Running in production mode - allowing CORS for {allowed_origins}") + app.add_middleware( + CORSMiddleware, + allow_origins=allowed_origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_headers=["*"], + ) + +logger.info(f"CUDA available: {is_cuda_available()}") + app.include_router(chat_router, prefix="/api/chat") app.include_router(query_router, prefix="/api/query") app.include_router(search_router, prefix="/api/search") app.include_router(healthcheck_router, prefix="/api/healthcheck") -# try to create the index first on startup +# Try to create the index first on startup create_index()