Skip to content

Commit

Permalink
⚗️ Maintenance: make upgrade/downgrade a module fixture (#4222)
Browse files Browse the repository at this point in the history
sanderegg authored May 11, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 65e6eb8 commit 5b4b9b3
Showing 25 changed files with 450 additions and 380 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[flake8]
ignore=E501
ignore=E501,W503
Original file line number Diff line number Diff line change
@@ -21,15 +21,17 @@
def upgrade():
# Reassign items from two_factor_enabled -> LOGIN_2FA_REQUIRED
conn = op.get_bind()
rows = conn.execute("SELECT name, login_settings FROM products").fetchall()
rows = conn.execute(sa.DDL("SELECT name, login_settings FROM products")).fetchall()
for row in rows:
data = row["login_settings"] or {}
if "two_factor_enabled" in data:
data["LOGIN_2FA_REQUIRED"] = data.pop("two_factor_enabled")
data = json.dumps(data)
conn.execute(
"UPDATE products SET login_settings = '{}' WHERE name = '{}'".format( # nosec
data, row["name"]
sa.DDL(
"UPDATE products SET login_settings = '{}' WHERE name = '{}'".format( # nosec
data, row["name"]
)
)
)

@@ -45,16 +47,18 @@ def upgrade():
def downgrade():
# Reassign items from LOGIN_2FA_REQUIRED -> two_factor_enabled=false
conn = op.get_bind()
rows = conn.execute("SELECT name, login_settings FROM products").fetchall()
rows = conn.execute(sa.DDL("SELECT name, login_settings FROM products")).fetchall()
for row in rows:
data = row["login_settings"] or {}
data["two_factor_enabled"] = data.pop(
"LOGIN_2FA_REQUIRED", False
) # back to default
data = json.dumps(data)
conn.execute(
"UPDATE products SET login_settings = '{}' WHERE name = '{}'".format( # nosec
data, row["name"]
sa.DDL(
"UPDATE products SET login_settings = '{}' WHERE name = '{}'".format( # nosec
data, row["name"]
)
)
)

Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ def upgrade():

conn = op.get_bind()
default_product_name = conn.scalar(
"SELECT name from products ORDER BY priority LIMIT 1"
sa.DDL("SELECT name from products ORDER BY priority LIMIT 1")
)

op.add_column(
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
"""

from typing import Any, Optional, Protocol
from typing import Any, Protocol

import sqlalchemy as sa

@@ -39,7 +39,7 @@ async def get_default_product_name(conn: _DBConnection) -> str:
:: raises ValueError if undefined
"""
product_name = await conn.scalar(
sa.select([products.c.name]).order_by(products.c.priority)
sa.select(products.c.name).order_by(products.c.priority)
)
if not product_name:
raise ValueError("No product defined in database")
@@ -50,9 +50,9 @@ async def get_default_product_name(conn: _DBConnection) -> str:

async def get_product_group_id(
connection: _DBConnection, product_name: str
) -> Optional[_GroupID]:
) -> _GroupID | None:
group_id = await connection.scalar(
sa.select([products.c.group_id]).where(products.c.name == product_name)
sa.select(products.c.group_id).where(products.c.name == product_name)
)
return None if group_id is None else _GroupID(group_id)

@@ -65,7 +65,7 @@ async def get_or_create_product_group(
"""
async with connection.begin():
group_id = await connection.scalar(
sa.select([products.c.group_id])
sa.select(products.c.group_id)
.where(products.c.name == product_name)
.with_for_update(read=True)
# a `FOR SHARE` lock: locks changes in the product until transaction is done.
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import functools
import itertools
from dataclasses import dataclass
from typing import Optional, TypedDict
from typing import TypedDict

import sqlalchemy as sa
from aiopg.sa.connection import SAConnection
@@ -103,9 +103,9 @@ async def access_count(
conn: SAConnection,
tag_id: int,
*,
read: Optional[bool] = None,
write: Optional[bool] = None,
delete: Optional[bool] = None,
read: bool | None = None,
write: bool | None = None,
delete: bool | None = None,
) -> int:
"""
Returns 0 if tag does not match access
@@ -129,7 +129,7 @@ async def access_count(
stmt = sa.select(sa.func.count(user_to_groups.c.uid)).select_from(j)

# The number of occurrences of the user_id = how many groups are giving this access permission
permissions_count: Optional[int] = await conn.scalar(stmt)
permissions_count: int | None = await conn.scalar(stmt)
return permissions_count if permissions_count else 0

#
@@ -142,7 +142,7 @@ async def create(
*,
name: str,
color: str,
description: Optional[str] = None, # =nullable
description: str | None = None, # =nullable
read: bool = True,
write: bool = True,
delete: bool = True,
@@ -182,15 +182,15 @@ async def create(

async def list(self, conn: SAConnection) -> list[TagDict]:
select_stmt = (
sa.select(_COLUMNS)
sa.select(*_COLUMNS)
.select_from(self._join_user_to_tags(tags_to_groups.c.read == True))
.order_by(tags.c.id)
)

return [TagDict(row.items()) async for row in conn.execute(select_stmt)] # type: ignore

async def get(self, conn: SAConnection, tag_id: int) -> TagDict:
select_stmt = sa.select(_COLUMNS).select_from(
select_stmt = sa.select(*_COLUMNS).select_from(
self._join_user_to_given_tag(tags_to_groups.c.read == True, tag_id=tag_id)
)

@@ -208,7 +208,6 @@ async def update(
tag_id: int,
**fields,
) -> TagDict:

updates = {
name: value
for name, value in fields.items()
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from datetime import datetime
from typing import Optional, TypedDict
from typing import TypedDict

from aiohttp import web
from aiohttp.test_utils import TestClient
@@ -92,13 +92,13 @@ async def log_client_in(


class NewUser:
def __init__(self, params=None, app: Optional[web.Application] = None):
def __init__(self, params=None, app: web.Application | None = None):
self.params = params
self.user = None
assert app
self.db = get_plugin_storage(app)

async def __aenter__(self):
async def __aenter__(self) -> UserInfoDict:
self.user = await create_fake_user(self.db, self.params)
return self.user

@@ -123,9 +123,9 @@ class NewInvitation(NewUser):
def __init__(
self,
client: TestClient,
guest_email: Optional[str] = None,
host: Optional[dict] = None,
trial_days: Optional[int] = None,
guest_email: str | None = None,
host: dict | None = None,
trial_days: int | None = None,
):
assert client.app
super().__init__(params=host, app=client.app)
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ def __get_validators__(cls):
def validate_docker_version(cls, docker_version: str) -> str:
try:
search_result = re.search(r"^\d\d.(\d\d|\d).(\d\d|\d)", docker_version)
assert search_result # nosec
return search_result.group()
except AttributeError:
raise ValueError( # pylint: disable=raise-missing-from
10 changes: 3 additions & 7 deletions services/director-v2/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -51,6 +51,7 @@
ServiceState,
)
from simcore_service_director_v2.modules.dynamic_sidecar.docker_service_specs.volume_remover import (
DIND_VERSION,
DockerVersion,
)
from yarl import URL
@@ -451,10 +452,5 @@ async def async_docker_client() -> AsyncIterable[aiodocker.Docker]:


@pytest.fixture
async def docker_version(async_docker_client: aiodocker.Docker) -> DockerVersion:
version_request = (
await async_docker_client._query_json( # pylint: disable=protected-access
"version", versioned_api=False
)
)
return parse_obj_as(DockerVersion, version_request["Version"])
async def docker_version() -> DockerVersion:
return parse_obj_as(DockerVersion, DIND_VERSION)
Original file line number Diff line number Diff line change
@@ -840,13 +840,11 @@ async def named_volumes(
async def is_volume_present(
async_docker_client: aiodocker.Docker, volume_name: str
) -> bool:
docker_volume = DockerVolume(async_docker_client, volume_name)
try:
await docker_volume.show()
return True
except aiodocker.DockerError as e:
assert e.message == f"get {volume_name}: no such volume"
return False
list_of_volumes = await async_docker_client.volumes.list()
for volume in list_of_volumes.get("Volumes", []):
if volume["Name"] == volume_name:
return True
return False


async def test_remove_volume_from_node_ok(
17 changes: 8 additions & 9 deletions services/web/server/src/simcore_service_webserver/groups_api.py
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ async def list_user_groups(

async with engine.acquire() as conn:
query = (
sa.select([groups, user_to_groups.c.access_rights])
sa.select(groups, user_to_groups.c.access_rights)
.select_from(
user_to_groups.join(groups, user_to_groups.c.gid == groups.c.gid),
)
@@ -88,7 +88,7 @@ async def list_user_groups(

async def _get_user_group(conn: SAConnection, user_id: int, gid: int) -> RowProxy:
result = await conn.execute(
sa.select([groups, user_to_groups.c.access_rights])
sa.select(groups, user_to_groups.c.access_rights)
.select_from(user_to_groups.join(groups, user_to_groups.c.gid == groups.c.gid))
.where(and_(user_to_groups.c.uid == user_id, user_to_groups.c.gid == gid))
)
@@ -101,7 +101,7 @@ async def _get_user_group(conn: SAConnection, user_id: int, gid: int) -> RowProx
async def _get_user_from_email(app: web.Application, email: str) -> RowProxy:
engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
result = await conn.execute(sa.select([users]).where(users.c.email == email))
result = await conn.execute(sa.select(users).where(users.c.email == email))
user: RowProxy = await result.fetchone()
if not user:
raise UserNotFoundError(email=email)
@@ -143,7 +143,7 @@ async def create_user_group(
engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
result = await conn.execute(
sa.select([users.c.primary_gid]).where(users.c.id == user_id)
sa.select(users.c.primary_gid).where(users.c.id == user_id)
)
user: RowProxy = await result.fetchone()
if not user:
@@ -217,7 +217,7 @@ async def list_users_in_group(
check_group_permissions(group, user_id, gid, "read")
# now get the list
query = (
sa.select([users, user_to_groups.c.access_rights])
sa.select(users, user_to_groups.c.access_rights)
.select_from(users.join(user_to_groups))
.where(user_to_groups.c.gid == gid)
)
@@ -234,7 +234,7 @@ async def auto_add_user_to_groups(app: web.Application, user_id: int) -> None:
engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
# get the groups where there are inclusion rules and see if they apply
query = sa.select([groups]).where(groups.c.inclusion_rules != {})
query = sa.select(groups).where(groups.c.inclusion_rules != {})
possible_group_ids = set()
async for row in conn.execute(query):
inclusion_rules = row[groups.c.inclusion_rules]
@@ -303,8 +303,7 @@ async def add_user_in_group(
check_group_permissions(group, user_id, gid, "write")
# now check the new user exists
users_count = await conn.scalar(
# pylint: disable=no-value-for-parameter
sa.select([sa.func.count()]).where(users.c.id == new_user_id)
sa.select(sa.func.count()).where(users.c.id == new_user_id)
)
if not users_count:
raise UserInGroupNotFoundError(new_user_id, gid) # type: ignore
@@ -325,7 +324,7 @@ async def _get_user_in_group_permissions(
) -> RowProxy:
# now get the user
result = await conn.execute(
sa.select([users, user_to_groups.c.access_rights])
sa.select(users, user_to_groups.c.access_rights)
.select_from(users.join(user_to_groups, users.c.id == user_to_groups.c.uid))
.where(and_(user_to_groups.c.gid == gid, users.c.id == the_user_id_in_group))
)
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
async def iter_products(conn: SAConnection) -> AsyncIterator[ResultProxy]:
"""Iterates on products sorted by priority i.e. the first is considered the default"""
async for row in conn.execute(
sa.select(_COLUMNS_IN_MODEL).order_by(products.c.priority)
sa.select(*_COLUMNS_IN_MODEL).order_by(products.c.priority)
):
assert row # nosec
yield row
Original file line number Diff line number Diff line change
@@ -180,7 +180,7 @@ class BaseProjectDB:
@classmethod
async def _get_everyone_group(cls, conn: SAConnection) -> RowProxy:
result = await conn.execute(
sa.select([groups]).where(groups.c.type == GroupType.EVERYONE)
sa.select(groups).where(groups.c.type == GroupType.EVERYONE)
)
row = await result.first()
return row
@@ -197,7 +197,7 @@ async def _list_user_groups(
user_groups.append(everyone_group)
else:
result = await conn.execute(
select([groups])
select(groups)
.select_from(groups.join(user_to_groups))
.where(user_to_groups.c.uid == user_id)
)
@@ -208,16 +208,14 @@ async def _list_user_groups(
async def _get_user_email(conn: SAConnection, user_id: int | None) -> str:
if not user_id:
return "[email protected]"
email = await conn.scalar(
sa.select([users.c.email]).where(users.c.id == user_id)
)
email = await conn.scalar(sa.select(users.c.email).where(users.c.id == user_id))
assert isinstance(email, str) or email is None # nosec
return email or "Unknown"

@staticmethod
async def _get_user_primary_group_gid(conn: SAConnection, user_id: int) -> int:
primary_gid = await conn.scalar(
sa.select([users.c.primary_gid]).where(users.c.id == str(user_id))
sa.select(users.c.primary_gid).where(users.c.id == str(user_id))
)
if not primary_gid:
raise UserNotFoundError(uid=user_id)
@@ -226,7 +224,7 @@ async def _get_user_primary_group_gid(conn: SAConnection, user_id: int) -> int:

@staticmethod
async def _get_tags_by_project(conn: SAConnection, project_id: str) -> list:
query = sa.select([study_tags.c.tag_id]).where(
query = sa.select(study_tags.c.tag_id).where(
study_tags.c.study_id == project_id
)
return [row.tag_id async for row in conn.execute(query)]
@@ -342,7 +340,7 @@ async def _get_project(
if only_published:
conditions &= projects.c.published == "true"

query = select([projects]).where(conditions)
query = select(projects).where(conditions)
if for_update:
query = query.with_for_update()

Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ async def _get_active_user_with(self, identity: str) -> _UserIdentity | None:
async with self.engine.acquire() as conn:
# NOTE: sometimes it raises psycopg2.DatabaseError in #880 and #1160
result: ResultProxy = await conn.execute(
sa.select([users.c.id, users.c.role]).where(
sa.select(users.c.id, users.c.role).where(
(users.c.email == identity)
& (users.c.status == UserStatus.ACTIVE)
)
18 changes: 9 additions & 9 deletions services/web/server/src/simcore_service_webserver/users_api.py
Original file line number Diff line number Diff line change
@@ -125,7 +125,7 @@ async def update_user_profile(
last_name = profile_update.last_name
if not first_name or not last_name:
name = await conn.scalar(
sa.select([users.c.name]).where(users.c.id == user_id)
sa.select(users.c.name).where(users.c.id == user_id)
)
try:
first_name, last_name = name.rsplit(".", maxsplit=2)
@@ -155,7 +155,7 @@ async def get_user_role(app: web.Application, user_id: UserID) -> UserRole:
engine: Engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
user_role: RowProxy | None = await conn.scalar(
sa.select([users.c.role]).where(users.c.id == user_id)
sa.select(users.c.role).where(users.c.id == user_id)
)
if user_role is None:
raise UserNotFoundError(uid=user_id)
@@ -167,7 +167,7 @@ async def get_guest_user_ids_and_names(app: web.Application) -> list[tuple[int,
result = deque()
async with engine.acquire() as conn:
async for row in conn.execute(
sa.select([users.c.id, users.c.name]).where(users.c.role == UserRole.GUEST)
sa.select(users.c.id, users.c.name).where(users.c.role == UserRole.GUEST)
):
result.append(row.as_tuple())
return list(result)
@@ -207,7 +207,7 @@ async def get_user_name(app: web.Application, user_id: int) -> UserNameDict:
user_id = _parse_as_user(user_id)
async with engine.acquire() as conn:
user_name = await conn.scalar(
sa.select([users.c.name]).where(users.c.id == user_id)
sa.select(users.c.name).where(users.c.id == user_id)
)
if not user_name:
raise UserNotFoundError(uid=user_id)
@@ -226,7 +226,7 @@ async def get_user(app: web.Application, user_id: int) -> dict:
engine = app[APP_DB_ENGINE_KEY]
user_id = _parse_as_user(user_id)
async with engine.acquire() as conn:
result = await conn.execute(sa.select([users]).where(users.c.id == user_id))
result = await conn.execute(sa.select(users).where(users.c.id == user_id))
row: RowProxy = await result.fetchone()
if not row:
raise UserNotFoundError(uid=user_id)
@@ -237,7 +237,7 @@ async def get_user_id_from_gid(app: web.Application, primary_gid: int) -> int:
engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
return await conn.scalar(
sa.select([users.c.id]).where(users.c.primary_gid == primary_gid)
sa.select(users.c.id).where(users.c.primary_gid == primary_gid)
)


@@ -265,7 +265,7 @@ async def list_tokens(app: web.Application, user_id: int) -> list[dict[str, str]
user_tokens = []
async with engine.acquire() as conn:
async for row in conn.execute(
sa.select([tokens.c.token_data]).where(tokens.c.user_id == user_id)
sa.select(tokens.c.token_data).where(tokens.c.user_id == user_id)
):
user_tokens.append(row["token_data"])
return user_tokens
@@ -277,7 +277,7 @@ async def get_token(
engine = app[APP_DB_ENGINE_KEY]
async with engine.acquire() as conn:
result = await conn.execute(
sa.select([tokens.c.token_data]).where(
sa.select(tokens.c.token_data).where(
and_(tokens.c.user_id == user_id, tokens.c.token_service == service_id)
)
)
@@ -292,7 +292,7 @@ async def update_token(
# TODO: optimize to a single call?
async with engine.acquire() as conn:
result = await conn.execute(
sa.select([tokens.c.token_data, tokens.c.token_id]).where(
sa.select(tokens.c.token_data, tokens.c.token_id).where(
and_(tokens.c.user_id == user_id, tokens.c.token_service == service_id)
)
)
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@
#


@pytest.fixture
@pytest.fixture(scope="module")
def postgres_db(postgres_db: sa.engine.Engine) -> sa.engine.Engine:
#
# Extends postgres_db fixture (called with web_server) to inject tables and start redis
471 changes: 241 additions & 230 deletions services/web/server/tests/unit/with_dbs/01/test_groups.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -167,15 +167,17 @@ def user_project_in_2_products(
faker: Faker,
) -> Iterator[dict[str, Any]]:
fake_product_name = faker.name()
postgres_db.execute(products.insert().values(name=fake_product_name, host_regex=""))
postgres_db.execute(
projects_to_products.insert().values(
project_uuid=user_project["uuid"], product_name=fake_product_name
with postgres_db.connect() as conn:
conn.execute(products.insert().values(name=fake_product_name, host_regex=""))
conn.execute(
projects_to_products.insert().values(
project_uuid=user_project["uuid"], product_name=fake_product_name
)
)
)
yield user_project
# cleanup
postgres_db.execute(products.delete().where(products.c.name == fake_product_name))
with postgres_db.connect() as conn:
conn.execute(products.delete().where(products.c.name == fake_product_name))


@pytest.mark.parametrize(*standard_role_response())
Original file line number Diff line number Diff line change
@@ -235,13 +235,15 @@ def s4l_product_name() -> str:
def s4l_products_db_name(
postgres_db: sa.engine.Engine, s4l_product_name: str
) -> Iterator[str]:
postgres_db.execute(
products.insert().values(
name=s4l_product_name, host_regex="pytest", display_name="pytest"
with postgres_db.connect() as conn:
conn.execute(
products.insert().values(
name=s4l_product_name, host_regex="pytest", display_name="pytest"
)
)
)
yield s4l_product_name
postgres_db.execute(products.delete().where(products.c.name == s4l_product_name))
with postgres_db.connect() as conn:
conn.execute(products.delete().where(products.c.name == s4l_product_name))


@pytest.fixture
@@ -277,7 +279,8 @@ async def test_list_projects_with_innaccessible_services(
assert len(data) == 0
# use-case 3: remove the links to products
# shall still return 0 because the user has no access to the services
postgres_db.execute(projects_to_products.delete())
with postgres_db.connect() as conn:
conn.execute(projects_to_products.delete())
data, *_ = await _list_projects(client, expected, headers=s4l_product_headers)
assert len(data) == 0
data, *_ = await _list_projects(client, expected)
78 changes: 45 additions & 33 deletions services/web/server/tests/unit/with_dbs/03/login/test_login_2fa.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
# pylint: disable=unused-variable

import asyncio
from contextlib import AsyncExitStack
from unittest.mock import Mock

import pytest
@@ -11,10 +12,9 @@
from aiohttp.test_utils import TestClient, make_mocked_request
from faker import Faker
from pytest import CaptureFixture, MonkeyPatch
from pytest_simcore.helpers import utils_login
from pytest_simcore.helpers.utils_assert import assert_status
from pytest_simcore.helpers.utils_envs import EnvVarsDict, setenvs_from_dict
from pytest_simcore.helpers.utils_login import parse_link, parse_test_marks
from pytest_simcore.helpers.utils_login import NewUser, parse_link, parse_test_marks
from servicelib.utils_secrets import generate_passcode
from simcore_postgres_database.models.products import ProductLoginSettingsDict, products
from simcore_service_webserver.application_settings import ApplicationSettings
@@ -59,7 +59,8 @@ def postgres_db(postgres_db: sa.engine.Engine):
)
.where(products.c.name == "osparc")
)
postgres_db.execute(stmt)
with postgres_db.connect() as conn:
conn.execute(stmt)
return postgres_db


@@ -234,10 +235,12 @@ def _get_confirmation_link_from_email():
assert user["phone"] == fake_user_phone_number
assert user["status"] == UserStatus.ACTIVE.value

# cleanup
await db.delete_user(user)


async def test_register_phone_fails_with_used_number(
client: TestClient,
db: AsyncpgStorage,
fake_user_email: str,
fake_user_password: str,
fake_user_phone_number: str,
@@ -247,37 +250,46 @@ async def test_register_phone_fails_with_used_number(
"""
assert client.app

# some user ALREADY registered with the same phone
await utils_login.create_fake_user(db, data={"phone": fake_user_phone_number})

# some registered user w/o phone
await utils_login.create_fake_user(
db,
data={"email": fake_user_email, "password": fake_user_password, "phone": None},
)
async with AsyncExitStack() as users_stack:
# some user ALREADY registered with the same phone
await users_stack.enter_async_context(
NewUser(params={"phone": fake_user_phone_number}, app=client.app)
)

# 1. login
url = client.app.router["auth_login"].url_for()
response = await client.post(
f"{url}",
json={
"email": fake_user_email,
"password": fake_user_password,
},
)
await assert_status(response, web.HTTPAccepted)
# some registered user w/o phone
await users_stack.enter_async_context(
NewUser(
params={
"email": fake_user_email,
"password": fake_user_password,
"phone": None,
},
app=client.app,
)
)

# 2. register existing phone
url = client.app.router["auth_register_phone"].url_for()
response = await client.post(
f"{url}",
json={
"email": fake_user_email,
"phone": fake_user_phone_number,
},
)
_, error = await assert_status(response, web.HTTPUnauthorized)
assert "phone" in error["message"]
# 1. login
url = client.app.router["auth_login"].url_for()
response = await client.post(
f"{url}",
json={
"email": fake_user_email,
"password": fake_user_password,
},
)
await assert_status(response, web.HTTPAccepted)

# 2. register existing phone
url = client.app.router["auth_register_phone"].url_for()
response = await client.post(
f"{url}",
json={
"email": fake_user_email,
"phone": fake_user_phone_number,
},
)
_, error = await assert_status(response, web.HTTPUnauthorized)
assert "phone" in error["message"]


async def test_send_email_code(
Original file line number Diff line number Diff line change
@@ -47,7 +47,8 @@ def postgres_db(postgres_db: sa.engine.Engine):
)
.where(products.c.name == "osparc")
)
postgres_db.execute(stmt)
with postgres_db.connect() as conn:
conn.execute(stmt)
return postgres_db


Original file line number Diff line number Diff line change
@@ -4,8 +4,10 @@
# pylint: disable=unused-variable

from datetime import timedelta
from typing import Iterator

import pytest
import sqlalchemy as sa
from aiohttp import web
from aiohttp.test_utils import TestClient
from faker import Faker
@@ -15,6 +17,7 @@
from pytest_simcore.helpers.utils_envs import EnvVarsDict, setenvs_from_dict
from pytest_simcore.helpers.utils_login import NewInvitation, NewUser, parse_link
from servicelib.aiohttp.rest_responses import unwrap_envelope
from simcore_postgres_database.models.users import users
from simcore_service_webserver.db_models import ConfirmationAction, UserStatus
from simcore_service_webserver.login._confirmation import _url_for_confirmation
from simcore_service_webserver.login._constants import (
@@ -50,8 +53,18 @@ def app_environment(
return app_environment | login_envs


@pytest.fixture
def _clean_user_table(postgres_db: sa.engine.Engine) -> Iterator[None]:
yield
with postgres_db.connect() as conn:
conn.execute(users.delete())


async def test_register_entrypoint(
client: TestClient, fake_user_email: str, fake_user_password: str
client: TestClient,
fake_user_email: str,
fake_user_password: str,
_clean_user_table: None,
):
assert client.app
url = client.app.router["auth_register"].url_for()
@@ -68,7 +81,9 @@ async def test_register_entrypoint(
assert fake_user_email in data["message"]


async def test_register_body_validation(client: TestClient, fake_user_password: str):
async def test_register_body_validation(
client: TestClient, fake_user_password: str, _clean_user_table: None
):
assert client.app
url = client.app.router["auth_register"].url_for()
response = await client.post(
@@ -111,7 +126,9 @@ async def test_regitration_is_not_get(client: TestClient):
await assert_error(response, web.HTTPMethodNotAllowed)


async def test_registration_with_existing_email(client: TestClient):
async def test_registration_with_existing_email(
client: TestClient, _clean_user_table: None
):
assert client.app

async with NewUser(app=client.app) as user:
@@ -133,6 +150,7 @@ async def test_registration_with_expired_confirmation(
client: TestClient,
db: AsyncpgStorage,
mocker: MockerFixture,
_clean_user_table: None,
):
assert client.app
mocker.patch(
@@ -171,6 +189,7 @@ async def test_registration_with_invalid_confirmation_code(
login_options: LoginOptions,
db: AsyncpgStorage,
mocker: MockerFixture,
_clean_user_table: None,
):
# Checks bug in https://github.com/ITISFoundation/osparc-simcore/pull/3356
assert client.app
@@ -201,6 +220,7 @@ async def test_registration_without_confirmation(
mocker: MockerFixture,
fake_user_email: str,
fake_user_password: str,
_clean_user_table: None,
):
assert client.app
mocker.patch(
@@ -240,6 +260,7 @@ async def test_registration_with_confirmation(
fake_user_email: str,
fake_user_password: str,
mocked_email_core_remove_comments: None,
_clean_user_table: None,
):
assert client.app
mocker.patch(
@@ -310,6 +331,7 @@ async def test_registration_with_invitation(
mocker: MockerFixture,
fake_user_email: str,
fake_user_password: str,
_clean_user_table: None,
):
assert client.app
mocker.patch(
@@ -371,6 +393,7 @@ async def test_registraton_with_invitation_for_trial_account(
mocker: MockerFixture,
fake_user_email: str,
fake_user_password: str,
_clean_user_table: None,
):
assert client.app
mocker.patch(
15 changes: 12 additions & 3 deletions services/web/server/tests/unit/with_dbs/03/tags/test_tags.py
Original file line number Diff line number Diff line change
@@ -4,9 +4,10 @@
# pylint: disable=too-many-arguments


from typing import Any, AsyncIterator, Callable
from typing import Any, AsyncIterator, Callable, Iterator

import pytest
import sqlalchemy as sa
from aiohttp import web
from aiohttp.test_utils import TestClient
from faker import Faker
@@ -23,20 +24,27 @@
from pytest_simcore.helpers.utils_login import UserInfoDict
from pytest_simcore.helpers.utils_projects import assert_get_same_project
from pytest_simcore.helpers.utils_tags import create_tag, delete_tag
from simcore_postgres_database.models.tags import tags
from simcore_service_webserver import tags_handlers
from simcore_service_webserver._meta import api_version_prefix
from simcore_service_webserver.db import get_database_engine
from simcore_service_webserver.db_models import UserRole
from simcore_service_webserver.projects.project_models import ProjectDict


@pytest.fixture
def _clean_tags_table(postgres_db: sa.engine.Engine) -> Iterator[None]:
yield
with postgres_db.connect() as conn:
conn.execute(tags.delete())


@pytest.mark.parametrize(
"route",
tags_handlers.routes,
ids=lambda r: f"{r.method.upper()} {r.path}",
)
def test_tags_route_against_openapi_specs(route, openapi_specs: OpenApiSpecs):

assert route.path.startswith(f"/{api_version_prefix}")
path = route.path.replace(f"/{api_version_prefix}", "")

@@ -185,6 +193,7 @@ async def test_create_and_update_tags(
logged_user: UserInfoDict,
user_role: UserRole,
everybody_tag_id: int,
_clean_tags_table: None,
):
assert client.app

@@ -204,7 +213,7 @@ async def test_create_and_update_tags(
"accessRights": {"read": True, "write": True, "delete": True},
}

url = client.app.router["update_tag"].url_for(tag_id="2")
url = client.app.router["update_tag"].url_for(tag_id=f"{created['id']}")
resp = await client.patch(
f"{url}",
json={"description": "This is my tag"},
21 changes: 12 additions & 9 deletions services/web/server/tests/unit/with_dbs/03/test_project_db.py
Original file line number Diff line number Diff line change
@@ -273,7 +273,8 @@ def db_api(client: TestClient, postgres_db: sa.engine.Engine) -> Iterator[Projec
yield db_api

# clean the projects
postgres_db.execute("DELETE FROM projects")
with postgres_db.connect() as conn:
conn.execute("DELETE FROM projects")


def _assert_added_project(
@@ -304,11 +305,12 @@ def _assert_added_project(
def _assert_projects_to_product_db_row(
postgres_db: sa.engine.Engine, project: dict[str, Any], product_name: str
):
rows = postgres_db.execute(
sa.select([projects_to_products]).where(
projects_to_products.c.project_uuid == f"{project['uuid']}"
)
).fetchall()
with postgres_db.connect() as conn:
rows = conn.execute(
sa.select([projects_to_products]).where(
projects_to_products.c.project_uuid == f"{project['uuid']}"
)
).fetchall()
assert rows
assert len(rows) == 1
assert rows[0][projects_to_products.c.product_name] == product_name
@@ -317,9 +319,10 @@ def _assert_projects_to_product_db_row(
def _assert_project_db_row(
postgres_db: sa.engine.Engine, project: dict[str, Any], **kwargs
):
row: Row | None = postgres_db.execute(
f"SELECT * FROM projects WHERE \"uuid\"='{project['uuid']}'"
).fetchone()
with postgres_db.connect() as conn:
row: Row | None = conn.execute(
f"SELECT * FROM projects WHERE \"uuid\"='{project['uuid']}'"
).fetchone()

expected_db_entries = {
"type": "STANDARD",
45 changes: 28 additions & 17 deletions services/web/server/tests/unit/with_dbs/03/test_users.py
Original file line number Diff line number Diff line change
@@ -9,10 +9,11 @@
from copy import deepcopy
from datetime import datetime, timezone
from itertools import repeat
from typing import Any, AsyncIterable, Callable
from typing import Any, AsyncIterable, AsyncIterator, Callable
from unittest.mock import MagicMock, Mock

import pytest
import redis.asyncio as aioredis
from aiohttp import web
from aiohttp.test_utils import TestClient
from aiopg.sa.connection import SAConnection
@@ -113,7 +114,7 @@ async def fake_tokens(logged_user: UserInfoDict, tokens_db, faker: Faker):
"token_key": faker.md5(raw_output=False),
"token_secret": faker.md5(raw_output=False),
}
row = await create_token_in_db(
await create_token_in_db(
tokens_db,
user_id=logged_user["id"],
token_service=data["service"],
@@ -239,10 +240,11 @@ async def test_create_token(
"token_secret": "my secret",
}

resp = await client.post(url, json=token)
resp = await client.post(f"{url}", json=token)
data, error = await assert_status(resp, expected)
if not error:
db_token = await get_token_from_db(tokens_db, token_data=token)
assert db_token
assert db_token["token_data"] == token
assert db_token["user_id"] == logged_user["id"]

@@ -317,7 +319,7 @@ async def test_update_token(
if not error:
# check in db
token_in_db = await get_token_from_db(tokens_db, token_service=sid)

assert token_in_db
assert token_in_db["token_data"]["token_secret"] == "some completely new secret"
assert token_in_db["token_data"]["token_secret"] != selected["token_secret"]

@@ -387,10 +389,11 @@ async def test_get_profile_with_failing_db_connection(
ISSUES: #880, #1160
"""
assert client.app
url = client.app.router["get_my_profile"].url_for()
assert str(url) == "/v0/me"

resp = await client.get(url)
resp = await client.get(f"{url}")

NUM_RETRY = 3
assert (
@@ -401,16 +404,19 @@ async def test_get_profile_with_failing_db_connection(


@pytest.fixture
async def notification_redis_client(client: TestClient) -> AsyncIterable[Redis]:
async def notification_redis_client(
client: TestClient,
) -> AsyncIterable[aioredis.Redis]:
assert client.app
redis_client = get_redis_user_notifications_client(client.app)
yield redis_client
await redis_client.flushall()


@asynccontextmanager
async def _create_notifications(
redis_client: Redis, logged_user: UserInfoDict, count: int
) -> AsyncIterable[list[UserNotification]]:
redis_client: aioredis.Redis, logged_user: UserInfoDict, count: int
) -> AsyncIterator[list[UserNotification]]:
user_id = logged_user["id"]
notification_categories = tuple(NotificationCategory)

@@ -450,17 +456,18 @@ async def _create_notifications(
)
async def test_get_user_notifications(
logged_user: UserInfoDict,
notification_redis_client: Redis,
notification_redis_client: aioredis.Redis,
client: TestClient,
notification_count: int,
):
assert client.app
url = client.app.router["get_user_notifications"].url_for()
assert str(url) == "/v0/me/notifications"

async with _create_notifications(
notification_redis_client, logged_user, notification_count
) as created_notifications:
response = await client.get(url)
response = await client.get(f"{url}")
json_response = await response.json()

result = parse_obj_as(list[UserNotification], json_response["data"])
@@ -502,13 +509,15 @@ async def test_get_user_notifications(
)
async def test_post_user_notification(
logged_user: UserInfoDict,
notification_redis_client: Redis,
notification_redis_client: aioredis.Redis,
client: TestClient,
notification_dict: dict[str, Any],
):
assert client.app
url = client.app.router["post_user_notification"].url_for()
assert str(url) == "/v0/me/notifications"
resp = await client.post(url, json=notification_dict)
notification_dict["user_id"] = logged_user["id"]
resp = await client.post(f"{url}", json=notification_dict)
assert resp.status == web.HTTPNoContent.status_code, await resp.text()

user_id = logged_user["id"]
@@ -518,7 +527,7 @@ async def test_post_user_notification(
assert len(user_notifications) == 1
# these are always generated and overwritten, even if provided by the user, since
# we do not want to overwrite existing ones
assert user_notifications[0].read == False
assert user_notifications[0].read is False
assert user_notifications[0].id != notification_dict.get("id", None)


@@ -535,17 +544,18 @@ async def test_post_user_notification(
)
async def test_post_user_notification_capped_list_length(
logged_user: UserInfoDict,
notification_redis_client: Redis,
notification_redis_client: aioredis.Redis,
client: TestClient,
notification_count: int,
):
assert client.app
url = client.app.router["post_user_notification"].url_for()
assert str(url) == "/v0/me/notifications"

notifications_create_results = await asyncio.gather(
*(
client.post(
url,
f"{url}",
json={
"user_id": "1",
"category": NotificationCategory.NEW_ORGANIZATION,
@@ -580,10 +590,11 @@ async def test_post_user_notification_capped_list_length(
)
async def test_update_user_notification_at_correct_index(
logged_user: UserInfoDict,
notification_redis_client: Redis,
notification_redis_client: aioredis.Redis,
client: TestClient,
notification_count: int,
):
assert client.app
user_id = logged_user["id"]

async def _get_stored_notifications() -> list[UserNotification]:
@@ -613,7 +624,7 @@ def _marked_as_read(
assert str(url) == f"/v0/me/notifications/{notification.id}"
assert notification.read is False

resp = await client.patch(url, json={"read": True})
resp = await client.patch(f"{url}", json={"read": True})
assert resp.status == web.HTTPNoContent.status_code

notifications_after_update = await _get_stored_notifications()
2 changes: 1 addition & 1 deletion services/web/server/tests/unit/with_dbs/conftest.py
Original file line number Diff line number Diff line change
@@ -442,7 +442,7 @@ def postgres_service(docker_services, postgres_dsn):
return url


@pytest.fixture(scope="function")
@pytest.fixture(scope="module")
def postgres_db(
postgres_dsn: dict, postgres_service: str
) -> Iterator[sa.engine.Engine]:

0 comments on commit 5b4b9b3

Please sign in to comment.