Skip to content

Commit

Permalink
[Bugfix] recompute dep filter param (#49010)
Browse files Browse the repository at this point in the history
* recompute dep filter param

* recompute dep for reshard
  • Loading branch information
JZ-LIANG authored Dec 14, 2022
1 parent a8d139a commit b9fad5d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2497,6 +2497,8 @@ def _reshard_output(self, block):
"read",
"write_to_array",
"read_from_array",
"nop",
"depend",
]
global _g_special_ops
skip_ops += _g_special_ops
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,9 @@ def insert_dependencies_for_two_ops(

def _select_best_depend_var(vars):

# parameter should not be dep var since it maybe partition in sharding pass
vars = [var for var in vars if not var.is_parameter]
assert len(vars) > 0
vars_with_numels = [(var, get_var_numel(var)) for var in vars]
vars_with_numels.sort(key=lambda x: x[1])

Expand Down

0 comments on commit b9fad5d

Please sign in to comment.