Skip to content

Commit

Permalink
Mitigates the need of unnecessary calculations 2d processing when los…
Browse files Browse the repository at this point in the history
…s2d_factor=0 (isarandi#45)
  • Loading branch information
ylee-rbx authored and GitHub Enterprise committed Aug 26, 2023
1 parent cfc00c7 commit 1a16a34
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
18 changes: 11 additions & 7 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ def train():
n_completed_steps = get_n_completed_steps(FLAGS.checkpoint_dir, FLAGS.load_path)

rng = np.random.RandomState(FLAGS.seed)
data2d = build_dataflow(
examples2d, data.data_loading.load_and_transform2d, (joint_info2d, TRAIN),
TRAIN, batch_size=FLAGS.batch_size_2d * n_repl, n_workers=FLAGS.workers,
rng=util.new_rng(rng), n_completed_steps=n_completed_steps,
n_total_steps=FLAGS.training_steps)
if FLAGS.loss2d_factor > 0.0:
data2d = build_dataflow(
examples2d, data.data_loading.load_and_transform2d, (joint_info2d, TRAIN),
TRAIN, batch_size=FLAGS.batch_size_2d * n_repl, n_workers=FLAGS.workers,
rng=util.new_rng(rng), n_completed_steps=n_completed_steps,
n_total_steps=FLAGS.training_steps)

data3d = build_dataflow(
example_sections, data.data_loading.load_and_transform3d, (joint_info3d, TRAIN),
Expand All @@ -72,8 +73,11 @@ def train():
rng=util.new_rng(rng), n_completed_steps=n_completed_steps,
n_total_steps=FLAGS.training_steps, roundrobin_sizes=roundrobin_sizes)

data_train = tf.data.Dataset.zip((data3d, data2d))
data_train = data_train.map(lambda batch3d, batch2d: {**batch3d, **batch2d})
if FLAGS.loss2d_factor > 0.0:
data_train = tf.data.Dataset.zip((data3d, data2d))
data_train = data_train.map(lambda batch3d, batch2d: {**batch3d, **batch2d})
else:
data_train = data3d
if not FLAGS.multi_gpu:
data_train = data_train.apply(tf.data.experimental.prefetch_to_device('GPU:0', 2))

Expand Down
22 changes: 14 additions & 8 deletions src/models/metro.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ def forward_test(self, inps):
def forward_train(self, inps, training):
preds = AttrDict()

image_both = tf.concat([inps.image, inps.image_2d], axis=0)
coords3d_pred_both = self.model(image_both, training=training)

batch_sizes = [t.shape.as_list()[0] for t in [inps.image, inps.image_2d]]
preds.coords3d_rel_pred, preds.coords3d_pred_2d = tf.split(
coords3d_pred_both, batch_sizes, axis=0)
if FLAGS.loss2d_factor > 0.0:
image_both = tf.concat([inps.image, inps.image_2d], axis=0)
coords3d_pred_both = self.model(image_both, training=training)
batch_sizes = [t.shape.as_list()[0] for t in [inps.image, inps.image_2d]]
preds.coords3d_rel_pred, preds.coords3d_pred_2d = tf.split(
coords3d_pred_both, batch_sizes, axis=0)
else:
preds.coords3d_rel_pred = self.model(inps.image, training=training)

joint_ids_3d = [
[self.joint_info.ids[n2] for n2 in self.joint_info.names if n2.startswith(n1)]
Expand All @@ -117,8 +119,9 @@ def get_2dlike_joints(coords):
for ids in joint_ids_3d], axis=1)

# numbers mean: like 2d dataset joints, 2d batch
preds.coords2d_pred_2d = get_2dlike_joints(preds.coords3d_pred_2d[..., :2])

if FLAGS.loss2d_factor > 0.0:
preds.coords2d_pred_2d = get_2dlike_joints(preds.coords3d_pred_2d[..., :2])

return preds

def compute_losses(self, inps, preds):
Expand All @@ -137,6 +140,9 @@ def compute_losses(self, inps, preds):

rootrel_absdiff = tf.abs((coords3d_true_rootrel - coords3d_pred_rootrel) / 1000)
losses.loss3d = tfu.reduce_mean_masked(rootrel_absdiff, inps.joint_validity_mask[:, joint_index_start:])
if FLAGS.loss2d_factor==0.0:
losses.loss = losses.loss3d
return losses

####################
# 2D BATCH
Expand Down

0 comments on commit 1a16a34

Please sign in to comment.