diff --git a/replications/internal/store.go b/replications/internal/store.go new file mode 100644 index 00000000000..8f528e813c6 --- /dev/null +++ b/replications/internal/store.go @@ -0,0 +1,255 @@ +package internal + +import ( + "context" + "database/sql" + "errors" + "fmt" + + sq "github.com/Masterminds/squirrel" + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/kit/platform" + ierrors "github.com/influxdata/influxdb/v2/kit/platform/errors" + "github.com/influxdata/influxdb/v2/sqlite" + "github.com/mattn/go-sqlite3" +) + +var errReplicationNotFound = &ierrors.Error{ + Code: ierrors.ENotFound, + Msg: "replication not found", +} + +func errRemoteNotFound(id platform.ID, cause error) error { + return &ierrors.Error{ + Code: ierrors.EInvalid, + Msg: fmt.Sprintf("remote %q not found", id), + Err: cause, + } +} + +type Store struct { + sqlStore *sqlite.SqlStore +} + +func NewStore(sqlStore *sqlite.SqlStore) *Store { + return &Store{ + sqlStore: sqlStore, + } +} + +func (s *Store) Lock() { + s.sqlStore.Mu.Lock() +} + +func (s *Store) Unlock() { + s.sqlStore.Mu.Unlock() +} + +// ListReplications returns a list of replications matching the provided filter. +func (s *Store) ListReplications(ctx context.Context, filter influxdb.ReplicationListFilter) (*influxdb.Replications, error) { + q := sq.Select( + "id", "org_id", "name", "description", "remote_id", "local_bucket_id", "remote_bucket_id", + "max_queue_size_bytes", "latest_response_code", "latest_error_message", "drop_non_retryable_data"). + From("replications") + + if filter.OrgID.Valid() { + q = q.Where(sq.Eq{"org_id": filter.OrgID}) + } + if filter.Name != nil { + q = q.Where(sq.Eq{"name": *filter.Name}) + } + if filter.RemoteID != nil { + q = q.Where(sq.Eq{"remote_id": *filter.RemoteID}) + } + if filter.LocalBucketID != nil { + q = q.Where(sq.Eq{"local_bucket_id": *filter.LocalBucketID}) + } + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var rs influxdb.Replications + if err := s.sqlStore.DB.SelectContext(ctx, &rs.Replications, query, args...); err != nil { + return nil, err + } + + return &rs, nil +} + +// CreateReplication persists a new replication in the database. Caller is responsible for managing locks. +func (s *Store) CreateReplication(ctx context.Context, newID platform.ID, request influxdb.CreateReplicationRequest) (*influxdb.Replication, error) { + q := sq.Insert("replications"). + SetMap(sq.Eq{ + "id": newID, + "org_id": request.OrgID, + "name": request.Name, + "description": request.Description, + "remote_id": request.RemoteID, + "local_bucket_id": request.LocalBucketID, + "remote_bucket_id": request.RemoteBucketID, + "max_queue_size_bytes": request.MaxQueueSizeBytes, + "drop_non_retryable_data": request.DropNonRetryableData, + "created_at": "datetime('now')", + "updated_at": "datetime('now')", + }). + Suffix("RETURNING id, org_id, name, description, remote_id, local_bucket_id, remote_bucket_id, max_queue_size_bytes, drop_non_retryable_data") + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var r influxdb.Replication + + if err := s.sqlStore.DB.GetContext(ctx, &r, query, args...); err != nil { + if sqlErr, ok := err.(sqlite3.Error); ok && sqlErr.ExtendedCode == sqlite3.ErrConstraintForeignKey { + return nil, errRemoteNotFound(request.RemoteID, err) + } + return nil, err + } + + return &r, nil +} + +// GetReplication gets a replication by ID from the database. +func (s *Store) GetReplication(ctx context.Context, id platform.ID) (*influxdb.Replication, error) { + q := sq.Select( + "id", "org_id", "name", "description", "remote_id", "local_bucket_id", "remote_bucket_id", + "max_queue_size_bytes", "latest_response_code", "latest_error_message", "drop_non_retryable_data"). + From("replications"). + Where(sq.Eq{"id": id}) + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var r influxdb.Replication + if err := s.sqlStore.DB.GetContext(ctx, &r, query, args...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errReplicationNotFound + } + return nil, err + } + + return &r, nil +} + +// UpdateReplication updates a replication by ID. Caller is responsible for managing locks. +func (s *Store) UpdateReplication(ctx context.Context, id platform.ID, request influxdb.UpdateReplicationRequest) (*influxdb.Replication, error) { + updates := sq.Eq{"updated_at": sq.Expr("datetime('now')")} + if request.Name != nil { + updates["name"] = *request.Name + } + if request.Description != nil { + updates["description"] = *request.Description + } + if request.RemoteID != nil { + updates["remote_id"] = *request.RemoteID + } + if request.RemoteBucketID != nil { + updates["remote_bucket_id"] = *request.RemoteBucketID + } + if request.MaxQueueSizeBytes != nil { + updates["max_queue_size_bytes"] = *request.MaxQueueSizeBytes + } + if request.DropNonRetryableData != nil { + updates["drop_non_retryable_data"] = *request.DropNonRetryableData + } + + q := sq.Update("replications").SetMap(updates).Where(sq.Eq{"id": id}). + Suffix("RETURNING id, org_id, name, description, remote_id, local_bucket_id, remote_bucket_id, max_queue_size_bytes, drop_non_retryable_data") + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var r influxdb.Replication + if err := s.sqlStore.DB.GetContext(ctx, &r, query, args...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errReplicationNotFound + } + if sqlErr, ok := err.(sqlite3.Error); ok && request.RemoteID != nil && sqlErr.ExtendedCode == sqlite3.ErrConstraintForeignKey { + return nil, errRemoteNotFound(*request.RemoteID, err) + } + return nil, err + } + + return &r, nil +} + +// DeleteReplication deletes a replication by ID from the database. Caller is responsible for managing locks. +func (s *Store) DeleteReplication(ctx context.Context, id platform.ID) error { + q := sq.Delete("replications").Where(sq.Eq{"id": id}).Suffix("RETURNING id") + query, args, err := q.ToSql() + if err != nil { + return err + } + + var d platform.ID + if err := s.sqlStore.DB.GetContext(ctx, &d, query, args...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return errReplicationNotFound + } + return err + } + + return nil +} + +// DeleteBucketReplications deletes the replications for the provided localBucketID from the database. Caller is +// responsible for managing locks. A list of deleted IDs is returned for further processing by the caller. +func (s *Store) DeleteBucketReplications(ctx context.Context, localBucketID platform.ID) ([]platform.ID, error) { + q := sq.Delete("replications").Where(sq.Eq{"local_bucket_id": localBucketID}).Suffix("RETURNING id") + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var deleted []platform.ID + if err := s.sqlStore.DB.SelectContext(ctx, &deleted, query, args...); err != nil { + return nil, err + } + + return deleted, nil +} + +func (s *Store) GetFullHTTPConfig(ctx context.Context, id platform.ID) (*ReplicationHTTPConfig, error) { + q := sq.Select("c.remote_url", "c.remote_api_token", "c.remote_org_id", "c.allow_insecure_tls", "r.remote_bucket_id"). + From("replications r").InnerJoin("remotes c ON r.remote_id = c.id AND r.id = ?", id) + + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + + var rc ReplicationHTTPConfig + if err := s.sqlStore.DB.GetContext(ctx, &rc, query, args...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errReplicationNotFound + } + return nil, err + } + return &rc, nil +} + +func (s *Store) PopulateRemoteHTTPConfig(ctx context.Context, id platform.ID, target *ReplicationHTTPConfig) error { + q := sq.Select("remote_url", "remote_api_token", "remote_org_id", "allow_insecure_tls"). + From("remotes").Where(sq.Eq{"id": id}) + query, args, err := q.ToSql() + if err != nil { + return err + } + + if err := s.sqlStore.DB.GetContext(ctx, target, query, args...); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return errRemoteNotFound(id, nil) + } + return err + } + + return nil +} diff --git a/replications/internal/store_test.go b/replications/internal/store_test.go new file mode 100644 index 00000000000..1c5bf1de680 --- /dev/null +++ b/replications/internal/store_test.go @@ -0,0 +1,435 @@ +package internal + +import ( + "context" + "fmt" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/influxdata/influxdb/v2" + "github.com/influxdata/influxdb/v2/kit/platform" + "github.com/influxdata/influxdb/v2/snowflake" + "github.com/influxdata/influxdb/v2/sqlite" + "github.com/influxdata/influxdb/v2/sqlite/migrations" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +var ( + ctx = context.Background() + initID = platform.ID(1) + desc = "testing testing" + replication = influxdb.Replication{ + ID: initID, + OrgID: platform.ID(10), + Name: "test", + Description: &desc, + RemoteID: platform.ID(100), + LocalBucketID: platform.ID(1000), + RemoteBucketID: platform.ID(99999), + MaxQueueSizeBytes: 3 * influxdb.DefaultReplicationMaxQueueSizeBytes, + } + createReq = influxdb.CreateReplicationRequest{ + OrgID: replication.OrgID, + Name: replication.Name, + Description: replication.Description, + RemoteID: replication.RemoteID, + LocalBucketID: replication.LocalBucketID, + RemoteBucketID: replication.RemoteBucketID, + MaxQueueSizeBytes: replication.MaxQueueSizeBytes, + } + httpConfig = ReplicationHTTPConfig{ + RemoteURL: fmt.Sprintf("http://%s.cloud", replication.RemoteID), + RemoteToken: replication.RemoteID.String(), + RemoteOrgID: platform.ID(888888), + AllowInsecureTLS: true, + RemoteBucketID: replication.RemoteBucketID, + } + newRemoteID = platform.ID(200) + newQueueSize = influxdb.MinReplicationMaxQueueSizeBytes + updateReq = influxdb.UpdateReplicationRequest{ + RemoteID: &newRemoteID, + MaxQueueSizeBytes: &newQueueSize, + DropNonRetryableData: boolPointer(true), + } + updatedReplication = influxdb.Replication{ + ID: replication.ID, + OrgID: replication.OrgID, + Name: replication.Name, + Description: replication.Description, + RemoteID: *updateReq.RemoteID, + LocalBucketID: replication.LocalBucketID, + RemoteBucketID: replication.RemoteBucketID, + MaxQueueSizeBytes: *updateReq.MaxQueueSizeBytes, + DropNonRetryableData: true, + } +) + +func TestCreateAndGetReplication(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + + insertRemote(t, testStore, replication.RemoteID) + + // Getting an invalid ID should return an error. + got, err := testStore.GetReplication(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) + + // Create a replication, check the results. + created, err := testStore.CreateReplication(ctx, initID, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Read the created replication and assert it matches the creation response. + got, err = testStore.GetReplication(ctx, created.ID) + require.NoError(t, err) + require.Equal(t, replication, *got) +} + +func TestCreateMissingRemote(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + + created, err := testStore.CreateReplication(ctx, initID, createReq) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("remote %q not found", createReq.RemoteID)) + require.Nil(t, created) + + // Make sure nothing was persisted. + got, err := testStore.GetReplication(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) +} + +func TestUpdateAndGetReplication(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + + insertRemote(t, testStore, replication.RemoteID) + insertRemote(t, testStore, updatedReplication.RemoteID) + + // Updating a nonexistent ID fails. + updated, err := testStore.UpdateReplication(ctx, initID, updateReq) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, updated) + + // Create a replication. + created, err := testStore.CreateReplication(ctx, initID, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Update the replication. + updated, err = testStore.UpdateReplication(ctx, created.ID, updateReq) + require.NoError(t, err) + require.Equal(t, updatedReplication, *updated) +} + +func TestUpdateMissingRemote(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + + insertRemote(t, testStore, replication.RemoteID) + + // Create a replication. + created, err := testStore.CreateReplication(ctx, initID, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Attempt to update the replication to point at a nonexistent remote. + updated, err := testStore.UpdateReplication(ctx, created.ID, updateReq) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("remote %q not found", *updateReq.RemoteID)) + require.Nil(t, updated) + + // Make sure nothing changed in the DB. + got, err := testStore.GetReplication(ctx, created.ID) + require.NoError(t, err) + require.Equal(t, replication, *got) +} + +func TestUpdateNoop(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + + insertRemote(t, testStore, replication.RemoteID) + + // Create a replication. + created, err := testStore.CreateReplication(ctx, initID, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + // Send a no-op update, assert nothing changed. + updated, err := testStore.UpdateReplication(ctx, created.ID, influxdb.UpdateReplicationRequest{}) + require.NoError(t, err) + require.Equal(t, replication, *updated) +} + +func TestDeleteReplication(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + + insertRemote(t, testStore, replication.RemoteID) + + // Deleting a nonexistent ID should return an error. + require.Equal(t, errReplicationNotFound, testStore.DeleteReplication(ctx, initID)) + + // Create a replication, then delete it. + created, err := testStore.CreateReplication(ctx, initID, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + require.NoError(t, testStore.DeleteReplication(ctx, created.ID)) + + // Looking up the ID should again produce an error. + got, err := testStore.GetReplication(ctx, created.ID) + require.Equal(t, errReplicationNotFound, err) + require.Nil(t, got) +} + +func TestDeleteReplications(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + + // Deleting when there is no bucket is OK. + _, err := testStore.DeleteBucketReplications(ctx, replication.LocalBucketID) + require.NoError(t, err) + + // Register a handful of replications. + createReq2, createReq3 := createReq, createReq + createReq2.Name, createReq3.Name = "test2", "test3" + createReq2.LocalBucketID = platform.ID(77777) + createReq3.RemoteID = updatedReplication.RemoteID + insertRemote(t, testStore, createReq.RemoteID) + insertRemote(t, testStore, createReq3.RemoteID) + + for _, req := range []influxdb.CreateReplicationRequest{createReq, createReq2, createReq3} { + _, err := testStore.CreateReplication(ctx, snowflake.NewIDGenerator().ID(), req) + require.NoError(t, err) + } + + listed, err := testStore.ListReplications(ctx, influxdb.ReplicationListFilter{OrgID: replication.OrgID}) + require.NoError(t, err) + require.Len(t, listed.Replications, 3) + + // Delete 2/3 by bucket ID. + deleted, err := testStore.DeleteBucketReplications(ctx, createReq.LocalBucketID) + require.NoError(t, err) + require.Len(t, deleted, 2) + + // Ensure they were deleted. + listed, err = testStore.ListReplications(ctx, influxdb.ReplicationListFilter{OrgID: replication.OrgID}) + require.NoError(t, err) + require.Len(t, listed.Replications, 1) + require.Equal(t, createReq2.LocalBucketID, listed.Replications[0].LocalBucketID) +} + +func TestListReplications(t *testing.T) { + t.Parallel() + + createReq2, createReq3 := createReq, createReq + createReq2.Name, createReq3.Name = "test2", "test3" + createReq2.LocalBucketID = platform.ID(77777) + createReq3.RemoteID = updatedReplication.RemoteID + + setup := func(t *testing.T, testStore *Store) []influxdb.Replication { + insertRemote(t, testStore, createReq.RemoteID) + insertRemote(t, testStore, createReq3.RemoteID) + + var allReplications []influxdb.Replication + for _, req := range []influxdb.CreateReplicationRequest{createReq, createReq2, createReq3} { + created, err := testStore.CreateReplication(ctx, snowflake.NewIDGenerator().ID(), req) + require.NoError(t, err) + allReplications = append(allReplications, *created) + } + return allReplications + } + + t.Run("list all for org", func(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + allRepls := setup(t, testStore) + + listed, err := testStore.ListReplications(ctx, influxdb.ReplicationListFilter{OrgID: createReq.OrgID}) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{Replications: allRepls}, *listed) + }) + + t.Run("list all with empty filter", func(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + allRepls := setup(t, testStore) + + otherOrgReq := createReq + otherOrgReq.OrgID = platform.ID(12345) + created, err := testStore.CreateReplication(ctx, snowflake.NewIDGenerator().ID(), otherOrgReq) + require.NoError(t, err) + allRepls = append(allRepls, *created) + + listed, err := testStore.ListReplications(ctx, influxdb.ReplicationListFilter{}) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{Replications: allRepls}, *listed) + }) + + t.Run("list by name", func(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + allRepls := setup(t, testStore) + + listed, err := testStore.ListReplications(ctx, influxdb.ReplicationListFilter{ + OrgID: createReq.OrgID, + Name: &createReq2.Name, + }) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{Replications: allRepls[1:2]}, *listed) + }) + + t.Run("list by remote ID", func(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + allRepls := setup(t, testStore) + + listed, err := testStore.ListReplications(ctx, influxdb.ReplicationListFilter{ + OrgID: createReq.OrgID, + RemoteID: &createReq.RemoteID, + }) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{Replications: allRepls[0:2]}, *listed) + }) + + t.Run("list by bucket ID", func(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + allRepls := setup(t, testStore) + + listed, err := testStore.ListReplications(ctx, influxdb.ReplicationListFilter{ + OrgID: createReq.OrgID, + LocalBucketID: &createReq.LocalBucketID, + }) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{Replications: append(allRepls[0:1], allRepls[2:]...)}, *listed) + }) + + t.Run("list by other org ID", func(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + + listed, err := testStore.ListReplications(ctx, influxdb.ReplicationListFilter{OrgID: platform.ID(2)}) + require.NoError(t, err) + require.Equal(t, influxdb.Replications{}, *listed) + }) +} + +func TestGetFullHTTPConfig(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + + // Does not exist returns the appropriate error + _, err := testStore.GetFullHTTPConfig(ctx, initID) + require.Equal(t, errReplicationNotFound, err) + + // Valid result + insertRemote(t, testStore, replication.RemoteID) + created, err := testStore.CreateReplication(ctx, initID, createReq) + require.NoError(t, err) + require.Equal(t, replication, *created) + + conf, err := testStore.GetFullHTTPConfig(ctx, initID) + require.NoError(t, err) + require.Equal(t, httpConfig, *conf) +} + +func TestPopulateRemoteHTTPConfig(t *testing.T) { + t.Parallel() + + testStore, clean := newTestStore(t) + defer clean(t) + + emptyConfig := &ReplicationHTTPConfig{} + + // Remote not found returns the appropriate error + target := &ReplicationHTTPConfig{} + err := testStore.PopulateRemoteHTTPConfig(ctx, replication.RemoteID, target) + require.Equal(t, errRemoteNotFound(replication.RemoteID, nil), err) + require.Equal(t, emptyConfig, target) + + // Valid result + want := ReplicationHTTPConfig{ + RemoteURL: httpConfig.RemoteURL, + RemoteToken: httpConfig.RemoteToken, + RemoteOrgID: httpConfig.RemoteOrgID, + AllowInsecureTLS: httpConfig.AllowInsecureTLS, + } + insertRemote(t, testStore, replication.RemoteID) + err = testStore.PopulateRemoteHTTPConfig(ctx, replication.RemoteID, target) + require.NoError(t, err) + require.Equal(t, want, *target) +} + +func newTestStore(t *testing.T) (*Store, func(t *testing.T)) { + sqlStore, clean := sqlite.NewTestStore(t) + logger := zaptest.NewLogger(t) + sqliteMigrator := sqlite.NewMigrator(sqlStore, logger) + require.NoError(t, sqliteMigrator.Up(ctx, migrations.AllUp)) + + // Make sure foreign-key checking is enabled. + _, err := sqlStore.DB.Exec("PRAGMA foreign_keys = ON;") + require.NoError(t, err) + + return NewStore(sqlStore), clean +} + +func insertRemote(t *testing.T, store *Store, id platform.ID) { + sqlStore := store.sqlStore + + sqlStore.Mu.Lock() + defer sqlStore.Mu.Unlock() + + q := sq.Insert("remotes").SetMap(sq.Eq{ + "id": id, + "org_id": replication.OrgID, + "name": fmt.Sprintf("foo-%s", id), + "remote_url": fmt.Sprintf("http://%s.cloud", id), + "remote_api_token": id.String(), + "remote_org_id": platform.ID(888888), + "allow_insecure_tls": true, + "created_at": "datetime('now')", + "updated_at": "datetime('now')", + }) + query, args, err := q.ToSql() + require.NoError(t, err) + + _, err = sqlStore.DB.Exec(query, args...) + require.NoError(t, err) +} + +func boolPointer(b bool) *bool { + return &b +} diff --git a/replications/mock/service_store.go b/replications/mock/service_store.go new file mode 100644 index 00000000000..023d0b02bc8 --- /dev/null +++ b/replications/mock/service_store.go @@ -0,0 +1,180 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/influxdata/influxdb/v2/replications (interfaces: ServiceStore) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + influxdb "github.com/influxdata/influxdb/v2" + platform "github.com/influxdata/influxdb/v2/kit/platform" + internal "github.com/influxdata/influxdb/v2/replications/internal" +) + +// MockServiceStore is a mock of ServiceStore interface. +type MockServiceStore struct { + ctrl *gomock.Controller + recorder *MockServiceStoreMockRecorder +} + +// MockServiceStoreMockRecorder is the mock recorder for MockServiceStore. +type MockServiceStoreMockRecorder struct { + mock *MockServiceStore +} + +// NewMockServiceStore creates a new mock instance. +func NewMockServiceStore(ctrl *gomock.Controller) *MockServiceStore { + mock := &MockServiceStore{ctrl: ctrl} + mock.recorder = &MockServiceStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockServiceStore) EXPECT() *MockServiceStoreMockRecorder { + return m.recorder +} + +// CreateReplication mocks base method. +func (m *MockServiceStore) CreateReplication(arg0 context.Context, arg1 platform.ID, arg2 influxdb.CreateReplicationRequest) (*influxdb.Replication, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateReplication", arg0, arg1, arg2) + ret0, _ := ret[0].(*influxdb.Replication) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateReplication indicates an expected call of CreateReplication. +func (mr *MockServiceStoreMockRecorder) CreateReplication(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateReplication", reflect.TypeOf((*MockServiceStore)(nil).CreateReplication), arg0, arg1, arg2) +} + +// DeleteBucketReplications mocks base method. +func (m *MockServiceStore) DeleteBucketReplications(arg0 context.Context, arg1 platform.ID) ([]platform.ID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteBucketReplications", arg0, arg1) + ret0, _ := ret[0].([]platform.ID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteBucketReplications indicates an expected call of DeleteBucketReplications. +func (mr *MockServiceStoreMockRecorder) DeleteBucketReplications(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteBucketReplications", reflect.TypeOf((*MockServiceStore)(nil).DeleteBucketReplications), arg0, arg1) +} + +// DeleteReplication mocks base method. +func (m *MockServiceStore) DeleteReplication(arg0 context.Context, arg1 platform.ID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteReplication", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteReplication indicates an expected call of DeleteReplication. +func (mr *MockServiceStoreMockRecorder) DeleteReplication(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteReplication", reflect.TypeOf((*MockServiceStore)(nil).DeleteReplication), arg0, arg1) +} + +// GetFullHTTPConfig mocks base method. +func (m *MockServiceStore) GetFullHTTPConfig(arg0 context.Context, arg1 platform.ID) (*internal.ReplicationHTTPConfig, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetFullHTTPConfig", arg0, arg1) + ret0, _ := ret[0].(*internal.ReplicationHTTPConfig) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetFullHTTPConfig indicates an expected call of GetFullHTTPConfig. +func (mr *MockServiceStoreMockRecorder) GetFullHTTPConfig(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFullHTTPConfig", reflect.TypeOf((*MockServiceStore)(nil).GetFullHTTPConfig), arg0, arg1) +} + +// GetReplication mocks base method. +func (m *MockServiceStore) GetReplication(arg0 context.Context, arg1 platform.ID) (*influxdb.Replication, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetReplication", arg0, arg1) + ret0, _ := ret[0].(*influxdb.Replication) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetReplication indicates an expected call of GetReplication. +func (mr *MockServiceStoreMockRecorder) GetReplication(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReplication", reflect.TypeOf((*MockServiceStore)(nil).GetReplication), arg0, arg1) +} + +// ListReplications mocks base method. +func (m *MockServiceStore) ListReplications(arg0 context.Context, arg1 influxdb.ReplicationListFilter) (*influxdb.Replications, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListReplications", arg0, arg1) + ret0, _ := ret[0].(*influxdb.Replications) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListReplications indicates an expected call of ListReplications. +func (mr *MockServiceStoreMockRecorder) ListReplications(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListReplications", reflect.TypeOf((*MockServiceStore)(nil).ListReplications), arg0, arg1) +} + +// Lock mocks base method. +func (m *MockServiceStore) Lock() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Lock") +} + +// Lock indicates an expected call of Lock. +func (mr *MockServiceStoreMockRecorder) Lock() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lock", reflect.TypeOf((*MockServiceStore)(nil).Lock)) +} + +// PopulateRemoteHTTPConfig mocks base method. +func (m *MockServiceStore) PopulateRemoteHTTPConfig(arg0 context.Context, arg1 platform.ID, arg2 *internal.ReplicationHTTPConfig) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PopulateRemoteHTTPConfig", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// PopulateRemoteHTTPConfig indicates an expected call of PopulateRemoteHTTPConfig. +func (mr *MockServiceStoreMockRecorder) PopulateRemoteHTTPConfig(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopulateRemoteHTTPConfig", reflect.TypeOf((*MockServiceStore)(nil).PopulateRemoteHTTPConfig), arg0, arg1, arg2) +} + +// Unlock mocks base method. +func (m *MockServiceStore) Unlock() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Unlock") +} + +// Unlock indicates an expected call of Unlock. +func (mr *MockServiceStoreMockRecorder) Unlock() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockServiceStore)(nil).Unlock)) +} + +// UpdateReplication mocks base method. +func (m *MockServiceStore) UpdateReplication(arg0 context.Context, arg1 platform.ID, arg2 influxdb.UpdateReplicationRequest) (*influxdb.Replication, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateReplication", arg0, arg1, arg2) + ret0, _ := ret[0].(*influxdb.Replication) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateReplication indicates an expected call of UpdateReplication. +func (mr *MockServiceStoreMockRecorder) UpdateReplication(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateReplication", reflect.TypeOf((*MockServiceStore)(nil).UpdateReplication), arg0, arg1, arg2) +} diff --git a/replications/service.go b/replications/service.go index 00a0409ac6a..877330c7f9e 100644 --- a/replications/service.go +++ b/replications/service.go @@ -4,13 +4,10 @@ import ( "bytes" "compress/gzip" "context" - "database/sql" - "errors" "fmt" "path/filepath" "sync" - sq "github.com/Masterminds/squirrel" "github.com/influxdata/influxdb/v2" "github.com/influxdata/influxdb/v2/kit/platform" ierrors "github.com/influxdata/influxdb/v2/kit/platform/errors" @@ -20,24 +17,10 @@ import ( "github.com/influxdata/influxdb/v2/snowflake" "github.com/influxdata/influxdb/v2/sqlite" "github.com/influxdata/influxdb/v2/storage" - "github.com/mattn/go-sqlite3" "go.uber.org/zap" "golang.org/x/sync/errgroup" ) -var errReplicationNotFound = &ierrors.Error{ - Code: ierrors.ENotFound, - Msg: "replication not found", -} - -func errRemoteNotFound(id platform.ID, cause error) error { - return &ierrors.Error{ - Code: ierrors.EInvalid, - Msg: fmt.Sprintf("remote %q not found", id), - Err: cause, - } -} - func errLocalBucketNotFound(id platform.ID, cause error) error { return &ierrors.Error{ Code: ierrors.EInvalid, @@ -46,11 +29,11 @@ func errLocalBucketNotFound(id platform.ID, cause error) error { } } -func NewService(store *sqlite.SqlStore, bktSvc BucketService, localWriter storage.PointsWriter, log *zap.Logger, enginePath string) (*service, *metrics.ReplicationsMetrics) { +func NewService(sqlStore *sqlite.SqlStore, bktSvc BucketService, localWriter storage.PointsWriter, log *zap.Logger, enginePath string) (*service, *metrics.ReplicationsMetrics) { metrs := metrics.NewReplicationsMetrics() return &service{ - store: store, + store: internal.NewStore(sqlStore), idGenerator: snowflake.NewIDGenerator(), bucketService: bktSvc, localWriter: localWriter, @@ -85,8 +68,21 @@ type DurableQueueManager interface { EnqueueData(replicationID platform.ID, data []byte, numPoints int) error } +type ServiceStore interface { + Lock() + Unlock() + ListReplications(context.Context, influxdb.ReplicationListFilter) (*influxdb.Replications, error) + CreateReplication(context.Context, platform.ID, influxdb.CreateReplicationRequest) (*influxdb.Replication, error) + GetReplication(context.Context, platform.ID) (*influxdb.Replication, error) + UpdateReplication(context.Context, platform.ID, influxdb.UpdateReplicationRequest) (*influxdb.Replication, error) + DeleteReplication(context.Context, platform.ID) error + PopulateRemoteHTTPConfig(context.Context, platform.ID, *internal.ReplicationHTTPConfig) error + GetFullHTTPConfig(context.Context, platform.ID) (*internal.ReplicationHTTPConfig, error) + DeleteBucketReplications(context.Context, platform.ID) ([]platform.ID, error) +} + type service struct { - store *sqlite.SqlStore + store ServiceStore idGenerator platform.IDGenerator bucketService BucketService validator ReplicationValidator @@ -96,34 +92,13 @@ type service struct { } func (s service) ListReplications(ctx context.Context, filter influxdb.ReplicationListFilter) (*influxdb.Replications, error) { - q := sq.Select( - "id", "org_id", "name", "description", "remote_id", "local_bucket_id", "remote_bucket_id", - "max_queue_size_bytes", "latest_response_code", "latest_error_message", "drop_non_retryable_data"). - From("replications"). - Where(sq.Eq{"org_id": filter.OrgID}) - - if filter.Name != nil { - q = q.Where(sq.Eq{"name": *filter.Name}) - } - if filter.RemoteID != nil { - q = q.Where(sq.Eq{"remote_id": *filter.RemoteID}) - } - if filter.LocalBucketID != nil { - q = q.Where(sq.Eq{"local_bucket_id": *filter.LocalBucketID}) - } - - query, args, err := q.ToSql() + rs, err := s.store.ListReplications(ctx, filter) if err != nil { return nil, err } - var rs influxdb.Replications - if err := s.store.DB.SelectContext(ctx, &rs.Replications, query, args...); err != nil { - return nil, err - } - if len(rs.Replications) == 0 { - return &rs, nil + return rs, nil } ids := make([]platform.ID, len(rs.Replications)) @@ -138,15 +113,15 @@ func (s service) ListReplications(ctx context.Context, filter influxdb.Replicati rs.Replications[i].CurrentQueueSizeBytes = sizes[rs.Replications[i].ID] } - return &rs, nil + return rs, nil } func (s service) CreateReplication(ctx context.Context, request influxdb.CreateReplicationRequest) (*influxdb.Replication, error) { s.bucketService.RLock() defer s.bucketService.RUnlock() - s.store.Mu.Lock() - defer s.store.Mu.Unlock() + s.store.Lock() + defer s.store.Unlock() if _, err := s.bucketService.FindBucketByID(ctx, request.LocalBucketID); err != nil { return nil, errLocalBucketNotFound(request.LocalBucketID, err) @@ -157,46 +132,16 @@ func (s service) CreateReplication(ctx context.Context, request influxdb.CreateR return nil, err } - q := sq.Insert("replications"). - SetMap(sq.Eq{ - "id": newID, - "org_id": request.OrgID, - "name": request.Name, - "description": request.Description, - "remote_id": request.RemoteID, - "local_bucket_id": request.LocalBucketID, - "remote_bucket_id": request.RemoteBucketID, - "max_queue_size_bytes": request.MaxQueueSizeBytes, - "drop_non_retryable_data": request.DropNonRetryableData, - "created_at": "datetime('now')", - "updated_at": "datetime('now')", - }). - Suffix("RETURNING id, org_id, name, description, remote_id, local_bucket_id, remote_bucket_id, max_queue_size_bytes, drop_non_retryable_data") - - cleanupQueue := func() { + r, err := s.store.CreateReplication(ctx, newID, request) + if err != nil { if cleanupErr := s.durableQueueManager.DeleteQueue(newID); cleanupErr != nil { s.log.Warn("durable queue remaining on disk after initialization failure", zap.Error(cleanupErr), zap.String("id", newID.String())) } - } - - query, args, err := q.ToSql() - if err != nil { - cleanupQueue() - return nil, err - } - - var r influxdb.Replication - if err := s.store.DB.GetContext(ctx, &r, query, args...); err != nil { - if sqlErr, ok := err.(sqlite3.Error); ok && sqlErr.ExtendedCode == sqlite3.ErrConstraintForeignKey { - cleanupQueue() - return nil, errRemoteNotFound(request.RemoteID, err) - } - cleanupQueue() return nil, err } - return &r, nil + return r, nil } func (s service) ValidateNewReplication(ctx context.Context, request influxdb.CreateReplicationRequest) error { @@ -205,7 +150,7 @@ func (s service) ValidateNewReplication(ctx context.Context, request influxdb.Cr } config := internal.ReplicationHTTPConfig{RemoteBucketID: request.RemoteBucketID} - if err := s.populateRemoteHTTPConfig(ctx, request.RemoteID, &config); err != nil { + if err := s.store.PopulateRemoteHTTPConfig(ctx, request.RemoteID, &config); err != nil { return err } @@ -220,77 +165,29 @@ func (s service) ValidateNewReplication(ctx context.Context, request influxdb.Cr } func (s service) GetReplication(ctx context.Context, id platform.ID) (*influxdb.Replication, error) { - q := sq.Select( - "id", "org_id", "name", "description", "remote_id", "local_bucket_id", "remote_bucket_id", - "max_queue_size_bytes", "latest_response_code", "latest_error_message", "drop_non_retryable_data"). - From("replications"). - Where(sq.Eq{"id": id}) - - query, args, err := q.ToSql() + r, err := s.store.GetReplication(ctx, id) if err != nil { return nil, err } - var r influxdb.Replication - if err := s.store.DB.GetContext(ctx, &r, query, args...); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, errReplicationNotFound - } - return nil, err - } - sizes, err := s.durableQueueManager.CurrentQueueSizes([]platform.ID{r.ID}) if err != nil { return nil, err } r.CurrentQueueSizeBytes = sizes[r.ID] - return &r, nil + return r, nil } func (s service) UpdateReplication(ctx context.Context, id platform.ID, request influxdb.UpdateReplicationRequest) (*influxdb.Replication, error) { - s.store.Mu.Lock() - defer s.store.Mu.Unlock() - - updates := sq.Eq{"updated_at": sq.Expr("datetime('now')")} - if request.Name != nil { - updates["name"] = *request.Name - } - if request.Description != nil { - updates["description"] = *request.Description - } - if request.RemoteID != nil { - updates["remote_id"] = *request.RemoteID - } - if request.RemoteBucketID != nil { - updates["remote_bucket_id"] = *request.RemoteBucketID - } - if request.MaxQueueSizeBytes != nil { - updates["max_queue_size_bytes"] = *request.MaxQueueSizeBytes - } - if request.DropNonRetryableData != nil { - updates["drop_non_retryable_data"] = *request.DropNonRetryableData - } - - q := sq.Update("replications").SetMap(updates).Where(sq.Eq{"id": id}). - Suffix("RETURNING id, org_id, name, description, remote_id, local_bucket_id, remote_bucket_id, max_queue_size_bytes, drop_non_retryable_data") + s.store.Lock() + defer s.store.Unlock() - query, args, err := q.ToSql() + r, err := s.store.UpdateReplication(ctx, id, request) if err != nil { return nil, err } - var r influxdb.Replication - if err := s.store.DB.GetContext(ctx, &r, query, args...); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, errReplicationNotFound - } - if sqlErr, ok := err.(sqlite3.Error); ok && request.RemoteID != nil && sqlErr.ExtendedCode == sqlite3.ErrConstraintForeignKey { - return nil, errRemoteNotFound(*request.RemoteID, err) - } - return nil, err - } - if request.MaxQueueSizeBytes != nil { if err := s.durableQueueManager.UpdateMaxQueueSize(id, *request.MaxQueueSizeBytes); err != nil { s.log.Warn("actual max queue size does not match the max queue size recorded in database", zap.String("id", id.String())) @@ -304,11 +201,11 @@ func (s service) UpdateReplication(ctx context.Context, id platform.ID, request } r.CurrentQueueSizeBytes = sizes[r.ID] - return &r, nil + return r, nil } func (s service) ValidateUpdatedReplication(ctx context.Context, id platform.ID, request influxdb.UpdateReplicationRequest) error { - baseConfig, err := s.getFullHTTPConfig(ctx, id) + baseConfig, err := s.store.GetFullHTTPConfig(ctx, id) if err != nil { return err } @@ -317,7 +214,7 @@ func (s service) ValidateUpdatedReplication(ctx context.Context, id platform.ID, } if request.RemoteID != nil { - if err := s.populateRemoteHTTPConfig(ctx, *request.RemoteID, baseConfig); err != nil { + if err := s.store.PopulateRemoteHTTPConfig(ctx, *request.RemoteID, baseConfig); err != nil { return err } } @@ -333,20 +230,10 @@ func (s service) ValidateUpdatedReplication(ctx context.Context, id platform.ID, } func (s service) DeleteReplication(ctx context.Context, id platform.ID) error { - s.store.Mu.Lock() - defer s.store.Mu.Unlock() + s.store.Lock() + defer s.store.Unlock() - q := sq.Delete("replications").Where(sq.Eq{"id": id}).Suffix("RETURNING id") - query, args, err := q.ToSql() - if err != nil { - return err - } - - var d platform.ID - if err := s.store.DB.GetContext(ctx, &d, query, args...); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return errReplicationNotFound - } + if err := s.store.DeleteReplication(ctx, id); err != nil { return err } @@ -358,36 +245,27 @@ func (s service) DeleteReplication(ctx context.Context, id platform.ID) error { } func (s service) DeleteBucketReplications(ctx context.Context, localBucketID platform.ID) error { - s.store.Mu.Lock() - defer s.store.Mu.Unlock() + s.store.Lock() + defer s.store.Unlock() - q := sq.Delete("replications").Where(sq.Eq{"local_bucket_id": localBucketID}).Suffix("RETURNING id") - query, args, err := q.ToSql() + deletedIDs, err := s.store.DeleteBucketReplications(ctx, localBucketID) if err != nil { return err } - var deleted []string - if err := s.store.DB.SelectContext(ctx, &deleted, query, args...); err != nil { - return err - } - errOccurred := false - for _, replication := range deleted { - id, err := platform.IDFromString(replication) - if err != nil { - s.log.Error("durable queue remaining on disk after deletion failure", zap.Error(err), zap.String("id", replication)) + deletedStrings := make([]string, 0, len(deletedIDs)) + for _, id := range deletedIDs { + if err := s.durableQueueManager.DeleteQueue(id); err != nil { + s.log.Error("durable queue remaining on disk after deletion failure", zap.Error(err), zap.String("id", id.String())) errOccurred = true } - if err := s.durableQueueManager.DeleteQueue(*id); err != nil { - s.log.Error("durable queue remaining on disk after deletion failure", zap.Error(err), zap.String("id", replication)) - errOccurred = true - } + deletedStrings = append(deletedStrings, id.String()) } - s.log.Debug("Deleted all replications for local bucket", - zap.String("bucket_id", localBucketID.String()), zap.Strings("ids", deleted)) + s.log.Debug("deleted replications for local bucket", + zap.String("bucket_id", localBucketID.String()), zap.Strings("ids", deletedStrings)) if errOccurred { return fmt.Errorf("deleting replications for bucket %q failed, see server logs for details", localBucketID) @@ -397,7 +275,7 @@ func (s service) DeleteBucketReplications(ctx context.Context, localBucketID pla } func (s service) ValidateReplication(ctx context.Context, id platform.ID) error { - config, err := s.getFullHTTPConfig(ctx, id) + config, err := s.store.GetFullHTTPConfig(ctx, id) if err != nil { return err } @@ -412,19 +290,16 @@ func (s service) ValidateReplication(ctx context.Context, id platform.ID) error } func (s service) WritePoints(ctx context.Context, orgID platform.ID, bucketID platform.ID, points []models.Point) error { - q := sq.Select("id").From("replications").Where(sq.Eq{"org_id": orgID, "local_bucket_id": bucketID}) - query, args, err := q.ToSql() + repls, err := s.store.ListReplications(ctx, influxdb.ReplicationListFilter{ + OrgID: orgID, + LocalBucketID: &bucketID, + }) if err != nil { return err } - var ids []platform.ID - if err := s.store.DB.SelectContext(ctx, &ids, query, args...); err != nil { - return err - } - // If there are no registered replications, all we need to do is a local write. - if len(ids) == 0 { + if len(repls.Replications) == 0 { return s.localWriter.WritePoints(ctx, orgID, bucketID, points) } @@ -459,75 +334,27 @@ func (s service) WritePoints(ctx context.Context, orgID platform.ID, bucketID pl // Enqueue the data into all registered replications. var wg sync.WaitGroup - wg.Add(len(ids)) - for _, id := range ids { + wg.Add(len(repls.Replications)) + for _, rep := range repls.Replications { go func(id platform.ID) { defer wg.Done() if err := s.durableQueueManager.EnqueueData(id, buf.Bytes(), len(points)); err != nil { s.log.Error("Failed to enqueue points for replication", zap.String("id", id.String()), zap.Error(err)) } - }(id) + }(rep.ID) } wg.Wait() return nil } -func (s service) getFullHTTPConfig(ctx context.Context, id platform.ID) (*internal.ReplicationHTTPConfig, error) { - q := sq.Select("c.remote_url", "c.remote_api_token", "c.remote_org_id", "c.allow_insecure_tls", "r.remote_bucket_id"). - From("replications r").InnerJoin("remotes c ON r.remote_id = c.id AND r.id = ?", id) - - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - - var rc internal.ReplicationHTTPConfig - if err := s.store.DB.GetContext(ctx, &rc, query, args...); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, errReplicationNotFound - } - return nil, err - } - return &rc, nil -} - -func (s service) populateRemoteHTTPConfig(ctx context.Context, id platform.ID, target *internal.ReplicationHTTPConfig) error { - q := sq.Select("remote_url", "remote_api_token", "remote_org_id", "allow_insecure_tls"). - From("remotes").Where(sq.Eq{"id": id}) - query, args, err := q.ToSql() - if err != nil { - return err - } - - if err := s.store.DB.GetContext(ctx, target, query, args...); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return errRemoteNotFound(id, nil) - } - return err - } - - return nil -} - func (s service) Open(ctx context.Context) error { - var trackedReplications influxdb.Replications - - // Get replications from sqlite - q := sq.Select( - "id", "max_queue_size_bytes"). - From("replications") - - query, args, err := q.ToSql() + trackedReplications, err := s.store.ListReplications(ctx, influxdb.ReplicationListFilter{}) if err != nil { return err } - if err := s.store.DB.SelectContext(ctx, &trackedReplications.Replications, query, args...); err != nil { - return err - } - trackedReplicationsMap := make(map[platform.ID]int64) for _, r := range trackedReplications.Replications { trackedReplicationsMap[r.ID] = r.MaxQueueSizeBytes diff --git a/replications/service_test.go b/replications/service_test.go index 63e5f36777c..8646f2feec1 100644 --- a/replications/service_test.go +++ b/replications/service_test.go @@ -8,16 +8,14 @@ import ( "fmt" "testing" - sq "github.com/Masterminds/squirrel" "github.com/golang/mock/gomock" "github.com/influxdata/influxdb/v2" "github.com/influxdata/influxdb/v2/kit/platform" + ierrors "github.com/influxdata/influxdb/v2/kit/platform/errors" "github.com/influxdata/influxdb/v2/mock" "github.com/influxdata/influxdb/v2/models" "github.com/influxdata/influxdb/v2/replications/internal" replicationsMock "github.com/influxdata/influxdb/v2/replications/mock" - "github.com/influxdata/influxdb/v2/sqlite" - "github.com/influxdata/influxdb/v2/sqlite/migrations" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) @@ -26,14 +24,27 @@ import ( //go:generate go run github.com/golang/mock/mockgen -package mock -destination ./mock/bucket_service.go github.com/influxdata/influxdb/v2/replications BucketService //go:generate go run github.com/golang/mock/mockgen -package mock -destination ./mock/queue_management.go github.com/influxdata/influxdb/v2/replications DurableQueueManager //go:generate go run github.com/golang/mock/mockgen -package mock -destination ./mock/points_writer.go github.com/influxdata/influxdb/v2/storage PointsWriter +//go:generate go run github.com/golang/mock/mockgen -package mock -destination ./mock/service_store.go github.com/influxdata/influxdb/v2/replications ServiceStore var ( - ctx = context.Background() - initID = platform.ID(1) - desc = "testing testing" - replication = influxdb.Replication{ - ID: initID, - OrgID: platform.ID(10), + ctx = context.Background() + orgID = platform.ID(10) + id1 = platform.ID(1) + id2 = platform.ID(2) + desc = "testing testing" + replication1 = influxdb.Replication{ + ID: id1, + OrgID: orgID, + Name: "test", + Description: &desc, + RemoteID: platform.ID(100), + LocalBucketID: platform.ID(1000), + RemoteBucketID: platform.ID(99999), + MaxQueueSizeBytes: 3 * influxdb.DefaultReplicationMaxQueueSizeBytes, + } + replication2 = influxdb.Replication{ + ID: id2, + OrgID: orgID, Name: "test", Description: &desc, RemoteID: platform.ID(100), @@ -42,592 +53,612 @@ var ( MaxQueueSizeBytes: 3 * influxdb.DefaultReplicationMaxQueueSizeBytes, } createReq = influxdb.CreateReplicationRequest{ - OrgID: replication.OrgID, - Name: replication.Name, - Description: replication.Description, - RemoteID: replication.RemoteID, - LocalBucketID: replication.LocalBucketID, - RemoteBucketID: replication.RemoteBucketID, - MaxQueueSizeBytes: replication.MaxQueueSizeBytes, + OrgID: replication1.OrgID, + Name: replication1.Name, + Description: replication1.Description, + RemoteID: replication1.RemoteID, + LocalBucketID: replication1.LocalBucketID, + RemoteBucketID: replication1.RemoteBucketID, + MaxQueueSizeBytes: replication1.MaxQueueSizeBytes, + } + newRemoteID = platform.ID(200) + newQueueSize = influxdb.MinReplicationMaxQueueSizeBytes + updateReqWithNewSize = influxdb.UpdateReplicationRequest{ + RemoteID: &newRemoteID, + MaxQueueSizeBytes: &newQueueSize, + } + updatedReplicationWithNewSize = influxdb.Replication{ + ID: replication1.ID, + OrgID: replication1.OrgID, + Name: replication1.Name, + Description: replication1.Description, + RemoteID: *updateReqWithNewSize.RemoteID, + LocalBucketID: replication1.LocalBucketID, + RemoteBucketID: replication1.RemoteBucketID, + MaxQueueSizeBytes: *updateReqWithNewSize.MaxQueueSizeBytes, + } + updateReqWithNoNewSize = influxdb.UpdateReplicationRequest{ + RemoteID: &newRemoteID, + } + updatedReplicationWithNoNewSize = influxdb.Replication{ + ID: replication1.ID, + OrgID: replication1.OrgID, + Name: replication1.Name, + Description: replication1.Description, + RemoteID: *updateReqWithNewSize.RemoteID, + LocalBucketID: replication1.LocalBucketID, + RemoteBucketID: replication1.RemoteBucketID, + MaxQueueSizeBytes: replication1.MaxQueueSizeBytes, } httpConfig = internal.ReplicationHTTPConfig{ - RemoteURL: fmt.Sprintf("http://%s.cloud", replication.RemoteID), - RemoteToken: replication.RemoteID.String(), - RemoteOrgID: platform.ID(888888), - AllowInsecureTLS: true, - RemoteBucketID: replication.RemoteBucketID, - } - newRemoteID = platform.ID(200) - newQueueSize = influxdb.MinReplicationMaxQueueSizeBytes - updateReq = influxdb.UpdateReplicationRequest{ - RemoteID: &newRemoteID, - MaxQueueSizeBytes: &newQueueSize, - DropNonRetryableData: boolPointer(true), - } - updatedReplication = influxdb.Replication{ - ID: replication.ID, - OrgID: replication.OrgID, - Name: replication.Name, - Description: replication.Description, - RemoteID: *updateReq.RemoteID, - LocalBucketID: replication.LocalBucketID, - RemoteBucketID: replication.RemoteBucketID, - MaxQueueSizeBytes: *updateReq.MaxQueueSizeBytes, - DropNonRetryableData: true, - } - updatedHttpConfig = internal.ReplicationHTTPConfig{ - RemoteURL: fmt.Sprintf("http://%s.cloud", updatedReplication.RemoteID), - RemoteToken: updatedReplication.RemoteID.String(), + RemoteURL: fmt.Sprintf("http://%s.cloud", replication1.RemoteID), + RemoteToken: replication1.RemoteID.String(), RemoteOrgID: platform.ID(888888), AllowInsecureTLS: true, - RemoteBucketID: updatedReplication.RemoteBucketID, + RemoteBucketID: replication1.RemoteBucketID, } ) -func TestCreateAndGetReplication(t *testing.T) { +func TestListReplications(t *testing.T) { t.Parallel() - svc, mocks, clean := newTestService(t) - defer clean(t) + filter := influxdb.ReplicationListFilter{} + + tests := []struct { + name string + list influxdb.Replications + ids []platform.ID + sizes map[platform.ID]int64 + storeErr error + queueManagerErr error + }{ + { + name: "matches multiple", + list: influxdb.Replications{ + Replications: []influxdb.Replication{replication1, replication2}, + }, + ids: []platform.ID{replication1.ID, replication2.ID}, + sizes: map[platform.ID]int64{replication1.ID: 1000, replication2.ID: 2000}, + }, + { + name: "matches one", + list: influxdb.Replications{ + Replications: []influxdb.Replication{replication1}, + }, + ids: []platform.ID{replication1.ID}, + sizes: map[platform.ID]int64{replication1.ID: 1000}, + }, + { + name: "matches none", + list: influxdb.Replications{}, + }, + { + name: "store error", + storeErr: errors.New("error from store"), + }, + { + name: "queue manager error", + list: influxdb.Replications{ + Replications: []influxdb.Replication{replication1}, + }, + ids: []platform.ID{replication1.ID}, + queueManagerErr: errors.New("error from queue manager"), + }, + } - insertRemote(t, svc.store, replication.RemoteID) - mocks.bucketSvc.EXPECT().RLock() - mocks.bucketSvc.EXPECT().RUnlock() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). - Return(&influxdb.Bucket{}, nil) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, mocks := newTestService(t) - // Getting or validating an invalid ID should return an error. - got, err := svc.GetReplication(ctx, initID) - require.Equal(t, errReplicationNotFound, err) - require.Nil(t, got) - require.Equal(t, errReplicationNotFound, svc.ValidateReplication(ctx, initID)) + mocks.serviceStore.EXPECT().ListReplications(gomock.Any(), filter).Return(&tt.list, tt.storeErr) - // Create a replication, check the results. - mocks.durableQueueManager.EXPECT().InitializeQueue(initID, createReq.MaxQueueSizeBytes) - created, err := svc.CreateReplication(ctx, createReq) - require.NoError(t, err) - require.Equal(t, replication, *created) + if tt.storeErr == nil && len(tt.list.Replications) > 0 { + mocks.durableQueueManager.EXPECT().CurrentQueueSizes(tt.ids).Return(tt.sizes, tt.queueManagerErr) + } - // Read the created replication and assert it matches the creation response. - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID}). - Return(map[platform.ID]int64{initID: replication.CurrentQueueSizeBytes}, nil) - got, err = svc.GetReplication(ctx, initID) - require.NoError(t, err) - require.Equal(t, replication, *got) + got, err := svc.ListReplications(ctx, filter) + + var wantErr error + if tt.storeErr != nil { + wantErr = tt.storeErr + } else if tt.queueManagerErr != nil { + wantErr = tt.queueManagerErr + } + + require.Equal(t, wantErr, err) + + if wantErr != nil { + require.Nil(t, got) + return + } - // Validate the replication; this is mostly a no-op for this test, but it allows - // us to check that our sql for extracting the linked remote's parameters is correct. - fakeErr := errors.New("O NO") - mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &httpConfig).Return(fakeErr) - require.Contains(t, svc.ValidateReplication(ctx, initID).Error(), fakeErr.Error()) + for _, r := range got.Replications { + require.Equal(t, tt.sizes[r.ID], r.CurrentQueueSizeBytes) + } + }) + } } -func TestCreateMissingBucket(t *testing.T) { +func TestCreateReplication(t *testing.T) { t.Parallel() - svc, mocks, clean := newTestService(t) - defer clean(t) + tests := []struct { + name string + create influxdb.CreateReplicationRequest + storeErr error + bucketErr error + queueManagerErr error + want *influxdb.Replication + wantErr error + }{ + { + name: "success", + create: createReq, + want: &replication1, + }, + { + name: "bucket service error", + create: createReq, + bucketErr: errors.New("bucket service error"), + wantErr: errLocalBucketNotFound(createReq.LocalBucketID, errors.New("bucket service error")), + }, + { + name: "initialize queue error", + create: createReq, + queueManagerErr: errors.New("queue manager error"), + wantErr: errors.New("queue manager error"), + }, + { + name: "store create error", + create: createReq, + storeErr: errors.New("store create error"), + wantErr: errors.New("store create error"), + }, + } - insertRemote(t, svc.store, replication.RemoteID) - bucketNotFound := errors.New("bucket not found") - mocks.bucketSvc.EXPECT().RLock() - mocks.bucketSvc.EXPECT().RUnlock() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). - Return(nil, bucketNotFound) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, mocks := newTestService(t) - created, err := svc.CreateReplication(ctx, createReq) - require.Equal(t, errLocalBucketNotFound(createReq.LocalBucketID, bucketNotFound), err) - require.Nil(t, created) + mocks.bucketSvc.EXPECT().RLock() + mocks.bucketSvc.EXPECT().RUnlock() + mocks.serviceStore.EXPECT().Lock() + mocks.serviceStore.EXPECT().Unlock() - // Make sure nothing was persisted. - got, err := svc.GetReplication(ctx, initID) - require.Equal(t, errReplicationNotFound, err) - require.Nil(t, got) -} + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), tt.create.LocalBucketID).Return(nil, tt.bucketErr) -func TestCreateMissingRemote(t *testing.T) { - t.Parallel() + if tt.bucketErr == nil { + mocks.durableQueueManager.EXPECT().InitializeQueue(id1, tt.create.MaxQueueSizeBytes).Return(tt.queueManagerErr) + } - svc, mocks, clean := newTestService(t) - defer clean(t) - - mocks.bucketSvc.EXPECT().RLock() - mocks.bucketSvc.EXPECT().RUnlock() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). - Return(&influxdb.Bucket{}, nil) - - mocks.durableQueueManager.EXPECT().InitializeQueue(initID, createReq.MaxQueueSizeBytes) - mocks.durableQueueManager.EXPECT().DeleteQueue(initID) - created, err := svc.CreateReplication(ctx, createReq) - require.Error(t, err) - require.Contains(t, err.Error(), fmt.Sprintf("remote %q not found", createReq.RemoteID)) - require.Nil(t, created) - - // Make sure nothing was persisted. - got, err := svc.GetReplication(ctx, initID) - require.Equal(t, errReplicationNotFound, err) - require.Nil(t, got) -} + if tt.queueManagerErr == nil && tt.bucketErr == nil { + mocks.serviceStore.EXPECT().CreateReplication(gomock.Any(), id1, tt.create).Return(tt.want, tt.storeErr) + } -func TestValidateReplicationWithoutPersisting(t *testing.T) { - t.Parallel() + if tt.storeErr != nil { + mocks.durableQueueManager.EXPECT().DeleteQueue(id1).Return(nil) + } - t.Run("missing bucket", func(t *testing.T) { - svc, mocks, clean := newTestService(t) - defer clean(t) + got, err := svc.CreateReplication(ctx, tt.create) + require.Equal(t, tt.want, got) + require.Equal(t, tt.wantErr, err) + }) + } +} - bucketNotFound := errors.New("bucket not found") - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(nil, bucketNotFound) +func TestValidateNewReplication(t *testing.T) { + t.Parallel() - require.Equal(t, errLocalBucketNotFound(createReq.LocalBucketID, bucketNotFound), - svc.ValidateNewReplication(ctx, createReq)) + tests := []struct { + name string + req influxdb.CreateReplicationRequest + storeErr error + bucketErr error + validatorErr error + wantErr error + }{ + { + name: "valid", + req: createReq, + }, + { + name: "bucket service error", + req: createReq, + bucketErr: errors.New("bucket service error"), + wantErr: errLocalBucketNotFound(createReq.LocalBucketID, errors.New("bucket service error")), + }, + { + name: "store populate error", + req: createReq, + storeErr: errors.New("store populate error"), + wantErr: errors.New("store populate error"), + }, + { + name: "validation error - invalid replication", + req: createReq, + validatorErr: errors.New("validation error"), + wantErr: &ierrors.Error{ + Code: ierrors.EInvalid, + Msg: "replication parameters fail validation", + Err: errors.New("validation error"), + }, + }, + } - got, err := svc.GetReplication(ctx, initID) - require.Equal(t, errReplicationNotFound, err) - require.Nil(t, got) - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, mocks := newTestService(t) - t.Run("missing remote", func(t *testing.T) { - svc, mocks, clean := newTestService(t) - defer clean(t) + mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), tt.req.LocalBucketID).Return(nil, tt.bucketErr) - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil) + testConfig := &internal.ReplicationHTTPConfig{RemoteBucketID: tt.req.RemoteBucketID} + if tt.bucketErr == nil { + mocks.serviceStore.EXPECT().PopulateRemoteHTTPConfig(gomock.Any(), tt.req.RemoteID, testConfig).Return(tt.storeErr) + } - require.Contains(t, svc.ValidateNewReplication(ctx, createReq).Error(), - fmt.Sprintf("remote %q not found", createReq.RemoteID)) + if tt.bucketErr == nil && tt.storeErr == nil { + mocks.validator.EXPECT().ValidateReplication(gomock.Any(), testConfig).Return(tt.validatorErr) + } - got, err := svc.GetReplication(ctx, initID) - require.Equal(t, errReplicationNotFound, err) - require.Nil(t, got) - }) + err := svc.ValidateNewReplication(ctx, tt.req) + require.Equal(t, tt.wantErr, err) + }) + } +} - t.Run("validation error", func(t *testing.T) { - svc, mocks, clean := newTestService(t) - defer clean(t) +func TestGetReplication(t *testing.T) { + t.Parallel() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil) - insertRemote(t, svc.store, createReq.RemoteID) + tests := []struct { + name string + sizes map[platform.ID]int64 + storeErr error + queueManagerErr error + storeWant influxdb.Replication + want influxdb.Replication + }{ + { + name: "success", + sizes: map[platform.ID]int64{replication1.ID: 1000}, + storeWant: replication1, + want: replication1, + }, + { + name: "store error", + storeErr: errors.New("store error"), + }, + { + name: "queue manager error", + storeWant: replication1, + queueManagerErr: errors.New("queue manager error"), + }, + } - fakeErr := errors.New("O NO") - mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &httpConfig).Return(fakeErr) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, mocks := newTestService(t) - require.Contains(t, svc.ValidateNewReplication(ctx, createReq).Error(), fakeErr.Error()) + mocks.serviceStore.EXPECT().GetReplication(gomock.Any(), id1).Return(&tt.storeWant, tt.storeErr) - got, err := svc.GetReplication(ctx, initID) - require.Equal(t, errReplicationNotFound, err) - require.Nil(t, got) - }) + if tt.storeErr == nil { + mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{id1}).Return(tt.sizes, tt.queueManagerErr) + } - t.Run("no error", func(t *testing.T) { - svc, mocks, clean := newTestService(t) - defer clean(t) + got, err := svc.GetReplication(ctx, id1) - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil) - insertRemote(t, svc.store, createReq.RemoteID) + var wantErr error + if tt.storeErr != nil { + wantErr = tt.storeErr + } else if tt.queueManagerErr != nil { + wantErr = tt.queueManagerErr + } - mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &httpConfig).Return(nil) + require.Equal(t, wantErr, err) - require.NoError(t, svc.ValidateNewReplication(ctx, createReq)) + if wantErr != nil { + require.Nil(t, got) + return + } - got, err := svc.GetReplication(ctx, initID) - require.Equal(t, errReplicationNotFound, err) - require.Nil(t, got) - }) + require.Equal(t, tt.sizes[got.ID], got.CurrentQueueSizeBytes) + }) + } } -func TestUpdateAndGetReplication(t *testing.T) { +func TestUpdateReplication(t *testing.T) { t.Parallel() - svc, mocks, clean := newTestService(t) - defer clean(t) - - insertRemote(t, svc.store, replication.RemoteID) - insertRemote(t, svc.store, updatedReplication.RemoteID) - mocks.bucketSvc.EXPECT().RLock() - mocks.bucketSvc.EXPECT().RUnlock() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). - Return(&influxdb.Bucket{}, nil) + tests := []struct { + name string + request influxdb.UpdateReplicationRequest + sizes map[platform.ID]int64 + storeErr error + queueManagerUpdateSizeErr error + queueManagerCurrentSizesErr error + storeUpdate *influxdb.Replication + want *influxdb.Replication + wantErr error + }{ + { + name: "success with new max queue size", + request: updateReqWithNewSize, + sizes: map[platform.ID]int64{replication1.ID: *updateReqWithNewSize.MaxQueueSizeBytes}, + storeUpdate: &updatedReplicationWithNewSize, + want: &updatedReplicationWithNewSize, + }, + { + name: "success with no new max queue size", + request: updateReqWithNoNewSize, + sizes: map[platform.ID]int64{replication1.ID: updatedReplicationWithNoNewSize.MaxQueueSizeBytes}, + storeUpdate: &updatedReplicationWithNoNewSize, + want: &updatedReplicationWithNoNewSize, + }, + { + name: "store error", + request: updateReqWithNoNewSize, + storeErr: errors.New("store error"), + wantErr: errors.New("store error"), + }, + { + name: "queue manager error - update max queue size", + request: updateReqWithNewSize, + queueManagerUpdateSizeErr: errors.New("update max size err"), + wantErr: errors.New("update max size err"), + }, + { + name: "queue manager error - current queue size", + request: updateReqWithNoNewSize, + queueManagerCurrentSizesErr: errors.New("current size err"), + storeUpdate: &updatedReplicationWithNoNewSize, + wantErr: errors.New("current size err"), + }, + } - // Updating a nonexistent ID fails. - updated, err := svc.UpdateReplication(ctx, initID, updateReq) - require.Equal(t, errReplicationNotFound, err) - require.Nil(t, updated) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, mocks := newTestService(t) - // Create a replication. - mocks.durableQueueManager.EXPECT().InitializeQueue(initID, createReq.MaxQueueSizeBytes) - created, err := svc.CreateReplication(ctx, createReq) - require.NoError(t, err) - require.Equal(t, replication, *created) + mocks.serviceStore.EXPECT().Lock() + mocks.serviceStore.EXPECT().Unlock() - // Update the replication. - mocks.durableQueueManager.EXPECT().UpdateMaxQueueSize(initID, *updateReq.MaxQueueSizeBytes) - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID}). - Return(map[platform.ID]int64{initID: replication.CurrentQueueSizeBytes}, nil) - updated, err = svc.UpdateReplication(ctx, initID, updateReq) - require.NoError(t, err) - require.Equal(t, updatedReplication, *updated) -} - -func TestUpdateMissingRemote(t *testing.T) { - t.Parallel() + mocks.serviceStore.EXPECT().UpdateReplication(gomock.Any(), id1, tt.request).Return(tt.storeUpdate, tt.storeErr) - svc, mocks, clean := newTestService(t) - defer clean(t) + if tt.storeErr == nil && tt.request.MaxQueueSizeBytes != nil { + mocks.durableQueueManager.EXPECT().UpdateMaxQueueSize(id1, *tt.request.MaxQueueSizeBytes).Return(tt.queueManagerUpdateSizeErr) + } - insertRemote(t, svc.store, replication.RemoteID) - mocks.bucketSvc.EXPECT().RLock() - mocks.bucketSvc.EXPECT().RUnlock() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). - Return(&influxdb.Bucket{}, nil) + if tt.storeErr == nil && tt.queueManagerUpdateSizeErr == nil { + mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{id1}).Return(tt.sizes, tt.queueManagerCurrentSizesErr) + } - // Create a replication. - mocks.durableQueueManager.EXPECT().InitializeQueue(initID, createReq.MaxQueueSizeBytes) - created, err := svc.CreateReplication(ctx, createReq) - require.NoError(t, err) - require.Equal(t, replication, *created) - - // Attempt to update the replication to point at a nonexistent remote. - updated, err := svc.UpdateReplication(ctx, initID, updateReq) - require.Error(t, err) - require.Contains(t, err.Error(), fmt.Sprintf("remote %q not found", *updateReq.RemoteID)) - require.Nil(t, updated) - - // Make sure nothing changed in the DB. - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID}). - Return(map[platform.ID]int64{initID: replication.CurrentQueueSizeBytes}, nil) - got, err := svc.GetReplication(ctx, initID) - require.NoError(t, err) - require.Equal(t, replication, *got) + got, err := svc.UpdateReplication(ctx, id1, tt.request) + require.Equal(t, tt.want, got) + require.Equal(t, tt.wantErr, err) + }) + } } -func TestUpdateNoop(t *testing.T) { +func TestValidateUpdatedReplication(t *testing.T) { t.Parallel() - svc, mocks, clean := newTestService(t) - defer clean(t) + tests := []struct { + name string + request influxdb.UpdateReplicationRequest + baseConfig *internal.ReplicationHTTPConfig + storeGetConfigErr error + storePopulateConfigErr error + validatorErr error + want error + }{ + { + name: "success", + request: updateReqWithNoNewSize, + baseConfig: &httpConfig, + }, + { + name: "store get full http config error", + storeGetConfigErr: errors.New("store get full http config error"), + want: errors.New("store get full http config error"), + }, + { + name: "store get populate remote config error", + request: updateReqWithNoNewSize, + storePopulateConfigErr: errors.New("store populate http config error"), + want: errors.New("store populate http config error"), + }, + { + name: "invalid update", + request: updateReqWithNoNewSize, + validatorErr: errors.New("invalid"), + want: &ierrors.Error{ + Code: ierrors.EInvalid, + Msg: "validation fails after applying update", + Err: errors.New("invalid"), + }, + }, + } - insertRemote(t, svc.store, replication.RemoteID) - mocks.bucketSvc.EXPECT().RLock() - mocks.bucketSvc.EXPECT().RUnlock() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). - Return(&influxdb.Bucket{}, nil) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, mocks := newTestService(t) - // Create a replication. - mocks.durableQueueManager.EXPECT().InitializeQueue(initID, createReq.MaxQueueSizeBytes) - created, err := svc.CreateReplication(ctx, createReq) - require.NoError(t, err) - require.Equal(t, replication, *created) + mocks.serviceStore.EXPECT().GetFullHTTPConfig(gomock.Any(), id1).Return(tt.baseConfig, tt.storeGetConfigErr) - // Send a no-op update, assert nothing changed. - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID}). - Return(map[platform.ID]int64{initID: replication.CurrentQueueSizeBytes}, nil) - updated, err := svc.UpdateReplication(ctx, initID, influxdb.UpdateReplicationRequest{}) - require.NoError(t, err) - require.Equal(t, replication, *updated) -} + if tt.storeGetConfigErr == nil { + mocks.serviceStore.EXPECT().PopulateRemoteHTTPConfig(gomock.Any(), *tt.request.RemoteID, tt.baseConfig).Return(tt.storePopulateConfigErr) + } -func TestValidateUpdatedReplicationWithoutPersisting(t *testing.T) { - t.Parallel() + if tt.storeGetConfigErr == nil && tt.storePopulateConfigErr == nil { + mocks.validator.EXPECT().ValidateReplication(gomock.Any(), tt.baseConfig).Return(tt.validatorErr) + } - t.Run("bad remote", func(t *testing.T) { - t.Parallel() - - svc, mocks, clean := newTestService(t) - defer clean(t) - - insertRemote(t, svc.store, replication.RemoteID) - mocks.bucketSvc.EXPECT().RLock() - mocks.bucketSvc.EXPECT().RUnlock() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). - Return(&influxdb.Bucket{}, nil) - - // Create a replication. - mocks.durableQueueManager.EXPECT().InitializeQueue(initID, createReq.MaxQueueSizeBytes) - created, err := svc.CreateReplication(ctx, createReq) - require.NoError(t, err) - require.Equal(t, replication, *created) - - // Attempt to update the replication to point at a nonexistent remote. - require.Contains(t, svc.ValidateUpdatedReplication(ctx, initID, updateReq).Error(), - fmt.Sprintf("remote %q not found", *updateReq.RemoteID)) - - // Make sure nothing changed in the DB. - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID}). - Return(map[platform.ID]int64{initID: replication.CurrentQueueSizeBytes}, nil) - got, err := svc.GetReplication(ctx, initID) - require.NoError(t, err) - require.Equal(t, replication, *got) - }) - - t.Run("validation error", func(t *testing.T) { - t.Parallel() - - svc, mocks, clean := newTestService(t) - defer clean(t) - - insertRemote(t, svc.store, replication.RemoteID) - insertRemote(t, svc.store, updatedReplication.RemoteID) - mocks.bucketSvc.EXPECT().RLock() - mocks.bucketSvc.EXPECT().RUnlock() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). - Return(&influxdb.Bucket{}, nil) - - // Create a replication. - mocks.durableQueueManager.EXPECT().InitializeQueue(initID, createReq.MaxQueueSizeBytes) - created, err := svc.CreateReplication(ctx, createReq) - require.NoError(t, err) - require.Equal(t, replication, *created) - - // Check updating to a failing remote, assert error is returned. - fakeErr := errors.New("O NO") - mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &updatedHttpConfig).Return(fakeErr) - - require.Contains(t, svc.ValidateUpdatedReplication(ctx, initID, updateReq).Error(), fakeErr.Error()) - - // Make sure nothing changed in the DB. - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID}). - Return(map[platform.ID]int64{initID: replication.CurrentQueueSizeBytes}, nil) - got, err := svc.GetReplication(ctx, initID) - require.NoError(t, err) - require.Equal(t, replication, *got) - }) - - t.Run("no error", func(t *testing.T) { - t.Parallel() - - svc, mocks, clean := newTestService(t) - defer clean(t) - - insertRemote(t, svc.store, replication.RemoteID) - insertRemote(t, svc.store, updatedReplication.RemoteID) - mocks.bucketSvc.EXPECT().RLock() - mocks.bucketSvc.EXPECT().RUnlock() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). - Return(&influxdb.Bucket{}, nil) - - // Create a replication. - mocks.durableQueueManager.EXPECT().InitializeQueue(initID, createReq.MaxQueueSizeBytes) - created, err := svc.CreateReplication(ctx, createReq) - require.NoError(t, err) - require.Equal(t, replication, *created) - - // Check updating to a remote that passes validation, assert no error. - mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &updatedHttpConfig).Return(nil) - - require.NoError(t, svc.ValidateUpdatedReplication(ctx, initID, updateReq)) - - // Make sure nothing changed in the DB. - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID}). - Return(map[platform.ID]int64{initID: replication.CurrentQueueSizeBytes}, nil) - got, err := svc.GetReplication(ctx, initID) - require.NoError(t, err) - require.Equal(t, replication, *got) - }) + err := svc.ValidateUpdatedReplication(ctx, id1, tt.request) + require.Equal(t, tt.want, err) + }) + } } func TestDeleteReplication(t *testing.T) { t.Parallel() - svc, mocks, clean := newTestService(t) - defer clean(t) + tests := []struct { + name string + storeErr error + queueManagerErr error + }{ + { + name: "success", + }, + { + name: "store error", + storeErr: errors.New("store error"), + }, + { + name: "queue manager error", + queueManagerErr: errors.New("queue manager error"), + }, + } - insertRemote(t, svc.store, replication.RemoteID) - mocks.bucketSvc.EXPECT().RLock() - mocks.bucketSvc.EXPECT().RUnlock() - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID). - Return(&influxdb.Bucket{}, nil) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, mocks := newTestService(t) - // Deleting a nonexistent ID should return an error. - require.Equal(t, errReplicationNotFound, svc.DeleteReplication(ctx, initID)) + mocks.serviceStore.EXPECT().Lock() + mocks.serviceStore.EXPECT().Unlock() - // Create a replication, then delete it. - mocks.durableQueueManager.EXPECT().InitializeQueue(initID, createReq.MaxQueueSizeBytes) - created, err := svc.CreateReplication(ctx, createReq) - require.NoError(t, err) - require.Equal(t, replication, *created) - mocks.durableQueueManager.EXPECT().DeleteQueue(initID) - require.NoError(t, svc.DeleteReplication(ctx, initID)) - - // Looking up the ID should again produce an error. - got, err := svc.GetReplication(ctx, initID) - require.Equal(t, errReplicationNotFound, err) - require.Nil(t, got) + mocks.serviceStore.EXPECT().DeleteReplication(gomock.Any(), id1).Return(tt.storeErr) + + if tt.storeErr == nil { + mocks.durableQueueManager.EXPECT().DeleteQueue(id1).Return(tt.queueManagerErr) + } + + err := svc.DeleteReplication(ctx, id1) + + var wantErr error + if tt.storeErr != nil { + wantErr = tt.storeErr + } else if tt.queueManagerErr != nil { + wantErr = tt.queueManagerErr + } + + require.Equal(t, wantErr, err) + }) + } } -func TestDeleteReplications(t *testing.T) { +func TestDeleteBucketReplications(t *testing.T) { t.Parallel() - svc, mocks, clean := newTestService(t) - defer clean(t) - - // Deleting when there is no bucket is OK. - require.NoError(t, svc.DeleteBucketReplications(ctx, replication.LocalBucketID)) - - // Register a handful of replications. - createReq2, createReq3 := createReq, createReq - createReq2.Name, createReq3.Name = "test2", "test3" - createReq2.LocalBucketID = platform.ID(77777) - createReq3.RemoteID = updatedReplication.RemoteID - mocks.bucketSvc.EXPECT().RLock().Times(3) - mocks.bucketSvc.EXPECT().RUnlock().Times(3) - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil).Times(2) - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq2.LocalBucketID).Return(&influxdb.Bucket{}, nil) - insertRemote(t, svc.store, createReq.RemoteID) - insertRemote(t, svc.store, createReq3.RemoteID) - - for _, req := range []influxdb.CreateReplicationRequest{createReq, createReq2, createReq3} { - mocks.durableQueueManager.EXPECT().InitializeQueue(gomock.Any(), req.MaxQueueSizeBytes) - _, err := svc.CreateReplication(ctx, req) - require.NoError(t, err) - } - - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID, initID + 1, initID + 2}). - Return(map[platform.ID]int64{initID: 0, initID + 1: 0, initID + 2: 0}, nil) - listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{OrgID: replication.OrgID}) - require.NoError(t, err) - require.Len(t, listed.Replications, 3) + tests := []struct { + name string + storeErr error + storeIDs []platform.ID + queueManagerErr error + wantErr error + }{ + { + name: "success - single replication IDs match bucket ID", + storeIDs: []platform.ID{id1}, + }, + { + name: "success - multiple replication IDs match bucket ID", + storeIDs: []platform.ID{id1, id2}, + }, + { + name: "zero replication IDs match bucket ID", + storeIDs: []platform.ID{}, + }, + { + name: "store error", + storeErr: errors.New("store error"), + wantErr: errors.New("store error"), + }, + { + name: "queue manager delete queue error", + storeIDs: []platform.ID{id1}, + queueManagerErr: errors.New("queue manager error"), + wantErr: fmt.Errorf("deleting replications for bucket %q failed, see server logs for details", id1), + }, + } - // Delete 2/3 by bucket ID. - mocks.durableQueueManager.EXPECT().DeleteQueue(gomock.Any()).Times(2) - require.NoError(t, svc.DeleteBucketReplications(ctx, createReq.LocalBucketID)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, mocks := newTestService(t) - // Ensure they were deleted. - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID + 1}). - Return(map[platform.ID]int64{initID + 1: 0}, nil) - listed, err = svc.ListReplications(ctx, influxdb.ReplicationListFilter{OrgID: replication.OrgID}) - require.NoError(t, err) - require.Len(t, listed.Replications, 1) - require.Equal(t, createReq2.LocalBucketID, listed.Replications[0].LocalBucketID) + mocks.serviceStore.EXPECT().Lock() + mocks.serviceStore.EXPECT().Unlock() + + mocks.serviceStore.EXPECT().DeleteBucketReplications(gomock.Any(), id1).Return(tt.storeIDs, tt.storeErr) + + if tt.storeErr == nil { + for _, id := range tt.storeIDs { + mocks.durableQueueManager.EXPECT().DeleteQueue(id).Return(tt.queueManagerErr) + } + } + + err := svc.DeleteBucketReplications(ctx, id1) + require.Equal(t, tt.wantErr, err) + }) + } } -func TestListReplications(t *testing.T) { +func TestValidateReplication(t *testing.T) { t.Parallel() - createReq2, createReq3 := createReq, createReq - createReq2.Name, createReq3.Name = "test2", "test3" - createReq2.LocalBucketID = platform.ID(77777) - createReq3.RemoteID = updatedReplication.RemoteID - - setup := func(t *testing.T, svc *service, mocks mocks) []influxdb.Replication { - mocks.bucketSvc.EXPECT().RLock().Times(3) - mocks.bucketSvc.EXPECT().RUnlock().Times(3) - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil).Times(2) - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq2.LocalBucketID).Return(&influxdb.Bucket{}, nil) - insertRemote(t, svc.store, createReq.RemoteID) - insertRemote(t, svc.store, createReq3.RemoteID) - - var allReplications []influxdb.Replication - for _, req := range []influxdb.CreateReplicationRequest{createReq, createReq2, createReq3} { - mocks.durableQueueManager.EXPECT().InitializeQueue(gomock.Any(), createReq.MaxQueueSizeBytes) - created, err := svc.CreateReplication(ctx, req) - require.NoError(t, err) - allReplications = append(allReplications, *created) - } - return allReplications - } - - t.Run("list all", func(t *testing.T) { - t.Parallel() - - svc, mocks, clean := newTestService(t) - defer clean(t) - allRepls := setup(t, svc, mocks) - - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID, initID + 1, initID + 2}). - Return(map[platform.ID]int64{initID: 0, initID + 1: 0, initID + 2: 0}, nil) - listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{OrgID: createReq.OrgID}) - require.NoError(t, err) - require.Equal(t, influxdb.Replications{Replications: allRepls}, *listed) - }) - - t.Run("list by name", func(t *testing.T) { - t.Parallel() - - svc, mocks, clean := newTestService(t) - defer clean(t) - allRepls := setup(t, svc, mocks) - - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID + 1}). - Return(map[platform.ID]int64{initID + 1: 0}, nil) - listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{ - OrgID: createReq.OrgID, - Name: &createReq2.Name, - }) - require.NoError(t, err) - require.Equal(t, influxdb.Replications{Replications: allRepls[1:2]}, *listed) - }) - - t.Run("list by remote ID", func(t *testing.T) { - t.Parallel() - - svc, mocks, clean := newTestService(t) - defer clean(t) - allRepls := setup(t, svc, mocks) - - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID, initID + 1}). - Return(map[platform.ID]int64{initID: 0, initID + 1: 0}, nil) - listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{ - OrgID: createReq.OrgID, - RemoteID: &createReq.RemoteID, - }) - require.NoError(t, err) - require.Equal(t, influxdb.Replications{Replications: allRepls[0:2]}, *listed) - }) - - t.Run("list by bucket ID", func(t *testing.T) { - t.Parallel() - - svc, mocks, clean := newTestService(t) - defer clean(t) - allRepls := setup(t, svc, mocks) - - mocks.durableQueueManager.EXPECT().CurrentQueueSizes([]platform.ID{initID, initID + 2}). - Return(map[platform.ID]int64{initID: 0, initID + 2: 0}, nil) - listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{ - OrgID: createReq.OrgID, - LocalBucketID: &createReq.LocalBucketID, - }) - require.NoError(t, err) - require.Equal(t, influxdb.Replications{Replications: append(allRepls[0:1], allRepls[2:]...)}, *listed) - }) + tests := []struct { + name string + storeErr error + validatorErr error + wantErr error + }{ + { + name: "valid", + }, + { + name: "store error", + storeErr: errors.New("store error"), + wantErr: errors.New("store error"), + }, + { + name: "validation error - invalid replication", + validatorErr: errors.New("validation error"), + wantErr: &ierrors.Error{ + Code: ierrors.EInvalid, + Msg: "replication failed validation", + Err: errors.New("validation error"), + }, + }, + } - t.Run("list by other org ID", func(t *testing.T) { - t.Parallel() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, mocks := newTestService(t) - svc, mocks, clean := newTestService(t) - defer clean(t) - setup(t, svc, mocks) + mocks.serviceStore.EXPECT().GetFullHTTPConfig(gomock.Any(), id1).Return(&httpConfig, tt.storeErr) + if tt.storeErr == nil { + mocks.validator.EXPECT().ValidateReplication(gomock.Any(), &httpConfig).Return(tt.validatorErr) + } - listed, err := svc.ListReplications(ctx, influxdb.ReplicationListFilter{OrgID: platform.ID(2)}) - require.NoError(t, err) - require.Equal(t, influxdb.Replications{}, *listed) - }) + err := svc.ValidateReplication(ctx, id1) + require.Equal(t, tt.wantErr, err) + }) + } } func TestWritePoints(t *testing.T) { t.Parallel() - svc, mocks, clean := newTestService(t) - defer clean(t) + svc, mocks := newTestService(t) - // Register a handful of replications. - createReq2, createReq3 := createReq, createReq - createReq2.Name, createReq3.Name = "test2", "test3" - createReq2.LocalBucketID = platform.ID(77777) - createReq3.RemoteID = updatedReplication.RemoteID - mocks.bucketSvc.EXPECT().RLock().Times(3) - mocks.bucketSvc.EXPECT().RUnlock().Times(3) - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil).Times(2) - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq2.LocalBucketID).Return(&influxdb.Bucket{}, nil) - insertRemote(t, svc.store, createReq.RemoteID) - insertRemote(t, svc.store, createReq3.RemoteID) - - for _, req := range []influxdb.CreateReplicationRequest{createReq, createReq2, createReq3} { - mocks.durableQueueManager.EXPECT().InitializeQueue(gomock.Any(), req.MaxQueueSizeBytes) - _, err := svc.CreateReplication(ctx, req) - require.NoError(t, err) + list := &influxdb.Replications{ + Replications: []influxdb.Replication{replication1, replication2}, } + mocks.serviceStore.EXPECT().ListReplications(gomock.Any(), influxdb.ReplicationListFilter{ + OrgID: orgID, + LocalBucketID: &id1, + }).Return(list, nil) + points, err := models.ParsePointsString(` cpu,host=0 value=1.1 6000000000 cpu,host=A value=1.2 2000000000 @@ -640,10 +671,10 @@ disk,host=C value=1.3 1000000000`) require.NoError(t, err) // Points should successfully write to local TSM. - mocks.pointWriter.EXPECT().WritePoints(gomock.Any(), replication.OrgID, replication.LocalBucketID, points).Return(nil) + mocks.pointWriter.EXPECT().WritePoints(gomock.Any(), orgID, id1, points).Return(nil) // Points should successfully be enqueued in the 2 replications associated with the local bucket. - for _, id := range []platform.ID{initID, initID + 2} { + for _, id := range []platform.ID{replication1.ID, replication2.ID} { mocks.durableQueueManager.EXPECT(). EnqueueData(id, gomock.Any(), len(points)). DoAndReturn(func(_ platform.ID, data []byte, numPoints int) error { @@ -666,33 +697,23 @@ disk,host=C value=1.3 1000000000`) }) } - require.NoError(t, svc.WritePoints(ctx, replication.OrgID, replication.LocalBucketID, points)) + require.NoError(t, svc.WritePoints(ctx, orgID, id1, points)) } func TestWritePoints_LocalFailure(t *testing.T) { t.Parallel() - svc, mocks, clean := newTestService(t) - defer clean(t) + svc, mocks := newTestService(t) - // Register a handful of replications. - createReq2, createReq3 := createReq, createReq - createReq2.Name, createReq3.Name = "test2", "test3" - createReq2.LocalBucketID = platform.ID(77777) - createReq3.RemoteID = updatedReplication.RemoteID - mocks.bucketSvc.EXPECT().RLock().Times(3) - mocks.bucketSvc.EXPECT().RUnlock().Times(3) - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq.LocalBucketID).Return(&influxdb.Bucket{}, nil).Times(2) - mocks.bucketSvc.EXPECT().FindBucketByID(gomock.Any(), createReq2.LocalBucketID).Return(&influxdb.Bucket{}, nil) - insertRemote(t, svc.store, createReq.RemoteID) - insertRemote(t, svc.store, createReq3.RemoteID) - - for _, req := range []influxdb.CreateReplicationRequest{createReq, createReq2, createReq3} { - mocks.durableQueueManager.EXPECT().InitializeQueue(gomock.Any(), req.MaxQueueSizeBytes) - _, err := svc.CreateReplication(ctx, req) - require.NoError(t, err) + list := &influxdb.Replications{ + Replications: []influxdb.Replication{replication1, replication2}, } + mocks.serviceStore.EXPECT().ListReplications(gomock.Any(), influxdb.ReplicationListFilter{ + OrgID: orgID, + LocalBucketID: &id1, + }).Return(list, nil) + points, err := models.ParsePointsString(` cpu,host=0 value=1.1 6000000000 cpu,host=A value=1.2 2000000000 @@ -706,9 +727,78 @@ disk,host=C value=1.3 1000000000`) // Points should fail to write to local TSM. writeErr := errors.New("O NO") - mocks.pointWriter.EXPECT().WritePoints(gomock.Any(), replication.OrgID, replication.LocalBucketID, points).Return(writeErr) + mocks.pointWriter.EXPECT().WritePoints(gomock.Any(), orgID, id1, points).Return(writeErr) // Don't expect any calls to enqueue points. - require.Equal(t, writeErr, svc.WritePoints(ctx, replication.OrgID, replication.LocalBucketID, points)) + require.Equal(t, writeErr, svc.WritePoints(ctx, orgID, id1, points)) +} + +func TestOpen(t *testing.T) { + t.Parallel() + + filter := influxdb.ReplicationListFilter{} + + tests := []struct { + name string + storeErr error + queueManagerErr error + replicationsMap map[platform.ID]int64 + list *influxdb.Replications + }{ + { + name: "no error, multiple replications from storage", + replicationsMap: map[platform.ID]int64{ + replication1.ID: replication1.MaxQueueSizeBytes, + replication2.ID: replication2.MaxQueueSizeBytes, + }, + list: &influxdb.Replications{ + Replications: []influxdb.Replication{replication1, replication2}, + }, + }, + { + name: "no error, one stored replication", + replicationsMap: map[platform.ID]int64{ + replication1.ID: replication1.MaxQueueSizeBytes, + }, + list: &influxdb.Replications{ + Replications: []influxdb.Replication{replication1}, + }, + }, + { + name: "store error", + storeErr: errors.New("store error"), + }, + { + name: "queue manager error", + replicationsMap: map[platform.ID]int64{ + replication1.ID: replication1.MaxQueueSizeBytes, + }, + list: &influxdb.Replications{ + Replications: []influxdb.Replication{replication1}, + }, + queueManagerErr: errors.New("queue manager error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, mocks := newTestService(t) + + mocks.serviceStore.EXPECT().ListReplications(gomock.Any(), filter).Return(tt.list, tt.storeErr) + if tt.storeErr == nil { + mocks.durableQueueManager.EXPECT().StartReplicationQueues(tt.replicationsMap).Return(tt.queueManagerErr) + } + + var wantErr error + if tt.storeErr != nil { + wantErr = tt.storeErr + } else if tt.queueManagerErr != nil { + wantErr = tt.queueManagerErr + } + + err := svc.Open(ctx) + require.Equal(t, wantErr, err) + }) + } } type mocks struct { @@ -716,17 +806,11 @@ type mocks struct { validator *replicationsMock.MockReplicationValidator durableQueueManager *replicationsMock.MockDurableQueueManager pointWriter *replicationsMock.MockPointsWriter + serviceStore *replicationsMock.MockServiceStore } -func newTestService(t *testing.T) (*service, mocks, func(t *testing.T)) { - store, clean := sqlite.NewTestStore(t) +func newTestService(t *testing.T) (*service, mocks) { logger := zaptest.NewLogger(t) - sqliteMigrator := sqlite.NewMigrator(store, logger) - require.NoError(t, sqliteMigrator.Up(ctx, migrations.AllUp)) - - // Make sure foreign-key checking is enabled. - _, err := store.DB.Exec("PRAGMA foreign_keys = ON;") - require.NoError(t, err) ctrl := gomock.NewController(t) mocks := mocks{ @@ -734,10 +818,11 @@ func newTestService(t *testing.T) (*service, mocks, func(t *testing.T)) { validator: replicationsMock.NewMockReplicationValidator(ctrl), durableQueueManager: replicationsMock.NewMockDurableQueueManager(ctrl), pointWriter: replicationsMock.NewMockPointsWriter(ctrl), + serviceStore: replicationsMock.NewMockServiceStore(ctrl), } svc := service{ - store: store, - idGenerator: mock.NewIncrementingIDGenerator(initID), + store: mocks.serviceStore, + idGenerator: mock.NewIncrementingIDGenerator(id1), bucketService: mocks.bucketSvc, validator: mocks.validator, log: logger, @@ -745,31 +830,5 @@ func newTestService(t *testing.T) (*service, mocks, func(t *testing.T)) { localWriter: mocks.pointWriter, } - return &svc, mocks, clean -} - -func insertRemote(t *testing.T, store *sqlite.SqlStore, id platform.ID) { - store.Mu.Lock() - defer store.Mu.Unlock() - - q := sq.Insert("remotes").SetMap(sq.Eq{ - "id": id, - "org_id": replication.OrgID, - "name": fmt.Sprintf("foo-%s", id), - "remote_url": fmt.Sprintf("http://%s.cloud", id), - "remote_api_token": id.String(), - "remote_org_id": platform.ID(888888), - "allow_insecure_tls": true, - "created_at": "datetime('now')", - "updated_at": "datetime('now')", - }) - query, args, err := q.ToSql() - require.NoError(t, err) - - _, err = store.DB.Exec(query, args...) - require.NoError(t, err) -} - -func boolPointer(b bool) *bool { - return &b + return &svc, mocks }