diff --git a/functorch/_src/make_functional.py b/functorch/_src/make_functional.py index 42268ea981..5c58d595d8 100644 --- a/functorch/_src/make_functional.py +++ b/functorch/_src/make_functional.py @@ -257,12 +257,12 @@ def __init__(self, stateless_model, param_names, buffer_names, self.all_names_map.update(buffer_names_map) @staticmethod - def _create_from(model, disable_params_grad=False): + def _create_from(model, disable_autograd_tracking=False): # TODO: We don't need to copy the model to create a stateless copy model_copy = copy.deepcopy(model) params, param_names, param_names_map = extract_weights(model_copy) buffers, buffer_names, buffer_names_map = extract_buffers(model_copy) - if disable_params_grad: + if disable_autograd_tracking: for param in params: param.requires_grad_(False) return ( @@ -297,11 +297,11 @@ def __init__(self, stateless_model, param_names, names_map): self.names_map = names_map @staticmethod - def _create_from(model, disable_params_grad=False): + def _create_from(model, disable_autograd_tracking=False): # TODO: We don't need to copy the model to create a stateless copy model_copy = copy.deepcopy(model) params, param_names, names_map = extract_weights(model_copy) - if disable_params_grad: + if disable_autograd_tracking: for param in params: param.requires_grad_(False) return FunctionalModule(model_copy, param_names, names_map), params @@ -316,7 +316,7 @@ def forward(self, params, *args, **kwargs): _swap_state(self.stateless_model, self.names_map, old_state) -def make_functional(model: nn.Module, disable_params_grad: bool = False): +def make_functional(model: nn.Module, disable_autograd_tracking: bool = False): """make_functional(model) -> func, params Given a ``torch.nn.Module``, :func:`make_functional` extracts the state @@ -361,20 +361,26 @@ def compute_loss(params, x, t): Args: model (torch.nn.Module): Input model. - disable_params_grad (bool): Flag to disable gradients for output parameters. - It could be helpful to set this flag to True if gradient computations only use the functorch API. - Disabling gradients also helps to avoid accumulating them in the backprop graph and to prevent - keeping them in memory. Default: False. + disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. + The returned params are unrelated to the set of params from the original model. If False (default), + the params will have ``requires_grad=True`` on them (aka they will be trackable with regular + PyTorch autograd), matching the requires_grad-ness of the params from the original model. + Otherwise, the returned params will have ``requires_grad=False``. Default, False. + If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or + ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. + Otherwise, if you're only planning on using functorch's gradient transforms, + then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking + history with PyTorch autograd. """ buffers = list(model.buffers()) if len(buffers) > 0: raise RuntimeError('make_functional(model): `model` has buffers. Please use ' 'make_functional_with_buffers(model) instead.') - return FunctionalModule._create_from(model, disable_params_grad=disable_params_grad) + return FunctionalModule._create_from(model, disable_autograd_tracking=disable_autograd_tracking) -def make_functional_with_buffers(model: nn.Module, disable_params_grad: bool = False): +def make_functional_with_buffers(model: nn.Module, disable_autograd_tracking: bool = False): """make_functional_with_buffers(model) -> func, params, buffers Given a ``torch.nn.Module``, make_functional_with_buffers extracts the @@ -416,13 +422,19 @@ def compute_loss(params, buffers, x, t): Args: model (torch.nn.Module): Input model. - disable_params_grad (bool): Flag to disable gradients for output parameters. - It could be helpful to set this flag to True if gradient computations only use the functorch API. - Disabling gradients also helps to avoid accumulating them in the backprop graph and to prevent - keeping them in memory. Default: False. + disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. + The returned params are unrelated to the set of params from the original model. If False (default), + the params will have ``requires_grad=True`` on them (aka they will be trackable with regular + PyTorch autograd), matching the requires_grad-ness of the params from the original model. + Otherwise, the returned params will have ``requires_grad=False``. Default, False. + If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or + ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. + Otherwise, if you're only planning on using functorch's gradient transforms, + then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking + history with PyTorch autograd. """ - return FunctionalModuleWithBuffers._create_from(model, disable_params_grad=disable_params_grad) + return FunctionalModuleWithBuffers._create_from(model, disable_autograd_tracking=disable_autograd_tracking) def transpose_stack(tuple_of_tuple_of_tensors):