Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the memory consumption when creating a task #2582

Merged
merged 4 commits into from
Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Memory consumption for the task creation process (<https://github.com/openvinotoolkit/cvat/pull/2582>)

### Security

Expand Down
54 changes: 34 additions & 20 deletions cvat/apps/engine/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,35 @@ def __init__(self, **kwargs):
raise Exception('No sourse path')
self.source_path = kwargs.get('source_path')

def _open_video_container(self, sourse_path, mode, options=None):
@staticmethod
def _open_video_container(sourse_path, mode, options=None):
return av.open(sourse_path, mode=mode, options=options)

def _close_video_container(self, container):
@staticmethod
def _close_video_container(container):
container.close()

def _get_video_stream(self, container):
@staticmethod
def _get_video_stream(container):
video_stream = next(stream for stream in container.streams if stream.type == 'video')
video_stream.thread_type = 'AUTO'
return video_stream

@staticmethod
def _get_frame_size(container):
video_stream = WorkWithVideo._get_video_stream(container)
for packet in container.demux(video_stream):
for frame in packet.decode():
if video_stream.metadata.get('rotate'):
frame = av.VideoFrame().from_ndarray(
rotate_image(
frame.to_ndarray(format='bgr24'),
360 - int(container.streams.video[0].metadata.get('rotate')),
),
format ='bgr24',
)
return frame.width, frame.height

class AnalyzeVideo(WorkWithVideo):
def check_type_first_frame(self):
container = self._open_video_container(self.source_path, mode='r')
Expand Down Expand Up @@ -71,28 +89,21 @@ def __init__(self, **kwargs):
self.key_frames = {}
self.frames = 0

container = self._open_video_container(self.source_path, 'r')
self.width, self.height = self._get_frame_size(container)
self._close_video_container(container)

def get_task_size(self):
return self.frames

@property
def frame_sizes(self):
container = self._open_video_container(self.source_path, 'r')
frame = next(iter(self.key_frames.values()))
if container.streams.video[0].metadata.get('rotate'):
frame = av.VideoFrame().from_ndarray(
rotate_image(
frame.to_ndarray(format='bgr24'),
360 - int(container.streams.video[0].metadata.get('rotate'))
),
format ='bgr24'
)
self._close_video_container(container)
return (frame.width, frame.height)
return (self.width, self.height)

def check_key_frame(self, container, video_stream, key_frame):
for packet in container.demux(video_stream):
for frame in packet.decode():
if md5_hash(frame) != md5_hash(key_frame[1]) or frame.pts != key_frame[1].pts:
if md5_hash(frame) != key_frame[1]['md5'] or frame.pts != key_frame[1]['pts']:
self.key_frames.pop(key_frame[0])
return

Expand All @@ -103,7 +114,7 @@ def check_seek_key_frames(self):
key_frames_copy = self.key_frames.copy()

for key_frame in key_frames_copy.items():
container.seek(offset=key_frame[1].pts, stream=video_stream)
container.seek(offset=key_frame[1]['pts'], stream=video_stream)
self.check_key_frame(container, video_stream, key_frame)

def check_frames_ratio(self, chunk_size):
Expand All @@ -114,10 +125,13 @@ def save_key_frames(self):
video_stream = self._get_video_stream(container)
frame_number = 0

for packet in container.demux(video_stream):
for packet in container.demux(video_stream):
for frame in packet.decode():
if frame.key_frame:
self.key_frames[frame_number] = frame
self.key_frames[frame_number] = {
'pts': frame.pts,
'md5': md5_hash(frame),
}
frame_number += 1

self.frames = frame_number
Expand All @@ -126,7 +140,7 @@ def save_key_frames(self):
def save_meta_info(self):
with open(self.meta_path, 'w') as meta_file:
for index, frame in self.key_frames.items():
meta_file.write('{} {}\n'.format(index, frame.pts))
meta_file.write('{} {}\n'.format(index, frame['pts']))

def get_nearest_left_key_frame(self, start_chunk_frame_number):
start_decode_frame_number = 0
Expand Down