diff --git a/README.md b/README.md index 3325193..9b38411 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,14 @@ # Paranoidhttp -[![Build Status](https://travis-ci.org/hakobe/paranoidhttp.svg?branch=master)](https://travis-ci.org/hakobe/paranoidhttp) +[![Build Status](https://travis-ci.org/hakobe/paranoidhttp.svg?branch=master)][travis] +[![Coverage Status](https://coveralls.io/repos/hakobe/paranoidhttp/badge.svg?branch=master)][coveralls] +[![MIT License](http://img.shields.io/badge/license-MIT-blue.svg?style=flat-square)][license] +[![GoDoc](https://godoc.org/github.com/hakobe/paranoidhttp?status.svg)][godoc] + +[travis]: https://travis-ci.org/hakobe/paranoidhttp +[coveralls]: https://coveralls.io/r/hakobe/paranoidhttp?branch=master +[license]: https://github.com/hakobe/paranoidhttp/blob/master/LICENSE +[godoc]: https://godoc.org/github.com/hakobe/paranoidhttp Paranoidhttp provides a pre-configured http.Client that protects you from harm. @@ -23,6 +31,11 @@ client, transport, dialer := paranoidhttp.NewClient() client.Timeout = 10 * time.Second transport.DisableCompression = true dialer.KeepAlive = 60 * time.Second + +// Add an permitted ipnets with functional option +ipNet, _ := net.ParseCIDR("127.0.0.1/32") +client, _, _ := paranoidhttp.New( + paranoidhttp.PermittedIPNets(ipNet)) ``` ## Known Issues diff --git a/client.go b/client.go index d6fe148..4f80444 100644 --- a/client.go +++ b/client.go @@ -4,37 +4,116 @@ import ( "context" "errors" "fmt" - "log" "net" "net/http" "regexp" "time" ) +// Config stores the rules for allowing IP/hosts +type config struct { + ForbiddenIPNets []*net.IPNet + PermittedIPNets []*net.IPNet + ForbiddenHosts []*regexp.Regexp +} + // DefaultClient is the default Client whose setting is the same as http.DefaultClient. -var DefaultClient *http.Client +var ( + defaultConfig config + DefaultClient *http.Client +) func mustParseCIDR(addr string) *net.IPNet { _, ipnet, err := net.ParseCIDR(addr) if err != nil { - log.Fatalf("%s must be parsed", addr) + panic(`net: ParseCIDR("` + addr + `"): ` + err.Error()) } return ipnet } -var ( - netPrivateClassA = mustParseCIDR("10.0.0.0/8") - netPrivateClassB = mustParseCIDR("172.16.0.0/12") - netPrivateClassC = mustParseCIDR("192.168.0.0/16") - netTestNet = mustParseCIDR("192.0.2.0/24") - net6To4Relay = mustParseCIDR("192.88.99.0/24") -) - func init() { + defaultConfig = config{ + ForbiddenIPNets: []*net.IPNet{ + mustParseCIDR("10.0.0.0/8"), // private class A + mustParseCIDR("172.16.0.0/12"), // private class B + mustParseCIDR("192.168.0.0/16"), // private class C + mustParseCIDR("192.0.2.0/24"), // test net 1 + mustParseCIDR("192.88.99.0/24"), // 6to4 relay + }, + ForbiddenHosts: []*regexp.Regexp{ + regexp.MustCompile(`(?i)^localhost$`), + regexp.MustCompile(`(?i)\s+`), + }, + } DefaultClient, _, _ = NewClient() } -func safeAddr(ctx context.Context, resolver *net.Resolver, hostport string) (string, error) { +// isHostForbidden checks whether a hostname is forbidden by the Config +func (c *config) isHostForbidden(host string) bool { + for _, forbiddenHost := range c.ForbiddenHosts { + if forbiddenHost.MatchString(host) { + return true + } + } + return false +} + +// isIPForbidden checks whether an IP address is forbidden by the Config +func (c *config) isIPForbidden(ip net.IP) bool { + for _, permittedIPNet := range c.PermittedIPNets { + if permittedIPNet.Contains(ip) { + return false + } + } + + if ip.Equal(net.IPv4bcast) || !ip.IsGlobalUnicast() { + return true + } + + for _, forbiddenIPNet := range c.ForbiddenIPNets { + if forbiddenIPNet.Contains(ip) { + return true + } + } + return false +} + +// BasicConfig contains the most common hosts and IPs to be blocked +func basicConfig() *config { + c := defaultConfig // copy to return clone + return &c +} + +// Option type of paranoidhttp +type Option func(*config) + +// ForbiddenIPNets sets forbidden IPNets +func ForbiddenIPNets(ips ...*net.IPNet) Option { + return func(c *config) { + c.ForbiddenIPNets = ips + } +} + +// PermittedIPNets sets permitted IPNets +// It takes priority over other forbidden rules. +func PermittedIPNets(ips ...*net.IPNet) Option { + return func(c *config) { + c.PermittedIPNets = ips + } +} + +// ForbiddenHosts set forbidden host rules by regexp +func ForbiddenHosts(hostRegs ...*regexp.Regexp) Option { + return func(c *config) { + c.ForbiddenHosts = hostRegs + } +} + +func safeAddr(ctx context.Context, resolver *net.Resolver, hostport string, opts ...Option) (string, error) { + c := basicConfig() + for _, opt := range opts { + opt(c) + } host, port, err := net.SplitHostPort(hostport) if err != nil { return "", err @@ -42,13 +121,13 @@ func safeAddr(ctx context.Context, resolver *net.Resolver, hostport string) (str ip := net.ParseIP(host) if ip != nil { - if ip.To4() != nil && isBadIPv4(ip) { + if ip.To4() != nil && c.isIPForbidden(ip) { return "", fmt.Errorf("bad ip is detected: %v", ip) } return net.JoinHostPort(ip.String(), port), nil } - if isBadHost(host) { + if c.isHostForbidden(host) { return "", fmt.Errorf("bad host is detected: %v", host) } @@ -66,7 +145,7 @@ func safeAddr(ctx context.Context, resolver *net.Resolver, hostport string) (str if addr.IP.To4() == nil { continue } - if isBadIPv4(addr.IP) { + if c.isIPForbidden(addr.IP) { return "", fmt.Errorf("bad ip is detected: %v", addr.IP) } safeAddrs = append(safeAddrs, addr) @@ -77,15 +156,15 @@ func safeAddr(ctx context.Context, resolver *net.Resolver, hostport string) (str return net.JoinHostPort(safeAddrs[0].IP.String(), port), nil } -// NewDialer returns a dialer function which only allows IPv4 connections. +// NewDialer returns a dialer function which only accepts IPv4 connections. // // This is used to create a new paranoid http.Client, // because I'm not sure about a paranoid behavior for IPv6 connections :( -func NewDialer(dialer *net.Dialer) func(ctx context.Context, network, addr string) (net.Conn, error) { +func NewDialer(dialer *net.Dialer, opts ...Option) func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, hostport string) (net.Conn, error) { switch network { case "tcp", "tcp4": - addr, err := safeAddr(ctx, dialer.Resolver, hostport) + addr, err := safeAddr(ctx, dialer.Resolver, hostport, opts...) if err != nil { return nil, err } @@ -99,14 +178,14 @@ func NewDialer(dialer *net.Dialer) func(ctx context.Context, network, addr strin // NewClient returns a new http.Client configured to be paranoid for attackers. // // This also returns http.Tranport and net.Dialer so that you can customize those behavior. -func NewClient() (*http.Client, *http.Transport, *net.Dialer) { +func NewClient(opts ...Option) (*http.Client, *http.Transport, *net.Dialer) { dialer := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, - DialContext: NewDialer(dialer), + DialContext: NewDialer(dialer, opts...), TLSHandshakeTimeout: 10 * time.Second, } return &http.Client{ @@ -114,31 +193,3 @@ func NewClient() (*http.Client, *http.Transport, *net.Dialer) { Transport: transport, }, transport, dialer } - -var regLocalhost = regexp.MustCompile("(?i)^localhost$") -var regHasSpace = regexp.MustCompile("(?i)\\s+") - -func isBadHost(host string) bool { - if regLocalhost.MatchString(host) { - return true - } - if regHasSpace.MatchString(host) { - return true - } - - return false -} - -func isBadIPv4(ip net.IP) bool { - if ip.To4() == nil { - panic("cannot be called for IPv6") - } - - if ip.Equal(net.IPv4bcast) || !ip.IsGlobalUnicast() || - netPrivateClassA.Contains(ip) || netPrivateClassB.Contains(ip) || netPrivateClassC.Contains(ip) || - netTestNet.Contains(ip) || net6To4Relay.Contains(ip) { - return true - } - - return false -} diff --git a/client_test.go b/client_test.go index 81d43d2..7282dd6 100644 --- a/client_test.go +++ b/client_test.go @@ -15,17 +15,21 @@ func TestRequest(t *testing.T) { if err == nil { t.Errorf("The request for localhost should be fail") } + + if _, err := DefaultClient.Get("http://192.168.0.1"); err == nil { + t.Errorf("The request for localhost should be fail") + } } -func TestIsBadHost(t *testing.T) { +func TestIsHostForbidden(t *testing.T) { badHosts := []string{ "localhost", "host has space", } for _, h := range badHosts { - if !isBadHost(h) { - t.Errorf("%s should be bad", h) + if !basicConfig().isHostForbidden(h) { + t.Errorf("%s should be forbidden", h) } } @@ -36,13 +40,13 @@ func TestIsBadHost(t *testing.T) { } for _, h := range notBadHosts { - if isBadHost(h) { - t.Errorf("%s should not be bad", h) + if basicConfig().isHostForbidden(h) { + t.Errorf("%s should not be forbidden", h) } } } -func TestIsBadIPv4(t *testing.T) { +func TestIsIpForbidden(t *testing.T) { badIPs := []string{ "0.0.0.0", // Unspecified "127.0.0.0", "127.255.255.255", // Loopback @@ -56,8 +60,8 @@ func TestIsBadIPv4(t *testing.T) { } for _, ip := range badIPs { - if !isBadIPv4(net.ParseIP(ip)) { - t.Errorf("%s should be bad", ip) + if !basicConfig().isIPForbidden(net.ParseIP(ip)) { + t.Errorf("%s should be forbidden", ip) } } @@ -73,8 +77,19 @@ func TestIsBadIPv4(t *testing.T) { } for _, ip := range notBadIPs { - if isBadIPv4(net.ParseIP(ip)) { - t.Errorf("%s should not be bad", ip) + if basicConfig().isIPForbidden(net.ParseIP(ip)) { + t.Errorf("%s should not be forbidden", ip) } } + + c := basicConfig() + ip := "172.18.0.1" + if !c.isIPForbidden(net.ParseIP(ip)) { + t.Errorf("%s should be forbidden", ip) + } + + c.PermittedIPNets = append(c.PermittedIPNets, mustParseCIDR("172.18.0.1/32")) + if c.isIPForbidden(net.ParseIP(ip)) { + t.Errorf("%s should not be forbidden", ip) + } }