Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests written for django #135

Merged
merged 1 commit into from
Mar 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions .github/workflows/black-linter.yml

This file was deleted.

39 changes: 39 additions & 0 deletions .github/workflows/server-checks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Server code check

on:
push:
paths:
- "src/**/*.py"
branches:
- develop
pull_request:
paths:
- "src/**/*.py"
branches:
- develop

jobs:
black_linter:
name: Black formatter
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: rickstaa/action-black@v1
with:
black_args: ". --check"
django_tests:
name: Django tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ./src
- name: Run test
run: |
python run.py migrate
python run.py test main rest
4 changes: 2 additions & 2 deletions src/server/main/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def load_user_fixture(sender, *args, **kwargs):
):
if not kwargs["kwargs"].get("id"):
kwargs["kwargs"]["id"] = (
models.User.objects.aggregate(Max("id")).get("id__max", 0) + 1
)
models.User.objects.aggregate(Max("id")).get("id__max", 0) or 0
) + 1

if not kwargs["kwargs"].get("password"):
password = "".join(
Expand Down
84 changes: 83 additions & 1 deletion src/server/main/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,85 @@
from django.test import TestCase
from django.utils import timezone
from main import models

# Create your tests here.

class MainAppTest(TestCase):
def setUp(self):
return super().setUp()

def test_user_creation(self):
"""
Ensure we can create new users using emails.
"""
models.User.objects.create(email="[email protected]")
user_queryset = models.User.objects.filter(
email="[email protected]"
)
self.assertEqual(user_queryset.count(), 1)

user, result = models.User.objects.update_or_create(
email="[email protected]",
defaults={"forename": "Test", "surname": "User"},
)
self.assertFalse(result)
self.assertEqual(str(user), "Test User")

def test_portion_items(self):
"""
Ensure we can create new portion items.
"""
queryset = [
models.PortionItem.objects.create(name=item, is_default=True)
for item in ["Milk", "Honey", "Sugar"]
]
self.assertQuerysetEqual(
queryset, models.PortionItem.objects.filter(is_default=True)
)

non_default_item = models.PortionItem.objects.create(name="Coco")
self.assertFalse(non_default_item.is_default)

def test_logs(self):
"""
Ensure we can log portions for a user and filter.
"""
user = models.User.objects.create(email="[email protected]")

portion_items = {
item["name"]: models.PortionItem.objects.create(**item)
for item in [
{"name": "A", "is_default": True},
{"name": "B", "is_default": False},
{"name": "C", "is_default": True},
{"name": "D", "is_default": False},
]
}

track_items = {
item: models.TrackItem.objects.create(user=user, item=portion_items[item])
for item in portion_items
}

logs = [
models.UserLog.objects.create(item=track_items[item])
for item in track_items
]

self.assertQuerysetEqual(
logs,
models.UserLog.objects.filter(
item__user=user, timestamp__lte=timezone.now()
),
)

def test_signals(self):
"""
Ensuring signals are working as desired, like generating passwords.
"""
user = models.User.objects.create(email="[email protected]")
self.assertIsNotNone(user.password)

self.assertQuerysetEqual(
models.PortionItem.objects.filter(trackitem__user=user),
models.PortionItem.objects.filter(is_default=True),
)
2 changes: 1 addition & 1 deletion src/server/rest/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def create(self, validated_data):
return instance

def update(self, instance, validated_data):
password = validated_data.pop("password")
password = validated_data.pop("password", None)
old_password = validated_data.pop("old_password", None)
instance = super().update(instance, validated_data)

Expand Down
156 changes: 154 additions & 2 deletions src/server/rest/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,155 @@
from django.test import TestCase
from django.urls import reverse
from django.utils import timezone
from rest_framework import status, test
from main import models
from rest import serializers

# Create your tests here.

class AccountTests(test.APITestCase):
def setUp(self):
_ = [
models.PortionItem.objects.create(**item)
for item in [
{"name": "System Chips", "is_default": True},
{"name": "Thermal Paste", "is_default": True},
{"name": "Lithium", "is_default": True},
{"name": "Gigawatts", "is_default": False},
{"name": "Electrons", "is_default": False},
]
]
return super().setUp()

def test_secure_access(self):
"""
Ensure endpoints are secured and require authorisation.
"""
_ = [
self.assertEqual(
self.client.get(reverse(endpoint["name"]), format="json").status_code,
endpoint.get("code", status.HTTP_200_OK),
)
for endpoint in [
{"name": "user-list", "code": status.HTTP_401_UNAUTHORIZED},
{"name": "portionitem-list"},
{"name": "trackitem-list", "code": status.HTTP_401_UNAUTHORIZED},
{"name": "userlog-list", "code": status.HTTP_401_UNAUTHORIZED},
{"name": "resource-list"},
{"name": "journal-list", "code": status.HTTP_401_UNAUTHORIZED},
]
]

def test_user_creation(self):
"""
Ensure we can create new users as desired through API.
"""
url = reverse("user-list")

response = self.client.post(url, data={"email": "test"})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictContainsSubset(
{
"email": ["Enter a valid email address."],
"password": ["This field is required."],
},
response.data,
)

response = self.client.post(
url, data={"email": "[email protected]", "password": "abcd"}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictContainsSubset(
{
"password": [
"This password is too short. It must contain at least 8 characters.",
"This password is too common.",
]
},
response.data,
)

response = self.client.post(
url, data={"email": "[email protected]", "password": "TestUser#0001"}
)
user = models.User.objects.get(email="[email protected]")
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertDictEqual(
serializers.UserSerializer(user).data,
response.data,
)

result = self.client.login(
email="[email protected]", password="TestUser#0001"
)
self.assertTrue(result)

response = self.client.patch(
f"{url}{user.id}/",
data={
"email": "[email protected]",
"forename": "Test",
"surname": "User",
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictContainsSubset(
{"forename": "Test", "surname": "User"},
response.data,
)

def test_portion_items(self):
"""
Ensure we can that portion items work on the endpoint.
"""
url = reverse("portionitem-list")

response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 4)
self.assertListEqual(
response.data.get("results"),
serializers.PortionItemSerializer(
models.PortionItem.objects.filter(is_default=True),
many=True,
).data,
)

response = self.client.post(url, data={"name": "Capacitors"})
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_logging(self):
"""
Ensure that users are able to log portions (through the API).
"""
user_credentials = {
"email": "[email protected]",
"password": "TestUser#0001",
}
self.client.post(reverse("user-list"), data=user_credentials)
self.client.login(**user_credentials)

response = self.client.get(reverse("trackitem-list"))
self.assertEqual(response.status_code, status.HTTP_200_OK)
track_items = models.TrackItem.objects.filter(
user__email=user_credentials["email"]
)
self.assertListEqual(
serializers.TrackItemSerializer(track_items, many=True).data, response.data
)

url = reverse("userlog-list")
_ = [
self.assertEqual(
self.client.post(url, data={"item": item.id}).status_code,
status.HTTP_201_CREATED,
)
for item in track_items
]

self.assertEqual(
models.UserLog.objects.filter(
item__user__email=user_credentials["email"],
timestamp__lte=timezone.now(),
).count(),
track_items.count(),
)
5 changes: 5 additions & 0 deletions src/server/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@
"HOST": env("DB_HOST"),
"PORT": env("DB_PORT"),
}
if env("DB_HOST", default=None)
else {
"ENGINE": "django.db.backends.sqlite3",
"NAME": BASE_DIR / "db.sqlite3",
},
}


Expand Down
2 changes: 1 addition & 1 deletion src/server/server/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@

urlpatterns = [
path("", include("main.urls")),
path("api/", include("rest.urls")),
path("api/", include("rest.urls")), # api/v1/
path("admin/", admin.site.urls),
] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)