Skip to content

Commit

Permalink
Merge pull request #2782 from GaryPWhite/cli_load_improvements
Browse files Browse the repository at this point in the history
handle rate limiting and more repo load messages
  • Loading branch information
sgoggins authored May 6, 2024
2 parents 77388ca + 154cd14 commit 23388e4
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 47 deletions.
25 changes: 17 additions & 8 deletions augur/application/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from functools import update_wrapper
import os
import sys
import socket
import re
import json
import httpx

from augur.application.db.engine import DatabaseEngine
from augur.application.db import get_engine, dispose_database_engine
Expand All @@ -16,13 +16,22 @@ def test_connection(function_internet_connection):
@click.pass_context
def new_func(ctx, *args, **kwargs):
usage = re.search(r"Usage:\s(.*)\s\[OPTIONS\]", str(ctx.get_usage())).groups()[0]
try:
#try to ping google's dns server
socket.create_connection(("8.8.8.8",53))
return ctx.invoke(function_internet_connection, *args, **kwargs)
except OSError as e:
print(e)
print(f"\n\n{usage} command setup failed\nYou are not connect to the internet. Please connect to the internet to run Augur\n")
with httpx.Client() as client:
try:
_ = client.request(
method="GET", url="http://chaoss.community", timeout=10, follow_redirects=True)

return ctx.invoke(function_internet_connection, *args, **kwargs)
except (TimeoutError, httpx.TimeoutException):
print("Request timed out.")
except httpx.NetworkError:
print(f"Network Error: {httpx.NetworkError}")
except httpx.ProtocolError:
print(f"Protocol Error: {httpx.ProtocolError}")
print(f"\n\n{usage} command setup failed\n \
You are not connected to the internet.\n \
Please connect to the internet to run Augur\n \
Consider setting http_proxy variables for limited access installations.")
sys.exit()

return update_wrapper(new_func, function_internet_connection)
Expand Down
82 changes: 43 additions & 39 deletions augur/application/cli/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
import re
import stat as stat_module

from augur.application.cli import test_connection, test_db_connection, with_database, DatabaseContext
from augur.application.cli import (
test_connection,
test_db_connection,
with_database,
DatabaseContext,
)

from augur.application.db.session import DatabaseSession
from sqlalchemy import update
Expand All @@ -23,8 +28,9 @@

logger = logging.getLogger(__name__)


@click.group("db", short_help="Database utilities")
@click.pass_context
@click.pass_context
def cli(ctx):
ctx.obj = DatabaseContext()

Expand All @@ -36,36 +42,43 @@ def cli(ctx):
@with_database
@click.pass_context
def add_repos(ctx, filename):
"""Add repositories to Augur's database.
"""Add repositories to Augur's database.
The .csv file format should be repo_url,group_id
NOTE: The Group ID must already exist in the REPO_Groups Table.
If you want to add an entire GitHub organization, refer to the command: augur db add-github-org"""
If you want to add an entire GitHub organization, refer to the command: augur db add-github-org"""
from augur.tasks.github.util.github_task_session import GithubTaskSession
from augur.util.repo_load_controller import RepoLoadController

with GithubTaskSession(logger, engine=ctx.obj.engine) as session:

controller = RepoLoadController(session)

line_total = len(open(filename).readlines())
with open(filename) as upload_repos_file:
data = csv.reader(upload_repos_file, delimiter=",")
for row in data:

for line_num, row in enumerate(data):
repo_data = {}
repo_data["url"] = row[0]
try:
repo_data["repo_group_id"] = int(row[1])
except ValueError:
print(f"Invalid repo group_id: {row[1]} for Git url: `{repo_data['url']}`")
print(
f"Invalid repo group_id: {row[1]} for Git url: `{repo_data['url']}`"
)
continue

print(
f"Inserting repo with Git URL `{repo_data['url']}` into repo group {repo_data['repo_group_id']}")
controller.add_cli_repo(repo_data)
f"Inserting repo {line_num}/{line_total} with Git URL `{repo_data['url']}` into repo group {repo_data['repo_group_id']}"
)

succeeded, message = controller.add_cli_repo(repo_data)
if not succeeded:
logger.error(f"insert repo failed with error: {message['status']}`")
else:
logger.info(f"Repo added: {repo_data}")
print("Success")


