diff --git a/nano/core_test/message.cpp b/nano/core_test/message.cpp index 946be3bbf4..6c570a133f 100644 --- a/nano/core_test/message.cpp +++ b/nano/core_test/message.cpp @@ -464,4 +464,173 @@ TEST (message, asc_pull_ack_serialization_account_info) ASSERT_EQ (original_payload.account_conf_height, message_payload.account_conf_height); ASSERT_TRUE (nano::at_end (stream)); +} + +TEST (message, node_id_handshake_query_serialization) +{ + nano::node_id_handshake::query_payload query{}; + query.cookie = 7; + nano::node_id_handshake original{ nano::dev::network_params.network, query }; + + // Serialize + std::vector bytes; + { + nano::vectorstream stream{ bytes }; + original.serialize (stream); + } + nano::bufferstream stream{ bytes.data (), bytes.size () }; + + // Header + bool error = false; + nano::message_header header (error, stream); + ASSERT_FALSE (error); + ASSERT_EQ (nano::message_type::node_id_handshake, header.type); + + // Message + nano::node_id_handshake message{ error, stream, header }; + ASSERT_FALSE (error); + ASSERT_TRUE (message.query); + ASSERT_FALSE (message.response); + + ASSERT_EQ (original.query->cookie, message.query->cookie); + + ASSERT_TRUE (nano::at_end (stream)); +} + +TEST (message, node_id_handshake_response_serialization) +{ + nano::node_id_handshake::response_payload response{}; + response.node_id = nano::account{ 7 }; + response.signature = nano::signature{ 11 }; + nano::node_id_handshake original{ nano::dev::network_params.network, std::nullopt, response }; + + // Serialize + std::vector bytes; + { + nano::vectorstream stream{ bytes }; + original.serialize (stream); + } + nano::bufferstream stream{ bytes.data (), bytes.size () }; + + // Header + bool error = false; + nano::message_header header (error, stream); + ASSERT_FALSE (error); + ASSERT_EQ (nano::message_type::node_id_handshake, header.type); + + // Message + nano::node_id_handshake message{ error, stream, header }; + ASSERT_FALSE (error); + ASSERT_FALSE (message.query); + ASSERT_TRUE (message.response); + ASSERT_FALSE (message.response->v2); + + ASSERT_EQ (original.response->node_id, message.response->node_id); + ASSERT_EQ (original.response->signature, message.response->signature); + + ASSERT_TRUE (nano::at_end (stream)); +} + +TEST (message, node_id_handshake_response_v2_serialization) +{ + nano::node_id_handshake::response_payload response{}; + response.node_id = nano::account{ 7 }; + response.signature = nano::signature{ 11 }; + nano::node_id_handshake::response_payload::v2_payload v2_pld{}; + v2_pld.salt = 17; + v2_pld.genesis = nano::block_hash{ 13 }; + response.v2 = v2_pld; + + nano::node_id_handshake original{ nano::dev::network_params.network, std::nullopt, response }; + + // Serialize + std::vector bytes; + { + nano::vectorstream stream{ bytes }; + original.serialize (stream); + } + nano::bufferstream stream{ bytes.data (), bytes.size () }; + + // Header + bool error = false; + nano::message_header header (error, stream); + ASSERT_FALSE (error); + ASSERT_EQ (nano::message_type::node_id_handshake, header.type); + + // Message + nano::node_id_handshake message{ error, stream, header }; + ASSERT_FALSE (error); + ASSERT_FALSE (message.query); + ASSERT_TRUE (message.response); + ASSERT_TRUE (message.response->v2); + + ASSERT_EQ (original.response->node_id, message.response->node_id); + ASSERT_EQ (original.response->signature, message.response->signature); + ASSERT_EQ (original.response->v2->salt, message.response->v2->salt); + ASSERT_EQ (original.response->v2->genesis, message.response->v2->genesis); + + ASSERT_TRUE (nano::at_end (stream)); +} + +TEST (handshake, signature) +{ + nano::keypair node_id{}; + nano::keypair node_id_2{}; + auto cookie = nano::random_pool::generate (); + auto cookie_2 = nano::random_pool::generate (); + + nano::node_id_handshake::response_payload response{}; + response.node_id = node_id.pub; + response.sign (cookie, node_id); + ASSERT_TRUE (response.validate (cookie)); + + // Invalid cookie + ASSERT_FALSE (response.validate (cookie_2)); + + // Invalid node id + response.node_id = node_id_2.pub; + ASSERT_FALSE (response.validate (cookie)); +} + +TEST (handshake, signature_v2) +{ + nano::keypair node_id{}; + nano::keypair node_id_2{}; + auto cookie = nano::random_pool::generate (); + auto cookie_2 = nano::random_pool::generate (); + + nano::node_id_handshake::response_payload original{}; + original.node_id = node_id.pub; + original.v2 = nano::node_id_handshake::response_payload::v2_payload{}; + original.v2->genesis = nano::test::random_hash (); + original.v2->salt = nano::random_pool::generate (); + original.sign (cookie, node_id); + ASSERT_TRUE (original.validate (cookie)); + + // Invalid cookie + ASSERT_FALSE (original.validate (cookie_2)); + + // Invalid node id + { + auto message = original; + ASSERT_TRUE (message.validate (cookie)); + message.node_id = node_id_2.pub; + ASSERT_FALSE (message.validate (cookie)); + } + + // Invalid genesis + { + auto message = original; + ASSERT_TRUE (message.validate (cookie)); + message.v2->genesis = nano::test::random_hash (); + ASSERT_FALSE (message.validate (cookie)); + } + + // Invalid salt + { + auto message = original; + ASSERT_TRUE (message.validate (cookie)); + message.v2->salt = nano::random_pool::generate (); + ASSERT_FALSE (message.validate (cookie)); + } } \ No newline at end of file diff --git a/nano/node/messages.cpp b/nano/node/messages.cpp index 4da575720f..4ed4281dd2 100644 --- a/nano/node/messages.cpp +++ b/nano/node/messages.cpp @@ -192,11 +192,11 @@ void nano::message_header::count_set (uint8_t count_a) extensions |= std::bitset<16> (static_cast (count_a) << 12); } -void nano::message_header::flag_set (uint8_t flag_a) +void nano::message_header::flag_set (uint8_t flag_a, bool enable) { // Flags from 8 are block_type & count debug_assert (flag_a < 8); - extensions.set (flag_a, true); + extensions.set (flag_a, enable); } bool nano::message_header::bulk_pull_is_count_present () const @@ -1599,10 +1599,12 @@ nano::node_id_handshake::node_id_handshake (nano::network_constants const & cons if (query) { header.flag_set (query_flag); + header.flag_set (v2_flag); // Always indicate support for V2 handshake when querying, old peers will just ignore it } if (response) { header.flag_set (response_flag); + header.flag_set (v2_flag, response->v2.has_value ()); // We only use V2 handshake when replying to peers that indicated support for it } } @@ -1635,7 +1637,7 @@ bool nano::node_id_handshake::deserialize (nano::stream & stream) if (is_response (header)) { response_payload pld{}; - pld.deserialize (stream); + pld.deserialize (stream, header); response = pld; } } @@ -1660,6 +1662,18 @@ bool nano::node_id_handshake::is_response (nano::message_header const & header) return result; } +bool nano::node_id_handshake::is_v2 (nano::message_header const & header) +{ + debug_assert (header.type == nano::message_type::node_id_handshake); + bool result = header.extensions.test (v2_flag); + return result; +} + +bool nano::node_id_handshake::is_v2 () const +{ + return is_v2 (header); +} + void nano::node_id_handshake::visit (nano::message_visitor & visitor_a) const { visitor_a.node_id_handshake (*this); @@ -1679,7 +1693,7 @@ std::size_t nano::node_id_handshake::size (nano::message_header const & header) } if (is_response (header)) { - result += response_payload::size; + result += response_payload::size (header); } return result; } @@ -1719,14 +1733,81 @@ void nano::node_id_handshake::query_payload::deserialize (nano::stream & stream) void nano::node_id_handshake::response_payload::serialize (nano::stream & stream) const { - nano::write (stream, node_id); - nano::write (stream, signature); + if (v2) + { + nano::write (stream, node_id); + nano::write (stream, v2->salt); + nano::write (stream, v2->genesis); + nano::write (stream, signature); + } + // TODO: Remove legacy handshake + else + { + nano::write (stream, node_id); + nano::write (stream, signature); + } +} + +void nano::node_id_handshake::response_payload::deserialize (nano::stream & stream, nano::message_header const & header) +{ + if (is_v2 (header)) + { + nano::read (stream, node_id); + v2_payload pld{}; + nano::read (stream, pld.salt); + nano::read (stream, pld.genesis); + v2 = pld; + nano::read (stream, signature); + } + else + { + nano::read (stream, node_id); + nano::read (stream, signature); + } +} + +std::size_t nano::node_id_handshake::response_payload::size (const nano::message_header & header) +{ + return is_v2 (header) ? size_v2 : size_v1; +} + +std::vector nano::node_id_handshake::response_payload::data_to_sign (const nano::uint256_union & cookie) const +{ + std::vector bytes; + { + nano::vectorstream stream{ bytes }; + + if (v2) + { + nano::write (stream, cookie); + nano::write (stream, v2->salt); + nano::write (stream, v2->genesis); + } + // TODO: Remove legacy handshake + else + { + nano::write (stream, cookie); + } + } + return bytes; +} + +void nano::node_id_handshake::response_payload::sign (const nano::uint256_union & cookie, nano::keypair const & key) +{ + debug_assert (key.pub == node_id); + auto data = data_to_sign (cookie); + signature = nano::sign_message (key.prv, key.pub, data.data (), data.size ()); + debug_assert (validate (cookie)); } -void nano::node_id_handshake::response_payload::deserialize (nano::stream & stream) +bool nano::node_id_handshake::response_payload::validate (const nano::uint256_union & cookie) const { - nano::read (stream, node_id); - nano::read (stream, signature); + auto data = data_to_sign (cookie); + if (nano::validate_message (node_id, data.data (), data.size (), signature)) // true => error + { + return false; // Fail + } + return true; // OK } /* @@ -2048,4 +2129,4 @@ void nano::asc_pull_ack::account_info_payload::deserialize (nano::stream & strea nano::read_big_endian (stream, account_block_count); nano::read (stream, account_conf_frontier); nano::read_big_endian (stream, account_conf_height); -} \ No newline at end of file +} diff --git a/nano/node/messages.hpp b/nano/node/messages.hpp index 0d52baa0eb..271da96b0a 100644 --- a/nano/node/messages.hpp +++ b/nano/node/messages.hpp @@ -74,7 +74,7 @@ class message_header final std::bitset<16> extensions; static std::size_t constexpr size = sizeof (nano::networks) + sizeof (version_max) + sizeof (version_using) + sizeof (version_min) + sizeof (type) + sizeof (/* extensions */ uint16_t); - void flag_set (uint8_t); + void flag_set (uint8_t, bool enable = true); static uint8_t constexpr bulk_pull_count_present_flag = 0; static uint8_t constexpr bulk_pull_ascending_flag = 1; bool bulk_pull_is_count_present () const; @@ -364,13 +364,30 @@ class node_id_handshake final : public message { public: void serialize (nano::stream &) const; - void deserialize (nano::stream &); + void deserialize (nano::stream &, nano::message_header const &); + + void sign (nano::uint256_union const & cookie, nano::keypair const &); + bool validate (nano::uint256_union const & cookie) const; + + private: + std::vector data_to_sign (nano::uint256_union const & cookie) const; - static std::size_t constexpr size = sizeof (nano::account) + sizeof (nano::signature); + public: + struct v2_payload + { + nano::uint256_union salt; + nano::block_hash genesis; + }; public: nano::account node_id; nano::signature signature; + std::optional v2; + + public: + static std::size_t constexpr size_v1 = sizeof (nano::account) + sizeof (nano::signature); + static std::size_t constexpr size_v2 = sizeof (nano::account) + sizeof (nano::signature) + sizeof (v2_payload); + static std::size_t size (nano::message_header const &); }; public: @@ -388,9 +405,12 @@ class node_id_handshake final : public message public: // Header static uint8_t constexpr query_flag = 0; static uint8_t constexpr response_flag = 1; + static uint8_t constexpr v2_flag = 2; static bool is_query (nano::message_header const &); static bool is_response (nano::message_header const &); + static bool is_v2 (nano::message_header const &); + bool is_v2 () const; public: // Payload std::optional query;