Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new generic worker #696

Merged
merged 14 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<img src="https://img.shields.io/badge/go%20version-min%201.21-green" alt="Go version"/>
<img src="https://img.shields.io/badge/go%20tests-450-green" alt="Go tests"/>
<img src="https://img.shields.io/badge/go%20bench-20-green" alt="Go bench"/>
<img src="https://img.shields.io/badge/go%20lines-37761-green" alt="Go lines"/>
<img src="https://img.shields.io/badge/go%20lines-35097-green" alt="Go lines"/>
</p>

<p align="center">
Expand Down
121 changes: 50 additions & 71 deletions collectors/dnsmessage.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ type MatchSource struct {
}

type DNSMessage struct {
*pkgutils.Collector
*pkgutils.GenericWorker
inputChan chan dnsutils.DNSMessage
}

func NewDNSMessage(next []pkgutils.Worker, config *pkgconfig.Config, logger *logger.Logger, name string) *DNSMessage {
s := &DNSMessage{
Collector: pkgutils.NewCollector(config, logger, name, "dnsmessage"),
inputChan: make(chan dnsutils.DNSMessage, config.Collectors.DNSMessage.ChannelBufferSize),
GenericWorker: pkgutils.NewGenericWorker(config, logger, name, "dnsmessage", 0),
inputChan: make(chan dnsutils.DNSMessage, config.Collectors.DNSMessage.ChannelBufferSize),
}
s.SetDefaultRoutes(next)
s.ReadConfig()
return s
}

