-
-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
2,862 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,5 @@ docker-compose-private.yml | |
.vscode | ||
.idea | ||
*.gpg | ||
.DS_Store | ||
*.pem |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
package acme | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"log" | ||
"time" | ||
|
||
"github.com/go-pkgz/repeater" | ||
) | ||
|
||
var acmeOpTimeout = 5 * time.Minute | ||
|
||
// Solver is an interface for solving ACME DNS challenge | ||
type Solver interface { | ||
// PreSolve is called before solving the challenge. ACME Order will be created and DNS record will be added. | ||
PreSolve(ctx context.Context) error | ||
// Solve is called to present TXT record and accept challenge. | ||
Solve(ctx context.Context) error | ||
// PostSolve is called after obtaining the certificate. | ||
PostSolve(ctx context.Context) error | ||
// GetCertificateExpiration returns certificate expiration date | ||
GetCertificateExpiration(certPath string) (time.Time, error) | ||
} | ||
|
||
// fqdns []string, provider string, nameservers []string | ||
|
||
// ScheduleCertificateRenewal schedules certificate renewal | ||
func ScheduleCertificateRenewal(solver Solver) { | ||
certPath := getEnvOptionalString("SSL_CERT", "./var/acme/cert.pem") | ||
|
||
go func(certPath string) { | ||
var ( | ||
expiredAt time.Time | ||
err error | ||
) | ||
|
||
dur := acmeOpTimeout >> 12 | ||
fmt.Println(dur.Milliseconds()) | ||
|
||
expiredAt, err = solver.GetCertificateExpiration(certPath) | ||
if err != nil { | ||
expiredAt = time.Now() | ||
log.Printf("[INFO] failed to get certificate expiration date, probably not obtained yet: %v", err) | ||
} | ||
|
||
for { | ||
<-time.After(time.Until(expiredAt.Add(time.Hour * 24 * -5))) | ||
|
||
// add DNS record and wait for propagation | ||
{ | ||
ctx, cancel := context.WithTimeout(context.Background(), acmeOpTimeout) | ||
err = repeater.NewDefault(10, acmeOpTimeout>>12).Do(ctx, func() error { | ||
if errc := solver.PreSolve(ctx); errc != nil { | ||
log.Printf("[INFO] error in ACME DNS Challenge Presolve: %v", errc) | ||
return errc | ||
} | ||
return nil | ||
}) | ||
cancel() | ||
if err != nil { | ||
log.Printf("[ERROR] ACME DNS Challenge Presolve failed. Last error %v", err) | ||
return | ||
} | ||
} | ||
|
||
// present TXT record and accept challenge | ||
{ | ||
ctx, cancel := context.WithTimeout(context.Background(), acmeOpTimeout) | ||
err = repeater.NewDefault(10, acmeOpTimeout>>12).Do(ctx, func() error { | ||
if errc := solver.Solve(ctx); errc != nil { | ||
log.Printf("[INFO] error in ACME DNS Challenge Solve: %v", errc) | ||
return errc | ||
} | ||
return nil | ||
}) | ||
cancel() | ||
if err != nil { | ||
log.Printf("[ERROR] retry limit reached ACME DNS Challenge Solve failed. Last error: %v", err) | ||
return | ||
} | ||
} | ||
|
||
// pull the certificate | ||
{ | ||
ctx, cancel := context.WithTimeout(context.Background(), acmeOpTimeout) | ||
err = repeater.NewDefault(10, acmeOpTimeout>>12).Do(ctx, func() error { | ||
if errc := solver.PostSolve(ctx); errc != nil { | ||
log.Printf("[INFO] error in ACME DNS Challenge PostSolve: %v", errc) | ||
return errc | ||
} | ||
return nil | ||
}) | ||
cancel() | ||
if err != nil { | ||
log.Printf("[ERROR] retry limit reached, ACME DNS Challenge PostSolve failed. Last error: %v", err) | ||
return | ||
} | ||
} | ||
|
||
expiredAt, err = solver.GetCertificateExpiration(certPath) | ||
if err != nil { | ||
log.Printf("[ERROR] failed to get certificate expiration date: %v", err) | ||
return | ||
} | ||
} | ||
}(certPath) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
package acme | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
type mockSolver struct { | ||
domain string | ||
expires time.Time | ||
preSolvedCalled int | ||
solveCalled int | ||
postSolvedCalled int | ||
} | ||
|
||
func (s *mockSolver) PreSolve(ctx context.Context) error { | ||
s.preSolvedCalled++ | ||
switch s.domain { | ||
case "mycompany1.com": | ||
return fmt.Errorf("preSolve failed") | ||
} | ||
return nil | ||
} | ||
|
||
func (s *mockSolver) Solve(ctx context.Context) error { | ||
s.solveCalled++ | ||
switch s.domain { | ||
case "mycompany2.com": | ||
return fmt.Errorf("solve failed") | ||
} | ||
return nil | ||
} | ||
|
||
func (s *mockSolver) PostSolve(ctx context.Context) error { | ||
s.postSolvedCalled++ | ||
switch s.domain { | ||
case "mycompany3.com": | ||
return fmt.Errorf("postSolved failed") | ||
} | ||
return nil | ||
} | ||
|
||
func (s *mockSolver) GetCertificateExpiration(certPath string) (time.Time, error) { | ||
// check called before loop starts | ||
if s.preSolvedCalled == 0 { | ||
switch s.domain { | ||
case "mycompany4.com": | ||
return time.Now().Add(time.Hour * 24 * 670), nil | ||
default: | ||
return time.Time{}, fmt.Errorf("certificate does not exist") | ||
} | ||
} | ||
return time.Now().Add(time.Hour * 24 * 365), nil | ||
} | ||
|
||
func TestScheduleCertificateRenewal(t *testing.T) { | ||
acmeOpTimeout = 15 * time.Second | ||
|
||
type args struct { | ||
domain string | ||
certExistedBefore bool | ||
expiryTime time.Time | ||
} | ||
|
||
type expected struct { | ||
preSolvedCalled int | ||
solveCalled int | ||
postSolvedCalled int | ||
} | ||
|
||
tests := []struct { | ||
name string | ||
args args | ||
expected expected | ||
}{ | ||
// {"certificate not existed before", | ||
// args{"example.com", false, time.Now().Add(time.Hour * 100 * 24)}, | ||
// expected{1, 1, 1}}, | ||
// {"presolve failed", | ||
// args{"mycompany1.com", false, time.Time{}}, | ||
// expected{10, 0, 0}}, | ||
// {"solve failed", | ||
// args{"mycompany2.com", false, time.Time{}}, | ||
// expected{1, 10, 0}}, | ||
{"postsolve failed", | ||
args{"mycompany3.com", false, time.Time{}}, | ||
expected{1, 1, 10}}, | ||
// {"certificate valid for a long time", | ||
// args{"mycompany4.com", false, time.Time{}}, | ||
// expected{0, 0, 0}}, | ||
} | ||
|
||
for _, tt := range tests { | ||
s := &mockSolver{ | ||
domain: tt.args.domain, | ||
expires: tt.args.expiryTime, | ||
} | ||
|
||
ScheduleCertificateRenewal(s) | ||
|
||
time.Sleep(acmeOpTimeout) | ||
assert.Equal(t, tt.expected.preSolvedCalled, s.preSolvedCalled, fmt.Sprintf("[case %s] preSolvedCalled not match", tt.name)) | ||
assert.Equal(t, tt.expected.solveCalled, s.solveCalled, fmt.Sprintf("[case %s] solveCalled not match", tt.name)) | ||
assert.Equal(t, tt.expected.postSolvedCalled, s.postSolvedCalled, fmt.Sprintf("[case %s] postSolvedCalled not match", tt.name)) | ||
} | ||
} |
Oops, something went wrong.