diff --git a/Cargo.lock b/Cargo.lock index bde0bfd..a223d15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -443,6 +443,16 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "ctrlc" +version = "3.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +dependencies = [ + "nix 0.28.0", + "windows-sys 0.52.0", +] + [[package]] name = "dashmap" version = "6.0.1" @@ -1127,6 +1137,8 @@ dependencies = [ "blazesym", "chrono", "clap", + "crossbeam-channel", + "ctrlc", "data-encoding", "errno", "gimli 0.31.0", @@ -1155,6 +1167,7 @@ dependencies = [ "thiserror", "tracing", "tracing-subscriber", + "v", ] [[package]] @@ -2569,6 +2582,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "v" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23bf1cdfefda84a01134d6a5a9815fbf66c8df618f03b8c2296e6d06041a2e75" + [[package]] name = "valuable" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 5365f6a..a3d6615 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,9 @@ nix = { version = "0.29.0", features = ["user"] } prost = "0.12" # Needed to encode protocol buffers to bytes. reqwest = { version = "0.12", features = ["blocking"] } lightswitch-proto = { path = "./lightswitch-proto"} +ctrlc = "3.4.4" +v = "0.1.0" +crossbeam-channel = "0.5.13" [dev-dependencies] assert_cmd = { version = "2.0.14" } diff --git a/src/main.rs b/src/main.rs index 9414973..ca408f3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,12 +9,13 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; use clap::Parser; +use crossbeam_channel::bounded; use inferno::flamegraph; use lightswitch::collector::{AggregatorCollector, Collector, NullCollector, StreamingCollector}; use nix::unistd::Uid; use primal::is_prime; use prost::Message; -use tracing::{error, Level}; +use tracing::{error, info, Level}; use tracing_subscriber::fmt::format::FmtSpan; use tracing_subscriber::FmtSubscriber; @@ -227,11 +228,20 @@ fn main() -> Result<(), Box> { } })); + let (stop_signal_sender, stop_signal_receive) = bounded(1); + + ctrlc::set_handler(move || { + info!("received Ctrl+C, stopping..."); + let _ = stop_signal_sender.send(()); + }) + .expect("Error setting Ctrl-C handler"); + let mut p: Profiler<'_> = Profiler::new( args.libbpf_logs, args.bpf_logging, args.duration, args.sample_freq, + stop_signal_receive, ); p.profile_pids(args.pids); p.run(collector.clone()); diff --git a/src/profiler.rs b/src/profiler.rs index f18c274..d20c2d1 100644 --- a/src/profiler.rs +++ b/src/profiler.rs @@ -4,9 +4,12 @@ use std::fmt; use std::fs; use std::os::fd::{AsFd, AsRawFd}; use std::path::PathBuf; -use std::sync::{mpsc, Arc, Mutex}; +use std::sync::{Arc, Mutex}; + +use crossbeam_channel::{bounded, select, tick, unbounded, Receiver, Sender}; + use std::thread; -use std::time::{Duration, Instant}; +use std::time::Duration; use anyhow::anyhow; use libbpf_rs::num_possible_cpus; @@ -120,7 +123,6 @@ impl ExecutableMapping { } pub struct NativeUnwindState { dirty: bool, - last_persisted: Instant, live_shard: Vec, known_executables: HashSet, shard_index: u64, @@ -130,7 +132,6 @@ impl Default for NativeUnwindState { fn default() -> Self { Self { dirty: false, - last_persisted: Instant::now() - Duration::from_secs(1_000), // old enough to trigger it the first time live_shard: Vec::with_capacity(SHARD_CAPACITY), known_executables: HashSet::new(), shard_index: 0, @@ -147,18 +148,20 @@ pub struct Profiler<'bpf> { procs: Arc>>, object_files: Arc>>, // Channel for new process events. - chan_send: Arc>>, - chan_receive: Arc>>, - // Channel for munmaps events. - tracers_chan_send: Arc>>, - tracers_chan_receive: Arc>>, + new_proc_chan_send: Arc>, + new_proc_chan_receive: Arc>, + // Channel for tracer events such as munmaps and process exits. + tracers_chan_send: Arc>, + tracers_chan_receive: Arc>, + // Profiler stop channel. + stop_chan_receive: Receiver<()>, // Native unwinding state native_unwind_state: NativeUnwindState, // Debug options filter_pids: HashMap, // Profile channel - profile_send: Arc>>, - profile_receive: Arc>>, + profile_send: Arc>, + profile_receive: Arc>, // Duration of this profile duration: Duration, // Per-CPU Sampling Frequency of this profile in Hz @@ -287,7 +290,9 @@ pub type SymbolizedAggregatedProfile = Vec; impl Default for Profiler<'_> { fn default() -> Self { - Self::new(false, false, Duration::MAX, 19) + let (_stop_signal_send, stop_signal_receive) = bounded(1); + + Self::new(false, false, Duration::MAX, 19, stop_signal_receive) } } @@ -297,6 +302,7 @@ impl Profiler<'_> { bpf_logging: bool, duration: Duration, sample_freq: u16, + stop_signal_receive: Receiver<()>, ) -> Self { let mut skel_builder: ProfilerSkelBuilder = ProfilerSkelBuilder::default(); skel_builder.obj_builder.debug(libbpf_debug); @@ -326,19 +332,19 @@ impl Profiler<'_> { let procs = Arc::new(Mutex::new(HashMap::new())); let object_files = Arc::new(Mutex::new(HashMap::new())); - let (sender, receiver) = mpsc::channel(); - let chan_send = Arc::new(Mutex::new(sender)); - let chan_receive = Arc::new(Mutex::new(receiver)); + let (sender, receiver) = unbounded(); + let chan_send = Arc::new(sender); + let chan_receive = Arc::new(receiver); - let (sender, receiver) = mpsc::channel(); - let tracers_chan_send = Arc::new(Mutex::new(sender)); - let tracers_chan_receive = Arc::new(Mutex::new(receiver)); + let (sender, receiver) = unbounded(); + let tracers_chan_send = Arc::new(sender); + let tracers_chan_receive = Arc::new(receiver); let native_unwind_state = NativeUnwindState::default(); - let (sender, receiver) = mpsc::channel(); - let profile_send: Arc>> = Arc::new(Mutex::new(sender)); - let profile_receive = Arc::new(Mutex::new(receiver)); + let (sender, receiver) = unbounded(); + let profile_send = Arc::new(sender); + let profile_receive = Arc::new(receiver); let filter_pids = HashMap::new(); @@ -348,10 +354,11 @@ impl Profiler<'_> { tracers, procs, object_files, - chan_send, - chan_receive, + new_proc_chan_send: chan_send, + new_proc_chan_receive: chan_receive, tracers_chan_send, tracers_chan_receive, + stop_chan_receive: stop_signal_receive, native_unwind_state, filter_pids, profile_send, @@ -370,11 +377,7 @@ impl Profiler<'_> { } pub fn send_profile(&mut self, profile: RawAggregatedProfile) { - self.profile_send - .lock() - .expect("sender lock") - .send(profile) - .expect("handle send"); + self.profile_send.send(profile).expect("handle send"); } pub fn run(mut self, collector: ThreadSafeCollector) { @@ -393,7 +396,7 @@ impl Profiler<'_> { self.tracers.attach().expect("attach tracers"); // New process events. - let chan_send = self.chan_send.clone(); + let chan_send = self.new_proc_chan_send.clone(); let perf_buffer = PerfBufferBuilder::new(self.bpf.maps().events()) .pages(PERF_BUFFER_BYTES / page_size::get()) .sample_cb(move |_cpu: i32, data: &[u8]| { @@ -416,8 +419,6 @@ impl Profiler<'_> { let mut event = tracer_event_t::default(); plain::copy_from_bytes(&mut event, data).expect("serde tracers event"); tracers_send - .lock() - .expect("sender lock") .send(TracerEvent::from(event)) .expect("handle event send"); }) @@ -439,7 +440,7 @@ impl Profiler<'_> { let collector = collector.clone(); thread::spawn(move || loop { - match profile_receive.lock().unwrap().recv() { + match profile_receive.recv() { Ok(profile) => { collector.lock().unwrap().collect( profile, @@ -453,70 +454,56 @@ impl Profiler<'_> { } }); - let start_time: Instant = Instant::now(); - let mut time_since_last_scheduled_collection: Instant = Instant::now(); + let ticks = tick(self.duration); + let persist_ticks = tick(Duration::from_millis(100)); loop { - if start_time.elapsed() >= self.duration { - debug!("done after running for {:?}", start_time.elapsed()); - let profile = self.collect_profile(); - self.send_profile(profile); - break; - } - - if time_since_last_scheduled_collection.elapsed() >= self.session_duration { - debug!("collecting profiles on schedule"); - let profile = self.collect_profile(); - self.send_profile(profile); - time_since_last_scheduled_collection = Instant::now(); - } - - let read = self - .tracers_chan_receive - .lock() - .expect("receive lock") - .recv_timeout(Duration::from_millis(50)); - - match read { - Ok(TracerEvent::Munmap(pid, start_address)) => { - self.handle_munmap(pid, start_address); - } - Ok(TracerEvent::ProcessExit(pid)) => { - self.handle_process_exit(pid); - } - Err(_) => {} - } - - let read = self - .chan_receive - .lock() - .expect("receive lock") - .recv_timeout(Duration::from_millis(150)); - - if let Ok(event) = read { - if event.type_ == event_type_EVENT_NEW_PROCESS { - self.event_new_proc(event.pid); - // Ensure we only remove the rate limits only if the above works. - // This is probably suited for a batched operation. - // let _ = self - // .bpf - // .maps() - // .rate_limits() - // .delete(unsafe { plain::as_bytes(&event) }); - } else { - error!("unknown event type {}", event.type_); + select! { + recv(self.stop_chan_receive) -> _ => { + println!("aa ctrl+c"); + let profile = self.collect_profile(); + self.send_profile(profile); + break; + }, + recv(ticks) -> _ => { + debug!("collecting profiles on schedule"); + let profile = self.collect_profile(); + self.send_profile(profile); + break; } - } - - let due_to_persist = - self.native_unwind_state.last_persisted.elapsed() > Duration::from_millis(100); - - if self.native_unwind_state.dirty - && due_to_persist - && self.persist_unwind_info(&self.native_unwind_state.live_shard) - { - self.native_unwind_state.dirty = false; - self.native_unwind_state.last_persisted = Instant::now(); + recv(self.tracers_chan_receive) -> read => { + match read { + Ok(TracerEvent::Munmap(pid, start_address)) => { + self.handle_munmap(pid, start_address); + } + Ok(TracerEvent::ProcessExit(pid)) => { + self.handle_process_exit(pid); + } + Err(_) => {} + } + }, + recv(self.new_proc_chan_receive) -> read => { + if let Ok(event) = read { + if event.type_ == event_type_EVENT_NEW_PROCESS { + self.event_new_proc(event.pid); + // Ensure we only remove the rate limits only if the above works. + // This is probably suited for a batched operation. + // let _ = self + // .bpf + // .maps() + // .rate_limits() + // .delete(unsafe { plain::as_bytes(&event) }); + } else { + error!("unknown event type {}", event.type_); + } + } + }, + recv(persist_ticks) -> _ => { + if self.native_unwind_state.dirty && self.persist_unwind_info(&self.native_unwind_state.live_shard) { + self.native_unwind_state.dirty = false; + } + }, + default(Duration::from_millis(100)) => {}, } } } @@ -1256,13 +1243,9 @@ impl Profiler<'_> { Ok(()) } - fn handle_event(sender: &Arc>>, data: &[u8]) { + fn handle_event(sender: &Arc>, data: &[u8]) { let event = plain::from_bytes(data).expect("handle event serde"); - sender - .lock() - .expect("sender lock") - .send(*event) - .expect("handle event send"); + sender.send(*event).expect("handle event send"); } fn handle_lost_events(cpu: i32, count: u64) { diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 7e8c4cb..0f3d2b4 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -4,6 +4,8 @@ use std::process::{Child, Command, Stdio}; use std::sync::{Arc, Mutex}; use std::time::Duration; +use crossbeam_channel::bounded; + use lightswitch::collector::{AggregatorCollector, Collector}; use lightswitch::profile::symbolize_profile; use lightswitch::profiler::Profiler; @@ -97,7 +99,14 @@ fn test_integration() { let collector = Arc::new(Mutex::new( Box::new(AggregatorCollector::new()) as Box )); - let mut p = Profiler::new(bpf_test_debug, bpf_test_debug, Duration::from_secs(5), 999); + let (_stop_signal_send, stop_signal_receive) = bounded(1); + let mut p = Profiler::new( + bpf_test_debug, + bpf_test_debug, + Duration::from_secs(5), + 999, + stop_signal_receive, + ); p.profile_pids(vec![cpp_proc.pid()]); p.run(collector.clone()); let collector = collector.lock().unwrap();