Skip to content

Commit

Permalink
Fix unused-variables warnings about EmbeddingBag ops (#1220)
Browse files Browse the repository at this point in the history
According to the documentation for
`torch.embedding_bag` (https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding_bag.html),
the default value for `scale_grad_by_freq` is False.
  • Loading branch information
ramiro050 authored Aug 15, 2022
1 parent c935795 commit 9d6ee48
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
14 changes: 13 additions & 1 deletion lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,23 @@ class ConvertAtenEmbeddingBagPaddingIdxOp
Value weight = adaptor.weight();
Value indices = adaptor.indices();
Value offsets = adaptor.offsets();
Value scaleGradByFreq = adaptor.scale_grad_by_freq();
Value scaleGradByFreq = op.scale_grad_by_freq();
Value mode = op.mode();
Value sparse = op.sparse();
Value includeLastOffset = op.include_last_offset();

bool scaleGradByFreqBool;
if (!matchPattern(scaleGradByFreq,
m_TorchConstantBool(&scaleGradByFreqBool))) {
return rewriter.notifyMatchFailure(
op, "scale_grad_by_freq is expected to be a constant boolean value.");
}

if (scaleGradByFreqBool) {
return rewriter.notifyMatchFailure(
op, "Unimplemented: scale_grad_by_freq=True.");
}

int64_t modeInt;
if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) {
return rewriter.notifyMatchFailure(
Expand Down
2 changes: 0 additions & 2 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2476,8 +2476,6 @@ class DecomposeAten_EmbeddingBagOp
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value weight = op.weight();
Value indices = op.indices();
Value offsets = op.offsets();
Expand Down

0 comments on commit 9d6ee48

Please sign in to comment.