diff --git a/loader/loader.go b/loader/loader.go index 02f5bbdc..fbbe6291 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -624,7 +624,6 @@ func createTransformHook(additionalTransformers ...Transformer) mapstructure.Dec reflect.TypeOf(types.BuildConfig{}): transformBuildConfig, reflect.TypeOf(types.DependsOnConfig{}): transformDependsOnConfig, reflect.TypeOf(types.ExtendsConfig{}): transformExtendsConfig, - reflect.TypeOf(types.DeviceRequest{}): transformServiceDeviceRequest, reflect.TypeOf(types.SSHConfig{}): transformSSHConfig, reflect.TypeOf(types.IncludeConfig{}): transformIncludeConfig, } @@ -1087,35 +1086,6 @@ var transformServicePort TransformerFunc = func(data interface{}) (interface{}, } } -var transformServiceDeviceRequest TransformerFunc = func(data interface{}) (interface{}, error) { - switch value := data.(type) { - case map[string]interface{}: - count, ok := value["count"] - if ok { - switch val := count.(type) { - case int: - return value, nil - case string: - if strings.ToLower(val) == "all" { - value["count"] = -1 - return value, nil - } - i, err := strconv.ParseInt(val, 10, 64) - if err == nil { - value["count"] = i - return value, nil - } - return data, errors.Errorf("invalid string value for 'count' (the only value allowed is 'all' or a number)") - default: - return data, errors.Errorf("invalid type %T for device count", val) - } - } - return data, nil - default: - return data, errors.Errorf("invalid type %T for resource reservation", value) - } -} - var transformFileReferenceConfig TransformerFunc = func(data interface{}) (interface{}, error) { switch value := data.(type) { case string: diff --git a/loader/loader_test.go b/loader/loader_test.go index 44ffc156..a3f4a175 100644 --- a/loader/loader_test.go +++ b/loader/loader_test.go @@ -2129,9 +2129,9 @@ services: devices: - driver: nvidia capabilities: [gpu] - count: somestring + count: some_string `) - assert.ErrorContains(t, err, "invalid string value for 'count' (the only value allowed is 'all' or a number)") + assert.ErrorContains(t, err, `invalid value "some_string", the only value allowed is 'all' or a number`) } func TestServicePullPolicy(t *testing.T) { diff --git a/types/device.go b/types/device.go new file mode 100644 index 00000000..81b4bea4 --- /dev/null +++ b/types/device.go @@ -0,0 +1,53 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package types + +import ( + "strconv" + "strings" + + "github.com/pkg/errors" +) + +type DeviceRequest struct { + Capabilities []string `yaml:"capabilities,omitempty" json:"capabilities,omitempty"` + Driver string `yaml:"driver,omitempty" json:"driver,omitempty"` + Count DeviceCount `yaml:"count,omitempty" json:"count,omitempty"` + IDs []string `yaml:"device_ids,omitempty" json:"device_ids,omitempty"` +} + +type DeviceCount int64 + +func (c *DeviceCount) DecodeMapstructure(value interface{}) error { + switch v := value.(type) { + case int: + *c = DeviceCount(v) + case string: + if strings.ToLower(v) == "all" { + *c = -1 + return nil + } + i, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return errors.Errorf("invalid value %q, the only value allowed is 'all' or a number", v) + } + *c = DeviceCount(i) + default: + return errors.Errorf("invalid type %T for device count", v) + } + return nil +} diff --git a/types/types.go b/types/types.go index 9c1ad497..5dc93bb1 100644 --- a/types/types.go +++ b/types/types.go @@ -584,13 +584,6 @@ type Resource struct { Extensions Extensions `yaml:"#extensions,inline" json:"-"` } -type DeviceRequest struct { - Capabilities []string `yaml:"capabilities,omitempty" json:"capabilities,omitempty"` - Driver string `yaml:"driver,omitempty" json:"driver,omitempty"` - Count int64 `yaml:"count,omitempty" json:"count,omitempty"` - IDs []string `yaml:"device_ids,omitempty" json:"device_ids,omitempty"` -} - // GenericResource represents a "user defined" resource which can // only be an integer (e.g: SSD=3) for a service type GenericResource struct {