Skip to content

Commit

Permalink
Branch selection for sync command (#73), better cli help, improved db…
Browse files Browse the repository at this point in the history
… rollback handling on error
  • Loading branch information
soad003 committed Feb 14, 2023
1 parent a091711 commit 114bebd
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 22 deletions.
43 changes: 30 additions & 13 deletions src/tagpack/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
import time
from argparse import ArgumentParser
from functools import partial
from multiprocessing import Pool, cpu_count

import pandas as pd
Expand Down Expand Up @@ -755,25 +756,31 @@ def sync_repos(args):

for repo_url in repos:
repo_url = repo_url.strip()
print_line(f"Syncing {repo_url}. Temp files in: {temp_dir_tt}")
print(f"Syncing {repo_url}. Temp files in: {temp_dir_tt}")

try:
print_line("Cloning...")
Repo.clone_from(repo_url, temp_dir_tt)

print_line("Inserting actorpacks ...")
print_info("Cloning...")
repo_url, *branch = repo_url.split(" ")
repo = Repo.clone_from(repo_url, temp_dir_tt)
if len(branch) > 0:
branch = branch[0]
print_info(f"Using branch {branch}")
repo.git.checkout(branch)

print("Inserting actorpacks ...")
exec_cli_command(["actorpack", "insert", "--add_new", temp_dir_tt])

print_line("Inserting tagpacks ...")
print("Inserting tagpacks ...")
exec_cli_command(["tagpack", "insert", "--add_new", temp_dir_tt])
finally:
print_line(f"Removing temp files in: {temp_dir_tt}")
rmtree(temp_dir_tt)
if os.path.isdir(temp_dir_tt):
print_info(f"Removing temp files in: {temp_dir_tt}")
rmtree(temp_dir_tt)

print_line("Removeing duplicates ...")
print("Removing duplicates ...")
exec_cli_command(["tagstore", "remove_duplicates"])

print_line("Refreshing db views ...")
print("Refreshing db views ...")
exec_cli_command(["tagstore", "refresh_views"])

print_success("Your tagstore is now up-to-date again.")
Expand All @@ -792,6 +799,16 @@ def main():
get_version()
),
)

def set_print_help_on_error(parser):
def print_help_subparser(subparser, args):
subparser.print_help()
print_fail("No action was requested. Please use as specified above.")

parser.set_defaults(func=partial(print_help_subparser, parser))

set_print_help_on_error(parser)

