From 1a16a342efff10d1950a56b9cfed2763a9117d4c Mon Sep 17 00:00:00 2001 From: Young Yoon Lee Date: Fri, 25 Aug 2023 19:30:18 -0700 Subject: [PATCH] Mitigates the need of unnecessary calculations 2d processing when loss2d_factor=0 (#45) --- src/main.py | 18 +++++++++++------- src/models/metro.py | 22 ++++++++++++++-------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/main.py b/src/main.py index e38912c..e360b95 100755 --- a/src/main.py +++ b/src/main.py @@ -59,11 +59,12 @@ def train(): n_completed_steps = get_n_completed_steps(FLAGS.checkpoint_dir, FLAGS.load_path) rng = np.random.RandomState(FLAGS.seed) - data2d = build_dataflow( - examples2d, data.data_loading.load_and_transform2d, (joint_info2d, TRAIN), - TRAIN, batch_size=FLAGS.batch_size_2d * n_repl, n_workers=FLAGS.workers, - rng=util.new_rng(rng), n_completed_steps=n_completed_steps, - n_total_steps=FLAGS.training_steps) + if FLAGS.loss2d_factor > 0.0: + data2d = build_dataflow( + examples2d, data.data_loading.load_and_transform2d, (joint_info2d, TRAIN), + TRAIN, batch_size=FLAGS.batch_size_2d * n_repl, n_workers=FLAGS.workers, + rng=util.new_rng(rng), n_completed_steps=n_completed_steps, + n_total_steps=FLAGS.training_steps) data3d = build_dataflow( example_sections, data.data_loading.load_and_transform3d, (joint_info3d, TRAIN), @@ -72,8 +73,11 @@ def train(): rng=util.new_rng(rng), n_completed_steps=n_completed_steps, n_total_steps=FLAGS.training_steps, roundrobin_sizes=roundrobin_sizes) - data_train = tf.data.Dataset.zip((data3d, data2d)) - data_train = data_train.map(lambda batch3d, batch2d: {**batch3d, **batch2d}) + if FLAGS.loss2d_factor > 0.0: + data_train = tf.data.Dataset.zip((data3d, data2d)) + data_train = data_train.map(lambda batch3d, batch2d: {**batch3d, **batch2d}) + else: + data_train = data3d if not FLAGS.multi_gpu: data_train = data_train.apply(tf.data.experimental.prefetch_to_device('GPU:0', 2)) diff --git a/src/models/metro.py b/src/models/metro.py index 1702feb..2da3fcb 100644 --- a/src/models/metro.py +++ b/src/models/metro.py @@ -97,12 +97,14 @@ def forward_test(self, inps): def forward_train(self, inps, training): preds = AttrDict() - image_both = tf.concat([inps.image, inps.image_2d], axis=0) - coords3d_pred_both = self.model(image_both, training=training) - - batch_sizes = [t.shape.as_list()[0] for t in [inps.image, inps.image_2d]] - preds.coords3d_rel_pred, preds.coords3d_pred_2d = tf.split( - coords3d_pred_both, batch_sizes, axis=0) + if FLAGS.loss2d_factor > 0.0: + image_both = tf.concat([inps.image, inps.image_2d], axis=0) + coords3d_pred_both = self.model(image_both, training=training) + batch_sizes = [t.shape.as_list()[0] for t in [inps.image, inps.image_2d]] + preds.coords3d_rel_pred, preds.coords3d_pred_2d = tf.split( + coords3d_pred_both, batch_sizes, axis=0) + else: + preds.coords3d_rel_pred = self.model(inps.image, training=training) joint_ids_3d = [ [self.joint_info.ids[n2] for n2 in self.joint_info.names if n2.startswith(n1)] @@ -117,8 +119,9 @@ def get_2dlike_joints(coords): for ids in joint_ids_3d], axis=1) # numbers mean: like 2d dataset joints, 2d batch - preds.coords2d_pred_2d = get_2dlike_joints(preds.coords3d_pred_2d[..., :2]) - + if FLAGS.loss2d_factor > 0.0: + preds.coords2d_pred_2d = get_2dlike_joints(preds.coords3d_pred_2d[..., :2]) + return preds def compute_losses(self, inps, preds): @@ -137,6 +140,9 @@ def compute_losses(self, inps, preds): rootrel_absdiff = tf.abs((coords3d_true_rootrel - coords3d_pred_rootrel) / 1000) losses.loss3d = tfu.reduce_mean_masked(rootrel_absdiff, inps.joint_validity_mask[:, joint_index_start:]) + if FLAGS.loss2d_factor==0.0: + losses.loss = losses.loss3d + return losses #################### # 2D BATCH