From c8d5ba63f912921a4c31960266c46f65c1af2a8b Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 15 Jun 2023 13:13:53 +0200 Subject: [PATCH] Save vocabulary files in the JSON format (#1293) * Save vocabulary files in the JSON format * Refactor change * Fix typo in error message --- include/ctranslate2/models/language_model.h | 1 - include/ctranslate2/models/model_reader.h | 8 +++ .../ctranslate2/models/sequence_to_sequence.h | 1 - include/ctranslate2/models/whisper.h | 1 - include/ctranslate2/vocabulary.h | 5 +- python/ctranslate2/specs/model_spec.py | 16 +++--- python/tests/test_opennmt_py.py | 4 +- python/tests/test_opennmt_tf.py | 17 ++----- python/tests/test_transformers.py | 5 +- src/models/language_model.cc | 5 +- src/models/model_reader.cc | 18 +++++++ src/models/sequence_to_sequence.cc | 51 +++++++++---------- src/models/whisper.cc | 6 ++- src/vocabulary.cc | 23 +++++++-- 14 files changed, 94 insertions(+), 67 deletions(-) diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index 568a1037c..fdab65d88 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -6,7 +6,6 @@ #include "ctranslate2/encoding.h" #include "ctranslate2/generation.h" #include "ctranslate2/scoring.h" -#include "ctranslate2/vocabulary.h" namespace ctranslate2 { namespace models { diff --git a/include/ctranslate2/models/model_reader.h b/include/ctranslate2/models/model_reader.h index 5a9cf622b..6e755927c 100644 --- a/include/ctranslate2/models/model_reader.h +++ b/include/ctranslate2/models/model_reader.h @@ -5,6 +5,8 @@ #include #include +#include "ctranslate2/vocabulary.h" + namespace ctranslate2 { namespace models { @@ -50,5 +52,11 @@ namespace ctranslate2 { std::unordered_map _files; }; + + std::shared_ptr + load_vocabulary(ModelReader& model_reader, + const std::string& filename, + VocabularyInfo vocab_info); + } } diff --git a/include/ctranslate2/models/sequence_to_sequence.h b/include/ctranslate2/models/sequence_to_sequence.h index 530e1b382..e1d79327f 100644 --- a/include/ctranslate2/models/sequence_to_sequence.h +++ b/include/ctranslate2/models/sequence_to_sequence.h @@ -5,7 +5,6 @@ #include "ctranslate2/models/model.h" #include "ctranslate2/scoring.h" #include "ctranslate2/translation.h" -#include "ctranslate2/vocabulary.h" #include "ctranslate2/vocabulary_map.h" namespace ctranslate2 { diff --git a/include/ctranslate2/models/whisper.h b/include/ctranslate2/models/whisper.h index 80782c22c..a9ade09e2 100644 --- a/include/ctranslate2/models/whisper.h +++ b/include/ctranslate2/models/whisper.h @@ -4,7 +4,6 @@ #include "ctranslate2/layers/whisper.h" #include "ctranslate2/models/model.h" #include "ctranslate2/replica_pool.h" -#include "ctranslate2/vocabulary.h" namespace ctranslate2 { namespace models { diff --git a/include/ctranslate2/vocabulary.h b/include/ctranslate2/vocabulary.h index 3d0da6751..fd9601d8e 100644 --- a/include/ctranslate2/vocabulary.h +++ b/include/ctranslate2/vocabulary.h @@ -17,7 +17,10 @@ namespace ctranslate2 { class Vocabulary { public: - Vocabulary(std::istream& in, VocabularyInfo info = VocabularyInfo()); + static Vocabulary from_text_file(std::istream& in, VocabularyInfo info = VocabularyInfo()); + static Vocabulary from_json_file(std::istream& in, VocabularyInfo info = VocabularyInfo()); + + Vocabulary(std::vector tokens, VocabularyInfo info = VocabularyInfo()); bool contains(const std::string& token) const; const std::string& to_token(size_t id) const; diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index b850c0700..c3ba5383f 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -497,8 +497,7 @@ def save(self, output_dir: str) -> None: vocabularies = {"shared": all_vocabularies[0]} for name, tokens in vocabularies.items(): - path = os.path.join(output_dir, "%s_vocabulary.txt" % name) - _save_lines(path, tokens) + _save_vocabulary(output_dir, "%s_vocabulary" % name, tokens) # Save the rest of the model. super().save(output_dir) @@ -566,15 +565,14 @@ def validate(self) -> None: def save(self, output_dir: str) -> None: # Save the vocabulary. - vocabulary_path = os.path.join(output_dir, "vocabulary.txt") - _save_lines(vocabulary_path, self._vocabulary) + _save_vocabulary(output_dir, "vocabulary", self._vocabulary) # Save the rest of the model. super().save(output_dir) -def _save_lines(path, lines): - with open(path, "w", encoding="utf-8", newline="") as f: - for line in lines: - f.write(line) - f.write("\n") +def _save_vocabulary(output_dir, name, tokens): + vocabulary_path = os.path.join(output_dir, "%s.json" % name) + + with open(vocabulary_path, "w", encoding="utf-8") as vocabulary_file: + json.dump(tokens, vocabulary_file, indent=2) diff --git a/python/tests/test_opennmt_py.py b/python/tests/test_opennmt_py.py index 9ed5056fb..32ccab542 100644 --- a/python/tests/test_opennmt_py.py +++ b/python/tests/test_opennmt_py.py @@ -58,8 +58,8 @@ def test_opennmt_py_source_features(tmp_dir, filename): converter = ctranslate2.converters.OpenNMTPyConverter(model_path) output_dir = str(tmp_dir.join("ctranslate2_model")) converter.convert(output_dir) - assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.txt")) - assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.txt")) + assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.json")) + assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.json")) source = [ ["آ", "ت", "ز", "م", "و", "ن"], diff --git a/python/tests/test_opennmt_tf.py b/python/tests/test_opennmt_tf.py index a02577340..1f11d926c 100644 --- a/python/tests/test_opennmt_tf.py +++ b/python/tests/test_opennmt_tf.py @@ -39,15 +39,6 @@ def test_opennmt_tf_model_conversion(tmp_dir, model_path): output_dir = str(tmp_dir.join("ctranslate2_model")) converter.convert(output_dir) - src_vocab_path = os.path.join(output_dir, "source_vocabulary.txt") - tgt_vocab_path = os.path.join(output_dir, "target_vocabulary.txt") - - # Check lines end with \n on all platforms. - with open(src_vocab_path, encoding="utf-8", newline="") as vocab_file: - assert vocab_file.readline() == "\n" - with open(tgt_vocab_path, encoding="utf-8", newline="") as vocab_file: - assert vocab_file.readline() == "\n" - translator = ctranslate2.Translator(output_dir) output = translator.translate_batch([["آ", "ت", "ز", "م", "و", "ن"]]) assert output[0].hypotheses[0] == ["a", "t", "z", "m", "o", "n"] @@ -145,7 +136,7 @@ def test_opennmt_tf_shared_embeddings_conversion(tmp_dir): output_dir = str(tmp_dir.join("ctranslate2_model")) converter.convert(output_dir) - assert os.path.isfile(os.path.join(output_dir, "shared_vocabulary.txt")) + assert os.path.isfile(os.path.join(output_dir, "shared_vocabulary.json")) # Check that the translation runs. translator = ctranslate2.Translator(output_dir) @@ -192,7 +183,7 @@ def test_opennmt_tf_gpt_conversion(tmp_dir): converter = ctranslate2.converters.OpenNMTTFConverter(model) converter.convert(output_dir) - assert os.path.isfile(os.path.join(output_dir, "vocabulary.txt")) + assert os.path.isfile(os.path.join(output_dir, "vocabulary.json")) def test_opennmt_tf_multi_features(tmp_dir): @@ -224,5 +215,5 @@ def test_opennmt_tf_multi_features(tmp_dir): output_dir = str(tmp_dir.join("ctranslate2_model")) converter.convert(output_dir) - assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.txt")) - assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.txt")) + assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.json")) + assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.json")) diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 8a328fd38..0779d2ca5 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -215,8 +215,9 @@ def test_transformers_marianmt_vocabulary(clear_transformers_cache, tmp_dir): output_dir = str(tmp_dir.join("ctranslate2_model")) output_dir = converter.convert(output_dir) - with open(os.path.join(output_dir, "shared_vocabulary.txt")) as vocab_file: - vocab = list(line.rstrip("\n") for line in vocab_file) + vocabulary_path = os.path.join(output_dir, "shared_vocabulary.json") + with open(vocabulary_path, encoding="utf-8") as vocabulary_file: + vocab = json.load(vocabulary_file) assert vocab[-1] != "" diff --git a/src/models/language_model.cc b/src/models/language_model.cc index 79438a7fa..fb09d18c7 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -30,8 +30,9 @@ namespace ctranslate2 { vocab_info.bos_token = config["bos_token"]; vocab_info.eos_token = config["eos_token"]; - _vocabulary = std::make_shared(*model_reader.get_required_file("vocabulary.txt"), - std::move(vocab_info)); + _vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info)); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); } diff --git a/src/models/model_reader.cc b/src/models/model_reader.cc index ebcc76c0e..628f9f318 100644 --- a/src/models/model_reader.cc +++ b/src/models/model_reader.cc @@ -75,5 +75,23 @@ namespace ctranslate2 { return std::make_unique(content.data(), content.size()); } + + std::shared_ptr + load_vocabulary(ModelReader& model_reader, + const std::string& filename, + VocabularyInfo vocab_info) { + std::unique_ptr file; + + file = model_reader.get_file(filename + ".json"); + if (file) + return std::make_shared(Vocabulary::from_json_file(*file, std::move(vocab_info))); + + file = model_reader.get_file(filename + ".txt"); + if (file) + return std::make_shared(Vocabulary::from_text_file(*file, std::move(vocab_info))); + + return nullptr; + } + } } diff --git a/src/models/sequence_to_sequence.cc b/src/models/sequence_to_sequence.cc index 9349ea512..a7e64611f 100644 --- a/src/models/sequence_to_sequence.cc +++ b/src/models/sequence_to_sequence.cc @@ -7,9 +7,6 @@ namespace ctranslate2 { namespace models { - static const std::string shared_vocabulary_file = "shared_vocabulary.txt"; - static const std::string source_vocabulary_file = "source_vocabulary.txt"; - static const std::string target_vocabulary_file = "target_vocabulary.txt"; static const std::string vmap_file = "vmap.txt"; @@ -20,37 +17,35 @@ namespace ctranslate2 { vocab_info.bos_token = config["bos_token"]; vocab_info.eos_token = config["eos_token"]; - auto shared_vocabulary = model_reader.get_file(shared_vocabulary_file); + auto shared_vocabulary = load_vocabulary(model_reader, "shared_vocabulary", vocab_info); + if (shared_vocabulary) { - _target_vocabulary = std::make_shared(*shared_vocabulary, vocab_info); - _source_vocabularies.emplace_back(_target_vocabulary); + _target_vocabulary = shared_vocabulary; + _source_vocabularies = {shared_vocabulary}; + } else { + _target_vocabulary = load_vocabulary(model_reader, "target_vocabulary", vocab_info); + if (!_target_vocabulary) + throw std::runtime_error("Cannot load the target vocabulary from the model directory"); - { - auto source_vocabulary = model_reader.get_file(source_vocabulary_file); - if (source_vocabulary) - _source_vocabularies.emplace_back(std::make_shared(*source_vocabulary, - vocab_info)); - else { - for (size_t i = 1;; i++) { - const std::string filename = "source_" + std::to_string(i) + "_vocabulary.txt"; - const auto vocabulary_file = model_reader.get_file(filename); - if (!vocabulary_file) - break; - _source_vocabularies.emplace_back(std::make_shared(*vocabulary_file, - vocab_info)); - } - } + auto source_vocabulary = load_vocabulary(model_reader, "source_vocabulary", vocab_info); - // If no source vocabularies were loaded, raise an error for the first filename. - if (_source_vocabularies.empty()) - model_reader.get_required_file(source_vocabulary_file); - } + if (source_vocabulary) { + _source_vocabularies = {source_vocabulary}; + } else { + for (size_t i = 1;; i++) { + const std::string filename = "source_" + std::to_string(i) + "_vocabulary"; + auto vocabulary = load_vocabulary(model_reader, filename, vocab_info); - { - auto target_vocabulary = model_reader.get_required_file(target_vocabulary_file); - _target_vocabulary = std::make_shared(*target_vocabulary, vocab_info); + if (!vocabulary) + break; + + _source_vocabularies.emplace_back(vocabulary); + } } + + if (_source_vocabularies.empty()) + throw std::runtime_error("Cannot load the source vocabulary from the model directory"); } } diff --git a/src/models/whisper.cc b/src/models/whisper.cc index 1a1a15470..c87088cc6 100644 --- a/src/models/whisper.cc +++ b/src/models/whisper.cc @@ -27,8 +27,10 @@ namespace ctranslate2 { vocab_info.unk_token = "<|endoftext|>"; vocab_info.bos_token = "<|startoftranscript|>"; vocab_info.eos_token = "<|endoftext|>"; - _vocabulary = std::make_shared(*model_reader.get_required_file("vocabulary.txt"), - std::move(vocab_info)); + + _vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info)); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); } bool WhisperModel::is_quantizable(const std::string& variable_name) const { diff --git a/src/vocabulary.cc b/src/vocabulary.cc index c603f528c..a244a72ba 100644 --- a/src/vocabulary.cc +++ b/src/vocabulary.cc @@ -1,12 +1,12 @@ #include "ctranslate2/vocabulary.h" +#include + #include "ctranslate2/utils.h" namespace ctranslate2 { - Vocabulary::Vocabulary(std::istream& in, VocabularyInfo info) - : _info(std::move(info)) - { + Vocabulary Vocabulary::from_text_file(std::istream& in, VocabularyInfo info) { std::vector tokens; std::string line; @@ -21,12 +21,25 @@ namespace ctranslate2 { tokens.emplace_back(std::move(line)); } + if (remove_carriage_return) { + for (auto& token : tokens) + token.pop_back(); + } + + return Vocabulary(std::move(tokens), std::move(info)); + } + + Vocabulary Vocabulary::from_json_file(std::istream& in, VocabularyInfo info) { + return Vocabulary(nlohmann::json::parse(in).get>(), std::move(info)); + } + + Vocabulary::Vocabulary(std::vector tokens, VocabularyInfo info) + : _info(std::move(info)) + { _token_to_id.reserve(tokens.size()); _id_to_token.reserve(tokens.size()); for (auto& token : tokens) { - if (remove_carriage_return) - token.pop_back(); add_token(std::move(token)); }