Skip to content

Commit

Permalink
Merge pull request #6734 from TheThingsNetwork/fix/lns-discovery-os
Browse files Browse the repository at this point in the history
Check authorization in Basics Station discovery
  • Loading branch information
adriansmares authored Dec 1, 2023
2 parents a8bc53d + da3af9a commit 28f39d3
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 48 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ For details about compatibility between different releases, see the **Commitment
- Server side events replaced with single socket connection using the native WebSocket API.
- Gateways now disconnect if the Gateway Server address has changed.
- This enables CUPS-enabled gateways to change their LNS before the periodic CUPS lookup occurs.
- The LoRa Basics Station discovery endpoint now verifies the authorization credentials of the caller.
- This enables the gateways to migrate to another instance gracefully while using CUPS.

### Deprecated

Expand Down
7 changes: 6 additions & 1 deletion pkg/gatewayserver/gatewayserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ type connectionEntry struct {
tasksDone *sync.WaitGroup
}

// AssertGatewayRights checks that the caller has the required rights over the provided gateway identifiers.
func (gs *GatewayServer) AssertGatewayRights(ctx context.Context, ids *ttnpb.GatewayIdentifiers, rights ...ttnpb.Right) error {
return gs.entityRegistry.AssertGatewayRights(ctx, ids, rights...)
}

