Skip to content

Commit

Permalink
Remove async eval and add sequential load
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Nov 5, 2024
1 parent 043fc2a commit 1c52719
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
4 changes: 4 additions & 0 deletions llms/mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def main():
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
sequential_load=mx.distributed.init().size() > 1,
)

if args.use_default_chat_template:
Expand Down Expand Up @@ -238,6 +239,9 @@ def main():
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None

world = mx.distributed.init()
print(f"Node {world.rank()} of {world.size()}", flush=True)
mx.distributed.init().barrier()
response = generate(
model,
tokenizer,
Expand Down
19 changes: 11 additions & 8 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,18 +308,13 @@ def _step(y):
y = y[prefill_step_size:]
mx.metal.clear_cache()

y, logprobs = _step(y)

mx.async_eval(y, logprobs)
n = 0
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
y, logprobs = _step(y)
n += 1
yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs


def stream_generate(
Expand Down Expand Up @@ -457,6 +452,7 @@ def load_config(model_path: Path) -> dict:
def load_model(
model_path: Path,
lazy: bool = False,
sequential_load: bool = False,
model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module:
Expand Down Expand Up @@ -528,6 +524,10 @@ def class_predicate(p, m):
model.shard()

if not lazy:
weights.clear()
if sequential_load:
for layer in model.layers:
mx.eval(layer.parameters())
mx.eval(model.parameters())

model.eval()
Expand All @@ -540,6 +540,7 @@ def load(
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
sequential_load: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
Expand All @@ -555,6 +556,8 @@ def load(
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
sequential_load (bool): If True then load each layer sequentially to
ensure that we are not wasting memory.
Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
Expand All @@ -564,7 +567,7 @@ def load(
"""
model_path = get_model_path(path_or_hf_repo)

model = load_model(model_path, lazy, model_config)
model = load_model(model_path, lazy, sequential_load, model_config)
if adapter_path is not None:
model = load_adapters(model, adapter_path)
model.eval()
Expand Down

0 comments on commit 1c52719

Please sign in to comment.