Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User Zero / Super Admin #1021

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions api/api_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import faulthandler
faulthandler.enable(file=sys.__stderr__) # will catch segfaults and write to stderr

from urllib.parse import urlparse

from collections import OrderedDict
from functools import cache
from html import escape as html_escape
Expand Down Expand Up @@ -153,7 +151,7 @@ def get_machine_list():

return DB().fetch_all(query)

def get_run_info(run_id):
def get_run_info(user, run_id):
query = """
SELECT
id, name, uri, branch, commit_hash,
Expand All @@ -163,20 +161,22 @@ def get_run_info(run_id):
measurement_config, machine_specs, machine_id, usage_scenario,
created_at, invalid_run, phases, logs, failed
FROM runs
WHERE id = %s
WHERE
(TRUE = %s OR user_id = ANY(%s::int[]))
AND id = %s
"""
params = (run_id,)
params = (user.is_super_user(), user.visible_users(), run_id)
return DB().fetch_one(query, params=params, fetch_mode='dict')

def get_timeline_query(uri, filename, machine_id, branch, metrics, phase, start_date=None, end_date=None, detail_name=None, limit_365=False, sorting='run'):
def get_timeline_query(user, uri, filename, machine_id, branch, metrics, phase, start_date=None, end_date=None, detail_name=None, limit_365=False, sorting='run'):

if filename is None or filename.strip() == '':
filename = 'usage_scenario.yml'

if branch is None or branch.strip() == '':
branch = 'main'

params = [uri, filename, branch, machine_id, f"%{phase}"]
params = [user.is_super_user(), user.visible_users(), uri, filename, branch, machine_id, f"%{phase}"]

