diff --git a/Serving/Trillium/JetStream-Maxtext/Mixtral-8X22B/README.md b/Serving/Trillium/JetStream-Maxtext/Mixtral-8X22B/README.md index cd544e3..d2b6a90 100644 --- a/Serving/Trillium/JetStream-Maxtext/Mixtral-8X22B/README.md +++ b/Serving/Trillium/JetStream-Maxtext/Mixtral-8X22B/README.md @@ -102,6 +102,7 @@ P99 TPOT: 16942.69 ms ``` # Appendix Checkpoint conversion on CPU +Note that at present we only recommend loading the checkpoint labelled 0.3, in `safetensors` format (as below), as we have seen problems with the 0.1 checkpoint in `pth` format ```bash # Get checkpoint from https://github.com/mistralai/mistral-inference @@ -117,7 +118,7 @@ export CHKPT_BUCKET=gs:// export SCANNED_CHKPT_PATH=${CHKPT_BUCKET}/scanned_ckpt JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py \ --base-model-path=${M8x22B_DIR} --model-size=mixtral-8x22b \ ---maxtext-model-path=${SCANNED_CHKPT_PATH} +--maxtext-model-path=${SCANNED_CHKPT_PATH} --checkpoint-type=safetensors # Convert checkpoint to unscanned version export UNSCANNED_RUN_NAME=unscanned_ckpt