Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Parallel loading of the model tensors #79

Open
philpax opened this issue Mar 26, 2023 · 5 comments
Open

Parallel loading of the model tensors #79

philpax opened this issue Mar 26, 2023 · 5 comments
Labels
issue:enhancement New feature or request

Comments

@philpax
Copy link
Collaborator

philpax commented Mar 26, 2023

People have reported faster loading of the models in upstream when the tensors are loaded in parallel: ggerganov/llama.cpp#85

This should be pretty easy to do with Rust if we convert loading to an iter and then use par_iter instead. It seems like this should be I/O bound, but perhaps the actual loading process has computational overhead?

@philpax philpax added the issue:enhancement New feature or request label Mar 26, 2023
@philpax philpax mentioned this issue Mar 26, 2023
6 tasks
@KerfuffleV2
Copy link
Contributor

Sort of related to speeding up loading, I've been messing around with rewriting it to use a mmap-based approach and nom. I don't know if it's really on the right track.

This is what just loading the header and vocabulary looks like:

pub mod mmap_loader {
    use mmap_rs::{MmapFlags, MmapOptions};
    #[allow(unused_imports)]
    use nom::{
        branch::alt,
        bytes::complete as nby,
        combinator as ncom,
        error::ParseError,
        multi as nm,
        number::complete::{self as nnum, le_f32, le_i32, le_u32},
        sequence as nseq, IResult, Parser, Slice,
    };
    use std::fs::File;

    use super::*;

    pub struct Flib;

    #[derive(Debug)]
    struct Header {
        legacy: bool,
        hyper: Hyperparameters,
    }

    impl Flib {
        fn parse_header(i: &[u8]) -> IResult<&[u8], Header> {
            let (i, magic) = le_i32(i)?;
            let legacy = match magic {
                ggml::FILE_MAGIC => false,
                ggml::FILE_MAGIC_UNVERSIONED => true,
                _ => return nom::error::context("ohno", ncom::fail)(i),
            };
            ncom::map(Flib::parse_hyperparameters, move |hyper| Header {
                legacy,
                hyper,
            })(i)
        }

        fn parse_hyperparameters(i: &[u8]) -> IResult<&[u8], Hyperparameters> {
            ncom::map(
                nseq::tuple((le_i32, le_i32, le_i32, le_i32, le_i32, le_i32, le_i32)),
                |(n_vocab, n_embd, n_mult, n_head, n_layer, n_rot, f16_)| Hyperparameters {
                    n_vocab,
                    n_ctx: 0,
                    n_embd,
                    n_mult,
                    n_head,
                    n_layer,
                    n_rot,
                    f16_,
                },
            )(i)
        }

        fn parse_vocabulary<'a>(i: &'a [u8], hdr: &Header) -> IResult<&'a [u8], Vocabulary> {
            const TOKEN_PLACEHOLDER: &str = "�";
            let n_vocab = hdr.hyper.n_vocab as usize;
            let legacy = hdr.legacy;
            let mut id_to_token = Vec::with_capacity(n_vocab);
            let mut id_to_token_score = Vec::with_capacity(n_vocab);
            let mut token_to_id = HashMap::with_capacity(n_vocab);
            let vocabitem_parser = |i| {
                nseq::tuple((nm::length_data(le_u32), ncom::cond(!legacy, le_f32)))(i)
                    .map(|(i, (sbytes, score))| (i, (sbytes, score.unwrap_or_default())))
            };
            let folf = |mut mtl: usize, (sbytes, score)| {
                let tid = id_to_token.len();
                let (ok, token) = std::str::from_utf8(sbytes).map_or_else(
                    |_| (false, TOKEN_PLACEHOLDER.to_string()),
                    |s| (true, s.to_string()),
                );
                if ok {
                    mtl = mtl.max(token.len());
                    token_to_id.insert(token.clone(), tid as TokenId);
                }
                id_to_token.push(token);
                id_to_token_score.push(score);
                mtl
            };
            let (i, max_token_length) =
                nm::fold_many_m_n(n_vocab, n_vocab, vocabitem_parser, || 0, folf)(i)?;
            IResult::Ok((
                i,
                Vocabulary {
                    id_to_token,
                    id_to_token_score,
                    token_to_id,
                    max_token_length,
                },
            ))
        }

        pub fn load(path: impl AsRef<Path>) -> Result<(), LoadError> {
            let path = path.as_ref();
            let fp = File::open(path).map_err(|e| LoadError::OpenFileFailed {
                source: e,
                path: path.to_owned(),
            })?;
            let flen = fp.metadata()?.len();
            let m = unsafe {
                MmapOptions::new(flen as usize).and_then(|mo| {
                    mo.with_file(fp, 0)
                        .with_flags(MmapFlags::NO_CORE_DUMP)
                        .map()
                })
            }
            .map_err(|e| LoadError::MmapFailed { source: e })?;
            let mb = m.as_slice();
            let (i, hdr) = Self::parse_header(mb).unwrap();
            println!("Got: {hdr:?}");
            let (i, vocab) = Self::parse_vocabulary(i, &hdr).unwrap();
            println!(
                "Got: {} - {} - {}",
                vocab.max_token_length,
                vocab.id_to_token.len(),
                vocab.token_to_id.len()
            );
            Ok(())
        }
    }
}

I honestly don't really love parsers in Rust, it's so much nicer in Haskell but I guess this is more readable than the current code. A long time ago, I experimented with trying to combine nom and monadic do type notation but it wasn't really practical: https://github.com/KerfuffleV2/mdoexperiments

@philpax
Copy link
Collaborator Author

philpax commented Apr 2, 2023

Along the lines of programmatic parsing, it might also be interesting to explore the use of https://github.com/jam1garner/binrw.

Not sure how that would impact parallel loading or #93, though.

@KerfuffleV2
Copy link
Contributor

Interesting. Weirdly enough, that actually only has limited support for non-streams (i.e. mmap). I don't know if it would be necessary to use the seek features for handling the GGML format, but if so that would mean mmaping was impossible.

@iacore
Copy link
Contributor

iacore commented Apr 8, 2023

Don't really need mmap. smol+nuclei+2 fd should be enough.

@philpax
Copy link
Collaborator Author

philpax commented Apr 24, 2023

With mmap support I'm not sure how relevant this is now. It doesn't do much actual work when setting up the tensors.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
issue:enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants