diff --git a/pkg/cdi/cache.go b/pkg/cdi/cache.go index 04f15e02..0abe8332 100644 --- a/pkg/cdi/cache.go +++ b/pkg/cdi/cache.go @@ -28,6 +28,7 @@ import ( "github.com/fsnotify/fsnotify" oci "github.com/opencontainers/runtime-spec/specs-go" + "tags.cncf.io/container-device-interface/pkg/cdi/producer" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -281,30 +282,31 @@ func (c *Cache) highestPrioritySpecDir() (string, int) { // priority Spec directory. If name has a "json" or "yaml" extension it // choses the encoding. Otherwise the default YAML encoding is used. func (c *Cache) WriteSpec(raw *cdi.Spec, name string) error { - var ( - specDir string - path string - prio int - spec *Spec - err error - ) - - specDir, prio = c.highestPrioritySpecDir() + specDir, _ := c.highestPrioritySpecDir() if specDir == "" { return errors.New("no Spec directories to write to") } - path = filepath.Join(specDir, name) - if ext := filepath.Ext(path); ext != ".json" && ext != ".yaml" { - path += defaultSpecExt + // Ideally we would like to pass the configured spec validator to the + // producer, but we would need to handle the synchronisation. + // Instead we call `validateSpec` here which is a no-op if no validator is + // configured. + if err := validateSpec(raw); err != nil { + return err } - spec, err = newSpec(raw, path, prio) + path := filepath.Join(specDir, name) + + p, err := producer.New(raw, + producer.WithOverwrite(true), + ) if err != nil { return err } - - return spec.write(true) + if _, err := p.Save(path); err != nil { + return err + } + return nil } // RemoveSpec removes a Spec with the given name from the highest diff --git a/pkg/cdi/container-edits.go b/pkg/cdi/container-edits.go index a7ac70d0..3632176e 100644 --- a/pkg/cdi/container-edits.go +++ b/pkg/cdi/container-edits.go @@ -26,6 +26,8 @@ import ( oci "github.com/opencontainers/runtime-spec/specs-go" ocigen "github.com/opencontainers/runtime-tools/generate" + + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -44,18 +46,6 @@ const ( PoststopHook = "poststop" ) -var ( - // Names of recognized hooks. - validHookNames = map[string]struct{}{ - PrestartHook: {}, - CreateRuntimeHook: {}, - CreateContainerHook: {}, - StartContainerHook: {}, - PoststartHook: {}, - PoststopHook: {}, - } -) - // ContainerEdits represent updates to be applied to an OCI Spec. // These updates can be specific to a CDI device, or they can be // specific to a CDI Spec. In the former case these edits should @@ -167,32 +157,7 @@ func (e *ContainerEdits) Validate() error { if e == nil || e.ContainerEdits == nil { return nil } - - if err := ValidateEnv(e.Env); err != nil { - return fmt.Errorf("invalid container edits: %w", err) - } - for _, d := range e.DeviceNodes { - if err := (&DeviceNode{d}).Validate(); err != nil { - return err - } - } - for _, h := range e.Hooks { - if err := (&Hook{h}).Validate(); err != nil { - return err - } - } - for _, m := range e.Mounts { - if err := (&Mount{m}).Validate(); err != nil { - return err - } - } - if e.IntelRdt != nil { - if err := (&IntelRdt{e.IntelRdt}).Validate(); err != nil { - return err - } - } - - return nil + return validator.Default.ValidateAny(e.ContainerEdits) } // Append other edits into this one. If called with a nil receiver, @@ -220,43 +185,6 @@ func (e *ContainerEdits) Append(o *ContainerEdits) *ContainerEdits { return e } -// isEmpty returns true if these edits are empty. This is valid in a -// global Spec context but invalid in a Device context. -func (e *ContainerEdits) isEmpty() bool { - if e == nil { - return false - } - if len(e.Env) > 0 { - return false - } - if len(e.DeviceNodes) > 0 { - return false - } - if len(e.Hooks) > 0 { - return false - } - if len(e.Mounts) > 0 { - return false - } - if len(e.AdditionalGIDs) > 0 { - return false - } - if e.IntelRdt != nil { - return false - } - return true -} - -// ValidateEnv validates the given environment variables. -func ValidateEnv(env []string) error { - for _, v := range env { - if strings.IndexByte(v, byte('=')) <= 0 { - return fmt.Errorf("invalid environment variable %q", v) - } - } - return nil -} - // DeviceNode is a CDI Spec DeviceNode wrapper, used for validating DeviceNodes. type DeviceNode struct { *cdi.DeviceNode @@ -264,27 +192,7 @@ type DeviceNode struct { // Validate a CDI Spec DeviceNode. func (d *DeviceNode) Validate() error { - validTypes := map[string]struct{}{ - "": {}, - "b": {}, - "c": {}, - "u": {}, - "p": {}, - } - - if d.Path == "" { - return errors.New("invalid (empty) device path") - } - if _, ok := validTypes[d.Type]; !ok { - return fmt.Errorf("device %q: invalid type %q", d.Path, d.Type) - } - for _, bit := range d.Permissions { - if bit != 'r' && bit != 'w' && bit != 'm' { - return fmt.Errorf("device %q: invalid permissions %q", - d.Path, d.Permissions) - } - } - return nil + return validator.Default.ValidateAny(d.DeviceNode) } // Hook is a CDI Spec Hook wrapper, used for validating hooks. @@ -294,16 +202,7 @@ type Hook struct { // Validate a hook. func (h *Hook) Validate() error { - if _, ok := validHookNames[h.HookName]; !ok { - return fmt.Errorf("invalid hook name %q", h.HookName) - } - if h.Path == "" { - return fmt.Errorf("invalid hook %q with empty path", h.HookName) - } - if err := ValidateEnv(h.Env); err != nil { - return fmt.Errorf("invalid hook %q: %w", h.HookName, err) - } - return nil + return validator.Default.ValidateAny(h.Hook) } // Mount is a CDI Mount wrapper, used for validating mounts. @@ -313,13 +212,7 @@ type Mount struct { // Validate a mount. func (m *Mount) Validate() error { - if m.HostPath == "" { - return errors.New("invalid mount, empty host path") - } - if m.ContainerPath == "" { - return errors.New("invalid mount, empty container path") - } - return nil + return validator.Default.ValidateAny(m.Mount) } // IntelRdt is a CDI IntelRdt wrapper. @@ -337,11 +230,7 @@ func ValidateIntelRdt(i *cdi.IntelRdt) error { // Validate validates the IntelRdt configuration. func (i *IntelRdt) Validate() error { - // ClosID must be a valid Linux filename - if len(i.ClosID) >= 4096 || i.ClosID == "." || i.ClosID == ".." || strings.ContainsAny(i.ClosID, "/\n") { - return errors.New("invalid ClosID") - } - return nil + return validator.Default.ValidateAny(i.IntelRdt) } // Ensure OCI Spec hooks are not nil so we can add hooks. diff --git a/pkg/cdi/device.go b/pkg/cdi/device.go index 2e5fa57f..9ac050ff 100644 --- a/pkg/cdi/device.go +++ b/pkg/cdi/device.go @@ -17,10 +17,8 @@ package cdi import ( - "fmt" - oci "github.com/opencontainers/runtime-spec/specs-go" - "tags.cncf.io/container-device-interface/internal/validation" + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" "tags.cncf.io/container-device-interface/pkg/parser" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -67,22 +65,5 @@ func (d *Device) edits() *ContainerEdits { // Validate the device. func (d *Device) validate() error { - if err := parser.ValidateDeviceName(d.Name); err != nil { - return err - } - name := d.Name - if d.spec != nil { - name = d.GetQualifiedName() - } - if err := validation.ValidateSpecAnnotations(name, d.Annotations); err != nil { - return err - } - edits := d.edits() - if edits.isEmpty() { - return fmt.Errorf("invalid device, empty device edits") - } - if err := edits.Validate(); err != nil { - return fmt.Errorf("invalid device %q: %w", d.Name, err) - } - return nil + return validator.Default.ValidateAny(d.Device) } diff --git a/pkg/cdi/producer/api.go b/pkg/cdi/producer/api.go new file mode 100644 index 00000000..1a9811e8 --- /dev/null +++ b/pkg/cdi/producer/api.go @@ -0,0 +1,36 @@ +/* + Copyright © 2024 The CDI Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package producer + +import cdi "tags.cncf.io/container-device-interface/specs-go" + +type SpecFormat string + +const ( + // DefaultSpecFormat defines the default encoding used to write CDI specs. + DefaultSpecFormat = SpecFormatYAML + + // SpecFormatJSON defines a CDI spec formatted as JSON. + SpecFormatJSON = SpecFormat(".json") + // SpecFormatYAML defines a CDI spec formatted as YAML. + SpecFormatYAML = SpecFormat(".yaml") +) + +// A SpecValidator is used to validate a CDI spec. +type SpecValidator interface { + Validate(*cdi.Spec) error +} diff --git a/pkg/cdi/producer/options.go b/pkg/cdi/producer/options.go new file mode 100644 index 00000000..3b85974f --- /dev/null +++ b/pkg/cdi/producer/options.go @@ -0,0 +1,75 @@ +/* + Copyright © 2024 The CDI Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package producer + +import ( + "fmt" + "io/fs" + + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" +) + +// An Option defines a functional option for constructing a producer. +type Option func(*options) error + +type options struct { + specFormat SpecFormat + specValidator SpecValidator + overwrite bool + permissions fs.FileMode +} + +// WithSpecFormat sets the output format of a CDI specification. +func WithSpecFormat(format SpecFormat) Option { + return func(o *options) error { + switch format { + case SpecFormatJSON, SpecFormatYAML: + o.specFormat = format + default: + return fmt.Errorf("invalid CDI spec format %v", format) + } + return nil + } +} + +// WithSpecValidator sets a validator to be used when writing an output spec. +func WithSpecValidator(specValidator SpecValidator) Option { + return func(o *options) error { + if specValidator == nil { + specValidator = validator.Disabled + } + o.specValidator = specValidator + return nil + } +} + +// WithOverwrite specifies whether a producer should overwrite a CDI spec when +// saving to file. +func WithOverwrite(overwrite bool) Option { + return func(o *options) error { + o.overwrite = overwrite + return nil + } +} + +// WithPermissions sets the file mode to be used for a saved CDI spec. +func WithPermissions(permissions fs.FileMode) Option { + return func(o *options) error { + o.permissions = permissions + return nil + } +} diff --git a/pkg/cdi/producer/producer.go b/pkg/cdi/producer/producer.go new file mode 100644 index 00000000..6ab787b3 --- /dev/null +++ b/pkg/cdi/producer/producer.go @@ -0,0 +1,188 @@ +/* + Copyright © 2024 The CDI Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package producer + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + + "sigs.k8s.io/yaml" + + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" + cdi "tags.cncf.io/container-device-interface/specs-go" +) + +// A SpecProducer defines a structure for outputting CDI specifications. +type SpecProducer struct { + *cdi.Spec + options +} + +// New creates a new producer with the supplied options. +func New(raw *cdi.Spec, opts ...Option) (*SpecProducer, error) { + sp := &SpecProducer{ + Spec: raw, + options: options{ + overwrite: true, + // TODO: This could be updated to 0644 to be world-readable. + permissions: 0600, + specFormat: DefaultSpecFormat, + specValidator: validator.Default, + }, + } + return sp.applyOptions(opts...) +} + +// cloneWithOptions creates a copy of the specified spec producer with the given options. +func cloneWithOptions(p *SpecProducer, opts ...Option) (*SpecProducer, error) { + sp := &SpecProducer{ + Spec: p.Spec, + options: p.options, + } + return sp.applyOptions(opts...) +} + +// applyOptions applies the specified options to a SpecProducer +func (p *SpecProducer) applyOptions(opts ...Option) (*SpecProducer, error) { + for _, opt := range opts { + err := opt(&p.options) + if err != nil { + return nil, err + } + } + return p, nil +} + +// Save writes a CDI spec to a file with the specified name. +// If the filename ends in a supported extension, the format implied by the +// extension takes precedence over the format with which the SpecProducer was +// configured. +func (p *SpecProducer) Save(filename string) (string, error) { + filename = p.normalizeFilename(filename) + format := p.specFormatFromFilename(filename) + // If the currently configured format doesn't match the expected format for + // the filename, we create a new + if p.specFormat != format { + pWithFormat, err := cloneWithOptions(p, WithSpecFormat(format)) + if err != nil { + return "", err + } + return pWithFormat.Save(filename) + } + + dir := filepath.Dir(filename) + if dir != "" { + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", fmt.Errorf("failed to create Spec dir: %w", err) + } + } + + tmp, err := os.CreateTemp(dir, "spec.*.tmp") + if err != nil { + return "", fmt.Errorf("failed to create Spec file: %w", err) + } + _, err = p.WriteTo(tmp) + tmp.Close() + if err != nil { + return "", fmt.Errorf("failed to write Spec file: %w", err) + } + + if err := os.Chmod(tmp.Name(), p.permissions); err != nil { + return "", fmt.Errorf("failed to set permissions on spec file: %w", err) + } + + err = renameIn(dir, filepath.Base(tmp.Name()), filepath.Base(filename), p.overwrite) + if err != nil { + _ = os.Remove(tmp.Name()) + return "", fmt.Errorf("failed to write Spec file: %w", err) + } + return filename, nil +} + +// Validate performs an explicit validation of the spec. +// If no validator is configured, the spec is considered unconditionally valid. +func (p *SpecProducer) Validate() error { + if p == nil || p.specValidator == nil { + return nil + } + return p.specValidator.Validate(p.Spec) +} + +// WriteTo writes the spec to the specified writer. +func (p *SpecProducer) WriteTo(w io.Writer) (int64, error) { + data, err := p.contents() + if err != nil { + return 0, fmt.Errorf("failed to marshal Spec file: %w", err) + } + + n, err := w.Write(data) + return int64(n), err +} + +// contents returns the raw contents of a CDI specification. +// Validation is performed before marshalling the contentent based on the spec format. +func (p *SpecProducer) contents() ([]byte, error) { + if err := p.Validate(); err != nil { + return nil, fmt.Errorf("spec validation failed: %w", err) + } + return p.marshal() +} + +// marshal returns the raw contents of a CDI specification. +// No validation is performed. +func (p *SpecProducer) marshal() ([]byte, error) { + switch p.options.specFormat { + case SpecFormatYAML: + data, err := yaml.Marshal(p.Spec) + if err != nil { + return nil, err + } + data = append([]byte("---\n"), data...) + return data, nil + case SpecFormatJSON: + return json.Marshal(p.Spec) + default: + return nil, fmt.Errorf("undefined CDI spec format %v", s.options.specFormat) + } +} + +// specFormatFromFilename determines the CDI spec format for the given filename. +func (p *SpecProducer) specFormatFromFilename(filename string) SpecFormat { + switch filepath.Ext(filename) { + case ".json": + return SpecFormatJSON + case ".yaml", ".yml": + return SpecFormatYAML + default: + return p.specFormat + } +} + +// normalizeFilename ensures that the specified filename ends in a supported extension. +func (p *SpecProducer) normalizeFilename(filename string) string { + switch filepath.Ext(filename) { + case ".json": + fallthrough + case ".yaml": + return filename + default: + return filename + string(p.specFormat) + } +} diff --git a/pkg/cdi/producer/producer_test.go b/pkg/cdi/producer/producer_test.go new file mode 100644 index 00000000..1144d6a7 --- /dev/null +++ b/pkg/cdi/producer/producer_test.go @@ -0,0 +1,149 @@ +package producer + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" + cdi "tags.cncf.io/container-device-interface/specs-go" +) + +func TestSave(t *testing.T) { + testCases := []struct { + description string + spec cdi.Spec + options []Option + filename string + expectedError error + expectedFilename string + expectedPermissions os.FileMode + expectedOutput string + }{ + { + description: "output as json", + spec: cdi.Spec{ + Version: "v0.3.0", + Kind: "example.com/class", + ContainerEdits: cdi.ContainerEdits{ + DeviceNodes: []*cdi.DeviceNode{ + { + Path: "/dev/foo", + }, + }, + }, + }, + options: []Option{}, + filename: "foo.json", + expectedFilename: "foo.json", + expectedPermissions: 0600, + expectedOutput: `{"cdiVersion":"v0.3.0","kind":"example.com/class","devices":null,"containerEdits":{"deviceNodes":[{"path":"/dev/foo"}]}}`, + }, + { + description: "output with permissions", + spec: cdi.Spec{ + Version: "v0.3.0", + Kind: "example.com/class", + ContainerEdits: cdi.ContainerEdits{ + DeviceNodes: []*cdi.DeviceNode{ + { + Path: "/dev/foo", + }, + }, + }, + }, + options: []Option{WithPermissions(0644)}, + filename: "foo.json", + expectedFilename: "foo.json", + expectedPermissions: 0644, + expectedOutput: `{"cdiVersion":"v0.3.0","kind":"example.com/class","devices":null,"containerEdits":{"deviceNodes":[{"path":"/dev/foo"}]}}`, + }, + { + description: "spec is validated on save", + spec: cdi.Spec{ + Version: "v99.3.0", + }, + options: []Option{}, + filename: "foo.json", + expectedError: validator.ErrInvalid, + }, + { + description: "filename overwrites format", + spec: cdi.Spec{ + Version: "v0.3.0", + Kind: "example.com/class", + ContainerEdits: cdi.ContainerEdits{ + DeviceNodes: []*cdi.DeviceNode{ + { + Path: "/dev/foo", + }, + }, + }, + }, + options: []Option{WithSpecFormat(SpecFormatJSON)}, + filename: "foo.yaml", + expectedFilename: "foo.yaml", + expectedPermissions: 0600, + expectedOutput: `--- +cdiVersion: v0.3.0 +containerEdits: + deviceNodes: + - path: /dev/foo +devices: null +kind: example.com/class +`, + }, + { + description: "filename is inferred from format", + spec: cdi.Spec{ + Version: "v0.3.0", + Kind: "example.com/class", + ContainerEdits: cdi.ContainerEdits{ + DeviceNodes: []*cdi.DeviceNode{ + { + Path: "/dev/foo", + }, + }, + }, + }, + options: []Option{WithSpecFormat(SpecFormatYAML)}, + filename: "foo", + expectedFilename: "foo.yaml", + expectedPermissions: 0600, + expectedOutput: `--- +cdiVersion: v0.3.0 +containerEdits: + deviceNodes: + - path: /dev/foo +devices: null +kind: example.com/class +`, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + + outputDir := t.TempDir() + + p, err := New(&tc.spec, tc.options...) + require.NoError(t, err) + + f, err := p.Save(filepath.Join(outputDir, tc.filename)) + require.ErrorIs(t, err, tc.expectedError) + if tc.expectedError != nil { + return + } + + require.Equal(t, filepath.Join(outputDir, tc.expectedFilename), f) + info, err := os.Stat(f) + require.NoError(t, err) + + require.Equal(t, tc.expectedPermissions, info.Mode()) + + contents, _ := os.ReadFile(f) + require.Equal(t, tc.expectedOutput, string(contents)) + }) + } +} diff --git a/pkg/cdi/spec_linux.go b/pkg/cdi/producer/renamein_linux.go similarity index 98% rename from pkg/cdi/spec_linux.go rename to pkg/cdi/producer/renamein_linux.go index 9ad27392..7d17b2f3 100644 --- a/pkg/cdi/spec_linux.go +++ b/pkg/cdi/producer/renamein_linux.go @@ -14,7 +14,7 @@ limitations under the License. */ -package cdi +package producer import ( "fmt" diff --git a/pkg/cdi/spec_other.go b/pkg/cdi/producer/renamein_other.go similarity index 98% rename from pkg/cdi/spec_other.go rename to pkg/cdi/producer/renamein_other.go index 285e04e2..96ba268a 100644 --- a/pkg/cdi/spec_other.go +++ b/pkg/cdi/producer/renamein_other.go @@ -17,7 +17,7 @@ limitations under the License. */ -package cdi +package producer import ( "os" diff --git a/pkg/cdi/producer/validator/api.go b/pkg/cdi/producer/validator/api.go new file mode 100644 index 00000000..fce9ff7e --- /dev/null +++ b/pkg/cdi/producer/validator/api.go @@ -0,0 +1,28 @@ +/* + Copyright © 2024 The CDI Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package validator + +import "errors" + +// Validators as constants. +const ( + Default = defaultValidator("default") + Disabled = disabledValidator("disabled") +) + +// An ErrInvalid error can be returned if CDI spec validation fails. +var ErrInvalid = errors.New("invalid") diff --git a/pkg/cdi/producer/validator/validator-default.go b/pkg/cdi/producer/validator/validator-default.go new file mode 100644 index 00000000..d3c03ca0 --- /dev/null +++ b/pkg/cdi/producer/validator/validator-default.go @@ -0,0 +1,257 @@ +/* + Copyright © 2024 The CDI Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package validator + +import ( + "errors" + "fmt" + "strings" + + "tags.cncf.io/container-device-interface/internal/validation" + "tags.cncf.io/container-device-interface/pkg/parser" + cdi "tags.cncf.io/container-device-interface/specs-go" +) + +type defaultValidator string + +// ValidateAny implements a generic validation handler for the defaultValidator. +func (v defaultValidator) ValidateAny(o interface{}) (rerr error) { + defer func() { + if rerr != nil { + rerr = errors.Join(rerr, ErrInvalid) + } + }() + + switch o := o.(type) { + case *cdi.ContainerEdits: + return v.validateEdits(o) + case *cdi.Device: + return v.validateDevice("", "", o) + case *cdi.DeviceNode: + return v.validateDeviceNode(o) + case *cdi.Hook: + return v.validateHook(o) + case *cdi.IntelRdt: + return v.validateIntelRdt(o) + case *cdi.Mount: + return v.validateMount(o) + case *cdi.Spec: + return v.Validate(o) + default: + return fmt.Errorf("unsupported validation type: %T", o) + } +} + +// Validate performs a default validation on a CDI spec. +func (v defaultValidator) Validate(s *cdi.Spec) (rerr error) { + defer func() { + if rerr != nil { + rerr = errors.Join(rerr, ErrInvalid) + } + }() + + if err := cdi.ValidateVersion(s); err != nil { + return err + } + vendor, class := parser.ParseQualifier(s.Kind) + if err := parser.ValidateVendorName(vendor); err != nil { + return err + } + if err := parser.ValidateClassName(class); err != nil { + return err + } + if err := validation.ValidateSpecAnnotations(s.Kind, s.Annotations); err != nil { + return err + } + if err := v.validateEdits(&s.ContainerEdits); err != nil { + return err + } + + seen := make(map[string]bool) + for _, d := range s.Devices { + if seen[d.Name] { + return fmt.Errorf("invalid spec, multiple device %q", d.Name) + } + seen[d.Name] = true + if err := v.validateDevice(vendor, class, &d); err != nil { + return fmt.Errorf("invalid device %q: %w", d.Name, err) + } + } + return nil +} + +func (v defaultValidator) validateDevice(vendor string, class string, d *cdi.Device) error { + if err := parser.ValidateDeviceName(d.Name); err != nil { + return err + } + + name := parser.QualifiedName(vendor, class, d.Name) + if err := validation.ValidateSpecAnnotations(name, d.Annotations); err != nil { + return err + } + + if err := v.assertNonEmptyEdits(&d.ContainerEdits); err != nil { + return err + } + if err := v.validateEdits(&d.ContainerEdits); err != nil { + return err + } + return nil +} + +func (v defaultValidator) assertNonEmptyEdits(e *cdi.ContainerEdits) error { + if e == nil { + return nil + } + if len(e.Env) > 0 { + return nil + } + if len(e.DeviceNodes) > 0 { + return nil + } + if len(e.Hooks) > 0 { + return nil + } + if len(e.Mounts) > 0 { + return nil + } + if len(e.AdditionalGIDs) > 0 { + return nil + } + if e.IntelRdt != nil { + return nil + } + return errors.New("empty container edits") +} + +func (v defaultValidator) validateEdits(e *cdi.ContainerEdits) error { + if e == nil { + return nil + } + if err := v.validateEnv(e.Env); err != nil { + return fmt.Errorf("invalid container edits: %w", err) + } + for _, d := range e.DeviceNodes { + if err := v.validateDeviceNode(d); err != nil { + return err + } + } + for _, h := range e.Hooks { + if err := v.validateHook(h); err != nil { + return err + } + } + for _, m := range e.Mounts { + if err := v.validateMount(m); err != nil { + return err + } + } + if err := v.validateIntelRdt(e.IntelRdt); err != nil { + return err + } + return nil +} + +func (v defaultValidator) validateEnv(env []string) error { + for _, v := range env { + if strings.IndexByte(v, byte('=')) <= 0 { + return fmt.Errorf("invalid environment variable %q", v) + } + } + return nil +} + +func (v defaultValidator) validateDeviceNode(d *cdi.DeviceNode) error { + validTypes := map[string]struct{}{ + "": {}, + "b": {}, + "c": {}, + "u": {}, + "p": {}, + } + + if d.Path == "" { + return errors.New("invalid (empty) device path") + } + if _, ok := validTypes[d.Type]; !ok { + return fmt.Errorf("device %q: invalid type %q", d.Path, d.Type) + } + for _, bit := range d.Permissions { + if bit != 'r' && bit != 'w' && bit != 'm' { + return fmt.Errorf("device %q: invalid permissions %q", + d.Path, d.Permissions) + } + } + return nil +} + +func (v defaultValidator) validateHook(h *cdi.Hook) error { + const ( + // PrestartHook is the name of the OCI "prestart" hook. + PrestartHook = "prestart" + // CreateRuntimeHook is the name of the OCI "createRuntime" hook. + CreateRuntimeHook = "createRuntime" + // CreateContainerHook is the name of the OCI "createContainer" hook. + CreateContainerHook = "createContainer" + // StartContainerHook is the name of the OCI "startContainer" hook. + StartContainerHook = "startContainer" + // PoststartHook is the name of the OCI "poststart" hook. + PoststartHook = "poststart" + // PoststopHook is the name of the OCI "poststop" hook. + PoststopHook = "poststop" + ) + validHookNames := map[string]struct{}{ + PrestartHook: {}, + CreateRuntimeHook: {}, + CreateContainerHook: {}, + StartContainerHook: {}, + PoststartHook: {}, + PoststopHook: {}, + } + + if _, ok := validHookNames[h.HookName]; !ok { + return fmt.Errorf("invalid hook name %q", h.HookName) + } + if h.Path == "" { + return fmt.Errorf("invalid hook %q with empty path", h.HookName) + } + if err := v.validateEnv(h.Env); err != nil { + return fmt.Errorf("invalid hook %q: %w", h.HookName, err) + } + return nil +} + +func (v defaultValidator) validateMount(m *cdi.Mount) error { + if m.HostPath == "" { + return errors.New("invalid mount, empty host path") + } + if m.ContainerPath == "" { + return errors.New("invalid mount, empty container path") + } + return nil +} + +func (v defaultValidator) validateIntelRdt(i *cdi.IntelRdt) error { + if i == nil { + return nil + } + // ClosID must be a valid Linux filename + if len(i.ClosID) >= 4096 || i.ClosID == "." || i.ClosID == ".." || strings.ContainsAny(i.ClosID, "/\n") { + return errors.New("invalid ClosID") + } + return nil +} diff --git a/pkg/cdi/producer/validator/validator-disabled.go b/pkg/cdi/producer/validator/validator-disabled.go new file mode 100644 index 00000000..9323dc7e --- /dev/null +++ b/pkg/cdi/producer/validator/validator-disabled.go @@ -0,0 +1,29 @@ +/* + Copyright © 2024 The CDI Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package validator + +import ( + cdi "tags.cncf.io/container-device-interface/specs-go" +) + +// A disabledValidator performs no validation. +type disabledValidator string + +// Validate always passes for a disabledValidator. +func (v disabledValidator) Validate(*cdi.Spec) error { + return nil +} diff --git a/pkg/cdi/spec.go b/pkg/cdi/spec.go index f0231d81..cdd4f709 100644 --- a/pkg/cdi/spec.go +++ b/pkg/cdi/spec.go @@ -17,7 +17,6 @@ package cdi import ( - "encoding/json" "fmt" "os" "path/filepath" @@ -27,7 +26,8 @@ import ( oci "github.com/opencontainers/runtime-spec/specs-go" "sigs.k8s.io/yaml" - "tags.cncf.io/container-device-interface/internal/validation" + "tags.cncf.io/container-device-interface/pkg/cdi/producer" + "tags.cncf.io/container-device-interface/pkg/cdi/producer/validator" "tags.cncf.io/container-device-interface/pkg/parser" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -118,52 +118,20 @@ func newSpec(raw *cdi.Spec, path string, priority int) (*Spec, error) { // Write the CDI Spec to the file associated with it during instantiation // by newSpec() or ReadSpec(). func (s *Spec) write(overwrite bool) error { - var ( - data []byte - dir string - tmp *os.File - err error + p, err := producer.New( + s.Spec, + producer.WithOverwrite(overwrite), ) - - err = validateSpec(s.Spec) if err != nil { return err } - if filepath.Ext(s.path) == ".yaml" { - data, err = yaml.Marshal(s.Spec) - data = append([]byte("---\n"), data...) - } else { - data, err = json.Marshal(s.Spec) - } + savedPath, err := p.Save(s.path) if err != nil { - return fmt.Errorf("failed to marshal Spec file: %w", err) - } - - dir = filepath.Dir(s.path) - err = os.MkdirAll(dir, 0o755) - if err != nil { - return fmt.Errorf("failed to create Spec dir: %w", err) - } - - tmp, err = os.CreateTemp(dir, "spec.*.tmp") - if err != nil { - return fmt.Errorf("failed to create Spec file: %w", err) - } - _, err = tmp.Write(data) - tmp.Close() - if err != nil { - return fmt.Errorf("failed to write Spec file: %w", err) - } - - err = renameIn(dir, filepath.Base(tmp.Name()), filepath.Base(s.path), overwrite) - - if err != nil { - os.Remove(tmp.Name()) - err = fmt.Errorf("failed to write Spec file: %w", err) + return err } - - return err + s.path = savedPath + return nil } // GetVendor returns the vendor of this Spec. @@ -209,22 +177,12 @@ func MinimumRequiredVersion(spec *cdi.Spec) (string, error) { // Validate the Spec. func (s *Spec) validate() (map[string]*Device, error) { - if err := cdi.ValidateVersion(s.Spec); err != nil { - return nil, err - } - if err := parser.ValidateVendorName(s.vendor); err != nil { - return nil, err - } - if err := parser.ValidateClassName(s.class); err != nil { - return nil, err - } - if err := validation.ValidateSpecAnnotations(s.Kind, s.Annotations); err != nil { - return nil, err - } - if err := s.edits().Validate(); err != nil { + if err := validator.Default.Validate(s.Spec); err != nil { return nil, err } + // TODO: The validator above should perform the same validation as below but + // we still need to construct the device map. devices := make(map[string]*Device) for _, d := range s.Devices { dev, err := newDevice(s, d)