Skip to content

Commit

Permalink
adds tests for http server
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhang93 committed May 11, 2024
1 parent cfb57a7 commit e86c126
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 24 deletions.
8 changes: 8 additions & 0 deletions cmd/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/shubhang93/tplagent/internal/agent"
"github.com/shubhang93/tplagent/internal/config"
"github.com/shubhang93/tplagent/internal/fatal"
"github.com/shubhang93/tplagent/internal/httplis"
"log/slog"
"os"
"os/signal"
Expand Down Expand Up @@ -107,13 +108,20 @@ func spawn(ctx context.Context, processMaker func(logger *slog.Logger) agentProc
logger.Info("starting agent")
}

wait := make(chan struct{})
go func() {
defer close(wait)
httplis.Start(ctx, conf.Agent.HTTPListenerAddr, logger)
}()

proc := processMaker(logger)
err := proc.Start(ctx, conf)
if err != nil {
logger.Error("agent exited with error", slog.String("error", err.Error()))
return err
}
logger.Info("agent exited without errors")
<-wait
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func Read(rr io.Reader) (TPLAgent, error) {
func Validate(c *TPLAgent) error {
var valErrs []error
if _, ok := allowedLogFmts[c.Agent.LogFmt]; !ok {
valErrs = append(valErrs, fmt.Errorf("validate:invalid log level"))
valErrs = append(valErrs, fmt.Errorf("validate:invalid log format"))
}

for tmplName, tmplConfig := range c.TemplateSpecs {
Expand Down
117 changes: 94 additions & 23 deletions internal/httplis/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ import (
"errors"
"fmt"
"github.com/shubhang93/tplagent/internal/config"
"io"
"log/slog"
"net/http"
"os"
"strings"
"syscall"
)

type reloadRequest struct {
Expand All @@ -18,35 +21,15 @@ type reloadRequest struct {

func Start(ctx context.Context, addr string, l *slog.Logger) {
mux := http.NewServeMux()
mux.HandleFunc("POST /config/reload", func(writer http.ResponseWriter, request *http.Request) {
reloadReq := reloadRequest{}
err := json.NewDecoder(request.Body).Decode(&reloadReq)
if err != nil {
writeJSON(writer, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}

_, err = os.Stat(reloadReq.ConfigPath)
if errors.Is(err, os.ErrNotExist) {
writeJSON(writer, http.StatusInternalServerError, map[string]string{
"error": fmt.Sprintf("file not found at %s", reloadReq.ConfigPath),
})
return
}

err = config.Validate(&reloadReq.Config)
if err != nil {
writeJSON(writer, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}

})

mux.HandleFunc("POST /config/reload", reloadConfig)

s := http.Server{Handler: mux, Addr: addr}

wait := make(chan struct{})
go func() {
defer close(wait)
<-ctx.Done()
_ = s.Shutdown(ctx)
}()

Expand All @@ -64,3 +47,91 @@ func writeJSON(writer http.ResponseWriter, status int, data any) {
_ = json.NewEncoder(writer).Encode(data)
return
}

func reloadConfig(writer http.ResponseWriter, request *http.Request) {

proc, err := os.FindProcess(os.Getpid())
if err != nil {
writeJSON(writer, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}

reloadReq := reloadRequest{}
err = json.NewDecoder(request.Body).Decode(&reloadReq)
if err != nil {
writeJSON(writer, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}

configFilePath := reloadReq.ConfigPath
_, err = os.Stat(configFilePath)
if errors.Is(err, os.ErrNotExist) {
writeJSON(writer, http.StatusNotFound, map[string]string{
"error": fmt.Sprintf("file not found at %s", configFilePath),
})
return
}

err = config.Validate(&reloadReq.Config)
if err != nil {
writeJSON(writer, http.StatusBadRequest, map[string]string{"error": err.Error()})
return
}

err = backupAndReplace(configFilePath, reloadReq.Config)
if err != nil {
writeJSON(writer, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}

err = proc.Signal(syscall.SIGHUP)
if err != nil {
writeJSON(writer, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}

writeJSON(writer, http.StatusOK, map[string]bool{"success": true})

}

func backupAndReplace(path string, newConfig config.TPLAgent) error {
bakFilename := fmt.Sprintf("%s.%s", path, "bak")
bakFile, err := os.Create(bakFilename)
if err != nil {
return err
}

oldFile, err := os.Open(path)
if err != nil {
return err
}

_, err = io.Copy(bakFile, oldFile)
if err != nil {
return err
}
_ = bakFile.Close()

tempFilename := fmt.Sprintf("%s.%s", path, "temp")
tempFile, err := os.Create(tempFilename)
if err != nil {
_ = os.Remove(bakFilename)
return err
}

jd := json.NewEncoder(tempFile)
jd.SetIndent("", strings.Repeat(" ", 2))
err = jd.Encode(newConfig)
if err != nil {
_ = os.Remove(tempFilename)
return err
}

err = os.Rename(tempFilename, path)
if err != nil {
_ = os.Remove(tempFilename)
return err
}
return nil

}
145 changes: 145 additions & 0 deletions internal/httplis/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package httplis

import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"testing"
"time"
)

func TestStart(t *testing.T) {

type reloadTest struct {
name string
wantStatus int
jsonBody func(string) string
beforeFunc func(string) error
wantSIGHUP bool
}

reloadTests := []reloadTest{
{
name: "config path does not exist",
wantStatus: http.StatusNotFound,
jsonBody: func(_ string) string {
return `{
"config_path": "/some/path"
}`
},
},
{
name: "invalid config",
beforeFunc: func(tmp string) error {
_, err := os.Create(tmp + "/config.json")
return err
},
wantStatus: http.StatusBadRequest,
jsonBody: func(tmp string) string {
return fmt.Sprintf(`{
"config_path": "%s",
"config": {
"agent": {
"log_fmt": "invalid"
}
}
}`, tmp+"/config.json")
},
},
{
name: "valid config",
wantStatus: http.StatusOK,
jsonBody: func(tmp string) string {
return fmt.Sprintf(`{
"config_path": "%s",
"config": {
"agent": {
"log_fmt": "text",
"log_level": "INFO",
"http_listener": "localhost:5000"
},
"templates": {
"server-conf": {
"raw": "hello {{.name}}"
}
}
}
}`, tmp+"/config.json")
},

beforeFunc: func(tmp string) error {
_, err := os.Create(tmp + "/config.json")
return err
},
wantSIGHUP: true,
},
}

for _, rt := range reloadTests {
t.Run(rt.name, func(t *testing.T) {

tmp := t.TempDir()
if rt.beforeFunc != nil {
if err := rt.beforeFunc(tmp); err != nil {
t.Errorf("before func run error:%v", err)
return
}
}

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

var wg sync.WaitGroup
sighup := make(chan os.Signal)
signal.Notify(sighup, syscall.SIGHUP)
sighupRcvd := false

wg.Add(1)
go func() {
defer wg.Done()
select {
case <-ctx.Done():
case <-sighup:
sighupRcvd = true
}
}()

wg.Add(1)
go func() {
defer wg.Done()
Start(ctx, "localhost:6000", newLogger())
}()

rdr := strings.NewReader(rt.jsonBody(tmp))
resp, err := http.Post("http://localhost:6000/config/reload", "application/json", rdr)
if err != nil {
t.Errorf("POST error:%v", err)
return
}
respBody, _ := io.ReadAll(resp.Body)
t.Log(string(respBody))
if resp.StatusCode != rt.wantStatus {
t.Errorf("expected status to be %d got %d", rt.wantStatus, resp.StatusCode)
return
}

wg.Wait()
if rt.wantSIGHUP != sighupRcvd {
t.Error("SIGHUP not received")
}

})
}

}

func newLogger() *slog.Logger {
return slog.New(slog.NewTextHandler(os.Stdout, nil))
}

0 comments on commit e86c126

Please sign in to comment.