Skip to content

Commit

Permalink
Merge pull request #253 from antodo/develop
Browse files Browse the repository at this point in the history
Remove thrust for cross entropy
  • Loading branch information
RParedesPalacios authored Feb 11, 2021
2 parents 83b6c3f + 673923d commit 4f010ca
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ src/serialization/onnx/onnx.pb.cc
*.DS_Store
.idea
.vscode
*~

# Build
/[Bb]uild*
Expand Down
16 changes: 10 additions & 6 deletions src/hardware/gpu/nn/gpu_losses.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <cublas_v2.h>

//#include <thrust/transform.h>
#include <thrust/reduce.h>
/* #include <thrust/reduce.h>
#include <thrust/functional.h>
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
Expand All @@ -23,7 +23,7 @@
#include <thrust/generate.h>
#include <thrust/sort.h>
#include <thrust/sequence.h>
#include <thrust/copy.h>
#include <thrust/copy.h> */

#include "eddl/hardware/gpu/nn/gpu_tensor_nn.h"
#include "eddl/hardware/gpu/nn/gpu_tensor_nn_kernels.h"
Expand Down Expand Up @@ -60,15 +60,17 @@ float gpu_categorical_cross_entropy(Tensor* y_true, Tensor* y_pred){

float *sum_array;
check_cuda(cudaMalloc((void**)&(sum_array), n_batches*sizeof(float)),"create temp array");
check_cuda(cudaMemset(sum_array, 0, sizeof(float)), "memset");
check_cuda(cudaDeviceSynchronize(), "create");

// Calculate derivative of Softmax
gpu_categorical_cross_entropy<<<numBlocks, blockSize>>>(y_true->ptr, y_pred->ptr, sum_array, n_batches, n_features);
check_cuda(cudaDeviceSynchronize(),"gpu_categorical_cross_entropy");

// Reduce sum and compute mean
thrust::device_ptr<float> dev_ptr = thrust::device_pointer_cast(sum_array);
float sum_ce = thrust::reduce(dev_ptr, dev_ptr + n_batches);
// thrust::device_ptr<float> dev_ptr = thrust::device_pointer_cast(sum_array);
float sum_ce; // = thrust::reduce(dev_ptr, dev_ptr + n_batches);
check_cuda(cudaMemcpy(&sum_ce, sum_array, sizeof(float), cudaMemcpyDeviceToHost), "memcpy");
float mean_ce = -sum_ce;//(float)n_batches; // Mean

// Delete tmp array
Expand Down Expand Up @@ -102,15 +104,17 @@ float gpu_binary_cross_entropy(Tensor* y_true, Tensor* y_pred){

float *sum_array;
check_cuda(cudaMalloc((void**)&(sum_array), n_batches*sizeof(float)),"create temp array");
check_cuda(cudaMemset(sum_array, 0, sizeof(float)), "memset");
check_cuda(cudaDeviceSynchronize(), "create");

// Calculate derivative of Softmax
gpu_binary_cross_entropy<<<numBlocks, blockSize>>>(y_true->ptr, y_pred->ptr, sum_array, y_true->size);
check_cuda(cudaDeviceSynchronize(),"gpu_binary_cross_entropy");

// Reduce sum and compute mean
thrust::device_ptr<float> dev_ptr = thrust::device_pointer_cast(sum_array);
float sum_ce = thrust::reduce(dev_ptr, dev_ptr + n_batches);
// thrust::device_ptr<float> dev_ptr = thrust::device_pointer_cast(sum_array);
float sum_ce; // = thrust::reduce(dev_ptr, dev_ptr + n_batches);
check_cuda(cudaMemcpy(&sum_ce, sum_array, sizeof(float), cudaMemcpyDeviceToHost), "memcpy");
float mean_ce = -sum_ce;//(float)n_batches; // Mean

// Delete tmp array
Expand Down
6 changes: 4 additions & 2 deletions src/hardware/gpu/nn/gpu_losses_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ __global__ void gpu_categorical_cross_entropy(float* y_true, float* y_pred, floa
}

// Store partial sums (later will be reduced)
sum_array[thread_id_x] = bi_sum;
// sum_array[thread_id_x] = bi_sum;
atomicAdd(sum_array, bi_sum);
}
}

Expand All @@ -70,7 +71,8 @@ __global__ void gpu_binary_cross_entropy(float* y_true, float* y_pred, float* su
float eps =10e-8;

// Store sums (later will be reduced)
sum_array[thread_id_x] = y_true[thread_id_x] * logf(y_pred[thread_id_x]+eps) + (1.0-y_true[thread_id_x]) * logf(1.0f-y_pred[thread_id_x]+eps);
// sum_array[thread_id_x] = y_true[thread_id_x] * logf(y_pred[thread_id_x]+eps) + (1.0-y_true[thread_id_x]) * logf(1.0f-y_pred[thread_id_x]+eps);
atomicAdd(sum_array, y_true[thread_id_x] * logf(y_pred[thread_id_x]+eps) + (1.0-y_true[thread_id_x]) * logf(1.0f-y_pred[thread_id_x]+eps));
}

}
Expand Down

0 comments on commit 4f010ca

Please sign in to comment.