Skip to content

Commit

Permalink
refactor(controller): use an atomic.Value for storing routes (#1486)
Browse files Browse the repository at this point in the history
This avoids having to acquire the lock to access the routing table,
which can be easily forgotten.
  • Loading branch information
alecthomas authored May 15, 2024
1 parent 242ad90 commit 16fb60c
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"connectrpc.com/connect"
"github.com/alecthomas/atomic"
"github.com/alecthomas/concurrency"
"github.com/alecthomas/kong"
"github.com/alecthomas/types/optional"
Expand Down Expand Up @@ -164,8 +165,7 @@ type Service struct {
// Map from endpoint to client.
clients *ttlcache.Cache[string, clients]

routesMu sync.RWMutex
routes map[string][]dal.Route
routes atomic.Value[map[string][]dal.Route]
config Config
runnerScaling scaling.RunnerScaling

Expand All @@ -192,11 +192,11 @@ func New(ctx context.Context, db *dal.DAL, config Config, runnerScaling scaling.
key: key,
deploymentLogsSink: newDeploymentLogsSink(ctx, db),
clients: ttlcache.New[string, clients](ttlcache.WithTTL[string, clients](time.Minute)),
routes: map[string][]dal.Route{},
config: config,
runnerScaling: runnerScaling,
increaseReplicaFailures: map[string]int{},
}
svc.routes.Store(map[string][]dal.Route{})

cronSvc := cronjobs.New(ctx, key, svc.config.Advertise.Host, cronjobs.Config{Timeout: config.CronJobTimeout}, db, svc.tasks, svc.callWithRequest)
svc.cronJobs = cronSvc
Expand Down Expand Up @@ -303,8 +303,8 @@ func (s *Service) Status(ctx context.Context, req *connect.Request[ftlv1.StatusR
if err != nil {
return nil, fmt.Errorf("could not get status: %w", err)
}
s.routesMu.RLock()
routes := slices.FlatMap(maps.Values(s.routes), func(routes []dal.Route) (out []*ftlv1.StatusResponse_Route) {
sroutes := s.routes.Load()
routes := slices.FlatMap(maps.Values(sroutes), func(routes []dal.Route) (out []*ftlv1.StatusResponse_Route) {
out = make([]*ftlv1.StatusResponse_Route, len(routes))
for i, route := range routes {
out[i] = &ftlv1.StatusResponse_Route{
Expand All @@ -316,7 +316,6 @@ func (s *Service) Status(ctx context.Context, req *connect.Request[ftlv1.StatusR
}
return out
})
s.routesMu.RUnlock()
replicas := map[string]int32{}
protoRunners, err := slices.MapErr(status.Runners, func(r dal.Runner) (*ftlv1.StatusResponse_Runner, error) {
var deployment *string
Expand Down Expand Up @@ -543,9 +542,7 @@ func (s *Service) RegisterRunner(ctx context.Context, stream *connect.ClientStre
} else if err != nil {
return nil, err
}
s.routesMu.Lock()
s.routes = routes
s.routesMu.Unlock()
s.routes.Store(routes)
}
if stream.Err() != nil {
return nil, stream.Err()
Expand Down Expand Up @@ -743,9 +740,7 @@ func (s *Service) callWithRequest(
}

module := verbRef.Module
s.routesMu.RLock()
routes, ok := s.routes[module]
s.routesMu.RUnlock()
routes, ok := s.routes.Load()[module]
if !ok {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("no routes for module %q", module))
}
Expand Down Expand Up @@ -1342,9 +1337,7 @@ func (s *Service) syncRoutes(ctx context.Context) (time.Duration, error) {
} else if err != nil {
return 0, err
}
s.routesMu.Lock()
s.routes = routes
s.routesMu.Unlock()
s.routes.Store(routes)
return time.Second, nil
}

Expand Down

0 comments on commit 16fb60c

Please sign in to comment.