From 5787bb203d0906ca053188a10d2dcccca0f33a01 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 27 Jun 2023 10:01:07 +0800 Subject: [PATCH] add Suspend/Resume operation (#605) * add Suspend/Resume operation * add tests --- cluster/calcium/control.go | 42 +++++++++ cluster/calcium/control_test.go | 125 ++++++++++++++++++++++++++ cluster/cluster.go | 4 + engine/docker/container.go | 10 +++ engine/engine.go | 2 + engine/fake/fake.go | 10 +++ engine/mocks/API.go | 155 +++++++++++++++++++++++++------- engine/mocks/fakeengine/mock.go | 2 + engine/virt/virt.go | 12 +++ go.mod | 2 +- go.sum | 6 +- types/specs.go | 8 +- types/workload.go | 16 ++++ types/workload_test.go | 8 ++ 14 files changed, 362 insertions(+), 40 deletions(-) diff --git a/cluster/calcium/control.go b/cluster/calcium/control.go index 28feb0c1a..610bbb5d4 100644 --- a/cluster/calcium/control.go +++ b/cluster/calcium/control.go @@ -43,6 +43,12 @@ func (c *Calcium) ControlWorkload(ctx context.Context, IDs []string, typ string, startHook, err := c.doStartWorkload(ctx, workload, force) message = append(message, startHook...) return err + case cluster.WorkloadSuspend: + message, err = c.doSuspendWorkload(ctx, workload, force) + return err + case cluster.WorkloadResume: + message, err = c.doResumeWorkload(ctx, workload, force) + return err } return types.ErrInvaildControlType }) @@ -103,3 +109,39 @@ func (c *Calcium) doStopWorkload(ctx context.Context, workload *types.Workload, } return message, err } + +func (c *Calcium) doSuspendWorkload(ctx context.Context, workload *types.Workload, force bool) (message []*bytes.Buffer, err error) { + if workload.Hook != nil && len(workload.Hook.BeforeSuspend) > 0 { + message, err = c.doHook( + ctx, + workload.ID, workload.User, + workload.Hook.BeforeSuspend, workload.Env, + workload.Hook.Force, workload.Privileged, + force, workload.Engine, + ) + if err != nil { + return message, err + } + } + + if err = workload.Suspend(ctx); err != nil { + message = append(message, bytes.NewBufferString(err.Error())) + } + return message, err +} + +func (c *Calcium) doResumeWorkload(ctx context.Context, workload *types.Workload, force bool) (message []*bytes.Buffer, err error) { + if err = workload.Resume(ctx); err != nil { + return message, err + } + if workload.Hook != nil && len(workload.Hook.AfterResume) > 0 { + message, err = c.doHook( + ctx, + workload.ID, workload.User, + workload.Hook.AfterResume, workload.Env, + workload.Hook.Force, workload.Privileged, + force, workload.Engine, + ) + } + return message, err +} diff --git a/cluster/calcium/control_test.go b/cluster/calcium/control_test.go index c4e8960ef..b9923175b 100644 --- a/cluster/calcium/control_test.go +++ b/cluster/calcium/control_test.go @@ -178,3 +178,128 @@ func TestControlRestart(t *testing.T) { assert.NoError(t, r.Error) } } + +func TestControlSuspend(t *testing.T) { + c := NewTestCluster() + ctx := context.Background() + store := c.store.(*storemocks.Store) + lock := &lockmocks.DistributedLock{} + lock.On("Lock", mock.Anything).Return(ctx, nil) + lock.On("Unlock", mock.Anything).Return(nil) + store.On("CreateLock", mock.Anything, mock.Anything).Return(lock, nil) + workload := &types.Workload{ + ID: "id1", + Privileged: true, + } + engine := &enginemocks.API{} + workload.Engine = engine + store.On("GetWorkloads", mock.Anything, mock.Anything).Return([]*types.Workload{workload}, nil) + // failed, hook true, remove always false + hook := &types.Hook{ + BeforeSuspend: []string{"cmd1"}, + } + workload.Hook = hook + workload.Hook.Force = true + engine.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return("", nil, nil, nil, types.ErrNilEngine) + ch, err := c.ControlWorkload(ctx, []string{"id1"}, cluster.WorkloadSuspend, false) + assert.NoError(t, err) + for r := range ch { + assert.Error(t, r.Error) + } + // stop failed + workload.Hook.Force = false + ch, err = c.ControlWorkload(ctx, []string{"id1"}, cluster.WorkloadSuspend, false) + engine.On("VirtualizationSuspend", mock.Anything, mock.Anything, mock.Anything).Return(types.ErrNilEngine).Once() + assert.NoError(t, err) + for r := range ch { + assert.Error(t, r.Error) + } + engine.On("VirtualizationSuspend", mock.Anything, mock.Anything, mock.Anything).Return(nil) + // stop success + ch, err = c.ControlWorkload(ctx, []string{"id1"}, cluster.WorkloadSuspend, false) + assert.NoError(t, err) + for r := range ch { + assert.NoError(t, r.Error) + } +} + +func TestControlResume(t *testing.T) { + c := NewTestCluster() + ctx := context.Background() + store := c.store.(*storemocks.Store) + lock := &lockmocks.DistributedLock{} + lock.On("Lock", mock.Anything).Return(ctx, nil) + lock.On("Unlock", mock.Anything).Return(nil) + store.On("CreateLock", mock.Anything, mock.Anything).Return(lock, nil) + // failed by GetWorkloads + store.On("GetWorkloads", mock.Anything, mock.Anything).Return(nil, types.ErrMockError).Once() + ch, err := c.ControlWorkload(ctx, []string{"id1"}, "", true) + assert.NoError(t, err) + for r := range ch { + assert.Error(t, r.Error) + } + workload := &types.Workload{ + ID: "id1", + Privileged: true, + } + engine := &enginemocks.API{} + workload.Engine = engine + store.On("GetWorkloads", mock.Anything, mock.Anything).Return([]*types.Workload{workload}, nil) + // failed by type + ch, err = c.ControlWorkload(ctx, []string{"id1"}, "", true) + assert.NoError(t, err) + for r := range ch { + assert.Error(t, r.Error) + } + // failed by start + engine.On("VirtualizationResume", mock.Anything, mock.Anything).Return(types.ErrNilEngine).Once() + ch, err = c.ControlWorkload(ctx, []string{"id1"}, cluster.WorkloadResume, false) + assert.NoError(t, err) + for r := range ch { + assert.Error(t, r.Error) + } + engine.On("VirtualizationResume", mock.Anything, mock.Anything).Return(nil) + // failed by Execute + hook := &types.Hook{ + AfterResume: []string{"cmd1", "cmd2"}, + } + workload.Hook = hook + workload.Hook.Force = false + engine.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return("", nil, nil, nil, types.ErrNilEngine).Times(3) + ch, err = c.ControlWorkload(ctx, []string{"id1"}, cluster.WorkloadResume, false) + assert.NoError(t, err) + for r := range ch { + assert.NoError(t, r.Error) + } + // force false, get no error + workload.Hook.Force = true + ch, err = c.ControlWorkload(ctx, []string{"id1"}, cluster.WorkloadResume, false) + assert.NoError(t, err) + for r := range ch { + assert.Error(t, r.Error) + assert.Equal(t, r.WorkloadID, "id1") + } + data := io.NopCloser(bytes.NewBufferString("output")) + engine.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return("eid", data, nil, nil, nil).Times(4) + // failed by ExecExitCode + engine.On("ExecExitCode", mock.Anything, mock.Anything, mock.Anything).Return(-1, types.ErrNilEngine).Once() + ch, err = c.ControlWorkload(ctx, []string{"id1"}, cluster.WorkloadResume, false) + assert.NoError(t, err) + for r := range ch { + assert.Error(t, r.Error) + } + // exitCode is not 0 + engine.On("ExecExitCode", mock.Anything, mock.Anything, mock.Anything).Return(-1, nil).Once() + ch, err = c.ControlWorkload(ctx, []string{"id1"}, cluster.WorkloadResume, false) + assert.NoError(t, err) + for r := range ch { + assert.Error(t, r.Error) + } + // exitCode is 0 + engine.On("ExecExitCode", mock.Anything, mock.Anything, mock.Anything).Return(0, nil) + ch, err = c.ControlWorkload(ctx, []string{"id1"}, cluster.WorkloadResume, false) + assert.NoError(t, err) + for r := range ch { + assert.NoError(t, r.Error) + } +} diff --git a/cluster/cluster.go b/cluster/cluster.go index 2e388eb36..8fe2c71c4 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -28,6 +28,10 @@ const ( WorkloadStart = "start" // WorkloadRestart for restart workload WorkloadRestart = "restart" + // WorkloadSuspend for suspending workload + WorkloadSuspend = "suspend" + // WorkloadResume for resuming workload + WorkloadResume = "resume" // WorkloadLock for lock workload WorkloadLock = "clock_%s" // PodLock for lock pod diff --git a/engine/docker/container.go b/engine/docker/container.go index c3a226229..0a14315f5 100644 --- a/engine/docker/container.go +++ b/engine/docker/container.go @@ -307,6 +307,16 @@ func (e *Engine) VirtualizationStop(ctx context.Context, ID string, gracefulTime return e.client.ContainerStop(ctx, ID, dockercontainer.StopOptions{Timeout: timeout}) } +// VirtualizationSuspend suspends virtualization +func (e *Engine) VirtualizationSuspend(context.Context, string) error { + return nil +} + +// VirtualizationResume resumes virtualization +func (e *Engine) VirtualizationResume(context.Context, string) error { + return nil +} + // VirtualizationRemove remove virtualization func (e *Engine) VirtualizationRemove(ctx context.Context, ID string, removeVolumes, force bool) error { if err := e.client.ContainerRemove(ctx, ID, dockertypes.ContainerRemoveOptions{RemoveVolumes: removeVolumes, Force: force}); err != nil { diff --git a/engine/engine.go b/engine/engine.go index 3b3a24f82..f5434343e 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -43,6 +43,8 @@ type API interface { VirtualizationStart(ctx context.Context, ID string) error VirtualizationStop(ctx context.Context, ID string, gracefulTimeout time.Duration) error VirtualizationRemove(ctx context.Context, ID string, volumes, force bool) error + VirtualizationSuspend(ctx context.Context, ID string) error + VirtualizationResume(ctx context.Context, ID string) error VirtualizationInspect(ctx context.Context, ID string) (*enginetypes.VirtualizationInfo, error) VirtualizationLogs(ctx context.Context, opts *enginetypes.VirtualizationLogStreamOptions) (stdout, stderr io.ReadCloser, err error) VirtualizationAttach(ctx context.Context, ID string, stream, openStdin bool) (stdout, stderr io.ReadCloser, stdin io.WriteCloser, err error) diff --git a/engine/fake/fake.go b/engine/fake/fake.go index f75a1d042..1a55677e9 100644 --- a/engine/fake/fake.go +++ b/engine/fake/fake.go @@ -140,6 +140,16 @@ func (f *EngineWithErr) VirtualizationStop(context.Context, string, time.Duratio return f.DefaultErr } +// VirtualizationSuspend . +func (f *EngineWithErr) VirtualizationSuspend(context.Context, string) error { + return f.DefaultErr +} + +// VirtualizationResume . +func (f *EngineWithErr) VirtualizationResume(context.Context, string) error { + return f.DefaultErr +} + // VirtualizationRemove . func (f *EngineWithErr) VirtualizationRemove(context.Context, string, bool, bool) error { return f.DefaultErr diff --git a/engine/mocks/API.go b/engine/mocks/API.go index 20676e4b9..8ed7cbfd5 100644 --- a/engine/mocks/API.go +++ b/engine/mocks/API.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.26.1. DO NOT EDIT. package mocks @@ -28,13 +28,17 @@ func (_m *API) BuildContent(ctx context.Context, scm source.Source, opts *types. ret := _m.Called(ctx, scm, opts) var r0 string + var r1 io.Reader + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, source.Source, *types.BuildContentOptions) (string, io.Reader, error)); ok { + return rf(ctx, scm, opts) + } if rf, ok := ret.Get(0).(func(context.Context, source.Source, *types.BuildContentOptions) string); ok { r0 = rf(ctx, scm, opts) } else { r0 = ret.Get(0).(string) } - var r1 io.Reader if rf, ok := ret.Get(1).(func(context.Context, source.Source, *types.BuildContentOptions) io.Reader); ok { r1 = rf(ctx, scm, opts) } else { @@ -43,7 +47,6 @@ func (_m *API) BuildContent(ctx context.Context, scm source.Source, opts *types. } } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, source.Source, *types.BuildContentOptions) error); ok { r2 = rf(ctx, scm, opts) } else { @@ -88,13 +91,16 @@ func (_m *API) ExecExitCode(ctx context.Context, ID string, execID string) (int, ret := _m.Called(ctx, ID, execID) var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (int, error)); ok { + return rf(ctx, ID, execID) + } if rf, ok := ret.Get(0).(func(context.Context, string, string) int); ok { r0 = rf(ctx, ID, execID) } else { r0 = ret.Get(0).(int) } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { r1 = rf(ctx, ID, execID) } else { @@ -123,13 +129,19 @@ func (_m *API) Execute(ctx context.Context, ID string, config *types.ExecConfig) ret := _m.Called(ctx, ID, config) var r0 string + var r1 io.ReadCloser + var r2 io.ReadCloser + var r3 io.WriteCloser + var r4 error + if rf, ok := ret.Get(0).(func(context.Context, string, *types.ExecConfig) (string, io.ReadCloser, io.ReadCloser, io.WriteCloser, error)); ok { + return rf(ctx, ID, config) + } if rf, ok := ret.Get(0).(func(context.Context, string, *types.ExecConfig) string); ok { r0 = rf(ctx, ID, config) } else { r0 = ret.Get(0).(string) } - var r1 io.ReadCloser if rf, ok := ret.Get(1).(func(context.Context, string, *types.ExecConfig) io.ReadCloser); ok { r1 = rf(ctx, ID, config) } else { @@ -138,7 +150,6 @@ func (_m *API) Execute(ctx context.Context, ID string, config *types.ExecConfig) } } - var r2 io.ReadCloser if rf, ok := ret.Get(2).(func(context.Context, string, *types.ExecConfig) io.ReadCloser); ok { r2 = rf(ctx, ID, config) } else { @@ -147,7 +158,6 @@ func (_m *API) Execute(ctx context.Context, ID string, config *types.ExecConfig) } } - var r3 io.WriteCloser if rf, ok := ret.Get(3).(func(context.Context, string, *types.ExecConfig) io.WriteCloser); ok { r3 = rf(ctx, ID, config) } else { @@ -156,7 +166,6 @@ func (_m *API) Execute(ctx context.Context, ID string, config *types.ExecConfig) } } - var r4 error if rf, ok := ret.Get(4).(func(context.Context, string, *types.ExecConfig) error); ok { r4 = rf(ctx, ID, config) } else { @@ -171,6 +180,10 @@ func (_m *API) ImageBuild(ctx context.Context, input io.Reader, refs []string, p ret := _m.Called(ctx, input, refs, platform) var r0 io.ReadCloser + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, io.Reader, []string, string) (io.ReadCloser, error)); ok { + return rf(ctx, input, refs, platform) + } if rf, ok := ret.Get(0).(func(context.Context, io.Reader, []string, string) io.ReadCloser); ok { r0 = rf(ctx, input, refs, platform) } else { @@ -179,7 +192,6 @@ func (_m *API) ImageBuild(ctx context.Context, input io.Reader, refs []string, p } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, io.Reader, []string, string) error); ok { r1 = rf(ctx, input, refs, platform) } else { @@ -194,13 +206,16 @@ func (_m *API) ImageBuildCachePrune(ctx context.Context, all bool) (uint64, erro ret := _m.Called(ctx, all) var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, bool) (uint64, error)); ok { + return rf(ctx, all) + } if rf, ok := ret.Get(0).(func(context.Context, bool) uint64); ok { r0 = rf(ctx, all) } else { r0 = ret.Get(0).(uint64) } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, bool) error); ok { r1 = rf(ctx, all) } else { @@ -215,13 +230,16 @@ func (_m *API) ImageBuildFromExist(ctx context.Context, ID string, refs []string ret := _m.Called(ctx, ID, refs, user) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, []string, string) (string, error)); ok { + return rf(ctx, ID, refs, user) + } if rf, ok := ret.Get(0).(func(context.Context, string, []string, string) string); ok { r0 = rf(ctx, ID, refs, user) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, []string, string) error); ok { r1 = rf(ctx, ID, refs, user) } else { @@ -236,6 +254,10 @@ func (_m *API) ImageList(ctx context.Context, image string) ([]*types.Image, err ret := _m.Called(ctx, image) var r0 []*types.Image + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]*types.Image, error)); ok { + return rf(ctx, image) + } if rf, ok := ret.Get(0).(func(context.Context, string) []*types.Image); ok { r0 = rf(ctx, image) } else { @@ -244,7 +266,6 @@ func (_m *API) ImageList(ctx context.Context, image string) ([]*types.Image, err } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, image) } else { @@ -259,6 +280,10 @@ func (_m *API) ImageLocalDigests(ctx context.Context, image string) ([]string, e ret := _m.Called(ctx, image) var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]string, error)); ok { + return rf(ctx, image) + } if rf, ok := ret.Get(0).(func(context.Context, string) []string); ok { r0 = rf(ctx, image) } else { @@ -267,7 +292,6 @@ func (_m *API) ImageLocalDigests(ctx context.Context, image string) ([]string, e } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, image) } else { @@ -282,6 +306,10 @@ func (_m *API) ImagePull(ctx context.Context, ref string, all bool) (io.ReadClos ret := _m.Called(ctx, ref, all) var r0 io.ReadCloser + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) (io.ReadCloser, error)); ok { + return rf(ctx, ref, all) + } if rf, ok := ret.Get(0).(func(context.Context, string, bool) io.ReadCloser); ok { r0 = rf(ctx, ref, all) } else { @@ -290,7 +318,6 @@ func (_m *API) ImagePull(ctx context.Context, ref string, all bool) (io.ReadClos } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok { r1 = rf(ctx, ref, all) } else { @@ -305,6 +332,10 @@ func (_m *API) ImagePush(ctx context.Context, ref string) (io.ReadCloser, error) ret := _m.Called(ctx, ref) var r0 io.ReadCloser + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (io.ReadCloser, error)); ok { + return rf(ctx, ref) + } if rf, ok := ret.Get(0).(func(context.Context, string) io.ReadCloser); ok { r0 = rf(ctx, ref) } else { @@ -313,7 +344,6 @@ func (_m *API) ImagePush(ctx context.Context, ref string) (io.ReadCloser, error) } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, ref) } else { @@ -328,13 +358,16 @@ func (_m *API) ImageRemoteDigest(ctx context.Context, image string) (string, err ret := _m.Called(ctx, image) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (string, error)); ok { + return rf(ctx, image) + } if rf, ok := ret.Get(0).(func(context.Context, string) string); ok { r0 = rf(ctx, image) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, image) } else { @@ -349,6 +382,10 @@ func (_m *API) ImageRemove(ctx context.Context, image string, force bool, prune ret := _m.Called(ctx, image, force, prune) var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool, bool) ([]string, error)); ok { + return rf(ctx, image, force, prune) + } if rf, ok := ret.Get(0).(func(context.Context, string, bool, bool) []string); ok { r0 = rf(ctx, image, force, prune) } else { @@ -357,7 +394,6 @@ func (_m *API) ImageRemove(ctx context.Context, image string, force bool, prune } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, bool, bool) error); ok { r1 = rf(ctx, image, force, prune) } else { @@ -386,6 +422,10 @@ func (_m *API) Info(ctx context.Context) (*types.Info, error) { ret := _m.Called(ctx) var r0 *types.Info + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*types.Info, error)); ok { + return rf(ctx) + } if rf, ok := ret.Get(0).(func(context.Context) *types.Info); ok { r0 = rf(ctx) } else { @@ -394,7 +434,6 @@ func (_m *API) Info(ctx context.Context) (*types.Info, error) { } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context) error); ok { r1 = rf(ctx) } else { @@ -409,6 +448,10 @@ func (_m *API) NetworkConnect(ctx context.Context, network string, target string ret := _m.Called(ctx, network, target, ipv4, ipv6) var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) ([]string, error)); ok { + return rf(ctx, network, target, ipv4, ipv6) + } if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) []string); ok { r0 = rf(ctx, network, target, ipv4, ipv6) } else { @@ -417,7 +460,6 @@ func (_m *API) NetworkConnect(ctx context.Context, network string, target string } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string, string, string) error); ok { r1 = rf(ctx, network, target, ipv4, ipv6) } else { @@ -446,6 +488,10 @@ func (_m *API) NetworkList(ctx context.Context, drivers []string) ([]*types.Netw ret := _m.Called(ctx, drivers) var r0 []*types.Network + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []string) ([]*types.Network, error)); ok { + return rf(ctx, drivers) + } if rf, ok := ret.Get(0).(func(context.Context, []string) []*types.Network); ok { r0 = rf(ctx, drivers) } else { @@ -454,7 +500,6 @@ func (_m *API) NetworkList(ctx context.Context, drivers []string) ([]*types.Netw } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { r1 = rf(ctx, drivers) } else { @@ -483,6 +528,12 @@ func (_m *API) VirtualizationAttach(ctx context.Context, ID string, stream bool, ret := _m.Called(ctx, ID, stream, openStdin) var r0 io.ReadCloser + var r1 io.ReadCloser + var r2 io.WriteCloser + var r3 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool, bool) (io.ReadCloser, io.ReadCloser, io.WriteCloser, error)); ok { + return rf(ctx, ID, stream, openStdin) + } if rf, ok := ret.Get(0).(func(context.Context, string, bool, bool) io.ReadCloser); ok { r0 = rf(ctx, ID, stream, openStdin) } else { @@ -491,7 +542,6 @@ func (_m *API) VirtualizationAttach(ctx context.Context, ID string, stream bool, } } - var r1 io.ReadCloser if rf, ok := ret.Get(1).(func(context.Context, string, bool, bool) io.ReadCloser); ok { r1 = rf(ctx, ID, stream, openStdin) } else { @@ -500,7 +550,6 @@ func (_m *API) VirtualizationAttach(ctx context.Context, ID string, stream bool, } } - var r2 io.WriteCloser if rf, ok := ret.Get(2).(func(context.Context, string, bool, bool) io.WriteCloser); ok { r2 = rf(ctx, ID, stream, openStdin) } else { @@ -509,7 +558,6 @@ func (_m *API) VirtualizationAttach(ctx context.Context, ID string, stream bool, } } - var r3 error if rf, ok := ret.Get(3).(func(context.Context, string, bool, bool) error); ok { r3 = rf(ctx, ID, stream, openStdin) } else { @@ -524,6 +572,13 @@ func (_m *API) VirtualizationCopyFrom(ctx context.Context, ID string, path strin ret := _m.Called(ctx, ID, path) var r0 []byte + var r1 int + var r2 int + var r3 int64 + var r4 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) ([]byte, int, int, int64, error)); ok { + return rf(ctx, ID, path) + } if rf, ok := ret.Get(0).(func(context.Context, string, string) []byte); ok { r0 = rf(ctx, ID, path) } else { @@ -532,28 +587,24 @@ func (_m *API) VirtualizationCopyFrom(ctx context.Context, ID string, path strin } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, string, string) int); ok { r1 = rf(ctx, ID, path) } else { r1 = ret.Get(1).(int) } - var r2 int if rf, ok := ret.Get(2).(func(context.Context, string, string) int); ok { r2 = rf(ctx, ID, path) } else { r2 = ret.Get(2).(int) } - var r3 int64 if rf, ok := ret.Get(3).(func(context.Context, string, string) int64); ok { r3 = rf(ctx, ID, path) } else { r3 = ret.Get(3).(int64) } - var r4 error if rf, ok := ret.Get(4).(func(context.Context, string, string) error); ok { r4 = rf(ctx, ID, path) } else { @@ -582,6 +633,10 @@ func (_m *API) VirtualizationCreate(ctx context.Context, opts *types.Virtualizat ret := _m.Called(ctx, opts) var r0 *types.VirtualizationCreated + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.VirtualizationCreateOptions) (*types.VirtualizationCreated, error)); ok { + return rf(ctx, opts) + } if rf, ok := ret.Get(0).(func(context.Context, *types.VirtualizationCreateOptions) *types.VirtualizationCreated); ok { r0 = rf(ctx, opts) } else { @@ -590,7 +645,6 @@ func (_m *API) VirtualizationCreate(ctx context.Context, opts *types.Virtualizat } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, *types.VirtualizationCreateOptions) error); ok { r1 = rf(ctx, opts) } else { @@ -605,6 +659,10 @@ func (_m *API) VirtualizationInspect(ctx context.Context, ID string) (*types.Vir ret := _m.Called(ctx, ID) var r0 *types.VirtualizationInfo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*types.VirtualizationInfo, error)); ok { + return rf(ctx, ID) + } if rf, ok := ret.Get(0).(func(context.Context, string) *types.VirtualizationInfo); ok { r0 = rf(ctx, ID) } else { @@ -613,7 +671,6 @@ func (_m *API) VirtualizationInspect(ctx context.Context, ID string) (*types.Vir } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, ID) } else { @@ -628,6 +685,11 @@ func (_m *API) VirtualizationLogs(ctx context.Context, opts *types.Virtualizatio ret := _m.Called(ctx, opts) var r0 io.ReadCloser + var r1 io.ReadCloser + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *types.VirtualizationLogStreamOptions) (io.ReadCloser, io.ReadCloser, error)); ok { + return rf(ctx, opts) + } if rf, ok := ret.Get(0).(func(context.Context, *types.VirtualizationLogStreamOptions) io.ReadCloser); ok { r0 = rf(ctx, opts) } else { @@ -636,7 +698,6 @@ func (_m *API) VirtualizationLogs(ctx context.Context, opts *types.Virtualizatio } } - var r1 io.ReadCloser if rf, ok := ret.Get(1).(func(context.Context, *types.VirtualizationLogStreamOptions) io.ReadCloser); ok { r1 = rf(ctx, opts) } else { @@ -645,7 +706,6 @@ func (_m *API) VirtualizationLogs(ctx context.Context, opts *types.Virtualizatio } } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, *types.VirtualizationLogStreamOptions) error); ok { r2 = rf(ctx, opts) } else { @@ -683,6 +743,20 @@ func (_m *API) VirtualizationResize(ctx context.Context, ID string, height uint, return r0 } +// VirtualizationResume provides a mock function with given fields: ctx, ID +func (_m *API) VirtualizationResume(ctx context.Context, ID string) error { + ret := _m.Called(ctx, ID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, ID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // VirtualizationStart provides a mock function with given fields: ctx, ID func (_m *API) VirtualizationStart(ctx context.Context, ID string) error { ret := _m.Called(ctx, ID) @@ -711,6 +785,20 @@ func (_m *API) VirtualizationStop(ctx context.Context, ID string, gracefulTimeou return r0 } +// VirtualizationSuspend provides a mock function with given fields: ctx, ID +func (_m *API) VirtualizationSuspend(ctx context.Context, ID string) error { + ret := _m.Called(ctx, ID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, ID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // VirtualizationUpdateResource provides a mock function with given fields: ctx, ID, params func (_m *API) VirtualizationUpdateResource(ctx context.Context, ID string, params resourcetypes.Resources) error { ret := _m.Called(ctx, ID, params) @@ -730,6 +818,10 @@ func (_m *API) VirtualizationWait(ctx context.Context, ID string, state string) ret := _m.Called(ctx, ID, state) var r0 *types.VirtualizationWaitResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*types.VirtualizationWaitResult, error)); ok { + return rf(ctx, ID, state) + } if rf, ok := ret.Get(0).(func(context.Context, string, string) *types.VirtualizationWaitResult); ok { r0 = rf(ctx, ID, state) } else { @@ -738,7 +830,6 @@ func (_m *API) VirtualizationWait(ctx context.Context, ID string, state string) } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { r1 = rf(ctx, ID, state) } else { diff --git a/engine/mocks/fakeengine/mock.go b/engine/mocks/fakeengine/mock.go index fe677f09b..20096e3c3 100644 --- a/engine/mocks/fakeengine/mock.go +++ b/engine/mocks/fakeengine/mock.go @@ -121,6 +121,8 @@ func MakeClient(_ context.Context, _ coretypes.Config, _, _, _, _, _ string) (en e.On("VirtualizationStart", mock.Anything, mock.Anything).Return(nil) e.On("VirtualizationStop", mock.Anything, mock.Anything, mock.Anything).Return(nil) e.On("VirtualizationRemove", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + e.On("VirtualizationSuspend", mock.Anything, mock.Anything).Return(nil) + e.On("VirtualizationResume", mock.Anything, mock.Anything).Return(nil) vcJSON := &enginetypes.VirtualizationInfo{ID: ID, Image: "mock-image", Running: true, Networks: map[string]string{"mock-network": "1.1.1.1"}} e.On("VirtualizationInspect", mock.Anything, mock.Anything).Return(vcJSON, nil) logs := io.NopCloser(bytes.NewBufferString("logs1...\nlogs2...\n")) diff --git a/engine/virt/virt.go b/engine/virt/virt.go index 33b6b61ca..c752b79a1 100644 --- a/engine/virt/virt.go +++ b/engine/virt/virt.go @@ -257,6 +257,18 @@ func (v *Virt) VirtualizationRemove(ctx context.Context, ID string, _, force boo return } +// VirtualizationSuspend suspends a guest. +func (v *Virt) VirtualizationSuspend(ctx context.Context, ID string) (err error) { + _, err = v.client.SuspendGuest(ctx, ID) + return +} + +// VirtualizationResume resumes a guest. +func (v *Virt) VirtualizationResume(ctx context.Context, ID string) (err error) { + _, err = v.client.ResumeGuest(ctx, ID) + return +} + // VirtualizationInspect gets a guest. func (v *Virt) VirtualizationInspect(ctx context.Context, ID string) (*enginetypes.VirtualizationInfo, error) { guest, err := v.client.GetGuest(ctx, ID) diff --git a/go.mod b/go.mod index 3d58ca7d3..f1393422b 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/opencontainers/image-spec v1.1.0-rc2.0.20221005185240-3a7f492d3f1b github.com/panjf2000/ants/v2 v2.7.3 github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/projecteru2/libyavirt v0.0.0-20230608083732-6473d0aff88b + github.com/projecteru2/libyavirt v0.0.0-20230621055438-179374d3115d github.com/prometheus/client_golang v1.15.0 github.com/rs/zerolog v1.29.1 github.com/sanity-io/litter v1.5.5 diff --git a/go.sum b/go.sum index e3c03da38..d69e71720 100644 --- a/go.sum +++ b/go.sum @@ -459,10 +459,8 @@ github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= -github.com/projecteru2/libyavirt v0.0.0-20230524090109-0faf050e0f3b h1:mXvbNYdr2uh2mhk5HdiBBSc9DhaR2RuulURaXhJaP2I= -github.com/projecteru2/libyavirt v0.0.0-20230524090109-0faf050e0f3b/go.mod h1:N41KaKmqbailweGs4x/mt2H0O0Y7MizObZQ+igLdzpw= -github.com/projecteru2/libyavirt v0.0.0-20230608083732-6473d0aff88b h1:mLdupCVfmXXGpaVW4QFblvsNJjYqhDJCM6LmEWnxLDE= -github.com/projecteru2/libyavirt v0.0.0-20230608083732-6473d0aff88b/go.mod h1:N41KaKmqbailweGs4x/mt2H0O0Y7MizObZQ+igLdzpw= +github.com/projecteru2/libyavirt v0.0.0-20230621055438-179374d3115d h1:L+mU66mq9qBDFyFGs3O0KDqo1ZzCTu7ZfRvQ6j/hDRY= +github.com/projecteru2/libyavirt v0.0.0-20230621055438-179374d3115d/go.mod h1:N41KaKmqbailweGs4x/mt2H0O0Y7MizObZQ+igLdzpw= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.1.0/go.mod h1:I1FGZT9+L76gKKOs5djB6ezCbFQP1xR9D75/vuwEF3g= diff --git a/types/specs.go b/types/specs.go index b6b01db04..1e1a38c88 100644 --- a/types/specs.go +++ b/types/specs.go @@ -6,9 +6,11 @@ import ( // Hook define hooks type Hook struct { - AfterStart []string `yaml:"after_start,omitempty"` - BeforeStop []string `yaml:"before_stop,omitempty"` - Force bool `yaml:"force,omitempty"` + AfterStart []string `yaml:"after_start,omitempty"` + BeforeStop []string `yaml:"before_stop,omitempty"` + AfterResume []string `yaml:"after_resume,omitempty"` + BeforeSuspend []string `yaml:"before_suspend,omitempty"` + Force bool `yaml:"force,omitempty"` } // HealthCheck define healthcheck diff --git a/types/workload.go b/types/workload.go index 4831e4c0d..6347675d5 100644 --- a/types/workload.go +++ b/types/workload.go @@ -82,6 +82,22 @@ func (c *Workload) Stop(ctx context.Context, force bool) error { return c.Engine.VirtualizationStop(ctx, c.ID, gracefulTimeout) } +// Suspend a workload +func (c *Workload) Suspend(ctx context.Context) error { + if c.Engine == nil { + return ErrNilEngine + } + return c.Engine.VirtualizationSuspend(ctx, c.ID) +} + +// Resume a workload +func (c *Workload) Resume(ctx context.Context) error { + if c.Engine == nil { + return ErrNilEngine + } + return c.Engine.VirtualizationResume(ctx, c.ID) +} + // Remove a workload func (c *Workload) Remove(ctx context.Context, force bool) (err error) { if c.Engine == nil { diff --git a/types/workload_test.go b/types/workload_test.go index 9fa0e7474..92d663f1e 100644 --- a/types/workload_test.go +++ b/types/workload_test.go @@ -30,12 +30,16 @@ func TestWorkloadControl(t *testing.T) { mockEngine.On("VirtualizationStart", mock.Anything, mock.Anything).Return(nil) mockEngine.On("VirtualizationStop", mock.Anything, mock.Anything, mock.Anything).Return(nil) mockEngine.On("VirtualizationRemove", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockEngine.On("VirtualizationSuspend", mock.Anything, mock.Anything).Return(nil) + mockEngine.On("VirtualizationResume", mock.Anything, mock.Anything).Return(nil) ctx := context.Background() c := Workload{} assert.Error(t, c.Start(ctx)) assert.Error(t, c.Stop(ctx, true)) assert.Error(t, c.Remove(ctx, true)) + assert.Error(t, c.Suspend(ctx)) + assert.Error(t, c.Resume(ctx)) c.Engine = mockEngine err := c.Start(ctx) @@ -44,4 +48,8 @@ func TestWorkloadControl(t *testing.T) { assert.NoError(t, err) err = c.Remove(ctx, true) assert.NoError(t, err) + err = c.Suspend(ctx) + assert.NoError(t, err) + err = c.Resume(ctx) + assert.NoError(t, err) }