Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evalfused #615

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
30 changes: 30 additions & 0 deletions app/br/br.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@
* \endcode
*/

QVector<float> convertStringToFloatVector(const QString& str) {
QVector<float> result;
QStringList values = str.split(",");

for (const QString& value : values) {
bool ok;
float num = value.toFloat(&ok);
if (ok) {
result.append(num);
} else {
// Handle conversion error, e.g., by throwing an exception or returning an empty vector
qDebug() << "Error converting string to float:" << value;
}
}

return result;
}

class FakeMain : public QRunnable
{
int argc;
Expand Down Expand Up @@ -123,6 +141,17 @@ class FakeMain : public QRunnable
} else {
br_eval(parv[0], parv[1], parv[2], atoi(parv[3]));
}
} else if (!strcmp(fun, "evalfused")) {
check((parc == 4 || (parc == 6)), "Incorrect parameter count for 'evalfused'.");
if (parc == 4) {
const QStringList simmatList = QString(parv[0]).split(",");
const QVector<float> weights = convertStringToFloatVector(parv[3]);
br_eval_fused(simmatList, parv[1], parv[2], 0, weights, -1e6, 1e6);
} else {
const QStringList simmatList = QString(parv[0]).split(",");
const QVector<float> weights = convertStringToFloatVector(parv[3]);
br_eval_fused(simmatList, parv[1], parv[2], 0, weights, atof(parv[4]), atof(parv[5]));
}
} else if (!strcmp(fun, "plot")) {
check(parc >= 2, "Incorrect parameter count for 'plot'.");
br_plot(parc-1, parv, parv[parc-1], true);
Expand Down Expand Up @@ -278,6 +307,7 @@ class FakeMain : public QRunnable
"-enroll <input_gallery> ... <input_gallery> {output_gallery}\n"
"-compare <target_gallery> <query_gallery> [{output}]\n"
"-eval <simmat> [<mask>] [{csv}] [{matches}]\n"
"-evalfused <simmat> <simmat2> [<simmat3>] [<mask>] [{csv}] <w1> <w2> [<w3>] <lowerBound> <upperBound>\n"
"-plot <csv> ... <csv> {destination}\n"
"\n"
"==== Other Commands ====\n"
Expand Down
108 changes: 82 additions & 26 deletions openbr/core/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,49 +148,100 @@ static cv::Mat constructMatchingMask(const cv::Mat &scores, const FileList &targ

float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const File &csv, int partition)
{
return Evaluate(scores, constructMatchingMask(scores, target, query, partition), csv, QString(), QString(), 0);
return Evaluate(scores, constructMatchingMask(scores, target, query, partition), csv, QStringList(), QString(), 0);
}

float Evaluate(const QString &simmat, const QString &mask, const File &csv, unsigned int matches)
float Evaluate(const QStringList &simmats, const QString &mask, const File &csv, unsigned int matches, const QVector<float> &weights, float lowerBound, float upperBound)
{
qDebug("Evaluating %s%s%s",
qPrintable(simmat),
if (simmats.size() <= 0) {
throw std::invalid_argument("Must supply at least one simmat.");
}

if (!weights.isEmpty() && weights.size() != simmats.size()) {
throw std::invalid_argument("The size of weights must match the number of simmats.");
}

int n_simmats = simmats.size();
if (n_simmats == 1) {
qDebug("Evaluating %s%s%s",
qPrintable(simmats[0]),
mask.isEmpty() ? "" : qPrintable(" with " + mask),
csv.name.isEmpty() ? "" : qPrintable(" to " + csv));

// Read similarity matrix
QString target, query;
Mat scores;
if (simmat.endsWith(".mtx")) {
scores = BEE::readMatrix(simmat, &target, &query);
} else {
QScopedPointer<Format> format(Factory<Format>::make(simmat));
scores = format->read();
qDebug("Evaluating %d simmats%s%s",
n_simmats,
mask.isEmpty() ? "" : qPrintable(" with " + mask),
csv.name.isEmpty() ? "" : qPrintable(" to " + csv));
}

QVector<Mat> scores(n_simmats);
QStringList targets;
for (int i = 0; i < n_simmats; ++i) {
targets.append("");
}
QString query;

for (int i = 0; i < n_simmats; i++) {
if (simmats[i].endsWith(".mtx")) {
Mat score_mat = BEE::readMatrix(simmats[i], &targets[i], &query);
scores[i] = score_mat;
} else {
QScopedPointer<Format> format(Factory<Format>::make(simmats[i]));
Mat score_mat = format->read();
scores[i] = score_mat;
}
}

// Read mask matrix
Mat truth;
if (mask.isEmpty()) {
// Use the galleries specified in the similarity matrix
if (target.isEmpty()) qFatal("Unspecified target gallery.");
if (targets[0].isEmpty()) qFatal("Unspecified target gallery.");
if (query.isEmpty()) qFatal("Unspecified query gallery.");

truth = constructMatchingMask(scores, TemplateList::fromGallery(target).files(),
truth = constructMatchingMask(scores[0], TemplateList::fromGallery(targets[0]).files(),
TemplateList::fromGallery(query).files());
} else {
File maskFile(mask);
maskFile.set("rows", scores.rows);
maskFile.set("columns", scores.cols);
maskFile.set("rows", scores[0].rows);
maskFile.set("columns", scores[0].cols);
QScopedPointer<Format> format(Factory<Format>::make(maskFile));
truth = format->read();
}

return Evaluate(scores, truth, csv, target, query, matches);
if (n_simmats > 1) {
double count_total = 0;
double count_fused = 0;
for (int i = 0; i < scores[0].rows; i++) {
for (int j = 0; j < scores[0].cols; j++) {
const BEE::MaskValue mask_val = truth.at<BEE::MaskValue>(i,j);
if (mask_val == BEE::DontCare) continue;

BEE::SimmatValue simmat_val = scores[0].at<BEE::SimmatValue>(i,j);
if (simmat_val != simmat_val) continue;

if (simmat_val >= lowerBound && simmat_val <= upperBound) {
simmat_val *= weights[0];
for (int k = 1; k < n_simmats; k++)
simmat_val += weights[k]*scores[k].at<BEE::SimmatValue>(i,j);
scores[0].at<BEE::SimmatValue>(i,j) = simmat_val;
count_fused++;
}
count_total++;
}
}
qDebug("fused %lf percent of comparisons w/ scores falling between [%f, %f]",
count_fused/count_total*100,
lowerBound,
upperBound);
}

return Evaluate(scores[0], truth, csv, targets, query, matches);
}

float Evaluate(const Mat &simmat, const Mat &mask, const File &csv, const QString &target, const QString &query, unsigned int matches)
float Evaluate(const Mat &simmat, const Mat &mask, const File &csv, const QStringList &targets, const QString &query, unsigned int matches)
{
if (target.isEmpty() || query.isEmpty()) matches = 0;
if (targets[0].isEmpty() || query.isEmpty()) matches = 0;
if (simmat.size() != mask.size())
qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).",
simmat.rows, simmat.cols, mask.rows, mask.cols);
Expand Down Expand Up @@ -338,7 +389,7 @@ float Evaluate(const Mat &simmat, const Mat &mask, const File &csv, const QStrin

QString filePath = Globals->path;
if (matches != 0 && EERIndex != 0) {
const FileList targetFiles = TemplateList::fromGallery(target).files();
const FileList targetFiles = TemplateList::fromGallery(targets[0]).files();
const FileList queryFiles = TemplateList::fromGallery(query).files();
unsigned int count = 0;
for (int i = EERIndex-1; i >= 0; i--) {
Expand Down Expand Up @@ -422,11 +473,16 @@ float Evaluate(const Mat &simmat, const Mat &mask, const File &csv, const QStrin
lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(result = getOperatingPoint(operatingPoints, "FAR", 0.001).TAR, 'f', 3))));

// Attempt to read template size from enrolled gallery and write to output CSV
size_t maxSize(0);
if (target.endsWith(".gal") && QFileInfo(target).exists()) {
foreach (const Template &t, TemplateList::fromGallery(target)) maxSize = max(maxSize, t.bytes());
lines.append(QString("TS,,%1").arg(QString::number(maxSize)));
size_t maxSizeTotal(0);

for (int i=0; i<targets.size(); i++) {
if (targets[i].endsWith(".gal") && QFileInfo(targets[i]).exists()) {
size_t maxSize(0);
foreach (const Template &t, TemplateList::fromGallery(targets[i])) maxSize = max(maxSize, t.bytes());
maxSizeTotal += maxSize;
}
}
lines.append(QString("TS,,%1").arg(QString::number(maxSizeTotal)));

// Write SD
int points = qMin(qMin((size_t)Max_Points, genuines.size()), impostors.size());
Expand All @@ -450,7 +506,7 @@ float Evaluate(const Mat &simmat, const Mat &mask, const File &csv, const QStrin
}

QtUtils::writeFile(csv, lines);
if (maxSize > 0) qDebug("Template Size: %i bytes", (int)maxSize);
if (maxSizeTotal > 0) qDebug("Template Size: %i bytes", (int)maxSizeTotal);
foreach (float FAR, QList<float>() << 1e-1 << 1e-2 << 1e-3 << 1e-4 << 1e-5 << 1e-6) {
const OperatingPoint op = getOperatingPoint(operatingPoints, "FAR", FAR);
printf("TAR & Similarity @ FAR = %.0e: %.4f %.3f\n", FAR, op.TAR, op.score);
Expand All @@ -469,7 +525,7 @@ float Evaluate(const Mat &simmat, const Mat &mask, const File &csv, const QStrin

void assertEval(const QString &simmat, const QString &mask, float accuracy)
{
float result = Evaluate(simmat, mask, "", 0);
float result = Evaluate({simmat}, mask, "", 0, QVector<float>(), -1e6, 1e6);
// Round result to nearest thousandth for comparison against input accuracy. Input is expected to be from previous
// results of br -eval.
result = floor(result*1000+0.5)/1000;
Expand Down
9 changes: 6 additions & 3 deletions openbr/core/eval.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// #endif // BR_EVAL_H

/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
* Copyright 2012 The MITRE Corporation *
* *
Expand All @@ -23,12 +25,13 @@

namespace br
{
float Evaluate(const QString &simmat, const QString &mask = "", const File &csv = "", unsigned int matches = 0); // Returns TAR @ FAR = 0.001
float Evaluate(const QStringList &simmats, const QString &mask = "", const File &csv = "", unsigned int matches = 0, const QVector<float> &weights = QVector<float>(), float lowerBound = -1e6, float upperBound = 1e6); // Returns TAR @ FAR = 0.001
float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const File &csv = "", int parition = 0);
float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const File &csv = "", const QString &target = "", const QString &query = "", unsigned int matches = 0);
float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const File &csv = "", const QStringList &targets = QStringList(), const QString &query = "", unsigned int matches = 0);

void assertEval(const QString &simmat, const QString &mask, float accuracy); // Check to see if -eval achieves a given TAR @ FAR = 0.001
float InplaceEval(const QString &simmat, const QString &mask, const QString &csv);

void EvalClassification(const QString &predictedGallery, const QString &truthGallery, QString predictedProperty = "", QString truthProperty = "");
float EvalDetection(const QString &predictedGallery, const QString &truthGallery, const QString &csv = "", bool normalize = false, int minSize = 0, int maxSize = 0, float relativeMinSize = 0, const QString &label = "", const float true_positive_threshold = 0.5f); // Return average overlap
float EvalLandmarking(const QString &predictedGallery, const QString &truthGallery, const QString &csv = "", int normalizationIndexA = 0, int normalizationIndexB = 1, int sampleIndex = 0, int totalExamples = 5); // Return average error
Expand Down
7 changes: 6 additions & 1 deletion openbr/openbr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@ void br_project(const char *input, const char *gallery)

float br_eval(const char *simmat, const char *mask, const char *csv, int matches)
{
return Evaluate(simmat, mask, csv, matches);
return Evaluate({simmat}, mask, csv, matches, QVector<float>(), -1e6, 1e6);
}

float br_eval_fused(const QStringList &simmats, const char *mask, const char *csv, int matches, const QVector<float> &weights, float lowerBound, float upperBound)
{
return Evaluate(simmats, mask, csv, matches, weights, lowerBound, upperBound);
}

void br_assert_eval(const char *simmat, const char *mask, const float accuracy)
Expand Down
5 changes: 4 additions & 1 deletion openbr/openbr.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
#define OPENBR_H

#include <openbr/openbr_export.h>
#include <QStringList>
#include <QVector>

#ifdef __cplusplus
extern "C" {
#endif


BR_EXPORT const char *br_about();

BR_EXPORT void br_cat(int num_input_galleries, const char *input_galleries[], const char *output_gallery);
Expand All @@ -50,6 +51,8 @@ BR_EXPORT void br_project(const char *input, const char *output);

BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv = "", int matches = 0);

BR_EXPORT float br_eval_fused(const QStringList &simmats, const char *mask, const char *csv = "", int matches = 0, const QVector<float> &weights = {}, float lowerBound = -1e6, float upperBound = 1e6);

BR_EXPORT void br_assert_eval(const char *simmat, const char *mask, const float accuracy);

BR_EXPORT float br_inplace_eval(const char * simmat, const char *mask, const char *csv);
Expand Down