Skip to content

Commit

Permalink
pgcli/vpn: add a proxy server to access the PG network (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkonfj committed Dec 28, 2024
1 parent cdbb6eb commit 4ce25bb
Show file tree
Hide file tree
Showing 6 changed files with 845 additions and 196 deletions.
102 changes: 102 additions & 0 deletions cmd/pgcli/vpn/rootless/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package rootless

import (
"context"
"fmt"
"log/slog"
"net"
"sync"
"time"

N "github.com/sigcn/pg/net"
"github.com/sigcn/pg/socks5"
"github.com/sigcn/pg/vpn/nic/gvisor"
)

type ProxyConfig struct {
Listen string
}

type ProxyServer struct {
Config ProxyConfig
GvisorCard *gvisor.GvisorCard

udpListener *N.UDPListener
}

func (s *ProxyServer) Start(ctx context.Context, wg *sync.WaitGroup) error {
tcpListener, err := net.Listen("tcp", s.Config.Listen)
if err != nil {
return err
}
udpPacketConn, err := net.ListenPacket("udp", s.Config.Listen)
if err != nil {
tcpListener.Close()
return err
}
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
tcpListener.Close()
udpPacketConn.Close()
}()
s.udpListener = &N.UDPListener{PacketConn: udpPacketConn}
slog.Info("[Proxy] Server started", "listen", fmt.Sprintf("tcp+udp://%s", tcpListener.Addr().String()))
go s.run(tcpListener)
return nil
}

func (s *ProxyServer) run(tcp net.Listener) {
for {
c, err := tcp.Accept()
if err != nil {
return
}
addr, cmd, err := socks5.ServerHandshake(c, nil)
if err != nil {
slog.Error("[Proxy] SOCKS5 handshake", "err", err)
continue
}
if cmd == socks5.CmdConnect {
if err := s.proxyTCP(c, addr); err != nil {
slog.Error("[Proxy] SOCKS5 tcp", "err", err)
}
continue
}
if cmd == socks5.CmdUDPAssociate {
go func() {
if err := s.proxyUDP(addr); err != nil {
slog.Error("[Proxy] SOCKS5 udp", "err", err)
}
}()
continue
}
}
}

func (s *ProxyServer) proxyTCP(rw net.Conn, addr socks5.Addr) error {
c, err := s.GvisorCard.DialContext(context.TODO(), "tcp", addr.String())
if err != nil {
rw.Close()
return err
}
go relay(rw, c)
return nil
}

