Skip to content

Commit

Permalink
refactor: enhance tests and improve overridePipelineImages structure
Browse files Browse the repository at this point in the history
This commit expands the test coverage to ensure more robust behavior and
refactors the `overridePipelineImages` function to improve error handling and
readability.
  • Loading branch information
rickstaa committed Jan 22, 2025
1 parent 3f76bd1 commit 3e48a78
Show file tree
Hide file tree
Showing 3 changed files with 411 additions and 152 deletions.
108 changes: 78 additions & 30 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
"io"
"log/slog"
"runtime/debug"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -35,7 +37,9 @@ const containerRemoveTimeout = 30 * time.Second
const containerCreatorLabel = "creator"
const containerCreator = "ai-worker"

var containerWatchInterval = 10 * time.Second
var containerWatchInterval = 5 * time.Second
var pipelineStartGracePeriod = 60 * time.Second
var maxHealthCheckFailures = 2

// This only works right now on a single GPU because if there is another container
// using the GPU we stop it so we don't have to worry about having enough ports
Expand All @@ -58,41 +62,49 @@ var pipelineToImage = map[string]string{
"segment-anything-2": "livepeer/ai-runner:segment-anything-2",
"text-to-speech": "livepeer/ai-runner:text-to-speech",
"audio-to-text": "livepeer/ai-runner:audio-to-text",
"llm": "livepeer/ai-runner:llm",
}
var livePipelineToImage = map[string]string{
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
"liveportrait": "livepeer/ai-runner:live-app-liveportrait",
"comfyui": "livepeer/ai-runner:live-app-comfyui",
"segment_anything_2": "livepeer/ai-runner:live-app-segment_anything_2",
"noop": "livepeer/ai-runner:live-app-noop",
}

// overridePipelineImages updates base and pipeline images with the provided overrides.
func overridePipelineImages(imageOverrides string) error {
var imageMap map[string]string
if err := json.Unmarshal([]byte(imageOverrides), &imageMap); err != nil {
// If not JSON, treat it as a single image string to set the base image.
defaultBaseImage = imageOverrides
return nil
if imageOverrides == "" {
return fmt.Errorf("empty string is not a valid image override")
}

// Successfully parsed JSON, update the mappings.
for pipeline, image := range imageMap {
if pipeline == "base" {
defaultBaseImage = image
continue
}
// Handle JSON format for multiple pipeline images.
var imageMap map[string]string
if err := json.Unmarshal([]byte(imageOverrides), &imageMap); err == nil {
for pipeline, image := range imageMap {
if pipeline == "base" {
defaultBaseImage = image
continue
}

// Check and update the pipeline images.
if _, exists := pipelineToImage[pipeline]; exists {
pipelineToImage[pipeline] = image
} else if _, exists := livePipelineToImage[pipeline]; exists {
livePipelineToImage[pipeline] = image
} else {
// If the pipeline is not found in the map, throw an error.
return fmt.Errorf("can't override docker image for unknown pipeline: %s", pipeline)
// Check and update the pipeline images.
if _, exists := pipelineToImage[pipeline]; exists {
pipelineToImage[pipeline] = image
} else if _, exists := livePipelineToImage[pipeline]; exists {
livePipelineToImage[pipeline] = image
} else {
return fmt.Errorf("can't override docker image for unknown pipeline: %s", pipeline)
}
}
return nil
}

// Check for invalid docker image string.
if strings.ContainsAny(imageOverrides, "{}[]\",") {
return fmt.Errorf("invalid JSON format for image overrides")
}

// Update the base image.
defaultBaseImage = imageOverrides
return nil
}

Expand Down Expand Up @@ -502,35 +514,71 @@ func (m *DockerManager) watchContainer(rc *RunnerContainer, borrowCtx context.Co
if r := recover(); r != nil {
slog.Error("Panic in container watch routine",
slog.String("container", rc.Name),
slog.Any("panic", r))
slog.Any("panic", r),
slog.String("stack", string(debug.Stack())))
}
}()

ticker := time.NewTicker(containerWatchInterval)
defer ticker.Stop()

slog.Info("Watching container", slog.String("container", rc.Name))
failures := 0
startTime := time.Now()
for {
if failures >= maxHealthCheckFailures && time.Since(startTime) > pipelineStartGracePeriod {
slog.Error("Container health check failed too many times", slog.String("container", rc.Name))
m.destroyContainer(rc, false)
return
}

select {
case <-borrowCtx.Done():
m.returnContainer(rc)
return
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), containerWatchInterval)
container, err := m.dockerClient.ContainerInspect(ctx, rc.ID)
health, err := rc.Client.HealthWithResponse(ctx)
cancel()

if docker.IsErrNotFound(err) {
// skip to destroy below to update internal state
} else if err != nil {
slog.Error("Error inspecting container",
if err != nil {
failures++
slog.Error("Error getting health for runner",
slog.String("container", rc.Name),
slog.String("error", err.Error()))
continue
} else if container.State.Running {
} else if health.StatusCode() != 200 {
failures++
slog.Error("Container health check failed with HTTP status code",
slog.String("container", rc.Name),
slog.Int("status_code", health.StatusCode()),
slog.String("body", string(health.Body)))
continue
}
m.destroyContainer(rc, false)
return
slog.Debug("Health check response",
slog.String("status", health.Status()),
slog.Any("JSON200", health.JSON200),
slog.String("body", string(health.Body)))

status := health.JSON200.Status
switch status {
case IDLE:
if time.Since(startTime) > pipelineStartGracePeriod {
slog.Info("Container is idle, returning to pool", slog.String("container", rc.Name))
m.returnContainer(rc)
return
}
fallthrough
case OK:
failures = 0
continue
default:
failures++
slog.Error("Container not healthy",
slog.String("container", rc.Name),
slog.String("status", string(status)),
slog.String("failures", strconv.Itoa(failures)))
}
}
}
}
Expand Down
Loading

0 comments on commit 3e48a78

Please sign in to comment.