Skip to content

Commit

Permalink
feat(worker): auto pull images if not found
Browse files Browse the repository at this point in the history
This commit ensures that the worker tries to pull the docker images if
they are not found locally.
  • Loading branch information
rickstaa committed Sep 12, 2024
1 parent 36d796d commit 6138eec
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package worker

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"strings"
"sync"
Expand All @@ -15,6 +17,8 @@ import (
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/mount"
"github.com/docker/docker/client"
"github.com/docker/docker/errdefs"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/go-connections/nat"
)

Expand Down Expand Up @@ -159,6 +163,32 @@ func (m *DockerManager) HasCapacity(ctx context.Context, pipeline, modelID strin
return err == nil
}

// pullImage pulls the specified image from the registry.
func (m *DockerManager) pullImage(ctx context.Context, imageName string) error {
reader, err := m.dockerClient.ImagePull(ctx, imageName, types.ImagePullOptions{})
if err != nil {
return fmt.Errorf("failed to pull image: %w", err)
}
defer reader.Close()

// Show progress.
decoder := json.NewDecoder(reader)
for {
var progress jsonmessage.JSONMessage
if err := decoder.Decode(&progress); err == io.EOF {
break
} else if err != nil {
return fmt.Errorf("error decoding progress message: %w", err)
}
if progress.Status != "" && progress.Progress != nil {
fmt.Printf("\r%s: %s", progress.Status, progress.Progress.String())
}
}
fmt.Println()

return nil
}

func (m *DockerManager) createContainer(ctx context.Context, pipeline string, modelID string, keepWarm bool, optimizationFlags OptimizationFlags) (*RunnerContainer, error) {
gpu, err := m.allocGPU(ctx)
if err != nil {
Expand Down Expand Up @@ -222,9 +252,23 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
},
}

// Create container and pull image if not found locally.
resp, err := m.dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, containerName)
if err != nil {
return nil, err
if errdefs.IsNotFound(err) {
slog.Info("Image not found locally, pulling image", slog.String("image", containerImage))
if err := m.pullImage(ctx, containerImage); err != nil {
return nil, err
}

// Retry container creation after pulling the image.
resp, err = m.dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, containerName)
if err != nil {
return nil, fmt.Errorf("failed to create container after pulling image: %w", err)
}
} else {
return nil, err
}
}

cctx, cancel := context.WithTimeout(ctx, containerTimeout)
Expand Down

0 comments on commit 6138eec

Please sign in to comment.