Skip to content

Commit

Permalink
Merge pull request #6 from azhavoro/az/fix_unittests
Browse files Browse the repository at this point in the history
fixed unit tests
  • Loading branch information
ygnn123 authored Nov 8, 2019
2 parents 6da262b + 033b635 commit 41ca049
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 76 deletions.
2 changes: 1 addition & 1 deletion cvat/apps/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _init_frame_info(self):
"path": db_image.path,
"width": db_image.width,
"height": db_image.height,
} for db_image in self._db_task.data.image_set.all()}
} for db_image in self._db_task.data.images.all()}

self._frame_mapping = {
self._get_filename(info["path"]): frame for frame, info in self._frame_info.items()
Expand Down
3 changes: 2 additions & 1 deletion cvat/apps/engine/media_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def save_as_chunks(self, chunk_size, compressed_chunk_path, original_chunk_path,
original_chunk = original_chunk_path(counter)
with zipfile.ZipFile(original_chunk, 'x') as zip_chunk:
for idx, image_file in enumerate(chunk_data):
zip_chunk.write(filename=image_file, arcname=os.path.basename(image_file))
arcname = '{:06d}.{}'.format(idx, os.path.basename(image_file)[1])
zip_chunk.write(filename=image_file, arcname=arcname)

counter += 1
if progress_callback:
Expand Down
2 changes: 1 addition & 1 deletion cvat/apps/engine/migrations/0023_auto_20191023_1025.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def fix_path(path):
# compressed images
old_task_data_dir = os.path.join(old_db_task_dir, 'data')
if os.path.isdir(old_task_data_dir):
shutil.copytree(old_task_data_dir, compressed_cache_dir, symlinks=False)
shutil.copytree(old_task_data_dir, compressed_cache_dir, symlinks=False, ignore_dangling_symlinks=True)

# prepare *.list chunks
for chunk_idx, start_frame in enumerate(range(0, db_data.size, db_data.chunk_size)):
Expand Down
39 changes: 10 additions & 29 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,8 @@ class DataSerializer(serializers.ModelSerializer):

class Meta:
model = models.Data
fields = ('chunk_size', 'size', 'image_quality', 'start_frame', 'stop_frame', 'frame_filter', 'type', 'client_files', 'server_files', 'remote_files')
# write_once_fields = ('source', 'image_quality', 'start_frame', 'stop_frame', 'frame_filter', 'type', 'chunk_size')
# ordering = []
fields = ('chunk_size', 'size', 'image_quality', 'start_frame', 'stop_frame', 'frame_filter', 'type',
'client_files', 'server_files', 'remote_files')

def validate_frame_filter(self, value):
match = re.search("step\s*=\s*([1-9]\d*)", value)
Expand All @@ -190,49 +189,31 @@ def create(self, validated_data):
client_file.save()

for f in server_files:
server_file = models.ServerFile(task=db_data, **f)
server_file = models.ServerFile(data=db_data, **f)
server_file.save()

for f in remote_files:
remote_file = models.RemoteFile(task=db_data, **f)
remote_file = models.RemoteFile(data=db_data, **f)
remote_file.save()

db_data.save()
return db_data

def update(self, instance, validated_data):
# TODO
client_files = validated_data.pop('clientfile_set')
server_files = validated_data.pop('serverfile_set')
remote_files = validated_data.pop('remotefile_set')

for file in client_files:
client_file = models.ClientFile(data=instance, **file)
client_file.save()

for file in server_files:
server_file = models.ServerFile(data=instance, **file)
server_file.save()

for file in remote_files:
remote_file = models.RemoteFile(data=instance, **file)
remote_file.save()
instance.save()
return instance

class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer):
labels = LabelSerializer(many=True, source='label_set', partial=True)
segments = SegmentSerializer(many=True, source='segment_set', read_only=True)
data_chunk_size = serializers.ReadOnlyField(source="data.chunk_size")
data_chunk_type = serializers.ReadOnlyField(source="data.type")
data_chunk_size = serializers.ReadOnlyField(source='data.chunk_size')
data_chunk_type = serializers.ReadOnlyField(source='data.type')
size = serializers.ReadOnlyField(source='data.size')
image_quality = serializers.ReadOnlyField(source='data.image_quality')

class Meta:
model = models.Task
fields = ('url', 'id', 'name', 'mode', 'owner', 'assignee',
'bug_tracker', 'created_date', 'updated_date', 'overlap',
'segment_size', 'z_order', 'status', 'labels', 'segments',
'project', 'data_chunk_size', 'data_chunk_type')
read_only_fields = ('mode', 'created_date', 'updated_date', 'status', 'data_chunk_size', 'data_chunk_type')
'project', 'data_chunk_size', 'data_chunk_type', 'size', 'image_quality')
read_only_fields = ('mode', 'created_date', 'updated_date', 'status', 'data_chunk_size', 'data_chunk_type', 'size', 'image_quality')
write_once_fields = ('overlap', 'segment_size')
ordering = ['-id']

