Skip to content

Commit

Permalink
Use eval mode when performing dummy input forward pass
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Hsieh <[email protected]>
  • Loading branch information
quic-klhsieh authored May 19, 2022
1 parent 2bea378 commit a433425
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,11 @@ def forward_pass(model: torch.nn.Module, batch: torch.Tensor):
:param batch: batch
:return: Nothing
"""
model.eval()
# first check if the model is on GPU or not
if utils.is_model_on_gpu(model):
batch = batch.cuda()
try:
with torch.no_grad():
with utils.in_eval_mode(model), torch.no_grad():
_ = model(batch)
except StopForwardException:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from aimet_torch.layer_database import LayerDatabase, Layer
from aimet_torch.data_subsampler import DataSubSampler
from aimet_torch.channel_pruning.weight_reconstruction import WeightReconstructor
from aimet_torch import utils
from aimet_torch.winnow.winnow import winnow_model


Expand Down Expand Up @@ -150,7 +151,7 @@ def sorting_hook(module, _inp, _out):
handles.append(pair.layer.module.register_forward_hook(sorting_hook))

# run one forward pass with hooks
with torch.no_grad():
with utils.in_eval_mode(model), torch.no_grad():
_ = model(input_data)

# remove hooks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ def _forward_pass(model: torch.nn.Module, batch: Union[torch.Tensor, List, Tuple
:param model: model
:param batch: batch
"""
# keep the model in eval mode
model.eval()

# get the model's device placement information
device = utils.get_device(model)
Expand All @@ -179,7 +177,8 @@ def _forward_pass(model: torch.nn.Module, batch: Union[torch.Tensor, List, Tuple
batch = [batch]

try:
with torch.no_grad():
# keep the model in eval mode
with utils.in_eval_mode(model), torch.no_grad():
_ = model(*batch)
except StopForwardException:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,8 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn
cls._add_markers(model, module_name_map, module_marker_map, is_conditional)
temp_file = os.path.join(working_dir, 'temp_onnx_model_with_markers.onnx')
if is_conditional:
dummy_output = model(*dummy_input)
with aimet_torch.utils.in_eval_mode(model), torch.no_grad():
dummy_output = model(*dummy_input)
scripted_model = torch.jit.script(model)
torch.onnx.export(scripted_model, dummy_input, temp_file, example_outputs=dummy_output,
enable_onnx_checker=False, **onnx_export_args.kwargs)
Expand Down
13 changes: 6 additions & 7 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):
layer.set_mode(QcQuantizeOpMode.ANALYSIS)

# Run forward iterations so we can collect statistics to compute the appropriate encodings
self.model.eval()
with torch.no_grad():
with utils.in_eval_mode(self.model), torch.no_grad():
_ = forward_pass_callback(self.model, forward_pass_callback_args)

# Get the computed per-layer encodings and log them
Expand Down Expand Up @@ -371,7 +370,7 @@ def export_torch_script_model_and_encodings(path: str, filename_prefix: str,
:param dummy_input: Dummy input to the model. Used to parse model graph.
:return: None
"""
with torch.no_grad():
with utils.in_eval_mode(original_model), torch.no_grad():
trace = torch.jit.trace(original_model, dummy_input)
ts_path = os.path.join(path, filename_prefix + '.torchscript.pth')
trace.save(ts_path)
Expand Down Expand Up @@ -1144,8 +1143,7 @@ def _export_conditional(self, path: str, filename_prefix: str, dummy_input: Unio
if self._is_conditional:
self._add_inputs_hook(hooks)

self.model.eval()
with torch.no_grad():
with utils.in_eval_mode(self.model), torch.no_grad():
_ = forward_pass_callback(self.model, forward_pass_callback_args)

# Any hooks that were hit during forward pass callback would have removed themselves. Remove the remaining
Expand Down Expand Up @@ -1219,8 +1217,9 @@ def run_modules_for_traced_custom_marker(self, module_list: List[torch.nn.Module
module = getattr(module, '_module_to_wrap')
# Only perform init and trace if the given module is a leaf module, and we have not recorded it before
if module in module_to_name_map and module_to_name_map[module] not in self._module_marker_map:
marker_layer = torch.jit.trace(CustomMarker(module, module_to_name_map[module]),
dummy_input)
with utils.in_eval_mode(module), torch.no_grad():
marker_layer = torch.jit.trace(CustomMarker(module, module_to_name_map[module]),
dummy_input)
self._module_marker_map[module_to_name_map[module]] = marker_layer


Expand Down
9 changes: 3 additions & 6 deletions TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ def _hook_to_collect_inp_out_data(_, inp, out):

handle = self._module.register_forward_hook(_hook_to_collect_inp_out_data)

# keep the model in eval mode
self._model.eval()

# get the model's device placement information
device = get_device(self._model)

Expand All @@ -127,7 +124,7 @@ def _hook_to_collect_inp_out_data(_, inp, out):
model_input = [model_input]

try:
with torch.no_grad():
with in_eval_mode(self._model), torch.no_grad():
_ = self._model(*model_input)

except StopForwardException:
Expand Down Expand Up @@ -234,7 +231,7 @@ def run_hook_for_layers(model: torch.nn.Module, input_shapes: Union[Tuple, List[
device = get_device(model)
dummy_tensors = create_rand_tensors_given_shapes(input_shapes)
dummy_tensors = [tensor.to(device) for tensor in dummy_tensors]
with torch.no_grad():
with in_eval_mode(model), torch.no_grad():
_ = model(*dummy_tensors)

# --------------------------
Expand Down Expand Up @@ -271,7 +268,7 @@ def run_hook_for_layers_with_given_input(model: torch.nn.Module, input_tensor: U
# ------------------------------------------------
# Run forward pass to execute the hook functions
# ------------------------------------------------
with torch.no_grad():
with in_eval_mode(model), torch.no_grad():
if isinstance(input_tensor, (list, tuple)):
_ = model(*input_tensor)
else:
Expand Down
2 changes: 2 additions & 0 deletions TrainingExtensions/torch/test/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def test_change_tensor_device(self):

def _collect_inp_out_data(self, device):
model = TinyModel().to(device=device)
model.eval()
model_input = torch.randn(1, 3, 32, 32).to(device=device)

module_data = utils.ModuleData(model, model.conv1)
Expand Down Expand Up @@ -287,6 +288,7 @@ def test_collect_inp_out_data_gpu(self):

def _collect_inp_out_data_multi_input(self, device):
model = MultiInput().to(device=device)
model.eval()
inp_shape_1 = (1, 3, 32, 32)
inp_shape_2 = (1, 3, 20, 20)
model_input = utils.create_rand_tensors_given_shapes([inp_shape_1, inp_shape_2])
Expand Down

0 comments on commit a433425

Please sign in to comment.