func (s *ProxyServer) proxyUDP(addr socks5.Addr) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, err := s.udpListener.AcceptContext(ctx)
if err != nil {
return err
}
c1, err := s.GvisorCard.DialContext(context.TODO(), "udp", addr.String())
if err != nil {
c.Close()
return err
}
go relay(c, c1)
return nil
}
11 changes: 11 additions & 0 deletions cmd/pgcli/vpn/vpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ func usage(flagSet *flag.FlagSet) {
logLevel := flagSet.Lookup("loglevel")
mtu := flagSet.Lookup("mtu")
peers := flagSet.Lookup("peers")
proxyListen := flagSet.Lookup("proxy-listen")
server := flagSet.Lookup("s")
tun := flagSet.Lookup("tun")
udpPort := flagSet.Lookup("udp-port")
Expand All @@ -119,6 +120,7 @@ func usage(flagSet *flag.FlagSet) {
fmt.Printf(" --key string\n\t%s\n", key.Usage)
fmt.Printf(" --loglevel int\n\t%s (default %s)\n", logLevel.Usage, logLevel.DefValue)
fmt.Printf(" --mtu int\n\t%s (default %s)\n", mtu.Usage, mtu.DefValue)
fmt.Printf(" --proxy-listen string\n\t%s\n", proxyListen.Usage)
fmt.Printf(" -s, --server string\n\t%s\n", server.Usage)
fmt.Printf(" --tun string\n\t%s (default %s)\n", tun.Usage, tun.DefValue)
fmt.Printf(" --udp-crypto\n\t%s (default %s)\n", cryptoAlgo.Usage, cryptoAlgo.DefValue)
Expand Down Expand Up @@ -154,6 +156,7 @@ func createConfig(flagSet *flag.FlagSet, args []string) (cfg Config, err error)
flagSet.IntVar(&cfg.NICConfig.MTU, "mtu", 1411, "nic mtu")
flagSet.StringVar(&cfg.NICConfig.Name, "tun", defaultTunName, "nic name")
flagSet.Var(&forwards, "forward", "start in rootless mode and create a port forward (e.g. tcp://127.0.0.1:80)")
flagSet.StringVar(&cfg.ProxyConfig.Listen, "proxy-listen", "", "start a proxy server to access the PG network (e.g. 127.0.0.1:4090)")

flagSet.StringVar(&cfg.PrivateKey, "key", "", "curve25519 private key in base58 format (default generate a new one)")
flagSet.StringVar(&cfg.SecretFile, "secret-file", "", "")
Expand Down Expand Up @@ -213,6 +216,7 @@ func createConfig(flagSet *flag.FlagSet, args []string) (cfg Config, err error)

type Config struct {
NICConfig nic.Config
ProxyConfig rootless.ProxyConfig
DiscoPortScanOffset int
DiscoPortScanCount int
DiscoPortScanDuration time.Duration
Expand Down Expand Up @@ -265,6 +269,13 @@ func (v *P2PVPN) Run(ctx context.Context) (err error) {
return err
}
}
if v.Config.ProxyConfig.Listen != "" {
if err := (&rootless.ProxyServer{
GvisorCard: card.(*gvisor.GvisorCard),
Config: v.Config.ProxyConfig}).Start(ctx, &wg); err != nil {
return err
}
}
if err := (&server.Server{
Vnic: v.nic,
PeerStore: c.PeerStore(),
Expand Down
183 changes: 183 additions & 0 deletions net/udp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package net

import (
"context"
"errors"
"log/slog"
"net"
"sync"
"sync/atomic"
"time"
)

var _ net.Conn = (*UDPConn)(nil)

type UDPConn struct {
removeConn func()
remoteAddr net.Addr
c net.PacketConn

closeOnce sync.Once
inbound chan []byte
closeChan chan struct{}
lastActiveTime atomic.Value
}

func (c *UDPConn) init() {
c.inbound = make(chan []byte, 512)
c.closeChan = make(chan struct{})
c.lastActiveTime.Store(time.Now())
ticker := time.NewTicker(6 * time.Second)
go func() { // create a timer to trace timeout udp conn, and close it
defer ticker.Stop()
for range ticker.C {
if time.Since(c.lastActiveTime.Load().(time.Time)) > 10*time.Second {
c.Close()
break
}
}
}()
}

func (c *UDPConn) Read(p []byte) (int, error) {
select {
case b := <-c.inbound:
c.lastActiveTime.Store(time.Now())
return copy(p, b), nil
case <-c.closeChan:
return 0, net.ErrClosed
}
}

func (c *UDPConn) Write(p []byte) (int, error) {
c.lastActiveTime.Store(time.Now())
return c.c.WriteTo(p, c.remoteAddr)
}

func (c *UDPConn) LocalAddr() net.Addr {
return c.c.LocalAddr()
}

func (c *UDPConn) RemoteAddr() net.Addr {
return c.remoteAddr
}

func (c *UDPConn) Close() error {
c.closeOnce.Do(func() {
close(c.closeChan)
close(c.inbound)
c.removeConn()
slog.Log(context.Background(), -2, "UDPConn closed", "local_addr", c.LocalAddr(), "remote_addr", c.remoteAddr)
})
return nil
}

func (c *UDPConn) SetDeadline(t time.Time) error {
return errors.ErrUnsupported
}

func (c *UDPConn) SetReadDeadline(t time.Time) error {
return errors.ErrUnsupported
}

func (c *UDPConn) SetWriteDeadline(t time.Time) error {
return errors.ErrUnsupported
}

type UDPListener struct {
PacketConn net.PacketConn

buf []byte
initOnce sync.Once
closeOnce sync.Once
udpChan chan *UDPConn

connMap map[string]*UDPConn
connMapMu sync.RWMutex
}

func (l *UDPListener) init() {
l.initOnce.Do(func() {
l.buf = make([]byte, 65535)
l.udpChan = make(chan *UDPConn, 8)
l.connMap = make(map[string]*UDPConn)
go l.readUDP()
})
}

func (l *UDPListener) readUDP() {
read := func() error {
read:
n, peerAddr, err := l.PacketConn.ReadFrom(l.buf)
if err != nil {
return err
}
l.connMapMu.RLock()
conn, ok := l.connMap[peerAddr.String()]
l.connMapMu.RUnlock()
if ok {
conn.inbound <- append([]byte(nil), l.buf[:n]...)
goto read
}
l.connMapMu.Lock()
conn, ok = l.connMap[peerAddr.String()]
if ok {
l.connMapMu.Unlock()
conn.inbound <- append([]byte(nil), l.buf[:n]...)
goto read
}
defer l.connMapMu.Unlock()
conn = &UDPConn{remoteAddr: peerAddr, c: l.PacketConn, removeConn: func() {
l.connMapMu.Lock()
defer l.connMapMu.Unlock()
delete(l.connMap, peerAddr.String())
}}
conn.init()
l.connMap[peerAddr.String()] = conn
conn.inbound <- append([]byte(nil), l.buf[:n]...)
l.udpChan <- conn
return nil
}
for {
if err := read(); err != nil {
return
}
}
}

func (l *UDPListener) Accept() (net.Conn, error) {
return l.AcceptContext(context.Background())
}

func (l *UDPListener) AcceptContext(ctx context.Context) (net.Conn, error) {
l.init()
select {
case c := <-l.udpChan:
return c, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}

func (l *UDPListener) Close() error {
if l.PacketConn == nil {
return nil
}
l.closeOnce.Do(func() {
l.PacketConn.Close()
l.connMapMu.Lock()
defer l.connMapMu.Unlock()
for _, c := range l.connMap {
go c.Close()
}
})
return nil
}

func (l *UDPListener) Addr() net.Addr {
l.init()
if l.PacketConn == nil {
return nil
}
return l.PacketConn.LocalAddr()
}
Loading

0 comments on commit 4ce25bb

Please sign in to comment.