Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Update poll delete/update db queries #2361

Merged
merged 2 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func init() {
Table("polls").
Column("expires_at_new").
Set("? = ?", bun.Ident("expires_at_new"), bun.Ident("expires_at")).
Where("1"). // bun gets angry performing update over all rows
Where("TRUE"). // bun gets angry performing update over all rows
NyaaaWhatsUpDoc marked this conversation as resolved.
Show resolved Hide resolved
Exec(ctx); err != nil {
return err
}
Expand Down
77 changes: 43 additions & 34 deletions internal/db/bundb/poll.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,12 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error

var poll gtsmodel.Poll

// Select poll counts from DB.
// Select current poll counts from DB,
// taking minimal columns needed to
// increment/decrement votes.
if err := tx.NewSelect().
Model(&poll).
Column("options", "votes", "voters").
Where("? = ?", bun.Ident("id"), vote.PollID).
Scan(ctx); err != nil {
return err
Expand All @@ -365,31 +368,35 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error

func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
err := p.db.RunInTx(ctx, func(tx Tx) error {
// Delete all vote in poll,
// returning all vote choices.
switch _, err := tx.NewDelete().
// Delete all votes in poll.
res, err := tx.NewDelete().
Table("poll_votes").
Where("? = ?", bun.Ident("poll_id"), pollID).
Exec(ctx); {
Exec(ctx)
if err != nil {
// irrecoverable
return err
}

case err == nil:
// no issue.
ra, err := res.RowsAffected()
if err != nil {
// irrecoverable
return err
}

case errors.Is(err, db.ErrNoEntries):
// no votes found,
// return here.
if ra == 0 {
// No poll votes deleted,
// nothing to update.
return nil

default:
// irrecoverable.
return err
}

// Select current poll counts from DB,
// taking minimal columns needed to
// increment/decrement votes.
var poll gtsmodel.Poll

// Select poll counts from DB.
switch err := tx.NewSelect().
Model(&poll).
Column("options", "votes", "voters").
Where("? = ?", bun.Ident("id"), pollID).
Scan(ctx); {

Expand All @@ -410,7 +417,7 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
poll.ResetVotes()

// Finally, update the poll entry.
_, err := tx.NewUpdate().
_, err = tx.NewUpdate().
Model(&poll).
Column("votes", "voters").
Where("? = ?", bun.Ident("id"), pollID).
Expand All @@ -432,43 +439,45 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {

func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {
err := p.db.RunInTx(ctx, func(tx Tx) error {
var choices []int
// Slice should only ever be of length
// 0 or 1; it's a slice of slices only
// because we can't LIMIT deletes to 1.
var choicesSl [][]int

// Delete vote in poll by account,
// returning the ID + choices of the vote.
switch err := tx.NewDelete().
if err := tx.NewDelete().
Table("poll_votes").
Where("? = ?", bun.Ident("poll_id"), pollID).
Where("? = ?", bun.Ident("account_id"), accountID).
Returning("choices").
Scan(ctx, &choices); {

case err == nil:
// no issue.

case errors.Is(err, db.ErrNoEntries):
// no votes found,
// return here.
return nil

default:
Returning("?", bun.Ident("choices")).
Scan(ctx, &choicesSl); err != nil {
// irrecoverable.
return err
}

var poll gtsmodel.Poll
if len(choicesSl) != 1 {
// No poll votes by this
// acct on this poll.
return nil
}
choices := choicesSl[0]

// Select poll counts from DB.
// Select current poll counts from DB,
// taking minimal columns needed to
// increment/decrement votes.
var poll gtsmodel.Poll
switch err := tx.NewSelect().
Model(&poll).
Column("options", "votes", "voters").
Where("? = ?", bun.Ident("id"), pollID).
Scan(ctx); {

case err == nil:
// no issue.

case errors.Is(err, db.ErrNoEntries):
// no votes found,
// no poll found,
// return here.
return nil

Expand Down
54 changes: 52 additions & 2 deletions internal/db/bundb/poll_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
Expand Down Expand Up @@ -304,15 +305,64 @@ func (suite *PollTestSuite) TestDeletePollVotes() {
suite.NoError(err)

// Fetch latest version of poll from database.
poll, err = suite.db.GetPollByID(ctx, poll.ID)
poll, err = suite.db.GetPollByID(
gtscontext.SetBarebones(ctx),
poll.ID,
)
suite.NoError(err)

// Check that poll counts are all zero.
suite.Equal(*poll.Voters, 0)
suite.Equal(poll.Votes, make([]int, len(poll.Options)))
suite.Equal(make([]int, len(poll.Options)), poll.Votes)
}
}

func (suite *PollTestSuite) TestDeletePollVotesNoPoll() {
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()

// Try to delete votes of nonexistent poll.
nonPollID := "01HF6V4XWTSZWJ80JNPPDTD4DB"

err := suite.db.DeletePollVotes(ctx, nonPollID)
suite.NoError(err)
}

func (suite *PollTestSuite) TestDeletePollVotesBy() {
ctx, cncl := context.WithCancel(context.Background())
defer cncl()

for _, vote := range suite.testPollVotes {
// Fetch before version of pollBefore from database.
pollBefore, err := suite.db.GetPollByID(ctx, vote.PollID)
suite.NoError(err)

// Delete this poll vote.
err = suite.db.DeletePollVoteBy(ctx, vote.PollID, vote.AccountID)
suite.NoError(err)

// Fetch after version of poll from database.
pollAfter, err := suite.db.GetPollByID(ctx, vote.PollID)
suite.NoError(err)

// Voters count should be reduced by 1.
suite.Equal(*pollBefore.Voters-1, *pollAfter.Voters)
}
}

func (suite *PollTestSuite) TestDeletePollVotesByNoAccount() {
ctx, cncl := context.WithCancel(context.Background())
defer cncl()

// Try to delete a poll by nonexisting account.
pollID := suite.testPolls["local_account_1_status_6_poll"].ID
nonAccountID := "01HF6T545G1G8ZNMY1S3ZXJ608"

err := suite.db.DeletePollVoteBy(ctx, pollID, nonAccountID)
suite.NoError(err)
}

func TestPollTestSuite(t *testing.T) {
suite.Run(t, new(PollTestSuite))
}