Skip to content

Commit

Permalink
feat: use context
Browse files Browse the repository at this point in the history
  • Loading branch information
aymanbagabas committed May 2, 2023
1 parent cbc155d commit 3d7eb7b
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 52 deletions.
5 changes: 3 additions & 2 deletions cmd/soft/migrate_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ var (
migrateConfig = &cobra.Command{
Use: "migrate-config",
Short: "Migrate config to new format",
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
keyPath := os.Getenv("SOFT_SERVE_KEY_PATH")
reposPath := os.Getenv("SOFT_SERVE_REPO_PATH")
bindAddr := os.Getenv("SOFT_SERVE_BIND_ADDRESS")
ctx := cmd.Context()
cfg := config.DefaultConfig()
sb, err := sqlite.NewSqliteBackend(cfg)
sb, err := sqlite.NewSqliteBackend(ctx, cfg)
if err != nil {
return fmt.Errorf("failed to create sqlite backend: %w", err)
}
Expand Down
14 changes: 13 additions & 1 deletion cmd/soft/root.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package main

import (
"context"
"os"
"runtime/debug"

"github.com/charmbracelet/log"
_ "github.com/charmbracelet/soft-serve/log"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -49,7 +51,17 @@ func init() {
}

func main() {
if err := rootCmd.Execute(); err != nil {
logger := log.NewWithOptions(os.Stderr, log.Options{
ReportTimestamp: true,
TimeFormat: "2006-01-02",
})
if os.Getenv("SOFT_SERVE_DEBUG") == "true" {
logger.SetLevel(log.DebugLevel)
}

ctx := context.Background()
ctx = log.WithContext(ctx, logger)
if err := rootCmd.ExecuteContext(ctx); err != nil {
os.Exit(1)
}
}
7 changes: 4 additions & 3 deletions cmd/soft/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"syscall"
"time"

_ "github.com/charmbracelet/soft-serve/log"
"github.com/charmbracelet/soft-serve/server"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/spf13/cobra"
Expand All @@ -19,19 +20,19 @@ var (
Long: "Start the server",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
cfg := config.DefaultConfig()
s, err := server.NewServer(cfg)
s, err := server.NewServer(ctx, cfg)
if err != nil {
return err
}

ctx := cmd.Context()
done := make(chan os.Signal, 1)
lch := make(chan error, 1)
go func() {
defer close(lch)
defer close(done)
lch <- s.Start(ctx)
lch <- s.Start()
}()

signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
Expand Down
5 changes: 3 additions & 2 deletions examples/setuid/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ func main() {
if err := syscall.Setuid(*uid); err != nil {
log.Fatal("Setuid error", "err", err)
}
ctx := context.Background()
cfg := config.DefaultConfig()
cfg.SSH.ListenAddr = fmt.Sprintf(":%d", *port)
s, err := server.NewServer(cfg)
s, err := server.NewServer(ctx, cfg)
if err != nil {
log.Fatal(err)
}
Expand All @@ -64,7 +65,7 @@ func main() {
<-done

log.Print("Stopping SSH server", "addr", cfg.SSH.ListenAddr)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer func() { cancel() }()
if err := s.Shutdown(ctx); err != nil {
log.Fatal(err)
Expand Down
48 changes: 25 additions & 23 deletions server/backend/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ var (
// backend.
type SqliteBackend struct {
cfg *config.Config
ctx context.Context
dp string
db *sqlx.DB
}
Expand All @@ -37,7 +38,7 @@ func (d *SqliteBackend) reposPath() string {
}

// NewSqliteBackend creates a new SqliteBackend.
func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) {
func NewSqliteBackend(ctx context.Context, cfg *config.Config) (*SqliteBackend, error) {
dataPath := cfg.DataPath
if err := os.MkdirAll(dataPath, 0755); err != nil {
return nil, err
Expand All @@ -51,6 +52,7 @@ func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) {

d := &SqliteBackend{
cfg: cfg,
ctx: ctx,
dp: dataPath,
db: db,
}
Expand All @@ -71,7 +73,7 @@ func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) {
// It implements backend.Backend.
func (d *SqliteBackend) AllowKeyless() bool {
var allow bool
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return tx.Get(&allow, "SELECT value FROM settings WHERE key = ?;", "allow_keyless")
}); err != nil {
return false
Expand All @@ -85,7 +87,7 @@ func (d *SqliteBackend) AllowKeyless() bool {
// It implements backend.Backend.
func (d *SqliteBackend) AnonAccess() backend.AccessLevel {
var level string
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return tx.Get(&level, "SELECT value FROM settings WHERE key = ?;", "anon_access")
}); err != nil {
return backend.NoAccess
Expand All @@ -99,7 +101,7 @@ func (d *SqliteBackend) AnonAccess() backend.AccessLevel {
// It implements backend.Backend.
func (d *SqliteBackend) SetAllowKeyless(allow bool) error {
return wrapDbErr(
wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", allow, "allow_keyless")
return err
}),
Expand All @@ -111,7 +113,7 @@ func (d *SqliteBackend) SetAllowKeyless(allow bool) error {
// It implements backend.Backend.
func (d *SqliteBackend) SetAnonAccess(level backend.AccessLevel) error {
return wrapDbErr(
wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", level.String(), "anon_access")
return err
}),
Expand Down Expand Up @@ -147,7 +149,7 @@ func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOpt
return nil, err
}

if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`INSERT INTO repo (name, project_name, description, private, mirror, hidden, updated_at)
VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP);`,
name, opts.ProjectName, opts.Description, opts.Private, opts.Mirror, opts.Hidden)
Expand Down Expand Up @@ -210,7 +212,7 @@ func (d *SqliteBackend) DeleteRepository(name string) error {
return os.ErrNotExist
}

if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec("DELETE FROM repo WHERE name = ?;", name)
return err
}); err != nil {
Expand Down Expand Up @@ -245,7 +247,7 @@ func (d *SqliteBackend) RenameRepository(oldName string, newName string) error {
return fmt.Errorf("repository %s already exists", newName)
}

if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec("UPDATE repo SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", newName, oldName)
return err
}); err != nil {
Expand All @@ -260,7 +262,7 @@ func (d *SqliteBackend) RenameRepository(oldName string, newName string) error {
// It implements backend.Backend.
func (d *SqliteBackend) Repositories() ([]backend.Repository, error) {
repos := make([]backend.Repository, 0)
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
rows, err := tx.Query("SELECT name FROM repo")
if err != nil {
return err
Expand Down Expand Up @@ -299,7 +301,7 @@ func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) {
}

var count int
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo)
}); err != nil {
return nil, wrapDbErr(err)
Expand All @@ -323,7 +325,7 @@ func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) {
func (d *SqliteBackend) Description(repo string) (string, error) {
repo = utils.SanitizeRepo(repo)
var desc string
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return tx.Get(&desc, "SELECT description FROM repo WHERE name = ?", repo)
}); err != nil {
return "", wrapDbErr(err)
Expand All @@ -338,7 +340,7 @@ func (d *SqliteBackend) Description(repo string) (string, error) {
func (d *SqliteBackend) IsMirror(repo string) (bool, error) {
repo = utils.SanitizeRepo(repo)
var mirror bool
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return tx.Get(&mirror, "SELECT mirror FROM repo WHERE name = ?", repo)
}); err != nil {
return false, wrapDbErr(err)
Expand All @@ -353,7 +355,7 @@ func (d *SqliteBackend) IsMirror(repo string) (bool, error) {
func (d *SqliteBackend) IsPrivate(repo string) (bool, error) {
repo = utils.SanitizeRepo(repo)
var private bool
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return tx.Get(&private, "SELECT private FROM repo WHERE name = ?", repo)
}); err != nil {
return false, wrapDbErr(err)
Expand All @@ -368,7 +370,7 @@ func (d *SqliteBackend) IsPrivate(repo string) (bool, error) {
func (d *SqliteBackend) IsHidden(repo string) (bool, error) {
repo = utils.SanitizeRepo(repo)
var hidden bool
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return tx.Get(&hidden, "SELECT hidden FROM repo WHERE name = ?", repo)
}); err != nil {
return false, wrapDbErr(err)
Expand All @@ -382,7 +384,7 @@ func (d *SqliteBackend) IsHidden(repo string) (bool, error) {
// It implements backend.Backend.
func (d *SqliteBackend) SetHidden(repo string, hidden bool) error {
repo = utils.SanitizeRepo(repo)
return wrapDbErr(wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec("UPDATE repo SET hidden = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", hidden, repo)
return err
}))
Expand All @@ -394,7 +396,7 @@ func (d *SqliteBackend) SetHidden(repo string, hidden bool) error {
func (d *SqliteBackend) ProjectName(repo string) (string, error) {
repo = utils.SanitizeRepo(repo)
var name string
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return tx.Get(&name, "SELECT project_name FROM repo WHERE name = ?", repo)
}); err != nil {
return "", wrapDbErr(err)
Expand All @@ -408,7 +410,7 @@ func (d *SqliteBackend) ProjectName(repo string) (string, error) {
// It implements backend.Backend.
func (d *SqliteBackend) SetDescription(repo string, desc string) error {
repo = utils.SanitizeRepo(repo)
return wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec("UPDATE repo SET description = ? WHERE name = ?", desc, repo)
return err
})
Expand All @@ -420,7 +422,7 @@ func (d *SqliteBackend) SetDescription(repo string, desc string) error {
func (d *SqliteBackend) SetPrivate(repo string, private bool) error {
repo = utils.SanitizeRepo(repo)
return wrapDbErr(
wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec("UPDATE repo SET private = ? WHERE name = ?", private, repo)
return err
}),
Expand All @@ -433,7 +435,7 @@ func (d *SqliteBackend) SetPrivate(repo string, private bool) error {
func (d *SqliteBackend) SetProjectName(repo string, name string) error {
repo = utils.SanitizeRepo(repo)
return wrapDbErr(
wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec("UPDATE repo SET project_name = ? WHERE name = ?", name, repo)
return err
}),
Expand All @@ -450,7 +452,7 @@ func (d *SqliteBackend) AddCollaborator(repo string, username string) error {
}

repo = utils.SanitizeRepo(repo)
return wrapDbErr(wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`INSERT INTO collab (user_id, repo_id, updated_at)
VALUES (
(SELECT id FROM user WHERE username = ?),
Expand All @@ -468,7 +470,7 @@ func (d *SqliteBackend) AddCollaborator(repo string, username string) error {
func (d *SqliteBackend) Collaborators(repo string) ([]string, error) {
repo = utils.SanitizeRepo(repo)
var users []string
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return tx.Select(&users, `SELECT name FROM user
INNER JOIN collab ON user.id = collab.user_id
INNER JOIN repo ON repo.id = collab.repo_id
Expand All @@ -486,7 +488,7 @@ func (d *SqliteBackend) Collaborators(repo string) ([]string, error) {
func (d *SqliteBackend) IsCollaborator(repo string, username string) (bool, error) {
repo = utils.SanitizeRepo(repo)
var count int
if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
return tx.Get(&count, `SELECT COUNT(*) FROM user
INNER JOIN collab ON user.id = collab.user_id
INNER JOIN repo ON repo.id = collab.repo_id
Expand All @@ -504,7 +506,7 @@ func (d *SqliteBackend) IsCollaborator(repo string, username string) (bool, erro
func (d *SqliteBackend) RemoveCollaborator(repo string, username string) error {
repo = utils.SanitizeRepo(repo)
return wrapDbErr(
wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`DELETE FROM collab
WHERE user_id = (SELECT id FROM user WHERE username = ?)
AND repo_id = (SELECT id FROM repo WHERE name = ?)`, username, repo)
Expand Down
4 changes: 0 additions & 4 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ import (
"gopkg.in/yaml.v3"
)

var (
logger = log.WithPrefix("server.config")
)

// SSHConfig is the configuration for the SSH server.
type SSHConfig struct {
// ListenAddr is the address on which the SSH server will listen.
Expand Down
4 changes: 2 additions & 2 deletions server/cron/cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func (l cronLogger) Error(err error, msg string, keysAndValues ...interface{}) {
}

// NewCronScheduler returns a new Cron.
func NewCronScheduler() *CronScheduler {
logger := cronLogger{log.WithPrefix("server.cron")}
func NewCronScheduler(ctx context.Context) *CronScheduler {
logger := cronLogger{log.FromContext(ctx).WithPrefix("server.cron")}
return &CronScheduler{
Cron: cron.New(cron.WithLogger(logger)),
}
Expand Down
4 changes: 3 additions & 1 deletion server/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -34,7 +35,8 @@ func TestMain(m *testing.M) {
if err != nil {
log.Fatal(err)
}
fb, err := sqlite.NewSqliteBackend(cfg)
ctx := context.TODO()
fb, err := sqlite.NewSqliteBackend(ctx, cfg)
if err != nil {
log.Fatal(err)
}
Expand Down
Loading

0 comments on commit 3d7eb7b

Please sign in to comment.