From 30b9325d6f6d4351706f437fe2ec5d5acb631288 Mon Sep 17 00:00:00 2001 From: Yichao 'Peak' Ji Date: Wed, 19 Apr 2023 00:16:15 +0800 Subject: [PATCH] feat(server): support cross-origin resource sharing (#148) --- Dockerfile | 1 + basaran/__init__.py | 1 + basaran/__main__.py | 5 +++++ requirements.txt | 1 + setup.py | 9 ++++++++- 5 files changed, 16 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 734f6b40..bb8efa84 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,6 +35,7 @@ ENV SERVER_CONNECTION_LIMIT="1024" ENV SERVER_CHANNEL_TIMEOUT="300" ENV SERVER_MODEL_NAME="" ENV SERVER_NO_PLAYGROUND="false" +ENV SERVER_CORS_ORIGINS="*" ENV COMPLETION_MAX_PROMPT="4096" ENV COMPLETION_MAX_TOKENS="4096" ENV COMPLETION_MAX_N="5" diff --git a/basaran/__init__.py b/basaran/__init__.py index dd453ab8..067c2e26 100644 --- a/basaran/__init__.py +++ b/basaran/__init__.py @@ -33,6 +33,7 @@ def is_true(value): SERVER_CHANNEL_TIMEOUT = int(os.getenv("SERVER_CHANNEL_TIMEOUT", "300")) SERVER_MODEL_NAME = os.getenv("SERVER_MODEL_NAME", "") or MODEL SERVER_NO_PLAYGROUND = is_true(os.getenv("SERVER_NO_PLAYGROUND", "")) +SERVER_CORS_ORIGINS = os.getenv("SERVER_CORS_ORIGINS", "*") # Completion-related arguments: COMPLETION_MAX_PROMPT = int(os.getenv("COMPLETION_MAX_PROMPT", "4096")) diff --git a/basaran/__main__.py b/basaran/__main__.py index 3df87649..75bad8f5 100644 --- a/basaran/__main__.py +++ b/basaran/__main__.py @@ -7,6 +7,7 @@ import waitress from flask import Flask, Response, abort, jsonify, render_template, request +from flask_cors import CORS from . import is_true from .choice import reduce_choice @@ -28,6 +29,7 @@ from . import SERVER_CHANNEL_TIMEOUT from . import SERVER_MODEL_NAME from . import SERVER_NO_PLAYGROUND +from . import SERVER_CORS_ORIGINS from . import COMPLETION_MAX_PROMPT from . import COMPLETION_MAX_TOKENS from . import COMPLETION_MAX_N @@ -52,6 +54,9 @@ app.json.compact = True app.url_map.strict_slashes = False +# Configure cross-origin resource sharing (CORS). +CORS(app, origins=SERVER_CORS_ORIGINS.split(",")) + def parse_options(schema): """Parse options specified in query parameters and request body.""" diff --git a/requirements.txt b/requirements.txt index a0b6c2b2..fd5c9252 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ bitsandbytes~=0.38.1 build>=0.10.0 coverage>=6.4.2 flake8>=3.7.9 +flask-cors~=3.0.10 flask>=2.2.1 huggingface-hub~=0.13.4 jinja2>=3.1.2 diff --git a/setup.py b/setup.py index 739cc564..e7e31a63 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,14 @@ packages=find_packages(), include_package_data=True, python_requires=">=3.8.0", - install_requires=["flask", "jinja2", "torch", "transformers", "waitress"], + install_requires=[ + "flask-cors", + "flask", + "jinja2", + "torch", + "transformers", + "waitress", + ], keywords=["api", "huggingface", "nlp", "openai", "transformer"], classifiers=[ "Development Status :: 4 - Beta",