Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactoring: early return 'if else' syntax -> 'if' syntax #167

Open
wants to merge 1 commit into
base: v2
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions sonnet/src/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ def __call__(self, inputs: tf.Tensor, multiplier: types.FloatLike = None):
self._initialize(inputs)
if multiplier is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems to me we have a pattern of :

if A:
  return B + (C * D)
else:
  return B + D

I think perhaps in this case, if we want to avoid the extra else perhaps we should extract more of the common part of the two expressions (e.g. the if A block should just scale the thing we add to B.

if A:
  D = C * D
return B + D

Concretely in this case it would look like the following (I think we cannot use *= because this is not supported on tf.Variable):

b = self.b
if multiplier is not None:
  b = b * multiplier
return inputs + b

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Thanks You!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, please feel free to push a new commit with these changes so we can review and get them merged 😄

return inputs + (self.b * multiplier)
else:
return inputs + self.b
return inputs + self.b


def calculate_bias_shape(input_shape: types.ShapeLike,
Expand Down