Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: expected scalar type Half but found Float: issues when I tried to run DCN on pytorch without C++ #245

Open
AndywithCV opened this issue Oct 19, 2023 · 2 comments

Comments

@AndywithCV
Copy link

Since I cannot compile C++ in my environment because of the imcompatibility of pakages, I then tried to use DCN with pytorch, but I encounter and error as below:

File "/root/data/andy/newyolov8/ultralytics/ultralytics/nn/modules/block.py", line 708, in dcnv3_core_pytorch
sampling_input_ = F.grid_sample(
File "/root/miniconda3/envs/newyolov8benchmark/lib/python3.8/site-packages/torch/nn/functional.py", line 4244, in grid_sample
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
RuntimeError: expected scalar type Half but found Float

Please help me to take a look!!!

@leij0318
Copy link

leij0318 commented Oct 20, 2023

Change the function in dcnv3_func.py in the functions folder in the ops--_dcnv3 folder

Replace dtype=torch.float32 in _get_reference_points and _generate_dilation_grids with dtype=torch.float16, so that the data type of the code when performing dcnv3-related data calculations is 16-bit floating point instead of 32-bit floating point, so that it can be used on the GPU Calculate on the GPU instead of the CPU, so that the ref and grid parameters can be calculated on the GPU.

Modify the code as follows:
##################################################################################

def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0,
                          stride_h=1, stride_w=1):
    _, H_, W_, _ = spatial_shapes
    H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
    W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1
    ref_y, ref_x = torch.meshgrid(
        torch.linspace(
            # pad_h + 0.5,
            # H_ - pad_h - 0.5,
            (dilation_h * (kernel_h - 1)) // 2 + 0.5,
            (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h,
            H_out,
            dtype=torch.float16,  ##torch.float32
            device=device),
        torch.linspace(
            # pad_w + 0.5,
            # W_ - pad_w - 0.5,
            (dilation_w * (kernel_w - 1)) // 2 + 0.5,
            (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w,
            W_out,
            dtype=torch.float32,
            device=device))
    ref_y = ref_y.reshape(-1)[None] / H_
    ref_x = ref_x.reshape(-1)[None] / W_

    ref = torch.stack((ref_x, ref_y), -1).reshape(
        1, H_out, W_out, 1, 2)

    return ref


def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device):
    _, H_, W_, _ = spatial_shapes
    points_list = []
    x, y = torch.meshgrid(
        torch.linspace(
            -((dilation_w * (kernel_w - 1)) // 2),
            -((dilation_w * (kernel_w - 1)) // 2) +
            (kernel_w - 1) * dilation_w, kernel_w,
            dtype=torch.float16,  # torch.float32
            device=device),
        torch.linspace(
            -((dilation_h * (kernel_h - 1)) // 2),
            -((dilation_h * (kernel_h - 1)) // 2) +
            (kernel_h - 1) * dilation_h, kernel_h,
            dtype=torch.float16,  # torch.float32
            device=device))

    points_list.extend([x / W_, y / H_])
    grid = torch.stack(points_list, -1).reshape(-1, 1, 2). \
        repeat(1, group, 1).permute(1, 0, 2)
    grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2)

    return grid

###################################################################
The above method does not guarantee that the modification will be correct, you can try it!

@GoblinCraftman
Copy link

and in the first part def _get_reference_points,there are still float32 remain. You need to change it into float16 or an error:meshgrid need the same type. And so on, in issue dcnv3 in modules also have two float need to change into float16.
Thank you for your thought.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants