From 0d028ca43ab029ab3107aef4e54a0ab36aaaab29 Mon Sep 17 00:00:00 2001 From: Roy Wiggins Date: Tue, 22 Oct 2024 21:40:59 +0000 Subject: [PATCH] fix tests: accidentally was reading from /opt/mercure, fixing required delaying resolving database connection information until runtime add basic booting test for bookkeeper --- alembic/env.py | 24 ++- bookkeeper.py | 103 ++++++----- bookkeeping/config.py | 33 ++-- bookkeeping/database.py | 339 ++++++++++++++++++----------------- bookkeeping/query.py | 62 +++---- common/config.py | 92 +++++----- dev-requirements.in | 2 + dev-requirements.txt | 27 +++ test.py | 2 +- tests/data/test_config.json | 1 + tests/dispatch/test_retry.py | 6 - tests/test_bookkeeper.py | 23 ++- tests/test_query.py | 20 ++- tests/test_router.py | 22 +-- tests/test_studies.py | 7 +- tests/testing_common.py | 26 ++- webgui.py | 69 +++---- 17 files changed, 483 insertions(+), 375 deletions(-) diff --git a/alembic/env.py b/alembic/env.py index c9c01945..f4a14f51 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -10,18 +10,16 @@ # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config -config.set_main_option('sqlalchemy.url', os.getenv("DATABASE_URL", "not set")) +_os_env_database_url = os.getenv("DATABASE_URL") +if _os_env_database_url is not None: + config.set_main_option('sqlalchemy.url', _os_env_database_url) # Interpret the config file for Python logging. # This line sets up loggers basically. -fileConfig(config.config_file_name) # type: ignore +# fileConfig(config.config_file_name) # type: ignore # add your model's MetaData object here # for 'autogenerate' support -import bookkeeper -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = bookkeeper.metadata # other values from the config, defined by the needs of env.py, # can be acquired: @@ -41,7 +39,15 @@ def run_migrations_offline() -> None: script output. """ - url = config.get_main_option("sqlalchemy.url") + import bookkeeping.database as db + # url = config.get_main_option("sqlalchemy.url") + url = os.environ.get('DATABASE_URL') + schema = os.environ.get('DATABASE_SCHEMA') + if schema == 'None': + schema = None + db.init_database(url, schema) + target_metadata = db.metadata + context.configure( url=url, target_metadata=target_metadata, @@ -60,6 +66,10 @@ def run_migrations_online() -> None: and associate a connection with the context. """ + import bookkeeping.database as db + db.init_database(os.environ.get('DATABASE_URL'), os.environ.get('DATABASE_SCHEMA')) + target_metadata = db.metadata + connectable = engine_from_config( config.get_section(config.config_ini_section), prefix="sqlalchemy.", diff --git a/bookkeeper.py b/bookkeeper.py index 5a8c1f4d..d1877318 100755 --- a/bookkeeper.py +++ b/bookkeeper.py @@ -2,7 +2,7 @@ bookkeeper.py ============= The bookkeeper service of mercure, which receives notifications from all mercure services -and stores the information in a Postgres database. +and stores the information in a database. """ # Standard python includes @@ -32,7 +32,7 @@ from common import config import common.monitor as monitor from common.constants import mercure_defs -from bookkeeping.database import * +import bookkeeping.database as db import bookkeeping.query as query import bookkeeping.config as bk_config from decoRouter import Router as decoRouter @@ -63,18 +63,14 @@ async def verify(self, token: str): ################################################################################### -def create_database() -> None: - """Creates all tables in the database if they do not exist.""" - subprocess.run( - ["alembic", "upgrade", "head"], - check=True, - env={ - **os.environ, - "PATH": "/opt/mercure/env/bin:" + os.environ["PATH"], - "DATABASE_URL": bk_config.DATABASE_URL, - }, - ) +from alembic.config import Config +from alembic import command +def create_database() -> None: + alembic_cfg = Config() + alembic_cfg.set_main_option('script_location', os.path.dirname(os.path.realpath(__file__))+'/alembic') + alembic_cfg.set_main_option('sqlalchemy.url', bk_config.DATABASE_URL) + command.upgrade(alembic_cfg, 'head') ################################################################################### @@ -83,7 +79,7 @@ def create_database() -> None: # async def execute_db_operation(operation) -> None: # global connection -# """Executes a previously prepared database operation.""" +# """Executes a previously prepared db.database operation.""" # try: # connection.execute(operation) # except: @@ -106,10 +102,10 @@ async def post_mercure_event(request) -> JSONResponse: severity = int(payload.get("severity", monitor.severity.INFO)) description = payload.get("description", "") - query = mercure_events.insert().values( + query = db.mercure_events.insert().values( sender=sender, event=event, severity=severity, description=description, time=datetime.datetime.now() ) - result = await database.execute(query) + result = await db.database.execute(query) logger.debug(result) return JSONResponse({"ok": ""}) @@ -138,16 +134,16 @@ async def processor_logs(request) -> JSONResponse: if (logs_folder_str := config.mercure.processing_logs.logs_file_store) and ( logs_path := Path(logs_folder_str) ).exists(): - query = processor_logs_table.insert().values(task_id=task_id, module_name=module_name, time=time, logs=None) - result = await database.execute(query) + query = db.processor_logs_table.insert().values(task_id=task_id, module_name=module_name, time=time, logs=None) + result = await db.database.execute(query) logs_path = logs_path / task_id logs_path.mkdir(exist_ok=True) logs_file = logs_path / f"{module_name}.{str(result)}.txt" logs_file.write_text(logs, encoding="utf-8") else: - query = processor_logs_table.insert().values(task_id=task_id, module_name=module_name, time=time, logs=logs) - result = await database.execute(query) + query = db.processor_logs_table.insert().values(task_id=task_id, module_name=module_name, time=time, logs=logs) + result = await db.database.execute(query) logger.debug(result) return JSONResponse({"ok": ""}) @@ -163,10 +159,10 @@ async def post_webgui_event(request) -> JSONResponse: user = payload.get("user", "UNKNOWN") description = payload.get("description", "") - query = webgui_events.insert().values( + query = db.webgui_events.insert().values( sender=sender, event=event, user=user, description=description, time=datetime.datetime.now() ) - await database.execute(query) + await db.database.execute(query) # tasks = BackgroundTasks() # tasks.add_task(execute_db_operation, operation=query) return JSONResponse({"ok": ""}) @@ -181,10 +177,10 @@ async def register_dicom(request) -> JSONResponse: file_uid = payload.get("file_uid", "") series_uid = payload.get("series_uid", "") - query = dicom_files.insert().values( + query = db.dicom_files.insert().values( filename=filename, file_uid=file_uid, series_uid=series_uid, time=datetime.datetime.now() ) - result = await database.execute(query) + result = await db.database.execute(query) logger.debug(f"Result: {result}") # tasks = BackgroundTasks() @@ -194,7 +190,7 @@ async def register_dicom(request) -> JSONResponse: async def parse_and_submit_tags(payload) -> None: """Helper function that reads series information from the request body.""" - query = dicom_series.insert().values( + query = db.dicom_series.insert().values( time=datetime.datetime.now(), series_uid=payload.get("SeriesInstanceUID", ""), study_uid=payload.get("StudyInstanceUID", ""), @@ -227,7 +223,7 @@ async def parse_and_submit_tags(payload) -> None: tag_softwareversions=payload.get("SoftwareVersions", ""), tag_stationname=payload.get("StationName", ""), ) - await database.execute(query) + await db.database.execute(query) @router.post("/register-series") @@ -251,7 +247,7 @@ async def register_task(request) -> JSONResponse: # Registering the task ordinarily happens first, but if "update-task" # came in first, we need to update the task instead. So we do an upsert. query = ( - insert(tasks_table) + insert(db.tasks_table) .values( id=payload["id"], series_uid=payload["series_uid"], @@ -267,7 +263,7 @@ async def register_task(request) -> JSONResponse: ) ) - await database.execute(query) + await db.database.execute(query) return JSONResponse({"ok": ""}) @@ -295,14 +291,14 @@ async def update_task(request) -> JSONResponse: # Ordinarily, update-task is called on an existing task. But if the task is # not yet registered, we need to create it. So we use an upsert here. query = ( - insert(tasks_table) + insert(db.tasks_table) .values(**update_values) .on_conflict_do_update( # update if exists index_elements=["id"], set_=update_values, ) ) - await database.execute(query) + await db.database.execute(query) return JSONResponse({"ok": ""}) @@ -314,7 +310,7 @@ async def test_begin(request) -> JSONResponse: type = payload.get("type", "route") rule_type = payload.get("rule_type", "series") task_id = payload.get("task_id", None) - query_a = insert(tests_table).values( + query_a = insert(db.tests_table).values( id=id, time_begin=datetime.datetime.now(), type=type, status="begin", task_id=task_id, rule_type=rule_type ) @@ -324,7 +320,7 @@ async def test_begin(request) -> JSONResponse: "task_id": task_id or query_a.excluded.task_id, }, ) - await database.execute(query) + await db.database.execute(query) return JSONResponse({"ok": ""}) @@ -335,8 +331,8 @@ async def test_end(request) -> JSONResponse: id = payload["id"] status = payload.get("status", "") - query = tests_table.update(tests_table.c.id == id).values(time_end=datetime.datetime.now(), status=status) - await database.execute(query) + query = db.tests_table.update(db.tests_table.c.id == id).values(time_end=datetime.datetime.now(), status=status) + await db.database.execute(query) return JSONResponse({"ok": ""}) @@ -373,7 +369,7 @@ async def post_task_event(request) -> JSONResponse: info = payload.get("info", "") task_id = payload.get("task_id") - query = task_events.insert().values( + query = db.task_events.insert().values( sender=sender, event=event, task_id=task_id, @@ -384,7 +380,7 @@ async def post_task_event(request) -> JSONResponse: time=event_time, client_timestamp=client_timestamp, ) - await database.execute(query) + await db.database.execute(query) return JSONResponse({"ok": ""}) @@ -393,8 +389,8 @@ async def post_task_event(request) -> JSONResponse: async def store_processor_output(request) -> JSONResponse: payload = dict(await request.json()) values_dict = {k:payload[k] for k in ("task_id", "task_acc", "task_mrn", "module", "index", "settings", "output")} - query = processor_outputs_table.insert().values(**values_dict) - await database.execute(query) + query = db.processor_outputs_table.insert().values(**values_dict) + await db.database.execute(query) return JSONResponse({"ok": ""}) @@ -405,11 +401,12 @@ async def store_processor_output(request) -> JSONResponse: @contextlib.asynccontextmanager async def lifespan(app): - await database.connect() + await db.database.connect() + assert db.metadata create_database() bk_config.set_api_key() yield - await database.disconnect() + await db.database.disconnect() async def server_error(request, exc) -> Response: @@ -422,14 +419,19 @@ async def server_error(request, exc) -> Response: 500: server_error } +app = None -app = Starlette(debug=bk_config.DEBUG_MODE, routes=router, lifespan=lifespan, exception_handlers=exception_handlers) -app.add_middleware( - AuthenticationMiddleware, - backend=TokenAuth(), - on_error=lambda _, exc: PlainTextResponse(str(exc), status_code=401), -) -app.mount("/query", query.query_app) +def create_app() -> Starlette: + global app + bk_config.read_bookkeeper_config() + app = Starlette(debug=bk_config.DEBUG_MODE, routes=router, lifespan=lifespan, exception_handlers=exception_handlers) + app.add_middleware( + AuthenticationMiddleware, + backend=TokenAuth(), + on_error=lambda _, exc: PlainTextResponse(str(exc), status_code=401), + ) + app.mount("/query", query.query_app) + return app def main(args=sys.argv[1:]) -> None: if "--reload" in args or os.getenv("MERCURE_ENV", "PROD").lower() == "dev": @@ -445,15 +447,18 @@ def main(args=sys.argv[1:]) -> None: logger.info("") try: + bk_config.read_bookkeeper_config() + db.init_database() config.read_config() query.set_timezone_conversion() + app = create_app() + uvicorn.run(app, host=bk_config.BOOKKEEPER_HOST, port=bk_config.BOOKKEEPER_PORT) + except Exception as e: logger.error(f"Could not read configuration file: {e}") logger.info("Going down.") sys.exit(1) - uvicorn.run(app, host=bk_config.BOOKKEEPER_HOST, port=bk_config.BOOKKEEPER_PORT) - if __name__ == "__main__": main() diff --git a/bookkeeping/config.py b/bookkeeping/config.py index 145207bb..0c949418 100644 --- a/bookkeeping/config.py +++ b/bookkeeping/config.py @@ -6,7 +6,7 @@ # Standard python includes import os -from typing import Any +from typing import Any, Optional import daiquiri # Starlette-related includes @@ -15,17 +15,26 @@ # Create local logger instance logger = daiquiri.getLogger("config") - - -bookkeeper_config = Config((os.getenv("MERCURE_CONFIG_FOLDER") or "/opt/mercure/config") + "/bookkeeper.env") - -BOOKKEEPER_PORT = bookkeeper_config("PORT", cast=int, default=8080) -BOOKKEEPER_HOST = bookkeeper_config("HOST", default="0.0.0.0") -DATABASE_URL = bookkeeper_config("DATABASE_URL", default="postgresql://mercure@localhost") -DATABASE_SCHEMA: Any = bookkeeper_config("DATABASE_SCHEMA", default=None) -DEBUG_MODE = bookkeeper_config("DEBUG", cast=bool, default=False) -API_KEY = None - +bookkeeper_config: Config +config_filename:str = (os.getenv("MERCURE_CONFIG_FOLDER") or "/opt/mercure/config") + "/bookkeeper.env" +DATABASE_URL: str +BOOKKEEPER_PORT: int +BOOKKEEPER_HOST: str +DATABASE_SCHEMA: Optional[str] +API_KEY: Optional[str] +DEBUG_MODE: bool + +def read_bookkeeper_config() -> Config: + global bookkeeper_config, BOOKKEEPER_PORT, BOOKKEEPER_HOST, DATABASE_URL, DATABASE_SCHEMA, DEBUG_MODE, API_KEY + bookkeeper_config = Config(config_filename) + + BOOKKEEPER_PORT = bookkeeper_config("PORT", cast=int, default=8080) + BOOKKEEPER_HOST = bookkeeper_config("HOST", default="0.0.0.0") + DATABASE_URL = bookkeeper_config("DATABASE_URL", default="postgresql://mercure@localhost") + DATABASE_SCHEMA = bookkeeper_config("DATABASE_SCHEMA", default=None) + DEBUG_MODE = bookkeeper_config("DEBUG", cast=bool, default=False) + API_KEY = None + return bookkeeper_config def set_api_key() -> None: global API_KEY diff --git a/bookkeeping/database.py b/bookkeeping/database.py index 427ccddf..8f3ac88e 100644 --- a/bookkeeping/database.py +++ b/bookkeeping/database.py @@ -19,165 +19,180 @@ ################################################################################### ## Definition of database tables ################################################################################### - - -database = databases.Database(bk_config.DATABASE_URL) -metadata = sqlalchemy.MetaData(schema=bk_config.DATABASE_SCHEMA) - -# SQLite does not support JSONB natively, so we use TEXT instead -JSONB = sqlalchemy.types.Text() if 'sqlite://' in bk_config.DATABASE_URL else sqlalchemy.dialects.postgresql.JSONB -# -mercure_events = sqlalchemy.Table( - "mercure_events", - metadata, - sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), - sqlalchemy.Column("time", sqlalchemy.DateTime), - sqlalchemy.Column("sender", sqlalchemy.String, default="Unknown"), - sqlalchemy.Column("event", sqlalchemy.String, default=monitor.m_events.UNKNOWN), - sqlalchemy.Column("severity", sqlalchemy.Integer, default=monitor.severity.INFO), - sqlalchemy.Column("description", sqlalchemy.String, default=""), -) - -webgui_events = sqlalchemy.Table( - "webgui_events", - metadata, - sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), - sqlalchemy.Column("time", sqlalchemy.DateTime), - sqlalchemy.Column("sender", sqlalchemy.String, default="Unknown"), - sqlalchemy.Column("event", sqlalchemy.String, default=monitor.w_events.UNKNOWN), - sqlalchemy.Column("user", sqlalchemy.String, default=""), - sqlalchemy.Column("description", sqlalchemy.String, default=""), -) - -dicom_files = sqlalchemy.Table( - "dicom_files", - metadata, - sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), - sqlalchemy.Column("time", sqlalchemy.DateTime), - sqlalchemy.Column("filename", sqlalchemy.String), - sqlalchemy.Column("file_uid", sqlalchemy.String), - sqlalchemy.Column("series_uid", sqlalchemy.String), -) - -dicom_series = sqlalchemy.Table( - "dicom_series", - metadata, - sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), - sqlalchemy.Column("time", sqlalchemy.DateTime), - sqlalchemy.Column("series_uid", sqlalchemy.String, unique=True), - sqlalchemy.Column("study_uid", sqlalchemy.String), - sqlalchemy.Column("tag_patientname", sqlalchemy.String), - sqlalchemy.Column("tag_patientid", sqlalchemy.String), - sqlalchemy.Column("tag_accessionnumber", sqlalchemy.String), - sqlalchemy.Column("tag_seriesnumber", sqlalchemy.String), - sqlalchemy.Column("tag_studyid", sqlalchemy.String), - sqlalchemy.Column("tag_patientbirthdate", sqlalchemy.String), - sqlalchemy.Column("tag_patientsex", sqlalchemy.String), - sqlalchemy.Column("tag_acquisitiondate", sqlalchemy.String), - sqlalchemy.Column("tag_acquisitiontime", sqlalchemy.String), - sqlalchemy.Column("tag_modality", sqlalchemy.String), - sqlalchemy.Column("tag_bodypartexamined", sqlalchemy.String), - sqlalchemy.Column("tag_studydescription", sqlalchemy.String), - sqlalchemy.Column("tag_seriesdescription", sqlalchemy.String), - sqlalchemy.Column("tag_protocolname", sqlalchemy.String), - sqlalchemy.Column("tag_codevalue", sqlalchemy.String), - sqlalchemy.Column("tag_codemeaning", sqlalchemy.String), - sqlalchemy.Column("tag_sequencename", sqlalchemy.String), - sqlalchemy.Column("tag_scanningsequence", sqlalchemy.String), - sqlalchemy.Column("tag_sequencevariant", sqlalchemy.String), - sqlalchemy.Column("tag_slicethickness", sqlalchemy.String), - sqlalchemy.Column("tag_contrastbolusagent", sqlalchemy.String), - sqlalchemy.Column("tag_referringphysicianname", sqlalchemy.String), - sqlalchemy.Column("tag_manufacturer", sqlalchemy.String), - sqlalchemy.Column("tag_manufacturermodelname", sqlalchemy.String), - sqlalchemy.Column("tag_magneticfieldstrength", sqlalchemy.String), - sqlalchemy.Column("tag_deviceserialnumber", sqlalchemy.String), - sqlalchemy.Column("tag_softwareversions", sqlalchemy.String), - sqlalchemy.Column("tag_stationname", sqlalchemy.String), -) - -task_events = sqlalchemy.Table( - "task_events", - metadata, - sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), - sqlalchemy.Column("task_id", sqlalchemy.String, sqlalchemy.ForeignKey("tasks.id"), nullable=True), - sqlalchemy.Column("time", sqlalchemy.DateTime), - sqlalchemy.Column("sender", sqlalchemy.String, default="Unknown"), - sqlalchemy.Column("event", sqlalchemy.String), - # sqlalchemy.Column("series_uid", sqlalchemy.String), - sqlalchemy.Column("file_count", sqlalchemy.Integer), - sqlalchemy.Column("target", sqlalchemy.String), - sqlalchemy.Column("info", sqlalchemy.String), - sqlalchemy.Column("client_timestamp", sqlalchemy.Integer), -) - -file_events = sqlalchemy.Table( - "file_events", - metadata, - sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), - sqlalchemy.Column("time", sqlalchemy.DateTime), - sqlalchemy.Column("dicom_file", sqlalchemy.Integer), - sqlalchemy.Column("event", sqlalchemy.Integer), -) - -dicom_series_map = sqlalchemy.Table( - "dicom_series_map", - metadata, - sqlalchemy.Column("id_file", sqlalchemy.Integer, primary_key=True), - sqlalchemy.Column("id_series", sqlalchemy.Integer), -) - -series_sequence_data = sqlalchemy.Table( - "series_sequence_data", - metadata, - sqlalchemy.Column("uid", sqlalchemy.String, primary_key=True), - sqlalchemy.Column("data", sqlalchemy.JSON), -) - -tasks_table = sqlalchemy.Table( - "tasks", - metadata, - sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), - sqlalchemy.Column("parent_id", sqlalchemy.String, nullable=True), - sqlalchemy.Column("time", sqlalchemy.DateTime), - sqlalchemy.Column("series_uid", sqlalchemy.String, nullable=True), - sqlalchemy.Column("study_uid", sqlalchemy.String, nullable=True), - sqlalchemy.Column("data", JSONB), -) - -tests_table = sqlalchemy.Table( - "tests", - metadata, - sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), - sqlalchemy.Column("type", sqlalchemy.String, nullable=True), - sqlalchemy.Column("rule_type", sqlalchemy.String, nullable=True), - sqlalchemy.Column("time_begin", sqlalchemy.DateTime, nullable=True), - sqlalchemy.Column("time_end", sqlalchemy.DateTime, nullable=True), - sqlalchemy.Column("status", sqlalchemy.String, nullable=True), - sqlalchemy.Column("task_id", sqlalchemy.String, nullable=True), - sqlalchemy.Column("data", JSONB, nullable=True), -) - -processor_logs_table = sqlalchemy.Table( - "processor_logs", - metadata, - sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), - sqlalchemy.Column("task_id", sqlalchemy.String, sqlalchemy.ForeignKey("tasks.id"), nullable=True), - sqlalchemy.Column("module_name", sqlalchemy.String, nullable=True), - sqlalchemy.Column("logs", sqlalchemy.String, nullable=True), - sqlalchemy.Column("time", sqlalchemy.DateTime, nullable=True), -) - -processor_outputs_table = sqlalchemy.Table( - "processor_outputs", - metadata, - sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), - sqlalchemy.Column("time", sqlalchemy.DateTime(timezone=True), server_default=func.now()), - sqlalchemy.Column("task_id", sqlalchemy.String, sqlalchemy.ForeignKey("tasks.id"),nullable=True), - sqlalchemy.Column("task_acc", sqlalchemy.String), - sqlalchemy.Column("task_mrn", sqlalchemy.String), - sqlalchemy.Column("module", sqlalchemy.String), - sqlalchemy.Column("index", sqlalchemy.Integer), - sqlalchemy.Column("settings", JSONB), - sqlalchemy.Column("output", JSONB), -) +database: databases.Database +metadata: sqlalchemy.MetaData +mercure_events: sqlalchemy.Table +webgui_events: sqlalchemy.Table +dicom_files: sqlalchemy.Table +dicom_series: sqlalchemy.Table +task_events: sqlalchemy.Table +file_events: sqlalchemy.Table +dicom_series_map: sqlalchemy.Table +series_sequence_data: sqlalchemy.Table +tasks_table: sqlalchemy.Table +tests_table: sqlalchemy.Table +processor_logs_table: sqlalchemy.Table +processor_outputs_table: sqlalchemy.Table + +def init_database(url=None, schema=None) -> databases.Database: + global database, metadata, mercure_events, webgui_events, dicom_files, dicom_series, task_events, file_events, dicom_series_map, series_sequence_data, tasks_table, tests_table, processor_logs_table, processor_outputs_table + database = databases.Database(url or bk_config.DATABASE_URL) + metadata = sqlalchemy.MetaData(schema=(schema or bk_config.DATABASE_SCHEMA)) + + # SQLite does not support JSONB natively, so we use TEXT instead + JSONB = sqlalchemy.types.Text() if 'sqlite://' in (url or bk_config.DATABASE_URL) else sqlalchemy.dialects.postgresql.JSONB + # + mercure_events = sqlalchemy.Table( + "mercure_events", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), + sqlalchemy.Column("time", sqlalchemy.DateTime), + sqlalchemy.Column("sender", sqlalchemy.String, default="Unknown"), + sqlalchemy.Column("event", sqlalchemy.String, default=monitor.m_events.UNKNOWN), + sqlalchemy.Column("severity", sqlalchemy.Integer, default=monitor.severity.INFO), + sqlalchemy.Column("description", sqlalchemy.String, default=""), + ) + + webgui_events = sqlalchemy.Table( + "webgui_events", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), + sqlalchemy.Column("time", sqlalchemy.DateTime), + sqlalchemy.Column("sender", sqlalchemy.String, default="Unknown"), + sqlalchemy.Column("event", sqlalchemy.String, default=monitor.w_events.UNKNOWN), + sqlalchemy.Column("user", sqlalchemy.String, default=""), + sqlalchemy.Column("description", sqlalchemy.String, default=""), + ) + + dicom_files = sqlalchemy.Table( + "dicom_files", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), + sqlalchemy.Column("time", sqlalchemy.DateTime), + sqlalchemy.Column("filename", sqlalchemy.String), + sqlalchemy.Column("file_uid", sqlalchemy.String), + sqlalchemy.Column("series_uid", sqlalchemy.String), + ) + + dicom_series = sqlalchemy.Table( + "dicom_series", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), + sqlalchemy.Column("time", sqlalchemy.DateTime), + sqlalchemy.Column("series_uid", sqlalchemy.String, unique=True), + sqlalchemy.Column("study_uid", sqlalchemy.String), + sqlalchemy.Column("tag_patientname", sqlalchemy.String), + sqlalchemy.Column("tag_patientid", sqlalchemy.String), + sqlalchemy.Column("tag_accessionnumber", sqlalchemy.String), + sqlalchemy.Column("tag_seriesnumber", sqlalchemy.String), + sqlalchemy.Column("tag_studyid", sqlalchemy.String), + sqlalchemy.Column("tag_patientbirthdate", sqlalchemy.String), + sqlalchemy.Column("tag_patientsex", sqlalchemy.String), + sqlalchemy.Column("tag_acquisitiondate", sqlalchemy.String), + sqlalchemy.Column("tag_acquisitiontime", sqlalchemy.String), + sqlalchemy.Column("tag_modality", sqlalchemy.String), + sqlalchemy.Column("tag_bodypartexamined", sqlalchemy.String), + sqlalchemy.Column("tag_studydescription", sqlalchemy.String), + sqlalchemy.Column("tag_seriesdescription", sqlalchemy.String), + sqlalchemy.Column("tag_protocolname", sqlalchemy.String), + sqlalchemy.Column("tag_codevalue", sqlalchemy.String), + sqlalchemy.Column("tag_codemeaning", sqlalchemy.String), + sqlalchemy.Column("tag_sequencename", sqlalchemy.String), + sqlalchemy.Column("tag_scanningsequence", sqlalchemy.String), + sqlalchemy.Column("tag_sequencevariant", sqlalchemy.String), + sqlalchemy.Column("tag_slicethickness", sqlalchemy.String), + sqlalchemy.Column("tag_contrastbolusagent", sqlalchemy.String), + sqlalchemy.Column("tag_referringphysicianname", sqlalchemy.String), + sqlalchemy.Column("tag_manufacturer", sqlalchemy.String), + sqlalchemy.Column("tag_manufacturermodelname", sqlalchemy.String), + sqlalchemy.Column("tag_magneticfieldstrength", sqlalchemy.String), + sqlalchemy.Column("tag_deviceserialnumber", sqlalchemy.String), + sqlalchemy.Column("tag_softwareversions", sqlalchemy.String), + sqlalchemy.Column("tag_stationname", sqlalchemy.String), + ) + + task_events = sqlalchemy.Table( + "task_events", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), + sqlalchemy.Column("task_id", sqlalchemy.String, sqlalchemy.ForeignKey("tasks.id"), nullable=True), + sqlalchemy.Column("time", sqlalchemy.DateTime), + sqlalchemy.Column("sender", sqlalchemy.String, default="Unknown"), + sqlalchemy.Column("event", sqlalchemy.String), + # sqlalchemy.Column("series_uid", sqlalchemy.String), + sqlalchemy.Column("file_count", sqlalchemy.Integer), + sqlalchemy.Column("target", sqlalchemy.String), + sqlalchemy.Column("info", sqlalchemy.String), + sqlalchemy.Column("client_timestamp", sqlalchemy.Integer), + ) + + file_events = sqlalchemy.Table( + "file_events", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), + sqlalchemy.Column("time", sqlalchemy.DateTime), + sqlalchemy.Column("dicom_file", sqlalchemy.Integer), + sqlalchemy.Column("event", sqlalchemy.Integer), + ) + + dicom_series_map = sqlalchemy.Table( + "dicom_series_map", + metadata, + sqlalchemy.Column("id_file", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("id_series", sqlalchemy.Integer), + ) + + series_sequence_data = sqlalchemy.Table( + "series_sequence_data", + metadata, + sqlalchemy.Column("uid", sqlalchemy.String, primary_key=True), + sqlalchemy.Column("data", sqlalchemy.JSON), + ) + + tasks_table = sqlalchemy.Table( + "tasks", + metadata, + sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), + sqlalchemy.Column("parent_id", sqlalchemy.String, nullable=True), + sqlalchemy.Column("time", sqlalchemy.DateTime), + sqlalchemy.Column("series_uid", sqlalchemy.String, nullable=True), + sqlalchemy.Column("study_uid", sqlalchemy.String, nullable=True), + sqlalchemy.Column("data", JSONB), + ) + + tests_table = sqlalchemy.Table( + "tests", + metadata, + sqlalchemy.Column("id", sqlalchemy.String, primary_key=True), + sqlalchemy.Column("type", sqlalchemy.String, nullable=True), + sqlalchemy.Column("rule_type", sqlalchemy.String, nullable=True), + sqlalchemy.Column("time_begin", sqlalchemy.DateTime, nullable=True), + sqlalchemy.Column("time_end", sqlalchemy.DateTime, nullable=True), + sqlalchemy.Column("status", sqlalchemy.String, nullable=True), + sqlalchemy.Column("task_id", sqlalchemy.String, nullable=True), + sqlalchemy.Column("data", JSONB, nullable=True), + ) + + processor_logs_table = sqlalchemy.Table( + "processor_logs", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), + sqlalchemy.Column("task_id", sqlalchemy.String, sqlalchemy.ForeignKey("tasks.id"), nullable=True), + sqlalchemy.Column("module_name", sqlalchemy.String, nullable=True), + sqlalchemy.Column("logs", sqlalchemy.String, nullable=True), + sqlalchemy.Column("time", sqlalchemy.DateTime, nullable=True), + ) + + processor_outputs_table = sqlalchemy.Table( + "processor_outputs", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True), + sqlalchemy.Column("time", sqlalchemy.DateTime(timezone=True), server_default=func.now()), + sqlalchemy.Column("task_id", sqlalchemy.String, sqlalchemy.ForeignKey("tasks.id"),nullable=True), + sqlalchemy.Column("task_acc", sqlalchemy.String), + sqlalchemy.Column("task_mrn", sqlalchemy.String), + sqlalchemy.Column("module", sqlalchemy.String), + sqlalchemy.Column("index", sqlalchemy.Integer), + sqlalchemy.Column("settings", JSONB), + sqlalchemy.Column("output", JSONB), + ) diff --git a/bookkeeping/query.py b/bookkeeping/query.py index 0f44932d..d38483b1 100644 --- a/bookkeeping/query.py +++ b/bookkeeping/query.py @@ -17,7 +17,7 @@ # App-specific includes import bookkeeping.config as bk_config -from bookkeeping.database import * +import bookkeeping.database as db from bookkeeping.helper import * from common import config from decoRouter import Router as decoRouter @@ -41,11 +41,11 @@ def set_timezone_conversion() -> None: async def get_series(request) -> JSONResponse: """Endpoint for retrieving series in the database.""" series_uid = request.query_params.get("series_uid", "") - query = dicom_series.select() + query = db.dicom_series.select() if series_uid: - query = query.where(dicom_series.c.series_uid == series_uid) + query = query.where(db.dicom_series.c.series_uid == series_uid) - result = await database.fetch_all(query) + result = await db.database.fetch_all(query) series = [dict(row) for row in result] for i, line in enumerate(series): @@ -61,14 +61,14 @@ async def get_tasks(request) -> JSONResponse: """Endpoint for retrieving tasks in the database.""" query = ( sqlalchemy.select( - tasks_table.c.id, tasks_table.c.time, dicom_series.c.tag_seriesdescription, dicom_series.c.tag_modality + db.tasks_table.c.id, db.tasks_table.c.time, db.dicom_series.c.tag_seriesdescription, db.dicom_series.c.tag_modality ) - .where(tasks_table.c.parent_id.is_(None)) # only show tasks without parents + .where(db.tasks_table.c.parent_id.is_(None)) # only show tasks without parents .join( - dicom_series, + db.dicom_series, # sqlalchemy.or_( # (dicom_series.c.study_uid == tasks_table.c.study_uid), - (dicom_series.c.series_uid == tasks_table.c.series_uid), + (db.dicom_series.c.series_uid == db.tasks_table.c.series_uid), # ), isouter=True, ) @@ -77,14 +77,14 @@ async def get_tasks(request) -> JSONResponse: # """ select tasks.id as task_id, tasks.time, tasks.series_uid, tasks.study_uid, "tag_seriesdescription", "tag_modality" from tasks # join dicom_series on tasks.study_uid = dicom_series.study_uid or tasks.series_uid = dicom_series.series_uid """ # ) - results = await database.fetch_all(query) + results = await db.database.fetch_all(query) return CustomJSONResponse(results) @router.get("/tests") @requires("authenticated") async def get_test_task(request) -> JSONResponse: - query = tests_table.select().order_by(tests_table.c.time_begin.desc()) + query = db.tests_table.select().order_by(db.tests_table.c.time_begin.desc()) # query = ( # sqlalchemy.select( # tasks_table.c.id, tasks_table.c.time, dicom_series.c.tag_seriesdescription, dicom_series.c.tag_modality @@ -98,7 +98,7 @@ async def get_test_task(request) -> JSONResponse: # ) # .where(dicom_series.c.tag_seriesdescription == "self_test_series " + request.query_params.get("id", "")) # ) - result_rows = await database.fetch_all(query) + result_rows = await db.database.fetch_all(query) results = [dict(row) for row in result_rows] for k in results: if not k["time_end"]: @@ -113,11 +113,11 @@ async def get_task_events(request) -> JSONResponse: """Endpoint for getting all events related to one task.""" task_id = request.query_params.get("task_id", "") - subtask_query = sqlalchemy.select(tasks_table.c.id).where(tasks_table.c.parent_id == task_id) + subtask_query = sqlalchemy.select(db.tasks_table.c.id).where(db.tasks_table.c.parent_id == task_id) # Note: The space at the end is needed for the case that there are no subtasks subtask_ids_str = "" - for row in await database.fetch_all(subtask_query): + for row in await db.database.fetch_all(subtask_query): subtask_ids_str += f"'{row[0]}'," subtask_ids_filter = "" @@ -139,7 +139,7 @@ async def get_task_events(request) -> JSONResponse: #print("SQL Query = " + query_string) query = sqlalchemy.text(query_string) - results = await database.fetch_all(query) + results = await db.database.fetch_all(query) return CustomJSONResponse(results) @@ -148,10 +148,10 @@ async def get_task_events(request) -> JSONResponse: async def get_dicom_files(request) -> JSONResponse: """Endpoint for getting all events related to one series.""" series_uid = request.query_params.get("series_uid", "") - query = dicom_files.select().order_by(dicom_files.c.time) + query = db.dicom_files.select().order_by(db.dicom_files.c.time) if series_uid: - query = query.where(dicom_files.c.series_uid == series_uid) - results = await database.fetch_all(query) + query = query.where(db.dicom_files.c.series_uid == series_uid) + results = await db.database.fetch_all(query) return CustomJSONResponse(results) @@ -162,16 +162,16 @@ async def get_task_process_logs(request) -> JSONResponse: task_id = request.query_params.get("task_id", "") subtask_query = ( - tasks_table.select() - .order_by(tasks_table.c.id) - .where(sqlalchemy.or_(tasks_table.c.id == task_id, tasks_table.c.parent_id == task_id)) + db.tasks_table.select() + .order_by(db.tasks_table.c.id) + .where(sqlalchemy.or_(db.tasks_table.c.id == task_id, db.tasks_table.c.parent_id == task_id)) ) - subtasks = await database.fetch_all(subtask_query) + subtasks = await db.database.fetch_all(subtask_query) subtask_ids = [row[0] for row in subtasks] - query = processor_logs_table.select(processor_logs_table.c.task_id.in_(subtask_ids)).order_by(processor_logs_table.c.id) - results = [dict(r) for r in await database.fetch_all(query)] + query = db.processor_logs_table.select(db.processor_logs_table.c.task_id.in_(subtask_ids)).order_by(db.processor_logs_table.c.id) + results = [dict(r) for r in await db.database.fetch_all(query)] for result in results: if result["logs"] == None: if logs_folder := config.mercure.processing_logs.logs_file_store: @@ -187,8 +187,8 @@ async def get_task_process_results(request) -> JSONResponse: """Endpoint for getting all processing results from a task.""" task_id = request.query_params.get("task_id", "") - query = processor_outputs_table.select().where(processor_outputs_table.c.task_id == task_id).order_by(processor_outputs_table.c.id) - results = [dict(r) for r in await database.fetch_all(query)] + query = db.processor_outputs_table.select().where(db.processor_outputs_table.c.task_id == task_id).order_by(db.processor_outputs_table.c.id) + results = [dict(r) for r in await db.database.fetch_all(query)] return CustomJSONResponse(results) @@ -253,7 +253,7 @@ async def find_task(request) -> JSONResponse: query = sqlalchemy.text(query_string) response: Dict = {} - result_rows = await database.fetch_all(query) + result_rows = await db.database.fetch_all(query) results = [dict(row) for row in result_rows] for item in results: @@ -330,7 +330,7 @@ async def get_task_info(request) -> JSONResponse: limit 1""" ) # TODO: use sqlalchemy interpolation - info_rows = await database.fetch_all(info_query) + info_rows = await db.database.fetch_all(info_query) info_results = [dict(row) for row in info_rows] if info_results: @@ -352,11 +352,11 @@ async def get_task_info(request) -> JSONResponse: # Now, get the task files embedded into the task or its subtasks query = ( - tasks_table.select() - .order_by(tasks_table.c.id) - .where(sqlalchemy.or_(tasks_table.c.id == task_id, tasks_table.c.parent_id == task_id)) + db.tasks_table.select() + .order_by(db.tasks_table.c.id) + .where(sqlalchemy.or_(db.tasks_table.c.id == task_id, db.tasks_table.c.parent_id == task_id)) ) - result_rows = await database.fetch_all(query) + result_rows = await db.database.fetch_all(query) results = [dict(row) for row in result_rows] for item in results: diff --git a/common/config.py b/common/config.py index cc2ecd39..9fe704c1 100755 --- a/common/config.py +++ b/common/config.py @@ -26,7 +26,11 @@ logger = get_logger() configuration_timestamp: float = 0 -configuration_filename = (os.getenv("MERCURE_CONFIG_FOLDER") or "/opt/mercure/config") + "/mercure.json" +_os_config_file = os.getenv("MERCURE_CONFIG_FILE") +if _os_config_file is not None: + configuration_filename = _os_config_file +else: + configuration_filename = (os.getenv("MERCURE_CONFIG_FOLDER") or "/opt/mercure/config") + "/mercure.json" mercure_defaults = { "appliance_name": "master", @@ -91,51 +95,51 @@ def read_config() -> Config: if lock_file.exists(): raise ResourceWarning(f"Configuration file locked: {lock_file}") - if configuration_file.exists(): - # Get the modification date/time of the configuration file - stat = os.stat(configuration_filename) - try: - timestamp = stat.st_mtime - except AttributeError: - timestamp = 0 - - # Check if the configuration file is newer than the version - # loaded into memory. If not, return - if timestamp <= configuration_timestamp: - return mercure - - logger.info(f"Reading configuration from: {configuration_filename}") - - with open(configuration_file, "r") as json_file: - loaded_config = json.load(json_file) - # Reset configuration to default values (to ensure all needed - # keys are present in the configuration) - merged: Dict = {**mercure_defaults, **loaded_config} - mercure = Config(**merged) - - # TODO: Check configuration for errors (esp targets and rules) - - # Check if directories exist - if not check_folders(): - raise FileNotFoundError("Configured folders missing") - - # logger.info("") - # logger.info("Active configuration: ") - # logger.info(json.dumps(mercure, indent=4)) - # logger.info("") - - try: - read_tagslist() - except Exception as e: - logger.info(e) - logger.info("Unable to parse list of additional tags. Check configuration file.") - - configuration_timestamp = timestamp - monitor.send_event(monitor.m_events.CONFIG_UPDATE, monitor.severity.INFO, "Configuration updated") - return mercure - else: + if not configuration_file.exists(): raise FileNotFoundError(f"Configuration file not found: {configuration_file}") + # Get the modification date/time of the configuration file + stat = os.stat(configuration_filename) + try: + timestamp = stat.st_mtime + except AttributeError: + timestamp = 0 + + # Check if the configuration file is newer than the version + # loaded into memory. If not, return + if timestamp <= configuration_timestamp: + return mercure + + logger.info(f"Reading configuration from: {configuration_filename}") + + with open(configuration_file, "r") as json_file: + loaded_config = json.load(json_file) + # Reset configuration to default values (to ensure all needed + # keys are present in the configuration) + merged: Dict = {**mercure_defaults, **loaded_config} + mercure = Config(**merged) + + # TODO: Check configuration for errors (esp targets and rules) + + # Check if directories exist + if not check_folders(): + raise FileNotFoundError("Configured folders missing") + + # logger.info("") + # logger.info("Active configuration: ") + # logger.info(json.dumps(mercure, indent=4)) + # logger.info("") + + try: + read_tagslist() + except Exception as e: + logger.info(e) + logger.info("Unable to parse list of additional tags. Check configuration file.") + + configuration_timestamp = timestamp + monitor.send_event(monitor.m_events.CONFIG_UPDATE, monitor.severity.INFO, "Configuration updated") + return mercure + def save_config() -> None: """Saves the current configuration in a file on the disk. Raises an exception if the file has diff --git a/dev-requirements.in b/dev-requirements.in index 038ef3e8..887ee0ed 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,3 +1,5 @@ -c requirements.txt supervisor fakeredis +pytest +pytest-asyncio \ No newline at end of file diff --git a/dev-requirements.txt b/dev-requirements.txt index 8ef5a98f..9a047bdc 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -8,8 +8,31 @@ async-timeout==4.0.3 # via # -c requirements.txt # redis +exceptiongroup==1.2.1 + # via + # -c requirements.txt + # pytest fakeredis==2.25.1 # via -r dev-requirements.in +iniconfig==2.0.0 + # via + # -c requirements.txt + # pytest +packaging==24.1 + # via + # -c requirements.txt + # pytest +pluggy==1.5.0 + # via + # -c requirements.txt + # pytest +pytest==8.2.2 + # via + # -c requirements.txt + # -r dev-requirements.in + # pytest-asyncio +pytest-asyncio==0.24.0 + # via -r dev-requirements.in redis==5.0.4 # via # -c requirements.txt @@ -18,6 +41,10 @@ sortedcontainers==2.4.0 # via fakeredis supervisor==4.2.5 # via -r dev-requirements.in +tomli==2.0.1 + # via + # -c requirements.txt + # pytest typing-extensions==4.12.2 # via # -c requirements.txt diff --git a/test.py b/test.py index 85d513ad..50d4d5a7 100755 --- a/test.py +++ b/test.py @@ -14,7 +14,7 @@ def run_test() -> None: config.mercure.study_complete_trigger = 2 config.mercure.series_complete_trigger = 1 config.save_config() - + app = create_app() client = TestClient(app) startup(app) form_data = { diff --git a/tests/data/test_config.json b/tests/data/test_config.json index cd23cbc3..7900326e 100755 --- a/tests/data/test_config.json +++ b/tests/data/test_config.json @@ -10,6 +10,7 @@ "processing_folder": "/var/processing", "jobs_folder": "/var/jobs", "bookkeeper": "0.0.0.0:8080", + "bookkeeper_api_key": "12345", "graphite_ip": "", "graphite_port": 2003, "router_scan_interval": 1, diff --git a/tests/dispatch/test_retry.py b/tests/dispatch/test_retry.py index 2732dd17..ea3b1659 100755 --- a/tests/dispatch/test_retry.py +++ b/tests/dispatch/test_retry.py @@ -1,10 +1,4 @@ import json -import os -import time -from pathlib import Path -from subprocess import CalledProcessError - -import pytest from dispatch.retry import increase_retry from common.constants import mercure_names diff --git a/tests/test_bookkeeper.py b/tests/test_bookkeeper.py index 2b475192..545b5652 100755 --- a/tests/test_bookkeeper.py +++ b/tests/test_bookkeeper.py @@ -2,9 +2,26 @@ test_bookkeeper.py ================== """ -import bookkeeper as b +import multiprocessing +import time +import requests +import bookkeeper +from testing_common import * -def test_bookkeeper_no_syntax_errors(): +# def run_server(app, port): +# b.uvicorn.run(app, host="localhost", port=port) + +def test_bookkeeper_starts(fs, bookkeeper_port): """ Checks if bookkeeper.py can be started. """ - assert b + bookkeeper_process = multiprocessing.Process(target=bookkeeper.main) + bookkeeper_process.start() + # Wait for the server to start + time.sleep(2) + response = requests.get(f"http://127.0.0.1:{bookkeeper_port}/test") + assert response.status_code == 200 + assert response.text == '{"ok":""}' + # Shutdown the server + print("Shutting down the server...") + bookkeeper_process.terminate() + bookkeeper_process.join() diff --git a/tests/test_query.py b/tests/test_query.py index 3bd41fb7..f7c60cea 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,5 +1,6 @@ import os from pathlib import Path +import subprocess import tempfile from typing import Dict, Optional, Tuple import pydicom @@ -16,7 +17,7 @@ from common.types import DicomTarget, DicomWebTarget from webinterface.common import redis from pydicom.uid import ExplicitVRLittleEndian, ImplicitVRLittleEndian -from testing_common import receiver_port, mercure_config +from testing_common import receiver_port, mercure_config, bookkeeper_port from logging import getLogger from rq import SimpleWorker, Queue, Connection from fakeredis import FakeStrictRedis @@ -183,17 +184,30 @@ def test_get_accession_job(dicom_server, dicomweb_server, mercure_config): assert results[0].remaining == 0 assert pydicom.dcmread(next(k for k in Path(config.jobs_folder).rglob("*.dcm"))).AccessionNumber == MOCK_ACCESSIONS[0] -def test_query_job(dicom_server, tempdir, rq_connection): +def test_query_job(dicom_server, tempdir, rq_connection,fs): """ Test the create_job function. We use mocker to mock the queue and avoid actually creating jobs. """ + fs.pause() + try: + if (subprocess.run(['systemctl', 'is-active', "mercure_worker*"],capture_output=True,text=True,check=False, + ).stdout.strip() == 'active'): + raise Exception("At least one mercure worker is running, stop it before running test.") + except subprocess.CalledProcessError: + pass + fs.resume() job = QueryPipeline.create([MOCK_ACCESSIONS[0]], {}, dicom_server, str(tempdir)) w = SimpleWorker(["mercure_fast", "mercure_slow"], connection=rq_connection) + w.work(burst=True) # assert len(list(Path(config.mercure.jobs_folder).iterdir())) == 1 print([k for k in Path(tempdir).rglob('*')]) - assert pydicom.dcmread(next(k for k in Path(tempdir).rglob("*.dcm"))).AccessionNumber == MOCK_ACCESSIONS[0] + try: + example_dcm = next(k for k in Path(tempdir).rglob("*.dcm")) + except StopIteration: + assert False, f"No DICOM file found in {tempdir}" + assert pydicom.dcmread(example_dcm).AccessionNumber == MOCK_ACCESSIONS[0] def tree(path, prefix='', level=0) -> None: if level==0: diff --git a/tests/test_router.py b/tests/test_router.py index 6a766bdd..f3a4b47b 100755 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -67,8 +67,7 @@ def create_series(mocked, fs, config, tags, name="bar") -> Tuple[str, str]: mock_incoming_uid(config, fs, series_uid, tags,name) return task_id, series_uid -@pytest.mark.asyncio -async def test_route_series_fail1(fs: FakeFilesystem, mercure_config, mocked): +def test_route_series_fail1(fs: FakeFilesystem, mercure_config, mocked): config = mercure_config(rules) tags = {"asdfasdfas": "foo"} @@ -88,7 +87,6 @@ async def test_route_series_fail1(fs: FakeFilesystem, mercure_config, mocked): f"Invalid tag for series {series_uid}", ) - def test_route_series_fail2(fs: FakeFilesystem, mercure_config, mocked): config = mercure_config(rules) @@ -104,7 +102,6 @@ def test_route_series_fail2(fs: FakeFilesystem, mercure_config, mocked): common.monitor.send_task_event.assert_any_call(task_event.DISCARD, task_id,1, "","Discard by default.") # type: ignore common.monitor.send_task_event.reset_mock() # type: ignore - def test_route_series_fail3(fs: FakeFilesystem, mercure_config, mocked): config = mercure_config(rules) @@ -145,7 +142,6 @@ def no_create_destination(dest): # f"Creating folder not possible {config.outgoing_folder}/{task_id}", # ) - def test_route_series_fail4(fs: FakeFilesystem, mercure_config, mocked): config = mercure_config(rules) @@ -164,7 +160,6 @@ def test_route_series_fail4(fs: FakeFilesystem, mercure_config, mocked): ) assert list(Path(config.outgoing_folder).glob("**/*.dcm")) == [] - def task_will_dispatch_to(task, config, fake_process) -> None: for target_item in task.dispatch.target_name: t = config.targets[target_item] @@ -188,8 +183,7 @@ def task_will_dispatch_to(task, config, fake_process) -> None: ) -@pytest.mark.asyncio -async def test_route_study(fs: FakeFilesystem, mercure_config, mocked, fake_process): +def test_route_study(fs: FakeFilesystem, mercure_config, mocked, fake_process): config = mercure_config(rules) study_uid = str(uuid.uuid4()) @@ -243,8 +237,7 @@ async def test_route_study(fs: FakeFilesystem, mercure_config, mocked, fake_proc # f"Routed to test_target", # ) -@pytest.mark.asyncio -async def test_route_series_success(fs: FakeFilesystem, mercure_config, mocked, fake_process): +def test_route_series_success(fs: FakeFilesystem, mercure_config, mocked, fake_process): config = mercure_config(rules) # attach_spies(mocker) # mocker.patch("routing.route_series.parse_ascconv", new=lambda x: {}) @@ -302,8 +295,7 @@ async def test_route_series_success(fs: FakeFilesystem, mercure_config, mocked, # common.monitor.send_event.assert_not_called() -@pytest.mark.asyncio -async def test_route_series_new_rule(fs: FakeFilesystem, mercure_config, mocked, fake_process): +def test_route_series_new_rule(fs: FakeFilesystem, mercure_config, mocked, fake_process): config = mercure_config(rules) # attach_spies(mocker) # mocker.patch("routing.route_series.parse_ascconv", new=lambda x: {}) @@ -341,8 +333,7 @@ async def test_route_series_new_rule(fs: FakeFilesystem, mercure_config, mocked, assert task.info.triggered_rules["route_series_new_rule"] == True # type: ignore task_will_dispatch_to(task, config, fake_process) -@pytest.mark.asyncio -async def test_route_series_with_bad_tags(fs: FakeFilesystem, mercure_config, mocked, fake_process): +def test_route_series_with_bad_tags(fs: FakeFilesystem, mercure_config, mocked, fake_process): config = mercure_config(rules) # attach_spies(mocker) # mocker.patch("routing.route_series.parse_ascconv", new=lambda x: {}) @@ -370,8 +361,7 @@ async def test_route_series_with_bad_tags(fs: FakeFilesystem, mercure_config, mo task: Task = Task(**json.load(e)) task_will_dispatch_to(task, config, fake_process) -@pytest.mark.asyncio -async def test_route_series_fail_with_bad_tags(fs: FakeFilesystem, mercure_config, mocked, fake_process): +def test_route_series_fail_with_bad_tags(fs: FakeFilesystem, mercure_config, mocked, fake_process): config = mercure_config(rules) # attach_spies(mocker) # mocker.patch("routing.route_series.parse_ascconv", new=lambda x: {}) diff --git a/tests/test_studies.py b/tests/test_studies.py index 563e11fb..63a34d7b 100644 --- a/tests/test_studies.py +++ b/tests/test_studies.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timedelta import importlib @@ -171,9 +172,8 @@ def test_route_study_error(fs: FakeFilesystem, mercure_config, mocked): assert list(Path(config.error_folder).glob("**/*")) != [] -@pytest.mark.asyncio @pytest.mark.parametrize("do_error", [True, False]) -async def test_route_study_processing(fs: FakeFilesystem, mercure_config, mocked, do_error): +def test_route_study_processing(fs: FakeFilesystem, mercure_config, mocked, do_error): config = mercure_config( { "modules": { @@ -232,7 +232,8 @@ def fake_processor(tag=None, meta=None, do_process=True, **kwargs): mocked.patch.object(Job, "dispatch_job", new=fake_run) mocked.patch.object(Job, "get_job", new=lambda x, y: dict(Status="dead")) - await processor.run_processor() + asyncio.run(processor.run_processor()) + # await processor.run_processor() if do_error: assert list(Path(config.error_folder).glob("**/*")) != [] else: diff --git a/tests/testing_common.py b/tests/testing_common.py index 09f247a7..0d0a3ab1 100755 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -7,7 +7,7 @@ from pathlib import Path import shutil import socket -from typing import Callable, Dict, Any, Iterator, Optional, Tuple +from typing import Callable, Dict, Any, Iterator, List, Optional, Tuple import uuid import pydicom @@ -20,7 +20,7 @@ import pytest import process -import routing, common, router, processor +import routing, common, router, processor, bookkeeper import common.config as config from common.types import Config import docker.errors @@ -95,9 +95,13 @@ def mocked(mocker): attach_spies(mocker) return mocker +@pytest.fixture(scope="module") +def bookkeeper_port(): + return random_port() + @pytest.fixture(scope="function", autouse=True) -def mercure_config(fs) -> Callable[[Dict], Config]: +def mercure_config(fs, bookkeeper_port) -> Callable[[Dict], Config]: # TODO: config from previous calls seems to leak in here config_path = os.path.realpath(os.path.dirname(os.path.realpath(__file__)) + "/data/test_config.json") @@ -112,7 +116,17 @@ def set_config(extra: Dict[Any, Any] = {}) -> Config: config.save_config() return config.mercure - set_config() + # set_config() + set_config({"bookkeeper": "sqlite:///tmp/mercure_bookkeeper_"+str(uuid.uuid4())+".db"}) # sqlite3 is not inside the fakefs so this is going to be a real file + + bookkeeper_env = f"""PORT={bookkeeper_port} +HOST=0.0.0.0 +DATABASE_URL={config.mercure.bookkeeper}""" + fs.create_file(bookkeeper.bk_config.config_filename, contents=bookkeeper_env) + + fs.add_real_directory(os.path.dirname(os.path.realpath(__file__))+'/../alembic') + + fs.add_real_file(os.path.dirname(os.path.realpath(__file__))+'/../alembic.ini',read_only=True) return set_config @@ -252,7 +266,3 @@ def random_port() -> int: @pytest.fixture(scope="module") def receiver_port(): return random_port() - -@pytest.fixture(scope="module") -def bookkeeper_port(): - return random_port() diff --git a/webgui.py b/webgui.py index ed7981a8..d3702633 100755 --- a/webgui.py +++ b/webgui.py @@ -116,18 +116,24 @@ async def authenticate(self, request): return AuthCredentials(credentials), ExtendedUser(username, is_admin) - -webgui_config = Config((os.getenv("MERCURE_CONFIG_FOLDER") or "/opt/mercure/config") + "/webgui.env") - - -# Note: PutSomethingRandomHere is the default value in the shipped configuration file. -# The app will not start with this value, forcing the users to set their onw secret -# key. Therefore, the value is used as default here as well. -SECRET_KEY = webgui_config("SECRET_KEY", cast=Secret, default="PutSomethingRandomHere") -WEBGUI_PORT = webgui_config("PORT", cast=int, default=8000) -WEBGUI_HOST = webgui_config("HOST", default="0.0.0.0") -DEBUG_MODE = webgui_config("DEBUG", cast=bool, default=True) - +webgui_config = None +SECRET_KEY: Secret +WEBGUI_PORT: int +WEBGUI_HOST: str +DEBUG_MODE: bool +def read_webgui_config() -> Config: + global webgui_config, SECRET_KEY, WEBGUI_HOST, WEBGUI_PORT, DEBUG_MODE + webgui_config = Config((os.getenv("MERCURE_CONFIG_FOLDER") or "/opt/mercure/config") + "/webgui.env") + + + # Note: PutSomethingRandomHere is the default value in the shipped configuration file. + # The app will not start with this value, forcing the users to set their onw secret + # key. Therefore, the value is used as default here as well. + SECRET_KEY = webgui_config("SECRET_KEY", cast=Secret, default=Secret("PutSomethingRandomHere")) + WEBGUI_PORT = webgui_config("PORT", cast=int, default=8000) + WEBGUI_HOST = webgui_config("HOST", default="0.0.0.0") + DEBUG_MODE = webgui_config("DEBUG", cast=bool, default=True) + return webgui_config @contextlib.asynccontextmanager async def lifespan(app): @@ -914,21 +920,24 @@ async def server_error(request, exc) -> Response: 404: not_found, 500: server_error } - -app = Starlette(debug=DEBUG_MODE, lifespan=lifespan, exception_handlers=exception_handlers, routes=router) -# Don't check the existence of the static folder because the wrong parent folder is used if the -# source code is parsed by sphinx. This would raise an exception and lead to failure of sphinx. -app.mount("/static", StaticFiles(directory="webinterface/statics", check_dir=False), name="static") -app.add_middleware(AuthenticationMiddleware, backend=SessionAuthBackend()) -app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY, session_cookie="mercure_session") -app.mount("/rules", rules.rules_app, name="rules") -app.mount("/targets", targets.targets_app) -app.mount("/modules", modules.modules_app) -app.mount("/users", users.users_app) -app.mount("/queue", queue.queue_app) -app.mount("/api", api.api_app) -app.mount("/tools", dashboards.dashboards_app) - +app = None + +def create_app() -> Starlette: + global app + app = Starlette(debug=DEBUG_MODE, lifespan=lifespan, exception_handlers=exception_handlers, routes=router) + # Don't check the existence of the static folder because the wrong parent folder is used if the + # source code is parsed by sphinx. This would raise an exception and lead to failure of sphinx. + app.mount("/static", StaticFiles(directory="webinterface/statics", check_dir=False), name="static") + app.add_middleware(AuthenticationMiddleware, backend=SessionAuthBackend()) + app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY, session_cookie="mercure_session") + app.mount("/rules", rules.rules_app, name="rules") + app.mount("/targets", targets.targets_app) + app.mount("/modules", modules.modules_app) + app.mount("/users", users.users_app) + app.mount("/queue", queue.queue_app) + app.mount("/api", api.api_app) + app.mount("/tools", dashboards.dashboards_app) + return app @@ -967,6 +976,7 @@ def main(args=sys.argv[1:]) -> None: logging.getLogger("watchdog").setLevel(logging.WARNING) try: + read_webgui_config() services.read_services() config_ = config.read_config() users.read_users() @@ -976,6 +986,8 @@ def main(args=sys.argv[1:]) -> None: if str(SECRET_KEY) == "PutSomethingRandomHere": logger.error("You need to change the SECRET_KEY in configuration/webgui.env") raise Exception("Invalid or missing SECRET_KEY in webgui.env") + app = create_app() + uvicorn.run(app, host=WEBGUI_HOST, port=WEBGUI_PORT) except Exception as e: logger.error(e) logger.error("Cannot start service. Showing emergency message.") @@ -983,8 +995,5 @@ def main(args=sys.argv[1:]) -> None: logger.info("Going down.") sys.exit(1) - uvicorn.run(app, host=WEBGUI_HOST, port=WEBGUI_PORT) - - if __name__ == "__main__": main()