diff --git a/Makefile b/Makefile index a5e79b3..50ef69a 100644 --- a/Makefile +++ b/Makefile @@ -95,7 +95,7 @@ golangci-lint: ## Download golangci-lint locally if necessary. MOCKERY = $(shell pwd)/bin/mockery mockery: ## Download mockery if necessary. - $(call go-install-tool,$(MOCKERY),github.com/vektra/mockery/v2@v2.12.3) + $(call go-install-tool,$(MOCKERY),github.com/vektra/mockery/v2@v2.14.0) # go-get-tool will 'go get' any package $2 and install it to $1. PROJECT_DIR := $(shell dirname $(abspath $(lastword $(MAKEFILE_LIST)))) diff --git a/go.mod b/go.mod index 587af6a..e340b1f 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/spf13/cobra v1.5.0 github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.7.0 k8s.io/api v0.24.2 k8s.io/apimachinery v0.24.2 k8s.io/client-go v0.24.2 @@ -45,7 +46,9 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/afero v1.6.0 // indirect + github.com/stretchr/objx v0.2.0 // indirect github.com/vishvananda/netlink v1.1.1-0.20211101163509-b10eb8fe5cf6 // indirect github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect golang.org/x/net v0.0.0-20220225172249-27dd8689420f // indirect diff --git a/pkg/tc/actuator_file_writer_test.go b/pkg/tc/actuator_file_writer_test.go new file mode 100644 index 0000000..bd444c8 --- /dev/null +++ b/pkg/tc/actuator_file_writer_test.go @@ -0,0 +1,140 @@ +package tc_test + +import ( + "flag" + "fmt" + "os" + "path/filepath" + "time" + + "k8s.io/klog/v2" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/Mellanox/multi-networkpolicy-tc/pkg/tc" + "github.com/Mellanox/multi-networkpolicy-tc/pkg/tc/types" + "github.com/Mellanox/multi-networkpolicy-tc/pkg/utils" +) + +func getLastModifiedTime(path string) time.Time { + fInfo, err := os.Lstat(path) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + return fInfo.ModTime() +} + +var _ = Describe("Actuator file writer tests", Ordered, func() { + var tempDir string + var logger klog.Logger + var actuator tc.Actuator + + BeforeAll(func() { + // init logger + fs := flag.NewFlagSet("test-flag-set", flag.PanicOnError) + klog.InitFlags(fs) + Expect(fs.Set("v", "8")).ToNot(HaveOccurred()) + logger = klog.NewKlogr().WithName("actuator-file-writer-test") + DeferCleanup(klog.Flush) + By("Logger initialized") + + // create temp dir + tempDir = GinkgoT().TempDir() + By(fmt.Sprintf("Generated temp dir for test: %s", tempDir)) + }) + + Context("Actuator file writer with bad path", func() { + It("fails to actuate on non existent path", func() { + nonExistentPath := filepath.Join(tempDir, "does", "not", "exist") + actuator = tc.NewActuatorFileWriterImpl(nonExistentPath, logger) + objs := &tc.TCObjects{ + QDisc: types.NewIngressQdisc(), + } + err := actuator.Actuate(objs) + Expect(err).To(HaveOccurred()) + }) + + It("fails to actuate on invalid path", func() { + invalidPath := "" + actuator = tc.NewActuatorFileWriterImpl(invalidPath, logger) + objs := &tc.TCObjects{ + QDisc: types.NewIngressQdisc(), + } + err := actuator.Actuate(objs) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("Actuator file writer with valid path", func() { + var tmpFilePath string + objs := &tc.TCObjects{ + QDisc: types.NewIngressQdisc(), + Filters: []types.Filter{ + types.NewFlowerFilterBuilder().WithProtocol(types.FilterProtocolIP).WithPriority(100).Build(), + }, + } + expectedFileContent := `qdisc: ingress +filters: +protocol ip pref 100 flower +` + + BeforeEach(func() { + tmpFilePath = filepath.Join(tempDir, "test-file") + exist, err := utils.PathExists(tmpFilePath) + Expect(err).ToNot(HaveOccurred()) + Expect(exist).To(BeFalse()) + actuator = tc.NewActuatorFileWriterImpl(tmpFilePath, logger) + }) + + AfterEach(func() { + exist, err := utils.PathExists(tmpFilePath) + Expect(err).ToNot(HaveOccurred()) + if exist { + Expect(os.Remove(tmpFilePath)).ToNot(HaveOccurred()) + } + }) + + It("Writes objects to file when file does not exist", func() { + err := actuator.Actuate(objs) + Expect(err).ToNot(HaveOccurred()) + + content, err := os.ReadFile(tmpFilePath) + Expect(err).ToNot(HaveOccurred()) + Expect(string(content)).To(BeEquivalentTo(expectedFileContent)) + }) + + It("updates objects in file when file exist", func() { + err := actuator.Actuate(objs) + Expect(err).ToNot(HaveOccurred()) + + objs.Filters = append( + objs.Filters, + types.NewFlowerFilterBuilder().WithProtocol(types.FilterProtocolAll).WithPriority(200).Build()) + + err = actuator.Actuate(objs) + Expect(err).ToNot(HaveOccurred()) + + content, err := os.ReadFile(tmpFilePath) + Expect(err).ToNot(HaveOccurred()) + expectedFileContent = `qdisc: ingress +filters: +protocol ip pref 100 flower +protocol all pref 200 flower +` + Expect(string(content)).To(BeEquivalentTo(expectedFileContent)) + }) + + It("does not update file if same objects provided", func() { + err := actuator.Actuate(objs) + Expect(err).ToNot(HaveOccurred()) + + firstModified := getLastModifiedTime(tmpFilePath) + + err = actuator.Actuate(objs) + Expect(err).ToNot(HaveOccurred()) + + lastModified := getLastModifiedTime(tmpFilePath) + + Expect(firstModified.Equal(lastModified)).To(BeTrue()) + }) + }) +}) diff --git a/pkg/tc/actuator_tc.go b/pkg/tc/actuator_tc.go index c0f8afa..52c09e2 100644 --- a/pkg/tc/actuator_tc.go +++ b/pkg/tc/actuator_tc.go @@ -21,6 +21,10 @@ type ActuatorTCImpl struct { // Actuate is an implementation of Actuator interface. it applies TCObjects on the representor // Note: it assumes all filters are in Chain 0 func (a *ActuatorTCImpl) Actuate(objects *TCObjects) error { + if objects.QDisc == nil && len(objects.Filters) > 0 { + return errors.New("Qdisc cannot be nil if Filters are provided") + } + // list qdiscs currentQDiscs, err := a.tcApi.QDiscList() if err != nil { diff --git a/pkg/tc/actuator_tc_test.go b/pkg/tc/actuator_tc_test.go new file mode 100644 index 0000000..d36820f --- /dev/null +++ b/pkg/tc/actuator_tc_test.go @@ -0,0 +1,280 @@ +package tc_test + +import ( + "flag" + + "github.com/pkg/errors" + "k8s.io/klog/v2" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/stretchr/testify/mock" + + "github.com/Mellanox/multi-networkpolicy-tc/pkg/tc" + tctypes "github.com/Mellanox/multi-networkpolicy-tc/pkg/tc/types" + + tcmocks "github.com/Mellanox/multi-networkpolicy-tc/pkg/tc/mocks" +) + +func ingressQdiscMatch() func(q tctypes.QDisc) bool { + return func(q tctypes.QDisc) bool { + return q.Type() == tctypes.QDiscIngressType + } +} + +func filterMatch(filter tctypes.Filter) func(f tctypes.Filter) bool { + return func(f tctypes.Filter) bool { + return filter.Equals(f) + } +} + +func filterAttrMatch(filterAttr *tctypes.FilterAttrs) func(f *tctypes.FilterAttrs) bool { + return func(f *tctypes.FilterAttrs) bool { + return filterAttr.Equals(f) + } +} + +func chainMatch(chain uint16) func(c tctypes.Chain) bool { + // Note(adrianc): ATM we are not needed to match on parent. it may change... + return func(c tctypes.Chain) bool { + return chain == *c.Attrs().Chain + } +} + +var _ = Describe("Actuator TC tests", func() { + var actuator tc.Actuator + var tcMock *tcmocks.TC + var logger klog.Logger + + BeforeEach(func() { + // init logger + fs := flag.NewFlagSet("test-flag-set", flag.PanicOnError) + klog.InitFlags(fs) + Expect(fs.Set("v", "8")).ToNot(HaveOccurred()) + logger = klog.NewKlogr().WithName("actuator-tc-test") + DeferCleanup(klog.Flush) + By("Logger initialized") + + tcMock = tcmocks.NewTC(GinkgoT()) + actuator = tc.NewActuatorTCImpl(tcMock, logger) + }) + + Context("Actuate Qdisc Only", func() { + It("fails if listing qdisc fails", func() { + tcObj := &tc.TCObjects{ + QDisc: tctypes.NewIngressQdisc(), + } + + tcMock.On("QDiscList").Return(nil, errors.New("test error!")) + + err := actuator.Actuate(tcObj) + Expect(err).To(HaveOccurred()) + }) + + It("fails if delete qdisc fails", func() { + tcObj := &tc.TCObjects{} + + tcMock.On("QDiscList").Return([]tctypes.QDisc{tctypes.NewIngressQdisc()}, nil) + tcMock.On("QDiscDel", mock.Anything).Return(errors.New("test error!")) + + err := actuator.Actuate(tcObj) + Expect(err).To(HaveOccurred()) + }) + + When("TCObjects does not contain Qdisc", func() { + It("deletes ingress Qdisc when exists", func() { + tcObj := &tc.TCObjects{} + + tcMock.On("QDiscList").Return([]tctypes.QDisc{tctypes.NewIngressQdisc()}, nil) + tcMock.On("QDiscDel", mock.MatchedBy(ingressQdiscMatch())).Return(nil) + + err := actuator.Actuate(tcObj) + Expect(err).ToNot(HaveOccurred()) + }) + + It("does nothing if ingress Qdisc does not exists", func() { + tcObj := &tc.TCObjects{} + + tcMock.On("QDiscList").Return([]tctypes.QDisc{}, nil) + + err := actuator.Actuate(tcObj) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + When("TCObjects contain ingress Qdisc", func() { + It("does nothing if ingress Qdisc exist without chain 0", func() { + tcObj := &tc.TCObjects{ + QDisc: tctypes.NewIngressQdisc(), + } + + tcMock.On("QDiscList").Return([]tctypes.QDisc{}, nil) + tcMock.On("ChainList", mock.Anything).Return([]tctypes.Chain{ + tctypes.NewChainBuilder().WithParent(0xfffffff1).WithChain(1).Build()}, nil) + + err := actuator.Actuate(tcObj) + Expect(err).ToNot(HaveOccurred()) + }) + + It("deletes chain 0 on ingress qdisc when exists", func() { + tcObj := &tc.TCObjects{ + QDisc: tctypes.NewIngressQdisc(), + } + + tcMock.On("QDiscList").Return([]tctypes.QDisc{}, nil) + tcMock.On("ChainList", mock.Anything).Return([]tctypes.Chain{ + tctypes.NewChainBuilder().WithParent(0xfffffff1).WithChain(0).Build()}, nil) + tcMock.On("ChainDel", + mock.MatchedBy(ingressQdiscMatch()), + mock.MatchedBy(chainMatch(0))). + Return(nil) + + err := actuator.Actuate(tcObj) + Expect(err).ToNot(HaveOccurred()) + }) + + It("does nothing if ingress Qdisc exists, chain 0 does not exist", func() { + tcObj := &tc.TCObjects{ + QDisc: tctypes.NewIngressQdisc(), + } + + tcMock.On("QDiscList").Return([]tctypes.QDisc{tctypes.NewIngressQdisc()}, nil) + tcMock.On("ChainList", mock.Anything).Return([]tctypes.Chain{}, nil) + + err := actuator.Actuate(tcObj) + Expect(err).ToNot(HaveOccurred()) + }) + + It("fails if delete chain fails", func() { + tcObj := &tc.TCObjects{ + QDisc: tctypes.NewIngressQdisc(), + } + + tcMock.On("QDiscList").Return([]tctypes.QDisc{}, nil) + tcMock.On("ChainList", mock.Anything).Return([]tctypes.Chain{ + tctypes.NewChainBuilder().WithParent(0xfffffff1).WithChain(0).Build()}, nil) + tcMock.On("ChainDel", mock.Anything, mock.Anything).Return(errors.New("test error!")) + + err := actuator.Actuate(tcObj) + Expect(err).To(HaveOccurred()) + }) + }) + }) + + Context("Actuate with filters", func() { + neededFilters := []tctypes.Filter{ + tctypes.NewFlowerFilterBuilder(). + WithProtocol(tctypes.FilterProtocolIP). + WithPriority(100). + WithMatchKeyDstIP("10.100.0.0/24"). + WithAction(tctypes.NewGenericActionBuiler().WithDrop().Build()). + Build(), + tctypes.NewFlowerFilterBuilder(). + WithProtocol(tctypes.FilterProtocolIP). + WithPriority(200). + WithMatchKeyDstIP("10.100.0.0/16"). + WithAction(tctypes.NewGenericActionBuiler().WithPass().Build()). + Build(), + } + existingFilters := []tctypes.Filter{ + tctypes.NewFlowerFilterBuilder(). + WithProtocol(tctypes.FilterProtocolIP). + WithPriority(100). + WithMatchKeyDstIP("10.100.1.0/24"). + WithAction(tctypes.NewGenericActionBuiler().WithDrop().Build()). + Build(), + tctypes.NewFlowerFilterBuilder(). + WithProtocol(tctypes.FilterProtocolIP). + WithPriority(200). + WithMatchKeyDstIP("10.100.0.0/16"). + WithAction(tctypes.NewGenericActionBuiler().WithPass().Build()). + Build(), + } + var tcObj *tc.TCObjects + + BeforeEach(func() { + tcObj = &tc.TCObjects{ + QDisc: tctypes.NewIngressQdisc(), + Filters: neededFilters, + } + }) + + When("no ingress qdisc in TCObjects", func() { + It("fails", func() { + tcObj.QDisc = nil + + err := actuator.Actuate(tcObj) + Expect(err).To(HaveOccurred()) + }) + }) + + When("filters provided in TCObjects, no filters set on ingress qdisc", func() { + BeforeEach(func() { + tcMock.On("QDiscList").Return([]tctypes.QDisc{tctypes.NewIngressQdisc()}, nil) + }) + + It("adds them to qdisc", func() { + tcMock.On("FilterList", mock.MatchedBy(ingressQdiscMatch())).Return([]tctypes.Filter{}, nil) + for i := range neededFilters { + tcMock.On( + "FilterAdd", + mock.MatchedBy(ingressQdiscMatch()), + mock.MatchedBy(filterMatch(neededFilters[i]))). + Return(nil) + } + + err := actuator.Actuate(tcObj) + Expect(err).ToNot(HaveOccurred()) + }) + + It("fails if listing filter on qdisc fails", func() { + tcMock.On("FilterList", mock.Anything). + Return(nil, errors.New("test error!")) + + err := actuator.Actuate(tcObj) + Expect(err).To(HaveOccurred()) + }) + + It("fails if adding filter to qdisc fails", func() { + tcMock.On("FilterList", mock.MatchedBy(ingressQdiscMatch())).Return([]tctypes.Filter{}, nil) + tcMock.On("FilterAdd", mock.Anything, mock.Anything).Return(errors.New("test error!")) + + err := actuator.Actuate(tcObj) + Expect(err).To(HaveOccurred()) + }) + }) + + When("filters provided in TCObjects, and filters set on ingress qdisc", func() { + BeforeEach(func() { + tcMock.On("QDiscList").Return([]tctypes.QDisc{tctypes.NewIngressQdisc()}, nil) + tcMock.On("FilterList", mock.MatchedBy(ingressQdiscMatch())).Return(existingFilters, nil) + }) + + It("removes un-needed filters and adds needed filters", func() { + tcMock.On( + "FilterDel", + mock.MatchedBy(ingressQdiscMatch()), + mock.MatchedBy(filterAttrMatch(existingFilters[0].Attrs()))). + Return(nil) + tcMock.On( + "FilterAdd", + mock.MatchedBy(ingressQdiscMatch()), + mock.MatchedBy(filterMatch(neededFilters[0]))). + Return(nil) + + err := actuator.Actuate(tcObj) + Expect(err).ToNot(HaveOccurred()) + }) + + It("fails if removing filter from qdisc fails", func() { + tcMock.On( + "FilterDel", + mock.MatchedBy(ingressQdiscMatch()), + mock.MatchedBy(filterAttrMatch(existingFilters[0].Attrs()))). + Return(errors.New("test error!")) + err := actuator.Actuate(tcObj) + Expect(err).To(HaveOccurred()) + }) + }) + }) +}) diff --git a/pkg/tc/mocks/TC.go b/pkg/tc/mocks/TC.go new file mode 100644 index 0000000..9941318 --- /dev/null +++ b/pkg/tc/mocks/TC.go @@ -0,0 +1,182 @@ +// Code generated by mockery v2.14.0. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + + types "github.com/Mellanox/multi-networkpolicy-tc/pkg/tc/types" +) + +// TC is an autogenerated mock type for the TC type +type TC struct { + mock.Mock +} + +// ChainAdd provides a mock function with given fields: qdisc, chain +func (_m *TC) ChainAdd(qdisc types.QDisc, chain types.Chain) error { + ret := _m.Called(qdisc, chain) + + var r0 error + if rf, ok := ret.Get(0).(func(types.QDisc, types.Chain) error); ok { + r0 = rf(qdisc, chain) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ChainDel provides a mock function with given fields: qdisc, chain +func (_m *TC) ChainDel(qdisc types.QDisc, chain types.Chain) error { + ret := _m.Called(qdisc, chain) + + var r0 error + if rf, ok := ret.Get(0).(func(types.QDisc, types.Chain) error); ok { + r0 = rf(qdisc, chain) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ChainList provides a mock function with given fields: qdisc +func (_m *TC) ChainList(qdisc types.QDisc) ([]types.Chain, error) { + ret := _m.Called(qdisc) + + var r0 []types.Chain + if rf, ok := ret.Get(0).(func(types.QDisc) []types.Chain); ok { + r0 = rf(qdisc) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Chain) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(types.QDisc) error); ok { + r1 = rf(qdisc) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FilterAdd provides a mock function with given fields: qdisc, filter +func (_m *TC) FilterAdd(qdisc types.QDisc, filter types.Filter) error { + ret := _m.Called(qdisc, filter) + + var r0 error + if rf, ok := ret.Get(0).(func(types.QDisc, types.Filter) error); ok { + r0 = rf(qdisc, filter) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// FilterDel provides a mock function with given fields: qdisc, filterAttr +func (_m *TC) FilterDel(qdisc types.QDisc, filterAttr *types.FilterAttrs) error { + ret := _m.Called(qdisc, filterAttr) + + var r0 error + if rf, ok := ret.Get(0).(func(types.QDisc, *types.FilterAttrs) error); ok { + r0 = rf(qdisc, filterAttr) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// FilterList provides a mock function with given fields: qdisc +func (_m *TC) FilterList(qdisc types.QDisc) ([]types.Filter, error) { + ret := _m.Called(qdisc) + + var r0 []types.Filter + if rf, ok := ret.Get(0).(func(types.QDisc) []types.Filter); ok { + r0 = rf(qdisc) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Filter) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(types.QDisc) error); ok { + r1 = rf(qdisc) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// QDiscAdd provides a mock function with given fields: qdisc +func (_m *TC) QDiscAdd(qdisc types.QDisc) error { + ret := _m.Called(qdisc) + + var r0 error + if rf, ok := ret.Get(0).(func(types.QDisc) error); ok { + r0 = rf(qdisc) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// QDiscDel provides a mock function with given fields: qdisc +func (_m *TC) QDiscDel(qdisc types.QDisc) error { + ret := _m.Called(qdisc) + + var r0 error + if rf, ok := ret.Get(0).(func(types.QDisc) error); ok { + r0 = rf(qdisc) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// QDiscList provides a mock function with given fields: +func (_m *TC) QDiscList() ([]types.QDisc, error) { + ret := _m.Called() + + var r0 []types.QDisc + if rf, ok := ret.Get(0).(func() []types.QDisc); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.QDisc) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type mockConstructorTestingTNewTC interface { + mock.TestingT + Cleanup(func()) +} + +// NewTC creates a new instance of TC. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewTC(t mockConstructorTestingTNewTC) *TC { + mock := &TC{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}