From 3b10795994e1363fbcee45bdcd53cc330682092c Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Tue, 11 Oct 2022 15:10:03 -0700 Subject: [PATCH] noise: use noise extension for early muxer selection. Tracking issue: https://github.com/libp2p/go-libp2p/issues/1789 obselets PR #1813 --- p2p/security/noise/benchmark_test.go | 2 +- p2p/security/noise/pb/payload.pb.go | 531 ++-------------------- p2p/security/noise/pb/payload.proto | 1 + p2p/security/noise/session.go | 12 +- p2p/security/noise/transport.go | 76 +++- p2p/security/noise/transport_test.go | 59 +++ p2p/transport/websocket/websocket_test.go | 2 +- p2p/transport/webtransport/transport.go | 2 +- 8 files changed, 183 insertions(+), 502 deletions(-) diff --git a/p2p/security/noise/benchmark_test.go b/p2p/security/noise/benchmark_test.go index 52454f5959..836275b954 100644 --- a/p2p/security/noise/benchmark_test.go +++ b/p2p/security/noise/benchmark_test.go @@ -39,7 +39,7 @@ func makeTransport(b *testing.B) *Transport { if err != nil { b.Fatal(err) } - tpt, err := New(priv) + tpt, err := New(priv, nil) if err != nil { b.Fatalf("error constructing transport: %v", err) } diff --git a/p2p/security/noise/pb/payload.pb.go b/p2p/security/noise/pb/payload.pb.go index 84db783eff..1d8368de36 100644 --- a/p2p/security/noise/pb/payload.pb.go +++ b/p2p/security/noise/pb/payload.pb.go @@ -6,9 +6,7 @@ package pb import ( fmt "fmt" proto "github.com/gogo/protobuf/proto" - io "io" math "math" - math_bits "math/bits" ) // Reference imports to suppress errors if they are not otherwise used. @@ -24,6 +22,10 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type NoiseExtensions struct { WebtransportCerthashes [][]byte `protobuf:"bytes,1,rep,name=webtransport_certhashes,json=webtransportCerthashes" json:"webtransport_certhashes,omitempty"` + StreamMuxers []string `protobuf:"bytes,2,rep,name=stream_muxers,json=streamMuxers" json:"stream_muxers,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } func (m *NoiseExtensions) Reset() { *m = NoiseExtensions{} } @@ -33,25 +35,16 @@ func (*NoiseExtensions) Descriptor() ([]byte, []int) { return fileDescriptor_678c914f1bee6d56, []int{0} } func (m *NoiseExtensions) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) + return xxx_messageInfo_NoiseExtensions.Unmarshal(m, b) } func (m *NoiseExtensions) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_NoiseExtensions.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } + return xxx_messageInfo_NoiseExtensions.Marshal(b, m, deterministic) } func (m *NoiseExtensions) XXX_Merge(src proto.Message) { xxx_messageInfo_NoiseExtensions.Merge(m, src) } func (m *NoiseExtensions) XXX_Size() int { - return m.Size() + return xxx_messageInfo_NoiseExtensions.Size(m) } func (m *NoiseExtensions) XXX_DiscardUnknown() { xxx_messageInfo_NoiseExtensions.DiscardUnknown(m) @@ -66,10 +59,20 @@ func (m *NoiseExtensions) GetWebtransportCerthashes() [][]byte { return nil } +func (m *NoiseExtensions) GetStreamMuxers() []string { + if m != nil { + return m.StreamMuxers + } + return nil +} + type NoiseHandshakePayload struct { - IdentityKey []byte `protobuf:"bytes,1,opt,name=identity_key,json=identityKey" json:"identity_key"` - IdentitySig []byte `protobuf:"bytes,2,opt,name=identity_sig,json=identitySig" json:"identity_sig"` - Extensions *NoiseExtensions `protobuf:"bytes,4,opt,name=extensions" json:"extensions,omitempty"` + IdentityKey []byte `protobuf:"bytes,1,opt,name=identity_key,json=identityKey" json:"identity_key,omitempty"` + IdentitySig []byte `protobuf:"bytes,2,opt,name=identity_sig,json=identitySig" json:"identity_sig,omitempty"` + Extensions *NoiseExtensions `protobuf:"bytes,4,opt,name=extensions" json:"extensions,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } func (m *NoiseHandshakePayload) Reset() { *m = NoiseHandshakePayload{} } @@ -79,25 +82,16 @@ func (*NoiseHandshakePayload) Descriptor() ([]byte, []int) { return fileDescriptor_678c914f1bee6d56, []int{1} } func (m *NoiseHandshakePayload) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) + return xxx_messageInfo_NoiseHandshakePayload.Unmarshal(m, b) } func (m *NoiseHandshakePayload) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - if deterministic { - return xxx_messageInfo_NoiseHandshakePayload.Marshal(b, m, deterministic) - } else { - b = b[:cap(b)] - n, err := m.MarshalToSizedBuffer(b) - if err != nil { - return nil, err - } - return b[:n], nil - } + return xxx_messageInfo_NoiseHandshakePayload.Marshal(b, m, deterministic) } func (m *NoiseHandshakePayload) XXX_Merge(src proto.Message) { xxx_messageInfo_NoiseHandshakePayload.Merge(m, src) } func (m *NoiseHandshakePayload) XXX_Size() int { - return m.Size() + return xxx_messageInfo_NoiseHandshakePayload.Size(m) } func (m *NoiseHandshakePayload) XXX_DiscardUnknown() { xxx_messageInfo_NoiseHandshakePayload.DiscardUnknown(m) @@ -134,470 +128,21 @@ func init() { func init() { proto.RegisterFile("payload.proto", fileDescriptor_678c914f1bee6d56) } var fileDescriptor_678c914f1bee6d56 = []byte{ - // 221 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x48, 0xac, 0xcc, - 0xc9, 0x4f, 0x4c, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xf2, 0xe2, - 0xe2, 0xf7, 0xcb, 0xcf, 0x2c, 0x4e, 0x75, 0xad, 0x28, 0x49, 0xcd, 0x2b, 0xce, 0xcc, 0xcf, 0x2b, - 0x16, 0x32, 0xe7, 0x12, 0x2f, 0x4f, 0x4d, 0x2a, 0x29, 0x4a, 0xcc, 0x2b, 0x2e, 0xc8, 0x2f, 0x2a, - 0x89, 0x4f, 0x4e, 0x2d, 0x2a, 0xc9, 0x48, 0x2c, 0xce, 0x48, 0x2d, 0x96, 0x60, 0x54, 0x60, 0xd6, - 0xe0, 0x09, 0x12, 0x43, 0x96, 0x76, 0x86, 0xcb, 0x2a, 0xcd, 0x63, 0xe4, 0x12, 0x05, 0x1b, 0xe6, - 0x91, 0x98, 0x97, 0x52, 0x9c, 0x91, 0x98, 0x9d, 0x1a, 0x00, 0xb1, 0x4f, 0x48, 0x9d, 0x8b, 0x27, - 0x33, 0x25, 0x35, 0xaf, 0x24, 0xb3, 0xa4, 0x32, 0x3e, 0x3b, 0xb5, 0x52, 0x82, 0x51, 0x81, 0x51, - 0x83, 0xc7, 0x89, 0xe5, 0xc4, 0x3d, 0x79, 0x86, 0x20, 0x6e, 0x98, 0x8c, 0x77, 0x6a, 0x25, 0x8a, - 0xc2, 0xe2, 0xcc, 0x74, 0x09, 0x26, 0x6c, 0x0a, 0x83, 0x33, 0xd3, 0x85, 0x8c, 0xb9, 0xb8, 0x52, - 0xe1, 0x4e, 0x96, 0x60, 0x51, 0x60, 0xd4, 0xe0, 0x36, 0x12, 0xd6, 0x2b, 0x48, 0xd2, 0x43, 0xf3, - 0x4d, 0x10, 0x92, 0x32, 0x27, 0x89, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0xf0, - 0x48, 0x8e, 0x71, 0xc2, 0x63, 0x39, 0x86, 0x0b, 0x8f, 0xe5, 0x18, 0x6e, 0x3c, 0x96, 0x63, 0x00, - 0x04, 0x00, 0x00, 0xff, 0xff, 0xb2, 0xb0, 0x39, 0x45, 0x1a, 0x01, 0x00, 0x00, -} - -func (m *NoiseExtensions) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *NoiseExtensions) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *NoiseExtensions) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if len(m.WebtransportCerthashes) > 0 { - for iNdEx := len(m.WebtransportCerthashes) - 1; iNdEx >= 0; iNdEx-- { - i -= len(m.WebtransportCerthashes[iNdEx]) - copy(dAtA[i:], m.WebtransportCerthashes[iNdEx]) - i = encodeVarintPayload(dAtA, i, uint64(len(m.WebtransportCerthashes[iNdEx]))) - i-- - dAtA[i] = 0xa - } - } - return len(dAtA) - i, nil -} - -func (m *NoiseHandshakePayload) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalToSizedBuffer(dAtA[:size]) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *NoiseHandshakePayload) MarshalTo(dAtA []byte) (int, error) { - size := m.Size() - return m.MarshalToSizedBuffer(dAtA[:size]) -} - -func (m *NoiseHandshakePayload) MarshalToSizedBuffer(dAtA []byte) (int, error) { - i := len(dAtA) - _ = i - var l int - _ = l - if m.Extensions != nil { - { - size, err := m.Extensions.MarshalToSizedBuffer(dAtA[:i]) - if err != nil { - return 0, err - } - i -= size - i = encodeVarintPayload(dAtA, i, uint64(size)) - } - i-- - dAtA[i] = 0x22 - } - if m.IdentitySig != nil { - i -= len(m.IdentitySig) - copy(dAtA[i:], m.IdentitySig) - i = encodeVarintPayload(dAtA, i, uint64(len(m.IdentitySig))) - i-- - dAtA[i] = 0x12 - } - if m.IdentityKey != nil { - i -= len(m.IdentityKey) - copy(dAtA[i:], m.IdentityKey) - i = encodeVarintPayload(dAtA, i, uint64(len(m.IdentityKey))) - i-- - dAtA[i] = 0xa - } - return len(dAtA) - i, nil -} - -func encodeVarintPayload(dAtA []byte, offset int, v uint64) int { - offset -= sovPayload(v) - base := offset - for v >= 1<<7 { - dAtA[offset] = uint8(v&0x7f | 0x80) - v >>= 7 - offset++ - } - dAtA[offset] = uint8(v) - return base -} -func (m *NoiseExtensions) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if len(m.WebtransportCerthashes) > 0 { - for _, b := range m.WebtransportCerthashes { - l = len(b) - n += 1 + l + sovPayload(uint64(l)) - } - } - return n -} - -func (m *NoiseHandshakePayload) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if m.IdentityKey != nil { - l = len(m.IdentityKey) - n += 1 + l + sovPayload(uint64(l)) - } - if m.IdentitySig != nil { - l = len(m.IdentitySig) - n += 1 + l + sovPayload(uint64(l)) - } - if m.Extensions != nil { - l = m.Extensions.Size() - n += 1 + l + sovPayload(uint64(l)) - } - return n -} - -func sovPayload(x uint64) (n int) { - return (math_bits.Len64(x|1) + 6) / 7 -} -func sozPayload(x uint64) (n int) { - return sovPayload(uint64((x << 1) ^ uint64((int64(x) >> 63)))) -} -func (m *NoiseExtensions) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowPayload - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: NoiseExtensions: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: NoiseExtensions: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field WebtransportCerthashes", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowPayload - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthPayload - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthPayload - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.WebtransportCerthashes = append(m.WebtransportCerthashes, make([]byte, postIndex-iNdEx)) - copy(m.WebtransportCerthashes[len(m.WebtransportCerthashes)-1], dAtA[iNdEx:postIndex]) - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipPayload(dAtA[iNdEx:]) - if err != nil { - return err - } - if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLengthPayload - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func (m *NoiseHandshakePayload) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowPayload - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: NoiseHandshakePayload: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: NoiseHandshakePayload: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field IdentityKey", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowPayload - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthPayload - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthPayload - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.IdentityKey = append(m.IdentityKey[:0], dAtA[iNdEx:postIndex]...) - if m.IdentityKey == nil { - m.IdentityKey = []byte{} - } - iNdEx = postIndex - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field IdentitySig", wireType) - } - var byteLen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowPayload - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - byteLen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if byteLen < 0 { - return ErrInvalidLengthPayload - } - postIndex := iNdEx + byteLen - if postIndex < 0 { - return ErrInvalidLengthPayload - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.IdentitySig = append(m.IdentitySig[:0], dAtA[iNdEx:postIndex]...) - if m.IdentitySig == nil { - m.IdentitySig = []byte{} - } - iNdEx = postIndex - case 4: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Extensions", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowPayload - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - msglen |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthPayload - } - postIndex := iNdEx + msglen - if postIndex < 0 { - return ErrInvalidLengthPayload - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - if m.Extensions == nil { - m.Extensions = &NoiseExtensions{} - } - if err := m.Extensions.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipPayload(dAtA[iNdEx:]) - if err != nil { - return err - } - if (skippy < 0) || (iNdEx+skippy) < 0 { - return ErrInvalidLengthPayload - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func skipPayload(dAtA []byte) (n int, err error) { - l := len(dAtA) - iNdEx := 0 - depth := 0 - for iNdEx < l { - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowPayload - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - wireType := int(wire & 0x7) - switch wireType { - case 0: - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowPayload - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - iNdEx++ - if dAtA[iNdEx-1] < 0x80 { - break - } - } - case 1: - iNdEx += 8 - case 2: - var length int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowPayload - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - length |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if length < 0 { - return 0, ErrInvalidLengthPayload - } - iNdEx += length - case 3: - depth++ - case 4: - if depth == 0 { - return 0, ErrUnexpectedEndOfGroupPayload - } - depth-- - case 5: - iNdEx += 4 - default: - return 0, fmt.Errorf("proto: illegal wireType %d", wireType) - } - if iNdEx < 0 { - return 0, ErrInvalidLengthPayload - } - if depth == 0 { - return iNdEx, nil - } - } - return 0, io.ErrUnexpectedEOF + // 213 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0x8f, 0xc1, 0x4a, 0x03, 0x31, + 0x10, 0x86, 0xc9, 0xd6, 0x8b, 0xd3, 0x14, 0x21, 0xa2, 0xe6, 0x18, 0xeb, 0x25, 0xa7, 0x3d, 0xe8, + 0xc1, 0x07, 0x10, 0x41, 0x10, 0x45, 0xe2, 0x03, 0x2c, 0x59, 0x77, 0xe8, 0x86, 0xda, 0x24, 0x64, + 0x46, 0x6c, 0x5e, 0xc3, 0x27, 0x16, 0x5a, 0x2c, 0x5d, 0xaf, 0xdf, 0xff, 0xc1, 0xcc, 0x07, 0x8b, + 0xec, 0xeb, 0x67, 0xf2, 0x43, 0x9b, 0x4b, 0xe2, 0xa4, 0x9a, 0xdc, 0x2f, 0x13, 0x9c, 0xbd, 0xa6, + 0x40, 0xf8, 0xb8, 0x65, 0x8c, 0x14, 0x52, 0x24, 0x75, 0x0f, 0x57, 0xdf, 0xd8, 0x73, 0xf1, 0x91, + 0x72, 0x2a, 0xdc, 0x7d, 0x60, 0xe1, 0xd1, 0xd3, 0x88, 0xa4, 0x85, 0x99, 0x59, 0xe9, 0x2e, 0x8f, + 0xe7, 0x87, 0xc3, 0xaa, 0x6e, 0x60, 0x41, 0x5c, 0xd0, 0x6f, 0xba, 0xcd, 0xd7, 0x16, 0x0b, 0xe9, + 0xc6, 0xcc, 0xec, 0xa9, 0x93, 0x7b, 0xf8, 0xb2, 0x63, 0xcb, 0x1f, 0x01, 0x17, 0xbb, 0x8b, 0x4f, + 0x3e, 0x0e, 0x34, 0xfa, 0x35, 0xbe, 0xed, 0x9f, 0x52, 0xd7, 0x20, 0xc3, 0x80, 0x91, 0x03, 0xd7, + 0x6e, 0x8d, 0x55, 0x0b, 0x23, 0xac, 0x74, 0xf3, 0x3f, 0xf6, 0x8c, 0x75, 0xa2, 0x50, 0x58, 0xe9, + 0x66, 0xaa, 0xbc, 0x87, 0x95, 0xba, 0x03, 0xc0, 0x43, 0x8b, 0x3e, 0x31, 0xc2, 0xce, 0x6f, 0xcf, + 0xdb, 0xdc, 0xb7, 0xff, 0x32, 0xdd, 0x91, 0xf6, 0x1b, 0x00, 0x00, 0xff, 0xff, 0x9a, 0x51, 0xda, + 0x51, 0x19, 0x01, 0x00, 0x00, } var ( diff --git a/p2p/security/noise/pb/payload.proto b/p2p/security/noise/pb/payload.proto index 7c1b0bdcae..ff303b0daf 100644 --- a/p2p/security/noise/pb/payload.proto +++ b/p2p/security/noise/pb/payload.proto @@ -3,6 +3,7 @@ package pb; message NoiseExtensions { repeated bytes webtransport_certhashes = 1; + repeated string stream_muxers = 2; } message NoiseHandshakePayload { diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index f1286b9ffb..ce8d97cdc8 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -40,6 +40,9 @@ type secureSession struct { prologue []byte initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler + + // ConnectionState holds state information releated to the secureSession entity. + connectionState network.ConnectionState } // newSecureSession creates a Noise session over the given insecureConn Conn, using @@ -110,7 +113,7 @@ func (s *secureSession) RemotePublicKey() crypto.PubKey { } func (s *secureSession) ConnState() network.ConnectionState { - return network.ConnectionState{} + return s.connectionState } func (s *secureSession) SetDeadline(t time.Time) error { @@ -128,3 +131,10 @@ func (s *secureSession) SetWriteDeadline(t time.Time) error { func (s *secureSession) Close() error { return s.insecureConn.Close() } + +func SessionWithConnState(s *secureSession, muxer string) *secureSession { + if s != nil { + s.connectionState.NextProto = muxer + } + return s +} diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index c6923698cc..59e60578bb 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -7,13 +7,19 @@ import ( "github.com/libp2p/go-libp2p/core/canonicallog" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" + "github.com/libp2p/go-libp2p/p2p/security/noise/pb" manet "github.com/multiformats/go-multiaddr/net" ) // ID is the protocol ID for noise -const ID = "/noise" +const ( + ID = "/noise" + MAX_EXTENSION_SIZE = 2048 + MAX_PROTO_NUM = 100 +) var _ sec.SecureTransport = &Transport{} @@ -22,38 +28,51 @@ var _ sec.SecureTransport = &Transport{} type Transport struct { localID peer.ID privateKey crypto.PrivKey + muxers []string } // New creates a new Noise transport using the given private key as its // libp2p identity key. -func New(privkey crypto.PrivKey) (*Transport, error) { +func New(privkey crypto.PrivKey, muxers []protocol.ID) (*Transport, error) { localID, err := peer.IDFromPrivateKey(privkey) if err != nil { return nil, err } + smuxers := make([]string, 0, len(muxers)) + for _, muxer := range muxers { + smuxers = append(smuxers, string(muxer)) + } + return &Transport{ localID: localID, privateKey: privkey, + muxers: smuxers, }, nil } // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false) + responderEDH := NewTransportEDH(t) + c, err := newSecureSession(t, ctx, insecure, p, nil, nil, responderEDH, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) if maErr == nil { canonicallog.LogPeerStatus(100, p, addr, "handshake_failure", "noise", "err", err.Error()) } } - return c, err + return SessionWithConnState(c, responderEDH.MatchMuxers(false)), err } // SecureOutbound runs the Noise handshake as the initiator. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true) + initiatorEDH := NewTransportEDH(t) + c, err := newSecureSession(t, ctx, insecure, p, nil, initiatorEDH, nil, true) + if err != nil { + return c, err + } + return SessionWithConnState(c, initiatorEDH.MatchMuxers(true)), err } func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) { @@ -65,3 +84,50 @@ func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTranspo } return st, nil } + +func matchMuxers(initiatorMuxers, responderMuxers []string) string { + selectedMuxer := "" + for _, muxer := range initiatorMuxers { + for _, respMuxer := range responderMuxers { + if respMuxer == muxer { + selectedMuxer = muxer + break + } + } + if selectedMuxer != "" { + break + } + } + return selectedMuxer +} + +type transportEarlyDataHandler struct { + transport *Transport + receivedMuxers []string +} + +func NewTransportEDH(t *Transport) *transportEarlyDataHandler { + return &transportEarlyDataHandler{transport: t} +} + +func (i *transportEarlyDataHandler) Send(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions { + return &pb.NoiseExtensions{ + WebtransportCerthashes: [][]byte{}, + StreamMuxers: i.transport.muxers, + } +} + +func (i *transportEarlyDataHandler) Received(_ context.Context, _ net.Conn, extension *pb.NoiseExtensions) error { + // Discard messages with size or the number of protocols exceeding extension limit for security. + if extension != nil && extension.XXX_Size() <= MAX_EXTENSION_SIZE && len(extension.StreamMuxers) <= MAX_PROTO_NUM { + i.receivedMuxers = extension.GetStreamMuxers() + } + return nil +} + +func (i *transportEarlyDataHandler) MatchMuxers(isInitiator bool) string { + if isInitiator { + return matchMuxers(i.transport.muxers, i.receivedMuxers) + } + return matchMuxers(i.receivedMuxers, i.transport.muxers) +} diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index 2fa90d06ef..411b1a34c8 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -37,6 +37,12 @@ func newTestTransport(t *testing.T, typ, bits int) *Transport { } } +func newTestTransportWithMuxers(t *testing.T, typ, bits int, muxers []string) *Transport { + transport := newTestTransport(t, typ, bits) + transport.muxers = muxers + return transport +} + // Create a new pair of connected TCP sockets. func newConnPair(t *testing.T) (net.Conn, net.Conn) { lstnr, err := net.Listen("tcp", "localhost:0") @@ -586,3 +592,56 @@ func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) { require.NoError(t, err) } } + +type noiseEarlyDataTestCase struct { + initProtos []string + respProtos []string + expectedResult string +} + +func TestHandshakeWithTransportEarlyData(t *testing.T) { + tests := []noiseEarlyDataTestCase{ + {initProtos: nil, respProtos: nil, expectedResult: ""}, + {[]string{"muxer1"}, []string{"muxer1"}, "muxer1"}, + {[]string{"muxer1"}, []string{}, ""}, + {[]string{}, []string{"muxer1"}, ""}, + {[]string{"muxer2"}, []string{"muxer1"}, ""}, + {[]string{"muxer1/1.0.0", "muxer2/1.0.1"}, []string{"muxer2/1.0.1", "muxer1/1.0.0"}, "muxer1/1.0.0"}, + {[]string{"muxer1/1.0.0", "muxer2/1.0.1", "muxer3/1.0.0"}, []string{"muxer2/1.0.1", "muxer1/1.0.1", "muxer3/1.0.0"}, "muxer2/1.0.1"}, + {[]string{"muxer1/1.0.0", "muxer2/1.0.0"}, []string{"muxer3/1.0.0"}, ""}, + } + + noiseHandshake := func(t *testing.T, initProtos, respProtos []string, expectedProto string) { + initTransport := newTestTransportWithMuxers(t, crypto.Ed25519, 2048, initProtos) + respTransport := newTestTransportWithMuxers(t, crypto.Ed25519, 2048, respProtos) + + initConn, respConn := connect(t, initTransport, respTransport) + defer initConn.Close() + defer respConn.Close() + + require.Equal(t, expectedProto, initConn.connectionState.NextProto) + require.Equal(t, expectedProto, respConn.connectionState.NextProto) + + initData := []byte("Test data for noise transport") + _, err := initConn.Write(initData) + if err != nil { + t.Fatal(err) + } + + respData := make([]byte, len(initData)) + _, err = respConn.Read(respData) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(initData, respData) { + t.Errorf("Data transmitted mismatch over noise session. %v != %v", initData, respData) + } + } + + for _, test := range tests { + t.Run("Transport EarlyData Test", func(t *testing.T) { + noiseHandshake(t, test.initProtos, test.respProtos, test.expectedResult) + }) + } +} diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 1961e9cec9..e41a414a1c 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -79,7 +79,7 @@ func newSecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { t.Fatal(err) } var secMuxer csms.SSMuxer - noiseTpt, err := noise.New(priv) + noiseTpt, err := noise.New(priv, nil) require.NoError(t, err) secMuxer.AddTransport(noise.ID, noiseTpt) return id, &secMuxer diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 5cd3a88170..c67bd960b6 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -102,7 +102,7 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa return nil, err } } - n, err := noise.New(key) + n, err := noise.New(key, nil) if err != nil { return nil, err }