Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle early returns in #[heap_neutral] #871

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/light-system-programs-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
- "circuit-lib/verifier/**"
- "merkle-tree/**"
- ".github/workflows/light-system-programs-tests.yml"
- "heap/**"
pull_request:
branches:
- "*"
Expand All @@ -17,6 +18,7 @@ on:
- "circuit-lib/verifier/**"
- "merkle-tree/**"
- ".github/workflows/light-system-programs-tests.yml"
- "heap/**"
types:
- opened
- synchronize
Expand Down
90 changes: 76 additions & 14 deletions heap/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::{alloc::Layout, mem::size_of, ptr::null_mut};
pub mod bench;

#[cfg(target_os = "solana")]
use anchor_lang::{
prelude::*,
solana_program::entrypoint::{HEAP_LENGTH, HEAP_START_ADDRESS},
Expand All @@ -14,12 +13,12 @@ pub static GLOBAL_ALLOCATOR: BumpAllocator = BumpAllocator {
len: HEAP_LENGTH,
};

#[cfg(target_os = "solana")]
#[error_code]
pub enum HeapError {
#[msg("The provided position to free is invalid.")]
InvalidHeapPos,
}

pub struct BumpAllocator {
pub start: usize,
pub len: usize,
Expand All @@ -28,7 +27,6 @@ pub struct BumpAllocator {
impl BumpAllocator {
const RESERVED_MEM: usize = size_of::<*mut u8>();

#[cfg(target_os = "solana")]
pub fn new() -> Self {
Self {
start: HEAP_START_ADDRESS as usize,
Expand Down Expand Up @@ -56,23 +54,22 @@ impl BumpAllocator {
*pos_ptr = pos;
}

#[cfg(target_os = "solana")]
pub fn log_total_heap(&self, msg: &str) -> u64 {
const HEAP_END_ADDRESS: u64 = HEAP_START_ADDRESS as u64 + HEAP_LENGTH as u64;

pub fn total_heap(&self) -> u64 {
const HEAP_END_ADDRESS: u64 = HEAP_START_ADDRESS + HEAP_LENGTH as u64;
let heap_start = unsafe { self.pos() } as u64;
let heap_used = HEAP_END_ADDRESS - heap_start;
msg!("{}: total heap used: {}", msg, heap_used);
heap_used
HEAP_END_ADDRESS - heap_start
}

pub fn log_total_heap(&self, msg: &str) -> u64 {
let total_heap = self.total_heap();
msg!("{}: total heap used: {}", msg, total_heap);
total_heap
}

#[cfg(target_os = "solana")]
pub fn get_heap_pos(&self) -> usize {
let heap_start = unsafe { self.pos() } as usize;
heap_start
unsafe { self.pos() }
}

#[cfg(target_os = "solana")]
pub fn free_heap(&self, pos: usize) -> Result<()> {
if pos < self.start + BumpAllocator::RESERVED_MEM || pos > self.start + self.len {
return err!(HeapError::InvalidHeapPos);
Expand All @@ -81,6 +78,28 @@ impl BumpAllocator {
unsafe { self.move_cursor(pos) };
Ok(())
}

#[allow(unused_variables)]
pub fn guard(&self, msg: String) -> HeapNeutralGuard {
#[cfg(feature = "mem-profiling")]
self.allocator.log_total_heap(format!("pre: {}", self.msg));
let pos = self.get_heap_pos();
HeapNeutralGuard {
allocator: self,
#[cfg(feature = "mem-profiling")]
msg,
pos,
}
}
}

impl Default for BumpAllocator {
fn default() -> Self {
Self {
start: HEAP_START_ADDRESS as usize,
len: HEAP_LENGTH,
}
}
}

unsafe impl std::alloc::GlobalAlloc for BumpAllocator {
Expand All @@ -107,6 +126,21 @@ unsafe impl std::alloc::GlobalAlloc for BumpAllocator {
}
}

pub struct HeapNeutralGuard<'a> {
allocator: &'a BumpAllocator,
#[cfg(feature = "mem-profiling")]
msg: String,
pos: usize,
}

impl<'a> Drop for HeapNeutralGuard<'a> {
fn drop(&mut self) {
#[cfg(feature = "mem-profiling")]
self.allocator.log_total_heap(format!("post: {}", self.msg));
let _ = self.allocator.free_heap(self.pos);
}
}

#[cfg(test)]
mod test {
use std::{
Expand Down Expand Up @@ -229,4 +263,32 @@ mod test {
assert_eq!(0, ptr.align_offset(size_of::<u64>()));
}
}

#[test]
fn test_heap_neutral_guard() {
let heap = [0u8; 128];
let allocator = BumpAllocator {
start: heap.as_ptr() as *const _ as usize,
len: heap.len(),
};

let layout = Layout::from_size_align(1, size_of::<u8>()).unwrap();
let _ptr_1 = unsafe { allocator.alloc(layout) };

let old_pos = allocator.get_heap_pos();

// With an explicit `drop`.
let guard = allocator.guard("ayylmao".to_string());
let _ptr_2 = unsafe { allocator.alloc(layout) };

drop(guard);
assert_eq!(allocator.get_heap_pos(), old_pos);

// In a scope, which should drop the guard implicitly.
{
let _guard = allocator.guard("ayylmao".to_string());
let _ptr_3 = unsafe { allocator.alloc(layout) };
}
assert_eq!(allocator.get_heap_pos(), old_pos);
}
}
36 changes: 11 additions & 25 deletions macros/light/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate proc_macro;
use accounts::process_light_accounts;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, parse_quote, DeriveInput, ItemFn};
use syn::{parse_macro_input, DeriveInput, ItemFn};
use traits::process_light_traits;
mod accounts;
mod pubkey;
Expand All @@ -19,34 +19,20 @@ pub fn pubkey(input: TokenStream) -> TokenStream {

#[proc_macro_attribute]
pub fn heap_neutral(_: TokenStream, input: TokenStream) -> TokenStream {
#[allow(unused_mut)]
let mut function = parse_macro_input!(input as ItemFn);

// Insert memory management code at the beginning of the function
let init_code: syn::Stmt = parse_quote! {
#[cfg(target_os = "solana")]
let pos = light_heap::GLOBAL_ALLOCATOR.get_heap_pos();
};
let msg = format!("pre: {}", function.sig.ident);
let log_pre: syn::Stmt = parse_quote! {
#[cfg(all(target_os = "solana", feature = "mem-profiling"))]
light_heap::GLOBAL_ALLOCATOR.log_total_heap(#msg);
};
function.block.stmts.insert(0, init_code);
function.block.stmts.insert(1, log_pre);
#[cfg(target_os = "solana")]
{
use syn::parse_quote;

// Insert memory management code at the end of the function
let msg = format!("post: {}", function.sig.ident);
let log_post: syn::Stmt = parse_quote! {
#[cfg(all(target_os = "solana", feature = "mem-profiling"))]
light_heap::GLOBAL_ALLOCATOR.log_total_heap(#msg);
};
let cleanup_code: syn::Stmt = parse_quote! {
#[cfg(target_os = "solana")]
light_heap::GLOBAL_ALLOCATOR.free_heap(pos)?;
};
let len = function.block.stmts.len();
function.block.stmts.insert(len - 1, log_post);
function.block.stmts.insert(len - 1, cleanup_code);
let msg = function.sig.ident.clone().to_string();
let init_code: syn::Stmt = parse_quote! {
let _guard = light_heap::GLOBAL_ALLOCATOR.guard(#msg.to_string());
};
function.block.stmts.insert(0, init_code);
}
TokenStream::from(quote! { #function })
}

Expand Down
Loading