Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(worker): auto pull images if not found #200

Merged
merged 15 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ require (
github.com/getkin/kin-openapi v0.128.0
github.com/go-chi/chi/v5 v5.1.0
github.com/oapi-codegen/runtime v1.1.1
github.com/opencontainers/image-spec v1.1.0
github.com/stretchr/testify v1.9.0
github.com/vincent-petithory/dataurl v1.0.0
)

require (
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/containerd/log v0.1.0 // indirect
Expand All @@ -36,7 +38,6 @@ require (
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
Expand Down
7 changes: 5 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
Expand All @@ -10,6 +10,8 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down Expand Up @@ -134,6 +136,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
Expand Down
5 changes: 4 additions & 1 deletion worker/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ type RunnerContainerConfig struct {
containerTimeout time.Duration
}

// Create global references to functions to allow for mocking in tests.
var runnerWaitUntilReadyFunc = runnerWaitUntilReady

func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig, name string) (*RunnerContainer, error) {
// Ensure that timeout is set to a non-zero value.
timeout := cfg.containerTimeout
Expand All @@ -63,7 +66,7 @@ func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig, name str
}

cctx, cancel := context.WithTimeout(ctx, cfg.containerTimeout)
if err := runnerWaitUntilReady(cctx, client, pollingInterval); err != nil {
if err := runnerWaitUntilReadyFunc(cctx, client, pollingInterval); err != nil {
cancel()
return nil, err
}
Expand Down
149 changes: 125 additions & 24 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@ package worker

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"strings"
"sync"
"time"

"github.com/docker/cli/opts"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/api/types/mount"
"github.com/docker/docker/api/types/network"
docker "github.com/docker/docker/client"
"github.com/docker/docker/errdefs"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/go-connections/nat"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
)

const containerModelDir = "/models"
Expand Down Expand Up @@ -57,41 +64,76 @@ var livePipelineToImage = map[string]string{
"noop": "livepeer/ai-runner:live-app-noop",
}

// DockerClient is an interface for the Docker client, allowing for mocking in tests.
// NOTE: ensure any docker.Client methods used in this package are added.
type DockerClient interface {
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error)
ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error)
ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error)
ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error
ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error
ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error
ImageInspectWithRaw(ctx context.Context, imageID string) (types.ImageInspect, []byte, error)
ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error)
}

// Compile-time assertion to ensure docker.Client implements DockerClient.
var _ DockerClient = (*docker.Client)(nil)

// Create global references to functions to allow for mocking in tests.
var dockerWaitUntilRunningFunc = dockerWaitUntilRunning

type DockerManager struct {
defaultImage string
gpus []string
modelDir string

dockerClient *docker.Client
dockerClient DockerClient
// gpu ID => container name
gpuContainers map[string]string
// container name => container
containers map[string]*RunnerContainer
mu *sync.Mutex
}

func NewDockerManager(defaultImage string, gpus []string, modelDir string) (*DockerManager, error) {
dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}

func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
ctx, cancel := context.WithTimeout(context.Background(), containerTimeout)
if err := removeExistingContainers(ctx, dockerClient); err != nil {
if err := removeExistingContainers(ctx, client); err != nil {
cancel()
return nil, err
}
cancel()

return &DockerManager{
manager := &DockerManager{
defaultImage: defaultImage,
gpus: gpus,
modelDir: modelDir,
dockerClient: dockerClient,
dockerClient: client,
gpuContainers: make(map[string]string),
containers: make(map[string]*RunnerContainer),
mu: &sync.Mutex{},
}, nil
}

return manager, nil
}

// EnsureImageAvailable ensures the container image is available locally for the given pipeline and model ID.
func (m *DockerManager) EnsureImageAvailable(ctx context.Context, pipeline string, modelID string) error {
imageName, err := m.getContainerImageName(pipeline, modelID)
if err != nil {
return err
}

// Pull the image if it is not available locally.
if !m.isImageAvailable(ctx, pipeline, modelID) {
slog.Info(fmt.Sprintf("Pulling image for pipeline %s and modelID %s: %s", pipeline, modelID, imageName))
err = m.pullImage(ctx, imageName)
if err != nil {
return err
}
}

return nil
}

func (m *DockerManager) Warm(ctx context.Context, pipeline string, modelID string, optimizationFlags OptimizationFlags) error {
Expand Down Expand Up @@ -157,6 +199,24 @@ func (m *DockerManager) returnContainer(rc *RunnerContainer) {
m.containers[rc.Name] = rc
}

// getContainerImageName returns the image name for the given pipeline and model ID.
// Returns an error if the image is not found for "live-video-to-video".
func (m *DockerManager) getContainerImageName(pipeline, modelID string) (string, error) {
if pipeline == "live-video-to-video" {
// We currently use the model ID as the live pipeline name for legacy reasons.
if image, ok := livePipelineToImage[modelID]; ok {
return image, nil
}
return "", fmt.Errorf("no container image found for live pipeline %s", modelID)
}

if image, ok := pipelineToImage[pipeline]; ok {
return image, nil
}

return m.defaultImage, nil
}

// HasCapacity checks if an unused managed container exists or if a GPU is available for a new container.
func (m *DockerManager) HasCapacity(ctx context.Context, pipeline, modelID string) bool {
m.mu.Lock()
Expand All @@ -169,11 +229,58 @@ func (m *DockerManager) HasCapacity(ctx context.Context, pipeline, modelID strin
}
}

// TODO: This can be removed if we optimize the selection algorithm.
// Currently, using CreateContainer errors only can cause orchestrator reselection.
if !m.isImageAvailable(ctx, pipeline, modelID) {
return false
}

// Check for available GPU to allocate for a new container for the requested model.
_, err := m.allocGPU(ctx)
return err == nil
}

