Skip to content

Commit

Permalink
Remove default typing argument in PACKAGE file] [batch:79/416] [shard…
Browse files Browse the repository at this point in the history
…:4/N]

Reviewed By: MaggieMoss

Differential Revision: D64867383

fbshipit-source-id: efc412f93077c0c1dbb0911a10dc051d0309fe06
  • Loading branch information
generatedunixname89002005307016 authored and facebook-github-bot committed Oct 28, 2024
1 parent 9480338 commit 16e3f14
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 6 deletions.
4 changes: 4 additions & 0 deletions flowtorch/bijectors/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@ def inverse(
# NOTE: Inversion is an expensive operation that scales in the
# dimension of the input
permutation = (
# pyre-fixme[16]: Item `None` of `Union[None, Parameters, ModuleList]`
# has no attribute `permutation`.
self._params_fn.permutation
) # TODO: type-safe named buffer (e.g. "permutation") access
# TODO: Make permutation, inverse work for other event shapes
log_detJ: torch.Tensor | None = None
for idx in cast(torch.LongTensor, permutation):
# pyre-fixme[29]: `Union[None, flowtorch.parameters.base.Parameters,
# torch.nn.modules.container.ModuleList]` is not a function.
_params = self._params_fn(x_new.clone(), context=context)
x_temp, log_detJ = self._inverse(y, params=_params)
x_new[..., idx] = x_temp[..., idx]
Expand Down
19 changes: 19 additions & 0 deletions flowtorch/bijectors/bijective_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ def register(
log_detJ: Tensor | None,
mode: str,
) -> "BijectiveTensor":
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_input`.
self._input = input
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_output`.
self._output = output
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_context`.
self._context = context
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_bijector`.
self._bijector = bijector
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_log_detJ`.
self._log_detJ = log_detJ
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_mode`.
self._mode = mode

if not (self.from_forward() or self.from_inverse()):
Expand Down Expand Up @@ -57,12 +63,14 @@ def check_bijector(self, bijector: "Bijector") -> bool:
return is_bijector

def bijectors(self) -> Iterator["Bijector"]:
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_bijector`.
yield self._bijector
for parent in self.parents():
if isinstance(parent, BijectiveTensor):
yield parent._bijector

def get_parent_from_bijector(self, bijector: "Bijector") -> Tensor:
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_bijector`.
if self._bijector is bijector:
return self.parent
for parent in self.parents():
Expand All @@ -73,15 +81,20 @@ def get_parent_from_bijector(self, bijector: "Bijector") -> Tensor:
raise RuntimeError("bijector not found in flow")

def check_context(self, context: Tensor | None) -> bool:
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_context`.
return self._context is context

def from_forward(self) -> bool:
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_mode`.
return self._mode == "forward"

def from_inverse(self) -> bool:
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_mode`.
return self._mode == "inverse"

def detach_from_flow(self) -> Tensor:
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_output`.
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_input`.
detached_tensor = self._output if self.from_forward() else self._input
if isinstance(detached_tensor, BijectiveTensor):
raise RuntimeError("the detached tensor is an instance of BijectiveTensor.")
Expand All @@ -90,8 +103,10 @@ def detach_from_flow(self) -> Tensor:
def has_ancestor(self, tensor: Tensor) -> bool:
if tensor is self:
return False # self is no parent of self
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_input`.
elif self.from_forward() and self._input is tensor:
return True
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_output`.
elif self.from_inverse() and self._output is tensor:
return True
elif self.from_forward() and isinstance(self._input, BijectiveTensor):
Expand All @@ -103,19 +118,23 @@ def has_ancestor(self, tensor: Tensor) -> bool:

@property
def log_detJ(self) -> Tensor | None:
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_log_detJ`.
return self._log_detJ

@property
def parent(self) -> Tensor:
if self.from_forward():
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_input`.
return self._input
else:
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_output`.
return self._output

def parents(self) -> Iterator[Tensor]:
child: Tensor | BijectiveTensor = self
while True:
assert isinstance(child, BijectiveTensor)
# pyre-fixme[16]: `Tensor` has no attribute `parent`.
child = parent = child.parent
yield parent
if not isinstance(child, BijectiveTensor):
Expand Down
5 changes: 5 additions & 0 deletions flowtorch/bijectors/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def forward(
y = bijector.forward(x_temp, context) # type: ignore
if is_record_flow_graph_enabled() and requires_log_detJ():
if isinstance(y, BijectiveTensor) and y.from_forward():
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_log_detJ`.
_log_detJ = y._log_detJ
elif isinstance(x_temp, BijectiveTensor) and x_temp.from_inverse():
_log_detJ = x_temp._log_detJ
Expand All @@ -74,11 +75,14 @@ def forward(
# TODO: Check that this doesn't contain bugs!
if (
is_record_flow_graph_enabled()
# pyre-fixme[61]: `y` is undefined, or not always defined.
and not isinstance(y, BijectiveTensor)
# pyre-fixme[61]: `y` is undefined, or not always defined.
and not (isinstance(x, BijectiveTensor) and y in set(x.parents()))
):
# we exclude y that are bijective tensors for Compose
y = to_bijective_tensor(x, x_temp, context, self, log_detJ, mode="forward")
# pyre-fixme[61]: `y` is undefined, or not always defined.
return y

def inverse(
Expand All @@ -93,6 +97,7 @@ def inverse(
x = bijector.inverse(y_temp, context) # type: ignore
if is_record_flow_graph_enabled() and requires_log_detJ():
if isinstance(y_temp, BijectiveTensor) and y_temp.from_forward():
# pyre-fixme[16]: `BijectiveTensor` has no attribute `_log_detJ`.
_log_detJ = y_temp._log_detJ
elif isinstance(x, BijectiveTensor) and x.from_inverse():
_log_detJ = x._log_detJ
Expand Down
8 changes: 7 additions & 1 deletion flowtorch/bijectors/permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _forward(
if self.permutation is None:
self.permutation = torch.randperm(x.shape[-1])

# pyre-fixme[6]: For 3rd argument expected `Tensor` but got `Optional[Tensor]`.
y = torch.index_select(x, -1, self.permutation)
ladj = self._log_abs_det_jacobian(x, y, params)
return y, ladj
Expand All @@ -54,6 +55,11 @@ def inv_permutation(self) -> torch.Tensor | None:

result = torch.empty_like(self.permutation, dtype=torch.long)
result[self.permutation] = torch.arange(
self.permutation.size(0), dtype=torch.long, device=self.permutation.device
# pyre-fixme[16]: `Optional` has no attribute `size`.
# pyre-fixme[16]: `Optional` has no attribute `device`.
self.permutation.size(0),
dtype=torch.long,
# pyre-fixme[16]: `Optional` has no attribute `device`.
device=self.permutation.device,
)
return result
1 change: 1 addition & 0 deletions flowtorch/bijectors/tanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ def _inverse(
def _log_abs_det_jacobian(
self, x: torch.Tensor, y: torch.Tensor, params: Sequence[torch.Tensor] | None
) -> torch.Tensor:
# pyre-fixme[7]: Expected `Tensor` but got `float`.
return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))
2 changes: 1 addition & 1 deletion flowtorch/bijectors/volume_preserving.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def _log_abs_det_jacobian(
return torch.zeros(
x.size()[: -self.domain.event_dim],
dtype=x.dtype,
layout=x.layout, # pyre-ignore[16]
layout=x.layout,
device=x.device,
)
23 changes: 19 additions & 4 deletions flowtorch/distributions/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,30 @@ def __init__(
# TODO: Confirm that the following logic works. Shouldn't it use
# .domain and .codomain?? Infer shape from constructed self.bijector
shape = (
self.base_dist.batch_shape + self.base_dist.event_shape # pyre-ignore[16]
self.base_dist.batch_shape
# pyre-fixme[58]: `+` is not supported for operand types `Size` and `Size`.
+ self.base_dist.event_shape
)
event_dim = self.bijector.domain.event_dim # type: ignore
event_dim = max(event_dim, len(self.base_dist.event_shape))
batch_shape = shape[: len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim :]

dist.Distribution.__init__(
self, batch_shape, event_shape, validate_args=validate_args
self,
# pyre-fixme[6]: For 2nd argument expected `Size` but got `Tuple[int, ...]`.
batch_shape,
# pyre-fixme[6]: For 3rd argument expected `Size` but got `Tuple[int, ...]`.
event_shape,
validate_args=validate_args,
)

def condition(self, context: torch.Tensor) -> "Flow":
self._context = context
return self

# pyre-fixme[14]: `sample` overrides method defined in `Distribution`
# inconsistently.
def sample(
self,
sample_shape: Tensor | torch.Size = _default_sample_shape,
Expand All @@ -57,10 +66,14 @@ def sample(
if context is None:
context = self._context
with torch.no_grad():
# pyre-fixme[6]: For 1st argument expected `Union[List[int], Size,
# typing.Tuple[int, ...]]` but got `Union[Size, Tensor]`.
x = self.base_dist.sample(sample_shape)
x = self.bijector.forward(x, context) # type: ignore
return x

# pyre-fixme[14]: `rsample` overrides method defined in `Distribution`
# inconsistently.
def rsample(
self,
sample_shape: Tensor | torch.Size = _default_sample_shape,
Expand All @@ -74,6 +87,8 @@ def rsample(
"""
if context is None:
context = self._context
# pyre-fixme[6]: For 1st argument expected `Union[List[int], Size,
# typing.Tuple[int, ...]]` but got `Union[Size, Tensor]`.
x = self.base_dist.rsample(sample_shape)
x = self.bijector.forward(x, context) # type: ignore
return x
Expand Down Expand Up @@ -109,7 +124,7 @@ def log_prob(
"""
if context is None:
context = self._context
event_dim = len(self.event_shape) # pyre-ignore[16]
event_dim = len(self.event_shape)

x = self.bijector.inverse(value, context) # type: ignore
log_prob = -_sum_rightmost(
Expand All @@ -118,7 +133,7 @@ def log_prob(
)
log_prob = log_prob + _sum_rightmost(
self.base_dist.log_prob(x),
event_dim - len(self.base_dist.event_shape), # pyre-ignore[16]
event_dim - len(self.base_dist.event_shape),
)

return log_prob
3 changes: 3 additions & 0 deletions flowtorch/distributions/neals_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ class NealsFunnel(dist.Distribution):
def __init__(self, validate_args: Any = None) -> None:
d = 2
batch_shape, event_shape = torch.Size([]), (d,)
# pyre-fixme[6]: For 2nd argument expected `Size` but got `Tuple[int]`.
super().__init__(batch_shape, event_shape, validate_args=validate_args)

# pyre-fixme[14]: `rsample` overrides method defined in `Distribution`
# inconsistently.
def rsample(
self,
sample_shape: torch.Tensor | torch.Size | None = None,
Expand Down
6 changes: 6 additions & 0 deletions flowtorch/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def generate_class_markdown(symbol_name: str, entity: Any) -> str:
try:
if hasattr(member_object, "__wrapped__"):
# decorators = get_decorators(member_object)
# pyre-fixme[16]: Item `BuiltinFunctionType` of
# `Union[BuiltinFunctionType, ClassMethodDescriptorType, FunctionType,
# MethodDescriptorType, MethodType, WrapperDescriptorType]` has no
# attribute `__wrapped__`.
member_object = member_object.__wrapped__
except Exception:
pass
Expand Down Expand Up @@ -267,6 +271,7 @@ def walk_packages(
else:
finder = None

# pyre-fixme[61]: `finder` is undefined, or not always defined.
if finder is not None:
module = finder.load_module(this_modname)

Expand All @@ -278,6 +283,7 @@ def walk_packages(
modules[this_modname] = (module, documentable_symbols(module))

del module
# pyre-fixme[61]: `finder` is undefined, or not always defined.
del finder

else:
Expand Down
1 change: 1 addition & 0 deletions flowtorch/nn/made.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
super().__init__(in_features, out_features, bias)
self.register_buffer("mask", mask.data)

# pyre-fixme[14]: `forward` overrides method defined in `Linear` inconsistently.
def forward(self, _input: torch.Tensor) -> torch.Tensor:
masked_weight = self.weight * self.mask
return F.linear(_input, masked_weight, self.bias)
3 changes: 3 additions & 0 deletions flowtorch/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def _select_bins(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
# Note that by convention, the context variable batch dimensions must broadcast
# over the input batch dimensions.
if len(idx.shape) >= len(x.shape):
# pyre-fixme[58]: `+` is not supported for operand types `Tuple[int, ...]`
# and `Size`.
x = x.reshape((1,) * (len(idx.shape) - len(x.shape)) + x.shape)
x = x.expand(idx.shape[:-2] + (-1,) * 2)

Expand Down Expand Up @@ -268,6 +270,7 @@ def monotonic_rational_spline(
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
# pyre-fixme[16]: `float` has no attribute `pow`.
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
Expand Down
1 change: 1 addition & 0 deletions flowtorch/parameters/dense_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _forward(

# results ~ (batch_shape, param_shapes[0]), ...
result = tuple(
# pyre-fixme[58]: `+` is not supported for operand types `Size` and `Size`.
h_slice.view(batch_shape + p_shape)
for h_slice, p_shape in zip(result, list(self.param_shapes))
)
Expand Down
2 changes: 2 additions & 0 deletions flowtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _walk_packages(
else:
finder = None

# pyre-fixme[61]: `finder` is undefined, or not always defined.
if finder is not None:
module = finder.load_module(this_modname)

Expand All @@ -90,6 +91,7 @@ def _walk_packages(
classes.extend(this_classes)

del module
# pyre-fixme[61]: `finder` is undefined, or not always defined.
del finder

else:
Expand Down

0 comments on commit 16e3f14

Please sign in to comment.