Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross GPU device mapping feature #395

Closed
joshpopelka20 opened this issue Jun 5, 2024 · 45 comments
Closed

Cross GPU device mapping feature #395

joshpopelka20 opened this issue Jun 5, 2024 · 45 comments
Labels
backend Backend work models Additions to model or architectures new feature New feature or request

Comments

@joshpopelka20
Copy link
Contributor

I'm working with a long context model (gradientai/Llama-3-8B-Instruct-262k) that exceeds the memory of a single A100 GPU. While the model weights are loaded, when I try to run inference, I get CUDA Out of Memory exception.

Requesting a new feature to allow users to use cross GPU device mapping.

@EricLBuehler EricLBuehler added new feature New feature or request backend Backend work models Additions to model or architectures labels Jun 5, 2024
@b0xtch
Copy link

b0xtch commented Jun 21, 2024

Related issue: huggingface/candle#2007

There was an attempt to do tensor parallelism:

#72

@EricLBuehler
Copy link
Owner

Hi @joshpopelka20 and @b0xtch! I just merged #462 which adds cross-GPU device mapping support (including for Python). I plan on implementing tensor parallelism, too, in the future.

@joshpopelka20
Copy link
Contributor Author

@EricLBuehler thanks for adding this feature.

When using the pypi package, I'm getting this error:

        = note: /usr/bin/ld: /tmp/pip-install-8y1wtzkm/mistralrs-cuda_36180abfef7b4f0687d7842e9b298d9e/target/release/build/mistralrs-core-f5a6ed7a31cbdb62/out/libmistralcuda.a(nonzero_bitwise-b50867152df76f01.o): relocation R_X86_64_32 against symbol `_Z17transform_indicesPKjjS0_jPj' can not be used when making a shared object; recompile with -fPIC
                /usr/bin/ld: final link failed: Nonrepresentable section on output
                collect2: error: ld returned 1 exit status

I'm passing the CUDA_NVCC_FLAGS flag so not sure why it's saying to "recompile with -fPIC". These are the commands I'm using:

env["CUDA_NVCC_FLAGS"] = "-fPIE"
result = subprocess.run(['pip', 'install', 'mistralrs-cuda'], env=env)

@EricLBuehler
Copy link
Owner

I'm passing the CUDA_NVCC_FLAGS flag so not sure why it's saying to "recompile with -fPIC". These are the commands I'm using:

Can you try:

env["CUDA_NVCC_FLAGS"] = "-fPIC"
result = subprocess.run(['pip', 'install', 'mistralrs-cuda'], env=env)

The -fPIC requirement may stem from your Linux distribution (some require -fPIE, I'll add this to the README)?

@joshpopelka20gmail
Copy link

Sorry, may not have been clear. I got that error when running with the flag set to '-fPIC'. I haven't tried running the code using the git repo and cargo. Do you want me to verify if that works?

I'd guess they'd be the same though.

@EricLBuehler
Copy link
Owner

Sorry, may not have been clear. I got that error when running with the flag set to '-fPIC'. I haven't tried running the code using the git repo and cargo. Do you want me to verify if that works?

Ah, ok. Not sure if we discussed this before, but what Linux distribution are you using?

@joshpopelka20
Copy link
Contributor Author

I'm using a Sagemaker Jupyter notebook so Amazon Linux. This is the distro info:

NAME="Amazon Linux"
VERSION="2"
ID="amzn"
ID_LIKE="centos rhel fedora"
VERSION_ID="2"
PRETTY_NAME="Amazon Linux 2"
ANSI_COLOR="0;33"
CPE_NAME="cpe:2.3:o:amazon:amazon_linux:2"
HOME_URL="https://amazonlinux.com/"
SUPPORT_END="2025-06-30"
Amazon Linux release 2 (Karoo)

@joshpopelka20
Copy link
Contributor Author

Adding some observations:

When I run CUDA_NVCC_FLAGS=-fPIE cargo build --release --features "cuda cudnn" and CUDA_NVCC_FLAGS=-fPIE cargo build --release --features "cuda flash-attn", I don't get the error.

When I run CUDA_NVCC_FLAGS=-fPIE cargo build --release --features "flash-attn cudnn" or CUDA_NVCC_FLAGS=-fPIE cargo build --release --features "cuda flash-attn cudnn", I get the same error:

= note: /usr/bin/ld: /home/ec2-user/mistral.rs/target/release/build/mistralrs-core-9c498c55121e0e87/out/libmistralcuda.a(nonzero_bitwise-b50867152df76f01.o): relocation R_X86_64_32 against symbol `_Z17transform_indicesPKjjS0_jPj' can not be used when making a shared object; recompile with -fPIC

