diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index 9cc9e356..9611a050 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -14,6 +14,7 @@ // use std::cell::RefCell; use std::convert::TryInto; +use std::num::Wrapping; use log::{debug, error, info, trace, warn}; use russh_cryptovec::CryptoVec; @@ -26,7 +27,8 @@ use crate::negotiation::{Named, Select}; use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage}; use crate::session::{Encrypted, EncryptedState, Kex, KexInit}; use crate::{ - auth, msg, negotiation, Channel, ChannelId, ChannelMsg, ChannelOpenFailure, ChannelParams, Sig, + auth, msg, negotiation, strict_kex_violation, Channel, ChannelId, ChannelMsg, + ChannelOpenFailure, ChannelParams, Sig, }; thread_local! { @@ -37,6 +39,7 @@ impl Session { pub(crate) async fn client_read_encrypted( mut self, mut client: H, + seqn: &mut Wrapping, buf: &[u8], ) -> Result<(H, Self), H::Error> { #[allow(clippy::indexing_slicing)] // length checked @@ -65,6 +68,12 @@ impl Session { }; if let Some(kexinit) = kexinit { + if let Some(ref algo) = kexinit.algo { + if self.common.strict_kex && !algo.strict_kex { + return Err(strict_kex_violation(msg::KEXINIT, 0).into()); + } + } + let dhdone = kexinit.client_parse( self.common.config.as_ref(), &mut *self.common.cipher.local_to_remote, @@ -100,6 +109,7 @@ impl Session { .local_to_remote .write(&[msg::NEWKEYS], &mut self.common.write_buffer); self.flush()?; + self.common.maybe_reset_seqn(); Ok((client, self)) } else { error!("Wrong packet received"); @@ -125,6 +135,11 @@ impl Session { self.pending_len = 0; self.common.newkeys(newkeys); self.flush()?; + + if self.common.strict_kex { + *seqn = Wrapping(0); + } + return Ok((client, self)); } Some(Kex::Init(k)) => { diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index ddbd4b8a..4128857c 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -77,6 +77,7 @@ use std::cell::RefCell; use std::collections::HashMap; +use std::num::Wrapping; use std::pin::Pin; use std::sync::Arc; @@ -104,7 +105,8 @@ use crate::session::{CommonSession, EncryptedState, Exchange, Kex, KexDhDone, Ke use crate::ssh_read::SshRead; use crate::sshbuffer::{SSHBuffer, SshId}; use crate::{ - auth, msg, negotiation, timeout, ChannelId, ChannelOpenFailure, Disconnect, Limits, Sig, + auth, msg, negotiation, strict_kex_violation, timeout, ChannelId, ChannelOpenFailure, + Disconnect, Limits, Sig, }; mod encrypted; @@ -128,6 +130,8 @@ pub struct Session { inbound_channel_receiver: Receiver, } +const STRICT_KEX_MSG_ORDER: &[u8] = &[msg::KEXINIT, msg::KEX_ECDH_REPLY, msg::NEWKEYS]; + impl Drop for Session { fn drop(&mut self) { debug!("drop session") @@ -693,6 +697,7 @@ where wants_reply: false, disconnected: false, buffer: CryptoVec::new(), + strict_kex: false, }, session_receiver, session_sender, @@ -784,7 +789,7 @@ impl Session { self.send_keepalive(true); } r = &mut reading => { - let (stream_read, buffer, mut opening_cipher) = match r { + let (stream_read, mut buffer, mut opening_cipher) = match r { Ok((_, stream_read, buffer, opening_cipher)) => (stream_read, buffer, opening_cipher), Err(e) => return Err(e.into()) }; @@ -813,8 +818,8 @@ impl Session { #[allow(clippy::indexing_slicing)] // length checked if buf[0] == crate::msg::DISCONNECT { break; - } else if buf[0] > 4 { - let (h, s) = reply(self, handler, &mut encrypted_signal, buf).await?; + } else { + let (h, s) = reply(self, handler, &mut encrypted_signal, &mut buffer.seqn, buf).await?; handler = h; self = s; } @@ -1176,8 +1181,24 @@ async fn reply( mut session: Session, mut handler: H, sender: &mut Option>, + seqn: &mut Wrapping, buf: &[u8], ) -> Result<(H, Session), H::Error> { + if let Some(message_type) = buf.first() { + if session.common.strict_kex && session.common.encrypted.is_none() { + let seqno = seqn.0 - 1; // was incremented after read() + if let Some(expected) = STRICT_KEX_MSG_ORDER.get(seqno as usize) { + if message_type != expected { + return Err(strict_kex_violation(*message_type, seqno as usize).into()); + } + } + } + + if [msg::IGNORE, msg::UNIMPLEMENTED, msg::DEBUG].contains(message_type) { + return Ok((handler, session)); + } + } + match session.common.kex.take() { Some(Kex::Init(kexinit)) => { if kexinit.algo.is_some() @@ -1191,6 +1212,11 @@ async fn reply( &mut session.common.write_buffer, )?; + // seqno has already been incremented after read() + if done.names.strict_kex && seqn.0 != 1 { + return Err(strict_kex_violation(msg::KEXINIT, seqn.0 as usize - 1).into()); + } + if done.kex.skip_exchange() { session.common.encrypted( initial_encrypted_state(&session), @@ -1216,6 +1242,7 @@ async fn reply( // We've sent ECDH_INIT, waiting for ECDH_REPLY let (kex, h) = kexdhdone.server_key_check(false, handler, buf).await?; handler = h; + session.common.strict_kex = session.common.strict_kex || kex.names.strict_kex; session.common.kex = Some(Kex::Keys(kex)); session .common @@ -1223,6 +1250,7 @@ async fn reply( .local_to_remote .write(&[msg::NEWKEYS], &mut session.common.write_buffer); session.flush()?; + session.common.maybe_reset_seqn(); Ok((handler, session)) } else { error!("Wrong packet received"); @@ -1241,13 +1269,16 @@ async fn reply( .common .encrypted(initial_encrypted_state(&session), newkeys); // Ok, NEWKEYS received, now encrypted. + if session.common.strict_kex { + *seqn = Wrapping(0); + } Ok((handler, session)) } Some(kex) => { session.common.kex = Some(kex); Ok((handler, session)) } - None => session.client_read_encrypted(handler, buf).await, + None => session.client_read_encrypted(handler, seqn, buf).await, } } diff --git a/russh/src/kex/mod.rs b/russh/src/kex/mod.rs index cc413d65..a3d25561 100644 --- a/russh/src/kex/mod.rs +++ b/russh/src/kex/mod.rs @@ -99,6 +99,10 @@ pub const NONE: Name = Name("none"); pub const EXTENSION_SUPPORT_AS_CLIENT: Name = Name("ext-info-c"); /// `ext-info-s` pub const EXTENSION_SUPPORT_AS_SERVER: Name = Name("ext-info-s"); +/// `kex-strict-c-v00@openssh.com` +pub const EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT: Name = Name("kex-strict-c-v00@openssh.com"); +/// `kex-strict-s-v00@openssh.com` +pub const EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER: Name = Name("kex-strict-s-v00@openssh.com"); const _CURVE25519: Curve25519KexType = Curve25519KexType {}; const _DH_G1_SHA1: DhGroup1Sha1KexType = DhGroup1Sha1KexType {}; diff --git a/russh/src/lib.rs b/russh/src/lib.rs index a9eed8bf..dc5fe92f 100644 --- a/russh/src/lib.rs +++ b/russh/src/lib.rs @@ -96,6 +96,7 @@ use std::fmt::{Debug, Display, Formatter}; +use log::debug; use parsing::ChannelOpenConfirmation; pub use russh_cryptovec::CryptoVec; use thiserror::Error; @@ -285,6 +286,23 @@ pub enum Error { #[error(transparent)] Elapsed(#[from] tokio::time::error::Elapsed), + + #[error("Violation detected during strict key exchange, message {message_type} at seq no {sequence_number}")] + StrictKeyExchangeViolation { + message_type: u8, + sequence_number: usize, + }, +} + +pub(crate) fn strict_kex_violation(message_type: u8, sequence_number: usize) -> crate::Error { + debug!( + "strict kex violated at sequence no. {:?}, message type: {:?}", + sequence_number, message_type + ); + crate::Error::StrictKeyExchangeViolation { + message_type, + sequence_number, + } } #[derive(Debug, Error)] diff --git a/russh/src/negotiation.rs b/russh/src/negotiation.rs index 5b0cf0af..eb143b72 100644 --- a/russh/src/negotiation.rs +++ b/russh/src/negotiation.rs @@ -23,6 +23,7 @@ use russh_keys::key::{KeyPair, PublicKey}; use crate::cipher::CIPHERS; use crate::compression::*; +use crate::kex::{EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER}; use crate::{cipher, kex, mac, msg, Error}; #[derive(Debug)] @@ -35,6 +36,7 @@ pub struct Names { pub server_compression: Compression, pub client_compression: Compression, pub ignore_guessed: bool, + pub strict_kex: bool, } /// Lists of preferred algorithms. This is normally hard-coded into implementations. @@ -56,6 +58,10 @@ const SAFE_KEX_ORDER: &[kex::Name] = &[ kex::CURVE25519, kex::CURVE25519_PRE_RFC_8731, kex::DH_G14_SHA256, + kex::EXTENSION_SUPPORT_AS_CLIENT, + kex::EXTENSION_SUPPORT_AS_SERVER, + kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, + kex::EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER, ]; const CIPHER_ORDER: &[cipher::Name] = &[ @@ -143,7 +149,9 @@ impl Named for KeyPair { } } -pub trait Select { +pub(crate) trait Select { + fn is_server() -> bool; + fn select + Copy>(a: &[S], b: &[u8]) -> Option<(bool, S)>; fn read_kex(buffer: &[u8], pref: &Preferred) -> Result { @@ -160,6 +168,24 @@ pub trait Select { return Err(Error::NoCommonKexAlgo); }; + let strict_kex_requested = pref.kex.contains(if Self::is_server() { + &EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER + } else { + &EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT + }); + let strict_kex_provided = Self::select( + &[if Self::is_server() { + EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT + } else { + EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER + }], + kex_string, + ) + .is_some(); + if strict_kex_requested && strict_kex_provided { + debug!("strict kex enabled") + } + let key_string = r.read_string()?; let (key_both_first, key_algorithm) = if let Some(x) = Self::select(pref.key, key_string) { x @@ -238,6 +264,7 @@ pub trait Select { server_compression, // Ignore the next packet if (1) it follows and (2) it's not the correct guess. ignore_guessed: fol && !(kex_both_first && key_both_first), + strict_kex: strict_kex_requested && strict_kex_provided, }) } _ => Err(Error::KexInit), @@ -249,6 +276,10 @@ pub struct Server; pub struct Client; impl Select for Server { + fn is_server() -> bool { + true + } + fn select + Copy>(server_list: &[S], client_list: &[u8]) -> Option<(bool, S)> { let mut both_first_choice = true; for c in client_list.split(|&x| x == b',') { @@ -264,6 +295,10 @@ impl Select for Server { } impl Select for Client { + fn is_server() -> bool { + false + } + fn select + Copy>(client_list: &[S], server_list: &[u8]) -> Option<(bool, S)> { let mut both_first_choice = true; for &c in client_list { @@ -287,11 +322,18 @@ pub fn write_kex(prefs: &Preferred, buf: &mut CryptoVec, as_server: bool) -> Res buf.extend(&cookie); // cookie buf.extend_list(prefs.kex.iter().filter(|k| { - **k != if as_server { - crate::kex::EXTENSION_SUPPORT_AS_CLIENT + !(if as_server { + [ + crate::kex::EXTENSION_SUPPORT_AS_CLIENT, + crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, + ] } else { - crate::kex::EXTENSION_SUPPORT_AS_SERVER - } + [ + crate::kex::EXTENSION_SUPPORT_AS_SERVER, + crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER, + ] + }) + .contains(*k) })); // kex algo buf.extend_list(prefs.key.iter()); diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index 354eddbd..a78984b2 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -34,6 +34,7 @@ impl Session { pub(crate) async fn server_read_encrypted( mut self, mut handler: H, + seqn: &mut Wrapping, buf: &[u8], ) -> Result<(H, Self), H::Error> { #[allow(clippy::indexing_slicing)] // length checked @@ -70,6 +71,9 @@ impl Session { &mut self.common.write_buffer, )?); } + if let Some(Kex::Dh(KexDh { ref names, .. })) = enc.rekey { + self.common.strict_kex = self.common.strict_kex || names.strict_kex; + } self.flush()?; return Ok((handler, self)); } @@ -82,6 +86,10 @@ impl Session { buf, &mut self.common.write_buffer, )?); + if let Some(Kex::Keys(_)) = enc.rekey { + // just sent NEWKEYS + self.common.maybe_reset_seqn(); + } self.flush()?; return Ok((handler, self)); } @@ -103,11 +111,21 @@ impl Session { self.pending_reads = pending; self.pending_len = 0; self.common.newkeys(newkeys); + if self.common.strict_kex { + *seqn = Wrapping(0); + } self.flush()?; return Ok((handler, self)); } Some(Kex::Init(k)) => { + if let Some(ref algo) = k.algo { + if self.common.strict_kex && !algo.strict_kex { + return Err(strict_kex_violation(msg::KEXINIT, 0).into()); + } + } + enc.rekey = Some(Kex::Init(k)); + self.pending_len += buf.len() as u32; if self.pending_len > 2 * self.target_window_size { return Err(Error::Pending.into()); diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index ab0c9ec8..70060c94 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -112,6 +112,7 @@ use std; use std::collections::HashMap; +use std::num::Wrapping; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -805,14 +806,33 @@ async fn read_ssh_id( wants_reply: false, disconnected: false, buffer: CryptoVec::new(), + strict_kex: false, }) } +const STRICT_KEX_MSG_ORDER: &[u8] = &[msg::KEXINIT, msg::KEX_ECDH_INIT, msg::NEWKEYS]; + async fn reply( mut session: Session, handler: H, + seqn: &mut Wrapping, buf: &[u8], ) -> Result<(H, Session), H::Error> { + if let Some(message_type) = buf.first() { + if session.common.strict_kex && session.common.encrypted.is_none() { + let seqno = seqn.0 - 1; // was incremented after read() + if let Some(expected) = STRICT_KEX_MSG_ORDER.get(seqno as usize) { + if message_type != expected { + return Err(strict_kex_violation(*message_type, seqno as usize).into()); + } + } + } + + if [msg::IGNORE, msg::UNIMPLEMENTED, msg::DEBUG].contains(message_type) { + return Ok((handler, session)); + } + } + // Handle key exchange/re-exchange. if session.common.encrypted.is_none() { match session.common.kex.take() { @@ -824,6 +844,13 @@ async fn reply( buf, &mut session.common.write_buffer, )?); + if let Some(Kex::Dh(KexDh { ref names, .. })) = session.common.kex { + session.common.strict_kex = names.strict_kex; + } + // seqno has already been incremented after read() + if session.common.strict_kex && seqn.0 != 1 { + return Err(strict_kex_violation(msg::KEXINIT, seqn.0 as usize - 1).into()); + } return Ok((handler, session)); } else { // Else, i.e. if the other side has not started @@ -839,6 +866,10 @@ async fn reply( buf, &mut session.common.write_buffer, )?); + if let Some(Kex::Keys(_)) = session.common.kex { + // just sent NEWKEYS + session.common.maybe_reset_seqn(); + } return Ok((handler, session)); } Some(Kex::Keys(newkeys)) => { @@ -854,6 +885,9 @@ async fn reply( newkeys, ); session.maybe_send_ext_info(); + if session.common.strict_kex { + *seqn = Wrapping(0); + } return Ok((handler, session)); } Some(kex) => { @@ -864,6 +898,6 @@ async fn reply( } Ok((handler, session)) } else { - Ok(session.server_read_encrypted(handler, buf).await?) + Ok(session.server_read_encrypted(handler, seqn, buf).await?) } } diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 9361e579..591e0a44 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -360,7 +360,7 @@ impl Session { while !self.common.disconnected { tokio::select! { r = &mut reading => { - let (stream_read, buffer, mut opening_cipher) = match r { + let (stream_read, mut buffer, mut opening_cipher) = match r { Ok((_, stream_read, buffer, opening_cipher)) => (stream_read, buffer, opening_cipher), Err(e) => return Err(e.into()) }; @@ -390,10 +390,10 @@ impl Session { debug!("break"); is_reading = Some((stream_read, buffer, opening_cipher)); break; - } else if buf[0] > 4 { + } else { std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); // TODO it'd be cleaner to just pass cipher to reply() - match reply(self, handler, buf).await { + match reply(self, handler, &mut buffer.seqn, buf).await { Ok((h, s)) => { handler = h; self = s; diff --git a/russh/src/session.rs b/russh/src/session.rs index 09afa95a..44cb8827 100644 --- a/russh/src/session.rs +++ b/russh/src/session.rs @@ -63,6 +63,7 @@ pub(crate) struct CommonSession { pub wants_reply: bool, pub disconnected: bool, pub buffer: CryptoVec, + pub strict_kex: bool, } impl CommonSession { @@ -74,6 +75,7 @@ impl CommonSession { enc.client_mac = newkeys.names.client_mac; enc.server_mac = newkeys.names.server_mac; self.cipher = newkeys.cipher; + self.strict_kex = self.strict_kex || newkeys.names.strict_kex; } } @@ -99,6 +101,7 @@ impl CommonSession { decompress: crate::compression::Decompress::None, }); self.cipher = newkeys.cipher; + self.strict_kex = newkeys.names.strict_kex; } /// Send a disconnect message. @@ -127,6 +130,12 @@ impl CommonSession { enc.byte(channel, msg) } } + + pub(crate) fn maybe_reset_seqn(&mut self) { + if self.strict_kex { + self.write_buffer.seqn = Wrapping(0); + } + } } impl Encrypted {