Skip to content

Commit

Permalink
Fix AD port name interpolation by properly merging slices in workload…
Browse files Browse the repository at this point in the history
…meta
  • Loading branch information
vboulineau committed Mar 25, 2022
1 parent d3daed8 commit 20caff0
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 19 deletions.
83 changes: 69 additions & 14 deletions pkg/workloadmeta/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,90 @@ package workloadmeta

import (
"reflect"
"strconv"
"time"

"github.com/imdario/mergo"
)

type timeMerger struct{}
type (
merger struct{}
)

var (
timeType = reflect.TypeOf(time.Time{})
timeMergerInstance = timeMerger{}
timeType = reflect.TypeOf(time.Time{})
portSliceType = reflect.TypeOf([]ContainerPort{})
mergerInstance = merger{}
)

func (tm timeMerger) Transformer(typ reflect.Type) func(dst, src reflect.Value) error {
if typ != timeType {
func (merger) Transformer(typ reflect.Type) func(dst, src reflect.Value) error {
switch typ {
case timeType:
return timeMerge
case portSliceType:
return portSliceMerge
}

return nil
}

func timeMerge(dst, src reflect.Value) error {
if !dst.CanSet() {
return nil
}

return func(dst, src reflect.Value) error {
if dst.CanSet() {
isZero := src.MethodByName("IsZero")
result := isZero.Call([]reflect.Value{})
if !result[0].Bool() {
dst.Set(src)
}
}
isZero := src.MethodByName("IsZero")
result := isZero.Call([]reflect.Value{})
if !result[0].Bool() {
dst.Set(src)
}
return nil
}

func portSliceMerge(dst, src reflect.Value) error {
if !dst.CanSet() {
return nil
}

srcSlice := src.Interface().([]ContainerPort)
dstSlice := dst.Interface().([]ContainerPort)

// Not allocation the map if nothing to do
if len(srcSlice) == 0 || len(dstSlice) == 0 {
return nil
}

mergeMap := make(map[string]ContainerPort, len(srcSlice)+len(dstSlice))
for _, port := range dstSlice {
mergeContainerPort(mergeMap, port)
}

for _, port := range srcSlice {
mergeContainerPort(mergeMap, port)
}

dstSlice = make([]ContainerPort, 0, len(mergeMap))
for _, port := range mergeMap {
dstSlice = append(dstSlice, port)
}
dst.Set(reflect.ValueOf(dstSlice))

return nil
}

func mergeContainerPort(mergeMap map[string]ContainerPort, port ContainerPort) {
portKey := strconv.Itoa(port.Port) + port.Protocol
existingPort, found := mergeMap[portKey]

if found {
if existingPort.Name == "" && port.Name != "" {
mergeMap[portKey] = port
}
} else {
mergeMap[portKey] = port
}
}

func merge(dst, src interface{}) error {
return mergo.Merge(dst, src, mergo.WithTransformers(timeMergerInstance))
return mergo.Merge(dst, src, mergo.WithAppendSlice, mergo.WithTransformers(mergerInstance))
}
98 changes: 93 additions & 5 deletions pkg/workloadmeta/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ import (
"github.com/stretchr/testify/assert"
)

func TestMerge(t *testing.T) {
testTime := time.Now()

fromSource1 := Container{
func container1(testTime time.Time) Container {
return Container{
EntityID: EntityID{
Kind: KindContainer,
ID: "foo1",
Expand All @@ -25,15 +23,32 @@ func TestMerge(t *testing.T) {
Name: "foo1-name",
Namespace: "",
},
Ports: []ContainerPort{
{
Name: "port1",
Port: 42000,
Protocol: "tcp",
},
{
Port: 42001,
Protocol: "udp",
},
{
Port: 42002,
},
},
State: ContainerState{
Running: true,
CreatedAt: testTime,
StartedAt: testTime,
FinishedAt: time.Time{},
},
CollectorTags: []string{"tag1", "tag2"},
}
}

fromSource2 := Container{
func container2(testTime time.Time) Container {
return Container{
EntityID: EntityID{
Kind: KindContainer,
ID: "foo1",
Expand All @@ -42,13 +57,35 @@ func TestMerge(t *testing.T) {
Name: "foo1-name",
Namespace: "",
},
Ports: []ContainerPort{
{
Port: 42000,
Protocol: "tcp",
},
{
Port: 42001,
Protocol: "udp",
},
{
Port: 42002,
Protocol: "tcp",
},
{
Port: 42003,
},
},
State: ContainerState{
CreatedAt: time.Time{},
StartedAt: time.Time{},
FinishedAt: time.Time{},
ExitCode: pointer.UInt32Ptr(100),
},
CollectorTags: []string{"tag3"},
}
}

func TestMerge(t *testing.T) {
testTime := time.Now()

expectedContainer := Container{
EntityID: EntityID{
Expand All @@ -68,12 +105,63 @@ func TestMerge(t *testing.T) {
},
}

expectedPorts := []ContainerPort{
{
Name: "port1",
Port: 42000,
Protocol: "tcp",
},
{
Port: 42001,
Protocol: "udp",
},
{
Port: 42002,
},
{
Port: 42002,
Protocol: "tcp",
},
{
Port: 42003,
},
}

expectedTags := []string{"tag1", "tag2", "tag3"}

// Test merging both ways
fromSource1 := container1(testTime)
fromSource2 := container2(testTime)
err := merge(&fromSource1, &fromSource2)
assert.NoError(t, err)
assert.ElementsMatch(t, expectedPorts, fromSource1.Ports)
assert.ElementsMatch(t, expectedTags, fromSource1.CollectorTags)
fromSource1.Ports = nil
fromSource1.CollectorTags = nil
assert.Equal(t, expectedContainer, fromSource1)

fromSource1 = container1(testTime)
fromSource2 = container2(testTime)
err = merge(&fromSource2, &fromSource1)
assert.NoError(t, err)
assert.ElementsMatch(t, expectedPorts, fromSource2.Ports)
assert.ElementsMatch(t, expectedTags, fromSource2.CollectorTags)
fromSource2.Ports = nil
fromSource2.CollectorTags = nil
assert.Equal(t, expectedContainer, fromSource2)

// Test merging nil slice in src/dst
fromSource1 = container1(testTime)
fromSource2 = container2(testTime)
fromSource2.Ports = nil
err = merge(&fromSource1, &fromSource2)
assert.NoError(t, err)
assert.ElementsMatch(t, container1(testTime).Ports, fromSource1.Ports)

fromSource1 = container1(testTime)
fromSource2 = container2(testTime)
fromSource2.Ports = nil
err = merge(&fromSource2, &fromSource1)
assert.NoError(t, err)
assert.ElementsMatch(t, container1(testTime).Ports, fromSource2.Ports)
}

0 comments on commit 20caff0

Please sign in to comment.