diff --git a/Makefile b/Makefile index 4e670ff..d39170e 100644 --- a/Makefile +++ b/Makefile @@ -12,11 +12,17 @@ lint: cert: mkdir -p testdata + # server rm -f testdata/*.pem testdata/*.srl openssl req -x509 -newkey rsa:4096 -days 365 -nodes -sha256 -keyout testdata/cakey.pem -out testdata/cacert.pem -subj "/C=UK/ST=Test State/L=Test Location/O=Test Org/OU=Test Unit/CN=*.example.com/emailAddress=k1lowxb@gmail.com" openssl req -newkey rsa:4096 -nodes -keyout testdata/key.pem -out testdata/csr.pem -subj "/C=JP/ST=Test State/L=Test Location/O=Test Org/OU=Test Unit/CN=*.example.com/emailAddress=k1lowxb@gmail.com" openssl x509 -req -sha256 -in testdata/csr.pem -days 60 -CA testdata/cacert.pem -CAkey testdata/cakey.pem -CAcreateserial -out testdata/cert.pem -extfile testdata/openssl.cnf openssl verify -CAfile testdata/cacert.pem testdata/cert.pem + # client + openssl req -x509 -newkey rsa:4096 -days 365 -nodes -sha256 -keyout testdata/clientcakey.pem -out testdata/clientcacert.pem -subj "/C=UK/ST=Test State/L=Test Location/O=Test Org/OU=Test Unit/CN=client/emailAddress=k1lowxb@gmail.com" + openssl req -newkey rsa:4096 -nodes -keyout testdata/clientkey.pem -out testdata/clientcsr.pem -subj "/C=JP/ST=Test State/L=Test Location/O=Test Org/OU=Test Unit/CN=client/emailAddress=k1lowxb@gmail.com" + openssl x509 -req -sha256 -in testdata/clientcsr.pem -days 60 -CA testdata/clientcacert.pem -CAkey testdata/clientcakey.pem -CAcreateserial -out testdata/clientcert.pem + openssl verify -CAfile testdata/clientcacert.pem testdata/clientcert.pem depsdev: go install github.com/Songmu/ghch/cmd/ghch@latest diff --git a/httpstub.go b/httpstub.go index 9ef4330..e882d22 100644 --- a/httpstub.go +++ b/httpstub.go @@ -18,14 +18,15 @@ import ( var _ http.Handler = (*Router)(nil) type Router struct { - matchers []*matcher - server *httptest.Server - middlewares middlewareFuncs - requests []*http.Request - t *testing.T - useTLS bool - cacert, cert, key []byte - mu sync.RWMutex + matchers []*matcher + server *httptest.Server + middlewares middlewareFuncs + requests []*http.Request + t *testing.T + useTLS bool + cacert, cert, key []byte + clientCacert, clientCert, clientKey []byte + mu sync.RWMutex } type matcher struct { @@ -86,11 +87,14 @@ func NewRouter(t *testing.T, opts ...Option) *Router { } } return &Router{ - t: t, - useTLS: c.useTLS, - cacert: c.cacert, - cert: c.cert, - key: c.key, + t: t, + useTLS: c.useTLS, + cacert: c.cacert, + cert: c.cert, + key: c.key, + clientCacert: c.clientCacert, + clientCert: c.clientCert, + clientKey: c.clientKey, } } @@ -127,7 +131,9 @@ func (rt *Router) Server() *httptest.Server { } if rt.useTLS { rt.server = httptest.NewUnstartedServer(rt) - if len(rt.cert) > 0 && len(rt.key) > 0 { + + // server certificates + if rt.cert != nil && rt.key != nil { cert, err := tls.X509KeyPair(rt.cert, rt.key) if err != nil { panic(err) @@ -140,8 +146,31 @@ func (rt *Router) Server() *httptest.Server { } rt.server.TLS.Certificates = []tls.Certificate{cert} } + // client CA + if rt.clientCacert != nil { + certpool, err := x509.SystemCertPool() + if err != nil { + // FIXME for Windows + // ref: https://github.com/golang/go/issues/18609 + certpool = x509.NewCertPool() + } + if !certpool.AppendCertsFromPEM(rt.clientCacert) { + panic("failed to add cacert") + } + existingConfig := rt.server.TLS + if existingConfig != nil { + rt.server.TLS = existingConfig.Clone() + } else { + rt.server.TLS = new(tls.Config) + } + rt.server.TLS.ClientCAs = certpool + rt.server.TLS.ClientAuth = tls.RequireAndVerifyClientCert + } + rt.server.StartTLS() - if len(rt.cacert) > 0 { + + // server CA + if rt.cacert != nil { certpool, err := x509.SystemCertPool() if err != nil { // FIXME for Windows @@ -154,6 +183,15 @@ func (rt *Router) Server() *httptest.Server { client := rt.server.Client() client.Transport.(*http.Transport).TLSClientConfig.RootCAs = certpool } + // client certificates + if rt.clientCert != nil && rt.clientKey != nil { + cert, err := tls.X509KeyPair(rt.clientCert, rt.clientKey) + if err != nil { + panic(err) + } + client := rt.server.Client() + client.Transport.(*http.Transport).TLSClientConfig.Certificates = []tls.Certificate{cert} + } } else { rt.server = httptest.NewServer(rt) } diff --git a/httpstub_test.go b/httpstub_test.go index e348d3a..b4d5251 100644 --- a/httpstub_test.go +++ b/httpstub_test.go @@ -390,3 +390,58 @@ func TestUseTLSWithCertificates(t *testing.T) { } } } + +func TestClientCertififaces(t *testing.T) { + clientCacert, err := os.ReadFile("testdata/clientcacert.pem") + if err != nil { + t.Fatal(err) + } + clientCert, err := os.ReadFile("testdata/clientcert.pem") + if err != nil { + t.Fatal(err) + } + clientKey, err := os.ReadFile("testdata/clientkey.pem") + if err != nil { + t.Fatal(err) + } + r := NewRouter(t, UseTLS(), ClientCACert(clientCacert), ClientCertificates(clientCert, clientKey)) + r.Method(http.MethodGet).Path("/api/v1/users/1").Header("Content-Type", "application/json").ResponseString(http.StatusOK, `{"name":"alice"}`) + ts := r.Server() + t.Cleanup(func() { + ts.Close() + }) + tc := ts.Client() + res, err := tc.Get("http://example.com/api/v1/users/1") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + res.Body.Close() + }) + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + + { + got := res.StatusCode + want := http.StatusOK + if got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } + { + got := res.Header.Get("Content-Type") + want := "application/json" + if got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } + { + got := string(body) + want := `{"name":"alice"}` + if got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } +} diff --git a/option.go b/option.go index 1364c7f..b8304c3 100644 --- a/option.go +++ b/option.go @@ -1,8 +1,9 @@ package httpstub type config struct { - useTLS bool - cacert, cert, key []byte + useTLS bool + cacert, cert, key []byte + clientCacert, clientCert, clientKey []byte } type Option func(*config) error @@ -24,7 +25,7 @@ func Certificates(cert, key []byte) Option { } } -// CACert set CA certificate +// CACert set CA func CACert(cacert []byte) Option { return func(c *config) error { c.cacert = cacert @@ -41,3 +42,20 @@ func UseTLSWithCertificates(cert, key []byte) Option { return nil } } + +// ClientCertificates set client certificates ( cert, key ) +func ClientCertificates(cert, key []byte) Option { + return func(c *config) error { + c.clientCert = cert + c.clientKey = key + return nil + } +} + +// ClientCACert set client CA +func ClientCACert(cacert []byte) Option { + return func(c *config) error { + c.clientCacert = cacert + return nil + } +}