Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove thrust for cross entropy #253

Merged
merged 36 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
78374b4
prototype code for forward
antodo Jan 22, 2021
f7bef05
prototype code for forward
antodo Jan 25, 2021
605dce6
prototype code for forward
antodo Jan 25, 2021
d394b0d
block code for forward
antodo Jan 27, 2021
771ec62
wip
antodo Feb 1, 2021
132b921
Merge remote-tracking branch 'upstream/develop' into develop
antodo Feb 2, 2021
1f352c2
primera version en GPU del backward
antodo Feb 4, 2021
0a6aecb
segunda version backward
antodo Feb 4, 2021
3c17360
Merge branch 'develop' of https://github.com/antodo/eddl into develop
antodo Feb 4, 2021
1d612fd
Merge remote-tracking branch 'upstream/develop' into develop
antodo Feb 4, 2021
5c58d6c
version final del backward
antodo Feb 4, 2021
a3e6917
Merge branch 'develop' of https://github.com/antodo/eddl into develop
antodo Feb 4, 2021
5a58384
Merge branch 'develop' into develop
salvacarrion Feb 4, 2021
e39a3dd
New BatchNorm routines in tensorNN
antodo Feb 4, 2021
78cf0d7
Merge branch 'develop' of https://github.com/antodo/eddl into develop
antodo Feb 4, 2021
c0e3d79
fix conditional compilation
antodo Feb 4, 2021
d43c00b
remove blank line changes
antodo Feb 4, 2021
b619f26
Merge branch 'develop' into develop
antodo Feb 4, 2021
f903b91
Merge branch 'develop' into develop
antodo Feb 4, 2021
b9edb0f
Merge remote-tracking branch 'upstream/develop' into develop
antodo Feb 4, 2021
9147ba6
Fix gpu_batchnorm_forward
antodo Feb 8, 2021
f31c560
Merge remote-tracking branch 'upstream/develop' into develop
antodo Feb 8, 2021
8930813
Merge branch 'develop' into develop
antodo Feb 8, 2021
f3ef2be
Merge branch 'develop' of https://github.com/antodo/eddl into develop
antodo Feb 9, 2021
f8ad7ef
Merge remote-tracking branch 'upstream/develop' into develop
antodo Feb 10, 2021
06fafc6
unitary test for batchnorm
antodo Feb 10, 2021
69f351c
Merge branch 'develop' of https://github.com/antodo/eddl into develop
antodo Feb 10, 2021
1e798a0
Merge branch 'develop' into develop
antodo Feb 10, 2021
c2b0c11
remove conflict marker
antodo Feb 10, 2021
8e16005
Move batchnorm code to NN
antodo Feb 10, 2021
5b2e8c3
Merge remote-tracking branch 'upstream/develop' into develop
antodo Feb 10, 2021
116c955
remove thrust from cross entropy
antodo Feb 11, 2021
2022915
Merge remote-tracking branch 'upstream/develop' into develop
antodo Feb 11, 2021
58df03c
Merge branch 'develop' into develop
antodo Feb 11, 2021
840b842
fix some errors
antodo Feb 11, 2021
673923d
Merge branch 'develop' of https://github.com/antodo/eddl into develop
antodo Feb 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ src/serialization/onnx/onnx.pb.cc
*.DS_Store
.idea
.vscode
*~

# Build
/[Bb]uild*
Expand Down
16 changes: 10 additions & 6 deletions src/hardware/gpu/nn/gpu_losses.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <cublas_v2.h>

//#include <thrust/transform.h>
#include <thrust/reduce.h>
/* #include <thrust/reduce.h>
#include <thrust/functional.h>
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
Expand All @@ -23,7 +23,7 @@
#include <thrust/generate.h>
#include <thrust/sort.h>
#include <thrust/sequence.h>
#include <thrust/copy.h>
#include <thrust/copy.h> */