// Connect connects a gateway by its identifiers to the Gateway Server, and returns a io.Connection for traffic and
// control.
func (gs *GatewayServer) Connect(
Expand All @@ -459,7 +464,7 @@ func (gs *GatewayServer) Connect(
addr *ttnpb.GatewayRemoteAddress,
opts ...io.ConnectionOption,
) (*io.Connection, error) {
if err := gs.entityRegistry.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil {
if err := gs.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil {
return nil, err
}

Expand Down
7 changes: 5 additions & 2 deletions pkg/gatewayserver/io/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ type Server interface {
GetBaseConfig(ctx context.Context) config.ServiceBase
// FillGatewayContext fills the given context and identifiers.
// This method should only be used for request contexts.
FillGatewayContext(ctx context.Context,
ids *ttnpb.GatewayIdentifiers) (context.Context, *ttnpb.GatewayIdentifiers, error)
FillGatewayContext(
ctx context.Context, ids *ttnpb.GatewayIdentifiers,
) (context.Context, *ttnpb.GatewayIdentifiers, error)
// AssertGatewayRights checks that the caller has the required rights over the provided gateway identifiers.
AssertGatewayRights(ctx context.Context, ids *ttnpb.GatewayIdentifiers, required ...ttnpb.Right) error
// Connect connects a gateway by its identifiers to the Gateway Server, and returns a Connection for traffic and
// control.
Connect(
Expand Down
7 changes: 6 additions & 1 deletion pkg/gatewayserver/io/mock/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ func (s *server) FillGatewayContext(ctx context.Context, ids *ttnpb.GatewayIdent
return ctx, ids, nil
}

// AssertRights implements io.Server.
func (*server) AssertGatewayRights(ctx context.Context, ids *ttnpb.GatewayIdentifiers, required ...ttnpb.Right) error {
return rights.RequireGateway(ctx, ids, required...)
}

// Connect implements io.Server.
func (s *server) Connect(
ctx context.Context,
Expand All @@ -84,7 +89,7 @@ func (s *server) Connect(
addr *ttnpb.GatewayRemoteAddress,
opts ...io.ConnectionOption,
) (*io.Connection, error) {
if err := rights.RequireGateway(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil {
if err := s.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil {
return nil, err
}
gtw, err := s.identityStore.GatewayRegistry().Get(ctx, &ttnpb.GetGatewayRequest{GatewayIds: ids})
Expand Down
16 changes: 13 additions & 3 deletions pkg/gatewayserver/io/ws/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,22 @@ type Formatter interface {
Endpoints() Endpoints
// HandleConnectionInfo handles connection information requests from web socket based protocols.
// This function returns a byte stream that contains connection information (ex: scheme, host, port etc) or an error if applicable.
HandleConnectionInfo(ctx context.Context, raw []byte, server io.Server, serverInfo ServerInfo, receivedAt time.Time) []byte
HandleConnectionInfo(
ctx context.Context,
raw []byte,
server io.Server,
serverInfo ServerInfo,
assertAuth func(context.Context, *ttnpb.GatewayIdentifiers) error,
) []byte
// HandleUp handles upstream messages from web socket based gateways.
// This function optionally returns a byte stream to be sent as response to the upstream message.
HandleUp(ctx context.Context, raw []byte, ids *ttnpb.GatewayIdentifiers, conn *io.Connection, receivedAt time.Time) ([]byte, error)
HandleUp(
ctx context.Context, raw []byte, ids *ttnpb.GatewayIdentifiers, conn *io.Connection, receivedAt time.Time,
) ([]byte, error)
// FromDownlink generates a downlink byte stream that can be sent over the WS connection.
FromDownlink(ctx context.Context, down *ttnpb.DownlinkMessage, bandID string, dlTime time.Time) ([]byte, error)
// TransferTime generates a spurious time transfer message for a particular server time.
TransferTime(ctx context.Context, serverTime time.Time, gpsTime *time.Time, concentratorTime *scheduling.ConcentratorTime) ([]byte, error)
TransferTime(
ctx context.Context, serverTime time.Time, gpsTime *time.Time, concentratorTime *scheduling.ConcentratorTime,
) ([]byte, error)
}
13 changes: 11 additions & 2 deletions pkg/gatewayserver/io/ws/lbslns/discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"encoding/json"
"fmt"
"time"

"go.thethings.network/lorawan-stack/v3/pkg/errors"
"go.thethings.network/lorawan-stack/v3/pkg/gatewayserver/io"
Expand All @@ -32,7 +31,13 @@ import (
var errEmptyGatewayEUI = errors.DefineFailedPrecondition("empty_gateway_eui", "empty gateway EUI")

// HandleConnectionInfo implements Formatter.
func (f *lbsLNS) HandleConnectionInfo(ctx context.Context, raw []byte, server io.Server, info ws.ServerInfo, receivedAt time.Time) []byte {
func (f *lbsLNS) HandleConnectionInfo(
ctx context.Context,
raw []byte,
server io.Server,
info ws.ServerInfo,
assertAuth func(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error,
) []byte {
var req DiscoverQuery

if err := json.Unmarshal(raw, &req); err != nil {
Expand All @@ -52,6 +57,10 @@ func (f *lbsLNS) HandleConnectionInfo(ctx context.Context, raw []byte, server io
}
ctx = filledCtx

if err := assertAuth(ctx, ids); err != nil {
return logAndWrapDiscoverError(ctx, err, fmt.Sprintf("Unauthorized"))
}

euiWithPrefix := fmt.Sprintf("eui-%s", types.MustEUI64(ids.Eui).OrZero().String())
res := DiscoverResponse{
EUI: req.EUI,
Expand Down
5 changes: 3 additions & 2 deletions pkg/gatewayserver/io/ws/lbslns/discover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"encoding/json"
"testing"
"time"

"github.com/smarty/assertions"
"go.thethings.network/lorawan-stack/v3/pkg/gatewayserver/io/ws"
Expand Down Expand Up @@ -69,9 +68,11 @@ func TestDiscover(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
msg, err := json.Marshal(tc.Query)
a.So(err, should.BeNil)
resp := lbsLNS.HandleConnectionInfo(ctx, msg, mockServer, info, time.Now())
resp := lbsLNS.HandleConnectionInfo(ctx, msg, mockServer, info, noopAssertRights)
expected, _ := json.Marshal(tc.ExpectedResponse)
a.So(string(resp), should.Equal, string(expected))
})
}
}

func noopAssertRights(context.Context, *ttnpb.GatewayIdentifiers) error { return nil }
4 changes: 4 additions & 0 deletions pkg/gatewayserver/io/ws/lbslns/discover_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ func (srv mockServer) FillGatewayContext(ctx context.Context, ids *ttnpb.Gateway
return ctx, srv.ids, nil
}

func (mockServer) AssertGatewayRights(context.Context, *ttnpb.GatewayIdentifiers, ...ttnpb.Right) error {
return nil
}

func (mockServer) Connect(
context.Context, io.Frontend, *ttnpb.GatewayIdentifiers, *ttnpb.GatewayRemoteAddress, ...io.ConnectionOption,
) (*io.Connection, error) {
Expand Down
72 changes: 35 additions & 37 deletions pkg/gatewayserver/io/ws/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ func (s *srv) handleConnectionInfo(w http.ResponseWriter, r *http.Request) {
"remote_addr", r.RemoteAddr,
))
logger := log.FromContext(ctx)

assertAuth := func(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error {
ctx, hasAuth := withForwardedAuth(ctx, ids, r.Header.Get("Authorization"))
if !hasAuth {
if !s.cfg.AllowUnauthenticated {
return errNoAuthProvided.WithAttributes("uid", unique.ID(ctx, ids))
}
return nil
}
return s.server.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK)
}

ws, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
logger.WithError(err).Debug("Failed to upgrade request to websocket connection")
Expand Down Expand Up @@ -154,7 +166,7 @@ func (s *srv) handleConnectionInfo(w http.ResponseWriter, r *http.Request) {
Address: net.JoinHostPort(host, port),
}

resp := s.formatter.HandleConnectionInfo(ctx, data, s.server, info, time.Now())
resp := s.formatter.HandleConnectionInfo(ctx, data, s.server, info, assertAuth)
if err := ws.WriteMessage(websocket.TextMessage, resp); err != nil {
logger.WithError(err).Warn("Failed to write connection info response message")
return
Expand Down Expand Up @@ -202,47 +214,12 @@ func (s *srv) handleTraffic(w http.ResponseWriter, r *http.Request) (err error)
uid := unique.ID(ctx, ids)
ctx = log.NewContextWithField(ctx, "gateway_uid", uid)

var md metadata.MD

if auth != "" {
if !strings.HasPrefix(auth, "Bearer ") {
auth = fmt.Sprintf("Bearer %s", auth)
}
md = metadata.New(map[string]string{
"id": ids.GatewayId,
"authorization": auth,
})
}

if ctxMd, ok := metadata.FromIncomingContext(ctx); ok {
md = metadata.Join(ctxMd, md)
}
ctx = metadata.NewIncomingContext(ctx, md)
// If a fallback frequency is defined in the server context, inject it into local the context.
if fallback, ok := frequencyplans.FallbackIDFromContext(s.ctx); ok {
ctx = frequencyplans.WithFallbackID(ctx, fallback)
}

var hasAuth bool
if auth != "" {
if !strings.HasPrefix(auth, "Bearer ") {
auth = fmt.Sprintf("Bearer %s", auth)
}
md = metadata.New(map[string]string{
"authorization": auth,
})
hasAuth = true
}

if ctxMd, ok := metadata.FromIncomingContext(ctx); ok {
md = metadata.Join(ctxMd, md)
}
ctx = metadata.NewIncomingContext(ctx, md)
// If a fallback frequency is defined in the server context, inject it into local the context.
if fallback, ok := frequencyplans.FallbackIDFromContext(s.ctx); ok {
ctx = frequencyplans.WithFallbackID(ctx, fallback)
}

ctx, hasAuth := withForwardedAuth(ctx, ids, auth)
if !hasAuth {
if !s.cfg.AllowUnauthenticated {
// We error here directly as there is no auth.
Expand Down Expand Up @@ -416,3 +393,24 @@ func (s *srv) handleTraffic(w http.ResponseWriter, r *http.Request) (err error)
}
}
}

func withForwardedAuth(ctx context.Context, ids *ttnpb.GatewayIdentifiers, auth string) (context.Context, bool) {
var md metadata.MD
var hasAuth bool
if auth != "" {
if !strings.HasPrefix(auth, "Bearer ") {
auth = fmt.Sprintf("Bearer %s", auth)
}
m := map[string]string{"authorization": auth}
if ids != nil {
m["id"] = ids.GatewayId
}
md = metadata.New(m)
if ctxMd, ok := metadata.FromIncomingContext(ctx); ok {
md = metadata.Join(ctxMd, md)
}
ctx = metadata.NewIncomingContext(ctx, md)
hasAuth = true
}
return ctx, hasAuth
}

0 comments on commit 28f39d3

Please sign in to comment.