diff --git a/docs/supported_models.md b/docs/supported_models.md index cab1d1f38..84c443010 100644 --- a/docs/supported_models.md +++ b/docs/supported_models.md @@ -171,14 +171,15 @@ Neural Speed supports the following models:
LazyTensor: + data_type = SAFETENSORS_DATA_TYPES[info['dtype']] + numpy_dtype = DATA_TYPE_TO_NUMPY[data_type] + shape: List[int] = info['shape'] + begin, end = info['data_offsets'] + assert 0 <= begin <= end <= len(byte_buf) + assert end - begin == math.prod(shape) * numpy_dtype.itemsize + buf = byte_buf[begin:end] + + def load() -> UnquantizedTensor: + return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) + + description = f'safetensors begin={begin} end={end} type={data_type} path={path}' + return LazyTensor(load, shape, data_type, description) + + model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} + return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) + + +def must_read(fp: IO[bytes], length: int) -> bytes: + ret = fp.read(length) + if len(ret) < length: + raise Exception("unexpectedly reached end of file") + return ret + + +def lazy_load_ne_file(fp: io.BufferedReader, path: Path) -> ModelPlus: + magic = must_read(fp, 4)[::-1] + if magic in (b'ggmf', b'ggjt'): + version, = struct.unpack("i", must_read(fp, 4)) + assert version == 1 + else: + assert magic == b'ne' + version = None + n_vocab, n_embd, n_mult, n_head, n_layer, rot, file_type = struct.unpack('<7i', must_read(fp, 28)) + + tokens: List[Tuple[bytes, float]] = [] + for i in range(n_vocab): + if i == 32000: + # HACK: GPT4All messed with the format without changing the magic + # number. Specifically, they changed the vocab section to contain + # `n_vocab - 1` tokens instead of `n_vocab` (i.e. omitting the + # extra pad token). Try to detect if we're reading a file like + # this. + orig_pos = fp.tell() + fp.seek(20, io.SEEK_CUR) + is_gpt4all = fp.read(21) == b'tok_embeddings.weight' + fp.seek(orig_pos) + if is_gpt4all: + break + + length, = struct.unpack("i", must_read(fp, 4)) + text = must_read(fp, length) + if magic != b'ne': + score, = struct.unpack("f", must_read(fp, 4)) + tokens.append((text, score)) + vocab = NEVocab(tokens) if magic != b'ne' else None + + model: LazyModel = {} + # Use mmap for the actual data to avoid race conditions with the file offset. + off = fp.raw.tell() + mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ)) + fp.raw.seek(off) # needed on Windows + + def read_tensor() -> None: # this is a function so that variables captured in `load` don't change + shape_len, name_len, ftype = struct.unpack("iii", must_read(fp, 12)) + assert 0 <= shape_len <= 3 + shape: List[int] = list(struct.unpack(f"{shape_len}i", must_read(fp, 4 * shape_len))) + shape = shape[::-1] + name = must_read(fp, name_len).decode('utf-8') + data_type = FTYPE_TO_DATA_TYPE[ftype] + + if magic == b'ggjt': + fp.seek((fp.tell() + 31) & -32) + + if data_type == DT_Q4_1: + # See GPTQForLLaMaQuantizedTensor.ne_ndarray() + size = 24 * (shape[1] // 32) * shape[0] + elif data_type == DT_Q4_0: + size = 20 * (shape[1] // 32) * shape[0] + else: + numpy_dtype = DATA_TYPE_TO_NUMPY[data_type] + elm_count = math.prod(shape) + size = elm_count * numpy_dtype.itemsize + offset = fp.tell() + buf = mapped[offset:offset + size] + fp.seek(size, io.SEEK_CUR) + + def load() -> Tensor: + if isinstance(data_type, QuantizedDataType): + ndarray = np.frombuffer(buf, dtype=np.uint32) + return NEQuantizedTensor(ndarray, shape, data_type) + else: + return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) + + description = f'ne offset={offset} type={data_type} path={path}' + model[name] = LazyTensor(load, shape, data_type, description) + + while fp.read(1) != b'': + fp.seek(-1, io.SEEK_CUR) + read_tensor() + + return ModelPlus(model=model, paths=[path], format='ne', vocab=vocab) + + +@functools.lru_cache(maxsize=None) +def lazy_load_file(path: Path) -> ModelPlus: + fp = open(path, 'rb') + first8 = fp.read(8) + fp.seek(0) + if first8[:2] == b'PK': + # A zip file, i.e. PyTorch format + return lazy_load_torch_file(fp, path) + elif first8[2:4] == b'gg': + # NE format + return lazy_load_ne_file(fp, path) + elif struct.unpack('Iterable[Out]: + '''Parallel map, but with backpressure. If the caller doesn't call `next` + fast enough, this will stop calling `func` at some point rather than + letting results pile up in memory. Specifically, there is a max of one + output value buffered per thread.''' + with concurrent.futures.ThreadPoolExecutor() as executor: + futures: List[concurrent.futures.Future[Out]] = [] + items_rev = list(iterable)[::-1] + for i in range(min(concurrency, len(items_rev))): + futures.append(executor.submit(func, items_rev.pop())) + while futures: + result = futures.pop(0).result() + if items_rev: + futures.append(executor.submit(func, items_rev.pop())) + yield result + + +def check_vocab_size(params: Params, vocab: Vocab) -> None: + if params.n_vocab != vocab.vocab_size: + # NEVocab comes from the same file as the model so shouldn't mismatch: + assert isinstance(vocab, SentencePieceVocab) + if params.n_vocab == vocab.vocab_size_base: + print("Ignoring added_tokens.json since model matches vocab size without it.") + vocab.added_tokens_list = [] + vocab.vocab_size = vocab.vocab_size_base + return + msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}" + if vocab.fname_added_tokens is not None: + msg += f" combined with {vocab.fname_added_tokens}" + msg += f" has {vocab.vocab_size})." + if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20 and vocab.fname_added_tokens is None: + msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." + raise Exception(msg) + + +class OutputFile: + def __init__(self, fname_out: Path) -> None: + self.fout = open(fname_out, "wb") + + def write_file_header(self, params: Params, file_type: NEFileType) -> None: + self.fout.write(b"ggjt"[::-1]) # magic + values = [ + 1, # file version + params.n_vocab, + params.n_embd, + params.n_mult, + params.n_head, + params.n_head_kv, # n_head_kv (multi_query attention) + params.n_layer, + params.n_embd // params.n_head, # rot (obsolete) + file_type.value, + ] + self.fout.write(struct.pack("i" * len(values), *values)) + self.fout.write(struct.pack("i", 0)) + self.fout.write(struct.pack("f", 0)) + self.fout.write(struct.pack("f", 0)) + self.fout.write(struct.pack("i", 0)) + self.fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt) + self.fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt) + + self.fout.write(struct.pack("i", 0)) + self.fout.write(struct.pack("i", params.ffn_hidden_size)) + self.fout.write(struct.pack("i", 0)) + self.fout.write(struct.pack("i", 8)) + self.fout.write(struct.pack("i", 2)) + self.fout.write(struct.pack("f", params.rms_norm_eps)) + self.fout.write(struct.pack("f", params.rope_theta)) + self.fout.write(struct.pack("f", params.rope_scale)) + + self.fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + self.fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + self.fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) + + self.fout.write( + struct.pack("i", 1) + ) + # TODO, bos_token_id = 0 in https://huggingface.co/decapoda-research/llama-7b-hf/blob/main/config.json + # but bos_token_id = 1 in llama.cpp + self.fout.write(struct.pack("i", 2)) + + self.fout.write(struct.pack("i", 0)) + self.fout.write(struct.pack("i", 0)) + + def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None: + sname = name.encode('utf-8') + self.fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[data_type])) + self.fout.write(struct.pack("i" * len(shape), *shape[::-1])) + self.fout.write(sname) + self.fout.seek((self.fout.tell() + 31) & -32) + + def write_vocab(self, vocab: Vocab) -> None: + for text, score in vocab.all_tokens(): + self.fout.write(struct.pack("i", len(text))) + self.fout.write(text) + self.fout.write(struct.pack("f", score)) + + @staticmethod + def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: + of = OutputFile(fname_out) + params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0, file_type=NEFileType.AllF32) + of = OutputFile(fname_out) + of.write_file_header(params) + of.write_vocab(vocab) + of.fout.close() + + @staticmethod + def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab, file_type: NEFileType) -> None: + check_vocab_size(params, vocab) + of = OutputFile(fname_out) + of.write_file_header(params, file_type) + print("Writing vocab...") + of.write_vocab(vocab) + + def do_item(item: Tuple[str, LazyTensor]) -> NDArray: + name, lazy_tensor = item + return lazy_tensor.load().to_ne().ndarray + + ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8) + for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + padi = len(str(len(model))) + print( + f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} |\ + type {lazy_tensor.data_type}" + ) + of.write_tensor_header(name, lazy_tensor.shape, lazy_tensor.data_type) + ndarray.tofile(of.fout) + of.fout.close() + + +def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> NEFileType: + wq_type = model["layers.0.attention.wq.weight"].data_type + if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): + return NEFileType.AllF32 + if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16): + return NEFileType.MostlyF16 + if output_type_str == "q4_1" or (output_type_str is None and isinstance(wq_type, QuantizedDataType) + and wq_type.have_addends): + if isinstance(model["output.weight"].data_type, QuantizedDataType): + return NEFileType.MostlyQ4_1 + else: + return NEFileType.PerLayerIsQ4_1 + if output_type_str == "q4_0" or (output_type_str is None and isinstance(wq_type, QuantizedDataType)): + return NEFileType.MostlyQ4_0 + name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} + raise Exception(f"Unexpected combination of types: {name_to_type}") + + +def do_necessary_conversions(model: LazyModel, params: Params) -> LazyModel: + model = handle_quantization(model) + + if "lm_head.weight" in model: + model = convert_transformers_to_orig(model, params) + model = filter_and_sort_tensors(model) + + return model + + +def convert_to_output_type(model: LazyModel, output_type: NEFileType) -> LazyModel: + return {name: tensor.astype(output_type.type_for_tensor(name, tensor)) for (name, tensor) in model.items()} + + +def nth_multifile_path(path: Path, n: int) -> Optional[Path]: + '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + the nth path in the model. + ''' + # Support the following patterns: + patterns: List[Tuple[str, str]] = [ + # - x.00.pth, x.01.pth, etc. + (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'), + # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. + (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'), + # x.bin, x.bin.1, etc. + (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}') + ] + for regex, replacement in patterns: + if re.search(regex, path.name): + new_path = path.with_name(re.sub(regex, replacement, path.name)) + if new_path.exists(): + return new_path + return None + + +def find_multifile_paths(path: Path) -> List[Path]: + '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + the whole list of paths in the model. + ''' + ret: List[Path] = [] + for i in itertools.count(): + nth_path = nth_multifile_path(path, i) + if nth_path is None: + break + ret.append(nth_path) + if not ret: + # No matches. This should only happen if the file was named, e.g., + # foo.0, and there was no file named foo. Oh well, try to process it + # as a single file. + return [path] + return ret + + +def load_some_model(path: Path) -> ModelPlus: + '''Load a model of any supported format.''' + # Be extra-friendly and accept either a file or a directory: + if path.is_dir(): + # Check if it's a set of safetensors files first + files = list(path.glob("model-00001-of-*.safetensors")) + if not files: + # Try the PyTorch patterns too, with lower priority + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] + files = [file for glob in globs for file in path.glob(glob)] + if not files: + # Try NE too, but with lower priority, since if both a non-NE + # model and a NE model exist in the same directory, we assume the + # latter was converted from the former. + files = list(path.glob("ne-model*.bin*")) + if not files: + raise Exception(f"Can't find model in directory {path}") + if len(files) > 1: + raise Exception(f"Found multiple models in {path}, not sure which to pick: {files}") + path = files[0] + + paths = find_multifile_paths(path) + models_plus: List[ModelPlus] = [] + for path in paths: + print(f"Loading model file {path}") + models_plus.append(lazy_load_file(path)) + + model_plus = merge_multifile_models(models_plus) + return model_plus + + +def filter_and_sort_tensors(model: LazyModel) -> LazyModel: + return {name: model[name] for name in TENSORS_LIST if name in model} + + +def load_vocab(path: Path) -> SentencePieceVocab: + # Be extra-friendly and accept either a file or a directory. Also, if it's + # a directory, it might be the model directory, and tokenizer.model might + # be in the parent of that. + if path.is_dir(): + path2 = path / "tokenizer.model" + # Use `.parent` instead of /.. to handle the symlink case better. + path3 = path.parent / "tokenizer.model" + if path2.exists(): + path = path2 + elif path3.exists(): + path = path3 + else: + raise FileNotFoundError( + f"Could not find tokenizer.model in {path} or its parent; if it's in another directory,\ + pass the directory as --vocab-dir" + ) + added_tokens_path = path.parent / "added_tokens.json" + print(f"Loading vocab file {path}") + return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) + + +def default_outfile(model_paths: List[Path], params: Params) -> Path: + namestr = { + NEFileType.AllF32: "f32", + NEFileType.MostlyF16: "f16", + NEFileType.MostlyQ4_0: "q4_0", + NEFileType.MostlyQ4_1: "q4_1", + NEFileType.PerLayerIsQ4_1: "q4_1", + }[params.file_type] + ret = model_paths[0].parent / f"ne-model-{namestr}.bin" + if ret in model_paths: + sys.stderr.write( + f"Error: Default output path ({ret}) would overwrite the input. Please explicitly specify \ + a path using --outfile.\n" + ) + sys.exit(1) + return ret + + +def do_dump_model(model_plus: ModelPlus) -> None: + print(f"model_plus.paths = {model_plus.paths!r}") + print(f"model_plus.format = {model_plus.format!r}") + print(f"model_plus.vocab = {model_plus.vocab!r}") + for name, lazy_tensor in model_plus.model.items(): + print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") + + +def main(args_in: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser(description="Convert a LLaMa model to a NE compatible file") + parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") + parser.add_argument("--dump-single", + action="store_true", + help="don't convert, just show what's in a single model file") + parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") + parser.add_argument("--outtype", + choices=["f32", "f16", "q4_1", "q4_0"], + help="output format (default: based on input)") + parser.add_argument("--vocab-dir", + type=Path, + help="directory containing tokenizer.model, if separate from model file") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", + type=Path, + help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") + args = parser.parse_args(args_in) + + vocab: Vocab + if args.dump_single: + model_plus = lazy_load_file(args.model) + do_dump_model(model_plus) + elif args.vocab_only: + vocab = load_vocab(args.vocab_dir or args.model) + assert args.outfile, "need --outfile if using --vocab-only" + outfile = args.outfile + OutputFile.write_vocab_only(outfile, vocab) + print(f"Wrote {outfile}") + else: + if Path(args.model).is_dir(): + print("Loadding the model from the local path.") + else: + print("Loadding the model from HF.") + model = AutoModel.from_pretrained(args.model, low_cpu_mem_usage=True, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + cache_path = Path(tokenizer.vocab_file).parent + args.model = cache_path + + model_plus = load_some_model(args.model) + if args.dump: + do_dump_model(model_plus) + return + if model_plus.vocab is not None and args.vocab_dir is None: + vocab = model_plus.vocab + else: + vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent + vocab = load_vocab(vocab_dir) + model = model_plus.model + params = Params.load(model_plus) + model = do_necessary_conversions(model, params) + output_type = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, output_type) + outfile = args.outfile or default_outfile(model_plus.paths, params) + OutputFile.write_all(outfile, params, model, vocab, output_type) + print(f"Wrote {outfile}") + + +if __name__ == '__main__': + main() diff --git a/neural_speed/convert/convert_mpt.py b/neural_speed/convert/convert_mpt.py index fcad21bda..cd56af41d 100644 --- a/neural_speed/convert/convert_mpt.py +++ b/neural_speed/convert/convert_mpt.py @@ -95,6 +95,8 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", 0)) # n_experts + fout.write(struct.pack("i", 0)) # n_expert_used fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_opt.py b/neural_speed/convert/convert_opt.py index 0068311e9..07b7a632a 100644 --- a/neural_speed/convert/convert_opt.py +++ b/neural_speed/convert/convert_opt.py @@ -106,6 +106,8 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", 0)) # n_experts + fout.write(struct.pack("i", 0)) # n_expert_used fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_phi.py b/neural_speed/convert/convert_phi.py index f74fdf5d1..6e02b0b55 100644 --- a/neural_speed/convert/convert_phi.py +++ b/neural_speed/convert/convert_phi.py @@ -197,9 +197,14 @@ def phi_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", 0)) # n_experts + fout.write(struct.pack("i", 0)) # n_expert_used fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor + fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else -1)) fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1)) fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1)) diff --git a/neural_speed/convert/convert_quantized_bloom.py b/neural_speed/convert/convert_quantized_bloom.py index b68d4bff9..21ce87ab6 100644 --- a/neural_speed/convert/convert_quantized_bloom.py +++ b/neural_speed/convert/convert_quantized_bloom.py @@ -170,7 +170,12 @@ def bytes_to_unicode(): f.write(struct.pack("i", 0)) f.write(struct.pack("i", 0)) f.write(struct.pack("i", 0)) +f.write(struct.pack("i", 0)) # n_experts +f.write(struct.pack("i", 0)) # n_expert_used f.write(struct.pack("f", 1e-6)) # rms norm eps +f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled +f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings +f.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) f.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1)) f.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2)) diff --git a/neural_speed/convert/convert_quantized_gptj.py b/neural_speed/convert/convert_quantized_gptj.py index 44a3a5d59..5ce8f863e 100644 --- a/neural_speed/convert/convert_quantized_gptj.py +++ b/neural_speed/convert/convert_quantized_gptj.py @@ -142,11 +142,12 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", 0)) # n_experts + fout.write(struct.pack("i", 0)) # n_expert_used fout.write(struct.pack("f", hparams.get( "rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor - fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) diff --git a/neural_speed/convert/convert_quantized_llama.py b/neural_speed/convert/convert_quantized_llama.py index dc51b37b4..b3bbfd9c5 100644 --- a/neural_speed/convert/convert_quantized_llama.py +++ b/neural_speed/convert/convert_quantized_llama.py @@ -146,6 +146,8 @@ def main(args_in: Optional[List[str]] = None) -> None: f.write(struct.pack("i", 0)) f.write(struct.pack("i", ffn_hidden_size)) f.write(struct.pack("i", 0)) + f.write(struct.pack("i", 0)) # n_experts + f.write(struct.pack("i", 0)) # n_expert_used f.write(struct.pack("f", config["rms_norm_eps"])) f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000)) diff --git a/neural_speed/convert/convert_quantized_mistral.py b/neural_speed/convert/convert_quantized_mistral.py index 82403dc94..8e154a8ca 100644 --- a/neural_speed/convert/convert_quantized_mistral.py +++ b/neural_speed/convert/convert_quantized_mistral.py @@ -82,10 +82,15 @@ def main(args_in: Optional[List[str]] = None) -> None: f.write(struct.pack("i", 0)) f.write(struct.pack("i", ffn_hidden_size)) f.write(struct.pack("i", 0)) + f.write(struct.pack("i", 0)) # n_experts + f.write(struct.pack("i", 0)) # n_expert_used f.write(struct.pack("f", config["rms_norm_eps"])) f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000)) f.write(struct.pack("f", rope_scale)) + f.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + f.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + f.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) # TODO, bos_token_id = 0 in https://huggingface.co/decapoda-research/llama-7b-hf/blob/main/config.json # but bos_token_id = 1 in llama.cpp diff --git a/neural_speed/convert/convert_qwen.py b/neural_speed/convert/convert_qwen.py index 900d8cfe8..704aa9ee6 100644 --- a/neural_speed/convert/convert_qwen.py +++ b/neural_speed/convert/convert_qwen.py @@ -112,6 +112,8 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", hparams["intermediate_size"])) fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", 0)) # n_experts + fout.write(struct.pack("i", 0)) # n_expert_used fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_starcoder.py b/neural_speed/convert/convert_starcoder.py index 11dec1f70..f176ef8d9 100644 --- a/neural_speed/convert/convert_starcoder.py +++ b/neural_speed/convert/convert_starcoder.py @@ -110,6 +110,8 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", 0)) # n_experts + fout.write(struct.pack("i", 0)) # n_expert_used fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/core/layers/Ops.h b/neural_speed/core/layers/Ops.h index dc3e28a8d..b0f441b42 100644 --- a/neural_speed/core/layers/Ops.h +++ b/neural_speed/core/layers/Ops.h @@ -48,6 +48,8 @@ enum ne_op { NE_OP_MUL_MAT, NE_OP_MUL_MAT_BIAS, + NE_OP_MUL_MAT_ID, + NE_OP_MUL_ID_FFN_SILU, NE_OP_SCALE, NE_OP_SET, NE_OP_CPY, @@ -88,6 +90,7 @@ enum ne_op { NE_OP_DUMP_TENSOR, NE_OP_DEBUG, NE_OP_CONV_1D, + NE_OP_ARGSORT, NE_OP_COUNT, }; diff --git a/neural_speed/core/layers/argsort.cpp b/neural_speed/core/layers/argsort.cpp new file mode 100644 index 000000000..958b4a8d1 --- /dev/null +++ b/neural_speed/core/layers/argsort.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "argsort.h" +#include+#include + +static void ne_compute_forward_argsort_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, + struct ne_tensor* dst) { + if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + return; + } + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const size_t nb00 = src0->nb[0]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + for (int64_t i = ith; i < nr; i += nth) { + int32_t* dst_data = (int32_t*)((char*)dst->data + i * nb1); + const float* src_data = (float*)((char*)src0->data + i * nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + std::sort(dst_data, dst_data + ne0, [src_data](int pos1, int pos2) { return (src_data[pos1] > src_data[pos2]); }); + } +} +void ne_compute_forward_argsort(const struct ne_compute_params* params, const struct ne_tensor* src0, + struct ne_tensor* dst) { + switch (src0->type) { + case NE_TYPE_F32: { + ne_compute_forward_argsort_f32(params, src0, dst); + } break; + default: { + NE_ASSERT(false); + } break; + } +} diff --git a/neural_speed/core/layers/argsort.h b/neural_speed/core/layers/argsort.h new file mode 100644 index 000000000..a9c7c2058 --- /dev/null +++ b/neural_speed/core/layers/argsort.h @@ -0,0 +1,28 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "core/ne.h" +#include "core/data_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ne_compute_forward_argsort(const struct ne_compute_params* params, const struct ne_tensor* src0, + struct ne_tensor* dst); + +#ifdef __cplusplus +} +#endif diff --git a/neural_speed/core/layers/layers.h b/neural_speed/core/layers/layers.h index ba7ba0816..34f7620b8 100644 --- a/neural_speed/core/layers/layers.h +++ b/neural_speed/core/layers/layers.h @@ -16,3 +16,4 @@ #include "conv.h" #include "memory.h" +#include "argsort.h" diff --git a/neural_speed/core/layers/mha_dense.cpp b/neural_speed/core/layers/mha_dense.cpp index 0fcd39031..af2953514 100644 --- a/neural_speed/core/layers/mha_dense.cpp +++ b/neural_speed/core/layers/mha_dense.cpp @@ -72,7 +72,7 @@ bool bestla_reordered_attn_fp32_support(const attn_shape_t* params) { // TODO(Yi): check K V's layout if (_cd->AMX_BF16()) return true; #endif - return _cd->AVX512F() || _cd->AVX2(); // use avx2 and f16c on avx2 platforms + return !_cd->AVX512F() || _cd->AVX2(); // use avx2 and f16c on avx2 platforms } // kv cache sizes in bytes per layer per batch per beam for; void bestla_reordered_attn_fp32_batch_kv_info(const kv_shape_t* params, kv_cache_info_t* out) { diff --git a/neural_speed/core/ne.h b/neural_speed/core/ne.h index 4790a297e..33bf4f0b6 100644 --- a/neural_speed/core/ne.h +++ b/neural_speed/core/ne.h @@ -43,7 +43,7 @@ #define NE_MAX_NODES 16384 #define NE_MAX_PARAMS 256 #define NE_MAX_CONTEXTS 64 -#define NE_MAX_OPT 4 +#define NE_MAX_OPT 36 #define NE_DEFAULT_N_THREADS 4 #define NE_MAX_OP_PARAMS 32 diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c index 38c76a252..791c94076 100644 --- a/neural_speed/core/ne_layers.c +++ b/neural_speed/core/ne_layers.c @@ -402,9 +402,10 @@ static const char* NE_OP_LABEL[NE_OP_COUNT] = { "NORM", "RMS_NORM", "RMS_NORM_BACK", - + "ARGSORT", "MUL_MAT", "MUL_MAT_WITH_BIAS", + "MUL_MAT_ID", "SCALE", "SET", "CPY", @@ -431,10 +432,10 @@ static const char* NE_OP_LABEL[NE_OP_COUNT] = { "FFN_SILU", "FFN_GeLU", "FFN_ADD_GeLU", + "FFN_ID_SILU", "FLASH_ATTN", "FLASH_ATTN_KV_UPDATE", "FLASH_FF", - "MAP_UNARY", "MAP_BINARY", "SPLIT", @@ -445,7 +446,7 @@ static const char* NE_OP_LABEL[NE_OP_COUNT] = { "DEBUG", }; -static_assert(NE_OP_COUNT == 64, "NE_OP_COUNT != 64"); +static_assert(NE_OP_COUNT == 67, "NE_OP_COUNT != 67"); static const char* NE_OP_SYMBOL[NE_OP_COUNT] = { "none", @@ -479,6 +480,7 @@ static const char* NE_OP_SYMBOL[NE_OP_COUNT] = { "X*Y", "X*Y+Z", "x*v", + "matmul_id", "y-\\>view(x)", "x-\\>y", "cont(x)", @@ -502,16 +504,17 @@ static const char* NE_OP_SYMBOL[NE_OP_COUNT] = { "QKV(x)", "ffn_silu(x)", + "ffn_id_silu(x)", "ffn_gelu(x)", "ffn_gelu_with_bias(x)", "flash_attn(x)", "flash_attn_kv_update(x)", "flash_ff(x)", - "f(x)", "f(x,y)", "conv_1d(x)", "debug(x)", + "argsort(x)", }; static_assert(sizeof(struct ne_object) % NE_MEM_ALIGN == 0, "ne_object size must be a multiple of NE_MEM_ALIGN"); @@ -2178,6 +2181,109 @@ struct ne_tensor* ne_mul_mat_with_bias(struct ne_context* ctx, struct ne_tensor* return result; } +struct ne_tensor* ne_mul_mat_id(struct ne_context* ctx, struct ne_tensor* const as[], int n_as, struct ne_tensor* ids, + int id, struct ne_tensor* b) { + NE_ASSERT(ids->type == NE_TYPE_I32); + NE_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); + NE_ASSERT(ids->ne[1] == b->ne[1]); + NE_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]); + NE_ASSERT(n_as > 0 && n_as <= 8); + NE_ASSERT(id >= 0 && id < ids->ne[0]); + + bool is_node = false; + + if (as[0]->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = {as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3]}; + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne, NE_SIZE_CALC); + int params[] = {id, n_as}; + ne_set_op_params(result, ¶ms, sizeof(params)); + result->op = NE_OP_MUL_MAT_ID; + result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->src0 = ids; + result->src1 = b; + + for (int i = 0; i < n_as; i++) { + struct ne_tensor* a = as[i]; + NE_ASSERT(ne_are_same_shape(as[0], a)); + NE_ASSERT(ne_can_mul_mat(a, b)); + NE_ASSERT(!ne_is_transposed(a)); + result->opt[i] = a; + } + + return result; +} + +struct ne_tensor* ne_mul_id_ffn_silu(struct ne_context* ctx, struct ne_tensor* const down[], + struct ne_tensor* const gate[], struct ne_tensor* const up[], int n_as, + struct ne_tensor* ids, int id, struct ne_tensor* src) { + struct ne_tensor* w1 = gate[0]; + struct ne_tensor* w2 = down[0]; + struct ne_tensor* w3 = up[0]; + NE_ASSERT(ids->type == NE_TYPE_I32); + NE_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); + NE_ASSERT(ids->ne[1] == src->ne[1]); + NE_ASSERT(ids->ne[2] == src->ne[2] && ids->ne[3] == src->ne[3]); + NE_ASSERT(n_as > 0 && n_as <= 8); + NE_ASSERT(id >= 0 && id < ids->ne[0]); + NE_ASSERT(ne_are_same_shape(w1, w3)); + NE_ASSERT(w2->ne[0] == w1->ne[1]); + + bool is_node = false; + + if (down[0]->grad || src->grad) { + is_node = true; + } + const int64_t ne[4] = {w2->ne[1], src->ne[1], src->ne[2], src->ne[3]}; + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC); + const int64_t tne[4] = {w1->ne[1], src->ne[1], src->ne[2], src->ne[3]}; + struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); + struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); + int params[] = {id, n_as}; + ne_set_op_params(result, ¶ms, sizeof(params)); + result->op = NE_OP_MUL_ID_FFN_SILU; + result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->src0 = src; + result->src1 = ids; + for (int i = 0; i < n_as; i++) { + struct ne_tensor* a = gate[i]; + struct ne_tensor* b = down[i]; + struct ne_tensor* c = up[i]; + result->opt[i] = a; + result->opt[i + 8] = b; + result->opt[i + 16] = c; + } + result->opt[24] = tmp; + result->opt[25] = tmp1; + // struct ne_tensor *result = ne_ffn_silu(ctx,gate[row_id], down[row_id],up[row_id], b); + return result; +} +struct ne_tensor* ne_argsort(struct ne_context* ctx, struct ne_tensor* a) { + bool is_node = false; + + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_I32, NE_MAX_DIMS, a->ne, NE_SIZE_CALC); + + result->op = NE_OP_ARGSORT; + result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->src0 = a; + + return result; +} + +// ne_top_k + +struct ne_tensor* ne_top_k(struct ne_context* ctx, struct ne_tensor* a, int k) { + NE_ASSERT(a->ne[0] >= k); + + struct ne_tensor* result = ne_argsort(ctx, a); + + result = ne_view_4d(ctx, result, k, result->ne[1], result->ne[2], result->ne[3], result->nb[1], result->nb[2], + result->nb[3], 0); + + return result; +} // ne_mul_qkv struct ne_tensor* ne_mul_qkv(struct ne_context* ctx, struct ne_tensor* qw, struct ne_tensor* kw, struct ne_tensor* vw, @@ -2754,17 +2860,22 @@ struct ne_tensor* ne_transpose(struct ne_context* ctx, struct ne_tensor* a) { // ne_get_rows struct ne_tensor* ne_get_rows(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* b) { - NE_ASSERT(ne_is_matrix(a) && ne_is_vector(b) && b->type == NE_TYPE_I32); + NE_ASSERT(a->ne[2] == b->ne[1]); + NE_ASSERT(b->ne[3] == 1); + NE_ASSERT(b->type == NE_TYPE_I32); bool is_node = false; if (a->grad || b->grad) { is_node = true; } - + enum ne_type type = NE_TYPE_F32; + if (a->type == NE_TYPE_I32) { + type = a->type; + } // TODO: implement non F32 return // struct ne_tensor * result = ne_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ne_tensor* result = ne_new_tensor_2d(ctx, NE_TYPE_F32, a->ne[0], b->ne[0], NE_SIZE_CALC); + struct ne_tensor* result = ne_new_tensor_4d(ctx, NE_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2], NE_SIZE_CALC); result->op = NE_OP_GET_ROWS; result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; @@ -5265,7 +5376,7 @@ static void ne_compute_forward_mul(const struct ne_compute_params* params, const static void ne_compute_forward_div_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, const struct ne_tensor* src1, struct ne_tensor* dst) { - assert(params->ith == 0); + // assert(params->ith == 0); assert(ne_are_same_shape(src0, src1) && ne_are_same_shape(src0, dst)); if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { @@ -5494,7 +5605,7 @@ static void ne_compute_forward_sum(const struct ne_compute_params* params, const static void ne_compute_forward_sum_rows_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, struct ne_tensor* dst) { - NE_ASSERT(params->ith == 0); + // NE_ASSERT(params->ith == 0); if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { return; @@ -6435,7 +6546,7 @@ static void ne_compute_forward_mul_mat_f32(const struct ne_compute_params* param const int64_t ir0 = (ir1 / ne11) % (ne02 * ne03); const int64_t i03 = (ir0 / (ne02)); - // Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2. + // Hack for "Falcon multi-query-attention key stutter" / alternative to ne_repeat2. // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470: const int64_t i02 = (i12 / (ne12 / ne02)); // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon) @@ -6558,7 +6669,7 @@ static void ne_compute_forward_mul_mat_f16_f32(const struct ne_compute_params* p const int64_t ir0 = (ir1 / ne11) % (ne02 * ne03); const int64_t i03 = (ir0 / (ne02)); - // Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2. + // Hack for "Falcon multi-query-attention key stutter" / alternative to ne_repeat2. // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470: const int64_t i02 = (i12 / (ne12 / ne02)); // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon) @@ -6692,7 +6803,7 @@ static void ne_compute_forward_mul_mat_q_f32(const struct ne_compute_params* par const int64_t ir0 = (ir1 / ne11) % (ne02 * ne03); const int64_t i03 = (ir0 / (ne02)); - // Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2. + // Hack for "Falcon multi-query-attention key stutter" / alternative to ne_repeat2. // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470: const int64_t i02 = (i12 / (ne12 / ne02)); // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon) @@ -6831,6 +6942,606 @@ static void ne_compute_forward_mul_mat(const struct ne_compute_params* params, c } } +static void ne_compute_forward_mul_mat_id_q_f32(const struct ne_compute_params* params, const struct ne_tensor* ids, + const struct ne_tensor* src1, struct ne_tensor* dst) { + int64_t t0 = ne_perf_time_us(); + UNUSED(t0); + const struct ne_tensor* src0 = dst->opt[0]; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const size_t nb00 = src0->nb[0]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb10 = src1->nb[0]; + + const size_t nb11 = src1->nb[1]; + UNUSED(nb11); + const size_t nb12 = src1->nb[2]; + UNUSED(nb12); + const size_t nb13 = src1->nb[3]; + UNUSED(nb13); + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const enum ne_type type = src0->type; + quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; + vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; + enum ne_type const vec_dot_type = quantize_fns[type].vec_dot_type; + + NE_ASSERT(ne0 == ne01); + NE_ASSERT(ne1 == ne11); + NE_ASSERT(ne2 == ne12); + NE_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + NE_ASSERT(nb00 == (int)NE_TYPE_SIZE[type]); + NE_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + NE_ASSERT(nb0 == sizeof(float)); + NE_ASSERT(nb0 <= nb1); + NE_ASSERT(nb1 <= nb2); + NE_ASSERT(nb2 <= nb3); + const int id = dst->op_params[0]; + const int n_as = dst->op_params[1]; + // char * wdata_src1_end = (char *)params->wdata; + // int64_t wdata_src1_end = 0; + +#define mmid_matrix_row(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + if (params->type == NE_TASK_INIT) { + if (ith != 0) { + return; + } + char* wdata = params->wdata; + const size_t row_size = ne10 * NE_TYPE_SIZE[vec_dot_type] / NE_BLCK_SIZE[vec_dot_type]; + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + quantize_row_q_dot((float*)((char*)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11), (void*)wdata, ne10); + wdata += row_size; + } + } + } + + return; + } + + if (params->type == NE_TASK_FINALIZE) { + return; + } + int64_t matrix_row_counts[100]; // [n_as] + int64_t matrix_rows[30000]; // [n_as][ne11] + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + memset(matrix_rows, -1, 30000 * sizeof(int64_t)); + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { + const int32_t row_id = *(const int32_t*)((const char*)ids->data + i01 * ids->nb[1] + id * ids->nb[0]); + NE_ASSERT(row_id >= 0 && row_id < n_as); + mmid_matrix_row(row_id, matrix_row_counts[row_id]) = i01; + matrix_row_counts[row_id] += 1; + } + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + const struct ne_tensor* src0_cur = dst->opt[cur_a]; + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1) / nth; + + const int64_t ir10 = dr * ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); + + // src1 rows + const int64_t nr1 = cne1 * ne12 * ne13; + + void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ne10 * NE_TYPE_SIZE[vec_dot_type] / NE_BLCK_SIZE[vec_dot_type]; + + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1 / (ne12 * cne1)); + const int64_t i12 = (ir1 - i13 * ne12 * cne1) / cne1; + const int64_t _i11 = (ir1 - i13 * ne12 * cne1 - i12 * cne1); + const int64_t i11 = mmid_matrix_row(cur_a, _i11); + if (i11 == -1) { + continue; + } + + const int64_t ir0 = (ir1 / ne11) % (ne02 * ne03); + const int64_t i03 = (ir0 / (ne02)); + // Hack for "Falcon multi-query-attention key stutter" / alternative to ne_repeat2. + // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470: + const int64_t i02 = (i12 / (ne12 / ne02)); + // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon) + // const int64_t i02 = (ir0 - i03*ne02); + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + char* src0_row = (char*)src0_cur->data + (0 + i02 * nb02 + i03 * nb03); + char* src1_col = (char*)wdata + (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size; + + float* dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + for (int64_t ir = ir10; ir < ir11; ++ir) { + vec_dot_q(ne00, &dst_col[ir], src0_row + ir * nb01, src1_col); + } + } + } +} + +static void ne_compute_forward_mul_mat_id_f32(const struct ne_compute_params* params, const struct ne_tensor* ids, + const struct ne_tensor* src1, struct ne_tensor* dst) { + int64_t t0 = ne_perf_time_us(); + UNUSED(t0); + const struct ne_tensor* src0 = dst->opt[0]; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne11 = src1->ne[1]; + + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const size_t nb00 = src0->nb[0]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb10 = src1->nb[0]; + + const size_t nb11 = src1->nb[1]; + UNUSED(nb11); + const size_t nb12 = src1->nb[2]; + UNUSED(nb12); + const size_t nb13 = src1->nb[3]; + UNUSED(nb13); + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + NE_ASSERT(ne0 == ne01); + NE_ASSERT(ne1 == ne11); + NE_ASSERT(ne2 == ne12); + NE_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + NE_ASSERT(nb00 == sizeof(float)); + NE_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + NE_ASSERT(nb0 == sizeof(float)); + NE_ASSERT(nb0 <= nb1); + NE_ASSERT(nb1 <= nb2); + NE_ASSERT(nb2 <= nb3); + const int id = dst->op_params[0]; + const int n_as = dst->op_params[1]; + // char * wdata_src1_end = (char *)params->wdata; + // int64_t wdata_src1_end = 0; + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + if (params->type == NE_TASK_INIT) { + return; + } + + if (params->type == NE_TASK_FINALIZE) { + return; + } + int64_t matrix_row_counts[100]; // [n_as] + int64_t matrix_rows[30000]; // [n_as][ne11] +#define mmid_matrix_row(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + memset(matrix_rows, -1, 30000 * sizeof(int64_t)); + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { + const int32_t row_id = *(const int32_t*)((const char*)ids->data + i01 * ids->nb[1] + id * ids->nb[0]); + NE_ASSERT(row_id >= 0 && row_id < n_as); + mmid_matrix_row(row_id, matrix_row_counts[row_id]) = i01; + matrix_row_counts[row_id] += 1; + } + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + const struct ne_tensor* src0_cur = dst->opt[cur_a]; + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1) / nth; + + const int64_t ir10 = dr * ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); + + // src1 rows + const int64_t nr1 = cne1 * ne12 * ne13; + + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1 / (ne12 * cne1)); + const int64_t i12 = (ir1 - i13 * ne12 * cne1) / cne1; + const int64_t _i11 = (ir1 - i13 * ne12 * cne1 - i12 * cne1); + const int64_t i11 = mmid_matrix_row(cur_a, _i11); + if (i11 == -1) { + continue; + } + + const int64_t ir0 = (ir1 / ne11) % (ne02 * ne03); + const int64_t i03 = (ir0 / (ne02)); + // Hack for "Falcon multi-query-attention key stutter" / alternative to ne_repeat2. + // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470: + const int64_t i02 = (i12 / (ne12 / ne02)); + // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon) + // const int64_t i02 = (ir0 - i03*ne02); + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + char* src0_row = (char*)src0_cur->data + (0 + i02 * nb02 + i03 * nb03); + char* src1_col = (char*)src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); + + float* dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + for (int64_t ir = ir10; ir < ir11; ++ir) { + ne_vec_dot_f32(ne00, &dst_col[ir], (float*)(src0_row + ir * nb01), (float*)src1_col); + } + } + } +} + +static void ne_compute_forward_mul_mat_id_f16_f32(const struct ne_compute_params* params, const struct ne_tensor* ids, + const struct ne_tensor* src1, struct ne_tensor* dst) { + int64_t t0 = ne_perf_time_us(); + UNUSED(t0); + const struct ne_tensor* src0 = dst->opt[0]; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const size_t nb00 = src0->nb[0]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb10 = src1->nb[0]; + + const size_t nb11 = src1->nb[1]; + UNUSED(nb11); + const size_t nb12 = src1->nb[2]; + UNUSED(nb12); + const size_t nb13 = src1->nb[3]; + UNUSED(nb13); + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + NE_ASSERT(ne0 == ne01); + NE_ASSERT(ne1 == ne11); + NE_ASSERT(ne2 == ne12); + NE_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + NE_ASSERT(nb00 == sizeof(ne_fp16_t)); + + // dst cannot be transposed or permuted + NE_ASSERT(nb0 == sizeof(float)); + NE_ASSERT(nb0 <= nb1); + NE_ASSERT(nb1 <= nb2); + NE_ASSERT(nb2 <= nb3); + const int id = dst->op_params[0]; + const int n_as = dst->op_params[1]; + // char * wdata_src1_end = (char *)params->wdata; + // int64_t wdata_src1_end = 0; + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + if (params->type == NE_TASK_INIT) { + ne_fp16_t* const wdata = params->wdata; + + size_t id = 0; + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + wdata[id++] = + NE_FP32_TO_FP16(*(float*)((char*)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11 + i10 * nb10)); + } + } + } + } + + NE_ASSERT(id * sizeof(ne_fp16_t) <= params->wsize); + + return; + } + + if (params->type == NE_TASK_FINALIZE) { + return; + } + int64_t matrix_row_counts[100]; // [n_as] + int64_t matrix_rows[30000]; // [n_as][ne11] +#define mmid_matrix_row(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + memset(matrix_rows, -1, 30000 * sizeof(int64_t)); + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { + const int32_t row_id = *(const int32_t*)((const char*)ids->data + i01 * ids->nb[1] + id * ids->nb[0]); + NE_ASSERT(row_id >= 0 && row_id < n_as); + mmid_matrix_row(row_id, matrix_row_counts[row_id]) = i01; + matrix_row_counts[row_id] += 1; + } + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + assert(nb10 / 2 == sizeof(ne_fp16_t)); + const struct ne_tensor* src0_cur = dst->opt[cur_a]; + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1) / nth; + + const int64_t ir10 = dr * ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); + + // src1 rows + const int64_t nr1 = cne1 * ne12 * ne13; + void* wdata = params->wdata; + const size_t row_size = ne10 * NE_TYPE_SIZE[NE_TYPE_F16]; + + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1 / (ne12 * cne1)); + const int64_t i12 = (ir1 - i13 * ne12 * cne1) / cne1; + const int64_t _i11 = (ir1 - i13 * ne12 * cne1 - i12 * cne1); + const int64_t i11 = mmid_matrix_row(cur_a, _i11); + if (i11 == -1) { + continue; + } + + const int64_t ir0 = (ir1 / ne11) % (ne02 * ne03); + const int64_t i03 = (ir0 / (ne02)); + // Hack for "Falcon multi-query-attention key stutter" / alternative to ne_repeat2. + // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470: + const int64_t i02 = (i12 / (ne12 / ne02)); + // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon) + // const int64_t i02 = (ir0 - i03*ne02); + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + char* src0_row = (char*)src0_cur->data + (0 + i02 * nb02 + i03 * nb03); + char* src1_col = (char*)wdata + (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size; + + float* dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + for (int64_t ir = ir10; ir < ir11; ++ir) { + ne_vec_dot_f16(ne00, &dst_col[ir], (ne_fp16_t*)(src0_row + ir * nb01), (ne_fp16_t*)src1_col); + } + } + } +} + +static void ne_compute_forward_mul_mat_id_q_f32_bestla(const struct ne_compute_params* params, + const struct ne_tensor* ids, const struct ne_tensor* src1, + struct ne_tensor* dst) { + int64_t t0 = ne_perf_time_us(); + UNUSED(t0); + const struct ne_tensor* src0 = dst->opt[0]; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const size_t nb00 = src0->nb[0]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb10 = src1->nb[0]; + + const size_t nb11 = src1->nb[1]; + UNUSED(nb11); + const size_t nb12 = src1->nb[2]; + UNUSED(nb12); + const size_t nb13 = src1->nb[3]; + UNUSED(nb13); + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + NE_ASSERT(ne0 == ne01); + NE_ASSERT(ne1 == ne11); + NE_ASSERT(ne2 == ne12); + NE_ASSERT(ne3 == ne13); + + const enum ne_type type = src0->type; + quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; + vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; + enum ne_type const vec_dot_type = quantize_fns[type].vec_dot_type; + // we don't support permuted src0 or src1 + NE_ASSERT(nb00 == (int)NE_TYPE_SIZE[type]); + NE_ASSERT(nb10 == sizeof(float)); + // dst cannot be transposed or permuted + NE_ASSERT(nb0 == sizeof(float)); + NE_ASSERT(nb0 <= nb1); + NE_ASSERT(nb1 <= nb2); + NE_ASSERT(nb2 <= nb3); + const int id = dst->op_params[0]; + const int n_as = dst->op_params[1]; + // char * wdata_src1_end = (char *)params->wdata; + // int64_t wdata_src1_end = 0; + int64_t matrix_row_counts[100]; // [n_as] + int64_t matrix_rows[30000]; // [n_as][ne11] +#define mmid_matrix_row(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + if (params->type == NE_TASK_INIT) { + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + memset(matrix_rows, -1, 30000 * sizeof(int64_t)); + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { + const int32_t row_id = *(const int32_t*)((const char*)ids->data + i01 * ids->nb[1] + id * ids->nb[0]); + NE_ASSERT(row_id >= 0 && row_id < n_as); + mmid_matrix_row(row_id, matrix_row_counts[row_id]) = i01; + matrix_row_counts[row_id] += 1; + } + + return; + } + + if (params->type == NE_TASK_FINALIZE) { + return; + } + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + // assert(nb10 / 2 == sizeof(ne_fp16_t)); + const struct ne_tensor* src0_cur = dst->opt[cur_a]; + // parallelize by src0 rows + + // src1 rows + const int64_t nr1 = cne1 * ne12 * ne13; + const size_t row_size = ne10 * NE_TYPE_SIZE[src1->type]; + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1 / (ne12 * cne1)); + const int64_t i12 = (ir1 - i13 * ne12 * cne1) / cne1; + const int64_t _i11 = (ir1 - i13 * ne12 * cne1 - i12 * cne1); + const int64_t i11 = mmid_matrix_row(cur_a, _i11); + if (i11 == -1) { + continue; + } + + const int64_t ir0 = (ir1 / ne11) % (ne02 * ne03); + const int64_t i03 = (ir0 / (ne02)); + // Hack for "Falcon multi-query-attention key stutter" / alternative to ne_repeat2. + // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470: + const int64_t i02 = (i12 / (ne12 / ne02)); + // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon) + // const int64_t i02 = (ir0 - i03*ne02); + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + char* src0_row = (char*)src0_cur->data; + char* src1_col = (char*)src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); + + float* dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + // parallelize by src0 rows + + bestla_f32f32_forward((float*)src1_col, (float*)src0_row, dst_col, 1, ne0, ne10, nb11 / ne_element_size(src1), + nb1 / ne_element_size(dst), params->wdata); + } + } +} +static void ne_compute_forward_mul_mat_id(const struct ne_compute_params* params, const struct ne_tensor* ids, + const struct ne_tensor* src1, struct ne_tensor* dst) { + switch (dst->opt[0]->type) { + case NE_TYPE_Q4_0: + case NE_TYPE_Q4_1: + case NE_TYPE_Q5_0: + case NE_TYPE_Q5_1: + case NE_TYPE_Q8_0: + case NE_TYPE_Q6_K: + case NE_TYPE_Q8_1: { + ne_compute_forward_mul_mat_id_q_f32(params, ids, src1, dst); + } break; + case NE_TYPE_BTLA: { + ne_compute_forward_mul_mat_id_q_f32_bestla(params, ids, src1, dst); + } break; + case NE_TYPE_F16: { + ne_compute_forward_mul_mat_id_f16_f32(params, ids, src1, dst); + } break; + case NE_TYPE_F32: { + ne_compute_forward_mul_mat_id_f32(params, ids, src1, dst); + } break; + default: { + NE_ASSERT(false); + } break; + } +} + static void ne_compute_forward_mul_mat_bias_q_f32_bestla(const struct ne_compute_params* params, const struct ne_tensor* src0, const struct ne_tensor* src1, const struct ne_tensor* bias, struct ne_tensor* dst) { @@ -6939,7 +7650,29 @@ static void ne_compute_forward_mul_qkv(const struct ne_compute_params* params, c bestla_fusion_QKV_f32f32_forward((float*)src->data, qw->data, kw->data, vw->data, (float*)dst->data, m, n, k, k, n, params->wdata); } +static void ne_compute_forward_ffn_id_silu(const struct ne_compute_params* params, const struct ne_tensor* src, + const struct ne_tensor* ids, const struct ne_tensor* tmp, + struct ne_tensor* tmp1, struct ne_tensor* dst) { + const int id = dst->op_params[0]; + if (params->type == NE_TASK_INIT) { + return; + } + + if (params->type == NE_TASK_FINALIZE) { + return; + } + const int32_t row_id = *(const int32_t*)((const char*)ids->data + id * ids->nb[0]); + const struct ne_tensor* w1 = dst->opt[row_id]; + const struct ne_tensor* w2 = dst->opt[row_id + 8]; + const struct ne_tensor* w3 = dst->opt[row_id + 16]; + const int fin = src->ne[0]; + const int fout = dst->ne[0]; + const int fmid = w1->ne[1]; + const int seq = dst->ne[1]; + bestla_fusion_FFN_SiLu_f32f32_forward((float*)src->data, w1->data, w2->data, w3->data, (float*)tmp->data, + (float*)tmp1->data, (float*)dst->data, seq, fin, fmid, fout, params->wdata); +} static void ne_compute_forward_ffn_silu(const struct ne_compute_params* params, const struct ne_tensor* src, const struct ne_tensor* w1, const struct ne_tensor* w2, struct ne_tensor* w3, const struct ne_tensor* tmp, struct ne_tensor* tmp1, struct ne_tensor* dst) { @@ -7212,12 +7945,37 @@ static void ne_compute_forward_get_rows_q(const struct ne_compute_params* params assert(dst->ne[0] == nc); assert(dst->ne[1] == nr); assert(src0->nb[0] == NE_TYPE_SIZE[type]); + assert(src0->ne[2] == src1->ne[1]); + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t*)src1->data)[i]; + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t*)((char*)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); - dequantize_row_q((const void*)((char*)src0->data + r * src0->nb[1]), (float*)((char*)dst->data + i * dst->nb[1]), - nc); + dequantize_row_q((const void*)((char*)src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03), + (float*)((char*)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc); + } + } } } @@ -7235,13 +7993,35 @@ static void ne_compute_forward_get_rows_f16(const struct ne_compute_params* para assert(dst->ne[0] == nc); assert(dst->ne[1] == nr); assert(src0->nb[0] == sizeof(ne_fp16_t)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t*)src1->data)[i]; - - for (int j = 0; j < nc; ++j) { - ne_fp16_t v = ((ne_fp16_t*)((char*)src0->data + r * src0->nb[1]))[j]; - ((float*)((char*)dst->data + i * dst->nb[1]))[j] = NE_FP16_TO_FP32(v); + assert(src0->ne[2] == src1->ne[1]); + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t*)((char*)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); + + ne_fp16_to_fp32_row((const void*)((char*)src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03), + (float*)((char*)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc); + } } } } @@ -7258,13 +8038,38 @@ static void ne_compute_forward_get_rows_f32(const struct ne_compute_params* para const int nr = ne_nelements(src1); assert(dst->ne[0] == nc); - assert(dst->ne[1] == nr); + assert(ne_nrows(dst) == nr); + assert(src0->ne[2] == src1->ne[1]); assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t*)src1->data)[i]; - - ne_vec_cpy_f32(nc, (float*)((char*)dst->data + i * dst->nb[1]), (float*)((char*)src0->data + r * src0->nb[1])); + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t*)((char*)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); + ne_vec_cpy_f32(nc, (float*)((char*)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), + (float*)((char*)src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03)); + } + } } } @@ -9487,9 +10292,18 @@ static void ne_compute_forward(struct ne_compute_params* params, struct ne_tenso case NE_OP_MUL_MAT: { ne_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); } break; + case NE_OP_MUL_MAT_ID: { + ne_compute_forward_mul_mat_id(params, tensor->src0, tensor->src1, tensor); + } break; + case NE_OP_ARGSORT: { + ne_compute_forward_argsort(params, tensor->src0, tensor); + } break; case NE_OP_MUL_QKV: { ne_compute_forward_mul_qkv(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor); } break; + case NE_OP_MUL_ID_FFN_SILU: { + ne_compute_forward_ffn_id_silu(params, tensor->src0, tensor->src1, tensor->opt[24], tensor->opt[25], tensor); + } break; case NE_OP_MUL_FFN_SILU: { ne_compute_forward_ffn_silu(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor->opt[3], tensor); @@ -10466,14 +11280,18 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) { work_size = MAX(work_size, cur); } break; case NE_OP_SUB: + case NE_OP_SUM: case NE_OP_DIV: + case NE_OP_SUM_ROWS: + // { + // node->n_tasks = 1; + // } break; case NE_OP_SQR: case NE_OP_SQRT: case NE_OP_LOG: - case NE_OP_SUM: - case NE_OP_SUM_ROWS: case NE_OP_MEAN: case NE_OP_ABS: + case NE_OP_ARGSORT: case NE_OP_SGN: case NE_OP_NEG: case NE_OP_STEP: @@ -10504,6 +11322,7 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) { node->n_tasks = n_threads; } break; case NE_OP_MUL_MAT_BIAS: + case NE_OP_MUL_MAT_ID: case NE_OP_CONV_1D: case NE_OP_MUL_MAT: { node->n_tasks = n_threads; @@ -10516,17 +11335,20 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) { // printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks); size_t cur = 0; - if (node->src0->type == NE_TYPE_BTLA) { - cur = bestla_f32f32_get_workspace_size(node->src1->ne[1], node->src0->ne[1], node->src1->ne[0], - node->src0->data); + struct ne_tensor* wei = node->src0; + if (node->op == NE_OP_MUL_MAT_ID) { + wei = node->opt[0]; + } + if (wei->type == NE_TYPE_BTLA) { + cur = bestla_f32f32_get_workspace_size(node->src1->ne[1], wei->ne[1], node->src1->ne[0], wei->data); node->n_tasks = 1; - } else if (node->src0->type == NE_TYPE_F16 && node->src1->type == NE_TYPE_F32) { + } else if (wei->type == NE_TYPE_F16 && node->src1->type == NE_TYPE_F32) { cur = NE_TYPE_SIZE[NE_TYPE_F16] * ne_nelements(node->src1); - } else if (node->src0->type == NE_TYPE_F32 && node->src1->type == NE_TYPE_F32) { + } else if (wei->type == NE_TYPE_F32 && node->src1->type == NE_TYPE_F32) { cur = 0; - } else if (ne_is_quantized(node->src0->type) && node->src1->type == NE_TYPE_F32) { + } else if (ne_is_quantized(wei->type) && node->src1->type == NE_TYPE_F32) { { - const enum ne_type type_q = quantize_fns[node->src0->type].vec_dot_type; + const enum ne_type type_q = quantize_fns[wei->type].vec_dot_type; cur = NE_TYPE_SIZE[type_q] * ne_nelements(node->src1) / NE_BLCK_SIZE[type_q]; } } else { @@ -10544,6 +11366,14 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) { work_size = MAX(work_size, cur); node->n_tasks = 1; } break; + case NE_OP_MUL_ID_FFN_SILU: { + size_t cur = 0; + cur = + bestla_fusion_FFN_f32f32_get_workspace_size(node->src0->ne[1], node->src0->ne[0], node->opt[0]->ne[1], + node->opt[9]->ne[1], node->opt[0]->data, node->opt[9]->data); + work_size = MAX(work_size, cur); + node->n_tasks = 1; + } break; case NE_OP_MUL_QKV: { size_t cur = 0; cur = bestla_fusion_QKV_f32f32_get_workspace_size(node->src0->ne[1], node->src1->ne[1], node->src1->ne[0], @@ -10909,12 +11739,16 @@ void ne_graph_profiling(const struct ne_cgraph* cgraph) { NE_PRINT("=== GRAPH Profiling ===\n"); int64_t ip_duration = 0; + int64_t mul_mat_id_duration = 0; for (int i = 0; i < cgraph->n_nodes; i++) { struct ne_tensor* node = cgraph->nodes[i]; if (node->op == NE_OP_MUL_MAT && node->ne[1] == node->ne[2]) { ip_duration += node->perf_time_us; } else { perf_total_per_op_us[node->op] += node->perf_time_us; + if (node->op == NE_OP_MUL_MAT_ID) { + mul_mat_id_duration += node->perf_time_us; + } } } @@ -10925,6 +11759,7 @@ void ne_graph_profiling(const struct ne_cgraph* cgraph) { NE_PRINT("perf_total_per_op_us[%24s] = %7.3f ms\n", NE_OP_LABEL[i], (double)perf_total_per_op_us[i] / 1000.0); } NE_PRINT("perf_total_per_op_us[%24s] = %7.3f ms\n", "INNER PRODUCT", (double)ip_duration / 1000.0); + NE_PRINT("perf_total_per_op_us[%24s] = %7.3f ms\n", "MUL_MAT_ID", (double)mul_mat_id_duration / 1000.0); NE_PRINT("========================================\n"); #else diff --git a/neural_speed/core/ne_layers.h b/neural_speed/core/ne_layers.h index c8332a6e8..21cd48d44 100644 --- a/neural_speed/core/ne_layers.h +++ b/neural_speed/core/ne_layers.h @@ -254,9 +254,17 @@ NE_API struct ne_tensor* ne_rms_norm_back(struct ne_context* ctx, struct ne_tens // result is m columns, p rows NE_API struct ne_tensor* ne_mul_mat(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* b); +NE_API struct ne_tensor* ne_mul_mat_id(struct ne_context* ctx, struct ne_tensor* const as[], int n_as, + struct ne_tensor* ids, int id, struct ne_tensor* b); +NE_API struct ne_tensor* ne_mul_id_ffn_silu(struct ne_context* ctx, struct ne_tensor* const down[], + struct ne_tensor* const gate[], struct ne_tensor* const up[], int n_as, + struct ne_tensor* ids, int id, struct ne_tensor* b); + NE_API struct ne_tensor* ne_mul_mat_with_bias(struct ne_context* ctx, struct ne_tensor* w, struct ne_tensor* b, struct ne_tensor* a); +NE_API struct ne_tensor* ne_argsort(struct ne_context* ctx, struct ne_tensor* a); +NE_API struct ne_tensor* ne_top_k(struct ne_context* ctx, struct ne_tensor* a, int k); // merged Q K V ne_mul_mat NE_API struct ne_tensor* ne_mul_qkv(struct ne_context* ctx, struct ne_tensor* qw, struct ne_tensor* kw, struct ne_tensor* vw, struct ne_tensor* src); diff --git a/neural_speed/models/llama/llama.cpp b/neural_speed/models/llama/llama.cpp index fad3d6e2d..41aedf08d 100644 --- a/neural_speed/models/llama/llama.cpp +++ b/neural_speed/models/llama/llama.cpp @@ -88,6 +88,8 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp int n_head = hparams.n_head; int head_size = n_embd / n_head; int n_head_kv = hparams.n_head_kv; + int n_expert = hparams.n_experts; + int n_expert_used = hparams.n_experts_used; bool enable_tp = false; #ifdef NS_TP_MODEL @@ -147,6 +149,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp struct ne_tensor* embd = ne_new_tensor_1d(ctx0, NE_TYPE_I32, N, NE_SIZE_CALC); ne_set_name(embd, "embd"); + for (int i = 0; i < batch_size; ++i) { memcpy(static_cast (embd->data) + i * N, (inputs + i)->tokens, N * ne_element_size(embd)); } @@ -351,17 +354,70 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp // cur = cur*ffn_norm(broadcasted) cur = ne_mul(ctx0, cur, model.layers[il].norm[1]); } - - if (bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[0]->data, model.layers[il].ffn[1]->data, - model.layers[il].ffn[2]->data, N, cur->ne[0], - model.layers[il].ffn[0]->ne[1], model.layers[il].ffn[1]->ne[1])) { - cur = ne_ffn_silu(ctx0, model.layers[il].ffn[0], model.layers[il].ffn[1], model.layers[il].ffn[2], cur); + if (n_expert == 0) { + if (bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[0]->data, model.layers[il].ffn[1]->data, + model.layers[il].ffn[2]->data, N, cur->ne[0], + model.layers[il].ffn[0]->ne[1], model.layers[il].ffn[1]->ne[1])) { + cur = ne_ffn_silu(ctx0, model.layers[il].ffn[0], model.layers[il].ffn[1], model.layers[il].ffn[2], cur); + } else { + struct ne_tensor* tmp = ne_mul_mat(ctx0, model.layers[il].ffn[2], cur); + cur = ne_mul_mat(ctx0, model.layers[il].ffn[0], cur); + cur = ne_silu(ctx0, cur); + cur = ne_mul(ctx0, cur, tmp); + cur = ne_mul_mat(ctx0, model.layers[il].ffn[1], cur); + } } else { - struct ne_tensor* tmp = ne_mul_mat(ctx0, model.layers[il].ffn[2], cur); - cur = ne_mul_mat(ctx0, model.layers[il].ffn[0], cur); - cur = ne_silu(ctx0, cur); - cur = ne_mul(ctx0, cur, tmp); - cur = ne_mul_mat(ctx0, model.layers[il].ffn[1], cur); + ne_tensor* logits = ne_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] + ne_tensor* probs = ne_soft_max_inplace(ctx0, logits); + ne_tensor* selected_experts = ne_top_k(ctx0, probs, n_expert_used); + ne_tensor* weights = ne_get_rows(ctx0, ne_reshape_3d(ctx0, probs, 1, n_expert, N), selected_experts); + weights = ne_reshape_2d(ctx0, weights, n_expert_used, N); + ne_tensor* weights_sum = ne_sum_rows(ctx0, weights); + weights_sum = ne_repeat(ctx0, weights_sum, weights); + weights = ne_div(ctx0, weights, weights_sum); + ne_tensor* moe_out = nullptr; + + for (int i = 0; i < n_expert_used; ++i) { + ne_tensor* cur_expert; + if (N == 1 && bestla_fusion_FFN_SiLu_f32f32_support( + model.layers[il].ffn_gate_exp[0]->data, model.layers[il].ffn_down_exp[0]->data, + model.layers[il].ffn_up_exp[0]->data, N, cur->ne[0], + model.layers[il].ffn_gate_exp[0]->ne[1], model.layers[il].ffn_down_exp[0]->ne[1])) { + cur_expert = ne_mul_id_ffn_silu(ctx0, model.layers[il].ffn_down_exp, model.layers[il].ffn_gate_exp, + model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur); + } else { + ne_tensor* cur_up = ne_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur); + ne_set_name(cur_up, "ffn_moe_up"); + + ne_tensor* cur_gate = + ne_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur); + ne_set_name(cur_gate, "ffn_moe_gate"); + + cur_gate = ne_silu(ctx0, cur_gate); + ne_set_name(cur_gate, "ffn_moe_silu"); + + cur_expert = ne_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd] + ne_set_name(cur_expert, "ffn_moe_gate_par"); + + cur_expert = ne_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, + cur_expert); // [n_tokens, n_embd] + ne_set_name(cur_expert, "ffn_moe_down"); + } + + cur_expert = + ne_mul(ctx0, cur_expert, + ne_repeat(ctx0, ne_view_2d(ctx0, weights, 1, N, weights->nb[1], i * weights->nb[0]), cur_expert)); + ne_set_name(cur_expert, "ffn_moe_weighted"); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ne_add(ctx0, moe_out, cur_expert); + ne_set_name(moe_out, "ffn_moe_out"); + } + } + + cur = moe_out; } #ifdef NS_TP_MODEL // ffn2 and ffn0 use split row, ffn1 use split column @@ -424,7 +480,6 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp sizeof(float) * n_vocab); } } - // extract embeddings if (!lctx.embedding.empty()) { auto& embedding_out = lctx.embedding; diff --git a/neural_speed/models/llama/llama.h b/neural_speed/models/llama/llama.h index e498de254..2cf7bdd08 100644 --- a/neural_speed/models/llama/llama.h +++ b/neural_speed/models/llama/llama.h @@ -47,7 +47,7 @@ class Llama : public IModel { private: model_archs arch = MODEL_LLAMA; std::unique_ptr ml; - uint32_t n_layer, n_embd, n_ff, n_vocab, n_head, n_head_kv; + uint32_t n_layer, n_embd, n_ff, n_vocab, n_head, n_head_kv, n_expert, n_expert_used; int n_gpu_layer; bool use_mmap, use_mlock, vocab_only; model_scratch scratch; diff --git a/neural_speed/models/llama/llama_utils.cpp b/neural_speed/models/llama/llama_utils.cpp index fd8fe065b..128f249a9 100644 --- a/neural_speed/models/llama/llama_utils.cpp +++ b/neural_speed/models/llama/llama_utils.cpp @@ -79,6 +79,8 @@ void Llama::init(const char* path_model, model_context* ctx, int n_gpu_layer_, b n_layer = hparams.n_layer; n_head_kv = hparams.n_head_kv; n_head = hparams.n_head; + n_expert = hparams.n_experts; + n_expert_used = hparams.n_experts_used; scratch = llama_mem_req(n_layer); model.scratchs = scratch; } @@ -140,9 +142,25 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); // ffn GEMM - layer.ffn[0] = ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, n_ff}, backend); - layer.ffn[1] = ml->get_tensor(layers_i + ".ffn_down.weight", {n_ff, n_embd}, backend); - layer.ffn[2] = ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, n_ff}, backend); + if (ml->verify_tensor(layers_i + ".ffn_gate.weight")) { + NE_ASSERT(n_expert == 0); + NE_ASSERT(n_expert_used == 0); + layer.ffn[0] = ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, n_ff}, backend); + layer.ffn[1] = ml->get_tensor(layers_i + ".ffn_down.weight", {n_ff, n_embd}, backend); + layer.ffn[2] = ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, n_ff}, backend); + } else { + NE_ASSERT(n_expert > 0); + NE_ASSERT(n_expert_used > 0); + layer.ffn_gate_inp = ml->get_tensor(layers_i + ".ffn_gate_inp.weight", {n_embd, n_expert}, backend); + for (uint32_t x = 0; x < n_expert; ++x) { + layer.ffn_gate_exp[x] = + ml->get_tensor(layers_i + ".ffn_gate." + std::to_string(x) + ".weight", {n_embd, n_ff}, backend); + layer.ffn_down_exp[x] = + ml->get_tensor(layers_i + ".ffn_down." + std::to_string(x) + ".weight", {n_ff, n_embd}, backend); + layer.ffn_up_exp[x] = + ml->get_tensor(layers_i + ".ffn_up." + std::to_string(x) + ".weight", {n_embd, n_ff}, backend); + } + } if (backend != NE_BACKEND_CPU) { vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + @@ -176,10 +194,26 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); // ffn GEMM - layer.ffn[0] = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend); - layer.ffn[1] = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {n_ff, n_embd}, backend); - layer.ffn[2] = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend); + if (ml->verify_tensor(layers_i + ".feed_forward.w1.weight")) { + NE_ASSERT(n_expert == 0); + NE_ASSERT(n_expert_used == 0); + layer.ffn[0] = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend); + layer.ffn[1] = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {n_ff, n_embd}, backend); + layer.ffn[2] = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend); + } else { + NE_ASSERT(n_expert > 0); + NE_ASSERT(n_expert_used > 0); + layer.ffn_gate_inp = ml->get_tensor(layers_i + ".ffn_gate_inp.weight", {n_embd, n_expert}, backend); + for (uint32_t x = 0; x < n_expert; ++x) { + layer.ffn_gate_exp[x] = + ml->get_tensor(layers_i + ".ffn_gate." + std::to_string(x) + ".weight", {n_embd, n_ff}, backend); + layer.ffn_down_exp[x] = + ml->get_tensor(layers_i + ".ffn_down." + std::to_string(x) + ".weight", {n_ff, n_embd}, backend); + layer.ffn_up_exp[x] = + ml->get_tensor(layers_i + ".ffn_up." + std::to_string(x) + ".weight", {n_embd, n_ff}, backend); + } + } if (backend != NE_BACKEND_CPU) { vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.norm[1]) + diff --git a/neural_speed/models/model_utils/gguf.h b/neural_speed/models/model_utils/gguf.h index 0018ec7d3..251280fa3 100644 --- a/neural_speed/models/model_utils/gguf.h +++ b/neural_speed/models/model_utils/gguf.h @@ -423,6 +423,8 @@ enum llm_kv { LLM_KV_ATTENTION_CLAMP_KQV, LLM_KV_ATTENTION_LAYERNORM_EPS, LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + LLM_KV_NUM_EXPERTS, + LLM_KV_NUM_EXPERTS_USED, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -466,6 +468,8 @@ static std::map LLM_KV_NAMES = { {LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length"}, {LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual"}, {LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout"}, + {LLM_KV_NUM_EXPERTS, "%s.expert_count"}, + {LLM_KV_NUM_EXPERTS_USED, "%s.expert_used_count"}, {LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count"}, {LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv"}, diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h index 7c6ed97ec..3813a0a09 100644 --- a/neural_speed/models/model_utils/model_files.h +++ b/neural_speed/models/model_utils/model_files.h @@ -908,6 +908,9 @@ struct gguf_loader { GGUF_GET_KEY(ctx_gguf, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT)); GGUF_GET_KEY(ctx_gguf, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); + GGUF_GET_KEY(ctx_gguf, hparams.n_experts, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_NUM_EXPERTS)); + GGUF_GET_KEY(ctx_gguf, hparams.n_experts_used, gguf_get_val_u32, GGUF_TYPE_UINT32, false, + kv(LLM_KV_NUM_EXPERTS_USED)); GGUF_GET_KEY(ctx_gguf, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_BLOCK_COUNT)); GGUF_GET_KEY(ctx_gguf, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); @@ -1095,6 +1098,8 @@ struct model_file_loader { // For ChatGLM-2 hparams.inner_hidden_size = file.read_u32(); + hparams.n_experts = file.read_u32(); + hparams.n_experts_used = file.read_u32(); file.read_raw(&hparams.rms_norm_eps, sizeof(float)); file.read_raw(&hparams.freq_base, sizeof(float)); @@ -1219,6 +1224,8 @@ struct model_file_saver { file.write_u32(hparams.multi_query_group_num); file.write_u32(hparams.ffn_hidden_size); file.write_u32(hparams.inner_hidden_size); + file.write_u32(hparams.n_experts); + file.write_u32(hparams.n_experts_used); file.write_raw(&hparams.rms_norm_eps, sizeof(float)); file.write_raw(&hparams.freq_base, sizeof(float)); diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h index 33e7df888..8f5bc43f1 100644 --- a/neural_speed/models/model_utils/model_types.h +++ b/neural_speed/models/model_utils/model_types.h @@ -45,6 +45,7 @@ #define MODEL_MAX_ATTN 8 #define MODEL_MAX_FFN 6 #define MODEL_MAX_OTHERS 7 +#define MODEL_MAX_EXPERTS 8 #define MODEL_USE_SCRATCH #define MODEL_MAX_SCRATCH_BUFFERS 16 @@ -139,6 +140,9 @@ struct model_hparams { // ChatGLM-1 int32_t inner_hidden_size = 0; + uint32_t n_experts = 0; + uint32_t n_experts_used = 0; + float rope_scaling_factor = 0.0f; int32_t original_max_position_embeddings = 0; int32_t use_yarn = 0; @@ -158,6 +162,14 @@ struct model_layer { // ff struct ne_tensor* ffn[MODEL_MAX_FFN]; + struct ne_tensor* ffn_gate_inp; + + struct ne_tensor* ffn_gate_exp[MODEL_MAX_EXPERTS]; + + struct ne_tensor* ffn_down_exp[MODEL_MAX_EXPERTS]; + + struct ne_tensor* ffn_up_exp[MODEL_MAX_EXPERTS]; + struct ne_tensor* k_cache; struct ne_tensor* v_cache; @@ -471,7 +483,8 @@ class model_name_to_arch { {"dolly", MODEL_GPTNEOX}, {"polyglot", MODEL_GPTNEOX}, {"starcoder", MODEL_STARCODER}, {"falcon", MODEL_FALCON}, {"bloom", MODEL_BLOOM}, {"chatglm2", MODEL_CHATGLM2}, {"chatglm", MODEL_CHATGLM}, {"baichuan", MODEL_BAICHUAN}, {"mistral", MODEL_LLAMA}, - {"qwen", MODEL_QWEN}, {"phi", MODEL_PHI}, {"whisper", MODEL_WHISPER}}; + {"qwen", MODEL_QWEN}, {"phi", MODEL_PHI}, {"whisper", MODEL_WHISPER}, + {"mixtral", MODEL_LLAMA}}; }; #ifdef __cplusplus diff --git a/tests/model-test/cpp_graph_inference.sh b/tests/model-test/cpp_graph_inference.sh index 973c09892..500fb37fd 100644 --- a/tests/model-test/cpp_graph_inference.sh +++ b/tests/model-test/cpp_graph_inference.sh @@ -155,6 +155,7 @@ model_name_map["qwen-7b"]="Qwen/Qwen-7B-Chat" model_name_map["magicoder"]="ise-uiuc/Magicoder-S-DS-6.7B" model_name_map["whisper"]="openai/whisper-tiny" model_name_map["phi2"]="microsoft/phi-2" +model_name_map["mixtral"]="mistralai/Mixtral-8x7B-Instruct-v0.1" function main() { conda_env="$1" @@ -263,6 +264,10 @@ function main() { quant_script="./build/bin/quant_phi" convert_script="${convert_script}/convert_phi.py" infer_cmd="./build/bin/run_phi" + elif [[ "${model}" == "mixtral" ]]; then + quant_script="./build/bin/quant_mixtral" + convert_script="${convert_script}/convert_mixtral.py" + infer_cmd="./build/bin/run_mixtral" else echo "Error: Unexpedted model: $model" 1>&2 exit 1