diff --git a/.github/workflows/generate_grpc_cache.yaml b/.github/workflows/generate_grpc_cache.yaml new file mode 100644 index 000000000000..c6b080b5f7cb --- /dev/null +++ b/.github/workflows/generate_grpc_cache.yaml @@ -0,0 +1,90 @@ +name: 'generate and publish GRPC docker caches' + +on: +- workflow_dispatch + +concurrency: + group: grpc-cache-${{ github.head_ref || github.ref }}-${{ github.repository }} + cancel-in-progress: true + +jobs: + generate_caches: + strategy: + matrix: + include: + - grpc-base-image: ubuntu:22.04 + runs-on: 'ubuntu-latest' + platforms: 'linux/amd64' + runs-on: ${{matrix.runs-on}} + steps: + - name: Release space from worker + if: matrix.runs-on == 'ubuntu-latest' + run: | + echo "Listing top largest packages" + pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) + head -n 30 <<< "${pkgs}" + echo + df -h + echo + sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true + sudo apt-get remove --auto-remove android-sdk-platform-tools || true + sudo apt-get purge --auto-remove android-sdk-platform-tools || true + sudo rm -rf /usr/local/lib/android + sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true + sudo rm -rf /usr/share/dotnet + sudo apt-get remove -y '^mono-.*' || true + sudo apt-get remove -y '^ghc-.*' || true + sudo apt-get remove -y '.*jdk.*|.*jre.*' || true + sudo apt-get remove -y 'php.*' || true + sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true + sudo apt-get remove -y '^google-.*' || true + sudo apt-get remove -y azure-cli || true + sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true + sudo apt-get remove -y '^gfortran-.*' || true + sudo apt-get remove -y microsoft-edge-stable || true + sudo apt-get remove -y firefox || true + sudo apt-get remove -y powershell || true + sudo apt-get remove -y r-base-core || true + sudo apt-get autoremove -y + sudo apt-get clean + echo + echo "Listing top largest packages" + pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) + head -n 30 <<< "${pkgs}" + echo + sudo rm -rfv build || true + sudo rm -rf /usr/share/dotnet || true + sudo rm -rf /opt/ghc || true + sudo rm -rf "/usr/local/share/boost" || true + sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true + df -h + + - name: Set up QEMU + uses: docker/setup-qemu-action@master + with: + platforms: all + + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@master + + - name: Checkout + uses: actions/checkout@v4 + + - name: Cache GRPC + uses: docker/build-push-action@v5 + with: + builder: ${{ steps.buildx.outputs.name }} + # The build-args MUST be an EXACT match between the image cache and other workflow steps that want to use that cache. + # This means that even the MAKEFLAGS have to be an EXACT match. + # If the build-args are not an EXACT match, it will result in a cache miss, which will require GRPC to be built from scratch. + build-args: | + GRPC_BASE_IMAGE=${{ matrix.grpc-base-image }} + MAKEFLAGS=--jobs=4 --output-sync=target + GRPC_VERSION=v1.58.0 + context: . + file: ./Dockerfile + cache-to: type=gha,ignore-error=true + target: grpc + platforms: ${{ matrix.platforms }} + push: false \ No newline at end of file diff --git a/.github/workflows/image-pr.yml b/.github/workflows/image-pr.yml index b703b16d6ed0..9c4fece71226 100644 --- a/.github/workflows/image-pr.yml +++ b/.github/workflows/image-pr.yml @@ -22,6 +22,7 @@ jobs: platforms: ${{ matrix.platforms }} runs-on: ${{ matrix.runs-on }} base-image: ${{ matrix.base-image }} + grpc-base-image: ${{ matrix.grpc-base-image }} makeflags: ${{ matrix.makeflags }} secrets: dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }} @@ -61,12 +62,14 @@ jobs: ffmpeg: 'false' image-type: 'extras' base-image: "rocm/dev-ubuntu-22.04:6.0-complete" + grpc-base-image: "ubuntu:22.04" runs-on: 'arc-runner-set' makeflags: "--jobs=3 --output-sync=target" - build-type: 'sycl_f16' platforms: 'linux/amd64' tag-latest: 'false' base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04" + grpc-base-image: "ubuntu:22.04" tag-suffix: 'sycl-f16-ffmpeg' ffmpeg: 'true' image-type: 'extras' @@ -85,6 +88,7 @@ jobs: platforms: ${{ matrix.platforms }} runs-on: ${{ matrix.runs-on }} base-image: ${{ matrix.base-image }} + grpc-base-image: ${{ matrix.grpc-base-image }} makeflags: ${{ matrix.makeflags }} secrets: dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }} @@ -102,11 +106,12 @@ jobs: image-type: 'core' runs-on: 'ubuntu-latest' base-image: "ubuntu:22.04" - makeflags: "--jobs=5 --output-sync=target" + makeflags: "--jobs=4 --output-sync=target" - build-type: 'sycl_f16' platforms: 'linux/amd64' tag-latest: 'false' base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04" + grpc-base-image: "ubuntu:22.04" tag-suffix: 'sycl-f16-ffmpeg-core' ffmpeg: 'true' image-type: 'core' @@ -122,4 +127,4 @@ jobs: image-type: 'core' runs-on: 'ubuntu-latest' base-image: "ubuntu:22.04" - makeflags: "--jobs=5 --output-sync=target" \ No newline at end of file + makeflags: "--jobs=4 --output-sync=target" \ No newline at end of file diff --git a/.github/workflows/image.yml b/.github/workflows/image.yml index d2607579b69f..255c1c656073 100644 --- a/.github/workflows/image.yml +++ b/.github/workflows/image.yml @@ -26,6 +26,7 @@ jobs: platforms: ${{ matrix.platforms }} runs-on: ${{ matrix.runs-on }} base-image: ${{ matrix.base-image }} + grpc-base-image: ${{ matrix.grpc-base-image }} aio: ${{ matrix.aio }} makeflags: ${{ matrix.makeflags }} latest-image: ${{ matrix.latest-image }} @@ -129,6 +130,7 @@ jobs: image-type: 'extras' aio: "-aio-gpu-hipblas" base-image: "rocm/dev-ubuntu-22.04:6.0-complete" + grpc-base-image: "ubuntu:22.04" latest-image: 'latest-gpu-hipblas' latest-image-aio: 'latest-aio-gpu-hipblas' runs-on: 'arc-runner-set' @@ -140,12 +142,14 @@ jobs: ffmpeg: 'false' image-type: 'extras' base-image: "rocm/dev-ubuntu-22.04:6.0-complete" + grpc-base-image: "ubuntu:22.04" runs-on: 'arc-runner-set' makeflags: "--jobs=3 --output-sync=target" - build-type: 'sycl_f16' platforms: 'linux/amd64' tag-latest: 'auto' base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04" + grpc-base-image: "ubuntu:22.04" tag-suffix: '-sycl-f16-ffmpeg' ffmpeg: 'true' image-type: 'extras' @@ -158,6 +162,7 @@ jobs: platforms: 'linux/amd64' tag-latest: 'auto' base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04" + grpc-base-image: "ubuntu:22.04" tag-suffix: '-sycl-f32-ffmpeg' ffmpeg: 'true' image-type: 'extras' @@ -171,6 +176,7 @@ jobs: platforms: 'linux/amd64' tag-latest: 'false' base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04" + grpc-base-image: "ubuntu:22.04" tag-suffix: '-sycl-f16-core' ffmpeg: 'false' image-type: 'core' @@ -180,6 +186,7 @@ jobs: platforms: 'linux/amd64' tag-latest: 'false' base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04" + grpc-base-image: "ubuntu:22.04" tag-suffix: '-sycl-f32-core' ffmpeg: 'false' image-type: 'core' @@ -189,6 +196,7 @@ jobs: platforms: 'linux/amd64' tag-latest: 'false' base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04" + grpc-base-image: "ubuntu:22.04" tag-suffix: '-sycl-f16-ffmpeg-core' ffmpeg: 'true' image-type: 'core' @@ -198,6 +206,7 @@ jobs: platforms: 'linux/amd64' tag-latest: 'false' base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04" + grpc-base-image: "ubuntu:22.04" tag-suffix: '-sycl-f32-ffmpeg-core' ffmpeg: 'true' image-type: 'core' @@ -210,6 +219,7 @@ jobs: ffmpeg: 'true' image-type: 'core' base-image: "rocm/dev-ubuntu-22.04:6.0-complete" + grpc-base-image: "ubuntu:22.04" runs-on: 'arc-runner-set' makeflags: "--jobs=3 --output-sync=target" - build-type: 'hipblas' @@ -219,6 +229,7 @@ jobs: ffmpeg: 'false' image-type: 'core' base-image: "rocm/dev-ubuntu-22.04:6.0-complete" + grpc-base-image: "ubuntu:22.04" runs-on: 'arc-runner-set' makeflags: "--jobs=3 --output-sync=target" @@ -236,6 +247,7 @@ jobs: runs-on: ${{ matrix.runs-on }} aio: ${{ matrix.aio }} base-image: ${{ matrix.base-image }} + grpc-base-image: ${{ matrix.grpc-base-image }} makeflags: ${{ matrix.makeflags }} latest-image: ${{ matrix.latest-image }} latest-image-aio: ${{ matrix.latest-image-aio }} @@ -258,7 +270,7 @@ jobs: aio: "-aio-cpu" latest-image: 'latest-cpu' latest-image-aio: 'latest-aio-cpu' - makeflags: "--jobs=5 --output-sync=target" + makeflags: "--jobs=4 --output-sync=target" - build-type: 'cublas' cuda-major-version: "11" cuda-minor-version: "7" @@ -269,7 +281,7 @@ jobs: image-type: 'core' base-image: "ubuntu:22.04" runs-on: 'ubuntu-latest' - makeflags: "--jobs=5 --output-sync=target" + makeflags: "--jobs=4 --output-sync=target" - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "1" @@ -280,7 +292,7 @@ jobs: image-type: 'core' base-image: "ubuntu:22.04" runs-on: 'ubuntu-latest' - makeflags: "--jobs=5 --output-sync=target" + makeflags: "--jobs=4 --output-sync=target" - build-type: 'cublas' cuda-major-version: "11" cuda-minor-version: "7" @@ -291,7 +303,7 @@ jobs: image-type: 'core' runs-on: 'ubuntu-latest' base-image: "ubuntu:22.04" - makeflags: "--jobs=5 --output-sync=target" + makeflags: "--jobs=4 --output-sync=target" - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "1" @@ -302,4 +314,4 @@ jobs: image-type: 'core' runs-on: 'ubuntu-latest' base-image: "ubuntu:22.04" - makeflags: "--jobs=5 --output-sync=target" + makeflags: "--jobs=4 --output-sync=target" diff --git a/.github/workflows/image_build.yml b/.github/workflows/image_build.yml index b0684a4c66fc..b06100ff69da 100644 --- a/.github/workflows/image_build.yml +++ b/.github/workflows/image_build.yml @@ -6,6 +6,10 @@ on: inputs: base-image: description: 'Base image' + required: true + type: string + grpc-base-image: + description: 'GRPC Base image, must be a compatible image with base-image' required: false default: '' type: string @@ -57,7 +61,7 @@ on: makeflags: description: 'Make Flags' required: false - default: '--jobs=3 --output-sync=target' + default: '--jobs=4 --output-sync=target' type: string aio: description: 'AIO Image Name' @@ -201,15 +205,16 @@ jobs: uses: docker/build-push-action@v5 with: builder: ${{ steps.buildx.outputs.name }} + # The build-args MUST be an EXACT match between the image cache and other workflow steps that want to use that cache. + # This means that even the MAKEFLAGS have to be an EXACT match. + # If the build-args are not an EXACT match, it will result in a cache miss, which will require GRPC to be built from scratch. build-args: | - IMAGE_TYPE=${{ inputs.image-type }} - BASE_IMAGE=${{ inputs.base-image }} - MAKEFLAGS=${{ inputs.makeflags }} + GRPC_BASE_IMAGE=${{ inputs.grpc-base-image || inputs.base-image }} + MAKEFLAGS=--jobs=4 --output-sync=target GRPC_VERSION=v1.58.0 context: . file: ./Dockerfile cache-from: type=gha - cache-to: type=gha,ignore-error=true target: grpc platforms: ${{ inputs.platforms }} push: false diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 156294b59da5..f50479e1db23 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -123,7 +123,9 @@ jobs: if: ${{ failure() }} uses: mxschmitt/action-tmate@v3.18 with: + detached: true connect-timeout-seconds: 180 + limit-access-to-actor: true tests-aio-container: runs-on: ubuntu-latest @@ -176,7 +178,9 @@ jobs: if: ${{ failure() }} uses: mxschmitt/action-tmate@v3.18 with: + detached: true connect-timeout-seconds: 180 + limit-access-to-actor: true tests-apple: runs-on: macOS-14 @@ -211,4 +215,6 @@ jobs: if: ${{ failure() }} uses: mxschmitt/action-tmate@v3.18 with: - connect-timeout-seconds: 180 \ No newline at end of file + detached: true + connect-timeout-seconds: 180 + limit-access-to-actor: true diff --git a/Dockerfile b/Dockerfile index 397fbe22618d..805ac3a6d741 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,6 @@ ARG IMAGE_TYPE=extras ARG BASE_IMAGE=ubuntu:22.04 +ARG GRPC_BASE_IMAGE=${BASE_IMAGE} # extras or core FROM ${BASE_IMAGE} as requirements-core @@ -104,7 +105,7 @@ RUN if [ ! -e /usr/bin/python ]; then \ ################################### ################################### -FROM ${BASE_IMAGE} as grpc +FROM ${GRPC_BASE_IMAGE} as grpc ARG MAKEFLAGS ARG GRPC_VERSION=v1.58.0 diff --git a/Makefile b/Makefile index f5b4dc2ac995..761c76d615b0 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ BINARY_NAME=local-ai # llama.cpp versions GOLLAMA_STABLE_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be -CPPLLAMA_VERSION?=7593639ce335e8d7f89aa9a54d616951f273af60 +CPPLLAMA_VERSION?=b8109bc0139f15a5b321909f47510b89dca47ffc # gpt4all version GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all @@ -16,7 +16,7 @@ RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=661e7ae26d442f5cfebd2a0881b44e8c55949ec6 # whisper.cpp version -WHISPER_CPP_VERSION?=a750868428868abd437e228ae5cab763ef3dc387 +WHISPER_CPP_VERSION?=b0c3cbf2e851cf232e432b590dcc514a689ec028 # bert.cpp version BERT_VERSION?=6abe312cded14042f6b7c3cd8edf082713334a4d @@ -179,20 +179,20 @@ endif all: help ## BERT embeddings -sources/go-bert: - git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp sources/go-bert - cd sources/go-bert && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1 +sources/go-bert.cpp: + git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp sources/go-bert.cpp + cd sources/go-bert.cpp && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1 -sources/go-bert/libgobert.a: sources/go-bert - $(MAKE) -C sources/go-bert libgobert.a +sources/go-bert.cpp/libgobert.a: sources/go-bert.cpp + $(MAKE) -C sources/go-bert.cpp libgobert.a -## go-llama-ggml -sources/go-llama-ggml: - git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp sources/go-llama-ggml - cd sources/go-llama-ggml && git checkout -b build $(GOLLAMA_STABLE_VERSION) && git submodule update --init --recursive --depth 1 +## go-llama.cpp +sources/go-llama.cpp: + git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp sources/go-llama.cpp + cd sources/go-llama.cpp && git checkout -b build $(GOLLAMA_STABLE_VERSION) && git submodule update --init --recursive --depth 1 -sources/go-llama-ggml/libbinding.a: sources/go-llama-ggml - $(MAKE) -C sources/go-llama-ggml BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a +sources/go-llama.cpp/libbinding.a: sources/go-llama.cpp + $(MAKE) -C sources/go-llama.cpp BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a ## go-piper sources/go-piper: @@ -211,12 +211,12 @@ sources/gpt4all/gpt4all-bindings/golang/libgpt4all.a: sources/gpt4all $(MAKE) -C sources/gpt4all/gpt4all-bindings/golang/ libgpt4all.a ## RWKV -sources/go-rwkv: - git clone --recurse-submodules $(RWKV_REPO) sources/go-rwkv - cd sources/go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1 +sources/go-rwkv.cpp: + git clone --recurse-submodules $(RWKV_REPO) sources/go-rwkv.cpp + cd sources/go-rwkv.cpp && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1 -sources/go-rwkv/librwkv.a: sources/go-rwkv - cd sources/go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. +sources/go-rwkv.cpp/librwkv.a: sources/go-rwkv.cpp + cd sources/go-rwkv.cpp && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. ## stable diffusion sources/go-stable-diffusion: @@ -236,23 +236,24 @@ sources/go-tiny-dream/libtinydream.a: sources/go-tiny-dream ## whisper sources/whisper.cpp: - git clone https://github.com/ggerganov/whisper.cpp.git sources/whisper.cpp + git clone https://github.com/ggerganov/whisper.cpp sources/whisper.cpp cd sources/whisper.cpp && git checkout -b build $(WHISPER_CPP_VERSION) && git submodule update --init --recursive --depth 1 sources/whisper.cpp/libwhisper.a: sources/whisper.cpp cd sources/whisper.cpp && make libwhisper.a -get-sources: sources/go-llama-ggml sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion sources/go-tiny-dream +get-sources: sources/go-llama.cpp sources/gpt4all sources/go-piper sources/go-rwkv.cpp sources/whisper.cpp sources/go-bert.cpp sources/go-stable-diffusion sources/go-tiny-dream replace: - $(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(CURDIR)/sources/go-rwkv + $(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(CURDIR)/sources/go-rwkv.cpp $(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp $(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp/bindings/go=$(CURDIR)/sources/whisper.cpp/bindings/go - $(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(CURDIR)/sources/go-bert + $(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(CURDIR)/sources/go-bert.cpp $(GOCMD) mod edit -replace github.com/M0Rf30/go-tiny-dream=$(CURDIR)/sources/go-tiny-dream $(GOCMD) mod edit -replace github.com/mudler/go-piper=$(CURDIR)/sources/go-piper $(GOCMD) mod edit -replace github.com/mudler/go-stable-diffusion=$(CURDIR)/sources/go-stable-diffusion $(GOCMD) mod edit -replace github.com/nomic-ai/gpt4all/gpt4all-bindings/golang=$(CURDIR)/sources/gpt4all/gpt4all-bindings/golang + $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(CURDIR)/sources/go-llama.cpp dropreplace: $(GOCMD) mod edit -dropreplace github.com/donomii/go-rwkv.cpp @@ -271,12 +272,12 @@ prepare-sources: get-sources replace ## GENERIC rebuild: ## Rebuilds the project $(GOCMD) clean -cache - $(MAKE) -C sources/go-llama-ggml clean + $(MAKE) -C sources/go-llama.cpp clean $(MAKE) -C sources/gpt4all/gpt4all-bindings/golang/ clean - $(MAKE) -C sources/go-rwkv clean + $(MAKE) -C sources/go-rwkv.cpp clean $(MAKE) -C sources/whisper.cpp clean $(MAKE) -C sources/go-stable-diffusion clean - $(MAKE) -C sources/go-bert clean + $(MAKE) -C sources/go-bert.cpp clean $(MAKE) -C sources/go-piper clean $(MAKE) -C sources/go-tiny-dream clean $(MAKE) build @@ -301,9 +302,6 @@ clean-tests: rm -rf test-dir rm -rf core/http/backend-assets -halt-backends: ## Used to clean up stray backends sometimes left running when debugging manually - ps | grep 'backend-assets/grpc/' | awk '{print $$1}' | xargs -I {} kill -9 {} - ## Build: build: prepare backend-assets grpcs ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) @@ -368,13 +366,13 @@ run-e2e-image: run-e2e-aio: @echo 'Running e2e AIO tests' - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e-aio test-e2e: @echo 'Running e2e tests' BUILD_TYPE=$(BUILD_TYPE) \ LOCALAI_API=http://$(E2E_BRIDGE_IP):5390/v1 \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e teardown-e2e: rm -rf $(TEST_DIR) || true @@ -382,15 +380,15 @@ teardown-e2e: test-gpt4all: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r $(TEST_PATHS) test-llama: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r $(TEST_PATHS) test-llama-gguf: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r $(TEST_PATHS) test-tts: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ @@ -601,8 +599,8 @@ backend-assets/gpt4all: sources/gpt4all sources/gpt4all/gpt4all-bindings/golang/ backend-assets/grpc: protogen-go replace mkdir -p backend-assets/grpc -backend-assets/grpc/bert-embeddings: sources/go-bert sources/go-bert/libgobert.a backend-assets/grpc - CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-bert LIBRARY_PATH=$(CURDIR)/sources/go-bert \ +backend-assets/grpc/bert-embeddings: sources/go-bert.cpp sources/go-bert.cpp/libgobert.a backend-assets/grpc + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-bert.cpp LIBRARY_PATH=$(CURDIR)/sources/go-bert.cpp \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bert-embeddings ./backend/go/llm/bert/ backend-assets/grpc/gpt4all: sources/gpt4all sources/gpt4all/gpt4all-bindings/golang/libgpt4all.a backend-assets/gpt4all backend-assets/grpc @@ -644,20 +642,16 @@ ifeq ($(BUILD_TYPE),metal) cp backend/cpp/llama/llama.cpp/build/bin/default.metallib backend-assets/grpc/ endif -backend-assets/grpc/llama-ggml: sources/go-llama-ggml sources/go-llama-ggml/libbinding.a backend-assets/grpc - $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(CURDIR)/sources/go-llama-ggml - CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-llama-ggml LIBRARY_PATH=$(CURDIR)/sources/go-llama-ggml \ +backend-assets/grpc/llama-ggml: sources/go-llama.cpp sources/go-llama.cpp/libbinding.a backend-assets/grpc + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-llama.cpp LIBRARY_PATH=$(CURDIR)/sources/go-llama.cpp \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama-ggml ./backend/go/llm/llama-ggml/ -# EXPERIMENTAL: -ifeq ($(BUILD_TYPE),metal) - cp $(CURDIR)/sources/go-llama-ggml/llama.cpp/ggml-metal.metal backend-assets/grpc/ -endif + backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/ -backend-assets/grpc/rwkv: sources/go-rwkv sources/go-rwkv/librwkv.a backend-assets/grpc - CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-rwkv LIBRARY_PATH=$(CURDIR)/sources/go-rwkv \ +backend-assets/grpc/rwkv: sources/go-rwkv.cpp sources/go-rwkv.cpp/librwkv.a backend-assets/grpc + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-rwkv.cpp LIBRARY_PATH=$(CURDIR)/sources/go-rwkv.cpp \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/rwkv ./backend/go/llm/rwkv backend-assets/grpc/stablediffusion: sources/go-stable-diffusion sources/go-stable-diffusion/libstablediffusion.a backend-assets/grpc @@ -720,4 +714,4 @@ docker-image-intel-xpu: .PHONY: swagger swagger: - swag init -g core/http/api.go --output swagger + swag init -g core/http/app.go --output swagger diff --git a/README.md b/README.md index 4c2f68b2c4fd..e28e3cb0862c 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ [Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) +- llama3: https://github.com/mudler/LocalAI/discussions/2076 - Parler-TTS: https://github.com/mudler/LocalAI/pull/2027 - Landing page: https://github.com/mudler/LocalAI/pull/1922 - Openvino support: https://github.com/mudler/LocalAI/pull/1892 diff --git a/aio/cpu/text-to-text.yaml b/aio/cpu/text-to-text.yaml index 6c4ec9e68ce0..cf18f659ae8d 100644 --- a/aio/cpu/text-to-text.yaml +++ b/aio/cpu/text-to-text.yaml @@ -6,14 +6,22 @@ parameters: template: chat_message: | <|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}} - {{- if .FunctionCall }}{{end}} - {{- if eq .RoleName "tool" }}{{end }} + {{- if .FunctionCall }} + + {{- else if eq .RoleName "tool" }} + + {{- end }} {{- if .Content}} - {{.Content}} + {{.Content }} + {{- end }} + {{- if .FunctionCall}} + {{toJson .FunctionCall}} + {{- end }} + {{- if .FunctionCall }} + + {{- else if eq .RoleName "tool" }} + {{- end }} - {{- if .FunctionCall}}{{toJson .FunctionCall}}{{end }} - {{- if .FunctionCall }}{{end }} - {{- if eq .RoleName "tool" }}{{end }} <|im_end|> # https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF#prompt-format-for-function-calling function: | diff --git a/aio/gpu-8g/text-to-text.yaml b/aio/gpu-8g/text-to-text.yaml index 8d5c84f772f0..0407bb2292dc 100644 --- a/aio/gpu-8g/text-to-text.yaml +++ b/aio/gpu-8g/text-to-text.yaml @@ -6,14 +6,22 @@ parameters: template: chat_message: | <|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}} - {{- if .FunctionCall }}{{end}} - {{- if eq .RoleName "tool" }}{{end }} + {{- if .FunctionCall }} + + {{- else if eq .RoleName "tool" }} + + {{- end }} {{- if .Content}} - {{.Content}} + {{.Content }} + {{- end }} + {{- if .FunctionCall}} + {{toJson .FunctionCall}} + {{- end }} + {{- if .FunctionCall }} + + {{- else if eq .RoleName "tool" }} + {{- end }} - {{- if .FunctionCall}}{{toJson .FunctionCall}}{{end }} - {{- if .FunctionCall }}{{end }} - {{- if eq .RoleName "tool" }}{{end }} <|im_end|> # https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF#prompt-format-for-function-calling function: | diff --git a/aio/intel/text-to-text.yaml b/aio/intel/text-to-text.yaml index a7cb5b4daf71..f5f93c14ff05 100644 --- a/aio/intel/text-to-text.yaml +++ b/aio/intel/text-to-text.yaml @@ -7,14 +7,22 @@ parameters: template: chat_message: | <|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}} - {{- if .FunctionCall }}{{end}} - {{- if eq .RoleName "tool" }}{{end }} + {{- if .FunctionCall }} + + {{- else if eq .RoleName "tool" }} + + {{- end }} {{- if .Content}} - {{.Content}} + {{.Content }} + {{- end }} + {{- if .FunctionCall}} + {{toJson .FunctionCall}} + {{- end }} + {{- if .FunctionCall }} + + {{- else if eq .RoleName "tool" }} + {{- end }} - {{- if .FunctionCall}}{{toJson .FunctionCall}}{{end }} - {{- if .FunctionCall }}{{end }} - {{- if eq .RoleName "tool" }}{{end }} <|im_end|> # https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF#prompt-format-for-function-calling function: | diff --git a/backend/backend.proto b/backend/backend.proto index 62e1a1a64448..ec01e4a7c7b9 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -177,6 +177,7 @@ message ModelOptions { bool EnforceEager = 52; int32 SwapSpace = 53; int32 MaxModelLen = 54; + int32 TensorParallelSize = 55; string MMProj = 41; diff --git a/backend/go/transcribe/transcript.go b/backend/go/transcribe/transcript.go index b38d5b9f5654..fdfaa974395d 100644 --- a/backend/go/transcribe/transcript.go +++ b/backend/go/transcribe/transcript.go @@ -21,7 +21,7 @@ func runCommand(command []string) (string, error) { // AudioToWav converts audio to wav for transcribe. // TODO: use https://github.com/mccoyst/ogg? func audioToWav(src, dst string) error { - command := []string{"ffmpeg", "-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst} + command := []string{"ffmpeg", "-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst} out, err := runCommand(command) if err != nil { return fmt.Errorf("error: %w out: %s", err, out) @@ -29,8 +29,8 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.TranscriptionResult, error) { - res := schema.TranscriptionResult{} +func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) { + res := schema.Result{} dir, err := os.MkdirTemp("", "whisper") if err != nil { diff --git a/backend/go/transcribe/whisper.go b/backend/go/transcribe/whisper.go index a9a62d249d2e..ac93be01195b 100644 --- a/backend/go/transcribe/whisper.go +++ b/backend/go/transcribe/whisper.go @@ -21,6 +21,6 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error { return err } -func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) { +func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) { return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) } diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index c7f1cd75aebe..1b38a9567cb2 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -148,7 +148,8 @@ def LoadModel(self, request, context): else: device_map="CPU" self.model = OVModelForCausalLM.from_pretrained(model_name, - compile=True, + compile=True, + ov_config={"PERFORMANCE_HINT": "LATENCY"}, device=device_map) self.OV = True else: @@ -212,12 +213,25 @@ async def _predict(self, request, context, streaming=False): set_seed(request.Seed) if request.TopP == 0: request.TopP = 0.9 + + if request.TopK == 0: + request.TopK = 40 max_tokens = 200 if request.Tokens > 0: max_tokens = request.Tokens - inputs = self.tokenizer(request.Prompt, return_tensors="pt") + prompt = request.Prompt + if not request.Prompt and request.UseTokenizerTemplate and request.Messages: + prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True) + + eos_token_id = self.tokenizer.eos_token_id + if request.StopPrompts: + eos_token_id = [] + for word in request.StopPrompts: + eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word)) + + inputs = self.tokenizer(prompt, return_tensors="pt") if self.CUDA: inputs = inputs.to("cuda") if XPU and self.OV == False: @@ -235,7 +249,7 @@ async def _predict(self, request, context, streaming=False): top_k=request.TopK, do_sample=True, attention_mask=inputs["attention_mask"], - eos_token_id=self.tokenizer.eos_token_id, + eos_token_id=eos_token_id, pad_token_id=self.tokenizer.eos_token_id, streamer=streamer) thread=Thread(target=self.model.generate, kwargs=config) @@ -264,7 +278,7 @@ async def _predict(self, request, context, streaming=False): top_k=request.TopK, do_sample=True, attention_mask=inputs["attention_mask"], - eos_token_id=self.tokenizer.eos_token_id, + eos_token_id=eos_token_id, pad_token_id=self.tokenizer.eos_token_id) generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] diff --git a/backend/python/vllm/backend_vllm.py b/backend/python/vllm/backend_vllm.py index ff0f0b2608f9..2d8b55db6190 100644 --- a/backend/python/vllm/backend_vllm.py +++ b/backend/python/vllm/backend_vllm.py @@ -95,6 +95,8 @@ async def LoadModel(self, request, context): engine_args.trust_remote_code = request.TrustRemoteCode if request.EnforceEager: engine_args.enforce_eager = request.EnforceEager + if request.TensorParallelSize: + engine_args.tensor_parallel_size = request.TensorParallelSize if request.SwapSpace != 0: engine_args.swap_space = request.SwapSpace if request.MaxModelLen != 0: diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index 2c63dedc7cf0..03ff90b99208 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -2,100 +2,14 @@ package backend import ( "fmt" - "time" "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/google/uuid" - "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" ) -type EmbeddingsBackendService struct { - ml *model.ModelLoader - bcl *config.BackendConfigLoader - appConfig *config.ApplicationConfig -} - -func NewEmbeddingsBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *EmbeddingsBackendService { - return &EmbeddingsBackendService{ - ml: ml, - bcl: bcl, - appConfig: appConfig, - } -} - -func (ebs *EmbeddingsBackendService) Embeddings(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] { - - resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - go func(request *schema.OpenAIRequest) { - if request.Model == "" { - request.Model = model.StableDiffusionBackend - } - - bc, request, err := ebs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, ebs.appConfig) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - items := []schema.Item{} - - for i, s := range bc.InputToken { - // get the model function to call for the result - embedFn, err := modelEmbedding("", s, ebs.ml, bc, ebs.appConfig) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - embeddings, err := embedFn() - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - for i, s := range bc.InputStrings { - // get the model function to call for the result - embedFn, err := modelEmbedding(s, []int{}, ebs.ml, bc, ebs.appConfig) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - embeddings, err := embedFn() - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Data: items, - Object: "list", - } - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp} - close(resultChannel) - }(request) - return resultChannel -} - -func modelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { modelFile := backendConfig.Model grpcOpts := gRPCModelOpts(backendConfig) diff --git a/core/backend/image.go b/core/backend/image.go index affb3bb33863..b0cffb0b80e7 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -1,252 +1,18 @@ package backend import ( - "bufio" - "encoding/base64" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strconv" - "strings" - "time" - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/google/uuid" - "github.com/rs/zerolog/log" - "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" ) -type ImageGenerationBackendService struct { - ml *model.ModelLoader - bcl *config.BackendConfigLoader - appConfig *config.ApplicationConfig - BaseUrlForGeneratedImages string -} - -func NewImageGenerationBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ImageGenerationBackendService { - return &ImageGenerationBackendService{ - ml: ml, - bcl: bcl, - appConfig: appConfig, - } -} - -func (igbs *ImageGenerationBackendService) GenerateImage(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] { - resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - go func(request *schema.OpenAIRequest) { - bc, request, err := igbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, igbs.appConfig) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - src := "" - if request.File != "" { - - var fileData []byte - // check if input.File is an URL, if so download it and save it - // to a temporary file - if strings.HasPrefix(request.File, "http://") || strings.HasPrefix(request.File, "https://") { - out, err := downloadFile(request.File) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed downloading file:%w", err)} - close(resultChannel) - return - } - defer os.RemoveAll(out) - - fileData, err = os.ReadFile(out) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed reading file:%w", err)} - close(resultChannel) - return - } - - } else { - // base 64 decode the file and write it somewhere - // that we will cleanup - fileData, err = base64.StdEncoding.DecodeString(request.File) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - } - - // Create a temporary file - outputFile, err := os.CreateTemp(igbs.appConfig.ImageDir, "b64") - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - // write the base64 result - writer := bufio.NewWriter(outputFile) - _, err = writer.Write(fileData) - if err != nil { - outputFile.Close() - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - outputFile.Close() - src = outputFile.Name() - defer os.RemoveAll(src) - } - - log.Debug().Msgf("Parameter Config: %+v", bc) - - switch bc.Backend { - case "stablediffusion": - bc.Backend = model.StableDiffusionBackend - case "tinydream": - bc.Backend = model.TinyDreamBackend - case "": - bc.Backend = model.StableDiffusionBackend - if bc.Model == "" { - bc.Model = "stablediffusion_assets" // TODO: check? - } - } - - sizeParts := strings.Split(request.Size, "x") - if len(sizeParts) != 2 { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")} - close(resultChannel) - return - } - width, err := strconv.Atoi(sizeParts[0]) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")} - close(resultChannel) - return - } - height, err := strconv.Atoi(sizeParts[1]) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")} - close(resultChannel) - return - } - - b64JSON := false - if request.ResponseFormat.Type == "b64_json" { - b64JSON = true - } - // src and clip_skip - var result []schema.Item - for _, i := range bc.PromptStrings { - n := request.N - if request.N == 0 { - n = 1 - } - for j := 0; j < n; j++ { - prompts := strings.Split(i, "|") - positive_prompt := prompts[0] - negative_prompt := "" - if len(prompts) > 1 { - negative_prompt = prompts[1] - } - - mode := 0 - step := bc.Step - if step == 0 { - step = 15 - } - - if request.Mode != 0 { - mode = request.Mode - } - - if request.Step != 0 { - step = request.Step - } - - tempDir := "" - if !b64JSON { - tempDir = igbs.appConfig.ImageDir - } - // Create a temporary file - outputFile, err := os.CreateTemp(tempDir, "b64") - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - outputFile.Close() - output := outputFile.Name() + ".png" - // Rename the temporary file - err = os.Rename(outputFile.Name(), output) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - if request.Seed == nil { - zVal := 0 // Idiomatic way to do this? Actually needed? - request.Seed = &zVal - } - - fn, err := imageGeneration(height, width, mode, step, *request.Seed, positive_prompt, negative_prompt, src, output, igbs.ml, bc, igbs.appConfig) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - if err := fn(); err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - item := &schema.Item{} - - if b64JSON { - defer os.RemoveAll(output) - data, err := os.ReadFile(output) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - item.B64JSON = base64.StdEncoding.EncodeToString(data) - } else { - base := filepath.Base(output) - item.URL = igbs.BaseUrlForGeneratedImages + base - } - - result = append(result, *item) - } - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Data: result, - } - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp} - close(resultChannel) - }(request) - return resultChannel -} - -func imageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { - +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { threads := backendConfig.Threads if *threads == 0 && appConfig.Threads != 0 { threads = &appConfig.Threads } - gRPCOpts := gRPCModelOpts(backendConfig) - opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(backendConfig.Backend), model.WithAssetDir(appConfig.AssetsDestination), @@ -284,24 +50,3 @@ func imageGeneration(height, width, mode, step, seed int, positive_prompt, negat return fn, nil } - -// TODO: Replace this function with pkg/downloader - no reason to have a (crappier) bespoke download file fn here, but get things working before that change. -func downloadFile(url string) (string, error) { - // Get the data - resp, err := http.Get(url) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // Create the file - out, err := os.CreateTemp("", "image") - if err != nil { - return "", err - } - defer out.Close() - - // Write the body to file - _, err = io.Copy(out, resp.Body) - return out.Name(), err -} diff --git a/core/backend/llm.go b/core/backend/llm.go index 75766d78d437..a4d1e5f35e42 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -11,117 +11,75 @@ import ( "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/rs/zerolog/log" - "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" ) -type LLMRequest struct { - Id int // TODO Remove if not used. - Text string - Images []string - RawMessages []schema.Message - // TODO: Other Modalities? -} - -type TokenUsage struct { - Prompt int - Completion int -} - type LLMResponse struct { - Request *LLMRequest Response string // should this be []byte? Usage TokenUsage } -// TODO: Does this belong here or in core/services/openai.go? -type LLMResponseBundle struct { - Request *schema.OpenAIRequest - Response []schema.Choice - Usage TokenUsage -} - -type LLMBackendService struct { - bcl *config.BackendConfigLoader - ml *model.ModelLoader - appConfig *config.ApplicationConfig - ftMutex sync.Mutex - cutstrings map[string]*regexp.Regexp -} - -func NewLLMBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *LLMBackendService { - return &LLMBackendService{ - bcl: bcl, - ml: ml, - appConfig: appConfig, - ftMutex: sync.Mutex{}, - cutstrings: make(map[string]*regexp.Regexp), - } +type TokenUsage struct { + Prompt int + Completion int } -// TODO: Should ctx param be removed and replaced with hardcoded req.Context? -func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, bc *config.BackendConfig, enableTokenChannel bool) ( - resultChannel <-chan concurrency.ErrorOr[*LLMResponse], tokenChannel <-chan concurrency.ErrorOr[*LLMResponse], err error) { - - threads := bc.Threads - if (threads == nil || *threads == 0) && llmbs.appConfig.Threads != 0 { - threads = &llmbs.appConfig.Threads +func ModelInference(ctx context.Context, s string, messages []schema.Message, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { + modelFile := c.Model + threads := c.Threads + if *threads == 0 && o.Threads != 0 { + threads = &o.Threads } - - grpcOpts := gRPCModelOpts(bc) + grpcOpts := gRPCModelOpts(c) var inferenceModel grpc.Backend + var err error - opts := modelOpts(bc, llmbs.appConfig, []model.Option{ + opts := modelOpts(c, o, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup - model.WithAssetDir(llmbs.appConfig.AssetsDestination), - model.WithModel(bc.Model), - model.WithContext(llmbs.appConfig.Context), + model.WithAssetDir(o.AssetsDestination), + model.WithModel(modelFile), + model.WithContext(o.Context), }) - if bc.Backend != "" { - opts = append(opts, model.WithBackendString(bc.Backend)) + if c.Backend != "" { + opts = append(opts, model.WithBackendString(c.Backend)) } - // Check if bc.Model exists, if it doesn't try to load it from the gallery - if llmbs.appConfig.AutoloadGalleries { // experimental - if _, err := os.Stat(bc.Model); os.IsNotExist(err) { + // Check if the modelFile exists, if it doesn't try to load it from the gallery + if o.AutoloadGalleries { // experimental + if _, err := os.Stat(modelFile); os.IsNotExist(err) { utils.ResetDownloadTimers() // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGalleryByName(llmbs.appConfig.Galleries, bc.Model, llmbs.appConfig.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) + err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) if err != nil { - return nil, nil, err + return nil, err } } } - if bc.Backend == "" { - log.Debug().Msgf("backend not known for %q, falling back to greedy loader to find it", bc.Model) - inferenceModel, err = llmbs.ml.GreedyLoader(opts...) + if c.Backend == "" { + inferenceModel, err = loader.GreedyLoader(opts...) } else { - inferenceModel, err = llmbs.ml.BackendLoader(opts...) + inferenceModel, err = loader.BackendLoader(opts...) } if err != nil { - log.Error().Err(err).Msg("[llmbs.Inference] failed to load a backend") - return + return nil, err } - grpcPredOpts := gRPCPredictOpts(bc, llmbs.appConfig.ModelPath) - grpcPredOpts.Prompt = req.Text - grpcPredOpts.Images = req.Images - - if bc.TemplateConfig.UseTokenizerTemplate && req.Text == "" { - grpcPredOpts.UseTokenizerTemplate = true - protoMessages := make([]*proto.Message, len(req.RawMessages), len(req.RawMessages)) - for i, message := range req.RawMessages { + var protoMessages []*proto.Message + // if we are using the tokenizer template, we need to convert the messages to proto messages + // unless the prompt has already been tokenized (non-chat endpoints + functions) + if c.TemplateConfig.UseTokenizerTemplate && s == "" { + protoMessages = make([]*proto.Message, len(messages), len(messages)) + for i, message := range messages { protoMessages[i] = &proto.Message{ Role: message.Role, } @@ -129,32 +87,47 @@ func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, case string: protoMessages[i].Content = ct default: - err = fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct) - return + return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct) } } } - tokenUsage := TokenUsage{} + // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported + fn := func() (LLMResponse, error) { + opts := gRPCPredictOpts(c, loader.ModelPath) + opts.Prompt = s + opts.Messages = protoMessages + opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate + opts.Images = images - promptInfo, pErr := inferenceModel.TokenizeString(ctx, grpcPredOpts) - if pErr == nil && promptInfo.Length > 0 { - tokenUsage.Prompt = int(promptInfo.Length) - } + tokenUsage := TokenUsage{} - rawResultChannel := make(chan concurrency.ErrorOr[*LLMResponse]) - // TODO this next line is the biggest argument for taking named return values _back_ out!!! - var rawTokenChannel chan concurrency.ErrorOr[*LLMResponse] + // check the per-model feature flag for usage, since tokenCallback may have a cost. + // Defaults to off as for now it is still experimental + if c.FeatureFlag.Enabled("usage") { + userTokenCallback := tokenCallback + if userTokenCallback == nil { + userTokenCallback = func(token string, usage TokenUsage) bool { + return true + } + } - if enableTokenChannel { - rawTokenChannel = make(chan concurrency.ErrorOr[*LLMResponse]) + promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) + if pErr == nil && promptInfo.Length > 0 { + tokenUsage.Prompt = int(promptInfo.Length) + } - // TODO Needs better name - ss := "" + tokenCallback = func(token string, usage TokenUsage) bool { + tokenUsage.Completion++ + return userTokenCallback(token, tokenUsage) + } + } + + if tokenCallback != nil { + ss := "" - go func() { var partialRune []byte - err := inferenceModel.PredictStream(ctx, grpcPredOpts, func(chars []byte) { + err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) { partialRune = append(partialRune, chars...) for len(partialRune) > 0 { @@ -164,126 +137,54 @@ func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, break } - tokenUsage.Completion++ - rawTokenChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{ - Response: string(r), - Usage: tokenUsage, - }} - + tokenCallback(string(r), tokenUsage) ss += string(r) partialRune = partialRune[size:] } }) - close(rawTokenChannel) + return LLMResponse{ + Response: ss, + Usage: tokenUsage, + }, err + } else { + // TODO: Is the chicken bit the only way to get here? is that acceptable? + reply, err := inferenceModel.Predict(ctx, opts) if err != nil { - rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err} - } else { - rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{ - Response: ss, - Usage: tokenUsage, - }} + return LLMResponse{}, err } - close(rawResultChannel) - }() - } else { - go func() { - reply, err := inferenceModel.Predict(ctx, grpcPredOpts) if tokenUsage.Prompt == 0 { tokenUsage.Prompt = int(reply.PromptTokens) } if tokenUsage.Completion == 0 { tokenUsage.Completion = int(reply.Tokens) } - if err != nil { - rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err} - close(rawResultChannel) - } else { - rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{ - Response: string(reply.Message), - Usage: tokenUsage, - }} - close(rawResultChannel) - } - }() - } - - resultChannel = rawResultChannel - tokenChannel = rawTokenChannel - return -} - -// TODO: Should predInput be a seperate param still, or should this fn handle extracting it from request?? -func (llmbs *LLMBackendService) GenerateText(predInput string, request *schema.OpenAIRequest, bc *config.BackendConfig, - mappingFn func(*LLMResponse) schema.Choice, enableCompletionChannels bool, enableTokenChannels bool) ( - // Returns: - resultChannel <-chan concurrency.ErrorOr[*LLMResponseBundle], completionChannels []<-chan concurrency.ErrorOr[*LLMResponse], tokenChannels []<-chan concurrency.ErrorOr[*LLMResponse], err error) { - - rawChannel := make(chan concurrency.ErrorOr[*LLMResponseBundle]) - resultChannel = rawChannel - - if request.N == 0 { // number of completions to return - request.N = 1 - } - images := []string{} - for _, m := range request.Messages { - images = append(images, m.StringImages...) - } - - for i := 0; i < request.N; i++ { - - individualResultChannel, tokenChannel, infErr := llmbs.Inference(request.Context, &LLMRequest{ - Text: predInput, - Images: images, - RawMessages: request.Messages, - }, bc, enableTokenChannels) - if infErr != nil { - err = infErr // Avoids complaints about redeclaring err but looks dumb - return + return LLMResponse{ + Response: string(reply.Message), + Usage: tokenUsage, + }, err } - completionChannels = append(completionChannels, individualResultChannel) - tokenChannels = append(tokenChannels, tokenChannel) } - go func() { - initialBundle := LLMResponseBundle{ - Request: request, - Response: []schema.Choice{}, - Usage: TokenUsage{}, - } - - wg := concurrency.SliceOfChannelsReducer(completionChannels, rawChannel, func(iv concurrency.ErrorOr[*LLMResponse], ov concurrency.ErrorOr[*LLMResponseBundle]) concurrency.ErrorOr[*LLMResponseBundle] { - if iv.Error != nil { - ov.Error = iv.Error - // TODO: Decide if we should wipe partials or not? - return ov - } - ov.Value.Usage.Prompt += iv.Value.Usage.Prompt - ov.Value.Usage.Completion += iv.Value.Usage.Completion - - ov.Value.Response = append(ov.Value.Response, mappingFn(iv.Value)) - return ov - }, concurrency.ErrorOr[*LLMResponseBundle]{Value: &initialBundle}, true) - wg.Wait() - - }() - - return + return fn, nil } -func (llmbs *LLMBackendService) Finetune(config config.BackendConfig, input, prediction string) string { +var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) +var mu sync.Mutex = sync.Mutex{} + +func Finetune(config config.BackendConfig, input, prediction string) string { if config.Echo { prediction = input + prediction } for _, c := range config.Cutstrings { - llmbs.ftMutex.Lock() - reg, ok := llmbs.cutstrings[c] + mu.Lock() + reg, ok := cutstrings[c] if !ok { - llmbs.cutstrings[c] = regexp.MustCompile(c) - reg = llmbs.cutstrings[c] + cutstrings[c] = regexp.MustCompile(c) + reg = cutstrings[c] } - llmbs.ftMutex.Unlock() + mu.Unlock() prediction = reg.ReplaceAllString(prediction, "") } diff --git a/core/backend/options.go b/core/backend/options.go index 0b4e56db5d14..60cb01ff24b9 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -10,7 +10,7 @@ import ( model "github.com/go-skynet/LocalAI/pkg/model" ) -func modelOpts(bc *config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { +func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { if so.SingleBackend { opts = append(opts, model.WithSingleActiveBackend()) } @@ -19,12 +19,12 @@ func modelOpts(bc *config.BackendConfig, so *config.ApplicationConfig, opts []mo opts = append(opts, model.EnableParallelRequests) } - if bc.GRPC.Attempts != 0 { - opts = append(opts, model.WithGRPCAttempts(bc.GRPC.Attempts)) + if c.GRPC.Attempts != 0 { + opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) } - if bc.GRPC.AttemptsSleepTime != 0 { - opts = append(opts, model.WithGRPCAttemptsDelay(bc.GRPC.AttemptsSleepTime)) + if c.GRPC.AttemptsSleepTime != 0 { + opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) } for k, v := range so.ExternalGRPCBackends { @@ -34,7 +34,7 @@ func modelOpts(bc *config.BackendConfig, so *config.ApplicationConfig, opts []mo return opts } -func getSeed(c *config.BackendConfig) int32 { +func getSeed(c config.BackendConfig) int32 { seed := int32(*c.Seed) if seed == config.RAND_SEED { seed = rand.Int31() @@ -43,7 +43,7 @@ func getSeed(c *config.BackendConfig) int32 { return seed } -func gRPCModelOpts(c *config.BackendConfig) *pb.ModelOptions { +func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { b := 512 if c.Batch != 0 { b = c.Batch @@ -74,6 +74,7 @@ func gRPCModelOpts(c *config.BackendConfig) *pb.ModelOptions { EnforceEager: c.EnforceEager, SwapSpace: int32(c.SwapSpace), MaxModelLen: int32(c.MaxModelLen), + TensorParallelSize: int32(c.TensorParallelSize), MMProj: c.MMProj, YarnExtFactor: c.YarnExtFactor, YarnAttnFactor: c.YarnAttnFactor, @@ -104,47 +105,47 @@ func gRPCModelOpts(c *config.BackendConfig) *pb.ModelOptions { } } -func gRPCPredictOpts(bc *config.BackendConfig, modelPath string) *pb.PredictOptions { +func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions { promptCachePath := "" - if bc.PromptCachePath != "" { - p := filepath.Join(modelPath, bc.PromptCachePath) + if c.PromptCachePath != "" { + p := filepath.Join(modelPath, c.PromptCachePath) os.MkdirAll(filepath.Dir(p), 0755) promptCachePath = p } return &pb.PredictOptions{ - Temperature: float32(*bc.Temperature), - TopP: float32(*bc.TopP), - NDraft: bc.NDraft, - TopK: int32(*bc.TopK), - Tokens: int32(*bc.Maxtokens), - Threads: int32(*bc.Threads), - PromptCacheAll: bc.PromptCacheAll, - PromptCacheRO: bc.PromptCacheRO, + Temperature: float32(*c.Temperature), + TopP: float32(*c.TopP), + NDraft: c.NDraft, + TopK: int32(*c.TopK), + Tokens: int32(*c.Maxtokens), + Threads: int32(*c.Threads), + PromptCacheAll: c.PromptCacheAll, + PromptCacheRO: c.PromptCacheRO, PromptCachePath: promptCachePath, - F16KV: *bc.F16, - DebugMode: *bc.Debug, - Grammar: bc.Grammar, - NegativePromptScale: bc.NegativePromptScale, - RopeFreqBase: bc.RopeFreqBase, - RopeFreqScale: bc.RopeFreqScale, - NegativePrompt: bc.NegativePrompt, - Mirostat: int32(*bc.LLMConfig.Mirostat), - MirostatETA: float32(*bc.LLMConfig.MirostatETA), - MirostatTAU: float32(*bc.LLMConfig.MirostatTAU), - Debug: *bc.Debug, - StopPrompts: bc.StopWords, - Repeat: int32(bc.RepeatPenalty), - NKeep: int32(bc.Keep), - Batch: int32(bc.Batch), - IgnoreEOS: bc.IgnoreEOS, - Seed: getSeed(bc), - FrequencyPenalty: float32(bc.FrequencyPenalty), - MLock: *bc.MMlock, - MMap: *bc.MMap, - MainGPU: bc.MainGPU, - TensorSplit: bc.TensorSplit, - TailFreeSamplingZ: float32(*bc.TFZ), - TypicalP: float32(*bc.TypicalP), + F16KV: *c.F16, + DebugMode: *c.Debug, + Grammar: c.Grammar, + NegativePromptScale: c.NegativePromptScale, + RopeFreqBase: c.RopeFreqBase, + RopeFreqScale: c.RopeFreqScale, + NegativePrompt: c.NegativePrompt, + Mirostat: int32(*c.LLMConfig.Mirostat), + MirostatETA: float32(*c.LLMConfig.MirostatETA), + MirostatTAU: float32(*c.LLMConfig.MirostatTAU), + Debug: *c.Debug, + StopPrompts: c.StopWords, + Repeat: int32(c.RepeatPenalty), + NKeep: int32(c.Keep), + Batch: int32(c.Batch), + IgnoreEOS: c.IgnoreEOS, + Seed: getSeed(c), + FrequencyPenalty: float32(c.FrequencyPenalty), + MLock: *c.MMlock, + MMap: *c.MMap, + MainGPU: c.MainGPU, + TensorSplit: c.TensorSplit, + TailFreeSamplingZ: float32(*c.TFZ), + TypicalP: float32(*c.TypicalP), } } diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 6761c2acef6d..4c3859dfed02 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -7,48 +7,11 @@ import ( "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" ) -type TranscriptionBackendService struct { - ml *model.ModelLoader - bcl *config.BackendConfigLoader - appConfig *config.ApplicationConfig -} - -func NewTranscriptionBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TranscriptionBackendService { - return &TranscriptionBackendService{ - ml: ml, - bcl: bcl, - appConfig: appConfig, - } -} - -func (tbs *TranscriptionBackendService) Transcribe(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.TranscriptionResult] { - responseChannel := make(chan concurrency.ErrorOr[*schema.TranscriptionResult]) - go func(request *schema.OpenAIRequest) { - bc, request, err := tbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, tbs.appConfig) - if err != nil { - responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Error: fmt.Errorf("failed reading parameters from request:%w", err)} - close(responseChannel) - return - } - - tr, err := modelTranscription(request.File, request.Language, tbs.ml, bc, tbs.appConfig) - if err != nil { - responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Error: err} - close(responseChannel) - return - } - responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Value: tr} - close(responseChannel) - }(request) - return responseChannel -} - -func modelTranscription(audio, language string, ml *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { +func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) { opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(model.WhisperBackend), diff --git a/core/backend/tts.go b/core/backend/tts.go index d1fa270dd5ec..f97b6202ef71 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -7,60 +7,29 @@ import ( "path/filepath" "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" ) -type TextToSpeechBackendService struct { - ml *model.ModelLoader - bcl *config.BackendConfigLoader - appConfig *config.ApplicationConfig -} - -func NewTextToSpeechBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TextToSpeechBackendService { - return &TextToSpeechBackendService{ - ml: ml, - bcl: bcl, - appConfig: appConfig, - } -} - -func (ttsbs *TextToSpeechBackendService) TextToAudioFile(request *schema.TTSRequest) <-chan concurrency.ErrorOr[*string] { - responseChannel := make(chan concurrency.ErrorOr[*string]) - go func(request *schema.TTSRequest) { - cfg, err := ttsbs.bcl.LoadBackendConfigFileByName(request.Model, ttsbs.appConfig.ModelPath, - config.LoadOptionDebug(ttsbs.appConfig.Debug), - config.LoadOptionThreads(ttsbs.appConfig.Threads), - config.LoadOptionContextSize(ttsbs.appConfig.ContextSize), - config.LoadOptionF16(ttsbs.appConfig.F16), - ) - if err != nil { - responseChannel <- concurrency.ErrorOr[*string]{Error: err} - close(responseChannel) - return - } +func generateUniqueFileName(dir, baseName, ext string) string { + counter := 1 + fileName := baseName + ext - if request.Backend != "" { - cfg.Backend = request.Backend + for { + filePath := filepath.Join(dir, fileName) + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + return fileName } - outFile, _, err := modelTTS(cfg.Backend, request.Input, cfg.Model, request.Voice, ttsbs.ml, ttsbs.appConfig, cfg) - if err != nil { - responseChannel <- concurrency.ErrorOr[*string]{Error: err} - close(responseChannel) - return - } - responseChannel <- concurrency.ErrorOr[*string]{Value: &outFile} - close(responseChannel) - }(request) - return responseChannel + counter++ + fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) + } } -func modelTTS(backend, text, modelFile string, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig *config.BackendConfig) (string, *proto.Result, error) { +func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) { bb := backend if bb == "" { bb = model.PiperBackend @@ -68,7 +37,7 @@ func modelTTS(backend, text, modelFile string, voice string, loader *model.Model grpcOpts := gRPCModelOpts(backendConfig) - opts := modelOpts(&config.BackendConfig{}, appConfig, []model.Option{ + opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ model.WithBackendString(bb), model.WithModel(modelFile), model.WithContext(appConfig.Context), @@ -118,19 +87,3 @@ func modelTTS(backend, text, modelFile string, voice string, loader *model.Model return filePath, res, err } - -func generateUniqueFileName(dir, baseName, ext string) string { - counter := 1 - fileName := baseName + ext - - for { - filePath := filepath.Join(dir, fileName) - _, err := os.Stat(filePath) - if os.IsNotExist(err) { - return fileName - } - - counter++ - fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) - } -} diff --git a/core/cli/cli.go b/core/cli/cli.go index 5e757f648611..2f2dcd8ba1bb 100644 --- a/core/cli/cli.go +++ b/core/cli/cli.go @@ -4,7 +4,7 @@ import "embed" type Context struct { Debug bool `env:"LOCALAI_DEBUG,DEBUG" default:"false" hidden:"" help:"DEPRECATED, use --log-level=debug instead. Enable debug logging"` - LogLevel *string `env:"LOCALAI_LOG_LEVEL" enum:"error,warn,info,debug" help:"Set the level of logs to output [${enum}]"` + LogLevel *string `env:"LOCALAI_LOG_LEVEL" enum:"error,warn,info,debug,trace" help:"Set the level of logs to output [${enum}]"` // This field is not a command line argument/flag, the struct tag excludes it from the parsed CLI BackendAssets embed.FS `kong:"-"` diff --git a/core/cli/models.go b/core/cli/models.go index 62ef366bf34d..6615e21d342b 100644 --- a/core/cli/models.go +++ b/core/cli/models.go @@ -25,7 +25,7 @@ type ModelsInstall struct { } type ModelsCMD struct { - List ModelsList `cmd:"" help:"List the models avaiable in your galleries" default:"withargs"` + List ModelsList `cmd:"" help:"List the models available in your galleries" default:"withargs"` Install ModelsInstall `cmd:"" help:"Install a model from the gallery"` } diff --git a/core/cli/run.go b/core/cli/run.go index cafc0b549e8c..42185a28cd7c 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -2,30 +2,31 @@ package cli import ( "fmt" - "os" "strings" "time" "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/http" "github.com/go-skynet/LocalAI/core/startup" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) type RunCMD struct { ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"` - ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` - BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"` - ImagePath string `env:"LOCALAI_IMAGE_PATH,IMAGE_PATH" type:"path" default:"/tmp/generated/images" help:"Location for images generated by backends (e.g. stablediffusion)" group:"storage"` - AudioPath string `env:"LOCALAI_AUDIO_PATH,AUDIO_PATH" type:"path" default:"/tmp/generated/audio" help:"Location for audio generated by backends (e.g. piper)" group:"storage"` - UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"` - ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"` - LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"` + ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` + BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"` + ImagePath string `env:"LOCALAI_IMAGE_PATH,IMAGE_PATH" type:"path" default:"/tmp/generated/images" help:"Location for images generated by backends (e.g. stablediffusion)" group:"storage"` + AudioPath string `env:"LOCALAI_AUDIO_PATH,AUDIO_PATH" type:"path" default:"/tmp/generated/audio" help:"Location for audio generated by backends (e.g. piper)" group:"storage"` + UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"` + ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"` + LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"` + LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"` // The alias on this option is there to preserve functionality with the old `--config-file` parameter ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"` - Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models"` + Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models"` RemoteLibrary string `env:"LOCALAI_REMOTE_LIBRARY,REMOTE_LIBRARY" default:"${remoteLibraryURL}" help:"A LocalAI remote library URL" group:"models"` PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"` @@ -60,15 +61,16 @@ func (r *RunCMD) Run(ctx *Context) error { config.WithYAMLConfigPreload(r.PreloadModelsConfig), config.WithModelPath(r.ModelsPath), config.WithContextSize(r.ContextSize), - config.WithDebug(*ctx.LogLevel == "debug"), + config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel), config.WithImageDir(r.ImagePath), config.WithAudioDir(r.AudioPath), config.WithUploadDir(r.UploadPath), config.WithConfigsDir(r.ConfigPath), + config.WithDynamicConfigDir(r.LocalaiConfigDir), + config.WithDynamicConfigDirPollInterval(r.LocalaiConfigDirPollInterval), config.WithF16(r.F16), config.WithStringGalleries(r.Galleries), config.WithModelLibraryURL(r.RemoteLibrary), - config.WithDisableMessage(false), config.WithCors(r.CORS), config.WithCorsAllowOrigins(r.CORSAllowOrigins), config.WithThreads(r.Threads), @@ -124,28 +126,16 @@ func (r *RunCMD) Run(ctx *Context) error { } if r.PreloadBackendOnly { - _, err := startup.Startup(opts...) + _, _, _, err := startup.Startup(opts...) return err } - application, err := startup.Startup(opts...) - + cl, ml, options, err := startup.Startup(opts...) if err != nil { return fmt.Errorf("failed basic startup tasks with error %s", err.Error()) } - // Watch the configuration directory - // If the directory does not exist, we don't watch it - if _, err := os.Stat(r.LocalaiConfigDir); err == nil { - closeConfigWatcherFn, err := startup.WatchConfigDirectory(r.LocalaiConfigDir, application.ApplicationConfig) - defer closeConfigWatcherFn() - - if err != nil { - return fmt.Errorf("failed while watching configuration directory %s", r.LocalaiConfigDir) - } - } - - appHTTP, err := http.App(application) + appHTTP, err := http.App(cl, ml, options) if err != nil { log.Error().Err(err).Msg("error during HTTP App construction") return err diff --git a/core/cli/transcript.go b/core/cli/transcript.go index f14a1a8771ff..9f36a77cebe7 100644 --- a/core/cli/transcript.go +++ b/core/cli/transcript.go @@ -7,7 +7,6 @@ import ( "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/model" ) @@ -44,21 +43,11 @@ func (t *TranscriptCMD) Run(ctx *Context) error { defer ml.StopAllGRPC() - tbs := backend.NewTranscriptionBackendService(ml, cl, opts) - - resultChannel := tbs.Transcribe(&schema.OpenAIRequest{ - PredictionOptions: schema.PredictionOptions{ - Language: t.Language, - }, - File: t.Filename, - }) - - r := <-resultChannel - - if r.Error != nil { - return r.Error + tr, err := backend.ModelTranscription(t.Filename, t.Language, ml, c, opts) + if err != nil { + return err } - for _, segment := range r.Value.Segments { + for _, segment := range tr.Segments { fmt.Println(segment.Start.String(), "-", segment.Text) } return nil diff --git a/core/cli/tts.go b/core/cli/tts.go index c7758c483f28..1d8fd3a39eca 100644 --- a/core/cli/tts.go +++ b/core/cli/tts.go @@ -9,7 +9,6 @@ import ( "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/model" ) @@ -43,29 +42,20 @@ func (t *TTSCMD) Run(ctx *Context) error { defer ml.StopAllGRPC() - ttsbs := backend.NewTextToSpeechBackendService(ml, config.NewBackendConfigLoader(), opts) + options := config.BackendConfig{} + options.SetDefaults() - request := &schema.TTSRequest{ - Model: t.Model, - Input: text, - Backend: t.Backend, - Voice: t.Voice, - } - - resultsChannel := ttsbs.TextToAudioFile(request) - - rawResult := <-resultsChannel - - if rawResult.Error != nil { - return rawResult.Error + filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, ml, opts, options) + if err != nil { + return err } if outputFile != "" { - if err := os.Rename(*rawResult.Value, outputFile); err != nil { + if err := os.Rename(filePath, outputFile); err != nil { return err } - fmt.Printf("Generated file %q\n", outputFile) + fmt.Printf("Generate file %s\n", outputFile) } else { - fmt.Printf("Generated file %q\n", *rawResult.Value) + fmt.Printf("Generate file %s\n", filePath) } return nil } diff --git a/core/config/application_config.go b/core/config/application_config.go index 9525553a6868..2d733c1eb02f 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -17,11 +17,13 @@ type ApplicationConfig struct { UploadLimitMB, Threads, ContextSize int DisableWelcomePage bool F16 bool - Debug, DisableMessage bool + Debug bool ImageDir string AudioDir string UploadDir string ConfigsDir string + DynamicConfigsDir string + DynamicConfigsDirPollInterval time.Duration CORS bool PreloadJSONModels string PreloadModelsFromPath string @@ -55,12 +57,11 @@ type AppOption func(*ApplicationConfig) func NewApplicationConfig(o ...AppOption) *ApplicationConfig { opt := &ApplicationConfig{ - Context: context.Background(), - UploadLimitMB: 15, - Threads: 1, - ContextSize: 512, - Debug: true, - DisableMessage: true, + Context: context.Background(), + UploadLimitMB: 15, + Threads: 1, + ContextSize: 512, + Debug: true, } for _, oo := range o { oo(opt) @@ -234,12 +235,6 @@ func WithDebug(debug bool) AppOption { } } -func WithDisableMessage(disableMessage bool) AppOption { - return func(o *ApplicationConfig) { - o.DisableMessage = disableMessage - } -} - func WithAudioDir(audioDir string) AppOption { return func(o *ApplicationConfig) { o.AudioDir = audioDir @@ -264,6 +259,18 @@ func WithConfigsDir(configsDir string) AppOption { } } +func WithDynamicConfigDir(dynamicConfigsDir string) AppOption { + return func(o *ApplicationConfig) { + o.DynamicConfigsDir = dynamicConfigsDir + } +} + +func WithDynamicConfigDirPollInterval(interval time.Duration) AppOption { + return func(o *ApplicationConfig) { + o.DynamicConfigsDirPollInterval = interval + } +} + func WithApiKeys(apiKeys []string) AppOption { return func(o *ApplicationConfig) { o.ApiKeys = apiKeys diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 47e4829d8a52..dfc216dc5b75 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -1,7 +1,23 @@ package config import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/downloader" + "github.com/go-skynet/LocalAI/pkg/functions" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" + "gopkg.in/yaml.v3" + + "github.com/charmbracelet/glamour" ) const ( @@ -24,7 +40,7 @@ type BackendConfig struct { InputToken [][]int `yaml:"-"` functionCallString, functionCallNameString string `yaml:"-"` - FunctionsConfig Functions `yaml:"function"` + FunctionsConfig functions.FunctionsConfig `yaml:"function"` FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early. // LLM configs (GPT4ALL, Llama.cpp, ...) @@ -124,6 +140,7 @@ type LLMConfig struct { EnforceEager bool `yaml:"enforce_eager"` // vLLM SwapSpace int `yaml:"swap_space"` // vLLM MaxModelLen int `yaml:"max_model_len"` // vLLM + TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM MMProj string `yaml:"mmproj"` RopeScaling string `yaml:"rope_scaling"` @@ -142,13 +159,6 @@ type AutoGPTQ struct { UseFastTokenizer bool `yaml:"use_fast_tokenizer"` } -type Functions struct { - DisableNoAction bool `yaml:"disable_no_action"` - NoActionFunctionName string `yaml:"no_action_function_name"` - NoActionDescriptionName string `yaml:"no_action_description_name"` - ParallelCalls bool `yaml:"parallel_calls"` -} - type TemplateConfig struct { Chat string `yaml:"chat"` ChatMessage string `yaml:"chat_message"` @@ -184,7 +194,7 @@ func (c *BackendConfig) FunctionToCall() string { } func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { - lo := &ConfigLoaderOptions{} + lo := &LoadOptions{} lo.Apply(opts...) ctx := lo.ctxSize @@ -195,15 +205,15 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { defaultTopP := 0.95 defaultTopK := 40 defaultTemp := 0.9 - defaultMaxTokens := 2048 defaultMirostat := 2 defaultMirostatTAU := 5.0 defaultMirostatETA := 0.1 defaultTypicalP := 1.0 defaultTFZ := 1.0 + defaultZero := 0 // Try to offload all GPU layers (if GPU is found) - defaultNGPULayers := 99999999 + defaultHigh := 99999999 trueV := true falseV := false @@ -244,7 +254,7 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { } if cfg.Maxtokens == nil { - cfg.Maxtokens = &defaultMaxTokens + cfg.Maxtokens = &defaultZero } if cfg.Mirostat == nil { @@ -259,7 +269,7 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { cfg.MirostatTAU = &defaultMirostatTAU } if cfg.NGPULayers == nil { - cfg.NGPULayers = &defaultNGPULayers + cfg.NGPULayers = &defaultHigh } if cfg.LowVRAM == nil { @@ -297,3 +307,287 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { cfg.Debug = &trueV } } + +////// Config Loader //////// + +type BackendConfigLoader struct { + configs map[string]BackendConfig + sync.Mutex +} + +type LoadOptions struct { + debug bool + threads, ctxSize int + f16 bool +} + +func LoadOptionDebug(debug bool) ConfigLoaderOption { + return func(o *LoadOptions) { + o.debug = debug + } +} + +func LoadOptionThreads(threads int) ConfigLoaderOption { + return func(o *LoadOptions) { + o.threads = threads + } +} + +func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { + return func(o *LoadOptions) { + o.ctxSize = ctxSize + } +} + +func LoadOptionF16(f16 bool) ConfigLoaderOption { + return func(o *LoadOptions) { + o.f16 = f16 + } +} + +type ConfigLoaderOption func(*LoadOptions) + +func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) { + for _, l := range options { + l(lo) + } +} + +// Load a config file for a model +func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + + // Load a config file if present after the model name + cfg := &BackendConfig{ + PredictionOptions: schema.PredictionOptions{ + Model: modelName, + }, + } + + cfgExisting, exists := cl.GetBackendConfig(modelName) + if exists { + cfg = &cfgExisting + } else { + // Try loading a model config file + modelConfig := filepath.Join(modelPath, modelName+".yaml") + if _, err := os.Stat(modelConfig); err == nil { + if err := cl.LoadBackendConfig( + modelConfig, opts..., + ); err != nil { + return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + cfgExisting, exists = cl.GetBackendConfig(modelName) + if exists { + cfg = &cfgExisting + } + } + } + + cfg.SetDefaults(opts...) + + return cfg, nil +} + +func NewBackendConfigLoader() *BackendConfigLoader { + return &BackendConfigLoader{ + configs: make(map[string]BackendConfig), + } +} +func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { + c := &[]*BackendConfig{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + for _, cc := range *c { + cc.SetDefaults(opts...) + } + + return *c, nil +} + +func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + lo := &LoadOptions{} + lo.Apply(opts...) + + c := &BackendConfig{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + c.SetDefaults(opts...) + return c, nil +} + +func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadBackendConfigFile(file, opts...) + if err != nil { + return fmt.Errorf("cannot load config file: %w", err) + } + + for _, cc := range c { + cm.configs[cc.Name] = *cc + } + return nil +} + +func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { + cl.Lock() + defer cl.Unlock() + c, err := ReadBackendConfig(file, opts...) + if err != nil { + return fmt.Errorf("cannot read config file: %w", err) + } + + cl.configs[c.Name] = *c + return nil +} + +func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { + cl.Lock() + defer cl.Unlock() + v, exists := cl.configs[m] + return v, exists +} + +func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { + cl.Lock() + defer cl.Unlock() + var res []BackendConfig + for _, v := range cl.configs { + res = append(res, v) + } + + sort.SliceStable(res, func(i, j int) bool { + return res[i].Name < res[j].Name + }) + + return res +} + +func (cl *BackendConfigLoader) ListBackendConfigs() []string { + cl.Lock() + defer cl.Unlock() + var res []string + for k := range cl.configs { + res = append(res, k) + } + return res +} + +// Preload prepare models if they are not local but url or huggingface repositories +func (cl *BackendConfigLoader) Preload(modelPath string) error { + cl.Lock() + defer cl.Unlock() + + status := func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + } + + log.Info().Msgf("Preloading models from %s", modelPath) + + renderMode := "dark" + if os.Getenv("COLOR") != "" { + renderMode = os.Getenv("COLOR") + } + + glamText := func(t string) { + out, err := glamour.Render(t, renderMode) + if err == nil && os.Getenv("NO_COLOR") == "" { + fmt.Println(out) + } else { + fmt.Println(t) + } + } + + for i, config := range cl.configs { + + // Download files and verify their SHA + for _, file := range config.DownloadFiles { + log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) + + if err := utils.VerifyPath(file.Filename, modelPath); err != nil { + return err + } + // Create file path + filePath := filepath.Join(modelPath, file.Filename) + + if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { + return err + } + } + + modelURL := config.PredictionOptions.Model + modelURL = downloader.ConvertURL(modelURL) + + if downloader.LooksLikeURL(modelURL) { + // md5 of model name + md5Name := utils.MD5(modelURL) + + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { + err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) + if err != nil { + return err + } + } + + cc := cl.configs[i] + c := &cc + c.PredictionOptions.Model = md5Name + cl.configs[i] = *c + } + if cl.configs[i].Name != "" { + glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name)) + } + if cl.configs[i].Description != "" { + //glamText("**Description**") + glamText(cl.configs[i].Description) + } + if cl.configs[i].Usage != "" { + //glamText("**Usage**") + glamText(cl.configs[i].Usage) + } + } + return nil +} + +// LoadBackendConfigsFromPath reads all the configurations of the models from a path +// (non-recursive) +func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { + cm.Lock() + defer cm.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return err + } + files := make([]fs.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + return err + } + files = append(files, info) + } + for _, file := range files { + // Skip templates, YAML and .keep files + if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { + continue + } + c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...) + if err == nil { + cm.configs[c.Name] = *c + } + } + + return nil +} diff --git a/core/config/backend_config_loader.go b/core/config/backend_config_loader.go deleted file mode 100644 index 62dfc1e03150..000000000000 --- a/core/config/backend_config_loader.go +++ /dev/null @@ -1,509 +0,0 @@ -package config - -import ( - "encoding/json" - "errors" - "fmt" - "io/fs" - "os" - "path/filepath" - "sort" - "strings" - "sync" - - "github.com/charmbracelet/glamour" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/downloader" - "github.com/go-skynet/LocalAI/pkg/grammar" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" - "gopkg.in/yaml.v2" -) - -type BackendConfigLoader struct { - configs map[string]BackendConfig - sync.Mutex -} - -type ConfigLoaderOptions struct { - debug bool - threads, ctxSize int - f16 bool -} - -func LoadOptionDebug(debug bool) ConfigLoaderOption { - return func(o *ConfigLoaderOptions) { - o.debug = debug - } -} - -func LoadOptionThreads(threads int) ConfigLoaderOption { - return func(o *ConfigLoaderOptions) { - o.threads = threads - } -} - -func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { - return func(o *ConfigLoaderOptions) { - o.ctxSize = ctxSize - } -} - -func LoadOptionF16(f16 bool) ConfigLoaderOption { - return func(o *ConfigLoaderOptions) { - o.f16 = f16 - } -} - -type ConfigLoaderOption func(*ConfigLoaderOptions) - -func (lo *ConfigLoaderOptions) Apply(options ...ConfigLoaderOption) { - for _, l := range options { - l(lo) - } -} - -func NewBackendConfigLoader() *BackendConfigLoader { - return &BackendConfigLoader{ - configs: make(map[string]BackendConfig), - } -} - -func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { - bcl.Lock() - defer bcl.Unlock() - c, err := readBackendConfig(file, opts...) - if err != nil { - return fmt.Errorf("cannot read config file: %w", err) - } - - bcl.configs[c.Name] = *c - return nil -} - -func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { - bcl.Lock() - defer bcl.Unlock() - v, exists := bcl.configs[m] - return v, exists -} - -func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { - bcl.Lock() - defer bcl.Unlock() - var res []BackendConfig - for _, v := range bcl.configs { - res = append(res, v) - } - sort.SliceStable(res, func(i, j int) bool { - return res[i].Name < res[j].Name - }) - return res -} - -func (bcl *BackendConfigLoader) ListBackendConfigs() []string { - bcl.Lock() - defer bcl.Unlock() - var res []string - for k := range bcl.configs { - res = append(res, k) - } - return res -} - -// Preload prepare models if they are not local but url or huggingface repositories -func (bcl *BackendConfigLoader) Preload(modelPath string) error { - bcl.Lock() - defer bcl.Unlock() - - status := func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - } - - log.Info().Msgf("Preloading models from %s", modelPath) - - renderMode := "dark" - if os.Getenv("COLOR") != "" { - renderMode = os.Getenv("COLOR") - } - - glamText := func(t string) { - out, err := glamour.Render(t, renderMode) - if err == nil && os.Getenv("NO_COLOR") == "" { - fmt.Println(out) - } else { - fmt.Println(t) - } - } - - for i, config := range bcl.configs { - - // Download files and verify their SHA - for _, file := range config.DownloadFiles { - log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) - - if err := utils.VerifyPath(file.Filename, modelPath); err != nil { - return err - } - // Create file path - filePath := filepath.Join(modelPath, file.Filename) - - if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { - return err - } - } - - modelURL := config.PredictionOptions.Model - modelURL = downloader.ConvertURL(modelURL) - - if downloader.LooksLikeURL(modelURL) { - // md5 of model name - md5Name := utils.MD5(modelURL) - - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) - if err != nil { - return err - } - } - - cc := bcl.configs[i] - c := &cc - c.PredictionOptions.Model = md5Name - bcl.configs[i] = *c - } - if bcl.configs[i].Name != "" { - glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name)) - } - if bcl.configs[i].Description != "" { - //glamText("**Description**") - glamText(bcl.configs[i].Description) - } - if bcl.configs[i].Usage != "" { - //glamText("**Usage**") - glamText(bcl.configs[i].Usage) - } - } - return nil -} - -func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { - bcl.Lock() - defer bcl.Unlock() - entries, err := os.ReadDir(path) - if err != nil { - return err - } - files := make([]fs.FileInfo, 0, len(entries)) - for _, entry := range entries { - info, err := entry.Info() - if err != nil { - return err - } - files = append(files, info) - } - for _, file := range files { - // Skip templates, YAML and .keep files - if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { - continue - } - c, err := readBackendConfig(filepath.Join(path, file.Name()), opts...) - if err == nil { - bcl.configs[c.Name] = *c - } - } - - return nil -} - -func (bcl *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { - bcl.Lock() - defer bcl.Unlock() - c, err := readBackendConfigFile(file, opts...) - if err != nil { - return fmt.Errorf("cannot load config file: %w", err) - } - - for _, cc := range c { - bcl.configs[cc.Name] = *cc - } - return nil -} - -////////// - -// Load a config file for a model -func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName string, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { - - // Load a config file if present after the model name - cfg := &BackendConfig{ - PredictionOptions: schema.PredictionOptions{ - Model: modelName, - }, - } - - cfgExisting, exists := bcl.GetBackendConfig(modelName) - if exists { - cfg = &cfgExisting - } else { - // Load a config file if present after the model name - modelConfig := filepath.Join(modelPath, modelName+".yaml") - if _, err := os.Stat(modelConfig); err == nil { - if err := bcl.LoadBackendConfig(modelConfig); err != nil { - return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - cfgExisting, exists = bcl.GetBackendConfig(modelName) - if exists { - cfg = &cfgExisting - } - } - } - - cfg.SetDefaults(opts...) - return cfg, nil -} - -func readBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { - c := &[]*BackendConfig{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - for _, cc := range *c { - cc.SetDefaults(opts...) - } - - return *c, nil -} - -func readBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { - c := &BackendConfig{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - c.SetDefaults(opts...) - return c, nil -} - -func (bcl *BackendConfigLoader) LoadBackendConfigForModelAndOpenAIRequest(modelFile string, input *schema.OpenAIRequest, appConfig *ApplicationConfig) (*BackendConfig, *schema.OpenAIRequest, error) { - cfg, err := bcl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, - LoadOptionContextSize(appConfig.ContextSize), - LoadOptionDebug(appConfig.Debug), - LoadOptionF16(appConfig.F16), - LoadOptionThreads(appConfig.Threads), - ) - - // Set the parameters for the language model prediction - updateBackendConfigFromOpenAIRequest(cfg, input) - - return cfg, input, err -} - -func updateBackendConfigFromOpenAIRequest(bc *BackendConfig, request *schema.OpenAIRequest) { - if request.Echo { - bc.Echo = request.Echo - } - if request.TopK != nil && *request.TopK != 0 { - bc.TopK = request.TopK - } - if request.TopP != nil && *request.TopP != 0 { - bc.TopP = request.TopP - } - - if request.Backend != "" { - bc.Backend = request.Backend - } - - if request.ClipSkip != 0 { - bc.Diffusers.ClipSkip = request.ClipSkip - } - - if request.ModelBaseName != "" { - bc.AutoGPTQ.ModelBaseName = request.ModelBaseName - } - - if request.NegativePromptScale != 0 { - bc.NegativePromptScale = request.NegativePromptScale - } - - if request.UseFastTokenizer { - bc.UseFastTokenizer = request.UseFastTokenizer - } - - if request.NegativePrompt != "" { - bc.NegativePrompt = request.NegativePrompt - } - - if request.RopeFreqBase != 0 { - bc.RopeFreqBase = request.RopeFreqBase - } - - if request.RopeFreqScale != 0 { - bc.RopeFreqScale = request.RopeFreqScale - } - - if request.Grammar != "" { - bc.Grammar = request.Grammar - } - - if request.Temperature != nil && *request.Temperature != 0 { - bc.Temperature = request.Temperature - } - - if request.Maxtokens != nil && *request.Maxtokens != 0 { - bc.Maxtokens = request.Maxtokens - } - - switch stop := request.Stop.(type) { - case string: - if stop != "" { - bc.StopWords = append(bc.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - bc.StopWords = append(bc.StopWords, s) - } - } - } - - if len(request.Tools) > 0 { - for _, tool := range request.Tools { - request.Functions = append(request.Functions, tool.Function) - } - } - - if request.ToolsChoice != nil { - var toolChoice grammar.Tool - switch content := request.ToolsChoice.(type) { - case string: - _ = json.Unmarshal([]byte(content), &toolChoice) - case map[string]interface{}: - dat, _ := json.Marshal(content) - _ = json.Unmarshal(dat, &toolChoice) - } - request.FunctionCall = map[string]interface{}{ - "name": toolChoice.Function.Name, - } - } - - // Decode each request's message content - index := 0 - for i, m := range request.Messages { - switch content := m.Content.(type) { - case string: - request.Messages[i].StringContent = content - case []interface{}: - dat, _ := json.Marshal(content) - c := []schema.Content{} - json.Unmarshal(dat, &c) - for _, pp := range c { - if pp.Type == "text" { - request.Messages[i].StringContent = pp.Text - } else if pp.Type == "image_url" { - // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: - base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL) - if err == nil { - request.Messages[i].StringImages = append(request.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - // set a placeholder for each image - request.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + request.Messages[i].StringContent - index++ - } else { - fmt.Print("Failed encoding image", err) - } - } - } - } - } - - if request.RepeatPenalty != 0 { - bc.RepeatPenalty = request.RepeatPenalty - } - - if request.FrequencyPenalty != 0 { - bc.FrequencyPenalty = request.FrequencyPenalty - } - - if request.PresencePenalty != 0 { - bc.PresencePenalty = request.PresencePenalty - } - - if request.Keep != 0 { - bc.Keep = request.Keep - } - - if request.Batch != 0 { - bc.Batch = request.Batch - } - - if request.IgnoreEOS { - bc.IgnoreEOS = request.IgnoreEOS - } - - if request.Seed != nil { - bc.Seed = request.Seed - } - - if request.TypicalP != nil { - bc.TypicalP = request.TypicalP - } - - switch inputs := request.Input.(type) { - case string: - if inputs != "" { - bc.InputStrings = append(bc.InputStrings, inputs) - } - case []interface{}: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - bc.InputStrings = append(bc.InputStrings, i) - case []interface{}: - tokens := []int{} - for _, ii := range i { - tokens = append(tokens, int(ii.(float64))) - } - bc.InputToken = append(bc.InputToken, tokens) - } - } - } - - // Can be either a string or an object - switch fnc := request.FunctionCall.(type) { - case string: - if fnc != "" { - bc.SetFunctionCallString(fnc) - } - case map[string]interface{}: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - bc.SetFunctionCallNameString(name) - } - - switch p := request.Prompt.(type) { - case string: - bc.PromptStrings = append(bc.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - bc.PromptStrings = append(bc.PromptStrings, s) - } - } - } -} diff --git a/core/config/exports_test.go b/core/config/exports_test.go deleted file mode 100644 index 70ba84e6ac59..000000000000 --- a/core/config/exports_test.go +++ /dev/null @@ -1,6 +0,0 @@ -package config - -// This file re-exports private functions to be used directly in unit tests. -// Since this file's name ends in _test.go, theoretically these should not be exposed past the tests. - -var ReadBackendConfigFile = readBackendConfigFile diff --git a/core/http/api.go b/core/http/api.go deleted file mode 100644 index 7094899a4378..000000000000 --- a/core/http/api.go +++ /dev/null @@ -1,278 +0,0 @@ -package http - -import ( - "errors" - "strings" - - "github.com/go-skynet/LocalAI/core" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" - "github.com/gofiber/swagger" // swagger handler - - "github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs" - "github.com/go-skynet/LocalAI/core/http/endpoints/localai" - "github.com/go-skynet/LocalAI/core/http/endpoints/openai" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/internal" - model "github.com/go-skynet/LocalAI/pkg/model" - - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/cors" - "github.com/gofiber/fiber/v2/middleware/logger" - "github.com/gofiber/fiber/v2/middleware/recover" -) - -func readAuthHeader(c *fiber.Ctx) string { - authHeader := c.Get("Authorization") - - // elevenlabs - xApiKey := c.Get("xi-api-key") - if xApiKey != "" { - authHeader = "Bearer " + xApiKey - } - - // anthropic - xApiKey = c.Get("x-api-key") - if xApiKey != "" { - authHeader = "Bearer " + xApiKey - } - - return authHeader -} - -// @title LocalAI API -// @version 2.0.0 -// @description The LocalAI Rest API. -// @termsOfService -// @contact.name LocalAI -// @contact.url https://localai.io -// @license.name MIT -// @license.url https://raw.githubusercontent.com/mudler/LocalAI/master/LICENSE -// @BasePath / -// @securityDefinitions.apikey BearerAuth -// @in header -// @name Authorization -func App(application *core.Application) (*fiber.App, error) { - // Return errors as JSON responses - app := fiber.New(fiber.Config{ - Views: renderEngine(), - BodyLimit: application.ApplicationConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: application.ApplicationConfig.DisableMessage, - // Override default error handler - ErrorHandler: func(ctx *fiber.Ctx, err error) error { - // Status code defaults to 500 - code := fiber.StatusInternalServerError - - // Retrieve the custom status code if it's a *fiber.Error - var e *fiber.Error - if errors.As(err, &e) { - code = e.Code - } - - // Send custom error page - return ctx.Status(code).JSON( - schema.ErrorResponse{ - Error: &schema.APIError{Message: err.Error(), Code: code}, - }, - ) - }, - }) - - if application.ApplicationConfig.Debug { - app.Use(logger.New(logger.Config{ - Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", - })) - } - - // Default middleware config - - if !application.ApplicationConfig.Debug { - app.Use(recover.New()) - } - - metricsService, err := services.NewLocalAIMetricsService() - if err != nil { - return nil, err - } - - if metricsService != nil { - app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) - app.Hooks().OnShutdown(func() error { - return metricsService.Shutdown() - }) - } - - // Auth middleware checking if API key is valid. If no API key is set, no auth is required. - auth := func(c *fiber.Ctx) error { - if len(application.ApplicationConfig.ApiKeys) == 0 { - return c.Next() - } - - authHeader := readAuthHeader(c) - if authHeader == "" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) - } - - // If it's a bearer token - authHeaderParts := strings.Split(authHeader, " ") - if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) - } - - apiKey := authHeaderParts[1] - for _, key := range application.ApplicationConfig.ApiKeys { - if apiKey == key { - return c.Next() - } - } - - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) - } - - if application.ApplicationConfig.CORS { - var c func(ctx *fiber.Ctx) error - if application.ApplicationConfig.CORSAllowOrigins == "" { - c = cors.New() - } else { - c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig.CORSAllowOrigins}) - } - - app.Use(c) - } - - fiberContextExtractor := fiberContext.NewFiberContextExtractor(application.ModelLoader, application.ApplicationConfig) - - // LocalAI API endpoints - galleryService := services.NewGalleryService(application.ApplicationConfig.ModelPath) - galleryService.Start(application.ApplicationConfig.Context, application.BackendConfigLoader) - - app.Get("/version", auth, func(c *fiber.Ctx) error { - return c.JSON(struct { - Version string `json:"version"` - }{Version: internal.PrintableVersion()}) - }) - - app.Get("/swagger/*", swagger.HandlerDefault) // default - - welcomeRoute( - app, - application.BackendConfigLoader, - application.ModelLoader, - application.ApplicationConfig, - auth, - ) - - modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(application.ApplicationConfig.Galleries, application.ApplicationConfig.ModelPath, galleryService) - app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) - app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint()) - app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint()) - app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint()) - app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint()) - app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint()) - app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint()) - - // Stores - storeLoader := model.NewModelLoader("") // TODO: Investigate if this should be migrated to application and reused. Should the path be configurable? Merging for now. - app.Post("/stores/set", auth, localai.StoresSetEndpoint(storeLoader, application.ApplicationConfig)) - app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(storeLoader, application.ApplicationConfig)) - app.Post("/stores/get", auth, localai.StoresGetEndpoint(storeLoader, application.ApplicationConfig)) - app.Post("/stores/find", auth, localai.StoresFindEndpoint(storeLoader, application.ApplicationConfig)) - - // openAI compatible API endpoints - - // chat - app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(fiberContextExtractor, application.OpenAIService)) - app.Post("/chat/completions", auth, openai.ChatEndpoint(fiberContextExtractor, application.OpenAIService)) - - // edit - app.Post("/v1/edits", auth, openai.EditEndpoint(fiberContextExtractor, application.OpenAIService)) - app.Post("/edits", auth, openai.EditEndpoint(fiberContextExtractor, application.OpenAIService)) - - // assistant - // TODO: Refactor this to the new style eventually - app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - - // files - app.Post("/v1/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Post("/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/v1/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - - // completion - app.Post("/v1/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService)) - app.Post("/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService)) - app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService)) - - // embeddings - app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService)) - app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService)) - app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService)) - - // audio - app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(fiberContextExtractor, application.TranscriptionBackendService)) - app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService)) - - // images - app.Post("/v1/images/generations", auth, openai.ImageEndpoint(fiberContextExtractor, application.ImageGenerationBackendService)) - - // Elevenlabs - app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService)) - - // LocalAI TTS? - app.Post("/tts", auth, localai.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService)) - - if application.ApplicationConfig.ImageDir != "" { - app.Static("/generated-images", application.ApplicationConfig.ImageDir) - } - - if application.ApplicationConfig.AudioDir != "" { - app.Static("/generated-audio", application.ApplicationConfig.AudioDir) - } - - ok := func(c *fiber.Ctx) error { - return c.SendStatus(200) - } - - // Kubernetes health checks - app.Get("/healthz", ok) - app.Get("/readyz", ok) - - // Experimental Backend Statistics Module - app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(application.BackendMonitorService)) - app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(application.BackendMonitorService)) - - // models - app.Get("/v1/models", auth, openai.ListModelsEndpoint(application.ListModelsService)) - app.Get("/models", auth, openai.ListModelsEndpoint(application.ListModelsService)) - - app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint()) - - // Define a custom 404 handler - // Note: keep this at the bottom! - app.Use(notFoundHandler) - - return app, nil -} diff --git a/core/http/app.go b/core/http/app.go new file mode 100644 index 000000000000..1061627f6550 --- /dev/null +++ b/core/http/app.go @@ -0,0 +1,199 @@ +package http + +import ( + "encoding/json" + "errors" + "os" + "strings" + + "github.com/go-skynet/LocalAI/pkg/utils" + + "github.com/go-skynet/LocalAI/core/http/endpoints/localai" + "github.com/go-skynet/LocalAI/core/http/endpoints/openai" + "github.com/go-skynet/LocalAI/core/http/routes" + + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/model" + + "github.com/gofiber/contrib/fiberzerolog" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/cors" + "github.com/gofiber/fiber/v2/middleware/recover" + + // swagger handler + "github.com/rs/zerolog/log" +) + +func readAuthHeader(c *fiber.Ctx) string { + authHeader := c.Get("Authorization") + + // elevenlabs + xApiKey := c.Get("xi-api-key") + if xApiKey != "" { + authHeader = "Bearer " + xApiKey + } + + // anthropic + xApiKey = c.Get("x-api-key") + if xApiKey != "" { + authHeader = "Bearer " + xApiKey + } + + return authHeader +} + +// @title LocalAI API +// @version 2.0.0 +// @description The LocalAI Rest API. +// @termsOfService +// @contact.name LocalAI +// @contact.url https://localai.io +// @license.name MIT +// @license.url https://raw.githubusercontent.com/mudler/LocalAI/master/LICENSE +// @BasePath / +// @securityDefinitions.apikey BearerAuth +// @in header +// @name Authorization + +func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) { + // Return errors as JSON responses + app := fiber.New(fiber.Config{ + Views: renderEngine(), + BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + // We disable the Fiber startup message as it does not conform to structured logging. + // We register a startup log line with connection information in the OnListen hook to keep things user friendly though + DisableStartupMessage: true, + // Override default error handler + ErrorHandler: func(ctx *fiber.Ctx, err error) error { + // Status code defaults to 500 + code := fiber.StatusInternalServerError + + // Retrieve the custom status code if it's a *fiber.Error + var e *fiber.Error + if errors.As(err, &e) { + code = e.Code + } + + // Send custom error page + return ctx.Status(code).JSON( + schema.ErrorResponse{ + Error: &schema.APIError{Message: err.Error(), Code: code}, + }, + ) + }, + }) + + app.Hooks().OnListen(func(listenData fiber.ListenData) error { + scheme := "http" + if listenData.TLS { + scheme = "https" + } + log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.") + return nil + }) + + // Have Fiber use zerolog like the rest of the application rather than it's built-in logger + logger := log.Logger + app.Use(fiberzerolog.New(fiberzerolog.Config{ + Logger: &logger, + })) + + // Default middleware config + + if !appConfig.Debug { + app.Use(recover.New()) + } + + metricsService, err := services.NewLocalAIMetricsService() + if err != nil { + return nil, err + } + + if metricsService != nil { + app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) + app.Hooks().OnShutdown(func() error { + return metricsService.Shutdown() + }) + } + + // Auth middleware checking if API key is valid. If no API key is set, no auth is required. + auth := func(c *fiber.Ctx) error { + if len(appConfig.ApiKeys) == 0 { + return c.Next() + } + + // Check for api_keys.json file + fileContent, err := os.ReadFile("api_keys.json") + if err == nil { + // Parse JSON content from the file + var fileKeys []string + err := json.Unmarshal(fileContent, &fileKeys) + if err != nil { + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"}) + } + + // Add file keys to options.ApiKeys + appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...) + } + + if len(appConfig.ApiKeys) == 0 { + return c.Next() + } + + authHeader := readAuthHeader(c) + if authHeader == "" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) + } + + // If it's a bearer token + authHeaderParts := strings.Split(authHeader, " ") + if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) + } + + apiKey := authHeaderParts[1] + for _, key := range appConfig.ApiKeys { + if apiKey == key { + return c.Next() + } + } + + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) + } + + if appConfig.CORS { + var c func(ctx *fiber.Ctx) error + if appConfig.CORSAllowOrigins == "" { + c = cors.New() + } else { + c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins}) + } + + app.Use(c) + } + + // Make sure directories exists + os.MkdirAll(appConfig.ImageDir, 0755) + os.MkdirAll(appConfig.AudioDir, 0755) + os.MkdirAll(appConfig.UploadDir, 0755) + os.MkdirAll(appConfig.ConfigsDir, 0755) + os.MkdirAll(appConfig.ModelPath, 0755) + + // Load config jsons + utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) + utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) + utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles) + + routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig, auth) + routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, auth) + routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth) + routes.RegisterPagesRoutes(app, cl, ml, appConfig, auth) + + // Define a custom 404 handler + // Note: keep this at the bottom! + app.Use(notFoundHandler) + + return app, nil +} diff --git a/core/http/api_test.go b/core/http/app_test.go similarity index 92% rename from core/http/api_test.go rename to core/http/app_test.go index bf8feb1c09cb..35e0a8bfc2f7 100644 --- a/core/http/api_test.go +++ b/core/http/app_test.go @@ -12,9 +12,7 @@ import ( "os" "path/filepath" "runtime" - "strings" - "github.com/go-skynet/LocalAI/core" "github.com/go-skynet/LocalAI/core/config" . "github.com/go-skynet/LocalAI/core/http" "github.com/go-skynet/LocalAI/core/schema" @@ -207,11 +205,12 @@ var _ = Describe("API test", func() { var cancel context.CancelFunc var tmpdir string var modelDir string - var application *core.Application + var bcl *config.BackendConfigLoader + var ml *model.ModelLoader + var applicationConfig *config.ApplicationConfig commonOpts := []config.AppOption{ config.WithDebug(true), - config.WithDisableMessage(true), } Context("API with ephemeral models", func() { @@ -252,7 +251,7 @@ var _ = Describe("API test", func() { }, } - application, err = startup.Startup( + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, config.WithContext(c), config.WithGalleries(galleries), @@ -261,7 +260,7 @@ var _ = Describe("API test", func() { config.WithBackendAssetsOutput(backendAssetsDir))...) Expect(err).ToNot(HaveOccurred()) - app, err = App(application) + app, err = App(bcl, ml, applicationConfig) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -474,11 +473,11 @@ var _ = Describe("API test", func() { }) Expect(err).ToNot(HaveOccurred()) Expect(len(resp2.Choices)).To(Equal(1)) - Expect(resp2.Choices[0].Message.ToolCalls[0].Function).ToNot(BeNil()) - Expect(resp2.Choices[0].Message.ToolCalls[0].Function.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.ToolCalls[0].Function.Name) + Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) + Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) var res map[string]string - err = json.Unmarshal([]byte(resp2.Choices[0].Message.ToolCalls[0].Function.Arguments), &res) + err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) Expect(err).ToNot(HaveOccurred()) Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) @@ -487,9 +486,9 @@ var _ = Describe("API test", func() { }) It("runs openllama gguf(llama-cpp)", Label("llama-gguf"), func() { - // if runtime.GOOS != "linux" { - // Skip("test supported only on linux") - // } + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } modelName := "codellama" response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "github:go-skynet/model-gallery/codellama-7b-instruct.yaml", @@ -504,7 +503,7 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) return response["processed"].(bool) - }, "480s", "10s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) By("testing chat") resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: modelName, Messages: []openai.ChatCompletionMessage{ @@ -551,13 +550,11 @@ var _ = Describe("API test", func() { }) Expect(err).ToNot(HaveOccurred()) Expect(len(resp2.Choices)).To(Equal(1)) - fmt.Printf("\n--- %+v\n\n", resp2.Choices[0].Message) - Expect(resp2.Choices[0].Message.ToolCalls).ToNot(BeNil()) - Expect(resp2.Choices[0].Message.ToolCalls[0]).ToNot(BeNil()) - Expect(resp2.Choices[0].Message.ToolCalls[0].Function.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.ToolCalls[0].Function.Name) + Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) + Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) var res map[string]string - err = json.Unmarshal([]byte(resp2.Choices[0].Message.ToolCalls[0].Function.Arguments), &res) + err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) Expect(err).ToNot(HaveOccurred()) Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) @@ -611,7 +608,7 @@ var _ = Describe("API test", func() { }, } - application, err = startup.Startup( + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, config.WithContext(c), config.WithAudioDir(tmpdir), @@ -622,7 +619,7 @@ var _ = Describe("API test", func() { config.WithBackendAssetsOutput(tmpdir))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = App(application) + app, err = App(bcl, ml, applicationConfig) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -726,14 +723,14 @@ var _ = Describe("API test", func() { var err error - application, err = startup.Startup( + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), config.WithContext(c), config.WithModelPath(modelPath), )...) Expect(err).ToNot(HaveOccurred()) - app, err = App(application) + app, err = App(bcl, ml, applicationConfig) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -763,11 +760,6 @@ var _ = Describe("API test", func() { Expect(len(models.Models)).To(Equal(6)) // If "config.yaml" should be included, this should be 8? }) It("can generate completions via ggml", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel.ggml", Prompt: testPrompt}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -775,11 +767,6 @@ var _ = Describe("API test", func() { }) It("can generate chat completions via ggml", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel.ggml", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -787,11 +774,6 @@ var _ = Describe("API test", func() { }) It("can generate completions from model configs", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "gpt4all", Prompt: testPrompt}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -799,11 +781,6 @@ var _ = Describe("API test", func() { }) It("can generate chat completions from model configs", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -890,9 +867,9 @@ var _ = Describe("API test", func() { Context("backends", func() { It("runs rwkv completion", func() { - // if runtime.GOOS != "linux" { - // Skip("test supported only on linux") - // } + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices) > 0).To(BeTrue()) @@ -913,20 +890,17 @@ var _ = Describe("API test", func() { } Expect(err).ToNot(HaveOccurred()) - - if len(response.Choices) > 0 { - text += response.Choices[0].Text - tokens++ - } + text += response.Choices[0].Text + tokens++ } Expect(text).ToNot(BeEmpty()) Expect(text).To(ContainSubstring("five")) Expect(tokens).ToNot(Or(Equal(1), Equal(0))) }) It("runs rwkv chat completion", func() { - // if runtime.GOOS != "linux" { - // Skip("test supported only on linux") - // } + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}}) Expect(err).ToNot(HaveOccurred()) @@ -1035,14 +1009,14 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - application, err = startup.Startup( + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, config.WithContext(c), config.WithModelPath(modelPath), config.WithConfigFile(os.Getenv("CONFIG_FILE")))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = App(application) + app, err = App(bcl, ml, applicationConfig) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -1066,33 +1040,18 @@ var _ = Describe("API test", func() { } }) It("can generate chat completions from config file (list1)", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate chat completions from config file (list2)", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate edit completions from config file", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - request := openaigo.EditCreateRequestBody{ Model: "list2", Instruction: "foo", diff --git a/core/http/ctx/fiber.go b/core/http/ctx/fiber.go index 99fbcde92a38..ffb631112abe 100644 --- a/core/http/ctx/fiber.go +++ b/core/http/ctx/fiber.go @@ -1,88 +1,43 @@ package fiberContext import ( - "context" - "encoding/json" "fmt" "strings" - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) -type FiberContextExtractor struct { - ml *model.ModelLoader - appConfig *config.ApplicationConfig -} - -func NewFiberContextExtractor(ml *model.ModelLoader, appConfig *config.ApplicationConfig) *FiberContextExtractor { - return &FiberContextExtractor{ - ml: ml, - appConfig: appConfig, - } -} - // ModelFromContext returns the model from the context // If no model is specified, it will take the first available // Takes a model string as input which should be the one received from the user request. // It returns the model name resolved from the context and an error if any. -func (fce *FiberContextExtractor) ModelFromContext(ctx *fiber.Ctx, modelInput string, firstModel bool) (string, error) { - ctxPM := ctx.Params("model") - if ctxPM != "" { - log.Debug().Msgf("[FCE] Overriding param modelInput %q with ctx.Params value %q", modelInput, ctxPM) - modelInput = ctxPM +func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) { + if ctx.Params("model") != "" { + modelInput = ctx.Params("model") } // Set model from bearer token, if available - bearer := strings.TrimPrefix(ctx.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && fce.ml.ExistsInModelPath(bearer) + bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) // If no model was specified, take the first available if modelInput == "" && !bearerExists && firstModel { - models, _ := fce.ml.ListModels() + models, _ := loader.ListModels() if len(models) > 0 { modelInput = models[0] - log.Debug().Msgf("[FCE] No model specified, using first available: %s", modelInput) + log.Debug().Msgf("No model specified, using: %s", modelInput) } else { - log.Warn().Msgf("[FCE] No model specified, none available") - return "", fmt.Errorf("[fce] no model specified, none available") + log.Debug().Msgf("No model specified, returning error") + return "", fmt.Errorf("no model specified") } } // If a model is found in bearer token takes precedence if bearerExists { - log.Debug().Msgf("[FCE] Using model from bearer token: %s", bearer) + log.Debug().Msgf("Using model from bearer token: %s", bearer) modelInput = bearer } - - if modelInput == "" { - log.Warn().Msg("[FCE] modelInput is empty") - } return modelInput, nil } - -// TODO: Do we still need the first return value? -func (fce *FiberContextExtractor) OpenAIRequestFromContext(c *fiber.Ctx, firstModel bool) (string, *schema.OpenAIRequest, error) { - input := new(schema.OpenAIRequest) - - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", nil, fmt.Errorf("failed parsing request body: %w", err) - } - - received, _ := json.Marshal(input) - - ctx, cancel := context.WithCancel(fce.appConfig.Context) - input.Context = ctx - input.Cancel = cancel - - log.Debug().Msgf("Request received: %s", string(received)) - - var err error - input.Model, err = fce.ModelFromContext(c, input.Model, firstModel) - - return input.Model, input, err -} diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go index 4f5db4638e0d..841f9b5f7846 100644 --- a/core/http/endpoints/elevenlabs/tts.go +++ b/core/http/endpoints/elevenlabs/tts.go @@ -2,7 +2,9 @@ package elevenlabs import ( "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/core/schema" "github.com/gofiber/fiber/v2" @@ -15,7 +17,7 @@ import ( // @Param request body schema.TTSRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/text-to-speech/{voice-id} [post] -func TTSEndpoint(fce *fiberContext.FiberContextExtractor, ttsbs *backend.TextToSpeechBackendService) func(c *fiber.Ctx) error { +func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.ElevenLabsTTSRequest) @@ -26,21 +28,34 @@ func TTSEndpoint(fce *fiberContext.FiberContextExtractor, ttsbs *backend.TextToS return err } - var err error - input.ModelID, err = fce.ModelFromContext(c, input.ModelID, false) + modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false) if err != nil { + modelFile = input.ModelID log.Warn().Msgf("Model not found in context: %s", input.ModelID) } - responseChannel := ttsbs.TextToAudioFile(&schema.TTSRequest{ - Model: input.ModelID, - Voice: voiceID, - Input: input.Text, - }) - rawValue := <-responseChannel - if rawValue.Error != nil { - return rawValue.Error + cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + config.LoadOptionDebug(appConfig.Debug), + config.LoadOptionThreads(appConfig.Threads), + config.LoadOptionContextSize(appConfig.ContextSize), + config.LoadOptionF16(appConfig.F16), + ) + if err != nil { + modelFile = input.ModelID + log.Warn().Msgf("Model not found in context: %s", input.ModelID) + } else { + if input.ModelID != "" { + modelFile = input.ModelID + } else { + modelFile = cfg.Model + } + } + log.Debug().Msgf("Request for model: %s", modelFile) + + filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg) + if err != nil { + return err } - return c.Download(*rawValue.Value) + return c.Download(filePath) } } diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go index dac20388d1d4..8c7a664a70b1 100644 --- a/core/http/endpoints/localai/backend_monitor.go +++ b/core/http/endpoints/localai/backend_monitor.go @@ -6,7 +6,7 @@ import ( "github.com/gofiber/fiber/v2" ) -func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error { +func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.BackendMonitorRequest) @@ -23,7 +23,7 @@ func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ct } } -func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error { +func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.BackendMonitorRequest) // Get input data from the request body diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index df7841fb2426..7822e0242c2c 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -2,7 +2,9 @@ package localai import ( "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/core/schema" "github.com/gofiber/fiber/v2" @@ -14,26 +16,45 @@ import ( // @Param request body schema.TTSRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/audio/speech [post] -func TTSEndpoint(fce *fiberContext.FiberContextExtractor, ttsbs *backend.TextToSpeechBackendService) func(c *fiber.Ctx) error { +func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - var err error + input := new(schema.TTSRequest) // Get input data from the request body - if err = c.BodyParser(input); err != nil { + if err := c.BodyParser(input); err != nil { return err } - input.Model, err = fce.ModelFromContext(c, input.Model, false) + modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false) + if err != nil { + modelFile = input.Model + log.Warn().Msgf("Model not found in context: %s", input.Model) + } + + cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + config.LoadOptionDebug(appConfig.Debug), + config.LoadOptionThreads(appConfig.Threads), + config.LoadOptionContextSize(appConfig.ContextSize), + config.LoadOptionF16(appConfig.F16), + ) + if err != nil { + modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) + } else { + modelFile = cfg.Model + } + log.Debug().Msgf("Request for model: %s", modelFile) + + if input.Backend != "" { + cfg.Backend = input.Backend } - responseChannel := ttsbs.TextToAudioFile(input) - rawValue := <-responseChannel - if rawValue.Error != nil { - return rawValue.Error + filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg) + if err != nil { + return err } - return c.Download(*rawValue.Value) + return c.Download(filePath) } } diff --git a/core/http/endpoints/localai/welcome.go b/core/http/endpoints/localai/welcome.go new file mode 100644 index 000000000000..fd3e6230e4c6 --- /dev/null +++ b/core/http/endpoints/localai/welcome.go @@ -0,0 +1,28 @@ +package localai + +import ( + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/internal" + "github.com/gofiber/fiber/v2" +) + +func WelcomeEndpoint(appConfig *config.ApplicationConfig, + models []string, backendConfigs []config.BackendConfig) func(*fiber.Ctx) error { + return func(c *fiber.Ctx) error { + summary := fiber.Map{ + "Title": "LocalAI API - " + internal.PrintableVersion(), + "Version": internal.PrintableVersion(), + "Models": models, + "ModelsConfig": backendConfigs, + "ApplicationConfig": appConfig, + } + + if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 { + // The client expects a JSON response + return c.Status(fiber.StatusOK).JSON(summary) + } else { + // Render index + return c.Render("views/index", summary) + } + } +} diff --git a/core/http/endpoints/openai/assistant.go b/core/http/endpoints/openai/assistant.go index 72cb8b4ab7d8..dceb378995fd 100644 --- a/core/http/endpoints/openai/assistant.go +++ b/core/http/endpoints/openai/assistant.go @@ -339,7 +339,7 @@ func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model } } - return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistantID %q", assistantID)) + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find ")) } } diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index a240b024e4c7..9adba8eabdd6 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -5,11 +5,16 @@ import ( "bytes" "encoding/json" "fmt" + "strings" + "time" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/functions" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" ) @@ -19,82 +24,418 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] -func ChatEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error { +func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { + emptyMessage := "" + id := uuid.New().String() + created := int(time.Now().Unix()) + + process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + resp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}}, + Object: "chat.completion.chunk", + Usage: schema.OpenAIUsage{ + PromptTokens: usage.Prompt, + CompletionTokens: usage.Completion, + TotalTokens: usage.Prompt + usage.Completion, + }, + } + + responses <- resp + return true + }) + close(responses) + } + processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + result := "" + _, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + result += s + // TODO: Change generated BNF grammar to be compliant with the schema so we can + // stream the result token by token here. + return true + }) + + results := functions.ParseFunctionCall(result, config.FunctionsConfig) + noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0 + + switch { + case noActionToRun: + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + result, err := handleQuestion(config, req, ml, startupOptions, results, prompt) + if err != nil { + log.Error().Err(err).Msg("error handling question") + return + } + + resp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}}, + Object: "chat.completion.chunk", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + + responses <- resp + + default: + for i, ss := range results { + name, args := ss.Name, ss.Arguments + + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: i, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + responses <- schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: i, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Arguments: args, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + } + } + + close(responses) + } + return func(c *fiber.Ctx) error { - _, request, err := fce.OpenAIRequestFromContext(c, false) + modelFile, input, err := readRequest(c, ml, startupOptions, true) if err != nil { - return fmt.Errorf("failed reading parameters from request: %w", err) + return fmt.Errorf("failed reading parameters from request:%w", err) } - traceID, finalResultChannel, _, tokenChannel, err := oais.Chat(request, false, request.Stream) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16) if err != nil { - return err + return fmt.Errorf("failed reading parameters from request:%w", err) + } + log.Debug().Msgf("Configuration read: %+v", config) + + funcs := input.Functions + shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() + + // Allow the user to set custom actions via config file + // to be "embedded" in each model + noActionName := "answer" + noActionDescription := "use this action to answer without performing any action" + + if config.FunctionsConfig.NoActionFunctionName != "" { + noActionName = config.FunctionsConfig.NoActionFunctionName + } + if config.FunctionsConfig.NoActionDescriptionName != "" { + noActionDescription = config.FunctionsConfig.NoActionDescriptionName + } + + if input.ResponseFormat.Type == "json_object" { + input.Grammar = functions.JSONBNF + } + + config.Grammar = input.Grammar + + if shouldUseFn { + log.Debug().Msgf("Response needs to process functions") + } + + switch { + case !config.FunctionsConfig.NoGrammar && shouldUseFn: + noActionGrammar := functions.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } + + // Append the no action function + if !config.FunctionsConfig.DisableNoAction { + funcs = append(funcs, noActionGrammar) + } + + // Force picking one of the functions by the request + if config.FunctionToCall() != "" { + funcs = funcs.Select(config.FunctionToCall()) + } + + // Update input grammar + jsStruct := funcs.ToJSONStructure() + config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) + case input.JSONFunctionGrammarObject != nil: + config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls) + default: + // Force picking one of the functions by the request + if config.FunctionToCall() != "" { + funcs = funcs.Select(config.FunctionToCall()) + } } - if request.Stream { + // process functions if we have any defined or if we have a function call string + + // functions are not supported in stream mode (yet?) + toStream := input.Stream + + log.Debug().Msgf("Parameters: %+v", config) + + var predInput string + + // If we are using the tokenizer template, we don't need to process the messages + // unless we are processing functions + if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn { + suppressConfigSystemPrompt := false + mess := []string{} + for messageIndex, i := range input.Messages { + var content string + role := i.Role + + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" + // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request + if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" { + roleFn := "assistant_function_call" + r := config.Roles[roleFn] + if r != "" { + role = roleFn + } + } + r := config.Roles[role] + contentExists := i.Content != nil && i.StringContent != "" + + fcall := i.FunctionCall + if len(i.ToolCalls) > 0 { + fcall = i.ToolCalls + } + + // First attempt to populate content via a chat message specific template + if config.TemplateConfig.ChatMessage != "" { + chatMessageData := model.ChatMessageTemplateData{ + SystemPrompt: config.SystemPrompt, + Role: r, + RoleName: role, + Content: i.StringContent, + FunctionCall: fcall, + FunctionName: i.Name, + LastMessage: messageIndex == (len(input.Messages) - 1), + Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)), + MessageIndex: messageIndex, + } + templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) + if err != nil { + log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping") + } else { + if templatedChatMessage == "" { + log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData) + continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf + } + log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) + content = templatedChatMessage + } + } + + marshalAnyRole := func(f any) { + j, err := json.Marshal(f) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } + } + } + marshalAny := func(f any) { + j, err := json.Marshal(f) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } + } + } + // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. + if content == "" { + if r != "" { + if contentExists { + content = fmt.Sprint(r, i.StringContent) + } + + if i.FunctionCall != nil { + marshalAnyRole(i.FunctionCall) + } + if i.ToolCalls != nil { + marshalAnyRole(i.ToolCalls) + } + } else { + if contentExists { + content = fmt.Sprint(i.StringContent) + } + if i.FunctionCall != nil { + marshalAny(i.FunctionCall) + } + if i.ToolCalls != nil { + marshalAny(i.ToolCalls) + } + } + // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately + if contentExists && role == "system" { + suppressConfigSystemPrompt = true + } + } + + mess = append(mess, content) + } + + predInput = strings.Join(mess, "\n") + log.Debug().Msgf("Prompt (before templating): %s", predInput) + + templateFile := "" - log.Debug().Msgf("Chat Stream request received") + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } + if config.TemplateConfig.Chat != "" && !shouldUseFn { + templateFile = config.TemplateConfig.Chat + } + + if config.TemplateConfig.Functions != "" && shouldUseFn { + templateFile = config.TemplateConfig.Functions + } + + if templateFile != "" { + templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + SuppressSystemPrompt: suppressConfigSystemPrompt, + Input: predInput, + Functions: funcs, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) + } + } + + log.Debug().Msgf("Prompt (after templating): %s", predInput) + if shouldUseFn && config.Grammar != "" { + log.Debug().Msgf("Grammar: %+v", config.Grammar) + } + } + + switch { + case toStream: + + log.Debug().Msgf("Stream request received") c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - // + // c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") + responses := make(chan schema.OpenAIResponse) + + if !shouldUseFn { + go process(predInput, input, config, ml, responses) + } else { + go processTools(noActionName, predInput, input, config, ml, responses) + } + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { usage := &schema.OpenAIUsage{} toolsCalled := false - for ev := range tokenChannel { - if ev.Error != nil { - log.Debug().Err(ev.Error).Msg("chat streaming responseChannel error") - request.Cancel() - break - } - usage = &ev.Value.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it - - if len(ev.Value.Choices[0].Delta.ToolCalls) > 0 { + for ev := range responses { + usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + if len(ev.Choices[0].Delta.ToolCalls) > 0 { toolsCalled = true } var buf bytes.Buffer enc := json.NewEncoder(&buf) - if ev.Error != nil { - log.Debug().Err(ev.Error).Msg("[ChatEndpoint] error to debug during tokenChannel handler") - enc.Encode(ev.Error) - } else { - enc.Encode(ev.Value) - } - log.Debug().Msgf("chat streaming sending chunk: %s", buf.String()) + enc.Encode(ev) + log.Debug().Msgf("Sending chunk: %s", buf.String()) _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) if err != nil { - log.Debug().Err(err).Msgf("Sending chunk failed") - request.Cancel() - break - } - err = w.Flush() - if err != nil { - log.Debug().Msg("error while flushing, closing connection") - request.Cancel() + log.Debug().Msgf("Sending chunk failed: %v", err) + input.Cancel() break } + w.Flush() } finishReason := "stop" if toolsCalled { finishReason = "tool_calls" - } else if toolsCalled && len(request.Tools) == 0 { + } else if toolsCalled && len(input.Tools) == 0 { finishReason = "function_call" } resp := &schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { FinishReason: finishReason, Index: 0, - Delta: &schema.Message{Content: ""}, + Delta: &schema.Message{Content: &emptyMessage}, }}, Object: "chat.completion.chunk", Usage: *usage, @@ -105,21 +446,146 @@ func ChatEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAI w.WriteString("data: [DONE]\n\n") w.Flush() })) - return nil + + // no streaming mode + default: + result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) { + if !shouldUseFn { + // no function is called, just reply and use stop as finish reason + *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) + return + } + + results := functions.ParseFunctionCall(s, config.FunctionsConfig) + noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0 + + switch { + case noActionsToRun: + result, err := handleQuestion(config, input, ml, startupOptions, results, predInput) + if err != nil { + log.Error().Err(err).Msg("error handling question") + return + } + *c = append(*c, schema.Choice{ + Message: &schema.Message{Role: "assistant", Content: &result}}) + default: + toolChoice := schema.Choice{ + Message: &schema.Message{ + Role: "assistant", + }, + } + + if len(input.Tools) > 0 { + toolChoice.FinishReason = "tool_calls" + } + + for _, ss := range results { + name, args := ss.Name, ss.Arguments + if len(input.Tools) > 0 { + // If we are using tools, we condense the function calls into + // a single response choice with all the tools + toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, + schema.ToolCall{ + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, + }, + }, + ) + } else { + // otherwise we return more choices directly + *c = append(*c, schema.Choice{ + FinishReason: "function_call", + Message: &schema.Message{ + Role: "assistant", + FunctionCall: map[string]interface{}{ + "name": name, + "arguments": args, + }, + }, + }) + } + } + + if len(input.Tools) > 0 { + // we need to append our result if we are using tools + *c = append(*c, toolChoice) + } + } + + }, nil) + if err != nil { + return err + } + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + respData, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", respData) + + // Return the prediction in the response body + return c.JSON(resp) } + } +} - // TODO is this proper to have exclusive from Stream, or do we need to issue both responses? - rawResponse := <-finalResultChannel +func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, prompt string) (string, error) { + log.Debug().Msgf("nothing to do, computing a reply") + arg := "" + if len(funcResults) > 0 { + arg = funcResults[0].Arguments + } + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + if err := json.Unmarshal([]byte(arg), &arguments); err != nil { + log.Debug().Msg("handleQuestion: function result did not contain a valid JSON object") + } + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = backend.Finetune(*config, prompt, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - if rawResponse.Error != nil { - return rawResponse.Error + return message, nil + } } + } - jsonResult, _ := json.Marshal(rawResponse.Value) - log.Debug().Str("jsonResult", string(jsonResult)).Msg("Chat Final Response") + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU/GPU) another computation + config.Grammar = "" + images := []string{} + for _, m := range input.Messages { + images = append(images, m.StringImages...) + } + + predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, ml, *config, o, nil) + if err != nil { + log.Error().Err(err).Msg("model inference failed") + return "", err + } - // Return the prediction in the response body - return c.JSON(rawResponse.Value) + prediction, err := predFunc() + if err != nil { + log.Error().Err(err).Msg("prediction failed") + return "", err } + return backend.Finetune(*config, prompt, prediction.Response), nil } diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index d8b412a90167..bcd46db55c8a 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -4,13 +4,18 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "fmt" + "time" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" - "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/functions" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" ) @@ -20,50 +25,116 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] -func CompletionEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error { +func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + id := uuid.New().String() + created := int(time.Now().Unix()) + + process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + resp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + Index: 0, + Text: s, + }, + }, + Object: "text_completion", + Usage: schema.OpenAIUsage{ + PromptTokens: usage.Prompt, + CompletionTokens: usage.Completion, + TotalTokens: usage.Prompt + usage.Completion, + }, + } + log.Debug().Msgf("Sending goroutine: %s", s) + + responses <- resp + return true + }) + close(responses) + } + return func(c *fiber.Ctx) error { - _, request, err := fce.OpenAIRequestFromContext(c, false) + modelFile, input, err := readRequest(c, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - log.Debug().Msgf("`OpenAIRequest`: %+v", request) + log.Debug().Msgf("`input`: %+v", input) - traceID, finalResultChannel, _, _, tokenChannel, err := oais.Completion(request, false, request.Stream) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { - return err + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + if input.ResponseFormat.Type == "json_object" { + input.Grammar = functions.JSONBNF } - if request.Stream { - log.Debug().Msgf("Completion Stream request received") + config.Grammar = input.Grammar + + log.Debug().Msgf("Parameter Config: %+v", config) + if input.Stream { + log.Debug().Msgf("Stream request received") c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) //c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") + } + + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } + + if config.TemplateConfig.Completion != "" { + templateFile = config.TemplateConfig.Completion + } + + if input.Stream { + if len(config.PromptStrings) > 1 { + return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") + } + + predInput := config.PromptStrings[0] + + if templateFile != "" { + templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ + Input: predInput, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + } + + responses := make(chan schema.OpenAIResponse) + + go process(predInput, input, config, ml, responses) c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - for ev := range tokenChannel { + + for ev := range responses { var buf bytes.Buffer enc := json.NewEncoder(&buf) - if ev.Error != nil { - log.Debug().Msgf("[CompletionEndpoint] error to debug during tokenChannel handler: %q", ev.Error) - enc.Encode(ev.Error) - } else { - enc.Encode(ev.Value) - } - - log.Debug().Msgf("completion streaming sending chunk: %s", buf.String()) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) fmt.Fprintf(w, "data: %v\n", buf.String()) w.Flush() } resp := &schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { Index: 0, @@ -80,15 +151,55 @@ func CompletionEndpoint(fce *fiberContext.FiberContextExtractor, oais *services. })) return nil } - // TODO is this proper to have exclusive from Stream, or do we need to issue both responses? - rawResponse := <-finalResultChannel - if rawResponse.Error != nil { - return rawResponse.Error + + var result []schema.Choice + + totalTokenUsage := backend.TokenUsage{} + + for k, i := range config.PromptStrings { + if templateFile != "" { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + Input: i, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + } + + r, tokenUsage, err := ComputeChoices( + input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) + }, nil) + if err != nil { + return err + } + + totalTokenUsage.Prompt += tokenUsage.Prompt + totalTokenUsage.Completion += tokenUsage.Completion + + result = append(result, r...) + } + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "text_completion", + Usage: schema.OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + }, } - jsonResult, _ := json.Marshal(rawResponse.Value) + + jsonResult, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(rawResponse.Value) + return c.JSON(resp) } } diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index a33050dd2b8b..254970958278 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -3,36 +3,92 @@ package openai import ( "encoding/json" "fmt" + "time" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" - "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "github.com/rs/zerolog/log" ) -func EditEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error { +func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - _, request, err := fce.OpenAIRequestFromContext(c, false) + modelFile, input, err := readRequest(c, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - _, finalResultChannel, _, _, _, err := oais.Edit(request, false, request.Stream) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { - return err + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } + + if config.TemplateConfig.Edit != "" { + templateFile = config.TemplateConfig.Edit + } + + var result []schema.Choice + totalTokenUsage := backend.TokenUsage{} + + for _, i := range config.InputStrings { + if templateFile != "" { + templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ + Input: i, + Instruction: input.Instruction, + SystemPrompt: config.SystemPrompt, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + } + + r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + totalTokenUsage.Prompt += tokenUsage.Prompt + totalTokenUsage.Completion += tokenUsage.Completion + + result = append(result, r...) } - rawResponse := <-finalResultChannel - if rawResponse.Error != nil { - return rawResponse.Error + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "edit", + Usage: schema.OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + }, } - jsonResult, _ := json.Marshal(rawResponse.Value) + jsonResult, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(rawResponse.Value) + return c.JSON(resp) } } diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go index be5469917537..eca34f79b266 100644 --- a/core/http/endpoints/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -3,9 +3,14 @@ package openai import ( "encoding/json" "fmt" + "time" "github.com/go-skynet/LocalAI/core/backend" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/model" + + "github.com/go-skynet/LocalAI/core/schema" + "github.com/google/uuid" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -16,25 +21,63 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/embeddings [post] -func EmbeddingsEndpoint(fce *fiberContext.FiberContextExtractor, ebs *backend.EmbeddingsBackendService) func(c *fiber.Ctx) error { +func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - _, input, err := fce.OpenAIRequestFromContext(c, true) + model, input, err := readRequest(c, ml, appConfig, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - responseChannel := ebs.Embeddings(input) + log.Debug().Msgf("Parameter Config: %+v", config) + items := []schema.Item{} - rawResponse := <-responseChannel + for i, s := range config.InputToken { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + for i, s := range config.InputStrings { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } - if rawResponse.Error != nil { - return rawResponse.Error + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Data: items, + Object: "list", } - jsonResult, _ := json.Marshal(rawResponse.Value) + jsonResult, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(rawResponse.Value) + return c.JSON(resp) } } diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index ec3d84dabbc1..9e806b3e51a4 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -1,18 +1,50 @@ package openai import ( + "bufio" + "encoding/base64" "encoding/json" "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/google/uuid" "github.com/go-skynet/LocalAI/core/backend" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) -// https://platform.openai.com/docs/api-reference/images/create +func downloadFile(url string) (string, error) { + // Get the data + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Create the file + out, err := os.CreateTemp("", "image") + if err != nil { + return "", err + } + defer out.Close() + + // Write the body to file + _, err = io.Copy(out, resp.Body) + return out.Name(), err +} + +// /* * @@ -27,36 +59,186 @@ import ( * */ - // ImageEndpoint is the OpenAI Image generation API endpoint https://platform.openai.com/docs/api-reference/images/create // @Summary Creates an image given a prompt. // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/images/generations [post] -func ImageEndpoint(fce *fiberContext.FiberContextExtractor, igbs *backend.ImageGenerationBackendService) func(c *fiber.Ctx) error { +func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - // TODO: Somewhat a hack. Is there a better place to assign this? - if igbs.BaseUrlForGeneratedImages == "" { - igbs.BaseUrlForGeneratedImages = c.BaseURL() + "/generated-images/" + m, input, err := readRequest(c, ml, appConfig, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) } - _, request, err := fce.OpenAIRequestFromContext(c, false) + + if m == "" { + m = model.StableDiffusionBackend + } + log.Debug().Msgf("Loading model: %+v", m) + + config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - responseChannel := igbs.GenerateImage(request) - rawResponse := <-responseChannel + src := "" + if input.File != "" { + + fileData := []byte{} + // check if input.File is an URL, if so download it and save it + // to a temporary file + if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { + out, err := downloadFile(input.File) + if err != nil { + return fmt.Errorf("failed downloading file:%w", err) + } + defer os.RemoveAll(out) + + fileData, err = os.ReadFile(out) + if err != nil { + return fmt.Errorf("failed reading file:%w", err) + } + + } else { + // base 64 decode the file and write it somewhere + // that we will cleanup + fileData, err = base64.StdEncoding.DecodeString(input.File) + if err != nil { + return err + } + } - if rawResponse.Error != nil { - return rawResponse.Error + // Create a temporary file + outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64") + if err != nil { + return err + } + // write the base64 result + writer := bufio.NewWriter(outputFile) + _, err = writer.Write(fileData) + if err != nil { + outputFile.Close() + return err + } + outputFile.Close() + src = outputFile.Name() + defer os.RemoveAll(src) } - jsonResult, err := json.Marshal(rawResponse.Value) + log.Debug().Msgf("Parameter Config: %+v", config) + + switch config.Backend { + case "stablediffusion": + config.Backend = model.StableDiffusionBackend + case "tinydream": + config.Backend = model.TinyDreamBackend + case "": + config.Backend = model.StableDiffusionBackend + } + + sizeParts := strings.Split(input.Size, "x") + if len(sizeParts) != 2 { + return fmt.Errorf("invalid value for 'size'") + } + width, err := strconv.Atoi(sizeParts[0]) if err != nil { - return err + return fmt.Errorf("invalid value for 'size'") } + height, err := strconv.Atoi(sizeParts[1]) + if err != nil { + return fmt.Errorf("invalid value for 'size'") + } + + b64JSON := false + if input.ResponseFormat.Type == "b64_json" { + b64JSON = true + } + // src and clip_skip + var result []schema.Item + for _, i := range config.PromptStrings { + n := input.N + if input.N == 0 { + n = 1 + } + for j := 0; j < n; j++ { + prompts := strings.Split(i, "|") + positive_prompt := prompts[0] + negative_prompt := "" + if len(prompts) > 1 { + negative_prompt = prompts[1] + } + + mode := 0 + step := config.Step + if step == 0 { + step = 15 + } + + if input.Mode != 0 { + mode = input.Mode + } + + if input.Step != 0 { + step = input.Step + } + + tempDir := "" + if !b64JSON { + tempDir = appConfig.ImageDir + } + // Create a temporary file + outputFile, err := os.CreateTemp(tempDir, "b64") + if err != nil { + return err + } + outputFile.Close() + output := outputFile.Name() + ".png" + // Rename the temporary file + err = os.Rename(outputFile.Name(), output) + if err != nil { + return err + } + + baseURL := c.BaseURL() + + fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig) + if err != nil { + return err + } + if err := fn(); err != nil { + return err + } + + item := &schema.Item{} + + if b64JSON { + defer os.RemoveAll(output) + data, err := os.ReadFile(output) + if err != nil { + return err + } + item.B64JSON = base64.StdEncoding.EncodeToString(data) + } else { + base := filepath.Base(output) + item.URL = baseURL + "/generated-images/" + base + } + + result = append(result, *item) + } + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Data: result, + } + + jsonResult, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", jsonResult) + // Return the prediction in the response body - return c.JSON(rawResponse.Value) + return c.JSON(resp) } } diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go new file mode 100644 index 000000000000..06e784b72048 --- /dev/null +++ b/core/http/endpoints/openai/inference.go @@ -0,0 +1,55 @@ +package openai + +import ( + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" + + "github.com/go-skynet/LocalAI/core/schema" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ComputeChoices( + req *schema.OpenAIRequest, + predInput string, + config *config.BackendConfig, + o *config.ApplicationConfig, + loader *model.ModelLoader, + cb func(string, *[]schema.Choice), + tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) { + n := req.N // number of completions to return + result := []schema.Choice{} + + if n == 0 { + n = 1 + } + + images := []string{} + for _, m := range req.Messages { + images = append(images, m.StringImages...) + } + + // get the model function to call for the result + predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, loader, *config, o, tokenCallback) + if err != nil { + return result, backend.TokenUsage{}, err + } + + tokenUsage := backend.TokenUsage{} + + for i := 0; i < n; i++ { + prediction, err := predFunc() + if err != nil { + return result, backend.TokenUsage{}, err + } + + tokenUsage.Prompt += prediction.Usage.Prompt + tokenUsage.Completion += prediction.Usage.Completion + + finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) + cb(finetunedResponse, &result) + + //result = append(result, Choice{Text: prediction}) + + } + return result, tokenUsage, err +} diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go index 9bb2b2ca3696..04e611a20fed 100644 --- a/core/http/endpoints/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -1,21 +1,61 @@ package openai import ( + "regexp" + + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/core/services" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" ) -func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) error { +func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { - // If blank, no filter is applied. + models, err := ml.ListModels() + if err != nil { + return err + } + var mm map[string]interface{} = map[string]interface{}{} + + dataModels := []schema.OpenAIModel{} + + var filterFn func(name string) bool filter := c.Query("filter") + + // If filter is not specified, do not filter the list by model name + if filter == "" { + filterFn = func(_ string) bool { return true } + } else { + // If filter _IS_ specified, we compile it to a regex which is used to create the filterFn + rxp, err := regexp.Compile(filter) + if err != nil { + return err + } + filterFn = func(name string) bool { + return rxp.MatchString(name) + } + } + // By default, exclude any loose files that are already referenced by a configuration file. excludeConfigured := c.QueryBool("excludeConfigured", true) - dataModels, err := lms.ListModels(filter, excludeConfigured) - if err != nil { - return err + // Start with the known configurations + for _, c := range cl.GetAllBackendConfigs() { + if excludeConfigured { + mm[c.Model] = nil + } + + if filterFn(c.Name) { + dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) + } + } + + // Then iterate through the loose files: + for _, m := range models { + // And only adds them if they shouldn't be skipped. + if _, exists := mm[m]; !exists && filterFn(m) { + dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) + } } return c.JSON(struct { diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go new file mode 100644 index 000000000000..9a107bab6aca --- /dev/null +++ b/core/http/endpoints/openai/request.go @@ -0,0 +1,285 @@ +package openai + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/go-skynet/LocalAI/core/config" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/functions" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { + input := new(schema.OpenAIRequest) + + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return "", nil, fmt.Errorf("failed parsing request body: %w", err) + } + + received, _ := json.Marshal(input) + + ctx, cancel := context.WithCancel(o.Context) + input.Context = ctx + input.Cancel = cancel + + log.Debug().Msgf("Request received: %s", string(received)) + + modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel) + + return modelFile, input, err +} + +// this function check if the string is an URL, if it's an URL downloads the image in memory +// encodes it in base64 and returns the base64 string +func getBase64Image(s string) (string, error) { + if strings.HasPrefix(s, "http") { + // download the image + resp, err := http.Get(s) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // read the image data into memory + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + // encode the image data in base64 + encoded := base64.StdEncoding.EncodeToString(data) + + // return the base64 string + return encoded, nil + } + + // if the string instead is prefixed with "data:image/jpeg;base64,", drop it + if strings.HasPrefix(s, "data:image/jpeg;base64,") { + return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil + } + return "", fmt.Errorf("not valid string") +} + +func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != nil { + config.TopK = input.TopK + } + if input.TopP != nil { + config.TopP = input.TopP + } + + if input.Backend != "" { + config.Backend = input.Backend + } + + if input.ClipSkip != 0 { + config.Diffusers.ClipSkip = input.ClipSkip + } + + if input.ModelBaseName != "" { + config.AutoGPTQ.ModelBaseName = input.ModelBaseName + } + + if input.NegativePromptScale != 0 { + config.NegativePromptScale = input.NegativePromptScale + } + + if input.UseFastTokenizer { + config.UseFastTokenizer = input.UseFastTokenizer + } + + if input.NegativePrompt != "" { + config.NegativePrompt = input.NegativePrompt + } + + if input.RopeFreqBase != 0 { + config.RopeFreqBase = input.RopeFreqBase + } + + if input.RopeFreqScale != 0 { + config.RopeFreqScale = input.RopeFreqScale + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + if input.Maxtokens != nil { + config.Maxtokens = input.Maxtokens + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if len(input.Tools) > 0 { + for _, tool := range input.Tools { + input.Functions = append(input.Functions, tool.Function) + } + } + + if input.ToolsChoice != nil { + var toolChoice functions.Tool + + switch content := input.ToolsChoice.(type) { + case string: + _ = json.Unmarshal([]byte(content), &toolChoice) + case map[string]interface{}: + dat, _ := json.Marshal(content) + _ = json.Unmarshal(dat, &toolChoice) + } + input.FunctionCall = map[string]interface{}{ + "name": toolChoice.Function.Name, + } + } + + // Decode each request's message content + index := 0 + for i, m := range input.Messages { + switch content := m.Content.(type) { + case string: + input.Messages[i].StringContent = content + case []interface{}: + dat, _ := json.Marshal(content) + c := []schema.Content{} + json.Unmarshal(dat, &c) + for _, pp := range c { + if pp.Type == "text" { + input.Messages[i].StringContent = pp.Text + } else if pp.Type == "image_url" { + // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: + base64, err := getBase64Image(pp.ImageURL.URL) + if err == nil { + input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff + // set a placeholder for each image + input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent + index++ + } else { + fmt.Print("Failed encoding image", err) + } + } + } + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.FrequencyPenalty != 0 { + config.FrequencyPenalty = input.FrequencyPenalty + } + + if input.PresencePenalty != 0 { + config.PresencePenalty = input.PresencePenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != nil { + config.Seed = input.Seed + } + + if input.TypicalP != nil { + config.TypicalP = input.TypicalP + } + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []interface{}: + tokens := []int{} + for _, ii := range i { + tokens = append(tokens, int(ii.(float64))) + } + config.InputToken = append(config.InputToken, tokens) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } +} + +func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { + cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath, + config.LoadOptionDebug(debug), + config.LoadOptionThreads(threads), + config.LoadOptionContextSize(ctx), + config.LoadOptionF16(f16), + ) + + // Set the parameters for the language model prediction + updateRequestConfig(cfg, input) + + return cfg, input, err +} diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index 572cec1288c6..c7dd39e7bafb 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -9,7 +9,8 @@ import ( "path/filepath" "github.com/go-skynet/LocalAI/core/backend" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/config" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -22,15 +23,17 @@ import ( // @Param file formData file true "file" // @Success 200 {object} map[string]string "Response" // @Router /v1/audio/transcriptions [post] -func TranscriptEndpoint(fce *fiberContext.FiberContextExtractor, tbs *backend.TranscriptionBackendService) func(c *fiber.Ctx) error { +func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - _, request, err := fce.OpenAIRequestFromContext(c, false) + m, input, err := readRequest(c, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - // TODO: Investigate this file copy stuff later - potentially belongs in service. - + config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } // retrieve the file data from the request file, err := c.FormFile("file") if err != nil { @@ -62,16 +65,13 @@ func TranscriptEndpoint(fce *fiberContext.FiberContextExtractor, tbs *backend.Tr log.Debug().Msgf("Audio file copied to: %+v", dst) - request.File = dst - - responseChannel := tbs.Transcribe(request) - rawResponse := <-responseChannel - - if rawResponse.Error != nil { - return rawResponse.Error + tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig) + if err != nil { + return err } - log.Debug().Msgf("Transcribed: %+v", rawResponse.Value) + + log.Debug().Msgf("Trascribed: %+v", tr) // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(rawResponse.Value) + return c.Status(http.StatusOK).JSON(tr) } } diff --git a/core/http/apt_suite_test.go b/core/http/http_suite_test.go similarity index 100% rename from core/http/apt_suite_test.go rename to core/http/http_suite_test.go diff --git a/core/http/render.go b/core/http/render.go index c50458684cab..8f1b36c6dbf6 100644 --- a/core/http/render.go +++ b/core/http/render.go @@ -7,10 +7,7 @@ import ( "net/http" "github.com/Masterminds/sprig/v3" - "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/internal" - "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" fiberhtml "github.com/gofiber/template/html/v2" "github.com/russross/blackfriday" @@ -33,40 +30,6 @@ func notFoundHandler(c *fiber.Ctx) error { return nil } -func welcomeRoute( - app *fiber.App, - cl *config.BackendConfigLoader, - ml *model.ModelLoader, - appConfig *config.ApplicationConfig, - auth func(*fiber.Ctx) error, -) { - if appConfig.DisableWelcomePage { - return - } - - models, _ := ml.ListModels() - backendConfigs := cl.GetAllBackendConfigs() - - app.Get("/", auth, func(c *fiber.Ctx) error { - summary := fiber.Map{ - "Title": "LocalAI API - " + internal.PrintableVersion(), - "Version": internal.PrintableVersion(), - "Models": models, - "ModelsConfig": backendConfigs, - "ApplicationConfig": appConfig, - } - - if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 { - // The client expects a JSON response - return c.Status(fiber.StatusOK).JSON(summary) - } else { - // Render index - return c.Render("views/index", summary) - } - }) - -} - func renderEngine() *fiberhtml.Engine { engine := fiberhtml.NewFileSystem(http.FS(viewsfs), ".html") engine.AddFuncMap(sprig.FuncMap()) diff --git a/core/http/routes/elevenlabs.go b/core/http/routes/elevenlabs.go new file mode 100644 index 000000000000..e24a19a8471c --- /dev/null +++ b/core/http/routes/elevenlabs.go @@ -0,0 +1,19 @@ +package routes + +import ( + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" +) + +func RegisterElevenLabsRoutes(app *fiber.App, + cl *config.BackendConfigLoader, + ml *model.ModelLoader, + appConfig *config.ApplicationConfig, + auth func(*fiber.Ctx) error) { + + // Elevenlabs + app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig)) + +} diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go new file mode 100644 index 000000000000..2651a53e1a18 --- /dev/null +++ b/core/http/routes/localai.go @@ -0,0 +1,64 @@ +package routes + +import ( + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/http/endpoints/localai" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/internal" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/swagger" +) + +func RegisterLocalAIRoutes(app *fiber.App, + cl *config.BackendConfigLoader, + ml *model.ModelLoader, + appConfig *config.ApplicationConfig, + auth func(*fiber.Ctx) error) { + + app.Get("/swagger/*", swagger.HandlerDefault) // default + + // LocalAI API endpoints + galleryService := services.NewGalleryService(appConfig.ModelPath) + galleryService.Start(appConfig.Context, cl) + + modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) + app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) + app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint()) + app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint()) + app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint()) + app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint()) + app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint()) + app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint()) + + app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig)) + + // Stores + sl := model.NewModelLoader("") + app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig)) + app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig)) + app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig)) + app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig)) + + // Kubernetes health checks + ok := func(c *fiber.Ctx) error { + return c.SendStatus(200) + } + + app.Get("/healthz", ok) + app.Get("/readyz", ok) + + app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint()) + + // Experimental Backend Statistics Module + backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now + app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitor)) + app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitor)) + + app.Get("/version", auth, func(c *fiber.Ctx) error { + return c.JSON(struct { + Version string `json:"version"` + }{Version: internal.PrintableVersion()}) + }) + +} diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go new file mode 100644 index 000000000000..c51ccdcb0716 --- /dev/null +++ b/core/http/routes/openai.go @@ -0,0 +1,86 @@ +package routes + +import ( + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/http/endpoints/localai" + "github.com/go-skynet/LocalAI/core/http/endpoints/openai" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" +) + +func RegisterOpenAIRoutes(app *fiber.App, + cl *config.BackendConfigLoader, + ml *model.ModelLoader, + appConfig *config.ApplicationConfig, + auth func(*fiber.Ctx) error) { + // openAI compatible API endpoint + + // chat + app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) + app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) + + // edit + app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) + app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) + + // assistant + app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig)) + app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig)) + app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig)) + app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig)) + app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig)) + app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) + app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) + app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) + app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + + // files + app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) + app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) + app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig)) + app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig)) + app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) + app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) + app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) + app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) + app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) + app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) + + // completion + app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + + // embeddings + app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + + // audio + app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig)) + app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig)) + + // images + app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig)) + + if appConfig.ImageDir != "" { + app.Static("/generated-images", appConfig.ImageDir) + } + + if appConfig.AudioDir != "" { + app.Static("/generated-audio", appConfig.AudioDir) + } + + // models + app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml)) + app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml)) +} diff --git a/core/http/routes/welcome.go b/core/http/routes/welcome.go new file mode 100644 index 000000000000..29b9e58631af --- /dev/null +++ b/core/http/routes/welcome.go @@ -0,0 +1,23 @@ +package routes + +import ( + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/http/endpoints/localai" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" +) + +func RegisterPagesRoutes(app *fiber.App, + cl *config.BackendConfigLoader, + ml *model.ModelLoader, + appConfig *config.ApplicationConfig, + auth func(*fiber.Ctx) error) { + + models, _ := ml.ListModels() + backendConfigs := cl.GetAllBackendConfigs() + + if !appConfig.DisableWelcomePage { + app.Get("/", auth, localai.WelcomeEndpoint(appConfig, models, backendConfigs)) + } + +} diff --git a/core/schema/openai.go b/core/schema/openai.go index 6aa0f1b0602a..a251ba681dcd 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -3,7 +3,7 @@ package schema import ( "context" - "github.com/go-skynet/LocalAI/pkg/grammar" + functions "github.com/go-skynet/LocalAI/pkg/functions" ) // APIError provides error information returned by the OpenAI API. @@ -108,7 +108,7 @@ type ChatCompletionResponseFormat struct { type OpenAIRequest struct { PredictionOptions - Context context.Context `json:"-"` + Context context.Context `json:"-"` Cancel context.CancelFunc `json:"-"` // whisper @@ -130,11 +130,11 @@ type OpenAIRequest struct { Messages []Message `json:"messages" yaml:"messages"` // A list of available functions to call - Functions []grammar.Function `json:"functions" yaml:"functions"` - FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object + Functions functions.Functions `json:"functions" yaml:"functions"` + FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object - Tools []grammar.Tool `json:"tools,omitempty" yaml:"tools"` - ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"` + Tools []functions.Tool `json:"tools,omitempty" yaml:"tools"` + ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"` Stream bool `json:"stream"` @@ -145,7 +145,7 @@ type OpenAIRequest struct { // A grammar to constrain the LLM output Grammar string `json:"grammar" yaml:"grammar"` - JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` + JSONFunctionGrammarObject *functions.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` Backend string `json:"backend" yaml:"backend"` diff --git a/core/schema/transcription.go b/core/schema/whisper.go similarity index 90% rename from core/schema/transcription.go rename to core/schema/whisper.go index fe1799fa3223..41413c1f06ed 100644 --- a/core/schema/transcription.go +++ b/core/schema/whisper.go @@ -10,7 +10,7 @@ type Segment struct { Tokens []int `json:"tokens"` } -type TranscriptionResult struct { +type Result struct { Segments []Segment `json:"segments"` Text string `json:"text"` } diff --git a/core/services/backend_monitor.go b/core/services/backend_monitor.go index a610432c1e8a..979a67a3981e 100644 --- a/core/services/backend_monitor.go +++ b/core/services/backend_monitor.go @@ -15,22 +15,22 @@ import ( gopsutil "github.com/shirou/gopsutil/v3/process" ) -type BackendMonitorService struct { +type BackendMonitor struct { configLoader *config.BackendConfigLoader modelLoader *model.ModelLoader options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. } -func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService { - return &BackendMonitorService{ +func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor { + return BackendMonitor{ configLoader: configLoader, modelLoader: modelLoader, options: appConfig, } } -func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) (string, error) { - config, exists := bms.configLoader.GetBackendConfig(modelName) +func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) { + config, exists := bm.configLoader.GetBackendConfig(modelName) var backendId string if exists { backendId = config.Model @@ -46,8 +46,8 @@ func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) return backendId, nil } -func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { - config, exists := bms.configLoader.GetBackendConfig(model) +func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { + config, exists := bm.configLoader.GetBackendConfig(model) var backend string if exists { backend = config.Model @@ -60,7 +60,7 @@ func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*sche backend = fmt.Sprintf("%s.bin", backend) } - pid, err := bms.modelLoader.GetGRPCPID(backend) + pid, err := bm.modelLoader.GetGRPCPID(backend) if err != nil { log.Error().Err(err).Str("model", model).Msg("failed to find GRPC pid") @@ -101,12 +101,12 @@ func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*sche }, nil } -func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.StatusResponse, error) { - backendId, err := bms.getModelLoaderIDFromModelName(modelName) +func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) { + backendId, err := bm.getModelLoaderIDFromModelName(modelName) if err != nil { return nil, err } - modelAddr := bms.modelLoader.CheckIsLoaded(backendId) + modelAddr := bm.modelLoader.CheckIsLoaded(backendId) if modelAddr == "" { return nil, fmt.Errorf("backend %s is not currently loaded", backendId) } @@ -114,7 +114,7 @@ func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.Status status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO()) if rpcErr != nil { log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) - val, slbErr := bms.SampleLocalBackendProcess(backendId) + val, slbErr := bm.SampleLocalBackendProcess(backendId) if slbErr != nil { return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) } @@ -131,10 +131,10 @@ func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.Status return status, nil } -func (bms BackendMonitorService) ShutdownModel(modelName string) error { - backendId, err := bms.getModelLoaderIDFromModelName(modelName) +func (bm BackendMonitor) ShutdownModel(modelName string) error { + backendId, err := bm.getModelLoaderIDFromModelName(modelName) if err != nil { return err } - return bms.modelLoader.ShutdownModel(backendId) + return bm.modelLoader.ShutdownModel(backendId) } diff --git a/core/services/gallery.go b/core/services/gallery.go index 1ef8e3e2a780..b068abbb1e61 100644 --- a/core/services/gallery.go +++ b/core/services/gallery.go @@ -3,18 +3,14 @@ package services import ( "context" "encoding/json" - "errors" "os" - "path/filepath" "strings" "sync" "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/embedded" - "github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/startup" "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" "gopkg.in/yaml.v2" ) @@ -33,6 +29,18 @@ func NewGalleryService(modelPath string) *GalleryService { } } +func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error { + + config, err := gallery.GetGalleryConfigFromURL(req.URL) + if err != nil { + return err + } + + config.Files = append(config.Files, req.AdditionalFiles...) + + return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) +} + func (g *GalleryService) UpdateStatus(s string, op *gallery.GalleryOpStatus) { g.Lock() defer g.Unlock() @@ -84,10 +92,10 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) } } else if op.ConfigURL != "" { - PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL) + startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL) err = cl.Preload(g.modelPath) } else { - err = prepareModel(g.modelPath, op.Req, progressCallback) + err = prepareModel(g.modelPath, op.Req, cl, progressCallback) } if err != nil { @@ -119,12 +127,13 @@ type galleryModel struct { ID string `json:"id"` } -func processRequests(modelPath string, galleries []gallery.Gallery, requests []galleryModel) error { +func processRequests(modelPath, s string, cm *config.BackendConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error { var err error for _, r := range requests { utils.ResetDownloadTimers() if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) + } else { if strings.Contains(r.ID, "@") { err = gallery.InstallModelFromGallery( @@ -149,7 +158,7 @@ func ApplyGalleryFromFile(modelPath, s string, cl *config.BackendConfigLoader, g return err } - return processRequests(modelPath, galleries, requests) + return processRequests(modelPath, s, cl, galleries, requests) } func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error { @@ -159,90 +168,5 @@ func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, return err } - return processRequests(modelPath, galleries, requests) -} - -// PreloadModelsConfigurations will preload models from the given list of URLs -// It will download the model if it is not already present in the model path -// It will also try to resolve if the model is an embedded model YAML configuration -func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) { - for _, url := range models { - - // As a best effort, try to resolve the model from the remote library - // if it's not resolved we try with the other method below - if modelLibraryURL != "" { - lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL) - if err == nil { - if lib[url] != "" { - log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url]) - url = lib[url] - } - } - } - - url = embedded.ModelShortURL(url) - switch { - case embedded.ExistsInModelsLibrary(url): - modelYAML, err := embedded.ResolveContent(url) - // If we resolve something, just save it to disk and continue - if err != nil { - log.Error().Err(err).Msg("error resolving model content") - continue - } - - log.Debug().Msgf("[startup] resolved embedded model: %s", url) - md5Name := utils.MD5(url) - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { - log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") - } - case downloader.LooksLikeURL(url): - log.Debug().Msgf("[startup] resolved model to download: %s", url) - - // md5 of model name - md5Name := utils.MD5(url) - - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - err := downloader.DownloadFile(url, modelDefinitionFilePath, "", func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - }) - if err != nil { - log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") - } - } - default: - if _, err := os.Stat(url); err == nil { - log.Debug().Msgf("[startup] resolved local model: %s", url) - // copy to modelPath - md5Name := utils.MD5(url) - - modelYAML, err := os.ReadFile(url) - if err != nil { - log.Error().Err(err).Str("filepath", url).Msg("error reading model definition") - continue - } - - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { - log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s") - } - } else { - log.Warn().Msgf("[startup] failed resolving model '%s'", url) - } - } - } -} - -func prepareModel(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64)) error { - - config, err := gallery.GetGalleryConfigFromURL(req.URL) - if err != nil { - return err - } - - config.Files = append(config.Files, req.AdditionalFiles...) - - return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) + return processRequests(modelPath, s, cl, galleries, requests) } diff --git a/core/services/list_models.go b/core/services/list_models.go deleted file mode 100644 index a21e6fafc6e9..000000000000 --- a/core/services/list_models.go +++ /dev/null @@ -1,72 +0,0 @@ -package services - -import ( - "regexp" - - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/model" -) - -type ListModelsService struct { - bcl *config.BackendConfigLoader - ml *model.ModelLoader - appConfig *config.ApplicationConfig -} - -func NewListModelsService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ListModelsService { - return &ListModelsService{ - bcl: bcl, - ml: ml, - appConfig: appConfig, - } -} - -func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) { - - models, err := lms.ml.ListModels() - if err != nil { - return nil, err - } - - var mm map[string]interface{} = map[string]interface{}{} - - dataModels := []schema.OpenAIModel{} - - var filterFn func(name string) bool - - // If filter is not specified, do not filter the list by model name - if filter == "" { - filterFn = func(_ string) bool { return true } - } else { - // If filter _IS_ specified, we compile it to a regex which is used to create the filterFn - rxp, err := regexp.Compile(filter) - if err != nil { - return nil, err - } - filterFn = func(name string) bool { - return rxp.MatchString(name) - } - } - - // Start with the known configurations - for _, c := range lms.bcl.GetAllBackendConfigs() { - if excludeConfigured { - mm[c.Model] = nil - } - - if filterFn(c.Name) { - dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) - } - } - - // Then iterate through the loose files: - for _, m := range models { - // And only adds them if they shouldn't be skipped. - if _, exists := mm[m]; !exists && filterFn(m) { - dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) - } - } - - return dataModels, nil -} diff --git a/core/services/openai.go b/core/services/openai.go deleted file mode 100644 index 7a2679adc3d0..000000000000 --- a/core/services/openai.go +++ /dev/null @@ -1,808 +0,0 @@ -package services - -import ( - "encoding/json" - "errors" - "fmt" - "strings" - "sync" - "time" - - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/concurrency" - "github.com/go-skynet/LocalAI/pkg/grammar" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/google/uuid" - "github.com/imdario/mergo" - "github.com/rs/zerolog/log" -) - -type endpointGenerationConfigurationFn func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration - -type endpointConfiguration struct { - SchemaObject string - TemplatePath string - TemplateData model.PromptTemplateData - ResultMappingFn func(resp *backend.LLMResponse, index int) schema.Choice - CompletionMappingFn func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] - TokenMappingFn func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] -} - -// TODO: This is used for completion and edit. I am pretty sure I forgot parts, but fix it later. -func simpleMapper(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] { - if resp.Error != nil || resp.Value == nil { - return concurrency.ErrorOr[*schema.OpenAIResponse]{Error: resp.Error} - } - return concurrency.ErrorOr[*schema.OpenAIResponse]{ - Value: &schema.OpenAIResponse{ - Choices: []schema.Choice{ - { - Text: resp.Value.Response, - }, - }, - Usage: schema.OpenAIUsage{ - PromptTokens: resp.Value.Usage.Prompt, - CompletionTokens: resp.Value.Usage.Completion, - TotalTokens: resp.Value.Usage.Prompt + resp.Value.Usage.Completion, - }, - }, - } -} - -// TODO: Consider alternative names for this. -// The purpose of this struct is to hold a reference to the OpenAI request context information -// This keeps things simple within core/services/openai.go and allows consumers to "see" this information if they need it -type OpenAIRequestTraceID struct { - ID string - Created int -} - -// This type split out from core/backend/llm.go - I'm still not _totally_ sure about this, but it seems to make sense to keep the generic LLM code from the OpenAI specific higher level functionality -type OpenAIService struct { - bcl *config.BackendConfigLoader - ml *model.ModelLoader - appConfig *config.ApplicationConfig - llmbs *backend.LLMBackendService -} - -func NewOpenAIService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig, llmbs *backend.LLMBackendService) *OpenAIService { - return &OpenAIService{ - bcl: bcl, - ml: ml, - appConfig: appConfig, - llmbs: llmbs, - } -} - -// Keeping in place as a reminder to POTENTIALLY ADD MORE VALIDATION HERE??? -func (oais *OpenAIService) getConfig(request *schema.OpenAIRequest) (*config.BackendConfig, *schema.OpenAIRequest, error) { - return oais.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, oais.appConfig) -} - -// TODO: It would be a lot less messy to make a return struct that had references to each of these channels -// INTENTIONALLY not doing that quite yet - I believe we need to let the references to unused channels die for the GC to automatically collect -- can we manually free()? -// finalResultsChannel is the primary async return path: one result for the entire request. -// promptResultsChannels is DUBIOUS. It's expected to be raw fan-out used within the function itself, but I am exposing for testing? One bundle of LLMResponseBundle per PromptString? Gets all N completions for a single prompt. -// completionsChannel is a channel that emits one *LLMResponse per generated completion, be that different prompts or N. Seems the most useful other than "entire request" Request is available to attempt tracing??? -// tokensChannel is a channel that emits one *LLMResponse per generated token. Let's see what happens! -func (oais *OpenAIService) Completion(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) ( - traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle], - completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { - - return oais.GenerateTextFromRequest(request, func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration { - return endpointConfiguration{ - SchemaObject: "text_completion", - TemplatePath: bc.TemplateConfig.Completion, - TemplateData: model.PromptTemplateData{ - SystemPrompt: bc.SystemPrompt, - }, - ResultMappingFn: func(resp *backend.LLMResponse, promptIndex int) schema.Choice { - return schema.Choice{ - Index: promptIndex, - FinishReason: "stop", - Text: resp.Response, - } - }, - CompletionMappingFn: simpleMapper, - TokenMappingFn: simpleMapper, - } - }, notifyOnPromptResult, notifyOnToken, nil) -} - -func (oais *OpenAIService) Edit(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) ( - traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle], - completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { - - return oais.GenerateTextFromRequest(request, func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration { - - return endpointConfiguration{ - SchemaObject: "edit", - TemplatePath: bc.TemplateConfig.Edit, - TemplateData: model.PromptTemplateData{ - SystemPrompt: bc.SystemPrompt, - Instruction: request.Instruction, - }, - ResultMappingFn: func(resp *backend.LLMResponse, promptIndex int) schema.Choice { - return schema.Choice{ - Index: promptIndex, - FinishReason: "stop", - Text: resp.Response, - } - }, - CompletionMappingFn: simpleMapper, - TokenMappingFn: simpleMapper, - } - }, notifyOnPromptResult, notifyOnToken, nil) -} - -func (oais *OpenAIService) Chat(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) ( - traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], - completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { - - return oais.GenerateFromMultipleMessagesChatRequest(request, notifyOnPromptResult, notifyOnToken, nil) -} - -func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest, endpointConfigFn endpointGenerationConfigurationFn, notifyOnPromptResult bool, notifyOnToken bool, initialTraceID *OpenAIRequestTraceID) ( - traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle], - completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { - - if initialTraceID == nil { - traceID = &OpenAIRequestTraceID{ - ID: uuid.New().String(), - Created: int(time.Now().Unix()), - } - } else { - traceID = initialTraceID - } - - bc, request, err := oais.getConfig(request) - if err != nil { - log.Error().Err(err).Msgf("[oais::GenerateTextFromRequest] error getting configuration") - return - } - - if request.ResponseFormat.Type == "json_object" { - request.Grammar = grammar.JSONBNF - } - - bc.Grammar = request.Grammar - - if request.Stream && len(bc.PromptStrings) > 1 { - log.Warn().Msg("potentially cannot handle more than 1 `PromptStrings` when Streaming?") - } - - rawFinalResultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - finalResultChannel = rawFinalResultChannel - promptResultsChannels = []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle]{} - var rawCompletionsChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] - var rawTokenChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] - if notifyOnPromptResult { - rawCompletionsChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - } - if notifyOnToken { - rawTokenChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - } - - promptResultsChannelLock := sync.Mutex{} - - endpointConfig := endpointConfigFn(bc, request) - - if len(endpointConfig.TemplatePath) == 0 { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if oais.ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", bc.Model)) { - endpointConfig.TemplatePath = bc.Model - } else { - log.Warn().Msgf("failed to find any template for %+v", request) - } - } - - setupWG := sync.WaitGroup{} - var prompts []string - if lPS := len(bc.PromptStrings); lPS > 0 { - setupWG.Add(lPS) - prompts = bc.PromptStrings - } else { - setupWG.Add(len(bc.InputStrings)) - prompts = bc.InputStrings - } - - var setupError error = nil - - for pI, p := range prompts { - - go func(promptIndex int, prompt string) { - if endpointConfig.TemplatePath != "" { - promptTemplateData := model.PromptTemplateData{ - Input: prompt, - } - err := mergo.Merge(promptTemplateData, endpointConfig.TemplateData, mergo.WithOverride) - if err == nil { - templatedInput, err := oais.ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, endpointConfig.TemplatePath, promptTemplateData) - if err == nil { - prompt = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", prompt) - } - } - } - - log.Debug().Msgf("[OAIS GenerateTextFromRequest] Prompt: %q", prompt) - promptResultsChannel, completionChannels, tokenChannels, err := oais.llmbs.GenerateText(prompt, request, bc, - func(r *backend.LLMResponse) schema.Choice { - return endpointConfig.ResultMappingFn(r, promptIndex) - }, notifyOnPromptResult, notifyOnToken) - if err != nil { - log.Error().Msgf("Unable to generate text prompt: %q\nerr: %q", prompt, err) - promptResultsChannelLock.Lock() - setupError = errors.Join(setupError, err) - promptResultsChannelLock.Unlock() - setupWG.Done() - return - } - if notifyOnPromptResult { - concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(completionChannels, endpointConfig.CompletionMappingFn), rawCompletionsChannel, true) - } - if notifyOnToken { - concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(tokenChannels, endpointConfig.TokenMappingFn), rawTokenChannel, true) - } - promptResultsChannelLock.Lock() - promptResultsChannels = append(promptResultsChannels, promptResultsChannel) - promptResultsChannelLock.Unlock() - setupWG.Done() - }(pI, p) - - } - setupWG.Wait() - - // If any of the setup goroutines experienced an error, quit early here. - if setupError != nil { - go func() { - log.Error().Err(setupError).Msgf("[OAIS GenerateTextFromRequest] caught an error during setup") - rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: setupError} - close(rawFinalResultChannel) - }() - return - } - - initialResponse := &schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, - Object: endpointConfig.SchemaObject, - Usage: schema.OpenAIUsage{}, - } - - // utils.SliceOfChannelsRawMerger[[]schema.Choice](promptResultsChannels, rawFinalResultChannel, func(results []schema.Choice) (*schema.OpenAIResponse, error) { - concurrency.SliceOfChannelsReducer( - promptResultsChannels, rawFinalResultChannel, - func(iv concurrency.ErrorOr[*backend.LLMResponseBundle], result concurrency.ErrorOr[*schema.OpenAIResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] { - - if iv.Error != nil { - result.Error = iv.Error - return result - } - result.Value.Usage.PromptTokens += iv.Value.Usage.Prompt - result.Value.Usage.CompletionTokens += iv.Value.Usage.Completion - result.Value.Usage.TotalTokens = result.Value.Usage.PromptTokens + result.Value.Usage.CompletionTokens - - result.Value.Choices = append(result.Value.Choices, iv.Value.Response...) - - return result - }, concurrency.ErrorOr[*schema.OpenAIResponse]{Value: initialResponse}, true) - - completionsChannel = rawCompletionsChannel - tokenChannel = rawTokenChannel - - return -} - -// TODO: For porting sanity, this is distinct from GenerateTextFromRequest and is _currently_ specific to Chat purposes -// this is not a final decision -- just a reality of moving a lot of parts at once -// / This has _become_ Chat which wasn't the goal... More cleanup in the future once it's stable? -func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool, initialTraceID *OpenAIRequestTraceID) ( - traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], - completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { - - if initialTraceID == nil { - traceID = &OpenAIRequestTraceID{ - ID: uuid.New().String(), - Created: int(time.Now().Unix()), - } - } else { - traceID = initialTraceID - } - - bc, request, err := oais.getConfig(request) - if err != nil { - return - } - - // Allow the user to set custom actions via config file - // to be "embedded" in each model - noActionName := "answer" - noActionDescription := "use this action to answer without performing any action" - - if bc.FunctionsConfig.NoActionFunctionName != "" { - noActionName = bc.FunctionsConfig.NoActionFunctionName - } - if bc.FunctionsConfig.NoActionDescriptionName != "" { - noActionDescription = bc.FunctionsConfig.NoActionDescriptionName - } - - if request.ResponseFormat.Type == "json_object" { - request.Grammar = grammar.JSONBNF - } - - bc.Grammar = request.Grammar - - processFunctions := false - funcs := grammar.Functions{} - // process functions if we have any defined or if we have a function call string - if len(request.Functions) > 0 && bc.ShouldUseFunctions() { - log.Debug().Msgf("Response needs to process functions") - - processFunctions = true - - noActionGrammar := grammar.Function{ - Name: noActionName, - Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "The message to reply the user with", - }}, - }, - } - - // Append the no action function - funcs = append(funcs, request.Functions...) - if !bc.FunctionsConfig.DisableNoAction { - funcs = append(funcs, noActionGrammar) - } - - // Force picking one of the functions by the request - if bc.FunctionToCall() != "" { - funcs = funcs.Select(bc.FunctionToCall()) - } - - // Update input grammar - jsStruct := funcs.ToJSONStructure() - bc.Grammar = jsStruct.Grammar("", bc.FunctionsConfig.ParallelCalls) - } else if request.JSONFunctionGrammarObject != nil { - bc.Grammar = request.JSONFunctionGrammarObject.Grammar("", bc.FunctionsConfig.ParallelCalls) - } - - if request.Stream && processFunctions { - log.Warn().Msg("Streaming + Functions is highly experimental in this version") - } - - var predInput string - - if !bc.TemplateConfig.UseTokenizerTemplate || processFunctions { - - suppressConfigSystemPrompt := false - mess := []string{} - for messageIndex, i := range request.Messages { - var content string - role := i.Role - - // if function call, we might want to customize the role so we can display better that the "assistant called a json action" - // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request - if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" { - roleFn := "assistant_function_call" - r := bc.Roles[roleFn] - if r != "" { - role = roleFn - } - } - r := bc.Roles[role] - contentExists := i.Content != nil && i.StringContent != "" - - fcall := i.FunctionCall - if len(i.ToolCalls) > 0 { - fcall = i.ToolCalls - } - - // First attempt to populate content via a chat message specific template - if bc.TemplateConfig.ChatMessage != "" { - chatMessageData := model.ChatMessageTemplateData{ - SystemPrompt: bc.SystemPrompt, - Role: r, - RoleName: role, - Content: i.StringContent, - FunctionCall: fcall, - FunctionName: i.Name, - LastMessage: messageIndex == (len(request.Messages) - 1), - Function: bc.Grammar != "" && (messageIndex == (len(request.Messages) - 1)), - MessageIndex: messageIndex, - } - templatedChatMessage, err := oais.ml.EvaluateTemplateForChatMessage(bc.TemplateConfig.ChatMessage, chatMessageData) - if err != nil { - log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, bc.TemplateConfig.ChatMessage, err) - } else { - if templatedChatMessage == "" { - log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", bc.TemplateConfig.ChatMessage, chatMessageData) - continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf - } - log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) - content = templatedChatMessage - } - } - marshalAnyRole := func(f any) { - j, err := json.Marshal(f) - if err == nil { - if contentExists { - content += "\n" + fmt.Sprint(r, " ", string(j)) - } else { - content = fmt.Sprint(r, " ", string(j)) - } - } - } - marshalAny := func(f any) { - j, err := json.Marshal(f) - if err == nil { - if contentExists { - content += "\n" + string(j) - } else { - content = string(j) - } - } - } - // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. - if content == "" { - if r != "" { - if contentExists { - content = fmt.Sprint(r, i.StringContent) - } - - if i.FunctionCall != nil { - marshalAnyRole(i.FunctionCall) - } - } else { - if contentExists { - content = fmt.Sprint(i.StringContent) - } - - if i.FunctionCall != nil { - marshalAny(i.FunctionCall) - } - - if i.ToolCalls != nil { - marshalAny(i.ToolCalls) - } - } - // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately - if contentExists && role == "system" { - suppressConfigSystemPrompt = true - } - } - - mess = append(mess, content) - } - - predInput = strings.Join(mess, "\n") - - log.Debug().Msgf("Prompt (before templating): %s", predInput) - - templateFile := "" - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if oais.ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", bc.Model)) { - templateFile = bc.Model - } - - if bc.TemplateConfig.Chat != "" && !processFunctions { - templateFile = bc.TemplateConfig.Chat - } - - if bc.TemplateConfig.Functions != "" && processFunctions { - templateFile = bc.TemplateConfig.Functions - } - - if templateFile != "" { - templatedInput, err := oais.ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: bc.SystemPrompt, - SuppressSystemPrompt: suppressConfigSystemPrompt, - Input: predInput, - Functions: funcs, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } else { - log.Debug().Msgf("Template failed loading: %s", err.Error()) - } - } - } - log.Debug().Msgf("Prompt (after templating): %s", predInput) - if processFunctions { - log.Debug().Msgf("Grammar: %+v", bc.Grammar) - } - - rawFinalResultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - var rawCompletionsChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] - var rawTokenChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] - if notifyOnPromptResult { - rawCompletionsChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - } - if notifyOnToken { - rawTokenChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - } - - rawResultChannel, individualCompletionChannels, tokenChannels, err := oais.llmbs.GenerateText(predInput, request, bc, func(resp *backend.LLMResponse) schema.Choice { - return schema.Choice{ - Index: 0, // ??? - FinishReason: "stop", - Message: &schema.Message{ - Role: "assistant", - Content: resp.Response, - }, - } - }, notifyOnPromptResult, notifyOnToken) - - chatSimpleMappingFn := func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] { - if resp.Error != nil || resp.Value == nil { - return concurrency.ErrorOr[*schema.OpenAIResponse]{Error: resp.Error} - } - return concurrency.ErrorOr[*schema.OpenAIResponse]{ - Value: &schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - Delta: &schema.Message{ - Role: "assistant", - Content: resp.Value.Response, - }, - Index: 0, - }, - }, - Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: resp.Value.Usage.Prompt, - CompletionTokens: resp.Value.Usage.Completion, - TotalTokens: resp.Value.Usage.Prompt + resp.Value.Usage.Completion, - }, - }, - } - } - - if notifyOnPromptResult { - concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(individualCompletionChannels, chatSimpleMappingFn), rawCompletionsChannel, true) - } - if notifyOnToken { - concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(tokenChannels, chatSimpleMappingFn), rawTokenChannel, true) - } - - go func() { - rawResult := <-rawResultChannel - if rawResult.Error != nil { - log.Warn().Msgf("OpenAIService::processTools GenerateText error [DEBUG THIS?] %q", rawResult.Error) - return - } - llmResponseChoices := rawResult.Value.Response - - if processFunctions && len(llmResponseChoices) > 1 { - log.Warn().Msgf("chat functions response with %d choices in response, debug this?", len(llmResponseChoices)) - log.Debug().Msgf("%+v", llmResponseChoices) - } - - for _, result := range rawResult.Value.Response { - // If no functions, just return the raw result. - if !processFunctions { - - resp := schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{result}, - Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: rawResult.Value.Usage.Prompt, - CompletionTokens: rawResult.Value.Usage.Completion, - TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion, - }, - } - - rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &resp} - - continue - } - // At this point, things are function specific! - - // Oh no this can't be the right way to do this... but it works. Save us, mudler! - fString := fmt.Sprintf("%s", result.Message.Content) - results := parseFunctionCall(fString, bc.FunctionsConfig.ParallelCalls) - noActionToRun := (len(results) > 0 && results[0].name == noActionName) - - if noActionToRun { - log.Debug().Msg("-- noActionToRun branch --") - initialMessage := schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: ""}}}, - Object: "stop", - } - rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &initialMessage} - - result, err := oais.handleQuestion(bc, request, results[0].arguments, predInput) - if err != nil { - log.Error().Msgf("error handling question: %s", err.Error()) - return - } - - resp := schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}}, - Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: rawResult.Value.Usage.Prompt, - CompletionTokens: rawResult.Value.Usage.Completion, - TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion, - }, - } - - rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &resp} - - } else { - log.Debug().Msgf("[GenerateFromMultipleMessagesChatRequest] fnResultsBranch: %+v", results) - for i, ss := range results { - name, args := ss.name, ss.arguments - - initialMessage := schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{ - FinishReason: "function_call", - Message: &schema.Message{ - Role: "assistant", - ToolCalls: []schema.ToolCall{ - { - Index: i, - ID: traceID.ID, - Type: "function", - FunctionCall: schema.FunctionCall{ - Name: name, - Arguments: args, - }, - }, - }, - }}}, - Object: "chat.completion.chunk", - } - rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &initialMessage} - } - } - } - - close(rawFinalResultChannel) - }() - - finalResultChannel = rawFinalResultChannel - completionsChannel = rawCompletionsChannel - tokenChannel = rawTokenChannel - return -} - -func (oais *OpenAIService) handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, args, prompt string) (string, error) { - log.Debug().Msgf("[handleQuestion called] nothing to do, computing a reply") - - // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} - json.Unmarshal([]byte(args), &arguments) - m, exists := arguments["message"] - if exists { - switch message := m.(type) { - case string: - if message != "" { - log.Debug().Msgf("Reply received from LLM: %s", message) - message = oais.llmbs.Finetune(*config, prompt, message) - log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - - return message, nil - } - } - } - - log.Debug().Msgf("No action received from LLM, without a message, computing a reply") - // Otherwise ask the LLM to understand the JSON output and the context, and return a message - // Note: This costs (in term of CPU/GPU) another computation - config.Grammar = "" - images := []string{} - for _, m := range input.Messages { - images = append(images, m.StringImages...) - } - - resultChannel, _, err := oais.llmbs.Inference(input.Context, &backend.LLMRequest{ - Text: prompt, - Images: images, - RawMessages: input.Messages, // Experimental - }, config, false) - - if err != nil { - log.Error().Msgf("inference setup error: %s", err.Error()) - return "", err - } - - raw := <-resultChannel - if raw.Error != nil { - log.Error().Msgf("inference error: %q", raw.Error.Error()) - return "", err - } - if raw.Value == nil { - log.Warn().Msgf("nil inference response") - return "", nil - } - return oais.llmbs.Finetune(*config, prompt, raw.Value.Response), nil -} - -type funcCallResults struct { - name string - arguments string -} - -func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults { - - results := []funcCallResults{} - - // TODO: use generics to avoid this code duplication - if multipleResults { - ss := []map[string]interface{}{} - s := utils.EscapeNewLines(llmresult) - json.Unmarshal([]byte(s), &ss) - - for _, s := range ss { - func_name, ok := s["function"] - if !ok { - continue - } - args, ok := s["arguments"] - if !ok { - continue - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - continue - } - results = append(results, funcCallResults{name: funcName, arguments: string(d)}) - } - } else { - // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients - s := utils.EscapeNewLines(llmresult) - if err := json.Unmarshal([]byte(s), &ss); err != nil { - log.Error().Msgf("error unmarshalling JSON: %s", err.Error()) - return results - } - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name, ok := ss["function"] - if !ok { - log.Debug().Msgf("ss[function] is not OK!, llm result: %q", llmresult) - return results - } - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - if !ok { - log.Debug().Msg("ss[arguments] is not OK!") - return results - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - log.Debug().Msgf("unexpected func_name: %+v", func_name) - return results - } - results = append(results, funcCallResults{name: funcName, arguments: string(d)}) - } - return results -} diff --git a/core/startup/config_file_watcher.go b/core/startup/config_file_watcher.go index 9c758e252ddc..5f6834d424a2 100644 --- a/core/startup/config_file_watcher.go +++ b/core/startup/config_file_watcher.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path" + "time" "github.com/fsnotify/fsnotify" "github.com/go-skynet/LocalAI/core/config" @@ -12,89 +13,157 @@ import ( "github.com/rs/zerolog/log" ) -type WatchConfigDirectoryCloser func() error - -func ReadApiKeysJson(configDir string, appConfig *config.ApplicationConfig) error { - fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json")) - if err == nil { - // Parse JSON content from the file - var fileKeys []string - err := json.Unmarshal(fileContent, &fileKeys) - if err == nil { - appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...) - return nil - } - return err - } - return err +type fileHandler func(fileContent []byte, appConfig *config.ApplicationConfig) error + +type configFileHandler struct { + handlers map[string]fileHandler + + watcher *fsnotify.Watcher + + configDir string + appConfig *config.ApplicationConfig } -func ReadExternalBackendsJson(configDir string, appConfig *config.ApplicationConfig) error { - fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json")) - if err != nil { - return err +// TODO: This should be a singleton eventually so other parts of the code can register config file handlers, +// then we can export it to other packages +func newConfigFileHandler(appConfig *config.ApplicationConfig) configFileHandler { + c := configFileHandler{ + handlers: make(map[string]fileHandler), + configDir: appConfig.DynamicConfigsDir, + appConfig: appConfig, } - // Parse JSON content from the file - var fileBackends map[string]string - err = json.Unmarshal(fileContent, &fileBackends) - if err != nil { - return err + c.Register("api_keys.json", readApiKeysJson(*appConfig), true) + c.Register("external_backends.json", readExternalBackendsJson(*appConfig), true) + return c +} + +func (c *configFileHandler) Register(filename string, handler fileHandler, runNow bool) error { + _, ok := c.handlers[filename] + if ok { + return fmt.Errorf("handler already registered for file %s", filename) } - err = mergo.Merge(&appConfig.ExternalGRPCBackends, fileBackends) - if err != nil { - return err + c.handlers[filename] = handler + if runNow { + c.callHandler(path.Join(c.appConfig.DynamicConfigsDir, filename), handler) } return nil } -var CONFIG_FILE_UPDATES = map[string]func(configDir string, appConfig *config.ApplicationConfig) error{ - "api_keys.json": ReadApiKeysJson, - "external_backends.json": ReadExternalBackendsJson, -} +func (c *configFileHandler) callHandler(filename string, handler fileHandler) { + fileContent, err := os.ReadFile(filename) + if err != nil && !os.IsNotExist(err) { + log.Error().Err(err).Str("filename", filename).Msg("could not read file") + } -func WatchConfigDirectory(configDir string, appConfig *config.ApplicationConfig) (WatchConfigDirectoryCloser, error) { - if len(configDir) == 0 { - return nil, fmt.Errorf("configDir blank") + if err = handler(fileContent, c.appConfig); err != nil { + log.Error().Err(err).Msg("WatchConfigDirectory goroutine failed to update options") } +} + +func (c *configFileHandler) Watch() error { configWatcher, err := fsnotify.NewWatcher() + c.watcher = configWatcher if err != nil { - log.Fatal().Msgf("Unable to create a watcher for the LocalAI Configuration Directory: %+v", err) + log.Fatal().Err(err).Str("configdir", c.configDir).Msg("wnable to create a watcher for configuration directory") } - ret := func() error { - configWatcher.Close() - return nil + + if c.appConfig.DynamicConfigsDirPollInterval > 0 { + log.Debug().Msg("Poll interval set, falling back to polling for configuration changes") + ticker := time.NewTicker(c.appConfig.DynamicConfigsDirPollInterval) + go func() { + for { + <-ticker.C + for file, handler := range c.handlers { + log.Debug().Str("file", file).Msg("polling config file") + c.callHandler(file, handler) + } + } + }() } // Start listening for events. go func() { for { select { - case event, ok := <-configWatcher.Events: + case event, ok := <-c.watcher.Events: if !ok { return } - if event.Has(fsnotify.Write) { - for targetName, watchFn := range CONFIG_FILE_UPDATES { - if event.Name == targetName { - err := watchFn(configDir, appConfig) - log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err) - } + if event.Has(fsnotify.Write | fsnotify.Create | fsnotify.Remove) { + handler, ok := c.handlers[path.Base(event.Name)] + if !ok { + continue } + + c.callHandler(event.Name, handler) } - case _, ok := <-configWatcher.Errors: + case err, ok := <-c.watcher.Errors: + log.Error().Err(err).Msg("config watcher error received") if !ok { return } - log.Error().Err(err).Msg("error encountered while watching config directory") } } }() // Add a path. - err = configWatcher.Add(configDir) + err = c.watcher.Add(c.appConfig.DynamicConfigsDir) if err != nil { - return ret, fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err) + return fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err) } - return ret, nil + return nil +} + +// TODO: When we institute graceful shutdown, this should be called +func (c *configFileHandler) Stop() { + c.watcher.Close() +} + +func readApiKeysJson(startupAppConfig config.ApplicationConfig) fileHandler { + handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error { + log.Debug().Msg("processing api_keys.json") + + if len(fileContent) > 0 { + // Parse JSON content from the file + var fileKeys []string + err := json.Unmarshal(fileContent, &fileKeys) + if err != nil { + return err + } + + appConfig.ApiKeys = append(startupAppConfig.ApiKeys, fileKeys...) + } else { + appConfig.ApiKeys = startupAppConfig.ApiKeys + } + log.Debug().Msg("api keys loaded from api_keys.json") + return nil + } + + return handler +} + +func readExternalBackendsJson(startupAppConfig config.ApplicationConfig) fileHandler { + handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error { + log.Debug().Msg("processing external_backends.json") + + if len(fileContent) > 0 { + // Parse JSON content from the file + var fileBackends map[string]string + err := json.Unmarshal(fileContent, &fileBackends) + if err != nil { + return err + } + appConfig.ExternalGRPCBackends = startupAppConfig.ExternalGRPCBackends + err = mergo.Merge(&appConfig.ExternalGRPCBackends, &fileBackends) + if err != nil { + return err + } + } else { + appConfig.ExternalGRPCBackends = startupAppConfig.ExternalGRPCBackends + } + log.Debug().Msg("external backends loaded from external_backends.json") + return nil + } + return handler } diff --git a/core/startup/startup.go b/core/startup/startup.go index 92ccaa9d616a..97882a22126a 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -4,102 +4,85 @@ import ( "fmt" "os" - "github.com/go-skynet/LocalAI/core" - "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/config" - openaiendpoint "github.com/go-skynet/LocalAI/core/http/endpoints/openai" // TODO: This is dubious. Fix this when splitting assistant api up. "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/pkg/assets" "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog" + pkgStartup "github.com/go-skynet/LocalAI/pkg/startup" "github.com/rs/zerolog/log" ) -// (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) { -func Startup(opts ...config.AppOption) (*core.Application, error) { +func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) { options := config.NewApplicationConfig(opts...) - zerolog.SetGlobalLevel(zerolog.InfoLevel) - if options.Debug { - zerolog.SetGlobalLevel(zerolog.DebugLevel) - } - log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath) log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) // Make sure directories exists if options.ModelPath == "" { - return nil, fmt.Errorf("options.ModelPath cannot be empty") + return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty") } err := os.MkdirAll(options.ModelPath, 0755) if err != nil { - return nil, fmt.Errorf("unable to create ModelPath: %q", err) + return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err) } if options.ImageDir != "" { err := os.MkdirAll(options.ImageDir, 0755) if err != nil { - return nil, fmt.Errorf("unable to create ImageDir: %q", err) + return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err) } } if options.AudioDir != "" { err := os.MkdirAll(options.AudioDir, 0755) if err != nil { - return nil, fmt.Errorf("unable to create AudioDir: %q", err) + return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err) } } if options.UploadDir != "" { err := os.MkdirAll(options.UploadDir, 0755) if err != nil { - return nil, fmt.Errorf("unable to create UploadDir: %q", err) - } - } - if options.ConfigsDir != "" { - err := os.MkdirAll(options.ConfigsDir, 0755) - if err != nil { - return nil, fmt.Errorf("unable to create ConfigsDir: %q", err) + return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err) } } - // Load config jsons - utils.LoadConfig(options.UploadDir, openaiendpoint.UploadedFilesFile, &openaiendpoint.UploadedFiles) - utils.LoadConfig(options.ConfigsDir, openaiendpoint.AssistantsConfigFile, &openaiendpoint.Assistants) - utils.LoadConfig(options.ConfigsDir, openaiendpoint.AssistantsFileConfigFile, &openaiendpoint.AssistantFiles) + // + pkgStartup.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...) - app := createApplication(options) + cl := config.NewBackendConfigLoader() + ml := model.NewModelLoader(options.ModelPath) - services.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...) + configLoaderOpts := options.ToConfigLoaderOptions() - if err := app.BackendConfigLoader.LoadBackendConfigsFromPath(options.ModelPath, app.ApplicationConfig.ToConfigLoaderOptions()...); err != nil { + if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config files") } if options.ConfigFile != "" { - if err := app.BackendConfigLoader.LoadBackendConfigFile(options.ConfigFile, app.ApplicationConfig.ToConfigLoaderOptions()...); err != nil { + if err := cl.LoadBackendConfigFile(options.ConfigFile, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config file") } } - if err := app.BackendConfigLoader.Preload(options.ModelPath); err != nil { + if err := cl.Preload(options.ModelPath); err != nil { log.Error().Err(err).Msg("error downloading models") } if options.PreloadJSONModels != "" { - if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, app.BackendConfigLoader, options.Galleries); err != nil { - return nil, err + if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { + return nil, nil, nil, err } } if options.PreloadModelsFromPath != "" { - if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, app.BackendConfigLoader, options.Galleries); err != nil { - return nil, err + if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { + return nil, nil, nil, err } } if options.Debug { - for _, v := range app.BackendConfigLoader.ListBackendConfigs() { - cfg, _ := app.BackendConfigLoader.GetBackendConfig(v) + for _, v := range cl.ListBackendConfigs() { + cfg, _ := cl.GetBackendConfig(v) log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) } } @@ -117,17 +100,17 @@ func Startup(opts ...config.AppOption) (*core.Application, error) { go func() { <-options.Context.Done() log.Debug().Msgf("Context canceled, shutting down") - app.ModelLoader.StopAllGRPC() + ml.StopAllGRPC() }() if options.WatchDog { wd := model.NewWatchDog( - app.ModelLoader, + ml, options.WatchDogBusyTimeout, options.WatchDogIdleTimeout, options.WatchDogBusy, options.WatchDogIdle) - app.ModelLoader.SetWatchDog(wd) + ml.SetWatchDog(wd) go wd.Run() go func() { <-options.Context.Done() @@ -136,36 +119,11 @@ func Startup(opts ...config.AppOption) (*core.Application, error) { }() } - log.Info().Msg("core/startup process completed!") - return app, nil -} + // Watch the configuration directory + // If the directory does not exist, we don't watch it + configHandler := newConfigFileHandler(options) + configHandler.Watch() -// In Lieu of a proper DI framework, this function wires up the Application manually. -// This is in core/startup rather than core/state.go to keep package references clean! -func createApplication(appConfig *config.ApplicationConfig) *core.Application { - app := &core.Application{ - ApplicationConfig: appConfig, - BackendConfigLoader: config.NewBackendConfigLoader(), - ModelLoader: model.NewModelLoader(appConfig.ModelPath), - } - - var err error - - app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - - app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.GalleryService = services.NewGalleryService(app.ApplicationConfig.ModelPath) - app.ListModelsService = services.NewListModelsService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService) - - app.LocalAIMetricsService, err = services.NewLocalAIMetricsService() - if err != nil { - log.Warn().Msg("Unable to initialize LocalAIMetricsService - non-fatal, optional service") - } - - return app + log.Info().Msg("core/startup process completed!") + return cl, ml, options, nil } diff --git a/core/state.go b/core/state.go deleted file mode 100644 index cf0d614b05d2..000000000000 --- a/core/state.go +++ /dev/null @@ -1,41 +0,0 @@ -package core - -import ( - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" -) - -// TODO: Can I come up with a better name or location for this? -// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy -// Perhaps a proper DI system is worth it in the future, but for now keep things simple. -type Application struct { - - // Application-Level Config - ApplicationConfig *config.ApplicationConfig - // ApplicationState *ApplicationState - - // Core Low-Level Services - BackendConfigLoader *config.BackendConfigLoader - ModelLoader *model.ModelLoader - - // Backend Services - EmbeddingsBackendService *backend.EmbeddingsBackendService - ImageGenerationBackendService *backend.ImageGenerationBackendService - LLMBackendService *backend.LLMBackendService - TranscriptionBackendService *backend.TranscriptionBackendService - TextToSpeechBackendService *backend.TextToSpeechBackendService - - // LocalAI System Services - BackendMonitorService *services.BackendMonitorService - GalleryService *services.GalleryService - ListModelsService *services.ListModelsService - LocalAIMetricsService *services.LocalAIMetricsService - OpenAIService *services.OpenAIService -} - -// TODO [NEXT PR?]: Break up ApplicationConfig. -// Migrate over stuff that is not set via config at all - especially runtime stuff -type ApplicationState struct { -} diff --git a/docs/content/docs/advanced/advanced-usage.md b/docs/content/docs/advanced/advanced-usage.md index 4bd160308bbb..cbf7dba34122 100644 --- a/docs/content/docs/advanced/advanced-usage.md +++ b/docs/content/docs/advanced/advanced-usage.md @@ -402,6 +402,7 @@ In the help text below, BASEPATH is the location that local-ai is being executed | --upload-path | /tmp/localai/upload | Path to store uploads from files api | $LOCALAI_UPLOAD_PATH | | --config-path | /tmp/localai/config | | $LOCALAI_CONFIG_PATH | | --localai-config-dir | BASEPATH/configuration | Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json) | $LOCALAI_CONFIG_DIR | +| --localai-config-dir-poll-interval | | Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to a time duration to poll the LocalAI Config Dir (example: 1m) | $LOCALAI_CONFIG_DIR_POLL_INTERVAL | | --models-config-file | STRING | YAML file containing a list of model backend configs | $LOCALAI_MODELS_CONFIG_FILE | #### Models Flags diff --git a/docs/content/docs/features/GPU-acceleration.md b/docs/content/docs/features/GPU-acceleration.md index aa931f07dd44..b382309ec318 100644 --- a/docs/content/docs/features/GPU-acceleration.md +++ b/docs/content/docs/features/GPU-acceleration.md @@ -12,7 +12,7 @@ Section under construction This section contains instruction on how to use LocalAI with GPU acceleration. {{% alert icon="âš¡" context="warning" %}} -For accelleration for AMD or Metal HW there are no specific container images, see the [build]({{%relref "docs/getting-started/build#Acceleration" %}}) +For accelleration for AMD or Metal HW is still in development, for additional details see the [build]({{%relref "docs/getting-started/build#Acceleration" %}}) {{% /alert %}} @@ -110,6 +110,143 @@ llama_model_load_internal: total VRAM used: 1598 MB llama_init_from_file: kv self size = 512.00 MB ``` +## ROCM(AMD) acceleration + +There are a limited number of tested configurations for ROCm systems however most newer deditated GPU consumer grade devices seem to be supported under the current ROCm6 implementation. + +Due to the nature of ROCm it is best to run all implementations in containers as this limits the number of packages required for installation on host system, compatability and package versions for dependencies across all variations of OS must be tested independently if disired, please refer to the [build]({{%relref "docs/getting-started/build#Acceleration" %}}) documentation. + +### Requirements + +- `ROCm 6.x.x` compatible GPU/accelerator +- OS: `Ubuntu` (22.04, 20.04), `RHEL` (9.3, 9.2, 8.9, 8.8), `SLES` (15.5, 15.4) +- Installed to host: `amdgpu-dkms` and `rocm` >=6.0.0 as per ROCm documentation. + +### Recommendations + +- Do not use on a system running Wayland. +- If running with Xorg do not use GPU assigned for compute for desktop rendering. +- Ensure at least 100GB of free space on disk hosting container runtime and storing images prior to installation. + +### Limitations + +Ongoing verification testing of ROCm compatability with integrated backends. +Please note the following list of verified backends and devices. + +### Verified + +The devices in the following list have been tested with `hipblas` images running `ROCm 6.0.0` + +| Backend | Verified | Devices | +| ---- | ---- | ---- | +| llama.cpp | yes | Radeon VII (gfx906) | +| diffusers | yes | Radeon VII (gfx906) | +| piper | yes | Radeon VII (gfx906) | +| whisper | no | none | +| autogptq | no | none | +| bark | no | none | +| coqui | no | none | +| transformers | no | none | +| exllama | no | none | +| exllama2 | no | none | +| mamba | no | none | +| petals | no | none | +| sentencetransformers | no | none | +| transformers-musicgen | no | none | +| vall-e-x | no | none | +| vllm | no | none | + +**You can help by expanding this list.** + +### System Prep + +1. Check your GPU LLVM target is compatible with the version of ROCm. This can be found in the [LLVM Docs](https://llvm.org/docs/AMDGPUUsage.html). +2. Check which ROCm version is compatible with your LLVM target and your chosen OS (pay special attention to supported kernel versions). See the following for compatability for ([ROCm 6.0.0](https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.0.0/reference/system-requirements.html)) or ([ROCm 6.0.2](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html)) +3. Install you chosen version of the `dkms` and `rocm` (it is recommended that the native package manager be used for this process for any OS as version changes are executed more easily via this method if updates are required). Take care to restart after installing `amdgpu-dkms` and before installing `rocm`, for details regarding this see the installation documentation for your chosen OS ([6.0.2](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/native-install/index.html) or [6.0.0](https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.0.0/how-to/native-install/index.html)) +4. Deploy. Yes it's that easy. + +#### Setup Example (Docker/containerd) + +The following are examples of the ROCm specific configuration elements required. + +```yaml +# docker-compose.yaml + # For full functionality select a non-'core' image, version locking the image is recommended for debug purposes. + image: quay.io/go-skynet/local-ai:master-aio-gpu-hipblas + environment: + - DEBUG=true + # If your gpu is not already included in the current list of default targets the following build details are required. + - REBUILD=true + - BUILD_TYPE=hipblas + - GPU_TARGETS=gfx906 # Example for Radeon VII + devices: + # AMD GPU only require the following devices be passed through to the container for offloading to occur. + - /dev/dri + - /dev/kfd +``` + +The same can also be executed as a `run` for your container runtime + +``` +docker run \ + -e DEBUG=true \ + -e REBUILD=true \ + -e BUILD_TYPE=hipblas \ + -e GPU_TARGETS=gfx906 \ + --device /dev/dri \ + --device /dev/kfd \ + quay.io/go-skynet/local-ai:master-aio-gpu-hipblas +``` + +Please ensure to add all other required environment variables, port forwardings, etc to your `compose` file or `run` command. + +The rebuild process will take some time to complete when deploying these containers and it is recommended that you `pull` the image prior to deployment as depending on the version these images may be ~20GB in size. + +#### Example (k8s) (Advanced Deployment/WIP) + +For k8s deployments there is an additional step required before deployment, this is the deployment of the [ROCm/k8s-device-plugin](https://artifacthub.io/packages/helm/amd-gpu-helm/amd-gpu). +For any k8s environment the documentation provided by AMD from the ROCm project should be successful. It is recommended that if you use rke2 or OpenShift that you deploy the SUSE or RedHat provided version of this resource to ensure compatability. +After this has been completed the [helm chart from go-skynet](https://github.com/go-skynet/helm-charts) can be configured and deployed mostly un-edited. + +The following are details of the changes that should be made to ensure proper function. +While these details may be configurable in the `values.yaml` development of this Helm chart is ongoing and is subject to change. + +The following details indicate the final state of the localai deployment relevant to GPU function. + +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {NAME}-local-ai +... +spec: + ... + template: + ... + spec: + containers: + - env: + - name: HIP_VISIBLE_DEVICES + value: '0' + # This variable indicates the devices availible to container (0:device1 1:device2 2:device3) etc. + # For multiple devices (say device 1 and 3) the value would be equivelant to HIP_VISIBLE_DEVICES="0,2" + # Please take note of this when an iGPU is present in host system as compatability is not assured. + ... + resources: + limits: + amd.com/gpu: '1' + requests: + amd.com/gpu: '1' +``` + +This configuration has been tested on a 'custom' cluster managed by SUSE Rancher that was deployed on top of Ubuntu 22.04.4, certification of other configuration is ongoing and compatability is not gauranteed. + +### Notes + +- When installing the ROCM kernel driver on your system ensure that you are installing an equal or newer version that that which is currently implemented in LocalAI (6.0.0 at time of writing). +- AMD documentation indicates that this will ensure functionality however your milage may vary depending on the GPU and distro you are using. +- If you encounter an `Error 413` on attempting to upload an audio file or image for whisper or llava/bakllava on a k8s deployment, note that the ingress for your deployment may require the annontation `nginx.ingress.kubernetes.io/proxy-body-size: "25m"` to allow larger uploads. This may be included in future versions of the helm chart. + ## Intel acceleration (sycl) ### Requirements diff --git a/docs/content/docs/features/text-generation.md b/docs/content/docs/features/text-generation.md index c11894e7e409..3f3f0b56ef00 100644 --- a/docs/content/docs/features/text-generation.md +++ b/docs/content/docs/features/text-generation.md @@ -257,6 +257,10 @@ parameters: # swap_space: 2 # Uncomment to specify the maximum length of a sequence (including prompt and output) # max_model_len: 32768 +# Uncomment and specify the number of Tensor divisions. +# Allows you to partition and run large models. Performance gains are limited. +# https://github.com/vllm-project/vllm/issues/1435 +# tensor_parallel_size: 2 ``` The backend will automatically download the required files in order to run the model. @@ -356,4 +360,4 @@ template: completion: | {{.Input}} -``` \ No newline at end of file +``` diff --git a/docs/content/docs/reference/aio-images.md b/docs/content/docs/reference/aio-images.md index 40f01f06d933..b5253ee49ccd 100644 --- a/docs/content/docs/reference/aio-images.md +++ b/docs/content/docs/reference/aio-images.md @@ -9,13 +9,14 @@ All-In-One images are images that come pre-configured with a set of models and b In the AIO images there are models configured with the names of OpenAI models, however, they are really backed by Open Source models. You can find the table below -| Category | Model name | Real model | -| Text Generation | `gpt-4` | `phi-2`(CPU) or `hermes-2-pro-mistral`(GPU) | -| Multimodal | `gpt-4-vision-preview` | `bakllava`(CPU) or `llava-1.6-mistral`(GPU) | -| Text generation | `stablediffusion` | `stablediffusion`(CPU) `dreamshaper-8` (GPU) | -| Audio transcription | `whisper-1` | `whisper` with the `whisper-base` model | -| Text to Audio | `tts-1` | the `en-us-amy-low.onnx` model with `rhasspy` | -| Embeddings | `text-embedding-ada-002` | | +| Category | Model name | Real model (CPU) | Real model (GPU) | +| ---- | ---- | ---- | ---- | +| Text Generation | `gpt-4` | `phi-2` | `hermes-2-pro-mistral` | +| Multimodal Vision | `gpt-4-vision-preview` | `bakllava` | `llava-1.6-mistral` | +| Image Generation | `stablediffusion` | `stablediffusion` | `dreamshaper-8` | +| Speech to Text | `whisper-1` | `whisper` with `whisper-base` model | <= same | +| Text to Speech | `tts-1` | `en-us-amy-low.onnx` from `rhasspy/piper` | <= same | +| Embeddings | `text-embedding-ada-002` | `all-MiniLM-L6-v2` in Q4 | `all-MiniLM-L6-v2` | ## Usage diff --git a/docs/data/version.json b/docs/data/version.json index 6a618115d5e2..55eebaebe995 100644 --- a/docs/data/version.json +++ b/docs/data/version.json @@ -1,3 +1,3 @@ { - "version": "v2.12.4" + "version": "null" } diff --git a/embedded/models/hermes-2-pro-mistral.yaml b/embedded/models/hermes-2-pro-mistral.yaml index 7bfa94180548..dd18ce6f862a 100644 --- a/embedded/models/hermes-2-pro-mistral.yaml +++ b/embedded/models/hermes-2-pro-mistral.yaml @@ -6,14 +6,22 @@ parameters: template: chat_message: | <|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}} - {{- if .FunctionCall }}{{end}} - {{- if eq .RoleName "tool" }}{{end }} + {{- if .FunctionCall }} + + {{- else if eq .RoleName "tool" }} + + {{- end }} {{- if .Content}} - {{.Content}} + {{.Content }} + {{- end }} + {{- if .FunctionCall}} + {{toJson .FunctionCall}} + {{- end }} + {{- if .FunctionCall }} + + {{- else if eq .RoleName "tool" }} + {{- end }} - {{- if .FunctionCall}}{{toJson .FunctionCall}}{{end }} - {{- if .FunctionCall }}{{end }} - {{- if eq .RoleName "tool" }}{{end }} <|im_end|> # https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF#prompt-format-for-function-calling function: | diff --git a/embedded/models/llama3-instruct.yaml b/embedded/models/llama3-instruct.yaml new file mode 100644 index 000000000000..d483d2b2a16e --- /dev/null +++ b/embedded/models/llama3-instruct.yaml @@ -0,0 +1,48 @@ +name: llama3-8b-instruct +mmap: true +parameters: + model: huggingface://second-state/Llama-3-8B-Instruct-GGUF/Meta-Llama-3-8B-Instruct-Q5_K_M.gguf + +template: + chat_message: | + <|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|> + + {{ if .FunctionCall -}} + Function call: + {{ else if eq .RoleName "tool" -}} + Function response: + {{ end -}} + {{ if .Content -}} + {{.Content -}} + {{ else if .FunctionCall -}} + {{ toJson .FunctionCall -}} + {{ end -}} + <|eot_id|> + function: | + <|start_header_id|>system<|end_header_id|> + + You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: + + {{range .Functions}} + {'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }} + {{end}} + + Use the following pydantic model json schema for each tool call you will make: + {'title': 'FunctionCall', 'type': 'object', 'properties': {'arguments': {'title': 'Arguments', 'type': 'object'}, 'name': {'title': 'Name', 'type': 'string'}}, 'required': ['arguments', 'name']}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + Function call: + chat: | + <|begin_of_text|>{{.Input }} + <|start_header_id|>assistant<|end_header_id|> + completion: | + {{.Input}} +context_size: 8192 +f16: true +stopwords: +- <|im_end|> +- +- "<|eot_id|>" +usage: | + curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{ + "model": "llama3-8b-instruct", + "messages": [{"role": "user", "content": "How are you doing?", "temperature": 0.1}] + }' diff --git a/examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru b/examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru deleted file mode 100644 index c33bafe1e778..000000000000 --- a/examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru +++ /dev/null @@ -1,25 +0,0 @@ -meta { - name: -completions Stream - type: http - seq: 4 -} - -post { - url: {{PROTOCOL}}{{HOST}}:{{PORT}}/completions - body: json - auth: none -} - -headers { - Content-Type: application/json -} - -body:json { - { - "model": "{{DEFAULT_MODEL}}", - "prompt": "function downloadFile(string url, string outputPath) {", - "max_tokens": 256, - "temperature": 0.5, - "stream": true - } -} diff --git a/examples/langchain/langchainpy-localai-example/requirements.txt b/examples/langchain/langchainpy-localai-example/requirements.txt index 1e63b0bf9e66..ba7f8429252d 100644 --- a/examples/langchain/langchainpy-localai-example/requirements.txt +++ b/examples/langchain/langchainpy-localai-example/requirements.txt @@ -1,4 +1,4 @@ -aiohttp==3.9.2 +aiohttp==3.9.4 aiosignal==1.3.1 async-timeout==4.0.2 attrs==23.1.0 diff --git a/gallery/bert-embeddings.yaml b/gallery/bert-embeddings.yaml new file mode 100644 index 000000000000..0798bf5433a9 --- /dev/null +++ b/gallery/bert-embeddings.yaml @@ -0,0 +1,15 @@ +name: "bert-embeddings" +license: "Apache 2.0" +urls: +- https://huggingface.co/skeskinen/ggml +description: | + Bert model that can be used for embeddings +config_file: | + parameters: + model: bert-MiniLM-L6-v2q4_0.bin + backend: bert-embeddings + embeddings: true +files: +- filename: "bert-MiniLM-L6-v2q4_0.bin" + sha256: "a5a174d8772c8a569faf9f3136c441f2c3855b5bf35ed32274294219533feaad" + uri: "https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin" \ No newline at end of file diff --git a/gallery/index.yaml b/gallery/index.yaml new file mode 100644 index 000000000000..6b882768ca44 --- /dev/null +++ b/gallery/index.yaml @@ -0,0 +1,503 @@ +## Whisper +- url: "github:mudler/LocalAI/gallery/whisper-base.yaml@master" + name: "whisper-1" + license: other +## Bert embeddings +- url: "github:mudler/LocalAI/gallery/bert-embeddings.yaml@master" + name: "bert-embeddings" + license: other +- url: "github:mudler/LocalAI/gallery/bert-embeddings.yaml@master" + name: "text-embedding-ada-002" + license: other +## Stable Diffusion +- url: github:mudler/LocalAI/gallery/stablediffusion.yaml@master + name: stablediffusion + license: other +## Tiny Dream +- url: github:mudler/LocalAI/gallery/tinydream.yaml@master + name: tinydream + license: other +## Piper TTS +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us-kathleen-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us-kathleen-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-kathleen-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-ca-upc_ona-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-ca-upc_ona-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-ca-upc_ona-x-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-ca-upc_pau-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-ca-upc_pau-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-ca-upc_pau-x-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-da-nst_talesyntese-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-da-nst_talesyntese-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-da-nst_talesyntese-medium.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-de-eva_k-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-de-eva_k-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-de-eva_k-x-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-de-karlsson-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-de-karlsson-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-de-karlsson-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-de-kerstin-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-de-kerstin-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-de-kerstin-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-de-pavoque-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-de-pavoque-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-de-pavoque-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-de-ramona-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-de-ramona-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-de-ramona-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-de-thorsten-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-de-thorsten-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-de-thorsten-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-el-gr-rapunzelina-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-el-gr-rapunzelina-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-el-gr-rapunzelina-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-gb-alan-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-gb-alan-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-gb-alan-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-gb-southern_english_female-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-gb-southern_english_female-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-gb-southern_english_female-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us-amy-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us-amy-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-amy-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us-danny-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us-danny-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-danny-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us-kathleen-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us-kathleen-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-kathleen-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us-lessac-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us-lessac-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-lessac-low.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us-lessac-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us-lessac-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-lessac-medium.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us-libritts-high + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us-libritts-high.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-libritts-high.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us-ryan-high + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us-ryan-high.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-ryan-high.tar.gz +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us-ryan-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us-ryan-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-ryan-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us-ryan-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us-ryan-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us-ryan-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-en-us_lessac + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-en-us_lessac.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-en-us_lessac.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-es-carlfm-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-es-carlfm-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-es-carlfm-x-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-es-mls_10246-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-es-mls_10246-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-es-mls_10246-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-es-mls_9972-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-es-mls_9972-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-es-mls_9972-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-fi-harri-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-fi-harri-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-fi-harri-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-fr-gilles-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-fr-gilles-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-fr-gilles-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-fr-mls_1840-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-fr-mls_1840-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-fr-mls_1840-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-fr-siwis-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-fr-siwis-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-fr-siwis-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-fr-siwis-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-fr-siwis-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-fr-siwis-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-is-bui-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-is-bui-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-is-bui-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-is-salka-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-is-salka-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-is-salka-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-is-steinn-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-is-steinn-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-is-steinn-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-is-ugla-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-is-ugla-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-is-ugla-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-it-riccardo_fasol-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-it-riccardo_fasol-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-it-riccardo_fasol-x-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-kk-iseke-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-kk-iseke-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-kk-iseke-x-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-kk-issai-high + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-kk-issai-high.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-kk-issai-high.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-kk-raya-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-kk-raya-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-kk-raya-x-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-ne-google-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-ne-google-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-ne-google-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-ne-google-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-ne-google-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-ne-google-x-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-nl-mls_5809-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-nl-mls_5809-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-nl-mls_5809-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-nl-mls_7432-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-nl-mls_7432-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-nl-mls_7432-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-nl-nathalie-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-nl-nathalie-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-nl-nathalie-x-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-nl-rdh-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-nl-rdh-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-nl-rdh-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-nl-rdh-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-nl-rdh-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-nl-rdh-x-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-no-talesyntese-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-no-talesyntese-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-no-talesyntese-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-pl-mls_6892-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-pl-mls_6892-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-pl-mls_6892-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-pt-br-edresson-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-pt-br-edresson-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-pt-br-edresson-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-ru-irinia-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-ru-irinia-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-ru-irinia-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-sv-se-nst-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-sv-se-nst-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-sv-se-nst-medium.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-uk-lada-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-uk-lada-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-uk-lada-x-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-vi-25hours-single-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-vi-25hours-single-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-vi-25hours-single-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-vi-vivos-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-vi-vivos-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-vi-vivos-x-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-zh-cn-huayan-x-low + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-zh-cn-huayan-x-low.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-zh-cn-huayan-x-low.tar.gz + +- url: github:mudler/LocalAI/gallery/virtual.yaml@master + name: voice-zh_CN-huayan-medium + license: other + urls: + - https://github.com/rhasspy/piper/releases/download/v0.0.2/ + files: + - filename: voice-zh_CN-huayan-medium.tar.gz + uri: https://github.com/rhasspy/piper/releases/download/v0.0.2/voice-zh_CN-huayan-medium.tar.gz \ No newline at end of file diff --git a/gallery/stablediffusion.yaml b/gallery/stablediffusion.yaml new file mode 100644 index 000000000000..c8a0eb8baa57 --- /dev/null +++ b/gallery/stablediffusion.yaml @@ -0,0 +1,54 @@ +name: "stablediffusion-cpp" +license: "BSD-3" +urls: +- https://github.com/EdVince/Stable-Diffusion-NCNN +- https://github.com/EdVince/Stable-Diffusion-NCNN/blob/main/LICENSE + +description: | + Stable Diffusion in NCNN with c++, supported txt2img and img2img +config_file: | + name: stablediffusion-cpp + backend: stablediffusion + parameters: + model: stablediffusion_assets + +files: +- filename: "stablediffusion_assets/AutoencoderKL-256-256-fp16-opt.param" + sha256: "18ca4b66685e21406bcf64c484b3b680b4949900415536d599cc876579c85c82" + uri: "https://raw.githubusercontent.com/EdVince/Stable-Diffusion-NCNN/main/x86/linux/assets/AutoencoderKL-256-256-fp16-opt.param" +- filename: "stablediffusion_assets/AutoencoderKL-512-512-fp16-opt.param" + sha256: "cf45f63aacf3dbbab0f59ed92a6f2c14d9a1801314631cd3abe91e3c85639a20" + uri: "https://raw.githubusercontent.com/EdVince/Stable-Diffusion-NCNN/main/x86/linux/assets/AutoencoderKL-512-512-fp16-opt.param" +- filename: "stablediffusion_assets/AutoencoderKL-base-fp16.param" + sha256: "0254a056dce61b0c27dc9ec1b78b53bcf55315c540f55f051eb841aa992701ba" + uri: "https://raw.githubusercontent.com/EdVince/Stable-Diffusion-NCNN/main/x86/linux/assets/AutoencoderKL-base-fp16.param" +- filename: "stablediffusion_assets/AutoencoderKL-encoder-512-512-fp16.bin" + sha256: "ddcb79a9951b9f91e05e087739ed69da2c1c4ae30ba4168cce350b49d617c9fa" + uri: "https://github.com/EdVince/Stable-Diffusion-NCNN/releases/download/naifu/AutoencoderKL-encoder-512-512-fp16.bin" +- filename: "stablediffusion_assets/AutoencoderKL-fp16.bin" + sha256: "f02e71f80e70252734724bbfaed5c4ddd3a8ed7e61bb2175ff5f53099f0e35dd" + uri: "https://github.com/EdVince/Stable-Diffusion-NCNN/releases/download/naifu/AutoencoderKL-fp16.bin" +- filename: "stablediffusion_assets/FrozenCLIPEmbedder-fp16.bin" + sha256: "1c9a12f4e1dd1b295a388045f7f28a2352a4d70c3dc96a542189a3dd7051fdd6" + uri: "https://github.com/EdVince/Stable-Diffusion-NCNN/releases/download/naifu/FrozenCLIPEmbedder-fp16.bin" +- filename: "stablediffusion_assets/FrozenCLIPEmbedder-fp16.param" + sha256: "471afbe678dd1fd3fe764ef9c6eccaccb0a7d7e601f27b462aa926b20eb368c9" + uri: "https://raw.githubusercontent.com/EdVince/Stable-Diffusion-NCNN/main/x86/linux/assets/FrozenCLIPEmbedder-fp16.param" +- filename: "stablediffusion_assets/log_sigmas.bin" + sha256: "a2089f8aa4c61f9c200feaec541ab3f5c94233b28deb6d5e8bcd974fa79b68ac" + uri: "https://github.com/EdVince/Stable-Diffusion-NCNN/raw/main/x86/linux/assets/log_sigmas.bin" +- filename: "stablediffusion_assets/UNetModel-256-256-MHA-fp16-opt.param" + sha256: "a58c380229f09491776df837b7aa7adffc0a87821dc4708b34535da2e36e3da1" + uri: "https://raw.githubusercontent.com/EdVince/Stable-Diffusion-NCNN/main/x86/linux/assets/UNetModel-256-256-MHA-fp16-opt.param" +- filename: "stablediffusion_assets/UNetModel-512-512-MHA-fp16-opt.param" + sha256: "f12034067062827bd7f43d1d21888d1f03905401acf6c6eea22be23c259636fa" + uri: "https://raw.githubusercontent.com/EdVince/Stable-Diffusion-NCNN/main/x86/linux/assets/UNetModel-512-512-MHA-fp16-opt.param" +- filename: "stablediffusion_assets/UNetModel-base-MHA-fp16.param" + sha256: "696f6975de49f4325b53ce32aff81861a6d6c07cd9ce3f0aae2cc405350af38d" + uri: "https://raw.githubusercontent.com/EdVince/Stable-Diffusion-NCNN/main/x86/linux/assets/UNetModel-base-MHA-fp16.param" +- filename: "stablediffusion_assets/UNetModel-MHA-fp16.bin" + sha256: "d618918d011bfc1f644c0f2a33bf84931bd53b28a98492b0a8ed6f3a818852c3" + uri: "https://github.com/EdVince/Stable-Diffusion-NCNN/releases/download/naifu/UNetModel-MHA-fp16.bin" +- filename: "stablediffusion_assets/vocab.txt" + sha256: "e30e57b6f1e47616982ef898d8922be24e535b4fa3d0110477b3a6f02ebbae7d" + uri: "https://raw.githubusercontent.com/EdVince/Stable-Diffusion-NCNN/main/x86/linux/assets/vocab.txt" \ No newline at end of file diff --git a/gallery/tinydream.yaml b/gallery/tinydream.yaml new file mode 100644 index 000000000000..415762def9e6 --- /dev/null +++ b/gallery/tinydream.yaml @@ -0,0 +1,42 @@ +name: "tinydream" +license: "BSD-3" +urls: + - https://github.com/symisc/tiny-dream + - https://github.com/symisc/tiny-dream/blob/main/LICENSE + +description: | + An embedded, Header Only, Stable Diffusion C++ implementation +config_file: | + name: tinydream + backend: tinydream + parameters: + model: tinydream_assets + +files: + - filename: "tinydream_assets/AutoencoderKL-fp16.bin" + sha256: "f02e71f80e70252734724bbfaed5c4ddd3a8ed7e61bb2175ff5f53099f0e35dd" + uri: "https://github.com/M0Rf30/tiny-dream-bins/releases/download/1.0/AutoencoderKL-fp16.bin" + - filename: "tinydream_assets/AutoencoderKL-fp16.param" + sha256: "0254a056dce61b0c27dc9ec1b78b53bcf55315c540f55f051eb841aa992701ba" + uri: "https://github.com/M0Rf30/tiny-dream-bins/releases/download/1.0/AutoencoderKL-fp16.param" + - filename: "tinydream_assets/FrozenCLIPEmbedder-fp16.bin" + sha256: "1c9a12f4e1dd1b295a388045f7f28a2352a4d70c3dc96a542189a3dd7051fdd6" + uri: "https://github.com/M0Rf30/tiny-dream-bins/releases/download/1.0/FrozenCLIPEmbedder-fp16.bin" + - filename: "tinydream_assets/FrozenCLIPEmbedder-fp16.param" + sha256: "471afbe678dd1fd3fe764ef9c6eccaccb0a7d7e601f27b462aa926b20eb368c9" + uri: "https://github.com/M0Rf30/tiny-dream-bins/releases/download/1.0/FrozenCLIPEmbedder-fp16.param" + - filename: "tinydream_assets/RealESRGAN_x4plus_anime.bin" + sha256: "fe01c269cfd10cdef8e018ab66ebe750cf79c7af4d1f9c16c737e1295229bacc" + uri: "https://github.com/M0Rf30/tiny-dream-bins/releases/download/1.0/RealESRGAN_x4plus_anime.bin" + - filename: "tinydream_assets/RealESRGAN_x4plus_anime.param" + sha256: "2b8fb6e0ae4d2d85704ca08c119a2f5ea40add4f2ecd512eb7f4cd44b6127ed4" + uri: "https://github.com/M0Rf30/tiny-dream-bins/releases/download/1.0/RealESRGAN_x4plus_anime.param" + - filename: "tinydream_assets/UNetModel-fp16.bin" + sha256: "d618918d011bfc1f644c0f2a33bf84931bd53b28a98492b0a8ed6f3a818852c3" + uri: "https://github.com/M0Rf30/tiny-dream-bins/releases/download/1.0/UNetModel-fp16.bin" + - filename: "tinydream_assets/UNetModel-fp16.param" + sha256: "696f6975de49f4325b53ce32aff81861a6d6c07cd9ce3f0aae2cc405350af38d" + uri: "https://github.com/M0Rf30/tiny-dream-bins/releases/download/1.0/UNetModel-fp16.param" + - filename: "tinydream_assets/vocab.txt" + sha256: "e30e57b6f1e47616982ef898d8922be24e535b4fa3d0110477b3a6f02ebbae7d" + uri: "https://github.com/M0Rf30/tiny-dream-bins/releases/download/1.0/vocab.txt" \ No newline at end of file diff --git a/gallery/virtual.yaml b/gallery/virtual.yaml new file mode 100644 index 000000000000..054c3257794f --- /dev/null +++ b/gallery/virtual.yaml @@ -0,0 +1,6 @@ +name: "virtual" + +description: | + A Base model definition + +license: "N/A" \ No newline at end of file diff --git a/gallery/whisper-base.yaml b/gallery/whisper-base.yaml new file mode 100644 index 000000000000..574dbb13f5bf --- /dev/null +++ b/gallery/whisper-base.yaml @@ -0,0 +1,18 @@ +name: "whisper-base" +license: "MIT" +urls: +- https://github.com/ggerganov/whisper.cpp +- https://huggingface.co/ggerganov/whisper.cpp + +description: | + Port of OpenAI's Whisper model in C/C++ + +config_file: | + backend: whisper + parameters: + model: ggml-whisper-base.bin + +files: +- filename: "ggml-whisper-base.bin" + sha256: "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe" + uri: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin" \ No newline at end of file diff --git a/go.mod b/go.mod index 99af8ce7957f..0bf9aa029a81 100644 --- a/go.mod +++ b/go.mod @@ -29,7 +29,7 @@ require ( github.com/otiai10/openaigo v1.6.0 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/prometheus/client_golang v1.17.0 - github.com/rs/zerolog v1.31.0 + github.com/rs/zerolog v1.32.0 github.com/russross/blackfriday v1.6.0 github.com/sashabaranov/go-openai v1.20.4 github.com/schollz/progressbar/v3 v3.13.1 @@ -145,6 +145,7 @@ require ( github.com/go-audio/riff v1.0.0 // indirect github.com/go-logr/logr v1.2.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/gofiber/contrib/fiberzerolog v1.0.0 github.com/google/go-cmp v0.6.0 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect diff --git a/go.sum b/go.sum index a421e79c6850..55fdaf06be6e 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,8 @@ github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg78 github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gofiber/contrib/fiberzerolog v1.0.0 h1:IB8q+NO2zPNS4VHKde1x5DqtMJ5vGrvDCydnAjlFw3E= +github.com/gofiber/contrib/fiberzerolog v1.0.0/go.mod h1:SOi+Wo7RQlO/HV0jsYTu6uFQy+8ZPTzCZW4fDEKD3l8= github.com/gofiber/fiber/v2 v2.52.4 h1:P+T+4iK7VaqUsq2PALYEfBBo6bJZ4q3FP8cZ84EggTM= github.com/gofiber/fiber/v2 v2.52.4/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= github.com/gofiber/swagger v1.0.0 h1:BzUzDS9ZT6fDUa692kxmfOjc1DZiloLiPK/W5z1H1tc= @@ -281,6 +283,8 @@ github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUz github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= +github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww= github.com/russross/blackfriday v1.6.0/go.mod h1:ti0ldHuxg49ri4ksnFxlkCfN+hvslNlmVHqNRXXJNAY= github.com/sashabaranov/go-openai v1.20.4 h1:095xQ/fAtRa0+Rj21sezVJABgKfGPNbyx/sAN/hJUmg= diff --git a/main.go b/main.go index 8b5696d13589..9976906bed34 100644 --- a/main.go +++ b/main.go @@ -72,6 +72,7 @@ Version: ${version} kong.Vars{ "basepath": kong.ExpandPath("."), "remoteLibraryURL": "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml", + "galleries": `[{"name":"localai", "url":"github:mudler/LocalAI/gallery/index.yaml"}]`, "version": internal.PrintableVersion(), }, ) @@ -91,17 +92,20 @@ Version: ${version} switch *cli.CLI.LogLevel { case "error": - log.Info().Msg("Setting logging to error") zerolog.SetGlobalLevel(zerolog.ErrorLevel) + log.Info().Msg("Setting logging to error") case "warn": - log.Info().Msg("Setting logging to warn") zerolog.SetGlobalLevel(zerolog.WarnLevel) + log.Info().Msg("Setting logging to warn") case "info": - log.Info().Msg("Setting logging to info") zerolog.SetGlobalLevel(zerolog.InfoLevel) + log.Info().Msg("Setting logging to info") case "debug": - log.Info().Msg("Setting logging to debug") zerolog.SetGlobalLevel(zerolog.DebugLevel) + log.Debug().Msg("Setting logging to debug") + case "trace": + zerolog.SetGlobalLevel(zerolog.TraceLevel) + log.Trace().Msg("Setting logging to trace") } // Populate the application with the embedded backend assets diff --git a/pkg/concurrency/concurrency.go b/pkg/concurrency/concurrency.go deleted file mode 100644 index 324e8cc5afb4..000000000000 --- a/pkg/concurrency/concurrency.go +++ /dev/null @@ -1,135 +0,0 @@ -package concurrency - -import ( - "sync" -) - -// TODO: closeWhenDone bool parameter :: -// It currently is experimental, and therefore exists. -// Is there ever a situation to use false? - -// This function is used to merge the results of a slice of channels of a specific result type down to a single result channel of a second type. -// mappingFn allows the caller to convert from the input type to the output type -// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. -// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. -func SliceOfChannelsRawMerger[IndividualResultType any, OutputResultType any](individualResultChannels []<-chan IndividualResultType, outputChannel chan<- OutputResultType, mappingFn func(IndividualResultType) (OutputResultType, error), closeWhenDone bool) *sync.WaitGroup { - var wg sync.WaitGroup - wg.Add(len(individualResultChannels)) - mergingFn := func(c <-chan IndividualResultType) { - for r := range c { - mr, err := mappingFn(r) - if err == nil { - outputChannel <- mr - } - } - wg.Done() - } - for _, irc := range individualResultChannels { - go mergingFn(irc) - } - if closeWhenDone { - go func() { - wg.Wait() - close(outputChannel) - }() - } - - return &wg -} - -// This function is used to merge the results of a slice of channels of a specific result type down to a single result channel of THE SAME TYPE. -// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. -// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. -func SliceOfChannelsRawMergerWithoutMapping[ResultType any](individualResultsChannels []<-chan ResultType, outputChannel chan<- ResultType, closeWhenDone bool) *sync.WaitGroup { - return SliceOfChannelsRawMerger(individualResultsChannels, outputChannel, func(v ResultType) (ResultType, error) { return v, nil }, closeWhenDone) -} - -// This function is used to merge the results of a slice of channels of a specific result type down to a single succcess result channel of a second type, and an error channel -// mappingFn allows the caller to convert from the input type to the output type -// This variant is designed to be aware of concurrency.ErrorOr[T], splitting successes from failures. -// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. -// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. -func SliceOfChannelsMergerWithErrors[IndividualResultType any, OutputResultType any](individualResultChannels []<-chan ErrorOr[IndividualResultType], successChannel chan<- OutputResultType, errorChannel chan<- error, mappingFn func(IndividualResultType) (OutputResultType, error), closeWhenDone bool) *sync.WaitGroup { - var wg sync.WaitGroup - wg.Add(len(individualResultChannels)) - mergingFn := func(c <-chan ErrorOr[IndividualResultType]) { - for r := range c { - if r.Error != nil { - errorChannel <- r.Error - } else { - mv, err := mappingFn(r.Value) - if err != nil { - errorChannel <- err - } else { - successChannel <- mv - } - } - } - wg.Done() - } - for _, irc := range individualResultChannels { - go mergingFn(irc) - } - if closeWhenDone { - go func() { - wg.Wait() - close(successChannel) - close(errorChannel) - }() - } - return &wg -} - -// This function is used to reduce down the results of a slice of channels of a specific result type down to a single result value of a second type. -// reducerFn allows the caller to convert from the input type to the output type -// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. -// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. -func SliceOfChannelsReducer[InputResultType any, OutputResultType any](individualResultsChannels []<-chan InputResultType, outputChannel chan<- OutputResultType, - reducerFn func(iv InputResultType, ov OutputResultType) OutputResultType, initialValue OutputResultType, closeWhenDone bool) (wg *sync.WaitGroup) { - wg = &sync.WaitGroup{} - wg.Add(len(individualResultsChannels)) - reduceLock := sync.Mutex{} - reducingFn := func(c <-chan InputResultType) { - for iv := range c { - reduceLock.Lock() - initialValue = reducerFn(iv, initialValue) - reduceLock.Unlock() - } - wg.Done() - } - for _, irc := range individualResultsChannels { - go reducingFn(irc) - } - go func() { - wg.Wait() - outputChannel <- initialValue - if closeWhenDone { - close(outputChannel) - } - }() - return wg -} - -// This function is primarily designed to be used in combination with the above utility functions. -// A slice of input result channels of a specific type is provided, along with a function to map those values to another type -// A slice of output result channels is returned, where each value is mapped as it comes in. -// The order of the slice will be retained. -func SliceOfChannelsTransformer[InputResultType any, OutputResultType any](inputChanels []<-chan InputResultType, mappingFn func(v InputResultType) OutputResultType) (outputChannels []<-chan OutputResultType) { - rawOutputChannels := make([]<-chan OutputResultType, len(inputChanels)) - - transformingFn := func(ic <-chan InputResultType, oc chan OutputResultType) { - for iv := range ic { - oc <- mappingFn(iv) - } - close(oc) - } - - for ci, c := range inputChanels { - roc := make(chan OutputResultType) - go transformingFn(c, roc) - rawOutputChannels[ci] = roc - } - - outputChannels = rawOutputChannels - return -} diff --git a/pkg/concurrency/concurrency_test.go b/pkg/concurrency/concurrency_test.go deleted file mode 100644 index fedd74be7d62..000000000000 --- a/pkg/concurrency/concurrency_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package concurrency_test - -// TODO: noramlly, these go in utils_tests, right? Why does this cause problems only in pkg/utils? - -import ( - "fmt" - "slices" - - . "github.com/go-skynet/LocalAI/pkg/concurrency" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("utils/concurrency tests", func() { - It("SliceOfChannelsReducer works", func() { - individualResultsChannels := []<-chan int{} - initialValue := 0 - for i := 0; i < 3; i++ { - c := make(chan int) - go func(i int, c chan int) { - for ii := 1; ii < 4; ii++ { - c <- (i * ii) - } - close(c) - }(i, c) - individualResultsChannels = append(individualResultsChannels, c) - } - Expect(len(individualResultsChannels)).To(Equal(3)) - finalResultChannel := make(chan int) - wg := SliceOfChannelsReducer[int, int](individualResultsChannels, finalResultChannel, func(input int, val int) int { - return val + input - }, initialValue, true) - - Expect(wg).ToNot(BeNil()) - - result := <-finalResultChannel - - Expect(result).ToNot(Equal(0)) - Expect(result).To(Equal(18)) - }) - - It("SliceOfChannelsRawMergerWithoutMapping works", func() { - individualResultsChannels := []<-chan int{} - for i := 0; i < 3; i++ { - c := make(chan int) - go func(i int, c chan int) { - for ii := 1; ii < 4; ii++ { - c <- (i * ii) - } - close(c) - }(i, c) - individualResultsChannels = append(individualResultsChannels, c) - } - Expect(len(individualResultsChannels)).To(Equal(3)) - outputChannel := make(chan int) - wg := SliceOfChannelsRawMergerWithoutMapping(individualResultsChannels, outputChannel, true) - Expect(wg).ToNot(BeNil()) - outputSlice := []int{} - for v := range outputChannel { - outputSlice = append(outputSlice, v) - } - Expect(len(outputSlice)).To(Equal(9)) - slices.Sort(outputSlice) - Expect(outputSlice[0]).To(BeZero()) - Expect(outputSlice[3]).To(Equal(1)) - Expect(outputSlice[8]).To(Equal(6)) - }) - - It("SliceOfChannelsTransformer works", func() { - individualResultsChannels := []<-chan int{} - for i := 0; i < 3; i++ { - c := make(chan int) - go func(i int, c chan int) { - for ii := 1; ii < 4; ii++ { - c <- (i * ii) - } - close(c) - }(i, c) - individualResultsChannels = append(individualResultsChannels, c) - } - Expect(len(individualResultsChannels)).To(Equal(3)) - mappingFn := func(i int) string { - return fmt.Sprintf("$%d", i) - } - - outputChannels := SliceOfChannelsTransformer(individualResultsChannels, mappingFn) - Expect(len(outputChannels)).To(Equal(3)) - rSlice := []string{} - for ii := 1; ii < 4; ii++ { - for i := 0; i < 3; i++ { - res := <-outputChannels[i] - rSlice = append(rSlice, res) - } - } - slices.Sort(rSlice) - Expect(rSlice[0]).To(Equal("$0")) - Expect(rSlice[3]).To(Equal("$1")) - Expect(rSlice[8]).To(Equal("$6")) - }) -}) diff --git a/pkg/concurrency/types.go b/pkg/concurrency/types.go deleted file mode 100644 index 76081ba3b808..000000000000 --- a/pkg/concurrency/types.go +++ /dev/null @@ -1,6 +0,0 @@ -package concurrency - -type ErrorOr[T any] struct { - Value T - Error error -} diff --git a/pkg/grammar/functions.go b/pkg/functions/functions.go similarity index 98% rename from pkg/grammar/functions.go rename to pkg/functions/functions.go index 1038f5e6f147..d75a2ee391d7 100644 --- a/pkg/grammar/functions.go +++ b/pkg/functions/functions.go @@ -1,4 +1,4 @@ -package grammar +package functions import ( "encoding/json" diff --git a/pkg/grammar/grammar_suite_test.go b/pkg/functions/functions_suite_test.go similarity index 90% rename from pkg/grammar/grammar_suite_test.go rename to pkg/functions/functions_suite_test.go index 652643b6003f..8964b1c806a2 100644 --- a/pkg/grammar/grammar_suite_test.go +++ b/pkg/functions/functions_suite_test.go @@ -1,4 +1,4 @@ -package grammar +package functions import ( "testing" diff --git a/pkg/grammar/functions_test.go b/pkg/functions/functions_test.go similarity index 96% rename from pkg/grammar/functions_test.go rename to pkg/functions/functions_test.go index 6e8a56ed252b..97953a5edf24 100644 --- a/pkg/grammar/functions_test.go +++ b/pkg/functions/functions_test.go @@ -1,7 +1,7 @@ -package grammar_test +package functions_test import ( - . "github.com/go-skynet/LocalAI/pkg/grammar" + . "github.com/go-skynet/LocalAI/pkg/functions" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/pkg/grammar/json_schema.go b/pkg/functions/grammar_json_schema.go similarity index 99% rename from pkg/grammar/json_schema.go rename to pkg/functions/grammar_json_schema.go index 76f9778f5b7f..010463908bc7 100644 --- a/pkg/grammar/json_schema.go +++ b/pkg/functions/grammar_json_schema.go @@ -1,4 +1,4 @@ -package grammar +package functions // a golang port of https://github.com/ggerganov/llama.cpp/pull/1887 diff --git a/pkg/grammar/json_schema_test.go b/pkg/functions/grammar_json_schema_test.go similarity index 98% rename from pkg/grammar/json_schema_test.go rename to pkg/functions/grammar_json_schema_test.go index 39d2a4d57886..fc9029a8d09f 100644 --- a/pkg/grammar/json_schema_test.go +++ b/pkg/functions/grammar_json_schema_test.go @@ -1,9 +1,9 @@ -package grammar_test +package functions_test import ( "strings" - . "github.com/go-skynet/LocalAI/pkg/grammar" + . "github.com/go-skynet/LocalAI/pkg/functions" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go new file mode 100644 index 000000000000..5324e8c6f9be --- /dev/null +++ b/pkg/functions/parse.go @@ -0,0 +1,108 @@ +package functions + +import ( + "encoding/json" + "regexp" + + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" +) + +type FunctionsConfig struct { + DisableNoAction bool `yaml:"disable_no_action"` + NoActionFunctionName string `yaml:"no_action_function_name"` + NoActionDescriptionName string `yaml:"no_action_description_name"` + ParallelCalls bool `yaml:"parallel_calls"` + NoGrammar bool `yaml:"no_grammar"` + ResponseRegex string `yaml:"response_regex"` +} + +type FuncCallResults struct { + Name string + Arguments string +} + +func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncCallResults { + multipleResults := functionConfig.ParallelCalls + useGrammars := !functionConfig.NoGrammar + + results := []FuncCallResults{} + + // if no grammar is used, we have to extract function and arguments from the result + if !useGrammars { + // the response is a string that we have to parse + + // We use named regexes here to extract the function name and arguments + // obviously, this expects the LLM to be stable and return correctly formatted JSON + // TODO: optimize this and pre-compile it + var respRegex = regexp.MustCompile(functionConfig.ResponseRegex) + match := respRegex.FindStringSubmatch(llmresult) + result := make(map[string]string) + for i, name := range respRegex.SubexpNames() { + if i != 0 && name != "" && len(match) > i { + result[name] = match[i] + } + } + + // TODO: open point about multiple results and/or mixed with chat messages + // This is not handled as for now, we only expect one function call per response + functionName := result["function"] + if functionName == "" { + return results + } + + return append(results, FuncCallResults{Name: result["function"], Arguments: result["arguments"]}) + } + + // with grammars + // TODO: use generics to avoid this code duplication + if multipleResults { + ss := []map[string]interface{}{} + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + for _, s := range ss { + func_name, ok := s["function"] + if !ok { + continue + } + args, ok := s["arguments"] + if !ok { + continue + } + d, _ := json.Marshal(args) + funcName, ok := func_name.(string) + if !ok { + continue + } + results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)}) + } + } else { + // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) + ss := map[string]interface{}{} + // This prevent newlines to break JSON parsing for clients + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name, ok := ss["function"] + if !ok { + return results + } + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + if !ok { + return results + } + d, _ := json.Marshal(args) + funcName, ok := func_name.(string) + if !ok { + return results + } + results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)}) + } + + return results +} diff --git a/pkg/functions/parse_test.go b/pkg/functions/parse_test.go new file mode 100644 index 000000000000..5168a7d1ae22 --- /dev/null +++ b/pkg/functions/parse_test.go @@ -0,0 +1,85 @@ +package functions_test + +import ( + . "github.com/go-skynet/LocalAI/pkg/functions" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LocalAI function parse tests", func() { + var functionConfig FunctionsConfig + + BeforeEach(func() { + // Default configuration setup + functionConfig = FunctionsConfig{ + ParallelCalls: false, + NoGrammar: false, + ResponseRegex: `(?P\w+)\s*\((?P.*)\)`, + } + }) + + Context("when using grammars and single result expected", func() { + It("should parse the function name and arguments correctly", func() { + input := `{"function": "add", "arguments": {"x": 5, "y": 3}}` + functionConfig.ParallelCalls = false + functionConfig.NoGrammar = false + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(1)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + }) + }) + + Context("when not using grammars and regex is needed", func() { + It("should extract function name and arguments from the regex", func() { + input := `add({"x":5,"y":3})` + functionConfig.NoGrammar = true + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(1)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + }) + }) + + Context("when having invalid input", func() { + It("returns no results when there is no input", func() { + input := "" + functionConfig.NoGrammar = true + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(0)) + + functionConfig.NoGrammar = false + + results = ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(0)) + }) + It("returns no results when is invalid", func() { + input := "invalid input" + functionConfig.NoGrammar = true + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(0)) + functionConfig.NoGrammar = false + + results = ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(0)) + }) + }) + Context("when parallel calls are enabled", func() { + It("should handle multiple function calls", func() { + input := `[{"function": "add", "arguments": {"x": 5, "y": 3}}, {"function": "subtract", "arguments": {"x": 10, "y": 7}}]` + functionConfig.ParallelCalls = true + functionConfig.NoGrammar = false + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(2)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + Expect(results[1].Name).To(Equal("subtract")) + Expect(results[1].Arguments).To(Equal(`{"x":10,"y":7}`)) + }) + }) +}) diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 49a6b1bd175d..8fb8c39dee44 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -41,7 +41,7 @@ type Backend interface { PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) - AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) + AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index c0b4bc345ffa..0af5d94faf8a 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -53,8 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { return fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) { - return schema.TranscriptionResult{}, fmt.Errorf("unimplemented") +func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) { + return schema.Result{}, fmt.Errorf("unimplemented") } func (llm *Base) TTS(*pb.TTSRequest) error { diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 0e0e56c73578..882db12aaf62 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -210,7 +210,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp return client.TTS(ctx, in, opts...) } -func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { +func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() @@ -231,7 +231,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques if err != nil { return nil, err } - tresult := &schema.TranscriptionResult{} + tresult := &schema.Result{} for _, s := range res.Segments { tks := []int{} for _, t := range s.Tokens { diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index b4ba48847663..73b185a34d0c 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -53,12 +53,12 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc. return e.s.TTS(ctx, in) } -func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { +func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { r, err := e.s.AudioTranscription(ctx, in) if err != nil { return nil, err } - tr := &schema.TranscriptionResult{} + tr := &schema.Result{} for _, s := range r.Segments { var tks []int for _, t := range s.Tokens { diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index aa7a3fbc4e53..4d06544dcd59 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -15,7 +15,7 @@ type LLM interface { Load(*pb.ModelOptions) error Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error - AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) + AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) TTS(*pb.TTSRequest) error TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) Status() (pb.StatusResponse, error) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 617d8f624928..5d9808a4f05b 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -81,7 +81,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string if _, err := os.Stat(uri); err == nil { serverAddress, err := getFreeAddress() if err != nil { - return "", fmt.Errorf("%s failed allocating free ports: %s", backend, err.Error()) + return "", fmt.Errorf("failed allocating free ports: %s", err.Error()) } // Make sure the process is executable if err := ml.startProcess(uri, o.model, serverAddress); err != nil { @@ -134,7 +134,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string if !ready { log.Debug().Msgf("GRPC Service NOT ready") - return "", fmt.Errorf("%s grpc service not ready", backend) + return "", fmt.Errorf("grpc service not ready") } options := *o.gRPCOptions @@ -145,10 +145,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options) if err != nil { - return "", fmt.Errorf("\"%s\" could not load model: %w", backend, err) + return "", fmt.Errorf("could not load model: %w", err) } if !res.Success { - return "", fmt.Errorf("\"%s\" could not load model (no success): %s", backend, res.Message) + return "", fmt.Errorf("could not load model (no success): %s", res.Message) } return client, nil diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 003d832745a0..1b5c9aa066e3 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -1,18 +1,19 @@ package model import ( - "bytes" "context" "fmt" "os" "path/filepath" "strings" "sync" - "text/template" - "github.com/Masterminds/sprig/v3" - grammar "github.com/go-skynet/LocalAI/pkg/grammar" + "github.com/go-skynet/LocalAI/pkg/templates" + + "github.com/go-skynet/LocalAI/pkg/functions" "github.com/go-skynet/LocalAI/pkg/grpc" + "github.com/go-skynet/LocalAI/pkg/utils" + process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" ) @@ -25,7 +26,7 @@ type PromptTemplateData struct { SuppressSystemPrompt bool // used by chat specifically to indicate that SystemPrompt above should be _ignored_ Input string Instruction string - Functions []grammar.Function + Functions []functions.Function MessageIndex int } @@ -42,21 +43,6 @@ type ChatMessageTemplateData struct { LastMessage bool } -// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go? -// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go -type TemplateType int - -const ( - ChatPromptTemplate TemplateType = iota - ChatMessageTemplate - CompletionPromptTemplate - EditPromptTemplate - FunctionsPromptTemplate - - // The following TemplateType is **NOT** a valid value and MUST be last. It exists to make the sanity integration tests simpler! - IntegrationTestTemplate -) - // new idea: what if we declare a struct of these here, and use a loop to check? // TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl @@ -67,7 +53,7 @@ type ModelLoader struct { grpcClients map[string]grpc.Backend models map[string]ModelAddress grpcProcesses map[string]*process.Process - templates map[TemplateType]map[string]*template.Template + templates *templates.TemplateCache wd *WatchDog } @@ -86,11 +72,10 @@ func NewModelLoader(modelPath string) *ModelLoader { ModelPath: modelPath, grpcClients: make(map[string]grpc.Backend), models: make(map[string]ModelAddress), - templates: make(map[TemplateType]map[string]*template.Template), + templates: templates.NewTemplateCache(modelPath), grpcProcesses: make(map[string]*process.Process), } - nml.initializeTemplateMap() return nml } @@ -99,7 +84,7 @@ func (ml *ModelLoader) SetWatchDog(wd *WatchDog) { } func (ml *ModelLoader) ExistsInModelPath(s string) bool { - return existsInPath(ml.ModelPath, s) + return utils.ExistsInPath(ml.ModelPath, s) } func (ml *ModelLoader) ListModels() ([]string, error) { @@ -194,82 +179,22 @@ func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress { return "" } -func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) { +const ( + ChatPromptTemplate templates.TemplateType = iota + ChatMessageTemplate + CompletionPromptTemplate + EditPromptTemplate + FunctionsPromptTemplate +) + +func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType templates.TemplateType, templateName string, in PromptTemplateData) (string, error) { // TODO: should this check be improved? if templateType == ChatMessageTemplate { return "", fmt.Errorf("invalid templateType: ChatMessage") } - return ml.evaluateTemplate(templateType, templateName, in) + return ml.templates.EvaluateTemplate(templateType, templateName, in) } func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { - return ml.evaluateTemplate(ChatMessageTemplate, templateName, messageData) -} - -func existsInPath(path string, s string) bool { - _, err := os.Stat(filepath.Join(path, s)) - return err == nil -} - -func (ml *ModelLoader) initializeTemplateMap() { - // This also seems somewhat clunky as we reference the Test / End of valid data value slug, but it works? - for tt := TemplateType(0); tt < IntegrationTestTemplate; tt++ { - ml.templates[tt] = make(map[string]*template.Template) - } -} - -func (ml *ModelLoader) evaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) { - ml.mu.Lock() - defer ml.mu.Unlock() - - m, ok := ml.templates[templateType][templateName] - if !ok { - // return "", fmt.Errorf("template not loaded: %s", templateName) - loadErr := ml.loadTemplateIfExists(templateType, templateName) - if loadErr != nil { - return "", loadErr - } - m = ml.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked - } - if m == nil { - return "", fmt.Errorf("failed loading a template for %s", templateName) - } - - var buf bytes.Buffer - - if err := m.Execute(&buf, in); err != nil { - return "", err - } - return buf.String(), nil -} - -func (ml *ModelLoader) loadTemplateIfExists(templateType TemplateType, templateName string) error { - // Check if the template was already loaded - if _, ok := ml.templates[templateType][templateName]; ok { - return nil - } - - // Check if the model path exists - // skip any error here - we run anyway if a template does not exist - modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName) - - dat := "" - if ml.ExistsInModelPath(modelTemplateFile) { - d, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile)) - if err != nil { - return err - } - dat = string(d) - } else { - dat = templateName - } - - // Parse the template - tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat) - if err != nil { - return err - } - ml.templates[templateType][templateName] = tmpl - - return nil + return ml.templates.EvaluateTemplate(ChatMessageTemplate, templateName, messageData) } diff --git a/pkg/model/loader_test.go b/pkg/model/loader_test.go new file mode 100644 index 000000000000..d3956b63ee98 --- /dev/null +++ b/pkg/model/loader_test.go @@ -0,0 +1,199 @@ +package model_test + +import ( + "github.com/go-skynet/LocalAI/pkg/model" + . "github.com/go-skynet/LocalAI/pkg/model" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}} +{{- if .FunctionCall }} + +{{- else if eq .RoleName "tool" }} + +{{- end }} +{{- if .Content}} +{{.Content }} +{{- end }} +{{- if .FunctionCall}} +{{toJson .FunctionCall}} +{{- end }} +{{- if .FunctionCall }} + +{{- else if eq .RoleName "tool" }} + +{{- end }} +<|im_end|>` + +const llama3 = `<|start_header_id|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}<|end_header_id|> + +{{ if .FunctionCall -}} +Function call: +{{ else if eq .RoleName "tool" -}} +Function response: +{{ end -}} +{{ if .Content -}} +{{.Content -}} +{{ else if .FunctionCall -}} +{{ toJson .FunctionCall -}} +{{ end -}} +<|eot_id|>` + +var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ + "user": { + "template": llama3, + "expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "user", + RoleName: "user", + Content: "A long time ago in a galaxy far, far away...", + FunctionCall: nil, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, + "assistant": { + "template": llama3, + "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "assistant", + RoleName: "assistant", + Content: "A long time ago in a galaxy far, far away...", + FunctionCall: nil, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, + "function_call": { + "template": llama3, + "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "assistant", + RoleName: "assistant", + Content: "", + FunctionCall: map[string]string{"function": "test"}, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, + "function_response": { + "template": llama3, + "expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "tool", + RoleName: "tool", + Content: "Response from tool", + FunctionCall: nil, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, +} + +var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]interface{}{ + "user": { + "template": chatML, + "expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...\n<|im_end|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "user", + RoleName: "user", + Content: "A long time ago in a galaxy far, far away...", + FunctionCall: nil, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, + "assistant": { + "template": chatML, + "expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...\n<|im_end|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "assistant", + RoleName: "assistant", + Content: "A long time ago in a galaxy far, far away...", + FunctionCall: nil, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, + "function_call": { + "template": chatML, + "expected": "<|im_start|>assistant\n\n{\"function\":\"test\"}\n\n<|im_end|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "assistant", + RoleName: "assistant", + Content: "", + FunctionCall: map[string]string{"function": "test"}, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, + "function_response": { + "template": chatML, + "expected": "<|im_start|>tool\n\nResponse from tool\n\n<|im_end|>", + "data": model.ChatMessageTemplateData{ + SystemPrompt: "", + Role: "tool", + RoleName: "tool", + Content: "Response from tool", + FunctionCall: nil, + FunctionName: "", + LastMessage: false, + Function: false, + MessageIndex: 0, + }, + }, +} + +var _ = Describe("Templates", func() { + Context("chat message ChatML", func() { + var modelLoader *ModelLoader + BeforeEach(func() { + modelLoader = NewModelLoader("") + }) + for key := range chatMLTestMatch { + foo := chatMLTestMatch[key] + It("renders correctly `"+key+"`", func() { + templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData)) + Expect(err).ToNot(HaveOccurred()) + Expect(templated).To(Equal(foo["expected"]), templated) + }) + } + }) + Context("chat message llama3", func() { + var modelLoader *ModelLoader + BeforeEach(func() { + modelLoader = NewModelLoader("") + }) + for key := range llama3TestMatch { + foo := llama3TestMatch[key] + It("renders correctly `"+key+"`", func() { + templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData)) + Expect(err).ToNot(HaveOccurred()) + Expect(templated).To(Equal(foo["expected"]), templated) + }) + } + }) +}) diff --git a/pkg/model/model_suite_test.go b/pkg/model/model_suite_test.go new file mode 100644 index 000000000000..6fa9c0049b4a --- /dev/null +++ b/pkg/model/model_suite_test.go @@ -0,0 +1,13 @@ +package model_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestModel(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "LocalAI model test") +} diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go new file mode 100644 index 000000000000..b09516a7f242 --- /dev/null +++ b/pkg/startup/model_preload.go @@ -0,0 +1,85 @@ +package startup + +import ( + "errors" + "os" + "path/filepath" + + "github.com/go-skynet/LocalAI/embedded" + "github.com/go-skynet/LocalAI/pkg/downloader" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" +) + +// PreloadModelsConfigurations will preload models from the given list of URLs +// It will download the model if it is not already present in the model path +// It will also try to resolve if the model is an embedded model YAML configuration +func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) { + for _, url := range models { + + // As a best effort, try to resolve the model from the remote library + // if it's not resolved we try with the other method below + if modelLibraryURL != "" { + lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL) + if err == nil { + if lib[url] != "" { + log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url]) + url = lib[url] + } + } + } + + url = embedded.ModelShortURL(url) + switch { + case embedded.ExistsInModelsLibrary(url): + modelYAML, err := embedded.ResolveContent(url) + // If we resolve something, just save it to disk and continue + if err != nil { + log.Error().Err(err).Msg("error resolving model content") + continue + } + + log.Debug().Msgf("[startup] resolved embedded model: %s", url) + md5Name := utils.MD5(url) + modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" + if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { + log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") + } + case downloader.LooksLikeURL(url): + log.Debug().Msgf("[startup] resolved model to download: %s", url) + + // md5 of model name + md5Name := utils.MD5(url) + + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { + modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" + err := downloader.DownloadFile(url, modelDefinitionFilePath, "", func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + }) + if err != nil { + log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") + } + } + default: + if _, err := os.Stat(url); err == nil { + log.Debug().Msgf("[startup] resolved local model: %s", url) + // copy to modelPath + md5Name := utils.MD5(url) + + modelYAML, err := os.ReadFile(url) + if err != nil { + log.Error().Err(err).Str("filepath", url).Msg("error reading model definition") + continue + } + + modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" + if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { + log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s") + } + } else { + log.Warn().Msgf("[startup] failed resolving model '%s'", url) + } + } + } +} diff --git a/core/services/model_preload_test.go b/pkg/startup/model_preload_test.go similarity index 96% rename from core/services/model_preload_test.go rename to pkg/startup/model_preload_test.go index fc65d565bdc9..63a8f8b03e3b 100644 --- a/core/services/model_preload_test.go +++ b/pkg/startup/model_preload_test.go @@ -1,14 +1,13 @@ -package services_test +package startup_test import ( "fmt" "os" "path/filepath" + . "github.com/go-skynet/LocalAI/pkg/startup" "github.com/go-skynet/LocalAI/pkg/utils" - . "github.com/go-skynet/LocalAI/core/services" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/pkg/templates/cache.go b/pkg/templates/cache.go new file mode 100644 index 000000000000..9ff5560585ff --- /dev/null +++ b/pkg/templates/cache.go @@ -0,0 +1,103 @@ +package templates + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "sync" + "text/template" + + "github.com/go-skynet/LocalAI/pkg/utils" + + "github.com/Masterminds/sprig/v3" +) + +// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go? +// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go +type TemplateType int + +type TemplateCache struct { + mu sync.Mutex + templatesPath string + templates map[TemplateType]map[string]*template.Template +} + +func NewTemplateCache(templatesPath string) *TemplateCache { + tc := &TemplateCache{ + templatesPath: templatesPath, + templates: make(map[TemplateType]map[string]*template.Template), + } + return tc +} + +func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) { + if _, ok := tc.templates[tt]; !ok { + tc.templates[tt] = make(map[string]*template.Template) + } +} + +func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) { + tc.mu.Lock() + defer tc.mu.Unlock() + + tc.initializeTemplateMapKey(templateType) + m, ok := tc.templates[templateType][templateName] + if !ok { + // return "", fmt.Errorf("template not loaded: %s", templateName) + loadErr := tc.loadTemplateIfExists(templateType, templateName) + if loadErr != nil { + return "", loadErr + } + m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked + } + if m == nil { + return "", fmt.Errorf("failed loading a template for %s", templateName) + } + + var buf bytes.Buffer + + if err := m.Execute(&buf, in); err != nil { + return "", err + } + return buf.String(), nil +} + +func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error { + + // Check if the template was already loaded + if _, ok := tc.templates[templateType][templateName]; ok { + return nil + } + + // Check if the model path exists + // skip any error here - we run anyway if a template does not exist + modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName) + + dat := "" + file := filepath.Join(tc.templatesPath, modelTemplateFile) + + // Security check + if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil { + return fmt.Errorf("template file outside path: %s", file) + } + + if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) { + d, err := os.ReadFile(file) + if err != nil { + return err + } + dat = string(d) + } else { + dat = templateName + } + + // Parse the template + tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat) + if err != nil { + return err + } + tc.templates[templateType][templateName] = tmpl + + return nil +} diff --git a/pkg/templates/cache_test.go b/pkg/templates/cache_test.go new file mode 100644 index 000000000000..83af02b2f38e --- /dev/null +++ b/pkg/templates/cache_test.go @@ -0,0 +1,73 @@ +package templates_test + +import ( + "os" + "path/filepath" + + "github.com/go-skynet/LocalAI/pkg/templates" // Update with your module path + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("TemplateCache", func() { + var ( + templateCache *templates.TemplateCache + tempDir string + ) + + BeforeEach(func() { + var err error + tempDir, err = os.MkdirTemp("", "templates") + Expect(err).NotTo(HaveOccurred()) + + // Writing example template files + err = os.WriteFile(filepath.Join(tempDir, "example.tmpl"), []byte("Hello, {{.Name}}!"), 0644) + Expect(err).NotTo(HaveOccurred()) + err = os.WriteFile(filepath.Join(tempDir, "empty.tmpl"), []byte(""), 0644) + Expect(err).NotTo(HaveOccurred()) + + templateCache = templates.NewTemplateCache(tempDir) + }) + + AfterEach(func() { + os.RemoveAll(tempDir) // Clean up + }) + + Describe("EvaluateTemplate", func() { + Context("when template is loaded successfully", func() { + It("should evaluate the template correctly", func() { + result, err := templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"}) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("Hello, Gopher!")) + }) + }) + + Context("when template isn't a file", func() { + It("should parse from string", func() { + result, err := templateCache.EvaluateTemplate(1, "{{.Name}}", map[string]string{"Name": "Gopher"}) + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(Equal("Gopher")) + }) + }) + + Context("when template is empty", func() { + It("should return an empty string", func() { + result, err := templateCache.EvaluateTemplate(1, "empty", nil) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("")) + }) + }) + }) + + Describe("concurrency", func() { + It("should handle multiple concurrent accesses", func(done Done) { + go func() { + _, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"}) + }() + go func() { + _, _ = templateCache.EvaluateTemplate(1, "example", map[string]string{"Name": "Gopher"}) + }() + close(done) + }, 0.1) // timeout in seconds + }) +}) diff --git a/pkg/templates/utils_suite_test.go b/pkg/templates/utils_suite_test.go new file mode 100644 index 000000000000..011ba8f61fdc --- /dev/null +++ b/pkg/templates/utils_suite_test.go @@ -0,0 +1,13 @@ +package templates_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestTemplates(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Templates test suite") +} diff --git a/pkg/utils/base64.go b/pkg/utils/base64.go deleted file mode 100644 index 769d8a88c734..000000000000 --- a/pkg/utils/base64.go +++ /dev/null @@ -1,50 +0,0 @@ -package utils - -import ( - "encoding/base64" - "fmt" - "io" - "net/http" - "strings" - "time" -) - -var base64DownloadClient http.Client = http.Client{ - Timeout: 30 * time.Second, -} - -// this function check if the string is an URL, if it's an URL downloads the image in memory -// encodes it in base64 and returns the base64 string - -// This may look weird down in pkg/utils while it is currently only used in core/config -// -// but I believe it may be useful for MQTT as well in the near future, so I'm -// extracting it while I'm thinking of it. -func GetImageURLAsBase64(s string) (string, error) { - if strings.HasPrefix(s, "http") { - // download the image - resp, err := base64DownloadClient.Get(s) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // read the image data into memory - data, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - // encode the image data in base64 - encoded := base64.StdEncoding.EncodeToString(data) - - // return the base64 string - return encoded, nil - } - - // if the string instead is prefixed with "data:image/jpeg;base64,", drop it - if strings.HasPrefix(s, "data:image/jpeg;base64,") { - return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil - } - return "", fmt.Errorf("not valid string") -} diff --git a/pkg/utils/path.go b/pkg/utils/path.go index f95b0138133a..9982bc1e6ffb 100644 --- a/pkg/utils/path.go +++ b/pkg/utils/path.go @@ -2,10 +2,16 @@ package utils import ( "fmt" + "os" "path/filepath" "strings" ) +func ExistsInPath(path string, s string) bool { + _, err := os.Stat(filepath.Join(path, s)) + return err == nil +} + func inTrustedRoot(path string, trustedRoot string) error { for path != "/" { path = filepath.Dir(path)