Skip to content

Commit

Permalink
all storage parts support gid for requests and data
Browse files Browse the repository at this point in the history
  • Loading branch information
umputun committed Dec 31, 2024
1 parent 16c193e commit 1c1ff49
Show file tree
Hide file tree
Showing 10 changed files with 397 additions and 305 deletions.
7 changes: 4 additions & 3 deletions app/storage/approved_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ var approvedUsersSchema = `
uid TEXT,
gid TEXT DEFAULT '',
name TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(gid, uid)
);
CREATE INDEX IF NOT EXISTS idx_approved_users_uid ON approved_users(uid);
CREATE INDEX IF NOT EXISTS idx_approved_users_gid ON approved_users(gid);
Expand Down Expand Up @@ -68,12 +69,12 @@ func NewApprovedUsers(ctx context.Context, db *Engine) (*ApprovedUsers, error) {
}

// Read returns a list of all approved users
func (au *ApprovedUsers) Read(ctx context.Context) ([]approved.UserInfo, error) {
func (au *ApprovedUsers) Read(ctx context.Context, gid string) ([]approved.UserInfo, error) {
au.db.RLock()
defer au.db.RUnlock()

users := []approvedUsersInfo{}
err := au.db.SelectContext(ctx, &users, "SELECT uid, gid, name, timestamp FROM approved_users ORDER BY uid ASC")
err := au.db.SelectContext(ctx, &users, "SELECT uid, gid, name, timestamp FROM approved_users WHERE gid=? ORDER BY uid ASC", gid)
if err != nil {
return nil, fmt.Errorf("failed to get approved users: %w", err)
}
Expand Down
77 changes: 30 additions & 47 deletions app/storage/approved_users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,52 +129,35 @@ func TestApprovedUsers_Read(t *testing.T) {

testTime := time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)

tests := []struct {
name string
setup func(t *testing.T, db *sqlx.DB)
expected []approved.UserInfo
}{
{
name: "read users with groups",
setup: func(t *testing.T, db *sqlx.DB) {
_, err := db.Exec("DELETE FROM approved_users")
require.NoError(t, err)

users := []approved.UserInfo{
{UserID: "123", UserName: "John", GroupID: "admin", Timestamp: testTime},
{UserID: "456", UserName: "Jane", GroupID: "user", Timestamp: testTime},
}
for _, u := range users {
err := au.Write(ctx, u)
require.NoError(t, err)
}
},
expected: []approved.UserInfo{
{UserID: "123", UserName: "John", GroupID: "admin", Timestamp: testTime},
{UserID: "456", UserName: "Jane", GroupID: "user", Timestamp: testTime},
},
},
{
name: "empty table",
setup: func(t *testing.T, db *sqlx.DB) {
_, err := db.Exec("DELETE FROM approved_users")
require.NoError(t, err)
},
expected: []approved.UserInfo{},
},
// write test data
users := []approved.UserInfo{
{UserID: "123", UserName: "John", GroupID: "gr1", Timestamp: testTime},
{UserID: "456", UserName: "Jane", GroupID: "gr2", Timestamp: testTime},
}
for _, u := range users {
err := au.Write(ctx, u)
require.NoError(t, err)
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup(t, &db.DB)
}
t.Run("read users with gr1", func(t *testing.T) {
users, err := au.Read(ctx, "gr1")
require.NoError(t, err)
require.Len(t, users, 1)
assert.Equal(t, users[0].UserID, "123")
})

users, err := au.Read(ctx)
require.NoError(t, err)
assert.Equal(t, tt.expected, users)
})
}
t.Run("read users with gr2", func(t *testing.T) {
users, err := au.Read(ctx, "gr2")
require.NoError(t, err)
require.Len(t, users, 1)
assert.Equal(t, users[0].UserID, "456")
})

t.Run("read user from non-existing group", func(t *testing.T) {
users, err := au.Read(ctx, "non-existing")
require.NoError(t, err)
require.Len(t, users, 0)
})
}

func TestApprovedUsers_Delete(t *testing.T) {
Expand Down Expand Up @@ -242,11 +225,11 @@ func TestApprovedUsers_StoreAndRead(t *testing.T) {
require.NoError(t, err)

for _, id := range tt.ids {
err = au.Write(ctx, approved.UserInfo{UserID: id, UserName: "name_" + id})
err = au.Write(ctx, approved.UserInfo{UserID: id, UserName: "name_" + id, GroupID: "gr1"})
require.NoError(t, err)
}

res, err := au.Read(ctx)
res, err := au.Read(ctx, "gr1")
require.NoError(t, err)
assert.Equal(t, len(tt.expected), len(res))
})
Expand All @@ -272,13 +255,13 @@ func TestApprovedUsers_ContextCancellation(t *testing.T) {

t.Run("read with cancelled context", func(t *testing.T) {
// prepare data
err := au.Write(ctx, approved.UserInfo{UserID: "123", UserName: "test"})
err := au.Write(ctx, approved.UserInfo{UserID: "123", UserName: "test", GroupID: "gr1"})
require.NoError(t, err)

ctxCanceled, cancel := context.WithCancel(context.Background())
cancel()

_, err = au.Read(ctxCanceled)
_, err = au.Read(ctxCanceled, "gr1")
require.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
})
Expand Down
2 changes: 1 addition & 1 deletion app/storage/detected_spam.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type DetectedSpamInfo struct {
Timestamp time.Time `db:"timestamp"`
Added bool `db:"added"` // added to samples
ChecksJSON string `db:"checks"` // Store as JSON
Checks []spamcheck.Response `db:"-"` // Don't store in DB
Checks []spamcheck.Response `db:"-"` // Don't store in DB directly, for db it uses ChecksJSON
}

// NewDetectedSpam creates a new DetectedSpam storage
Expand Down
15 changes: 8 additions & 7 deletions app/storage/dictionary.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ func NewDictionary(ctx context.Context, db *Engine) (*Dictionary, error) {
gid TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
type TEXT CHECK (type IN ('stop_phrase', 'ignored_word')),
data TEXT NOT NULL UNIQUE
data TEXT NOT NULL,
UNIQUE(gid, data)
);
CREATE INDEX IF NOT EXISTS idx_dictionary_timestamp ON dictionary(timestamp);
CREATE INDEX IF NOT EXISTS idx_dictionary_type ON dictionary(type);
Expand All @@ -62,7 +63,7 @@ func NewDictionary(ctx context.Context, db *Engine) (*Dictionary, error) {
}

// Add adds a stop phrase or ignored word to the dictionary
func (d *Dictionary) Add(ctx context.Context, t DictionaryType, data, gid string) error {
func (d *Dictionary) Add(ctx context.Context, gid string, t DictionaryType, data string) error {
if err := t.Validate(); err != nil {
return err
}
Expand Down Expand Up @@ -113,7 +114,7 @@ func (d *Dictionary) Delete(ctx context.Context, id int64) error {
}

// Read reads all entries from the dictionary by type
func (d *Dictionary) Read(ctx context.Context, t DictionaryType, gid string) ([]string, error) {
func (d *Dictionary) Read(ctx context.Context, gid string, t DictionaryType) ([]string, error) {
d.db.RLock()
defer d.db.RUnlock()

Expand All @@ -130,11 +131,11 @@ func (d *Dictionary) Read(ctx context.Context, t DictionaryType, gid string) ([]
}

// Reader returns a reader for phrases by type
func (d *Dictionary) Reader(ctx context.Context, t DictionaryType, gid string) (io.ReadCloser, error) {
func (d *Dictionary) Reader(ctx context.Context, gid string, t DictionaryType) (io.ReadCloser, error) {
if err := t.Validate(); err != nil {
return nil, err
}
recs, err := d.Read(ctx, t, gid)
recs, err := d.Read(ctx, gid, t)
if err != nil {
return nil, fmt.Errorf("failed to read phrases: %w", err)
}
Expand All @@ -143,7 +144,7 @@ func (d *Dictionary) Reader(ctx context.Context, t DictionaryType, gid string) (
}

// Iterator returns an iterator for phrases by type
func (d *Dictionary) Iterator(ctx context.Context, t DictionaryType, gid string) (iter.Seq[string], error) {
func (d *Dictionary) Iterator(ctx context.Context, gid string, t DictionaryType) (iter.Seq[string], error) {
if err := t.Validate(); err != nil {
return nil, err
}
Expand Down Expand Up @@ -174,7 +175,7 @@ func (d *Dictionary) Iterator(ctx context.Context, t DictionaryType, gid string)
// Import reads phrases from the reader and imports them into the storage.
// If withCleanup is true removes all entries with the same type before import.
// Input format is either a single phrase per line or a CSV file with multiple phrases.
func (d *Dictionary) Import(ctx context.Context, t DictionaryType, gid string, r io.Reader, withCleanup bool) (*DictionaryStats, error) {
func (d *Dictionary) Import(ctx context.Context, gid string, t DictionaryType, r io.Reader, withCleanup bool) (*DictionaryStats, error) {
if err := t.Validate(); err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 1c1ff49

Please sign in to comment.