Skip to content

Commit

Permalink
Cabs calling convention (rust-lang#749)
Browse files Browse the repository at this point in the history
* Handle array types in TypeAnalysis

* Handle array calling convention of cabs

* Add tests
  • Loading branch information
tgymnich authored Jul 28, 2022
1 parent 1d9d047 commit 8abdb9f
Show file tree
Hide file tree
Showing 11 changed files with 424 additions and 13 deletions.
84 changes: 72 additions & 12 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9945,26 +9945,65 @@ class AdjointGenerator
Value *d = Builder2.CreateCall(called, args);

if (args.size() == 2) {
Value *op0 = diffe(orig->getArgOperand(0), Builder2);

Value *op1 = diffe(orig->getArgOperand(1), Builder2);
Value *op0 = gutils->isConstantValue(orig->getArgOperand(0))
? nullptr
: diffe(orig->getArgOperand(0), Builder2);
Value *op1 = gutils->isConstantValue(orig->getArgOperand(1))
? nullptr
: diffe(orig->getArgOperand(1), Builder2);

auto rule1 = [&](Value *op) {
return Builder2.CreateFMul(args[0], Builder2.CreateFDiv(op, d));
};

auto rule = [&](Value *op0, Value *op1) {
auto rule2 = [&](Value *op0, Value *op1) {
Value *dif1 =
Builder2.CreateFMul(args[0], Builder2.CreateFDiv(op0, d));
Value *dif2 =
Builder2.CreateFMul(args[1], Builder2.CreateFDiv(op1, d));
return Builder2.CreateFAdd(dif1, dif2);
};

Value *dif =
applyChainRule(call.getType(), Builder2, rule, op0, op1);
Value *dif;
if (op0 && op1)
dif = applyChainRule(call.getType(), Builder2, rule2, op0, op1);
else if (op0)
dif = applyChainRule(call.getType(), Builder2, rule1, op0);
else if (op1)
dif = applyChainRule(call.getType(), Builder2, rule1, op1);
else
llvm_unreachable(
"trying to differentiate a constant instruction");

setDiffe(orig, dif, Builder2);
return;
} else {
llvm::errs() << *orig << "\n";
llvm_unreachable("unknown calling convention found for cabs");
} else if (args.size() == 1) {
if (auto AT = dyn_cast<ArrayType>(args[0]->getType())) {
if (AT->getNumElements() == 2) {
Value *op = diffe(orig->getArgOperand(0), Builder2);
Value *args0 = Builder2.CreateExtractValue(args[0], 0);
Value *args1 = Builder2.CreateExtractValue(args[0], 1);

auto rule = [&](Value *op) {
Value *op0 = Builder2.CreateExtractValue(op, 0);
Value *op1 = Builder2.CreateExtractValue(op, 1);

Value *dif1 =
Builder2.CreateFMul(args0, Builder2.CreateFDiv(op0, d));
Value *dif2 =
Builder2.CreateFMul(args1, Builder2.CreateFDiv(op1, d));
return Builder2.CreateFAdd(dif1, dif2);
};

Value *dif =
applyChainRule(call.getType(), Builder2, rule, op);
setDiffe(orig, dif, Builder2);
return;
}
}
}
llvm::errs() << *orig << "\n";
llvm_unreachable("unknown calling convention found for cabs");
}
case DerivativeMode::ReverseModeGradient:
case DerivativeMode::ReverseModeCombined: {
Expand Down Expand Up @@ -9998,10 +10037,31 @@ class AdjointGenerator
Builder2.CreateFMul(args[i], div), Builder2,
orig->getType());
return;
} else {
llvm::errs() << *orig << "\n";
llvm_unreachable("unknown calling convention found for cabs");
} else if (args.size() == 1) {
if (auto AT = dyn_cast<ArrayType>(args[0]->getType())) {
if (AT->getNumElements() == 2) {
if (!gutils->isConstantValue(orig->getArgOperand(0))) {
Value *agg = UndefValue::get(args[0]->getType());
agg = Builder2.CreateInsertValue(
agg,
Builder2.CreateFMul(
Builder2.CreateExtractValue(args[0], 0), div),
0);
agg = Builder2.CreateInsertValue(
agg,
Builder2.CreateFMul(
Builder2.CreateExtractValue(args[0], 1), div),
1);

addToDiffe(orig->getArgOperand(0), agg, Builder2,
orig->getType());
return;
}
}
}
}
llvm::errs() << *orig << "\n";
llvm_unreachable("unknown calling convention found for cabs");
}
case DerivativeMode::ReverseModePrimal: {
return;
Expand Down
16 changes: 15 additions & 1 deletion enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4240,7 +4240,21 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
llvm::errs() << *T << " - " << call << "\n";
llvm_unreachable("Unknown type for libm");
}

} else if (auto AT = dyn_cast<ArrayType>(T)) {
assert(AT->getNumElements() >= 1);
if (AT->getElementType()->isFloatingPointTy())
updateAnalysis(
call.getArgOperand(i),
TypeTree(ConcreteType(AT->getElementType()->getScalarType()))
.Only(-1),
&call);
else if (AT->getElementType()->isIntegerTy()) {
updateAnalysis(call.getArgOperand(i),
TypeTree(BaseType::Integer).Only(-1), &call);
} else {
llvm::errs() << *T << " - " << call << "\n";
llvm_unreachable("Unknown type for libm");
}
} else {
llvm::errs() << *T << " - " << call << "\n";
llvm_unreachable("Unknown type for libm");
Expand Down
28 changes: 28 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/cabs-const.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x, double %y) {
entry:
%call = call double @cabs(double %x, double %y)
ret double %call
}

