Skip to content

Commit

Permalink
update (#44418)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo authored Jul 19, 2022
1 parent 130c108 commit d5f0ed4
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 547 deletions.
17 changes: 4 additions & 13 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,8 @@ void AnalysisPredictor::InitPlace() {
place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (config_.thread_local_stream_enabled()) {
auto *ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
VLOG(3) << "The prediction process will be completed using a separate "
"normal-priority stream on each thread.";
ctx->ResetThreadContext(platform::stream::Priority::kNormal);
LOG_FIRST_N(WARNING, 1) << "We will remove this interface in the future. "
"Please use config.SetExecStream instead.";
}
#endif
} else if (config_.use_xpu()) {
Expand Down Expand Up @@ -1621,14 +1618,8 @@ bool AnalysisPredictor::ZeroCopyRun() {

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) {
if (stream != nullptr) {
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
auto gpu_place = place_;
auto *dev_ctx = reinterpret_cast<paddle::platform::CUDADeviceContext *>(
pool.Get(gpu_place));
dev_ctx->SetThreadLocalStream(stream);
}
LOG_FIRST_N(WARNING, 1) << "We will remove this interface in the future. "
"Please use config.SetExecStream instead.";
return ZeroCopyRun();
}
#endif
Expand Down
185 changes: 0 additions & 185 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,198 +534,13 @@ void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

thread_local std::unordered_map<const CUDADeviceContext*,
std::shared_ptr<CUDAContext>>
CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&RawStream(), place_);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
const stream::Priority& priority,
const stream::StreamFlag& flag) {
place_ = place;
CUDADeviceGuard guard(place_.device);
stream_.reset(new stream::CUDAStream(place, priority, flag));
InitEigenContext();
InitCuBlasContext();
InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
InitCuBlasLtContext();
#endif
InitCuSparseContext();
InitCuSolverContext();
#endif
}

void CUDAContext::SetStream(gpuStream_t stream) {
if (stream_->raw_stream() != stream) {
CUDADeviceGuard guard(place_.device);
DestoryCuDNNContext();
DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
DestoryCuBlasLtContext();
#endif
DestoryCuSolverContext();
#endif

stream_->SetStream(stream);

InitEigenContext();
InitCuBlasContext();
InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
InitCuBlasLtContext();
#endif
InitCuSolverContext();
#endif
}
}

CUDAContext::~CUDAContext() {
CUDADeviceGuard guard(place_.device);
DestoryCuDNNContext();
DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
InitCuBlasLtContext();
#endif
DestoryCuSparseContext();
DestoryCuSolverContext();
#endif
}

CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
phi::GPUContext::PartialInitWithoutAllocator();
cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
}

CUDADeviceContext::~CUDADeviceContext() = default;

Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
if (thread_ctx_.count(this)) {
return context()->EigenDevice().get();
}
return phi::GPUContext::eigen_device();
}

void CUDADeviceContext::Wait() const {
VLOG(4) << "CUDA context(" << this << ") Wait";
if (thread_ctx_.count(this)) {
context()->Stream()->Wait();
return;
}
phi::GPUContext::Wait();
}

#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
#endif
if (thread_ctx_.count(this)) {
return context()->CudnnHandle();
}
return phi::GPUContext::cudnn_handle();
}

#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
if (thread_ctx_.count(this)) {
return context()->CublasHandle()->GetCublasHandle();
}
return phi::GPUContext::cublas_handle();
}
#else
cublasHandle_t CUDADeviceContext::cublas_handle() const {
if (thread_ctx_.count(this)) {
return context()->CublasHandle()->GetCublasHandle();
}
return phi::GPUContext::cublas_handle();
}
#if CUDA_VERSION >= 11060
cublasLtHandle_t CUDADeviceContext::cublaslt_handle() const {
if (thread_ctx_.count(this)) {
return context()->CublasLtHandle()->GetCublasLtHandle();
}
return phi::GPUContext::cublaslt_handle();
}
#endif
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
if (thread_ctx_.count(this)) {
return context()->CusparseHandle()->GetCusparseHandle();
}
return phi::GPUContext::cusparse_handle();
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
if (thread_ctx_.count(this)) {
return context()->CusolverDnHandle();
}
return phi::GPUContext::cusolver_dn_handle();
}
#endif

void CUDADeviceContext::RecordEvent(
gpuEvent_t ev, const std::function<void()>& callback) const {
if (thread_ctx_.count(this)) {
context()->Stream()->RecordEvent(ev, callback);
return;
}
phi::GPUContext::RecordEvent(ev, callback);
}

void CUDADeviceContext::AddStreamCallback(
const std::function<void()>& callback) const {
if (thread_ctx_.count(this)) {
context()->Stream()->AddCallback(callback);
return;
}
phi::GPUContext::AddStreamCallback(callback);
}

void CUDADeviceContext::WaitStreamCallback() const {
if (thread_ctx_.count(this)) {
context()->Stream()->WaitCallback();
return;
}
phi::GPUContext::WaitStreamCallback();
}

phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
if (thread_ctx_.count(this)) {
// return workspace_.get();
return phi::DnnWorkspaceHandle(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(GetPlace())
.get(),
stream());
}
return phi::GPUContext::cudnn_workspace_handle();
}

gpuStream_t CUDADeviceContext::stream() const {
if (thread_ctx_.count(this)) {
return context()->RawStream();
}
return phi::GPUContext::stream();
}

std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
if (!thread_ctx_.count(this)) {
PADDLE_THROW(platform::errors::PermissionDenied(
"CUDADeviceContext call context() failed, make sure in the "
"thread_local semantic."));
}
return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
return cuda_stream_.get();
}
Expand Down
Loading

0 comments on commit d5f0ed4

Please sign in to comment.