parser.add_argument("-v", "--version", action="version", version=show_version())
parser.add_argument(
"--config",
Expand Down Expand Up @@ -825,6 +842,7 @@ def main():

# parsers for tagpack command
parser_tp = subparsers.add_parser("tagpack", help="tagpack commands")
set_print_help_on_error(parser)

ptp = parser_tp.add_subparsers(title="TagPack commands")

Expand Down Expand Up @@ -932,6 +950,7 @@ def main():

# parsers for actorpack command
parser_ap = subparsers.add_parser("actorpack", help="actorpack commands")
set_print_help_on_error(parser_ap)

app = parser_ap.add_subparsers(title="ActorPack commands")

Expand Down Expand Up @@ -1099,6 +1118,7 @@ def main():

# parsers for database housekeeping
parser_db = subparsers.add_parser("tagstore", help="database housekeeping commands")
set_print_help_on_error(parser_db)

pdp = parser_db.add_subparsers(title="TagStore commands")

Expand Down Expand Up @@ -1276,9 +1296,6 @@ def main():
print_warn(url_msg)
parser.error("No postgresql URL connection was provided. Exiting.")

if not hasattr(args, "func"):
parser.error("No action was requested. Exiting.")

args.func(args)


Expand Down
51 changes: 42 additions & 9 deletions src/tagpack/tagstore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from functools import wraps

import numpy as np
from cashaddress.convert import to_legacy_address
Expand All @@ -12,6 +13,29 @@
register_adapter(np.int64, AsIs)


def auto_commit(function):
@wraps(function)
def wrapper(*args, **kwargs):
"""
Automatically calls commit at the end of a function or
rollback if an error occurs. If rollback is not execute
int leaves the connection in a broken state.
https://stackoverflow.com/questions/2979369/databaseerror-current-transaction-is-aborted-commands-ignored-until-end-of-tra
"""
self, *_ = args
try:
output = function(*args, **kwargs)
except Exception as e:
# self.cursor.execute("rollback")
self.conn.rollback()
raise e
finally:
self.conn.commit()
return output

return wrapper


class TagStore(object):
def __init__(self, url, schema):
self.conn = connect(url, options=f"-c search_path={schema}")
Expand All @@ -22,6 +46,7 @@ def __init__(self, url, schema):
self.existing_packs = None
self.existing_actorpacks = None

@auto_commit
def insert_taxonomy(self, taxonomy):
if taxonomy.key == "confidence":
self.insert_confidence_scores(taxonomy)
Expand All @@ -39,8 +64,9 @@ def insert_taxonomy(self, taxonomy):
v = (c.id, c.label, c.taxonomy.key, c.uri, c.description)
self.cursor.execute(statement, v)

self.conn.commit()
# self.conn.commit()

@auto_commit
def insert_confidence_scores(self, confidence):
statement = "INSERT INTO confidence (id, label, description, level)"
statement += " VALUES (%s, %s, %s, %s)"
Expand All @@ -49,7 +75,7 @@ def insert_confidence_scores(self, confidence):
values = (c.id, c.label, c.description, c.level)
self.cursor.execute(statement, values)

self.conn.commit()
# self.conn.commit()

def tp_exists(self, prefix, rel_path):
if not self.existing_packs:
Expand All @@ -59,6 +85,7 @@ def tp_exists(self, prefix, rel_path):
def create_id(self, prefix, rel_path):
return ":".join([prefix, rel_path]) if prefix else rel_path

@auto_commit
def insert_tagpack(
self, tagpack, is_public, force_insert, prefix, rel_path, batch=1000
):
Expand Down Expand Up @@ -109,7 +136,7 @@ def insert_tagpack(
execute_batch(self.cursor, addr_sql, address_data)
execute_batch(self.cursor, tag_sql, tag_data)

self.conn.commit()
# self.conn.commit()

def actorpack_exists(self, prefix, actorpack_name):
if not self.existing_actorpacks:
Expand All @@ -124,6 +151,7 @@ def get_ingested_actorpacks(self) -> list:
self.cursor.execute("SELECT id from actorpack")
return [i[0] for i in self.cursor.fetchall()]

@auto_commit
def insert_actorpack(
self, actorpack, is_public, force_insert, prefix, rel_path, batch=1000
):
Expand Down Expand Up @@ -180,7 +208,7 @@ def insert_actorpack(
execute_batch(self.cursor, act_cat_sql, cat_data)
execute_batch(self.cursor, act_jur_sql, jur_data)

self.conn.commit()
# self.conn.commit()

def low_quality_address_labels(self, th=0.25, currency="", category="") -> dict:
"""
Expand Down Expand Up @@ -253,6 +281,7 @@ def remove_duplicates(self):
self.conn.commit()
return self.cursor.rowcount

@auto_commit
def refresh_db(self):
self.cursor.execute("REFRESH MATERIALIZED VIEW label")
self.cursor.execute("REFRESH MATERIALIZED VIEW statistics")
Expand All @@ -261,7 +290,7 @@ def refresh_db(self):
"REFRESH MATERIALIZED VIEW "
"cluster_defining_tags_by_frequency_and_maxconfidence"
) # noqa
self.conn.commit()
# self.conn.commit()

def get_addresses(self, update_existing):
if update_existing:
Expand Down Expand Up @@ -298,6 +327,7 @@ def get_tagstore_composition(self, by_currency=False):
for record in self.cursor:
yield record

@auto_commit
def insert_cluster_mappings(self, clusters):
if not clusters.empty:
q = "INSERT INTO address_cluster_mapping (address, currency, \
Expand All @@ -317,16 +347,17 @@ def insert_cluster_mappings(self, clusters):
data = clusters[cols].to_records(index=False)

execute_batch(self.cursor, q, data)
self.conn.commit()
# self.conn.commit()

def _supports_currency(self, tag):
return tag.all_fields.get("currency") in self.supported_currencies

@auto_commit
def finish_mappings_update(self, keys):
q = "UPDATE address SET is_mapped=true WHERE NOT is_mapped \
AND currency IN %s"
self.cursor.execute(q, (tuple(keys),))
self.conn.commit()
# self.conn.commit()

def get_ingested_tagpacks(self) -> list:
self.cursor.execute("SELECT id from tagpack")
Expand Down Expand Up @@ -400,6 +431,7 @@ def list_address_actors(self, currency=""):
self.cursor.execute(q, v)
return self.cursor.fetchall()

@auto_commit
def update_tags_actors(self):
"""
Update the `tag.actor` field by searching an actor.id that matches with
Expand Down Expand Up @@ -432,9 +464,10 @@ def update_tags_actors(self):
"AND (label ILIKE 'okex%' OR label ILIKE 'okb%')"
)
self.cursor.execute(q)
self.conn.commit()
# self.conn.commit()
return rowcount

@auto_commit
def update_quality_actors(self):
"""
Update all entries in `address_quality` having a unique actor in table
Expand Down Expand Up @@ -470,7 +503,7 @@ def update_quality_actors(self):
)
self.cursor.execute(q)
rowcount = self.cursor.rowcount
self.conn.commit()
# self.conn.commit()
return rowcount


Expand Down

0 comments on commit 114bebd

Please sign in to comment.