-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathcmd_workflow.go
298 lines (258 loc) · 8.42 KB
/
cmd_workflow.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
package sparta
import (
"context"
"fmt"
"strings"
"sync"
"time"
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
////////////////////////////////////////////////////////////////////////////////
// CONSTANTS
////////////////////////////////////////////////////////////////////////////////
var (
// SpartaTagBuildIDKey is the keyname used in the CloudFormation Output
// that stores the user-supplied or automatically generated BuildID
// for this run
SpartaTagBuildIDKey = spartaTagName("buildId")
// SpartaTagBuildVersion is the Sparta version used to provision
// the application
SpartaTagSpartaVersionKey = spartaTagName("sparta-version")
// SpartaTagBuildTagsKey is the keyname used in the CloudFormation Output
// that stores the optional user-supplied golang build tags
SpartaTagBuildTagsKey = spartaTagName("buildTags")
)
const (
// MetadataParamCloudFormationStackPath is the path to the template
MetadataParamCloudFormationStackPath = "CloudFormationStackPath"
// MetadataParamServiceName is the name of the stack to use
MetadataParamServiceName = "ServiceName"
// MetadataParamS3Bucket is the Metadata param we use for the bucket
MetadataParamS3Bucket = "ArtifactS3Bucket"
// Metadata params for a ZIP archive
//
// MetadataParamCodeArchivePath is the intemediate local path to the code
MetadataParamCodeArchivePath = "CodeArchivePath"
// MetadataParamS3SiteArchivePath is the intemediate local path to the S3 site contents
MetadataParamS3SiteArchivePath = "S3SiteArtifactPath"
// Metadata params for OCI builds
//
// MetadataParamECRTag is the locally tagged Docker image to push
MetadataParamECRTag = "ECRTag"
)
const (
// StackParamS3CodeKeyName is the Stack Parameter to the S3 key of the uploaded asset
StackParamS3CodeKeyName = "CodeArtifactS3Key"
// StackParamArtifactBucketName is where we uploaded the artifact to
StackParamArtifactBucketName = MetadataParamS3Bucket
// StackParamS3CodeVersion is the object version to use for the S3 item
StackParamS3CodeVersion = "CodeArtifactS3ObjectVersion"
// StackParamS3SiteArchiveKey is the param to the S3 archive for a static website.
StackParamS3SiteArchiveKey = "SiteArtifactS3Key"
// StackParamS3SiteArchiveVersion is the version of the S3 artifact to use
StackParamS3SiteArchiveVersion = "SiteArtifactS3ObjectVersion"
// StackParamCodeImageURI is the ImageURI to the uploaded image
StackParamCodeImageURI = "CodeImageURI"
)
const (
// StackOutputBuildTime is the Output param for when this template was built
StackOutputBuildTime = "TemplateCreationTime"
// StackOutputBuildID is the Output tag that holds the build id
StackOutputBuildID = "BuildID"
)
func showOptionalAWSUsageInfo(err error, logger *zerolog.Logger) {
if err == nil {
return
}
var missingRegionErr *awsv2.MissingRegionError
if errors.As(err, &missingRegionErr) {
logger.Error().Msg("")
logger.Error().Msg("Consider setting env.AWS_REGION, env.AWS_DEFAULT_REGION, or env.AWS_SDK_LOAD_CONFIG to resolve this issue.")
logger.Error().Msg("See the documentation at https://pkg.go.dev/github.com/aws/aws-sdk-go-v2 for more information.")
logger.Error().Msg("")
}
}
func spartaTagName(baseKey string) string {
return fmt.Sprintf("io:sparta:%s", baseKey)
}
// Sanitize the provided input by replacing illegal characters with underscores
func sanitizedName(input string) string {
return reSanitize.ReplaceAllString(input, "_")
}
type pipelineBaseOp interface {
Invoke(context.Context, *zerolog.Logger) error
Rollback(context.Context, *zerolog.Logger) error
}
type pipelineStageBase interface {
Run(context.Context, *zerolog.Logger) error
Append(string, pipelineBaseOp) pipelineStageBase
Rollback(context.Context, *zerolog.Logger) error
}
type pipelineStageOpEntry struct {
opName string
op pipelineBaseOp
}
type pipelineStage struct {
ops []*pipelineStageOpEntry
}
func (ps *pipelineStage) Append(opName string, op pipelineBaseOp) pipelineStageBase {
ps.ops = append(ps.ops, &pipelineStageOpEntry{
opName: opName,
op: op,
})
return ps
}
func (ps *pipelineStage) Run(ctx context.Context, logger *zerolog.Logger) error {
var wg sync.WaitGroup
var mapErr sync.Map
for eachIndex, eachEntry := range ps.ops {
wg.Add(1)
go func(opIndex int, opEntry *pipelineStageOpEntry, goLogger *zerolog.Logger) {
defer wg.Done()
opErr := opEntry.op.Invoke(ctx, goLogger)
if opErr != nil {
mapErr.Store(opEntry.opName, opErr)
}
}(eachIndex, eachEntry, logger)
}
wg.Wait()
// Were there any errors?
errorText := []string{}
mapErr.Range(func(key interface{}, value interface{}) bool {
errorText = append(errorText, fmt.Sprintf("%s=>%v",
key,
value))
return true
})
if len(errorText) != 0 {
return errors.New(strings.Join(errorText, ", "))
}
return nil
}
func (ps *pipelineStage) Rollback(ctx context.Context, logger *zerolog.Logger) error {
// Ok, another wg to async cleanup everything. Operations
// need to be a bit stateful for this...
var wgRollback sync.WaitGroup
logger.Debug().Msgf("Rolling back %T due to errors", ps)
for _, eachEntry := range ps.ops {
wgRollback.Add(1)
go func(opEntry *pipelineStageOpEntry, goLogger *zerolog.Logger) {
defer wgRollback.Done()
opErr := opEntry.op.Rollback(ctx, goLogger)
if opErr != nil {
goLogger.Warn().Msgf("Operation (%s) rollback failed: %s", opEntry.opName, opErr)
}
}(eachEntry, logger)
}
wgRollback.Wait()
return nil
}
type pipelineStageEntry struct {
stageName string
stage pipelineStageBase
duration time.Duration
}
type pipeline struct {
stages []*pipelineStageEntry
startTime time.Time
}
func (p *pipeline) Append(stageName string, stage pipelineStageBase) *pipeline {
p.stages = append(p.stages, &pipelineStageEntry{
stageName: stageName,
stage: stage,
})
return p
}
func (p *pipeline) Run(ctx context.Context,
name string,
logger *zerolog.Logger) error {
p.startTime = time.Now()
// Run the stages, if there is an error, rollback
for stageIndex, curStage := range p.stages {
startTime := time.Now()
stageErr := curStage.stage.Run(ctx, logger)
if stageErr != nil {
logger.Error().Msgf("Pipeline stage %s failed", curStage.stageName)
for index := stageIndex; index >= 0; index-- {
rollbackErr := p.stages[index].stage.Rollback(ctx, logger)
if rollbackErr != nil {
logger.Warn().Msgf("Pipeline stage %s failed to Rollback", curStage.stageName)
}
}
return stageErr
}
curStage.duration = time.Since(startTime)
}
// Log the total stage execution times...
logger.Debug().Msg(headerDivider)
for _, eachStageEntry := range p.stages {
logger.Debug().
Str("Name", eachStageEntry.stageName).
Str("Duration", eachStageEntry.duration.String()).
Msg("Stage duration")
}
return nil
}
////////////////////////////////////////////////////////////////////////////////
// Common stages
////////////////////////////////////////////////////////////////////////////////
type userFunctionRollbackOp struct {
serviceName string
awsConfig awsv2.Config
noop bool
rollbackFuncs []RollbackHookHandler
}
func (ufro *userFunctionRollbackOp) Rollback(ctx context.Context, logger *zerolog.Logger) error {
wg := sync.WaitGroup{}
for _, eachRollbackHook := range ufro.rollbackFuncs {
wg.Add(1)
go func(ctx context.Context,
handler RollbackHookHandler,
serviceName string,
config awsv2.Config,
noop bool,
logger *zerolog.Logger) {
// Decrement the counter when the goroutine completes.
defer wg.Done()
_, rollbackErr := handler.Rollback(ctx,
serviceName,
config,
noop,
logger)
if rollbackErr != nil {
logger.Warn().
Err(rollbackErr).
Str("Function", fmt.Sprintf("%T", handler)).
Msg("Rollback function failed")
}
}(ctx,
eachRollbackHook,
ufro.serviceName,
ufro.awsConfig,
ufro.noop,
logger)
}
wg.Wait()
return nil
}
func (ufro *userFunctionRollbackOp) Invoke(ctx context.Context, logger *zerolog.Logger) error {
return nil
}
func newUserRollbackEnabledPipeline(serviceName string,
config awsv2.Config,
rollbackFuncs []RollbackHookHandler,
noop bool) *pipeline {
buildPipeline := &pipeline{}
// Verify
rollbackStateUserFunctions := &pipelineStage{}
rollbackStateUserFunctions.Append("userRollbackFunctions", &userFunctionRollbackOp{
serviceName: serviceName,
awsConfig: config,
noop: noop,
rollbackFuncs: rollbackFuncs,
})
buildPipeline.Append("userRollbackFunctions", rollbackStateUserFunctions)
return buildPipeline
}