Skip to content

Commit

Permalink
Transform image
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashwin Vaidya committed Oct 10, 2023
1 parent 65253d2 commit f048c2b
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/otx/algorithms/anomaly/tasks/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,23 @@
logger = get_logger(__name__)


class OTXOpenVINOAnomalyDataloader:
"""Dataloader for loading OTX dataset into OTX OpenVINO Inferencer.
class OTXNNCFAnomalyDataloader:
"""Dataloader for loading OTX dataset for NNCF optimization.
Args:
dataset (DatasetEntity): OTX dataset entity
model: (AnomalyDetection) The modelAPI model used for fetching the transforms.
shuffle (bool, optional): Shuffle dataset. Defaults to True.
"""

def __init__(
self,
dataset: DatasetEntity,
model: AnomalyDetection,
shuffle: bool = True,
):
self.dataset = dataset
self.model = model
self.shuffler = None
if shuffle:
self.shuffler = list(range(len(dataset)))
Expand All @@ -109,7 +112,11 @@ def __getitem__(self, index: int):
image = self.dataset[index].numpy
annotation = self.dataset[index].annotation_scene

return (index, annotation), image
resized_image = self.model.resize(image, (self.model.w, self.model.h))
resized_image = self.model.input_transform(resized_image)
resized_image = self.model._change_layout(resized_image)

return (index, annotation), resized_image

def __len__(self) -> int:
"""Get size of the dataset.
Expand Down Expand Up @@ -315,7 +322,7 @@ def optimize(
)

logger.info("Starting PTQ optimization.")
data_loader = OTXOpenVINOAnomalyDataloader(dataset=dataset)
data_loader = OTXNNCFAnomalyDataloader(dataset=dataset, model=self.inference_model)
quantization_dataset = nncf.Dataset(data_loader, lambda data: data[1])

with tempfile.TemporaryDirectory() as tempdir:
Expand Down

0 comments on commit f048c2b

Please sign in to comment.