Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🤖 Refactor Celery Initialization and Database Setup Checks #1591

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/celery_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,27 @@
logger = logging.getLogger(__name__)
celery_app = Celery("seer")

def _is_celery_worker() -> bool:
"""Check if we're running in a Celery worker process"""
return (
os.environ.get("CELERY_APP") is not None
or "celery" in sys.argv[0].lower()
)

# This abstract helps tests that want to validate the entry point process.
def setup_celery_entrypoint(app: Celery):
app.on_configure.connect(init_celery_app)

# Only initialize when the worker actually starts
app.on_worker_init.connect(init_celery_app)

@inject
def init_celery_app(*args: Any, sender: Celery, config: CeleryConfig = injected, **kwargs: Any):
for k, v in config.items():
setattr(sender.conf, k, v)
bootup(start_model_loading=False, integrations=[CeleryIntegration(propagate_traces=True)])
from celery_app.tasks import setup_periodic_tasks

# Only run bootup if we're in a Celery worker process
if _is_celery_worker():
bootup(start_model_loading=False, integrations=[CeleryIntegration(propagate_traces=True)])
from celery_app.tasks import setup_periodic_tasks

sender.on_after_finalize.connect(setup_periodic_tasks)

Expand Down
18 changes: 17 additions & 1 deletion src/seer/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from enum import StrEnum
from typing import Any, List, Optional


import sqlalchemy
from flask import Flask
from flask import Flask, current_app
from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
from pgvector.sqlalchemy import Vector # type: ignore
Expand Down Expand Up @@ -35,21 +36,36 @@
from seer.configuration import AppConfig
from seer.dependency_injection import inject, injected

def is_db_initialized() -> bool:
"""Check if database is already initialized in current Flask app"""
return bool(current_app and hasattr(current_app, "_database_initialized"))

def mark_db_initialized() -> None:
"""Mark database as initialized in current Flask app"""
if current_app:
setattr(current_app, "_database_initialized", True)
else:
logger.warning("No Flask application context found when marking database as initialized")

@inject
def initialize_database(
config: AppConfig = injected,
app: Flask = injected,
):
if is_db_initialized():
return

app.config["SQLALCHEMY_DATABASE_URI"] = config.DATABASE_URL
app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"connect_args": {"prepare_threshold": None}}


db.init_app(app)
migrate.init_app(app, db)

with app.app_context():
Session.configure(bind=db.engine)

mark_db_initialized()

class Base(DeclarativeBase):
pass
Expand Down
Loading