From 6138eeca3ffa0155ac99abe8b6ba9a7a3e37ff18 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 12 Sep 2024 21:29:17 +0200 Subject: [PATCH] feat(worker): auto pull images if not found This commit ensures that the worker tries to pull the docker images if they are not found locally. --- worker/docker.go | 46 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/worker/docker.go b/worker/docker.go index f42c0e49..e63cb5ec 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -2,8 +2,10 @@ package worker import ( "context" + "encoding/json" "errors" "fmt" + "io" "log/slog" "strings" "sync" @@ -15,6 +17,8 @@ import ( "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/mount" "github.com/docker/docker/client" + "github.com/docker/docker/errdefs" + "github.com/docker/docker/pkg/jsonmessage" "github.com/docker/go-connections/nat" ) @@ -159,6 +163,32 @@ func (m *DockerManager) HasCapacity(ctx context.Context, pipeline, modelID strin return err == nil } +// pullImage pulls the specified image from the registry. +func (m *DockerManager) pullImage(ctx context.Context, imageName string) error { + reader, err := m.dockerClient.ImagePull(ctx, imageName, types.ImagePullOptions{}) + if err != nil { + return fmt.Errorf("failed to pull image: %w", err) + } + defer reader.Close() + + // Show progress. + decoder := json.NewDecoder(reader) + for { + var progress jsonmessage.JSONMessage + if err := decoder.Decode(&progress); err == io.EOF { + break + } else if err != nil { + return fmt.Errorf("error decoding progress message: %w", err) + } + if progress.Status != "" && progress.Progress != nil { + fmt.Printf("\r%s: %s", progress.Status, progress.Progress.String()) + } + } + fmt.Println() + + return nil +} + func (m *DockerManager) createContainer(ctx context.Context, pipeline string, modelID string, keepWarm bool, optimizationFlags OptimizationFlags) (*RunnerContainer, error) { gpu, err := m.allocGPU(ctx) if err != nil { @@ -222,9 +252,23 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo }, } + // Create container and pull image if not found locally. resp, err := m.dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, containerName) if err != nil { - return nil, err + if errdefs.IsNotFound(err) { + slog.Info("Image not found locally, pulling image", slog.String("image", containerImage)) + if err := m.pullImage(ctx, containerImage); err != nil { + return nil, err + } + + // Retry container creation after pulling the image. + resp, err = m.dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, containerName) + if err != nil { + return nil, fmt.Errorf("failed to create container after pulling image: %w", err) + } + } else { + return nil, err + } } cctx, cancel := context.WithTimeout(ctx, containerTimeout)