Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SMF code for readability, conciseness, and modularity #108

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion internal/context/charging.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package context

import (
"github.com/google/uuid"

"github.com/free5gc/openapi/models"
)

Expand Down Expand Up @@ -29,5 +31,5 @@ type ChargingInfo struct {
EventLimitExpiryTimer *Timer
ChargingLevel ChargingLevel
RatingGroup int32
UpfId string
UpfUUID uuid.UUID
}
99 changes: 97 additions & 2 deletions internal/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math"
"net"
"os"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -66,6 +67,14 @@ type SMFContext struct {
PFCPCancelFunc context.CancelFunc
PfcpHeartbeatInterval time.Duration

PfcpHeartbeatRetries int
PfcpHeartbeatTolerance int
PfcpHeartbeatTimeout time.Duration

SmContextPool sync.Map
CanonicalRef sync.Map
SeidSMContextMap sync.Map

// Now only "IPv4" supported
// TODO: support "IPv6", "IPv4v6", "Ethernet"
SupportedPDUSessionType string
Expand All @@ -81,6 +90,92 @@ type SMFContext struct {
ChargingIDGenerator *idgenerator.IDGenerator
}

/*
func (smfContext *SMFContext) ProcEachSMContext(procFunc func(*SMContext) bool) {
smfContext.SmContextPool.Range(func(key, value interface{}) bool {
smContext := value.(*SMContext)
return procFunc(smContext) // processing function determines if loop continues
})
}*/

func canonicalName(id string, pduSessID int32) string {
return fmt.Sprintf("%s-%d", id, pduSessID)
}

func (smfContext *SMFContext) ResolveRef(id string, pduSessID int32) (string, error) {
if value, ok := smfContext.CanonicalRef.Load(canonicalName(id, pduSessID)); ok {
ref := value.(string)
return ref, nil
} else {
return "", fmt.Errorf("UE[%s] - PDUSessionID[%d] not found in SMFContext", id, pduSessID)
}
}

// *** add unit test ***//
func (smfContext *SMFContext) GetSMContextByRef(ref string) *SMContext {
// TODO: neu schreiben, ProcEachSMContext nutzen
var smCtx *SMContext
if value, ok := smfContext.SmContextPool.Load(ref); ok {
smCtx = value.(*SMContext)
}
return smCtx
}

func (smfContext *SMFContext) GetSMContextById(id string, pduSessID int32) *SMContext {
// TODO: neu schreiben, ProcEachSMContext nutzen
var smCtx *SMContext
ref, err := smfContext.ResolveRef(id, pduSessID)
if err != nil {
return nil
}
if value, ok := smfContext.SmContextPool.Load(ref); ok {
smCtx = value.(*SMContext)
}
return smCtx
}

// *** add unit test ***//
func (smfContext *SMFContext) RemoveSMContext(smContext *SMContext) {
logger.CtxLog.Traceln("In RemoveSMContext")

for _, dataPath := range smContext.Tunnel.DataPathPool {
// TODO: free PDR IDs?
dataPath.DeactivateTunnelAndPDR(smContext)
}

// free UE IP
if smContext.SelectedUPF != nil && smContext.PDUAddress != nil {
logger.PduSessLog.Infof("UE[%s] PDUSessionID[%d] Release IP[%s]",
smContext.Supi, smContext.PDUSessionID, smContext.PDUAddress.String())
GetUserPlaneInformation().
ReleaseUEIP(smContext.SelectedUPF, smContext.PDUAddress, smContext.UseStaticIP)
smContext.SelectedUPF = nil
}

// TODO: what about PFCP session rules?

// TODO: still required or done elsewhere?
for _, pfcpSessionContext := range smContext.PFCPSessionContexts {
smfContext.SeidSMContextMap.Delete(pfcpSessionContext.LocalSEID)
}

ReleaseTEID(smContext.LocalULTeid)
ReleaseTEID(smContext.LocalDLTeid)

smfContext.SmContextPool.Delete(smContext.Ref)
smfContext.CanonicalRef.Delete(canonicalName(smContext.Supi, smContext.PDUSessionID))
smContext.Log.Infof("smContext[%s] is deleted from pool", smContext.Ref)
}

// *** add unit test ***//
func (smfContext *SMFContext) GetSMContextBySEID(seid uint64) *SMContext {
if value, ok := smfContext.SeidSMContextMap.Load(seid); ok {
smContext := value.(*SMContext)
return smContext
}
return nil
}

func GenerateChargingID() int32 {
if smfContext.ChargingIDGenerator != nil {
if id, err := smfContext.ChargingIDGenerator.Allocate(); err == nil {
Expand Down Expand Up @@ -116,8 +211,8 @@ func RetrieveDnnInformation(snssai *models.Snssai, dnn string) *SnssaiSmfDnnInfo
return nil
}

func AllocateLocalSEID() uint64 {
return atomic.AddUint64(&smfContext.LocalSEIDCount, 1)
func (s *SMFContext) AllocateLocalSEID() uint64 {
return atomic.AddUint64(&s.LocalSEIDCount, 1)
}

func InitSmfContext(config *factory.Config) {
Expand Down
Loading
Loading