diff --git a/dnsutils/message.go b/dnsutils/message.go index 3d8e6988..b295c1d4 100644 --- a/dnsutils/message.go +++ b/dnsutils/message.go @@ -17,6 +17,7 @@ import ( "github.com/dmachard/go-dnstap-protobuf" "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/miekg/dns" "github.com/nqd/flat" "google.golang.org/protobuf/proto" ) @@ -731,7 +732,7 @@ func (dm *DnsMessage) ToPacketLayer() ([]gopacket.SerializableLayer, error) { ip6.SrcIP = net.ParseIP(srcIp) ip6.DstIP = net.ParseIP(dstIp) default: - return nil, errors.New("family " + dm.NetworkInfo.Family + " not yet implemented") + return nil, errors.New("family (" + dm.NetworkInfo.Family + ") not yet implemented") } // set transport @@ -830,3 +831,17 @@ func GetFakeDnsMessage() DnsMessage { dm.DNS.Qtype = "A" return dm } + +func GetFakeDnsMessageWithPayload() DnsMessage { + // fake dns query payload + dnsmsg := new(dns.Msg) + dnsmsg.SetQuestion("dnscollector.dev.", dns.TypeA) + dnsquestion, _ := dnsmsg.Pack() + + dm := GetFakeDnsMessage() + dm.NetworkInfo.Family = PROTO_IPV4 + dm.NetworkInfo.Protocol = PROTO_UDP + dm.DNS.Payload = dnsquestion + dm.DNS.Length = len(dnsquestion) + return dm +} diff --git a/loggers/stdout.go b/loggers/stdout.go index 7291f8d9..77a8bfbc 100644 --- a/loggers/stdout.go +++ b/loggers/stdout.go @@ -3,6 +3,7 @@ package loggers import ( "bytes" "encoding/json" + "io" "log" "os" "strings" @@ -39,7 +40,8 @@ type StdOut struct { config *dnsutils.Config configChan chan *dnsutils.Config logger *logger.Logger - stdout *log.Logger + writerText *log.Logger + writerPcap *pcapgo.Writer name string } @@ -55,7 +57,7 @@ func NewStdOut(config *dnsutils.Config, console *logger.Logger, name string) *St logger: console, config: config, configChan: make(chan *dnsutils.Config), - stdout: log.New(os.Stdout, "", 0), + writerText: log.New(os.Stdout, "", 0), name: name, } o.ReadConfig() @@ -70,6 +72,7 @@ func (c *StdOut) ReadConfig() { if !IsStdoutValidMode(c.config.Loggers.Stdout.Mode) { c.logger.Fatal("["+c.name+"] logger=stdout - invalid mode: ", c.config.Loggers.Stdout.Mode) } + if len(c.config.Loggers.Stdout.TextFormat) > 0 { c.textFormat = strings.Fields(c.config.Loggers.Stdout.TextFormat) } else { @@ -90,8 +93,18 @@ func (c *StdOut) LogError(msg string, v ...interface{}) { c.logger.Error("["+c.name+"] logger=stdout - "+msg, v...) } -func (o *StdOut) SetBuffer(b *bytes.Buffer) { - o.stdout.SetOutput(b) +func (o *StdOut) SetTextWriter(b *bytes.Buffer) { + o.writerText = log.New(os.Stdout, "", 0) + o.writerText.SetOutput(b) +} + +func (o *StdOut) SetPcapWriter(w io.Writer) { + o.LogInfo("init pcap writer") + + o.writerPcap = pcapgo.NewWriter(w) + if err := o.writerPcap.WriteFileHeader(65536, layers.LinkTypeEthernet); err != nil { + o.logger.Fatal("["+o.name+"] logger=stdout - pcap init error: %e", err) + } } func (o *StdOut) Channel() chan dnsutils.DnsMessage { @@ -140,7 +153,7 @@ RUN_LOOP: case dm, opened := <-o.inputChan: if !opened { - o.LogInfo("input channel closed!") + o.LogInfo("run: input channel closed!") return } @@ -162,13 +175,8 @@ func (o *StdOut) Process() { // standard output buffer buffer := new(bytes.Buffer) - // pcap init ? - var writerPcap *pcapgo.Writer - if o.config.Loggers.Stdout.Mode == dnsutils.MODE_PCAP { - writerPcap = pcapgo.NewWriter(os.Stdout) - if err := writerPcap.WriteFileHeader(65536, layers.LinkTypeEthernet); err != nil { - o.LogError("pcap init error: %e", err) - } + if o.config.Loggers.Stdout.Mode == dnsutils.MODE_PCAP && o.writerPcap == nil { + o.SetPcapWriter(os.Stdout) } o.LogInfo("ready to process") @@ -181,14 +189,14 @@ PROCESS_LOOP: case dm, opened := <-o.outputChan: if !opened { - o.LogInfo("output channel closed!") + o.LogInfo("process: output channel closed!") return } switch o.config.Loggers.Stdout.Mode { case dnsutils.MODE_PCAP: if len(dm.DNS.Payload) == 0 { - o.LogError("no dns payload to encode, drop it!") + o.LogError("process: no dns payload to encode, drop it") continue } @@ -214,25 +222,25 @@ PROCESS_LOOP: Length: bufSize, } - writerPcap.WritePacket(ci, buf.Bytes()) + o.writerPcap.WritePacket(ci, buf.Bytes()) case dnsutils.MODE_TEXT: - o.stdout.Print(dm.String(o.textFormat, + o.writerText.Print(dm.String(o.textFormat, o.config.Global.TextFormatDelimiter, o.config.Global.TextFormatBoundary)) case dnsutils.MODE_JSON: json.NewEncoder(buffer).Encode(dm) - o.stdout.Print(buffer.String()) + o.writerText.Print(buffer.String()) buffer.Reset() case dnsutils.MODE_FLATJSON: flat, err := dm.Flatten() if err != nil { - o.LogError("flattening DNS message failed: %e", err) + o.LogError("process: flattening DNS message failed: %e", err) } json.NewEncoder(buffer).Encode(flat) - o.stdout.Print(buffer.String()) + o.writerText.Print(buffer.String()) buffer.Reset() } } diff --git a/loggers/stdout_test.go b/loggers/stdout_test.go index a1361f2b..5a5e340a 100644 --- a/loggers/stdout_test.go +++ b/loggers/stdout_test.go @@ -8,6 +8,7 @@ import ( "github.com/dmachard/go-dnscollector/dnsutils" "github.com/dmachard/go-logger" + "github.com/google/gopacket/pcapgo" ) func Test_StdoutTextMode(t *testing.T) { @@ -68,7 +69,7 @@ func Test_StdoutTextMode(t *testing.T) { cfg.Global.TextFormatBoundary = tc.boundary g := NewStdOut(cfg, logger.New(false), "test") - g.SetBuffer(&stdout) + g.SetTextWriter(&stdout) go g.Run() @@ -112,7 +113,7 @@ func Test_StdoutJsonMode(t *testing.T) { cfg := dnsutils.GetFakeConfig() cfg.Loggers.Stdout.Mode = tc.mode g := NewStdOut(cfg, logger.New(false), "test") - g.SetBuffer(&stdout) + g.SetTextWriter(&stdout) go g.Run() @@ -133,3 +134,74 @@ func Test_StdoutJsonMode(t *testing.T) { }) } } + +func Test_StdoutPcapMode(t *testing.T) { + // redirect stdout output to bytes buffer + var pcap bytes.Buffer + + // init logger and run + cfg := dnsutils.GetFakeConfig() + cfg.Loggers.Stdout.Mode = "pcap" + + g := NewStdOut(cfg, logger.New(false), "test") + g.SetPcapWriter(&pcap) + + go g.Run() + + // send DNSMessage to channel + dm := dnsutils.GetFakeDnsMessageWithPayload() + g.Channel() <- dm + + // stop logger + time.Sleep(time.Second) + g.Stop() + + // check pcap output + pcapReader, err := pcapgo.NewReader(bytes.NewReader(pcap.Bytes())) + if err != nil { + t.Errorf("unable to read pcap: %s", err) + return + } + data, _, err := pcapReader.ReadPacketData() + if err != nil { + t.Errorf("unable to read packet: %s", err) + return + } + if len(data) < dm.DNS.Length { + t.Errorf("incorrect packet size: %d", len(data)) + } +} + +func Test_StdoutPcapMode_NoDNSPayload(t *testing.T) { + // redirect stdout output to bytes buffer + logger := logger.New(false) + var logs bytes.Buffer + logger.SetOutput(&logs) + + var pcap bytes.Buffer + + // init logger and run + cfg := dnsutils.GetFakeConfig() + cfg.Loggers.Stdout.Mode = "pcap" + + g := NewStdOut(cfg, logger, "test") + g.SetPcapWriter(&pcap) + + go g.Run() + + // send DNSMessage to channel + dm := dnsutils.GetFakeDnsMessage() + g.Channel() <- dm + + // stop logger + time.Sleep(time.Second) + g.Stop() + + // check output + regxp := "ERROR:.*process: no dns payload to encode, drop it.*" + pattern := regexp.MustCompile(regxp) + ret := logs.String() + if !pattern.MatchString(ret) { + t.Errorf("stdout error want %s, got: %s", regxp, ret) + } +}