From d2c16feba9d9f00f16f9406e2a466cd0cc832433 Mon Sep 17 00:00:00 2001 From: Matthias Schoettle Date: Sat, 1 Jun 2024 03:18:58 -0400 Subject: [PATCH] fix(types): Make admin classes generic as their super classes (#737) Co-authored-by: Serhii Tereshchenko --- .github/workflows/test.yml | 2 +- modeltranslation/__init__.py | 5 ++ modeltranslation/_typing.py | 20 ++++++ modeltranslation/admin.py | 66 ++++++++++++------- modeltranslation/tests/test_runtime_typing.py | 12 ++++ 5 files changed, 80 insertions(+), 25 deletions(-) create mode 100644 modeltranslation/tests/test_runtime_typing.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 16c8d278..b6ef0872 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,7 +96,7 @@ jobs: if [[ $DB == postgres ]]; then pip install -q psycopg2-binary fi - pip install typing-extensions coverage pytest pytest-django pytest-cov parameterized $(./get-django-version.py ${{ matrix.django }}) + pip install django_stubs_ext typing-extensions coverage pytest pytest-django pytest-cov parameterized $(./get-django-version.py ${{ matrix.django }}) - name: Run tests run: | pytest --cov-report term diff --git a/modeltranslation/__init__.py b/modeltranslation/__init__.py index 23c03191..03810500 100644 --- a/modeltranslation/__init__.py +++ b/modeltranslation/__init__.py @@ -1,3 +1,5 @@ +from modeltranslation._typing import monkeypatch + try: from django import VERSION as _django_version @@ -5,3 +7,6 @@ default_app_config = "modeltranslation.apps.ModeltranslationConfig" except ImportError: pass + +# monkeypatch generic classes at runtime +monkeypatch() diff --git a/modeltranslation/_typing.py b/modeltranslation/_typing.py index 039bbe7f..30eb34af 100644 --- a/modeltranslation/_typing.py +++ b/modeltranslation/_typing.py @@ -3,6 +3,9 @@ import sys from typing import Literal, TypeVar +from django.contrib import admin +from django.contrib.admin.options import BaseModelAdmin + if sys.version_info >= (3, 11): from typing import Self, TypeAlias # noqa: F401 else: @@ -14,3 +17,20 @@ # See https://github.com/typeddjango/django-stubs/blob/082955/django-stubs/utils/datastructures.pyi#L12-L14 _ListOrTuple: TypeAlias = "list[_K] | tuple[_K, ...]" + + +# https://github.com/typeddjango/django-stubs/tree/master/django_stubs_ext +# For generic classes to work at runtime we need to define `__class_getitem__`. +# We're defining it here, instead of relying on django_stubs_ext, because +# we don't want every user setting up django_stubs_ext just for this feature. +def monkeypatch() -> None: + classes = [ + admin.ModelAdmin, + BaseModelAdmin, + ] + + for class_ in classes: + if not hasattr(class_, "__class_getitem__"): + class_.__class_getitem__ = classmethod( # type: ignore[attr-defined] + lambda cls, *args, **kwargs: cls + ) diff --git a/modeltranslation/admin.py b/modeltranslation/admin.py index 6c69ab84..0025cdf4 100644 --- a/modeltranslation/admin.py +++ b/modeltranslation/admin.py @@ -1,10 +1,10 @@ from __future__ import annotations from copy import deepcopy -from typing import Any, Iterable, Sequence +from typing import Any, Iterable, Sequence, TypeVar from django import forms -from django.db.models import Field +from django.db.models import Field, Model from django.contrib import admin from django.contrib.admin.options import BaseModelAdmin, InlineModelAdmin, flatten_fieldsets from django.contrib.contenttypes.admin import GenericStackedInline, GenericTabularInline @@ -25,8 +25,10 @@ from modeltranslation.widgets import ClearableWidgetWrapper from modeltranslation._typing import _ListOrTuple +_ModelT = TypeVar("_ModelT", bound=Model) -class TranslationBaseModelAdmin(BaseModelAdmin): + +class TranslationBaseModelAdmin(BaseModelAdmin[_ModelT]): _orig_was_required: dict[str, bool] = {} both_empty_values_fields = () @@ -36,7 +38,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._patch_prepopulated_fields() def _get_declared_fieldsets( - self, request: HttpRequest, obj: Any | None = None + self, request: HttpRequest, obj: _ModelT | None = None ) -> _ListOrTuple[tuple[str | None, dict[str, Any]]] | None: # Take custom modelform fields option into account if not self.fields and hasattr(self.form, "_meta") and self.form._meta.fields: @@ -216,17 +218,17 @@ def append_lang(source: str) -> str: self.prepopulated_fields = prepopulated_fields def _get_form_or_formset( - self, request: HttpRequest, obj: Any | None, **kwargs: Any + self, request: HttpRequest, obj: Model | None, **kwargs: Any ) -> dict[str, Any]: """ Generic code shared by get_form and get_formset. """ - exclude = self.get_exclude(request, obj) + exclude = self.get_exclude(request, obj) # type: ignore[arg-type] if exclude is None: exclude = [] else: exclude = list(exclude) - exclude.extend(self.get_readonly_fields(request, obj)) + exclude.extend(self.get_readonly_fields(request, obj)) # type: ignore[arg-type] if not exclude and hasattr(self.form, "_meta") and self.form._meta.exclude: # Take the custom ModelForm's Meta.exclude into account only if the # ModelAdmin doesn't define its own. @@ -240,7 +242,7 @@ def _get_form_or_formset( return kwargs def _get_fieldsets_pre_form_or_formset( - self, request: HttpRequest, obj: Any | None = None + self, request: HttpRequest, obj: _ModelT | None = None ) -> _ListOrTuple[tuple[str | None, dict[str, Any]]] | None: """ Generic get_fieldsets code, shared by @@ -249,7 +251,7 @@ def _get_fieldsets_pre_form_or_formset( return self._get_declared_fieldsets(request, obj) def _get_fieldsets_post_form_or_formset( - self, request: HttpRequest, form: type[forms.ModelForm], obj: Any | None = None + self, request: HttpRequest, form: type[forms.ModelForm], obj: _ModelT | None = None ) -> list: """ Generic get_fieldsets code, shared by @@ -280,7 +282,7 @@ def get_translation_field_excludes( return tuple(exclude) def get_readonly_fields( - self, request: HttpRequest, obj: Any | None = None + self, request: HttpRequest, obj: _ModelT | None = None ) -> _ListOrTuple[str]: """ Hook to specify custom readonly fields. @@ -288,7 +290,7 @@ def get_readonly_fields( return self.replace_orig_field(self.readonly_fields) -class TranslationAdmin(TranslationBaseModelAdmin, admin.ModelAdmin): +class TranslationAdmin(TranslationBaseModelAdmin[_ModelT], admin.ModelAdmin[_ModelT]): # TODO: Consider addition of a setting which allows to override the fallback to True group_fieldsets = False @@ -356,7 +358,7 @@ def _group_fieldsets(self, fieldsets: list) -> list: # Extract the original field's verbose_name for use as this # fieldset's label - using gettext_lazy in your model # declaration can make that translatable. - label = self.model._meta.get_field(orig_field).verbose_name.capitalize() + label = self.model._meta.get_field(orig_field).verbose_name.capitalize() # type: ignore[union-attr] temp_fieldsets[orig_field] = ( label, {"fields": trans_fieldnames, "classes": ("mt-fieldset",)}, @@ -374,13 +376,13 @@ def _group_fieldsets(self, fieldsets: list) -> list: return fieldsets def get_form( - self, request: HttpRequest, obj: Any | None = None, **kwargs: Any + self, request: HttpRequest, obj: _ModelT | None = None, **kwargs: Any ) -> type[forms.ModelForm]: kwargs = self._get_form_or_formset(request, obj, **kwargs) return super().get_form(request, obj, **kwargs) def get_fieldsets( - self, request: HttpRequest, obj: Any | None = None + self, request: HttpRequest, obj: _ModelT | None = None ) -> _ListOrTuple[tuple[str | None, dict[str, Any]]]: return self._get_fieldsets_pre_form_or_formset(request, obj) or self._group_fieldsets( self._get_fieldsets_post_form_or_formset( @@ -389,41 +391,57 @@ def get_fieldsets( ) -class TranslationInlineModelAdmin(TranslationBaseModelAdmin, InlineModelAdmin): +_ChildModelT = TypeVar("_ChildModelT", bound=Model) +_ParentModelT = TypeVar("_ParentModelT", bound=Model) + + +class TranslationInlineModelAdmin( + TranslationBaseModelAdmin[_ChildModelT], InlineModelAdmin[_ChildModelT, _ParentModelT] +): def get_formset( - self, request: HttpRequest, obj: Any | None = None, **kwargs: Any + self, request: HttpRequest, obj: _ParentModelT | None = None, **kwargs: Any ) -> type[BaseInlineFormSet]: kwargs = self._get_form_or_formset(request, obj, **kwargs) return super().get_formset(request, obj, **kwargs) - def get_fieldsets(self, request: HttpRequest, obj: Any | None = None): + def get_fieldsets(self, request: HttpRequest, obj: _ChildModelT | None = None): # FIXME: If fieldsets are declared on an inline some kind of ghost # fieldset line with just the original model verbose_name of the model # is displayed above the new fieldsets. declared_fieldsets = self._get_fieldsets_pre_form_or_formset(request, obj) if declared_fieldsets: return declared_fieldsets - form = self.get_formset(request, obj, fields=None).form + form = self.get_formset(request, obj, fields=None).form # type: ignore[arg-type] return self._get_fieldsets_post_form_or_formset(request, form, obj) -class TranslationTabularInline(TranslationInlineModelAdmin, admin.TabularInline): +class TranslationTabularInline( + TranslationInlineModelAdmin[_ChildModelT, _ParentModelT], + admin.TabularInline[_ChildModelT, _ParentModelT], +): pass -class TranslationStackedInline(TranslationInlineModelAdmin, admin.StackedInline): +class TranslationStackedInline( + TranslationInlineModelAdmin[_ChildModelT, _ParentModelT], + admin.StackedInline[_ChildModelT, _ParentModelT], +): pass -class TranslationGenericTabularInline(TranslationInlineModelAdmin, GenericTabularInline): +class TranslationGenericTabularInline( + TranslationInlineModelAdmin[_ChildModelT, _ParentModelT], GenericTabularInline +): pass -class TranslationGenericStackedInline(TranslationInlineModelAdmin, GenericStackedInline): +class TranslationGenericStackedInline( + TranslationInlineModelAdmin[_ChildModelT, _ParentModelT], GenericStackedInline +): pass -class TabbedDjangoJqueryTranslationAdmin(TranslationAdmin): +class TabbedDjangoJqueryTranslationAdmin(TranslationAdmin[_ModelT]): """ Convenience class which includes the necessary media files for tabbed translation fields. Reuses Django's internal jquery version. @@ -441,7 +459,7 @@ class Media: } -class TabbedExternalJqueryTranslationAdmin(TranslationAdmin): +class TabbedExternalJqueryTranslationAdmin(TranslationAdmin[_ModelT]): """ Convenience class which includes the necessary media files for tabbed translation fields. Loads recent jquery version from a cdn. diff --git a/modeltranslation/tests/test_runtime_typing.py b/modeltranslation/tests/test_runtime_typing.py new file mode 100644 index 00000000..9146c4bb --- /dev/null +++ b/modeltranslation/tests/test_runtime_typing.py @@ -0,0 +1,12 @@ +from modeltranslation import admin +from modeltranslation.tests import models + + +def test_translation_admin(): + class TestModelAdmin(admin.TranslationAdmin[models.TestModel]): + pass + + class TestInlineModelAdmin( + admin.TranslationInlineModelAdmin[models.ForeignKeyModel, models.TestModel] + ): + pass