Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save vocabulary files in the JSON format #1293

Merged
merged 3 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/ctranslate2/models/language_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "ctranslate2/encoding.h"
#include "ctranslate2/generation.h"
#include "ctranslate2/scoring.h"
#include "ctranslate2/vocabulary.h"

namespace ctranslate2 {
namespace models {
Expand Down
8 changes: 8 additions & 0 deletions include/ctranslate2/models/model_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <string>
#include <unordered_map>

#include "ctranslate2/vocabulary.h"

namespace ctranslate2 {
namespace models {

Expand Down Expand Up @@ -50,5 +52,11 @@ namespace ctranslate2 {
std::unordered_map<std::string, std::string> _files;
};


std::shared_ptr<Vocabulary>
load_vocabulary(ModelReader& model_reader,
const std::string& filename,
VocabularyInfo vocab_info);

}
}
1 change: 0 additions & 1 deletion include/ctranslate2/models/sequence_to_sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "ctranslate2/models/model.h"
#include "ctranslate2/scoring.h"
#include "ctranslate2/translation.h"
#include "ctranslate2/vocabulary.h"
#include "ctranslate2/vocabulary_map.h"

namespace ctranslate2 {
Expand Down
1 change: 0 additions & 1 deletion include/ctranslate2/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "ctranslate2/layers/whisper.h"
#include "ctranslate2/models/model.h"
#include "ctranslate2/replica_pool.h"
#include "ctranslate2/vocabulary.h"

namespace ctranslate2 {
namespace models {
Expand Down
5 changes: 4 additions & 1 deletion include/ctranslate2/vocabulary.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ namespace ctranslate2 {
class Vocabulary
{
public:
Vocabulary(std::istream& in, VocabularyInfo info = VocabularyInfo());
static Vocabulary from_text_file(std::istream& in, VocabularyInfo info = VocabularyInfo());
static Vocabulary from_json_file(std::istream& in, VocabularyInfo info = VocabularyInfo());

Vocabulary(std::vector<std::string> tokens, VocabularyInfo info = VocabularyInfo());

bool contains(const std::string& token) const;
const std::string& to_token(size_t id) const;
Expand Down
16 changes: 7 additions & 9 deletions python/ctranslate2/specs/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,7 @@ def save(self, output_dir: str) -> None:
vocabularies = {"shared": all_vocabularies[0]}

for name, tokens in vocabularies.items():
path = os.path.join(output_dir, "%s_vocabulary.txt" % name)
_save_lines(path, tokens)
_save_vocabulary(output_dir, "%s_vocabulary" % name, tokens)

# Save the rest of the model.
super().save(output_dir)
Expand Down Expand Up @@ -566,15 +565,14 @@ def validate(self) -> None:

def save(self, output_dir: str) -> None:
# Save the vocabulary.
vocabulary_path = os.path.join(output_dir, "vocabulary.txt")
_save_lines(vocabulary_path, self._vocabulary)
_save_vocabulary(output_dir, "vocabulary", self._vocabulary)

# Save the rest of the model.
super().save(output_dir)


def _save_lines(path, lines):
with open(path, "w", encoding="utf-8", newline="") as f:
for line in lines:
f.write(line)
f.write("\n")
def _save_vocabulary(output_dir, name, tokens):
vocabulary_path = os.path.join(output_dir, "%s.json" % name)

with open(vocabulary_path, "w", encoding="utf-8") as vocabulary_file:
json.dump(tokens, vocabulary_file, indent=2)
4 changes: 2 additions & 2 deletions python/tests/test_opennmt_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def test_opennmt_py_source_features(tmp_dir, filename):
converter = ctranslate2.converters.OpenNMTPyConverter(model_path)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.json"))
assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.json"))

source = [
["آ", "ت", "ز", "م", "و", "ن"],
Expand Down
17 changes: 4 additions & 13 deletions python/tests/test_opennmt_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,6 @@ def test_opennmt_tf_model_conversion(tmp_dir, model_path):
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)

src_vocab_path = os.path.join(output_dir, "source_vocabulary.txt")
tgt_vocab_path = os.path.join(output_dir, "target_vocabulary.txt")

# Check lines end with \n on all platforms.
with open(src_vocab_path, encoding="utf-8", newline="") as vocab_file:
assert vocab_file.readline() == "<blank>\n"
with open(tgt_vocab_path, encoding="utf-8", newline="") as vocab_file:
assert vocab_file.readline() == "<blank>\n"

translator = ctranslate2.Translator(output_dir)
output = translator.translate_batch([["آ", "ت", "ز", "م", "و", "ن"]])
assert output[0].hypotheses[0] == ["a", "t", "z", "m", "o", "n"]
Expand Down Expand Up @@ -145,7 +136,7 @@ def test_opennmt_tf_shared_embeddings_conversion(tmp_dir):
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)

assert os.path.isfile(os.path.join(output_dir, "shared_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "shared_vocabulary.json"))

# Check that the translation runs.
translator = ctranslate2.Translator(output_dir)
Expand Down Expand Up @@ -192,7 +183,7 @@ def test_opennmt_tf_gpt_conversion(tmp_dir):
converter = ctranslate2.converters.OpenNMTTFConverter(model)
converter.convert(output_dir)

assert os.path.isfile(os.path.join(output_dir, "vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "vocabulary.json"))


def test_opennmt_tf_multi_features(tmp_dir):
Expand Down Expand Up @@ -224,5 +215,5 @@ def test_opennmt_tf_multi_features(tmp_dir):
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)

assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.json"))
assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.json"))
5 changes: 3 additions & 2 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,9 @@ def test_transformers_marianmt_vocabulary(clear_transformers_cache, tmp_dir):
output_dir = str(tmp_dir.join("ctranslate2_model"))
output_dir = converter.convert(output_dir)

with open(os.path.join(output_dir, "shared_vocabulary.txt")) as vocab_file:
vocab = list(line.rstrip("\n") for line in vocab_file)
vocabulary_path = os.path.join(output_dir, "shared_vocabulary.json")
with open(vocabulary_path, encoding="utf-8") as vocabulary_file:
vocab = json.load(vocabulary_file)

assert vocab[-1] != "<pad>"

Expand Down
5 changes: 3 additions & 2 deletions src/models/language_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ namespace ctranslate2 {
vocab_info.bos_token = config["bos_token"];
vocab_info.eos_token = config["eos_token"];

_vocabulary = std::make_shared<Vocabulary>(*model_reader.get_required_file("vocabulary.txt"),
std::move(vocab_info));
_vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info));
if (!_vocabulary)
throw std::runtime_error("Cannot load the vocabulary from the model directory");
}


Expand Down
18 changes: 18 additions & 0 deletions src/models/model_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,23 @@ namespace ctranslate2 {
return std::make_unique<imemstream>(content.data(), content.size());
}


std::shared_ptr<Vocabulary>
load_vocabulary(ModelReader& model_reader,
const std::string& filename,
VocabularyInfo vocab_info) {
std::unique_ptr<std::istream> file;

file = model_reader.get_file(filename + ".json");
if (file)
return std::make_shared<Vocabulary>(Vocabulary::from_json_file(*file, std::move(vocab_info)));

file = model_reader.get_file(filename + ".txt");
if (file)
return std::make_shared<Vocabulary>(Vocabulary::from_text_file(*file, std::move(vocab_info)));

return nullptr;
}

}
}
51 changes: 23 additions & 28 deletions src/models/sequence_to_sequence.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
namespace ctranslate2 {
namespace models {

static const std::string shared_vocabulary_file = "shared_vocabulary.txt";
static const std::string source_vocabulary_file = "source_vocabulary.txt";
static const std::string target_vocabulary_file = "target_vocabulary.txt";
static const std::string vmap_file = "vmap.txt";


Expand All @@ -20,37 +17,35 @@ namespace ctranslate2 {
vocab_info.bos_token = config["bos_token"];
vocab_info.eos_token = config["eos_token"];

auto shared_vocabulary = model_reader.get_file(shared_vocabulary_file);
auto shared_vocabulary = load_vocabulary(model_reader, "shared_vocabulary", vocab_info);

if (shared_vocabulary) {
_target_vocabulary = std::make_shared<Vocabulary>(*shared_vocabulary, vocab_info);
_source_vocabularies.emplace_back(_target_vocabulary);
_target_vocabulary = shared_vocabulary;
_source_vocabularies = {shared_vocabulary};

} else {
_target_vocabulary = load_vocabulary(model_reader, "target_vocabulary", vocab_info);
if (!_target_vocabulary)
throw std::runtime_error("Cannot load the target vocabulary from the model directory");

{
auto source_vocabulary = model_reader.get_file(source_vocabulary_file);
if (source_vocabulary)
_source_vocabularies.emplace_back(std::make_shared<Vocabulary>(*source_vocabulary,
vocab_info));
else {
for (size_t i = 1;; i++) {
const std::string filename = "source_" + std::to_string(i) + "_vocabulary.txt";
const auto vocabulary_file = model_reader.get_file(filename);
if (!vocabulary_file)
break;
_source_vocabularies.emplace_back(std::make_shared<Vocabulary>(*vocabulary_file,
vocab_info));
}
}
auto source_vocabulary = load_vocabulary(model_reader, "source_vocabulary", vocab_info);

// If no source vocabularies were loaded, raise an error for the first filename.
if (_source_vocabularies.empty())
model_reader.get_required_file(source_vocabulary_file);
}
if (source_vocabulary) {
_source_vocabularies = {source_vocabulary};
} else {
for (size_t i = 1;; i++) {
const std::string filename = "source_" + std::to_string(i) + "_vocabulary";
auto vocabulary = load_vocabulary(model_reader, filename, vocab_info);

{
auto target_vocabulary = model_reader.get_required_file(target_vocabulary_file);
_target_vocabulary = std::make_shared<Vocabulary>(*target_vocabulary, vocab_info);
if (!vocabulary)
break;

_source_vocabularies.emplace_back(vocabulary);
}
}

if (_source_vocabularies.empty())
throw std::runtime_error("Cannot load the source vocabulary from the model directory");
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ namespace ctranslate2 {
vocab_info.unk_token = "<|endoftext|>";
vocab_info.bos_token = "<|startoftranscript|>";
vocab_info.eos_token = "<|endoftext|>";
_vocabulary = std::make_shared<Vocabulary>(*model_reader.get_required_file("vocabulary.txt"),
std::move(vocab_info));

_vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info));
if (!_vocabulary)
throw std::runtime_error("Cannot load the vocabulary from the model directory");
}

bool WhisperModel::is_quantizable(const std::string& variable_name) const {
Expand Down
23 changes: 18 additions & 5 deletions src/vocabulary.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include "ctranslate2/vocabulary.h"

#include <nlohmann/json.hpp>

#include "ctranslate2/utils.h"

namespace ctranslate2 {

Vocabulary::Vocabulary(std::istream& in, VocabularyInfo info)
: _info(std::move(info))
{
Vocabulary Vocabulary::from_text_file(std::istream& in, VocabularyInfo info) {
std::vector<std::string> tokens;
std::string line;

Expand All @@ -21,12 +21,25 @@ namespace ctranslate2 {
tokens.emplace_back(std::move(line));
}

if (remove_carriage_return) {
for (auto& token : tokens)
token.pop_back();
}

return Vocabulary(std::move(tokens), std::move(info));
}

Vocabulary Vocabulary::from_json_file(std::istream& in, VocabularyInfo info) {
return Vocabulary(nlohmann::json::parse(in).get<std::vector<std::string>>(), std::move(info));
}

Vocabulary::Vocabulary(std::vector<std::string> tokens, VocabularyInfo info)
: _info(std::move(info))
{
_token_to_id.reserve(tokens.size());
_id_to_token.reserve(tokens.size());

for (auto& token : tokens) {
if (remove_carriage_return)
token.pop_back();
add_token(std::move(token));
}

Expand Down