From 114bebd96004b677b15459645c765018d0c65556 Mon Sep 17 00:00:00 2001 From: Michael F Date: Tue, 14 Feb 2023 09:44:08 +0100 Subject: [PATCH] Branch selection for sync command (#73), better cli help, improved db rollback handling on error --- src/tagpack/cli.py | 43 +++++++++++++++++++++++----------- src/tagpack/tagstore.py | 51 +++++++++++++++++++++++++++++++++-------- 2 files changed, 72 insertions(+), 22 deletions(-) diff --git a/src/tagpack/cli.py b/src/tagpack/cli.py index 3623960..b3311ba 100644 --- a/src/tagpack/cli.py +++ b/src/tagpack/cli.py @@ -4,6 +4,7 @@ import tempfile import time from argparse import ArgumentParser +from functools import partial from multiprocessing import Pool, cpu_count import pandas as pd @@ -755,25 +756,31 @@ def sync_repos(args): for repo_url in repos: repo_url = repo_url.strip() - print_line(f"Syncing {repo_url}. Temp files in: {temp_dir_tt}") + print(f"Syncing {repo_url}. Temp files in: {temp_dir_tt}") try: - print_line("Cloning...") - Repo.clone_from(repo_url, temp_dir_tt) - - print_line("Inserting actorpacks ...") + print_info("Cloning...") + repo_url, *branch = repo_url.split(" ") + repo = Repo.clone_from(repo_url, temp_dir_tt) + if len(branch) > 0: + branch = branch[0] + print_info(f"Using branch {branch}") + repo.git.checkout(branch) + + print("Inserting actorpacks ...") exec_cli_command(["actorpack", "insert", "--add_new", temp_dir_tt]) - print_line("Inserting tagpacks ...") + print("Inserting tagpacks ...") exec_cli_command(["tagpack", "insert", "--add_new", temp_dir_tt]) finally: - print_line(f"Removing temp files in: {temp_dir_tt}") - rmtree(temp_dir_tt) + if os.path.isdir(temp_dir_tt): + print_info(f"Removing temp files in: {temp_dir_tt}") + rmtree(temp_dir_tt) - print_line("Removeing duplicates ...") + print("Removing duplicates ...") exec_cli_command(["tagstore", "remove_duplicates"]) - print_line("Refreshing db views ...") + print("Refreshing db views ...") exec_cli_command(["tagstore", "refresh_views"]) print_success("Your tagstore is now up-to-date again.") @@ -792,6 +799,16 @@ def main(): get_version() ), ) + + def set_print_help_on_error(parser): + def print_help_subparser(subparser, args): + subparser.print_help() + print_fail("No action was requested. Please use as specified above.") + + parser.set_defaults(func=partial(print_help_subparser, parser)) + + set_print_help_on_error(parser) + parser.add_argument("-v", "--version", action="version", version=show_version()) parser.add_argument( "--config", @@ -825,6 +842,7 @@ def main(): # parsers for tagpack command parser_tp = subparsers.add_parser("tagpack", help="tagpack commands") + set_print_help_on_error(parser) ptp = parser_tp.add_subparsers(title="TagPack commands") @@ -932,6 +950,7 @@ def main(): # parsers for actorpack command parser_ap = subparsers.add_parser("actorpack", help="actorpack commands") + set_print_help_on_error(parser_ap) app = parser_ap.add_subparsers(title="ActorPack commands") @@ -1099,6 +1118,7 @@ def main(): # parsers for database housekeeping parser_db = subparsers.add_parser("tagstore", help="database housekeeping commands") + set_print_help_on_error(parser_db) pdp = parser_db.add_subparsers(title="TagStore commands") @@ -1276,9 +1296,6 @@ def main(): print_warn(url_msg) parser.error("No postgresql URL connection was provided. Exiting.") - if not hasattr(args, "func"): - parser.error("No action was requested. Exiting.") - args.func(args) diff --git a/src/tagpack/tagstore.py b/src/tagpack/tagstore.py index 97d3b7b..1af7f7d 100644 --- a/src/tagpack/tagstore.py +++ b/src/tagpack/tagstore.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from datetime import datetime +from functools import wraps import numpy as np from cashaddress.convert import to_legacy_address @@ -12,6 +13,29 @@ register_adapter(np.int64, AsIs) +def auto_commit(function): + @wraps(function) + def wrapper(*args, **kwargs): + """ + Automatically calls commit at the end of a function or + rollback if an error occurs. If rollback is not execute + int leaves the connection in a broken state. + https://stackoverflow.com/questions/2979369/databaseerror-current-transaction-is-aborted-commands-ignored-until-end-of-tra + """ + self, *_ = args + try: + output = function(*args, **kwargs) + except Exception as e: + # self.cursor.execute("rollback") + self.conn.rollback() + raise e + finally: + self.conn.commit() + return output + + return wrapper + + class TagStore(object): def __init__(self, url, schema): self.conn = connect(url, options=f"-c search_path={schema}") @@ -22,6 +46,7 @@ def __init__(self, url, schema): self.existing_packs = None self.existing_actorpacks = None + @auto_commit def insert_taxonomy(self, taxonomy): if taxonomy.key == "confidence": self.insert_confidence_scores(taxonomy) @@ -39,8 +64,9 @@ def insert_taxonomy(self, taxonomy): v = (c.id, c.label, c.taxonomy.key, c.uri, c.description) self.cursor.execute(statement, v) - self.conn.commit() + # self.conn.commit() + @auto_commit def insert_confidence_scores(self, confidence): statement = "INSERT INTO confidence (id, label, description, level)" statement += " VALUES (%s, %s, %s, %s)" @@ -49,7 +75,7 @@ def insert_confidence_scores(self, confidence): values = (c.id, c.label, c.description, c.level) self.cursor.execute(statement, values) - self.conn.commit() + # self.conn.commit() def tp_exists(self, prefix, rel_path): if not self.existing_packs: @@ -59,6 +85,7 @@ def tp_exists(self, prefix, rel_path): def create_id(self, prefix, rel_path): return ":".join([prefix, rel_path]) if prefix else rel_path + @auto_commit def insert_tagpack( self, tagpack, is_public, force_insert, prefix, rel_path, batch=1000 ): @@ -109,7 +136,7 @@ def insert_tagpack( execute_batch(self.cursor, addr_sql, address_data) execute_batch(self.cursor, tag_sql, tag_data) - self.conn.commit() + # self.conn.commit() def actorpack_exists(self, prefix, actorpack_name): if not self.existing_actorpacks: @@ -124,6 +151,7 @@ def get_ingested_actorpacks(self) -> list: self.cursor.execute("SELECT id from actorpack") return [i[0] for i in self.cursor.fetchall()] + @auto_commit def insert_actorpack( self, actorpack, is_public, force_insert, prefix, rel_path, batch=1000 ): @@ -180,7 +208,7 @@ def insert_actorpack( execute_batch(self.cursor, act_cat_sql, cat_data) execute_batch(self.cursor, act_jur_sql, jur_data) - self.conn.commit() + # self.conn.commit() def low_quality_address_labels(self, th=0.25, currency="", category="") -> dict: """ @@ -253,6 +281,7 @@ def remove_duplicates(self): self.conn.commit() return self.cursor.rowcount + @auto_commit def refresh_db(self): self.cursor.execute("REFRESH MATERIALIZED VIEW label") self.cursor.execute("REFRESH MATERIALIZED VIEW statistics") @@ -261,7 +290,7 @@ def refresh_db(self): "REFRESH MATERIALIZED VIEW " "cluster_defining_tags_by_frequency_and_maxconfidence" ) # noqa - self.conn.commit() + # self.conn.commit() def get_addresses(self, update_existing): if update_existing: @@ -298,6 +327,7 @@ def get_tagstore_composition(self, by_currency=False): for record in self.cursor: yield record + @auto_commit def insert_cluster_mappings(self, clusters): if not clusters.empty: q = "INSERT INTO address_cluster_mapping (address, currency, \ @@ -317,16 +347,17 @@ def insert_cluster_mappings(self, clusters): data = clusters[cols].to_records(index=False) execute_batch(self.cursor, q, data) - self.conn.commit() + # self.conn.commit() def _supports_currency(self, tag): return tag.all_fields.get("currency") in self.supported_currencies + @auto_commit def finish_mappings_update(self, keys): q = "UPDATE address SET is_mapped=true WHERE NOT is_mapped \ AND currency IN %s" self.cursor.execute(q, (tuple(keys),)) - self.conn.commit() + # self.conn.commit() def get_ingested_tagpacks(self) -> list: self.cursor.execute("SELECT id from tagpack") @@ -400,6 +431,7 @@ def list_address_actors(self, currency=""): self.cursor.execute(q, v) return self.cursor.fetchall() + @auto_commit def update_tags_actors(self): """ Update the `tag.actor` field by searching an actor.id that matches with @@ -432,9 +464,10 @@ def update_tags_actors(self): "AND (label ILIKE 'okex%' OR label ILIKE 'okb%')" ) self.cursor.execute(q) - self.conn.commit() + # self.conn.commit() return rowcount + @auto_commit def update_quality_actors(self): """ Update all entries in `address_quality` having a unique actor in table @@ -470,7 +503,7 @@ def update_quality_actors(self): ) self.cursor.execute(q) rowcount = self.cursor.rowcount - self.conn.commit() + # self.conn.commit() return rowcount