Skip to content

Commit

Permalink
feat: Fixed conv1d converter when weights are Tensor (#2542)
Browse files Browse the repository at this point in the history
Signed-off-by: Anurag Dixit <[email protected]>
  • Loading branch information
andi4191 authored Feb 27, 2024
1 parent afd5abe commit 9a100b6
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 4 deletions.
32 changes: 28 additions & 4 deletions core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,26 +131,43 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)

// Make a new Dims with only the spatial dimensions.
nvinfer1::Dims filter_dim;
nvinfer1::Dims original_dim = in->getDimensions();
int64_t nbSpatialDims = in->getDimensions().nbDims - 2;
TORCHTRT_CHECK(
nbSpatialDims = kernel_dims.nbDims - 2,
"Number of input spatial dimensions should match the kernel spatial dimensions");
filter_dim.nbDims = nbSpatialDims;
filter_dim.d[0] = kernel_dims.d[2];
filter_dim.d[1] = kernel_dims.d[3];
int32_t num_output_maps = kernel_dims.d[0];
bool expand_dims = nbSpatialDims == 1;
if (expand_dims) {
// In case of Conv1D -> map it to 2D version
// TensorRT expects nbSpatialDims = 2 or 3
filter_dim = util::unsqueezeDims(filter_dim, filter_dim.nbDims, 1, false);
// Reshape input dimensions
in = addPadding(ctx, n, in, 4);
LOG_DEBUG("Reshaping input dimensions to: " << in->getDimensions());
kernel = addPadding(ctx, n, kernel, 4);
LOG_DEBUG("Reshaping kernel dimensions to: " << kernel->getDimensions());
if (transposed) {
num_output_maps = kernel_dims.d[1];
}
}

// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};

nvinfer1::ILayer* layer = nullptr;
nvinfer1::ITensor* out = nullptr;
if (transposed) {
// Fix padding based on output_padding provided
nvinfer1::Dims begPadding = padding;
bool hasOutputPadding = false;
add_output_padding(padding, out_padding, hasOutputPadding);

nvinfer1::IDeconvolutionLayer* deconvLayer = ctx->net->addDeconvolutionNd(
*in, kernel_dims.d[0], filter_dim, kernel_weights, hasOutputPadding ? nvinfer1::Weights{} : bias.data);
*in, num_output_maps, filter_dim, kernel_weights, hasOutputPadding ? nvinfer1::Weights{} : bias.data);
deconvLayer->setStrideNd(stride);
deconvLayer->setDilationNd(dilation);
deconvLayer->setNbGroups(groups);
Expand All @@ -161,15 +178,21 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
deconvLayer->setInput(1, *kernel);
TORCHTRT_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n);
layer = deconvLayer;
out = deconvLayer->getOutput(0);
if (hasOutputPadding) {
LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding);
nvinfer1::ITensor* tensorPtr = deconvLayer->getOutput(0);
auto dims = in->getDimensions();
layer = add_bias_layer(ctx, tensorPtr, dims, out_padding, bias);
out = layer->getOutput(0);
}
if (expand_dims) {
// Un-expand the expanded dimension
out = addUnpadding(ctx, n, out, original_dim.nbDims);
}
} else {
nvinfer1::IConvolutionLayer* convLayer =
ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
ctx->net->addConvolutionNd(*in, num_output_maps, filter_dim, kernel_weights, bias.data);
convLayer->setStrideNd(stride);
convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
convLayer->setPaddingNd(padding);
Expand All @@ -180,10 +203,11 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
// Set conv kernel weights
convLayer->setInput(1, *kernel);
layer = convLayer;
out = layer->getOutput(0);
}

ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << layer->getOutput(0)->getDimensions());
ctx->AssociateValueAndTensor(n->outputs()[0], out);
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}

Expand Down
102 changes: 102 additions & 0 deletions tests/core/conversion/converters/test_conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,57 @@ TEST(Converters, ATenConvolution1dConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenConv1dWithWeightTensorsConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 5, 3, strides=[15, 3, 1])):
%2 : int = prim::Constant[value=-128]()
%3 : float = prim::Constant[value=3.5]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=127]()
%quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
%6 : int = prim::Constant[value=6]()
%7 : int = prim::Constant[value=5]()
%8 : Device = prim::Constant[value="cuda:0"]()
%9 : None = prim::Constant()
%10 : int[] = prim::ListConstruct(%7)
%11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
%12 : int[] = prim::ListConstruct(%7)
%13 : int = prim::Constant[value=1]()
%14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
%quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
%15 : None = prim::Constant()
%16 : int = prim::Constant[value=1]()
%17 : int = prim::Constant[value=0]()
%18 : int = prim::Constant[value=1]()
%19 : int = prim::Constant[value=0]()
%20 : bool = prim::Constant[value=0]()
%21 : int[] = prim::ListConstruct(%16)
%22 : int[] = prim::ListConstruct(%17)
%23 : int[] = prim::ListConstruct(%18)
%24 : int[] = prim::ListConstruct(%19)
%25 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %21, %22, %23, %20, %24, %16, %20, %20, %20, %20)
return (%25))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {4, 5, 3}, {at::kCUDA});
auto w = at::randint(1, 2, {4, 5, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_w});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_w});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}, nvinfer1::DataType::kINT8);

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
Expand Down Expand Up @@ -609,6 +660,57 @@ TEST(Converters, ATenConv1dTransposeWithPaddingOutPaddingConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenConv1dTransposeWithWeightTensorsConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 5, 3, strides=[15, 3, 1])):
%2 : int = prim::Constant[value=-128]()
%3 : float = prim::Constant[value=3.5]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=127]()
%quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
%6 : int = prim::Constant[value=6]()
%7 : int = prim::Constant[value=4]()
%8 : Device = prim::Constant[value="cuda:0"]()
%9 : None = prim::Constant()
%10 : int[] = prim::ListConstruct(%7)
%11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
%12 : int[] = prim::ListConstruct(%7)
%13 : int = prim::Constant[value=1]()
%14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
%quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
%15 : None = prim::Constant()
%16 : int = prim::Constant[value=1]()
%17 : int = prim::Constant[value=0]()
%18 : int = prim::Constant[value=1]()
%19 : int = prim::Constant[value=0]()
%20 : bool = prim::Constant[value=0]()
%21 : int[] = prim::ListConstruct(%16)
%22 : int[] = prim::ListConstruct(%17)
%23 : int[] = prim::ListConstruct(%18)
%24 : int[] = prim::ListConstruct(%19)
%25 : bool = prim::Constant[value=1]()
%26 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %21, %22, %23, %25, %24, %18, %20, %20, %20, %20)
return (%26))IR";
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {4, 5, 3}, {at::kCUDA});
auto w = at::randint(1, 2, {5, 4, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_w});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_w});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}, nvinfer1::DataType::kINT8);

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenConvTransposeWithPaddingOutPaddingConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
Expand Down

0 comments on commit 9a100b6

Please sign in to comment.