Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Expose common dataset interface
Browse files Browse the repository at this point in the history
This will allow to create new types of datasets without making changes in other parts of DIGITS.
  • Loading branch information
gheinrich committed May 11, 2016
1 parent 8d75238 commit bec92de
Show file tree
Hide file tree
Showing 14 changed files with 363 additions and 272 deletions.
124 changes: 123 additions & 1 deletion digits/dataset/images/classification/job.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) 2014-2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import os

from ..job import ImageDatasetJob
from digits.dataset import tasks
from digits.status import Status
from digits.utils import subclass, override
from digits.utils import subclass, override, constants

# NOTE: Increment this everytime the pickled object changes
PICKLE_VERSION = 2
Expand Down Expand Up @@ -54,11 +56,121 @@ def __setstate__(self, state):

self.pickver_job_dataset_image_classification = PICKLE_VERSION

def create_db_tasks(self):
"""
Return all CreateDbTasks for this job
"""
return [t for t in self.tasks if isinstance(t, tasks.CreateDbTask)]

@override
def get_backend(self):
"""
Return the DB backend used to create this dataset
"""
return self.train_db_task().backend

def get_encoding(self):
"""
Return the DB encoding used to create this dataset
"""
return self.train_db_task().encoding

def get_compression(self):
"""
Return the DB compression used to create this dataset
"""
return self.train_db_task().compression

@override
def get_entry_count(self, stage):
"""
Return the number of entries in the DB matching the specified stage
"""
if stage == constants.TRAIN_DB:
db = self.train_db_task()
elif stage == constants.VAL_DB:
db = self.val_db_task()
elif stage == constants.TEST_DB:
db = self.test_db_task()
else:
return 0
return db.entries_count if db is not None else 0

@override
def get_feature_dims(self):
"""
Return the shape of the feature N-D array
"""
return self.image_dims

@override
def get_feature_db_path(self, stage):
"""
Return the absolute feature DB path for the specified stage
"""
path = self.path(stage)
return path if os.path.exists(path) else None

@override
def get_label_db_path(self, stage):
"""
Return the absolute label DB path for the specified stage
"""
# classification datasets don't have label DBs
return None

@override
def get_mean_file(self):
"""
Return the mean file
"""
return self.train_db_task().mean_file

@override
def job_type(self):
return 'Image Classification Dataset'

@override
def json_dict(self, verbose=False):
d = super(ImageClassificationDatasetJob, self).json_dict(verbose)

if verbose:
d.update({
'ParseFolderTasks': [{
"name": t.name(),
"label_count": t.label_count,
"train_count": t.train_count,
"val_count": t.val_count,
"test_count": t.test_count,
} for t in self.parse_folder_tasks()],
'CreateDbTasks': [{
"name": t.name(),
"entries": t.entries_count,
"image_width": t.image_dims[0],
"image_height": t.image_dims[1],
"image_channels": t.image_dims[2],
"backend": t.backend,
"encoding": t.encoding,
"compression": t.compression,
} for t in self.create_db_tasks()],
})
return d

def parse_folder_tasks(self):
"""
Return all ParseFolderTasks for this job
"""
return [t for t in self.tasks if isinstance(t, tasks.ParseFolderTask)]

def test_db_task(self):
"""
Return the task that creates the test set
"""
for t in self.tasks:
if isinstance(t, tasks.CreateDbTask) and 'test' in t.name().lower():
return t
return None

def train_db_task(self):
"""
Return the task that creates the training set
Expand All @@ -68,3 +180,13 @@ def train_db_task(self):
return t
return None

def val_db_task(self):
"""
Return the task that creates the validation set
"""
for t in self.tasks:
if isinstance(t, tasks.CreateDbTask) and 'val' in t.name().lower():
return t
return None


91 changes: 82 additions & 9 deletions digits/dataset/images/generic/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from ..job import ImageDatasetJob
from digits.dataset import tasks
from digits.utils import subclass, override
from digits.utils import subclass, override, constants

# NOTE: Increment this everytime the pickled object changes
PICKLE_VERSION = 1
Expand All @@ -23,18 +23,91 @@ def __setstate__(self, state):
super(GenericImageDatasetJob, self).__setstate__(state)
self.pickver_job_dataset_image_generic = PICKLE_VERSION

@override
def job_type(self):
return 'Generic Image Dataset'

@override
def train_db_task(self):
def analyze_db_task(self, stage):
"""
Return the task that creates the training set
Return AnalyzeDbTask for this stage
"""
if stage == constants.TRAIN_DB:
s = 'train'
elif stage == constants.VAL_DB:
s = 'val'
else:
return None
for t in self.tasks:
if isinstance(t, tasks.AnalyzeDbTask) and 'train' in t.name().lower():
if isinstance(t, tasks.AnalyzeDbTask) and s in t.name().lower():
return t
return None

def analyze_db_tasks(self):
"""
Return all AnalyzeDbTasks for this job
"""
return [t for t in self.tasks if isinstance(t, tasks.AnalyzeDbTask)]

@override
def get_backend(self):
"""
Return the DB backend used to create this dataset
"""
return self.analyze_db_task(constants.TRAIN_DB).backend

@override
def get_entry_count(self, stage):
"""
Return the number of entries in the DB matching the specified stage
"""
task = self.analyze_db_task(stage)
return task.image_count if task is not None else 0

@override
def get_feature_db_path(self, stage):
"""
Return the absolute feature DB path for the specified stage
"""
db = None
if stage == constants.TRAIN_DB:
s = 'Training'
elif stage == constants.VAL_DB:
s = 'Validation'
else:
return None
for task in self.tasks:
if task.purpose == '%s Images' % s:
db = task
return self.path(db.database) if db else None

@override
def get_feature_dims(self):
"""
Return the shape of the feature N-D array
"""
db_task = self.analyze_db_task(constants.TRAIN_DB)
return [db_task.image_height, db_task.image_width, db_task.image_channels]

@override
def get_label_db_path(self, stage):
"""
Return the absolute label DB path for the specified stage
"""
db = None
if stage == constants.TRAIN_DB:
s = 'Training'
elif stage == constants.VAL_DB:
s = 'Validation'
else:
return None
for task in self.tasks:
if task.purpose == '%s Labels' % s:
db = task
return self.path(db.database) if db else None

@override
def get_mean_file(self):
"""
Return the mean file
"""
return self.mean_file

@override
def job_type(self):
return 'Generic Image Dataset'
67 changes: 17 additions & 50 deletions digits/dataset/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,71 +20,38 @@ def __init__(self, **kwargs):
super(DatasetJob, self).__init__(**kwargs)
self.pickver_job_dataset = PICKLE_VERSION

@override
def json_dict(self, verbose=False):
d = super(DatasetJob, self).json_dict(verbose)

if verbose:
d.update({
'ParseFolderTasks': [{
"name": t.name(),
"label_count": t.label_count,
"train_count": t.train_count,
"val_count": t.val_count,
"test_count": t.test_count,
} for t in self.parse_folder_tasks()],
'CreateDbTasks': [{
"name": t.name(),
"entries": t.entries_count,
"image_width": t.image_dims[0],
"image_height": t.image_dims[1],
"image_channels": t.image_dims[2],
"backend": t.backend,
"encoding": t.encoding,
"compression": t.compression,
} for t in self.create_db_tasks()],
})
return d

def parse_folder_tasks(self):
def get_backend(self):
"""
Return all ParseFolderTasks for this job
Return the DB backend used to create this dataset
"""
return [t for t in self.tasks if isinstance(t, tasks.ParseFolderTask)]
raise NotImplementedError('Please implement me')

def create_db_tasks(self):
def get_entry_count(self, stage):
"""
Return all CreateDbTasks for this job
Return the number of entries in the DB matching the specified stage
"""
return [t for t in self.tasks if isinstance(t, tasks.CreateDbTask)]
raise NotImplementedError('Please implement me')

def train_db_task(self):
def get_feature_db_path(self, stage):
"""
Return the task that creates the training set
Return the absolute feature DB path for the specified stage
"""
raise NotImplementedError('Please implement me')

def val_db_task(self):
def get_feature_dims(self):
"""
Return the task that creates the validation set
Return the shape of the feature N-D array
"""
for t in self.tasks:
if isinstance(t, tasks.CreateDbTask) and 'val' in t.name().lower():
return t
return None
raise NotImplementedError('Please implement me')

def test_db_task(self):
def get_label_db_path(self, stage):
"""
Return the task that creates the test set
Return the absolute label DB path for the specified stage
"""
for t in self.tasks:
if isinstance(t, tasks.CreateDbTask) and 'test' in t.name().lower():
return t
return None
raise NotImplementedError('Please implement me')

def analyze_db_tasks(self):
def get_mean_file(self):
"""
Return all AnalyzeDbTasks for this job
Return the mean file
"""
return [t for t in self.tasks if isinstance(t, tasks.AnalyzeDbTask)]

raise NotImplementedError('Please implement me')
4 changes: 2 additions & 2 deletions digits/model/images/classification/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def download_files(self, epoch=-1):
download_files.extend([
(task.dataset.path(task.dataset.labels_file),
os.path.basename(task.dataset.labels_file)),
(task.dataset.path(task.dataset.train_db_task().mean_file),
os.path.basename(task.dataset.train_db_task().mean_file)),
(task.dataset.path(task.dataset.get_mean_file()),
os.path.basename(task.dataset.get_mean_file())),
(snapshot_filename,
os.path.basename(snapshot_filename)),
])
Expand Down
2 changes: 1 addition & 1 deletion digits/model/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def create():
else ''), form.python_layer_server_file.data)

job.tasks.append(fw.create_train_task(
job_dir = job.dir(),
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
Expand Down
2 changes: 1 addition & 1 deletion digits/model/images/generic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def create():
else ''), form.python_layer_server_file.data)

job.tasks.append(fw.create_train_task(
job_dir = job.dir(),
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
Expand Down
Loading

0 comments on commit bec92de

Please sign in to comment.