diff --git a/app/api_report.go b/app/api_report.go index 0aa046f043..54ad023523 100644 --- a/app/api_report.go +++ b/app/api_report.go @@ -9,6 +9,11 @@ import ( // Raw report handler func makeRawReportHandler(rep Reporter) CtxHandlerFunc { return func(ctx context.Context, w http.ResponseWriter, r *http.Request) { - respondWith(w, http.StatusOK, rep.Report(ctx)) + report, err := rep.Report(ctx) + if err != nil { + respondWith(w, http.StatusInternalServerError, err.Error()) + return + } + respondWith(w, http.StatusOK, report) } } diff --git a/app/api_topologies.go b/app/api_topologies.go index 4d2e424e8b..6930aeaf4e 100644 --- a/app/api_topologies.go +++ b/app/api_topologies.go @@ -195,7 +195,12 @@ func (r *registry) walk(f func(APITopologyDesc)) { // makeTopologyList returns a handler that yields an APITopologyList. func (r *registry) makeTopologyList(rep Reporter) CtxHandlerFunc { return func(ctx context.Context, w http.ResponseWriter, req *http.Request) { - topologies := r.renderTopologies(rep.Report(ctx), req) + report, err := rep.Report(ctx) + if err != nil { + respondWith(w, http.StatusInternalServerError, err.Error()) + return + } + topologies := r.renderTopologies(report, req) respondWith(w, http.StatusOK, topologies) } } diff --git a/app/api_topology.go b/app/api_topology.go index 639ecfb5f7..e28df96193 100644 --- a/app/api_topology.go +++ b/app/api_topology.go @@ -28,8 +28,13 @@ type APINode struct { // Full topology. func handleTopology(ctx context.Context, rep Reporter, renderer render.Renderer, w http.ResponseWriter, r *http.Request) { + report, err := rep.Report(ctx) + if err != nil { + respondWith(w, http.StatusInternalServerError, err.Error()) + return + } respondWith(w, http.StatusOK, APITopology{ - Nodes: renderer.Render(rep.Report(ctx)).Prune(), + Nodes: renderer.Render(report).Prune(), }) } @@ -54,15 +59,19 @@ func handleWs(ctx context.Context, rep Reporter, renderer render.Renderer, w htt func handleNode(topologyID, nodeID string) func(context.Context, Reporter, render.Renderer, http.ResponseWriter, *http.Request) { return func(ctx context.Context, rep Reporter, renderer render.Renderer, w http.ResponseWriter, r *http.Request) { var ( - rpt = rep.Report(ctx) - rendered = renderer.Render(rep.Report(ctx)) - node, ok = rendered[nodeID] + report, err = rep.Report(ctx) + rendered = renderer.Render(report) + node, ok = rendered[nodeID] ) + if err != nil { + respondWith(w, http.StatusInternalServerError, err.Error()) + return + } if !ok { http.NotFound(w, r) return } - respondWith(w, http.StatusOK, APINode{Node: detailed.MakeNode(topologyID, rpt, rendered, node)}) + respondWith(w, http.StatusOK, APINode{Node: detailed.MakeNode(topologyID, report, rendered, node)}) } } @@ -103,7 +112,12 @@ func handleWebsocket( defer rep.UnWait(ctx, wait) for { - newTopo := renderer.Render(rep.Report(ctx)).Prune() + report, err := rep.Report(ctx) + if err != nil { + log.Errorf("Error generating report: %v", err) + return + } + newTopo := renderer.Render(report).Prune() diff := render.TopoDiff(previousTopo, newTopo) previousTopo = newTopo diff --git a/app/collector.go b/app/collector.go index 6c2f72730b..7269e1bffe 100644 --- a/app/collector.go +++ b/app/collector.go @@ -14,7 +14,7 @@ import ( // Reporter is something that can produce reports on demand. It's a convenient // interface for parts of the app, and several experimental components. type Reporter interface { - Report(context.Context) report.Report + Report(context.Context) (report.Report, error) WaitOn(context.Context, chan struct{}) UnWait(context.Context, chan struct{}) } @@ -22,7 +22,7 @@ type Reporter interface { // Adder is something that can accept reports. It's a convenient interface for // parts of the app, and several experimental components. type Adder interface { - Add(context.Context, report.Report) + Add(context.Context, report.Report) error } // A Collector is a Reporter and an Adder @@ -83,7 +83,7 @@ func NewCollector(window time.Duration) Collector { var now = time.Now // Add adds a report to the collector's internal state. It implements Adder. -func (c *collector) Add(_ context.Context, rpt report.Report) { +func (c *collector) Add(_ context.Context, rpt report.Report) error { c.mtx.Lock() defer c.mtx.Unlock() c.reports = append(c.reports, timestampReport{now(), rpt}) @@ -92,11 +92,12 @@ func (c *collector) Add(_ context.Context, rpt report.Report) { if rpt.Shortcut { c.Broadcast() } + return nil } // Report returns a merged report over all added reports. It implements // Reporter. -func (c *collector) Report(_ context.Context) report.Report { +func (c *collector) Report(_ context.Context) (report.Report, error) { c.mtx.Lock() defer c.mtx.Unlock() @@ -105,7 +106,7 @@ func (c *collector) Report(_ context.Context) report.Report { if c.cached != nil && len(c.reports) > 0 { oldest := now().Add(-c.window) if c.reports[0].timestamp.Before(oldest) { - return *c.cached + return *c.cached, nil } } c.reports = clean(c.reports, c.window) @@ -118,7 +119,7 @@ func (c *collector) Report(_ context.Context) report.Report { } rpt.ID = fmt.Sprintf("%x", id.Sum64()) c.cached = &rpt - return rpt + return rpt, nil } type timestampReport struct { diff --git a/app/collector_test.go b/app/collector_test.go index 81caf11724..4006b05e75 100644 --- a/app/collector_test.go +++ b/app/collector_test.go @@ -23,12 +23,20 @@ func TestCollector(t *testing.T) { r2 := report.MakeReport() r2.Endpoint.AddNode("bar", report.MakeNode()) - if want, have := report.MakeReport(), c.Report(ctx); !reflect.DeepEqual(want, have) { + have, err := c.Report(ctx) + if err != nil { + t.Error(err) + } + if want := report.MakeReport(); !reflect.DeepEqual(want, have) { t.Error(test.Diff(want, have)) } c.Add(ctx, r1) - if want, have := r1, c.Report(ctx); !reflect.DeepEqual(want, have) { + have, err = c.Report(ctx) + if err != nil { + t.Error(err) + } + if want := r1; !reflect.DeepEqual(want, have) { t.Error(test.Diff(want, have)) } @@ -36,7 +44,11 @@ func TestCollector(t *testing.T) { merged := report.MakeReport() merged = merged.Merge(r1) merged = merged.Merge(r2) - if want, have := merged, c.Report(ctx); !reflect.DeepEqual(want, have) { + have, err = c.Report(ctx) + if err != nil { + t.Error(err) + } + if want := merged; !reflect.DeepEqual(want, have) { t.Error(test.Diff(want, have)) } } diff --git a/app/mock_reporter_test.go b/app/mock_reporter_test.go index b5efe5e5b6..9903357e52 100644 --- a/app/mock_reporter_test.go +++ b/app/mock_reporter_test.go @@ -10,7 +10,7 @@ import ( // StaticReport is used as a fixture in tests. It emulates an xfer.Collector. type StaticReport struct{} -func (s StaticReport) Report(context.Context) report.Report { return fixture.Report } -func (s StaticReport) Add(context.Context, report.Report) {} -func (s StaticReport) WaitOn(context.Context, chan struct{}) {} -func (s StaticReport) UnWait(context.Context, chan struct{}) {} +func (s StaticReport) Report(context.Context) (report.Report, error) { return fixture.Report, nil } +func (s StaticReport) Add(context.Context, report.Report) error { return nil } +func (s StaticReport) WaitOn(context.Context, chan struct{}) {} +func (s StaticReport) UnWait(context.Context, chan struct{}) {} diff --git a/app/pipe_router.go b/app/pipe_router.go index 5bdfdf9c48..3db9813350 100644 --- a/app/pipe_router.go +++ b/app/pipe_router.go @@ -1,6 +1,7 @@ package app import ( + "fmt" "io" "sync" "time" @@ -29,9 +30,9 @@ const ( // PipeRouter stores pipes and allows you to connect to either end of them. type PipeRouter interface { - Get(context.Context, string, End) (xfer.Pipe, io.ReadWriter, bool) - Release(context.Context, string, End) - Delete(context.Context, string) + Get(context.Context, string, End) (xfer.Pipe, io.ReadWriter, error) + Release(context.Context, string, End) error + Delete(context.Context, string) error Stop() } @@ -77,7 +78,7 @@ func NewLocalPipeRouter() PipeRouter { return pipeRouter } -func (pr *localPipeRouter) Get(_ context.Context, id string, e End) (xfer.Pipe, io.ReadWriter, bool) { +func (pr *localPipeRouter) Get(_ context.Context, id string, e End) (xfer.Pipe, io.ReadWriter, error) { pr.Lock() defer pr.Unlock() p, ok := pr.pipes[id] @@ -91,43 +92,45 @@ func (pr *localPipeRouter) Get(_ context.Context, id string, e End) (xfer.Pipe, pr.pipes[id] = p } if p.Closed() { - return nil, nil, false + return nil, nil, fmt.Errorf("Pipe %s closed", id) } end, endIO := p.end(e) end.refCount++ - return p, endIO, true + return p, endIO, nil } -func (pr *localPipeRouter) Release(_ context.Context, id string, e End) { +func (pr *localPipeRouter) Release(_ context.Context, id string, e End) error { pr.Lock() defer pr.Unlock() p, ok := pr.pipes[id] if !ok { - // uh oh - return + return fmt.Errorf("Pipe %s not found", id) } end, _ := p.end(e) end.refCount-- if end.refCount > 0 { - return + return nil } if !p.Closed() { end.lastUsedTime = mtime.Now() } + + return nil } -func (pr *localPipeRouter) Delete(_ context.Context, id string) { +func (pr *localPipeRouter) Delete(_ context.Context, id string) error { pr.Lock() defer pr.Unlock() p, ok := pr.pipes[id] if !ok { - return + return nil } p.Close() p.tombstoneTime = mtime.Now() + return nil } func (pr *localPipeRouter) Stop() { diff --git a/app/pipes.go b/app/pipes.go index 6bb5742f66..7f98f65e8a 100644 --- a/app/pipes.go +++ b/app/pipes.go @@ -32,8 +32,8 @@ func RegisterPipeRoutes(router *mux.Router, pr PipeRouter) { func checkPipe(pr PipeRouter, end End) CtxHandlerFunc { return func(ctx context.Context, w http.ResponseWriter, r *http.Request) { id := mux.Vars(r)["pipeID"] - _, _, ok := pr.Get(ctx, id, end) - if !ok { + _, _, err := pr.Get(ctx, id, end) + if err != nil { w.WriteHeader(http.StatusNoContent) return } @@ -44,8 +44,9 @@ func checkPipe(pr PipeRouter, end End) CtxHandlerFunc { func handlePipeWs(pr PipeRouter, end End) CtxHandlerFunc { return func(ctx context.Context, w http.ResponseWriter, r *http.Request) { id := mux.Vars(r)["pipeID"] - pipe, endIO, ok := pr.Get(ctx, id, end) - if !ok { + pipe, endIO, err := pr.Get(ctx, id, end) + if err != nil { + log.Errorf("Error getting pipe %s: %v", id, err) http.NotFound(w, r) return } @@ -69,6 +70,8 @@ func deletePipe(pr PipeRouter) CtxHandlerFunc { return func(ctx context.Context, w http.ResponseWriter, r *http.Request) { pipeID := mux.Vars(r)["pipeID"] log.Infof("Closing pipe %s", pipeID) - pr.Delete(ctx, pipeID) + if err := pr.Delete(ctx, pipeID); err != nil { + respondWith(w, http.StatusInternalServerError, err.Error()) + } } } diff --git a/app/pipes_internal_test.go b/app/pipes_internal_test.go index 1a6643933e..3089197555 100644 --- a/app/pipes_internal_test.go +++ b/app/pipes_internal_test.go @@ -33,9 +33,9 @@ func TestPipeTimeout(t *testing.T) { // create a new pipe. id := "foo" ctx := context.Background() - pipe, _, ok := pr.Get(ctx, id, UIEnd) - if !ok { - t.Fatalf("not ok") + pipe, _, err := pr.Get(ctx, id, UIEnd) + if err != nil { + t.Fatalf("not ok: %v", err) } // move time forward such that the new pipe should timeout diff --git a/app/router_test.go b/app/router_test.go index cc80b522cd..563320ca47 100644 --- a/app/router_test.go +++ b/app/router_test.go @@ -75,7 +75,11 @@ func TestReportPostHandler(t *testing.T) { } ctx := context.Background() - if want, have := fixture.Report.Endpoint.Nodes, c.Report(ctx).Endpoint.Nodes; len(have) == 0 || len(want) != len(have) { + report, err := c.Report(ctx) + if err != nil { + t.Error(err) + } + if want, have := fixture.Report.Endpoint.Nodes, report.Endpoint.Nodes; len(have) == 0 || len(want) != len(have) { t.Fatalf("Content-Type %s: %v", contentType, test.Diff(have, want)) } }