diff --git a/examples/go_service/main.go b/examples/go_service/main.go index 80d5a9893..810703f9b 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "log" + "net" "os" "github.com/sirupsen/logrus" @@ -56,18 +57,18 @@ pki: cert: /home/rice/Developer/nebula-config/app.crt key: /home/rice/Developer/nebula-config/app.key ` - var config config.C - if err := config.LoadString(configStr); err != nil { + var cfg config.C + if err := cfg.LoadString(configStr); err != nil { return err } l := logrus.New() l.Out = os.Stdout - service, err := service.New(&config, l) + svc, err := service.New(&cfg, l) if err != nil { return err } - ln, err := service.Listen("tcp", ":1234") + ln, err := svc.Listen("tcp", ":1234") if err != nil { return err } @@ -77,16 +78,24 @@ pki: log.Printf("accept error: %s", err) break } - defer conn.Close() + defer func(conn net.Conn) { + _ = conn.Close() + }(conn) log.Printf("got connection") - conn.Write([]byte("hello world\n")) + _, err = conn.Write([]byte("hello world\n")) + if err != nil { + log.Printf("write error: %s", err) + } scanner := bufio.NewScanner(conn) for scanner.Scan() { message := scanner.Text() - fmt.Fprintf(conn, "echo: %q\n", message) + _, err = fmt.Fprintf(conn, "echo: %q\n", message) + if err != nil { + log.Printf("write error: %s", err) + } log.Printf("got message %q", message) } @@ -96,8 +105,8 @@ pki: } } - service.Close() - if err := service.Wait(); err != nil { + _ = svc.Close() + if err := svc.Wait(); err != nil { return err } return nil diff --git a/handshake_manager.go b/handshake_manager.go index 217f11b7b..1df37bdbc 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -35,7 +35,7 @@ var ( type HandshakeConfig struct { tryInterval time.Duration - retries int + retries int64 triggerBuffer int useRelays bool @@ -69,7 +69,7 @@ type HandshakeHostInfo struct { startTime time.Time // Time that we first started trying with this handshake ready bool // Is the handshake ready - counter int // How many attempts have we made so far + counter int64 // How many attempts have we made so far lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes @@ -665,6 +665,6 @@ func generateIndex(l *logrus.Logger) (uint32, error) { return index, nil } -func hsTimeout(tries int, interval time.Duration) time.Duration { - return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval))) +func hsTimeout(tries int64, interval time.Duration) time.Duration { + return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval))) } diff --git a/main.go b/main.go index 248f329c6..c6edc9133 100644 --- a/main.go +++ b/main.go @@ -215,7 +215,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg handshakeConfig := HandshakeConfig{ tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), - retries: c.GetInt("handshakes.retries", DefaultHandshakeRetries), + retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), useRelays: useRelays, diff --git a/service/service.go b/service/service.go index 1e79dd081..aacb59591 100644 --- a/service/service.go +++ b/service/service.go @@ -9,6 +9,7 @@ import ( "log" "math" "net" + "net/netip" "os" "strings" "sync" @@ -144,24 +145,48 @@ func New(config *config.C, logger *logrus.Logger) (*Service, error) { return &s, nil } -// DialContext dials the provided address. Currently only TCP is supported. -func (s *Service) DialContext(ctx context.Context, network, address string) (*gonet.TCPConn, error) { - if network != "tcp" && network != "tcp4" { - return nil, errors.New("only tcp is supported") - } - - addr, err := net.ResolveTCPAddr(network, address) - if err != nil { - return nil, err +func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber { + if addr.Is6() { + return ipv6.ProtocolNumber } + return ipv4.ProtocolNumber +} - fullAddr := tcpip.FullAddress{ - NIC: nicID, - Addr: tcpip.AddrFromSlice(addr.IP), - Port: uint16(addr.Port), +// DialContext dials the provided address. +func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + addr, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + num := getProtocolNumber(addr.AddrPort().Addr()) + return gonet.DialUDP(s.ipstack, nil, &fullAddr, num) + case "tcp", "tcp4", "tcp6": + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + num := getProtocolNumber(addr.AddrPort().Addr()) + return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num) + default: + return nil, fmt.Errorf("unknown network type: %s", network) } +} - return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber) +// Dial dials the provided address +func (s *Service) Dial(network, address string) (net.Conn, error) { + return s.DialContext(context.Background(), network, address) } func (s *Service) DialUDP(address string) (*gonet.UDPConn, error) { diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ef072436b..2eee76ee2 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -218,9 +218,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 rsa.Addr = ip.Addr().As16() - port := ip.Port() - // Little Endian -> Network Endian - rsa.Port = (port >> 8) | ((port & 0xff) << 8) + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) for { _, _, err := unix.Syscall6( @@ -251,9 +249,7 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet4 rsa.Family = unix.AF_INET rsa.Addr = ip.Addr().As4() - port := ip.Port() - // Little Endian -> Network Endian - rsa.Port = (port >> 8) | ((port & 0xff) << 8) + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) for { _, _, err := unix.Syscall6(