metrics_condition = ''
if metrics is None or metrics.strip() == '' or metrics.strip() == 'key':
Expand Down Expand Up @@ -217,7 +217,8 @@ def get_timeline_query(uri, filename, machine_id, branch, metrics, phase, start_
LEFT JOIN phase_stats as p ON
r.id = p.run_id
WHERE
r.uri = %s
(TRUE = %s OR r.user_id = ANY(%s::int[]))
AND r.uri = %s
AND r.filename = %s
AND r.branch = %s
AND r.end_measurement IS NOT NULL
Expand All @@ -239,17 +240,20 @@ def get_timeline_query(uri, filename, machine_id, branch, metrics, phase, start_

return (query, params)

def get_comparison_details(ids, comparison_db_key):
def get_comparison_details(user, ids, comparison_db_key):

query = sql.SQL('''
SELECT
id, name, created_at, commit_hash, commit_timestamp, gmt_hash, {}
FROM runs
WHERE id = ANY(%s::uuid[])
WHERE
(TRUE = %s OR user_id = ANY(%s::int[]))
AND id = ANY(%s::uuid[])
ORDER BY created_at -- Must be same order as in get_phase_stats
''').format(sql.Identifier(comparison_db_key))

data = DB().fetch_all(query, (ids, ))
params = (user.is_super_user(), user.visible_users(), ids)
data = DB().fetch_all(query, params=params)
if data is None or data == []:
raise RuntimeError('Could not get comparison details')

Expand All @@ -276,21 +280,23 @@ def get_comparison_details(ids, comparison_db_key):

return comparison_details

def determine_comparison_case(ids):
def determine_comparison_case(user, ids):

query = '''
WITH uniques as (
SELECT uri, filename, machine_id, commit_hash, branch FROM runs
WHERE id = ANY(%s::uuid[])
WHERE
(TRUE = %s OR user_id = ANY(%s::int[]))
AND id = ANY(%s::uuid[])
GROUP BY uri, filename, machine_id, commit_hash, branch
)
SELECT
COUNT(DISTINCT uri ), COUNT(DISTINCT filename), COUNT(DISTINCT machine_id),
COUNT(DISTINCT commit_hash ), COUNT(DISTINCT branch)
FROM uniques
'''

data = DB().fetch_one(query, (ids, ))
params = (user.is_super_user(), user.visible_users(), ids)
data = DB().fetch_one(query, params=params)
if data is None or data == [] or data[1] is None: # special check for data[1] as this is aggregate query which always returns result
raise RuntimeError('Could not determine compare case')

Expand Down Expand Up @@ -376,7 +382,7 @@ def determine_comparison_case(ids):

raise RuntimeError('Could not determine comparison case after checking all conditions')

def get_phase_stats(ids):
def get_phase_stats(user, ids):
query = """
SELECT
a.phase, a.metric, a.detail_name, a.value, a.type, a.max_value, a.min_value, a.unit,
Expand All @@ -387,11 +393,13 @@ def get_phase_stats(ids):
LEFT JOIN machines as c on c.id = b.machine_id

WHERE
a.run_id = ANY(%s::uuid[])
(TRUE = %s OR b.user_id = ANY(%s::int[]))
AND a.run_id = ANY(%s::uuid[])
ORDER BY
b.created_at ASC -- Must be same order as in get_comparison_details
"""
return DB().fetch_all(query, (ids, ))
params = (user.is_super_user(), user.visible_users(), ids)
return DB().fetch_all(query, params=params)

# Would be interesting to know if in an application server like gunicor @cache
# Will also work for subsequent requests ...?
Expand Down Expand Up @@ -695,23 +703,23 @@ def __init__(
)

def authenticate(authentication_token=Depends(header_scheme), request: Request = None):
parsed_url = urlparse(str(request.url))

try:
if not authentication_token or authentication_token.strip() == '': # Note that if no token is supplied this will authenticate as the DEFAULT user, which in FOSS systems has full capabilities
authentication_token = 'DEFAULT'

user = User.authenticate(SecureVariable(authentication_token))

if not user.can_use_route(parsed_url.path):
if not user.can_use_route(request.scope["route"].path):
raise HTTPException(status_code=401, detail="Route not allowed") from UserAuthenticationError

if not user.has_api_quota(parsed_url.path):
if not user.has_api_quota(request.scope["route"].path):
raise HTTPException(status_code=401, detail="Quota exceeded") from UserAuthenticationError

user.deduct_api_quota(parsed_url.path, 1)
user.deduct_api_quota(request.scope["route"].path, 1)

except UserAuthenticationError:
raise HTTPException(status_code=401, detail="Invalid token") from UserAuthenticationError
except UserAuthenticationError as exc:
raise HTTPException(status_code=401, detail=str(exc)) from UserAuthenticationError
return user

def get_connecting_ip(request):
Expand Down
64 changes: 40 additions & 24 deletions api/eco_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ async def post_ci_measurement_add(
return Response(status_code=204)

@router.get('/v1/ci/measurements')
async def get_ci_measurements(repo: str, branch: str, workflow: str, start_date: date, end_date: date):
async def get_ci_measurements(repo: str, branch: str, workflow: str, start_date: date, end_date: date, user: User = Depends(authenticate)):

query = """
query = '''
SELECT energy_uj, run_id, created_at, label, cpu, commit_hash, duration_us, source, cpu_util_avg,
(SELECT workflow_name FROM ci_measurements AS latest_workflow
WHERE latest_workflow.repo = ci_measurements.repo
Expand All @@ -166,12 +166,14 @@ async def get_ci_measurements(repo: str, branch: str, workflow: str, start_date:
lat, lon, city, carbon_intensity_g, carbon_ug
FROM ci_measurements
WHERE
repo = %s AND branch = %s AND workflow_id = %s
(TRUE = %s OR user_id = ANY(%s::int[]))
AND repo = %s AND branch = %s AND workflow_id = %s
AND DATE(created_at) >= TO_DATE(%s, 'YYYY-MM-DD')
AND DATE(created_at) <= TO_DATE(%s, 'YYYY-MM-DD')
ORDER BY run_id ASC, created_at ASC
"""
params = (repo, branch, workflow, str(start_date), str(end_date))
'''

params = (user.is_super_user(), user.visible_users(), repo, branch, workflow, str(start_date), str(end_date))
data = DB().fetch_all(query, params=params)

if data is None or data == []:
Expand All @@ -180,7 +182,7 @@ async def get_ci_measurements(repo: str, branch: str, workflow: str, start_date:
return ORJSONResponse({'success': True, 'data': data})

@router.get('/v1/ci/stats')
async def get_ci_stats(repo: str, branch: str, workflow: str, start_date: date, end_date: date):
async def get_ci_stats(repo: str, branch: str, workflow: str, start_date: date, end_date: date, user: User = Depends(authenticate)):


query = '''
Expand All @@ -193,7 +195,8 @@ async def get_ci_stats(repo: str, branch: str, workflow: str, start_date: date,
SUM(carbon_ug) as e
FROM ci_measurements
WHERE
repo = %s AND branch = %s AND workflow_id = %s
(TRUE = %s OR user_id = ANY(%s::int[]))
AND repo = %s AND branch = %s AND workflow_id = %s
AND DATE(created_at) >= TO_DATE(%s, 'YYYY-MM-DD') AND DATE(created_at) <= TO_DATE(%s, 'YYYY-MM-DD')
GROUP BY run_id
) SELECT
Expand All @@ -206,7 +209,8 @@ async def get_ci_stats(repo: str, branch: str, workflow: str, start_date: date,
COUNT(*)
FROM my_table;
'''
params = (repo, branch, workflow, str(start_date), str(end_date))

params = (user.is_super_user(), user.visible_users(), repo, branch, workflow, str(start_date), str(end_date))
totals_data = DB().fetch_one(query, params=params)

if totals_data is None or totals_data[0] is None: # aggregate query always returns row
Expand All @@ -224,11 +228,12 @@ async def get_ci_stats(repo: str, branch: str, workflow: str, start_date: date,
COUNT(*), label
FROM ci_measurements
WHERE
repo = %s AND branch = %s AND workflow_id = %s
(TRUE = %s OR user_id = ANY(%s::int[]))
AND repo = %s AND branch = %s AND workflow_id = %s
AND DATE(created_at) >= TO_DATE(%s, 'YYYY-MM-DD') AND DATE(created_at) <= TO_DATE(%s, 'YYYY-MM-DD')
GROUP BY label
'''
params = (repo, branch, workflow, str(start_date), str(end_date))
params = (user.is_super_user(), user.visible_users(), repo, branch, workflow, str(start_date), str(end_date))
per_label_data = DB().fetch_all(query, params=params)

if per_label_data is None or per_label_data[0] is None:
Expand All @@ -238,14 +243,15 @@ async def get_ci_stats(repo: str, branch: str, workflow: str, start_date: date,


@router.get('/v1/ci/repositories')
async def get_ci_repositories(repo: str | None = None, sort_by: str = 'name'):
async def get_ci_repositories(repo: str | None = None, sort_by: str = 'name', user: User = Depends(authenticate)):

params = []
query = """
query = '''
SELECT repo, source, MAX(created_at) as last_run
FROM ci_measurements
WHERE 1=1
"""
WHERE
(TRUE = %s OR user_id = ANY(%s::int[]))
'''
params = [user.is_super_user(), user.visible_users()]

if repo: # filter is currently not used, but may be a feature in the future
query = f"{query} AND ci_measurements.repo = %s \n"
Expand All @@ -266,10 +272,10 @@ async def get_ci_repositories(repo: str | None = None, sort_by: str = 'name'):


@router.get('/v1/ci/runs')
async def get_ci_runs(repo: str, sort_by: str = 'name'):
async def get_ci_runs(repo: str, sort_by: str = 'name', user: User = Depends(authenticate)):


params = []
query = """
query = '''
SELECT repo, branch, workflow_id, source, MAX(created_at) as last_run,
(SELECT workflow_name FROM ci_measurements AS latest_workflow
WHERE latest_workflow.repo = ci_measurements.repo
Expand All @@ -278,8 +284,11 @@ async def get_ci_runs(repo: str, sort_by: str = 'name'):
ORDER BY latest_workflow.created_at DESC
LIMIT 1) AS workflow_name
FROM ci_measurements
WHERE 1=1
"""
WHERE
(TRUE = %s OR user_id = ANY(%s::int[]))
'''

params = [user.is_super_user(), user.visible_users()]

query = f"{query} AND ci_measurements.repo = %s \n"
params.append(repo)
Expand All @@ -296,8 +305,11 @@ async def get_ci_runs(repo: str, sort_by: str = 'name'):

return ORJSONResponse({'success': True, 'data': data}) # no escaping needed, as it happend on ingest

# Route to display a badge for a CI run
## A complex case to allow public visibility of the badge but restricting everything else would be to have
## User 1 restricted to only this route but a fully populated 'visible_users' array
@router.get('/v1/ci/badge/get')
async def get_ci_badge_get(repo: str, branch: str, workflow:str, mode: str = 'last', metric: str = 'energy', duration_days: int | None = None):
async def get_ci_badge_get(repo: str, branch: str, workflow:str, mode: str = 'last', metric: str = 'energy', duration_days: int | None = None, user: User = Depends(authenticate)):
if metric == 'energy':
metric = 'energy_uj'
metric_unit = 'uJ'
Expand All @@ -316,14 +328,16 @@ async def get_ci_badge_get(repo: str, branch: str, workflow:str, mode: str = 'la
if duration_days and (duration_days < 1 or duration_days > 365):
raise RequestValidationError('Duration days must be between 1 and 365 days')

params = [repo, branch, workflow]


query = f"""
SELECT SUM({metric})
FROM ci_measurements
WHERE repo = %s AND branch = %s AND workflow_id = %s
WHERE
(TRUE = %s OR user_id = ANY(%s::int[]))
AND repo = %s AND branch = %s AND workflow_id = %s
ArneTR marked this conversation as resolved.
Show resolved Hide resolved
"""
params = [user.is_super_user(), user.visible_users(), repo, branch, workflow]

if mode == 'avg':
if not duration_days:
Expand All @@ -332,7 +346,9 @@ async def get_ci_badge_get(repo: str, branch: str, workflow:str, mode: str = 'la
WITH my_table as (
SELECT SUM({metric}) my_sum
FROM ci_measurements
WHERE repo = %s AND branch = %s AND workflow_id = %s AND DATE(created_at) > NOW() - make_interval(days => %s)
WHERE
(TRUE = %s OR user_id = ANY(%s::int[]))
AND repo = %s AND branch = %s AND workflow_id = %s AND DATE(created_at) > NOW() - make_interval(days => %s)
GROUP BY run_id
) SELECT AVG(my_sum) FROM my_table;
"""
Expand Down
Loading
Loading