From 9cf4a5f005131d958bc74dfb7c259cbf1d673b9f Mon Sep 17 00:00:00 2001 From: Yael-Baron Date: Tue, 2 Apr 2024 10:38:55 +0300 Subject: [PATCH] Integration of RAFT Optical Flow model to SG. --- src/super_gradients/common/object_names.py | 3 + .../module_interfaces/__init__.py | 3 + .../exportable_optical_flow.py | 293 ++++++++++ .../arch_params/raft_l_arch_params.yaml | 41 ++ .../arch_params/raft_s_arch_params.yaml | 41 ++ .../training/models/__init__.py | 3 + .../models/optical_flow_models/__init__.py | 0 .../optical_flow_models/raft/__init__.py | 31 ++ .../optical_flow_models/raft/raft_base.py | 500 ++++++++++++++++++ .../optical_flow_models/raft/raft_variants.py | 140 +++++ .../export_optical_flow_model_test.py | 139 +++++ tests/unit_tests/raft_tests.py | 38 ++ 12 files changed, 1232 insertions(+) create mode 100644 src/super_gradients/module_interfaces/exportable_optical_flow.py create mode 100644 src/super_gradients/recipes/arch_params/raft_l_arch_params.yaml create mode 100644 src/super_gradients/recipes/arch_params/raft_s_arch_params.yaml create mode 100644 src/super_gradients/training/models/optical_flow_models/__init__.py create mode 100644 src/super_gradients/training/models/optical_flow_models/raft/__init__.py create mode 100644 src/super_gradients/training/models/optical_flow_models/raft/raft_base.py create mode 100644 src/super_gradients/training/models/optical_flow_models/raft/raft_variants.py create mode 100644 tests/unit_tests/export_optical_flow_model_test.py create mode 100644 tests/unit_tests/raft_tests.py diff --git a/src/super_gradients/common/object_names.py b/src/super_gradients/common/object_names.py index db78546bad..c2c0672bf4 100644 --- a/src/super_gradients/common/object_names.py +++ b/src/super_gradients/common/object_names.py @@ -338,6 +338,9 @@ class Models: YOLO_NAS_POSE_M = "yolo_nas_pose_m" YOLO_NAS_POSE_L = "yolo_nas_pose_l" + RAFT_S = "raft_s" + RAFT_L = "raft_l" + class ConcatenatedTensorFormats: XYXY_LABEL = "XYXY_LABEL" diff --git a/src/super_gradients/module_interfaces/__init__.py b/src/super_gradients/module_interfaces/__init__.py index f9871c3825..9838aab106 100644 --- a/src/super_gradients/module_interfaces/__init__.py +++ b/src/super_gradients/module_interfaces/__init__.py @@ -12,6 +12,7 @@ SemanticSegmentationDecodingModule, BinarySegmentationDecodingModule, ) +from .exportable_optical_flow import ExportableOpticalFlowModel, OpticalFlowModelExportResult __all__ = [ "HasPredict", @@ -35,4 +36,6 @@ "AbstractSegmentationDecodingModule", "SemanticSegmentationDecodingModule", "BinarySegmentationDecodingModule", + "ExportableOpticalFlowModel", + "OpticalFlowModelExportResult", ] diff --git a/src/super_gradients/module_interfaces/exportable_optical_flow.py b/src/super_gradients/module_interfaces/exportable_optical_flow.py new file mode 100644 index 0000000000..98cecfb186 --- /dev/null +++ b/src/super_gradients/module_interfaces/exportable_optical_flow.py @@ -0,0 +1,293 @@ +import copy +import dataclasses +import gc +from typing import Union, Optional, List, Tuple + +import numpy as np +import onnx +import onnxsim +import torch +from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.conversion import ExportTargetBackend, ExportQuantizationMode +from super_gradients.conversion.conversion_utils import find_compatible_model_device_for_dtype +from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_install +from super_gradients.import_utils import import_pytorch_quantization_or_install +from super_gradients.module_interfaces.supports_input_shape_check import SupportsInputShapeCheck +from super_gradients.training.utils.export_utils import ( + infer_image_shape_from_model, + infer_image_input_channels, +) +from super_gradients.training.utils.utils import infer_model_device, check_model_contains_quantized_modules +from torch import nn +from torch.utils.data import DataLoader + +logger = get_logger(__name__) + +__all__ = ["ExportableOpticalFlowModel", "OpticalFlowModelExportResult"] + + +@dataclasses.dataclass +class OpticalFlowModelExportResult: + """ + A dataclass that holds the result of model export. + """ + + input_image_channels: int + input_image_dtype: torch.dtype + input_image_shape: Tuple[int, int] + + engine: ExportTargetBackend + quantization_mode: Optional[ExportQuantizationMode] + + output: str + + usage_instructions: str = "" + + def __repr__(self): + return self.usage_instructions + + +class ExportableOpticalFlowModel: + """ + A mixin class that adds export functionality to the optical flow models. + Classes that inherit from this mixin must implement the following methods: + - get_decoding_module() + - get_preprocessing_callback() + Providing these methods are implemented correctly, the model can be exported to ONNX or TensorRT formats + using model.export(...) method. + """ + + def export( + self, + output: str, + quantization_mode: Optional[ExportQuantizationMode] = None, + selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa + calibration_loader: Optional[DataLoader] = None, + calibration_method: str = "percentile", + calibration_batches: int = 16, + calibration_percentile: float = 99.99, + batch_size: int = 1, + input_image_shape: Optional[Tuple[int, int]] = None, + input_image_channels: Optional[int] = None, + input_image_dtype: Optional[torch.dtype] = None, + onnx_export_kwargs: Optional[dict] = None, + onnx_simplify: bool = True, + device: Optional[Union[torch.device, str]] = None, + ): + """ + Export the model to one of supported formats. Format is inferred from the output file extension or can be + explicitly specified via `format` argument. + + :param output: Output file name of the exported model. + :param quantization_mode: (QuantizationMode) Sets the quantization mode for the exported model. + If None, the model is exported as-is without any changes to mode weights. + If QuantizationMode.FP16, the model is exported with weights converted to half precision. + If QuantizationMode.INT8, the model is exported with weights quantized to INT8 (Using PTQ). + For this mode you can use calibration_loader to specify a data loader for calibrating the model. + :param selective_quantizer: (SelectiveQuantizer) An optional quantizer for selectively quantizing model weights. + :param calibration_loader: (torch.utils.data.DataLoader) An optional data loader for calibrating a quantized model. + :param calibration_method: (str) Calibration method for quantized models. See QuantizationCalibrator for details. + :param calibration_batches: (int) Number of batches to use for calibration. Default is 16. + :param calibration_percentile: (float) Percentile for percentile calibration method. Default is 99.99. + :param batch_size: (int) Batch size for the exported model. + :param input_image_shape: (tuple) Input image shape (height, width) for the exported model. + If None, the function will infer the image shape from the model's preprocessing params. + :param input_image_channels: (int) Number of input image channels for the exported model. + If None, the function will infer the number of channels from the model itself + (No implemented now, will use hard-coded value of 3 for now). + :param input_image_dtype: (torch.dtype) Type of the input image for the exported model. + If None, the function will infer the dtype from the model's preprocessing and other parameters. + If preprocessing is True, dtype will default to torch.uint8. + If preprocessing is False and requested quantization mode is FP16 a torch.float16 will be used, + otherwise a default torch.float32 dtype will be used. + :param device: (torch.device) Device to use for exporting the model. If not specified, the device is inferred from the model itself. + :param onnx_export_kwargs: (dict) Optional keyword arguments for torch.onnx.export() function. + :param onnx_simplify: (bool) If True, apply onnx-simplifier to the exported model. + :return: + """ + + # Do imports here to avoid raising error of missing onnx_graphsurgeon package if it is not needed. + import_onnx_graphsurgeon_or_install() + if ExportQuantizationMode.INT8 == quantization_mode: + import_pytorch_quantization_or_install() + from super_gradients.conversion.conversion_utils import torch_dtype_to_numpy_dtype + + usage_instructions = [] + + # Hard-coded for now + # Will be made a parameter if we decide to support CoreML/OpenVino/TRT export in the future + engine = ExportTargetBackend.ONNXRUNTIME + + if not isinstance(self, nn.Module): + raise TypeError(f"Export is only supported for torch.nn.Module. Got type {type(self)}") + + device: torch.device = device or infer_model_device(self) + if device is None: + raise ValueError( + "Device is not specified and cannot be inferred from the model. " + "Please specify the device explicitly: model.export(..., device=torch.device(...))" + ) + + # The following is a trick to infer the exact device index in order to make sure the model using right device. + # User may pass device="cuda", which is not explicitly specifying device index. + # Using this trick, we can infer the correct device (cuda:3 for instance) and use it later for checking + # whether model places all it's parameters on the right device. + device = torch.zeros(1).to(device).device + + logger.debug(f"Using device: {device} for exporting model {self.__class__.__name__}") + + model: nn.Module = copy.deepcopy(self).eval() + + # Infer the input image shape from the model + if input_image_shape is None: + input_image_shape = infer_image_shape_from_model(model) + logger.debug(f"Inferred input image shape: {input_image_shape} from model {model.__class__.__name__}") + + if input_image_shape is None: + raise ValueError( + "Image shape is not specified and cannot be inferred from the model. " + "Please specify the image shape explicitly: model.export(..., input_image_shape=(height, width))" + ) + + try: + rows, cols = input_image_shape + except ValueError: + raise ValueError(f"Image shape must be a tuple of two integers (height, width), got {input_image_shape} instead") + + # Infer the number of input channels from the model + if input_image_channels is None: + input_image_channels = infer_image_input_channels(model) + logger.debug(f"Inferred input image channels: {input_image_channels} from model {model.__class__.__name__}") + + if input_image_channels is None: + raise ValueError( + "Number of input channels is not specified and cannot be inferred from the model. " + "Please specify the number of input channels explicitly: model.export(..., input_image_channels=NUM_CHANNELS_YOUR_MODEL_TAKES)" + ) + + input_shape = (batch_size, 2, input_image_channels, rows, cols) + + if isinstance(model, SupportsInputShapeCheck): + model.validate_input_shape(input_shape) + + prep_model_for_conversion_kwargs = { + "input_size": input_shape, + } + + model_type = torch.half if quantization_mode == ExportQuantizationMode.FP16 else torch.float32 + device = find_compatible_model_device_for_dtype(device, model_type) + + # This variable holds the output names of the model. + # If postprocessing is enabled, it will be set to the output names of the postprocessing module. + output_names: Optional[List[str]] = None + + if hasattr(model, "prep_model_for_conversion"): + model.prep_model_for_conversion(**prep_model_for_conversion_kwargs) + + contains_quantized_modules = check_model_contains_quantized_modules(model) + + if quantization_mode == ExportQuantizationMode.INT8: + from super_gradients.training.utils.quantization import ptq + + model = ptq( + model, + selective_quantizer=selective_quantizer, + calibration_loader=calibration_loader, + calibration_method=calibration_method, + calibration_batches=calibration_batches, + calibration_percentile=calibration_percentile, + ) + elif quantization_mode == ExportQuantizationMode.FP16: + if contains_quantized_modules: + raise RuntimeError("Model contains quantized modules for INT8 mode. " "FP16 quantization is not supported for such models.") + elif quantization_mode is None and contains_quantized_modules: + # If quantization_mode is None, but we have quantized modules in the model, we need to + # update the quantization_mode to INT8, so that we can correctly export the model. + quantization_mode = ExportQuantizationMode.INT8 + + from super_gradients.training.models.conversion import ConvertableCompletePipelineModel + + # The model.prep_model_for_conversion will be called inside ConvertableCompletePipelineModel once more, + # but as long as implementation of prep_model_for_conversion is idempotent, it should be fine. + complete_model = ( + ConvertableCompletePipelineModel(model=model, pre_process=None, post_process=None, **prep_model_for_conversion_kwargs).to(device).eval() + ) + + if quantization_mode == ExportQuantizationMode.FP16: + # For FP16 quantization, we simply can to convert the whole model to half precision + complete_model = complete_model.half() + + if calibration_loader is not None: + logger.warning( + "It seems you've passed calibration_loader to export function, but quantization_mode is set to FP16. " + "FP16 quantization is done by calling model.half() so you don't need to pass calibration_loader, as it will be ignored." + ) + + if engine in {ExportTargetBackend.ONNXRUNTIME, ExportTargetBackend.TENSORRT}: + from super_gradients.conversion.onnx.export_to_onnx import export_to_onnx + + onnx_export_kwargs = onnx_export_kwargs or {} + onnx_input = torch.randn(input_shape).to(device=device, dtype=input_image_dtype) + + export_to_onnx( + model=complete_model, + model_input=onnx_input, + onnx_filename=output, + input_names=["input"], + output_names=output_names, + onnx_opset=onnx_export_kwargs.get("opset_version", None), + do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True), + dynamic_axes=onnx_export_kwargs.get("dynamic_axes", None), + keep_initializers_as_inputs=onnx_export_kwargs.get("keep_initializers_as_inputs", False), + verbose=onnx_export_kwargs.get("verbose", False), + ) + + if onnx_simplify: + model_opt, simplify_successful = onnxsim.simplify(output) + if not simplify_successful: + raise RuntimeError(f"Failed to simplify ONNX model {output} with onnxsim. Please check the logs for details.") + onnx.save(model_opt, output) + + logger.debug(f"Ran onnxsim.simplify on {output}") + else: + raise ValueError(f"Unsupported export format: {engine}. Supported formats: onnxruntime, tensorrt") + + # Cleanup memory, not sure whether it is necessary but just in case + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Add usage instructions + usage_instructions.append(f"Model exported successfully to {output}") + usage_instructions.append( + f"Model expects input image of shape [{batch_size}, {2}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]}]" + ) + usage_instructions.append(f"Input image dtype is {input_image_dtype}") + + usage_instructions.append("Exported model is in ONNX format and can be used with ONNXRuntime") + usage_instructions.append("To run inference with ONNXRuntime, please use the following code snippet:") + usage_instructions.append("") + usage_instructions.append(" import onnxruntime") + usage_instructions.append(" import numpy as np") + + usage_instructions.append(f' session = onnxruntime.InferenceSession("{output}", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])') + usage_instructions.append(" inputs = [o.name for o in session.get_inputs()]") + usage_instructions.append(" outputs = [o.name for o in session.get_outputs()]") + + dtype_name = np.dtype(torch_dtype_to_numpy_dtype(input_image_dtype)).name + usage_instructions.append( + f" example_input_batch = np.zeros(({batch_size}, {2}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]})).astype(np.{dtype_name})" # noqa + ) + + usage_instructions.append(" flow_prediction = session.run(outputs, {inputs[0]: example_input_batch})") + usage_instructions.append("") + + return OpticalFlowModelExportResult( + input_image_channels=input_image_channels, + input_image_dtype=input_image_dtype, + input_image_shape=input_image_shape, + engine=engine, + quantization_mode=quantization_mode, + output=output, + usage_instructions="\n".join(usage_instructions), + ) diff --git a/src/super_gradients/recipes/arch_params/raft_l_arch_params.yaml b/src/super_gradients/recipes/arch_params/raft_l_arch_params.yaml new file mode 100644 index 0000000000..fff35ba0f9 --- /dev/null +++ b/src/super_gradients/recipes/arch_params/raft_l_arch_params.yaml @@ -0,0 +1,41 @@ +in_channels: 3 # the number of in_channels to fnet and cnet. 1 for greyscale image. + +encoder_params: + in_planes: 64 + hidden_dim: 128 + context_dim: 128 + corr_levels: 4 + corr_radius: 4 + dropout: 0 + fnet: # Feature encoder + output_dim: 256 + norm_fn: 'instance' # 'instance' - for instance normalization + cnet: # Context encoder + norm_fn: 'batch' # 'batch' - for batch normalization + output_dim: 256 # context_dim + hidden_dim + update_block: + hidden_dim: ${..hidden_dim} + use_mask: True + motion_encoder: + num_corr_conv: 2 # limited to max 2 conv layers + convc1_output_dim: 256 + convc2_output_dim: 192 + convf1_output_dim: 128 + convf2_output_dim: 64 + conv_output_dim: 126 + gru: + block: SepConvGRU + hidden_dim: ${...hidden_dim} + input_dim: 256 + flow_head: + hidden_dim: 256 + input_dim: ${arch_params.encoder_params.update_block.gru.hidden_dim} + +corr_params: + alternate_corr: False + +flow_params: + iters: 12 # the number of iterations the optimization loop will run to refine the optical flow predictions over multiple iterations + upsample_mode: convex # if none, then using a predefined upsample function (upflow8) OR convex with basic model only + +#freeze_bn: True # True for FlyingChairs diff --git a/src/super_gradients/recipes/arch_params/raft_s_arch_params.yaml b/src/super_gradients/recipes/arch_params/raft_s_arch_params.yaml new file mode 100644 index 0000000000..f6968ec59d --- /dev/null +++ b/src/super_gradients/recipes/arch_params/raft_s_arch_params.yaml @@ -0,0 +1,41 @@ +in_channels: 3 # the number of in_channels to fnet and cnet. 1 for greyscale image. + +encoder_params: + in_planes: 32 + hidden_dim: 96 + context_dim: 64 + corr_levels: 4 + corr_radius: 3 + dropout: 0 + fnet: # Feature encoder + output_dim: 128 + norm_fn: 'instance' # 'instance' - for instance normalization + cnet: # Context encoder + norm_fn: 'none' # 'batch' - for batch normalization + output_dim: 160 # context_dim + hidden_dim + update_block: + hidden_dim: ${..hidden_dim} + use_mask: False + motion_encoder: + num_corr_conv: 1 # limited to max 2 conv layers + convc1_output_dim: 96 + convc2_output_dim: # convc2 layer is set only for raft_l + convf1_output_dim: 64 + convf2_output_dim: 32 + conv_output_dim: 80 + gru: + block: ConvGRU + hidden_dim: ${...hidden_dim} + input_dim: 146 + flow_head: + hidden_dim: 128 + input_dim: ${arch_params.encoder_params.update_block.gru.hidden_dim} + +corr_params: + alternate_corr: False + +flow_params: + iters: 12 # the number of iterations the optimization loop will run to refine the optical flow predictions over multiple iterations + upsample_mode: # if none, then using a predefined upsample function (upflow8) OR convex with basic model only + +#freeze_bn: True # True for FlyingChairs diff --git a/src/super_gradients/training/models/__init__.py b/src/super_gradients/training/models/__init__.py index abc6bcb349..2855b2f06f 100755 --- a/src/super_gradients/training/models/__init__.py +++ b/src/super_gradients/training/models/__init__.py @@ -125,6 +125,7 @@ from super_gradients.training.models.arch_params_factory import get_arch_params from super_gradients.training.models.conversion import convert_to_coreml, convert_to_onnx, convert_from_config +from super_gradients.training.models.optical_flow_models.raft.raft_variants import RAFT_S, RAFT_L from super_gradients.common.object_names import Models from super_gradients.common.registry.registry import ARCHITECTURES @@ -295,4 +296,6 @@ "SegFormerB5", "DDRNet39Backbone", "BasicResNetBlock", + "RAFT_S", + "RAFT_L", ] diff --git a/src/super_gradients/training/models/optical_flow_models/__init__.py b/src/super_gradients/training/models/optical_flow_models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/super_gradients/training/models/optical_flow_models/raft/__init__.py b/src/super_gradients/training/models/optical_flow_models/raft/__init__.py new file mode 100644 index 0000000000..a15c35a510 --- /dev/null +++ b/src/super_gradients/training/models/optical_flow_models/raft/__init__.py @@ -0,0 +1,31 @@ +from super_gradients.training.models.optical_flow_models.raft.raft_base import ( + BottleneckBlock, + Encoder, + ContextEncoder, + FlowHead, + SepConvGRU, + ConvGRU, + MotionEncoder, + UpdateBlock, + CorrBlock, + AlternateCorrBlock, + FlowIterativeBlock, +) + +from super_gradients.training.models.optical_flow_models.raft.raft_variants import RAFT_S, RAFT_L + +__all__ = [ + "BottleneckBlock", + "Encoder", + "ContextEncoder", + "FlowHead", + "SepConvGRU", + "ConvGRU", + "MotionEncoder", + "UpdateBlock", + "CorrBlock", + "AlternateCorrBlock", + "FlowIterativeBlock", + "RAFT_S", + "RAFT_L", +] diff --git a/src/super_gradients/training/models/optical_flow_models/raft/raft_base.py b/src/super_gradients/training/models/optical_flow_models/raft/raft_base.py new file mode 100644 index 0000000000..366fa79482 --- /dev/null +++ b/src/super_gradients/training/models/optical_flow_models/raft/raft_base.py @@ -0,0 +1,500 @@ +from typing import Optional, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from super_gradients.module_interfaces import SupportsReplaceInputChannels + + +__all__ = [ + "BottleneckBlock", + "Encoder", + "ContextEncoder", + "FlowHead", + "SepConvGRU", + "ConvGRU", + "MotionEncoder", + "UpdateBlock", + "CorrBlock", + "AlternateCorrBlock", + "FlowIterativeBlock", +] + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes: int, planes: int, norm_fn: str = "group", stride: int = 1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Encoder(nn.Module, SupportsReplaceInputChannels): + def __init__( + self, + in_channels: int, + in_planes: int, + output_dim: int = 128, + norm_fn: str = "batch", + dropout: float = 0.0, + ): + super(Encoder, self).__init__() + self.norm_fn = norm_fn + self.in_planes = in_planes + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(self.in_planes) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(self.in_planes) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=self.in_planes, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + dim = int(self.in_planes / 32) + self.layer1 = self._make_layer(dim * 32, stride=1) + self.layer2 = self._make_layer((dim + 1) * 32, stride=2) + self.layer3 = self._make_layer((dim + 2) * 32, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d((dim + 2) * 32, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim: int, stride: int = 1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + + # # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None): + from super_gradients.modules.weight_replacement_utils import replace_conv2d_input_channels + + self.conv1 = replace_conv2d_input_channels(conv=self.conv1, in_channels=in_channels, fn=compute_new_weights_fn) + + def get_input_channels(self) -> int: + return self.conv1.in_channels + + +class ContextEncoder(nn.Module): + def __init__( + self, + in_channels: int, + in_planes: int, + hidden_dim: int, + context_dim: int, + output_dim: int = 128, + norm_fn: str = "batch", + dropout: float = 0.0, + ): + super(ContextEncoder, self).__init__() + + self.cnet = Encoder(in_channels=in_channels, in_planes=in_planes, output_dim=output_dim, norm_fn=norm_fn, dropout=dropout) + + self.hidden_dim = hidden_dim + self.context_dim = context_dim + + def forward(self, x): + out = self.cnet(x) + net, inp = torch.split(out, [self.hidden_dim, self.context_dim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + return net, inp + + +class FlowHead(nn.Module): + def __init__(self, input_dim: int = 128, hidden_dim: int = 256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim: int = 128, input_dim: int = 192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + + self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim: int = 128, input_dim: int = 192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class MotionEncoder(nn.Module): + def __init__( + self, + corr_levels: int, + corr_radius: int, + num_corr_conv: int, + convc1_output_dim: int, + convc2_output_dim: int, + convf1_output_dim: int, + convf2_output_dim: int, + conv_output_dim: int, + ): + super(MotionEncoder, self).__init__() + self.num_corr_conv = num_corr_conv + + cor_planes = corr_levels * (2 * corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, convc1_output_dim, 1, padding=0) + if self.num_corr_conv == 2: + self.convc2 = nn.Conv2d(convc1_output_dim, convc2_output_dim, 3, padding=1) + conv_input_dim = convf2_output_dim + convc2_output_dim + else: + conv_input_dim = convf2_output_dim + convc1_output_dim + + self.convf1 = nn.Conv2d(2, convf1_output_dim, 7, padding=3) + self.convf2 = nn.Conv2d(convf1_output_dim, convf2_output_dim, 3, padding=1) + self.conv = nn.Conv2d(conv_input_dim, conv_output_dim, 3, padding=1) + + def forward(self, flow: Tensor, corr: Tensor): + cor = F.relu(self.convc1(corr)) + if self.num_corr_conv == 2: + cor = F.relu(self.convc2(cor)) + + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class UpdateBlock(nn.Module): + def __init__(self, encoder_params, hidden_dim: int = 128): + super(UpdateBlock, self).__init__() + self.use_mask = encoder_params.update_block.use_mask + + self.encoder = MotionEncoder( + encoder_params.corr_levels, + encoder_params.corr_radius, + encoder_params.update_block.motion_encoder.num_corr_conv, + encoder_params.update_block.motion_encoder.convc1_output_dim, + encoder_params.update_block.motion_encoder.convc2_output_dim, + encoder_params.update_block.motion_encoder.convf1_output_dim, + encoder_params.update_block.motion_encoder.convf2_output_dim, + encoder_params.update_block.motion_encoder.conv_output_dim, + ) + + if encoder_params.update_block.gru.block == "ConvGRU": + self.gru = ConvGRU(hidden_dim=encoder_params.update_block.gru.hidden_dim, input_dim=encoder_params.update_block.gru.input_dim) + elif encoder_params.update_block.gru.block == "SepConvGRU": + self.gru = SepConvGRU(hidden_dim=encoder_params.update_block.gru.hidden_dim, input_dim=encoder_params.update_block.gru.input_dim) + + self.flow_head = FlowHead(hidden_dim, hidden_dim=encoder_params.update_block.flow_head.hidden_dim) + + if self.use_mask: + self.mask = nn.Sequential(nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 64 * 9, 1, padding=0)) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + if self.use_mask: + # scale mask to balance gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow + else: + return net, None, delta_flow + + +class CorrBlock: + def __init__(self, num_levels: int = 4, radius: int = 4): + self.num_levels = num_levels + self.radius = radius + + def __call__(self, coords, fmap1, fmap2): + corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + corr_pyramid.append(corr) + + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_pyramid.append(corr) + + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = self.bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + @staticmethod + def bilinear_sampler(img, coords, mask: bool = False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +class AlternateCorrBlock: + def __init__(self, num_levels: int = 4, radius: int = 4): + self.num_levels = num_levels + self.radius = radius + + def __call__( + self, + coords, + fmap1, + fmap2, + ): + import alt_cuda_corr + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class FlowIterativeBlock(nn.Module): + def __init__(self, encoder_params, hidden_dim, flow_params, alternate_corr): + super(FlowIterativeBlock, self).__init__() + self.update_block = UpdateBlock(encoder_params, hidden_dim) + self.upsample_mode = flow_params.upsample_mode + self.iters = flow_params.iters + + if alternate_corr: + self.corr_fn = AlternateCorrBlock(radius=encoder_params.corr_radius) + else: + self.corr_fn = CorrBlock(radius=encoder_params.corr_radius) + + @staticmethod + def upsample_flow(flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + @staticmethod + def upflow8(flow, mode="bilinear"): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + def forward(self, coords0, coords1, net, inp, fmap1, fmap2): + + flow_predictions = [] + + for itr in range(self.iters): + coords1 = coords1.detach() + corr = self.corr_fn(coords1, fmap1, fmap2) # index correlation volume + + flow = coords1 - coords0 + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + # update the coordinates based on the flow change + coords1 = coords1 + delta_flow + + # upsample flow predictions + if self.upsample_mode is None: + flow_up = self.upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + return flow_predictions, flow_up diff --git a/src/super_gradients/training/models/optical_flow_models/raft/raft_variants.py b/src/super_gradients/training/models/optical_flow_models/raft/raft_variants.py new file mode 100644 index 0000000000..a251466f5d --- /dev/null +++ b/src/super_gradients/training/models/optical_flow_models/raft/raft_variants.py @@ -0,0 +1,140 @@ +import copy +from typing import Union, Optional, Callable + +import torch +import torch.nn as nn +from omegaconf import DictConfig + +from super_gradients.common.object_names import Models + +from super_gradients.common.registry import register_model +from super_gradients.module_interfaces import SupportsReplaceInputChannels +from super_gradients.module_interfaces.exportable_optical_flow import ExportableOpticalFlowModel +from super_gradients.training.models import get_arch_params, SgModule +from super_gradients.training.utils.utils import HpmStruct, get_param + +from .raft_base import Encoder, ContextEncoder, FlowIterativeBlock + + +class RAFT(ExportableOpticalFlowModel, SgModule): + def __init__(self, in_channels, encoder_params, corr_params, flow_params, num_classes): + super().__init__() + + self.in_channels = in_channels + + self.feature_encoder = Encoder( + in_channels=self.in_channels, + in_planes=encoder_params.in_planes, + output_dim=encoder_params.fnet.output_dim, + norm_fn=encoder_params.fnet.norm_fn, + dropout=encoder_params.dropout, + ) + + self.context_encoder = ContextEncoder( + in_channels=self.in_channels, + in_planes=encoder_params.in_planes, + hidden_dim=encoder_params.hidden_dim, + context_dim=encoder_params.context_dim, + output_dim=encoder_params.cnet.output_dim, + norm_fn=encoder_params.cnet.norm_fn, + dropout=encoder_params.dropout, + ) + + self.flow_iterative_block = FlowIterativeBlock(encoder_params, encoder_params.update_block.hidden_dim, flow_params, corr_params.alternate_corr) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + @staticmethod + def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + def initialize_flow(self, img): + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = self.coords_grid(N, H // 8, W // 8, device=img.device) + coords1 = self.coords_grid(N, H // 8, W // 8, device=img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def forward(self, x, **kwargs): + """Estimate optical flow between pairs of frames""" + + image1 = x[:, 0] + image2 = x[:, 1] + + # run the feature network + fmap1, fmap2 = self.feature_encoder([image1, image2]) + + # run the context network + net, inp = self.context_encoder(image1) + + # initialize flow + coords0, coords1 = self.initialize_flow(image1) + + # run update block network + flow_predictions, flow_up = self.flow_iterative_block(coords0, coords1, net, inp, fmap1, fmap2) + + if not self.training: + return flow_up # removed 1st coords1 - coords0, + + return flow_predictions + + def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None): + if isinstance(self.feature_encoder, SupportsReplaceInputChannels) and isinstance(self.context_encoder, SupportsReplaceInputChannels): + self.feature_encoder.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn) + self.context_encoder.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn) + + self.in_channels = self.get_input_channels() + else: + raise NotImplementedError( + f"`{self.feature_encoder.__class__.__name__}` and `{self.context_encoder.__class__.__name__}` do not support `replace_input_channels`" + ) + + def get_input_channels(self) -> int: + if isinstance(self.feature_encoder, SupportsReplaceInputChannels) and isinstance(self.context_encoder, SupportsReplaceInputChannels): + return self.feature_encoder.get_input_channels() + else: + raise NotImplementedError( + f"`{self.feature_encoder.__class__.__name__}` and `{self.context_encoder.__class__.__name__}` do not support `replace_input_channels`" + ) + + def prep_model_for_conversion(self, input_size: Optional[Union[tuple, list]] = None, **kwargs): + for module in self.modules(): + if module != self and hasattr(module, "prep_model_for_conversion"): + module.prep_model_for_conversion(input_size, **kwargs) + + +@register_model(Models.RAFT_S) +class RAFT_S(RAFT): + def __init__(self, arch_params: Union[HpmStruct, DictConfig]): + default_arch_params = get_arch_params("raft_s_arch_params") + merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params)) + merged_arch_params.override(**arch_params.to_dict()) + super().__init__( + in_channels=merged_arch_params.in_channels, + encoder_params=merged_arch_params.encoder_params, + corr_params=merged_arch_params.corr_params, + flow_params=merged_arch_params.flow_params, + num_classes=get_param(merged_arch_params, "num_classes", None), + ) + + +@register_model(Models.RAFT_L) +class RAFT_L(RAFT): + def __init__(self, arch_params: Union[HpmStruct, DictConfig]): + default_arch_params = get_arch_params("raft_l_arch_params") + merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params)) + merged_arch_params.override(**arch_params.to_dict()) + super().__init__( + in_channels=merged_arch_params.in_channels, + encoder_params=merged_arch_params.encoder_params, + corr_params=merged_arch_params.corr_params, + flow_params=merged_arch_params.flow_params, + num_classes=get_param(merged_arch_params, "num_classes", None), + ) diff --git a/tests/unit_tests/export_optical_flow_model_test.py b/tests/unit_tests/export_optical_flow_model_test.py new file mode 100644 index 0000000000..41825fec30 --- /dev/null +++ b/tests/unit_tests/export_optical_flow_model_test.py @@ -0,0 +1,139 @@ +import logging +import os +import tempfile +import unittest + +import numpy as np +import onnxruntime +import torch +from super_gradients.common.object_names import Models +from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_install +from super_gradients.import_utils import import_pytorch_quantization_or_install +from super_gradients.module_interfaces import ExportableOpticalFlowModel, OpticalFlowModelExportResult +from super_gradients.training import models + + +gs = import_onnx_graphsurgeon_or_install() +import_pytorch_quantization_or_install() + + +class TestOpticalFlowModelExport(unittest.TestCase): + def setUp(self) -> None: + logging.getLogger().setLevel(logging.DEBUG) + + self.models_to_test = [ + Models.RAFT_S, + Models.RAFT_L, + ] + + # def test_infer_input_image_shape_from_model(self): + # assert infer_image_shape_from_model(models.get(Models.RAFT_S, num_classes=1)) is None + # assert infer_image_shape_from_model(models.get(Models.RAFT_L, num_classes=1)) is None + + # def test_infer_input_image_num_channels_from_model(self): + # assert infer_image_input_channels(models.get(Models.RAFT_S, num_classes=1)) == 3 + # assert infer_image_input_channels(models.get(Models.RAFT_L, num_classes=1)) == 3 + + def test_export_to_onnxruntime_and_run(self): + """ + Test export to ONNX + """ + + with tempfile.TemporaryDirectory() as tmpdirname: + for model_type in self.models_to_test: + with self.subTest(model_type=model_type): + model_name = str(model_type).lower().replace(".", "_") + out_path = os.path.join(tmpdirname, f"{model_name}_onnxruntime.onnx") + + model_arch: ExportableOpticalFlowModel = models.get(model_name, num_classes=1) + export_result = model_arch.export( + out_path, + input_image_shape=(640, 640), # Force .export() to infer image shape from the model itself + input_image_channels=3, + input_image_dtype=torch.float32, + ) + + [flow_prediction] = self._run_inference_with_onnx(export_result) + self.assertTrue(flow_prediction.shape[0] == 1) + self.assertTrue(flow_prediction.shape[1] == 2) + self.assertTrue(flow_prediction.shape[2] == 640) + self.assertTrue(flow_prediction.shape[3] == 640) + + # def test_export_int8_quantized_with_calibration(self): + # with tempfile.TemporaryDirectory() as tmpdirname: + # for model_type in self.models_to_test: + # with self.subTest(model_type=model_type): + # model_name = str(model_type).lower().replace(".", "_") + # out_path = os.path.join(tmpdirname, f"{model_name}.onnx") + # + # dummy_calibration_dataset = [torch.randn((3, 640, 640), dtype=torch.float32) for _ in range(32)] + # dummy_calibration_loader = DataLoader(dummy_calibration_dataset, batch_size=8) + # + # model_arch: ExportableOpticalFlowModel = models.get(model_name, num_classes=1) + # export_result = model_arch.export( + # out_path, + # input_image_shape=(640, 640), # Force .export() to infer image shape from the model itself + # quantization_mode=ExportQuantizationMode.INT8, + # calibration_loader=dummy_calibration_loader, + # ) + # + # [flow_prediction] = self._run_inference_with_onnx(export_result) + # self.assertTrue(flow_prediction.shape[0] == 1) + # self.assertTrue(flow_prediction.shape[1] == 2) + # self.assertTrue(flow_prediction.shape[2] == 640) + # self.assertTrue(flow_prediction.shape[3] == 640) + + def _run_inference_with_onnx(self, export_result: OpticalFlowModelExportResult): + # onnx_filename = out_path, input_shape = export_result.image_shape, output_predictions_format = output_predictions_format + + input = np.zeros((1, 2, 3, 640, 640)).astype(np.float32) + + session = onnxruntime.InferenceSession(export_result.output) + inputs = [o.name for o in session.get_inputs()] + outputs = [o.name for o in session.get_outputs()] + result = session.run(outputs, {inputs[0]: input}) + + return result + + # def test_export_already_quantized_model(self): + # from super_gradients.training.utils.quantization import SelectiveQuantizer + # + # for model_type in self.models_to_test: + # with self.subTest(model_type=model_type): + # model = models.get(model_type, num_classes=1) + # q_util = SelectiveQuantizer( + # default_quant_modules_calibrator_weights="max", + # default_quant_modules_calibrator_inputs="histogram", + # default_per_channel_quant_weights=True, + # default_learn_amax=False, + # verbose=True, + # ) + # q_util.quantize_module(model) + # + # with tempfile.TemporaryDirectory() as tmpdirname: + # output_model1 = os.path.join(tmpdirname, f"{model_type}_quantized_explicit_int8.onnx") + # output_model2 = os.path.join(tmpdirname, f"{model_type}_quantized.onnx") + # + # # If model is already quantized to int8, the export should be successful but model should not be quantized again + # model.export( + # output_model1, + # quantization_mode=ExportQuantizationMode.INT8, + # ) + # + # # If model is quantized but quantization mode is not specified, the export should be also successful + # # but model should not be quantized again + # model.export( + # output_model2, + # quantization_mode=None, + # ) + # + # # If model is already quantized to int8, we should not be able to export model to FP16 + # with self.assertRaises(RuntimeError): + # model.export( + # "raft_s_quantized.onnx", + # quantization_mode=ExportQuantizationMode.FP16, + # ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/raft_tests.py b/tests/unit_tests/raft_tests.py new file mode 100644 index 0000000000..014f8d2ddd --- /dev/null +++ b/tests/unit_tests/raft_tests.py @@ -0,0 +1,38 @@ +import unittest + +import torch + +from super_gradients.common.object_names import Models +from super_gradients.training import models + + +class TestRAFT(unittest.TestCase): + def setUp(self): + self.models_to_test = [ + Models.RAFT_S, + Models.RAFT_L, + ] + + def test_raft_custom_in_channels(self): + """ + Validate that we can create a YOLO-NAS model with custom in_channels. + """ + for model_type in self.models_to_test: + with self.subTest(model_type=model_type): + model_name = str(model_type).lower().replace(".", "_") + model = models.get(model_name, arch_params=dict(in_channels=1), num_classes=1).eval() + model(torch.rand(1, 2, 1, 640, 640)) + + def test_raft_forward(self): + """ + Validate that we can create a YOLO-NAS model with custom in_channels. + """ + for model_type in self.models_to_test: + with self.subTest(model_type=model_type): + model_name = str(model_type).lower().replace(".", "_") + model = models.get(model_name, num_classes=1).eval() + model(torch.rand(1, 2, 3, 640, 640)) + + +if __name__ == "__main__": + unittest.main()