Skip to content

Commit

Permalink
Support assigning relationss via custom lookup fields
Browse files Browse the repository at this point in the history
  • Loading branch information
stevelacey committed Aug 25, 2022
1 parent 48b9d58 commit 8682852
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 10 deletions.
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def pytest_configure():
RoleFactory,
SkillFactory,
TagFactory,
TaskFactory,
TeamFactory,
UserFactory,
)
Expand All @@ -54,6 +55,7 @@ def pytest_configure():
register(RoleFactory, "role")
register(SkillFactory, "skill")
register(TagFactory, "tag")
register(TaskFactory, "task")
register(TeamFactory, "team")
register(UserFactory, "user")

Expand Down
9 changes: 8 additions & 1 deletion tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from django.contrib.auth.models import User

from tests.models import Profile, Role, Skill, Tag, Team
from tests.models import Profile, Role, Skill, Tag, Task, Team


class ProfileFactory(DjangoModelFactory):
Expand Down Expand Up @@ -41,6 +41,13 @@ class Meta:
model = Tag


class TaskFactory(DjangoModelFactory):
name = factory.Sequence(lambda i: f"Task {i}")

class Meta:
model = Task


class TeamFactory(DjangoModelFactory):
name = factory.Sequence(lambda i: f"Team {i}")

Expand Down
9 changes: 9 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Profile(models.Model):
role = models.ForeignKey("Role", on_delete=models.CASCADE)
team = models.ForeignKey("Team", blank=True, null=True, on_delete=models.SET_NULL)
skills = models.ManyToManyField("Skill", through="RatedSkill")
tasks = models.ManyToManyField("Task")
tags = models.ManyToManyField("Tag")

recovery_email = models.EmailField(blank=True, max_length=320, null=True)
Expand Down Expand Up @@ -75,6 +76,14 @@ class Tag(models.Model):
name = models.CharField(max_length=200)


class Task(models.Model):
custom_id = models.UUIDField(default=uuid4)
name = models.CharField(max_length=200)

class Api:
lookup_field = "custom_id"


class Team(models.Model):
name = models.CharField(max_length=100)

Expand Down
10 changes: 10 additions & 0 deletions tests/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ProfileSerializer(Serializer):
skills = fields.Nested("SkillSerializer", attribute="ratedskill_set", many=True)
team = fields.Nested("TeamSerializer")
tags = fields.Pluck("TagSerializer", "name", many=True)
tasks = fields.Pluck("TaskSerializer", "name", many=True)
user = fields.Nested("UserSerializer")

class Meta:
Expand All @@ -44,6 +45,7 @@ class Meta:
"skills",
"team",
"tags",
"tasks",
"user",
"last_active",
"created_at",
Expand All @@ -65,6 +67,7 @@ class Meta:
"skills",
"team",
"tags",
"tasks",
"user",
"last_active",
"created_at",
Expand All @@ -89,6 +92,13 @@ class Meta:
fields = ["id", "name"]


class TaskSerializer(Serializer):
id = fields.UUID(attribute="custom_id")

class Meta:
fields = ["id", "name"]


class TeamSerializer(Serializer):
class Meta:
fields = ["id", "name"]
10 changes: 10 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,16 @@ def test_profile_update_m2m_can_be_empty(client, db, method, profile, tag):
assert len(result["tags"]) == 0


@pytest.mark.parametrize("method", ["PATCH", "PUT"])
def test_profile_update_m2m_lookup_field(client, db, method, profile, task):
payload = dict(tasks=[task.custom_id])
response = client.generic(method, f"/profiles/{profile.pk}/", payload)
result = response.json()
assert response.status_code == 200, result
assert len(result["tasks"]) == 1
assert result["tasks"][0] == task.name


@pytest.mark.parametrize("method", ["PATCH", "PUT"])
def test_profile_update_m2m_is_not_nullable(client, db, method, profile, tag):
response = client.generic(method, f"/profiles/{profile.pk}/", dict(tags=None))
Expand Down
27 changes: 18 additions & 9 deletions worf/assigns.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,27 @@ def save(self, instance, bundle):

def set_foreign_key(self, instance, key, value):
related_model = self.get_related_model(key)
try:
related_instance = (
related_model.objects.get(pk=value) if value is not None else None
)
except related_model.DoesNotExist as e:
raise ValidationError(f"Invalid {self.keymap[key]}") from e
setattr(instance, key, related_instance)
related_model_meta = getattr(related_model, "Api", None)
lookup_field = getattr(related_model_meta, "lookup_field", "pk")
if value is not None:
try:
value = related_model.objects.get(**{lookup_field: value})
except related_model.DoesNotExist as e:
raise ValidationError(f"Invalid {self.keymap[key]}") from e
setattr(instance, key, value)

def set_many_to_many(self, instance, key, value):
related_manager = getattr(instance, key)
related_model = related_manager.model
related_model_meta = getattr(related_model, "Api", None)
lookup_field = getattr(related_model_meta, "lookup_field", "pk")
try:
getattr(instance, key).set(value)
except (IntegrityError, ValueError) as e:
if lookup_field != "pk":
results = related_model.objects.filter(**{f"{lookup_field}__in": value})
assert len(results) == len(value)
value = results
related_manager.set(value)
except (AssertionError, IntegrityError, ValueError) as e:
raise ValidationError(f"Invalid {self.keymap[key]}") from e

def set_many_to_many_with_through(self, instance, key, value):
Expand Down

0 comments on commit 8682852

Please sign in to comment.