diff --git a/chain.go b/chain.go index bedeb6f2..9a973ff6 100644 --- a/chain.go +++ b/chain.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "syscall" "time" "github.com/go-log/log" @@ -18,6 +19,7 @@ var ( type Chain struct { isRoute bool Retries int + Mark int nodeGroups []*NodeGroup route []Node // nodes in the selected route } @@ -131,6 +133,10 @@ func (c *Chain) DialContext(ctx context.Context, network, address string, opts . return } +func setSocketMark(fd int, value int) (e error) { + return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, value) +} + func (c *Chain) dialWithOptions(ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) { if options == nil { options = &ChainOptions{} @@ -150,6 +156,20 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op timeout = DialTimeout } + var controlFunction func(_ string, _ string, c syscall.RawConn) error = nil + if c.Mark > 0 { + controlFunction = func(_, _ string, cc syscall.RawConn) error { + return cc.Control(func(fd uintptr) { + ex := setSocketMark(int(fd), c.Mark) + if ex != nil { + log.Logf("net dialer set mark %d error: %s", c.Mark, ex) + } else { + // log.Logf("net dialer set mark %d success", options.Mark) + } + }) + } + } + if route.IsEmpty() { switch network { case "udp", "udp4", "udp6": @@ -160,6 +180,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op } d := &net.Dialer{ Timeout: timeout, + Control: controlFunction, // LocalAddr: laddr, // TODO: optional local address } return d.DialContext(ctx, network, ipAddr) @@ -328,6 +349,7 @@ type ChainOptions struct { Timeout time.Duration Hosts *Hosts Resolver Resolver + Mark int } // ChainOption allows a common way to set chain options. diff --git a/cmd/gost/main.go b/cmd/gost/main.go index f08f4132..7aea15b9 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -31,6 +31,7 @@ func init() { flag.Var(&baseCfg.route.ChainNodes, "F", "forward address, can make a forward chain") flag.Var(&baseCfg.route.ServeNodes, "L", "listen address, can listen on multiple ports (required)") + flag.IntVar(&baseCfg.route.Mark, "M", 0, "Specify out connection mark") flag.StringVar(&configureFile, "C", "", "configure file") flag.BoolVar(&baseCfg.Debug, "D", false, "enable debug log") flag.BoolVar(&printVersion, "V", false, "print version") diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 0db5b2e7..ee1caa6e 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -30,11 +30,13 @@ type route struct { ServeNodes stringList ChainNodes stringList Retries int + Mark int } func (r *route) parseChain() (*gost.Chain, error) { chain := gost.NewChain() chain.Retries = r.Retries + chain.Mark = r.Mark gid := 1 // group ID for _, ns := range r.ChainNodes {