-
Notifications
You must be signed in to change notification settings - Fork 1
/
rasterize.py
483 lines (392 loc) · 19.9 KB
/
rasterize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
import logging
import math
import os
import subprocess
from typing import Optional, Tuple
import click
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from PIL import Image
from plyfile import PlyData, PlyElement
from spherical_harmonics import sh_to_rgb
from utils import read_color_components, read_scene
logging.basicConfig(
format="[%(asctime)s] %(levelname)s [%(pathname)s:%(lineno)d] - %(message)s",
datefmt="%m-%d %H:%M:%S",
level=logging.NOTSET,
)
logger = logging.getLogger(__name__)
# Z_FAR and Z_NEAR are computer graphics distance which mark the near sight and far sight limit
# i.e you cannot something closer than Z_NEAR or farther than Z_FAR
Z_FAR = 100.0
Z_NEAR = 0.01
# This is a scaling factor set in the original implementation. Not sure whether there's an actual reason to use this particular value
GAUSSIAN_SPREAD = 3
# Size of a processing block for a CUDA kernel (i.e a block processes a 16*16 set of pixels)
BLOCK_SIZE = 16
# Set maximum density to prevent overflow issues
MAX_GAUSSIAN_DENSITY = 0.99
# Minimum alpha before stopping to blend new gaussians (they will not be visible in any case)
MIN_ALPHA = 1 / 255
def quaternion_to_rotation_matrix(quaternion: torch.Tensor) -> torch.Tensor:
"""
Quaternion representation for rotation matrices is common, as it is a more efficient representation.
The rotation matrix can be recovered with the formula below, no tricks just calculus.
"""
w_q = quaternion[0, :]
x = quaternion[1, :]
y = quaternion[2, :]
z = quaternion[3, :]
return torch.stack(
[
torch.stack([1 - 2 * y ** 2 - 2 * z ** 2, 2 * x * y - 2 * z * w_q, 2 * x * z + 2 * y * w_q,]),
torch.stack([2 * x * y + 2 * z * w_q, 1 - 2 * x ** 2 - 2 * z ** 2, 2 * y * z - 2 * x * w_q,]),
torch.stack([2 * x * z - 2 * y * w_q, 2 * y * z + 2 * x * w_q, 1 - 2 * x ** 2 - 2 * y ** 2,]),
]
).float()
def get_world_to_camera_matrix(normalized_qvec: torch.Tensor, tvec: torch.Tensor) -> torch.Tensor:
"""
We create the matrix that transforms coordinates from World space (i.e agnostic to your POV, or simply the reference coordinate space)
to the Camera space which is the system of coordinates based on the camera POV
Given a rotation matrix R and a translation vector T, the matrix is defined as:
M = [[R11, R12, R13, -T_x],
[R21, R22, R23, -T_y],
[R31, R32, R33, -T_z],
[ 0, 0, 0, 1]]
"""
rotation_matrix = quaternion_to_rotation_matrix(normalized_qvec.unsqueeze(1))
projection_matrix = torch.zeros((4, 4))
projection_matrix[:3, :3] = rotation_matrix.squeeze(-1)
projection_matrix[:3, 3] = tvec
projection_matrix[3, 3] = 1
return projection_matrix
def project_to_camera_space(gaussian_means: torch.Tensor, world_to_camera: torch.Tensor) -> torch.Tensor:
"""
This is just 3D geometry, the new coordinates are obtained by:
- applying the rotation (i.e multiplying by the rotation matrix)
- add the translation
"""
return gaussian_means @ world_to_camera[:3, :3] + world_to_camera[-1, :3]
def get_covariance_matrix_from_mesh(mesh: PlyElement) -> torch.Tensor:
"""
Covariance matrices are trained parameters. They will define the spread of each gaussian in the 3D space, and therefore
the area of pixels covered by the gaussians once projected in 2D
See paper: they formulate gaussian covariances using a scale matrix S and a rotation matrix R
such that Cov = R * S * S_t * R_t
"""
scales = torch.exp(
torch.tensor(np.stack([mesh.elements[0]["scale_0"], mesh.elements[0]["scale_1"], mesh.elements[0]["scale_2"],]))
)
rotations = torch.tensor(
np.stack(
[
mesh.elements[0]["rot_0"],
mesh.elements[0]["rot_1"],
mesh.elements[0]["rot_2"],
mesh.elements[0]["rot_3"],
]
)
)
# Learned quaternions do not guarantee a unit norm, therefore we have to normalize them
unit_quaternions = torch.nn.functional.normalize(rotations, p=2.0, dim=0)
rotation_matrices = quaternion_to_rotation_matrix(unit_quaternions).permute(2, 0, 1)
scale_matrices = torch.zeros((scales.shape[-1], 3, 3))
indices = torch.arange(3)
scale_matrices[:, indices, indices] = scales.T
M = rotation_matrices @ scale_matrices
return M @ torch.permute(M, (0, 2, 1))
def get_projection_matrix(fov_x: float, fov_y: float) -> torch.Tensor:
"""
The projection matrix models the transformation from the 3D camera space to
the 2D screen space: effectively, you project all points onto that 2D plane
using this matrix.
It takes into account the field of view (what is visible from the camera for x-y axes)
along with Z_FAR and Z_NEAR which accounts for points visible depth-wise
"""
tan_half_fov_x = math.tan((fov_x / 2))
tan_half_fov_y = math.tan((fov_y / 2))
top = tan_half_fov_y * Z_NEAR
bottom = -top
right = tan_half_fov_x * Z_NEAR
left = -right
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 2.0 * Z_NEAR / (right - left)
P[1, 1] = 2.0 * Z_NEAR / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * Z_FAR / (Z_FAR - Z_NEAR)
P[2, 3] = -(Z_FAR * Z_NEAR) / (Z_FAR - Z_NEAR)
return P
def compute_covering_bbox(
screen_means: torch.Tensor, projected_covariances: torch.Tensor, width: float, height: float,
) -> torch.Tensor:
"""
For each 2D projected gaussian, we first compute its spread using its eigen values. And since
we need to know the covered area in terms of pixels (ie. cannot model an ellipsoid), we approximate
the spread with a bounding box centered on the gaussian.
"""
det = (
projected_covariances[:, 0, 0] * projected_covariances[:, 1, 1]
- projected_covariances[:, 1, 0] * projected_covariances[:, 0, 1]
)
trace = projected_covariances[:, 0, 0] + projected_covariances[:, 1, 1]
# Have to clamp to 0 in case lambda is negative (no guarantee it is not)
# To preven instabilities, we set the max at 0.1 (value defined in the original implementation)
lambda1 = trace / 2.0 + torch.sqrt(
torch.max((trace ** 2) / 4.0 - det, torch.tensor([0.1], device=screen_means.device))
)
lambda2 = trace / 2.0 - torch.sqrt(
torch.max((trace ** 2) / 4.0 - det, torch.tensor([0.1], device=screen_means.device))
)
# This is equivalent to taking 3 times the maximum standard deviation from the projected gaussian
max_spread = torch.ceil(
GAUSSIAN_SPREAD * torch.sqrt(torch.max(torch.stack([lambda1, lambda2], dim=-1), dim=-1).values)
)
# The original implementation divides the screen space in blocks of size BLOCK_SIZE
# We keep this paradigm here so that we can more easily map this step back to the original implementation
# but for this simplified implementation, this is not required.
bboxes = torch.stack(
[
torch.clamp((screen_means[:, 0] - (max_spread)) / BLOCK_SIZE, 0, width - 1),
torch.clamp((screen_means[:, 1] - (max_spread)) / BLOCK_SIZE, 0, height - 1),
torch.clamp((screen_means[:, 0] + (max_spread + BLOCK_SIZE - 1)) / BLOCK_SIZE, 0, width - 1,),
torch.clamp((screen_means[:, 1] + (max_spread + BLOCK_SIZE - 1)) / BLOCK_SIZE, 0, height - 1,),
],
dim=-1,
)
# Clamp again for gaussians that spread outside of the screen
bboxes = torch.floor(bboxes).to(int)
return bboxes
def compute_2d_covariance(
cov_matrices, camera_space_points, tan_fov_x, tan_fov_y, focals, world_to_camera
) -> torch.Tensor:
"""
The spread of each gaussian needs to be projected in screen space (similarly to each gaussian center).
This is done by projecting the covariance matrices of each gaussian using the EWA Splatting technique.
The original implementation is located at: https://github.com/graphdeco-inria/diff-gaussian-rasterization/blob/59f5f77e3ddbac3ed9db93ec2cfe99ed6c5d121d/cuda_rasterizer/forward.cu#L74
"""
limx = torch.tensor([1.3 * tan_fov_x], device=cov_matrices.device)
limy = torch.tensor([1.3 * tan_fov_y], device=cov_matrices.device)
# In the original implementation, the formula for the focals is missing a factor 2
# See: https://github.com/graphdeco-inria/diff-gaussian-rasterization/blob/59f5f77e3ddbac3ed9db93ec2cfe99ed6c5d121d/cuda_rasterizer/rasterizer_impl.cu#L222-L223
# To account for it, we divide by 2 here
focal_x, focal_y = focals / 2
txtz = camera_space_points[:, 0] / camera_space_points[:, 2]
tytz = camera_space_points[:, 1] / camera_space_points[:, 2]
tx = torch.min(limx, torch.max(-limx, txtz)) * camera_space_points[:, 2]
ty = torch.min(limy, torch.max(-limy, tytz)) * camera_space_points[:, 2]
# Compute the Jacobian matrix
J = torch.zeros((camera_space_points.shape[0], 3, 3), device=cov_matrices.device)
J[:, 0, 0] = focal_x / camera_space_points[:, 2]
J[:, 0, 2] = -(focal_x * tx) / (camera_space_points[:, 2] * camera_space_points[:, 2])
J[:, 1, 1] = focal_y / camera_space_points[:, 2]
J[:, 1, 2] = -(focal_y * ty) / (camera_space_points[:, 2] * camera_space_points[:, 2])
W = world_to_camera[:-1, :-1].T
T = torch.bmm(W.expand(J.shape[0], 3, 3).transpose(2, 1), J.transpose(2, 1)).transpose(2, 1)
vrk = torch.zeros((camera_space_points.shape[0], 3, 3), device=cov_matrices.device)
vrk[:, 0, 0] = cov_matrices[:, 0, 0]
vrk[:, 0, 1] = cov_matrices[:, 0, 1]
vrk[:, 0, 2] = cov_matrices[:, 0, 2]
vrk[:, 1, 0] = cov_matrices[:, 0, 1]
vrk[:, 1, 1] = cov_matrices[:, 1, 1]
vrk[:, 1, 2] = cov_matrices[:, 1, 2]
vrk[:, 2, 0] = cov_matrices[:, 0, 2]
vrk[:, 2, 1] = cov_matrices[:, 1, 2]
vrk[:, 2, 2] = cov_matrices[:, 2, 2]
proj_cov = T @ vrk @ T.transpose(2, 1)
# Apply low-pass filter: every Gaussian should be at least
# one pixel wide/high. Discard 3rd row and column.
proj_cov[:, 0, 0] += 0.3
proj_cov[:, 1, 1] += 0.3
return proj_cov[:, :2, :2]
def rasterize_gaussian(
gaussian_index: int,
bboxes: torch.Tensor,
screen: torch.Tensor,
screen_means: torch.Tensor,
sigmas: torch.Tensor,
rgb: torch.Tensor,
opacity_buffer: torch.Tensor,
opacity: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Here we rasterize a gaussian, ie. compute what pixels are covered by the gaussian and its spread
and "blend" the gaussian onto the existing screen (where previous gaussians have already been blended)
"""
sigma_x, sigma_y, sigma_x_y = sigmas[gaussian_index]
x_grid = torch.arange(bboxes[gaussian_index, 0], bboxes[gaussian_index, 2])
y_grid = torch.arange(bboxes[gaussian_index, 1], bboxes[gaussian_index, 3])
mesh_x, mesh_y = torch.meshgrid(x_grid, y_grid, indexing="ij")
mesh = torch.stack([mesh_x, mesh_y], dim=-1).view(-1, 2).to(screen_means.device)
# We compute the transmittance of the gaussian at each pixel covered which determines how much the new
# gaussian contributes to the color of the resulting pixel
dist_to_mean = screen_means[gaussian_index] - mesh
gaussian_density = (
-0.5 * (sigma_x * (dist_to_mean[:, 0] ** 2) + sigma_y * (dist_to_mean[:, 1] ** 2))
- sigma_x_y * dist_to_mean[:, 0] * dist_to_mean[:, 1]
)
alpha = torch.min(
opacity[gaussian_index] * torch.exp(gaussian_density),
torch.tensor([MAX_GAUSSIAN_DENSITY], device=screen_means.device),
).float()
# For numerical stability
valid = (alpha > MIN_ALPHA) & (gaussian_density <= 0)
valid_mesh = mesh[valid, :]
# Update the screen pixels with the alpha blending values for each of the pixel
screen[valid_mesh[:, 0], valid_mesh[:, 1], :] += (
alpha[valid, None] * rgb[gaussian_index] * opacity_buffer[valid_mesh[:, 0], valid_mesh[:, 1], None]
)
# Update the opacity buffer to track how much transmittance is left before each pixel is "saturated"
# i.e cannot transmit color from "deeper" gaussians
opacity_buffer[valid_mesh[:, 0], valid_mesh[:, 1]] = opacity_buffer[valid_mesh[:, 0], valid_mesh[:, 1]] * (
1 - alpha[valid]
)
return screen, opacity_buffer
@click.command()
@click.option("--input_dir", type=str, default="")
@click.option("--trained_model_path", type=str, default="")
@click.option("--output_path", type=str, default="")
@click.option("--scene-index", type=int, default=0)
@click.option("--scale-factor", type=int, default=2)
@click.option("--generate_video", is_flag=True, type=bool, default=False)
def run_rasterization(
input_dir: str,
trained_model_path,
output_path: Optional[str],
scene_index: int = 0,
scale_factor: int = 2,
generate_video: bool = False,
) -> None:
torch.set_num_threads(os.cpu_count() - 1)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Loading the scenes which are scene-specific information
logger.info(f"Fetching scenes from: {input_dir}")
scenes, cam_info = read_scene(path_to_scene=input_dir)
scene = scenes[scene_index]
# Loading the ground truth image from Mip-NERF 360
# images_{scale_fraction} i.e if 2, image has been shrunk by a factor 2
gt_img_path = os.path.join(os.path.join(input_dir, f"images_{scale_factor}"), scene.name)
img = Image.open(gt_img_path)
fx, fy, _, _ = cam_info[1].params
focals = np.array([fx, fy])
width, height = img.size
# Note: in the original implementation, there's a missing factor 2 in that formula
# We'll account for it downstream
fov_x = 2 * np.arctan(cam_info[1].width / (2 * fx))
fov_y = 2 * np.arctan(cam_info[1].height / (2 * fy))
tan_fov_x = math.tan(fov_x * 0.5)
tan_fov_y = math.tan(fov_y * 0.5)
qvec = torch.tensor(scene.qvec)
tvec = torch.tensor(scene.tvec)
logger.info(
f'Fetching trained model from: {os.path.join(trained_model_path, "point_cloud/iteration_30000/point_cloud.ply")}'
)
plydata = PlyData.read(os.path.join(trained_model_path, "point_cloud/iteration_30000/point_cloud.ply"))
gaussian_means = torch.tensor(
np.stack([plydata.elements[0]["x"], plydata.elements[0]["y"], plydata.elements[0]["z"],]).T, device=device
).float()
covariance_matrices = get_covariance_matrix_from_mesh(plydata).float().to(device)
opacity = torch.sigmoid(torch.tensor(np.array(plydata.elements[0]["opacity"]), device=device))
# Matrices to project coordinates from the reference to camera space
world_to_camera = get_world_to_camera_matrix(qvec, tvec).transpose(0, 1).to(device)
projection_matrix = get_projection_matrix(fov_x, fov_y).transpose(0, 1).to(device)
# Combine both
full_proj_transform = (world_to_camera.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
# Extracting the gaussian colors
colors = read_color_components(plydata).to(device)
rgb = sh_to_rgb(gaussian_means, colors, world_to_camera, degree=3)
# Projecting pionts into camera space for sub-tasks downstream
camera_space_gaussian_means = project_to_camera_space(gaussian_means, world_to_camera)
# Projecting points into screen space
points = gaussian_means @ full_proj_transform[:3, :] + full_proj_transform[-1, :]
# Frustum culling (essentially filtering the points that are too close to the camera)
frustum_culling_filter = camera_space_gaussian_means[:, 2] < 0.2
points[frustum_culling_filter] = 0.0
# Applying perspective divide to project point to NDC space
p_w = 1.0 / (points[:, -1] + 0.0000001)
p_proj = points[:, :-1] * p_w[:, None]
projected_covariances = compute_2d_covariance(
covariance_matrices, camera_space_gaussian_means, tan_fov_x, tan_fov_y, focals, world_to_camera,
)
# Setting the gaussian spread to 0.0 means they won't be projected during rasterization
projected_covariances[frustum_culling_filter] = 0.0
# Project back to screen space
screen_gaussians = ((p_proj[:, :2] + 1.0) * torch.tensor([width, height], device=device) - 1.0) / 2
covering_bboxes = compute_covering_bbox(screen_gaussians, projected_covariances, width, height)
det = (
projected_covariances[:, 0, 0] * projected_covariances[:, 1, 1]
- projected_covariances[:, 1, 0] * projected_covariances[:, 0, 1]
)
# det can underflow into 0, so have to zero-out the inverse of det as well
# More generally, if the determinant is 0 for a gaussian, it means that its density does not span a 3D space (ie could be a line or plane)
det_inv = torch.where(det == 0, 0, 1 / det)
# Computing the
sigmas = torch.stack(
[
projected_covariances[:, 1, 1] * det_inv[:],
projected_covariances[:, 0, 0] * det_inv[:],
-projected_covariances[:, 0, 1] * det_inv[:],
],
dim=-1,
)
# As mentioned earlier, we project those bboxes back to screen space
# This is not required if we don't use the Block size for the CUDA kernels
x_min = torch.clamp(covering_bboxes[:, 0] * BLOCK_SIZE, 0, width - 1)
y_min = torch.clamp(covering_bboxes[:, 1] * BLOCK_SIZE, 0, height - 1)
x_max = torch.clamp(covering_bboxes[:, 2] * BLOCK_SIZE, 0, width - 1)
y_max = torch.clamp(covering_bboxes[:, 3] * BLOCK_SIZE, 0, height - 1)
bboxes = torch.stack([x_min, y_min, x_max, y_max], dim=-1)
bbox_area = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1])
# Since we do alpha-blending for rasterization, we need to rasterize them in
# increasing depth order (i.e from near to far)
depths = camera_space_gaussian_means[:, 2]
depth_sorted_gaussians = torch.sort(depths).indices
if generate_video:
os.makedirs(output_path, exist_ok=True)
os.makedirs(os.path.join(output_path, "images"), exist_ok=True)
"""
Here is where the inefficiency comes in compared to the official implementation:
since we cannot parallelize the rasterization for each pixel, we have to loop over the
gaussians and rasterize them one by one (instead of distributing this process with a CUDA kernel)
"""
iteration_step = 0
screen = torch.zeros((int(width), int(height), 3), device=device).float()
opacity_buffer = torch.ones((int(width), int(height)), device=device).float()
for gaussian_index in tqdm.tqdm(depth_sorted_gaussians):
if bbox_area[gaussian_index] == 0 or (torch.any(sigmas[gaussian_index] == 0)):
# Either the gaussian has a null area or its opacity is 0
continue
screen, opacity_buffer = rasterize_gaussian(
gaussian_index, bboxes, screen, screen_gaussians, sigmas, rgb, opacity_buffer, opacity,
)
if iteration_step % 1000 == 0 and generate_video:
img = Image.fromarray((screen[:, :, :3].transpose(1, 0).cpu().numpy() * 255.0).astype(np.uint8))
img.save(os.path.join(output_path, "images", f"image_iter_{str(iteration_step).zfill(7)}.png",))
iteration_step += 1
if generate_video:
framerate = 20
for i in range(1, 2 * framerate + 1):
# We add 2 secs of video to let some time to see the fully recreated image before the video ends
img.save(
os.path.join(output_path, "images", f"image_iter_{str(iteration_step + 1000*i + 1).zfill(7)}.png",)
)
video_path = os.path.join(output_path, "video_render.mp4")
if os.path.exists(video_path):
os.remove(video_path)
cmd = f'ffmpeg -framerate {framerate} -pattern_type glob -i "{os.path.join(output_path, "images", "image_iter_*.png")}" -r 10 -vcodec libx264 -s {width - (width % 2)}x{height - (height % 2)} -pix_fmt yuv420p {video_path}'
subprocess.run(cmd, shell=True, check=True)
plt.figure(figsize=(10, 10))
plt.subplot(2, 1, 1)
plt.imshow(screen[:, :, :3].transpose(1, 0).cpu())
plt.title("Rendered Image")
plt.subplot(2, 1, 2)
plt.imshow(mpimg.imread(gt_img_path))
plt.title("Reference Image")
plt.show()
if __name__ == "__main__":
run_rasterization()