diff --git a/controller/handler_edge_ctrl/common.go b/controller/handler_edge_ctrl/common.go index 64110b0bd..b55ed7134 100644 --- a/controller/handler_edge_ctrl/common.go +++ b/controller/handler_edge_ctrl/common.go @@ -229,7 +229,8 @@ func (self *baseSessionRequestContext) checkSessionFingerprints(fingerprints []s func (self *baseSessionRequestContext) verifyEdgeRouterAccess() { if self.err == nil { // validate edge router - result, err := self.handler.getAppEnv().Managers.EdgeRouter.ListForSession(self.session.Id) + erMgr := self.handler.getAppEnv().Managers.EdgeRouter + edgeRouterAllowed, err := erMgr.IsAccessToEdgeRouterAllowed(self.session.IdentityId, self.session.ServiceId, self.sourceRouter.Id) if err != nil { self.err = internalError(err) logrus. @@ -239,14 +240,6 @@ func (self *baseSessionRequestContext) verifyEdgeRouterAccess() { return } - edgeRouterAllowed := false - for _, er := range result.EdgeRouters { - if er.Id == self.sourceRouter.Id { - edgeRouterAllowed = true - break - } - } - if !edgeRouterAllowed { self.err = InvalidEdgeRouterForSessionError{} } diff --git a/controller/internal/routes/session_api_model.go b/controller/internal/routes/session_api_model.go index 07f6109b8..0690fd545 100644 --- a/controller/internal/routes/session_api_model.go +++ b/controller/internal/routes/session_api_model.go @@ -166,7 +166,7 @@ func MapSessionsToRestEntities(ae *env.AppEnv, rc *response.RequestContext, sess func getSessionEdgeRouters(ae *env.AppEnv, ns *model.Session) ([]*rest_model.SessionEdgeRouter, error) { var edgeRouters []*rest_model.SessionEdgeRouter - edgeRoutersForSession, err := ae.Managers.EdgeRouter.ListForSession(ns.Id) + edgeRoutersForSession, err := ae.Managers.EdgeRouter.ListForIdentityAndService(ns.IdentityId, ns.ServiceId, nil) if err != nil { return nil, err } diff --git a/controller/model/edge_router_manager.go b/controller/model/edge_router_manager.go index ba5fd2629..c8fe499c7 100644 --- a/controller/model/edge_router_manager.go +++ b/controller/model/edge_router_manager.go @@ -188,27 +188,6 @@ func (self *EdgeRouterManager) Query(query string) (*EdgeRouterListResult, error return result, nil } -func (self *EdgeRouterManager) ListForSession(sessionId string) (*EdgeRouterListResult, error) { - var result *EdgeRouterListResult - - err := self.env.GetDbProvider().GetDb().View(func(tx *bbolt.Tx) error { - session, err := self.env.GetStores().Session.LoadOneById(tx, sessionId) - if err != nil { - return err - } - apiSession, err := self.env.GetStores().ApiSession.LoadOneById(tx, session.ApiSessionId) - if err != nil { - return err - } - - limit := -1 - - result, err = self.ListForIdentityAndServiceWithTx(tx, apiSession.IdentityId, session.ServiceId, &limit) - return err - }) - return result, err -} - func (self *EdgeRouterManager) ListForIdentityAndService(identityId, serviceId string, limit *int) (*EdgeRouterListResult, error) { var list *EdgeRouterListResult var err error @@ -223,27 +202,58 @@ func (self *EdgeRouterManager) ListForIdentityAndService(identityId, serviceId s } func (self *EdgeRouterManager) ListForIdentityAndServiceWithTx(tx *bbolt.Tx, identityId, serviceId string, limit *int) (*EdgeRouterListResult, error) { - service, err := self.env.GetStores().EdgeService.LoadOneById(tx, serviceId) - if err != nil { - return nil, err - } - if service == nil { - return nil, errors.Errorf("no service with id %v found", serviceId) - } - - query := fmt.Sprintf(`anyOf(identities) = "%v" and anyOf(services) = "%v"`, identityId, service.Id) + query := fmt.Sprintf(`anyOf(identities) = "%v" and anyOf(services) = "%v"`, identityId, serviceId) if limit != nil { query += " limit " + strconv.Itoa(*limit) } result := &EdgeRouterListResult{manager: self} - if err = self.ListWithTx(tx, query, result.collect); err != nil { + if err := self.ListWithTx(tx, query, result.collect); err != nil { return nil, err } return result, nil } +func (self *EdgeRouterManager) IsAccessToEdgeRouterAllowed(identityId, serviceId, edgeRouterId string) (bool, error) { + var result bool + err := self.GetDb().View(func(tx *bbolt.Tx) error { + identityEdgeRouters := self.env.GetStores().Identity.GetRefCountedLinkCollection(db.EntityTypeRouters) + serviceEdgeRouters := self.env.GetStores().EdgeService.GetRefCountedLinkCollection(persistence.FieldEdgeRouters) + + identityCount := identityEdgeRouters.GetLinkCount(tx, []byte(identityId), []byte(edgeRouterId)) + serviceCount := serviceEdgeRouters.GetLinkCount(tx, []byte(serviceId), []byte(edgeRouterId)) + result = identityCount != nil && *identityCount > 0 && serviceCount != nil && *serviceCount > 0 + return nil + }) + if err != nil { + return false, nil + } + return result, nil +} + +func (self *EdgeRouterManager) IsSharedEdgeRouterPresent(identityId, serviceId string) (bool, error) { + var result bool + err := self.GetDb().View(func(tx *bbolt.Tx) error { + identityEdgeRouters := self.env.GetStores().Identity.GetRefCountedLinkCollection(db.EntityTypeRouters) + serviceEdgeRouters := self.env.GetStores().EdgeService.GetRefCountedLinkCollection(persistence.FieldEdgeRouters) + + cursor := identityEdgeRouters.IterateLinks(tx, []byte(identityId), true) + for cursor.IsValid() { + serviceCount := serviceEdgeRouters.GetLinkCount(tx, []byte(serviceId), cursor.Current()) + if result = serviceCount != nil && *serviceCount > 0; result { + return nil + } + cursor.Next() + } + return nil + }) + if err != nil { + return false, nil + } + return result, nil +} + func (self *EdgeRouterManager) QueryRoleAttributes(queryString string) ([]string, *models.QueryMetaData, error) { index := self.env.GetStores().EdgeRouter.GetRoleAttributesIndex() return self.queryRoleAttributes(index, queryString) diff --git a/controller/model/edge_router_manager_test.go b/controller/model/edge_router_manager_test.go index 43afec38d..4d569bb89 100644 --- a/controller/model/edge_router_manager_test.go +++ b/controller/model/edge_router_manager_test.go @@ -27,12 +27,14 @@ func (ctx *TestContext) testGetEdgeRoutersForServiceAndIdentity(*testing.T) { // test default case, with no limits on service ctx.False(ctx.isEdgeRouterAccessible(edgeRouter.Id, identity.Id, service.Id)) ctx.False(ctx.isEdgeRouterAccessible(edgeRouter2.Id, identity.Id, service.Id)) + ctx.False(ctx.managers.EdgeRouter.IsSharedEdgeRouterPresent(identity.Id, service.Id)) serp := ctx.requireNewServiceNewEdgeRouterPolicy(ss("@"+service.Id), ss("#"+eid.New())) // should not be accessible if we limit to a role no one has ctx.False(ctx.isEdgeRouterAccessible(edgeRouter.Id, identity.Id, service.Id)) ctx.False(ctx.isEdgeRouterAccessible(edgeRouter2.Id, identity.Id, service.Id)) + ctx.False(ctx.managers.EdgeRouter.IsSharedEdgeRouterPresent(identity.Id, service.Id)) serp.EdgeRouterRoles = []string{"@" + edgeRouter.Id} ctx.NoError(ctx.managers.ServiceEdgeRouterPolicy.Update(serp, nil)) @@ -40,6 +42,8 @@ func (ctx *TestContext) testGetEdgeRoutersForServiceAndIdentity(*testing.T) { // should be accessible if we limit to our specific router ctx.True(ctx.isEdgeRouterAccessible(edgeRouter.Id, identity.Id, service.Id)) ctx.False(ctx.isEdgeRouterAccessible(edgeRouter2.Id, identity.Id, service.Id)) + ctx.True(ctx.managers.EdgeRouter.IsSharedEdgeRouterPresent(identity.Id, service.Id)) + } func (ctx *TestContext) isEdgeRouterAccessible(edgeRouterId, identityId, serviceId string) bool { @@ -58,5 +62,10 @@ func (ctx *TestContext) isEdgeRouterAccessible(edgeRouterId, identityId, service return nil }) ctx.NoError(err) + + accessAllowed, err := ctx.managers.EdgeRouter.IsAccessToEdgeRouterAllowed(identityId, serviceId, edgeRouterId) + ctx.NoError(err) + ctx.Equal(found, accessAllowed) + return found } diff --git a/controller/model/session_manager.go b/controller/model/session_manager.go index 61eb87f9e..24ab1f6e2 100644 --- a/controller/model/session_manager.go +++ b/controller/model/session_manager.go @@ -189,12 +189,11 @@ func (self *SessionManager) Create(entity *Session) (string, error) { return "", apierror.NewInvalidPosture(policyResult.Cause) } - maxRows := 1 - result, err := self.GetEnv().GetManagers().EdgeRouter.ListForIdentityAndService(apiSession.IdentityId, entity.ServiceId, &maxRows) + edgeRouterAvailable, err := self.GetEnv().GetManagers().EdgeRouter.IsSharedEdgeRouterPresent(apiSession.IdentityId, entity.ServiceId) if err != nil { return "", err } - if result.Count < 1 { + if !edgeRouterAvailable { return "", apierror.NewNoEdgeRoutersAvailable() }