Running CUDA_NVCC_FLAGS=-fPIE cargo build --release --features "cudnn" by itself is also successful.

So it doesn't seem like compiling with both cudnn and flash-attn.

@joshpopelka20
Copy link
Contributor Author

I'm thinking the issue is related to mistralrs-core/build.rs file. I tried to add

.arg("--compiler-options")
.arg("-fPIC")

but it didn't help.

I think it's more of an issue with this line of code println!("cargo:rustc-link-lib=mistralcuda"). How do I find that C library and determine if it was compiled with the -fPIC flag?

@EricLBuehler
Copy link
Owner

I'm thinking the issue is related to mistralrs-core/build.rs file. I tried to add

I added support for the NVCC flag envvar to that, so it should be seamless to use the envvar instead of changing the code, now.

I think it's more of an issue with this line of code println!("cargo:rustc-link-lib=mistralcuda"). How do I find that C library and determine if it was compiled with the -fPIC flag?

It's in whatever the OUT_DIR environment variable is. Perhaps you can panic! on it: panic!("{build_dir}");?

@joshpopelka20
Copy link
Contributor Author

I added a pull request with a fix for the issue #471. Looks like it was a divide by zero issue.

I didn't add any error message; I just let it continue to run. I ran llama and there was no issues.

Also, I'm wondering if there should be other code added to the build.rs file. Like in the candle project:

    let target = std::env::var("TARGET").unwrap();
    if target.contains("msvc") {
        // nothing to link to
    } else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") {
        println!("cargo:rustc-link-lib=dylib=c++");
    } else if target.contains("android") {
        println!("cargo:rustc-link-lib=dylib=c++_shared");
    } else {
        println!("cargo:rustc-link-lib=dylib=stdc++");
    }

I didn't have any issues with these, but someone else might be using those OSs in the future.

@EricLBuehler
Copy link
Owner

Also, I'm wondering if there should be other code added to the build.rs file. Like in the candle project:

I just merged #472 which adds this, thanks for pointing that out.

@joshpopelka20
Copy link
Contributor Author

The layers are now distributed across my 4 A10G GPUs

image

One request I have is: can there be a progress bar added when the model is being loaded? For larger models (40+ gb), it takes about 20 mins and is hard to know what is going on.

No rush on this, but it would be a nice enhancement.

@EricLBuehler
Copy link
Owner

EricLBuehler commented Jun 25, 2024

Hi @joshpopelka20! I just merged #479 which adds a loading bar while loading the repeating layers. It would be great if you could install from source with maturin ahead of the PyPI rolling release (in ~2 days) to try it out!

@joshpopelka20
Copy link
Contributor Author

There were no issues building from source. Also, the 2 day delay is not an issue for me.

Finally, I think the ask for this issue is complete. Would you like me to leave it open for adding tensor parallelism in the future? I'm not sure how you are tracking that.

@EricLBuehler
Copy link
Owner

There were no issues building from source. Also, the 2 day delay is not an issue for me.

Great, just one thing to confirm: does the progress bar function to show the loading?

Would you like me to leave it open for adding tensor parallelism in the future? I'm not sure how you are tracking that.

I'll create a separate issue, as device mapping is a bit different from tensor parallelism.

@b0xtch
Copy link

b0xtch commented Jun 25, 2024

Amazing stuff! The tensor parallelism I am guessing will be on the core candle repo? or do you plan to abstract that in some way under this repo?

I have linked the device mapping issue with one in candle. huggingface/candle#2007

@EricLBuehler
Copy link
Owner

Amazing stuff! The tensor parallelism I am guessing will be on the core candle repo? or do you plan to abstract that in some way under this repo?

