Skip to content

Commit

Permalink
create minimal conda env for RTX4090 to fix slow inference
Browse files Browse the repository at this point in the history
  • Loading branch information
tlpss committed Jan 9, 2024
1 parent a12b7e3 commit 1c70335
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 155 deletions.
3 changes: 2 additions & 1 deletion contact_graspnet/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import yaml
from yaml import Loader

def recursive_key_value_assign(d,ks,v):
"""
Expand Down Expand Up @@ -37,7 +38,7 @@ def load_config(checkpoint_dir, batch_size=None, max_epoch=None, data_path=None,
config_path = os.path.join(checkpoint_dir, 'config.yaml')
config_path = config_path if os.path.exists(config_path) else os.path.join(os.path.dirname(__file__),'config.yaml')
with open(config_path,'r') as f:
global_config = yaml.load(f)
global_config = yaml.load(f,Loader=Loader)

for conf in arg_configs:
k_str, v = conf.split(':')
Expand Down
4 changes: 2 additions & 2 deletions contact_graspnet/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
# physical_devices = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(physical_devices[0], True)

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.join(BASE_DIR))
Expand Down
2 changes: 1 addition & 1 deletion contact_graspnet/visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def show_image(rgb, segmap):
"""
plt.figure()
figManager = plt.get_current_fig_manager()
figManager.window.showMaximized()
#figManager.window.showMaximized()

plt.ion()
plt.show()
Expand Down
151 changes: 0 additions & 151 deletions contact_graspnet_env.yml

This file was deleted.

17 changes: 17 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: contact-graspnet
channels:
- conda-forge
- anaconda
- defaults
dependencies:
- python=3.9
- cudatoolkit==11.8
- cudnn
- pip
- pip:
- tensorflow==2.5 # requires python<=3.9
- opencv-python-headless
- pyyaml
- pyrender
- tqdm
- mayavi

0 comments on commit 1c70335

Please sign in to comment.