define double @test_derivative(double %x, double %y) {
entry:
%0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_const", double %x, double %y, double 1.0)
ret double %0
}

declare double @cabs(double, double)

; Function Attrs: nounwind
declare double @__enzyme_fwddiff(double (double, double)*, ...)


; CHECK: define internal double @fwddiffetester(double %x, double %y, double %"y'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y)
; CHECK-NEXT: %1 = fdiv fast double %"y'", %0
; CHECK-NEXT: %2 = fmul fast double %x, %1
; CHECK-NEXT: ret double %2
; CHECK-NEXT:}
31 changes: 31 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/cabs.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x, double %y) {
entry:
%call = call double @cabs(double %x, double %y)
ret double %call
}

define double @test_derivative(double %x, double %y) {
entry:
%0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 1.0)
ret double %0
}

declare double @cabs(double, double)

; Function Attrs: nounwind
declare double @__enzyme_fwddiff(double (double, double)*, ...)


; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y)
; CHECK-NEXT: %1 = fdiv fast double %"x'", %0
; CHECK-NEXT: %2 = fmul fast double %x, %1
; CHECK-NEXT: %3 = fdiv fast double %"y'", %0
; CHECK-NEXT: %4 = fmul fast double %y, %3
; CHECK-NEXT: %5 = fadd fast double %2, %4
; CHECK-NEXT: ret double %5
; CHECK-NEXT: }
38 changes: 38 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/cabs2-const.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s

; Function Attrs: nounwind readnone willreturn
declare double @cabs([2 x double])

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x, double %y) {
entry:
%agg0 = insertvalue [2 x double] undef, double %x, 0
%agg1 = insertvalue [2 x double] %agg0, double %y, 1
%call = call double @cabs([2 x double] %agg1)
ret double %call
}

define double @test_derivative(double %x, double %y) {
entry:
%0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_const", double %x, double %y, double 1.0)
ret double %0
}

; Function Attrs: nounwind
declare double @__enzyme_fwddiff(double (double, double)*, ...)


; CHECK: define internal double @fwddiffetester(double %x, double %y, double %"y'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0
; CHECK-NEXT: %"agg1'ipiv" = insertvalue [2 x double] zeroinitializer, double %"y'", 1
; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1
; CHECK-NEXT: %0 = call fast double @cabs([2 x double] %agg1)
; CHECK-NEXT: %1 = extractvalue [2 x double] %"agg1'ipiv", 0
; CHECK-NEXT: %2 = fdiv fast double %1, %0
; CHECK-NEXT: %3 = fmul fast double %x, %2
; CHECK-NEXT: %4 = fdiv fast double %"y'", %0
; CHECK-NEXT: %5 = fmul fast double %y, %4
; CHECK-NEXT: %6 = fadd fast double %3, %5
; CHECK-NEXT: ret double %6
; CHECK-NEXT: }
36 changes: 36 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/cabs2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s

; Function Attrs: nounwind readnone willreturn
declare double @cabs([2 x double])

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x, double %y) {
entry:
%agg0 = insertvalue [2 x double] undef, double %x, 0
%agg1 = insertvalue [2 x double] %agg0, double %y, 1
%call = call double @cabs([2 x double] %agg1)
ret double %call
}

define double @test_derivative(double %x, double %y) {
entry:
%0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 1.0)
ret double %0
}

; Function Attrs: nounwind
declare double @__enzyme_fwddiff(double (double, double)*, ...)


; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0
; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1
; CHECK-NEXT: %0 = call fast double @cabs([2 x double] %agg1)
; CHECK-NEXT: %1 = fdiv fast double %"x'", %0
; CHECK-NEXT: %2 = fmul fast double %x, %1
; CHECK-NEXT: %3 = fdiv fast double %"y'", %0
; CHECK-NEXT: %4 = fmul fast double %y, %3
; CHECK-NEXT: %5 = fadd fast double %2, %4
; CHECK-NEXT: ret double %5
; CHECK-NEXT: }
50 changes: 50 additions & 0 deletions enzyme/test/Enzyme/ForwardModeVector/cabs.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x, double %y) {
entry:
%call = call double @cabs(double %x, double %y)
ret double %call
}

