Skip to content

Commit

Permalink
all: fix panic in cert parsing (#19)
Browse files Browse the repository at this point in the history
Fixes #18.
  • Loading branch information
ainar-g authored Mar 15, 2023
1 parent aeff595 commit 266e248
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 50 deletions.
120 changes: 72 additions & 48 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package dnscrypt
import (
"crypto/ed25519"
"encoding/binary"
"fmt"
"net"
"strings"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/ameshkov/dnsstamps"
"github.com/miekg/dns"
Expand Down Expand Up @@ -69,6 +71,7 @@ func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ResolverInfo, error) {
if err != nil {
return nil, err
}

resolverInfo.ResolverCert = cert

// Compute shared key that we'll use to encrypt/decrypt messages
Expand All @@ -83,23 +86,24 @@ func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ResolverInfo, error) {
// Exchange performs a synchronous DNS query to the specified DNSCrypt server and returns a DNS response.
// This method creates a new network connection for every call so avoid using it for TCP.
// DNSCrypt cert needs to be fetched and validated prior to this call using the c.DialStamp method.
func (c *Client) Exchange(m *dns.Msg, resolverInfo *ResolverInfo) (*dns.Msg, error) {
func (c *Client) Exchange(m *dns.Msg, resolverInfo *ResolverInfo) (resp *dns.Msg, err error) {
network := "udp"
if c.Net == "tcp" {
network = "tcp"
}

conn, err := net.Dial(network, resolverInfo.ServerAddress)
if err != nil {
return nil, err
return nil, fmt.Errorf("dialing: %w", err)
}
defer conn.Close()
defer func() { err = errors.WithDeferred(err, conn.Close()) }()

r, err := c.ExchangeConn(conn, m, resolverInfo)
resp, err = c.ExchangeConn(conn, m, resolverInfo)
if err != nil {
return nil, err
return nil, fmt.Errorf("exchanging: %w", err)
}
return r, nil

return resp, nil
}

// ExchangeConn performs a synchronous DNS query to the specified DNSCrypt server and returns a DNS response.
Expand Down Expand Up @@ -217,7 +221,7 @@ func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error)
}

// fetchCert loads DNSCrypt cert from the specified server
func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (cert *Cert, err error) {
providerName := stamp.ProviderName
if !strings.HasSuffix(providerName, ".") {
providerName = providerName + "."
Expand All @@ -236,67 +240,87 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
return nil, ErrFailedToFetchCert
}

var certErr error
currentCert := &Cert{}
foundValid := false

for _, rr := range r.Answer {
txt, ok := rr.(*dns.TXT)
if !ok {
continue
}
var b []byte
b, certErr = unpackTxtString(strings.Join(txt.Txt, ""))
if certErr != nil {
log.Debug("[%s] failed to pack TXT record: %v", providerName, certErr)
continue
}

cert := &Cert{}
certErr = cert.Deserialize(b)
if certErr != nil {
log.Debug("[%s] failed to deserialize cert: %v", providerName, certErr)
cert, err = parseCert(stamp, currentCert, providerName, strings.Join(txt.Txt, ""))
if err != nil {
log.Debug("[%s] bad cert: %s", providerName, err)
} else if cert == nil {
continue
}

log.Debug("[%s] fetched certificate %d", providerName, cert.Serial)
currentCert = cert
foundValid = true
}

if !cert.VerifyDate() {
certErr = ErrInvalidDate
log.Debug("[%s] cert %d date is not valid", providerName, cert.Serial)
continue
}
if foundValid {
return currentCert, nil
} else if err == nil {
err = errors.Error("no valid txt records")
}

if !cert.VerifySignature(stamp.ServerPk) {
certErr = ErrInvalidCertSignature
log.Debug("[%s] cert %d signature is not valid", providerName, cert.Serial)
continue
}
return nil, err
}

if cert.Serial < currentCert.Serial {
log.Debug("[%v] cert %d superseded by a previous certificate", providerName, cert.Serial)
continue
}
// parseCert parses a certificate from its string form and returns it if it has
// priority over currentCert.
func parseCert(
stamp dnsstamps.ServerStamp,
currentCert *Cert,
providerName string,
certStr string,
) (cert *Cert, err error) {
certBytes, err := unpackTxtString(certStr)
if err != nil {
return nil, fmt.Errorf("unpacking txt record: %w", err)
}

if cert.Serial == currentCert.Serial {
if cert.EsVersion > currentCert.EsVersion {
log.Debug("[%v] Upgrading the construction from %v to %v", providerName, currentCert.EsVersion, cert.EsVersion)
} else {
log.Debug("[%v] Keeping the previous, preferred crypto construction", providerName)
continue
}
}
cert = &Cert{}
err = cert.Deserialize(certBytes)
if err != nil {
return nil, fmt.Errorf("deserializing cert for: %w", err)
}

// Setting the cert
currentCert = cert
foundValid = true
log.Debug("[%s] fetched certificate %d", providerName, cert.Serial)

if !cert.VerifyDate() {
return nil, ErrInvalidDate
}

if foundValid {
return currentCert, nil
if !cert.VerifySignature(stamp.ServerPk) {
return nil, ErrInvalidCertSignature
}

if cert.Serial < currentCert.Serial {
log.Debug("[%v] cert %d superseded by a previous certificate", providerName, cert.Serial)

return nil, nil
}

if cert.Serial > currentCert.Serial {
return cert, nil
}

return nil, certErr
if cert.EsVersion <= currentCert.EsVersion {
log.Debug("[%v] keeping the previous, preferred crypto construction", providerName)

return nil, nil
}

log.Debug(
"[%v] upgrading the construction from %v to %v",
providerName,
currentCert.EsVersion,
cert.EsVersion,
)

return cert, nil
}

func (c *Client) maxQuerySize() int {
Expand Down
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestTimeoutOnDialExchange(t *testing.T) {

// Check error
require.NotNil(t, err)
require.True(t, os.IsTimeout(err))
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
}

func TestFetchCertPublicResolvers(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/ameshkov/dnscrypt/v2

go 1.18
go 1.19

require (
github.com/AdguardTeam/golibs v0.10.9
Expand Down

0 comments on commit 266e248

Please sign in to comment.