diff --git a/gptme/server/api.py b/gptme/server/api.py index 72bb9eef..424b6099 100644 --- a/gptme/server/api.py +++ b/gptme/server/api.py @@ -14,6 +14,7 @@ import flask from flask import current_app, request +from flask_cors import CORS from ..commands import execute_cmd from ..dirs import get_logs_dir @@ -165,8 +166,16 @@ def favicon(): return flask.send_from_directory(media_path, "logo.png") -def create_app() -> flask.Flask: - """Create the Flask app.""" +def create_app(cors_origin: str | None = None) -> flask.Flask: + """Create the Flask app. + + Args: + cors_origin: CORS origin to allow. Use '*' to allow all origins. + """ app = flask.Flask(__name__, static_folder=static_path) app.register_blueprint(api) + + if cors_origin: + CORS(app, resources={r"/api/*": {"origins": cors_origin}}) + return app diff --git a/gptme/server/cli.py b/gptme/server/cli.py index f0decfdf..7ea19592 100644 --- a/gptme/server/cli.py +++ b/gptme/server/cli.py @@ -27,6 +27,11 @@ help="Port to run the server on.", ) @click.option("--tools", default=None, help="Tools to enable, comma separated.") +@click.option( + "--cors-origin", + default=None, + help="CORS origin to allow. Use '*' to allow all origins.", +) def main( debug: bool, verbose: bool, @@ -34,6 +39,7 @@ def main( host: str, port: str, tools: str | None, + cors_origin: str | None, ): # pragma: no cover """ Starts a server and web UI for gptme. @@ -58,5 +64,5 @@ def main( exit(1) click.echo("Initialization complete, starting server") - app = create_app() + app = create_app(cors_origin=cors_origin) app.run(debug=debug, host=host, port=int(port)) diff --git a/poetry.lock b/poetry.lock index 5bcbc787..1dc94d29 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "accessible-pygments" @@ -634,6 +634,20 @@ Werkzeug = ">=3.0.0" async = ["asgiref (>=3.2)"] dotenv = ["python-dotenv"] +[[package]] +name = "flask-cors" +version = "4.0.2" +description = "A Flask extension adding a decorator for CORS support" +optional = true +python-versions = "*" +files = [ + {file = "Flask_Cors-4.0.2-py2.py3-none-any.whl", hash = "sha256:38364faf1a7a5d0a55bd1d2e2f83ee9e359039182f5e6a029557e1f56d92c09a"}, + {file = "flask_cors-4.0.2.tar.gz", hash = "sha256:493b98e2d1e2f1a4720a7af25693ef2fe32fbafec09a2f72c59f3e475eda61d2"}, +] + +[package.dependencies] +Flask = ">=0.9" + [[package]] name = "fonttools" version = "4.54.1" @@ -3425,13 +3439,13 @@ files = [ requests = "*" [extras] -all = ["flask", "matplotlib", "numpy", "pandas", "pillow", "playwright", "python-xlib"] +all = ["flask", "flask-cors", "matplotlib", "numpy", "pandas", "pillow", "playwright", "python-xlib"] browser = ["playwright"] computer = ["pillow", "python-xlib"] datascience = ["matplotlib", "numpy", "pandas", "pillow"] -server = ["flask"] +server = ["flask", "flask-cors"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "1a818642af5d1853c1f8bbc628218757ec088e756a1098e5caca4b4403268258" +content-hash = "3ee6ba8b5b968ee3f907a96335ca60cad440e661dc9b73bb22c835f6b5c8b2ac" diff --git a/pyproject.toml b/pyproject.toml index 9dbebe56..722e8728 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ pillow = {version = "*", optional=true} # server flask = {version = "^3.0", optional=true} +flask-cors = {version = "^4.0", optional=true} [tool.poetry.group.dev.dependencies] # lint @@ -83,13 +84,13 @@ types-tabulate = "*" types-lxml = "*" [tool.poetry.extras] -server = ["flask"] +server = ["flask", "flask-cors"] browser = ["playwright"] datascience = ["matplotlib", "pandas", "numpy", "pillow"] computer = ["python-xlib", "pillow"] # pillow already in datascience but listed for clarity all = [ # server - "flask", + "flask", "flask-cors", # browser "playwright", # datascience