From 6ab380a41fb2cb9a634a7c05b5a58812c349f20f Mon Sep 17 00:00:00 2001 From: Mattt Date: Tue, 25 Jun 2024 16:07:22 -0700 Subject: [PATCH] Add passing test that includes Torch and TensorFlow together (#1123) Signed-off-by: Mattt Zmuda Co-authored-by: Ben Firshman --- pkg/config/config_test.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index fa6696f9cc..4b59fe75df 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -495,6 +495,33 @@ torch==2.3.1+cu121` require.Equal(t, expected, requirements) } +func TestPythonPackagesBothTorchAndTensorflow(t *testing.T) { + config := &Config{ + Build: &Build{ + GPU: true, + PythonVersion: "3.11.1", + PythonPackages: []string{ + "tensorflow==2.12.0", + "torch==2.0.1", + "torchvision==0.15.2", + }, + CUDA: "11.8", + }, + } + err := config.ValidateAndComplete("") + require.NoError(t, err) + require.Equal(t, "11.8", config.Build.CUDA) + require.Equal(t, "8", config.Build.CuDNN) + + requirements, err := config.PythonRequirementsForArch("", "") + require.NoError(t, err) + expected := `--extra-index-url https://download.pytorch.org/whl/cu118 +tensorflow==2.12.0 +torch==2.0.1+cu118 +torchvision==0.15.2` + require.Equal(t, expected, requirements) +} + func TestCUDABaseImageTag(t *testing.T) { config := &Config{ Build: &Build{