Skip to content

Commit

Permalink
Option to change how websocket data frames are returned to the client
Browse files Browse the repository at this point in the history
The websocket specification defines 2 ways how application-layer data transmitted over the socket can be represented: `Text` or `Binary`.
Currently the default data frame representation returned by the RPC server is `Binary` but this is counterintuitive since JSONRPC already is in a text representation.

This PR let the client, per connection, decide in which representation data should be returned by the server. This is achived by adding support for a `frame` query parameter on the `/ws` endpoint.
Examples:
- ws://127.0.0.1:8000/ws		--> data frames are returned as `Binary` (default)
- ws://127.0.0.1:8000/ws?frame=text	--> data frames are returned as `Text`
- ws://127.0.0.1:8000/ws?frame=binary	--> data frames are returned as `Binary`
- ws://127.0.0.1:8000/ws?frame=foo   	--> data frames are returned as `Binary` (Fallback for unsupported values)

To accomplish backwards compatibility `Binary` is still the default returned representation but in a major release this should be changed to `Text`
  • Loading branch information
Eligioo authored and jsdanielh committed Jan 27, 2025
1 parent d0439d1 commit ae65618
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 18 deletions.
20 changes: 20 additions & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ pub enum Error {
InvalidSubscriptionId(Value),
}

/// Indicate if a websocket frame response should be in Binary or Text
#[derive(Copy, Clone, Default)]
pub enum FrameType {
/// Binary frame type
#[default]
Binary,
/// Text frame type
Text,
}

impl From<&String> for FrameType {
fn from(value: &String) -> Self {
match value.as_str() {
"text" => FrameType::Text,
"binary" => FrameType::Binary,
_ => FrameType::Binary,
}
}
}

/// A JSON-RPC request or response can either be a single request or response, or a list of the former. This `enum`
/// matches either for serialization and deserialization.
///
Expand Down
2 changes: 1 addition & 1 deletion derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl<'a> RpcMethod<'a> {
let notifier = ::std::sync::Arc::new(::nimiq_jsonrpc_server::Notify::new());
let listener = notifier.clone();

let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), listener);
let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), listener, frame_type);

