Skip to content

Commit

Permalink
quick commit to show to victor
Browse files Browse the repository at this point in the history
  • Loading branch information
rickstaa committed Jan 21, 2025
1 parent ddfb96b commit 3f76bd1
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 118 deletions.
82 changes: 38 additions & 44 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ var containerHostPorts = map[string]string{
}

// Mapping for per pipeline container images.
var defaultBaseImage = "livepeer/ai-runner:latest"
var pipelineToImage = map[string]string{
"segment-anything-2": "livepeer/ai-runner:segment-anything-2",
"text-to-speech": "livepeer/ai-runner:text-to-speech",
"audio-to-text": "livepeer/ai-runner:audio-to-text"
"audio-to-text": "livepeer/ai-runner:audio-to-text",
}

var livePipelineToImage = map[string]string{
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
"liveportrait": "livepeer/ai-runner:live-app-liveportrait",
Expand All @@ -67,6 +67,35 @@ var livePipelineToImage = map[string]string{
"noop": "livepeer/ai-runner:live-app-noop",
}

// overridePipelineImages updates base and pipeline images with the provided overrides.
func overridePipelineImages(imageOverrides string) error {
var imageMap map[string]string
if err := json.Unmarshal([]byte(imageOverrides), &imageMap); err != nil {
// If not JSON, treat it as a single image string to set the base image.
defaultBaseImage = imageOverrides
return nil
}

// Successfully parsed JSON, update the mappings.
for pipeline, image := range imageMap {
if pipeline == "base" {
defaultBaseImage = image
continue
}

// Check and update the pipeline images.
if _, exists := pipelineToImage[pipeline]; exists {
pipelineToImage[pipeline] = image
} else if _, exists := livePipelineToImage[pipeline]; exists {
livePipelineToImage[pipeline] = image
} else {
// If the pipeline is not found in the map, throw an error.
return fmt.Errorf("can't override docker image for unknown pipeline: %s", pipeline)
}
}
return nil
}

// DockerClient is an interface for the Docker client, allowing for mocking in tests.
// NOTE: ensure any docker.Client methods used in this package are added.
type DockerClient interface {
Expand Down Expand Up @@ -99,58 +128,23 @@ type DockerManager struct {
mu *sync.Mutex
}

// updatePipelineMappings updates the specified mapping with pipeline to image overriding.
// It logs a warning if a pipeline is not found in the given mapping.
//
// Parameters:
// - overrides: A map of pipeline names to custom image names.
// - mapping: The map to be updated with the provided overrides.
// - mapName: The name of the map (used for logging purposes).
func updatePipelineMappings(overrides map[string]string, mapping map[string]string, mapName string) {
for pipeline, image := range overrides {
if _, exists := mapping[pipeline]; exists {
mapping[pipeline] = image
} else {
slog.Warn("Pipeline not found in map", "map", mapName, "pipeline", pipeline)
}
}
}

// overridePipelineImages function parses a JSON string containing pipeline-to-image mappings and overrides the default mappings if valid.
// It updates the `pipelineToImage` and `livePipelineToImage` maps with custom images.
// Parameters:
// - defaultImage: A string that can either be containerImage name or a JSON string with overrides for pipeline-to-image mappings.
//
// Returns:
// - error: An error if the JSON parsing fails or if the mapping is not found in existing maps else `nil`.
func overridePipelineImages(defaultImage string) error {
if strings.HasPrefix(defaultImage, "{") || strings.HasSuffix(defaultImage, "}") {
var pipelineOverrides map[string]string
if err := json.Unmarshal([]byte(defaultImage), &pipelineOverrides); err != nil {
slog.Error("Error parsing JSON", "error", err)
return err
}
updatePipelineMappings(pipelineOverrides, pipelineToImage, "pipelineToImage")
updatePipelineMappings(pipelineOverrides, livePipelineToImage, "livePipelineToImage")
}
return nil
}

func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
func NewDockerManager(imageOverrides string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
ctx, cancel := context.WithTimeout(context.Background(), containerTimeout)
if err := removeExistingContainers(ctx, client); err != nil {
cancel()
return nil, err
}
cancel()

// call to handle image overriding logic
if err := overridePipelineImages(defaultImage); err != nil {
return nil, err
// Override pipeline images if provided.
if imageOverrides != "" {
if err := overridePipelineImages(imageOverrides); err != nil {
return nil, err
}
}

manager := &DockerManager{
defaultImage: defaultImage,
defaultImage: defaultBaseImage,
gpus: gpus,
modelDir: modelDir,
dockerClient: client,
Expand Down
167 changes: 95 additions & 72 deletions worker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,101 @@ func createDockerManager(mockDockerClient *MockDockerClient) *DockerManager {
}
}

func TestOverridePipelineImages(t *testing.T) {
mockDockerClient := new(MockDockerClient)
dockerManager := createDockerManager(mockDockerClient)

// Store the original values of the maps
originalPipelineToImage := make(map[string]string)
for k, v := range pipelineToImage {
originalPipelineToImage[k] = v
}

originalLivePipelineToImage := make(map[string]string)
for k, v := range livePipelineToImage {
originalLivePipelineToImage[k] = v
}

// Reset the maps before each test
t.Cleanup(func() {
pipelineToImage = make(map[string]string)
for k, v := range originalPipelineToImage {
pipelineToImage[k] = v
}
livePipelineToImage = make(map[string]string)
for k, v := range originalLivePipelineToImage {
livePipelineToImage[k] = v
}
})

tests := []struct {
name string
inputJSON string
expectedImages map[string]string
expectError bool
}{
{
name: "ValidPipelineOverrides",
inputJSON: `{"segment-anything-2": "custom-image:1.0", "text-to-speech": "speech-image:2.0"}`,
expectedImages: map[string]string{
"segment-anything-2": "custom-image:1.0",
"text-to-speech": "speech-image:2.0",
},
expectError: false,
},
{
name: "NoOverrideFallback",
inputJSON: `{"segment-anything-2": "custom-image:1.0"}`,
expectedImages: map[string]string{
"streamdiffusion": "default-image",
},
expectError: false,
},
{
name: "EmptyJSON",
inputJSON: `{}`,
expectedImages: map[string]string{
"segment-anything-2": pipelineToImage["segment-anything-2"],
},
expectError: false,
},
{
name: "MalformedJSON",
inputJSON: `{"segment-anything-2": "custom-image:1.0"`,
expectError: true,
},
{
name: "EmptyString",
inputJSON: "",
expectError: true,
},
{
name: "UnknownPipeline",
inputJSON: `{"unknown-pipeline": "unknown-image:latest"}`,
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Call overridePipelineImages function with the mock data.
err := overridePipelineImages(tt.inputJSON)

if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)

// Verify the expected image.
for pipeline, expectedImage := range tt.expectedImages {
image, _ := dockerManager.getContainerImageName(pipeline, "")
require.Equal(t, expectedImage, image)
}
}
})
}
}

func TestNewDockerManager(t *testing.T) {
mockDockerClient := new(MockDockerClient)

Expand Down Expand Up @@ -784,75 +879,3 @@ func TestDockerWaitUntilRunning(t *testing.T) {
mockDockerClient.AssertExpectations(t)
})
}

