diff --git a/cmd/main.go b/cmd/main.go index 74fb5ce1..db97fcb7 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -86,7 +86,7 @@ func action(cliCtx *cli.Context) error { } factory.UERoutingConfig = ueRoutingCfg - pfcpStart, pfcpTerminate := utils.InitPFCPFunc() + pfcpStart, pfcpTerminate := utils.InitPFCPFunc(ctx) smf, err := service.NewApp(ctx, cfg, tlsKeyLogPath, pfcpStart, pfcpTerminate) if err != nil { sigCh <- nil diff --git a/internal/context/context.go b/internal/context/context.go index 8f29ed31..1c97d6eb 100644 --- a/internal/context/context.go +++ b/internal/context/context.go @@ -62,8 +62,8 @@ type SMFContext struct { OAuth2Required bool UserPlaneInformation *UserPlaneInformation - Ctx context.Context - PFCPCancelFunc context.CancelFunc + PfcpContext context.Context + PfcpCancelFunc context.CancelFunc PfcpHeartbeatInterval time.Duration // Now only "IPv4" supported diff --git a/internal/context/datapath.go b/internal/context/datapath.go index 84bfc89a..820f8be9 100644 --- a/internal/context/datapath.go +++ b/internal/context/datapath.go @@ -823,7 +823,7 @@ func (p *DataPath) AddChargingRules(smContext *SMContext, chgLevel ChargingLevel // nolint nodeId := node.GetUPFID() logger.PduSessLog.Tracef("DownLinkTunnel add URR for node %s %+v", - nodeId, node.UpLinkTunnel.PDR) + nodeId, node.DownLinkTunnel.PDR) } } } diff --git a/internal/context/sm_context_policy_test.go b/internal/context/sm_context_policy_test.go index f8e1f4a7..1db64cb8 100644 --- a/internal/context/sm_context_policy_test.go +++ b/internal/context/sm_context_policy_test.go @@ -1,12 +1,13 @@ package context_test import ( + "context" "testing" "github.com/stretchr/testify/require" "github.com/free5gc/openapi/models" - "github.com/free5gc/smf/internal/context" + smf_context "github.com/free5gc/smf/internal/context" "github.com/free5gc/smf/pkg/factory" ) @@ -114,7 +115,7 @@ var testConfig = factory.Config{ } func initConfig() { - context.InitSmfContext(&testConfig) + smf_context.InitSmfContext(&testConfig) factory.SmfConfig = &testConfig } @@ -125,7 +126,7 @@ func TestApplySessionRules(t *testing.T) { name string decision *models.SmPolicyDecision noErr bool - expectedSessRules map[string]*context.SessionRule + expectedSessRules map[string]*smf_context.SessionRule }{ { name: "nil decision", @@ -151,7 +152,7 @@ func TestApplySessionRules(t *testing.T) { }, }, }, - expectedSessRules: map[string]*context.SessionRule{ + expectedSessRules: map[string]*smf_context.SessionRule{ "SessRuleId-1": { SessionRule: &models.SessionRule{ AuthSessAmbr: &models.Ambr{ @@ -192,7 +193,7 @@ func TestApplySessionRules(t *testing.T) { }, }, }, - expectedSessRules: map[string]*context.SessionRule{ + expectedSessRules: map[string]*smf_context.SessionRule{ "SessRuleId-1": { SessionRule: &models.SessionRule{ AuthSessAmbr: &models.Ambr{ @@ -250,7 +251,7 @@ func TestApplySessionRules(t *testing.T) { }, }, }, - expectedSessRules: map[string]*context.SessionRule{ + expectedSessRules: map[string]*smf_context.SessionRule{ "SessRuleId-1": { SessionRule: &models.SessionRule{ AuthSessAmbr: &models.Ambr{ @@ -295,7 +296,7 @@ func TestApplySessionRules(t *testing.T) { "SessRuleId-1": nil, }, }, - expectedSessRules: map[string]*context.SessionRule{ + expectedSessRules: map[string]*smf_context.SessionRule{ "SessRuleId-2": { SessionRule: &models.SessionRule{ AuthSessAmbr: &models.Ambr{ @@ -328,7 +329,7 @@ func TestApplySessionRules(t *testing.T) { }, } - smctx := context.NewSMContext("imsi-208930000000001", 10) + smctx := smf_context.NewSMContext("imsi-208930000000001", 10) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -350,9 +351,9 @@ func TestApplyPccRules(t *testing.T) { name string decision *models.SmPolicyDecision noErr bool - expectedPCCRules map[string]*context.PCCRule + expectedPCCRules map[string]*smf_context.PCCRule expectedQosDatas map[string]*models.QosData - expectedTcDatas map[string]*context.TrafficControlData + expectedTcDatas map[string]*smf_context.TrafficControlData }{ { name: "nil decision", @@ -390,7 +391,7 @@ func TestApplyPccRules(t *testing.T) { }, }, }, - expectedPCCRules: map[string]*context.PCCRule{ + expectedPCCRules: map[string]*smf_context.PCCRule{ "PccRuleId-1": { PccRule: &models.PccRule{ FlowInfos: []models.FlowInformation{ @@ -410,7 +411,7 @@ func TestApplyPccRules(t *testing.T) { QosId: "QosId-1", }, }, - expectedTcDatas: map[string]*context.TrafficControlData{ + expectedTcDatas: map[string]*smf_context.TrafficControlData{ "TcId-1": { TrafficControlData: &models.TrafficControlData{ TcId: "TcId-1", @@ -446,7 +447,7 @@ func TestApplyPccRules(t *testing.T) { }, }, }, - expectedPCCRules: map[string]*context.PCCRule{ + expectedPCCRules: map[string]*smf_context.PCCRule{ "PccRuleId-1": { PccRule: &models.PccRule{ FlowInfos: []models.FlowInformation{ @@ -482,7 +483,7 @@ func TestApplyPccRules(t *testing.T) { QosId: "QosId-2", }, }, - expectedTcDatas: map[string]*context.TrafficControlData{ + expectedTcDatas: map[string]*smf_context.TrafficControlData{ "TcId-1": { TrafficControlData: &models.TrafficControlData{ TcId: "TcId-1", @@ -518,7 +519,7 @@ func TestApplyPccRules(t *testing.T) { }, }, }, - expectedPCCRules: map[string]*context.PCCRule{ + expectedPCCRules: map[string]*smf_context.PCCRule{ "PccRuleId-1": { PccRule: &models.PccRule{ FlowInfos: []models.FlowInformation{ @@ -554,7 +555,7 @@ func TestApplyPccRules(t *testing.T) { QosId: "QosId-3", }, }, - expectedTcDatas: map[string]*context.TrafficControlData{ + expectedTcDatas: map[string]*smf_context.TrafficControlData{ "TcId-1": { TrafficControlData: &models.TrafficControlData{ TcId: "TcId-1", @@ -575,7 +576,7 @@ func TestApplyPccRules(t *testing.T) { "PccRuleId-2": nil, }, }, - expectedPCCRules: map[string]*context.PCCRule{ + expectedPCCRules: map[string]*smf_context.PCCRule{ "PccRuleId-1": { PccRule: &models.PccRule{ FlowInfos: []models.FlowInformation{ @@ -595,7 +596,7 @@ func TestApplyPccRules(t *testing.T) { QosId: "QosId-3", }, }, - expectedTcDatas: map[string]*context.TrafficControlData{ + expectedTcDatas: map[string]*smf_context.TrafficControlData{ "TcId-1": { TrafficControlData: &models.TrafficControlData{ TcId: "TcId-1", @@ -616,20 +617,20 @@ func TestApplyPccRules(t *testing.T) { "PccRuleId-1": nil, }, }, - expectedPCCRules: map[string]*context.PCCRule{}, + expectedPCCRules: map[string]*smf_context.PCCRule{}, expectedQosDatas: map[string]*models.QosData{}, - expectedTcDatas: map[string]*context.TrafficControlData{}, + expectedTcDatas: map[string]*smf_context.TrafficControlData{}, noErr: true, }, } - smfContext := context.GetSelf() - smfContext.UserPlaneInformation = context.NewUserPlaneInformation(userPlaneConfig) - for _, n := range smfContext.UserPlaneInformation.UPFs { - n.UPFStatus = context.AssociatedSetUpSuccess + smfContext := smf_context.GetSelf() + smfContext.UserPlaneInformation = smf_context.NewUserPlaneInformation(userPlaneConfig) + for _, upf := range smfContext.UserPlaneInformation.UPFs { + upf.AssociationContext = context.Background() } - smctx := context.NewSMContext("imsi-208930000000002", 10) + smctx := smf_context.NewSMContext("imsi-208930000000002", 10) smctx.SMLock.Lock() defer smctx.SMLock.Unlock() @@ -659,7 +660,7 @@ func TestApplyPccRules(t *testing.T) { }, } smctx.SelectedPDUSessionType = 1 - smctx.SessionRules["SessRuleId-1"] = &context.SessionRule{ + smctx.SessionRules["SessRuleId-1"] = &smf_context.SessionRule{ SessionRule: &models.SessionRule{ AuthSessAmbr: &models.Ambr{ Uplink: "1000 Kbps", diff --git a/internal/context/upf.go b/internal/context/upf.go index 1ad0a73f..79cc3c02 100644 --- a/internal/context/upf.go +++ b/internal/context/upf.go @@ -70,11 +70,10 @@ type UPF struct { NodeID pfcpType.NodeID UPIPInfo pfcpType.UserPlaneIPResourceInformation - UPFStatus UPFStatus RecoveryTimeStamp time.Time - Ctx context.Context - CancelFunc context.CancelFunc + AssociationContext context.Context + CancelAssociation context.CancelFunc SNssaiInfos []*SnssaiUPFInfo N3Interfaces []*UPFInterfaceInfo @@ -328,7 +327,10 @@ func NewUPF( upfPool.Store(upf.GetID(), upf) // Initialize context - upf.UPFStatus = NotAssociated + upf.AssociationContext, upf.CancelAssociation = context.WithCancel(context.Background()) + upf.CancelAssociation() // necessary to avoid nil pointer for checks of AssociationContext before UPF is associated + + upf.NodeID = *nodeID upf.pdrIDGenerator = idgenerator.NewGenerator(1, math.MaxUint16) upf.farIDGenerator = idgenerator.NewGenerator(1, math.MaxUint32) upf.barIDGenerator = idgenerator.NewGenerator(1, math.MaxUint8) @@ -448,162 +450,155 @@ func SelectUPFByDnn(dnn string) *UPF { return upf } -func (upf *UPF) pdrID() (uint16, error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err := fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return 0, err +func (upf *UPF) pdrID() (pdrID uint16, err error) { + if err = upf.IsAssociated(); err != nil { + return } - var pdrID uint16 - if tmpID, err := upf.pdrIDGenerator.Allocate(); err != nil { + tmpID, err := upf.pdrIDGenerator.Allocate() + if err != nil { return 0, err - } else { - pdrID = uint16(tmpID) } - - return pdrID, nil + pdrID = uint16(tmpID) + return } -func (upf *UPF) farID() (uint32, error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err := fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return 0, err +func (upf *UPF) farID() (farID uint32, err error) { + if err = upf.IsAssociated(); err != nil { + return } - var farID uint32 - if tmpID, err := upf.farIDGenerator.Allocate(); err != nil { + tmpID, err := upf.farIDGenerator.Allocate() + if err != nil { return 0, err - } else { - farID = uint32(tmpID) } - - return farID, nil + farID = uint32(tmpID) + return } -func (upf *UPF) barID() (uint8, error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err := fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return 0, err +func (upf *UPF) barID() (barID uint8, err error) { + if err = upf.IsAssociated(); err != nil { + return } - var barID uint8 - if tmpID, err := upf.barIDGenerator.Allocate(); err != nil { + tmpID, err := upf.barIDGenerator.Allocate() + if err != nil { return 0, err - } else { - barID = uint8(tmpID) } - - return barID, nil + barID = uint8(tmpID) + return } -func (upf *UPF) qerID() (uint32, error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err := fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return 0, err +func (upf *UPF) qerID() (qerID uint32, err error) { + if err = upf.IsAssociated(); err != nil { + return } - var qerID uint32 - if tmpID, err := upf.qerIDGenerator.Allocate(); err != nil { + tmpID, err := upf.qerIDGenerator.Allocate() + if err != nil { return 0, err - } else { - qerID = uint32(tmpID) } - - return qerID, nil + qerID = uint32(tmpID) + return } -func (upf *UPF) urrID() (uint32, error) { - var urrID uint32 - if tmpID, err := upf.urrIDGenerator.Allocate(); err != nil { +func (upf *UPF) urrID() (urrID uint32, err error) { + tmpID, err := upf.urrIDGenerator.Allocate() + if err != nil { return 0, err - } else { - urrID = uint32(tmpID) } - - return urrID, nil + urrID = uint32(tmpID) + return } -func (upf *UPF) AddPDR() (*PDR, error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err := fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return nil, err +func (upf *UPF) AddPDR() (pdr *PDR, err error) { + if err = upf.IsAssociated(); err != nil { + return } - pdr := new(PDR) - if PDRID, err := upf.pdrID(); err != nil { - return nil, err - } else { - pdr.PDRID = PDRID - upf.pdrPool.Store(pdr.PDRID, pdr) + pdrID, err := upf.pdrID() + if err != nil { + return } - if newFAR, err := upf.AddFAR(); err != nil { - return nil, err - } else { - pdr.FAR = newFAR + newFAR, err := upf.AddFAR() + if err != nil { + return } - return pdr, nil + pdr = &PDR{ + PDRID: pdrID, + FAR: newFAR, + } + upf.pdrPool.Store(pdr.PDRID, pdr) + return } -func (upf *UPF) AddFAR() (*FAR, error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err := fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return nil, err +func (upf *UPF) AddFAR() (far *FAR, err error) { + if err = upf.IsAssociated(); err != nil { + return } - far := new(FAR) - if FARID, err := upf.farID(); err != nil { - return nil, err - } else { - far.FARID = FARID - upf.farPool.Store(far.FARID, far) + farID, err := upf.farID() + if err != nil { + return } - + far = &FAR{ + FARID: farID, + } + upf.farPool.Store(far.FARID, far) return far, nil } -func (upf *UPF) AddBAR() (*BAR, error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err := fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return nil, err +func (upf *UPF) AddBAR() (bar *BAR, err error) { + if err = upf.IsAssociated(); err != nil { + return } - bar := new(BAR) - if BARID, err := upf.barID(); err != nil { - } else { - bar.BARID = BARID - upf.barPool.Store(bar.BARID, bar) + barID, err := upf.barID() + if err != nil { + return } - - return bar, nil + bar = &BAR{ + BARID: barID, + } + upf.barPool.Store(bar.BARID, bar) + return } -func (upf *UPF) AddQER() (*QER, error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err := fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return nil, err +func (upf *UPF) AddQER() (qer *QER, err error) { + if err = upf.IsAssociated(); err != nil { + return } - qer := new(QER) - if QERID, err := upf.qerID(); err != nil { - } else { - qer.QERID = QERID - upf.qerPool.Store(qer.QERID, qer) + qerID, err := upf.qerID() + if err != nil { + return } - - return qer, nil + qer = &QER{ + QERID: qerID, + } + upf.qerPool.Store(qer.QERID, qer) + return } -func (upf *UPF) AddURR(urrId uint32, opts ...UrrOpt) (*URR, error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err := fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return nil, err +func (upf *UPF) AddURR(urrID uint32, opts ...UrrOpt) (urr *URR, err error) { + if err = upf.IsAssociated(); err != nil { + return } - urr := new(URR) - urr.MeasureMethod = MesureMethodVol - urr.MeasurementInformation = MeasureInformation(true, false) + if urrID == 0 { + urrID, err = upf.urrID() + if err != nil { + return + } + } + + urr = &URR{ + URRID: urrID, + MeasureMethod: MesureMethodVol, + MeasurementInformation: MeasureInformation(true, false), + } for _, opt := range opts { opt(urr) @@ -622,6 +617,10 @@ func (upf *UPF) AddURR(urrId uint32, opts ...UrrOpt) (*URR, error) { return urr, nil } +func (upf *UPF) GetUUID() uuid.UUID { + return upf.uuid +} + func (upf *UPF) GetQERById(qerId uint32) *QER { qer, ok := upf.qerPool.Load(qerId) if ok { @@ -632,50 +631,46 @@ func (upf *UPF) GetQERById(qerId uint32) *QER { // *** add unit test ***// func (upf *UPF) RemovePDR(pdr *PDR) (err error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err = fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return err + if err = upf.IsAssociated(); err != nil { + return } upf.pdrIDGenerator.FreeID(int64(pdr.PDRID)) upf.pdrPool.Delete(pdr.PDRID) - return nil + return } // *** add unit test ***// func (upf *UPF) RemoveFAR(far *FAR) (err error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err = fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return err + if err = upf.IsAssociated(); err != nil { + return } upf.farIDGenerator.FreeID(int64(far.FARID)) upf.farPool.Delete(far.FARID) - return nil + return } // *** add unit test ***// func (upf *UPF) RemoveBAR(bar *BAR) (err error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err = fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return err + if err = upf.IsAssociated(); err != nil { + return } upf.barIDGenerator.FreeID(int64(bar.BARID)) upf.barPool.Delete(bar.BARID) - return nil + return } // *** add unit test ***// func (upf *UPF) RemoveQER(qer *QER) (err error) { - if upf.UPFStatus != AssociatedSetUpSuccess { - err = fmt.Errorf("UPF[%s] not Associate with SMF", upf.GetNodeIDString()) - return err + if err = upf.IsAssociated(); err != nil { + return } upf.qerIDGenerator.FreeID(int64(qer.QERID)) upf.qerPool.Delete(qer.QERID) - return nil + return } func (upf *UPF) isSupportSnssai(snssai *SNssai) bool { @@ -697,20 +692,12 @@ func (upf *UPF) ProcEachSMContext(procFunc func(*SMContext)) { }) } -func (upf *UPF) MatchedSelection(selection *UPFSelectionParams) bool { - for _, snssaiInfo := range upf.SNssaiInfos { - currentSnssai := snssaiInfo.SNssai - if currentSnssai.Equal(selection.SNssai) { - for _, dnnInfo := range snssaiInfo.DnnList { - if dnnInfo.Dnn == selection.Dnn { - if selection.Dnai == "" { - return true - } else if dnnInfo.ContainsDNAI(selection.Dnai) { - return true - } - } - } - } +func (upf *UPF) IsAssociated() error { + select { + case <-upf.AssociationContext.Done(): + return fmt.Errorf("UPF[%s] not associated with SMF", + upf.NodeID.ResolveNodeIdToIp().String()) + default: + return nil } - return false } diff --git a/internal/context/upf_test.go b/internal/context/upf_test.go index c4f49d46..56429216 100644 --- a/internal/context/upf_test.go +++ b/internal/context/upf_test.go @@ -1,6 +1,7 @@ package context_test import ( + "context" "fmt" "net" "testing" @@ -10,7 +11,6 @@ import ( "github.com/free5gc/nas/nasMessage" "github.com/free5gc/pfcp/pfcpType" - "github.com/free5gc/smf/internal/context" smf_context "github.com/free5gc/smf/internal/context" "github.com/free5gc/smf/pkg/factory" ) @@ -158,11 +158,11 @@ func TestAddPDR(t *testing.T) { { upf: smf_context.NewUPF(mockUPNode, mockNodeID, mockIfaces), resultStr: "AddPDR should fail", - expectedError: fmt.Errorf("UPF[127.0.0.1] not Associate with SMF"), + expectedError: fmt.Errorf("UPF[127.0.0.1] not associated with SMF"), }, } - testCases[0].upf.UPFStatus = context.AssociatedSetUpSuccess + testCases[0].upf.AssociationContext = context.Background() Convey("AddPDR should indeed add PDR and report error appropiately", t, func() { for i, testcase := range testCases { @@ -194,18 +194,18 @@ func TestAddFAR(t *testing.T) { expectedError error }{ { - upf: context.NewUPF(mockUPNode, mockNodeID, mockIfaces), + upf: smf_context.NewUPF(mockUPNode, mockNodeID, mockIfaces), resultStr: "AddFAR should success", expectedError: nil, }, { - upf: context.NewUPF(mockUPNode, mockNodeID, mockIfaces), + upf: smf_context.NewUPF(mockUPNode, mockNodeID, mockIfaces), resultStr: "AddFAR should fail", - expectedError: fmt.Errorf("UPF[127.0.0.1] not Associate with SMF"), + expectedError: fmt.Errorf("UPF[127.0.0.1] not associated with SMF"), }, } - testCases[0].upf.UPFStatus = context.AssociatedSetUpSuccess + testCases[0].upf.AssociationContext = context.Background() Convey("AddFAR should indeed add FAR and report error appropiately", t, func() { for i, testcase := range testCases { @@ -237,18 +237,18 @@ func TestAddQER(t *testing.T) { expectedError error }{ { - upf: context.NewUPF(mockUPNode, mockNodeID, mockIfaces), + upf: smf_context.NewUPF(mockUPNode, mockNodeID, mockIfaces), resultStr: "AddQER should success", expectedError: nil, }, { - upf: context.NewUPF(mockUPNode, mockNodeID, mockIfaces), + upf: smf_context.NewUPF(mockUPNode, mockNodeID, mockIfaces), resultStr: "AddQER should fail", - expectedError: fmt.Errorf("UPF[127.0.0.1] not Associate with SMF"), + expectedError: fmt.Errorf("UPF[127.0.0.1] not associated with SMF"), }, } - testCases[0].upf.UPFStatus = context.AssociatedSetUpSuccess + testCases[0].upf.AssociationContext = context.Background() Convey("AddQER should indeed add QER and report error appropiately", t, func() { for i, testcase := range testCases { @@ -280,18 +280,18 @@ func TestAddBAR(t *testing.T) { expectedError error }{ { - upf: context.NewUPF(mockUPNode, mockNodeID, mockIfaces), + upf: smf_context.NewUPF(mockUPNode, mockNodeID, mockIfaces), resultStr: "AddBAR should success", expectedError: nil, }, { - upf: context.NewUPF(mockUPNode, mockNodeID, mockIfaces), + upf: smf_context.NewUPF(mockUPNode, mockNodeID, mockIfaces), resultStr: "AddBAR should fail", - expectedError: fmt.Errorf("UPF[127.0.0.1] not Associate with SMF"), + expectedError: fmt.Errorf("UPF[127.0.0.1] not associated with SMF"), }, } - testCases[0].upf.UPFStatus = context.AssociatedSetUpSuccess + testCases[0].upf.AssociationContext = context.Background() Convey("AddBAR should indeed add BAR and report error appropiately", t, func() { for i, testcase := range testCases { diff --git a/internal/context/user_plane_information.go b/internal/context/user_plane_information.go index 55b31d70..260d3e7b 100644 --- a/internal/context/user_plane_information.go +++ b/internal/context/user_plane_information.go @@ -773,14 +773,10 @@ func (upi *UserPlaneInformation) selectMatchUPF(selection *UPFSelectionParams) [ if currentSnssai.Equal(targetSnssai) { for _, dnnInfo := range snssaiInfo.DnnList { - if dnnInfo.Dnn != selection.Dnn { - continue - } - if selection.Dnai != "" && !dnnInfo.ContainsDNAI(selection.Dnai) { - continue + if dnnInfo.Dnn == selection.Dnn && dnnInfo.ContainsDNAI(selection.Dnai) { + upfList = append(upfList, upf) + break } - upfList = append(upfList, upf) - break } } } @@ -912,12 +908,12 @@ func (upi *UserPlaneInformation) SelectUPFAndAllocUEIP(selection *UPFSelectionPa sortedUPFList := createUPFListForSelection(upfList) for _, upf := range sortedUPFList { logger.CtxLog.Debugf("check start UPF: %s", - upi.GetUPFNameByIp(upf.GetNodeIDString())) - if upf.UPFStatus != AssociatedSetUpSuccess { - logger.CtxLog.Infof("PFCP Association not yet Established with: %s", - upi.GetUPFNameByIp(upf.GetNodeIDString())) + upi.GetUPFNameByIp(upf.NodeID.ResolveNodeIdToIp().String())) + if err = upf.IsAssociated(); err != nil { + logger.CtxLog.Infoln(err) continue } + pools, useStaticIPPool := getUEIPPool(upf, selection) if len(pools) == 0 { continue diff --git a/internal/context/user_plane_information_test.go b/internal/context/user_plane_information_test.go index 1564653b..08940c10 100644 --- a/internal/context/user_plane_information_test.go +++ b/internal/context/user_plane_information_test.go @@ -1,6 +1,7 @@ package context_test import ( + "context" "fmt" "net" "testing" @@ -280,7 +281,7 @@ func TestSelectUPFAndAllocUEIP(t *testing.T) { userplaneInformation := smf_context.NewUserPlaneInformation(configuration) for _, upf := range userplaneInformation.UPFs { - upf.UPFStatus = smf_context.AssociatedSetUpSuccess + upf.AssociationContext = context.Background() } for i := 0; i <= 100; i++ { @@ -504,7 +505,7 @@ var testCasesOfGetUEIPPool = []struct { func TestGetUEIPPool(t *testing.T) { userplaneInformation := smf_context.NewUserPlaneInformation(configForIPPoolAllocate) for _, upf := range userplaneInformation.UPFs { - upf.UPFStatus = smf_context.AssociatedSetUpSuccess + upf.AssociationContext = context.Background() } for ci, tc := range testCasesOfGetUEIPPool { diff --git a/internal/pfcp/handler/handler.go b/internal/pfcp/handler/handler.go index ea53c001..7448072c 100644 --- a/internal/pfcp/handler/handler.go +++ b/internal/pfcp/handler/handler.go @@ -124,12 +124,10 @@ func HandlePfcpSessionReportRequest(msg *pfcpUdp.Message) { pfcp_message.SendPfcpSessionReportResponse(msg.RemoteAddr, cause, seqFromUPF, 0) return } - if upf.UPFStatus != smf_context.AssociatedSetUpSuccess { - logger.PfcpLog.Warnf("PFCP Session Report Request : Not Associated with UPF[%s], Request Rejected", - upfNodeIDtoIPStr) + if err := upf.IsAssociated(); err != nil { + logger.PfcpLog.Warnf("PFCP Session Report Request rejected: %+v", err) cause.CauseValue = pfcpType.CauseNoEstablishedPfcpAssociation pfcp_message.SendPfcpSessionReportResponse(msg.RemoteAddr, cause, seqFromUPF, 0) - return } if smContext.UpCnxState == models.UpCnxState_DEACTIVATED { diff --git a/internal/pfcp/message/build.go b/internal/pfcp/message/build.go index 5a81fd80..d20af340 100644 --- a/internal/pfcp/message/build.go +++ b/internal/pfcp/message/build.go @@ -435,8 +435,9 @@ func BuildPfcpSessionEstablishmentRequest( urrMap[urr.URRID] = urr } for _, filteredURR := range urrMap { - if filteredURR.State == context.RULE_INITIAL { - msg.CreateURR = append(msg.CreateURR, urrToCreateURR(filteredURR)) + msg.CreateURR = append(msg.CreateURR, urrToCreateURR(filteredURR)) + if filteredURR.State == context.RULE_CREATE { + smContext.Log.Warn("Duplicate URR creation") } filteredURR.State = context.RULE_CREATE } @@ -564,6 +565,9 @@ func BuildPfcpSessionModificationRequest( for _, urr := range urrList { switch urr.State { + case context.RULE_CREATE: + smContext.Log.Warn("Duplicate URR creation") + fallthrough case context.RULE_INITIAL: msg.CreateURR = append(msg.CreateURR, urrToCreateURR(urr)) case context.RULE_UPDATE: diff --git a/internal/pfcp/message/send.go b/internal/pfcp/message/send.go index 0b083114..f65f1546 100644 --- a/internal/pfcp/message/send.go +++ b/internal/pfcp/message/send.go @@ -140,8 +140,8 @@ func SendPfcpSessionEstablishmentRequest( urrList []*context.URR, ) (resMsg *pfcpUdp.Message, err error) { nodeIDtoIP := upf.GetNodeIDString() - if upf.UPFStatus != context.AssociatedSetUpSuccess { - return nil, fmt.Errorf("Not Associated with UPF[%s]", nodeIDtoIP) + if err = upf.IsAssociated(); err != nil { + return nil, err } pfcpMsg, err := BuildPfcpSessionEstablishmentRequest(upf.NodeID, nodeIDtoIP, @@ -223,8 +223,8 @@ func SendPfcpSessionModificationRequest( urrList []*context.URR, ) (resMsg *pfcpUdp.Message, err error) { nodeIDtoIP := upf.GetNodeIDString() - if upf.UPFStatus != context.AssociatedSetUpSuccess { - return nil, fmt.Errorf("Not Associated with UPF[%s]", nodeIDtoIP) + if err = upf.IsAssociated(); err != nil { + return nil, err } pfcpMsg, err := BuildPfcpSessionModificationRequest(upf.NodeID, nodeIDtoIP, @@ -296,10 +296,13 @@ func SendPfcpSessionModificationResponse(addr *net.UDPAddr) { udp.SendPfcpResponse(message, addr) } -func SendPfcpSessionDeletionRequest(upf *context.UPF, ctx *context.SMContext) (resMsg *pfcpUdp.Message, err error) { +func SendPfcpSessionDeletionRequest( + upf *context.UPF, + ctx *context.SMContext, +) (resMsg *pfcpUdp.Message, err error) { nodeIDtoIP := upf.GetNodeIDString() - if upf.UPFStatus != context.AssociatedSetUpSuccess { - return nil, fmt.Errorf("Not Associated with UPF[%s]", nodeIDtoIP) + if err = upf.IsAssociated(); err != nil { + return nil, err } pfcpMsg, err := BuildPfcpSessionDeletionRequest() diff --git a/internal/pfcp/message/send_test.go b/internal/pfcp/message/send_test.go index cf6cd33c..a114d0c3 100644 --- a/internal/pfcp/message/send_test.go +++ b/internal/pfcp/message/send_test.go @@ -24,9 +24,8 @@ func TestSendPfcpSessionEstablishmentRequest(t *testing.T) { } func TestSendHeartbeatResponse(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - smf_context.GetSelf().Ctx = ctx - smf_context.GetSelf().PFCPCancelFunc = cancel + smfContext := smf_context.GetSelf() + smfContext.PfcpContext, smfContext.PfcpCancelFunc = context.WithCancel(context.Background()) udp.Run(smf_pfcp.Dispatch) udp.ServerStartTime = time.Now() diff --git a/internal/pfcp/udp/udp.go b/internal/pfcp/udp/udp.go index dcdf12ab..3e1e8685 100644 --- a/internal/pfcp/udp/udp.go +++ b/internal/pfcp/udp/udp.go @@ -1,7 +1,6 @@ package udp import ( - "context" "errors" "net" "runtime/debug" @@ -18,8 +17,6 @@ const MaxPfcpUdpDataSize = 1024 var Server *pfcpUdp.PfcpServer -var cancelFunc *context.CancelFunc - var ServerStartTime time.Time func Run(dispatch func(*pfcpUdp.Message)) { @@ -30,10 +27,9 @@ func Run(dispatch func(*pfcpUdp.Message)) { } }() - newCtx, newCancelFunc := context.WithCancel(smf_context.GetSelf().Ctx) - cancelFunc = &newCancelFunc + smfContext := smf_context.GetSelf() - serverIP := smf_context.GetSelf().ListenIP().To4() + serverIP := smfContext.ListenIP().To4() Server = pfcpUdp.NewPfcpServer(serverIP.String()) err := Server.Listen() @@ -61,7 +57,7 @@ func Run(dispatch func(*pfcpUdp.Message)) { } else { logger.PfcpLog.Warnf("Read PFCP error: %v, msg: [%v]", errReadFrom, msg) select { - case <-newCtx.Done(): + case <-smfContext.PfcpContext.Done(): // PFCP is closing return default: @@ -94,7 +90,8 @@ func SendPfcpRequest(sndMsg *pfcp.Message, addr *net.UDPAddr) (rsvMsg *pfcpUdp.M } func ClosePfcp() error { - (*cancelFunc)() + smf_context.GetSelf().PfcpCancelFunc() + closeErr := Server.Close() if closeErr != nil { logger.PfcpLog.Errorf("Pfcp close err: %+v", closeErr) diff --git a/internal/pfcp/udp/udp_test.go b/internal/pfcp/udp/udp_test.go index 5e98b8eb..07f084d5 100644 --- a/internal/pfcp/udp/udp_test.go +++ b/internal/pfcp/udp/udp_test.go @@ -21,16 +21,16 @@ const testPfcpClientPort = 12345 func TestRun(t *testing.T) { // Set SMF Node ID - smf_context.GetSelf().CPNodeID = pfcpType.NodeID{ + smfContext := smf_context.GetSelf() + + smfContext.CPNodeID = pfcpType.NodeID{ NodeIdType: pfcpType.NodeIdTypeIpv4Address, IP: net.ParseIP("127.0.0.1").To4(), } - smf_context.GetSelf().ExternalAddr = "127.0.0.1" - smf_context.GetSelf().ListenAddr = "127.0.0.1" + smfContext.ExternalAddr = "127.0.0.1" + smfContext.ListenAddr = "127.0.0.1" - ctx, cancel := context.WithCancel(context.Background()) - smf_context.GetSelf().Ctx = ctx - smf_context.GetSelf().PFCPCancelFunc = cancel + smfContext.PfcpContext, smfContext.PfcpCancelFunc = context.WithCancel(context.Background()) udp.Run(smf_pfcp.Dispatch) testPfcpReq := pfcp.Message{ diff --git a/internal/sbi/api_upi.go b/internal/sbi/api_upi.go index e25ba3ce..d3049eba 100644 --- a/internal/sbi/api_upi.go +++ b/internal/sbi/api_upi.go @@ -1,7 +1,6 @@ package sbi import ( - "context" "net/http" "github.com/gin-gonic/gin" @@ -69,9 +68,8 @@ func (s *Server) PostUpNodesLinks(c *gin.Context) { for _, upf := range upi.UPFs { // only associate new ones - if upf.UPFStatus == smf_context.NotAssociated { - upf.Ctx, upf.CancelFunc = context.WithCancel(context.Background()) - go s.Processor().ToBeAssociatedWithUPF(smf_context.GetSelf().Ctx, upf) + if err := upf.IsAssociated(); err != nil { + go s.Processor().ToBeAssociatedWithUPF(smf_context.GetSelf().PfcpContext, upf) } } c.JSON(http.StatusOK, gin.H{"status": "OK"}) @@ -88,8 +86,8 @@ func (s *Server) DeleteUpNodeLink(c *gin.Context) { defer upi.Mu.Unlock() if upNode, ok := upi.UPNodes[upNodeRef]; ok { if upNode.GetType() == smf_context.UPNODE_UPF { + upNode.(*smf_context.UPF).CancelAssociation() go s.Processor().ReleaseAllResourcesOfUPF(upNode.(*smf_context.UPF)) - upNode.(*smf_context.UPF).CancelFunc() } upi.UpNodeDelete(upNodeRef) c.JSON(http.StatusOK, gin.H{"status": "OK"}) diff --git a/internal/sbi/processor/association.go b/internal/sbi/processor/association.go index f3ff7a89..67313127 100644 --- a/internal/sbi/processor/association.go +++ b/internal/sbi/processor/association.go @@ -15,7 +15,7 @@ import ( "github.com/free5gc/smf/internal/pfcp/message" ) -func (p *Processor) ToBeAssociatedWithUPF(ctx context.Context, upf *smf_context.UPF) { +func (p *Processor) ToBeAssociatedWithUPF(smfPfcpContext context.Context, upf *smf_context.UPF) { var upfStr string if upf.NodeID.NodeIdType == pfcpType.NodeIdTypeFqdn { upfStr = fmt.Sprintf("[%s](%s)", upf.NodeID.FQDN, upf.GetNodeIDString()) @@ -24,24 +24,21 @@ func (p *Processor) ToBeAssociatedWithUPF(ctx context.Context, upf *smf_context. } for { - ensureSetupPfcpAssociation(ctx, upf, upfStr) - if isDone(ctx, upf) { - break - } - - if smf_context.GetSelf().PfcpHeartbeatInterval == 0 { + // check if SMF PFCP context (parent) was canceled + // note: UPF AssociationContexts are children of smfPfcpContext + select { + case <-smfPfcpContext.Done(): + logger.MainLog.Infoln("Canceled SMF PFCP context") return - } - - keepHeartbeatTo(ctx, upf, upfStr) - // return when UPF heartbeat lost is detected or association is canceled - if isDone(ctx, upf) { - break - } + default: + ensureSetupPfcpAssociation(smfPfcpContext, upf, upfStr) + if smf_context.GetSelf().PfcpHeartbeatInterval == 0 { + return + } + keepHeartbeatTo(upf, upfStr) + // returns when UPF heartbeat loss is detected or association is canceled - p.releaseAllResourcesOfUPF(upf, upfStr) - if isDone(ctx, upf) { - break + p.releaseAllResourcesOfUPF(upf, upfStr) } } } @@ -56,41 +53,30 @@ func (p *Processor) ReleaseAllResourcesOfUPF(upf *smf_context.UPF) { p.releaseAllResourcesOfUPF(upf, upfStr) } -func isDone(ctx context.Context, upf *smf_context.UPF) bool { - select { - case <-ctx.Done(): - return true - case <-upf.Ctx.Done(): - return true - default: - return false - } -} - -func ensureSetupPfcpAssociation(ctx context.Context, upf *smf_context.UPF, upfStr string) { +func ensureSetupPfcpAssociation(parentContext context.Context, upf *smf_context.UPF, upfStr string) { alertTime := time.Now() alertInterval := smf_context.GetSelf().AssocFailAlertInterval retryInterval := smf_context.GetSelf().AssocFailRetryInterval for { - timer := time.After(retryInterval) err := setupPfcpAssociation(upf, upfStr) if err == nil { + // success + // assign UPF an AssociationContext, with SMF PFCP Context as parent + upf.AssociationContext, upf.CancelAssociation = context.WithCancel(parentContext) return } - logger.MainLog.Warnf("Failed to setup an association with UPF%s, error:%+v", upfStr, err) + logger.MainLog.Warnf("Failed to setup an association with UPF[%s], error:%+v", upfStr, err) now := time.Now() logger.MainLog.Debugf("now %+v, alertTime %+v", now, alertTime) if now.After(alertTime.Add(alertInterval)) { - logger.MainLog.Errorf("ALERT for UPF%s", upfStr) + logger.MainLog.Errorf("ALERT for UPF[%s]", upfStr) alertTime = now } - logger.MainLog.Debugf("Wait %+v (or less) until next retry attempt", retryInterval) - select { - case <-ctx.Done(): - logger.MainLog.Infof("Canceled association request to UPF%s", upfStr) - return - case <-upf.Ctx.Done(): - logger.MainLog.Infof("Canceled association request to this UPF%s only", upfStr) + logger.MainLog.Debugf("Wait %+v until next retry attempt", retryInterval) + timer := time.After(retryInterval) + select { // no default case, either case needs to be true to continue + case <-parentContext.Done(): + logger.MainLog.Infoln("Canceled SMF PFCP context") return case <-timer: continue @@ -119,8 +105,6 @@ func setupPfcpAssociation(upf *smf_context.UPF, upfStr string) error { logger.MainLog.Infof("Received PFCP Association Setup Accepted Response from UPF%s", upfStr) - upf.UPFStatus = smf_context.AssociatedSetUpSuccess - if rsp.UserPlaneIPResourceInformation != nil { upf.UPIPInfo = *rsp.UserPlaneIPResourceInformation @@ -131,7 +115,7 @@ func setupPfcpAssociation(upf *smf_context.UPF, upfStr string) error { return nil } -func keepHeartbeatTo(ctx context.Context, upf *smf_context.UPF, upfStr string) { +func keepHeartbeatTo(upf *smf_context.UPF, upfStr string) { for { err := doPfcpHeartbeat(upf, upfStr) if err != nil { @@ -141,11 +125,8 @@ func keepHeartbeatTo(ctx context.Context, upf *smf_context.UPF, upfStr string) { timer := time.After(smf_context.GetSelf().PfcpHeartbeatInterval) select { - case <-ctx.Done(): - logger.MainLog.Infof("Canceled Heartbeat with UPF%s", upfStr) - return - case <-upf.Ctx.Done(): - logger.MainLog.Infof("Canceled Heartbeat to this UPF%s only", upfStr) + case <-upf.AssociationContext.Done(): + logger.MainLog.Infof("Canceled association to UPF[%s]", upfStr) return case <-timer: continue @@ -154,15 +135,15 @@ func keepHeartbeatTo(ctx context.Context, upf *smf_context.UPF, upfStr string) { } func doPfcpHeartbeat(upf *smf_context.UPF, upfStr string) error { - if upf.UPFStatus != smf_context.AssociatedSetUpSuccess { - return fmt.Errorf("invalid status of UPF%s: %d", upfStr, upf.UPFStatus) + if err := upf.IsAssociated(); err != nil { + return fmt.Errorf("Cancel heartbeat: %+v", err) } logger.MainLog.Debugf("Sending PFCP Heartbeat Request to UPF%s", upfStr) resMsg, err := message.SendPfcpHeartbeatRequest(upf) if err != nil { - upf.UPFStatus = smf_context.NotAssociated + upf.CancelAssociation() upf.RecoveryTimeStamp = time.Time{} return fmt.Errorf("SendPfcpHeartbeatRequest error: %w", err) } @@ -179,7 +160,7 @@ func doPfcpHeartbeat(upf *smf_context.UPF, upfStr string) error { upf.RecoveryTimeStamp = rsp.RecoveryTimeStamp.RecoveryTimeStamp } else if upf.RecoveryTimeStamp.Before(rsp.RecoveryTimeStamp.RecoveryTimeStamp) { // received a newer recovery timestamp - upf.UPFStatus = smf_context.NotAssociated + upf.CancelAssociation() upf.RecoveryTimeStamp = time.Time{} return fmt.Errorf("received PFCP Heartbeat Response RecoveryTimeStamp has been updated") } diff --git a/internal/sbi/processor/pdu_session_test.go b/internal/sbi/processor/pdu_session_test.go index 738fc8cd..bfb0e838 100644 --- a/internal/sbi/processor/pdu_session_test.go +++ b/internal/sbi/processor/pdu_session_test.go @@ -370,9 +370,8 @@ func initDiscAMFStubNRF() { } func initStubPFCP() { - ctx, cancel := context.WithCancel(context.Background()) - smf_context.GetSelf().Ctx = ctx - smf_context.GetSelf().PFCPCancelFunc = cancel + smfContext := smf_context.GetSelf() + smfContext.PfcpContext, smfContext.PfcpCancelFunc = context.WithCancel(context.Background()) udp.Run(pfcp.Dispatch) } @@ -451,9 +450,8 @@ func TestHandlePDUSessionSMContextCreate(t *testing.T) { initStubPFCP() // modify associate setup status - allUPFs := smf_context.GetSelf().UserPlaneInformation.UPFs - for _, upfNode := range allUPFs { - upfNode.UPFStatus = smf_context.AssociatedSetUpSuccess + for _, upf := range smf_context.GetSelf().UserPlaneInformation.UPFs { + upf.AssociationContext = context.Background() } testCases := []struct { diff --git a/pkg/service/init.go b/pkg/service/init.go index 0a4341a1..00ec52cb 100644 --- a/pkg/service/init.go +++ b/pkg/service/init.go @@ -88,9 +88,8 @@ func NewApp( smf.ctx, smf.cancel = context.WithCancel(ctx) // for PFCP - ctx, cancel := context.WithCancel(smf.ctx) - smf_context.GetSelf().Ctx = ctx - smf_context.GetSelf().PFCPCancelFunc = cancel + smfContext := smf_context.GetSelf() + smfContext.PfcpContext, smfContext.PfcpCancelFunc = context.WithCancel(smf.ctx) SMF = smf diff --git a/pkg/utils/pfcp_util.go b/pkg/utils/pfcp_util.go index dfb07f4b..224558f6 100644 --- a/pkg/utils/pfcp_util.go +++ b/pkg/utils/pfcp_util.go @@ -11,17 +11,12 @@ import ( "github.com/free5gc/smf/pkg/service" ) -var ( - pfcpStart func(a *service.SmfApp) - pfcpStop func() -) +func InitPFCPFunc(pCtx context.Context) (func(a *service.SmfApp), func()) { + smfContext := smf_context.GetSelf() -func InitPFCPFunc() (func(a *service.SmfApp), func()) { - pfcpStart = func(a *service.SmfApp) { + pfcpStart := func(a *service.SmfApp) { // Initialize PFCP server - ctx, cancel := context.WithCancel(context.Background()) - smf_context.GetSelf().Ctx = ctx - smf_context.GetSelf().PFCPCancelFunc = cancel + smfContext.PfcpContext, smfContext.PfcpCancelFunc = context.WithCancel(pCtx) udp.Run(pfcp.Dispatch) @@ -29,13 +24,12 @@ func InitPFCPFunc() (func(a *service.SmfApp), func()) { time.Sleep(1000 * time.Millisecond) for _, upf := range smf_context.GetSelf().UserPlaneInformation.UPFs { - upf.Ctx, upf.CancelFunc = context.WithCancel(ctx) - go a.Processor().ToBeAssociatedWithUPF(ctx, upf) + go a.Processor().ToBeAssociatedWithUPF(smfContext.PfcpContext, upf) } } - pfcpStop = func() { - smf_context.GetSelf().PFCPCancelFunc() + pfcpStop := func() { + smfContext.PfcpCancelFunc() err := udp.Server.Close() if err != nil { logger.Log.Errorf("udp server close failed %+v", err)