diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index c85eaa3c7..b55f5b0af 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -3,3 +3,5 @@ # Enforce frontend styling and remove dead code fefc45c2cdc4d3107369c4d70210894d098a775c +# Ignore backend ruff format/lint +a35a4c7c4de78f0b38502882382f07aeac815b3c diff --git a/backend/dataall/base/utils/json_utils.py b/backend/dataall/base/utils/json_utils.py index 1a57af3df..9b7c9697b 100644 --- a/backend/dataall/base/utils/json_utils.py +++ b/backend/dataall/base/utils/json_utils.py @@ -37,7 +37,7 @@ def to_json(record): elif isinstance(record, type({'a': 'dict'})): return json.loads(json.dumps(record, default=json_decoder)) elif type(record) in [str, 'unicode']: - return record + return json.dumps(record) elif type(record) in [int, float]: return json.dumps(record) elif isinstance(record, bool): diff --git a/backend/dataall/core/organizations/api/resolvers.py b/backend/dataall/core/organizations/api/resolvers.py index 01d87e61e..a43933b3d 100644 --- a/backend/dataall/core/organizations/api/resolvers.py +++ b/backend/dataall/core/organizations/api/resolvers.py @@ -113,7 +113,7 @@ def send_query_chatbot(context, source, queryString): def list_group_organization_permissions(context, source, organizationUri, groupUri): - return OrganizationService.list_group_organization_permissions(organizationUri, groupUri) + return OrganizationService.list_group_organization_permissions(uri=organizationUri, groupUri=groupUri) def list_invited_organization_permissions_with_descriptions(context, source): diff --git a/backend/dataall/core/resource_lock/db/resource_lock_models.py b/backend/dataall/core/resource_lock/db/resource_lock_models.py index d98f30ffe..7e478758e 100644 --- a/backend/dataall/core/resource_lock/db/resource_lock_models.py +++ b/backend/dataall/core/resource_lock/db/resource_lock_models.py @@ -9,7 +9,6 @@ class ResourceLock(Base): resourceUri = Column(String, nullable=False, primary_key=True) resourceType = Column(String, nullable=False, primary_key=True) - isLocked = Column(Boolean, default=False) acquiredByUri = Column(String, nullable=True) acquiredByType = Column(String, nullable=True) @@ -17,12 +16,10 @@ def __init__( self, resourceUri: str, resourceType: str, - isLocked: bool = False, acquiredByUri: Optional[str] = None, acquiredByType: Optional[str] = None, ): self.resourceUri = resourceUri self.resourceType = resourceType - self.isLocked = isLocked self.acquiredByUri = acquiredByUri self.acquiredByType = acquiredByType diff --git a/backend/dataall/core/resource_lock/db/resource_lock_repositories.py b/backend/dataall/core/resource_lock/db/resource_lock_repositories.py index e242853a1..25d4e24c8 100644 --- a/backend/dataall/core/resource_lock/db/resource_lock_repositories.py +++ b/backend/dataall/core/resource_lock/db/resource_lock_repositories.py @@ -2,35 +2,23 @@ from dataall.core.resource_lock.db.resource_lock_models import ResourceLock from sqlalchemy import and_, or_ +from sqlalchemy.orm import Session +from time import sleep +from typing import List, Tuple +from contextlib import contextmanager +from dataall.base.db.exceptions import ResourceLockTimeout log = logging.getLogger(__name__) +MAX_RETRIES = 10 +RETRY_INTERVAL = 60 -class ResourceLockRepository: - @staticmethod - def create_resource_lock( - session, resource_uri, resource_type, is_locked=False, acquired_by_uri=None, acquired_by_type=None - ): - resource_lock = ResourceLock( - resourceUri=resource_uri, - resourceType=resource_type, - isLocked=is_locked, - acquiredByUri=acquired_by_uri, - acquiredByType=acquired_by_type, - ) - session.add(resource_lock) - session.commit() - - @staticmethod - def delete_resource_lock(session, resource_uri): - resource_lock = session.query(ResourceLock).filter(ResourceLock.resourceUri == resource_uri).first() - session.delete(resource_lock) - session.commit() +class ResourceLockRepository: @staticmethod - def acquire_locks(resources, session, acquired_by_uri, acquired_by_type): + def _acquire_locks(resources, session, acquired_by_uri, acquired_by_type): """ - Attempts to acquire one or more locks on the resources identified by resourceUri and resourceType. + Attempts to acquire/create one or more locks on the resources identified by resourceUri and resourceType. Args: resources: List of resource tuples (resourceUri, resourceType) to acquire locks for. @@ -47,19 +35,22 @@ def acquire_locks(resources, session, acquired_by_uri, acquired_by_type): and_( ResourceLock.resourceUri == resource[0], ResourceLock.resourceType == resource[1], - ~ResourceLock.isLocked, ) for resource in resources ] - resource_locks = session.query(ResourceLock).filter(or_(*filter_conditions)).with_for_update().all() - # Ensure lock record found for each resource - if len(resource_locks) == len(resources): - # Update the attributes of the ResourceLock object - for resource_lock in resource_locks: - resource_lock.isLocked = True - resource_lock.acquiredByUri = acquired_by_uri - resource_lock.acquiredByType = acquired_by_type + if not session.query(ResourceLock).filter(or_(*filter_conditions)).first(): + records = [] + for resource in resources: + records.append( + ResourceLock( + resourceUri=resource[0], + resourceType=resource[1], + acquiredByUri=acquired_by_uri, + acquiredByType=acquired_by_type, + ) + ) + session.add_all(records) session.commit() return True else: @@ -74,9 +65,9 @@ def acquire_locks(resources, session, acquired_by_uri, acquired_by_type): return False @staticmethod - def release_lock(session, resource_uri, resource_type, share_uri): + def _release_lock(session, resource_uri, resource_type, share_uri): """ - Releases the lock on the resource identified by resource_uri, resource_type. + Releases/delete the lock on the resource identified by resource_uri, resource_type. Args: session (sqlalchemy.orm.Session): The SQLAlchemy session object used for interacting with the database. @@ -96,7 +87,6 @@ def release_lock(session, resource_uri, resource_type, share_uri): and_( ResourceLock.resourceUri == resource_uri, ResourceLock.resourceType == resource_type, - ResourceLock.isLocked, ResourceLock.acquiredByUri == share_uri, ) ) @@ -105,10 +95,7 @@ def release_lock(session, resource_uri, resource_type, share_uri): ) if resource_lock: - resource_lock.isLocked = False - resource_lock.acquiredByUri = '' - resource_lock.acquiredByType = '' - + session.delete(resource_lock) session.commit() return True else: @@ -120,3 +107,31 @@ def release_lock(session, resource_uri, resource_type, share_uri): session.rollback() log.error('Error occurred while releasing lock:', e) return False + + @staticmethod + @contextmanager + def acquire_lock_with_retry( + resources: List[Tuple[str, str]], session: Session, acquired_by_uri: str, acquired_by_type: str + ): + retries_remaining = MAX_RETRIES + log.info(f'Attempting to acquire lock for resources {resources} by share {acquired_by_uri}...') + while not ( + lock_acquired := ResourceLockRepository._acquire_locks( + resources, session, acquired_by_uri, acquired_by_type + ) + ): + log.info( + f'Lock for one or more resources {resources} already acquired. Retrying in {RETRY_INTERVAL} seconds...' + ) + sleep(RETRY_INTERVAL) + retries_remaining -= 1 + if retries_remaining <= 0: + raise ResourceLockTimeout( + 'process shares', + f'Failed to acquire lock for one or more of {resources=}', + ) + try: + yield lock_acquired + finally: + for resource in resources: + ResourceLockRepository._release_lock(session, resource[0], resource[1], acquired_by_uri) diff --git a/backend/dataall/modules/catalog/__init__.py b/backend/dataall/modules/catalog/__init__.py index 1444359ed..95e76f0a6 100644 --- a/backend/dataall/modules/catalog/__init__.py +++ b/backend/dataall/modules/catalog/__init__.py @@ -17,7 +17,7 @@ def __init__(self): class CatalogAsyncHandlersModuleInterface(ModuleInterface): - """Implements ModuleInterface for datapipelines async lambda""" + """Implements ModuleInterface for catalog async lambda""" @staticmethod def is_supported(modes: Set[ImportMode]): diff --git a/backend/dataall/modules/catalog/api/types.py b/backend/dataall/modules/catalog/api/types.py index 78703dee0..f8aa049e0 100644 --- a/backend/dataall/modules/catalog/api/types.py +++ b/backend/dataall/modules/catalog/api/types.py @@ -42,6 +42,7 @@ fields=[ gql.Field(name='nodeUri', type=gql.ID), gql.Field(name='parentUri', type=gql.NonNullableType(gql.String)), + gql.Field(name='status', type=gql.NonNullableType(gql.String)), gql.Field(name='owner', type=gql.NonNullableType(gql.String)), gql.Field(name='path', type=gql.NonNullableType(gql.String)), gql.Field(name='label', type=gql.NonNullableType(gql.String)), diff --git a/backend/dataall/modules/catalog/handlers/ecs_catalog_handlers.py b/backend/dataall/modules/catalog/handlers/ecs_catalog_handlers.py index b215a933f..430a2d409 100644 --- a/backend/dataall/modules/catalog/handlers/ecs_catalog_handlers.py +++ b/backend/dataall/modules/catalog/handlers/ecs_catalog_handlers.py @@ -18,8 +18,8 @@ def run_ecs_reindex_catalog_task(engine, task: Task): CatalogIndexerTask.index_objects(engine, str(task.payload.get('with_deletes', False))) else: ecs_task_arn = Ecs.run_ecs_task( - task_definition_param='ecs/task_def_arn/share_management', - container_name_param='ecs/container/share_management', + task_definition_param='ecs/task_def_arn/catalog_indexer', + container_name_param='ecs/container/catalog_indexer', context=[ {'name': 'with_deletes', 'value': str(task.payload.get('with_deletes', False))}, ], diff --git a/backend/dataall/modules/catalog/services/catalog_service.py b/backend/dataall/modules/catalog/services/catalog_service.py index 92ae33d67..d750cde3b 100644 --- a/backend/dataall/modules/catalog/services/catalog_service.py +++ b/backend/dataall/modules/catalog/services/catalog_service.py @@ -1,12 +1,6 @@ import logging from dataall.base.context import get_context -from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService - -from dataall.modules.catalog.db.glossary_repositories import GlossaryRepository -from dataall.modules.catalog.db.glossary_models import GlossaryNode -from dataall.modules.catalog.services.glossaries_permissions import MANAGE_GLOSSARIES -from dataall.modules.catalog.indexers.registry import GlossaryRegistry from dataall.core.permissions.services.tenant_policy_service import TenantPolicyValidationService from dataall.core.tasks.db.task_models import Task from dataall.core.tasks.service_handlers import Worker @@ -15,10 +9,6 @@ logger = logging.getLogger(__name__) -def _session(): - return get_context().db_engine.scoped_session() - - class CatalogService: @staticmethod def start_reindex_catalog(with_deletes: bool) -> bool: diff --git a/backend/dataall/modules/dashboards/__init__.py b/backend/dataall/modules/dashboards/__init__.py index a47dd7f26..ffbc8e92d 100644 --- a/backend/dataall/modules/dashboards/__init__.py +++ b/backend/dataall/modules/dashboards/__init__.py @@ -75,3 +75,21 @@ def __init__(self): DashboardCatalogIndexer() log.info('Dashboard catalog indexer task has been loaded') + + +class DashboardAsyncHandlersModuleInterface(ModuleInterface): + """Implements ModuleInterface for dashboard async lambda""" + + @staticmethod + def is_supported(modes: Set[ImportMode]): + return ImportMode.HANDLERS in modes + + @staticmethod + def depends_on() -> List[Type['ModuleInterface']]: + from dataall.modules.catalog import CatalogAsyncHandlersModuleInterface + + return [CatalogAsyncHandlersModuleInterface] + + def __init__(self): + pass + log.info('S3 Dataset handlers have been imported') diff --git a/backend/dataall/modules/s3_datasets/services/dataset_service.py b/backend/dataall/modules/s3_datasets/services/dataset_service.py index 5f4dbf175..58561112a 100644 --- a/backend/dataall/modules/s3_datasets/services/dataset_service.py +++ b/backend/dataall/modules/s3_datasets/services/dataset_service.py @@ -165,9 +165,6 @@ def create_dataset(uri, admin_group, data: dict): DatasetService.check_imported_resources(dataset) dataset = DatasetRepository.create_dataset(session=session, env=environment, dataset=dataset, data=data) - ResourceLockRepository.create_resource_lock( - session=session, resource_uri=dataset.datasetUri, resource_type=dataset.__tablename__ - ) DatasetBucketRepository.create_dataset_bucket(session, dataset, data) ResourcePolicyService.attach_resource_policy( @@ -413,7 +410,6 @@ def delete_dataset(uri: str, delete_from_aws: bool = False): ResourcePolicyService.delete_resource_policy(session=session, resource_uri=uri, group=env.SamlGroupName) if dataset.stewards: ResourcePolicyService.delete_resource_policy(session=session, resource_uri=uri, group=dataset.stewards) - ResourceLockRepository.delete_resource_lock(session=session, resource_uri=dataset.datasetUri) DatasetRepository.delete_dataset(session, dataset) if delete_from_aws: diff --git a/backend/dataall/modules/s3_datasets_shares/__init__.py b/backend/dataall/modules/s3_datasets_shares/__init__.py index 779f4a4ff..fcf520992 100644 --- a/backend/dataall/modules/s3_datasets_shares/__init__.py +++ b/backend/dataall/modules/s3_datasets_shares/__init__.py @@ -23,15 +23,37 @@ def depends_on() -> List[Type['ModuleInterface']]: def __init__(self): from dataall.core.environment.services.environment_resource_manager import EnvironmentResourceManager from dataall.modules.s3_datasets_shares import api - from dataall.modules.s3_datasets_shares.services.managed_share_policy_service import SharePolicyService + from dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service import S3SharePolicyService from dataall.modules.s3_datasets.services.dataset_service import DatasetService from dataall.modules.datasets_base.services.dataset_list_service import DatasetListService - from dataall.modules.s3_datasets_shares.services.dataset_sharing_service import DatasetSharingService - from dataall.modules.s3_datasets_shares.db.share_object_repositories import ShareEnvironmentResource + from dataall.modules.s3_datasets_shares.services.s3_share_dataset_service import S3ShareDatasetService + from dataall.modules.s3_datasets_shares.db.s3_share_object_repositories import S3ShareEnvironmentResource + from dataall.modules.shares_base.services.share_processor_manager import ( + ShareProcessorManager, + ShareProcessorDefinition, + ) + from dataall.modules.shares_base.services.shares_enums import ShareableType + from dataall.modules.s3_datasets.db.dataset_models import DatasetTable, DatasetBucket, DatasetStorageLocation + + EnvironmentResourceManager.register(S3ShareEnvironmentResource()) + DatasetService.register(S3ShareDatasetService()) + DatasetListService.register(S3ShareDatasetService()) + + ShareProcessorManager.register_processor( + ShareProcessorDefinition(ShareableType.Table, None, DatasetTable, DatasetTable.tableUri) + ) + ShareProcessorManager.register_processor( + ShareProcessorDefinition(ShareableType.S3Bucket, None, DatasetBucket, DatasetBucket.bucketUri) + ) + ShareProcessorManager.register_processor( + ShareProcessorDefinition( + ShareableType.StorageLocation, + None, + DatasetStorageLocation, + DatasetStorageLocation.locationUri, + ) + ) - EnvironmentResourceManager.register(ShareEnvironmentResource()) - DatasetService.register(DatasetSharingService()) - DatasetListService.register(DatasetSharingService()) log.info('API of dataset sharing has been imported') @@ -55,7 +77,7 @@ def depends_on() -> List[Type['ModuleInterface']]: ] def __init__(self): - log.info('S3 Sharing handlers have been imported') + log.info('s3_datasets_shares handlers have been imported') class S3DatasetsSharesCdkModuleInterface(ModuleInterface): @@ -67,9 +89,9 @@ def is_supported(modes): def __init__(self): import dataall.modules.s3_datasets_shares.cdk - from dataall.modules.s3_datasets_shares.services.managed_share_policy_service import SharePolicyService + from dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service import S3SharePolicyService - log.info('CDK module data_sharing has been imported') + log.info('CDK module s3_datasets_shares has been imported') class S3DatasetsSharesECSShareModuleInterface(ModuleInterface): @@ -122,4 +144,4 @@ def __init__(self): ) ) - log.info('ECS Share module s3_data_sharing has been imported') + log.info('ECS Share module s3_datasets_shares has been imported') diff --git a/backend/dataall/modules/s3_datasets_shares/api/mutations.py b/backend/dataall/modules/s3_datasets_shares/api/mutations.py index 0af2eb548..3f95f0ca3 100644 --- a/backend/dataall/modules/s3_datasets_shares/api/mutations.py +++ b/backend/dataall/modules/s3_datasets_shares/api/mutations.py @@ -1,6 +1,7 @@ from dataall.base.api import gql from dataall.modules.s3_datasets_shares.api.resolvers import ( verify_dataset_share_objects, + reapply_share_items_share_object_for_dataset, ) @@ -10,3 +11,10 @@ type=gql.Boolean, resolver=verify_dataset_share_objects, ) + +reApplyShareObjectItemsOnDataset = gql.MutationField( + name='reApplyShareObjectItemsOnDataset', + args=[gql.Argument(name='datasetUri', type=gql.NonNullableType(gql.String))], + type=gql.Boolean, + resolver=reapply_share_items_share_object_for_dataset, +) diff --git a/backend/dataall/modules/s3_datasets_shares/api/resolvers.py b/backend/dataall/modules/s3_datasets_shares/api/resolvers.py index 5b07c159b..f77076a7a 100644 --- a/backend/dataall/modules/s3_datasets_shares/api/resolvers.py +++ b/backend/dataall/modules/s3_datasets_shares/api/resolvers.py @@ -3,7 +3,7 @@ from dataall.base.api.context import Context from dataall.base.db.exceptions import RequiredParameter from dataall.base.feature_toggle_checker import is_feature_enabled -from dataall.modules.s3_datasets_shares.services.dataset_sharing_service import DatasetSharingService +from dataall.modules.s3_datasets_shares.services.s3_share_service import S3ShareService log = logging.getLogger(__name__) @@ -41,32 +41,34 @@ def validate_dataset_share_selector_input(data): def list_shared_tables_by_env_dataset(context: Context, source, datasetUri: str, envUri: str): - return DatasetSharingService.list_shared_tables_by_env_dataset(datasetUri, envUri) + return S3ShareService.list_shared_tables_by_env_dataset(datasetUri, envUri) @is_feature_enabled('modules.s3_datasets.features.aws_actions') def get_dataset_shared_assume_role_url(context: Context, source, datasetUri: str = None): - return DatasetSharingService.get_dataset_shared_assume_role_url(uri=datasetUri) + return S3ShareService.get_dataset_shared_assume_role_url(uri=datasetUri) def verify_dataset_share_objects(context: Context, source, input): RequestValidator.validate_dataset_share_selector_input(input) dataset_uri = input.get('datasetUri') verify_share_uris = input.get('shareUris') - return DatasetSharingService.verify_dataset_share_objects(uri=dataset_uri, share_uris=verify_share_uris) + return S3ShareService.verify_dataset_share_objects(uri=dataset_uri, share_uris=verify_share_uris) + + +def reapply_share_items_share_object_for_dataset(context: Context, source, datasetUri: str): + return S3ShareService.reapply_share_items_for_dataset(uri=datasetUri) def get_s3_consumption_data(context: Context, source, shareUri: str): - return DatasetSharingService.get_s3_consumption_data(uri=shareUri) + return S3ShareService.get_s3_consumption_data(uri=shareUri) def list_shared_databases_tables_with_env_group(context: Context, source, environmentUri: str, groupUri: str): - return DatasetSharingService.list_shared_databases_tables_with_env_group( - environmentUri=environmentUri, groupUri=groupUri - ) + return S3ShareService.list_shared_databases_tables_with_env_group(environmentUri=environmentUri, groupUri=groupUri) def resolve_shared_db_name(context: Context, source, **kwargs): - return DatasetSharingService.resolve_shared_db_name( + return S3ShareService.resolve_shared_db_name( source.GlueDatabaseName, source.shareUri, source.targetEnvAwsAccountId, source.targetEnvRegion ) diff --git a/backend/dataall/modules/s3_datasets_shares/aws/glue_client.py b/backend/dataall/modules/s3_datasets_shares/aws/glue_client.py index 61795b199..ef694d8c0 100644 --- a/backend/dataall/modules/s3_datasets_shares/aws/glue_client.py +++ b/backend/dataall/modules/s3_datasets_shares/aws/glue_client.py @@ -175,6 +175,18 @@ def get_source_catalog(self): raise e return None + def get_glue_database_from_catalog(self): + # Check if a catalog account exists and return database accordingly + try: + catalog_dict = self.get_source_catalog() + + if catalog_dict is not None: + return catalog_dict.get('database_name') + else: + return self._database + except Exception as e: + raise e + def get_database_tags(self): # Get tags from the glue database account_id = self._account_id diff --git a/backend/dataall/modules/s3_datasets_shares/db/share_object_repositories.py b/backend/dataall/modules/s3_datasets_shares/db/s3_share_object_repositories.py similarity index 74% rename from backend/dataall/modules/s3_datasets_shares/db/share_object_repositories.py rename to backend/dataall/modules/s3_datasets_shares/db/s3_share_object_repositories.py index c21e04a3d..dac3ae8f0 100644 --- a/backend/dataall/modules/s3_datasets_shares/db/share_object_repositories.py +++ b/backend/dataall/modules/s3_datasets_shares/db/s3_share_object_repositories.py @@ -2,16 +2,13 @@ from warnings import warn from typing import List -from sqlalchemy import and_, or_, func, case +from sqlalchemy import and_, or_ from sqlalchemy.orm import Query -from dataall.core.environment.db.environment_models import Environment, EnvironmentGroup +from dataall.core.environment.db.environment_models import Environment from dataall.core.environment.services.environment_resource_manager import EnvironmentResource -from dataall.base.db import exceptions, paginate from dataall.modules.shares_base.services.shares_enums import ( - ShareItemHealthStatus, ShareObjectStatus, - ShareItemStatus, ShareableType, PrincipalType, ) @@ -19,13 +16,13 @@ from dataall.modules.shares_base.db.share_object_models import ShareObjectItem, ShareObject from dataall.modules.shares_base.db.share_object_repositories import ShareObjectRepository from dataall.modules.s3_datasets.db.dataset_repositories import DatasetRepository -from dataall.modules.s3_datasets.db.dataset_models import DatasetStorageLocation, DatasetTable, S3Dataset, DatasetBucket +from dataall.modules.s3_datasets.db.dataset_models import DatasetTable, S3Dataset from dataall.modules.datasets_base.db.dataset_models import DatasetBase logger = logging.getLogger(__name__) -class ShareEnvironmentResource(EnvironmentResource): +class S3ShareEnvironmentResource(EnvironmentResource): @staticmethod def count_resources(session, environment, group_uri) -> int: return S3ShareObjectRepository.count_S3_principal_shares( @@ -395,124 +392,6 @@ def list_s3_dataset_shares_with_existing_shared_items( query = query.filter(ShareObjectItem.itemType == item_type) return query.all() - @staticmethod # TODO!!! - def list_shareable_items(session, share, states, data): # TODO - # All tables from dataset with a column isShared - # marking the table as part of the shareObject - tables = ( - session.query( - DatasetTable.tableUri.label('itemUri'), - func.coalesce('DatasetTable').label('itemType'), - DatasetTable.GlueTableName.label('itemName'), - DatasetTable.description.label('description'), - ShareObjectItem.shareItemUri.label('shareItemUri'), - ShareObjectItem.status.label('status'), - ShareObjectItem.healthStatus.label('healthStatus'), - ShareObjectItem.healthMessage.label('healthMessage'), - ShareObjectItem.lastVerificationTime.label('lastVerificationTime'), - case( - [(ShareObjectItem.shareItemUri.isnot(None), True)], - else_=False, - ).label('isShared'), - ) - .outerjoin( - ShareObjectItem, - and_( - ShareObjectItem.shareUri == share.shareUri, - DatasetTable.tableUri == ShareObjectItem.itemUri, - ), - ) - .filter(DatasetTable.datasetUri == share.datasetUri) - ) - if states: - tables = tables.filter(ShareObjectItem.status.in_(states)) - - # All folders from the dataset with a column isShared - # marking the folder as part of the shareObject - locations = ( - session.query( - DatasetStorageLocation.locationUri.label('itemUri'), - func.coalesce('DatasetStorageLocation').label('itemType'), - DatasetStorageLocation.S3Prefix.label('itemName'), - DatasetStorageLocation.description.label('description'), - ShareObjectItem.shareItemUri.label('shareItemUri'), - ShareObjectItem.status.label('status'), - ShareObjectItem.healthStatus.label('healthStatus'), - ShareObjectItem.healthMessage.label('healthMessage'), - ShareObjectItem.lastVerificationTime.label('lastVerificationTime'), - case( - [(ShareObjectItem.shareItemUri.isnot(None), True)], - else_=False, - ).label('isShared'), - ) - .outerjoin( - ShareObjectItem, - and_( - ShareObjectItem.shareUri == share.shareUri, - DatasetStorageLocation.locationUri == ShareObjectItem.itemUri, - ), - ) - .filter(DatasetStorageLocation.datasetUri == share.datasetUri) - ) - if states: - locations = locations.filter(ShareObjectItem.status.in_(states)) - - s3_buckets = ( - session.query( - DatasetBucket.bucketUri.label('itemUri'), - func.coalesce('S3Bucket').label('itemType'), - DatasetBucket.S3BucketName.label('itemName'), - DatasetBucket.description.label('description'), - ShareObjectItem.shareItemUri.label('shareItemUri'), - ShareObjectItem.status.label('status'), - ShareObjectItem.healthStatus.label('healthStatus'), - ShareObjectItem.healthMessage.label('healthMessage'), - ShareObjectItem.lastVerificationTime.label('lastVerificationTime'), - case( - [(ShareObjectItem.shareItemUri.isnot(None), True)], - else_=False, - ).label('isShared'), - ) - .outerjoin( - ShareObjectItem, - and_( - ShareObjectItem.shareUri == share.shareUri, - DatasetBucket.bucketUri == ShareObjectItem.itemUri, - ), - ) - .filter(DatasetBucket.datasetUri == share.datasetUri) - ) - if states: - s3_buckets = s3_buckets.filter(ShareObjectItem.status.in_(states)) - - shareable_objects = tables.union(locations, s3_buckets).subquery('shareable_objects') - query = session.query(shareable_objects) - - if data: - if data.get('term'): - term = data.get('term') - query = query.filter( - or_( - shareable_objects.c.itemName.ilike(term + '%'), - shareable_objects.c.description.ilike(term + '%'), - ) - ) - if 'isShared' in data: - is_shared = data.get('isShared') - query = query.filter(shareable_objects.c.isShared == is_shared) - - if 'isHealthy' in data: - # healthy_status = ShareItemHealthStatus.Healthy.value - query = ( - query.filter(shareable_objects.c.healthStatus == ShareItemHealthStatus.Healthy.value) - if data.get('isHealthy') - else query.filter(shareable_objects.c.healthStatus != ShareItemHealthStatus.Healthy.value) - ) - - return paginate( - query.order_by(shareable_objects.c.itemName).distinct(), data.get('page', 1), data.get('pageSize', 10) - ).to_dict() - # the next 2 methods are used in subscription task @staticmethod def find_share_items_by_item_uri(session, item_uri): diff --git a/backend/dataall/modules/s3_datasets_shares/services/dataset_sharing_alarm_service.py b/backend/dataall/modules/s3_datasets_shares/services/s3_share_alarm_service.py similarity index 99% rename from backend/dataall/modules/s3_datasets_shares/services/dataset_sharing_alarm_service.py rename to backend/dataall/modules/s3_datasets_shares/services/s3_share_alarm_service.py index 5a46fea2e..e9c6a5e2f 100644 --- a/backend/dataall/modules/s3_datasets_shares/services/dataset_sharing_alarm_service.py +++ b/backend/dataall/modules/s3_datasets_shares/services/s3_share_alarm_service.py @@ -9,7 +9,7 @@ log = logging.getLogger(__name__) -class DatasetSharingAlarmService(AlarmService): +class S3ShareAlarmService(AlarmService): """Contains set of alarms for datasets""" def trigger_table_sharing_failure_alarm( diff --git a/backend/dataall/modules/s3_datasets_shares/services/s3_share_dataset_service.py b/backend/dataall/modules/s3_datasets_shares/services/s3_share_dataset_service.py new file mode 100644 index 000000000..435ed4e8e --- /dev/null +++ b/backend/dataall/modules/s3_datasets_shares/services/s3_share_dataset_service.py @@ -0,0 +1,105 @@ +from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService +from dataall.base.db import exceptions +from dataall.modules.shares_base.db.share_object_models import ShareObject +from dataall.modules.s3_datasets_shares.db.s3_share_object_repositories import S3ShareObjectRepository +from dataall.modules.shares_base.services.share_permissions import SHARE_OBJECT_APPROVER +from dataall.modules.s3_datasets.services.dataset_permissions import ( + DELETE_DATASET, + DELETE_DATASET_TABLE, + DELETE_DATASET_FOLDER, +) +from dataall.modules.datasets_base.services.datasets_enums import DatasetRole, DatasetTypes +from dataall.modules.datasets_base.services.dataset_service_interface import DatasetServiceInterface + + +import logging + +log = logging.getLogger(__name__) + + +class S3ShareDatasetService(DatasetServiceInterface): + @property + def dataset_type(self): + return DatasetTypes.S3 + + @staticmethod + def resolve_additional_dataset_user_role(session, uri, username, groups): + """Implemented as part of the DatasetServiceInterface""" + share = S3ShareObjectRepository.get_share_by_dataset_attributes(session, uri, username, groups) + if share is not None: + return DatasetRole.Shared.value + return None + + @staticmethod + def check_before_delete(session, uri, **kwargs): + """Implemented as part of the DatasetServiceInterface""" + action = kwargs.get('action') + if action in [DELETE_DATASET_FOLDER, DELETE_DATASET_TABLE]: + existing_s3_shared_items = S3ShareObjectRepository.check_existing_s3_shared_items(session, uri) + if existing_s3_shared_items: + raise exceptions.ResourceShared( + action=action, + message='Revoke all shares for this item before deletion', + ) + elif action in [DELETE_DATASET]: + shares = S3ShareObjectRepository.list_s3_dataset_shares_with_existing_shared_items( + session=session, dataset_uri=uri + ) + if shares: + raise exceptions.ResourceShared( + action=DELETE_DATASET, + message='Revoke all dataset shares before deletion.', + ) + else: + raise exceptions.RequiredParameter('Delete action') + return True + + @staticmethod + def execute_on_delete(session, uri, **kwargs): + """Implemented as part of the DatasetServiceInterface""" + action = kwargs.get('action') + if action in [DELETE_DATASET_FOLDER, DELETE_DATASET_TABLE]: + S3ShareObjectRepository.delete_s3_share_item(session, uri) + elif action in [DELETE_DATASET]: + S3ShareObjectRepository.delete_s3_shares_with_no_shared_items(session, uri) + else: + raise exceptions.RequiredParameter('Delete action') + return True + + @staticmethod + def append_to_list_user_datasets(session, username, groups): + """Implemented as part of the DatasetServiceInterface""" + return S3ShareObjectRepository.list_user_s3_shared_datasets(session, username, groups) + + @staticmethod + def extend_attach_steward_permissions(session, dataset, new_stewards, **kwargs): + """Implemented as part of the DatasetServiceInterface""" + dataset_shares = S3ShareObjectRepository.find_s3_dataset_shares(session, dataset.datasetUri) + if dataset_shares: + for share in dataset_shares: + ResourcePolicyService.attach_resource_policy( + session=session, + group=new_stewards, + permissions=SHARE_OBJECT_APPROVER, + resource_uri=share.shareUri, + resource_type=ShareObject.__name__, + ) + if dataset.stewards != dataset.SamlAdminGroupName: + ResourcePolicyService.delete_resource_policy( + session=session, + group=dataset.stewards, + resource_uri=share.shareUri, + ) + + @staticmethod + def extend_delete_steward_permissions(session, dataset, **kwargs): + """Implemented as part of the DatasetServiceInterface""" + dataset_shares = S3ShareObjectRepository.find_s3_dataset_shares(session, dataset.datasetUri) + if dataset_shares: + for share in dataset_shares: + if dataset.stewards != dataset.SamlAdminGroupName: + ResourcePolicyService.delete_resource_policy( + session=session, + group=dataset.stewards, + resource_uri=share.shareUri, + ) diff --git a/backend/dataall/modules/s3_datasets_shares/services/managed_share_policy_service.py b/backend/dataall/modules/s3_datasets_shares/services/s3_share_managed_policy_service.py similarity index 98% rename from backend/dataall/modules/s3_datasets_shares/services/managed_share_policy_service.py rename to backend/dataall/modules/s3_datasets_shares/services/s3_share_managed_policy_service.py index e53a94944..2738e8b38 100644 --- a/backend/dataall/modules/s3_datasets_shares/services/managed_share_policy_service.py +++ b/backend/dataall/modules/s3_datasets_shares/services/s3_share_managed_policy_service.py @@ -18,7 +18,7 @@ EMPTY_STATEMENT_SID = 'EmptyStatement' -class SharePolicyService(ManagedPolicy): +class S3SharePolicyService(ManagedPolicy): def __init__(self, role_name, account, region, environmentUri, resource_prefix): self.role_name = role_name self.account = account @@ -48,7 +48,7 @@ def generate_empty_policy(self) -> dict: @staticmethod def remove_empty_statement(policy_doc: dict, statement_sid: str) -> dict: - statement_index = SharePolicyService._get_statement_by_sid(policy_doc, statement_sid) + statement_index = S3SharePolicyService._get_statement_by_sid(policy_doc, statement_sid) if statement_index is not None: policy_doc['Statement'].pop(statement_index) return policy_doc diff --git a/backend/dataall/modules/s3_datasets_shares/services/dataset_sharing_service.py b/backend/dataall/modules/s3_datasets_shares/services/s3_share_service.py similarity index 55% rename from backend/dataall/modules/s3_datasets_shares/services/dataset_sharing_service.py rename to backend/dataall/modules/s3_datasets_shares/services/s3_share_service.py index 207ab1c65..255544bc7 100644 --- a/backend/dataall/modules/s3_datasets_shares/services/dataset_sharing_service.py +++ b/backend/dataall/modules/s3_datasets_shares/services/s3_share_service.py @@ -1,123 +1,141 @@ +import logging from warnings import warn + from dataall.base.db import utils +from dataall.base.context import get_context +from dataall.base.aws.sts import SessionHelper from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService from dataall.core.environment.services.environment_service import EnvironmentService -from dataall.base.context import get_context -from dataall.base.db import exceptions -from dataall.base.aws.sts import SessionHelper -from dataall.modules.shares_base.db.share_object_models import ShareObject -from dataall.modules.s3_datasets_shares.db.share_object_repositories import S3ShareObjectRepository +from dataall.core.tasks.db.task_models import Task +from dataall.core.tasks.service_handlers import Worker from dataall.modules.shares_base.db.share_object_repositories import ShareObjectRepository from dataall.modules.shares_base.db.share_state_machines_repositories import ShareStatusRepository -from dataall.modules.shares_base.services.share_permissions import SHARE_OBJECT_APPROVER, GET_SHARE_OBJECT -from dataall.modules.s3_datasets_shares.services.share_item_service import S3ShareItemService from dataall.modules.shares_base.services.share_item_service import ShareItemService +from dataall.modules.shares_base.services.share_permissions import GET_SHARE_OBJECT +from dataall.modules.shares_base.services.shares_enums import ( + ShareableType, + ShareItemStatus, +) +from dataall.modules.s3_datasets.db.dataset_models import DatasetTable, DatasetStorageLocation from dataall.modules.s3_datasets.db.dataset_repositories import DatasetRepository from dataall.modules.s3_datasets.services.dataset_permissions import ( MANAGE_DATASETS, UPDATE_DATASET, - DELETE_DATASET, - DELETE_DATASET_TABLE, - DELETE_DATASET_FOLDER, CREDENTIALS_DATASET, + DATASET_TABLE_READ, + DATASET_FOLDER_READ, ) -from dataall.modules.datasets_base.services.datasets_enums import DatasetRole, DatasetTypes -from dataall.modules.datasets_base.services.dataset_service_interface import DatasetServiceInterface +from dataall.modules.s3_datasets_shares.db.s3_share_object_repositories import S3ShareObjectRepository from dataall.modules.s3_datasets_shares.aws.glue_client import GlueClient -import logging - log = logging.getLogger(__name__) -class DatasetSharingService(DatasetServiceInterface): - @property - def dataset_type(self): - return DatasetTypes.S3 - +class S3ShareService: @staticmethod - def resolve_additional_dataset_user_role(session, uri, username, groups): - """Implemented as part of the DatasetServiceInterface""" - share = S3ShareObjectRepository.get_share_by_dataset_attributes(session, uri, username, groups) - if share is not None: - return DatasetRole.Shared.value - return None + def delete_dataset_table_read_permission(session, share, tableUri): + """ + Delete Table permissions to share groups + """ + other_shares = S3ShareObjectRepository.find_all_other_share_items( + session, + not_this_share_uri=share.shareUri, + item_uri=tableUri, + share_type=ShareableType.Table.value, + principal_type='GROUP', + principal_uri=share.groupUri, + item_status=[ShareItemStatus.Share_Succeeded.value], + ) + log.info(f'Table {tableUri} has been shared with group {share.groupUri} in {len(other_shares)} more shares') + if len(other_shares) == 0: + log.info('Delete permissions...') + ResourcePolicyService.delete_resource_policy(session=session, group=share.groupUri, resource_uri=tableUri) @staticmethod - def check_before_delete(session, uri, **kwargs): - """Implemented as part of the DatasetServiceInterface""" - action = kwargs.get('action') - if action in [DELETE_DATASET_FOLDER, DELETE_DATASET_TABLE]: - existing_s3_shared_items = S3ShareObjectRepository.check_existing_s3_shared_items(session, uri) - if existing_s3_shared_items: - raise exceptions.ResourceShared( - action=action, - message='Revoke all shares for this item before deletion', - ) - elif action in [DELETE_DATASET]: - shares = S3ShareObjectRepository.list_s3_dataset_shares_with_existing_shared_items( - session=session, dataset_uri=uri + def delete_dataset_folder_read_permission(session, share, locationUri): + """ + Delete Folder permissions to share groups + """ + other_shares = S3ShareObjectRepository.find_all_other_share_items( + session, + not_this_share_uri=share.shareUri, + item_uri=locationUri, + share_type=ShareableType.StorageLocation.value, + principal_type='GROUP', + principal_uri=share.groupUri, + item_status=[ShareItemStatus.Share_Succeeded.value], + ) + log.info( + f'Location {locationUri} has been shared with group {share.groupUri} in {len(other_shares)} more shares' + ) + if len(other_shares) == 0: + log.info('Delete permissions...') + ResourcePolicyService.delete_resource_policy( + session=session, + group=share.groupUri, + resource_uri=locationUri, ) - if shares: - raise exceptions.ResourceShared( - action=DELETE_DATASET, - message='Revoke all dataset shares before deletion.', - ) - else: - raise exceptions.RequiredParameter('Delete action') - return True @staticmethod - def execute_on_delete(session, uri, **kwargs): - """Implemented as part of the DatasetServiceInterface""" - action = kwargs.get('action') - if action in [DELETE_DATASET_FOLDER, DELETE_DATASET_TABLE]: - S3ShareObjectRepository.delete_s3_share_item(session, uri) - elif action in [DELETE_DATASET]: - S3ShareObjectRepository.delete_s3_shares_with_no_shared_items(session, uri) + def attach_dataset_table_read_permission(session, share, tableUri): + """ + Attach Table permissions to share groups + """ + existing_policy = ResourcePolicyService.find_resource_policies( + session, + group=share.groupUri, + resource_uri=tableUri, + resource_type=DatasetTable.__name__, + permissions=DATASET_TABLE_READ, + ) + # toDo: separate policies from list DATASET_TABLE_READ, because in future only one of them can be granted (Now they are always granted together) + if len(existing_policy) == 0: + log.info( + f'Attaching new resource permission policy {DATASET_TABLE_READ} to table {tableUri} for group {share.groupUri}' + ) + ResourcePolicyService.attach_resource_policy( + session=session, + group=share.groupUri, + permissions=DATASET_TABLE_READ, + resource_uri=tableUri, + resource_type=DatasetTable.__name__, + ) else: - raise exceptions.RequiredParameter('Delete action') - return True - - @staticmethod - def append_to_list_user_datasets(session, username, groups): - """Implemented as part of the DatasetServiceInterface""" - return S3ShareObjectRepository.list_user_s3_shared_datasets(session, username, groups) + log.info( + f'Resource permission policy {DATASET_TABLE_READ} to table {tableUri} for group {share.groupUri} already exists. Skip... ' + ) @staticmethod - def extend_attach_steward_permissions(session, dataset, new_stewards, **kwargs): - """Implemented as part of the DatasetServiceInterface""" - dataset_shares = S3ShareObjectRepository.find_s3_dataset_shares(session, dataset.datasetUri) - if dataset_shares: - for share in dataset_shares: - ResourcePolicyService.attach_resource_policy( - session=session, - group=new_stewards, - permissions=SHARE_OBJECT_APPROVER, - resource_uri=share.shareUri, - resource_type=ShareObject.__name__, - ) - if dataset.stewards != dataset.SamlAdminGroupName: - ResourcePolicyService.delete_resource_policy( - session=session, - group=dataset.stewards, - resource_uri=share.shareUri, - ) + def attach_dataset_folder_read_permission(session, share, locationUri): + """ + Attach Folder permissions to share groups + """ + existing_policy = ResourcePolicyService.find_resource_policies( + session, + group=share.groupUri, + resource_uri=locationUri, + resource_type=DatasetStorageLocation.__name__, + permissions=DATASET_FOLDER_READ, + ) + # toDo: separate policies from list DATASET_TABLE_READ, because in future only one of them can be granted (Now they are always granted together) + if len(existing_policy) == 0: + log.info( + f'Attaching new resource permission policy {DATASET_FOLDER_READ} to folder {locationUri} for group {share.groupUri}' + ) - @staticmethod - def extend_delete_steward_permissions(session, dataset, **kwargs): - """Implemented as part of the DatasetServiceInterface""" - dataset_shares = S3ShareObjectRepository.find_s3_dataset_shares(session, dataset.datasetUri) - if dataset_shares: - for share in dataset_shares: - if dataset.stewards != dataset.SamlAdminGroupName: - ResourcePolicyService.delete_resource_policy( - session=session, - group=dataset.stewards, - resource_uri=share.shareUri, - ) + ResourcePolicyService.attach_resource_policy( + session=session, + group=share.groupUri, + permissions=DATASET_FOLDER_READ, + resource_uri=locationUri, + resource_type=DatasetStorageLocation.__name__, + ) + else: + log.info( + f'Resource permission policy {DATASET_FOLDER_READ} to table {locationUri} for group {share.groupUri} already exists. Skip... ' + ) @staticmethod @TenantPolicyService.has_tenant_permission(MANAGE_DATASETS) @@ -134,6 +152,17 @@ def verify_dataset_share_objects(uri: str, share_uris: list): ShareItemService.verify_items_share_object(uri=share_uri, item_uris=item_uris) return True + @staticmethod + @TenantPolicyService.has_tenant_permission(MANAGE_DATASETS) + @ResourcePolicyService.has_resource_permission(UPDATE_DATASET) + def reapply_share_items_for_dataset(uri: str): + context = get_context() + with context.db_engine.scoped_session() as session: + reapply_share_items_task: Task = Task(action='ecs.dataset.share.reapply', targetUri=uri) + session.add(reapply_share_items_task) + Worker.queue(engine=context.db_engine, task_ids=[reapply_share_items_task.taskUri]) + return True + @staticmethod def list_shared_tables_by_env_dataset(dataset_uri: str, env_uri: str): context = get_context() @@ -195,9 +224,10 @@ def get_s3_consumption_data(uri): separator='-', ) # Check if the share was made with a Glue Database - datasetGlueDatabase = S3ShareItemService.get_glue_database_for_share( - dataset.GlueDatabaseName, dataset.AwsAccountId, dataset.region - ) + datasetGlueDatabase = GlueClient( + account_id=dataset.AwsAccountId, region=dataset.region, database=dataset.GlueDatabaseName + ).get_glue_database_from_catalog() + old_shared_db_name = f'{datasetGlueDatabase}_shared_{uri}'[:254] database = GlueClient( account_id=environment.AwsAccountId, region=environment.region, database=old_shared_db_name diff --git a/backend/dataall/modules/s3_datasets_shares/services/share_item_service.py b/backend/dataall/modules/s3_datasets_shares/services/share_item_service.py deleted file mode 100644 index 695b9ccd4..000000000 --- a/backend/dataall/modules/s3_datasets_shares/services/share_item_service.py +++ /dev/null @@ -1,135 +0,0 @@ -import logging - -from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService -from dataall.modules.shares_base.services.shares_enums import ( - ShareableType, - ShareItemStatus, -) -from dataall.modules.s3_datasets_shares.aws.glue_client import GlueClient -from dataall.modules.s3_datasets_shares.db.share_object_repositories import S3ShareObjectRepository -from dataall.modules.s3_datasets.db.dataset_models import DatasetTable, DatasetStorageLocation -from dataall.modules.s3_datasets.services.dataset_permissions import DATASET_TABLE_READ, DATASET_FOLDER_READ - -log = logging.getLogger(__name__) - - -class S3ShareItemService: - @staticmethod - def get_glue_database_for_share(glueDatabase, account_id, region): # TODO: IN S3_DATASETS_SHARES - # Check if a catalog account exists and return database accordingly - try: - catalog_dict = GlueClient( - account_id=account_id, - region=region, - database=glueDatabase, - ).get_source_catalog() - - if catalog_dict is not None: - return catalog_dict.get('database_name') - else: - return glueDatabase - except Exception as e: - raise e - - @staticmethod - def delete_dataset_table_read_permission(session, share, tableUri): - """ - Delete Table permissions to share groups - """ - other_shares = S3ShareObjectRepository.find_all_other_share_items( - session, - not_this_share_uri=share.shareUri, - item_uri=tableUri, - share_type=ShareableType.Table.value, - principal_type='GROUP', - principal_uri=share.groupUri, - item_status=[ShareItemStatus.Share_Succeeded.value], - ) - log.info(f'Table {tableUri} has been shared with group {share.groupUri} in {len(other_shares)} more shares') - if len(other_shares) == 0: - log.info('Delete permissions...') - ResourcePolicyService.delete_resource_policy(session=session, group=share.groupUri, resource_uri=tableUri) - - @staticmethod - def delete_dataset_folder_read_permission(session, share, locationUri): - """ - Delete Folder permissions to share groups - """ - other_shares = S3ShareObjectRepository.find_all_other_share_items( - session, - not_this_share_uri=share.shareUri, - item_uri=locationUri, - share_type=ShareableType.StorageLocation.value, - principal_type='GROUP', - principal_uri=share.groupUri, - item_status=[ShareItemStatus.Share_Succeeded.value], - ) - log.info( - f'Location {locationUri} has been shared with group {share.groupUri} in {len(other_shares)} more shares' - ) - if len(other_shares) == 0: - log.info('Delete permissions...') - ResourcePolicyService.delete_resource_policy( - session=session, - group=share.groupUri, - resource_uri=locationUri, - ) - - @staticmethod - def attach_dataset_table_read_permission(session, share, tableUri): - """ - Attach Table permissions to share groups - """ - existing_policy = ResourcePolicyService.find_resource_policies( - session, - group=share.groupUri, - resource_uri=tableUri, - resource_type=DatasetTable.__name__, - permissions=DATASET_TABLE_READ, - ) - # toDo: separate policies from list DATASET_TABLE_READ, because in future only one of them can be granted (Now they are always granted together) - if len(existing_policy) == 0: - log.info( - f'Attaching new resource permission policy {DATASET_TABLE_READ} to table {tableUri} for group {share.groupUri}' - ) - ResourcePolicyService.attach_resource_policy( - session=session, - group=share.groupUri, - permissions=DATASET_TABLE_READ, - resource_uri=tableUri, - resource_type=DatasetTable.__name__, - ) - else: - log.info( - f'Resource permission policy {DATASET_TABLE_READ} to table {tableUri} for group {share.groupUri} already exists. Skip... ' - ) - - @staticmethod - def attach_dataset_folder_read_permission(session, share, locationUri): - """ - Attach Folder permissions to share groups - """ - existing_policy = ResourcePolicyService.find_resource_policies( - session, - group=share.groupUri, - resource_uri=locationUri, - resource_type=DatasetStorageLocation.__name__, - permissions=DATASET_FOLDER_READ, - ) - # toDo: separate policies from list DATASET_TABLE_READ, because in future only one of them can be granted (Now they are always granted together) - if len(existing_policy) == 0: - log.info( - f'Attaching new resource permission policy {DATASET_FOLDER_READ} to folder {locationUri} for group {share.groupUri}' - ) - - ResourcePolicyService.attach_resource_policy( - session=session, - group=share.groupUri, - permissions=DATASET_FOLDER_READ, - resource_uri=locationUri, - resource_type=DatasetStorageLocation.__name__, - ) - else: - log.info( - f'Resource permission policy {DATASET_FOLDER_READ} to table {locationUri} for group {share.groupUri} already exists. Skip... ' - ) diff --git a/backend/dataall/modules/s3_datasets_shares/services/share_managers/lf_share_manager.py b/backend/dataall/modules/s3_datasets_shares/services/share_managers/lf_share_manager.py index 8d12fb40f..b2ddef07a 100644 --- a/backend/dataall/modules/s3_datasets_shares/services/share_managers/lf_share_manager.py +++ b/backend/dataall/modules/s3_datasets_shares/services/share_managers/lf_share_manager.py @@ -19,7 +19,7 @@ ShareItemHealthStatus, ) from dataall.modules.s3_datasets.db.dataset_models import DatasetTable -from dataall.modules.s3_datasets_shares.services.dataset_sharing_alarm_service import DatasetSharingAlarmService +from dataall.modules.s3_datasets_shares.services.s3_share_alarm_service import S3ShareAlarmService from dataall.modules.shares_base.db.share_object_models import ShareObjectItem from dataall.modules.s3_datasets_shares.services.share_managers.share_manager_utils import ShareErrorFormatter from dataall.modules.shares_base.services.sharing_service import ShareData @@ -569,7 +569,7 @@ def handle_share_failure( f'due to: {error}' ) - DatasetSharingAlarmService().trigger_table_sharing_failure_alarm(table, self.share, self.target_environment) + S3ShareAlarmService().trigger_table_sharing_failure_alarm(table, self.share, self.target_environment) return True def handle_revoke_failure( @@ -589,9 +589,7 @@ def handle_revoke_failure( f'with target account {self.target_environment.AwsAccountId}/{self.target_environment.region} ' f'due to: {error}' ) - DatasetSharingAlarmService().trigger_revoke_table_sharing_failure_alarm( - table, self.share, self.target_environment - ) + S3ShareAlarmService().trigger_revoke_table_sharing_failure_alarm(table, self.share, self.target_environment) return True def handle_share_failure_for_all_tables(self, tables, error, share_item_status, reapply=False): diff --git a/backend/dataall/modules/s3_datasets_shares/services/share_managers/s3_access_point_share_manager.py b/backend/dataall/modules/s3_datasets_shares/services/share_managers/s3_access_point_share_manager.py index e3663e3ee..8aaa9a7ea 100644 --- a/backend/dataall/modules/s3_datasets_shares/services/share_managers/s3_access_point_share_manager.py +++ b/backend/dataall/modules/s3_datasets_shares/services/share_managers/s3_access_point_share_manager.py @@ -18,12 +18,12 @@ DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID, ) from dataall.base.aws.iam import IAM -from dataall.modules.s3_datasets_shares.services.dataset_sharing_alarm_service import DatasetSharingAlarmService +from dataall.modules.s3_datasets_shares.services.s3_share_alarm_service import S3ShareAlarmService from dataall.modules.shares_base.db.share_object_repositories import ShareObjectRepository from dataall.modules.shares_base.services.share_exceptions import PrincipalRoleNotFound from dataall.modules.s3_datasets_shares.services.share_managers.share_manager_utils import ShareErrorFormatter -from dataall.modules.s3_datasets_shares.services.managed_share_policy_service import ( - SharePolicyService, +from dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service import ( + S3SharePolicyService, IAM_S3_ACCESS_POINTS_STATEMENT_SID, EMPTY_STATEMENT_SID, ) @@ -159,7 +159,7 @@ def check_target_role_access_policy(self) -> None: key_alias = f'alias/{self.dataset.KmsAlias}' kms_client = KmsClient(self.dataset_account_id, self.source_environment.region) kms_key_id = kms_client.get_key_id(key_alias) - share_policy_service = SharePolicyService( + share_policy_service = S3SharePolicyService( environmentUri=self.target_environment.environmentUri, account=self.target_environment.AwsAccountId, region=self.target_environment.region, @@ -194,7 +194,7 @@ def check_target_role_access_policy(self) -> None: ) logger.info(f'Policy... {policy_document}') - s3_statement_index = SharePolicyService._get_statement_by_sid( + s3_statement_index = S3SharePolicyService._get_statement_by_sid( policy_document, f'{IAM_S3_ACCESS_POINTS_STATEMENT_SID}S3' ) @@ -228,7 +228,7 @@ def check_target_role_access_policy(self) -> None: ) if kms_key_id: - kms_statement_index = SharePolicyService._get_statement_by_sid( + kms_statement_index = S3SharePolicyService._get_statement_by_sid( policy_document, f'{IAM_S3_ACCESS_POINTS_STATEMENT_SID}KMS' ) kms_target_resources = [f'arn:aws:kms:{self.dataset_region}:{self.dataset_account_id}:key/{kms_key_id}'] @@ -268,7 +268,7 @@ def grant_target_role_access_policy(self): """ logger.info(f'Grant target role {self.target_requester_IAMRoleName} access policy') - share_policy_service = SharePolicyService( + share_policy_service = S3SharePolicyService( environmentUri=self.target_environment.environmentUri, account=self.target_environment.AwsAccountId, region=self.target_environment.region, @@ -619,7 +619,7 @@ def delete_access_point(self): def revoke_target_role_access_policy(self): logger.info('Deleting target role IAM statements...') - share_policy_service = SharePolicyService( + share_policy_service = S3SharePolicyService( environmentUri=self.target_environment.environmentUri, account=self.target_environment.AwsAccountId, region=self.target_environment.region, @@ -718,7 +718,7 @@ def handle_share_failure(self, error: Exception) -> None: f'with target account {self.target_environment.AwsAccountId}/{self.target_environment.region} ' f'due to: {error}' ) - DatasetSharingAlarmService().trigger_folder_sharing_failure_alarm( + S3ShareAlarmService().trigger_folder_sharing_failure_alarm( self.target_folder, self.share, self.target_environment ) @@ -735,7 +735,7 @@ def handle_revoke_failure(self, error: Exception) -> bool: f'with target account {self.target_environment.AwsAccountId}/{self.target_environment.region} ' f'due to: {error}' ) - DatasetSharingAlarmService().trigger_revoke_folder_sharing_failure_alarm( + S3ShareAlarmService().trigger_revoke_folder_sharing_failure_alarm( self.target_folder, self.share, self.target_environment ) return True diff --git a/backend/dataall/modules/s3_datasets_shares/services/share_managers/s3_bucket_share_manager.py b/backend/dataall/modules/s3_datasets_shares/services/share_managers/s3_bucket_share_manager.py index 609b50a7b..aafc932e7 100644 --- a/backend/dataall/modules/s3_datasets_shares/services/share_managers/s3_bucket_share_manager.py +++ b/backend/dataall/modules/s3_datasets_shares/services/share_managers/s3_bucket_share_manager.py @@ -15,9 +15,9 @@ from dataall.modules.shares_base.db.share_object_models import ShareObject from dataall.modules.shares_base.services.share_exceptions import PrincipalRoleNotFound from dataall.modules.s3_datasets_shares.services.share_managers.share_manager_utils import ShareErrorFormatter -from dataall.modules.s3_datasets_shares.services.dataset_sharing_alarm_service import DatasetSharingAlarmService -from dataall.modules.s3_datasets_shares.services.managed_share_policy_service import ( - SharePolicyService, +from dataall.modules.s3_datasets_shares.services.s3_share_alarm_service import S3ShareAlarmService +from dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service import ( + S3SharePolicyService, IAM_S3_BUCKETS_STATEMENT_SID, EMPTY_STATEMENT_SID, ) @@ -70,7 +70,7 @@ def check_s3_iam_access(self) -> None: kms_client = KmsClient(self.source_account_id, self.source_environment.region) kms_key_id = kms_client.get_key_id(key_alias) - share_policy_service = SharePolicyService( + share_policy_service = S3SharePolicyService( environmentUri=self.target_environment.environmentUri, account=self.target_environment.AwsAccountId, region=self.target_environment.region, @@ -98,7 +98,7 @@ def check_s3_iam_access(self) -> None: version_id, policy_document = IAM.get_managed_policy_default_version( self.target_environment.AwsAccountId, self.target_environment.region, share_resource_policy_name ) - s3_statement_index = SharePolicyService._get_statement_by_sid( + s3_statement_index = S3SharePolicyService._get_statement_by_sid( policy_document, f'{IAM_S3_BUCKETS_STATEMENT_SID}S3' ) @@ -131,7 +131,7 @@ def check_s3_iam_access(self) -> None: ) if kms_key_id: - kms_statement_index = SharePolicyService._get_statement_by_sid( + kms_statement_index = S3SharePolicyService._get_statement_by_sid( policy_document, f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS' ) kms_target_resources = [f'arn:aws:kms:{self.bucket_region}:{self.source_account_id}:key/{kms_key_id}'] @@ -172,7 +172,7 @@ def grant_s3_iam_access(self): """ logger.info(f'Grant target role {self.target_requester_IAMRoleName} access policy') - share_policy_service = SharePolicyService( + share_policy_service = S3SharePolicyService( environmentUri=self.target_environment.environmentUri, account=self.target_environment.AwsAccountId, region=self.target_environment.region, @@ -484,7 +484,7 @@ def delete_target_role_access_policy( ): logger.info('Deleting target role IAM statements...') - share_policy_service = SharePolicyService( + share_policy_service = S3SharePolicyService( role_name=share.principalIAMRoleName, account=target_environment.AwsAccountId, region=self.target_environment.region, @@ -574,7 +574,7 @@ def handle_share_failure(self, error: Exception) -> bool: f'with target account {self.target_environment.AwsAccountId}/{self.target_environment.region} ' f'due to: {error}' ) - DatasetSharingAlarmService().trigger_s3_bucket_sharing_failure_alarm( + S3ShareAlarmService().trigger_s3_bucket_sharing_failure_alarm( self.target_bucket, self.share, self.target_environment ) return True @@ -592,7 +592,7 @@ def handle_revoke_failure(self, error: Exception) -> bool: f'with target account {self.target_environment.AwsAccountId}/{self.target_environment.region} ' f'due to: {error}' ) - DatasetSharingAlarmService().trigger_revoke_s3_bucket_sharing_failure_alarm( + S3ShareAlarmService().trigger_revoke_s3_bucket_sharing_failure_alarm( self.target_bucket, self.share, self.target_environment ) return True diff --git a/backend/dataall/modules/s3_datasets_shares/services/share_processors/glue_table_share_processor.py b/backend/dataall/modules/s3_datasets_shares/services/share_processors/glue_table_share_processor.py index 8c5bc0f42..3091f7041 100644 --- a/backend/dataall/modules/s3_datasets_shares/services/share_processors/glue_table_share_processor.py +++ b/backend/dataall/modules/s3_datasets_shares/services/share_processors/glue_table_share_processor.py @@ -16,10 +16,10 @@ from dataall.modules.s3_datasets_shares.services.share_managers import LFShareManager from dataall.modules.s3_datasets_shares.aws.ram_client import RamClient from dataall.modules.shares_base.services.share_object_service import ShareObjectService -from dataall.modules.s3_datasets_shares.services.share_item_service import S3ShareItemService +from dataall.modules.s3_datasets_shares.services.s3_share_service import S3ShareService from dataall.modules.shares_base.db.share_object_repositories import ShareObjectRepository from dataall.modules.shares_base.db.share_state_machines_repositories import ShareStatusRepository -from dataall.modules.s3_datasets_shares.db.share_object_repositories import S3ShareObjectRepository +from dataall.modules.s3_datasets_shares.db.s3_share_object_repositories import S3ShareObjectRepository from dataall.modules.shares_base.db.share_object_state_machines import ShareItemSM from dataall.modules.s3_datasets_shares.services.share_managers.share_manager_utils import ShareErrorFormatter @@ -154,7 +154,7 @@ def process_approved_shares(self) -> bool: manager.grant_principals_permissions_to_resource_link_table(table) log.info('Attaching TABLE READ permissions...') - S3ShareItemService.attach_dataset_table_read_permission( + S3ShareService.attach_dataset_table_read_permission( self.session, self.share_data.share, table.tableUri ) @@ -276,7 +276,7 @@ def process_revoked_shares(self) -> bool: and self.share_data.share.groupUri != self.share_data.dataset.stewards ): log.info('Deleting TABLE READ permissions...') - S3ShareItemService.delete_dataset_table_read_permission( + S3ShareService.delete_dataset_table_read_permission( self.session, self.share_data.share, table.tableUri ) diff --git a/backend/dataall/modules/s3_datasets_shares/services/share_processors/s3_access_point_share_processor.py b/backend/dataall/modules/s3_datasets_shares/services/share_processors/s3_access_point_share_processor.py index ab464f4dc..87ec4f6c0 100644 --- a/backend/dataall/modules/s3_datasets_shares/services/share_processors/s3_access_point_share_processor.py +++ b/backend/dataall/modules/s3_datasets_shares/services/share_processors/s3_access_point_share_processor.py @@ -5,7 +5,7 @@ from dataall.modules.shares_base.services.share_exceptions import PrincipalRoleNotFound from dataall.modules.s3_datasets_shares.services.share_managers import S3AccessPointShareManager from dataall.modules.shares_base.services.share_object_service import ShareObjectService -from dataall.modules.s3_datasets_shares.services.share_item_service import S3ShareItemService +from dataall.modules.s3_datasets_shares.services.s3_share_service import S3ShareService from dataall.modules.shares_base.services.shares_enums import ( ShareItemHealthStatus, ShareItemStatus, @@ -76,7 +76,7 @@ def process_approved_shares(self) -> bool: manager.update_dataset_bucket_key_policy() log.info('Attaching FOLDER READ permissions...') - S3ShareItemService.attach_dataset_folder_read_permission( + S3ShareService.attach_dataset_folder_read_permission( self.session, self.share_data.share, folder.locationUri ) @@ -145,7 +145,7 @@ def process_revoked_shares(self) -> bool: and self.share_data.share.groupUri != self.share_data.dataset.stewards ): log.info(f'Deleting FOLDER READ permissions from {folder.locationUri}...') - S3ShareItemService.delete_dataset_folder_read_permission( + S3ShareService.delete_dataset_folder_read_permission( self.session, manager.share, folder.locationUri ) diff --git a/backend/dataall/modules/s3_datasets_shares/tasks/dataset_subscription_task.py b/backend/dataall/modules/s3_datasets_shares/tasks/dataset_subscription_task.py index e382b05ef..e5a29c904 100644 --- a/backend/dataall/modules/s3_datasets_shares/tasks/dataset_subscription_task.py +++ b/backend/dataall/modules/s3_datasets_shares/tasks/dataset_subscription_task.py @@ -10,7 +10,7 @@ from dataall.core.environment.services.environment_service import EnvironmentService from dataall.base.db import get_engine from dataall.modules.shares_base.db.share_object_models import ShareObjectItem -from dataall.modules.s3_datasets_shares.db.share_object_repositories import S3ShareObjectRepository +from dataall.modules.s3_datasets_shares.db.s3_share_object_repositories import S3ShareObjectRepository from dataall.modules.shares_base.services.share_notification_service import ShareNotificationService from dataall.modules.s3_datasets.aws.sns_dataset_client import SnsDatasetClient from dataall.modules.s3_datasets.db.dataset_location_repositories import DatasetLocationRepository diff --git a/backend/dataall/modules/shares_base/db/share_object_repositories.py b/backend/dataall/modules/shares_base/db/share_object_repositories.py index 0a9e2eafb..596a5771a 100644 --- a/backend/dataall/modules/shares_base/db/share_object_repositories.py +++ b/backend/dataall/modules/shares_base/db/share_object_repositories.py @@ -4,11 +4,14 @@ from typing import List from dataall.base.db import exceptions, paginate +from dataall.base.db.paginator import Page from dataall.core.organizations.db.organization_models import Organization from dataall.core.environment.db.environment_models import Environment, EnvironmentGroup from dataall.modules.datasets_base.db.dataset_models import DatasetBase from dataall.modules.datasets_base.db.dataset_repositories import DatasetBaseRepository +from dataall.modules.notifications.db.notification_models import Notification from dataall.modules.shares_base.db.share_object_models import ShareObjectItem, ShareObject + from dataall.modules.shares_base.services.shares_enums import ( ShareItemHealthStatus, PrincipalType, @@ -349,7 +352,9 @@ def list_shareable_items_of_type(session, share, type, share_type_model, share_t @staticmethod def paginated_list_shareable_items(session, subqueries: List[Query], data: dict = None): - if len(subqueries) == 1: + if len(subqueries) == 0: + return Page([], 1, 1, 0) # empty page. All modules are turned off + elif len(subqueries) == 1: shareable_objects = subqueries[0].subquery('shareable_objects') else: shareable_objects = subqueries[0].union(*subqueries[1:]).subquery('shareable_objects') @@ -377,3 +382,33 @@ def paginated_list_shareable_items(session, subqueries: List[Query], data: dict return paginate( query.order_by(shareable_objects.c.itemName).distinct(), data.get('page', 1), data.get('pageSize', 10) ).to_dict() + + @staticmethod + def list_active_share_object_for_dataset(session, dataset_uri): + share_objects = ( + session.query(ShareObject) + .filter(and_(ShareObject.datasetUri == dataset_uri, ShareObject.deleted.is_(None))) + .all() + ) + return share_objects + + @staticmethod + def fetch_submitted_shares_with_notifications(session): + """ + A method used by the scheduled ECS Task to run fetch_submitted_shares_with_notifications() process against ALL shared objects in ALL + active share objects within dataall + """ + with session() as session: + pending_shares = ( + session.query(ShareObject) + .join( + Notification, + and_( + ShareObject.shareUri == func.split_part(Notification.target_uri, '|', 1), + ShareObject.datasetUri == func.split_part(Notification.target_uri, '|', 2), + ), + ) + .filter(and_(Notification.type == 'SHARE_OBJECT_SUBMITTED', ShareObject.status == 'Submitted')) + .all() + ) + return pending_shares diff --git a/backend/dataall/modules/shares_base/handlers/ecs_share_handler.py b/backend/dataall/modules/shares_base/handlers/ecs_share_handler.py index 0bf633a66..b0308dd47 100644 --- a/backend/dataall/modules/shares_base/handlers/ecs_share_handler.py +++ b/backend/dataall/modules/shares_base/handlers/ecs_share_handler.py @@ -1,3 +1,4 @@ +import json import logging import os @@ -5,6 +6,7 @@ from dataall.core.stacks.aws.ecs import Ecs from dataall.core.tasks.db.task_models import Task from dataall.modules.shares_base.services.sharing_service import SharingService +from dataall.modules.shares_base.tasks.share_reapplier_task import EcsBulkShareRepplyService log = logging.getLogger(__name__) @@ -30,21 +32,43 @@ def verify_share(engine, task: Task): def reapply_share(engine, task: Task): return EcsShareHandler._manage_share(engine, task, SharingService.reapply_share, 'reapply_share') + @staticmethod + @Worker.handler(path='ecs.dataset.share.reapply') + def reapply_shares_of_dataset(engine, task: Task): + envname = os.environ.get('envname', 'local') + if envname in ['local', 'dkrcompose']: + EcsBulkShareRepplyService.process_reapply_shares_for_dataset(engine, task.targetUri) + else: + context = [ + {'name': 'datasetUri', 'value': task.targetUri}, + ] + return EcsShareHandler._run_share_management_ecs_task( + task_definition_param_str='ecs/task_def_arn/share_reapplier', + container_name_param_str='ecs/container/share_reapplier', + context=context, + ) + @staticmethod def _manage_share(engine, task: Task, local_handler, ecs_handler: str): envname = os.environ.get('envname', 'local') if envname in ['local', 'dkrcompose']: return local_handler(engine, task.targetUri) else: - return EcsShareHandler._run_share_management_ecs_task(share_uri=task.targetUri, handler=ecs_handler) + share_management_context = [ + {'name': 'shareUri', 'value': task.targetUri}, + {'name': 'handler', 'value': ecs_handler}, + ] + return EcsShareHandler._run_share_management_ecs_task( + task_definition_param_str='ecs/task_def_arn/share_management', + container_name_param_str='ecs/container/share_management', + context=share_management_context, + ) @staticmethod - def _run_share_management_ecs_task(share_uri, handler): - return Ecs.run_ecs_task( - task_definition_param='ecs/task_def_arn/share_management', - container_name_param='ecs/container/share_management', - context=[ - {'name': 'shareUri', 'value': share_uri}, - {'name': 'handler', 'value': handler}, - ], + def _run_share_management_ecs_task(task_definition_param_str, container_name_param_str, context): + ecs_task_arn = Ecs.run_ecs_task( + task_definition_param=task_definition_param_str, + container_name_param=container_name_param_str, + context=context, ) + return {'task_arn': ecs_task_arn} diff --git a/backend/dataall/modules/shares_base/services/share_item_service.py b/backend/dataall/modules/shares_base/services/share_item_service.py index 5ad7c1472..7237e4f4d 100644 --- a/backend/dataall/modules/shares_base/services/share_item_service.py +++ b/backend/dataall/modules/shares_base/services/share_item_service.py @@ -3,12 +3,10 @@ from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService from dataall.core.tasks.service_handlers import Worker from dataall.base.context import get_context -from dataall.core.environment.services.environment_service import EnvironmentService from dataall.core.tasks.db.task_models import Task from dataall.base.db.exceptions import ObjectNotFound, UnauthorizedOperation from dataall.modules.shares_base.services.shares_enums import ( ShareObjectActions, - ShareableType, ShareItemStatus, ShareItemActions, ShareItemHealthStatus, diff --git a/backend/dataall/modules/shares_base/services/share_notification_service.py b/backend/dataall/modules/shares_base/services/share_notification_service.py index 765138af9..e5b66d383 100644 --- a/backend/dataall/modules/shares_base/services/share_notification_service.py +++ b/backend/dataall/modules/shares_base/services/share_notification_service.py @@ -9,6 +9,7 @@ from dataall.base.context import get_context from dataall.modules.shares_base.services.shares_enums import ShareObjectStatus from dataall.modules.notifications.db.notification_repositories import NotificationRepository +from dataall.modules.notifications.services.ses_email_notification_service import SESEmailNotificationService from dataall.modules.datasets_base.db.dataset_models import DatasetBase log = logging.getLogger(__name__) @@ -56,11 +57,47 @@ def notify_share_object_submission(self, email_id: str): self._create_notification_task(subject=subject, msg=email_notification_msg) return notifications + def notify_persistent_email_reminder(self, email_id: str): + share_link_text = '' + if os.environ.get('frontend_domain_url'): + share_link_text = ( + f'

