Skip to content

Commit

Permalink
Merge branch 'main' into feat/tutorials/yolo_world_export
Browse files Browse the repository at this point in the history
  • Loading branch information
skottmckay authored Oct 2, 2024
2 parents 767e795 + 12a9e8b commit 5d3185e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
23 changes: 13 additions & 10 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,17 +397,14 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
bool compute_offset_mapping,
std::list<OffsetMappingType>& offset_map) const {
std::vector<int64_t> res;

// Add BOS token to result
res.push_back(bos_token_id_);

size_t max_length = static_cast<size_t>(max_length_i64);
// Parse input
bool add_dummy_prefix = false;
if (ModelName() == kModel_Llama) {
add_dummy_prefix = true;
}
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
bool add_dummy_prefix = bpe_conf_.get().add_dummy_prefix_;

for (auto& seg_id : special_token_split_res) {
if (res.size() >= max_length) break;

Expand Down Expand Up @@ -637,7 +634,8 @@ static const auto kSpmConfiguration = BpeModelConf{
"<s>", // bos_token
"</s>", // eos_token
"", // pad_token
true};
true, // spm_model
true}; // add_dummy_prefix

SpmTokenizer::SpmTokenizer()
: KernelBpeTokenizer(kSpmConfiguration) {}
Expand Down Expand Up @@ -700,7 +698,7 @@ OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_ex
}

// Helper methods (to be added to the class declaration)
bool JsonFastTokenizer::CheckForSpmModel(const json& tok_json) {
void JsonFastTokenizer::LoadSpmModelParams(const json& tok_json) {
auto decoder_node = tok_json.find("decoder");
if (decoder_node != tok_json.end()) {
auto decoders_node = decoder_node->find("decoders");
Expand All @@ -710,13 +708,18 @@ bool JsonFastTokenizer::CheckForSpmModel(const json& tok_json) {
if (type == "Replace") {
std::string target = step.value("/pattern/String"_json_pointer, "");
if (target == spm_escaped_space) {
return true;
json_conf_.spm_model_ = true;
}
}
else if (type == "Strip") {
std::string content = step.value("/content"_json_pointer, "");
if (content == " ") {
json_conf_.add_dummy_prefix_ = true;
}
}
}
}
}
return false;
}

void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
Expand Down Expand Up @@ -756,7 +759,7 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config
bpe_conf_ = json_conf_;

// Check for SPM model
json_conf_.spm_model_ = CheckForSpmModel(tok_json);
LoadSpmModelParams(tok_json);

auto model_node = tok_json.find("model");
if (model_node == tok_json.end()) {
Expand Down
8 changes: 4 additions & 4 deletions operators/tokenizer/bpe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct BpeModelConf {
const char* pad_token_{nullptr};

bool spm_model_{};
bool add_dummy_prefix_{};
std::string GetSpecialTokens() const;
};

Expand All @@ -34,6 +35,7 @@ struct KernelBpeTokenizer {

const std::string& ModelName() const { return model_name_; }
uint32_t GetTokenId(const std::string& token) const;
bool GetAddDummyPrefix() const { return bpe_conf_.get().add_dummy_prefix_; }

protected:
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
Expand All @@ -50,8 +52,8 @@ struct KernelBpeTokenizer {
void CreateUnicodeByteEncoder();

protected:
std::reference_wrapper<BpeModelConf const> bpe_conf_;
std::string model_name_;
std::reference_wrapper<BpeModelConf const> bpe_conf_;
std::unique_ptr<ort_extensions::BpeModel> bbpe_tokenizer_;

int64_t padding_length_ = -1;
Expand Down Expand Up @@ -122,12 +124,10 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
const auto& GetAddedTokens() const { return added_tokens_; }
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
bool IsSpmModel() const { return json_conf_.spm_model_; }
bool tiktoken_ = false;

private:
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
// template functions to avoid including the huge json header file
bool CheckForSpmModel(const json& tok_json);
void LoadSpmModelParams(const json& tok_json);
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config);

Expand Down
4 changes: 2 additions & 2 deletions operators/tokenizer/bpe_streaming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
// whitespace_token_ = tok_config.clean_up_tokenization_spaces_ ? 1 : 0;
skip_special_tokens_ = 1;
// en_normalization_ = 0;
add_dummy_prefix_ = tok_config.tokenizer_class_ == "LlamaTokenizer" ? 1 : 0;
add_dummy_prefix_ = encoder.GetAddDummyPrefix();
eos_token_id_ = encoder.GetEncoder().GetTokenId(tok_config.eos_token_);

tok_config_ = ptr_config;
Expand Down Expand Up @@ -249,7 +249,7 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
private:

extTokenId_t eos_token_id_{0};
bool add_dummy_prefix_ = false;
bool spm_model_{};
bool add_dummy_prefix_{};
std::shared_ptr<ort_extensions::TokenJsonConfig const> tok_config_;
};

0 comments on commit 5d3185e

Please sign in to comment.