Skip to content

Commit

Permalink
fix(db): send initial heartbeat when there is no static dbs (#11160)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielcorado committed Apr 18, 2022
1 parent f00ae81 commit 608a85a
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 0 deletions.
79 changes: 79 additions & 0 deletions integration/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package integration

import (
"context"
"fmt"
"net"
"net/http"
"testing"
"time"

Expand All @@ -31,6 +33,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
Expand Down Expand Up @@ -588,6 +591,49 @@ func TestDatabaseAccessMongoSeparateListener(t *testing.T) {
require.NoError(t, err)
}

func TestDatabaseAgentState(t *testing.T) {
tests := map[string]struct {
agentParams databaseAgentStartParams
}{
"WithStaticDatabases": {
agentParams: databaseAgentStartParams{
databases: []service.Database{
{Name: "mysql", Protocol: defaults.ProtocolMySQL, URI: "localhost:3306"},
{Name: "pg", Protocol: defaults.ProtocolPostgres, URI: "localhost:5432"},
},
},
},
"WithResourceMatchers": {
agentParams: databaseAgentStartParams{
resourceMatchers: []services.ResourceMatcher{
{Labels: types.Labels{"*": []string{"*"}}},
},
},
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
pack := setupDatabaseTest(t)

// Start also ensures that the database agent has the “ready” state.
// If the agent can’t make it, this function will fail the test.
agent, _ := pack.startRootDatabaseAgent(t, test.agentParams)

// In addition to the checks performed during the agent start,
// we’ll request the diagnostic server to ensure the readyz route
// is returning to the proper state.
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%v/readyz", agent.Config.DiagnosticAddr.Addr), nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

require.Equal(t, http.StatusOK, resp.StatusCode)
})
}
}

func waitForAuditEventTypeWithBackoff(t *testing.T, cli *auth.Server, startTime time.Time, eventType string) []apievents.AuditEvent {
max := time.Second
timeout := time.After(max)
Expand Down Expand Up @@ -1015,6 +1061,39 @@ func (p *databasePack) waitForLeaf(t *testing.T) {
}
}

// databaseAgentStartParams parameters used to configure a database agent.
type databaseAgentStartParams struct {
databases []service.Database
resourceMatchers []services.ResourceMatcher
}

// startRootDatabaseAgent starts a database agent with the provided
// configuration on the root cluster.
func (p *databasePack) startRootDatabaseAgent(t *testing.T, params databaseAgentStartParams) (*service.TeleportProcess, *auth.Client) {
conf := service.MakeDefaultConfig()
conf.DataDir = t.TempDir()
conf.Token = "static-token-value"
conf.DiagnosticAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("localhost", ports.Pop())}
conf.AuthServers = []utils.NetAddr{
{
AddrNetwork: "tcp",
Addr: net.JoinHostPort(Loopback, p.root.cluster.GetPortWeb()),
},
}
conf.Clock = p.clock
conf.Databases.Enabled = true
conf.Databases.Databases = params.databases
conf.Databases.ResourceMatchers = params.resourceMatchers

server, authClient, err := p.root.cluster.StartDatabase(conf)
require.NoError(t, err)
t.Cleanup(func() {
server.Close()
})

return server, authClient
}

func containsDB(servers []types.DatabaseServer, name string) bool {
for _, server := range servers {
if server.GetDatabase().GetName() == name {
Expand Down
3 changes: 3 additions & 0 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,8 @@ type agentParams struct {
NoStart bool
// GCPSQL defines the GCP Cloud SQL mock to use for GCP API calls.
GCPSQL *cloud.GCPSQLAdminClientMock
// OnHeartbeat defines a heartbeat function that generates heartbeat events.
OnHeartbeat func(error)
}

func (p *agentParams) setDefaults(c *testContext) {
Expand Down Expand Up @@ -1874,6 +1876,7 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p a
Limiter: connLimiter,
Auth: testAuth,
Databases: p.Databases,
OnHeartbeat: p.OnHeartbeat,
ResourceMatchers: p.ResourceMatchers,
GetServerInfoFn: p.GetServerInfoFn,
GetRotation: func(types.SystemRole) (*types.Rotation, error) {
Expand Down
6 changes: 6 additions & 0 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,12 @@ func (s *Server) Start(ctx context.Context) (err error) {
return trace.Wrap(err)
}

// If the agent doesn’t have any static databases configured, send a
// heartbeat without error to make the component “ready”.
if len(s.cfg.Databases) == 0 && s.cfg.OnHeartbeat != nil {
s.cfg.OnHeartbeat(nil)
}

return nil
}

Expand Down
67 changes: 67 additions & 0 deletions lib/srv/db/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ package db

import (
"context"
"sync/atomic"
"testing"
"time"

apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/limiter"

"github.com/jackc/pgconn"
Expand Down Expand Up @@ -177,3 +180,67 @@ func TestDatabaseServerLimiting(t *testing.T) {
require.FailNow(t, "we should exceed the connection limit by now")
})
}

func TestHeartbeatEvents(t *testing.T) {
ctx := context.Background()

dbOne, err := types.NewDatabaseV3(types.Metadata{
Name: "dbOne",
}, types.DatabaseSpecV3{
Protocol: defaults.ProtocolPostgres,
URI: "localhost:5432",
})
require.NoError(t, err)

dbTwo, err := types.NewDatabaseV3(types.Metadata{
Name: "dbOne",
}, types.DatabaseSpecV3{
Protocol: defaults.ProtocolMySQL,
URI: "localhost:3306",
})
require.NoError(t, err)

tests := map[string]struct {
staticDatabases types.Databases
heartbeatCount int64
}{
"SingleStaticDatabase": {
staticDatabases: types.Databases{dbOne},
heartbeatCount: 1,
},
"MultipleStaticDatabases": {
staticDatabases: types.Databases{dbOne, dbTwo},
heartbeatCount: 2,
},
"EmptyStaticDatabases": {
staticDatabases: types.Databases{},
heartbeatCount: 1,
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
var heartbeatEvents int64
heartbeatRecorder := func(err error) {
require.NoError(t, err)
atomic.AddInt64(&heartbeatEvents, 1)
}

testCtx := setupTestContext(ctx, t)
server := testCtx.setupDatabaseServer(ctx, t, agentParams{
NoStart: true,
OnHeartbeat: heartbeatRecorder,
Databases: test.staticDatabases,
})
require.NoError(t, server.Start(ctx))
t.Cleanup(func() {
server.Close()
})

require.NotNil(t, server)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&heartbeatEvents) == test.heartbeatCount
}, 2*time.Second, 500*time.Millisecond)
})
}
}

0 comments on commit 608a85a

Please sign in to comment.