From ac2c3e90ff2aed95e8da6f3fcea25d7023e99590 Mon Sep 17 00:00:00 2001 From: haller33 Date: Mon, 2 Dec 2024 09:55:14 -0300 Subject: [PATCH] feat: tags collection redesign --- api/routes/tags_test.go | 14 +- api/services/mocks/services.go | 289 +++++++++- api/services/sshkeys.go | 4 +- api/services/sshkeys_tags.go | 4 +- api/services/sshkeys_tags_test.go | 115 +++- api/services/sshkeys_test.go | 81 ++- api/services/tags.go | 14 +- api/services/tags_test.go | 139 ++++- api/services/utils.go | 18 +- api/store/mocks/store.go | 512 +++++++++++++++++- api/store/mongo/device_tags.go | 72 ++- api/store/mongo/device_tags_test.go | 12 +- api/store/mongo/fixtures/tags.json | 29 + api/store/mongo/migrations/migration_69.go | 1 - api/store/mongo/migrations/migration_86.go | 2 - .../mongo/migrations/migration_86_test.go | 6 +- api/store/mongo/publickey_tags.go | 39 +- api/store/mongo/store_test.go | 12 +- api/store/mongo/tags.go | 120 ++-- api/store/mongo/tags_test.go | 281 +++++++++- api/store/tags.go | 13 +- api/store/transaction.go | 4 +- 22 files changed, 1619 insertions(+), 162 deletions(-) create mode 100644 api/store/mongo/fixtures/tags.json diff --git a/api/routes/tags_test.go b/api/routes/tags_test.go index b9a636af1a4..7df44e08a16 100644 --- a/api/routes/tags_test.go +++ b/api/routes/tags_test.go @@ -10,6 +10,7 @@ import ( "github.com/shellhub-io/shellhub/api/services/mocks" "github.com/shellhub-io/shellhub/pkg/api/authorizer" + "github.com/shellhub-io/shellhub/pkg/models" "github.com/stretchr/testify/assert" gomock "github.com/stretchr/testify/mock" ) @@ -26,7 +27,18 @@ func TestGetTags(t *testing.T) { title: "success when try to get an existing tag", expectedStatus: http.StatusOK, requiredMocks: func() { - mock.On("GetTags", gomock.Anything, "").Return([]string{"tag1", "tag2"}, 2, nil) + mock.On("GetTags", gomock.Anything, "").Return([]models.Tags{ + { + Name: "tag-1", + Color: "#ff0000", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + { + Name: "tag-2", + Color: "green", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + }, 2, nil) }, }, } diff --git a/api/services/mocks/services.go b/api/services/mocks/services.go index 98a9ec997ce..e3e848b993f 100644 --- a/api/services/mocks/services.go +++ b/api/services/mocks/services.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.20.0. DO NOT EDIT. +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. package mocks @@ -28,6 +28,10 @@ type Service struct { func (_m *Service) AddNamespaceMember(ctx context.Context, req *requests.NamespaceAddMember) (*models.Namespace, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for AddNamespaceMember") + } + var r0 *models.Namespace var r1 error if rf, ok := ret.Get(0).(func(context.Context, *requests.NamespaceAddMember) (*models.Namespace, error)); ok { @@ -54,6 +58,10 @@ func (_m *Service) AddNamespaceMember(ctx context.Context, req *requests.Namespa func (_m *Service) AddPublicKeyTag(ctx context.Context, tenant string, fingerprint string, tag string) error { ret := _m.Called(ctx, tenant, fingerprint, tag) + if len(ret) == 0 { + panic("no return value specified for AddPublicKeyTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { r0 = rf(ctx, tenant, fingerprint, tag) @@ -68,6 +76,10 @@ func (_m *Service) AddPublicKeyTag(ctx context.Context, tenant string, fingerpri func (_m *Service) AuthAPIKey(ctx context.Context, key string) (*models.APIKey, error) { ret := _m.Called(ctx, key) + if len(ret) == 0 { + panic("no return value specified for AuthAPIKey") + } + var r0 *models.APIKey var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (*models.APIKey, error)); ok { @@ -94,6 +106,10 @@ func (_m *Service) AuthAPIKey(ctx context.Context, key string) (*models.APIKey, func (_m *Service) AuthCacheToken(ctx context.Context, tenant string, id string, token string) error { ret := _m.Called(ctx, tenant, id, token) + if len(ret) == 0 { + panic("no return value specified for AuthCacheToken") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { r0 = rf(ctx, tenant, id, token) @@ -108,6 +124,10 @@ func (_m *Service) AuthCacheToken(ctx context.Context, tenant string, id string, func (_m *Service) AuthDevice(ctx context.Context, req requests.DeviceAuth, remoteAddr string) (*models.DeviceAuthResponse, error) { ret := _m.Called(ctx, req, remoteAddr) + if len(ret) == 0 { + panic("no return value specified for AuthDevice") + } + var r0 *models.DeviceAuthResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, requests.DeviceAuth, string) (*models.DeviceAuthResponse, error)); ok { @@ -134,6 +154,10 @@ func (_m *Service) AuthDevice(ctx context.Context, req requests.DeviceAuth, remo func (_m *Service) AuthIsCacheToken(ctx context.Context, tenant string, id string) (bool, error) { ret := _m.Called(ctx, tenant, id) + if len(ret) == 0 { + panic("no return value specified for AuthIsCacheToken") + } + var r0 bool var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string) (bool, error)); ok { @@ -158,6 +182,10 @@ func (_m *Service) AuthIsCacheToken(ctx context.Context, tenant string, id strin func (_m *Service) AuthLocalUser(ctx context.Context, req *requests.AuthLocalUser, sourceIP string) (*models.UserAuthResponse, int64, string, error) { ret := _m.Called(ctx, req, sourceIP) + if len(ret) == 0 { + panic("no return value specified for AuthLocalUser") + } + var r0 *models.UserAuthResponse var r1 int64 var r2 string @@ -198,6 +226,10 @@ func (_m *Service) AuthLocalUser(ctx context.Context, req *requests.AuthLocalUse func (_m *Service) AuthPublicKey(ctx context.Context, req requests.PublicKeyAuth) (*models.PublicKeyAuthResponse, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for AuthPublicKey") + } + var r0 *models.PublicKeyAuthResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, requests.PublicKeyAuth) (*models.PublicKeyAuthResponse, error)); ok { @@ -224,6 +256,10 @@ func (_m *Service) AuthPublicKey(ctx context.Context, req requests.PublicKeyAuth func (_m *Service) AuthUncacheToken(ctx context.Context, tenant string, id string) error { ret := _m.Called(ctx, tenant, id) + if len(ret) == 0 { + panic("no return value specified for AuthUncacheToken") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { r0 = rf(ctx, tenant, id) @@ -238,6 +274,10 @@ func (_m *Service) AuthUncacheToken(ctx context.Context, tenant string, id strin func (_m *Service) BillingEvaluate(_a0 internalclient.Client, _a1 string) (bool, error) { ret := _m.Called(_a0, _a1) + if len(ret) == 0 { + panic("no return value specified for BillingEvaluate") + } + var r0 bool var r1 error if rf, ok := ret.Get(0).(func(internalclient.Client, string) (bool, error)); ok { @@ -262,6 +302,10 @@ func (_m *Service) BillingEvaluate(_a0 internalclient.Client, _a1 string) (bool, func (_m *Service) BillingReport(_a0 internalclient.Client, _a1 string, _a2 string) error { ret := _m.Called(_a0, _a1, _a2) + if len(ret) == 0 { + panic("no return value specified for BillingReport") + } + var r0 error if rf, ok := ret.Get(0).(func(internalclient.Client, string, string) error); ok { r0 = rf(_a0, _a1, _a2) @@ -276,6 +320,10 @@ func (_m *Service) BillingReport(_a0 internalclient.Client, _a1 string, _a2 stri func (_m *Service) CreateAPIKey(ctx context.Context, req *requests.CreateAPIKey) (*responses.CreateAPIKey, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for CreateAPIKey") + } + var r0 *responses.CreateAPIKey var r1 error if rf, ok := ret.Get(0).(func(context.Context, *requests.CreateAPIKey) (*responses.CreateAPIKey, error)); ok { @@ -302,6 +350,10 @@ func (_m *Service) CreateAPIKey(ctx context.Context, req *requests.CreateAPIKey) func (_m *Service) CreateDeviceTag(ctx context.Context, uid models.UID, tag string) error { ret := _m.Called(ctx, uid, tag) + if len(ret) == 0 { + panic("no return value specified for CreateDeviceTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, string) error); ok { r0 = rf(ctx, uid, tag) @@ -316,6 +368,10 @@ func (_m *Service) CreateDeviceTag(ctx context.Context, uid models.UID, tag stri func (_m *Service) CreateNamespace(ctx context.Context, namespace *requests.NamespaceCreate) (*models.Namespace, error) { ret := _m.Called(ctx, namespace) + if len(ret) == 0 { + panic("no return value specified for CreateNamespace") + } + var r0 *models.Namespace var r1 error if rf, ok := ret.Get(0).(func(context.Context, *requests.NamespaceCreate) (*models.Namespace, error)); ok { @@ -342,6 +398,10 @@ func (_m *Service) CreateNamespace(ctx context.Context, namespace *requests.Name func (_m *Service) CreatePrivateKey(ctx context.Context) (*models.PrivateKey, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for CreatePrivateKey") + } + var r0 *models.PrivateKey var r1 error if rf, ok := ret.Get(0).(func(context.Context) (*models.PrivateKey, error)); ok { @@ -368,6 +428,10 @@ func (_m *Service) CreatePrivateKey(ctx context.Context) (*models.PrivateKey, er func (_m *Service) CreatePublicKey(ctx context.Context, req requests.PublicKeyCreate, tenant string) (*responses.PublicKeyCreate, error) { ret := _m.Called(ctx, req, tenant) + if len(ret) == 0 { + panic("no return value specified for CreatePublicKey") + } + var r0 *responses.PublicKeyCreate var r1 error if rf, ok := ret.Get(0).(func(context.Context, requests.PublicKeyCreate, string) (*responses.PublicKeyCreate, error)); ok { @@ -394,6 +458,10 @@ func (_m *Service) CreatePublicKey(ctx context.Context, req requests.PublicKeyCr func (_m *Service) CreateSession(ctx context.Context, session requests.SessionCreate) (*models.Session, error) { ret := _m.Called(ctx, session) + if len(ret) == 0 { + panic("no return value specified for CreateSession") + } + var r0 *models.Session var r1 error if rf, ok := ret.Get(0).(func(context.Context, requests.SessionCreate) (*models.Session, error)); ok { @@ -420,6 +488,10 @@ func (_m *Service) CreateSession(ctx context.Context, session requests.SessionCr func (_m *Service) CreateUserToken(ctx context.Context, req *requests.CreateUserToken) (*models.UserAuthResponse, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for CreateUserToken") + } + var r0 *models.UserAuthResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *requests.CreateUserToken) (*models.UserAuthResponse, error)); ok { @@ -446,6 +518,10 @@ func (_m *Service) CreateUserToken(ctx context.Context, req *requests.CreateUser func (_m *Service) DeactivateSession(ctx context.Context, uid models.UID) error { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for DeactivateSession") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) error); ok { r0 = rf(ctx, uid) @@ -460,6 +536,10 @@ func (_m *Service) DeactivateSession(ctx context.Context, uid models.UID) error func (_m *Service) DeleteAPIKey(ctx context.Context, req *requests.DeleteAPIKey) error { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for DeleteAPIKey") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *requests.DeleteAPIKey) error); ok { r0 = rf(ctx, req) @@ -474,6 +554,10 @@ func (_m *Service) DeleteAPIKey(ctx context.Context, req *requests.DeleteAPIKey) func (_m *Service) DeleteDevice(ctx context.Context, uid models.UID, tenant string) error { ret := _m.Called(ctx, uid, tenant) + if len(ret) == 0 { + panic("no return value specified for DeleteDevice") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, string) error); ok { r0 = rf(ctx, uid, tenant) @@ -488,6 +572,10 @@ func (_m *Service) DeleteDevice(ctx context.Context, uid models.UID, tenant stri func (_m *Service) DeleteNamespace(ctx context.Context, tenantID string) error { ret := _m.Called(ctx, tenantID) + if len(ret) == 0 { + panic("no return value specified for DeleteNamespace") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { r0 = rf(ctx, tenantID) @@ -502,6 +590,10 @@ func (_m *Service) DeleteNamespace(ctx context.Context, tenantID string) error { func (_m *Service) DeletePublicKey(ctx context.Context, fingerprint string, tenant string) error { ret := _m.Called(ctx, fingerprint, tenant) + if len(ret) == 0 { + panic("no return value specified for DeletePublicKey") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { r0 = rf(ctx, fingerprint, tenant) @@ -516,6 +608,10 @@ func (_m *Service) DeletePublicKey(ctx context.Context, fingerprint string, tena func (_m *Service) DeleteTag(ctx context.Context, tenant string, tag string) error { ret := _m.Called(ctx, tenant, tag) + if len(ret) == 0 { + panic("no return value specified for DeleteTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { r0 = rf(ctx, tenant, tag) @@ -530,6 +626,10 @@ func (_m *Service) DeleteTag(ctx context.Context, tenant string, tag string) err func (_m *Service) EditNamespace(ctx context.Context, req *requests.NamespaceEdit) (*models.Namespace, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for EditNamespace") + } + var r0 *models.Namespace var r1 error if rf, ok := ret.Get(0).(func(context.Context, *requests.NamespaceEdit) (*models.Namespace, error)); ok { @@ -556,6 +656,10 @@ func (_m *Service) EditNamespace(ctx context.Context, req *requests.NamespaceEdi func (_m *Service) EditSessionRecordStatus(ctx context.Context, sessionRecord bool, tenantID string) error { ret := _m.Called(ctx, sessionRecord, tenantID) + if len(ret) == 0 { + panic("no return value specified for EditSessionRecordStatus") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, bool, string) error); ok { r0 = rf(ctx, sessionRecord, tenantID) @@ -570,6 +674,10 @@ func (_m *Service) EditSessionRecordStatus(ctx context.Context, sessionRecord bo func (_m *Service) EvaluateKeyFilter(ctx context.Context, key *models.PublicKey, dev models.Device) (bool, error) { ret := _m.Called(ctx, key, dev) + if len(ret) == 0 { + panic("no return value specified for EvaluateKeyFilter") + } + var r0 bool var r1 error if rf, ok := ret.Get(0).(func(context.Context, *models.PublicKey, models.Device) (bool, error)); ok { @@ -594,6 +702,10 @@ func (_m *Service) EvaluateKeyFilter(ctx context.Context, key *models.PublicKey, func (_m *Service) EvaluateKeyUsername(ctx context.Context, key *models.PublicKey, username string) (bool, error) { ret := _m.Called(ctx, key, username) + if len(ret) == 0 { + panic("no return value specified for EvaluateKeyUsername") + } + var r0 bool var r1 error if rf, ok := ret.Get(0).(func(context.Context, *models.PublicKey, string) (bool, error)); ok { @@ -618,6 +730,10 @@ func (_m *Service) EvaluateKeyUsername(ctx context.Context, key *models.PublicKe func (_m *Service) EventSession(ctx context.Context, uid models.UID, event *models.SessionEvent) error { ret := _m.Called(ctx, uid, event) + if len(ret) == 0 { + panic("no return value specified for EventSession") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, *models.SessionEvent) error); ok { r0 = rf(ctx, uid, event) @@ -632,6 +748,10 @@ func (_m *Service) EventSession(ctx context.Context, uid models.UID, event *mode func (_m *Service) GetDevice(ctx context.Context, uid models.UID) (*models.Device, error) { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for GetDevice") + } + var r0 *models.Device var r1 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) (*models.Device, error)); ok { @@ -658,6 +778,10 @@ func (_m *Service) GetDevice(ctx context.Context, uid models.UID) (*models.Devic func (_m *Service) GetDeviceByPublicURLAddress(ctx context.Context, address string) (*models.Device, error) { ret := _m.Called(ctx, address) + if len(ret) == 0 { + panic("no return value specified for GetDeviceByPublicURLAddress") + } + var r0 *models.Device var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (*models.Device, error)); ok { @@ -684,6 +808,10 @@ func (_m *Service) GetDeviceByPublicURLAddress(ctx context.Context, address stri func (_m *Service) GetNamespace(ctx context.Context, tenantID string) (*models.Namespace, error) { ret := _m.Called(ctx, tenantID) + if len(ret) == 0 { + panic("no return value specified for GetNamespace") + } + var r0 *models.Namespace var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (*models.Namespace, error)); ok { @@ -710,6 +838,10 @@ func (_m *Service) GetNamespace(ctx context.Context, tenantID string) (*models.N func (_m *Service) GetPublicKey(ctx context.Context, fingerprint string, tenant string) (*models.PublicKey, error) { ret := _m.Called(ctx, fingerprint, tenant) + if len(ret) == 0 { + panic("no return value specified for GetPublicKey") + } + var r0 *models.PublicKey var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.PublicKey, error)); ok { @@ -736,6 +868,10 @@ func (_m *Service) GetPublicKey(ctx context.Context, fingerprint string, tenant func (_m *Service) GetSession(ctx context.Context, uid models.UID) (*models.Session, error) { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for GetSession") + } + var r0 *models.Session var r1 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) (*models.Session, error)); ok { @@ -762,6 +898,10 @@ func (_m *Service) GetSession(ctx context.Context, uid models.UID) (*models.Sess func (_m *Service) GetSessionRecord(ctx context.Context, tenantID string) (bool, error) { ret := _m.Called(ctx, tenantID) + if len(ret) == 0 { + panic("no return value specified for GetSessionRecord") + } + var r0 bool var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok { @@ -786,6 +926,10 @@ func (_m *Service) GetSessionRecord(ctx context.Context, tenantID string) (bool, func (_m *Service) GetStats(ctx context.Context) (*models.Stats, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for GetStats") + } + var r0 *models.Stats var r1 error if rf, ok := ret.Get(0).(func(context.Context) (*models.Stats, error)); ok { @@ -809,20 +953,24 @@ func (_m *Service) GetStats(ctx context.Context) (*models.Stats, error) { } // GetTags provides a mock function with given fields: ctx, tenant -func (_m *Service) GetTags(ctx context.Context, tenant string) ([]string, int, error) { +func (_m *Service) GetTags(ctx context.Context, tenant string) ([]models.Tags, int, error) { ret := _m.Called(ctx, tenant) - var r0 []string + if len(ret) == 0 { + panic("no return value specified for GetTags") + } + + var r0 []models.Tags var r1 int var r2 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]string, int, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) ([]models.Tags, int, error)); ok { return rf(ctx, tenant) } - if rf, ok := ret.Get(0).(func(context.Context, string) []string); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) []models.Tags); ok { r0 = rf(ctx, tenant) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]string) + r0 = ret.Get(0).([]models.Tags) } } @@ -845,6 +993,10 @@ func (_m *Service) GetTags(ctx context.Context, tenant string) ([]string, int, e func (_m *Service) GetUserRole(ctx context.Context, tenantID string, userID string) (string, error) { ret := _m.Called(ctx, tenantID, userID) + if len(ret) == 0 { + panic("no return value specified for GetUserRole") + } + var r0 string var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string) (string, error)); ok { @@ -869,6 +1021,10 @@ func (_m *Service) GetUserRole(ctx context.Context, tenantID string, userID stri func (_m *Service) KeepAliveSession(ctx context.Context, uid models.UID) error { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for KeepAliveSession") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) error); ok { r0 = rf(ctx, uid) @@ -883,6 +1039,10 @@ func (_m *Service) KeepAliveSession(ctx context.Context, uid models.UID) error { func (_m *Service) LeaveNamespace(ctx context.Context, req *requests.LeaveNamespace) (*models.UserAuthResponse, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for LeaveNamespace") + } + var r0 *models.UserAuthResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *requests.LeaveNamespace) (*models.UserAuthResponse, error)); ok { @@ -909,6 +1069,10 @@ func (_m *Service) LeaveNamespace(ctx context.Context, req *requests.LeaveNamesp func (_m *Service) ListAPIKeys(ctx context.Context, req *requests.ListAPIKey) ([]models.APIKey, int, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for ListAPIKeys") + } + var r0 []models.APIKey var r1 int var r2 error @@ -942,6 +1106,10 @@ func (_m *Service) ListAPIKeys(ctx context.Context, req *requests.ListAPIKey) ([ func (_m *Service) ListDevices(ctx context.Context, req *requests.DeviceList) ([]models.Device, int, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for ListDevices") + } + var r0 []models.Device var r1 int var r2 error @@ -975,6 +1143,10 @@ func (_m *Service) ListDevices(ctx context.Context, req *requests.DeviceList) ([ func (_m *Service) ListNamespaces(ctx context.Context, req *requests.NamespaceList) ([]models.Namespace, int, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for ListNamespaces") + } + var r0 []models.Namespace var r1 int var r2 error @@ -1008,6 +1180,10 @@ func (_m *Service) ListNamespaces(ctx context.Context, req *requests.NamespaceLi func (_m *Service) ListPublicKeys(ctx context.Context, paginator query.Paginator) ([]models.PublicKey, int, error) { ret := _m.Called(ctx, paginator) + if len(ret) == 0 { + panic("no return value specified for ListPublicKeys") + } + var r0 []models.PublicKey var r1 int var r2 error @@ -1041,6 +1217,10 @@ func (_m *Service) ListPublicKeys(ctx context.Context, paginator query.Paginator func (_m *Service) ListSessions(ctx context.Context, paginator query.Paginator) ([]models.Session, int, error) { ret := _m.Called(ctx, paginator) + if len(ret) == 0 { + panic("no return value specified for ListSessions") + } + var r0 []models.Session var r1 int var r2 error @@ -1074,6 +1254,10 @@ func (_m *Service) ListSessions(ctx context.Context, paginator query.Paginator) func (_m *Service) LookupDevice(ctx context.Context, namespace string, name string) (*models.Device, error) { ret := _m.Called(ctx, namespace, name) + if len(ret) == 0 { + panic("no return value specified for LookupDevice") + } + var r0 *models.Device var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.Device, error)); ok { @@ -1100,6 +1284,10 @@ func (_m *Service) LookupDevice(ctx context.Context, namespace string, name stri func (_m *Service) OfflineDevice(ctx context.Context, uid models.UID) error { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for OfflineDevice") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) error); ok { r0 = rf(ctx, uid) @@ -1114,6 +1302,10 @@ func (_m *Service) OfflineDevice(ctx context.Context, uid models.UID) error { func (_m *Service) PublicKey() *rsa.PublicKey { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for PublicKey") + } + var r0 *rsa.PublicKey if rf, ok := ret.Get(0).(func() *rsa.PublicKey); ok { r0 = rf() @@ -1130,6 +1322,10 @@ func (_m *Service) PublicKey() *rsa.PublicKey { func (_m *Service) RemoveDeviceTag(ctx context.Context, uid models.UID, tag string) error { ret := _m.Called(ctx, uid, tag) + if len(ret) == 0 { + panic("no return value specified for RemoveDeviceTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, string) error); ok { r0 = rf(ctx, uid, tag) @@ -1144,6 +1340,10 @@ func (_m *Service) RemoveDeviceTag(ctx context.Context, uid models.UID, tag stri func (_m *Service) RemoveNamespaceMember(ctx context.Context, req *requests.NamespaceRemoveMember) (*models.Namespace, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for RemoveNamespaceMember") + } + var r0 *models.Namespace var r1 error if rf, ok := ret.Get(0).(func(context.Context, *requests.NamespaceRemoveMember) (*models.Namespace, error)); ok { @@ -1170,6 +1370,10 @@ func (_m *Service) RemoveNamespaceMember(ctx context.Context, req *requests.Name func (_m *Service) RemovePublicKeyTag(ctx context.Context, tenant string, fingerprint string, tag string) error { ret := _m.Called(ctx, tenant, fingerprint, tag) + if len(ret) == 0 { + panic("no return value specified for RemovePublicKeyTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { r0 = rf(ctx, tenant, fingerprint, tag) @@ -1184,6 +1388,10 @@ func (_m *Service) RemovePublicKeyTag(ctx context.Context, tenant string, finger func (_m *Service) RenameDevice(ctx context.Context, uid models.UID, name string, tenant string) error { ret := _m.Called(ctx, uid, name, tenant) + if len(ret) == 0 { + panic("no return value specified for RenameDevice") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, string, string) error); ok { r0 = rf(ctx, uid, name, tenant) @@ -1198,6 +1406,10 @@ func (_m *Service) RenameDevice(ctx context.Context, uid models.UID, name string func (_m *Service) RenameTag(ctx context.Context, tenant string, oldTag string, newTag string) error { ret := _m.Called(ctx, tenant, oldTag, newTag) + if len(ret) == 0 { + panic("no return value specified for RenameTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { r0 = rf(ctx, tenant, oldTag, newTag) @@ -1212,6 +1424,10 @@ func (_m *Service) RenameTag(ctx context.Context, tenant string, oldTag string, func (_m *Service) Setup(ctx context.Context, req requests.Setup) error { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for Setup") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, requests.Setup) error); ok { r0 = rf(ctx, req) @@ -1226,6 +1442,10 @@ func (_m *Service) Setup(ctx context.Context, req requests.Setup) error { func (_m *Service) SetupVerify(ctx context.Context, sign string) error { ret := _m.Called(ctx, sign) + if len(ret) == 0 { + panic("no return value specified for SetupVerify") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { r0 = rf(ctx, sign) @@ -1240,6 +1460,10 @@ func (_m *Service) SetupVerify(ctx context.Context, sign string) error { func (_m *Service) SystemDownloadInstallScript(ctx context.Context) (string, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for SystemDownloadInstallScript") + } + var r0 string var r1 error if rf, ok := ret.Get(0).(func(context.Context) (string, error)); ok { @@ -1264,6 +1488,10 @@ func (_m *Service) SystemDownloadInstallScript(ctx context.Context) (string, err func (_m *Service) SystemGetInfo(ctx context.Context, req requests.SystemGetInfo) (*models.SystemInfo, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for SystemGetInfo") + } + var r0 *models.SystemInfo var r1 error if rf, ok := ret.Get(0).(func(context.Context, requests.SystemGetInfo) (*models.SystemInfo, error)); ok { @@ -1290,6 +1518,10 @@ func (_m *Service) SystemGetInfo(ctx context.Context, req requests.SystemGetInfo func (_m *Service) UpdateAPIKey(ctx context.Context, req *requests.UpdateAPIKey) error { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for UpdateAPIKey") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *requests.UpdateAPIKey) error); ok { r0 = rf(ctx, req) @@ -1304,6 +1536,10 @@ func (_m *Service) UpdateAPIKey(ctx context.Context, req *requests.UpdateAPIKey) func (_m *Service) UpdateDevice(ctx context.Context, tenant string, uid models.UID, name *string, publicURL *bool) error { ret := _m.Called(ctx, tenant, uid, name, publicURL) + if len(ret) == 0 { + panic("no return value specified for UpdateDevice") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, models.UID, *string, *bool) error); ok { r0 = rf(ctx, tenant, uid, name, publicURL) @@ -1318,6 +1554,10 @@ func (_m *Service) UpdateDevice(ctx context.Context, tenant string, uid models.U func (_m *Service) UpdateDeviceStatus(ctx context.Context, tenant string, uid models.UID, status models.DeviceStatus) error { ret := _m.Called(ctx, tenant, uid, status) + if len(ret) == 0 { + panic("no return value specified for UpdateDeviceStatus") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, models.UID, models.DeviceStatus) error); ok { r0 = rf(ctx, tenant, uid, status) @@ -1332,6 +1572,10 @@ func (_m *Service) UpdateDeviceStatus(ctx context.Context, tenant string, uid mo func (_m *Service) UpdateDeviceTag(ctx context.Context, uid models.UID, tags []string) error { ret := _m.Called(ctx, uid, tags) + if len(ret) == 0 { + panic("no return value specified for UpdateDeviceTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, []string) error); ok { r0 = rf(ctx, uid, tags) @@ -1346,6 +1590,10 @@ func (_m *Service) UpdateDeviceTag(ctx context.Context, uid models.UID, tags []s func (_m *Service) UpdateNamespaceMember(ctx context.Context, req *requests.NamespaceUpdateMember) error { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for UpdateNamespaceMember") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *requests.NamespaceUpdateMember) error); ok { r0 = rf(ctx, req) @@ -1360,6 +1608,10 @@ func (_m *Service) UpdateNamespaceMember(ctx context.Context, req *requests.Name func (_m *Service) UpdatePasswordUser(ctx context.Context, id string, currentPassword string, newPassword string) error { ret := _m.Called(ctx, id, currentPassword, newPassword) + if len(ret) == 0 { + panic("no return value specified for UpdatePasswordUser") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { r0 = rf(ctx, id, currentPassword, newPassword) @@ -1374,6 +1626,10 @@ func (_m *Service) UpdatePasswordUser(ctx context.Context, id string, currentPas func (_m *Service) UpdatePublicKey(ctx context.Context, fingerprint string, tenant string, key requests.PublicKeyUpdate) (*models.PublicKey, error) { ret := _m.Called(ctx, fingerprint, tenant, key) + if len(ret) == 0 { + panic("no return value specified for UpdatePublicKey") + } + var r0 *models.PublicKey var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string, requests.PublicKeyUpdate) (*models.PublicKey, error)); ok { @@ -1400,6 +1656,10 @@ func (_m *Service) UpdatePublicKey(ctx context.Context, fingerprint string, tena func (_m *Service) UpdatePublicKeyTags(ctx context.Context, tenant string, fingerprint string, tags []string) error { ret := _m.Called(ctx, tenant, fingerprint, tags) + if len(ret) == 0 { + panic("no return value specified for UpdatePublicKeyTags") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, []string) error); ok { r0 = rf(ctx, tenant, fingerprint, tags) @@ -1414,6 +1674,10 @@ func (_m *Service) UpdatePublicKeyTags(ctx context.Context, tenant string, finge func (_m *Service) UpdateSession(ctx context.Context, uid models.UID, model models.SessionUpdate) error { ret := _m.Called(ctx, uid, model) + if len(ret) == 0 { + panic("no return value specified for UpdateSession") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, models.SessionUpdate) error); ok { r0 = rf(ctx, uid, model) @@ -1428,6 +1692,10 @@ func (_m *Service) UpdateSession(ctx context.Context, uid models.UID, model mode func (_m *Service) UpdateUser(ctx context.Context, req *requests.UpdateUser) ([]string, error) { ret := _m.Called(ctx, req) + if len(ret) == 0 { + panic("no return value specified for UpdateUser") + } + var r0 []string var r1 error if rf, ok := ret.Get(0).(func(context.Context, *requests.UpdateUser) ([]string, error)); ok { @@ -1450,13 +1718,12 @@ func (_m *Service) UpdateUser(ctx context.Context, req *requests.UpdateUser) ([] return r0, r1 } -type mockConstructorTestingTNewService interface { +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { mock.TestingT Cleanup(func()) -} - -// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewService(t mockConstructorTestingTNewService) *Service { +}) *Service { mock := &Service{} mock.Mock.Test(t) diff --git a/api/services/sshkeys.go b/api/services/sshkeys.go index 0f7ba89b065..d16395a44fc 100644 --- a/api/services/sshkeys.go +++ b/api/services/sshkeys.go @@ -84,7 +84,7 @@ func (s *service) CreatePublicKey(ctx context.Context, req requests.PublicKeyCre } for _, tag := range req.Filter.Tags { - if !contains(tags, tag) { + if !containsTags(tags, tag) { return nil, NewErrTagNotFound(tag, nil) } } @@ -150,7 +150,7 @@ func (s *service) UpdatePublicKey(ctx context.Context, fingerprint, tenant strin } for _, tag := range key.Filter.Tags { - if !contains(tags, tag) { + if !containsTags(tags, tag) { return nil, NewErrTagNotFound(tag, nil) } } diff --git a/api/services/sshkeys_tags.go b/api/services/sshkeys_tags.go index 2c37b617e71..ff67edfb0bd 100644 --- a/api/services/sshkeys_tags.go +++ b/api/services/sshkeys_tags.go @@ -39,7 +39,7 @@ func (s *service) AddPublicKeyTag(ctx context.Context, tenant, fingerprint, tag return NewErrTagEmpty(tenant, err) } - if !contains(tags, tag) { + if !containsTags(tags, tag) { return NewErrTagNotFound(tag, nil) } @@ -129,7 +129,7 @@ func (s *service) UpdatePublicKeyTags(ctx context.Context, tenant, fingerprint s } for _, tag := range tags { - if !contains(allTags, tag) { + if !containsTags(allTags, tag) { return NewErrTagNotFound(tag, nil) } } diff --git a/api/services/sshkeys_tags_test.go b/api/services/sshkeys_tags_test.go index 806f8254a48..de285d6c87a 100644 --- a/api/services/sshkeys_tags_test.go +++ b/api/services/sshkeys_tags_test.go @@ -82,13 +82,23 @@ func TestAddPublicKeyTag(t *testing.T) { namespace := &models.Namespace{ TenantID: "tenant", } - tags := []string{"tag1", "tag2"} + tagsNames := []string{ + "tag-1", "tag-2", + } + tags := []models.Tags{ + { + Name: "tag-1", + }, + { + Name: "tag-2", + }, + } key := &models.PublicKey{ TenantID: "tenant", Fingerprint: "fingerprint", PublicKeyFields: models.PublicKeyFields{ Filter: models.PublicKeyFilter{ - Tags: tags, + Tags: tagsNames, }, }, } @@ -108,7 +118,24 @@ func TestAddPublicKeyTag(t *testing.T) { namespace := &models.Namespace{ TenantID: "tenant", } - tags := []string{"tag", "tag3", "tag6"} + // tagsNames := []string{"tag", "tag3", "tag6"} + tags := []models.Tags{ + { + Name: "tag", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag3", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag6", + Color: "", + Tenant: "tenant", + }, + } key := &models.PublicKey{ TenantID: "tenant", Fingerprint: "fingerprint", @@ -134,7 +161,23 @@ func TestAddPublicKeyTag(t *testing.T) { namespace := &models.Namespace{ TenantID: "tenant", } - tags := []string{"tag", "tag3", "tag6"} + tags := []models.Tags{ + { + Name: "tag", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag3", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag6", + Color: "", + Tenant: "tenant", + }, + } key := &models.PublicKey{ TenantID: "tenant", Fingerprint: "fingerprint", @@ -360,7 +403,23 @@ func TestUpdatePublicKeyTags(t *testing.T) { namespace := &models.Namespace{ TenantID: "tenant", } - tags := []string{"tag4", "tag5", "tag7", "tag5"} + tags := []models.Tags{ + { + Name: "tag4", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag5", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag7", + Color: "", + Tenant: "tenant", + }, + } key := &models.PublicKey{ TenantID: "tenant", Fingerprint: "fingerprint", @@ -386,7 +445,28 @@ func TestUpdatePublicKeyTags(t *testing.T) { namespace := &models.Namespace{ TenantID: "tenant", } - tags := []string{"tag1", "tag2", "tag3", "tag4"} + tags := []models.Tags{ + { + Name: "tag1", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag2", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag3", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag4", + Color: "", + Tenant: "tenant", + }, + } key := &models.PublicKey{ TenantID: "tenant", Fingerprint: "fingerprint", @@ -413,7 +493,28 @@ func TestUpdatePublicKeyTags(t *testing.T) { namespace := &models.Namespace{ TenantID: "tenant", } - tags := []string{"tag1", "tag2", "tag3", "tag4"} + tags := []models.Tags{ + { + Name: "tag1", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag2", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag3", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag4", + Color: "", + Tenant: "tenant", + }, + } key := &models.PublicKey{ TenantID: "tenant", Fingerprint: "fingerprint", diff --git a/api/services/sshkeys_test.go b/api/services/sshkeys_test.go index 89b30769892..a92b7af3be6 100644 --- a/api/services/sshkeys_test.go +++ b/api/services/sshkeys_test.go @@ -295,7 +295,7 @@ func TestUpdatePublicKeys(t *testing.T) { }, }, requiredMocks: func() { - mock.On("TagsGet", ctx, "tenant").Return([]string{}, 0, errors.New("error", "", 0)).Once() + mock.On("TagsGet", ctx, "tenant").Return([]models.Tags{}, 0, errors.New("error", "", 0)).Once() }, expected: Expected{nil, NewErrTagEmpty("tenant", errors.New("error", "", 0))}, }, @@ -309,7 +309,19 @@ func TestUpdatePublicKeys(t *testing.T) { }, }, requiredMocks: func() { - mock.On("TagsGet", ctx, "tenant").Return([]string{"tag1", "tag4"}, 2, nil).Once() + mock.On("TagsGet", ctx, "tenant"). + Return([]models.Tags{ + { + Name: "tag1", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag4", + Color: "", + Tenant: "tenant", + }, + }, 2, nil).Once() }, expected: Expected{nil, NewErrTagNotFound("tag2", nil)}, }, @@ -331,7 +343,19 @@ func TestUpdatePublicKeys(t *testing.T) { }, } - mock.On("TagsGet", ctx, "tenant").Return([]string{"tag1", "tag2"}, 2, nil).Once() + mock.On("TagsGet", ctx, "tenant"). + Return([]models.Tags{ + { + Name: "tag1", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag2", + Color: "", + Tenant: "tenant", + }, + }, 2, nil).Once() mock.On("PublicKeyUpdate", ctx, "fingerprint", "tenant", &model).Return(nil, errors.New("error", "", 0)).Once() }, expected: Expected{nil, errors.New("error", "", 0)}, @@ -362,7 +386,15 @@ func TestUpdatePublicKeys(t *testing.T) { }, } - mock.On("TagsGet", ctx, "tenant").Return([]string{"tag1", "tag2"}, 2, nil).Once() + mock.On("TagsGet", ctx, "tenant"). + Return([]models.Tags{ + { + Name: "tag1", + }, + { + Name: "tag2", + }, + }, 2, nil).Once() mock.On("PublicKeyUpdate", ctx, "fingerprint", "tenant", &model).Return(keyUpdateWithTagsModel, nil).Once() }, expected: Expected{&models.PublicKey{ @@ -581,7 +613,7 @@ func TestCreatePublicKeys(t *testing.T) { }, }, requiredMocks: func() { - mock.On("TagsGet", ctx, "tenant").Return([]string{}, 0, errors.New("error", "", 0)).Once() + mock.On("TagsGet", ctx, "tenant").Return([]models.Tags{}, 0, errors.New("error", "", 0)).Once() }, expected: Expected{nil, NewErrTagEmpty("tenant", errors.New("error", "", 0))}, }, @@ -597,7 +629,18 @@ func TestCreatePublicKeys(t *testing.T) { }, }, requiredMocks: func() { - mock.On("TagsGet", ctx, "tenant").Return([]string{"tag1", "tag4"}, 2, nil).Once() + mock.On("TagsGet", ctx, "tenant").Return([]models.Tags{ + { + Name: "tag1", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag4", + Color: "", + Tenant: "tenant", + }, + }, 2, nil).Once() }, expected: Expected{nil, NewErrTagNotFound("tag2", nil)}, }, @@ -868,7 +911,18 @@ func TestCreatePublicKeys(t *testing.T) { }, } - mock.On("TagsGet", ctx, keyWithTags.TenantID).Return([]string{"tag1", "tag2"}, 2, nil).Once() + mock.On("TagsGet", ctx, keyWithTags.TenantID).Return([]models.Tags{ + { + Name: "tag1", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag2", + Color: "", + Tenant: "tenant", + }, + }, 2, nil).Once() mock.On("PublicKeyGet", ctx, keyWithTags.Fingerprint, "tenant").Return(nil, nil).Once() mock.On("PublicKeyCreate", ctx, &keyWithTagsModel).Return(errors.New("error", "", 0)).Once() }, @@ -907,7 +961,18 @@ func TestCreatePublicKeys(t *testing.T) { }, } - mock.On("TagsGet", ctx, keyWithTags.TenantID).Return([]string{"tag1", "tag2"}, 2, nil).Once() + mock.On("TagsGet", ctx, keyWithTags.TenantID).Return([]models.Tags{ + { + Name: "tag1", + Color: "", + Tenant: "tenant", + }, + { + Name: "tag2", + Color: "", + Tenant: "tenant", + }, + }, 2, nil).Once() mock.On("PublicKeyGet", ctx, keyWithTags.Fingerprint, "tenant").Return(nil, nil).Once() mock.On("PublicKeyCreate", ctx, &keyWithTagsModel).Return(nil).Once() }, diff --git a/api/services/tags.go b/api/services/tags.go index d43ddafcaf8..e5e70d2048a 100644 --- a/api/services/tags.go +++ b/api/services/tags.go @@ -7,18 +7,20 @@ import ( ) type TagsService interface { - GetTags(ctx context.Context, tenant string) ([]string, int, error) + GetTags(ctx context.Context, tenant string) ([]models.Tags, int, error) RenameTag(ctx context.Context, tenant string, oldTag string, newTag string) error DeleteTag(ctx context.Context, tenant string, tag string) error } -func (s *service) GetTags(ctx context.Context, tenant string) ([]string, int, error) { +func (s *service) GetTags(ctx context.Context, tenant string) ([]models.Tags, int, error) { namespace, err := s.store.NamespaceGet(ctx, tenant) if err != nil || namespace == nil { return nil, 0, NewErrNamespaceNotFound(tenant, err) } - return s.store.TagsGet(ctx, namespace.TenantID) + tags, count, err := s.store.TagsGet(ctx, namespace.TenantID) + + return tags, int(count), err } func (s *service) RenameTag(ctx context.Context, tenant string, oldTag string, newTag string) error { @@ -31,11 +33,11 @@ func (s *service) RenameTag(ctx context.Context, tenant string, oldTag string, n return NewErrTagEmpty(tenant, err) } - if !contains(tags, oldTag) { + if !containsTags(tags, oldTag) { return NewErrTagNotFound(oldTag, nil) } - if contains(tags, newTag) { + if containsTags(tags, newTag) { return NewErrTagDuplicated(newTag, nil) } @@ -59,7 +61,7 @@ func (s *service) DeleteTag(ctx context.Context, tenant string, tag string) erro return NewErrTagEmpty(tenant, err) } - if !contains(tags, tag) { + if !containsTags(tags, tag) { return NewErrTagNotFound(tag, nil) } diff --git a/api/services/tags_test.go b/api/services/tags_test.go index 8b5303885ac..b6e9d083ee3 100644 --- a/api/services/tags_test.go +++ b/api/services/tags_test.go @@ -20,7 +20,7 @@ func TestGetTags(t *testing.T) { ctx := context.TODO() type Expected struct { - Tags []string + Tags []models.Tags Count int Error error } @@ -73,10 +73,32 @@ func TestGetTags(t *testing.T) { namespace := &models.Namespace{Name: "namespace", TenantID: "tenant"} mock.On("NamespaceGet", ctx, "tenant").Return(namespace, nil).Once() - mock.On("TagsGet", ctx, "tenant").Return(device.Tags, len(device.Tags), nil).Once() + mock.On("TagsGet", ctx, "tenant").Return([]models.Tags{ + { + Name: "device1", + Tenant: "tenant", + Color: "#efdaef", + }, + { + Name: "device2", + Tenant: "tenant", + Color: "#efdaef", + }, + }, len(device.Tags), nil).Once() }, expected: Expected{ - Tags: []string{"device1", "device2"}, + Tags: []models.Tags{ + { + Name: "device1", + Tenant: "tenant", + Color: "#efdaef", + }, + { + Name: "device2", + Tenant: "tenant", + Color: "#efdaef", + }, + }, Count: len([]string{"device1", "device2"}), Error: nil, }, @@ -148,7 +170,23 @@ func TestRenameTag(t *testing.T) { Tags: []string{"device3", "device4", "device5"}, } - mock.On("TagsGet", ctx, namespace.TenantID).Return(deviceWithTags.Tags, len(deviceWithTags.Tags), nil).Once() + mock.On("TagsGet", ctx, namespace.TenantID).Return([]models.Tags{ + { + Name: "device3", + Color: "", + Tenant: "tenant", + }, + { + Name: "device4", + Color: "", + Tenant: "tenant", + }, + { + Name: "device5", + Color: "", + Tenant: "tenant", + }, + }, len(deviceWithTags.Tags), nil).Once() }, expected: NewErrTagNotFound("device2", nil), }, @@ -171,7 +209,23 @@ func TestRenameTag(t *testing.T) { Tags: []string{"device3", "device4", "device5"}, } - mock.On("TagsGet", ctx, namespace.TenantID).Return(deviceWithTags.Tags, len(deviceWithTags.Tags), nil).Once() + mock.On("TagsGet", ctx, namespace.TenantID).Return([]models.Tags{ + { + Name: "device3", + Color: "", + Tenant: "tenant", + }, + { + Name: "device4", + Color: "", + Tenant: "tenant", + }, + { + Name: "device5", + Color: "", + Tenant: "tenant", + }, + }, len(deviceWithTags.Tags), nil).Once() }, expected: NewErrTagDuplicated("device5", nil), }, @@ -194,7 +248,23 @@ func TestRenameTag(t *testing.T) { Tags: []string{"device3", "device4", "device5"}, } - mock.On("TagsGet", ctx, namespace.TenantID).Return(deviceWithTags.Tags, len(deviceWithTags.Tags), nil).Once() + mock.On("TagsGet", ctx, namespace.TenantID).Return([]models.Tags{ + { + Name: "device3", + Color: "", + Tenant: "tenant", + }, + { + Name: "device4", + Color: "", + Tenant: "tenant", + }, + { + Name: "device5", + Color: "", + Tenant: "tenant", + }, + }, len(deviceWithTags.Tags), nil).Once() mock.On("TagsRename", ctx, namespace.TenantID, "device3", "device1").Return(int64(0), errors.New("error", "", 0)).Once() }, expected: errors.New("error", "", 0), @@ -218,7 +288,23 @@ func TestRenameTag(t *testing.T) { Tags: []string{"device3", "device4", "device5"}, } - mock.On("TagsGet", ctx, namespace.TenantID).Return(deviceWithTags.Tags, len(deviceWithTags.Tags), nil).Once() + mock.On("TagsGet", ctx, namespace.TenantID).Return([]models.Tags{ + { + Name: "device3", + Color: "", + Tenant: "tenant", + }, + { + Name: "device4", + Color: "", + Tenant: "tenant", + }, + { + Name: "device5", + Color: "", + Tenant: "tenant", + }, + }, len(deviceWithTags.Tags), nil).Once() mock.On("TagsRename", ctx, namespace.TenantID, "device3", "device1").Return(int64(1), nil).Once() }, expected: nil, @@ -296,7 +382,18 @@ func TestDeleteTag(t *testing.T) { } mock.On("NamespaceGet", ctx, "tenant").Return(namespace, nil).Once() - mock.On("TagsGet", ctx, "tenant").Return(device.Tags, len(device.Tags), nil).Once() + mock.On("TagsGet", ctx, "tenant").Return([]models.Tags{ + { + Name: "device1", + Color: "", + Tenant: "tenant", + }, + { + Name: "device2", + Color: "", + Tenant: "tenant", + }, + }, len(device.Tags), nil).Once() }, expected: NewErrTagNotFound("device3", nil), }, @@ -315,7 +412,18 @@ func TestDeleteTag(t *testing.T) { } mock.On("NamespaceGet", ctx, "tenant").Return(namespace, nil).Once() - mock.On("TagsGet", ctx, "tenant").Return(device.Tags, len(device.Tags), nil).Once() + mock.On("TagsGet", ctx, "tenant").Return([]models.Tags{ + { + Name: "device1", + Color: "", + Tenant: "tenant", + }, + { + Name: "device2", + Color: "", + Tenant: "tenant", + }, + }, len(device.Tags), nil).Once() mock.On("TagsDelete", ctx, "tenant", "device1").Return(int64(0), errors.New("error", "", 0)).Once() }, expected: errors.New("error", "", 0), @@ -335,7 +443,18 @@ func TestDeleteTag(t *testing.T) { } mock.On("NamespaceGet", ctx, "tenant").Return(namespace, nil).Once() - mock.On("TagsGet", ctx, "tenant").Return(device.Tags, len(device.Tags), nil).Once() + mock.On("TagsGet", ctx, "tenant").Return([]models.Tags{ + { + Name: "device1", + Color: "", + Tenant: "tenant", + }, + { + Name: "device2", + Color: "", + Tenant: "tenant", + }, + }, len(device.Tags), nil).Once() mock.On("TagsDelete", ctx, "tenant", "device1").Return(int64(1), nil).Once() }, expected: nil, diff --git a/api/services/utils.go b/api/services/utils.go index fb64061069c..b384f5f2314 100644 --- a/api/services/utils.go +++ b/api/services/utils.go @@ -3,8 +3,10 @@ package services import ( "crypto/rsa" "os" + "slices" jwt "github.com/golang-jwt/jwt/v4" + "github.com/shellhub-io/shellhub/pkg/models" ) func LoadKeys() (*rsa.PrivateKey, *rsa.PublicKey, error) { @@ -31,12 +33,14 @@ func LoadKeys() (*rsa.PrivateKey, *rsa.PublicKey, error) { return privKey, pubKey, nil } -func contains(list []string, item string) bool { - for _, i := range list { - if i == item { - return true - } - } +func containsTags(list []models.Tags, item string) bool { + return slices.ContainsFunc(list, func(n models.Tags) bool { + return n.Name == item + }) +} - return false +func contains(list []string, item string) bool { + return slices.ContainsFunc(list, func(n string) bool { + return n == item + }) } diff --git a/api/store/mocks/store.go b/api/store/mocks/store.go index ef863ef06a7..3d1bd1d405e 100644 --- a/api/store/mocks/store.go +++ b/api/store/mocks/store.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.20.0. DO NOT EDIT. +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. package mocks @@ -24,6 +24,10 @@ type Store struct { func (_m *Store) APIKeyConflicts(ctx context.Context, tenantID string, target *models.APIKeyConflicts) ([]string, bool, error) { ret := _m.Called(ctx, tenantID, target) + if len(ret) == 0 { + panic("no return value specified for APIKeyConflicts") + } + var r0 []string var r1 bool var r2 error @@ -57,6 +61,10 @@ func (_m *Store) APIKeyConflicts(ctx context.Context, tenantID string, target *m func (_m *Store) APIKeyCreate(ctx context.Context, APIKey *models.APIKey) (string, error) { ret := _m.Called(ctx, APIKey) + if len(ret) == 0 { + panic("no return value specified for APIKeyCreate") + } + var r0 string var r1 error if rf, ok := ret.Get(0).(func(context.Context, *models.APIKey) (string, error)); ok { @@ -81,6 +89,10 @@ func (_m *Store) APIKeyCreate(ctx context.Context, APIKey *models.APIKey) (strin func (_m *Store) APIKeyDelete(ctx context.Context, tenantID string, name string) error { ret := _m.Called(ctx, tenantID, name) + if len(ret) == 0 { + panic("no return value specified for APIKeyDelete") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { r0 = rf(ctx, tenantID, name) @@ -95,6 +107,10 @@ func (_m *Store) APIKeyDelete(ctx context.Context, tenantID string, name string) func (_m *Store) APIKeyGet(ctx context.Context, id string) (*models.APIKey, error) { ret := _m.Called(ctx, id) + if len(ret) == 0 { + panic("no return value specified for APIKeyGet") + } + var r0 *models.APIKey var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (*models.APIKey, error)); ok { @@ -121,6 +137,10 @@ func (_m *Store) APIKeyGet(ctx context.Context, id string) (*models.APIKey, erro func (_m *Store) APIKeyGetByName(ctx context.Context, tenantID string, name string) (*models.APIKey, error) { ret := _m.Called(ctx, tenantID, name) + if len(ret) == 0 { + panic("no return value specified for APIKeyGetByName") + } + var r0 *models.APIKey var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.APIKey, error)); ok { @@ -147,6 +167,10 @@ func (_m *Store) APIKeyGetByName(ctx context.Context, tenantID string, name stri func (_m *Store) APIKeyList(ctx context.Context, tenantID string, paginator query.Paginator, sorter query.Sorter) ([]models.APIKey, int, error) { ret := _m.Called(ctx, tenantID, paginator, sorter) + if len(ret) == 0 { + panic("no return value specified for APIKeyList") + } + var r0 []models.APIKey var r1 int var r2 error @@ -180,6 +204,10 @@ func (_m *Store) APIKeyList(ctx context.Context, tenantID string, paginator quer func (_m *Store) APIKeyUpdate(ctx context.Context, tenantID string, name string, changes *models.APIKeyChanges) error { ret := _m.Called(ctx, tenantID, name, changes) + if len(ret) == 0 { + panic("no return value specified for APIKeyUpdate") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, *models.APIKeyChanges) error); ok { r0 = rf(ctx, tenantID, name, changes) @@ -194,6 +222,10 @@ func (_m *Store) APIKeyUpdate(ctx context.Context, tenantID string, name string, func (_m *Store) DeviceBulkDeleteTag(ctx context.Context, tenant string, tag string) (int64, error) { ret := _m.Called(ctx, tenant, tag) + if len(ret) == 0 { + panic("no return value specified for DeviceBulkDeleteTag") + } + var r0 int64 var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string) (int64, error)); ok { @@ -218,6 +250,10 @@ func (_m *Store) DeviceBulkDeleteTag(ctx context.Context, tenant string, tag str func (_m *Store) DeviceBulkRenameTag(ctx context.Context, tenant string, currentTag string, newTag string) (int64, error) { ret := _m.Called(ctx, tenant, currentTag, newTag) + if len(ret) == 0 { + panic("no return value specified for DeviceBulkRenameTag") + } + var r0 int64 var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (int64, error)); ok { @@ -242,6 +278,10 @@ func (_m *Store) DeviceBulkRenameTag(ctx context.Context, tenant string, current func (_m *Store) DeviceChooser(ctx context.Context, tenantID string, chosen []string) error { ret := _m.Called(ctx, tenantID, chosen) + if len(ret) == 0 { + panic("no return value specified for DeviceChooser") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, []string) error); ok { r0 = rf(ctx, tenantID, chosen) @@ -256,6 +296,10 @@ func (_m *Store) DeviceChooser(ctx context.Context, tenantID string, chosen []st func (_m *Store) DeviceCreate(ctx context.Context, d models.Device, hostname string) error { ret := _m.Called(ctx, d, hostname) + if len(ret) == 0 { + panic("no return value specified for DeviceCreate") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.Device, string) error); ok { r0 = rf(ctx, d, hostname) @@ -270,6 +314,10 @@ func (_m *Store) DeviceCreate(ctx context.Context, d models.Device, hostname str func (_m *Store) DeviceCreatePublicURLAddress(ctx context.Context, uid models.UID) error { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for DeviceCreatePublicURLAddress") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) error); ok { r0 = rf(ctx, uid) @@ -284,6 +332,10 @@ func (_m *Store) DeviceCreatePublicURLAddress(ctx context.Context, uid models.UI func (_m *Store) DeviceDelete(ctx context.Context, uid models.UID) error { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for DeviceDelete") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) error); ok { r0 = rf(ctx, uid) @@ -298,6 +350,10 @@ func (_m *Store) DeviceDelete(ctx context.Context, uid models.UID) error { func (_m *Store) DeviceGet(ctx context.Context, uid models.UID) (*models.Device, error) { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for DeviceGet") + } + var r0 *models.Device var r1 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) (*models.Device, error)); ok { @@ -324,6 +380,10 @@ func (_m *Store) DeviceGet(ctx context.Context, uid models.UID) (*models.Device, func (_m *Store) DeviceGetByMac(ctx context.Context, mac string, tenantID string, status models.DeviceStatus) (*models.Device, error) { ret := _m.Called(ctx, mac, tenantID, status) + if len(ret) == 0 { + panic("no return value specified for DeviceGetByMac") + } + var r0 *models.Device var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string, models.DeviceStatus) (*models.Device, error)); ok { @@ -350,6 +410,10 @@ func (_m *Store) DeviceGetByMac(ctx context.Context, mac string, tenantID string func (_m *Store) DeviceGetByName(ctx context.Context, name string, tenantID string, status models.DeviceStatus) (*models.Device, error) { ret := _m.Called(ctx, name, tenantID, status) + if len(ret) == 0 { + panic("no return value specified for DeviceGetByName") + } + var r0 *models.Device var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string, models.DeviceStatus) (*models.Device, error)); ok { @@ -376,6 +440,10 @@ func (_m *Store) DeviceGetByName(ctx context.Context, name string, tenantID stri func (_m *Store) DeviceGetByPublicURLAddress(ctx context.Context, address string) (*models.Device, error) { ret := _m.Called(ctx, address) + if len(ret) == 0 { + panic("no return value specified for DeviceGetByPublicURLAddress") + } + var r0 *models.Device var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (*models.Device, error)); ok { @@ -402,6 +470,10 @@ func (_m *Store) DeviceGetByPublicURLAddress(ctx context.Context, address string func (_m *Store) DeviceGetByUID(ctx context.Context, uid models.UID, tenantID string) (*models.Device, error) { ret := _m.Called(ctx, uid, tenantID) + if len(ret) == 0 { + panic("no return value specified for DeviceGetByUID") + } + var r0 *models.Device var r1 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, string) (*models.Device, error)); ok { @@ -428,6 +500,10 @@ func (_m *Store) DeviceGetByUID(ctx context.Context, uid models.UID, tenantID st func (_m *Store) DeviceGetTags(ctx context.Context, tenant string) ([]string, int, error) { ret := _m.Called(ctx, tenant) + if len(ret) == 0 { + panic("no return value specified for DeviceGetTags") + } + var r0 []string var r1 int var r2 error @@ -461,6 +537,10 @@ func (_m *Store) DeviceGetTags(ctx context.Context, tenant string) ([]string, in func (_m *Store) DeviceList(ctx context.Context, status models.DeviceStatus, pagination query.Paginator, filters query.Filters, sorter query.Sorter, acceptable store.DeviceAcceptable) ([]models.Device, int, error) { ret := _m.Called(ctx, status, pagination, filters, sorter, acceptable) + if len(ret) == 0 { + panic("no return value specified for DeviceList") + } + var r0 []models.Device var r1 int var r2 error @@ -494,6 +574,10 @@ func (_m *Store) DeviceList(ctx context.Context, status models.DeviceStatus, pag func (_m *Store) DeviceListByUsage(ctx context.Context, tenantID string) ([]models.UID, error) { ret := _m.Called(ctx, tenantID) + if len(ret) == 0 { + panic("no return value specified for DeviceListByUsage") + } + var r0 []models.UID var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) ([]models.UID, error)); ok { @@ -520,6 +604,10 @@ func (_m *Store) DeviceListByUsage(ctx context.Context, tenantID string) ([]mode func (_m *Store) DeviceLookup(ctx context.Context, namespace string, hostname string) (*models.Device, error) { ret := _m.Called(ctx, namespace, hostname) + if len(ret) == 0 { + panic("no return value specified for DeviceLookup") + } + var r0 *models.Device var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.Device, error)); ok { @@ -546,6 +634,10 @@ func (_m *Store) DeviceLookup(ctx context.Context, namespace string, hostname st func (_m *Store) DevicePullTag(ctx context.Context, uid models.UID, tag string) error { ret := _m.Called(ctx, uid, tag) + if len(ret) == 0 { + panic("no return value specified for DevicePullTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, string) error); ok { r0 = rf(ctx, uid, tag) @@ -560,6 +652,10 @@ func (_m *Store) DevicePullTag(ctx context.Context, uid models.UID, tag string) func (_m *Store) DevicePushTag(ctx context.Context, uid models.UID, tag string) error { ret := _m.Called(ctx, uid, tag) + if len(ret) == 0 { + panic("no return value specified for DevicePushTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, string) error); ok { r0 = rf(ctx, uid, tag) @@ -574,6 +670,10 @@ func (_m *Store) DevicePushTag(ctx context.Context, uid models.UID, tag string) func (_m *Store) DeviceRemovedCount(ctx context.Context, tenant string) (int64, error) { ret := _m.Called(ctx, tenant) + if len(ret) == 0 { + panic("no return value specified for DeviceRemovedCount") + } + var r0 int64 var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (int64, error)); ok { @@ -598,6 +698,10 @@ func (_m *Store) DeviceRemovedCount(ctx context.Context, tenant string) (int64, func (_m *Store) DeviceRemovedDelete(ctx context.Context, tenant string, uid models.UID) error { ret := _m.Called(ctx, tenant, uid) + if len(ret) == 0 { + panic("no return value specified for DeviceRemovedDelete") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, models.UID) error); ok { r0 = rf(ctx, tenant, uid) @@ -612,6 +716,10 @@ func (_m *Store) DeviceRemovedDelete(ctx context.Context, tenant string, uid mod func (_m *Store) DeviceRemovedGet(ctx context.Context, tenant string, uid models.UID) (*models.DeviceRemoved, error) { ret := _m.Called(ctx, tenant, uid) + if len(ret) == 0 { + panic("no return value specified for DeviceRemovedGet") + } + var r0 *models.DeviceRemoved var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, models.UID) (*models.DeviceRemoved, error)); ok { @@ -638,6 +746,10 @@ func (_m *Store) DeviceRemovedGet(ctx context.Context, tenant string, uid models func (_m *Store) DeviceRemovedInsert(ctx context.Context, tenant string, device *models.Device) error { ret := _m.Called(ctx, tenant, device) + if len(ret) == 0 { + panic("no return value specified for DeviceRemovedInsert") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, *models.Device) error); ok { r0 = rf(ctx, tenant, device) @@ -652,6 +764,10 @@ func (_m *Store) DeviceRemovedInsert(ctx context.Context, tenant string, device func (_m *Store) DeviceRemovedList(ctx context.Context, tenant string, pagination query.Paginator, filters query.Filters, sorter query.Sorter) ([]models.DeviceRemoved, int, error) { ret := _m.Called(ctx, tenant, pagination, filters, sorter) + if len(ret) == 0 { + panic("no return value specified for DeviceRemovedList") + } + var r0 []models.DeviceRemoved var r1 int var r2 error @@ -685,6 +801,10 @@ func (_m *Store) DeviceRemovedList(ctx context.Context, tenant string, paginatio func (_m *Store) DeviceRename(ctx context.Context, uid models.UID, hostname string) error { ret := _m.Called(ctx, uid, hostname) + if len(ret) == 0 { + panic("no return value specified for DeviceRename") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, string) error); ok { r0 = rf(ctx, uid, hostname) @@ -699,6 +819,10 @@ func (_m *Store) DeviceRename(ctx context.Context, uid models.UID, hostname stri func (_m *Store) DeviceSetOffline(ctx context.Context, uid string) error { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for DeviceSetOffline") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { r0 = rf(ctx, uid) @@ -713,6 +837,10 @@ func (_m *Store) DeviceSetOffline(ctx context.Context, uid string) error { func (_m *Store) DeviceSetOnline(ctx context.Context, connectedDevices []models.ConnectedDevice) error { ret := _m.Called(ctx, connectedDevices) + if len(ret) == 0 { + panic("no return value specified for DeviceSetOnline") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, []models.ConnectedDevice) error); ok { r0 = rf(ctx, connectedDevices) @@ -727,6 +855,10 @@ func (_m *Store) DeviceSetOnline(ctx context.Context, connectedDevices []models. func (_m *Store) DeviceSetPosition(ctx context.Context, uid models.UID, position models.DevicePosition) error { ret := _m.Called(ctx, uid, position) + if len(ret) == 0 { + panic("no return value specified for DeviceSetPosition") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, models.DevicePosition) error); ok { r0 = rf(ctx, uid, position) @@ -741,6 +873,10 @@ func (_m *Store) DeviceSetPosition(ctx context.Context, uid models.UID, position func (_m *Store) DeviceSetTags(ctx context.Context, uid models.UID, tags []string) (int64, int64, error) { ret := _m.Called(ctx, uid, tags) + if len(ret) == 0 { + panic("no return value specified for DeviceSetTags") + } + var r0 int64 var r1 int64 var r2 error @@ -772,6 +908,10 @@ func (_m *Store) DeviceSetTags(ctx context.Context, uid models.UID, tags []strin func (_m *Store) DeviceUpdate(ctx context.Context, tenant string, uid models.UID, name *string, publicURL *bool) error { ret := _m.Called(ctx, tenant, uid, name, publicURL) + if len(ret) == 0 { + panic("no return value specified for DeviceUpdate") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, models.UID, *string, *bool) error); ok { r0 = rf(ctx, tenant, uid, name, publicURL) @@ -786,6 +926,10 @@ func (_m *Store) DeviceUpdate(ctx context.Context, tenant string, uid models.UID func (_m *Store) DeviceUpdateLastSeen(ctx context.Context, uid models.UID, ts time.Time) error { ret := _m.Called(ctx, uid, ts) + if len(ret) == 0 { + panic("no return value specified for DeviceUpdateLastSeen") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, time.Time) error); ok { r0 = rf(ctx, uid, ts) @@ -800,6 +944,10 @@ func (_m *Store) DeviceUpdateLastSeen(ctx context.Context, uid models.UID, ts ti func (_m *Store) DeviceUpdateOnline(ctx context.Context, uid models.UID, online bool) error { ret := _m.Called(ctx, uid, online) + if len(ret) == 0 { + panic("no return value specified for DeviceUpdateOnline") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, bool) error); ok { r0 = rf(ctx, uid, online) @@ -814,6 +962,10 @@ func (_m *Store) DeviceUpdateOnline(ctx context.Context, uid models.UID, online func (_m *Store) DeviceUpdateStatus(ctx context.Context, uid models.UID, status models.DeviceStatus) error { ret := _m.Called(ctx, uid, status) + if len(ret) == 0 { + panic("no return value specified for DeviceUpdateStatus") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, models.DeviceStatus) error); ok { r0 = rf(ctx, uid, status) @@ -828,6 +980,10 @@ func (_m *Store) DeviceUpdateStatus(ctx context.Context, uid models.UID, status func (_m *Store) GetStats(ctx context.Context) (*models.Stats, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for GetStats") + } + var r0 *models.Stats var r1 error if rf, ok := ret.Get(0).(func(context.Context) (*models.Stats, error)); ok { @@ -854,6 +1010,10 @@ func (_m *Store) GetStats(ctx context.Context) (*models.Stats, error) { func (_m *Store) NamespaceAddMember(ctx context.Context, tenantID string, member *models.Member) error { ret := _m.Called(ctx, tenantID, member) + if len(ret) == 0 { + panic("no return value specified for NamespaceAddMember") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, *models.Member) error); ok { r0 = rf(ctx, tenantID, member) @@ -868,6 +1028,10 @@ func (_m *Store) NamespaceAddMember(ctx context.Context, tenantID string, member func (_m *Store) NamespaceCreate(ctx context.Context, namespace *models.Namespace) (*models.Namespace, error) { ret := _m.Called(ctx, namespace) + if len(ret) == 0 { + panic("no return value specified for NamespaceCreate") + } + var r0 *models.Namespace var r1 error if rf, ok := ret.Get(0).(func(context.Context, *models.Namespace) (*models.Namespace, error)); ok { @@ -894,6 +1058,10 @@ func (_m *Store) NamespaceCreate(ctx context.Context, namespace *models.Namespac func (_m *Store) NamespaceDelete(ctx context.Context, tenantID string) error { ret := _m.Called(ctx, tenantID) + if len(ret) == 0 { + panic("no return value specified for NamespaceDelete") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { r0 = rf(ctx, tenantID) @@ -908,6 +1076,10 @@ func (_m *Store) NamespaceDelete(ctx context.Context, tenantID string) error { func (_m *Store) NamespaceEdit(ctx context.Context, tenant string, changes *models.NamespaceChanges) error { ret := _m.Called(ctx, tenant, changes) + if len(ret) == 0 { + panic("no return value specified for NamespaceEdit") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, *models.NamespaceChanges) error); ok { r0 = rf(ctx, tenant, changes) @@ -929,6 +1101,10 @@ func (_m *Store) NamespaceGet(ctx context.Context, tenantID string, opts ...stor _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for NamespaceGet") + } + var r0 *models.Namespace var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, ...store.NamespaceQueryOption) (*models.Namespace, error)); ok { @@ -962,6 +1138,10 @@ func (_m *Store) NamespaceGetByName(ctx context.Context, name string, opts ...st _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for NamespaceGetByName") + } + var r0 *models.Namespace var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, ...store.NamespaceQueryOption) (*models.Namespace, error)); ok { @@ -995,6 +1175,10 @@ func (_m *Store) NamespaceGetPreferred(ctx context.Context, userID string, opts _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for NamespaceGetPreferred") + } + var r0 *models.Namespace var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, ...store.NamespaceQueryOption) (*models.Namespace, error)); ok { @@ -1021,6 +1205,10 @@ func (_m *Store) NamespaceGetPreferred(ctx context.Context, userID string, opts func (_m *Store) NamespaceGetSessionRecord(ctx context.Context, tenantID string) (bool, error) { ret := _m.Called(ctx, tenantID) + if len(ret) == 0 { + panic("no return value specified for NamespaceGetSessionRecord") + } + var r0 bool var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok { @@ -1052,6 +1240,10 @@ func (_m *Store) NamespaceList(ctx context.Context, paginator query.Paginator, f _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for NamespaceList") + } + var r0 []models.Namespace var r1 int var r2 error @@ -1085,6 +1277,10 @@ func (_m *Store) NamespaceList(ctx context.Context, paginator query.Paginator, f func (_m *Store) NamespaceRemoveMember(ctx context.Context, tenantID string, memberID string) error { ret := _m.Called(ctx, tenantID, memberID) + if len(ret) == 0 { + panic("no return value specified for NamespaceRemoveMember") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { r0 = rf(ctx, tenantID, memberID) @@ -1099,6 +1295,10 @@ func (_m *Store) NamespaceRemoveMember(ctx context.Context, tenantID string, mem func (_m *Store) NamespaceSetSessionRecord(ctx context.Context, sessionRecord bool, tenantID string) error { ret := _m.Called(ctx, sessionRecord, tenantID) + if len(ret) == 0 { + panic("no return value specified for NamespaceSetSessionRecord") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, bool, string) error); ok { r0 = rf(ctx, sessionRecord, tenantID) @@ -1113,6 +1313,10 @@ func (_m *Store) NamespaceSetSessionRecord(ctx context.Context, sessionRecord bo func (_m *Store) NamespaceUpdate(ctx context.Context, tenantID string, namespace *models.Namespace) error { ret := _m.Called(ctx, tenantID, namespace) + if len(ret) == 0 { + panic("no return value specified for NamespaceUpdate") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, *models.Namespace) error); ok { r0 = rf(ctx, tenantID, namespace) @@ -1127,6 +1331,10 @@ func (_m *Store) NamespaceUpdate(ctx context.Context, tenantID string, namespace func (_m *Store) NamespaceUpdateMember(ctx context.Context, tenantID string, memberID string, changes *models.MemberChanges) error { ret := _m.Called(ctx, tenantID, memberID, changes) + if len(ret) == 0 { + panic("no return value specified for NamespaceUpdateMember") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, *models.MemberChanges) error); ok { r0 = rf(ctx, tenantID, memberID, changes) @@ -1141,6 +1349,10 @@ func (_m *Store) NamespaceUpdateMember(ctx context.Context, tenantID string, mem func (_m *Store) Options() store.QueryOptions { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Options") + } + var r0 store.QueryOptions if rf, ok := ret.Get(0).(func() store.QueryOptions); ok { r0 = rf() @@ -1157,6 +1369,10 @@ func (_m *Store) Options() store.QueryOptions { func (_m *Store) PrivateKeyCreate(ctx context.Context, key *models.PrivateKey) error { ret := _m.Called(ctx, key) + if len(ret) == 0 { + panic("no return value specified for PrivateKeyCreate") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *models.PrivateKey) error); ok { r0 = rf(ctx, key) @@ -1171,6 +1387,10 @@ func (_m *Store) PrivateKeyCreate(ctx context.Context, key *models.PrivateKey) e func (_m *Store) PrivateKeyGet(ctx context.Context, fingerprint string) (*models.PrivateKey, error) { ret := _m.Called(ctx, fingerprint) + if len(ret) == 0 { + panic("no return value specified for PrivateKeyGet") + } + var r0 *models.PrivateKey var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (*models.PrivateKey, error)); ok { @@ -1197,6 +1417,10 @@ func (_m *Store) PrivateKeyGet(ctx context.Context, fingerprint string) (*models func (_m *Store) PublicKeyBulkDeleteTag(ctx context.Context, tenant string, tag string) (int64, error) { ret := _m.Called(ctx, tenant, tag) + if len(ret) == 0 { + panic("no return value specified for PublicKeyBulkDeleteTag") + } + var r0 int64 var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string) (int64, error)); ok { @@ -1221,6 +1445,10 @@ func (_m *Store) PublicKeyBulkDeleteTag(ctx context.Context, tenant string, tag func (_m *Store) PublicKeyBulkRenameTag(ctx context.Context, tenant string, currentTag string, newTag string) (int64, error) { ret := _m.Called(ctx, tenant, currentTag, newTag) + if len(ret) == 0 { + panic("no return value specified for PublicKeyBulkRenameTag") + } + var r0 int64 var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (int64, error)); ok { @@ -1245,6 +1473,10 @@ func (_m *Store) PublicKeyBulkRenameTag(ctx context.Context, tenant string, curr func (_m *Store) PublicKeyCreate(ctx context.Context, key *models.PublicKey) error { ret := _m.Called(ctx, key) + if len(ret) == 0 { + panic("no return value specified for PublicKeyCreate") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, *models.PublicKey) error); ok { r0 = rf(ctx, key) @@ -1259,6 +1491,10 @@ func (_m *Store) PublicKeyCreate(ctx context.Context, key *models.PublicKey) err func (_m *Store) PublicKeyDelete(ctx context.Context, fingerprint string, tenantID string) error { ret := _m.Called(ctx, fingerprint, tenantID) + if len(ret) == 0 { + panic("no return value specified for PublicKeyDelete") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { r0 = rf(ctx, fingerprint, tenantID) @@ -1273,6 +1509,10 @@ func (_m *Store) PublicKeyDelete(ctx context.Context, fingerprint string, tenant func (_m *Store) PublicKeyGet(ctx context.Context, fingerprint string, tenantID string) (*models.PublicKey, error) { ret := _m.Called(ctx, fingerprint, tenantID) + if len(ret) == 0 { + panic("no return value specified for PublicKeyGet") + } + var r0 *models.PublicKey var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.PublicKey, error)); ok { @@ -1299,6 +1539,10 @@ func (_m *Store) PublicKeyGet(ctx context.Context, fingerprint string, tenantID func (_m *Store) PublicKeyGetTags(ctx context.Context, tenant string) ([]string, int, error) { ret := _m.Called(ctx, tenant) + if len(ret) == 0 { + panic("no return value specified for PublicKeyGetTags") + } + var r0 []string var r1 int var r2 error @@ -1332,6 +1576,10 @@ func (_m *Store) PublicKeyGetTags(ctx context.Context, tenant string) ([]string, func (_m *Store) PublicKeyList(ctx context.Context, paginator query.Paginator) ([]models.PublicKey, int, error) { ret := _m.Called(ctx, paginator) + if len(ret) == 0 { + panic("no return value specified for PublicKeyList") + } + var r0 []models.PublicKey var r1 int var r2 error @@ -1365,6 +1613,10 @@ func (_m *Store) PublicKeyList(ctx context.Context, paginator query.Paginator) ( func (_m *Store) PublicKeyPullTag(ctx context.Context, tenant string, fingerprint string, tag string) error { ret := _m.Called(ctx, tenant, fingerprint, tag) + if len(ret) == 0 { + panic("no return value specified for PublicKeyPullTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { r0 = rf(ctx, tenant, fingerprint, tag) @@ -1379,6 +1631,10 @@ func (_m *Store) PublicKeyPullTag(ctx context.Context, tenant string, fingerprin func (_m *Store) PublicKeyPushTag(ctx context.Context, tenant string, fingerprint string, tag string) error { ret := _m.Called(ctx, tenant, fingerprint, tag) + if len(ret) == 0 { + panic("no return value specified for PublicKeyPushTag") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { r0 = rf(ctx, tenant, fingerprint, tag) @@ -1393,6 +1649,10 @@ func (_m *Store) PublicKeyPushTag(ctx context.Context, tenant string, fingerprin func (_m *Store) PublicKeySetTags(ctx context.Context, tenant string, fingerprint string, tags []string) (int64, int64, error) { ret := _m.Called(ctx, tenant, fingerprint, tags) + if len(ret) == 0 { + panic("no return value specified for PublicKeySetTags") + } + var r0 int64 var r1 int64 var r2 error @@ -1424,6 +1684,10 @@ func (_m *Store) PublicKeySetTags(ctx context.Context, tenant string, fingerprin func (_m *Store) PublicKeyUpdate(ctx context.Context, fingerprint string, tenantID string, key *models.PublicKeyUpdate) (*models.PublicKey, error) { ret := _m.Called(ctx, fingerprint, tenantID, key) + if len(ret) == 0 { + panic("no return value specified for PublicKeyUpdate") + } + var r0 *models.PublicKey var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string, *models.PublicKeyUpdate) (*models.PublicKey, error)); ok { @@ -1450,6 +1714,10 @@ func (_m *Store) PublicKeyUpdate(ctx context.Context, fingerprint string, tenant func (_m *Store) SessionActiveCreate(ctx context.Context, uid models.UID, session *models.Session) error { ret := _m.Called(ctx, uid, session) + if len(ret) == 0 { + panic("no return value specified for SessionActiveCreate") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, *models.Session) error); ok { r0 = rf(ctx, uid, session) @@ -1464,6 +1732,10 @@ func (_m *Store) SessionActiveCreate(ctx context.Context, uid models.UID, sessio func (_m *Store) SessionCreate(ctx context.Context, session models.Session) (*models.Session, error) { ret := _m.Called(ctx, session) + if len(ret) == 0 { + panic("no return value specified for SessionCreate") + } + var r0 *models.Session var r1 error if rf, ok := ret.Get(0).(func(context.Context, models.Session) (*models.Session, error)); ok { @@ -1490,6 +1762,10 @@ func (_m *Store) SessionCreate(ctx context.Context, session models.Session) (*mo func (_m *Store) SessionDeleteActives(ctx context.Context, uid models.UID) error { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for SessionDeleteActives") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) error); ok { r0 = rf(ctx, uid) @@ -1504,6 +1780,10 @@ func (_m *Store) SessionDeleteActives(ctx context.Context, uid models.UID) error func (_m *Store) SessionEvent(ctx context.Context, uid models.UID, event *models.SessionEvent) error { ret := _m.Called(ctx, uid, event) + if len(ret) == 0 { + panic("no return value specified for SessionEvent") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, *models.SessionEvent) error); ok { r0 = rf(ctx, uid, event) @@ -1518,6 +1798,10 @@ func (_m *Store) SessionEvent(ctx context.Context, uid models.UID, event *models func (_m *Store) SessionGet(ctx context.Context, uid models.UID) (*models.Session, error) { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for SessionGet") + } + var r0 *models.Session var r1 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) (*models.Session, error)); ok { @@ -1544,6 +1828,10 @@ func (_m *Store) SessionGet(ctx context.Context, uid models.UID) (*models.Sessio func (_m *Store) SessionList(ctx context.Context, paginator query.Paginator) ([]models.Session, int, error) { ret := _m.Called(ctx, paginator) + if len(ret) == 0 { + panic("no return value specified for SessionList") + } + var r0 []models.Session var r1 int var r2 error @@ -1577,6 +1865,10 @@ func (_m *Store) SessionList(ctx context.Context, paginator query.Paginator) ([] func (_m *Store) SessionSetLastSeen(ctx context.Context, uid models.UID) error { ret := _m.Called(ctx, uid) + if len(ret) == 0 { + panic("no return value specified for SessionSetLastSeen") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID) error); ok { r0 = rf(ctx, uid) @@ -1591,6 +1883,10 @@ func (_m *Store) SessionSetLastSeen(ctx context.Context, uid models.UID) error { func (_m *Store) SessionSetRecorded(ctx context.Context, uid models.UID, recorded bool) error { ret := _m.Called(ctx, uid, recorded) + if len(ret) == 0 { + panic("no return value specified for SessionSetRecorded") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, bool) error); ok { r0 = rf(ctx, uid, recorded) @@ -1605,6 +1901,10 @@ func (_m *Store) SessionSetRecorded(ctx context.Context, uid models.UID, recorde func (_m *Store) SessionUpdate(ctx context.Context, uid models.UID, model *models.Session) error { ret := _m.Called(ctx, uid, model) + if len(ret) == 0 { + panic("no return value specified for SessionUpdate") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, *models.Session) error); ok { r0 = rf(ctx, uid, model) @@ -1619,6 +1919,10 @@ func (_m *Store) SessionUpdate(ctx context.Context, uid models.UID, model *model func (_m *Store) SessionUpdateDeviceUID(ctx context.Context, oldUID models.UID, newUID models.UID) error { ret := _m.Called(ctx, oldUID, newUID) + if len(ret) == 0 { + panic("no return value specified for SessionUpdateDeviceUID") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, models.UID, models.UID) error); ok { r0 = rf(ctx, oldUID, newUID) @@ -1633,6 +1937,10 @@ func (_m *Store) SessionUpdateDeviceUID(ctx context.Context, oldUID models.UID, func (_m *Store) SystemGet(ctx context.Context) (*models.System, error) { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for SystemGet") + } + var r0 *models.System var r1 error if rf, ok := ret.Get(0).(func(context.Context) (*models.System, error)); ok { @@ -1656,11 +1964,15 @@ func (_m *Store) SystemGet(ctx context.Context) (*models.System, error) { } // SystemSet provides a mock function with given fields: ctx, key, value -func (_m *Store) SystemSet(ctx context.Context, key string, value interface{}) error { +func (_m *Store) SystemSet(ctx context.Context, key string, value any) error { ret := _m.Called(ctx, key, value) + if len(ret) == 0 { + panic("no return value specified for SystemSet") + } + var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, any) error); ok { r0 = rf(ctx, key, value) } else { r0 = ret.Error(0) @@ -1669,10 +1981,72 @@ func (_m *Store) SystemSet(ctx context.Context, key string, value interface{}) e return r0 } +// TagGet provides a mock function with given fields: ctx, tagName, tenant +func (_m *Store) TagGet(ctx context.Context, tagName string, tenant string) (*models.Tags, error) { + ret := _m.Called(ctx, tagName, tenant) + + if len(ret) == 0 { + panic("no return value specified for TagGet") + } + + var r0 *models.Tags + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.Tags, error)); ok { + return rf(ctx, tagName, tenant) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) *models.Tags); ok { + r0 = rf(ctx, tagName, tenant) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Tags) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, tagName, tenant) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// TagsBulkDeleteTag provides a mock function with given fields: ctx, tenant, tagName +func (_m *Store) TagsBulkDeleteTag(ctx context.Context, tenant string, tagName string) (int64, error) { + ret := _m.Called(ctx, tenant, tagName) + + if len(ret) == 0 { + panic("no return value specified for TagsBulkDeleteTag") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (int64, error)); ok { + return rf(ctx, tenant, tagName) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) int64); ok { + r0 = rf(ctx, tenant, tagName) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, tenant, tagName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // TagsDelete provides a mock function with given fields: ctx, tenant, tag func (_m *Store) TagsDelete(ctx context.Context, tenant string, tag string) (int64, error) { ret := _m.Called(ctx, tenant, tag) + if len(ret) == 0 { + panic("no return value specified for TagsDelete") + } + var r0 int64 var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string) (int64, error)); ok { @@ -1694,27 +2068,31 @@ func (_m *Store) TagsDelete(ctx context.Context, tenant string, tag string) (int } // TagsGet provides a mock function with given fields: ctx, tenant -func (_m *Store) TagsGet(ctx context.Context, tenant string) ([]string, int, error) { +func (_m *Store) TagsGet(ctx context.Context, tenant string) ([]models.Tags, int64, error) { ret := _m.Called(ctx, tenant) - var r0 []string - var r1 int + if len(ret) == 0 { + panic("no return value specified for TagsGet") + } + + var r0 []models.Tags + var r1 int64 var r2 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]string, int, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) ([]models.Tags, int64, error)); ok { return rf(ctx, tenant) } - if rf, ok := ret.Get(0).(func(context.Context, string) []string); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) []models.Tags); ok { r0 = rf(ctx, tenant) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]string) + r0 = ret.Get(0).([]models.Tags) } } - if rf, ok := ret.Get(1).(func(context.Context, string) int); ok { + if rf, ok := ret.Get(1).(func(context.Context, string) int64); ok { r1 = rf(ctx, tenant) } else { - r1 = ret.Get(1).(int) + r1 = ret.Get(1).(int64) } if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { @@ -1726,10 +2104,69 @@ func (_m *Store) TagsGet(ctx context.Context, tenant string) ([]string, int, err return r0, r1, r2 } +// TagsGetTags provides a mock function with given fields: ctx, tenant +func (_m *Store) TagsGetTags(ctx context.Context, tenant string) ([]models.Tags, int64, error) { + ret := _m.Called(ctx, tenant) + + if len(ret) == 0 { + panic("no return value specified for TagsGetTags") + } + + var r0 []models.Tags + var r1 int64 + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]models.Tags, int64, error)); ok { + return rf(ctx, tenant) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []models.Tags); ok { + r0 = rf(ctx, tenant) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.Tags) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) int64); ok { + r1 = rf(ctx, tenant) + } else { + r1 = ret.Get(1).(int64) + } + + if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { + r2 = rf(ctx, tenant) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// TagsPushTag provides a mock function with given fields: ctx, tagName, tenantID +func (_m *Store) TagsPushTag(ctx context.Context, tagName string, tenantID string) error { + ret := _m.Called(ctx, tagName, tenantID) + + if len(ret) == 0 { + panic("no return value specified for TagsPushTag") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, tagName, tenantID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // TagsRename provides a mock function with given fields: ctx, tenant, oldTag, newTag func (_m *Store) TagsRename(ctx context.Context, tenant string, oldTag string, newTag string) (int64, error) { ret := _m.Called(ctx, tenant, oldTag, newTag) + if len(ret) == 0 { + panic("no return value specified for TagsRename") + } + var r0 int64 var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (int64, error)); ok { @@ -1754,6 +2191,10 @@ func (_m *Store) TagsRename(ctx context.Context, tenant string, oldTag string, n func (_m *Store) UserConflicts(ctx context.Context, target *models.UserConflicts) ([]string, bool, error) { ret := _m.Called(ctx, target) + if len(ret) == 0 { + panic("no return value specified for UserConflicts") + } + var r0 []string var r1 bool var r2 error @@ -1787,6 +2228,10 @@ func (_m *Store) UserConflicts(ctx context.Context, target *models.UserConflicts func (_m *Store) UserCreate(ctx context.Context, user *models.User) (string, error) { ret := _m.Called(ctx, user) + if len(ret) == 0 { + panic("no return value specified for UserCreate") + } + var r0 string var r1 error if rf, ok := ret.Get(0).(func(context.Context, *models.User) (string, error)); ok { @@ -1811,6 +2256,10 @@ func (_m *Store) UserCreate(ctx context.Context, user *models.User) (string, err func (_m *Store) UserCreateInvited(ctx context.Context, email string) (string, error) { ret := _m.Called(ctx, email) + if len(ret) == 0 { + panic("no return value specified for UserCreateInvited") + } + var r0 string var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (string, error)); ok { @@ -1835,6 +2284,10 @@ func (_m *Store) UserCreateInvited(ctx context.Context, email string) (string, e func (_m *Store) UserDelete(ctx context.Context, id string) error { ret := _m.Called(ctx, id) + if len(ret) == 0 { + panic("no return value specified for UserDelete") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { r0 = rf(ctx, id) @@ -1849,6 +2302,10 @@ func (_m *Store) UserDelete(ctx context.Context, id string) error { func (_m *Store) UserGetByEmail(ctx context.Context, email string) (*models.User, error) { ret := _m.Called(ctx, email) + if len(ret) == 0 { + panic("no return value specified for UserGetByEmail") + } + var r0 *models.User var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (*models.User, error)); ok { @@ -1875,6 +2332,10 @@ func (_m *Store) UserGetByEmail(ctx context.Context, email string) (*models.User func (_m *Store) UserGetByID(ctx context.Context, id string, ns bool) (*models.User, int, error) { ret := _m.Called(ctx, id, ns) + if len(ret) == 0 { + panic("no return value specified for UserGetByID") + } + var r0 *models.User var r1 int var r2 error @@ -1908,6 +2369,10 @@ func (_m *Store) UserGetByID(ctx context.Context, id string, ns bool) (*models.U func (_m *Store) UserGetByUsername(ctx context.Context, username string) (*models.User, error) { ret := _m.Called(ctx, username) + if len(ret) == 0 { + panic("no return value specified for UserGetByUsername") + } + var r0 *models.User var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (*models.User, error)); ok { @@ -1934,6 +2399,10 @@ func (_m *Store) UserGetByUsername(ctx context.Context, username string) (*model func (_m *Store) UserGetInfo(ctx context.Context, id string) (*models.UserInfo, error) { ret := _m.Called(ctx, id) + if len(ret) == 0 { + panic("no return value specified for UserGetInfo") + } + var r0 *models.UserInfo var r1 error if rf, ok := ret.Get(0).(func(context.Context, string) (*models.UserInfo, error)); ok { @@ -1960,6 +2429,10 @@ func (_m *Store) UserGetInfo(ctx context.Context, id string) (*models.UserInfo, func (_m *Store) UserList(ctx context.Context, paginator query.Paginator, filters query.Filters) ([]models.User, int, error) { ret := _m.Called(ctx, paginator, filters) + if len(ret) == 0 { + panic("no return value specified for UserList") + } + var r0 []models.User var r1 int var r2 error @@ -1993,6 +2466,10 @@ func (_m *Store) UserList(ctx context.Context, paginator query.Paginator, filter func (_m *Store) UserUpdate(ctx context.Context, id string, changes *models.UserChanges) error { ret := _m.Called(ctx, id, changes) + if len(ret) == 0 { + panic("no return value specified for UserUpdate") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, *models.UserChanges) error); ok { r0 = rf(ctx, id, changes) @@ -2007,6 +2484,10 @@ func (_m *Store) UserUpdate(ctx context.Context, id string, changes *models.User func (_m *Store) WithTransaction(ctx context.Context, cb store.TransactionCb) error { ret := _m.Called(ctx, cb) + if len(ret) == 0 { + panic("no return value specified for WithTransaction") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, store.TransactionCb) error); ok { r0 = rf(ctx, cb) @@ -2017,13 +2498,12 @@ func (_m *Store) WithTransaction(ctx context.Context, cb store.TransactionCb) er return r0 } -type mockConstructorTestingTNewStore interface { +// NewStore creates a new instance of Store. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewStore(t interface { mock.TestingT Cleanup(func()) -} - -// NewStore creates a new instance of Store. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewStore(t mockConstructorTestingTNewStore) *Store { +}) *Store { mock := &Store{} mock.Mock.Test(t) diff --git a/api/store/mongo/device_tags.go b/api/store/mongo/device_tags.go index 83ad2763c2e..c3b13b75e9a 100644 --- a/api/store/mongo/device_tags.go +++ b/api/store/mongo/device_tags.go @@ -2,23 +2,58 @@ package mongo import ( "context" + "errors" "github.com/shellhub-io/shellhub/api/store" "github.com/shellhub-io/shellhub/pkg/models" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/writeconcern" ) func (s *Store) DevicePushTag(ctx context.Context, uid models.UID, tag string) error { - t, err := s.db.Collection("devices").UpdateOne(ctx, bson.M{"uid": uid}, bson.M{"$push": bson.M{"tags": tag}}) + session, err := s.db.Client().StartSession() if err != nil { return FromMongoError(err) } - - if t.ModifiedCount < 1 { - return store.ErrNoDocuments - } - - return nil + defer session.EndSession(ctx) + + _, erro := session.WithTransaction(ctx, func(sessCtx mongo.SessionContext) (interface{}, error) { + device := new(models.Device) + + err := s.db.Collection("devices", options.Collection().SetWriteConcern(writeconcern.Majority())). + FindOne(sessCtx, bson.M{"uid": uid}). + Decode(device) + if err != nil { + return nil, FromMongoError(err) + } + + if _, err := s.TagGet(sessCtx, tag, device.TenantID); err != nil { + if errors.Is(err, store.ErrNoDocuments) { + err := s.TagsPushTag(sessCtx, tag, device.TenantID) + if err != nil { + return nil, FromMongoError(err) + } + } else if err != nil { + return nil, err + } + } + + t, err := s.db.Collection("devices", options.Collection().SetWriteConcern(writeconcern.Majority())). + UpdateOne(sessCtx, bson.M{"uid": uid}, bson.M{"$push": bson.M{"tags": tag}}) + if err != nil { + return nil, FromMongoError(err) + } + + if t.ModifiedCount < 1 { + return nil, store.ErrNoDocuments + } + + return nil, nil + }) + + return erro } func (s *Store) DevicePullTag(ctx context.Context, uid models.UID, tag string) error { @@ -35,19 +70,36 @@ func (s *Store) DevicePullTag(ctx context.Context, uid models.UID, tag string) e } func (s *Store) DeviceSetTags(ctx context.Context, uid models.UID, tags []string) (int64, int64, error) { - tag, err := s.db.Collection("devices").UpdateOne(ctx, bson.M{"uid": uid}, bson.M{"$set": bson.M{"tags": tags}}) + tag, err := s.db.Collection("devices"). + UpdateOne(ctx, bson.M{"uid": uid}, bson.M{"$set": bson.M{"tags": tags}}) + + if tag.MatchedCount < 1 { + return tag.MatchedCount, tag.ModifiedCount, store.ErrNoDocuments + } return tag.MatchedCount, tag.ModifiedCount, FromMongoError(err) } func (s *Store) DeviceBulkRenameTag(ctx context.Context, tenant, currentTag, newTag string) (int64, error) { - res, err := s.db.Collection("devices").UpdateMany(ctx, bson.M{"tenant_id": tenant, "tags": currentTag}, bson.M{"$set": bson.M{"tags.$": newTag}}) + res, err := s.db.Collection("devices"). + UpdateMany(ctx, bson.M{"tenant_id": tenant, "tags": currentTag}, bson.M{"$set": bson.M{"tags.$": newTag}}) return res.ModifiedCount, FromMongoError(err) } +func (s *Store) TagsBulkRenameTag(ctx context.Context, tenant, currentTag, newTag string) (int64, error) { + res, err := s.db.Collection("tags", options.Collection().SetWriteConcern(writeconcern.Majority())). + UpdateOne(ctx, bson.M{"tenant_id": tenant, "name": currentTag}, bson.M{"$set": bson.M{"name": newTag}}) + if err != nil { + return 0, FromMongoError(err) + } + + return res.ModifiedCount, err +} + func (s *Store) DeviceBulkDeleteTag(ctx context.Context, tenant, tag string) (int64, error) { - res, err := s.db.Collection("devices").UpdateMany(ctx, bson.M{"tenant_id": tenant}, bson.M{"$pull": bson.M{"tags": tag}}) + res, err := s.db.Collection("devices"). + UpdateMany(ctx, bson.M{"tenant_id": tenant}, bson.M{"$pull": bson.M{"tags": tag}}) return res.ModifiedCount, FromMongoError(err) } diff --git a/api/store/mongo/device_tags_test.go b/api/store/mongo/device_tags_test.go index 44e963c515e..bfe03aa2126 100644 --- a/api/store/mongo/device_tags_test.go +++ b/api/store/mongo/device_tags_test.go @@ -2,9 +2,11 @@ package mongo_test import ( "context" + // "errors" "testing" "github.com/shellhub-io/shellhub/api/store" + // mongo "github.com/shellhub-io/shellhub/api/store/mongo" "github.com/shellhub-io/shellhub/pkg/models" "github.com/stretchr/testify/assert" ) @@ -21,14 +23,14 @@ func TestDevicePushTag(t *testing.T) { description: "fails when device doesn't exist", uid: models.UID("nonexistent"), tag: "tag4", - fixtures: []string{fixtureDevices}, + fixtures: []string{fixtureTags, fixtureDevices}, expected: store.ErrNoDocuments, }, { description: "successfully creates single tag for an existing device", uid: models.UID("2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c"), tag: "tag4", - fixtures: []string{fixtureDevices}, + fixtures: []string{fixtureTags, fixtureDevices}, expected: nil, }, } @@ -108,14 +110,14 @@ func TestDeviceSetTags(t *testing.T) { expected Expected }{ { - description: "successfully when device doesn't exist", + description: "fails when device doesn't exist", uid: models.UID("nonexistent"), - tags: []string{"new-tag"}, + tags: []string{"tag-1"}, fixtures: []string{fixtureDevices}, expected: Expected{ matchedCount: 0, updatedCount: 0, - err: nil, + err: store.ErrNoDocuments, }, }, { diff --git a/api/store/mongo/fixtures/tags.json b/api/store/mongo/fixtures/tags.json new file mode 100644 index 00000000000..e3e6b1cd7d6 --- /dev/null +++ b/api/store/mongo/fixtures/tags.json @@ -0,0 +1,29 @@ +{ + "tags": { + "67519c0c31490629a1fc612c": { + "name" : "red", + "tenant_id" : "00000000-0000-4000-0000-000000000000", + "color" : "" + }, + "67519e3531490629a1fc612f": { + "name" : "redone", + "tenant_id" : "ed369960-338b-4e8c-9ec4-e5fd31974763", + "color" : "#ff0000" + }, + "67519e4231490629a1fc6130": { + "name" : "blue", + "tenant_id" : "00000000-0000-4000-0000-000000000000", + "color" : "#0000ff" + }, + "6751a03431490629a1fc6131": { + "name" : "tag-1", + "tenant_id" : "00000000-0000-4000-0000-000000000000", + "color" : "#a25f36" + }, + "6751b1a93592db0deea3fd97": { + "name" : "green", + "tenant_id" : "00000000-0000-4000-0000-000000000000", + "color" : "green" + } + } +} diff --git a/api/store/mongo/migrations/migration_69.go b/api/store/mongo/migrations/migration_69.go index c693e216a0c..401382b3aa9 100644 --- a/api/store/mongo/migrations/migration_69.go +++ b/api/store/mongo/migrations/migration_69.go @@ -83,7 +83,6 @@ var migration69 = migrate.Migration{ return nil, err }) - if err != nil { return err } diff --git a/api/store/mongo/migrations/migration_86.go b/api/store/mongo/migrations/migration_86.go index b79f0ce5723..5b350dca398 100644 --- a/api/store/mongo/migrations/migration_86.go +++ b/api/store/mongo/migrations/migration_86.go @@ -42,7 +42,6 @@ var migration86 = migrate.Migration{ Options: options.Index().SetName("tenant_id").SetUnique(false), } - _, err2 := db.Collection("tags", options.Collection().SetWriteConcern(writeconcern.Majority()), ).Indexes().CreateOne(ctx, indexTenant) @@ -62,7 +61,6 @@ var migration86 = migrate.Migration{ _, err := db.Collection("tags", options.Collection().SetWriteConcern(writeconcern.Majority()), ).Indexes().DropOne(ctx, "names") - if err != nil { return err } diff --git a/api/store/mongo/migrations/migration_86_test.go b/api/store/mongo/migrations/migration_86_test.go index c6dfaa7ff85..cf82cb98eea 100644 --- a/api/store/mongo/migrations/migration_86_test.go +++ b/api/store/mongo/migrations/migration_86_test.go @@ -28,8 +28,8 @@ func TestMigration86(t *testing.T) { description: "Apply up on migration 83 when there is at least one user", setup: func(t *testing.T) { _, err := c.Database("test").Collection("tags").InsertOne(ctx, models.Tags{ - Name: "red", - Color: "#ff0000", + Name: "red", + Color: "#ff0000", Tenant: "00000000-0000-4000-0000-000000000000", }) require.NoError(t, err) @@ -56,7 +56,7 @@ func TestMigration86(t *testing.T) { assert.NoError(tt, srv.Reset()) }) - migrates := migrate.NewMigrate(c.Database("test"), GenerateMigrations()[86 - 1]) + migrates := migrate.NewMigrate(c.Database("test"), GenerateMigrations()[86-1]) require.NoError(tt, migrates.Up(context.Background(), migrate.AllAvailable)) test.setup(tt) diff --git a/api/store/mongo/publickey_tags.go b/api/store/mongo/publickey_tags.go index 8a11753d90b..635409b2a0c 100644 --- a/api/store/mongo/publickey_tags.go +++ b/api/store/mongo/publickey_tags.go @@ -2,19 +2,48 @@ package mongo import ( "context" + "errors" "github.com/shellhub-io/shellhub/api/store" "go.mongodb.org/mongo-driver/bson" + mongodriver "go.mongodb.org/mongo-driver/mongo" ) func (s *Store) PublicKeyPushTag(ctx context.Context, tenant, fingerprint, tag string) error { - result, err := s.db.Collection("public_keys").UpdateOne(ctx, bson.M{"tenant_id": tenant, "fingerprint": fingerprint}, bson.M{"$addToSet": bson.M{"filter.tags": tag}}) + session, err := s.db.Client().StartSession() if err != nil { - return err + return FromMongoError(err) } - - if result.ModifiedCount < 1 { - return store.ErrNoDocuments + defer session.EndSession(ctx) + + _, erro := session.WithTransaction(ctx, func(sessCtx mongodriver.SessionContext) (interface{}, error) { + if _, err := s.TagGet(sessCtx, tag, tenant); err != nil { + if errors.Is(err, store.ErrNoDocuments) { + err := s.TagsPushTag(sessCtx, tag, tenant) + if err != nil { + return nil, FromMongoError(err) + } + } else if err != nil { + return nil, err + } + } + + result, err := s.db.Collection("public_keys"). + UpdateOne(sessCtx, bson.M{"tenant_id": tenant, "fingerprint": fingerprint}, + bson.M{"$addToSet": bson.M{"filter.tags": tag}}) + if err != nil { + return nil, err + } + + if result.ModifiedCount < 1 { + return nil, store.ErrNoDocuments + } + + return nil, nil + }) + + if erro != nil { + return erro } return nil diff --git a/api/store/mongo/store_test.go b/api/store/mongo/store_test.go index d990ecc798e..102a3ccdb99 100644 --- a/api/store/mongo/store_test.go +++ b/api/store/mongo/store_test.go @@ -16,9 +16,11 @@ import ( mongodb "go.mongodb.org/mongo-driver/mongo" ) -var srv = &dbtest.Server{} -var db *mongodb.Database -var s store.Store +var ( + srv = &dbtest.Server{} + db *mongodb.Database + s store.Store +) const ( fixtureAPIKeys = "api-key" // Check "store.mongo.fixtures.api-keys" for fixture info @@ -29,7 +31,8 @@ const ( fixtureFirewallRules = "firewall_rules" // Check "store.mongo.fixtures.firewall_rules" for fixture info fixturePublicKeys = "public_keys" // Check "store.mongo.fixtures.public_keys" for fixture info fixturePrivateKeys = "private_keys" // Check "store.mongo.fixtures.private_keys" for fixture info - fixtureUsers = "users" // Check "store.mongo.fixtures.users" for fixture iefo + fixtureUsers = "users" // Check "store.mongo.fixtures.users" for fixture info + fixtureTags = "tags" // Check "store.mongo.fixtures.tags" for fixture info fixtureNamespaces = "namespaces" // Check "store.mongo.fixtures.namespaces" for fixture info fixtureRecoveryTokens = "recovery_tokens" // Check "store.mongo.fixtures.recovery_tokens" for fixture info ) @@ -65,6 +68,7 @@ func TestMain(m *testing.M) { mongotest.SimpleConvertTime("sessions", "last_seen"), mongotest.SimpleConvertObjID("active_sessions", "_id"), mongotest.SimpleConvertTime("active_sessions", "last_seen"), + mongotest.SimpleConvertObjID("tags", "_id"), } if err := srv.Up(ctx); err != nil { diff --git a/api/store/mongo/tags.go b/api/store/mongo/tags.go index 6ea7146aee2..90982d948c6 100644 --- a/api/store/mongo/tags.go +++ b/api/store/mongo/tags.go @@ -3,8 +3,11 @@ package mongo import ( "context" + "github.com/shellhub-io/shellhub/pkg/models" "go.mongodb.org/mongo-driver/bson" mongodriver "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/writeconcern" ) func (s *Store) FirewallRuleGetTags(ctx context.Context, tenant string) ([]string, int, error) { @@ -18,43 +21,6 @@ func (s *Store) FirewallRuleGetTags(ctx context.Context, tenant string) ([]strin return tags, len(tags), FromMongoError(err) } -func (s *Store) TagsGet(ctx context.Context, tenant string) ([]string, int, error) { - session, err := s.db.Client().StartSession() - if err != nil { - return nil, 0, err - } - defer session.EndSession(ctx) - - tags, err := session.WithTransaction(ctx, func(sessCtx mongodriver.SessionContext) (interface{}, error) { - deviceTags, _, err := s.DeviceGetTags(sessCtx, tenant) - if err != nil { - return nil, err - } - - keyTags, _, err := s.PublicKeyGetTags(sessCtx, tenant) - if err != nil { - return nil, err - } - - ruleTags, _, err := s.FirewallRuleGetTags(sessCtx, tenant) - if err != nil { - return nil, err - } - - tags := []string{} - tags = append(tags, deviceTags...) - tags = append(tags, keyTags...) - tags = append(tags, ruleTags...) - - return removeDuplicate[string](tags), nil - }) - if err != nil { - return nil, 0, FromMongoError(err) - } - - return tags.([]string), len(tags.([]string)), nil -} - func (s *Store) FirewallRuleBulkRenameTag(ctx context.Context, tenant, currentTag, newTag string) (int64, error) { res, err := s.db.Collection("firewall_rules").UpdateMany(ctx, bson.M{"tenant_id": tenant, "filter.tags": currentTag}, bson.M{"$set": bson.M{"filter.tags.$": newTag}}) @@ -84,7 +50,12 @@ func (s *Store) TagsRename(ctx context.Context, tenantID string, oldTag string, return int64(0), err } - return devCount + keyCount + rulCount, nil + tagsCount, err := s.TagsBulkRenameTag(sessCtx, tenantID, oldTag, newTag) + if err != nil { + return int64(0), err + } + + return devCount + keyCount + rulCount + tagsCount, nil }) if err != nil { return int64(0), FromMongoError(err) @@ -122,7 +93,12 @@ func (s *Store) TagsDelete(ctx context.Context, tenantID string, tag string) (in return int64(0), err } - return devCount + keyCount + rulCount, nil + tagCount, err := s.TagsBulkDeleteTag(sessCtx, tenantID, tag) + if err != nil { + return int64(0), err + } + + return devCount + keyCount + rulCount + tagCount, nil }) if err != nil { return int64(0), FromMongoError(err) @@ -130,3 +106,69 @@ func (s *Store) TagsDelete(ctx context.Context, tenantID string, tag string) (in return count.(int64), nil } + +func (s *Store) TagGet(ctx context.Context, tagName, tenant string) (*models.Tags, error) { + tag := new(models.Tags) + if err := s.db.Collection("tags").FindOne(ctx, bson.M{"name": tagName, "tenant_id": tenant}).Decode(tag); err != nil { + return nil, FromMongoError(err) + } + + return tag, nil +} + +func (s *Store) TagsGet(ctx context.Context, tenant string) ([]models.Tags, int64, error) { + tags, length, err := s.TagsGetTags(ctx, tenant) + if err != nil { + return nil, length, FromMongoError(err) + } + + return removeDuplicate[models.Tags](tags), length, nil +} + +func (s *Store) TagsPushTag(ctx context.Context, tagName, tenantID string) error { + tag := &models.Tags{ + Name: tagName, + Tenant: tenantID, + Color: "", + } + + _, err := s.db.Collection("tags", options.Collection().SetWriteConcern(writeconcern.Majority())). + InsertOne(ctx, tag) + if err != nil { + return FromMongoError(err) + } + + return nil +} + +func (s *Store) TagsBulkDeleteTag(ctx context.Context, tenant, tagName string) (int64, error) { + res, err := s.db.Collection("tags", options.Collection().SetWriteConcern(writeconcern.Majority())). + DeleteOne(ctx, bson.M{"tenant_id": tenant, "name": tagName}) + + return res.DeletedCount, FromMongoError(err) +} + +func (s *Store) TagsGetTags(ctx context.Context, tenant string) ([]models.Tags, int64, error) { + cursor, err := s.db.Collection("tags").Find(ctx, bson.M{"tenant_id": tenant}) + if err != nil { + return nil, 0, FromMongoError(err) + } + defer cursor.Close(ctx) + + tags := make([]models.Tags, cursor.RemainingBatchLength()) + i := 0 + + for cursor.Next(ctx) { + tg := new(models.Tags) + + if err := cursor.Decode(tg); err != nil { + return nil, 0, FromMongoError(err) + } + + tags[i] = *tg //nolint:forcetypeassert + + i++ + } + + return tags, int64(len(tags)), FromMongoError(err) +} diff --git a/api/store/mongo/tags_test.go b/api/store/mongo/tags_test.go index be9fd44d5b3..145e281d4b5 100644 --- a/api/store/mongo/tags_test.go +++ b/api/store/mongo/tags_test.go @@ -5,13 +5,23 @@ import ( "sort" "testing" + "github.com/shellhub-io/shellhub/pkg/models" "github.com/stretchr/testify/assert" ) +// sort tags model by tag name +// Due to the non-deterministic order of applying fixtures when dealing with multiple datasets, +// we ensure that both the expected and result arrays are correctly sorted. +func sortTags(tags []models.Tags) { + sort.Slice(tags, func(i, j int) bool { + return tags[i].Name < tags[j].Name + }) +} + func TestTagsGet(t *testing.T) { type Expected struct { - tags []string - len int + tags []models.Tags + len int64 err error } @@ -24,23 +34,40 @@ func TestTagsGet(t *testing.T) { { description: "succeeds when tag is found", tenant: "00000000-0000-4000-0000-000000000000", - fixtures: []string{fixturePublicKeys, fixtureFirewallRules, fixtureDevices}, + fixtures: []string{fixtureTags, fixturePublicKeys, fixtureFirewallRules, fixtureDevices}, expected: Expected{ - tags: []string{"tag-1"}, - len: 1, - err: nil, + tags: []models.Tags{ + { + ID: "67519c0c31490629a1fc612c", + Name: "red", + Color: "", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + { + ID: "67519e4231490629a1fc6130", + Name: "blue", + Color: "#0000ff", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + { + ID: "6751a03431490629a1fc6131", + Name: "tag-1", + Color: "#a25f36", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + { + ID: "6751b1a93592db0deea3fd97", + Name: "green", + Tenant: "00000000-0000-4000-0000-000000000000", + Color: "green", + }, + }, + len: 4, + err: nil, }, }, } - // Due to the non-deterministic order of applying fixtures when dealing with multiple datasets, - // we ensure that both the expected and result arrays are correctly sorted. - sort := func(tags []string) { - sort.Slice(tags, func(i, j int) bool { - return tags[i] < tags[j] - }) - } - for _, tc := range cases { t.Run(tc.description, func(t *testing.T) { ctx := context.Background() @@ -52,8 +79,8 @@ func TestTagsGet(t *testing.T) { tags, count, err := s.TagsGet(ctx, tc.tenant) - sort(tc.expected.tags) - sort(tags) + sortTags(tc.expected.tags) + sortTags(tags) assert.Equal(t, tc.expected, Expected{tags: tags, len: count, err: err}) }) @@ -79,9 +106,9 @@ func TestTagsRename(t *testing.T) { tenant: "00000000-0000-4000-0000-000000000000", oldTag: "tag-1", newTag: "edited-tag", - fixtures: []string{fixturePublicKeys, fixtureFirewallRules, fixtureDevices}, + fixtures: []string{fixtureTags, fixturePublicKeys, fixtureFirewallRules, fixtureDevices}, expected: Expected{ - count: 6, + count: 7, err: nil, }, }, @@ -141,3 +168,221 @@ func TestTagsDelete(t *testing.T) { }) } } + +func TestTagGet(t *testing.T) { + type Expected struct { + tag []models.Tags + count int64 + err error + } + + cases := []struct { + description string + tenant string + tag string + fixtures []string + expected Expected + }{ + { + description: "succeeds when tag is found", + tenant: "00000000-0000-4000-0000-000000000000", + tag: "tag-1", + fixtures: []string{fixtureTags, fixturePublicKeys, fixtureFirewallRules, fixtureDevices}, + expected: Expected{ + tag: []models.Tags{ + { + ID: "67519c0c31490629a1fc612c", + Name: "red", + Color: "", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + { + ID: "67519e4231490629a1fc6130", + Name: "blue", + Color: "#0000ff", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + { + ID: "6751b1a93592db0deea3fd97", + Name: "green", + Tenant: "00000000-0000-4000-0000-000000000000", + Color: "green", + }, + { + ID: "6751a03431490629a1fc6131", + Name: "tag-1", + Color: "#a25f36", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + }, + count: 4, + err: nil, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + ctx := context.Background() + + assert.NoError(t, srv.Apply(tc.fixtures...)) + t.Cleanup(func() { + assert.NoError(t, srv.Reset()) + }) + + tags, count, err := s.TagsGet(ctx, tc.tenant) + + sortTags(tags) + sortTags(tc.expected.tag) + + assert.Equal(t, tc.expected, Expected{tags, count, err}) + }) + } +} + +func TestTagsGetTags(t *testing.T) { + type Expected struct { + tag []models.Tags + count int64 + err error + } + + cases := []struct { + description string + tenant string + tag string + fixtures []string + expected Expected + }{ + { + description: "succeeds when tag is found", + tenant: "00000000-0000-4000-0000-000000000000", + tag: "tag-1", + fixtures: []string{fixtureTags, fixturePublicKeys, fixtureFirewallRules, fixtureDevices}, + expected: Expected{ + tag: []models.Tags{ + { + ID: "67519c0c31490629a1fc612c", + Name: "red", + Color: "", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + { + ID: "67519e4231490629a1fc6130", + Name: "blue", + Color: "#0000ff", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + { + ID: "6751b1a93592db0deea3fd97", + Name: "green", + Tenant: "00000000-0000-4000-0000-000000000000", + Color: "green", + }, + { + ID: "6751a03431490629a1fc6131", + Name: "tag-1", + Color: "#a25f36", + Tenant: "00000000-0000-4000-0000-000000000000", + }, + }, + count: 4, + err: nil, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + ctx := context.Background() + + assert.NoError(t, srv.Apply(tc.fixtures...)) + t.Cleanup(func() { + assert.NoError(t, srv.Reset()) + }) + + tags, count, err := s.TagsGetTags(ctx, tc.tenant) + + sortTags(tags) + sortTags(tc.expected.tag) + + assert.Equal(t, tc.expected, Expected{tags, count, err}) + }) + } +} + +func TestTagsPushTag(t *testing.T) { + cases := []struct { + description string + name string + tenant string + tag string + fixtures []string + expected error + }{ + { + description: "succeeds when tag is found", + tenant: "00000000-0000-4000-0000-000000000000", + name: "red-one", + tag: "tag-1", + fixtures: []string{fixtureTags, fixturePublicKeys, fixtureFirewallRules, fixtureDevices}, + expected: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + ctx := context.Background() + + assert.NoError(t, srv.Apply(tc.fixtures...)) + t.Cleanup(func() { + assert.NoError(t, srv.Reset()) + }) + + err := s.TagsPushTag(ctx, tc.name, tc.tenant) + assert.Equal(t, tc.expected, err) + }) + } +} + +func TestTagsBulkDeleteTag(t *testing.T) { + type Expected struct { + count int64 + err error + } + + cases := []struct { + description string + name string + tenant string + tag string + fixtures []string + expected Expected + }{ + { + description: "succeeds when tag is found", + tenant: "00000000-0000-4000-0000-000000000000", + name: "red-one", + tag: "tag-1", + fixtures: []string{fixtureTags, fixturePublicKeys, fixtureFirewallRules, fixtureDevices}, + expected: Expected{ + count: 0, + err: nil, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + ctx := context.Background() + + assert.NoError(t, srv.Apply(tc.fixtures...)) + t.Cleanup(func() { + assert.NoError(t, srv.Reset()) + }) + + count, err := s.TagsBulkDeleteTag(ctx, tc.tenant, tc.name) + assert.Equal(t, tc.expected, Expected{count, err}) + }) + } +} diff --git a/api/store/tags.go b/api/store/tags.go index 4894f572f3b..2fdc49f5b58 100644 --- a/api/store/tags.go +++ b/api/store/tags.go @@ -1,13 +1,17 @@ package store -import "context" +import ( + "context" + + "github.com/shellhub-io/shellhub/pkg/models" +) type TagsStore interface { // TagsGet retrieves all tags associated with the specified tenant. It functions by invoking "[document]GetTags" // for each document that implements tags. // Returns the tags, the count of unique tags, and an error if any issues arise. // It also filters the returned tags, removing any duplicates. - TagsGet(ctx context.Context, tenant string) (tags []string, n int, err error) + TagsGet(ctx context.Context, tenant string) (tags []models.Tags, n int64, err error) // TagsRename replaces all occurrences of the old tag with the new tag for all documents associated with the specified tenant. // It operates by invoking "[document]BulkRenameTag" for each document that implements tags. @@ -18,4 +22,9 @@ type TagsStore interface { // invoking "[document]BulkDeleteTag" for each document that implements tags. // Returns the count of documents updated and an error if any issues arise during the tag deletion. TagsDelete(ctx context.Context, tenant string, tag string) (updatedCount int64, err error) + + TagGet(ctx context.Context, tagName, tenant string) (*models.Tags, error) + TagsGetTags(ctx context.Context, tenant string) ([]models.Tags, int64, error) + TagsPushTag(ctx context.Context, tagName, tenantID string) error + TagsBulkDeleteTag(ctx context.Context, tenant, tagName string) (int64, error) } diff --git a/api/store/transaction.go b/api/store/transaction.go index 87c20466aa1..560189a1e48 100644 --- a/api/store/transaction.go +++ b/api/store/transaction.go @@ -5,9 +5,7 @@ import ( "errors" ) -var ( - ErrStartTransactionFailed = errors.New("start transaction failed") -) +var ErrStartTransactionFailed = errors.New("start transaction failed") // TransactionCb defines the function signature expected for transaction operations. // It typically encompasses a series of store method calls that must be executed within a transaction.