diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cu b/paddle/fluid/framework/details/nan_inf_utils_detail.cu index 44668e491eb29..f9f91680e3401 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cu +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cu @@ -152,9 +152,7 @@ void TensorCheckerVisitor::apply( PADDLE_ENFORCE_CUDA_SUCCESS( cudaMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1, - cudaMemcpyHostToDevice, dev_ctx->stream()), - platform::errors::External( - "Async cudaMemcpy op_var info to gpu failed.")); + cudaMemcpyHostToDevice, dev_ctx->stream())); } else { // get auto iter = op_var2gpu_str.find(op_var); PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true, diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu index 7a032acef676b..9eefb925d2061 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -124,12 +124,9 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs, float const* input_ptr = reinterpret_cast(inputs[0]); float* const* h_odatas = reinterpret_cast(outputs); float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemcpyAsync(output_ptrs, h_odatas, - d_output_ptrs_.size() * sizeof(float*), - cudaMemcpyHostToDevice, stream), - platform::errors::External( - "CUDA Memcpy failed during split plugin run.")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync( + output_ptrs, h_odatas, d_output_ptrs_.size() * sizeof(float*), + cudaMemcpyHostToDevice, stream)); int outer_rows = outer_rows_ * batchSize; @@ -244,12 +241,9 @@ int SplitPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, float* const* h_odatas = reinterpret_cast(outputs); float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemcpyAsync(output_ptrs, h_odatas, - d_output_ptrs.size() * sizeof(float*), - cudaMemcpyHostToDevice, stream), - platform::errors::External( - "CUDA Memcpy failed during split plugin run.")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync( + output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(float*), + cudaMemcpyHostToDevice, stream)); split_kernel<<>>( d_segment_offsets.size(), d_segment_offsets_ptr, input_ptr, output_ptrs, @@ -263,12 +257,9 @@ int SplitPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, half* const* h_odatas = reinterpret_cast(outputs); half** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemcpyAsync(output_ptrs, h_odatas, - d_output_ptrs.size() * sizeof(half*), - cudaMemcpyHostToDevice, stream), - platform::errors::External( - "CUDA Memcpy failed during split plugin run.")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync( + output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(half*), + cudaMemcpyHostToDevice, stream)); split_kernel<<>>( d_segment_offsets.size(), d_segment_offsets_ptr, input_ptr, output_ptrs, diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocator.h b/paddle/fluid/memory/allocation/cuda_device_context_allocator.h index 0997f575acc4e..2163562a6080b 100644 --- a/paddle/fluid/memory/allocation/cuda_device_context_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_device_context_allocator.h @@ -80,17 +80,13 @@ class CUDADeviceContextAllocator : public Allocator { : place_(place), default_stream_(default_stream) { platform::CUDADeviceGuard guard(place_.device); PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventCreate(&event_, cudaEventDisableTiming), - platform::errors::External( - "Create event failed in CUDADeviceContextAllocator")); + cudaEventCreate(&event_, cudaEventDisableTiming)); } ~CUDADeviceContextAllocator() { if (event_) { platform::CUDADeviceGuard guard(place_.device); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventDestroy(event_), - "Destory event failed in CUDADeviceContextAllocator destroctor"); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(event_)); } } @@ -103,12 +99,9 @@ class CUDADeviceContextAllocator : public Allocator { auto allocation = new CUDADeviceContextAllocation(memory::Alloc(place_, size)); // Wait for the event on stream + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event_, default_stream_)); PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventRecord(event_, default_stream_), - "Failed to record event in CUDADeviceContextAllocator"); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaStreamWaitEvent(default_stream_, event_, 0), - "Failed to wait event in CUDADeviceContextAllocator"); + cudaStreamWaitEvent(default_stream_, event_, 0)); return allocation; } diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index 006bf559195aa..cbd7e33bc6b72 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -141,12 +141,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, cu_stream); } - PADDLE_ENFORCE_CUDA_SUCCESS( - err, - "ArgSortOP failed as could not launch " - "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate" - "temp_storage_bytes, status:%s.", - temp_storage_bytes, cudaGetErrorString(err)); + PADDLE_ENFORCE_CUDA_SUCCESS(err); Tensor temp_storage; temp_storage.mutable_data(ctx.GetPlace(), temp_storage_bytes); @@ -165,12 +160,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, cu_stream); } - PADDLE_ENFORCE_CUDA_SUCCESS( - err, - "ArgSortOP failed as could not launch " - "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, " - "temp_storage_bytes:%d status:%s.", - temp_storage_bytes, cudaGetErrorString(err)); + PADDLE_ENFORCE_CUDA_SUCCESS(err); } template diff --git a/paddle/fluid/operators/fused/fused_bn_activation_op.cu b/paddle/fluid/operators/fused/fused_bn_activation_op.cu index 2e308657936c0..32eaf1180977a 100644 --- a/paddle/fluid/operators/fused/fused_bn_activation_op.cu +++ b/paddle/fluid/operators/fused/fused_bn_activation_op.cu @@ -108,32 +108,21 @@ class FusedBatchNormActKernel cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&data_desc_), - platform::errors::External( - "The error has happened when calling " - "cudnnCreateTensorDescriptor(&data_desc_).")); + platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_), - platform::errors::External( - "The error has happened when calling " - "cudnnCreateTensorDescriptor(&bn_param_desc_).")); + platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); VLOG(3) << "Setting descriptors."; std::vector dims = {N, C, H, W, D}; std::vector strides = {H * W * D * C, 1, W * D * C, D * C, C}; - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnSetTensorNdDescriptor( - data_desc_, CudnnDataType::type, - x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()), - platform::errors::External( - "The error has happened when calling cudnnSetTensorNdDescriptor.")); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_, - data_desc_, mode_), - platform::errors::External("The error has happened when calling " - "cudnnDeriveBNTensorDescriptor.")); + data_desc_, mode_)); double this_factor = 1. - momentum; cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; @@ -166,10 +155,7 @@ class FusedBatchNormActKernel /*yDesc=*/data_desc_, /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, /*activationDesc=*/activation_desc_, - /*sizeInBytes=*/&workspace_size), - platform::errors::External( - "The error has happened when calling " - "cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize.")); + /*sizeInBytes=*/&workspace_size)); // -------------- cudnn batchnorm reserve space -------------- PADDLE_ENFORCE_CUDA_SUCCESS( @@ -179,10 +165,7 @@ class FusedBatchNormActKernel /*bnOps=*/bnOps_, /*activationDesc=*/activation_desc_, /*xDesc=*/data_desc_, - /*sizeInBytes=*/&reserve_space_size), - platform::errors::External( - "The error has happened when calling " - "cudnnGetBatchNormalizationTrainingExReserveSpaceSize.")); + /*sizeInBytes=*/&reserve_space_size)); reserve_space_ptr = reserve_space->mutable_data(ctx.GetPlace(), x->type(), reserve_space_size); @@ -204,22 +187,13 @@ class FusedBatchNormActKernel saved_variance->template mutable_data>( ctx.GetPlace()), activation_desc_, workspace_ptr, workspace_size, reserve_space_ptr, - reserve_space_size), - platform::errors::External( - "The error has happened when calling " - "cudnnBatchNormalizationForwardTrainingEx.")); + reserve_space_size)); // clean when exit. PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(data_desc_), - platform::errors::External( - "The error has happened when calling " - "cudnnDestroyTensorDescriptor(data_desc_).")); + platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_), - platform::errors::External( - "The error has happened when calling " - "cudnnDestroyTensorDescriptor(bn_param_desc_).")); + platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); } }; @@ -298,15 +272,9 @@ class FusedBatchNormActGradKernel cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&data_desc_), - platform::errors::External( - "The error has happened when calling " - "cudnnCreateTensorDescriptor(&data_desc_).")); + platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_), - platform::errors::External( - "The error has happened when calling " - "cudnnCreateTensorDescriptor(&bn_param_desc_).")); + platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { LOG(ERROR) << "Provided epsilon is smaller than " << "CUDNN_BN_MIN_EPSILON. Setting it to " @@ -314,17 +282,12 @@ class FusedBatchNormActGradKernel } epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnSetTensorNdDescriptor( - data_desc_, CudnnDataType::type, - x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()), - platform::errors::External( - "The error has happened when calling cudnnSetTensorNdDescriptor.")); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_, - data_desc_, mode_), - platform::errors::External("The error has happened when calling " - "cudnnDeriveBNTensorDescriptor.")); + data_desc_, mode_)); const auto *saved_mean = ctx.Input("SavedMean"); const auto *saved_var = ctx.Input("SavedVariance"); @@ -354,10 +317,7 @@ class FusedBatchNormActGradKernel /*dxDesc=*/data_desc_, /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, /*activationDesc=*/activation_desc_, - /*sizeInBytes=*/&workspace_size), - platform::errors::External( - "The error has happened when calling " - "cudnnGetBatchNormalizationBackwardExWorkspaceSize.")); + /*sizeInBytes=*/&workspace_size)); workspace_ptr = workspace_tensor.mutable_data(ctx.GetPlace(), x->type(), workspace_size); @@ -395,21 +355,13 @@ class FusedBatchNormActGradKernel /*workspace=*/workspace_ptr, /*workSpaceSizeInBytes=*/workspace_size, /*reserveSpace=*/const_cast(reserve_space->template data()), - /*reserveSpaceSizeInBytes=*/reserve_space_size), - platform::errors::External("The error has happened when calling " - "cudnnBatchNormalizationBackwardEx.")); + /*reserveSpaceSizeInBytes=*/reserve_space_size)); // clean when exit. PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(data_desc_), - platform::errors::External( - "The error has happened when calling " - "cudnnDestroyTensorDescriptor(data_desc_).")); + platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_), - platform::errors::External( - "The error has happened when calling " - "cudnnDestroyTensorDescriptor(bn_param_desc_).")); + platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); } }; diff --git a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc index b61ef8e566b77..17cb4556d45ef 100644 --- a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc +++ b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc @@ -46,13 +46,9 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel { cudnnTensorDescriptor_t in_desc; cudnnTensorDescriptor_t out_desc; PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&in_desc), - platform::errors::External("Create cudnn tensor descriptor failed in " - "transpose_flatten_concat_fusion op.")); + platform::dynload::cudnnCreateTensorDescriptor(&in_desc)); PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&out_desc), - platform::errors::External("Create cudnn tensor descriptor failed in " - "transpose_flatten_concat_fusion op.")); + platform::dynload::cudnnCreateTensorDescriptor(&out_desc)); cudnnDataType_t cudnn_dtype = CudnnDataType::type; auto& dev_ctx = ctx.template device_context(); @@ -91,24 +87,15 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel { dims_y[i] = 1; } - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnSetTensorNdDescriptor( - in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data()), - platform::errors::External("Create cudnn tensorNd descriptor failed " - "in transpose_flatten_concat op.")); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnSetTensorNdDescriptor( - out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data()), - platform::errors::External("Create cudnn tensorNd descriptor failed " - "in transpose_flatten_concat op.")); - - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnTransformTensor( - handle, CudnnDataType::kOne(), in_desc, - static_cast(ins[k]->data()), - CudnnDataType::kZero(), out_desc, static_cast(odata)), - platform::errors::External("Create cudnn transform tensor failed in " - "transpose_flatten_concat op.")); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( + in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data())); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( + out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data())); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnTransformTensor( + handle, CudnnDataType::kOne(), in_desc, + static_cast(ins[k]->data()), + CudnnDataType::kZero(), out_desc, static_cast(odata))); if (concat_axis == 0) { odata += osize; } else { @@ -117,13 +104,9 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel { } } PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(in_desc), - platform::errors::External( - "Destory cudnn descriptor failed in transpose_flatten_concat op.")); + platform::dynload::cudnnDestroyTensorDescriptor(in_desc)); PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(out_desc), - platform::errors::External( - "Destory cudnn descriptor failed in transpose_flatten_concat op.")); + platform::dynload::cudnnDestroyTensorDescriptor(out_desc)); } }; diff --git a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc index c266b0d32b14a..3bf34fc685ee8 100644 --- a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc +++ b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc @@ -60,13 +60,10 @@ class CUDNNGridSampleOpKernel : public framework::OpKernel { cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( DataLayout::kNCHW, framework::vectorize(output->dims())); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnSpatialTfSamplerForward( - handle, cudnn_st_desc, CudnnDataType::kOne(), cudnn_input_desc, - input_data, grid_data, CudnnDataType::kZero(), cudnn_output_desc, - output_data), - platform::errors::InvalidArgument( - "cudnnSpatialTfSamplerForward in Op(grid_sampler) failed")); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSpatialTfSamplerForward( + handle, cudnn_st_desc, CudnnDataType::kOne(), cudnn_input_desc, + input_data, grid_data, CudnnDataType::kZero(), cudnn_output_desc, + output_data)); } }; @@ -122,9 +119,7 @@ class CUDNNGridSampleGradOpKernel : public framework::OpKernel { input_data, CudnnDataType::kZero(), cudnn_input_grad_desc, input_grad_data, CudnnDataType::kOne(), cudnn_output_grad_desc, output_grad_data, grid_data, CudnnDataType::kZero(), - grid_grad_data), - platform::errors::InvalidArgument( - "cudnnSpatialTfSamplerBackward in Op(grid_sampler) failed")); + grid_grad_data)); } }; diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index c0ab35b0e753c..8e903a4eccc74 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -41,16 +41,12 @@ struct CUBlas { template static void SCAL(ARGS... args) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cublasSscal(args...), - platform::errors::External("dynload cublasSscal lib failed")); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSscal(args...)); } template static void VCOPY(ARGS... args) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cublasScopy(args...), - platform::errors::External("dynload cublasScopy lib failed")); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasScopy(args...)); } template @@ -108,16 +104,12 @@ struct CUBlas { template static void SCAL(ARGS... args) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cublasDscal(args...), - platform::errors::External("dynload cublasDscal lib failed")); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDscal(args...)); } template static void VCOPY(ARGS... args) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cublasDcopy(args...), - platform::errors::External("dynload cublasDcopy lib failed")); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDcopy(args...)); } template diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index cdd138d7bdc99..7b7fa49f171c1 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -59,17 +59,14 @@ class MeanCUDAKernel : public framework::OpKernel { auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, trans_x, out_data, size_prob, stream); - PADDLE_ENFORCE_CUDA_SUCCESS(err, - "MeanOP failed to get reduce workspace size", - cudaGetErrorString(err)); + PADDLE_ENFORCE_CUDA_SUCCESS(err); framework::Tensor tmp; auto* temp_storage = tmp.mutable_data( framework::make_ddim({static_cast(temp_storage_bytes)}), context.GetPlace()); err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, trans_x, out_data, size_prob, stream); - PADDLE_ENFORCE_CUDA_SUCCESS(err, "MeanOP failed to run reduce computation", - cudaGetErrorString(err)); + PADDLE_ENFORCE_CUDA_SUCCESS(err); } }; diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index b237df130abcc..e72820611d3a9 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -104,13 +104,9 @@ void BufferedReader::ReadAsync(size_t i) { // gpu memory immediately without waiting gpu kernel ends platform::SetDeviceId(boost::get(place_).device); PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventRecord(events_[i].get(), compute_stream_), - platform::errors::Fatal( - "cudaEventRecord raises unexpected exception")); + cudaEventRecord(events_[i].get(), compute_stream_)); PADDLE_ENFORCE_CUDA_SUCCESS( - cudaStreamWaitEvent(stream_.get(), events_[i].get(), 0), - platform::errors::Fatal( - "cudaStreamWaitEvent raises unexpected exception")); + cudaStreamWaitEvent(stream_.get(), events_[i].get(), 0)); platform::RecordEvent record_event("BufferedReader:MemoryCopy"); for (size_t i = 0; i < cpu.size(); ++i) { @@ -138,17 +134,11 @@ void BufferedReader::ReadAsync(size_t i) { size); memory::Copy(boost::get(place_), gpu_ptr, cuda_pinned_place, cuda_pinned_ptr, size, stream_.get()); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaStreamSynchronize(stream_.get()), - platform::errors::Fatal( - "cudaStreamSynchronize raises unexpected exception")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_.get())); } gpu[i].set_lod(cpu[i].lod()); } - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaStreamSynchronize(stream_.get()), - platform::errors::Fatal( - "cudaStreamSynchronize raises unexpected exception")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_.get())); } #endif return i; diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu.h b/paddle/fluid/operators/sync_batch_norm_op.cu.h index 083d22aa2a38a..cfb9e16942c25 100644 --- a/paddle/fluid/operators/sync_batch_norm_op.cu.h +++ b/paddle/fluid/operators/sync_batch_norm_op.cu.h @@ -191,12 +191,9 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx, if (comm) { int dtype = platform::ToNCCLDataType(mean_out->type()); // In-place operation - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::ncclAllReduce(stats, stats, 2 * C + 1, - static_cast(dtype), - ncclSum, comm, stream), - platform::errors::InvalidArgument( - "ncclAllReduce in Op(sync_batch_norm) failed")); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + stats, stats, 2 * C + 1, static_cast(dtype), ncclSum, + comm, stream)); } #endif @@ -468,12 +465,9 @@ void SyncBatchNormGradFunctor( if (comm) { int dtype = platform::ToNCCLDataType(scale->type()); // In-place operation - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::ncclAllReduce(stats, stats, 2 * C + 1, - static_cast(dtype), - ncclSum, comm, stream), - platform::errors::InvalidArgument( - "ncclAllReduce in Op(sync_batch_norm) failed")); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + stats, stats, 2 * C + 1, static_cast(dtype), ncclSum, + comm, stream)); } #endif diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h index 54f5e911e3d0f..74cf5545239f1 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_helper.h @@ -29,14 +29,7 @@ namespace platform { class CublasHandleHolder { public: CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) { - PADDLE_ENFORCE_CUDA_SUCCESS( - dynload::cublasCreate(&handle_), - platform::errors::External( - "The cuBLAS library was not initialized. This is usually caused by " - "an error in the CUDA Runtime API called by the cuBLAS routine, or " - "an error in the hardware setup.\n" - "To correct: check that the hardware, an appropriate version of " - "the driver, and the cuBLAS library are correctly installed.")); + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cublasCreate(&handle_)); PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cublasSetStream(handle_, stream)); #if CUDA_VERSION >= 9000 if (math_type == CUBLAS_TENSOR_OP_MATH) { diff --git a/paddle/fluid/platform/cuda_resource_pool.cc b/paddle/fluid/platform/cuda_resource_pool.cc index 1828f0760a79a..65c8b96028ace 100644 --- a/paddle/fluid/platform/cuda_resource_pool.cc +++ b/paddle/fluid/platform/cuda_resource_pool.cc @@ -27,18 +27,13 @@ CudaStreamResourcePool::CudaStreamResourcePool() { platform::SetDeviceId(dev_idx); cudaStream_t stream; PADDLE_ENFORCE_CUDA_SUCCESS( - cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking), - platform::errors::Fatal( - "cudaStreamCreateWithFlags raises unexpected exception")); + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); return stream; }; auto deleter = [dev_idx](cudaStream_t stream) { platform::SetDeviceId(dev_idx); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaStreamDestroy(stream), - platform::errors::Fatal( - "cudaStreamDestroy raises unexpected exception")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream)); }; pool_.emplace_back( @@ -72,18 +67,13 @@ CudaEventResourcePool::CudaEventResourcePool() { platform::SetDeviceId(dev_idx); cudaEvent_t event; PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventCreateWithFlags(&event, cudaEventDisableTiming), - platform::errors::Fatal( - "cudaEventCreateWithFlags raises unexpected exception")); + cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); return event; }; auto deleter = [dev_idx](cudaEvent_t event) { platform::SetDeviceId(dev_idx); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventDestroy(event), - platform::errors::Fatal( - "cudaEventDestroy raises unexpected exception")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(event)); }; pool_.emplace_back(ResourcePool::Create(creator, deleter)); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 322a32796787e..3bffa72bb927d 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -278,12 +278,9 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { << "Please recompile or reinstall Paddle with compatible CUDNN " "version."; } + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE_CUDA_SUCCESS( - dynload::cudnnCreate(&cudnn_handle_), - "Failed to create Cudnn handle in DeviceContext"); - PADDLE_ENFORCE_CUDA_SUCCESS( - dynload::cudnnSetStream(cudnn_handle_, stream_), - "Failed to set stream for Cudnn handle in DeviceContext"); + dynload::cudnnSetStream(cudnn_handle_, stream_)); } else { cudnn_handle_ = nullptr; } @@ -302,8 +299,7 @@ CUDADeviceContext::~CUDADeviceContext() { eigen_device_.reset(); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_)); if (cudnn_handle_) { - PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_), - "Failed to destory Cudnn handle"); + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_)); } #if defined(PADDLE_WITH_NCCL) if (nccl_comm_) { @@ -325,10 +321,7 @@ void CUDADeviceContext::Wait() const { } #endif - PADDLE_ENFORCE_CUDA_SUCCESS( - e_sync, platform::errors::Fatal( - "cudaStreamSynchronize raises error: %s, errono: %d", - cudaGetErrorString(e_sync), static_cast(e_sync))); + PADDLE_ENFORCE_CUDA_SUCCESS(e_sync); } int CUDADeviceContext::GetComputeCapability() const { diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 8aa42a4679df6..da5c207f767df 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -474,6 +474,9 @@ struct EOFException : public std::exception { /** CUDA PADDLE ENFORCE FUNCTIONS AND MACROS **/ #ifdef PADDLE_WITH_CUDA +/** cuda ERROR **/ +inline bool is_error(cudaError_t e) { return e != cudaSuccess; } + inline std::string GetCudaErrorWebsite(int32_t cuda_version) { std::ostringstream webstr; webstr << "https://docs.nvidia.com/cuda/"; @@ -486,7 +489,7 @@ inline std::string GetCudaErrorWebsite(int32_t cuda_version) { return webstr.str(); } -inline std::string GetCudaErrorMessage(cudaError_t e) { +inline std::string build_nvidia_error_msg(cudaError_t e) { #if CUDA_VERSION >= 10000 && CUDA_VERSION < 11000 int32_t cuda_version = 100; #elif CUDA_VERSION >= 9000 @@ -495,14 +498,14 @@ inline std::string GetCudaErrorMessage(cudaError_t e) { int32_t cuda_version = -1; #endif std::ostringstream sout; - sout << " CUDA runtime error(" << e << "), " << cudaGetErrorString(e) << "."; + sout << " Cuda error(" << e << "), " << cudaGetErrorString(e) << "."; static platform::proto::cudaerrorDesc cudaerror; static bool _initSucceed = false; if (cudaerror.ByteSizeLong() == 0) { std::string filePath; #if !defined(_WIN32) Dl_info info; - if (dladdr(reinterpret_cast(GetCudaErrorMessage), &info)) { + if (dladdr(reinterpret_cast(GetCudaErrorWebsite), &info)) { std::string strModule(info.dli_fname); const size_t last_slash_idx = strModule.find_last_of("/"); std::string compare_path = strModule.substr(strModule.length() - 6); @@ -521,7 +524,7 @@ inline std::string GetCudaErrorMessage(cudaError_t e) { char buf[100]; MEMORY_BASIC_INFORMATION mbi; HMODULE h_module = - (::VirtualQuery(GetCudaErrorMessage, &mbi, sizeof(mbi)) != 0) + (::VirtualQuery(GetCudaErrorWebsite, &mbi, sizeof(mbi)) != 0) ? (HMODULE)mbi.AllocationBase : NULL; GetModuleFileName(h_module, buf, 100); @@ -562,19 +565,6 @@ inline std::string GetCudaErrorMessage(cudaError_t e) { return sout.str(); } -inline bool is_error(cudaError_t e) { return e != cudaSuccess; } - -inline std::string build_ex_string(cudaError_t e, const std::string& msg) { - // note(zhouwei): the generated message when developer don't input error - // message, but it is not needed when CUDA ERROR; - // Better method is to refactor class ErrorSummary or - // PADDLE_ENFORCE_CUDA_SUCCESS. - if (msg.find("An error occurred here") != std::string::npos) { - return platform::errors::External(GetCudaErrorMessage(e)).ToString(); - } - return msg + GetCudaErrorMessage(e); -} - inline void throw_on_error(cudaError_t e, const std::string& msg) { #ifndef REPLACE_ENFORCE_GLOG throw std::runtime_error(msg); @@ -583,13 +573,47 @@ inline void throw_on_error(cudaError_t e, const std::string& msg) { #endif } +/** curand ERROR **/ inline bool is_error(curandStatus_t stat) { return stat != CURAND_STATUS_SUCCESS; } -inline std::string build_ex_string(curandStatus_t stat, - const std::string& msg) { - return msg; +inline const char* curandGetErrorString(curandStatus_t stat) { + switch (stat) { + case CURAND_STATUS_SUCCESS: + return "CURAND_STATUS_SUCCESS"; + case CURAND_STATUS_VERSION_MISMATCH: + return "CURAND_STATUS_VERSION_MISMATCH"; + case CURAND_STATUS_NOT_INITIALIZED: + return "CURAND_STATUS_NOT_INITIALIZED"; + case CURAND_STATUS_ALLOCATION_FAILED: + return "CURAND_STATUS_ALLOCATION_FAILED"; + case CURAND_STATUS_TYPE_ERROR: + return "CURAND_STATUS_TYPE_ERROR"; + case CURAND_STATUS_OUT_OF_RANGE: + return "CURAND_STATUS_OUT_OF_RANGE"; + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + case CURAND_STATUS_LAUNCH_FAILURE: + return "CURAND_STATUS_LAUNCH_FAILURE"; + case CURAND_STATUS_PREEXISTING_FAILURE: + return "CURAND_STATUS_PREEXISTING_FAILURE"; + case CURAND_STATUS_INITIALIZATION_FAILED: + return "CURAND_STATUS_INITIALIZATION_FAILED"; + case CURAND_STATUS_ARCH_MISMATCH: + return "CURAND_STATUS_ARCH_MISMATCH"; + case CURAND_STATUS_INTERNAL_ERROR: + return "CURAND_STATUS_INTERNAL_ERROR"; + default: + return "Unknown curand status"; + } +} + +inline std::string build_nvidia_error_msg(curandStatus_t stat) { + std::string msg(" Curand error, "); + return msg + curandGetErrorString(stat) + " "; } inline void throw_on_error(curandStatus_t stat, const std::string& msg) { @@ -601,13 +625,14 @@ inline void throw_on_error(curandStatus_t stat, const std::string& msg) { #endif } +/** cudnn ERROR **/ inline bool is_error(cudnnStatus_t stat) { return stat != CUDNN_STATUS_SUCCESS; } -inline std::string build_ex_string(cudnnStatus_t stat, const std::string& msg) { - return msg + "\n [Hint: " + platform::dynload::cudnnGetErrorString(stat) + - "]"; +inline std::string build_nvidia_error_msg(cudnnStatus_t stat) { + std::string msg(" Cudnn error, "); + return msg + platform::dynload::cudnnGetErrorString(stat) + " "; } inline void throw_on_error(cudnnStatus_t stat, const std::string& msg) { @@ -618,33 +643,39 @@ inline void throw_on_error(cudnnStatus_t stat, const std::string& msg) { #endif } +/** cublas ERROR **/ inline bool is_error(cublasStatus_t stat) { return stat != CUBLAS_STATUS_SUCCESS; } -inline std::string build_ex_string(cublasStatus_t stat, - const std::string& msg) { - std::string err; - if (stat == CUBLAS_STATUS_NOT_INITIALIZED) { - err = "CUBLAS_STATUS_NOT_INITIALIZED"; - } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { - err = "CUBLAS_STATUS_ALLOC_FAILED"; - } else if (stat == CUBLAS_STATUS_INVALID_VALUE) { - err = "CUBLAS_STATUS_INVALID_VALUE"; - } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) { - err = "CUBLAS_STATUS_ARCH_MISMATCH"; - } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) { - err = "CUBLAS_STATUS_MAPPING_ERROR"; - } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) { - err = "CUBLAS_STATUS_EXECUTION_FAILED"; - } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) { - err = "CUBLAS_STATUS_INTERNAL_ERROR"; - } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) { - err = "CUBLAS_STATUS_NOT_SUPPORTED"; - } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { - err = "CUBLAS_STATUS_LICENSE_ERROR"; +inline const char* cublasGetErrorString(cublasStatus_t stat) { + switch (stat) { + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + default: + return "Unknown cublas status"; } - return msg + "\n [Hint: " + err + "]"; +} + +inline std::string build_nvidia_error_msg(cublasStatus_t stat) { + std::string msg(" Cublas error, "); + return msg + cublasGetErrorString(stat) + " "; } inline void throw_on_error(cublasStatus_t stat, const std::string& msg) { @@ -655,15 +686,15 @@ inline void throw_on_error(cublasStatus_t stat, const std::string& msg) { #endif } +/** nccl ERROR **/ #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) inline bool is_error(ncclResult_t nccl_result) { return nccl_result != ncclSuccess; } -inline std::string build_ex_string(ncclResult_t nccl_result, - const std::string& msg) { - return msg + "\n [" + platform::dynload::ncclGetErrorString(nccl_result) + - "]"; +inline std::string build_nvidia_error_msg(ncclResult_t nccl_result) { + std::string msg(" Nccl error, "); + return msg + platform::dynload::ncclGetErrorString(nccl_result) + " "; } inline void throw_on_error(ncclResult_t nccl_result, const std::string& msg) { @@ -673,11 +704,8 @@ inline void throw_on_error(ncclResult_t nccl_result, const std::string& msg) { LOG(FATAL) << msg; #endif } -#endif // __APPLE__ and windows +#endif // not(__APPLE__) and PADDLE_WITH_NCCL -#endif // PADDLE_WITH_CUDA - -#ifdef PADDLE_WITH_CUDA namespace details { template @@ -700,30 +728,28 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); #endif } // namespace details -#endif // PADDLE_WITH_CUDA -#ifdef PADDLE_WITH_CUDA -#define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \ - do { \ - auto __cond__ = (COND); \ - using __CUDA_STATUS_TYPE__ = decltype(__cond__); \ - constexpr auto __success_type__ = \ - ::paddle::platform::details::CudaStatusType< \ - __CUDA_STATUS_TYPE__>::kSuccess; \ - if (UNLIKELY(__cond__ != __success_type__)) { \ - try { \ - ::paddle::platform::throw_on_error( \ - __cond__, \ - ::paddle::platform::build_ex_string( \ - __cond__, \ - ::paddle::platform::ErrorSummary(__VA_ARGS__).ToString())); \ - } catch (...) { \ - HANDLE_THE_ERROR \ - throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ - __FILE__, __LINE__); \ - END_HANDLE_THE_ERROR \ - } \ - } \ +#define PADDLE_ENFORCE_CUDA_SUCCESS(COND) \ + do { \ + auto __cond__ = (COND); \ + using __CUDA_STATUS_TYPE__ = decltype(__cond__); \ + constexpr auto __success_type__ = \ + ::paddle::platform::details::CudaStatusType< \ + __CUDA_STATUS_TYPE__>::kSuccess; \ + if (UNLIKELY(__cond__ != __success_type__)) { \ + try { \ + ::paddle::platform::throw_on_error( \ + __cond__, \ + ::paddle::platform::errors::External( \ + ::paddle::platform::build_nvidia_error_msg(__cond__)) \ + .ToString()); \ + } catch (...) { \ + HANDLE_THE_ERROR \ + throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ + __FILE__, __LINE__); \ + END_HANDLE_THE_ERROR \ + } \ + } \ } while (0) #undef DEFINE_CUDA_STATUS_TYPE diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index 0057c784528c2..db77ba95856d9 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -261,15 +261,14 @@ TEST(EOF_EXCEPTION, THROW_EOF) { #ifdef PADDLE_WITH_CUDA template bool CheckCudaStatusSuccess(T value, const std::string& msg = "success") { - PADDLE_ENFORCE_CUDA_SUCCESS(value, msg); + PADDLE_ENFORCE_CUDA_SUCCESS(value); return true; } template -bool CheckCudaStatusFailure( - T value, const std::string& msg = "self-defined cuda status failed") { +bool CheckCudaStatusFailure(T value, const std::string& msg) { try { - PADDLE_ENFORCE_CUDA_SUCCESS(value, msg); + PADDLE_ENFORCE_CUDA_SUCCESS(value); return false; } catch (paddle::platform::EnforceNotMet& error) { std::string ex_msg = error.what(); @@ -279,24 +278,29 @@ bool CheckCudaStatusFailure( TEST(enforce, cuda_success) { EXPECT_TRUE(CheckCudaStatusSuccess(cudaSuccess)); - EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorInvalidValue)); - EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorMemoryAllocation)); + EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorInvalidValue, "Cuda error")); + EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorMemoryAllocation, "Cuda error")); EXPECT_TRUE(CheckCudaStatusSuccess(CURAND_STATUS_SUCCESS)); - EXPECT_TRUE(CheckCudaStatusFailure(CURAND_STATUS_VERSION_MISMATCH)); - EXPECT_TRUE(CheckCudaStatusFailure(CURAND_STATUS_NOT_INITIALIZED)); + EXPECT_TRUE( + CheckCudaStatusFailure(CURAND_STATUS_VERSION_MISMATCH, "Curand error")); + EXPECT_TRUE( + CheckCudaStatusFailure(CURAND_STATUS_NOT_INITIALIZED, "Curand error")); EXPECT_TRUE(CheckCudaStatusSuccess(CUDNN_STATUS_SUCCESS)); - EXPECT_TRUE(CheckCudaStatusFailure(CUDNN_STATUS_NOT_INITIALIZED)); - EXPECT_TRUE(CheckCudaStatusFailure(CUDNN_STATUS_ALLOC_FAILED)); + EXPECT_TRUE( + CheckCudaStatusFailure(CUDNN_STATUS_NOT_INITIALIZED, "Cudnn error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUDNN_STATUS_ALLOC_FAILED, "Cudnn error")); EXPECT_TRUE(CheckCudaStatusSuccess(CUBLAS_STATUS_SUCCESS)); - EXPECT_TRUE(CheckCudaStatusFailure(CUBLAS_STATUS_NOT_INITIALIZED)); - EXPECT_TRUE(CheckCudaStatusFailure(CUBLAS_STATUS_INVALID_VALUE)); + EXPECT_TRUE( + CheckCudaStatusFailure(CUBLAS_STATUS_NOT_INITIALIZED, "Cublas error")); + EXPECT_TRUE( + CheckCudaStatusFailure(CUBLAS_STATUS_INVALID_VALUE, "Cublas error")); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess)); - EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError)); - EXPECT_TRUE(CheckCudaStatusFailure(ncclSystemError)); + EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError, "Nccl error")); + EXPECT_TRUE(CheckCudaStatusFailure(ncclSystemError, "Nccl error")); #endif } #endif diff --git a/paddle/fluid/platform/errors.h b/paddle/fluid/platform/errors.h index 09cbf4ce1a8ea..1b379e3bf6611 100644 --- a/paddle/fluid/platform/errors.h +++ b/paddle/fluid/platform/errors.h @@ -34,8 +34,8 @@ class ErrorSummary { // This constructor is only used to be compatible with // current existing no error message PADDLE_ENFORCE_* // Note(zhouwei): PADDLE_ENFORCE_CUDA_SUCCESS error message - // can be get from API or Nvidia official website, error - // message from developer is not necessary + // can be get automatically, error message from developer + // is not necessary ErrorSummary() { code_ = paddle::platform::error::LEGACY; msg_ = diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index a03190c845fb5..c07abba9e8ef9 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -273,22 +273,12 @@ size_t GpuMaxChunkSize() { void GpuMemcpyAsync(void *dst, const void *src, size_t count, enum cudaMemcpyKind kind, cudaStream_t stream) { - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemcpyAsync(dst, src, count, kind, stream), - platform::errors::External( - "cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync " - "(%p -> %p, length: %d). ", - src, dst, static_cast(count))); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(dst, src, count, kind, stream)); } void GpuMemcpySync(void *dst, const void *src, size_t count, enum cudaMemcpyKind kind) { - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemcpy(dst, src, count, kind), - platform::errors::External( - "cudaMemcpy failed in paddle::platform::GpuMemcpySync " - "(%p -> %p, length: %d). ", - src, dst, static_cast(count))); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(dst, src, count, kind)); } void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src, @@ -397,8 +387,7 @@ class RecordedCudaMallocHelper { CUDADeviceGuard guard(dev_id_); auto err = cudaFree(ptr); if (err != cudaErrorCudartUnloading) { - PADDLE_ENFORCE_CUDA_SUCCESS( - err, platform::errors::External("cudaFree raises unexpected error")); + PADDLE_ENFORCE_CUDA_SUCCESS(err); if (NeedRecord()) { std::lock_guard guard(*mtx_); cur_size_ -= size; diff --git a/paddle/fluid/platform/profiler_helper.h b/paddle/fluid/platform/profiler_helper.h index 41d5180ffaf5f..af27564b99f79 100644 --- a/paddle/fluid/platform/profiler_helper.h +++ b/paddle/fluid/platform/profiler_helper.h @@ -117,10 +117,7 @@ void SynchronizeAllDevice() { int count = GetCUDADeviceCount(); for (int i = 0; i < count; i++) { SetDeviceId(i); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaDeviceSynchronize(), - platform::errors::External( - "Device synchronize failed in cudaDeviceSynchronize()")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceSynchronize()); } #endif } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index ca3be55d4e5f5..1f523ef4a7dd4 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1239,6 +1239,22 @@ All parameter, weight, gradient are variables in Paddle. .def("__init__", [](platform::CUDAPlace &self, int dev_id) { #ifdef PADDLE_WITH_CUDA + if (dev_id == 1) { + int count; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDeviceCount(&count)); + } else if (dev_id == 2) { + curandStatus_t a = CURAND_STATUS_OUT_OF_RANGE; + PADDLE_ENFORCE_CUDA_SUCCESS(a); + } else if (dev_id == 3) { + cudnnStatus_t a = CUDNN_STATUS_INTERNAL_ERROR; + PADDLE_ENFORCE_CUDA_SUCCESS(a); + } else if (dev_id == 4) { + cublasStatus_t a = CUBLAS_STATUS_LICENSE_ERROR; + PADDLE_ENFORCE_CUDA_SUCCESS(a); + } else if (dev_id == 5) { + ncclResult_t a = ncclSystemError; + PADDLE_ENFORCE_CUDA_SUCCESS(a); + } if (UNLIKELY(dev_id < 0)) { LOG(ERROR) << string::Sprintf( "Invalid CUDAPlace(%d), device id must be 0 or " diff --git a/tools/check_api_approvals.sh b/tools/check_api_approvals.sh index 51330bea8ea62..3e079d0433f87 100644 --- a/tools/check_api_approvals.sh +++ b/tools/check_api_approvals.sh @@ -172,8 +172,8 @@ if [ "${ALL_PADDLE_ENFORCE}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then check_approval 1 6836917 47554610 22561442 fi -ALL_PADDLE_CHECK=`git diff -U0 upstream/$BRANCH |grep "^+" |grep -zoE "(PADDLE_ENFORCE[A-Z_]*|PADDLE_THROW)\(.[^,\);]*.[^;]*\);\s" || true` -VALID_PADDLE_CHECK=`echo "$ALL_PADDLE_CHECK" | grep -zoE '(PADDLE_ENFORCE[A-Z_]*|PADDLE_THROW)\((.[^,;]+,)*.[^";]*(errors::).[^"]*".[^";]{20,}.[^;]*\);\s' || true` +ALL_PADDLE_CHECK=`git diff -U0 upstream/$BRANCH |grep "^+" |grep -zoE "(PADDLE_ENFORCE[A-Z_]{0,9}|PADDLE_THROW)\(.[^,\);]*.[^;]*\);\s" || true` +VALID_PADDLE_CHECK=`echo "$ALL_PADDLE_CHECK" | grep -zoE '(PADDLE_ENFORCE[A-Z_]{0,9}|PADDLE_THROW)\((.[^,;]+,)*.[^";]*(errors::).[^"]*".[^";]{20,}.[^;]*\);\s' || true` INVALID_PADDLE_CHECK=`echo "$ALL_PADDLE_CHECK" |grep -vxF "$VALID_PADDLE_CHECK" || true` if [ "${INVALID_PADDLE_CHECK}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then echo_line="The error message you wrote in PADDLE_ENFORCE{_**} or PADDLE_THROW does not meet our error message writing specification. Possible errors include 1. the error message is empty / 2. the error message is too short / 3. the error type is not specified. Please read the specification [ https://github.com/PaddlePaddle/Paddle/wiki/Paddle-Error-Message-Writing-Specification ], then refine the error message. If it is a mismatch, please specify chenwhql (Recommend), luotao1 or lanxianghit review and approve.\nThe PADDLE_ENFORCE{_**} or PADDLE_THROW entries that do not meet the specification are as follows:\n${INVALID_PADDLE_CHECK}\n" diff --git a/tools/count_invalid_enforce.sh b/tools/count_invalid_enforce.sh index 77d264370fdb0..46feadd56672f 100644 --- a/tools/count_invalid_enforce.sh +++ b/tools/count_invalid_enforce.sh @@ -1,7 +1,7 @@ #!/bin/bash -ALL_PADDLE_CHECK=`grep -r -zoE "(PADDLE_ENFORCE[A-Z_]*|PADDLE_THROW)\(.[^,\);]*.[^;]*\);\s" ../paddle/fluid || true` +ALL_PADDLE_CHECK=`grep -r -zoE "(PADDLE_ENFORCE[A-Z_]{0,9}|PADDLE_THROW)\(.[^,\);]*.[^;]*\);\s" ../paddle/fluid || true` ALL_PADDLE_CHECK_CNT=`echo "$ALL_PADDLE_CHECK" | grep -cE "(PADDLE_ENFORCE|PADDLE_THROW)" || true` -VALID_PADDLE_CHECK_CNT=`echo "$ALL_PADDLE_CHECK" | grep -zoE '(PADDLE_ENFORCE[A-Z_]*|PADDLE_THROW)\((.[^,;]+,)*.[^";]*(errors::).[^"]*".[^";]{20,}.[^;]*\);\s' | grep -cE "(PADDLE_ENFORCE|PADDLE_THROW)" || true` +VALID_PADDLE_CHECK_CNT=`echo "$ALL_PADDLE_CHECK" | grep -zoE '(PADDLE_ENFORCE[A-Z_]{0,9}|PADDLE_THROW)\((.[^,;]+,)*.[^";]*(errors::).[^"]*".[^";]{20,}.[^;]*\);\s' | grep -cE "(PADDLE_ENFORCE|PADDLE_THROW)" || true` echo "----------------------------" echo "PADDLE ENFORCE & THROW COUNT"