Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Expose common dataset interface #723

Merged
merged 1 commit into from
May 13, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 123 additions & 1 deletion digits/dataset/images/classification/job.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) 2014-2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import os

from ..job import ImageDatasetJob
from digits.dataset import tasks
from digits.status import Status
from digits.utils import subclass, override
from digits.utils import subclass, override, constants

# NOTE: Increment this everytime the pickled object changes
PICKLE_VERSION = 2
Expand Down Expand Up @@ -54,11 +56,121 @@ def __setstate__(self, state):

self.pickver_job_dataset_image_classification = PICKLE_VERSION

def create_db_tasks(self):
"""
Return all CreateDbTasks for this job
"""
return [t for t in self.tasks if isinstance(t, tasks.CreateDbTask)]

@override
def get_backend(self):
"""
Return the DB backend used to create this dataset
"""
return self.train_db_task().backend

def get_encoding(self):
"""
Return the DB encoding used to create this dataset
"""
return self.train_db_task().encoding

def get_compression(self):
"""
Return the DB compression used to create this dataset
"""
return self.train_db_task().compression

@override
def get_entry_count(self, stage):
"""
Return the number of entries in the DB matching the specified stage
"""
if stage == constants.TRAIN_DB:
db = self.train_db_task()
elif stage == constants.VAL_DB:
db = self.val_db_task()
elif stage == constants.TEST_DB:
db = self.test_db_task()
else:
return 0
return db.entries_count if db is not None else 0

@override
def get_feature_dims(self):
"""
Return the shape of the feature N-D array
"""
return self.image_dims

@override
def get_feature_db_path(self, stage):
"""
Return the absolute feature DB path for the specified stage
"""
path = self.path(stage)
return path if os.path.exists(path) else None

@override
def get_label_db_path(self, stage):
"""
Return the absolute label DB path for the specified stage
"""
# classification datasets don't have label DBs
return None

@override
def get_mean_file(self):
"""
Return the mean file
"""
return self.train_db_task().mean_file

@override
def job_type(self):
return 'Image Classification Dataset'

@override
def json_dict(self, verbose=False):
d = super(ImageClassificationDatasetJob, self).json_dict(verbose)

if verbose:
d.update({
'ParseFolderTasks': [{
"name": t.name(),
"label_count": t.label_count,
"train_count": t.train_count,
"val_count": t.val_count,
"test_count": t.test_count,
} for t in self.parse_folder_tasks()],
'CreateDbTasks': [{
"name": t.name(),
"entries": t.entries_count,
"image_width": t.image_dims[0],
"image_height": t.image_dims[1],
"image_channels": t.image_dims[2],
"backend": t.backend,
"encoding": t.encoding,
"compression": t.compression,
} for t in self.create_db_tasks()],
})
return d

def parse_folder_tasks(self):
"""
Return all ParseFolderTasks for this job
"""
return [t for t in self.tasks if isinstance(t, tasks.ParseFolderTask)]

def test_db_task(self):
"""
Return the task that creates the test set
"""
for t in self.tasks:
if isinstance(t, tasks.CreateDbTask) and 'test' in t.name().lower():
return t
return None

