Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

policy: add REFUSE #114

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ const (
// Note: an example usage can be found in the SkipProxyHeaderForCIDR
// function.
SKIP
// REFUSE is the same as REJECT if a proxy header is set and the same as
// REQUIRE if a proxy header is not set.
REFUSE
)

// SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a
Expand Down Expand Up @@ -117,7 +120,7 @@ func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) {
return nil, err
}

return whitelistPolicy(allowFrom, REJECT), nil
return whitelistPolicy(allowFrom, REFUSE), nil
}

// MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic
Expand Down
4 changes: 2 additions & 2 deletions policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *t
t.Fatalf("err: %v", err)
}

if policy != REJECT {
t.Fatalf("Expected policy REJECT, got %v", policy)
if policy != REFUSE {
t.Fatalf("Expected policy REFUSE, got %v", policy)
}
}

Expand Down
4 changes: 2 additions & 2 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func (p *Conn) readHeader() error {
// let's act as if there was no error when PROXY protocol is not present.
if err == ErrNoProxyProtocol {
// but not if it is required that the connection has one
if p.ProxyHeaderPolicy == REQUIRE {
if p.ProxyHeaderPolicy == REQUIRE || p.ProxyHeaderPolicy == REFUSE {
return err
}

Expand All @@ -298,7 +298,7 @@ func (p *Conn) readHeader() error {
// proxy protocol header was found
if err == nil && header != nil {
switch p.ProxyHeaderPolicy {
case REJECT:
case REJECT, REFUSE:
// this connection is not allowed to send one
return ErrSuperfluousProxyHeader
case USE, REQUIRE:
Expand Down
152 changes: 80 additions & 72 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -752,100 +752,108 @@ func TestAcceptReturnsErrorWhenConnPolicyFuncErrors(t *testing.T) {
}

func TestReadingIsRefusedWhenProxyHeaderRequiredButMissing(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
policyFuncs := []PolicyFunc{
func(upstream net.Addr) (Policy, error) { return REQUIRE, nil },
func(upstream net.Addr) (Policy, error) { return REFUSE, nil },
}
for _, policyFunc := range policyFuncs {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }
pl := &Listener{Listener: l, Policy: policyFunc}

pl := &Listener{Listener: l, Policy: policyFunc}
cliResult := make(chan error)
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
cliResult <- err
return
}
defer conn.Close()

cliResult := make(chan error)
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if _, err := conn.Write([]byte("ping")); err != nil {
cliResult <- err
return
}

close(cliResult)
}()

conn, err := pl.Accept()
if err != nil {
cliResult <- err
return
t.Fatalf("err: %v", err)
}
defer conn.Close()

if _, err := conn.Write([]byte("ping")); err != nil {
cliResult <- err
return
recv := make([]byte, 4)
if _, err = conn.Read(recv); err != ErrNoProxyProtocol {
t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err)
}
err = <-cliResult
if err != nil {
t.Fatalf("client error: %v", err)
}

close(cliResult)
}()

conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

recv := make([]byte, 4)
if _, err = conn.Read(recv); err != ErrNoProxyProtocol {
t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err)
}
err = <-cliResult
if err != nil {
t.Fatalf("client error: %v", err)
}
}

func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
policyFuncs := []PolicyFunc{
func(upstream net.Addr) (Policy, error) { return REJECT, nil },
func(upstream net.Addr) (Policy, error) { return REFUSE, nil },
}
for _, policyFunc := range policyFuncs {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return REJECT, nil }
pl := &Listener{Listener: l, Policy: policyFunc}

pl := &Listener{Listener: l, Policy: policyFunc}
cliResult := make(chan error)
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
cliResult <- err
return
}
defer conn.Close()
header := &Header{
Version: 2,
Command: PROXY,
TransportProtocol: TCPv4,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP("10.1.1.1"),
Port: 1000,
},
DestinationAddr: &net.TCPAddr{
IP: net.ParseIP("20.2.2.2"),
Port: 2000,
},
}
if _, err := header.WriteTo(conn); err != nil {
cliResult <- err
return
}

cliResult := make(chan error)
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
close(cliResult)
}()

conn, err := pl.Accept()
if err != nil {
cliResult <- err
return
t.Fatalf("err: %v", err)
}
defer conn.Close()
header := &Header{
Version: 2,
Command: PROXY,
TransportProtocol: TCPv4,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP("10.1.1.1"),
Port: 1000,
},
DestinationAddr: &net.TCPAddr{
IP: net.ParseIP("20.2.2.2"),
Port: 2000,
},

recv := make([]byte, 4)
if _, err = conn.Read(recv); err != ErrSuperfluousProxyHeader {
t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err)
}
if _, err := header.WriteTo(conn); err != nil {
cliResult <- err
return
err = <-cliResult
if err != nil {
t.Fatalf("client error: %v", err)
}

close(cliResult)
}()

conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

recv := make([]byte, 4)
if _, err = conn.Read(recv); err != ErrSuperfluousProxyHeader {
t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err)
}
err = <-cliResult
if err != nil {
t.Fatalf("client error: %v", err)
}
}
func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {
Expand Down
Loading