Skip to content

Commit

Permalink
fix "\n" tokenization + phi-2 new layer names
Browse files Browse the repository at this point in the history
vince62s committed Jan 18, 2024
1 parent 8045a86 commit 0d88622
Showing 15 changed files with 126 additions and 85 deletions.
6 changes: 3 additions & 3 deletions eval_llm/MMLU/run_mmlu_opennmt.py
Original file line number Diff line number Diff line change
@@ -167,12 +167,12 @@ def evaluate(opt):
prompt_end = format_example(test_df, i, include_answer=False)
train_prompt = gen_prompt(dev_df, task, k)
prompt = train_prompt + prompt_end
"""
while len(prompt.split()) > 768:

while len(prompt.split(" ")) > 768:
prompt_split = prompt.split("\n\n")
prompt_split.pop(1)
prompt = "\n\n".join(prompt_split)
"""

label = test_df.iloc[i, test_df.shape[1] - 1]
records.append({"prompt": prompt, "answer": label})
src.append(prompt.replace("\n", "⦅newline⦆"))
5 changes: 2 additions & 3 deletions eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py
Original file line number Diff line number Diff line change
@@ -119,7 +119,7 @@ def evaluate(opt):
engine = InferenceEnginePY(engine_opt)

# Tokenize the dataset.
opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
opt.src = "eval_llm/WIKITEXT2/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
tokenize_dataset(opt, context_length=512)

# Score the tokeznized dataset
@@ -140,8 +140,7 @@ def evaluate(opt):

def _get_parser():
parser = ArgumentParser(description="run_wikitext-2_benchmark.py")
opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
return parser


5 changes: 4 additions & 1 deletion onmt/decoders/ensemble.py
Original file line number Diff line number Diff line change
@@ -65,7 +65,10 @@ def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
for i, model_decoder in enumerate(self.model_decoders)
]
)
mean_attns = self.combine_attns(attns)
if attns[0]["std"] is not None:
mean_attns = self.combine_attns(attns)
else:
mean_attns = attns
return EnsembleDecoderOutput(dec_outs), mean_attns

def combine_attns(self, attns):
12 changes: 6 additions & 6 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
@@ -174,20 +174,20 @@ def __init__(

def _process(self, stream):
for i, example in enumerate(stream):
example["src"] = example["src"].strip("\n").split()
example["src_original"] = example["src_original"].strip("\n").split()
example["src"] = example["src"].strip().split(" ")
example["src_original"] = example["src_original"].strip().split(" ")
if "src_feats" in example:
example["src_feats"] = [
feat.strip("\n").split() for feat in example["src_feats"]
feat.strip().split(" ") for feat in example["src_feats"]
]
line_number = i * self.stride + self.offset
example["cid_line_number"] = line_number
example["cid"] = self.cid
if "align" in example:
example["align"] = example["align"].strip("\n").split()
example["align"] = example["align"].strip().split(" ")
if example["tgt"] is not None:
example["tgt"] = example["tgt"].strip("\n").split()
example["tgt_original"] = example["tgt_original"].strip("\n").split()
example["tgt"] = example["tgt"].strip().split(" ")
example["tgt_original"] = example["tgt_original"].strip().split(" ")
if (
len(example["src"]) == 0
or len(example["tgt"]) == 0
20 changes: 11 additions & 9 deletions onmt/inputters/text_utils.py
Original file line number Diff line number Diff line change
@@ -121,31 +121,33 @@ def numericalize(vocabs, example):
numeric = example
numeric["src"]["src_ids"] = []
if vocabs["data_task"] == ModelTask.SEQ2SEQ:
src_text = example["src"]["src"].split()
src_text = example["src"]["src"].split(" ")
numeric["src"]["src_ids"] = vocabs["src"](src_text)
if example["tgt"] is not None:
numeric["tgt"]["tgt_ids"] = []
tgt_text = example["tgt"]["tgt"].split()
tgt_text = example["tgt"]["tgt"].split(" ")
numeric["tgt"]["tgt_ids"] = vocabs["tgt"](
[decoder_start_token] + tgt_text + [DefaultTokens.EOS]
)

elif vocabs["data_task"] == ModelTask.LANGUAGE_MODEL:
src_text = example["src"]["src"].split()
src_text = example["src"]["src"].split(" ")
if decoder_start_token != "":
src_text = [decoder_start_token] + src_text
numeric["src"]["src_ids"] = vocabs["src"](src_text)
if example["tgt"] is not None:
numeric["tgt"]["tgt_ids"] = []
tgt_text = example["tgt"]["tgt"].split()
tgt_text = example["tgt"]["tgt"].split(" ")
numeric["tgt"]["tgt_ids"] = vocabs["tgt"](tgt_text + [DefaultTokens.EOS])
if decoder_start_token == "":
numeric["tgt"]["tgt_ids"] = numeric["tgt"]["tgt_ids"][1:]
else:
raise ValueError(f"Something went wrong with task {vocabs['data_task']}")

if "feats" in example["src"]:
numeric_feats = []
for fv, feat in zip(vocabs["src_feats"], example["src"]["feats"]):
numeric_feats.append(fv(feat.split()))
numeric_feats.append(fv(feat.split(" ")))
numeric["src"]["feats"] = numeric_feats

return numeric
@@ -329,7 +331,7 @@ def textbatch_to_tensor(vocabs, batch, device, is_train=False):
infer_iter = []
for i, ex in enumerate(batch):
# Keep it consistent with dynamic data
ex["srclen"] = len(ex["src"]["src"].split())
ex["srclen"] = len(ex["src"]["src"].split(" "))
ex["in_in_bucket"] = i
ex["cid"] = "text"
ex["cid_line_number"] = i
@@ -354,7 +356,7 @@ def _addcopykeys(vocabs, example):
Returns:
``example``, changed as described.
"""
src = example["src"]["src"].split()
src = example["src"]["src"].split(" ")
src_ex_vocab = pyonmttok.build_vocab_from_tokens(
Counter(src),
maximum_size=0,
@@ -377,10 +379,10 @@ def _addcopykeys(vocabs, example):
if vocabs["data_task"] == ModelTask.SEQ2SEQ:
tgt = (
[DefaultTokens.UNK]
+ example["tgt"]["tgt"].split()
+ example["tgt"]["tgt"].split(" ")
+ [DefaultTokens.UNK]
)
elif vocabs["data_task"] == ModelTask.LANGUAGE_MODEL:
tgt = example["tgt"]["tgt"].split() + [DefaultTokens.UNK]
tgt = example["tgt"]["tgt"].split(" ") + [DefaultTokens.UNK]
example["alignment"] = src_ex_vocab(tgt)
return example
2 changes: 1 addition & 1 deletion onmt/transforms/fuzzymatch.py
Original file line number Diff line number Diff line change
@@ -216,6 +216,6 @@ def batch_apply(self, batch, is_train=False, stats=None, **kwargs):
assert len(src_segments) == len(fuzzied_src)
for idx, (example, _, _) in enumerate(batch):
if fuzzied_src[idx] != "":
example["src"] = fuzzied_src[idx].split()
example["src"] = fuzzied_src[idx].split(" ")

return batch
14 changes: 7 additions & 7 deletions onmt/transforms/inlinetags.py
Original file line number Diff line number Diff line change
@@ -73,8 +73,8 @@ def _tagged_src_tgt(self, src_example, tgt_example) -> tuple:
maybe_augmented[1].strip() if len(maybe_augmented) > 1 else None
)

tokenized_source_string = source_only.split()
tokenized_target_string = tgt_example.split()
tokenized_source_string = source_only.split(" ")
tokenized_target_string = tgt_example.split(" ")

src_offset, tgt_offset = 0, 0
src_with_tags, tgt_with_tags = list(), list()
@@ -140,12 +140,12 @@ def _tagged_src_tgt(self, src_example, tgt_example) -> tuple:

src_term = " ".join(
tokenized_source_string[
source_index : source_index + len(pair[0].split())
source_index : source_index + len(pair[0].split(" "))
]
)
tgt_term = " ".join(
tokenized_target_string[
target_index : target_index + len(pair[1].split())
target_index : target_index + len(pair[1].split(" "))
]
)

@@ -210,11 +210,11 @@ def _tagged_src_tgt(self, src_example, tgt_example) -> tuple:
tgt_with_tags.append(tgt_example[tgt_offset:])

return (
"".join(src_with_tags).replace("∥", " ").split(),
"".join(tgt_with_tags).replace("∥", " ").split(),
"".join(src_with_tags).replace("∥", " ").split(" "),
"".join(tgt_with_tags).replace("∥", " ").split(" "),
), is_match
else:
return (src_example.split(), tgt_example.split()), is_match
return (src_example.split(" "), tgt_example.split(" ")), is_match


@register_transform(name="inlinetags")
16 changes: 8 additions & 8 deletions onmt/transforms/misc.py
Original file line number Diff line number Diff line change
@@ -136,8 +136,8 @@ def get_specials(cls, opts):
prefix_dict = cls.get_prefix_dict(opts)
src_specials, tgt_specials = set(), set()
for _, prefix in prefix_dict.items():
src_specials.update(prefix["src"].split())
tgt_specials.update(prefix["tgt"].split())
src_specials.update(prefix["src"].split(" "))
tgt_specials.update(prefix["tgt"].split(" "))
return (src_specials, tgt_specials)

def warm_up(self, vocabs=None):
@@ -149,9 +149,9 @@ def _prepend(self, example, prefix):
"""Prepend `prefix` to `tokens`."""
for side, side_prefix in prefix.items():
if example.get(side) is not None:
example[side] = side_prefix.split() + example[side]
example[side] = side_prefix.split(" ") + example[side]
elif len(side_prefix) > 0:
example[side] = side_prefix.split()
example[side] = side_prefix.split(" ")
return example

def apply(self, example, is_train=False, stats=None, **kwargs):
@@ -250,8 +250,8 @@ def get_specials(cls, opts):
suffix_dict = cls.get_suffix_dict(opts)
src_specials, tgt_specials = set(), set()
for _, suffix in suffix_dict.items():
src_specials.update(suffix["src"].split())
tgt_specials.update(suffix["tgt"].split())
src_specials.update(suffix["src"].split(" "))
tgt_specials.update(suffix["tgt"].split(" "))
return (src_specials, tgt_specials)

def warm_up(self, vocabs=None):
@@ -263,9 +263,9 @@ def _append(self, example, suffix):
"""Prepend `suffix` to `tokens`."""
for side, side_suffix in suffix.items():
if example.get(side) is not None:
example[side] = example[side] + side_suffix.split()
example[side] = example[side] + side_suffix.split(" ")
elif len(side_suffix) > 0:
example[side] = side_suffix.split()
example[side] = side_suffix.split(" ")
return example

def apply(self, example, is_train=False, stats=None, **kwargs):
4 changes: 2 additions & 2 deletions onmt/transforms/normalize.py
Original file line number Diff line number Diff line change
@@ -329,7 +329,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
self.pre_dict[corpus_name],
self.post_dict[corpus_name],
)
example["src"] = src_str.split()
example["src"] = src_str.split(" ")

if example["tgt"] is not None:
tgt_str = self.mpn.normalize(
@@ -341,6 +341,6 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
self.pre_dict[corpus_name],
self.post_dict[corpus_name],
)
example["tgt"] = tgt_str.split()
example["tgt"] = tgt_str.split(" ")

return example
16 changes: 8 additions & 8 deletions onmt/transforms/terminology.py
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ def _create_internal_termbase(self, termbase_path):
for pair in pairs:
src_term, tgt_term = map(str, pair.split("\t"))
src_lemma = " ".join(
"∥".join(tok.lemma_.split()) for tok in self.src_nlp(src_term)
"∥".join(tok.lemma_.split(" ")) for tok in self.src_nlp(src_term)
).strip()
tgt_lemma = " ".join(
tok.lemma_ for tok in self.tgt_nlp(tgt_term)
@@ -93,7 +93,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple:

# Perform tokenization with spacy for consistency.
tokenized_source = [tok.text for tok in doc_src]
lemmatized_source = ["∥".join(tok.lemma_.lower().split()) for tok in doc_src]
lemmatized_source = ["∥".join(tok.lemma_.lower().split(" ")) for tok in doc_src]
lemmatized_target = [tok.lemma_.lower() for tok in doc_tgt]

lemmatized_source_string = " ".join(lemmatized_source)
@@ -143,7 +143,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple:
lemma_list_index += len(w) + 1

# We need to know if the term is multiword
num_words_in_src_term = len(src_entry.split())
num_words_in_src_term = len(src_entry.split(" "))
src_term = " ".join(
tokenized_source[
lemma_list_index : lemma_list_index + num_words_in_src_term
@@ -164,7 +164,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple:

if is_match:
source_with_terms.append(lemmatized_source_string[offset:])
tokenized_source_with_terms = "".join(source_with_terms).split()
tokenized_source_with_terms = "".join(source_with_terms).split(" ")

if not (
len(tokenized_source)
@@ -173,7 +173,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple:
):
final_string = " ".join(tokenized_source)
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string)
return fixed_punct.split(), not is_match
return fixed_punct.split(" "), not is_match

# Construct the final source from the lemmatized list
# that contains the terms. We compare the tokens in the
@@ -195,17 +195,17 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple:
final_string = " ".join(
completed_tokenized_source
+ [self.delimiter]
+ augmented_part.split()
+ augmented_part.split(" ")
)
else:
final_string = " ".join(completed_tokenized_source)

fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string)
return fixed_punct.split(), is_match
return fixed_punct.split(" "), is_match
else:
final_string = " ".join(tokenized_source)
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string)
return fixed_punct.split(), not is_match
return fixed_punct.split(" "), not is_match


@register_transform(name="terminology")
15 changes: 11 additions & 4 deletions onmt/transforms/tokenize.py
Original file line number Diff line number Diff line change
@@ -283,7 +283,7 @@ def apply_reverse(self, translated):
if isinstance(translated, list):
return self._detokenize(translated, "tgt")
else:
return self._detokenize(translated.split(), "tgt")
return self._detokenize(translated.split(" "), "tgt")

def _repr_args(self):
"""Return str represent key arguments for class."""
@@ -353,7 +353,7 @@ def apply_reverse(self, translated):
if isinstance(translated, list):
return self._detokenize(translated, "tgt")
else:
return self._detokenize(translated.split(), "tgt")
return self._detokenize(translated.split(" "), "tgt")


@register_transform(name="onmt_tokenize")
@@ -550,7 +550,14 @@ def tokenize_string(self, sentence, side="src", is_train=False):
self.maptable[b]
for b in sentence.replace(DefaultTokens.SEP, "\n").encode("utf-8")
)
segmented = tokenizer(sentence)
segmented1 = tokenizer(sentence)
segmented = []
# ugly patch to make sure "\n\n" is split in two items
for s in segmented1:
if s == "ĊĊ":
segmented.extend(["Ċ", "Ċ"])
else:
segmented.append(s)
else:
segmented = tokenizer(sentence)
return segmented
@@ -572,7 +579,7 @@ def apply_reverse(self, translated):
if isinstance(translated, list):
return self._detokenize(translated, "tgt")
else:
return self._detokenize(translated.split(), "tgt")
return self._detokenize(translated.split(" "), "tgt")

def _repr_args(self):
"""Return str represent key arguments for class."""
4 changes: 2 additions & 2 deletions onmt/transforms/uppercase.py
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
for c in unicodedata.normalize("NFD", src_str.upper())
if unicodedata.category(c) != "Mn"
)
example["src"] = src_str.split()
example["src"] = src_str.split(" ")

if example["tgt"] is not None:
tgt_str = " ".join(example["tgt"])
@@ -56,6 +56,6 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
for c in unicodedata.normalize("NFD", tgt_str.upper())
if unicodedata.category(c) != "Mn"
)
example["tgt"] = tgt_str.split()
example["tgt"] = tgt_str.split(" ")

return example
8 changes: 4 additions & 4 deletions onmt/translate/translation_server.py
Original file line number Diff line number Diff line change
@@ -937,7 +937,7 @@ def maybe_detokenize(self, sequence, side="tgt"):
"""De-tokenize the sequence (or not)
Same args/returns as :func:``tokenize()``"""

if self.tokenizers_opt is not None and "".join(sequence.split()) != "":
if self.tokenizers_opt is not None and "".join(sequence.split(" ")) != "":
return self.detokenize(sequence, side)
return sequence

@@ -950,9 +950,9 @@ def detokenize(self, sequence, side="tgt"):
raise ValueError("No tokenizer loaded")

if self.tokenizers_opt[side]["type"] == "sentencepiece":
detok = self.tokenizers[side].DecodePieces(sequence.split())
detok = self.tokenizers[side].DecodePieces(sequence.split(" "))
elif self.tokenizers_opt[side]["type"] == "pyonmttok":
detok = self.tokenizers[side].detokenize(sequence.split())
detok = self.tokenizers[side].detokenize(sequence.split(" "))

return detok

@@ -976,7 +976,7 @@ def maybe_convert_align(self, src, tgt, align, align_scores):
"To get decoded alignment, joiner/spacer "
"should be used in both side's tokenizer."
)
elif "".join(tgt.split()) != "":
elif "".join(tgt.split(" ")) != "":
align = to_word_align(
src, tgt, align, align_scores, src_marker, tgt_marker
)
6 changes: 3 additions & 3 deletions onmt/utils/alignment.py
Original file line number Diff line number Diff line change
@@ -115,14 +115,14 @@ def to_word_align(
assert m_src in ["joiner", "spacer"], "Invalid value for argument m_src!"
assert m_tgt in ["joiner", "spacer"], "Invalid value for argument m_tgt!"

src, tgt = src.strip().split(), tgt.strip().split()
src, tgt = src.strip().split(" "), tgt.strip().split(" ")
subword_align = {
(int(a), int(b)) for a, b in (x.split("-") for x in subword_align.split())
(int(a), int(b)) for a, b in (x.split("-") for x in subword_align.split(" "))
}

subword_align_scores = dict(
(int(a), float(b))
for a, b in (x.split("-") for x in subword_align_scores.split())
for a, b in (x.split("-") for x in subword_align_scores.split(" "))
)

src_map = (
78 changes: 54 additions & 24 deletions tools/convert_HF.py
Original file line number Diff line number Diff line change
@@ -74,29 +74,20 @@
".feed_forward.experts.7.layer_norm.weight": ".post_attention_layernorm.weight",
}
key_maps["PhiForCausalLM"] = {
"layer_prefix": "transformer.h.",
"decoder.embeddings.make_embedding.emb_luts.0.weight": "transformer.embd.wte.weight",
"decoder.layer_norm.weight": "lm_head.ln.weight",
"decoder.layer_norm.bias": "lm_head.ln.bias",
"generator.weight": "lm_head.linear.weight",
"generator.bias": "lm_head.linear.bias",
".self_attn.linear_query.": (
".mixer.Wqkv.",
"[:hidden_size]", # noqa E501
),
".self_attn.linear_keys.": (
".mixer.Wqkv.",
"[hidden_size:2*hidden_size]", # noqa E501
),
".self_attn.linear_values.": (
".mixer.Wqkv.",
"[-hidden_size:]", # noqa E501
),
".self_attn.final_linear.": ".mixer.out_proj.",
"layer_prefix": "model.layers.",
"decoder.embeddings.make_embedding.emb_luts.0.weight": "model.embed_tokens.weight",
"decoder.layer_norm.weight": "model.final_layernorm.weight",
"decoder.layer_norm.bias": "model.final_layernorm.bias",
"generator.weight": "lm_head.weight",
"generator.bias": "lm_head.bias",
".self_attn.linear_query.": ".self_attn.q_proj.",
".self_attn.linear_keys.": ".self_attn.k_proj.",
".self_attn.linear_values.": ".self_attn.v_proj.",
".self_attn.final_linear.": ".self_attn.dense.",
".feed_forward.w_1.": ".mlp.fc1.",
".feed_forward.w_2.": ".mlp.fc2.",
".layer_norm_1.weight": (".ln.weight", ""),
".layer_norm_1.bias": (".ln.bias", ""),
".layer_norm_1.weight": (".input_layernorm.weight", ""),
".layer_norm_1.bias": (".input_layernorm.bias", ""),
}
ln_table = {
"LlamaForCausalLM": "rms",
@@ -190,6 +181,10 @@ def __init__(self, model_path: str):
"You used a local directory but tokenizer.model",
" and/or tokenizer.json are missing",
)
if os.path.exists(os.path.join(opt.model_dir, "tokenizer_config.json")):
tokenizer_config_json = os.path.join(opt.model_dir, "tokenizer_config.json")
else:
tokenizer_config_json = None
else:
directory_path, _ = os.path.split(opt.output)
os.makedirs(directory_path, exist_ok=True)
@@ -224,6 +219,17 @@ def __init__(self, model_path: str):
raise huggingface_hub.utils.EntryNotFoundError(
"Something went wrong the repo does not contain any config.json file"
)
try:
tokenizer_config_json = huggingface_hub.hf_hub_download(
repo_id=opt.model_dir,
filename="tokenizer_config.json",
local_dir=directory_path,
token=opt.token,
)
except huggingface_hub.utils.EntryNotFoundError:
raise huggingface_hub.utils.EntryNotFoundError(
"Something went wrong the repo does not contain any tokenizer_config.json file"
)
try:
wmap_path = huggingface_hub.hf_hub_download(
repo_id=opt.model_dir,
@@ -325,6 +331,8 @@ def __init__(self, model_path: str):
norm_eps = config["rms_norm_eps"]
elif "layer_norm_epsilon" in config.keys():
norm_eps = config["layer_norm_epsilon"]
elif "layer_norm_eps" in config.keys():
norm_eps = config["layer_norm_eps"]
else:
norm_eps = 1e-6
if "rope_theta" in config.keys():
@@ -333,6 +341,8 @@ def __init__(self, model_path: str):
rope_theta = 1e4
if "rotary_dim" in config.keys():
rotary_dim = config["rotary_dim"]
elif "partial_rotary_factor" in config.keys():
rotary_dim = int(config["partial_rotary_factor"] * (hidden_size // heads))
else:
rotary_dim = 0
if "sliding_window" in config.keys():
@@ -404,7 +414,7 @@ def __init__(self, model_path: str):
params = ["weight", "bias"]

add_qkvbias = False
aff_ffnbias = False
add_ffnbias = False
rotary_interleave = False
if arch == "PhiForCausalLM":
parallel_residual = True
@@ -689,11 +699,28 @@ def get_weight(checkpoint, tensor_name):

directory_path, _ = os.path.split(opt.output)
os.makedirs(directory_path, exist_ok=True)
if tokenizer_config_json is not None:
with open(tokenizer_config_json, encoding="utf-8") as f:
data = json.load(f)
if "add_bos_token" in data.keys():
add_bos_token = data["add_bos_token"]
else:
add_bos_token = False
else:
add_bos_token = True
vocabs = {}
if tokenizer_model is not None:
tokenizer = Tokenizer(model_path=tokenizer_model)
vocab = tokenizer.vocab
vocab[3] = DefaultTokens.PAD
if "<|startoftext|>" in vocab:
index = vocab.index("<|startoftext|>")
vocab[index] = DefaultTokens.BOS
if "<|endoftext|>" in vocab:
index = vocab.index("<|endoftext|>")
vocab[index] = DefaultTokens.EOS
if "<0x00>" in vocab:
index = vocab.index("<0x00>")
vocab[index] = DefaultTokens.PAD
src_vocab = pyonmttok.build_vocab_from_tokens(
vocab,
maximum_size=tokenizer.n_words,
@@ -722,7 +749,10 @@ def get_weight(checkpoint, tensor_name):
vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab
vocabs["data_task"] = "lm"
vocabs["decoder_start_token"] = decoder_start_table[arch]
if add_bos_token:
vocabs["decoder_start_token"] = decoder_start_table[arch]
else:
vocabs["decoder_start_token"] = ""
onmt_cp["vocab"] = {}
onmt_cp["vocab"] = vocabs_to_dict(vocabs)

0 comments on commit 0d88622

Please sign in to comment.