-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtls.go
101 lines (90 loc) · 2.24 KB
/
tls.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package verify
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"strings"
"sync"
"time"
"github.com/rs/zerolog/log"
)
const maxRemoteWait = 5 * time.Second
type tlsVerifierOptions struct {
rootCAs *x509.CertPool
}
type tlsVerifier struct {
tlsVerifierOptions
mu sync.Mutex
errors map[string]error
}
func newTLSVerifier(opts tlsVerifierOptions) *tlsVerifier {
return &tlsVerifier{
tlsVerifierOptions: opts,
errors: make(map[string]error),
}
}
func (v *tlsVerifier) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: maxRemoteWait,
KeepAlive: maxRemoteWait,
DualStack: true,
}
log.Info().Str("addr", addr).Msg("dialing")
conn, err := tls.DialWithDialer(dialer, network, addr, &tls.Config{
InsecureSkipVerify: true,
ServerName: tlsHost(addr),
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return v.VerifyPeerCertificate(tlsHost(addr), rawCerts)
},
})
if err != nil {
log.Error().Err(err).Str("addr", addr).Msg("error dialing TLS")
return nil, err
}
return conn, nil
}
func (v *tlsVerifier) VerifyPeerCertificate(serverName string, rawCerts [][]byte) error {
certs := make([]*x509.Certificate, len(rawCerts))
for i, rawCert := range rawCerts {
var err error
certs[i], err = x509.ParseCertificate(rawCert)
if err != nil {
log.Error().Err(err).Msg("error parsing TLS certificate")
return err
}
}
opts := x509.VerifyOptions{
DNSName: serverName,
Roots: v.rootCAs,
Intermediates: x509.NewCertPool(),
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
v.mu.Lock()
if err == nil {
delete(v.errors, serverName)
} else {
v.errors[serverName] = err
log.Error().
Err(err).
Str("server-name", serverName).
Msg("invalid TLS certificate")
}
v.mu.Unlock()
return nil
}
func (v *tlsVerifier) GetTLSError(serverName string) error {
v.mu.Lock()
err := v.errors[serverName]
v.mu.Unlock()
return err
}
func tlsHost(targetAddr string) string {
if strings.LastIndex(targetAddr, ":") > strings.LastIndex(targetAddr, "]") {
targetAddr = targetAddr[:strings.LastIndex(targetAddr, ":")]
}
return targetAddr
}