Skip to content

Commit

Permalink
Refactor encryptor impl
Browse files Browse the repository at this point in the history
Signed-off-by: kexuan.yang <[email protected]>
  • Loading branch information
yangkx1024 committed Jan 20, 2024
1 parent 2ce087d commit 9cf8049
Show file tree
Hide file tree
Showing 14 changed files with 80 additions and 106 deletions.
Binary file modified android/library-encrypt/src/main/jniLibs/arm64-v8a/libmmkv.so
Binary file not shown.
Binary file modified android/library-encrypt/src/main/jniLibs/armeabi-v7a/libmmkv.so
Binary file not shown.
Binary file modified android/library-encrypt/src/main/jniLibs/x86_64/libmmkv.so
Binary file not shown.
Binary file modified android/library/src/main/jniLibs/arm64-v8a/libmmkv.so
Binary file not shown.
Binary file modified android/library/src/main/jniLibs/armeabi-v7a/libmmkv.so
Binary file not shown.
Binary file modified android/library/src/main/jniLibs/x86_64/libmmkv.so
Binary file not shown.
Binary file modified ios/MMKV/RustMMKV.xcframework.zip
Binary file not shown.
5 changes: 3 additions & 2 deletions src/core/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ use kv::{Types, KV};
include!(concat!(env!("OUT_DIR"), "/protos/mod.rs"));

#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct Buffer(KV);

pub trait Encoder: Send {
fn encode_to_bytes(&self, raw_buffer: &Buffer) -> Result<Vec<u8>>;
fn encode_to_bytes(&self, raw_buffer: &Buffer, position: u32) -> Result<Vec<u8>>;
}

