From 6a7f77868e7888253c0d8091e17a811a8a7be60a Mon Sep 17 00:00:00 2001 From: Nikolay Bystritskiy Date: Sun, 6 Feb 2022 22:10:34 +0100 Subject: [PATCH] fixed test after refactor --- app/acme/acme.go | 129 ++++---- app/acme/acme_test.go | 127 +++++--- app/acme/dns_challenge.go | 423 +++++++++------------------ app/acme/dns_challenge_test.go | 276 ++++------------- app/acme/dnsprovider/README.md | 22 +- app/acme/dnsprovider/cloudns_test.go | 24 +- app/acme/dnsprovider/route53.go | 270 +++++++++++++---- app/acme/dnsprovider/route53_test.go | 12 +- app/main.go | 34 ++- 9 files changed, 588 insertions(+), 729 deletions(-) diff --git a/app/acme/acme.go b/app/acme/acme.go index b768dc5b..91396617 100644 --- a/app/acme/acme.go +++ b/app/acme/acme.go @@ -2,103 +2,86 @@ package acme import ( "context" - "log" "time" - "github.com/go-pkgz/repeater" + log "github.com/go-pkgz/lgr" ) -// var acmeOpTimeout = 5 * time.Minute +var ( + attemptInterval = time.Minute * 1 + maxAttemps = 5 +) // Solver is an interface for solving ACME DNS challenge type Solver interface { // PreSolve is called before solving the challenge. ACME Order will be created and DNS record will be added. - PreSolve(ctx context.Context) error - // Solve is called to present TXT record and accept challenge. - Solve(ctx context.Context) error - // PostSolve is called after obtaining the certificate. - PostSolve(ctx context.Context) error - // GetCertificateExpiration returns certificate expiration date - GetCertificateExpiration(certPath string) (time.Time, error) -} + PreSolve() error -// fqdns []string, provider string, nameservers []string + // Solve is called to accept the challenge and pull the certificate. + Solve() error -// ScheduleCertificateRenewal schedules certificate renewal -func ScheduleCertificateRenewal(solver Solver, timeout time.Duration) { - certPath := getEnvOptionalString("SSL_CERT", "./var/acme/cert.pem") + // ObtainCertificate is called to obtain the certificate. + // Certificate will be saved to the file path specified by flag (env: SSL_CERT). //TODO add proper descr + ObtainCertificate() error +} +// ScheduleCertificateRenewal schedules certificate renewal +func ScheduleCertificateRenewal(ctx context.Context, solver Solver, certPath string) { go func(certPath string) { - var ( - expiredAt time.Time - err error - ) + var nextAttemptAfter time.Duration - expiredAt, err = solver.GetCertificateExpiration(certPath) - if err != nil { - expiredAt = time.Now() - log.Printf("[INFO] failed to get certificate expiration date, probably not obtained yet: %v", err) + if expiredAt, err := getCertificateExpiration(certPath); err == nil { + nextAttemptAfter = time.Until(expiredAt.Add(time.Hour * 24 * -5)) + log.Printf("[INFO] certificate will expire in %v, next attempt in %v", expiredAt, nextAttemptAfter) } + attempted := 0 for { - <-time.After(time.Until(expiredAt.Add(time.Hour * 24 * -5))) + select { + case <-ctx.Done(): + return + case <-time.After(nextAttemptAfter): + } + attempted++ - // add DNS record and wait for propagation - { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - err = repeater.NewDefault(10, timeout>>12).Do(ctx, func() error { - if errc := solver.PreSolve(ctx); errc != nil { - log.Printf("[INFO] error in ACME DNS Challenge Presolve: %v", errc) - return errc - } - return nil - }) - cancel() - if err != nil { - log.Printf("[ERROR] ACME DNS Challenge Presolve failed. Last error %v", err) - return - } + if attempted > maxAttemps { + log.Printf("[ERROR] Certificate renewal failed after %d attempts", attempted-1) + return } + log.Printf("[INFO] renewing certificate attempt %d", attempted) - // present TXT record and accept challenge - { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - err = repeater.NewDefault(10, timeout>>12).Do(ctx, func() error { - if errc := solver.Solve(ctx); errc != nil { - log.Printf("[INFO] error in ACME DNS Challenge Solve: %v", errc) - return errc - } - return nil - }) - cancel() - if err != nil { - log.Printf("[ERROR] retry limit reached ACME DNS Challenge Solve failed. Last error: %v", err) - return - } + // create ACME order and add TXT record for the challenge + if err := solver.PreSolve(); err != nil { + nextAttemptAfter = time.Duration(attempted) * attemptInterval + log.Printf("[WARN] error during preparing ACME order: %v, next attempt in %v", err, nextAttemptAfter) + continue } - // pull the certificate - { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - err = repeater.NewDefault(10, timeout>>12).Do(ctx, func() error { - if errc := solver.PostSolve(ctx); errc != nil { - log.Printf("[INFO] error in ACME DNS Challenge PostSolve: %v", errc) - return errc - } - return nil - }) - cancel() - if err != nil { - log.Printf("[ERROR] retry limit reached, ACME DNS Challenge PostSolve failed. Last error: %v", err) - return - } + // solve the challenge + if err := solver.Solve(); err != nil { + nextAttemptAfter = time.Duration(attempted) * attemptInterval + log.Printf("[WARN] error during solving ACME DNS Challenge: %v, next attempt in %v", err, nextAttemptAfter) + continue } - expiredAt, err = solver.GetCertificateExpiration(certPath) - if err != nil { - log.Printf("[ERROR] failed to get certificate expiration date: %v", err) - return + // obtain certificate + if err := solver.ObtainCertificate(); err != nil { + nextAttemptAfter = time.Duration(attempted) * attemptInterval + log.Printf("[WARN] error during certificate obtaining: %v, next attempt in %v", err, nextAttemptAfter) + continue } + + expiredAt, err := getCertificateExpiration(certPath) + if err == nil { + // 5 days earlier than the certificate expiration + nextAttemptAfter = time.Until(expiredAt.Add(time.Hour * 24 * -5)) + log.Printf("[INFO] certificate will expire in %v, next attempt in %v", expiredAt, nextAttemptAfter) + attempted = 0 + continue + } + + log.Printf("[WARN] failed to get certificate expiration date, probably not obtained yet: %v", err) + nextAttemptAfter = time.Duration(attempted) * attemptInterval } }(certPath) } diff --git a/app/acme/acme_test.go b/app/acme/acme_test.go index 5137d6c8..9ab69723 100644 --- a/app/acme/acme_test.go +++ b/app/acme/acme_test.go @@ -2,22 +2,31 @@ package acme import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "fmt" + "math/big" + "os" "testing" "time" "github.com/stretchr/testify/assert" ) +const certPath = "./TestScheduleCertificateRenewal.pem" + type mockSolver struct { domain string expires time.Time preSolvedCalled int solveCalled int - postSolvedCalled int + obtainCertCalled int } -func (s *mockSolver) PreSolve(ctx context.Context) error { +func (s *mockSolver) PreSolve() error { s.preSolvedCalled++ switch s.domain { case "mycompany1.com": @@ -26,39 +35,31 @@ func (s *mockSolver) PreSolve(ctx context.Context) error { return nil } -func (s *mockSolver) Solve(ctx context.Context) error { +func (s *mockSolver) Solve() error { s.solveCalled++ switch s.domain { case "mycompany2.com": - return fmt.Errorf("solve failed") + return fmt.Errorf("postSolved failed") } return nil } -func (s *mockSolver) PostSolve(ctx context.Context) error { - s.postSolvedCalled++ +func (s *mockSolver) ObtainCertificate() error { + s.obtainCertCalled++ switch s.domain { case "mycompany3.com": - return fmt.Errorf("postSolved failed") + return fmt.Errorf("obtainCertificate failed") + case "mycompany5.com": + return nil + default: + return createCert(time.Now().Add(time.Hour*24*365), s.domain) } - return nil } - -func (s *mockSolver) GetCertificateExpiration(certPath string) (time.Time, error) { - // check called before loop starts - if s.preSolvedCalled == 0 { - switch s.domain { - case "mycompany4.com": - return time.Now().Add(time.Hour * 24 * 670), nil - default: - return time.Time{}, fmt.Errorf("certificate does not exist") - } - } - return time.Now().Add(time.Hour * 24 * 365), nil -} - func TestScheduleCertificateRenewal(t *testing.T) { - timeout := 15 * time.Second + testMaxAttemps := 10 + maxAttemps = testMaxAttemps + + attemptInterval = time.Microsecond * 10 type args struct { domain string @@ -69,7 +70,7 @@ func TestScheduleCertificateRenewal(t *testing.T) { type expected struct { preSolvedCalled int solveCalled int - postSolvedCalled int + obtainCertCalled int } tests := []struct { @@ -77,34 +78,78 @@ func TestScheduleCertificateRenewal(t *testing.T) { args args expected expected }{ - // {"certificate not existed before", - // args{"example.com", false, time.Now().Add(time.Hour * 100 * 24)}, - // expected{1, 1, 1}}, - // {"presolve failed", - // args{"mycompany1.com", false, time.Time{}}, - // expected{10, 0, 0}}, - // {"solve failed", - // args{"mycompany2.com", false, time.Time{}}, - // expected{1, 10, 0}}, - {"postsolve failed", + {"certificate not existed before", + args{"example.com", false, time.Time{}}, + expected{1, 1, 1}}, + {"presolve always fails", + args{"mycompany1.com", false, time.Time{}}, + expected{testMaxAttemps, 0, 0}}, + {"solve always fails", + args{"mycompany2.com", false, time.Time{}}, + expected{testMaxAttemps, testMaxAttemps, 0}}, + {"obtain cert failed", args{"mycompany3.com", false, time.Time{}}, - expected{1, 1, 10}}, - // {"certificate valid for a long time", - // args{"mycompany4.com", false, time.Time{}}, - // expected{0, 0, 0}}, + expected{maxAttemps, maxAttemps, maxAttemps}}, + {"certificate valid for a long time", + args{"mycompany4.com", true, time.Now().Add(time.Hour * 100 * 24)}, + expected{0, 0, 0}}, + {"obtain cert success, but file not created", + args{"mycompany5.com", false, time.Time{}}, + expected{maxAttemps, maxAttemps, maxAttemps}}, } for _, tt := range tests { + if tt.args.certExistedBefore { + if err := createCert(tt.args.expiryTime, tt.args.domain); err != nil { + t.Fatal(err) + } + } + s := &mockSolver{ domain: tt.args.domain, expires: tt.args.expiryTime, } - ScheduleCertificateRenewal(s, timeout) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + ScheduleCertificateRenewal(ctx, s, certPath) + time.Sleep(time.Second * 2) - time.Sleep(timeout) assert.Equal(t, tt.expected.preSolvedCalled, s.preSolvedCalled, fmt.Sprintf("[case %s] preSolvedCalled not match", tt.name)) assert.Equal(t, tt.expected.solveCalled, s.solveCalled, fmt.Sprintf("[case %s] solveCalled not match", tt.name)) - assert.Equal(t, tt.expected.postSolvedCalled, s.postSolvedCalled, fmt.Sprintf("[case %s] postSolvedCalled not match", tt.name)) + assert.Equal(t, tt.expected.obtainCertCalled, s.obtainCertCalled, fmt.Sprintf("[case %s] postSolvedCalled not match", tt.name)) + + os.Remove(certPath) + cancel() + } +} + +func createCert(expireAt time.Time, domain string) error { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + }, + NotBefore: time.Now(), + NotAfter: expireAt, + + KeyUsage: x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{domain}, + } + // write cert to file + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + return err + } + certFile, err := os.Create(certPath) + if err != nil { + return err + } + + if _, err := certFile.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes})); err != nil { + return err } + return certFile.Close() } diff --git a/app/acme/dns_challenge.go b/app/acme/dns_challenge.go index e4477bb2..f0de67bc 100644 --- a/app/acme/dns_challenge.go +++ b/app/acme/dns_challenge.go @@ -11,7 +11,6 @@ import ( "net" "os" "path/filepath" - "sync" "time" "github.com/umputun/reproxy/app/acme/dnsprovider" @@ -23,7 +22,8 @@ var defaultNameservers = []string{ "google-public-dns-b.google.com", } -var acmeV2Enpoint = "https://acme-v02.api.letsencrypt.org/directory" +// var acmeV2Enpoint = "https://acme-v02.api.letsencrypt.org/directory" +var acmeV2Enpoint = "https://acme-staging-v02.api.letsencrypt.org/directory" // DNSChallengeConfig contains configuration for DNS challenge type DNSChallengeConfig struct { @@ -33,6 +33,8 @@ type DNSChallengeConfig struct { Nameservers []string Timeout time.Duration PollingInterval time.Duration + CertPath string + KeyPath string } // DNSChallenge represents an ACME DNS challenge @@ -43,10 +45,10 @@ type DNSChallenge struct { accountKey *rsa.PrivateKey provider dnsprovider.Provider order *acme.Order - challenges []*acme.Challenge - records []dnsprovider.Record timeout time.Duration pollingInterval time.Duration + certPath string + keyPath string } // NewDNSChallege creates new DNSChallenge @@ -68,16 +70,21 @@ func NewDNSChallege(config DNSChallengeConfig) (*DNSChallenge, error) { } return &DNSChallenge{provider: p, - nameservers: config.Nameservers, - domains: config.Domains, - records: make([]dnsprovider.Record, 0), - challenges: make([]*acme.Challenge, 0), + nameservers: config.Nameservers, + domains: config.Domains, + timeout: config.Timeout, + pollingInterval: config.PollingInterval, + certPath: config.CertPath, + keyPath: config.KeyPath, }, nil } // PreSolve is called before solving the challenge. // ACME Order will be created and DNS record will be added. -func (d *DNSChallenge) PreSolve(ctx context.Context) error { +func (d *DNSChallenge) PreSolve() error { + ctx, cancel := context.WithTimeout(context.Background(), d.timeout) + defer cancel() + if err := d.register(); err != nil { return err } @@ -85,37 +92,95 @@ func (d *DNSChallenge) PreSolve(ctx context.Context) error { return err } - if err := d.presentRecords(); err != nil { - return err - } - return nil } -// Solve is called to present TXT record and accept challenge. -func (d *DNSChallenge) Solve(ctx context.Context) error { - errs := d.waitDNSPropagation(ctx) - for _, err := range errs { +// waitPropagation blocks until the DNS record is propagated or timeout is reached. +func (d *DNSChallenge) waitPropagation(record dnsprovider.Record) error { + if len(d.domains) == 0 { + return fmt.Errorf("no domain is provided") + } + + ctx, cancel := context.WithTimeout(context.Background(), d.timeout) + defer cancel() + + if err := d.provider.WaitUntilPropagated(ctx, record); err != nil { log.Printf("[WARN] %v", err) } - if err := d.checkWithNS(ctx); err != nil { + if err := d.checkWithNS(ctx, record); err != nil { log.Printf("[WARN] nameservers lookup failed with errors: %v", err) } - if err := d.acceptChallenges(ctx); err != nil { - return err + return nil +} + +// Solve is called to accept the challenge and pull the certificate. +func (d *DNSChallenge) Solve() error { + ctx, cancel := context.WithTimeout(context.Background(), d.timeout) + defer cancel() + + for _, authzURL := range d.order.AuthzURLs { + authz, err := d.client.GetAuthorization(ctx, authzURL) + if err != nil { + return err + } + + var chl *acme.Challenge + for i := range authz.Challenges { + if authz.Challenges[i].Type == "dns-01" { + chl = authz.Challenges[i] + break + } + } + + if chl == nil { + return fmt.Errorf("no DNS-01 challenge found for %v", authz.Identifier.Value) + } + + record := dnsprovider.Record{ + Type: "TXT", + Host: "_acme-challenge", + Domain: authz.Identifier.Value, + } + + if record.Value, err = d.client.DNS01ChallengeRecord(chl.Token); err != nil { + return fmt.Errorf("failed to get TXT record value: %v", err) + } + + err = d.provider.AddRecord(record) + if err != nil { + return fmt.Errorf("failed to add TXT record %s: %v", record.Host+record.Domain, err) + } + + err = d.waitPropagation(record) + if err != nil { + return fmt.Errorf("failed to wait for TXT record %s: %v", record.Host+record.Domain, err) + } + + _, err = d.client.Accept(ctx, chl) + if err != nil { + return err + } + + _, err = d.client.WaitAuthorization(ctx, authzURL) + if err != nil { + return err + } + + err = d.provider.RemoveRecord(record) + if err != nil { + log.Printf("[WARN] failed to remove TXT record %s: %v", record.Host+record.Domain, err) + } } return nil } -// PostSolve is called after obtaining the certificate. -func (d *DNSChallenge) PostSolve(ctx context.Context) error { - defer d.cleanupRecords() - if len(d.domains) == 0 { - return fmt.Errorf("no domain is provided") - } +// ObtainCertificate is called to obtain the certificate. +func (d *DNSChallenge) ObtainCertificate() error { + ctx, cancel := context.WithTimeout(context.Background(), d.timeout) + defer cancel() q := &x509.CertificateRequest{ DNSNames: d.domains, @@ -152,8 +217,7 @@ func (d *DNSChallenge) PostSolve(ctx context.Context) error { return nil } -// GetCertificateExpiration returns certificate expiration date -func (d *DNSChallenge) GetCertificateExpiration(certPath string) (time.Time, error) { +func getCertificateExpiration(certPath string) (time.Time, error) { b, err := os.ReadFile(filepath.Clean(certPath)) if err != nil { return time.Time{}, err @@ -170,68 +234,29 @@ func (d *DNSChallenge) GetCertificateExpiration(certPath string) (time.Time, err } func (d *DNSChallenge) register() error { - if d.client == nil { - var err error - d.accountKey, err = rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return err - } - - d.client = &acme.Client{ - DirectoryURL: acmeV2Enpoint, - Key: d.accountKey, - } - + if d.client != nil { + return nil } - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) - defer cancel() - - if _, err := d.client.Register(ctx, &acme.Account{}, acme.AcceptTOS); err != nil { + var err error + d.accountKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { return err } - return nil -} -func (d *DNSChallenge) acceptChallenges(ctx context.Context) error { - for _, chl := range d.challenges { - if _, err := d.client.Accept(ctx, chl); err != nil { - return err - } + client := acme.Client{ + DirectoryURL: acmeV2Enpoint, + Key: d.accountKey, } - errCh := make(chan error, len(d.order.AuthzURLs)) - var wg sync.WaitGroup - waitCh := make(chan struct{}, 3) // limit concurrency to 3 - - for _, authURL := range d.order.AuthzURLs { - wg.Add(1) - go func(authURL string) { - waitCh <- struct{}{} - defer func() { - <-waitCh - wg.Done() - }() - _, err := d.client.WaitAuthorization(ctx, authURL) - errCh <- err - }(authURL) - } - - go func() { - wg.Wait() - close(errCh) - }() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() - errs := make([]error, 0, len(d.order.AuthzURLs)) - for err := range errCh { - if err != nil { - errs = append(errs, err) - } + if _, err := client.Register(ctx, &acme.Account{}, acme.AcceptTOS); err != nil { + return err } - if len(errs) > 0 { - return fmt.Errorf("some challenges failed: %v", errs[0]) - } + d.client = &client return nil } @@ -248,176 +273,40 @@ func (d *DNSChallenge) prepareOrder(ctx context.Context, domains []string) error return err } - authCh := make(chan *acme.Authorization, len(d.order.AuthzURLs)) - var authWg sync.WaitGroup - - waitCh := make(chan struct{}, 3) // limit concurrency to 3 - for _, authURL := range d.order.AuthzURLs { - authWg.Add(1) - go func(authURL string) { - waitCh <- struct{}{} - defer func() { - <-waitCh - authWg.Done() - }() - auth, authErr := d.client.GetAuthorization(ctx, authURL) - if authErr != nil { - log.Printf("failed to get authorization: %v", authErr) - return - } - authCh <- auth - }(authURL) - } - - go func() { - authWg.Wait() - close(authCh) - }() - - for authz := range authCh { - // according to ACME spec, authorization objects are created in the "pending" state - if authz.Status != acme.StatusPending { - log.Printf("[ERROR] DNS-01 challenge for %v is not pending, with status %s", authz.Identifier.Value, authz.Status) - continue - } - - var chl *acme.Challenge - for i := range authz.Challenges { - if authz.Challenges[i].Type == "dns-01" { - chl = authz.Challenges[i] - break - } - } - - if chl == nil { - log.Printf("[ERROR] no DNS-01 challenge found for %v", authz.Identifier.Value) - continue - } - - d.challenges = append(d.challenges, chl) - - record := dnsprovider.Record{ - Type: "TXT", - Host: "_acme-challenge", - Domain: authz.Identifier.Value, - } - - record.Value, err = d.client.DNS01ChallengeRecord(chl.Token) - if err != nil { - log.Printf("failed to get TXT record value: %v", err) - continue - } - - d.records = append(d.records, record) - } - return nil } -func (d *DNSChallenge) presentRecords() error { - recCh := make(chan dnsprovider.Record, len(d.records)) - var recWg sync.WaitGroup - - waitCh := make(chan struct{}, 3) // limit concurrency to 3 - for _, r := range d.records { - recWg.Add(1) - go func(r dnsprovider.Record) { - waitCh <- struct{}{} - defer func() { - <-waitCh - recWg.Done() - }() - if err := d.provider.AddRecord(r); err != nil { - log.Printf("[ERROR] failed to add TXT record: %v", err) - return - } - recCh <- r - }(r) - } - - go func() { - recWg.Wait() - close(recCh) - }() +func (d *DNSChallenge) checkWithNS(ctx context.Context, record dnsprovider.Record) error { + ticker := time.NewTicker(d.pollingInterval) - addedRecords := make([]dnsprovider.Record, 0, len(d.records)) - for r := range recCh { - addedRecords = append(addedRecords, r) - } + var lastErr error + nextNameserver := d.getNameserverFn() - if len(addedRecords) != len(d.records) { - defer d.cleanupRecords() + nameserver := nextNameserver() + if lastErr = lookupTXTRecord(record, nameserver); lastErr == nil { + nameserver = nextNameserver() } - return nil -} - -func (d *DNSChallenge) waitDNSPropagation(ctx context.Context) []error { - errCh := make(chan error, len(d.records)) - var wg sync.WaitGroup - - waitCh := make(chan struct{}, 3) // limit concurrency to 3 records - for _, r := range d.records { - wg.Add(1) - go func(r dnsprovider.Record) { - waitCh <- struct{}{} - defer func() { - <-waitCh - wg.Done() - }() - if err := d.provider.WaitUntilPropagated(ctx, r); err != nil { - errCh <- fmt.Errorf("error waiting for %s record %s to propagate: %v", - r.Type, fmt.Sprintf("%s.%s", r.Host, r.Domain), err) +nsLoop: + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timeout while checking DNS record propagation. Last error: %v", lastErr) + case <-ticker.C: + // record propagated to all nameservers + if nameserver == "" { + break nsLoop } - }(r) - } - - go func() { - wg.Wait() - close(errCh) - }() - - errs := make([]error, 0, len(d.records)) - for err := range errCh { - if err != nil { - errs = append(errs, err) - } - } - - return errs -} - -func (d *DNSChallenge) checkWithNS(ctx context.Context) error { - ticker := time.NewTicker(10 * time.Second) - - for _, record := range d.records { - var lastErr error - nextNameserver := d.getNameserverFn() - - nameserver := nextNameserver() - if lastErr = lookupTXTRecord(record, nameserver); lastErr == nil { - nameserver = nextNameserver() - } - nsLoop: - for { - select { - case <-ctx.Done(): - return fmt.Errorf("timeout while checking DNS record propagation. Last error: %v", lastErr) - case <-ticker.C: - // record propagated to all nameservers - if nameserver == "" { - break nsLoop - } - err := lookupTXTRecord(record, nameserver) - if err == nil { - log.Printf("[INFO] DNS record %s.%s propagated to nameserver %s", record.Host, record.Domain, nameserver) - nameserver = nextNameserver() - continue - } - lastErr = err + err := lookupTXTRecord(record, nameserver) + if err == nil { + log.Printf("[INFO] DNS record %s.%s propagated to nameserver %s", record.Host, record.Domain, nameserver) + nameserver = nextNameserver() + continue } + lastErr = err } } + return nil } @@ -426,10 +315,7 @@ func (d *DNSChallenge) writeCertificates(privateKey *rsa.PrivateKey, cert *x509. return fmt.Errorf("private key or certificate is nil") } - keyPath := getEnvOptionalString("SSL_KEY", "./var/acme/key.pem") - certPath := getEnvOptionalString("SSL_CERT", "./var/acme/cert.pem") - - dir := filepath.Dir(keyPath) + dir := filepath.Dir(d.keyPath) if _, err := os.Stat(dir); err != nil { if os.IsNotExist(err) { if err := os.MkdirAll(dir, 0o700); err != nil { @@ -438,7 +324,7 @@ func (d *DNSChallenge) writeCertificates(privateKey *rsa.PrivateKey, cert *x509. } } - keyOut, err := os.OpenFile(filepath.Clean(keyPath), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) + keyOut, err := os.OpenFile(filepath.Clean(d.keyPath), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) if err != nil { return err } @@ -453,7 +339,7 @@ func (d *DNSChallenge) writeCertificates(privateKey *rsa.PrivateKey, cert *x509. return fmt.Errorf("error closing key.pem: %v", err) } - certOut, err := os.Create(certPath) + certOut, err := os.Create(d.certPath) if err != nil { return err } @@ -466,8 +352,8 @@ func (d *DNSChallenge) writeCertificates(privateKey *rsa.PrivateKey, cert *x509. return err } - log.Printf("[INFO] wrote certificate to %s", certPath) - log.Printf("[INFO] wrote private key to %s", keyPath) + log.Printf("[INFO] wrote certificate to %s", d.certPath) + log.Printf("[INFO] wrote private key to %s", d.keyPath) return nil } @@ -488,49 +374,6 @@ func (d *DNSChallenge) getNameserverFn() func() string { } } -func (d *DNSChallenge) cleanupRecords() { - type recErr struct { - record dnsprovider.Record - err error - } - recCh := make(chan recErr, len(d.records)) - var recWg sync.WaitGroup - - waitCh := make(chan struct{}, 3) // limit concurrency to 3 - for _, r := range d.records { - recWg.Add(1) - go func(r dnsprovider.Record) { - waitCh <- struct{}{} - defer func() { - <-waitCh - recWg.Done() - }() - if err := d.provider.RemoveRecord(r); err != nil { - log.Printf("[INFO] cleanup failed to remove TXT record: %v", err) - recCh <- recErr{r, err} - return - } - recCh <- recErr{r, nil} - }(r) - } - - go func() { - recWg.Wait() - close(recCh) - }() - - recs := d.records[:0] - for r := range recCh { - if r.err == nil { - log.Printf("[INFO] cleanup removed TXT record %s.%s", r.record.Host, r.record.Domain) - continue - } - recs = append(recs, r.record) - log.Printf("[INFO] cleanup failed to remove TXT record: %v", r.err) - } - d.records = recs -} - func lookupTXTRecord(record dnsprovider.Record, nameserver string) error { r := &net.Resolver{ PreferGo: true, diff --git a/app/acme/dns_challenge_test.go b/app/acme/dns_challenge_test.go index 76d8e4a5..fef9e9bf 100644 --- a/app/acme/dns_challenge_test.go +++ b/app/acme/dns_challenge_test.go @@ -119,7 +119,6 @@ func TestMain(m *testing.M) { setupMock() acmeV2Enpoint = mockACMEServer.URL provider = &mockDNSProvider{} - //acmeOpTimeout = 10 * time.Second } os.Exit(m.Run()) } @@ -525,8 +524,6 @@ func TestDNSChallenge_prepareOrder(t *testing.T) { fmt.Sprintf("%s: expected %d identifiers, got %d", tt.name, tt.expected.numIdentifiers, len(d.order.Identifiers))) assert.NotEmpty(t, d.order.FinalizeURL, fmt.Sprintf("%s: expected FinalizeURL to be set", tt.name)) - assert.Equal(t, tt.expected.numRecords, len(d.records), - fmt.Sprintf("%s: expected %d records, got %d", tt.name, tt.expected.numRecords, len(d.records))) }) } } @@ -560,18 +557,13 @@ func TestDNSChallenge_solveDNSChallengeLEStaging(t *testing.T) { t.Fatal(err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() + dc.timeout = time.Minute * 5 - if err := dc.PreSolve(ctx); err != nil { + if err := dc.PreSolve(); err != nil { t.Fatal(err) } - if err := dc.Solve(ctx); err != nil { - t.Fatal(err) - } - - if err := dc.PostSolve(ctx); err != nil { + if err := dc.Solve(); err != nil { t.Fatal(err) } @@ -600,34 +592,38 @@ func TestDNSChallenge_solveDNSChallenge(t *testing.T) { for _, tt := range tests { d := &DNSChallenge{ - provider: &mockDNSProvider{}, - domains: tt.args.domains, + provider: &mockDNSProvider{}, + domains: tt.args.domains, + pollingInterval: time.Second * 1, + timeout: timeoutForTests, } if err := d.register(); err != nil { t.Fatal(err) } - { - ctx, cancel := context.WithTimeout(context.Background(), timeoutForTests) - err := d.PreSolve(ctx) - assert.Empty(t, err) - cancel() - } + err := d.PreSolve() + assert.Empty(t, err) - { - ctx, cancel := context.WithTimeout(context.Background(), timeoutForTests) - err := d.Solve(ctx) - assert.Empty(t, err) - cancel() - } + // mock the DNS provider to return a successful response + // after challenge is accepted + for i := range d.order.AuthzURLs { + authURL := &d.order.AuthzURLs[i] - { - ctx, cancel := context.WithTimeout(context.Background(), timeoutForTests) - err := d.PostSolve(ctx) - assert.Empty(t, err) - cancel() + var u *url.URL + u, err = url.Parse(*authURL) + if err != nil { + t.Fatal(err) + } + q := u.Query() + q.Set("afterAccept", "true") + u.RawQuery = q.Encode() + *authURL = u.String() } + + err = d.Solve() + assert.Empty(t, err) + } } @@ -846,8 +842,7 @@ func Test_GetCertificateExpiration(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := &DNSChallenge{} - got, err := d.GetCertificateExpiration(tt.fileName) + got, err := getCertificateExpiration(tt.fileName) if (err != nil) && !tt.wantErr { t.Errorf("getCertExpiration() error = %v, wantErr %v", err, tt.wantErr) return @@ -927,11 +922,17 @@ func TestDNSChallenge_writeCertificates(t *testing.T) { os.RemoveAll("./var") }) - d := &DNSChallenge{} for _, tt := range tests { + keyPath := "./var/acme/key.pem" + certPath := "./var/acme/cert.pem" if tt.envs != nil { - os.Setenv("SSL_KEY", tt.envs.keyPath) - os.Setenv("SSL_CERT", tt.envs.certPath) + keyPath = tt.envs.keyPath + certPath = tt.envs.certPath + } + + d := &DNSChallenge{ + certPath: certPath, + keyPath: keyPath, } err = d.writeCertificates(tt.args.privateKey, tt.args.cert) @@ -943,13 +944,6 @@ func TestDNSChallenge_writeCertificates(t *testing.T) { continue } - keyPath := "./var/acme/key.pem" - certPath := "./var/acme/cert.pem" - if tt.envs != nil { - keyPath = tt.envs.keyPath - certPath = tt.envs.certPath - } - _, err := os.Stat(keyPath) assert.Equal(t, tt.wantErr, err != nil, "[case %s] file with key: %v, filepath: %s", tt.name, err, keyPath) _, err = os.Stat(certPath) @@ -960,42 +954,12 @@ func TestDNSChallenge_writeCertificates(t *testing.T) { } } -func TestDNSChallenge_waitDNSPropagation(t *testing.T) { - type args struct { - records []dnsprovider.Record - } - tests := []struct { - name string - args args - expectedErrsNum int - }{ - {"correct case", - args{records: []dnsprovider.Record{{Host: "errorcase"}}}, - 1}, - {"timeout", - args{records: []dnsprovider.Record{{Host: "timeout"}}}, - 1}, - } - d := &DNSChallenge{ - provider: &mockDNSProvider{}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), timeoutForTests) - d.records = tt.args.records - errs := d.waitDNSPropagation(ctx) - assert.Equal(t, tt.expectedErrsNum, len(errs)) - cancel() - }) - } -} - func TestDNSChallenge_checkWithNS(t *testing.T) { type fields struct { nameservers []string } type args struct { - records []dnsprovider.Record + record dnsprovider.Record } tests := []struct { name string @@ -1003,126 +967,35 @@ func TestDNSChallenge_checkWithNS(t *testing.T) { args args wantErr bool }{ - {"timeout", fields{}, args{records: []dnsprovider.Record{{Host: "notexistinghost"}}}, true}, + {"timeout", fields{}, args{record: dnsprovider.Record{Host: "notexistinghost"}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { d := &DNSChallenge{ - nameservers: tt.fields.nameservers, + nameservers: tt.fields.nameservers, + pollingInterval: time.Second * 10, } ctx, cancel := context.WithTimeout(context.Background(), timeoutForTests) - d.records = tt.args.records - if err := d.checkWithNS(ctx); (err != nil) && !tt.wantErr { - t.Errorf("DNSChallenge.checkWithNS() error = %v, wantErr %v", err, tt.wantErr) - } - cancel() - }) - } -} -func TestDNSChallenge_acceptChallenges(t *testing.T) { - type args struct { - domains []string - } - tests := []struct { - name string - args args - wantErr bool - }{ - {"one domain", args{domains: []string{"mycompany-0.com"}}, false}, - {"multiple domain, wildcards", args{domains: []string{"mycompany-1.com", "*.mycompany-1.com"}}, false}, - {"wait auth status not valid", args{domains: []string{"mycompany-8.com"}}, true}, - {"wait auth for one domain not valid", args{domains: []string{"mycompany-0.com", "mycompany-8.com"}}, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // prepare test - d := &DNSChallenge{provider: &mockDNSProvider{}} - if err := d.register(); err != nil { - t.Fatal(err) - } - - { - ctx, cancel := context.WithTimeout(context.Background(), timeoutForTests) - if err := d.prepareOrder(ctx, tt.args.domains); err != nil { - t.Fatal(err) - } - cancel() - - // mock wait auth result - for i := range d.order.AuthzURLs { - authURL := &d.order.AuthzURLs[i] - u, err := url.Parse(*authURL) - if err != nil { - t.Fatal(err) - } - q := u.Query() - q.Set("afterAccept", "true") - u.RawQuery = q.Encode() - *authURL = u.String() - } - } - ctx, cancel := context.WithTimeout(context.Background(), timeoutForTests) - if err := d.acceptChallenges(ctx); (err != nil) && !tt.wantErr { - t.Errorf("DNSChallenge.acceptOrder() error = %v, wantErr %v", err, tt.wantErr) + if err := d.checkWithNS(ctx, tt.args.record); (err != nil) && !tt.wantErr { + t.Errorf("DNSChallenge.checkWithNS() error = %v, wantErr %v", err, tt.wantErr) } cancel() }) } } -func TestDNSChallenge_presentRecords(t *testing.T) { - type fields struct { - records []dnsprovider.Record - addedRecords []dnsprovider.Record // records should be added - removedRecords []dnsprovider.Record // records should be removed because one some record failed - } - tests := []struct { - name string - fields fields - wantErr bool - }{ - {"correct case", - fields{records: []dnsprovider.Record{{Domain: "mycompany-0.com"}}, - addedRecords: []dnsprovider.Record{{Domain: "mycompany-0.com"}}, - removedRecords: []dnsprovider.Record{}}, - false}, - {"add records failed", - fields{records: []dnsprovider.Record{{Domain: "mycompany-6.com"}}, - addedRecords: []dnsprovider.Record{}, - removedRecords: []dnsprovider.Record{{Domain: "mycompany-6.com"}}}, - true}, - {"one record added, one failed", - fields{records: []dnsprovider.Record{{Domain: "mycompany-0.com"}, {Domain: "mycompany-6.com"}}, - addedRecords: []dnsprovider.Record{{Domain: "mycompany-0.com"}}, - removedRecords: []dnsprovider.Record{{Domain: "mycompany-0.com"}, {Domain: "mycompany-6.com"}}}, - true}, - } - for _, tt := range tests { - addedRecords = make([]dnsprovider.Record, 0) - removedRecords = make([]dnsprovider.Record, 0) - - d := &DNSChallenge{ - provider: &mockDNSProvider{}, - records: tt.fields.records, - } - if err := d.presentRecords(); (err != nil) && !tt.wantErr { - t.Errorf("DNSChallenge.presentRecords() error = %v, wantErr %v", err, tt.wantErr) - } - - assert.Equal(t, tt.fields.addedRecords, addedRecords, fmt.Sprintf("case [%s], added records not match", tt.name)) - assert.Equal(t, tt.fields.removedRecords, removedRecords, fmt.Sprintf("case [%s] removed records not match", tt.name)) - } -} +func TestDNSChallenge_ObtainCertificate(t *testing.T) { + certPath := "./TestDNSChallenge_ObtainCertificate_Cert.pem" + keyPath := "./TestDNSChallenge_ObtainCertificate_Key.pem" -func TestDNSChallenge_PostSolve(t *testing.T) { if err := os.MkdirAll("./var/acme", os.ModePerm); err != nil { t.Fatal(err) } t.Cleanup(func() { - os.RemoveAll("./var") + os.Remove(certPath) + os.Remove(keyPath) }) type fields struct { @@ -1146,7 +1019,11 @@ func TestDNSChallenge_PostSolve(t *testing.T) { args{}, true}, } - d := &DNSChallenge{} + d := &DNSChallenge{ + timeout: timeoutForTests, + certPath: certPath, + keyPath: keyPath, + } if err := d.register(); err != nil { t.Fatal(err) } @@ -1157,59 +1034,26 @@ func TestDNSChallenge_PostSolve(t *testing.T) { d.order = tt.fields.order d.domains = tt.args.domains - ctx, cancel := context.WithTimeout(context.Background(), timeoutForTests) - err = d.PostSolve(ctx) - cancel() - if (err != nil) && !tt.wantErr { + if err = d.Solve(); (err != nil) && !tt.wantErr { t.Errorf("DNSChallenge.pullCert() error = %v, wantErr %v", err, tt.wantErr) continue } + if err = d.ObtainCertificate(); (err != nil) && !tt.wantErr { + t.Errorf("DNSChallenge.ObtainCertificate() error = %v, wantErr %v", err, tt.wantErr) + continue + } + if (err != nil) == tt.wantErr { continue } - _, err = os.Stat("./var/cert.pem") + _, err = os.Stat(certPath) assert.Empty(t, err, fmt.Sprintf("case [%s]: cert file not found", tt.name)) - _, err = os.Stat("./var/key.pem") + _, err = os.Stat(keyPath) assert.Empty(t, err, fmt.Sprintf("case [%s]: key file not found", tt.name)) // }) } } - -func TestDNSChallenge_cleanupRecords(t *testing.T) { - tests := []struct { - name string - records []dnsprovider.Record - expected []dnsprovider.Record - }{ - {name: "correct case", - records: []dnsprovider.Record{{Domain: "cleanupRecords1.com"}}, - expected: []dnsprovider.Record{}, - }, - {name: "deletion some records failed", - records: []dnsprovider.Record{{Domain: "cleanupRecords1.com"}, {Domain: "cleanupRecords2.com"}}, - expected: []dnsprovider.Record{{Domain: "cleanupRecords2.com"}}, - }, - {name: "deletion all records failed", - records: []dnsprovider.Record{{Domain: "cleanupRecords2.com"}, {Domain: "cleanupRecords3.com"}}, - expected: []dnsprovider.Record{{Domain: "cleanupRecords2.com"}, {Domain: "cleanupRecords3.com"}}, - }, - {name: "nothing to delete", - records: []dnsprovider.Record{}, - expected: []dnsprovider.Record{}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &DNSChallenge{ - provider: &mockDNSProvider{}, - records: tt.records, - } - d.cleanupRecords() - assert.ElementsMatch(t, tt.expected, d.records) - }) - } -} diff --git a/app/acme/dnsprovider/README.md b/app/acme/dnsprovider/README.md index acfb87fc..d14eaf1c 100644 --- a/app/acme/dnsprovider/README.md +++ b/app/acme/dnsprovider/README.md @@ -1,11 +1,17 @@ -# Supported DNS providers: +# DNS Providers for ACME challenge +Following section describes how to configure DNS providers for ACME challenge. Reproxy supports configuration via configuration file and environment variables. -## CloudDNS -- **CLOUDNS_AUTH_ID**: Your CloudDNS Auth ID -- **CLOUDNS_SUB_AUTH_ID**: Your CloudDNS Sub Auth ID (optional, if auth id used) -- **CLOUDNS_AUTH_PASSWORD** : Your CloudDNS Auth Password -- CLOUDNS_TTL: The TTL for the DNS records (default: 300) -- CLOUDNS_DNS_PROPAGATION_TIMEOUT: The time to wait for DNS propagation (in seconds) (default: 180) -- CLOUDNS_DNS_PROPAGATION_CHECK_INTERVAL: The time between DNS propagation status checks (in seconds) (default: 10) +## Supported DNS providers: +### [Amazon Route 53](https://aws.amazon.com/route53/) +- **Access Key ID** yaml: `access_key_id`, env: `ROUTE53_ACCESS_KEY_ID` +- **Secret Access Key** yaml: `secret_access_key`, env: `ROUTE53_SECRET_ACCESS_KEY` +- **Hosted Zone ID** yaml: `hosted_zone_id`, env: `ROUTE53_HOSTED_ZONE_ID` +- TTL (optional, default `300s`) yaml: `ttl`, env: `ROUTE53_TTL` +- Region(optional, default `us-east-1`) yaml: `region`, env: `ROUTE53_REGION` +### [CloudDNS](https://www.cloudns.net/) +- **Authorized User ID** yaml:`auth_id` env:`CLOUDNS_AUTH_ID`` +- **Authorized Subuser ID** yaml:`sub_auth_id` env:"`CLOUDNS_SUB_AUTH_ID` +- **Password** yaml:`password` env:`CLOUDNS_AUTH_PASSWORD` +- TTL (optional, default `300s`) yaml: `ttl` env:`CLOUDNS_TTL` \ No newline at end of file diff --git a/app/acme/dnsprovider/cloudns_test.go b/app/acme/dnsprovider/cloudns_test.go index fda6924e..7be389d1 100644 --- a/app/acme/dnsprovider/cloudns_test.go +++ b/app/acme/dnsprovider/cloudns_test.go @@ -168,12 +168,10 @@ func setupMock() { func Test_newCloudnsProvider(t *testing.T) { type envs struct { - authID string - subAuthID string - authPassword string - TTL string - dnsPropagationTimeout string - dnsPropagationInterval string + authID string + subAuthID string + authPassword string + TTL string } tests := []struct { @@ -182,13 +180,13 @@ func Test_newCloudnsProvider(t *testing.T) { wantErr bool }{ {"envs for authID and subAuthID not set", - envs{"", "", "", "", "", ""}, + envs{"", "", "", ""}, true}, {"env for password not set", - envs{"account", "subaccount", "", "", "", ""}, + envs{"account", "subaccount", "", ""}, true}, {"with default optional parameters", - envs{"account", "subaccount", "init1234", "", "", ""}, + envs{"account", "subaccount", "init1234", ""}, false}, } for _, tt := range tests { @@ -197,8 +195,6 @@ func Test_newCloudnsProvider(t *testing.T) { setEnv(envSubAuthID, tt.envs.subAuthID) setEnv(envAuthPassword, tt.envs.authPassword) setEnv(envTTL, tt.envs.TTL) - setEnv(envDNSPropagationTimeout, tt.envs.dnsPropagationTimeout) - setEnv(envDNSPropagationCheckInteval, tt.envs.dnsPropagationInterval) got, err := newCloudnsProvider(Opts{}) if (err != nil) && !tt.wantErr { @@ -214,12 +210,6 @@ func Test_newCloudnsProvider(t *testing.T) { assert.Equal(t, tt.envs.subAuthID, got.subAuthID, "subAuthID") assert.Equal(t, tt.envs.authPassword, got.authPassword, "authPassword") - expTimeout := time.Second * time.Duration(180) - assert.Equal(t, expTimeout, got.timeout, "dnsPropagationTimeout") - - expInterval := time.Second * time.Duration(10) - assert.Equal(t, expInterval, got.poolingInterval, "dnsPropagationInterval") - os.Unsetenv(envAuthID) os.Unsetenv(envSubAuthID) os.Unsetenv(envAuthPassword) diff --git a/app/acme/dnsprovider/route53.go b/app/acme/dnsprovider/route53.go index f493cad1..611e48e2 100644 --- a/app/acme/dnsprovider/route53.go +++ b/app/acme/dnsprovider/route53.go @@ -1,15 +1,16 @@ package dnsprovider import ( + "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/hex" + "encoding/xml" "fmt" "net/http" "net/url" "sort" - "strconv" "strings" "time" @@ -18,43 +19,45 @@ import ( const route53Endpoint = "https://route53.amazonaws.com" -const ( - payloadXML = ` - - - optional comment about the changes in this change batch request - - - %s - - %s - %s - %s - - - %s - - - - - - - ` -) +// changeRecordsReq is the payload for the ChangeResourceRecordSets API call. +type changeRecordsRequest struct { + XMLName xml.Name `xml:"https://route53.amazonaws.com/doc/2013-04-01/ ChangeResourceRecordSetsRequest"` + Comment string `xml:"ChangeBatch>Comment,omitempty"` + Action string `xml:"ChangeBatch>Changes>Change>Action"` + Name string `xml:"ChangeBatch>Changes>Change>ResourceRecordSet>Name"` + Type string `xml:"ChangeBatch>Changes>Change>ResourceRecordSet>Type"` + TTL int `xml:"ChangeBatch>Changes>Change>ResourceRecordSet>TTL"` + Value string `xml:"ChangeBatch>Changes>Change>ResourceRecordSet>ResourceRecords>ResourceRecord>Value"` +} + +type changeRecordsResponse struct { + XMLName xml.Name `xml:"ChangeResourceRecordSetsResponse"` + Comment string `xml:"ChangeInfo>Comment"` + ID string `xml:"ChangeInfo>Id"` + Status string `xml:"ChangeInfo>Status"` + SubmittedAt string `xml:"ChangeInfo>SubmittedAt"` +} + +type getChangeResponse struct { + XMLName xml.Name `xml:"GetChangeResponse"` + ID string `xml:"ChangeInfo>Id"` + Status string `xml:"ChangeInfo>Status"` + SubmittedAt string `xml:"ChangeInfo>SubmittedAt"` +} type kv struct { key string value string } -// awsRequestParams are the parameters used to make an AWS request. -type awsRequestParams struct { +// awsRequestOpts are the parameters used to make an AWS request. +type awsRequestOpts struct { method string uri string amzTime time.Time queryParams []kv headers []kv - payload string + payload []byte region string service string } @@ -67,15 +70,21 @@ type route53Config struct { Region string `yaml:"region" env:"ROUTE53_REGION"` } +type recrecordWithID struct { + ID string + Record +} + type route53 struct { accessKeyID string secretAccessKey string region string hostedZoneID string timeout time.Duration - interval time.Duration + poolingInterval time.Duration ttl int client *http.Client + addedRecords []recrecordWithID } func newRoute53Provider(opts Opts) (*route53, error) { @@ -84,7 +93,7 @@ func newRoute53Provider(opts Opts) (*route53, error) { // try to read config from file first and fallback to environment variables if err := cleanenv.ReadConfig(opts.ConfigPath, &conf); err != nil { if errc := cleanenv.ReadEnv(&conf); errc != nil { - return nil, fmt.Errorf("route53 provider: failed to read config: %s", err) + return nil, fmt.Errorf("route53: failed to read config: %s", err) } } @@ -93,87 +102,224 @@ func newRoute53Provider(opts Opts) (*route53, error) { secretAccessKey: conf.SecretAccessKey, region: conf.Region, hostedZoneID: conf.HostedZoneID, + addedRecords: []recrecordWithID{}, ttl: conf.TTL, + timeout: opts.Timeout, + poolingInterval: opts.PollingInterval, client: &http.Client{Timeout: opts.Timeout}, }, nil } // AddRecord creates TXT records for the specified FQDN and value. func (r *route53) AddRecord(record Record) error { - t := time.Now() - params := awsRequestParams{ - method: "POST", - uri: "/2013-04-01/hostedzone/" + r.hostedZoneID + "/rrset/", + resp, err := r.changeRecord("UPSERT", record) + if err != nil { + return err + } + + r.addedRecords = append(r.addedRecords, recrecordWithID{ + ID: strings.TrimPrefix(resp.ID, "/change/"), + Record: record, + }) + + return nil +} + +// RemoveRecord removes the TXT records matching the specified FQDN and value. +func (r *route53) RemoveRecord(record Record) error { + _, err := r.changeRecord("DELETE", record) + if err != nil { + return err + } + + recs := r.addedRecords[:0] + for _, rec := range r.addedRecords { + if rec.Host == record.Host && rec.Domain == record.Domain && + rec.Type == record.Type && rec.Value == record.Value { + continue + } + recs = append(recs, rec) + } + r.addedRecords = recs + + return nil +} + +// WaitUntilPropagated waits for the DNS records to propagate. +// The method will be called after creating TXT records. A provider API could be +// used to check propagation status. +func (r *route53) WaitUntilPropagated(ctx context.Context, record Record) error { + ticker := time.NewTicker(r.poolingInterval) + timer := time.NewTimer(r.timeout) + + var changeID string + for _, rec := range r.addedRecords { + if rec.Host == record.Host && rec.Domain == record.Domain && + rec.Type == record.Type && rec.Value == record.Value { + changeID = rec.ID + break + } + } + + if changeID == "" { + return fmt.Errorf("route53: failed to find change ID for record %s", record) + } + + for { + select { + case <-ticker.C: + updated, err := r.isUpdated(changeID) + if err != nil { + return err + } + if updated { + return nil + } + case <-ctx.Done(): + return fmt.Errorf("route53: timeout waiting for DNS propagation") + case <-timer.C: + return fmt.Errorf("route53: timeout waiting for DNS propagation") + } + } +} + +func (r *route53) isUpdated(changeID string) (bool, error) { + t := time.Now().UTC() + + reqOpts := awsRequestOpts{ + method: "GET", + uri: fmt.Sprintf("/2013-04-01/change/%s", changeID), amzTime: t, queryParams: []kv{ - {key: "Id", value: r.hostedZoneID}, + {key: "Action", value: "GetChange"}, + {key: "Id", value: changeID}, + {key: "Version", value: "2013-04-01"}, }, headers: []kv{ {key: "Host", value: "route53.amazonaws.com"}, - {key: "Content-Type", value: "text/xml"}, - {key: "X-Amz-Date", value: t.Format("2006-01-02")}, + {key: "X-Amz-Date", value: t.Format("20060102T150405Z")}, }, - payload: fmt.Sprintf(payloadXML, "CREATE", - fmt.Sprintf("%s%s.", record.Host, record.Domain), record.Type, strconv.Itoa(r.ttl), record.Value), + payload: []byte(""), region: r.region, service: "route53", } - req, err := r.prepareRequest(params) + req, err := r.prepareRequest(reqOpts) if err != nil { - return err + return false, err } resp, err := r.client.Do(req) if err != nil { - return err + return false, err } if resp.StatusCode != http.StatusOK { - return fmt.Errorf("route53 provider: failed to add record: %s", resp.Status) + return false, fmt.Errorf("route53: errorcode by retrieving record status %s", resp.Status) } - return nil -} + var response getChangeResponse + if err := xml.NewDecoder(resp.Body).Decode(&response); err != nil { + return false, err + } -// RemoveRecord removes the TXT records matching the specified FQDN and value. -func (r *route53) RemoveRecord(record Record) error { + if response.Status == "INSYNC" { + return true, nil + } - return nil + return false, fmt.Errorf("route53: status of change: %s (INSYNC required)", response.Status) } -// WaitUntilPropagated waits for the DNS records to propagate. -// The method will be called after creating TXT records. A provider API could be -// used to check propagation status. -func (r *route53) WaitUntilPropagated(ctx context.Context, record Record) error { +func (r *route53) changeRecord(action string, record Record) (*changeRecordsResponse, error) { + t := time.Now().UTC() - return nil -} + payload := changeRecordsRequest{ + Action: action, + Name: fmt.Sprintf("%s.%s.", record.Host, record.Domain), + Type: record.Type, + TTL: r.ttl, + Value: fmt.Sprintf("%q", record.Value), + } + + bp, err := xml.Marshal(payload) + if err != nil { + return nil, err + } + + reqOpts := awsRequestOpts{ + method: "POST", + uri: "/2013-04-01/hostedzone/" + r.hostedZoneID + "/rrset/", + amzTime: t, + queryParams: []kv{ + {key: "Action", value: "ChangeResourceRecordSets"}, + {key: "Id", value: r.hostedZoneID}, + {key: "Version", value: "2013-04-01"}, + }, + headers: []kv{ + {key: "Host", value: "route53.amazonaws.com"}, + {key: "X-Amz-Date", value: t.Format("20060102T150405Z")}, + }, + payload: bp, + region: r.region, + service: "route53", + } -func (r *route53) prepareRequest(params awsRequestParams) (*http.Request, error) { + req, err := r.prepareRequest(reqOpts) + if err != nil { + return nil, err + } - req, err := http.NewRequest(params.method, route53Endpoint+params.uri, http.NoBody) + resp, err := r.client.Do(req) if err != nil { return nil, err } - signature := r.calculateSignature(params) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("route53: failed to add record: %s", resp.Status) + } - hdrs := make([]string, 0, len(params.headers)) - for _, h := range params.headers { + var response changeRecordsResponse + if err := xml.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, err + } + + return &response, nil +} + +func (r *route53) prepareRequest(opts awsRequestOpts) (*http.Request, error) { + req, err := http.NewRequest(opts.method, route53Endpoint+opts.uri, bytes.NewReader(opts.payload)) + if err != nil { + return nil, err + } + + signature := r.calculateSignature(opts) + + hdrs := make([]string, 0, len(opts.headers)) + for _, h := range opts.headers { hdrs = append(hdrs, h.key) } shdrs := strings.Join(hdrs, ";") + cred := fmt.Sprintf("%s/%s/%s/%s/aws4_request", r.accessKeyID, opts.amzTime.Format("20060102"), opts.region, opts.service) req.Header.Set("Authorization", fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s, SignedHeaders=%s, Signature=%s", - r.accessKeyID, shdrs, signature)) + cred, shdrs, signature)) + + for _, h := range opts.headers { + req.Header.Set(h.key, h.value) + } + + q := req.URL.Query() + for _, p := range opts.queryParams { + q.Add(p.key, p.value) + } + req.URL.RawQuery = q.Encode() return req, nil } -func (r *route53) calculateSignature(params awsRequestParams) string { +func (r *route53) calculateSignature(params awsRequestOpts) string { canonicalReq := createCanonicalReq(params) stringToSign := "AWS4-HMAC-SHA256\n" @@ -201,7 +347,7 @@ func (r *route53) calculateSignature(params awsRequestParams) string { return signature } -func createCanonicalReq(params awsRequestParams) string { +func createCanonicalReq(params awsRequestOpts) string { // sort by value, url.Values.Encode takes care of sorting by key sort.Slice(params.queryParams, func(i, j int) bool { return params.queryParams[i].value > params.queryParams[j].value @@ -223,7 +369,7 @@ func createCanonicalReq(params awsRequestParams) string { canHKeys := strings.Join(headKeys, ";") phash := sha256.New() - phash.Write([]byte(params.payload)) + phash.Write(params.payload) payloadHashed := strings.ToLower(fmt.Sprintf("%x", phash.Sum(nil))) canReq := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", params.method, params.uri, canParam, canHead, canHKeys, payloadHashed) diff --git a/app/acme/dnsprovider/route53_test.go b/app/acme/dnsprovider/route53_test.go index 830019b3..3b4102b5 100644 --- a/app/acme/dnsprovider/route53_test.go +++ b/app/acme/dnsprovider/route53_test.go @@ -13,14 +13,14 @@ func Test_createCanonicalReq(t *testing.T) { tests := []struct { name string - args awsRequestParams + args awsRequestOpts want string }{ {"example from amazon documentation", - awsRequestParams{"GET", "/", ttime, + awsRequestOpts{"GET", "/", ttime, []kv{{"Version", "2010-05-08"}, {"Action", "ListUsers"}}, []kv{{"content-type", "application/x-www-form-urlencoded; charset=utf-8"}, {"host", "iam.amazonaws.com"}, {"x-amz-date", "20150830T123600Z"}}, - "", + []byte(""), "us-east-1", "iam"}, "f536975d06c0309214f805bb90ccff089219ecd68b2577efef23edd43b7e1a59"}, @@ -47,15 +47,15 @@ func Test_route53_calculateSignature(t *testing.T) { tests := []struct { name string fields fields - args awsRequestParams + args awsRequestOpts want string }{ {"example from amazon documentation", fields{"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"}, - awsRequestParams{"GET", "/", ttime, + awsRequestOpts{"GET", "/", ttime, []kv{{"Version", "2010-05-08"}, {"Action", "ListUsers"}}, []kv{{"content-type", "application/x-www-form-urlencoded; charset=utf-8"}, {"host", "iam.amazonaws.com"}, {"x-amz-date", "20150830T123600Z"}}, - "", + []byte(""), "us-east-1", "iam"}, "5d672d79c15b13162d9279b0855cfba6789a8edb4c82c400e06b5924a6f2b5d7"}, diff --git a/app/main.go b/app/main.go index 080a5c6d..660a9385 100644 --- a/app/main.go +++ b/app/main.go @@ -40,19 +40,19 @@ var opts struct { LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint SSL struct { - Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` //nolint - Cert string `long:"cert" default:"./var/acme/cert.pem" env:"CERT" description:"path to cert.pem file"` - Key string `long:"key" default:"./var/acme/key.pem" env:"KEY" description:"path to key.pem file"` - ACMELocation string `long:"acme-location" env:"ACME_LOCATION" description:"dir where certificates will be stored by autocert manager" default:"./var/acme"` - ACMEEmail string `long:"acme-email" env:"ACME_EMAIL" description:"admin email for certificate notifications"` - RedirHTTPPort int `long:"http-port" env:"HTTP_PORT" description:"http port for redirect to https and acme challenge test (default: 8080 under docker, 80 without)"` - FQDNs []string `long:"fqdn" env:"ACME_FQDN" env-delim:"," description:"FQDN(s) for ACME certificates"` - DNSChallengeEnabled bool `long:"dns-challenge-enabled" env:"ACME_DNS_CHALLENGE_ENABLED" description:"enable dns challenge"` - DNSProvider string `long:"dns-challenge-provider" env:"ACME_DNS_CHALLENGE_PROVIDER" description:"DNS provider" choice:"cloudns" choice:"cloudflare" choice:"route53" default:"cloudns"` //nolint - DNSResolvers []string `long:"dns-challenge-resolvers" env-delim:"," env:"ACME_DNS_CHALLENGE_RESOLVERS" description:"DNS resolvers" ` - DNSChallengeTimeout int `long:"dns-challenge-timeout" env:"ACME_DNS_CHALLENGE_TIMEOUT" description:"DNS challenge timeout in seconds" default:"180"` - DNSChallengeInterval int `long:"dns-challenge-interval" env:"ACME_DNS_CHALLENGE_INTERVAL" description:"DNS challenge polling interval in seconds" default:"10"` - DNSProviderConf string `long:"dns-provider-config" env:"SSL_ACME_DNS_PROVIDER_CONFIG" description:"path to DNS provider config file"` + Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` //nolint + Cert string `long:"cert" default:"./var/acme/cert.pem" env:"CERT" description:"path to cert.pem file"` + Key string `long:"key" default:"./var/acme/key.pem" env:"KEY" description:"path to key.pem file"` + ACMELocation string `long:"acme-location" env:"ACME_LOCATION" description:"dir where certificates will be stored by autocert manager" default:"./var/acme"` + ACMEEmail string `long:"acme-email" env:"ACME_EMAIL" description:"admin email for certificate notifications"` + RedirHTTPPort int `long:"http-port" env:"HTTP_PORT" description:"http port for redirect to https and acme challenge test (default: 8080 under docker, 80 without)"` + FQDNs []string `long:"fqdn" env:"ACME_FQDN" env-delim:"," description:"FQDN(s) for ACME certificates"` + DNSChallengeEnabled bool `long:"dns-challenge-enabled" env:"ACME_DNS_CHALLENGE_ENABLED" description:"enable dns challenge"` + DNSProvider string `long:"dns-challenge-provider" env:"ACME_DNS_CHALLENGE_PROVIDER" description:"DNS provider" choice:"cloudns" choice:"cloudflare" choice:"route53" default:"cloudns"` //nolint + DNSResolvers []string `long:"dns-challenge-resolvers" env-delim:"," env:"ACME_DNS_CHALLENGE_RESOLVERS" description:"DNS resolvers" ` + DNSChallengeTimeout time.Duration `long:"dns-challenge-timeout" env:"ACME_DNS_CHALLENGE_TIMEOUT" description:"DNS challenge timeout in seconds" default:"300s"` + DNSChallengeInterval time.Duration `long:"dns-challenge-interval" env:"ACME_DNS_CHALLENGE_INTERVAL" description:"DNS challenge polling interval in seconds" default:"10s"` + DNSProviderConf string `long:"dns-provider-config" env:"SSL_ACME_DNS_PROVIDER_CONFIG" description:"path to DNS provider config file"` } `group:"ssl" namespace:"ssl" env-namespace:"SSL"` Assets struct { @@ -214,8 +214,10 @@ func run() error { Domains: domains, Nameservers: opts.SSL.DNSResolvers, ProviderConfig: opts.SSL.DNSProviderConf, - Timeout: time.Second * time.Duration(opts.SSL.DNSChallengeTimeout), - PollingInterval: time.Second * time.Duration(opts.SSL.DNSChallengeInterval), + Timeout: opts.SSL.DNSChallengeTimeout, + PollingInterval: opts.SSL.DNSChallengeInterval, + CertPath: opts.SSL.Cert, + KeyPath: opts.SSL.Key, } var dc *acme.DNSChallenge @@ -225,7 +227,7 @@ func run() error { } if dc != nil { - acme.ScheduleCertificateRenewal(dc, dcc.Timeout) + acme.ScheduleCertificateRenewal(context.Background(), dc, opts.SSL.Cert) } }