From 5a0bfae8c5318295c2f8ab37d38b25a87fa87539 Mon Sep 17 00:00:00 2001 From: Charlie Voiselle <464492+angrycub@users.noreply.github.com> Date: Wed, 11 Oct 2023 09:38:20 -0400 Subject: [PATCH] Add unix domain socket support to API - Expose internal HTTP client's Do() via Raw - Use URL parser to identify scheme - Align more with curl output - Add changelog - Fix test failure; add tests for socket envvars - Apply review feedback for tests - Consolidate address parsing - Address feedback from code reviews Co-authored-by: Tim Gross --- .changelog/16872.txt | 3 ++ api/api.go | 81 ++++++++++++++++++++++++++++++++---- api/raw.go | 10 ++++- command/operator_api.go | 54 +++++++++++++----------- command/operator_api_test.go | 74 ++++++++++++++++++++++++++++++++ 5 files changed, 188 insertions(+), 34 deletions(-) create mode 100644 .changelog/16872.txt diff --git a/.changelog/16872.txt b/.changelog/16872.txt new file mode 100644 index 00000000000..955c76adccc --- /dev/null +++ b/.changelog/16872.txt @@ -0,0 +1,3 @@ +```release-note:improvement +api: Added support for Unix domain sockets +``` diff --git a/api/api.go b/api/api.go index 7f9b41bd37b..8f2a07b71f1 100644 --- a/api/api.go +++ b/api/api.go @@ -205,6 +205,16 @@ type Config struct { // retryOptions holds the configuration necessary to perform retries // on put calls. retryOptions *retryOptions + + // url is populated with the initial parsed address and is not modified in the + // case of a unix:// URL, as opposed to Address. + url *url.URL +} + +// URL returns a copy of the initial parsed address and is not modified in the +// case of a `unix://` URL, as opposed to Address. +func (c *Config) URL() *url.URL { + return c.url } // ClientConfig copies the configuration with a new client address, region, and @@ -214,6 +224,7 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config { if tlsEnabled { scheme = "https" } + config := &Config{ Address: fmt.Sprintf("%s://%s", scheme, address), Region: region, @@ -223,6 +234,7 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config { HttpAuth: c.HttpAuth, WaitTime: c.WaitTime, TLSConfig: c.TLSConfig.Copy(), + url: copyURL(c.url), } // Update the tls server name for connecting to a client @@ -278,9 +290,30 @@ func (t *TLSConfig) Copy() *TLSConfig { return nt } +// defaultUDSClient creates a unix domain socket client. Errors return a nil +// http.Client, which is tested for in ConfigureTLS. This function expects that +// the Address has already been parsed into the config.url value. +func defaultUDSClient(config *Config) *http.Client { + + config.Address = "http://127.0.0.1" + + httpClient := &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", config.url.EscapedPath()) + }, + }, + } + return defaultClient(httpClient) +} + func defaultHttpClient() *http.Client { httpClient := cleanhttp.DefaultPooledClient() - transport := httpClient.Transport.(*http.Transport) + return defaultClient(httpClient) +} + +func defaultClient(c *http.Client) *http.Client { + transport := c.Transport.(*http.Transport) transport.TLSHandshakeTimeout = 10 * time.Second transport.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, @@ -290,7 +323,7 @@ func defaultHttpClient() *http.Client { // well yet: https://github.com/gorilla/websocket/issues/417 transport.ForceAttemptHTTP2 = false - return httpClient + return c } // DefaultConfig returns a default configuration for the client @@ -467,18 +500,29 @@ type Client struct { // NewClient returns a new client func NewClient(config *Config) (*Client, error) { + var err error // bootstrap the config defConfig := DefaultConfig() if config.Address == "" { config.Address = defConfig.Address - } else if _, err := url.Parse(config.Address); err != nil { + } + + // we have to test the address that comes from DefaultConfig, because it + // could be the value of NOMAD_ADDR which is applied without testing + if config.url, err = url.Parse(config.Address); err != nil { return nil, fmt.Errorf("invalid address '%s': %v", config.Address, err) } httpClient := config.HttpClient if httpClient == nil { - httpClient = defaultHttpClient() + switch { + case config.url.Scheme == "unix": + httpClient = defaultUDSClient(config) // mutates config + default: + httpClient = defaultHttpClient() + } + if err := ConfigureTLS(httpClient, config.TLSConfig); err != nil { return nil, err } @@ -760,24 +804,32 @@ func (r *request) toHTTP() (*http.Request, error) { // newRequest is used to create a new request func (c *Client) newRequest(method, path string) (*request, error) { - base, _ := url.Parse(c.config.Address) + u, err := url.Parse(path) if err != nil { return nil, err } + r := &request{ config: &c.config, method: method, url: &url.URL{ - Scheme: base.Scheme, - User: base.User, - Host: base.Host, + Scheme: c.config.url.Scheme, + User: c.config.url.User, + Host: c.config.url.Host, Path: u.Path, RawPath: u.RawPath, }, header: make(http.Header), params: make(map[string][]string), } + + // fixup socket paths + if r.url.Scheme == "unix" { + r.url.Scheme = "http" + r.url.Host = "127.0.0.1" + } + if c.config.Region != "" { r.params.Set("region", c.config.Region) } @@ -1210,3 +1262,16 @@ func (o *WriteOptions) WithContext(ctx context.Context) *WriteOptions { o2.ctx = ctx return o2 } + +// copyURL makes a deep copy of a net/url.URL +func copyURL(u1 *url.URL) *url.URL { + if u1 == nil { + return nil + } + o := *u1 + if o.User != nil { + ou := *u1.User + o.User = &ou + } + return &o +} diff --git a/api/raw.go b/api/raw.go index 87f8a9c5eb1..73e2a529992 100644 --- a/api/raw.go +++ b/api/raw.go @@ -3,7 +3,10 @@ package api -import "io" +import ( + "io" + "net/http" +) // Raw can be used to do raw queries against custom endpoints type Raw struct { @@ -39,3 +42,8 @@ func (raw *Raw) Write(endpoint string, in, out interface{}, q *WriteOptions) (*W func (raw *Raw) Delete(endpoint string, out interface{}, q *WriteOptions) (*WriteMeta, error) { return raw.c.delete(endpoint, nil, out, q) } + +// Do uses the raw client's internal httpClient to process the request +func (raw *Raw) Do(req *http.Request) (*http.Response, error) { + return raw.c.httpClient.Do(req) +} diff --git a/command/operator_api.go b/command/operator_api.go index 985e20123de..50d45e00fcc 100644 --- a/command/operator_api.go +++ b/command/operator_api.go @@ -5,7 +5,6 @@ package command import ( "bytes" - "crypto/tls" "fmt" "io" "net" @@ -13,9 +12,7 @@ import ( "net/url" "os" "strings" - "time" - "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/nomad/api" "github.com/posener/complete" ) @@ -138,11 +135,18 @@ func (c *OperatorAPICommand) Run(args []string) int { // By default verbose func is a noop verbose := func(string, ...interface{}) {} + verboseSocket := func(*api.Config, string, ...interface{}) {} + if c.verboseFlag { verbose = func(format string, a ...interface{}) { // Use Warn instead of Info because Info goes to stdout c.Ui.Warn(fmt.Sprintf(format, a...)) } + verboseSocket = func(cfg *api.Config, format string, a ...interface{}) { + if cfg.URL() != nil && cfg.URL().Scheme == "unix" { + c.Ui.Warn(fmt.Sprintf(format, a...)) + } + } } // Opportunistically read from stdin and POST unless method has been @@ -166,11 +170,13 @@ func (c *OperatorAPICommand) Run(args []string) int { c.method = "GET" } + // NewClient mutates or validates Config.Address, so call it to match + // the behavior of other commands. Typically these are called as a combination + // using c.Client(); however, we need access to the client configuration + // to build the corresponding curl output. config := c.clientConfig() + apiC, err := api.NewClient(config) - // NewClient mutates or validates Config.Address, so call it to match - // the behavior of other commands. - _, err := api.NewClient(config) if err != nil { c.Ui.Error(fmt.Sprintf("Error initializing client: %v", err)) return 1 @@ -198,23 +204,10 @@ func (c *OperatorAPICommand) Run(args []string) int { c.Ui.Output(out) return 0 } - - // Re-implement a big chunk of api/api.go since we don't export it. - client := cleanhttp.DefaultClient() - transport := client.Transport.(*http.Transport) - transport.TLSHandshakeTimeout = 10 * time.Second - transport.TLSClientConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - } - - if err := api.ConfigureTLS(client, config.TLSConfig); err != nil { - c.Ui.Error(fmt.Sprintf("Error configuring TLS: %v", err)) - return 1 - } + apiR := apiC.Raw() setQueryParams(config, path) - - verbose("> %s %s", c.method, path) + verboseSocket(config, fmt.Sprintf("* Trying %s...", config.URL().EscapedPath())) req, err := http.NewRequest(c.method, path.String(), c.body) if err != nil { @@ -222,6 +215,10 @@ func (c *OperatorAPICommand) Run(args []string) int { return 1 } + h := req.URL.Hostname() + verboseSocket(config, fmt.Sprintf("* Connected to %s (%s)", h, config.URL().EscapedPath())) + verbose("> %s %s %s", c.method, req.URL.Path, req.Proto) + // Set headers from command line req.Header = headerFlags.headers @@ -244,11 +241,11 @@ func (c *OperatorAPICommand) Run(args []string) int { verbose("> %s: %s", k, v) } } - + verbose(">") verbose("* Sending request and receiving response...") // Do the request! - resp, err := client.Do(req) + resp, err := apiR.Do(req) if err != nil { c.Ui.Error(fmt.Sprintf("Error performing request: %v", err)) return 1 @@ -310,7 +307,8 @@ func (c *OperatorAPICommand) apiToCurl(config *api.Config, headers http.Header, parts = append(parts, "--verbose") } - if c.method != "" { + // add method flags. Note: curl output complains about `-X GET` + if c.method != "" && c.method != http.MethodGet { parts = append(parts, "-X "+c.method) } @@ -318,6 +316,10 @@ func (c *OperatorAPICommand) apiToCurl(config *api.Config, headers http.Header, parts = append(parts, "--data-binary @-") } + if config.URL().EscapedPath() != "" { + parts = append(parts, fmt.Sprintf("--unix-socket %q", config.URL().EscapedPath())) + } + if config.TLSConfig != nil { parts = tlsToCurl(parts, config.TLSConfig) @@ -412,7 +414,9 @@ func pathToURL(config *api.Config, path string) (*url.URL, error) { // If the scheme is missing from the path, it likely means the path is just // the HTTP handler path. Attempt to infer this. - if !strings.HasPrefix(path, "http://") && !strings.HasPrefix(path, "https://") { + if !strings.HasPrefix(path, "http://") && + !strings.HasPrefix(path, "https://") && + !strings.HasPrefix(path, "unix://") { scheme := "http" // If the user has set any TLS configuration value, this is a good sign diff --git a/command/operator_api_test.go b/command/operator_api_test.go index de20eac21d4..4d097b0d687 100644 --- a/command/operator_api_test.go +++ b/command/operator_api_test.go @@ -5,10 +5,13 @@ package command import ( "bytes" + "fmt" + "net" "net/http" "net/http/httptest" "net/url" "os" + "path" "testing" "time" @@ -220,3 +223,74 @@ func TestOperatorAPICommand_ContentLength(t *testing.T) { t.Fatalf("timed out waiting for request") } } + +func makeSocketListener(t *testing.T) (net.Listener, string) { + td := os.TempDir() // testing.TempDir() on macOS makes paths that are too long + sPath := path.Join(td, t.Name()+".sock") + os.Remove(sPath) // git rid of stale ones now. + + t.Cleanup(func() { os.Remove(sPath) }) + + // Create a Unix domain socket and listen for incoming connections. + socket, err := net.Listen("unix", sPath) + must.NoError(t, err) + return socket, sPath +} + +// TestOperatorAPICommand_Socket tests that requests can be routed over a unix +// domain socket +// +// Can not be run in parallel as it modifies the environment. +func TestOperatorAPICommand_Socket(t *testing.T) { + + ping := make(chan struct{}, 1) + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ping <- struct{}{} + })) + sock, sockPath := makeSocketListener(t) + ts.Listener = sock + ts.Start() + defer ts.Close() + + // Setup command. + ui := cli.NewMockUi() + cmd := &OperatorAPICommand{Meta: Meta{Ui: ui}} + + tcs := []struct { + name string + env map[string]string + args []string + exitCode int + }{ + { + name: "nomad_addr", + env: map[string]string{"NOMAD_ADDR": "unix://" + sockPath}, + args: []string{"/v1/jobs"}, + exitCode: 0, + }, + { + name: "nomad_addr opaques host", + env: map[string]string{"NOMAD_ADDR": "unix://" + sockPath}, + args: []string{"http://example.com/v1/jobs"}, + exitCode: 0, + }, + } + for i, tc := range tcs { + t.Run(fmt.Sprintf("%v_%s", i+1, t.Name()), func(t *testing.T) { + tc := tc + for k, v := range tc.env { + t.Setenv(k, v) + } + + exitCode := cmd.Run(tc.args) + must.Eq(t, tc.exitCode, exitCode, must.Sprint(ui.ErrorWriter.String())) + + select { + case l := <-ping: + must.Eq(t, struct{}{}, l) + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for request") + } + }) + } +}