Skip to content

Commit

Permalink
preparations for new modes (rust-lang#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored Oct 6, 2021
1 parent e97c0d1 commit 6e45ead
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 128 deletions.
27 changes: 19 additions & 8 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,15 @@ class AdjointGenerator
Value *mask = nullptr, Value *orig_maskInit = nullptr) {
auto &DL = gutils->newFunc->getParent()->getDataLayout();

assert(gutils->can_modref_map);
assert(gutils->can_modref_map->find(&I) != gutils->can_modref_map->end());
bool can_modref = gutils->can_modref_map->find(&I)->second;
assert(Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeVector || gutils->can_modref_map);
assert(Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeVector ||
gutils->can_modref_map->find(&I) != gutils->can_modref_map->end());
bool can_modref = Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeVector
? false
: gutils->can_modref_map->find(&I)->second;

constantval |= gutils->isConstantValue(&I);

Expand Down Expand Up @@ -5726,14 +5732,18 @@ class AdjointGenerator
IRBuilder<> BuilderZ(newCall);
BuilderZ.setFastMathFlags(getFast());

if (uncacheable_args_map.find(&call) == uncacheable_args_map.end()) {
if (uncacheable_args_map.find(&call) == uncacheable_args_map.end() &&
Mode != DerivativeMode::ForwardMode &&
Mode != DerivativeMode::ForwardModeVector) {
llvm::errs() << " call: " << call << "\n";
for (auto &pair : uncacheable_args_map) {
llvm::errs() << " + " << *pair.first << "\n";
}
}

assert(uncacheable_args_map.find(&call) != uncacheable_args_map.end());
assert(uncacheable_args_map.find(&call) != uncacheable_args_map.end() ||
Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeVector);
const std::map<Argument *, bool> &uncacheable_args =
uncacheable_args_map.find(&call)->second;

Expand Down Expand Up @@ -7613,7 +7623,9 @@ class AdjointGenerator
// If we need this value and it is illegal to recompute it (it writes or
// may load uncacheable data)
// Store and reload it
if (Mode != DerivativeMode::ReverseModeCombined && subretused &&
if (Mode != DerivativeMode::ReverseModeCombined &&
Mode != DerivativeMode::ForwardMode &&
Mode != DerivativeMode::ForwardModeVector && subretused &&
(orig->mayWriteToMemory() ||
!gutils->legalRecompute(orig, ValueToValueMapTy(), nullptr))) {
if (!gutils->unnecessaryIntermediates.count(orig)) {
Expand Down Expand Up @@ -7719,8 +7731,7 @@ class AdjointGenerator
cast<Function>(called), subretType, argsInverted, gutils->TLI,
TR.analyzer.interprocedural, /*returnValue*/ retUsed,
/*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr,
nextTypeInfo, uncacheable_args,
/*AtomicAdd*/ gutils->AtomicAdd);
nextTypeInfo, {});

assert(newcalled);
FunctionType *FT = cast<FunctionType>(
Expand Down
5 changes: 2 additions & 3 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ LLVMValueRef EnzymeCreateForwardDiff(
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
CDerivativeMode mode, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
uint8_t *_uncacheable_args, size_t uncacheable_args_size, uint8_t AtomicAdd,
uint8_t PostOpt) {
uint8_t *_uncacheable_args, size_t uncacheable_args_size, uint8_t PostOpt) {
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
(DIFFE_TYPE *)constant_args +
constant_args_size);
Expand All @@ -352,7 +351,7 @@ LLVMValueRef EnzymeCreateForwardDiff(
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
eunwrap(TA).TLI, eunwrap(TA), returnValue, dretUsed, (DerivativeMode)mode,
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
uncacheable_args, AtomicAdd, PostOpt));
uncacheable_args, PostOpt));
}
LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
Expand Down
15 changes: 8 additions & 7 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,14 @@ typedef enum {
DEM_ReverseModeCombined = 3,
} CDerivativeMode;

LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
CDerivativeMode mode, LLVMTypeRef additionalArg,
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, uint8_t AtomicAdd, uint8_t PostOpt);
LLVMValueRef
EnzymeCreateForwardDiff(EnzymeLogicRef, LLVMValueRef todiff,
CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA,
uint8_t returnValue, uint8_t dretUsed,
CDerivativeMode mode, LLVMTypeRef additionalArg,
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, uint8_t PostOpt);

LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ class Enzyme : public ModulePass {
newFunc = Logic.CreateForwardDiff(
cast<Function>(fn), retType, constants, TLI, TA,
/*should return*/ false, /*dretPtr*/ false, mode,
/*addedType*/ nullptr, type_args, volatile_args, AtomicAdd, PostOpt);
/*addedType*/ nullptr, type_args, volatile_args, PostOpt);
break;
case DerivativeMode::ReverseModeCombined:
newFunc = Logic.CreatePrimalAndGradient(
Expand Down
Loading

0 comments on commit 6e45ead

Please sign in to comment.