Skip to content

Commit

Permalink
Merge pull request #2 from google-research/main
Browse files Browse the repository at this point in the history
Catchup
  • Loading branch information
MrXandbadas authored Feb 3, 2023
2 parents 41faed3 + f0cec8f commit b2f608b
Showing 1 changed file with 90 additions and 38 deletions.
128 changes: 90 additions & 38 deletions challenges/point_tracking/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,6 @@
from tensorflow_graphics.geometry.transformation import rotation_matrix_3d


TOTAL_TRACKS = 256
MAX_SAMPLED_FRAC = .1
MAX_SEG_ID = 22
INPUT_SIZE = (None, 256, 256)
STRIDE = 4 # Make sure this divides all axes of INPUT_SIZE


def project_point(cam, point3d, num_frames):
"""Compute the image space coordinates [0, 1] for a set of points.
Expand Down Expand Up @@ -188,6 +181,7 @@ def get_camera_matrices(
cam_positions,
cam_quaternions,
cam_sensor_width,
input_size,
num_frames=None,
):
"""Tf function that converts camera positions into projection matrices."""
Expand All @@ -198,7 +192,7 @@ def get_camera_matrices(
focal_length = tf.cast(cam_focal_length, tf.float32)
sensor_width = tf.cast(cam_sensor_width, tf.float32)
f_x = focal_length / sensor_width
f_y = focal_length / sensor_width * INPUT_SIZE[1] / INPUT_SIZE[2]
f_y = focal_length / sensor_width * input_size[0] / input_size[1]
p_x = 0.5
p_y = 0.5
intrinsics.append(
Expand Down Expand Up @@ -235,6 +229,7 @@ def single_object_reproject(
num_frames=None,
depth_map=None,
window=None,
input_size=None,
):
"""Reproject points for a single object.
Expand All @@ -247,6 +242,7 @@ def single_object_reproject(
num_frames: Number of frames
depth_map: Depth map video for the camera
window: the window inside which we're sampling points
input_size: [height, width] of the input images.
Returns:
Position for each point, of shape [num_points, num_frames, 2], in pixel
Expand All @@ -265,8 +261,8 @@ def single_object_reproject(
)

occluded = tf.less(reproj[:, :, 2], 0)
reproj = reproj[:, :, 0:2] * np.array(INPUT_SIZE[2:0:-1])[np.newaxis,
np.newaxis, :]
reproj = reproj[:, :, 0:2] * np.array(input_size[::-1])[np.newaxis,
np.newaxis, :]
occluded = tf.logical_or(
occluded,
tf.less(
Expand All @@ -285,19 +281,23 @@ def single_object_reproject(
return obj_reproj, obj_occ


def get_num_to_sample(counts):
def get_num_to_sample(counts, max_seg_id, max_sampled_frac, tracks_to_sample):
"""Computes the number of points to sample for each object.
Args:
counts: The number of points available per object. An int array of length
n, where n is the number of objects.
max_seg_id: The maximum number of segment id's in the video.
max_sampled_frac: The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
tracks_to_sample: Total number of tracks to sample per video.
Returns:
The number of points to sample for each object. An int array of length n.
"""
seg_order = tf.argsort(counts)
sorted_counts = tf.gather(counts, seg_order)
initializer = (0, TOTAL_TRACKS, 0)
initializer = (0, tracks_to_sample, 0)

def scan_fn(prev_output, count_seg):
index = prev_output[0]
Expand All @@ -308,7 +308,7 @@ def scan_fn(prev_output, count_seg):
tf.cast(desired_frac, tf.float32))
want_to_sample = tf.cast(tf.round(want_to_sample), tf.int32)
max_to_sample = (
tf.cast(count_seg, tf.float32) * tf.cast(MAX_SAMPLED_FRAC, tf.float32))
tf.cast(count_seg, tf.float32) * tf.cast(max_sampled_frac, tf.float32))
max_to_sample = tf.cast(tf.round(max_to_sample), tf.int32)
num_to_sample = tf.minimum(want_to_sample, max_to_sample)

Expand All @@ -323,7 +323,7 @@ def scan_fn(prev_output, count_seg):
num_to_sample = tf.concat(
[
num_to_sample,
tf.zeros([MAX_SEG_ID - tf.shape(num_to_sample)[0]], dtype=tf.int32),
tf.zeros([max_seg_id - tf.shape(num_to_sample)[0]], dtype=tf.int32),
],
axis=0,
)
Expand All @@ -344,7 +344,10 @@ def track_points(
cam_quaternions,
cam_sensor_width,
window,
num_frames=None,
tracks_to_sample=256,
sampling_stride=4,
max_seg_id=25,
max_sampled_frac=0.1,
):
"""Track points in 2D using Kubric data.
Expand All @@ -364,7 +367,12 @@ def track_points(
window: the window inside which we're sampling points. Integer valued
in the format [x_min, y_min, x_max, y_max], where min is inclusive and
max is exclusive.
num_frames: number of frames in the video
tracks_to_sample: Total number of tracks to sample per video.
sampling_stride: For efficiency, query points are sampled from a random grid
of this stride.
max_seg_id: The maxium segment id in the video.
max_sampled_frac: The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
Returns:
A set of queries, randomly sampled from the video (with a bias toward
Expand All @@ -388,22 +396,25 @@ def track_points(
depth_f32 = tf.cast(depth, tf.float32)
depth_map = depth_min + depth_f32 * (depth_max-depth_min) / 65535

input_size = object_coordinates.shape.as_list()[1:3]
num_frames = object_coordinates.shape.as_list()[0]

# We first sample query points within the given window. That means first
# extracting the window from the segmentation tensor, because we want to have
# a bias toward moving objects.
# Note: for speed we sample points on a grid. The grid start position is
# randomized within the window.
start_vec = [
tf.random.uniform([], minval=0, maxval=STRIDE, dtype=tf.int32)
for _ in range(len(INPUT_SIZE))
tf.random.uniform([], minval=0, maxval=sampling_stride, dtype=tf.int32)
for _ in range(3)
]
start_vec[1] += window[0]
start_vec[2] += window[1]
end_vec = [num_frames, window[2], window[3]]

def extract_box(x):
x = x[start_vec[0]::STRIDE, start_vec[1]:window[2]:STRIDE,
start_vec[2]:window[3]:STRIDE]
x = x[start_vec[0]::sampling_stride, start_vec[1]:window[2]:sampling_stride,
start_vec[2]:window[3]:sampling_stride]
return x

segmentations_box = extract_box(segmentations)
Expand All @@ -413,13 +424,19 @@ def extract_box(x):
# how many points are available for each object.

cnt = tf.math.bincount(tf.cast(tf.reshape(segmentations_box, [-1]), tf.int32))
num_to_sample = get_num_to_sample(cnt)
num_to_sample.set_shape([MAX_SEG_ID])
num_to_sample = get_num_to_sample(
cnt,
max_seg_id,
max_sampled_frac,
tracks_to_sample,
)
num_to_sample.set_shape([max_seg_id])
intrinsics, matrix_world = get_camera_matrices(
cam_focal_length,
cam_positions,
cam_quaternions,
cam_sensor_width,
input_size,
num_frames=num_frames,
)

Expand All @@ -431,11 +448,14 @@ def get_camera(fr=None):
# Construct pixel coordinates for each pixel within the window.
window = tf.cast(window, tf.float32)
z, y, x = tf.meshgrid(
*[tf.range(st, ed, STRIDE) for st, ed in zip(start_vec, end_vec)],
*[
tf.range(st, ed, sampling_stride)
for st, ed in zip(start_vec, end_vec)
],
indexing='ij')
pix_coords = tf.reshape(tf.stack([z, y, x], axis=-1), [-1, 3])

for i in range(MAX_SEG_ID):
for i in range(max_seg_id):
# sample points on object i in the first frame. obj_id is the position
# within the object_coordinates array, which is one lower than the value
# in the segmentation mask (0 in the segmentation mask is the background
Expand Down Expand Up @@ -492,6 +512,7 @@ def get_camera(fr=None):
num_frames=num_frames,
depth_map=depth_map,
window=window,
input_size=input_size,
),
lambda: # pylint: disable=g-long-lambda
(tf.zeros([0, num_frames, 2], dtype=tf.float32),
Expand All @@ -508,7 +529,7 @@ def get_camera(fr=None):
np.array([num_frames]), window[2:4]],
axis=0)
wd = wd[tf.newaxis, tf.newaxis, :]
coord_multiplier = [num_frames, INPUT_SIZE[1], INPUT_SIZE[2]]
coord_multiplier = [num_frames, input_size[0], input_size[1]]
all_reproj = tf.concat(all_reproj, axis=0)
# We need to extract x,y, but the format of the window is [t1,y1,x1,t2,y2,x2]
window_size = wd[:, :, 5:3:-1] - wd[:, :, 2:0:-1]
Expand Down Expand Up @@ -558,17 +579,28 @@ def _get_distorted_bounding_box(


def add_tracks(data,
train_size=(200, 200),
train_size=(256, 256),
vflip=False,
random_crop=True):
random_crop=True,
tracks_to_sample=256,
sampling_stride=4,
max_seg_id=25,
max_sampled_frac=0.1):
"""Track points in 2D using Kubric data.
Args:
data: kubric data, including RGB/depth/object coordinate/segmentation
data: Kubric data, including RGB/depth/object coordinate/segmentation
videos and camera parameters.
train_size: cropped output will be at this resolution
train_size: Cropped output will be at this resolution. Ignored if
random_crop is False.
vflip: whether to vertically flip images and tracks (to test generalization)
random_crop: whether to randomly crop videos
random_crop: Whether to randomly crop videos
tracks_to_sample: Total number of tracks to sample per video.
sampling_stride: For efficiency, query points are sampled from a random grid
of this stride.
max_seg_id: The maxium segment id in the video.
max_sampled_frac: The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
Returns:
A dict with the following keys:
Expand All @@ -589,6 +621,8 @@ def add_tracks(data,
"""
shp = data['video'].shape.as_list()
num_frames = shp[0]
if any([s % sampling_stride != 0 for s in shp[:-1]]):
raise ValueError('All video dims must be a multiple of sampling_stride.')

bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
min_area = 0.3
Expand All @@ -604,7 +638,7 @@ def add_tracks(data,
area_range=(min_area, max_area),
max_attempts=20)
else:
crop_window = tf.constant([0, 0, INPUT_SIZE[1], INPUT_SIZE[2]],
crop_window = tf.constant([0, 0, shp[1], shp[2]],
dtype=tf.int32,
shape=[4])

Expand All @@ -613,14 +647,14 @@ def add_tracks(data,
data['metadata']['depth_range'], data['segmentations'],
data['instances']['bboxes_3d'], data['camera']['focal_length'],
data['camera']['positions'], data['camera']['quaternions'],
data['camera']['sensor_width'], crop_window, num_frames)
data['camera']['sensor_width'], crop_window, tracks_to_sample,
sampling_stride, max_seg_id, max_sampled_frac)
video = data['video']

shp = video.shape.as_list()
num_frames = shp[0]
query_points.set_shape([TOTAL_TRACKS, 3])
target_points.set_shape([TOTAL_TRACKS, num_frames, 2])
occluded.set_shape([TOTAL_TRACKS, num_frames])
query_points.set_shape([tracks_to_sample, 3])
target_points.set_shape([tracks_to_sample, num_frames, 2])
occluded.set_shape([tracks_to_sample, num_frames])

# Crop the video to the sampled window, in a way which matches the coordinate
# frame produced the track_points functions.
Expand Down Expand Up @@ -654,6 +688,11 @@ def create_point_tracking_dataset(
repeat=True,
vflip=False,
random_crop=True,
tracks_to_sample=256,
sampling_stride=4,
max_seg_id=25,
max_sampled_frac=0.1,
num_parallel_point_extraction_calls=16,
**kwargs):
"""Construct a dataset for point tracking using Kubric: go/kubric.
Expand All @@ -667,6 +706,15 @@ def create_point_tracking_dataset(
repeat: Bool. whether to repeat the dataset.
vflip: Bool. whether to vertically flip the dataset to test generalization.
random_crop: Bool. whether to randomly crop videos
tracks_to_sample: Int. Total number of tracks to sample per video.
sampling_stride: Int. For efficiency, query points are sampled from a
random grid of this stride.
max_seg_id: Int. The maxium segment id in the video. Note the size of
the to graph is proportional to this number, so prefer small values.
max_sampled_frac: Float. The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
num_parallel_point_extraction_calls: Int. The num_parallel_calls for the
map function for point extraction.
**kwargs: additional args to pass to tfds.load.
Returns:
Expand All @@ -686,8 +734,12 @@ def create_point_tracking_dataset(
add_tracks,
train_size=train_size,
vflip=vflip,
random_crop=random_crop),
num_parallel_calls=2)
random_crop=random_crop,
tracks_to_sample=tracks_to_sample,
sampling_stride=sampling_stride,
max_seg_id=max_seg_id,
max_sampled_frac=max_sampled_frac),
num_parallel_calls=num_parallel_point_extraction_calls)
if shuffle_buffer_size is not None:
ds = ds.shuffle(shuffle_buffer_size)

Expand Down

0 comments on commit b2f608b

Please sign in to comment.