I plan on implementing the higher-level aspects here: the synchronization, the reduce ops, etc., which can all be done with the public Candle APIs. I actually have a fork of Candle which I maintain (https://github.com/EricLBuehler/candle). I have this because some of the features which make mistral.rs faster than Candle would not get merged quickly enough/fit that project's goal well. However, I do want to contribute any progress I make with tensor parallelism and so I'll try to contribute what makes sense!

@joshpopelka20
Copy link
Contributor Author

When I run CUDA_NVCC_FLAGS="-fPIE" maturin develop -r --features "cuda flash-attn cudnn" and try to load the model with the Runner class, I get this error:

panicked at /home/ec2-user/.cargo/registry/src/index.crates.io-6f17d22bba15001f/cudarc-0.11.6/src/driver/result.rs:63:43:
thread panicked while processing panic. aborting.

I don't normally run from the command line so I tried with the current pip package (mistralrs-cuda 0.1.22). That was able to load the model.

@EricLBuehler
Copy link
Owner

When I run CUDA_NVCC_FLAGS="-fPIE" maturin develop -r --features "cuda flash-attn cudnn" and try to load the model with the Runner class, I get this error:

Can you try to run that with RUST_BACKTRACE=1?

@joshpopelka20
Copy link
Contributor Author

I tried with RUST_BACKTRACE=1 CUDA_NVCC_FLAGS="-fPIE" maturin develop -r --features "cuda flash-attn cudnn" and export RUST_BACKTRACE=1, it isn't giving me any additional output. Not sure if I'm doing it wrong or that's it for the stacktrace.

@joshpopelka20
Copy link
Contributor Author

image

Tried 'full' as well:
image

@joshpopelka20
Copy link
Contributor Author

#478 also seems to have an issue with that library as well. Though in a different file: cudarc-0.11.6/src/lib.rs

Not sure if they're related.

@joshpopelka20
Copy link
Contributor Author

It seems to be throwing the error at this line:

let err_str = self.error_string().unwrap();

From this https://users.rust-lang.org/t/how-to-prevent-thread-panicked-while-processing-panic-aborting/56508/2, it looks like it might be an issue with trying to unwrap that error string. I'll try to debug tonight.

@joshpopelka20
Copy link
Contributor Author

joshpopelka20 commented Jun 26, 2024

I cloned https://github.com/coreylowman/cudarc and added this code to the mistral.rs root Cargo.toml (on my box):

[patch.crates-io]
cudarc = { path = "/home/ec2-user/cudarc" }

I removed the unwrap function and now I'm getting this error: Segmentation fault

Any suggestions on further debugging?

I'll see if I can get more output tomorrow; right now, that error seems like another Cuda bug.

@NiuBlibing
Copy link

I tried to run with ./target/release/mistralrs-server --port 1234 -n "0:20;1:20;2:20;3:20" plain -m ./Qwen/Qwen2-72B-Instruct/ -a qwen2 but still oom where only one gpu's memory was growing.

And ./target/release/mistralrs-server --port 1234 -n "0:6;1:6;2:6;3:6" plain -m /jr-sec-ai-train/open-models/Qwen/Qwen1.5-0.5B-Chat/ -a qwen2 is success where four gpus' memory was growing.

Is there something wrong for me? I use git version e04f8400.

@joshpopelka20
Copy link
Contributor Author

I was able to get more logging when I used panic in line 53 of cudarc/src/driver/result.rs

   pub fn error_string(&self) -> Result<&CStr, DriverError> {
        let mut err_str = MaybeUninit::uninit();
        panic!("{:?}", err_str);

Not sure it's helpful though

thread '<unnamed>' panicked at /home/ec2-user/cudarc/src/driver/result.rs:53:9:
core::mem::maybe_uninit::MaybeUninit<*const i8>
stack backtrace:
   0:     0x7ff5fd0eb3f5 - <std::sys_common::backtrace::_print::DisplayBacktrace as core::fmt::Display>::fmt::h1e1a1972118942ad
   1:     0x7ff5fd1197cb - core::fmt::write::hc090a2ffd6b28c4a
   2:     0x7ff5fd0e74df - std::io::Write::write_fmt::h8898bac6ff039a23
   3:     0x7ff5fd0eb1ce - std::sys_common::backtrace::print::ha96650907276675e
   4:     0x7ff5fd0ec639 - std::panicking::default_hook::{{closure}}::h215c2a0a8346e0e0
   5:     0x7ff5fd0ec37d - std::panicking::default_hook::h207342be97478370
   6:     0x7ff5fd0ecad3 - std::panicking::rust_panic_with_hook::hac8bdceee1e4fe2c
   7:     0x7ff5fd0ec9b4 - std::panicking::begin_panic_handler::{{closure}}::h00d785e82757ce3c
   8:     0x7ff5fd0eb8b9 - std::sys_common::backtrace::__rust_end_short_backtrace::h1628d957bcd06996
   9:     0x7ff5fd0ec6e7 - rust_begin_unwind
  10:     0x7ff5fc369e43 - core::panicking::panic_fmt::hdc63834ffaaefae5
  11:     0x7ff5fd070eba - <&T as core::fmt::Debug>::fmt::hbb771b0a79147136
  12:     0x7ff5fd1197cb - core::fmt::write::hc090a2ffd6b28c4a
  13:     0x7ff5fd071010 - <cudarc::driver::result::DriverError as core::fmt::Display>::fmt::heb0f09e810474a5e
  14:     0x7ff5fd1197cb - core::fmt::write::hc090a2ffd6b28c4a
  15:     0x7ff5fcf92ae7 - <candle_core::error::Error as core::fmt::Display>::fmt::hf6848a77fb28bd8b
  16:     0x7ff5fc37ca6e - mistralrs::Runner::new::h62e9d9fcf7c2e3fa
  17:     0x7ff5fc3843d4 - mistralrs::Runner::__pymethod___new____::h12cf9fba34601bfc
  18:     0x7ff5fc37980a - pyo3::impl_::trampoline::trampoline::hf137faff76e4bf3b
  19:     0x7ff5fc3837c1 - mistralrs::<impl pyo3::impl_::pyclass::PyMethods<mistralrs::Runner> for pyo3::impl_::pyclass::PyClassImplCollector<mistralrs::Runner>>::py_methods::ITEMS::trampoline::hcb06f753a45992c0
  20:     0x556be9392db2 - type_call
                               at /usr/local/src/conda/python-3.10.8/Objects/typeobject.c:1123:11
  21:     0x556be9392db2 - _PyObject_MakeTpCall
                               at /usr/local/src/conda/python-3.10.8/Objects/call.c:215:18
  22:     0x556be938f097 - _PyObject_VectorcallTstate
                               at /usr/local/src/conda/python-3.10.8/Include/cpython/abstract.h:112:16
  23:     0x556be938f097 - _PyObject_VectorcallTstate
                               at /usr/local/src/conda/python-3.10.8/Include/cpython/abstract.h:99:1
  24:     0x556be938f097 - PyObject_Vectorcall
                               at /usr/local/src/conda/python-3.10.8/Include/cpython/abstract.h:123:12
  25:     0x556be938f097 - call_function
                               at /usr/local/src/conda/python-3.10.8/Python/ceval.c:5891:13
  26:     0x556be938f097 - _PyEval_EvalFrameDefault
                               at /usr/local/src/conda/python-3.10.8/Python/ceval.c:4231:19
  27:     0x556be943c732 - _PyEval_EvalFrame
                               at /usr/local/src/conda/python-3.10.8/Include/internal/pycore_ceval.h:46:12
  28:     0x556be943c732 - _PyEval_Vector
                               at /usr/local/src/conda/python-3.10.8/Python/ceval.c:5065:24
  29:     0x556be943c677 - PyEval_EvalCode
                               at /usr/local/src/conda/python-3.10.8/Python/ceval.c:1134:12
  30:     0x556be9470049 - run_eval_code_obj
                               at /usr/local/src/conda/python-3.10.8/Python/pythonrun.c:1291:9
  31:     0x556be946a964 - run_mod
                               at /usr/local/src/conda/python-3.10.8/Python/pythonrun.c:1312:19
  32:     0x556be92ee123 - pyrun_file
                               at /usr/local/src/conda/python-3.10.8/Python/pythonrun.c:1208:15
  33:     0x556be9464c9f - _PyRun_SimpleFileObject
                               at /usr/local/src/conda/python-3.10.8/Python/pythonrun.c:456:13
  34:     0x556be9464863 - _PyRun_AnyFileObject
                               at /usr/local/src/conda/python-3.10.8/Python/pythonrun.c:90:15
  35:     0x556be9461a1f - pymain_run_file_obj
                               at /usr/local/src/conda/python-3.10.8/Modules/main.c:357:15
  36:     0x556be9461a1f - pymain_run_file
                               at /usr/local/src/conda/python-3.10.8/Modules/main.c:376:15
  37:     0x556be9461a1f - pymain_run_python
                               at /usr/local/src/conda/python-3.10.8/Modules/main.c:591:21
  38:     0x556be9461a1f - Py_RunMain
                               at /usr/local/src/conda/python-3.10.8/Modules/main.c:670:5
  39:     0x556be942f969 - Py_BytesMain
                               at /usr/local/src/conda/python-3.10.8/Modules/main.c:1090:12
  40:     0x7ff60dd0b13a - __libc_start_main
  41:     0x556be942f871 - <unknown>
Traceback (most recent call last):
  File "/home/ec2-user/test.py", line 27, in <module>
    llm = Runner(
pyo3_runtime.PanicException: core::mem::maybe_uninit::MaybeUninit<*const i8>

@joshpopelka20
Copy link
Contributor Author

I've narrowed it down to this line in mistralrs-pyo3/src/lib.rs for function not metal get_device():

let res = Device::cuda_if_available(0)?;

The error handling isn't very good to know exactly what the error is. I'm trying some add'l error handling, but nothing has worked so far.

@EricLBuehler
Copy link
Owner

I've narrowed it down to this line in mistralrs-pyo3/src/lib.rs for function not metal get_device():

The error is probably happening there because that is when the CUDA stuff will get initialized for the first time.

@joshpopelka20
Copy link
Contributor Author

CUDA operation failed with error: CUDA_ERROR_STUB_LIBRARY

I think it's an issue with LD_LIBRARY_PATH pointing to cuda12.1. I tried to manually update it, but that didn't work. I'll get a ticket open with AWS.

Also, I'm going to create a PR with cudarc to add the additional error handling. Think that'll be beneficial going forward.

@joshpopelka20
Copy link
Contributor Author

Just a quick update on this. AWS responded to me, but they'll need more time to troubleshoot. Hope this isn't holding anybody else up.

@b0xtch
Copy link

b0xtch commented Jun 28, 2024

Just a quick update on this. AWS responded to me, but they'll need more time to troubleshoot. Hope this isn't holding anybody else up.

I am getting an issue with flash-attn when running this on CUDA > 12.x on AWS instances (ubuntu, amazon amis)

@joshpopelka20
Copy link
Contributor Author

@b0xtch your issue may not be the same. Are you seeing any CUDA error codes (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1ggc6c391505e117393cc2558fff6bfc2e990696c86fcee1f536a1ec7d25867feeb)?

Like these:

CUDA_ERROR_INVALID_VALUE = 1
This indicates that one or more of the parameters passed to the API call is not within an acceptable range of values.
CUDA_ERROR_OUT_OF_MEMORY = 2
The API call failed because it was unable to allocate enough memory or other resources to perform the requested operation.
CUDA_ERROR_NOT_INITIALIZED = 3
This indicates that the CUDA driver has not been initialized with cuInit() or that initialization has failed.
CUDA_ERROR_DEINITIALIZED = 4
This indicates that the CUDA driver is in the process of shutting down.

Sorry, the CUDA docs aren't very user friendly, but you'll find more in the above link.

@b0xtch
Copy link

b0xtch commented Jun 28, 2024

CUDA_ERROR_OUT_OF_MEMORY

  • yes getting the CUDA_ERROR_OUT_OF_MEMORY
  • it’s hard to debug when both flash and cuda are both giving errors

@EricLBuehler
Copy link
Owner

CUDA_ERROR_OUT_OF_MEMORY

  • yes getting the CUDA_ERROR_OUT_OF_MEMORY
  • it’s hard to debug when both flash and cuda are both giving errors

@chenwanqq mentioned on #472 that there may be a redundant copy. I'm looking into it.

@Leflak
Copy link

Leflak commented Jun 29, 2024

Don't know if that is the right place (sorry if not) but got 3 x RTX 3090 (Debian 12.5) and when launching :

cargo run --release --features cuda -- -n "0:16;1:16;2:16" -i plain -m google/gemma-2-27b-it -a gemma2

That always try to fill the device 0 and leads CUDA_ERROR_OUT_OF_MEMORY (tried with different layer values but still the same). Any idea?

@joshpopelka20
Copy link
Contributor Author

I retried with the latest Pypi package (version 0.1.24), and it loads the device layers into my four A10G GPUs:
image

When I try to run inference:

output = llm.send_chat_completion_request(
    ChatCompletionRequest(
        model="llama",
        messages=messages,
        max_tokens=256,
        presence_penalty=1.0,
        top_p=0.1,
        temperature=0,
    )
)

I get this error:

called Result::unwrap() on an Err value: Cuda(Cuda(DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory")))

This is how the GPU memory is distributed after the error message:
image

I think it's loading the KV cache into only one of the four GPUs. Is there an available profiling tool to know for certain what is in the memory of each GPU?

@joshpopelka20
Copy link
Contributor Author

Don't know if this will help, but this is the code that seems to be causing the issue (in mistralrs-core/src/pipeline/
/mod.rs):

    async fn step(
        &mut self,
        input_seqs: &mut [&mut Sequence],
        is_prompt: bool,
        prefix_cacher: &mut PrefixCacheManager,
        disable_eos_stop: bool,
        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
        pre_op: CacheInstruction,
        post_op: CacheInstruction,
    ) -> Result<(), candle_core::Error> {
        let inputs = self
            .get_processor()
            .inputs_processor()
            .process_inputs(
                self.tokenizer(),
                input_seqs,
                is_prompt,
                self.get_metadata().is_xlora,
                &self.device(),
                self.get_metadata().has_no_kv_cache,
                None,
                self.get_input_processor_config(),
            )
            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;

Since &self.device() is the same each time, only its memory is going up during inference.

This systems programming is new for me, so not sure what the fix could be. Is there a way to check for multiple devices and choose the one with the highest available memory? I think, ideally, some method that creates a device pool and splits the KV cache across the devices would be best.

@Leflak, @b0xtch, @NiuBlibing not sure if you are seeing it error at the same place, but that's my guess.

@EricLBuehler
Copy link
Owner

@joshpopelka20 yes, we store the KV cache on one GPU. This is because the attention operation necessitates that the Q, K and V matrices be on the same GPU. I'm not sure if there is a way to solve this other than KV cache quantization or other compression techniques.

@joshpopelka20gmail
Copy link

@EricLBuehler thanks for the reply. You can let me know if I'm wrong, but I think the prefix cache might help here.

So for my use case, the prompt will mostly be the same each time (approx 9900 out of the 10K prompt tokens). What I noticed is on the first request, the KV cache grows on one GPU. Then on the prompt phase of the second request, I get OOM. Can we get the prefixes (matching on the 9900 tokens) and evict from the GPU prior to the prompt phase?

I know this is use case specific and probably won't work for most use cases. Hopefully, there is some on-going research on how to split the KV cache across devices.

@joshpopelka20
Copy link
Contributor Author

I've been doing some more debugging. I added a logger in mistralrs-core/src/models/quantized_llama.rs after let x = (attn + residual)?; in this function :

impl ModelWeights {
    pub fn forward(
        &self,
        x: &Tensor,
        start_offsets: &[usize],
        start_offsets_kernel: Tensor,
        context_lens: Vec<(usize, usize)>,
    ) -> Result<Tensor> {

This is the logger:

            let bytes = x.elem_count() * x.dtype().size_in_bytes();
            let size_mb = (bytes as f64) / (1024.0 * 1024.0);
            println!("Tensor size: {:.2} MB", size_mb);

After 5 runs with the same prompt, this is how the variable size looks:
image

This looks like a memory leak and I don't think it is the correct behavior. I haven't had luck in figuring out what is causing this behavior, but I suspect it is something in the Engine. It isn't doing garbage collection correctly after each run.

@EricLBuehler
Copy link
Owner

EricLBuehler commented Jul 6, 2024

I think this depends on the situation. Assuming you are running in interactive mode, this is expected as the chat interaction increases the KV cache size. Taking some quick measurements, the growth rate of the size is linear, which is what we would expect as the sequence length increases by some constant step.

It isn't doing garbage collection correctly after each run.

In Rust, memory deallocation is not handled by a garbage collector and instead is coded into the executable with the Drop trait. This is called when something goes out of scope. Memory leakage is not a possibility because it is a static guarantee, but holding allocations for a long time is. So, if you run mistral.rs, the KV cache for a sequence is only dropped once it is finished.

In the interactive mode, each time you send a request, a new sequence is created. The number of tokens will be higher and higher each time, increasing the prompt length and therefore the KV cache size.

I think the main question is: are you using interactive mode?

I will look into this though!

@joshpopelka20gmail
Copy link

Yes, for debugging the issue, I've been using interactive mode.

I'll retry the debugger with a maturin build on Monday (though AWS is still looking into my issue when I do a "build from source" so that might not work).

Thanks! Whenever you have the time to look at, it's no rush. Enjoy your weekend :)

@joshpopelka20
Copy link
Contributor Author

I was able to retest with this provided example: examples/http.md. Using the api server, the tensor was the same size for each run:
image

I think this issue has been fixed. I'll close it.

Finally, @EricLBuehler thank you for your time and patience. The codebase is starting to make sense to me and reading research papers about transformers/llms is getting easier. I've spent most of my career working with Java and Javascript, so really excited to start my career transition to an "AI engineer" role.

@EricLBuehler
Copy link
Owner

@joshpopelka20 glad that it works, I'm happy to help. Please let me know if you have any questions about mistral.rs or transformers/LLMs in general!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend Backend work models Additions to model or architectures new feature New feature or request
Projects
None yet
Development

No branches or pull requests

6 participants