Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Jul 15, 2024
1 parent 27db02e commit a290896
Show file tree
Hide file tree
Showing 9 changed files with 361 additions and 320 deletions.
7 changes: 4 additions & 3 deletions benches/benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use chai::config::Config;
use chai::encoder::generic::GenericEncoder;
use criterion::{criterion_group, criterion_main, Criterion};

use chai::cli::{Cli, Command};
Expand Down Expand Up @@ -32,9 +33,9 @@ fn process_cli_input(
b: &mut Criterion,
) -> Result<(), Error> {
let representation = Representation::new(config)?;
let encoder = Encoder::new(&representation, elements, &assets)?;
let buffer = Buffer::new(&encoder);
let objective = Objective::new(&representation, encoder, assets)?;
let encoder = GenericEncoder::new(&representation, elements, &assets)?;
let buffer = Buffer::new(&encoder.encodables, encoder.get_space());
let objective = Objective::new(&representation, Box::new(encoder), assets)?;
let constraints = Constraints::new(&representation)?;
let mut problem = ElementPlacementProblem::new(representation, constraints, objective, buffer)?;
let mut candidate = problem.generate_candidate();
Expand Down
298 changes: 46 additions & 252 deletions src/encoder.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
//! 编码引擎
use rustc_hash::FxHashMap;
use rustc_hash::FxHashSet;

use crate::config::EncoderConfig;
use crate::error::Error;
use crate::representation::{
Assemble, AssembleList, Assets, AutoSelect, Buffer, Code, Entry, Frequency, Key, KeyMap,
Occupation, Representation, Sequence, MAX_COMBINATION_LENGTH, MAX_WORD_LENGTH,
};
use std::collections::HashSet;
use std::{cmp::Reverse, fmt::Debug, iter::zip};
use crate::representation::{Buffer, Frequency, KeyMap, Sequence};

pub mod generic;

/// 一个可编码对象
#[derive(Debug, Clone)]
Expand All @@ -18,265 +13,64 @@ pub struct Encodable {
pub length: usize,
pub sequence: Sequence,
pub frequency: u64,
pub level: i64,
pub level: u64,
pub hash: u16,
pub index: usize,
}

pub struct Encoder {
pub encodables: Vec<Encodable>,
pub transition_matrix: Vec<Vec<(usize, u64)>>,
pub config: EncoderConfig,
auto_select: AutoSelect,
pub radix: u64,
select_keys: Vec<Key>,
pub short_code: Option<[Vec<CompiledScheme>; MAX_WORD_LENGTH]>,
}

#[derive(Debug)]
pub struct CompiledScheme {
pub prefix: usize,
pub select_keys: Vec<usize>,
}

impl Encoder {
pub fn adapt(
frequency: &Frequency,
words: &HashSet<String>,
) -> (Frequency, Vec<(String, String, u64)>) {
let mut new_frequency = Frequency::new();
let mut transition_pairs = Vec::new();
for (word, value) in frequency {
if words.contains(word) {
new_frequency.insert(word.clone(), new_frequency.get(word).unwrap_or(&0) + *value);
} else {
// 使用逆向最大匹配算法来分词
let chars: Vec<_> = word.chars().collect();
let mut end = chars.len();
let mut last_match: Option<String> = None;
while end > 0 {
let mut start = end - 1;
// 如果最后一个字不在词表里,就不要了
if !words.contains(&chars[start].to_string()) {
end -= 1;
continue;
}
// 继续向前匹配,看看是否能匹配到更长的词
while start > 0
&& words.contains(&chars[(start - 1)..end].iter().collect::<String>())
{
start -= 1;
}
// 确定最大匹配
let sub_word: String = chars[start..end].iter().collect();
*new_frequency.entry(sub_word.clone()).or_default() += *value;
if let Some(last) = last_match {
transition_pairs.push((sub_word.clone(), last, *value));
}
last_match = Some(sub_word);
end = start;
}
}
}
(new_frequency, transition_pairs)
}

/// 提供配置表示、拆分表、词表和共用资源来创建一个编码引擎
/// 字需要提供拆分表
/// 词只需要提供词表,它对应的拆分序列从字推出
pub fn new(
representation: &Representation,
resource: AssembleList,
assets: &Assets,
) -> Result<Encoder, Error> {
let encoder = &representation.config.encoder;
let max_length = encoder.max_length;
if max_length >= 8 {
return Err("目前暂不支持最大码长大于等于 8 的方案计算!".into());
}

// 预处理词频
let all_words: HashSet<_> = resource.iter().map(|x| x.name.clone()).collect();
let (frequency, transition_pairs) = Self::adapt(&assets.frequency, &all_words);

// 将拆分序列映射降序排列
let mut encodables = Vec::new();
for (index, assemble) in resource.into_iter().enumerate() {
let Assemble {
name,
importance,
level,
..
} = assemble.clone();
let sequence = representation.transform_elements(assemble)?;
let char_frequency = *frequency.get(&name).unwrap_or(&0);
let frequency = char_frequency * importance / 100;
let hash: u16 = (name.chars().map(|x| x as u32).sum::<u32>()) as u16;
encodables.push(Encodable {
name: name.clone(),
length: name.chars().count(),
sequence,
frequency,
level,
hash,
index,
});
}

encodables.sort_by_key(|x| Reverse(x.frequency));

let map_word_to_index: FxHashMap<String, usize> = encodables
.iter()
.enumerate()
.map(|(index, x)| (x.name.clone(), index))
.collect();
let mut transition_matrix = vec![vec![]; encodables.len()];
for (from, to, value) in transition_pairs {
let from = *map_word_to_index.get(&from).unwrap();
let to = *map_word_to_index.get(&to).unwrap();
transition_matrix[from].push((to, value));
}
for row in transition_matrix.iter_mut() {
row.sort_by_key(|x| x.0);
}

// 处理自动上屏
let auto_select = representation.transform_auto_select()?;

// 处理简码规则
let mut short_code = None;
if let Some(configs) = &encoder.short_code {
short_code = Some(representation.transform_short_code(configs.clone())?);
}
let encoder = Encoder {
encodables,
transition_matrix,
auto_select,
config: encoder.clone(),
radix: representation.radix,
select_keys: representation.select_keys.clone(),
short_code,
};
Ok(encoder)
}

pub fn get_actual_code(&self, code: u64, rank: i8, length: u32) -> (u64, u32) {
if rank == 0 && *self.auto_select.get(code as usize).unwrap_or(&true) {
return (code, length);
}
let select = *self
.select_keys
.get(rank.unsigned_abs() as usize)
.unwrap_or(&self.select_keys[0]) as u64
* self.radix.pow(length);
(code + select, length + 1)
}

pub fn encode_full(&self, keymap: &KeyMap, buffer: &mut Buffer, occupation: &mut Occupation) {
let weights = (0..=self.config.max_length)
.map(|x| self.radix.pow(x as u32))
.collect::<Vec<_>>();
for (encodable, pointer) in zip(&self.encodables, &mut buffer.full) {
let sequence = &encodable.sequence;
let mut code = 0_u64;
for (element, weight) in zip(sequence, &weights) {
code += keymap[*element] as u64 * weight;
}
pointer.code = code;
pointer.rank = occupation.rank(code) as i8;
occupation.insert(code, encodable.hash);
}
}

pub fn encode_short(
&self,
buffer: &mut Buffer,
full_occupation: &Occupation,
short_occupation: &mut Occupation,
) {
if self.short_code.is_none() {
return;
}
let short_code = self.short_code.as_ref().unwrap();
// 优先简码
for ((code, pointer), encodable) in
zip(zip(&buffer.full, &mut buffer.short), &self.encodables)
{
if encodable.level == -1 {
continue;
}
let modulo = self.radix.pow(encodable.level as u32);
let short = code.code % modulo;
pointer.code = short;
pointer.rank = 0;
short_occupation.insert(short, encodable.hash);
}
// 常规简码
for ((code, pointer), encodable) in
zip(zip(&buffer.full, &mut buffer.short), &self.encodables)
{
let schemes = &short_code[encodable.length - 1];
if schemes.is_empty() || encodable.level >= 0 {
continue;
}
let full = &code.code;
let mut has_reduced = false;
let hash = encodable.hash;
for scheme in schemes {
let CompiledScheme {
prefix,
select_keys,
} = scheme;
// 如果根本没有这么多码,就放弃
if *full < self.radix.pow((*prefix - 1) as u32) {
pub fn adapt(
frequency: &Frequency,
words: &FxHashSet<String>,
) -> (Frequency, Vec<(String, String, u64)>) {
let mut new_frequency = Frequency::new();
let mut transition_pairs = Vec::new();
for (word, value) in frequency {
if words.contains(word) {
new_frequency.insert(word.clone(), new_frequency.get(word).unwrap_or(&0) + *value);
} else {
// 使用逆向最大匹配算法来分词
let chars: Vec<_> = word.chars().collect();
let mut end = chars.len();
let mut last_match: Option<String> = None;
while end > 0 {
let mut start = end - 1;
// 如果最后一个字不在词表里,就不要了
if !words.contains(&chars[start].to_string()) {
end -= 1;
continue;
}
// 首先将全码截取一部分出来
let modulo = self.radix.pow(*prefix as u32);
let short = full % modulo;
let capacity = select_keys.len() as u8;
if full_occupation.rank(short) + short_occupation.rank_hash(short, hash) >= capacity
// 继续向前匹配,看看是否能匹配到更长的词
while start > 0
&& words.contains(&chars[(start - 1)..end].iter().collect::<String>())
{
continue;
start -= 1;
}
pointer.code = short;
pointer.rank = short_occupation.rank_hash(short, hash) as i8;
short_occupation.insert(short, hash);
has_reduced = true;
break;
}
if !has_reduced {
pointer.code = *full;
pointer.rank = short_occupation.rank_hash(*full, hash) as i8;
short_occupation.insert(*full, hash);
// 确定最大匹配
let sub_word: String = chars[start..end].iter().collect();
*new_frequency.entry(sub_word.clone()).or_default() += *value;
if let Some(last) = last_match {
transition_pairs.push((sub_word.clone(), last, *value));
}
last_match = Some(sub_word);
end = start;
}
}
}

pub fn encode(&self, keymap: &KeyMap, representation: &Representation) -> Vec<Entry> {
let mut buffer = Buffer::new(self);
let mut full_occupation = Occupation::new(representation.get_space());
let mut short_occupation = Occupation::new(representation.get_space());
self.encode_full(keymap, &mut buffer, &mut full_occupation);
self.encode_short(&mut buffer, &full_occupation, &mut short_occupation);
let mut entries: Vec<(usize, Entry)> = Vec::new();
let recover = |code: Code| representation.repr_code(code).iter().collect();
for (index, encodable) in self.encodables.iter().enumerate() {
let entry = Entry {
name: encodable.name.clone(),
full: recover(buffer.full[index].code),
full_rank: buffer.full[index].rank,
short: recover(buffer.short[index].code),
short_rank: buffer.short[index].rank,
};
entries.push((encodable.index, entry));
}
entries.sort_by_key(|x| x.0);
entries.into_iter().map(|x| x.1).collect()
}

pub fn get_space(&self) -> usize {
let max_length = self.config.max_length.min(MAX_COMBINATION_LENGTH);
self.radix.pow(max_length as u32) as usize
}
(new_frequency, transition_pairs)
}

pub trait Encoder {
fn encode_full(&self, keymap: &KeyMap, buffer: &mut Buffer);
fn encode_short(&self, buffer: &mut Buffer);
fn get_radix(&self) -> u64;
fn get_space(&self) -> usize;
fn get_actual_code(&self, code: u64, rank: i8, length: u32) -> (u64, u32);
fn get_transitions(&self, index: usize) -> &[(usize, u64)];
}
Loading

0 comments on commit a290896

Please sign in to comment.