From c62dd431f40b9b0910f7bcba5942250194009c90 Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Tue, 2 May 2023 18:52:24 -0700 Subject: [PATCH] Add passing test that includes Torch and TensorFlow together Signed-off-by: Mattt Zmuda --- pkg/config/config_test.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index b33d903eb8..921eb58f6d 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -345,6 +345,33 @@ foo==1.0.0` require.Equal(t, expected, requirements) } +func TestPythonPackagesBothTorchAndTensorflow(t *testing.T) { + config := &Config{ + Build: &Build{ + GPU: true, + PythonVersion: "3.11.1", + PythonPackages: []string{ + "tensorflow==2.12.0", + "torch==2.0.1", + "torchvision==0.15.2", + }, + CUDA: "11.8", + }, + } + err := config.ValidateAndComplete("") + require.NoError(t, err) + require.Equal(t, "11.8", config.Build.CUDA) + require.Equal(t, "8", config.Build.CuDNN) + + requirements, err := config.PythonRequirementsForArch("", "") + require.NoError(t, err) + expected := `--extra-index-url https://download.pytorch.org/whl/cu118 +tensorflow==2.12.0 +torch==2.0.1+cu118 +torchvision==0.15.2` + require.Equal(t, expected, requirements) +} + func TestCUDABaseImageTag(t *testing.T) { config := &Config{ Build: &Build{