diff --git a/internal/pkg/cli/job_init.go b/internal/pkg/cli/job_init.go index 462040b69b4..1c45b0c17ba 100644 --- a/internal/pkg/cli/job_init.go +++ b/internal/pkg/cli/job_init.go @@ -208,7 +208,7 @@ func (o *initJobOpts) Execute() error { DockerfilePath: o.dockerfilePath, Image: o.image, Platform: manifest.PlatformArgsOrString{ - PlatformString: o.platform, + PlatformString: manifest.PlatformStringP(o.platform), }, }, diff --git a/internal/pkg/cli/svc_init.go b/internal/pkg/cli/svc_init.go index a8cf37caa94..2781ce7a595 100644 --- a/internal/pkg/cli/svc_init.go +++ b/internal/pkg/cli/svc_init.go @@ -276,7 +276,7 @@ func (o *initSvcOpts) Execute() error { DockerfilePath: o.dockerfilePath, Image: o.image, Platform: manifest.PlatformArgsOrString{ - PlatformString: o.platform, + PlatformString: manifest.PlatformStringP(o.platform), }, Topics: o.topics, }, diff --git a/internal/pkg/deploy/cloudformation/stack/backend_svc_test.go b/internal/pkg/deploy/cloudformation/stack/backend_svc_test.go index c139e035f5b..3a757731d88 100644 --- a/internal/pkg/deploy/cloudformation/stack/backend_svc_test.go +++ b/internal/pkg/deploy/cloudformation/stack/backend_svc_test.go @@ -245,7 +245,8 @@ Outputs: if tc.setUpManifest != nil { tc.setUpManifest(conf) - conf.manifest.Network.VPC.Placement = aws.String(manifest.PrivateSubnetPlacement) + privatePlacement := manifest.Placement(manifest.PrivateSubnetPlacement) + conf.manifest.Network.VPC.Placement = &privatePlacement conf.manifest.Network.VPC.SecurityGroups = []string{"sg-1234"} } diff --git a/internal/pkg/deploy/cloudformation/stack/lb_web_svc.go b/internal/pkg/deploy/cloudformation/stack/lb_web_svc.go index 6c2becf0d17..605a35e1b35 100644 --- a/internal/pkg/deploy/cloudformation/stack/lb_web_svc.go +++ b/internal/pkg/deploy/cloudformation/stack/lb_web_svc.go @@ -163,7 +163,10 @@ func (s *LoadBalancedWebService) Template() (string, error) { deregistrationDelay = aws.Int64(int64(s.manifest.RoutingRule.DeregistrationDelay.Seconds())) } - allowedSourceIPs := s.manifest.AllowedSourceIps + var allowedSourceIPs []string + for _, ipNet := range s.manifest.AllowedSourceIps { + allowedSourceIPs = append(allowedSourceIPs, string(ipNet)) + } content, err := s.parser.ParseLoadBalancedWebService(template.WorkloadOpts{ Variables: s.manifest.Variables, Secrets: s.manifest.Secrets, diff --git a/internal/pkg/deploy/cloudformation/stack/transformers.go b/internal/pkg/deploy/cloudformation/stack/transformers.go index fa61128503b..c5d34e850bb 100644 --- a/internal/pkg/deploy/cloudformation/stack/transformers.go +++ b/internal/pkg/deploy/cloudformation/stack/transformers.go @@ -619,7 +619,10 @@ func convertNetworkConfig(network manifest.NetworkConfig) *template.NetworkOpts SubnetsType: template.PublicSubnetsPlacement, SecurityGroups: network.VPC.SecurityGroups, } - if aws.StringValue(network.VPC.Placement) != manifest.PublicSubnetPlacement { + if network.VPC.Placement == nil { + return opts + } + if *network.VPC.Placement != manifest.PublicSubnetPlacement { opts.AssignPublicIP = template.DisablePublicIP opts.SubnetsType = template.PrivateSubnetsPlacement } diff --git a/internal/pkg/deploy/cloudformation/stack/transformers_test.go b/internal/pkg/deploy/cloudformation/stack/transformers_test.go index 27ab594a7f7..64f04e44e1a 100644 --- a/internal/pkg/deploy/cloudformation/stack/transformers_test.go +++ b/internal/pkg/deploy/cloudformation/stack/transformers_test.go @@ -303,6 +303,7 @@ func Test_convertSidecar(t *testing.T) { func Test_convertAdvancedCount(t *testing.T) { mockRange := manifest.IntRangeBand("1-10") + mockPerc := manifest.Percentage(70) testCases := map[string]struct { input *manifest.AdvancedCount expected *template.AdvancedCount @@ -335,7 +336,7 @@ func Test_convertAdvancedCount(t *testing.T) { Range: manifest.Range{ Value: &mockRange, }, - CPU: aws.Int(70), + CPU: &mockPerc, }, expected: &template.AdvancedCount{ Autoscaling: &template.AutoscalingOpts{ @@ -354,7 +355,7 @@ func Test_convertAdvancedCount(t *testing.T) { SpotFrom: aws.Int(5), }, }, - CPU: aws.Int(70), + CPU: &mockPerc, }, expected: &template.AdvancedCount{ Autoscaling: &template.AutoscalingOpts{ @@ -486,10 +487,14 @@ func Test_convertCapacityProviders(t *testing.T) { } func Test_convertAutoscaling(t *testing.T) { - mockRange := manifest.IntRangeBand("1-100") - badRange := manifest.IntRangeBand("badRange") - mockRequests := 1000 - mockResponseTime := 512 * time.Millisecond + var ( + mockRange = manifest.IntRangeBand("1-100") + badRange = manifest.IntRangeBand("badRange") + mockRequests = 1000 + mockResponseTime = 512 * time.Millisecond + mockCPU = manifest.Percentage(70) + mockMem = manifest.Percentage(80) + ) testAcceptableLatency := 10 * time.Minute testAvgProcessingTime := 250 * time.Millisecond @@ -513,8 +518,8 @@ func Test_convertAutoscaling(t *testing.T) { Range: manifest.Range{ Value: &mockRange, }, - CPU: aws.Int(70), - Memory: aws.Int(80), + CPU: &mockCPU, + Memory: &mockMem, Requests: aws.Int(mockRequests), ResponseTime: &mockResponseTime, }, @@ -537,8 +542,8 @@ func Test_convertAutoscaling(t *testing.T) { SpotFrom: aws.Int(5), }, }, - CPU: aws.Int(70), - Memory: aws.Int(80), + CPU: &mockCPU, + Memory: &mockMem, Requests: aws.Int(mockRequests), ResponseTime: &mockResponseTime, }, diff --git a/internal/pkg/deploy/cloudformation/stack/worker_svc_test.go b/internal/pkg/deploy/cloudformation/stack/worker_svc_test.go index 384b01bbfff..171a6ea5523 100644 --- a/internal/pkg/deploy/cloudformation/stack/worker_svc_test.go +++ b/internal/pkg/deploy/cloudformation/stack/worker_svc_test.go @@ -281,7 +281,7 @@ Outputs: if tc.setUpManifest != nil { tc.setUpManifest(conf) - conf.manifest.Network.VPC.Placement = aws.String(manifest.PrivateSubnetPlacement) + conf.manifest.Network.VPC.Placement = &manifest.PrivateSubnetPlacement conf.manifest.Network.VPC.SecurityGroups = []string{"sg-1234"} } diff --git a/internal/pkg/manifest/applyenv_test.go b/internal/pkg/manifest/applyenv_test.go index 31ad2eca600..22730a81de5 100644 --- a/internal/pkg/manifest/applyenv_test.go +++ b/internal/pkg/manifest/applyenv_test.go @@ -525,28 +525,28 @@ func TestApplyEnv_StringSlice(t *testing.T) { }{ "string slice overridden": { inSvc: func(svc *LoadBalancedWebService) { - svc.RoutingRule.AllowedSourceIps = []string{"walk", "like", "an", "egyptian"} - svc.Environments["test"].RoutingRule.AllowedSourceIps = []string{"walk", "on", "the", "wild", "side"} + svc.ImageConfig.HealthCheck.Command = []string{"walk", "like", "an", "egyptian"} + svc.Environments["test"].ImageConfig.HealthCheck.Command = []string{"walk", "on", "the", "wild", "side"} }, wanted: func(svc *LoadBalancedWebService) { - svc.RoutingRule.AllowedSourceIps = []string{"walk", "on", "the", "wild", "side"} + svc.ImageConfig.HealthCheck.Command = []string{"walk", "on", "the", "wild", "side"} }, }, "string slice overridden by zero value": { inSvc: func(svc *LoadBalancedWebService) { - svc.RoutingRule.AllowedSourceIps = []string{"walk", "like", "an", "egyptian"} - svc.Environments["test"].RoutingRule.AllowedSourceIps = []string{} + svc.ImageConfig.HealthCheck.Command = []string{"walk", "like", "an", "egyptian"} + svc.Environments["test"].ImageConfig.HealthCheck.Command = []string{} }, wanted: func(svc *LoadBalancedWebService) { - svc.RoutingRule.AllowedSourceIps = []string{} + svc.ImageConfig.HealthCheck.Command = []string{} }, }, "string slice not overridden": { inSvc: func(svc *LoadBalancedWebService) { - svc.RoutingRule.AllowedSourceIps = []string{"walk", "like", "an", "egyptian"} + svc.ImageConfig.HealthCheck.Command = []string{"walk", "like", "an", "egyptian"} }, wanted: func(svc *LoadBalancedWebService) { - svc.RoutingRule.AllowedSourceIps = []string{"walk", "like", "an", "egyptian"} + svc.ImageConfig.HealthCheck.Command = []string{"walk", "like", "an", "egyptian"} }, }, } diff --git a/internal/pkg/manifest/backend_svc.go b/internal/pkg/manifest/backend_svc.go index 1c80eea2b55..d54e97915b0 100644 --- a/internal/pkg/manifest/backend_svc.go +++ b/internal/pkg/manifest/backend_svc.go @@ -144,7 +144,7 @@ func newDefaultBackendService() *BackendService { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP(PublicSubnetPlacement), + Placement: &PublicSubnetPlacement, }, }, }, diff --git a/internal/pkg/manifest/backend_svc_test.go b/internal/pkg/manifest/backend_svc_test.go index 7bc740f675c..58fd097dfa2 100644 --- a/internal/pkg/manifest/backend_svc_test.go +++ b/internal/pkg/manifest/backend_svc_test.go @@ -56,7 +56,7 @@ func TestNewBackendSvc(t *testing.T) { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, }, }, }, @@ -102,7 +102,7 @@ func TestNewBackendSvc(t *testing.T) { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, }, }, }, @@ -255,6 +255,7 @@ func TestBackendService_Publish(t *testing.T) { } func TestBackendSvc_ApplyEnv(t *testing.T) { + mockPercentage := Percentage(70) mockBackendServiceWithNoEnvironments := BackendService{ Workload: Workload{ Name: aws.String("phonetool"), @@ -366,7 +367,7 @@ func TestBackendSvc_ApplyEnv(t *testing.T) { TaskConfig: TaskConfig{ Count: Count{ AdvancedCount: AdvancedCount{ - CPU: aws.Int(70), + CPU: &mockPercentage, }, }, CPU: aws.Int(512), @@ -562,7 +563,7 @@ func TestBackendSvc_ApplyEnv(t *testing.T) { Memory: aws.Int(256), Count: Count{ AdvancedCount: AdvancedCount{ - CPU: aws.Int(70), + CPU: &mockPercentage, }, }, Variables: map[string]string{ @@ -689,6 +690,7 @@ func TestBackendSvc_ApplyEnv(t *testing.T) { func TestBackendSvc_ApplyEnv_CountOverrides(t *testing.T) { mockRange := IntRangeBand("1-10") + mockPercentage := Percentage(80) testCases := map[string]struct { svcCount Count envCount Count @@ -699,7 +701,7 @@ func TestBackendSvc_ApplyEnv_CountOverrides(t *testing.T) { svcCount: Count{ AdvancedCount: AdvancedCount{ Range: Range{Value: &mockRange}, - CPU: aws.Int(80), + CPU: &mockPercentage, }, }, envCount: Count{}, @@ -709,7 +711,7 @@ func TestBackendSvc_ApplyEnv_CountOverrides(t *testing.T) { Count: Count{ AdvancedCount: AdvancedCount{ Range: Range{Value: &mockRange}, - CPU: aws.Int(80), + CPU: &mockPercentage, }, }, }, diff --git a/internal/pkg/manifest/job.go b/internal/pkg/manifest/job.go index 7656ff711ba..6d9d090a224 100644 --- a/internal/pkg/manifest/job.go +++ b/internal/pkg/manifest/job.go @@ -157,7 +157,7 @@ func newDefaultScheduledJob() *ScheduledJob { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP(PublicSubnetPlacement), + Placement: &PublicSubnetPlacement, }, }, }, diff --git a/internal/pkg/manifest/job_test.go b/internal/pkg/manifest/job_test.go index e62266aabe8..bc9b9513878 100644 --- a/internal/pkg/manifest/job_test.go +++ b/internal/pkg/manifest/job_test.go @@ -125,7 +125,7 @@ func TestScheduledJob_ApplyEnv(t *testing.T) { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP(PublicSubnetPlacement), + Placement: &PublicSubnetPlacement, }, }, }, @@ -171,7 +171,7 @@ func TestScheduledJob_ApplyEnv(t *testing.T) { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP(PublicSubnetPlacement), + Placement: &PublicSubnetPlacement, }, }, }, diff --git a/internal/pkg/manifest/lb_web_svc.go b/internal/pkg/manifest/lb_web_svc.go index 634931e3022..47c21604f4e 100644 --- a/internal/pkg/manifest/lb_web_svc.go +++ b/internal/pkg/manifest/lb_web_svc.go @@ -109,7 +109,7 @@ func newDefaultLoadBalancedWebService() *LoadBalancedWebService { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP(PublicSubnetPlacement), + Placement: &PublicSubnetPlacement, }, }, }, @@ -182,11 +182,14 @@ type RoutingRule struct { Alias Alias `yaml:"alias"` DeregistrationDelay *time.Duration `yaml:"deregistration_delay"` // TargetContainer is the container load balancer routes traffic to. - TargetContainer *string `yaml:"target_container"` - TargetContainerCamelCase *string `yaml:"targetContainer"` // "targetContainerCamelCase" for backwards compatibility - AllowedSourceIps []string `yaml:"allowed_source_ips"` + TargetContainer *string `yaml:"target_container"` + TargetContainerCamelCase *string `yaml:"targetContainer"` // "targetContainerCamelCase" for backwards compatibility + AllowedSourceIps []IPNet `yaml:"allowed_source_ips"` } +// IPNet represents an IP network string. For example: 10.1.0.0/16 +type IPNet string + // Alias is a custom type which supports unmarshaling "http.alias" yaml which // can either be of type string or type slice of string. type Alias stringSliceOrString diff --git a/internal/pkg/manifest/lb_web_svc_test.go b/internal/pkg/manifest/lb_web_svc_test.go index a90077eb574..a4b789435e6 100644 --- a/internal/pkg/manifest/lb_web_svc_test.go +++ b/internal/pkg/manifest/lb_web_svc_test.go @@ -76,7 +76,7 @@ func TestNewLoadBalancedWebService(t *testing.T) { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, }, }, }, @@ -191,7 +191,12 @@ func TestLoadBalancedWebService_MarshalBinary(t *testing.T) { } func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { - mockRange := IntRangeBand("1-10") + var ( + mockIPNet1 = IPNet("10.1.0.0/24") + mockIPNet2 = IPNet("10.1.1.0/24") + mockRange = IntRangeBand("1-10") + mockPerc = Percentage(80) + ) testCases := map[string]struct { in *LoadBalancedWebService envToApply string @@ -369,7 +374,7 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, SecurityGroups: []string{"sg-123"}, }, }, @@ -526,7 +531,7 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, SecurityGroups: []string{"sg-456", "sg-789"}, }, }, @@ -540,7 +545,7 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { Count: Count{ AdvancedCount: AdvancedCount{ Range: Range{Value: &mockRange}, - CPU: aws.Int(80), + CPU: &mockPerc, }, }, }, @@ -566,7 +571,7 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { Value: nil, AdvancedCount: AdvancedCount{ Range: Range{Value: &mockRange}, - CPU: aws.Int(80), + CPU: &mockPerc, }, }, }, @@ -591,13 +596,13 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { Count: Count{ AdvancedCount: AdvancedCount{ Range: Range{Value: &mockRange}, - CPU: aws.Int(80), + CPU: &mockPerc, }, }, }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, SecurityGroups: []string{"sg-456", "sg-789"}, }, }, @@ -615,13 +620,13 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { Value: nil, AdvancedCount: AdvancedCount{ Range: Range{Value: &mockRange}, - CPU: aws.Int(80), + CPU: &mockPerc, }, }, }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, SecurityGroups: []string{"sg-456", "sg-789"}, }, }, @@ -1051,13 +1056,13 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { HealthCheck: HealthCheckArgsOrString{ HealthCheckPath: aws.String("path"), }, - AllowedSourceIps: []string{"ip1", "ip2"}, + AllowedSourceIps: []IPNet{mockIPNet1}, }, }, Environments: map[string]*LoadBalancedWebServiceConfig{ "prod-iad": { RoutingRule: RoutingRule{ - AllowedSourceIps: []string{"ip1", "ip3"}, + AllowedSourceIps: []IPNet{mockIPNet2}, }, }, }, @@ -1074,7 +1079,7 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { HealthCheck: HealthCheckArgsOrString{ HealthCheckPath: aws.String("path"), }, - AllowedSourceIps: []string{"ip1", "ip3"}, + AllowedSourceIps: []IPNet{mockIPNet2}, }, }, }, @@ -1090,7 +1095,7 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { HealthCheck: HealthCheckArgsOrString{ HealthCheckPath: aws.String("path"), }, - AllowedSourceIps: []string{"ip1", "ip2"}, + AllowedSourceIps: []IPNet{mockIPNet1, mockIPNet2}, }, }, Environments: map[string]*LoadBalancedWebServiceConfig{ @@ -1115,7 +1120,7 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { HealthCheck: HealthCheckArgsOrString{ HealthCheckPath: aws.String("another-path"), }, - AllowedSourceIps: []string{"ip1", "ip2"}, + AllowedSourceIps: []IPNet{mockIPNet1, mockIPNet2}, }, }, }, @@ -1131,7 +1136,7 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { HealthCheck: HealthCheckArgsOrString{ HealthCheckPath: aws.String("path"), }, - AllowedSourceIps: []string{"ip1", "ip2"}, + AllowedSourceIps: []IPNet{mockIPNet1, mockIPNet2}, }, }, Environments: map[string]*LoadBalancedWebServiceConfig{ @@ -1140,7 +1145,7 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { HealthCheck: HealthCheckArgsOrString{ HealthCheckPath: aws.String("another-path"), }, - AllowedSourceIps: []string{}, + AllowedSourceIps: []IPNet{}, }, }, }, @@ -1157,7 +1162,7 @@ func TestLoadBalancedWebService_ApplyEnv(t *testing.T) { HealthCheck: HealthCheckArgsOrString{ HealthCheckPath: aws.String("another-path"), }, - AllowedSourceIps: []string{}, + AllowedSourceIps: []IPNet{}, }, }, }, diff --git a/internal/pkg/manifest/rd_web_svc.go b/internal/pkg/manifest/rd_web_svc.go index 1cfba6e6a40..b19ce1a33df 100644 --- a/internal/pkg/manifest/rd_web_svc.go +++ b/internal/pkg/manifest/rd_web_svc.go @@ -94,7 +94,8 @@ func (s *RequestDrivenWebService) TaskPlatform() (*string, error) { if s.InstanceConfig.Platform.PlatformString == nil { return nil, nil } - return s.InstanceConfig.Platform.PlatformString, nil + str := string(*s.InstanceConfig.Platform.PlatformString) + return &str, nil } // BuildArgs returns a docker.BuildArguments object given a ws root directory. diff --git a/internal/pkg/manifest/svc.go b/internal/pkg/manifest/svc.go index 7cf470edf78..66000f6a46b 100644 --- a/internal/pkg/manifest/svc.go +++ b/internal/pkg/manifest/svc.go @@ -175,13 +175,16 @@ func (c *Count) Desired() (*int, error) { return aws.Int(min), nil } +// Percentage represents a valid percentage integer ranging from 0 to 100. +type Percentage int + // AdvancedCount represents the configurable options for Auto Scaling as well as // Capacity configuration (spot). type AdvancedCount struct { Spot *int `yaml:"spot"` // mutually exclusive with other fields Range Range `yaml:"range"` - CPU *int `yaml:"cpu_percentage"` - Memory *int `yaml:"memory_percentage"` + CPU *Percentage `yaml:"cpu_percentage"` + Memory *Percentage `yaml:"memory_percentage"` Requests *int `yaml:"requests"` ResponseTime *time.Duration `yaml:"response_time"` QueueScaling QueueScaling `yaml:"queue_delay"` diff --git a/internal/pkg/manifest/svc_test.go b/internal/pkg/manifest/svc_test.go index 72909d5ec7c..34a9ed2c5dd 100644 --- a/internal/pkg/manifest/svc_test.go +++ b/internal/pkg/manifest/svc_test.go @@ -15,6 +15,7 @@ import ( ) func TestUnmarshalSvc(t *testing.T) { + mockPerc := Percentage(70) testCases := map[string]struct { inContent string @@ -43,6 +44,9 @@ http: alias: - foobar.com - v1.foobar.com + allowed_source_ips: + - 10.1.0.0/24 + - 10.1.1.0/24 variables: LOG_LEVEL: "WARN" secrets: @@ -103,6 +107,7 @@ environments: HealthCheck: HealthCheckArgsOrString{ HealthCheckPath: aws.String("/"), }, + AllowedSourceIps: []IPNet{IPNet("10.1.0.0/24"), IPNet("10.1.1.0/24")}, }, TaskConfig: TaskConfig{ CPU: aws.Int(512), @@ -144,7 +149,7 @@ environments: }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, }, }, TaskDefOverrides: []OverrideRule{ @@ -200,7 +205,7 @@ environments: Range: Range{ Value: &mockRange, }, - CPU: aws.Int(70), + CPU: &mockPerc, }, }, }, @@ -263,7 +268,7 @@ secrets: }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, }, }, }, @@ -326,7 +331,7 @@ subscribe: }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, }, }, Subscribe: SubscribeConfig{ @@ -379,8 +384,12 @@ type: 'OH NO' } func TestCount_UnmarshalYAML(t *testing.T) { - mockResponseTime := 500 * time.Millisecond - mockRange := IntRangeBand("1-10") + var ( + mockResponseTime = 500 * time.Millisecond + mockRange = IntRangeBand("1-10") + mockCPU = Percentage(70) + mockMem = Percentage(80) + ) testCases := map[string]struct { inContent []byte @@ -405,8 +414,8 @@ func TestCount_UnmarshalYAML(t *testing.T) { wantedStruct: Count{ AdvancedCount: AdvancedCount{ Range: Range{Value: &mockRange}, - CPU: aws.Int(70), - Memory: aws.Int(80), + CPU: &mockCPU, + Memory: &mockMem, Requests: aws.Int(1000), ResponseTime: &mockResponseTime, }, @@ -445,7 +454,7 @@ func TestCount_UnmarshalYAML(t *testing.T) { min: 2 max: 8 spot_from: 3 - cpu_percentage: 50 + cpu_percentage: 70 `), wantedStruct: Count{ AdvancedCount: AdvancedCount{ @@ -456,7 +465,7 @@ func TestCount_UnmarshalYAML(t *testing.T) { SpotFrom: aws.Int(3), }, }, - CPU: aws.Int(50), + CPU: &mockCPU, }, }, }, diff --git a/internal/pkg/manifest/transform_test.go b/internal/pkg/manifest/transform_test.go index 559063943ed..5f4f72f3843 100644 --- a/internal/pkg/manifest/transform_test.go +++ b/internal/pkg/manifest/transform_test.go @@ -333,6 +333,7 @@ func TestStringSliceOrStringTransformer_Transformer(t *testing.T) { } func TestPlatformArgsOrStringTransformer_Transformer(t *testing.T) { + mockPlatformStr := PlatformString("mockString") testCases := map[string]struct { original func(p *PlatformArgsOrString) override func(p *PlatformArgsOrString) @@ -340,7 +341,7 @@ func TestPlatformArgsOrStringTransformer_Transformer(t *testing.T) { }{ "string set to empty if args is not nil": { original: func(p *PlatformArgsOrString) { - p.PlatformString = aws.String("mockString") + p.PlatformString = &mockPlatformStr }, override: func(p *PlatformArgsOrString) { p.PlatformArgs = PlatformArgs{ @@ -363,10 +364,10 @@ func TestPlatformArgsOrStringTransformer_Transformer(t *testing.T) { } }, override: func(p *PlatformArgsOrString) { - p.PlatformString = aws.String("mockString") + p.PlatformString = &mockPlatformStr }, wanted: func(p *PlatformArgsOrString) { - p.PlatformString = aws.String("mockString") + p.PlatformString = &mockPlatformStr }, }, } @@ -513,6 +514,7 @@ func TestCountTransformer_Transformer(t *testing.T) { } func TestAdvancedCountTransformer_Transformer(t *testing.T) { + mockPerc := Percentage(80) testCases := map[string]struct { original func(a *AdvancedCount) override func(a *AdvancedCount) @@ -526,14 +528,14 @@ func TestAdvancedCountTransformer_Transformer(t *testing.T) { a.Range = Range{ Value: (*IntRangeBand)(aws.String("1-10")), } - a.CPU = aws.Int(1024) + a.CPU = &mockPerc a.Requests = aws.Int(42) }, wanted: func(a *AdvancedCount) { a.Range = Range{ Value: (*IntRangeBand)(aws.String("1-10")), } - a.CPU = aws.Int(1024) + a.CPU = &mockPerc a.Requests = aws.Int(42) }, }, @@ -542,7 +544,7 @@ func TestAdvancedCountTransformer_Transformer(t *testing.T) { a.Range = Range{ Value: (*IntRangeBand)(aws.String("1-10")), } - a.CPU = aws.Int(1024) + a.CPU = &mockPerc a.Requests = aws.Int(42) }, override: func(a *AdvancedCount) { diff --git a/internal/pkg/manifest/validate.go b/internal/pkg/manifest/validate.go index 3059a864b51..c09b2d9bb58 100644 --- a/internal/pkg/manifest/validate.go +++ b/internal/pkg/manifest/validate.go @@ -6,6 +6,7 @@ package manifest import ( "errors" "fmt" + "net" "regexp" "strconv" "strings" @@ -284,6 +285,12 @@ func (r *RoutingRule) Validate() error { secondField: "targetContainer", } } + for ind, ip := range r.AllowedSourceIps { + if err = ip.Validate(); err != nil { + return fmt.Errorf(`validate "allowed_source_ips[%v]": %w`, ind, err) + } + } + return nil } @@ -308,6 +315,17 @@ func (*Alias) Validate() error { return nil } +// Validate returns nil if IPNet is configured correctly. +func (ip *IPNet) Validate() error { + if ip == nil { + return nil + } + if _, _, err := net.ParseCIDR(string(*ip)); err != nil { + return err + } + return nil +} + // Validate returns nil if TaskConfig is configured correctly. func (t *TaskConfig) Validate() error { var err error @@ -328,9 +346,25 @@ func (t *TaskConfig) Validate() error { // Validate returns nil if PlatformArgsOrString is configured correctly. func (p *PlatformArgsOrString) Validate() error { + if err := p.PlatformString.Validate(); err != nil { + return err + } return p.PlatformArgs.Validate() } +// Validate returns nil if PlatformString is configured correctly. +func (p *PlatformString) Validate() error { + if p == nil { + return nil + } + val := string(*p) + reg := regexp.MustCompile(`^.+\/.+$`) + if !reg.MatchString(val) { + return fmt.Errorf(`cannot use %s for platform. Must match the regex "^.+\/.+$"`, val) + } + return nil +} + // Validate returns nil if PlatformArgsOrString is configured correctly. // TODO: add validation once "feat/pencere" is merged. func (p *PlatformArgs) Validate() error { @@ -384,6 +418,24 @@ func (a *AdvancedCount) Validate() error { return fmt.Errorf(`validate "queue_delay": %w`, err) } } + + if err := a.CPU.Validate(); err != nil { + return fmt.Errorf(`validate "cpu_percentage": %w`, err) + } + if err := a.Memory.Validate(); err != nil { + return fmt.Errorf(`validate "memory_percentage": %w`, err) + } + return nil +} + +// Validate returns nil if Percentage is configured correctly. +func (p *Percentage) Validate() error { + if p == nil { + return nil + } + if val := int(*p); val < 0 || val > 100 { + return fmt.Errorf("cannot specify %v as Percentage. Must be an integer from 0 to 100", val) + } return nil } @@ -593,9 +645,25 @@ func (v *vpcConfig) Validate() error { if v.isEmpty() { return nil } + if err := v.Placement.Validate(); err != nil { + return fmt.Errorf(`validate "placement": %w`, err) + } return nil } +// Validate returns nil if Placement is configured correctly. +func (p *Placement) Validate() error { + if p == nil { + return fmt.Errorf(`"placement" cannot be empty`) + } + for _, allowed := range subnetPlacements { + if string(*p) == allowed { + return nil + } + } + return fmt.Errorf(`"placement" %s is invalid. Must be one of %#v"`, string(*p), subnetPlacements) +} + // Validate returns nil if RequestDrivenWebServiceHttpConfig is configured correctly. func (r *RequestDrivenWebServiceHttpConfig) Validate() error { return r.HealthCheckConfiguration.Validate() diff --git a/internal/pkg/manifest/validate_test.go b/internal/pkg/manifest/validate_test.go index 8d751129131..79f8ba47d1d 100644 --- a/internal/pkg/manifest/validate_test.go +++ b/internal/pkg/manifest/validate_test.go @@ -50,6 +50,20 @@ func TestLoadBalancedWebServiceConfig_Validate(t *testing.T) { }, wantedErrorPrefix: `validate "http": `, }, + "error if fail to validate network": { + lbConfig: LoadBalancedWebServiceConfig{ + ImageConfig: testImageConfig, + RoutingRule: RoutingRule{ + TargetContainer: aws.String("mockTargetContainer"), + }, + Network: NetworkConfig{ + vpcConfig{ + SecurityGroups: []string{}, + }, + }, + }, + wantedErrorPrefix: `validate "network": `, + }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { @@ -65,6 +79,14 @@ func TestLoadBalancedWebServiceConfig_Validate(t *testing.T) { } func TestBackendServiceConfig_Validate(t *testing.T) { + testImageConfig := ImageWithPortAndHealthcheck{ + ImageWithPort: ImageWithPort{ + Image: Image{ + Build: BuildArgsOrString{BuildString: aws.String("mockBuild")}, + }, + Port: uint16P(80), + }, + } testCases := map[string]struct { config BackendServiceConfig @@ -83,6 +105,17 @@ func TestBackendServiceConfig_Validate(t *testing.T) { }, wantedErrorPrefix: `validate "image": `, }, + "error if fail to validate network": { + config: BackendServiceConfig{ + ImageConfig: testImageConfig, + Network: NetworkConfig{ + vpcConfig{ + SecurityGroups: []string{}, + }, + }, + }, + wantedErrorPrefix: `validate "network": `, + }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { @@ -129,6 +162,11 @@ func TestRequestDrivenWebServiceConfig_Validate(t *testing.T) { } func TestWorkerServiceConfig_Validate(t *testing.T) { + testImageConfig := ImageWithHealthcheck{ + Image: Image{ + Build: BuildArgsOrString{BuildString: aws.String("mockBuild")}, + }, + } testCases := map[string]struct { config WorkerServiceConfig @@ -145,6 +183,17 @@ func TestWorkerServiceConfig_Validate(t *testing.T) { }, wantedErrorPrefix: `validate "image": `, }, + "error if fail to validate network": { + config: WorkerServiceConfig{ + ImageConfig: testImageConfig, + Network: NetworkConfig{ + vpcConfig{ + SecurityGroups: []string{}, + }, + }, + }, + wantedErrorPrefix: `validate "network": `, + }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { @@ -160,6 +209,11 @@ func TestWorkerServiceConfig_Validate(t *testing.T) { } func TestScheduledJobConfig_Validate(t *testing.T) { + testImageConfig := ImageWithHealthcheck{ + Image: Image{ + Build: BuildArgsOrString{BuildString: aws.String("mockBuild")}, + }, + } testCases := map[string]struct { config ScheduledJobConfig @@ -176,6 +230,17 @@ func TestScheduledJobConfig_Validate(t *testing.T) { }, wantedErrorPrefix: `validate "image": `, }, + "error if fail to validate network": { + config: ScheduledJobConfig{ + ImageConfig: testImageConfig, + Network: NetworkConfig{ + vpcConfig{ + SecurityGroups: []string{}, + }, + }, + }, + wantedErrorPrefix: `validate "network": `, + }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { @@ -255,7 +320,8 @@ func TestRoutingRule_Validate(t *testing.T) { testCases := map[string]struct { RoutingRule RoutingRule - wantedError error + wantedErrorMsgPrefix string + wantedError error }{ "error if both target_container and targetContainer are specified": { RoutingRule: RoutingRule{ @@ -264,6 +330,16 @@ func TestRoutingRule_Validate(t *testing.T) { }, wantedError: fmt.Errorf(`must specify one, not both, of "target_container" and "targetContainer"`), }, + "error if one of allowed_source_ips is not valid": { + RoutingRule: RoutingRule{ + AllowedSourceIps: []IPNet{ + IPNet("10.1.0.0/24"), + IPNet("badIP"), + IPNet("10.1.1.0/24"), + }, + }, + wantedErrorMsgPrefix: `validate "allowed_source_ips[1]": `, + }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { @@ -271,25 +347,63 @@ func TestRoutingRule_Validate(t *testing.T) { if tc.wantedError != nil { require.EqualError(t, gotErr, tc.wantedError.Error()) + return + } + if tc.wantedErrorMsgPrefix != "" { + require.Error(t, gotErr) + require.Contains(t, gotErr.Error(), tc.wantedErrorMsgPrefix) + return + } + require.NoError(t, gotErr) + }) + } +} + +func TestIPNet_Validate(t *testing.T) { + testCases := map[string]struct { + in IPNet + wanted error + }{ + "should return an error if IPNet is not valid": { + in: IPNet("badIPNet"), + wanted: errors.New("invalid CIDR address: badIPNet"), + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + err := tc.in.Validate() + + if tc.wanted != nil { + require.EqualError(t, err, tc.wanted.Error()) } else { - require.NoError(t, gotErr) + require.NoError(t, err) } }) } } func TestTaskConfig_Validate(t *testing.T) { + mockPerc := Percentage(70) + mockBadPlatformStr := PlatformString("mockBadPlatform") testCases := map[string]struct { TaskConfig TaskConfig wantedErrorPrefix string }{ + "error if fail to validate platform": { + TaskConfig: TaskConfig{ + Platform: PlatformArgsOrString{ + PlatformString: &mockBadPlatformStr, + }, + }, + wantedErrorPrefix: `validate "platform": `, + }, "error if fail to validate count": { TaskConfig: TaskConfig{ Count: Count{ AdvancedCount: AdvancedCount{ Spot: aws.Int(123), - CPU: aws.Int(123), + CPU: &mockPerc, }, }, }, @@ -326,7 +440,34 @@ func TestTaskConfig_Validate(t *testing.T) { } } +func TestPlatformString_Validate(t *testing.T) { + testCases := map[string]struct { + in PlatformString + wanted error + }{ + "should return an error if platform string is not valid": { + in: PlatformString("NS"), + wanted: errors.New(`cannot use NS for platform. Must match the regex "^.+\/.+$"`), + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + err := tc.in.Validate() + + if tc.wanted != nil { + require.EqualError(t, err, tc.wanted.Error()) + } else { + require.NoError(t, err) + } + }) + } +} + func TestAdvancedCount_Validate(t *testing.T) { + var ( + mockPerc = Percentage(70) + invalidPerc = Percentage(-1) + ) testCases := map[string]struct { AdvancedCount AdvancedCount @@ -351,7 +492,7 @@ func TestAdvancedCount_Validate(t *testing.T) { Range: Range{ Value: (*IntRangeBand)(aws.String("1-10")), }, - CPU: aws.Int(70), + CPU: &mockPerc, QueueScaling: QueueScaling{ AcceptableLatency: durationp(10 * time.Second), AvgProcessingTime: durationp(1 * time.Second), @@ -362,7 +503,7 @@ func TestAdvancedCount_Validate(t *testing.T) { "error if both spot and autoscaling fields are specified": { AdvancedCount: AdvancedCount{ Spot: aws.Int(123), - CPU: aws.Int(70), + CPU: &mockPerc, workloadType: LoadBalancedWebServiceType, }, wantedError: fmt.Errorf(`must specify one, not both, of "spot" and "range/cpu_percentage/memory_percentage/requests/response_time"`), @@ -412,14 +553,14 @@ func TestAdvancedCount_Validate(t *testing.T) { }, "error if range is missing when autoscaling fields are set for Backend Service": { AdvancedCount: AdvancedCount{ - CPU: aws.Int(123), + CPU: &mockPerc, workloadType: BackendServiceType, }, wantedError: fmt.Errorf(`"range" must be specified if "cpu_percentage or memory_percentage" are specified`), }, "error if range is missing when autoscaling fields are set for Worker Service": { AdvancedCount: AdvancedCount{ - CPU: aws.Int(123), + CPU: &mockPerc, workloadType: WorkerServiceType, }, wantedError: fmt.Errorf(`"range" must be specified if "cpu_percentage, memory_percentage or queue_delay" are specified`), @@ -441,6 +582,26 @@ func TestAdvancedCount_Validate(t *testing.T) { }, wantedErrorMsgPrefix: `validate "queue_delay": `, }, + "error if CPU perc is not valid": { + AdvancedCount: AdvancedCount{ + Range: Range{ + Value: (*IntRangeBand)(stringP("1-2")), + }, + CPU: &invalidPerc, + workloadType: LoadBalancedWebServiceType, + }, + wantedErrorMsgPrefix: `validate "cpu_percentage": `, + }, + "error if memory perc is not valid": { + AdvancedCount: AdvancedCount{ + Range: Range{ + Value: (*IntRangeBand)(stringP("1-2")), + }, + Memory: &invalidPerc, + workloadType: LoadBalancedWebServiceType, + }, + wantedErrorMsgPrefix: `validate "memory_percentage": `, + }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { @@ -451,6 +612,7 @@ func TestAdvancedCount_Validate(t *testing.T) { return } if tc.wantedErrorMsgPrefix != "" { + require.Error(t, gotErr) require.Contains(t, gotErr.Error(), tc.wantedErrorMsgPrefix) return } @@ -459,6 +621,29 @@ func TestAdvancedCount_Validate(t *testing.T) { } } +func TestPercentage_Validate(t *testing.T) { + testCases := map[string]struct { + in Percentage + wanted error + }{ + "should return an error if percentage is not valid": { + in: Percentage(120), + wanted: errors.New("cannot specify 120 as Percentage. Must be an integer from 0 to 100"), + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + err := tc.in.Validate() + + if tc.wanted != nil { + require.EqualError(t, err, tc.wanted.Error()) + } else { + require.NoError(t, err) + } + }) + } +} + func TestQueueScaling_Validate(t *testing.T) { testCases := map[string]struct { in QueueScaling @@ -683,3 +868,84 @@ func TestEFSVolumeConfiguration_Validate(t *testing.T) { }) } } + +func TestNetworkConfig_Validate(t *testing.T) { + testCases := map[string]struct { + config NetworkConfig + + wantedErrorPrefix string + }{ + "error if fail to validate vpc": { + config: NetworkConfig{ + VPC: vpcConfig{ + SecurityGroups: []string{}, + }, + }, + wantedErrorPrefix: `validate "vpc": `, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + gotErr := tc.config.Validate() + + if tc.wantedErrorPrefix != "" { + require.Contains(t, gotErr.Error(), tc.wantedErrorPrefix) + } else { + require.NoError(t, gotErr) + } + }) + } +} + +func TestVpcConfig_Validate(t *testing.T) { + testCases := map[string]struct { + config vpcConfig + + wantedErrorPrefix string + }{ + "error if fail to validate placement": { + config: vpcConfig{ + SecurityGroups: []string{}, + }, + wantedErrorPrefix: `validate "placement": `, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + gotErr := tc.config.Validate() + + if tc.wantedErrorPrefix != "" { + require.Contains(t, gotErr.Error(), tc.wantedErrorPrefix) + } else { + require.NoError(t, gotErr) + } + }) + } +} + +func TestPlacement_Validate(t *testing.T) { + mockInvalidPlacement := Placement("external") + testCases := map[string]struct { + in *Placement + wanted error + }{ + "should return an error if placement is empty": { + wanted: errors.New(`"placement" cannot be empty`), + }, + "should return an error if placement is invalid": { + in: &mockInvalidPlacement, + wanted: errors.New(`"placement" external is invalid. Must be one of []string{"public", "private"}"`), + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + err := tc.in.Validate() + + if tc.wanted != nil { + require.EqualError(t, err, tc.wanted.Error()) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/pkg/manifest/worker_svc.go b/internal/pkg/manifest/worker_svc.go index 6badcd0de4a..0efdb8c4c59 100644 --- a/internal/pkg/manifest/worker_svc.go +++ b/internal/pkg/manifest/worker_svc.go @@ -171,7 +171,7 @@ func newDefaultWorkerService() *WorkerService { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: aws.String(PublicSubnetPlacement), + Placement: &PublicSubnetPlacement, }, }, }, diff --git a/internal/pkg/manifest/worker_svc_test.go b/internal/pkg/manifest/worker_svc_test.go index ad20e660a42..3549e57f597 100644 --- a/internal/pkg/manifest/worker_svc_test.go +++ b/internal/pkg/manifest/worker_svc_test.go @@ -55,7 +55,7 @@ func TestNewWorkerSvc(t *testing.T) { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, }, }, }, @@ -96,7 +96,7 @@ func TestNewWorkerSvc(t *testing.T) { }, Network: NetworkConfig{ VPC: vpcConfig{ - Placement: stringP("public"), + Placement: &PublicSubnetPlacement, }, }, }, @@ -188,6 +188,7 @@ func TestWorkerSvc_MarshalBinary(t *testing.T) { } func TestWorkerSvc_ApplyEnv(t *testing.T) { + mockPerc := Percentage(70) mockWorkerServiceWithNoEnvironments := WorkerService{ Workload: Workload{ Name: aws.String("phonetool"), @@ -293,7 +294,7 @@ func TestWorkerSvc_ApplyEnv(t *testing.T) { TaskConfig: TaskConfig{ Count: Count{ AdvancedCount: AdvancedCount{ - CPU: aws.Int(70), + CPU: &mockPerc, }, }, CPU: aws.Int(512), @@ -714,7 +715,7 @@ func TestWorkerSvc_ApplyEnv(t *testing.T) { Memory: aws.Int(256), Count: Count{ AdvancedCount: AdvancedCount{ - CPU: aws.Int(70), + CPU: &mockPerc, }, }, Variables: map[string]string{ @@ -1084,6 +1085,7 @@ func TestWorkerSvc_ApplyEnv(t *testing.T) { func TestWorkerSvc_ApplyEnv_CountOverrides(t *testing.T) { mockRange := IntRangeBand("1-10") + mockPerc := Percentage(80) testCases := map[string]struct { svcCount Count envCount Count @@ -1094,7 +1096,7 @@ func TestWorkerSvc_ApplyEnv_CountOverrides(t *testing.T) { svcCount: Count{ AdvancedCount: AdvancedCount{ Range: Range{Value: &mockRange}, - CPU: aws.Int(80), + CPU: &mockPerc, }, }, envCount: Count{}, @@ -1104,7 +1106,7 @@ func TestWorkerSvc_ApplyEnv_CountOverrides(t *testing.T) { Count: Count{ AdvancedCount: AdvancedCount{ Range: Range{Value: &mockRange}, - CPU: aws.Int(80), + CPU: &mockPerc, }, }, }, diff --git a/internal/pkg/manifest/workload.go b/internal/pkg/manifest/workload.go index 98763862730..7cbc1a813b1 100644 --- a/internal/pkg/manifest/workload.go +++ b/internal/pkg/manifest/workload.go @@ -24,18 +24,18 @@ import ( const ( defaultFluentbitImage = "amazon/aws-for-fluent-bit:latest" defaultDockerfileName = "Dockerfile" - - // AWS VPC subnet placement options. - PublicSubnetPlacement = "public" - PrivateSubnetPlacement = "private" ) var ( + // AWS VPC subnet placement options. + PublicSubnetPlacement = Placement("public") + PrivateSubnetPlacement = Placement("private") + // WorkloadTypes holds all workload manifest types. WorkloadTypes = append(ServiceTypes, JobTypes...) // All placement options. - subnetPlacements = []string{PublicSubnetPlacement, PrivateSubnetPlacement} + subnetPlacements = []string{string(PublicSubnetPlacement), string(PrivateSubnetPlacement)} validPlatforms = []string{dockerengine.DockerBuildPlatform(dockerengine.LinuxOS, dockerengine.Amd64Arch)} validOperatingSystems = []string{dockerengine.LinuxOS} @@ -434,7 +434,8 @@ func (t *TaskConfig) TaskPlatform() (*string, error) { if t.Platform.PlatformString == nil { return nil, nil } - return t.Platform.PlatformString, nil + val := string(*t.Platform.PlatformString) + return &val, nil } // PublishConfig represents the configurable options for setting up publishers. @@ -461,8 +462,9 @@ func (c *NetworkConfig) IsEmpty() bool { // If the user specified a placement that's not valid then throw an error. func (c *NetworkConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { type networkWithDefaults NetworkConfig + publicPlacement := Placement(PublicSubnetPlacement) defaultVPCConf := vpcConfig{ - Placement: stringP(PublicSubnetPlacement), + Placement: &publicPlacement, } conf := networkWithDefaults{ VPC: defaultVPCConf, @@ -474,15 +476,18 @@ func (c *NetworkConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { conf.VPC = defaultVPCConf } if !conf.VPC.isValidPlacement() { - return fmt.Errorf("field '%s' is '%v' must be one of %#v", "network.vpc.placement", aws.StringValue(conf.VPC.Placement), subnetPlacements) + return fmt.Errorf("field '%s' is '%v' must be one of %#v", "network.vpc.placement", string(*conf.VPC.Placement), subnetPlacements) } *c = NetworkConfig(conf) return nil } +// Placement represents where to place tasks (public or private subnets). +type Placement string + // vpcConfig represents the security groups and subnets attached to a task. type vpcConfig struct { - Placement *string `yaml:"placement"` + *Placement `yaml:"placement"` SecurityGroups []string `yaml:"security_groups"` } @@ -495,7 +500,7 @@ func (c *vpcConfig) isValidPlacement() bool { return false } for _, allowed := range subnetPlacements { - if *c.Placement == allowed { + if string(*c.Placement) == allowed { return true } } @@ -594,11 +599,24 @@ func (hc *ContainerHealthCheck) ApplyIfNotSet(other *ContainerHealthCheck) { } } +// PlatformString represents the platform string consisting of OS family and architecture type. +// For example: "windows/x86" +type PlatformString string + +// PlatformStringP converts a string pointer to a PlatformString pointer. +func PlatformStringP(s *string) *PlatformString { + if s == nil { + return nil + } + val := PlatformString(*s) + return &val +} + // PlatformArgsOrString is a custom type which supports unmarshaling yaml which // can either be of type string or type PlatformArgs. type PlatformArgsOrString struct { - PlatformString *string - PlatformArgs PlatformArgs + *PlatformString + PlatformArgs PlatformArgs } // UnmarshalYAML overrides the default YAML unmarshaling logic for the PlatformArgsOrString @@ -651,16 +669,16 @@ func (p *PlatformArgs) bothSpecified() bool { return (p.OSFamily != nil) && (p.Arch != nil) } -func validatePlatform(platform *string) error { +func validatePlatform(platform *PlatformString) error { if platform == nil { return nil } for _, validPlatform := range validPlatforms { - if aws.StringValue(platform) == validPlatform { + if string(*platform) == validPlatform { return nil } } - return fmt.Errorf("platform %s is invalid; %s: %s", aws.StringValue(platform), english.PluralWord(len(validPlatforms), "the valid platform is", "valid platforms are"), english.WordSeries(validPlatforms, "and")) + return fmt.Errorf("platform %s is invalid; %s: %s", string(*platform), english.PluralWord(len(validPlatforms), "the valid platform is", "valid platforms are"), english.WordSeries(validPlatforms, "and")) } func validateOS(os *string) error { diff --git a/internal/pkg/manifest/workload_test.go b/internal/pkg/manifest/workload_test.go index 0503726cb9c..5947955b24f 100644 --- a/internal/pkg/manifest/workload_test.go +++ b/internal/pkg/manifest/workload_test.go @@ -287,6 +287,7 @@ func TestBuildArgs_UnmarshalYAML(t *testing.T) { } func TestPlatformArgsOrString_UnmarshalYAML(t *testing.T) { + mockPlatformStr := PlatformString("linux/amd64") testCases := map[string]struct { inContent []byte @@ -331,7 +332,7 @@ func TestPlatformArgsOrString_UnmarshalYAML(t *testing.T) { inContent: []byte(`platform: linux/amd64`), wantedStruct: PlatformArgsOrString{ - PlatformString: aws.String("linux/amd64"), + PlatformString: &mockPlatformStr, }, }, "both os/arch specified with valid values": { @@ -668,7 +669,7 @@ network: `, wantedConfig: &NetworkConfig{ VPC: vpcConfig{ - Placement: stringP(PublicSubnetPlacement), + Placement: &PublicSubnetPlacement, }, }, }, @@ -691,7 +692,7 @@ network: `, wantedConfig: &NetworkConfig{ VPC: vpcConfig{ - Placement: stringP(PublicSubnetPlacement), + Placement: &PublicSubnetPlacement, SecurityGroups: []string{"sg-1234", "sg-4567"}, }, },