#include "eddl/hardware/gpu/nn/gpu_tensor_nn.h"
#include "eddl/hardware/gpu/nn/gpu_tensor_nn_kernels.h"
Expand Down Expand Up @@ -60,15 +60,17 @@ float gpu_categorical_cross_entropy(Tensor* y_true, Tensor* y_pred){

float *sum_array;
check_cuda(cudaMalloc((void**)&(sum_array), n_batches*sizeof(float)),"create temp array");
check_cuda(cudaMemset(sum_array, 0, sizeof(float)), "memset");
check_cuda(cudaDeviceSynchronize(), "create");

// Calculate derivative of Softmax
gpu_categorical_cross_entropy<<<numBlocks, blockSize>>>(y_true->ptr, y_pred->ptr, sum_array, n_batches, n_features);
check_cuda(cudaDeviceSynchronize(),"gpu_categorical_cross_entropy");

// Reduce sum and compute mean
thrust::device_ptr<float> dev_ptr = thrust::device_pointer_cast(sum_array);
float sum_ce = thrust::reduce(dev_ptr, dev_ptr + n_batches);
// thrust::device_ptr<float> dev_ptr = thrust::device_pointer_cast(sum_array);
float sum_ce; // = thrust::reduce(dev_ptr, dev_ptr + n_batches);
check_cuda(cudaMemcpy(&sum_ce, sum_array, sizeof(float), cudaMemcpyDeviceToHost), "memcpy");
float mean_ce = -sum_ce;//(float)n_batches; // Mean

// Delete tmp array
Expand Down Expand Up @@ -102,15 +104,17 @@ float gpu_binary_cross_entropy(Tensor* y_true, Tensor* y_pred){

float *sum_array;
check_cuda(cudaMalloc((void**)&(sum_array), n_batches*sizeof(float)),"create temp array");
check_cuda(cudaMemset(sum_array, 0, sizeof(float)), "memset");
check_cuda(cudaDeviceSynchronize(), "create");

// Calculate derivative of Softmax
gpu_binary_cross_entropy<<<numBlocks, blockSize>>>(y_true->ptr, y_pred->ptr, sum_array, y_true->size);
check_cuda(cudaDeviceSynchronize(),"gpu_binary_cross_entropy");

// Reduce sum and compute mean
thrust::device_ptr<float> dev_ptr = thrust::device_pointer_cast(sum_array);
float sum_ce = thrust::reduce(dev_ptr, dev_ptr + n_batches);
// thrust::device_ptr<float> dev_ptr = thrust::device_pointer_cast(sum_array);
float sum_ce; // = thrust::reduce(dev_ptr, dev_ptr + n_batches);
check_cuda(cudaMemcpy(&sum_ce, sum_array, sizeof(float), cudaMemcpyDeviceToHost), "memcpy");
float mean_ce = -sum_ce;//(float)n_batches; // Mean

// Delete tmp array
Expand Down
6 changes: 4 additions & 2 deletions src/hardware/gpu/nn/gpu_losses_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ __global__ void gpu_categorical_cross_entropy(float* y_true, float* y_pred, floa
}

// Store partial sums (later will be reduced)
sum_array[thread_id_x] = bi_sum;
// sum_array[thread_id_x] = bi_sum;
atomicAdd(sum_array, bi_sum);
}
}

Expand All @@ -70,7 +71,8 @@ __global__ void gpu_binary_cross_entropy(float* y_true, float* y_pred, float* su
float eps =10e-8;

// Store sums (later will be reduced)
sum_array[thread_id_x] = y_true[thread_id_x] * logf(y_pred[thread_id_x]+eps) + (1.0-y_true[thread_id_x]) * logf(1.0f-y_pred[thread_id_x]+eps);
// sum_array[thread_id_x] = y_true[thread_id_x] * logf(y_pred[thread_id_x]+eps) + (1.0-y_true[thread_id_x]) * logf(1.0f-y_pred[thread_id_x]+eps);
atomicAdd(sum_array, y_true[thread_id_x] * logf(y_pred[thread_id_x]+eps) + (1.0-y_true[thread_id_x]) * logf(1.0f-y_pred[thread_id_x]+eps));
}

}
Expand Down