-
Notifications
You must be signed in to change notification settings - Fork 458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for splitting in_features in linear layers #8715
Open
metascroy
wants to merge
1
commit into
main
Choose a base branch
from
split-in-feat
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+190
−58
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import sys | ||
|
||
sys.path.insert(0, ".") | ||
import copy | ||
|
||
import torch | ||
from utils import replace_linear_with_split_linear | ||
|
||
|
||
def get_split_model( | ||
model, | ||
out_target_split_size=1, | ||
out_max_splits=1, | ||
in_target_split_size=1, | ||
in_max_splits=1, | ||
): | ||
model_copy = copy.deepcopy(model) | ||
replace_linear_with_split_linear( | ||
model_copy, | ||
out_target_split_size, | ||
out_max_splits, | ||
in_target_split_size, | ||
in_max_splits, | ||
) | ||
return model_copy | ||
|
||
|
||
def test_split_model(): | ||
inputs = torch.randn(10, 5, 1, 512) | ||
|
||
model = torch.nn.Sequential(*[torch.nn.Linear(512, 1024, bias=False)]) | ||
model1 = get_split_model(model, 64, 2, 64, 1000) | ||
model2 = get_split_model(model, 64, 2, 64, 1) | ||
model3 = get_split_model(model, 64, 1, 64, 1000) | ||
|
||
assert torch.allclose(model(inputs), model1(inputs), atol=1e-5) | ||
assert torch.allclose(model(inputs), model2(inputs), atol=1e-5) | ||
assert torch.allclose(model(inputs), model3(inputs), atol=1e-5) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_split_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
|
||
|
||
class SplitLinearModule(torch.nn.Module): | ||
def __init__( | ||
self, | ||
in_features, | ||
out_features, | ||
out_target_split_size=1, | ||
out_max_splits=1, | ||
in_target_split_size=1, | ||
in_max_splits=1, | ||
): | ||
super(SplitLinearModule, self).__init__() | ||
self.out_split_sizes = self._get_split_sizes( | ||
out_features, out_target_split_size, out_max_splits | ||
) | ||
self.in_split_sizes = self._get_split_sizes( | ||
in_features, in_target_split_size, in_max_splits | ||
) | ||
print( | ||
f"Splitting out_features={out_features} into {len(self.out_split_sizes)} of size {self.out_split_sizes[0]}." | ||
) | ||
print( | ||
f"Splitting in_features={in_features} into {len(self.in_split_sizes)} of size {self.in_split_sizes[0]}." | ||
) | ||
|
||
# self.ops contains a list of linear ops for different pieces of the output matrix | ||
# The index of an op at (in_idx, out_idx) is given by self.op_index(in_idx, out_idx) | ||
self.ops = torch.nn.ModuleList() | ||
for idx_out, s_out in enumerate(self.out_split_sizes): | ||
for idx_in, s_in in enumerate(self.in_split_sizes): | ||
assert len(self.ops) == self.op_index(idx_in, idx_out) | ||
self.ops.append(torch.nn.Linear(s_in, s_out, bias=False)) | ||
|
||
def op_index(self, in_index, out_index): | ||
idx = out_index * len(self.in_split_sizes) + in_index | ||
return idx | ||
|
||
def _get_split_sizes(self, n_features, target_split_size, max_splits): | ||
num_splits = max(n_features // target_split_size, 1) | ||
if num_splits > max_splits: | ||
num_splits = max_splits | ||
|
||
split_size = n_features // num_splits | ||
split_remainder = n_features % num_splits | ||
if split_remainder > 0: | ||
raise ValueError( | ||
f"Cannot split {n_features} with target_split_size={target_split_size} and max_splits={max_splits} because it leaves a remainder of {split_remainder}." | ||
) | ||
|
||
ret = [split_size for _ in range(num_splits)] | ||
return ret | ||
|
||
def set_params(self, weight): | ||
split_weights = [] | ||
for w_out in weight.split(self.out_split_sizes, dim=0): | ||
for w in w_out.split(self.in_split_sizes, dim=1): | ||
split_weights.append(w) | ||
|
||
for i, split in enumerate(self.ops): | ||
split.weight = torch.nn.Parameter(split_weights[i]) | ||
|
||
def forward(self, x): | ||
if len(self.in_split_sizes) == 1: | ||
out_chunks = [op(x) for op in self.ops] | ||
else: | ||
x_splits = x.split(self.in_split_sizes, dim=-1) | ||
out_chunks = [ | ||
torch.sum( | ||
torch.stack( | ||
[ | ||
self.ops[self.op_index(in_idx, out_idx)].forward( | ||
x_splits[in_idx] | ||
) | ||
for in_idx in range(len(self.in_split_sizes)) | ||
], | ||
), | ||
dim=0, | ||
) | ||
for out_idx in range(len(self.out_split_sizes)) | ||
] | ||
|
||
return torch.concat(out_chunks, dim=-1) | ||
|
||
|
||
def replace_linear_with_split_linear( | ||
model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1 | ||
): | ||
for name, module in model.named_children(): | ||
if isinstance(module, torch.nn.Linear): | ||
assert module.bias is None, "SplitLinearModule does not support bias" | ||
new_module = SplitLinearModule( | ||
module.in_features, | ||
module.out_features, | ||
out_target_split_size, | ||
out_max_splits, | ||
in_target_split_size, | ||
in_max_splits, | ||
) | ||
new_module.set_params(module.weight) | ||
setattr(model, name, new_module) | ||
else: | ||
replace_linear_with_split_linear( | ||
module, | ||
out_target_split_size, | ||
out_max_splits, | ||
in_target_split_size, | ||
in_max_splits, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might miss some context - I thought you said splitting linear speeds up the perf in ANE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you said in features, in the description. What did you split to get better perf?