diff --git a/internal/cache/account.go b/internal/cache/account.go new file mode 100644 index 0000000000..f62d48140d --- /dev/null +++ b/internal/cache/account.go @@ -0,0 +1,157 @@ +package cache + +import ( + "sync" + + "github.com/ReneKroon/ttlcache" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// AccountCache is a wrapper around ttlcache.Cache to provide URL and URI lookups for gtsmodel.Account +type AccountCache struct { + cache *ttlcache.Cache // map of IDs -> cached accounts + urls map[string]string // map of account URLs -> IDs + uris map[string]string // map of account URIs -> IDs + mutex sync.Mutex +} + +// NewAccountCache returns a new instantiated AccountCache object +func NewAccountCache() *AccountCache { + c := AccountCache{ + cache: ttlcache.NewCache(), + urls: make(map[string]string, 100), + uris: make(map[string]string, 100), + mutex: sync.Mutex{}, + } + + // Set callback to purge lookup maps on expiration + c.cache.SetExpirationCallback(func(key string, value interface{}) { + account := value.(*gtsmodel.Account) + + c.mutex.Lock() + delete(c.urls, account.URL) + delete(c.uris, account.URI) + c.mutex.Unlock() + }) + + return &c +} + +// GetByID attempts to fetch a account from the cache by its ID, you will receive a copy for thread-safety +func (c *AccountCache) GetByID(id string) (*gtsmodel.Account, bool) { + c.mutex.Lock() + account, ok := c.getByID(id) + c.mutex.Unlock() + return account, ok +} + +// GetByURL attempts to fetch a account from the cache by its URL, you will receive a copy for thread-safety +func (c *AccountCache) GetByURL(url string) (*gtsmodel.Account, bool) { + // Perform safe ID lookup + c.mutex.Lock() + id, ok := c.urls[url] + + // Not found, unlock early + if !ok { + c.mutex.Unlock() + return nil, false + } + + // Attempt account lookup + account, ok := c.getByID(id) + c.mutex.Unlock() + return account, ok +} + +// GetByURI attempts to fetch a account from the cache by its URI, you will receive a copy for thread-safety +func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) { + // Perform safe ID lookup + c.mutex.Lock() + id, ok := c.uris[uri] + + // Not found, unlock early + if !ok { + c.mutex.Unlock() + return nil, false + } + + // Attempt account lookup + account, ok := c.getByID(id) + c.mutex.Unlock() + return account, ok +} + +// getByID performs an unsafe (no mutex locks) lookup of account by ID, returning a copy of account in cache +func (c *AccountCache) getByID(id string) (*gtsmodel.Account, bool) { + v, ok := c.cache.Get(id) + if !ok { + return nil, false + } + return copyAccount(v.(*gtsmodel.Account)), true +} + +// Put places a account in the cache, ensuring that the object place is a copy for thread-safety +func (c *AccountCache) Put(account *gtsmodel.Account) { + if account == nil || account.ID == "" { + panic("invalid account") + } + + c.mutex.Lock() + c.cache.Set(account.ID, copyAccount(account)) + if account.URL != "" { + c.urls[account.URL] = account.ID + } + if account.URI != "" { + c.uris[account.URI] = account.ID + } + c.mutex.Unlock() +} + +// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects. +// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) +// this should be a relatively cheap process +func copyAccount(account *gtsmodel.Account) *gtsmodel.Account { + return >smodel.Account{ + ID: account.ID, + Username: account.Username, + Domain: account.Domain, + AvatarMediaAttachmentID: account.AvatarMediaAttachmentID, + AvatarMediaAttachment: nil, + AvatarRemoteURL: account.AvatarRemoteURL, + HeaderMediaAttachmentID: account.HeaderMediaAttachmentID, + HeaderMediaAttachment: nil, + HeaderRemoteURL: account.HeaderRemoteURL, + DisplayName: account.DisplayName, + Fields: account.Fields, + Note: account.Note, + Memorial: account.Memorial, + MovedToAccountID: account.MovedToAccountID, + CreatedAt: account.CreatedAt, + UpdatedAt: account.UpdatedAt, + Bot: account.Bot, + Reason: account.Reason, + Locked: account.Locked, + Discoverable: account.Discoverable, + Privacy: account.Privacy, + Sensitive: account.Sensitive, + Language: account.Language, + URI: account.URI, + URL: account.URL, + LastWebfingeredAt: account.LastWebfingeredAt, + InboxURI: account.InboxURI, + OutboxURI: account.OutboxURI, + FollowingURI: account.FollowingURI, + FollowersURI: account.FollowersURI, + FeaturedCollectionURI: account.FeaturedCollectionURI, + ActorType: account.ActorType, + AlsoKnownAs: account.AlsoKnownAs, + PrivateKey: nil, + PublicKey: account.PublicKey, + PublicKeyURI: account.PublicKeyURI, + SensitizedAt: account.SensitizedAt, + SilencedAt: account.SilencedAt, + SuspendedAt: account.SuspendedAt, + HideCollections: account.HideCollections, + SuspensionOrigin: account.SuspensionOrigin, + } +} diff --git a/internal/cache/account_test.go b/internal/cache/account_test.go new file mode 100644 index 0000000000..d65c3196d2 --- /dev/null +++ b/internal/cache/account_test.go @@ -0,0 +1,41 @@ +package cache_test + +import ( + "testing" + + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func TestAccountCache(t *testing.T) { + cache := cache.NewAccountCache() + + // Attempt to place an account + account := gtsmodel.Account{ + ID: "id", + URI: "uri", + URL: "url", + } + cache.Put(&account) + + var ok bool + var check *gtsmodel.Account + + // Check we can retrieve + check, ok = cache.GetByID(account.ID) + if !ok || !accountIs(&account, check) { + t.Fatal("Could not find expected status") + } + check, ok = cache.GetByURI(account.URI) + if !ok || !accountIs(&account, check) { + t.Fatal("Could not find expected status") + } + check, ok = cache.GetByURL(account.URL) + if !ok || !accountIs(&account, check) { + t.Fatal("Could not find expected status") + } +} + +func accountIs(account1, account2 *gtsmodel.Account) bool { + return account1.ID == account2.ID && account1.URI == account2.URI && account1.URL == account2.URL +} diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index d7d45a739f..177f43126d 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -34,6 +35,7 @@ import ( type accountDB struct { config *config.Config conn *DBConn + cache *cache.AccountCache } func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { @@ -83,22 +85,45 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel. return account, nil } +func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { + // Attempt to fetch cached account + account, cached := cacheGet() + + if !cached { + account = >smodel.Account{} + + // Not cached! Perform database query + err := dbQuery(account) + if err != nil { + return nil, a.conn.ProcessError(err) + } + + // Place in the cache + a.cache.Put(account) + } + + return account, nil +} + func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { if strings.TrimSpace(account.ID) == "" { + // TODO: we should not need this check here return nil, errors.New("account had no ID") } + // Update the account's last-used account.UpdatedAt = time.Now() - q := a.conn. - NewUpdate(). - Model(account). - WherePK() - - _, err := q.Exec(ctx) + // Update the account model in the DB + _, err := a.conn.NewUpdate().Model(account).WherePK().Exec(ctx) if err != nil { return nil, a.conn.ProcessError(err) } + + // Place updated account in cache + // (this will replace existing, i.e. invalidating) + a.cache.Put(account) + return account, nil } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index ba19614a26..6fcc56e51f 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -130,11 +130,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) conn.RegisterModel(t) } + accounts := &accountDB{config: c, conn: conn, cache: cache.NewAccountCache()} + ps := &bunDBService{ - Account: &accountDB{ - config: c, - conn: conn, - }, + Account: accounts, Admin: &adminDB{ config: c, conn: conn, @@ -174,9 +173,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) conn: conn, }, Status: &statusDB{ - config: c, - conn: conn, - cache: cache.NewStatusCache(), + config: c, + conn: conn, + cache: cache.NewStatusCache(), + accounts: accounts, }, Timeline: &timelineDB{ config: c, diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go index 9e3a0d2893..abaebcebd0 100644 --- a/internal/db/bundb/conn.go +++ b/internal/db/bundb/conn.go @@ -12,6 +12,8 @@ import ( // dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality type DBConn struct { + // TODO: move *Config here, no need to be in each struct type + errProc func(error) db.Error // errProc is the SQL-type specific error processor log *logrus.Logger // log is the logger passed with this DBConn *bun.DB // DB is the underlying bun.DB connection diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index e0b86fe3e7..9464cfadfd 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -34,6 +34,11 @@ type statusDB struct { config *config.Config conn *DBConn cache *cache.StatusCache + + // TODO: keep method definitions in same place but instead have receiver + // all point to one single "db" type, so they can all share methods + // and caches where necessary + accounts *accountDB } func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { @@ -120,6 +125,14 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta s.cache.Put(status) } + // Set the status author account + author, err := s.accounts.GetAccountByID(ctx, status.AccountID) + if err != nil { + return nil, err + } + + // Return the prepared status + status.Account = author return status, nil }