diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index 856364556c35..1fd6af6fbf19 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -43,6 +43,8 @@ pub(crate) use messages::*; #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub struct TxId(String); +const MINIMUM_TX_ID_LENGTH: usize = 24; + impl Default for TxId { fn default() -> Self { Self(cuid::cuid().unwrap()) @@ -54,7 +56,15 @@ where T: Into, { fn from(s: T) -> Self { - Self(s.into()) + let contents = s.into(); + assert!( + contents.len() >= MINIMUM_TX_ID_LENGTH, + "minimum length for a TxId ({}) is {}, but was {}", + contents, + MINIMUM_TX_ID_LENGTH, + contents.len() + ); + Self(contents) } } @@ -102,6 +112,7 @@ impl Into for TxId { let mut buffer = [0; 16]; let tx_id_bytes = self.0.as_bytes(); let len = tx_id_bytes.len(); + // bytes [len-20 to len-12): least significative 4 bytes of the timestamp + 4 bytes counter for (i, source_idx) in (len - 20..len - 12).enumerate() { buffer[i] = tx_id_bytes[source_idx]; diff --git a/query-engine/query-engine/src/server/mod.rs b/query-engine/query-engine/src/server/mod.rs index 84c784d44832..32203f6fd760 100644 --- a/query-engine/query-engine/src/server/mod.rs +++ b/query-engine/query-engine/src/server/mod.rs @@ -41,15 +41,15 @@ pub async fn listen(opts: &PrismaOpt, state: State) -> PrismaResult<()> { pub async fn routes(state: State, req: Request) -> Result, hyper::Error> { let start = Instant::now(); - if req.method() == Method::POST && req.uri().path().contains("transaction") { - return handle_transaction(state, req).await; + if req.method() == Method::POST && req.uri().path().starts_with("/transaction") { + return transaction_handler(state, req).await; } if [Method::POST, Method::GET].contains(req.method()) - && req.uri().path().contains("metrics") + && req.uri().path().starts_with("/metrics") && state.enable_metrics { - return handle_metrics(state, req).await; + return metrics_handler(state, req).await; } let mut res = match (req.method(), req.uri().path()) { @@ -188,7 +188,7 @@ fn playground_handler() -> Response { .unwrap() } -async fn handle_metrics(state: State, req: Request) -> Result, hyper::Error> { +async fn metrics_handler(state: State, req: Request) -> Result, hyper::Error> { let format = if let Some(query) = req.uri().query() { if query.contains("format=json") { MetricFormat::Json @@ -235,47 +235,27 @@ async fn handle_metrics(state: State, req: Request) -> Result start a transaction /// POST /transaction/{tx_id}/commit -> commit a transaction /// POST /transaction/{tx_id}/rollback -> rollback a transaction -async fn handle_transaction(state: State, req: Request) -> Result, hyper::Error> { - let path = req.uri().path(); +async fn transaction_handler(state: State, req: Request) -> Result, hyper::Error> { + let path = req.uri().path().to_owned(); + let sections: Vec<&str> = path.split('/').collect(); - if path.contains("start") { + if sections.len() == 3 && sections[2] == "start" { return transaction_start_handler(state, req).await; } - let sections: Vec<&str> = path.split('/').collect(); - - if sections.len() < 2 { - return Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from("Request does not contain the transaction id")) - .unwrap()); + if sections.len() == 4 && sections[3] == "commit" { + return transaction_commit_handler(state, req, sections[2].into()).await; } - let tx_id: TxId = sections[2].into(); + if sections.len() == 4 && sections[3] == "rollback" { + return transaction_rollback_handler(state, req, sections[2].into()).await; + } - let succuss_resp = Response::builder() - .status(StatusCode::OK) - .header(CONTENT_TYPE, "application/json") - .body(Body::from(r#"{}"#)) + let res = Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from(format!("wrong transaction handler path: {}", &path))) .unwrap(); - - if path.contains("commit") { - match state.cx.executor.commit_tx(tx_id).await { - Ok(_) => Ok(succuss_resp), - Err(err) => Ok(err_to_http_resp(err, None)), - } - } else if path.contains("rollback") { - match state.cx.executor.rollback_tx(tx_id).await { - Ok(_) => Ok(succuss_resp), - Err(err) => Ok(err_to_http_resp(err, None)), - } - } else { - let res = Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::empty()) - .unwrap(); - Ok(res) - } + Ok(res) } async fn transaction_start_handler(state: State, req: Request) -> Result, hyper::Error> { @@ -318,7 +298,100 @@ async fn transaction_start_handler(state: State, req: Request) -> Result Ok(tx_id_to_http_resp(tx_id, telemetry)), + Ok(tx_id) => { + let result = if let Some(telemetry) = telemetry { + json!({ "id": tx_id.to_string(), "extensions": { "logs": json!(telemetry.logs), "traces": json!(telemetry.traces) } }) + } else { + json!({ "id": tx_id.to_string() }) + }; + let result_bytes = serde_json::to_vec(&result).unwrap(); + + let res = Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(Body::from(result_bytes)) + .unwrap(); + Ok(res) + } + Err(err) => Ok(err_to_http_resp(err, telemetry)), + } +} + +async fn transaction_commit_handler( + state: State, + req: Request, + tx_id: TxId, +) -> Result, hyper::Error> { + let ServerExecutionContext { + tx_id: _, + span: _, + trace_id: _, + capture_config, + } = ServerExecutionContext::builder(&req.headers()) + .with_tx_id(tx_id.clone()) + .build(); + + if let telemetry::capturing::Capturer::Enabled(capturer) = capture_config.clone() { + capturer.start_capturing().await; + } + + let result = state.cx.executor.rollback_tx(tx_id).await; + + let telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = capture_config { + capturer.fetch_captures().await + } else { + None + }; + + match result { + Ok(_) => Ok(empty_json_to_http_resp(telemetry)), + Err(err) => Ok(err_to_http_resp(err, telemetry)), + } +} + +async fn transaction_rollback_handler( + state: State, + req: Request, + tx_id: TxId, +) -> Result, hyper::Error> { + let ServerExecutionContext { + tx_id: _, + span: _, + trace_id: _, + capture_config, + } = ServerExecutionContext::builder(&req.headers()) + .with_tx_id(tx_id.clone()) + .build(); + + if let telemetry::capturing::Capturer::Enabled(capturer) = capture_config.clone() { + capturer.start_capturing().await; + } + + let result = state.cx.executor.rollback_tx(tx_id).await; + + let telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = capture_config { + capturer.fetch_captures().await + } else { + None + }; + + match result { + Ok(_) => { + let result = if let Some(telemetry) = telemetry { + json!({ "extensions": { "logs": json!(telemetry.logs), "traces": json!(telemetry.traces) } }) + } else { + json!({}) + }; + let result_bytes = serde_json::to_vec(&result).unwrap(); + + let res = Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(Body::from(result_bytes)) + .unwrap(); + + Ok(res) + } Err(err) => Ok(err_to_http_resp(err, telemetry)), } } @@ -358,21 +431,18 @@ impl<'a> Extractor for HeaderExtractor<'a> { } } -fn tx_id_to_http_resp( - tx_id: TxId, - captured_telemetry: Option, -) -> Response { +fn empty_json_to_http_resp(captured_telemetry: Option) -> Response { let result = if let Some(telemetry) = captured_telemetry { - json!({ "id": tx_id.to_string(), "extensions": { "logs": json!(telemetry.logs), "traces": json!(telemetry.traces) } }) + json!({ "extensions": { "logs": json!(telemetry.logs), "traces": json!(telemetry.traces) } }) } else { - json!({ "id": tx_id.to_string() }) + json!({}) }; let result_bytes = serde_json::to_vec(&result).unwrap(); Response::builder() .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json") - .body(dbg!(Body::from(result_bytes))) + .body(Body::from(result_bytes)) .unwrap() }