Skip to content

Commit

Permalink
Prepare for "Fix type-safety of torch.nn.Module instances": wave 2
Browse files Browse the repository at this point in the history
Summary: See D52890934

Reviewed By: malfet, r-barnes

Differential Revision: D66245100

fbshipit-source-id: 019058106ac7eaacf29c1c55912922ea55894d23
  • Loading branch information
ezyang authored and facebook-github-bot committed Nov 21, 2024
1 parent e20cbe9 commit f6c2ca6
Show file tree
Hide file tree
Showing 23 changed files with 147 additions and 1 deletion.
1 change: 1 addition & 0 deletions projects/implicitron_trainer/impl/optimizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __call__(
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
else:
p_groups = [
Expand Down
1 change: 1 addition & 0 deletions projects/implicitron_trainer/impl/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def _training_or_validation_epoch(
):
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
if hasattr(model, "visualize"):
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
model.visualize(
viz,
visdom_env_imgs,
Expand Down
4 changes: 4 additions & 0 deletions pytorch3d/implicitron/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def adjust_camera_to_bbox_crop_(

focal_length_px, principal_point_px = _convert_ndc_to_pixels(
camera.focal_length[0],
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
camera.principal_point[0],
image_size_wh,
)
Expand All @@ -341,6 +342,7 @@ def adjust_camera_to_bbox_crop_(
)

camera.focal_length = focal_length[None]
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
camera.principal_point = principal_point_cropped[None]


Expand All @@ -352,6 +354,7 @@ def adjust_camera_to_image_scale_(
) -> PerspectiveCameras:
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
camera.focal_length[0],
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
camera.principal_point[0],
original_size_wh,
)
Expand All @@ -368,6 +371,7 @@ def adjust_camera_to_image_scale_(
image_size_wh_output,
)
camera.focal_length = focal_length_scaled[None]
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
camera.principal_point = principal_point_scaled[None]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,15 @@ def _get_resnet_stage_feature_name(self, stage) -> str:
return f"res_layer_{stage + 1}"

def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
# pyre-fixme[58]: `/` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
return (img - self._resnet_mean) / self._resnet_std

def get_feat_dims(self) -> int:
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
# not a function.
return sum(self._feat_dim.values())

def forward(
Expand Down Expand Up @@ -183,7 +189,12 @@ def forward(
else:
imgs_normed = imgs_resized
# is not a function.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
feats = self.stem(imgs_normed)
# pyre-fixme[6]: For 1st argument expected `Iterable[_T1]` but got
# `Union[Tensor, Module]`.
# pyre-fixme[6]: For 2nd argument expected `Iterable[_T2]` but got
# `Union[Tensor, Module]`.
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
feats = layer(feats)
# just a sanity check below
Expand Down
4 changes: 4 additions & 0 deletions pytorch3d/implicitron/models/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ def curried_viewpooler(pts):
)
custom_args["global_code"] = global_code

# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
# function.
for func in self._implicit_functions:
func.bind_args(**custom_args)

Expand All @@ -500,6 +502,8 @@ def curried_viewpooler(pts):
# Unbind the custom arguments to prevent pytorch from storing
# large buffers of intermediate results due to points in the
# bound arguments.
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
# function.
for func in self._implicit_functions:
func.unbind_args()

Expand Down
3 changes: 3 additions & 0 deletions pytorch3d/implicitron/models/global_encoder/autodecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _build_key_map(
return key_map

def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`.
return (self._autodecoder_codes.weight**2).mean()

def get_encoding_dim(self) -> int:
Expand All @@ -95,13 +96,15 @@ def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tenso
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
# `Tensor`.
x = torch.tensor(
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
[self._key_map[elem] for elem in x],
dtype=torch.long,
device=next(self.parameters()).device,
)
except StopIteration:
raise ValueError("Not enough n_instances in the autodecoder") from None

# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._autodecoder_codes(x)

def _load_key_map_hook(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def forward(
if frame_timestamp.shape[-1] != 1:
raise ValueError("Frame timestamp's last dimensions should be one.")
time = frame_timestamp / self.time_divisor
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._harmonic_embedding(time)

def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,14 @@ def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
# if the skip tensor is None, we use `x` instead.
z = x
skipi = 0
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
# `Union[Tensor, Module]`.
for li, layer in enumerate(self.mlp):
# pyre-fixme[58]: `in` is not supported for right operand type
# `Union[Tensor, Module]`.
if li in self._input_skips:
if self._skip_affine_trans:
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
else:
y = torch.cat((y, z), dim=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,16 @@ def forward(
self.embed_fn is None and fun_viewpool is None and global_code is None
):
return torch.tensor(
[], device=rays_points_world.device, dtype=rays_points_world.dtype
[],
device=rays_points_world.device,
dtype=rays_points_world.dtype,
# pyre-fixme[6]: For 2nd argument expected `Union[int, SymInt]` but got
# `Union[Module, Tensor]`.
).view(0, self.out_dim)

embeddings = []
if self.embed_fn is not None:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
embeddings.append(self.embed_fn(rays_points_world))

if fun_viewpool is not None:
Expand All @@ -164,13 +169,19 @@ def forward(

embedding = torch.cat(embeddings, dim=-1)
x = embedding
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
for layer_idx in range(self.num_layers - 1):
if layer_idx in self.skip_in:
x = torch.cat([x, embedding], dim=-1) / 2**0.5

# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
x = self.linear_layers[layer_idx](x)

# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
if layer_idx < self.num_layers - 2:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
x = self.softplus(x)

return x
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,10 @@ def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)

# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self.color_layer((self.intermediate_linear(features), rays_embedding))

@staticmethod
Expand Down Expand Up @@ -195,6 +197,8 @@ def forward(
embeds = create_embeddings_for_implicit_function(
xyz_world=rays_points_world,
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
# got `Union[None, Tensor, Module]`.
xyz_embedding_function=(
self.harmonic_embedding_xyz if self.input_xyz else None
),
Expand All @@ -206,19 +210,23 @@ def forward(
)

# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
features = self.xyz_encoder(embeds)
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
# NNs operate on the flattenned rays; reshaping to the correct spatial size
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
features = features.reshape(*rays_points_world.shape[:-1], -1)

# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
raw_densities = self.density_layer(features)
# raw_densities.shape = [minibatch x ... x 1] in [0-1]

if self.xyz_ray_dir_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")

# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
directions = ray_bundle.directions @ camera.R
else:
directions = ray_bundle.directions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def forward(

embeds = create_embeddings_for_implicit_function(
xyz_world=rays_points_world,
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
# got `Union[Tensor, Module]`.
xyz_embedding_function=self._harmonic_embedding,
global_code=global_code,
fun_viewpool=fun_viewpool,
Expand All @@ -112,6 +114,7 @@ def forward(

# Before running the network, we have to resize embeds to ndims=3,
# otherwise the SRN layers consume huge amounts of memory.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
raymarch_features = self._net(
embeds.view(embeds.shape[0], -1, embeds.shape[-1])
)
Expand Down Expand Up @@ -166,7 +169,9 @@ def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
rays_embedding = self._harmonic_embedding(rays_directions_normed)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._color_layer((features, rays_embedding))

def forward(
Expand Down Expand Up @@ -195,20 +200,24 @@ def forward(
denoting the color of each ray point.
"""
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
features = self._net(raymarch_features)
# features.shape = [minibatch x ... x self.n_hidden_units]

if self.ray_dir_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")

# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
directions = ray_bundle.directions @ camera.R
else:
directions = ray_bundle.directions

# NNs operate on the flattenned rays; reshaping to the correct spatial size
features = features.reshape(*raymarch_features.shape[:-1], -1)

# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
raw_densities = self._density_layer(features)

rays_colors = self._get_colors(features, directions)
Expand Down Expand Up @@ -269,6 +278,7 @@ def _run_hypernet(self, global_code: torch.Tensor) -> Tuple[SRNRaymarchFunction]
srn_raymarch_function.
"""

# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
net = self._hypernet(global_code)

# use the hyper-net generated network to instantiate the raymarch module
Expand Down Expand Up @@ -304,6 +314,8 @@ def forward(
# across LSTM iterations for the same global_code.
if self.cached_srn_raymarch_function is None:
# generate the raymarching network from the hypernet
# pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
# `cached_srn_raymarch_function`.
self.cached_srn_raymarch_function = self._run_hypernet(global_code)
(srn_raymarch_function,) = cast(
Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
Expand Down Expand Up @@ -331,6 +343,7 @@ def __post_init__(self):
def create_raymarch_function(self) -> None:
self.raymarch_function = SRNRaymarchFunction(
latent_dim=self.latent_dim,
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
**self.raymarch_function_args,
)

Expand Down Expand Up @@ -389,6 +402,7 @@ def create_hypernet(self) -> None:
self.hypernet = SRNRaymarchHyperNet(
latent_dim=self.latent_dim,
latent_dim_hypernet=self.latent_dim_hypernet,
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
**self.hypernet_args,
)

Expand Down
11 changes: 11 additions & 0 deletions pytorch3d/implicitron/models/implicit_function/voxel_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def change_individual_resolution(tensor, wanted_resolution):
for name, tensor in vars(grid_values_with_wanted_resolution).items()
}

# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self.values_type(**params), True

def get_resolution_change_epochs(self) -> Tuple[int, ...]:
Expand Down Expand Up @@ -882,6 +883,7 @@ def forward(self, points: torch.Tensor) -> torch.Tensor:
torch.Tensor of shape (..., n_features)
"""
locator = self._get_volume_locator()
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
grid_values = self.voxel_grid.values_type(**self.params)
# voxel grids operate with extra n_grids dimension, which we fix to one
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
Expand All @@ -895,6 +897,7 @@ def set_voxel_grid_parameters(self, params: VoxelGridValuesBase) -> None:
replace current parameters
"""
if self.hold_voxel_grid_as_parameters:
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
self.params = torch.nn.ParameterDict(
{
k: torch.nn.Parameter(val)
Expand Down Expand Up @@ -945,6 +948,7 @@ def _apply_epochs(self, epoch: int) -> bool:
Returns:
True if parameter change has happened else False.
"""
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
grid_values = self.voxel_grid.values_type(**self.params)
grid_values, change = self.voxel_grid.change_resolution(
grid_values, epoch=epoch
Expand Down Expand Up @@ -992,16 +996,21 @@ def _create_parameters_with_new_size(
"""
'''
new_params = {}
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
# function.
for name in self.params:
key = prefix + "params." + name
if key in state_dict:
new_params[name] = torch.zeros_like(state_dict[key])
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))

def get_device(self) -> torch.device:
"""
Returns torch.device on which module parameters are located
"""
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
# not a function.
return next(val for val in self.params.values() if val is not None).device

def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
Expand All @@ -1018,13 +1027,15 @@ def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
"""
locator = self._get_volume_locator()
# torch.nn.modules.module.Module]` is not a function.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
old_grid_values = self.voxel_grid.values_type(**self.params)
new_grid_values = self.voxel_grid.crop_world(
min_point, max_point, old_grid_values, locator
)
grid_values, _ = self.voxel_grid.change_resolution(
new_grid_values, grid_values_with_wanted_resolution=old_grid_values
)
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
self.params = torch.nn.ParameterDict(
{
k: torch.nn.Parameter(val)
Expand Down
Loading

0 comments on commit f6c2ca6

Please sign in to comment.