pub struct DecodeResult {
Expand All @@ -22,7 +23,7 @@ pub struct DecodeResult {
}

pub trait Decoder: Send {
fn decode_bytes(&self, data: &[u8]) -> Result<DecodeResult>;
fn decode_bytes(&self, data: &[u8], position: u32) -> Result<DecodeResult>;
}

macro_rules! impl_from_typed_array {
Expand Down
8 changes: 4 additions & 4 deletions src/core/crc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const CRC8: Crc<u8> = Crc::<u8>::new(&CRC_8_AUTOSAR);
pub struct CrcEncoderDecoder;

impl Encoder for CrcEncoderDecoder {
fn encode_to_bytes(&self, raw_buffer: &Buffer) -> Result<Vec<u8>> {
fn encode_to_bytes(&self, raw_buffer: &Buffer, _: u32) -> Result<Vec<u8>> {
let bytes_to_write = raw_buffer.to_bytes();
let sum = CRC8.checksum(bytes_to_write.as_slice());
let len = bytes_to_write.len() as u32 + 1;
Expand All @@ -23,7 +23,7 @@ impl Encoder for CrcEncoderDecoder {
}

impl Decoder for CrcEncoderDecoder {
fn decode_bytes(&self, data: &[u8]) -> Result<DecodeResult> {
fn decode_bytes(&self, data: &[u8], _: u32) -> Result<DecodeResult> {
let offset = size_of::<u32>();
let item_len = u32::from_be_bytes(data[0..offset].try_into().map_err(|_| DataInvalid)?);
let bytes_to_decode = &data[offset..(offset + item_len as usize - 1)];
Expand Down Expand Up @@ -56,8 +56,8 @@ mod tests {
#[test]
fn test_crc_buffer() {
let buffer = Buffer::from_i32("key", 1);
let bytes = CrcEncoderDecoder.encode_to_bytes(&buffer).unwrap();
let decode_result = CrcEncoderDecoder.decode_bytes(bytes.as_slice()).unwrap();
let bytes = CrcEncoderDecoder.encode_to_bytes(&buffer, 0).unwrap();
let decode_result = CrcEncoderDecoder.decode_bytes(bytes.as_slice(), 0).unwrap();
assert_eq!(decode_result.len, bytes.len() as u32);
assert_eq!(decode_result.buffer, Some(buffer));
}
Expand Down
98 changes: 34 additions & 64 deletions src/core/encrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ use eax::aead::rand_core::RngCore;
use eax::aead::stream::{NewStream, StreamBE32, StreamPrimitive};
use eax::aead::{generic_array::GenericArray, KeyInit, OsRng, Payload};
use eax::Eax;
use std::cell::RefCell;
use std::fs;
use std::fs::OpenOptions;
use std::io::{Read, Write};
use std::mem::size_of;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::sync::Arc;

use crate::core::buffer::{Buffer, DecodeResult, Decoder, Encoder};
use crate::Error::{DataInvalid, DecryptFailed, EncryptFailed, LockError};
use crate::Error::{DataInvalid, DecryptFailed, EncryptFailed};
use crate::Result;

const LOG_TAG: &str = "MMKV:Encrypt";
Expand All @@ -24,26 +23,22 @@ type Stream = StreamBE32<Aes128Eax>;

#[derive(Clone)]
pub struct Encryptor {
pub meta_file_path: PathBuf,
encryptor: Arc<Mutex<EncryptorImpl>>,
key: Vec<u8>,
meta_file_path: PathBuf,
encryptor: Arc<StreamWrapper>,
}

struct EncryptorImpl {
stream: Stream,
position: RefCell<u32>,
}
#[repr(transparent)]
struct StreamWrapper(Stream);

impl Encryptor {
pub fn init(file_path: &Path, key: &str) -> Self {
let decoded_key = hex::decode(key).unwrap();
let meta_file_path = Encryptor::resolve_meta_file_path(file_path);
let encryptor_impl =
EncryptorImpl::init(decoded_key.as_slice().try_into().unwrap(), &meta_file_path);
let encryptor =
StreamWrapper::init(decoded_key.as_slice().try_into().unwrap(), &meta_file_path);
Encryptor {
meta_file_path,
encryptor: Arc::new(Mutex::new(encryptor_impl)),
key: decoded_key,
encryptor: Arc::new(encryptor),
}
}

Expand All @@ -59,21 +54,14 @@ impl Encryptor {
pub fn remove_file(&self) {
let _ = fs::remove_file(&self.meta_file_path);
}

pub fn reset(&mut self) {
*self.encryptor.lock().unwrap() = EncryptorImpl::init(
self.key.as_slice().try_into().unwrap(),
&self.meta_file_path,
);
}
}

impl EncryptorImpl {
impl StreamWrapper {
fn init(key: [u8; 16], meta_file_path: &PathBuf) -> Self {
if meta_file_path.exists() {
EncryptorImpl::new_with_nonce(key, meta_file_path)
StreamWrapper::new_with_nonce(key, meta_file_path)
} else {
EncryptorImpl::new(key, meta_file_path)
StreamWrapper::new(key, meta_file_path)
}
}

Expand All @@ -91,10 +79,7 @@ impl EncryptorImpl {
.expect("failed to write nonce file");
let cipher = Aes128Eax::new(generic_array);
let stream = StreamBE32::from_aead(cipher, &nonce);
EncryptorImpl {
stream,
position: Default::default(),
}
StreamWrapper(stream)
}

fn new_with_nonce(key: [u8; 16], meta_file_path: &PathBuf) -> Self {
Expand All @@ -104,7 +89,7 @@ impl EncryptorImpl {
error!(LOG_TAG, "filed to read nonce, reason: {:?}", reason);
warn!(LOG_TAG, "delete meta file due to previous reason, which may cause mmkv drop all encrypted data");
let _ = fs::remove_file(meta_file_path);
EncryptorImpl::new(key, meta_file_path)
StreamWrapper::new(key, meta_file_path)
};
match nonce_file.read_to_end(&mut nonce) {
Ok(len) if len != NONCE_LEN => {
Expand All @@ -117,47 +102,40 @@ impl EncryptorImpl {
let nonce = GenericArray::from_slice(nonce.as_slice());
let cipher = Aes128Eax::new(generic_array);
let stream = StreamBE32::from_aead(cipher, nonce);
EncryptorImpl {
stream,
position: Default::default(),
}
StreamWrapper(stream)
}

fn encrypt(&self, bytes: Vec<u8>) -> Result<Vec<u8>> {
let position = *self.position.borrow();
fn encrypt(&self, bytes: Vec<u8>, position: u32) -> Result<Vec<u8>> {
if position == Stream::COUNTER_MAX {
return Err(EncryptFailed(String::from("counter overflow")));
}

let result = self
.stream
.0
.encrypt(position, false, Payload::from(bytes.as_slice()))
.map_err(|e| EncryptFailed(e.to_string()))?;

*self.position.borrow_mut() = position + Stream::COUNTER_INCR;
Ok(result)
}

fn decrypt(&self, bytes: Vec<u8>) -> Result<Vec<u8>> {
let position = *self.position.borrow();
fn decrypt(&self, bytes: Vec<u8>, position: u32) -> Result<Vec<u8>> {
if position == Stream::COUNTER_MAX {
return Err(DecryptFailed(String::from("counter overflow")));
}

let result = self
.stream
.0
.decrypt(position, false, Payload::from(bytes.as_slice()))
.map_err(|e| DecryptFailed(e.to_string()))?;

*self.position.borrow_mut() = position + Stream::COUNTER_INCR;
Ok(result)
}
}

impl Encoder for Encryptor {
fn encode_to_bytes(&self, raw_buffer: &Buffer) -> Result<Vec<u8>> {
fn encode_to_bytes(&self, raw_buffer: &Buffer, position: u32) -> Result<Vec<u8>> {
let bytes_to_write = raw_buffer.to_bytes();
let crypt_bytes = self.encryptor.lock().unwrap().encrypt(bytes_to_write)?;
let crypt_bytes = self.encryptor.encrypt(bytes_to_write, position)?;
let len = crypt_bytes.len() as u32;
let mut data = len.to_be_bytes().to_vec();
data.extend_from_slice(crypt_bytes.as_slice());
Expand All @@ -166,17 +144,15 @@ impl Encoder for Encryptor {
}

impl Decoder for Encryptor {
fn decode_bytes(&self, data: &[u8]) -> Result<DecodeResult> {
fn decode_bytes(&self, data: &[u8], position: u32) -> Result<DecodeResult> {
let data_offset = size_of::<u32>();
let item_len =
u32::from_be_bytes(data[0..data_offset].try_into().map_err(|_| DataInvalid)?);
let bytes_to_decode = &data[data_offset..(data_offset + item_len as usize)];
let read_len = data_offset as u32 + item_len;
let result = self
.encryptor
.lock()
.map_err(|e| LockError(e.to_string()))
.and_then(|encryptor| encryptor.decrypt(bytes_to_decode.to_vec()))
.decrypt(bytes_to_decode.to_vec(), position)
.and_then(|vec| Buffer::from_encoded_bytes(vec.as_slice()));
let buffer = match result {
Ok(data) => Some(data),
Expand All @@ -203,32 +179,26 @@ mod tests {
#[test]
fn test_crypt_buffer() {
let path = Path::new("./mmkv");
let mut encoder = Encryptor::init(path, TEST_KEY);
let mut decoder = Encryptor::init(path, TEST_KEY);
let encryptor = Encryptor::init(path, TEST_KEY);
let buffer1 = Buffer::from_i32("key1", 1);
let bytes1 = encoder.encode_to_bytes(&buffer1).unwrap();
let decode_result1 = decoder.decode_bytes(bytes1.as_slice()).unwrap();
let bytes1 = encryptor.encode_to_bytes(&buffer1, 0).unwrap();
let decode_result1 = encryptor.decode_bytes(bytes1.as_slice(), 0).unwrap();
assert_eq!(decode_result1.len, bytes1.len() as u32);
assert_eq!(decode_result1.buffer, Some(buffer1.clone()));
let buffer2 = Buffer::from_i32("key2", 2);
let bytes2 = encoder.encode_to_bytes(&buffer2).unwrap();
let decode_result2 = decoder.decode_bytes(bytes2.as_slice()).unwrap();
let bytes2 = encryptor.encode_to_bytes(&buffer2, 1).unwrap();
let decode_result2 = encryptor.decode_bytes(bytes2.as_slice(), 1).unwrap();
assert_eq!(decode_result2.len, bytes2.len() as u32);
assert_eq!(decode_result2.buffer, Some(buffer2));
assert_eq!(*encoder.encryptor.lock().unwrap().position.borrow(), 2);
assert_eq!(*decoder.encryptor.lock().unwrap().position.borrow(), 2);
assert!(decoder
.decode_bytes(bytes1.as_slice())
assert!(encryptor
.decode_bytes(bytes1.as_slice(), 1)
.unwrap()
.buffer
.is_none());
encoder.reset();
decoder.reset();
assert_eq!(*encoder.encryptor.lock().unwrap().position.borrow(), 0);
assert_eq!(*decoder.encryptor.lock().unwrap().position.borrow(), 0);
let new_decode_result1 = decoder.decode_bytes(bytes1.as_slice()).unwrap();
let encryptor = Encryptor::init(path, TEST_KEY);
let new_decode_result1 = encryptor.decode_bytes(bytes1.as_slice(), 0).unwrap();
assert_eq!(new_decode_result1.buffer, Some(buffer1));
encoder.remove_file();
assert!(!encoder.meta_file_path.exists());
encryptor.remove_file();
assert!(!encryptor.meta_file_path.exists());
}
}
22 changes: 14 additions & 8 deletions src/core/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ const LOG_TAG: &str = "MMKV:MemoryMap";

pub struct Iter<'a, F>
where
F: Fn(&[u8]) -> crate::Result<DecodeResult>,
F: Fn(&[u8], u32) -> crate::Result<DecodeResult>,
{
mm: &'a MemoryMap,
pub position: u32,
start: usize,
end: usize,
decode: F,
Expand All @@ -16,12 +17,13 @@ where
impl MemoryMap {
pub fn iter<F>(&self, decode: F) -> Iter<F>
where
F: Fn(&[u8]) -> crate::Result<DecodeResult>,
F: Fn(&[u8], u32) -> crate::Result<DecodeResult>,
{
let start = LEN_OFFSET;
let end = self.offset();
Iter {
mm: self,
position: 0,
start,
end,
decode,
Expand All @@ -31,7 +33,7 @@ impl MemoryMap {

impl<'a, F> Iterator for Iter<'a, F>
where
F: Fn(&[u8]) -> crate::Result<DecodeResult>,
F: Fn(&[u8], u32) -> crate::Result<DecodeResult>,
{
type Item = Option<Buffer>;

Expand All @@ -40,7 +42,8 @@ where
return None;
}
let bytes = self.mm.read(self.start..self.end);
let decode_result = (self.decode)(bytes);
let decode_result = (self.decode)(bytes, self.position);
self.position += 1;
match decode_result {
Ok(result) => {
self.start += result.len as usize;
Expand Down Expand Up @@ -69,7 +72,7 @@ mod tests {

struct TestEncoderDecoder;
impl Encoder for TestEncoderDecoder {
fn encode_to_bytes(&self, raw_buffer: &Buffer) -> Result<Vec<u8>> {
fn encode_to_bytes(&self, raw_buffer: &Buffer, _: u32) -> Result<Vec<u8>> {
let bytes_to_write = raw_buffer.to_bytes();
let len = bytes_to_write.len() as u32;
let mut data = len.to_be_bytes().to_vec();
Expand All @@ -79,7 +82,7 @@ mod tests {
}

impl Decoder for TestEncoderDecoder {
fn decode_bytes(&self, data: &[u8]) -> Result<DecodeResult> {
fn decode_bytes(&self, data: &[u8], _: u32) -> Result<DecodeResult> {
let offset = size_of::<u32>();
let item_len = u32::from_be_bytes(data[0..offset].try_into().map_err(|_| DataInvalid)?);
let bytes_to_decode = &data[offset..(offset + item_len as usize)];
Expand Down Expand Up @@ -115,12 +118,15 @@ mod tests {
let test_encoder = &TestEncoderDecoder;
for i in 0..10 {
let buffer = Buffer::from_i32(&i.to_string(), i);
mm.append(test_encoder.encode_to_bytes(&buffer).unwrap())
mm.append(test_encoder.encode_to_bytes(&buffer, i as u32).unwrap())
.unwrap();
buffers.push(buffer);
}
let decoder = &TestEncoderDecoder;
for (index, i) in mm.iter(|bytes| decoder.decode_bytes(bytes)).enumerate() {
for (index, i) in mm
.iter(|bytes, position| decoder.decode_bytes(bytes, position))
.enumerate()
{
assert_eq!(buffers[index], i.unwrap());
}
let _ = fs::remove_file("test_mmap_iterator");
Expand Down
1 change: 1 addition & 0 deletions src/core/memory_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ impl DerefMut for RawMmap {
unsafe impl Send for RawMmap {}

#[derive(Debug)]
#[repr(transparent)]
pub struct MemoryMap(RawMmap);

impl MemoryMap {
Expand Down
Loading

0 comments on commit 9cf8049

Please sign in to comment.