def train_db_task(self):
"""
Return the task that creates the training set
Expand All @@ -68,3 +180,13 @@ def train_db_task(self):
return t
return None

def val_db_task(self):
"""
Return the task that creates the validation set
"""
for t in self.tasks:
if isinstance(t, tasks.CreateDbTask) and 'val' in t.name().lower():
return t
return None


91 changes: 82 additions & 9 deletions digits/dataset/images/generic/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from ..job import ImageDatasetJob
from digits.dataset import tasks
from digits.utils import subclass, override
from digits.utils import subclass, override, constants

# NOTE: Increment this everytime the pickled object changes
PICKLE_VERSION = 1
Expand All @@ -23,18 +23,91 @@ def __setstate__(self, state):
super(GenericImageDatasetJob, self).__setstate__(state)
self.pickver_job_dataset_image_generic = PICKLE_VERSION

@override
def job_type(self):
return 'Generic Image Dataset'

@override
def train_db_task(self):
def analyze_db_task(self, stage):
"""
Return the task that creates the training set
Return AnalyzeDbTask for this stage
"""
if stage == constants.TRAIN_DB:
s = 'train'
elif stage == constants.VAL_DB:
s = 'val'
else:
return None
for t in self.tasks:
if isinstance(t, tasks.AnalyzeDbTask) and 'train' in t.name().lower():
if isinstance(t, tasks.AnalyzeDbTask) and s in t.name().lower():
return t
return None

def analyze_db_tasks(self):
"""
Return all AnalyzeDbTasks for this job
"""
return [t for t in self.tasks if isinstance(t, tasks.AnalyzeDbTask)]

@override
def get_backend(self):
"""
Return the DB backend used to create this dataset
"""
return self.analyze_db_task(constants.TRAIN_DB).backend

@override
def get_entry_count(self, stage):
"""
Return the number of entries in the DB matching the specified stage
"""
task = self.analyze_db_task(stage)
return task.image_count if task is not None else 0

@override
def get_feature_db_path(self, stage):
"""
Return the absolute feature DB path for the specified stage
"""
db = None
if stage == constants.TRAIN_DB:
s = 'Training'
elif stage == constants.VAL_DB:
s = 'Validation'
else:
return None
for task in self.tasks:
if task.purpose == '%s Images' % s:
db = task
return self.path(db.database) if db else None

@override
def get_feature_dims(self):
"""
Return the shape of the feature N-D array
"""
db_task = self.analyze_db_task(constants.TRAIN_DB)
return [db_task.image_height, db_task.image_width, db_task.image_channels]

@override
def get_label_db_path(self, stage):
"""
Return the absolute label DB path for the specified stage
"""
db = None
if stage == constants.TRAIN_DB:
s = 'Training'
elif stage == constants.VAL_DB:
s = 'Validation'
else:
return None
for task in self.tasks:
if task.purpose == '%s Labels' % s:
db = task
return self.path(db.database) if db else None

@override
def get_mean_file(self):
"""
Return the mean file
"""
return self.mean_file

@override
def job_type(self):
return 'Generic Image Dataset'
67 changes: 17 additions & 50 deletions digits/dataset/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,71 +20,38 @@ def __init__(self, **kwargs):
super(DatasetJob, self).__init__(**kwargs)
self.pickver_job_dataset = PICKLE_VERSION

@override
def json_dict(self, verbose=False):
d = super(DatasetJob, self).json_dict(verbose)

if verbose:
d.update({
'ParseFolderTasks': [{
"name": t.name(),
"label_count": t.label_count,
"train_count": t.train_count,
"val_count": t.val_count,
"test_count": t.test_count,
} for t in self.parse_folder_tasks()],
'CreateDbTasks': [{
"name": t.name(),
"entries": t.entries_count,
"image_width": t.image_dims[0],
"image_height": t.image_dims[1],
"image_channels": t.image_dims[2],
"backend": t.backend,
"encoding": t.encoding,
"compression": t.compression,
} for t in self.create_db_tasks()],
})
return d

def parse_folder_tasks(self):
def get_backend(self):
"""
Return all ParseFolderTasks for this job
Return the DB backend used to create this dataset
"""
return [t for t in self.tasks if isinstance(t, tasks.ParseFolderTask)]
raise NotImplementedError('Please implement me')

def create_db_tasks(self):
def get_entry_count(self, stage):
"""
Return all CreateDbTasks for this job
Return the number of entries in the DB matching the specified stage
"""
return [t for t in self.tasks if isinstance(t, tasks.CreateDbTask)]
raise NotImplementedError('Please implement me')

def train_db_task(self):
def get_feature_db_path(self, stage):
"""
Return the task that creates the training set
Return the absolute feature DB path for the specified stage
"""
raise NotImplementedError('Please implement me')

def val_db_task(self):
def get_feature_dims(self):
"""
Return the task that creates the validation set
Return the shape of the feature N-D array
"""
for t in self.tasks:
if isinstance(t, tasks.CreateDbTask) and 'val' in t.name().lower():
return t
return None
raise NotImplementedError('Please implement me')

def test_db_task(self):
def get_label_db_path(self, stage):
"""
Return the task that creates the test set
Return the absolute label DB path for the specified stage
"""
for t in self.tasks:
if isinstance(t, tasks.CreateDbTask) and 'test' in t.name().lower():
return t
return None
raise NotImplementedError('Please implement me')

def analyze_db_tasks(self):
def get_mean_file(self):
"""
Return all AnalyzeDbTasks for this job
Return the mean file
"""
return [t for t in self.tasks if isinstance(t, tasks.AnalyzeDbTask)]

raise NotImplementedError('Please implement me')
4 changes: 2 additions & 2 deletions digits/model/images/classification/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def download_files(self, epoch=-1):
download_files.extend([
(task.dataset.path(task.dataset.labels_file),
os.path.basename(task.dataset.labels_file)),
(task.dataset.path(task.dataset.train_db_task().mean_file),
os.path.basename(task.dataset.train_db_task().mean_file)),
(task.dataset.path(task.dataset.get_mean_file()),
os.path.basename(task.dataset.get_mean_file())),
(snapshot_filename,
os.path.basename(snapshot_filename)),
])
Expand Down
2 changes: 1 addition & 1 deletion digits/model/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def create():
else ''), form.python_layer_server_file.data)

job.tasks.append(fw.create_train_task(
job_dir = job.dir(),
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
Expand Down
2 changes: 1 addition & 1 deletion digits/model/images/generic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def create():
else ''), form.python_layer_server_file.data)

job.tasks.append(fw.create_train_task(
job_dir = job.dir(),
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
Expand Down
Loading