-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix In-place Assignments for PiecewiseRationalQuadratic
Compatibility with functorch
and torch2.0
#77
base: master
Are you sure you want to change the base?
Fix In-place Assignments for PiecewiseRationalQuadratic
Compatibility with functorch
and torch2.0
#77
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,6 @@ def unconstrained_rational_quadratic_spline( | |
min_bin_width=DEFAULT_MIN_BIN_WIDTH, | ||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, | ||
min_derivative=DEFAULT_MIN_DERIVATIVE, | ||
enable_identity_init=False, | ||
): | ||
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) | ||
outside_interval_mask = ~inside_interval_mask | ||
|
@@ -34,16 +33,18 @@ def unconstrained_rational_quadratic_spline( | |
constant = np.log(np.exp(1 - min_derivative) - 1) | ||
unnormalized_derivatives[..., 0] = constant | ||
unnormalized_derivatives[..., -1] = constant | ||
|
||
outputs[outside_interval_mask] = inputs[outside_interval_mask] | ||
logabsdet[outside_interval_mask] = 0 | ||
outputs = torch.where(outside_interval_mask, inputs, outputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we drop the zero init above if we're copying here anyway? I.e. |
||
logabsdet = torch.where(outside_interval_mask, torch.zeros_like(logabsdet), logabsdet) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, does the original line even have an effect? We're assigning zeros to what is already initialized to zeros. |
||
else: | ||
raise RuntimeError("{} tails are not implemented.".format(tails)) | ||
|
||
if torch.any(inside_interval_mask): | ||
( | ||
outputs[inside_interval_mask], | ||
logabsdet[inside_interval_mask], | ||
# outputs[inside_interval_mask], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not leave unused code in comments. |
||
# logabsdet[inside_interval_mask], | ||
a, | ||
b, | ||
) = rational_quadratic_spline( | ||
inputs=inputs[inside_interval_mask], | ||
unnormalized_widths=unnormalized_widths[inside_interval_mask, :], | ||
|
@@ -57,12 +58,29 @@ def unconstrained_rational_quadratic_spline( | |
min_bin_width=min_bin_width, | ||
min_bin_height=min_bin_height, | ||
min_derivative=min_derivative, | ||
enable_identity_init=enable_identity_init, | ||
) | ||
|
||
|
||
# turn inside_interval_mask into an int tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we simplify below by using masked_scatter? |
||
in_mask = inside_interval_mask.reshape(-1).type(torch.int64) | ||
repeat_values = in_mask * (torch.arange(in_mask.shape[-1], dtype=in_mask.dtype, device=in_mask.device) + 1) | ||
# find the number of trailing zeros in repeat_values | ||
repeat_values = repeat_values[repeat_values > 0] | ||
trail = len(in_mask) - repeat_values[-1] | ||
# add a zero to the beginning of the repeat_values and take the diff | ||
repeat_values = torch.cat([torch.zeros_like(repeat_values[:1]), repeat_values], dim=-1).diff() | ||
a_ = torch.repeat_interleave(a, repeat_values, dim=-1) | ||
# add trail zeros to the end of a_ | ||
a_ = F.pad(a_, pad=(0, trail), mode="constant", value=0.0) | ||
a_ = a_.reshape(outputs.shape) | ||
b_ = torch.repeat_interleave(b, repeat_values, dim=-1) | ||
b_ = F.pad(b_, pad=(0, trail), mode="constant", value=0.0) | ||
b_ = b_.reshape(logabsdet.shape) | ||
outputs = torch.where(inside_interval_mask, a_, outputs) | ||
logabsdet = torch.where(inside_interval_mask, b_, logabsdet) | ||
|
||
return outputs, logabsdet | ||
|
||
|
||
def rational_quadratic_spline( | ||
inputs, | ||
unnormalized_widths, | ||
|
@@ -76,7 +94,6 @@ def rational_quadratic_spline( | |
min_bin_width=DEFAULT_MIN_BIN_WIDTH, | ||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, | ||
min_derivative=DEFAULT_MIN_DERIVATIVE, | ||
enable_identity_init=False, | ||
): | ||
if torch.min(inputs) < left or torch.max(inputs) > right: | ||
raise InputOutsideDomain() | ||
|
@@ -97,11 +114,7 @@ def rational_quadratic_spline( | |
cumwidths[..., -1] = right | ||
widths = cumwidths[..., 1:] - cumwidths[..., :-1] | ||
|
||
if enable_identity_init: #flow is the identity if initialized with parameters equal to zero | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as earlier comment: we want to keep this, not sure if this is on purpose. |
||
beta = np.log(2) / (1 - min_derivative) | ||
else: #backward compatibility | ||
beta = 1 | ||
derivatives = min_derivative + F.softplus(unnormalized_derivatives, beta=beta) | ||
derivatives = min_derivative + F.softplus(unnormalized_derivatives) | ||
|
||
heights = F.softmax(unnormalized_heights, dim=-1) | ||
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this on purpose, or we should re-base the changes?