diff --git a/api/main/_permission_util.py b/api/main/_permission_util.py index be2e09c6d..5bc729e56 100644 --- a/api/main/_permission_util.py +++ b/api/main/_permission_util.py @@ -499,7 +499,7 @@ def augment_permission(user, qs): # if model == Localization: - qs = qs.annotate(section=F("media__primary_section__pk")) + qs = qs.annotate(section=F("media_proj__primary_section__pk")) elif model == State: sb = Subquery( Media.objects.filter(state__pk=OuterRef("pk")).values("primary_section__pk")[:1] @@ -507,10 +507,10 @@ def augment_permission(user, qs): qs = qs.annotate(section=sb) # Calculate a dictionary for permissions by section and version in this set - effected_media = qs.values("media__pk") + effected_media = qs.values("media_proj__pk") effected_sections = ( - Section.objects.filter(project=project, media__in=effected_media) - .values("pk") + SectionMediaM2M.objects.filter(project=project, media__in=effected_media) + .values("section") .distinct() ) effected_versions = qs.values("version__pk") @@ -538,7 +538,7 @@ def augment_permission(user, qs): } section_cases = [ - When(media__primary_section=section, then=Value(perm)) + When(media_proj__primary_section=section, then=Value(perm)) for section, perm in section_perm_dict.items() ] version_cases = [ diff --git a/api/main/management/commands/backupresources.py b/api/main/management/commands/backupresources.py index 1e9e7a8df..67ad47647 100644 --- a/api/main/management/commands/backupresources.py +++ b/api/main/management/commands/backupresources.py @@ -20,7 +20,7 @@ class Command(BaseCommand): help = "Backs up any resource objects with `backed_up==False`." def handle(self, **options): - resource_qs = Resource.objects.filter(media__deleted=False, backed_up=False) + resource_qs = Resource.objects.filter(media_proj__deleted=False, backed_up=False) # Check for existence of default backup store default_backup_store = get_tator_store(backup=True) diff --git a/api/main/models.py b/api/main/models.py index e65794a03..f4f195709 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -9,6 +9,7 @@ from django.contrib.contenttypes.models import ContentType from django.contrib.gis.db.models import Model from django.contrib.gis.db.models import ForeignKey +from django.contrib.gis.db.models import ForeignObject from django.contrib.gis.db.models import ManyToManyField from django.contrib.gis.db.models import OneToOneField from django.contrib.gis.db.models import CharField @@ -115,12 +116,18 @@ RETURN NEW; """ +UPDATE_LOOKUP_TRIGGER_FUNC = """ +SET plan_cache_mode=force_generic_plan; +EXECUTE format('EXECUTE update_lookup_{0}(%s,%s)', NEW.project::integer, NEW.id::integer); +SET plan_cache_mode=auto; +RETURN NEW; +""" # Register prepared statements for the triggers to optimize performance on creation of a database connection @receiver(connection_created) def on_connection_created(sender, connection, **kwargs): http_method = get_http_method() - if http_method in ["PATCH", "POST"]: + if http_method in ["PATCH", "POST", "DELETE"]: logger.info( f"{http_method} detected, creating prepared statements." ) # useful for testing purposes @@ -155,6 +162,18 @@ def create_prepared_statements(cursor): "PREPARE update_latest_mark_state(UUID, INT) AS UPDATE main_state SET latest_mark=(SELECT COALESCE(MAX(mark),0) FROM main_state WHERE elemental_id=$1 AND version=$2 AND deleted=FALSE) WHERE elemental_id=$1 AND version=$2;" ) + cursor.execute( + "PREPARE update_lookup_media(INT, INT) AS INSERT INTO main_projectlookup (project_id, media_id) VALUES ($1, $2) ON CONFLICT DO NOTHING;" + ) + + cursor.execute( + "PREPARE update_lookup_localization(INT, INT) AS INSERT INTO main_projectlookup (project_id, localization_id) VALUES ($1, $2) ON CONFLICT DO NOTHING;" + ) + + cursor.execute( + "PREPARE update_lookup_state(INT, INT) AS INSERT INTO main_projectlookup (project_id, state_id) VALUES ($1, $2) ON CONFLICT DO NOTHING;" + ) + class ModelDiffMixin(object): """ @@ -1389,6 +1408,15 @@ class Media(Model, ModelDiffMixin): """ + class Meta: + triggers = [ + pgtrigger.Trigger( + name="post_media_update_lookup", + operation=pgtrigger.Insert, + when=pgtrigger.After, + func=UPDATE_LOOKUP_TRIGGER_FUNC.format("media"), + ), + ] project = ForeignKey( Project, @@ -1519,7 +1547,7 @@ def is_backed_up(self): media_qs = Media.objects.filter(pk__in=self.media_files["ids"]) return all(media.is_backed_up() for media in media_qs.iterator()) - resource_qs = Resource.objects.filter(media=self) + resource_qs = Resource.objects.filter(media_proj=self) return all(resource.backed_up for resource in resource_qs.iterator()) def media_def_iterator(self, keys: List[str] = None) -> Generator[Tuple[str, dict], None, None]: @@ -1652,10 +1680,16 @@ class File(Model, ModelDiffMixin): ) """ Unique ID for a to facilitate cross-cluster sync operations """ - class Resource(Model): path = CharField(db_index=True, max_length=256) + # Comment this out to find all usages media = ManyToManyField(Media, related_name="resource_media") + media_proj = ManyToManyField( + Media, + related_name="resource_media_proj", + through="ResourceMediaM2M", + through_fields=("resource", "media_proj"), + ) generic_files = ManyToManyField(File, related_name="resource_files") bucket = ForeignKey(Bucket, on_delete=PROTECT, null=True, blank=True, related_name="bucket") backup_bucket = ForeignKey( @@ -1687,7 +1721,7 @@ def add_resource(path_or_link, media, generic_file=None): if created: obj.bucket = media.project.bucket obj.save() - obj.media.add(media) + add_media_to_resource(obj, media) @staticmethod @transaction.atomic @@ -1700,7 +1734,7 @@ def delete_resource(path_or_link, project_id): obj = Resource.objects.get(path=path) # If any media or generic files still reference this resource, don't delete it - if obj.media.all().count() > 0 or obj.generic_files.all().count() > 0: + if obj.media_proj.all().count() > 0 or obj.generic_files.all().count() > 0: return logger.info(f"Deleting object {path}") @@ -1772,6 +1806,31 @@ def restore_resource(path, domain): return TatorBackupManager().finish_restore_resource(path, project, domain) +class ResourceMediaM2M(Model): + resource = ForeignKey(Resource, on_delete=CASCADE) + media = ForeignKey(Media, on_delete=CASCADE) + project = ForeignKey(Project, on_delete=CASCADE) + media_proj = ForeignObject( + to=Media, + on_delete=CASCADE, + from_fields=("project", "media"), + to_fields=("project", "id"), + related_name="media_proj", + null=True, + ) + class Meta: + constraints = [ + UniqueConstraint(name="resourcem2m", fields=["resource", "project", "media"]) + ] + + +@transaction.atomic +def add_media_to_resource(resource, media): + obj, created = ResourceMediaM2M.objects.get_or_create( + resource=resource, media=media, project=media.project + ) + + @receiver(post_save, sender=Media) def media_save(sender, instance, created, **kwargs): if instance.media_files and created: @@ -1810,7 +1869,8 @@ def drop_media_from_resource(path, media): try: logger.info(f"Dropping media {media} from resource {path}") obj = Resource.objects.get(path=path) - obj.media.remove(media) + matches = ResourceMediaM2M.objects.filter(resource=obj, media=media) + matches.delete() except: logger.warning(f"Could not remove {media} from {path}", exc_info=True) @@ -1861,6 +1921,12 @@ class Meta: declare=[("_var", "integer")], func=AFTER_MARK_TRIGGER_FUNC.format("localization"), ), + pgtrigger.Trigger( + name="post_localization_update_lookup", + operation=pgtrigger.Insert, + when=pgtrigger.After, + func=UPDATE_LOOKUP_TRIGGER_FUNC.format("localization"), + ), ] project = ForeignKey(Project, on_delete=SET_NULL, null=True, blank=True, db_column="project") @@ -1890,7 +1956,15 @@ class Meta: db_column="modified_by", ) user = ForeignKey(User, on_delete=PROTECT, db_column="user") - media = ForeignKey(Media, on_delete=SET_NULL, null=True, blank=True, db_column="media") + media = IntegerField(null=True, blank=True, db_column="media") + media_proj = ForeignObject( + to=Media, + on_delete=CASCADE, + to_fields=("project", "id"), + from_fields=("project", "media"), + null=True, + name="media_proj", + ) frame = PositiveIntegerField(null=True, blank=True) thumbnail_image = ForeignKey( Media, @@ -1966,6 +2040,12 @@ class Meta: declare=[("_var", "integer")], func=AFTER_MARK_TRIGGER_FUNC.format("state"), ), + pgtrigger.Trigger( + name="post_state_update_lookup", + operation=pgtrigger.Insert, + when=pgtrigger.After, + func=UPDATE_LOOKUP_TRIGGER_FUNC.format("state"), + ), ] project = ForeignKey(Project, on_delete=SET_NULL, null=True, blank=True, db_column="project") @@ -1999,6 +2079,18 @@ class Meta: """ Pointer to localization in which this one was generated from """ media = ManyToManyField(Media, related_name="state") localizations = ManyToManyField(Localization) + media_proj = ManyToManyField( + Media, + related_name="state_media_proj", + through="StateMediaM2M", + through_fields=("state", "media_proj"), + ) + localization_proj = ManyToManyField( + Localization, + related_name="state_localization_proj", + through="StateLocalizationM2M", + through_fields=("state", "localization_proj"), + ) segments = JSONField(null=True, blank=True) color = CharField(null=True, blank=True, max_length=8) frame = PositiveIntegerField(null=True, blank=True) @@ -2026,7 +2118,101 @@ def selectOnMedia(media_id): return State.objects.filter(media__in=media_id) -@receiver(m2m_changed, sender=State.localizations.through) +class StateMediaM2M(Model): + state = ForeignKey(State, on_delete=CASCADE) + media = ForeignKey(Media, on_delete=CASCADE) + project = ForeignKey(Project, on_delete=CASCADE) + media_proj = ForeignObject( + to=Media, + on_delete=CASCADE, + from_fields=("project", "media"), + to_fields=("project", "id"), + related_name="stm_media_proj", + null=True, + ) + + class Meta: + constraints = [ + UniqueConstraint(name="state_media_m2m", fields=["state", "project", "media"]) + ] + + +@transaction.atomic +def add_media_id_to_state(state, media_id, project_id, clear=False): + if type(media_id) == int: + obj, created = StateMediaM2M.objects.get_or_create( + state=state, media_id=media_id, project_id=project_id + ) + else: + existing = StateMediaM2M.objects.filter( + state=state, media_id__in=media_id, project=project_id + ) + if clear: + existing.delete() + existing = list(existing.values_list("media_id", flat=True)) + blk = [] + for media in media_id: + if media not in existing: + blk.append(StateMediaM2M(state=state, media_id=media, project_id=project_id)) + + if blk: + StateMediaM2M.objects.bulk_create(blk) + + +def add_media_to_state(state, media): + return add_media_id_to_state(state, media.id, media.project.id) + + +class StateLocalizationM2M(Model): + state = ForeignKey(State, on_delete=CASCADE) + localization = ForeignKey(Localization, on_delete=CASCADE) + project = ForeignKey(Project, on_delete=CASCADE) + localization_proj = ForeignObject( + to=Localization, + on_delete=CASCADE, + from_fields=("project", "localization"), + to_fields=("project", "id"), + related_name="stm_localization_proj", + null=True, + ) + + class Meta: + constraints = [ + UniqueConstraint( + name="state_localization_m2m", fields=["state", "project", "localization"] + ) + ] + + +@transaction.atomic +def add_localization_id_to_state(state, localization_id, project_id, clear=False): + if type(localization_id) == int: + obj, created = StateLocalizationM2M.objects.get_or_create( + state=state, localization_id=localization_id, project_id=project_id + ) + else: + existing = StateLocalizationM2M.objects.filter( + state=state, localization_id__in=localization_id, project=project_id + ) + if clear: + existing.delete() + existing = list(existing.values_list("localization_id", flat=True)) + blk = [] + for local in localization_id: + if local not in existing: + blk.append( + StateLocalizationM2M(state=state, localization_id=local, project_id=project_id) + ) + + if blk: + StateLocalizationM2M.objects.bulk_create(blk) + + +def add_localization_to_state(state, localization): + return add_localization_id_to_state(state, localization.id, localization.project.id) + + +@receiver(m2m_changed, sender=State.localization_proj.through) def calc_segments(sender, **kwargs): instance = kwargs["instance"] sortedLocalizations = Localization.objects.filter(pk__in=instance.localizations.all()).order_by( @@ -2172,6 +2358,40 @@ class Section(Model): explicit_listing = BooleanField(default=False, null=True, blank=True) media = ManyToManyField(Media) + media_proj = ManyToManyField( + Media, + related_name="section_media_proj", + through="SectionMediaM2M", + through_fields=("section", "media_proj"), + ) + + +class SectionMediaM2M(Model): + section = ForeignKey(Section, on_delete=CASCADE) + media = ForeignKey(Media, on_delete=CASCADE) + project = ForeignKey(Project, on_delete=CASCADE) + media_proj = ForeignObject( + to=Media, + on_delete=CASCADE, + from_fields=("project", "media"), + to_fields=("project", "id"), + related_name="sm_media_proj", + null=True, + ) + + class Meta: + constraints = [UniqueConstraint(name="sectionm2m", fields=["section", "project", "media"])] + + +@transaction.atomic +def add_media_id_to_section(section, media_id, project_id): + obj, created = SectionMediaM2M.objects.get_or_create( + section=section, media_id=media_id, project_id=project_id + ) + + +def add_media_to_section(section, media): + return add_media_id_to_section(section, media.id, media.project.id) class Favorite(Model): @@ -2463,6 +2683,22 @@ class Meta: ] +class ProjectLookup(Model): + """This Table defines an easy way to look up the project associated with a given object""" + + media = ForeignKey(Media, on_delete=CASCADE, null=True, blank=True) + localization = ForeignKey(Localization, on_delete=CASCADE, null=True, blank=True) + state = ForeignKey(State, on_delete=CASCADE, null=True, blank=True) + project = ForeignKey(Project, on_delete=CASCADE, null=True, blank=True) + class Meta: + constraints = [ + UniqueConstraint( + fields=["project", "media", "localization", "state"], + name="lookup_uniqueness_check", + ) + ] + + # Structure to handle identifying columns with project-scoped indices # e.g. Not relaying solely on `db_index=True` in django. BUILT_IN_INDICES = { @@ -2496,3 +2732,12 @@ class Meta: {"name": "$path", "dtype": "upper_string"}, ], } + +if os.getenv("TATOR_EXT_MODELS", None): + import importlib + for module in os.getenv("TATOR_EXT_MODELS").split(","): + try: + importlib.import_module(module) + except Exception as e: + logger.error(f"Failed to import module {module}: {e}") + assert False diff --git a/api/main/rest/_annotation_query.py b/api/main/rest/_annotation_query.py index 979447b6f..5d788be5a 100644 --- a/api/main/rest/_annotation_query.py +++ b/api/main/rest/_annotation_query.py @@ -10,7 +10,17 @@ from django.db.models import Q, F from django.db.models import BigIntegerField, ExpressionWrapper -from ..models import Localization, LocalizationType, Media, MediaType, Section, State, StateType +from ..models import ( + Localization, + LocalizationType, + Media, + MediaType, + Section, + State, + StateType, + StateLocalizationM2M, + StateMediaM2M, +) from ..schema._attributes import related_keys @@ -74,7 +84,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): if media_id is not None: media_ids += media_id if media_ids: - qs = qs.filter(media__in=set(media_ids)) + qs = qs.filter(media_proj__pk__in=set(media_ids)) if len(media_ids) > 1: qs = qs.distinct() @@ -83,7 +93,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): localization_ids += localization_id_put if state_ids and (annotation_type == "localization"): localization_ids += list( - State.localizations.through.objects.filter(state__in=set(state_ids)) + StateLocalizationM2M.objects.filter(state__in=set(state_ids), project=project) .values_list("localization_id", flat=True) .distinct() ) @@ -91,7 +101,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): if annotation_type == "localization": qs = qs.filter(pk__in=set(localization_ids)) elif annotation_type == "state": - qs = qs.filter(localizations__in=set(localization_ids)).distinct() + qs = qs.filter(localization_proj__pk__in=set(localization_ids)).distinct() if state_ids and (annotation_type == "state"): qs = qs.filter(pk__in=set(state_ids)) @@ -99,13 +109,15 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): if frame_state_ids and (annotation_type == "localization"): # Combine media and frame from states then find localizations that match expression = ExpressionWrapper( - Cast("media", output_field=BigIntegerField()).bitleftshift(32).bitor(F("frame")), + Cast("media_proj__pk", output_field=BigIntegerField()) + .bitleftshift(32) + .bitor(F("frame")), output_field=BigIntegerField(), ) media_frames = ( State.objects.filter( pk__in=set(frame_state_ids), - media__isnull=False, + media_proj__isnull=False, frame__isnull=False, variant_deleted=False, ) @@ -113,7 +125,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): .distinct() ) qs = ( - qs.filter(media__isnull=False, frame__isnull=False) + qs.filter(media_proj__isnull=False, frame__isnull=False) .alias(media_frame=expression) .filter(media_frame__in=media_frames) ) @@ -179,7 +191,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): related_object_search = section.related_object_search media_qs = Media.objects.filter(project=project, type=media_type_id) if section.explicit_listing: - media_qs = Media.objects.filter(pk__in=section.media.values("id")) + media_qs = Media.objects.filter(pk__in=section.media_proj.values("id")) logger.info(f"Explicit listing: {media_ids}") elif section_uuid: media_qs = _look_for_section_uuid(media_qs, section_uuid) @@ -199,9 +211,9 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): raise ValueError(f"Invalid Section value pk={section.pk}") media_ids.append(media_qs) - query = Q(media__in=media_ids.pop()) + query = Q(media_proj__pk__in=media_ids.pop()) for m in media_ids: - query = query | Q(media__in=m) + query = query | Q(media_proj__pk__in=m) qs = qs.filter(query) # Do a related query @@ -224,9 +236,9 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): ) if related_matches: related_match = related_matches.pop() - query = Q(media__in=related_match) + query = Q(media_proj__in=related_match) for r in related_matches: - query = query | Q(media__in=r) + query = query | Q(media_proj__in=r) qs = qs.filter(query).distinct() if params.get("encoded_related_search"): @@ -242,9 +254,9 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): related_matches.append(media_qs) if related_matches: related_match = related_matches.pop() - query = Q(media__in=related_match) + query = Q(media_proj__in=related_match) for r in related_matches: - query = query | Q(media__in=r) + query = query | Q(media_proj__in=r) qs = qs.filter(query).distinct() else: qs = qs.filter(pk=-1) @@ -255,7 +267,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): state_qs = State.objects.filter(pk__in=set(params.get("related_id"))) qs = qs.filter(pk__in=state_qs.values("localizations").distinct()) elif annotation_type == "state": - qs = qs.filter(localizations__in=set(params.get("related_id"))) + qs = qs.filter(localization_proj__pk__in=set(params.get("related_id"))) if apply_merge: # parent_set = ANNOTATION_LOOKUP[annotation_type].objects.filter(pk__in=Subquery()) diff --git a/api/main/rest/_attribute_query.py b/api/main/rest/_attribute_query.py index cc27bd249..9b2cd5e94 100644 --- a/api/main/rest/_attribute_query.py +++ b/api/main/rest/_attribute_query.py @@ -77,7 +77,9 @@ class MediaFieldExpression: def get_wrapper(): return ExpressionWrapper( - Cast("media", output_field=BigIntegerField()).bitleftshift(32).bitor(F("frame")), + Cast("media_proj__pk", output_field=BigIntegerField()) + .bitleftshift(32) + .bitor(F("frame")), output_field=BigIntegerField(), ) @@ -144,9 +146,9 @@ def _related_search( orig_list = [*related_matches] related_match = related_matches.pop() # Pop and process the list - media_vals = related_match.values("media") + media_vals = related_match.values("media_proj__pk") for related_match in related_matches: - this_vals = related_match.values("media") + this_vals = related_match.values("media_proj__pk") media_vals = media_vals.union(this_vals) # We now have all the matching media, but lost the score information @@ -155,14 +157,16 @@ def _related_search( # list comp didn't play nice here, but this is easier to read anyway score = [] for x in orig_list: - annotated_x = x.values("media").annotate(count=Count("media")) - filtered_x = annotated_x.filter(media=OuterRef("id")) + annotated_x = x.values("media_proj__pk").annotate(count=Count("media_proj__pk")) + filtered_x = annotated_x.filter(media_proj__pk=OuterRef("id")) values_x = filtered_x.values("count").order_by("-count")[:1] score.append(Subquery(values_x)) if len(score) > 1: - qs = qs.filter(pk__in=media_vals.values("media")).annotate(incident=Greatest(*score)) + qs = qs.filter(pk__in=media_vals.values("media_proj__pk")).annotate( + incident=Greatest(*score) + ) else: - qs = qs.filter(pk__in=media_vals.values("media")).annotate(incident=score[0]) + qs = qs.filter(pk__in=media_vals.values("media_proj__pk")).annotate(incident=score[0]) else: qs = qs.filter(pk=-1).annotate(incident=Value(0)) return qs diff --git a/api/main/rest/_media_query.py b/api/main/rest/_media_query.py index 86973b300..80d9548b6 100644 --- a/api/main/rest/_media_query.py +++ b/api/main/rest/_media_query.py @@ -11,7 +11,16 @@ from django.db.models.functions import Cast from django.db.models import UUIDField, TextField, F -from ..models import LocalizationType, Media, MediaType, Localization, Section, State, StateType +from ..models import ( + LocalizationType, + Media, + MediaType, + Localization, + Section, + State, + StateType, + StateMediaM2M, +) from ..schema._attributes import related_keys from ._attribute_query import ( @@ -71,7 +80,7 @@ def _get_media_psql_queryset(project, filter_ops, params): media_ids += media_id if state_ids is not None: media_ids += list( - State.media.through.objects.filter(state__in=set(state_ids)) + StateMediaM2M.objects.filter(state__in=set(state_ids), project=project) .values_list("media_id", flat=True) .distinct() ) @@ -174,9 +183,9 @@ def _get_media_psql_queryset(project, filter_ops, params): ) if related_matches: related_match = related_matches.pop() - query = Q(pk__in=related_match.values("media")) + query = Q(pk__in=related_match.values("media_proj__pk")) for r in related_matches: - query = query | Q(pk__in=r.values("media")) + query = query | Q(pk__in=r.values("media_proj__pk")) qs = qs.filter(query).distinct() if section_id: @@ -187,7 +196,7 @@ def _get_media_psql_queryset(project, filter_ops, params): section_uuid = section[0].tator_user_sections if section[0].explicit_listing: - qs = qs.filter(pk__in=section[0].media.all()) + qs = qs.filter(pk__in=section[0].media_proj.all().values("pk")) elif section_uuid: qs = _look_for_section_uuid(qs, section_uuid) elif section[0].object_search: @@ -212,7 +221,7 @@ def _get_media_psql_queryset(project, filter_ops, params): section_uuid = section.tator_user_sections if section.explicit_listing: - match_qs = qs.filter(pk__in=section.media.all()) + match_qs = qs.filter(pk__in=section.media_proj.all().values("pk")) elif section_uuid: match_qs = _look_for_section_uuid(qs, section_uuid) diff --git a/api/main/rest/_media_util.py b/api/main/rest/_media_util.py index 06c37efa3..6ce865003 100644 --- a/api/main/rest/_media_util.py +++ b/api/main/rest/_media_util.py @@ -26,7 +26,7 @@ def __init__(self, video, temp_dir, quality=None): # If available we only attempt to fetch # the part of the file we need to self._segment_info = None - resources = Resource.objects.filter(media__in=[video]) + resources = Resource.objects.filter(media_proj__in=[video]) store_lookup = get_storage_lookup(resources) if "streaming" in video.media_files: diff --git a/api/main/rest/localization.py b/api/main/rest/localization.py index 6fa141c93..814aa2468 100644 --- a/api/main/rest/localization.py +++ b/api/main/rest/localization.py @@ -12,6 +12,7 @@ from ..models import Project from ..models import Version from ..models import Section +from ..models import ProjectLookup from ..schema import LocalizationListSchema from ..schema import LocalizationDetailSchema, LocalizationByElementalIdSchema from ..schema.components import localization as localization_schema @@ -211,7 +212,7 @@ def _post(self, params): Localization( project=project, type=metas[loc_spec["type"]], - media=medias[loc_spec["media_id"]], + media=medias[loc_spec["media_id"]].pk, user=compute_user( project, self.request.user, loc_spec.get("user_elemental_id", None) ), @@ -576,7 +577,11 @@ def get_permissions(self): def get_queryset(self, **kwargs): return self.filter_only_viewables( - Localization.objects.filter(pk=self.params["id"], deleted=False) + Localization.objects.filter( + project=ProjectLookup.objects.get(localization=self.params["id"]).project, + pk=self.params["id"], + deleted=False, + ) ) def _get(self, params): @@ -608,8 +613,11 @@ def get_queryset(self, **kwargs): include_deleted = False if params.get("prune", None) == 1: include_deleted = True + version_obj = Version.objects.get(pk=params["version"]) qs = Localization.objects.filter( - elemental_id=params["elemental_id"], version=params["version"] + project=version_obj.project, + elemental_id=params["elemental_id"], + version=params["version"], ) if include_deleted is False: qs = qs.filter(deleted=False) diff --git a/api/main/rest/localization_graphic.py b/api/main/rest/localization_graphic.py index 2cacaf4de..93b7c77b8 100644 --- a/api/main/rest/localization_graphic.py +++ b/api/main/rest/localization_graphic.py @@ -256,7 +256,7 @@ def _get(self, params: dict): # By reaching here, it's expected that the graphics mode is to create a new # thumbnail using the provided parameters. That new thumbnail is returned with tempfile.TemporaryDirectory() as temp_dir: - media_util = MediaUtil(video=obj.media, temp_dir=temp_dir) + media_util = MediaUtil(video=obj.media_proj, temp_dir=temp_dir) roi = self._getRoi( obj=obj, diff --git a/api/main/rest/media.py b/api/main/rest/media.py index 3d4638b03..5ad4a72a3 100644 --- a/api/main/rest/media.py +++ b/api/main/rest/media.py @@ -28,6 +28,7 @@ database_qs, database_query_ids, Version, + ProjectLookup, ) from .._permission_util import PermissionMask @@ -106,7 +107,8 @@ def _presign(user_id, expiration, medias, fields=None, no_cache=False): "attachment", ] media_ids = set([media["id"] for media in medias]) - resources = Resource.objects.filter(media__in=media_ids) + media_objs = Media.objects.filter(pk__in=media_ids) + resources = Resource.objects.filter(media_proj__in=media_objs) store_lookup = get_storage_lookup(resources) cache = TatorCache() ttl = expiration - 3600 @@ -519,10 +521,10 @@ def _delete(self, params): # Any states that are only associated to deleted media should also be marked # for deletion. - not_deleted = State.objects.filter(project=project, media__deleted=False).values_list( - "id", flat=True - ) - deleted = State.objects.filter(project=project, media__deleted=True).values_list( + not_deleted = State.objects.filter( + project=project, media_proj__deleted=False + ).values_list("id", flat=True) + deleted = State.objects.filter(project=project, media_proj__deleted=True).values_list( "id", flat=True ) all_deleted = set(deleted) - set(not_deleted) @@ -902,10 +904,10 @@ def _delete(self, params): # Any states that are only associated to deleted media should also be marked # for deletion. - not_deleted = State.objects.filter(project=project, media__deleted=False).values_list( + not_deleted = State.objects.filter(project=project, media_proj__deleted=False).values_list( "id", flat=True ) - deleted = State.objects.filter(project=project, media__deleted=True).values_list( + deleted = State.objects.filter(project=project, media_proj__deleted=True).values_list( "id", flat=True ) all_deleted = set(deleted) - set(not_deleted) @@ -919,4 +921,8 @@ def _delete(self, params): return {"message": f'Media {params["id"]} successfully deleted!'} def get_queryset(self, **kwargs): - return Media.objects.filter(pk=self.params["id"], deleted=False) + return Media.objects.filter( + project=ProjectLookup.objects.get(media=self.params["id"]).project, + pk=self.params["id"], + deleted=False, + ) diff --git a/api/main/rest/permalink.py b/api/main/rest/permalink.py index 8a620f6ff..d95797e29 100644 --- a/api/main/rest/permalink.py +++ b/api/main/rest/permalink.py @@ -46,8 +46,9 @@ def _presign(expiration, medias, fields=None): "thumbnail_gif", "attachment", ] - media_ids = [media["id"] for media in medias] - resources = Resource.objects.filter(media__in=media_ids) + media_objs = Media.objects.filter(pk__in=[media["id"] for media in medias]) + resources = Resource.objects.filter(media_proj__in=media_objs) + logger.info(f"resources = {resources}") storage_lookup = get_storage_lookup(resources) # Get replace all keys with presigned urls. diff --git a/api/main/rest/section.py b/api/main/rest/section.py index 8b1433df5..2a519e092 100644 --- a/api/main/rest/section.py +++ b/api/main/rest/section.py @@ -10,6 +10,8 @@ from ..models import Project from ..models import database_qs from ..models import RowProtection +from ..models import SectionMediaM2M +from ..models import add_media_to_section, add_media_id_to_section from ..schema import SectionListSchema from ..schema import SectionDetailSchema from ..schema.components import section @@ -32,7 +34,7 @@ def _fill_m2m(response_data): section_ids = [section["id"] for section in response_data] media = { obj["section_id"]: obj["media"] - for obj in Section.media.through.objects.filter(section__in=section_ids) + for obj in Section.media_proj.through.objects.filter(section__in=section_ids) .values("section_id") .order_by("section_id") .annotate(media=ArrayAgg("media_id", default=[])) @@ -137,7 +139,7 @@ def _post(self, params): ) if media_list: for media_id in media_list: - section.media.add(media_id) + add_media_id_to_section(section, media_id, project.pk) section.save() # Automatically create row protection for newly created section based on the creator RowProtection.objects.create( @@ -220,10 +222,16 @@ def _patch(self, params): # Handle removing/adding media media_add = params.get("media_add", []) media_del = params.get("media_del", []) + # This is already in an atomic block so we are good to go. for m in media_add: - section.media.add(m) + SectionMediaM2M.objects.get_or_create( + section=section, media_id=media_id, project_id=project.id + ) for m in media_del: - section.media.remove(m) + matching = SectionMediaM2M.objects.filter( + section=section, media_id=m, project_id=section.project.id + ) + matching.delete() # Handle attributes new_attrs = validate_attributes(params, section, section.project.attribute_types) diff --git a/api/main/rest/state.py b/api/main/rest/state.py index 587b0ba67..e87849d70 100644 --- a/api/main/rest/state.py +++ b/api/main/rest/state.py @@ -17,9 +17,14 @@ from ..models import Project from ..models import Membership from ..models import Version +from ..models import ProjectLookup from ..models import User from ..models import InterpolationMethods from ..models import Section +from ..models import StateMediaM2M +from ..models import StateLocalizationM2M +from ..models import add_media_to_state, add_media_id_to_state +from ..models import add_localization_to_state, add_localization_id_to_state from ..schema import StateListSchema from ..schema import StateDetailSchema, StateByElementalIdSchema from ..schema import MergeStatesSchema @@ -60,7 +65,7 @@ def _fill_m2m(response_data): state_ids = set([state["id"] for state in response_data]) localizations = { obj["state_id"]: obj["localizations"] - for obj in State.localizations.through.objects.filter(state__in=state_ids) + for obj in State.localization_proj.through.objects.filter(state__in=state_ids) .values("state_id") .order_by("state_id") .annotate(localizations=ArrayAgg("localization_id", default=[])) @@ -68,7 +73,7 @@ def _fill_m2m(response_data): } media = { obj["state_id"]: obj["media"] - for obj in State.media.through.objects.filter(state__in=state_ids) + for obj in State.media_proj.through.objects.filter(state__in=state_ids) .values("state_id") .order_by("state_id") .annotate(media=ArrayAgg("media_id", default=[])) @@ -306,33 +311,14 @@ def _post(self, params): # Create media relations. media_relations = [] for state, state_spec in zip(states, state_specs): - for media_id in state_spec["media_ids"]: - media_states = State.media.through( - state_id=state.id, - media_id=media_id, - ) - media_relations.append(media_states) - if len(media_relations) > 1000: - State.media.through.objects.bulk_create(media_relations, ignore_conflicts=True) - media_relations = [] - State.media.through.objects.bulk_create(media_relations, ignore_conflicts=True) + if "media_ids" in state_spec: + add_media_id_to_state(state, state_spec["media_ids"], project.pk) # Create localization relations. loc_relations = [] for state, state_spec in zip(states, state_specs): if "localization_ids" in state_spec: - for localization_id in state_spec["localization_ids"]: - loc_states = State.localizations.through( - state_id=state.id, - localization_id=localization_id, - ) - loc_relations.append(loc_states) - if len(loc_relations) > 1000: - State.localizations.through.objects.bulk_create( - loc_relations, ignore_conflicts=True - ) - loc_relations = [] - State.localizations.through.objects.bulk_create(loc_relations, ignore_conflicts=True) + add_localization_id_to_state(state, state_spec["localization_ids"], project.pk) # Calculate segments (this is not triggered for bulk created m2m). localization_ids = set( @@ -434,7 +420,9 @@ def _patch(self, params): origin_datetimes = [] for original in qs.iterator(): - many_to_many.append((original.media.all(), original.localizations.all())) + many_to_many.append( + (original.media_proj.all(), original.localization_proj.all()) + ) original.pk = None original.id = None for key, value in update_kwargs.items(): @@ -444,8 +432,12 @@ def _patch(self, params): origin_datetimes.append(original.created_datetime) new_objs = State.objects.bulk_create(objs) for p_obj, m2m, origin_datetime in zip(new_objs, many_to_many, origin_datetimes): - p_obj.media.set(m2m[0]) - p_obj.localizations.set(m2m[1]) + add_localization_id_to_state( + p_obj, list(m2m[1].values_list("id", flat=True)), p_obj.project.pk + ) + add_media_id_to_state( + p_obj, list(m2m[0].values_list("id", flat=True)), p_obj.project.pk + ) # Django doesn't let you fix created_datetime unless you fetch the object again found_it = State.objects.get(pk=p_obj.pk) @@ -493,17 +485,11 @@ def get_qs(self, params, qs): if not qs.exists(): raise Http404 state = qs.values(*STATE_PROPERTIES)[0] + local_qs = StateLocalizationM2M.objects.filter(state=state["id"]) + media_qs = StateMediaM2M.objects.filter(state=state["id"]) # Get many to many fields. - state["localizations"] = list( - State.localizations.through.objects.filter(state_id=state["id"]).aggregate( - localizations=ArrayAgg("localization_id", default=[]) - )["localizations"] - ) - state["media"] = list( - State.media.through.objects.filter(state_id=state["id"]).aggregate( - media=ArrayAgg("media_id", default=[]) - )["media"] - ) + state["localizations"] = list(local_qs.values_list("localization_id", flat=True)) + state["media"] = list(media_qs.values_list("media_id", flat=True)) return state def patch_qs(self, params, qs): @@ -536,40 +522,41 @@ def patch_qs(self, params, qs): obj.frame = params["frame"] if "media_ids" in params: - media_elements = Media.objects.filter(pk__in=params["media_ids"]) - obj.media.set(media_elements) if association_type != "Media": logger.warning( f"Media set on state {obj.id} of type {association_type}." "This is not a Media type state." ) + add_media_id_to_state(obj, params["media_ids"], obj.project.pk, clear=True) + if "localization_ids" in params: - localizations = Localization.objects.filter(pk__in=params["localization_ids"]) - obj.localizations.set(localizations) if association_type != "Localization": logger.warning( f"Media set on state {obj.id} of type {association_type}." "This is not a Media type state." ) + add_localization_id_to_state( + obj, params["localization_ids"], obj.project.pk, clear=True + ) + if "localization_ids_add" in params: - localizations = Localization.objects.filter(pk__in=params["localization_ids_add"]) - obj.localizations.add(*list(localizations)) if association_type != "Localization": logger.warning( f"Media set on state {obj.id} of type {association_type}." "This is not a Media type state." ) + add_localization_id_to_state(obj, params["localization_ids_add"], obj.project.pk) if "localization_ids_remove" in params: - localizations = Localization.objects.filter(pk__in=params["localization_ids_remove"]) - obj.localizations.remove(*list(localizations)) if association_type != "Localization": logger.warning( f"Media set on state {obj.id} of type {association_type}." "This is not a Media type state." ) + local_qs = obj.localization_proj.filter(pk__in=params["localization_ids_remove"]) + local_qs.delete() if params.get("user_elemental_id", None): params["in_place"] = 1 @@ -579,11 +566,11 @@ def patch_qs(self, params, qs): obj.created_by = computed_author # Make sure media and localizations are part of this project. - media_qs = Media.objects.filter(pk__in=obj.media.all()) - localization_qs = Localization.objects.filter(pk__in=obj.localizations.all()) + media_qs = Media.objects.filter(pk__in=obj.media_proj.all()) + localization_qs = obj.localization_proj.all() media_projects = list(media_qs.values_list("project", flat=True).distinct()) localization_projects = list(localization_qs.values_list("project", flat=True).distinct()) - if obj.localizations.count() > 0: + if obj.localization_proj.count() > 0: if len(localization_projects) != 1: raise Exception( f"Localizations must be part of project {obj.project.id}, got projects " @@ -594,7 +581,7 @@ def patch_qs(self, params, qs): f"Localizations must be part of project {obj.project.id}, got project " f"{localization_projects[0]}!" ) - if obj.media.count() > 0: + if obj.media_proj.count() > 0: if len(media_projects) != 1: raise Exception( f"Media must be part of project {obj.project.id}, got projects " @@ -624,8 +611,8 @@ def patch_qs(self, params, qs): f"Object is mark {obj.mark} of {obj.latest_mark} for {obj.version.name}/{obj.elemental_id}" ) - old_media = obj.media.all() - old_localizations = obj.localizations.all() + old_media = list(obj.media_proj.all().values_list("id", flat=True)) + old_localizations = list(obj.localization_proj.all().values_list("id", flat=True)) # Save edits as new object, mark is calculated in trigger obj.id = None obj.pk = None @@ -635,8 +622,10 @@ def patch_qs(self, params, qs): # Keep original creation time found_it.created_datetime = origin_datetime found_it.save() - found_it.media.set(old_media) - found_it.localizations.set(old_localizations) + + # Add by list + add_media_id_to_state(found_it, old_media, obj.project.pk) + add_localization_id_to_state(found_it, old_localizations, obj.project.pk) return { "message": f"State {obj.elemental_id}@{obj.version.id}/{obj.mark} successfully updated!", @@ -723,7 +712,9 @@ def _patch(self, params: dict) -> dict: } def get_queryset(self): - return State.objects.all() + return self.filter_only_viewables( + State.objects.objects.filter(pk__in=[self.params["id"], self.params["merge_state_id"]]) + ) class TrimStateEndAPI(BaseDetailView): @@ -737,7 +728,7 @@ class TrimStateEndAPI(BaseDetailView): @transaction.atomic def _patch(self, params: dict) -> dict: obj = State.objects.get(pk=params["id"], deleted=False) - localizations = obj.localizations.order_by("frame") + localizations = obj.localization_proj.order_by("frame") if params["endpoint"] == "start": keep_localization = lambda frame: frame >= params["frame"] @@ -771,7 +762,7 @@ def _patch(self, params: dict) -> dict: } def get_queryset(self): - return State.objects.all() + return self.filter_only_viewables(State.objects.filter(pk=self.params["id"])) class StateDetailAPI(StateDetailBaseAPI): @@ -798,7 +789,13 @@ def get_permissions(self): return super().get_permissions() def get_queryset(self, **kwargs): - return self.filter_only_viewables(State.objects.filter(pk=self.params["id"], deleted=False)) + return self.filter_only_viewables( + State.objects.filter( + project=ProjectLookup.objects.get(state=self.params["id"]).project, + pk=self.params["id"], + deleted=False, + ) + ) def _get(self, params): return self.get_qs(params, self.get_queryset()) @@ -839,7 +836,12 @@ def get_queryset(self, **kwargs): include_deleted = False if params.get("prune", None) == 1: include_deleted = True - qs = State.objects.filter(elemental_id=params["elemental_id"], version=params["version"]) + version_obj = Version.objects.get(pk=params["version"]) + qs = State.objects.filter( + project=version_obj.project, + elemental_id=params["elemental_id"], + version=params["version"], + ) if include_deleted is False: qs = qs.filter(deleted=False) diff --git a/api/main/rest/state_graphic.py b/api/main/rest/state_graphic.py index ee3de02a3..7ef802083 100644 --- a/api/main/rest/state_graphic.py +++ b/api/main/rest/state_graphic.py @@ -63,8 +63,8 @@ def _get(self, params): if typeObj.association != "Localization": raise Exception("Not a localization association state") - video = state.media.all()[0] - localizations = state.localizations.order_by("frame")[offset : offset + length] + video = state.media_proj.all()[0] + localizations = state.localization_proj.order_by("frame")[offset : offset + length] frames = [l.frame for l in localizations] roi = [(l.width, l.height, l.x, l.y) for l in localizations] with tempfile.TemporaryDirectory() as temp_dir: diff --git a/api/main/tests.py b/api/main/tests.py index 274389b8a..03dfbff11 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -294,7 +294,7 @@ def create_test_box(user, entity_type, project, media, frame, attributes={}, ver type=entity_type, project=project, version=version, - media=media, + media=media.pk, frame=frame, x=x, y=y, @@ -316,7 +316,7 @@ def make_box_obj(user, entity_type, project, media, frame, attributes={}): type=entity_type, project=project, version=project.version_set.all()[0], - media=media, + media=media.pk, frame=frame, x=x, y=y, @@ -345,7 +345,7 @@ def create_test_line(user, entity_type, project, media, frame, attributes={}): type=entity_type, project=project, version=project.version_set.all()[0], - media=media, + media=media.pk, frame=frame, x=x0, y=y0, @@ -365,7 +365,7 @@ def create_test_dot(user, entity_type, project, media, frame, attributes={}): type=entity_type, project=project, version=project.version_set.all()[0], - media=media, + media=media.pk, frame=frame, x=x, y=y, @@ -1381,7 +1381,6 @@ def test_detail_delete_permissions(self): if expected_status == status.HTTP_200_OK: del self.entities[0] - class AttributeMediaTestMixin: def test_media_with_attr(self): response = self.client.get( @@ -2178,6 +2177,7 @@ def _generate_key(self): class CurrentUserTestCase(TatorTransactionTest): def setUp(self): + super().setUp() logging.disable(logging.CRITICAL) self.user = create_test_user() self.client.force_authenticate(self.user) @@ -2302,6 +2302,7 @@ def test_avatar(self): class ProjectDeleteTestCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -2338,7 +2339,7 @@ def setUp(self): ] for state in self.states: for media in random.choices(self.videos): - state.media.add(media) + add_media_to_state(state, media) def test_delete(self): self.client.delete(f"/rest/Project/{self.project.pk}") @@ -2369,6 +2370,7 @@ def setUp(self): class AlgorithmTestCase(TatorTransactionTest, PermissionListMembershipTestMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -2386,6 +2388,7 @@ def setUp(self): class AnonymousAccessTestCase(TatorTransactionTest): def setUp(self): + super().setUp() logging.disable(logging.CRITICAL) self.user = create_test_user() self.random_user = create_test_user() @@ -2426,8 +2429,8 @@ def setUp(self): self.test_bucket = create_test_bucket(None) resource = Resource(path="fake_key.txt", bucket=self.test_bucket) resource.save() - resource.media.add(self.public_video) - resource.media.add(self.private_video) + add_media_to_resource(resource, self.public_video) + add_media_to_resource(resource, self.private_video) resource.save() def test_random_user(self): @@ -2472,6 +2475,7 @@ class VideoTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -2497,6 +2501,12 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, media=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.media.pk == e.pk + self.media_entities = self.entities self.list_uri = "Medias" self.detail_uri = "Media" @@ -2642,7 +2652,7 @@ def test_search(self): frame=10, version=self.project.version_set.all()[0], ) - state.media.add(self.entities[1]) + add_media_to_state(state, self.entities[1]) state.save() state_2 = State.objects.create( @@ -2653,7 +2663,7 @@ def test_search(self): frame=100, version=self.project.version_set.all()[0], ) - state_2.media.add(self.entities[1]) + add_media_to_state(state_2, self.entities[1]) state_2.save() # Do a frame_state lookup and find the localization at @@ -2964,6 +2974,7 @@ class ImageTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -2987,6 +2998,12 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, media=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.media.pk == e.pk + self.media_entities = self.entities self.list_uri = "Medias" self.detail_uri = "Media" @@ -3015,7 +3032,7 @@ class LocalizationBoxTestCase( def setUp(self): super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) - # logging.disable(logging.CRITICAL) + logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() self.user = create_test_user() self.user_two = create_test_user() @@ -3051,6 +3068,10 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, localization=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.localization.pk == e.pk self.list_uri = "Localizations" self.detail_uri = "Localization" self.create_entity = functools.partial( @@ -3098,6 +3119,7 @@ class LocalizationLineTestCase( AttributeRenameMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() @@ -3135,6 +3157,10 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, localization=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.localization.pk == e.pk self.list_uri = "Localizations" self.detail_uri = "Localization" self.create_entity = functools.partial( @@ -3182,6 +3208,7 @@ class LocalizationDotTestCase( AttributeRenameMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() @@ -3219,6 +3246,10 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, localization=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.localization.pk == e.pk self.list_uri = "Localizations" self.detail_uri = "Localization" self.create_entity = functools.partial( @@ -3264,6 +3295,7 @@ class LocalizationPolyTestCase( AttributeRenameMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() @@ -3301,6 +3333,10 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, localization=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.localization.pk == e.pk self.list_uri = "Localizations" self.detail_uri = "Localization" self.create_entity = functools.partial( @@ -3345,8 +3381,9 @@ class StateTestCase( AttributeRenameMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) - # logging.disable(logging.CRITICAL) + logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() self.user = create_test_user() self.user_two = create_test_user() @@ -3383,8 +3420,12 @@ def setUp(self): attributes={"Float Test": random.random() * 1000}, ) for media in random.choices(self.media_entities): - state.media.add(media) + add_media_to_state(state, media) self.entities.append(state) + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, state=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.state.pk == e.pk self.list_uri = "States" self.detail_uri = "State" self.create_entity = functools.partial( @@ -3458,6 +3499,7 @@ def test_frame_association(self): class LocalizationMediaDeleteCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -3738,6 +3780,7 @@ def test_multiple_media_delete(self): class StateMediaDeleteCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -3902,11 +3945,11 @@ def test_multiple_media_delete(self): self.assertEqual(len(response.data), 2) not_deleted = State.objects.filter( - project=self.project.pk, media__deleted=False + project=self.project.pk, media_proj__deleted=False + ).values_list("id", flat=True) + deleted = State.objects.filter( + project=self.project.pk, media_proj__deleted=True ).values_list("id", flat=True) - deleted = State.objects.filter(project=self.project.pk, media__deleted=True).values_list( - "id", flat=True - ) response = self.client.delete( f"/rest/Medias/{self.project.pk}?attribute={attr_search}", format="json" @@ -3925,11 +3968,11 @@ def test_multiple_media_delete(self): "id", flat=True ) not_deleted = State.objects.filter( - project=self.project.pk, media__deleted=False + project=self.project.pk, media_proj__deleted=False + ).values_list("id", flat=True) + deleted = State.objects.filter( + project=self.project.pk, media_proj__deleted=True ).values_list("id", flat=True) - deleted = State.objects.filter(project=self.project.pk, media__deleted=True).values_list( - "id", flat=True - ) response = self.client.get(f"/rest/States/{self.project.pk}?attribute={attr_search}") self.assertEqual(len(response.data), 0) @@ -3954,6 +3997,7 @@ class LeafTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4077,6 +4121,7 @@ class LeafTypeTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4104,6 +4149,7 @@ class StateTypeTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4151,6 +4197,7 @@ class MediaTypeTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4194,6 +4241,7 @@ class LocalizationTypeTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4266,6 +4314,7 @@ class MembershipTestCase( TatorTransactionTest, PermissionListMembershipTestMixin, PermissionDetailTestMixin ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4289,6 +4338,7 @@ def setUp(self): class ProjectTestCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4453,6 +4503,7 @@ def test_delete_non_creator(self): class TranscodeTestCase(TatorTransactionTest, PermissionCreateTestMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4488,6 +4539,7 @@ class VersionTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4584,6 +4636,7 @@ class FavoriteStateTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) self.user = create_test_user() self.client.force_authenticate(self.user) @@ -4633,6 +4686,7 @@ class FavoriteLocalizationTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4680,6 +4734,7 @@ class BookmarkTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4714,6 +4769,7 @@ class AffiliationTestCase( PermissionDetailAffiliationTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4747,6 +4803,7 @@ def get_organization(self): class OrganizationTestCase(TatorTransactionTest, PermissionDetailAffiliationTestMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user(is_staff=True) @@ -4844,6 +4901,7 @@ class BucketTestCase( TatorTransactionTest, PermissionListAffiliationTestMixin, PermissionDetailAffiliationTestMixin ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4892,6 +4950,7 @@ def test_create_no_affiliation(self): class ImageFileTestCase(TatorTransactionTest, FileMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4926,6 +4985,7 @@ def test_thumbnail_gif(self): class VideoFileTestCase(TatorTransactionTest, FileMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4967,6 +5027,7 @@ def test_archival(self): class AudioFileTestCase(TatorTransactionTest, FileMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4995,6 +5056,7 @@ def test_audio(self): class AuxiliaryFileTestCase(TatorTransactionTest, FileMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -5033,6 +5095,7 @@ class ResourceTestCase(TatorTransactionTest): } def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -5373,8 +5436,8 @@ def test_thumbnails(self): self.assertTrue(self._store_obj_exists(image_key)) self.assertTrue(self._store_obj_exists(thumb_key)) - self.assertEqual(Resource.objects.get(path=image_key).media.all()[0].pk, image_id) - self.assertEqual(Resource.objects.get(path=thumb_key).media.all()[0].pk, image_id) + self.assertEqual(Resource.objects.get(path=image_key).media_proj.all()[0].pk, image_id) + self.assertEqual(Resource.objects.get(path=thumb_key).media_proj.all()[0].pk, image_id) # Delete the media and verify the files are gone. response = self.client.delete(f"/rest/Media/{image_id}", format="json") @@ -5405,8 +5468,8 @@ def test_thumbnails(self): self.assertTrue(self._store_obj_exists(image_key)) self.assertTrue(self._store_obj_exists(thumb_key)) - self.assertEqual(Resource.objects.get(path=image_key).media.all()[0].pk, image_id) - self.assertEqual(Resource.objects.get(path=thumb_key).media.all()[0].pk, image_id) + self.assertEqual(Resource.objects.get(path=image_key).media_proj.all()[0].pk, image_id) + self.assertEqual(Resource.objects.get(path=thumb_key).media_proj.all()[0].pk, image_id) # Delete the media and verify the files are gone. response = self.client.delete(f"/rest/Media/{image_id}", format="json") @@ -5436,8 +5499,8 @@ def test_thumbnails(self): gif_key = video.media_files["thumbnail_gif"][0]["path"] self.assertTrue(self._store_obj_exists(thumb_key)) self.assertTrue(self._store_obj_exists(gif_key)) - self.assertEqual(Resource.objects.get(path=thumb_key).media.all()[0].pk, video_id) - self.assertEqual(Resource.objects.get(path=gif_key).media.all()[0].pk, video_id) + self.assertEqual(Resource.objects.get(path=thumb_key).media_proj.all()[0].pk, video_id) + self.assertEqual(Resource.objects.get(path=gif_key).media_proj.all()[0].pk, video_id) # Delete the media and verify the files are gone. response = self.client.delete(f"/rest/Media/{video_id}", format="json") @@ -5600,6 +5663,7 @@ class ResourceWithBackupTestCase(ResourceTestCase): """This runs the same tests as `ResourceTestCase` but adds project-specific buckets""" def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -5632,6 +5696,7 @@ def setUp(self): class AttributeTestCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -5847,6 +5912,7 @@ class MutateAliasTestCase(TatorTransactionTest): """Tests alias mutation.""" def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -6027,6 +6093,7 @@ def get_affiliation(self, organization, user): return Affiliation.objects.filter(organization=organization, user=user)[0] def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -6104,6 +6171,7 @@ def get_affiliation(self, organization, user): return Affiliation.objects.filter(organization=organization, user=user)[0] def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -6168,6 +6236,7 @@ def test_detail_is_a_member_permissions(self): class UsernameTestCase(TatorTransactionTest): def setUp(self): + super().setUp() self.list_uri = "Users" self.detail_uri = "User" @@ -6210,6 +6279,7 @@ def test_create_case_insensitive_username(self): class SectionTestCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -6731,7 +6801,8 @@ def test_multi_section_lookup(self): class AdvancedPermissionTestCase(TatorTransactionTest): def setUp(self): - logging.disable(logging.CRITICAL) + super().setUp() + # logging.disable(logging.CRITICAL) # Add 9 users names = ["Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Hank", "Ivy"] self.users = [create_test_user(is_staff=False, username=name) for name in names] @@ -6820,13 +6891,13 @@ def setUp(self): self.private_media = [v.pk for v in self.videos[3:6]] for media in self.videos[:3]: - self.public_section.media.add(media) + add_media_to_section(self.public_section, media) self.public_section.save() media.primary_section = self.public_section media.save() for media in self.videos[3:6]: - self.private_section.media.add(media) + add_media_to_section(self.private_section, media) self.private_section.save() media.primary_section = self.private_section media.save() @@ -7038,7 +7109,10 @@ def test_permission_augmentation(self): # Test effective permission for boxes in media match expected result for media in media_qs: - localization_qs = Localization.objects.filter(project=self.project, media=media) + localization_qs = Localization.objects.filter( + project=self.project, media_proj=media + ) + logger.info(f"Checking localizations for {media.pk}") localization_qs = augment_permission(user, localization_qs) if media.primary_section: media_primary_section_pk = media.primary_section.pk diff --git a/api/main/util.py b/api/main/util.py index 703b720bb..0478c14ce 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -853,7 +853,7 @@ def get_clone_info(media: Media) -> dict: media_dict["original"]["media"] = Media.objects.get(pk=media_id) # Shared base queryset - media_qs = Media.objects.filter(resource_media__path__in=paths) + media_qs = Media.objects.filter(resource_media_proj__path__in=paths) media_dict["clones"].update(ele for ele in media_qs.values_list("id", flat=True)) media_dict["clones"].remove(media_dict["original"]["media"].id) else: @@ -1361,3 +1361,79 @@ def cull_low_used_indices(project_id, dry_run=True, population_limit=10000): continue print(f"Deleting index for {t.name} {attr['name']}/{attr['dtype']}") ts.delete_index(t, attr) + + +def fill_lookup_table(project_id, dry_run=False): + unhandled_media = Media.objects.filter(project=project_id, projectlookup__isnull=True) + unhandled_localizations = Localization.objects.filter( + project=project_id, projectlookup__isnull=True + ) + unhandled_states = State.objects.filter(project=project_id, projectlookup__isnull=True) + + print( + f"For {project_id}, need to add:\n\t{unhandled_media.count()} media to lookup table\n\t{unhandled_localizations.count()} localizations to lookup table\n\t{unhandled_states.count()} states to lookup table" + ) + + if dry_run: + return + + # Break into 500-element chunks + data = [] + for m in unhandled_media.iterator(chunk_size=500): + data.append(ProjectLookup(project_id=project_id, media_id=m.id)) + if len(data) == 500: + ProjectLookup.objects.bulk_create(data) + data = [] + if data: + ProjectLookup.objects.bulk_create(data) + data = [] + + print(f"Completed media for {project_id}") + for l in unhandled_localizations.iterator(chunk_size=500): + data.append(ProjectLookup(project_id=project_id, localization_id=l.id)) + if len(data) == 500: + ProjectLookup.objects.bulk_create(data) + data = [] + if data: + ProjectLookup.objects.bulk_create(data) + data = [] + print(f"Completed localizations for {project_id}") + + for s in unhandled_states.iterator(chunk_size=500): + data.append(ProjectLookup(project_id=project_id, state_id=s.id)) + if len(data) == 500: + ProjectLookup.objects.bulk_create(data) + data = [] + if data: + ProjectLookup.objects.bulk_create(data) + print(f"Completed states for {project_id}") + data = [] + + +def migrate_old_many_to_many(): + from django.db import connection + + with connection.cursor() as cursor: + # Handle resource many to many first + cursor.execute("DELETE FROM main_resourcemediam2m") + cursor.execute( + 'INSERT INTO main_resourcemediam2m (resource_id, media_id, project_id) SELECT resource_id, "main_media".id, "main_media".project FROM main_resource_media LEFT OUTER JOIN "main_media" ON ("main_resource_media"."media_id" = "main_media"."id")' + ) + + # Now sections + cursor.execute("DELETE FROM main_sectionmediam2m") + cursor.execute( + 'INSERT INTO main_sectionmediam2m (section_id, media_id, project_id) SELECT section_id, "main_media".id, "main_media".project FROM main_section_media LEFT OUTER JOIN "main_media" ON ("main_section_media"."media_id" = "main_media"."id")' + ) + + # state-media + cursor.execute("DELETE FROM main_statemediam2m") + cursor.execute( + 'INSERT INTO main_statemediam2m (state_id, media_id, project_id) SELECT state_id, "main_media".id, "main_media".project FROM main_state_media LEFT OUTER JOIN "main_media" ON ("main_state_media"."media_id" = "main_media"."id")' + ) + + # state-localization + cursor.execute("DELETE FROM main_statelocalizationm2m") + cursor.execute( + 'INSERT INTO main_statelocalizationm2m (state_id, localization_id, project_id) SELECT state_id, "main_localization".id, "main_localization".project FROM main_state_localizations LEFT OUTER JOIN "main_localization" ON ("main_state_localizations"."localization_id" = "main_localization"."id")' + )