Skip to content

Commit

Permalink
Report AD checkpoint contexts (#5058)
Browse files Browse the repository at this point in the history
* Transferring source locations when creating phi instructions

* Tracking for simple variables

* Deriving source locations for loop counters

* Printing checkpoint structure breakdown

* More readable output format

* Special behavior for loop counters

* Writing report to file

* Add slangc option to enable checkpoint reports

* Display types of checkpointed fields

* Message in case there are no checkpointing contexts

* Catch source locations for function calls

* Source cleanup

* Fix compilation warnings

* Remove stray dump()

* Provide the report through diagnostic notes

* Add missing path for sourceLoc during unzip pass

* Add tests for reporting intermediates

* Include more transfer cases for source locations

* Fix ordering in address elimination

* Fill in more holes with source location transfer

* Remove debugging line

* Reverting changes to diagnostic sink

* Simplify address elimination using source location RAII contexts

* Eliminating manual source loc transfers in forward transcription

* Fix local var adaptation to use RAII location setter

* Simplify primal hoisting logic for source location transfer

* Simplify unzipping with RAII location scopes

* Simplify transpose logic

* Cleaning up for rev.cpp

* Reverting spacing changes

* Fix mistake with source loc RAII instantiation

* Fix formatting issues
  • Loading branch information
venkataram-nv authored Sep 19, 2024
1 parent 3240799 commit b808aa4
Show file tree
Hide file tree
Showing 33 changed files with 264 additions and 33 deletions.
1 change: 1 addition & 0 deletions include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ extern "C"
EmitIr, // bool
ReportDownstreamTime, // bool
ReportPerfBenchmark, // bool
ReportCheckpointIntermediates, // bool
SkipSPIRVValidation, // bool
SourceEmbedStyle,
SourceEmbedName,
Expand Down
1 change: 1 addition & 0 deletions source/slang-record-replay/util/emum-to-string.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ namespace SlangRecord
CASE(EmitIr);
CASE(ReportDownstreamTime);
CASE(ReportPerfBenchmark);
CASE(ReportCheckpointIntermediates);
CASE(SkipSPIRVValidation);
CASE(SourceEmbedStyle);
CASE(SourceEmbedName);
Expand Down
6 changes: 5 additions & 1 deletion source/slang/slang-compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2451,12 +2451,16 @@ namespace Slang
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr);
}

bool CodeGenContext::shouldReportCheckpointIntermediates()
{
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ReportCheckpointIntermediates);
}

bool CodeGenContext::shouldDumpIntermediates()
{
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIntermediates);
}


bool CodeGenContext::shouldTrackLiveness()
{
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::TrackLiveness);
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -2728,6 +2728,7 @@ namespace Slang

bool shouldValidateIR();
bool shouldDumpIR();
bool shouldReportCheckpointIntermediates();

bool shouldTrackLiveness();

Expand Down
6 changes: 6 additions & 0 deletions source/slang/slang-diagnostic-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,12 @@ DIAGNOSTIC(58002, Error, unhandledGLSLSSBOType, "Unhandled GLSL Shader Storage B

DIAGNOSTIC(58003, Error, inconsistentPointerAddressSpace, "'$0': use of pointer with inconsistent address space.")

// Autodiff checkpoint reporting
DIAGNOSTIC(-1, Note, reportCheckpointIntermediates, "checkpointing context of $1 bytes associated with function: '$0'")
DIAGNOSTIC(-1, Note, reportCheckpointVariable, "$0 bytes ($1) used to checkpoint the following item:")
DIAGNOSTIC(-1, Note, reportCheckpointCounter, "$0 bytes ($1) used for a loop counter here:")
DIAGNOSTIC(-1, Note, reportCheckpointNone, "no checkpoint contexts to report")

//
// 8xxxx - Issues specific to a particular library/technology/platform/etc.
//
Expand Down
67 changes: 67 additions & 0 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "slang-ir-wgsl-legalize.h"
#include "slang-ir-insts.h"
#include "slang-ir-inline.h"
#include "slang-ir-layout.h"
#include "slang-ir-legalize-array-return-type.h"
#include "slang-ir-legalize-mesh-outputs.h"
#include "slang-ir-legalize-varying-params.h"
Expand Down Expand Up @@ -214,6 +215,68 @@ static void dumpIRIfEnabled(
}
}

