Skip to content

Commit

Permalink
packetbeat/protos/dns: clean up package (#33286)
Browse files Browse the repository at this point in the history
* avoid magic numbers
* fix hashableDNSTuple size and offsets
* avoid use of String and Error methods in formatted print calls
* remove redundant conversions
* quieten linter
* use plugin-owned logp.Logger
  • Loading branch information
efd6 authored Oct 12, 2022
1 parent b69e7b3 commit d39531e
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 113 deletions.
118 changes: 58 additions & 60 deletions packetbeat/protos/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ type dnsPlugin struct {

results protos.Reporter // Channel where results are pushed.
watcher procs.ProcessesWatcher
}

var debugf = logp.MakeDebug("dns")

const maxDNSTupleRawSize = 16 + 16 + 2 + 2 + 4 + 1
logger *logp.Logger
}

// Transport protocol.
type transport uint8
Expand Down Expand Up @@ -92,6 +90,15 @@ func (t transport) String() string {

type hashableDNSTuple [maxDNSTupleRawSize]byte

const (
maxDNSTupleRawSize = 2*(sizeofIP+sizeofPort) + sizeofID + sizeofTransport

sizeofIP = 16
sizeofPort = 2
sizeofID = 2
sizeofTransport = 1
)

// DnsMessage contains a single DNS message.
type dnsMessage struct {
ts time.Time // Time when the message was received.
Expand All @@ -109,8 +116,8 @@ type dnsTuple struct {
transport transport
id uint16

raw hashableDNSTuple // Src_ip:Src_port:Dst_ip:Dst_port:Transport:Id
revRaw hashableDNSTuple // Dst_ip:Dst_port:Src_ip:Src_port:Transport:Id
raw hashableDNSTuple // Src_ip:Src_port:Dst_ip:Dst_port:ID:Transport
revRaw hashableDNSTuple // Dst_ip:Dst_port:Src_ip:Src_port:ID:Transport
}

func dnsTupleFromIPPort(t *common.IPPortTuple, trans transport, id uint16) dnsTuple {
Expand Down Expand Up @@ -152,21 +159,21 @@ func (t *dnsTuple) computeHashables() {
copy(t.raw[18:34], t.DstIP)
copy(t.raw[34:36], []byte{byte(t.DstPort >> 8), byte(t.DstPort)})
copy(t.raw[36:38], []byte{byte(t.id >> 8), byte(t.id)})
t.raw[39] = byte(t.transport)
t.raw[38] = byte(t.transport)

copy(t.revRaw[0:16], t.DstIP)
copy(t.revRaw[16:18], []byte{byte(t.DstPort >> 8), byte(t.DstPort)})
copy(t.revRaw[18:34], t.SrcIP)
copy(t.revRaw[34:36], []byte{byte(t.SrcPort >> 8), byte(t.SrcPort)})
copy(t.revRaw[36:38], []byte{byte(t.id >> 8), byte(t.id)})
t.revRaw[39] = byte(t.transport)
t.revRaw[38] = byte(t.transport)
}

func (t *dnsTuple) String() string {
return fmt.Sprintf("DnsTuple src[%s:%d] dst[%s:%d] transport[%s] id[%d]",
t.SrcIP.String(),
t.SrcIP,
t.SrcPort,
t.DstIP.String(),
t.DstIP,
t.DstPort,
t.transport,
t.id)
Expand Down Expand Up @@ -212,13 +219,8 @@ func init() {
protos.Register("dns", New)
}

func New(
testMode bool,
results protos.Reporter,
watcher procs.ProcessesWatcher,
cfg *conf.C,
) (protos.Plugin, error) {
p := &dnsPlugin{}
func New(testMode bool, results protos.Reporter, watcher procs.ProcessesWatcher, cfg *conf.C) (protos.Plugin, error) {
p := &dnsPlugin{logger: logp.NewLogger("dns")}
config := defaultConfig
if !testMode {
if err := cfg.Unpack(&config); err != nil {
Expand All @@ -240,7 +242,7 @@ func (dns *dnsPlugin) init(results protos.Reporter, watcher procs.ProcessesWatch
func(k common.Key, v common.Value) {
trans, ok := v.(*dnsTransaction)
if !ok {
logp.Err("Expired value is not a *DnsTransaction.")
dns.logger.Error("Expired value is not a *DnsTransaction.")
return
}
dns.expireTransaction(trans)
Expand All @@ -253,14 +255,13 @@ func (dns *dnsPlugin) init(results protos.Reporter, watcher procs.ProcessesWatch
return nil
}

func (dns *dnsPlugin) setFromConfig(config *dnsConfig) error {
func (dns *dnsPlugin) setFromConfig(config *dnsConfig) {
dns.ports = config.Ports
dns.sendRequest = config.SendRequest
dns.sendResponse = config.SendResponse
dns.includeAuthorities = config.IncludeAuthorities
dns.includeAdditionals = config.IncludeAdditionals
dns.transactionTimeout = config.TransactionTimeout
return nil
}

func newTransaction(ts time.Time, tuple dnsTuple, cmd common.ProcessTuple) *dnsTransaction {
Expand Down Expand Up @@ -292,14 +293,14 @@ func (dns *dnsPlugin) ConnectionTimeout() time.Duration {
}

func (dns *dnsPlugin) receivedDNSRequest(tuple *dnsTuple, msg *dnsMessage) {
debugf("Processing query. %s", tuple.String())
dns.logger.Debugf("Processing query. %s", tuple)

trans := dns.deleteTransaction(tuple.hashable())
if trans != nil {
// This happens if a client puts multiple requests in flight
// with the same ID.
trans.notes = append(trans.notes, duplicateQueryMsg.Error())
debugf("%s %s", duplicateQueryMsg.Error(), tuple.String())
dns.logger.Debugf("%v %s", duplicateQueryMsg, tuple)
dns.publishTransaction(trans)
dns.deleteTransaction(trans.tuple.hashable())
}
Expand All @@ -308,21 +309,21 @@ func (dns *dnsPlugin) receivedDNSRequest(tuple *dnsTuple, msg *dnsMessage) {

if tuple.transport == transportUDP && (msg.data.IsEdns0() != nil) && msg.length > maxDNSPacketSize {
trans.notes = append(trans.notes, udpPacketTooLarge.Error())
debugf("%s", udpPacketTooLarge.Error())
dns.logger.Debugf("%v", udpPacketTooLarge)
}

dns.transactions.Put(tuple.hashable(), trans)
trans.request = msg
}

func (dns *dnsPlugin) receivedDNSResponse(tuple *dnsTuple, msg *dnsMessage) {
debugf("Processing response. %s", tuple.String())
dns.logger.Debugf("Processing response. %s", tuple)

trans := dns.getTransaction(tuple.revHashable())
if trans == nil {
trans = newTransaction(msg.ts, tuple.reverse(), msg.cmdlineTuple.Reverse())
trans.notes = append(trans.notes, orphanedResponse.Error())
debugf("%s %s", orphanedResponse.Error(), tuple.String())
dns.logger.Debugf("%v %s", orphanedResponse, tuple)
unmatchedResponses.Add(1)
}

Expand All @@ -332,7 +333,7 @@ func (dns *dnsPlugin) receivedDNSResponse(tuple *dnsTuple, msg *dnsMessage) {
respIsEdns := msg.data.IsEdns0() != nil
if !respIsEdns && msg.length > maxDNSPacketSize {
trans.notes = append(trans.notes, udpPacketTooLarge.responseError())
debugf("%s", udpPacketTooLarge.responseError())
dns.logger.Debugf("%s", udpPacketTooLarge.responseError())
}

request := trans.request
Expand All @@ -342,10 +343,10 @@ func (dns *dnsPlugin) receivedDNSResponse(tuple *dnsTuple, msg *dnsMessage) {
switch {
case reqIsEdns && !respIsEdns:
trans.notes = append(trans.notes, respEdnsNoSupport.Error())
debugf("%s %s", respEdnsNoSupport.Error(), tuple.String())
dns.logger.Debugf("%v %s", respEdnsNoSupport, tuple)
case !reqIsEdns && respIsEdns:
trans.notes = append(trans.notes, respEdnsUnexpected.Error())
debugf("%s %s", respEdnsUnexpected.Error(), tuple.String())
dns.logger.Debugf("%v %s", respEdnsUnexpected, tuple)
}
}
}
Expand All @@ -359,7 +360,7 @@ func (dns *dnsPlugin) publishTransaction(t *dnsTransaction) {
return
}

debugf("Publishing transaction. %s", t.tuple.String())
dns.logger.Debugf("Publishing transaction. %s", &t.tuple)

evt, pbf := pb.NewBeatEvent(t.ts)

Expand Down Expand Up @@ -388,18 +389,17 @@ func (dns *dnsPlugin) publishTransaction(t *dnsTransaction) {
fields["query"] = dnsQuestionToString(t.request.data.Question[0])
fields["resource"] = t.request.data.Question[0].Name
}
addDNSToMapStr(dnsEvent, pbf, t.response.data, dns.includeAuthorities,
dns.includeAdditionals)
addDNSToMapStr(dnsEvent, pbf, t.response.data, dns.includeAuthorities, dns.includeAdditionals, dns.logger)

if t.response.data.Rcode == 0 {
fields["status"] = common.OK_STATUS
}

if dns.sendRequest {
fields["request"] = dnsToString(t.request.data)
fields["request"] = dnsToString(t.request.data, dns.logger)
}
if dns.sendResponse {
fields["response"] = dnsToString(t.response.data)
fields["response"] = dnsToString(t.response.data, dns.logger)
}
} else if t.request != nil {
pbf.Source.Bytes = int64(t.request.length)
Expand All @@ -411,11 +411,10 @@ func (dns *dnsPlugin) publishTransaction(t *dnsTransaction) {
fields["query"] = dnsQuestionToString(t.request.data.Question[0])
fields["resource"] = t.request.data.Question[0].Name
}
addDNSToMapStr(dnsEvent, pbf, t.request.data, dns.includeAuthorities,
dns.includeAdditionals)
addDNSToMapStr(dnsEvent, pbf, t.request.data, dns.includeAuthorities, dns.includeAdditionals, dns.logger)

if dns.sendRequest {
fields["request"] = dnsToString(t.request.data)
fields["request"] = dnsToString(t.request.data, dns.logger)
}
} else if t.response != nil {
pbf.Destination.Bytes = int64(t.response.length)
Expand All @@ -427,10 +426,9 @@ func (dns *dnsPlugin) publishTransaction(t *dnsTransaction) {
fields["query"] = dnsQuestionToString(t.response.data.Question[0])
fields["resource"] = t.response.data.Question[0].Name
}
addDNSToMapStr(dnsEvent, pbf, t.response.data, dns.includeAuthorities,
dns.includeAdditionals)
addDNSToMapStr(dnsEvent, pbf, t.response.data, dns.includeAuthorities, dns.includeAdditionals, dns.logger)
if dns.sendResponse {
fields["response"] = dnsToString(t.response.data)
fields["response"] = dnsToString(t.response.data, dns.logger)
}
}

Expand All @@ -439,13 +437,13 @@ func (dns *dnsPlugin) publishTransaction(t *dnsTransaction) {

func (dns *dnsPlugin) expireTransaction(t *dnsTransaction) {
t.notes = append(t.notes, noResponse.Error())
debugf("%s %s", noResponse.Error(), t.tuple.String())
dns.logger.Debugf("%v %s", noResponse, &t.tuple)
dns.publishTransaction(t)
unmatchedRequests.Add(1)
}

// Adds the DNS message data to the supplied MapStr.
func addDNSToMapStr(m mapstr.M, pbf *pb.Fields, dns *mkdns.Msg, authority bool, additional bool) {
func addDNSToMapStr(m mapstr.M, pbf *pb.Fields, dns *mkdns.Msg, authority bool, additional bool, logger *logp.Logger) {
m["id"] = dns.Id
m["op_code"] = dnsOpCodeToString(dns.Opcode)

Expand Down Expand Up @@ -527,7 +525,7 @@ func addDNSToMapStr(m mapstr.M, pbf *pb.Fields, dns *mkdns.Msg, authority bool,
m["answers_count"] = len(dns.Answer)
if len(dns.Answer) > 0 {
var resolvedIPs []string
m["answers"], resolvedIPs = rrsToMapStrs(dns.Answer, true)
m["answers"], resolvedIPs = rrsToMapStrs(dns.Answer, true, logger)
if len(resolvedIPs) > 0 {
m["resolved_ip"] = resolvedIPs
pbf.AddIP(resolvedIPs...)
Expand All @@ -536,7 +534,7 @@ func addDNSToMapStr(m mapstr.M, pbf *pb.Fields, dns *mkdns.Msg, authority bool,

m["authorities_count"] = len(dns.Ns)
if authority && len(dns.Ns) > 0 {
m["authorities"], _ = rrsToMapStrs(dns.Ns, false)
m["authorities"], _ = rrsToMapStrs(dns.Ns, false, logger)
}

if rrOPT != nil {
Expand All @@ -545,7 +543,7 @@ func addDNSToMapStr(m mapstr.M, pbf *pb.Fields, dns *mkdns.Msg, authority bool,
m["additionals_count"] = len(dns.Extra)
}
if additional && len(dns.Extra) > 0 {
rrsMapStrs, _ := rrsToMapStrs(dns.Extra, false)
rrsMapStrs, _ := rrsToMapStrs(dns.Extra, false, logger)
// We do not want OPT RR to appear in the 'additional' section,
// that's why rrsMapStrs could be empty even though len(dns.Extra) > 0
if len(rrsMapStrs) > 0 {
Expand Down Expand Up @@ -590,13 +588,13 @@ func optToMapStr(rrOPT *mkdns.OPT) mapstr.M {

// rrsToMapStr converts an slice of RR's to an slice of MapStr's and optionally
// returns a list of the IP addresses found in the resource records.
func rrsToMapStrs(records []mkdns.RR, ipList bool) ([]mapstr.M, []string) {
func rrsToMapStrs(records []mkdns.RR, ipList bool, logger *logp.Logger) ([]mapstr.M, []string) {
var allIPs []string
mapStrSlice := make([]mapstr.M, 0, len(records))
for _, rr := range records {
rrHeader := rr.Header()

mapStr, ips := rrToMapStr(rr, ipList)
mapStr, ips := rrToMapStr(rr, ipList, logger)
if len(mapStr) == 0 { // OPT pseudo-RR returns an empty MapStr
continue
}
Expand All @@ -619,11 +617,11 @@ func rrsToMapStrs(records []mkdns.RR, ipList bool) ([]mapstr.M, []string) {
//
// TODO An improvement would be to replace 'data' by the real field name
// It would require some changes in unit tests
func rrToString(rr mkdns.RR) string {
func rrToString(rr mkdns.RR, logger *logp.Logger) string {
var st string
var keys []string

mapStr, _ := rrToMapStr(rr, false)
mapStr, _ := rrToMapStr(rr, false, logger)
data, ok := mapStr["data"]
delete(mapStr, "data")

Expand Down Expand Up @@ -656,7 +654,7 @@ func rrToString(rr mkdns.RR) string {
return b.String()
}

func rrToMapStr(rr mkdns.RR, ipList bool) (mapstr.M, []string) {
func rrToMapStr(rr mkdns.RR, ipList bool, logger *logp.Logger) (mapstr.M, []string) {
mapStr := mapstr.M{}
rrType := rr.Header().Rrtype

Expand All @@ -671,17 +669,17 @@ func rrToMapStr(rr mkdns.RR, ipList bool) (mapstr.M, []string) {
switch x := rr.(type) {
default:
// We don't have special handling for this type
debugf("No special handling for RR type %s", dnsTypeToString(rrType))
logger.Debugf("No special handling for RR type %s", dnsTypeToString(rrType))
unsupportedRR := new(mkdns.RFC3597)
err := unsupportedRR.ToRFC3597(x)
if err == nil {
rData, err := hexStringToString(unsupportedRR.Rdata)
mapStr["data"] = rData
if err != nil {
debugf("%s", err.Error())
logger.Debugf("%v", err)
}
} else {
debugf("Rdata for the unhandled RR type %s could not be fetched", dnsTypeToString(rrType))
logger.Debugf("Rdata for the unhandled RR type %s could not be fetched", dnsTypeToString(rrType))
}

// Don't attempt to render IPs for answers that are incomplete.
Expand Down Expand Up @@ -735,11 +733,11 @@ func rrToMapStr(rr mkdns.RR, ipList bool) (mapstr.M, []string) {
mapStr["data"] = trimRightDot(x.Ptr)
case *mkdns.RFC3597:
// Miekg/dns lib doesn't handle this type
debugf("Unknown RR type %s", dnsTypeToString(rrType))
logger.Debugf("Unknown RR type %s", dnsTypeToString(rrType))
rData, err := hexStringToString(x.Rdata)
mapStr["data"] = rData
if err != nil {
debugf("%s", err.Error())
logger.Debugf("%v", err)
}
case *mkdns.RRSIG:
mapStr["type_covered"] = dnsTypeToString(x.TypeCovered)
Expand Down Expand Up @@ -781,16 +779,16 @@ func dnsQuestionToString(q mkdns.Question) string {

// rrsToString converts an array of RR's to a
// string.
func rrsToString(r []mkdns.RR) string {
func rrsToString(r []mkdns.RR, logger *logp.Logger) string {
var rrStrs []string
for _, rr := range r {
rrStrs = append(rrStrs, rrToString(rr))
rrStrs = append(rrStrs, rrToString(rr, logger))
}
return strings.Join(rrStrs, "; ")
}

// dnsToString converts a DNS message to a string.
func dnsToString(dns *mkdns.Msg) string {
func dnsToString(dns *mkdns.Msg, logger *logp.Logger) string {
var msgType string
if dns.Response {
msgType = "response"
Expand Down Expand Up @@ -834,17 +832,17 @@ func dnsToString(dns *mkdns.Msg) string {

if len(dns.Answer) > 0 {
a = append(a, fmt.Sprintf("ANSWER %s",
rrsToString(dns.Answer)))
rrsToString(dns.Answer, logger)))
}

if len(dns.Ns) > 0 {
a = append(a, fmt.Sprintf("AUTHORITY %s",
rrsToString(dns.Ns)))
rrsToString(dns.Ns, logger)))
}

if len(dns.Extra) > 0 {
a = append(a, fmt.Sprintf("ADDITIONAL %s",
rrsToString(dns.Extra)))
rrsToString(dns.Extra, logger)))
}

return strings.Join(a, "; ")
Expand Down
Loading

0 comments on commit d39531e

Please sign in to comment.