From 95091013da2a918422dea63d18b4dc4b85526478 Mon Sep 17 00:00:00 2001 From: "Sean T. Allen" Date: Tue, 3 Aug 2021 09:46:46 -0400 Subject: [PATCH] Add basis for allowing the creation of configuration enforcement in gcs This commit is the minimal set of functionality needed to allow users to create a configuration policy that gcs can enforce. Policy enforcement will allow users to state "only these containers, with these command lines, etc etc" should be run. If anything in gcs doesn't match the user supplied policy, it will end container run and report an error. Currently, only container filesystem policy is enforced. This is done at two points. When a pmem device is mounted, its dm-verity root hash is checked against policy to see if it is allowed. At the time of overlay creation, the order of layers is compared to policy to make sure that the container is being constructed as the user expected. Additional policy enforcement that is coming in future commits includes: - enforce policy for scsi mounts - enforce container command line - enforce environment variables --- .github/workflows/ci.yml | 1 + cmd/containerd-shim-runhcs-v1/pod.go | 10 +- go.mod | 1 + go.sum | 1 + internal/guest/prot/protocol.go | 10 + internal/guest/runtime/hcsv2/uvm.go | 76 +- internal/guest/storage/overlay/overlay.go | 7 +- .../guest/storage/overlay/overlay_test.go | 80 +- internal/guest/storage/pmem/pmem.go | 18 +- internal/guest/storage/pmem/pmem_test.go | 56 +- .../mountmonitoringsecuritypolicyenforcer.go | 24 + internal/guestrequest/types.go | 10 +- internal/hcsoci/resources_lcow.go | 2 +- internal/hcsoci/resources_wcow.go | 2 +- internal/jobcontainers/jobcontainer.go | 2 +- internal/jobcontainers/storage.go | 4 +- internal/layers/layers.go | 18 +- internal/oci/annotations.go | 3 + internal/oci/uvm.go | 2 + internal/tools/securitypolicy/README.md | 83 ++ internal/tools/securitypolicy/main.go | 179 ++++ internal/uvm/combine_layers.go | 24 +- internal/uvm/create_lcow.go | 2 + internal/uvm/security_policy.go | 50 + pkg/securitypolicy/securitypolicy.go | 38 + pkg/securitypolicy/securitypolicy_test.go | 459 +++++++++ pkg/securitypolicy/securitypolicyenforcer.go | 204 ++++ vendor/github.com/BurntSushi/toml/.gitignore | 5 + vendor/github.com/BurntSushi/toml/.travis.yml | 15 + vendor/github.com/BurntSushi/toml/COMPATIBLE | 3 + vendor/github.com/BurntSushi/toml/COPYING | 21 + vendor/github.com/BurntSushi/toml/Makefile | 19 + vendor/github.com/BurntSushi/toml/README.md | 218 ++++ vendor/github.com/BurntSushi/toml/decode.go | 509 ++++++++++ .../github.com/BurntSushi/toml/decode_meta.go | 121 +++ vendor/github.com/BurntSushi/toml/doc.go | 27 + vendor/github.com/BurntSushi/toml/encode.go | 568 +++++++++++ .../BurntSushi/toml/encoding_types.go | 19 + .../BurntSushi/toml/encoding_types_1.1.go | 18 + vendor/github.com/BurntSushi/toml/lex.go | 953 ++++++++++++++++++ vendor/github.com/BurntSushi/toml/parse.go | 592 +++++++++++ vendor/github.com/BurntSushi/toml/session.vim | 1 + .../github.com/BurntSushi/toml/type_check.go | 91 ++ .../github.com/BurntSushi/toml/type_fields.go | 242 +++++ vendor/modules.txt | 2 + 45 files changed, 4747 insertions(+), 43 deletions(-) create mode 100644 internal/guest/storage/test/policy/mountmonitoringsecuritypolicyenforcer.go create mode 100644 internal/tools/securitypolicy/README.md create mode 100644 internal/tools/securitypolicy/main.go create mode 100644 internal/uvm/security_policy.go create mode 100644 pkg/securitypolicy/securitypolicy.go create mode 100644 pkg/securitypolicy/securitypolicy_test.go create mode 100644 pkg/securitypolicy/securitypolicyenforcer.go create mode 100644 vendor/github.com/BurntSushi/toml/.gitignore create mode 100644 vendor/github.com/BurntSushi/toml/.travis.yml create mode 100644 vendor/github.com/BurntSushi/toml/COMPATIBLE create mode 100644 vendor/github.com/BurntSushi/toml/COPYING create mode 100644 vendor/github.com/BurntSushi/toml/Makefile create mode 100644 vendor/github.com/BurntSushi/toml/README.md create mode 100644 vendor/github.com/BurntSushi/toml/decode.go create mode 100644 vendor/github.com/BurntSushi/toml/decode_meta.go create mode 100644 vendor/github.com/BurntSushi/toml/doc.go create mode 100644 vendor/github.com/BurntSushi/toml/encode.go create mode 100644 vendor/github.com/BurntSushi/toml/encoding_types.go create mode 100644 vendor/github.com/BurntSushi/toml/encoding_types_1.1.go create mode 100644 vendor/github.com/BurntSushi/toml/lex.go create mode 100644 vendor/github.com/BurntSushi/toml/parse.go create mode 100644 vendor/github.com/BurntSushi/toml/session.vim create mode 100644 vendor/github.com/BurntSushi/toml/type_check.go create mode 100644 vendor/github.com/BurntSushi/toml/type_fields.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8790872b02..b872db8d77 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,6 +70,7 @@ jobs: - run: go build ./cmd/ncproxy - run: go build ./cmd/dmverity-vhd - run: go build ./internal/tools/grantvmgroupaccess + - run: go build ./internal/tools/securitypolicy - run: go build ./internal/tools/uvmboot - run: go build ./internal/tools/zapdir diff --git a/cmd/containerd-shim-runhcs-v1/pod.go b/cmd/containerd-shim-runhcs-v1/pod.go index bf9dfa930b..abd9cbc506 100644 --- a/cmd/containerd-shim-runhcs-v1/pod.go +++ b/cmd/containerd-shim-runhcs-v1/pod.go @@ -89,6 +89,7 @@ func createPod(ctx context.Context, events publisher, req *task.CreateTaskReques } var parent *uvm.UtilityVM + var lopts *uvm.OptionsLCOW if oci.IsIsolated(s) { // Create the UVM parent opts, err := oci.SpecToUVMCreateOpts(ctx, s, fmt.Sprintf("%s@vm", req.ID), owner) @@ -97,7 +98,7 @@ func createPod(ctx context.Context, events publisher, req *task.CreateTaskReques } switch opts.(type) { case *uvm.OptionsLCOW: - lopts := (opts).(*uvm.OptionsLCOW) + lopts = (opts).(*uvm.OptionsLCOW) parent, err = uvm.CreateLCOW(ctx, lopts) if err != nil { return nil, err @@ -130,6 +131,13 @@ func createPod(ctx context.Context, events publisher, req *task.CreateTaskReques parent.Close() return nil, err } + + if lopts != nil { + err := parent.SetSecurityPolicy(ctx, lopts.SecurityPolicy) + if err != nil { + return nil, errors.Wrap(err, "unable to set security policy") + } + } } else if oci.IsJobContainer(s) { // If we're making a job container fake a task (i.e reuse the wcowPodSandbox logic) p.sandboxTask = newWcowPodSandboxTask(ctx, events, req.ID, req.Bundle, parent, "") diff --git a/go.mod b/go.mod index b9db24af84..1556ffe635 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/Microsoft/hcsshim go 1.13 require ( + github.com/BurntSushi/toml v0.3.1 github.com/Microsoft/go-winio v0.4.17 github.com/containerd/cgroups v1.0.1 github.com/containerd/console v1.0.2 diff --git a/go.sum b/go.sum index e5443094d2..87d1893f05 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,7 @@ github.com/Azure/go-autorest/autorest/mocks v0.4.0/go.mod h1:LTp+uSrOhSkaKrUy935 github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= diff --git a/internal/guest/prot/protocol.go b/internal/guest/prot/protocol.go index b12554eb06..91ee04f992 100644 --- a/internal/guest/prot/protocol.go +++ b/internal/guest/prot/protocol.go @@ -8,6 +8,7 @@ import ( "strconv" "github.com/Microsoft/hcsshim/internal/guest/commonutils" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" v1 "github.com/containerd/cgroups/stats/v1" oci "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" @@ -519,6 +520,8 @@ const ( MrtVPCIDevice = ModifyResourceType("VPCIDevice") // MrtContainerConstraints is the modify resource type for updating container constraints MrtContainerConstraints = ModifyResourceType("ContainerConstraints") + // MrtSecurityPolicy is the modify resource type for updating the security policy + MrtSecurityPolicy = ModifyResourceType("SecurityPolicy") ) // ModifyRequestType is the type of operation to perform on a given modify @@ -618,6 +621,12 @@ func UnmarshalContainerModifySettings(b []byte) (*ContainerModifySettings, error return &request, errors.Wrap(err, "failed to unmarshal settings as ContainerConstraintsV2") } msr.Settings = cc + case MrtSecurityPolicy: + policy := &securitypolicy.EncodedSecurityPolicy{} + if err := commonutils.UnmarshalJSONWithHresult(msrRawSettings, policy); err != nil { + return &request, errors.Wrap(err, "failed to unmarshal settings as EncodedSecurityPolicy") + } + msr.Settings = policy default: return &request, errors.Errorf("invalid ResourceType '%s'", msr.ResourceType) } @@ -713,6 +722,7 @@ type CombinedLayersV2 struct { Layers []Layer `json:",omitempty"` ScratchPath string `json:",omitempty"` ContainerRootPath string + ContainerId string `json:",omitempty"` } // NetworkAdapter represents a network interface and its associated diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 1bd360ef82..37583684f9 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -5,6 +5,7 @@ package hcsv2 import ( "bufio" "context" + "encoding/base64" "encoding/json" "fmt" "os" @@ -26,6 +27,7 @@ import ( "github.com/Microsoft/hcsshim/internal/guest/storage/pmem" "github.com/Microsoft/hcsshim/internal/guest/storage/scsi" "github.com/Microsoft/hcsshim/internal/guest/transport" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" shellwords "github.com/mattn/go-shellwords" "github.com/pkg/errors" ) @@ -46,15 +48,60 @@ type Host struct { // Rtime is the Runtime interface used by the GCS core. rtime runtime.Runtime vsock transport.Transport + + // state required for the security policy enforcement + policyMutex sync.Mutex + securityPolicyEnforcer securitypolicy.SecurityPolicyEnforcer + securityPolicyEnforcerSet bool } func NewHost(rtime runtime.Runtime, vsock transport.Transport) *Host { return &Host{ - containers: make(map[string]*Container), - externalProcesses: make(map[int]*externalProcess), - rtime: rtime, - vsock: vsock, + containers: make(map[string]*Container), + externalProcesses: make(map[int]*externalProcess), + rtime: rtime, + vsock: vsock, + securityPolicyEnforcerSet: false, + securityPolicyEnforcer: &securitypolicy.OpenDoorSecurityPolicyEnforcer{}, + } +} + +// SetSecurityPolicy takes a base64 encoded security policy +// and sets up our internal data structures we use to store +// said policy. +// The security policy is transmitted as json in an annotation, +// so we first have to remove the base64 encoding that allows +// the JSON based policy to be passed as a string. From there, +// we decode the JSON and setup our security policy state +func (h *Host) SetSecurityPolicy(base64_policy string) error { + h.policyMutex.Lock() + defer h.policyMutex.Unlock() + if h.securityPolicyEnforcerSet { + return errors.New("security policy has already been set") + } + + // base64 decode the incoming policy string + // its base64 encoded because it is coming from an annotation + // annotations are a map of string to string + // we want to store a complex json object so.... base64 it is + jsonPolicy, err := base64.StdEncoding.DecodeString(base64_policy) + if err != nil { + return errors.Wrap(err, "Unable to decode policy from Base64 format") + } + + // json unmarshall the decoded to a SecurityPolicy + securityPolicy := &securitypolicy.SecurityPolicy{} + json.Unmarshal(jsonPolicy, securityPolicy) + + p, err := securitypolicy.NewSecurityPolicyEnforcer(securityPolicy) + if err != nil { + return err } + + h.securityPolicyEnforcer = p + h.securityPolicyEnforcerSet = true + + return nil } func (h *Host) RemoveContainer(id string) { @@ -200,9 +247,9 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, setti case prot.MrtMappedDirectory: return modifyMappedDirectory(ctx, h.vsock, settings.RequestType, settings.Settings.(*prot.MappedDirectoryV2)) case prot.MrtVPMemDevice: - return modifyMappedVPMemDevice(ctx, settings.RequestType, settings.Settings.(*prot.MappedVPMemDeviceV2)) + return modifyMappedVPMemDevice(ctx, settings.RequestType, settings.Settings.(*prot.MappedVPMemDeviceV2), h.securityPolicyEnforcer) case prot.MrtCombinedLayers: - return modifyCombinedLayers(ctx, settings.RequestType, settings.Settings.(*prot.CombinedLayersV2)) + return modifyCombinedLayers(ctx, settings.RequestType, settings.Settings.(*prot.CombinedLayersV2), h.securityPolicyEnforcer) case prot.MrtNetwork: return modifyNetwork(ctx, settings.RequestType, settings.Settings.(*prot.NetworkAdapterV2)) case prot.MrtVPCIDevice: @@ -213,6 +260,13 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, setti return err } return c.modifyContainerConstraints(ctx, settings.RequestType, settings.Settings.(*prot.ContainerConstraintsV2)) + case prot.MrtSecurityPolicy: + policy, ok := settings.Settings.(*securitypolicy.EncodedSecurityPolicy) + if !ok { + return errors.New("the request's settings are not of type EncodedSecurityPolicy") + } + + return h.SetSecurityPolicy(policy.SecurityPolicy) default: return errors.Errorf("the ResourceType \"%s\" is not supported for UVM", settings.ResourceType) } @@ -381,12 +435,12 @@ func modifyMappedDirectory(ctx context.Context, vsock transport.Transport, rt pr } } -func modifyMappedVPMemDevice(ctx context.Context, rt prot.ModifyRequestType, vpd *prot.MappedVPMemDeviceV2) (err error) { +func modifyMappedVPMemDevice(ctx context.Context, rt prot.ModifyRequestType, vpd *prot.MappedVPMemDeviceV2, securityPolicy securitypolicy.SecurityPolicyEnforcer) (err error) { switch rt { case prot.MreqtAdd: - return pmem.Mount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, vpd.VerityInfo) + return pmem.Mount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, vpd.VerityInfo, securityPolicy) case prot.MreqtRemove: - return pmem.Unmount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, vpd.VerityInfo) + return pmem.Unmount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, vpd.VerityInfo, securityPolicy) default: return newInvalidRequestTypeError(rt) } @@ -401,7 +455,7 @@ func modifyMappedVPCIDevice(ctx context.Context, rt prot.ModifyRequestType, vpci } } -func modifyCombinedLayers(ctx context.Context, rt prot.ModifyRequestType, cl *prot.CombinedLayersV2) (err error) { +func modifyCombinedLayers(ctx context.Context, rt prot.ModifyRequestType, cl *prot.CombinedLayersV2, securityPolicy securitypolicy.SecurityPolicyEnforcer) (err error) { switch rt { case prot.MreqtAdd: layerPaths := make([]string, len(cl.Layers)) @@ -420,7 +474,7 @@ func modifyCombinedLayers(ctx context.Context, rt prot.ModifyRequestType, cl *pr workdirPath = filepath.Join(cl.ScratchPath, "work") } - return overlay.Mount(ctx, layerPaths, upperdirPath, workdirPath, cl.ContainerRootPath, readonly) + return overlay.Mount(ctx, layerPaths, upperdirPath, workdirPath, cl.ContainerRootPath, readonly, cl.ContainerId, securityPolicy) case prot.MreqtRemove: return storage.UnmountPath(ctx, cl.ContainerRootPath, true) default: diff --git a/internal/guest/storage/overlay/overlay.go b/internal/guest/storage/overlay/overlay.go index 07fc24eff1..696ff09565 100644 --- a/internal/guest/storage/overlay/overlay.go +++ b/internal/guest/storage/overlay/overlay.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/Microsoft/hcsshim/internal/oc" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" "github.com/pkg/errors" "go.opencensus.io/trace" "golang.org/x/sys/unix" @@ -30,11 +31,15 @@ var ( // // Always creates `rootfsPath`. On mount failure the created `rootfsPath` will // be automatically cleaned up. -func Mount(ctx context.Context, layerPaths []string, upperdirPath, workdirPath, rootfsPath string, readonly bool) (err error) { +func Mount(ctx context.Context, layerPaths []string, upperdirPath, workdirPath, rootfsPath string, readonly bool, containerId string, securityPolicy securitypolicy.SecurityPolicyEnforcer) (err error) { _, span := trace.StartSpan(ctx, "overlay::Mount") defer span.End() defer func() { oc.SetSpanStatus(span, err) }() + if err := securityPolicy.EnforceOverlayMountPolicy(containerId, layerPaths); err != nil { + return err + } + lowerdir := strings.Join(layerPaths, ":") span.AddAttributes( trace.StringAttribute("layerPaths", lowerdir), diff --git a/internal/guest/storage/overlay/overlay_test.go b/internal/guest/storage/overlay/overlay_test.go index 48363c1d41..bbbf9019ea 100644 --- a/internal/guest/storage/overlay/overlay_test.go +++ b/internal/guest/storage/overlay/overlay_test.go @@ -7,6 +7,13 @@ import ( "errors" "os" "testing" + + "github.com/Microsoft/hcsshim/internal/guest/storage/test/policy" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" +) + +const ( + fakeContainerId = "1" ) type undo struct { @@ -76,7 +83,7 @@ func Test_Mount_Success(t *testing.T) { return nil } - err := Mount(context.Background(), []string{"/layer1", "/layer2"}, "/upper", "/work", "/root", false) + err := Mount(context.Background(), []string{"/layer1", "/layer2"}, "/upper", "/work", "/root", false, fakeContainerId, openDoorSecurityPolicyEnforcer()) if err != nil { t.Fatalf("expected no error got: %v", err) } @@ -120,7 +127,7 @@ func Test_Mount_Readonly_Success(t *testing.T) { return nil } - err := Mount(context.Background(), []string{"/layer1", "/layer2"}, "", "", "/root", false) + err := Mount(context.Background(), []string{"/layer1", "/layer2"}, "", "", "/root", false, fakeContainerId, openDoorSecurityPolicyEnforcer()) if err != nil { t.Fatalf("expected no error got: %v", err) } @@ -128,3 +135,72 @@ func Test_Mount_Readonly_Success(t *testing.T) { t.Fatal("expected root to be created") } } + +func Test_Security_Policy_Enforcement(t *testing.T) { + undo := captureTestMethods() + defer undo.Close() + + var upperCreated, workCreated, rootCreated bool + osMkdirAll = func(path string, perm os.FileMode) error { + if perm != 0755 { + t.Errorf("os.MkdirAll at: %s, perm: %v expected perm: 0755", path, perm) + } + switch path { + case "/upper": + upperCreated = true + return nil + case "/work": + workCreated = true + return nil + case "/root": + rootCreated = true + return nil + } + return errors.New("unexpected os.MkdirAll path") + } + unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { + if source != "overlay" { + t.Errorf("expected source: 'overlay' got: %v", source) + } + if target != "/root" { + t.Errorf("expected target: '/root' got: %v", target) + } + if fstype != "overlay" { + t.Errorf("expected fstype: 'overlay' got: %v", fstype) + } + if flags != 0 { + t.Errorf("expected flags: '0' got: %v", flags) + } + if data != "lowerdir=/layer1:/layer2,upperdir=/upper,workdir=/work" { + t.Errorf("expected data: 'lowerdir=/layer1:/layer2,upperdir=/upper,workdir=/work' got: %v", data) + } + return nil + } + + enforcer := mountMonitoringSecurityPolicyEnforcer() + err := Mount(context.Background(), []string{"/layer1", "/layer2"}, "/upper", "/work", "/root", false, fakeContainerId, enforcer) + if err != nil { + t.Fatalf("expected no error got: %v", err) + } + if !upperCreated || !workCreated || !rootCreated { + t.Fatalf("expected all upper: %v, work: %v, root: %v to be created", upperCreated, workCreated, rootCreated) + } + + expectedPmem := 0 + if enforcer.PmemMountCalls != expectedPmem { + t.Errorf("expected %d attempt at pmem mount enforcement, got %d", expectedPmem, enforcer.PmemMountCalls) + } + + expectedOverlay := 1 + if enforcer.OverlayMountCalls != expectedOverlay { + t.Fatalf("expected %d attempts at overlay mount enforcement, got %d", expectedOverlay, enforcer.OverlayMountCalls) + } +} + +func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { + return &securitypolicy.OpenDoorSecurityPolicyEnforcer{} +} + +func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer { + return &policy.MountMonitoringSecurityPolicyEnforcer{} +} diff --git a/internal/guest/storage/pmem/pmem.go b/internal/guest/storage/pmem/pmem.go index 63a9158987..e4374d61ed 100644 --- a/internal/guest/storage/pmem/pmem.go +++ b/internal/guest/storage/pmem/pmem.go @@ -5,9 +5,11 @@ package pmem import ( "context" "fmt" + "os" + "github.com/Microsoft/hcsshim/internal/guest/prot" "github.com/Microsoft/hcsshim/internal/log" - "os" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" "github.com/Microsoft/hcsshim/internal/guest/storage" dm "github.com/Microsoft/hcsshim/internal/guest/storage/devicemapper" @@ -63,7 +65,7 @@ func mountInternal(ctx context.Context, source, target string) (err error) { // // Note: both mappingInfo and verityInfo can be non-nil at the same time, in that case // linear target is created first and it becomes the data/hash device for verity target. -func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot.DeviceMappingInfo, verityInfo *prot.DeviceVerityInfo) (err error) { +func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot.DeviceMappingInfo, verityInfo *prot.DeviceVerityInfo, securityPolicy securitypolicy.SecurityPolicyEnforcer) (err error) { mCtx, span := trace.StartSpan(ctx, "pmem::Mount") defer span.End() defer func() { oc.SetSpanStatus(span, err) }() @@ -73,6 +75,16 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. trace.StringAttribute("target", target)) devicePath := fmt.Sprintf(pMemFmt, device) + + var deviceHash string + if verityInfo != nil { + deviceHash = verityInfo.RootDigest + } + err = securityPolicy.EnforcePmemMountPolicy(target, deviceHash) + if err != nil { + return errors.Wrapf(err, "won't mount pmem device %d onto %s", device, target) + } + // dm linear target has to be created first. when verity info is also present, the linear target becomes the data // device instead of the original VPMem. if mappingInfo != nil { @@ -178,7 +190,7 @@ func createDMVerityTarget(ctx context.Context, devPath, devName, target string, } // Unmount unmounts `target` and removes corresponding linear and verity targets when needed -func Unmount(ctx context.Context, devNumber uint32, target string, mappingInfo *prot.DeviceMappingInfo, verityInfo *prot.DeviceVerityInfo) (err error) { +func Unmount(ctx context.Context, devNumber uint32, target string, mappingInfo *prot.DeviceMappingInfo, verityInfo *prot.DeviceVerityInfo, securityPolicy securitypolicy.SecurityPolicyEnforcer) (err error) { _, span := trace.StartSpan(ctx, "pmem::Unmount") defer span.End() defer func() { oc.SetSpanStatus(span, err) }() diff --git a/internal/guest/storage/pmem/pmem_test.go b/internal/guest/storage/pmem/pmem_test.go index 35f1489e72..446b95c1fd 100644 --- a/internal/guest/storage/pmem/pmem_test.go +++ b/internal/guest/storage/pmem/pmem_test.go @@ -8,6 +8,8 @@ import ( "os" "testing" + "github.com/Microsoft/hcsshim/internal/guest/storage/test/policy" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" "github.com/pkg/errors" "golang.org/x/sys/unix" ) @@ -25,7 +27,7 @@ func Test_Mount_Mkdir_Fails_Error(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return expectedErr } - err := Mount(context.Background(), 0, "", nil, nil) + err := Mount(context.Background(), 0, "", nil, nil, openDoorSecurityPolicyEnforcer()) if errors.Cause(err) != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } @@ -49,7 +51,7 @@ func Test_Mount_Mkdir_ExpectedPath(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, target, nil, nil) + err := Mount(context.Background(), 0, target, nil, nil, openDoorSecurityPolicyEnforcer()) if err != nil { t.Fatalf("expected nil error got: %v", err) } @@ -73,7 +75,7 @@ func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, target, nil, nil) + err := Mount(context.Background(), 0, target, nil, nil, openDoorSecurityPolicyEnforcer()) if err != nil { t.Fatalf("expected nil error got: %v", err) } @@ -100,7 +102,7 @@ func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) { // Fake the mount failure to test remove is called return expectedErr } - err := Mount(context.Background(), 0, target, nil, nil) + err := Mount(context.Background(), 0, target, nil, nil, openDoorSecurityPolicyEnforcer()) if errors.Cause(err) != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } @@ -127,7 +129,7 @@ func Test_Mount_Valid_Source(t *testing.T) { } return nil } - err := Mount(context.Background(), device, "/fake/path", nil, nil) + err := Mount(context.Background(), device, "/fake/path", nil, nil, openDoorSecurityPolicyEnforcer()) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -150,7 +152,7 @@ func Test_Mount_Valid_Target(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, expectedTarget, nil, nil) + err := Mount(context.Background(), 0, expectedTarget, nil, nil, openDoorSecurityPolicyEnforcer()) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -173,7 +175,7 @@ func Test_Mount_Valid_FSType(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, "/fake/path", nil, nil) + err := Mount(context.Background(), 0, "/fake/path", nil, nil, openDoorSecurityPolicyEnforcer()) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -196,7 +198,7 @@ func Test_Mount_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, "/fake/path", nil, nil) + err := Mount(context.Background(), 0, "/fake/path", nil, nil, openDoorSecurityPolicyEnforcer()) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -219,8 +221,44 @@ func Test_Mount_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, "/fake/path", nil, nil) + err := Mount(context.Background(), 0, "/fake/path", nil, nil, openDoorSecurityPolicyEnforcer()) if err != nil { t.Fatalf("expected nil err, got: %v", err) } } + +func Test_Security_Policy_Enforcement(t *testing.T) { + clearTestDependencies() + + osMkdirAll = func(path string, perm os.FileMode) error { + return nil + } + + unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { + return nil + } + + enforcer := mountMonitoringSecurityPolicyEnforcer() + err := Mount(context.Background(), 0, "/fake/path", nil, nil, enforcer) + if err != nil { + t.Fatalf("expected nil err, got: %v", err) + } + + expectedPmem := 1 + if enforcer.PmemMountCalls != expectedPmem { + t.Fatalf("expected %d attempt at pmem mount enforcement, got %d", expectedPmem, enforcer.PmemMountCalls) + } + + expectedOverlay := 0 + if enforcer.OverlayMountCalls != expectedOverlay { + t.Fatalf("expected %d attempts at overlay mount enforcement, got %d", expectedOverlay, enforcer.OverlayMountCalls) + } +} + +func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { + return &securitypolicy.OpenDoorSecurityPolicyEnforcer{} +} + +func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer { + return &policy.MountMonitoringSecurityPolicyEnforcer{} +} diff --git a/internal/guest/storage/test/policy/mountmonitoringsecuritypolicyenforcer.go b/internal/guest/storage/test/policy/mountmonitoringsecuritypolicyenforcer.go new file mode 100644 index 0000000000..c67527ffa8 --- /dev/null +++ b/internal/guest/storage/test/policy/mountmonitoringsecuritypolicyenforcer.go @@ -0,0 +1,24 @@ +package policy + +import ( + "github.com/Microsoft/hcsshim/pkg/securitypolicy" +) + +// For testing. Records the number of calls to each method so we can verify +// the expected interactions took place. +type MountMonitoringSecurityPolicyEnforcer struct { + PmemMountCalls int + OverlayMountCalls int +} + +var _ securitypolicy.SecurityPolicyEnforcer = (*MountMonitoringSecurityPolicyEnforcer)(nil) + +func (p *MountMonitoringSecurityPolicyEnforcer) EnforcePmemMountPolicy(target string, deviceHash string) (err error) { + p.PmemMountCalls++ + return nil +} + +func (p *MountMonitoringSecurityPolicyEnforcer) EnforceOverlayMountPolicy(containerID string, layerPaths []string) (err error) { + p.OverlayMountCalls++ + return nil +} diff --git a/internal/guestrequest/types.go b/internal/guestrequest/types.go index 548a8b43fd..a03cf09aa3 100644 --- a/internal/guestrequest/types.go +++ b/internal/guestrequest/types.go @@ -16,7 +16,14 @@ import ( // since the container path is already the scratch path. For linux, the GCS unions // the specified layers and ScratchPath together, placing the resulting union // filesystem at ContainerRootPath. -type CombinedLayers struct { +type LCOWCombinedLayers struct { + ContainerID string `jason:"ContainerID"` + ContainerRootPath string `json:"ContainerRootPath,omitempty"` + Layers []hcsschema.Layer `json:"Layers,omitempty"` + ScratchPath string `json:"ScratchPath,omitempty"` +} + +type WCOWCombinedLayers struct { ContainerRootPath string `json:"ContainerRootPath,omitempty"` Layers []hcsschema.Layer `json:"Layers,omitempty"` ScratchPath string `json:"ScratchPath,omitempty"` @@ -113,6 +120,7 @@ const ( ResourceTypeVPCIDevice ResourceType = "VPCIDevice" ResourceTypeContainerConstraints ResourceType = "ContainerConstraints" ResourceTypeHvSocket ResourceType = "HvSocket" + ResourceTypeSecurityPolicy ResourceType = "SecurityPolicy" ) // GuestRequest is for modify commands passed to the guest. diff --git a/internal/hcsoci/resources_lcow.go b/internal/hcsoci/resources_lcow.go index c024605f73..a6f6e78949 100644 --- a/internal/hcsoci/resources_lcow.go +++ b/internal/hcsoci/resources_lcow.go @@ -42,7 +42,7 @@ func allocateLinuxResources(ctx context.Context, coi *createOptionsInternal, r * containerRootInUVM := r.ContainerRootInUVM() if coi.Spec.Windows != nil && len(coi.Spec.Windows.LayerFolders) > 0 { log.G(ctx).Debug("hcsshim::allocateLinuxResources mounting storage") - rootPath, err := layers.MountContainerLayers(ctx, coi.Spec.Windows.LayerFolders, containerRootInUVM, "", coi.HostingSystem) + rootPath, err := layers.MountContainerLayers(ctx, coi.actualID, coi.Spec.Windows.LayerFolders, containerRootInUVM, "", coi.HostingSystem) if err != nil { return errors.Wrap(err, "failed to mount container storage") } diff --git a/internal/hcsoci/resources_wcow.go b/internal/hcsoci/resources_wcow.go index ef0ae89c38..ef7c8c1373 100644 --- a/internal/hcsoci/resources_wcow.go +++ b/internal/hcsoci/resources_wcow.go @@ -53,7 +53,7 @@ func allocateWindowsResources(ctx context.Context, coi *createOptionsInternal, r if coi.Spec.Root.Path == "" && (coi.HostingSystem != nil || coi.Spec.Windows.HyperV == nil) { log.G(ctx).Debug("hcsshim::allocateWindowsResources mounting storage") containerRootInUVM := r.ContainerRootInUVM() - containerRootPath, err := layers.MountContainerLayers(ctx, coi.Spec.Windows.LayerFolders, containerRootInUVM, "", coi.HostingSystem) + containerRootPath, err := layers.MountContainerLayers(ctx, coi.actualID, coi.Spec.Windows.LayerFolders, containerRootInUVM, "", coi.HostingSystem) if err != nil { return errors.Wrap(err, "failed to mount container storage") } diff --git a/internal/jobcontainers/jobcontainer.go b/internal/jobcontainers/jobcontainer.go index f2d338e82f..576f8420aa 100644 --- a/internal/jobcontainers/jobcontainer.go +++ b/internal/jobcontainers/jobcontainer.go @@ -131,7 +131,7 @@ func Create(ctx context.Context, id string, s *specs.Spec) (_ cow.Container, _ * }() sandboxPath := fmt.Sprintf(sandboxMountFormat, id) - if err := mountLayers(ctx, s, sandboxPath); err != nil { + if err := mountLayers(ctx, id, s, sandboxPath); err != nil { return nil, nil, errors.Wrap(err, "failed to mount container layers") } container.sandboxMount = sandboxPath diff --git a/internal/jobcontainers/storage.go b/internal/jobcontainers/storage.go index 7c4fe4b168..1ff1ac4744 100644 --- a/internal/jobcontainers/storage.go +++ b/internal/jobcontainers/storage.go @@ -15,7 +15,7 @@ import ( // Trailing backslash required for SetVolumeMountPoint and DeleteVolumeMountPoint const sandboxMountFormat = `C:\C\%s\` -func mountLayers(ctx context.Context, s *specs.Spec, volumeMountPath string) error { +func mountLayers(ctx context.Context, containerId string, s *specs.Spec, volumeMountPath string) error { if s == nil || s.Windows == nil || s.Windows.LayerFolders == nil { return errors.New("field 'Spec.Windows.Layerfolders' is not populated") } @@ -41,7 +41,7 @@ func mountLayers(ctx context.Context, s *specs.Spec, volumeMountPath string) err if s.Root.Path == "" { log.G(ctx).Debug("mounting job container storage") - containerRootPath, err := layers.MountContainerLayers(ctx, s.Windows.LayerFolders, "", volumeMountPath, nil) + containerRootPath, err := layers.MountContainerLayers(ctx, containerId, s.Windows.LayerFolders, "", volumeMountPath, nil) if err != nil { return errors.Wrap(err, "failed to mount container storage") } diff --git a/internal/layers/layers.go b/internal/layers/layers.go index 86ff5e53a2..3af5883137 100644 --- a/internal/layers/layers.go +++ b/internal/layers/layers.go @@ -75,7 +75,8 @@ func (layers *ImageLayers) Release(ctx context.Context, all bool) error { // the host at `volumeMountPath`. // // TODO dcantah: Keep better track of the layers that are added, don't simply discard the SCSI, VSMB, etc. resource types gotten inside. -func MountContainerLayers(ctx context.Context, layerFolders []string, guestRoot, volumeMountPath string, uvm *uvmpkg.UtilityVM) (_ string, err error) { + +func MountContainerLayers(ctx context.Context, containerId string, layerFolders []string, guestRoot string, volumeMountPath string, uvm *uvmpkg.UtilityVM) (_ string, err error) { log.G(ctx).WithField("layerFolders", layerFolders).Debug("hcsshim::mountContainerLayers") if uvm == nil { @@ -212,7 +213,7 @@ func MountContainerLayers(ctx context.Context, layerFolders []string, guestRoot, rootfs = containerScratchPathInUVM } else { rootfs = ospath.Join(uvm.OS(), guestRoot, uvmpkg.RootfsPath) - err = uvm.CombineLayersLCOW(ctx, lcowUvmLayerPaths, containerScratchPathInUVM, rootfs) + err = uvm.CombineLayersLCOW(ctx, containerId, lcowUvmLayerPaths, containerScratchPathInUVM, rootfs) } if err != nil { return "", err @@ -326,9 +327,16 @@ func UnmountContainerLayers(ctx context.Context, layerFolders []string, containe // Always remove the combined layers as they are part of scsi/vsmb/vpmem // removals. - if err := uvm.RemoveCombinedLayers(ctx, containerRootPath); err != nil { - log.G(ctx).WithError(err).Warn("failed guest request to remove combined layers") - retError = err + if uvm.OS() == "windows" { + if err := uvm.RemoveCombinedLayersWCOW(ctx, containerRootPath); err != nil { + log.G(ctx).WithError(err).Warn("failed guest request to remove combined layers") + retError = err + } + } else { + if err := uvm.RemoveCombinedLayersLCOW(ctx, containerRootPath); err != nil { + log.G(ctx).WithError(err).Warn("failed guest request to remove combined layers") + retError = err + } } // Unload the SCSI scratch path diff --git a/internal/oci/annotations.go b/internal/oci/annotations.go index eb011c656a..e8acdb6422 100644 --- a/internal/oci/annotations.go +++ b/internal/oci/annotations.go @@ -210,4 +210,7 @@ const ( // AnnotationNcproxyContainerID indicates whether or not to use the hcsshim container ID // when setting up ncproxy and computeagent AnnotationNcproxyContainerID = "io.microsoft.network.ncproxy.containerid" + + // AnnotationSecurityPolicy is used to specify a security policy for opengcs to enforce + AnnotationSecurityPolicy = "io.microsoft.virtualmachine.lcow.securitypolicy" ) diff --git a/internal/oci/uvm.go b/internal/oci/uvm.go index 06715099f0..9ddae4b445 100644 --- a/internal/oci/uvm.go +++ b/internal/oci/uvm.go @@ -329,6 +329,8 @@ func SpecToUVMCreateOpts(ctx context.Context, s *specs.Spec, id, owner string) ( lopts.BootFilesPath = parseAnnotationsString(s.Annotations, AnnotationBootFilesRootPath, lopts.BootFilesPath) lopts.CPUGroupID = parseAnnotationsString(s.Annotations, AnnotationCPUGroupID, lopts.CPUGroupID) lopts.NetworkConfigProxy = parseAnnotationsString(s.Annotations, AnnotationNetworkConfigProxy, lopts.NetworkConfigProxy) + lopts.SecurityPolicy = parseAnnotationsString(s.Annotations, AnnotationSecurityPolicy, lopts.SecurityPolicy) + handleAnnotationPreferredRootFSType(ctx, s.Annotations, lopts) handleAnnotationKernelDirectBoot(ctx, s.Annotations, lopts) diff --git a/internal/tools/securitypolicy/README.md b/internal/tools/securitypolicy/README.md new file mode 100644 index 0000000000..e2bb4c4264 --- /dev/null +++ b/internal/tools/securitypolicy/README.md @@ -0,0 +1,83 @@ +# securitypolicy + +Takes a configuration to a TOML file and outputs a Base64 encoded string of the +generated security policy. + +`securitypolicy` exists as a tool to make it easier to generate security policies +for developers working functionality related to security policy in this repository. +It is not intended to be used by "end users" but could be used as a basis for +such a tool. + +A Base64 encoded version of policy is sent as an annotation to GCS for processing. +The `securitypolicy` tool will, by default, output Base64 encoded JSON. + +Running the tool can take a long time as each layer for each container must +be downloaded, turned into an ext4, and finally a dm-verity root hash calculated. + +## Example TOML configuration file + +```toml +[[image]] +name = "rust:1.52.1" +command = "rustc --help" +``` + +### Converted to JSON + +The above TOML configuration gets translated into the appropriate policy that is +represented in JSON. + +```json +{ + "allow_all": false, + "containers": [ + { + "command": "/pause", + "layers": [ + "16b514057a06ad665f92c02863aca074fd5976c755d26bff16365299169e8415" + ] + }, + { + "command": "rustc --help", + "layers": [ + "fe84c9d5bfddd07a2624d00333cf13c1a9c941f3a261f13ead44fc6a93bc0e7a", + "4dedae42847c704da891a28c25d32201a1ae440bce2aecccfa8e6f03b97a6a6c", + "41d64cdeb347bf236b4c13b7403b633ff11f1cf94dbc7cf881a44d6da88c5156", + "eb36921e1f82af46dfe248ef8f1b3afb6a5230a64181d960d10237a08cd73c79", + "e769d7487cc314d3ee748a4440805317c19262c7acd2fdbdb0d47d2e4613a15c", + "1b80f120dbd88e4355d6241b519c3e25290215c469516b49dece9cf07175a766" + ] + } + ] +} +``` + +## CLI Options + +- -c + +TOML configuration file to process (required) + +- -j + +output raw JSON in addition to the Base64 encoded version + +- -u + +username to use to login to remote container services (defaults to anonymous) + +- -p + +password to use to login to remote container services (defaults to anonymous) + +## Pause container + +All LCOW pods require a pause container to run. The pause container must be +included in the policy. As this tool is aimed at LCOW developers, a default +version of the pause container is automatically added to policy even though it +isn't in the TOML configuration. + +If the version of the pause container changes from 3.1, you will need to update +the hardcoded root hash by running the `dmverity-vhd` to compute the root hash +for the new container and update this tool accordingly. + diff --git a/internal/tools/securitypolicy/main.go b/internal/tools/securitypolicy/main.go new file mode 100644 index 0000000000..64d0d724e4 --- /dev/null +++ b/internal/tools/securitypolicy/main.go @@ -0,0 +1,179 @@ +package main + +import ( + "encoding/base64" + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "os" + + "github.com/BurntSushi/toml" + "github.com/Microsoft/hcsshim/ext4/dmverity" + "github.com/Microsoft/hcsshim/ext4/tar2ext4" + sp "github.com/Microsoft/hcsshim/pkg/securitypolicy" + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/name" + "github.com/google/go-containerregistry/pkg/v1/remote" +) + +var ( + config_file = flag.String("c", "", "config") + output_json = flag.Bool("j", false, "json") + username = flag.String("u", "", "username") + password = flag.String("p", "", "password") +) + +func main() { + flag.Parse() + if flag.NArg() != 0 || len(*config_file) == 0 { + flag.Usage() + os.Exit(1) + } + + err := func() (err error) { + configData, err := ioutil.ReadFile(*config_file) + if err != nil { + return err + } + + config := &Config{ + AllowAll: false, + Images: []Image{}, + } + + err = toml.Unmarshal(configData, config) + if err != nil { + return err + } + + policy, err := func() (sp.SecurityPolicy, error) { + if config.AllowAll { + return createOpenDoorPolicy(), nil + } else { + return createPolicyFromConfig(*config) + } + }() + + if err != nil { + return err + } + + j, err := json.Marshal(policy) + if err != nil { + return err + } + if *output_json { + fmt.Printf("%s\n", j) + } + b := base64.StdEncoding.EncodeToString(j) + fmt.Printf("%s\n", b) + + return nil + }() + + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +type Image struct { + Name string `toml:"name"` + Command string `toml:"command"` +} + +type Config struct { + AllowAll bool `toml:"allow_all"` + Images []Image `toml:"image"` +} + +func createOpenDoorPolicy() sp.SecurityPolicy { + return sp.SecurityPolicy{ + AllowAll: true, + } +} + +func createPolicyFromConfig(config Config) (sp.SecurityPolicy, error) { + p := sp.SecurityPolicy{} + + // for now, we hardcode the pause container version 3.1 here + // in a final end user tool, we would not do it this way. + // as this is a tool for use by developers currently working + // on security policy implementation code + pausec := sp.SecurityPolicyContainer{ + Command: "/pause", + Layers: []string{"16b514057a06ad665f92c02863aca074fd5976c755d26bff16365299169e8415"}, + } + p.Containers = append(p.Containers, pausec) + + var imageOptions []remote.Option + if len(*username) != 0 && len(*password) != 0 { + auth := authn.Basic{ + Username: *username, + Password: *password} + c, _ := auth.Authorization() + authOption := remote.WithAuth(authn.FromConfig(*c)) + imageOptions = append(imageOptions, authOption) + } + + for _, image := range config.Images { + container := sp.SecurityPolicyContainer{ + Command: image.Command, + Layers: []string{}, + } + ref, err := name.ParseReference(image.Name) + if err != nil { + return p, fmt.Errorf("'%s' isn't a valid image name\n", image.Name) + } + img, err := remote.Image(ref, imageOptions...) + if err != nil { + return p, fmt.Errorf("unable to fetch image '%s': %s", image.Name, err.Error()) + } + + layers, err := img.Layers() + if err != nil { + return p, err + } + + for _, layer := range layers { + r, err := layer.Uncompressed() + if err != nil { + return p, err + } + + out, err := ioutil.TempFile("", "") + if err != nil { + return p, err + } + defer os.Remove(out.Name()) + + opts := []tar2ext4.Option{ + tar2ext4.ConvertWhiteout, + tar2ext4.MaximumDiskSize(128 * 1024 * 1024 * 1024), + } + + err = tar2ext4.Convert(r, out, opts...) + if err != nil { + return p, err + } + + data, err := ioutil.ReadFile(out.Name()) + if err != nil { + return p, err + } + + tree, err := dmverity.MerkleTree(data) + if err != nil { + return p, err + } + hash := dmverity.RootHash(tree) + hash_string := fmt.Sprintf("%x", hash) + container.Layers = append(container.Layers, hash_string) + } + + p.Containers = append(p.Containers, container) + } + + return p, nil +} diff --git a/internal/uvm/combine_layers.go b/internal/uvm/combine_layers.go index 74c0ac70e2..5971cc49dc 100644 --- a/internal/uvm/combine_layers.go +++ b/internal/uvm/combine_layers.go @@ -20,7 +20,7 @@ func (uvm *UtilityVM) CombineLayersWCOW(ctx context.Context, layerPaths []hcssch GuestRequest: guestrequest.GuestRequest{ ResourceType: guestrequest.ResourceTypeCombinedLayers, RequestType: requesttype.Add, - Settings: guestrequest.CombinedLayers{ + Settings: guestrequest.WCOWCombinedLayers{ ContainerRootPath: containerRootPath, Layers: layerPaths, }, @@ -35,7 +35,7 @@ func (uvm *UtilityVM) CombineLayersWCOW(ctx context.Context, layerPaths []hcssch // // NOTE: `layerPaths`, `scrathPath`, and `rootfsPath` are paths from within the // UVM. -func (uvm *UtilityVM) CombineLayersLCOW(ctx context.Context, layerPaths []string, scratchPath, rootfsPath string) error { +func (uvm *UtilityVM) CombineLayersLCOW(ctx context.Context, containerId string, layerPaths []string, scratchPath, rootfsPath string) error { if uvm.operatingSystem != "linux" { return errNotSupported } @@ -48,7 +48,8 @@ func (uvm *UtilityVM) CombineLayersLCOW(ctx context.Context, layerPaths []string GuestRequest: guestrequest.GuestRequest{ ResourceType: guestrequest.ResourceTypeCombinedLayers, RequestType: requesttype.Add, - Settings: guestrequest.CombinedLayers{ + Settings: guestrequest.LCOWCombinedLayers{ + ContainerID: containerId, ContainerRootPath: rootfsPath, Layers: layers, ScratchPath: scratchPath, @@ -61,12 +62,25 @@ func (uvm *UtilityVM) CombineLayersLCOW(ctx context.Context, layerPaths []string // RemoveCombinedLayers removes the previously combined layers at `rootfsPath`. // // NOTE: `rootfsPath` is the path from within the UVM. -func (uvm *UtilityVM) RemoveCombinedLayers(ctx context.Context, rootfsPath string) error { +func (uvm *UtilityVM) RemoveCombinedLayersWCOW(ctx context.Context, rootfsPath string) error { msr := &hcsschema.ModifySettingRequest{ GuestRequest: guestrequest.GuestRequest{ ResourceType: guestrequest.ResourceTypeCombinedLayers, RequestType: requesttype.Remove, - Settings: guestrequest.CombinedLayers{ + Settings: guestrequest.WCOWCombinedLayers{ + ContainerRootPath: rootfsPath, + }, + }, + } + return uvm.modify(ctx, msr) +} + +func (uvm *UtilityVM) RemoveCombinedLayersLCOW(ctx context.Context, rootfsPath string) error { + msr := &hcsschema.ModifySettingRequest{ + GuestRequest: guestrequest.GuestRequest{ + ResourceType: guestrequest.ResourceTypeCombinedLayers, + RequestType: requesttype.Remove, + Settings: guestrequest.LCOWCombinedLayers{ ContainerRootPath: rootfsPath, }, }, diff --git a/internal/uvm/create_lcow.go b/internal/uvm/create_lcow.go index de18ca1c60..192d16a95a 100644 --- a/internal/uvm/create_lcow.go +++ b/internal/uvm/create_lcow.go @@ -75,6 +75,7 @@ type OptionsLCOW struct { PreferredRootFSType PreferredRootFSType // If `KernelFile` is `InitrdFile` use `PreferredRootFSTypeInitRd`. If `KernelFile` is `VhdFile` use `PreferredRootFSTypeVHD` EnableColdDiscardHint bool // Whether the HCS should use cold discard hints. Defaults to false VPCIEnabled bool // Whether the kernel should enable pci + SecurityPolicy string // Optional security policy } // defaultLCOWOSBootFilesPath returns the default path used to locate the LCOW @@ -120,6 +121,7 @@ func NewDefaultOptionsLCOW(id, owner string) *OptionsLCOW { PreferredRootFSType: PreferredRootFSTypeInitRd, EnableColdDiscardHint: false, VPCIEnabled: false, + SecurityPolicy: "", } if _, err := os.Stat(filepath.Join(opts.BootFilesPath, VhdFile)); err == nil { diff --git a/internal/uvm/security_policy.go b/internal/uvm/security_policy.go new file mode 100644 index 0000000000..e826b6d679 --- /dev/null +++ b/internal/uvm/security_policy.go @@ -0,0 +1,50 @@ +package uvm + +import ( + "context" + "errors" + "fmt" + + "github.com/Microsoft/hcsshim/internal/guestrequest" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/requesttype" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" +) + +var ( + ErrBadPolicy = errors.New("your policy looks suspicious or is badly formatted") +) + +// SetSecurityPolicy tells the gcs instance in the UVM what policy to apply. +// +// This has to happen before we start mounting things or generally changing +// the state of the UVM after is has been measured at startup +func (uvm *UtilityVM) SetSecurityPolicy(ctx context.Context, policy string) error { + if uvm.operatingSystem != "linux" { + return errNotSupported + } + + uvm.m.Lock() + defer uvm.m.Unlock() + + modification := &hcsschema.ModifySettingRequest{ + RequestType: requesttype.Add, + Settings: securitypolicy.EncodedSecurityPolicy{ + SecurityPolicy: policy, + }, + } + + modification.GuestRequest = guestrequest.GuestRequest{ + ResourceType: guestrequest.ResourceTypeSecurityPolicy, + RequestType: requesttype.Add, + Settings: securitypolicy.EncodedSecurityPolicy{ + SecurityPolicy: policy, + }, + } + + if err := uvm.modify(ctx, modification); err != nil { + return fmt.Errorf("uvm::Policy: failed to modify utility VM configuration: %s", err) + } + + return nil +} diff --git a/pkg/securitypolicy/securitypolicy.go b/pkg/securitypolicy/securitypolicy.go new file mode 100644 index 0000000000..988b63713d --- /dev/null +++ b/pkg/securitypolicy/securitypolicy.go @@ -0,0 +1,38 @@ +package securitypolicy + +// SecurityPolicy is the user supplied security policy to enforce. +type SecurityPolicy struct { + // Flag that when set to true allows for all checks to pass. Currently used + // to run with security policy enforcement "running dark"; checks can be in + // place but the default policy that is created on startup has AllowAll set + // to true, thus making policy enforcement effectively "off" from a logical + // standpoint. Policy enforcement isn't actually off as the policy is "allow + // everything:. + AllowAll bool `json:"allow_all"` + // One or more containers that are allowed to run + Containers []SecurityPolicyContainer `json:"containers"` +} + +// SecurityPolicyContainer contains information about a container that should be +// allowed to run. "Allowed to run" is a bit of misnomer. For example, we +// enforce that when an overlay file system is constructed that it must be a +// an ordering of layers (as seen through dm-verity root hashes of devices) +// that match a listing from Layers in one of any valid SecurityPolicyContainer +// entries. Once that overlay creation is allowed, the command could not match +// policy and running the command would be rejected. +type SecurityPolicyContainer struct { + // The command that we will allow the container to execute + Command string `json:"command"` + // An ordered list of dm-verity root hashes for each layer that makes up + // "a container". Containers are constructed as an overlay file system. The + // order that the layers are overlayed is important and needs to be enforced + // as part of policy. + Layers []string `json:"layers"` +} + +// EncodedSecurityPolicy is a JSON representation of SecurityPolicy that has +// been base64 encoded for storage in an annotation embedded within another +// JSON configuration +type EncodedSecurityPolicy struct { + SecurityPolicy string `json:"SecurityPolicy,omitempty"` +} diff --git a/pkg/securitypolicy/securitypolicy_test.go b/pkg/securitypolicy/securitypolicy_test.go new file mode 100644 index 0000000000..1b47e1ae6d --- /dev/null +++ b/pkg/securitypolicy/securitypolicy_test.go @@ -0,0 +1,459 @@ +package securitypolicy + +import ( + "math/rand" + "reflect" + "strconv" + "strings" + "testing" + "testing/quick" + "time" +) + +const ( + maxContainersInGeneratedPolicy = 32 + maxLayersInGeneratedContainer = 32 + maxGeneratedContainerID = 1000000 + maxGeneratedCommandLength = 128 + maxGeneratedMountTargetLength = 256 + rootHashLength = 64 +) + +// Do we correctly set up the data structures that are part of creating a new +// StandardSecurityPolicyEnforcer +func Test_StandardSecurityPolicyEnforcer_Devices_Initialization(t *testing.T) { + f := func(p *SecurityPolicy) bool { + policy, err := NewStandardSecurityPolicyEnforcer(p) + if err != nil { + return false + } + + // there should be a device entry for each container + if len(p.Containers) != len(policy.Devices) { + return false + } + + // in each device entry that corresponds to a container, + // the array should have space for all the root hashes + for i := 0; i < len(p.Containers); i++ { + if len(p.Containers[i].Layers) != len(policy.Devices[i]) { + return false + } + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 1000}); err != nil { + t.Errorf("Test_StandardSecurityPolicyEnforcer_Devices_Initialization failed: %v", err) + } +} + +// Verify that StandardSecurityPolicyEnforcer.EnforcePmemMountPolicy will return +// an error when there's no matching root hash in the policy +func Test_EnforcePmemMountPolicy_No_Matches(t *testing.T) { + f := func(p *SecurityPolicy) bool { + + policy, err := NewStandardSecurityPolicyEnforcer(p) + if err != nil { + return false + } + + r := rand.New(rand.NewSource(time.Now().UnixNano())) + target := generateMountTarget(r) + rootHash := generateInvalidRootHash(r) + + err = policy.EnforcePmemMountPolicy(target, rootHash) + + // we expect an error, not getting one means something is broken + return err != nil + } + + if err := quick.Check(f, &quick.Config{MaxCount: 1000}); err != nil { + t.Errorf("Test_EnforcePmemMountPolicy_No_Matches failed: %v", err) + } +} + +// Verify that StandardSecurityPolicyEnforcer.EnforcePmemMountPolicy doesn't return +// an error when there's a matching root hash in the policy +func Test_EnforcePmemMountPolicy_Matches(t *testing.T) { + f := func(p *SecurityPolicy) bool { + + policy, err := NewStandardSecurityPolicyEnforcer(p) + if err != nil { + return false + } + + r := rand.New(rand.NewSource(time.Now().UnixNano())) + target := generateMountTarget(r) + rootHash := selectRootHashFromPolicy(p, r) + + err = policy.EnforcePmemMountPolicy(target, rootHash) + + // getting an error means something is broken + return err == nil + } + + if err := quick.Check(f, &quick.Config{MaxCount: 1000}); err != nil { + t.Errorf("Test_EnforcePmemMountPolicy_No_Matches failed: %v", err) + } +} + +// Verify that StandardSecurityPolicyEnforcer.EnforceOverlayMountPolicy will return +// an error when there's no matching overlay targets. +func Test_EnforceOverlayMountPolicy_No_Matches(t *testing.T) { + f := func(p *SecurityPolicy) bool { + + policy, err := NewStandardSecurityPolicyEnforcer(p) + if err != nil { + return false + } + + r := rand.New(rand.NewSource(time.Now().UnixNano())) + containerID := generateContainerId(r) + container := selectContainerFromPolicy(p, r) + + layerPaths, err := createInvalidOverlayForContainer(policy, container, r) + if err != nil { + return false + } + + err = policy.EnforceOverlayMountPolicy(containerID, layerPaths) + + // not getting an error means something is broken + return err != nil + } + + if err := quick.Check(f, &quick.Config{MaxCount: 1000}); err != nil { + t.Errorf("Test_EnforceOverlayMountPolicy_No_Matches failed: %v", err) + } +} + +// Verify that StandardSecurityPolicyEnforcer.EnforceOverlayMountPolicy doesn't +// return an error when there's a valid overlay target. +func Test_EnforceOverlayMountPolicy_Matches(t *testing.T) { + f := func(p *SecurityPolicy) bool { + + policy, err := NewStandardSecurityPolicyEnforcer(p) + if err != nil { + return false + } + + r := rand.New(rand.NewSource(time.Now().UnixNano())) + containerID := generateContainerId(r) + container := selectContainerFromPolicy(p, r) + + layerPaths, err := createValidOverlayForContainer(policy, container, r) + if err != nil { + return false + } + + err = policy.EnforceOverlayMountPolicy(containerID, layerPaths) + + // getting an error means something is broken + return err == nil + } + + if err := quick.Check(f, &quick.Config{MaxCount: 1000}); err != nil { + t.Errorf("Test_EnforceOverlayMountPolicy_Matches: %v", err) + } +} + +// Tests the specific case of trying to mount the same overlay twice using the /// same container id. This should be disallowed. +func Test_EnforceOverlayMountPolicy_Overlay_Single_Container_Twice(t *testing.T) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + p := generateSecurityPolicy(r, 1) + + policy, err := NewStandardSecurityPolicyEnforcer(p) + if err != nil { + t.Fatalf("expected nil error got: %v", err) + } + + containerID := generateContainerId(r) + container := selectContainerFromPolicy(p, r) + + layerPaths, err := createValidOverlayForContainer(policy, container, r) + if err != nil { + t.Fatalf("expected nil error got: %v", err) + } + + err = policy.EnforceOverlayMountPolicy(containerID, layerPaths) + if err != nil { + t.Fatalf("expected nil error got: %v", err) + } + + err = policy.EnforceOverlayMountPolicy(containerID, layerPaths) + if err == nil { + t.Fatalf("able to create overlay for the same container twice") + } +} + +// Test that if more than 1 instance of the same image is started, that we can +// create all the overlays that are required. So for example, if there are +// 13 instances of image X that all share the same overlay of root hashes, +// all 13 should be allowed. +func Test_EnforceOverlayMountPolicy_Multiple_Instances_Same_Container(t *testing.T) { + for containersToCreate := 2; containersToCreate <= maxContainersInGeneratedPolicy; containersToCreate++ { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + var containers []SecurityPolicyContainer + + for i := 1; i <= int(containersToCreate); i++ { + c := SecurityPolicyContainer{ + Command: "command " + strconv.Itoa(i), + Layers: []string{"1", "2"}, + } + + containers = append(containers, c) + } + + p := &SecurityPolicy{ + AllowAll: false, + Containers: containers, + } + + policy, err := NewStandardSecurityPolicyEnforcer(p) + if err != nil { + t.Fatal("unexpected error on test setup") + } + + idsUsed := map[string]bool{} + for i := 0; i < len(p.Containers); i++ { + layerPaths, err := createValidOverlayForContainer(policy, p.Containers[i], r) + if err != nil { + t.Fatal("unexpected error on test setup") + } + + idUnique := false + var id string + for idUnique == false { + id = generateContainerId(r) + _, found := idsUsed[id] + idUnique = !found + idsUsed[id] = true + } + err = policy.EnforceOverlayMountPolicy(id, layerPaths) + if err != nil { + t.Fatalf("failed with %d containers", containersToCreate) + } + } + + t.Logf("ok for %d\n", containersToCreate) + } +} + +// Verify that can't create more containers using an overlay than exists in the +// policy. For example, if there is a single instance of image Foo in the +// policy, we should be able to create a single container for that overlay +// but no more than that one. +func Test_EnforceOverlayMountPolicy_Overlay_Single_Container_Twice_With_Different_IDs(t *testing.T) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + p := generateSecurityPolicy(r, 1) + + policy, err := NewStandardSecurityPolicyEnforcer(p) + if err != nil { + t.Fatalf("expected nil error got: %v", err) + } + + var containerIDOne, containerIDTwo string + + for containerIDOne == containerIDTwo { + containerIDOne = generateContainerId(r) + containerIDTwo = generateContainerId(r) + } + container := selectContainerFromPolicy(p, r) + + layerPaths, err := createValidOverlayForContainer(policy, container, r) + if err != nil { + t.Fatalf("expected nil error got: %v", err) + } + + err = policy.EnforceOverlayMountPolicy(containerIDOne, layerPaths) + if err != nil { + t.Fatalf("expected nil error got: %v", err) + } + + err = policy.EnforceOverlayMountPolicy(containerIDTwo, layerPaths) + if err == nil { + t.Fatalf("able to reuse an overlay across containers") + } +} + +// +// Setup and "fixtures" follow... +// + +func (*SecurityPolicy) Generate(r *rand.Rand, size int) reflect.Value { + p := generateSecurityPolicy(r, maxContainersInGeneratedPolicy) + return reflect.ValueOf(p) +} + +func generateSecurityPolicy(r *rand.Rand, numContainers int32) *SecurityPolicy { + p := &SecurityPolicy{} + p.AllowAll = false + containers := atLeastOneAtMost(r, numContainers) + for i := 0; i < (int)(containers); i++ { + p.Containers = append(p.Containers, generateSecurityPolicyContainer(r, maxLayersInGeneratedContainer)) + } + + return p +} + +func generateSecurityPolicyContainer(r *rand.Rand, size int32) SecurityPolicyContainer { + c := SecurityPolicyContainer{} + c.Command = generateCommand(r) + layers := int(atLeastOneAtMost(r, size)) + for i := 0; i < layers; i++ { + c.Layers = append(c.Layers, generateRootHash(r)) + } + + return c +} + +func generateRootHash(r *rand.Rand) string { + return randString(r, rootHashLength) +} + +func generateCommand(r *rand.Rand) string { + return randVariableString(r, maxGeneratedCommandLength) +} + +func generateMountTarget(r *rand.Rand) string { + return randVariableString(r, maxGeneratedMountTargetLength) +} + +func generateInvalidRootHash(r *rand.Rand) string { + // Guaranteed to be an incorrect size as it maxes out in size at one less + // than the correct length. If this ever creates a hash that passes, we + // have a seriously weird bug + return randVariableString(r, rootHashLength-1) +} + +func selectRootHashFromPolicy(policy *SecurityPolicy, r *rand.Rand) string { + + numberOfContainersInPolicy := len(policy.Containers) + container := policy.Containers[r.Intn(numberOfContainersInPolicy)] + numberOfLayersInContainer := len(container.Layers) + + return container.Layers[r.Intn(numberOfLayersInContainer)] +} + +func generateContainerId(r *rand.Rand) string { + id := atLeastOneAtMost(r, maxGeneratedContainerID) + return strconv.FormatInt(int64(id), 10) +} + +func selectContainerFromPolicy(policy *SecurityPolicy, r *rand.Rand) SecurityPolicyContainer { + numberOfContainersInPolicy := len(policy.Containers) + return policy.Containers[r.Intn(numberOfContainersInPolicy)] +} + +func createValidOverlayForContainer(enforcer SecurityPolicyEnforcer, container SecurityPolicyContainer, r *rand.Rand) ([]string, error) { + // storage for our mount paths + overlay := make([]string, len(container.Layers)) + + for i := 0; i < len(container.Layers); i++ { + mount := generateMountTarget(r) + err := enforcer.EnforcePmemMountPolicy(mount, container.Layers[i]) + if err != nil { + return overlay, err + } + + overlay[len(overlay)-i-1] = mount + } + + return overlay, nil +} + +func createInvalidOverlayForContainer(enforcer SecurityPolicyEnforcer, container SecurityPolicyContainer, r *rand.Rand) ([]string, error) { + method := r.Intn(3) + if method == 0 { + return invalidOverlaySameSizeWrongMounts(enforcer, container, r) + } else if method == 1 { + return invalidOverlayCorrectDevicesWrongOrderSomeMissing(enforcer, container, r) + } else { + return invalidOverlayRandomJunk(enforcer, container, r) + } +} + +func invalidOverlaySameSizeWrongMounts(enforcer SecurityPolicyEnforcer, container SecurityPolicyContainer, r *rand.Rand) ([]string, error) { + // storage for our mount paths + overlay := make([]string, len(container.Layers)) + + for i := 0; i < len(container.Layers); i++ { + mount := generateMountTarget(r) + err := enforcer.EnforcePmemMountPolicy(mount, container.Layers[i]) + if err != nil { + return overlay, err + } + + // generate a random new mount point to cause an error + overlay[len(overlay)-i-1] = generateMountTarget(r) + } + + return overlay, nil +} + +func invalidOverlayCorrectDevicesWrongOrderSomeMissing(enforcer SecurityPolicyEnforcer, container SecurityPolicyContainer, r *rand.Rand) ([]string, error) { + if len(container.Layers) == 1 { + // won't work with only 1, we need to bail out to another method + return invalidOverlayRandomJunk(enforcer, container, r) + } + // storage for our mount paths + var overlay []string + + for i := 0; i < len(container.Layers); i++ { + mount := generateMountTarget(r) + err := enforcer.EnforcePmemMountPolicy(mount, container.Layers[i]) + if err != nil { + return overlay, err + } + + if r.Intn(10) != 0 { + overlay = append(overlay, mount) + } + } + + return overlay, nil +} + +func invalidOverlayRandomJunk(enforcer SecurityPolicyEnforcer, container SecurityPolicyContainer, r *rand.Rand) ([]string, error) { + // create "junk" for entry + layersToCreate := r.Int31n(maxLayersInGeneratedContainer) + overlay := make([]string, layersToCreate) + + for i := 0; i < int(layersToCreate); i++ { + overlay[i] = generateMountTarget(r) + } + + // setup entirely different and "correct" expected mounting + for i := 0; i < len(container.Layers); i++ { + mount := generateMountTarget(r) + err := enforcer.EnforcePmemMountPolicy(mount, container.Layers[i]) + if err != nil { + return overlay, err + } + } + + return overlay, nil +} + +func randVariableString(r *rand.Rand, maxLen int32) string { + return randString(r, atLeastOneAtMost(r, maxLen)) +} + +func randString(r *rand.Rand, len int32) string { + var s strings.Builder + for i := 0; i < (int)(len); i++ { + s.WriteRune((rune)(0x00ff & r.Int31n(256))) + } + + return s.String() +} + +func randMinMax(r *rand.Rand, min int32, max int32) int32 { + return r.Int31n(max-min+1) + min +} + +func atLeastOneAtMost(r *rand.Rand, most int32) int32 { + return randMinMax(r, 1, most) +} diff --git a/pkg/securitypolicy/securitypolicyenforcer.go b/pkg/securitypolicy/securitypolicyenforcer.go new file mode 100644 index 0000000000..53484afd29 --- /dev/null +++ b/pkg/securitypolicy/securitypolicyenforcer.go @@ -0,0 +1,204 @@ +package securitypolicy + +import ( + "errors" + "fmt" + "sync" +) + +type SecurityPolicyEnforcer interface { + EnforcePmemMountPolicy(target string, deviceHash string) (err error) + EnforceOverlayMountPolicy(containerID string, layerPaths []string) (err error) +} + +func NewSecurityPolicyEnforcer(policy *SecurityPolicy) (SecurityPolicyEnforcer, error) { + if policy == nil { + return nil, errors.New("security policy can't be nil") + } + + if policy.AllowAll { + return &OpenDoorSecurityPolicyEnforcer{}, nil + } else { + return NewStandardSecurityPolicyEnforcer(policy) + } +} + +type StandardSecurityPolicyEnforcer struct { + // The user supplied security policy. + SecurityPolicy SecurityPolicy + // Devices and ContainerIndexToContainerIds are used to build up an + // understanding of the containers running with a UVM as they come up and + // map them back to a container definition from the user supplied + // SecurityPolicy + // + // Devices is a listing of dm-verity root hashes seen when mounting a device + // stored in a "per-container basis". As the UVM goes through its process of + // bringing up containers, we have to piece together information about what + // is going on. + // + // At the time that devices are being mounted, we do not know a container + // that they will be used for; only that there is a device with a given root + // hash that being mounted. We check to make sure that the root hash for the + // devices is a root hash that exists for 1 or more layers in any container + // in the supplied SecurityPolicy. Each "seen" layer is recorded in devices + // as it is mounted. So for example, if a root hash mount is found for the + // device being mounted and the first layer of the first container then we + // record the root hash in Devices[0][0]. + // + // Later, when overlay filesystems created, we verify that the ordered layers + // for said overlay filesystem match one of the device orderings in Devices. + // When a match is found, the index in Devices is the same index in + // SecurityPolicy.Containers. Overlay filesystem creation is the first time we + // have a "container id" available to us. The container id identifies the + // container in question going forward. We record the mapping of Container + // index to container id so that when we have future operations like "run + // command" which come with a container id, we can find the corresponding + // container index and use that to look up the command in the appropriate + // SecurityPolicyContainer instance. + // + // As containers can have exactly the same base image and be "the same" at + // the time we are doing overlay, the ContainerIndexToContainerIds in an + // array of possible containers for a given container id. + // + // implementation details are availanle in: + // - EnforcePmemMountPolicy + // - EnforceOverlayMountPolicy + // - NewStandardSecurityPolicyEnforcer + Devices [][]string + ContainerIndexToContainerIds map[int][]string + // Mutex to prevent concurrent access to fields + mutex *sync.Mutex +} + +var _ SecurityPolicyEnforcer = (*StandardSecurityPolicyEnforcer)(nil) + +func NewStandardSecurityPolicyEnforcer(policy *SecurityPolicy) (*StandardSecurityPolicyEnforcer, error) { + if policy == nil { + return nil, errors.New("security policy can't be nil") + } + + // create new StandardSecurityPolicyEnforcer and add the new SecurityPolicy + // to it + // fill out corresponding devices structure by creating a "same shapped" + // devices listing that corresponds to our container root hash lists + // the devices list will get filled out as layers are mounted + devices := make([][]string, len(policy.Containers)) + + for i, container := range policy.Containers { + devices[i] = make([]string, len(container.Layers)) + } + + return &StandardSecurityPolicyEnforcer{ + SecurityPolicy: *policy, + Devices: devices, + ContainerIndexToContainerIds: map[int][]string{}, + mutex: &sync.Mutex{}, + }, nil +} + +func (policyState *StandardSecurityPolicyEnforcer) EnforcePmemMountPolicy(target string, deviceHash string) (err error) { + policyState.mutex.Lock() + defer policyState.mutex.Unlock() + + if len(policyState.SecurityPolicy.Containers) < 1 { + return errors.New("policy doesn't allow mounting containers") + } + + if deviceHash == "" { + return errors.New("device is missing verity root hash.") + } + + found := false + + for i, container := range policyState.SecurityPolicy.Containers { + for ii, layer := range container.Layers { + if deviceHash == layer { + policyState.Devices[i][ii] = target + found = true + } + } + } + + if !found { + return fmt.Errorf("roothash %s for mount %s doesn't match policy", deviceHash, target) + } + + return nil +} + +func (policyState *StandardSecurityPolicyEnforcer) EnforceOverlayMountPolicy(containerID string, layerPaths []string) (err error) { + policyState.mutex.Lock() + defer policyState.mutex.Unlock() + + if len(policyState.SecurityPolicy.Containers) < 1 { + return errors.New("policy doesn't allow mounting containers") + } + + // find maximum number of containers that could share this overlay + maxPossibleContainerIdsForOverlay := 0 + for _, device_list := range policyState.Devices { + if equalForOverlay(layerPaths, device_list) { + maxPossibleContainerIdsForOverlay++ + } + } + + if maxPossibleContainerIdsForOverlay == 0 { + errmsg := fmt.Sprintf("layerPaths '%v' doesn't match any valid layer path: '%v'", layerPaths, policyState.Devices) + return errors.New(errmsg) + } + + for i, device_list := range policyState.Devices { + if equalForOverlay(layerPaths, device_list) { + existing := policyState.ContainerIndexToContainerIds[i] + if len(existing) < maxPossibleContainerIdsForOverlay { + policyState.ContainerIndexToContainerIds[i] = append(existing, containerID) + } else { + errmsg := fmt.Sprintf("layerPaths '%v' already used in maximum number of container overlays", layerPaths) + return errors.New(errmsg) + } + } + } + + return nil +} + +func equalForOverlay(a1 []string, a2 []string) bool { + // We've stored the layers from bottom to topl they are in layerPaths as + // top to bottom (the order a string gets concatenated for the unix mount + // command). W do our check with that in mind. + if len(a1) == len(a2) { + top_index := len(a2) - 1 + for i, v := range a1 { + if v != a2[top_index-i] { + return false + } + } + } else { + return false + } + return true +} + +type OpenDoorSecurityPolicyEnforcer struct{} + +var _ SecurityPolicyEnforcer = (*OpenDoorSecurityPolicyEnforcer)(nil) + +func (p *OpenDoorSecurityPolicyEnforcer) EnforcePmemMountPolicy(target string, deviceHash string) (err error) { + return nil +} + +func (p *OpenDoorSecurityPolicyEnforcer) EnforceOverlayMountPolicy(containerID string, layerPaths []string) (err error) { + return nil +} + +type ClosedDoorSecurityPolicyEnforcer struct{} + +var _ SecurityPolicyEnforcer = (*ClosedDoorSecurityPolicyEnforcer)(nil) + +func (p *ClosedDoorSecurityPolicyEnforcer) EnforcePmemMountPolicy(target string, deviceHash string) (err error) { + return errors.New("mounting is denied by policy") +} + +func (p *ClosedDoorSecurityPolicyEnforcer) EnforceOverlayMountPolicy(containerID string, layerPaths []string) (err error) { + return errors.New("creating an overlay fs is denied by policy") +} diff --git a/vendor/github.com/BurntSushi/toml/.gitignore b/vendor/github.com/BurntSushi/toml/.gitignore new file mode 100644 index 0000000000..0cd3800377 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/.gitignore @@ -0,0 +1,5 @@ +TAGS +tags +.*.swp +tomlcheck/tomlcheck +toml.test diff --git a/vendor/github.com/BurntSushi/toml/.travis.yml b/vendor/github.com/BurntSushi/toml/.travis.yml new file mode 100644 index 0000000000..8b8afc4f0e --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/.travis.yml @@ -0,0 +1,15 @@ +language: go +go: + - 1.1 + - 1.2 + - 1.3 + - 1.4 + - 1.5 + - 1.6 + - tip +install: + - go install ./... + - go get github.com/BurntSushi/toml-test +script: + - export PATH="$PATH:$HOME/gopath/bin" + - make test diff --git a/vendor/github.com/BurntSushi/toml/COMPATIBLE b/vendor/github.com/BurntSushi/toml/COMPATIBLE new file mode 100644 index 0000000000..6efcfd0ce5 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/COMPATIBLE @@ -0,0 +1,3 @@ +Compatible with TOML version +[v0.4.0](https://github.com/toml-lang/toml/blob/v0.4.0/versions/en/toml-v0.4.0.md) + diff --git a/vendor/github.com/BurntSushi/toml/COPYING b/vendor/github.com/BurntSushi/toml/COPYING new file mode 100644 index 0000000000..01b5743200 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/COPYING @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2013 TOML authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/BurntSushi/toml/Makefile b/vendor/github.com/BurntSushi/toml/Makefile new file mode 100644 index 0000000000..3600848d33 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/Makefile @@ -0,0 +1,19 @@ +install: + go install ./... + +test: install + go test -v + toml-test toml-test-decoder + toml-test -encoder toml-test-encoder + +fmt: + gofmt -w *.go */*.go + colcheck *.go */*.go + +tags: + find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS + +push: + git push origin master + git push github master + diff --git a/vendor/github.com/BurntSushi/toml/README.md b/vendor/github.com/BurntSushi/toml/README.md new file mode 100644 index 0000000000..7c1b37ecc7 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/README.md @@ -0,0 +1,218 @@ +## TOML parser and encoder for Go with reflection + +TOML stands for Tom's Obvious, Minimal Language. This Go package provides a +reflection interface similar to Go's standard library `json` and `xml` +packages. This package also supports the `encoding.TextUnmarshaler` and +`encoding.TextMarshaler` interfaces so that you can define custom data +representations. (There is an example of this below.) + +Spec: https://github.com/toml-lang/toml + +Compatible with TOML version +[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md) + +Documentation: https://godoc.org/github.com/BurntSushi/toml + +Installation: + +```bash +go get github.com/BurntSushi/toml +``` + +Try the toml validator: + +```bash +go get github.com/BurntSushi/toml/cmd/tomlv +tomlv some-toml-file.toml +``` + +[![Build Status](https://travis-ci.org/BurntSushi/toml.svg?branch=master)](https://travis-ci.org/BurntSushi/toml) [![GoDoc](https://godoc.org/github.com/BurntSushi/toml?status.svg)](https://godoc.org/github.com/BurntSushi/toml) + +### Testing + +This package passes all tests in +[toml-test](https://github.com/BurntSushi/toml-test) for both the decoder +and the encoder. + +### Examples + +This package works similarly to how the Go standard library handles `XML` +and `JSON`. Namely, data is loaded into Go values via reflection. + +For the simplest example, consider some TOML file as just a list of keys +and values: + +```toml +Age = 25 +Cats = [ "Cauchy", "Plato" ] +Pi = 3.14 +Perfection = [ 6, 28, 496, 8128 ] +DOB = 1987-07-05T05:45:00Z +``` + +Which could be defined in Go as: + +```go +type Config struct { + Age int + Cats []string + Pi float64 + Perfection []int + DOB time.Time // requires `import time` +} +``` + +And then decoded with: + +```go +var conf Config +if _, err := toml.Decode(tomlData, &conf); err != nil { + // handle error +} +``` + +You can also use struct tags if your struct field name doesn't map to a TOML +key value directly: + +```toml +some_key_NAME = "wat" +``` + +```go +type TOML struct { + ObscureKey string `toml:"some_key_NAME"` +} +``` + +### Using the `encoding.TextUnmarshaler` interface + +Here's an example that automatically parses duration strings into +`time.Duration` values: + +```toml +[[song]] +name = "Thunder Road" +duration = "4m49s" + +[[song]] +name = "Stairway to Heaven" +duration = "8m03s" +``` + +Which can be decoded with: + +```go +type song struct { + Name string + Duration duration +} +type songs struct { + Song []song +} +var favorites songs +if _, err := toml.Decode(blob, &favorites); err != nil { + log.Fatal(err) +} + +for _, s := range favorites.Song { + fmt.Printf("%s (%s)\n", s.Name, s.Duration) +} +``` + +And you'll also need a `duration` type that satisfies the +`encoding.TextUnmarshaler` interface: + +```go +type duration struct { + time.Duration +} + +func (d *duration) UnmarshalText(text []byte) error { + var err error + d.Duration, err = time.ParseDuration(string(text)) + return err +} +``` + +### More complex usage + +Here's an example of how to load the example from the official spec page: + +```toml +# This is a TOML document. Boom. + +title = "TOML Example" + +[owner] +name = "Tom Preston-Werner" +organization = "GitHub" +bio = "GitHub Cofounder & CEO\nLikes tater tots and beer." +dob = 1979-05-27T07:32:00Z # First class dates? Why not? + +[database] +server = "192.168.1.1" +ports = [ 8001, 8001, 8002 ] +connection_max = 5000 +enabled = true + +[servers] + + # You can indent as you please. Tabs or spaces. TOML don't care. + [servers.alpha] + ip = "10.0.0.1" + dc = "eqdc10" + + [servers.beta] + ip = "10.0.0.2" + dc = "eqdc10" + +[clients] +data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it + +# Line breaks are OK when inside arrays +hosts = [ + "alpha", + "omega" +] +``` + +And the corresponding Go types are: + +```go +type tomlConfig struct { + Title string + Owner ownerInfo + DB database `toml:"database"` + Servers map[string]server + Clients clients +} + +type ownerInfo struct { + Name string + Org string `toml:"organization"` + Bio string + DOB time.Time +} + +type database struct { + Server string + Ports []int + ConnMax int `toml:"connection_max"` + Enabled bool +} + +type server struct { + IP string + DC string +} + +type clients struct { + Data [][]interface{} + Hosts []string +} +``` + +Note that a case insensitive match will be tried if an exact match can't be +found. + +A working example of the above can be found in `_examples/example.{go,toml}`. diff --git a/vendor/github.com/BurntSushi/toml/decode.go b/vendor/github.com/BurntSushi/toml/decode.go new file mode 100644 index 0000000000..b0fd51d5b6 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/decode.go @@ -0,0 +1,509 @@ +package toml + +import ( + "fmt" + "io" + "io/ioutil" + "math" + "reflect" + "strings" + "time" +) + +func e(format string, args ...interface{}) error { + return fmt.Errorf("toml: "+format, args...) +} + +// Unmarshaler is the interface implemented by objects that can unmarshal a +// TOML description of themselves. +type Unmarshaler interface { + UnmarshalTOML(interface{}) error +} + +// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`. +func Unmarshal(p []byte, v interface{}) error { + _, err := Decode(string(p), v) + return err +} + +// Primitive is a TOML value that hasn't been decoded into a Go value. +// When using the various `Decode*` functions, the type `Primitive` may +// be given to any value, and its decoding will be delayed. +// +// A `Primitive` value can be decoded using the `PrimitiveDecode` function. +// +// The underlying representation of a `Primitive` value is subject to change. +// Do not rely on it. +// +// N.B. Primitive values are still parsed, so using them will only avoid +// the overhead of reflection. They can be useful when you don't know the +// exact type of TOML data until run time. +type Primitive struct { + undecoded interface{} + context Key +} + +// DEPRECATED! +// +// Use MetaData.PrimitiveDecode instead. +func PrimitiveDecode(primValue Primitive, v interface{}) error { + md := MetaData{decoded: make(map[string]bool)} + return md.unify(primValue.undecoded, rvalue(v)) +} + +// PrimitiveDecode is just like the other `Decode*` functions, except it +// decodes a TOML value that has already been parsed. Valid primitive values +// can *only* be obtained from values filled by the decoder functions, +// including this method. (i.e., `v` may contain more `Primitive` +// values.) +// +// Meta data for primitive values is included in the meta data returned by +// the `Decode*` functions with one exception: keys returned by the Undecoded +// method will only reflect keys that were decoded. Namely, any keys hidden +// behind a Primitive will be considered undecoded. Executing this method will +// update the undecoded keys in the meta data. (See the example.) +func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error { + md.context = primValue.context + defer func() { md.context = nil }() + return md.unify(primValue.undecoded, rvalue(v)) +} + +// Decode will decode the contents of `data` in TOML format into a pointer +// `v`. +// +// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be +// used interchangeably.) +// +// TOML arrays of tables correspond to either a slice of structs or a slice +// of maps. +// +// TOML datetimes correspond to Go `time.Time` values. +// +// All other TOML types (float, string, int, bool and array) correspond +// to the obvious Go types. +// +// An exception to the above rules is if a type implements the +// encoding.TextUnmarshaler interface. In this case, any primitive TOML value +// (floats, strings, integers, booleans and datetimes) will be converted to +// a byte string and given to the value's UnmarshalText method. See the +// Unmarshaler example for a demonstration with time duration strings. +// +// Key mapping +// +// TOML keys can map to either keys in a Go map or field names in a Go +// struct. The special `toml` struct tag may be used to map TOML keys to +// struct fields that don't match the key name exactly. (See the example.) +// A case insensitive match to struct names will be tried if an exact match +// can't be found. +// +// The mapping between TOML values and Go values is loose. That is, there +// may exist TOML values that cannot be placed into your representation, and +// there may be parts of your representation that do not correspond to +// TOML values. This loose mapping can be made stricter by using the IsDefined +// and/or Undecoded methods on the MetaData returned. +// +// This decoder will not handle cyclic types. If a cyclic type is passed, +// `Decode` will not terminate. +func Decode(data string, v interface{}) (MetaData, error) { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v)) + } + if rv.IsNil() { + return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v)) + } + p, err := parse(data) + if err != nil { + return MetaData{}, err + } + md := MetaData{ + p.mapping, p.types, p.ordered, + make(map[string]bool, len(p.ordered)), nil, + } + return md, md.unify(p.mapping, indirect(rv)) +} + +// DecodeFile is just like Decode, except it will automatically read the +// contents of the file at `fpath` and decode it for you. +func DecodeFile(fpath string, v interface{}) (MetaData, error) { + bs, err := ioutil.ReadFile(fpath) + if err != nil { + return MetaData{}, err + } + return Decode(string(bs), v) +} + +// DecodeReader is just like Decode, except it will consume all bytes +// from the reader and decode it for you. +func DecodeReader(r io.Reader, v interface{}) (MetaData, error) { + bs, err := ioutil.ReadAll(r) + if err != nil { + return MetaData{}, err + } + return Decode(string(bs), v) +} + +// unify performs a sort of type unification based on the structure of `rv`, +// which is the client representation. +// +// Any type mismatch produces an error. Finding a type that we don't know +// how to handle produces an unsupported type error. +func (md *MetaData) unify(data interface{}, rv reflect.Value) error { + + // Special case. Look for a `Primitive` value. + if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() { + // Save the undecoded data and the key context into the primitive + // value. + context := make(Key, len(md.context)) + copy(context, md.context) + rv.Set(reflect.ValueOf(Primitive{ + undecoded: data, + context: context, + })) + return nil + } + + // Special case. Unmarshaler Interface support. + if rv.CanAddr() { + if v, ok := rv.Addr().Interface().(Unmarshaler); ok { + return v.UnmarshalTOML(data) + } + } + + // Special case. Handle time.Time values specifically. + // TODO: Remove this code when we decide to drop support for Go 1.1. + // This isn't necessary in Go 1.2 because time.Time satisfies the encoding + // interfaces. + if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) { + return md.unifyDatetime(data, rv) + } + + // Special case. Look for a value satisfying the TextUnmarshaler interface. + if v, ok := rv.Interface().(TextUnmarshaler); ok { + return md.unifyText(data, v) + } + // BUG(burntsushi) + // The behavior here is incorrect whenever a Go type satisfies the + // encoding.TextUnmarshaler interface but also corresponds to a TOML + // hash or array. In particular, the unmarshaler should only be applied + // to primitive TOML values. But at this point, it will be applied to + // all kinds of values and produce an incorrect error whenever those values + // are hashes or arrays (including arrays of tables). + + k := rv.Kind() + + // laziness + if k >= reflect.Int && k <= reflect.Uint64 { + return md.unifyInt(data, rv) + } + switch k { + case reflect.Ptr: + elem := reflect.New(rv.Type().Elem()) + err := md.unify(data, reflect.Indirect(elem)) + if err != nil { + return err + } + rv.Set(elem) + return nil + case reflect.Struct: + return md.unifyStruct(data, rv) + case reflect.Map: + return md.unifyMap(data, rv) + case reflect.Array: + return md.unifyArray(data, rv) + case reflect.Slice: + return md.unifySlice(data, rv) + case reflect.String: + return md.unifyString(data, rv) + case reflect.Bool: + return md.unifyBool(data, rv) + case reflect.Interface: + // we only support empty interfaces. + if rv.NumMethod() > 0 { + return e("unsupported type %s", rv.Type()) + } + return md.unifyAnything(data, rv) + case reflect.Float32: + fallthrough + case reflect.Float64: + return md.unifyFloat64(data, rv) + } + return e("unsupported type %s", rv.Kind()) +} + +func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { + tmap, ok := mapping.(map[string]interface{}) + if !ok { + if mapping == nil { + return nil + } + return e("type mismatch for %s: expected table but found %T", + rv.Type().String(), mapping) + } + + for key, datum := range tmap { + var f *field + fields := cachedTypeFields(rv.Type()) + for i := range fields { + ff := &fields[i] + if ff.name == key { + f = ff + break + } + if f == nil && strings.EqualFold(ff.name, key) { + f = ff + } + } + if f != nil { + subv := rv + for _, i := range f.index { + subv = indirect(subv.Field(i)) + } + if isUnifiable(subv) { + md.decoded[md.context.add(key).String()] = true + md.context = append(md.context, key) + if err := md.unify(datum, subv); err != nil { + return err + } + md.context = md.context[0 : len(md.context)-1] + } else if f.name != "" { + // Bad user! No soup for you! + return e("cannot write unexported field %s.%s", + rv.Type().String(), f.name) + } + } + } + return nil +} + +func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { + tmap, ok := mapping.(map[string]interface{}) + if !ok { + if tmap == nil { + return nil + } + return badtype("map", mapping) + } + if rv.IsNil() { + rv.Set(reflect.MakeMap(rv.Type())) + } + for k, v := range tmap { + md.decoded[md.context.add(k).String()] = true + md.context = append(md.context, k) + + rvkey := indirect(reflect.New(rv.Type().Key())) + rvval := reflect.Indirect(reflect.New(rv.Type().Elem())) + if err := md.unify(v, rvval); err != nil { + return err + } + md.context = md.context[0 : len(md.context)-1] + + rvkey.SetString(k) + rv.SetMapIndex(rvkey, rvval) + } + return nil +} + +func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error { + datav := reflect.ValueOf(data) + if datav.Kind() != reflect.Slice { + if !datav.IsValid() { + return nil + } + return badtype("slice", data) + } + sliceLen := datav.Len() + if sliceLen != rv.Len() { + return e("expected array length %d; got TOML array of length %d", + rv.Len(), sliceLen) + } + return md.unifySliceArray(datav, rv) +} + +func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error { + datav := reflect.ValueOf(data) + if datav.Kind() != reflect.Slice { + if !datav.IsValid() { + return nil + } + return badtype("slice", data) + } + n := datav.Len() + if rv.IsNil() || rv.Cap() < n { + rv.Set(reflect.MakeSlice(rv.Type(), n, n)) + } + rv.SetLen(n) + return md.unifySliceArray(datav, rv) +} + +func (md *MetaData) unifySliceArray(data, rv reflect.Value) error { + sliceLen := data.Len() + for i := 0; i < sliceLen; i++ { + v := data.Index(i).Interface() + sliceval := indirect(rv.Index(i)) + if err := md.unify(v, sliceval); err != nil { + return err + } + } + return nil +} + +func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error { + if _, ok := data.(time.Time); ok { + rv.Set(reflect.ValueOf(data)) + return nil + } + return badtype("time.Time", data) +} + +func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error { + if s, ok := data.(string); ok { + rv.SetString(s) + return nil + } + return badtype("string", data) +} + +func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error { + if num, ok := data.(float64); ok { + switch rv.Kind() { + case reflect.Float32: + fallthrough + case reflect.Float64: + rv.SetFloat(num) + default: + panic("bug") + } + return nil + } + return badtype("float", data) +} + +func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error { + if num, ok := data.(int64); ok { + if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 { + switch rv.Kind() { + case reflect.Int, reflect.Int64: + // No bounds checking necessary. + case reflect.Int8: + if num < math.MinInt8 || num > math.MaxInt8 { + return e("value %d is out of range for int8", num) + } + case reflect.Int16: + if num < math.MinInt16 || num > math.MaxInt16 { + return e("value %d is out of range for int16", num) + } + case reflect.Int32: + if num < math.MinInt32 || num > math.MaxInt32 { + return e("value %d is out of range for int32", num) + } + } + rv.SetInt(num) + } else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 { + unum := uint64(num) + switch rv.Kind() { + case reflect.Uint, reflect.Uint64: + // No bounds checking necessary. + case reflect.Uint8: + if num < 0 || unum > math.MaxUint8 { + return e("value %d is out of range for uint8", num) + } + case reflect.Uint16: + if num < 0 || unum > math.MaxUint16 { + return e("value %d is out of range for uint16", num) + } + case reflect.Uint32: + if num < 0 || unum > math.MaxUint32 { + return e("value %d is out of range for uint32", num) + } + } + rv.SetUint(unum) + } else { + panic("unreachable") + } + return nil + } + return badtype("integer", data) +} + +func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error { + if b, ok := data.(bool); ok { + rv.SetBool(b) + return nil + } + return badtype("boolean", data) +} + +func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error { + rv.Set(reflect.ValueOf(data)) + return nil +} + +func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error { + var s string + switch sdata := data.(type) { + case TextMarshaler: + text, err := sdata.MarshalText() + if err != nil { + return err + } + s = string(text) + case fmt.Stringer: + s = sdata.String() + case string: + s = sdata + case bool: + s = fmt.Sprintf("%v", sdata) + case int64: + s = fmt.Sprintf("%d", sdata) + case float64: + s = fmt.Sprintf("%f", sdata) + default: + return badtype("primitive (string-like)", data) + } + if err := v.UnmarshalText([]byte(s)); err != nil { + return err + } + return nil +} + +// rvalue returns a reflect.Value of `v`. All pointers are resolved. +func rvalue(v interface{}) reflect.Value { + return indirect(reflect.ValueOf(v)) +} + +// indirect returns the value pointed to by a pointer. +// Pointers are followed until the value is not a pointer. +// New values are allocated for each nil pointer. +// +// An exception to this rule is if the value satisfies an interface of +// interest to us (like encoding.TextUnmarshaler). +func indirect(v reflect.Value) reflect.Value { + if v.Kind() != reflect.Ptr { + if v.CanSet() { + pv := v.Addr() + if _, ok := pv.Interface().(TextUnmarshaler); ok { + return pv + } + } + return v + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + return indirect(reflect.Indirect(v)) +} + +func isUnifiable(rv reflect.Value) bool { + if rv.CanSet() { + return true + } + if _, ok := rv.Interface().(TextUnmarshaler); ok { + return true + } + return false +} + +func badtype(expected string, data interface{}) error { + return e("cannot load TOML value of type %T into a Go %s", data, expected) +} diff --git a/vendor/github.com/BurntSushi/toml/decode_meta.go b/vendor/github.com/BurntSushi/toml/decode_meta.go new file mode 100644 index 0000000000..b9914a6798 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/decode_meta.go @@ -0,0 +1,121 @@ +package toml + +import "strings" + +// MetaData allows access to meta information about TOML data that may not +// be inferrable via reflection. In particular, whether a key has been defined +// and the TOML type of a key. +type MetaData struct { + mapping map[string]interface{} + types map[string]tomlType + keys []Key + decoded map[string]bool + context Key // Used only during decoding. +} + +// IsDefined returns true if the key given exists in the TOML data. The key +// should be specified hierarchially. e.g., +// +// // access the TOML key 'a.b.c' +// IsDefined("a", "b", "c") +// +// IsDefined will return false if an empty key given. Keys are case sensitive. +func (md *MetaData) IsDefined(key ...string) bool { + if len(key) == 0 { + return false + } + + var hash map[string]interface{} + var ok bool + var hashOrVal interface{} = md.mapping + for _, k := range key { + if hash, ok = hashOrVal.(map[string]interface{}); !ok { + return false + } + if hashOrVal, ok = hash[k]; !ok { + return false + } + } + return true +} + +// Type returns a string representation of the type of the key specified. +// +// Type will return the empty string if given an empty key or a key that +// does not exist. Keys are case sensitive. +func (md *MetaData) Type(key ...string) string { + fullkey := strings.Join(key, ".") + if typ, ok := md.types[fullkey]; ok { + return typ.typeString() + } + return "" +} + +// Key is the type of any TOML key, including key groups. Use (MetaData).Keys +// to get values of this type. +type Key []string + +func (k Key) String() string { + return strings.Join(k, ".") +} + +func (k Key) maybeQuotedAll() string { + var ss []string + for i := range k { + ss = append(ss, k.maybeQuoted(i)) + } + return strings.Join(ss, ".") +} + +func (k Key) maybeQuoted(i int) string { + quote := false + for _, c := range k[i] { + if !isBareKeyChar(c) { + quote = true + break + } + } + if quote { + return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\"" + } + return k[i] +} + +func (k Key) add(piece string) Key { + newKey := make(Key, len(k)+1) + copy(newKey, k) + newKey[len(k)] = piece + return newKey +} + +// Keys returns a slice of every key in the TOML data, including key groups. +// Each key is itself a slice, where the first element is the top of the +// hierarchy and the last is the most specific. +// +// The list will have the same order as the keys appeared in the TOML data. +// +// All keys returned are non-empty. +func (md *MetaData) Keys() []Key { + return md.keys +} + +// Undecoded returns all keys that have not been decoded in the order in which +// they appear in the original TOML document. +// +// This includes keys that haven't been decoded because of a Primitive value. +// Once the Primitive value is decoded, the keys will be considered decoded. +// +// Also note that decoding into an empty interface will result in no decoding, +// and so no keys will be considered decoded. +// +// In this sense, the Undecoded keys correspond to keys in the TOML document +// that do not have a concrete type in your representation. +func (md *MetaData) Undecoded() []Key { + undecoded := make([]Key, 0, len(md.keys)) + for _, key := range md.keys { + if !md.decoded[key.String()] { + undecoded = append(undecoded, key) + } + } + return undecoded +} diff --git a/vendor/github.com/BurntSushi/toml/doc.go b/vendor/github.com/BurntSushi/toml/doc.go new file mode 100644 index 0000000000..b371f396ed --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/doc.go @@ -0,0 +1,27 @@ +/* +Package toml provides facilities for decoding and encoding TOML configuration +files via reflection. There is also support for delaying decoding with +the Primitive type, and querying the set of keys in a TOML document with the +MetaData type. + +The specification implemented: https://github.com/toml-lang/toml + +The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify +whether a file is a valid TOML document. It can also be used to print the +type of each key in a TOML document. + +Testing + +There are two important types of tests used for this package. The first is +contained inside '*_test.go' files and uses the standard Go unit testing +framework. These tests are primarily devoted to holistically testing the +decoder and encoder. + +The second type of testing is used to verify the implementation's adherence +to the TOML specification. These tests have been factored into their own +project: https://github.com/BurntSushi/toml-test + +The reason the tests are in a separate project is so that they can be used by +any implementation of TOML. Namely, it is language agnostic. +*/ +package toml diff --git a/vendor/github.com/BurntSushi/toml/encode.go b/vendor/github.com/BurntSushi/toml/encode.go new file mode 100644 index 0000000000..d905c21a24 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/encode.go @@ -0,0 +1,568 @@ +package toml + +import ( + "bufio" + "errors" + "fmt" + "io" + "reflect" + "sort" + "strconv" + "strings" + "time" +) + +type tomlEncodeError struct{ error } + +var ( + errArrayMixedElementTypes = errors.New( + "toml: cannot encode array with mixed element types") + errArrayNilElement = errors.New( + "toml: cannot encode array with nil element") + errNonString = errors.New( + "toml: cannot encode a map with non-string key type") + errAnonNonStruct = errors.New( + "toml: cannot encode an anonymous field that is not a struct") + errArrayNoTable = errors.New( + "toml: TOML array element cannot contain a table") + errNoKey = errors.New( + "toml: top-level values must be Go maps or structs") + errAnything = errors.New("") // used in testing +) + +var quotedReplacer = strings.NewReplacer( + "\t", "\\t", + "\n", "\\n", + "\r", "\\r", + "\"", "\\\"", + "\\", "\\\\", +) + +// Encoder controls the encoding of Go values to a TOML document to some +// io.Writer. +// +// The indentation level can be controlled with the Indent field. +type Encoder struct { + // A single indentation level. By default it is two spaces. + Indent string + + // hasWritten is whether we have written any output to w yet. + hasWritten bool + w *bufio.Writer +} + +// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer +// given. By default, a single indentation level is 2 spaces. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{ + w: bufio.NewWriter(w), + Indent: " ", + } +} + +// Encode writes a TOML representation of the Go value to the underlying +// io.Writer. If the value given cannot be encoded to a valid TOML document, +// then an error is returned. +// +// The mapping between Go values and TOML values should be precisely the same +// as for the Decode* functions. Similarly, the TextMarshaler interface is +// supported by encoding the resulting bytes as strings. (If you want to write +// arbitrary binary data then you will need to use something like base64 since +// TOML does not have any binary types.) +// +// When encoding TOML hashes (i.e., Go maps or structs), keys without any +// sub-hashes are encoded first. +// +// If a Go map is encoded, then its keys are sorted alphabetically for +// deterministic output. More control over this behavior may be provided if +// there is demand for it. +// +// Encoding Go values without a corresponding TOML representation---like map +// types with non-string keys---will cause an error to be returned. Similarly +// for mixed arrays/slices, arrays/slices with nil elements, embedded +// non-struct types and nested slices containing maps or structs. +// (e.g., [][]map[string]string is not allowed but []map[string]string is OK +// and so is []map[string][]string.) +func (enc *Encoder) Encode(v interface{}) error { + rv := eindirect(reflect.ValueOf(v)) + if err := enc.safeEncode(Key([]string{}), rv); err != nil { + return err + } + return enc.w.Flush() +} + +func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) { + defer func() { + if r := recover(); r != nil { + if terr, ok := r.(tomlEncodeError); ok { + err = terr.error + return + } + panic(r) + } + }() + enc.encode(key, rv) + return nil +} + +func (enc *Encoder) encode(key Key, rv reflect.Value) { + // Special case. Time needs to be in ISO8601 format. + // Special case. If we can marshal the type to text, then we used that. + // Basically, this prevents the encoder for handling these types as + // generic structs (or whatever the underlying type of a TextMarshaler is). + switch rv.Interface().(type) { + case time.Time, TextMarshaler: + enc.keyEqElement(key, rv) + return + } + + k := rv.Kind() + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, + reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.String, reflect.Bool: + enc.keyEqElement(key, rv) + case reflect.Array, reflect.Slice: + if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) { + enc.eArrayOfTables(key, rv) + } else { + enc.keyEqElement(key, rv) + } + case reflect.Interface: + if rv.IsNil() { + return + } + enc.encode(key, rv.Elem()) + case reflect.Map: + if rv.IsNil() { + return + } + enc.eTable(key, rv) + case reflect.Ptr: + if rv.IsNil() { + return + } + enc.encode(key, rv.Elem()) + case reflect.Struct: + enc.eTable(key, rv) + default: + panic(e("unsupported type for key '%s': %s", key, k)) + } +} + +// eElement encodes any value that can be an array element (primitives and +// arrays). +func (enc *Encoder) eElement(rv reflect.Value) { + switch v := rv.Interface().(type) { + case time.Time: + // Special case time.Time as a primitive. Has to come before + // TextMarshaler below because time.Time implements + // encoding.TextMarshaler, but we need to always use UTC. + enc.wf(v.UTC().Format("2006-01-02T15:04:05Z")) + return + case TextMarshaler: + // Special case. Use text marshaler if it's available for this value. + if s, err := v.MarshalText(); err != nil { + encPanic(err) + } else { + enc.writeQuoted(string(s)) + } + return + } + switch rv.Kind() { + case reflect.Bool: + enc.wf(strconv.FormatBool(rv.Bool())) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64: + enc.wf(strconv.FormatInt(rv.Int(), 10)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, + reflect.Uint32, reflect.Uint64: + enc.wf(strconv.FormatUint(rv.Uint(), 10)) + case reflect.Float32: + enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32))) + case reflect.Float64: + enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64))) + case reflect.Array, reflect.Slice: + enc.eArrayOrSliceElement(rv) + case reflect.Interface: + enc.eElement(rv.Elem()) + case reflect.String: + enc.writeQuoted(rv.String()) + default: + panic(e("unexpected primitive type: %s", rv.Kind())) + } +} + +// By the TOML spec, all floats must have a decimal with at least one +// number on either side. +func floatAddDecimal(fstr string) string { + if !strings.Contains(fstr, ".") { + return fstr + ".0" + } + return fstr +} + +func (enc *Encoder) writeQuoted(s string) { + enc.wf("\"%s\"", quotedReplacer.Replace(s)) +} + +func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) { + length := rv.Len() + enc.wf("[") + for i := 0; i < length; i++ { + elem := rv.Index(i) + enc.eElement(elem) + if i != length-1 { + enc.wf(", ") + } + } + enc.wf("]") +} + +func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) { + if len(key) == 0 { + encPanic(errNoKey) + } + for i := 0; i < rv.Len(); i++ { + trv := rv.Index(i) + if isNil(trv) { + continue + } + panicIfInvalidKey(key) + enc.newline() + enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll()) + enc.newline() + enc.eMapOrStruct(key, trv) + } +} + +func (enc *Encoder) eTable(key Key, rv reflect.Value) { + panicIfInvalidKey(key) + if len(key) == 1 { + // Output an extra newline between top-level tables. + // (The newline isn't written if nothing else has been written though.) + enc.newline() + } + if len(key) > 0 { + enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll()) + enc.newline() + } + enc.eMapOrStruct(key, rv) +} + +func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) { + switch rv := eindirect(rv); rv.Kind() { + case reflect.Map: + enc.eMap(key, rv) + case reflect.Struct: + enc.eStruct(key, rv) + default: + panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String()) + } +} + +func (enc *Encoder) eMap(key Key, rv reflect.Value) { + rt := rv.Type() + if rt.Key().Kind() != reflect.String { + encPanic(errNonString) + } + + // Sort keys so that we have deterministic output. And write keys directly + // underneath this key first, before writing sub-structs or sub-maps. + var mapKeysDirect, mapKeysSub []string + for _, mapKey := range rv.MapKeys() { + k := mapKey.String() + if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) { + mapKeysSub = append(mapKeysSub, k) + } else { + mapKeysDirect = append(mapKeysDirect, k) + } + } + + var writeMapKeys = func(mapKeys []string) { + sort.Strings(mapKeys) + for _, mapKey := range mapKeys { + mrv := rv.MapIndex(reflect.ValueOf(mapKey)) + if isNil(mrv) { + // Don't write anything for nil fields. + continue + } + enc.encode(key.add(mapKey), mrv) + } + } + writeMapKeys(mapKeysDirect) + writeMapKeys(mapKeysSub) +} + +func (enc *Encoder) eStruct(key Key, rv reflect.Value) { + // Write keys for fields directly under this key first, because if we write + // a field that creates a new table, then all keys under it will be in that + // table (not the one we're writing here). + rt := rv.Type() + var fieldsDirect, fieldsSub [][]int + var addFields func(rt reflect.Type, rv reflect.Value, start []int) + addFields = func(rt reflect.Type, rv reflect.Value, start []int) { + for i := 0; i < rt.NumField(); i++ { + f := rt.Field(i) + // skip unexported fields + if f.PkgPath != "" && !f.Anonymous { + continue + } + frv := rv.Field(i) + if f.Anonymous { + t := f.Type + switch t.Kind() { + case reflect.Struct: + // Treat anonymous struct fields with + // tag names as though they are not + // anonymous, like encoding/json does. + if getOptions(f.Tag).name == "" { + addFields(t, frv, f.Index) + continue + } + case reflect.Ptr: + if t.Elem().Kind() == reflect.Struct && + getOptions(f.Tag).name == "" { + if !frv.IsNil() { + addFields(t.Elem(), frv.Elem(), f.Index) + } + continue + } + // Fall through to the normal field encoding logic below + // for non-struct anonymous fields. + } + } + + if typeIsHash(tomlTypeOfGo(frv)) { + fieldsSub = append(fieldsSub, append(start, f.Index...)) + } else { + fieldsDirect = append(fieldsDirect, append(start, f.Index...)) + } + } + } + addFields(rt, rv, nil) + + var writeFields = func(fields [][]int) { + for _, fieldIndex := range fields { + sft := rt.FieldByIndex(fieldIndex) + sf := rv.FieldByIndex(fieldIndex) + if isNil(sf) { + // Don't write anything for nil fields. + continue + } + + opts := getOptions(sft.Tag) + if opts.skip { + continue + } + keyName := sft.Name + if opts.name != "" { + keyName = opts.name + } + if opts.omitempty && isEmpty(sf) { + continue + } + if opts.omitzero && isZero(sf) { + continue + } + + enc.encode(key.add(keyName), sf) + } + } + writeFields(fieldsDirect) + writeFields(fieldsSub) +} + +// tomlTypeName returns the TOML type name of the Go value's type. It is +// used to determine whether the types of array elements are mixed (which is +// forbidden). If the Go value is nil, then it is illegal for it to be an array +// element, and valueIsNil is returned as true. + +// Returns the TOML type of a Go value. The type may be `nil`, which means +// no concrete TOML type could be found. +func tomlTypeOfGo(rv reflect.Value) tomlType { + if isNil(rv) || !rv.IsValid() { + return nil + } + switch rv.Kind() { + case reflect.Bool: + return tomlBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, + reflect.Uint64: + return tomlInteger + case reflect.Float32, reflect.Float64: + return tomlFloat + case reflect.Array, reflect.Slice: + if typeEqual(tomlHash, tomlArrayType(rv)) { + return tomlArrayHash + } + return tomlArray + case reflect.Ptr, reflect.Interface: + return tomlTypeOfGo(rv.Elem()) + case reflect.String: + return tomlString + case reflect.Map: + return tomlHash + case reflect.Struct: + switch rv.Interface().(type) { + case time.Time: + return tomlDatetime + case TextMarshaler: + return tomlString + default: + return tomlHash + } + default: + panic("unexpected reflect.Kind: " + rv.Kind().String()) + } +} + +// tomlArrayType returns the element type of a TOML array. The type returned +// may be nil if it cannot be determined (e.g., a nil slice or a zero length +// slize). This function may also panic if it finds a type that cannot be +// expressed in TOML (such as nil elements, heterogeneous arrays or directly +// nested arrays of tables). +func tomlArrayType(rv reflect.Value) tomlType { + if isNil(rv) || !rv.IsValid() || rv.Len() == 0 { + return nil + } + firstType := tomlTypeOfGo(rv.Index(0)) + if firstType == nil { + encPanic(errArrayNilElement) + } + + rvlen := rv.Len() + for i := 1; i < rvlen; i++ { + elem := rv.Index(i) + switch elemType := tomlTypeOfGo(elem); { + case elemType == nil: + encPanic(errArrayNilElement) + case !typeEqual(firstType, elemType): + encPanic(errArrayMixedElementTypes) + } + } + // If we have a nested array, then we must make sure that the nested + // array contains ONLY primitives. + // This checks arbitrarily nested arrays. + if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) { + nest := tomlArrayType(eindirect(rv.Index(0))) + if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) { + encPanic(errArrayNoTable) + } + } + return firstType +} + +type tagOptions struct { + skip bool // "-" + name string + omitempty bool + omitzero bool +} + +func getOptions(tag reflect.StructTag) tagOptions { + t := tag.Get("toml") + if t == "-" { + return tagOptions{skip: true} + } + var opts tagOptions + parts := strings.Split(t, ",") + opts.name = parts[0] + for _, s := range parts[1:] { + switch s { + case "omitempty": + opts.omitempty = true + case "omitzero": + opts.omitzero = true + } + } + return opts +} + +func isZero(rv reflect.Value) bool { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return rv.Uint() == 0 + case reflect.Float32, reflect.Float64: + return rv.Float() == 0.0 + } + return false +} + +func isEmpty(rv reflect.Value) bool { + switch rv.Kind() { + case reflect.Array, reflect.Slice, reflect.Map, reflect.String: + return rv.Len() == 0 + case reflect.Bool: + return !rv.Bool() + } + return false +} + +func (enc *Encoder) newline() { + if enc.hasWritten { + enc.wf("\n") + } +} + +func (enc *Encoder) keyEqElement(key Key, val reflect.Value) { + if len(key) == 0 { + encPanic(errNoKey) + } + panicIfInvalidKey(key) + enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1)) + enc.eElement(val) + enc.newline() +} + +func (enc *Encoder) wf(format string, v ...interface{}) { + if _, err := fmt.Fprintf(enc.w, format, v...); err != nil { + encPanic(err) + } + enc.hasWritten = true +} + +func (enc *Encoder) indentStr(key Key) string { + return strings.Repeat(enc.Indent, len(key)-1) +} + +func encPanic(err error) { + panic(tomlEncodeError{err}) +} + +func eindirect(v reflect.Value) reflect.Value { + switch v.Kind() { + case reflect.Ptr, reflect.Interface: + return eindirect(v.Elem()) + default: + return v + } +} + +func isNil(rv reflect.Value) bool { + switch rv.Kind() { + case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return rv.IsNil() + default: + return false + } +} + +func panicIfInvalidKey(key Key) { + for _, k := range key { + if len(k) == 0 { + encPanic(e("Key '%s' is not a valid table name. Key names "+ + "cannot be empty.", key.maybeQuotedAll())) + } + } +} + +func isValidKeyName(s string) bool { + return len(s) != 0 +} diff --git a/vendor/github.com/BurntSushi/toml/encoding_types.go b/vendor/github.com/BurntSushi/toml/encoding_types.go new file mode 100644 index 0000000000..d36e1dd600 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/encoding_types.go @@ -0,0 +1,19 @@ +// +build go1.2 + +package toml + +// In order to support Go 1.1, we define our own TextMarshaler and +// TextUnmarshaler types. For Go 1.2+, we just alias them with the +// standard library interfaces. + +import ( + "encoding" +) + +// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here +// so that Go 1.1 can be supported. +type TextMarshaler encoding.TextMarshaler + +// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined +// here so that Go 1.1 can be supported. +type TextUnmarshaler encoding.TextUnmarshaler diff --git a/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go b/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go new file mode 100644 index 0000000000..e8d503d046 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/encoding_types_1.1.go @@ -0,0 +1,18 @@ +// +build !go1.2 + +package toml + +// These interfaces were introduced in Go 1.2, so we add them manually when +// compiling for Go 1.1. + +// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here +// so that Go 1.1 can be supported. +type TextMarshaler interface { + MarshalText() (text []byte, err error) +} + +// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined +// here so that Go 1.1 can be supported. +type TextUnmarshaler interface { + UnmarshalText(text []byte) error +} diff --git a/vendor/github.com/BurntSushi/toml/lex.go b/vendor/github.com/BurntSushi/toml/lex.go new file mode 100644 index 0000000000..e0a742a887 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/lex.go @@ -0,0 +1,953 @@ +package toml + +import ( + "fmt" + "strings" + "unicode" + "unicode/utf8" +) + +type itemType int + +const ( + itemError itemType = iota + itemNIL // used in the parser to indicate no type + itemEOF + itemText + itemString + itemRawString + itemMultilineString + itemRawMultilineString + itemBool + itemInteger + itemFloat + itemDatetime + itemArray // the start of an array + itemArrayEnd + itemTableStart + itemTableEnd + itemArrayTableStart + itemArrayTableEnd + itemKeyStart + itemCommentStart + itemInlineTableStart + itemInlineTableEnd +) + +const ( + eof = 0 + comma = ',' + tableStart = '[' + tableEnd = ']' + arrayTableStart = '[' + arrayTableEnd = ']' + tableSep = '.' + keySep = '=' + arrayStart = '[' + arrayEnd = ']' + commentStart = '#' + stringStart = '"' + stringEnd = '"' + rawStringStart = '\'' + rawStringEnd = '\'' + inlineTableStart = '{' + inlineTableEnd = '}' +) + +type stateFn func(lx *lexer) stateFn + +type lexer struct { + input string + start int + pos int + line int + state stateFn + items chan item + + // Allow for backing up up to three runes. + // This is necessary because TOML contains 3-rune tokens (""" and '''). + prevWidths [3]int + nprev int // how many of prevWidths are in use + // If we emit an eof, we can still back up, but it is not OK to call + // next again. + atEOF bool + + // A stack of state functions used to maintain context. + // The idea is to reuse parts of the state machine in various places. + // For example, values can appear at the top level or within arbitrarily + // nested arrays. The last state on the stack is used after a value has + // been lexed. Similarly for comments. + stack []stateFn +} + +type item struct { + typ itemType + val string + line int +} + +func (lx *lexer) nextItem() item { + for { + select { + case item := <-lx.items: + return item + default: + lx.state = lx.state(lx) + } + } +} + +func lex(input string) *lexer { + lx := &lexer{ + input: input, + state: lexTop, + line: 1, + items: make(chan item, 10), + stack: make([]stateFn, 0, 10), + } + return lx +} + +func (lx *lexer) push(state stateFn) { + lx.stack = append(lx.stack, state) +} + +func (lx *lexer) pop() stateFn { + if len(lx.stack) == 0 { + return lx.errorf("BUG in lexer: no states to pop") + } + last := lx.stack[len(lx.stack)-1] + lx.stack = lx.stack[0 : len(lx.stack)-1] + return last +} + +func (lx *lexer) current() string { + return lx.input[lx.start:lx.pos] +} + +func (lx *lexer) emit(typ itemType) { + lx.items <- item{typ, lx.current(), lx.line} + lx.start = lx.pos +} + +func (lx *lexer) emitTrim(typ itemType) { + lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line} + lx.start = lx.pos +} + +func (lx *lexer) next() (r rune) { + if lx.atEOF { + panic("next called after EOF") + } + if lx.pos >= len(lx.input) { + lx.atEOF = true + return eof + } + + if lx.input[lx.pos] == '\n' { + lx.line++ + } + lx.prevWidths[2] = lx.prevWidths[1] + lx.prevWidths[1] = lx.prevWidths[0] + if lx.nprev < 3 { + lx.nprev++ + } + r, w := utf8.DecodeRuneInString(lx.input[lx.pos:]) + lx.prevWidths[0] = w + lx.pos += w + return r +} + +// ignore skips over the pending input before this point. +func (lx *lexer) ignore() { + lx.start = lx.pos +} + +// backup steps back one rune. Can be called only twice between calls to next. +func (lx *lexer) backup() { + if lx.atEOF { + lx.atEOF = false + return + } + if lx.nprev < 1 { + panic("backed up too far") + } + w := lx.prevWidths[0] + lx.prevWidths[0] = lx.prevWidths[1] + lx.prevWidths[1] = lx.prevWidths[2] + lx.nprev-- + lx.pos -= w + if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' { + lx.line-- + } +} + +// accept consumes the next rune if it's equal to `valid`. +func (lx *lexer) accept(valid rune) bool { + if lx.next() == valid { + return true + } + lx.backup() + return false +} + +// peek returns but does not consume the next rune in the input. +func (lx *lexer) peek() rune { + r := lx.next() + lx.backup() + return r +} + +// skip ignores all input that matches the given predicate. +func (lx *lexer) skip(pred func(rune) bool) { + for { + r := lx.next() + if pred(r) { + continue + } + lx.backup() + lx.ignore() + return + } +} + +// errorf stops all lexing by emitting an error and returning `nil`. +// Note that any value that is a character is escaped if it's a special +// character (newlines, tabs, etc.). +func (lx *lexer) errorf(format string, values ...interface{}) stateFn { + lx.items <- item{ + itemError, + fmt.Sprintf(format, values...), + lx.line, + } + return nil +} + +// lexTop consumes elements at the top level of TOML data. +func lexTop(lx *lexer) stateFn { + r := lx.next() + if isWhitespace(r) || isNL(r) { + return lexSkip(lx, lexTop) + } + switch r { + case commentStart: + lx.push(lexTop) + return lexCommentStart + case tableStart: + return lexTableStart + case eof: + if lx.pos > lx.start { + return lx.errorf("unexpected EOF") + } + lx.emit(itemEOF) + return nil + } + + // At this point, the only valid item can be a key, so we back up + // and let the key lexer do the rest. + lx.backup() + lx.push(lexTopEnd) + return lexKeyStart +} + +// lexTopEnd is entered whenever a top-level item has been consumed. (A value +// or a table.) It must see only whitespace, and will turn back to lexTop +// upon a newline. If it sees EOF, it will quit the lexer successfully. +func lexTopEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case r == commentStart: + // a comment will read to a newline for us. + lx.push(lexTop) + return lexCommentStart + case isWhitespace(r): + return lexTopEnd + case isNL(r): + lx.ignore() + return lexTop + case r == eof: + lx.emit(itemEOF) + return nil + } + return lx.errorf("expected a top-level item to end with a newline, "+ + "comment, or EOF, but got %q instead", r) +} + +// lexTable lexes the beginning of a table. Namely, it makes sure that +// it starts with a character other than '.' and ']'. +// It assumes that '[' has already been consumed. +// It also handles the case that this is an item in an array of tables. +// e.g., '[[name]]'. +func lexTableStart(lx *lexer) stateFn { + if lx.peek() == arrayTableStart { + lx.next() + lx.emit(itemArrayTableStart) + lx.push(lexArrayTableEnd) + } else { + lx.emit(itemTableStart) + lx.push(lexTableEnd) + } + return lexTableNameStart +} + +func lexTableEnd(lx *lexer) stateFn { + lx.emit(itemTableEnd) + return lexTopEnd +} + +func lexArrayTableEnd(lx *lexer) stateFn { + if r := lx.next(); r != arrayTableEnd { + return lx.errorf("expected end of table array name delimiter %q, "+ + "but got %q instead", arrayTableEnd, r) + } + lx.emit(itemArrayTableEnd) + return lexTopEnd +} + +func lexTableNameStart(lx *lexer) stateFn { + lx.skip(isWhitespace) + switch r := lx.peek(); { + case r == tableEnd || r == eof: + return lx.errorf("unexpected end of table name " + + "(table names cannot be empty)") + case r == tableSep: + return lx.errorf("unexpected table separator " + + "(table names cannot be empty)") + case r == stringStart || r == rawStringStart: + lx.ignore() + lx.push(lexTableNameEnd) + return lexValue // reuse string lexing + default: + return lexBareTableName + } +} + +// lexBareTableName lexes the name of a table. It assumes that at least one +// valid character for the table has already been read. +func lexBareTableName(lx *lexer) stateFn { + r := lx.next() + if isBareKeyChar(r) { + return lexBareTableName + } + lx.backup() + lx.emit(itemText) + return lexTableNameEnd +} + +// lexTableNameEnd reads the end of a piece of a table name, optionally +// consuming whitespace. +func lexTableNameEnd(lx *lexer) stateFn { + lx.skip(isWhitespace) + switch r := lx.next(); { + case isWhitespace(r): + return lexTableNameEnd + case r == tableSep: + lx.ignore() + return lexTableNameStart + case r == tableEnd: + return lx.pop() + default: + return lx.errorf("expected '.' or ']' to end table name, "+ + "but got %q instead", r) + } +} + +// lexKeyStart consumes a key name up until the first non-whitespace character. +// lexKeyStart will ignore whitespace. +func lexKeyStart(lx *lexer) stateFn { + r := lx.peek() + switch { + case r == keySep: + return lx.errorf("unexpected key separator %q", keySep) + case isWhitespace(r) || isNL(r): + lx.next() + return lexSkip(lx, lexKeyStart) + case r == stringStart || r == rawStringStart: + lx.ignore() + lx.emit(itemKeyStart) + lx.push(lexKeyEnd) + return lexValue // reuse string lexing + default: + lx.ignore() + lx.emit(itemKeyStart) + return lexBareKey + } +} + +// lexBareKey consumes the text of a bare key. Assumes that the first character +// (which is not whitespace) has not yet been consumed. +func lexBareKey(lx *lexer) stateFn { + switch r := lx.next(); { + case isBareKeyChar(r): + return lexBareKey + case isWhitespace(r): + lx.backup() + lx.emit(itemText) + return lexKeyEnd + case r == keySep: + lx.backup() + lx.emit(itemText) + return lexKeyEnd + default: + return lx.errorf("bare keys cannot contain %q", r) + } +} + +// lexKeyEnd consumes the end of a key and trims whitespace (up to the key +// separator). +func lexKeyEnd(lx *lexer) stateFn { + switch r := lx.next(); { + case r == keySep: + return lexSkip(lx, lexValue) + case isWhitespace(r): + return lexSkip(lx, lexKeyEnd) + default: + return lx.errorf("expected key separator %q, but got %q instead", + keySep, r) + } +} + +// lexValue starts the consumption of a value anywhere a value is expected. +// lexValue will ignore whitespace. +// After a value is lexed, the last state on the next is popped and returned. +func lexValue(lx *lexer) stateFn { + // We allow whitespace to precede a value, but NOT newlines. + // In array syntax, the array states are responsible for ignoring newlines. + r := lx.next() + switch { + case isWhitespace(r): + return lexSkip(lx, lexValue) + case isDigit(r): + lx.backup() // avoid an extra state and use the same as above + return lexNumberOrDateStart + } + switch r { + case arrayStart: + lx.ignore() + lx.emit(itemArray) + return lexArrayValue + case inlineTableStart: + lx.ignore() + lx.emit(itemInlineTableStart) + return lexInlineTableValue + case stringStart: + if lx.accept(stringStart) { + if lx.accept(stringStart) { + lx.ignore() // Ignore """ + return lexMultilineString + } + lx.backup() + } + lx.ignore() // ignore the '"' + return lexString + case rawStringStart: + if lx.accept(rawStringStart) { + if lx.accept(rawStringStart) { + lx.ignore() // Ignore """ + return lexMultilineRawString + } + lx.backup() + } + lx.ignore() // ignore the "'" + return lexRawString + case '+', '-': + return lexNumberStart + case '.': // special error case, be kind to users + return lx.errorf("floats must start with a digit, not '.'") + } + if unicode.IsLetter(r) { + // Be permissive here; lexBool will give a nice error if the + // user wrote something like + // x = foo + // (i.e. not 'true' or 'false' but is something else word-like.) + lx.backup() + return lexBool + } + return lx.errorf("expected value but found %q instead", r) +} + +// lexArrayValue consumes one value in an array. It assumes that '[' or ',' +// have already been consumed. All whitespace and newlines are ignored. +func lexArrayValue(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r) || isNL(r): + return lexSkip(lx, lexArrayValue) + case r == commentStart: + lx.push(lexArrayValue) + return lexCommentStart + case r == comma: + return lx.errorf("unexpected comma") + case r == arrayEnd: + // NOTE(caleb): The spec isn't clear about whether you can have + // a trailing comma or not, so we'll allow it. + return lexArrayEnd + } + + lx.backup() + lx.push(lexArrayValueEnd) + return lexValue +} + +// lexArrayValueEnd consumes everything between the end of an array value and +// the next value (or the end of the array): it ignores whitespace and newlines +// and expects either a ',' or a ']'. +func lexArrayValueEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r) || isNL(r): + return lexSkip(lx, lexArrayValueEnd) + case r == commentStart: + lx.push(lexArrayValueEnd) + return lexCommentStart + case r == comma: + lx.ignore() + return lexArrayValue // move on to the next value + case r == arrayEnd: + return lexArrayEnd + } + return lx.errorf( + "expected a comma or array terminator %q, but got %q instead", + arrayEnd, r, + ) +} + +// lexArrayEnd finishes the lexing of an array. +// It assumes that a ']' has just been consumed. +func lexArrayEnd(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemArrayEnd) + return lx.pop() +} + +// lexInlineTableValue consumes one key/value pair in an inline table. +// It assumes that '{' or ',' have already been consumed. Whitespace is ignored. +func lexInlineTableValue(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r): + return lexSkip(lx, lexInlineTableValue) + case isNL(r): + return lx.errorf("newlines not allowed within inline tables") + case r == commentStart: + lx.push(lexInlineTableValue) + return lexCommentStart + case r == comma: + return lx.errorf("unexpected comma") + case r == inlineTableEnd: + return lexInlineTableEnd + } + lx.backup() + lx.push(lexInlineTableValueEnd) + return lexKeyStart +} + +// lexInlineTableValueEnd consumes everything between the end of an inline table +// key/value pair and the next pair (or the end of the table): +// it ignores whitespace and expects either a ',' or a '}'. +func lexInlineTableValueEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r): + return lexSkip(lx, lexInlineTableValueEnd) + case isNL(r): + return lx.errorf("newlines not allowed within inline tables") + case r == commentStart: + lx.push(lexInlineTableValueEnd) + return lexCommentStart + case r == comma: + lx.ignore() + return lexInlineTableValue + case r == inlineTableEnd: + return lexInlineTableEnd + } + return lx.errorf("expected a comma or an inline table terminator %q, "+ + "but got %q instead", inlineTableEnd, r) +} + +// lexInlineTableEnd finishes the lexing of an inline table. +// It assumes that a '}' has just been consumed. +func lexInlineTableEnd(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemInlineTableEnd) + return lx.pop() +} + +// lexString consumes the inner contents of a string. It assumes that the +// beginning '"' has already been consumed and ignored. +func lexString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == eof: + return lx.errorf("unexpected EOF") + case isNL(r): + return lx.errorf("strings cannot contain newlines") + case r == '\\': + lx.push(lexString) + return lexStringEscape + case r == stringEnd: + lx.backup() + lx.emit(itemString) + lx.next() + lx.ignore() + return lx.pop() + } + return lexString +} + +// lexMultilineString consumes the inner contents of a string. It assumes that +// the beginning '"""' has already been consumed and ignored. +func lexMultilineString(lx *lexer) stateFn { + switch lx.next() { + case eof: + return lx.errorf("unexpected EOF") + case '\\': + return lexMultilineStringEscape + case stringEnd: + if lx.accept(stringEnd) { + if lx.accept(stringEnd) { + lx.backup() + lx.backup() + lx.backup() + lx.emit(itemMultilineString) + lx.next() + lx.next() + lx.next() + lx.ignore() + return lx.pop() + } + lx.backup() + } + } + return lexMultilineString +} + +// lexRawString consumes a raw string. Nothing can be escaped in such a string. +// It assumes that the beginning "'" has already been consumed and ignored. +func lexRawString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == eof: + return lx.errorf("unexpected EOF") + case isNL(r): + return lx.errorf("strings cannot contain newlines") + case r == rawStringEnd: + lx.backup() + lx.emit(itemRawString) + lx.next() + lx.ignore() + return lx.pop() + } + return lexRawString +} + +// lexMultilineRawString consumes a raw string. Nothing can be escaped in such +// a string. It assumes that the beginning "'''" has already been consumed and +// ignored. +func lexMultilineRawString(lx *lexer) stateFn { + switch lx.next() { + case eof: + return lx.errorf("unexpected EOF") + case rawStringEnd: + if lx.accept(rawStringEnd) { + if lx.accept(rawStringEnd) { + lx.backup() + lx.backup() + lx.backup() + lx.emit(itemRawMultilineString) + lx.next() + lx.next() + lx.next() + lx.ignore() + return lx.pop() + } + lx.backup() + } + } + return lexMultilineRawString +} + +// lexMultilineStringEscape consumes an escaped character. It assumes that the +// preceding '\\' has already been consumed. +func lexMultilineStringEscape(lx *lexer) stateFn { + // Handle the special case first: + if isNL(lx.next()) { + return lexMultilineString + } + lx.backup() + lx.push(lexMultilineString) + return lexStringEscape(lx) +} + +func lexStringEscape(lx *lexer) stateFn { + r := lx.next() + switch r { + case 'b': + fallthrough + case 't': + fallthrough + case 'n': + fallthrough + case 'f': + fallthrough + case 'r': + fallthrough + case '"': + fallthrough + case '\\': + return lx.pop() + case 'u': + return lexShortUnicodeEscape + case 'U': + return lexLongUnicodeEscape + } + return lx.errorf("invalid escape character %q; only the following "+ + "escape characters are allowed: "+ + `\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX`, r) +} + +func lexShortUnicodeEscape(lx *lexer) stateFn { + var r rune + for i := 0; i < 4; i++ { + r = lx.next() + if !isHexadecimal(r) { + return lx.errorf(`expected four hexadecimal digits after '\u', `+ + "but got %q instead", lx.current()) + } + } + return lx.pop() +} + +func lexLongUnicodeEscape(lx *lexer) stateFn { + var r rune + for i := 0; i < 8; i++ { + r = lx.next() + if !isHexadecimal(r) { + return lx.errorf(`expected eight hexadecimal digits after '\U', `+ + "but got %q instead", lx.current()) + } + } + return lx.pop() +} + +// lexNumberOrDateStart consumes either an integer, a float, or datetime. +func lexNumberOrDateStart(lx *lexer) stateFn { + r := lx.next() + if isDigit(r) { + return lexNumberOrDate + } + switch r { + case '_': + return lexNumber + case 'e', 'E': + return lexFloat + case '.': + return lx.errorf("floats must start with a digit, not '.'") + } + return lx.errorf("expected a digit but got %q", r) +} + +// lexNumberOrDate consumes either an integer, float or datetime. +func lexNumberOrDate(lx *lexer) stateFn { + r := lx.next() + if isDigit(r) { + return lexNumberOrDate + } + switch r { + case '-': + return lexDatetime + case '_': + return lexNumber + case '.', 'e', 'E': + return lexFloat + } + + lx.backup() + lx.emit(itemInteger) + return lx.pop() +} + +// lexDatetime consumes a Datetime, to a first approximation. +// The parser validates that it matches one of the accepted formats. +func lexDatetime(lx *lexer) stateFn { + r := lx.next() + if isDigit(r) { + return lexDatetime + } + switch r { + case '-', 'T', ':', '.', 'Z', '+': + return lexDatetime + } + + lx.backup() + lx.emit(itemDatetime) + return lx.pop() +} + +// lexNumberStart consumes either an integer or a float. It assumes that a sign +// has already been read, but that *no* digits have been consumed. +// lexNumberStart will move to the appropriate integer or float states. +func lexNumberStart(lx *lexer) stateFn { + // We MUST see a digit. Even floats have to start with a digit. + r := lx.next() + if !isDigit(r) { + if r == '.' { + return lx.errorf("floats must start with a digit, not '.'") + } + return lx.errorf("expected a digit but got %q", r) + } + return lexNumber +} + +// lexNumber consumes an integer or a float after seeing the first digit. +func lexNumber(lx *lexer) stateFn { + r := lx.next() + if isDigit(r) { + return lexNumber + } + switch r { + case '_': + return lexNumber + case '.', 'e', 'E': + return lexFloat + } + + lx.backup() + lx.emit(itemInteger) + return lx.pop() +} + +// lexFloat consumes the elements of a float. It allows any sequence of +// float-like characters, so floats emitted by the lexer are only a first +// approximation and must be validated by the parser. +func lexFloat(lx *lexer) stateFn { + r := lx.next() + if isDigit(r) { + return lexFloat + } + switch r { + case '_', '.', '-', '+', 'e', 'E': + return lexFloat + } + + lx.backup() + lx.emit(itemFloat) + return lx.pop() +} + +// lexBool consumes a bool string: 'true' or 'false. +func lexBool(lx *lexer) stateFn { + var rs []rune + for { + r := lx.next() + if !unicode.IsLetter(r) { + lx.backup() + break + } + rs = append(rs, r) + } + s := string(rs) + switch s { + case "true", "false": + lx.emit(itemBool) + return lx.pop() + } + return lx.errorf("expected value but found %q instead", s) +} + +// lexCommentStart begins the lexing of a comment. It will emit +// itemCommentStart and consume no characters, passing control to lexComment. +func lexCommentStart(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemCommentStart) + return lexComment +} + +// lexComment lexes an entire comment. It assumes that '#' has been consumed. +// It will consume *up to* the first newline character, and pass control +// back to the last state on the stack. +func lexComment(lx *lexer) stateFn { + r := lx.peek() + if isNL(r) || r == eof { + lx.emit(itemText) + return lx.pop() + } + lx.next() + return lexComment +} + +// lexSkip ignores all slurped input and moves on to the next state. +func lexSkip(lx *lexer, nextState stateFn) stateFn { + return func(lx *lexer) stateFn { + lx.ignore() + return nextState + } +} + +// isWhitespace returns true if `r` is a whitespace character according +// to the spec. +func isWhitespace(r rune) bool { + return r == '\t' || r == ' ' +} + +func isNL(r rune) bool { + return r == '\n' || r == '\r' +} + +func isDigit(r rune) bool { + return r >= '0' && r <= '9' +} + +func isHexadecimal(r rune) bool { + return (r >= '0' && r <= '9') || + (r >= 'a' && r <= 'f') || + (r >= 'A' && r <= 'F') +} + +func isBareKeyChar(r rune) bool { + return (r >= 'A' && r <= 'Z') || + (r >= 'a' && r <= 'z') || + (r >= '0' && r <= '9') || + r == '_' || + r == '-' +} + +func (itype itemType) String() string { + switch itype { + case itemError: + return "Error" + case itemNIL: + return "NIL" + case itemEOF: + return "EOF" + case itemText: + return "Text" + case itemString, itemRawString, itemMultilineString, itemRawMultilineString: + return "String" + case itemBool: + return "Bool" + case itemInteger: + return "Integer" + case itemFloat: + return "Float" + case itemDatetime: + return "DateTime" + case itemTableStart: + return "TableStart" + case itemTableEnd: + return "TableEnd" + case itemKeyStart: + return "KeyStart" + case itemArray: + return "Array" + case itemArrayEnd: + return "ArrayEnd" + case itemCommentStart: + return "CommentStart" + } + panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype))) +} + +func (item item) String() string { + return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val) +} diff --git a/vendor/github.com/BurntSushi/toml/parse.go b/vendor/github.com/BurntSushi/toml/parse.go new file mode 100644 index 0000000000..50869ef926 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/parse.go @@ -0,0 +1,592 @@ +package toml + +import ( + "fmt" + "strconv" + "strings" + "time" + "unicode" + "unicode/utf8" +) + +type parser struct { + mapping map[string]interface{} + types map[string]tomlType + lx *lexer + + // A list of keys in the order that they appear in the TOML data. + ordered []Key + + // the full key for the current hash in scope + context Key + + // the base key name for everything except hashes + currentKey string + + // rough approximation of line number + approxLine int + + // A map of 'key.group.names' to whether they were created implicitly. + implicits map[string]bool +} + +type parseError string + +func (pe parseError) Error() string { + return string(pe) +} + +func parse(data string) (p *parser, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + if err, ok = r.(parseError); ok { + return + } + panic(r) + } + }() + + p = &parser{ + mapping: make(map[string]interface{}), + types: make(map[string]tomlType), + lx: lex(data), + ordered: make([]Key, 0), + implicits: make(map[string]bool), + } + for { + item := p.next() + if item.typ == itemEOF { + break + } + p.topLevel(item) + } + + return p, nil +} + +func (p *parser) panicf(format string, v ...interface{}) { + msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s", + p.approxLine, p.current(), fmt.Sprintf(format, v...)) + panic(parseError(msg)) +} + +func (p *parser) next() item { + it := p.lx.nextItem() + if it.typ == itemError { + p.panicf("%s", it.val) + } + return it +} + +func (p *parser) bug(format string, v ...interface{}) { + panic(fmt.Sprintf("BUG: "+format+"\n\n", v...)) +} + +func (p *parser) expect(typ itemType) item { + it := p.next() + p.assertEqual(typ, it.typ) + return it +} + +func (p *parser) assertEqual(expected, got itemType) { + if expected != got { + p.bug("Expected '%s' but got '%s'.", expected, got) + } +} + +func (p *parser) topLevel(item item) { + switch item.typ { + case itemCommentStart: + p.approxLine = item.line + p.expect(itemText) + case itemTableStart: + kg := p.next() + p.approxLine = kg.line + + var key Key + for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() { + key = append(key, p.keyString(kg)) + } + p.assertEqual(itemTableEnd, kg.typ) + + p.establishContext(key, false) + p.setType("", tomlHash) + p.ordered = append(p.ordered, key) + case itemArrayTableStart: + kg := p.next() + p.approxLine = kg.line + + var key Key + for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() { + key = append(key, p.keyString(kg)) + } + p.assertEqual(itemArrayTableEnd, kg.typ) + + p.establishContext(key, true) + p.setType("", tomlArrayHash) + p.ordered = append(p.ordered, key) + case itemKeyStart: + kname := p.next() + p.approxLine = kname.line + p.currentKey = p.keyString(kname) + + val, typ := p.value(p.next()) + p.setValue(p.currentKey, val) + p.setType(p.currentKey, typ) + p.ordered = append(p.ordered, p.context.add(p.currentKey)) + p.currentKey = "" + default: + p.bug("Unexpected type at top level: %s", item.typ) + } +} + +// Gets a string for a key (or part of a key in a table name). +func (p *parser) keyString(it item) string { + switch it.typ { + case itemText: + return it.val + case itemString, itemMultilineString, + itemRawString, itemRawMultilineString: + s, _ := p.value(it) + return s.(string) + default: + p.bug("Unexpected key type: %s", it.typ) + panic("unreachable") + } +} + +// value translates an expected value from the lexer into a Go value wrapped +// as an empty interface. +func (p *parser) value(it item) (interface{}, tomlType) { + switch it.typ { + case itemString: + return p.replaceEscapes(it.val), p.typeOfPrimitive(it) + case itemMultilineString: + trimmed := stripFirstNewline(stripEscapedWhitespace(it.val)) + return p.replaceEscapes(trimmed), p.typeOfPrimitive(it) + case itemRawString: + return it.val, p.typeOfPrimitive(it) + case itemRawMultilineString: + return stripFirstNewline(it.val), p.typeOfPrimitive(it) + case itemBool: + switch it.val { + case "true": + return true, p.typeOfPrimitive(it) + case "false": + return false, p.typeOfPrimitive(it) + } + p.bug("Expected boolean value, but got '%s'.", it.val) + case itemInteger: + if !numUnderscoresOK(it.val) { + p.panicf("Invalid integer %q: underscores must be surrounded by digits", + it.val) + } + val := strings.Replace(it.val, "_", "", -1) + num, err := strconv.ParseInt(val, 10, 64) + if err != nil { + // Distinguish integer values. Normally, it'd be a bug if the lexer + // provides an invalid integer, but it's possible that the number is + // out of range of valid values (which the lexer cannot determine). + // So mark the former as a bug but the latter as a legitimate user + // error. + if e, ok := err.(*strconv.NumError); ok && + e.Err == strconv.ErrRange { + + p.panicf("Integer '%s' is out of the range of 64-bit "+ + "signed integers.", it.val) + } else { + p.bug("Expected integer value, but got '%s'.", it.val) + } + } + return num, p.typeOfPrimitive(it) + case itemFloat: + parts := strings.FieldsFunc(it.val, func(r rune) bool { + switch r { + case '.', 'e', 'E': + return true + } + return false + }) + for _, part := range parts { + if !numUnderscoresOK(part) { + p.panicf("Invalid float %q: underscores must be "+ + "surrounded by digits", it.val) + } + } + if !numPeriodsOK(it.val) { + // As a special case, numbers like '123.' or '1.e2', + // which are valid as far as Go/strconv are concerned, + // must be rejected because TOML says that a fractional + // part consists of '.' followed by 1+ digits. + p.panicf("Invalid float %q: '.' must be followed "+ + "by one or more digits", it.val) + } + val := strings.Replace(it.val, "_", "", -1) + num, err := strconv.ParseFloat(val, 64) + if err != nil { + if e, ok := err.(*strconv.NumError); ok && + e.Err == strconv.ErrRange { + + p.panicf("Float '%s' is out of the range of 64-bit "+ + "IEEE-754 floating-point numbers.", it.val) + } else { + p.panicf("Invalid float value: %q", it.val) + } + } + return num, p.typeOfPrimitive(it) + case itemDatetime: + var t time.Time + var ok bool + var err error + for _, format := range []string{ + "2006-01-02T15:04:05Z07:00", + "2006-01-02T15:04:05", + "2006-01-02", + } { + t, err = time.ParseInLocation(format, it.val, time.Local) + if err == nil { + ok = true + break + } + } + if !ok { + p.panicf("Invalid TOML Datetime: %q.", it.val) + } + return t, p.typeOfPrimitive(it) + case itemArray: + array := make([]interface{}, 0) + types := make([]tomlType, 0) + + for it = p.next(); it.typ != itemArrayEnd; it = p.next() { + if it.typ == itemCommentStart { + p.expect(itemText) + continue + } + + val, typ := p.value(it) + array = append(array, val) + types = append(types, typ) + } + return array, p.typeOfArray(types) + case itemInlineTableStart: + var ( + hash = make(map[string]interface{}) + outerContext = p.context + outerKey = p.currentKey + ) + + p.context = append(p.context, p.currentKey) + p.currentKey = "" + for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() { + if it.typ != itemKeyStart { + p.bug("Expected key start but instead found %q, around line %d", + it.val, p.approxLine) + } + if it.typ == itemCommentStart { + p.expect(itemText) + continue + } + + // retrieve key + k := p.next() + p.approxLine = k.line + kname := p.keyString(k) + + // retrieve value + p.currentKey = kname + val, typ := p.value(p.next()) + // make sure we keep metadata up to date + p.setType(kname, typ) + p.ordered = append(p.ordered, p.context.add(p.currentKey)) + hash[kname] = val + } + p.context = outerContext + p.currentKey = outerKey + return hash, tomlHash + } + p.bug("Unexpected value type: %s", it.typ) + panic("unreachable") +} + +// numUnderscoresOK checks whether each underscore in s is surrounded by +// characters that are not underscores. +func numUnderscoresOK(s string) bool { + accept := false + for _, r := range s { + if r == '_' { + if !accept { + return false + } + accept = false + continue + } + accept = true + } + return accept +} + +// numPeriodsOK checks whether every period in s is followed by a digit. +func numPeriodsOK(s string) bool { + period := false + for _, r := range s { + if period && !isDigit(r) { + return false + } + period = r == '.' + } + return !period +} + +// establishContext sets the current context of the parser, +// where the context is either a hash or an array of hashes. Which one is +// set depends on the value of the `array` parameter. +// +// Establishing the context also makes sure that the key isn't a duplicate, and +// will create implicit hashes automatically. +func (p *parser) establishContext(key Key, array bool) { + var ok bool + + // Always start at the top level and drill down for our context. + hashContext := p.mapping + keyContext := make(Key, 0) + + // We only need implicit hashes for key[0:-1] + for _, k := range key[0 : len(key)-1] { + _, ok = hashContext[k] + keyContext = append(keyContext, k) + + // No key? Make an implicit hash and move on. + if !ok { + p.addImplicit(keyContext) + hashContext[k] = make(map[string]interface{}) + } + + // If the hash context is actually an array of tables, then set + // the hash context to the last element in that array. + // + // Otherwise, it better be a table, since this MUST be a key group (by + // virtue of it not being the last element in a key). + switch t := hashContext[k].(type) { + case []map[string]interface{}: + hashContext = t[len(t)-1] + case map[string]interface{}: + hashContext = t + default: + p.panicf("Key '%s' was already created as a hash.", keyContext) + } + } + + p.context = keyContext + if array { + // If this is the first element for this array, then allocate a new + // list of tables for it. + k := key[len(key)-1] + if _, ok := hashContext[k]; !ok { + hashContext[k] = make([]map[string]interface{}, 0, 5) + } + + // Add a new table. But make sure the key hasn't already been used + // for something else. + if hash, ok := hashContext[k].([]map[string]interface{}); ok { + hashContext[k] = append(hash, make(map[string]interface{})) + } else { + p.panicf("Key '%s' was already created and cannot be used as "+ + "an array.", keyContext) + } + } else { + p.setValue(key[len(key)-1], make(map[string]interface{})) + } + p.context = append(p.context, key[len(key)-1]) +} + +// setValue sets the given key to the given value in the current context. +// It will make sure that the key hasn't already been defined, account for +// implicit key groups. +func (p *parser) setValue(key string, value interface{}) { + var tmpHash interface{} + var ok bool + + hash := p.mapping + keyContext := make(Key, 0) + for _, k := range p.context { + keyContext = append(keyContext, k) + if tmpHash, ok = hash[k]; !ok { + p.bug("Context for key '%s' has not been established.", keyContext) + } + switch t := tmpHash.(type) { + case []map[string]interface{}: + // The context is a table of hashes. Pick the most recent table + // defined as the current hash. + hash = t[len(t)-1] + case map[string]interface{}: + hash = t + default: + p.bug("Expected hash to have type 'map[string]interface{}', but "+ + "it has '%T' instead.", tmpHash) + } + } + keyContext = append(keyContext, key) + + if _, ok := hash[key]; ok { + // Typically, if the given key has already been set, then we have + // to raise an error since duplicate keys are disallowed. However, + // it's possible that a key was previously defined implicitly. In this + // case, it is allowed to be redefined concretely. (See the + // `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.) + // + // But we have to make sure to stop marking it as an implicit. (So that + // another redefinition provokes an error.) + // + // Note that since it has already been defined (as a hash), we don't + // want to overwrite it. So our business is done. + if p.isImplicit(keyContext) { + p.removeImplicit(keyContext) + return + } + + // Otherwise, we have a concrete key trying to override a previous + // key, which is *always* wrong. + p.panicf("Key '%s' has already been defined.", keyContext) + } + hash[key] = value +} + +// setType sets the type of a particular value at a given key. +// It should be called immediately AFTER setValue. +// +// Note that if `key` is empty, then the type given will be applied to the +// current context (which is either a table or an array of tables). +func (p *parser) setType(key string, typ tomlType) { + keyContext := make(Key, 0, len(p.context)+1) + for _, k := range p.context { + keyContext = append(keyContext, k) + } + if len(key) > 0 { // allow type setting for hashes + keyContext = append(keyContext, key) + } + p.types[keyContext.String()] = typ +} + +// addImplicit sets the given Key as having been created implicitly. +func (p *parser) addImplicit(key Key) { + p.implicits[key.String()] = true +} + +// removeImplicit stops tagging the given key as having been implicitly +// created. +func (p *parser) removeImplicit(key Key) { + p.implicits[key.String()] = false +} + +// isImplicit returns true if the key group pointed to by the key was created +// implicitly. +func (p *parser) isImplicit(key Key) bool { + return p.implicits[key.String()] +} + +// current returns the full key name of the current context. +func (p *parser) current() string { + if len(p.currentKey) == 0 { + return p.context.String() + } + if len(p.context) == 0 { + return p.currentKey + } + return fmt.Sprintf("%s.%s", p.context, p.currentKey) +} + +func stripFirstNewline(s string) string { + if len(s) == 0 || s[0] != '\n' { + return s + } + return s[1:] +} + +func stripEscapedWhitespace(s string) string { + esc := strings.Split(s, "\\\n") + if len(esc) > 1 { + for i := 1; i < len(esc); i++ { + esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace) + } + } + return strings.Join(esc, "") +} + +func (p *parser) replaceEscapes(str string) string { + var replaced []rune + s := []byte(str) + r := 0 + for r < len(s) { + if s[r] != '\\' { + c, size := utf8.DecodeRune(s[r:]) + r += size + replaced = append(replaced, c) + continue + } + r += 1 + if r >= len(s) { + p.bug("Escape sequence at end of string.") + return "" + } + switch s[r] { + default: + p.bug("Expected valid escape code after \\, but got %q.", s[r]) + return "" + case 'b': + replaced = append(replaced, rune(0x0008)) + r += 1 + case 't': + replaced = append(replaced, rune(0x0009)) + r += 1 + case 'n': + replaced = append(replaced, rune(0x000A)) + r += 1 + case 'f': + replaced = append(replaced, rune(0x000C)) + r += 1 + case 'r': + replaced = append(replaced, rune(0x000D)) + r += 1 + case '"': + replaced = append(replaced, rune(0x0022)) + r += 1 + case '\\': + replaced = append(replaced, rune(0x005C)) + r += 1 + case 'u': + // At this point, we know we have a Unicode escape of the form + // `uXXXX` at [r, r+5). (Because the lexer guarantees this + // for us.) + escaped := p.asciiEscapeToUnicode(s[r+1 : r+5]) + replaced = append(replaced, escaped) + r += 5 + case 'U': + // At this point, we know we have a Unicode escape of the form + // `uXXXX` at [r, r+9). (Because the lexer guarantees this + // for us.) + escaped := p.asciiEscapeToUnicode(s[r+1 : r+9]) + replaced = append(replaced, escaped) + r += 9 + } + } + return string(replaced) +} + +func (p *parser) asciiEscapeToUnicode(bs []byte) rune { + s := string(bs) + hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32) + if err != nil { + p.bug("Could not parse '%s' as a hexadecimal number, but the "+ + "lexer claims it's OK: %s", s, err) + } + if !utf8.ValidRune(rune(hex)) { + p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s) + } + return rune(hex) +} + +func isStringType(ty itemType) bool { + return ty == itemString || ty == itemMultilineString || + ty == itemRawString || ty == itemRawMultilineString +} diff --git a/vendor/github.com/BurntSushi/toml/session.vim b/vendor/github.com/BurntSushi/toml/session.vim new file mode 100644 index 0000000000..562164be06 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/session.vim @@ -0,0 +1 @@ +au BufWritePost *.go silent!make tags > /dev/null 2>&1 diff --git a/vendor/github.com/BurntSushi/toml/type_check.go b/vendor/github.com/BurntSushi/toml/type_check.go new file mode 100644 index 0000000000..c73f8afc1a --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/type_check.go @@ -0,0 +1,91 @@ +package toml + +// tomlType represents any Go type that corresponds to a TOML type. +// While the first draft of the TOML spec has a simplistic type system that +// probably doesn't need this level of sophistication, we seem to be militating +// toward adding real composite types. +type tomlType interface { + typeString() string +} + +// typeEqual accepts any two types and returns true if they are equal. +func typeEqual(t1, t2 tomlType) bool { + if t1 == nil || t2 == nil { + return false + } + return t1.typeString() == t2.typeString() +} + +func typeIsHash(t tomlType) bool { + return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash) +} + +type tomlBaseType string + +func (btype tomlBaseType) typeString() string { + return string(btype) +} + +func (btype tomlBaseType) String() string { + return btype.typeString() +} + +var ( + tomlInteger tomlBaseType = "Integer" + tomlFloat tomlBaseType = "Float" + tomlDatetime tomlBaseType = "Datetime" + tomlString tomlBaseType = "String" + tomlBool tomlBaseType = "Bool" + tomlArray tomlBaseType = "Array" + tomlHash tomlBaseType = "Hash" + tomlArrayHash tomlBaseType = "ArrayHash" +) + +// typeOfPrimitive returns a tomlType of any primitive value in TOML. +// Primitive values are: Integer, Float, Datetime, String and Bool. +// +// Passing a lexer item other than the following will cause a BUG message +// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime. +func (p *parser) typeOfPrimitive(lexItem item) tomlType { + switch lexItem.typ { + case itemInteger: + return tomlInteger + case itemFloat: + return tomlFloat + case itemDatetime: + return tomlDatetime + case itemString: + return tomlString + case itemMultilineString: + return tomlString + case itemRawString: + return tomlString + case itemRawMultilineString: + return tomlString + case itemBool: + return tomlBool + } + p.bug("Cannot infer primitive type of lex item '%s'.", lexItem) + panic("unreachable") +} + +// typeOfArray returns a tomlType for an array given a list of types of its +// values. +// +// In the current spec, if an array is homogeneous, then its type is always +// "Array". If the array is not homogeneous, an error is generated. +func (p *parser) typeOfArray(types []tomlType) tomlType { + // Empty arrays are cool. + if len(types) == 0 { + return tomlArray + } + + theType := types[0] + for _, t := range types[1:] { + if !typeEqual(theType, t) { + p.panicf("Array contains values of type '%s' and '%s', but "+ + "arrays must be homogeneous.", theType, t) + } + } + return tomlArray +} diff --git a/vendor/github.com/BurntSushi/toml/type_fields.go b/vendor/github.com/BurntSushi/toml/type_fields.go new file mode 100644 index 0000000000..608997c22f --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/type_fields.go @@ -0,0 +1,242 @@ +package toml + +// Struct field handling is adapted from code in encoding/json: +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the Go distribution. + +import ( + "reflect" + "sort" + "sync" +) + +// A field represents a single field found in a struct. +type field struct { + name string // the name of the field (`toml` tag included) + tag bool // whether field has a `toml` tag + index []int // represents the depth of an anonymous field + typ reflect.Type // the type of the field +} + +// byName sorts field by name, breaking ties with depth, +// then breaking ties with "name came from toml tag", then +// breaking ties with index sequence. +type byName []field + +func (x byName) Len() int { return len(x) } + +func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x byName) Less(i, j int) bool { + if x[i].name != x[j].name { + return x[i].name < x[j].name + } + if len(x[i].index) != len(x[j].index) { + return len(x[i].index) < len(x[j].index) + } + if x[i].tag != x[j].tag { + return x[i].tag + } + return byIndex(x).Less(i, j) +} + +// byIndex sorts field by index sequence. +type byIndex []field + +func (x byIndex) Len() int { return len(x) } + +func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x byIndex) Less(i, j int) bool { + for k, xik := range x[i].index { + if k >= len(x[j].index) { + return false + } + if xik != x[j].index[k] { + return xik < x[j].index[k] + } + } + return len(x[i].index) < len(x[j].index) +} + +// typeFields returns a list of fields that TOML should recognize for the given +// type. The algorithm is breadth-first search over the set of structs to +// include - the top struct and then any reachable anonymous structs. +func typeFields(t reflect.Type) []field { + // Anonymous fields to explore at the current level and the next. + current := []field{} + next := []field{{typ: t}} + + // Count of queued names for current level and the next. + count := map[reflect.Type]int{} + nextCount := map[reflect.Type]int{} + + // Types already visited at an earlier level. + visited := map[reflect.Type]bool{} + + // Fields found. + var fields []field + + for len(next) > 0 { + current, next = next, current[:0] + count, nextCount = nextCount, map[reflect.Type]int{} + + for _, f := range current { + if visited[f.typ] { + continue + } + visited[f.typ] = true + + // Scan f.typ for fields to include. + for i := 0; i < f.typ.NumField(); i++ { + sf := f.typ.Field(i) + if sf.PkgPath != "" && !sf.Anonymous { // unexported + continue + } + opts := getOptions(sf.Tag) + if opts.skip { + continue + } + index := make([]int, len(f.index)+1) + copy(index, f.index) + index[len(f.index)] = i + + ft := sf.Type + if ft.Name() == "" && ft.Kind() == reflect.Ptr { + // Follow pointer. + ft = ft.Elem() + } + + // Record found field and index sequence. + if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { + tagged := opts.name != "" + name := opts.name + if name == "" { + name = sf.Name + } + fields = append(fields, field{name, tagged, index, ft}) + if count[f.typ] > 1 { + // If there were multiple instances, add a second, + // so that the annihilation code will see a duplicate. + // It only cares about the distinction between 1 or 2, + // so don't bother generating any more copies. + fields = append(fields, fields[len(fields)-1]) + } + continue + } + + // Record new anonymous struct to explore in next round. + nextCount[ft]++ + if nextCount[ft] == 1 { + f := field{name: ft.Name(), index: index, typ: ft} + next = append(next, f) + } + } + } + } + + sort.Sort(byName(fields)) + + // Delete all fields that are hidden by the Go rules for embedded fields, + // except that fields with TOML tags are promoted. + + // The fields are sorted in primary order of name, secondary order + // of field index length. Loop over names; for each name, delete + // hidden fields by choosing the one dominant field that survives. + out := fields[:0] + for advance, i := 0, 0; i < len(fields); i += advance { + // One iteration per name. + // Find the sequence of fields with the name of this first field. + fi := fields[i] + name := fi.name + for advance = 1; i+advance < len(fields); advance++ { + fj := fields[i+advance] + if fj.name != name { + break + } + } + if advance == 1 { // Only one field with this name + out = append(out, fi) + continue + } + dominant, ok := dominantField(fields[i : i+advance]) + if ok { + out = append(out, dominant) + } + } + + fields = out + sort.Sort(byIndex(fields)) + + return fields +} + +// dominantField looks through the fields, all of which are known to +// have the same name, to find the single field that dominates the +// others using Go's embedding rules, modified by the presence of +// TOML tags. If there are multiple top-level fields, the boolean +// will be false: This condition is an error in Go and we skip all +// the fields. +func dominantField(fields []field) (field, bool) { + // The fields are sorted in increasing index-length order. The winner + // must therefore be one with the shortest index length. Drop all + // longer entries, which is easy: just truncate the slice. + length := len(fields[0].index) + tagged := -1 // Index of first tagged field. + for i, f := range fields { + if len(f.index) > length { + fields = fields[:i] + break + } + if f.tag { + if tagged >= 0 { + // Multiple tagged fields at the same level: conflict. + // Return no field. + return field{}, false + } + tagged = i + } + } + if tagged >= 0 { + return fields[tagged], true + } + // All remaining fields have the same length. If there's more than one, + // we have a conflict (two fields named "X" at the same level) and we + // return no field. + if len(fields) > 1 { + return field{}, false + } + return fields[0], true +} + +var fieldCache struct { + sync.RWMutex + m map[reflect.Type][]field +} + +// cachedTypeFields is like typeFields but uses a cache to avoid repeated work. +func cachedTypeFields(t reflect.Type) []field { + fieldCache.RLock() + f := fieldCache.m[t] + fieldCache.RUnlock() + if f != nil { + return f + } + + // Compute fields without lock. + // Might duplicate effort but won't hold other computations back. + f = typeFields(t) + if f == nil { + f = []field{} + } + + fieldCache.Lock() + if fieldCache.m == nil { + fieldCache.m = map[reflect.Type][]field{} + } + fieldCache.m[t] = f + fieldCache.Unlock() + return f +} diff --git a/vendor/modules.txt b/vendor/modules.txt index eccfc3a7b8..c65f8a9a76 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,3 +1,5 @@ +# github.com/BurntSushi/toml v0.3.1 +github.com/BurntSushi/toml # github.com/Microsoft/go-winio v0.4.17 github.com/Microsoft/go-winio github.com/Microsoft/go-winio/backuptar