Skip to content

Commit

Permalink
Add ability to restart bot.
Browse files Browse the repository at this point in the history
  • Loading branch information
airforce270 committed Nov 21, 2023
1 parent e4ed3b5 commit bf261cd
Show file tree
Hide file tree
Showing 12 changed files with 357 additions and 86 deletions.
15 changes: 12 additions & 3 deletions apiclients/supinic/supinic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package supinic

import (
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -31,10 +32,18 @@ const pingInterval = 15 * time.Minute
// StartPinging starts a background task to ping the Supinic API regularly
// to make sure the API knows the bot is still online.
// This function blocks and should be run within a goroutine.
func (c *Client) StartPinging() {
func (c *Client) StartPinging(ctx context.Context) {
pingTimer := time.NewTicker(pingInterval)
for {
go c.pingAPI()
time.Sleep(pingInterval)
select {
case <-ctx.Done():
log.Print("Stopping pinging Supinic API, context cancelled")
return
case <-pingTimer.C:
go c.pingAPI()
default:
}
pingTimer.Reset(pingInterval)
}
}

Expand Down
12 changes: 9 additions & 3 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ type Cache interface {
const (
// Cache key for the last sent Twitch message.
KeyLastSentTwitchMessage = "twitch_last_sent_message"
// Cache key for the platform that bot restart was requested from.
KeyRestartRequestedOnPlatform = "restart_requested_on_platform"
// Cache key for the channel that bot restart was requested from.
KeyRestartRequestedInChannel = "restart_requested_from_channel"
// Cache key for the ID of the message that requested the bot restart.
KeyRestartRequestedByMessageID = "restart_requested_by_message"
)

// GlobalSlowmodeKey returns the global slowmode cache key for a platform.
Expand All @@ -78,10 +84,10 @@ func (c *Redis) StoreExpiringBool(key string, value bool, expiration time.Durati
}
func (c *Redis) FetchBool(key string) (bool, error) {
resp, err := c.r.Get(context.Background(), key).Bool()
if errors.Is(err, redis.Nil) {
return false, nil
}
if err != nil {
if errors.Is(err, redis.Nil) {
return false, nil
}
return false, err
}
return resp, nil
Expand Down
23 changes: 23 additions & 0 deletions commands/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/airforce270/airbot/permission"
twitchplatform "github.com/airforce270/airbot/platforms/twitch"
"github.com/airforce270/airbot/utils"
"github.com/airforce270/airbot/utils/restart"
)

// Commands contains this package's commands.
Expand All @@ -30,6 +31,7 @@ var Commands = [...]basecommand.Command{
leaveCommand,
leaveOtherCommand,
reloadConfigCommand,
restartCommand,
setPrefixCommand,
}

Expand Down Expand Up @@ -138,6 +140,13 @@ var (
Handler: reloadConfig,
}

restartCommand = basecommand.Command{
Name: "restart",
Desc: "Restarts the bot. Does not restart the database, etc.",
Permission: permission.Admin,
Handler: restartBot,
}

setPrefixCommand = basecommand.Command{
Name: "setprefix",
Desc: "Sets the bot's prefix in the channel.",
Expand Down Expand Up @@ -346,6 +355,20 @@ func reloadConfig(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, e
}, nil
}

func restartBot(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, error) {
go restart.WriteRequester(msg.Platform.Name(), msg.Message.Channel, msg.Message.ID)

const delay = 500 * time.Millisecond
time.AfterFunc(delay, func() { restart.C <- true })

return []*base.Message{
{
Channel: msg.Message.Channel,
Text: "Restarting Airbot.",
},
}, nil
}

func setPrefix(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, error) {
prefixArg := args[0]
if !prefixArg.Present {
Expand Down
4 changes: 2 additions & 2 deletions commands/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2983,7 +2983,7 @@ func setFakes(url string, db *gorm.DB) {
kick.BaseURL = url
pastebin.FetchPasteURLOverride = url
seventv.BaseURL = url
twitch.Conn = twitch.NewForTesting(url, db)
twitch.SetInstance(twitch.NewForTesting(url, db))
}

func resetFakes() {
Expand All @@ -2995,7 +2995,7 @@ func resetFakes() {
kick.BaseURL = savedKickURL
pastebin.FetchPasteURLOverride = ""
seventv.BaseURL = saved7TVURL
twitch.Conn = twitch.NewForTesting(helix.DefaultAPIBaseURL, nil)
twitch.SetInstance(twitch.NewForTesting(helix.DefaultAPIBaseURL, nil))
}

func joinOtherUser1() error {
Expand Down
6 changes: 4 additions & 2 deletions database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package database

import (
"context"
"fmt"
"strings"
"sync"
Expand Down Expand Up @@ -36,10 +37,10 @@ var (
)

// Connect creates a connection to the database.
func Connect(dbname, user, password string) (*gorm.DB, error) {
func Connect(ctx context.Context, dbName, user, password string) (*gorm.DB, error) {
settings := map[string]string{
"host": "database",
"dbname": dbname,
"dbname": dbName,
"user": user,
"password": password,
"port": "5432",
Expand All @@ -51,6 +52,7 @@ func Connect(dbname, user, password string) (*gorm.DB, error) {
if err != nil {
return nil, fmt.Errorf("failed to open DB connection: %w", err)
}
gormDB.WithContext(ctx)

db, err := gormDB.DB()
if err != nil {
Expand Down
6 changes: 6 additions & 0 deletions docs/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ If it's wrapped in `[square brackets]`, it's an **optional** parameter.
- > Usage: `$reloadconfig`
- > Minimum permission level: `Admin`
### $restart

- Restarts the bot. Does not restart the database, etc.
- > Usage: `$restart`
- > Minimum permission level: `Admin`
### $setprefix

- Sets the bot's prefix in the channel.
Expand Down
15 changes: 12 additions & 3 deletions gamba/gamba.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package gamba

import (
"context"
"errors"
"fmt"
"log"
Expand All @@ -23,10 +24,18 @@ var (

// StartGrantingPoints starts a loop to grant points to all chatters on an interval.
// This function blocks and should be run within a goroutine.
func StartGrantingPoints(ps map[string]base.Platform, db *gorm.DB) {
func StartGrantingPoints(ctx context.Context, ps map[string]base.Platform, db *gorm.DB) {
timer := time.NewTicker(grantInterval)
for {
go grantPoints(ps, db)
time.Sleep(grantInterval)
select {
case <-ctx.Done():
log.Print("Stopping point granting, context cancelled")
return
case <-timer.C:
go grantPoints(ps, db)
default:
}
timer.Reset(grantInterval)
}
}

Expand Down
125 changes: 81 additions & 44 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,70 +2,61 @@
package main

import (
"context"
"fmt"
"log"
"os"
"os/signal"
"sync"
"time"

"github.com/airforce270/airbot/apiclients/supinic"
"github.com/airforce270/airbot/base"
"github.com/airforce270/airbot/cache"
"github.com/airforce270/airbot/config"
"github.com/airforce270/airbot/database"
"github.com/airforce270/airbot/gamba"
"github.com/airforce270/airbot/platforms"
"github.com/airforce270/airbot/utils/cleanup"
"github.com/airforce270/airbot/utils/restart"
)

// cleanupFunc is a function that should be called before program exit.
type cleanupFunc struct {
// name is the function's human-readable name.
name string
// f is the function to be called.
f func() error
}

// cleanupFuncs contains functions to be called to cleanup before program exit.
var cleanupFuncs []cleanupFunc
const (
waitForCancelFuncs = 100 * time.Millisecond
waitForContextCancellation = 100 * time.Millisecond
)

// startListeningForSigterm starts a goroutine that listens for SIGTERM (ctrl-c)
// and runs the cleanup functions before exiting.
func startListeningForSigterm() {
c := make(chan os.Signal, 2)
signal.Notify(c, os.Interrupt)
go func() {
<-c
for _, f := range cleanupFuncs {
if err := f.f(); err != nil {
log.Printf("cleanup function %s failed: %v", f.name, err)
}
}
os.Exit(1)
}()
func initialStart(ctx context.Context) (cleanup.Cleaner, error) {
cleaner, _, err := start(ctx)
return cleaner, err
}

// wait blocks the thread that calls it indefinitely.
func wait() {
wg := sync.WaitGroup{}
wg.Add(1)
wg.Wait()
func reStart(ctx context.Context) (cleanup.Cleaner, error) {
cleaner, ps, err := start(ctx)
if err != nil {
return nil, err
}
if err := restart.Notify(ps); err != nil {
log.Printf("Failed to notify restart: %v", err)
}
return cleaner, err
}

func main() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
startListeningForSigterm()
func start(ctx context.Context) (cleanup.Cleaner, map[string]base.Platform, error) {
cleaner := cleanup.NewCleaner()

log.Print("Reading config...")
cfg, err := config.Read()
if err != nil {
log.Fatalf("failed to read config: %v", err)
return nil, nil, fmt.Errorf("failed to read config: %v", err)
}

log.Print("Setting config values...")
config.StoreGlobals(cfg)

log.Printf("Connecting to database...")
db, err := database.Connect(os.Getenv("POSTGRES_DB"), os.Getenv("POSTGRES_USER"), os.Getenv("POSTGRES_PASSWORD"))
db, err := database.Connect(ctx, os.Getenv("POSTGRES_DB"), os.Getenv("POSTGRES_USER"), os.Getenv("POSTGRES_PASSWORD"))
if err != nil {
log.Fatalf("failed to connect to database: %v", err)
return nil, nil, fmt.Errorf("failed to connect to database: %v", err)
}
database.SetInstance(db)

Expand All @@ -75,34 +66,80 @@ func main() {

log.Printf("Performing database migrations...")
if err = database.Migrate(db); err != nil {
log.Fatalf("failed to perform database migrations: %v", err)
return nil, nil, fmt.Errorf("failed to perform database migrations: %v", err)
}

log.Printf("Preparing chat connections...")
ps, err := platforms.Build(cfg, db, &cdb)
if err != nil {
log.Fatalf("Failed to build platforms: %v", err)
return nil, nil, fmt.Errorf("failed to build platforms: %v", err)
}

for _, p := range ps {
log.Printf("Connecting to %s...", p.Name())
if err := p.Connect(); err != nil {
log.Fatalf("Failed to connect to %s: %v", p.Name(), err)
return cleaner, nil, fmt.Errorf("failed to connect to %s: %v", p.Name(), err)
}

log.Printf("Starting to handle messages on %s...", p.Name())
go platforms.StartHandling(p, db, &cdb, cfg.LogIncoming, cfg.LogOutgoing)
cleanupFuncs = append(cleanupFuncs, cleanupFunc{name: p.Name(), f: p.Disconnect})
go platforms.StartHandling(ctx, p, db, &cdb, cfg.LogIncoming, cfg.LogOutgoing)
cleaner.Register(cleanup.Func{Name: p.Name(), F: p.Disconnect})
}

go gamba.StartGrantingPoints(ps, db)
go gamba.StartGrantingPoints(ctx, ps, db)

if cfg.Supinic.IsConfigured() && cfg.Supinic.ShouldPingAPI {
log.Println("Starting to ping the Supinic API...")
supinicClient := supinic.NewClient(cfg.Supinic.UserID, cfg.Supinic.APIKey)
go supinicClient.StartPinging()
go supinicClient.StartPinging(ctx)
}

return cleaner, ps, nil
}

func main() {
ctx, cancel := context.WithCancel(context.Background())
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)

c := make(chan os.Signal, 2)
signal.Notify(c, os.Interrupt)
go func() {
<-c
cancel()
time.Sleep(waitForContextCancellation)
os.Exit(1)
}()

cleaner, err := initialStart(ctx)
if err != nil {
log.Fatalf("Failed to start: %v", err)
}
log.Printf("Airbot is now running.")
wait()

for {
select {
case <-restart.C:
log.Printf("Restarting...")

if err := cleaner.Cleanup(); err != nil {
log.Printf("Cleanup failed: %v", err)
}
time.Sleep(waitForCancelFuncs)
cancel()
time.Sleep(waitForContextCancellation)

ctx, cancel = context.WithCancel(context.Background())

cleaner, err = reStart(ctx)
if err != nil {
log.Fatalf("Failed to start: %v", err)
}
log.Printf("Airbot is now running (restarted).")
case <-ctx.Done():
log.Printf("Context cancelled, Airbot shutting down.")
return
}
}
}

// send message that says "Restarted" once bot is restarted
Loading

0 comments on commit bf261cd

Please sign in to comment.