-
Notifications
You must be signed in to change notification settings - Fork 115
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #680 from omertuc/wrapper
`ilab` wrapper script adjustments
- Loading branch information
Showing
2 changed files
with
19 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,185 +1,29 @@ | ||
#!/bin/bash | ||
|
||
# Template values replaced by container build | ||
ENDPOINT_URL="__REPLACE_ENDPOINT_URL__" | ||
TRAIN_DEVICE="__REPLACE_TRAIN_DEVICE__" | ||
CONTAINER_DEVICE="__REPLACE_CONTAINER_DEVICE__" | ||
IMAGE_NAME="__REPLACE_IMAGE_NAME__" | ||
VLLM_NAME="__REPLACE_VLLM_NAME__" | ||
TRAIN_NAME="__REPLACE_TRAIN_NAME__" | ||
GPU_COUNT_COMMAND="__REPLACE_GPU_COUNT_COMMAND__" | ||
|
||
# ENDPOINT_URL="http://0.0.0.0:8080/v1" | ||
# TRAIN_DEVICE="cuda" | ||
# CONTAINER_DEVICE="nvidia.com/gpu=all" | ||
# IMAGE_NAME="quay.io/ai-lab/instructlab-nvidia:latest" | ||
# VLLM_NAME="quay.io/ai-lab/vllm:latest" | ||
# TRAIN_NAME="quay.io/ai-lab/deepspeed-trainer:latest" | ||
# GPU_COUNT_COMMAND="nvidia-ctk --quiet cdi list | grep -P nvidia.com/gpu='\d+' | wc -l" | ||
export ENTRYPOINT="/opt/python3.11/venv/bin/ilab" | ||
export PARAMS=("$@") | ||
|
||
# HF caching uses relative symlink structures, so keep cache relative to | ||
# the central working directory | ||
CONTAINER_CACHE="/instructlab/cache" | ||
HOST_CACHE="$(pwd)/cache" | ||
WORKDIR="$(pwd)" | ||
SCRIPT_DIR=$(dirname "$0") | ||
DEFAULT_SERVE_MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1" | ||
for dir in "$HOME/.cache" "$HOME/.config" "$HOME/.local"; do | ||
mkdir -p "$dir" | ||
done | ||
|
||
if [[ -z "${GPU_AMOUNT}" ]]; then | ||
GPU_AMOUNT=$(bash -c "${GPU_COUNT_COMMAND}") | ||
if [[ "$?" != "0" ]]; then | ||
echo "Could not determine GPU count, set export GPU_AMOUNT= manually" | ||
exit | ||
fi | ||
if [[ "$1" = "shell" ]]; then | ||
export ENTRYPOINT=bash | ||
export PARAMS=() | ||
fi | ||
|
||
if [[ "$GPU_AMOUNT" -lt 2 ]]; then | ||
echo "WARNING: You need at least 2 GPUs to load full precision models" | ||
fi | ||
|
||
NPROC_PER_NODE=${GPU_AMOUNT} | ||
EFFECTIVE_BATCH_SIZE=$((12*${GPU_AMOUNT})) | ||
NUM_INSTRUCTIONS=5000 | ||
NUM_EPOCHS=10 | ||
|
||
has_argument() { | ||
match=$1 | ||
shift | ||
for arg in "$@"; do | ||
if [[ "$arg" == *"$match"* ]]; then | ||
return 0 | ||
fi | ||
done | ||
return 1 | ||
} | ||
|
||
get_argument() { | ||
local match=$1 | ||
shift | ||
|
||
local found=false | ||
local arg | ||
while [ "$#" -gt 0 ]; do | ||
arg="$1" | ||
shift | ||
if [[ "$arg" == "$match" ]]; then | ||
found=true | ||
if [ "$#" -gt 0 ]; then | ||
echo "$1" | ||
return 0 | ||
else | ||
echo "" | ||
return 0 | ||
fi | ||
fi | ||
done | ||
|
||
if ! $found; then | ||
echo "" | ||
return 0 | ||
fi | ||
} | ||
|
||
get_argument_default() { | ||
local match=$1 | ||
local default=$2 | ||
shift | ||
shift | ||
local result=$(get_argument ${match} "$@") | ||
if [[ -z "${result}" ]]; then | ||
echo $default | ||
return 0 | ||
fi | ||
echo "${result}" | ||
} | ||
|
||
get_model() { | ||
model=$(get_argument_default "--model" "${DEFAULT_SERVE_MODEL}" "$@") | ||
if [[ ! "${model}" =~ ^/instructlab/models.* ]]; then | ||
echo /instructlab/models/"${model}" | ||
else | ||
echo "${model}" | ||
fi | ||
} | ||
|
||
mkdir -p "${HOST_CACHE}" | ||
PODMAN_COMMAND=("podman" "run" "--rm" "-it" "--device" "${CONTAINER_DEVICE}" \ | ||
"--security-opt" "label=disable" "--net" "host" \ | ||
"-v" "${WORKDIR}:/instructlab" "--entrypoint" "" \ | ||
"-e" "HF_HOME=${CONTAINER_CACHE}" \ | ||
"-e" "HF_TOKEN=${HF_TOKEN}" \ | ||
"${IMAGE_NAME}") | ||
PODMAN_COMMAND_SERVE=("podman" "run" "--rm" "-it" "--device" "${CONTAINER_DEVICE}" \ | ||
"--security-opt" "label=disable" "--net" "host" \ | ||
"-v" "${WORKDIR}:/instructlab" \ | ||
"--shm-size=10gb" \ | ||
"-e" "HF_HOME=${CONTAINER_CACHE}/" \ | ||
"-e" "HF_TOKEN=${HF_TOKEN}" \ | ||
"${VLLM_NAME}" "--host=0.0.0.0" "--port=8080" "--tensor-parallel-size=${GPU_AMOUNT}") | ||
|
||
if [[ "$1" = "init" ]]; then | ||
if ! has_argument "--repository" "$@"; then | ||
shift | ||
"${PODMAN_COMMAND[@]}" ilab init \ | ||
--repository https://github.com/instructlab/taxonomy.git "$@" | ||
exit $? | ||
fi | ||
elif [[ "$1" = "train" ]]; then | ||
samples=$(get_argument_default "--num-samples" ${NUM_INSTRUCTIONS} "$@") | ||
epochs=$(get_argument_default "--num-epochs" ${NUM_EPOCHS} "$@") | ||
${SCRIPT_DIR}/ilab-training-launcher ${NPROC_PER_NODE} ${EFFECTIVE_BATCH_SIZE} \ | ||
${TRAIN_DEVICE} ${samples} ${epochs} ${CONTAINER_DEVICE} ${TRAIN_NAME} | ||
exit $? | ||
elif [[ "$1" = "serve" ]]; then | ||
# run vllm container which will serve vllm and ilab generate | ||
args=() | ||
model=$(get_model "$@") | ||
if [[ "${model}" == *"${DEFAULT_SERVE_MODEL}" ]]; then | ||
args+=("--chat-template=mixtral.jinja") | ||
fi | ||
args+=("--model" "${model}") | ||
"${PODMAN_COMMAND_SERVE[@]}" "${args[@]}" | ||
exit $? | ||
elif [[ "$1" = "chat" ]]; then | ||
shift | ||
args=($@) | ||
if ! has_argument "--endpoint-url" "$@"; then | ||
args+=("--endpoint-url" "http://0.0.0.0:8080/v1") | ||
fi | ||
if ! has_argument "--model-family" "$@"; then | ||
args+=("--model-family" "mixtral") | ||
fi | ||
args+=("--model" $(get_model "$@")) | ||
"${PODMAN_COMMAND[@]}" ilab chat "${args[@]}" | ||
exit $? | ||
elif [[ "$1" = "generate" ]]; then | ||
shift | ||
args=($@) | ||
if ! has_argument "--endpoint-url" "$@"; then | ||
args+=("--endpoint-url" "http://0.0.0.0:8080/v1") | ||
fi | ||
if ! has_argument "--model-family" "$@"; then | ||
args+=("--model-family" "mixtral") | ||
fi | ||
if ! has_argument "--num-instructions" "$@"; then | ||
args+=("--num-instructions" "5000") | ||
fi | ||
args+=("--model" $(get_model "$@")) | ||
echo ilab generate "${args[@]}" | ||
|
||
"${PODMAN_COMMAND[@]}" ilab generate "${args[@]}" | ||
exit $? | ||
elif [[ "$1" == "download" && $# -lt 2 ]]; then | ||
echo "You must specify the model to download." | ||
echo | ||
echo "High-fidelity generation and training requires two models:" | ||
echo | ||
echo "Mixtral: ilab download --repository ${DEFAULT_SERVE_MODEL}" | ||
echo "Granite: ilab download --repository ibm/granite-7b-base" | ||
echo | ||
echo "For more options type ilab --help" | ||
exit 1 | ||
fi | ||
|
||
"${PODMAN_COMMAND[@]}" ilab "$@" | ||
|
||
PODMAN_COMMAND=("podman" "run" "--rm" "-it" | ||
"--device" "${CONTAINER_DEVICE}" | ||
"--security-opt" "label=disable" "--net" "host" | ||
"-v" "$HOME/.cache:/root/.cache" | ||
"-v" "$HOME/.config:/root/.config" | ||
"-v" "$HOME/.local:/root/.local" | ||
"--entrypoint" "$ENTRYPOINT" | ||
"--env" "HF_TOKEN" | ||
"${IMAGE_NAME}") | ||
|
||
"${PODMAN_COMMAND[@]}" "${PARAMS[@]}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters