Skip to content

Commit

Permalink
refactor: enhance tests and improve overridePipelineImages structure
Browse files Browse the repository at this point in the history
This commit expands the test coverage to ensure more robust behavior and
refactors the `overridePipelineImages` function to improve error handling and
readability.
  • Loading branch information
rickstaa committed Jan 22, 2025
1 parent 66c5f84 commit 57b5ef7
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 76 deletions.
46 changes: 27 additions & 19 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,38 @@ var livePipelineToImage = map[string]string{

// overridePipelineImages updates base and pipeline images with the provided overrides.
func overridePipelineImages(imageOverrides string) error {
var imageMap map[string]string
if err := json.Unmarshal([]byte(imageOverrides), &imageMap); err != nil {
// If not JSON, treat it as a single image string to set the base image.
defaultBaseImage = imageOverrides
return nil
if imageOverrides == "" {
return fmt.Errorf("empty string is not a valid image override")
}

// Successfully parsed JSON, update the mappings.
for pipeline, image := range imageMap {
if pipeline == "base" {
defaultBaseImage = image
continue
}
// Handle JSON format for multiple pipeline images.
var imageMap map[string]string
if err := json.Unmarshal([]byte(imageOverrides), &imageMap); err == nil {
for pipeline, image := range imageMap {
if pipeline == "base" {
defaultBaseImage = image
continue
}

// Check and update the pipeline images.
if _, exists := pipelineToImage[pipeline]; exists {
pipelineToImage[pipeline] = image
} else if _, exists := livePipelineToImage[pipeline]; exists {
livePipelineToImage[pipeline] = image
} else {
// If the pipeline is not found in the map, throw an error.
return fmt.Errorf("can't override docker image for unknown pipeline: %s", pipeline)
// Check and update the pipeline images.
if _, exists := pipelineToImage[pipeline]; exists {
pipelineToImage[pipeline] = image
} else if _, exists := livePipelineToImage[pipeline]; exists {
livePipelineToImage[pipeline] = image
} else {
return fmt.Errorf("can't override docker image for unknown pipeline: %s", pipeline)
}
}
return nil
}

// Check for invalid docker image string.
if strings.ContainsAny(imageOverrides, "{}[]\",") {
return fmt.Errorf("invalid JSON format for image overrides")
}

// Update the base image.
defaultBaseImage = imageOverrides
return nil
}

Expand Down
136 changes: 79 additions & 57 deletions worker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,95 +106,117 @@ func createDockerManager(mockDockerClient *MockDockerClient) *DockerManager {
}
}

func TestOverridePipelineImages(t *testing.T) {
mockDockerClient := new(MockDockerClient)
dockerManager := createDockerManager(mockDockerClient)

// Store the original values of the maps
originalPipelineToImage := make(map[string]string)
for k, v := range pipelineToImage {
originalPipelineToImage[k] = v
}

originalLivePipelineToImage := make(map[string]string)
for k, v := range livePipelineToImage {
originalLivePipelineToImage[k] = v
// copyMap returns a deep copy of the given map.
func copyMap(m map[string]string) map[string]string {
copy := make(map[string]string)
for k, v := range m {
copy[k] = v
}
return copy
}

// Reset the maps before each test
t.Cleanup(func() {
pipelineToImage = make(map[string]string)
for k, v := range originalPipelineToImage {
pipelineToImage[k] = v
}
livePipelineToImage = make(map[string]string)
for k, v := range originalLivePipelineToImage {
livePipelineToImage[k] = v
}
})
func TestOverridePipelineImages(t *testing.T) {
// Store the original values of the maps.
originalDefaultBaseImage := defaultBaseImage
originalPipelineToImage := copyMap(pipelineToImage)
originalLivePipelineToImage := copyMap(livePipelineToImage)

tests := []struct {
name string
inputJSON string
expectedImages map[string]string
expectError bool
name string
inputJSON string
expectedBase string
expectedPipelineImages map[string]string
expectedLiveImages map[string]string
expectError bool
}{
{
name: "ValidPipelineOverrides",
inputJSON: `{"segment-anything-2": "custom-image:1.0", "text-to-speech": "speech-image:2.0"}`,
expectedImages: map[string]string{
name: "ValidPipelineOverrides",
inputJSON: `{"segment-anything-2": "custom-image:1.0", "text-to-speech": "speech-image:2.0"}`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: map[string]string{
"segment-anything-2": "custom-image:1.0",
"text-to-speech": "speech-image:2.0",
"audio-to-text": originalPipelineToImage["audio-to-text"],
},
expectError: false,
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "NoOverrideFallback",
inputJSON: `{"segment-anything-2": "custom-image:1.0"}`,
expectedImages: map[string]string{
"streamdiffusion": "default-image",
},
expectError: false,
name: "OverrideBaseImage",
inputJSON: "new-base-image:latest",
expectedBase: "new-base-image:latest",
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "EmptyJSON",
inputJSON: `{}`,
expectedImages: map[string]string{
"segment-anything-2": pipelineToImage["segment-anything-2"],
},
expectError: false,
name: "OverrideBaseImageJSON",
inputJSON: `{"base": "new-base-image:latest"}`,
expectedBase: "new-base-image:latest",
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "MalformedJSON",
inputJSON: `{"segment-anything-2": "custom-image:1.0"`,
expectError: true,
name: "EmptyJSON",
inputJSON: `{}`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "EmptyString",
inputJSON: "",
expectError: true,
name: "MalformedJSON",
inputJSON: `{"segment-anything-2": "custom-image:1.0"`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: true,
},
{
name: "UnknownPipeline",
inputJSON: `{"unknown-pipeline": "unknown-image:latest"}`,
expectError: true,
name: "EmptyString",
inputJSON: "",
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: true,
},
{
name: "UnknownPipeline",
inputJSON: `{"unknown-pipeline": "unknown-image:latest"}`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Register a cleanup function to reset state after the subtest.
t.Cleanup(func() {
defaultBaseImage = originalDefaultBaseImage
pipelineToImage = copyMap(originalPipelineToImage)
livePipelineToImage = copyMap(originalLivePipelineToImage)
})

// Call overridePipelineImages function with the mock data.
err := overridePipelineImages(tt.inputJSON)

if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedBase, defaultBaseImage)

// Verify the expected pipeline images.
for pipeline, expectedImage := range tt.expectedPipelineImages {
require.Equal(t, expectedImage, pipelineToImage[pipeline])
}

// Verify the expected image.
for pipeline, expectedImage := range tt.expectedImages {
image, _ := dockerManager.getContainerImageName(pipeline, "")
require.Equal(t, expectedImage, image)
// Verify the expected live pipeline images.
for livePipeline, expectedImage := range tt.expectedLiveImages {
require.Equal(t, expectedImage, livePipelineToImage[livePipeline])
}
}
})
Expand Down

0 comments on commit 57b5ef7

Please sign in to comment.