Skip to content

Commit

Permalink
Process Level Guard for IPTables (#978)
Browse files Browse the repository at this point in the history
closes #955
  • Loading branch information
DmitryDodzin committed Jan 24, 2023
1 parent 94d50e0 commit eaadde4
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 18 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ Check [Keep a Changelog](http://keepachangelog.com/) for recommendations on how
- Support for Go's `os.ReadDir` on Linux (by hooking the `getdents64` syscall). Part of
[#120](https://github.com/metalbear-co/mirrord/issues/120).

### Changed

- mirrord-agent: Wrap agent with a parent proccess to doublecheck the clearing of iptables. See [#955](https://github.com/metalbear-co/mirrord/issues/955)

### Fixed

- mirrord-agent: Handle HTTP upgrade requests when the stealer feature is enabled
Expand Down
59 changes: 57 additions & 2 deletions mirrord/agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};
use crate::{
cli::Args,
runtime::{get_container, Container, ContainerRuntime},
steal::{connection::TcpConnectionStealer, StealerCommand},
steal::{
connection::TcpConnectionStealer,
ip_tables::{IPTableFormatter, SafeIpTables, MIRRORD_IPTABLE_CHAIN_ENV},
StealerCommand,
},
util::{run_thread_in_namespace, ClientId, IndexAllocator},
};

Expand Down Expand Up @@ -465,6 +469,50 @@ async fn start_agent() -> Result<()> {
Ok(())
}

async fn clear_iptable_chain(chain_name: String) -> Result<()> {
let ipt = iptables::new(false).unwrap();
let formatter = IPTableFormatter::detect(&ipt)?;

SafeIpTables::remove_chain(&ipt, &formatter, &chain_name)
}

fn spawn_child_agent() -> Result<()> {
let command_args = std::env::args().collect::<Vec<_>>();

let mut child_agent = std::process::Command::new(&command_args[0])
.args(&command_args[1..])
.spawn()?;

let _ = child_agent.wait();

Ok(())
}

async fn start_iptable_guard() -> Result<()> {
debug!("start_iptable_guard -> Initializing iptable-guard.");

let args = parse_args();
let state = State::new(&args).await?;
let pid = state.get_container_pid().await?;

let chain_name = SafeIpTables::<iptables::IPTables>::get_chain_name();

std::env::set_var(MIRRORD_IPTABLE_CHAIN_ENV, &chain_name);

let result = spawn_child_agent();

run_thread_in_namespace(
clear_iptable_chain(chain_name),
"clear iptables".to_owned(),
pid,
"net",
)
.join()
.map_err(|_| AgentError::JoinTask)??;

result
}

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
Expand All @@ -479,7 +527,13 @@ async fn main() -> Result<()> {

debug!("main -> Initializing mirrord-agent.");

match start_agent().await {
let agent_result = if std::env::var(MIRRORD_IPTABLE_CHAIN_ENV).is_ok() {
start_agent().await
} else {
start_iptable_guard().await
};

match agent_result {
Ok(_) => {
info!("main -> mirrord-agent `start` exiting successfully.")
}
Expand All @@ -490,5 +544,6 @@ async fn main() -> Result<()> {
)
}
}

Ok(())
}
2 changes: 1 addition & 1 deletion mirrord/agent/src/steal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{
pub(super) mod api;
pub(super) mod connection;
pub(crate) mod http;
mod ip_tables;
pub(super) mod ip_tables;
mod orig_dst;

/// Commands from the agent that are passed down to the stealer worker, through [`TcpStealerApi`].
Expand Down
45 changes: 30 additions & 15 deletions mirrord/agent/src/steal/ip_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ use rand::distributions::{Alphanumeric, DistString};

use crate::error::{AgentError, Result};

pub(crate) static MIRRORD_IPTABLE_CHAIN_ENV: &str = "MIRRORD_IPTABLE_CHAIN_NAME";

#[cfg_attr(test, mockall::automock)]
pub(super) trait IPTables {
pub(crate) trait IPTables {
fn create_chain(&self, name: &str) -> Result<()>;
fn remove_chain(&self, name: &str) -> Result<()>;

Expand Down Expand Up @@ -56,7 +58,7 @@ impl IPTables for iptables::IPTables {
}

/// Wrapper struct for IPTables so it flushes on drop.
pub(super) struct SafeIpTables<IPT: IPTables> {
pub(crate) struct SafeIpTables<IPT: IPTables> {
inner: IPT,
chain_name: String,
formatter: IPTableFormatter,
Expand All @@ -75,12 +77,11 @@ where
pub(super) fn new(ipt: IPT) -> Result<Self> {
let formatter = IPTableFormatter::detect(&ipt)?;

let random_string = Alphanumeric.sample_string(&mut rand::thread_rng(), 5);
let chain_name = format!("MIRRORD_REDIRECT_{}", random_string);
let chain_name = Self::get_chain_name();

ipt.create_chain(&chain_name)?;

ipt.add_rule(formatter.entrypoint(), &format!("-j {}", chain_name))?;
ipt.add_rule(formatter.entrypoint(), &format!("-j {}", &chain_name))?;

Ok(Self {
inner: ipt,
Expand All @@ -89,6 +90,27 @@ where
})
}

pub(crate) fn get_chain_name() -> String {
std::env::var(MIRRORD_IPTABLE_CHAIN_ENV).unwrap_or_else(|_| {
format!(
"MIRRORD_REDIRECT_{}",
Alphanumeric.sample_string(&mut rand::thread_rng(), 5)
)
})
}

pub(crate) fn remove_chain(
ipt: &IPT,
formatter: &IPTableFormatter,
chain_name: &str,
) -> Result<()> {
ipt.remove_rule(formatter.entrypoint(), &format!("-j {}", chain_name))?;

ipt.remove_chain(chain_name)?;

Ok(())
}

/// Helper function that lists all the iptables' rules belonging to [`Self::chain_name`].
#[tracing::instrument(level = "trace", skip(self))]
pub(super) fn list_rules(&self) -> Result<Vec<String>> {
Expand Down Expand Up @@ -171,18 +193,11 @@ where
IPT: IPTables,
{
fn drop(&mut self) {
self.inner
.remove_rule(
self.formatter.entrypoint(),
&format!("-j {}", self.chain_name),
)
.unwrap();

self.inner.remove_chain(&self.chain_name).unwrap();
Self::remove_chain(&self.inner, &self.formatter, &self.chain_name).unwrap();
}
}

enum IPTableFormatter {
pub(crate) enum IPTableFormatter {
Normal,
Mesh,
}
Expand All @@ -191,7 +206,7 @@ impl IPTableFormatter {
const MESH_OUTPUTS: [&'static str; 2] = ["-j PROXY_INIT_OUTPUT", "-j ISTIO_OUTPUT"];

#[tracing::instrument(level = "trace", skip_all)]
fn detect<IPT: IPTables>(ipt: &IPT) -> Result<Self> {
pub(crate) fn detect<IPT: IPTables>(ipt: &IPT) -> Result<Self> {
let output = ipt.list_rules("OUTPUT")?;

if output.iter().any(|rule| {
Expand Down

0 comments on commit eaadde4

Please sign in to comment.