From faa4da4835e0622387e2ab1f6481875c5243c5c9 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Mon, 4 Sep 2017 13:02:27 +0800 Subject: [PATCH 1/3] fix ctc edit distance in v2 API. --- .../gserver/evaluators/CTCErrorEvaluator.cpp | 56 +++++++++++++++---- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp index 132119015f967..8e2dc020cd848 100644 --- a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp +++ b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp @@ -26,6 +26,7 @@ class CTCErrorEvaluator : public NotGetableEvaluator { int numTimes_, numClasses_, numSequences_, blank_; real deletions_, insertions_, substitutions_; int seqClassficationError_; + mutable std::unordered_map evalResults_; std::vector path2String(const std::vector& path) { std::vector str; @@ -183,6 +184,18 @@ class CTCErrorEvaluator : public NotGetableEvaluator { return stringAlignment(gtStr, recogStr); } + void storeLocalValues() const { + evalResults_["error"] = numSequences_ ? totalScore_ / numSequences_ : 0; + evalResults_["deletion_error"] = + numSequences_ ? deletions_ / numSequences_ : 0; + evalResults_["insertion_error"] = + numSequences_ ? insertions_ / numSequences_ : 0; + evalResults_["substitution_error"] = + numSequences_ ? substitutions_ / numSequences_ : 0; + evalResults_["sequence_error"] = + (real)seqClassficationError_ / numSequences_; + } + public: CTCErrorEvaluator() : numTimes_(0), @@ -245,16 +258,12 @@ class CTCErrorEvaluator : public NotGetableEvaluator { } virtual void printStats(std::ostream& os) const { - os << config_.name() << "=" - << (numSequences_ ? totalScore_ / numSequences_ : 0); - os << " deletions error" - << "=" << (numSequences_ ? deletions_ / numSequences_ : 0); - os << " insertions error" - << "=" << (numSequences_ ? insertions_ / numSequences_ : 0); - os << " substitutions error" - << "=" << (numSequences_ ? substitutions_ / numSequences_ : 0); - os << " sequences error" - << "=" << (real)seqClassficationError_ / numSequences_; + storeLocalValues(); + os << config_.name() << "=" << evalResults_["error"]; + os << " deletions error = " << evalResults_["deletion_error"]; + os << " insertions error = " << evalResults_["insertion_error"]; + os << " substitution error = " << evalResults_["substitution_error"]; + os << " sequence error = " << evalResults_["sequence_error"]; } virtual void distributeEval(ParameterClient2* client) { @@ -272,6 +281,33 @@ class CTCErrorEvaluator : public NotGetableEvaluator { seqClassficationError_ = (int)buf[4]; numSequences_ = (int)buf[5]; } + + void getNames(std::vector* names) { + storeLocalValues(); + names->reserve(names->size() + evalResults_.size()); + for (auto it = evalResults_.begin(); it != evalResults_.end(); ++it) { + names->push_back(config_.name() + "." + it->first); + } + } + + real getValue(const std::string& name, Error* err) const { + storeLocalValues(); + + const std::string delimiter("."); + std::string::size_type foundPos = name.find(delimiter, 0); + CHECK(foundPos != std::string::npos); + + auto it = evalResults_.find( + name.substr(foundPos + delimiter.size(), name.length())); + if (it == evalResults_.end()) { + *err = Error("Evaluator does not have the key %s", name.c_str()); + return 0.0f; + } + + return it->second; + } + + std::string getTypeImpl() const { return "ctc_edit_distance"; } }; REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator); From 0b478e991c95f838da41bfcbf11f1e2b80ac17eb Mon Sep 17 00:00:00 2001 From: caoying03 Date: Mon, 4 Sep 2017 15:12:56 +0800 Subject: [PATCH 2/3] follow comments. --- paddle/gserver/evaluators/CTCErrorEvaluator.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp index 8e2dc020cd848..928c77a088ff1 100644 --- a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp +++ b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "Evaluator.h" #include "paddle/gserver/gradientmachines/NeuralNetwork.h" +#include "paddle/utils/StringUtil.h" namespace paddle { @@ -259,7 +260,7 @@ class CTCErrorEvaluator : public NotGetableEvaluator { virtual void printStats(std::ostream& os) const { storeLocalValues(); - os << config_.name() << "=" << evalResults_["error"]; + os << config_.name() << " error = " << evalResults_["error"]; os << " deletions error = " << evalResults_["deletion_error"]; os << " insertions error = " << evalResults_["insertion_error"]; os << " substitution error = " << evalResults_["substitution_error"]; @@ -293,12 +294,10 @@ class CTCErrorEvaluator : public NotGetableEvaluator { real getValue(const std::string& name, Error* err) const { storeLocalValues(); - const std::string delimiter("."); - std::string::size_type foundPos = name.find(delimiter, 0); - CHECK(foundPos != std::string::npos); + std::vector buffers; + paddle::str::split(name, '.', &buffers); + auto it = evalResults_.find(buffers[buffers.size() - 1]); - auto it = evalResults_.find( - name.substr(foundPos + delimiter.size(), name.length())); if (it == evalResults_.end()) { *err = Error("Evaluator does not have the key %s", name.c_str()); return 0.0f; @@ -307,7 +306,11 @@ class CTCErrorEvaluator : public NotGetableEvaluator { return it->second; } - std::string getTypeImpl() const { return "ctc_edit_distance"; } + std::string getType(const std::string& name, Error* err) const { + getValue(name, err); + if (!err->isOK()) return ""; + return "ctc_edit_distance"; + } }; REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator); From a523bea8e585bd63f4167e012a05b03ad435b574 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Mon, 4 Sep 2017 20:40:33 +0800 Subject: [PATCH 3/3] fix getType. --- paddle/gserver/evaluators/CTCErrorEvaluator.cpp | 8 +++++--- paddle/gserver/evaluators/ChunkEvaluator.cpp | 8 +++++++- paddle/gserver/evaluators/Evaluator.h | 13 ++++++++----- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp index 928c77a088ff1..92087fa32b1e4 100644 --- a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp +++ b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp @@ -21,7 +21,7 @@ namespace paddle { /** * calculate sequence-to-sequence edit distance */ -class CTCErrorEvaluator : public NotGetableEvaluator { +class CTCErrorEvaluator : public Evaluator { private: MatrixPtr outActivations_; int numTimes_, numClasses_, numSequences_, blank_; @@ -307,8 +307,10 @@ class CTCErrorEvaluator : public NotGetableEvaluator { } std::string getType(const std::string& name, Error* err) const { - getValue(name, err); - if (!err->isOK()) return ""; + this->getValue(name, err); + if (!err->isOK()) { + return ""; + } return "ctc_edit_distance"; } }; diff --git a/paddle/gserver/evaluators/ChunkEvaluator.cpp b/paddle/gserver/evaluators/ChunkEvaluator.cpp index 1658282f3a5f7..a2ab15eedee4a 100644 --- a/paddle/gserver/evaluators/ChunkEvaluator.cpp +++ b/paddle/gserver/evaluators/ChunkEvaluator.cpp @@ -268,7 +268,13 @@ class ChunkEvaluator : public Evaluator { } // get type of evaluator - std::string getTypeImpl() const { return "chunk"; } + std::string getType(const std::string& name, Error* err) const { + this->getValue(name, err); + if (!err->isOK()) { + return ""; + } + return "chunk"; + } private: void storeLocalValues() const { diff --git a/paddle/gserver/evaluators/Evaluator.h b/paddle/gserver/evaluators/Evaluator.h index b114500e2b7c1..90203553e0a5f 100644 --- a/paddle/gserver/evaluators/Evaluator.h +++ b/paddle/gserver/evaluators/Evaluator.h @@ -211,6 +211,7 @@ class NotGetableEvaluator : public Evaluator { *err = Error("Not implemented"); return .0f; } + std::string getType(const std::string& name, Error* err) const { *err = Error("Not implemented"); return ""; @@ -331,6 +332,7 @@ class RankAucEvaluator : public Evaluator { protected: std::string getTypeImpl() const; }; + /** * @brief precision, recall and f1 score Evaluator * \f[ @@ -358,6 +360,12 @@ class PrecisionRecallEvaluator : public Evaluator { virtual void distributeEval(ParameterClient2* client); + void getNames(std::vector* names); + + real getValue(const std::string& name, Error* err) const; + + std::string getType(const std::string& name, Error* err) const; + struct StatsInfo { /// numbers of true positives double TP; @@ -428,11 +436,6 @@ class PrecisionRecallEvaluator : public Evaluator { mutable std::unordered_map values_; void storeLocalValues() const; - // Evaluator interface -public: - void getNames(std::vector* names); - real getValue(const std::string& name, Error* err) const; - std::string getType(const std::string& name, Error* err) const; }; /*