static void reportCheckpointIntermediates(CodeGenContext* codeGenContext, DiagnosticSink* sink, IRModule* irModule)
{
// Report checkpointing information
CompilerOptionSet& optionSet = codeGenContext->getTargetProgram()->getOptionSet();
SourceManager* sourceManager = sink->getSourceManager();

SourceWriter typeWriter(sourceManager, LineDirectiveMode::None, nullptr);

CLikeSourceEmitter::Desc description;
description.codeGenContext = codeGenContext;
description.sourceWriter = &typeWriter;

CPPSourceEmitter emitter(description);

int nonEmptyStructs = 0;
for (auto inst : irModule->getGlobalInsts())
{
IRStructType *structType = as<IRStructType>(inst);
if (!structType)
continue;

auto checkpointDecoration = structType->findDecoration<IRCheckpointIntermediateDecoration>();
if (!checkpointDecoration)
continue;

IRSizeAndAlignment structSize;
getNaturalSizeAndAlignment(optionSet, structType, &structSize);

// Reporting happens before empty structs are optimized out
// and we still want to keep the checkpointing decorations,
// so we end up needing to check for non-zero-ness
if (structSize.size == 0)
continue;

auto func = checkpointDecoration->getSourceFunction();
sink->diagnose(structType, Diagnostics::reportCheckpointIntermediates, func, structSize.size);
nonEmptyStructs++;

for (auto field : structType->getFields())
{
IRType *fieldType = field->getFieldType();
IRSizeAndAlignment fieldSize;
getNaturalSizeAndAlignment(optionSet, fieldType, &fieldSize);
if (fieldSize.size == 0)
continue;

typeWriter.clearContent();
emitter.emitType(fieldType);

sink->diagnose(field->sourceLoc,
field->findDecoration<IRLoopCounterDecoration>()
? Diagnostics::reportCheckpointCounter
: Diagnostics::reportCheckpointVariable,
fieldSize.size,
typeWriter.getContent());
}
}

if (nonEmptyStructs == 0)
sink->diagnose(SourceLoc(), Diagnostics::reportCheckpointNone);
}

struct LinkingAndOptimizationOptions
{
bool shouldLegalizeExistentialAndResourceTypes = true;
Expand Down Expand Up @@ -767,6 +830,10 @@ Result linkAndOptimizeIR(
break;
}

// Report checkpointing information
if (codeGenContext->shouldReportCheckpointIntermediates())
reportCheckpointIntermediates(codeGenContext, sink, irModule);

if (requiredLoweringPassSet.autodiff)
finalizeAutoDiffPass(targetProgram, irModule);

Expand Down
18 changes: 9 additions & 9 deletions source/slang/slang-ir-addr-inst-elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,28 @@ struct AddressInstEliminationContext
}
}

void transformLoadAddr(IRUse* use)
void transformLoadAddr(IRBuilder& builder, IRUse* use)
{
auto addr = use->get();
auto load = as<IRLoad>(use->getUser());

IRBuilder builder(module);
builder.setInsertBefore(use->getUser());
auto value = getValue(builder, addr);
load->replaceUsesWith(value);
load->removeAndDeallocate();
}

void transformStoreAddr(IRUse* use)
void transformStoreAddr(IRBuilder& builder, IRUse* use)
{
auto addr = use->get();
auto store = as<IRStore>(use->getUser());

IRBuilder builder(module);
builder.setInsertBefore(use->getUser());
storeValue(builder, addr, store->getVal());
store->removeAndDeallocate();
}

void transformCallAddr(IRUse* use)
void transformCallAddr(IRBuilder& builder, IRUse* use)
{
auto addr = use->get();
auto call = as<IRCall>(use->getUser());
Expand All @@ -103,7 +101,6 @@ struct AddressInstEliminationContext
return;
}

IRBuilder builder(module);
builder.setInsertBefore(call);
auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType());

Expand Down Expand Up @@ -155,17 +152,20 @@ struct AddressInstEliminationContext
use = nextUse;
continue;
}

IRBuilder transformBuilder(module);
IRBuilderSourceLocRAII sourceLocationScope(&transformBuilder, use->getUser()->sourceLoc);

