Skip to content

Commit

Permalink
fix ctc edit distance in v2 API.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Sep 4, 2017
1 parent c1feb27 commit faa4da4
Showing 1 changed file with 46 additions and 10 deletions.
56 changes: 46 additions & 10 deletions paddle/gserver/evaluators/CTCErrorEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class CTCErrorEvaluator : public NotGetableEvaluator {
int numTimes_, numClasses_, numSequences_, blank_;
real deletions_, insertions_, substitutions_;
int seqClassficationError_;
mutable std::unordered_map<std::string, real> evalResults_;

std::vector<int> path2String(const std::vector<int>& path) {
std::vector<int> str;
Expand Down Expand Up @@ -183,6 +184,18 @@ class CTCErrorEvaluator : public NotGetableEvaluator {
return stringAlignment(gtStr, recogStr);
}

void storeLocalValues() const {
evalResults_["error"] = numSequences_ ? totalScore_ / numSequences_ : 0;
evalResults_["deletion_error"] =
numSequences_ ? deletions_ / numSequences_ : 0;
evalResults_["insertion_error"] =
numSequences_ ? insertions_ / numSequences_ : 0;
evalResults_["substitution_error"] =
numSequences_ ? substitutions_ / numSequences_ : 0;
evalResults_["sequence_error"] =
(real)seqClassficationError_ / numSequences_;
}

public:
CTCErrorEvaluator()
: numTimes_(0),
Expand Down Expand Up @@ -245,16 +258,12 @@ class CTCErrorEvaluator : public NotGetableEvaluator {
}

virtual void printStats(std::ostream& os) const {
os << config_.name() << "="
<< (numSequences_ ? totalScore_ / numSequences_ : 0);
os << " deletions error"
<< "=" << (numSequences_ ? deletions_ / numSequences_ : 0);
os << " insertions error"
<< "=" << (numSequences_ ? insertions_ / numSequences_ : 0);
os << " substitutions error"
<< "=" << (numSequences_ ? substitutions_ / numSequences_ : 0);
os << " sequences error"
<< "=" << (real)seqClassficationError_ / numSequences_;
storeLocalValues();
os << config_.name() << "=" << evalResults_["error"];
os << " deletions error = " << evalResults_["deletion_error"];
os << " insertions error = " << evalResults_["insertion_error"];
os << " substitution error = " << evalResults_["substitution_error"];
os << " sequence error = " << evalResults_["sequence_error"];
}

virtual void distributeEval(ParameterClient2* client) {
Expand All @@ -272,6 +281,33 @@ class CTCErrorEvaluator : public NotGetableEvaluator {
seqClassficationError_ = (int)buf[4];
numSequences_ = (int)buf[5];
}

void getNames(std::vector<std::string>* names) {
storeLocalValues();
names->reserve(names->size() + evalResults_.size());
for (auto it = evalResults_.begin(); it != evalResults_.end(); ++it) {
names->push_back(config_.name() + "." + it->first);
}
}

real getValue(const std::string& name, Error* err) const {
storeLocalValues();

const std::string delimiter(".");
std::string::size_type foundPos = name.find(delimiter, 0);
CHECK(foundPos != std::string::npos);

auto it = evalResults_.find(
name.substr(foundPos + delimiter.size(), name.length()));
if (it == evalResults_.end()) {
*err = Error("Evaluator does not have the key %s", name.c_str());
return 0.0f;
}

return it->second;
}

std::string getTypeImpl() const { return "ctc_edit_distance"; }
};

REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator);
Expand Down

0 comments on commit faa4da4

Please sign in to comment.