func TestDockerManager_overridePipelineImages(t *testing.T) {
mockDockerClient := new(MockDockerClient)
dockerManager := createDockerManager(mockDockerClient)

tests := []struct {
name string
inputJSON string
pipeline string
expectedImage string
expectError bool
}{
{
name: "ValidOverride",
inputJSON: `{"segment-anything-2": "custom-image:1.0"}`,
pipeline: "segment-anything-2",
expectedImage: "custom-image:1.0",
expectError: false,
},
{
name: "MultipleOverrides",
inputJSON: `{"segment-anything-2": "custom-image:1.0", "text-to-speech": "speech-image:2.0"}`,
pipeline: "text-to-speech",
expectedImage: "speech-image:2.0",
expectError: false,
},
{
name: "NoOverrideFallback",
inputJSON: `{"segment-anything-2": "custom-image:1.0"}`,
pipeline: "streamdiffusion",
expectedImage: "default-image",
expectError: false,
},
{
name: "EmptyJSON",
inputJSON: `{}`,
pipeline: "segment-anything-2",
expectedImage: "custom-image:1.0",
expectError: false,
},
{
name: "MalformedJSON",
inputJSON: `{"segment-anything-2": "custom-image:1.0"`,
pipeline: "segment-anything-2",
expectError: true,
},
{
name: "RegularStringInput",
inputJSON: "",
pipeline: "image-to-video",
expectedImage: "default-image",
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Call overridePipelineImages function with the mock data.
err := overridePipelineImages(tt.inputJSON)

if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)

// Verify the expected image.
image, _ := dockerManager.getContainerImageName(tt.pipeline, "")
require.Equal(t, tt.expectedImage, image)
}
})
}
}
4 changes: 2 additions & 2 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ type Worker struct {
mu *sync.Mutex
}

func NewWorker(defaultImage string, gpus []string, modelDir string) (*Worker, error) {
func NewWorker(imageOverrides string, gpus []string, modelDir string) (*Worker, error) {
dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}

manager, err := NewDockerManager(defaultImage, gpus, modelDir, dockerClient)
manager, err := NewDockerManager(imageOverrides, gpus, modelDir, dockerClient)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 3f76bd1

Please sign in to comment.