diff --git a/local.go b/local.go index e3c4e0d..44a7cc8 100644 --- a/local.go +++ b/local.go @@ -20,6 +20,7 @@ package main import ( + "context" "crypto/tls" "net" "sync" @@ -57,7 +58,8 @@ func doLocal() { wssURL = "wss://" + *serverName + *path } - listener, err := net.Listen("tcp", *bindAddr) + listenConfig := net.ListenConfig{Control: getControlFunc(defaultLeftTCPConfig)} + listener, err := listenConfig.Listen(context.Background(), "tcp", *bindAddr) if err != nil { logrus.Fatalf("net.Listen: %v", err) } @@ -79,7 +81,7 @@ func doLocal() { func newRightConn() (net.Conn, error) { var rightConn net.Conn d := &net.Dialer{ - Control: getControlFunc(), + Control: getControlFunc(defaultRightTCPConfig), Timeout: handShakeTimeout, } diff --git a/main.go b/main.go index c52409f..0316a7e 100644 --- a/main.go +++ b/main.go @@ -81,6 +81,10 @@ var ( //mux config defaultSmuxConfig *smux.Config + + //tcp config + defaultLeftTCPConfig *tcpConfig + defaultRightTCPConfig *tcpConfig ) const ( @@ -167,6 +171,44 @@ func main() { MaxStreamBuffer: 512 * 1024, } + localTCPConfig := &tcpConfig{ + tfo: false, + noDelay: false, + mss: 0, + sndBuf: 64 * 1024, + rcvBuf: 64 * 1024, + } + + defaultTCPConfig := &tcpConfig{ + tfo: *enableTFO, + noDelay: *enableTCPNoDelay, + mss: *mss, + sndBuf: tcp_SO_SNDBUF, + rcvBuf: tcp_SO_RCVBUF, + } + + // bind + addr, err := net.ResolveTCPAddr("tcp", *bindAddr) + if err != nil { + logrus.Fatalf("bind addr invalid, %v", err) + } + if addr.IP.IsLoopback() { + defaultLeftTCPConfig = localTCPConfig + } else { + defaultLeftTCPConfig = defaultTCPConfig + } + + // remote + addr, err = net.ResolveTCPAddr("tcp", *remoteAddr) + if err != nil { + logrus.Fatalf("remote addr invalid, %v", err) + } + if addr.IP.IsLoopback() { + defaultRightTCPConfig = localTCPConfig + } else { + defaultRightTCPConfig = defaultTCPConfig + } + buffPool = &sync.Pool{New: func() interface{} { return make([]byte, ioCopyBuffSize) }} @@ -186,7 +228,7 @@ func main() { net.DefaultResolver.PreferGo = true net.DefaultResolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { d := net.Dialer{} - d.Control = getControlFunc() + d.Control = getControlFunc(defaultRightTCPConfig) return d.DialContext(ctx, "tcp", *fallbackDNS) } } diff --git a/server.go b/server.go index 196a8f4..8dc9039 100644 --- a/server.go +++ b/server.go @@ -54,7 +54,7 @@ func doServer() { tlsConfig.Certificates = []tls.Certificate{cer} } - listenConfig := net.ListenConfig{Control: getControlFunc()} + listenConfig := net.ListenConfig{Control: getControlFunc(defaultLeftTCPConfig)} innerListener, err := listenConfig.Listen(context.Background(), "tcp", *bindAddr) if err != nil { logrus.Fatalf("tls inner Listener: %v", err) @@ -94,7 +94,8 @@ func doServer() { func handleLeftConn(leftConn net.Conn) { defer leftConn.Close() - rightConn, err := net.Dial("tcp", *remoteAddr) + d := net.Dialer{Control: getControlFunc(defaultRightTCPConfig)} + rightConn, err := d.Dial("tcp", *remoteAddr) if err != nil { logrus.Errorf("tcp failed to dial, %v", err) return diff --git a/tcp.go b/tcp.go new file mode 100644 index 0000000..629e60d --- /dev/null +++ b/tcp.go @@ -0,0 +1,9 @@ +package main + +type tcpConfig struct { + tfo bool + noDelay bool + mss int + sndBuf int + rcvBuf int +} diff --git a/tcp_android.go b/tcp_android.go index 1af1db6..66b60a2 100644 --- a/tcp_android.go +++ b/tcp_android.go @@ -28,14 +28,14 @@ import ( "golang.org/x/sys/unix" ) -func getControlFunc() func(network, address string, c syscall.RawConn) error { +func getControlFunc(conf *tcpConfig) func(network, address string, c syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error { if *vpnMode { if err := c.Control(sendFdToBypass); err != nil { return err } } - return c.Control(setSockOpt) + return c.Control(conf.setSockOpt) } } diff --git a/tcp_android_unix.go b/tcp_android_unix.go index 73fd612..778be42 100644 --- a/tcp_android_unix.go +++ b/tcp_android_unix.go @@ -27,37 +27,37 @@ import ( ) //TCP_MAXSEG TCP_NODELAY SO_SND/RCVBUF etc.. -func setSockOpt(uintFd uintptr) { +func (c *tcpConfig) setSockOpt(uintFd uintptr) { fd := int(uintFd) - if *enableTFO { + if c.tfo { err := unix.SetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_FASTOPEN_CONNECT, 1) if err != nil { logrus.Errorf("setsockopt TCP_FASTOPEN_CONNECT, %v", err) } } - if *enableTCPNoDelay { + if c.noDelay { err := unix.SetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) if err != nil { logrus.Errorf("setsockopt TCP_NODELAY, %v", err) } } - if *mss > 0 { + if c.mss > 0 { err := unix.SetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_MAXSEG, *mss) if err != nil { logrus.Errorf("setsockopt TCP_MAXSEG, %v", err) } } - if tcp_SO_SNDBUF > 0 { - err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUF, tcp_SO_SNDBUF) + if c.sndBuf > 0 { + err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUF, c.sndBuf) if err != nil { logrus.Errorf("setsockopt SO_SNDBUF, %v", err) } } - if tcp_SO_RCVBUF > 0 { - err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, tcp_SO_RCVBUF) + if c.rcvBuf > 0 { + err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, c.rcvBuf) if err != nil { logrus.Errorf("setsockopt SO_RCVBUF, %v", err) } diff --git a/tcp_other.go b/tcp_other.go index fe1cac5..f67c1d8 100644 --- a/tcp_other.go +++ b/tcp_other.go @@ -25,6 +25,6 @@ import ( "syscall" ) -func getControlFunc() func(network, address string, c syscall.RawConn) error { +func getControlFunc(conf *tcpConfig) func(network, address string, c syscall.RawConn) error { return nil } diff --git a/tcp_unix.go b/tcp_unix.go index 7fc11e6..5f5155f 100644 --- a/tcp_unix.go +++ b/tcp_unix.go @@ -23,8 +23,8 @@ package main import "syscall" -func getControlFunc() func(network, address string, c syscall.RawConn) error { +func getControlFunc(conf *tcpConfig) func(network, address string, c syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error { - return c.Control(setSockOpt) + return c.Control(conf.setSockOpt) } } diff --git a/tcp_windows.go b/tcp_windows.go index cf44c61..2920c6d 100644 --- a/tcp_windows.go +++ b/tcp_windows.go @@ -28,38 +28,32 @@ import ( "golang.org/x/sys/windows" ) -func getControlFunc() func(network, address string, c syscall.RawConn) error { +func getControlFunc(conf *tcpConfig) func(network, address string, c syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error { - return c.Control(setSockOpt) + return c.Control(conf.setSockOpt) } } //TCP_MAXSEG TCP_NODELAY SO_SND/RCVBUF etc.. -func setSockOpt(uintptrFd uintptr) { +func (c *tcpConfig) setSockOpt(uintptrFd uintptr) { fd := windows.Handle(uintptrFd) var err error - if *enableTCPNoDelay { + if c.noDelay { err = windows.SetsockoptInt(fd, windows.IPPROTO_TCP, windows.TCP_NODELAY, 1) if err != nil { logrus.Errorf("setsockopt TCP_NODELAY, %v", err) } } - // can't set TCP_MAXSEG on windows - - // if *mss > 0 { - // windows.SetsockoptInt(fd, windows.IPPROTO_TCP, windows.TCP_MAXSEG, *mss) - // } - - if tcp_SO_SNDBUF > 0 { - err := windows.SetsockoptInt(fd, windows.SOL_SOCKET, windows.SO_SNDBUF, tcp_SO_SNDBUF) + if c.sndBuf > 0 { + err := windows.SetsockoptInt(fd, windows.SOL_SOCKET, windows.SO_SNDBUF, c.sndBuf) if err != nil { logrus.Errorf("setsockopt SO_SNDBUF, %v", err) } } - if tcp_SO_RCVBUF > 0 { - err := windows.SetsockoptInt(fd, windows.SOL_SOCKET, windows.SO_RCVBUF, tcp_SO_RCVBUF) + if c.rcvBuf > 0 { + err := windows.SetsockoptInt(fd, windows.SOL_SOCKET, windows.SO_RCVBUF, c.rcvBuf) if err != nil { logrus.Errorf("setsockopt SO_RCVBUF, %v", err) }