Skip to content

Commit

Permalink
Refactor formats tests (#1634)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max authored Jun 10, 2020
1 parent e0ef2cf commit c792c8c
Showing 1 changed file with 54 additions and 48 deletions.
102 changes: 54 additions & 48 deletions cvat/apps/dataset_manager/tests/_test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def _setUpModule():

from io import BytesIO
import os.path as osp
import random
import tempfile
import zipfile

Expand All @@ -74,25 +73,14 @@ def _setUpModule():
_setUpModule()


def generate_image_file(filename):
def generate_image_file(filename, size=(100, 50)):
f = BytesIO()
width = random.randint(10, 200)
height = random.randint(10, 200)
image = Image.new('RGB', size=(width, height))
image = Image.new('RGB', size=size)
image.save(f, 'jpeg')
f.name = filename
f.seek(0)

return f

def create_db_users(cls):
group_user, _ = Group.objects.get_or_create(name="user")

user_dummy = User.objects.create_superuser(username="test", password="test", email="")
user_dummy.groups.add(group_user)

cls.user = user_dummy

class ForceLogin:
def __init__(self, user, client):
self.user = user
Expand All @@ -109,14 +97,47 @@ def __exit__(self, exception_type, exception_value, traceback):
if self.user:
self.client.logout()

class TaskExportTest(APITestCase):
class _DbTestBase(APITestCase):
def setUp(self):
self.client = APIClient()

@classmethod
def setUpTestData(cls):
create_db_users(cls)
cls.create_db_users()

@classmethod
def create_db_users(cls):
group, _ = Group.objects.get_or_create(name="adm")

admin = User.objects.create_superuser(
username="test", password="test", email="")
admin.groups.add(group)

cls.user = admin

def _put_api_v1_task_id_annotations(self, tid, data):
with ForceLogin(self.user, self.client):
response = self.client.put("/api/v1/tasks/%s/annotations" % tid,
data=data, format="json")

return response

def _create_task(self, data, image_data):
with ForceLogin(self.user, self.client):
response = self.client.post('/api/v1/tasks', data=data, format="json")
assert response.status_code == status.HTTP_201_CREATED, response.status_code
tid = response.data["id"]

response = self.client.post("/api/v1/tasks/%s/data" % tid,
data=image_data)
assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code

response = self.client.get("/api/v1/tasks/%s" % tid)
task = response.data

return task

class TaskExportTest(_DbTestBase):
def _generate_annotations(self, task):
annotations = {
"version": 0,
Expand Down Expand Up @@ -231,7 +252,15 @@ def _generate_annotations(self, task):
self._put_api_v1_task_id_annotations(task["id"], annotations)
return annotations

def _generate_task(self):
def _generate_task_images(self, count):
images = {
"client_files[%d]" % i: generate_image_file("image_%d.jpg" % i)
for i in range(count)
}
images["image_quality"] = 75
return images

def _generate_task(self, images):
task = {
"name": "my task #1",
"owner": '',
Expand Down Expand Up @@ -261,35 +290,10 @@ def _generate_task(self):
{"name": "person"},
]
}
return self._create_task(task, 3)

def _create_task(self, data, size):
with ForceLogin(self.user, self.client):
response = self.client.post('/api/v1/tasks', data=data, format="json")
assert response.status_code == status.HTTP_201_CREATED, response.status_code
tid = response.data["id"]

images = {
"client_files[%d]" % i: generate_image_file("image_%d.jpg" % i)
for i in range(size)
}
images["image_quality"] = 75
response = self.client.post("/api/v1/tasks/{}/data".format(tid), data=images)
assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code

response = self.client.get("/api/v1/tasks/{}".format(tid))
task = response.data

return task

def _put_api_v1_task_id_annotations(self, tid, data):
with ForceLogin(self.user, self.client):
response = self.client.put("/api/v1/tasks/{}/annotations".format(tid),
data=data, format="json")

return response
return self._create_task(task, images)

def _test_export(self, check, task, format_name, **export_args):
@staticmethod
def _test_export(check, task, format_name, **export_args):
with tempfile.TemporaryDirectory() as temp_dir:
file_path = osp.join(temp_dir, format_name)
dm.task.export_task(task["id"], file_path,
Expand Down Expand Up @@ -340,9 +344,10 @@ def check(file_path):

format_name = f.DISPLAY_NAME
for save_images in { True, False }:
images = self._generate_task_images(3)
task = self._generate_task(images)
self._generate_annotations(task)
with self.subTest(format=format_name, save_images=save_images):
task = self._generate_task()
self._generate_annotations(task)
self._test_export(check, task,
format_name, save_images=save_images)

Expand All @@ -365,7 +370,8 @@ def test_empty_images_are_exported(self):
if not dm.formats.registry.EXPORT_FORMATS[format_name].ENABLED:
self.skipTest("Format is disabled")

task = self._generate_task()
images = self._generate_task_images(3)
task = self._generate_task(images)

def check(file_path):
def load_dataset(src):
Expand Down

0 comments on commit c792c8c

Please sign in to comment.