diff --git a/cmd/pgcli/vpn/rootless/proxy.go b/cmd/pgcli/vpn/rootless/proxy.go new file mode 100644 index 0000000..0f9d247 --- /dev/null +++ b/cmd/pgcli/vpn/rootless/proxy.go @@ -0,0 +1,102 @@ +package rootless + +import ( + "context" + "fmt" + "log/slog" + "net" + "sync" + "time" + + N "github.com/sigcn/pg/net" + "github.com/sigcn/pg/socks5" + "github.com/sigcn/pg/vpn/nic/gvisor" +) + +type ProxyConfig struct { + Listen string +} + +type ProxyServer struct { + Config ProxyConfig + GvisorCard *gvisor.GvisorCard + + udpListener *N.UDPListener +} + +func (s *ProxyServer) Start(ctx context.Context, wg *sync.WaitGroup) error { + tcpListener, err := net.Listen("tcp", s.Config.Listen) + if err != nil { + return err + } + udpPacketConn, err := net.ListenPacket("udp", s.Config.Listen) + if err != nil { + tcpListener.Close() + return err + } + wg.Add(1) + go func() { + defer wg.Done() + <-ctx.Done() + tcpListener.Close() + udpPacketConn.Close() + }() + s.udpListener = &N.UDPListener{PacketConn: udpPacketConn} + slog.Info("[Proxy] Server started", "listen", fmt.Sprintf("tcp+udp://%s", tcpListener.Addr().String())) + go s.run(tcpListener) + return nil +} + +func (s *ProxyServer) run(tcp net.Listener) { + for { + c, err := tcp.Accept() + if err != nil { + return + } + addr, cmd, err := socks5.ServerHandshake(c, nil) + if err != nil { + slog.Error("[Proxy] SOCKS5 handshake", "err", err) + continue + } + if cmd == socks5.CmdConnect { + if err := s.proxyTCP(c, addr); err != nil { + slog.Error("[Proxy] SOCKS5 tcp", "err", err) + } + continue + } + if cmd == socks5.CmdUDPAssociate { + go func() { + if err := s.proxyUDP(addr); err != nil { + slog.Error("[Proxy] SOCKS5 udp", "err", err) + } + }() + continue + } + } +} + +func (s *ProxyServer) proxyTCP(rw net.Conn, addr socks5.Addr) error { + c, err := s.GvisorCard.DialContext(context.TODO(), "tcp", addr.String()) + if err != nil { + rw.Close() + return err + } + go relay(rw, c) + return nil +} + +func (s *ProxyServer) proxyUDP(addr socks5.Addr) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + c, err := s.udpListener.AcceptContext(ctx) + if err != nil { + return err + } + c1, err := s.GvisorCard.DialContext(context.TODO(), "udp", addr.String()) + if err != nil { + c.Close() + return err + } + go relay(c, c1) + return nil +} diff --git a/cmd/pgcli/vpn/vpn.go b/cmd/pgcli/vpn/vpn.go index 4fb6923..4621a07 100644 --- a/cmd/pgcli/vpn/vpn.go +++ b/cmd/pgcli/vpn/vpn.go @@ -94,6 +94,7 @@ func usage(flagSet *flag.FlagSet) { logLevel := flagSet.Lookup("loglevel") mtu := flagSet.Lookup("mtu") peers := flagSet.Lookup("peers") + proxyListen := flagSet.Lookup("proxy-listen") server := flagSet.Lookup("s") tun := flagSet.Lookup("tun") udpPort := flagSet.Lookup("udp-port") @@ -119,6 +120,7 @@ func usage(flagSet *flag.FlagSet) { fmt.Printf(" --key string\n\t%s\n", key.Usage) fmt.Printf(" --loglevel int\n\t%s (default %s)\n", logLevel.Usage, logLevel.DefValue) fmt.Printf(" --mtu int\n\t%s (default %s)\n", mtu.Usage, mtu.DefValue) + fmt.Printf(" --proxy-listen string\n\t%s\n", proxyListen.Usage) fmt.Printf(" -s, --server string\n\t%s\n", server.Usage) fmt.Printf(" --tun string\n\t%s (default %s)\n", tun.Usage, tun.DefValue) fmt.Printf(" --udp-crypto\n\t%s (default %s)\n", cryptoAlgo.Usage, cryptoAlgo.DefValue) @@ -154,6 +156,7 @@ func createConfig(flagSet *flag.FlagSet, args []string) (cfg Config, err error) flagSet.IntVar(&cfg.NICConfig.MTU, "mtu", 1411, "nic mtu") flagSet.StringVar(&cfg.NICConfig.Name, "tun", defaultTunName, "nic name") flagSet.Var(&forwards, "forward", "start in rootless mode and create a port forward (e.g. tcp://127.0.0.1:80)") + flagSet.StringVar(&cfg.ProxyConfig.Listen, "proxy-listen", "", "start a proxy server to access the PG network (e.g. 127.0.0.1:4090)") flagSet.StringVar(&cfg.PrivateKey, "key", "", "curve25519 private key in base58 format (default generate a new one)") flagSet.StringVar(&cfg.SecretFile, "secret-file", "", "") @@ -213,6 +216,7 @@ func createConfig(flagSet *flag.FlagSet, args []string) (cfg Config, err error) type Config struct { NICConfig nic.Config + ProxyConfig rootless.ProxyConfig DiscoPortScanOffset int DiscoPortScanCount int DiscoPortScanDuration time.Duration @@ -265,6 +269,13 @@ func (v *P2PVPN) Run(ctx context.Context) (err error) { return err } } + if v.Config.ProxyConfig.Listen != "" { + if err := (&rootless.ProxyServer{ + GvisorCard: card.(*gvisor.GvisorCard), + Config: v.Config.ProxyConfig}).Start(ctx, &wg); err != nil { + return err + } + } if err := (&server.Server{ Vnic: v.nic, PeerStore: c.PeerStore(), diff --git a/net/udp.go b/net/udp.go new file mode 100644 index 0000000..7e673c4 --- /dev/null +++ b/net/udp.go @@ -0,0 +1,183 @@ +package net + +import ( + "context" + "errors" + "log/slog" + "net" + "sync" + "sync/atomic" + "time" +) + +var _ net.Conn = (*UDPConn)(nil) + +type UDPConn struct { + removeConn func() + remoteAddr net.Addr + c net.PacketConn + + closeOnce sync.Once + inbound chan []byte + closeChan chan struct{} + lastActiveTime atomic.Value +} + +func (c *UDPConn) init() { + c.inbound = make(chan []byte, 512) + c.closeChan = make(chan struct{}) + c.lastActiveTime.Store(time.Now()) + ticker := time.NewTicker(6 * time.Second) + go func() { // create a timer to trace timeout udp conn, and close it + defer ticker.Stop() + for range ticker.C { + if time.Since(c.lastActiveTime.Load().(time.Time)) > 10*time.Second { + c.Close() + break + } + } + }() +} + +func (c *UDPConn) Read(p []byte) (int, error) { + select { + case b := <-c.inbound: + c.lastActiveTime.Store(time.Now()) + return copy(p, b), nil + case <-c.closeChan: + return 0, net.ErrClosed + } +} + +func (c *UDPConn) Write(p []byte) (int, error) { + c.lastActiveTime.Store(time.Now()) + return c.c.WriteTo(p, c.remoteAddr) +} + +func (c *UDPConn) LocalAddr() net.Addr { + return c.c.LocalAddr() +} + +func (c *UDPConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *UDPConn) Close() error { + c.closeOnce.Do(func() { + close(c.closeChan) + close(c.inbound) + c.removeConn() + slog.Log(context.Background(), -2, "UDPConn closed", "local_addr", c.LocalAddr(), "remote_addr", c.remoteAddr) + }) + return nil +} + +func (c *UDPConn) SetDeadline(t time.Time) error { + return errors.ErrUnsupported +} + +func (c *UDPConn) SetReadDeadline(t time.Time) error { + return errors.ErrUnsupported +} + +func (c *UDPConn) SetWriteDeadline(t time.Time) error { + return errors.ErrUnsupported +} + +type UDPListener struct { + PacketConn net.PacketConn + + buf []byte + initOnce sync.Once + closeOnce sync.Once + udpChan chan *UDPConn + + connMap map[string]*UDPConn + connMapMu sync.RWMutex +} + +func (l *UDPListener) init() { + l.initOnce.Do(func() { + l.buf = make([]byte, 65535) + l.udpChan = make(chan *UDPConn, 8) + l.connMap = make(map[string]*UDPConn) + go l.readUDP() + }) +} + +func (l *UDPListener) readUDP() { + read := func() error { + read: + n, peerAddr, err := l.PacketConn.ReadFrom(l.buf) + if err != nil { + return err + } + l.connMapMu.RLock() + conn, ok := l.connMap[peerAddr.String()] + l.connMapMu.RUnlock() + if ok { + conn.inbound <- append([]byte(nil), l.buf[:n]...) + goto read + } + l.connMapMu.Lock() + conn, ok = l.connMap[peerAddr.String()] + if ok { + l.connMapMu.Unlock() + conn.inbound <- append([]byte(nil), l.buf[:n]...) + goto read + } + defer l.connMapMu.Unlock() + conn = &UDPConn{remoteAddr: peerAddr, c: l.PacketConn, removeConn: func() { + l.connMapMu.Lock() + defer l.connMapMu.Unlock() + delete(l.connMap, peerAddr.String()) + }} + conn.init() + l.connMap[peerAddr.String()] = conn + conn.inbound <- append([]byte(nil), l.buf[:n]...) + l.udpChan <- conn + return nil + } + for { + if err := read(); err != nil { + return + } + } +} + +func (l *UDPListener) Accept() (net.Conn, error) { + return l.AcceptContext(context.Background()) +} + +func (l *UDPListener) AcceptContext(ctx context.Context) (net.Conn, error) { + l.init() + select { + case c := <-l.udpChan: + return c, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (l *UDPListener) Close() error { + if l.PacketConn == nil { + return nil + } + l.closeOnce.Do(func() { + l.PacketConn.Close() + l.connMapMu.Lock() + defer l.connMapMu.Unlock() + for _, c := range l.connMap { + go c.Close() + } + }) + return nil +} + +func (l *UDPListener) Addr() net.Addr { + l.init() + if l.PacketConn == nil { + return nil + } + return l.PacketConn.LocalAddr() +} diff --git a/socks5/socks5.go b/socks5/socks5.go new file mode 100644 index 0000000..6f3ef7c --- /dev/null +++ b/socks5/socks5.go @@ -0,0 +1,457 @@ +// Copyright (c) 2024 sigcn/pg +// Licensed under the GNU GENERAL PUBLIC LICENSE Version 3. +// Copyright (c) 2021-2024 clash +// Licensed under the GNU GENERAL PUBLIC LICENSE Version 3. +package socks5 + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "net" + "net/netip" + "strconv" +) + +type Authenticator interface { + Verify(user string, pass string) bool +} + +// Error represents a SOCKS error +type Error byte + +func (err Error) Error() string { + return "SOCKS error: " + strconv.Itoa(int(err)) +} + +// Command is request commands as defined in RFC 1928 section 4. +type Command = uint8 + +const Version = 5 + +// SOCKS request commands as defined in RFC 1928 section 4. +const ( + CmdConnect Command = 1 + CmdBind Command = 2 + CmdUDPAssociate Command = 3 +) + +// SOCKS address types as defined in RFC 1928 section 5. +const ( + AtypIPv4 = 1 + AtypDomainName = 3 + AtypIPv6 = 4 +) + +// MaxAddrLen is the maximum size of SOCKS address in bytes. +const MaxAddrLen = 1 + 1 + 255 + 2 + +// MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth +const MaxAuthLen = 255 + +// Addr represents a SOCKS address as defined in RFC 1928 section 5. +type Addr []byte + +func (a Addr) String() string { + var host, port string + + switch a[0] { + case AtypDomainName: + hostLen := uint16(a[1]) + host = string(a[2 : 2+hostLen]) + port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1])) + case AtypIPv4: + host = net.IP(a[1 : 1+net.IPv4len]).String() + port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1])) + case AtypIPv6: + host = net.IP(a[1 : 1+net.IPv6len]).String() + port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1])) + } + + return net.JoinHostPort(host, port) +} + +// UDPAddr converts a socks5.Addr to *net.UDPAddr +func (a Addr) UDPAddr() *net.UDPAddr { + if len(a) == 0 { + return nil + } + switch a[0] { + case AtypIPv4: + var ip [net.IPv4len]byte + copy(ip[0:], a[1:1+net.IPv4len]) + return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))} + case AtypIPv6: + var ip [net.IPv6len]byte + copy(ip[0:], a[1:1+net.IPv6len]) + return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))} + } + // Other Atyp + return nil +} + +// SOCKS errors as defined in RFC 1928 section 6. +const ( + ErrGeneralFailure = Error(1) + ErrConnectionNotAllowed = Error(2) + ErrNetworkUnreachable = Error(3) + ErrHostUnreachable = Error(4) + ErrConnectionRefused = Error(5) + ErrTTLExpired = Error(6) + ErrCommandNotSupported = Error(7) + ErrAddressNotSupported = Error(8) +) + +// Auth errors used to return a specific "Auth failed" error +var ErrAuth = errors.New("auth failed") + +type User struct { + Username string + Password string +} + +// ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side. +func ServerHandshake(rw net.Conn, authenticator Authenticator) (addr Addr, command Command, err error) { + // Read RFC 1928 for request and reply structure and sizes. + buf := make([]byte, MaxAddrLen) + // read VER, NMETHODS, METHODS + if _, err = io.ReadFull(rw, buf[:2]); err != nil { + return + } + nmethods := buf[1] + if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil { + return + } + + // write VER METHOD + if authenticator != nil { + if _, err = rw.Write([]byte{5, 2}); err != nil { + return + } + + // Get header + header := make([]byte, 2) + if _, err = io.ReadFull(rw, header); err != nil { + return + } + + authBuf := make([]byte, MaxAuthLen) + // Get username + userLen := int(header[1]) + if userLen <= 0 { + rw.Write([]byte{1, 1}) + err = ErrAuth + return + } + if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil { + return + } + user := string(authBuf[:userLen]) + + // Get password + if _, err = rw.Read(header[:1]); err != nil { + return + } + passLen := int(header[0]) + if passLen <= 0 { + rw.Write([]byte{1, 1}) + err = ErrAuth + return + } + if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil { + return + } + pass := string(authBuf[:passLen]) + + // Verify + if ok := authenticator.Verify(user, pass); !ok { + rw.Write([]byte{1, 1}) + err = ErrAuth + return + } + + // Response auth state + if _, err = rw.Write([]byte{1, 0}); err != nil { + return + } + } else { + if _, err = rw.Write([]byte{5, 0}); err != nil { + return + } + } + + // read VER CMD RSV ATYP DST.ADDR DST.PORT + if _, err = io.ReadFull(rw, buf[:3]); err != nil { + return + } + + command = buf[1] + addr, err = ReadAddr(rw, buf) + if err != nil { + return + } + + switch command { + case CmdConnect, CmdUDPAssociate: + // Acquire server listened address info + localAddr := ParseAddr(rw.LocalAddr().String()) + if localAddr == nil { + err = ErrAddressNotSupported + } else { + // write VER REP RSV ATYP BND.ADDR BND.PORT + _, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{})) + } + case CmdBind: + fallthrough + default: + err = ErrCommandNotSupported + } + + return +} + +// ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side. +func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) { + buf := make([]byte, MaxAddrLen) + var err error + + // VER, NMETHODS, METHODS + if user != nil { + _, err = rw.Write([]byte{5, 1, 2}) + } else { + _, err = rw.Write([]byte{5, 1, 0}) + } + if err != nil { + return nil, err + } + + // VER, METHOD + if _, err := io.ReadFull(rw, buf[:2]); err != nil { + return nil, err + } + + if buf[0] != 5 { + return nil, errors.New("SOCKS version error") + } + + if buf[1] == 2 { + if user == nil { + return nil, ErrAuth + } + + // password protocol version + authMsg := &bytes.Buffer{} + authMsg.WriteByte(1) + authMsg.WriteByte(uint8(len(user.Username))) + authMsg.WriteString(user.Username) + authMsg.WriteByte(uint8(len(user.Password))) + authMsg.WriteString(user.Password) + + if _, err := rw.Write(authMsg.Bytes()); err != nil { + return nil, err + } + + if _, err := io.ReadFull(rw, buf[:2]); err != nil { + return nil, err + } + + if buf[1] != 0 { + return nil, errors.New("rejected username/password") + } + } else if buf[1] != 0 { + return nil, errors.New("SOCKS need auth") + } + + // VER, CMD, RSV, ADDR + if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil { + return nil, err + } + + // VER, REP, RSV + if _, err := io.ReadFull(rw, buf[:3]); err != nil { + return nil, err + } + + return ReadAddr(rw, buf) +} + +func ReadAddr(r io.Reader, b []byte) (Addr, error) { + if len(b) < MaxAddrLen { + return nil, io.ErrShortBuffer + } + _, err := io.ReadFull(r, b[:1]) // read 1st byte for address type + if err != nil { + return nil, err + } + + switch b[0] { + case AtypDomainName: + _, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length + if err != nil { + return nil, err + } + domainLength := uint16(b[1]) + _, err = io.ReadFull(r, b[2:2+domainLength+2]) + return b[:1+1+domainLength+2], err + case AtypIPv4: + _, err = io.ReadFull(r, b[1:1+net.IPv4len+2]) + return b[:1+net.IPv4len+2], err + case AtypIPv6: + _, err = io.ReadFull(r, b[1:1+net.IPv6len+2]) + return b[:1+net.IPv6len+2], err + } + + return nil, ErrAddressNotSupported +} + +// SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed. +func SplitAddr(b []byte) Addr { + addrLen := 1 + if len(b) < addrLen { + return nil + } + + switch b[0] { + case AtypDomainName: + if len(b) < 2 { + return nil + } + addrLen = 1 + 1 + int(b[1]) + 2 + case AtypIPv4: + addrLen = 1 + net.IPv4len + 2 + case AtypIPv6: + addrLen = 1 + net.IPv6len + 2 + default: + return nil + + } + + if len(b) < addrLen { + return nil + } + + return b[:addrLen] +} + +// ParseAddr parses the address in string s. Returns nil if failed. +func ParseAddr(s string) Addr { + var addr Addr + host, port, err := net.SplitHostPort(s) + if err != nil { + return nil + } + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + addr = make([]byte, 1+net.IPv4len+2) + addr[0] = AtypIPv4 + copy(addr[1:], ip4) + } else { + addr = make([]byte, 1+net.IPv6len+2) + addr[0] = AtypIPv6 + copy(addr[1:], ip) + } + } else { + if len(host) > 255 { + return nil + } + addr = make([]byte, 1+1+len(host)+2) + addr[0] = AtypDomainName + addr[1] = byte(len(host)) + copy(addr[2:], host) + } + + portnum, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil + } + + addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum) + + return addr +} + +// ParseAddrToSocksAddr parse a socks addr from net.addr +// This is a fast path of ParseAddr(addr.String()) +func ParseAddrToSocksAddr(addr net.Addr) Addr { + var hostip net.IP + var port int + if udpaddr, ok := addr.(*net.UDPAddr); ok { + hostip = udpaddr.IP + port = udpaddr.Port + } else if tcpaddr, ok := addr.(*net.TCPAddr); ok { + hostip = tcpaddr.IP + port = tcpaddr.Port + } + + // fallback parse + if hostip == nil { + return ParseAddr(addr.String()) + } + + var parsed Addr + if ip4 := hostip.To4(); ip4.DefaultMask() != nil { + parsed = make([]byte, 1+net.IPv4len+2) + parsed[0] = AtypIPv4 + copy(parsed[1:], ip4) + binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port)) + + } else { + parsed = make([]byte, 1+net.IPv6len+2) + parsed[0] = AtypIPv6 + copy(parsed[1:], hostip) + binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port)) + } + return parsed +} + +func AddrFromStdAddrPort(addrPort netip.AddrPort) Addr { + addr := addrPort.Addr() + if addr.Is4() { + ip4 := addr.As4() + return []byte{AtypIPv4, ip4[0], ip4[1], ip4[2], ip4[3], byte(addrPort.Port() >> 8), byte(addrPort.Port())} + } + + buf := make([]byte, 1+net.IPv6len+2) + buf[0] = AtypIPv6 + copy(buf[1:], addr.AsSlice()) + buf[1+net.IPv6len] = byte(addrPort.Port() >> 8) + buf[1+net.IPv6len+1] = byte(addrPort.Port()) + return buf +} + +// DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet` +func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { + if len(packet) < 5 { + err = errors.New("insufficient length of packet") + return + } + + // packet[0] and packet[1] are reserved + if !bytes.Equal(packet[:2], []byte{0, 0}) { + err = errors.New("reserved fields should be zero") + return + } + + if packet[2] != 0 /* fragments */ { + err = errors.New("discarding fragmented payload") + return + } + + addr = SplitAddr(packet[3:]) + if addr == nil { + err = errors.New("failed to read UDP header") + } + + payload = packet[3+len(addr):] + return +} + +func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) { + if addr == nil { + err = errors.New("address is invalid") + return + } + packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{}) + return +} diff --git a/vpn/nic/gvisor/gvisor.go b/vpn/nic/gvisor/gvisor.go index f2707a6..09a5028 100644 --- a/vpn/nic/gvisor/gvisor.go +++ b/vpn/nic/gvisor/gvisor.go @@ -9,6 +9,7 @@ import ( "strings" "sync" + N "github.com/sigcn/pg/net" "github.com/sigcn/pg/vpn/nic" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" @@ -17,6 +18,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) var ( @@ -122,6 +125,75 @@ func (g *GvisorCard) Close() error { return nil } +func (g *GvisorCard) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + g.init() + if !strings.HasPrefix(network, "tcp") && !strings.HasPrefix(network, "udp") { + return nil, errors.New("only tcp/udp is supported") + } + + if strings.HasPrefix(network, "tcp") { + tcpAddr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + var add tcpip.Address + var protocol tcpip.NetworkProtocolNumber + if tcpAddr.IP.To4() != nil { + add = tcpip.AddrFrom4(tcpAddr.AddrPort().Addr().As4()) + protocol = ipv4.ProtocolNumber + } else { + add = tcpip.AddrFrom16(tcpAddr.AddrPort().Addr().As16()) + protocol = ipv6.ProtocolNumber + } + addr := tcpip.FullAddress{ + NIC: g.nicID, + Addr: add, + Port: tcpAddr.AddrPort().Port()} + return gonet.DialContextTCP(ctx, g.Stack, addr, protocol) + } + + if strings.HasPrefix(network, "udp") { + udpAddr, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + var add tcpip.Address + var protocol tcpip.NetworkProtocolNumber + if udpAddr.IP.To4() != nil { + add = tcpip.AddrFrom4(udpAddr.AddrPort().Addr().As4()) + protocol = ipv4.ProtocolNumber + } else { + add = tcpip.AddrFrom16(udpAddr.AddrPort().Addr().As16()) + protocol = ipv6.ProtocolNumber + } + addr := &tcpip.FullAddress{ + NIC: g.nicID, + Addr: add, + Port: udpAddr.AddrPort().Port()} + return gonet.DialUDP(g.Stack, nil, addr, protocol) + } + return nil, nil +} + +func (g *GvisorCard) listenUDP(addr tcpip.FullAddress) (net.PacketConn, error) { + var wq waiter.Queue + var ep tcpip.Endpoint + var err tcpip.Error + if net.IP(addr.Addr.AsSlice()).To4() != nil { + ep, err = g.Stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + } else { + ep, err = g.Stack.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + } + if err != nil { + return nil, errors.New(err.String()) + } + err = ep.Bind(addr) + if err != nil { + return nil, errors.New(err.String()) + } + return gonet.NewUDPConn(&wq, ep), nil +} + func (g *GvisorCard) Listen(ctx context.Context, network string, port uint16) (l net.Listener, err error) { g.init() if !strings.HasPrefix(network, "tcp") && !strings.HasPrefix(network, "udp") { @@ -140,12 +212,20 @@ func (g *GvisorCard) Listen(ctx context.Context, network string, port uint16) (l if network == "udp4" { addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr4, Port: port} - return &udpListener{s: g.Stack, addr: addr}, nil + pc, err := g.listenUDP(addr) + if err != nil { + return nil, err + } + return &N.UDPListener{PacketConn: pc}, nil } if network == "udp6" { addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr6, Port: port} - return &udpListener{s: g.Stack, addr: addr}, nil + pc, err := g.listenUDP(addr) + if err != nil { + return nil, err + } + return &N.UDPListener{PacketConn: pc}, nil } var listeners []net.Listener @@ -160,11 +240,19 @@ func (g *GvisorCard) Listen(ctx context.Context, network string, port uint16) (l if network == "udp" { if g.addr4.Len() > 0 { addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr4, Port: port} - listeners = append(listeners, &udpListener{s: g.Stack, addr: addr}) + pc, err := g.listenUDP(addr) + if err != nil { + return nil, err + } + listeners = append(listeners, &N.UDPListener{PacketConn: pc}) } if g.addr6.Len() > 0 { addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr6, Port: port} - listeners = append(listeners, &udpListener{s: g.Stack, addr: addr}) + pc, err := g.listenUDP(addr) + if err != nil { + return nil, err + } + listeners = append(listeners, &N.UDPListener{PacketConn: pc}) } return &combinedListeners{listeners: listeners}, nil } diff --git a/vpn/nic/gvisor/udp.go b/vpn/nic/gvisor/udp.go deleted file mode 100644 index dae7560..0000000 --- a/vpn/nic/gvisor/udp.go +++ /dev/null @@ -1,192 +0,0 @@ -package gvisor - -import ( - "context" - "errors" - "log/slog" - "net" - "sync" - "sync/atomic" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var _ net.Listener = (*udpListener)(nil) -var _ net.Conn = (*udpConn)(nil) - -type udpConn struct { - removeConn func() - remoteAddr net.Addr - c *gonet.UDPConn - - closeOnce sync.Once - inbound chan []byte - closeChan chan struct{} - lastActiveTime atomic.Value -} - -func (c *udpConn) init() { - c.inbound = make(chan []byte, 512) - c.closeChan = make(chan struct{}) - c.lastActiveTime.Store(time.Now()) - ticker := time.NewTicker(6 * time.Second) - go func() { // create a timer to trace timeout udp conn, and close it - defer ticker.Stop() - for range ticker.C { - if time.Since(c.lastActiveTime.Load().(time.Time)) > 10*time.Second { - c.Close() - break - } - } - }() -} - -func (c *udpConn) Read(p []byte) (int, error) { - select { - case b := <-c.inbound: - c.lastActiveTime.Store(time.Now()) - return copy(p, b), nil - case <-c.closeChan: - return 0, net.ErrClosed - } -} - -func (c *udpConn) Write(p []byte) (int, error) { - c.lastActiveTime.Store(time.Now()) - return c.c.WriteTo(p, c.remoteAddr) -} - -func (c *udpConn) LocalAddr() net.Addr { - return c.c.LocalAddr() -} - -func (c *udpConn) RemoteAddr() net.Addr { - return c.remoteAddr -} - -func (c *udpConn) Close() error { - c.closeOnce.Do(func() { - close(c.closeChan) - close(c.inbound) - c.removeConn() - slog.Log(context.Background(), -2, "[gVisor] UDPConn closed", "local_addr", c.LocalAddr(), "remote_addr", c.remoteAddr) - }) - return nil -} - -func (c *udpConn) SetDeadline(t time.Time) error { - return errors.ErrUnsupported -} - -func (c *udpConn) SetReadDeadline(t time.Time) error { - return errors.ErrUnsupported -} - -func (c *udpConn) SetWriteDeadline(t time.Time) error { - return errors.ErrUnsupported -} - -type udpListener struct { - addr tcpip.FullAddress - s *stack.Stack - - buf []byte - c *gonet.UDPConn - initErr error - initOnce sync.Once - closeOnce sync.Once - - connMap map[net.Addr]*udpConn - connMapMu sync.RWMutex -} - -func (l *udpListener) init() { - l.initOnce.Do(func() { - var wq waiter.Queue - var ep tcpip.Endpoint - var err tcpip.Error - if net.IP(l.addr.Addr.AsSlice()).To4() != nil { - ep, err = l.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - } else { - ep, err = l.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) - } - if err != nil { - l.initErr = errors.New(err.String()) - return - } - err = ep.Bind(l.addr) - if err != nil { - l.initErr = errors.New(err.String()) - return - } - l.buf = make([]byte, 65535) - l.connMap = make(map[net.Addr]*udpConn) - l.c = gonet.NewUDPConn(&wq, ep) - }) -} - -func (l *udpListener) Accept() (net.Conn, error) { - l.init() - if l.initErr != nil { - return nil, l.initErr - } -read: - n, peerAddr, err := l.c.ReadFrom(l.buf) - if err != nil { - return nil, err - } - l.connMapMu.RLock() - conn, ok := l.connMap[peerAddr] - l.connMapMu.RUnlock() - if ok { - conn.inbound <- append([]byte(nil), l.buf[:n]...) - goto read - } - l.connMapMu.Lock() - conn, ok = l.connMap[peerAddr] - if ok { - l.connMapMu.Unlock() - conn.inbound <- append([]byte(nil), l.buf[:n]...) - goto read - } - defer l.connMapMu.Unlock() - conn = &udpConn{remoteAddr: peerAddr, c: l.c, removeConn: func() { - l.connMapMu.Lock() - defer l.connMapMu.Unlock() - delete(l.connMap, peerAddr) - }} - conn.init() - l.connMap[peerAddr] = conn - conn.inbound <- append([]byte(nil), l.buf[:n]...) - return conn, nil -} - -func (l *udpListener) Close() error { - if l.c == nil { - return nil - } - l.closeOnce.Do(func() { - l.c.Close() - l.connMapMu.Lock() - defer l.connMapMu.Unlock() - for _, c := range l.connMap { - go c.Close() - } - }) - return nil -} - -func (l *udpListener) Addr() net.Addr { - l.init() - if l.c == nil { - return nil - } - return l.c.LocalAddr() -}