From 03a1aa072ea9a6e8205e756075ec5253c9d13b8b Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Mon, 22 Jan 2024 11:09:58 +0300 Subject: [PATCH] fix: removed the jax import and defaulted jax to use the cpu due to out of memory issues in the unet demo --- examples_and_demos/image_segmentation_with_ivy_unet.ipynb | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples_and_demos/image_segmentation_with_ivy_unet.ipynb b/examples_and_demos/image_segmentation_with_ivy_unet.ipynb index c86f217f..505adb73 100644 --- a/examples_and_demos/image_segmentation_with_ivy_unet.ipynb +++ b/examples_and_demos/image_segmentation_with_ivy_unet.ipynb @@ -61,8 +61,6 @@ "source": [ "import ivy\n", "ivy.set_default_device(\"gpu:0\")\n", - "import jax\n", - "jax.devices()\n", "import torch\n", "import numpy as np" ] @@ -562,6 +560,7 @@ "import jax\n", "\n", "jax.config.update('jax_enable_x64', True)\n", + "ivy.set_default_device(\"cpu\")\n", "ivy.set_backend(\"jax\")\n", "ivy_unet = ivy_models.unet_carvana(n_channels=3, n_classes=2, pretrained=True)" ]