Expand Down
52 changes: 27 additions & 25 deletions cvat/apps/engine/tests/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from django.conf import settings
from django.contrib.auth.models import User, Group
from cvat.apps.engine.models import (Task, Segment, Job, StatusChoice,
AttributeType, Project)
AttributeType, Project, Data)
from cvat.apps.annotation.models import AnnotationFormat
from unittest import mock
import io
Expand Down Expand Up @@ -50,14 +50,28 @@ def create_db_users(cls):
cls.user = cls.user5 = user_dummy

def create_db_task(data):
data_settings = {
"size": data.pop("size"),
"image_quality": data.pop("image_quality"),
}

db_data = Data.objects.create(**data_settings)
shutil.rmtree(db_data.get_data_dirname(), ignore_errors=True)
os.makedirs(db_data.get_data_dirname())
os.makedirs(db_data.get_upload_dirname())

db_task = Task.objects.create(**data)
shutil.rmtree(db_task.get_task_dirname(), ignore_errors=True)
os.makedirs(db_task.get_upload_dirname())
os.makedirs(db_task.get_data_dirname())

for x in range(0, db_task.size, db_task.segment_size):
os.makedirs(db_task.get_task_dirname())
os.makedirs(db_task.get_task_logs_dirname())
os.makedirs(db_task.get_task_artifacts_dirname())
os.makedirs(db_task.get_task_datum_dirname())
db_task.data = db_data
db_task.save()

for x in range(0, db_task.data.size, db_task.segment_size):
start_frame = x
stop_frame = min(x + db_task.segment_size - 1, db_task.size - 1)
stop_frame = min(x + db_task.segment_size - 1, db_task.data.size - 1)

