Skip to content

Commit

Permalink
Raise Http{Request,Response}BodyError when too much is written (#7591)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottt authored Nov 29, 2023
1 parent ffdac62 commit 8cf0d42
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,7 @@ fn main() {
http_types::OutgoingBody::finish(outgoing_body, None).expect_err("finish should fail");

assert!(
matches!(
&e,
http_types::ErrorCode::InternalError(Some(s))
if s == "not enough written to body stream",
),
matches!(&e, http_types::ErrorCode::HttpRequestBodySize(Some(3))),
"unexpected error: {e:#?}"
);
}
Expand All @@ -75,25 +71,26 @@ fn main() {
.expect_err("write should fail");

let e = match e {
test_programs::wasi::io::streams::StreamError::LastOperationFailed(e) => e,
test_programs::wasi::io::streams::StreamError::LastOperationFailed(e) => {
http_types::http_error_code(&e)
}
test_programs::wasi::io::streams::StreamError::Closed => panic!("request closed"),
};

assert!(matches!(
http_types::http_error_code(&e),
Some(http_types::ErrorCode::InternalError(Some(msg)))
if msg == "too much written to output stream"));
assert!(
matches!(
e,
Some(http_types::ErrorCode::HttpRequestBodySize(Some(18)))
),
"unexpected error {e:?}"
);
}

let e =
http_types::OutgoingBody::finish(outgoing_body, None).expect_err("finish should fail");

assert!(
matches!(
&e,
http_types::ErrorCode::InternalError(Some(s))
if s == "too much written to body stream",
),
matches!(&e, http_types::ErrorCode::HttpRequestBodySize(Some(18))),
"unexpected error: {e:#?}"
);
}
Expand Down
58 changes: 36 additions & 22 deletions crates/wasi-http/src/body.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::internal_error;
use crate::{bindings::http::types, types::FieldMap};
use anyhow::anyhow;
use bytes::Bytes;
Expand Down Expand Up @@ -386,6 +385,11 @@ impl WrittenState {
}
}

/// The number of bytes that have been written so far.
fn written(&self) -> u64 {
self.written.load(std::sync::atomic::Ordering::Relaxed)
}

/// Add `len` to the total number of bytes written. Returns `false` if the new total exceeds
/// the number of bytes expected to be written.
fn update(&self, len: usize) -> bool {
Expand All @@ -395,22 +399,17 @@ impl WrittenState {
.fetch_add(len, std::sync::atomic::Ordering::Relaxed);
old + len <= self.expected
}

/// Return a comparison of total bytes written to the number of bytes expected to be written.
fn finish(self) -> std::cmp::Ordering {
let written = self.written.load(std::sync::atomic::Ordering::Relaxed);
written.cmp(&self.expected)
}
}

pub struct HostOutgoingBody {
pub body_output_stream: Option<Box<dyn HostOutputStream>>,
context: StreamContext,
written: Option<WrittenState>,
finish_sender: Option<tokio::sync::oneshot::Sender<FinishMessage>>,
}

impl HostOutgoingBody {
pub fn new(size: Option<u64>) -> (Self, HyperOutgoingBody) {
pub fn new(context: StreamContext, size: Option<u64>) -> (Self, HyperOutgoingBody) {
let written = size.map(WrittenState::new);

use tokio::sync::oneshot::error::RecvError;
Expand Down Expand Up @@ -465,11 +464,13 @@ impl HostOutgoingBody {
.boxed();

// TODO: this capacity constant is arbitrary, and should be configurable
let output_stream = BodyWriteStream::new(1024 * 1024, body_sender, written.clone());
let output_stream =
BodyWriteStream::new(context, 1024 * 1024, body_sender, written.clone());

(
Self {
body_output_stream: Some(Box::new(output_stream)),
context,
written,
finish_sender: Some(finish_sender),
},
Expand All @@ -488,17 +489,10 @@ impl HostOutgoingBody {
.expect("outgoing-body trailer_sender consumed by a non-owning function");

if let Some(w) = self.written {
use std::cmp::Ordering;
let res = w.finish();
if res != Ordering::Equal {
let msg = match res {
Ordering::Less => "not enough",
Ordering::Greater => "too much",
Ordering::Equal => unreachable!(),
};

let written = w.written();
if written != w.expected {
let _ = sender.send(FinishMessage::Abort);
return Err(internal_error(format!("{msg} written to body stream")));
return Err(self.context.as_body_error(written));
}
}

Expand Down Expand Up @@ -528,8 +522,25 @@ impl HostOutgoingBody {
}
}

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum StreamContext {
Request,
Response,
}

impl StreamContext {
/// Construct an http request or response body size error.
pub fn as_body_error(&self, size: u64) -> types::ErrorCode {
match self {
StreamContext::Request => types::ErrorCode::HttpRequestBodySize(Some(size)),
StreamContext::Response => types::ErrorCode::HttpResponseBodySize(Some(size)),
}
}
}

/// Provides a [`HostOutputStream`] impl from a [`tokio::sync::mpsc::Sender`].
struct BodyWriteStream {
context: StreamContext,
writer: mpsc::Sender<Bytes>,
write_budget: usize,
written: Option<WrittenState>,
Expand All @@ -538,13 +549,15 @@ struct BodyWriteStream {
impl BodyWriteStream {
/// Create a [`BodyWriteStream`].
fn new(
context: StreamContext,
write_budget: usize,
writer: mpsc::Sender<Bytes>,
written: Option<WrittenState>,
) -> Self {
// at least one capacity is required to send a message
assert!(writer.max_capacity() >= 1);
BodyWriteStream {
context,
writer,
write_budget,
written,
Expand All @@ -562,9 +575,10 @@ impl HostOutputStream for BodyWriteStream {
Ok(()) => {
if let Some(written) = self.written.as_ref() {
if !written.update(len) {
return Err(StreamError::LastOperationFailed(anyhow!(internal_error(
"too much written to output stream".to_owned()
))));
let total = written.written();
return Err(StreamError::LastOperationFailed(anyhow!(self
.context
.as_body_error(total))));
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
bindings::http::types::{self, Headers, Method, Scheme, StatusCode, Trailers},
body::{HostFutureTrailers, HostIncomingBody, HostOutgoingBody},
body::{HostFutureTrailers, HostIncomingBody, HostOutgoingBody, StreamContext},
types::{
is_forbidden_header, remove_forbidden_headers, FieldMap, HostFields,
HostFutureIncomingResponse, HostIncomingRequest, HostIncomingResponse, HostOutgoingRequest,
Expand Down Expand Up @@ -384,7 +384,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostOutgoingRequest for T {
Err(e) => return Ok(Err(e)),
};

let (host_body, hyper_body) = HostOutgoingBody::new(size);
let (host_body, hyper_body) = HostOutgoingBody::new(StreamContext::Request, size);

req.body = Some(hyper_body);

Expand Down Expand Up @@ -727,7 +727,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostOutgoingResponse for T {
Err(e) => return Ok(Err(e)),
};

let (host, body) = HostOutgoingBody::new(size);
let (host, body) = HostOutgoingBody::new(StreamContext::Response, size);

resp.body.replace(body);

Expand Down

0 comments on commit 8cf0d42

Please sign in to comment.