Skip to content

Commit

Permalink
Fix: use models serviceName
Browse files Browse the repository at this point in the history
  • Loading branch information
andy89923 committed Feb 6, 2024
1 parent ae0a724 commit b68ecce
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 23 deletions.
10 changes: 5 additions & 5 deletions internal/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func Init() {
}

type NFContext interface {
AuthorizationCheck(token, serviceName string) error
AuthorizationCheck(token string, serviceName models.ServiceName) error
}

var _ NFContext = &AUSFContext{}
Expand Down Expand Up @@ -166,22 +166,22 @@ func (a *AUSFContext) GetSelfID() string {
return a.NfId
}

func (c *AUSFContext) GetTokenCtx(scope string, targetNF models.NfType) (
func (c *AUSFContext) GetTokenCtx(serviceName models.ServiceName, targetNF models.NfType) (
context.Context, *models.ProblemDetails, error,
) {
if !c.OAuth2Required {
return context.TODO(), nil, nil
}
return oauth.GetTokenCtx(models.NfType_AUSF, targetNF,
c.NfId, c.NrfUri, scope)
c.NfId, c.NrfUri, string(serviceName))
}

func (c *AUSFContext) AuthorizationCheck(token, serviceName string) error {
func (c *AUSFContext) AuthorizationCheck(token string, serviceName models.ServiceName) error {
if !c.OAuth2Required {
logger.UtilLog.Debugf("AUSFContext::AuthorizationCheck: OAuth2 not required\n")
return nil
}

logger.UtilLog.Debugf("AUSFContext::AuthorizationCheck: token[%s] serviceName[%s]\n", token, serviceName)
return oauth.VerifyOAuth(token, serviceName, c.NrfCertPem)
return oauth.VerifyOAuth(token, string(serviceName), c.NrfCertPem)
}
2 changes: 1 addition & 1 deletion internal/sbi/consumer/nf_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
func SendSearchNFInstances(nrfUri string, targetNfType, requestNfType models.NfType,
param Nnrf_NFDiscovery.SearchNFInstancesParamOpts,
) (*models.SearchResult, error) {
ctx, _, err := ausf_context.GetSelf().GetTokenCtx("nnrf-disc", models.NfType_NRF)
ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NNRF_DISC, models.NfType_NRF)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/sbi/consumer/nf_management.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func SendRegisterNFInstance(nrfUri, nfInstanceId string, profile models.NfProfil
configuration.SetBasePath(nrfUri)
client := Nnrf_NFManagement.NewAPIClient(configuration)

ctx, _, err := ausf_context.GetSelf().GetTokenCtx("nnrf-nfm", models.NfType_NRF)
ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NNRF_NFM, models.NfType_NRF)
if err != nil {
return "", "", err
}
Expand Down Expand Up @@ -94,7 +94,7 @@ func SendRegisterNFInstance(nrfUri, nfInstanceId string, profile models.NfProfil
func SendDeregisterNFInstance() (*models.ProblemDetails, error) {
logger.ConsumerLog.Infof("Send Deregister NFInstance")

ctx, pd, err := ausf_context.GetSelf().GetTokenCtx("nnrf-nfm", models.NfType_NRF)
ctx, pd, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NNRF_NFM, models.NfType_NRF)
if err != nil {
return pd, err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/sbi/producer/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ func sendAuthResultToUDM(id string, authType models.AuthType, success bool, serv

client := createClientToUdmUeau(udmUrl)

ctx, _, err := ausf_context.GetSelf().GetTokenCtx("nudm-ueau", models.NfType_UDM)
ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NUDM_UEAU, models.NfType_UDM)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/sbi/producer/ue_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func UeAuthPostRequestProcedure(updateAuthenticationInfo models.AuthenticationIn
udmUrl := getUdmUrl(self.NrfUri)
client := createClientToUdmUeau(udmUrl)

ctx, _, err := ausf_context.GetSelf().GetTokenCtx("nudm-ueau", models.NfType_UDM)
ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NUDM_UEAU, models.NfType_UDM)
if err != nil {
return nil, "", nil
}
Expand Down
4 changes: 1 addition & 3 deletions internal/sbi/sorprotection/routers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import (
logger_util "github.com/free5gc/util/logger"
)

const serviceName string = string(models.ServiceName_NAUSF_SORPROTECTION)

// Route is the information for every URI.
type Route struct {
// Name is the name of this Route.
Expand All @@ -50,7 +48,7 @@ func NewRouter() *gin.Engine {
func AddService(engine *gin.Engine) *gin.RouterGroup {
group := engine.Group(factory.AusfSorprotectionResUriPrefix)

routerAuthorizationCheck := util.NewRouterAuthorizationCheck(serviceName)
routerAuthorizationCheck := util.NewRouterAuthorizationCheck(models.ServiceName_NAUSF_SORPROTECTION)
group.Use(func(c *gin.Context) {
routerAuthorizationCheck.Check(c, ausf_context.GetSelf())
})
Expand Down
4 changes: 1 addition & 3 deletions internal/sbi/ueauthentication/routers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import (
logger_util "github.com/free5gc/util/logger"
)

const serviceName string = string(models.ServiceName_NAUSF_AUTH)

// Route is the information for every URI.
type Route struct {
// Name is the name of this Route.
Expand All @@ -50,7 +48,7 @@ func NewRouter() *gin.Engine {
func AddService(engine *gin.Engine) *gin.RouterGroup {
group := engine.Group(factory.AusfAuthResUriPrefix)

routerAuthorizationCheck := util.NewRouterAuthorizationCheck(serviceName)
routerAuthorizationCheck := util.NewRouterAuthorizationCheck(models.ServiceName_NAUSF_AUTH)
group.Use(func(c *gin.Context) {
routerAuthorizationCheck.Check(c, ausf_context.GetSelf())
})
Expand Down
4 changes: 1 addition & 3 deletions internal/sbi/upuprotection/routers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import (
logger_util "github.com/free5gc/util/logger"
)

const serviceName string = string(models.ServiceName_NAUSF_UPUPROTECTION)

// Route is the information for every URI.
type Route struct {
// Name is the name of this Route.
Expand All @@ -50,7 +48,7 @@ func NewRouter() *gin.Engine {
func AddService(engine *gin.Engine) *gin.RouterGroup {
group := engine.Group(factory.AusfAuthResUriPrefix)

routerAuthorizationCheck := util.NewRouterAuthorizationCheck(serviceName)
routerAuthorizationCheck := util.NewRouterAuthorizationCheck(models.ServiceName_NAUSF_UPUPROTECTION)
group.Use(func(c *gin.Context) {
routerAuthorizationCheck.Check(c, ausf_context.GetSelf())
})
Expand Down
5 changes: 3 additions & 2 deletions internal/util/router_auth_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ import (

ausf_context "github.com/free5gc/ausf/internal/context"
"github.com/free5gc/ausf/internal/logger"
"github.com/free5gc/openapi/models"
)

type RouterAuthorizationCheck struct {
serviceName string
serviceName models.ServiceName
}

func NewRouterAuthorizationCheck(serviceName string) *RouterAuthorizationCheck {
func NewRouterAuthorizationCheck(serviceName models.ServiceName) *RouterAuthorizationCheck {
return &RouterAuthorizationCheck{
serviceName: serviceName,
}
Expand Down
6 changes: 4 additions & 2 deletions internal/util/router_auth_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (

"github.com/gin-gonic/gin"
"github.com/pkg/errors"

"github.com/free5gc/openapi/models"
)

const (
Expand All @@ -20,7 +22,7 @@ func newMockAUSFContext() *mockAUSFContext {
return &mockAUSFContext{}
}

func (m *mockAUSFContext) AuthorizationCheck(token string, serviceName string) error {
func (m *mockAUSFContext) AuthorizationCheck(token string, serviceName models.ServiceName) error {
if token == Valid {
return nil
}
Expand Down Expand Up @@ -81,7 +83,7 @@ func TestRouterAuthorizationCheck_Check(t *testing.T) {
}
c.Request.Header.Set("Authorization", tt.args.token)

rac := NewRouterAuthorizationCheck("testService")
rac := NewRouterAuthorizationCheck(models.ServiceName("testService"))
rac.Check(c, newMockAUSFContext())
if w.Code != tt.want.statusCode {
t.Errorf("StatusCode should be %d, but got %d", tt.want.statusCode, w.Code)
Expand Down

0 comments on commit b68ecce

Please sign in to comment.