Skip to content

Commit

Permalink
Make the list of cidrs and hosts configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
hamstah committed Dec 19, 2018
1 parent a47e0c3 commit 649d737
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 48 deletions.
104 changes: 60 additions & 44 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,77 @@ func mustParseCIDR(addr string) *net.IPNet {
return ipnet
}

var (
netPrivateClassA = mustParseCIDR("10.0.0.0/8")
netPrivateClassB = mustParseCIDR("172.16.0.0/12")
netPrivateClassC = mustParseCIDR("192.168.0.0/16")
netTestNet = mustParseCIDR("192.0.2.0/24")
net6To4Relay = mustParseCIDR("192.88.99.0/24")
)
// Config stores the rules for allowing IP/hosts
type Config struct {
ForbiddenCIDRs []*net.IPNet
ForbiddenHosts []*regexp.Regexp
}

// IsHostForbidden checks whether a hostname is forbidden by the Config
func (c *Config) IsHostForbidden(host string) bool {
for _, forbiddenHost := range c.ForbiddenHosts {
if forbiddenHost.MatchString(host) {
return true
}
}
return false
}

// IsIPForbidden checks whether an IP address is forbidden by the Config
func (c *Config) IsIPForbidden(ip net.IP) bool {
if ip.To4() == nil {
panic("cannot be called for IPv6")
}

if ip.Equal(net.IPv4bcast) || !ip.IsGlobalUnicast() {
return true
}

for _, forbiddenCIDR := range c.ForbiddenCIDRs {
if forbiddenCIDR.Contains(ip) {
return true
}
}
return false
}

// DefaultConfig contains the most common hosts and IPs to be blocked
func DefaultConfig() *Config {
return &Config{
ForbiddenCIDRs: []*net.IPNet{
mustParseCIDR("10.0.0.0/8"), // private class A
mustParseCIDR("172.16.0.0/12"), // private class B
mustParseCIDR("192.168.0.0/16"), // private class C
mustParseCIDR("192.0.2.0/24"), // test net 1
mustParseCIDR("192.88.99.0/24"), // 6to4 relay
},

ForbiddenHosts: []*regexp.Regexp{
regexp.MustCompile("(?i)^localhost$"),
regexp.MustCompile("(?i)\\s+"),
},
}
}

func init() {
DefaultClient, _, _ = NewClient()
DefaultClient, _, _ = NewClient(DefaultConfig())
}

func safeAddr(ctx context.Context, resolver *net.Resolver, hostport string) (string, error) {
func safeAddr(ctx context.Context, resolver *net.Resolver, config *Config, hostport string) (string, error) {
host, port, err := net.SplitHostPort(hostport)
if err != nil {
return "", err
}

ip := net.ParseIP(host)
if ip != nil {
if ip.To4() != nil && isBadIPv4(ip) {
if ip.To4() != nil && config.IsIPForbidden(ip) {
return "", fmt.Errorf("bad ip is detected: %v", ip)
}
return net.JoinHostPort(ip.String(), port), nil
}

if isBadHost(host) {
if config.IsHostForbidden(host) {
return "", fmt.Errorf("bad host is detected: %v", host)
}

Expand All @@ -66,7 +110,7 @@ func safeAddr(ctx context.Context, resolver *net.Resolver, hostport string) (str
if addr.IP.To4() == nil {
continue
}
if isBadIPv4(addr.IP) {
if config.IsIPForbidden(addr.IP) {
return "", fmt.Errorf("bad ip is detected: %v", addr.IP)
}
safeAddrs = append(safeAddrs, addr)
Expand All @@ -81,11 +125,11 @@ func safeAddr(ctx context.Context, resolver *net.Resolver, hostport string) (str
//
// This is used to create a new paranoid http.Client,
// because I'm not sure about a paranoid behavior for IPv6 connections :(
func NewDialer(dialer *net.Dialer) func(ctx context.Context, network, addr string) (net.Conn, error) {
func NewDialer(dialer *net.Dialer, config *Config) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, hostport string) (net.Conn, error) {
switch network {
case "tcp", "tcp4":
addr, err := safeAddr(ctx, dialer.Resolver, hostport)
addr, err := safeAddr(ctx, dialer.Resolver, config, hostport)
if err != nil {
return nil, err
}
Expand All @@ -99,46 +143,18 @@ func NewDialer(dialer *net.Dialer) func(ctx context.Context, network, addr strin
// NewClient returns a new http.Client configured to be paranoid for attackers.
//
// This also returns http.Tranport and net.Dialer so that you can customize those behavior.
func NewClient() (*http.Client, *http.Transport, *net.Dialer) {
func NewClient(config *Config) (*http.Client, *http.Transport, *net.Dialer) {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: NewDialer(dialer),
DialContext: NewDialer(dialer, config),
TLSHandshakeTimeout: 10 * time.Second,
}
return &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}, transport, dialer
}

var regLocalhost = regexp.MustCompile("(?i)^localhost$")
var regHasSpace = regexp.MustCompile("(?i)\\s+")

func isBadHost(host string) bool {
if regLocalhost.MatchString(host) {
return true
}
if regHasSpace.MatchString(host) {
return true
}

return false
}

func isBadIPv4(ip net.IP) bool {
if ip.To4() == nil {
panic("cannot be called for IPv6")
}

if ip.Equal(net.IPv4bcast) || !ip.IsGlobalUnicast() ||
netPrivateClassA.Contains(ip) || netPrivateClassB.Contains(ip) || netPrivateClassC.Contains(ip) ||
netTestNet.Contains(ip) || net6To4Relay.Contains(ip) {
return true
}

return false
}
11 changes: 7 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ func TestRequest(t *testing.T) {
}

func TestIsBadHost(t *testing.T) {
config := DefaultConfig()

badHosts := []string{
"localhost",
"host has space",
}

for _, h := range badHosts {
if !isBadHost(h) {
if !config.IsHostForbidden(h) {
t.Errorf("%s should be bad", h)
}
}
Expand All @@ -34,13 +36,14 @@ func TestIsBadHost(t *testing.T) {
}

for _, h := range notBadHosts {
if isBadHost(h) {
if config.IsHostForbidden(h) {
t.Errorf("%s should not be bad", h)
}
}
}

func TestIsBadIPv4(t *testing.T) {
config := DefaultConfig()
badIPs := []string{
"0.0.0.0", // Unspecified
"127.0.0.0", "127.255.255.255", // Loopback
Expand All @@ -54,7 +57,7 @@ func TestIsBadIPv4(t *testing.T) {
}

for _, ip := range badIPs {
if !isBadIPv4(net.ParseIP(ip)) {
if !config.IsIPForbidden(net.ParseIP(ip)) {
t.Errorf("%s should be bad", ip)
}
}
Expand All @@ -71,7 +74,7 @@ func TestIsBadIPv4(t *testing.T) {
}

for _, ip := range notBadIPs {
if isBadIPv4(net.ParseIP(ip)) {
if config.IsIPForbidden(net.ParseIP(ip)) {
t.Errorf("%s should not be bad", ip)
}
}
Expand Down

0 comments on commit 649d737

Please sign in to comment.