Skip to content

Commit

Permalink
Fix GEMM importer (#828)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <[email protected]>
  • Loading branch information
kevinch-nv authored Apr 11, 2022
1 parent 0031a42 commit 5f27e45
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 114 deletions.
96 changes: 7 additions & 89 deletions builtin_op_importers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1523,87 +1523,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
bool transA = attrs.get("transA", false);
bool transB = attrs.get("transB", false);
nvinfer1::ITensor& inputA = convertToTensor(inputs.at(0), ctx);
nvinfer1::ITensor& inputB = convertToTensor(inputs.at(1), ctx);
// Validate inputs
ASSERT(inputs.at(0).shape().nbDims == 2 && inputs.at(1).shape().nbDims == 2 && "GEMM must have 2D inputs!", ErrorCode::kINVALID_NODE);
ASSERT(inputA.getDimensions().nbDims == 2 && inputB.getDimensions().nbDims == 2 && "GEMM must have 2D inputs!", ErrorCode::kINVALID_NODE);
// TRT does not support INT32 input types for this node
ASSERT(!inputs.at(0).isInt32() && !inputs.at(1).isInt32()
&& "TensorRT doesn't support INT32 inputs for GEMM!", ErrorCode::kUNSUPPORTED_NODE);
// Use FC if it is likely to be faster - which is usually when no Shuffles are required.
bool canUseFC = inputs.at(0).is_tensor() && inputs.at(1).is_weights() && alpha == 1.f
&& beta == 1.f && inputs.at(0).tensor().getDimensions().nbDims == 2 && inputs.at(1).weights().shape.nbDims == 2;
canUseFC &= inputs.size() < 3 || (inputs.at(2).is_weights() && inputs.at(2).weights().shape.nbDims == 1);
if (canUseFC)
{
LOG_VERBOSE("GEMM: using FC layer instead of MM because all criteria were met.");
const std::vector<int> axesInput{2, 3};
nvinfer1::ITensor* inputAExtendDim = unsqueezeTensor(ctx, node, inputA, axesInput);

ShapedWeights weights = inputs.at(1).weights();
if (!transB)
{
auto transposedWeights = ctx->createTempWeights(weights.type, weights.shape);
ASSERT(transposeWeights(weights, {1, 0}, &transposedWeights, ctx), ErrorCode::kUNSUPPORTED_NODE);
weights = transposedWeights;
}
ShapedWeights biases{};
if (inputs.size() > 2)
{
biases = inputs.at(2).weights();
}
nvinfer1::IFullyConnectedLayer* fc = ctx->network()->addFullyConnected(*inputAExtendDim, biases.shape.d[0], weights, biases);
// Register layer, along with refittable kernel weights and bias weights (if any)
ctx->registerLayer(fc, getNodeName(node));
ctx->network()->setWeightsName(weights, weights.getName());
if (inputs.size() == 3)
{
ctx->network()->setWeightsName(biases, inputs.at(2).weights().getName());
}
const std::vector<int> axesOutput{2, 3};
return {{squeezeTensor(ctx, node, *fc->getOutput(0), axesOutput)}};
}

nvinfer1::ITensor* inputB {nullptr};

// If input B is a constant, we transpose at parse time if necessary,
// because In some cases, A * Bt is much slower than A * B.
if (inputs.at(1).is_weights())
{
ShapedWeights weights = inputs.at(1).weights();
if (transB)
{
auto transposedWeights = ctx->createTempWeights(weights.type, weights.shape);
ASSERT(transposeWeights(weights, {1, 0}, &transposedWeights, ctx) && "Failed to transpose input tensor B.", ErrorCode::kUNSUPPORTED_NODE);
weights = transposedWeights;
// Since we've already transposed now, we can set transpose to false.
transB = false;
}
nvinfer1::IConstantLayer* weightsLayer
= ctx->network()->addConstant(weights.shape, static_cast<nvinfer1::Weights>(weights));
// Map the constant layer to the weights name.
ctx->registerLayer(weightsLayer, node.input(1));
ctx->network()->setWeightsName(weights, weights.getName());
inputB = weightsLayer->getOutput(0);
}
else
{
inputB = &inputs.at(1).tensor();
}

nvinfer1::ITensor* inputASqueezed = &inputA;
nvinfer1::Dims newDims = squeeze_trailing_dims(inputA.getDimensions());
// When A has more than 2 dimensions, it needs to be flattened.
if (newDims.nbDims > 2)
{
newDims = nvinfer1::Dims{1, {-1}};
}
// Due to other TRT layers, inputA may sometimes have trailing 1s that need to be removed.
if (newDims.nbDims < inputA.getDimensions().nbDims)
{
nvinfer1::IShuffleLayer* squeeze = ctx->network()->addShuffle(inputA);
squeeze->setReshapeDimensions(newDims);
squeeze->setZeroIsPlaceholder(false);
inputASqueezed = squeeze->getOutput(0);
}
ASSERT(!inputs.at(0).isInt32() && !inputs.at(1).isInt32() && "TensorRT doesn't support INT32 inputs for GEMM!",
ErrorCode::kUNSUPPORTED_NODE);

const auto getMatrixOp = [](const nvinfer1::ITensor& input, bool transpose) {
if (input.getDimensions().nbDims == 1)
Expand All @@ -1617,13 +1542,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
return nvinfer1::MatrixOperation::kNONE;
};

nvinfer1::MatrixOperation opA = getMatrixOp(*inputASqueezed, transA);
nvinfer1::MatrixOperation opB = getMatrixOp(*inputB, transB);
nvinfer1::MatrixOperation opA = getMatrixOp(inputA, transA);
nvinfer1::MatrixOperation opB = getMatrixOp(inputB, transB);

LOG_VERBOSE("Using opA: " << static_cast<int>(opA) << " opB: " << static_cast<int>(opB));
LOG_VERBOSE("GEMM: A, after squeezing: " << inputASqueezed->getDimensions());

nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network()->addMatrixMultiply(*inputASqueezed, opA, *inputB, opB);
nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network()->addMatrixMultiply(inputA, opA, inputB, opB);
ctx->registerLayer(matmul, getNodeName(node));
nvinfer1::ITensor* matmulTensor = matmul->getOutput(0);

Expand Down Expand Up @@ -1655,12 +1579,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
*betaConstantTensor, *biasTensor, nvinfer1::ElementWiseOperation::kPROD);
biasTensor = scaledBias->getOutput(0);
}
// A*B may be lower rank than C in TRT, so need to squeeze C.
if (ctx->getOpsetVersion() < 7 && !attrs.get("broadcast", false))
{
nvinfer1::Dims squeezeDims = squeeze_leading_dims(biasTensor->getDimensions());
biasTensor = reshapeTensor(ctx, *biasTensor, squeezeDims);
}
CHECK(broadcastTensors(ctx, matmulTensor, biasTensor));
nvinfer1::IElementWiseLayer* biasAdd
= ctx->network()->addElementWise(*matmulTensor, *biasTensor, nvinfer1::ElementWiseOperation::kSUM);
Expand Down
25 changes: 0 additions & 25 deletions trt_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,31 +102,6 @@ inline nvinfer1::Permutation remove_first_dim(nvinfer1::Permutation const& perm)
return new_perm;
}

inline nvinfer1::Dims squeeze_trailing_dims(nvinfer1::Dims const& dims)
{
nvinfer1::Dims new_dims = dims;
// Note: TRT requires at least one dimension, so we don't squeeze [1]->[]
while (new_dims.nbDims > 1 && new_dims.d[new_dims.nbDims - 1] == 1)
{
--new_dims.nbDims;
}
return new_dims;
}

inline nvinfer1::Dims squeeze_leading_dims(const nvinfer1::Dims& dims)
{
nvinfer1::Dims newDims;
// Copy dims only if a non-1 has been seen already.
bool non1Seen{false};
newDims.nbDims = std::copy_if(dims.d, dims.d + dims.nbDims, newDims.d,
[&non1Seen](int x) {
non1Seen = (x != 1) ? true : non1Seen;
return non1Seen;
})
- newDims.d;
return newDims;
}

inline nvinfer1::DimsHW operator-(nvinfer1::DimsHW dims)
{
return nvinfer1::DimsHW(-dims.h(), -dims.w());
Expand Down

0 comments on commit 5f27e45

Please sign in to comment.