Skip to content

Commit

Permalink
add PNCalculator and PNTypeCalculator (PaddlePaddle#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
WorgenZhang authored Feb 10, 2023
1 parent eabc67d commit 9d34771
Show file tree
Hide file tree
Showing 6 changed files with 577 additions and 17 deletions.
278 changes: 274 additions & 4 deletions paddle/fluid/framework/fleet/metrics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void BasicAucCalculator::calculate_bucket_error() {
#endif
}

void BasicAucCalculator::reset_records() {
void BasicAucCalculator::reset_wuauc_records() {
// reset wuauc_records_
wuauc_records_.clear();
_user_cnt = 0;
Expand All @@ -249,7 +249,7 @@ void BasicAucCalculator::reset_records() {
}

// add uid data
void BasicAucCalculator::add_uid_data(const float* d_pred,
void BasicAucCalculator::add_uid_data_wuauc(const float* d_pred,
const int64_t* d_label,
const int64_t* d_uid, int batch_size,
const paddle::platform::Place& place) {
Expand All @@ -266,11 +266,11 @@ void BasicAucCalculator::add_uid_data(const float* d_pred,

std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_uid_unlock_data(h_pred[i], h_label[i], static_cast<uint64_t>(h_uid[i]));
add_uid_unlock_data_wuauc(h_pred[i], h_label[i], static_cast<uint64_t>(h_uid[i]));
}
}

void BasicAucCalculator::add_uid_unlock_data(double pred, int label,
void BasicAucCalculator::add_uid_unlock_data_wuauc(double pred, int label,
uint64_t uid) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
Expand Down Expand Up @@ -375,6 +375,276 @@ BasicAucCalculator::WuaucRocData BasicAucCalculator::computeSingelUserAuc(
return {tp, fp, auc};
}

void BasicAucCalculator::reset_pn_records() {
// reset wuauc_records_
pn_records_.clear();
_final_pn = 0;
_count = 0;
_positive_num = 0;
_negtive_num = 0;
}

// add_uid_data_pn
void BasicAucCalculator::add_uid_data_pn(const float* d_pred,
const float* d_label,
const int64_t* d_uid, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<float> h_label;
thread_local std::vector<uint64_t> h_uid;
h_pred.resize(batch_size);
h_label.resize(batch_size);
h_uid.resize(batch_size);

// memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
// memcpy(h_label.data(), d_label, sizeof(float) * batch_size);
// memcpy(h_uid.data(), d_uid, sizeof(uint64_t) * batch_size);

cudaMemcpy(h_pred.data(), d_pred, sizeof(float) * batch_size, cudaMemcpyDeviceToHost);
cudaMemcpy(h_label.data(), d_label, sizeof(float) * batch_size, cudaMemcpyDeviceToHost);
cudaMemcpy(h_uid.data(), d_uid, sizeof(uint64_t) * batch_size, cudaMemcpyDeviceToHost);

std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_uid_unlock_data_pn(h_pred[i], h_label[i], static_cast<uint64_t>(h_uid[i]));
}
}

void BasicAucCalculator::add_uid_unlock_data_pn(float pred, float label,
uint64_t uid) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
// PADDLE_ENFORCE_EQ(
// label * label, label,
// platform::errors::PreconditionNotMet(
// "label must be equal to 0 or 1, but its value is: %d", label));

PNRecord record;
record.uid_ = uid;
record.label_ = label;
record.pred_ = pred;
const float EPSINON = -0.00001;
if (label >= EPSINON) {
pn_records_.emplace_back(std::move(record));
_count += 1;
}
}

void BasicAucCalculator::computePN() {
_positive_num = 0;
_negtive_num = 0;

// no need shuffle by uid when single machine

std::sort(pn_records_.begin(), pn_records_.end(),
[](const PNRecord& lhs, const PNRecord& rhs) {
if (lhs.uid_ == rhs.uid_) {
if (lhs.pred_ == rhs.pred_) {
return lhs.label_ < rhs.label_;
} else {
return lhs.pred_ > rhs.pred_;
}
} else {
return lhs.uid_ > rhs.uid_;
}
});

uint64_t prev_uid = 0;
size_t prev_pos = 0;

for (size_t i = 0; i < pn_records_.size(); ++i) {
if (pn_records_[i].uid_ != prev_uid) {
count_pn_pairs(pn_records_, prev_pos, i, _positive_num, _negtive_num);
prev_uid = pn_records_[i].uid_;
prev_pos = i;
}
}
count_pn_pairs(pn_records_, prev_pos, pn_records_.size(), _positive_num, _negtive_num);

// no need allreduce metrics when single machine
if (_negtive_num == 0) {
_final_pn = FLT_MAX;
} else {
_final_pn = (float)((double)_positive_num / (double)_negtive_num);
}

}

void BasicAucCalculator::count_pn_pairs(const std::vector<PNRecord>& recs, size_t start, size_t end, double& positive_num, double& negtive_num) {
if (end <= 0) {
return;
}
end = std::min(end, recs.size());
for (size_t i = start; i < end - 1; ++i) {
float label = recs[i].label_;
for (size_t j = i + 1; j < end; ++j) {
if (label < recs[j].label_) {
negtive_num += 1;
} else {
positive_num += 1;
}
}
}
}

void BasicAucCalculator::init_pn(int type_num) {
set_pn_type_num(type_num);
}


void BasicAucCalculator::reset_type_pn_records() {
// reset wuauc_records_
_pn_info.reset();
_pn_infos.clear();
_pn_infos.resize(_type_num);
pn_type_records_.clear();
}

// add_uid_data_pn
void BasicAucCalculator::add_uid_data_type_pn(const float* d_pred,
const float* d_label,
const int64_t* d_uid,
const int64_t* d_type,
int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<float> h_label;
thread_local std::vector<uint64_t> h_uid;
thread_local std::vector<int64_t> h_type;
h_pred.resize(batch_size);
h_label.resize(batch_size);
h_uid.resize(batch_size);
h_type.resize(batch_size);

// memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
// memcpy(h_label.data(), d_label, sizeof(float) * batch_size);
// memcpy(h_uid.data(), d_uid, sizeof(uint64_t) * batch_size);

cudaMemcpy(h_pred.data(), d_pred, sizeof(float) * batch_size, cudaMemcpyDeviceToHost);
cudaMemcpy(h_label.data(), d_label, sizeof(float) * batch_size, cudaMemcpyDeviceToHost);
cudaMemcpy(h_uid.data(), d_uid, sizeof(uint64_t) * batch_size, cudaMemcpyDeviceToHost);
cudaMemcpy(h_type.data(), d_type, sizeof(int64_t) * batch_size, cudaMemcpyDeviceToHost);

std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_uid_unlock_data_type_pn(h_pred[i], h_label[i], static_cast<uint64_t>(h_uid[i]), h_type[i]);
}
}

void BasicAucCalculator::add_uid_unlock_data_type_pn(float pred, float label,
uint64_t uid, int64_t type) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
// PADDLE_ENFORCE_EQ(
// label * label, label,
// platform::errors::PreconditionNotMet(
// "label must be equal to 0 or 1, but its value is: %d", label));

PNTypeRecord record;
record.uid = uid;
record.label = label;
record.pred = pred;
record.rtype = (int)type;
const float EPSINON = -0.00001;
if (label >= EPSINON) {
CHECK((int)type >= 0);
CHECK((int)type < _type_num);
pn_type_records_.emplace_back(std::move(record));
_pn_info.count += 1;
_pn_info.label_sum += (double)label;
_pn_info.pred_sum += (double)pred;
// VLOG(0) << "add_uid_unlock_data_type_pn 2 type --> " << (int)type << " length of pn_infos= " << _pn_infos.size();
_pn_infos[(int)type].count += 1;
_pn_infos[(int)type].label_sum += (double)label;
_pn_infos[(int)type].pred_sum += (double)pred;
}
}

void BasicAucCalculator::computeTypePN() {
// no need shuffle by uid when single machine

std::sort(pn_type_records_.begin(), pn_type_records_.end(),
[](const PNTypeRecord& lhs, const PNTypeRecord& rhs) {
if (lhs.uid == rhs.uid) {
if (lhs.pred == rhs.pred) {
return lhs.label < rhs.label;
} else {
return lhs.pred > rhs.pred;
}
} else {
return lhs.uid > rhs.uid;
}
});

uint64_t prev_uid = 0;
size_t prev_pos = 0;
for (size_t i = 0; i < pn_type_records_.size(); ++i) {
if (pn_type_records_[i].uid != prev_uid) {
count_pn_type_pairs(pn_type_records_, prev_pos, i);
prev_uid = pn_type_records_[i].uid;
prev_pos = i;
}
}
count_pn_type_pairs(pn_type_records_, prev_pos, pn_type_records_.size());
compute_pn_info(_pn_info);
for (auto& sub_pn_info : _pn_infos) {
compute_pn_info(sub_pn_info);
}

}

void BasicAucCalculator::count_pn_type_pairs(const std::vector<PNTypeRecord>& recs, size_t start, size_t end) {
if (end <= 0) {
return;
}
end = std::min(end, recs.size());
for (size_t i = start; i < end - 1; ++i) {
float label = recs[i].label;
int rtype = recs[i].rtype;
for (size_t j = i + 1; j < end; ++j) {
double diff = fabs(std::min(label, (float)1440.0) - std::min(recs[j].label, (float)1440.0));
if (label < recs[j].label) {
_pn_info.negtive_num += 1;
_pn_info.negtive_wnum += diff;
if (rtype == recs[j].rtype) {
_pn_infos[rtype].negtive_num += 1;
_pn_infos[rtype].negtive_wnum += diff;
}
} else {
_pn_info.positive_num += 1;
_pn_info.positive_wnum += diff;
if (rtype == recs[j].rtype) {
_pn_infos[rtype].positive_num += 1;
_pn_infos[rtype].positive_wnum += diff;
}
}
}
}
}

void BasicAucCalculator::compute_pn_info(PNInfo& info) {
if (info.negtive_num == 0) {
info.final_pn = FLT_MAX;
} else {
info.final_pn = (float)((double)info.positive_num / (double)info.negtive_num);
}

if (info.negtive_wnum <= 0) {
info.final_wpn = FLT_MAX;
} else {
info.final_wpn = (float)(info.positive_wnum / info.negtive_wnum);
}

if (info.count > 0) {
info.final_pred_avg = info.pred_sum / (double)info.count;
info.final_label_avg = info.label_sum / (double)info.count;
}
}

} // namespace framework
} // namespace paddle
#endif
Loading

0 comments on commit 9d34771

Please sign in to comment.