Skip to content

Commit

Permalink
Temporarily prevent CG from traversing custom functional only modules
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Hsieh <[email protected]>
  • Loading branch information
quic-klhsieh authored Jul 30, 2022
1 parent 3d503f3 commit e09d587
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,10 @@ def _parse_trace_graph(self, # pylint: disable=too-many-locals
# Recursive parsing is needed 1) if the module is not leaf module.
# 2) If the module is leaf module but has multiple functional operations in
# forward method.
if self._is_recursive_parsing_needed(subgraph_model, module_to_jit_trace):
# TODO: Temporarily reverting this change while we rework CG to be able to handle reused custom modules
# with only functionals inside.
# if self._is_recursive_parsing_needed(subgraph_model, module_to_jit_trace):
if not is_leaf_module(subgraph_model):
node_name_to_subgraph_model[getattr_node_info.node_alias] = (subgraph_model, getattr_node_info)

# invoking forward method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def test_get_blacklisted_modules(self):
class TestModelValidatorPreparer:

@pytest.mark.cuda
@pytest.mark.skip('Disable temporarily while CG does not dive into custom modules with only functionals.')
def test_model_validator_preparer(self):
""" Validate model validator and preparer workflow """
model = Model().eval().cuda()
Expand Down

0 comments on commit e09d587

Please sign in to comment.