Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Jul 31, 2024
1 parent 84dd386 commit 468c41a
Show file tree
Hide file tree
Showing 18 changed files with 462 additions and 349 deletions.
37 changes: 26 additions & 11 deletions benches/benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use chai::config::Config;
use chai::encoder::occupation::Occupation;
use chai::encoder::simple_occupation::SimpleOccupation;
use criterion::{criterion_group, criterion_main, Criterion};

use chai::cli::{Cli, Command};
use chai::constraints::Constraints;
use chai::encoder::Encoder;
use chai::metaheuristics::Metaheuristics;
use chai::encoder::{Driver, Encoder};
use chai::objectives::Objective;
use chai::problem::ElementPlacementProblem;
use chai::problem::Problem;
use chai::representation::{AssembleList, Assets};
use chai::{error::Error, representation::Representation};
use std::path::PathBuf;
Expand All @@ -25,27 +26,41 @@ fn simulate_cli_input(name: &str) -> (Config, AssembleList, Assets) {
cli.prepare_file()
}

fn process_cli_input(
config: Config,
fn do_benchmark(
representation: Representation,
elements: AssembleList,
assets: Assets,
driver: Box<dyn Driver>,
b: &mut Criterion,
) -> Result<(), Error> {
let representation = Representation::new(config)?;
let encoder = Encoder::new(&representation, elements, &assets, true)?;
let objective = Objective::new(&representation, encoder, assets)?;
let constraints = Constraints::new(&representation)?;
let mut problem = ElementPlacementProblem::new(representation, constraints, objective)?;
let mut candidate = problem.generate_candidate();
let encoder = Encoder::new(&representation, elements, &assets, driver)?;
let objective = Objective::new(&representation, encoder, assets)?;
let mut problem = Problem::new(representation, constraints, objective)?;
let candidate = problem.generate_candidate();
b.bench_function("Evaluation", |b| {
b.iter(|| {
candidate = problem.tweak_candidate(&candidate);
problem.rank_candidate(&candidate);
})
});
Ok(())
}

fn process_cli_input(
config: Config,
elements: AssembleList,
assets: Assets,
b: &mut Criterion,
) -> Result<(), Error> {
let representation = Representation::new(config)?;
let driver: Box<dyn Driver> = if representation.config.encoder.max_length <= 4 {
Box::new(SimpleOccupation::new(representation.get_space()))
} else {
Box::new(Occupation::new(representation.get_space()))
};
do_benchmark(representation, elements, assets, driver, b)
}

fn length_3(b: &mut Criterion) {
let (config, resource, assets) = simulate_cli_input("easy");
process_cli_input(config, resource, assets, b).unwrap();
Expand Down
22 changes: 7 additions & 15 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//!
use crate::data::{Glyph, PrimitiveRepertoire, Reading};
use crate::metaheuristics::simulated_annealing::Parameters;
use crate::metaheuristics::simulated_annealing::SimulatedAnnealing;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::collections::HashMap;
Expand All @@ -24,6 +24,7 @@ pub struct Info {
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Data {
pub character_set: Option<String>,
pub repertoire: Option<PrimitiveRepertoire>,
pub glyph_customization: Option<HashMap<String, Glyph>>,
pub reading_customization: Option<HashMap<String, Vec<Reading>>>,
Expand All @@ -41,6 +42,7 @@ pub struct Analysis {
pub customize: Option<HashMap<String, Vec<String>>>,
pub strong: Option<Vec<String>>,
pub weak: Option<Vec<String>>,
pub serializer: Option<String>,
}

#[skip_serializing_none]
Expand Down Expand Up @@ -233,20 +235,10 @@ pub struct ConstraintsConfig {

#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchConfig {
pub random_move: f64,
pub random_swap: f64,
pub random_full_key_swap: f64,
}

#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SolverConfig {
pub algorithm: String,
pub runtime: Option<u64>,
pub parameters: Option<Parameters>,
pub report_after: Option<f64>,
pub search_method: Option<SearchConfig>,
#[serde(tag = "algorithm")]
pub enum SolverConfig {
SimulatedAnnealing(SimulatedAnnealing),
// TODO: Add more algorithms
}

#[skip_serializing_none]
Expand Down
6 changes: 5 additions & 1 deletion src/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ impl Constraints {
let element_number = representation.element_repr.get(&x);
element_number.ok_or(format!("{x} 不存在于键盘映射中"))
};
let optimization = representation.config.optimization.as_ref().ok_or("优化配置不存在")?;
let optimization = representation
.config
.optimization
.as_ref()
.ok_or("优化配置不存在")?;
if let Some(constraints) = &optimization.constraints {
values.append(&mut constraints.elements.clone().unwrap_or_default());
values.append(&mut constraints.indices.clone().unwrap_or_default());
Expand Down
25 changes: 7 additions & 18 deletions src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
use crate::error::Error;
use crate::representation::{
Assemble, AssembleList, Assets, AutoSelect, Code, CodeInfo, CodeSubInfo, Codes, Entry,
Frequency, Key, KeyMap, Representation, Sequence, MAX_COMBINATION_LENGTH, MAX_WORD_LENGTH,
Frequency, Key, KeyMap, Representation, Sequence, MAX_WORD_LENGTH,
};
use occupation::Occupation;
use rustc_hash::{FxHashMap, FxHashSet};
use std::cmp::Reverse;

mod occupation;

mod simple_occupation;
use simple_occupation::SimpleOccupation;
pub mod c3;
pub mod occupation;
pub mod simple_occupation;

/// 一个可编码对象
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -74,8 +72,7 @@ pub fn adapt(
}

pub trait Driver {
fn encode_full(&mut self, keymap: &KeyMap, config: &EncoderConfig, buffer: &mut Codes);
fn encode_short(&mut self, config: &EncoderConfig, buffer: &mut Codes);
fn run(&mut self, keymap: &KeyMap, config: &EncoderConfig, buffer: &mut Codes);
}

pub struct EncoderConfig {
Expand Down Expand Up @@ -121,7 +118,7 @@ impl Encoder {
representation: &Representation,
resource: AssembleList,
assets: &Assets,
simple: bool,
driver: Box<dyn Driver>,
) -> Result<Self, Error> {
let encoder = &representation.config.encoder;
let max_length = encoder.max_length;
Expand Down Expand Up @@ -200,12 +197,6 @@ impl Encoder {
first_key: representation.select_keys[0],
short_code,
};
let space = representation.get_space();
let driver: Box<dyn Driver> = if simple && max_length <= MAX_COMBINATION_LENGTH {
Box::new(SimpleOccupation::new(space))
} else {
Box::new(Occupation::new(space))
};
let encoder = Self {
transition_matrix,
buffer,
Expand All @@ -216,9 +207,7 @@ impl Encoder {
}

pub fn prepare(&mut self, keymap: &KeyMap) {
self.driver
.encode_full(keymap, &self.config, &mut self.buffer);
self.driver.encode_short(&self.config, &mut self.buffer);
self.driver.run(keymap, &self.config, &mut self.buffer);
}

pub fn encode(&mut self, keymap: &KeyMap, representation: &Representation) -> Vec<Entry> {
Expand Down
44 changes: 44 additions & 0 deletions src/encoder/c3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use super::{Driver, EncoderConfig};
use crate::representation::{Codes, KeyMap};
use std::iter::zip;

/// 编码是否已被占据
/// 用一个数组和一个哈希集合来表示,数组用来表示四码以内的编码,哈希集合用来表示四码以上的编码
pub struct C3 {
pub full_space: Vec<u8>,
}

impl C3 {
pub fn new(length: usize) -> Self {
Self {
full_space: vec![0; length],
}
}

pub fn reset(&mut self) {
self.full_space.iter_mut().for_each(|x| {
*x = 0;
});
}
}

impl Driver for C3 {
fn run(&mut self, keymap: &KeyMap, config: &EncoderConfig, buffer: &mut Codes) {
self.reset();
// 1. 全码
for (encodable, pointer) in zip(&config.encodables, buffer.iter_mut()) {
let sequence = &encodable.sequence;
assert!(sequence.len() >= 3);
let code = keymap[sequence[0]] as u64 * config.radix * config.radix
+ keymap[sequence[1]] as u64 * config.radix
+ keymap[sequence[2]] as u64;
pointer.full.actual = code;
}

for pointer in buffer.iter_mut() {
let rank = self.full_space[pointer.full.actual as usize];
pointer.full.duplicate = rank > 0;
self.full_space[pointer.full.actual as usize] = rank.saturating_add(1);
}
}
}
10 changes: 4 additions & 6 deletions src/encoder/occupation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ impl Occupation {
}

impl Driver for Occupation {
fn encode_full(&mut self, keymap: &KeyMap, config: &EncoderConfig, full: &mut Codes) {
fn run(&mut self, keymap: &KeyMap, config: &EncoderConfig, buffer: &mut Codes) {
self.reset();
let weights: Vec<_> = (0..=config.max_length)
.map(|x| config.radix.pow(x as u32))
.collect();
for (encodable, pointer) in zip(&config.encodables, full) {
for (encodable, pointer) in zip(&config.encodables, buffer.iter_mut()) {
let sequence = &encodable.sequence;
let mut code = 0_u64;
for (element, weight) in zip(sequence, &weights) {
Expand All @@ -99,9 +99,6 @@ impl Driver for Occupation {
};
self.full_space.insert(code, encodable.hash);
}
}

fn encode_short(&mut self, config: &EncoderConfig, buffer: &mut Codes) {
if config.short_code.is_none() || config.short_code.as_ref().unwrap().is_empty() {
return;
}
Expand Down Expand Up @@ -146,7 +143,8 @@ impl Driver for Occupation {
}
// 首先将全码截取一部分出来
let short = full.code % weight;
let rank = self.full_space.rank_hash(short, hash) + self.short_space.rank_hash(short, hash);
let rank = self.full_space.rank_hash(short, hash)
+ self.short_space.rank_hash(short, hash);
if rank >= select_keys.len() as u8 {
continue;
}
Expand Down
12 changes: 5 additions & 7 deletions src/encoder/simple_occupation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ impl SimpleOccupation {
}

impl Driver for SimpleOccupation {
fn encode_full(&mut self, keymap: &KeyMap, config: &EncoderConfig, full: &mut Codes) {
fn run(&mut self, keymap: &KeyMap, config: &EncoderConfig, buffer: &mut Codes) {
self.reset();
// 1. 全码
let weights: Vec<_> = (0..=config.max_length)
.map(|x| config.radix.pow(x as u32))
.collect();
for (encodable, pointer) in zip(&config.encodables, full.iter_mut()) {
for (encodable, pointer) in zip(&config.encodables, buffer.iter_mut()) {
let sequence = &encodable.sequence;
let mut code = 0_u64;
for (element, weight) in zip(sequence, &weights) {
Expand All @@ -45,17 +46,14 @@ impl Driver for SimpleOccupation {
pointer.full.actual = config.wrap_actual(code, 0, weights[sequence.len()]);
self.full_space[code as usize] = rank.saturating_add(1);
}
}

fn encode_short(&mut self, config: &EncoderConfig, buffer: &mut Codes) {
if config.short_code.is_none() || config.short_code.as_ref().unwrap().is_empty() {
return;
}
let weights: Vec<_> = (0..=config.max_length)
.map(|x| config.radix.pow(x as u32))
.collect();
let short_code = config.short_code.as_ref().unwrap();
// 优先简码
// 2. 优先简码
for (pointer, encodable) in zip(buffer.iter_mut(), &config.encodables) {
if encodable.level == u64::MAX {
continue;
Expand All @@ -66,7 +64,7 @@ impl Driver for SimpleOccupation {
pointer.short.actual = config.wrap_actual(short, rank, encodable.level);
self.short_space[short as usize] = rank.saturating_add(1);
}
// 常规简码
// 3. 常规简码
for (pointer, encodable) in zip(buffer.iter_mut(), &config.encodables) {
if encodable.level != u64::MAX {
continue;
Expand Down
8 changes: 3 additions & 5 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@ use wasm_bindgen::JsError;

#[derive(Debug, Clone)]
pub struct Error {
pub message: String
pub message: String,
}

impl From<String> for Error {
fn from(value: String) -> Self {
Self {
message: value
}
Self { message: value }
}
}

impl From<&str> for Error {
fn from(value: &str) -> Self {
Self {
message: value.to_string()
message: value.to_string(),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/interface.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! 输出接口的抽象层
//!
//!
//! 定义了一个特征,指定了所有在退火计算的过程中需要向用户反馈的数据。命令行界面、Web 界面只需要各自实现这些方法,就可向用户报告各种用户数据,实现方式可以很不一样。
use crate::config::Config;
Expand Down
Loading

0 comments on commit 468c41a

Please sign in to comment.