Skip to content

Commit

Permalink
Impl update_config to support remote update
Browse files Browse the repository at this point in the history
- Removed `reload_config` b/c it's been deprecated.

There are code patterns like `writer.lock().unwrap_or_else(...)` which
ignore `PoisonError`, assuming that the log writer should not cause a
panic even if another thread has poisoned the lock.

If `log_dir` changes by updating config, logs should be saved in that
directory. This commit implements this by wrapping the file writer.
Alternatively, `tracing_subscriber::reload` could be used, but it cannot
handle cases where the directory is added or removed (as it requires
reassigning the writer itself, which is not possible due to internal
limitations).
Related issue: tokio-rs/tracing#1629

Close: #112
  • Loading branch information
danbi2990 committed Nov 15, 2024
1 parent 7d89b3a commit 1aabd74
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 92 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- Support `update_config` from the Manager server only if the local configuration
is not specified.

### Changed

- Configuration options required for establishing a connection with the central
Expand All @@ -32,6 +37,7 @@ Versioning](https://semver.org/spec/v2.0.0.html).
- Removed OS-specific configuration directory.
- Linux: $HOME/.config/crusher/config.toml
- macOS: $HOME/Library/Application Support/com.cluml.crusher/config.toml
- Removed `reload_config` functionality.

## [0.4.1] - 2024-10-04

Expand Down
204 changes: 136 additions & 68 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@ mod request;
mod settings;
mod subscribe;

use std::fs::{create_dir_all, File, OpenOptions};
use std::io::Write;
use std::net::SocketAddr;
use std::path::Path;
use std::str::FromStr;
use std::sync::Mutex;
use std::{collections::HashMap, env, fs, sync::Arc};

use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use clap::Parser;
use client::Certs;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use settings::Settings;
pub use settings::TEMP_TOML_POST_FIX;
use tokio::{
sync::{Notify, RwLock},
sync::{mpsc, Notify, RwLock},
task,
};
use tracing::metadata::LevelFilter;
use tracing::{error, warn};
use tracing::{error, info, warn};
use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::{
fmt, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer,
Expand All @@ -29,7 +32,6 @@ use tracing_subscriber::{
use crate::{request::RequestedPolicy, subscribe::read_last_timestamp};

const REQUESTED_POLICY_CHANNEL_SIZE: usize = 1;
const DEFAULT_TOML: &str = "/usr/local/aice/conf/crusher.toml";

#[derive(Debug, Clone)]
pub struct ManagerServer {
Expand Down Expand Up @@ -80,20 +82,11 @@ pub struct CmdLineArgs {
#[tokio::main]
async fn main() -> Result<()> {
let args = CmdLineArgs::parse();

let config_path = args
.config
.clone()
.unwrap_or_else(|| DEFAULT_TOML.to_string());

let mut settings = Settings::from_args(args.clone())?;

let temp_path = format!("{config_path}{TEMP_TOML_POST_FIX}");

let _guards = init_tracing(settings.log_dir.as_deref());
let mut log_manager = init_tracing(settings.log_dir.as_deref())?;
let (config_tx, mut config_rx) = mpsc::channel::<String>(1);

loop {
let config_reload = Arc::new(Notify::new());
let notify_shutdown = Arc::new(Notify::new());

let cert_pem = fs::read(&args.cert)
Expand Down Expand Up @@ -134,8 +127,8 @@ async fn main() -> Result<()> {
task::spawn(request_client.run(
Arc::clone(&runtime_policy_list),
Arc::clone(&delete_policy_ids),
config_reload.clone(),
notify_shutdown.clone(),
config_tx.clone(),
));

let subscribe_client = subscribe::Client::new(
Expand All @@ -152,26 +145,25 @@ async fn main() -> Result<()> {
notify_shutdown.clone(),
));
loop {
config_reload.notified().await;
match Settings::from_file(&temp_path) {
Ok(new_settings) => {
settings = new_settings;
notify_shutdown.notify_waiters();
notify_shutdown.notified().await;
fs::rename(&temp_path, &config_path).unwrap_or_else(|e| {
error!("Failed to rename the new configuration file: {e}");
});
break;
}
Err(e) => {
error!("Failed to load the new configuration: {:?}", e);
warn!("Run Crusher with the previous config");
fs::remove_file(&temp_path).unwrap_or_else(|e| {
error!("Failed to remove the temporary file: {e}");
});
if let Some(config) = config_rx.recv().await {
if args.config.is_some() {
warn!("Cannot update the configuration from the Manager server because a local configuration file is specified");
continue;
}
}
let Ok(new_settings) = Settings::from_str(&config) else {
error!("Failed to parse the configuration from Manager server");
continue;
};
log_manager
.dynamic_log_file_writer
.change_log_dir(settings.log_dir.as_deref(), new_settings.log_dir.as_deref())?;
settings = new_settings;
notify_shutdown.notify_waiters();
notify_shutdown.notified().await;
info!("Updated the configuration from the Manager server");
break;
};
info!("No new configuration received from the Manager server");
}
}
}
Expand Down Expand Up @@ -209,20 +201,112 @@ fn to_ca_certs(ca_certs_pem: &Vec<Vec<u8>>) -> Result<rustls::RootCertStore> {
Ok(root_cert)
}

/// Initializes the tracing subscriber.
/// Manages the log file and guards.
///
/// If `log_dir` is `None` or the runtime is in debug mode, logs will be printed to stdout.
/// `_guards` will flush the logs when they are dropped.
///
/// Returns a vector of `WorkerGuard` that flushes the log when dropped.
fn init_tracing(log_dir: Option<&Path>) -> Vec<WorkerGuard> {
let mut guards = vec![];
let subscriber = tracing_subscriber::Registry::default();
let file_name = format!("{}.log", env!("CARGO_PKG_NAME"));
/// `dynamic_log_file_writer` wraps the log file to allow changing its path dynamically.
/// If the log file is not provided, logs will be ignored by using `std::io::sink()`.
struct LogManager {
_guards: Vec<WorkerGuard>,
dynamic_log_file_writer: DynamicLogFileWriter,
}

#[derive(Clone)]
struct DynamicLogFileWriter {
writer: Arc<Mutex<Box<dyn Write + Send>>>,
}

impl DynamicLogFileWriter {
fn try_new(dir_path: Option<&Path>) -> Result<Self> {
Ok(Self {
writer: Arc::new(Mutex::new(DynamicLogFileWriter::create_writer(dir_path)?)),
})
}

fn create_log_file(dir_path: &Path) -> Result<File> {
if let Err(e) = create_dir_all(dir_path) {
bail!("Cannot create directory recursively for {dir_path:?}: {e}");
}

let file_name = format!("{}.log", env!("CARGO_PKG_NAME"));

let file = OpenOptions::new()
.create(true)
.append(true)
.open(dir_path.join(file_name))
.map_err(|e| anyhow!("Cannot create log file: {e}"));

let is_valid_file =
matches!(log_dir, Some(path) if std::fs::File::create(path.join(&file_name)).is_ok());
file
}

fn create_writer(log_dir: Option<&Path>) -> Result<Box<dyn Write + Send>> {
match log_dir {
Some(dir) => Ok(Box::new(DynamicLogFileWriter::create_log_file(dir)?)),
None => Ok(Box::new(std::io::sink())),
}
}

let stdout_layer = if !is_valid_file || cfg!(debug_assertions) {
fn change_log_dir(&mut self, old_dir: Option<&Path>, new_dir: Option<&Path>) -> Result<()> {
if old_dir.eq(&new_dir) {
info!("New directory is the same as the old directory");
return Ok(());
}
if let Some(dir) = new_dir {
info!("Log directory will change to {}", dir.display());
}
let new_writer = DynamicLogFileWriter::create_writer(new_dir)?;
{
let mut old_writer = self
.writer
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
old_writer.flush()?;
*old_writer = new_writer;
}
if let Some(dir) = old_dir {
info!("Previous logs are in {}", dir.display());
}
Ok(())
}
}

impl Write for DynamicLogFileWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut writer = self
.writer
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
writer.write(buf)
}

fn flush(&mut self) -> std::io::Result<()> {
let mut writer = self
.writer
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
writer.flush()
}
}

/// Initializes the tracing subscriber.
///
/// If `log_dir` is `None` or the runtime is in debug mode, logs will be printed to stdout.
fn init_tracing(log_dir: Option<&Path>) -> Result<LogManager> {
let dynamic_log_file_writer = DynamicLogFileWriter::try_new(log_dir)?;
let (file_writer, file_guard) = tracing_appender::non_blocking(dynamic_log_file_writer.clone());
let file_layer = fmt::Layer::default()
.with_ansi(false)
.with_target(false)
.with_writer(file_writer)
.with_filter(
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy(),
);
let mut guards = vec![file_guard];

let stdout_layer = if log_dir.is_none() || cfg!(debug_assertions) {
let (stdout_writer, stdout_guard) = tracing_appender::non_blocking(std::io::stdout());
guards.push(stdout_guard);
Some(
Expand All @@ -235,28 +319,12 @@ fn init_tracing(log_dir: Option<&Path>) -> Vec<WorkerGuard> {
None
};

let file_layer = if is_valid_file {
let file_appender = tracing_appender::rolling::never(
log_dir.expect("verified by is_valid_file"),
file_name,
);
let (file_writer, file_guard) = tracing_appender::non_blocking(file_appender);
guards.push(file_guard);
Some(
fmt::Layer::default()
.with_ansi(false)
.with_target(false)
.with_writer(file_writer)
.with_filter(
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy(),
),
)
} else {
None
};

subscriber.with(stdout_layer).with(file_layer).init();
guards
tracing_subscriber::Registry::default()
.with(stdout_layer)
.with(file_layer)
.init();
Ok(LogManager {
_guards: guards,
dynamic_log_file_writer,
})
}
37 changes: 25 additions & 12 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,16 @@ impl Client {
self,
active_policy_list: Arc<RwLock<HashMap<u32, RequestedPolicy>>>,
delete_policy_ids: Arc<RwLock<Vec<u32>>>,
config_reload: Arc<Notify>,
wait_shutdown: Arc<Notify>,
config_send: tokio::sync::mpsc::Sender<String>,
) -> Result<()> {
loop {
match connect(
&self,
active_policy_list.clone(),
delete_policy_ids.clone(),
config_reload.clone(),
wait_shutdown.clone(),
config_send.clone(),
)
.await
{
Expand Down Expand Up @@ -158,8 +158,8 @@ async fn connect(
client: &Client,
active_policy_list: Arc<RwLock<HashMap<u32, RequestedPolicy>>>,
delete_policy_ids: Arc<RwLock<Vec<u32>>>,
config_reload: Arc<Notify>,
wait_shutdown: Arc<Notify>,
config_send: tokio::sync::mpsc::Sender<String>,
) -> Result<()> {
let mut conn_builder = ConnectionBuilder::new(
&client.server_name,
Expand All @@ -171,18 +171,19 @@ async fn connect(
&client.key,
)?;
conn_builder.root_certs(&client.ca_certs)?;
let conn = conn_builder.connect().await?;
let connection = conn_builder.connect().await?;
info!("Connection established to server {}", client.server_address);

let request_handler = RequestHandler {
request_send: client.request_send.clone(),
active_policy_list,
delete_policy_ids,
config_reload: config_reload.clone(),
connection,
config_send,
};

tokio::select! {
res = handle_incoming(request_handler, &conn) => {
res = handle_incoming(request_handler) => {
if let Err(e) = res {
warn!("control channel failed: {}", e);
return Err(e);
Expand All @@ -196,9 +197,9 @@ async fn connect(
}
}

async fn handle_incoming(handler: RequestHandler, conn: &Connection) -> Result<()> {
async fn handle_incoming(handler: RequestHandler) -> Result<()> {
loop {
match conn.accept_bi().await {
match handler.connection.accept_bi().await {
Ok((mut send, mut recv)) => {
let mut hdl = handler.clone();
tokio::spawn(async move {
Expand All @@ -217,7 +218,8 @@ struct RequestHandler {
request_send: Sender<RequestedPolicy>,
active_policy_list: Arc<RwLock<HashMap<u32, RequestedPolicy>>>,
delete_policy_ids: Arc<RwLock<Vec<u32>>>,
config_reload: Arc<Notify>,
connection: Connection,
config_send: tokio::sync::mpsc::Sender<String>,
}

#[async_trait]
Expand Down Expand Up @@ -305,9 +307,20 @@ impl review_protocol::request::Handler for RequestHandler {
Ok(())
}

async fn reload_config(&mut self) -> Result<(), String> {
info!("start reloading configuration");
self.config_reload.notify_one();
async fn update_config(&mut self) -> Result<(), String> {
info!("Updating configuration");
match self.connection.get_config().await {
Ok(config) => {
self.config_send
.send(config)
.await
.map_err(|e| format!("Failed to send config: {e}"))?;
}
Err(e) => {
return Err(format!("Failed to get config: {e}"));
}
};

Ok(())
}

Expand Down
Loading

0 comments on commit 1aabd74

Please sign in to comment.