Skip to content

Commit

Permalink
⚗️ Maintenance: make upgrade/downgrade a module fixture (#4222)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg authored May 11, 2023
1 parent 65e6eb8 commit 5b4b9b3
Show file tree
Hide file tree
Showing 25 changed files with 450 additions and 380 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[flake8]
ignore=E501
ignore=E501,W503
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@
def upgrade():
# Reassign items from two_factor_enabled -> LOGIN_2FA_REQUIRED
conn = op.get_bind()
rows = conn.execute("SELECT name, login_settings FROM products").fetchall()
rows = conn.execute(sa.DDL("SELECT name, login_settings FROM products")).fetchall()
for row in rows:
data = row["login_settings"] or {}
if "two_factor_enabled" in data:
data["LOGIN_2FA_REQUIRED"] = data.pop("two_factor_enabled")
data = json.dumps(data)
conn.execute(
"UPDATE products SET login_settings = '{}' WHERE name = '{}'".format( # nosec
data, row["name"]
sa.DDL(
"UPDATE products SET login_settings = '{}' WHERE name = '{}'".format( # nosec
data, row["name"]
)
)
)

Expand All @@ -45,16 +47,18 @@ def upgrade():
def downgrade():
# Reassign items from LOGIN_2FA_REQUIRED -> two_factor_enabled=false
conn = op.get_bind()
rows = conn.execute("SELECT name, login_settings FROM products").fetchall()
rows = conn.execute(sa.DDL("SELECT name, login_settings FROM products")).fetchall()
for row in rows:
data = row["login_settings"] or {}
data["two_factor_enabled"] = data.pop(
"LOGIN_2FA_REQUIRED", False
) # back to default
data = json.dumps(data)
conn.execute(
"UPDATE products SET login_settings = '{}' WHERE name = '{}'".format( # nosec
data, row["name"]
sa.DDL(
"UPDATE products SET login_settings = '{}' WHERE name = '{}'".format( # nosec
data, row["name"]
)
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def upgrade():

conn = op.get_bind()
default_product_name = conn.scalar(
"SELECT name from products ORDER BY priority LIMIT 1"
sa.DDL("SELECT name from products ORDER BY priority LIMIT 1")
)

op.add_column(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""

from typing import Any, Optional, Protocol
from typing import Any, Protocol

import sqlalchemy as sa

Expand Down Expand Up @@ -39,7 +39,7 @@ async def get_default_product_name(conn: _DBConnection) -> str:
:: raises ValueError if undefined
"""
product_name = await conn.scalar(
sa.select([products.c.name]).order_by(products.c.priority)
sa.select(products.c.name).order_by(products.c.priority)
)
if not product_name:
raise ValueError("No product defined in database")
Expand All @@ -50,9 +50,9 @@ async def get_default_product_name(conn: _DBConnection) -> str:

async def get_product_group_id(
connection: _DBConnection, product_name: str
) -> Optional[_GroupID]:
) -> _GroupID | None:
group_id = await connection.scalar(
sa.select([products.c.group_id]).where(products.c.name == product_name)
sa.select(products.c.group_id).where(products.c.name == product_name)
)
return None if group_id is None else _GroupID(group_id)

Expand All @@ -65,7 +65,7 @@ async def get_or_create_product_group(
"""
async with connection.begin():
group_id = await connection.scalar(
sa.select([products.c.group_id])
sa.select(products.c.group_id)
.where(products.c.name == product_name)
.with_for_update(read=True)
# a `FOR SHARE` lock: locks changes in the product until transaction is done.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import itertools
from dataclasses import dataclass
from typing import Optional, TypedDict
from typing import TypedDict

import sqlalchemy as sa
from aiopg.sa.connection import SAConnection
Expand Down Expand Up @@ -103,9 +103,9 @@ async def access_count(
conn: SAConnection,
tag_id: int,
*,
read: Optional[bool] = None,
write: Optional[bool] = None,
delete: Optional[bool] = None,
read: bool | None = None,
write: bool | None = None,
delete: bool | None = None,
) -> int:
"""
Returns 0 if tag does not match access
Expand All @@ -129,7 +129,7 @@ async def access_count(
stmt = sa.select(sa.func.count(user_to_groups.c.uid)).select_from(j)

# The number of occurrences of the user_id = how many groups are giving this access permission
permissions_count: Optional[int] = await conn.scalar(stmt)
permissions_count: int | None = await conn.scalar(stmt)
return permissions_count if permissions_count else 0

#
Expand All @@ -142,7 +142,7 @@ async def create(
*,
name: str,
color: str,
description: Optional[str] = None, # =nullable
description: str | None = None, # =nullable
read: bool = True,
write: bool = True,
delete: bool = True,
Expand Down Expand Up @@ -182,15 +182,15 @@ async def create(

async def list(self, conn: SAConnection) -> list[TagDict]:
select_stmt = (
sa.select(_COLUMNS)
sa.select(*_COLUMNS)
.select_from(self._join_user_to_tags(tags_to_groups.c.read == True))
.order_by(tags.c.id)
)

return [TagDict(row.items()) async for row in conn.execute(select_stmt)] # type: ignore

async def get(self, conn: SAConnection, tag_id: int) -> TagDict:
select_stmt = sa.select(_COLUMNS).select_from(
select_stmt = sa.select(*_COLUMNS).select_from(
self._join_user_to_given_tag(tags_to_groups.c.read == True, tag_id=tag_id)
)

Expand All @@ -208,7 +208,6 @@ async def update(
tag_id: int,
**fields,
) -> TagDict:

updates = {
name: value
for name, value in fields.items()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from datetime import datetime
from typing import Optional, TypedDict
from typing import TypedDict

from aiohttp import web
from aiohttp.test_utils import TestClient
Expand Down Expand Up @@ -92,13 +92,13 @@ async def log_client_in(


class NewUser:
def __init__(self, params=None, app: Optional[web.Application] = None):
def __init__(self, params=None, app: web.Application | None = None):
self.params = params
self.user = None
assert app
self.db = get_plugin_storage(app)

async def __aenter__(self):
async def __aenter__(self) -> UserInfoDict:
self.user = await create_fake_user(self.db, self.params)
return self.user

Expand All @@ -123,9 +123,9 @@ class NewInvitation(NewUser):
def __init__(
self,
client: TestClient,
guest_email: Optional[str] = None,
host: Optional[dict] = None,
trial_days: Optional[int] = None,
guest_email: str | None = None,
host: dict | None = None,
trial_days: int | None = None,
):
assert client.app
super().__init__(params=host, app=client.app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __get_validators__(cls):
def validate_docker_version(cls, docker_version: str) -> str:
try:
search_result = re.search(r"^\d\d.(\d\d|\d).(\d\d|\d)", docker_version)
assert search_result # nosec
return search_result.group()
except AttributeError:
raise ValueError( # pylint: disable=raise-missing-from
Expand Down
10 changes: 3 additions & 7 deletions services/director-v2/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ServiceState,
)
from simcore_service_director_v2.modules.dynamic_sidecar.docker_service_specs.volume_remover import (
DIND_VERSION,
DockerVersion,
)
from yarl import URL
Expand Down Expand Up @@ -451,10 +452,5 @@ async def async_docker_client() -> AsyncIterable[aiodocker.Docker]:


@pytest.fixture
async def docker_version(async_docker_client: aiodocker.Docker) -> DockerVersion:
version_request = (
await async_docker_client._query_json( # pylint: disable=protected-access
"version", versioned_api=False
)
)
return parse_obj_as(DockerVersion, version_request["Version"])
async def docker_version() -> DockerVersion:
return parse_obj_as(DockerVersion, DIND_VERSION)
Original file line number Diff line number Diff line change
Expand Up @@ -840,13 +840,11 @@ async def named_volumes(
async def is_volume_present(
async_docker_client: aiodocker.Docker, volume_name: str
) -> bool:
docker_volume = DockerVolume(async_docker_client, volume_name)
try:
await docker_volume.show()
return True
except aiodocker.DockerError as e:
assert e.message == f"get {volume_name}: no such volume"
return False
list_of_volumes = await async_docker_client.volumes.list()
for volume in list_of_volumes.get("Volumes", []):
if volume["Name"] == volume_name:
return True
return False


async def test_remove_volume_from_node_ok(
Expand Down
17 changes: 8 additions & 9 deletions services/web/server/src/simcore_service_webserver/groups_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def list_user_groups(

async with engine.acquire() as conn:
query = (
sa.select([groups, user_to_groups.c.access_rights])
sa.select(groups, user_to_groups.c.access_rights)
.select_from(
user_to_groups.join(groups, user_to_groups.c.gid == groups.c.gid),
)
Expand All @@ -88,7 +88,7 @@ async def list_user_groups(

async def _get_user_group(conn: SAConnection, user_id: int, gid: int) -> RowProxy:
result = await conn.execute(
sa.select([groups, user_to_groups.c.access_rights])
sa.select(groups, user_to_groups.c.access_rights)
.select_from(user_to_groups.join(groups, user_to_groups.c.gid == groups.c.gid))
.where(and_(user_to_groups.c.uid == user_id, user_to_groups.c.gid == gid))
)
Expand All @@ -101,7 +101,7 @@ async def _get_user_group(conn: SAConnection, user_id: int, gid: int) -> RowProx
async def _get_user_from_email(app: web.Application, email: str) -> RowProxy:
engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
result = await conn.execute(sa.select([users]).where(users.c.email == email))
result = await conn.execute(sa.select(users).where(users.c.email == email))
user: RowProxy = await result.fetchone()
if not user:
raise UserNotFoundError(email=email)
Expand Down Expand Up @@ -143,7 +143,7 @@ async def create_user_group(
engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
result = await conn.execute(
sa.select([users.c.primary_gid]).where(users.c.id == user_id)
sa.select(users.c.primary_gid).where(users.c.id == user_id)
)
user: RowProxy = await result.fetchone()
if not user:
Expand Down Expand Up @@ -217,7 +217,7 @@ async def list_users_in_group(
check_group_permissions(group, user_id, gid, "read")
# now get the list
query = (
sa.select([users, user_to_groups.c.access_rights])
sa.select(users, user_to_groups.c.access_rights)
.select_from(users.join(user_to_groups))
.where(user_to_groups.c.gid == gid)
)
Expand All @@ -234,7 +234,7 @@ async def auto_add_user_to_groups(app: web.Application, user_id: int) -> None:
engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
# get the groups where there are inclusion rules and see if they apply
query = sa.select([groups]).where(groups.c.inclusion_rules != {})
query = sa.select(groups).where(groups.c.inclusion_rules != {})
possible_group_ids = set()
async for row in conn.execute(query):
inclusion_rules = row[groups.c.inclusion_rules]
Expand Down Expand Up @@ -303,8 +303,7 @@ async def add_user_in_group(
check_group_permissions(group, user_id, gid, "write")
# now check the new user exists
users_count = await conn.scalar(
# pylint: disable=no-value-for-parameter
sa.select([sa.func.count()]).where(users.c.id == new_user_id)
sa.select(sa.func.count()).where(users.c.id == new_user_id)
)
if not users_count:
raise UserInGroupNotFoundError(new_user_id, gid) # type: ignore
Expand All @@ -325,7 +324,7 @@ async def _get_user_in_group_permissions(
) -> RowProxy:
# now get the user
result = await conn.execute(
sa.select([users, user_to_groups.c.access_rights])
sa.select(users, user_to_groups.c.access_rights)
.select_from(users.join(user_to_groups, users.c.id == user_to_groups.c.uid))
.where(and_(user_to_groups.c.gid == gid, users.c.id == the_user_id_in_group))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
async def iter_products(conn: SAConnection) -> AsyncIterator[ResultProxy]:
"""Iterates on products sorted by priority i.e. the first is considered the default"""
async for row in conn.execute(
sa.select(_COLUMNS_IN_MODEL).order_by(products.c.priority)
sa.select(*_COLUMNS_IN_MODEL).order_by(products.c.priority)
):
assert row # nosec
yield row
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class BaseProjectDB:
@classmethod
async def _get_everyone_group(cls, conn: SAConnection) -> RowProxy:
result = await conn.execute(
sa.select([groups]).where(groups.c.type == GroupType.EVERYONE)
sa.select(groups).where(groups.c.type == GroupType.EVERYONE)
)
row = await result.first()
return row
Expand All @@ -197,7 +197,7 @@ async def _list_user_groups(
user_groups.append(everyone_group)
else:
result = await conn.execute(
select([groups])
select(groups)
.select_from(groups.join(user_to_groups))
.where(user_to_groups.c.uid == user_id)
)
Expand All @@ -208,16 +208,14 @@ async def _list_user_groups(
async def _get_user_email(conn: SAConnection, user_id: int | None) -> str:
if not user_id:
return "[email protected]"
email = await conn.scalar(
sa.select([users.c.email]).where(users.c.id == user_id)
)
email = await conn.scalar(sa.select(users.c.email).where(users.c.id == user_id))
assert isinstance(email, str) or email is None # nosec
return email or "Unknown"

@staticmethod
async def _get_user_primary_group_gid(conn: SAConnection, user_id: int) -> int:
primary_gid = await conn.scalar(
sa.select([users.c.primary_gid]).where(users.c.id == str(user_id))
sa.select(users.c.primary_gid).where(users.c.id == str(user_id))
)
if not primary_gid:
raise UserNotFoundError(uid=user_id)
Expand All @@ -226,7 +224,7 @@ async def _get_user_primary_group_gid(conn: SAConnection, user_id: int) -> int:

@staticmethod
async def _get_tags_by_project(conn: SAConnection, project_id: str) -> list:
query = sa.select([study_tags.c.tag_id]).where(
query = sa.select(study_tags.c.tag_id).where(
study_tags.c.study_id == project_id
)
return [row.tag_id async for row in conn.execute(query)]
Expand Down Expand Up @@ -342,7 +340,7 @@ async def _get_project(
if only_published:
conditions &= projects.c.published == "true"

query = select([projects]).where(conditions)
query = select(projects).where(conditions)
if for_update:
query = query.with_for_update()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def _get_active_user_with(self, identity: str) -> _UserIdentity | None:
async with self.engine.acquire() as conn:
# NOTE: sometimes it raises psycopg2.DatabaseError in #880 and #1160
result: ResultProxy = await conn.execute(
sa.select([users.c.id, users.c.role]).where(
sa.select(users.c.id, users.c.role).where(
(users.c.email == identity)
& (users.c.status == UserStatus.ACTIVE)
)
Expand Down
Loading

0 comments on commit 5b4b9b3

Please sign in to comment.