Skip to content

Commit

Permalink
Save vocabulary files in the JSON format (#1293)
Browse files Browse the repository at this point in the history
* Save vocabulary files in the JSON format

* Refactor change

* Fix typo in error message
  • Loading branch information
guillaumekln authored Jun 15, 2023
1 parent 7f358d2 commit c8d5ba6
Show file tree
Hide file tree
Showing 14 changed files with 94 additions and 67 deletions.
1 change: 0 additions & 1 deletion include/ctranslate2/models/language_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "ctranslate2/encoding.h"
#include "ctranslate2/generation.h"
#include "ctranslate2/scoring.h"
#include "ctranslate2/vocabulary.h"

namespace ctranslate2 {
namespace models {
Expand Down
8 changes: 8 additions & 0 deletions include/ctranslate2/models/model_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <string>
#include <unordered_map>

#include "ctranslate2/vocabulary.h"

namespace ctranslate2 {
namespace models {

Expand Down Expand Up @@ -50,5 +52,11 @@ namespace ctranslate2 {
std::unordered_map<std::string, std::string> _files;
};


std::shared_ptr<Vocabulary>
load_vocabulary(ModelReader& model_reader,
const std::string& filename,
VocabularyInfo vocab_info);

}
}
1 change: 0 additions & 1 deletion include/ctranslate2/models/sequence_to_sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "ctranslate2/models/model.h"
#include "ctranslate2/scoring.h"
#include "ctranslate2/translation.h"
#include "ctranslate2/vocabulary.h"
#include "ctranslate2/vocabulary_map.h"

namespace ctranslate2 {
Expand Down
1 change: 0 additions & 1 deletion include/ctranslate2/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "ctranslate2/layers/whisper.h"
#include "ctranslate2/models/model.h"
#include "ctranslate2/replica_pool.h"
#include "ctranslate2/vocabulary.h"

namespace ctranslate2 {
namespace models {
Expand Down
5 changes: 4 additions & 1 deletion include/ctranslate2/vocabulary.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ namespace ctranslate2 {
class Vocabulary
{
public:
Vocabulary(std::istream& in, VocabularyInfo info = VocabularyInfo());
static Vocabulary from_text_file(std::istream& in, VocabularyInfo info = VocabularyInfo());
static Vocabulary from_json_file(std::istream& in, VocabularyInfo info = VocabularyInfo());

Vocabulary(std::vector<std::string> tokens, VocabularyInfo info = VocabularyInfo());

bool contains(const std::string& token) const;
const std::string& to_token(size_t id) const;
Expand Down
16 changes: 7 additions & 9 deletions python/ctranslate2/specs/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,7 @@ def save(self, output_dir: str) -> None:
vocabularies = {"shared": all_vocabularies[0]}

for name, tokens in vocabularies.items():
path = os.path.join(output_dir, "%s_vocabulary.txt" % name)
_save_lines(path, tokens)
_save_vocabulary(output_dir, "%s_vocabulary" % name, tokens)

# Save the rest of the model.
super().save(output_dir)
Expand Down Expand Up @@ -566,15 +565,14 @@ def validate(self) -> None:

def save(self, output_dir: str) -> None:
# Save the vocabulary.
vocabulary_path = os.path.join(output_dir, "vocabulary.txt")
_save_lines(vocabulary_path, self._vocabulary)
_save_vocabulary(output_dir, "vocabulary", self._vocabulary)

# Save the rest of the model.
super().save(output_dir)


def _save_lines(path, lines):
with open(path, "w", encoding="utf-8", newline="") as f:
for line in lines:
f.write(line)
f.write("\n")
def _save_vocabulary(output_dir, name, tokens):
vocabulary_path = os.path.join(output_dir, "%s.json" % name)

with open(vocabulary_path, "w", encoding="utf-8") as vocabulary_file:
json.dump(tokens, vocabulary_file, indent=2)
4 changes: 2 additions & 2 deletions python/tests/test_opennmt_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def test_opennmt_py_source_features(tmp_dir, filename):
converter = ctranslate2.converters.OpenNMTPyConverter(model_path)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.json"))
assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.json"))

source = [
["آ", "ت", "ز", "م", "و", "ن"],
Expand Down
17 changes: 4 additions & 13 deletions python/tests/test_opennmt_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,6 @@ def test_opennmt_tf_model_conversion(tmp_dir, model_path):
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)

src_vocab_path = os.path.join(output_dir, "source_vocabulary.txt")
tgt_vocab_path = os.path.join(output_dir, "target_vocabulary.txt")

# Check lines end with \n on all platforms.
with open(src_vocab_path, encoding="utf-8", newline="") as vocab_file:
assert vocab_file.readline() == "<blank>\n"
with open(tgt_vocab_path, encoding="utf-8", newline="") as vocab_file:
assert vocab_file.readline() == "<blank>\n"

translator = ctranslate2.Translator(output_dir)
output = translator.translate_batch([["آ", "ت", "ز", "م", "و", "ن"]])
assert output[0].hypotheses[0] == ["a", "t", "z", "m", "o", "n"]
Expand Down Expand Up @@ -145,7 +136,7 @@ def test_opennmt_tf_shared_embeddings_conversion(tmp_dir):
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)

assert os.path.isfile(os.path.join(output_dir, "shared_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "shared_vocabulary.json"))

# Check that the translation runs.
translator = ctranslate2.Translator(output_dir)
Expand Down Expand Up @@ -192,7 +183,7 @@ def test_opennmt_tf_gpt_conversion(tmp_dir):
converter = ctranslate2.converters.OpenNMTTFConverter(model)
converter.convert(output_dir)

assert os.path.isfile(os.path.join(output_dir, "vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "vocabulary.json"))


def test_opennmt_tf_multi_features(tmp_dir):
Expand Down Expand Up @@ -224,5 +215,5 @@ def test_opennmt_tf_multi_features(tmp_dir):
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)

assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.json"))
assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.json"))
5 changes: 3 additions & 2 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,9 @@ def test_transformers_marianmt_vocabulary(clear_transformers_cache, tmp_dir):
output_dir = str(tmp_dir.join("ctranslate2_model"))
output_dir = converter.convert(output_dir)

with open(os.path.join(output_dir, "shared_vocabulary.txt")) as vocab_file:
vocab = list(line.rstrip("\n") for line in vocab_file)
vocabulary_path = os.path.join(output_dir, "shared_vocabulary.json")
with open(vocabulary_path, encoding="utf-8") as vocabulary_file:
vocab = json.load(vocabulary_file)

assert vocab[-1] != "<pad>"

Expand Down
5 changes: 3 additions & 2 deletions src/models/language_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ namespace ctranslate2 {
vocab_info.bos_token = config["bos_token"];
vocab_info.eos_token = config["eos_token"];

_vocabulary = std::make_shared<Vocabulary>(*model_reader.get_required_file("vocabulary.txt"),
std::move(vocab_info));
_vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info));
if (!_vocabulary)
throw std::runtime_error("Cannot load the vocabulary from the model directory");
}


Expand Down
18 changes: 18 additions & 0 deletions src/models/model_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,23 @@ namespace ctranslate2 {
return std::make_unique<imemstream>(content.data(), content.size());
}


std::shared_ptr<Vocabulary>
load_vocabulary(ModelReader& model_reader,
const std::string& filename,
VocabularyInfo vocab_info) {
std::unique_ptr<std::istream> file;

file = model_reader.get_file(filename + ".json");
if (file)
return std::make_shared<Vocabulary>(Vocabulary::from_json_file(*file, std::move(vocab_info)));

file = model_reader.get_file(filename + ".txt");
if (file)
return std::make_shared<Vocabulary>(Vocabulary::from_text_file(*file, std::move(vocab_info)));

return nullptr;
}

}
}
51 changes: 23 additions & 28 deletions src/models/sequence_to_sequence.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
namespace ctranslate2 {
namespace models {

static const std::string shared_vocabulary_file = "shared_vocabulary.txt";
static const std::string source_vocabulary_file = "source_vocabulary.txt";
static const std::string target_vocabulary_file = "target_vocabulary.txt";
static const std::string vmap_file = "vmap.txt";


Expand All @@ -20,37 +17,35 @@ namespace ctranslate2 {
vocab_info.bos_token = config["bos_token"];
vocab_info.eos_token = config["eos_token"];

auto shared_vocabulary = model_reader.get_file(shared_vocabulary_file);
auto shared_vocabulary = load_vocabulary(model_reader, "shared_vocabulary", vocab_info);

if (shared_vocabulary) {
_target_vocabulary = std::make_shared<Vocabulary>(*shared_vocabulary, vocab_info);
_source_vocabularies.emplace_back(_target_vocabulary);
_target_vocabulary = shared_vocabulary;
_source_vocabularies = {shared_vocabulary};

} else {
_target_vocabulary = load_vocabulary(model_reader, "target_vocabulary", vocab_info);
if (!_target_vocabulary)
throw std::runtime_error("Cannot load the target vocabulary from the model directory");

{
auto source_vocabulary = model_reader.get_file(source_vocabulary_file);
if (source_vocabulary)
_source_vocabularies.emplace_back(std::make_shared<Vocabulary>(*source_vocabulary,
vocab_info));
else {
for (size_t i = 1;; i++) {
const std::string filename = "source_" + std::to_string(i) + "_vocabulary.txt";
const auto vocabulary_file = model_reader.get_file(filename);
if (!vocabulary_file)
break;
_source_vocabularies.emplace_back(std::make_shared<Vocabulary>(*vocabulary_file,
vocab_info));
}
}
auto source_vocabulary = load_vocabulary(model_reader, "source_vocabulary", vocab_info);

// If no source vocabularies were loaded, raise an error for the first filename.
if (_source_vocabularies.empty())
model_reader.get_required_file(source_vocabulary_file);
}
if (source_vocabulary) {
_source_vocabularies = {source_vocabulary};
} else {
for (size_t i = 1;; i++) {
const std::string filename = "source_" + std::to_string(i) + "_vocabulary";
auto vocabulary = load_vocabulary(model_reader, filename, vocab_info);

{
auto target_vocabulary = model_reader.get_required_file(target_vocabulary_file);
_target_vocabulary = std::make_shared<Vocabulary>(*target_vocabulary, vocab_info);
if (!vocabulary)
break;

_source_vocabularies.emplace_back(vocabulary);
}
}

if (_source_vocabularies.empty())
throw std::runtime_error("Cannot load the source vocabulary from the model directory");
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ namespace ctranslate2 {
vocab_info.unk_token = "<|endoftext|>";
vocab_info.bos_token = "<|startoftranscript|>";
vocab_info.eos_token = "<|endoftext|>";
_vocabulary = std::make_shared<Vocabulary>(*model_reader.get_required_file("vocabulary.txt"),
std::move(vocab_info));

_vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info));
if (!_vocabulary)
throw std::runtime_error("Cannot load the vocabulary from the model directory");
}

bool WhisperModel::is_quantizable(const std::string& variable_name) const {
Expand Down
23 changes: 18 additions & 5 deletions src/vocabulary.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include "ctranslate2/vocabulary.h"

#include <nlohmann/json.hpp>

#include "ctranslate2/utils.h"

namespace ctranslate2 {

Vocabulary::Vocabulary(std::istream& in, VocabularyInfo info)
: _info(std::move(info))
{
Vocabulary Vocabulary::from_text_file(std::istream& in, VocabularyInfo info) {
std::vector<std::string> tokens;
std::string line;

Expand All @@ -21,12 +21,25 @@ namespace ctranslate2 {
tokens.emplace_back(std::move(line));
}

if (remove_carriage_return) {
for (auto& token : tokens)
token.pop_back();
}

return Vocabulary(std::move(tokens), std::move(info));
}

Vocabulary Vocabulary::from_json_file(std::istream& in, VocabularyInfo info) {
return Vocabulary(nlohmann::json::parse(in).get<std::vector<std::string>>(), std::move(info));
}

Vocabulary::Vocabulary(std::vector<std::string> tokens, VocabularyInfo info)
: _info(std::move(info))
{
_token_to_id.reserve(tokens.size());
_id_to_token.reserve(tokens.size());

for (auto& token : tokens) {
if (remove_carriage_return)
token.pop_back();
add_token(std::move(token));
}

Expand Down

0 comments on commit c8d5ba6

Please sign in to comment.