Skip to content

Commit

Permalink
feat: PRG stream id (#121)
Browse files Browse the repository at this point in the history
* add stream_id to PRG and update benches

* persist counter state
  • Loading branch information
sinui0 authored May 3, 2024
1 parent 3fc9972 commit cfa42a5
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 23 deletions.
22 changes: 15 additions & 7 deletions mpz-core/benches/prg.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};

use mpz_core::{block::Block, prg::Prg};
use rand_core::RngCore;

fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("Prg::byte", move |bench| {
let mut group = c.benchmark_group("prg");

group.throughput(Throughput::Bytes(1));
group.bench_function("byte", move |bench| {
let mut prg = Prg::new();
let mut x = 0u8;
bench.iter(|| {
Expand All @@ -13,17 +16,20 @@ fn criterion_benchmark(c: &mut Criterion) {
});
});

c.bench_function("Prg::bytes", move |bench| {
const BYTES_PER: u64 = 16 * 1024;
group.throughput(Throughput::Bytes(BYTES_PER));
group.bench_function("bytes", move |bench| {
let mut prg = Prg::new();
let mut x = (0..16 * 1024)
let mut x = (0..BYTES_PER)
.map(|_| rand::random::<u8>())
.collect::<Vec<u8>>();
bench.iter(|| {
prg.fill_bytes(black_box(&mut x));
});
});

c.bench_function("Prg::block", move |bench| {
group.throughput(Throughput::Elements(1));
group.bench_function("block", move |bench| {
let mut prg = Prg::new();
let mut x = Block::ZERO;
bench.iter(|| {
Expand All @@ -32,9 +38,11 @@ fn criterion_benchmark(c: &mut Criterion) {
});
});

c.bench_function("Prg::blocks", move |bench| {
const BLOCKS_PER: u64 = 16 * 1024;
group.throughput(Throughput::Elements(BLOCKS_PER));
group.bench_function("blocks", move |bench| {
let mut prg = Prg::new();
let mut x = (0..16 * 1024)
let mut x = (0..BLOCKS_PER)
.map(|_| rand::random::<Block>())
.collect::<Vec<Block>>();
bench.iter(|| {
Expand Down
106 changes: 90 additions & 16 deletions mpz-core/src/prg.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,42 @@
//! Implement AES-based PRG.
use std::collections::HashMap;

use crate::{aes::AesEncryptor, Block};
use rand::Rng;
use rand_core::{
block::{BlockRng, BlockRngCore},
CryptoRng, RngCore, SeedableRng,
};

/// Struct of PRG Core
#[derive(Clone)]
struct PrgCore {
aes: AesEncryptor,
state: u64,
// Stores the counter for each stream id.
state: HashMap<u64, u64>,
stream_id: u64,
counter: u64,
}

// This implementation is somehow standard, and is adapted from Swanky.
impl BlockRngCore for PrgCore {
type Item = u32;
type Results = [u32; 4 * AesEncryptor::AES_BLOCK_COUNT];

// Compute [AES(state)..AES(state+8)]
// Compute 8 encrypted counter blocks at a time.
#[inline(always)]
fn generate(&mut self, results: &mut Self::Results) {
let mut states = [0; AesEncryptor::AES_BLOCK_COUNT].map(
#[inline(always)]
|_| {
let x = self.state;
self.state += 1;
Block::from(bytemuck::cast::<_, [u8; 16]>([x, 0u64]))
let mut block = [0u8; 16];
let counter = self.counter;
self.counter += 1;

block[..8].copy_from_slice(&counter.to_le_bytes());
block[8..].copy_from_slice(&self.stream_id.to_le_bytes());

Block::from(block)
},
);
self.aes.encrypt_many_blocks(&mut states);
Expand All @@ -40,13 +50,24 @@ impl SeedableRng for PrgCore {
#[inline(always)]
fn from_seed(seed: Self::Seed) -> Self {
let aes = AesEncryptor::new(seed);
Self { aes, state: 0u64 }
Self {
aes,
state: Default::default(),
stream_id: 0u64,
counter: 0u64,
}
}
}

impl CryptoRng for PrgCore {}

/// Struct of PRG
/// AES-based PRG.
///
/// This PRG is based on AES128 used in counter-mode to generate pseudo-random data streams.
///
/// # Stream ID
///
/// The PRG is configurable with a stream ID, which can be used to generate distinct streams using the same seed. See [`Prg::set_stream_id`].
#[derive(Clone)]
pub struct Prg(BlockRng<PrgCore>);

Expand Down Expand Up @@ -92,8 +113,28 @@ impl Prg {
/// New Prg with random seed.
#[inline(always)]
pub fn new() -> Self {
let seed = rand::random::<Block>();
Prg::from_seed(seed)
Prg::from_seed(rand::random::<Block>())
}

/// Returns the current counter.
pub fn counter(&self) -> u64 {
self.0.core.counter
}

/// Returns the stream id.
pub fn stream_id(&self) -> u64 {
self.0.core.stream_id
}

/// Sets the stream id.
pub fn set_stream_id(&mut self, stream_id: u64) {
let state = &mut self.0.core.state;
state.insert(self.0.core.stream_id, self.0.core.counter);

let counter = state.get(&stream_id).copied().unwrap_or(0);

self.0.core.stream_id = stream_id;
self.0.core.counter = counter;
}

/// Generate a random bool value.
Expand Down Expand Up @@ -141,10 +182,43 @@ impl Default for Prg {
}
}

#[test]
fn prg_test() {
let mut prg = Prg::new();
let mut x = vec![Block::ZERO; 2];
prg.random_blocks(&mut x);
assert_ne!(x[0], x[1]);
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_prg_ne() {
let mut prg = Prg::new();
let mut x = vec![Block::ZERO; 2];
prg.random_blocks(&mut x);
assert_ne!(x[0], x[1]);
}

#[test]
fn test_prg_streams_are_distinct() {
let mut prg = Prg::from_seed(Block::ZERO);
let mut x = vec![Block::ZERO; 2];
prg.random_blocks(&mut x);

let mut y = vec![Block::ZERO; 2];
prg.set_stream_id(1);
prg.random_blocks(&mut y);

assert_ne!(x[0], y[0]);
}

#[test]
fn test_prg_state_persisted() {
let mut prg = Prg::from_seed(Block::ZERO);
let mut x = vec![Block::ZERO; 2];
prg.random_blocks(&mut x);

let counter = prg.counter();
assert_ne!(counter, 0);

prg.set_stream_id(1);
prg.set_stream_id(0);

assert_eq!(prg.counter(), counter);
}
}

0 comments on commit cfa42a5

Please sign in to comment.