Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check authorization in Basics Station discovery #6734

Merged
merged 2 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ For details about compatibility between different releases, see the **Commitment
- Server side events replaced with single socket connection using the native WebSocket API.
- Gateways now disconnect if the Gateway Server address has changed.
- This enables CUPS-enabled gateways to change their LNS before the periodic CUPS lookup occurs.
- The LoRa Basics Station discovery endpoint now verifies the authorization credentials of the caller.
- This enables the gateways to migrate to another instance gracefully while using CUPS.

### Deprecated

Expand Down
7 changes: 6 additions & 1 deletion pkg/gatewayserver/gatewayserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ type connectionEntry struct {
tasksDone *sync.WaitGroup
}

// AssertGatewayRights checks that the caller has the required rights over the provided gateway identifiers.
func (gs *GatewayServer) AssertGatewayRights(ctx context.Context, ids *ttnpb.GatewayIdentifiers, rights ...ttnpb.Right) error {
return gs.entityRegistry.AssertGatewayRights(ctx, ids, rights...)
}

// Connect connects a gateway by its identifiers to the Gateway Server, and returns a io.Connection for traffic and
// control.
func (gs *GatewayServer) Connect(
Expand All @@ -459,7 +464,7 @@ func (gs *GatewayServer) Connect(
addr *ttnpb.GatewayRemoteAddress,
opts ...io.ConnectionOption,
) (*io.Connection, error) {
if err := gs.entityRegistry.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil {
if err := gs.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil {
return nil, err
}

Expand Down
7 changes: 5 additions & 2 deletions pkg/gatewayserver/io/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ type Server interface {
GetBaseConfig(ctx context.Context) config.ServiceBase
// FillGatewayContext fills the given context and identifiers.
// This method should only be used for request contexts.
FillGatewayContext(ctx context.Context,
ids *ttnpb.GatewayIdentifiers) (context.Context, *ttnpb.GatewayIdentifiers, error)
FillGatewayContext(
ctx context.Context, ids *ttnpb.GatewayIdentifiers,
) (context.Context, *ttnpb.GatewayIdentifiers, error)
// AssertGatewayRights checks that the caller has the required rights over the provided gateway identifiers.
AssertGatewayRights(ctx context.Context, ids *ttnpb.GatewayIdentifiers, required ...ttnpb.Right) error
// Connect connects a gateway by its identifiers to the Gateway Server, and returns a Connection for traffic and
// control.
Connect(
Expand Down
7 changes: 6 additions & 1 deletion pkg/gatewayserver/io/mock/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ func (s *server) FillGatewayContext(ctx context.Context, ids *ttnpb.GatewayIdent
return ctx, ids, nil
}

// AssertRights implements io.Server.
func (*server) AssertGatewayRights(ctx context.Context, ids *ttnpb.GatewayIdentifiers, required ...ttnpb.Right) error {
return rights.RequireGateway(ctx, ids, required...)
}

// Connect implements io.Server.
func (s *server) Connect(
ctx context.Context,
Expand All @@ -84,7 +89,7 @@ func (s *server) Connect(
addr *ttnpb.GatewayRemoteAddress,
opts ...io.ConnectionOption,
) (*io.Connection, error) {
if err := rights.RequireGateway(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil {
if err := s.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK); err != nil {
return nil, err
}
gtw, err := s.identityStore.GatewayRegistry().Get(ctx, &ttnpb.GetGatewayRequest{GatewayIds: ids})
Expand Down
16 changes: 13 additions & 3 deletions pkg/gatewayserver/io/ws/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,22 @@ type Formatter interface {
Endpoints() Endpoints
// HandleConnectionInfo handles connection information requests from web socket based protocols.
// This function returns a byte stream that contains connection information (ex: scheme, host, port etc) or an error if applicable.
HandleConnectionInfo(ctx context.Context, raw []byte, server io.Server, serverInfo ServerInfo, receivedAt time.Time) []byte
HandleConnectionInfo(
ctx context.Context,
raw []byte,
server io.Server,
serverInfo ServerInfo,
assertAuth func(context.Context, *ttnpb.GatewayIdentifiers) error,
) []byte
// HandleUp handles upstream messages from web socket based gateways.
// This function optionally returns a byte stream to be sent as response to the upstream message.
HandleUp(ctx context.Context, raw []byte, ids *ttnpb.GatewayIdentifiers, conn *io.Connection, receivedAt time.Time) ([]byte, error)
HandleUp(
ctx context.Context, raw []byte, ids *ttnpb.GatewayIdentifiers, conn *io.Connection, receivedAt time.Time,
) ([]byte, error)
// FromDownlink generates a downlink byte stream that can be sent over the WS connection.
FromDownlink(ctx context.Context, down *ttnpb.DownlinkMessage, bandID string, dlTime time.Time) ([]byte, error)
// TransferTime generates a spurious time transfer message for a particular server time.
TransferTime(ctx context.Context, serverTime time.Time, gpsTime *time.Time, concentratorTime *scheduling.ConcentratorTime) ([]byte, error)
TransferTime(
ctx context.Context, serverTime time.Time, gpsTime *time.Time, concentratorTime *scheduling.ConcentratorTime,
) ([]byte, error)
}
13 changes: 11 additions & 2 deletions pkg/gatewayserver/io/ws/lbslns/discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"encoding/json"
"fmt"
"time"

"go.thethings.network/lorawan-stack/v3/pkg/errors"
"go.thethings.network/lorawan-stack/v3/pkg/gatewayserver/io"
Expand All @@ -32,7 +31,13 @@ import (
var errEmptyGatewayEUI = errors.DefineFailedPrecondition("empty_gateway_eui", "empty gateway EUI")

// HandleConnectionInfo implements Formatter.
func (f *lbsLNS) HandleConnectionInfo(ctx context.Context, raw []byte, server io.Server, info ws.ServerInfo, receivedAt time.Time) []byte {
func (f *lbsLNS) HandleConnectionInfo(
ctx context.Context,
raw []byte,
server io.Server,
info ws.ServerInfo,
assertAuth func(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error,
) []byte {
var req DiscoverQuery

if err := json.Unmarshal(raw, &req); err != nil {
Expand All @@ -52,6 +57,10 @@ func (f *lbsLNS) HandleConnectionInfo(ctx context.Context, raw []byte, server io
}
ctx = filledCtx

if err := assertAuth(ctx, ids); err != nil {
return logAndWrapDiscoverError(ctx, err, fmt.Sprintf("Unauthorized"))
}

euiWithPrefix := fmt.Sprintf("eui-%s", types.MustEUI64(ids.Eui).OrZero().String())
res := DiscoverResponse{
EUI: req.EUI,
Expand Down
5 changes: 3 additions & 2 deletions pkg/gatewayserver/io/ws/lbslns/discover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"encoding/json"
"testing"
"time"

"github.com/smarty/assertions"
"go.thethings.network/lorawan-stack/v3/pkg/gatewayserver/io/ws"
Expand Down Expand Up @@ -69,9 +68,11 @@ func TestDiscover(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
msg, err := json.Marshal(tc.Query)
a.So(err, should.BeNil)
resp := lbsLNS.HandleConnectionInfo(ctx, msg, mockServer, info, time.Now())
resp := lbsLNS.HandleConnectionInfo(ctx, msg, mockServer, info, noopAssertRights)
expected, _ := json.Marshal(tc.ExpectedResponse)
a.So(string(resp), should.Equal, string(expected))
})
}
}

func noopAssertRights(context.Context, *ttnpb.GatewayIdentifiers) error { return nil }
4 changes: 4 additions & 0 deletions pkg/gatewayserver/io/ws/lbslns/discover_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ func (srv mockServer) FillGatewayContext(ctx context.Context, ids *ttnpb.Gateway
return ctx, srv.ids, nil
}

func (mockServer) AssertGatewayRights(context.Context, *ttnpb.GatewayIdentifiers, ...ttnpb.Right) error {
return nil
}

func (mockServer) Connect(
context.Context, io.Frontend, *ttnpb.GatewayIdentifiers, *ttnpb.GatewayRemoteAddress, ...io.ConnectionOption,
) (*io.Connection, error) {
Expand Down
72 changes: 35 additions & 37 deletions pkg/gatewayserver/io/ws/ws.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright © 2019 The Things Network Foundation, The Things Industries B.V.

Check warning on line 1 in pkg/gatewayserver/io/ws/ws.go

View workflow job for this annotation

GitHub Actions / Check Mergeability

pkg/gatewayserver/io/ws/ws.go has a conflict when merging TheThingsIndustries/lorawan-stack:v3.28.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -117,6 +117,18 @@
"remote_addr", r.RemoteAddr,
))
logger := log.FromContext(ctx)

assertAuth := func(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error {
ctx, hasAuth := withForwardedAuth(ctx, ids, r.Header.Get("Authorization"))
if !hasAuth {
if !s.cfg.AllowUnauthenticated {
return errNoAuthProvided.WithAttributes("uid", unique.ID(ctx, ids))
}
return nil
}
return s.server.AssertGatewayRights(ctx, ids, ttnpb.Right_RIGHT_GATEWAY_LINK)
}

ws, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
logger.WithError(err).Debug("Failed to upgrade request to websocket connection")
Expand Down Expand Up @@ -154,7 +166,7 @@
Address: net.JoinHostPort(host, port),
}

resp := s.formatter.HandleConnectionInfo(ctx, data, s.server, info, time.Now())
resp := s.formatter.HandleConnectionInfo(ctx, data, s.server, info, assertAuth)
if err := ws.WriteMessage(websocket.TextMessage, resp); err != nil {
logger.WithError(err).Warn("Failed to write connection info response message")
return
Expand Down Expand Up @@ -202,47 +214,12 @@
uid := unique.ID(ctx, ids)
ctx = log.NewContextWithField(ctx, "gateway_uid", uid)

var md metadata.MD

if auth != "" {
if !strings.HasPrefix(auth, "Bearer ") {
auth = fmt.Sprintf("Bearer %s", auth)
}
md = metadata.New(map[string]string{
"id": ids.GatewayId,
"authorization": auth,
})
}

if ctxMd, ok := metadata.FromIncomingContext(ctx); ok {
md = metadata.Join(ctxMd, md)
}
ctx = metadata.NewIncomingContext(ctx, md)
// If a fallback frequency is defined in the server context, inject it into local the context.
if fallback, ok := frequencyplans.FallbackIDFromContext(s.ctx); ok {
ctx = frequencyplans.WithFallbackID(ctx, fallback)
}

var hasAuth bool
if auth != "" {
if !strings.HasPrefix(auth, "Bearer ") {
auth = fmt.Sprintf("Bearer %s", auth)
}
md = metadata.New(map[string]string{
"authorization": auth,
})
hasAuth = true
}

if ctxMd, ok := metadata.FromIncomingContext(ctx); ok {
md = metadata.Join(ctxMd, md)
}
ctx = metadata.NewIncomingContext(ctx, md)
// If a fallback frequency is defined in the server context, inject it into local the context.
if fallback, ok := frequencyplans.FallbackIDFromContext(s.ctx); ok {
ctx = frequencyplans.WithFallbackID(ctx, fallback)
}

ctx, hasAuth := withForwardedAuth(ctx, ids, auth)
if !hasAuth {
if !s.cfg.AllowUnauthenticated {
// We error here directly as there is no auth.
Expand Down Expand Up @@ -416,3 +393,24 @@
}
}
}

func withForwardedAuth(ctx context.Context, ids *ttnpb.GatewayIdentifiers, auth string) (context.Context, bool) {
var md metadata.MD
var hasAuth bool
if auth != "" {
if !strings.HasPrefix(auth, "Bearer ") {
auth = fmt.Sprintf("Bearer %s", auth)
}
m := map[string]string{"authorization": auth}
if ids != nil {
m["id"] = ids.GatewayId
}
md = metadata.New(m)
if ctxMd, ok := metadata.FromIncomingContext(ctx); ok {
md = metadata.Join(ctxMd, md)
}
ctx = metadata.NewIncomingContext(ctx, md)
hasAuth = true
}
return ctx, hasAuth
}
Loading