-
Notifications
You must be signed in to change notification settings - Fork 1
/
gsMerge.py
84 lines (70 loc) · 3.03 KB
/
gsMerge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from myUtils import gaus_transform, gaus_append, gaus_copy, rescale
import torch
from gaussian_renderer import render, network_gui
from scene import GaussianModel
from arguments import ModelParams, PipelineParams, OptimizationParams
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_FOUND = True
except ImportError:
TENSORBOARD_FOUND = False
#Utils:
def gaus_copy(g, gnew):
gnew.active_sh_degree = g.active_sh_degree
gnew._xyz = g._xyz
gnew._features_dc = g._features_dc
gnew._features_rest = g._features_rest
gnew._scaling = g._scaling
gnew._rotation = g._rotation
gnew._opacity = g._opacity
def gsMerge(dataset: ModelParams, opt: OptimizationParams, pipe: PipelineParams, testing_iterations, saving_iterations,
checkpoint_iterations, checkpoint: str, checkpoint2: str, debug_from, RI1, TI1, RI2, TI2, R1, T1, R2, T2, Sscale):
g1 = GaussianModel(dataset.sh_degree)
g1.training_setup(opt)
g2 = GaussianModel(dataset.sh_degree)
g2.training_setup(opt)
first_iter = 0
if checkpoint:
(model_params, first_iter) = torch.load(checkpoint)
g1.restore(model_params, opt)
print("First model loaded from checkpoint.")
if checkpoint2:
(model_params, first_iter) = torch.load(checkpoint2)
g2.restore(model_params, opt)
print("Second model loaded from checkpoint2.")
g1copy = GaussianModel(dataset.sh_degree)
g1copy.training_setup(opt)
g2copy = GaussianModel(dataset.sh_degree)
g2copy.training_setup(opt)
gnew = GaussianModel(dataset.sh_degree)
gnew.training_setup(opt)
gaus_copy(g1, g1copy)
gaus_copy(g2, g2copy)
gaus_transform(g1copy, RI1.t(), TI1)
gaus_transform(g1copy, R1, T1)
gaus_transform(g2copy, RI2.t(), TI2)
gaus_transform(g2copy, R2, T2)
g2scale = Sscale
rescale(g2copy, g2scale)
gaus_append(g1copy, g2copy, gnew)
print(g1.active_sh_degree, g2.active_sh_degree)
print(gnew.active_sh_degree)
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
test_iteration = 10000
torch.save((gnew.capture(), test_iteration), "gnew.pth")
print("fused model saved in gnew.pth!")
while True:
if network_gui.conn == None:
network_gui.try_connect()
while network_gui.conn != None:
try:
net_image_bytes = None
custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
if custom_cam != None:
net_image = render(custom_cam, gnew, pipe, background, scaling_modifer)["render"]
net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
network_gui.send(net_image_bytes, dataset.source_path)
except Exception as e:
print(e)
network_gui.conn = None