diff --git a/policy.go b/policy.go index ebef8b9..21b076c 100644 --- a/policy.go +++ b/policy.go @@ -51,6 +51,9 @@ const ( // Note: an example usage can be found in the SkipProxyHeaderForCIDR // function. SKIP + // REFUSE is the same as REJECT if a proxy header is set and the same as + // REQUIRE if a proxy header is not set. + REFUSE ) // SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a @@ -117,7 +120,7 @@ func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) { return nil, err } - return whitelistPolicy(allowFrom, REJECT), nil + return whitelistPolicy(allowFrom, REFUSE), nil } // MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic diff --git a/policy_test.go b/policy_test.go index a888bdd..d4c808e 100644 --- a/policy_test.go +++ b/policy_test.go @@ -42,8 +42,8 @@ func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *t t.Fatalf("err: %v", err) } - if policy != REJECT { - t.Fatalf("Expected policy REJECT, got %v", policy) + if policy != REFUSE { + t.Fatalf("Expected policy REFUSE, got %v", policy) } } diff --git a/protocol.go b/protocol.go index 658900a..178bfdc 100644 --- a/protocol.go +++ b/protocol.go @@ -288,7 +288,7 @@ func (p *Conn) readHeader() error { // let's act as if there was no error when PROXY protocol is not present. if err == ErrNoProxyProtocol { // but not if it is required that the connection has one - if p.ProxyHeaderPolicy == REQUIRE { + if p.ProxyHeaderPolicy == REQUIRE || p.ProxyHeaderPolicy == REFUSE { return err } @@ -298,7 +298,7 @@ func (p *Conn) readHeader() error { // proxy protocol header was found if err == nil && header != nil { switch p.ProxyHeaderPolicy { - case REJECT: + case REJECT, REFUSE: // this connection is not allowed to send one return ErrSuperfluousProxyHeader case USE, REQUIRE: diff --git a/protocol_test.go b/protocol_test.go index fd976d1..0e9d1a9 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -752,100 +752,108 @@ func TestAcceptReturnsErrorWhenConnPolicyFuncErrors(t *testing.T) { } func TestReadingIsRefusedWhenProxyHeaderRequiredButMissing(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) + policyFuncs := []PolicyFunc{ + func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }, + func(upstream net.Addr) (Policy, error) { return REFUSE, nil }, } + for _, policyFunc := range policyFuncs { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } - policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil } + pl := &Listener{Listener: l, Policy: policyFunc} - pl := &Listener{Listener: l, Policy: policyFunc} + cliResult := make(chan error) + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + cliResult <- err + return + } + defer conn.Close() - cliResult := make(chan error) - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) + if _, err := conn.Write([]byte("ping")); err != nil { + cliResult <- err + return + } + + close(cliResult) + }() + + conn, err := pl.Accept() if err != nil { - cliResult <- err - return + t.Fatalf("err: %v", err) } defer conn.Close() - if _, err := conn.Write([]byte("ping")); err != nil { - cliResult <- err - return + recv := make([]byte, 4) + if _, err = conn.Read(recv); err != ErrNoProxyProtocol { + t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) + } + err = <-cliResult + if err != nil { + t.Fatalf("client error: %v", err) } - - close(cliResult) - }() - - conn, err := pl.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - recv := make([]byte, 4) - if _, err = conn.Read(recv); err != ErrNoProxyProtocol { - t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) - } - err = <-cliResult - if err != nil { - t.Fatalf("client error: %v", err) } } func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) + policyFuncs := []PolicyFunc{ + func(upstream net.Addr) (Policy, error) { return REJECT, nil }, + func(upstream net.Addr) (Policy, error) { return REFUSE, nil }, } + for _, policyFunc := range policyFuncs { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } - policyFunc := func(upstream net.Addr) (Policy, error) { return REJECT, nil } + pl := &Listener{Listener: l, Policy: policyFunc} - pl := &Listener{Listener: l, Policy: policyFunc} + cliResult := make(chan error) + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + cliResult <- err + return + } + defer conn.Close() + header := &Header{ + Version: 2, + Command: PROXY, + TransportProtocol: TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP("20.2.2.2"), + Port: 2000, + }, + } + if _, err := header.WriteTo(conn); err != nil { + cliResult <- err + return + } - cliResult := make(chan error) - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) + close(cliResult) + }() + + conn, err := pl.Accept() if err != nil { - cliResult <- err - return + t.Fatalf("err: %v", err) } defer conn.Close() - header := &Header{ - Version: 2, - Command: PROXY, - TransportProtocol: TCPv4, - SourceAddr: &net.TCPAddr{ - IP: net.ParseIP("10.1.1.1"), - Port: 1000, - }, - DestinationAddr: &net.TCPAddr{ - IP: net.ParseIP("20.2.2.2"), - Port: 2000, - }, + + recv := make([]byte, 4) + if _, err = conn.Read(recv); err != ErrSuperfluousProxyHeader { + t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err) } - if _, err := header.WriteTo(conn); err != nil { - cliResult <- err - return + err = <-cliResult + if err != nil { + t.Fatalf("client error: %v", err) } - - close(cliResult) - }() - - conn, err := pl.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - recv := make([]byte, 4) - if _, err = conn.Read(recv); err != ErrSuperfluousProxyHeader { - t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err) - } - err = <-cliResult - if err != nil { - t.Fatalf("client error: %v", err) } } func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {