Skip to content

Commit

Permalink
fix(types): Make admin classes generic as their super classes (#737)
Browse files Browse the repository at this point in the history
Co-authored-by: Serhii Tereshchenko <[email protected]>
  • Loading branch information
mschoettle and last-partizan authored Jun 1, 2024
1 parent 6768a26 commit d2c16fe
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ jobs:
if [[ $DB == postgres ]]; then
pip install -q psycopg2-binary
fi
pip install typing-extensions coverage pytest pytest-django pytest-cov parameterized $(./get-django-version.py ${{ matrix.django }})
pip install django_stubs_ext typing-extensions coverage pytest pytest-django pytest-cov parameterized $(./get-django-version.py ${{ matrix.django }})
- name: Run tests
run: |
pytest --cov-report term
5 changes: 5 additions & 0 deletions modeltranslation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from modeltranslation._typing import monkeypatch

try:
from django import VERSION as _django_version

if _django_version < (3, 2):
default_app_config = "modeltranslation.apps.ModeltranslationConfig"
except ImportError:
pass

# monkeypatch generic classes at runtime
monkeypatch()
20 changes: 20 additions & 0 deletions modeltranslation/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import sys
from typing import Literal, TypeVar

from django.contrib import admin
from django.contrib.admin.options import BaseModelAdmin

if sys.version_info >= (3, 11):
from typing import Self, TypeAlias # noqa: F401
else:
Expand All @@ -14,3 +17,20 @@

# See https://github.com/typeddjango/django-stubs/blob/082955/django-stubs/utils/datastructures.pyi#L12-L14
_ListOrTuple: TypeAlias = "list[_K] | tuple[_K, ...]"


# https://github.com/typeddjango/django-stubs/tree/master/django_stubs_ext
# For generic classes to work at runtime we need to define `__class_getitem__`.
# We're defining it here, instead of relying on django_stubs_ext, because
# we don't want every user setting up django_stubs_ext just for this feature.
def monkeypatch() -> None:
classes = [
admin.ModelAdmin,
BaseModelAdmin,
]

for class_ in classes:
if not hasattr(class_, "__class_getitem__"):
class_.__class_getitem__ = classmethod( # type: ignore[attr-defined]
lambda cls, *args, **kwargs: cls
)
66 changes: 42 additions & 24 deletions modeltranslation/admin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any, Iterable, Sequence
from typing import Any, Iterable, Sequence, TypeVar

from django import forms
from django.db.models import Field
from django.db.models import Field, Model
from django.contrib import admin
from django.contrib.admin.options import BaseModelAdmin, InlineModelAdmin, flatten_fieldsets
from django.contrib.contenttypes.admin import GenericStackedInline, GenericTabularInline
Expand All @@ -25,8 +25,10 @@
from modeltranslation.widgets import ClearableWidgetWrapper
from modeltranslation._typing import _ListOrTuple

_ModelT = TypeVar("_ModelT", bound=Model)

class TranslationBaseModelAdmin(BaseModelAdmin):

class TranslationBaseModelAdmin(BaseModelAdmin[_ModelT]):
_orig_was_required: dict[str, bool] = {}
both_empty_values_fields = ()

Expand All @@ -36,7 +38,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._patch_prepopulated_fields()

def _get_declared_fieldsets(
self, request: HttpRequest, obj: Any | None = None
self, request: HttpRequest, obj: _ModelT | None = None
) -> _ListOrTuple[tuple[str | None, dict[str, Any]]] | None:
# Take custom modelform fields option into account
if not self.fields and hasattr(self.form, "_meta") and self.form._meta.fields:
Expand Down Expand Up @@ -216,17 +218,17 @@ def append_lang(source: str) -> str:
self.prepopulated_fields = prepopulated_fields

def _get_form_or_formset(
self, request: HttpRequest, obj: Any | None, **kwargs: Any
self, request: HttpRequest, obj: Model | None, **kwargs: Any
) -> dict[str, Any]:
"""
Generic code shared by get_form and get_formset.
"""
exclude = self.get_exclude(request, obj)
exclude = self.get_exclude(request, obj) # type: ignore[arg-type]
if exclude is None:
exclude = []
else:
exclude = list(exclude)
exclude.extend(self.get_readonly_fields(request, obj))
exclude.extend(self.get_readonly_fields(request, obj)) # type: ignore[arg-type]
if not exclude and hasattr(self.form, "_meta") and self.form._meta.exclude:
# Take the custom ModelForm's Meta.exclude into account only if the
# ModelAdmin doesn't define its own.
Expand All @@ -240,7 +242,7 @@ def _get_form_or_formset(
return kwargs

def _get_fieldsets_pre_form_or_formset(
self, request: HttpRequest, obj: Any | None = None
self, request: HttpRequest, obj: _ModelT | None = None
) -> _ListOrTuple[tuple[str | None, dict[str, Any]]] | None:
"""
Generic get_fieldsets code, shared by
Expand All @@ -249,7 +251,7 @@ def _get_fieldsets_pre_form_or_formset(
return self._get_declared_fieldsets(request, obj)

def _get_fieldsets_post_form_or_formset(
self, request: HttpRequest, form: type[forms.ModelForm], obj: Any | None = None
self, request: HttpRequest, form: type[forms.ModelForm], obj: _ModelT | None = None
) -> list:
"""
Generic get_fieldsets code, shared by
Expand Down Expand Up @@ -280,15 +282,15 @@ def get_translation_field_excludes(
return tuple(exclude)

def get_readonly_fields(
self, request: HttpRequest, obj: Any | None = None
self, request: HttpRequest, obj: _ModelT | None = None
) -> _ListOrTuple[str]:
"""
Hook to specify custom readonly fields.
"""
return self.replace_orig_field(self.readonly_fields)


class TranslationAdmin(TranslationBaseModelAdmin, admin.ModelAdmin):
class TranslationAdmin(TranslationBaseModelAdmin[_ModelT], admin.ModelAdmin[_ModelT]):
# TODO: Consider addition of a setting which allows to override the fallback to True
group_fieldsets = False

Expand Down Expand Up @@ -356,7 +358,7 @@ def _group_fieldsets(self, fieldsets: list) -> list:
# Extract the original field's verbose_name for use as this
# fieldset's label - using gettext_lazy in your model
# declaration can make that translatable.
label = self.model._meta.get_field(orig_field).verbose_name.capitalize()
label = self.model._meta.get_field(orig_field).verbose_name.capitalize() # type: ignore[union-attr]
temp_fieldsets[orig_field] = (
label,
{"fields": trans_fieldnames, "classes": ("mt-fieldset",)},
Expand All @@ -374,13 +376,13 @@ def _group_fieldsets(self, fieldsets: list) -> list:
return fieldsets

def get_form(
self, request: HttpRequest, obj: Any | None = None, **kwargs: Any
self, request: HttpRequest, obj: _ModelT | None = None, **kwargs: Any
) -> type[forms.ModelForm]:
kwargs = self._get_form_or_formset(request, obj, **kwargs)
return super().get_form(request, obj, **kwargs)

def get_fieldsets(
self, request: HttpRequest, obj: Any | None = None
self, request: HttpRequest, obj: _ModelT | None = None
) -> _ListOrTuple[tuple[str | None, dict[str, Any]]]:
return self._get_fieldsets_pre_form_or_formset(request, obj) or self._group_fieldsets(
self._get_fieldsets_post_form_or_formset(
Expand All @@ -389,41 +391,57 @@ def get_fieldsets(
)


class TranslationInlineModelAdmin(TranslationBaseModelAdmin, InlineModelAdmin):
_ChildModelT = TypeVar("_ChildModelT", bound=Model)
_ParentModelT = TypeVar("_ParentModelT", bound=Model)


class TranslationInlineModelAdmin(
TranslationBaseModelAdmin[_ChildModelT], InlineModelAdmin[_ChildModelT, _ParentModelT]
):
def get_formset(
self, request: HttpRequest, obj: Any | None = None, **kwargs: Any
self, request: HttpRequest, obj: _ParentModelT | None = None, **kwargs: Any
) -> type[BaseInlineFormSet]:
kwargs = self._get_form_or_formset(request, obj, **kwargs)
return super().get_formset(request, obj, **kwargs)

def get_fieldsets(self, request: HttpRequest, obj: Any | None = None):
def get_fieldsets(self, request: HttpRequest, obj: _ChildModelT | None = None):
# FIXME: If fieldsets are declared on an inline some kind of ghost
# fieldset line with just the original model verbose_name of the model
# is displayed above the new fieldsets.
declared_fieldsets = self._get_fieldsets_pre_form_or_formset(request, obj)
if declared_fieldsets:
return declared_fieldsets
form = self.get_formset(request, obj, fields=None).form
form = self.get_formset(request, obj, fields=None).form # type: ignore[arg-type]
return self._get_fieldsets_post_form_or_formset(request, form, obj)


class TranslationTabularInline(TranslationInlineModelAdmin, admin.TabularInline):
class TranslationTabularInline(
TranslationInlineModelAdmin[_ChildModelT, _ParentModelT],
admin.TabularInline[_ChildModelT, _ParentModelT],
):
pass


class TranslationStackedInline(TranslationInlineModelAdmin, admin.StackedInline):
class TranslationStackedInline(
TranslationInlineModelAdmin[_ChildModelT, _ParentModelT],
admin.StackedInline[_ChildModelT, _ParentModelT],
):
pass


class TranslationGenericTabularInline(TranslationInlineModelAdmin, GenericTabularInline):
class TranslationGenericTabularInline(
TranslationInlineModelAdmin[_ChildModelT, _ParentModelT], GenericTabularInline
):
pass


class TranslationGenericStackedInline(TranslationInlineModelAdmin, GenericStackedInline):
class TranslationGenericStackedInline(
TranslationInlineModelAdmin[_ChildModelT, _ParentModelT], GenericStackedInline
):
pass


class TabbedDjangoJqueryTranslationAdmin(TranslationAdmin):
class TabbedDjangoJqueryTranslationAdmin(TranslationAdmin[_ModelT]):
"""
Convenience class which includes the necessary media files for tabbed
translation fields. Reuses Django's internal jquery version.
Expand All @@ -441,7 +459,7 @@ class Media:
}


class TabbedExternalJqueryTranslationAdmin(TranslationAdmin):
class TabbedExternalJqueryTranslationAdmin(TranslationAdmin[_ModelT]):
"""
Convenience class which includes the necessary media files for tabbed
translation fields. Loads recent jquery version from a cdn.
Expand Down
12 changes: 12 additions & 0 deletions modeltranslation/tests/test_runtime_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from modeltranslation import admin
from modeltranslation.tests import models


def test_translation_admin():
class TestModelAdmin(admin.TranslationAdmin[models.TestModel]):
pass

class TestInlineModelAdmin(
admin.TranslationInlineModelAdmin[models.ForeignKeyModel, models.TestModel]
):
pass

0 comments on commit d2c16fe

Please sign in to comment.