Skip to content

Commit

Permalink
Handle cbrt (rust-lang#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jul 17, 2021
1 parent e5d4699 commit 3aef65b
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
38 changes: 35 additions & 3 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3845,6 +3845,38 @@ class AdjointGenerator
return;
}

if (funcName == "cbrt") {
if (gutils->knownRecomputeHeuristic.find(orig) !=
gutils->knownRecomputeHeuristic.end()) {
if (!gutils->knownRecomputeHeuristic[orig]) {
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
getIndex(orig, CacheType::Self));
}
}
eraseIfUnused(*orig);
if (Mode == DerivativeMode::ReverseModePrimal ||
gutils->isConstantInstruction(orig))
return;

IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);
Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)),
Builder2);
Value *args[] = {x};
#if LLVM_VERSION_MAJOR >= 11
auto callval = orig->getCalledOperand();
#else
auto callval = orig->getCalledValue();
#endif
Value *dif0 = Builder2.CreateFDiv(
Builder2.CreateFMul(diffe(orig, Builder2), x),
Builder2.CreateFMul(
ConstantFP::get(x->getType(), 3),
Builder2.CreateCall(orig->getFunctionType(), callval, args)));
addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType());
return;
}

if (funcName == "tanhf" || funcName == "tanh") {
if (gutils->knownRecomputeHeuristic.find(orig) !=
gutils->knownRecomputeHeuristic.end()) {
Expand Down Expand Up @@ -4201,9 +4233,9 @@ class AdjointGenerator
getIndex(orig, CacheType::Self));
}
}
eraseIfUnused(*orig);
if (Mode == DerivativeMode::ReverseModePrimal ||
gutils->isConstantInstruction(orig)) {
eraseIfUnused(*orig);
return;
}

Expand Down Expand Up @@ -4244,9 +4276,9 @@ class AdjointGenerator
getIndex(orig, CacheType::Self));
}
}
eraseIfUnused(*orig);
if (Mode == DerivativeMode::ReverseModePrimal ||
gutils->isConstantInstruction(orig)) {
eraseIfUnused(*orig);
return;
}

Expand Down Expand Up @@ -4288,9 +4320,9 @@ class AdjointGenerator
getIndex(orig, CacheType::Self));
}
}
eraseIfUnused(*orig);
if (Mode == DerivativeMode::ReverseModePrimal ||
gutils->isConstantInstruction(orig)) {
eraseIfUnused(*orig);
return;
}

Expand Down
3 changes: 1 addition & 2 deletions enzyme/test/Enzyme/ReverseMode/cabs.ll
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ declare double @__enzyme_autodiff(double (double, double)*, ...)

; CHECK: define internal { double, double } @diffetester(double %x, double %y, double %differeturn) {
; CHECK-NEXT: entry:
; CHECK-NEXT: %call = call double @cabs(double %x, double %y)
; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y)
; CHECK-NEXT: %1 = fdiv fast double %differeturn, %0
; CHECK-NEXT: %2 = fmul fast double %x, %1
; CHECK-NEXT: %3 = fmul fast double %y, %1
; CHECK-NEXT: %4 = insertvalue { double, double } undef, double %2, 0
; CHECK-NEXT: %5 = insertvalue { double, double } %4, double %3, 1
; CHECK-NEXT: ret { double, double } %5
; CHECK-NEXT: }
; CHECK-NEXT: }
29 changes: 29 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/cbrt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x) {
entry:
%call = call double @cbrt(double %x)
ret double %call
}

define double @test_derivative(double %x) {
entry:
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
ret double %0
}

declare double @cbrt(double)

; Function Attrs: nounwind
declare double @__enzyme_autodiff(double (double)*, ...)

; CHECK: define internal { double } @diffetester(double %x, double %differeturn) {
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @cbrt(double %x)
; CHECK-NEXT: %1 = fmul fast double 3.000000e+00, %0
; CHECK-NEXT: %2 = fmul fast double %differeturn, %x
; CHECK-NEXT: %3 = fdiv fast double %2, %1
; CHECK-NEXT: %4 = insertvalue { double } undef, double %3, 0
; CHECK-NEXT: ret { double } %4
; CHECK-NEXT: }

0 comments on commit 3aef65b

Please sign in to comment.