func (c *DNSMessage) ReadConfigMatching(value interface{}) {
func (w *DNSMessage) ReadConfigMatching(value interface{}) {
reflectedValue := reflect.ValueOf(value)
if reflectedValue.Kind() == reflect.Map {
keys := reflectedValue.MapKeys()
Expand All @@ -60,9 +60,9 @@ func (c *DNSMessage) ReadConfigMatching(value interface{}) {
}
}
if len(matchSrc) > 0 {
sourceData, err := c.LoadData(matchSrc, srcKind)
sourceData, err := w.LoadData(matchSrc, srcKind)
if err != nil {
c.LogFatal(err)
w.LogFatal(err)
}
if len(sourceData.regexList) > 0 {
value.(map[interface{}]interface{})[srcKind] = sourceData.regexList
Expand All @@ -74,44 +74,44 @@ func (c *DNSMessage) ReadConfigMatching(value interface{}) {
}
}

func (c *DNSMessage) ReadConfig() {
func (w *DNSMessage) ReadConfig() {
// load external file for include
if len(c.GetConfig().Collectors.DNSMessage.Matching.Include) > 0 {
for _, value := range c.GetConfig().Collectors.DNSMessage.Matching.Include {
c.ReadConfigMatching(value)
if len(w.GetConfig().Collectors.DNSMessage.Matching.Include) > 0 {
for _, value := range w.GetConfig().Collectors.DNSMessage.Matching.Include {
w.ReadConfigMatching(value)
}
}
// load external file for exclude
if len(c.GetConfig().Collectors.DNSMessage.Matching.Exclude) > 0 {
for _, value := range c.GetConfig().Collectors.DNSMessage.Matching.Exclude {
c.ReadConfigMatching(value)
if len(w.GetConfig().Collectors.DNSMessage.Matching.Exclude) > 0 {
for _, value := range w.GetConfig().Collectors.DNSMessage.Matching.Exclude {
w.ReadConfigMatching(value)
}
}
}

func (c *DNSMessage) GetInputChannel() chan dnsutils.DNSMessage {
return c.inputChan
func (w *DNSMessage) GetInputChannel() chan dnsutils.DNSMessage {
return w.inputChan
}

func (c *DNSMessage) LoadData(matchSource string, srcKind string) (MatchSource, error) {
func (w *DNSMessage) LoadData(matchSource string, srcKind string) (MatchSource, error) {
if isFileSource(matchSource) {
dataSource, err := c.LoadFromFile(matchSource, srcKind)
dataSource, err := w.LoadFromFile(matchSource, srcKind)
if err != nil {
c.LogFatal(err)
w.LogFatal(err)
}
return dataSource, nil
} else if isURLSource(matchSource) {
dataSource, err := c.LoadFromURL(matchSource, srcKind)
dataSource, err := w.LoadFromURL(matchSource, srcKind)
if err != nil {
c.LogFatal(err)
w.LogFatal(err)
}
return dataSource, nil
}
return MatchSource{}, fmt.Errorf("match source not supported %s", matchSource)
}

func (c *DNSMessage) LoadFromURL(matchSource string, srcKind string) (MatchSource, error) {
c.LogInfo("loading matching source from url=%s", matchSource)
func (w *DNSMessage) LoadFromURL(matchSource string, srcKind string) (MatchSource, error) {
w.LogInfo("loading matching source from url=%s", matchSource)
resp, err := http.Get(matchSource)
if err != nil {
return MatchSource{}, err
Expand All @@ -130,21 +130,21 @@ func (c *DNSMessage) LoadFromURL(matchSource string, srcKind string) (MatchSourc
for scanner.Scan() {
matchSources.regexList = append(matchSources.regexList, regexp.MustCompile(scanner.Text()))
}
c.LogInfo("remote source loaded with %d entries kind=%s", len(matchSources.regexList), srcKind)
w.LogInfo("remote source loaded with %d entries kind=%s", len(matchSources.regexList), srcKind)
case dnsutils.MatchingKindString:
for scanner.Scan() {
matchSources.stringList = append(matchSources.stringList, scanner.Text())
}
c.LogInfo("remote source loaded with %d entries kind=%s", len(matchSources.stringList), srcKind)
w.LogInfo("remote source loaded with %d entries kind=%s", len(matchSources.stringList), srcKind)
}

return matchSources, nil
}

func (c *DNSMessage) LoadFromFile(filePath string, srcKind string) (MatchSource, error) {
func (w *DNSMessage) LoadFromFile(filePath string, srcKind string) (MatchSource, error) {
localFile := strings.TrimPrefix(filePath, "file://")

c.LogInfo("loading matching source from file=%s", localFile)
w.LogInfo("loading matching source from file=%s", localFile)
file, err := os.Open(localFile)
if err != nil {
return MatchSource{}, fmt.Errorf("unable to open file: %w", err)
Expand All @@ -158,48 +158,45 @@ func (c *DNSMessage) LoadFromFile(filePath string, srcKind string) (MatchSource,
for scanner.Scan() {
matchSources.regexList = append(matchSources.regexList, regexp.MustCompile(scanner.Text()))
}
c.LogInfo("file loaded with %d entries kind=%s", len(matchSources.regexList), srcKind)
w.LogInfo("file loaded with %d entries kind=%s", len(matchSources.regexList), srcKind)
case dnsutils.MatchingKindString:
for scanner.Scan() {
matchSources.stringList = append(matchSources.stringList, scanner.Text())
}
c.LogInfo("file loaded with %d entries kind=%s", len(matchSources.stringList), srcKind)
w.LogInfo("file loaded with %d entries kind=%s", len(matchSources.stringList), srcKind)
}

return matchSources, nil
}

func (c *DNSMessage) Run() {
c.LogInfo("running collector...")
defer func() {
c.LogInfo("run terminated")
c.StopIsDone()
}()
func (w *DNSMessage) StartCollect() {
w.LogInfo("worker is starting collection")
defer w.CollectDone()

var err error

// prepare next channels
defaultRoutes, defaultNames := pkgutils.GetRoutes(c.GetDefaultRoutes())
droppedRoutes, droppedNames := pkgutils.GetRoutes(c.GetDefaultRoutes())
defaultRoutes, defaultNames := pkgutils.GetRoutes(w.GetDefaultRoutes())
droppedRoutes, droppedNames := pkgutils.GetRoutes(w.GetDroppedRoutes())

// prepare transforms
subprocessors := transformers.NewTransforms(&c.GetConfig().IngoingTransformers, c.GetLogger(), c.GetName(), defaultRoutes, 0)
subprocessors := transformers.NewTransforms(&w.GetConfig().IngoingTransformers, w.GetLogger(), w.GetName(), defaultRoutes, 0)

// read incoming dns message
c.LogInfo("waiting dns message to process...")
w.LogInfo("waiting dns message to process...")
for {
select {
case <-c.OnStop():
case <-w.OnStop():
return

// save the new config
case cfg := <-c.NewConfig():
c.SetConfig(cfg)
c.ReadConfig()
case cfg := <-w.NewConfig():
w.SetConfig(cfg)
w.ReadConfig()

case dm, opened := <-c.inputChan:
case dm, opened := <-w.GetInputChannel():
if !opened {
c.LogInfo("channel closed, exit")
w.LogInfo("channel closed, exit")
return
}

Expand All @@ -208,10 +205,10 @@ func (c *DNSMessage) Run() {
matchedInclude := false
matchedExclude := false

if len(c.GetConfig().Collectors.DNSMessage.Matching.Include) > 0 {
err, matchedInclude = dm.Matching(c.GetConfig().Collectors.DNSMessage.Matching.Include)
if len(w.GetConfig().Collectors.DNSMessage.Matching.Include) > 0 {
err, matchedInclude = dm.Matching(w.GetConfig().Collectors.DNSMessage.Matching.Include)
if err != nil {
c.LogError(err.Error())
w.LogError(err.Error())
}
if matched && matchedInclude {
matched = true
Expand All @@ -220,10 +217,10 @@ func (c *DNSMessage) Run() {
}
}

if len(c.GetConfig().Collectors.DNSMessage.Matching.Exclude) > 0 {
err, matchedExclude = dm.Matching(c.GetConfig().Collectors.DNSMessage.Matching.Exclude)
if len(w.GetConfig().Collectors.DNSMessage.Matching.Exclude) > 0 {
err, matchedExclude = dm.Matching(w.GetConfig().Collectors.DNSMessage.Matching.Exclude)
if err != nil {
c.LogError(err.Error())
w.LogError(err.Error())
}
if matched && !matchedExclude {
matched = true
Expand All @@ -237,37 +234,19 @@ func (c *DNSMessage) Run() {
if matched {
subprocessors.InitDNSMessageFormat(&dm)
if subprocessors.ProcessMessage(&dm) == transformers.ReturnDrop {
for i := range droppedRoutes {
select {
case droppedRoutes[i] <- dm:
default:
c.NextStanzaIsBusy(droppedNames[i])
}
}
w.SendTo(droppedRoutes, droppedNames, dm)
continue
}
}

// drop packet ?
if !matched {
for i := range droppedRoutes {
select {
case droppedRoutes[i] <- dm:
default:
c.NextStanzaIsBusy(droppedNames[i])
}
}
w.SendTo(droppedRoutes, droppedNames, dm)
continue
}

// send to next
for i := range defaultRoutes {
select {
case defaultRoutes[i] <- dm:
default:
c.NextStanzaIsBusy(defaultNames[i])
}
}
w.SendTo(defaultRoutes, defaultNames, dm)
}
}
}
2 changes: 1 addition & 1 deletion collectors/dnsmessage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func Test_DnsMessage_BufferLoggerIsFull(t *testing.T) {
c.AddDefaultRoute(nxt)

// run collector
go c.Run()
go c.StartCollect()

// add a shot of dnsmessages to collector
dmIn := dnsutils.GetFakeDNSMessage()
Expand Down
Loading
Loading