switch (use->getUser()->getOp())
{
case kIROp_Load:
transformLoadAddr(use);
transformLoadAddr(transformBuilder, use);
break;
case kIROp_Store:
transformStoreAddr(use);
transformStoreAddr(transformBuilder, use);
break;
case kIROp_Call:
transformCallAddr(use);
transformCallAddr(transformBuilder, use);
break;
case kIROp_GetElementPtr:
case kIROp_FieldAddress:
Expand Down
42 changes: 30 additions & 12 deletions source/slang/slang-ir-autodiff-primal-hoist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
#include "slang-ir-autodiff-region.h"
#include "slang-ir-simplify-cfg.h"
#include "slang-ir-util.h"
#include "../core/slang-func-ptr.h"
#include "slang-ir-insts.h"
#include "slang-ir.h"
#include "../core/slang-func-ptr.h"

namespace Slang
{
Expand Down Expand Up @@ -1092,7 +1093,8 @@ IRType* getTypeForLocalStorage(
IRVar* emitIndexedLocalVar(
IRBlock* varBlock,
IRType* baseType,
const List<IndexTrackingInfo>& defBlockIndices)
const List<IndexTrackingInfo>& defBlockIndices,
SourceLoc location)
{
// Cannot store pointers. Case should have been handled by now.
SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType));
Expand All @@ -1101,6 +1103,8 @@ IRVar* emitIndexedLocalVar(
SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType));

IRBuilder varBuilder(varBlock->getModule());
IRBuilderSourceLocRAII sourceLocationScope(&varBuilder, location);

varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst());

IRType* varType = getTypeForLocalStorage(&varBuilder, baseType, defBlockIndices);
Expand Down Expand Up @@ -1179,9 +1183,14 @@ IRVar* storeIndexedValue(
IRInst* instToStore,
const List<IndexTrackingInfo>& defBlockIndices)
{
IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices);
IRVar* localVar = emitIndexedLocalVar(defaultVarBlock,
instToStore->getDataType(),
defBlockIndices,
instToStore->sourceLoc);

IRInst* addr = emitIndexedStoreAddressForVar(builder, localVar, defBlockIndices);
IRInst* addr = emitIndexedStoreAddressForVar(builder,
localVar,
defBlockIndices);

builder->emitStore(addr, instToStore);

Expand Down Expand Up @@ -1574,12 +1583,16 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
// region, that means there's no need to allocate a fully indexed var.
//
defBlockIndices = maybeTrimIndices(defBlockIndices, indexedBlockInfo, outOfScopeUses);

IRVar* localVar = storeIndexedValue(
&builder,
varBlock,
builder.emitLoad(varToStore),
defBlockIndices);

IRVar* localVar = nullptr;
{
IRBuilderSourceLocRAII sourceLocationScope(&builder, varToStore->sourceLoc);
localVar = storeIndexedValue(
&builder,
varBlock,
builder.emitLoad(varToStore),
defBlockIndices);
}

for (auto use : outOfScopeUses)
{
Expand Down Expand Up @@ -1626,6 +1639,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
}
else
{
IRBuilderSourceLocRAII sourceLocationScope(&builder, instToStore->sourceLoc);

// Handle the special case of loop counters.
// The only case where there will be a reference of primal loop counter from rev blocks
// is the start of a loop in the reverse code. Since loop counters are not considered a
Expand All @@ -1643,6 +1658,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(

setInsertAfterOrdinaryInst(&builder, instToStore);
auto localVar = storeIndexedValue(&builder, varBlock, instToStore, defBlockIndices);
if (isLoopCounter)
builder.addLoopCounterDecoration(localVar);

for (auto use : outOfScopeUses)
{
Expand Down Expand Up @@ -1728,6 +1745,8 @@ static IRBlock* getUpdateBlock(IRLoop* loop)
void lowerIndexedRegion(IRLoop*& primalLoop, IRLoop*& diffLoop, IRInst*& primalCountParam, IRInst*& diffCountParam)
{
IRBuilder builder(primalLoop);
IRBuilderSourceLocRAII sourceLocationScope(&builder, primalLoop->sourceLoc);

primalCountParam = nullptr;

// Grab first primal block.
Expand Down Expand Up @@ -1899,8 +1918,7 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
// Legalize the primal inst accesses by introducing local variables / arrays and emitting
// necessary load/store logic.
//
primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
return primalsInfo;
return ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
}

void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func)
Expand Down
Loading

0 comments on commit b808aa4

Please sign in to comment.