db_segment = Segment()
db_segment.task = db_task
Expand Down Expand Up @@ -1051,7 +1065,7 @@ def _run_api_v1_tasks_id(self, tid, user):
def _check_response(self, response, db_task):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["name"], db_task.name)
self.assertEqual(response.data["size"], db_task.size)
self.assertEqual(response.data["size"], db_task.data.size)
self.assertEqual(response.data["mode"], db_task.mode)
owner = db_task.owner.id if db_task.owner else None
self.assertEqual(response.data["owner"], owner)
Expand All @@ -1060,7 +1074,7 @@ def _check_response(self, response, db_task):
self.assertEqual(response.data["overlap"], db_task.overlap)
self.assertEqual(response.data["segment_size"], db_task.segment_size)
self.assertEqual(response.data["z_order"], db_task.z_order)
self.assertEqual(response.data["image_quality"], db_task.image_quality)
self.assertEqual(response.data["image_quality"], db_task.data.image_quality)
self.assertEqual(response.data["status"], db_task.status)
self.assertListEqual(
[label.name for label in db_task.label_set.all()],
Expand Down Expand Up @@ -1146,7 +1160,7 @@ def _check_response(self, response, db_task, data):
self.assertEqual(response.status_code, status.HTTP_200_OK)
name = data.get("name", db_task.name)
self.assertEqual(response.data["name"], name)
self.assertEqual(response.data["size"], db_task.size)
self.assertEqual(response.data["size"], db_task.data.size)
mode = data.get("mode", db_task.mode)
self.assertEqual(response.data["mode"], mode)
owner = db_task.owner.id if db_task.owner else None
Expand All @@ -1159,7 +1173,7 @@ def _check_response(self, response, db_task, data):
self.assertEqual(response.data["segment_size"], db_task.segment_size)
z_order = data.get("z_order", db_task.z_order)
self.assertEqual(response.data["z_order"], z_order)
image_quality = data.get("image_quality", db_task.image_quality)
image_quality = data.get("image_quality", db_task.data.image_quality)
self.assertEqual(response.data["image_quality"], image_quality)
self.assertEqual(response.data["status"], db_task.status)
if data.get("labels"):
Expand Down Expand Up @@ -1187,7 +1201,6 @@ def test_api_v1_tasks_id_admin(self):
data = {
"name": "new name for the task",
"owner": self.owner.id,
"image_quality": 60,
"labels": [{
"name": "non-vehicle",
"attributes": [{
Expand All @@ -1204,7 +1217,6 @@ def test_api_v1_tasks_id_user(self):
data = {
"name": "new name for the task",
"owner": self.assignee.id,
"image_quality": 63,
"labels": [{
"name": "car",
"attributes": [{
Expand All @@ -1221,7 +1233,6 @@ def test_api_v1_tasks_id_user(self):
def test_api_v1_tasks_id_observer(self):
data = {
"name": "new name for the task",
"image_quality": 61,
"labels": [{
"name": "test",
}]
Expand All @@ -1231,7 +1242,6 @@ def test_api_v1_tasks_id_observer(self):
def test_api_v1_tasks_id_no_auth(self):
data = {
"name": "new name for the task",
"image_quality": 59,
"labels": [{
"name": "test",
}]
Expand Down Expand Up @@ -1315,15 +1325,13 @@ def _run_api_v1_tasks(self, user, data):
def _check_response(self, response, user, data):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data["name"], data["name"])
self.assertEqual(response.data["size"], 0)
self.assertEqual(response.data["mode"], "")
self.assertEqual(response.data["owner"], data.get("owner", user.id))
self.assertEqual(response.data["assignee"], data.get("assignee"))
self.assertEqual(response.data["bug_tracker"], data.get("bug_tracker", ""))
self.assertEqual(response.data["overlap"], data.get("overlap", None))
self.assertEqual(response.data["segment_size"], data.get("segment_size", 0))
self.assertEqual(response.data["z_order"], data.get("z_order", False))
self.assertEqual(response.data["image_quality"], data.get("image_quality", 50))
self.assertEqual(response.data["status"], StatusChoice.ANNOTATION)
self.assertListEqual(
[label["name"] for label in data.get("labels")],
Expand All @@ -1342,7 +1350,6 @@ def _check_api_v1_tasks(self, user, data):
def test_api_v1_tasks_admin(self):
data = {
"name": "new name for the task",
"image_quality": 60,
"labels": [{
"name": "non-vehicle",
"attributes": [{
Expand All @@ -1359,7 +1366,6 @@ def test_api_v1_tasks_user(self):
data = {
"name": "new name for the task",
"owner": self.assignee.id,
"image_quality": 63,
"labels": [{
"name": "car",
"attributes": [{
Expand All @@ -1376,7 +1382,6 @@ def test_api_v1_tasks_user(self):
def test_api_v1_tasks_observer(self):
data = {
"name": "new name for the task",
"image_quality": 61,
"labels": [{
"name": "test",
}]
Expand All @@ -1386,7 +1391,6 @@ def test_api_v1_tasks_observer(self):
def test_api_v1_tasks_no_auth(self):
data = {
"name": "new name for the task",
"image_quality": 59,
"labels": [{
"name": "test",
}]
Expand Down Expand Up @@ -1474,7 +1478,6 @@ def _test_api_v1_tasks_id_data(self, user):
"overlap": 0,
"segment_size": 100,
"z_order": False,
"image_quality": 75,
"labels": [
{"name": "car"},
{"name": "person"},
Expand All @@ -1488,6 +1491,7 @@ def _test_api_v1_tasks_id_data(self, user):
"client_files[0]": generate_image_file("test_1.jpg"),
"client_files[1]": generate_image_file("test_2.jpg"),
"client_files[2]": generate_image_file("test_3.jpg"),
"image_quality": 75,
}

response = self._run_api_v1_tasks_id_data(task_id, user, data)
Expand All @@ -1497,7 +1501,6 @@ def _test_api_v1_tasks_id_data(self, user):
"name": "my task #2",
"overlap": 0,
"segment_size": 0,
"image_quality": 75,
"labels": [
{"name": "car"},
{"name": "person"},
Expand All @@ -1512,6 +1515,7 @@ def _test_api_v1_tasks_id_data(self, user):
"server_files[1]": "test_2.jpg",
"server_files[2]": "test_3.jpg",
"server_files[3]": "data/test_3.jpg",
"image_quality": 75,
}

response = self._run_api_v1_tasks_id_data(task_id, user, data)
Expand All @@ -1534,7 +1538,6 @@ def test_api_v1_tasks_id_data_no_auth(self):
"overlap": 0,
"segment_size": 100,
"z_order": False,
"image_quality": 75,
"labels": [
{"name": "car"},
{"name": "person"},
Expand Down Expand Up @@ -1577,7 +1580,6 @@ def _create_task(self, owner, assignee):
"overlap": 0,
"segment_size": 100,
"z_order": False,
"image_quality": 75,
"labels": [
{
"name": "car",
Expand Down Expand Up @@ -1610,6 +1612,7 @@ def _create_task(self, owner, assignee):
"client_files[0]": generate_image_file("test_1.jpg"),
"client_files[1]": generate_image_file("test_2.jpg"),
"client_files[2]": generate_image_file("test_3.jpg"),
"image_quality": 75,
}
response = self.client.post("/api/v1/tasks/{}/data".format(tid), data=images)
assert response.status_code == status.HTTP_202_ACCEPTED
Expand Down Expand Up @@ -2667,7 +2670,6 @@ def _get_initial_annotation(annotation_format):
response = self._get_annotation_formats(annotator)
self.assertEqual(response.status_code, HTTP_200_OK)


if annotator is not None:
supported_formats = response.data
else:
Expand Down
4 changes: 2 additions & 2 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def perform_destroy(self, instance):
super().perform_destroy(instance)
shutil.rmtree(task_dirname, ignore_errors=True)
if not instance.data.tasks.all():
shutil.rmtree(instance.data.get_data_dirname())
shutil.rmtree(instance.data.get_data_dirname(), ignore_errors=True)
instance.data.delete()

@action(detail=True, methods=['GET'], serializer_class=JobSerializer)
Expand Down Expand Up @@ -347,7 +347,7 @@ def dump(self, request, pk, filename):
raise serializers.ValidationError(
"Please specify a correct 'format' parameter for the request")

file_path = os.path.join(db_task.get_task_dirname(),
file_path = os.path.join(db_task.get_task_artifacts_dirname(),
"{}.{}.{}.{}".format(filename, username, timestamp, db_dumper.format.lower()))

queue = django_rq.get_queue("default")
Expand Down
11 changes: 11 additions & 0 deletions cvat/settings/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,20 @@

DATA_ROOT = os.path.join(_temp_dir.name, 'data')
os.makedirs(DATA_ROOT, exist_ok=True)

SHARE_ROOT = os.path.join(_temp_dir.name, 'share')
os.makedirs(SHARE_ROOT, exist_ok=True)

MEDIA_DATA_ROOT = os.path.join(DATA_ROOT, 'data')
os.makedirs(MEDIA_DATA_ROOT, exist_ok=True)

TASKS_ROOT = os.path.join(DATA_ROOT, 'tasks')
os.makedirs(TASKS_ROOT, exist_ok=True)

MODELS_ROOT = os.path.join(DATA_ROOT, 'models')
os.makedirs(MODELS_ROOT, exist_ok=True)


# To avoid ERROR django.security.SuspiciousFileOperation:
# The joined path (...) is located outside of the base path component
MEDIA_ROOT = _temp_dir.name
Expand Down
Loading

0 comments on commit 41ca049

Please sign in to comment.