Skip to content

Commit

Permalink
Added test
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed May 20, 2024
1 parent 874119c commit d4a0e07
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def cxcywhr_to_poly(boxes: np.ndarray) -> np.ndarray:
if shape[-1] != 5:
raise ValueError(f"Expected last dimension to be 5, got {shape[-1]}")

flat_rboxes = boxes.reshape(-1, 5)
flat_rboxes = boxes.reshape(-1, 5).astype(np.float32)
polys = np.zeros((flat_rboxes.shape[0], 4, 2), dtype=np.float32)
for i, box in enumerate(flat_rboxes):
cx, cy, w, h, r = box
Expand Down
23 changes: 23 additions & 0 deletions tests/integration_tests/albumentations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from albumentations import Compose, HorizontalFlip, InvertImg

from super_gradients.training.datasets import Cifar10, Cifar100, ImageNetDataset, COCODetectionDataset, CoCoSegmentationDataSet, COCOPoseEstimationDataset
from super_gradients.training.samples import OBBSample
from super_gradients.training.transforms.pipeline_adaptors import AlbumentationsAdaptor
from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization
from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xywh_to_xyxy
from super_gradients.training.datasets.depth_estimation_datasets import NYUv2DepthEstimationDataset
Expand Down Expand Up @@ -338,6 +340,27 @@ def test_coco_pose_albumentations_intergration(self):

_ = next(iter(unsupported_ds))

def test_obb_support_albumentations(self):
import albumentations as A

adaptor = AlbumentationsAdaptor(
composed_transforms=A.Compose(
transforms=[A.ShiftScaleRotate(p=1), A.RandomBrightness(p=1), A.Transpose(p=1)], keypoint_params=A.KeypointParams(format="xy")
)
)

sample = OBBSample(
image=np.ones((256, 256, 3), dtype=np.uint8),
rboxes_cxcywhr=np.array([[128, 128, 100, 50, 0]]),
labels=np.array([1]),
is_crowd=np.array([0]),
additional_samples=None,
)
sample = adaptor.apply_to_sample(sample)
self.assertEqual(sample.image.shape, (256, 256, 3))
self.assertEqual(sample.rboxes_cxcywhr.shape, (1, 5))
self.assertEqual(sample.labels.shape, (1,))


if __name__ == "__main__":
unittest.main()

0 comments on commit d4a0e07

Please sign in to comment.