diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index 5612bb3cea..439c8698fa 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -75,6 +75,8 @@ enum DecryptionError { EncryptedMessageNoDestination, #[error("Decryption failed: {0}")] DecryptionFailedMalformedCipher(#[from] DhtEncryptError), + #[error("Encrypted message must have a non-empty body")] + EncryptedMessageEmptyBody, } /// This layer is responsible for attempting to decrypt inbound messages. @@ -346,6 +348,10 @@ where S: Service /// Performs message validation that should be performed by all nodes. If an error is encountered, the message is /// invalid and should never have been sent. fn initial_validation(message: DhtInboundMessage) -> Result { + if message.body.is_empty() { + return Err(DecryptionError::EncryptedMessageEmptyBody); + } + if message.dht_header.flags.is_encrypted() { // Check if there is no destination specified and discard if message.dht_header.destination.is_unknown() { @@ -572,6 +578,33 @@ mod test { assert_eq!(decrypted.decryption_result.unwrap_err(), inbound_msg.body); } + #[test] + fn decrypt_inbound_fail_empty_contents() { + let service = service_fn( + move |_msg: DecryptedDhtMessage| -> future::Ready> { + panic!("Should not be called") + }, + ); + let node_identity = make_node_identity(); + let (connectivity, _) = create_connectivity_mock(); + let mut service = DecryptionService::new(Default::default(), node_identity, connectivity, service); + + let some_other_node_identity = make_node_identity(); + let mut inbound_msg = make_dht_inbound_message( + &some_other_node_identity, + &Vec::new(), + DhtMessageFlags::ENCRYPTED, + true, + true, + ) + .unwrap(); + inbound_msg.body = Vec::new(); + + let err = block_on(service.call(inbound_msg.clone())).unwrap_err(); + let err = err.downcast::().unwrap(); + unpack_enum!(DecryptionError::EncryptedMessageEmptyBody = err); + } + #[runtime::test] async fn decrypt_inbound_fail_destination() { let (connectivity, mock) = create_connectivity_mock();