From 20caff09a8395b35b28c3246104ee2303d820d78 Mon Sep 17 00:00:00 2001 From: Vincent Boulineau Date: Fri, 25 Mar 2022 12:36:50 +0100 Subject: [PATCH] Fix AD port name interpolation by properly merging slices in workloadmeta --- pkg/workloadmeta/merge.go | 83 +++++++++++++++++++++++----- pkg/workloadmeta/merge_test.go | 98 ++++++++++++++++++++++++++++++++-- 2 files changed, 162 insertions(+), 19 deletions(-) diff --git a/pkg/workloadmeta/merge.go b/pkg/workloadmeta/merge.go index 8aa98c3392ec11..2e4e59dba43912 100644 --- a/pkg/workloadmeta/merge.go +++ b/pkg/workloadmeta/merge.go @@ -7,35 +7,90 @@ package workloadmeta import ( "reflect" + "strconv" "time" "github.com/imdario/mergo" ) -type timeMerger struct{} +type ( + merger struct{} +) var ( - timeType = reflect.TypeOf(time.Time{}) - timeMergerInstance = timeMerger{} + timeType = reflect.TypeOf(time.Time{}) + portSliceType = reflect.TypeOf([]ContainerPort{}) + mergerInstance = merger{} ) -func (tm timeMerger) Transformer(typ reflect.Type) func(dst, src reflect.Value) error { - if typ != timeType { +func (merger) Transformer(typ reflect.Type) func(dst, src reflect.Value) error { + switch typ { + case timeType: + return timeMerge + case portSliceType: + return portSliceMerge + } + + return nil +} + +func timeMerge(dst, src reflect.Value) error { + if !dst.CanSet() { return nil } - return func(dst, src reflect.Value) error { - if dst.CanSet() { - isZero := src.MethodByName("IsZero") - result := isZero.Call([]reflect.Value{}) - if !result[0].Bool() { - dst.Set(src) - } - } + isZero := src.MethodByName("IsZero") + result := isZero.Call([]reflect.Value{}) + if !result[0].Bool() { + dst.Set(src) + } + return nil +} + +func portSliceMerge(dst, src reflect.Value) error { + if !dst.CanSet() { + return nil + } + + srcSlice := src.Interface().([]ContainerPort) + dstSlice := dst.Interface().([]ContainerPort) + + // Not allocation the map if nothing to do + if len(srcSlice) == 0 || len(dstSlice) == 0 { return nil } + + mergeMap := make(map[string]ContainerPort, len(srcSlice)+len(dstSlice)) + for _, port := range dstSlice { + mergeContainerPort(mergeMap, port) + } + + for _, port := range srcSlice { + mergeContainerPort(mergeMap, port) + } + + dstSlice = make([]ContainerPort, 0, len(mergeMap)) + for _, port := range mergeMap { + dstSlice = append(dstSlice, port) + } + dst.Set(reflect.ValueOf(dstSlice)) + + return nil +} + +func mergeContainerPort(mergeMap map[string]ContainerPort, port ContainerPort) { + portKey := strconv.Itoa(port.Port) + port.Protocol + existingPort, found := mergeMap[portKey] + + if found { + if existingPort.Name == "" && port.Name != "" { + mergeMap[portKey] = port + } + } else { + mergeMap[portKey] = port + } } func merge(dst, src interface{}) error { - return mergo.Merge(dst, src, mergo.WithTransformers(timeMergerInstance)) + return mergo.Merge(dst, src, mergo.WithAppendSlice, mergo.WithTransformers(mergerInstance)) } diff --git a/pkg/workloadmeta/merge_test.go b/pkg/workloadmeta/merge_test.go index 69c760733dacf6..932cbfcbd14e06 100644 --- a/pkg/workloadmeta/merge_test.go +++ b/pkg/workloadmeta/merge_test.go @@ -13,10 +13,8 @@ import ( "github.com/stretchr/testify/assert" ) -func TestMerge(t *testing.T) { - testTime := time.Now() - - fromSource1 := Container{ +func container1(testTime time.Time) Container { + return Container{ EntityID: EntityID{ Kind: KindContainer, ID: "foo1", @@ -25,15 +23,32 @@ func TestMerge(t *testing.T) { Name: "foo1-name", Namespace: "", }, + Ports: []ContainerPort{ + { + Name: "port1", + Port: 42000, + Protocol: "tcp", + }, + { + Port: 42001, + Protocol: "udp", + }, + { + Port: 42002, + }, + }, State: ContainerState{ Running: true, CreatedAt: testTime, StartedAt: testTime, FinishedAt: time.Time{}, }, + CollectorTags: []string{"tag1", "tag2"}, } +} - fromSource2 := Container{ +func container2(testTime time.Time) Container { + return Container{ EntityID: EntityID{ Kind: KindContainer, ID: "foo1", @@ -42,13 +57,35 @@ func TestMerge(t *testing.T) { Name: "foo1-name", Namespace: "", }, + Ports: []ContainerPort{ + { + Port: 42000, + Protocol: "tcp", + }, + { + Port: 42001, + Protocol: "udp", + }, + { + Port: 42002, + Protocol: "tcp", + }, + { + Port: 42003, + }, + }, State: ContainerState{ CreatedAt: time.Time{}, StartedAt: time.Time{}, FinishedAt: time.Time{}, ExitCode: pointer.UInt32Ptr(100), }, + CollectorTags: []string{"tag3"}, } +} + +func TestMerge(t *testing.T) { + testTime := time.Now() expectedContainer := Container{ EntityID: EntityID{ @@ -68,12 +105,63 @@ func TestMerge(t *testing.T) { }, } + expectedPorts := []ContainerPort{ + { + Name: "port1", + Port: 42000, + Protocol: "tcp", + }, + { + Port: 42001, + Protocol: "udp", + }, + { + Port: 42002, + }, + { + Port: 42002, + Protocol: "tcp", + }, + { + Port: 42003, + }, + } + + expectedTags := []string{"tag1", "tag2", "tag3"} + // Test merging both ways + fromSource1 := container1(testTime) + fromSource2 := container2(testTime) err := merge(&fromSource1, &fromSource2) assert.NoError(t, err) + assert.ElementsMatch(t, expectedPorts, fromSource1.Ports) + assert.ElementsMatch(t, expectedTags, fromSource1.CollectorTags) + fromSource1.Ports = nil + fromSource1.CollectorTags = nil assert.Equal(t, expectedContainer, fromSource1) + fromSource1 = container1(testTime) + fromSource2 = container2(testTime) err = merge(&fromSource2, &fromSource1) assert.NoError(t, err) + assert.ElementsMatch(t, expectedPorts, fromSource2.Ports) + assert.ElementsMatch(t, expectedTags, fromSource2.CollectorTags) + fromSource2.Ports = nil + fromSource2.CollectorTags = nil assert.Equal(t, expectedContainer, fromSource2) + + // Test merging nil slice in src/dst + fromSource1 = container1(testTime) + fromSource2 = container2(testTime) + fromSource2.Ports = nil + err = merge(&fromSource1, &fromSource2) + assert.NoError(t, err) + assert.ElementsMatch(t, container1(testTime).Ports, fromSource1.Ports) + + fromSource1 = container1(testTime) + fromSource2 = container2(testTime) + fromSource2.Ports = nil + err = merge(&fromSource2, &fromSource1) + assert.NoError(t, err) + assert.ElementsMatch(t, container1(testTime).Ports, fromSource2.Ports) }