Skip to content

Commit

Permalink
Cleanup julia api usage (#2179)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Nov 28, 2024
1 parent 32a2118 commit b59ab66
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
8 changes: 5 additions & 3 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,11 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
return unwrap(
AHandle(wrap(&B), wrap(CI), Args.size(), refs.data(), gutils));
};
shadowErasers[Name] = [=](IRBuilder<> &B, Value *ToFree) -> llvm::CallInst * {
return cast_or_null<CallInst>(unwrap(FHandle(wrap(&B), wrap(ToFree))));
};
if (FHandle)
shadowErasers[Name] = [=](IRBuilder<> &B,
Value *ToFree) -> llvm::CallInst * {
return cast_or_null<CallInst>(unwrap(FHandle(wrap(&B), wrap(ToFree))));
};
}

void EnzymeRegisterCallHandler(char *Name,
Expand Down
11 changes: 10 additions & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9325,7 +9325,16 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder,
}
if (allocationfn == "julia.gc_alloc_obj" ||
allocationfn == "jl_gc_alloc_typed" ||
allocationfn == "ijl_gc_alloc_typed")
allocationfn == "ijl_gc_alloc_typed" ||
allocationfn == "jl_alloc_array_1d" ||
allocationfn == "ijl_alloc_array_1d" ||
allocationfn == "jl_alloc_array_2d" ||
allocationfn == "ijl_alloc_array_2d" ||
allocationfn == "jl_alloc_array_3d" ||
allocationfn == "ijl_alloc_array_3d" || allocationfn == "jl_new_array" ||
allocationfn == "ijl_new_array" ||
allocationfn == "jl_alloc_genericmemory" ||
allocationfn == "ijl_alloc_genericmemory")
return nullptr;

if (allocationfn == "enzyme_allocator") {
Expand Down
16 changes: 16 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5243,6 +5243,22 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
TypeTree(BaseType::Integer).Only(-1, &call), &call);
return;
}
if (funcName == "julia.except_enter" || funcName == "ijl_excstack_state" ||
funcName == "jl_excstack_state") {
updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
return;
}
if (funcName == "jl_array_copy" || funcName == "ijl_array_copy" ||
funcName == "jl_inactive_inout" ||
funcName == "jl_genericmemory_copy_slice" ||
funcName == "ijl_genericmemory_copy_slice") {
if (direction & DOWN)
updateAnalysis(&call, getAnalysis(call.getOperand(0)), &call);
if (direction & UP)
updateAnalysis(call.getOperand(0), getAnalysis(&call), &call);
return;
}

if (isAllocationFunction(funcName, TLI)) {
size_t Idx = 0;
for (auto &Arg : ci->args()) {
Expand Down

0 comments on commit b59ab66

Please sign in to comment.