From 667ee195b6a737a979c5aeb965c3989fe407dba8 Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Thu, 30 Nov 2023 20:11:37 +0100 Subject: [PATCH 1/2] gs: Allow individual right checks --- pkg/gatewayserver/gatewayserver.go | 7 ++++++- pkg/gatewayserver/io/io.go | 7 +++++-- pkg/gatewayserver/io/mock/server.go | 7 ++++++- pkg/gatewayserver/io/ws/lbslns/discover_util_test.go | 4 ++++ 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/pkg/gatewayserver/gatewayserver.go b/pkg/gatewayserver/gatewayserver.go index 63ca1790ec..804cef400f 100644 --- a/pkg/gatewayserver/gatewayserver.go +++ b/pkg/gatewayserver/gatewayserver.go @@ -450,6 +450,11 @@ type connectionEntry struct { tasksDone *sync.WaitGroup } +// AssertGatewayRights checks that the caller has the required rights over the provided gateway identifiers. +func (gs *GatewayServer) AssertGatewayRights(ctx context.Context, ids *ttnpb.GatewayIdentifiers, rights ...ttnpb.Right) error { + return gs.entityRegistry.AssertGatewayRights(ctx, ids, rights...) +} + // Connect connects a gateway by its identifiers to the Gateway Server, and returns a io.Connection for traffic and // control. func (gs *GatewayServer) Connect( @@ -459,7 +464,7 @@ func (gs *GatewayServer) Connect( addr *ttnpb.GatewayRemoteAddress, opts ...io.ConnectionOption, ) (*io.Connection, error) { - if err := gs.entityRegistry.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil { + if err := gs.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil { return nil, err } diff --git a/pkg/gatewayserver/io/io.go b/pkg/gatewayserver/io/io.go index 3cf734b30a..3659ba02c4 100644 --- a/pkg/gatewayserver/io/io.go +++ b/pkg/gatewayserver/io/io.go @@ -60,8 +60,11 @@ type Server interface { GetBaseConfig(ctx context.Context) config.ServiceBase // FillGatewayContext fills the given context and identifiers. // This method should only be used for request contexts. - FillGatewayContext(ctx context.Context, - ids *ttnpb.GatewayIdentifiers) (context.Context, *ttnpb.GatewayIdentifiers, error) + FillGatewayContext( + ctx context.Context, ids *ttnpb.GatewayIdentifiers, + ) (context.Context, *ttnpb.GatewayIdentifiers, error) + // AssertGatewayRights checks that the caller has the required rights over the provided gateway identifiers. + AssertGatewayRights(ctx context.Context, ids *ttnpb.GatewayIdentifiers, required ...ttnpb.Right) error // Connect connects a gateway by its identifiers to the Gateway Server, and returns a Connection for traffic and // control. Connect( diff --git a/pkg/gatewayserver/io/mock/server.go b/pkg/gatewayserver/io/mock/server.go index 083d05ff09..080e1da0cc 100644 --- a/pkg/gatewayserver/io/mock/server.go +++ b/pkg/gatewayserver/io/mock/server.go @@ -76,6 +76,11 @@ func (s *server) FillGatewayContext(ctx context.Context, ids *ttnpb.GatewayIdent return ctx, ids, nil } +// AssertRights implements io.Server. +func (*server) AssertGatewayRights(ctx context.Context, ids *ttnpb.GatewayIdentifiers, required ...ttnpb.Right) error { + return rights.RequireGateway(ctx, ids, required...) +} + // Connect implements io.Server. func (s *server) Connect( ctx context.Context, @@ -84,7 +89,7 @@ func (s *server) Connect( addr *ttnpb.GatewayRemoteAddress, opts ...io.ConnectionOption, ) (*io.Connection, error) { - if err := rights.RequireGateway(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil { + if err := s.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil { return nil, err } gtw, err := s.identityStore.GatewayRegistry().Get(ctx, &ttnpb.GetGatewayRequest{GatewayIds: ids}) diff --git a/pkg/gatewayserver/io/ws/lbslns/discover_util_test.go b/pkg/gatewayserver/io/ws/lbslns/discover_util_test.go index 19743b5693..e8e62e96e8 100644 --- a/pkg/gatewayserver/io/ws/lbslns/discover_util_test.go +++ b/pkg/gatewayserver/io/ws/lbslns/discover_util_test.go @@ -37,6 +37,10 @@ func (srv mockServer) FillGatewayContext(ctx context.Context, ids *ttnpb.Gateway return ctx, srv.ids, nil } +func (mockServer) AssertGatewayRights(context.Context, *ttnpb.GatewayIdentifiers, ...ttnpb.Right) error { + return nil +} + func (mockServer) Connect( context.Context, io.Frontend, *ttnpb.GatewayIdentifiers, *ttnpb.GatewayRemoteAddress, ...io.ConnectionOption, ) (*io.Connection, error) { From da3af9a634de871827c4314099cbe1d5fa2ee4ba Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Thu, 30 Nov 2023 20:29:47 +0100 Subject: [PATCH 2/2] gs: Assert gateway rights on discovery --- CHANGELOG.md | 2 + pkg/gatewayserver/io/ws/format.go | 16 ++++- pkg/gatewayserver/io/ws/lbslns/discover.go | 13 +++- .../io/ws/lbslns/discover_test.go | 5 +- pkg/gatewayserver/io/ws/ws.go | 72 +++++++++---------- 5 files changed, 64 insertions(+), 44 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a69b719fc0..cea2227c35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ For details about compatibility between different releases, see the **Commitment - Server side events replaced with single socket connection using the native WebSocket API. - Gateways now disconnect if the Gateway Server address has changed. - This enables CUPS-enabled gateways to change their LNS before the periodic CUPS lookup occurs. +- The LoRa Basics Station discovery endpoint now verifies the authorization credentials of the caller. + - This enables the gateways to migrate to another instance gracefully while using CUPS. ### Deprecated diff --git a/pkg/gatewayserver/io/ws/format.go b/pkg/gatewayserver/io/ws/format.go index 34ba32ddc2..d8c59a9162 100644 --- a/pkg/gatewayserver/io/ws/format.go +++ b/pkg/gatewayserver/io/ws/format.go @@ -41,12 +41,22 @@ type Formatter interface { Endpoints() Endpoints // HandleConnectionInfo handles connection information requests from web socket based protocols. // This function returns a byte stream that contains connection information (ex: scheme, host, port etc) or an error if applicable. - HandleConnectionInfo(ctx context.Context, raw []byte, server io.Server, serverInfo ServerInfo, receivedAt time.Time) []byte + HandleConnectionInfo( + ctx context.Context, + raw []byte, + server io.Server, + serverInfo ServerInfo, + assertAuth func(context.Context, *ttnpb.GatewayIdentifiers) error, + ) []byte // HandleUp handles upstream messages from web socket based gateways. // This function optionally returns a byte stream to be sent as response to the upstream message. - HandleUp(ctx context.Context, raw []byte, ids *ttnpb.GatewayIdentifiers, conn *io.Connection, receivedAt time.Time) ([]byte, error) + HandleUp( + ctx context.Context, raw []byte, ids *ttnpb.GatewayIdentifiers, conn *io.Connection, receivedAt time.Time, + ) ([]byte, error) // FromDownlink generates a downlink byte stream that can be sent over the WS connection. FromDownlink(ctx context.Context, down *ttnpb.DownlinkMessage, bandID string, dlTime time.Time) ([]byte, error) // TransferTime generates a spurious time transfer message for a particular server time. - TransferTime(ctx context.Context, serverTime time.Time, gpsTime *time.Time, concentratorTime *scheduling.ConcentratorTime) ([]byte, error) + TransferTime( + ctx context.Context, serverTime time.Time, gpsTime *time.Time, concentratorTime *scheduling.ConcentratorTime, + ) ([]byte, error) } diff --git a/pkg/gatewayserver/io/ws/lbslns/discover.go b/pkg/gatewayserver/io/ws/lbslns/discover.go index c2601930dc..5400eea70c 100644 --- a/pkg/gatewayserver/io/ws/lbslns/discover.go +++ b/pkg/gatewayserver/io/ws/lbslns/discover.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - "time" "go.thethings.network/lorawan-stack/v3/pkg/errors" "go.thethings.network/lorawan-stack/v3/pkg/gatewayserver/io" @@ -32,7 +31,13 @@ import ( var errEmptyGatewayEUI = errors.DefineFailedPrecondition("empty_gateway_eui", "empty gateway EUI") // HandleConnectionInfo implements Formatter. -func (f *lbsLNS) HandleConnectionInfo(ctx context.Context, raw []byte, server io.Server, info ws.ServerInfo, receivedAt time.Time) []byte { +func (f *lbsLNS) HandleConnectionInfo( + ctx context.Context, + raw []byte, + server io.Server, + info ws.ServerInfo, + assertAuth func(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error, +) []byte { var req DiscoverQuery if err := json.Unmarshal(raw, &req); err != nil { @@ -52,6 +57,10 @@ func (f *lbsLNS) HandleConnectionInfo(ctx context.Context, raw []byte, server io } ctx = filledCtx + if err := assertAuth(ctx, ids); err != nil { + return logAndWrapDiscoverError(ctx, err, fmt.Sprintf("Unauthorized")) + } + euiWithPrefix := fmt.Sprintf("eui-%s", types.MustEUI64(ids.Eui).OrZero().String()) res := DiscoverResponse{ EUI: req.EUI, diff --git a/pkg/gatewayserver/io/ws/lbslns/discover_test.go b/pkg/gatewayserver/io/ws/lbslns/discover_test.go index 12c3f6a27b..bc7e4c998b 100644 --- a/pkg/gatewayserver/io/ws/lbslns/discover_test.go +++ b/pkg/gatewayserver/io/ws/lbslns/discover_test.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "testing" - "time" "github.com/smarty/assertions" "go.thethings.network/lorawan-stack/v3/pkg/gatewayserver/io/ws" @@ -69,9 +68,11 @@ func TestDiscover(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { msg, err := json.Marshal(tc.Query) a.So(err, should.BeNil) - resp := lbsLNS.HandleConnectionInfo(ctx, msg, mockServer, info, time.Now()) + resp := lbsLNS.HandleConnectionInfo(ctx, msg, mockServer, info, noopAssertRights) expected, _ := json.Marshal(tc.ExpectedResponse) a.So(string(resp), should.Equal, string(expected)) }) } } + +func noopAssertRights(context.Context, *ttnpb.GatewayIdentifiers) error { return nil } diff --git a/pkg/gatewayserver/io/ws/ws.go b/pkg/gatewayserver/io/ws/ws.go index 41834ef38f..225ae1fe01 100644 --- a/pkg/gatewayserver/io/ws/ws.go +++ b/pkg/gatewayserver/io/ws/ws.go @@ -117,6 +117,18 @@ func (s *srv) handleConnectionInfo(w http.ResponseWriter, r *http.Request) { "remote_addr", r.RemoteAddr, )) logger := log.FromContext(ctx) + + assertAuth := func(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error { + ctx, hasAuth := withForwardedAuth(ctx, ids, r.Header.Get("Authorization")) + if !hasAuth { + if !s.cfg.AllowUnauthenticated { + return errNoAuthProvided.WithAttributes("uid", unique.ID(ctx, ids)) + } + return nil + } + return s.server.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK) + } + ws, err := s.upgrader.Upgrade(w, r, nil) if err != nil { logger.WithError(err).Debug("Failed to upgrade request to websocket connection") @@ -154,7 +166,7 @@ func (s *srv) handleConnectionInfo(w http.ResponseWriter, r *http.Request) { Address: net.JoinHostPort(host, port), } - resp := s.formatter.HandleConnectionInfo(ctx, data, s.server, info, time.Now()) + resp := s.formatter.HandleConnectionInfo(ctx, data, s.server, info, assertAuth) if err := ws.WriteMessage(websocket.TextMessage, resp); err != nil { logger.WithError(err).Warn("Failed to write connection info response message") return @@ -202,47 +214,12 @@ func (s *srv) handleTraffic(w http.ResponseWriter, r *http.Request) (err error) uid := unique.ID(ctx, ids) ctx = log.NewContextWithField(ctx, "gateway_uid", uid) - var md metadata.MD - - if auth != "" { - if !strings.HasPrefix(auth, "Bearer ") { - auth = fmt.Sprintf("Bearer %s", auth) - } - md = metadata.New(map[string]string{ - "id": ids.GatewayId, - "authorization": auth, - }) - } - - if ctxMd, ok := metadata.FromIncomingContext(ctx); ok { - md = metadata.Join(ctxMd, md) - } - ctx = metadata.NewIncomingContext(ctx, md) - // If a fallback frequency is defined in the server context, inject it into local the context. - if fallback, ok := frequencyplans.FallbackIDFromContext(s.ctx); ok { - ctx = frequencyplans.WithFallbackID(ctx, fallback) - } - - var hasAuth bool - if auth != "" { - if !strings.HasPrefix(auth, "Bearer ") { - auth = fmt.Sprintf("Bearer %s", auth) - } - md = metadata.New(map[string]string{ - "authorization": auth, - }) - hasAuth = true - } - - if ctxMd, ok := metadata.FromIncomingContext(ctx); ok { - md = metadata.Join(ctxMd, md) - } - ctx = metadata.NewIncomingContext(ctx, md) // If a fallback frequency is defined in the server context, inject it into local the context. if fallback, ok := frequencyplans.FallbackIDFromContext(s.ctx); ok { ctx = frequencyplans.WithFallbackID(ctx, fallback) } + ctx, hasAuth := withForwardedAuth(ctx, ids, auth) if !hasAuth { if !s.cfg.AllowUnauthenticated { // We error here directly as there is no auth. @@ -416,3 +393,24 @@ func (s *srv) handleTraffic(w http.ResponseWriter, r *http.Request) (err error) } } } + +func withForwardedAuth(ctx context.Context, ids *ttnpb.GatewayIdentifiers, auth string) (context.Context, bool) { + var md metadata.MD + var hasAuth bool + if auth != "" { + if !strings.HasPrefix(auth, "Bearer ") { + auth = fmt.Sprintf("Bearer %s", auth) + } + m := map[string]string{"authorization": auth} + if ids != nil { + m["id"] = ids.GatewayId + } + md = metadata.New(m) + if ctxMd, ok := metadata.FromIncomingContext(ctx); ok { + md = metadata.Join(ctxMd, md) + } + ctx = metadata.NewIncomingContext(ctx, md) + hasAuth = true + } + return ctx, hasAuth +}