From 87c89643e1b46a587eac75f2f25be1459fc2947e Mon Sep 17 00:00:00 2001 From: Rich James Date: Tue, 12 Nov 2024 21:12:39 +0000 Subject: [PATCH] Safetensors checkpoint loading --- Serving/Trillium/JetStream-Maxtext/Mixtral-8X22B/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Serving/Trillium/JetStream-Maxtext/Mixtral-8X22B/README.md b/Serving/Trillium/JetStream-Maxtext/Mixtral-8X22B/README.md index cd544e3..d2b6a90 100644 --- a/Serving/Trillium/JetStream-Maxtext/Mixtral-8X22B/README.md +++ b/Serving/Trillium/JetStream-Maxtext/Mixtral-8X22B/README.md @@ -102,6 +102,7 @@ P99 TPOT: 16942.69 ms ``` # Appendix Checkpoint conversion on CPU +Note that at present we only recommend loading the checkpoint labelled 0.3, in `safetensors` format (as below), as we have seen problems with the 0.1 checkpoint in `pth` format ```bash # Get checkpoint from https://github.com/mistralai/mistral-inference @@ -117,7 +118,7 @@ export CHKPT_BUCKET=gs:// export SCANNED_CHKPT_PATH=${CHKPT_BUCKET}/scanned_ckpt JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py \ --base-model-path=${M8x22B_DIR} --model-size=mixtral-8x22b \ ---maxtext-model-path=${SCANNED_CHKPT_PATH} +--maxtext-model-path=${SCANNED_CHKPT_PATH} --checkpoint-type=safetensors # Convert checkpoint to unscanned version export UNSCANNED_RUN_NAME=unscanned_ckpt