@cli.command("get-repo-groups")
Expand Down Expand Up @@ -101,7 +114,6 @@ def add_repo_groups(ctx, filename):
Create new repo groups in Augur's database
"""
with ctx.obj.engine.begin() as connection:

df = pd.read_sql(
s.sql.text("SELECT repo_group_id FROM augur_data.repo_groups"),
connection,
Expand All @@ -117,7 +129,6 @@ def add_repo_groups(ctx, filename):
with open(filename) as create_repo_groups_file:
data = csv.reader(create_repo_groups_file, delimiter=",")
for row in data:

# Handle case where there's a hanging empty row.
if not row:
logger.info("Skipping empty data...")
Expand All @@ -137,6 +148,7 @@ def add_repo_groups(ctx, filename):
f"Repo group with ID {row[1]} for repo group {row[1]} already exists, skipping..."
)


@cli.command("add-github-org")
@click.argument("organization_name")
@test_connection
Expand All @@ -151,29 +163,26 @@ def add_github_org(ctx, organization_name):
from augur.util.repo_load_controller import RepoLoadController

with GithubTaskSession(logger, engine=ctx.obj.engine) as session:

controller = RepoLoadController(session)

controller.add_cli_org(organization_name)


# get_db_version is a helper function to print_db_version and upgrade_db_version
def get_db_version(engine):

db_version_sql = s.sql.text(
"""
SELECT * FROM augur_operations.augur_settings WHERE setting = 'augur_data_version'
"""
)

with engine.connect() as connection:

result = int(connection.execute(db_version_sql).fetchone()[2])

engine.dispose()
return result



@cli.command("print-db-version")
@test_connection
@test_db_connection
Expand Down Expand Up @@ -252,10 +261,10 @@ def update_api_key(ctx, api_key):
)

with ctx.obj.engine.begin() as connection:

connection.execute(update_api_key_sql, api_key=api_key)
logger.info(f"Updated Augur API key to: {api_key}")


@cli.command("get-api-key")
@test_connection
@test_db_connection
Expand All @@ -282,36 +291,35 @@ def get_api_key(ctx):
def check_pgpass():
augur_db_env_var = getenv("AUGUR_DB")
if augur_db_env_var:

# gets the user, passowrd, host, port, and database_name out of environment variable
# assumes database string of structure <beginning_of_db_string>//<user>:<password>@<host>:<port>/<database_name>
# it returns a tuple like (<user>, <password>, <host>, <port>, <database_name)
db_string_parsed = re.search(r"^.+:\/\/([a-zA-Z0-9_]+):(.+)@([a-zA-Z0-9-_~\.]+):(\d{1,5})\/([a-zA-Z0-9_-]+)", augur_db_env_var).groups()
db_string_parsed = re.search(
r"^.+:\/\/([a-zA-Z0-9_]+):(.+)@([a-zA-Z0-9-_~\.]+):(\d{1,5})\/([a-zA-Z0-9_-]+)",
augur_db_env_var,
).groups()

if db_string_parsed:

db_config = {
"user": db_string_parsed[0],
"password": db_string_parsed[1],
"host": db_string_parsed[2],
"host": db_string_parsed[2],
"port": db_string_parsed[3],
"database_name": db_string_parsed[4]
"database_name": db_string_parsed[4],
}

check_pgpass_credentials(db_config)

else:
print("Database string is invalid and cannot be used")


else:
with open("db.config.json", "r") as f:
with open("db.config.json", "r") as f:
config = json.load(f)
print(f"Config: {config}")
check_pgpass_credentials(config)



@cli.command("init-database")
@click.option("--default-db-name", default="postgres")
@click.option("--default-user", default="postgres")
Expand Down Expand Up @@ -370,22 +378,20 @@ def init_database(
f"GRANT ALL PRIVILEGES ON DATABASE {target_db_name} TO {target_user};",
)


@cli.command("reset-repo-age")
@test_connection
@test_db_connection
@with_database
@click.pass_context
def reset_repo_age(ctx):

with DatabaseSession(logger, engine=ctx.obj.engine) as session:
update_query = (
update(Repo)
.values(repo_added=datetime.now())
)
update_query = update(Repo).values(repo_added=datetime.now())

session.execute(update_query)
session.commit()


@cli.command("test-connection")
@test_connection
@test_db_connection
Expand All @@ -406,14 +412,13 @@ def run_psql_command_in_database(target_type, target):

if augur_db_environment_var:
pass
#TODO: Add functionality for environment variable
# TODO: Add functionality for environment variable
else:
with open("db.config.json", 'r') as f:
with open("db.config.json", "r") as f:
db_config = json.load(f)

host = db_config['host']
database_name = db_config['database_name']

host = db_config["host"]
database_name = db_config["database_name"]

db_conn_string = f"postgresql+psycopg2://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['database_name']}"
engine = s.create_engine(db_conn_string)
Expand Down Expand Up @@ -442,7 +447,7 @@ def check_pgpass_credentials(config):

if not path.isfile(pgpass_file_path):
print("~/.pgpass does not exist, creating.")
with open(pgpass_file_path, "w+",encoding="utf-8") as _:
with open(pgpass_file_path, "w+", encoding="utf-8") as _:
chmod(pgpass_file_path, stat_module.S_IWRITE | stat_module.S_IREAD)

pgpass_file_mask = oct(os.stat(pgpass_file_path).st_mode & 0o777)
Expand All @@ -451,7 +456,7 @@ def check_pgpass_credentials(config):
print("Updating ~/.pgpass file permissions.")
chmod(pgpass_file_path, stat_module.S_IWRITE | stat_module.S_IREAD)

with open(pgpass_file_path, "a+",encoding="utf-8") as pgpass_file:
with open(pgpass_file_path, "a+", encoding="utf-8") as pgpass_file:
end = pgpass_file.tell()
pgpass_file.seek(0)

Expand All @@ -475,4 +480,3 @@ def check_pgpass_credentials(config):
pgpass_file.write(credentials_string + "\n")
else:
print("Credentials found in $HOME/.pgpass")

21 changes: 21 additions & 0 deletions augur/application/db/models/augur_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sqlalchemy.orm import relationship
from sqlalchemy.sql import text
from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound
from time import sleep, mktime, gmtime, time, localtime
import logging
import re
import json
Expand Down Expand Up @@ -924,6 +925,17 @@ def is_valid_github_repo(gh_session, url: str) -> bool:
continue

data = result.json()
if result.status_code == 403: #GH Rate limiting
wait_until = int(result.headers.get("x-ratelimit-reset"))
# use time package to find how many seconds to wait
wait_in_seconds = int(
mktime(gmtime(wait_until)) -
mktime(gmtime(time()))
)
wait_until_time = localtime(wait_until)
logger.error(f"rate limited fetching {url}z")
logger.error(f"sleeping until {wait_until_time.tm_hour}:{wait_until_time.tm_min} ({wait_in_seconds} seconds)")
sleep(wait_in_seconds)
# if there was an error return False
if "message" in data.keys():

Expand All @@ -934,6 +946,8 @@ def is_valid_github_repo(gh_session, url: str) -> bool:

return True, {"status": "Valid repo", "repo_type": data["owner"]["type"]}

return False, {"status": "Failed to validate repo after multiple attempts"}

@staticmethod
def is_valid_gitlab_repo(gl_session, url: str) -> bool:
"""Determine whether a GitLab repo URL is valid.
Expand Down Expand Up @@ -961,13 +975,20 @@ def is_valid_gitlab_repo(gl_session, url: str) -> bool:
while attempts < 10:
response = hit_api(gl_session.oauths, url, logger)

if wait_in_seconds := response.headers.get("Retry-After") is not None:
logger.info(f"rate limited fetching {url}, sleeping for {wait_in_seconds}")
print(f"rate limited fetching {url}, sleeping for {wait_in_seconds}")
sleep(int(wait_in_seconds))

if response.status_code == 404:
return False, {"status": "Invalid repo"}

if response.status_code == 200:
return True, {"status": "Valid repo"}

attempts += 1
logger.info(f"could not validate {url}, will attempt again in {attempts*5} seconds")
sleep(attempts*3)

return False, {"status": "Failed to validate repo after multiple attempts"}

Expand Down

0 comments on commit 23388e4

Please sign in to comment.