Skip to content

Commit

Permalink
finish db test
Browse files Browse the repository at this point in the history
  • Loading branch information
shampoobera committed Apr 30, 2024
1 parent 9af0630 commit b5a29e1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
18 changes: 11 additions & 7 deletions contrib/screener-api/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,32 @@ func (d *DBSuite) TestBlacklist() {
testAddress := gofakeit.BitcoinAddress()

blacklistBody := db.BlacklistedAddress{
TypeReq: "test",
Id: "test",
TypeReq: "create",
Id: "testId",
Address: testAddress,
Network: "bitcoin",
Tag: "test",
Remark: "test",
Tag: "testTag",
Remark: "testRemark",
}

// blacklist the address
err := testDB.PutBlacklistedAddress(d.GetTestContext(), blacklistBody)
d.Require().NoError(err)
blacklistedAddress, err := testDB.GetBlacklistedAddress(d.GetTestContext(), blacklistBody.Address)
d.Require().NoError(err)
d.Require().NotNil(blacklistedAddress)

// update the address
blacklistBody.Remark = "updated"
blacklistBody.TypeReq = "update"
blacklistBody.Remark = "testRemarkUpdated"
err = testDB.UpdateBlacklistedAddress(d.GetTestContext(), blacklistBody.Id, blacklistBody)
d.Require().NoError(err)

// check to make sure it updated
blacklistedAddress, err := testDB.GetBlacklistedAddress(d.GetTestContext(), blacklistBody.Id)
blacklistedAddress, err = testDB.GetBlacklistedAddress(d.GetTestContext(), blacklistBody.Address)
d.Require().NoError(err)
d.Require().NotNil(blacklistedAddress)
d.Require().Equal("updated", blacklistedAddress.Remark)
d.Require().Equal("testRemarkUpdated", blacklistedAddress.Remark)

// check for non blacklisted address
res, err := testDB.GetBlacklistedAddress(d.GetTestContext(), gofakeit.BitcoinAddress())
Expand Down
22 changes: 8 additions & 14 deletions contrib/screener-api/db/sql/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ func GetAllModels() (allModels []interface{}) {
// GetBlacklistedAddress queries the db for the blacklisted address.
// Returns true if the address is blacklisted, false otherwise.
// Not used currently.
func (s *Store) GetBlacklistedAddress(ctx context.Context, id string) (*db.BlacklistedAddress, error) {
func (s *Store) GetBlacklistedAddress(ctx context.Context, address string) (*db.BlacklistedAddress, error) {
var blacklistedAddress db.BlacklistedAddress

if err := s.db.WithContext(ctx).Where(&db.BlacklistedAddress{
Id: id,
}).First(&blacklistedAddress).Error; err != nil {
if err := s.db.WithContext(ctx).Where("address = ?", address).
First(&blacklistedAddress).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
return nil, err
}
return nil, fmt.Errorf("failed to get blacklisted address: %w", err)
}
Expand Down Expand Up @@ -74,9 +73,7 @@ func (s *Store) PutBlacklistedAddress(ctx context.Context, body db.BlacklistedAd
// UpdateBlacklistedAddress updates the blacklisted address in the underlying db.
func (s *Store) UpdateBlacklistedAddress(ctx context.Context, id string, body db.BlacklistedAddress) error {
dbTx := s.db.WithContext(ctx).Model(&db.BlacklistedAddress{}).
Where(&db.BlacklistedAddress{
Id: id,
}).Updates(body)
Where("id = ?", id).Updates(body)
if dbTx.Error != nil {
return fmt.Errorf("failed to update blacklisted address: %w", dbTx.Error)
}
Expand All @@ -85,13 +82,10 @@ func (s *Store) UpdateBlacklistedAddress(ctx context.Context, id string, body db
}

func (s *Store) DeleteBlacklistedAddress(ctx context.Context, id string) error {
dbTx := s.db.WithContext(ctx).Where(&db.BlacklistedAddress{
Id: id,
}).Delete(&db.BlacklistedAddress{})
if dbTx.Error != nil {
return fmt.Errorf("failed to delete blacklisted address: %w", dbTx.Error)
if dbTx := s.db.WithContext(ctx).Where(
"id = ?", id).Delete(&db.BlacklistedAddress{}); dbTx.Error != nil || dbTx.RowsAffected == 0 {
return fmt.Errorf("failed to delete blacklisted address")
}

return nil
}

Expand Down

0 comments on commit b5a29e1

Please sign in to comment.