// isImageAvailable checks if the specified image is available locally.
func (m *DockerManager) isImageAvailable(ctx context.Context, pipeline string, modelID string) bool {
imageName, err := m.getContainerImageName(pipeline, modelID)
if err != nil {
slog.Error(err.Error())
return false
}

_, _, err = m.dockerClient.ImageInspectWithRaw(ctx, imageName)
if err != nil {
slog.Error(fmt.Sprintf("Image for pipeline %s and modelID %s is not available locally: %s", pipeline, modelID, imageName))
}
return err == nil
}

// pullImage pulls the specified image from the registry.
func (m *DockerManager) pullImage(ctx context.Context, imageName string) error {
reader, err := m.dockerClient.ImagePull(ctx, imageName, image.PullOptions{})
if err != nil {
return fmt.Errorf("failed to pull image: %w", err)
}
defer reader.Close()

// Show progress.
decoder := json.NewDecoder(reader)
rickstaa marked this conversation as resolved.
Show resolved Hide resolved
for {
var progress jsonmessage.JSONMessage
if err := decoder.Decode(&progress); err == io.EOF {
break
} else if err != nil {
return fmt.Errorf("error decoding progress message: %w", err)
}
if progress.Status != "" && progress.Progress != nil {
fmt.Printf("\r%s: %s", progress.Status, progress.Progress.String())
}
}
fmt.Println()
rickstaa marked this conversation as resolved.
Show resolved Hide resolved

return nil
}

func (m *DockerManager) createContainer(ctx context.Context, pipeline string, modelID string, keepWarm bool, optimizationFlags OptimizationFlags) (*RunnerContainer, error) {
gpu, err := m.allocGPU(ctx)
if err != nil {
Expand All @@ -183,15 +290,9 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
// NOTE: We currently allow only one container per GPU for each pipeline.
containerHostPort := containerHostPorts[pipeline][:3] + gpu
containerName := dockerContainerName(pipeline, modelID, containerHostPort)
containerImage := m.defaultImage
if pipelineSpecificImage, ok := pipelineToImage[pipeline]; ok {
containerImage = pipelineSpecificImage
} else if pipeline == "live-video-to-video" {
// We currently use the model ID as the live pipeline name for legacy reasons
containerImage = livePipelineToImage[modelID]
if containerImage == "" {
return nil, fmt.Errorf("no container image found for live pipeline %s", modelID)
}
containerImage, err := m.getContainerImageName(pipeline, modelID)
if err != nil {
return nil, err
}

slog.Info("Starting managed container", slog.String("gpu", gpu), slog.String("name", containerName), slog.String("modelID", modelID), slog.String("containerImage", containerImage))
Expand Down Expand Up @@ -258,7 +359,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
cancel()

cctx, cancel = context.WithTimeout(ctx, containerTimeout)
if err := dockerWaitUntilRunning(cctx, m.dockerClient, resp.ID, pollingInterval); err != nil {
if err := dockerWaitUntilRunningFunc(cctx, m.dockerClient, resp.ID, pollingInterval); err != nil {
cancel()
dockerRemoveContainer(m.dockerClient, resp.ID)
return nil, err
Expand Down Expand Up @@ -390,7 +491,7 @@ func (m *DockerManager) watchContainer(rc *RunnerContainer, borrowCtx context.Co
}
}

func removeExistingContainers(ctx context.Context, client *docker.Client) error {
func removeExistingContainers(ctx context.Context, client DockerClient) error {
filters := filters.NewArgs(filters.Arg("label", containerCreatorLabel+"="+containerCreator))
containers, err := client.ContainerList(ctx, container.ListOptions{All: true, Filters: filters})
if err != nil {
Expand All @@ -416,7 +517,7 @@ func dockerContainerName(pipeline string, modelID string, suffix ...string) stri
return fmt.Sprintf("%s_%s", pipeline, sanitizedModelID)
}

func dockerRemoveContainer(client *docker.Client, containerID string) error {
func dockerRemoveContainer(client DockerClient, containerID string) error {
ctx, cancel := context.WithTimeout(context.Background(), containerRemoveTimeout)
defer cancel()

Expand Down Expand Up @@ -449,7 +550,7 @@ func dockerRemoveContainer(client *docker.Client, containerID string) error {
}
}

func dockerWaitUntilRunning(ctx context.Context, client *docker.Client, containerID string, pollingInterval time.Duration) error {
func dockerWaitUntilRunning(ctx context.Context, client DockerClient, containerID string, pollingInterval time.Duration) error {
ticker := time.NewTicker(pollingInterval)
defer ticker.Stop()

Expand Down
Loading