Skip to content

Commit

Permalink
feat(ws): activity service auth scopes
Browse files Browse the repository at this point in the history
  • Loading branch information
hbomb79 committed Jul 20, 2024
1 parent eca3fde commit eb53b0b
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 62 deletions.
84 changes: 71 additions & 13 deletions internal/api/activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package api

import (
"errors"
"slices"
"sync"

"github.com/google/uuid"
"github.com/hbomb79/Thea/internal/api/controllers/ingests"
"github.com/hbomb79/Thea/internal/api/controllers/transcodes"
"github.com/hbomb79/Thea/internal/http/websocket"
"github.com/hbomb79/Thea/internal/user/permissions"
)

const (
Expand All @@ -21,6 +24,9 @@ type broadcaster struct {
ingestService ingests.IngestService
transcodeService TranscodeService
store Store

clientScopes map[authScope][]uuid.UUID
clientMutex *sync.Mutex
}

func newBroadcaster(
Expand All @@ -29,12 +35,72 @@ func newBroadcaster(
transcodeService TranscodeService,
store Store,
) *broadcaster {
return &broadcaster{socketHub, ingestService, transcodeService, store}
return &broadcaster{socketHub, ingestService, transcodeService, store, make(map[authScope][]uuid.UUID, 0), &sync.Mutex{}}
}

type authScope int

const (
mediaScope authScope = iota
transcodeScope
ingestScope
)

var scopePerms = map[authScope][]string{
mediaScope: {permissions.AccessMediaPermission},
transcodeScope: {permissions.AccessTargetPermission},
ingestScope: {permissions.AccessIngestsPermission},
}

// sliceContainsAll returns true if the slice 'a' contains
// ALL the elements inside of 'b'.
func sliceContainsAll[T comparable](a, b []T) bool {
for _, v := range b {
if !slices.Contains(a, v) {
return false
}
}

return true
}

func (hub *broadcaster) RegisterClient(clientID uuid.UUID, permissions []string) {
hub.clientMutex.Lock()
defer hub.clientMutex.Unlock()

for scope, requiredPerms := range scopePerms {
if sliceContainsAll(permissions, requiredPerms) {
hub.clientScopes[scope] = append(hub.clientScopes[scope], clientID)
}
}
}

func (hub *broadcaster) DeregisterClient(clientID uuid.UUID) {
hub.clientMutex.Lock()
defer hub.clientMutex.Unlock()

for k, clients := range hub.clientScopes {
hub.clientScopes[k] = slices.DeleteFunc(clients, func(id uuid.UUID) bool { return id == clientID })
}
}

func (hub *broadcaster) protectedSend(scope authScope, title string, body map[string]interface{}) {
clients := hub.clientScopes[scope]
for _, client := range clients {
// TODO: this could cause quite the number of messages to be sent. Probably fine for
// now, but maybe a queue + worker pool might make sense?
hub.socketHub.Send(&websocket.SocketMessage{
Target: &client,
Title: title,
Body: body,
Type: websocket.Update,
})
}
}

func (hub *broadcaster) BroadcastTranscodeUpdate(id uuid.UUID) error {
item := hub.transcodeService.Task(id)
hub.broadcast(TitleTranscodeUpdate, map[string]interface{}{
hub.protectedSend(transcodeScope, TitleTranscodeUpdate, map[string]interface{}{
"id": id,
"transcode": nullsafeNewDto(item, transcodes.NewDtoFromTask),
})
Expand All @@ -47,7 +113,7 @@ func (hub *broadcaster) BroadcastTaskProgressUpdate(id uuid.UUID) error {
return nil
}

hub.broadcast(TitleTranscodeProgressUpdate, map[string]interface{}{
hub.protectedSend(transcodeScope, TitleTranscodeProgressUpdate, map[string]interface{}{
"transcode_id": id,
"progress": item.LastProgress(),
})
Expand All @@ -56,28 +122,20 @@ func (hub *broadcaster) BroadcastTaskProgressUpdate(id uuid.UUID) error {

func (hub *broadcaster) BroadcastIngestUpdate(id uuid.UUID) error {
item := hub.ingestService.GetIngest(id)
hub.broadcast(TitleIngestUpdate, map[string]interface{}{
hub.protectedSend(ingestScope, TitleIngestUpdate, map[string]interface{}{
"ingest_id": id,
"ingest": nullsafeNewDto(item, ingests.NewDto),
})
return nil
}

func (hub *broadcaster) broadcast(title string, update map[string]interface{}) {
hub.socketHub.Send(&websocket.SocketMessage{
Title: title,
Body: update,
Type: websocket.Update,
})
}

func (hub *broadcaster) BroadcastWorkflowUpdate(id uuid.UUID) error {
return errors.New("not yet implemented")
}

func (hub *broadcaster) BroadcastMediaUpdate(id uuid.UUID) error {
media := hub.store.GetMedia(id)
hub.broadcast(TitleMediaUpdate, map[string]interface{}{
hub.protectedSend(mediaScope, TitleMediaUpdate, map[string]interface{}{
"media_id": id,
"media": media,
})
Expand Down
43 changes: 43 additions & 0 deletions internal/api/jwt/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,49 @@ func (auth *jwtAuthProvider) validateTokenFromAuthInput(ctx context.Context, aut
return nil
}

// validateTokenFromRequest is a simpler version of validateTokenFromAuthInput,
// which acts only on an HTTP request. This is useful in times where the request URI
// is not documented by our OpenAPI spec, and as such poses a huge annoyance.
//
// This function behaves very similarly to the aforementioned counterpart, however
// permission 'scope' validation is NOT performed, so endpoints utilizing this form
// of manual authentication should consider checking this manually.
func (auth *jwtAuthProvider) ValidateTokenFromRequest(ec echo.Context, request *http.Request) (*AuthenticatedUser, error) {
tokenCookie, err := request.Cookie(AuthTokenCookieName)
if err != nil {
return nil, ErrAuthTokenMissing
}

token, err := auth.validateJWT(tokenCookie.Value, auth.authTokenSecret)
if err != nil {
return nil, fmt.Errorf("validation of auth token failed: %w", err)
}

claims, ok := token.Claims.(*jwt.MapClaims)
if !ok {
return nil, errors.New("failed to cast JWT claims to MapClaims")
}

// Extract user information (ID and permissions) from JWT
userID, err := auth.getUserIDFromClaims(*claims)
if err != nil {
return nil, err
}

// Grab user permissions so we can store them in the context
userPermissions, err := auth.getPermissionsFromClaims(*claims)
if err != nil {
return nil, err
}

// Insert user info inside of request context to allow for
// endpoint handlers to extract user information
authUser := &AuthenticatedUser{UserID: *userID, Permissions: userPermissions}
ec.Set("user", authUser)

return authUser, nil
}

func (auth *jwtAuthProvider) getPermissionsFromClaims(claims jwt.MapClaims) ([]string, error) {
if permissions, ok := claims["permissions"]; ok {
perms, ok := permissions.([]interface{})
Expand Down
32 changes: 29 additions & 3 deletions internal/api/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,39 @@ func NewRestGateway(

// -- Setup gateway --
socket := websocket.New()
broadcaster := newBroadcaster(socket, ingestService, transcodeService, store)

// The activity service endpoint is not documented in the OpenAPI spec, so it
// has a unique setup because:
// - The code gen does not know about it, and so we must define the endpoint manually
// - The JWT authentication cannot be done leveraging the OpenAPI validator, as this request
// breaches the spec. Therefore, we validate it manually from the request. This is fine
// for this endpoint as we base what information flows through the websocket using the permissions,
// so there's no permission specifically-required to access this endpoint (the only requirement is
// that you're authenticated).
ec.GET(apiBasePath+"/activity/ws", func(c echo.Context) error {
// TODO authentication things here pls
socket.UpgradeToSocket(c.Response(), c.Request())
user, err := authProvider.ValidateTokenFromRequest(c, c.Request())
if err != nil {
// TODO: ensure this error doesn't leak information. We may need to log this
// error and return a simple HTTP Forbidden.
return err
}

socket.UpgradeToSocket(c.Response(), c.Request(), func(client websocket.SocketClient, event websocket.ClientEvent) {
//exhaustive:enforce
switch event {
case websocket.OPENED:
broadcaster.RegisterClient(client.ID, user.Permissions)
case websocket.CLOSED:
broadcaster.DeregisterClient(client.ID)
}
})

return nil
})

gateway := &RestGateway{
broadcaster: newBroadcaster(socket, ingestService, transcodeService, store),
broadcaster: broadcaster,
config: config,
ec: ec,
socket: socket,
Expand Down
12 changes: 6 additions & 6 deletions internal/http/websocket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"github.com/gorilla/websocket"
)

type socketClient struct {
id *uuid.UUID
type SocketClient struct {
ID uuid.UUID
socket *websocket.Conn
}

func (client *socketClient) SendMessage(message *SocketMessage) error {
func (client *SocketClient) SendMessage(message *SocketMessage) error {
return client.socket.WriteJSON(message)
}

Expand All @@ -19,20 +19,20 @@ func (client *socketClient) SendMessage(message *SocketMessage) error {
// experiences an error, or the JSON marshalling fails, this error will be returned
// and consequently the read loop will close. It is the responsibility of the caller
// to de-register the client once the connection closes.
func (client *socketClient) Read(receiveCh chan *SocketMessage) error {
func (client *SocketClient) Read(receiveCh chan *SocketMessage) error {
for {
var recv SocketMessage
if err := client.socket.ReadJSON(&recv); err != nil {
return err
}

// Set the message origin to point to this clients uuid
recv.Origin = client.id
recv.Origin = &client.ID
receiveCh <- &recv
}
}

// Close will close this clients socket.
func (client *socketClient) Close() {
func (client *SocketClient) Close() {
client.socket.Close()
}
Loading

0 comments on commit eb53b0b

Please sign in to comment.