diff --git a/nflows/transforms/splines/rational_quadratic.py b/nflows/transforms/splines/rational_quadratic.py index bb6a8c4..fbb1e82 100644 --- a/nflows/transforms/splines/rational_quadratic.py +++ b/nflows/transforms/splines/rational_quadratic.py @@ -21,7 +21,6 @@ def unconstrained_rational_quadratic_spline( min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_derivative=DEFAULT_MIN_DERIVATIVE, - enable_identity_init=False, ): inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) outside_interval_mask = ~inside_interval_mask @@ -34,16 +33,18 @@ def unconstrained_rational_quadratic_spline( constant = np.log(np.exp(1 - min_derivative) - 1) unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., -1] = constant - - outputs[outside_interval_mask] = inputs[outside_interval_mask] - logabsdet[outside_interval_mask] = 0 + + outputs = torch.where(outside_interval_mask, inputs, outputs) + logabsdet = torch.where(outside_interval_mask, torch.zeros_like(logabsdet), logabsdet) else: raise RuntimeError("{} tails are not implemented.".format(tails)) if torch.any(inside_interval_mask): ( - outputs[inside_interval_mask], - logabsdet[inside_interval_mask], + # outputs[inside_interval_mask], + # logabsdet[inside_interval_mask], + a, + b, ) = rational_quadratic_spline( inputs=inputs[inside_interval_mask], unnormalized_widths=unnormalized_widths[inside_interval_mask, :], @@ -57,12 +58,29 @@ def unconstrained_rational_quadratic_spline( min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, - enable_identity_init=enable_identity_init, ) + + + # turn inside_interval_mask into an int tensor + in_mask = inside_interval_mask.reshape(-1).type(torch.int64) + repeat_values = in_mask * (torch.arange(in_mask.shape[-1], dtype=in_mask.dtype, device=in_mask.device) + 1) + # find the number of trailing zeros in repeat_values + repeat_values = repeat_values[repeat_values > 0] + trail = len(in_mask) - repeat_values[-1] + # add a zero to the beginning of the repeat_values and take the diff + repeat_values = torch.cat([torch.zeros_like(repeat_values[:1]), repeat_values], dim=-1).diff() + a_ = torch.repeat_interleave(a, repeat_values, dim=-1) + # add trail zeros to the end of a_ + a_ = F.pad(a_, pad=(0, trail), mode="constant", value=0.0) + a_ = a_.reshape(outputs.shape) + b_ = torch.repeat_interleave(b, repeat_values, dim=-1) + b_ = F.pad(b_, pad=(0, trail), mode="constant", value=0.0) + b_ = b_.reshape(logabsdet.shape) + outputs = torch.where(inside_interval_mask, a_, outputs) + logabsdet = torch.where(inside_interval_mask, b_, logabsdet) return outputs, logabsdet - def rational_quadratic_spline( inputs, unnormalized_widths, @@ -76,7 +94,6 @@ def rational_quadratic_spline( min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_derivative=DEFAULT_MIN_DERIVATIVE, - enable_identity_init=False, ): if torch.min(inputs) < left or torch.max(inputs) > right: raise InputOutsideDomain() @@ -97,11 +114,7 @@ def rational_quadratic_spline( cumwidths[..., -1] = right widths = cumwidths[..., 1:] - cumwidths[..., :-1] - if enable_identity_init: #flow is the identity if initialized with parameters equal to zero - beta = np.log(2) / (1 - min_derivative) - else: #backward compatibility - beta = 1 - derivatives = min_derivative + F.softplus(unnormalized_derivatives, beta=beta) + derivatives = min_derivative + F.softplus(unnormalized_derivatives) heights = F.softmax(unnormalized_heights, dim=-1) heights = min_bin_height + (1 - min_bin_height * num_bins) * heights