From b5a29e1bfc8cc54e928f1a9b45bee73f56e82dc8 Mon Sep 17 00:00:00 2001 From: shampoobera Date: Tue, 30 Apr 2024 11:21:32 -0500 Subject: [PATCH] finish db test --- contrib/screener-api/db/db_test.go | 18 +++++++++++------- contrib/screener-api/db/sql/base/base.go | 22 ++++++++-------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/contrib/screener-api/db/db_test.go b/contrib/screener-api/db/db_test.go index 7dd81c26b8..5f4ef8d28b 100644 --- a/contrib/screener-api/db/db_test.go +++ b/contrib/screener-api/db/db_test.go @@ -73,28 +73,32 @@ func (d *DBSuite) TestBlacklist() { testAddress := gofakeit.BitcoinAddress() blacklistBody := db.BlacklistedAddress{ - TypeReq: "test", - Id: "test", + TypeReq: "create", + Id: "testId", Address: testAddress, Network: "bitcoin", - Tag: "test", - Remark: "test", + Tag: "testTag", + Remark: "testRemark", } // blacklist the address err := testDB.PutBlacklistedAddress(d.GetTestContext(), blacklistBody) d.Require().NoError(err) + blacklistedAddress, err := testDB.GetBlacklistedAddress(d.GetTestContext(), blacklistBody.Address) + d.Require().NoError(err) + d.Require().NotNil(blacklistedAddress) // update the address - blacklistBody.Remark = "updated" + blacklistBody.TypeReq = "update" + blacklistBody.Remark = "testRemarkUpdated" err = testDB.UpdateBlacklistedAddress(d.GetTestContext(), blacklistBody.Id, blacklistBody) d.Require().NoError(err) // check to make sure it updated - blacklistedAddress, err := testDB.GetBlacklistedAddress(d.GetTestContext(), blacklistBody.Id) + blacklistedAddress, err = testDB.GetBlacklistedAddress(d.GetTestContext(), blacklistBody.Address) d.Require().NoError(err) d.Require().NotNil(blacklistedAddress) - d.Require().Equal("updated", blacklistedAddress.Remark) + d.Require().Equal("testRemarkUpdated", blacklistedAddress.Remark) // check for non blacklisted address res, err := testDB.GetBlacklistedAddress(d.GetTestContext(), gofakeit.BitcoinAddress()) diff --git a/contrib/screener-api/db/sql/base/base.go b/contrib/screener-api/db/sql/base/base.go index 22a8aa3f6b..d75422d066 100644 --- a/contrib/screener-api/db/sql/base/base.go +++ b/contrib/screener-api/db/sql/base/base.go @@ -38,14 +38,13 @@ func GetAllModels() (allModels []interface{}) { // GetBlacklistedAddress queries the db for the blacklisted address. // Returns true if the address is blacklisted, false otherwise. // Not used currently. -func (s *Store) GetBlacklistedAddress(ctx context.Context, id string) (*db.BlacklistedAddress, error) { +func (s *Store) GetBlacklistedAddress(ctx context.Context, address string) (*db.BlacklistedAddress, error) { var blacklistedAddress db.BlacklistedAddress - if err := s.db.WithContext(ctx).Where(&db.BlacklistedAddress{ - Id: id, - }).First(&blacklistedAddress).Error; err != nil { + if err := s.db.WithContext(ctx).Where("address = ?", address). + First(&blacklistedAddress).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, nil + return nil, err } return nil, fmt.Errorf("failed to get blacklisted address: %w", err) } @@ -74,9 +73,7 @@ func (s *Store) PutBlacklistedAddress(ctx context.Context, body db.BlacklistedAd // UpdateBlacklistedAddress updates the blacklisted address in the underlying db. func (s *Store) UpdateBlacklistedAddress(ctx context.Context, id string, body db.BlacklistedAddress) error { dbTx := s.db.WithContext(ctx).Model(&db.BlacklistedAddress{}). - Where(&db.BlacklistedAddress{ - Id: id, - }).Updates(body) + Where("id = ?", id).Updates(body) if dbTx.Error != nil { return fmt.Errorf("failed to update blacklisted address: %w", dbTx.Error) } @@ -85,13 +82,10 @@ func (s *Store) UpdateBlacklistedAddress(ctx context.Context, id string, body db } func (s *Store) DeleteBlacklistedAddress(ctx context.Context, id string) error { - dbTx := s.db.WithContext(ctx).Where(&db.BlacklistedAddress{ - Id: id, - }).Delete(&db.BlacklistedAddress{}) - if dbTx.Error != nil { - return fmt.Errorf("failed to delete blacklisted address: %w", dbTx.Error) + if dbTx := s.db.WithContext(ctx).Where( + "id = ?", id).Delete(&db.BlacklistedAddress{}); dbTx.Error != nil || dbTx.RowsAffected == 0 { + return fmt.Errorf("failed to delete blacklisted address") } - return nil }