Skip to content

Commit

Permalink
Handle HTTP upgrade request when steal + HTTP traffic filter are enab…
Browse files Browse the repository at this point in the history
…led. (#973)

Issue: #928 

Depends on [test-images #14](metalbear-co/test-images#14).

Implements a pass-through mechanism to handle HTTP/1 upgrade requests from a client.

## Motivation

Currently, if you run mirrord with steal traffic, and HTTP filter enabled, we end up not serving HTTP upgrade requests appropriately. Our implementation of [`Service`](https://docs.rs/hyper/1.0.0-rc.2/hyper/service/trait.Service.html) would either:

- Forward the HTTP request to the local app (if a filter matched), send back a response and **not** upgrade the connection, or;
- Pass the request to its original destination, get a response back from it, and, again, **not** upgrade the connection.

This would prevent the user from handling websockets (our main motivator behind this feature), that start up as HTTP requests.

## Implementation

We're manually handling connection upgrades, by first checking if an HTTP request contains the "Upgrade" header, and if that's the case, then we:

- Current `unmacthed_request` flow:

1. Create a connection to the original destination (let's call this the `agent_original` connection);
2. Forward the HTTP request to it, awaiting its response;
3. Send the response back through our hyper connection (the `agent_client` connection we handle in hyper);

- The new steps:

4. Send the `agent_original` connection to the `filter_task` (through the `upgrade_tx` channel);
5. Retrieve the `agent_client` connection from hyper, as a normal `TcpStream`;
6. Copy whatever bytes were leftover from these connections (bytes that were read, but not processed by hyper);
7. Use [`copy_bidirectional`](https://docs.rs/tokio/latest/tokio/io/fn.copy_bidirectional.html) to keep copying from one connection to the other.

### Tests

The feature is being tested by `test_websocket_upgrade`, which is very similar to the `test_complete_passthrough` test. The main difference is that we use [`tokio-tugstenite`](https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/) to create a websocket connection (instead of a normal `TcpStream`).

## Alternatives

### Give a [`DuplexStream`](https://docs.rs/tokio/latest/tokio/io/struct.DuplexStream.html) to hyper

We could feed hyper a `DuplexStream` rather than the actual `TcpStream` connection, and deal with the upgrade ourselves, by parsing the bytes into an `Header`, checking for the "Upgrade: " header, or even use the [`Response`] as a way to identify that this was an upgrade request.

I've explored this approach, but this seemed a bit too much, especially when @t4lz asked me about "why not use the `unmatched_request`?"

### Use [`hyper::upgrade::on`]

hyper has a proper handler for dealing with HTTP upgrade requests, where you give it a [`Request`](https://docs.rs/hyper/1.0.0-rc.2/hyper/struct.Request.html) that contains an [`OnUpgrade`](https://docs.rs/hyper/1.0.0-rc.2/hyper/upgrade/struct.OnUpgrade.html) extension, and it'll give you back its connection (`agent_client`) and the leftover bytes.

I could not make this approach work, as we need to give the request to [`SendRequest`](https://docs.rs/hyper/1.0.0-rc.2/hyper/client/conn/http1/struct.SendRequest.html) by value, making it impossible to also feed it to the `upgrade::on` function. This was the case because we need to poll the connection (`agent_original`) in a separate task, thus requiring us to also move the request into it.

The request cannot be copied/cloned, as it contains a `Receiver` channel for the `OnUpgrade` feature, and if you try to manually copy it, the upgrade is never triggered, because we're dealing with different channels.

Co-authored-by: meowjesty <[email protected]>
  • Loading branch information
meowjesty and meowjesty committed Jan 23, 2023
1 parent 3b26d74 commit 7dccbd7
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 65 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion mirrord/agent/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ edition.workspace = true

[dependencies]
containerd-client = {git = "https://github.com/containerd/rust-extensions", rev="6bc49c007cf93869e7d83fca4818b6aae1145b45"}
tokio = { version = "1", features = ["rt", "rt-multi-thread", "net", "macros", "fs"] }
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "net", "macros", "fs"] }
serde.workspace = true
serde_json.workspace = true
pnet = "0.31"
Expand Down
5 changes: 5 additions & 0 deletions mirrord/agent/src/steal/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ impl HttpFilterManager {
///
/// You can't create just an empty [`HttpFilterManager`], as we don't steal traffic on ports
/// that no client has registered interest in.
#[tracing::instrument(level = "trace", skip(matched_tx))]
pub(super) fn new(
client_id: ClientId,
filter: Regex,
Expand All @@ -87,6 +88,7 @@ impl HttpFilterManager {
///
/// [`HttpFilterManager::client_filters`] are shared between hyper tasks, so adding a new one
/// here will impact the tasks as well.
#[tracing::instrument(level = "trace", skip(self))]
pub(super) fn add_client(&mut self, client_id: ClientId, filter: Regex) -> Option<Regex> {
self.client_filters.insert(client_id, filter)
}
Expand All @@ -95,10 +97,12 @@ impl HttpFilterManager {
///
/// [`HttpFilterManager::client_filters`] are shared between hyper tasks, so removing a client
/// here will impact the tasks as well.
#[tracing::instrument(level = "trace", skip(self))]
pub(super) fn remove_client(&mut self, client_id: &ClientId) -> Option<(ClientId, Regex)> {
self.client_filters.remove(client_id)
}

#[tracing::instrument(level = "trace", skip(self))]
pub(super) fn contains_client(&self, client_id: &ClientId) -> bool {
self.client_filters.contains_key(client_id)
}
Expand All @@ -121,6 +125,7 @@ impl HttpFilterManager {
));
}

#[tracing::instrument(level = "trace", skip(self))]
pub(super) fn is_empty(&self) -> bool {
self.client_filters.is_empty()
}
Expand Down
7 changes: 7 additions & 0 deletions mirrord/agent/src/steal/http/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use mirrord_protocol::ConnectionId;
use thiserror::Error;

use crate::steal::HandlerHttpRequest;
Expand All @@ -19,4 +20,10 @@ pub enum HttpTrafficError {

#[error("Failed with Captured `{0}`!")]
ResponseReceiver(#[from] tokio::sync::oneshot::error::RecvError),

#[error("Failed hyper HTTP `{0}`!")]
HyperHttp(#[from] hyper::http::Error),

#[error("Failed closing connection with `{0}`!")]
CloseSender(#[from] tokio::sync::mpsc::error::SendError<ConnectionId>),
}
92 changes: 77 additions & 15 deletions mirrord/agent/src/steal/http/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@ use std::{net::SocketAddr, sync::Arc, time::Duration};

use dashmap::DashMap;
use fancy_regex::Regex;
use hyper::server::conn::http1;
use hyper::server::{self, conn::http1};
use mirrord_protocol::ConnectionId;
use tokio::{io::copy_bidirectional, net::TcpStream, sync::mpsc::Sender};
use tokio::{
io::{copy_bidirectional, AsyncWriteExt},
net::TcpStream,
sync::{mpsc::Sender, oneshot},
};
use tracing::{error, trace};

use super::{hyper_handler::HyperHandler, DefaultReversibleStream, HttpVersion};
use super::{
error::HttpTrafficError,
hyper_handler::{HyperHandler, RawHyperConnection},
DefaultReversibleStream, HttpVersion,
};
use crate::{steal::HandlerHttpRequest, util::ClientId};

const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0";
Expand All @@ -20,9 +28,27 @@ const DEFAULT_HTTP_VERSION_DETECTION_TIMEOUT: Duration = Duration::from_secs(10)
/// requests.
pub(super) const MINIMAL_HEADER_SIZE: usize = 10;

/// Read the start of the TCP stream, decide if it's HTTP (of a supported version), if it is, serve
/// the connection with a [`HyperHandler`]. If it isn't, just forward the whole TCP connection to
/// the original destination.
/// Reads the start of the [`TcpStream`], and decides if it's HTTP (we currently only support
/// HTTP/1) or not,
///
/// ## HTTP/1
///
/// If the stream is identified as HTTP/1 by our check in [`HttpVersion::new`], then we serve the
/// connection with [`HyperHandler`].
///
/// ### Upgrade
///
/// If an upgrade request is detected in the [`HyperHandler`], then we take the HTTP connection
/// that's being served (after HTTP processing is done), and use [`copy_bidirectional`] to copy the
/// data from the upgraded connection to its original destination (similar to the Not HTTP/1
/// handling).
///
/// ## Not HTTP/1
///
/// Forwards the whole TCP connection to the original destination with [`copy_bidirectional`].
///
/// It's important to note that, we don't lose the bytes read from the original stream, due to us
/// converting it into a [`DefaultReversibleStream`].
#[tracing::instrument(
level = "trace",
skip(stolen_stream, matched_tx, connection_close_sender)
Expand All @@ -34,8 +60,9 @@ pub(super) async fn filter_task(
filters: Arc<DashMap<ClientId, Regex>>,
matched_tx: Sender<HandlerHttpRequest>,
connection_close_sender: Sender<ConnectionId>,
) {
) -> Result<(), HttpTrafficError> {
let port = original_destination.port();

match DefaultReversibleStream::read_header(
stolen_stream,
DEFAULT_HTTP_VERSION_DETECTION_TIMEOUT,
Expand All @@ -48,8 +75,17 @@ pub(super) async fn filter_task(
&H2_PREFACE[..MINIMAL_HEADER_SIZE],
) {
HttpVersion::V1 => {
// TODO: do we need to do something with this result?
let _res = http1::Builder::new()
// Contains the upgraded interceptor connection, if any.
let (upgrade_tx, upgrade_rx) = oneshot::channel::<RawHyperConnection>();

// We have to keep the connection alive to handle a possible upgrade request
// manually.
let server::conn::http1::Parts {
io: mut client_agent, // i.e. browser-agent connection
read_buf: agent_unprocessed,
..
} = http1::Builder::new()
.preserve_header_case(true)
.serve_connection(
reversible_stream,
HyperHandler {
Expand All @@ -59,17 +95,37 @@ pub(super) async fn filter_task(
port,
original_destination,
request_id: 0,
upgrade_tx: Some(upgrade_tx),
},
)
.await;
.without_shutdown()
.await?;

if let Ok(RawHyperConnection {
stream: mut agent_remote, // i.e. agent-original destination connection
unprocessed_bytes: client_unprocessed,
}) = upgrade_rx.await
{
// Send the data we received from the client, and have not processed as
// HTTP, to the original destination.
agent_remote.write_all(&agent_unprocessed).await?;

let _res = connection_close_sender
// Send the data we received from the original destination, and have not
// processed as HTTP, to the client.
client_agent.write_all(&client_unprocessed).await?;

// Now both the client and original destinations should be in sync, so we
// can just copy the bytes from one into the other.
copy_bidirectional(&mut client_agent, &mut agent_remote).await?;
}

connection_close_sender
.send(connection_id)
.await
.inspect_err(|connection_id| {
error!("Main TcpConnectionStealer dropped connection close channel while HTTP filter is still running. \
Cannot report the closing of connection {connection_id}.");
});
}).map_err(From::from)
}

// TODO(alex): hyper handling of HTTP/2 requires a bit more work, as it takes an
Expand All @@ -92,22 +148,27 @@ pub(super) async fn filter_task(
trace!(
"Forwarded {incoming} incoming bytes and {outgoing} \
outgoing bytes in passthrough connection"
)
);

Ok(())
}
Err(err) => {
error!(
"Encountered error while forwarding unsupported \
connection to its original destination: {err:?}"
)
);

Err(err)?
}
};
}
}
Err(err) => {
error!(
"Could not connect to original destination {original_destination:?}\
. Received a connection with an unsupported protocol version to a \
filtered HTTP port, but cannot forward the connection because of \
the connection error: {err:?}");
Err(err)?
}
}
}
Expand All @@ -116,6 +177,7 @@ pub(super) async fn filter_task(

Err(read_error) => {
error!("Got error while trying to read first bytes of TCP stream: {read_error:?}");
Err(read_error)
}
}
}
Loading

0 comments on commit 7dccbd7

Please sign in to comment.