Please visit data.all share link ' + f'to review and take appropriate action or view more details.' + ) + + msg_intro = f"""Dear User, + This is a reminder that a share request for the dataset "{self.dataset.label}" submitted by {email_id} + on behalf of principal "{self.share.principalId}" is still pending and has not been addressed. + """ + + msg_end = """Your prompt attention to this matter is greatly appreciated. + Best regards, + The Data.all Team + """ + + subject = f'URGENT REMINDER: Data.all | Action Required on Pending Share Request for {self.dataset.label}' + email_notification_msg = msg_intro + share_link_text + msg_end + + notifications = self.register_notifications( + notification_type=DataSharingNotificationType.SHARE_OBJECT_SUBMITTED.value, msg=msg_intro + ) + + self._create_persistent_reminder_notification_task(subject=subject, msg=email_notification_msg) + return notifications + def notify_share_object_approval(self, email_id: str): share_link_text = '' if os.environ.get('frontend_domain_url'): - share_link_text = f'

Please visit data.all share link to take action or view more details' - msg = f'User {email_id} APPROVED share request for dataset {self.dataset.label} for principal {self.share.principalId}' + share_link_text = ( + f'

Please visit data.all share link ' + f'to take action or view more details' + ) + msg = ( + f'User {email_id} APPROVED share request for dataset {self.dataset.label} ' + f'for principal {self.share.principalId}' + ) subject = f'Data.all | Share Request Approved for {self.dataset.label}' email_notification_msg = msg + share_link_text @@ -167,3 +204,42 @@ def _create_notification_task(self, subject, msg): log.info(f'Notification type : {share_notification_config_type} is not active') else: log.info('Notifications are not active') + + def _create_persistent_reminder_notification_task(self, subject, msg): + """ + At the moment just for notification_config_type = email, but designed for additional notification types + Emails sent to: + - dataset.SamlAdminGroupName + - dataset.stewards + """ + share_notification_config = config.get_property( + 'modules.datasets_base.features.share_notifications', default=None + ) + if share_notification_config: + for share_notification_config_type in share_notification_config.keys(): + n_config = share_notification_config[share_notification_config_type] + if n_config.get('active', False) == True: + notification_recipient_groups_list = [self.dataset.SamlAdminGroupName, self.dataset.stewards] + + if share_notification_config_type == 'email': + notification_task: Task = Task( + action='notification.service', + targetUri=self.share.shareUri, + payload={ + 'notificationType': share_notification_config_type, + 'subject': subject, + 'message': msg, + 'recipientGroupsList': notification_recipient_groups_list, + 'recipientEmailList': [], + }, + ) + self.session.add(notification_task) + self.session.commit() + + SESEmailNotificationService.send_email_task( + subject, msg, notification_recipient_groups_list, [] + ) + else: + log.info(f'Notification type : {share_notification_config_type} is not active') + else: + log.info('Notifications are not active') diff --git a/backend/dataall/modules/shares_base/services/sharing_service.py b/backend/dataall/modules/shares_base/services/sharing_service.py index 552732be2..ce35623e7 100644 --- a/backend/dataall/modules/shares_base/services/sharing_service.py +++ b/backend/dataall/modules/shares_base/services/sharing_service.py @@ -1,9 +1,7 @@ import logging -from typing import Any, List, Tuple from dataclasses import dataclass -from time import sleep -from sqlalchemy.orm import Session +from typing import Any from dataall.core.resource_lock.db.resource_lock_repositories import ResourceLockRepository from dataall.base.db import Engine from dataall.core.environment.db.environment_models import ConsumptionRole, Environment, EnvironmentGroup @@ -24,15 +22,12 @@ from dataall.modules.shares_base.services.share_processor_manager import ShareProcessorManager from dataall.modules.shares_base.services.share_object_service import ( ShareObjectService, -) # TODO move to shares_base in following PR +) from dataall.modules.shares_base.services.share_exceptions import PrincipalRoleNotFound from dataall.base.db.exceptions import ResourceLockTimeout log = logging.getLogger(__name__) -MAX_RETRIES = 10 -RETRY_INTERVAL = 60 - @dataclass class ShareData: @@ -94,44 +89,45 @@ def approve_share(cls, engine: Engine, share_uri: str) -> bool: f'Principal role {share_data.share.principalIAMRoleName} is not found.', ) - if not cls.acquire_lock_with_retry( - resources, session, share_data.share.shareUri, share_data.share.__tablename__ + with ResourceLockRepository.acquire_lock_with_retry( + resources=resources, + session=session, + acquired_by_uri=share_data.share.shareUri, + acquired_by_type=share_data.share.__tablename__, ): - raise ResourceLockTimeout( - 'process approved shares', - f'Failed to acquire lock for dataset {share_data.dataset.datasetUri}', - ) - for type, processor in ShareProcessorManager.SHARING_PROCESSORS.items(): - try: - log.info(f'Granting permissions of {type.value}') - shareable_items = ShareObjectRepository.get_share_data_items_by_type( - session, - share_data.share, - processor.shareable_type, - processor.shareable_uri, - status=ShareItemStatus.Share_Approved.value, - ) - success = processor.Processor(session, share_data, shareable_items).process_approved_shares() - log.info(f'Sharing {type.value} succeeded = {success}') - if not success: + for type, processor in ShareProcessorManager.SHARING_PROCESSORS.items(): + try: + log.info(f'Granting permissions of {type.value}') + shareable_items = ShareObjectRepository.get_share_data_items_by_type( + session, + share_data.share, + processor.shareable_type, + processor.shareable_uri, + status=ShareItemStatus.Share_Approved.value, + ) + success = processor.Processor( + session, share_data, shareable_items + ).process_approved_shares() + log.info(f'Sharing {type.value} succeeded = {success}') + if not success: + share_successful = False + except Exception as e: + log.error(f'Error occurred during sharing of {type.value}: {e}') + ShareStatusRepository.update_share_item_status_batch( + session, + share_uri, + old_status=ShareItemStatus.Share_Approved.value, + new_status=ShareItemStatus.Share_Failed.value, + share_item_type=processor.type.value, + ) + ShareStatusRepository.update_share_item_status_batch( + session, + share_uri, + old_status=ShareItemStatus.Share_In_Progress.value, + new_status=ShareItemStatus.Share_Failed.value, + share_item_type=processor.type.value, + ) share_successful = False - except Exception as e: - log.error(f'Error occurred during sharing of {type.value}: {e}') - ShareStatusRepository.update_share_item_status_batch( - session, - share_uri, - old_status=ShareItemStatus.Share_Approved.value, - new_status=ShareItemStatus.Share_Failed.value, - share_item_type=processor.type.value, - ) - ShareStatusRepository.update_share_item_status_batch( - session, - share_uri, - old_status=ShareItemStatus.Share_In_Progress.value, - new_status=ShareItemStatus.Share_Failed.value, - share_item_type=processor.type.value, - ) - share_successful = False return share_successful except Exception as e: @@ -143,13 +139,6 @@ def approve_share(cls, engine: Engine, share_uri: str) -> bool: finally: new_share_state = share_object_sm.run_transition(ShareObjectActions.Finish.value) share_object_sm.update_state(session, share_data.share, new_share_state) - for resource in resources: - if not ResourceLockRepository.release_lock( - session, resource[0], resource[1], share_data.share.shareUri - ): - log.error( - f'Failed to release lock for resource: resource_uri={resource[0]}, resource_type={resource[1]}' - ) @classmethod def revoke_share(cls, engine: Engine, share_uri: str) -> bool: @@ -200,45 +189,43 @@ def revoke_share(cls, engine: Engine, share_uri: str) -> bool: f'Principal role {share_data.share.principalIAMRoleName} is not found.', ) - if not cls.acquire_lock_with_retry( - resources, session, share_data.share.shareUri, share_data.share.__tablename__ + with ResourceLockRepository.acquire_lock_with_retry( + resources=resources, + session=session, + acquired_by_uri=share_data.share.shareUri, + acquired_by_type=share_data.share.__tablename__, ): - raise ResourceLockTimeout( - 'process revoked shares', - f'Failed to acquire lock for dataset {share_data.dataset.datasetUri}', - ) - - for type, processor in ShareProcessorManager.SHARING_PROCESSORS.items(): - try: - shareable_items = ShareObjectRepository.get_share_data_items_by_type( - session, - share_data.share, - processor.shareable_type, - processor.shareable_uri, - status=ShareItemStatus.Revoke_Approved.value, - ) - log.info(f'Revoking permissions with {type.value}') - success = processor.Processor(session, share_data, shareable_items).process_revoked_shares() - log.info(f'Revoking {type.value} succeeded = {success}') - if not success: + for type, processor in ShareProcessorManager.SHARING_PROCESSORS.items(): + try: + shareable_items = ShareObjectRepository.get_share_data_items_by_type( + session, + share_data.share, + processor.shareable_type, + processor.shareable_uri, + status=ShareItemStatus.Revoke_Approved.value, + ) + log.info(f'Revoking permissions with {type.value}') + success = processor.Processor(session, share_data, shareable_items).process_revoked_shares() + log.info(f'Revoking {type.value} succeeded = {success}') + if not success: + revoke_successful = False + except Exception as e: + log.error(f'Error occurred during share revoking of {type.value}: {e}') + ShareStatusRepository.update_share_item_status_batch( + session, + share_uri, + old_status=ShareItemStatus.Revoke_Approved.value, + new_status=ShareItemStatus.Revoke_Failed.value, + share_item_type=processor.type.value, + ) + ShareStatusRepository.update_share_item_status_batch( + session, + share_uri, + old_status=ShareItemStatus.Revoke_In_Progress.value, + new_status=ShareItemStatus.Revoke_Failed.value, + share_item_type=processor.type.value, + ) revoke_successful = False - except Exception as e: - log.error(f'Error occurred during share revoking of {type.value}: {e}') - ShareStatusRepository.update_share_item_status_batch( - session, - share_uri, - old_status=ShareItemStatus.Revoke_Approved.value, - new_status=ShareItemStatus.Revoke_Failed.value, - share_item_type=processor.type.value, - ) - ShareStatusRepository.update_share_item_status_batch( - session, - share_uri, - old_status=ShareItemStatus.Revoke_In_Progress.value, - new_status=ShareItemStatus.Revoke_Failed.value, - share_item_type=processor.type.value, - ) - revoke_successful = False return revoke_successful except Exception as e: @@ -255,14 +242,6 @@ def revoke_share(cls, engine: Engine, share_uri: str) -> bool: new_share_state = share_sm.run_transition(ShareObjectActions.Finish.value) share_sm.update_state(session, share_data.share, new_share_state) - for resource in resources: - if not ResourceLockRepository.release_lock( - session, resource[0], resource[1], share_data.share.shareUri - ): - log.error( - f'Failed to release lock for resource: resource_uri={resource[0]}, resource_type={resource[1]}' - ) - @classmethod def verify_share( cls, @@ -356,57 +335,47 @@ def reapply_share(cls, engine: Engine, share_uri: str) -> bool: log.error(f'Failed to get Principal IAM Role {share_data.share.principalIAMRoleName}, exiting...') return False else: - lock_acquired = cls.acquire_lock_with_retry( - resources, session, share_data.share.shareUri, share_data.share.__tablename__ - ) + with ResourceLockRepository.acquire_lock_with_retry( + resources=resources, + session=session, + acquired_by_uri=share_data.share.shareUri, + acquired_by_type=share_data.share.__tablename__, + ): + for type, processor in ShareProcessorManager.SHARING_PROCESSORS.items(): + try: + log.info(f'Reapplying permissions to {type.value}') + shareable_items = ShareObjectRepository.get_share_data_items_by_type( + session, + share_data.share, + processor.shareable_type, + processor.shareable_uri, + None, + ShareItemHealthStatus.PendingReApply.value, + ) + success = processor.Processor( + session, share_data, shareable_items + ).process_approved_shares() + log.info(f'Reapplying {type.value} succeeded = {success}') + if not success: + reapply_successful = False + except Exception as e: + log.error(f'Error occurred during share reapplying of {type.value}: {e}') - if not lock_acquired: - log.error(f'Failed to acquire lock for dataset {share_data.dataset.datasetUri}, exiting...') - error_message = f'SHARING PROCESS TIMEOUT: Failed to acquire lock for dataset {share_data.dataset.datasetUri}' - ShareStatusRepository.update_share_item_health_status_batch( - session, - share_uri, - old_status=ShareItemHealthStatus.PendingReApply.value, - new_status=ShareItemHealthStatus.Unhealthy.value, - message=error_message, - ) - return False + return reapply_successful - for type, processor in ShareProcessorManager.SHARING_PROCESSORS.items(): - try: - log.info(f'Reapplying permissions to {type.value}') - shareable_items = ShareObjectRepository.get_share_data_items_by_type( - session, - share_data.share, - processor.shareable_type, - processor.shareable_uri, - None, - ShareItemHealthStatus.PendingReApply.value, - ) - success = processor.Processor( - session, share_data, shareable_items - ).process_approved_shares() - log.info(f'Reapplying {type.value} succeeded = {success}') - if not success: - reapply_successful = False - except Exception as e: - log.error(f'Error occurred during share reapplying of {type.value}: {e}') + except ResourceLockTimeout as e: + ShareStatusRepository.update_share_item_health_status_batch( + session, + share_uri, + old_status=ShareItemHealthStatus.PendingReApply.value, + new_status=ShareItemHealthStatus.Unhealthy.value, + message=str(e), + ) - return reapply_successful except Exception as e: log.error(f'Error occurred during share approval: {e}') return False - finally: - with engine.scoped_session() as session: - for resource in resources: - if not ResourceLockRepository.release_lock( - session, resource[0], resource[1], share_data.share.shareUri - ): - log.error( - f'Failed to release lock for resource: resource_uri={resource[0]}, resource_type={resource[1]}' - ) - @staticmethod def _get_share_data_and_items(session, share_uri, status, healthStatus=None): data = ShareObjectRepository.get_share_data(session, share_uri) @@ -422,28 +391,3 @@ def _get_share_data_and_items(session, share_uri, status, healthStatus=None): session=session, share_uri=share_uri, status=[status], healthStatus=[healthStatus] ) return share_data, share_items - - @staticmethod - def acquire_lock_with_retry( - resources: List[Tuple[str, str]], session: Session, acquired_by_uri: str, acquired_by_type: str - ): - for attempt in range(MAX_RETRIES): - try: - log.info(f'Attempting to acquire lock for resources {resources} by share {acquired_by_uri}...') - lock_acquired = ResourceLockRepository.acquire_locks( - resources, session, acquired_by_uri, acquired_by_type - ) - if lock_acquired: - return True - - log.info( - f'Lock for one or more resources {resources} already acquired. Retrying in {RETRY_INTERVAL} seconds...' - ) - sleep(RETRY_INTERVAL) - - except Exception as e: - log.error('Error occurred while retrying acquiring lock:', e) - return False - - log.info(f'Max retries reached. Failed to acquire lock for one or more resources {resources}') - return False diff --git a/backend/dataall/modules/shares_base/tasks/persistent_email_reminders_task.py b/backend/dataall/modules/shares_base/tasks/persistent_email_reminders_task.py new file mode 100644 index 000000000..67a77145f --- /dev/null +++ b/backend/dataall/modules/shares_base/tasks/persistent_email_reminders_task.py @@ -0,0 +1,44 @@ +import logging +import os +import sys +from dataall.modules.shares_base.db.share_object_models import ShareObject +from dataall.base.db import get_engine +from dataall.base.aws.sqs import SqsQueue +from dataall.core.tasks.service_handlers import Worker +from backend.dataall.modules.shares_base.db.share_object_repositories import ShareObjectRepository +from dataall.modules.shares_base.services.share_notification_service import ShareNotificationService +from dataall.modules.datasets_base.db.dataset_repositories import DatasetBaseRepository + + +root = logging.getLogger() +root.setLevel(logging.INFO) +if not root.hasHandlers(): + root.addHandler(logging.StreamHandler(sys.stdout)) +log = logging.getLogger(__name__) + + +def persistent_email_reminders(engine): + """ + A method used by the scheduled ECS Task to run persistent_email_reminder() process against ALL + active share objects within data.all and send emails to all pending shares. + """ + with engine.scoped_session() as session: + log.info('Running Persistent Email Reminders Task') + pending_shares = ShareObjectRepository.fetch_submitted_shares_with_notifications(session=session) + log.info(f'Found {len(pending_shares)} pending shares') + pending_share: ShareObject + for pending_share in pending_shares: + log.info(f'Sending Email Reminder for Share: {pending_share.shareUri}') + share = ShareObjectRepository.get_share_by_uri(session, pending_share.shareUri) + dataset = DatasetBaseRepository.get_dataset_by_uri(session, share.datasetUri) + ShareNotificationService(session=session, dataset=dataset, share=share).notify_persistent_email_reminder( + email_id=share.owner + ) + log.info(f'Email reminder sent for share {share.shareUri}') + log.info('Completed Persistent Email Reminders Task') + + +if __name__ == '__main__': + ENVNAME = os.environ.get('envname', 'local') + ENGINE = get_engine(envname=ENVNAME) + persistent_email_reminders(engine=ENGINE) diff --git a/backend/dataall/modules/shares_base/tasks/share_reapplier_task.py b/backend/dataall/modules/shares_base/tasks/share_reapplier_task.py index 872095bce..053d53dd2 100644 --- a/backend/dataall/modules/shares_base/tasks/share_reapplier_task.py +++ b/backend/dataall/modules/shares_base/tasks/share_reapplier_task.py @@ -18,32 +18,68 @@ log = logging.getLogger(__name__) -def reapply_shares(engine): +class EcsBulkShareRepplyService: + @classmethod + def process_reapply_shares_for_dataset(cls, engine, dataset_uri): + with engine.scoped_session() as session: + processed_share_objects = [] + share_objects_for_dataset = ShareObjectRepository.list_active_share_object_for_dataset( + session=session, dataset_uri=dataset_uri + ) + log.info(f'Found {len(share_objects_for_dataset)} active share objects on dataset with uri: {dataset_uri}') + share_object: ShareObject + for share_object in share_objects_for_dataset: + log.info( + f'Re-applying Share Items for Share Object (Share URI: {share_object.shareUri} ) with Requestor: {share_object.principalId} on Target Dataset: {share_object.datasetUri}' + ) + processed_share_objects.append(share_object.shareUri) + ShareStatusRepository.update_share_item_health_status_batch( + session=session, + share_uri=share_object.shareUri, + old_status=ShareItemHealthStatus.Unhealthy.value, + new_status=ShareItemHealthStatus.PendingReApply.value, + ) + SharingService.reapply_share(engine, share_uri=share_object.shareUri) + return processed_share_objects + + @classmethod + def process_reapply_shares(cls, engine): + with engine.scoped_session() as session: + processed_share_objects = [] + all_share_objects: [ShareObject] = ShareObjectRepository.list_all_active_share_objects(session) + log.info(f'Found {len(all_share_objects)} share objects ') + share_object: ShareObject + for share_object in all_share_objects: + log.info( + f'Re-applying Share Items for Share Object with Requestor: {share_object.principalId} on Target Dataset: {share_object.datasetUri}' + ) + processed_share_objects.append(share_object.shareUri) + ShareStatusRepository.update_share_item_health_status_batch( + session=session, + share_uri=share_object.shareUri, + old_status=ShareItemHealthStatus.Unhealthy.value, + new_status=ShareItemHealthStatus.PendingReApply.value, + ) + SharingService.reapply_share(engine, share_uri=share_object.shareUri) + return processed_share_objects + + +def reapply_shares(engine, dataset_uri): """ A method used by the scheduled ECS Task to re-apply_share() on all data.all active shares + If dataset_uri is provided this ECS will reapply on all unhealthy shares belonging to a dataset + else it will reapply on all data.all active unhealthy shares. """ - with engine.scoped_session() as session: - processed_share_objects = [] - all_share_objects: [ShareObject] = ShareObjectRepository.list_all_active_share_objects(session) - log.info(f'Found {len(all_share_objects)} share objects ') - share_object: ShareObject - for share_object in all_share_objects: - log.info( - f'Re-applying Share Items for Share Object with Requestor: {share_object.principalId} on Target Dataset: {share_object.datasetUri}' - ) - processed_share_objects.append(share_object.shareUri) - ShareStatusRepository.update_share_item_health_status_batch( - session=session, - share_uri=share_object.shareUri, - old_status=ShareItemHealthStatus.Unhealthy.value, - new_status=ShareItemHealthStatus.PendingReApply.value, - ) - SharingService.reapply_share(engine, share_uri=share_object.shareUri) - return processed_share_objects + if dataset_uri: + return EcsBulkShareRepplyService.process_reapply_shares_for_dataset(engine, dataset_uri) + else: + return EcsBulkShareRepplyService.process_reapply_shares(engine) if __name__ == '__main__': load_modules(modes={ImportMode.SHARES_TASK}) ENVNAME = os.environ.get('envname', 'local') ENGINE = get_engine(envname=ENVNAME) - reapply_shares(engine=ENGINE) + dataset_uri = os.environ.get('datasetUri', '') + processed_shares = reapply_shares(engine=ENGINE, dataset_uri=dataset_uri) + log.info(f'Finished processing {len(processed_shares)} shares') diff --git a/backend/migrations/versions/328e35e39e1e_invite_env_groups_as_readers.py b/backend/migrations/versions/328e35e39e1e_invite_env_groups_as_readers.py index 008abdf7c..5f243a8ff 100644 --- a/backend/migrations/versions/328e35e39e1e_invite_env_groups_as_readers.py +++ b/backend/migrations/versions/328e35e39e1e_invite_env_groups_as_readers.py @@ -10,9 +10,9 @@ from sqlalchemy import orm from dataall.core.environment.db.environment_models import EnvironmentGroup, Environment from dataall.core.organizations.db.organization_repositories import OrganizationRepository -from dataall.core.organizations.services.organization_service import OrganizationService from dataall.core.permissions.services.organization_permissions import GET_ORGANIZATION - +from dataall.core.organizations.db import organization_models as models +from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService # revision identifiers, used by Alembic. revision = '328e35e39e1e' @@ -37,15 +37,24 @@ def upgrade(): env_org[e.environmentUri] = e.organizationUri for group in all_env_groups: - group_membership = OrganizationRepository.find_group_membership( - session, [group.groupUri], env_org[group.environmentUri] - ) + organization_uri = env_org[group.environmentUri] + + group_membership = OrganizationRepository.find_group_membership(session, [group.groupUri], organization_uri) + if group_membership is None: - data = { - 'groupUri': group.groupUri, - 'permissions': [GET_ORGANIZATION], - } - OrganizationService.invite_group(env_org[group.environmentUri], data) + # 1. Add Organization Group + org_group = models.OrganizationGroup(organizationUri=organization_uri, groupUri=group.groupUri) + session.add(org_group) + + # 2. Add Resource Policy Permissions + permissions = [GET_ORGANIZATION] + ResourcePolicyService.attach_resource_policy( + session=session, + group=group.groupUri, + resource_uri=organization_uri, + permissions=permissions, + resource_type=models.Organization.__name__, + ) def downgrade(): diff --git a/backend/migrations/versions/797dd1012be1_resource_lock_table.py b/backend/migrations/versions/797dd1012be1_resource_lock_table.py index 085c35672..de5f27351 100644 --- a/backend/migrations/versions/797dd1012be1_resource_lock_table.py +++ b/backend/migrations/versions/797dd1012be1_resource_lock_table.py @@ -1,7 +1,7 @@ """resource_lock_table Revision ID: 797dd1012be1 -Revises: 6adce90ab470 +Revises: 18da23d3ba44 Create Date: 2024-06-17 19:06:51.569471 """ @@ -9,7 +9,7 @@ from alembic import op from sqlalchemy import orm, Column, String, Boolean, ForeignKey import sqlalchemy as sa -from typing import Optional +from typing import Optional, List from sqlalchemy.ext.declarative import declarative_base from dataall.base.db import utils @@ -27,7 +27,6 @@ class ResourceLock(Base): resourceUri = Column(String, nullable=False, primary_key=True) resourceType = Column(String, nullable=False, primary_key=True) - isLocked = Column(Boolean, default=False) acquiredByUri = Column(String, nullable=True) acquiredByType = Column(String, nullable=True) @@ -35,13 +34,11 @@ def __init__( self, resourceUri: str, resourceType: str, - isLocked: bool = False, acquiredByUri: Optional[str] = None, acquiredByType: Optional[str] = None, ): self.resourceUri = resourceUri self.resourceType = resourceType - self.isLocked = isLocked self.acquiredByUri = acquiredByUri self.acquiredByType = acquiredByType @@ -58,126 +55,63 @@ class S3Dataset(DatasetBase): datasetUri = Column(String, ForeignKey('dataset.datasetUri'), primary_key=True) -class EnvironmentGroup(Base): - __tablename__ = 'environment_group_permission' - groupUri = Column(String, primary_key=True) - environmentUri = Column(String, primary_key=True) +class DatasetLock(Base): + __tablename__ = 'dataset_lock' + datasetUri = Column(String, nullable=False, primary_key=True) + isLocked = Column(Boolean, default=False, nullable=False) + acquiredBy = Column(String, nullable=True) - -class ConsumptionRole(Base): - __tablename__ = 'consumptionrole' - consumptionRoleUri = Column(String, primary_key=True, default=utils.uuid('group')) + @classmethod + def uri(cls): + return cls.datasetUri def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - # Drop Foregin Key - op.drop_constraint('dataset_lock_datasetUri_fkey', 'dataset_lock', type_='foreignkey') - - # Rename Table to Resource Lock - op.rename_table('dataset_lock', 'resource_lock') + ## drop dataset_lock table + op.drop_table('dataset_lock') - # Rename Columns - op.alter_column( + ## create resource_lock table + op.create_table( 'resource_lock', - 'datasetUri', - nullable=False, - new_column_name='resourceUri', - existing_type=String, - primary_key=True, + sa.Column('resourceUri', sa.String(), nullable=False, primary_key=True), + sa.Column('resourceType', sa.String(), nullable=False, primary_key=True), + sa.Column('acquiredByUri', sa.String(), nullable=True), + sa.Column('acquiredByType', sa.String(), nullable=True), ) - op.alter_column( - 'resource_lock', - 'acquiredBy', - nullable=True, - new_column_name='acquiredByUri', - existing_type=String, - ) - - # Add New Columns - op.add_column('resource_lock', sa.Column('resourceType', sa.String())) - op.add_column('resource_lock', sa.Column('acquiredByType', sa.String(), nullable=True)) - - session = orm.Session(bind=op.get_bind()) - - # Backfill Dataset Locks - session.query(ResourceLock).update( - { - ResourceLock.resourceType: S3Dataset.__tablename__, - } - ) - session.commit() - - # Add resourceType as primary key after backfilling - op.alter_column('resource_lock', 'resourceType', primary_key=True) - - # Backfill Locks for Env Groups - env_groups = session.query(EnvironmentGroup).all() - for group in env_groups: - lock = ResourceLock( - resourceUri=f'{group.groupUri}-{group.environmentUri}', - resourceType=EnvironmentGroup.__tablename__, - ) - session.add(lock) - session.commit() - - # Backfill Locks for Consumption Roles - consumption_roles = session.query(ConsumptionRole).all() - for role in consumption_roles: - lock = ResourceLock(resourceUri=role.consumptionRoleUri, resourceType=ConsumptionRole.__tablename__) - session.add(lock) - session.commit() # ### end Alembic commands ### def downgrade(): - session = orm.Session(bind=op.get_bind()) - # Deleting Locks for Env Groups - env_groups = session.query(EnvironmentGroup).all() - for group in env_groups: - lock = session.query(ResourceLock).get( - (f'{group.groupUri}-{group.environmentUri}', EnvironmentGroup.__tablename__) - ) - if lock: - print('YES LOCK') - session.delete(lock) - - # Deleting Locks for Consumption Roles - consumption_roles = session.query(ConsumptionRole).all() - for role in consumption_roles: - print('CR ROLE', role.consumptionRoleUri) - lock = session.query(ResourceLock).get((role.consumptionRoleUri, ConsumptionRole.__tablename__)) - if lock: - print('YES LOCK') - session.delete(lock) - session.commit() + # Drop resource_lock table + op.drop_table('resource_lock') - # Drop Columns - op.drop_column('resource_lock', 'resourceType') - op.drop_column('resource_lock', 'acquiredByType') + bind = op.get_bind() + session = orm.Session(bind=bind) + datasets: List[S3Dataset] = session.query(S3Dataset).all() - # Rename Columns - op.alter_column( - 'resource_lock', - 'resourceUri', - nullable=False, - new_column_name='datasetUri', - existing_type=String, - primary_key=True, - ) - op.alter_column( - 'resource_lock', - 'acquiredByUri', - nullable=True, - new_column_name='acquiredBy', - existing_type=String, - ) + print('Creating dataset_lock table') - # Rename Table to Dataset Lock - op.rename_table('resource_lock', 'dataset_lock') + op.create_table( + 'dataset_lock', + sa.Column('datasetUri', sa.String(), primary_key=True), + sa.Column('isLocked', sa.Boolean(), nullable=False), + sa.Column('acquiredBy', sa.String(), nullable=True), + ) - # Add Foregin Key - op.create_foreign_key('dataset_lock_datasetUri_fkey', 'dataset_lock', 'dataset', ['datasetUri'], ['datasetUri']) + op.create_foreign_key( + 'fk_dataset_lock_datasetUri', # Constraint name + 'dataset_lock', + 'dataset', + ['datasetUri'], + ['datasetUri'], + ) + print('Creating a new row for each existing dataset in dataset_lock table') + for dataset in datasets: + dataset_lock = DatasetLock(datasetUri=dataset.datasetUri, isLocked=False, acquiredBy='') + session.add(dataset_lock) + session.flush() # flush to get the datasetUri + session.commit() # ### end Alembic commands ### diff --git a/backend/migrations/versions/d05f9a5b215e_backfill_dataset_table_permissions.py b/backend/migrations/versions/d05f9a5b215e_backfill_dataset_table_permissions.py index 180483888..0474c86f4 100644 --- a/backend/migrations/versions/d05f9a5b215e_backfill_dataset_table_permissions.py +++ b/backend/migrations/versions/d05f9a5b215e_backfill_dataset_table_permissions.py @@ -24,7 +24,7 @@ ShareableType, ShareItemStatus, ) -from dataall.modules.s3_datasets_shares.db.share_object_repositories import ShareObjectRepository +from dataall.modules.shares_base.db.share_object_repositories import ShareObjectRepository from dataall.modules.s3_datasets.db.dataset_repositories import DatasetRepository from dataall.modules.s3_datasets.services.dataset_permissions import DATASET_TABLE_READ diff --git a/cdk.json b/cdk.json index d4b8d9543..c11c3c050 100644 --- a/cdk.json +++ b/cdk.json @@ -14,7 +14,8 @@ "region": "us-east-1", "prod_sizing": false, "vpc_restricted_nacl": true, - "enable_pivot_role_auto_create": true + "enable_pivot_role_auto_create": true, + "with_approval_tests": true } ] } diff --git a/config.json b/config.json index c9d516c56..d5cd97f89 100644 --- a/config.json +++ b/config.json @@ -18,6 +18,7 @@ "share_notifications": { "email": { "active": false, + "persistent_reminders": false, "parameters": { "group_notifications": true } diff --git a/deploy/custom_resources/cognito_config/cognito_urls.py b/deploy/custom_resources/cognito_config/cognito_urls.py new file mode 100644 index 000000000..43be35c1f --- /dev/null +++ b/deploy/custom_resources/cognito_config/cognito_urls.py @@ -0,0 +1,82 @@ +import logging +import os + +import boto3 + +logger = logging.getLogger() +logger.setLevel(os.environ.get('LOG_LEVEL', 'INFO')) +log = logging.getLogger(__name__) + + +def setup_cognito( + region, + envname, + custom_domain='False', +): + ssm = boto3.client('ssm', region_name=region) + user_pool_id = ssm.get_parameter(Name=f'/dataall/{envname}/cognito/userpool')['Parameter']['Value'] + log.info(f'Cognito Pool ID: {user_pool_id}') + app_client = ssm.get_parameter(Name=f'/dataall/{envname}/cognito/appclient')['Parameter']['Value'] + + if custom_domain == 'False': + log.info('Switching to us-east-1 region...') + ssm = boto3.client('ssm', region_name='us-east-1') + signin_singout_link = ssm.get_parameter(Name=f'/dataall/{envname}/CloudfrontDistributionDomainName')[ + 'Parameter' + ]['Value'] + user_guide_link = ssm.get_parameter( + Name=f'/dataall/{envname}/cloudfront/docs/user/CloudfrontDistributionDomainName' + )['Parameter']['Value'] + else: + signin_singout_link = ssm.get_parameter(Name=f'/dataall/{envname}/frontend/custom_domain_name')['Parameter'][ + 'Value' + ] + user_guide_link = ssm.get_parameter(Name=f'/dataall/{envname}/userguide/custom_domain_name')['Parameter'][ + 'Value' + ] + + log.info(f'UI: {signin_singout_link}') + log.info(f'USERGUIDE: {user_guide_link}') + + cognito = boto3.client('cognito-idp', region_name=region) + user_pool = cognito.describe_user_pool_client(UserPoolId=user_pool_id, ClientId=app_client) + + del user_pool['UserPoolClient']['CreationDate'] + del user_pool['UserPoolClient']['LastModifiedDate'] + + config_callbacks = [ + f'https://{signin_singout_link}', + f'https://{user_guide_link}/parseauth', + ] + existing_callbacks = user_pool['UserPoolClient'].get('CallbackURLs', []) + if 'https://example.com' in existing_callbacks: + existing_callbacks.remove('https://example.com') + updated_callbacks = existing_callbacks + list(set(config_callbacks) - set(existing_callbacks)) + log.info(f'Updated CallBackUrls: {updated_callbacks}') + + config_logout_urls = [f'https://{signin_singout_link}'] + existing_logout_urls = user_pool['UserPoolClient'].get('LogoutURLs', []) + updated_logout_urls = existing_logout_urls + list(set(config_logout_urls) - set(existing_logout_urls)) + log.info(f'Updated LogOutUrls: {updated_logout_urls}') + + user_pool['UserPoolClient']['CallbackURLs'] = updated_callbacks + user_pool['UserPoolClient']['LogoutURLs'] = updated_logout_urls + + response = cognito.update_user_pool_client( + **user_pool['UserPoolClient'], + ) + + log.info(f'CallbackUrls and LogOutUrls updated successfully: {response}') + + +def handler(event, context) -> None: + log.info('Starting Cognito Configuration...') + envname = os.environ.get('envname') + region = os.environ.get('deployment_region') + custom_domain = os.environ.get('custom_domain') + setup_cognito( + region, + envname, + custom_domain, + ) + log.info('Cognito Configuration Finished Successfully') diff --git a/deploy/custom_resources/cognito_config/cognito_urls_config.py b/deploy/custom_resources/cognito_config/cognito_users.py similarity index 63% rename from deploy/custom_resources/cognito_config/cognito_urls_config.py rename to deploy/custom_resources/cognito_config/cognito_users.py index 836616822..dc17c8517 100644 --- a/deploy/custom_resources/cognito_config/cognito_urls_config.py +++ b/deploy/custom_resources/cognito_config/cognito_users.py @@ -22,65 +22,13 @@ def setup_cognito( region, resource_prefix, envname, - internet_facing='True', - custom_domain='False', enable_cw_canaries='False', with_approval_tests='False', ): ssm = boto3.client('ssm', region_name=region) user_pool_id = ssm.get_parameter(Name=f'/dataall/{envname}/cognito/userpool')['Parameter']['Value'] log.info(f'Cognito Pool ID: {user_pool_id}') - app_client = ssm.get_parameter(Name=f'/dataall/{envname}/cognito/appclient')['Parameter']['Value'] - - if custom_domain == 'False' and internet_facing == 'True': - log.info('Switching to us-east-1 region...') - ssm = boto3.client('ssm', region_name='us-east-1') - signin_singout_link = ssm.get_parameter(Name=f'/dataall/{envname}/CloudfrontDistributionDomainName')[ - 'Parameter' - ]['Value'] - user_guide_link = ssm.get_parameter( - Name=f'/dataall/{envname}/cloudfront/docs/user/CloudfrontDistributionDomainName' - )['Parameter']['Value'] - else: - signin_singout_link = ssm.get_parameter(Name=f'/dataall/{envname}/frontend/custom_domain_name')['Parameter'][ - 'Value' - ] - user_guide_link = ssm.get_parameter(Name=f'/dataall/{envname}/userguide/custom_domain_name')['Parameter'][ - 'Value' - ] - - log.info(f'UI: {signin_singout_link}') - log.info(f'USERGUIDE: {user_guide_link}') - cognito = boto3.client('cognito-idp', region_name=region) - user_pool = cognito.describe_user_pool_client(UserPoolId=user_pool_id, ClientId=app_client) - - del user_pool['UserPoolClient']['CreationDate'] - del user_pool['UserPoolClient']['LastModifiedDate'] - - config_callbacks = [ - f'https://{signin_singout_link}', - f'https://{user_guide_link}/parseauth', - ] - existing_callbacks = user_pool['UserPoolClient'].get('CallbackURLs', []) - if 'https://example.com' in existing_callbacks: - existing_callbacks.remove('https://example.com') - updated_callbacks = existing_callbacks + list(set(config_callbacks) - set(existing_callbacks)) - log.info(f'Updated CallBackUrls: {updated_callbacks}') - - config_logout_urls = [f'https://{signin_singout_link}'] - existing_logout_urls = user_pool['UserPoolClient'].get('LogoutURLs', []) - updated_logout_urls = existing_logout_urls + list(set(config_logout_urls) - set(existing_logout_urls)) - log.info(f'Updated LogOutUrls: {updated_logout_urls}') - - user_pool['UserPoolClient']['CallbackURLs'] = updated_callbacks - user_pool['UserPoolClient']['LogoutURLs'] = updated_logout_urls - - response = cognito.update_user_pool_client( - **user_pool['UserPoolClient'], - ) - - log.info(f'CallbackUrls and LogOutUrls updated successfully: {response}') try: response = cognito.create_group( @@ -162,8 +110,6 @@ def handler(event, context) -> None: log.info('Starting Cognito Configuration...') envname = os.environ.get('envname') region = os.environ.get('deployment_region') - internet_facing = os.environ.get('internet_facing') - custom_domain = os.environ.get('custom_domain') enable_cw_canaries = os.environ.get('enable_cw_canaries') resource_prefix = os.environ.get('resource_prefix') with_approval_tests = os.environ.get('with_approval_tests') @@ -171,8 +117,6 @@ def handler(event, context) -> None: region, resource_prefix, envname, - internet_facing, - custom_domain, enable_cw_canaries, with_approval_tests, ) diff --git a/deploy/custom_resources/utils.py b/deploy/custom_resources/utils.py new file mode 100644 index 000000000..899b8a6cc --- /dev/null +++ b/deploy/custom_resources/utils.py @@ -0,0 +1,23 @@ +import os + +from aws_cdk import aws_lambda, BundlingOptions +from aws_cdk.aws_lambda import AssetCode + +from stacks.solution_bundling import SolutionBundling + + +def get_lambda_code(path, image=aws_lambda.Runtime.PYTHON_3_9.bundling_image) -> AssetCode: + assets_path = os.path.realpath( + os.path.join( + os.path.dirname(__file__), + path, + ) + ) + + return aws_lambda.Code.from_asset( + path=assets_path, + bundling=BundlingOptions( + image=image, + local=SolutionBundling(source_path=assets_path), + ), + ) diff --git a/deploy/stacks/backend_stack.py b/deploy/stacks/backend_stack.py index c3dbe5cca..64dd7fd83 100644 --- a/deploy/stacks/backend_stack.py +++ b/deploy/stacks/backend_stack.py @@ -219,6 +219,9 @@ def __init__( self.lambda_api_stack.api_handler, self.lambda_api_stack.elasticsearch_proxy_handler, ], + email_custom_domain=ses_stack.ses_identity.email_identity_name if ses_stack is not None else None, + ses_configuration_set=ses_stack.configuration_set.configuration_set_name if ses_stack is not None else None, + custom_domain=custom_domain, **kwargs, ) diff --git a/deploy/stacks/cloudfront.py b/deploy/stacks/cloudfront.py index 20c004bb8..2a9da100c 100644 --- a/deploy/stacks/cloudfront.py +++ b/deploy/stacks/cloudfront.py @@ -15,8 +15,11 @@ RemovalPolicy, CfnOutput, BundlingOptions, + aws_kms, ) +from aws_cdk.triggers import TriggerFunction +from custom_resources.utils import get_lambda_code from .pyNestedStack import pyNestedClass from .solution_bundling import SolutionBundling from .waf_rules import get_waf_rules @@ -34,6 +37,7 @@ def __init__( custom_waf_rules=None, tooling_account_id=None, custom_auth=None, + backend_region=None, **kwargs, ): super().__init__(scope, id, **kwargs) @@ -269,6 +273,9 @@ def __init__( ) ) + if not custom_auth: + self.cognito_urls_config(resource_prefix, envname, backend_region, custom_domain, [cloudfront_distribution]) + CfnOutput( self, 'OutputCfnFrontDistribution', @@ -489,3 +496,94 @@ def error_responses(): response_page_path='/index.html', ), ] + + def cognito_urls_config(self, resource_prefix, envname, backend_region, custom_domain, execute_after): + lambda_env_key = aws_kms.Key( + self, + f'{resource_prefix}-{envname}-cogn-urls-lambda-env-var-key', + removal_policy=RemovalPolicy.DESTROY, + alias=f'{resource_prefix}-{envname}-cogn-urls-lambda-env-var-key', + enable_key_rotation=True, + policy=iam.PolicyDocument( + statements=[ + iam.PolicyStatement( + resources=['*'], + effect=iam.Effect.ALLOW, + principals=[ + iam.AccountPrincipal(account_id=self.account), + ], + actions=['kms:*'], + ), + iam.PolicyStatement( + resources=['*'], + effect=iam.Effect.ALLOW, + principals=[ + iam.ServicePrincipal(service='lambda.amazonaws.com'), + ], + actions=['kms:GenerateDataKey*', 'kms:Decrypt'], + ), + ], + ), + ) + + cognito_config_code = get_lambda_code('cognito_config') + + TriggerFunction( + self, + 'TriggerFunction-CognitoUrlsConfig', + function_name=f'{resource_prefix}-{envname}-cognito_urls_config', + description='dataall CognitoUrlsConfig trigger function', + initial_policy=[ + iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + 'secretsmanager:DescribeSecret', + 'secretsmanager:GetSecretValue', + 'ssm:GetParameterHistory', + 'ssm:GetParameters', + 'ssm:GetParameter', + 'ssm:GetParametersByPath', + 'kms:Decrypt', + 'kms:GenerateDataKey', + 'kms:DescribeKey', + 'rum:GetAppMonitor', + ], + resources=[ + f'arn:aws:kms:{self.region}:{self.account}:key/*', + f'arn:aws:ssm:*:{self.account}:parameter/*dataall*', + f'arn:aws:secretsmanager:{self.region}:{self.account}:secret:*dataall*', + f'arn:aws:rum:{self.region}:{self.account}:appmonitor/*dataall*', + ], + ), + iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + 'cognito-idp:AddCustomAttributes', + 'cognito-idp:UpdateUserPool', + 'cognito-idp:DescribeUserPoolClient', + 'cognito-idp:CreateGroup', + 'cognito-idp:UpdateUserPoolClient', + 'cognito-idp:AdminSetUserPassword', + 'cognito-idp:AdminCreateUser', + 'cognito-idp:DescribeUserPool', + 'cognito-idp:AdminAddUserToGroup', + ], + resources=[f'arn:aws:cognito-idp:{backend_region}:{self.account}:userpool/*'], + ), + ], + code=cognito_config_code, + memory_size=256, + timeout=Duration.minutes(15), + environment={ + 'envname': envname, + 'deployment_region': backend_region, + 'custom_domain': str(bool(custom_domain)), + }, + environment_encryption=lambda_env_key, + tracing=_lambda.Tracing.ACTIVE, + retry_attempts=0, + runtime=_lambda.Runtime.PYTHON_3_9, + handler='cognito_urls.handler', + execute_after=execute_after, + execute_on_handler_change=True, + ) diff --git a/deploy/stacks/cloudfront_stack.py b/deploy/stacks/cloudfront_stack.py index 412a64975..bd1bd4512 100644 --- a/deploy/stacks/cloudfront_stack.py +++ b/deploy/stacks/cloudfront_stack.py @@ -16,6 +16,7 @@ def __init__( custom_domain=None, custom_waf_rules=None, custom_auth=None, + backend_region=None, **kwargs, ): super().__init__(scope, id, **kwargs) @@ -42,5 +43,6 @@ def __init__( custom_domain=custom_domain, custom_waf_rules=custom_waf_rules, custom_auth=custom_auth, + backend_region=backend_region, **kwargs, ) diff --git a/deploy/stacks/cloudfront_stage.py b/deploy/stacks/cloudfront_stage.py index 19a79c2e6..329816834 100644 --- a/deploy/stacks/cloudfront_stage.py +++ b/deploy/stacks/cloudfront_stage.py @@ -16,6 +16,7 @@ def __init__( custom_domain=None, custom_waf_rules=None, custom_auth=None, + backend_region=None, **kwargs, ): super().__init__(scope, id, **kwargs) @@ -29,6 +30,7 @@ def __init__( custom_domain=custom_domain, custom_waf_rules=custom_waf_rules, custom_auth=custom_auth, + backend_region=backend_region, ) Tags.of(cloudfront_stack).add('Application', f'{resource_prefix}-{envname}') diff --git a/deploy/stacks/cognito.py b/deploy/stacks/cognito.py index 532295c29..69299a2ad 100644 --- a/deploy/stacks/cognito.py +++ b/deploy/stacks/cognito.py @@ -17,6 +17,7 @@ from aws_cdk.aws_cognito import AuthFlow from aws_cdk.triggers import TriggerFunction +from custom_resources.utils import get_lambda_code from .pyNestedStack import pyNestedClass from .solution_bundling import SolutionBundling from .waf_rules import get_waf_rules @@ -240,6 +241,34 @@ def __init__( string_value=cross_account_frontend_config_role.role_name, ) + lambda_env_key = kms.Key( + self, + f'{resource_prefix}-{envname}-cogn-config-lambda-env-var-key', + removal_policy=RemovalPolicy.DESTROY, + alias=f'{resource_prefix}-{envname}-cogn-config-lambda-env-var-key', + enable_key_rotation=True, + policy=iam.PolicyDocument( + statements=[ + iam.PolicyStatement( + resources=['*'], + effect=iam.Effect.ALLOW, + principals=[ + iam.AccountPrincipal(account_id=self.account), + ], + actions=['kms:*'], + ), + iam.PolicyStatement( + resources=['*'], + effect=iam.Effect.ALLOW, + principals=[ + iam.ServicePrincipal(service='lambda.amazonaws.com'), + ], + actions=['kms:GenerateDataKey*', 'kms:Decrypt'], + ), + ], + ), + ) + if internet_facing: role_inline_policy = iam.Policy( self, @@ -281,33 +310,6 @@ def __init__( 'sync_congito_params', ) ) - lambda_env_key = kms.Key( - self, - f'{resource_prefix}-cogn-lambda-env-var-key', - removal_policy=RemovalPolicy.DESTROY, - alias=f'{resource_prefix}-cogn-lambda-env-var-key', - enable_key_rotation=True, - policy=iam.PolicyDocument( - statements=[ - iam.PolicyStatement( - resources=['*'], - effect=iam.Effect.ALLOW, - principals=[ - iam.AccountPrincipal(account_id=self.account), - ], - actions=['kms:*'], - ), - iam.PolicyStatement( - resources=['*'], - effect=iam.Effect.ALLOW, - principals=[ - iam.ServicePrincipal(service='lambda.amazonaws.com'), - ], - actions=['kms:GenerateDataKey*', 'kms:Decrypt'], - ), - ], - ), - ) cognito_sync_handler = _lambda.Function( self, f'CognitoParamsSyncHandler{envname}', @@ -350,22 +352,7 @@ def __init__( sync_cr.node.add_dependency(domain_name) sync_cr.node.add_dependency(pool_arn) - cognito_config_assets = os.path.realpath( - os.path.join( - os.path.dirname(__file__), - '..', - 'custom_resources', - 'cognito_config', - ) - ) - - cognito_config_code = _lambda.Code.from_asset( - path=cognito_config_assets, - bundling=BundlingOptions( - image=_lambda.Runtime.PYTHON_3_9.bundling_image, - local=SolutionBundling(source_path=cognito_config_assets), - ), - ) + cognito_config_code = get_lambda_code('cognito_config') TriggerFunction( self, @@ -412,8 +399,6 @@ def __init__( environment={ 'envname': envname, 'deployment_region': self.region, - 'internet_facing': str(internet_facing), - 'custom_domain': str(not domain_name), 'enable_cw_canaries': str(enable_cw_rum), 'resource_prefix': resource_prefix, 'with_approval_tests': str(with_approval_tests), @@ -422,7 +407,7 @@ def __init__( tracing=_lambda.Tracing.ACTIVE, retry_attempts=0, runtime=_lambda.Runtime.PYTHON_3_9, - handler='cognito_urls_config.handler', + handler='cognito_users.handler', execute_after=[self.client], execute_on_handler_change=True, ) diff --git a/deploy/stacks/container.py b/deploy/stacks/container.py index a9981b621..bbad9efa5 100644 --- a/deploy/stacks/container.py +++ b/deploy/stacks/container.py @@ -31,6 +31,9 @@ def __init__( tooling_account_id=None, s3_prefix_list=None, lambdas=None, + email_custom_domain=None, + ses_configuration_set=None, + custom_domain=None, **kwargs, ): super().__init__(scope, id, **kwargs) @@ -49,6 +52,13 @@ def __init__( envname, resource_prefix, vpc, vpce_connection, s3_prefix_list, lambdas ) self.ecs_security_groups: [aws_ec2.SecurityGroup] = [self.scheduled_tasks_sg, self.share_manager_sg] + self.env_vars = self._create_env('INFO') + + # Check if custom domain exists and if it exists email notifications could be enabled. + # Create an env variable which stores the domain URL. + # This is used for sending data.all share weblinks in the email notifications. + if custom_domain and custom_domain.get('hosted_zone_name'): + self.env_vars.update({'frontend_domain_url': f'https://{custom_domain["hosted_zone_name"]}'}) cluster = ecs.Cluster( self, @@ -58,7 +68,10 @@ def __init__( container_insights=True, ) - self.task_role = self.create_task_role(envname, resource_prefix, pivot_role_name) + self.task_role = self.create_task_role( + envname, resource_prefix, pivot_role_name, email_custom_domain, ses_configuration_set + ) + self.cicd_stacks_updater_role = self.create_cicd_stacks_updater_role( envname, resource_prefix, tooling_account_id ) @@ -178,6 +191,7 @@ def __init__( self.add_share_verifier_task() self.add_share_reapplier_task() self.add_omics_fetch_workflows_task() + self.add_persistent_email_reminders_task() @run_if(['modules.s3_datasets.active', 'modules.dashboards.active']) def add_catalog_indexer_task(self): @@ -299,8 +313,49 @@ def add_share_reapplier_task(self): ), readonly_root_filesystem=True, ) + + ssm.StringParameter( + self, + f'ShareReapplierTaskARNSSM{self._envname}', + parameter_name=f'/dataall/{self._envname}/ecs/task_def_arn/share_reapplier', + string_value=share_reapplier_task_definition.task_definition_arn, + ) + + ssm.StringParameter( + self, + f'ShareReapplierTaskContainerSSM{self._envname}', + parameter_name=f'/dataall/{self._envname}/ecs/container/share_reapplier', + string_value=share_reapplier_container.container_name, + ) + self.ecs_task_definitions_families.append(share_reapplier_task_definition.family) + @run_if(['modules.dataset_base.features.share_notifications.email.persistent_reminders']) + def add_persistent_email_reminders_task(self): + persistent_email_reminders_task, persistent_email_reminders_task_def = self.set_scheduled_task( + cluster=self.ecs_cluster, + command=[ + 'python3.9', + '-m', + 'dataall.modules.shares_base.tasks.persistent_email_reminders_task', + ], + container_id='container', + ecr_repository=self._ecr_repository, + environment=self.env_vars, + image_tag=self._cdkproxy_image_tag, + log_group=self.create_log_group( + self._envname, self._resource_prefix, log_group_name='persistent-email-reminders' + ), + schedule_expression=Schedule.expression('cron(0 9 ? * 2 *)'), # Run at 9:00 AM UTC every Monday + scheduled_task_id=f'{self._resource_prefix}-{self._envname}-persistent-email-reminders-schedule', + task_id=f'{self._resource_prefix}-{self._envname}-persistent-email-reminders', + task_role=self.task_role, + vpc=self._vpc, + security_group=self.scheduled_tasks_sg, + prod_sizing=self._prod_sizing, + ) + self.ecs_task_definitions_families.append(persistent_email_reminders_task.task_definition.family) + @run_if(['modules.s3_datasets.active']) def add_subscription_task(self): subscriptions_task, subscription_task_def = self.set_scheduled_task( @@ -468,7 +523,9 @@ def create_cicd_stacks_updater_role(self, envname, resource_prefix, tooling_acco ) return cicd_stacks_updater_role - def create_task_role(self, envname, resource_prefix, pivot_role_name): + def create_task_role( + self, envname, resource_prefix, pivot_role_name, email_custom_domain=None, ses_configuration_set=None + ): role_inline_policy = iam.Policy( self, f'ECSRolePolicy{envname}', @@ -557,6 +614,18 @@ def create_task_role(self, envname, resource_prefix, pivot_role_name): ), ], ) + + if email_custom_domain and ses_configuration_set: + role_inline_policy.document.add_statements( + iam.PolicyStatement( + actions=['ses:SendEmail'], + resources=[ + f'arn:aws:ses:{self.region}:{self.account}:identity/{email_custom_domain}', + f'arn:aws:ses:{self.region}:{self.account}:configuration-set/{ses_configuration_set}', + ], + ) + ) + task_role = iam.Role( self, f'ECSTaskRole{envname}', @@ -564,6 +633,7 @@ def create_task_role(self, envname, resource_prefix, pivot_role_name): inline_policies={f'ECSRoleInlinePolicy{envname}': role_inline_policy.document}, assumed_by=iam.ServicePrincipal('ecs-tasks.amazonaws.com'), ) + task_role.grant_pass_role(task_role) return task_role diff --git a/deploy/stacks/pipeline.py b/deploy/stacks/pipeline.py index 761df7aa0..1d89e067c 100644 --- a/deploy/stacks/pipeline.py +++ b/deploy/stacks/pipeline.py @@ -757,6 +757,7 @@ def set_cloudfront_stage(self, target_env): custom_domain=target_env.get('custom_domain'), custom_auth=target_env.get('custom_auth', None), custom_waf_rules=target_env.get('custom_waf_rules', None), + backend_region=target_env.get('region', self.region), ) ) front_stage_actions = ( diff --git a/frontend/src/modules/Catalog/components/RequestDashboardAccessModal.js b/frontend/src/modules/Catalog/components/RequestDashboardAccessModal.js index b9c16ed38..2ce4e88fb 100644 --- a/frontend/src/modules/Catalog/components/RequestDashboardAccessModal.js +++ b/frontend/src/modules/Catalog/components/RequestDashboardAccessModal.js @@ -23,9 +23,6 @@ export const RequestDashboardAccessModal = (props) => { const dispatch = useDispatch(); const client = useClient(); const groups = useGroups(); - const idpGroupOptions = groups - ? groups.map((g) => ({ value: g, label: g })) - : []; async function submit(values, setStatus, setSubmitting, setErrors) { try { @@ -117,20 +114,24 @@ export const RequestDashboardAccessModal = (props) => { option.value)} + disablePortal + options={groups} onChange={(event, value) => { - setFieldValue('groupUri', value); + if (value) { + setFieldValue('groupUri', value); + } else { + setFieldValue('groupUri', ''); + } }} - renderInput={(renderParams) => ( + inputValue={values.groupUri} + renderInput={(params) => ( )} diff --git a/frontend/src/modules/Environments/components/EnvironmentRoleAddForm.js b/frontend/src/modules/Environments/components/EnvironmentRoleAddForm.js index 3e7bf4d25..8926de4dc 100644 --- a/frontend/src/modules/Environments/components/EnvironmentRoleAddForm.js +++ b/frontend/src/modules/Environments/components/EnvironmentRoleAddForm.js @@ -1,12 +1,12 @@ import { GroupAddOutlined } from '@mui/icons-material'; import { LoadingButton } from '@mui/lab'; import { + Autocomplete, Box, CardContent, CircularProgress, Dialog, FormControlLabel, - MenuItem, Switch, TextField, Typography @@ -154,23 +154,31 @@ export const EnvironmentRoleAddForm = (props) => { /> - - {groupOptions.map((group) => ( - - {group.label} - - ))} - + option)} + onChange={(event, value) => { + if (value && value.value) { + setFieldValue('groupUri', value.value); + } else { + setFieldValue('groupUri', ''); + } + }} + noOptionsText="No teams found for this environment" + renderInput={(params) => ( + + )} + /> { freeSolo options={groupOptions.map((option) => option.value)} onChange={(event, value) => { - setFieldValue('groupUri', value); + if (value) { + setFieldValue('groupUri', value); + } else { + setFieldValue('groupUri', ''); + } }} + inputValue={values.groupUri} renderInput={(params) => ( { error={Boolean(touched.groupUri && errors.groupUri)} helperText={touched.groupUri && errors.groupUri} onChange={handleChange} - value={values.groupUri} variant="outlined" /> )} diff --git a/frontend/src/modules/Environments/components/NetworkCreateModal.js b/frontend/src/modules/Environments/components/NetworkCreateModal.js index c0cd01e5b..3b3154cfc 100644 --- a/frontend/src/modules/Environments/components/NetworkCreateModal.js +++ b/frontend/src/modules/Environments/components/NetworkCreateModal.js @@ -1,12 +1,12 @@ import { LoadingButton } from '@mui/lab'; import { + Autocomplete, Box, CardContent, CardHeader, Dialog, FormHelperText, Grid, - MenuItem, TextField, Typography } from '@mui/material'; @@ -60,7 +60,7 @@ export const NetworkCreateModal = (props) => { description: values.description, label: values.label, vpcId: values.vpcId, - SamlGroupName: values.SamlGroupName, + SamlGroupName: values.SamlAdminGroupName, privateSubnetIds: values.privateSubnetIds, publicSubnetIds: values.publicSubnetIds }) @@ -124,7 +124,7 @@ export const NetworkCreateModal = (props) => { initialValues={{ label: '', vpcId: '', - SamlGroupName: '', + SamlAdminGroupName: '', privateSubnetIds: [], publicSubnetIds: [], tags: [] @@ -132,7 +132,7 @@ export const NetworkCreateModal = (props) => { validationSchema={Yup.object().shape({ label: Yup.string().max(255).required('*VPC name is required'), vpcId: Yup.string().max(255).required('*VPC ID is required'), - SamlGroupName: Yup.string() + SamlAdminGroupName: Yup.string() .max(255) .required('*Team is required'), privateSubnetIds: Yup.array().nullable(), @@ -229,27 +229,37 @@ export const NetworkCreateModal = (props) => { - option)} + noOptionsText="No teams found for this environment" + onChange={(event, value) => { + if (value && value.value) { + setFieldValue('SamlAdminGroupName', value.value); + } else { + setFieldValue('SamlAdminGroupName', ''); + } + }} + renderInput={(params) => ( + )} - helperText={ - touched.SamlGroupName && errors.SamlGroupName - } - label="Team" - name="SamlGroupName" - onChange={handleChange} - select - value={values.SamlGroupName} - variant="outlined" - > - {groupOptions.map((group) => ( - - {group.label} - - ))} - + /> { const { enqueueSnackbar } = useSnackbar(); const [withDelete, setWithDelete] = useState(false); - const handlePopUpModal = async () => { + const handleReindexStart = async () => { setUpdatingReIndex(true); if (!client) { dispatch({ @@ -212,35 +217,70 @@ export const ReIndexConfirmationPopUp = (props) => { }; return ( - - - - - Are you sure you want to start re-indexing the ENTIRE data.all - Catalog? - - } - /> - - { - setWithDelete(!withDelete); - }} - edge="start" - name="withDelete" - /> - + + + + Start Data.all Catalog Reindexing Task? + + + + + + Starting a reindexing job will update all catalog objects in + data.all with the latest information found in RDS. + + + + + + { + setWithDelete(!withDelete); + }} + edge="start" + name="withDelete" + /> + } + label={ +
+ With Deletes + + Specifying withDeletes will identify catalog objects + no longer in data.all's DB (if any) and attempt to delete / + clean up the catalog + +
+ } + /> +
+
+ + + + + Please confirm if you want to start the reindexing task: + - - +
+
); @@ -263,7 +303,6 @@ export const ReIndexConfirmationPopUp = (props) => { export const MaintenanceViewer = () => { const client = useClient(); const [refreshing, setRefreshing] = useState(false); - // const [refreshingReIndex, setRefreshingReIndex] = useState(false); const refreshingReIndex = false; const [updatingReIndex, setUpdatingReIndex] = useState(false); const [updating, setUpdating] = useState(false); @@ -444,29 +483,33 @@ export const MaintenanceViewer = () => { {refreshingReIndex ? ( ) : ( - - - Re-Index Data.all Catalog} /> - - - setPopUpReIndex(true)} - startIcon={} - sx={{ m: 1 }} - variant="contained" - > - Start Re-Index Catalog Task - +
+ {isModuleEnabled(ModuleNames.CATALOG) && ( + + + Re-Index Data.all Catalog} /> + + + setPopUpReIndex(true)} + startIcon={} + sx={{ m: 1 }} + variant="contained" + > + Start Re-Index Catalog Task + + + + - - - + )} +
)} {refreshing ? ( diff --git a/frontend/src/modules/Organizations/services/listOrganizationGroupPermissions.js b/frontend/src/modules/Organizations/services/listOrganizationGroupPermissions.js index 7de6dbb23..b74020a06 100644 --- a/frontend/src/modules/Organizations/services/listOrganizationGroupPermissions.js +++ b/frontend/src/modules/Organizations/services/listOrganizationGroupPermissions.js @@ -10,8 +10,8 @@ export const listOrganizationGroupPermissions = ({ }, query: gql` query listOrganizationGroupPermissions( - $organizationUri: String - $groupUri: String + $organizationUri: String! + $groupUri: String! ) { listOrganizationGroupPermissions( organizationUri: $organizationUri diff --git a/frontend/src/modules/Pipelines/views/PipelineCreateForm.js b/frontend/src/modules/Pipelines/views/PipelineCreateForm.js index e4a1ecb0e..006beae67 100644 --- a/frontend/src/modules/Pipelines/views/PipelineCreateForm.js +++ b/frontend/src/modules/Pipelines/views/PipelineCreateForm.js @@ -29,13 +29,10 @@ import { useSettings } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; -import { - listEnvironmentGroups, - listValidEnvironments, - useClient -} from 'services'; +import { listValidEnvironments, useClient } from 'services'; import { createDataPipeline } from '../services'; import { PipelineEnvironmentCreateForm } from '../components'; +import { EnvironmentTeamDropdown } from 'modules/Shared'; const PipelineCreateForm = (props) => { const navigate = useNavigate(); @@ -44,7 +41,6 @@ const PipelineCreateForm = (props) => { const client = useClient(); const { settings } = useSettings(); const [loading, setLoading] = useState(true); - const [groupOptions, setGroupOptions] = useState([]); const [environmentOptions, setEnvironmentOptions] = useState([]); const devOptions = [ { value: 'cdk-trunk', label: 'CDK Pipelines - Trunk-based' }, @@ -78,29 +74,6 @@ const PipelineCreateForm = (props) => { setLoading(false); }, [client, dispatch]); - const fetchGroups = async (environmentUri) => { - try { - const response = await client.query( - listEnvironmentGroups({ - filter: Defaults.selectListFilter, - environmentUri - }) - ); - if (!response.errors) { - setGroupOptions( - response.data.listEnvironmentGroups.nodes.map((g) => ({ - value: g.groupUri, - label: g.groupUri - })) - ); - } else { - dispatch({ type: SET_ERROR, error: response.errors[0].message }); - } - } catch (e) { - dispatch({ type: SET_ERROR, error: e.message }); - } - }; - useEffect(() => { if (client) { fetchEnvironments().catch((e) => @@ -330,96 +303,13 @@ const PipelineCreateForm = (props) => { - - { - setFieldValue('SamlGroupName', ''); - fetchGroups( - event.target.value.environmentUri - ).catch((e) => - dispatch({ type: SET_ERROR, error: e.message }) - ); - setFieldValue('environment', event.target.value); - }} - select - value={values.environment} - variant="outlined" - > - {environmentOptions.map((environment) => ( - - {environment.label} - - ))} - - - - { - setFieldValue( - 'SamlGroupName', - event.target.value - ); - }} - select - value={values.SamlGroupName} - variant="outlined" - > - {groupOptions.map((group) => ( - - {group.label} - - ))} - - - - - - - - + { option)} onChange={(event, value) => { if (value && value.value) { @@ -570,37 +570,17 @@ const DatasetCreateForm = (props) => { }} inputValue={values.stewards} renderInput={(params) => ( - - {groupOptions.length > 0 ? ( - - ) : ( - + + helperText={touched.stewards && errors.stewards} + label="Stewards" + onChange={handleChange} + variant="outlined" + /> )} /> diff --git a/frontend/src/modules/S3_Datasets/views/DatasetEditForm.js b/frontend/src/modules/S3_Datasets/views/DatasetEditForm.js index 6177a24d6..b621ac504 100644 --- a/frontend/src/modules/S3_Datasets/views/DatasetEditForm.js +++ b/frontend/src/modules/S3_Datasets/views/DatasetEditForm.js @@ -581,18 +581,25 @@ const DatasetEditForm = (props) => { option.value)} + options={groupOptions.map((option) => option)} onChange={(event, value) => { - setFieldValue('stewards', value); + if (value && value.value) { + setFieldValue('stewards', value.value); + } else { + setFieldValue('stewards', ''); + } }} - renderInput={(renderParams) => ( + inputValue={values.stewards} + renderInput={(params) => ( )} diff --git a/frontend/src/modules/Shares/components/ShareBoxList.js b/frontend/src/modules/Shares/components/ShareBoxList.js index adc48099c..a8611e43e 100644 --- a/frontend/src/modules/Shares/components/ShareBoxList.js +++ b/frontend/src/modules/Shares/components/ShareBoxList.js @@ -11,7 +11,7 @@ import CircularProgress from '@mui/material/CircularProgress'; import CheckBoxOutlineBlankIcon from '@mui/icons-material/CheckBoxOutlineBlank'; import CheckBoxIcon from '@mui/icons-material/CheckBox'; import PropTypes from 'prop-types'; -import { useCallback, useEffect, useState } from 'react'; +import React, { useCallback, useEffect, useState } from 'react'; import { Helmet } from 'react-helmet-async'; import { Defaults, Pager, ShareStatus, useSettings } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; @@ -29,6 +29,9 @@ import { ShareBoxListItem } from './ShareBoxListItem'; import { ShareObjectSelectorModal } from './ShareObjectSelectorModal'; import { NavigateShareViewModal } from './NavigateShareViewModal'; import { ShareStatusList } from '../constants'; +import { RefreshRounded } from '@mui/icons-material'; +import { reApplyShareObjectItemsOnDataset } from '../services/reApplyShareObjectItemsOnDataset'; +import { useSnackbar } from 'notistack'; const icon = ; const checkedIcon = ; @@ -53,7 +56,10 @@ export const ShareBoxList = (props) => { useState(false); const [isNavigateShareViewModalOpen, setIsNavigateShareViewModalOpen] = useState(false); + const [reApplyButtonLoadingState, setreApplyButtonLoadingState] = + useState(false); const statusOptions = ShareStatusList; + const { enqueueSnackbar } = useSnackbar(); const handleVerifyObjectItemsModalOpen = () => { setIsVerifyObjectItemsModalOpen(true); @@ -256,6 +262,33 @@ export const ShareBoxList = (props) => { .finally(() => setLoading(false)); }, [client, dispatch]); + const reapplyShares = async (datasetUri) => { + try { + setreApplyButtonLoadingState(true); + const response = await client.mutate( + reApplyShareObjectItemsOnDataset({ datasetUri: datasetUri }) + ); + if (response && !response.errors) { + setreApplyButtonLoadingState(false); + enqueueSnackbar( + `Reapplying process for all unhealthy shares on dataset with uri: ${datasetUri} has started. Please check each individual share for share item health status`, + { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'success' + } + ); + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } catch (error) { + setreApplyButtonLoadingState(false); + dispatch({ type: SET_ERROR, error: error?.message }); + } + }; + useEffect(() => { setLoading(true); setFilter({ page: 1, pageSize: 10, term: '' }); @@ -337,6 +370,23 @@ export const ShareBoxList = (props) => { )} + {dataset && ( + } + sx={{ m: 1 }} + onClick={(event) => { + reapplyShares(dataset.datasetUri); + }} + type="button" + variant="outlined" + > + Re-apply Share Item(s) for Dataset + + )} + ({ + variables: { + datasetUri + }, + mutation: gql` + mutation reApplyShareObjectItemsOnDataset($datasetUri: String!) { + reApplyShareObjectItemsOnDataset(datasetUri: $datasetUri) + } + ` +}); diff --git a/frontend/src/modules/Worksheets/services/listS3DatasetsSharedWithEnvGroup.js b/frontend/src/modules/Worksheets/services/listS3DatasetsSharedWithEnvGroup.js index a3944d0a1..49dbb37f1 100644 --- a/frontend/src/modules/Worksheets/services/listS3DatasetsSharedWithEnvGroup.js +++ b/frontend/src/modules/Worksheets/services/listS3DatasetsSharedWithEnvGroup.js @@ -9,8 +9,8 @@ export const listS3DatasetsSharedWithEnvGroup = ({ }, query: gql` query listS3DatasetsSharedWithEnvGroup( - $environmentUri: String - $groupUri: String + $environmentUri: String! + $groupUri: String! ) { listS3DatasetsSharedWithEnvGroup( environmentUri: $environmentUri diff --git a/tests/modules/s3_datasets/tasks/test_dataset_catalog_indexer.py b/tests/modules/s3_datasets/tasks/test_dataset_catalog_indexer.py index 259a05d1c..8629f149b 100644 --- a/tests/modules/s3_datasets/tasks/test_dataset_catalog_indexer.py +++ b/tests/modules/s3_datasets/tasks/test_dataset_catalog_indexer.py @@ -57,3 +57,27 @@ def test_catalog_indexer(db, org, env, sync_dataset, table, mocker): indexed_objects_counter = CatalogIndexerTask.index_objects(engine=db) # Count should be One table + One Dataset = 2 assert indexed_objects_counter == 2 + + +def test_catalog_indexer_with_deletes(db, org, env, sync_dataset, table, mocker): + # When Table no longer exists + mocker.patch('dataall.modules.s3_datasets.indexers.table_indexer.DatasetTableIndexer.upsert_all', return_value=[]) + mocker.patch( + 'dataall.modules.s3_datasets.indexers.dataset_indexer.DatasetIndexer.upsert', return_value=sync_dataset + ) + mocker.patch( + 'dataall.modules.catalog.indexers.base_indexer.BaseIndexer.search', + return_value={'hits': {'hits': [{'_id': table.tableUri}]}}, + ) + delete_doc_path = mocker.patch( + 'dataall.modules.catalog.indexers.base_indexer.BaseIndexer.delete_doc', return_value=True + ) + + # And with_deletes 'True' for index_objects + indexed_objects_counter = CatalogIndexerTask.index_objects(engine=db, with_deletes='True') + + # Index Objects Should call Delete Doc 1 time for Table + assert delete_doc_path.call_count == 1 + + # Count should be One Dataset = 1 + assert indexed_objects_counter == 1 diff --git a/tests/modules/s3_datasets/test_dataset_resource_found.py b/tests/modules/s3_datasets/test_dataset_resource_found.py index 231f7ea8f..45c2267d4 100644 --- a/tests/modules/s3_datasets/test_dataset_resource_found.py +++ b/tests/modules/s3_datasets/test_dataset_resource_found.py @@ -1,5 +1,4 @@ from dataall.modules.s3_datasets.db.dataset_models import S3Dataset -from dataall.core.resource_lock.db.resource_lock_models import ResourceLock from dataall.modules.s3_datasets.services.dataset_permissions import CREATE_DATASET @@ -124,8 +123,6 @@ def test_dataset_resource_found(db, client, env_fixture, org_fixture, group2, us assert 'EnvironmentResourcesFound' in response.errors[0].message with db.scoped_session() as session: - dataset_lock = session.query(ResourceLock).filter(ResourceLock.resourceUri == dataset.datasetUri).first() - session.delete(dataset_lock) dataset = session.query(S3Dataset).get(dataset.datasetUri) session.delete(dataset) session.commit() diff --git a/tests/modules/s3_datasets_shares/tasks/test_lf_share_manager.py b/tests/modules/s3_datasets_shares/tasks/test_lf_share_manager.py index 2f9c95b97..451eec950 100644 --- a/tests/modules/s3_datasets_shares/tasks/test_lf_share_manager.py +++ b/tests/modules/s3_datasets_shares/tasks/test_lf_share_manager.py @@ -17,7 +17,7 @@ from dataall.modules.shares_base.services.shares_enums import ShareItemStatus from dataall.modules.shares_base.db.share_object_models import ShareObject, ShareObjectItem from dataall.modules.s3_datasets.db.dataset_models import DatasetTable, S3Dataset -from dataall.modules.s3_datasets_shares.services.dataset_sharing_alarm_service import DatasetSharingAlarmService +from dataall.modules.s3_datasets_shares.services.s3_share_alarm_service import S3ShareAlarmService from dataall.modules.s3_datasets_shares.services.share_processors.glue_table_share_processor import ( ProcessLakeFormationShare, ) @@ -870,7 +870,7 @@ def test_check_catalog_account_exists_and_update_processor_with_catalog_doesnt_e def test_handle_share_failure(manager_with_mocked_clients, table1: DatasetTable, mocker): # Given - alarm_service_mock = mocker.patch.object(DatasetSharingAlarmService, 'trigger_table_sharing_failure_alarm') + alarm_service_mock = mocker.patch.object(S3ShareAlarmService, 'trigger_table_sharing_failure_alarm') error = Exception() manager, lf_client, glue_client, mock_glue_client = manager_with_mocked_clients @@ -887,7 +887,7 @@ def test_handle_revoke_failure( mocker, ): # Given - alarm_service_mock = mocker.patch.object(DatasetSharingAlarmService, 'trigger_revoke_table_sharing_failure_alarm') + alarm_service_mock = mocker.patch.object(S3ShareAlarmService, 'trigger_revoke_table_sharing_failure_alarm') error = Exception() manager, lf_client, glue_client, mock_glue_client = manager_with_mocked_clients diff --git a/tests/modules/s3_datasets_shares/tasks/test_s3_access_point_share_manager.py b/tests/modules/s3_datasets_shares/tasks/test_s3_access_point_share_manager.py index 7f8cd7ca2..b6ac04b86 100644 --- a/tests/modules/s3_datasets_shares/tasks/test_s3_access_point_share_manager.py +++ b/tests/modules/s3_datasets_shares/tasks/test_s3_access_point_share_manager.py @@ -10,7 +10,7 @@ from dataall.core.organizations.db.organization_models import Organization from dataall.modules.s3_datasets_shares.aws.s3_client import S3ControlClient from dataall.modules.shares_base.db.share_object_models import ShareObject, ShareObjectItem -from dataall.modules.s3_datasets_shares.services.managed_share_policy_service import SharePolicyService +from dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service import S3SharePolicyService from dataall.modules.s3_datasets_shares.services.share_managers import S3AccessPointShareManager from dataall.modules.s3_datasets.db.dataset_models import DatasetStorageLocation, S3Dataset from dataall.modules.shares_base.services.sharing_service import ShareData @@ -334,11 +334,11 @@ def test_grant_target_role_access_policy_test_empty_policy( return_value=None, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -373,7 +373,7 @@ def test_grant_target_role_access_policy_test_empty_policy( # When share_manager.grant_target_role_access_policy() - expected_policy_name = SharePolicyService( + expected_policy_name = S3SharePolicyService( environmentUri=target_environment.environmentUri, role_name=share1.principalIAMRoleName, account=target_environment.AwsAccountId, @@ -400,11 +400,11 @@ def test_grant_target_role_access_policy_existing_policy_bucket_not_included( mocker.patch('dataall.base.aws.iam.IAM.get_managed_policy_default_version', return_value=('v1', iam_policy)) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -424,10 +424,10 @@ def test_grant_target_role_access_policy_existing_policy_bucket_not_included( # Iam function is called with str from object so we transform back to object policy_object = iam_policy - s3_index = SharePolicyService._get_statement_by_sid( + s3_index = S3SharePolicyService._get_statement_by_sid( policy=policy_object, sid=f'{IAM_S3_ACCESS_POINTS_STATEMENT_SID}S3' ) - kms_index = SharePolicyService._get_statement_by_sid( + kms_index = S3SharePolicyService._get_statement_by_sid( policy=policy_object, sid=f'{IAM_S3_ACCESS_POINTS_STATEMENT_SID}KMS' ) @@ -440,7 +440,7 @@ def test_grant_target_role_access_policy_existing_policy_bucket_not_included( ) # Assert that statements for S3 bucket sharing are unaffected - s3_index = SharePolicyService._get_statement_by_sid(policy=policy_object, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') + s3_index = S3SharePolicyService._get_statement_by_sid(policy=policy_object, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') def test_grant_target_role_access_policy_existing_policy_bucket_included(mocker, share_manager): @@ -452,12 +452,12 @@ def test_grant_target_role_access_policy_existing_policy_bucket_included(mocker, mocker.patch('dataall.base.aws.iam.IAM.get_managed_policy_default_version', return_value=('v1', iam_policy)) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -903,12 +903,12 @@ def test_delete_target_role_access_policy_no_remaining_statement( 'dataall.base.aws.iam.IAM.get_managed_policy_default_version', return_value=('v1', existing_target_role_policy) ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -928,7 +928,7 @@ def test_delete_target_role_access_policy_no_remaining_statement( # When we revoke IAM access to the target IAM role share_manager.revoke_target_role_access_policy() - expected_policy_name = SharePolicyService( + expected_policy_name = S3SharePolicyService( environmentUri=target_environment.environmentUri, role_name=share1.principalIAMRoleName, account=target_environment.AwsAccountId, @@ -1006,12 +1006,12 @@ def test_delete_target_role_access_policy_with_remaining_statement( ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -1032,7 +1032,7 @@ def test_delete_target_role_access_policy_with_remaining_statement( share_manager.revoke_target_role_access_policy() # Then - expected_policy_name = SharePolicyService( + expected_policy_name = S3SharePolicyService( environmentUri=target_environment.environmentUri, role_name=share1.principalIAMRoleName, account=target_environment.AwsAccountId, @@ -1199,12 +1199,12 @@ def test_check_bucket_policy_missing_sid(mocker, base_bucket_policy, share_manag def test_check_target_role_access_policy(mocker, share_manager): # Given mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -1235,12 +1235,12 @@ def test_check_target_role_access_policy(mocker, share_manager): def test_check_target_role_access_policy_existing_policy_bucket_and_key_not_included(mocker, share_manager): # Given mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -1272,12 +1272,12 @@ def test_check_target_role_access_policy_existing_policy_bucket_and_key_not_incl def test_check_target_role_access_policy_test_no_policy(mocker, share_manager): # Given mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=False, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -1293,12 +1293,12 @@ def test_check_target_role_access_policy_test_no_policy(mocker, share_manager): def test_check_target_role_access_policy_test_policy_not_attached(mocker, share_manager): # Given mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) # Policy is not attached mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=False, ) diff --git a/tests/modules/s3_datasets_shares/tasks/test_s3_bucket_share_manager.py b/tests/modules/s3_datasets_shares/tasks/test_s3_bucket_share_manager.py index 3e560717a..840c20568 100644 --- a/tests/modules/s3_datasets_shares/tasks/test_s3_bucket_share_manager.py +++ b/tests/modules/s3_datasets_shares/tasks/test_s3_bucket_share_manager.py @@ -9,7 +9,7 @@ from dataall.core.organizations.db.organization_models import Organization from dataall.modules.shares_base.db.share_object_models import ShareObject from dataall.modules.s3_datasets_shares.services.share_managers import S3BucketShareManager -from dataall.modules.s3_datasets_shares.services.managed_share_policy_service import SharePolicyService +from dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service import S3SharePolicyService from dataall.modules.s3_datasets.db.dataset_models import S3Dataset, DatasetBucket from dataall.modules.shares_base.services.sharing_service import ShareData @@ -418,7 +418,7 @@ def test_grant_s3_iam_access_with_no_policy(mocker, dataset2, share2_manager): # Check if the get and update_role_policy func are called and policy statements are added mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=False, ) kms_client = mock_kms_client(mocker) @@ -429,15 +429,15 @@ def test_grant_s3_iam_access_with_no_policy(mocker, dataset2, share2_manager): 'Statement': [{'Sid': EMPTY_STATEMENT_SID, 'Effect': 'Allow', 'Action': 'none:null', 'Resource': '*'}], } share_policy_service_mock_1 = mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.create_managed_policy_from_inline_and_delete_inline', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.create_managed_policy_from_inline_and_delete_inline', return_value='arn:iam::someArn', ) share_policy_service_mock_2 = mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.attach_policy', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.attach_policy', return_value=True, ) share_policy_service_mock_3 = mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=False, ) iam_update_role_policy_mock_1 = mocker.patch( @@ -457,8 +457,8 @@ def test_grant_s3_iam_access_with_no_policy(mocker, dataset2, share2_manager): iam_policy = empty_policy_document - s3_index = SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') - kms_index = SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') + s3_index = S3SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') + kms_index = S3SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') # Assert if the IAM role policy with S3 and KMS permissions was created assert len(iam_policy['Statement']) == 2 @@ -484,11 +484,11 @@ def test_grant_s3_iam_access_with_empty_policy(mocker, dataset2, share2_manager) # Check if the get and update_role_policy func are called and policy statements are added mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) kms_client = mock_kms_client(mocker) @@ -513,8 +513,8 @@ def test_grant_s3_iam_access_with_empty_policy(mocker, dataset2, share2_manager) iam_policy = initial_policy_document - s3_index = SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') - kms_index = SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') + s3_index = S3SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') + kms_index = S3SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') # Assert if the IAM role policy with S3 and KMS permissions was created assert len(iam_policy['Statement']) == 2 @@ -558,15 +558,15 @@ def test_grant_s3_iam_access_with_policy_and_target_resources_not_present(mocker } mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) - s3_index = SharePolicyService._get_statement_by_sid(policy=policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') - kms_index = SharePolicyService._get_statement_by_sid(policy=policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') + s3_index = S3SharePolicyService._get_statement_by_sid(policy=policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') + kms_index = S3SharePolicyService._get_statement_by_sid(policy=policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') assert len(policy['Statement']) == 2 assert len(policy['Statement'][s3_index]['Resource']) == 2 @@ -630,11 +630,11 @@ def test_grant_s3_iam_access_with_complete_policy_present(mocker, dataset2, shar } mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) kms_client = mock_kms_client(mocker) @@ -654,8 +654,8 @@ def test_grant_s3_iam_access_with_complete_policy_present(mocker, dataset2, shar created_iam_policy = policy - s3_index = SharePolicyService._get_statement_by_sid(policy=policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') - kms_index = SharePolicyService._get_statement_by_sid(policy=policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') + s3_index = S3SharePolicyService._get_statement_by_sid(policy=policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') + kms_index = S3SharePolicyService._get_statement_by_sid(policy=policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') assert len(created_iam_policy['Statement']) == 2 assert ( @@ -920,7 +920,7 @@ def test_delete_target_role_access_no_policy_no_other_resources_shared( } mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=False, ) @@ -928,11 +928,11 @@ def test_delete_target_role_access_no_policy_no_other_resources_shared( kms_client().get_key_id.return_value = 'kms-key' share_policy_service_mock_1 = mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.create_managed_policy_from_inline_and_delete_inline', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.create_managed_policy_from_inline_and_delete_inline', return_value='arn:iam::someArn', ) share_policy_service_mock_2 = mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.attach_policy', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.attach_policy', return_value=True, ) @@ -955,10 +955,10 @@ def test_delete_target_role_access_no_policy_no_other_resources_shared( # Get the updated IAM policy and compare it with the existing one updated_iam_policy = policy_document - s3_index = SharePolicyService._get_statement_by_sid( + s3_index = S3SharePolicyService._get_statement_by_sid( policy=updated_iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3' ) - kms_index = SharePolicyService._get_statement_by_sid( + kms_index = S3SharePolicyService._get_statement_by_sid( policy=updated_iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS' ) @@ -992,11 +992,11 @@ def test_delete_target_role_access_policy_no_resource_of_datasets_s3_bucket( } mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -1019,8 +1019,8 @@ def test_delete_target_role_access_policy_no_resource_of_datasets_s3_bucket( # Get the updated IAM policy and compare it with the existing one updated_iam_policy = iam_policy - s3_index = SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') - kms_index = SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') + s3_index = S3SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') + kms_index = S3SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') assert len(updated_iam_policy['Statement']) == 2 assert 'arn:aws:s3:::someOtherBucket,arn:aws:s3:::someOtherBucket/*' == ','.join( @@ -1065,11 +1065,11 @@ def test_delete_target_role_access_policy_with_multiple_s3_buckets_in_policy( } mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -1092,8 +1092,8 @@ def test_delete_target_role_access_policy_with_multiple_s3_buckets_in_policy( updated_iam_policy = iam_policy - s3_index = SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') - kms_index = SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') + s3_index = S3SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}S3') + kms_index = S3SharePolicyService._get_statement_by_sid(policy=iam_policy, sid=f'{IAM_S3_BUCKETS_STATEMENT_SID}KMS') assert f'arn:aws:s3:::{dataset2.S3BucketName}' not in updated_iam_policy['Statement'][s3_index]['Resource'] assert f'arn:aws:s3:::{dataset2.S3BucketName}/*' not in updated_iam_policy['Statement'][s3_index]['Resource'] @@ -1139,11 +1139,11 @@ def test_delete_target_role_access_policy_with_one_s3_bucket_and_one_kms_resourc } mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) @@ -1397,11 +1397,11 @@ def test_check_s3_iam_access(mocker, dataset2, share2_manager): ], } mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) # Gets policy with S3 and KMS @@ -1425,7 +1425,7 @@ def test_check_s3_iam_access_no_policy(mocker, dataset2, share2_manager): # When policy does not exist iam_update_role_policy_mock_1 = mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=False, ) @@ -1446,11 +1446,11 @@ def test_check_s3_iam_access_policy_not_attached(mocker, dataset2, share2_manage # When policy does not exist iam_update_role_policy_mock_1 = mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=False, ) @@ -1481,11 +1481,11 @@ def test_check_s3_iam_access_missing_policy_statement(mocker, dataset2, share2_m ], } mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) # Gets policy with other S3 and KMS @@ -1529,11 +1529,11 @@ def test_check_s3_iam_access_missing_target_resource(mocker, dataset2, share2_ma ], } mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) # Gets policy with other S3 and KMS diff --git a/tests/modules/s3_datasets_shares/test_share.py b/tests/modules/s3_datasets_shares/test_share.py index 988b14f5e..68670a65b 100644 --- a/tests/modules/s3_datasets_shares/test_share.py +++ b/tests/modules/s3_datasets_shares/test_share.py @@ -25,8 +25,8 @@ @pytest.fixture(scope='function') def mock_glue_client(mocker): glue_client = MagicMock() - mocker.patch('dataall.modules.s3_datasets_shares.services.share_item_service.GlueClient', return_value=glue_client) - glue_client.get_source_catalog.return_value = None + mocker.patch('dataall.modules.s3_datasets_shares.aws.glue_client.GlueClient', return_value=glue_client) + glue_client.get_glue_database_from_catalog.return_value = None def random_table_name(): @@ -431,7 +431,7 @@ def create_share_object( } """ mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.create_managed_policy_from_inline_and_delete_inline', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.create_managed_policy_from_inline_and_delete_inline', return_value=True, ) @@ -864,11 +864,11 @@ def test_create_share_object_as_requester(mocker, client, user2, group2, env2gro # SharePolicy exists and is attached # When a user that belongs to environment and group creates request mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) create_share_object_response = create_share_object( @@ -893,11 +893,11 @@ def test_create_share_object_as_approver_and_requester(mocker, client, user, gro # SharePolicy exists and is attached # When a user that belongs to environment and group creates request mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) create_share_object_response = create_share_object( @@ -924,11 +924,11 @@ def test_create_share_object_with_item_authorized( # SharePolicy exists and is attached # When a user that belongs to environment and group creates request with table in the request mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=True, ) create_share_object_response = create_share_object( @@ -969,15 +969,15 @@ def test_create_share_object_share_policy_not_attached_attachMissingPolicies_ena # SharePolicy exists and is NOT attached, attachMissingPolicies=True # When a correct user creates request, data.all attaches the policy and the share creates successfully mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=False, ) attach_mocker = mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.attach_policy', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.attach_policy', return_value=True, ) create_share_object_response = create_share_object( @@ -1006,15 +1006,15 @@ def test_create_share_object_share_policy_not_attached_attachMissingPolicies_dis # SharePolicy exists and is NOT attached, attachMissingPolicies=True but principal=Group so managed=Trye # When a correct user creates request, data.all attaches the policy and the share creates successfully mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=False, ) attach_mocker = mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.attach_policy', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.attach_policy', return_value=True, ) create_share_object_response = create_share_object( @@ -1043,11 +1043,11 @@ def test_create_share_object_share_policy_not_attached_attachMissingPolicies_dis # SharePolicy exists and is NOT attached, attachMissingPolicies=True # When a correct user creates request, data.all attaches the policy and the share creates successfully mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_exists', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_exists', return_value=True, ) mocker.patch( - 'dataall.modules.s3_datasets_shares.services.managed_share_policy_service.SharePolicyService.check_if_policy_attached', + 'dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service.S3SharePolicyService.check_if_policy_attached', return_value=False, ) consumption_role = type('consumption_role', (object,), {})() diff --git a/tests_new/integration_tests/README.md b/tests_new/integration_tests/README.md index ef8b1fd1e..9bb9f5575 100644 --- a/tests_new/integration_tests/README.md +++ b/tests_new/integration_tests/README.md @@ -2,6 +2,12 @@ The purpose of these tests is to automatically validate functionalities of data.all on a real deployment. +🚨🚨🚨 + +Currently **we support only Cognito based deployments** but support for any IdP is on the plans + +🚨🚨🚨 + ## Pre-requisites - A real deployment of data.all in AWS @@ -57,32 +63,25 @@ The purpose of these tests is to automatically validate functionalities of data. } } ``` -- If you are not using Cognito then you must manually create the users/groups -- If you are using Cognito the pipeline will create the users/groups +- The pipeline will create the users/groups ## Run tests The tests are executed in CodeBuild as part of the CICD pipeline if the cdk.json parameter `with_approval_tests` is set to True. -But you can also run the tests locally with deployment account credentials: - -```bash -export ENVNAME = "Introduce deployment environment name" -export AWS_REGION = "Introduce backend region" -make integration-tests -``` +You can also run the tests locally by... -or run the tests locally without credentials +* Authenticating to your data.all environment account (you might want to set the `AWS_PROFILE` env variable) -```bash -export ENVNAME = "Introduce deployment environment name" -export AWS_REGION = "Introduce backend region" -export COGNITO_CLIENT = "Introduce Cognito client id" -export API_ENDPOINT = "Introduce API endpoint url" -echo "add your testdata here" > testdata.json -make integration-tests -``` +* ```bash + export ENVNAME = "Introduce deployment environment name" + export AWS_REGION = "Introduce backend region" + export COGNITO_CLIENT = "Introduce Cognito client id" + export API_ENDPOINT = "Introduce API endpoint url" + echo "add your testdata here" > testdata.json + make integration-tests + ``` ## Coverage diff --git a/tests_new/integration_tests/core/environment/global_conftest.py b/tests_new/integration_tests/core/environment/global_conftest.py index 2b0de0e7f..086fa2304 100644 --- a/tests_new/integration_tests/core/environment/global_conftest.py +++ b/tests_new/integration_tests/core/environment/global_conftest.py @@ -1,33 +1,33 @@ import logging -import re import pytest from integration_tests.client import GqlError -from integration_tests.core.environment.queries import create_environment, get_environment, delete_environment -from integration_tests.utils import poller +from integration_tests.core.environment.queries import ( + create_environment, + get_environment, + delete_environment, + list_environments, + invite_group_on_env, +) +from integration_tests.core.organizations.queries import create_organization +from integration_tests.core.stack.utils import check_stack_ready log = logging.getLogger(__name__) -@poller(check_success=lambda env: not re.match(r'.*IN_PROGRESS|PENDING', env.stack.status, re.IGNORECASE), timeout=600) -def check_env_ready(client, env_uri): - env = get_environment(client, env_uri) - log.info(f'polling {env_uri=}, new {env.stack.status=}') - return env +def create_env(client, group, org_uri, account_id, region, tags=[]): + env = create_environment( + client, name='testEnvA', group=group, organizationUri=org_uri, awsAccountId=account_id, region=region, tags=tags + ) + check_stack_ready(client, env.environmentUri, env.stack.stackUri) + return get_environment(client, env.environmentUri) -def create_env(client, group, org_uri, account_id, region): - new_env_uri = create_environment( - client, name='testEnvA', group=group, organizationUri=org_uri, awsAccountId=account_id, region=region - )['environmentUri'] - return check_env_ready(client, new_env_uri) - - -def delete_env(client, env_uri): - check_env_ready(client, env_uri) +def delete_env(client, env): + check_stack_ready(client, env.environmentUri, env.stack.stackUri) try: - return delete_environment(client, env_uri) + return delete_environment(client, env.environmentUri) except GqlError: log.exception('unexpected error when deleting environment') return False @@ -40,27 +40,28 @@ def delete_env(client, env_uri): @pytest.fixture(scope='session') -def session_env1(client1, group1, org1, testdata): +def session_env1(client1, group1, org1, session_id, testdata): envdata = testdata.envs['session_env1'] env = None try: - env = create_env(client1, group1, org1['organizationUri'], envdata.accountId, envdata.region) + env = create_env(client1, group1, org1.organizationUri, envdata.accountId, envdata.region, tags=[session_id]) yield env finally: if env: - delete_env(client1, env['environmentUri']) + delete_env(client1, env) @pytest.fixture(scope='session') -def session_env2(client1, group1, org1, testdata): +def session_env2(client1, group1, group2, org2, session_id, testdata): envdata = testdata.envs['session_env2'] env = None try: - env = create_env(client1, group1, org1['organizationUri'], envdata.accountId, envdata.region) + env = create_env(client1, group1, org2.organizationUri, envdata.accountId, envdata.region, tags=[session_id]) + invite_group_on_env(client1, env.environmentUri, group2, ['CREATE_DATASET']) yield env finally: if env: - delete_env(client1, env['environmentUri']) + delete_env(client1, env) """ @@ -74,11 +75,11 @@ def temp_env1(client1, group1, org1, testdata): envdata = testdata.envs['temp_env1'] env = None try: - env = create_env(client1, group1, org1['organizationUri'], envdata.accountId, envdata.region) + env = create_env(client1, group1, org1.organizationUri, envdata.accountId, envdata.region) yield env finally: if env: - delete_env(client1, env['environmentUri']) + delete_env(client1, env) """ @@ -87,5 +88,21 @@ def temp_env1(client1, group1, org1, testdata): """ -@pytest.fixture(scope='function') -def persistent_env1(client1, group1, org1, testdata): ... # TODO +def get_or_create_persistent_env(env_name, client, group, testdata): + envs = list_environments(client, term=env_name).nodes + if envs: + return envs[0] + else: + envdata = testdata.envs[env_name] + org = create_organization(client, f'org_{env_name}', group) + env = create_env(client, group, org.organizationUri, envdata.accountId, envdata.region, tags=[env_name]) + if env.stack.status in ['CREATE_COMPLETE', 'UPDATE_COMPLETE']: + return env + else: + delete_env(client, env['environmentUri']) + raise RuntimeError(f'failed to create {env_name=} {env=}') + + +@pytest.fixture(scope='session') +def persistent_env1(client1, group1, testdata): + return get_or_create_persistent_env('persistent_env1', client1, group1, testdata) diff --git a/tests_new/integration_tests/core/environment/queries.py b/tests_new/integration_tests/core/environment/queries.py index 36dd0060c..98135d89d 100644 --- a/tests_new/integration_tests/core/environment/queries.py +++ b/tests_new/integration_tests/core/environment/queries.py @@ -1,4 +1,55 @@ -def create_environment(client, name, group, organizationUri, awsAccountId, region): +ENV_TYPE = """ +environmentUri +created +userRoleInEnvironment +description +name +label +AwsAccountId +region +owner +tags +SamlGroupName +EnvironmentDefaultBucketName +EnvironmentDefaultIAMRoleArn +EnvironmentDefaultIAMRoleName +EnvironmentDefaultIAMRoleImported +resourcePrefix +subscriptionsEnabled +subscriptionsProducersTopicImported +subscriptionsConsumersTopicImported +subscriptionsConsumersTopicName +subscriptionsProducersTopicName +organization { + organizationUri + label + name +} +stack { + stack + status + stackUri + targetUri + accountid + region + stackid + link + outputs + resources +} +networks { + VpcId + privateSubnetIds + publicSubnetIds +} +parameters { + key + value +} +""" + + +def create_environment(client, name, group, organizationUri, awsAccountId, region, tags): query = { 'operationName': 'CreateEnvironment', 'variables': { @@ -9,24 +60,15 @@ def create_environment(client, name, group, organizationUri, awsAccountId, regio 'AwsAccountId': awsAccountId, 'region': region, 'description': 'Created for integration testing', - 'tags': [], + 'tags': tags, } }, - 'query': """ - mutation CreateEnvironment($input: NewEnvironmentInput!) { - createEnvironment(input: $input) { - environmentUri - label - userRoleInEnvironment - SamlGroupName - AwsAccountId - created - parameters { - key - value - } - } - } + 'query': f""" + mutation CreateEnvironment($input: NewEnvironmentInput!) {{ + createEnvironment(input: $input) {{ + {ENV_TYPE} + }} + }} """, } response = client.query(query=query) @@ -37,58 +79,12 @@ def get_environment(client, environmentUri): query = { 'operationName': 'GetEnvironment', 'variables': {'environmentUri': environmentUri}, - 'query': """ - query GetEnvironment($environmentUri: String!) { - getEnvironment(environmentUri: $environmentUri) { - environmentUri - created - userRoleInEnvironment - description - name - label - AwsAccountId - region - owner - tags - SamlGroupName - EnvironmentDefaultBucketName - EnvironmentDefaultIAMRoleArn - EnvironmentDefaultIAMRoleName - EnvironmentDefaultIAMRoleImported - resourcePrefix - subscriptionsEnabled - subscriptionsProducersTopicImported - subscriptionsConsumersTopicImported - subscriptionsConsumersTopicName - subscriptionsProducersTopicName - organization { - organizationUri - label - name - } - stack { - stack - status - stackUri - targetUri - accountid - region - stackid - link - outputs - resources - } - networks { - VpcId - privateSubnetIds - publicSubnetIds - } - parameters { - key - value - } - } - } + 'query': f""" + query GetEnvironment($environmentUri: String!) {{ + getEnvironment(environmentUri: $environmentUri) {{ + {ENV_TYPE} + }} + }} """, } response = client.query(query=query) @@ -125,26 +121,139 @@ def update_environment(client, environmentUri, input: dict): 'environmentUri': environmentUri, 'input': input, }, - 'query': """ + 'query': f""" mutation UpdateEnvironment( $environmentUri: String! $input: ModifyEnvironmentInput! - ) { - updateEnvironment(environmentUri: $environmentUri, input: $input) { + ) {{ + updateEnvironment(environmentUri: $environmentUri, input: $input) {{ + {ENV_TYPE} + }} + }} + """, + } + response = client.query(query=query) + return response.data.updateEnvironment + + +def list_environments(client, term=''): + query = { + 'operationName': 'ListEnvironments', + 'variables': { + 'filter': {'term': term}, + }, + 'query': f""" + query ListEnvironments($filter: EnvironmentFilter) {{ + listEnvironments(filter: $filter) {{ + count + page + pages + hasNext + hasPrevious + nodes {{ + {ENV_TYPE} + }} + }} + }} + """, + } + response = client.query(query=query) + return response.data.listEnvironments + + +def invite_group_on_env(client, env_uri, group_uri, perms, iam_role_arn=None): + query = { + 'operationName': 'inviteGroupOnEnvironment', + 'variables': { + 'input': { + 'environmentUri': env_uri, + 'groupUri': group_uri, + 'permissions': perms, + 'environmentIAMRoleArn': iam_role_arn, + }, + }, + 'query': """ + mutation inviteGroupOnEnvironment($input: InviteGroupOnEnvironmentInput!) { + inviteGroupOnEnvironment(input: $input) { environmentUri - label - userRoleInEnvironment - SamlGroupName - AwsAccountId - description - created - parameters { - key - value - } } } """, } response = client.query(query=query) - return response.data.updateEnvironment + return response.data.inviteGroupOnEnvironment + + +def remove_group_from_env(client, env_uri, group_uri): + query = { + 'operationName': 'removeGroupFromEnvironment', + 'variables': {'environmentUri': env_uri, 'groupUri': group_uri}, + 'query': """ + mutation removeGroupFromEnvironment( + $environmentUri: String! + $groupUri: String! + ) { + removeGroupFromEnvironment( + environmentUri: $environmentUri + groupUri: $groupUri + ) { + environmentUri + } + } + """, + } + response = client.query(query=query) + return response.data.removeGroupFromEnvironment + + +def add_consumption_role(client, env_uri, group_uri, consumption_role_name, iam_role_arn, is_managed=True): + query = { + 'operationName': 'addConsumptionRoleToEnvironment', + 'variables': { + 'input': { + 'environmentUri': env_uri, + 'groupUri': group_uri, + 'consumptionRoleName': consumption_role_name, + 'IAMRoleArn': iam_role_arn, + 'dataallManaged': is_managed, + }, + }, + 'query': """ + mutation addConsumptionRoleToEnvironment( + $input: AddConsumptionRoleToEnvironmentInput! + ) { + addConsumptionRoleToEnvironment(input: $input) { + consumptionRoleUri + consumptionRoleName + environmentUri + groupUri + IAMRoleArn + } + } + """, + } + response = client.query(query=query) + return response.data.addConsumptionRoleToEnvironment + + +def remove_consumption_role(client, env_uri, consumption_role_uri): + query = { + 'operationName': 'removeConsumptionRoleFromEnvironment', + 'variables': { + 'environmentUri': env_uri, + 'consumptionRoleUri': consumption_role_uri, + }, + 'query': """ + mutation removeConsumptionRoleFromEnvironment( + $environmentUri: String! + $consumptionRoleUri: String! + ) { + removeConsumptionRoleFromEnvironment( + environmentUri: $environmentUri + consumptionRoleUri: $consumptionRoleUri + ) + } + """, + } + response = client.query(query=query) + return response.data.removeConsumptionRoleFromEnvironment diff --git a/tests_new/integration_tests/core/environment/test_environment.py b/tests_new/integration_tests/core/environment/test_environment.py index 648c22eef..55d44188a 100644 --- a/tests_new/integration_tests/core/environment/test_environment.py +++ b/tests_new/integration_tests/core/environment/test_environment.py @@ -1,22 +1,35 @@ +import logging from datetime import datetime from assertpy import assert_that -from integration_tests.core.environment.queries import update_environment, get_environment +from integration_tests.core.environment.queries import ( + get_environment, + update_environment, + list_environments, + invite_group_on_env, + add_consumption_role, + remove_consumption_role, + remove_group_from_env, +) +from integration_tests.core.stack.queries import update_stack +from integration_tests.core.stack.utils import check_stack_in_progress, check_stack_ready from integration_tests.errors import GqlError +log = logging.getLogger(__name__) + def test_create_env(session_env1): - assert_that(session_env1.stack.status).is_equal_to('CREATE_COMPLETE') + assert_that(session_env1.stack.status).is_in('CREATE_COMPLETE', 'UPDATE_COMPLETE') def test_modify_env(client1, session_env1): test_description = f'a test description {datetime.utcnow().isoformat()}' env_uri = session_env1.environmentUri updated_env = update_environment(client1, env_uri, {'description': test_description}) - assert_that(updated_env).contains_entry({'environmentUri': env_uri}, {'description': test_description}) + assert_that(updated_env).contains_entry(environmentUri=env_uri, description=test_description) env = get_environment(client1, env_uri) - assert_that(env).contains_entry({'environmentUri': env_uri}, {'description': test_description}) + assert_that(env).contains_entry(environmentUri=env_uri, description=test_description) def test_modify_env_unauthorized(client1, client2, session_env1): @@ -26,6 +39,81 @@ def test_modify_env_unauthorized(client1, client2, session_env1): client2, env_uri, {'description': test_description} ).contains('UnauthorizedOperation', env_uri) env = get_environment(client1, env_uri) - assert_that(env).contains_entry({'environmentUri': env_uri}).does_not_contain_entry( - {'description': test_description} - ) + assert_that(env).contains_entry(environmentUri=env_uri).does_not_contain_entry(description=test_description) + + +def test_list_envs_authorized(client1, session_env1, session_env2, session_id): + assert_that(list_environments(client1, term=session_id).nodes).is_length(2) + + +def test_list_envs_invited(client2, session_env1, session_env2, session_id): + assert_that(list_environments(client2, term=session_id).nodes).is_length(1) + + +def test_persistent_env_update(client1, persistent_env1): + # wait for stack to get to a final state before triggering an update + stack_uri = persistent_env1.stack.stackUri + env_uri = persistent_env1.environmentUri + check_stack_ready(client1, env_uri, stack_uri) + update_stack(client1, env_uri, 'environment') + # wait for stack to move to "in_progress" state + check_stack_in_progress(client1, env_uri, stack_uri) + stack = check_stack_ready(client1, env_uri, stack_uri) + assert_that(stack.status).is_equal_to('UPDATE_COMPLETE') + + +def test_invite_group_on_env_no_org(client1, session_env2, group4): + assert_that(invite_group_on_env).raises(GqlError).when_called_with( + client1, session_env2.environmentUri, group4, ['CREATE_DATASET'] + ).contains(group4, 'is not a member of the organization') + + +def test_invite_group_on_env_unauthorized(client2, session_env2, group2): + assert_that(invite_group_on_env).raises(GqlError).when_called_with( + client2, session_env2.environmentUri, group2, ['CREATE_DATASET'] + ).contains('UnauthorizedOperation', 'INVITE_ENVIRONMENT_GROUP', session_env2.environmentUri) + + +def test_invite_group_on_env(client1, client2, session_env2, group2): + env_uri = session_env2.environmentUri + assert_that(list_environments(client2).nodes).extracting('environmentUri').contains(env_uri) + # assert that client2 can get the environment + assert_that(get_environment(client2, env_uri)).contains_entry(userRoleInEnvironment='Invited') + + +def test_invite_remove_group_on_env(client1, client3, session_env2, group3): + env_uri = session_env2.environmentUri + try: + assert_that(list_environments(client3).nodes).extracting('environmentUri').does_not_contain(env_uri) + assert_that(invite_group_on_env(client1, env_uri, group3, ['CREATE_DATASET'])).contains_entry( + environmentUri=env_uri + ) + # assert that client3 can get the environment + assert_that(get_environment(client3, env_uri)).contains_entry(userRoleInEnvironment='Invited') + finally: + assert_that(remove_group_from_env(client1, env_uri, group3)).contains_entry(environmentUri=env_uri) + assert_that(get_environment).raises(GqlError).when_called_with(client3, env_uri).contains( + 'UnauthorizedOperation', 'GET_ENVIRONMENT', env_uri + ) + + +def test_add_remove_consumption_role(client1, session_env2, group1): + env_uri = session_env2.environmentUri + consumption_role = None + try: + consumption_role = add_consumption_role( + client1, env_uri, group1, 'TestConsumptionRole', f'arn:aws:iam::{session_env2.AwsAccountId}:role/Admin' + ) + assert_that(consumption_role).contains_key( + 'consumptionRoleUri', 'consumptionRoleName', 'environmentUri', 'groupUri', 'IAMRoleArn' + ) + finally: + if consumption_role: + assert_that(remove_consumption_role(client1, env_uri, consumption_role.consumptionRoleUri)).is_true() + + +def test_add_consumption_role_unauthorized(client2, session_env2, group1): + env_uri = session_env2.environmentUri + assert_that(add_consumption_role).raises(GqlError).when_called_with( + client2, env_uri, group1, 'TestConsumptionRole', f'arn:aws:iam::{session_env2.AwsAccountId}:role/Admin' + ).contains('UnauthorizedOperation', 'ADD_ENVIRONMENT_CONSUMPTION_ROLES', env_uri) diff --git a/tests_new/integration_tests/core/organizations/global_conftest.py b/tests_new/integration_tests/core/organizations/global_conftest.py index 1177a8121..40dded5ca 100644 --- a/tests_new/integration_tests/core/organizations/global_conftest.py +++ b/tests_new/integration_tests/core/organizations/global_conftest.py @@ -13,12 +13,13 @@ def org1(client1, group1, session_id): @pytest.fixture(scope='session') -def org2(client1, group1, group2, session_id): +def org2(client1, group1, group2, group3, session_id): """ Session org owned by group1 and invite group2 """ org = create_organization(client1, 'organization2', group1, tags=[session_id]) invite_team_to_organization(client=client1, organizationUri=org.organizationUri, group=group2) + invite_team_to_organization(client=client1, organizationUri=org.organizationUri, group=group3) yield org archive_organization(client1, org.organizationUri) diff --git a/tests_new/integration_tests/core/organizations/test_organization.py b/tests_new/integration_tests/core/organizations/test_organization.py index c602d39c8..dcdae19a6 100644 --- a/tests_new/integration_tests/core/organizations/test_organization.py +++ b/tests_new/integration_tests/core/organizations/test_organization.py @@ -44,7 +44,7 @@ def test_get_organization_organization_with_admin_team(client1, org1): assert_that(response.stats.groups).is_equal_to(0) -def test_get_organization_organization_with_invited_team(client2, org2): +def test_get_organization_with_invited_team(client2, org2): # Given an organization organization = org2 # When an invited team (client2) gets the organization @@ -52,7 +52,7 @@ def test_get_organization_organization_with_invited_team(client2, org2): # Then assert_that(response.organizationUri).is_equal_to(organization.organizationUri) assert_that(response.userRoleInOrganization).is_equal_to('Invited') - assert_that(response.stats.groups).is_equal_to(1) + assert_that(response.stats.groups).is_equal_to(2) def test_get_organization_with_unauthorized_team(client3, org1): @@ -82,10 +82,10 @@ def test_list_organizations_with_invited_team(client2, org1, org2, session_id): assert_that(response.count).is_equal_to(1) -def test_list_organizations_with_unauthorized_team(client3, org1, org2, session_id): +def test_list_organizations_with_unauthorized_team(client4, org1, org2, session_id): # Given 2 organizations # When a non-invited user - response = list_organizations(client3, term=session_id) + response = list_organizations(client4, term=session_id) # Then assert_that(response.count).is_equal_to(0) diff --git a/tests_new/integration_tests/core/stack/queries.py b/tests_new/integration_tests/core/stack/queries.py new file mode 100644 index 000000000..f40ac253f --- /dev/null +++ b/tests_new/integration_tests/core/stack/queries.py @@ -0,0 +1,58 @@ +def update_stack(client, target_uri, target_type): + query = { + 'operationName': 'updateStack', + 'variables': {'targetUri': target_uri, 'targetType': target_type}, + 'query': """ + mutation updateStack($targetUri: String!, $targetType: String!) { + updateStack(targetUri: $targetUri, targetType: $targetType) { + stackUri + targetUri + name + } + } + """, + } + response = client.query(query=query) + return response.data.updateStack + + +def get_stack(client, env_uri, stack_uri, target_uri, target_type): + query = { + 'operationName': 'getStack', + 'variables': { + 'environmentUri': env_uri, + 'stackUri': stack_uri, + 'targetUri': target_uri, + 'targetType': target_type, + }, + 'query': """ + query getStack( + $environmentUri: String! + $stackUri: String! + $targetUri: String! + $targetType: String! + ) { + getStack( + environmentUri: $environmentUri + stackUri: $stackUri + targetUri: $targetUri + targetType: $targetType + ) { + status + stackUri + targetUri + accountid + region + stackid + link + outputs + resources + error + events + name + } + } + """, + } + response = client.query(query=query) + return response.data.getStack diff --git a/tests_new/integration_tests/core/stack/utils.py b/tests_new/integration_tests/core/stack/utils.py new file mode 100644 index 000000000..3608244ae --- /dev/null +++ b/tests_new/integration_tests/core/stack/utils.py @@ -0,0 +1,18 @@ +import re + +from integration_tests.core.stack.queries import get_stack +from integration_tests.utils import poller + + +def is_stack_in_progress(stack): + return re.match(r'.*IN_PROGRESS|PENDING', stack.status, re.IGNORECASE) + + +@poller(check_success=is_stack_in_progress, timeout=600) +def check_stack_in_progress(client, env_uri, stack_uri, target_uri=None, target_type='environment'): + return get_stack(client, env_uri, stack_uri, target_uri or env_uri, target_type) + + +@poller(check_success=lambda stack: not is_stack_in_progress(stack), timeout=600) +def check_stack_ready(client, env_uri, stack_uri, target_uri=None, target_type='environment'): + return get_stack(client, env_uri, stack_uri, target_uri or env_uri, target_type) diff --git a/tests_new/integration_tests/utils.py b/tests_new/integration_tests/utils.py index a48786f94..4c7856829 100644 --- a/tests_new/integration_tests/utils.py +++ b/tests_new/integration_tests/utils.py @@ -11,12 +11,12 @@ def poller( timeout: Optional[float] = float('inf'), sleep_time: Optional[float] = 10.0, ): - def decorator(function): - @wraps(function) + def decorator(func): + @wraps(func) def wrapper(*args, **kwargs): current_timeout = timeout - while not check_success(retval := function(*args, **kwargs)): - log.debug(f'polling {current_timeout} {retval}') + while not check_success(retval := func(*args, **kwargs)): + log.debug(f'{func.__name__=} polling {current_timeout=} {retval=}') time.sleep(sleep_time) current_timeout -= sleep_time if current_timeout <= 0: