From 3a4f2f2067350b5853a46d175ff6fc71466788b9 Mon Sep 17 00:00:00 2001 From: Dave Pifke Date: Mon, 8 Jan 2024 18:57:19 -0700 Subject: [PATCH 1/2] Allow use of fs.FS for $INCLUDE and wrap errors This adds ZoneParser.SetIncludeAllowedFS, to specify an fs.FS when enabling support for $INCLUDE, for reading included files from somewhere other than the local filesystem. I've also modified ParseError to support wrapping another error, such as errors encountered while opening the $INCLUDE target. This allows for much more robust handling, using errors.Is() instead of testing for particular strings (which may not be identical between fs.FS implementations). ParseError was being constructed in a lot of places using positional instead of named members. Updating ParseError initialization after the new member field was added makes this change seem a lot larger than it actually is. The changes here should be completely backwards compatible. The ParseError change should be invisible to anyone not trying to unwrap it, and ZoneParser will continue to use os.Open if the existing SetIncludeAllowed method is called instead of the new SetIncludeAllowedFS method. --- dnssec_keyscan.go | 2 +- generate.go | 2 +- privaterr.go | 2 +- scan.go | 67 ++++++--- scan_rr.go | 359 +++++++++++++++++++++++----------------------- scan_test.go | 40 ++++++ svcb.go | 20 +-- 7 files changed, 279 insertions(+), 213 deletions(-) diff --git a/dnssec_keyscan.go b/dnssec_keyscan.go index 5e72249b5..9c9972db6 100644 --- a/dnssec_keyscan.go +++ b/dnssec_keyscan.go @@ -160,7 +160,7 @@ func parseKey(r io.Reader, file string) (map[string]string, error) { k = l.token case zValue: if k == "" { - return nil, &ParseError{file, "no private key seen", l} + return nil, &ParseError{file: file, err: "no private key seen", lex: l} } m[strings.ToLower(k)] = l.token diff --git a/generate.go b/generate.go index 713e9d2da..a81d2bc51 100644 --- a/generate.go +++ b/generate.go @@ -116,7 +116,7 @@ func (r *generateReader) parseError(msg string, end int) *ParseError { l.token = r.s[r.si-1 : end] l.column += r.si // l.column starts one zBLANK before r.s - return &ParseError{r.file, msg, l} + return &ParseError{file: r.file, err: msg, lex: l} } func (r *generateReader) Read(p []byte) (int, error) { diff --git a/privaterr.go b/privaterr.go index d256b652e..350ea5a47 100644 --- a/privaterr.go +++ b/privaterr.go @@ -84,7 +84,7 @@ Fetch: err := r.Data.Parse(text) if err != nil { - return &ParseError{"", err.Error(), l} + return &ParseError{wrappedErr: err, lex: l} } return nil diff --git a/scan.go b/scan.go index 062d8ff3a..776546802 100644 --- a/scan.go +++ b/scan.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "io" + "io/fs" "os" "path/filepath" "strconv" @@ -64,20 +65,26 @@ const ( // ParseError is a parsing error. It contains the parse error and the location in the io.Reader // where the error occurred. type ParseError struct { - file string - err string - lex lex + file string + err string + wrappedErr error + lex lex } func (e *ParseError) Error() (s string) { if e.file != "" { s = e.file + ": " } + if e.err == "" && e.wrappedErr != nil { + e.err = e.wrappedErr.Error() + } s += "dns: " + e.err + ": " + strconv.QuoteToASCII(e.lex.token) + " at line: " + strconv.Itoa(e.lex.line) + ":" + strconv.Itoa(e.lex.column) return } +func (e *ParseError) Unwrap() error { return e.wrappedErr } + type lex struct { token string // text of the token err bool // when true, token text has lexer error @@ -168,8 +175,9 @@ type ZoneParser struct { // sub is used to parse $INCLUDE files and $GENERATE directives. // Next, by calling subNext, forwards the resulting RRs from this // sub parser to the calling code. - sub *ZoneParser - osFile *os.File + sub *ZoneParser + r io.Reader + fsys fs.FS includeDepth uint8 @@ -188,7 +196,7 @@ func NewZoneParser(r io.Reader, origin, file string) *ZoneParser { if origin != "" { origin = Fqdn(origin) if _, ok := IsDomainName(origin); !ok { - pe = &ParseError{file, "bad initial origin name", lex{}} + pe = &ParseError{file: file, err: "bad initial origin name"} } } @@ -220,6 +228,12 @@ func (zp *ZoneParser) SetIncludeAllowed(v bool) { zp.includeAllowed = v } +// SetIncludeAllowedFS controls whether $INCLUDE directives are allowed, and +// specifies an [fs.FS] from which to open and read files. +func (zp *ZoneParser) SetIncludeAllowedFS(v bool, fsys fs.FS) { + zp.includeAllowed, zp.fsys = v, fsys +} + // Err returns the first non-EOF error that was encountered by the // ZoneParser. func (zp *ZoneParser) Err() error { @@ -237,7 +251,7 @@ func (zp *ZoneParser) Err() error { } func (zp *ZoneParser) setParseError(err string, l lex) (RR, bool) { - zp.parseErr = &ParseError{zp.file, err, l} + zp.parseErr = &ParseError{file: zp.file, err: err, lex: l} return nil, false } @@ -260,9 +274,11 @@ func (zp *ZoneParser) subNext() (RR, bool) { return rr, true } - if zp.sub.osFile != nil { - zp.sub.osFile.Close() - zp.sub.osFile = nil + if zp.sub.r != nil { + if c, ok := zp.sub.r.(io.Closer); ok { + c.Close() + } + zp.sub.r = nil } if zp.sub.Err() != nil { @@ -406,20 +422,29 @@ func (zp *ZoneParser) Next() (RR, bool) { includePath = filepath.Join(filepath.Dir(zp.file), includePath) } - r1, e1 := os.Open(includePath) + var r1 io.Reader + var e1 error + if zp.fsys != nil { + r1, e1 = zp.fsys.Open(includePath) + } else { + r1, e1 = os.Open(includePath) + } if e1 != nil { var as string if !filepath.IsAbs(l.token) { as = fmt.Sprintf(" as `%s'", includePath) } - - msg := fmt.Sprintf("failed to open `%s'%s: %v", l.token, as, e1) - return zp.setParseError(msg, l) + zp.parseErr = &ParseError{ + file: zp.file, + wrappedErr: fmt.Errorf("failed to open `%s'%s: %w", l.token, as, e1), + lex: l, + } + return nil, false } zp.sub = NewZoneParser(r1, neworigin, includePath) - zp.sub.defttl, zp.sub.includeDepth, zp.sub.osFile = zp.defttl, zp.includeDepth+1, r1 - zp.sub.SetIncludeAllowed(true) + zp.sub.defttl, zp.sub.includeDepth, zp.sub.r = zp.defttl, zp.includeDepth+1, r1 + zp.sub.SetIncludeAllowedFS(true, zp.fsys) return zp.subNext() case zExpectDirTTLBl: if l.value != zBlank { @@ -1326,12 +1351,12 @@ func slurpRemainder(c *zlexer) *ParseError { case zBlank: l, _ = c.Next() if l.value != zNewline && l.value != zEOF { - return &ParseError{"", "garbage after rdata", l} + return &ParseError{err: "garbage after rdata", lex: l} } case zNewline: case zEOF: default: - return &ParseError{"", "garbage after rdata", l} + return &ParseError{err: "garbage after rdata", lex: l} } return nil } @@ -1340,16 +1365,16 @@ func slurpRemainder(c *zlexer) *ParseError { // Used for NID and L64 record. func stringToNodeID(l lex) (uint64, *ParseError) { if len(l.token) < 19 { - return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l} + return 0, &ParseError{file: l.token, err: "bad NID/L64 NodeID/Locator64", lex: l} } // There must be three colons at fixes positions, if not its a parse error if l.token[4] != ':' && l.token[9] != ':' && l.token[14] != ':' { - return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l} + return 0, &ParseError{file: l.token, err: "bad NID/L64 NodeID/Locator64", lex: l} } s := l.token[0:4] + l.token[5:9] + l.token[10:14] + l.token[15:19] u, err := strconv.ParseUint(s, 16, 64) if err != nil { - return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l} + return 0, &ParseError{file: l.token, err: "bad NID/L64 NodeID/Locator64", lex: l} } return u, nil } diff --git a/scan_rr.go b/scan_rr.go index a635e1c5c..97899552a 100644 --- a/scan_rr.go +++ b/scan_rr.go @@ -3,6 +3,7 @@ package dns import ( "encoding/base64" "errors" + "fmt" "net" "strconv" "strings" @@ -15,14 +16,14 @@ func endingToString(c *zlexer, errstr string) (string, *ParseError) { l, _ := c.Next() // zString for l.value != zNewline && l.value != zEOF { if l.err { - return s.String(), &ParseError{"", errstr, l} + return s.String(), &ParseError{err: errstr, lex: l} } switch l.value { case zString: s.WriteString(l.token) case zBlank: // Ok default: - return "", &ParseError{"", errstr, l} + return "", &ParseError{err: errstr, lex: l} } l, _ = c.Next() } @@ -36,7 +37,7 @@ func endingToTxtSlice(c *zlexer, errstr string) ([]string, *ParseError) { // Get the remaining data until we see a zNewline l, _ := c.Next() if l.err { - return nil, &ParseError{"", errstr, l} + return nil, &ParseError{err: errstr, lex: l} } // Build the slice @@ -45,7 +46,7 @@ func endingToTxtSlice(c *zlexer, errstr string) ([]string, *ParseError) { empty := false for l.value != zNewline && l.value != zEOF { if l.err { - return nil, &ParseError{"", errstr, l} + return nil, &ParseError{err: errstr, lex: l} } switch l.value { case zString: @@ -72,7 +73,7 @@ func endingToTxtSlice(c *zlexer, errstr string) ([]string, *ParseError) { case zBlank: if quote { // zBlank can only be seen in between txt parts. - return nil, &ParseError{"", errstr, l} + return nil, &ParseError{err: errstr, lex: l} } case zQuote: if empty && quote { @@ -81,13 +82,13 @@ func endingToTxtSlice(c *zlexer, errstr string) ([]string, *ParseError) { quote = !quote empty = true default: - return nil, &ParseError{"", errstr, l} + return nil, &ParseError{err: errstr, lex: l} } l, _ = c.Next() } if quote { - return nil, &ParseError{"", errstr, l} + return nil, &ParseError{err: errstr, lex: l} } return s, nil @@ -102,7 +103,7 @@ func (rr *A) parse(c *zlexer, o string) *ParseError { // IPv4. isIPv4 := !strings.Contains(l.token, ":") if rr.A == nil || !isIPv4 || l.err { - return &ParseError{"", "bad A A", l} + return &ParseError{err: "bad A A", lex: l} } return slurpRemainder(c) } @@ -114,7 +115,7 @@ func (rr *AAAA) parse(c *zlexer, o string) *ParseError { // addresses cannot include ":". isIPv6 := strings.Contains(l.token, ":") if rr.AAAA == nil || !isIPv6 || l.err { - return &ParseError{"", "bad AAAA AAAA", l} + return &ParseError{err: "bad AAAA AAAA", lex: l} } return slurpRemainder(c) } @@ -123,7 +124,7 @@ func (rr *NS) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad NS Ns", l} + return &ParseError{err: "bad NS Ns", lex: l} } rr.Ns = name return slurpRemainder(c) @@ -133,7 +134,7 @@ func (rr *PTR) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad PTR Ptr", l} + return &ParseError{err: "bad PTR Ptr", lex: l} } rr.Ptr = name return slurpRemainder(c) @@ -143,7 +144,7 @@ func (rr *NSAPPTR) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad NSAP-PTR Ptr", l} + return &ParseError{err: "bad NSAP-PTR Ptr", lex: l} } rr.Ptr = name return slurpRemainder(c) @@ -153,7 +154,7 @@ func (rr *RP) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() mbox, mboxOk := toAbsoluteName(l.token, o) if l.err || !mboxOk { - return &ParseError{"", "bad RP Mbox", l} + return &ParseError{err: "bad RP Mbox", lex: l} } rr.Mbox = mbox @@ -163,7 +164,7 @@ func (rr *RP) parse(c *zlexer, o string) *ParseError { txt, txtOk := toAbsoluteName(l.token, o) if l.err || !txtOk { - return &ParseError{"", "bad RP Txt", l} + return &ParseError{err: "bad RP Txt", lex: l} } rr.Txt = txt @@ -174,7 +175,7 @@ func (rr *MR) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad MR Mr", l} + return &ParseError{err: "bad MR Mr", lex: l} } rr.Mr = name return slurpRemainder(c) @@ -184,7 +185,7 @@ func (rr *MB) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad MB Mb", l} + return &ParseError{err: "bad MB Mb", lex: l} } rr.Mb = name return slurpRemainder(c) @@ -194,7 +195,7 @@ func (rr *MG) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad MG Mg", l} + return &ParseError{err: "bad MG Mg", lex: l} } rr.Mg = name return slurpRemainder(c) @@ -227,7 +228,7 @@ func (rr *MINFO) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() rmail, rmailOk := toAbsoluteName(l.token, o) if l.err || !rmailOk { - return &ParseError{"", "bad MINFO Rmail", l} + return &ParseError{err: "bad MINFO Rmail", lex: l} } rr.Rmail = rmail @@ -237,7 +238,7 @@ func (rr *MINFO) parse(c *zlexer, o string) *ParseError { email, emailOk := toAbsoluteName(l.token, o) if l.err || !emailOk { - return &ParseError{"", "bad MINFO Email", l} + return &ParseError{err: "bad MINFO Email", lex: l} } rr.Email = email @@ -248,7 +249,7 @@ func (rr *MF) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad MF Mf", l} + return &ParseError{err: "bad MF Mf", lex: l} } rr.Mf = name return slurpRemainder(c) @@ -258,7 +259,7 @@ func (rr *MD) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad MD Md", l} + return &ParseError{err: "bad MD Md", lex: l} } rr.Md = name return slurpRemainder(c) @@ -268,7 +269,7 @@ func (rr *MX) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad MX Pref", l} + return &ParseError{err: "bad MX Pref", lex: l} } rr.Preference = uint16(i) @@ -278,7 +279,7 @@ func (rr *MX) parse(c *zlexer, o string) *ParseError { name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad MX Mx", l} + return &ParseError{err: "bad MX Mx", lex: l} } rr.Mx = name @@ -289,7 +290,7 @@ func (rr *RT) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil { - return &ParseError{"", "bad RT Preference", l} + return &ParseError{err: "bad RT Preference", lex: l} } rr.Preference = uint16(i) @@ -299,7 +300,7 @@ func (rr *RT) parse(c *zlexer, o string) *ParseError { name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad RT Host", l} + return &ParseError{err: "bad RT Host", lex: l} } rr.Host = name @@ -310,7 +311,7 @@ func (rr *AFSDB) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad AFSDB Subtype", l} + return &ParseError{err: "bad AFSDB Subtype", lex: l} } rr.Subtype = uint16(i) @@ -320,7 +321,7 @@ func (rr *AFSDB) parse(c *zlexer, o string) *ParseError { name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad AFSDB Hostname", l} + return &ParseError{err: "bad AFSDB Hostname", lex: l} } rr.Hostname = name return slurpRemainder(c) @@ -329,7 +330,7 @@ func (rr *AFSDB) parse(c *zlexer, o string) *ParseError { func (rr *X25) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() if l.err { - return &ParseError{"", "bad X25 PSDNAddress", l} + return &ParseError{err: "bad X25 PSDNAddress", lex: l} } rr.PSDNAddress = l.token return slurpRemainder(c) @@ -339,7 +340,7 @@ func (rr *KX) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad KX Pref", l} + return &ParseError{err: "bad KX Pref", lex: l} } rr.Preference = uint16(i) @@ -349,7 +350,7 @@ func (rr *KX) parse(c *zlexer, o string) *ParseError { name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad KX Exchanger", l} + return &ParseError{err: "bad KX Exchanger", lex: l} } rr.Exchanger = name return slurpRemainder(c) @@ -359,7 +360,7 @@ func (rr *CNAME) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad CNAME Target", l} + return &ParseError{err: "bad CNAME Target", lex: l} } rr.Target = name return slurpRemainder(c) @@ -369,7 +370,7 @@ func (rr *DNAME) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad DNAME Target", l} + return &ParseError{err: "bad DNAME Target", lex: l} } rr.Target = name return slurpRemainder(c) @@ -379,7 +380,7 @@ func (rr *SOA) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() ns, nsOk := toAbsoluteName(l.token, o) if l.err || !nsOk { - return &ParseError{"", "bad SOA Ns", l} + return &ParseError{err: "bad SOA Ns", lex: l} } rr.Ns = ns @@ -389,7 +390,7 @@ func (rr *SOA) parse(c *zlexer, o string) *ParseError { mbox, mboxOk := toAbsoluteName(l.token, o) if l.err || !mboxOk { - return &ParseError{"", "bad SOA Mbox", l} + return &ParseError{err: "bad SOA Mbox", lex: l} } rr.Mbox = mbox @@ -402,16 +403,16 @@ func (rr *SOA) parse(c *zlexer, o string) *ParseError { for i := 0; i < 5; i++ { l, _ = c.Next() if l.err { - return &ParseError{"", "bad SOA zone parameter", l} + return &ParseError{err: "bad SOA zone parameter", lex: l} } if j, err := strconv.ParseUint(l.token, 10, 32); err != nil { if i == 0 { // Serial must be a number - return &ParseError{"", "bad SOA zone parameter", l} + return &ParseError{err: "bad SOA zone parameter", lex: l} } // We allow other fields to be unitful duration strings if v, ok = stringToTTL(l.token); !ok { - return &ParseError{"", "bad SOA zone parameter", l} + return &ParseError{err: "bad SOA zone parameter", lex: l} } } else { @@ -441,7 +442,7 @@ func (rr *SRV) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad SRV Priority", l} + return &ParseError{err: "bad SRV Priority", lex: l} } rr.Priority = uint16(i) @@ -449,7 +450,7 @@ func (rr *SRV) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() // zString i, e1 := strconv.ParseUint(l.token, 10, 16) if e1 != nil || l.err { - return &ParseError{"", "bad SRV Weight", l} + return &ParseError{err: "bad SRV Weight", lex: l} } rr.Weight = uint16(i) @@ -457,7 +458,7 @@ func (rr *SRV) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() // zString i, e2 := strconv.ParseUint(l.token, 10, 16) if e2 != nil || l.err { - return &ParseError{"", "bad SRV Port", l} + return &ParseError{err: "bad SRV Port", lex: l} } rr.Port = uint16(i) @@ -467,7 +468,7 @@ func (rr *SRV) parse(c *zlexer, o string) *ParseError { name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad SRV Target", l} + return &ParseError{err: "bad SRV Target", lex: l} } rr.Target = name return slurpRemainder(c) @@ -477,7 +478,7 @@ func (rr *NAPTR) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad NAPTR Order", l} + return &ParseError{err: "bad NAPTR Order", lex: l} } rr.Order = uint16(i) @@ -485,7 +486,7 @@ func (rr *NAPTR) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() // zString i, e1 := strconv.ParseUint(l.token, 10, 16) if e1 != nil || l.err { - return &ParseError{"", "bad NAPTR Preference", l} + return &ParseError{err: "bad NAPTR Preference", lex: l} } rr.Preference = uint16(i) @@ -493,57 +494,57 @@ func (rr *NAPTR) parse(c *zlexer, o string) *ParseError { c.Next() // zBlank l, _ = c.Next() // _QUOTE if l.value != zQuote { - return &ParseError{"", "bad NAPTR Flags", l} + return &ParseError{err: "bad NAPTR Flags", lex: l} } l, _ = c.Next() // Either String or Quote if l.value == zString { rr.Flags = l.token l, _ = c.Next() // _QUOTE if l.value != zQuote { - return &ParseError{"", "bad NAPTR Flags", l} + return &ParseError{err: "bad NAPTR Flags", lex: l} } } else if l.value == zQuote { rr.Flags = "" } else { - return &ParseError{"", "bad NAPTR Flags", l} + return &ParseError{err: "bad NAPTR Flags", lex: l} } // Service c.Next() // zBlank l, _ = c.Next() // _QUOTE if l.value != zQuote { - return &ParseError{"", "bad NAPTR Service", l} + return &ParseError{err: "bad NAPTR Service", lex: l} } l, _ = c.Next() // Either String or Quote if l.value == zString { rr.Service = l.token l, _ = c.Next() // _QUOTE if l.value != zQuote { - return &ParseError{"", "bad NAPTR Service", l} + return &ParseError{err: "bad NAPTR Service", lex: l} } } else if l.value == zQuote { rr.Service = "" } else { - return &ParseError{"", "bad NAPTR Service", l} + return &ParseError{err: "bad NAPTR Service", lex: l} } // Regexp c.Next() // zBlank l, _ = c.Next() // _QUOTE if l.value != zQuote { - return &ParseError{"", "bad NAPTR Regexp", l} + return &ParseError{err: "bad NAPTR Regexp", lex: l} } l, _ = c.Next() // Either String or Quote if l.value == zString { rr.Regexp = l.token l, _ = c.Next() // _QUOTE if l.value != zQuote { - return &ParseError{"", "bad NAPTR Regexp", l} + return &ParseError{err: "bad NAPTR Regexp", lex: l} } } else if l.value == zQuote { rr.Regexp = "" } else { - return &ParseError{"", "bad NAPTR Regexp", l} + return &ParseError{err: "bad NAPTR Regexp", lex: l} } // After quote no space?? @@ -553,7 +554,7 @@ func (rr *NAPTR) parse(c *zlexer, o string) *ParseError { name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad NAPTR Replacement", l} + return &ParseError{err: "bad NAPTR Replacement", lex: l} } rr.Replacement = name return slurpRemainder(c) @@ -563,7 +564,7 @@ func (rr *TALINK) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() previousName, previousNameOk := toAbsoluteName(l.token, o) if l.err || !previousNameOk { - return &ParseError{"", "bad TALINK PreviousName", l} + return &ParseError{err: "bad TALINK PreviousName", lex: l} } rr.PreviousName = previousName @@ -573,7 +574,7 @@ func (rr *TALINK) parse(c *zlexer, o string) *ParseError { nextName, nextNameOk := toAbsoluteName(l.token, o) if l.err || !nextNameOk { - return &ParseError{"", "bad TALINK NextName", l} + return &ParseError{err: "bad TALINK NextName", lex: l} } rr.NextName = nextName @@ -591,7 +592,7 @@ func (rr *LOC) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 32) if e != nil || l.err || i > 90 { - return &ParseError{"", "bad LOC Latitude", l} + return &ParseError{err: "bad LOC Latitude", lex: l} } rr.Latitude = 1000 * 60 * 60 * uint32(i) @@ -602,7 +603,7 @@ func (rr *LOC) parse(c *zlexer, o string) *ParseError { goto East } if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 59 { - return &ParseError{"", "bad LOC Latitude minutes", l} + return &ParseError{err: "bad LOC Latitude minutes", lex: l} } else { rr.Latitude += 1000 * 60 * uint32(i) } @@ -610,7 +611,7 @@ func (rr *LOC) parse(c *zlexer, o string) *ParseError { c.Next() // zBlank l, _ = c.Next() if i, err := strconv.ParseFloat(l.token, 64); err != nil || l.err || i < 0 || i >= 60 { - return &ParseError{"", "bad LOC Latitude seconds", l} + return &ParseError{err: "bad LOC Latitude seconds", lex: l} } else { rr.Latitude += uint32(1000 * i) } @@ -621,14 +622,14 @@ func (rr *LOC) parse(c *zlexer, o string) *ParseError { goto East } // If still alive, flag an error - return &ParseError{"", "bad LOC Latitude North/South", l} + return &ParseError{err: "bad LOC Latitude North/South", lex: l} East: // East c.Next() // zBlank l, _ = c.Next() if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 180 { - return &ParseError{"", "bad LOC Longitude", l} + return &ParseError{err: "bad LOC Longitude", lex: l} } else { rr.Longitude = 1000 * 60 * 60 * uint32(i) } @@ -639,14 +640,14 @@ East: goto Altitude } if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 59 { - return &ParseError{"", "bad LOC Longitude minutes", l} + return &ParseError{err: "bad LOC Longitude minutes", lex: l} } else { rr.Longitude += 1000 * 60 * uint32(i) } c.Next() // zBlank l, _ = c.Next() if i, err := strconv.ParseFloat(l.token, 64); err != nil || l.err || i < 0 || i >= 60 { - return &ParseError{"", "bad LOC Longitude seconds", l} + return &ParseError{err: "bad LOC Longitude seconds", lex: l} } else { rr.Longitude += uint32(1000 * i) } @@ -657,19 +658,19 @@ East: goto Altitude } // If still alive, flag an error - return &ParseError{"", "bad LOC Longitude East/West", l} + return &ParseError{err: "bad LOC Longitude East/West", lex: l} Altitude: c.Next() // zBlank l, _ = c.Next() if l.token == "" || l.err { - return &ParseError{"", "bad LOC Altitude", l} + return &ParseError{err: "bad LOC Altitude", lex: l} } if l.token[len(l.token)-1] == 'M' || l.token[len(l.token)-1] == 'm' { l.token = l.token[0 : len(l.token)-1] } if i, err := strconv.ParseFloat(l.token, 64); err != nil { - return &ParseError{"", "bad LOC Altitude", l} + return &ParseError{err: "bad LOC Altitude", lex: l} } else { rr.Altitude = uint32(i*100.0 + 10000000.0 + 0.5) } @@ -684,19 +685,19 @@ Altitude: case 0: // Size exp, m, ok := stringToCm(l.token) if !ok { - return &ParseError{"", "bad LOC Size", l} + return &ParseError{err: "bad LOC Size", lex: l} } rr.Size = exp&0x0f | m<<4&0xf0 case 1: // HorizPre exp, m, ok := stringToCm(l.token) if !ok { - return &ParseError{"", "bad LOC HorizPre", l} + return &ParseError{err: "bad LOC HorizPre", lex: l} } rr.HorizPre = exp&0x0f | m<<4&0xf0 case 2: // VertPre exp, m, ok := stringToCm(l.token) if !ok { - return &ParseError{"", "bad LOC VertPre", l} + return &ParseError{err: "bad LOC VertPre", lex: l} } rr.VertPre = exp&0x0f | m<<4&0xf0 } @@ -704,7 +705,7 @@ Altitude: case zBlank: // Ok default: - return &ParseError{"", "bad LOC Size, HorizPre or VertPre", l} + return &ParseError{err: "bad LOC Size, HorizPre or VertPre", lex: l} } l, _ = c.Next() } @@ -716,14 +717,14 @@ func (rr *HIP) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 8) if e != nil || l.err { - return &ParseError{"", "bad HIP PublicKeyAlgorithm", l} + return &ParseError{err: "bad HIP PublicKeyAlgorithm", lex: l} } rr.PublicKeyAlgorithm = uint8(i) c.Next() // zBlank l, _ = c.Next() // zString if l.token == "" || l.err { - return &ParseError{"", "bad HIP Hit", l} + return &ParseError{err: "bad HIP Hit", lex: l} } rr.Hit = l.token // This can not contain spaces, see RFC 5205 Section 6. rr.HitLength = uint8(len(rr.Hit)) / 2 @@ -731,12 +732,12 @@ func (rr *HIP) parse(c *zlexer, o string) *ParseError { c.Next() // zBlank l, _ = c.Next() // zString if l.token == "" || l.err { - return &ParseError{"", "bad HIP PublicKey", l} + return &ParseError{err: "bad HIP PublicKey", lex: l} } rr.PublicKey = l.token // This cannot contain spaces decodedPK, decodedPKerr := base64.StdEncoding.DecodeString(rr.PublicKey) if decodedPKerr != nil { - return &ParseError{"", "bad HIP PublicKey", l} + return &ParseError{err: "bad HIP PublicKey", lex: l} } rr.PublicKeyLength = uint16(len(decodedPK)) @@ -748,13 +749,13 @@ func (rr *HIP) parse(c *zlexer, o string) *ParseError { case zString: name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad HIP RendezvousServers", l} + return &ParseError{err: "bad HIP RendezvousServers", lex: l} } xs = append(xs, name) case zBlank: // Ok default: - return &ParseError{"", "bad HIP RendezvousServers", l} + return &ParseError{err: "bad HIP RendezvousServers", lex: l} } l, _ = c.Next() } @@ -768,7 +769,7 @@ func (rr *CERT) parse(c *zlexer, o string) *ParseError { if v, ok := StringToCertType[l.token]; ok { rr.Type = v } else if i, err := strconv.ParseUint(l.token, 10, 16); err != nil { - return &ParseError{"", "bad CERT Type", l} + return &ParseError{err: "bad CERT Type", lex: l} } else { rr.Type = uint16(i) } @@ -776,7 +777,7 @@ func (rr *CERT) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() // zString i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad CERT KeyTag", l} + return &ParseError{err: "bad CERT KeyTag", lex: l} } rr.KeyTag = uint16(i) c.Next() // zBlank @@ -784,7 +785,7 @@ func (rr *CERT) parse(c *zlexer, o string) *ParseError { if v, ok := StringToAlgorithm[l.token]; ok { rr.Algorithm = v } else if i, err := strconv.ParseUint(l.token, 10, 8); err != nil { - return &ParseError{"", "bad CERT Algorithm", l} + return &ParseError{err: "bad CERT Algorithm", lex: l} } else { rr.Algorithm = uint8(i) } @@ -810,7 +811,7 @@ func (rr *CSYNC) parse(c *zlexer, o string) *ParseError { j, e := strconv.ParseUint(l.token, 10, 32) if e != nil { // Serial must be a number - return &ParseError{"", "bad CSYNC serial", l} + return &ParseError{err: "bad CSYNC serial", lex: l} } rr.Serial = uint32(j) @@ -820,7 +821,7 @@ func (rr *CSYNC) parse(c *zlexer, o string) *ParseError { j, e1 := strconv.ParseUint(l.token, 10, 16) if e1 != nil { // Serial must be a number - return &ParseError{"", "bad CSYNC flags", l} + return &ParseError{err: "bad CSYNC flags", lex: l} } rr.Flags = uint16(j) @@ -838,12 +839,12 @@ func (rr *CSYNC) parse(c *zlexer, o string) *ParseError { tokenUpper := strings.ToUpper(l.token) if k, ok = StringToType[tokenUpper]; !ok { if k, ok = typeToInt(l.token); !ok { - return &ParseError{"", "bad CSYNC TypeBitMap", l} + return &ParseError{err: "bad CSYNC TypeBitMap", lex: l} } } rr.TypeBitMap = append(rr.TypeBitMap, k) default: - return &ParseError{"", "bad CSYNC TypeBitMap", l} + return &ParseError{err: "bad CSYNC TypeBitMap", lex: l} } l, _ = c.Next() } @@ -854,7 +855,7 @@ func (rr *ZONEMD) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 32) if e != nil || l.err { - return &ParseError{"", "bad ZONEMD Serial", l} + return &ParseError{err: "bad ZONEMD Serial", lex: l} } rr.Serial = uint32(i) @@ -862,7 +863,7 @@ func (rr *ZONEMD) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad ZONEMD Scheme", l} + return &ParseError{err: "bad ZONEMD Scheme", lex: l} } rr.Scheme = uint8(i) @@ -870,7 +871,7 @@ func (rr *ZONEMD) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() i, err := strconv.ParseUint(l.token, 10, 8) if err != nil || l.err { - return &ParseError{"", "bad ZONEMD Hash Algorithm", l} + return &ParseError{err: "bad ZONEMD Hash Algorithm", lex: l} } rr.Hash = uint8(i) @@ -891,11 +892,11 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError { if strings.HasPrefix(tokenUpper, "TYPE") { t, ok = typeToInt(l.token) if !ok { - return &ParseError{"", "bad RRSIG Typecovered", l} + return &ParseError{err: "bad RRSIG Typecovered", lex: l} } rr.TypeCovered = t } else { - return &ParseError{"", "bad RRSIG Typecovered", l} + return &ParseError{err: "bad RRSIG Typecovered", lex: l} } } else { rr.TypeCovered = t @@ -904,14 +905,14 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError { c.Next() // zBlank l, _ = c.Next() if l.err { - return &ParseError{"", "bad RRSIG Algorithm", l} + return &ParseError{err: "bad RRSIG Algorithm", lex: l} } i, e := strconv.ParseUint(l.token, 10, 8) rr.Algorithm = uint8(i) // if 0 we'll check the mnemonic in the if if e != nil { v, ok := StringToAlgorithm[l.token] if !ok { - return &ParseError{"", "bad RRSIG Algorithm", l} + return &ParseError{err: "bad RRSIG Algorithm", lex: l} } rr.Algorithm = v } @@ -920,7 +921,7 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad RRSIG Labels", l} + return &ParseError{err: "bad RRSIG Labels", lex: l} } rr.Labels = uint8(i) @@ -928,7 +929,7 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() i, e2 := strconv.ParseUint(l.token, 10, 32) if e2 != nil || l.err { - return &ParseError{"", "bad RRSIG OrigTtl", l} + return &ParseError{err: "bad RRSIG OrigTtl", lex: l} } rr.OrigTtl = uint32(i) @@ -939,7 +940,7 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError { if i, err := strconv.ParseUint(l.token, 10, 32); err == nil { rr.Expiration = uint32(i) } else { - return &ParseError{"", "bad RRSIG Expiration", l} + return &ParseError{err: "bad RRSIG Expiration", lex: l} } } else { rr.Expiration = i @@ -951,7 +952,7 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError { if i, err := strconv.ParseUint(l.token, 10, 32); err == nil { rr.Inception = uint32(i) } else { - return &ParseError{"", "bad RRSIG Inception", l} + return &ParseError{err: "bad RRSIG Inception", lex: l} } } else { rr.Inception = i @@ -961,7 +962,7 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() i, e3 := strconv.ParseUint(l.token, 10, 16) if e3 != nil || l.err { - return &ParseError{"", "bad RRSIG KeyTag", l} + return &ParseError{err: "bad RRSIG KeyTag", lex: l} } rr.KeyTag = uint16(i) @@ -970,7 +971,7 @@ func (rr *RRSIG) parse(c *zlexer, o string) *ParseError { rr.SignerName = l.token name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad RRSIG SignerName", l} + return &ParseError{err: "bad RRSIG SignerName", lex: l} } rr.SignerName = name @@ -987,7 +988,7 @@ func (rr *NSEC) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad NSEC NextDomain", l} + return &ParseError{err: "bad NSEC NextDomain", lex: l} } rr.NextDomain = name @@ -1005,12 +1006,12 @@ func (rr *NSEC) parse(c *zlexer, o string) *ParseError { tokenUpper := strings.ToUpper(l.token) if k, ok = StringToType[tokenUpper]; !ok { if k, ok = typeToInt(l.token); !ok { - return &ParseError{"", "bad NSEC TypeBitMap", l} + return &ParseError{err: "bad NSEC TypeBitMap", lex: l} } } rr.TypeBitMap = append(rr.TypeBitMap, k) default: - return &ParseError{"", "bad NSEC TypeBitMap", l} + return &ParseError{err: "bad NSEC TypeBitMap", lex: l} } l, _ = c.Next() } @@ -1021,27 +1022,27 @@ func (rr *NSEC3) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 8) if e != nil || l.err { - return &ParseError{"", "bad NSEC3 Hash", l} + return &ParseError{err: "bad NSEC3 Hash", lex: l} } rr.Hash = uint8(i) c.Next() // zBlank l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad NSEC3 Flags", l} + return &ParseError{err: "bad NSEC3 Flags", lex: l} } rr.Flags = uint8(i) c.Next() // zBlank l, _ = c.Next() i, e2 := strconv.ParseUint(l.token, 10, 16) if e2 != nil || l.err { - return &ParseError{"", "bad NSEC3 Iterations", l} + return &ParseError{err: "bad NSEC3 Iterations", lex: l} } rr.Iterations = uint16(i) c.Next() l, _ = c.Next() if l.token == "" || l.err { - return &ParseError{"", "bad NSEC3 Salt", l} + return &ParseError{err: "bad NSEC3 Salt", lex: l} } if l.token != "-" { rr.SaltLength = uint8(len(l.token)) / 2 @@ -1051,7 +1052,7 @@ func (rr *NSEC3) parse(c *zlexer, o string) *ParseError { c.Next() l, _ = c.Next() if l.token == "" || l.err { - return &ParseError{"", "bad NSEC3 NextDomain", l} + return &ParseError{err: "bad NSEC3 NextDomain", lex: l} } rr.HashLength = 20 // Fix for NSEC3 (sha1 160 bits) rr.NextDomain = l.token @@ -1070,12 +1071,12 @@ func (rr *NSEC3) parse(c *zlexer, o string) *ParseError { tokenUpper := strings.ToUpper(l.token) if k, ok = StringToType[tokenUpper]; !ok { if k, ok = typeToInt(l.token); !ok { - return &ParseError{"", "bad NSEC3 TypeBitMap", l} + return &ParseError{err: "bad NSEC3 TypeBitMap", lex: l} } } rr.TypeBitMap = append(rr.TypeBitMap, k) default: - return &ParseError{"", "bad NSEC3 TypeBitMap", l} + return &ParseError{err: "bad NSEC3 TypeBitMap", lex: l} } l, _ = c.Next() } @@ -1086,21 +1087,21 @@ func (rr *NSEC3PARAM) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 8) if e != nil || l.err { - return &ParseError{"", "bad NSEC3PARAM Hash", l} + return &ParseError{err: "bad NSEC3PARAM Hash", lex: l} } rr.Hash = uint8(i) c.Next() // zBlank l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad NSEC3PARAM Flags", l} + return &ParseError{err: "bad NSEC3PARAM Flags", lex: l} } rr.Flags = uint8(i) c.Next() // zBlank l, _ = c.Next() i, e2 := strconv.ParseUint(l.token, 10, 16) if e2 != nil || l.err { - return &ParseError{"", "bad NSEC3PARAM Iterations", l} + return &ParseError{err: "bad NSEC3PARAM Iterations", lex: l} } rr.Iterations = uint16(i) c.Next() @@ -1115,7 +1116,7 @@ func (rr *NSEC3PARAM) parse(c *zlexer, o string) *ParseError { func (rr *EUI48) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() if len(l.token) != 17 || l.err { - return &ParseError{"", "bad EUI48 Address", l} + return &ParseError{err: "bad EUI48 Address", lex: l} } addr := make([]byte, 12) dash := 0 @@ -1124,7 +1125,7 @@ func (rr *EUI48) parse(c *zlexer, o string) *ParseError { addr[i+1] = l.token[i+1+dash] dash++ if l.token[i+1+dash] != '-' { - return &ParseError{"", "bad EUI48 Address", l} + return &ParseError{err: "bad EUI48 Address", lex: l} } } addr[10] = l.token[15] @@ -1132,7 +1133,7 @@ func (rr *EUI48) parse(c *zlexer, o string) *ParseError { i, e := strconv.ParseUint(string(addr), 16, 48) if e != nil { - return &ParseError{"", "bad EUI48 Address", l} + return &ParseError{err: "bad EUI48 Address", lex: l} } rr.Address = i return slurpRemainder(c) @@ -1141,7 +1142,7 @@ func (rr *EUI48) parse(c *zlexer, o string) *ParseError { func (rr *EUI64) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() if len(l.token) != 23 || l.err { - return &ParseError{"", "bad EUI64 Address", l} + return &ParseError{err: "bad EUI64 Address", lex: l} } addr := make([]byte, 16) dash := 0 @@ -1150,7 +1151,7 @@ func (rr *EUI64) parse(c *zlexer, o string) *ParseError { addr[i+1] = l.token[i+1+dash] dash++ if l.token[i+1+dash] != '-' { - return &ParseError{"", "bad EUI64 Address", l} + return &ParseError{err: "bad EUI64 Address", lex: l} } } addr[14] = l.token[21] @@ -1158,7 +1159,7 @@ func (rr *EUI64) parse(c *zlexer, o string) *ParseError { i, e := strconv.ParseUint(string(addr), 16, 64) if e != nil { - return &ParseError{"", "bad EUI68 Address", l} + return &ParseError{err: "bad EUI68 Address", lex: l} } rr.Address = i return slurpRemainder(c) @@ -1168,14 +1169,14 @@ func (rr *SSHFP) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 8) if e != nil || l.err { - return &ParseError{"", "bad SSHFP Algorithm", l} + return &ParseError{err: "bad SSHFP Algorithm", lex: l} } rr.Algorithm = uint8(i) c.Next() // zBlank l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad SSHFP Type", l} + return &ParseError{err: "bad SSHFP Type", lex: l} } rr.Type = uint8(i) c.Next() // zBlank @@ -1191,21 +1192,21 @@ func (rr *DNSKEY) parseDNSKEY(c *zlexer, o, typ string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad " + typ + " Flags", l} + return &ParseError{err: "bad " + typ + " Flags", lex: l} } rr.Flags = uint16(i) c.Next() // zBlank l, _ = c.Next() // zString i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad " + typ + " Protocol", l} + return &ParseError{err: "bad " + typ + " Protocol", lex: l} } rr.Protocol = uint8(i) c.Next() // zBlank l, _ = c.Next() // zString i, e2 := strconv.ParseUint(l.token, 10, 8) if e2 != nil || l.err { - return &ParseError{"", "bad " + typ + " Algorithm", l} + return &ParseError{err: "bad " + typ + " Algorithm", lex: l} } rr.Algorithm = uint8(i) s, e3 := endingToString(c, "bad "+typ+" PublicKey") @@ -1227,7 +1228,7 @@ func (rr *IPSECKEY) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() num, err := strconv.ParseUint(l.token, 10, 8) if err != nil || l.err { - return &ParseError{"", "bad IPSECKEY value", l} + return &ParseError{err: "bad IPSECKEY value", lex: l} } rr.Precedence = uint8(num) c.Next() // zBlank @@ -1235,7 +1236,7 @@ func (rr *IPSECKEY) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() num, err = strconv.ParseUint(l.token, 10, 8) if err != nil || l.err { - return &ParseError{"", "bad IPSECKEY value", l} + return &ParseError{err: "bad IPSECKEY value", lex: l} } rr.GatewayType = uint8(num) c.Next() // zBlank @@ -1243,19 +1244,19 @@ func (rr *IPSECKEY) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() num, err = strconv.ParseUint(l.token, 10, 8) if err != nil || l.err { - return &ParseError{"", "bad IPSECKEY value", l} + return &ParseError{err: "bad IPSECKEY value", lex: l} } rr.Algorithm = uint8(num) c.Next() // zBlank l, _ = c.Next() if l.err { - return &ParseError{"", "bad IPSECKEY gateway", l} + return &ParseError{err: "bad IPSECKEY gateway", lex: l} } rr.GatewayAddr, rr.GatewayHost, err = parseAddrHostUnion(l.token, o, rr.GatewayType) if err != nil { - return &ParseError{"", "IPSECKEY " + err.Error(), l} + return &ParseError{wrappedErr: fmt.Errorf("IPSECKEY %w", err), lex: l} } c.Next() // zBlank @@ -1272,14 +1273,14 @@ func (rr *AMTRELAY) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() num, err := strconv.ParseUint(l.token, 10, 8) if err != nil || l.err { - return &ParseError{"", "bad AMTRELAY value", l} + return &ParseError{err: "bad AMTRELAY value", lex: l} } rr.Precedence = uint8(num) c.Next() // zBlank l, _ = c.Next() if l.err || !(l.token == "0" || l.token == "1") { - return &ParseError{"", "bad discovery value", l} + return &ParseError{err: "bad discovery value", lex: l} } if l.token == "1" { rr.GatewayType = 0x80 @@ -1290,19 +1291,19 @@ func (rr *AMTRELAY) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() num, err = strconv.ParseUint(l.token, 10, 8) if err != nil || l.err { - return &ParseError{"", "bad AMTRELAY value", l} + return &ParseError{err: "bad AMTRELAY value", lex: l} } rr.GatewayType |= uint8(num) c.Next() // zBlank l, _ = c.Next() if l.err { - return &ParseError{"", "bad AMTRELAY gateway", l} + return &ParseError{err: "bad AMTRELAY gateway", lex: l} } rr.GatewayAddr, rr.GatewayHost, err = parseAddrHostUnion(l.token, o, rr.GatewayType&0x7f) if err != nil { - return &ParseError{"", "AMTRELAY " + err.Error(), l} + return &ParseError{wrappedErr: fmt.Errorf("AMTRELAY %w", err), lex: l} } return slurpRemainder(c) @@ -1338,21 +1339,21 @@ func (rr *RKEY) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad RKEY Flags", l} + return &ParseError{err: "bad RKEY Flags", lex: l} } rr.Flags = uint16(i) c.Next() // zBlank l, _ = c.Next() // zString i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad RKEY Protocol", l} + return &ParseError{err: "bad RKEY Protocol", lex: l} } rr.Protocol = uint8(i) c.Next() // zBlank l, _ = c.Next() // zString i, e2 := strconv.ParseUint(l.token, 10, 8) if e2 != nil || l.err { - return &ParseError{"", "bad RKEY Algorithm", l} + return &ParseError{err: "bad RKEY Algorithm", lex: l} } rr.Algorithm = uint8(i) s, e3 := endingToString(c, "bad RKEY PublicKey") @@ -1385,21 +1386,21 @@ func (rr *GPOS) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() _, e := strconv.ParseFloat(l.token, 64) if e != nil || l.err { - return &ParseError{"", "bad GPOS Longitude", l} + return &ParseError{err: "bad GPOS Longitude", lex: l} } rr.Longitude = l.token c.Next() // zBlank l, _ = c.Next() _, e1 := strconv.ParseFloat(l.token, 64) if e1 != nil || l.err { - return &ParseError{"", "bad GPOS Latitude", l} + return &ParseError{err: "bad GPOS Latitude", lex: l} } rr.Latitude = l.token c.Next() // zBlank l, _ = c.Next() _, e2 := strconv.ParseFloat(l.token, 64) if e2 != nil || l.err { - return &ParseError{"", "bad GPOS Altitude", l} + return &ParseError{err: "bad GPOS Altitude", lex: l} } rr.Altitude = l.token return slurpRemainder(c) @@ -1409,7 +1410,7 @@ func (rr *DS) parseDS(c *zlexer, o, typ string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad " + typ + " KeyTag", l} + return &ParseError{err: "bad " + typ + " KeyTag", lex: l} } rr.KeyTag = uint16(i) c.Next() // zBlank @@ -1418,7 +1419,7 @@ func (rr *DS) parseDS(c *zlexer, o, typ string) *ParseError { tokenUpper := strings.ToUpper(l.token) i, ok := StringToAlgorithm[tokenUpper] if !ok || l.err { - return &ParseError{"", "bad " + typ + " Algorithm", l} + return &ParseError{err: "bad " + typ + " Algorithm", lex: l} } rr.Algorithm = i } else { @@ -1428,7 +1429,7 @@ func (rr *DS) parseDS(c *zlexer, o, typ string) *ParseError { l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad " + typ + " DigestType", l} + return &ParseError{err: "bad " + typ + " DigestType", lex: l} } rr.DigestType = uint8(i) s, e2 := endingToString(c, "bad "+typ+" Digest") @@ -1443,7 +1444,7 @@ func (rr *TA) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad TA KeyTag", l} + return &ParseError{err: "bad TA KeyTag", lex: l} } rr.KeyTag = uint16(i) c.Next() // zBlank @@ -1452,7 +1453,7 @@ func (rr *TA) parse(c *zlexer, o string) *ParseError { tokenUpper := strings.ToUpper(l.token) i, ok := StringToAlgorithm[tokenUpper] if !ok || l.err { - return &ParseError{"", "bad TA Algorithm", l} + return &ParseError{err: "bad TA Algorithm", lex: l} } rr.Algorithm = i } else { @@ -1462,7 +1463,7 @@ func (rr *TA) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad TA DigestType", l} + return &ParseError{err: "bad TA DigestType", lex: l} } rr.DigestType = uint8(i) s, e2 := endingToString(c, "bad TA Digest") @@ -1477,21 +1478,21 @@ func (rr *TLSA) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 8) if e != nil || l.err { - return &ParseError{"", "bad TLSA Usage", l} + return &ParseError{err: "bad TLSA Usage", lex: l} } rr.Usage = uint8(i) c.Next() // zBlank l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad TLSA Selector", l} + return &ParseError{err: "bad TLSA Selector", lex: l} } rr.Selector = uint8(i) c.Next() // zBlank l, _ = c.Next() i, e2 := strconv.ParseUint(l.token, 10, 8) if e2 != nil || l.err { - return &ParseError{"", "bad TLSA MatchingType", l} + return &ParseError{err: "bad TLSA MatchingType", lex: l} } rr.MatchingType = uint8(i) // So this needs be e2 (i.e. different than e), because...??t @@ -1507,21 +1508,21 @@ func (rr *SMIMEA) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 8) if e != nil || l.err { - return &ParseError{"", "bad SMIMEA Usage", l} + return &ParseError{err: "bad SMIMEA Usage", lex: l} } rr.Usage = uint8(i) c.Next() // zBlank l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad SMIMEA Selector", l} + return &ParseError{err: "bad SMIMEA Selector", lex: l} } rr.Selector = uint8(i) c.Next() // zBlank l, _ = c.Next() i, e2 := strconv.ParseUint(l.token, 10, 8) if e2 != nil || l.err { - return &ParseError{"", "bad SMIMEA MatchingType", l} + return &ParseError{err: "bad SMIMEA MatchingType", lex: l} } rr.MatchingType = uint8(i) // So this needs be e2 (i.e. different than e), because...??t @@ -1536,14 +1537,14 @@ func (rr *SMIMEA) parse(c *zlexer, o string) *ParseError { func (rr *RFC3597) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() if l.token != "\\#" { - return &ParseError{"", "bad RFC3597 Rdata", l} + return &ParseError{err: "bad RFC3597 Rdata", lex: l} } c.Next() // zBlank l, _ = c.Next() rdlength, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad RFC3597 Rdata ", l} + return &ParseError{err: "bad RFC3597 Rdata ", lex: l} } s, e1 := endingToString(c, "bad RFC3597 Rdata") @@ -1551,7 +1552,7 @@ func (rr *RFC3597) parse(c *zlexer, o string) *ParseError { return e1 } if int(rdlength)*2 != len(s) { - return &ParseError{"", "bad RFC3597 Rdata", l} + return &ParseError{err: "bad RFC3597 Rdata", lex: l} } rr.Rdata = s return nil @@ -1599,14 +1600,14 @@ func (rr *URI) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad URI Priority", l} + return &ParseError{err: "bad URI Priority", lex: l} } rr.Priority = uint16(i) c.Next() // zBlank l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 16) if e1 != nil || l.err { - return &ParseError{"", "bad URI Weight", l} + return &ParseError{err: "bad URI Weight", lex: l} } rr.Weight = uint16(i) @@ -1616,7 +1617,7 @@ func (rr *URI) parse(c *zlexer, o string) *ParseError { return e2 } if len(s) != 1 { - return &ParseError{"", "bad URI Target", l} + return &ParseError{err: "bad URI Target", lex: l} } rr.Target = s[0] return nil @@ -1636,7 +1637,7 @@ func (rr *NID) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad NID Preference", l} + return &ParseError{err: "bad NID Preference", lex: l} } rr.Preference = uint16(i) c.Next() // zBlank @@ -1653,14 +1654,14 @@ func (rr *L32) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad L32 Preference", l} + return &ParseError{err: "bad L32 Preference", lex: l} } rr.Preference = uint16(i) c.Next() // zBlank l, _ = c.Next() // zString rr.Locator32 = net.ParseIP(l.token) if rr.Locator32 == nil || l.err { - return &ParseError{"", "bad L32 Locator", l} + return &ParseError{err: "bad L32 Locator", lex: l} } return slurpRemainder(c) } @@ -1669,7 +1670,7 @@ func (rr *LP) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad LP Preference", l} + return &ParseError{err: "bad LP Preference", lex: l} } rr.Preference = uint16(i) @@ -1678,7 +1679,7 @@ func (rr *LP) parse(c *zlexer, o string) *ParseError { rr.Fqdn = l.token name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{"", "bad LP Fqdn", l} + return &ParseError{err: "bad LP Fqdn", lex: l} } rr.Fqdn = name return slurpRemainder(c) @@ -1688,7 +1689,7 @@ func (rr *L64) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad L64 Preference", l} + return &ParseError{err: "bad L64 Preference", lex: l} } rr.Preference = uint16(i) c.Next() // zBlank @@ -1705,7 +1706,7 @@ func (rr *UID) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 32) if e != nil || l.err { - return &ParseError{"", "bad UID Uid", l} + return &ParseError{err: "bad UID Uid", lex: l} } rr.Uid = uint32(i) return slurpRemainder(c) @@ -1715,7 +1716,7 @@ func (rr *GID) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 32) if e != nil || l.err { - return &ParseError{"", "bad GID Gid", l} + return &ParseError{err: "bad GID Gid", lex: l} } rr.Gid = uint32(i) return slurpRemainder(c) @@ -1737,7 +1738,7 @@ func (rr *PX) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{"", "bad PX Preference", l} + return &ParseError{err: "bad PX Preference", lex: l} } rr.Preference = uint16(i) @@ -1746,7 +1747,7 @@ func (rr *PX) parse(c *zlexer, o string) *ParseError { rr.Map822 = l.token map822, map822Ok := toAbsoluteName(l.token, o) if l.err || !map822Ok { - return &ParseError{"", "bad PX Map822", l} + return &ParseError{err: "bad PX Map822", lex: l} } rr.Map822 = map822 @@ -1755,7 +1756,7 @@ func (rr *PX) parse(c *zlexer, o string) *ParseError { rr.Mapx400 = l.token mapx400, mapx400Ok := toAbsoluteName(l.token, o) if l.err || !mapx400Ok { - return &ParseError{"", "bad PX Mapx400", l} + return &ParseError{err: "bad PX Mapx400", lex: l} } rr.Mapx400 = mapx400 return slurpRemainder(c) @@ -1765,14 +1766,14 @@ func (rr *CAA) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 8) if e != nil || l.err { - return &ParseError{"", "bad CAA Flag", l} + return &ParseError{err: "bad CAA Flag", lex: l} } rr.Flag = uint8(i) c.Next() // zBlank l, _ = c.Next() // zString if l.value != zString { - return &ParseError{"", "bad CAA Tag", l} + return &ParseError{err: "bad CAA Tag", lex: l} } rr.Tag = l.token @@ -1782,7 +1783,7 @@ func (rr *CAA) parse(c *zlexer, o string) *ParseError { return e1 } if len(s) != 1 { - return &ParseError{"", "bad CAA Value", l} + return &ParseError{err: "bad CAA Value", lex: l} } rr.Value = s[0] return nil @@ -1793,7 +1794,7 @@ func (rr *TKEY) parse(c *zlexer, o string) *ParseError { // Algorithm if l.value != zString { - return &ParseError{"", "bad TKEY algorithm", l} + return &ParseError{err: "bad TKEY algorithm", lex: l} } rr.Algorithm = l.token c.Next() // zBlank @@ -1802,13 +1803,13 @@ func (rr *TKEY) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() i, e := strconv.ParseUint(l.token, 10, 8) if e != nil || l.err { - return &ParseError{"", "bad TKEY key length", l} + return &ParseError{err: "bad TKEY key length", lex: l} } rr.KeySize = uint16(i) c.Next() // zBlank l, _ = c.Next() if l.value != zString { - return &ParseError{"", "bad TKEY key", l} + return &ParseError{err: "bad TKEY key", lex: l} } rr.Key = l.token c.Next() // zBlank @@ -1817,13 +1818,13 @@ func (rr *TKEY) parse(c *zlexer, o string) *ParseError { l, _ = c.Next() i, e1 := strconv.ParseUint(l.token, 10, 8) if e1 != nil || l.err { - return &ParseError{"", "bad TKEY otherdata length", l} + return &ParseError{err: "bad TKEY otherdata length", lex: l} } rr.OtherLen = uint16(i) c.Next() // zBlank l, _ = c.Next() if l.value != zString { - return &ParseError{"", "bad TKEY otherday", l} + return &ParseError{err: "bad TKEY otherday", lex: l} } rr.OtherData = l.token return nil @@ -1841,14 +1842,14 @@ func (rr *APL) parse(c *zlexer, o string) *ParseError { continue } if l.value != zString { - return &ParseError{"", "unexpected APL field", l} + return &ParseError{err: "unexpected APL field", lex: l} } // Expected format: [!]afi:address/prefix colon := strings.IndexByte(l.token, ':') if colon == -1 { - return &ParseError{"", "missing colon in APL field", l} + return &ParseError{err: "missing colon in APL field", lex: l} } family, cidr := l.token[:colon], l.token[colon+1:] @@ -1861,7 +1862,7 @@ func (rr *APL) parse(c *zlexer, o string) *ParseError { afi, e := strconv.ParseUint(family, 10, 16) if e != nil { - return &ParseError{"", "failed to parse APL family: " + e.Error(), l} + return &ParseError{wrappedErr: fmt.Errorf("failed to parse APL family: %w", e), lex: l} } var addrLen int switch afi { @@ -1870,19 +1871,19 @@ func (rr *APL) parse(c *zlexer, o string) *ParseError { case 2: addrLen = net.IPv6len default: - return &ParseError{"", "unrecognized APL family", l} + return &ParseError{err: "unrecognized APL family", lex: l} } ip, subnet, e1 := net.ParseCIDR(cidr) if e1 != nil { - return &ParseError{"", "failed to parse APL address: " + e1.Error(), l} + return &ParseError{wrappedErr: fmt.Errorf("failed to parse APL address: %w", e1), lex: l} } if !ip.Equal(subnet.IP) { - return &ParseError{"", "extra bits in APL address", l} + return &ParseError{err: "extra bits in APL address", lex: l} } if len(subnet.IP) != addrLen { - return &ParseError{"", "address mismatch with the APL family", l} + return &ParseError{err: "address mismatch with the APL family", lex: l} } prefixes = append(prefixes, APLPrefix{ diff --git a/scan_test.go b/scan_test.go index 218c9750b..fcc1e36b2 100644 --- a/scan_test.go +++ b/scan_test.go @@ -1,11 +1,14 @@ package dns import ( + "errors" "io" + "io/fs" "net" "os" "strings" "testing" + "testing/fstest" ) func TestZoneParserGenerate(t *testing.T) { @@ -96,6 +99,43 @@ func TestZoneParserInclude(t *testing.T) { } } +func TestZoneParserIncludeFS(t *testing.T) { + fsys := fstest.MapFS{ + "db.foo": &fstest.MapFile{ + Data: []byte("foo\tIN\tA\t127.0.0.1"), + }, + } + zone := "$ORIGIN example.org.\n$INCLUDE db.foo\nbar\tIN\tA\t127.0.0.2" + + var got int + z := NewZoneParser(strings.NewReader(zone), "", "") + z.SetIncludeAllowedFS(true, fsys) + for rr, ok := z.Next(); ok; _, ok = z.Next() { + switch rr.Header().Name { + case "foo.example.org.", "bar.example.org.": + default: + t.Fatalf("expected foo.example.org. or bar.example.org., but got %s", rr.Header().Name) + } + got++ + } + if err := z.Err(); err != nil { + t.Fatalf("expected no error, but got %s", err) + } + + if expected := 2; got != expected { + t.Errorf("failed to parse zone after include, expected %d records, got %d", expected, got) + } + + fsys = fstest.MapFS{} + + z = NewZoneParser(strings.NewReader(zone), "", "") + z.SetIncludeAllowedFS(true, fsys) + z.Next() + if err := z.Err(); !errors.Is(err, fs.ErrNotExist) { + t.Fatalf(`expected fs.ErrNotExist but got: %T %v`, err, err) + } +} + func TestZoneParserIncludeDisallowed(t *testing.T) { tmpfile, err := os.CreateTemp("", "dns") if err != nil { diff --git a/svcb.go b/svcb.go index d38aa2f05..c1a740b68 100644 --- a/svcb.go +++ b/svcb.go @@ -85,7 +85,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError { l, _ := c.Next() i, e := strconv.ParseUint(l.token, 10, 16) if e != nil || l.err { - return &ParseError{l.token, "bad SVCB priority", l} + return &ParseError{file: l.token, err: "bad SVCB priority", lex: l} } rr.Priority = uint16(i) @@ -95,7 +95,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError { name, nameOk := toAbsoluteName(l.token, o) if l.err || !nameOk { - return &ParseError{l.token, "bad SVCB Target", l} + return &ParseError{file: l.token, err: "bad SVCB Target", lex: l} } rr.Target = name @@ -111,7 +111,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError { if !canHaveNextKey { // The key we can now read was probably meant to be // a part of the last value. - return &ParseError{l.token, "bad SVCB value quotation", l} + return &ParseError{file: l.token, err: "bad SVCB value quotation", lex: l} } // In key=value pairs, value does not have to be quoted unless value @@ -124,7 +124,7 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError { // Key with no value and no equality sign key = l.token } else if idx == 0 { - return &ParseError{l.token, "bad SVCB key", l} + return &ParseError{file: l.token, err: "bad SVCB key", lex: l} } else { key, value = l.token[:idx], l.token[idx+1:] @@ -144,30 +144,30 @@ func (rr *SVCB) parse(c *zlexer, o string) *ParseError { value = l.token l, _ = c.Next() if l.value != zQuote { - return &ParseError{l.token, "SVCB unterminated value", l} + return &ParseError{file: l.token, err: "SVCB unterminated value", lex: l} } case zQuote: // There's nothing in double quotes. default: - return &ParseError{l.token, "bad SVCB value", l} + return &ParseError{file: l.token, err: "bad SVCB value", lex: l} } } } } kv := makeSVCBKeyValue(svcbStringToKey(key)) if kv == nil { - return &ParseError{l.token, "bad SVCB key", l} + return &ParseError{file: l.token, err: "bad SVCB key", lex: l} } if err := kv.parse(value); err != nil { - return &ParseError{l.token, err.Error(), l} + return &ParseError{file: l.token, wrappedErr: err, lex: l} } xs = append(xs, kv) case zQuote: - return &ParseError{l.token, "SVCB key can't contain double quotes", l} + return &ParseError{file: l.token, err: "SVCB key can't contain double quotes", lex: l} case zBlank: canHaveNextKey = true default: - return &ParseError{l.token, "bad SVCB values", l} + return &ParseError{file: l.token, err: "bad SVCB values", lex: l} } l, _ = c.Next() } From d59cf64ff29b3eef876c1192aad0b7cc44d50b69 Mon Sep 17 00:00:00 2001 From: Dave Pifke Date: Tue, 9 Jan 2024 10:11:01 -0700 Subject: [PATCH 2/2] Don't duplicate SetIncludeAllowed; clarify edge cases Rather than duplicate functionality between SetIncludeAllowed and SetIncludeAllowedFS, have a method SetIncludeFS, which only sets the fs.FS. I've improved the documentation to point out some considerations for users hoping to use fs.FS as a security boundary. Per the fs.ValidPath documentation, fs.FS implementations must use path (not filepath) semantics, with slash as a separator (even on Windows). Some, like os.DirFS, also require all paths to be relative. I've clarified this in the documentation, made the includePath manipulation more robust to edge cases, and added some additional tests for relative and absolute paths. --- scan.go | 44 ++++++++++++++++++++++++++++++++++---------- scan_test.go | 39 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 12 deletions(-) diff --git a/scan.go b/scan.go index 776546802..1f92ae421 100644 --- a/scan.go +++ b/scan.go @@ -6,6 +6,7 @@ import ( "io" "io/fs" "os" + "path" "path/filepath" "strconv" "strings" @@ -228,10 +229,22 @@ func (zp *ZoneParser) SetIncludeAllowed(v bool) { zp.includeAllowed = v } -// SetIncludeAllowedFS controls whether $INCLUDE directives are allowed, and -// specifies an [fs.FS] from which to open and read files. -func (zp *ZoneParser) SetIncludeAllowedFS(v bool, fsys fs.FS) { - zp.includeAllowed, zp.fsys = v, fsys +// SetIncludeFS provides an [fs.FS] to use when looking for the target of +// $INCLUDE directives. ($INCLUDE must still be enabled separately by calling +// [ZoneParser.SetIncludeAllowed].) If fsys is nil, [os.Open] will be used. +// +// When fsys is an on-disk FS, the ability of $INCLUDE to reach files from +// outside its root directory depends upon the FS implementation. For +// instance, [os.DirFS] will refuse to open paths like "../../etc/passwd", +// however it will still follow links which may point anywhere on the system. +// +// FS paths are slash-separated on all systems, even Windows. $INCLUDE paths +// containing other characters such as backslash and colon may be accepted as +// valid, but those characters will never be interpreted by an FS +// implementation as path element separators. See [fs.ValidPath] for more +// details. +func (zp *ZoneParser) SetIncludeFS(fsys fs.FS) { + zp.fsys = fsys } // Err returns the first non-EOF error that was encountered by the @@ -418,20 +431,30 @@ func (zp *ZoneParser) Next() (RR, bool) { // Start with the new file includePath := l.token - if !filepath.IsAbs(includePath) { - includePath = filepath.Join(filepath.Dir(zp.file), includePath) - } - var r1 io.Reader var e1 error if zp.fsys != nil { + // fs.FS always uses / as separator, even on Windows, so use + // path instead of filepath here: + if !path.IsAbs(includePath) { + includePath = path.Join(path.Dir(zp.file), includePath) + } + + // os.DirFS, and probably others, expect all paths to be + // relative, so clean the path and remove leading / if + // present: + includePath = strings.TrimLeft(path.Clean(includePath), "/") + r1, e1 = zp.fsys.Open(includePath) } else { + if !filepath.IsAbs(includePath) { + includePath = filepath.Join(filepath.Dir(zp.file), includePath) + } r1, e1 = os.Open(includePath) } if e1 != nil { var as string - if !filepath.IsAbs(l.token) { + if includePath != l.token { as = fmt.Sprintf(" as `%s'", includePath) } zp.parseErr = &ParseError{ @@ -444,7 +467,8 @@ func (zp *ZoneParser) Next() (RR, bool) { zp.sub = NewZoneParser(r1, neworigin, includePath) zp.sub.defttl, zp.sub.includeDepth, zp.sub.r = zp.defttl, zp.includeDepth+1, r1 - zp.sub.SetIncludeAllowedFS(true, zp.fsys) + zp.sub.SetIncludeAllowed(true) + zp.sub.SetIncludeFS(zp.fsys) return zp.subNext() case zExpectDirTTLBl: if l.value != zBlank { diff --git a/scan_test.go b/scan_test.go index fcc1e36b2..3332c82d8 100644 --- a/scan_test.go +++ b/scan_test.go @@ -109,7 +109,8 @@ func TestZoneParserIncludeFS(t *testing.T) { var got int z := NewZoneParser(strings.NewReader(zone), "", "") - z.SetIncludeAllowedFS(true, fsys) + z.SetIncludeAllowed(true) + z.SetIncludeFS(fsys) for rr, ok := z.Next(); ok; _, ok = z.Next() { switch rr.Header().Name { case "foo.example.org.", "bar.example.org.": @@ -129,13 +130,47 @@ func TestZoneParserIncludeFS(t *testing.T) { fsys = fstest.MapFS{} z = NewZoneParser(strings.NewReader(zone), "", "") - z.SetIncludeAllowedFS(true, fsys) + z.SetIncludeAllowed(true) + z.SetIncludeFS(fsys) z.Next() if err := z.Err(); !errors.Is(err, fs.ErrNotExist) { t.Fatalf(`expected fs.ErrNotExist but got: %T %v`, err, err) } } +func TestZoneParserIncludeFSPaths(t *testing.T) { + fsys := fstest.MapFS{ + "baz/bat/db.foo": &fstest.MapFile{ + Data: []byte("foo\tIN\tA\t127.0.0.1"), + }, + } + + for _, p := range []string{ + "../bat/db.foo", + "/baz/bat/db.foo", + } { + zone := "$ORIGIN example.org.\n$INCLUDE " + p + "\nbar\tIN\tA\t127.0.0.2" + var got int + z := NewZoneParser(strings.NewReader(zone), "", "baz/quux/db.bar") + z.SetIncludeAllowed(true) + z.SetIncludeFS(fsys) + for rr, ok := z.Next(); ok; _, ok = z.Next() { + switch rr.Header().Name { + case "foo.example.org.", "bar.example.org.": + default: + t.Fatalf("$INCLUDE %q: expected foo.example.org. or bar.example.org., but got %s", p, rr.Header().Name) + } + got++ + } + if err := z.Err(); err != nil { + t.Fatalf("$INCLUDE %q: expected no error, but got %s", p, err) + } + if expected := 2; got != expected { + t.Errorf("$INCLUDE %q: failed to parse zone after include, expected %d records, got %d", p, expected, got) + } + } +} + func TestZoneParserIncludeDisallowed(t *testing.T) { tmpfile, err := os.CreateTemp("", "dns") if err != nil {