From a44b165b203b0a057d03b248ac9b83e11d12b021 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 18 Apr 2023 12:20:39 -0600 Subject: [PATCH] fix: pass dial options to FUSE mounts This is a port of https://github.com/GoogleCloudPlatform/cloud-sql-proxy/pull/1737. --- internal/proxy/proxy.go | 27 +++++------- internal/proxy/proxy_other.go | 7 ++- tests/connection_test.go | 7 +++ tests/fuse_test.go | 82 +++++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 18 deletions(-) create mode 100644 tests/fuse_test.go diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 4a626e3e..61268c4d 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -362,19 +362,13 @@ type Client struct { // all AlloyDB instances. connCount uint64 - // maxConns is the maximum number of allowed connections tracked by - // connCount. If not set, there is no limit. - maxConns uint64 + conf *Config dialer alloydb.Dialer // mnts is a list of all mounted sockets for this client mnts []*socketMount - // waitOnClose is the maximum duration to wait for open connections to close - // when shutting down. - waitOnClose time.Duration - logger alloydb.Logger fuseMount @@ -396,10 +390,9 @@ func NewClient(ctx context.Context, d alloydb.Dialer, l alloydb.Logger, conf *Co } c := &Client{ - logger: l, - dialer: d, - maxConns: conf.MaxConnections, - waitOnClose: conf.WaitOnClose, + logger: l, + dialer: d, + conf: conf, } if conf.FUSEDir != "" { @@ -485,7 +478,7 @@ func (c *Client) CheckConnections(ctx context.Context) (int, error) { // ConnCount returns the number of open connections and the maximum allowed // connections. Returns 0 when the maximum allowed connections have not been set. func (c *Client) ConnCount() (uint64, uint64) { - return atomic.LoadUint64(&c.connCount), c.maxConns + return atomic.LoadUint64(&c.connCount), c.conf.MaxConnections } // Serve starts proxying connections for all configured instances using the @@ -565,13 +558,13 @@ func (c *Client) Close() error { if cErr != nil { mErr = append(mErr, cErr) } - if c.waitOnClose == 0 { + if c.conf.WaitOnClose == 0 { if len(mErr) > 0 { return mErr } return nil } - timeout := time.After(c.waitOnClose) + timeout := time.After(c.conf.WaitOnClose) t := time.NewTicker(100 * time.Millisecond) defer t.Stop() for { @@ -586,7 +579,7 @@ func (c *Client) Close() error { } open := atomic.LoadUint64(&c.connCount) if open > 0 { - mErr = append(mErr, fmt.Errorf("%d connection(s) still open after waiting %v", open, c.waitOnClose)) + mErr = append(mErr, fmt.Errorf("%d connection(s) still open after waiting %v", open, c.conf.WaitOnClose)) } if len(mErr) > 0 { return mErr @@ -619,8 +612,8 @@ func (c *Client) serveSocketMount(_ context.Context, s *socketMount) error { count := atomic.AddUint64(&c.connCount, 1) defer atomic.AddUint64(&c.connCount, ^uint64(0)) - if c.maxConns > 0 && count > c.maxConns { - c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.maxConns) + if c.conf.MaxConnections > 0 && count > c.conf.MaxConnections { + c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.conf.MaxConnections) _ = cConn.Close() return } diff --git a/internal/proxy/proxy_other.go b/internal/proxy/proxy_other.go index 3f64ec60..7da711d0 100644 --- a/internal/proxy/proxy_other.go +++ b/internal/proxy/proxy_other.go @@ -116,7 +116,7 @@ func (c *Client) Lookup(ctx context.Context, instance string, _ *fuse.EntryOut) } s, err := newSocketMount( - ctx, &Config{UnixSocket: c.fuseTempDir}, + ctx, withUnixSocket(*c.conf, c.fuseTempDir), nil, InstanceConnConfig{Name: instanceURI}, ) if err != nil { @@ -148,6 +148,11 @@ func (c *Client) Lookup(ctx context.Context, instance string, _ *fuse.EntryOut) ), fs.OK } +func withUnixSocket(c Config, tmpDir string) *Config { + c.UnixSocket = tmpDir + return &c +} + func (c *Client) serveFuse(ctx context.Context, notify func()) error { srv, err := fs.Mount(c.fuseDir, c, &fs.Options{ MountOptions: fuse.MountOptions{AllowOther: true}, diff --git a/tests/connection_test.go b/tests/connection_test.go index b1eda492..1f321b5b 100644 --- a/tests/connection_test.go +++ b/tests/connection_test.go @@ -70,6 +70,10 @@ func keyfile(t *testing.T) string { // proxyConnTest is a test helper to verify the proxy works with a basic connectivity test. func proxyConnTest(t *testing.T, args []string, driver, dsn string) { + proxyConnTestWithReady(t, args, driver, dsn, func() error { return nil }) +} + +func proxyConnTestWithReady(t *testing.T, args []string, driver, dsn string, ready func() error) { ctx, cancel := context.WithTimeout(context.Background(), connTestTimeout) defer cancel() // Start the proxy @@ -82,6 +86,9 @@ func proxyConnTest(t *testing.T, args []string, driver, dsn string) { if err != nil { t.Fatalf("unable to verify proxy was serving: %s \n %s", err, output) } + if err := ready(); err != nil { + t.Fatalf("proxy was not ready: %v", err) + } // Connect to the instance db, err := sql.Open(driver, dsn) diff --git a/tests/fuse_test.go b/tests/fuse_test.go new file mode 100644 index 00000000..d281c77d --- /dev/null +++ b/tests/fuse_test.go @@ -0,0 +1,82 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows && !darwin + +package tests + +import ( + "fmt" + "os" + "testing" + "time" + + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" +) + +func TestPostgresFUSEConnect(t *testing.T) { + if testing.Short() { + t.Skip("skipping Postgres integration tests") + } + tmpDir, cleanup := createTempDir(t) + defer cleanup() + + host := proxy.UnixAddress(tmpDir, *alloydbConnName) + dsn := fmt.Sprintf( + "host=%s user=%s password=%s database=%s sslmode=disable", + host, *alloydbUser, *alloydbPass, *alloydbDB, + ) + testFUSE(t, tmpDir, host, dsn) +} + +func testFUSE(t *testing.T, tmpDir, host string, dsn string) { + tmpDir2, cleanup2 := createTempDir(t) + defer cleanup2() + + waitForFUSE := func() error { + var err error + for i := 0; i < 10; i++ { + _, err = os.Stat(host) + if err == nil { + return nil + } + time.Sleep(500 * time.Millisecond) + } + return fmt.Errorf("failed to find FUSE mounted Unix socket: %v", err) + } + + tcs := []struct { + desc string + dbUser string + args []string + }{ + { + desc: "using default fuse", + args: []string{fmt.Sprintf("--fuse=%s", tmpDir), fmt.Sprintf("--fuse-tmp-dir=%s", tmpDir2)}, + }, + { + desc: "using fuse with auto-iam-authn", + args: []string{fmt.Sprintf("--fuse=%s", tmpDir), "--auto-iam-authn"}, + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + proxyConnTestWithReady(t, tc.args, "pgx", dsn, waitForFUSE) + // given the kernel some time to unmount the fuse + time.Sleep(100 * time.Millisecond) + }) + } + +}