diff --git a/cmd/soft/root.go b/cmd/soft/root.go index 273c9d6a3..f78c22606 100644 --- a/cmd/soft/root.go +++ b/cmd/soft/root.go @@ -54,6 +54,9 @@ func init() { func main() { logger := NewDefaultLogger() + // Set global logger + log.SetDefault(logger) + // Set the max number of processes to the number of CPUs // This is useful when running soft serve in a container if _, err := maxprocs.Set(maxprocs.Logger(logger.Debugf)); err != nil { diff --git a/git/repo.go b/git/repo.go index 98b81b24c..8c0b87c43 100644 --- a/git/repo.go +++ b/git/repo.go @@ -200,13 +200,6 @@ func (r *Repository) CommitsByPage(ref *Reference, page, size int) (Commits, err return commits, nil } -// UpdateServerInfo updates the repository server info. -func (r *Repository) UpdateServerInfo() error { - cmd := git.NewCommand("update-server-info") - _, err := cmd.RunInDir(r.Path) - return err -} - // Config returns the config value for the given key. func (r *Repository) Config(key string, opts ...ConfigOptions) (string, error) { dir, err := gitDir(r.Repository) diff --git a/git/server.go b/git/server.go new file mode 100644 index 000000000..e868b1a29 --- /dev/null +++ b/git/server.go @@ -0,0 +1,18 @@ +package git + +import ( + "context" + + "github.com/gogs/git-module" +) + +// UpdateServerInfo updates the server info file for the given repo path. +func UpdateServerInfo(ctx context.Context, path string) error { + if !isGitDir(path) { + return ErrNotAGitRepository + } + + cmd := git.NewCommand("update-server-info").WithContext(ctx).WithTimeout(-1) + _, err := cmd.RunInDir(path) + return err +} diff --git a/git/utils.go b/git/utils.go index 2b3d28728..3710e172d 100644 --- a/git/utils.go +++ b/git/utils.go @@ -1,6 +1,7 @@ package git import ( + "os" "path/filepath" "github.com/gobwas/glob" @@ -49,3 +50,25 @@ func LatestFile(repo *Repository, pattern string) (string, string, error) { } return "", "", ErrFileNotFound } + +// Returns true if path is a directory containing an `objects` directory and a +// `HEAD` file. +func isGitDir(path string) bool { + stat, err := os.Stat(filepath.Join(path, "objects")) + if err != nil { + return false + } + if !stat.IsDir() { + return false + } + + stat, err = os.Stat(filepath.Join(path, "HEAD")) + if err != nil { + return false + } + if stat.IsDir() { + return false + } + + return true +} diff --git a/internal/log/log.go b/internal/log/log.go index 7389b80fb..b6c4b1443 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -32,6 +32,10 @@ func NewDefaultLogger() *log.Logger { if debug, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_DEBUG")); debug { logger.SetLevel(log.DebugLevel) + + if verbose, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_VERBOSE")); verbose { + logger.SetReportCaller(true) + } } logger.SetTimeFormat(cfg.Log.TimeFormat) diff --git a/server/backend/sqlite/hooks.go b/server/backend/sqlite/hooks.go index ff39046f4..972b3f31d 100644 --- a/server/backend/sqlite/hooks.go +++ b/server/backend/sqlite/hooks.go @@ -36,16 +36,6 @@ func (d *SqliteBackend) PostUpdate(stdout io.Writer, stderr io.Writer, repo stri var wg sync.WaitGroup - // Update server info - wg.Add(1) - go func() { - defer wg.Done() - if err := updateServerInfo(d, repo); err != nil { - d.logger.Error("error updating server-info", "repo", repo, "err", err) - return - } - }() - // Populate last-modified file. wg.Add(1) go func() { @@ -59,20 +49,6 @@ func (d *SqliteBackend) PostUpdate(stdout io.Writer, stderr io.Writer, repo stri wg.Wait() } -func updateServerInfo(d *SqliteBackend, repo string) error { - rr, err := d.Repository(repo) - if err != nil { - return err - } - - r, err := rr.Open() - if err != nil { - return err - } - - return r.UpdateServerInfo() -} - func populateLastModified(d *SqliteBackend, repo string) error { var rr *Repo _rr, err := d.Repository(repo) diff --git a/server/backend/sqlite/sqlite.go b/server/backend/sqlite/sqlite.go index 91527e3fc..3273373ea 100644 --- a/server/backend/sqlite/sqlite.go +++ b/server/backend/sqlite/sqlite.go @@ -151,17 +151,12 @@ func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOpt return err } - rr, err := git.Init(rp, true) + _, err := git.Init(rp, true) if err != nil { d.logger.Debug("failed to create repository", "err", err) return err } - if err := rr.UpdateServerInfo(); err != nil { - d.logger.Debug("failed to update server info", "err", err) - return err - } - return nil }); err != nil { d.logger.Debug("failed to create repository in database", "err", err) diff --git a/server/config/config.go b/server/config/config.go index 2dca4fada..c8b9fa9bf 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -114,6 +114,40 @@ type Config struct { Backend backend.Backend `yaml:"-"` } +// Environ returns the config as a list of environment variables. +func (c *Config) Environ() []string { + envs := []string{} + if c == nil { + return envs + } + + // TODO: do this dynamically + envs = append(envs, []string{ + fmt.Sprintf("SOFT_SERVE_NAME=%s", c.Name), + fmt.Sprintf("SOFT_SERVE_DATA_PATH=%s", c.DataPath), + fmt.Sprintf("SOFT_SERVE_INITIAL_ADMIN_KEYS=%s", strings.Join(c.InitialAdminKeys, "\n")), + fmt.Sprintf("SOFT_SERVE_SSH_LISTEN_ADDR=%s", c.SSH.ListenAddr), + fmt.Sprintf("SOFT_SERVE_SSH_PUBLIC_URL=%s", c.SSH.PublicURL), + fmt.Sprintf("SOFT_SERVE_SSH_KEY_PATH=%s", c.SSH.KeyPath), + fmt.Sprintf("SOFT_SERVE_SSH_CLIENT_KEY_PATH=%s", c.SSH.ClientKeyPath), + fmt.Sprintf("SOFT_SERVE_SSH_MAX_TIMEOUT=%d", c.SSH.MaxTimeout), + fmt.Sprintf("SOFT_SERVE_SSH_IDLE_TIMEOUT=%d", c.SSH.IdleTimeout), + fmt.Sprintf("SOFT_SERVE_GIT_LISTEN_ADDR=%s", c.Git.ListenAddr), + fmt.Sprintf("SOFT_SERVE_GIT_MAX_TIMEOUT=%d", c.Git.MaxTimeout), + fmt.Sprintf("SOFT_SERVE_GIT_IDLE_TIMEOUT=%d", c.Git.IdleTimeout), + fmt.Sprintf("SOFT_SERVE_GIT_MAX_CONNECTIONS=%d", c.Git.MaxConnections), + fmt.Sprintf("SOFT_SERVE_HTTP_LISTEN_ADDR=%s", c.HTTP.ListenAddr), + fmt.Sprintf("SOFT_SERVE_HTTP_TLS_KEY_PATH=%s", c.HTTP.TLSKeyPath), + fmt.Sprintf("SOFT_SERVE_HTTP_TLS_CERT_PATH=%s", c.HTTP.TLSCertPath), + fmt.Sprintf("SOFT_SERVE_HTTP_PUBLIC_URL=%s", c.HTTP.PublicURL), + fmt.Sprintf("SOFT_SERVE_STATS_LISTEN_ADDR=%s", c.Stats.ListenAddr), + fmt.Sprintf("SOFT_SERVE_LOG_FORMAT=%s", c.Log.Format), + fmt.Sprintf("SOFT_SERVE_LOG_TIME_FORMAT=%s", c.Log.TimeFormat), + }...) + + return envs +} + func parseConfig(path string) (*Config, error) { dataPath := filepath.Dir(path) cfg := &Config{ diff --git a/server/daemon/conn.go b/server/daemon/conn.go new file mode 100644 index 000000000..090d76aee --- /dev/null +++ b/server/daemon/conn.go @@ -0,0 +1,105 @@ +package daemon + +import ( + "context" + "errors" + "net" + "sync" + "time" +) + +// connections is a synchronizes access to to a net.Conn pool. +type connections struct { + m map[net.Conn]struct{} + mu sync.Mutex +} + +func (m *connections) Add(c net.Conn) { + m.mu.Lock() + defer m.mu.Unlock() + m.m[c] = struct{}{} +} + +func (m *connections) Close(c net.Conn) error { + m.mu.Lock() + defer m.mu.Unlock() + err := c.Close() + delete(m.m, c) + return err +} + +func (m *connections) Size() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.m) +} + +func (m *connections) CloseAll() error { + m.mu.Lock() + defer m.mu.Unlock() + var err error + for c := range m.m { + err = errors.Join(err, c.Close()) + delete(m.m, c) + } + + return err +} + +// serverConn is a wrapper around a net.Conn that closes the connection when +// the one of the timeouts is reached. +type serverConn struct { + net.Conn + + initTimeout time.Duration + idleTimeout time.Duration + maxDeadline time.Time + closeCanceler context.CancelFunc +} + +var _ net.Conn = (*serverConn)(nil) + +func (c *serverConn) Write(p []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Write(p) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Read(b []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Read(b) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Close() (err error) { + err = c.Conn.Close() + if c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) updateDeadline() { + switch { + case c.initTimeout > 0: + initTimeout := time.Now().Add(c.initTimeout) + c.initTimeout = 0 + if initTimeout.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { + c.Conn.SetDeadline(initTimeout) + return + } + case c.idleTimeout > 0: + idleDeadline := time.Now().Add(c.idleTimeout) + if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { + c.Conn.SetDeadline(idleDeadline) + return + } + } + c.Conn.SetDeadline(c.maxDeadline) +} diff --git a/server/daemon/daemon.go b/server/daemon/daemon.go index 944f0eedb..1820c0e3e 100644 --- a/server/daemon/daemon.go +++ b/server/daemon/daemon.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "path/filepath" + "strings" "sync" "time" @@ -41,40 +42,6 @@ var ( ErrServerClosed = fmt.Errorf("git: %w", net.ErrClosed) ) -// connections synchronizes access to to a net.Conn pool. -type connections struct { - m map[net.Conn]struct{} - mu sync.Mutex -} - -func (m *connections) Add(c net.Conn) { - m.mu.Lock() - defer m.mu.Unlock() - m.m[c] = struct{}{} -} - -func (m *connections) Close(c net.Conn) { - m.mu.Lock() - defer m.mu.Unlock() - _ = c.Close() - delete(m.m, c) -} - -func (m *connections) Size() int { - m.mu.Lock() - defer m.mu.Unlock() - return len(m.m) -} - -func (m *connections) CloseAll() { - m.mu.Lock() - defer m.mu.Unlock() - for c := range m.m { - _ = c.Close() - delete(m.m, c) - } -} - // GitDaemon represents a Git daemon. type GitDaemon struct { ctx context.Context @@ -213,26 +180,53 @@ func (d *GitDaemon) handleClient(conn net.Conn) { return } - gitPack := git.UploadPack - counter := uploadPackGitCounter - cmd := string(split[0]) - switch cmd { - case git.UploadPackBin: - gitPack = git.UploadPack - case git.UploadArchiveBin: - gitPack = git.UploadArchive + var handler git.ServiceHandler + var counter *prometheus.CounterVec + service := git.Service(split[0]) + switch service { + case git.UploadPackService: + handler = git.UploadPack + counter = uploadPackGitCounter + case git.UploadArchiveService: + handler = git.UploadArchive counter = uploadArchiveGitCounter default: d.fatal(c, git.ErrInvalidRequest) return } - opts := bytes.Split(split[1], []byte{'\x00'}) - if len(opts) == 0 { - d.fatal(c, git.ErrInvalidRequest) + opts := bytes.SplitN(split[1], []byte{0}, 3) + if len(opts) < 2 { + d.fatal(c, git.ErrInvalidRequest) // nolint: errcheck return } + host := strings.TrimPrefix(string(opts[1]), "host=") + extraParams := map[string]string{} + + if len(opts) > 2 { + buf := bytes.TrimPrefix(opts[2], []byte{0}) + for _, o := range bytes.Split(buf, []byte{0}) { + opt := string(o) + if opt == "" { + continue + } + + kv := strings.SplitN(opt, "=", 2) + if len(kv) != 2 { + d.logger.Errorf("git: invalid option %q", opt) + continue + } + + extraParams[kv[0]] = kv[1] + } + + version := extraParams["version"] + if version != "" { + d.logger.Debugf("git: protocol version %s", version) + } + } + be := d.be.WithContext(ctx) if !be.AllowKeyless() { d.fatal(c, git.ErrNotAuthed) @@ -240,14 +234,21 @@ func (d *GitDaemon) handleClient(conn net.Conn) { } name := utils.SanitizeRepo(string(opts[0])) - d.logger.Debugf("git: connect %s %s %s", c.RemoteAddr(), cmd, name) - defer d.logger.Debugf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, name) + d.logger.Debugf("git: connect %s %s %s", c.RemoteAddr(), service, name) + defer d.logger.Debugf("git: disconnect %s %s %s", c.RemoteAddr(), service, name) + // git bare repositories should end in ".git" // https://git-scm.com/docs/gitrepository-layout repo := name + ".git" reposDir := filepath.Join(d.cfg.DataPath, "repos") if err := git.EnsureWithin(reposDir, repo); err != nil { - d.fatal(c, err) + d.logger.Debugf("git: error ensuring repo path: %v", err) + d.fatal(c, git.ErrInvalidRepo) + return + } + + if _, err := d.be.Repository(repo); err != nil { + d.fatal(c, git.ErrInvalidRepo) return } @@ -261,9 +262,33 @@ func (d *GitDaemon) handleClient(conn net.Conn) { envs := []string{ "SOFT_SERVE_REPO_NAME=" + name, "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo), + "SOFT_SERVE_HOST=" + host, } - if err := gitPack(ctx, c, c, c, filepath.Join(reposDir, repo), envs...); err != nil { + // Add git protocol environment variable. + if len(extraParams) > 0 { + var gitProto string + for k, v := range extraParams { + if len(gitProto) > 0 { + gitProto += ":" + } + gitProto += k + "=" + v + } + envs = append(envs, "GIT_PROTOCOL="+gitProto) + } + + envs = append(envs, d.cfg.Environ()...) + + cmd := git.ServiceCommand{ + Stdin: c, + Stdout: c, + Stderr: c, + Env: envs, + Dir: filepath.Join(reposDir, repo), + } + + if err := handler(ctx, cmd); err != nil { + d.logger.Debugf("git: error handling request: %v", err) d.fatal(c, err) return } @@ -296,51 +321,3 @@ func (d *GitDaemon) Shutdown(ctx context.Context) error { return err } } - -type serverConn struct { - net.Conn - - idleTimeout time.Duration - maxDeadline time.Time - closeCanceler context.CancelFunc -} - -func (c *serverConn) Write(p []byte) (n int, err error) { - c.updateDeadline() - n, err = c.Conn.Write(p) - if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) Read(b []byte) (n int, err error) { - c.updateDeadline() - n, err = c.Conn.Read(b) - if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) Close() (err error) { - err = c.Conn.Close() - if c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) updateDeadline() { - switch { - case c.idleTimeout > 0: - idleDeadline := time.Now().Add(c.idleTimeout) - if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { - c.Conn.SetDeadline(idleDeadline) - return - } - fallthrough - default: - c.Conn.SetDeadline(c.maxDeadline) - } -} diff --git a/server/git/git.go b/server/git/git.go index ef8affe20..8f8ae3d7c 100644 --- a/server/git/git.go +++ b/server/git/git.go @@ -5,16 +5,12 @@ import ( "errors" "fmt" "io" - "os" - "os/exec" "path/filepath" "strings" "github.com/charmbracelet/log" "github.com/charmbracelet/soft-serve/git" - "github.com/charmbracelet/soft-serve/server/config" "github.com/go-git/go-git/v5/plumbing/format/pktline" - "golang.org/x/sync/errgroup" ) var ( @@ -38,112 +34,6 @@ var ( ErrTimeout = errors.New("I/O timeout reached") ) -// Git protocol commands. -const ( - ReceivePackBin = "git-receive-pack" - UploadPackBin = "git-upload-pack" - UploadArchiveBin = "git-upload-archive" -) - -// UploadPack runs the git upload-pack protocol against the provided repo. -func UploadPack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error { - exists, err := fileExists(repoDir) - if !exists { - return ErrInvalidRepo - } - if err != nil { - return err - } - return RunGit(ctx, in, out, er, "", envs, UploadPackBin[4:], repoDir) -} - -// UploadArchive runs the git upload-archive protocol against the provided repo. -func UploadArchive(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error { - exists, err := fileExists(repoDir) - if !exists { - return ErrInvalidRepo - } - if err != nil { - return err - } - return RunGit(ctx, in, out, er, "", envs, UploadArchiveBin[4:], repoDir) -} - -// ReceivePack runs the git receive-pack protocol against the provided repo. -func ReceivePack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error { - if err := RunGit(ctx, in, out, er, "", envs, ReceivePackBin[4:], repoDir); err != nil { - return err - } - return EnsureDefaultBranch(ctx, in, out, er, repoDir) -} - -// RunGit runs a git command in the given repo. -func RunGit(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, dir string, envs []string, args ...string) error { - cfg := config.FromContext(ctx) - logger := log.FromContext(ctx).WithPrefix("rungit") - c := exec.CommandContext(ctx, "git", args...) - c.Dir = dir - c.Env = append(os.Environ(), envs...) - c.Env = append(c.Env, "PATH="+os.Getenv("PATH")) - c.Env = append(c.Env, "SOFT_SERVE_DEBUG="+os.Getenv("SOFT_SERVE_DEBUG")) - if cfg != nil { - c.Env = append(c.Env, "SOFT_SERVE_LOG_FORMAT="+cfg.Log.Format) - c.Env = append(c.Env, "SOFT_SERVE_LOG_TIME_FORMAT="+cfg.Log.TimeFormat) - } - - stdin, err := c.StdinPipe() - if err != nil { - logger.Error("failed to get stdin pipe", "err", err) - return err - } - - stdout, err := c.StdoutPipe() - if err != nil { - logger.Error("failed to get stdout pipe", "err", err) - return err - } - - stderr, err := c.StderrPipe() - if err != nil { - logger.Error("failed to get stderr pipe", "err", err) - return err - } - - if err := c.Start(); err != nil { - logger.Error("failed to start command", "err", err) - return err - } - - errg, ctx := errgroup.WithContext(ctx) - - // stdin - errg.Go(func() error { - defer stdin.Close() - - _, err := io.Copy(stdin, in) - return err - }) - - // stdout - errg.Go(func() error { - _, err := io.Copy(out, stdout) - return err - }) - - // stderr - errg.Go(func() error { - _, err := io.Copy(er, stderr) - return err - }) - - if err := errg.Wait(); err != nil { - logger.Error("while copying output", "err", err) - } - - // Wait for the command to finish - return c.Wait() -} - // WritePktline encodes and writes a pktline to the given writer. func WritePktline(w io.Writer, v ...interface{}) { msg := fmt.Sprintln(v...) @@ -179,19 +69,10 @@ func EnsureWithin(reposDir string, repo string) error { return nil } -func fileExists(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } - if os.IsNotExist(err) { - return false, nil - } - return true, err -} - -func EnsureDefaultBranch(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoPath string) error { - r, err := git.Open(repoPath) +// EnsureDefaultBranch ensures the repo has a default branch. +// It will prefer choosing "main" or "master" if available. +func EnsureDefaultBranch(ctx context.Context, scmd ServiceCommand) error { + r, err := git.Open(scmd.Dir) if err != nil { return err } @@ -205,8 +86,21 @@ func EnsureDefaultBranch(ctx context.Context, in io.Reader, out io.Writer, er io // Rename the default branch to the first branch available _, err = r.HEAD() if err == git.ErrReferenceNotExist { - err = RunGit(ctx, in, out, er, repoPath, []string{}, "branch", "-M", brs[0]) - if err != nil { + branch := brs[0] + // Prefer "main" or "master" as the default branch + for _, b := range brs { + if b == "main" || b == "master" { + branch = b + break + } + } + + cmd := git.NewCommand("branch", "-M", branch).WithContext(ctx) + if err := cmd.RunInDirWithOptions(scmd.Dir, git.RunInDirOptions{ + Stdin: scmd.Stdin, + Stdout: scmd.Stdout, + Stderr: scmd.Stderr, + }); err != nil { return err } } diff --git a/server/git/service.go b/server/git/service.go new file mode 100644 index 000000000..073001840 --- /dev/null +++ b/server/git/service.go @@ -0,0 +1,186 @@ +package git + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "os/exec" + "strings" + + "github.com/charmbracelet/log" + "golang.org/x/sync/errgroup" +) + +// Service is a Git daemon service. +type Service string + +const ( + // UploadPackService is the upload-pack service. + UploadPackService Service = "git-upload-pack" + // UploadArchiveService is the upload-archive service. + UploadArchiveService Service = "git-upload-archive" + // ReceivePackService is the receive-pack service. + ReceivePackService Service = "git-receive-pack" +) + +// String returns the string representation of the service. +func (s Service) String() string { + return string(s) +} + +// Name returns the name of the service. +func (s Service) Name() string { + return strings.TrimPrefix(s.String(), "git-") +} + +// Handler is the service handler. +func (s Service) Handler(ctx context.Context, cmd ServiceCommand) error { + switch s { + case UploadPackService, UploadArchiveService, ReceivePackService: + return gitServiceHandler(ctx, s, cmd) + default: + return fmt.Errorf("unsupported service: %s", s) + } +} + +// ServiceHandler is a git service command handler. +type ServiceHandler func(ctx context.Context, cmd ServiceCommand) error + +// gitServiceHandler is the default service handler using the git binary. +func gitServiceHandler(ctx context.Context, svc Service, scmd ServiceCommand) error { + cmd := exec.CommandContext(ctx, "git", "-c", "uploadpack.allowFilter=true", svc.Name()) // nolint: gosec + cmd.Dir = scmd.Dir + if len(scmd.Args) > 0 { + cmd.Args = append(cmd.Args, scmd.Args...) + } + + cmd.Args = append(cmd.Args, ".") + + cmd.Env = os.Environ() + if len(scmd.Env) > 0 { + cmd.Env = append(cmd.Env, scmd.Env...) + } + + if scmd.CmdFunc != nil { + scmd.CmdFunc(cmd) + } + + var ( + err error + stdin io.WriteCloser + stdout io.ReadCloser + stderr io.ReadCloser + ) + + if scmd.Stdin != nil { + stdin, err = cmd.StdinPipe() + if err != nil { + return err + } + } + + if scmd.Stdout != nil { + stdout, err = cmd.StdoutPipe() + if err != nil { + return err + } + } + + if scmd.Stderr != nil { + stderr, err = cmd.StderrPipe() + if err != nil { + return err + } + } + + log.Debugf("git service command in %q: %s", cmd.Dir, cmd.String()) + if err := cmd.Start(); err != nil { + return err + } + + errg, ctx := errgroup.WithContext(ctx) + + // stdin + if scmd.Stdin != nil { + errg.Go(func() error { + if scmd.StdinHandler != nil { + return scmd.StdinHandler(scmd.Stdin, stdin) + } else { + return defaultStdinHandler(scmd.Stdin, stdin) + } + }) + } + + // stdout + if scmd.Stdout != nil { + errg.Go(func() error { + if scmd.StdoutHandler != nil { + return scmd.StdoutHandler(scmd.Stdout, stdout) + } else { + return defaultStdoutHandler(scmd.Stdout, stdout) + } + }) + } + + // stderr + if scmd.Stderr != nil { + errg.Go(func() error { + if scmd.StderrHandler != nil { + return scmd.StderrHandler(scmd.Stderr, stderr) + } else { + return defaultStderrHandler(scmd.Stderr, stderr) + } + }) + } + + return errors.Join(errg.Wait(), cmd.Wait()) +} + +// ServiceCommand is used to run a git service command. +type ServiceCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer + Dir string + Env []string + Args []string + + // Modifier functions + CmdFunc func(*exec.Cmd) + StdinHandler func(io.Reader, io.WriteCloser) error + StdoutHandler func(io.Writer, io.ReadCloser) error + StderrHandler func(io.Writer, io.ReadCloser) error +} + +func defaultStdinHandler(in io.Reader, stdin io.WriteCloser) error { + defer stdin.Close() // nolint: errcheck + _, err := io.Copy(stdin, in) + return err +} + +func defaultStdoutHandler(out io.Writer, stdout io.ReadCloser) error { + _, err := io.Copy(out, stdout) + return err +} + +func defaultStderrHandler(err io.Writer, stderr io.ReadCloser) error { + _, erro := io.Copy(err, stderr) + return erro +} + +// UploadPack runs the git upload-pack protocol against the provided repo. +func UploadPack(ctx context.Context, cmd ServiceCommand) error { + return gitServiceHandler(ctx, UploadPackService, cmd) +} + +// UploadArchive runs the git upload-archive protocol against the provided repo. +func UploadArchive(ctx context.Context, cmd ServiceCommand) error { + return gitServiceHandler(ctx, UploadArchiveService, cmd) +} + +// ReceivePack runs the git receive-pack protocol against the provided repo. +func ReceivePack(ctx context.Context, cmd ServiceCommand) error { + return gitServiceHandler(ctx, ReceivePackService, cmd) +} diff --git a/server/ssh/ssh.go b/server/ssh/ssh.go index 42cd7f5f5..8e98d0ebc 100644 --- a/server/ssh/ssh.go +++ b/server/ssh/ssh.go @@ -216,13 +216,13 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware { return func(s ssh.Session) { func() { start := time.Now() - cmd := s.Command() + cmdLine := s.Command() ctx := s.Context() be := ss.be.WithContext(ctx) - if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") { - gc := cmd[0] + + if len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git") { // repo should be in the form of "repo.git" - name := utils.SanitizeRepo(cmd[1]) + name := utils.SanitizeRepo(cmdLine[1]) pk := s.PublicKey() ak := backend.MarshalAuthorizedKey(pk) access := cfg.Backend.AccessLevelByPublicKey(name, pk) @@ -240,12 +240,27 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware { "SOFT_SERVE_REPO_NAME=" + name, "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo), "SOFT_SERVE_PUBLIC_KEY=" + ak, + "SOFT_SERVE_USERNAME=" + ctx.User(), } - ss.logger.Debug("git middleware", "cmd", gc, "access", access.String()) + // Add ssh session & config environ + envs = append(envs, s.Environ()...) + envs = append(envs, cfg.Environ()...) + repoDir := filepath.Join(reposDir, repo) - switch gc { - case git.ReceivePackBin: + service := git.Service(cmdLine[0]) + cmd := git.ServiceCommand{ + Stdin: s, + Stdout: s, + Stderr: s.Stderr(), + Env: envs, + Dir: repoDir, + } + + ss.logger.Debug("git middleware", "cmd", service, "access", access.String()) + + switch service { + case git.ReceivePackService: receivePackCounter.WithLabelValues(name).Inc() defer func() { receivePackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds()) @@ -262,20 +277,27 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware { } createRepoCounter.WithLabelValues(name).Inc() } - if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil { + + if err := git.ReceivePack(ctx, cmd); err != nil { + sshFatal(s, git.ErrSystemMalfunction) + } + + if err := git.EnsureDefaultBranch(ctx, cmd); err != nil { sshFatal(s, git.ErrSystemMalfunction) } + + receivePackCounter.WithLabelValues(name).Inc() return - case git.UploadPackBin, git.UploadArchiveBin: + case git.UploadPackService, git.UploadArchiveService: if access < backend.ReadOnlyAccess { sshFatal(s, git.ErrNotAuthed) return } - gitPack := git.UploadPack - switch gc { - case git.UploadArchiveBin: - gitPack = git.UploadArchive + handler := git.UploadPack + switch service { + case git.UploadArchiveService: + handler = git.UploadArchive uploadArchiveCounter.WithLabelValues(name).Inc() defer func() { uploadArchiveSeconds.WithLabelValues(name).Add(time.Since(start).Seconds()) @@ -285,10 +307,9 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware { defer func() { uploadPackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds()) }() - } - err := gitPack(ctx, s, s, s.Stderr(), repoDir, envs...) + err := handler(ctx, cmd) if errors.Is(err, git.ErrInvalidRepo) { sshFatal(s, git.ErrInvalidRepo) } else if err != nil { diff --git a/server/web/git.go b/server/web/git.go new file mode 100644 index 000000000..2ca926591 --- /dev/null +++ b/server/web/git.go @@ -0,0 +1,459 @@ +package web + +import ( + "bytes" + "compress/gzip" + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + "time" + + "github.com/charmbracelet/log" + gitb "github.com/charmbracelet/soft-serve/git" + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/git" + "github.com/charmbracelet/soft-serve/server/utils" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "goji.io/pat" + "goji.io/pattern" +) + +// GitRoute is a route for git services. +type GitRoute struct { + method string + pattern *regexp.Regexp + handler http.HandlerFunc + + cfg *config.Config + be backend.Backend + logger *log.Logger +} + +var _ Route = GitRoute{} + +// Match implements goji.Pattern. +func (g GitRoute) Match(r *http.Request) *http.Request { + if g.method != r.Method { + return nil + } + + re := g.pattern + ctx := r.Context() + if m := re.FindStringSubmatch(r.URL.Path); m != nil { + file := strings.Replace(r.URL.Path, m[1]+"/", "", 1) + repo := utils.SanitizeRepo(m[1]) + ".git" + + var service git.Service + switch { + case strings.HasSuffix(r.URL.Path, git.UploadPackService.String()): + service = git.UploadPackService + case strings.HasSuffix(r.URL.Path, git.ReceivePackService.String()): + service = git.ReceivePackService + } + + ctx = context.WithValue(ctx, pattern.Variable("service"), service.String()) + ctx = context.WithValue(ctx, pattern.Variable("dir"), filepath.Join(g.cfg.DataPath, "repos", repo)) + ctx = context.WithValue(ctx, pattern.Variable("repo"), repo) + ctx = context.WithValue(ctx, pattern.Variable("file"), file) + + if g.cfg != nil { + ctx = config.WithContext(ctx, g.cfg) + } + + if g.be != nil { + ctx = backend.WithContext(ctx, g.be.WithContext(ctx)) + } + + if g.logger != nil { + ctx = log.WithContext(ctx, g.logger) + } + + return r.WithContext(ctx) + } + + return nil +} + +// ServeHTTP implements http.Handler. +func (g GitRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) { + g.handler(w, r) +} + +var ( + gitHttpReceiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "soft_serve", + Subsystem: "http", + Name: "git_receive_pack_total", + Help: "The total number of git push requests", + }, []string{"repo"}) + + gitHttpUploadCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "soft_serve", + Subsystem: "http", + Name: "git_upload_pack_total", + Help: "The total number of git fetch/pull requests", + }, []string{"repo", "file"}) +) + +func gitRoutes(ctx context.Context, logger *log.Logger) []Route { + routes := make([]Route, 0) + cfg := config.FromContext(ctx) + be := backend.FromContext(ctx) + + // Git services + // These routes don't handle authentication/authorization. + // This is handled through wrapping the handlers for each route. + // See below (withAccess). + // TODO: add lfs support + for _, route := range []GitRoute{ + { + pattern: regexp.MustCompile("(.*?)/git-upload-pack$"), + method: http.MethodPost, + handler: serviceRpc, + }, + { + pattern: regexp.MustCompile("(.*?)/git-receive-pack$"), + method: http.MethodPost, + handler: serviceRpc, + }, + { + pattern: regexp.MustCompile("(.*?)/info/refs$"), + method: http.MethodGet, + handler: getInfoRefs, + }, + { + pattern: regexp.MustCompile("(.*?)/HEAD$"), + method: http.MethodGet, + handler: getTextFile, + }, + { + pattern: regexp.MustCompile("(.*?)/objects/info/alternates$"), + method: http.MethodGet, + handler: getTextFile, + }, + { + pattern: regexp.MustCompile("(.*?)/objects/info/http-alternates$"), + method: http.MethodGet, + handler: getTextFile, + }, + { + pattern: regexp.MustCompile("(.*?)/objects/info/packs$"), + method: http.MethodGet, + handler: getInfoPacks, + }, + { + pattern: regexp.MustCompile("(.*?)/objects/info/[^/]*$"), + method: http.MethodGet, + handler: getTextFile, + }, + { + pattern: regexp.MustCompile("(.*?)/objects/[0-9a-f]{2}/[0-9a-f]{38}$"), + method: http.MethodGet, + handler: getLooseObject, + }, + { + pattern: regexp.MustCompile("(.*?)/objects/pack/pack-[0-9a-f]{40}\\.pack$"), + method: http.MethodGet, + handler: getPackFile, + }, + { + pattern: regexp.MustCompile("(.*?)/objects/pack/pack-[0-9a-f]{40}\\.idx$"), + method: http.MethodGet, + handler: getIdxFile, + }, + } { + route.cfg = cfg + route.be = be + route.logger = logger + route.handler = withAccess(route.handler) + routes = append(routes, route) + } + + return routes +} + +// withAccess handles auth. +func withAccess(fn http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + be := backend.FromContext(ctx) + logger := log.FromContext(ctx) + + if !be.AllowKeyless() { + renderForbidden(w) + return + } + + repo := pat.Param(r, "repo") + service := git.Service(pat.Param(r, "service")) + access := be.AccessLevel(repo, "") + + switch service { + case git.ReceivePackService: + if access < backend.ReadWriteAccess { + renderUnauthorized(w) + return + } + + // Create the repo if it doesn't exist. + if _, err := be.Repository(repo); err != nil { + if _, err := be.CreateRepository(repo, backend.RepositoryOptions{}); err != nil { + logger.Error("failed to create repository", "repo", repo, "err", err) + renderInternalServerError(w) + return + } + } + default: + if access < backend.ReadOnlyAccess { + renderUnauthorized(w) + return + } + } + + fn(w, r) + } +} + +func serviceRpc(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + logger := log.FromContext(ctx) + service, dir, repo := git.Service(pat.Param(r, "service")), pat.Param(r, "dir"), pat.Param(r, "repo") + + if !isSmart(r, service) { + renderForbidden(w) + return + } + + if service == git.ReceivePackService { + gitHttpReceiveCounter.WithLabelValues(repo) + } + + w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-result", service)) + w.Header().Set("Connection", "Keep-Alive") + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(http.StatusOK) + + version := r.Header.Get("Git-Protocol") + + cmd := git.ServiceCommand{ + Stdin: r.Body, + Stdout: w, + Dir: dir, + Args: []string{"--stateless-rpc"}, + } + + if len(version) != 0 { + cmd.Env = append(cmd.Env, fmt.Sprintf("GIT_PROTOCOL=%s", version)) + } + + // Handle gzip encoding + cmd.StdinHandler = func(in io.Reader, stdin io.WriteCloser) (err error) { + // We know that `in` is an `io.ReadCloser` because it's `r.Body`. + reader := in.(io.ReadCloser) + defer reader.Close() // nolint: errcheck + switch r.Header.Get("Content-Encoding") { + case "gzip": + reader, err = gzip.NewReader(reader) + if err != nil { + return err + } + defer reader.Close() // nolint: errcheck + } + + _, err = io.Copy(stdin, reader) + return err + } + + // Handle buffered output + // Useful when using proxies + cmd.StdoutHandler = func(out io.Writer, stdout io.ReadCloser) error { + // We know that `out` is an `http.ResponseWriter`. + flusher, ok := out.(http.Flusher) + if !ok { + return fmt.Errorf("expected http.ResponseWriter to be an http.Flusher, got %T", out) + } + + p := make([]byte, 1024) + for { + nRead, err := stdout.Read(p) + if err == io.EOF { + break + } + nWrite, err := out.Write(p[:nRead]) + if err != nil { + return err + } + if nRead != nWrite { + return fmt.Errorf("failed to write data: %d read, %d written", nRead, nWrite) + } + flusher.Flush() + } + + return nil + } + + if err := service.Handler(ctx, cmd); err != nil { + logger.Errorf("error executing service: %s", err) + } +} + +func getInfoRefs(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + logger := log.FromContext(ctx) + dir, repo, file := pat.Param(r, "dir"), pat.Param(r, "repo"), pat.Param(r, "file") + service := getServiceType(r) + version := r.Header.Get("Git-Protocol") + + gitHttpUploadCounter.WithLabelValues(repo, file).Inc() + + if service != "" && (service == git.UploadPackService || service == git.ReceivePackService) { + // Smart HTTP + var refs bytes.Buffer + cmd := git.ServiceCommand{ + Stdout: &refs, + Dir: dir, + Args: []string{"--stateless-rpc", "--advertise-refs"}, + } + + if len(version) != 0 { + cmd.Env = append(cmd.Env, fmt.Sprintf("GIT_PROTOCOL=%s", version)) + } + + if err := service.Handler(ctx, cmd); err != nil { + logger.Errorf("error executing service: %s", err) + renderNotFound(w) + return + } + + hdrNocache(w) + w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", service)) + w.WriteHeader(http.StatusOK) + if len(version) == 0 { + git.WritePktline(w, "# service="+service.String()) + } + + w.Write(refs.Bytes()) // nolint: errcheck + } else { + // Dumb HTTP + updateServerInfo(ctx, dir) // nolint: errcheck + hdrNocache(w) + sendFile("text/plain; charset=utf-8", w, r) + } +} + +func getInfoPacks(w http.ResponseWriter, r *http.Request) { + hdrCacheForever(w) + sendFile("text/plain; charset=utf-8", w, r) +} + +func getLooseObject(w http.ResponseWriter, r *http.Request) { + hdrCacheForever(w) + sendFile("application/x-git-loose-object", w, r) +} + +func getPackFile(w http.ResponseWriter, r *http.Request) { + hdrCacheForever(w) + sendFile("application/x-git-packed-objects", w, r) +} + +func getIdxFile(w http.ResponseWriter, r *http.Request) { + hdrCacheForever(w) + sendFile("application/x-git-packed-objects-toc", w, r) +} + +func getTextFile(w http.ResponseWriter, r *http.Request) { + hdrNocache(w) + sendFile("text/plain", w, r) +} + +func sendFile(contentType string, w http.ResponseWriter, r *http.Request) { + dir, file := pat.Param(r, "dir"), pat.Param(r, "file") + reqFile := filepath.Join(dir, file) + + f, err := os.Stat(reqFile) + if os.IsNotExist(err) { + renderNotFound(w) + return + } + + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Length", fmt.Sprintf("%d", f.Size())) + w.Header().Set("Last-Modified", f.ModTime().Format(http.TimeFormat)) + http.ServeFile(w, r, reqFile) +} + +func getServiceType(r *http.Request) git.Service { + service := r.FormValue("service") + if !strings.HasPrefix(service, "git-") { + return "" + } + + return git.Service(service) +} + +func isSmart(r *http.Request, service git.Service) bool { + if r.Header.Get("Content-Type") == fmt.Sprintf("application/x-%s-request", service) { + return true + } + return false +} + +func updateServerInfo(ctx context.Context, dir string) error { + return gitb.UpdateServerInfo(ctx, dir) +} + +// HTTP error response handling functions + +func renderMethodNotAllowed(w http.ResponseWriter, r *http.Request) { + if r.Proto == "HTTP/1.1" { + w.WriteHeader(http.StatusMethodNotAllowed) + w.Write([]byte("Method Not Allowed")) // nolint: errcheck + } else { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Bad Request")) // nolint: errcheck + } +} + +func renderNotFound(w http.ResponseWriter) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("Not Found")) // nolint: errcheck +} + +func renderUnauthorized(w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Unauthorized")) // nolint: errcheck +} + +func renderForbidden(w http.ResponseWriter) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte("Forbidden")) // nolint: errcheck +} + +func renderInternalServerError(w http.ResponseWriter) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal Server Error")) // nolint: errcheck +} + +// Header writing functions + +func hdrNocache(w http.ResponseWriter) { + w.Header().Set("Expires", "Fri, 01 Jan 1980 00:00:00 GMT") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Cache-Control", "no-cache, max-age=0, must-revalidate") +} + +func hdrCacheForever(w http.ResponseWriter) { + now := time.Now().Unix() + expires := now + 31536000 + w.Header().Set("Date", fmt.Sprintf("%d", now)) + w.Header().Set("Expires", fmt.Sprintf("%d", expires)) + w.Header().Set("Cache-Control", "public, max-age=31536000") +} diff --git a/server/web/goget.go b/server/web/goget.go new file mode 100644 index 000000000..7e7c8c9d6 --- /dev/null +++ b/server/web/goget.go @@ -0,0 +1,94 @@ +package web + +import ( + "net/http" + "net/url" + "path" + "text/template" + + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/utils" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "goji.io/pattern" +) + +var goGetCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "soft_serve", + Subsystem: "http", + Name: "go_get_total", + Help: "The total number of go get requests", +}, []string{"repo"}) + +var repoIndexHTMLTpl = template.Must(template.New("index").Parse(` + + + + + + + +Redirecting to docs at godoc.org/{{ .ImportRoot }}/{{ .Repo }}... + +`)) + +// GoGetHandler handles go get requests. +type GoGetHandler struct { + cfg *config.Config + be backend.Backend +} + +var _ http.Handler = (*GoGetHandler)(nil) + +func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + repo := pattern.Path(r.Context()) + repo = utils.SanitizeRepo(repo) + be := g.be.WithContext(r.Context()) + + // Handle go get requests. + // + // Always return a 200 status code, even if the repo doesn't exist. + // + // https://golang.org/cmd/go/#hdr-Remote_import_paths + // https://go.dev/ref/mod#vcs-branch + if r.URL.Query().Get("go-get") == "1" { + repo := repo + importRoot, err := url.Parse(g.cfg.HTTP.PublicURL) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // find the repo + for { + if _, err := be.Repository(repo); err == nil { + break + } + + if repo == "" || repo == "." || repo == "/" { + return + } + + repo = path.Dir(repo) + } + + if err := repoIndexHTMLTpl.Execute(w, struct { + Repo string + Config *config.Config + ImportRoot string + }{ + Repo: url.PathEscape(repo), + Config: g.cfg, + ImportRoot: importRoot.Host, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + goGetCounter.WithLabelValues(repo).Inc() + return + } + + http.NotFound(w, r) +} diff --git a/server/web/http.go b/server/web/http.go index b932f9416..7ff375cab 100644 --- a/server/web/http.go +++ b/server/web/http.go @@ -2,103 +2,31 @@ package web import ( "context" - "fmt" "net/http" - "net/url" - "path" - "path/filepath" - "regexp" - "strings" - "text/template" "time" - "github.com/charmbracelet/log" "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" - "github.com/charmbracelet/soft-serve/server/utils" - "github.com/dustin/go-humanize" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - "goji.io" - "goji.io/pat" - "goji.io/pattern" ) -var ( - gitHttpCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: "soft_serve", - Subsystem: "http", - Name: "git_fetch_pull_total", - Help: "The total number of git fetch/pull requests", - }, []string{"repo", "file"}) - - goGetCounter = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: "soft_serve", - Subsystem: "http", - Name: "go_get_total", - Help: "The total number of go get requests", - }, []string{"repo"}) -) - -// logWriter is a wrapper around http.ResponseWriter that allows us to capture -// the HTTP status code and bytes written to the response. -type logWriter struct { - http.ResponseWriter - code, bytes int -} - -func (r *logWriter) Write(p []byte) (int, error) { - written, err := r.ResponseWriter.Write(p) - r.bytes += written - return written, err -} - -// Note this is generally only called when sending an HTTP error, so it's -// important to set the `code` value to 200 as a default -func (r *logWriter) WriteHeader(code int) { - r.code = code - r.ResponseWriter.WriteHeader(code) -} - -func (s *HTTPServer) loggingMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - writer := &logWriter{code: http.StatusOK, ResponseWriter: w} - s.logger.Debug("request", - "method", r.Method, - "uri", r.RequestURI, - "addr", r.RemoteAddr) - next.ServeHTTP(writer, r) - elapsed := time.Since(start) - s.logger.Debug("response", - "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)), - "bytes", humanize.Bytes(uint64(writer.bytes)), - "time", elapsed) - }) -} - // HTTPServer is an http server. type HTTPServer struct { - ctx context.Context - cfg *config.Config - be backend.Backend - server *http.Server - dirHandler http.Handler - logger *log.Logger + ctx context.Context + cfg *config.Config + be backend.Backend + server *http.Server } +// NewHTTPServer creates a new HTTP server. func NewHTTPServer(ctx context.Context) (*HTTPServer, error) { cfg := config.FromContext(ctx) - mux := goji.NewMux() s := &HTTPServer{ - ctx: ctx, - cfg: cfg, - be: backend.FromContext(ctx), - logger: log.FromContext(ctx).WithPrefix("http"), - dirHandler: http.FileServer(http.Dir(filepath.Join(cfg.DataPath, "repos"))), + ctx: ctx, + cfg: cfg, + be: backend.FromContext(ctx), server: &http.Server{ Addr: cfg.HTTP.ListenAddr, - Handler: mux, + Handler: NewRouter(ctx), ReadHeaderTimeout: time.Second * 10, ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, @@ -106,21 +34,6 @@ func NewHTTPServer(ctx context.Context) (*HTTPServer, error) { }, } - mux.Use(s.loggingMiddleware) - for _, m := range []Matcher{ - getInfoRefs, - getHead, - getAlternates, - getHTTPAlternates, - getInfoPacks, - getInfoFile, - getLooseObject, - getPackFile, - getIdxFile, - } { - mux.HandleFunc(NewPattern(m), s.handleGit) - } - mux.HandleFunc(pat.Get("/*"), s.handleIndex) return s, nil } @@ -141,193 +54,3 @@ func (s *HTTPServer) ListenAndServe() error { func (s *HTTPServer) Shutdown(ctx context.Context) error { return s.server.Shutdown(ctx) } - -// Pattern is a pattern for matching a URL. -// It matches against GET requests. -type Pattern struct { - match func(*url.URL) *match -} - -// NewPattern returns a new Pattern with the given matcher. -func NewPattern(m Matcher) *Pattern { - return &Pattern{ - match: m, - } -} - -// Match is a match for a URL. -// -// It implements goji.Pattern. -func (p *Pattern) Match(r *http.Request) *http.Request { - if r.Method != "GET" { - return nil - } - - if m := p.match(r.URL); m != nil { - ctx := context.WithValue(r.Context(), pattern.Variable("repo"), m.RepoPath) - ctx = context.WithValue(ctx, pattern.Variable("file"), m.FilePath) - return r.WithContext(ctx) - } - return nil -} - -// Matcher finds a match in a *url.URL. -type Matcher = func(*url.URL) *match - -var ( - getInfoRefs = func(u *url.URL) *match { - return matchSuffix(u.Path, "/info/refs") - } - - getHead = func(u *url.URL) *match { - return matchSuffix(u.Path, "/HEAD") - } - - getAlternates = func(u *url.URL) *match { - return matchSuffix(u.Path, "/objects/info/alternates") - } - - getHTTPAlternates = func(u *url.URL) *match { - return matchSuffix(u.Path, "/objects/info/http-alternates") - } - - getInfoPacks = func(u *url.URL) *match { - return matchSuffix(u.Path, "/objects/info/packs") - } - - getInfoFileRegexp = regexp.MustCompile(".*?(/objects/info/[^/]*)$") - getInfoFile = func(u *url.URL) *match { - return findStringSubmatch(u.Path, getInfoFileRegexp) - } - - getLooseObjectRegexp = regexp.MustCompile(".*?(/objects/[0-9a-f]{2}/[0-9a-f]{38})$") - getLooseObject = func(u *url.URL) *match { - return findStringSubmatch(u.Path, getLooseObjectRegexp) - } - - getPackFileRegexp = regexp.MustCompile(`.*?(/objects/pack/pack-[0-9a-f]{40}\.pack)$`) - getPackFile = func(u *url.URL) *match { - return findStringSubmatch(u.Path, getPackFileRegexp) - } - - getIdxFileRegexp = regexp.MustCompile(`.*?(/objects/pack/pack-[0-9a-f]{40}\.idx)$`) - getIdxFile = func(u *url.URL) *match { - return findStringSubmatch(u.Path, getIdxFileRegexp) - } -) - -// match represents a match for a URL. -type match struct { - RepoPath, FilePath string -} - -func matchSuffix(path, suffix string) *match { - if !strings.HasSuffix(path, suffix) { - return nil - } - repoPath := strings.Replace(path, suffix, "", 1) - filePath := strings.Replace(path, repoPath+"/", "", 1) - return &match{repoPath, filePath} -} - -func findStringSubmatch(path string, prefix *regexp.Regexp) *match { - m := prefix.FindStringSubmatch(path) - if m == nil { - return nil - } - suffix := m[1] - repoPath := strings.Replace(path, suffix, "", 1) - filePath := strings.Replace(path, repoPath+"/", "", 1) - return &match{repoPath, filePath} -} - -var repoIndexHTMLTpl = template.Must(template.New("index").Parse(` - - - - - - - -Redirecting to docs at godoc.org/{{ .ImportRoot }}/{{ .Repo }}... - -`)) - -func (s *HTTPServer) handleIndex(w http.ResponseWriter, r *http.Request) { - repo := pattern.Path(r.Context()) - repo = utils.SanitizeRepo(repo) - be := s.be.WithContext(r.Context()) - - // Handle go get requests. - // - // Always return a 200 status code, even if the repo doesn't exist. - // - // https://golang.org/cmd/go/#hdr-Remote_import_paths - // https://go.dev/ref/mod#vcs-branch - if r.URL.Query().Get("go-get") == "1" { - repo := repo - importRoot, err := url.Parse(s.cfg.HTTP.PublicURL) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // find the repo - for { - if _, err := be.Repository(repo); err == nil { - break - } - - if repo == "" || repo == "." || repo == "/" { - return - } - - repo = path.Dir(repo) - } - - if err := repoIndexHTMLTpl.Execute(w, struct { - Repo string - Config *config.Config - ImportRoot string - }{ - Repo: url.PathEscape(repo), - Config: s.cfg, - ImportRoot: importRoot.Host, - }); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - goGetCounter.WithLabelValues(repo).Inc() - return - } - - http.NotFound(w, r) -} - -func (s *HTTPServer) handleGit(w http.ResponseWriter, r *http.Request) { - repo := pat.Param(r, "repo") - repo = utils.SanitizeRepo(repo) + ".git" - be := s.be.WithContext(r.Context()) - if _, err := be.Repository(repo); err != nil { - s.logger.Debug("repository not found", "repo", repo, "err", err) - http.NotFound(w, r) - return - } - - if !s.cfg.Backend.AllowKeyless() { - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - - access := s.cfg.Backend.AccessLevel(repo, "") - if access < backend.ReadOnlyAccess { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - - file := pat.Param(r, "file") - gitHttpCounter.WithLabelValues(repo, file).Inc() - r.URL.Path = fmt.Sprintf("/%s/%s", repo, file) - s.dirHandler.ServeHTTP(w, r) -} diff --git a/server/web/logging.go b/server/web/logging.go new file mode 100644 index 000000000..f0f43a05c --- /dev/null +++ b/server/web/logging.go @@ -0,0 +1,84 @@ +package web + +import ( + "bufio" + "fmt" + "net" + "net/http" + "time" + + "github.com/charmbracelet/log" + "github.com/dustin/go-humanize" +) + +// logWriter is a wrapper around http.ResponseWriter that allows us to capture +// the HTTP status code and bytes written to the response. +type logWriter struct { + http.ResponseWriter + code, bytes int +} + +var _ http.ResponseWriter = (*logWriter)(nil) + +var _ http.Flusher = (*logWriter)(nil) + +var _ http.Hijacker = (*logWriter)(nil) + +var _ http.CloseNotifier = (*logWriter)(nil) + +// Write implements http.ResponseWriter. +func (r *logWriter) Write(p []byte) (int, error) { + written, err := r.ResponseWriter.Write(p) + r.bytes += written + return written, err +} + +// Note this is generally only called when sending an HTTP error, so it's +// important to set the `code` value to 200 as a default. +func (r *logWriter) WriteHeader(code int) { + r.code = code + r.ResponseWriter.WriteHeader(code) +} + +// Flush implements http.Flusher. +func (r *logWriter) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// CloseNotify implements http.CloseNotifier. +func (r *logWriter) CloseNotify() <-chan bool { + if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + return nil +} + +// Hijack implements http.Hijacker. +func (r *logWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := r.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, fmt.Errorf("http.Hijacker not implemented") +} + +// NewLoggingMiddleware returns a new logging middleware. +func NewLoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + writer := &logWriter{code: http.StatusOK, ResponseWriter: w} + logger.Debug("request", + "method", r.Method, + "uri", r.RequestURI, + "addr", r.RemoteAddr) + next.ServeHTTP(writer, r) + elapsed := time.Since(start) + logger.Debug("response", + "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)), + "bytes", humanize.Bytes(uint64(writer.bytes)), + "time", elapsed) + }) + } +} diff --git a/server/web/server.go b/server/web/server.go new file mode 100644 index 000000000..ea15e778f --- /dev/null +++ b/server/web/server.go @@ -0,0 +1,40 @@ +// Package server is the reusable server +package web + +import ( + "context" + "net/http" + + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/config" + "goji.io" + "goji.io/pat" +) + +// Route is an interface for a route. +type Route interface { + http.Handler + goji.Pattern +} + +// NewRouter returns a new HTTP router. +func NewRouter(ctx context.Context) *goji.Mux { + mux := goji.NewMux() + cfg := config.FromContext(ctx) + be := backend.FromContext(ctx) + logger := log.FromContext(ctx).WithPrefix("http") + + // Middlewares + mux.Use(NewLoggingMiddleware(logger)) + + // Git routes + for _, service := range gitRoutes(ctx, logger) { + mux.Handle(service, service) + } + + // go-get handler + mux.Handle(pat.Get("/*"), GoGetHandler{cfg, be}) + + return mux +}