Skip to content

Commit

Permalink
shard llama model after conversion and unshard on loading (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
dastrobu authored Dec 25, 2023
1 parent 738448c commit 2bd20ef
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
24 changes: 22 additions & 2 deletions llms/llama/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten


def llama(model_path):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
Expand Down Expand Up @@ -140,6 +139,22 @@ def quantize(weights, config, args):
return quantized_weights, quantized_config


def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
# TODO: simplify to v.nbytes as soon as mx.array exposes it
estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
Expand Down Expand Up @@ -200,6 +215,11 @@ def quantize(weights, config, args):
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
np.savez(str(mlx_path / "weights.npz"), **weights)
shards = make_shards(weights)
if len(shards) == 1:
np.savez(str(mlx_path / f"weights.npz"), **shards[0])
else:
for i, shard in enumerate(shards):
np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
with open(mlx_path / "config.json", "w") as fid:
json.dump(params, fid, indent=4)
20 changes: 18 additions & 2 deletions llms/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import json
import time
import glob
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
Expand Down Expand Up @@ -330,7 +331,23 @@ def sanitize_config(config, weights):

def load_model(model_path):
model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))

unsharded_weights_path = Path(model_path / "weights.npz")
if unsharded_weights_path.is_file():
print("[INFO] Loading model from {}.".format(unsharded_weights_path))
weights = mx.load(str(unsharded_weights_path))
else:
sharded_weights_glob = str(model_path / "weights.*.npz")
weight_files = glob.glob(sharded_weights_glob)
print("[INFO] Loading model from {}.".format(sharded_weights_glob))

if len(weight_files) == 0:
raise FileNotFoundError("No weights found in {}".format(model_path))

weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())

with open(model_path / "config.json", "r") as f:
config = sanitize_config(json.loads(f.read()), weights)
quantization = config.pop("quantization", None)
Expand Down Expand Up @@ -373,7 +390,6 @@ def load_model(model_path):

mx.random.seed(args.seed)

print("[INFO] Loading model from disk.")
model, tokenizer = load_model(args.model_path)
if args.few_shot:
few_shot_generate(args)
Expand Down

0 comments on commit 2bd20ef

Please sign in to comment.