define [3 x double] @test_derivative(double %x, double %y) {
entry:
%0 = tail call [3 x double] (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_width", i64 3, double %x, double 1.0, double 1.3, double 2.0, double %y, double 1.0, double 0.0, double 2.0)
ret [3 x double] %0
}

declare double @cabs(double, double)

; Function Attrs: nounwind
declare [3 x double] @__enzyme_fwddiff(double (double, double)*, ...)


; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'", double %y, [3 x double] %"y'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y)
; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 0
; CHECK-NEXT: %2 = extractvalue [3 x double] %"y'", 0
; CHECK-NEXT: %3 = fdiv fast double %1, %0
; CHECK-NEXT: %4 = fmul fast double %x, %3
; CHECK-NEXT: %5 = fdiv fast double %2, %0
; CHECK-NEXT: %6 = fmul fast double %y, %5
; CHECK-NEXT: %7 = fadd fast double %4, %6
; CHECK-NEXT: %8 = insertvalue [3 x double] undef, double %7, 0
; CHECK-NEXT: %9 = extractvalue [3 x double] %"x'", 1
; CHECK-NEXT: %10 = extractvalue [3 x double] %"y'", 1
; CHECK-NEXT: %11 = fdiv fast double %9, %0
; CHECK-NEXT: %12 = fmul fast double %x, %11
; CHECK-NEXT: %13 = fdiv fast double %10, %0
; CHECK-NEXT: %14 = fmul fast double %y, %13
; CHECK-NEXT: %15 = fadd fast double %12, %14
; CHECK-NEXT: %16 = insertvalue [3 x double] %8, double %15, 1
; CHECK-NEXT: %17 = extractvalue [3 x double] %"x'", 2
; CHECK-NEXT: %18 = extractvalue [3 x double] %"y'", 2
; CHECK-NEXT: %19 = fdiv fast double %17, %0
; CHECK-NEXT: %20 = fmul fast double %x, %19
; CHECK-NEXT: %21 = fdiv fast double %18, %0
; CHECK-NEXT: %22 = fmul fast double %y, %21
; CHECK-NEXT: %23 = fadd fast double %20, %22
; CHECK-NEXT: %24 = insertvalue [3 x double] %16, double %23, 2
; CHECK-NEXT: ret [3 x double] %24
; CHECK-NEXT: }
55 changes: 55 additions & 0 deletions enzyme/test/Enzyme/ForwardModeVector/cabs2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s

; Function Attrs: nounwind readnone willreturn
declare double @cabs([2 x double]) #7

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x, double %y) {
entry:
%agg0 = insertvalue [2 x double] undef, double %x, 0
%agg1 = insertvalue [2 x double] %agg0, double %y, 1
%call = call double @cabs([2 x double] %agg1)
ret double %call
}

define [3 x double] @test_derivative(double %x, double %y) {
entry:
%0 = tail call [3 x double] (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_width", i64 3, double %x, double 1.0, double 1.3, double 2.0, double %y, double 1.0, double 0.0, double 2.0)
ret [3 x double] %0
}

; Function Attrs: nounwind
declare [3 x double] @__enzyme_fwddiff(double (double, double)*, ...)


; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'", double %y, [3 x double] %"y'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = extractvalue [3 x double] %"x'", 0
; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 1
; CHECK-NEXT: %2 = extractvalue [3 x double] %"x'", 2
; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0
; CHECK-NEXT: %3 = extractvalue [3 x double] %"y'", 0
; CHECK-NEXT: %4 = extractvalue [3 x double] %"y'", 1
; CHECK-NEXT: %5 = extractvalue [3 x double] %"y'", 2
; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1
; CHECK-NEXT: %6 = call fast double @cabs([2 x double] %agg1)
; CHECK-NEXT: %7 = fdiv fast double %0, %6
; CHECK-NEXT: %8 = fmul fast double %x, %7
; CHECK-NEXT: %9 = fdiv fast double %3, %6
; CHECK-NEXT: %10 = fmul fast double %y, %9
; CHECK-NEXT: %11 = fadd fast double %8, %10
; CHECK-NEXT: %12 = insertvalue [3 x double] undef, double %11, 0
; CHECK-NEXT: %13 = fdiv fast double %1, %6
; CHECK-NEXT: %14 = fmul fast double %x, %13
; CHECK-NEXT: %15 = fdiv fast double %4, %6
; CHECK-NEXT: %16 = fmul fast double %y, %15
; CHECK-NEXT: %17 = fadd fast double %14, %16
; CHECK-NEXT: %18 = insertvalue [3 x double] %12, double %17, 1
; CHECK-NEXT: %19 = fdiv fast double %2, %6
; CHECK-NEXT: %20 = fmul fast double %x, %19
; CHECK-NEXT: %21 = fdiv fast double %5, %6
; CHECK-NEXT: %22 = fmul fast double %y, %21
; CHECK-NEXT: %23 = fadd fast double %20, %22
; CHECK-NEXT: %24 = insertvalue [3 x double] %18, double %23, 2
; CHECK-NEXT: ret [3 x double] %24
; CHECK-NEXT: }
Loading

0 comments on commit 8abdb9f

Please sign in to comment.