-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OOM with 8 A800 #34
Comments
@647sherry Our training was conducted on 8 A100-80G gpus, which is two times larger than your setting. For larger models, you could try reducing per_device_train_batch_size as needed, and increase gradient_accumulation_steps, such that per_device_train_batch_size * gradient_accumulation_steps = 8. |
thx for ur reply. I set per_device_train_batch_size = 1 and gradient_accumulation_steps = 8, with zero_stage2&3, but still oom |
I ran into a similar issue with a deepseek-math-7b model on a100 80G GPU, cannot get things to work with er_device_train_batch_size = 1 and gradient_accumulation_steps = 1, I suspect there is some model related bug there. |
@jojo23333 Is that OOM error or other issues? Can you check the GPU utilization while you are not running this training to see if it has been taken by other processes? From what I know OOM under your setting is not possible. |
Hi, thanks so much for the reply! I'm pretty sure all memory was eaten up by this single process. I did a very detailed verification and see the forward process and model loading takes up to 30 ishGB/80GB and the backward process immediately leads to OOM. However, that being said, I had a kind of different setup.
I'm not super sure whether this has something to do with the response/context length for a single sample, I roughly estimate that the response length is at most 2x longer than what it is in the superchat dataset. But still, cannot get batch_size=1 working is somewhat wierd. |
Actually I tried zephyr-7b-sft-full with the orginal setting on my data, I was able to get training going on with per-device batch size = 8, but not with deepseek some how. |
Hi @jojo23333 Have you solved the OOM? This might be an related issue huggingface/transformers#29484 |
thanks for the pointer, I'll take a look |
hi, I got OOM error while fine tuning with qwen-14b-chat and the default model.
using
accelerate launch --config_file configs/deepspeed_zero3.yaml --multi_gpu --num_processes=8 --main_process_port 29501 spin/run_spin.py configs/config.yaml --num_train_epochs=3 --output_dir="xxx/spin_outputs/iter0-ckpt"
system info
absl-py 2.1.0
accelerate 0.23.0
aiohttp 3.9.5
aioprometheus 23.12.0
aiosignal 1.3.1
annotated-types 0.7.0
anyio 4.4.0
async-timeout 4.0.3
attrs 23.2.0
bitsandbytes 0.41.2.post2
certifi 2024.6.2
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
cmake 3.29.6
contourpy 1.2.1
cycler 0.12.1
datasets 2.14.6
deepspeed 0.12.2
dill 0.3.7
diskcache 5.6.3
dnspython 2.6.1
docstring_parser 0.16
einops 0.8.0
email_validator 2.1.1
evaluate 0.4.0
exceptiongroup 1.2.1
fastapi 0.111.0
fastapi-cli 0.0.4
filelock 3.15.1
flash_attn 2.5.9.post1
fonttools 4.53.0
frozenlist 1.4.1
fsspec 2023.10.0
grpcio 1.64.1
h11 0.14.0
hjson 3.1.0
httpcore 1.0.5
httptools 0.6.1
httpx 0.27.0
huggingface-hub 0.23.3
idna 3.7
interegular 0.3.3
Jinja2 3.1.4
joblib 1.4.2
jsonlines 4.0.0
jsonschema 4.22.0
jsonschema-specifications 2023.12.1
kiwisolver 1.4.5
lark 1.1.9
llvmlite 0.43.0
Markdown 3.6
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.9.0
mdurl 0.1.2
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
multiprocess 0.70.15
nest-asyncio 1.6.0
networkx 3.3
ninja 1.11.1.1
numba 0.60.0
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.5.40
nvidia-nvtx-cu12 12.1.105
opencv-python 4.10.0.84
orjson 3.10.5
outlines 0.0.34
packaging 24.1
pandas 2.2.2
peft 0.6.1
pillow 10.4.0
pip 24.0
prometheus_client 0.20.0
protobuf 3.20.2
psutil 5.9.8
py-cpuinfo 9.0.0
py4j 0.10.9.7
pyarrow 16.1.0
pydantic 2.7.4
pydantic_core 2.18.4
Pygments 2.18.0
pynvml 11.5.0
pyparsing 3.1.2
pyspark 3.5.1
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-multipart 0.0.9
pytz 2024.1
PyYAML 6.0.1
quantile-python 1.1
ray 2.24.0
referencing 0.35.1
regex 2024.5.15
requests 2.32.3
responses 0.18.0
rich 13.7.1
rpds-py 0.18.1
safetensors 0.4.3
scipy 1.13.1
seaborn 0.13.2
sentencepiece 0.2.0
setuptools 69.5.1
shellingham 1.5.4
shtab 1.7.1
six 1.16.0
sniffio 1.3.1
spin 0.1.0.dev0
starlette 0.37.2
sympy 1.12.1
tensorboard 2.17.0
tensorboard-data-server 0.7.2
tiktoken 0.6.0
tokenizers 0.15.2
torch 2.1.0
torchvision 0.18.1
tqdm 4.66.4
transformers 4.36.2
transformers-stream-generator 0.0.5
triton 2.1.0
trl 0.7.4
typer 0.12.3
typing_extensions 4.12.2
tyro 0.8.4
tzdata 2024.1
ujson 5.10.0
ultralytics-thop 2.0.0
urllib3 2.2.1
uvicorn 0.30.1
uvloop 0.19.0
vllm 0.3.0
watchfiles 0.22.0
websockets 12.0
Werkzeug 3.0.3
wheel 0.43.0
xformers 0.0.23.post1
xxhash 3.4.1
yarl 1.9.4
Thanks for your help in advance!
The text was updated successfully, but these errors were encountered: