Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for splitting in_features in linear layers #8715

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 14 additions & 53 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse

Expand All @@ -24,55 +26,7 @@

sys.path.insert(0, ".")
from llama_transformer import InputManager, load_model


class SplitLinearModule(torch.nn.Module):
def __init__(self, in_features, out_features, target_split_size, max_splits):
super(SplitLinearModule, self).__init__()
num_splits = max(out_features // target_split_size, 1)
if num_splits > max_splits:
num_splits = max_splits

self.split_size = out_features // num_splits
self.split_remainder = out_features % num_splits
self.splits = torch.nn.ModuleList(
[torch.nn.Linear(in_features, self.split_size) for _ in range(num_splits)]
)
print(
f"Splitting out_features={out_features} into {num_splits} of size {self.split_size}"
)
if self.split_remainder > 0:
print(
f"Warning: remainder {self.split_remainder} after splitting out_features={out_features} into {num_splits} of size {self.split_size}"
)
self.splits.append(torch.nn.Linear(in_features, self.split_remainder))

def split_sizes(self):
return [split.out_features for split in self.splits]

def forward(self, x):
return torch.cat([split(x) for split in self.splits], dim=-1)


def replace_linear_with_split_linear(model, target_split_size, max_splits):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
new_module = SplitLinearModule(
module.in_features, module.out_features, target_split_size, max_splits
)
split_sizes = new_module.split_sizes()
if module.bias is not None:
split_bias = module.bias.split(split_sizes)
split_weights = module.weight.split(split_sizes, dim=0)
for i, split in enumerate(new_module.splits):
split.weight = torch.nn.Parameter(split_weights[i])
if module.bias is not None:
split.bias = torch.nn.Parameter(split_bias[i])
else:
split.bias = None
setattr(model, name, new_module)
else:
replace_linear_with_split_linear(module, target_split_size, max_splits)
from utils import replace_linear_with_split_linear


def main() -> None:
Expand Down Expand Up @@ -175,7 +129,13 @@ def main() -> None:

if export_args.target_split_size is not None:
replace_linear_with_split_linear(
model, export_args.target_split_size, export_args.max_splits
model,
out_target_split_size=export_args.target_split_size,
out_max_splits=export_args.max_splits,
# I have not found splitting on in_features to be beneficial,
# and it often leads to OOM so I set in_max_splits to 1
in_target_split_size=1,
in_max_splits=1,
)

model.eval()
Expand Down Expand Up @@ -241,6 +201,7 @@ def main() -> None:
ep,
preserve_ops=[
torch.ops.aten.scaled_dot_product_attention.default,
# preserve norm op for numerical stability
torch.ops.aten.linalg_vector_norm.default,
],
compile_config=EdgeCompileConfig(
Expand Down
11 changes: 6 additions & 5 deletions examples/apple/coreml/llama/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ The runner can also be used to run an eager model model to compare with CoreML n

We are actively experimenting with different settings. But here are ones that we've found work well for Llama1B on iPhone 15 Pro:

* Set use_cache_list
* Split linear layers with target_split_size=1024, max_splits=8
* Use seq_length=32 or seq_length=64, both of which offer reasonable tradeoffs for prefill and decode performance. seq_length=32 is better at decode and seq_length=64 is better at prefill.

In our tests, we set max_seq_length=1024, but if your application allows for it, performance can improve with max_seq_length=512 or by keeping max_seq_length=1024 and setting cache_size=512-seq_length.
* Set use_cache_list.
* Use seq_length = 32, which offers a good balance between prefill/decode performance.
* Split out_features in linear layers with target_split_size=1024, max_splits=8.
* For ANE, set dtype = fp16, coreml-quantize = c4w. The requires doing QAT on Llama1B for good accuracy.
* Set embedding-quantize to "4,32".
* Set max_seq_length to 128, 256, 512, 1024, and 2048, depending on needed context. Note that performance drops with max_seq_length. More specifically, performance drops with cache_size, and the best experience may require a good cache eviction policy. The python runner in run.py uses a last-in-last-out policy when cache_size is specified.
6 changes: 6 additions & 0 deletions examples/apple/coreml/llama/run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import sys

Expand Down
48 changes: 48 additions & 0 deletions examples/apple/coreml/llama/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sys

sys.path.insert(0, ".")
import copy

import torch
from utils import replace_linear_with_split_linear


def get_split_model(
model,
out_target_split_size=1,
out_max_splits=1,
in_target_split_size=1,
in_max_splits=1,
):
model_copy = copy.deepcopy(model)
replace_linear_with_split_linear(
model_copy,
out_target_split_size,
out_max_splits,
in_target_split_size,
in_max_splits,
)
return model_copy


def test_split_model():
inputs = torch.randn(10, 5, 1, 512)

model = torch.nn.Sequential(*[torch.nn.Linear(512, 1024, bias=False)])
model1 = get_split_model(model, 64, 2, 64, 1000)
model2 = get_split_model(model, 64, 2, 64, 1)
model3 = get_split_model(model, 64, 1, 64, 1000)

assert torch.allclose(model(inputs), model1(inputs), atol=1e-5)
assert torch.allclose(model(inputs), model2(inputs), atol=1e-5)
assert torch.allclose(model(inputs), model3(inputs), atol=1e-5)


if __name__ == "__main__":
test_split_model()
116 changes: 116 additions & 0 deletions examples/apple/coreml/llama/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch


class SplitLinearModule(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might miss some context - I thought you said splitting linear speeds up the perf in ANE?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you said in features, in the description. What did you split to get better perf?

def __init__(
self,
in_features,
out_features,
out_target_split_size=1,
out_max_splits=1,
in_target_split_size=1,
in_max_splits=1,
):
super(SplitLinearModule, self).__init__()
self.out_split_sizes = self._get_split_sizes(
out_features, out_target_split_size, out_max_splits
)
self.in_split_sizes = self._get_split_sizes(
in_features, in_target_split_size, in_max_splits
)
print(
f"Splitting out_features={out_features} into {len(self.out_split_sizes)} of size {self.out_split_sizes[0]}."
)
print(
f"Splitting in_features={in_features} into {len(self.in_split_sizes)} of size {self.in_split_sizes[0]}."
)

# self.ops contains a list of linear ops for different pieces of the output matrix
# The index of an op at (in_idx, out_idx) is given by self.op_index(in_idx, out_idx)
self.ops = torch.nn.ModuleList()
for idx_out, s_out in enumerate(self.out_split_sizes):
for idx_in, s_in in enumerate(self.in_split_sizes):
assert len(self.ops) == self.op_index(idx_in, idx_out)
self.ops.append(torch.nn.Linear(s_in, s_out, bias=False))

def op_index(self, in_index, out_index):
idx = out_index * len(self.in_split_sizes) + in_index
return idx

def _get_split_sizes(self, n_features, target_split_size, max_splits):
num_splits = max(n_features // target_split_size, 1)
if num_splits > max_splits:
num_splits = max_splits

split_size = n_features // num_splits
split_remainder = n_features % num_splits
if split_remainder > 0:
raise ValueError(
f"Cannot split {n_features} with target_split_size={target_split_size} and max_splits={max_splits} because it leaves a remainder of {split_remainder}."
)

ret = [split_size for _ in range(num_splits)]
return ret

def set_params(self, weight):
split_weights = []
for w_out in weight.split(self.out_split_sizes, dim=0):
for w in w_out.split(self.in_split_sizes, dim=1):
split_weights.append(w)

for i, split in enumerate(self.ops):
split.weight = torch.nn.Parameter(split_weights[i])

def forward(self, x):
if len(self.in_split_sizes) == 1:
out_chunks = [op(x) for op in self.ops]
else:
x_splits = x.split(self.in_split_sizes, dim=-1)
out_chunks = [
torch.sum(
torch.stack(
[
self.ops[self.op_index(in_idx, out_idx)].forward(
x_splits[in_idx]
)
for in_idx in range(len(self.in_split_sizes))
],
),
dim=0,
)
for out_idx in range(len(self.out_split_sizes))
]

return torch.concat(out_chunks, dim=-1)


def replace_linear_with_split_linear(
model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1
):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
assert module.bias is None, "SplitLinearModule does not support bias"
new_module = SplitLinearModule(
module.in_features,
module.out_features,
out_target_split_size,
out_max_splits,
in_target_split_size,
in_max_splits,
)
new_module.set_params(module.weight)
setattr(model, name, new_module)
else:
replace_linear_with_split_linear(
module,
out_target_split_size,
out_max_splits,
in_target_split_size,
in_max_splits,
)
Loading