Skip to content

Commit

Permalink
Create ServerExecutionContext to reuse in BEGIN/COMMIT/ROLLBACK trans…
Browse files Browse the repository at this point in the history
…action blocks
  • Loading branch information
miguelff committed Jan 17, 2023
1 parent f7b55d6 commit bd123ba
Showing 1 changed file with 108 additions and 68 deletions.
176 changes: 108 additions & 68 deletions query-engine/query-engine/src/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::state::State;
use crate::{opt::PrismaOpt, PrismaResult};
use hyper::http::HeaderValue;
use hyper::service::{make_service_fn, service_fn};
use hyper::{header::CONTENT_TYPE, Body, HeaderMap, Method, Request, Response, Server, StatusCode};
use opentelemetry::trace::{SpanId, TraceContextExt, TraceId};
Expand All @@ -17,9 +16,6 @@ use std::time::Instant;
use tracing::{field, Instrument, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;

const TRANSACTION_ID_HEADER: &str = "X-transaction-id";
const TRACE_CAPTURE_HEADER: &str = "X-capture-telemetry";

/// Starts up the graphql query engine server
pub async fn listen(opts: &PrismaOpt, state: State) -> PrismaResult<()> {
let query_engine = make_service_fn(move |_| {
Expand Down Expand Up @@ -117,7 +113,14 @@ async fn graphql_handler(state: State, req: Request<Body>) -> Result<Response<Bo
return Ok(handle_debug_headers(&req));
}

let (tx_id, span, capture_config, trace_id) = process_gql_req_headers(&req);
let ServerExecutionContext {
span,
tx_id,
capture_config,
trace_id,
} = ServerExecutionContext::builder(req.headers())
.with_span(info_span!("prisma:engine", user_facing = true))
.build();

if let telemetry::capturing::Capturer::Enabled(capturer) = capture_config.clone() {
capturer.start_capturing().await;
Expand All @@ -131,7 +134,10 @@ async fn graphql_handler(state: State, req: Request<Body>) -> Result<Response<Bo
match serde_json::from_slice(full_body.as_ref()) {
Ok(body) => {
let handler = GraphQlHandler::new(&*state.cx.executor, state.cx.query_schema());
let mut result = handler.handle(body, tx_id, trace_id).instrument(span).await;
let mut result = handler
.handle(body, tx_id, Some(trace_id.to_string()))
.instrument(span)
.await;

if let telemetry::capturing::Capturer::Enabled(capturer) = capture_config {
let telemetry = capturer.fetch_captures().await;
Expand Down Expand Up @@ -272,20 +278,26 @@ async fn handle_transaction(state: State, req: Request<Body>) -> Result<Response
}

async fn transaction_start_handler(state: State, req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
let cx = get_parent_span_context(&req);
let headers = req.headers().to_owned();

let body_start = req.into_body();
// block and buffer request until the request has completed
let full_body = hyper::body::to_bytes(body_start).await?;
let mut tx_opts: TransactionOptions = serde_json::from_slice(full_body.as_ref()).unwrap();
let tx_id = tx_opts.with_predefined_transaction_id();

let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty);
if let Some(context) = cx {
span.set_parent(context);
} else {
span.set_parent(tx_id.into());
}
let ServerExecutionContext {
tx_id,
span,
trace_id,
capture_config,
} = ServerExecutionContext::builder(&headers)
.with_span(info_span!(
"prisma:engine:itx_runner",
user_facing = true,
itx_id = field::Empty
))
.with_tx_id(tx_opts.with_predefined_transaction_id())
.build();

match state
.cx
Expand All @@ -310,18 +322,6 @@ async fn transaction_start_handler(state: State, req: Request<Body>) -> Result<R
}
}

fn get_transaction_id_from_header(req: &Request<Body>) -> Option<TxId> {
match req.headers().get(TRANSACTION_ID_HEADER) {
Some(id_header) => {
let msg = format!("{} has not been correctly set.", TRANSACTION_ID_HEADER);
let id = id_header.to_str().unwrap_or(msg.as_str());
Some(TxId::from(id))
}

None => None,
}
}

/// Handle debug headers inside the main GraphQL endpoint.
fn handle_debug_headers(req: &Request<Body>) -> Response<Body> {
/// Debug header that triggers a panic in the request thread.
Expand Down Expand Up @@ -357,24 +357,6 @@ impl<'a> Extractor for HeaderExtractor<'a> {
}
}

/// If the client sends us a trace and span id, extracting a new context if the
/// headers are set. If not, returns None.
fn get_parent_span_context(req: &Request<Body>) -> Option<Context> {
let extractor = HeaderExtractor(req.headers());
let context = global::get_text_map_propagator(|propagator| propagator.extract(&extractor));

// because getting the context is infallible, we can be returning a context that's not
// useful for our purposes, for that reason we validate it and return None in case
// it's set with an invalid TraceId
let trace_id = telemetry::helpers::get_trace_id_from_context(&context);
let span_id = context.span().span_context().span_id();
if trace_id == TraceId::INVALID || span_id == SpanId::INVALID {
None
} else {
Some(context)
}
}

fn err_to_http_resp(err: query_core::CoreError) -> Response<Body> {
let status = match err {
query_core::CoreError::TransactionError(ref err) => match err {
Expand All @@ -395,33 +377,91 @@ fn err_to_http_resp(err: query_core::CoreError) -> Response<Body> {
Response::builder().status(status).body(body).unwrap()
}

pub(crate) fn process_gql_req_headers(
req: &Request<Body>,
) -> (Option<TxId>, Span, telemetry::capturing::Capturer, Option<String>) {
let tx_id = get_transaction_id_from_header(req);

let span = info_span!("prisma:engine", user_facing = true);
let cx = get_parent_span_context(req);
if let Some(context) = cx {
span.set_parent(context);
} else if let Some(tx_id) = tx_id.clone() {
span.set_parent(tx_id.into());
/// Encapsulates the data relevant to tweak the execution of a query, particularly, its
/// tracing configuration, and transaction scope.
struct ServerExecutionContext {
pub(crate) tx_id: Option<TxId>,
pub(crate) span: Span,
pub(crate) trace_id: TraceId,
pub(crate) capture_config: telemetry::capturing::Capturer,
}

impl ServerExecutionContext {
pub(crate) fn builder(headers: &HeaderMap) -> ServerExecutionContextBuilder {
ServerExecutionContextBuilder {
headers,
root_span: None,
tx_id: None,
}
}
}

let context = span.context();
struct ServerExecutionContextBuilder<'req> {
headers: &'req HeaderMap,
root_span: Option<Span>,
tx_id: Option<TxId>,
}

let trace_id = telemetry::helpers::get_trace_id_from_context(&context);
let trace_capture_header = req.headers().get(TRACE_CAPTURE_HEADER);
let trace_capture = create_capture_config(trace_capture_header, trace_id);
impl ServerExecutionContextBuilder<'_> {
pub(crate) fn with_tx_id(mut self, tx_id: TxId) -> Self {
self.tx_id = Some(tx_id);
self
}

(tx_id, span, trace_capture, Some(trace_id.to_string()))
}
pub(crate) fn with_span(mut self, span: Span) -> Self {
self.root_span = Some(span);
self
}

pub fn create_capture_config(header: Option<&HeaderValue>, trace_id: TraceId) -> telemetry::capturing::Capturer {
let mut settings = if let Some(h) = header {
h.to_str().unwrap_or("")
} else {
""
};
telemetry::capturing::capturer(trace_id, settings)
pub(crate) fn build(self) -> ServerExecutionContext {
const TRACE_CAPTURE_HEADER: &str = "X-capture-telemetry";
const TRANSACTION_ID_HEADER: &str = "X-transaction-id";

// TxId. Either set explicitly or inferred from the header or
let tx_id = if self.tx_id.is_some() {
self.tx_id
} else if let Some(id_header) = self.headers.get(TRANSACTION_ID_HEADER) {
let msg = format!("{} has not been correctly set.", TRANSACTION_ID_HEADER);
let id = id_header.to_str().unwrap_or(msg.as_str());
Some(TxId::from(id))
} else {
None
};

// Span. Either set explicitly or default to current
let span = if let Some(span) = self.root_span {
span
} else {
Span::current()
};

// Parent tracing context, either propagated in the headers, or if not propagated and
// in the scope of a transaction, inferred from its id.
let extractor = HeaderExtractor(self.headers);
let context = global::get_text_map_propagator(|propagator| propagator.extract(&extractor));

let parent_trace_id = context.span().span_context().trace_id();
let parent_span_id = context.span().span_context().span_id();
if parent_trace_id != TraceId::INVALID && parent_span_id != SpanId::INVALID {
span.set_parent(context);
} else if let Some(tx_id) = tx_id.clone() {
span.set_parent(tx_id.into());
};

// Capturing, configured in a header
let settings = if let Some(h) = self.headers.get(TRACE_CAPTURE_HEADER) {
h.to_str().unwrap_or("")
} else {
""
};
let trace_id = span.context().span().span_context().trace_id();
let capture_config = telemetry::capturing::capturer(trace_id, settings);

ServerExecutionContext {
tx_id,
span,
trace_id,
capture_config,
}
}
}

0 comments on commit bd123ba

Please sign in to comment.