Skip to content

Commit

Permalink
[skip ci] Updated docstring according to the review
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Apr 22, 2022
1 parent cb1273a commit db662b1
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions functorch/_src/make_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,12 @@ def __init__(self, stateless_model, param_names, buffer_names,
self.all_names_map.update(buffer_names_map)

@staticmethod
def _create_from(model, disable_params_grad=False):
def _create_from(model, disable_autograd_tracking=False):
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, param_names_map = extract_weights(model_copy)
buffers, buffer_names, buffer_names_map = extract_buffers(model_copy)
if disable_params_grad:
if disable_autograd_tracking:
for param in params:
param.requires_grad_(False)
return (
Expand Down Expand Up @@ -297,11 +297,11 @@ def __init__(self, stateless_model, param_names, names_map):
self.names_map = names_map

@staticmethod
def _create_from(model, disable_params_grad=False):
def _create_from(model, disable_autograd_tracking=False):
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, names_map = extract_weights(model_copy)
if disable_params_grad:
if disable_autograd_tracking:
for param in params:
param.requires_grad_(False)
return FunctionalModule(model_copy, param_names, names_map), params
Expand All @@ -316,7 +316,7 @@ def forward(self, params, *args, **kwargs):
_swap_state(self.stateless_model, self.names_map, old_state)


def make_functional(model: nn.Module, disable_params_grad: bool = False):
def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
"""make_functional(model) -> func, params
Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
Expand Down Expand Up @@ -361,20 +361,26 @@ def compute_loss(params, x, t):
Args:
model (torch.nn.Module): Input model.
disable_params_grad (bool): Flag to disable gradients for output parameters.
It could be helpful to set this flag to True if gradient computations only use the functorch API.
Disabling gradients also helps to avoid accumulating them in the backprop graph and to prevent
keeping them in memory. Default: False.
disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
The returned params are unrelated to the set of params from the original model. If False (default),
the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
PyTorch autograd), matching the requires_grad-ness of the params from the original model.
Otherwise, the returned params will have ``requires_grad=False``. Default, False.
If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
Otherwise, if you're only planning on using functorch's gradient transforms,
then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
history with PyTorch autograd.
"""
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError('make_functional(model): `model` has buffers. Please use '
'make_functional_with_buffers(model) instead.')
return FunctionalModule._create_from(model, disable_params_grad=disable_params_grad)
return FunctionalModule._create_from(model, disable_autograd_tracking=disable_autograd_tracking)


def make_functional_with_buffers(model: nn.Module, disable_params_grad: bool = False):
def make_functional_with_buffers(model: nn.Module, disable_autograd_tracking: bool = False):
"""make_functional_with_buffers(model) -> func, params, buffers
Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
Expand Down Expand Up @@ -416,13 +422,19 @@ def compute_loss(params, buffers, x, t):
Args:
model (torch.nn.Module): Input model.
disable_params_grad (bool): Flag to disable gradients for output parameters.
It could be helpful to set this flag to True if gradient computations only use the functorch API.
Disabling gradients also helps to avoid accumulating them in the backprop graph and to prevent
keeping them in memory. Default: False.
disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
The returned params are unrelated to the set of params from the original model. If False (default),
the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
PyTorch autograd), matching the requires_grad-ness of the params from the original model.
Otherwise, the returned params will have ``requires_grad=False``. Default, False.
If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
Otherwise, if you're only planning on using functorch's gradient transforms,
then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
history with PyTorch autograd.
"""
return FunctionalModuleWithBuffers._create_from(model, disable_params_grad=disable_params_grad)
return FunctionalModuleWithBuffers._create_from(model, disable_autograd_tracking=disable_autograd_tracking)


def transpose_stack(tuple_of_tuple_of_tensors):
Expand Down

0 comments on commit db662b1

Please sign in to comment.