Ok::<_, ::nimiq_jsonrpc_core::RpcError>((subscription, Some(notifier)))
}
Expand Down
1 change: 1 addition & 0 deletions derive/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ fn impl_service(im: &mut ItemImpl, args: &ServiceMeta) -> TokenStream {
request: ::nimiq_jsonrpc_core::Request,
tx: Option<&::tokio::sync::mpsc::Sender<::nimiq_jsonrpc_server::Message>>,
stream_id: u64,
frame_type: Option<::nimiq_jsonrpc_core::FrameType>,
) -> Option<::nimiq_jsonrpc_server::ResponseAndSubScriptionNotifier> {
match request.method.as_str() {
#(#match_arms)*
Expand Down
66 changes: 49 additions & 17 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::{
use async_trait::async_trait;
use axum::{
body::{Body, Bytes},
extract::{DefaultBodyLimit, State, WebSocketUpgrade},
extract::{DefaultBodyLimit, Query, State, WebSocketUpgrade},
http::{header::CONTENT_TYPE, response::Builder, HeaderValue, Method, StatusCode},
middleware::Next,
response::{IntoResponse as _, Response as HttpResponse},
Expand Down Expand Up @@ -47,7 +47,8 @@ use tokio::{
};

use nimiq_jsonrpc_core::{
Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId, SubscriptionMessage,
FrameType, Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId,
SubscriptionMessage,
};

pub use axum::extract::ws::Message;
Expand Down Expand Up @@ -297,7 +298,7 @@ impl<D: Dispatcher> Server<D> {
let http_router = Router::new().route(
"/",
post(|body: Bytes| async move {
let data = Self::handle_raw_request(inner, &Message::binary(body), None)
let data = Self::handle_raw_request(inner, &Message::binary(body), None, None)
.await
.unwrap_or(Message::Binary(Bytes::new()));

Expand All @@ -312,7 +313,11 @@ impl<D: Dispatcher> Server<D> {
let inner = Arc::clone(&self.inner);
let ws_router = Router::new().route(
"/ws",
any(|ws: WebSocketUpgrade| async move { Self::upgrade_to_ws(inner, ws) }),
any(
|Query(params): Query<HashMap<String, String>>, ws: WebSocketUpgrade| async move {
Self::upgrade_to_ws(inner, ws, params)
},
),
);

let app = Router::new()
Expand Down Expand Up @@ -344,10 +349,18 @@ impl<D: Dispatcher> Server<D> {
///
/// # TODO:
///
/// - This sends stuff as binary websocket frames. It should really use text frames.
/// - Make the queue size configurable
///
fn upgrade_to_ws(inner: Arc<Inner<D>>, ws: WebSocketUpgrade) -> HttpResponse<Body> {
fn upgrade_to_ws(
inner: Arc<Inner<D>>,
ws: WebSocketUpgrade,
query_params: HashMap<String, String>,
) -> HttpResponse<Body> {
let frame_type: Option<FrameType> = query_params
.get("frame")
.map(|frame_type| Some(frame_type.into()))
.unwrap_or_default();

ws.on_upgrade(move |websocket| {
let (mut tx, mut rx) = websocket.split();

Expand Down Expand Up @@ -383,6 +396,7 @@ impl<D: Dispatcher> Server<D> {
Arc::clone(&inner),
&message,
Some(&multiplex_tx),
frame_type,
)
.await
{
Expand All @@ -409,14 +423,16 @@ impl<D: Dispatcher> Server<D> {
/// - `request`: The raw request data.
/// - `tx`: If the request was received over websocket, this the message queue over which the called function can
/// send notifications to the client (used for subscriptions).
/// - `frame_type`: If the request was received over websocket, indicate whether notifications are send back as Text or Binary frames.
///
async fn handle_raw_request(
inner: Arc<Inner<D>>,
request: &Message,
tx: Option<&mpsc::Sender<Message>>,
frame_type: Option<FrameType>,
) -> Option<Message> {
match serde_json::from_slice(request.clone().into_data().as_ref()) {
Ok(request) => Self::handle_request(inner, request, tx).await,
Ok(request) => Self::handle_request(inner, request, tx, frame_type).await,
Err(_e) => {
log::error!("Received invalid JSON from client");
Some(SingleOrBatch::Single(Response::new_error(
Expand Down Expand Up @@ -447,21 +463,27 @@ impl<D: Dispatcher> Server<D> {
/// - `request`: The request that was received.
/// - `tx`: If the request was received over websocket, this the message queue over which the called function can
/// send notifications to the client (used for subscriptions).
/// - `frame_type`: If the request was received over websocket, indicate whether notifications are send back as Text or Binary frames.
///
async fn handle_request(
inner: Arc<Inner<D>>,
request: SingleOrBatch<Request>,
tx: Option<&mpsc::Sender<Message>>,
frame_type: Option<FrameType>,
) -> Option<SingleOrBatch<Response>> {
match request {
SingleOrBatch::Single(request) => Self::handle_single_request(inner, request, tx)
.await
.map(|(response, _)| SingleOrBatch::Single(response)),
SingleOrBatch::Single(request) => {
Self::handle_single_request(inner, request, tx, frame_type)
.await
.map(|(response, _)| SingleOrBatch::Single(response))
}

SingleOrBatch::Batch(requests) => {
let futures = requests
.into_iter()
.map(|request| Self::handle_single_request(Arc::clone(&inner), request, tx))
.map(|request| {
Self::handle_single_request(Arc::clone(&inner), request, tx, frame_type)
})
.collect::<FuturesUnordered<_>>();

let responses = futures
Expand All @@ -479,6 +501,7 @@ impl<D: Dispatcher> Server<D> {
inner: Arc<Inner<D>>,
request: Request,
tx: Option<&mpsc::Sender<Message>>,
frame_type: Option<FrameType>,
) -> Option<ResponseAndSubScriptionNotifier> {
if request.method == "unsubscribe" {
return Self::handle_unsubscribe_stream(request, inner).await;
Expand All @@ -490,7 +513,7 @@ impl<D: Dispatcher> Server<D> {

log::debug!("request: {:#?}", request);

let response = dispatcher.dispatch(request, tx, id).await;
let response = dispatcher.dispatch(request, tx, id, frame_type).await;

log::debug!("response: {:#?}", response);

Expand Down Expand Up @@ -565,6 +588,7 @@ pub trait Dispatcher: Send + Sync + 'static {
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
frame_type: Option<FrameType>,
) -> Option<ResponseAndSubScriptionNotifier>;

/// Returns whether a method should be dispatched with this dispatcher.
Expand Down Expand Up @@ -605,13 +629,14 @@ impl Dispatcher for ModularDispatcher {
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
frame_type: Option<FrameType>,
) -> Option<ResponseAndSubScriptionNotifier> {
for dispatcher in &mut self.dispatchers {
let m = dispatcher.match_method(&request.method);
log::debug!("Matching '{}' against dispatcher -> {}", request.method, m);
log::debug!("Methods: {:?}", dispatcher.method_names());
if m {
return dispatcher.dispatch(request, tx, id).await;
return dispatcher.dispatch(request, tx, id, frame_type).await;
}
}

Expand Down Expand Up @@ -674,10 +699,11 @@ where
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
frame_type: Option<FrameType>,
) -> Option<ResponseAndSubScriptionNotifier> {
if self.is_allowed(&request.method) {
log::debug!("Dispatching method: {}", request.method);
self.inner.dispatch(request, tx, id).await
self.inner.dispatch(request, tx, id, frame_type).await
} else {
log::debug!("Method not allowed: {}", request.method);
// If the method is not white-listed, pretend it doesn't exist.
Expand Down Expand Up @@ -833,6 +859,7 @@ async fn forward_notification<T>(
tx: &mut mpsc::Sender<Message>,
id: &SubscriptionId,
method: &str,
frame_type: Option<FrameType>,
) -> Result<(), Error>
where
T: Serialize + Debug + Send + Sync,
Expand All @@ -846,8 +873,12 @@ where

log::debug!("Sending notification: {:?}", notification);

tx.send(Message::binary(serde_json::to_vec(&notification)?))
.await?;
let message = match frame_type {
Some(FrameType::Text) => Message::text(serde_json::to_string(&notification)?),
Some(FrameType::Binary) | None => Message::binary(serde_json::to_vec(&notification)?),
};

tx.send(message).await?;

Ok(())
}
Expand All @@ -871,6 +902,7 @@ pub fn connect_stream<T, S>(
stream_id: u64,
method: String,
notify_handler: Arc<Notify>,
frame_type: Option<FrameType>,
) -> SubscriptionId
where
T: Serialize + Debug + Send + Sync,
Expand All @@ -892,7 +924,7 @@ where
item = stream.next() => {
match item {
Some(notification) => {
if let Err(error) = forward_notification(notification, &mut tx, &id, &method).await {
if let Err(error) = forward_notification(notification, &mut tx, &id, &method, frame_type).await {
// Break the loop when the channel is closed
if let Error::Mpsc(_) = error {
break;
Expand Down

0 comments on commit ae65618

Please sign in to comment.