Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix for illegal memory access error caused when running medusa lora and plain loras in parallel. #525

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,6 @@ def load(
idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights
}

adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}

rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
Expand Down Expand Up @@ -338,10 +336,22 @@ def load(
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()]
batch_indices = [idx if idx in rank_indices else -1 for idx in batch_indices]
batch_indices = torch.tensor(batch_indices, dtype=torch.int64, device=device)
# `indices` indexes the `segment_indices` which contains segment wise adapter index
# `lora_a_ptr` contains segment wise pointers to lora weights
# lengths of `lora_a_ptr` and `segment_indices` must be same
# `indices` will be used to slice the `lora_a_ptr` tensor
# first, find the mapping between adapter index and its location in the `indices` array
idx_locs = {}
for loc, idx in enumerate(indices):
# use the idx to find the adapter index
if segment_indices[idx] not in idx_locs:
# save the first location of encountering a particular adapter index
idx_locs[segment_indices[idx]] = loc
# second, iterate over the adapter index for each token and find its location in the `indices` array
batch_indices = torch.tensor([
idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1
for idx in meta.adapter_indices.tolist()
], dtype=torch.int64, device=device)

rank_data[rank] = RankSegments(
rank=rank,
Expand Down
65 changes: 64 additions & 1 deletion server/tests/utils/test_lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Tuple, Type
from unittest import mock

import pytest
Expand Down Expand Up @@ -102,6 +102,69 @@ def test_batched_lora_weights(lora_ranks: List[int]):
assert rd.segment_ends.shape == (2,)



@pytest.mark.parametrize(
"lora_ranks,adapter_indices,expected",
[
(
[8, 8, 16],
[0, 0, 1, 1, 0, 0, 1, 1, 2, 2],
{
8: (4, [0, 0, 1, 1, 0, 0, 1, 1, -1, -1]),
16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0])
}
),
(
[4, 8, 16],
[0, 0, 1, 1, 0, 0, 1, 1, 2, 2],
{
4: (2, [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]),
8: (2, [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]),
16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]),
}
),
],
)
def test_batched_lora_weights_decode(
lora_ranks: List[int],
adapter_indices: List[int],
expected: Dict[int, Tuple[int, List[int]]]
):
from lorax_server.utils.segments import find_segments
batched_weights = LayerAdapterWeights()
assert batched_weights.is_empty()

h = 1024
for idx, lora_rank in enumerate(lora_ranks):
weights = LoraWeights(
weights_a=[torch.randn((h, lora_rank), dtype=torch.float16)],
weights_b=[torch.randn((lora_rank, h), dtype=torch.float16)],
adapter_config=LoraConfig(r=lora_rank),
)
batched_weights.add_adapter(idx, weights)

segments, segment_indices = find_segments(adapter_indices)

meta = AdapterBatchMetadata(
adapter_indices=torch.tensor(adapter_indices, dtype=torch.int64),
adapter_set=set(adapter_indices),
adapter_segments=torch.tensor(segments, dtype=torch.int64),
segment_indices=segment_indices,
)

with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))):
data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None).get(LORA)

for lora_rank, rd in data.rank_data.items():
expected_indices = torch.tensor(expected[lora_rank][1], dtype=rd.indices.dtype, device=rd.indices.device)
assert rd.lora_a_ptr.shape == (expected[lora_rank][0],)
assert rd.lora_b_ptr.shape == (expected[lora_rank][0],)
assert all(rd.indices == expected_indices)
assert rd.segment_starts == None
assert rd.segment_ends == None
assert rd.tmp_shrink == None
assert rd.tmp_expand == None

def test_batched_lora_weights_no_segments():
batched_weights = LayerAdapterWeights()
assert batched_weights.is_empty()
Expand Down
Loading