From c2abf45214c95f97883e5c9bd3227fd18de92579 Mon Sep 17 00:00:00 2001 From: Andrew Brain Date: Mon, 26 Feb 2024 07:54:53 -0600 Subject: [PATCH 1/3] automatically handle database engine cleanup in cli --- augur/application/cli/__init__.py | 17 ++++++++++ augur/application/cli/backend.py | 44 +++++++++++++++---------- augur/application/cli/config.py | 43 +++++++++++++++--------- augur/application/cli/db.py | 54 +++++++++++++++++-------------- augur/application/cli/user.py | 2 +- 5 files changed, 102 insertions(+), 58 deletions(-) diff --git a/augur/application/cli/__init__.py b/augur/application/cli/__init__.py index aaf548432a..67f85c4d09 100644 --- a/augur/application/cli/__init__.py +++ b/augur/application/cli/__init__.py @@ -8,6 +8,7 @@ import json from augur.application.db.engine import DatabaseEngine +from augur.application.db import get_engine, dispose_database_engine from sqlalchemy.exc import OperationalError @@ -72,6 +73,22 @@ def new_func(ctx, *args, **kwargs): return update_wrapper(new_func, function_db_connection) + +class DatabaseContext(object): + def __init__(self): + self.engine = None + +def with_database(f): + @click.pass_context + def new_func(ctx, *args, **kwargs): + ctx.obj.engine = get_engine() + try: + return ctx.invoke(f, *args, **kwargs) + finally: + dispose_database_engine() + return new_func + + # def pass_application(f): # @click.pass_context # def new_func(ctx, *args, **kwargs): diff --git a/augur/application/cli/backend.py b/augur/application/cli/backend.py index 87e119cbb5..5a12de3f65 100644 --- a/augur/application/cli/backend.py +++ b/augur/application/cli/backend.py @@ -23,17 +23,17 @@ from augur.application.db.session import DatabaseSession from augur.application.logs import AugurLogger from augur.application.config import AugurConfig -from augur.application.cli import test_connection, test_db_connection +from augur.application.cli import test_connection, test_db_connection, with_database, DatabaseContext import sqlalchemy as s logger = AugurLogger("augur", reset_logfiles=True).get_logger() - @click.group('server', short_help='Commands for controlling the backend API server & data collection workers') -def cli(): - pass +@click.pass_context +def cli(ctx): + ctx.obj = DatabaseContext() @cli.command("start") @click.option("--disable-collection", is_flag=True, default=False, help="Turns off data collection workers") @@ -41,7 +41,9 @@ def cli(): @click.option('--port') @test_connection @test_db_connection -def start(disable_collection, development, port): +@with_database +@click.pass_context +def start(ctx, disable_collection, development, port): """Start Augur's backend server.""" try: @@ -63,7 +65,7 @@ def start(disable_collection, development, port): except FileNotFoundError: logger.error("\n\nPlease run augur commands in the root directory\n\n") - with DatabaseSession(logger) as db_session: + with DatabaseSession(logger, engine=ctx.obj.engine) as db_session: config = AugurConfig(logger, db_session) host = config.get_value("Server", "host") @@ -85,7 +87,7 @@ def start(disable_collection, development, port): logger.info("Deleting old task schedule") os.remove("celerybeat-schedule.db") - with DatabaseSession(logger) as db_session: + with DatabaseSession(logger, engine=ctx.obj.engine) as db_session: config = AugurConfig(logger, db_session) log_level = config.get_value("Logging", "log_level") celery_beat_process = None @@ -94,7 +96,7 @@ def start(disable_collection, development, port): if not disable_collection: - with DatabaseSession(logger) as session: + with DatabaseSession(logger, engine=ctx.obj.engine) as session: clean_collection_status(session) assign_orphan_repos_to_default_user(session) @@ -132,7 +134,7 @@ def start(disable_collection, development, port): if not disable_collection: try: - cleanup_after_collection_halt(logger) + cleanup_after_collection_halt(logger, ctx.obj.engine) except RedisConnectionError: pass @@ -194,24 +196,32 @@ def determine_worker_processes(ratio,maximum): @cli.command('stop') -def stop(): +@test_connection +@test_db_connection +@with_database +@click.pass_context +def stop(ctx): """ Sends SIGTERM to all Augur server & worker processes """ logger = logging.getLogger("augur.cli") - augur_stop(signal.SIGTERM, logger) + augur_stop(signal.SIGTERM, logger, ctx.obj.engine) @cli.command('kill') -def kill(): +@test_connection +@test_db_connection +@with_database +@click.pass_context +def kill(ctx): """ Sends SIGKILL to all Augur server & worker processes """ logger = logging.getLogger("augur.cli") - augur_stop(signal.SIGKILL, logger) + augur_stop(signal.SIGKILL, logger, ctx.obj.engine) -def augur_stop(signal, logger): +def augur_stop(signal, logger, engine): """ Stops augur with the given signal, and cleans up collection if it was running @@ -224,13 +234,13 @@ def augur_stop(signal, logger): _broadcast_signal_to_processes(augur_processes, broadcast_signal=signal, given_logger=logger) if "celery" in process_names: - cleanup_after_collection_halt(logger) + cleanup_after_collection_halt(logger, engine) -def cleanup_after_collection_halt(logger): +def cleanup_after_collection_halt(logger, engine): clear_redis_caches() connection_string = "" - with DatabaseSession(logger) as session: + with DatabaseSession(logger, engine=engine) as session: config = AugurConfig(logger, session) connection_string = config.get_section("RabbitMQ")['connection_string'] diff --git a/augur/application/cli/config.py b/augur/application/cli/config.py index 160ce92b32..e5beae92eb 100644 --- a/augur/application/cli/config.py +++ b/augur/application/cli/config.py @@ -9,7 +9,7 @@ from augur.application.db.session import DatabaseSession from augur.application.config import AugurConfig -from augur.application.cli import test_connection, test_db_connection +from augur.application.cli import DatabaseContext, test_connection, test_db_connection, with_database from augur.util.inspect_without_import import get_phase_names_without_import ROOT_AUGUR_DIRECTORY = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))) @@ -18,8 +18,9 @@ ENVVAR_PREFIX = "AUGUR_" @click.group('config', short_help='Generate an augur.config.json') -def cli(): - pass +@click.pass_context +def cli(ctx): + ctx.obj = DatabaseContext() @cli.command('init') @click.option('--github-api-key', help="GitHub API key for data collection from the GitHub API", envvar=ENVVAR_PREFIX + 'GITHUB_API_KEY') @@ -29,7 +30,9 @@ def cli(): @click.option('--rabbitmq-conn-string', help="String to connect to rabbitmq broker", envvar=ENVVAR_PREFIX + 'RABBITMQ_CONN_STRING') @test_connection @test_db_connection -def init_config(github_api_key, facade_repo_directory, gitlab_api_key, redis_conn_string, rabbitmq_conn_string): +@with_database +@click.pass_context +def init_config(ctx, github_api_key, facade_repo_directory, gitlab_api_key, redis_conn_string, rabbitmq_conn_string): if not github_api_key: @@ -59,7 +62,7 @@ def init_config(github_api_key, facade_repo_directory, gitlab_api_key, redis_con keys["github_api_key"] = github_api_key keys["gitlab_api_key"] = gitlab_api_key - with DatabaseSession(logger) as session: + with DatabaseSession(logger, engine=ctx.obj.engine) as session: config = AugurConfig(logger, session) @@ -104,9 +107,11 @@ def init_config(github_api_key, facade_repo_directory, gitlab_api_key, redis_con @click.option('--file', required=True) @test_connection @test_db_connection -def load_config(file): +@with_database +@click.pass_context +def load_config(ctx, file): - with DatabaseSession(logger) as session: + with DatabaseSession(logger, engine=ctx.obj.engine) as session: config = AugurConfig(logger, session) print("WARNING: This will override your current config") @@ -127,9 +132,11 @@ def load_config(file): @click.option('--file', required=True) @test_connection @test_db_connection -def add_section(section_name, file): +@with_database +@click.pass_context +def add_section(ctx, section_name, file): - with DatabaseSession(logger) as session: + with DatabaseSession(logger, engine=ctx.obj.engine) as session: config = AugurConfig(logger, session) if config.is_section_in_config(section_name): @@ -156,9 +163,11 @@ def add_section(section_name, file): @click.option('--data-type', required=True) @test_connection @test_db_connection -def config_set(section, setting, value, data_type): +@with_database +@click.pass_context +def config_set(ctx, section, setting, value, data_type): - with DatabaseSession(logger) as session: + with DatabaseSession(logger, engine=ctx.obj.engine) as session: config = AugurConfig(logger, session) if data_type not in config.accepted_types: @@ -180,9 +189,11 @@ def config_set(section, setting, value, data_type): @click.option('--setting') @test_connection @test_db_connection -def config_get(section, setting): +@with_database +@click.pass_context +def config_get(ctx, section, setting): - with DatabaseSession(logger) as session: + with DatabaseSession(logger, engine=ctx.obj.engine) as session: config = AugurConfig(logger, session) if setting: @@ -210,9 +221,11 @@ def config_get(section, setting): @cli.command('clear') @test_connection @test_db_connection -def clear_config(): +@with_database +@click.pass_context +def clear_config(ctx): - with DatabaseSession(logger) as session: + with DatabaseSession(logger, ctx.obj.engine) as session: config = AugurConfig(logger, session) if not config.empty(): diff --git a/augur/application/cli/db.py b/augur/application/cli/db.py index 16c75e16be..7b380af24b 100644 --- a/augur/application/cli/db.py +++ b/augur/application/cli/db.py @@ -12,7 +12,7 @@ import json import re -from augur.application.cli import test_connection, test_db_connection +from augur.application.cli import test_connection, test_db_connection, with_database, DatabaseContext from augur.application.db.session import DatabaseSession from augur.application.db.engine import DatabaseEngine @@ -23,15 +23,18 @@ logger = logging.getLogger(__name__) @click.group("db", short_help="Database utilities") -def cli(): - pass +@click.pass_context +def cli(ctx): + ctx.obj = DatabaseContext() @cli.command("add-repos") @click.argument("filename", type=click.Path(exists=True)) @test_connection @test_db_connection -def add_repos(filename): +@with_database +@click.pass_context +def add_repos(ctx, filename): """Add repositories to Augur's database. The .csv file format should be repo_url,group_id @@ -42,7 +45,7 @@ def add_repos(filename): from augur.tasks.github.util.github_task_session import GithubTaskSession from augur.util.repo_load_controller import RepoLoadController - with GithubTaskSession(logger) as session: + with GithubTaskSession(logger, engine=ctx.obj.engine) as session: controller = RepoLoadController(session) @@ -67,12 +70,14 @@ def add_repos(filename): @cli.command("get-repo-groups") @test_connection @test_db_connection -def get_repo_groups(): +@with_database +@click.pass_context +def get_repo_groups(ctx): """ List all repo groups and their associated IDs """ - with DatabaseEngine() as engine, engine.connect() as connection: + with ctx.obj.engine.connect() as connection: df = pd.read_sql( s.sql.text( "SELECT repo_group_id, rg_name, rg_description FROM augur_data.repo_groups" @@ -80,20 +85,21 @@ def get_repo_groups(): connection, ) print(df) - engine.dispose() return df @cli.command("add-repo-groups") +@click.argument("filename", type=click.Path(exists=True)) @test_connection @test_db_connection -@click.argument("filename", type=click.Path(exists=True)) -def add_repo_groups(filename): +@with_database +@click.pass_context +def add_repo_groups(ctx, filename): """ Create new repo groups in Augur's database """ - with DatabaseEngine() as engine, engine.begin() as connection: + with ctx.obj.engine.begin() as connection: df = pd.read_sql( s.sql.text("SELECT repo_group_id FROM augur_data.repo_groups"), @@ -129,21 +135,20 @@ def add_repo_groups(filename): f"Repo group with ID {row[1]} for repo group {row[1]} already exists, skipping..." ) - engine.dispose() - - @cli.command("add-github-org") @click.argument("organization_name") @test_connection @test_db_connection -def add_github_org(organization_name): +@with_database +@click.pass_context +def add_github_org(ctx, organization_name): """ Create new repo groups in Augur's database """ from augur.tasks.github.util.github_task_session import GithubTaskSession from augur.util.repo_load_controller import RepoLoadController - with GithubTaskSession(logger) as session: + with GithubTaskSession(logger, engine=ctx.obj.engine) as session: controller = RepoLoadController(session) @@ -228,7 +233,9 @@ def generate_api_key(ctx): @click.argument("api_key") @test_connection @test_db_connection -def update_api_key(api_key): +@with_database +@click.pass_context +def update_api_key(ctx, api_key): """ Update the API key in the database to the given key """ @@ -242,18 +249,17 @@ def update_api_key(api_key): """ ) - with DatabaseEngine() as engine, engine.begin() as connection: + with ctx.obj.engine.begin() as connection: connection.execute(update_api_key_sql, api_key=api_key) logger.info(f"Updated Augur API key to: {api_key}") - engine.dispose() - - @cli.command("get-api-key") @test_connection @test_db_connection -def get_api_key(): +@with_database +@click.pass_context +def get_api_key(ctx): get_api_key_sql = s.sql.text( """ SELECT value FROM augur_operations.augur_settings WHERE setting='augur_api_key'; @@ -261,13 +267,11 @@ def get_api_key(): ) try: - with DatabaseEngine() as engine, engine.connect() as connection: + with ctx.obj.engine.connect() as connection: print(connection.execute(get_api_key_sql).fetchone()[0]) except TypeError: print("No Augur API key found.") - engine.dispose() - @cli.command( "check-pgpass", diff --git a/augur/application/cli/user.py b/augur/application/cli/user.py index 9d0b822be2..2cae5d7b22 100644 --- a/augur/application/cli/user.py +++ b/augur/application/cli/user.py @@ -12,7 +12,7 @@ from augur.application.db.engine import DatabaseEngine from sqlalchemy.orm import sessionmaker - +# TODO: Update these commands to use cli DatabaseContext so this engine is cleaned up engine = DatabaseEngine().engine Session = sessionmaker(bind=engine) From fb9463e9f68b57c4028f3f5f9b4b3c39b5d81337 Mon Sep 17 00:00:00 2001 From: Andrew Brain Date: Mon, 26 Feb 2024 13:07:12 -0600 Subject: [PATCH 2/3] add database decorator to last db command --- augur/application/cli/db.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/augur/application/cli/db.py b/augur/application/cli/db.py index 7b380af24b..68932f426a 100644 --- a/augur/application/cli/db.py +++ b/augur/application/cli/db.py @@ -368,11 +368,28 @@ def init_database( f"GRANT ALL PRIVILEGES ON DATABASE {target_db_name} TO {target_user};", ) +@cli.command("reset-repo-age") +@test_connection +@test_db_connection +@with_database +@click.pass_context +def reset_repo_age(ctx): + + with DatabaseSession(logger, engine=ctx.obj.engine) as session: + update_query = ( + update(Repo) + .values(repo_added=datetime.now()) + ) + + session.execute(update_query) + session.commit() + @cli.command("test-connection") @test_connection @test_db_connection def test_db_connection(): - pass + print("Successful db connection") + # TODO: Fix this function def run_psql_command_in_database(target_type, target): @@ -457,18 +474,3 @@ def check_pgpass_credentials(config): else: print("Credentials found in $HOME/.pgpass") - -#NOTE: For some reason when I try to add function decorators to this function -#click thinks it's an argument and tries to parse it but it errors since a function -#isn't an iterable. -@cli.command("reset-repo-age") -def reset_repo_age(): - - with DatabaseSession(logger) as session: - update_query = ( - update(Repo) - .values(repo_added=datetime.now()) - ) - - session.execute(update_query) - session.commit() From 6cafc46bb9057839df3c646920de3acd72e29263 Mon Sep 17 00:00:00 2001 From: Andrew Brain Date: Tue, 27 Feb 2024 07:44:49 -0600 Subject: [PATCH 3/3] remove object super class from DatabaseContext --- augur/application/cli/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/augur/application/cli/__init__.py b/augur/application/cli/__init__.py index 67f85c4d09..e07e880bd9 100644 --- a/augur/application/cli/__init__.py +++ b/augur/application/cli/__init__.py @@ -74,7 +74,7 @@ def new_func(ctx, *args, **kwargs): return update_wrapper(new_func, function_db_connection) -class DatabaseContext(object): +class DatabaseContext(): def __init__(self): self.engine = None