diff --git a/dipu/torch_dipu/csrc_dipu/profiler/profiler.cpp b/dipu/torch_dipu/csrc_dipu/profiler/profiler.cpp index ea23bf43f..4789b4984 100644 --- a/dipu/torch_dipu/csrc_dipu/profiler/profiler.cpp +++ b/dipu/torch_dipu/csrc_dipu/profiler/profiler.cpp @@ -1,12 +1,17 @@ #include "profiler.h" -#include #include -#include +#include +#include #include +#include #include +#include "csrc_dipu/profiler/CorrelationIDManager.h" + +#include "ThreadUtil.h" + namespace dipu { namespace profile { @@ -265,22 +270,20 @@ void abandonAllRecords() { resetId(); } -RecordCreator::RecordCreator(const string_t& name, size_t opId, +RecordCreator::RecordCreator(string_t name, size_t opId, uint64_t linkCorrelationId, - const ExtraRecordInfo& extraInfo) { + ExtraRecordInfo extraInfo) { if (isEnable()) { - name_ = name; + name_ = std::move(name); opId_ = opId; begin_ = torch::profiler::impl::getTime(); end_ = false; linkCorrelationId_ = linkCorrelationId; - extraInfo_ = extraInfo; + extraInfo_ = std::move(extraInfo); } } -RecordCreator::~RecordCreator() { end(); } - -void RecordCreator::end() { +void RecordCreator::end() noexcept { if (!end_) { RecordsImpl::get().addRecord( Record{name_, opId_, begin_, @@ -295,12 +298,12 @@ void RecordCreator::end() { DeviceRecordCreator::DeviceRecordCreator(string_t name, deviceStream_t stream, int streamId, size_t opId, uint64_t linkCorrelationId, - const ExtraRecordInfo& extraInfo) { + ExtraRecordInfo extraInfo) { if (isEnable()) { DeviceRecordsImpl::get().ensureSetup(stream); - name_ = name; + name_ = std::move(name); opId_ = opId; - extraInfo_ = extraInfo; + extraInfo_ = std::move(extraInfo); stream_ = stream; streamId_ = streamId; pStart_.reset(new DeviceEvent()); @@ -311,9 +314,7 @@ DeviceRecordCreator::DeviceRecordCreator(string_t name, deviceStream_t stream, } } -DeviceRecordCreator::~DeviceRecordCreator() { end(); } - -void DeviceRecordCreator::end() { +void DeviceRecordCreator::end() noexcept { if (!end_) { TORCH_CHECK(pStart_, "dipu profiler error with pStart_ is not inited"); TORCH_CHECK(pStop_, "dipu profiler error with pStop_ is not inited"); @@ -329,12 +330,12 @@ void DeviceRecordCreator::end() { } static std::string extraceFunction(const std::string& functionName) { - auto start = functionName.find_first_not_of(":"); + auto start = functionName.find_first_not_of(':'); if (start == std::string::npos) { return ""; } - auto end = functionName.find_first_of("("); + auto end = functionName.find_first_of('('); if (end == std::string::npos) { end = functionName.size(); } @@ -345,32 +346,18 @@ static std::string extraceFunction(const std::string& functionName) { return functionName.substr(start, end - start); } -RecordBlockCreator::RecordBlockCreator(string_t name, - const ExtraRecordInfo& extraInfo, - deviceStream_t stream, int streamId, - bool enProfile) { - if (enProfile && isEnable()) { - size_t opId = generateId(); - uint64_t correlationId = - CorrelationIDManager::instance().getCorrelationID(); - name = extraceFunction(name); - pHostRecord_.reset(new RecordCreator("LaunchKernel_" + name, opId, - correlationId, extraInfo)); - pDeviceRecord_.reset(new DeviceRecordCreator(name, stream, streamId, opId, - correlationId, extraInfo)); - } -} - -void RecordBlockCreator::end() { - if (!finish_) { - pHostRecord_.reset(); - pDeviceRecord_.reset(); - } - finish_ = true; +void RecordBlockCreator::initialize(string_t name, ExtraRecordInfo extraInfo, + deviceStream_t stream, + c10::StreamId streamId) { + size_t opId = generateId(); + uint64_t correlationId = CorrelationIDManager::instance().getCorrelationID(); + name = extraceFunction(name); + pHostRecord_ = std::make_unique("LaunchKernel_" + name, opId, + correlationId, extraInfo); + pDeviceRecord_ = std::make_unique( + std::move(name), stream, streamId, opId, correlationId, + std::move(extraInfo)); } - -RecordBlockCreator::~RecordBlockCreator() { end(); } - } // namespace profile } // namespace dipu diff --git a/dipu/torch_dipu/csrc_dipu/profiler/profiler.h b/dipu/torch_dipu/csrc_dipu/profiler/profiler.h index eed733567..7cb5a750d 100644 --- a/dipu/torch_dipu/csrc_dipu/profiler/profiler.h +++ b/dipu/torch_dipu/csrc_dipu/profiler/profiler.h @@ -1,23 +1,23 @@ #pragma once -#include -#include -#include +#include #include #include #include #include -#include #include -#include #include #include -#include +#include +#include +#include + +#include "csrc_dipu/vendor/vendorapi.h" #include #include -#include "CorrelationIDManager.h" +#include "IActivityProfiler.h" namespace dipu { namespace profile { @@ -40,11 +40,9 @@ void abandonAllRecords(); struct ExtraRecordInfo { string_t scope; - size_t opSeqId; + size_t opSeqId{}; string_t attrs; - ExtraRecordInfo() : scope(""), opSeqId(0), attrs("") {} - ExtraRecordInfo& setScope(const string_t& scopeName) { scope = scopeName; return *this; @@ -86,7 +84,6 @@ class RecordsImpl final { std::map, libkineto::ResourceInfo> resourceInfo_; - private: RecordsImpl() = default; public: @@ -112,14 +109,13 @@ class RecordCreator final { ExtraRecordInfo extraInfo_; public: - explicit RecordCreator(const string_t& name, size_t opId, - uint64_t linkCorrelationId, - const ExtraRecordInfo& extraInfo = ExtraRecordInfo()); + explicit RecordCreator(string_t name, size_t opId, uint64_t linkCorrelationId, + ExtraRecordInfo extraInfo = ExtraRecordInfo()); - ~RecordCreator(); + ~RecordCreator() { end(); } private: - void end(); + void end() noexcept; }; class DeviceEvent; @@ -148,27 +144,52 @@ class DeviceRecordCreator final { public: DeviceRecordCreator(string_t name, deviceStream_t stream, int streamId, size_t opId, uint64_t linkCorrelationId, - const ExtraRecordInfo& extraInfo = ExtraRecordInfo()); + ExtraRecordInfo extraInfo = ExtraRecordInfo()); - ~DeviceRecordCreator(); + ~DeviceRecordCreator() { end(); } private: - void end(); + void end() noexcept; }; class RecordBlockCreator { public: + // TODO(lljbash): maybe use std::string_view and std::optional after c++17 explicit RecordBlockCreator( - string_t name, const ExtraRecordInfo& extraInfo = ExtraRecordInfo(), - deviceStream_t stream = dipu::getCurrentDIPUStream(), - int streamId = dipu::getCurrentDIPUStream().id(), - bool enProfile = isEnable()); + c10::string_view name, + c10::optional extraInfo = c10::nullopt, + c10::optional stream = c10::nullopt, + c10::optional streamId = c10::nullopt, + c10::optional enProfile = c10::nullopt) { + if (enProfile.value_or(isEnable())) { + if (!extraInfo) { + extraInfo.emplace(); + } + if (!stream) { + auto dipu_stream = getCurrentDIPUStream(); + if (!streamId) { + streamId = dipu_stream.id(); + } + stream = static_cast(dipu_stream); + } + initialize(string_t(name), std::move(*extraInfo), *stream, *streamId); + } + } - void end(); + void end() noexcept { + if (!finish_) { + pHostRecord_.reset(); + pDeviceRecord_.reset(); + finish_ = true; + } + } - ~RecordBlockCreator(); + ~RecordBlockCreator() { end(); } private: + void initialize(string_t name, ExtraRecordInfo extraInfo, + deviceStream_t stream, c10::StreamId streamId); + std::unique_ptr pHostRecord_ = nullptr; std::unique_ptr pDeviceRecord_ = nullptr; bool finish_ = false;