From 6dff04b942d5e0cc2e2da06cd5e1f499de183a41 Mon Sep 17 00:00:00 2001 From: Isaiah Becker-Mayer Date: Fri, 1 Apr 2022 17:45:29 -0400 Subject: [PATCH 1/3] Abstract-out chunked read logic (#11616) Creates a general vchan Client that can read chunked messages. rdpdr and cliprdr::Client's have the vchan::Client as a field. --- lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs | 148 +++++++------------ lib/srv/desktop/rdp/rdpclient/src/lib.rs | 4 +- lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs | 84 ++++++----- lib/srv/desktop/rdp/rdpclient/src/vchan.rs | 51 +++++++ 4 files changed, 149 insertions(+), 138 deletions(-) diff --git a/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs b/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs index 4c053bc495e53..dee745039eaa8 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs @@ -14,42 +14,26 @@ use super::util; use crate::errors::invalid_data_error; -use crate::vchan::ChannelPDUFlags; use crate::{vchan, Payload}; use bitflags::bitflags; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use num_traits::FromPrimitive; use rdp::core::{mcs, tpkt}; use rdp::model::error::*; -use rdp::try_let; use std::collections::HashMap; use std::collections::VecDeque; -use std::io::{Cursor, Read, Write}; +use std::io::{Read, Write}; pub const CHANNEL_NAME: &str = "cliprdr"; -struct PendingData { - data: Vec, - total_length: u32, - clipboard_header: Option, -} - -impl PendingData { - fn reset(&mut self, length: u32) { - self.data.clear(); - self.total_length = length; - self.clipboard_header = None; - } -} - /// Client implements a client for the clipboard virtual channel /// (CLIPRDR) extension, as defined in: /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpeclip/fb9b7e0b-6db4-41c2-b83c-f889c1ee7688 pub struct Client { clipboard: HashMap>, - pending: PendingData, on_remote_copy: Box)>, incoming_paste_formats: VecDeque, + vchan: vchan::Client, } impl Default for Client { @@ -62,91 +46,58 @@ impl Client { pub fn new(on_remote_copy: Box)>) -> Self { Client { clipboard: HashMap::new(), - pending: PendingData { - data: Vec::new(), - total_length: 0, - clipboard_header: None, - }, on_remote_copy, incoming_paste_formats: VecDeque::new(), + vchan: vchan::Client::new(), } } - - pub fn read( + /// Reads raw RDP messages sent on the cliprdr virtual channel and replies as necessary. + pub fn read_and_reply( &mut self, payload: tpkt::Payload, mcs: &mut mcs::Client, ) -> RdpResult<()> { - let mut payload = try_let!(tpkt::Payload::Raw, payload)?; - let pdu_header = vchan::ChannelPDUHeader::decode(&mut payload)?; - - // TODO(zmb3): this logic is the same for all virtual channels, and should - // be moved to vchan.rs and reused for the rdpdr client as well - if pdu_header - .flags - .contains(ChannelPDUFlags::CHANNEL_FLAG_FIRST) - { - self.pending.reset(pdu_header.length); - self.pending.clipboard_header = Some(ClipboardPDUHeader::decode(&mut payload)?); - } - - payload.read_to_end(&mut self.pending.data)?; - - if pdu_header - .flags - .contains(ChannelPDUFlags::CHANNEL_FLAG_LAST) - && self.pending.clipboard_header.is_some() - { - let full_msg = self.pending.data.split_off(0); - let mut payload = Cursor::new(full_msg); - let header = self.pending.clipboard_header.take().unwrap(); - return self.handle_message(header, &mut payload, mcs); - } + if let Some(mut payload) = self.vchan.read(payload)? { + let header = ClipboardPDUHeader::decode(&mut payload)?; - Ok(()) - } + debug!("received {:?}", header.msg_type); - fn handle_message( - &mut self, - header: ClipboardPDUHeader, - payload: &mut Payload, - mcs: &mut mcs::Client, - ) -> RdpResult<()> { - debug!("received {:?}", header.msg_type); - - let responses = match header.msg_type { - ClipboardPDUType::CB_CLIP_CAPS => self.handle_server_caps(payload)?, - ClipboardPDUType::CB_MONITOR_READY => self.handle_monitor_ready(payload)?, - ClipboardPDUType::CB_FORMAT_LIST => { - self.handle_format_list(payload, header.data_len)? - } - ClipboardPDUType::CB_FORMAT_LIST_RESPONSE => { - self.handle_format_list_response(header.msg_flags)? - } - ClipboardPDUType::CB_FORMAT_DATA_REQUEST => self.handle_format_data_request(payload)?, - ClipboardPDUType::CB_FORMAT_DATA_RESPONSE => { - if header - .msg_flags - .contains(ClipboardHeaderFlags::CB_RESPONSE_OK) - { - self.handle_format_data_response(payload, header.data_len)? - } else { - warn!("RDP server failed to process format data request"); + let responses = match header.msg_type { + ClipboardPDUType::CB_CLIP_CAPS => self.handle_server_caps(&mut payload)?, + ClipboardPDUType::CB_MONITOR_READY => self.handle_monitor_ready(&mut payload)?, + ClipboardPDUType::CB_FORMAT_LIST => { + self.handle_format_list(&mut payload, header.data_len)? + } + ClipboardPDUType::CB_FORMAT_LIST_RESPONSE => { + self.handle_format_list_response(header.msg_flags)? + } + ClipboardPDUType::CB_FORMAT_DATA_REQUEST => { + self.handle_format_data_request(&mut payload)? + } + ClipboardPDUType::CB_FORMAT_DATA_RESPONSE => { + if header + .msg_flags + .contains(ClipboardHeaderFlags::CB_RESPONSE_OK) + { + self.handle_format_data_response(&mut payload, header.data_len)? + } else { + warn!("RDP server failed to process format data request"); + vec![] + } + } + _ => { + warn!( + "CLIPRDR message {:?} not implemented, ignoring", + header.msg_type + ); vec![] } - } - _ => { - warn!( - "CLIPRDR message {:?} not implemented, ignoring", - header.msg_type - ); - vec![] - } - }; + }; - let chan = &CHANNEL_NAME.to_string(); - for resp in responses { - mcs.write(chan, resp)?; + let chan = &CHANNEL_NAME.to_string(); + for resp in responses { + mcs.write(chan, resp)?; + } } Ok(()) @@ -375,6 +326,14 @@ impl ClipboardPDUHeader { } } + fn encode(&self) -> RdpResult> { + let mut w = vec![]; + w.write_u16::(self.msg_type as u16)?; + w.write_u16::(self.msg_flags.bits())?; + w.write_u32::(self.data_len)?; + Ok(w) + } + fn decode(payload: &mut Payload) -> RdpResult { let typ = payload.read_u16::()?; Ok(Self { @@ -385,15 +344,8 @@ impl ClipboardPDUHeader { data_len: payload.read_u32::()?, }) } - - fn encode(&self) -> RdpResult> { - let mut w = vec![]; - w.write_u16::(self.msg_type as u16)?; - w.write_u16::(self.msg_flags.bits())?; - w.write_u32::(self.data_len)?; - Ok(w) - } } + #[derive(Clone, Copy, Debug, Eq, PartialEq, FromPrimitive, ToPrimitive)] #[allow(non_camel_case_types)] enum ClipboardPDUType { diff --git a/lib/srv/desktop/rdp/rdpclient/src/lib.rs b/lib/srv/desktop/rdp/rdpclient/src/lib.rs index 8d3d8ba301565..efd5d6cf51b5c 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/lib.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/lib.rs @@ -287,9 +287,9 @@ impl RdpClient { // name. match channel_name.as_str() { "global" => self.global.read(message, &mut self.mcs, callback), - rdpdr::CHANNEL_NAME => self.rdpdr.read(message, &mut self.mcs), + rdpdr::CHANNEL_NAME => self.rdpdr.read_and_reply(message, &mut self.mcs), cliprdr::CHANNEL_NAME => match self.cliprdr { - Some(ref mut clip) => clip.read(message, &mut self.mcs), + Some(ref mut clip) => clip.read_and_reply(message, &mut self.mcs), None => Ok(()), }, RDPSND_CHANNEL_NAME => { diff --git a/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs b/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs index abbbe3530fd5e..0d78147db106b 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs @@ -21,7 +21,6 @@ use rdp::core::mcs; use rdp::core::tpkt; use rdp::model::data::Message; use rdp::model::error::*; -use rdp::try_let; use std::io::{Read, Write}; pub const CHANNEL_NAME: &str = "rdpdr"; @@ -31,56 +30,61 @@ pub const CHANNEL_NAME: &str = "rdpdr"; /// /// This client only supports a single smartcard device. pub struct Client { + vchan: vchan::Client, scard: scard::Client, } impl Client { pub fn new(cert_der: Vec, key_der: Vec, pin: String) -> Self { Client { + vchan: vchan::Client::new(), scard: scard::Client::new(cert_der, key_der, pin), } } - pub fn read( + /// Reads raw RDP messages sent on the rdpdr virtual channel and replies as necessary. + pub fn read_and_reply( &mut self, payload: tpkt::Payload, mcs: &mut mcs::Client, ) -> RdpResult<()> { - let mut payload = try_let!(tpkt::Payload::Raw, payload)?; - - // Ignore this, we don't need anything from this header. - let _pdu_header = vchan::ChannelPDUHeader::decode(&mut payload)?; - - let header = Header::decode(&mut payload)?; - if let Component::RDPDR_CTYP_PRN = header.component { - warn!("got {:?} RDPDR header from RDP server, ignoring because we're not redirecting any printers", header); - return Ok(()); - } - let resp = match header.packet_id { - PacketId::PAKID_CORE_SERVER_ANNOUNCE => self.handle_server_announce(&mut payload)?, - PacketId::PAKID_CORE_SERVER_CAPABILITY => { - self.handle_server_capability(&mut payload)? + if let Some(mut payload) = self.vchan.read(payload)? { + let header = SharedHeader::decode(&mut payload)?; + if let Component::RDPDR_CTYP_PRN = header.component { + warn!("got {:?} RDPDR header from RDP server, ignoring because we're not redirecting any printers", header); + return Ok(()); } - PacketId::PAKID_CORE_CLIENTID_CONFIRM => self.handle_client_id_confirm(&mut payload)?, - PacketId::PAKID_CORE_DEVICE_REPLY => self.handle_device_reply(&mut payload)?, - // Device IO request is where communication with the smartcard actually happens. - // Everything up to this point was negotiation and smartcard device registration. - PacketId::PAKID_CORE_DEVICE_IOREQUEST => self.handle_device_io_request(&mut payload)?, - _ => { - // We don't implement the full set of messages. Only the ones necessary for initial - // negotiation and registration of a smartcard device. - error!( - "RDPDR packets {:?} are not implemented yet, ignoring", - header.packet_id - ); - None + let resp = match header.packet_id { + PacketId::PAKID_CORE_SERVER_ANNOUNCE => { + self.handle_server_announce(&mut payload)? + } + PacketId::PAKID_CORE_SERVER_CAPABILITY => { + self.handle_server_capability(&mut payload)? + } + PacketId::PAKID_CORE_CLIENTID_CONFIRM => { + self.handle_client_id_confirm(&mut payload)? + } + PacketId::PAKID_CORE_DEVICE_REPLY => self.handle_device_reply(&mut payload)?, + // Device IO request is where communication with the smartcard actually happens. + // Everything up to this point was negotiation and smartcard device registration. + PacketId::PAKID_CORE_DEVICE_IOREQUEST => { + self.handle_device_io_request(&mut payload)? + } + _ => { + // We don't implement the full set of messages. Only the ones necessary for initial + // negotiation and registration of a smartcard device. + error!( + "RDPDR packets {:?} are not implemented yet, ignoring", + header.packet_id + ); + None + } + }; + + if let Some(resp) = resp { + return mcs.write(&CHANNEL_NAME.to_string(), resp); } - }; - - if let Some(resp) = resp { - Ok(mcs.write(&CHANNEL_NAME.to_string(), resp)?) - } else { - Ok(()) } + Ok(()) } fn handle_server_announce(&self, payload: &mut Payload) -> RdpResult>> { @@ -166,7 +170,7 @@ impl Client { } fn encode_message(packet_id: PacketId, payload: Vec) -> RdpResult> { - let mut inner = Header::new(Component::RDPDR_CTYP_CORE, packet_id).encode()?; + let mut inner = SharedHeader::new(Component::RDPDR_CTYP_CORE, packet_id).encode()?; inner.extend_from_slice(&payload); let mut outer = vchan::ChannelPDUHeader::new( inner.length() as u32, @@ -177,13 +181,17 @@ fn encode_message(packet_id: PacketId, payload: Vec) -> RdpResult> { Ok(outer) } +/// 2.2.1.1 Shared Header (RDPDR_HEADER) +/// This header is present at the beginning of every message in this protocol. +/// The purpose of this header is to describe the type of the message. +/// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpefs/29d4108f-8163-4a67-8271-e48c4b9c2a7c #[derive(Debug)] -struct Header { +struct SharedHeader { component: Component, packet_id: PacketId, } -impl Header { +impl SharedHeader { fn new(component: Component, packet_id: PacketId) -> Self { Self { component, diff --git a/lib/srv/desktop/rdp/rdpclient/src/vchan.rs b/lib/srv/desktop/rdp/rdpclient/src/vchan.rs index 0ab8ce6ecb878..45f0d189c8759 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/vchan.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/vchan.rs @@ -16,7 +16,57 @@ use crate::errors::invalid_data_error; use crate::Payload; use bitflags::bitflags; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use rdp::core::tpkt; use rdp::model::error::*; +use rdp::try_let; +use std::io::{Cursor, Read}; + +/// Client is a general client for handling virtual channel payloads. +/// Its read method can read an RDP message sent in multiple chunks +/// (or a single chunk) over a virtual channel. +/// See https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/343e4888-4c48-4054-b0e3-4e0762d1993c +/// for more information about chunks. +pub struct Client { + data: Vec, +} + +impl Default for Client { + fn default() -> Self { + Self::new() + } +} + +impl Client { + pub fn new() -> Self { + Self { data: Vec::new() } + } + + /// Callers can call read() to process RDP messages (PDUs) sent over a virtual channel. + /// + /// For chunked PDUs, the Client will piece the full PDU together in Client.data over multiple calls, + /// and will only return an Ok(Some(Payload)) once a full message has been pieced together. + /// + /// The Payload will be the raw bytes of the PDU, starting at the channel specific header. + /// For example, if handling a cliprdr PDU, Payload will be a full PDU starting with the + /// CLIPRDR_HEADER structure that's is present in all clipboard PDUs. + /// + /// Returns Ok(None) on interim chunks. + pub fn read(&mut self, raw_payload: tpkt::Payload) -> RdpResult> { + let mut raw_payload = try_let!(tpkt::Payload::Raw, raw_payload)?; + let channel_pdu_header = ChannelPDUHeader::decode(&mut raw_payload)?; + + raw_payload.read_to_end(&mut self.data)?; + + if channel_pdu_header + .flags + .contains(ChannelPDUFlags::CHANNEL_FLAG_LAST) + { + return Ok(Some(Cursor::new(self.data.split_off(0)))); + } + + Ok(None) + } +} /// The default maximum chunk size for virtual channel data. /// @@ -52,6 +102,7 @@ bitflags! { /// transmitted between an RDP client and server. /// /// It is specified in section 2.2.6.1.1 of MS-RDPBCGR. +/// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/f125c65e-6901-43c3-8071-d7d5aaee7ae4 #[derive(Debug)] pub struct ChannelPDUHeader { /// The total length of the uncompressed PDU data, From 495de0ffee647d2c4980a78a93ec9e8ab2fbb421 Mon Sep 17 00:00:00 2001 From: Isaiah Becker-Mayer Date: Wed, 13 Apr 2022 17:22:08 -0400 Subject: [PATCH 2/3] reuse `vchan::Client` to add the general header and break messages into chunks (#11714) --- lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs | 174 +++++++++---------- lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs | 64 +++---- lib/srv/desktop/rdp/rdpclient/src/util.rs | 10 +- lib/srv/desktop/rdp/rdpclient/src/vchan.rs | 43 +++++ 4 files changed, 164 insertions(+), 127 deletions(-) diff --git a/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs b/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs index dee745039eaa8..556936d6c99b4 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::util; use crate::errors::invalid_data_error; +use crate::util; use crate::{vchan, Payload}; use bitflags::bitflags; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -132,7 +132,7 @@ impl Client { let (data, format) = encode_clipboard(converted); self.clipboard.insert(format as u32, data); - encode_message( + self.add_headers_and_chunkify( ClipboardPDUType::CB_FORMAT_LIST, FormatListPDU { format_names: vec![LongFormatName::id(format as u32)], @@ -166,7 +166,7 @@ impl Client { // 1. Send our clipboard capabilities // 2. Mimic a "copy" operation by sending a format list PDU // This completes the initialization process. - let mut result = encode_message( + let mut result = self.add_headers_and_chunkify( ClipboardPDUType::CB_CLIP_CAPS, ClipboardCapabilitiesPDU { general: Some(GeneralClipboardCapabilitySet { @@ -176,13 +176,15 @@ impl Client { } .encode()?, )?; - result.extend(encode_message( - ClipboardPDUType::CB_FORMAT_LIST, - FormatListPDU:: { - format_names: vec![LongFormatName::id(0)], - } - .encode()?, - )?); + result.extend( + self.add_headers_and_chunkify( + ClipboardPDUType::CB_FORMAT_LIST, + FormatListPDU:: { + format_names: vec![LongFormatName::id(0)], + } + .encode()?, + )?, + ); Ok(result) } @@ -202,7 +204,8 @@ impl Client { .collect::>(); debug!("{:?} data was copied on the RDP server", formats); - let mut result = encode_message(ClipboardPDUType::CB_FORMAT_LIST_RESPONSE, vec![])?; + let mut result = + self.add_headers_and_chunkify(ClipboardPDUType::CB_FORMAT_LIST_RESPONSE, vec![])?; let request_format; if formats.contains(&(ClipboardFormat::CF_UNICODETEXT as u32)) { @@ -225,7 +228,7 @@ impl Client { self.incoming_paste_formats.push_back(request_format); // request the data by imitating a paste event. - result.extend(encode_message( + result.extend(self.add_headers_and_chunkify( ClipboardPDUType::CB_FORMAT_DATA_REQUEST, FormatDataRequestPDU::for_id(request_format as u32).encode()?, )?); @@ -261,7 +264,7 @@ impl Client { } }; - encode_message( + self.add_headers_and_chunkify( ClipboardPDUType::CB_FORMAT_DATA_RESPONSE, FormatDataResponsePDU { data }.encode()?, ) @@ -291,6 +294,50 @@ impl Client { (self.on_remote_copy)(decoded); Ok(vec![]) } + + /// add_headers_and_chunkify takes an encoded PDU ready to be sent over a virtual channel (payload), + /// adds on the Clipboard PDU Header based the passed msg_type, adds the appropriate (virtual) Channel PDU Header, + /// and splits the entire payload into chunks if the payload exceeds the maximum size. + fn add_headers_and_chunkify( + &self, + msg_type: ClipboardPDUType, + payload: Vec, + ) -> RdpResult>> { + let msg_flags = match msg_type { + // the spec requires 0 for these messages + ClipboardPDUType::CB_CLIP_CAPS => ClipboardHeaderFlags::from_bits_truncate(0), + ClipboardPDUType::CB_TEMP_DIRECTORY => ClipboardHeaderFlags::from_bits_truncate(0), + ClipboardPDUType::CB_LOCK_CLIPDATA => ClipboardHeaderFlags::from_bits_truncate(0), + ClipboardPDUType::CB_UNLOCK_CLIPDATA => ClipboardHeaderFlags::from_bits_truncate(0), + ClipboardPDUType::CB_FORMAT_DATA_REQUEST => ClipboardHeaderFlags::from_bits_truncate(0), + + // assume success for now + ClipboardPDUType::CB_FORMAT_DATA_RESPONSE => ClipboardHeaderFlags::CB_RESPONSE_OK, + ClipboardPDUType::CB_FORMAT_LIST_RESPONSE => ClipboardHeaderFlags::CB_RESPONSE_OK, + + // we don't advertise support for file transfers, so the server should never send this, + // but if it does, ensure the response indicates a failure + ClipboardPDUType::CB_FILECONTENTS_RESPONSE => ClipboardHeaderFlags::CB_RESPONSE_FAIL, + + _ => ClipboardHeaderFlags::from_bits_truncate(0), + }; + + let channel_flags = match msg_type { + ClipboardPDUType::CB_FORMAT_LIST + | ClipboardPDUType::CB_CLIP_CAPS + | ClipboardPDUType::CB_FORMAT_DATA_REQUEST + | ClipboardPDUType::CB_FORMAT_DATA_RESPONSE => { + Some(vchan::ChannelPDUFlags::CHANNEL_FLAG_SHOW_PROTOCOL) + } + _ => None, + }; + + let mut inner = + ClipboardPDUHeader::new(msg_type, msg_flags, payload.len() as u32).encode()?; + inner.extend(payload); + + self.vchan.add_header_and_chunkify(channel_flags, inner) + } } bitflags! { @@ -519,7 +566,7 @@ fn encode_clipboard(mut data: String) -> (Vec, ClipboardFormat) { (data.into_bytes(), ClipboardFormat::CF_TEXT) } else { - let encoded = util::to_nul_terminated_utf16le(&data); + let encoded = util::to_unicode(&data); (encoded, ClipboardFormat::CF_UNICODETEXT) } } @@ -631,7 +678,7 @@ impl FormatName for LongFormatName { // must be encoded as a single Unicode null character (two zero bytes) None => w.write_u16::(0)?, Some(name) => { - w.append(&mut util::to_nul_terminated_utf16le(name)); + w.append(&mut util::to_unicode(name)); } }; @@ -758,69 +805,6 @@ impl FormatDataResponsePDU { } } -/// encode_message encodes a message by wrapping it in the appropriate -/// channel header. If the payload exceeds the maximum size, the message -/// is split into multiple messages. -fn encode_message(msg_type: ClipboardPDUType, payload: Vec) -> RdpResult>> { - let msg_flags = match msg_type { - // the spec requires 0 for these messages - ClipboardPDUType::CB_CLIP_CAPS => ClipboardHeaderFlags::from_bits_truncate(0), - ClipboardPDUType::CB_TEMP_DIRECTORY => ClipboardHeaderFlags::from_bits_truncate(0), - ClipboardPDUType::CB_LOCK_CLIPDATA => ClipboardHeaderFlags::from_bits_truncate(0), - ClipboardPDUType::CB_UNLOCK_CLIPDATA => ClipboardHeaderFlags::from_bits_truncate(0), - ClipboardPDUType::CB_FORMAT_DATA_REQUEST => ClipboardHeaderFlags::from_bits_truncate(0), - - // assume success for now - ClipboardPDUType::CB_FORMAT_DATA_RESPONSE => ClipboardHeaderFlags::CB_RESPONSE_OK, - ClipboardPDUType::CB_FORMAT_LIST_RESPONSE => ClipboardHeaderFlags::CB_RESPONSE_OK, - - // we don't advertise support for file transfers, so the server should never send this, - // but if it does, ensure the response indicates a failure - ClipboardPDUType::CB_FILECONTENTS_RESPONSE => ClipboardHeaderFlags::CB_RESPONSE_FAIL, - - _ => ClipboardHeaderFlags::from_bits_truncate(0), - }; - let mut inner = ClipboardPDUHeader::new(msg_type, msg_flags, payload.len() as u32).encode()?; - inner.extend(payload); - let total_len = inner.len() as u32; - - let mut result = Vec::new(); - let mut first = true; - while !inner.is_empty() { - let i = std::cmp::min(inner.len(), vchan::CHANNEL_CHUNK_LEGNTH); - let leftover = inner.split_off(i); - - let mut channel_flags = match msg_type { - ClipboardPDUType::CB_FORMAT_LIST - | ClipboardPDUType::CB_CLIP_CAPS - | ClipboardPDUType::CB_FORMAT_DATA_REQUEST - | ClipboardPDUType::CB_FORMAT_DATA_RESPONSE => { - vchan::ChannelPDUFlags::CHANNEL_FLAG_SHOW_PROTOCOL - } - _ => vchan::ChannelPDUFlags::from_bits_truncate(0), - }; - - if first { - channel_flags.set(vchan::ChannelPDUFlags::CHANNEL_FLAG_FIRST, true); - first = false; - } - if leftover.is_empty() { - channel_flags.set(vchan::ChannelPDUFlags::CHANNEL_FLAG_LAST, true); - } - - // the Channel PDU Header always specifies the *total length* of the PDU, - // even if it has to be split into multpile chunks: - // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/a542bf19-1c86-4c80-ab3e-61449653abf6 - let mut outer = vchan::ChannelPDUHeader::new(total_len, channel_flags).encode()?; - outer.extend(inner); - result.push(outer); - - inner = leftover; - } - - Ok(result) -} - #[cfg(test)] mod tests { use crate::vchan::ChannelPDUFlags; @@ -831,15 +815,17 @@ mod tests { #[test] fn encode_format_list_short() { - let msg = encode_message( - ClipboardPDUType::CB_FORMAT_LIST, - FormatListPDU { - format_names: vec![ShortFormatName::id(ClipboardFormat::CF_TEXT as u32)], - } - .encode() - .unwrap(), - ) - .unwrap(); + let client = Client::default(); + let msg = client + .add_headers_and_chunkify( + ClipboardPDUType::CB_FORMAT_LIST, + FormatListPDU { + format_names: vec![ShortFormatName::id(ClipboardFormat::CF_TEXT as u32)], + } + .encode() + .unwrap(), + ) + .unwrap(); assert_eq!( msg[0], @@ -867,8 +853,11 @@ mod tests { format_names: vec![LongFormatName::id(0)], }; - let encoded = - encode_message(ClipboardPDUType::CB_FORMAT_LIST, empty.encode().unwrap()).unwrap(); + let client = Client::default(); + + let encoded = client + .add_headers_and_chunkify(ClipboardPDUType::CB_FORMAT_LIST, empty.encode().unwrap()) + .unwrap(); assert_eq!( encoded[0], @@ -1027,7 +1016,10 @@ mod tests { } let pdu = FormatDataResponsePDU { data }; let encoded = pdu.encode().unwrap(); - let messages = encode_message(ClipboardPDUType::CB_FORMAT_DATA_RESPONSE, encoded).unwrap(); + let client = Client::default(); + let messages = client + .add_headers_and_chunkify(ClipboardPDUType::CB_FORMAT_DATA_RESPONSE, encoded) + .unwrap(); assert_eq!(2, messages.len()); let header0 = @@ -1047,7 +1039,7 @@ mod tests { #[test] fn responds_to_format_data_request_hasdata() { // a null-terminated utf-16 string, represented as a Vec - let test_data = util::to_nul_terminated_utf16le("test"); + let test_data = util::to_unicode("test"); let mut c: Client = Default::default(); c.clipboard diff --git a/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs b/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs index 0d78147db106b..6c3e37195e8d6 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs @@ -53,7 +53,7 @@ impl Client { warn!("got {:?} RDPDR header from RDP server, ignoring because we're not redirecting any printers", header); return Ok(()); } - let resp = match header.packet_id { + let responses = match header.packet_id { PacketId::PAKID_CORE_SERVER_ANNOUNCE => { self.handle_server_announce(&mut payload)? } @@ -76,54 +76,55 @@ impl Client { "RDPDR packets {:?} are not implemented yet, ignoring", header.packet_id ); - None + vec![] } }; - if let Some(resp) = resp { - return mcs.write(&CHANNEL_NAME.to_string(), resp); + let chan = &CHANNEL_NAME.to_string(); + for resp in responses { + mcs.write(chan, resp)?; } } Ok(()) } - fn handle_server_announce(&self, payload: &mut Payload) -> RdpResult>> { + fn handle_server_announce(&self, payload: &mut Payload) -> RdpResult>> { let req = ServerAnnounceRequest::decode(payload)?; debug!("got ServerAnnounceRequest {:?}", req); - let resp = encode_message( + let resp = self.add_headers_and_chunkify( PacketId::PAKID_CORE_CLIENTID_CONFIRM, ClientAnnounceReply::new(req).encode()?, )?; debug!("sending client announce reply"); - Ok(Some(resp)) + Ok(resp) } - fn handle_server_capability(&self, payload: &mut Payload) -> RdpResult>> { + fn handle_server_capability(&self, payload: &mut Payload) -> RdpResult>> { let req = ServerCoreCapabilityRequest::decode(payload)?; debug!("got {:?}", req); - let resp = encode_message( + let resp = self.add_headers_and_chunkify( PacketId::PAKID_CORE_CLIENT_CAPABILITY, ClientCoreCapabilityResponse::new_response().encode()?, )?; debug!("sending client core capability response"); - Ok(Some(resp)) + Ok(resp) } - fn handle_client_id_confirm(&self, payload: &mut Payload) -> RdpResult>> { + fn handle_client_id_confirm(&self, payload: &mut Payload) -> RdpResult>> { let req = ServerClientIdConfirm::decode(payload)?; debug!("got ServerClientIdConfirm {:?}", req); - let resp = encode_message( + let resp = self.add_headers_and_chunkify( PacketId::PAKID_CORE_DEVICELIST_ANNOUNCE, ClientDeviceListAnnounceRequest::new_smartcard().encode()?, )?; debug!("sending client device list announce request"); - Ok(Some(resp)) + Ok(resp) } - fn handle_device_reply(&self, payload: &mut Payload) -> RdpResult>> { + fn handle_device_reply(&self, payload: &mut Payload) -> RdpResult>> { let req = ServerDeviceAnnounceResponse::decode(payload)?; debug!("got {:?}", req); @@ -138,11 +139,11 @@ impl Client { &req.result_code ))) } else { - Ok(None) + Ok(vec![]) } } - fn handle_device_io_request(&mut self, payload: &mut Payload) -> RdpResult>> { + fn handle_device_io_request(&mut self, payload: &mut Payload) -> RdpResult>> { let req = DeviceIoRequest::decode(payload)?; debug!("got {:?}", req); @@ -152,14 +153,14 @@ impl Client { let (code, res) = self.scard.ioctl(ioctl.io_control_code, payload)?; if code == SPECIAL_NO_RESPONSE { - return Ok(None); + return Ok(vec![]); } - let resp = encode_message( + let resp = self.add_headers_and_chunkify( PacketId::PAKID_CORE_DEVICE_IOCOMPLETION, DeviceControlResponse::new(&ioctl, code, res).encode()?, )?; debug!("sending device IO response"); - Ok(Some(resp)) + Ok(resp) } else { Err(invalid_data_error(&format!( "got unsupported major_function in DeviceIoRequest: {:?}", @@ -167,22 +168,23 @@ impl Client { ))) } } -} -fn encode_message(packet_id: PacketId, payload: Vec) -> RdpResult> { - let mut inner = SharedHeader::new(Component::RDPDR_CTYP_CORE, packet_id).encode()?; - inner.extend_from_slice(&payload); - let mut outer = vchan::ChannelPDUHeader::new( - inner.length() as u32, - vchan::ChannelPDUFlags::CHANNEL_FLAG_ONLY, - ) - .encode()?; - outer.extend_from_slice(&inner); - Ok(outer) + /// add_headers_and_chunkify takes an encoded PDU ready to be sent over a virtual channel (payload), + /// adds on the Shared Header based the passed packet_id, adds the appropriate (virtual) Channel PDU Header, + /// and splits the entire payload into chunks if the payload exceeds the maximum size. + fn add_headers_and_chunkify( + &self, + packet_id: PacketId, + payload: Vec, + ) -> RdpResult>> { + let mut inner = SharedHeader::new(Component::RDPDR_CTYP_CORE, packet_id).encode()?; + inner.extend_from_slice(&payload); + self.vchan.add_header_and_chunkify(None, inner) + } } /// 2.2.1.1 Shared Header (RDPDR_HEADER) -/// This header is present at the beginning of every message in this protocol. +/// This header is present at the beginning of every message in sent over the rdpdr virtual channel. /// The purpose of this header is to describe the type of the message. /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpefs/29d4108f-8163-4a67-8271-e48c4b9c2a7c #[derive(Debug)] diff --git a/lib/srv/desktop/rdp/rdpclient/src/util.rs b/lib/srv/desktop/rdp/rdpclient/src/util.rs index 087220ea1cd24..1f412b8721988 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/util.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/util.rs @@ -20,9 +20,9 @@ /// UTF-16LE encoded Vec, which is useful in cases where we want /// to handle some data in the code as a &str (or String), and later /// convert it to RDP's preferred format and send it over the wire. -pub fn to_nul_terminated_utf16le(s: &str) -> Vec { - s.encode_utf16() - .chain([0]) - .flat_map(|v| v.to_le_bytes()) - .collect() +pub fn to_unicode(s: &str) -> Vec { + let mut buf: Vec = s.encode_utf16().flat_map(|v| v.to_le_bytes()).collect(); + let mut null_terminator: Vec = vec![0, 0]; + buf.append(&mut null_terminator); + buf } diff --git a/lib/srv/desktop/rdp/rdpclient/src/vchan.rs b/lib/srv/desktop/rdp/rdpclient/src/vchan.rs index 45f0d189c8759..a22f9b7bcff91 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/vchan.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/vchan.rs @@ -66,6 +66,49 @@ impl Client { Ok(None) } + + /// add_header_and_chunkify takes an encoded PDU ready to be sent over a virtual channel (payload), + /// adds the appropriate (virtual) Channel PDU Header, and splits it into chunks if the payload exceeds + /// the maximum size. The caller may optionally provide any any non-chunk-related Channel PDU Header + /// flags that should be set. "Non-chunk-related" means any flags besides CHANNEL_FLAG_FIRST and CHANNEL_FLAG_LAST, which + /// are handled by this function automatically. + pub fn add_header_and_chunkify( + &self, + channel_flags: Option, + payload: Vec, + ) -> RdpResult>> { + let mut inner = payload; + let total_len = inner.len() as u32; + + let mut result = Vec::new(); + let mut first = true; + while !inner.is_empty() { + let i = std::cmp::min(inner.len(), CHANNEL_CHUNK_LEGNTH); + let leftover = inner.split_off(i); + + let mut channel_flags = + channel_flags.unwrap_or_else(|| ChannelPDUFlags::from_bits_truncate(0)); + + if first { + channel_flags.set(ChannelPDUFlags::CHANNEL_FLAG_FIRST, true); + first = false; + } + if leftover.is_empty() { + channel_flags.set(ChannelPDUFlags::CHANNEL_FLAG_LAST, true); + } + + // the Channel PDU Header always specifies the *total length* of the PDU, + // even if it has to be split into multpile chunks: + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/a542bf19-1c86-4c80-ab3e-61449653abf6 + let mut outer = ChannelPDUHeader::new(total_len, channel_flags).encode()?; + outer.extend(inner); + result.push(outer); + + inner = leftover; + } + + Ok(result) + } } /// The default maximum chunk size for virtual channel data. From 8725ce03f72610da19d540e86f3a06ae07170c82 Mon Sep 17 00:00:00 2001 From: Isaiah Becker-Mayer Date: Mon, 9 May 2022 17:02:25 -0400 Subject: [PATCH 3/3] `CGOError` --> `CGOErrCode` (#12499) --- lib/srv/desktop/rdp/rdpclient/client.go | 71 +++++------- lib/srv/desktop/rdp/rdpclient/librdprs.h | 28 +++-- lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs | 9 +- lib/srv/desktop/rdp/rdpclient/src/errors.rs | 4 + lib/srv/desktop/rdp/rdpclient/src/lib.rs | 116 ++++++++++--------- 5 files changed, 107 insertions(+), 121 deletions(-) diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go index 9d1187c05ff13..9d7a147ee112b 100644 --- a/lib/srv/desktop/rdp/rdpclient/client.go +++ b/lib/srv/desktop/rdp/rdpclient/client.go @@ -66,7 +66,6 @@ import "C" import ( "context" "errors" - "fmt" "image" "io" "os" @@ -242,8 +241,8 @@ func (c *Client) connect(ctx context.Context) error { C.uint16_t(c.clientHeight), C.bool(c.cfg.AllowClipboard), ) - if err := cgoError(res.err); err != nil { - return trace.Wrap(err) + if res.err != C.ErrCodeSuccess { + return trace.ConnectionProblem(nil, "RDP connection failed") } c.rustClient = res.client return nil @@ -260,7 +259,7 @@ func (c *Client) start() { // C.read_rdp_output blocks for the duration of the RDP connection and // calls handle_bitmap repeatedly with the incoming bitmaps. - if err := cgoError(C.read_rdp_output(c.rustClient)); err != nil { + if err := C.read_rdp_output(c.rustClient); err != C.ErrCodeSuccess { c.cfg.Log.Warningf("Failed reading RDP output frame: %v", err) // close the TDP connection to the browser @@ -297,7 +296,7 @@ func (c *Client) start() { switch m := msg.(type) { case tdp.MouseMove: mouseX, mouseY = m.X, m.Y - if err := cgoError(C.write_rdp_pointer( + if err := C.write_rdp_pointer( c.rustClient, C.CGOMousePointerEvent{ x: C.uint16_t(m.X), @@ -305,8 +304,7 @@ func (c *Client) start() { button: C.PointerButtonNone, wheel: C.PointerWheelNone, }, - )); err != nil { - c.cfg.Log.Warningf("Failed forwarding RDP mouse pointer: %v", err) + ); err != C.ErrCodeSuccess { return } case tdp.MouseButton: @@ -322,7 +320,7 @@ func (c *Client) start() { default: button = C.PointerButtonNone } - if err := cgoError(C.write_rdp_pointer( + if err := C.write_rdp_pointer( c.rustClient, C.CGOMousePointerEvent{ x: C.uint16_t(mouseX), @@ -331,8 +329,7 @@ func (c *Client) start() { down: m.State == tdp.ButtonPressed, wheel: C.PointerWheelNone, }, - )); err != nil { - c.cfg.Log.Warningf("Failed forwarding RDP mouse button: %v", err) + ); err != C.ErrCodeSuccess { return } case tdp.MouseWheel: @@ -351,7 +348,7 @@ func (c *Client) start() { default: wheel = C.PointerWheelNone } - if err := cgoError(C.write_rdp_pointer( + if err := C.write_rdp_pointer( c.rustClient, C.CGOMousePointerEvent{ x: C.uint16_t(mouseX), @@ -360,29 +357,26 @@ func (c *Client) start() { wheel: uint32(wheel), wheel_delta: C.int16_t(m.Delta), }, - )); err != nil { - c.cfg.Log.Warningf("Failed forwarding RDP mouse wheel: %v", err) + ); err != C.ErrCodeSuccess { return } case tdp.KeyboardButton: - if err := cgoError(C.write_rdp_keyboard( + if err := C.write_rdp_keyboard( c.rustClient, C.CGOKeyboardEvent{ code: C.uint16_t(m.KeyCode), down: m.State == tdp.ButtonPressed, }, - )); err != nil { - c.cfg.Log.Warningf("Failed forwarding RDP key press: %v", err) + ); err != C.ErrCodeSuccess { return } case tdp.ClipboardData: if len(m) > 0 { - if err := cgoError(C.update_clipboard( + if err := C.update_clipboard( c.rustClient, (*C.uint8_t)(unsafe.Pointer(&m[0])), C.uint32_t(len(m)), - )); err != nil { - c.cfg.Log.Warningf("Failed forwarding RDP clipboard data: %v", err) + ); err != C.ErrCodeSuccess { return } } else { @@ -396,11 +390,11 @@ func (c *Client) start() { } //export handle_bitmap -func handle_bitmap(handle C.uintptr_t, cb *C.CGOBitmap) C.CGOError { +func handle_bitmap(handle C.uintptr_t, cb *C.CGOBitmap) C.CGOErrCode { return cgo.Handle(handle).Value().(*Client).handleBitmap(cb) } -func (c *Client) handleBitmap(cb *C.CGOBitmap) C.CGOError { +func (c *Client) handleBitmap(cb *C.CGOBitmap) C.CGOErrCode { // Notify the input forwarding goroutine that we're ready for input. // Input can only be sent after connection was established, which we infer // from the fact that a bitmap was sent. @@ -429,26 +423,28 @@ func (c *Client) handleBitmap(cb *C.CGOBitmap) C.CGOError { copy(img.Pix, data) if err := c.cfg.Conn.OutputMessage(tdp.NewPNG(img, c.cfg.Encoder)); err != nil { - return C.CString(fmt.Sprintf("failed to send PNG frame %v: %v", img.Rect, err)) + c.cfg.Log.Errorf("failed to send PNG frame %v: %v", img.Rect, err) + return C.ErrCodeFailure } - return nil + return C.ErrCodeSuccess } //export handle_remote_copy -func handle_remote_copy(handle C.uintptr_t, data *C.uint8_t, length C.uint32_t) C.CGOError { +func handle_remote_copy(handle C.uintptr_t, data *C.uint8_t, length C.uint32_t) C.CGOErrCode { goData := C.GoBytes(unsafe.Pointer(data), C.int(length)) return cgo.Handle(handle).Value().(*Client).handleRemoteCopy(goData) } // handleRemoteCopy is called from Rust when data is copied // on the remote desktop -func (c *Client) handleRemoteCopy(data []byte) C.CGOError { +func (c *Client) handleRemoteCopy(data []byte) C.CGOErrCode { c.cfg.Log.Debugf("Received %d bytes of clipboard data from Windows desktop", len(data)) if err := c.cfg.Conn.OutputMessage(tdp.ClipboardData(data)); err != nil { - return C.CString(fmt.Sprintf("failed to send clipboard data: %v", err)) + c.cfg.Log.Errorf("failed handling remote copy: %v", err) + return C.ErrCodeFailure } - return nil + return C.ErrCodeSuccess } // close frees the memory of the cgo.Handle, @@ -456,8 +452,9 @@ func (c *Client) handleRemoteCopy(data []byte) C.CGOError { // and frees the Rust client. func (c *Client) close() { c.closeOnce.Do(func() { - // Close the RDP client - if err := cgoError(C.close_rdp(c.rustClient)); err != nil { + c.handle.Delete() + + if err := C.close_rdp(c.rustClient); err != C.ErrCodeSuccess { c.cfg.Log.Warningf("failed to close the RDP client") } @@ -484,19 +481,3 @@ func (c *Client) UpdateClientActivity() { c.clientLastActive = time.Now().UTC() c.clientActivityMu.Unlock() } - -// cgoError converts from a CGO-originated error to a Go error, copying the -// error string and releasing the CGO data. -func cgoError(s C.CGOError) error { - if s == nil { - return nil - } - gs := C.GoString(s) - C.free_rust_string(s) - return errors.New(gs) -} - -//export free_go_string -func free_go_string(s *C.char) { - C.free(unsafe.Pointer(s)) -} diff --git a/lib/srv/desktop/rdp/rdpclient/librdprs.h b/lib/srv/desktop/rdp/rdpclient/librdprs.h index e31462e98dc69..a8c6278d77a18 100644 --- a/lib/srv/desktop/rdp/rdpclient/librdprs.h +++ b/lib/srv/desktop/rdp/rdpclient/librdprs.h @@ -20,6 +20,11 @@ */ #define CHANNEL_CHUNK_LEGNTH 1600 +typedef enum CGOErrCode { + ErrCodeSuccess = 0, + ErrCodeFailure = 1, +} CGOErrCode; + typedef enum CGOPointerButton { PointerButtonNone, PointerButtonLeft, @@ -46,14 +51,9 @@ typedef enum CGOPointerWheel { */ typedef struct Client Client; -/** - * CGOError is an alias for a C string pointer, for C API clarity. - */ -typedef char *CGOError; - typedef struct ClientOrError { struct Client *client; - CGOError err; + enum CGOErrCode err; } ClientOrError; /** @@ -126,7 +126,7 @@ struct ClientOrError connect_rdp(uintptr_t go_ref, * * `client_ptr` must be a valid pointer to a Client. */ -CGOError update_clipboard(struct Client *client_ptr, uint8_t *data, uint32_t len); +enum CGOErrCode update_clipboard(struct Client *client_ptr, uint8_t *data, uint32_t len); /** * `read_rdp_output` reads incoming RDP bitmap frames from client at client_ref and forwards them to @@ -137,28 +137,28 @@ CGOError update_clipboard(struct Client *client_ptr, uint8_t *data, uint32_t len * `client_ptr` must be a valid pointer to a Client. * `handle_bitmap` *must not* free the memory of CGOBitmap. */ -CGOError read_rdp_output(struct Client *client_ptr); +enum CGOErrCode read_rdp_output(struct Client *client_ptr); /** * # Safety * * client_ptr must be a valid pointer to a Client. */ -CGOError write_rdp_pointer(struct Client *client_ptr, struct CGOMousePointerEvent pointer); +enum CGOErrCode write_rdp_pointer(struct Client *client_ptr, struct CGOMousePointerEvent pointer); /** * # Safety * * client_ptr must be a valid pointer to a Client. */ -CGOError write_rdp_keyboard(struct Client *client_ptr, struct CGOKeyboardEvent key); +enum CGOErrCode write_rdp_keyboard(struct Client *client_ptr, struct CGOKeyboardEvent key); /** * # Safety * * client_ptr must be a valid pointer to a Client. */ -CGOError close_rdp(struct Client *client_ptr); +enum CGOErrCode close_rdp(struct Client *client_ptr); /** * free_rdp lets the Go side inform us when it's done with Client and it can be dropped. @@ -176,8 +176,6 @@ void free_rdp(struct Client *client_ptr); */ void free_rust_string(char *s); -extern void free_go_string(char *s); - -extern CGOError handle_bitmap(uintptr_t client_ref, struct CGOBitmap *b); +extern enum CGOErrCode handle_bitmap(uintptr_t client_ref, struct CGOBitmap *b); -extern CGOError handle_remote_copy(uintptr_t client_ref, uint8_t *data, uint32_t len); +extern enum CGOErrCode handle_remote_copy(uintptr_t client_ref, uint8_t *data, uint32_t len); diff --git a/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs b/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs index 556936d6c99b4..3f046b6e87e4b 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/cliprdr.rs @@ -31,19 +31,19 @@ pub const CHANNEL_NAME: &str = "cliprdr"; /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpeclip/fb9b7e0b-6db4-41c2-b83c-f889c1ee7688 pub struct Client { clipboard: HashMap>, - on_remote_copy: Box)>, incoming_paste_formats: VecDeque, + on_remote_copy: Box) -> RdpResult<()>>, vchan: vchan::Client, } impl Default for Client { fn default() -> Self { - Self::new(Box::new(|_| {})) + Self::new(Box::new(|_| Ok(()))) } } impl Client { - pub fn new(on_remote_copy: Box)>) -> Self { + pub fn new(on_remote_copy: Box) -> RdpResult<()>>) -> Self { Client { clipboard: HashMap::new(), on_remote_copy, @@ -291,7 +291,7 @@ impl Client { ); let decoded = decode_clipboard(resp.data, format)?; - (self.on_remote_copy)(decoded); + (self.on_remote_copy)(decoded)?; Ok(vec![]) } @@ -1068,6 +1068,7 @@ mod tests { let mut c = Client::new(Box::new(move |vec| { send.send(vec).unwrap(); + Ok(()) })); let data_format_list = FormatListPDU { diff --git a/lib/srv/desktop/rdp/rdpclient/src/errors.rs b/lib/srv/desktop/rdp/rdpclient/src/errors.rs index ac18c442a4cc9..7e1ec566a19bf 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/errors.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/errors.rs @@ -19,6 +19,10 @@ pub fn invalid_data_error(msg: &str) -> Error { Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, msg)) } +pub fn try_error(msg: &str) -> Error { + Error::TryError(msg.to_string()) +} + // NTSTATUS_OK is a Windows NTStatus value that means "success". pub const NTSTATUS_OK: u32 = 0; // SPECIAL_NO_RESPONSE is our custom (not defined by Windows) NTStatus value that means "don't send diff --git a/lib/srv/desktop/rdp/rdpclient/src/lib.rs b/lib/srv/desktop/rdp/rdpclient/src/lib.rs index efd5d6cf51b5c..178a33112be91 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/lib.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/lib.rs @@ -72,10 +72,13 @@ impl Client { fn into_raw(self: Box) -> *mut Self { Box::into_raw(self) } - unsafe fn from_ptr<'a>(ptr: *const Self) -> Result<&'a Client, CGOError> { + unsafe fn from_ptr<'a>(ptr: *const Self) -> Result<&'a Client, CGOErrCode> { match ptr.as_ref() { Some(c) => Ok(c), - None => Err(to_cgo_error("invalid Rust client pointer".to_string())), + None => { + error!("invalid Rust client pointer"); + Err(CGOErrCode::ErrCodeFailure) + } } } unsafe fn from_raw(ptr: *mut Self) -> Box { @@ -86,7 +89,7 @@ impl Client { #[repr(C)] pub struct ClientOrError { client: *mut Client, - err: CGOError, + err: CGOErrCode, } impl From> for ClientOrError { @@ -94,12 +97,15 @@ impl From> for ClientOrError { match r { Ok(client) => ClientOrError { client: Box::new(client).into_raw(), - err: CGO_OK, - }, - Err(e) => ClientOrError { - client: ptr::null_mut(), - err: to_cgo_error(format!("{:?}", e)), + err: CGOErrCode::ErrCodeSuccess, }, + Err(e) => { + error!("{:?}", e); + ClientOrError { + client: ptr::null_mut(), + err: CGOErrCode::ErrCodeFailure, + } + } } } } @@ -248,8 +254,15 @@ fn connect_rdp_inner( // Client for the "cliprdr" channel - clipboard sharing. let cliprdr = if params.allow_clipboard { - Some(cliprdr::Client::new(Box::new(move |v| unsafe { - handle_remote_copy(go_ref, v.as_ptr() as _, v.len() as u32); + Some(cliprdr::Client::new(Box::new(move |v| -> RdpResult<()> { + unsafe { + if handle_remote_copy(go_ref, v.as_ptr() as _, v.len() as u32) + != CGOErrCode::ErrCodeSuccess + { + return Err(errors::try_error("failed to handle remote copy")); + } + } + Ok(()) }))) } else { None @@ -406,7 +419,7 @@ pub unsafe extern "C" fn update_clipboard( client_ptr: *mut Client, data: *mut u8, len: u32, -) -> CGOError { +) -> CGOErrCode { let client = match Client::from_ptr(client_ptr) { Ok(client) => client, Err(cgo_error) => { @@ -423,17 +436,18 @@ pub unsafe extern "C" fn update_clipboard( Ok(messages) => { for message in messages { if let Err(e) = lock.mcs.write(&cliprdr::CHANNEL_NAME.to_string(), message) { - return to_cgo_error(format!( - "failed writing cliprdr format list: {:?}", - e - )); + error!("failed writing cliprdr format list: {:?}", e); + return CGOErrCode::ErrCodeFailure; } } - CGO_OK + CGOErrCode::ErrCodeSuccess + } + Err(e) => { + error!("failed updating clipboard: {:?}", e); + CGOErrCode::ErrCodeFailure } - Err(e) => to_cgo_error(format!("failed updating clipboard: {:?}", e)), }, - None => CGO_OK, + None => CGOErrCode::ErrCodeSuccess, } } @@ -445,7 +459,7 @@ pub unsafe extern "C" fn update_clipboard( /// `client_ptr` must be a valid pointer to a Client. /// `handle_bitmap` *must not* free the memory of CGOBitmap. #[no_mangle] -pub unsafe extern "C" fn read_rdp_output(client_ptr: *mut Client) -> CGOError { +pub unsafe extern "C" fn read_rdp_output(client_ptr: *mut Client) -> CGOErrCode { let client = match Client::from_ptr(client_ptr) { Ok(client) => client, Err(cgo_error) => { @@ -453,9 +467,10 @@ pub unsafe extern "C" fn read_rdp_output(client_ptr: *mut Client) -> CGOError { } }; if let Some(err) = read_rdp_output_inner(client) { - to_cgo_error(err) + error!("{}", err); + CGOErrCode::ErrCodeFailure } else { - CGO_OK + CGOErrCode::ErrCodeSuccess } } @@ -467,7 +482,7 @@ fn read_rdp_output_inner(client: &Client) -> Option { // Wait for some data to be available on the TCP socket FD before consuming it. This prevents // us from locking the mutex in Client permanently while no data is available. while wait_for_fd(tcp_fd as usize) { - let mut err = CGO_OK; + let mut err = CGOErrCode::ErrCodeSuccess; let res = client .rdp_client .lock() @@ -485,7 +500,7 @@ fn read_rdp_output_inner(client: &Client) -> Option { } }; unsafe { - err = handle_bitmap(client_ref, &mut cbitmap) as CGOError; + err = handle_bitmap(client_ref, &mut cbitmap) as CGOErrCode; }; } // These should never really be sent by the server to us. @@ -503,9 +518,8 @@ fn read_rdp_output_inner(client: &Client) -> Option { } _ => {} } - if err != CGO_OK { - let err_str = unsafe { from_cgo_error(err) }; - return Some(format!("failed forwarding RDP bitmap frame: {}", err_str)); + if err != CGOErrCode::ErrCodeSuccess { + return Some("failed forwarding RDP bitmap frame".to_string()); } } None @@ -570,7 +584,7 @@ impl From for PointerEvent { pub unsafe extern "C" fn write_rdp_pointer( client_ptr: *mut Client, pointer: CGOMousePointerEvent, -) -> CGOError { +) -> CGOErrCode { let client = match Client::from_ptr(client_ptr) { Ok(client) => client, Err(cgo_error) => { @@ -584,9 +598,10 @@ pub unsafe extern "C" fn write_rdp_pointer( .write(RdpEvent::Pointer(pointer.into())); if let Err(e) = res { - to_cgo_error(format!("failed writing RDP pointer event: {:?}", e)) + error!("failed writing RDP pointer event: {:?}", e); + CGOErrCode::ErrCodeFailure } else { - CGO_OK + CGOErrCode::ErrCodeSuccess } } @@ -618,7 +633,7 @@ impl From for KeyboardEvent { pub unsafe extern "C" fn write_rdp_keyboard( client_ptr: *mut Client, key: CGOKeyboardEvent, -) -> CGOError { +) -> CGOErrCode { let client = match Client::from_ptr(client_ptr) { Ok(client) => client, Err(cgo_error) => { @@ -631,9 +646,10 @@ pub unsafe extern "C" fn write_rdp_keyboard( .unwrap() .write(RdpEvent::Key(key.into())); if let Err(e) = res { - to_cgo_error(format!("failed writing RDP keyboard event: {:?}", e)) + error!("failed writing RDP keyboard event: {:?}", e); + CGOErrCode::ErrCodeFailure } else { - CGO_OK + CGOErrCode::ErrCodeSuccess } } @@ -641,7 +657,7 @@ pub unsafe extern "C" fn write_rdp_keyboard( /// /// client_ptr must be a valid pointer to a Client. #[no_mangle] -pub unsafe extern "C" fn close_rdp(client_ptr: *mut Client) -> CGOError { +pub unsafe extern "C" fn close_rdp(client_ptr: *mut Client) -> CGOErrCode { let client = match Client::from_ptr(client_ptr) { Ok(client) => client, Err(cgo_error) => { @@ -649,9 +665,10 @@ pub unsafe extern "C" fn close_rdp(client_ptr: *mut Client) -> CGOError { } }; if let Err(e) = client.rdp_client.lock().unwrap().shutdown() { - to_cgo_error(format!("failed writing RDP keyboard event: {:?}", e)) + error!("failed writing RDP keyboard event: {:?}", e); + CGOErrCode::ErrCodeFailure } else { - CGO_OK + CGOErrCode::ErrCodeSuccess } } @@ -687,33 +704,18 @@ unsafe fn from_go_array(len: u32, ptr: *mut u8) -> Vec { slice::from_raw_parts(ptr, len as usize).to_vec() } -/// CGOError is an alias for a C string pointer, for C API clarity. -pub type CGOError = *mut c_char; - -/// CGO_OK is a CGOError value that means "success". -const CGO_OK: CGOError = ptr::null_mut(); - -fn to_cgo_error(s: String) -> CGOError { - CString::new(s).expect("CString::new failed").into_raw() -} - -/// from_cgo_error copies CGOError into a String and frees the underlying Go memory. -/// -/// # Safety -/// -/// The pointer inside the CGOError must point to a valid null terminated Go string. -unsafe fn from_cgo_error(e: CGOError) -> String { - let s = from_go_string(e); - free_go_string(e); - s +#[repr(C)] +#[derive(Copy, Clone, PartialEq)] +pub enum CGOErrCode { + ErrCodeSuccess = 0, + ErrCodeFailure = 1, } // These functions are defined on the Go side. Look for functions with '//export funcname' // comments. extern "C" { - fn free_go_string(s: *mut c_char); - fn handle_bitmap(client_ref: usize, b: *mut CGOBitmap) -> CGOError; - fn handle_remote_copy(client_ref: usize, data: *mut u8, len: u32) -> CGOError; + fn handle_bitmap(client_ref: usize, b: *mut CGOBitmap) -> CGOErrCode; + fn handle_remote_copy(client_ref: usize, data: *mut u8, len: u32) -> CGOErrCode; } /// Payload is a generic type used to represent raw incoming RDP messages for parsing.