From 43338b3743db909b34952f0a17250eb80b10f04d Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 3 Sep 2024 14:32:10 -0400 Subject: [PATCH 01/52] Construct table for project lookup by various ids Add utility to backfill lookup table --- api/main/models.py | 16 ++++++++++++++++ api/main/util.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/api/main/models.py b/api/main/models.py index e65794a03..aa4378f94 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -2463,6 +2463,22 @@ class Meta: ] +class ProjectLookup(Model): + """This Table defines an easy way to look up the project associated with a given object""" + + media = ForeignKey(Media, on_delete=CASCADE, null=True, blank=True) + localization = ForeignKey(Localization, on_delete=CASCADE, null=True, blank=True) + state = ForeignKey(State, on_delete=CASCADE, null=True, blank=True) + project = ForeignKey(Project, on_delete=CASCADE, null=True, blank=True) + class Meta: + constraints = [ + UniqueConstraint( + fields=["project", "media", "localization", "state"], + name="lookup_uniqueness_check", + ) + ] + + # Structure to handle identifying columns with project-scoped indices # e.g. Not relaying solely on `db_index=True` in django. BUILT_IN_INDICES = { diff --git a/api/main/util.py b/api/main/util.py index 703b720bb..998a71a98 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -1361,3 +1361,34 @@ def cull_low_used_indices(project_id, dry_run=True, population_limit=10000): continue print(f"Deleting index for {t.name} {attr['name']}/{attr['dtype']}") ts.delete_index(t, attr) + + +def fill_lookup_table(project_id, dry_run=False): + unhandled_media = Media.objects.filter(project=project_id, projectlookup__isnull=True) + unhandled_localizations = Localization.objects.filter( + project=project_id, projectlookup__isnull=True + ) + unhandled_states = State.objects.filter(project=project_id, projectlookup__isnull=True) + + print( + "For {project_id}, need to add:\n\t{unhandled_media.count()} media to lookup table\n\t{unhandled_localizations.count()} localizations to lookup table\n\t{unhandled_states.count()} states to lookup table" + ) + + if dry_run: + return + + # Break into 500-element chunks + for chunk in unhandled_media.iterator(chunk_size=500): + ProjectLookup.objects.bulk_create( + [ProjectLookUp(project=project_id, media_id=m.id) for m in chunk] + ) + + for chunk in unhandled_localizations.iterator(chunk_size=500): + ProjectLookup.objects.bulk_create( + [ProjectLookUp(project=project_id, localization_id=l.id) for l in chunk] + ) + + for chunk in unhandled_states.iterator(chunk_size=500): + ProjectLookup.objects.bulk_create( + [ProjectLookUp(project=project_id, state_id=s.id) for s in chunk] + ) From 22e90bf8036a75be6a9f6a4307f16bc29c70c541 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 3 Sep 2024 15:40:41 -0400 Subject: [PATCH 02/52] Handle backporting existing elements into global lookup table --- api/main/util.py | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/api/main/util.py b/api/main/util.py index 998a71a98..efa6c4bcf 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -1371,24 +1371,40 @@ def fill_lookup_table(project_id, dry_run=False): unhandled_states = State.objects.filter(project=project_id, projectlookup__isnull=True) print( - "For {project_id}, need to add:\n\t{unhandled_media.count()} media to lookup table\n\t{unhandled_localizations.count()} localizations to lookup table\n\t{unhandled_states.count()} states to lookup table" + f"For {project_id}, need to add:\n\t{unhandled_media.count()} media to lookup table\n\t{unhandled_localizations.count()} localizations to lookup table\n\t{unhandled_states.count()} states to lookup table" ) if dry_run: return # Break into 500-element chunks - for chunk in unhandled_media.iterator(chunk_size=500): - ProjectLookup.objects.bulk_create( - [ProjectLookUp(project=project_id, media_id=m.id) for m in chunk] - ) - - for chunk in unhandled_localizations.iterator(chunk_size=500): - ProjectLookup.objects.bulk_create( - [ProjectLookUp(project=project_id, localization_id=l.id) for l in chunk] - ) - - for chunk in unhandled_states.iterator(chunk_size=500): - ProjectLookup.objects.bulk_create( - [ProjectLookUp(project=project_id, state_id=s.id) for s in chunk] - ) + data = [] + for m in unhandled_media.iterator(chunk_size=500): + data.append(ProjectLookup(project_id=project_id, media_id=m.id)) + if len(data) == 500: + ProjectLookup.objects.bulk_create(data) + data = [] + if data: + ProjectLookup.objects.bulk_create(data) + data = [] + + print(f"Completed media for {project_id}") + for l in unhandled_localizations.iterator(chunk_size=500): + data.append(ProjectLookup(project_id=project_id, media_id=l.id)) + if len(data) == 500: + ProjectLookup.objects.bulk_create(data) + data = [] + if data: + ProjectLookup.objects.bulk_create(data) + data = [] + print(f"Completed localizations for {project_id}") + + for s in unhandled_states.iterator(chunk_size=500): + data.append(ProjectLookup(project_id=project_id, state_id=s.id)) + if len(data) == 500: + ProjectLookup.objects.bulk_create(data) + data = [] + if data: + ProjectLookup.objects.bulk_create(data) + print(f"Completed states for {project_id}") + data = [] From 0d486565fa822f8bc74fe335d8c8c32ad87efe48 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 3 Sep 2024 16:29:26 -0400 Subject: [PATCH 03/52] Add prepared statements for trigger usage --- api/main/models.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/api/main/models.py b/api/main/models.py index aa4378f94..ddc54e018 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -155,6 +155,18 @@ def create_prepared_statements(cursor): "PREPARE update_latest_mark_state(UUID, INT) AS UPDATE main_state SET latest_mark=(SELECT COALESCE(MAX(mark),0) FROM main_state WHERE elemental_id=$1 AND version=$2 AND deleted=FALSE) WHERE elemental_id=$1 AND version=$2;" ) + cursor.execute( + "PREPARE update_lookup_media(INT, INT) AS INSERT INTO main_projectlookup (project_id, media_id) VALUES ($1, $2) ON CONFLICT DO NOTHING;" + ) + + cursor.execute( + "PREPARE update_lookup_localization(INT, INT) AS INSERT INTO main_projectlookup (project_id, localization_id) VALUES ($1, $2) ON CONFLICT DO NOTHING;" + ) + + cursor.execute( + "PREPARE update_lookup_state(INT, INT) AS INSERT INTO main_projectlookup (project_id, state_id) VALUES ($1, $2) ON CONFLICT DO NOTHING;" + ) + class ModelDiffMixin(object): """ From c06344bae70a020a4bfab3fa43580dccf717f15f Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 3 Sep 2024 16:39:41 -0400 Subject: [PATCH 04/52] Hook up triggers to keep lookup table consistent --- api/main/models.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/api/main/models.py b/api/main/models.py index ddc54e018..5593cef57 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -115,6 +115,12 @@ RETURN NEW; """ +UPDATE_LOOKUP_TRIGGER_FUNC = """ +SET plan_cache_mode=force_generic_plan; +EXECUTE format('EXECUTE update_lookup_{0}(%s,%s)', NEW.project::integer, NEW.id::integer); +SET plan_cache_mode=auto; +RETURN NEW; +""" # Register prepared statements for the triggers to optimize performance on creation of a database connection @receiver(connection_created) @@ -1401,6 +1407,15 @@ class Media(Model, ModelDiffMixin): """ + class Meta: + triggers = [ + pgtrigger.Trigger( + name="post_media_update_lookup", + operation=pgtrigger.Insert, + when=pgtrigger.After, + func=UPDATE_LOOKUP_TRIGGER_FUNC.format("media"), + ), + ] project = ForeignKey( Project, @@ -1873,6 +1888,12 @@ class Meta: declare=[("_var", "integer")], func=AFTER_MARK_TRIGGER_FUNC.format("localization"), ), + pgtrigger.Trigger( + name="post_localization_update_lookup", + operation=pgtrigger.Insert, + when=pgtrigger.After, + func=UPDATE_LOOKUP_TRIGGER_FUNC.format("localization"), + ), ] project = ForeignKey(Project, on_delete=SET_NULL, null=True, blank=True, db_column="project") @@ -1978,6 +1999,12 @@ class Meta: declare=[("_var", "integer")], func=AFTER_MARK_TRIGGER_FUNC.format("state"), ), + pgtrigger.Trigger( + name="post_state_update_lookup", + operation=pgtrigger.Insert, + when=pgtrigger.After, + func=UPDATE_LOOKUP_TRIGGER_FUNC.format("state"), + ), ] project = ForeignKey(Project, on_delete=SET_NULL, null=True, blank=True, db_column="project") From 312eca9b68fe2a0836b2d82eae7c0594cb5e50d2 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 3 Sep 2024 17:28:26 -0400 Subject: [PATCH 05/52] Move this setup to an earlier point to get proper insertion point --- api/main/tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/api/main/tests.py b/api/main/tests.py index 274389b8a..07de7d5be 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -44,8 +44,9 @@ class TatorTransactionTest(APITransactionTestCase): """Handle cases when test runner flushes DB and indices are still being made.""" - def setUp(self): - # Need to do this for first test in a test db instance + @classmethod + def setUpClass(cls): + super().setUpClass() with connection.cursor() as cursor: create_prepared_statements(cursor) From 88093be7f9724518930ec50daa96e28e75883935 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 3 Sep 2024 17:33:46 -0400 Subject: [PATCH 06/52] Add lookup checks to unit tests (Media) --- api/main/tests.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/api/main/tests.py b/api/main/tests.py index 07de7d5be..6f23cf557 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -1382,7 +1382,6 @@ def test_detail_delete_permissions(self): if expected_status == status.HTTP_200_OK: del self.entities[0] - class AttributeMediaTestMixin: def test_media_with_attr(self): response = self.client.get( @@ -2498,6 +2497,12 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, media=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.media.pk == e.pk + self.media_entities = self.entities self.list_uri = "Medias" self.detail_uri = "Media" @@ -2988,6 +2993,12 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, media=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.media.pk == e.pk + self.media_entities = self.entities self.list_uri = "Medias" self.detail_uri = "Media" From caadcc70683980e12cfb0a8e86f1716244bacc96 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 3 Sep 2024 20:33:25 -0400 Subject: [PATCH 07/52] Revert "Move this setup to an earlier point to get proper insertion point" This reverts commit 69623cb340e993dc1f86343bd308b9f9d28d57b7. --- api/main/tests.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/api/main/tests.py b/api/main/tests.py index 6f23cf557..1bff1e836 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -44,9 +44,8 @@ class TatorTransactionTest(APITransactionTestCase): """Handle cases when test runner flushes DB and indices are still being made.""" - @classmethod - def setUpClass(cls): - super().setUpClass() + def setUp(self): + # Need to do this for first test in a test db instance with connection.cursor() as cursor: create_prepared_statements(cursor) From b7a0f96b528dbbbaa7d30883352e4ee67368f374 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 3 Sep 2024 20:37:31 -0400 Subject: [PATCH 08/52] Call parent super function --- api/main/tests.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/api/main/tests.py b/api/main/tests.py index 1bff1e836..ec810a24e 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -2177,6 +2177,7 @@ def _generate_key(self): class CurrentUserTestCase(TatorTransactionTest): def setUp(self): + super().setUp() logging.disable(logging.CRITICAL) self.user = create_test_user() self.client.force_authenticate(self.user) @@ -2301,6 +2302,7 @@ def test_avatar(self): class ProjectDeleteTestCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -2368,6 +2370,7 @@ def setUp(self): class AlgorithmTestCase(TatorTransactionTest, PermissionListMembershipTestMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -2385,6 +2388,7 @@ def setUp(self): class AnonymousAccessTestCase(TatorTransactionTest): def setUp(self): + super().setUp() logging.disable(logging.CRITICAL) self.user = create_test_user() self.random_user = create_test_user() @@ -2471,6 +2475,7 @@ class VideoTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -2969,6 +2974,7 @@ class ImageTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -3109,6 +3115,7 @@ class LocalizationLineTestCase( AttributeRenameMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() @@ -3193,6 +3200,7 @@ class LocalizationDotTestCase( AttributeRenameMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() @@ -3275,6 +3283,7 @@ class LocalizationPolyTestCase( AttributeRenameMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() @@ -3356,6 +3365,7 @@ class StateTestCase( AttributeRenameMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) # logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() @@ -3469,6 +3479,7 @@ def test_frame_association(self): class LocalizationMediaDeleteCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -3749,6 +3760,7 @@ def test_multiple_media_delete(self): class StateMediaDeleteCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -3965,6 +3977,7 @@ class LeafTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4088,6 +4101,7 @@ class LeafTypeTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4115,6 +4129,7 @@ class StateTypeTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4162,6 +4177,7 @@ class MediaTypeTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4205,6 +4221,7 @@ class LocalizationTypeTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4277,6 +4294,7 @@ class MembershipTestCase( TatorTransactionTest, PermissionListMembershipTestMixin, PermissionDetailTestMixin ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4300,6 +4318,7 @@ def setUp(self): class ProjectTestCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4464,6 +4483,7 @@ def test_delete_non_creator(self): class TranscodeTestCase(TatorTransactionTest, PermissionCreateTestMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4499,6 +4519,7 @@ class VersionTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4595,6 +4616,7 @@ class FavoriteStateTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) self.user = create_test_user() self.client.force_authenticate(self.user) @@ -4644,6 +4666,7 @@ class FavoriteLocalizationTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4691,6 +4714,7 @@ class BookmarkTestCase( PermissionDetailTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4725,6 +4749,7 @@ class AffiliationTestCase( PermissionDetailAffiliationTestMixin, ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4758,6 +4783,7 @@ def get_organization(self): class OrganizationTestCase(TatorTransactionTest, PermissionDetailAffiliationTestMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user(is_staff=True) @@ -4855,6 +4881,7 @@ class BucketTestCase( TatorTransactionTest, PermissionListAffiliationTestMixin, PermissionDetailAffiliationTestMixin ): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4903,6 +4930,7 @@ def test_create_no_affiliation(self): class ImageFileTestCase(TatorTransactionTest, FileMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4937,6 +4965,7 @@ def test_thumbnail_gif(self): class VideoFileTestCase(TatorTransactionTest, FileMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -4978,6 +5007,7 @@ def test_archival(self): class AudioFileTestCase(TatorTransactionTest, FileMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -5006,6 +5036,7 @@ def test_audio(self): class AuxiliaryFileTestCase(TatorTransactionTest, FileMixin): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -5044,6 +5075,7 @@ class ResourceTestCase(TatorTransactionTest): } def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -5611,6 +5643,7 @@ class ResourceWithBackupTestCase(ResourceTestCase): """This runs the same tests as `ResourceTestCase` but adds project-specific buckets""" def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -5643,6 +5676,7 @@ def setUp(self): class AttributeTestCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -5858,6 +5892,7 @@ class MutateAliasTestCase(TatorTransactionTest): """Tests alias mutation.""" def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -6038,6 +6073,7 @@ def get_affiliation(self, organization, user): return Affiliation.objects.filter(organization=organization, user=user)[0] def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -6115,6 +6151,7 @@ def get_affiliation(self, organization, user): return Affiliation.objects.filter(organization=organization, user=user)[0] def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() @@ -6179,6 +6216,7 @@ def test_detail_is_a_member_permissions(self): class UsernameTestCase(TatorTransactionTest): def setUp(self): + super().setUp() self.list_uri = "Users" self.detail_uri = "User" @@ -6221,6 +6259,7 @@ def test_create_case_insensitive_username(self): class SectionTestCase(TatorTransactionTest): def setUp(self): + super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) logging.disable(logging.CRITICAL) self.user = create_test_user() From a3078ac7963cb1c55404313ab518136f9b2ee826 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 3 Sep 2024 20:43:26 -0400 Subject: [PATCH 09/52] Add lookup check to locals + states --- api/main/tests.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/api/main/tests.py b/api/main/tests.py index ec810a24e..434f2d31c 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -3068,6 +3068,10 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, localization=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.localization.pk == e.pk self.list_uri = "Localizations" self.detail_uri = "Localization" self.create_entity = functools.partial( @@ -3153,6 +3157,10 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, localization=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.localization.pk == e.pk self.list_uri = "Localizations" self.detail_uri = "Localization" self.create_entity = functools.partial( @@ -3238,6 +3246,10 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, localization=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.localization.pk == e.pk self.list_uri = "Localizations" self.detail_uri = "Localization" self.create_entity = functools.partial( @@ -3321,6 +3333,10 @@ def setUp(self): ) for idx in range(random.randint(6, 10)) ] + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, localization=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.localization.pk == e.pk self.list_uri = "Localizations" self.detail_uri = "Localization" self.create_entity = functools.partial( @@ -3406,6 +3422,10 @@ def setUp(self): for media in random.choices(self.media_entities): state.media.add(media) self.entities.append(state) + for e in self.entities: + lookup = ProjectLookup.objects.get(project=e.project, state=e.pk) + assert lookup.project.pk == e.project.pk + assert lookup.state.pk == e.pk self.list_uri = "States" self.detail_uri = "State" self.create_entity = functools.partial( From 93be9fe698f83c48c1bfdc934110cb36250eb224 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Wed, 4 Sep 2024 14:47:48 -0400 Subject: [PATCH 10/52] Encorporate project into individual detail accessors [Conflict] # Conflicts: # api/main/rest/localization.py # api/main/rest/media.py # api/main/rest/state.py --- api/main/rest/localization.py | 12 ++++++++++-- api/main/rest/media.py | 7 ++++++- api/main/rest/state.py | 16 ++++++++++++++-- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/api/main/rest/localization.py b/api/main/rest/localization.py index 6fa141c93..185cf19e5 100644 --- a/api/main/rest/localization.py +++ b/api/main/rest/localization.py @@ -12,6 +12,7 @@ from ..models import Project from ..models import Version from ..models import Section +from ..models import ProjectLookup from ..schema import LocalizationListSchema from ..schema import LocalizationDetailSchema, LocalizationByElementalIdSchema from ..schema.components import localization as localization_schema @@ -576,7 +577,11 @@ def get_permissions(self): def get_queryset(self, **kwargs): return self.filter_only_viewables( - Localization.objects.filter(pk=self.params["id"], deleted=False) + Localization.objects.filter( + project=ProjectLookup.objects.get(localization=self.params["id"]).project, + pk=self.params["id"], + deleted=False, + ) ) def _get(self, params): @@ -608,8 +613,11 @@ def get_queryset(self, **kwargs): include_deleted = False if params.get("prune", None) == 1: include_deleted = True + version_obj = Version.objects.get(pk=params["version"]) qs = Localization.objects.filter( - elemental_id=params["elemental_id"], version=params["version"] + project=version_obj.project, + elemental_id=params["elemental_id"], + version=params["version"], ) if include_deleted is False: qs = qs.filter(deleted=False) diff --git a/api/main/rest/media.py b/api/main/rest/media.py index 3d4638b03..449d72bb1 100644 --- a/api/main/rest/media.py +++ b/api/main/rest/media.py @@ -28,6 +28,7 @@ database_qs, database_query_ids, Version, + ProjectLookup, ) from .._permission_util import PermissionMask @@ -919,4 +920,8 @@ def _delete(self, params): return {"message": f'Media {params["id"]} successfully deleted!'} def get_queryset(self, **kwargs): - return Media.objects.filter(pk=self.params["id"], deleted=False) + return Media.objects.filter( + project=ProjectLookup.objects.get(media=self.params["id"]).project, + pk=self.params["id"], + deleted=False, + ) diff --git a/api/main/rest/state.py b/api/main/rest/state.py index 587b0ba67..13f60be97 100644 --- a/api/main/rest/state.py +++ b/api/main/rest/state.py @@ -17,6 +17,7 @@ from ..models import Project from ..models import Membership from ..models import Version +from ..models import ProjectLookup from ..models import User from ..models import InterpolationMethods from ..models import Section @@ -798,7 +799,13 @@ def get_permissions(self): return super().get_permissions() def get_queryset(self, **kwargs): - return self.filter_only_viewables(State.objects.filter(pk=self.params["id"], deleted=False)) + return self.filter_only_viewables( + State.objects.filter( + project=ProjectLookup.objects.get(state=self.params["id"]).project, + pk=self.params["id"], + deleted=False, + ) + ) def _get(self, params): return self.get_qs(params, self.get_queryset()) @@ -839,7 +846,12 @@ def get_queryset(self, **kwargs): include_deleted = False if params.get("prune", None) == 1: include_deleted = True - qs = State.objects.filter(elemental_id=params["elemental_id"], version=params["version"]) + version_obj = Version.objects.get(pk=params["version"]) + qs = State.objects.filter( + project=version_obj.project, + elemental_id=params["elemental_id"], + version=params["version"], + ) if include_deleted is False: qs = qs.filter(deleted=False) From d12285cb0076f8bf38fa8f0ae46b6e126ca81534 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Wed, 4 Sep 2024 15:02:16 -0400 Subject: [PATCH 11/52] Fix test issue --- api/main/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/main/tests.py b/api/main/tests.py index 434f2d31c..da27ed36b 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -6801,6 +6801,7 @@ def test_multi_section_lookup(self): class AdvancedPermissionTestCase(TatorTransactionTest): def setUp(self): + super().setUp() logging.disable(logging.CRITICAL) # Add 9 users names = ["Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Hank", "Ivy"] From 0d8e920990bc38ee49caf7e951dddf2ec0cefce9 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Thu, 5 Sep 2024 10:22:48 -0400 Subject: [PATCH 12/52] Fix typo --- api/main/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/main/util.py b/api/main/util.py index efa6c4bcf..9b9a8959c 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -1390,7 +1390,7 @@ def fill_lookup_table(project_id, dry_run=False): print(f"Completed media for {project_id}") for l in unhandled_localizations.iterator(chunk_size=500): - data.append(ProjectLookup(project_id=project_id, media_id=l.id)) + data.append(ProjectLookup(project_id=project_id, localization_id=l.id)) if len(data) == 500: ProjectLookup.objects.bulk_create(data) data = [] From ef216b95b667833ab684372d6eb5ca9107b278f1 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Thu, 5 Sep 2024 14:15:47 -0400 Subject: [PATCH 13/52] Import model extensions. --- api/main/models.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/api/main/models.py b/api/main/models.py index 5593cef57..c87033b70 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -2551,3 +2551,12 @@ class Meta: {"name": "$path", "dtype": "upper_string"}, ], } + +if os.getenv("TATOR_EXT_MODELS", None): + import importlib + for module in os.getenv("TATOR_EXT_MODELS").split(","): + try: + importlib.import_module(module) + except Exception as e: + logger.error(f"Failed to import module {module}: {e}") + assert False \ No newline at end of file From ce6030ee65b6cbcbfb58013e6681f83863aa8c18 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Fri, 6 Sep 2024 10:33:06 -0400 Subject: [PATCH 14/52] Add foreign object import --- api/main/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/main/models.py b/api/main/models.py index c87033b70..f7f17068c 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -9,6 +9,7 @@ from django.contrib.contenttypes.models import ContentType from django.contrib.gis.db.models import Model from django.contrib.gis.db.models import ForeignKey +from django.contrib.gis.db.models import ForeignObject from django.contrib.gis.db.models import ManyToManyField from django.contrib.gis.db.models import OneToOneField from django.contrib.gis.db.models import CharField From fdc54e0f585c6079e7b9fcebaae36b7cf3c08a06 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Fri, 6 Sep 2024 10:54:07 -0400 Subject: [PATCH 15/52] Add foreign object to replace foreign key --- api/main/models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index f7f17068c..a72abfda7 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1924,7 +1924,14 @@ class Meta: db_column="modified_by", ) user = ForeignKey(User, on_delete=PROTECT, db_column="user") - media = ForeignKey(Media, on_delete=SET_NULL, null=True, blank=True, db_column="media") + media_id = IntegerField(null=True, blank=True, db_column="media") + media = ForeignObject( + Media, + on_delete=CASCADE, + to_fields=["project", "id"], + from_fields=["project", "media"], + null=True, + ) frame = PositiveIntegerField(null=True, blank=True) thumbnail_image = ForeignKey( Media, @@ -2560,4 +2567,4 @@ class Meta: importlib.import_module(module) except Exception as e: logger.error(f"Failed to import module {module}: {e}") - assert False \ No newline at end of file + assert False From e2fcd354b8f525255425644b1f3c1a16bdd33a77 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Fri, 6 Sep 2024 11:08:31 -0400 Subject: [PATCH 16/52] Fix missing keyword argument --- api/main/models.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index a72abfda7..8cce173c8 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1924,13 +1924,14 @@ class Meta: db_column="modified_by", ) user = ForeignKey(User, on_delete=PROTECT, db_column="user") - media_id = IntegerField(null=True, blank=True, db_column="media") - media = ForeignObject( - Media, + media = IntegerField(null=True, blank=True, db_column="media") + media_proj = ForeignObject( + to=Media, on_delete=CASCADE, to_fields=["project", "id"], from_fields=["project", "media"], null=True, + name="media_proj", ) frame = PositiveIntegerField(null=True, blank=True) thumbnail_image = ForeignKey( From 37b4e99bb248caa28f12bb166a3b6d666013aad2 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Fri, 6 Sep 2024 12:42:36 -0400 Subject: [PATCH 17/52] Fix new usage of media field --- api/main/rest/localization.py | 2 +- api/main/tests.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/main/rest/localization.py b/api/main/rest/localization.py index 185cf19e5..814aa2468 100644 --- a/api/main/rest/localization.py +++ b/api/main/rest/localization.py @@ -212,7 +212,7 @@ def _post(self, params): Localization( project=project, type=metas[loc_spec["type"]], - media=medias[loc_spec["media_id"]], + media=medias[loc_spec["media_id"]].pk, user=compute_user( project, self.request.user, loc_spec.get("user_elemental_id", None) ), diff --git a/api/main/tests.py b/api/main/tests.py index da27ed36b..12fb3313e 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -294,7 +294,7 @@ def create_test_box(user, entity_type, project, media, frame, attributes={}, ver type=entity_type, project=project, version=version, - media=media, + media=media.pk, frame=frame, x=x, y=y, @@ -316,7 +316,7 @@ def make_box_obj(user, entity_type, project, media, frame, attributes={}): type=entity_type, project=project, version=project.version_set.all()[0], - media=media, + media=media.pk, frame=frame, x=x, y=y, @@ -345,7 +345,7 @@ def create_test_line(user, entity_type, project, media, frame, attributes={}): type=entity_type, project=project, version=project.version_set.all()[0], - media=media, + media=media.pk, frame=frame, x=x0, y=y0, @@ -365,7 +365,7 @@ def create_test_dot(user, entity_type, project, media, frame, attributes={}): type=entity_type, project=project, version=project.version_set.all()[0], - media=media, + media=media.pk, frame=frame, x=x, y=y, @@ -3032,7 +3032,7 @@ class LocalizationBoxTestCase( def setUp(self): super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) - # logging.disable(logging.CRITICAL) + logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() self.user = create_test_user() self.user_two = create_test_user() From 9499c8369851e12680f39959561ca12a4484c83d Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 10:54:50 -0400 Subject: [PATCH 18/52] Switch these to tuples vs lists --- api/main/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index 8cce173c8..106cbf8ad 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1928,8 +1928,8 @@ class Meta: media_proj = ForeignObject( to=Media, on_delete=CASCADE, - to_fields=["project", "id"], - from_fields=["project", "media"], + to_fields=("project", "id"), + from_fields=("project", "media"), null=True, name="media_proj", ) From 05ff0a1c78f4bda82a80b5b6b0705d08e9f1316c Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 11:58:47 -0400 Subject: [PATCH 19/52] Utilize new composite key --- api/main/_permission_util.py | 6 +++--- api/main/tests.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/api/main/_permission_util.py b/api/main/_permission_util.py index be2e09c6d..940541129 100644 --- a/api/main/_permission_util.py +++ b/api/main/_permission_util.py @@ -499,7 +499,7 @@ def augment_permission(user, qs): # if model == Localization: - qs = qs.annotate(section=F("media__primary_section__pk")) + qs = qs.annotate(section=F("media_proj__primary_section__pk")) elif model == State: sb = Subquery( Media.objects.filter(state__pk=OuterRef("pk")).values("primary_section__pk")[:1] @@ -507,7 +507,7 @@ def augment_permission(user, qs): qs = qs.annotate(section=sb) # Calculate a dictionary for permissions by section and version in this set - effected_media = qs.values("media__pk") + effected_media = qs.values("media_proj__pk") effected_sections = ( Section.objects.filter(project=project, media__in=effected_media) .values("pk") @@ -538,7 +538,7 @@ def augment_permission(user, qs): } section_cases = [ - When(media__primary_section=section, then=Value(perm)) + When(media_proj__primary_section=section, then=Value(perm)) for section, perm in section_perm_dict.items() ] version_cases = [ diff --git a/api/main/tests.py b/api/main/tests.py index 12fb3313e..97667440d 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -7109,7 +7109,9 @@ def test_permission_augmentation(self): # Test effective permission for boxes in media match expected result for media in media_qs: - localization_qs = Localization.objects.filter(project=self.project, media=media) + localization_qs = Localization.objects.filter( + project=self.project, media_proj=media + ) localization_qs = augment_permission(user, localization_qs) if media.primary_section: media_primary_section_pk = media.primary_section.pk From c68d8e6c880697b18b6eca245c351e187596a3ee Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 12:35:34 -0400 Subject: [PATCH 20/52] Add prepared statements for delete calls too --- api/main/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/main/models.py b/api/main/models.py index 106cbf8ad..78d1a971c 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -127,7 +127,7 @@ @receiver(connection_created) def on_connection_created(sender, connection, **kwargs): http_method = get_http_method() - if http_method in ["PATCH", "POST"]: + if http_method in ["PATCH", "POST", "DELETE"]: logger.info( f"{http_method} detected, creating prepared statements." ) # useful for testing purposes From ec5fd6c236fd9fdedf5fd5e580cb45656ba99655 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 12:35:42 -0400 Subject: [PATCH 21/52] Use new composite key --- api/main/rest/localization_graphic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/main/rest/localization_graphic.py b/api/main/rest/localization_graphic.py index 2cacaf4de..93b7c77b8 100644 --- a/api/main/rest/localization_graphic.py +++ b/api/main/rest/localization_graphic.py @@ -256,7 +256,7 @@ def _get(self, params: dict): # By reaching here, it's expected that the graphics mode is to create a new # thumbnail using the provided parameters. That new thumbnail is returned with tempfile.TemporaryDirectory() as temp_dir: - media_util = MediaUtil(video=obj.media, temp_dir=temp_dir) + media_util = MediaUtil(video=obj.media_proj, temp_dir=temp_dir) roi = self._getRoi( obj=obj, From c08d02f2b163f29a399c645d09eda305802e1e4c Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 13:07:47 -0400 Subject: [PATCH 22/52] Switch to composite key ManyToMany --- api/main/models.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index 78d1a971c..4ca08cebb 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1680,10 +1680,14 @@ class File(Model, ModelDiffMixin): ) """ Unique ID for a to facilitate cross-cluster sync operations """ - class Resource(Model): path = CharField(db_index=True, max_length=256) - media = ManyToManyField(Media, related_name="resource_media") + media = ManyToManyField( + Media, + related_name="resource_media", + through="ResourceMedia", + through_fields=("resource", "media_proj"), + ) generic_files = ManyToManyField(File, related_name="resource_files") bucket = ForeignKey(Bucket, on_delete=PROTECT, null=True, blank=True, related_name="bucket") backup_bucket = ForeignKey( @@ -1800,6 +1804,19 @@ def restore_resource(path, domain): return TatorBackupManager().finish_restore_resource(path, project, domain) +class ResourceMedia(Model): + resource = ForeignKey(Resource, on_delete=CASCADE) + media = ForeignKey(Media, on_delete=CASCADE) + project = ForeignKey(Project, on_delete=CASCADE) + media_proj = ForeignObject( + to=Media, + on_delete=CASCADE, + from_fields=("project", "media"), + to_fields=("project", "id"), + related_name="media_proj", + ) + + @receiver(post_save, sender=Media) def media_save(sender, instance, created, **kwargs): if instance.media_files and created: From 4115e9ec81eb58c8e4eea8d0d6bcb5ab6af66746 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 13:12:25 -0400 Subject: [PATCH 23/52] Add new composite-key many to many field to Resource objects --- api/main/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index 4ca08cebb..439332984 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1682,9 +1682,10 @@ class File(Model, ModelDiffMixin): class Resource(Model): path = CharField(db_index=True, max_length=256) - media = ManyToManyField( + media = ManyToManyField(Media, related_name="resource_media") + media_proj = ManyToManyField( Media, - related_name="resource_media", + related_name="resource_media_proj", through="ResourceMedia", through_fields=("resource", "media_proj"), ) From 49c90879f048f0489bb95fe47fd553999e6fb7bd Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 13:38:09 -0400 Subject: [PATCH 24/52] Switch to composite key for ResourceMedia --- api/main/models.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index 439332984..c5a397e27 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1682,7 +1682,8 @@ class File(Model, ModelDiffMixin): class Resource(Model): path = CharField(db_index=True, max_length=256) - media = ManyToManyField(Media, related_name="resource_media") + # Comment this out to find all usages + # media = ManyToManyField(Media, related_name="resource_media") media_proj = ManyToManyField( Media, related_name="resource_media_proj", @@ -1720,7 +1721,7 @@ def add_resource(path_or_link, media, generic_file=None): if created: obj.bucket = media.project.bucket obj.save() - obj.media.add(media) + ResourceMedia.objects.create(resource=obj, media=media, project=media.project) @staticmethod @transaction.atomic @@ -1733,7 +1734,7 @@ def delete_resource(path_or_link, project_id): obj = Resource.objects.get(path=path) # If any media or generic files still reference this resource, don't delete it - if obj.media.all().count() > 0 or obj.generic_files.all().count() > 0: + if obj.media_proj.all().count() > 0 or obj.generic_files.all().count() > 0: return logger.info(f"Deleting object {path}") @@ -1856,7 +1857,8 @@ def drop_media_from_resource(path, media): try: logger.info(f"Dropping media {media} from resource {path}") obj = Resource.objects.get(path=path) - obj.media.remove(media) + matches = ResourceMedia.objects.filter(resource=obj, media=media) + matches.delete() except: logger.warning(f"Could not remove {media} from {path}", exc_info=True) From 33b0ab47f791def051ac68dac79c2c4079aaedb3 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 13:43:44 -0400 Subject: [PATCH 25/52] Fix up tests for new lookup strategy --- api/main/tests.py | 12 ++++++------ api/main/util.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/api/main/tests.py b/api/main/tests.py index 97667440d..590c10e6b 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -5436,8 +5436,8 @@ def test_thumbnails(self): self.assertTrue(self._store_obj_exists(image_key)) self.assertTrue(self._store_obj_exists(thumb_key)) - self.assertEqual(Resource.objects.get(path=image_key).media.all()[0].pk, image_id) - self.assertEqual(Resource.objects.get(path=thumb_key).media.all()[0].pk, image_id) + self.assertEqual(Resource.objects.get(path=image_key).media_proj.all()[0].pk, image_id) + self.assertEqual(Resource.objects.get(path=thumb_key).media_proj.all()[0].pk, image_id) # Delete the media and verify the files are gone. response = self.client.delete(f"/rest/Media/{image_id}", format="json") @@ -5468,8 +5468,8 @@ def test_thumbnails(self): self.assertTrue(self._store_obj_exists(image_key)) self.assertTrue(self._store_obj_exists(thumb_key)) - self.assertEqual(Resource.objects.get(path=image_key).media.all()[0].pk, image_id) - self.assertEqual(Resource.objects.get(path=thumb_key).media.all()[0].pk, image_id) + self.assertEqual(Resource.objects.get(path=image_key).media_proj.all()[0].pk, image_id) + self.assertEqual(Resource.objects.get(path=thumb_key).media_proj.all()[0].pk, image_id) # Delete the media and verify the files are gone. response = self.client.delete(f"/rest/Media/{image_id}", format="json") @@ -5499,8 +5499,8 @@ def test_thumbnails(self): gif_key = video.media_files["thumbnail_gif"][0]["path"] self.assertTrue(self._store_obj_exists(thumb_key)) self.assertTrue(self._store_obj_exists(gif_key)) - self.assertEqual(Resource.objects.get(path=thumb_key).media.all()[0].pk, video_id) - self.assertEqual(Resource.objects.get(path=gif_key).media.all()[0].pk, video_id) + self.assertEqual(Resource.objects.get(path=thumb_key).media_proj.all()[0].pk, video_id) + self.assertEqual(Resource.objects.get(path=gif_key).media_proj.all()[0].pk, video_id) # Delete the media and verify the files are gone. response = self.client.delete(f"/rest/Media/{video_id}", format="json") diff --git a/api/main/util.py b/api/main/util.py index 9b9a8959c..be4fa3fbe 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -853,7 +853,7 @@ def get_clone_info(media: Media) -> dict: media_dict["original"]["media"] = Media.objects.get(pk=media_id) # Shared base queryset - media_qs = Media.objects.filter(resource_media__path__in=paths) + media_qs = Media.objects.filter(resource_media_proj__path__in=paths) media_dict["clones"].update(ele for ele in media_qs.values_list("id", flat=True)) media_dict["clones"].remove(media_dict["original"]["media"].id) else: From 463d0ce11b895d8933bca8adf6031548af22d1a5 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 14:07:35 -0400 Subject: [PATCH 26/52] Fix resource usage in permalink presign --- api/main/rest/permalink.py | 5 +++-- api/main/tests.py | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/api/main/rest/permalink.py b/api/main/rest/permalink.py index 8a620f6ff..d95797e29 100644 --- a/api/main/rest/permalink.py +++ b/api/main/rest/permalink.py @@ -46,8 +46,9 @@ def _presign(expiration, medias, fields=None): "thumbnail_gif", "attachment", ] - media_ids = [media["id"] for media in medias] - resources = Resource.objects.filter(media__in=media_ids) + media_objs = Media.objects.filter(pk__in=[media["id"] for media in medias]) + resources = Resource.objects.filter(media_proj__in=media_objs) + logger.info(f"resources = {resources}") storage_lookup = get_storage_lookup(resources) # Get replace all keys with presigned urls. diff --git a/api/main/tests.py b/api/main/tests.py index 590c10e6b..ff938c8ad 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -2429,8 +2429,12 @@ def setUp(self): self.test_bucket = create_test_bucket(None) resource = Resource(path="fake_key.txt", bucket=self.test_bucket) resource.save() - resource.media.add(self.public_video) - resource.media.add(self.private_video) + ResourceMedia.objects.create( + resource=resource, media=self.public_video, project=self.public_video.project + ) + ResourceMedia.objects.create( + resource=resource, media=self.private_video, project=self.private_video.project + ) resource.save() def test_random_user(self): From 3ab4af43fc4ab730bbedfc04c56eeffcb89cef97 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 14:27:55 -0400 Subject: [PATCH 27/52] Fix rest of m2m usages --- api/main/management/commands/backupresources.py | 2 +- api/main/models.py | 2 +- api/main/rest/_media_util.py | 2 +- api/main/rest/media.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/api/main/management/commands/backupresources.py b/api/main/management/commands/backupresources.py index 1e9e7a8df..67ad47647 100644 --- a/api/main/management/commands/backupresources.py +++ b/api/main/management/commands/backupresources.py @@ -20,7 +20,7 @@ class Command(BaseCommand): help = "Backs up any resource objects with `backed_up==False`." def handle(self, **options): - resource_qs = Resource.objects.filter(media__deleted=False, backed_up=False) + resource_qs = Resource.objects.filter(media_proj__deleted=False, backed_up=False) # Check for existence of default backup store default_backup_store = get_tator_store(backup=True) diff --git a/api/main/models.py b/api/main/models.py index c5a397e27..c7aef32c8 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1547,7 +1547,7 @@ def is_backed_up(self): media_qs = Media.objects.filter(pk__in=self.media_files["ids"]) return all(media.is_backed_up() for media in media_qs.iterator()) - resource_qs = Resource.objects.filter(media=self) + resource_qs = Resource.objects.filter(media_proj=self) return all(resource.backed_up for resource in resource_qs.iterator()) def media_def_iterator(self, keys: List[str] = None) -> Generator[Tuple[str, dict], None, None]: diff --git a/api/main/rest/_media_util.py b/api/main/rest/_media_util.py index 06c37efa3..6ce865003 100644 --- a/api/main/rest/_media_util.py +++ b/api/main/rest/_media_util.py @@ -26,7 +26,7 @@ def __init__(self, video, temp_dir, quality=None): # If available we only attempt to fetch # the part of the file we need to self._segment_info = None - resources = Resource.objects.filter(media__in=[video]) + resources = Resource.objects.filter(media_proj__in=[video]) store_lookup = get_storage_lookup(resources) if "streaming" in video.media_files: diff --git a/api/main/rest/media.py b/api/main/rest/media.py index 449d72bb1..e6f8faebf 100644 --- a/api/main/rest/media.py +++ b/api/main/rest/media.py @@ -107,7 +107,8 @@ def _presign(user_id, expiration, medias, fields=None, no_cache=False): "attachment", ] media_ids = set([media["id"] for media in medias]) - resources = Resource.objects.filter(media__in=media_ids) + media_objs = Media.objects.filter(pk__in=media_ids) + resources = Resource.objects.filter(media_proj__in=media_objs) store_lookup = get_storage_lookup(resources) cache = TatorCache() ttl = expiration - 3600 From 98af363319cfca7bc3886004e65df98fd946678e Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 14:28:03 -0400 Subject: [PATCH 28/52] Add migration from old to new m2m --- api/main/util.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/api/main/util.py b/api/main/util.py index be4fa3fbe..571320cf0 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -1408,3 +1408,15 @@ def fill_lookup_table(project_id, dry_run=False): ProjectLookup.objects.bulk_create(data) print(f"Completed states for {project_id}") data = [] + + +def migrate_old_many_to_many(): + from django.db import connection + + with connection.cursor() as cursor: + + # Handle resource many to many first + cursor.execute("DELETE * FROM main_resourcemedia") + cursor.execute( + 'INSERT INTO main_resourcemedia (resource_id, media_id, project_id) SELECT "main_resource".id, "main_media".id, "main_media".project FROM main_resource_media LEFT OUTER JOIN "main_media" ON ("main_resource_media"."media_id" = "main_media"."id") LEFT OUTER JOIN "main_resource" ON ("main_resource_media"."resource_id" = "main_resource"."id")' + ) From b0bec172641504c1f44d95dfb06f5e34cb78b7cc Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 14:41:07 -0400 Subject: [PATCH 29/52] Add back in old field to make migration non-destructive --- api/main/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/main/models.py b/api/main/models.py index c7aef32c8..5acd0a2ff 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1683,7 +1683,7 @@ class File(Model, ModelDiffMixin): class Resource(Model): path = CharField(db_index=True, max_length=256) # Comment this out to find all usages - # media = ManyToManyField(Media, related_name="resource_media") + media = ManyToManyField(Media, related_name="resource_media") media_proj = ManyToManyField( Media, related_name="resource_media_proj", From 39986d2603c5a36dc5f637700940eb3b09519b4d Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 14:50:19 -0400 Subject: [PATCH 30/52] Fix migration utility --- api/main/util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/main/util.py b/api/main/util.py index 571320cf0..f47274a2f 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -1414,9 +1414,8 @@ def migrate_old_many_to_many(): from django.db import connection with connection.cursor() as cursor: - # Handle resource many to many first - cursor.execute("DELETE * FROM main_resourcemedia") + cursor.execute("DELETE FROM main_resourcemedia") cursor.execute( 'INSERT INTO main_resourcemedia (resource_id, media_id, project_id) SELECT "main_resource".id, "main_media".id, "main_media".project FROM main_resource_media LEFT OUTER JOIN "main_media" ON ("main_resource_media"."media_id" = "main_media"."id") LEFT OUTER JOIN "main_resource" ON ("main_resource_media"."resource_id" = "main_resource"."id")' ) From 45973736819dff5ce2d796bf8d67522106484d79 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 14:55:49 -0400 Subject: [PATCH 31/52] Fix migration utility syntax a bit --- api/main/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/main/util.py b/api/main/util.py index f47274a2f..c568cf5bc 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -1417,5 +1417,5 @@ def migrate_old_many_to_many(): # Handle resource many to many first cursor.execute("DELETE FROM main_resourcemedia") cursor.execute( - 'INSERT INTO main_resourcemedia (resource_id, media_id, project_id) SELECT "main_resource".id, "main_media".id, "main_media".project FROM main_resource_media LEFT OUTER JOIN "main_media" ON ("main_resource_media"."media_id" = "main_media"."id") LEFT OUTER JOIN "main_resource" ON ("main_resource_media"."resource_id" = "main_resource"."id")' + 'INSERT INTO main_resourcemedia (resource_id, media_id, project_id) SELECT resource_id, "main_media".id, "main_media".project FROM main_resource_media LEFT OUTER JOIN "main_media" ON ("main_resource_media"."media_id" = "main_media"."id")' ) From a36f6723f169f7dd7b95dfbbeafc58efa5a547f3 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 15:32:54 -0400 Subject: [PATCH 32/52] Clean up naming + add unique constraint --- api/main/models.py | 8 ++++++-- api/main/util.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index 5acd0a2ff..875f5bb7e 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1687,7 +1687,7 @@ class Resource(Model): media_proj = ManyToManyField( Media, related_name="resource_media_proj", - through="ResourceMedia", + through="ResourceMediaM2M", through_fields=("resource", "media_proj"), ) generic_files = ManyToManyField(File, related_name="resource_files") @@ -1806,7 +1806,7 @@ def restore_resource(path, domain): return TatorBackupManager().finish_restore_resource(path, project, domain) -class ResourceMedia(Model): +class ResourceMediaM2M(Model): resource = ForeignKey(Resource, on_delete=CASCADE) media = ForeignKey(Media, on_delete=CASCADE) project = ForeignKey(Project, on_delete=CASCADE) @@ -1817,6 +1817,10 @@ class ResourceMedia(Model): to_fields=("project", "id"), related_name="media_proj", ) + class Meta: + constraints = [ + UniqueConstraint(name="resourcem2m", fields=["resource", "project", "media"]) + ] @receiver(post_save, sender=Media) diff --git a/api/main/util.py b/api/main/util.py index c568cf5bc..06b270054 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -1415,7 +1415,7 @@ def migrate_old_many_to_many(): with connection.cursor() as cursor: # Handle resource many to many first - cursor.execute("DELETE FROM main_resourcemedia") + cursor.execute("DELETE FROM main_resourcemediam2m") cursor.execute( - 'INSERT INTO main_resourcemedia (resource_id, media_id, project_id) SELECT resource_id, "main_media".id, "main_media".project FROM main_resource_media LEFT OUTER JOIN "main_media" ON ("main_resource_media"."media_id" = "main_media"."id")' + 'INSERT INTO main_resourcemediam2m (resource_id, media_id, project_id) SELECT resource_id, "main_media".id, "main_media".project FROM main_resource_media LEFT OUTER JOIN "main_media" ON ("main_resource_media"."media_id" = "main_media"."id")' ) From 5a66713ae1d104d8bdc93d645321adb5726abc71 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 15:46:00 -0400 Subject: [PATCH 33/52] Add custom M2M layer + disable legacy field to find usages --- api/main/models.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/api/main/models.py b/api/main/models.py index 875f5bb7e..7322f018d 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -2243,7 +2243,29 @@ class Section(Model): attributes = JSONField(null=True, blank=True, default=dict) explicit_listing = BooleanField(default=False, null=True, blank=True) - media = ManyToManyField(Media) + # media = ManyToManyField(Media) + media_proj = ManyToManyField( + Media, + related_name="section_media_proj", + through="SectionMediaM2M", + through_fields=("section", "media_proj"), + ) + + +class SectionMediaM2M(Model): + section = ForeignKey(Section, on_delete=CASCADE) + media = ForeignKey(Media, on_delete=CASCADE) + project = ForeignKey(Project, on_delete=CASCADE) + media_proj = ForeignObject( + to=Media, + on_delete=CASCADE, + from_fields=("project", "media"), + to_fields=("project", "id"), + related_name="media_proj", + ) + + class Meta: + constraints = [UniqueConstraint(name="sectionm2m", fields=["section", "project", "media"])] class Favorite(Model): From 298732d42f3268fde2713c34603db780941f8ec2 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 15:48:22 -0400 Subject: [PATCH 34/52] Switch to custom M2M usage --- api/main/rest/section.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/api/main/rest/section.py b/api/main/rest/section.py index 8b1433df5..3bc2e17e2 100644 --- a/api/main/rest/section.py +++ b/api/main/rest/section.py @@ -10,6 +10,7 @@ from ..models import Project from ..models import database_qs from ..models import RowProtection +from ..models import SectionMediaM2M from ..schema import SectionListSchema from ..schema import SectionDetailSchema from ..schema.components import section @@ -32,7 +33,7 @@ def _fill_m2m(response_data): section_ids = [section["id"] for section in response_data] media = { obj["section_id"]: obj["media"] - for obj in Section.media.through.objects.filter(section__in=section_ids) + for obj in Section.media_proj.through.objects.filter(section__in=section_ids) .values("section_id") .order_by("section_id") .annotate(media=ArrayAgg("media_id", default=[])) @@ -137,7 +138,9 @@ def _post(self, params): ) if media_list: for media_id in media_list: - section.media.add(media_id) + SectionMediaM2M.objects.create( + section=section, media_id=media_id, project_id=project.id + ) section.save() # Automatically create row protection for newly created section based on the creator RowProtection.objects.create( @@ -221,9 +224,14 @@ def _patch(self, params): media_add = params.get("media_add", []) media_del = params.get("media_del", []) for m in media_add: - section.media.add(m) + SectionMediaM2M.objects.create( + section=section, media_id=media_id, project_id=project.id + ) for m in media_del: - section.media.remove(m) + matching = SectionMediaM2M.objects.filter( + section=section, media_id=m, project_id=section.project.id + ) + matching.delete() # Handle attributes new_attrs = validate_attributes(params, section, section.project.attribute_types) From d2a71e0b3fbd74796fe1aef8c68a12936c46577f Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 16:51:41 -0400 Subject: [PATCH 35/52] Fix new many to many section capability --- api/main/_permission_util.py | 4 ++-- api/main/models.py | 8 +++++--- api/main/rest/_media_query.py | 2 +- api/main/tests.py | 15 ++++++++++----- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/api/main/_permission_util.py b/api/main/_permission_util.py index 940541129..5bc729e56 100644 --- a/api/main/_permission_util.py +++ b/api/main/_permission_util.py @@ -509,8 +509,8 @@ def augment_permission(user, qs): # Calculate a dictionary for permissions by section and version in this set effected_media = qs.values("media_proj__pk") effected_sections = ( - Section.objects.filter(project=project, media__in=effected_media) - .values("pk") + SectionMediaM2M.objects.filter(project=project, media__in=effected_media) + .values("section") .distinct() ) effected_versions = qs.values("version__pk") diff --git a/api/main/models.py b/api/main/models.py index 7322f018d..ff246bba7 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1721,7 +1721,7 @@ def add_resource(path_or_link, media, generic_file=None): if created: obj.bucket = media.project.bucket obj.save() - ResourceMedia.objects.create(resource=obj, media=media, project=media.project) + ResourceMediaM2M.objects.create(resource=obj, media=media, project=media.project) @staticmethod @transaction.atomic @@ -1816,6 +1816,7 @@ class ResourceMediaM2M(Model): from_fields=("project", "media"), to_fields=("project", "id"), related_name="media_proj", + null=True, ) class Meta: constraints = [ @@ -1861,7 +1862,7 @@ def drop_media_from_resource(path, media): try: logger.info(f"Dropping media {media} from resource {path}") obj = Resource.objects.get(path=path) - matches = ResourceMedia.objects.filter(resource=obj, media=media) + matches = ResourceMediaM2M.objects.filter(resource=obj, media=media) matches.delete() except: logger.warning(f"Could not remove {media} from {path}", exc_info=True) @@ -2261,7 +2262,8 @@ class SectionMediaM2M(Model): on_delete=CASCADE, from_fields=("project", "media"), to_fields=("project", "id"), - related_name="media_proj", + related_name="sm_media_proj", + null=True, ) class Meta: diff --git a/api/main/rest/_media_query.py b/api/main/rest/_media_query.py index 86973b300..738ae49eb 100644 --- a/api/main/rest/_media_query.py +++ b/api/main/rest/_media_query.py @@ -212,7 +212,7 @@ def _get_media_psql_queryset(project, filter_ops, params): section_uuid = section.tator_user_sections if section.explicit_listing: - match_qs = qs.filter(pk__in=section.media.all()) + match_qs = qs.filter(pk__in=section.media_proj.all()) elif section_uuid: match_qs = _look_for_section_uuid(qs, section_uuid) diff --git a/api/main/tests.py b/api/main/tests.py index ff938c8ad..b3c6bab97 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -2429,10 +2429,10 @@ def setUp(self): self.test_bucket = create_test_bucket(None) resource = Resource(path="fake_key.txt", bucket=self.test_bucket) resource.save() - ResourceMedia.objects.create( + ResourceMediaM2M.objects.create( resource=resource, media=self.public_video, project=self.public_video.project ) - ResourceMedia.objects.create( + ResourceMediaM2M.objects.create( resource=resource, media=self.private_video, project=self.private_video.project ) resource.save() @@ -6806,7 +6806,7 @@ def test_multi_section_lookup(self): class AdvancedPermissionTestCase(TatorTransactionTest): def setUp(self): super().setUp() - logging.disable(logging.CRITICAL) + # logging.disable(logging.CRITICAL) # Add 9 users names = ["Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Hank", "Ivy"] self.users = [create_test_user(is_staff=False, username=name) for name in names] @@ -6895,13 +6895,17 @@ def setUp(self): self.private_media = [v.pk for v in self.videos[3:6]] for media in self.videos[:3]: - self.public_section.media.add(media) + SectionMediaM2M.objects.create( + section=self.public_section, media=media, project=media.project + ) self.public_section.save() media.primary_section = self.public_section media.save() for media in self.videos[3:6]: - self.private_section.media.add(media) + SectionMediaM2M.objects.create( + section=self.private_section, media=media, project=media.project + ) self.private_section.save() media.primary_section = self.private_section media.save() @@ -7116,6 +7120,7 @@ def test_permission_augmentation(self): localization_qs = Localization.objects.filter( project=self.project, media_proj=media ) + logger.info(f"Checking localizations for {media.pk}") localization_qs = augment_permission(user, localization_qs) if media.primary_section: media_primary_section_pk = media.primary_section.pk From 20edf3f5e5a85e9add5fdfbb27c497223780c6bb Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 16:51:47 -0400 Subject: [PATCH 36/52] Add migration capability --- api/main/util.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/api/main/util.py b/api/main/util.py index 06b270054..7dd371124 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -1419,3 +1419,10 @@ def migrate_old_many_to_many(): cursor.execute( 'INSERT INTO main_resourcemediam2m (resource_id, media_id, project_id) SELECT resource_id, "main_media".id, "main_media".project FROM main_resource_media LEFT OUTER JOIN "main_media" ON ("main_resource_media"."media_id" = "main_media"."id")' ) + + # Now sections + # Handle resource many to many first + cursor.execute("DELETE FROM main_sectionmediam2m") + cursor.execute( + 'INSERT INTO main_sectionmediam2m (section_id, media_id, project_id) SELECT section_id, "main_media".id, "main_media".project FROM main_section_media LEFT OUTER JOIN "main_media" ON ("main_section_media"."media_id" = "main_media"."id")' + ) From c8374e8fd31d97e7327f5dbdbd309acb45137431 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 16:53:53 -0400 Subject: [PATCH 37/52] Fix missing media_proj --- api/main/rest/_annotation_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/main/rest/_annotation_query.py b/api/main/rest/_annotation_query.py index 979447b6f..9b5263c56 100644 --- a/api/main/rest/_annotation_query.py +++ b/api/main/rest/_annotation_query.py @@ -179,7 +179,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): related_object_search = section.related_object_search media_qs = Media.objects.filter(project=project, type=media_type_id) if section.explicit_listing: - media_qs = Media.objects.filter(pk__in=section.media.values("id")) + media_qs = Media.objects.filter(pk__in=section.media_proj.values("id")) logger.info(f"Explicit listing: {media_ids}") elif section_uuid: media_qs = _look_for_section_uuid(media_qs, section_uuid) From a3f2697dc64dc071faaafe874f619e8a4426bf3e Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 17:09:58 -0400 Subject: [PATCH 38/52] Fix logic to work with cloning --- api/main/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/main/models.py b/api/main/models.py index ff246bba7..4ad036d7b 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1721,6 +1721,10 @@ def add_resource(path_or_link, media, generic_file=None): if created: obj.bucket = media.project.bucket obj.save() + existing = ResourceMediaM2M.objects.filter( + resource=obj, media=media, project=media.project + ) + if not existing: ResourceMediaM2M.objects.create(resource=obj, media=media, project=media.project) @staticmethod From c2e540138b97efd829dd50931025ea9b93b4faa2 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Tue, 10 Sep 2024 17:10:10 -0400 Subject: [PATCH 39/52] Restore old field to make migration non-destructive --- api/main/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/main/models.py b/api/main/models.py index 4ad036d7b..c0080319b 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -2248,7 +2248,7 @@ class Section(Model): attributes = JSONField(null=True, blank=True, default=dict) explicit_listing = BooleanField(default=False, null=True, blank=True) - # media = ManyToManyField(Media) + media = ManyToManyField(Media) media_proj = ManyToManyField( Media, related_name="section_media_proj", From 765da7d3dc7a3f4c8046077abef2122193b0eee5 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Wed, 11 Sep 2024 09:25:46 -0400 Subject: [PATCH 40/52] Add new custom M2M tables --- api/main/models.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/api/main/models.py b/api/main/models.py index c0080319b..5c01ecb43 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -2103,6 +2103,42 @@ def selectOnMedia(media_id): return State.objects.filter(media__in=media_id) +class StateMediaM2M(Model): + state = ForeignKey(State, on_delete=CASCADE) + media = ForeignKey(Media, on_delete=CASCADE) + project = ForeignKey(Project, on_delete=CASCADE) + media_proj = ForeignObject( + to=Media, + on_delete=CASCADE, + from_fields=("project", "media"), + to_fields=("project", "id"), + related_name="stm_media_proj", + null=True, + ) + + class Meta: + constraints = [UniqueConstraint(name="sectionm2m", fields=["state", "project", "media"])] + + +class StateLocalizationM2M(Model): + state = ForeignKey(State, on_delete=CASCADE) + localization = ForeignKey(Localization, on_delete=CASCADE) + project = ForeignKey(Project, on_delete=CASCADE) + localization_proj = ForeignObject( + to=Localization, + on_delete=CASCADE, + from_fields=("project", "localization"), + to_fields=("project", "id"), + related_name="stm_localization_proj", + null=True, + ) + + class Meta: + constraints = [ + UniqueConstraint(name="sectionm2m", fields=["state", "project", "localization"]) + ] + + @receiver(m2m_changed, sender=State.localizations.through) def calc_segments(sender, **kwargs): instance = kwargs["instance"] From 4844669d8b87b4e598caa4ae3907015cb1cdbceb Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Wed, 11 Sep 2024 09:28:46 -0400 Subject: [PATCH 41/52] Add to migration utility --- api/main/util.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/api/main/util.py b/api/main/util.py index 7dd371124..463382a4f 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -1421,8 +1421,19 @@ def migrate_old_many_to_many(): ) # Now sections - # Handle resource many to many first cursor.execute("DELETE FROM main_sectionmediam2m") cursor.execute( 'INSERT INTO main_sectionmediam2m (section_id, media_id, project_id) SELECT section_id, "main_media".id, "main_media".project FROM main_section_media LEFT OUTER JOIN "main_media" ON ("main_section_media"."media_id" = "main_media"."id")' ) + + # state-media + cursor.execute("DELETE FROM main_statemediam2m") + cursor.execute( + 'INSERT INTO main_statemediam2m (state_id, media_id, project_id) SELECT state_id, "main_media".id, "main_media".project FROM main_state_media LEFT OUTER JOIN "main_media" ON ("main_state_media"."media_id" = "main_media"."id")' + ) + + # state-localization + cursor.execute("DELETE FROM main_statelocalizationm2m") + cursor.execute( + 'INSERT INTO main_statelocalizationm2m (state_id, localization_id, project_id) SELECT state_id, "main_localization".id, "main_localization".project FROM main_localization_media LEFT OUTER JOIN "main_localization" ON ("main_state_localization"."localization_id" = "main_localization"."id")' + ) From 547c19d7dab267f5fca7f13eaa7fdbcee3a01235 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Wed, 11 Sep 2024 09:35:52 -0400 Subject: [PATCH 42/52] Fix constraint names --- api/main/models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index 5c01ecb43..0822075db 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -2117,7 +2117,9 @@ class StateMediaM2M(Model): ) class Meta: - constraints = [UniqueConstraint(name="sectionm2m", fields=["state", "project", "media"])] + constraints = [ + UniqueConstraint(name="state_media_m2m", fields=["state", "project", "media"]) + ] class StateLocalizationM2M(Model): @@ -2135,7 +2137,9 @@ class StateLocalizationM2M(Model): class Meta: constraints = [ - UniqueConstraint(name="sectionm2m", fields=["state", "project", "localization"]) + UniqueConstraint( + name="state_localization_m2m", fields=["state", "project", "localization"] + ) ] From b3ed2523f1822e6af398327677288763381bf15c Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Wed, 11 Sep 2024 09:38:36 -0400 Subject: [PATCH 43/52] Fix migration utility SQL (table name was plural) --- api/main/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/main/util.py b/api/main/util.py index 463382a4f..0478c14ce 100644 --- a/api/main/util.py +++ b/api/main/util.py @@ -1435,5 +1435,5 @@ def migrate_old_many_to_many(): # state-localization cursor.execute("DELETE FROM main_statelocalizationm2m") cursor.execute( - 'INSERT INTO main_statelocalizationm2m (state_id, localization_id, project_id) SELECT state_id, "main_localization".id, "main_localization".project FROM main_localization_media LEFT OUTER JOIN "main_localization" ON ("main_state_localization"."localization_id" = "main_localization"."id")' + 'INSERT INTO main_statelocalizationm2m (state_id, localization_id, project_id) SELECT state_id, "main_localization".id, "main_localization".project FROM main_state_localizations LEFT OUTER JOIN "main_localization" ON ("main_state_localizations"."localization_id" = "main_localization"."id")' ) From 7dc0a566164b9b6c6058c95453a1bf4bc7ed0b05 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Wed, 11 Sep 2024 09:41:27 -0400 Subject: [PATCH 44/52] Switch to new custom M2M fields --- api/main/models.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index 0822075db..7f53f86b3 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -2074,8 +2074,20 @@ class Meta: version = ForeignKey(Version, on_delete=CASCADE, null=True, blank=False, db_column="version") parent = ForeignKey("self", on_delete=SET_NULL, null=True, blank=True, db_column="parent") """ Pointer to localization in which this one was generated from """ - media = ManyToManyField(Media, related_name="state") - localizations = ManyToManyField(Localization) + # media = ManyToManyField(Media, related_name="state") + # localizations = ManyToManyField(Localization) + media_proj = ManyToManyField( + Media, + related_name="state_media_proj", + through="StateMediaM2M", + through_fields=("state", "media_proj"), + ) + localization_proj = ManyToManyField( + Localization, + related_name="state_localization_proj", + through="StateLocalizationM2M", + through_fields=("state", "localization_proj"), + ) segments = JSONField(null=True, blank=True) color = CharField(null=True, blank=True, max_length=8) frame = PositiveIntegerField(null=True, blank=True) @@ -2143,7 +2155,7 @@ class Meta: ] -@receiver(m2m_changed, sender=State.localizations.through) +@receiver(m2m_changed, sender=State.localization_proj.through) def calc_segments(sender, **kwargs): instance = kwargs["instance"] sortedLocalizations = Localization.objects.filter(pk__in=instance.localizations.all()).order_by( From e9703649fa74307dbc8948fcfd3ea7210308bfcd Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Wed, 11 Sep 2024 21:29:11 -0400 Subject: [PATCH 45/52] Clean up usage of new custom M2M fields for state model --- api/main/models.py | 76 +++++++++++++++++++++-- api/main/rest/_annotation_query.py | 8 ++- api/main/rest/_attribute_query.py | 4 +- api/main/rest/media.py | 12 ++-- api/main/rest/section.py | 6 +- api/main/rest/state.py | 96 +++++++++++++----------------- api/main/tests.py | 26 +++----- 7 files changed, 139 insertions(+), 89 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index 7f53f86b3..b5d61b363 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -1721,11 +1721,7 @@ def add_resource(path_or_link, media, generic_file=None): if created: obj.bucket = media.project.bucket obj.save() - existing = ResourceMediaM2M.objects.filter( - resource=obj, media=media, project=media.project - ) - if not existing: - ResourceMediaM2M.objects.create(resource=obj, media=media, project=media.project) + add_media_to_resource(obj, media) @staticmethod @transaction.atomic @@ -1828,6 +1824,13 @@ class Meta: ] +@transaction.atomic +def add_media_to_resource(resource, media): + obj, created = ResourceMediaM2M.objects.get_or_create( + resource=resource, media=media, project=media.project + ) + + @receiver(post_save, sender=Media) def media_save(sender, instance, created, **kwargs): if instance.media_files and created: @@ -2134,6 +2137,31 @@ class Meta: ] +@transaction.atomic +def add_media_id_to_state(state, media_id, project_id): + if type(media_id) == int: + obj, created = StateMediaM2M.objects.get_or_create( + state=state, media_id=media_id, project_id=project_id + ) + else: + existing = list( + StateMediaM2M.objects.filter( + state=state, media_id__in=media_id, project=project_id + ).values_list("media_id", flat=True) + ) + blk = [] + for media in media_id: + if media not in existing: + blk.append(StateMediaM2M(state=state, media_id=media, project_id=project_id)) + + if blk: + StateMediaM2M.objects.bulk_create(blk) + + +def add_media_to_state(state, media): + return add_media_id_to_state(state, media.id, media.project.id) + + class StateLocalizationM2M(Model): state = ForeignKey(State, on_delete=CASCADE) localization = ForeignKey(Localization, on_delete=CASCADE) @@ -2155,6 +2183,33 @@ class Meta: ] +@transaction.atomic +def add_localization_id_to_state(state, localization_id, project_id): + if type(localization_id) == int: + obj, created = StateLocalizationM2M.objects.get_or_create( + state=state, localization_id=localization_id, project_id=project_id + ) + else: + existing = list( + StateLocalizationM2M.objects.filter( + state=state, localization_id__in=localization_id, project=project_id + ).values_list("localization_id", flat=True) + ) + blk = [] + for local in localization_id: + if local not in existing: + blk.append( + StateLocalizationM2M(state=state, localization_id=local, project_id=project_id) + ) + + if blk: + StateLocalizationM2M.objects.bulk_create(blk) + + +def add_localization_to_state(state, localization): + return add_localization_id_to_state(state, localization.id, localization.project.id) + + @receiver(m2m_changed, sender=State.localization_proj.through) def calc_segments(sender, **kwargs): instance = kwargs["instance"] @@ -2326,6 +2381,17 @@ class Meta: constraints = [UniqueConstraint(name="sectionm2m", fields=["section", "project", "media"])] +@transaction.atomic +def add_media_id_to_section(section, media_id, project_id): + obj, created = SectionMediaM2M.objects.get_or_create( + section=section, media_id=media, project_id=project_id + ) + + +def add_media_to_section(section, media): + return add_media_id_to_section(section, media.id, media.project.id) + + class Favorite(Model): """Stores an annotation saved by a user.""" diff --git a/api/main/rest/_annotation_query.py b/api/main/rest/_annotation_query.py index 9b5263c56..cb84ba802 100644 --- a/api/main/rest/_annotation_query.py +++ b/api/main/rest/_annotation_query.py @@ -74,7 +74,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): if media_id is not None: media_ids += media_id if media_ids: - qs = qs.filter(media__in=set(media_ids)) + qs = qs.filter(media_proj__pk__in=set(media_ids)) if len(media_ids) > 1: qs = qs.distinct() @@ -99,13 +99,15 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): if frame_state_ids and (annotation_type == "localization"): # Combine media and frame from states then find localizations that match expression = ExpressionWrapper( - Cast("media", output_field=BigIntegerField()).bitleftshift(32).bitor(F("frame")), + Cast("media_proj__pk", output_field=BigIntegerField()) + .bitleftshift(32) + .bitor(F("frame")), output_field=BigIntegerField(), ) media_frames = ( State.objects.filter( pk__in=set(frame_state_ids), - media__isnull=False, + media_proj__isnull=False, frame__isnull=False, variant_deleted=False, ) diff --git a/api/main/rest/_attribute_query.py b/api/main/rest/_attribute_query.py index cc27bd249..b75ee1802 100644 --- a/api/main/rest/_attribute_query.py +++ b/api/main/rest/_attribute_query.py @@ -77,7 +77,9 @@ class MediaFieldExpression: def get_wrapper(): return ExpressionWrapper( - Cast("media", output_field=BigIntegerField()).bitleftshift(32).bitor(F("frame")), + Cast("media_proj__pk", output_field=BigIntegerField()) + .bitleftshift(32) + .bitor(F("frame")), output_field=BigIntegerField(), ) diff --git a/api/main/rest/media.py b/api/main/rest/media.py index e6f8faebf..5ad4a72a3 100644 --- a/api/main/rest/media.py +++ b/api/main/rest/media.py @@ -521,10 +521,10 @@ def _delete(self, params): # Any states that are only associated to deleted media should also be marked # for deletion. - not_deleted = State.objects.filter(project=project, media__deleted=False).values_list( - "id", flat=True - ) - deleted = State.objects.filter(project=project, media__deleted=True).values_list( + not_deleted = State.objects.filter( + project=project, media_proj__deleted=False + ).values_list("id", flat=True) + deleted = State.objects.filter(project=project, media_proj__deleted=True).values_list( "id", flat=True ) all_deleted = set(deleted) - set(not_deleted) @@ -904,10 +904,10 @@ def _delete(self, params): # Any states that are only associated to deleted media should also be marked # for deletion. - not_deleted = State.objects.filter(project=project, media__deleted=False).values_list( + not_deleted = State.objects.filter(project=project, media_proj__deleted=False).values_list( "id", flat=True ) - deleted = State.objects.filter(project=project, media__deleted=True).values_list( + deleted = State.objects.filter(project=project, media_proj__deleted=True).values_list( "id", flat=True ) all_deleted = set(deleted) - set(not_deleted) diff --git a/api/main/rest/section.py b/api/main/rest/section.py index 3bc2e17e2..c27b08eee 100644 --- a/api/main/rest/section.py +++ b/api/main/rest/section.py @@ -11,6 +11,7 @@ from ..models import database_qs from ..models import RowProtection from ..models import SectionMediaM2M +from ..models import add_media_to_section from ..schema import SectionListSchema from ..schema import SectionDetailSchema from ..schema.components import section @@ -138,9 +139,7 @@ def _post(self, params): ) if media_list: for media_id in media_list: - SectionMediaM2M.objects.create( - section=section, media_id=media_id, project_id=project.id - ) + add_media_id_to_section(section, media_id, project.pk) section.save() # Automatically create row protection for newly created section based on the creator RowProtection.objects.create( @@ -223,6 +222,7 @@ def _patch(self, params): # Handle removing/adding media media_add = params.get("media_add", []) media_del = params.get("media_del", []) + # This is already in an atomic block so we are good to go. for m in media_add: SectionMediaM2M.objects.create( section=section, media_id=media_id, project_id=project.id diff --git a/api/main/rest/state.py b/api/main/rest/state.py index 13f60be97..3969f9b1d 100644 --- a/api/main/rest/state.py +++ b/api/main/rest/state.py @@ -21,6 +21,10 @@ from ..models import User from ..models import InterpolationMethods from ..models import Section +from ..models import StateMediaM2M +from ..models import StateLocalizationM2M +from ..models import add_media_to_state, add_media_id_to_state +from ..models import add_localization_to_state, add_localization_id_to_state from ..schema import StateListSchema from ..schema import StateDetailSchema, StateByElementalIdSchema from ..schema import MergeStatesSchema @@ -61,7 +65,7 @@ def _fill_m2m(response_data): state_ids = set([state["id"] for state in response_data]) localizations = { obj["state_id"]: obj["localizations"] - for obj in State.localizations.through.objects.filter(state__in=state_ids) + for obj in State.localization_proj.through.objects.filter(state__in=state_ids) .values("state_id") .order_by("state_id") .annotate(localizations=ArrayAgg("localization_id", default=[])) @@ -69,7 +73,7 @@ def _fill_m2m(response_data): } media = { obj["state_id"]: obj["media"] - for obj in State.media.through.objects.filter(state__in=state_ids) + for obj in State.media_proj.through.objects.filter(state__in=state_ids) .values("state_id") .order_by("state_id") .annotate(media=ArrayAgg("media_id", default=[])) @@ -307,33 +311,14 @@ def _post(self, params): # Create media relations. media_relations = [] for state, state_spec in zip(states, state_specs): - for media_id in state_spec["media_ids"]: - media_states = State.media.through( - state_id=state.id, - media_id=media_id, - ) - media_relations.append(media_states) - if len(media_relations) > 1000: - State.media.through.objects.bulk_create(media_relations, ignore_conflicts=True) - media_relations = [] - State.media.through.objects.bulk_create(media_relations, ignore_conflicts=True) + if "media_ids" in state_spec: + add_media_id_to_state(state, state_spec["media_ids"], project.pk) # Create localization relations. loc_relations = [] for state, state_spec in zip(states, state_specs): if "localization_ids" in state_spec: - for localization_id in state_spec["localization_ids"]: - loc_states = State.localizations.through( - state_id=state.id, - localization_id=localization_id, - ) - loc_relations.append(loc_states) - if len(loc_relations) > 1000: - State.localizations.through.objects.bulk_create( - loc_relations, ignore_conflicts=True - ) - loc_relations = [] - State.localizations.through.objects.bulk_create(loc_relations, ignore_conflicts=True) + add_localization_id_to_state(state, state_spec["localization_ids"], project.pk) # Calculate segments (this is not triggered for bulk created m2m). localization_ids = set( @@ -435,7 +420,9 @@ def _patch(self, params): origin_datetimes = [] for original in qs.iterator(): - many_to_many.append((original.media.all(), original.localizations.all())) + many_to_many.append( + (original.media_proj.all(), original.localization_proj.all()) + ) original.pk = None original.id = None for key, value in update_kwargs.items(): @@ -445,8 +432,12 @@ def _patch(self, params): origin_datetimes.append(original.created_datetime) new_objs = State.objects.bulk_create(objs) for p_obj, m2m, origin_datetime in zip(new_objs, many_to_many, origin_datetimes): - p_obj.media.set(m2m[0]) - p_obj.localizations.set(m2m[1]) + add_localization_id_to_state( + p_obj, list(m2m[1].values_list("id", flat=True)), p_obj.project.pk + ) + add_media_id_to_state( + p_obj, list(m2m[0].values_list("id", flat=True)), p_obj.project.pk + ) # Django doesn't let you fix created_datetime unless you fetch the object again found_it = State.objects.get(pk=p_obj.pk) @@ -494,17 +485,11 @@ def get_qs(self, params, qs): if not qs.exists(): raise Http404 state = qs.values(*STATE_PROPERTIES)[0] + local_qs = StateLocalizationM2M.objects.filter(state=state["id"]) + media_qs = StateMediaM2M.objects.filter(state=state["id"]) # Get many to many fields. - state["localizations"] = list( - State.localizations.through.objects.filter(state_id=state["id"]).aggregate( - localizations=ArrayAgg("localization_id", default=[]) - )["localizations"] - ) - state["media"] = list( - State.media.through.objects.filter(state_id=state["id"]).aggregate( - media=ArrayAgg("media_id", default=[]) - )["media"] - ) + state["localizations"] = list(local_qs.values_list("localization_id", flat=True)) + state["media"] = list(media_qs.values_list("media_id", flat=True)) return state def patch_qs(self, params, qs): @@ -537,40 +522,41 @@ def patch_qs(self, params, qs): obj.frame = params["frame"] if "media_ids" in params: - media_elements = Media.objects.filter(pk__in=params["media_ids"]) - obj.media.set(media_elements) if association_type != "Media": logger.warning( f"Media set on state {obj.id} of type {association_type}." "This is not a Media type state." ) + obj.media_proj.all().delete() + + add_media_id_to_state(obj, params["media_ids"], obj.project.pk) + if "localization_ids" in params: - localizations = Localization.objects.filter(pk__in=params["localization_ids"]) - obj.localizations.set(localizations) if association_type != "Localization": logger.warning( f"Media set on state {obj.id} of type {association_type}." "This is not a Media type state." ) + obj.localization_proj.all().delete() + add_localization_id_to_state(obj, params["localization_ids"], obj.project.pk) if "localization_ids_add" in params: - localizations = Localization.objects.filter(pk__in=params["localization_ids_add"]) - obj.localizations.add(*list(localizations)) if association_type != "Localization": logger.warning( f"Media set on state {obj.id} of type {association_type}." "This is not a Media type state." ) + add_localization_id_to_state(obj, params["localization_ids_add"], obj.project.pk) if "localization_ids_remove" in params: - localizations = Localization.objects.filter(pk__in=params["localization_ids_remove"]) - obj.localizations.remove(*list(localizations)) if association_type != "Localization": logger.warning( f"Media set on state {obj.id} of type {association_type}." "This is not a Media type state." ) + local_qs = obj.localization_proj.filter(pk__in=params["localization_ids_remove"]) + local_qs.delete() if params.get("user_elemental_id", None): params["in_place"] = 1 @@ -580,11 +566,11 @@ def patch_qs(self, params, qs): obj.created_by = computed_author # Make sure media and localizations are part of this project. - media_qs = Media.objects.filter(pk__in=obj.media.all()) - localization_qs = Localization.objects.filter(pk__in=obj.localizations.all()) + media_qs = Media.objects.filter(pk__in=obj.media_proj.all()) + localization_qs = obj.localization_proj.all() media_projects = list(media_qs.values_list("project", flat=True).distinct()) localization_projects = list(localization_qs.values_list("project", flat=True).distinct()) - if obj.localizations.count() > 0: + if obj.localization_proj.count() > 0: if len(localization_projects) != 1: raise Exception( f"Localizations must be part of project {obj.project.id}, got projects " @@ -595,7 +581,7 @@ def patch_qs(self, params, qs): f"Localizations must be part of project {obj.project.id}, got project " f"{localization_projects[0]}!" ) - if obj.media.count() > 0: + if obj.media_proj.count() > 0: if len(media_projects) != 1: raise Exception( f"Media must be part of project {obj.project.id}, got projects " @@ -625,8 +611,8 @@ def patch_qs(self, params, qs): f"Object is mark {obj.mark} of {obj.latest_mark} for {obj.version.name}/{obj.elemental_id}" ) - old_media = obj.media.all() - old_localizations = obj.localizations.all() + old_media = list(obj.media_proj.all().values_list("id", flat=True)) + old_localizations = list(obj.localization_proj.all().values_list("id", flat=True)) # Save edits as new object, mark is calculated in trigger obj.id = None obj.pk = None @@ -636,8 +622,10 @@ def patch_qs(self, params, qs): # Keep original creation time found_it.created_datetime = origin_datetime found_it.save() - found_it.media.set(old_media) - found_it.localizations.set(old_localizations) + + # Add by list + add_media_id_to_state(found_it, old_media, obj.project.pk) + add_localization_id_to_state(found_it, old_localizations, obj.project.pk) return { "message": f"State {obj.elemental_id}@{obj.version.id}/{obj.mark} successfully updated!", @@ -738,7 +726,7 @@ class TrimStateEndAPI(BaseDetailView): @transaction.atomic def _patch(self, params: dict) -> dict: obj = State.objects.get(pk=params["id"], deleted=False) - localizations = obj.localizations.order_by("frame") + localizations = obj.localization_proj.order_by("frame") if params["endpoint"] == "start": keep_localization = lambda frame: frame >= params["frame"] diff --git a/api/main/tests.py b/api/main/tests.py index b3c6bab97..42454228a 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -2339,7 +2339,7 @@ def setUp(self): ] for state in self.states: for media in random.choices(self.videos): - state.media.add(media) + add_media_to_state(state, media) def test_delete(self): self.client.delete(f"/rest/Project/{self.project.pk}") @@ -2429,12 +2429,8 @@ def setUp(self): self.test_bucket = create_test_bucket(None) resource = Resource(path="fake_key.txt", bucket=self.test_bucket) resource.save() - ResourceMediaM2M.objects.create( - resource=resource, media=self.public_video, project=self.public_video.project - ) - ResourceMediaM2M.objects.create( - resource=resource, media=self.private_video, project=self.private_video.project - ) + add_media_to_resource(resource, self.public_video) + add_media_to_resource(resource, self.private_video) resource.save() def test_random_user(self): @@ -2656,7 +2652,7 @@ def test_search(self): frame=10, version=self.project.version_set.all()[0], ) - state.media.add(self.entities[1]) + add_media_to_state(state, self.entities[1]) state.save() state_2 = State.objects.create( @@ -2667,7 +2663,7 @@ def test_search(self): frame=100, version=self.project.version_set.all()[0], ) - state_2.media.add(self.entities[1]) + add_media_to_state(state_2, self.entities[1]) state_2.save() # Do a frame_state lookup and find the localization at @@ -3387,7 +3383,7 @@ class StateTestCase( def setUp(self): super().setUp() print(f"\n{self.__class__.__name__}=", end="", flush=True) - # logging.disable(logging.CRITICAL) + logging.disable(logging.CRITICAL) BurstableThrottle.apply_monkey_patching_for_test() self.user = create_test_user() self.user_two = create_test_user() @@ -3424,7 +3420,7 @@ def setUp(self): attributes={"Float Test": random.random() * 1000}, ) for media in random.choices(self.media_entities): - state.media.add(media) + add_media_to_state(state, media) self.entities.append(state) for e in self.entities: lookup = ProjectLookup.objects.get(project=e.project, state=e.pk) @@ -6895,17 +6891,13 @@ def setUp(self): self.private_media = [v.pk for v in self.videos[3:6]] for media in self.videos[:3]: - SectionMediaM2M.objects.create( - section=self.public_section, media=media, project=media.project - ) + add_media_to_section(self.public_section, media) self.public_section.save() media.primary_section = self.public_section media.save() for media in self.videos[3:6]: - SectionMediaM2M.objects.create( - section=self.private_section, media=media, project=media.project - ) + add_media_to_section(self.private_section, media) self.private_section.save() media.primary_section = self.private_section media.save() From 2f43addc2c7283234ff50900578b39641639c05c Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Thu, 12 Sep 2024 08:48:32 -0400 Subject: [PATCH 46/52] Clean up usage of media related searches to use media__proj --- api/main/models.py | 2 +- api/main/rest/_annotation_query.py | 14 +++++++------- api/main/rest/section.py | 2 +- api/main/tests.py | 16 ++++++++-------- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index b5d61b363..12bfcf313 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -2384,7 +2384,7 @@ class Meta: @transaction.atomic def add_media_id_to_section(section, media_id, project_id): obj, created = SectionMediaM2M.objects.get_or_create( - section=section, media_id=media, project_id=project_id + section=section, media_id=media_id, project_id=project_id ) diff --git a/api/main/rest/_annotation_query.py b/api/main/rest/_annotation_query.py index cb84ba802..9e1eddd2e 100644 --- a/api/main/rest/_annotation_query.py +++ b/api/main/rest/_annotation_query.py @@ -115,7 +115,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): .distinct() ) qs = ( - qs.filter(media__isnull=False, frame__isnull=False) + qs.filter(media_proj__isnull=False, frame__isnull=False) .alias(media_frame=expression) .filter(media_frame__in=media_frames) ) @@ -201,9 +201,9 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): raise ValueError(f"Invalid Section value pk={section.pk}") media_ids.append(media_qs) - query = Q(media__in=media_ids.pop()) + query = Q(media_proj__pk__in=media_ids.pop()) for m in media_ids: - query = query | Q(media__in=m) + query = query | Q(media_proj__pk__in=m) qs = qs.filter(query) # Do a related query @@ -226,9 +226,9 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): ) if related_matches: related_match = related_matches.pop() - query = Q(media__in=related_match) + query = Q(media_proj__in=related_match) for r in related_matches: - query = query | Q(media__in=r) + query = query | Q(media_proj__in=r) qs = qs.filter(query).distinct() if params.get("encoded_related_search"): @@ -244,9 +244,9 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): related_matches.append(media_qs) if related_matches: related_match = related_matches.pop() - query = Q(media__in=related_match) + query = Q(media_proj__in=related_match) for r in related_matches: - query = query | Q(media__in=r) + query = query | Q(media_proj__in=r) qs = qs.filter(query).distinct() else: qs = qs.filter(pk=-1) diff --git a/api/main/rest/section.py b/api/main/rest/section.py index c27b08eee..f935fc5c9 100644 --- a/api/main/rest/section.py +++ b/api/main/rest/section.py @@ -11,7 +11,7 @@ from ..models import database_qs from ..models import RowProtection from ..models import SectionMediaM2M -from ..models import add_media_to_section +from ..models import add_media_to_section, add_media_id_to_section from ..schema import SectionListSchema from ..schema import SectionDetailSchema from ..schema.components import section diff --git a/api/main/tests.py b/api/main/tests.py index 42454228a..03dfbff11 100644 --- a/api/main/tests.py +++ b/api/main/tests.py @@ -3945,11 +3945,11 @@ def test_multiple_media_delete(self): self.assertEqual(len(response.data), 2) not_deleted = State.objects.filter( - project=self.project.pk, media__deleted=False + project=self.project.pk, media_proj__deleted=False + ).values_list("id", flat=True) + deleted = State.objects.filter( + project=self.project.pk, media_proj__deleted=True ).values_list("id", flat=True) - deleted = State.objects.filter(project=self.project.pk, media__deleted=True).values_list( - "id", flat=True - ) response = self.client.delete( f"/rest/Medias/{self.project.pk}?attribute={attr_search}", format="json" @@ -3968,11 +3968,11 @@ def test_multiple_media_delete(self): "id", flat=True ) not_deleted = State.objects.filter( - project=self.project.pk, media__deleted=False + project=self.project.pk, media_proj__deleted=False + ).values_list("id", flat=True) + deleted = State.objects.filter( + project=self.project.pk, media_proj__deleted=True ).values_list("id", flat=True) - deleted = State.objects.filter(project=self.project.pk, media__deleted=True).values_list( - "id", flat=True - ) response = self.client.get(f"/rest/States/{self.project.pk}?attribute={attr_search}") self.assertEqual(len(response.data), 0) From 3387f49108c6e2f669d890353c1c2bb4e3bcd76f Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Thu, 12 Sep 2024 12:24:32 -0400 Subject: [PATCH 47/52] Clean up state_media interactions --- api/main/models.py | 22 ++++++++++++---------- api/main/rest/state.py | 10 +++++----- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index 12bfcf313..20db5dbad 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -2138,17 +2138,18 @@ class Meta: @transaction.atomic -def add_media_id_to_state(state, media_id, project_id): +def add_media_id_to_state(state, media_id, project_id, clear=False): if type(media_id) == int: obj, created = StateMediaM2M.objects.get_or_create( state=state, media_id=media_id, project_id=project_id ) else: - existing = list( - StateMediaM2M.objects.filter( - state=state, media_id__in=media_id, project=project_id - ).values_list("media_id", flat=True) + existing = StateMediaM2M.objects.filter( + state=state, media_id__in=media_id, project=project_id ) + if clear: + existing.delete() + existing = list(existing.values_list("media_id", flat=True)) blk = [] for media in media_id: if media not in existing: @@ -2184,17 +2185,18 @@ class Meta: @transaction.atomic -def add_localization_id_to_state(state, localization_id, project_id): +def add_localization_id_to_state(state, localization_id, project_id, clear=False): if type(localization_id) == int: obj, created = StateLocalizationM2M.objects.get_or_create( state=state, localization_id=localization_id, project_id=project_id ) else: - existing = list( - StateLocalizationM2M.objects.filter( - state=state, localization_id__in=localization_id, project=project_id - ).values_list("localization_id", flat=True) + existing = StateLocalizationM2M.objects.filter( + state=state, localization_id__in=localization_id, project=project_id ) + if clear: + existing.delete() + existing = list(existing.values_list("localization_id", flat=True)) blk = [] for local in localization_id: if local not in existing: diff --git a/api/main/rest/state.py b/api/main/rest/state.py index 3969f9b1d..4527fa693 100644 --- a/api/main/rest/state.py +++ b/api/main/rest/state.py @@ -528,9 +528,7 @@ def patch_qs(self, params, qs): "This is not a Media type state." ) - obj.media_proj.all().delete() - - add_media_id_to_state(obj, params["media_ids"], obj.project.pk) + add_media_id_to_state(obj, params["media_ids"], obj.project.pk, clear=True) if "localization_ids" in params: if association_type != "Localization": @@ -538,8 +536,10 @@ def patch_qs(self, params, qs): f"Media set on state {obj.id} of type {association_type}." "This is not a Media type state." ) - obj.localization_proj.all().delete() - add_localization_id_to_state(obj, params["localization_ids"], obj.project.pk) + + add_localization_id_to_state( + obj, params["localization_ids"], obj.project.pk, clear=True + ) if "localization_ids_add" in params: if association_type != "Localization": From d9a65d081b5c427eed05c03fb5325fa2dd8640ef Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Thu, 12 Sep 2024 12:24:39 -0400 Subject: [PATCH 48/52] Fix state graphic endpoint --- api/main/rest/state_graphic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/main/rest/state_graphic.py b/api/main/rest/state_graphic.py index ee3de02a3..7ef802083 100644 --- a/api/main/rest/state_graphic.py +++ b/api/main/rest/state_graphic.py @@ -63,8 +63,8 @@ def _get(self, params): if typeObj.association != "Localization": raise Exception("Not a localization association state") - video = state.media.all()[0] - localizations = state.localizations.order_by("frame")[offset : offset + length] + video = state.media_proj.all()[0] + localizations = state.localization_proj.order_by("frame")[offset : offset + length] frames = [l.frame for l in localizations] roi = [(l.width, l.height, l.x, l.y) for l in localizations] with tempfile.TemporaryDirectory() as temp_dir: From efbcd69f2dfd075c5ee7c34cb731c114ee592288 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Thu, 12 Sep 2024 12:24:54 -0400 Subject: [PATCH 49/52] Make this less fragile --- api/main/rest/section.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/main/rest/section.py b/api/main/rest/section.py index f935fc5c9..2a519e092 100644 --- a/api/main/rest/section.py +++ b/api/main/rest/section.py @@ -224,7 +224,7 @@ def _patch(self, params): media_del = params.get("media_del", []) # This is already in an atomic block so we are good to go. for m in media_add: - SectionMediaM2M.objects.create( + SectionMediaM2M.objects.get_or_create( section=section, media_id=media_id, project_id=project.id ) for m in media_del: From 0a86ef418909217001ded2e7cc2108a1d8bdb667 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Thu, 12 Sep 2024 16:08:18 -0400 Subject: [PATCH 50/52] Fix up rest of query logic --- api/main/rest/_annotation_query.py | 18 ++++++++++++++---- api/main/rest/_attribute_query.py | 14 ++++++++------ api/main/rest/_media_query.py | 21 +++++++++++++++------ 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/api/main/rest/_annotation_query.py b/api/main/rest/_annotation_query.py index 9e1eddd2e..5d788be5a 100644 --- a/api/main/rest/_annotation_query.py +++ b/api/main/rest/_annotation_query.py @@ -10,7 +10,17 @@ from django.db.models import Q, F from django.db.models import BigIntegerField, ExpressionWrapper -from ..models import Localization, LocalizationType, Media, MediaType, Section, State, StateType +from ..models import ( + Localization, + LocalizationType, + Media, + MediaType, + Section, + State, + StateType, + StateLocalizationM2M, + StateMediaM2M, +) from ..schema._attributes import related_keys @@ -83,7 +93,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): localization_ids += localization_id_put if state_ids and (annotation_type == "localization"): localization_ids += list( - State.localizations.through.objects.filter(state__in=set(state_ids)) + StateLocalizationM2M.objects.filter(state__in=set(state_ids), project=project) .values_list("localization_id", flat=True) .distinct() ) @@ -91,7 +101,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): if annotation_type == "localization": qs = qs.filter(pk__in=set(localization_ids)) elif annotation_type == "state": - qs = qs.filter(localizations__in=set(localization_ids)).distinct() + qs = qs.filter(localization_proj__pk__in=set(localization_ids)).distinct() if state_ids and (annotation_type == "state"): qs = qs.filter(pk__in=set(state_ids)) @@ -257,7 +267,7 @@ def _get_annotation_psql_queryset(project, filter_ops, params, annotation_type): state_qs = State.objects.filter(pk__in=set(params.get("related_id"))) qs = qs.filter(pk__in=state_qs.values("localizations").distinct()) elif annotation_type == "state": - qs = qs.filter(localizations__in=set(params.get("related_id"))) + qs = qs.filter(localization_proj__pk__in=set(params.get("related_id"))) if apply_merge: # parent_set = ANNOTATION_LOOKUP[annotation_type].objects.filter(pk__in=Subquery()) diff --git a/api/main/rest/_attribute_query.py b/api/main/rest/_attribute_query.py index b75ee1802..9b2cd5e94 100644 --- a/api/main/rest/_attribute_query.py +++ b/api/main/rest/_attribute_query.py @@ -146,9 +146,9 @@ def _related_search( orig_list = [*related_matches] related_match = related_matches.pop() # Pop and process the list - media_vals = related_match.values("media") + media_vals = related_match.values("media_proj__pk") for related_match in related_matches: - this_vals = related_match.values("media") + this_vals = related_match.values("media_proj__pk") media_vals = media_vals.union(this_vals) # We now have all the matching media, but lost the score information @@ -157,14 +157,16 @@ def _related_search( # list comp didn't play nice here, but this is easier to read anyway score = [] for x in orig_list: - annotated_x = x.values("media").annotate(count=Count("media")) - filtered_x = annotated_x.filter(media=OuterRef("id")) + annotated_x = x.values("media_proj__pk").annotate(count=Count("media_proj__pk")) + filtered_x = annotated_x.filter(media_proj__pk=OuterRef("id")) values_x = filtered_x.values("count").order_by("-count")[:1] score.append(Subquery(values_x)) if len(score) > 1: - qs = qs.filter(pk__in=media_vals.values("media")).annotate(incident=Greatest(*score)) + qs = qs.filter(pk__in=media_vals.values("media_proj__pk")).annotate( + incident=Greatest(*score) + ) else: - qs = qs.filter(pk__in=media_vals.values("media")).annotate(incident=score[0]) + qs = qs.filter(pk__in=media_vals.values("media_proj__pk")).annotate(incident=score[0]) else: qs = qs.filter(pk=-1).annotate(incident=Value(0)) return qs diff --git a/api/main/rest/_media_query.py b/api/main/rest/_media_query.py index 738ae49eb..80d9548b6 100644 --- a/api/main/rest/_media_query.py +++ b/api/main/rest/_media_query.py @@ -11,7 +11,16 @@ from django.db.models.functions import Cast from django.db.models import UUIDField, TextField, F -from ..models import LocalizationType, Media, MediaType, Localization, Section, State, StateType +from ..models import ( + LocalizationType, + Media, + MediaType, + Localization, + Section, + State, + StateType, + StateMediaM2M, +) from ..schema._attributes import related_keys from ._attribute_query import ( @@ -71,7 +80,7 @@ def _get_media_psql_queryset(project, filter_ops, params): media_ids += media_id if state_ids is not None: media_ids += list( - State.media.through.objects.filter(state__in=set(state_ids)) + StateMediaM2M.objects.filter(state__in=set(state_ids), project=project) .values_list("media_id", flat=True) .distinct() ) @@ -174,9 +183,9 @@ def _get_media_psql_queryset(project, filter_ops, params): ) if related_matches: related_match = related_matches.pop() - query = Q(pk__in=related_match.values("media")) + query = Q(pk__in=related_match.values("media_proj__pk")) for r in related_matches: - query = query | Q(pk__in=r.values("media")) + query = query | Q(pk__in=r.values("media_proj__pk")) qs = qs.filter(query).distinct() if section_id: @@ -187,7 +196,7 @@ def _get_media_psql_queryset(project, filter_ops, params): section_uuid = section[0].tator_user_sections if section[0].explicit_listing: - qs = qs.filter(pk__in=section[0].media.all()) + qs = qs.filter(pk__in=section[0].media_proj.all().values("pk")) elif section_uuid: qs = _look_for_section_uuid(qs, section_uuid) elif section[0].object_search: @@ -212,7 +221,7 @@ def _get_media_psql_queryset(project, filter_ops, params): section_uuid = section.tator_user_sections if section.explicit_listing: - match_qs = qs.filter(pk__in=section.media_proj.all()) + match_qs = qs.filter(pk__in=section.media_proj.all().values("pk")) elif section_uuid: match_qs = _look_for_section_uuid(qs, section_uuid) From 105ad6163a769d610f7c616913f2c39a311b1bec Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Thu, 12 Sep 2024 16:08:37 -0400 Subject: [PATCH 51/52] Fix permission check for Merge/Trim state apis --- api/main/rest/state.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api/main/rest/state.py b/api/main/rest/state.py index 4527fa693..e87849d70 100644 --- a/api/main/rest/state.py +++ b/api/main/rest/state.py @@ -712,7 +712,9 @@ def _patch(self, params: dict) -> dict: } def get_queryset(self): - return State.objects.all() + return self.filter_only_viewables( + State.objects.objects.filter(pk__in=[self.params["id"], self.params["merge_state_id"]]) + ) class TrimStateEndAPI(BaseDetailView): @@ -760,7 +762,7 @@ def _patch(self, params: dict) -> dict: } def get_queryset(self): - return State.objects.all() + return self.filter_only_viewables(State.objects.filter(pk=self.params["id"])) class StateDetailAPI(StateDetailBaseAPI): From 7a190b90380fd7e3fbc60518f56f9d3b3026f780 Mon Sep 17 00:00:00 2001 From: Brian Tate Date: Thu, 12 Sep 2024 16:11:26 -0400 Subject: [PATCH 52/52] Make this not a destructive migration --- api/main/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/main/models.py b/api/main/models.py index 20db5dbad..f4f195709 100644 --- a/api/main/models.py +++ b/api/main/models.py @@ -2077,8 +2077,8 @@ class Meta: version = ForeignKey(Version, on_delete=CASCADE, null=True, blank=False, db_column="version") parent = ForeignKey("self", on_delete=SET_NULL, null=True, blank=True, db_column="parent") """ Pointer to localization in which this one was generated from """ - # media = ManyToManyField(Media, related_name="state") - # localizations = ManyToManyField(Localization) + media = ManyToManyField(Media, related_name="state") + localizations = ManyToManyField(Localization) media_proj = ManyToManyField( Media, related_name="state_media_proj",