diff --git a/testing/proxytest/https.go b/testing/proxytest/https.go new file mode 100644 index 00000000000..1a302394f1c --- /dev/null +++ b/testing/proxytest/https.go @@ -0,0 +1,207 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License 2.0; +// you may not use this file except in compliance with the Elastic License 2.0. + +package proxytest + +import ( + "bufio" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "strings" + + "github.com/elastic/elastic-agent-libs/testing/certutil" +) + +func (p *Proxy) serveHTTPS(w http.ResponseWriter, r *http.Request) { + log := loggerFromReqCtx(r) + log.Debug("handling CONNECT") + + clientCon, err := hijack(w) + if err != nil { + p.http500Error(clientCon, "cannot handle request", err, log) + return + } + defer clientCon.Close() + + // Hijack successful, w is now useless, let's make sure it isn't used by + // mistake ;) + w = nil //nolint:ineffassign,wastedassign // w is now useless, let's make sure it isn't used by mistake ;) + log.Debug("hijacked request") + + // ==================== CONNECT accepted, let the client know + _, err = clientCon.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) + if err != nil { + p.http500Error(clientCon, "failed to send 200-OK after CONNECT", err, log) + return + } + + // ==================== TLS handshake + // client will proceed to perform the TLS handshake with the "target", + // which we're impersonating. + + // generate a TLS certificate matching the target's host + cert, err := p.newTLSCert(r.URL) + if err != nil { + p.http500Error(clientCon, "failed generating certificate", err, log) + return + } + + tlscfg := p.TLS.Clone() + tlscfg.Certificates = []tls.Certificate{*cert} + clientTLSConn := tls.Server(clientCon, tlscfg) + defer clientTLSConn.Close() + err = clientTLSConn.Handshake() + if err != nil { + p.http500Error(clientCon, "failed TLS handshake with client", err, log) + return + } + + clientTLSReader := bufio.NewReader(clientTLSConn) + + notEOF := func(r *bufio.Reader) bool { + _, err = r.Peek(1) + return !errors.Is(err, io.EOF) + } + // ==================== Handle the actual request + for notEOF(clientTLSReader) { + // read request from the client sent after the 1s CONNECT request + req, err := http.ReadRequest(clientTLSReader) + if err != nil { + p.http500Error(clientTLSConn, "failed reading client request", err, log) + return + } + + // carry over the original remote addr + req.RemoteAddr = r.RemoteAddr + + // the read request is relative to the host from the original CONNECT + // request and without scheme. Therefore, set them in the new request. + req.URL, err = url.Parse("https://" + r.Host + req.URL.String()) + if err != nil { + p.http500Error(clientTLSConn, "failed reading request URL from client", err, log) + return + } + cleanUpHeaders(req.Header) + + // now the request is ready, it can be altered and sent just as it's + // done for an HTTP request. + resp, err := p.processRequest(req) + if err != nil { + p.httpError(clientTLSConn, + http.StatusBadGateway, + "failed performing request to target", err, log) + return + } + + clientResp := http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: resp.StatusCode, + TransferEncoding: append([]string{}, resp.TransferEncoding...), + Trailer: resp.Trailer.Clone(), + Body: resp.Body, + ContentLength: resp.ContentLength, + Header: resp.Header.Clone(), + } + + err = clientResp.Write(clientTLSConn) + if err != nil { + p.http500Error(clientTLSConn, "failed writing response body", err, log) + return + } + + _ = resp.Body.Close() + } + + log.Debug("EOF reached, finishing HTTPS handler") +} + +func (p *Proxy) newTLSCert(u *url.URL) (*tls.Certificate, error) { + // generate the certificate key - it needs to be RSA because Elastic Defend + // do not support EC :/ + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("could not create RSA private key: %w", err) + } + host := u.Hostname() + + var name string + var ips []net.IP + ip := net.ParseIP(host) + if ip == nil { // host isn't an IP, therefore it must be an DNS + name = host + } else { + ips = append(ips, ip) + } + + cert, _, err := certutil.GenerateGenericChildCert( + name, + ips, + priv, + &priv.PublicKey, + p.ca.capriv, + p.ca.cacert) + if err != nil { + return nil, fmt.Errorf("could not generate TLS certificate for %s: %w", + host, err) + } + + return cert, nil +} + +func (p *Proxy) http500Error(clientCon net.Conn, msg string, err error, log *slog.Logger) { + p.httpError(clientCon, http.StatusInternalServerError, msg, err, log) +} + +func (p *Proxy) httpError(clientCon net.Conn, status int, msg string, err error, log *slog.Logger) { + log.Error(msg, "err", err) + + resp := http.Response{ + StatusCode: status, + ProtoMajor: 1, + ProtoMinor: 1, + Body: io.NopCloser(strings.NewReader(msg)), + Header: http.Header{}, + } + resp.Header.Set("Content-Type", "text/html; charset=utf-8") + + err = resp.Write(clientCon) + if err != nil { + log.Error("failed writing response", "err", err) + } +} + +func hijack(w http.ResponseWriter) (net.Conn, error) { + hijacker, ok := w.(http.Hijacker) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprint(w, "cannot handle request") + return nil, errors.New("http.ResponseWriter does not support hijacking") + } + + clientCon, _, err := hijacker.Hijack() + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, err = fmt.Fprint(w, "cannot handle request") + + return nil, fmt.Errorf("could not Hijack HTTPS CONNECT request: %w", err) + } + + return clientCon, err +} + +func cleanUpHeaders(h http.Header) { + h.Del("Proxy-Connection") + h.Del("Proxy-Authenticate") + h.Del("Proxy-Authorization") + h.Del("Connection") +} diff --git a/testing/proxytest/proxytest.go b/testing/proxytest/proxytest.go index 5676c5bcc02..2a9c5f040f3 100644 --- a/testing/proxytest/proxytest.go +++ b/testing/proxytest/proxytest.go @@ -5,10 +5,15 @@ package proxytest import ( + "bufio" + "context" + "crypto" "crypto/tls" + "crypto/x509" "fmt" "io" "log" + "log/slog" "net" "net/http" "net/http/httptest" @@ -27,6 +32,7 @@ type Proxy struct { Port string // LocalhostURL is the server URL as "http(s)://localhost:PORT". + // Deprecated. Use Proxy.URL instead. LocalhostURL string // proxiedRequests is a "request log" for every request the proxy receives. @@ -34,6 +40,10 @@ type Proxy struct { proxiedRequestsMu sync.Mutex opts options + log *slog.Logger + + ca ca + client *http.Client } type Option func(o *options) @@ -46,6 +56,14 @@ type options struct { logFn func(format string, a ...any) verbose bool serverTLSConfig *tls.Config + capriv crypto.PrivateKey + cacert *x509.Certificate + client *http.Client +} + +type ca struct { + capriv crypto.PrivateKey + cacert *x509.Certificate } // WithAddress will set the address the server will listen on. The format is as @@ -56,6 +74,24 @@ func WithAddress(addr string) Option { } } +// WithHTTPClient sets http.Client used to proxy requests to the target host. +func WithHTTPClient(c *http.Client) Option { + return func(o *options) { + o.client = c + } +} + +// WithMITMCA sets the CA used for MITM (men in the middle) when proxying HTTPS +// requests. It's used to generate TLS certificates matching the target host. +// Ideally the CA is the same as the one issuing the TLS certificate for the +// proxy set by WithServerTLSConfig. +func WithMITMCA(priv crypto.PrivateKey, cert *x509.Certificate) func(o *options) { + return func(o *options) { + o.capriv = priv + o.cacert = cert + } +} + // WithRequestLog sets the proxy to log every request using logFn. It uses name // as a prefix to the log. func WithRequestLog(name string, logFn func(format string, a ...any)) Option { @@ -75,14 +111,6 @@ func WithRewrite(old, new string) Option { } } -// WithVerboseLog sets the proxy to log every request verbosely. WithRequestLog -// must be used as well, otherwise WithVerboseLog will not take effect. -func WithVerboseLog() Option { - return func(o *options) { - o.verbose = true - } -} - // WithRewriteFn calls f on the request *url.URL before forwarding it. // It takes precedence over WithRewrite. Use if more control over the rewrite // is needed. @@ -92,12 +120,22 @@ func WithRewriteFn(f func(u *url.URL)) Option { } } +// WithServerTLSConfig sets the TLS config for the server. func WithServerTLSConfig(tc *tls.Config) Option { return func(o *options) { o.serverTLSConfig = tc } } +// WithVerboseLog sets the proxy to log every request verbosely and enables +// debug level logging. WithRequestLog must be used as well, otherwise +// WithVerboseLog will not take effect. +func WithVerboseLog() Option { + return func(o *options) { + o.verbose = true + } +} + // New returns a new Proxy ready for use. Use: // - WithAddress to set the proxy's address, // - WithRewrite or WithRewriteFn to rewrite the URL before forwarding the request. @@ -106,7 +144,7 @@ func WithServerTLSConfig(tc *tls.Config) Option { func New(t *testing.T, optns ...Option) *Proxy { t.Helper() - opts := options{addr: ":0"} + opts := options{addr: "127.0.0.1:0", client: &http.Client{}} for _, o := range optns { o(&opts) } @@ -120,17 +158,35 @@ func New(t *testing.T, optns ...Option) *Proxy { t.Fatalf("NewServer failed to create a net.Listener: %v", err) } - p := Proxy{opts: opts} + // Create a text handler that writes to standard output + lv := slog.LevelInfo + if opts.verbose { + lv = slog.LevelDebug + } + p := Proxy{ + opts: opts, + client: opts.client, + log: slog.New(slog.NewTextHandler(logfWriter(opts.logFn), &slog.HandlerOptions{ + AddSource: true, + Level: lv, + })), + } + if opts.capriv != nil && opts.cacert != nil { + p.ca = ca{capriv: opts.capriv, cacert: opts.cacert} + } p.Server = httptest.NewUnstartedServer( http.HandlerFunc(func(ww http.ResponseWriter, r *http.Request) { - w := &statusResponseWriter{w: ww} + w := &proxyResponseWriter{w: ww} requestID := uuid.Must(uuid.NewV4()).String() - opts.logFn("[%s] STARTING - %s %s %s %s\n", - requestID, r.Method, r.URL, r.Proto, r.RemoteAddr) + p.log.Info(fmt.Sprintf("STARTING - %s '%s' %s %s", + r.Method, r.URL, r.Proto, r.RemoteAddr)) + + rr := addIDToReqCtx(r, requestID) + rrr := addLoggerReqCtx(rr, p.log.With("req_id", requestID)) - p.ServeHTTP(w, r) + p.ServeHTTP(w, rrr) opts.logFn(fmt.Sprintf("[%s] DONE %d - %s %s %s %s\n", requestID, w.statusCode, r.Method, r.URL, r.Proto, r.RemoteAddr)) @@ -150,7 +206,6 @@ func New(t *testing.T, optns ...Option) *Proxy { p.Port = u.Port() p.LocalhostURL = "http://localhost:" + p.Port - opts.logFn("running on %s -> %s", p.URL, p.LocalhostURL) return &p } @@ -164,7 +219,7 @@ func (p *Proxy) Start() error { p.Port = u.Port() p.LocalhostURL = "http://localhost:" + p.Port - p.opts.logFn("running on %s -> %s", p.URL, p.LocalhostURL) + p.log.Info(fmt.Sprintf("running on %s -> %s", p.URL, p.LocalhostURL)) return nil } @@ -178,34 +233,21 @@ func (p *Proxy) StartTLS() error { p.Port = u.Port() p.LocalhostURL = "https://localhost:" + p.Port - p.opts.logFn("running on %s -> %s", p.URL, p.LocalhostURL) + p.log.Info(fmt.Sprintf("running on %s -> %s", p.URL, p.LocalhostURL)) return nil } func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - origURL := r.URL.String() - - switch { - case p.opts.rewriteURL != nil: - p.opts.rewriteURL(r.URL) - case p.opts.rewriteHost != nil: - r.URL.Host = p.opts.rewriteHost(r.URL.Host) - } - - if p.opts.verbose { - p.opts.logFn("original URL: %s, new URL: %s", - origURL, r.URL.String()) + if r.Method == http.MethodConnect { + p.serveHTTPS(w, r) + return } - p.proxiedRequestsMu.Lock() - p.proxiedRequests = append(p.proxiedRequests, - fmt.Sprintf("%s - %s %s %s", - r.Method, r.URL.Scheme, r.URL.Host, r.URL.String())) - p.proxiedRequestsMu.Unlock() - - r.RequestURI = "" + p.serveHTTP(w, r) +} - resp, err := http.DefaultClient.Do(r) +func (p *Proxy) serveHTTP(w http.ResponseWriter, r *http.Request) { + resp, err := p.processRequest(r) if err != nil { w.WriteHeader(http.StatusInternalServerError) msg := fmt.Sprintf("could not make request: %#v", err.Error()) @@ -225,6 +267,34 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +// processRequest executes the configured request manipulation and perform the +// request. +func (p *Proxy) processRequest(r *http.Request) (*http.Response, error) { + origURL := r.URL.String() + + switch { + case p.opts.rewriteURL != nil: + p.opts.rewriteURL(r.URL) + case p.opts.rewriteHost != nil: + r.URL.Host = p.opts.rewriteHost(r.URL.Host) + } + + p.log.Debug(fmt.Sprintf("original URL: %s, new URL: %s", + origURL, r.URL.String())) + + p.proxiedRequestsMu.Lock() + p.proxiedRequests = append(p.proxiedRequests, + fmt.Sprintf("%s - %s %s %s", + r.Method, r.URL.Scheme, r.URL.Host, r.URL.String())) + p.proxiedRequestsMu.Unlock() + + // when modifying the request, RequestURI isn't updated, and it isn't + // needed anyway, so remove it. + r.RequestURI = "" + + return p.client.Do(r) +} + // ProxiedRequests returns a slice with the "request log" with every request the // proxy received. func (p *Proxy) ProxiedRequests() []string { @@ -236,26 +306,67 @@ func (p *Proxy) ProxiedRequests() []string { return rs } -// statusResponseWriter wraps a http.ResponseWriter to expose the status code -// through statusResponseWriter.statusCode -type statusResponseWriter struct { +var _ http.Hijacker = &proxyResponseWriter{} + +// proxyResponseWriter wraps a http.ResponseWriter to expose the status code +// through proxyResponseWriter.statusCode +type proxyResponseWriter struct { w http.ResponseWriter statusCode int } -func (s *statusResponseWriter) Header() http.Header { +func (s *proxyResponseWriter) Header() http.Header { return s.w.Header() } -func (s *statusResponseWriter) Write(bs []byte) (int, error) { +func (s *proxyResponseWriter) Write(bs []byte) (int, error) { return s.w.Write(bs) } -func (s *statusResponseWriter) WriteHeader(statusCode int) { +func (s *proxyResponseWriter) WriteHeader(statusCode int) { s.statusCode = statusCode s.w.WriteHeader(statusCode) } -func (s *statusResponseWriter) StatusCode() int { +func (s *proxyResponseWriter) StatusCode() int { return s.statusCode } + +func (s *proxyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := s.w.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("%T does not support hijacking", s.w) + } + + return hijacker.Hijack() +} + +type ctxKeyRecID struct{} +type ctxKeyLogger struct{} + +func addIDToReqCtx(r *http.Request, id string) *http.Request { + return r.WithContext(context.WithValue(r.Context(), ctxKeyRecID{}, id)) +} + +func idFromReqCtx(r *http.Request) string { //nolint:unused // kept for completeness + return r.Context().Value(ctxKeyRecID{}).(string) +} + +func addLoggerReqCtx(r *http.Request, log *slog.Logger) *http.Request { + return r.WithContext(context.WithValue(r.Context(), ctxKeyLogger{}, log)) +} + +func loggerFromReqCtx(r *http.Request) *slog.Logger { + l, ok := r.Context().Value(ctxKeyLogger{}).(*slog.Logger) + if !ok { + return slog.Default() + } + return l +} + +type logfWriter func(format string, a ...any) + +func (w logfWriter) Write(p []byte) (n int, err error) { + w(string(p)) + return len(p), nil +} diff --git a/testing/proxytest/proxytest_example_test.go b/testing/proxytest/proxytest_example_test.go new file mode 100644 index 00000000000..4e62bbe72f4 --- /dev/null +++ b/testing/proxytest/proxytest_example_test.go @@ -0,0 +1,177 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License 2.0; +// you may not use this file except in compliance with the Elastic License 2.0. + +//go:build example + +package proxytest + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-agent-libs/testing/certutil" +) + +// TestRunHTTPSProxy is an example of how to use the proxytest outside tests, +// and it instructs how to perform a request through the proxy using cURL. +// From the repo's root, run this test with: +// go test -tags example -v -run TestRunHTTPSProxy$ ./testing/proxytest +func TestRunHTTPSProxy(t *testing.T) { + // Create a temporary directory to store certificates + tmpDir := t.TempDir() + + // ========================= generate certificates ========================= + serverCAKey, serverCACert, serverCAPair, err := certutil.NewRootCA( + certutil.WithCNPrefix("server")) + require.NoError(t, err, "error creating root CA") + + serverCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + serverCAKey, + serverCACert, certutil.WithCNPrefix("server")) + require.NoError(t, err, "error creating server certificate") + serverCACertPool := x509.NewCertPool() + serverCACertPool.AddCert(serverCACert) + + proxyCAKey, proxyCACert, proxyCAPair, err := certutil.NewRootCA( + certutil.WithCNPrefix("proxy")) + require.NoError(t, err, "error creating root CA") + + proxyCert, proxyCertPair, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + proxyCAKey, + proxyCACert, + certutil.WithCNPrefix("proxy")) + require.NoError(t, err, "error creating server certificate") + + clientCAKey, clientCACert, clientCAPair, err := certutil.NewRootCA( + certutil.WithCNPrefix("client")) + require.NoError(t, err, "error creating root CA") + clientCACertPool := x509.NewCertPool() + clientCACertPool.AddCert(clientCACert) + + _, clientCertPair, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + clientCAKey, + clientCACert, + certutil.WithCNPrefix("client")) + require.NoError(t, err, "error creating server certificate") + + // =========================== save certificates =========================== + serverCACertFile := filepath.Join(tmpDir, "serverCA.crt") + if err := os.WriteFile(serverCACertFile, serverCAPair.Cert, 0644); err != nil { + t.Fatal(err) + } + serverCAKeyFile := filepath.Join(tmpDir, "serverCA.key") + if err := os.WriteFile(serverCAKeyFile, serverCAPair.Key, 0644); err != nil { + t.Fatal(err) + } + + proxyCACertFile := filepath.Join(tmpDir, "proxyCA.crt") + if err := os.WriteFile(proxyCACertFile, proxyCAPair.Cert, 0644); err != nil { + t.Fatal(err) + } + proxyCAKeyFile := filepath.Join(tmpDir, "proxyCA.key") + if err := os.WriteFile(proxyCAKeyFile, proxyCAPair.Key, 0644); err != nil { + t.Fatal(err) + } + proxyCertFile := filepath.Join(tmpDir, "proxyCert.crt") + if err := os.WriteFile(proxyCertFile, proxyCertPair.Cert, 0644); err != nil { + t.Fatal(err) + } + proxyKeyFile := filepath.Join(tmpDir, "proxyCert.key") + if err := os.WriteFile(proxyKeyFile, proxyCertPair.Key, 0644); err != nil { + t.Fatal(err) + } + + clientCACertFile := filepath.Join(tmpDir, "clientCA.crt") + if err := os.WriteFile(clientCACertFile, clientCAPair.Cert, 0644); err != nil { + t.Fatal(err) + } + clientCAKeyFile := filepath.Join(tmpDir, "clientCA.key") + if err := os.WriteFile(clientCAKeyFile, clientCAPair.Key, 0644); err != nil { + t.Fatal(err) + } + clientCertCertFile := filepath.Join(tmpDir, "clientCert.crt") + if err := os.WriteFile(clientCertCertFile, clientCertPair.Cert, 0644); err != nil { + t.Fatal(err) + } + clientCertKeyFile := filepath.Join(tmpDir, "clientCert.key") + if err := os.WriteFile(clientCertKeyFile, clientCertPair.Key, 0644); err != nil { + t.Fatal(err) + } + + // ========================== create target server ========================= + targetHost := "not-a-server.co" + server := httptest.NewUnstartedServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("It works!")) + })) + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + MinVersion: tls.VersionTLS13, + } + server.StartTLS() + t.Logf("target server running on %s", server.URL) + + // ============================== create proxy ============================= + proxy := New(t, + WithVerboseLog(), + WithRequestLog("https", t.Logf), + WithRewrite(targetHost+":443", server.URL[8:]), + WithMITMCA(proxyCAKey, proxyCACert), + WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: serverCACertPool, + MinVersion: tls.VersionTLS13, + }, + }, + }), + WithServerTLSConfig(&tls.Config{ + Certificates: []tls.Certificate{*proxyCert}, + ClientCAs: clientCACertPool, + ClientAuth: tls.VerifyClientCertIfGiven, + MinVersion: tls.VersionTLS13, + })) + err = proxy.StartTLS() + require.NoError(t, err, "error starting proxy") + t.Logf("proxy running on %s", proxy.LocalhostURL) + defer proxy.Close() + + // ============================ test instructions ========================== + + u := "https://" + targetHost + t.Logf("make request to %q using proxy %q", u, proxy.LocalhostURL) + + t.Logf(`curl \ +--proxy-cacert %s \ +--proxy-cert %s \ +--proxy-key %s \ +--cacert %s \ +--proxy %s \ +%s`, + proxyCACertFile, + clientCertCertFile, + clientCertKeyFile, + proxyCACertFile, + proxy.URL, + u, + ) + + t.Log("CTRL+C to stop") + <-context.Background().Done() +} diff --git a/testing/proxytest/proxytest_test.go b/testing/proxytest/proxytest_test.go index 324a0faa49d..6b2ad0007c7 100644 --- a/testing/proxytest/proxytest_test.go +++ b/testing/proxytest/proxytest_test.go @@ -6,14 +6,9 @@ package proxytest import ( "context" - "crypto/rand" - "crypto/rsa" "crypto/tls" "crypto/x509" - "crypto/x509/pkix" - "fmt" "io" - "math/big" "net" "net/http" "net/http/httptest" @@ -23,19 +18,23 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-agent-libs/testing/certutil" ) func TestProxy(t *testing.T) { - - // Eagerly create objects for TLS setup with CA and server certificate - caCertificate, _, caPrivateKey, err := createCaCertificate() - require.NoError(t, err, "error creating CA cert and key") - - serverCert, serverCertBytes, serverPrivateKey, err := createCertificateSignedByCA(caCertificate, caPrivateKey) + proxyCAKey, proxyCACert, _, err := certutil.NewRootCA() + require.NoError(t, err, "error creating root CA") + + proxyCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + proxyCAKey, + proxyCACert) require.NoError(t, err, "error creating server certificate") - caCertPool := x509.NewCertPool() - caCertPool.AddCert(caCertificate) + proxyCACertPool := x509.NewCertPool() + proxyCACertPool.AddCert(proxyCACert) type setup struct { fakeBackendServer *httptest.Server @@ -91,7 +90,7 @@ func TestProxy(t *testing.T) { Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), TLSClientConfig: &tls.Config{ - RootCAs: caCertPool, + RootCAs: proxyCACertPool, MinVersion: tls.VersionTLS12, }, }, @@ -100,13 +99,9 @@ func TestProxy(t *testing.T) { }, proxyOptions: []Option{ WithServerTLSConfig(&tls.Config{ - ClientCAs: caCertPool, - Certificates: []tls.Certificate{{ - Certificate: [][]byte{serverCertBytes}, - PrivateKey: serverPrivateKey, - Leaf: serverCert, - }}, - MinVersion: tls.VersionTLS12, + ClientCAs: proxyCACertPool, + Certificates: []tls.Certificate{*proxyCert}, + MinVersion: tls.VersionTLS12, }), }, proxyStartTLS: true, @@ -132,21 +127,21 @@ func TestProxy(t *testing.T) { require.NoErrorf(t, err, "failed to parse proxy URL %q", proxy.URL) // Client certificate - clientCert, clientCertBytes, clientPrivateKey, err := createCertificateSignedByCA(caCertificate, caPrivateKey) - require.NoError(t, err, "error creating client certificate") + tlsCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + proxyCAKey, + proxyCACert) + require.NoError(t, err, "failed generating client certificate") // Client with its own certificate and trusting the proxy cert CA return &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), TLSClientConfig: &tls.Config{ - RootCAs: caCertPool, + RootCAs: proxyCACertPool, Certificates: []tls.Certificate{ - { - Certificate: [][]byte{clientCertBytes}, - PrivateKey: clientPrivateKey, - Leaf: clientCert, - }, + *tlsCert, }, MinVersion: tls.VersionTLS12, }, @@ -156,14 +151,10 @@ func TestProxy(t *testing.T) { proxyOptions: []Option{ // require client authentication and verify cert WithServerTLSConfig(&tls.Config{ - ClientCAs: caCertPool, - ClientAuth: tls.RequireAndVerifyClientCert, - Certificates: []tls.Certificate{{ - Certificate: [][]byte{serverCertBytes}, - PrivateKey: serverPrivateKey, - Leaf: serverCert, - }}, - MinVersion: tls.VersionTLS12, + ClientCAs: proxyCACertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + Certificates: []tls.Certificate{*proxyCert}, + MinVersion: tls.VersionTLS12, }), }, proxyStartTLS: true, @@ -206,7 +197,6 @@ func TestProxy(t *testing.T) { t.Log("Starting proxytest without TLS") err = proxy.Start() } - require.NoError(t, err, "error starting proxytest") defer proxy.Close() @@ -240,93 +230,156 @@ func TestProxy(t *testing.T) { } -func createFakeBackendServer() *httptest.Server { - handlerF := func(writer http.ResponseWriter, request *http.Request) { - // always return HTTP 200 - writer.WriteHeader(http.StatusOK) - } +func TestHTTPSProxy(t *testing.T) { + targetHost := "not-a-server.co" + proxy, client, target := prepareMTLSProxyAndTargetServer(t, targetHost) + defer proxy.Close() + defer target.Close() + + tcs := []struct { + name string + target string + // assertFn should not close the response body + assertFn func(*testing.T, *http.Response, error) + }{ + { + name: "successful_request", + target: "https://" + targetHost, + assertFn: func(t *testing.T, got *http.Response, err error) { + if !assert.Equal(t, http.StatusOK, got.StatusCode, "unexpected status code") { + body, err := io.ReadAll(got.Body) + if err != nil { + t.Logf("could not read response body") + t.FailNow() + } + _ = got.Body.Close() - fakeBackendHTTPServer := httptest.NewServer(http.HandlerFunc(handlerF)) - return fakeBackendHTTPServer -} + t.Logf("request body: %s", string(body)) + } -// utility function to create a CA cert and related key for tests. It returns certificate and key as PEM-encoded blocks -func createCaCertificate() (cert *x509.Certificate, certBytes []byte, privateKey *rsa.PrivateKey, err error) { - // set up our CA certificate - ca := &x509.Certificate{ - SerialNumber: big.NewInt(2023), - Subject: pkix.Name{ - Organization: []string{"Elastic Agent Testing Company, INC."}, - Country: []string{"US"}, - Province: []string{""}, - Locality: []string{"San Francisco"}, - StreetAddress: []string{"Golden Gate Bridge"}, - PostalCode: []string{"94016"}, + }, + }, + { + name: "request_failure", + target: "https://any.not.target.will.do", + assertFn: func(t *testing.T, got *http.Response, err error) { + assert.Equal(t, http.StatusBadGateway, got.StatusCode) + + body, err := io.ReadAll(got.Body) + require.NoError(t, err, "failed reading response body") + assert.Contains(t, string(body), "failed performing request to target") + }, }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, 1), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, } - // create our private and public key - caPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, nil, err - } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Logf("making request to %q using proxy %q", tc.target, proxy.URL) - // create the CA - caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivateKey.PublicKey, caPrivateKey) - if err != nil { - return nil, nil, nil, err - } + got, err := client.Get(tc.target) //nolint:noctx // it's a test + require.NoError(t, err, "request should have succeeded") + defer got.Body.Close() - cert, err = x509.ParseCertificate(caBytes) - if err != nil { - return nil, nil, nil, fmt.Errorf("error parsing generated CA cert: %w", err) + // assertFn should not close the response body + tc.assertFn(t, got, err) + }) } - - return cert, caBytes, caPrivateKey, nil } -// utility function to create a new certificate signed by a CA and related key for tests. -// Both paramenters and returned certificate and key are PEM-encoded blocks. -func createCertificateSignedByCA(caCert *x509.Certificate, caPrivateKey *rsa.PrivateKey) (cert *x509.Certificate, certBytes []byte, privateKey *rsa.PrivateKey, err error) { - certTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(2024), - Subject: pkix.Name{ - Organization: []string{"Elastic Agent Testing Company, INC."}, - Country: []string{"US"}, - Province: []string{""}, - Locality: []string{"San Francisco"}, - StreetAddress: []string{"Golden Gate Bridge"}, - PostalCode: []string{"94016"}, - }, - IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, - DNSNames: []string{"localhost"}, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, 1), - SubjectKeyId: []byte{1, 2, 3, 4, 6}, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, - } +func prepareMTLSProxyAndTargetServer(t *testing.T, targetHost string) (*Proxy, http.Client, *httptest.Server) { + serverCAKey, serverCACert, _, err := certutil.NewRootCA() + require.NoError(t, err, "error creating root CA") - certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, nil, err - } + serverCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + serverCAKey, + serverCACert) + require.NoError(t, err, "error creating server certificate") + serverCACertPool := x509.NewCertPool() + serverCACertPool.AddCert(serverCACert) + + proxyCAKey, proxyCACert, _, err := certutil.NewRootCA() + require.NoError(t, err, "error creating root CA") - certBytes, err = x509.CreateCertificate(rand.Reader, certTemplate, caCert, &certPrivKey.PublicKey, caPrivateKey) - if err != nil { - return nil, nil, nil, fmt.Errorf("error creating new certificate: %w", err) + proxyCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + proxyCAKey, + proxyCACert) + require.NoError(t, err, "error creating server certificate") + proxyCACertPool := x509.NewCertPool() + proxyCACertPool.AddCert(proxyCACert) + + clientCAKey, clientCACert, _, err := certutil.NewRootCA() + require.NoError(t, err, "error creating root CA") + clientCACertPool := x509.NewCertPool() + clientCACertPool.AddCert(clientCACert) + + clientCert, _, err := certutil.GenerateChildCert( + "localhost", + []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback, net.IPv6zero}, + clientCAKey, + clientCACert) + require.NoError(t, err, "error creating server certificate") + + server := httptest.NewUnstartedServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("It works!")) + })) + server.TLS = &tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{*serverCert}, + } + server.StartTLS() + t.Logf("target server running on %s", server.URL) + + proxy := New(t, + WithVerboseLog(), + WithRequestLog("https", t.Logf), + WithRewrite(targetHost+":443", server.URL[8:]), + WithMITMCA(proxyCAKey, proxyCACert), + WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS13, + RootCAs: serverCACertPool, + }, + }, + }), + WithServerTLSConfig(&tls.Config{ + Certificates: []tls.Certificate{*proxyCert}, + ClientCAs: clientCACertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS13, + })) + err = proxy.StartTLS() + require.NoError(t, err, "error starting proxy") + t.Logf("proxy running on %s", proxy.URL) + + proxyURL, err := url.Parse(proxy.URL) + require.NoErrorf(t, err, "failed to parse proxy URL %q", proxy.URL) + + client := http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + RootCAs: proxyCACertPool, + Certificates: []tls.Certificate{*clientCert}, + MinVersion: tls.VersionTLS12, + }, + }, } - cert, err = x509.ParseCertificate(certBytes) - if err != nil { - return nil, nil, nil, fmt.Errorf("error parsing new certificate: %w", err) + return proxy, client, server +} + +func createFakeBackendServer() *httptest.Server { + handlerF := func(writer http.ResponseWriter, request *http.Request) { + // always return HTTP 200 + writer.WriteHeader(http.StatusOK) } - return cert, certBytes, certPrivKey, nil + fakeBackendHTTPServer := httptest.NewServer(http.HandlerFunc(handlerF)) + return fakeBackendHTTPServer }