-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/is nan #7068
Feature/is nan #7068
Changes from all commits
fd2bf55
a5e1cf5
8b877dd
42062c3
516967e
b711870
15309fd
4518252
e54bb6c
3d282ec
a5291f9
837da79
16a8432
71157b3
e2be6dd
003917d
878d2e9
a9a44e0
3158b4b
f97205e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/framework/tensor_util.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
template <typename Predicate, typename DevCtx> | ||
struct AnyDTypeVisitor { | ||
Predicate predicate_; | ||
const Tensor& tensor_; | ||
const DevCtx& ctx_; | ||
Tensor* out_; | ||
|
||
AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx, | ||
Tensor* out) | ||
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} | ||
|
||
template <typename T> | ||
void operator()() const { | ||
auto t = EigenVector<T>::Flatten(tensor_); | ||
auto o = EigenScalar<bool>::From(*out_); | ||
// return any of predicate_(t) is true. | ||
o.device(*ctx_.eigen_device()) = predicate_(t).any(); | ||
} | ||
}; | ||
|
||
template <typename Predicate, typename DevCtx> | ||
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, | ||
const DevCtx& ctx, framework::Tensor* out) { | ||
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>( | ||
predicate, tensor, ctx, out)); | ||
} | ||
|
||
template <typename Predicate> | ||
struct AnyVisitor : public boost::static_visitor<bool> { | ||
const framework::Tensor& tensor_; | ||
Predicate predicate_; | ||
|
||
AnyVisitor(const framework::Tensor& tensor, Predicate predicate) | ||
: tensor_(tensor), predicate_(std::move(predicate)) {} | ||
|
||
template <typename Place> | ||
bool operator()(const Place& place) const { | ||
framework::Tensor out; | ||
out.Resize({1}); | ||
out.mutable_data<bool>(place); | ||
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); | ||
AnyImpl(predicate_, tensor_, *ctx, &out); | ||
return this->GetResult(out, place); | ||
} | ||
|
||
bool GetResult(const framework::Tensor& out, | ||
const platform::CUDAPlace& gpu) const { | ||
platform::CPUPlace cpu; | ||
framework::Tensor tmp; | ||
tmp.Resize({1}); | ||
tmp.mutable_data<bool>(cpu); | ||
auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu); | ||
gpuctx->Wait(); | ||
CopyFrom(out, cpu, *gpuctx, &tmp); | ||
gpuctx->Wait(); | ||
return GetResult(tmp, cpu); | ||
} | ||
|
||
bool GetResult(const framework::Tensor& out, | ||
const platform::CPUPlace& cpu) const { | ||
return *out.data<bool>(); | ||
} | ||
}; | ||
|
||
template <typename Predicate> | ||
inline bool Any(const framework::Tensor& tensor, Predicate predicate) { | ||
AnyVisitor<Predicate> visitor(tensor, predicate); | ||
auto place = tensor.place(); | ||
return platform::VisitPlace(place, visitor); | ||
} | ||
|
||
struct HasNANPredicate { | ||
template <typename T> | ||
auto operator()(const T& eigen_vec) const | ||
-> decltype(std::declval<T>().isnan()) { | ||
// Cast eigen_vector to vector of bool. true if is inf. | ||
return eigen_vec.isnan(); | ||
} | ||
}; | ||
|
||
bool HasNAN(const framework::Tensor& tensor) { | ||
HasNANPredicate predicate; | ||
return Any(tensor, predicate); | ||
} | ||
|
||
struct HasInfPredicate { | ||
template <typename T> | ||
auto operator()(const T& eigen_vec) const | ||
-> decltype(std::declval<T>().isinf()) { | ||
// Cast eigen_vector to vector of bool. true if is inf. | ||
return eigen_vec.isinf(); | ||
} | ||
}; | ||
|
||
bool HasInf(const framework::Tensor& tensor) { | ||
HasInfPredicate predicate; | ||
return Any(tensor, predicate); | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
./tensor_util.cc | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,10 @@ limitations under the License. */ | |
|
||
#pragma once | ||
#include "paddle/framework/data_type.h" | ||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/framework.pb.h" | ||
#include "paddle/framework/tensor.h" | ||
#include "paddle/platform/device_context.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
@@ -207,6 +209,12 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) { | |
src_ptr, size); | ||
} | ||
|
||
// Returns true if a tensor contains NAN, i.e., Not A Number. | ||
extern bool HasNAN(const framework::Tensor& tensor); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
// Returns true if a tensor contains Inf, i.e., Infinity. | ||
extern bool HasInf(const framework::Tensor& tensor); | ||
|
||
inline void SerializeToStream(std::ostream& os, const Tensor& tensor, | ||
const platform::DeviceContext& dev_ctx) { | ||
// TODO(typhoonzero): serialize to ostream | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
|
||
#include "paddle/framework/tensor_util.h" | ||
#include <gtest/gtest.h> | ||
#include <cmath> | ||
#include <string> | ||
|
||
namespace paddle { | ||
|
@@ -230,6 +231,29 @@ TEST(CopyToVector, Tensor) { | |
#endif | ||
} | ||
|
||
TEST(IsNAN, CPU) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add GPU unit tests at the same time. |
||
using namespace paddle::framework; | ||
using namespace paddle::platform; | ||
Tensor src; | ||
float* buf = src.mutable_data<float>({3}, CPUPlace()); | ||
buf[0] = 0.0; | ||
buf[1] = NAN; | ||
buf[2] = 0.0; | ||
|
||
ASSERT_TRUE(HasNAN(src)); | ||
} | ||
|
||
TEST(IsInf, CPU) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IsInf --> HasInf |
||
using namespace paddle::framework; | ||
using namespace paddle::platform; | ||
Tensor src; | ||
double* buf = src.mutable_data<double>({3}, CPUPlace()); | ||
buf[0] = 1.0; | ||
buf[1] = INFINITY; | ||
buf[2] = 0.0; | ||
ASSERT_TRUE(HasInf(src)); | ||
} | ||
|
||
TEST(Tensor, SerializeAndDeserialize) { | ||
framework::Tensor src_tensor; | ||
int array[6] = {1, 2, 3, 4, 5, 6}; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "gtest/gtest.h" | ||
#include "paddle/framework/tensor_util.h" | ||
#include "paddle/platform/device_context.h" | ||
#include "paddle/platform/place.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
static __global__ void FillNAN(float* buf) { | ||
buf[0] = 0.0; | ||
buf[1] = 0.1; | ||
buf[2] = NAN; | ||
} | ||
static __global__ void FillInf(float* buf) { | ||
buf[0] = 0.0; | ||
buf[1] = INFINITY; | ||
buf[2] = 0.5; | ||
} | ||
|
||
TEST(HasNAN, GPU) { | ||
Tensor tensor; | ||
platform::CUDAPlace gpu(0); | ||
auto& pool = platform::DeviceContextPool::Instance(); | ||
auto* cuda_ctx = pool.GetByPlace(gpu); | ||
float* buf = tensor.mutable_data<float>({3}, gpu); | ||
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); | ||
cuda_ctx->Wait(); | ||
ASSERT_TRUE(HasNAN(tensor)); | ||
} | ||
|
||
TEST(HasInf, GPU) { | ||
Tensor tensor; | ||
platform::CUDAPlace gpu(0); | ||
auto& pool = platform::DeviceContextPool::Instance(); | ||
auto* cuda_ctx = pool.GetByPlace(gpu); | ||
float* buf = tensor.mutable_data<float>({3}, gpu); | ||
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); | ||
cuda_ctx->Wait(); | ||
ASSERT_TRUE(HasInf(tensor)); | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext { | |
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; | ||
}; | ||
|
||
template <typename Place> | ||
struct DefaultDeviceContextType; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is used for In the future, we could add a library type to this method. |
||
|
||
template <> | ||
struct DefaultDeviceContextType<platform::CPUPlace> { | ||
using TYPE = CPUDeviceContext; | ||
}; | ||
|
||
#ifdef PADDLE_WITH_CUDA | ||
|
||
class EigenCudaStreamDevice; | ||
|
@@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext { | |
cublasHandle_t cublas_handle_; | ||
}; | ||
|
||
template <> | ||
struct DefaultDeviceContextType<platform::CUDAPlace> { | ||
using TYPE = CUDADeviceContext; | ||
}; | ||
|
||
class CUDNNDeviceContext : public CUDADeviceContext { | ||
public: | ||
explicit CUDNNDeviceContext(CUDAPlace place); | ||
|
@@ -125,6 +138,13 @@ class DeviceContextPool { | |
/*! \brief Return handle of single device context. */ | ||
const platform::DeviceContext* Get(const platform::Place& place); | ||
|
||
template <typename Place> | ||
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace( | ||
const Place& place) { | ||
return reinterpret_cast< | ||
const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place)); | ||
} | ||
|
||
private: | ||
static DeviceContextPool* pool; | ||
constexpr static int LEFT_SHIFT = 8; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not move the code in tensor_util.cc to tensor_util.h. Or we can have a tensor_util_impl.h. It's a little strange to have a symbolic link.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@QiJune Reyoung probably had an offline discussion with you already. But looks like
.cu
is necessary to make nvcc pass the compilation...