Skip to content

Commit

Permalink
Fix unused-variables warnings about EmbeddingBag ops
Browse files Browse the repository at this point in the history
According to the documentation for
`torch.embedding_bag` (https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding_bag.html),
the default value for `scale_grad_by_freq` is False.
  • Loading branch information
ramiro050 committed Aug 12, 2022
1 parent 51bfe25 commit 549a5a4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
12 changes: 12 additions & 0 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ class ConvertAtenEmbeddingBagPaddingIdxOp
Value sparse = op.sparse();
Value includeLastOffset = op.include_last_offset();

bool scaleGradByFreqBool;
if (!matchPattern(scaleGradByFreq,
m_TorchConstantBool(&scaleGradByFreqBool))) {
return rewriter.notifyMatchFailure(
op, "scale_grad_by_freq is expected to be a constant boolean value.");
}

if (!scaleGradByFreqBool) {
return rewriter.notifyMatchFailure(
op, "Unimplemented: scale_grad_by_freq=True.");
}

int64_t modeInt;
if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) {
return rewriter.notifyMatchFailure(
Expand Down
2 changes: 0 additions & 2 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2435,8 +2435,6 @@ class DecomposeAten_EmbeddingBagOp
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value weight = op.weight();
Value indices = op.indices();
Value offsets = op.offsets();
Expand Down

0 comments on commit 549a5a4

Please sign in to comment.