diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 1d9687ac30..0000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,37 +0,0 @@ ---- -name: Report -about: Create a report to help us improve Jan -title: 'bug: ' -labels: 'type: bug' -assignees: '' - ---- - -**Describe the bug** -A clear and concise description of what the bug is. - -**Steps to reproduce** -Steps to reproduce the behavior: -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -**Expected behavior** -A clear and concise description of what you expected to happen. - -**Screenshots** -If applicable, add screenshots to help explain your issue. - -**Environment details** -- Operating System: [Specify your OS. e.g., MacOS Sonoma 14.2.1, Windows 11, Ubuntu 22, etc] -- Jan Version: [e.g., 0.4.xxx nightly or manual] -- Processor: [e.g., Apple M1, Intel Core i7, AMD Ryzen 5, etc] -- RAM: [e.g., 8GB, 16GB] -- Any additional relevant hardware specifics: [e.g., Graphics card, SSD/HDD] - -**Logs** -If the cause of the error is not clear, kindly provide your usage logs: https://jan.ai/docs/troubleshooting#how-to-get-error-logs - -**Additional context** -Add any other context or information that could be helpful in diagnosing the problem. diff --git a/.github/workflows/jan-electron-linter-and-test.yml b/.github/workflows/jan-electron-linter-and-test.yml index 5ae64c4eb2..3a95e804e4 100644 --- a/.github/workflows/jan-electron-linter-and-test.yml +++ b/.github/workflows/jan-electron-linter-and-test.yml @@ -67,9 +67,9 @@ jobs: run: | echo "REPORT_PORTAL_DESCRIPTION=${{github.sha}})" >> $GITHUB_ENV -# - name: 'Config report portal' -# run: | -# make update-playwright-config REPORT_PORTAL_URL=${{ secrets.REPORT_PORTAL_URL }} REPORT_PORTAL_API_KEY=${{ secrets.REPORT_PORTAL_API_KEY }} REPORT_PORTAL_PROJECT_NAME=${{ secrets.REPORT_PORTAL_PROJECT_NAME }} REPORT_PORTAL_LAUNCH_NAME="Jan App macos" REPORT_PORTAL_DESCRIPTION="${{env.REPORT_PORTAL_DESCRIPTION}}" + - name: 'Config report portal' + run: | + make update-playwright-config REPORT_PORTAL_URL=${{ secrets.REPORT_PORTAL_URL }} REPORT_PORTAL_API_KEY=${{ secrets.REPORT_PORTAL_API_KEY }} REPORT_PORTAL_PROJECT_NAME=${{ secrets.REPORT_PORTAL_PROJECT_NAME }} REPORT_PORTAL_LAUNCH_NAME="Jan App macos" REPORT_PORTAL_DESCRIPTION="${{env.REPORT_PORTAL_DESCRIPTION}}" - name: Linter and test run: | @@ -147,10 +147,10 @@ jobs: run: | echo "REPORT_PORTAL_DESCRIPTION=${{github.sha}}" >> $GITHUB_ENV -# - name: 'Config report portal' -# shell: bash -# run: | -# make update-playwright-config REPORT_PORTAL_URL=${{ secrets.REPORT_PORTAL_URL }} REPORT_PORTAL_API_KEY=${{ secrets.REPORT_PORTAL_API_KEY }} REPORT_PORTAL_PROJECT_NAME=${{ secrets.REPORT_PORTAL_PROJECT_NAME }} REPORT_PORTAL_LAUNCH_NAME="Jan App Windows ${{ matrix.antivirus-tools }}" REPORT_PORTAL_DESCRIPTION="${{env.REPORT_PORTAL_DESCRIPTION}}" + - name: 'Config report portal' + shell: bash + run: | + make update-playwright-config REPORT_PORTAL_URL=${{ secrets.REPORT_PORTAL_URL }} REPORT_PORTAL_API_KEY=${{ secrets.REPORT_PORTAL_API_KEY }} REPORT_PORTAL_PROJECT_NAME=${{ secrets.REPORT_PORTAL_PROJECT_NAME }} REPORT_PORTAL_LAUNCH_NAME="Jan App Windows ${{ matrix.antivirus-tools }}" REPORT_PORTAL_DESCRIPTION="${{env.REPORT_PORTAL_DESCRIPTION}}" - name: Linter and test shell: powershell @@ -195,14 +195,11 @@ jobs: run: | echo "REPORT_PORTAL_DESCRIPTION=${{github.event.after}}" >> $GITHUB_ENV -# - name: 'Config report portal' -# shell: bash -# run: | -# make update-playwright-config REPORT_PORTAL_URL=${{ secrets.REPORT_PORTAL_URL }} REPORT_PORTAL_API_KEY=${{ secrets.REPORT_PORTAL_API_KEY }} REPORT_PORTAL_PROJECT_NAME=${{ secrets.REPORT_PORTAL_PROJECT_NAME }} REPORT_PORTAL_LAUNCH_NAME="Jan App Windows" REPORT_PORTAL_DESCRIPTION="${{env.REPORT_PORTAL_DESCRIPTION}}" + - name: 'Config report portal' + shell: bash + run: | + make update-playwright-config REPORT_PORTAL_URL=${{ secrets.REPORT_PORTAL_URL }} REPORT_PORTAL_API_KEY=${{ secrets.REPORT_PORTAL_API_KEY }} REPORT_PORTAL_PROJECT_NAME=${{ secrets.REPORT_PORTAL_PROJECT_NAME }} REPORT_PORTAL_LAUNCH_NAME="Jan App Windows" REPORT_PORTAL_DESCRIPTION="${{env.REPORT_PORTAL_DESCRIPTION}}" - - name: Setup node-gyp - distutils - run: pip3 install --upgrade setuptools - - name: Linter and test shell: powershell run: | @@ -278,10 +275,10 @@ jobs: run: | echo "REPORT_PORTAL_DESCRIPTION=${{github.sha}}" >> $GITHUB_ENV -# - name: 'Config report portal' -# shell: bash -# run: | -# make update-playwright-config REPORT_PORTAL_URL=${{ secrets.REPORT_PORTAL_URL }} REPORT_PORTAL_API_KEY=${{ secrets.REPORT_PORTAL_API_KEY }} REPORT_PORTAL_PROJECT_NAME=${{ secrets.REPORT_PORTAL_PROJECT_NAME }} REPORT_PORTAL_LAUNCH_NAME="Jan App Linux" REPORT_PORTAL_DESCRIPTION="${{env.REPORT_PORTAL_DESCRIPTION}}" + - name: 'Config report portal' + shell: bash + run: | + make update-playwright-config REPORT_PORTAL_URL=${{ secrets.REPORT_PORTAL_URL }} REPORT_PORTAL_API_KEY=${{ secrets.REPORT_PORTAL_API_KEY }} REPORT_PORTAL_PROJECT_NAME=${{ secrets.REPORT_PORTAL_PROJECT_NAME }} REPORT_PORTAL_LAUNCH_NAME="Jan App Linux" REPORT_PORTAL_DESCRIPTION="${{env.REPORT_PORTAL_DESCRIPTION}}" - name: Linter and test run: | diff --git a/.github/workflows/jan-server-build-nightly.yml b/.github/workflows/jan-server-build-nightly.yml new file mode 100644 index 0000000000..29e13804ee --- /dev/null +++ b/.github/workflows/jan-server-build-nightly.yml @@ -0,0 +1,40 @@ +name: Docker Builder - Nightly / Manual + +on: + push: + branches: + - main + - feature/helmchart-and-ci-jan-server + paths-ignore: + - 'README.md' + - 'docs/**' + schedule: + - cron: '0 21 * * 1,2,3' # At 8 PM UTC on Monday, Tuesday, and Wednesday which is 4 AM UTC+7 Tuesday, Wednesday, and Thursday + workflow_dispatch: + +jobs: + # Job create Update app version based on latest release tag with build number and save to output + get-update-version: + uses: ./.github/workflows/template-get-update-version.yml + + build-cpu: + uses: ./.github/workflows/template-build-jan-server.yml + permissions: + packages: write + secrets: inherit + needs: [get-update-version] + with: + dockerfile_path: ./Dockerfile + docker_image_tag: "ghcr.io/janhq/jan-server:dev-cpu-latest,ghcr.io/janhq/jan-server:dev-cpu-${{ needs.get-update-version.outputs.new_version }}" + + build-gpu: + uses: ./.github/workflows/template-build-jan-server.yml + permissions: + packages: write + secrets: inherit + needs: [get-update-version] + with: + dockerfile_path: ./Dockerfile.gpu + docker_image_tag: "ghcr.io/janhq/jan-server:dev-cuda-12.2-latest,ghcr.io/janhq/jan-server:dev-cuda-12.2-${{ needs.get-update-version.outputs.new_version }}" + + diff --git a/.github/workflows/jan-server-build.yml b/.github/workflows/jan-server-build.yml new file mode 100644 index 0000000000..503efd2989 --- /dev/null +++ b/.github/workflows/jan-server-build.yml @@ -0,0 +1,30 @@ +name: Docker Builder - Tag + +on: + push: + tags: ["v[0-9]+.[0-9]+.[0-9]+"] + +jobs: + # Job create Update app version based on latest release tag with build number and save to output + get-update-version: + uses: ./.github/workflows/template-get-update-version.yml + + build-cpu: + permissions: + packages: write + uses: ./.github/workflows/template-build-jan-server.yml + secrets: inherit + needs: [get-update-version] + with: + dockerfile_path: ./Dockerfile + docker_image_tag: "ghcr.io/janhq/jan-server:cpu-latest,ghcr.io/janhq/jan-server:cpu-${{ needs.get-update-version.outputs.new_version }}" + + build-gpu: + permissions: + packages: write + uses: ./.github/workflows/template-build-jan-server.yml + secrets: inherit + needs: [get-update-version] + with: + dockerfile_path: ./Dockerfile.gpu + docker_image_tag: "ghcr.io/janhq/jan-server:cuda-12.2-latest,ghcr.io/janhq/jan-server:cuda-12.2-${{ needs.get-update-version.outputs.new_version }}" diff --git a/.github/workflows/template-build-linux-x64.yml b/.github/workflows/template-build-linux-x64.yml index c3df9be962..08cb1dadaf 100644 --- a/.github/workflows/template-build-linux-x64.yml +++ b/.github/workflows/template-build-linux-x64.yml @@ -91,7 +91,6 @@ jobs: AWS_ACCESS_KEY_ID: ${{ secrets.CLOUDFLARE_R2_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.CLOUDFLARE_R2_SECRET_ACCESS_KEY }} AWS_EC2_METADATA_DISABLED: "true" - AWS_MAX_ATTEMPTS: "5" - name: Build and publish app to github if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') && inputs.public_provider == 'github' diff --git a/.github/workflows/template-build-macos-arm64.yml b/.github/workflows/template-build-macos-arm64.yml index a442984093..a5bc1e5394 100644 --- a/.github/workflows/template-build-macos-arm64.yml +++ b/.github/workflows/template-build-macos-arm64.yml @@ -56,11 +56,6 @@ jobs: with: node-version: 20 - - name: Install python - uses: actions/setup-python@v4 - with: - python-version: '3.9' - - name: Install jq uses: dcarbone/install-jq-action@v2.0.1 @@ -135,7 +130,6 @@ jobs: AWS_SECRET_ACCESS_KEY: ${{ secrets.CLOUDFLARE_R2_SECRET_ACCESS_KEY }} AWS_DEFAULT_REGION: auto AWS_EC2_METADATA_DISABLED: "true" - AWS_MAX_ATTEMPTS: "5" - name: Build and publish app to github if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') && inputs.public_provider == 'github' diff --git a/.github/workflows/template-build-macos-x64.yml b/.github/workflows/template-build-macos-x64.yml index c35c0ca36e..d9543194d6 100644 --- a/.github/workflows/template-build-macos-x64.yml +++ b/.github/workflows/template-build-macos-x64.yml @@ -56,11 +56,6 @@ jobs: with: node-version: 20 - - name: Install python - uses: actions/setup-python@v4 - with: - python-version: '3.9' - - name: Install jq uses: dcarbone/install-jq-action@v2.0.1 @@ -135,7 +130,6 @@ jobs: AWS_SECRET_ACCESS_KEY: ${{ secrets.CLOUDFLARE_R2_SECRET_ACCESS_KEY }} AWS_DEFAULT_REGION: auto AWS_EC2_METADATA_DISABLED: "true" - AWS_MAX_ATTEMPTS: "5" - name: Build and publish app to github if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') && inputs.public_provider == 'github' diff --git a/.github/workflows/template-build-windows-x64.yml b/.github/workflows/template-build-windows-x64.yml index 7fef1810ad..b81997bde2 100644 --- a/.github/workflows/template-build-windows-x64.yml +++ b/.github/workflows/template-build-windows-x64.yml @@ -120,7 +120,6 @@ jobs: AWS_SECRET_ACCESS_KEY: ${{ secrets.CLOUDFLARE_R2_SECRET_ACCESS_KEY }} AWS_DEFAULT_REGION: auto AWS_EC2_METADATA_DISABLED: "true" - AWS_MAX_ATTEMPTS: "5" - name: Build app and publish app to github if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') && inputs.public_provider == 'github' diff --git a/.gitignore b/.gitignore index dd14a2238a..0b6f98465a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,9 +12,6 @@ yarn.lock dist build .DS_Store -electron/resources/win/* -electron/resources/linux/* -electron/resources/mac/* electron/renderer electron/models electron/docs diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000..7fbbda2cfb --- /dev/null +++ b/Dockerfile @@ -0,0 +1,60 @@ +FROM node:20-bookworm AS base + +# 1. Install dependencies only when needed +FROM base AS builder + +# Install g++ 11 +RUN apt update && apt install -y gcc-11 g++-11 cpp-11 jq xsel && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Install dependencies based on the preferred package manager +COPY . ./ + +RUN export NITRO_VERSION=$(cat extensions/inference-nitro-extension/bin/version.txt) && \ + jq --arg nitroVersion $NITRO_VERSION '(.scripts."downloadnitro:linux" | gsub("\\${NITRO_VERSION}"; $nitroVersion)) | gsub("\r"; "")' extensions/inference-nitro-extension/package.json > /tmp/newcommand.txt && export NEW_COMMAND=$(sed 's/^"//;s/"$//' /tmp/newcommand.txt) && jq --arg newCommand "$NEW_COMMAND" '.scripts."downloadnitro:linux" = $newCommand' extensions/inference-nitro-extension/package.json > /tmp/package.json && mv /tmp/package.json extensions/inference-nitro-extension/package.json +RUN make install-and-build + +# # 2. Rebuild the source code only when needed +FROM base AS runner + +# Install g++ 11 +RUN apt update && apt install -y gcc-11 g++-11 cpp-11 jq xsel && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy the package.json and yarn.lock of root yarn space to leverage Docker cache +COPY --from=builder /app/package.json ./package.json +COPY --from=builder /app/node_modules ./node_modules/ +COPY --from=builder /app/yarn.lock ./yarn.lock + +# Copy the package.json, yarn.lock, and build output of server yarn space to leverage Docker cache +COPY --from=builder /app/core ./core/ +COPY --from=builder /app/server ./server/ +RUN cd core && yarn install && yarn run build +RUN yarn workspace @janhq/server install && yarn workspace @janhq/server build +COPY --from=builder /app/docs/openapi ./docs/openapi/ + +# Copy pre-install dependencies +COPY --from=builder /app/pre-install ./pre-install/ + +# Copy the package.json, yarn.lock, and output of web yarn space to leverage Docker cache +COPY --from=builder /app/joi ./joi/ +COPY --from=builder /app/web ./web/ + +RUN yarn workspace @janhq/joi install && yarn workspace @janhq/joi build +RUN yarn workspace @janhq/web install + +RUN npm install -g serve@latest + +EXPOSE 1337 3000 3928 + +ENV JAN_API_HOST 0.0.0.0 +ENV JAN_API_PORT 1337 + +ENV API_BASE_URL http://localhost:1337 + +CMD ["sh", "-c", "export NODE_ENV=production && yarn workspace @janhq/web build && cd web && npx serve out & cd server && node build/main.js"] + +# docker build -t jan . +# docker run -p 1337:1337 -p 3000:3000 -p 3928:3928 jan diff --git a/Dockerfile.gpu b/Dockerfile.gpu new file mode 100644 index 0000000000..195a28d429 --- /dev/null +++ b/Dockerfile.gpu @@ -0,0 +1,87 @@ +# Please change the base image to the appropriate CUDA version base on NVIDIA Driver Compatibility +# Run nvidia-smi to check the CUDA version and the corresponding driver version +# Then update the base image to the appropriate CUDA version refer https://catalog.ngc.nvidia.com/orgs/nvidia/containers/cuda/tags + +FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 AS base + +# 1. Install dependencies only when needed +FROM base AS builder + +# Install g++ 11 +RUN apt update && apt install -y gcc-11 g++-11 cpp-11 jq xsel curl gnupg make python3-dev && curl -sL https://deb.nodesource.com/setup_20.x | bash - && apt install nodejs -y && rm -rf /var/lib/apt/lists/* + +# Update alternatives for GCC and related tools +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 110 \ + --slave /usr/bin/g++ g++ /usr/bin/g++-11 \ + --slave /usr/bin/gcov gcov /usr/bin/gcov-11 \ + --slave /usr/bin/gcc-ar gcc-ar /usr/bin/gcc-ar-11 \ + --slave /usr/bin/gcc-ranlib gcc-ranlib /usr/bin/gcc-ranlib-11 && \ + update-alternatives --install /usr/bin/cpp cpp /usr/bin/cpp-11 110 + +RUN npm install -g yarn + +WORKDIR /app + +# Install dependencies based on the preferred package manager +COPY . ./ + +RUN export NITRO_VERSION=$(cat extensions/inference-nitro-extension/bin/version.txt) && \ + jq --arg nitroVersion $NITRO_VERSION '(.scripts."downloadnitro:linux" | gsub("\\${NITRO_VERSION}"; $nitroVersion)) | gsub("\r"; "")' extensions/inference-nitro-extension/package.json > /tmp/newcommand.txt && export NEW_COMMAND=$(sed 's/^"//;s/"$//' /tmp/newcommand.txt) && jq --arg newCommand "$NEW_COMMAND" '.scripts."downloadnitro:linux" = $newCommand' extensions/inference-nitro-extension/package.json > /tmp/package.json && mv /tmp/package.json extensions/inference-nitro-extension/package.json +RUN make install-and-build + +# # 2. Rebuild the source code only when needed +FROM base AS runner + +# Install g++ 11 +RUN apt update && apt install -y gcc-11 g++-11 cpp-11 jq xsel curl gnupg make python3-dev && curl -sL https://deb.nodesource.com/setup_20.x | bash - && apt-get install nodejs -y && rm -rf /var/lib/apt/lists/* + +# Update alternatives for GCC and related tools +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 110 \ + --slave /usr/bin/g++ g++ /usr/bin/g++-11 \ + --slave /usr/bin/gcov gcov /usr/bin/gcov-11 \ + --slave /usr/bin/gcc-ar gcc-ar /usr/bin/gcc-ar-11 \ + --slave /usr/bin/gcc-ranlib gcc-ranlib /usr/bin/gcc-ranlib-11 && \ + update-alternatives --install /usr/bin/cpp cpp /usr/bin/cpp-11 110 + +RUN npm install -g yarn + +WORKDIR /app + +# Copy the package.json and yarn.lock of root yarn space to leverage Docker cache +COPY --from=builder /app/package.json ./package.json +COPY --from=builder /app/node_modules ./node_modules/ +COPY --from=builder /app/yarn.lock ./yarn.lock + +# Copy the package.json, yarn.lock, and build output of server yarn space to leverage Docker cache +COPY --from=builder /app/core ./core/ +COPY --from=builder /app/server ./server/ +RUN cd core && yarn install && yarn run build +RUN yarn workspace @janhq/server install && yarn workspace @janhq/server build +COPY --from=builder /app/docs/openapi ./docs/openapi/ + +# Copy pre-install dependencies +COPY --from=builder /app/pre-install ./pre-install/ + +# Copy the package.json, yarn.lock, and output of web yarn space to leverage Docker cache +COPY --from=builder /app/joi ./joi/ +COPY --from=builder /app/web ./web/ + +RUN yarn workspace @janhq/joi install && yarn workspace @janhq/joi build +RUN yarn workspace @janhq/web install + +RUN npm install -g serve@latest + +EXPOSE 1337 3000 3928 + +ENV LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/cuda-12.0/compat${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} + +ENV JAN_API_HOST 0.0.0.0 +ENV JAN_API_PORT 1337 + +ENV API_BASE_URL http://localhost:1337 + +CMD ["sh", "-c", "export NODE_ENV=production && yarn workspace @janhq/web build && cd web && npx serve out & cd server && node build/main.js"] + +# pre-requisites: nvidia-docker +# docker build -t jan-gpu . -f Dockerfile.gpu +# docker run -p 1337:1337 -p 3000:3000 -p 3928:3928 --gpus all jan-gpu diff --git a/Makefile b/Makefile index e2c6a4a2ac..1687f8bbe8 100644 --- a/Makefile +++ b/Makefile @@ -18,14 +18,16 @@ else cd joi && yarn install && yarn build endif -# Installs yarn dependencies and builds core +# Installs yarn dependencies and builds core and extensions install-and-build: build-joi ifeq ($(OS),Windows_NT) yarn config set network-timeout 300000 endif yarn global add turbo@1.13.2 yarn build:core + yarn build:server yarn install + yarn build:extensions check-file-counts: install-and-build ifeq ($(OS),Windows_NT) @@ -34,11 +36,11 @@ else @tgz_count=$$(find pre-install -type f -name "*.tgz" | wc -l); dir_count=$$(find extensions -mindepth 1 -maxdepth 1 -type d -exec test -e '{}/package.json' \; -print | wc -l); if [ $$tgz_count -ne $$dir_count ]; then echo "Number of .tgz files in pre-install ($$tgz_count) does not match the number of subdirectories in extension ($$dir_count)"; exit 1; else echo "Extension build successful"; fi endif -dev: install-and-build +dev: check-file-counts yarn dev # Linting -lint: install-and-build +lint: check-file-counts yarn lint update-playwright-config: @@ -106,11 +108,11 @@ test: lint yarn test # Builds and publishes the app -build-and-publish: install-and-build +build-and-publish: check-file-counts yarn build:publish # Build -build: install-and-build +build: check-file-counts yarn build clean: diff --git a/README.md b/README.md index 3ba58438b4..e1622b0812 100644 --- a/README.md +++ b/README.md @@ -210,6 +210,12 @@ Contributions are welcome! Please read the [CONTRIBUTING.md](CONTRIBUTING.md) fi This will start the development server and open the desktop app. +3. (Optional) **Run the API server without frontend** + + ```bash + yarn dev:server + ``` + ### For production build ```bash diff --git a/charts/server/Chart.lock b/charts/server/Chart.lock new file mode 100644 index 0000000000..915788d617 --- /dev/null +++ b/charts/server/Chart.lock @@ -0,0 +1,6 @@ +dependencies: +- name: common + repository: oci://ghcr.io/janhq/charts + version: 0.1.2 +digest: sha256:35e98bde174130787755b0f8ea2359b7b6790d965a7157c2f7cabf1bc8c04471 +generated: "2024-02-20T16:20:37.6530108+07:00" diff --git a/charts/server/Chart.yaml b/charts/server/Chart.yaml new file mode 100644 index 0000000000..fb2e1c91bd --- /dev/null +++ b/charts/server/Chart.yaml @@ -0,0 +1,10 @@ +apiVersion: v2 +name: jan-server +description: A Helm chart for Kubernetes +type: application +version: 0.1.0 +appVersion: '1.0.0' +dependencies: + - name: common + version: 0.1.2 # common-chart-version + repository: oci://ghcr.io/janhq/charts diff --git a/charts/server/charts/common-0.1.2.tgz b/charts/server/charts/common-0.1.2.tgz new file mode 100644 index 0000000000..946617eabb Binary files /dev/null and b/charts/server/charts/common-0.1.2.tgz differ diff --git a/charts/server/config.json b/charts/server/config.json new file mode 100644 index 0000000000..62e9682fa6 --- /dev/null +++ b/charts/server/config.json @@ -0,0 +1,4 @@ +{ + "image-list": "server=ghcr.io/janhq/jan-server", + "platforms": "linux/amd64" +} \ No newline at end of file diff --git a/charts/server/values.yaml b/charts/server/values.yaml new file mode 100644 index 0000000000..b31f476569 --- /dev/null +++ b/charts/server/values.yaml @@ -0,0 +1,256 @@ +common: + imageTag: v0.4.6-cpu + # DO NOT CHANGE THE LINE ABOVE. MAKE ALL CHANGES BELOW + + # Global pvc for all workload + pvc: + enabled: false + name: 'janroot' + accessModes: 'ReadWriteOnce' + storageClassName: '' + capacity: '50Gi' + + # Global image pull secret + imagePullSecrets: [] + + externalSecret: + create: false + name: '' + annotations: {} + + nameOverride: 'jan-server' + fullnameOverride: 'jan-server' + + serviceAccount: + create: true + annotations: {} + name: 'jan-server-service-account' + + podDisruptionBudget: + create: false + minAvailable: 1 + + workloads: + - name: server + image: + repository: ghcr.io/janhq/jan-server + pullPolicy: Always + + command: ['/bin/sh', '-c'] + args: ['cd server && node build/main.js'] + + replicaCount: 1 + ports: + containerPort: 1337 + + strategy: + canary: + steps: + - setWeight: 50 + - pause: { duration: 1m } + + ingress: + enabled: true + className: 'nginx' + annotations: + nginx.ingress.kubernetes.io/proxy-body-size: '100m' + nginx.ingress.kubernetes.io/proxy-read-timeout: '1800' + nginx.ingress.kubernetes.io/proxy-send-timeout: '1800' + # cert-manager.io/cluster-issuer: 'jan-ai-dns01-cluster-issuer' + # nginx.ingress.kubernetes.io/force-ssl-redirect: 'true' + nginx.ingress.kubernetes.io/backend-protocol: HTTP + hosts: + - host: server.local + paths: + - path: / + pathType: Prefix + tls: + [] + # - hosts: + # - server-dev.jan.ai + # secretName: jan-server-prod-tls-v2 + + instrumentation: + enabled: false + podAnnotations: {} + + podSecurityContext: {} + + securityContext: {} + + service: + externalLabel: {} + type: ClusterIP + port: 1337 + targetPort: 1337 + + # If you want to use GPU, please uncomment the following lines and change imageTag to the one with GPU support + resources: + # limits: + # nvidia.com/gpu: 1 + requests: + cpu: 2000m + memory: 8192M + + # If you want to use pv, please uncomment the following lines and enable pvc.enabled + volumes: + [] + # - name: janroot + # persistentVolumeClaim: + # claimName: janroot + + volumeMounts: + [] + # - name: janroot + # mountPath: /app/server/build/jan + + # AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, S3_BUCKET_NAME, AWS_ENDPOINT, AWS_REGION should mount as a secret env instead of plain text here + # Change API_BASE_URL to your server's public domain + env: + - name: API_BASE_URL + value: 'http://server.local' + + lifecycle: {} + autoscaling: + enabled: false + minReplicas: 2 + maxReplicas: 3 + targetCPUUtilizationPercentage: 95 + targetMemoryUtilizationPercentage: 95 + + kedaScaling: + enabled: false # ignore if autoscaling.enable = true + cooldownPeriod: 30 + pollingInterval: 2 + minReplicas: 1 + maxReplicas: 5 + metricName: celery_queue_length + query: celery_queue_length{queue_name="myqueue"} # change queue_name here + serverAddress: http://prometheus-prod-kube-prome-prometheus.monitoring.svc:9090 + threshold: '3' + + nodeSelector: {} + + tolerations: [] + + podSecurityGroup: + enabled: false + securitygroupid: [] + + # Reloader Option + reloader: 'false' + vpa: + enabled: false + + - name: web + image: + repository: ghcr.io/janhq/jan-server + pullPolicy: Always + + command: ['/bin/sh', '-c'] + args: + [ + 'export NODE_ENV=production && yarn workspace @janhq/web build && cd web && npx serve out', + ] + + replicaCount: 1 + ports: + containerPort: 3000 + + strategy: + canary: + steps: + - setWeight: 50 + - pause: { duration: 1m } + + ingress: + enabled: true + className: 'nginx' + annotations: + nginx.ingress.kubernetes.io/proxy-body-size: '100m' + nginx.ingress.kubernetes.io/proxy-read-timeout: '1800' + nginx.ingress.kubernetes.io/proxy-send-timeout: '1800' + # cert-manager.io/cluster-issuer: 'jan-ai-dns01-cluster-issuer' + # nginx.ingress.kubernetes.io/force-ssl-redirect: 'true' + nginx.ingress.kubernetes.io/backend-protocol: HTTP + hosts: + - host: web.local + paths: + - path: / + pathType: Prefix + tls: + [] + # - hosts: + # - server-dev.jan.ai + # secretName: jan-server-prod-tls-v2 + + instrumentation: + enabled: false + podAnnotations: {} + + podSecurityContext: {} + + securityContext: {} + + service: + externalLabel: {} + type: ClusterIP + port: 3000 + targetPort: 3000 + + resources: + limits: + cpu: 1000m + memory: 2048M + requests: + cpu: 50m + memory: 500M + + volumes: + [] + # - name: janroot + # persistentVolumeClaim: + # claimName: janroot + + volumeMounts: + [] + # - name: janroot + # mountPath: /app/server/build/jan + + # AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, S3_BUCKET_NAME, AWS_ENDPOINT, AWS_REGION should mount as a secret env instead of plain text here + # Change API_BASE_URL to your server's public domain + env: + - name: API_BASE_URL + value: 'http://server.local' + + lifecycle: {} + autoscaling: + enabled: true + minReplicas: 1 + maxReplicas: 3 + targetCPUUtilizationPercentage: 95 + targetMemoryUtilizationPercentage: 95 + + kedaScaling: + enabled: false # ignore if autoscaling.enable = true + cooldownPeriod: 30 + pollingInterval: 2 + minReplicas: 1 + maxReplicas: 5 + metricName: celery_queue_length + query: celery_queue_length{queue_name="myqueue"} # change queue_name here + serverAddress: http://prometheus-prod-kube-prome-prometheus.monitoring.svc:9090 + threshold: '3' + + nodeSelector: {} + + tolerations: [] + + podSecurityGroup: + enabled: false + securitygroupid: [] + + # Reloader Option + reloader: 'false' + vpa: + enabled: false diff --git a/core/README.md b/core/README.md index 293e6668c3..925ffaf7b8 100644 --- a/core/README.md +++ b/core/README.md @@ -40,7 +40,7 @@ import * as node from "@janhq/core/node"; private static inference(incomingMessage: MessageRequestData) { // Prepare customized message content - const content: MessageContent = { + const content: ThreadContent = { type: ContentType.Text, text: { value: "I'm Jan Assistant!", @@ -49,7 +49,7 @@ import * as node from "@janhq/core/node"; }; // Modify message and send out - const outGoingMessage: Message = { + const outGoingMessage: ThreadMessage = { ...incomingMessage, content }; diff --git a/core/package.json b/core/package.json index 082ce9a8d9..9e4d8d69a3 100644 --- a/core/package.json +++ b/core/package.json @@ -8,7 +8,7 @@ ], "homepage": "https://jan.ai", "license": "AGPL-3.0", - "main": "dist/lib/index.js", + "main": "dist/core.es5.js", "module": "dist/core.cjs.js", "typings": "dist/types/index.d.ts", "files": [ @@ -17,18 +17,18 @@ ], "author": "Jan ", "exports": { - ".": "./dist/lib/index.js", + ".": "./dist/core.es5.js", "./node": "./dist/node/index.cjs.js" }, "typesVersions": { "*": { ".": [ - "./dist/lib/index.js", + "./dist/core.es5.js.map", "./dist/types/index.d.ts" ], "node": [ "./dist/node/index.cjs.js.map", - "./dist/types/index.d.ts" + "./dist/types/node/index.d.ts" ] } }, @@ -40,7 +40,6 @@ "start": "rollup -c rollup.config.ts -w" }, "devDependencies": { - "openai": "4.51.0", "@rollup/plugin-replace": "^5.0.5", "@types/jest": "^29.5.12", "@types/node": "^20.11.4", @@ -59,6 +58,7 @@ "typescript": "^5.3.3" }, "dependencies": { - "rxjs": "^7.8.1" + "rxjs": "^7.8.1", + "ulidx": "^2.3.0" } } diff --git a/core/rollup.config.ts b/core/rollup.config.ts index b4268d82e6..e3336bfad6 100644 --- a/core/rollup.config.ts +++ b/core/rollup.config.ts @@ -43,7 +43,7 @@ export default [ ], }, { - input: `src/index.ts`, + input: `src/node/index.ts`, output: [{ file: 'dist/node/index.cjs.js', format: 'cjs', sourcemap: true }], // Indicate here external modules you don't wanna include in your bundle (i.e.: 'lodash') external: [ @@ -52,6 +52,7 @@ export default [ 'pacote', '@types/pacote', '@npmcli/arborist', + 'ulidx', 'node-fetch', 'fs', 'request', diff --git a/core/src/browser/core.ts b/core/src/browser/core.ts new file mode 100644 index 0000000000..fdbceb06bb --- /dev/null +++ b/core/src/browser/core.ts @@ -0,0 +1,165 @@ +import { DownloadRequest, FileStat, NetworkConfig, SystemInformation } from '../types' + +/** + * Execute a extension module function in main process + * + * @param extension extension name to import + * @param method function name to execute + * @param args arguments to pass to the function + * @returns Promise + * + */ +const executeOnMain: (extension: string, method: string, ...args: any[]) => Promise = ( + extension, + method, + ...args +) => globalThis.core?.api?.invokeExtensionFunc(extension, method, ...args) + +/** + * Downloads a file from a URL and saves it to the local file system. + * + * @param {DownloadRequest} downloadRequest - The request to download the file. + * @param {NetworkConfig} network - Optional object to specify proxy/whether to ignore SSL certificates. + * + * @returns {Promise} A promise that resolves when the file is downloaded. + */ +const downloadFile: (downloadRequest: DownloadRequest, network?: NetworkConfig) => Promise = ( + downloadRequest, + network +) => globalThis.core?.api?.downloadFile(downloadRequest, network) + +/** + * Get unit in bytes for a remote file. + * + * @param url - The url of the file. + * @returns {Promise} - A promise that resolves with the file size. + */ +const getFileSize: (url: string) => Promise = (url: string) => + globalThis.core.api?.getFileSize(url) + +/** + * Aborts the download of a specific file. + * @param {string} fileName - The name of the file whose download is to be aborted. + * @returns {Promise} A promise that resolves when the download has been aborted. + */ +const abortDownload: (fileName: string) => Promise = (fileName) => + globalThis.core.api?.abortDownload(fileName) + +/** + * Gets Jan's data folder path. + * + * @returns {Promise} A Promise that resolves with Jan's data folder path. + */ +const getJanDataFolderPath = (): Promise => globalThis.core.api?.getJanDataFolderPath() + +/** + * Opens the file explorer at a specific path. + * @param {string} path - The path to open in the file explorer. + * @returns {Promise} A promise that resolves when the file explorer is opened. + */ +const openFileExplorer: (path: string) => Promise = (path) => + globalThis.core.api?.openFileExplorer(path) + +/** + * Joins multiple paths together. + * @param paths - The paths to join. + * @returns {Promise} A promise that resolves with the joined path. + */ +const joinPath: (paths: string[]) => Promise = (paths) => + globalThis.core.api?.joinPath(paths) + +/** + * Retrieve the basename from an url. + * @param path - The path to retrieve. + * @returns {Promise} A promise that resolves with the basename. + */ +const baseName: (paths: string) => Promise = (path) => globalThis.core.api?.baseName(path) + +/** + * Opens an external URL in the default web browser. + * + * @param {string} url - The URL to open. + * @returns {Promise} - A promise that resolves when the URL has been successfully opened. + */ +const openExternalUrl: (url: string) => Promise = (url) => + globalThis.core.api?.openExternalUrl(url) + +/** + * Gets the resource path of the application. + * + * @returns {Promise} - A promise that resolves with the resource path. + */ +const getResourcePath: () => Promise = () => globalThis.core.api?.getResourcePath() + +/** + * Gets the user's home path. + * @returns return user's home path + */ +const getUserHomePath = (): Promise => globalThis.core.api?.getUserHomePath() + +/** + * Log to file from browser processes. + * + * @param message - Message to log. + */ +const log: (message: string, fileName?: string) => void = (message, fileName) => + globalThis.core.api?.log(message, fileName) + +/** + * Check whether the path is a subdirectory of another path. + * + * @param from - The path to check. + * @param to - The path to check against. + * + * @returns {Promise} - A promise that resolves with a boolean indicating whether the path is a subdirectory. + */ +const isSubdirectory: (from: string, to: string) => Promise = (from: string, to: string) => + globalThis.core.api?.isSubdirectory(from, to) + +/** + * Get system information + * @returns {Promise} - A promise that resolves with the system information. + */ +const systemInformation: () => Promise = () => + globalThis.core.api?.systemInformation() + +/** + * Show toast message from browser processes. + * @param title + * @param message + * @returns + */ +const showToast: (title: string, message: string) => void = (title, message) => + globalThis.core.api?.showToast(title, message) + +/** + * Register extension point function type definition + */ +export type RegisterExtensionPoint = ( + extensionName: string, + extensionId: string, + method: Function, + priority?: number +) => void + +/** + * Functions exports + */ +export { + executeOnMain, + downloadFile, + abortDownload, + getJanDataFolderPath, + openFileExplorer, + getResourcePath, + joinPath, + openExternalUrl, + baseName, + log, + isSubdirectory, + getUserHomePath, + systemInformation, + showToast, + getFileSize, + FileStat, +} diff --git a/core/src/browser/events.ts b/core/src/browser/events.ts new file mode 100644 index 0000000000..da85f7e3be --- /dev/null +++ b/core/src/browser/events.ts @@ -0,0 +1,35 @@ +/** + * Adds an observer for an event. + * + * @param eventName The name of the event to observe. + * @param handler The handler function to call when the event is observed. + */ +const on: (eventName: string, handler: Function) => void = (eventName, handler) => { + globalThis.core?.events?.on(eventName, handler) +} + +/** + * Removes an observer for an event. + * + * @param eventName The name of the event to stop observing. + * @param handler The handler function to call when the event is observed. + */ +const off: (eventName: string, handler: Function) => void = (eventName, handler) => { + globalThis.core?.events?.off(eventName, handler) +} + +/** + * Emits an event. + * + * @param eventName The name of the event to emit. + * @param object The object to pass to the event callback. + */ +const emit: (eventName: string, object: any) => void = (eventName, object) => { + globalThis.core?.events?.emit(eventName, object) +} + +export const events = { + on, + off, + emit, +} diff --git a/core/src/browser/extension.ts b/core/src/browser/extension.ts new file mode 100644 index 0000000000..18a6e44919 --- /dev/null +++ b/core/src/browser/extension.ts @@ -0,0 +1,211 @@ +import { SettingComponentProps } from '../types' +import { getJanDataFolderPath, joinPath } from './core' +import { fs } from './fs' + +export enum ExtensionTypeEnum { + Assistant = 'assistant', + Conversational = 'conversational', + Inference = 'inference', + Model = 'model', + SystemMonitoring = 'systemMonitoring', + HuggingFace = 'huggingFace', +} + +export interface ExtensionType { + type(): ExtensionTypeEnum | undefined +} + +export interface Compatibility { + platform: string[] + version: string +} + +const ALL_INSTALLATION_STATE = [ + 'NotRequired', // not required. + 'Installed', // require and installed. Good to go. + 'NotInstalled', // require to be installed. + 'Corrupted', // require but corrupted. Need to redownload. + 'NotCompatible', // require but not compatible. +] as const + +export type InstallationStateTuple = typeof ALL_INSTALLATION_STATE +export type InstallationState = InstallationStateTuple[number] + +/** + * Represents a base extension. + * This class should be extended by any class that represents an extension. + */ +export abstract class BaseExtension implements ExtensionType { + protected settingFolderName = 'settings' + protected settingFileName = 'settings.json' + + /** @type {string} Name of the extension. */ + name: string + + /** @type {string} Product Name of the extension. */ + productName?: string + + /** @type {string} The URL of the extension to load. */ + url: string + + /** @type {boolean} Whether the extension is activated or not. */ + active + + /** @type {string} Extension's description. */ + description + + /** @type {string} Extension's version. */ + version + + constructor( + url: string, + name: string, + productName?: string, + active?: boolean, + description?: string, + version?: string + ) { + this.name = name + this.productName = productName + this.url = url + this.active = active + this.description = description + this.version = version + } + + /** + * Returns the type of the extension. + * @returns {ExtensionType} The type of the extension + * Undefined means its not extending any known extension by the application. + */ + type(): ExtensionTypeEnum | undefined { + return undefined + } + + /** + * Called when the extension is loaded. + * Any initialization logic for the extension should be put here. + */ + abstract onLoad(): void + + /** + * Called when the extension is unloaded. + * Any cleanup logic for the extension should be put here. + */ + abstract onUnload(): void + + /** + * The compatibility of the extension. + * This is used to check if the extension is compatible with the current environment. + * @property {Array} platform + */ + compatibility(): Compatibility | undefined { + return undefined + } + + async registerSettings(settings: SettingComponentProps[]): Promise { + if (!this.name) { + console.error('Extension name is not defined') + return + } + + const extensionSettingFolderPath = await joinPath([ + await getJanDataFolderPath(), + 'settings', + this.name, + ]) + settings.forEach((setting) => { + setting.extensionName = this.name + }) + try { + await fs.mkdir(extensionSettingFolderPath) + const settingFilePath = await joinPath([extensionSettingFolderPath, this.settingFileName]) + + if (await fs.existsSync(settingFilePath)) return + await fs.writeFileSync(settingFilePath, JSON.stringify(settings, null, 2)) + } catch (err) { + console.error(err) + } + } + + async getSetting(key: string, defaultValue: T) { + const keySetting = (await this.getSettings()).find((setting) => setting.key === key) + + const value = keySetting?.controllerProps.value + return (value as T) ?? defaultValue + } + + onSettingUpdate(key: string, value: T) { + return + } + + /** + * Determine if the prerequisites for the extension are installed. + * + * @returns {boolean} true if the prerequisites are installed, false otherwise. + */ + async installationState(): Promise { + return 'NotRequired' + } + + /** + * Install the prerequisites for the extension. + * + * @returns {Promise} + */ + async install(): Promise { + return + } + + async getSettings(): Promise { + if (!this.name) return [] + + const settingPath = await joinPath([ + await getJanDataFolderPath(), + this.settingFolderName, + this.name, + this.settingFileName, + ]) + + try { + const content = await fs.readFileSync(settingPath, 'utf-8') + const settings: SettingComponentProps[] = JSON.parse(content) + return settings + } catch (err) { + console.warn(err) + return [] + } + } + + async updateSettings(componentProps: Partial[]): Promise { + if (!this.name) return + + const settings = await this.getSettings() + + const updatedSettings = settings.map((setting) => { + const updatedSetting = componentProps.find( + (componentProp) => componentProp.key === setting.key + ) + if (updatedSetting && updatedSetting.controllerProps) { + setting.controllerProps.value = updatedSetting.controllerProps.value + } + return setting + }) + + const settingPath = await joinPath([ + await getJanDataFolderPath(), + this.settingFolderName, + this.name, + this.settingFileName, + ]) + + await fs.writeFileSync(settingPath, JSON.stringify(updatedSettings, null, 2)) + + updatedSettings.forEach((setting) => { + this.onSettingUpdate( + setting.key, + setting.controllerProps.value + ) + }) + } +} diff --git a/core/src/browser/extensions/assistant.ts b/core/src/browser/extensions/assistant.ts new file mode 100644 index 0000000000..d025c67868 --- /dev/null +++ b/core/src/browser/extensions/assistant.ts @@ -0,0 +1,19 @@ +import { Assistant, AssistantInterface } from '../../types' +import { BaseExtension, ExtensionTypeEnum } from '../extension' + +/** + * Assistant extension for managing assistants. + * @extends BaseExtension + */ +export abstract class AssistantExtension extends BaseExtension implements AssistantInterface { + /** + * Assistant extension type. + */ + type(): ExtensionTypeEnum | undefined { + return ExtensionTypeEnum.Assistant + } + + abstract createAssistant(assistant: Assistant): Promise + abstract deleteAssistant(assistant: Assistant): Promise + abstract getAssistants(): Promise +} diff --git a/core/src/browser/extensions/conversational.ts b/core/src/browser/extensions/conversational.ts new file mode 100644 index 0000000000..ec53fbbbf9 --- /dev/null +++ b/core/src/browser/extensions/conversational.ts @@ -0,0 +1,26 @@ +import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../../types' +import { BaseExtension, ExtensionTypeEnum } from '../extension' + +/** + * Conversational extension. Persists and retrieves conversations. + * @abstract + * @extends BaseExtension + */ +export abstract class ConversationalExtension + extends BaseExtension + implements ThreadInterface, MessageInterface +{ + /** + * Conversation extension type. + */ + type(): ExtensionTypeEnum | undefined { + return ExtensionTypeEnum.Conversational + } + + abstract getThreads(): Promise + abstract saveThread(thread: Thread): Promise + abstract deleteThread(threadId: string): Promise + abstract addNewMessage(message: ThreadMessage): Promise + abstract writeMessages(threadId: string, messages: ThreadMessage[]): Promise + abstract getAllMessages(threadId: string): Promise +} diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts new file mode 100644 index 0000000000..7cd9f513e2 --- /dev/null +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -0,0 +1,104 @@ +import { getJanDataFolderPath, joinPath } from '../../core' +import { events } from '../../events' +import { BaseExtension } from '../../extension' +import { fs } from '../../fs' +import { MessageRequest, Model, ModelEvent } from '../../../types' +import { EngineManager } from './EngineManager' + +/** + * Base AIEngine + * Applicable to all AI Engines + */ +export abstract class AIEngine extends BaseExtension { + private static modelsFolder = 'models' + + // The inference engine + abstract provider: string + + /** + * On extension load, subscribe to events. + */ + override onLoad() { + this.registerEngine() + + events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) + } + + /** + * Registers AI Engines + */ + registerEngine() { + EngineManager.instance().register(this) + } + + async registerModels(models: Model[]): Promise { + const modelFolderPath = await joinPath([await getJanDataFolderPath(), AIEngine.modelsFolder]) + + let shouldNotifyModelUpdate = false + for (const model of models) { + const modelPath = await joinPath([modelFolderPath, model.id]) + const isExist = await fs.existsSync(modelPath) + + if (isExist) { + await this.migrateModelIfNeeded(model, modelPath) + continue + } + + await fs.mkdir(modelPath) + await fs.writeFileSync( + await joinPath([modelPath, 'model.json']), + JSON.stringify(model, null, 2) + ) + shouldNotifyModelUpdate = true + } + + if (shouldNotifyModelUpdate) { + events.emit(ModelEvent.OnModelsUpdate, {}) + } + } + + async migrateModelIfNeeded(model: Model, modelPath: string): Promise { + try { + const modelJson = await fs.readFileSync(await joinPath([modelPath, 'model.json']), 'utf-8') + const currentModel: Model = JSON.parse(modelJson) + if (currentModel.version !== model.version) { + await fs.writeFileSync( + await joinPath([modelPath, 'model.json']), + JSON.stringify(model, null, 2) + ) + + events.emit(ModelEvent.OnModelsUpdate, {}) + } + } catch (error) { + console.warn('Error while try to migrating model', error) + } + } + + /** + * Loads the model. + */ + async loadModel(model: Model): Promise { + if (model.engine.toString() !== this.provider) return Promise.resolve() + events.emit(ModelEvent.OnModelReady, model) + return Promise.resolve() + } + /** + * Stops the model. + */ + async unloadModel(model?: Model): Promise { + if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve() + events.emit(ModelEvent.OnModelStopped, model ?? {}) + return Promise.resolve() + } + + /* + * Inference request + */ + inference(data: MessageRequest) {} + + /** + * Stop inference + */ + stopInference() {} +} diff --git a/core/src/browser/extensions/engines/EngineManager.ts b/core/src/browser/extensions/engines/EngineManager.ts new file mode 100644 index 0000000000..2980c5c65e --- /dev/null +++ b/core/src/browser/extensions/engines/EngineManager.ts @@ -0,0 +1,32 @@ +import { AIEngine } from './AIEngine' + +/** + * Manages the registration and retrieval of inference engines. + */ +export class EngineManager { + public engines = new Map() + + /** + * Registers an engine. + * @param engine - The engine to register. + */ + register(engine: T) { + this.engines.set(engine.provider, engine) + } + + /** + * Retrieves a engine by provider. + * @param provider - The name of the engine to retrieve. + * @returns The engine, if found. + */ + get(provider: string): T | undefined { + return this.engines.get(provider) as T | undefined + } + + /** + * The instance of the engine manager. + */ + static instance(): EngineManager { + return window.core?.engineManager as EngineManager ?? new EngineManager() + } +} diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts new file mode 100644 index 0000000000..fb9e4962c4 --- /dev/null +++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts @@ -0,0 +1,64 @@ +import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core' +import { events } from '../../events' +import { Model, ModelEvent } from '../../../types' +import { OAIEngine } from './OAIEngine' + +/** + * Base OAI Local Inference Provider + * Added the implementation of loading and unloading model (applicable to local inference providers) + */ +export abstract class LocalOAIEngine extends OAIEngine { + // The inference engine + abstract nodeModule: string + loadModelFunctionName: string = 'loadModel' + unloadModelFunctionName: string = 'unloadModel' + + /** + * On extension load, subscribe to events. + */ + override onLoad() { + super.onLoad() + // These events are applicable to local inference providers + events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) + } + + /** + * Load the model. + */ + override async loadModel(model: Model): Promise { + if (model.engine.toString() !== this.provider) return + const modelFolderName = 'models' + const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id]) + const systemInfo = await systemInformation() + const res = await executeOnMain( + this.nodeModule, + this.loadModelFunctionName, + { + modelFolder, + model, + }, + systemInfo + ) + + if (res?.error) { + events.emit(ModelEvent.OnModelFail, { error: res.error }) + return Promise.reject(res.error) + } else { + this.loadedModel = model + events.emit(ModelEvent.OnModelReady, model) + return Promise.resolve() + } + } + /** + * Stops the model. + */ + override async unloadModel(model?: Model) { + if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve() + + this.loadedModel = undefined + await executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => { + events.emit(ModelEvent.OnModelStopped, {}) + }) + } +} diff --git a/core/src/browser/extensions/engines/OAIEngine.ts b/core/src/browser/extensions/engines/OAIEngine.ts new file mode 100644 index 0000000000..01ef55e5e4 --- /dev/null +++ b/core/src/browser/extensions/engines/OAIEngine.ts @@ -0,0 +1,157 @@ +import { requestInference } from './helpers/sse' +import { ulid } from 'ulidx' +import { AIEngine } from './AIEngine' +import { + ChatCompletionRole, + ContentType, + InferenceEvent, + MessageEvent, + MessageRequest, + MessageRequestType, + MessageStatus, + Model, + ModelInfo, + ThreadContent, + ThreadMessage, +} from '../../../types' +import { events } from '../../events' + +/** + * Base OAI Inference Provider + * Applicable to all OAI compatible inference providers + */ +export abstract class OAIEngine extends AIEngine { + // The inference engine + abstract inferenceUrl: string + + // Controller to handle stop requests + controller = new AbortController() + isCancelled = false + + // The loaded model instance + loadedModel: Model | undefined + + // Transform the payload + transformPayload?: Function + + // Transform the response + transformResponse?: Function + + /** + * On extension load, subscribe to events. + */ + override onLoad() { + super.onLoad() + events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data)) + events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference()) + } + + /** + * On extension unload + */ + override onUnload(): void {} + + /* + * Inference request + */ + override async inference(data: MessageRequest) { + if (data.model?.engine?.toString() !== this.provider) return + + const timestamp = Date.now() + const message: ThreadMessage = { + id: ulid(), + thread_id: data.threadId, + type: data.type, + assistant_id: data.assistantId, + role: ChatCompletionRole.Assistant, + content: [], + status: MessageStatus.Pending, + created: timestamp, + updated: timestamp, + object: 'thread.message', + } + + if (data.type !== MessageRequestType.Summary) { + events.emit(MessageEvent.OnMessageResponse, message) + } + + this.isCancelled = false + this.controller = new AbortController() + + const model: ModelInfo = { + ...(this.loadedModel ? this.loadedModel : {}), + ...data.model, + } + + const header = await this.headers() + let requestBody = { + messages: data.messages ?? [], + model: model.id, + stream: true, + ...model.parameters, + } + if (this.transformPayload) { + requestBody = this.transformPayload(requestBody) + } + + requestInference( + this.inferenceUrl, + requestBody, + model, + this.controller, + header, + this.transformResponse + ).subscribe({ + next: (content: any) => { + const messageContent: ThreadContent = { + type: ContentType.Text, + text: { + value: content.trim(), + annotations: [], + }, + } + message.content = [messageContent] + events.emit(MessageEvent.OnMessageUpdate, message) + }, + complete: async () => { + message.status = message.content.length ? MessageStatus.Ready : MessageStatus.Error + events.emit(MessageEvent.OnMessageUpdate, message) + }, + error: async (err: any) => { + console.debug('inference url: ', this.inferenceUrl) + console.debug('header: ', header) + console.error(`Inference error:`, JSON.stringify(err)) + if (this.isCancelled || message.content.length) { + message.status = MessageStatus.Stopped + events.emit(MessageEvent.OnMessageUpdate, message) + return + } + message.status = MessageStatus.Error + message.content[0] = { + type: ContentType.Text, + text: { + value: err.message, + annotations: [], + }, + } + message.error_code = err.code + events.emit(MessageEvent.OnMessageUpdate, message) + }, + }) + } + + /** + * Stops the inference. + */ + override stopInference() { + this.isCancelled = true + this.controller?.abort() + } + + /** + * Headers for the inference request + */ + async headers(): Promise { + return {} + } +} diff --git a/core/src/browser/extensions/engines/RemoteOAIEngine.ts b/core/src/browser/extensions/engines/RemoteOAIEngine.ts new file mode 100644 index 0000000000..b112353707 --- /dev/null +++ b/core/src/browser/extensions/engines/RemoteOAIEngine.ts @@ -0,0 +1,27 @@ +import { OAIEngine } from './OAIEngine' + +/** + * Base OAI Remote Inference Provider + * Added the implementation of loading and unloading model (applicable to local inference providers) + */ +export abstract class RemoteOAIEngine extends OAIEngine { + apiKey?: string + /** + * On extension load, subscribe to events. + */ + override onLoad() { + super.onLoad() + } + + /** + * Headers for the inference request + */ + override async headers(): Promise { + return { + ...(this.apiKey && { + 'Authorization': `Bearer ${this.apiKey}`, + 'api-key': `${this.apiKey}`, + }), + } + } +} diff --git a/core/src/browser/extensions/engines/helpers/sse.ts b/core/src/browser/extensions/engines/helpers/sse.ts new file mode 100644 index 0000000000..024ced4703 --- /dev/null +++ b/core/src/browser/extensions/engines/helpers/sse.ts @@ -0,0 +1,95 @@ +import { Observable } from 'rxjs' +import { ErrorCode, ModelRuntimeParams } from '../../../../types' +/** + * Sends a request to the inference server to generate a response based on the recent messages. + * @param recentMessages - An array of recent messages to use as context for the inference. + * @returns An Observable that emits the generated response as a string. + */ +export function requestInference( + inferenceUrl: string, + requestBody: any, + model: { + id: string + parameters: ModelRuntimeParams + }, + controller?: AbortController, + headers?: HeadersInit, + transformResponse?: Function +): Observable { + return new Observable((subscriber) => { + fetch(inferenceUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Access-Control-Allow-Origin': '*', + 'Accept': model.parameters.stream ? 'text/event-stream' : 'application/json', + ...headers, + }, + body: JSON.stringify(requestBody), + signal: controller?.signal, + }) + .then(async (response) => { + if (!response.ok) { + const data = await response.json() + let errorCode = ErrorCode.Unknown + if (data.error) { + errorCode = data.error.code ?? data.error.type ?? ErrorCode.Unknown + } else if (response.status === 401) { + errorCode = ErrorCode.InvalidApiKey + } + const error = { + message: data.error?.message ?? 'Error occurred.', + code: errorCode, + } + subscriber.error(error) + subscriber.complete() + return + } + if (model.parameters.stream === false) { + const data = await response.json() + if (transformResponse) { + subscriber.next(transformResponse(data)) + } else { + subscriber.next(data.choices[0]?.message?.content ?? '') + } + } else { + const stream = response.body + const decoder = new TextDecoder('utf-8') + const reader = stream?.getReader() + let content = '' + + while (true && reader) { + const { done, value } = await reader.read() + if (done) { + break + } + const text = decoder.decode(value) + const lines = text.trim().split('\n') + let cachedLines = '' + for (const line of lines) { + try { + if (transformResponse) { + content += transformResponse(line) + subscriber.next(content ?? '') + } else { + const toParse = cachedLines + line + if (!line.includes('data: [DONE]')) { + const data = JSON.parse(toParse.replace('data: ', '')) + content += data.choices[0]?.delta?.content ?? '' + if (content.startsWith('assistant: ')) { + content = content.replace('assistant: ', '') + } + if (content !== '') subscriber.next(content) + } + } + } catch { + cachedLines = line + } + } + } + } + subscriber.complete() + }) + .catch((err) => subscriber.error(err)) + }) +} diff --git a/core/src/browser/extensions/engines/index.ts b/core/src/browser/extensions/engines/index.ts new file mode 100644 index 0000000000..34ef45afd1 --- /dev/null +++ b/core/src/browser/extensions/engines/index.ts @@ -0,0 +1,5 @@ +export * from './AIEngine' +export * from './OAIEngine' +export * from './LocalOAIEngine' +export * from './RemoteOAIEngine' +export * from './EngineManager' diff --git a/core/src/browser/extensions/index.ts b/core/src/browser/extensions/index.ts new file mode 100644 index 0000000000..85d5a85835 --- /dev/null +++ b/core/src/browser/extensions/index.ts @@ -0,0 +1,30 @@ +/** + * Conversational extension. Persists and retrieves conversations. + * @module + */ +export { ConversationalExtension } from './conversational' + +/** + * Inference extension. Start, stop and inference models. + */ +export { InferenceExtension } from './inference' + +/** + * Monitoring extension for system monitoring. + */ +export { MonitoringExtension } from './monitoring' + +/** + * Assistant extension for managing assistants. + */ +export { AssistantExtension } from './assistant' + +/** + * Model extension for managing models. + */ +export { ModelExtension } from './model' + +/** + * Base AI Engines. + */ +export * from './engines' diff --git a/core/src/browser/extensions/inference.ts b/core/src/browser/extensions/inference.ts new file mode 100644 index 0000000000..44c50f7f82 --- /dev/null +++ b/core/src/browser/extensions/inference.ts @@ -0,0 +1,16 @@ +import { InferenceInterface, MessageRequest, ThreadMessage } from '../../types' +import { BaseExtension, ExtensionTypeEnum } from '../extension' + +/** + * Inference extension. Start, stop and inference models. + */ +export abstract class InferenceExtension extends BaseExtension implements InferenceInterface { + /** + * Inference extension type. + */ + type(): ExtensionTypeEnum | undefined { + return ExtensionTypeEnum.Inference + } + + abstract inference(data: MessageRequest): Promise +} diff --git a/core/src/browser/extensions/model.ts b/core/src/browser/extensions/model.ts new file mode 100644 index 0000000000..5b3089403f --- /dev/null +++ b/core/src/browser/extensions/model.ts @@ -0,0 +1,36 @@ +import { BaseExtension, ExtensionTypeEnum } from '../extension' +import { + GpuSetting, + HuggingFaceRepoData, + ImportingModel, + Model, + ModelInterface, + OptionType, +} from '../../types' + +/** + * Model extension for managing models. + */ +export abstract class ModelExtension extends BaseExtension implements ModelInterface { + /** + * Model extension type. + */ + type(): ExtensionTypeEnum | undefined { + return ExtensionTypeEnum.Model + } + + abstract downloadModel( + model: Model, + gpuSettings?: GpuSetting, + network?: { proxy: string; ignoreSSL?: boolean } + ): Promise + abstract cancelModelDownload(modelId: string): Promise + abstract deleteModel(modelId: string): Promise + abstract saveModel(model: Model): Promise + abstract getDownloadedModels(): Promise + abstract getConfiguredModels(): Promise + abstract importModels(models: ImportingModel[], optionType: OptionType): Promise + abstract updateModelInfo(modelInfo: Partial): Promise + abstract fetchHuggingFaceRepoData(repoId: string): Promise + abstract getDefaultModel(): Promise +} diff --git a/core/src/browser/extensions/monitoring.ts b/core/src/browser/extensions/monitoring.ts new file mode 100644 index 0000000000..cb544b6b72 --- /dev/null +++ b/core/src/browser/extensions/monitoring.ts @@ -0,0 +1,20 @@ +import { BaseExtension, ExtensionTypeEnum } from '../extension' +import { GpuSetting, MonitoringInterface, OperatingSystemInfo } from '../../types' + +/** + * Monitoring extension for system monitoring. + * @extends BaseExtension + */ +export abstract class MonitoringExtension extends BaseExtension implements MonitoringInterface { + /** + * Monitoring extension type. + */ + type(): ExtensionTypeEnum | undefined { + return ExtensionTypeEnum.SystemMonitoring + } + + abstract getGpuSetting(): Promise + abstract getResourcesInfo(): Promise + abstract getCurrentLoad(): Promise + abstract getOsInfo(): Promise +} diff --git a/core/src/browser/fs.ts b/core/src/browser/fs.ts new file mode 100644 index 0000000000..cca9bb1d3f --- /dev/null +++ b/core/src/browser/fs.ts @@ -0,0 +1,87 @@ +import { FileStat } from '../types' + +/** + * Writes data to a file at the specified path. + * @returns {Promise} A Promise that resolves when the file is written successfully. + */ +const writeFileSync = (...args: any[]) => globalThis.core.api?.writeFileSync(...args) + +/** + * Writes blob data to a file at the specified path. + * @param path - The path to file. + * @param data - The blob data. + * @returns + */ +const writeBlob: (path: string, data: string) => Promise = (path, data) => + globalThis.core.api?.writeBlob(path, data) + +/** + * Reads the contents of a file at the specified path. + * @returns {Promise} A Promise that resolves with the contents of the file. + */ +const readFileSync = (...args: any[]) => globalThis.core.api?.readFileSync(...args) +/** + * Check whether the file exists + * @param {string} path + * @returns {boolean} A boolean indicating whether the path is a file. + */ +const existsSync = (...args: any[]) => globalThis.core.api?.existsSync(...args) +/** + * List the directory files + * @returns {Promise} A Promise that resolves with the contents of the directory. + */ +const readdirSync = (...args: any[]) => globalThis.core.api?.readdirSync(...args) +/** + * Creates a directory at the specified path. + * @returns {Promise} A Promise that resolves when the directory is created successfully. + */ +const mkdir = (...args: any[]) => globalThis.core.api?.mkdir(...args) + +/** + * Removes a directory at the specified path. + * @returns {Promise} A Promise that resolves when the directory is removed successfully. + */ +const rm = (...args: any[]) => globalThis.core.api?.rm(...args, { recursive: true, force: true }) + +/** + * Deletes a file from the local file system. + * @param {string} path - The path of the file to delete. + * @returns {Promise} A Promise that resolves when the file is deleted. + */ +const unlinkSync = (...args: any[]) => globalThis.core.api?.unlinkSync(...args) + +/** + * Appends data to a file at the specified path. + */ +const appendFileSync = (...args: any[]) => globalThis.core.api?.appendFileSync(...args) + +const copyFile: (src: string, dest: string) => Promise = (src, dest) => + globalThis.core.api?.copyFile(src, dest) + +/** + * Gets the file's stats. + * + * @param path - The path to the file. + * @param outsideJanDataFolder - Whether the file is outside the Jan data folder. + * @returns {Promise} - A promise that resolves with the file's stats. + */ +const fileStat: (path: string, outsideJanDataFolder?: boolean) => Promise = ( + path, + outsideJanDataFolder +) => globalThis.core.api?.fileStat(path, outsideJanDataFolder) + +// TODO: Export `dummy` fs functions automatically +// Currently adding these manually +export const fs = { + writeFileSync, + readFileSync, + existsSync, + readdirSync, + mkdir, + rm, + unlinkSync, + appendFileSync, + copyFile, + fileStat, + writeBlob, +} diff --git a/core/src/browser/index.ts b/core/src/browser/index.ts new file mode 100644 index 0000000000..a7803c7e04 --- /dev/null +++ b/core/src/browser/index.ts @@ -0,0 +1,35 @@ +/** + * Export Core module + * @module + */ +export * from './core' + +/** + * Export Event module. + * @module + */ +export * from './events' + +/** + * Export Filesystem module. + * @module + */ +export * from './fs' + +/** + * Export Extension module. + * @module + */ +export * from './extension' + +/** + * Export all base extensions. + * @module + */ +export * from './extensions' + +/** + * Export all base tools. + * @module + */ +export * from './tools' diff --git a/core/src/browser/tools/index.ts b/core/src/browser/tools/index.ts new file mode 100644 index 0000000000..24cd127804 --- /dev/null +++ b/core/src/browser/tools/index.ts @@ -0,0 +1,2 @@ +export * from './manager' +export * from './tool' diff --git a/core/src/browser/tools/manager.ts b/core/src/browser/tools/manager.ts new file mode 100644 index 0000000000..b323ad7ced --- /dev/null +++ b/core/src/browser/tools/manager.ts @@ -0,0 +1,47 @@ +import { AssistantTool, MessageRequest } from '../../types' +import { InferenceTool } from './tool' + +/** + * Manages the registration and retrieval of inference tools. + */ +export class ToolManager { + public tools = new Map() + + /** + * Registers a tool. + * @param tool - The tool to register. + */ + register(tool: T) { + this.tools.set(tool.name, tool) + } + + /** + * Retrieves a tool by it's name. + * @param name - The name of the tool to retrieve. + * @returns The tool, if found. + */ + get(name: string): T | undefined { + return this.tools.get(name) as T | undefined + } + + /* + ** Process the message request with the tools. + */ + process(request: MessageRequest, tools: AssistantTool[]): Promise { + return tools.reduce((prevPromise, currentTool) => { + return prevPromise.then((prevResult) => { + return currentTool.enabled + ? this.get(currentTool.type)?.process(prevResult, currentTool) ?? + Promise.resolve(prevResult) + : Promise.resolve(prevResult) + }) + }, Promise.resolve(request)) + } + + /** + * The instance of the tool manager. + */ + static instance(): ToolManager { + return (window.core?.toolManager as ToolManager) ?? new ToolManager() + } +} diff --git a/core/src/browser/tools/tool.ts b/core/src/browser/tools/tool.ts new file mode 100644 index 0000000000..0fd3429331 --- /dev/null +++ b/core/src/browser/tools/tool.ts @@ -0,0 +1,12 @@ +import { AssistantTool, MessageRequest } from '../../types' + +/** + * Represents a base inference tool. + */ +export abstract class InferenceTool { + abstract name: string + /* + ** Process a message request and return the processed message request. + */ + abstract process(request: MessageRequest, tool?: AssistantTool): Promise +} diff --git a/core/src/index.ts b/core/src/index.ts index 1bb83fc879..cfd69f93d1 100644 --- a/core/src/index.ts +++ b/core/src/index.ts @@ -4,6 +4,12 @@ */ export * from './types' +/** + * Export browser module + * @module + */ +export * from './browser' + /** * Declare global object */ diff --git a/core/src/node/api/HttpServer.ts b/core/src/node/api/HttpServer.ts new file mode 100644 index 0000000000..32d5977175 --- /dev/null +++ b/core/src/node/api/HttpServer.ts @@ -0,0 +1,8 @@ +export interface HttpServer { + post: (route: string, handler: (req: any, res: any) => Promise) => void + get: (route: string, handler: (req: any, res: any) => Promise) => void + patch: (route: string, handler: (req: any, res: any) => Promise) => void + put: (route: string, handler: (req: any, res: any) => Promise) => void + delete: (route: string, handler: (req: any, res: any) => Promise) => void + register: (router: any, opts?: any) => void +} diff --git a/core/src/node/api/common/adapter.ts b/core/src/node/api/common/adapter.ts new file mode 100644 index 0000000000..2beacf3254 --- /dev/null +++ b/core/src/node/api/common/adapter.ts @@ -0,0 +1,43 @@ +import { + AppRoute, + DownloadRoute, + ExtensionRoute, + FileManagerRoute, + FileSystemRoute, +} from '../../../types/api' +import { Downloader } from '../processors/download' +import { FileSystem } from '../processors/fs' +import { Extension } from '../processors/extension' +import { FSExt } from '../processors/fsExt' +import { App } from '../processors/app' + +export class RequestAdapter { + downloader: Downloader + fileSystem: FileSystem + extension: Extension + fsExt: FSExt + app: App + + constructor(observer?: Function) { + this.downloader = new Downloader(observer) + this.fileSystem = new FileSystem() + this.extension = new Extension() + this.fsExt = new FSExt() + this.app = new App() + } + + // TODO: Clearer Factory pattern here + process(route: string, ...args: any) { + if (route in DownloadRoute) { + return this.downloader.process(route, ...args) + } else if (route in FileSystemRoute) { + return this.fileSystem.process(route, ...args) + } else if (route in ExtensionRoute) { + return this.extension.process(route, ...args) + } else if (route in FileManagerRoute) { + return this.fsExt.process(route, ...args) + } else if (route in AppRoute) { + return this.app.process(route, ...args) + } + } +} diff --git a/core/src/node/api/common/handler.ts b/core/src/node/api/common/handler.ts new file mode 100644 index 0000000000..5cf232d8a6 --- /dev/null +++ b/core/src/node/api/common/handler.ts @@ -0,0 +1,20 @@ +import { CoreRoutes } from '../../../types/api' +import { RequestAdapter } from './adapter' + +export type Handler = (route: string, args: any) => any + +export class RequestHandler { + handler: Handler + adapter: RequestAdapter + + constructor(handler: Handler, observer?: Function) { + this.handler = handler + this.adapter = new RequestAdapter(observer) + } + + handle() { + CoreRoutes.map((route) => { + this.handler(route, async (...args: any[]) => this.adapter.process(route, ...args)) + }) + } +} diff --git a/core/src/node/api/index.ts b/core/src/node/api/index.ts new file mode 100644 index 0000000000..ab0c516569 --- /dev/null +++ b/core/src/node/api/index.ts @@ -0,0 +1,3 @@ +export * from './HttpServer' +export * from './restful/v1' +export * from './common/handler' diff --git a/core/src/node/api/processors/Processor.ts b/core/src/node/api/processors/Processor.ts new file mode 100644 index 0000000000..8ef0c6e191 --- /dev/null +++ b/core/src/node/api/processors/Processor.ts @@ -0,0 +1,3 @@ +export abstract class Processor { + abstract process(key: string, ...args: any[]): any +} diff --git a/core/src/node/api/processors/app.ts b/core/src/node/api/processors/app.ts new file mode 100644 index 0000000000..c98060da49 --- /dev/null +++ b/core/src/node/api/processors/app.ts @@ -0,0 +1,93 @@ +import { basename, isAbsolute, join, relative } from 'path' + +import { Processor } from './Processor' +import { + log as writeLog, + appResourcePath, + getAppConfigurations as appConfiguration, + updateAppConfiguration, +} from '../../helper' + +export class App implements Processor { + observer?: Function + + constructor(observer?: Function) { + this.observer = observer + } + + process(key: string, ...args: any[]): any { + const instance = this as any + const func = instance[key] + return func(...args) + } + + /** + * Joins multiple paths together, respect to the current OS. + */ + joinPath(args: any[]) { + return join(...args) + } + + /** + * Checks if the given path is a subdirectory of the given directory. + * + * @param _event - The IPC event object. + * @param from - The path to check. + * @param to - The directory to check against. + * + * @returns {Promise} - A promise that resolves with the result. + */ + isSubdirectory(from: any, to: any) { + const rel = relative(from, to) + const isSubdir = rel && !rel.startsWith('..') && !isAbsolute(rel) + + if (isSubdir === '') return false + else return isSubdir + } + + /** + * Retrieve basename from given path, respect to the current OS. + */ + baseName(args: any) { + return basename(args) + } + + /** + * Log message to log file. + */ + log(args: any) { + writeLog(args) + } + + getAppConfigurations() { + return appConfiguration() + } + + async updateAppConfiguration(args: any) { + await updateAppConfiguration(args) + } + + /** + * Start Jan API Server. + */ + async startServer(args?: any) { + const { startServer } = require('@janhq/server') + return startServer({ + host: args?.host, + port: args?.port, + isCorsEnabled: args?.isCorsEnabled, + isVerboseEnabled: args?.isVerboseEnabled, + schemaPath: join(await appResourcePath(), 'docs', 'openapi', 'jan.yaml'), + baseDir: join(await appResourcePath(), 'docs', 'openapi'), + prefix: args?.prefix, + }) + } + + /** + * Stop Jan API Server. + */ + stopServer() { + const { stopServer } = require('@janhq/server') + return stopServer() + } +} diff --git a/core/src/node/api/processors/download.ts b/core/src/node/api/processors/download.ts new file mode 100644 index 0000000000..07486bdf88 --- /dev/null +++ b/core/src/node/api/processors/download.ts @@ -0,0 +1,161 @@ +import { resolve, sep } from 'path' +import { DownloadEvent } from '../../../types/api' +import { normalizeFilePath, validatePath } from '../../helper/path' +import { getJanDataFolderPath } from '../../helper' +import { DownloadManager } from '../../helper/download' +import { createWriteStream, renameSync } from 'fs' +import { Processor } from './Processor' +import { DownloadRequest, DownloadState, NetworkConfig } from '../../../types' + +export class Downloader implements Processor { + observer?: Function + + constructor(observer?: Function) { + this.observer = observer + } + + process(key: string, ...args: any[]): any { + const instance = this as any + const func = instance[key] + return func(this.observer, ...args) + } + + downloadFile(observer: any, downloadRequest: DownloadRequest, network?: NetworkConfig) { + const request = require('request') + const progress = require('request-progress') + + const strictSSL = !network?.ignoreSSL + const proxy = network?.proxy?.startsWith('http') ? network.proxy : undefined + + const { localPath, url } = downloadRequest + let normalizedPath = localPath + if (typeof localPath === 'string') { + normalizedPath = normalizeFilePath(localPath) + } + const array = normalizedPath.split(sep) + const fileName = array.pop() ?? '' + const modelId = array.pop() ?? '' + + const destination = resolve(getJanDataFolderPath(), normalizedPath) + validatePath(destination) + const rq = request({ url, strictSSL, proxy }) + + // Put request to download manager instance + DownloadManager.instance.setRequest(normalizedPath, rq) + + // Downloading file to a temp file first + const downloadingTempFile = `${destination}.download` + + // adding initial download state + const initialDownloadState: DownloadState = { + modelId, + fileName, + time: { + elapsed: 0, + remaining: 0, + }, + speed: 0, + percent: 0, + size: { + total: 0, + transferred: 0, + }, + children: [], + downloadState: 'downloading', + extensionId: downloadRequest.extensionId, + downloadType: downloadRequest.downloadType, + localPath: normalizedPath, + } + DownloadManager.instance.downloadProgressMap[modelId] = initialDownloadState + DownloadManager.instance.downloadInfo[normalizedPath] = initialDownloadState + + if (downloadRequest.downloadType === 'extension') { + observer?.(DownloadEvent.onFileDownloadUpdate, initialDownloadState) + } + + progress(rq, {}) + .on('progress', (state: any) => { + const currentDownloadState = DownloadManager.instance.downloadProgressMap[modelId] + const downloadState: DownloadState = { + ...currentDownloadState, + ...state, + fileName: fileName, + downloadState: 'downloading', + } + console.debug('progress: ', downloadState) + observer?.(DownloadEvent.onFileDownloadUpdate, downloadState) + DownloadManager.instance.downloadProgressMap[modelId] = downloadState + }) + .on('error', (error: Error) => { + const currentDownloadState = DownloadManager.instance.downloadProgressMap[modelId] + const downloadState: DownloadState = { + ...currentDownloadState, + fileName: fileName, + error: error.message, + downloadState: 'error', + } + + observer?.(DownloadEvent.onFileDownloadError, downloadState) + DownloadManager.instance.downloadProgressMap[modelId] = downloadState + }) + .on('end', () => { + const currentDownloadState = DownloadManager.instance.downloadProgressMap[modelId] + if (currentDownloadState && DownloadManager.instance.networkRequests[normalizedPath]) { + // Finished downloading, rename temp file to actual file + renameSync(downloadingTempFile, destination) + const downloadState: DownloadState = { + ...currentDownloadState, + fileName: fileName, + downloadState: 'end', + } + observer?.(DownloadEvent.onFileDownloadSuccess, downloadState) + DownloadManager.instance.downloadProgressMap[modelId] = downloadState + } + }) + .pipe(createWriteStream(downloadingTempFile)) + } + + abortDownload(observer: any, fileName: string) { + const rq = DownloadManager.instance.networkRequests[fileName] + if (rq) { + DownloadManager.instance.networkRequests[fileName] = undefined + rq?.abort() + } + + const downloadInfo = DownloadManager.instance.downloadInfo[fileName] + observer?.(DownloadEvent.onFileDownloadError, { + ...downloadInfo, + fileName, + error: 'aborted', + }) + } + + resumeDownload(_observer: any, fileName: any) { + DownloadManager.instance.networkRequests[fileName]?.resume() + } + + pauseDownload(_observer: any, fileName: any) { + DownloadManager.instance.networkRequests[fileName]?.pause() + } + + async getFileSize(_observer: any, url: string): Promise { + return new Promise((resolve, reject) => { + const request = require('request') + request( + { + url, + method: 'HEAD', + }, + function (err: any, response: any) { + if (err) { + console.error('Getting file size failed:', err) + reject(err) + } else { + const size: number = response.headers['content-length'] ?? -1 + resolve(size) + } + } + ) + }) + } +} diff --git a/core/src/node/api/processors/extension.ts b/core/src/node/api/processors/extension.ts new file mode 100644 index 0000000000..df5d2d945c --- /dev/null +++ b/core/src/node/api/processors/extension.ts @@ -0,0 +1,88 @@ +import { readdirSync } from 'fs' +import { join, extname } from 'path' + +import { Processor } from './Processor' +import { ModuleManager } from '../../helper/module' +import { getJanExtensionsPath as getPath } from '../../helper' +import { + getActiveExtensions as getExtensions, + getExtension, + removeExtension, + installExtensions, +} from '../../extension/store' +import { appResourcePath } from '../../helper/path' + +export class Extension implements Processor { + observer?: Function + + constructor(observer?: Function) { + this.observer = observer + } + + process(key: string, ...args: any[]): any { + const instance = this as any + const func = instance[key] + return func(...args) + } + + invokeExtensionFunc(modulePath: string, method: string, ...params: any[]) { + const module = require(join(getPath(), modulePath)) + ModuleManager.instance.setModule(modulePath, module) + + if (typeof module[method] === 'function') { + return module[method](...params) + } else { + console.debug(module[method]) + console.error(`Function "${method}" does not exist in the module.`) + } + } + + /** + * Returns the paths of the base extensions. + * @returns An array of paths to the base extensions. + */ + async baseExtensions() { + const baseExtensionPath = join(await appResourcePath(), 'pre-install') + return readdirSync(baseExtensionPath) + .filter((file) => extname(file) === '.tgz') + .map((file) => join(baseExtensionPath, file)) + } + + /**MARK: Extension Manager handlers */ + async installExtension(extensions: any) { + // Install and activate all provided extensions + const installed = await installExtensions(extensions) + return JSON.parse(JSON.stringify(installed)) + } + + // Register IPC route to uninstall a extension + async uninstallExtension(extensions: any) { + // Uninstall all provided extensions + for (const ext of extensions) { + const extension = getExtension(ext) + await extension.uninstall() + if (extension.name) removeExtension(extension.name) + } + + // Reload all renderer pages if needed + return true + } + + // Register IPC route to update a extension + async updateExtension(extensions: any) { + // Update all provided extensions + const updated: any[] = [] + for (const ext of extensions) { + const extension = getExtension(ext) + const res = await extension.update() + if (res) updated.push(extension) + } + + // Reload all renderer pages if needed + return JSON.parse(JSON.stringify(updated)) + } + + getActiveExtensions() { + return JSON.parse(JSON.stringify(getExtensions())) + } +} diff --git a/core/src/node/api/processors/fs.ts b/core/src/node/api/processors/fs.ts new file mode 100644 index 0000000000..0557d21875 --- /dev/null +++ b/core/src/node/api/processors/fs.ts @@ -0,0 +1,95 @@ +import { join, resolve } from 'path' +import { normalizeFilePath, validatePath } from '../../helper/path' +import { getJanDataFolderPath } from '../../helper' +import { Processor } from './Processor' +import fs from 'fs' + +export class FileSystem implements Processor { + observer?: Function + private static moduleName = 'fs' + + constructor(observer?: Function) { + this.observer = observer + } + + process(route: string, ...args: any): any { + const instance = this as any + const func = instance[route] + if (func) { + return func(...args) + } else { + return import(FileSystem.moduleName).then((mdl) => + mdl[route]( + ...args.map((arg: any, index: number) => { + if(index !== 0) { + return arg + } + if (index === 0 && typeof arg !== 'string') { + throw new Error(`Invalid argument ${JSON.stringify(args)}`) + } + const path = + (arg.startsWith(`file:/`) || arg.startsWith(`file:\\`)) + ? join(getJanDataFolderPath(), normalizeFilePath(arg)) + : arg + + if(path.startsWith(`http://`) || path.startsWith(`https://`)) { + return path + } + const absolutePath = resolve(path) + validatePath(absolutePath) + return absolutePath + }) + ) + ) + } + } + + rm(...args: any): Promise { + if (typeof args[0] !== 'string') { + throw new Error(`rm error: Invalid argument ${JSON.stringify(args)}`) + } + + let path = args[0] + if (path.startsWith(`file:/`) || path.startsWith(`file:\\`)) { + path = join(getJanDataFolderPath(), normalizeFilePath(path)) + } + + const absolutePath = resolve(path) + validatePath(absolutePath) + + return new Promise((resolve, reject) => { + fs.rm(absolutePath, { recursive: true, force: true }, (err) => { + if (err) { + reject(err) + } else { + resolve() + } + }) + }) + } + + mkdir(...args: any): Promise { + if (typeof args[0] !== 'string') { + throw new Error(`mkdir error: Invalid argument ${JSON.stringify(args)}`) + } + + let path = args[0] + if (path.startsWith(`file:/`) || path.startsWith(`file:\\`)) { + path = join(getJanDataFolderPath(), normalizeFilePath(path)) + } + + const absolutePath = resolve(path) + validatePath(absolutePath) + + return new Promise((resolve, reject) => { + fs.mkdir(absolutePath, { recursive: true }, (err) => { + if (err) { + reject(err) + } else { + resolve() + } + }) + }) + } + +} diff --git a/core/src/node/api/processors/fsExt.ts b/core/src/node/api/processors/fsExt.ts new file mode 100644 index 0000000000..155732cfce --- /dev/null +++ b/core/src/node/api/processors/fsExt.ts @@ -0,0 +1,82 @@ +import { join } from 'path' +import fs from 'fs' +import { appResourcePath, normalizeFilePath, validatePath } from '../../helper/path' +import { getJanDataFolderPath, getJanDataFolderPath as getPath } from '../../helper' +import { Processor } from './Processor' +import { FileStat } from '../../../types' + +export class FSExt implements Processor { + observer?: Function + + constructor(observer?: Function) { + this.observer = observer + } + + process(key: string, ...args: any): any { + const instance = this as any + const func = instance[key] + return func(...args) + } + + // Handles the 'getJanDataFolderPath' IPC event. This event is triggered to get the user space path. + getJanDataFolderPath() { + return Promise.resolve(getPath()) + } + + // Handles the 'getResourcePath' IPC event. This event is triggered to get the resource path. + getResourcePath() { + return appResourcePath() + } + + // Handles the 'getUserHomePath' IPC event. This event is triggered to get the user home path. + getUserHomePath() { + return process.env[process.platform == 'win32' ? 'USERPROFILE' : 'HOME'] + } + + // handle fs is directory here + fileStat(path: string, outsideJanDataFolder?: boolean) { + const normalizedPath = normalizeFilePath(path) + + const fullPath = outsideJanDataFolder + ? normalizedPath + : join(getJanDataFolderPath(), normalizedPath) + const isExist = fs.existsSync(fullPath) + if (!isExist) return undefined + + const isDirectory = fs.lstatSync(fullPath).isDirectory() + const size = fs.statSync(fullPath).size + + const fileStat: FileStat = { + isDirectory, + size, + } + + return fileStat + } + + writeBlob(path: string, data: any) { + try { + const normalizedPath = normalizeFilePath(path) + + const dataBuffer = Buffer.from(data, 'base64') + const writePath = join(getJanDataFolderPath(), normalizedPath) + validatePath(writePath) + fs.writeFileSync(writePath, dataBuffer) + } catch (err) { + console.error(`writeFile ${path} result: ${err}`) + } + } + + copyFile(src: string, dest: string): Promise { + validatePath(dest) + return new Promise((resolve, reject) => { + fs.copyFile(src, dest, (err) => { + if (err) { + reject(err) + } else { + resolve() + } + }) + }) + } +} diff --git a/core/src/node/api/restful/app/download.ts b/core/src/node/api/restful/app/download.ts new file mode 100644 index 0000000000..5e0c83d01a --- /dev/null +++ b/core/src/node/api/restful/app/download.ts @@ -0,0 +1,23 @@ +import { DownloadRoute } from '../../../../types/api' +import { DownloadManager } from '../../../helper/download' +import { HttpServer } from '../../HttpServer' + +export const downloadRouter = async (app: HttpServer) => { + app.get(`/download/${DownloadRoute.getDownloadProgress}/:modelId`, async (req, res) => { + const modelId = req.params.modelId + + console.debug(`Getting download progress for model ${modelId}`) + console.debug( + `All Download progress: ${JSON.stringify(DownloadManager.instance.downloadProgressMap)}` + ) + + // check if null DownloadManager.instance.downloadProgressMap + if (!DownloadManager.instance.downloadProgressMap[modelId]) { + return res.status(404).send({ + message: 'Download progress not found', + }) + } else { + return res.status(200).send(DownloadManager.instance.downloadProgressMap[modelId]) + } + }) +} diff --git a/core/src/node/api/restful/app/handlers.ts b/core/src/node/api/restful/app/handlers.ts new file mode 100644 index 0000000000..43c3f7add9 --- /dev/null +++ b/core/src/node/api/restful/app/handlers.ts @@ -0,0 +1,13 @@ +import { HttpServer } from '../../HttpServer' +import { Handler, RequestHandler } from '../../common/handler' + +export function handleRequests(app: HttpServer) { + const restWrapper: Handler = (route: string, listener: (...args: any[]) => any) => { + app.post(`/app/${route}`, async (request: any, reply: any) => { + const args = JSON.parse(request.body) as any[] + reply.send(JSON.stringify(await listener(...args))) + }) + } + const handler = new RequestHandler(restWrapper) + handler.handle() +} diff --git a/core/src/node/api/restful/common.ts b/core/src/node/api/restful/common.ts new file mode 100644 index 0000000000..c8061c34a2 --- /dev/null +++ b/core/src/node/api/restful/common.ts @@ -0,0 +1,82 @@ +import { HttpServer } from '../HttpServer' +import { + chatCompletions, + deleteBuilder, + downloadModel, + getBuilder, + retrieveBuilder, + createMessage, + createThread, + getMessages, + retrieveMessage, + updateThread, +} from './helper/builder' + +import { JanApiRouteConfiguration } from './helper/configuration' +import { startModel, stopModel } from './helper/startStopModel' +import { ModelSettingParams } from '../../../types' + +export const commonRouter = async (app: HttpServer) => { + const normalizeData = (data: any) => { + return { + object: 'list', + data, + } + } + // Common Routes + // Read & Delete :: Threads | Models | Assistants + Object.keys(JanApiRouteConfiguration).forEach((key) => { + app.get(`/${key}`, async (_request) => + getBuilder(JanApiRouteConfiguration[key]).then(normalizeData) + ) + + app.get(`/${key}/:id`, async (request: any) => + retrieveBuilder(JanApiRouteConfiguration[key], request.params.id) + ) + + app.delete(`/${key}/:id`, async (request: any) => + deleteBuilder(JanApiRouteConfiguration[key], request.params.id) + ) + }) + + // Threads + app.post(`/threads`, async (req, res) => createThread(req.body)) + + app.get(`/threads/:threadId/messages`, async (req, res) => + getMessages(req.params.threadId).then(normalizeData) + ) + + app.get(`/threads/:threadId/messages/:messageId`, async (req, res) => + retrieveMessage(req.params.threadId, req.params.messageId) + ) + + app.post(`/threads/:threadId/messages`, async (req, res) => + createMessage(req.params.threadId as any, req.body as any) + ) + + app.patch(`/threads/:threadId`, async (request: any) => + updateThread(request.params.threadId, request.body) + ) + + // Models + app.get(`/models/download/:modelId`, async (request: any) => + downloadModel(request.params.modelId, { + ignoreSSL: request.query.ignoreSSL === 'true', + proxy: request.query.proxy, + }) + ) + + app.put(`/models/:modelId/start`, async (request: any) => { + let settingParams: ModelSettingParams | undefined = undefined + if (Object.keys(request.body).length !== 0) { + settingParams = JSON.parse(request.body) as ModelSettingParams + } + + return startModel(request.params.modelId, settingParams) + }) + + app.put(`/models/:modelId/stop`, async (request: any) => stopModel(request.params.modelId)) + + // Chat Completion + app.post(`/chat/completions`, async (request: any, reply: any) => chatCompletions(request, reply)) +} diff --git a/core/src/node/api/restful/helper/builder.ts b/core/src/node/api/restful/helper/builder.ts new file mode 100644 index 0000000000..cd121cdb7e --- /dev/null +++ b/core/src/node/api/restful/helper/builder.ts @@ -0,0 +1,362 @@ +import { + existsSync, + readdirSync, + readFileSync, + writeFileSync, + mkdirSync, + appendFileSync, + createWriteStream, + rmdirSync, +} from 'fs' +import { JanApiRouteConfiguration, RouteConfiguration } from './configuration' +import { join } from 'path' +import { ContentType, MessageStatus, Model, ThreadMessage } from '../../../../types' +import { getEngineConfiguration, getJanDataFolderPath } from '../../../helper' +import { DEFAULT_CHAT_COMPLETION_URL } from './consts' + +// TODO: Refactor these +export const getBuilder = async (configuration: RouteConfiguration) => { + const directoryPath = join(getJanDataFolderPath(), configuration.dirName) + try { + if (!existsSync(directoryPath)) { + console.debug('model folder not found') + return [] + } + + const files: string[] = readdirSync(directoryPath) + + const allDirectories: string[] = [] + for (const file of files) { + if (file === '.DS_Store') continue + allDirectories.push(file) + } + + const results = allDirectories + .map((dirName) => { + const jsonPath = join(directoryPath, dirName, configuration.metadataFileName) + return readModelMetadata(jsonPath) + }) + .filter((data) => !!data) + const modelData = results + .map((result: any) => { + try { + return JSON.parse(result) + } catch (err) { + console.error(err) + } + }) + .filter((e: any) => !!e) + + return modelData + } catch (err) { + console.error(err) + return [] + } +} + +const readModelMetadata = (path: string): string | undefined => { + if (existsSync(path)) { + return readFileSync(path, 'utf-8') + } else { + return undefined + } +} + +export const retrieveBuilder = async (configuration: RouteConfiguration, id: string) => { + const data = await getBuilder(configuration) + const filteredData = data.filter((d: any) => d.id === id)[0] + + if (!filteredData) { + return undefined + } + + return filteredData +} + +export const deleteBuilder = async (configuration: RouteConfiguration, id: string) => { + if (configuration.dirName === 'assistants' && id === 'jan') { + return { + message: 'Cannot delete Jan assistant', + } + } + + const directoryPath = join(getJanDataFolderPath(), configuration.dirName) + try { + const data = await retrieveBuilder(configuration, id) + if (!data) { + return { + message: 'Not found', + } + } + + const objectPath = join(directoryPath, id) + rmdirSync(objectPath, { recursive: true }) + return { + id: id, + object: configuration.delete.object, + deleted: true, + } + } catch (ex) { + console.error(ex) + } +} + +export const getMessages = async (threadId: string): Promise => { + const threadDirPath = join(getJanDataFolderPath(), 'threads', threadId) + const messageFile = 'messages.jsonl' + try { + const files: string[] = readdirSync(threadDirPath) + if (!files.includes(messageFile)) { + console.error(`${threadDirPath} not contains message file`) + return [] + } + + const messageFilePath = join(threadDirPath, messageFile) + if (!existsSync(messageFilePath)) { + console.debug('message file not found') + return [] + } + + const lines = readFileSync(messageFilePath, 'utf-8') + .toString() + .split('\n') + .filter((line: any) => line !== '') + + const messages: ThreadMessage[] = [] + lines.forEach((line: string) => { + messages.push(JSON.parse(line) as ThreadMessage) + }) + return messages + } catch (err) { + console.error(err) + return [] + } +} + +export const retrieveMessage = async (threadId: string, messageId: string) => { + const messages = await getMessages(threadId) + const filteredMessages = messages.filter((m) => m.id === messageId) + if (!filteredMessages || filteredMessages.length === 0) { + return { + message: 'Not found', + } + } + + return filteredMessages[0] +} + +export const createThread = async (thread: any) => { + const threadMetadataFileName = 'thread.json' + // TODO: add validation + if (!thread.assistants || thread.assistants.length === 0) { + return { + message: 'Thread must have at least one assistant', + } + } + + const threadId = generateThreadId(thread.assistants[0].assistant_id) + try { + const updatedThread = { + ...thread, + id: threadId, + created: Date.now(), + updated: Date.now(), + } + const threadDirPath = join(getJanDataFolderPath(), 'threads', updatedThread.id) + const threadJsonPath = join(threadDirPath, threadMetadataFileName) + + if (!existsSync(threadDirPath)) { + mkdirSync(threadDirPath) + } + + await writeFileSync(threadJsonPath, JSON.stringify(updatedThread, null, 2)) + return updatedThread + } catch (err) { + return { + error: err, + } + } +} + +export const updateThread = async (threadId: string, thread: any) => { + const threadMetadataFileName = 'thread.json' + const currentThreadData = await retrieveBuilder(JanApiRouteConfiguration.threads, threadId) + if (!currentThreadData) { + return { + message: 'Thread not found', + } + } + // we don't want to update the id and object + delete thread.id + delete thread.object + + const updatedThread = { + ...currentThreadData, + ...thread, + updated: Date.now(), + } + try { + const threadDirPath = join(getJanDataFolderPath(), 'threads', updatedThread.id) + const threadJsonPath = join(threadDirPath, threadMetadataFileName) + + await writeFileSync(threadJsonPath, JSON.stringify(updatedThread, null, 2)) + return updatedThread + } catch (err) { + return { + message: err, + } + } +} + +const generateThreadId = (assistantId: string) => { + return `${assistantId}_${(Date.now() / 1000).toFixed(0)}` +} + +export const createMessage = async (threadId: string, message: any) => { + const threadMessagesFileName = 'messages.jsonl' + + try { + const { ulid } = require('ulidx') + const msgId = ulid() + const createdAt = Date.now() + const threadMessage: ThreadMessage = { + id: msgId, + thread_id: threadId, + status: MessageStatus.Ready, + created: createdAt, + updated: createdAt, + object: 'thread.message', + role: message.role, + content: [ + { + type: ContentType.Text, + text: { + value: message.content, + annotations: [], + }, + }, + ], + } + + const threadDirPath = join(getJanDataFolderPath(), 'threads', threadId) + const threadMessagePath = join(threadDirPath, threadMessagesFileName) + + if (!existsSync(threadDirPath)) { + mkdirSync(threadDirPath) + } + appendFileSync(threadMessagePath, JSON.stringify(threadMessage) + '\n') + return threadMessage + } catch (err) { + return { + message: err, + } + } +} + +export const downloadModel = async ( + modelId: string, + network?: { proxy?: string; ignoreSSL?: boolean } +) => { + const strictSSL = !network?.ignoreSSL + const proxy = network?.proxy?.startsWith('http') ? network.proxy : undefined + const model = await retrieveBuilder(JanApiRouteConfiguration.models, modelId) + if (!model || model.object !== 'model') { + return { + message: 'Model not found', + } + } + + const directoryPath = join(getJanDataFolderPath(), 'models', modelId) + if (!existsSync(directoryPath)) { + mkdirSync(directoryPath) + } + + // path to model binary + const modelBinaryPath = join(directoryPath, modelId) + + const request = require('request') + const progress = require('request-progress') + + for (const source of model.sources) { + const rq = request({ url: source, strictSSL, proxy }) + progress(rq, {}) + .on('progress', function (state: any) { + console.debug('progress', JSON.stringify(state, null, 2)) + }) + .on('error', function (err: Error) { + console.error('error', err) + }) + .on('end', function () { + console.debug('end') + }) + .pipe(createWriteStream(modelBinaryPath)) + } + + return { + message: `Starting download ${modelId}`, + } +} + +export const chatCompletions = async (request: any, reply: any) => { + const modelList = await getBuilder(JanApiRouteConfiguration.models) + const modelId = request.body.model + + const matchedModels = modelList.filter((model: Model) => model.id === modelId) + if (matchedModels.length === 0) { + const error = { + error: { + message: `The model ${request.body.model} does not exist`, + type: 'invalid_request_error', + param: null, + code: 'model_not_found', + }, + } + reply.code(404).send(error) + return + } + + const requestedModel = matchedModels[0] + + const engineConfiguration = await getEngineConfiguration(requestedModel.engine) + + let apiKey: string | undefined = undefined + let apiUrl: string = DEFAULT_CHAT_COMPLETION_URL + + if (engineConfiguration) { + apiKey = engineConfiguration.api_key + apiUrl = engineConfiguration.full_url ?? DEFAULT_CHAT_COMPLETION_URL + } + + const headers: Record = { + 'Content-Type': 'application/json', + } + + if (apiKey) { + headers['Authorization'] = `Bearer ${apiKey}` + headers['api-key'] = apiKey + } + + if (requestedModel.engine === 'openai' && request.body.stop) { + // openai only allows max 4 stop words + request.body.stop = request.body.stop.slice(0, 4) + } + + const fetch = require('node-fetch') + const response = await fetch(apiUrl, { + method: 'POST', + headers: headers, + body: JSON.stringify(request.body), + }) + if (response.status !== 200) { + console.error(response) + reply.code(400).send(response) + } else { + reply.raw.writeHead(200, { + 'Content-Type': request.body.stream === true ? 'text/event-stream' : 'application/json', + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + 'Access-Control-Allow-Origin': '*', + }) + response.body.pipe(reply.raw) + } +} diff --git a/core/src/node/api/restful/helper/configuration.ts b/core/src/node/api/restful/helper/configuration.ts new file mode 100644 index 0000000000..88e5ffb61e --- /dev/null +++ b/core/src/node/api/restful/helper/configuration.ts @@ -0,0 +1,31 @@ +export const JanApiRouteConfiguration: Record = { + models: { + dirName: 'models', + metadataFileName: 'model.json', + delete: { + object: 'model', + }, + }, + assistants: { + dirName: 'assistants', + metadataFileName: 'assistant.json', + delete: { + object: 'assistant', + }, + }, + threads: { + dirName: 'threads', + metadataFileName: 'thread.json', + delete: { + object: 'thread', + }, + }, +} + +export type RouteConfiguration = { + dirName: string + metadataFileName: string + delete: { + object: string + } +} diff --git a/core/src/node/api/restful/helper/consts.ts b/core/src/node/api/restful/helper/consts.ts new file mode 100644 index 0000000000..8d8f8e3410 --- /dev/null +++ b/core/src/node/api/restful/helper/consts.ts @@ -0,0 +1,19 @@ +// The PORT to use for the Nitro subprocess +export const NITRO_DEFAULT_PORT = 3928 + +// The HOST address to use for the Nitro subprocess +export const LOCAL_HOST = '127.0.0.1' + +export const SUPPORTED_MODEL_FORMAT = '.gguf' + +// The URL for the Nitro subprocess +const NITRO_HTTP_SERVER_URL = `http://${LOCAL_HOST}:${NITRO_DEFAULT_PORT}` +// The URL for the Nitro subprocess to load a model +export const NITRO_HTTP_LOAD_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/server/loadmodel` +// The URL for the Nitro subprocess to validate a model +export const NITRO_HTTP_VALIDATE_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/server/modelstatus` + +// The URL for the Nitro subprocess to kill itself +export const NITRO_HTTP_KILL_URL = `${NITRO_HTTP_SERVER_URL}/processmanager/destroy` + +export const DEFAULT_CHAT_COMPLETION_URL = `http://${LOCAL_HOST}:${NITRO_DEFAULT_PORT}/inferences/server/chat_completion` // default nitro url diff --git a/core/src/node/api/restful/helper/startStopModel.ts b/core/src/node/api/restful/helper/startStopModel.ts new file mode 100644 index 0000000000..8665850da8 --- /dev/null +++ b/core/src/node/api/restful/helper/startStopModel.ts @@ -0,0 +1,355 @@ +import fs from 'fs' +import { join } from 'path' +import { + getJanDataFolderPath, + getJanExtensionsPath, + getSystemResourceInfo, + log, +} from '../../../helper' +import { ChildProcessWithoutNullStreams, spawn } from 'child_process' +import { Model, ModelSettingParams, PromptTemplate } from '../../../../types' +import { + LOCAL_HOST, + NITRO_DEFAULT_PORT, + NITRO_HTTP_KILL_URL, + NITRO_HTTP_LOAD_MODEL_URL, + NITRO_HTTP_VALIDATE_MODEL_URL, + SUPPORTED_MODEL_FORMAT, +} from './consts' + +// The subprocess instance for Nitro +let subprocess: ChildProcessWithoutNullStreams | undefined = undefined + +// TODO: move this to core type +interface NitroModelSettings extends ModelSettingParams { + llama_model_path: string + cpu_threads: number +} + +export const startModel = async (modelId: string, settingParams?: ModelSettingParams) => { + try { + await runModel(modelId, settingParams) + + return { + message: `Model ${modelId} started`, + } + } catch (e) { + return { + error: e, + } + } +} + +const runModel = async (modelId: string, settingParams?: ModelSettingParams): Promise => { + const janDataFolderPath = getJanDataFolderPath() + const modelFolderFullPath = join(janDataFolderPath, 'models', modelId) + + if (!fs.existsSync(modelFolderFullPath)) { + throw new Error(`Model not found: ${modelId}`) + } + + const files: string[] = fs.readdirSync(modelFolderFullPath) + + // Look for GGUF model file + const ggufBinFile = files.find((file) => file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT)) + + const modelMetadataPath = join(modelFolderFullPath, 'model.json') + const modelMetadata: Model = JSON.parse(fs.readFileSync(modelMetadataPath, 'utf-8')) + + if (!ggufBinFile) { + throw new Error('No GGUF model file found') + } + const modelBinaryPath = join(modelFolderFullPath, ggufBinFile) + + const nitroResourceProbe = await getSystemResourceInfo() + const nitroModelSettings: NitroModelSettings = { + // This is critical and requires real CPU physical core count (or performance core) + cpu_threads: Math.max(1, nitroResourceProbe.numCpuPhysicalCore), + ...modelMetadata.settings, + ...settingParams, + llama_model_path: modelBinaryPath, + ...(modelMetadata.settings.mmproj && { + mmproj: join(modelFolderFullPath, modelMetadata.settings.mmproj), + }), + } + + log(`[SERVER]::Debug: Nitro model settings: ${JSON.stringify(nitroModelSettings)}`) + + // Convert settings.prompt_template to system_prompt, user_prompt, ai_prompt + if (modelMetadata.settings.prompt_template) { + const promptTemplate = modelMetadata.settings.prompt_template + const prompt = promptTemplateConverter(promptTemplate) + if (prompt?.error) { + throw new Error(prompt.error) + } + nitroModelSettings.system_prompt = prompt.system_prompt + nitroModelSettings.user_prompt = prompt.user_prompt + nitroModelSettings.ai_prompt = prompt.ai_prompt + } + + await runNitroAndLoadModel(modelId, nitroModelSettings) +} + +// TODO: move to util +const promptTemplateConverter = (promptTemplate: string): PromptTemplate => { + // Split the string using the markers + const systemMarker = '{system_message}' + const promptMarker = '{prompt}' + + if (promptTemplate.includes(systemMarker) && promptTemplate.includes(promptMarker)) { + // Find the indices of the markers + const systemIndex = promptTemplate.indexOf(systemMarker) + const promptIndex = promptTemplate.indexOf(promptMarker) + + // Extract the parts of the string + const system_prompt = promptTemplate.substring(0, systemIndex) + const user_prompt = promptTemplate.substring(systemIndex + systemMarker.length, promptIndex) + const ai_prompt = promptTemplate.substring(promptIndex + promptMarker.length) + + // Return the split parts + return { system_prompt, user_prompt, ai_prompt } + } else if (promptTemplate.includes(promptMarker)) { + // Extract the parts of the string for the case where only promptMarker is present + const promptIndex = promptTemplate.indexOf(promptMarker) + const user_prompt = promptTemplate.substring(0, promptIndex) + const ai_prompt = promptTemplate.substring(promptIndex + promptMarker.length) + + // Return the split parts + return { user_prompt, ai_prompt } + } + + // Return an error if none of the conditions are met + return { error: 'Cannot split prompt template' } +} + +const runNitroAndLoadModel = async (modelId: string, modelSettings: NitroModelSettings) => { + // Gather system information for CPU physical cores and memory + const tcpPortUsed = require('tcp-port-used') + + await stopModel(modelId) + await tcpPortUsed.waitUntilFree(NITRO_DEFAULT_PORT, 300, 5000) + + /** + * There is a problem with Windows process manager + * Should wait for awhile to make sure the port is free and subprocess is killed + * The tested threshold is 500ms + **/ + if (process.platform === 'win32') { + await new Promise((resolve) => setTimeout(resolve, 500)) + } + + await spawnNitroProcess() + await loadLLMModel(modelSettings) + await validateModelStatus() +} + +const spawnNitroProcess = async (): Promise => { + log(`[SERVER]::Debug: Spawning cortex subprocess...`) + + let binaryFolder = join( + getJanExtensionsPath(), + '@janhq', + 'inference-cortex-extension', + 'dist', + 'bin' + ) + + let executableOptions = executableNitroFile() + const tcpPortUsed = require('tcp-port-used') + + const args: string[] = ['1', LOCAL_HOST, NITRO_DEFAULT_PORT.toString()] + // Execute the binary + log( + `[SERVER]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` + ) + subprocess = spawn( + executableOptions.executablePath, + ['1', LOCAL_HOST, NITRO_DEFAULT_PORT.toString()], + { + cwd: binaryFolder, + env: { + ...process.env, + CUDA_VISIBLE_DEVICES: executableOptions.cudaVisibleDevices, + }, + } + ) + + // Handle subprocess output + subprocess.stdout.on('data', (data: any) => { + log(`[SERVER]::Debug: ${data}`) + }) + + subprocess.stderr.on('data', (data: any) => { + log(`[SERVER]::Error: ${data}`) + }) + + subprocess.on('close', (code: any) => { + log(`[SERVER]::Debug: cortex exited with code: ${code}`) + subprocess = undefined + }) + + tcpPortUsed.waitUntilUsed(NITRO_DEFAULT_PORT, 300, 30000).then(() => { + log(`[SERVER]::Debug: cortex is ready`) + }) +} + +type NitroExecutableOptions = { + executablePath: string + cudaVisibleDevices: string +} + +const executableNitroFile = (): NitroExecutableOptions => { + const nvidiaInfoFilePath = join(getJanDataFolderPath(), 'settings', 'settings.json') + let binaryFolder = join( + getJanExtensionsPath(), + '@janhq', + 'inference-cortex-extension', + 'dist', + 'bin' + ) + + let cudaVisibleDevices = '' + let binaryName = 'cortex-cpp' + /** + * The binary folder is different for each platform. + */ + if (process.platform === 'win32') { + /** + * For Windows: win-cpu, win-cuda-11-7, win-cuda-12-0 + */ + let nvidiaInfo = JSON.parse(fs.readFileSync(nvidiaInfoFilePath, 'utf-8')) + if (nvidiaInfo['run_mode'] === 'cpu') { + binaryFolder = join(binaryFolder, 'win-cpu') + } else { + if (nvidiaInfo['cuda'].version === '12') { + binaryFolder = join(binaryFolder, 'win-cuda-12-0') + } else { + binaryFolder = join(binaryFolder, 'win-cuda-11-7') + } + cudaVisibleDevices = nvidiaInfo['gpu_highest_vram'] + } + binaryName = 'cortex-cpp.exe' + } else if (process.platform === 'darwin') { + /** + * For MacOS: mac-universal both Silicon and InteL + */ + if(process.arch === 'arm64') { + binaryFolder = join(binaryFolder, 'mac-arm64') + } else { + binaryFolder = join(binaryFolder, 'mac-amd64') + } + } else { + /** + * For Linux: linux-cpu, linux-cuda-11-7, linux-cuda-12-0 + */ + let nvidiaInfo = JSON.parse(fs.readFileSync(nvidiaInfoFilePath, 'utf-8')) + if (nvidiaInfo['run_mode'] === 'cpu') { + binaryFolder = join(binaryFolder, 'linux-cpu') + } else { + if (nvidiaInfo['cuda'].version === '12') { + binaryFolder = join(binaryFolder, 'linux-cuda-12-0') + } else { + binaryFolder = join(binaryFolder, 'linux-cuda-11-7') + } + cudaVisibleDevices = nvidiaInfo['gpu_highest_vram'] + } + } + + return { + executablePath: join(binaryFolder, binaryName), + cudaVisibleDevices, + } +} + +const validateModelStatus = async (): Promise => { + // Send a GET request to the validation URL. + // Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries. + const fetchRT = require('fetch-retry') + const fetchRetry = fetchRT(fetch) + + return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + retries: 5, + retryDelay: 500, + }).then(async (res: Response) => { + log(`[SERVER]::Debug: Validate model state success with response ${JSON.stringify(res)}`) + // If the response is OK, check model_loaded status. + if (res.ok) { + const body = await res.json() + // If the model is loaded, return an empty object. + // Otherwise, return an object with an error message. + if (body.model_loaded) { + return Promise.resolve() + } + } + return Promise.reject('Validate model status failed') + }) +} + +const loadLLMModel = async (settings: NitroModelSettings): Promise => { + log(`[SERVER]::Debug: Loading model with params ${JSON.stringify(settings)}`) + const fetchRT = require('fetch-retry') + const fetchRetry = fetchRT(fetch) + + return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(settings), + retries: 3, + retryDelay: 500, + }) + .then((res: any) => { + log(`[SERVER]::Debug: Load model request with response ${JSON.stringify(res)}`) + return Promise.resolve(res) + }) + .catch((err: any) => { + log(`[SERVER]::Error: Load model failed with error ${err}`) + return Promise.reject(err) + }) +} + +/** + * Stop model and kill nitro process. + */ +export const stopModel = async (_modelId: string) => { + if (!subprocess) { + return { + error: "Model isn't running", + } + } + return new Promise((resolve, reject) => { + const controller = new AbortController() + setTimeout(() => { + controller.abort() + reject({ + error: 'Failed to stop model: Timedout', + }) + }, 5000) + const tcpPortUsed = require('tcp-port-used') + log(`[SERVER]::Debug: Request to kill cortex`) + + fetch(NITRO_HTTP_KILL_URL, { + method: 'DELETE', + signal: controller.signal, + }) + .then(() => { + subprocess?.kill() + subprocess = undefined + }) + .catch(() => { + // don't need to do anything, we still kill the subprocess + }) + .then(() => tcpPortUsed.waitUntilFree(NITRO_DEFAULT_PORT, 300, 5000)) + .then(() => log(`[SERVER]::Debug: Nitro process is terminated`)) + .then(() => + resolve({ + message: 'Model stopped', + }) + ) + }) +} diff --git a/core/src/node/api/restful/v1.ts b/core/src/node/api/restful/v1.ts new file mode 100644 index 0000000000..5eb8f50679 --- /dev/null +++ b/core/src/node/api/restful/v1.ts @@ -0,0 +1,16 @@ +import { HttpServer } from '../HttpServer' +import { commonRouter } from './common' +import { downloadRouter } from './app/download' +import { handleRequests } from './app/handlers' + +export const v1Router = async (app: HttpServer) => { + // MARK: Public API Routes + app.register(commonRouter) + + // MARK: Internal Application Routes + handleRequests(app) + + // Expanded route for tracking download progress + // TODO: Replace by Observer Wrapper (ZeroMQ / Vanilla Websocket) + app.register(downloadRouter) +} diff --git a/core/src/node/extension/extension.ts b/core/src/node/extension/extension.ts new file mode 100644 index 0000000000..849f2d5f28 --- /dev/null +++ b/core/src/node/extension/extension.ts @@ -0,0 +1,203 @@ +import { rmdirSync } from 'fs' +import { resolve, join } from 'path' +import { ExtensionManager } from './manager' + +/** + * An NPM package that can be used as an extension. + * Used to hold all the information and functions necessary to handle the extension lifecycle. + */ +export default class Extension { + /** + * @property {string} origin Original specification provided to fetch the package. + * @property {Object} installOptions Options provided to pacote when fetching the manifest. + * @property {name} name The name of the extension as defined in the manifest. + * @property {name} productName The display name of the extension as defined in the manifest. + * @property {string} url Electron URL where the package can be accessed. + * @property {string} version Version of the package as defined in the manifest. + * @property {string} main The entry point as defined in the main entry of the manifest. + * @property {string} description The description of extension as defined in the manifest. + */ + origin?: string + installOptions: any + name?: string + productName?: string + url?: string + version?: string + main?: string + description?: string + + /** @private */ + _active = false + + /** + * @private + * @property {Object.} #listeners A list of callbacks to be executed when the Extension is updated. + */ + listeners: Record void> = {} + + /** + * Set installOptions with defaults for options that have not been provided. + * @param {string} [origin] Original specification provided to fetch the package. + * @param {Object} [options] Options provided to pacote when fetching the manifest. + */ + constructor(origin?: string, options = {}) { + const Arborist = require('@npmcli/arborist') + const defaultOpts = { + version: false, + fullMetadata: true, + Arborist, + } + + this.origin = origin + this.installOptions = { ...defaultOpts, ...options } + } + + /** + * Package name with version number. + * @type {string} + */ + get specifier() { + return this.origin + (this.installOptions.version ? '@' + this.installOptions.version : '') + } + + /** + * Whether the extension should be registered with its activation points. + * @type {boolean} + */ + get active() { + return this._active + } + + /** + * Set Package details based on it's manifest + * @returns {Promise.} Resolves to true when the action completed + */ + async getManifest() { + // Get the package's manifest (package.json object) + try { + await import('pacote').then((pacote) => { + return pacote.manifest(this.specifier, this.installOptions).then((mnf) => { + // set the Package properties based on the it's manifest + this.name = mnf.name + this.productName = mnf.productName as string | undefined + this.version = mnf.version + this.main = mnf.main + this.description = mnf.description + }) + }) + } catch (error) { + throw new Error(`Package ${this.origin} does not contain a valid manifest: ${error}`) + } + + return true + } + + /** + * Extract extension to extensions folder. + * @returns {Promise.} This extension + * @private + */ + async _install() { + try { + // import the manifest details + await this.getManifest() + + // Install the package in a child folder of the given folder + const pacote = await import('pacote') + await pacote.extract( + this.specifier, + join(ExtensionManager.instance.getExtensionsPath() ?? '', this.name ?? ''), + this.installOptions + ) + + // Set the url using the custom extensions protocol + this.url = `extension://${this.name}/${this.main}` + + this.emitUpdate() + } catch (err) { + // Ensure the extension is not stored and the folder is removed if the installation fails + this.setActive(false) + throw err + } + + return [this] + } + + /** + * Subscribe to updates of this extension + * @param {string} name name of the callback to register + * @param {callback} cb The function to execute on update + */ + subscribe(name: string, cb: () => void) { + this.listeners[name] = cb + } + + /** + * Remove subscription + * @param {string} name name of the callback to remove + */ + unsubscribe(name: string) { + delete this.listeners[name] + } + + /** + * Execute listeners + */ + emitUpdate() { + for (const cb in this.listeners) { + this.listeners[cb].call(null, this) + } + } + + /** + * Check for updates and install if available. + * @param {string} version The version to update to. + * @returns {boolean} Whether an update was performed. + */ + async update(version = false) { + if (await this.isUpdateAvailable()) { + this.installOptions.version = version + await this._install() + return true + } + + return false + } + + /** + * Check if a new version of the extension is available at the origin. + * @returns the latest available version if a new version is available or false if not. + */ + async isUpdateAvailable() { + return import('pacote').then((pacote) => { + if (this.origin) { + return pacote.manifest(this.origin).then((mnf) => { + return mnf.version !== this.version ? mnf.version : false + }) + } + }) + } + + /** + * Remove extension and refresh renderers. + * @returns {Promise} + */ + async uninstall(): Promise { + const path = ExtensionManager.instance.getExtensionsPath() + const extPath = resolve(path ?? '', this.name ?? '') + rmdirSync(extPath, { recursive: true }) + + this.emitUpdate() + } + + /** + * Set a extension's active state. This determines if a extension should be loaded on initialisation. + * @param {boolean} active State to set _active to + * @returns {Extension} This extension + */ + setActive(active: boolean) { + this._active = active + this.emitUpdate() + return this + } +} diff --git a/core/src/node/extension/index.ts b/core/src/node/extension/index.ts new file mode 100644 index 0000000000..994fc97f2f --- /dev/null +++ b/core/src/node/extension/index.ts @@ -0,0 +1,136 @@ +import { readFileSync } from 'fs' + +import { normalize } from 'path' + +import Extension from './extension' +import { + getAllExtensions, + removeExtension, + persistExtensions, + installExtensions, + getExtension, + getActiveExtensions, + addExtension, +} from './store' +import { ExtensionManager } from './manager' + +export function init(options: any) { + // Create extensions protocol to serve extensions to renderer + registerExtensionProtocol() + + // perform full setup if extensionsPath is provided + if (options.extensionsPath) { + return useExtensions(options.extensionsPath) + } + + return {} +} + +/** + * Create extensions protocol to provide extensions to renderer + * @private + * @returns {boolean} Whether the protocol registration was successful + */ +async function registerExtensionProtocol() { + let electron: any = undefined + + try { + const moduleName = 'electron' + electron = await import(moduleName) + } catch (err) { + console.error('Electron is not available') + } + const extensionPath = ExtensionManager.instance.getExtensionsPath() + if (electron && electron.protocol) { + return electron.protocol?.registerFileProtocol('extension', (request: any, callback: any) => { + const entry = request.url.substr('extension://'.length - 1) + + const url = normalize(extensionPath + entry) + callback({ path: url }) + }) + } +} + +/** + * Set extensions up to run from the extensionPath folder if it is provided and + * load extensions persisted in that folder. + * @param {string} extensionsPath Path to the extensions folder. Required if not yet set up. + * @returns {extensionManager} A set of functions used to manage the extension lifecycle. + */ +export function useExtensions(extensionsPath: string) { + if (!extensionsPath) throw Error('A path to the extensions folder is required to use extensions') + // Store the path to the extensions folder + ExtensionManager.instance.setExtensionsPath(extensionsPath) + + // Remove any registered extensions + for (const extension of getAllExtensions()) { + if (extension.name) removeExtension(extension.name, false) + } + + // Read extension list from extensions folder + const extensions = JSON.parse( + readFileSync(ExtensionManager.instance.getExtensionsFile(), 'utf-8') + ) + try { + // Create and store a Extension instance for each extension in list + for (const p in extensions) { + loadExtension(extensions[p]) + } + persistExtensions() + } catch (error) { + // Throw meaningful error if extension loading fails + throw new Error( + 'Could not successfully rebuild list of installed extensions.\n' + + error + + '\nPlease check the extensions.json file in the extensions folder.' + ) + } + + // Return the extension lifecycle functions + return getStore() +} + +/** + * Check the given extension object. If it is marked for uninstalling, the extension files are removed. + * Otherwise a Extension instance for the provided object is created and added to the store. + * @private + * @param {Object} ext Extension info + */ +function loadExtension(ext: any) { + // Create new extension, populate it with ext details and save it to the store + const extension = new Extension() + + for (const key in ext) { + if (Object.prototype.hasOwnProperty.call(ext, key)) { + // Use Object.defineProperty to set the properties as writable + Object.defineProperty(extension, key, { + value: ext[key], + writable: true, + enumerable: true, + configurable: true, + }) + } + } + addExtension(extension, false) + extension.subscribe('pe-persist', persistExtensions) +} + +/** + * Returns the publicly available store functions. + * @returns {extensionManager} A set of functions used to manage the extension lifecycle. + */ +export function getStore() { + if (!ExtensionManager.instance.getExtensionsFile()) { + throw new Error( + 'The extension path has not yet been set up. Please run useExtensions before accessing the store' + ) + } + + return { + installExtensions, + getExtension, + getAllExtensions, + getActiveExtensions, + removeExtension, + } +} diff --git a/core/src/node/extension/manager.ts b/core/src/node/extension/manager.ts new file mode 100644 index 0000000000..c66d7b1633 --- /dev/null +++ b/core/src/node/extension/manager.ts @@ -0,0 +1,45 @@ +import { join, resolve } from 'path' + +import { existsSync, mkdirSync, writeFileSync } from 'fs' + +/** + * Manages extension installation and migration. + */ + +export class ExtensionManager { + public static instance: ExtensionManager = new ExtensionManager() + + private extensionsPath: string | undefined + + constructor() { + if (ExtensionManager.instance) { + return ExtensionManager.instance + } + } + + getExtensionsPath(): string | undefined { + return this.extensionsPath + } + + setExtensionsPath(extPath: string) { + // Create folder if it does not exist + let extDir + try { + extDir = resolve(extPath) + if (extDir.length < 2) throw new Error() + + if (!existsSync(extDir)) mkdirSync(extDir) + + const extensionsJson = join(extDir, 'extensions.json') + if (!existsSync(extensionsJson)) writeFileSync(extensionsJson, '{}') + + this.extensionsPath = extDir + } catch (error) { + throw new Error('Invalid path provided to the extensions folder') + } + } + + getExtensionsFile() { + return join(this.extensionsPath ?? '', 'extensions.json') + } +} diff --git a/core/src/node/extension/store.ts b/core/src/node/extension/store.ts new file mode 100644 index 0000000000..630756485d --- /dev/null +++ b/core/src/node/extension/store.ts @@ -0,0 +1,125 @@ +import { writeFileSync } from 'fs' +import Extension from './extension' +import { ExtensionManager } from './manager' + +/** + * @module store + * @private + */ + +/** + * Register of installed extensions + * @type {Object.} extension - List of installed extensions + */ +const extensions: Record = {} + +/** + * Get a extension from the stored extensions. + * @param {string} name Name of the extension to retrieve + * @returns {Extension} Retrieved extension + * @alias extensionManager.getExtension + */ +export function getExtension(name: string) { + if (!Object.prototype.hasOwnProperty.call(extensions, name)) { + throw new Error(`Extension ${name} does not exist`) + } + + return extensions[name] +} + +/** + * Get list of all extension objects. + * @returns {Array.} All extension objects + * @alias extensionManager.getAllExtensions + */ +export function getAllExtensions() { + return Object.values(extensions) +} + +/** + * Get list of active extension objects. + * @returns {Array.} Active extension objects + * @alias extensionManager.getActiveExtensions + */ +export function getActiveExtensions() { + return Object.values(extensions).filter((extension) => extension.active) +} + +/** + * Remove extension from store and maybe save stored extensions to file + * @param {string} name Name of the extension to remove + * @param {boolean} persist Whether to save the changes to extensions to file + * @returns {boolean} Whether the delete was successful + * @alias extensionManager.removeExtension + */ +export function removeExtension(name: string, persist = true) { + const del = delete extensions[name] + if (persist) persistExtensions() + return del +} + +/** + * Add extension to store and maybe save stored extensions to file + * @param {Extension} extension Extension to add to store + * @param {boolean} persist Whether to save the changes to extensions to file + * @returns {void} + */ +export function addExtension(extension: Extension, persist = true) { + if (extension.name) extensions[extension.name] = extension + if (persist) { + persistExtensions() + extension.subscribe('pe-persist', persistExtensions) + } +} + +/** + * Save stored extensions to file + * @returns {void} + */ +export function persistExtensions() { + const persistData: Record = {} + for (const name in extensions) { + persistData[name] = extensions[name] + } + writeFileSync(ExtensionManager.instance.getExtensionsFile(), JSON.stringify(persistData)) +} + +/** + * Create and install a new extension for the given specifier. + * @param {Array.} extensions A list of NPM specifiers, or installation configuration objects. + * @param {boolean} [store=true] Whether to store the installed extensions in the store + * @returns {Promise.>} New extension + * @alias extensionManager.installExtensions + */ +export async function installExtensions(extensions: any) { + const installed: Extension[] = [] + const installations = extensions.map((ext: any): Promise => { + const isObject = typeof ext === 'object' + const spec = isObject ? [ext.specifier, ext] : [ext] + const activate = isObject ? ext.activate !== false : true + + // Install and possibly activate extension + const extension = new Extension(...spec) + if (!extension.origin) { + return Promise.resolve() + } + return extension._install().then(() => { + if (activate) extension.setActive(true) + // Add extension to store if needed + addExtension(extension) + installed.push(extension) + }) + }) + + await Promise.all(installations) + + // Return list of all installed extensions + return installed +} + +/** + * @typedef {Object.} installOptions The {@link https://www.npmjs.com/package/pacote|pacote} + * options used to install the extension with some extra options. + * @param {string} specifier the NPM specifier that identifies the package. + * @param {boolean} [activate] Whether this extension should be activated after installation. Defaults to true. + */ diff --git a/core/src/node/helper/config.ts b/core/src/node/helper/config.ts new file mode 100644 index 0000000000..1a341a6252 --- /dev/null +++ b/core/src/node/helper/config.ts @@ -0,0 +1,157 @@ +import { AppConfiguration, SettingComponentProps } from '../../types' +import { join } from 'path' +import fs from 'fs' +import os from 'os' +import childProcess from 'child_process' + +const configurationFileName = 'settings.json' + +// TODO: do no specify app name in framework module +// TODO: do not default the os.homedir +const defaultJanDataFolder = join(os?.homedir() || '', 'jan') +const defaultAppConfig: AppConfiguration = { + data_folder: defaultJanDataFolder, + quick_ask: false, +} + +/** + * Getting App Configurations. + * + * @returns {AppConfiguration} The app configurations. + */ +export const getAppConfigurations = (): AppConfiguration => { + // Retrieve Application Support folder path + // Fallback to user home directory if not found + const configurationFile = getConfigurationFilePath() + + if (!fs.existsSync(configurationFile)) { + // create default app config if we don't have one + console.debug(`App config not found, creating default config at ${configurationFile}`) + fs.writeFileSync(configurationFile, JSON.stringify(defaultAppConfig)) + return defaultAppConfig + } + + try { + const appConfigurations: AppConfiguration = JSON.parse( + fs.readFileSync(configurationFile, 'utf-8') + ) + return appConfigurations + } catch (err) { + console.error(`Failed to read app config, return default config instead! Err: ${err}`) + return defaultAppConfig + } +} + +const getConfigurationFilePath = () => + join( + global.core?.appPath() || process.env[process.platform == 'win32' ? 'USERPROFILE' : 'HOME'], + configurationFileName + ) + +export const updateAppConfiguration = (configuration: AppConfiguration): Promise => { + const configurationFile = getConfigurationFilePath() + console.debug('updateAppConfiguration, configurationFile: ', configurationFile) + + fs.writeFileSync(configurationFile, JSON.stringify(configuration)) + return Promise.resolve() +} + +/** + * Utility function to get data folder path + * + * @returns {string} The data folder path. + */ +export const getJanDataFolderPath = (): string => { + const appConfigurations = getAppConfigurations() + return appConfigurations.data_folder +} + +/** + * Utility function to get extension path + * + * @returns {string} The extensions path. + */ +export const getJanExtensionsPath = (): string => { + const appConfigurations = getAppConfigurations() + return join(appConfigurations.data_folder, 'extensions') +} + +/** + * Utility function to physical cpu count + * + * @returns {number} The physical cpu count. + */ +export const physicalCpuCount = async (): Promise => { + const platform = os.platform() + try { + if (platform === 'linux') { + const output = await exec('lscpu -p | egrep -v "^#" | sort -u -t, -k 2,4 | wc -l') + return parseInt(output.trim(), 10) + } else if (platform === 'darwin') { + const output = await exec('sysctl -n hw.physicalcpu_max') + return parseInt(output.trim(), 10) + } else if (platform === 'win32') { + const output = await exec('WMIC CPU Get NumberOfCores') + return output + .split(os.EOL) + .map((line: string) => parseInt(line)) + .filter((value: number) => !isNaN(value)) + .reduce((sum: number, number: number) => sum + number, 1) + } else { + const cores = os.cpus().filter((cpu: any, index: number) => { + const hasHyperthreading = cpu.model.includes('Intel') + const isOdd = index % 2 === 1 + return !hasHyperthreading || isOdd + }) + return cores.length + } + } catch (err) { + console.warn('Failed to get physical CPU count', err) + // Divide by 2 to get rid of hyper threading + const coreCount = Math.ceil(os.cpus().length / 2) + console.debug('Using node API to get physical CPU count:', coreCount) + return coreCount + } +} + +const exec = async (command: string): Promise => { + return new Promise((resolve, reject) => { + childProcess.exec(command, { encoding: 'utf8' }, (error, stdout) => { + if (error) { + reject(error) + } else { + resolve(stdout) + } + }) + }) +} + +// a hacky way to get the api key. we should comes up with a better +// way to handle this +export const getEngineConfiguration = async (engineId: string) => { + if (engineId !== 'openai' && engineId !== 'groq') return undefined + + const settingDirectoryPath = join( + getJanDataFolderPath(), + 'settings', + '@janhq', + engineId === 'openai' ? 'inference-openai-extension' : 'inference-groq-extension', + 'settings.json' + ) + + const content = fs.readFileSync(settingDirectoryPath, 'utf-8') + const settings: SettingComponentProps[] = JSON.parse(content) + const apiKeyId = engineId === 'openai' ? 'openai-api-key' : 'groq-api-key' + const keySetting = settings.find((setting) => setting.key === apiKeyId) + let fullUrl = settings.find((setting) => setting.key === 'chat-completions-endpoint') + ?.controllerProps.value + + let apiKey = keySetting?.controllerProps.value + if (typeof apiKey !== 'string') apiKey = '' + if (typeof fullUrl !== 'string') fullUrl = '' + + return { + api_key: apiKey, + full_url: fullUrl, + } +} diff --git a/core/src/node/helper/download.ts b/core/src/node/helper/download.ts new file mode 100644 index 0000000000..51a0b0a8f7 --- /dev/null +++ b/core/src/node/helper/download.ts @@ -0,0 +1,30 @@ +import { DownloadState } from '../../types' + +/** + * Manages file downloads and network requests. + */ +export class DownloadManager { + public networkRequests: Record = {} + + public static instance: DownloadManager = new DownloadManager() + + // store the download information with key is model id + public downloadProgressMap: Record = {} + + // store the download information with key is normalized file path + public downloadInfo: Record = {} + + constructor() { + if (DownloadManager.instance) { + return DownloadManager.instance + } + } + /** + * Sets a network request for a specific file. + * @param {string} fileName - The name of the file. + * @param {Request | undefined} request - The network request to set, or undefined to clear the request. + */ + setRequest(fileName: string, request: any | undefined) { + this.networkRequests[fileName] = request + } +} diff --git a/core/src/node/helper/index.ts b/core/src/node/helper/index.ts new file mode 100644 index 0000000000..51030023f8 --- /dev/null +++ b/core/src/node/helper/index.ts @@ -0,0 +1,6 @@ +export * from './config' +export * from './download' +export * from './logger' +export * from './module' +export * from './path' +export * from './resource' diff --git a/core/src/node/helper/logger.ts b/core/src/node/helper/logger.ts new file mode 100644 index 0000000000..a6b3c8befb --- /dev/null +++ b/core/src/node/helper/logger.ts @@ -0,0 +1,81 @@ +// Abstract Logger class that all loggers should extend. +export abstract class Logger { + // Each logger must have a unique name. + abstract name: string + + /** + * Log message to log file. + * This method should be overridden by subclasses to provide specific logging behavior. + */ + abstract log(args: any): void +} + +// LoggerManager is a singleton class that manages all registered loggers. +export class LoggerManager { + // Map of registered loggers, keyed by their names. + public loggers = new Map() + + // Array to store logs that are queued before the loggers are registered. + queuedLogs: any[] = [] + + // Flag to indicate whether flushLogs is currently running. + private isFlushing = false + + // Register a new logger. If a logger with the same name already exists, it will be replaced. + register(logger: Logger) { + this.loggers.set(logger.name, logger) + } + // Unregister a logger by its name. + unregister(name: string) { + this.loggers.delete(name) + } + + get(name: string) { + return this.loggers.get(name) + } + + // Flush queued logs to all registered loggers. + flushLogs() { + // If flushLogs is already running, do nothing. + if (this.isFlushing) { + return + } + + this.isFlushing = true + + while (this.queuedLogs.length > 0 && this.loggers.size > 0) { + const log = this.queuedLogs.shift() + this.loggers.forEach((logger) => { + logger.log(log) + }) + } + + this.isFlushing = false + } + + // Log message using all registered loggers. + log(args: any) { + this.queuedLogs.push(args) + + this.flushLogs() + } + + /** + * The instance of the logger. + * If an instance doesn't exist, it creates a new one. + * This ensures that there is only one LoggerManager instance at any time. + */ + static instance(): LoggerManager { + let instance: LoggerManager | undefined = global.core?.logger + if (!instance) { + instance = new LoggerManager() + if (!global.core) global.core = {} + global.core.logger = instance + } + return instance + } +} + +export const log = (...args: any) => { + LoggerManager.instance().log(args) +} diff --git a/core/src/node/helper/module.ts b/core/src/node/helper/module.ts new file mode 100644 index 0000000000..0919667df5 --- /dev/null +++ b/core/src/node/helper/module.ts @@ -0,0 +1,31 @@ +/** + * Manages imported modules. + */ +export class ModuleManager { + public requiredModules: Record = {} + public cleaningResource = false + + public static instance: ModuleManager = new ModuleManager() + + constructor() { + if (ModuleManager.instance) { + return ModuleManager.instance + } + } + + /** + * Sets a module. + * @param {string} moduleName - The name of the module. + * @param {any | undefined} nodule - The module to set, or undefined to clear the module. + */ + setModule(moduleName: string, nodule: any | undefined) { + this.requiredModules[moduleName] = nodule + } + + /** + * Clears all imported modules. + */ + clearImportedModules() { + this.requiredModules = {} + } +} diff --git a/core/src/node/helper/path.ts b/core/src/node/helper/path.ts new file mode 100644 index 0000000000..a2d57ed3e7 --- /dev/null +++ b/core/src/node/helper/path.ts @@ -0,0 +1,44 @@ +import { join, resolve } from 'path' +import { getJanDataFolderPath } from './config' + +/** + * Normalize file path + * Remove all file protocol prefix + * @param path + * @returns + */ +export function normalizeFilePath(path: string): string { + return path.replace(/^(file:[\\/]+)([^:\s]+)$/, '$2') +} + +export async function appResourcePath(): Promise { + let electron: any = undefined + + try { + const moduleName = 'electron' + electron = await import(moduleName) + } catch (err) { + console.error('Electron is not available') + } + + // electron + if (electron && electron.protocol) { + let appPath = join(electron.app.getAppPath(), '..', 'app.asar.unpacked') + + if (!electron.app.isPackaged) { + // for development mode + appPath = join(electron.app.getAppPath()) + } + return appPath + } + // server + return join(global.core.appPath(), '../../..') +} + +export function validatePath(path: string) { + const janDataFolderPath = getJanDataFolderPath() + const absolutePath = resolve(__dirname, path) + if (!absolutePath.startsWith(janDataFolderPath)) { + throw new Error(`Invalid path: ${absolutePath}`) + } +} diff --git a/core/src/node/helper/resource.ts b/core/src/node/helper/resource.ts new file mode 100644 index 0000000000..c7bfbf20c7 --- /dev/null +++ b/core/src/node/helper/resource.ts @@ -0,0 +1,13 @@ +import { SystemResourceInfo } from '../../types' +import { physicalCpuCount } from './config' +import { log } from './logger' + +export const getSystemResourceInfo = async (): Promise => { + const cpu = await physicalCpuCount() + log(`[CORTEX]::CPU information - ${cpu}`) + + return { + numCpuPhysicalCore: cpu, + memAvailable: 0, // TODO: this should not be 0 + } +} diff --git a/core/src/node/index.ts b/core/src/node/index.ts new file mode 100644 index 0000000000..eb60270752 --- /dev/null +++ b/core/src/node/index.ts @@ -0,0 +1,8 @@ +export * from './extension/index' +export * from './extension/extension' +export * from './extension/manager' +export * from './extension/store' +export * from './api' +export * from './helper' +export * from './../types' +export * from '../types/api' diff --git a/core/src/types/api/index.ts b/core/src/types/api/index.ts index 267441f4ab..e50dce6de8 100644 --- a/core/src/types/api/index.ts +++ b/core/src/types/api/index.ts @@ -1,3 +1,5 @@ +import { ChatCompletionMessage } from '../inference' + /** * Native Route APIs * @description Enum of all the routes exposed by the app @@ -25,19 +27,23 @@ export enum NativeRoute { quickAskSizeUpdated = 'quickAskSizeUpdated', ackDeepLink = 'ackDeepLink', - homePath = 'homePath', - getThemes = 'getThemes', - readTheme = 'readTheme', - - // used for migration. Please remove this later on. - getAllMessagesAndThreads = 'getAllMessagesAndThreads', - getAllLocalModels = 'getAllLocalModels', - syncModelFileToCortex = 'syncModelFileToCortex', +} - openAppLog = 'openAppLog', - appDataFolder = 'appDataFolder', - changeDataFolder = 'changeDataFolder', - isDirectoryEmpty = 'isDirectoryEmpty', +/** + * App Route APIs + * @description Enum of all the routes exposed by the app + */ +export enum AppRoute { + getAppConfigurations = 'getAppConfigurations', + updateAppConfiguration = 'updateAppConfiguration', + joinPath = 'joinPath', + isSubdirectory = 'isSubdirectory', + baseName = 'baseName', + startServer = 'startServer', + stopServer = 'stopServer', + log = 'log', + systemInformation = 'systemInformation', + showToast = 'showToast', } export enum AppEvent { @@ -51,6 +57,22 @@ export enum AppEvent { onDeepLink = 'onDeepLink', } +export enum DownloadRoute { + abortDownload = 'abortDownload', + downloadFile = 'downloadFile', + pauseDownload = 'pauseDownload', + resumeDownload = 'resumeDownload', + getDownloadProgress = 'getDownloadProgress', + getFileSize = 'getFileSize', +} + +export enum DownloadEvent { + onFileDownloadUpdate = 'onFileDownloadUpdate', + onFileDownloadError = 'onFileDownloadError', + onFileDownloadSuccess = 'onFileDownloadSuccess', + onFileUnzipSuccess = 'onFileUnzipSuccess', +} + export enum LocalImportModelEvent { onLocalImportModelUpdate = 'onLocalImportModelUpdate', onLocalImportModelFailed = 'onLocalImportModelFailed', @@ -58,17 +80,92 @@ export enum LocalImportModelEvent { onLocalImportModelFinished = 'onLocalImportModelFinished', } +export enum ExtensionRoute { + baseExtensions = 'baseExtensions', + getActiveExtensions = 'getActiveExtensions', + installExtension = 'installExtension', + invokeExtensionFunc = 'invokeExtensionFunc', + updateExtension = 'updateExtension', + uninstallExtension = 'uninstallExtension', +} +export enum FileSystemRoute { + appendFileSync = 'appendFileSync', + unlinkSync = 'unlinkSync', + existsSync = 'existsSync', + readdirSync = 'readdirSync', + rm = 'rm', + mkdir = 'mkdir', + readFileSync = 'readFileSync', + writeFileSync = 'writeFileSync', +} +export enum FileManagerRoute { + copyFile = 'copyFile', + getJanDataFolderPath = 'getJanDataFolderPath', + getResourcePath = 'getResourcePath', + getUserHomePath = 'getUserHomePath', + fileStat = 'fileStat', + writeBlob = 'writeBlob', +} + export type ApiFunction = (...args: any[]) => any export type NativeRouteFunctions = { [K in NativeRoute]: ApiFunction } +export type AppRouteFunctions = { + [K in AppRoute]: ApiFunction +} + export type AppEventFunctions = { [K in AppEvent]: ApiFunction } -export type APIFunctions = NativeRouteFunctions & AppEventFunctions +export type DownloadRouteFunctions = { + [K in DownloadRoute]: ApiFunction +} + +export type DownloadEventFunctions = { + [K in DownloadEvent]: ApiFunction +} + +export type ExtensionRouteFunctions = { + [K in ExtensionRoute]: ApiFunction +} + +export type FileSystemRouteFunctions = { + [K in FileSystemRoute]: ApiFunction +} + +export type FileManagerRouteFunctions = { + [K in FileManagerRoute]: ApiFunction +} -export const APIRoutes = [...Object.values(NativeRoute)] -export const APIEvents = [...Object.values(AppEvent), ...Object.values(LocalImportModelEvent)] +export type APIFunctions = NativeRouteFunctions & + AppRouteFunctions & + AppEventFunctions & + DownloadRouteFunctions & + DownloadEventFunctions & + ExtensionRouteFunctions & + FileSystemRouteFunctions & + FileManagerRoute + +export const CoreRoutes = [ + ...Object.values(AppRoute), + ...Object.values(DownloadRoute), + ...Object.values(ExtensionRoute), + ...Object.values(FileSystemRoute), + ...Object.values(FileManagerRoute), +] + +export const APIRoutes = [...CoreRoutes, ...Object.values(NativeRoute)] +export const APIEvents = [ + ...Object.values(AppEvent), + ...Object.values(DownloadEvent), + ...Object.values(LocalImportModelEvent), +] +export type PayloadType = { + messages: ChatCompletionMessage[] + model: string + stream: boolean +} diff --git a/core/src/types/assistant/assistantEntity.ts b/core/src/types/assistant/assistantEntity.ts index 1c60bae7a1..27592e26b6 100644 --- a/core/src/types/assistant/assistantEntity.ts +++ b/core/src/types/assistant/assistantEntity.ts @@ -1,27 +1,38 @@ -import { - AssistantTool as OpenAiAssistantTool, - Assistant as OpenAiAssistant, - AssistantCreateParams as OpenAiAssistantCreateParams, - AssistantUpdateParams as OpenAiAssistantUpdateParams, -} from 'openai/resources/beta/assistants' -import { AssistantResponseFormatOption as OpenAIAssistantResponseFormatOption } from 'openai/resources/beta/threads/threads' - -export interface Assistant extends OpenAiAssistant { - avatar?: string - - tools: AssistantTool[] -} - -export type AssistantResponseFormatOption = OpenAIAssistantResponseFormatOption - -export interface AssistantToolResources extends OpenAiAssistant.ToolResources {} - -export type AssistantTool = OpenAiAssistantTool & { - enabled?: boolean - +/** + * Assistant type defines the shape of an assistant object. + * @stored + */ + +export type AssistantTool = { + type: string + enabled: boolean useTimeWeightedRetriever?: boolean + settings: any } -export interface AssistantCreateParams extends OpenAiAssistantCreateParams {} - -export interface AssistantUpdateParams extends OpenAiAssistantUpdateParams {} +export type Assistant = { + /** Represents the avatar of the user. */ + avatar: string + /** Represents the location of the thread. */ + thread_location: string | undefined + /** Represents the unique identifier of the object. */ + id: string + /** Represents the object. */ + object: string + /** Represents the creation timestamp of the object. */ + created_at: number + /** Represents the name of the object. */ + name: string + /** Represents the description of the object. */ + description?: string + /** Represents the model of the object. */ + model: string + /** Represents the instructions for the object. */ + instructions?: string + /** Represents the tools associated with the object. */ + tools?: AssistantTool[] + /** Represents the file identifiers associated with the object. */ + file_ids: string[] + /** Represents the metadata of the object. */ + metadata?: Record +} diff --git a/core/src/types/assistant/assistantEvent.ts b/core/src/types/assistant/assistantEvent.ts new file mode 100644 index 0000000000..8c32f5d37a --- /dev/null +++ b/core/src/types/assistant/assistantEvent.ts @@ -0,0 +1,7 @@ +/** + * The `EventName` enumeration contains the names of all the available events in the Jan platform. + */ +export enum AssistantEvent { + /** The `OnAssistantsUpdate` event is emitted when the assistant list is updated. */ + OnAssistantsUpdate = 'OnAssistantsUpdate', +} diff --git a/core/src/types/assistant/assistantInterface.ts b/core/src/types/assistant/assistantInterface.ts new file mode 100644 index 0000000000..3c10bbb7fe --- /dev/null +++ b/core/src/types/assistant/assistantInterface.ts @@ -0,0 +1,26 @@ +import { Assistant } from './assistantEntity' +/** + * Assistant extension for managing assistants. + * @extends BaseExtension + */ +export interface AssistantInterface { + /** + * Creates a new assistant. + * @param {Assistant} assistant - The assistant object to be created. + * @returns {Promise} A promise that resolves when the assistant has been created. + */ + createAssistant(assistant: Assistant): Promise + + /** + * Deletes an existing assistant. + * @param {Assistant} assistant - The assistant object to be deleted. + * @returns {Promise} A promise that resolves when the assistant has been deleted. + */ + deleteAssistant(assistant: Assistant): Promise + + /** + * Retrieves all existing assistants. + * @returns {Promise} A promise that resolves to an array of all assistants. + */ + getAssistants(): Promise +} diff --git a/core/src/types/assistant/index.ts b/core/src/types/assistant/index.ts index 8682319af2..e18589551a 100644 --- a/core/src/types/assistant/index.ts +++ b/core/src/types/assistant/index.ts @@ -1 +1,3 @@ export * from './assistantEntity' +export * from './assistantEvent' +export * from './assistantInterface' diff --git a/core/src/types/config/appConfigEntity.ts b/core/src/types/config/appConfigEntity.ts index 6180303c1c..1402aeca12 100644 --- a/core/src/types/config/appConfigEntity.ts +++ b/core/src/types/config/appConfigEntity.ts @@ -1,8 +1,4 @@ export type AppConfiguration = { - dataFolderPath: string, - quickAsk: boolean, - cortexCppHost: string, - cortexCppPort: number, - apiServerHost: string, - apiServerPort: number, + data_folder: string + quick_ask: boolean } diff --git a/core/src/types/events/index.ts b/core/src/types/events/index.ts deleted file mode 100644 index d154c92d4d..0000000000 --- a/core/src/types/events/index.ts +++ /dev/null @@ -1,2 +0,0 @@ -export * from './model.event' -export * from './resource.event' diff --git a/core/src/types/events/model.event.ts b/core/src/types/events/model.event.ts deleted file mode 100644 index 98e9965cd8..0000000000 --- a/core/src/types/events/model.event.ts +++ /dev/null @@ -1,40 +0,0 @@ -export type ModelId = string - -const ModelLoadingEvents = [ - 'starting', - 'stopping', - 'started', - 'stopped', - 'starting-failed', - 'stopping-failed', - 'model-downloaded', - 'model-deleted', -] as const -export type ModelLoadingEvent = (typeof ModelLoadingEvents)[number] - -const AllModelStates = ['starting', 'stopping', 'started'] as const -export type ModelState = (typeof AllModelStates)[number] - -// TODO: should make this model -> id -export interface ModelStatus { - model: ModelId - status: ModelState - metadata: Record -} - -export interface ModelEvent { - model: ModelId - event: ModelLoadingEvent - metadata: Record -} - -export const EmptyModelEvent = {} - -export type StatusAndEvent = { - status: Record - event: ModelEvent | typeof EmptyModelEvent -} - -export interface ModelStatusAndEvent { - data: StatusAndEvent -} diff --git a/core/src/types/events/resource.event.ts b/core/src/types/events/resource.event.ts deleted file mode 100644 index 6bb5aa7970..0000000000 --- a/core/src/types/events/resource.event.ts +++ /dev/null @@ -1,21 +0,0 @@ -export interface ResourceEvent { - data: ResourceStatus -} - -export interface ResourceStatus { - mem: UsedMemInfo - cpu: { - usage: number - } - gpus: GpuInfo[] -} - -export interface UsedMemInfo { - total: number - used: number -} - -export interface GpuInfo { - name: string | undefined - vram: UsedMemInfo -} diff --git a/core/src/types/file/index.ts b/core/src/types/file/index.ts index a13a921361..d941987ef1 100644 --- a/core/src/types/file/index.ts +++ b/core/src/types/file/index.ts @@ -52,82 +52,3 @@ type DownloadSize = { total: number transferred: number } - -export interface DownloadState2 { - /** - * The id of a particular download. Being used to prevent duplication of downloads. - */ - id: string - - /** - * For displaying purposes. - */ - title: string - - /** - * The type of download. - */ - type: DownloadType2 - - /** - * Percentage of the download. - */ - progress: number - - /** - * The status of the download. - */ - status: DownloadStatus - - /** - * Explanation of the error if the download failed. - */ - error?: string - - /** - * The actual downloads. [DownloadState] is just a group to supporting for download multiple files. - */ - children: DownloadItem[] -} - -export enum DownloadStatus { - Pending = 'pending', - Downloading = 'downloading', - Error = 'error', - Downloaded = 'downloaded', -} - -export interface DownloadItem { - /** - * Filename of the download. - */ - id: string - - time: { - elapsed: number - remaining: number - } - - size: { - total: number - transferred: number - } - - checksum?: string - - status: DownloadStatus - - error?: string - - metadata?: Record -} - -export interface DownloadStateEvent { - data: DownloadState[] -} - -export enum DownloadType2 { - Model = 'model', - Miscelanous = 'miscelanous', - Engine = 'engine', -} diff --git a/core/src/types/huggingface/huggingfaceEntity.ts b/core/src/types/huggingface/huggingfaceEntity.ts index 1f7e3fb400..da846900ba 100644 --- a/core/src/types/huggingface/huggingfaceEntity.ts +++ b/core/src/types/huggingface/huggingfaceEntity.ts @@ -40,11 +40,6 @@ export type CardDataKeysTuple = typeof CardDataKeys export type CardDataKeys = CardDataKeysTuple[number] export const AllQuantizations = [ - 'IQ1_M', - 'IQ1_S', - 'IQ3_S', - 'Q3_K_XL', - 'IQ4_NL', 'Q3_K_S', 'Q3_K_M', 'Q3_K_L', @@ -56,16 +51,8 @@ export const AllQuantizations = [ 'Q4_1', 'Q5_0', 'Q5_1', - 'Q5_K_L', - 'Q4_K_L', 'IQ2_XXS', 'IQ2_XS', - 'IQ2_S', - 'IQ2_M', - 'IQ3_M', - 'IQ3_XS', - 'IQ3_XXS', - 'IQ4_XS', 'Q2_K', 'Q2_K_S', 'Q6_K', diff --git a/core/src/types/index.ts b/core/src/types/index.ts index 293f138556..6627ebff9b 100644 --- a/core/src/types/index.ts +++ b/core/src/types/index.ts @@ -2,6 +2,7 @@ export * from './assistant' export * from './model' export * from './thread' export * from './message' +export * from './inference' export * from './monitoring' export * from './file' export * from './config' @@ -9,4 +10,3 @@ export * from './huggingface' export * from './miscellaneous' export * from './api' export * from './setting' -export * from './events' diff --git a/core/src/types/inference/index.ts b/core/src/types/inference/index.ts new file mode 100644 index 0000000000..a0a71f1427 --- /dev/null +++ b/core/src/types/inference/index.ts @@ -0,0 +1,3 @@ +export * from './inferenceEntity' +export * from './inferenceInterface' +export * from './inferenceEvent' diff --git a/core/src/types/inference/inferenceEntity.ts b/core/src/types/inference/inferenceEntity.ts new file mode 100644 index 0000000000..c37e3b0793 --- /dev/null +++ b/core/src/types/inference/inferenceEntity.ts @@ -0,0 +1,46 @@ +import { ContentType, ContentValue } from '../message' + +/** + * The role of the author of this message. + */ +export enum ChatCompletionRole { + System = 'system', + Assistant = 'assistant', + User = 'user', +} + +/** + * The `MessageRequest` type defines the shape of a new message request object. + * @data_transfer_object + */ +export type ChatCompletionMessage = { + /** The contents of the message. **/ + content?: ChatCompletionMessageContent + /** The role of the author of this message. **/ + role: ChatCompletionRole +} + +export type ChatCompletionMessageContent = + | string + | (ChatCompletionMessageContentText & + ChatCompletionMessageContentImage & + ChatCompletionMessageContentDoc)[] + +export enum ChatCompletionMessageContentType { + Text = 'text', + Image = 'image_url', + Doc = 'doc_url', +} + +export type ChatCompletionMessageContentText = { + type: ChatCompletionMessageContentType + text: string +} +export type ChatCompletionMessageContentImage = { + type: ChatCompletionMessageContentType + image_url: { url: string } +} +export type ChatCompletionMessageContentDoc = { + type: ChatCompletionMessageContentType + doc_url: { url: string } +} diff --git a/core/src/types/inference/inferenceEvent.ts b/core/src/types/inference/inferenceEvent.ts new file mode 100644 index 0000000000..f685a54b37 --- /dev/null +++ b/core/src/types/inference/inferenceEvent.ts @@ -0,0 +1,7 @@ +/** + * The `EventName` enumeration contains the names of all the available events in the Jan platform. + */ +export enum InferenceEvent { + /** The `OnInferenceStopped` event is emitted when a inference is stopped. */ + OnInferenceStopped = 'OnInferenceStopped', +} diff --git a/core/src/types/inference/inferenceInterface.ts b/core/src/types/inference/inferenceInterface.ts new file mode 100644 index 0000000000..21e327e45d --- /dev/null +++ b/core/src/types/inference/inferenceInterface.ts @@ -0,0 +1,13 @@ +import { MessageRequest, ThreadMessage } from '../message' + +/** + * Inference extension. Start, stop and inference models. + */ +export interface InferenceInterface { + /** + * Processes an inference request. + * @param data - The data for the inference request. + * @returns The result of the inference request. + */ + inference(data: MessageRequest): Promise +} diff --git a/core/src/types/message/index.ts b/core/src/types/message/index.ts index b1a6f1e181..ebb4c363d8 100644 --- a/core/src/types/message/index.ts +++ b/core/src/types/message/index.ts @@ -1 +1,4 @@ export * from './messageEntity' +export * from './messageInterface' +export * from './messageEvent' +export * from './messageRequestType' diff --git a/core/src/types/message/messageEntity.ts b/core/src/types/message/messageEntity.ts index 8cba418a95..26bcad1a74 100644 --- a/core/src/types/message/messageEntity.ts +++ b/core/src/types/message/messageEntity.ts @@ -1,26 +1,122 @@ -import { - ChatCompletionMessageParam as OpenAiChatCompletionMessageParam, - ChatCompletionMessage as OpenAiChatCompletionMessage, -} from 'openai/resources' -import { - MessageCreateParams as OpenAiMessageCreateParams, - Message as OpenAiMessage, - MessageContent as OpenAiMessageContent, - TextContentBlock as OpenAiTextContentBlock, -} from 'openai/resources/beta/threads/messages' +import { ChatCompletionMessage, ChatCompletionRole } from '../inference' +import { ModelInfo } from '../model' +import { Thread } from '../thread' -export interface Message extends OpenAiMessage {} +/** + * The `ThreadMessage` type defines the shape of a thread's message object. + * @stored + */ +export type ThreadMessage = { + /** Unique identifier for the message, generated by default using the ULID method. **/ + id: string + /** Object name **/ + object: string + /** Thread id, default is a ulid. **/ + thread_id: string + /** The assistant id of this thread. **/ + assistant_id?: string + /** The role of the author of this message. **/ + role: ChatCompletionRole + /** The content of this message. **/ + content: ThreadContent[] + /** The status of this message. **/ + status: MessageStatus + /** The timestamp indicating when this message was created. Represented in Unix time. **/ + created: number + /** The timestamp indicating when this message was updated. Represented in Unix time. **/ + updated: number + /** The additional metadata of this message. **/ + metadata?: Record -export type MessageContent = OpenAiMessageContent + type?: string -export type TextContentBlock = OpenAiTextContentBlock + /** The error code which explain what error type. Used in conjunction with MessageStatus.Error */ + error_code?: ErrorCode +} -export interface MessageIncompleteDetails extends OpenAiMessage.IncompleteDetails {} +/** + * The `MessageRequest` type defines the shape of a new message request object. + * @data_transfer_object + */ +export type MessageRequest = { + id?: string -export interface MessageAttachment extends OpenAiMessage.Attachment {} + /** + * @deprecated Use thread object instead + * The thread id of the message request. + */ + threadId: string -export interface ChatCompletionMessage extends OpenAiChatCompletionMessage {} + /** + * The assistant id of the message request. + */ + assistantId?: string -export type ChatCompletionMessageParam = OpenAiChatCompletionMessageParam + /** Messages for constructing a chat completion request **/ + messages?: ChatCompletionMessage[] -export interface MessageCreateParams extends OpenAiMessageCreateParams {} + /** Settings for constructing a chat completion request **/ + model?: ModelInfo + + /** The thread of this message is belong to. **/ + // TODO: deprecate threadId field + thread?: Thread + + type?: string +} + +/** + * The status of the message. + * @data_transfer_object + */ +export enum MessageStatus { + /** Message is fully loaded. **/ + Ready = 'ready', + /** Message is not fully loaded. **/ + Pending = 'pending', + /** Message loaded with error. **/ + Error = 'error', + /** Message is cancelled streaming */ + Stopped = 'stopped', +} + +export enum ErrorCode { + InvalidApiKey = 'invalid_api_key', + + AuthenticationError = 'authentication_error', + + InsufficientQuota = 'insufficient_quota', + + InvalidRequestError = 'invalid_request_error', + + Unknown = 'unknown', +} + +/** + * The content type of the message. + */ +export enum ContentType { + Text = 'text', + Image = 'image', + Pdf = 'pdf', +} + +/** + * The `ContentValue` type defines the shape of a content value object + * @data_transfer_object + */ +export type ContentValue = { + value: string + annotations: string[] + name?: string + size?: number +} + +/** + * The `ThreadContent` type defines the shape of a message's content object + * @data_transfer_object + */ +export type ThreadContent = { + type: ContentType + text: ContentValue +} diff --git a/core/src/types/message/messageEvent.ts b/core/src/types/message/messageEvent.ts new file mode 100644 index 0000000000..40fd84c30b --- /dev/null +++ b/core/src/types/message/messageEvent.ts @@ -0,0 +1,8 @@ +export enum MessageEvent { + /** The `OnMessageSent` event is emitted when a message is sent. */ + OnMessageSent = 'OnMessageSent', + /** The `OnMessageResponse` event is emitted when a message is received. */ + OnMessageResponse = 'OnMessageResponse', + /** The `OnMessageUpdate` event is emitted when a message is updated. */ + OnMessageUpdate = 'OnMessageUpdate', +} diff --git a/core/src/types/message/messageInterface.ts b/core/src/types/message/messageInterface.ts new file mode 100644 index 0000000000..f6579da88b --- /dev/null +++ b/core/src/types/message/messageInterface.ts @@ -0,0 +1,30 @@ +import { ThreadMessage } from './messageEntity' + +/** + * Conversational extension. Persists and retrieves conversations. + * @abstract + * @extends BaseExtension + */ +export interface MessageInterface { + /** + * Adds a new message to the thread. + * @param {ThreadMessage} message - The message to be added. + * @returns {Promise} A promise that resolves when the message has been added. + */ + addNewMessage(message: ThreadMessage): Promise + + /** + * Writes an array of messages to a specific thread. + * @param {string} threadId - The ID of the thread to write the messages to. + * @param {ThreadMessage[]} messages - The array of messages to be written. + * @returns {Promise} A promise that resolves when the messages have been written. + */ + writeMessages(threadId: string, messages: ThreadMessage[]): Promise + + /** + * Retrieves all messages from a specific thread. + * @param {string} threadId - The ID of the thread to retrieve the messages from. + * @returns {Promise} A promise that resolves to an array of messages from the thread. + */ + getAllMessages(threadId: string): Promise +} diff --git a/core/src/types/message/messageRequestType.ts b/core/src/types/message/messageRequestType.ts new file mode 100644 index 0000000000..cbb4cf4217 --- /dev/null +++ b/core/src/types/message/messageRequestType.ts @@ -0,0 +1,5 @@ +export enum MessageRequestType { + Thread = 'Thread', + Assistant = 'Assistant', + Summary = 'Summary', +} diff --git a/core/src/types/model/chatCompletion.ts b/core/src/types/model/chatCompletion.ts deleted file mode 100644 index 4cb5c9f971..0000000000 --- a/core/src/types/model/chatCompletion.ts +++ /dev/null @@ -1,10 +0,0 @@ -import { - ChatCompletionCreateParamsNonStreaming as OpenAiChatCompletionCreateParamsNonStreaming, - ChatCompletionCreateParamsStreaming as OpenAiChatCompletionCreateParamsStreaming, -} from 'openai/resources/chat/completions' - -export interface ChatCompletionCreateParamsNonStreaming - extends OpenAiChatCompletionCreateParamsNonStreaming {} - -export interface ChatCompletionCreateParamsStreaming - extends OpenAiChatCompletionCreateParamsStreaming {} diff --git a/core/src/types/model/index.ts b/core/src/types/model/index.ts index bcfe476d38..fdbf018636 100644 --- a/core/src/types/model/index.ts +++ b/core/src/types/model/index.ts @@ -1,3 +1,4 @@ export * from './modelEntity' +export * from './modelInterface' +export * from './modelEvent' export * from './modelImport' -export * from './chatCompletion' diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index 56c3f6abdf..426b308462 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -1,237 +1,146 @@ -import { Model as OpenAiModel } from 'openai/resources' - -export const LocalEngines = ['cortex.llamacpp', 'cortex.onnx', 'cortex.tensorrt-llm'] as const - -export const RemoteEngines = [ - 'anthropic', - 'mistral', - 'martian', - 'openrouter', - 'openai', - 'groq', - 'triton_trtllm', - 'cohere', - 'nvidia', -] as const - -export const LlmEngines = [...LocalEngines, ...RemoteEngines] as const -export type LlmEngine = (typeof LlmEngines)[number] -export type LocalEngine = (typeof LocalEngines)[number] -export type RemoteEngine = (typeof RemoteEngines)[number] +/** + * Represents the information about a model. + * @stored + */ +export type ModelInfo = { + id: string + settings: ModelSettingParams + parameters: ModelRuntimeParams + engine?: InferenceEngine +} /** - * The available engine statuses. + * Represents the inference engine. + * @stored */ -export enum EngineStatus { - Ready = 'ready', - MissingConfiguration = 'missing_configuration', - NotInitialized = 'not_initialized', - NotSupported = 'not_supported', - Error = 'error', -} + +export enum InferenceEngine { + anthropic = 'anthropic', + mistral = 'mistral', + martian = 'martian', + openrouter = 'openrouter', + nitro = 'nitro', + openai = 'openai', + groq = 'groq', + triton_trtllm = 'triton_trtllm', + nitro_tensorrt_llm = 'nitro-tensorrt-llm', + cohere = 'cohere', +} export type ModelArtifact = { filename: string url: string } -export interface Model extends OpenAiModel, ModelSettingParams, ModelRuntimeParams { - /** - * Model identifier. - */ - model: string - - /** - * GGUF metadata: general.name - */ - name?: string - - /** - * GGUF metadata: version - */ - version?: string - - /** - * Currently we only have 'embedding' | 'llm' - */ - model_type?: string - - /** - * The model download source. It can be an external url or a local filepath. - */ - files: string[] | ModelArtifact - - metadata?: Record -} - /** - * The available model settings. + * Model type defines the shape of a model object. + * @stored */ -export interface ModelSettingParams { - /** - * The context length for model operations varies; the maximum depends on the specific model used. - */ - ctx_len?: number - - /** - * The number of layers to load onto the GPU for acceleration. - */ - ngl?: number - embedding?: boolean - +export type Model = { /** - * Number of parallel sequences to decode + * The type of the object. + * Default: "model" */ - n_parallel?: number + object: string /** - * Determines CPU inference threads, limited by hardware and OS. (Maximum determined by system) + * The version of the model. */ - cpu_threads?: number + version: string /** - * GGUF metadata: tokenizer.chat_template + * The format of the model. */ - prompt_template?: string - system_prompt?: string - ai_prompt?: string - user_prompt?: string - llama_model_path?: string - mmproj?: string - cont_batching?: boolean + format: string /** - * The model engine. + * The model download source. It can be an external url or a local filepath. */ - engine?: LlmEngine + sources: ModelArtifact[] /** - * The prompt to use for internal configuration + * The model identifier, which can be referenced in the API endpoints. */ - pre_prompt?: string + id: string /** - * The batch size for prompt eval step + * Human-readable name that is used for UI. */ - n_batch?: number + name: string /** - * To enable prompt caching or not + * The Unix timestamp (in seconds) for when the model was created */ - caching_enabled?: boolean + created: number /** - * Group attention factor in self-extend + * Default: "A cool model from Huggingface" */ - grp_attn_n?: number + description: string /** - * Group attention width in self-extend + * The model settings. */ - grp_attn_w?: number + settings: ModelSettingParams /** - * Prevent system swapping of the model to disk in macOS + * The model runtime parameters. */ - mlock?: boolean + parameters: ModelRuntimeParams /** - * You can constrain the sampling using GBNF grammars by providing path to a grammar file + * Metadata of the model. */ - grammar_file?: string - + metadata: ModelMetadata /** - * To enable Flash Attention, default is true + * The model engine. */ - flash_attn?: boolean + engine: InferenceEngine +} - /** - * KV cache type: f16, q8_0, q4_0, default is f16 - */ - cache_type?: string +export type ModelMetadata = { + author: string + tags: string[] + size: number + cover?: string +} - /** - * To enable mmap, default is true - */ - use_mmap?: boolean +/** + * The available model settings. + */ +export type ModelSettingParams = { + ctx_len?: number + ngl?: number + embedding?: boolean + n_parallel?: number + cpu_threads?: number + prompt_template?: string + system_prompt?: string + ai_prompt?: string + user_prompt?: string + llama_model_path?: string + mmproj?: string + cont_batching?: boolean + vision_model?: boolean + text_model?: boolean } -type ModelSettingParamsKeys = keyof ModelSettingParams -export const modelSettingParamsKeys: ModelSettingParamsKeys[] = [ - 'ctx_len', - 'ngl', - 'embedding', - 'n_parallel', - 'cpu_threads', - 'prompt_template', - 'system_prompt', - 'ai_prompt', - 'user_prompt', - 'llama_model_path', - 'mmproj', - 'cont_batching', - 'engine', - 'pre_prompt', - 'n_batch', - 'caching_enabled', - 'grp_attn_n', - 'grp_attn_w', - 'mlock', - 'grammar_file', - 'flash_attn', - 'cache_type', - 'use_mmap', -] /** * The available model runtime parameters. */ -export interface ModelRuntimeParams { - /** - * Controls the randomness of the model’s output. - */ +export type ModelRuntimeParams = { temperature?: number token_limit?: number top_k?: number - - /** - * Set probability threshold for more relevant outputs. - */ top_p?: number - - /** - * Enable real-time data processing for faster predictions. - */ stream?: boolean - - /* - * The maximum number of tokens the model will generate in a single response. - */ max_tokens?: number - - /** - * Defines specific tokens or phrases at which the model will stop generating further output. - */ stop?: string[] - - /** - * Adjusts the likelihood of the model repeating words or phrases in its output. - */ frequency_penalty?: number - - /** - * Influences the generation of new and varied concepts in the model’s output. - */ presence_penalty?: number + engine?: string +} + +export type ModelInitFailed = Model & { + error: Error } -type ModelRuntimeParamsKeys = keyof ModelRuntimeParams -export const modelRuntimeParamsKeys: ModelRuntimeParamsKeys[] = [ - 'temperature', - 'token_limit', - 'top_k', - 'top_p', - 'stream', - 'max_tokens', - 'stop', - 'frequency_penalty', - 'presence_penalty', -] diff --git a/core/src/types/model/modelEvent.ts b/core/src/types/model/modelEvent.ts new file mode 100644 index 0000000000..443f3a34fb --- /dev/null +++ b/core/src/types/model/modelEvent.ts @@ -0,0 +1,17 @@ +/** + * The `EventName` enumeration contains the names of all the available events in the Jan platform. + */ +export enum ModelEvent { + /** The `OnModelInit` event is emitted when a model inits. */ + OnModelInit = 'OnModelInit', + /** The `OnModelReady` event is emitted when a model ready. */ + OnModelReady = 'OnModelReady', + /** The `OnModelFail` event is emitted when a model fails loading. */ + OnModelFail = 'OnModelFail', + /** The `OnModelStop` event is emitted when a model start to stop. */ + OnModelStop = 'OnModelStop', + /** The `OnModelStopped` event is emitted when a model stopped ok. */ + OnModelStopped = 'OnModelStopped', + /** The `OnModelUpdate` event is emitted when the model list is updated. */ + OnModelsUpdate = 'OnModelsUpdate', +} diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts new file mode 100644 index 0000000000..639c7c8d34 --- /dev/null +++ b/core/src/types/model/modelInterface.ts @@ -0,0 +1,52 @@ +import { GpuSetting } from '../miscellaneous' +import { Model } from './modelEntity' + +/** + * Model extension for managing models. + */ +export interface ModelInterface { + /** + * Downloads a model. + * @param model - The model to download. + * @param network - Optional object to specify proxy/whether to ignore SSL certificates. + * @returns A Promise that resolves when the model has been downloaded. + */ + downloadModel( + model: Model, + gpuSettings?: GpuSetting, + network?: { ignoreSSL?: boolean; proxy?: string } + ): Promise + + /** + * Cancels the download of a specific model. + * @param {string} modelId - The ID of the model to cancel the download for. + * @returns {Promise} A promise that resolves when the download has been cancelled. + */ + cancelModelDownload(modelId: string): Promise + + /** + * Deletes a model. + * @param modelId - The ID of the model to delete. + * @returns A Promise that resolves when the model has been deleted. + */ + deleteModel(modelId: string): Promise + + /** + * Saves a model. + * @param model - The model to save. + * @returns A Promise that resolves when the model has been saved. + */ + saveModel(model: Model): Promise + + /** + * Gets a list of downloaded models. + * @returns A Promise that resolves with an array of downloaded models. + */ + getDownloadedModels(): Promise + + /** + * Gets a list of configured models. + * @returns A Promise that resolves with an array of configured models. + */ + getConfiguredModels(): Promise +} diff --git a/core/src/types/thread/index.ts b/core/src/types/thread/index.ts index 2349b1bbf6..32155e1cd3 100644 --- a/core/src/types/thread/index.ts +++ b/core/src/types/thread/index.ts @@ -1 +1,3 @@ export * from './threadEntity' +export * from './threadInterface' +export * from './threadEvent' diff --git a/core/src/types/thread/threadEntity.ts b/core/src/types/thread/threadEntity.ts index 8ec1e9dbf5..dd88b10eca 100644 --- a/core/src/types/thread/threadEntity.ts +++ b/core/src/types/thread/threadEntity.ts @@ -1,12 +1,46 @@ -import { Thread as OpenAiThread } from 'openai/resources/beta/threads/threads' -import { Assistant } from '../assistant' +import { AssistantTool } from '../assistant' +import { ModelInfo } from '../model' -export interface ThreadToolResources extends OpenAiThread.ToolResources {} - -export interface Thread extends OpenAiThread { +/** + * The `Thread` type defines the shape of a thread object. + * @stored + */ +export type Thread = { + /** Unique identifier for the thread, generated by default using the ULID method. **/ + id: string + /** Object name **/ + object: string + /** The title of this thread. **/ title: string + /** Assistants in this thread. **/ + assistants: ThreadAssistantInfo[] + /** The timestamp indicating when this thread was created, represented in ISO 8601 format. **/ + created: number + /** The timestamp indicating when this thread was updated, represented in ISO 8601 format. **/ + updated: number + /** The additional metadata of this thread. **/ + metadata?: Record +} - assistants: Assistant[] +/** + * Represents the information about an assistant in a thread. + * @stored + */ +export type ThreadAssistantInfo = { + assistant_id: string + assistant_name: string + model: ModelInfo + instructions?: string + tools?: AssistantTool[] +} - tool_resources: ThreadToolResources | null +/** + * Represents the state of a thread. + * @stored + */ +export type ThreadState = { + hasMore: boolean + waitingForResponse: boolean + error?: Error + lastMessage?: string } diff --git a/core/src/types/thread/threadEvent.ts b/core/src/types/thread/threadEvent.ts new file mode 100644 index 0000000000..4b19b09c10 --- /dev/null +++ b/core/src/types/thread/threadEvent.ts @@ -0,0 +1,4 @@ +export enum ThreadEvent { + /** The `OnThreadStarted` event is emitted when a thread is started. */ + OnThreadStarted = 'OnThreadStarted', +} diff --git a/core/src/types/thread/threadInterface.ts b/core/src/types/thread/threadInterface.ts new file mode 100644 index 0000000000..792c8c8a5f --- /dev/null +++ b/core/src/types/thread/threadInterface.ts @@ -0,0 +1,31 @@ +import { Thread } from './threadEntity' + +/** + * Conversational extension. Persists and retrieves conversations. + * @abstract + * @extends BaseExtension + */ +export interface ThreadInterface { + /** + * Returns a list of thread. + * @abstract + * @returns {Promise} A promise that resolves to an array of threads. + */ + getThreads(): Promise + + /** + * Saves a thread. + * @abstract + * @param {Thread} thread - The thread to save. + * @returns {Promise} A promise that resolves when the thread is saved. + */ + saveThread(thread: Thread): Promise + + /** + * Deletes a thread. + * @abstract + * @param {string} threadId - The ID of the thread to delete. + * @returns {Promise} A promise that resolves when the thread is deleted. + */ + deleteThread(threadId: string): Promise +} diff --git a/core/tests/node/path.test.ts b/core/tests/node/path.test.ts index fd278ed462..5390df1193 100644 --- a/core/tests/node/path.test.ts +++ b/core/tests/node/path.test.ts @@ -1,11 +1,12 @@ -describe('Test file normalize', () => { - test('returns no file protocol prefix on Unix', async () => { - // expect(normalizeFilePath('file://test.txt')).toBe('test.txt') - // expect(normalizeFilePath('file:/test.txt')).toBe('test.txt') - expect(1 + 1).toBe(2) - }) - // test("returns no file protocol prefix on Windows", async () => { - // expect(normalizeFilePath("file:\\\\test.txt")).toBe("test.txt"); - // expect(normalizeFilePath("file:\\test.txt")).toBe("test.txt"); - // }); -}) +import { normalizeFilePath } from "../../src/node/helper/path"; + +describe("Test file normalize", () => { + test("returns no file protocol prefix on Unix", async () => { + expect(normalizeFilePath("file://test.txt")).toBe("test.txt"); + expect(normalizeFilePath("file:/test.txt")).toBe("test.txt"); + }); + test("returns no file protocol prefix on Windows", async () => { + expect(normalizeFilePath("file:\\\\test.txt")).toBe("test.txt"); + expect(normalizeFilePath("file:\\test.txt")).toBe("test.txt"); + }); +}); diff --git a/core/tsconfig.json b/core/tsconfig.json index 14e15e4ec5..daeb7eeffe 100644 --- a/core/tsconfig.json +++ b/core/tsconfig.json @@ -1,9 +1,9 @@ { "compilerOptions": { "moduleResolution": "node", - "target": "es2022", + "target": "es5", "module": "ES2020", - "lib": ["es2018", "dom"], + "lib": ["es2015", "es2016", "es2017", "dom"], "strict": true, "sourceMap": true, "declaration": true, diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml new file mode 100644 index 0000000000..2e09d641b3 --- /dev/null +++ b/docker-compose-dev.yml @@ -0,0 +1,171 @@ +# Docker Compose file for setting up Minio, createbuckets, app_cpu, and app_gpu services + +version: '3.7' + +services: + # Minio service for object storage + minio: + image: minio/minio + volumes: + - minio_data:/data + ports: + - '9000:9000' + - '9001:9001' + environment: + # Set the root user and password for Minio + MINIO_ROOT_USER: minioadmin # This acts as AWS_ACCESS_KEY + MINIO_ROOT_PASSWORD: minioadmin # This acts as AWS_SECRET_ACCESS_KEY + command: server --console-address ":9001" /data + restart: always + healthcheck: + test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] + interval: 30s + timeout: 20s + retries: 3 + networks: + vpcbr: + ipv4_address: 10.5.0.2 + + # createbuckets service to create a bucket and set its policy + createbuckets: + image: minio/mc + depends_on: + - minio + entrypoint: > + /bin/sh -c " + /usr/bin/mc alias set myminio http://minio:9000 minioadmin minioadmin; + /usr/bin/mc mb myminio/mybucket; + /usr/bin/mc policy set public myminio/mybucket; + exit 0; + " + networks: + vpcbr: + + # app_cpu service for running the CPU version of the application + app_cpu_s3fs: + image: jan:latest + volumes: + - app_data_cpu_s3fs:/app/server/build/jan + build: + context: . + dockerfile: Dockerfile + environment: + # Set the AWS access key, secret access key, bucket name, endpoint, and region for app_cpu + AWS_ACCESS_KEY_ID: minioadmin + AWS_SECRET_ACCESS_KEY: minioadmin + S3_BUCKET_NAME: mybucket + AWS_ENDPOINT: http://10.5.0.2:9000 + AWS_REGION: us-east-1 + API_BASE_URL: http://localhost:1337 + restart: always + profiles: + - cpu-s3fs + ports: + - '3000:3000' + - '1337:1337' + - '3928:3928' + networks: + vpcbr: + ipv4_address: 10.5.0.3 + + # app_gpu service for running the GPU version of the application + app_gpu_s3fs: + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + image: jan-gpu:latest + volumes: + - app_data_gpu_s3fs:/app/server/build/jan + build: + context: . + dockerfile: Dockerfile.gpu + restart: always + environment: + # Set the AWS access key, secret access key, bucket name, endpoint, and region for app_gpu + AWS_ACCESS_KEY_ID: minioadmin + AWS_SECRET_ACCESS_KEY: minioadmin + S3_BUCKET_NAME: mybucket + AWS_ENDPOINT: http://10.5.0.2:9000 + AWS_REGION: us-east-1 + API_BASE_URL: http://localhost:1337 + profiles: + - gpu-s3fs + ports: + - '3000:3000' + - '1337:1337' + - '3928:3928' + networks: + vpcbr: + ipv4_address: 10.5.0.4 + + app_cpu_fs: + image: jan:latest + volumes: + - app_data_cpu_fs:/app/server/build/jan + build: + context: . + dockerfile: Dockerfile + environment: + API_BASE_URL: http://localhost:1337 + restart: always + profiles: + - cpu-fs + ports: + - '3000:3000' + - '1337:1337' + - '3928:3928' + networks: + vpcbr: + ipv4_address: 10.5.0.5 + + # app_gpu service for running the GPU version of the application + app_gpu_fs: + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + image: jan-gpu:latest + volumes: + - app_data_gpu_fs:/app/server/build/jan + build: + context: . + dockerfile: Dockerfile.gpu + restart: always + environment: + API_BASE_URL: http://localhost:1337 + profiles: + - gpu-fs + ports: + - '3000:3000' + - '1337:1337' + - '3928:3928' + networks: + vpcbr: + ipv4_address: 10.5.0.6 + +volumes: + minio_data: + app_data_cpu_s3fs: + app_data_gpu_s3fs: + app_data_cpu_fs: + app_data_gpu_fs: + +networks: + vpcbr: + driver: bridge + ipam: + config: + - subnet: 10.5.0.0/16 + gateway: 10.5.0.1 +# Usage: +# - Run 'docker compose -f docker-compose-dev.yml --profile cpu-s3fs up -d' to start the app_cpu service +# - Run 'docker compose -f docker-compose-dev.yml --profile gpu-s3fs up -d' to start the app_gpu service +# - Run 'docker compose -f docker-compose-dev.yml --profile cpu-fs up -d' to start the app_cpu service +# - Run 'docker compose -f docker-compose-dev.yml --profile gpu-fs up -d' to start the app_gpu service diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000..1e5660c12b --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,159 @@ +# Docker Compose file for setting up Minio, createbuckets, app_cpu, and app_gpu services + +version: '3.7' + +services: + # Minio service for object storage + minio: + image: minio/minio + volumes: + - minio_data:/data + ports: + - '9000:9000' + - '9001:9001' + environment: + # Set the root user and password for Minio + MINIO_ROOT_USER: minioadmin # This acts as AWS_ACCESS_KEY + MINIO_ROOT_PASSWORD: minioadmin # This acts as AWS_SECRET_ACCESS_KEY + command: server --console-address ":9001" /data + restart: always + healthcheck: + test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] + interval: 30s + timeout: 20s + retries: 3 + networks: + vpcbr: + ipv4_address: 10.5.0.2 + + # createbuckets service to create a bucket and set its policy + createbuckets: + image: minio/mc + depends_on: + - minio + entrypoint: > + /bin/sh -c " + /usr/bin/mc alias set myminio http://minio:9000 minioadmin minioadmin; + /usr/bin/mc mb myminio/mybucket; + /usr/bin/mc policy set public myminio/mybucket; + exit 0; + " + networks: + vpcbr: + + # app_cpu service for running the CPU version of the application + app_cpu_s3fs: + volumes: + - app_data_cpu_s3fs:/app/server/build/jan + image: ghcr.io/janhq/jan-server:dev-cpu-latest + environment: + # Set the AWS access key, secret access key, bucket name, endpoint, and region for app_cpu + AWS_ACCESS_KEY_ID: minioadmin + AWS_SECRET_ACCESS_KEY: minioadmin + S3_BUCKET_NAME: mybucket + AWS_ENDPOINT: http://10.5.0.2:9000 + AWS_REGION: us-east-1 + API_BASE_URL: http://localhost:1337 + restart: always + profiles: + - cpu-s3fs + ports: + - '3000:3000' + - '1337:1337' + - '3928:3928' + networks: + vpcbr: + ipv4_address: 10.5.0.3 + + # app_gpu service for running the GPU version of the application + app_gpu_s3fs: + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + image: ghcr.io/janhq/jan-server:dev-cuda-12.2-latest + volumes: + - app_data_gpu_s3fs:/app/server/build/jan + restart: always + environment: + # Set the AWS access key, secret access key, bucket name, endpoint, and region for app_gpu + AWS_ACCESS_KEY_ID: minioadmin + AWS_SECRET_ACCESS_KEY: minioadmin + S3_BUCKET_NAME: mybucket + AWS_ENDPOINT: http://10.5.0.2:9000 + AWS_REGION: us-east-1 + API_BASE_URL: http://localhost:1337 + profiles: + - gpu-s3fs + ports: + - '3000:3000' + - '1337:1337' + - '3928:3928' + networks: + vpcbr: + ipv4_address: 10.5.0.4 + + app_cpu_fs: + image: ghcr.io/janhq/jan-server:dev-cpu-latest + volumes: + - app_data_cpu_fs:/app/server/build/jan + environment: + API_BASE_URL: http://localhost:1337 + restart: always + profiles: + - cpu-fs + ports: + - '3000:3000' + - '1337:1337' + - '3928:3928' + networks: + vpcbr: + ipv4_address: 10.5.0.5 + + # app_gpu service for running the GPU version of the application + app_gpu_fs: + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + image: ghcr.io/janhq/jan-server:dev-cuda-12.2-latest + volumes: + - app_data_gpu_fs:/app/server/build/jan + restart: always + environment: + API_BASE_URL: http://localhost:1337 + profiles: + - gpu-fs + ports: + - '3000:3000' + - '1337:1337' + - '3928:3928' + networks: + vpcbr: + ipv4_address: 10.5.0.6 + +volumes: + minio_data: + app_data_cpu_s3fs: + app_data_gpu_s3fs: + app_data_cpu_fs: + app_data_gpu_fs: + +networks: + vpcbr: + driver: bridge + ipam: + config: + - subnet: 10.5.0.0/16 + gateway: 10.5.0.1 +# Usage: +# - Run 'docker compose --profile cpu-s3fs up -d' to start the app_cpu service +# - Run 'docker compose --profile gpu-s3fs up -d' to start the app_gpu service +# - Run 'docker compose --profile cpu-fs up -d' to start the app_cpu service +# - Run 'docker compose --profile gpu-fs up -d' to start the app_gpu service diff --git a/electron/.eslintrc.js b/electron/.eslintrc.js index a8b7c00cb8..d252ec42bd 100644 --- a/electron/.eslintrc.js +++ b/electron/.eslintrc.js @@ -34,5 +34,5 @@ module.exports = { { name: 'Link', linkAttribute: 'to' }, ], }, - ignorePatterns: ['build', 'renderer', 'node_modules', '@global', 'playwright-report'], + ignorePatterns: ['build', 'renderer', 'node_modules', '@global'], } diff --git a/electron/download.bat b/electron/download.bat deleted file mode 100644 index 04f763cd0e..0000000000 --- a/electron/download.bat +++ /dev/null @@ -1,24 +0,0 @@ -@echo off -setlocal - -:: Read the version from the version.txt file -set /p CORTEX_VERSION=<./resources/version.txt - -:: Set the download URL -set DOWNLOAD_URL=https://github.com/janhq/cortex/releases/download/v%CORTEX_VERSION%/cortex-%CORTEX_VERSION%-amd64-windows.tar.gz - -:: Set the output directory and file name -set OUTPUT_DIR=./resources/win -set OUTPUT_FILE=%OUTPUT_DIR%/cortex.exe - -echo %OUTPUT_FILE% - -:: Check if the file already exists -if exist %OUTPUT_FILE% ( - echo File %OUTPUT_FILE% already exists. Skipping download. -) else ( - echo Downloading from %DOWNLOAD_URL% - .\node_modules\.bin\download %DOWNLOAD_URL% -e -o %OUTPUT_DIR% -) - -endlocal \ No newline at end of file diff --git a/electron/handlers/common.ts b/electron/handlers/common.ts new file mode 100644 index 0000000000..a2a1bd2f79 --- /dev/null +++ b/electron/handlers/common.ts @@ -0,0 +1,20 @@ +import { Handler, RequestHandler } from '@janhq/core/node' +import { ipcMain } from 'electron' +import { windowManager } from '../managers/window' + +export function injectHandler() { + const ipcWrapper: Handler = ( + route: string, + listener: (...args: any[]) => any + ) => + ipcMain.handle(route, async (_event, ...args: any[]) => { + return listener(...args) + }) + + const handler = new RequestHandler( + ipcWrapper, + (channel: string, args: any) => + windowManager.mainWindow?.webContents.send(channel, args) + ) + handler.handle() +} diff --git a/electron/handlers/native.ts b/electron/handlers/native.ts index d11dc40d62..869b9fd58e 100644 --- a/electron/handlers/native.ts +++ b/electron/handlers/native.ts @@ -1,32 +1,30 @@ import { app, ipcMain, dialog, shell, nativeTheme } from 'electron' +import { join } from 'path' import { windowManager } from '../managers/window' import { + ModuleManager, + getJanDataFolderPath, + getJanExtensionsPath, + init, AppEvent, NativeRoute, SelectFileProp, - SelectFileOption, - AppConfiguration, } from '@janhq/core/node' +import { SelectFileOption } from '@janhq/core' import { menu } from '../utils/menu' -import { join } from 'path' -import { - getAppConfigurations, - getJanDataFolderPath, - legacyDataPath, - updateAppConfiguration, -} from './../utils/path' -import { - readdirSync, - writeFileSync, - readFileSync, - existsSync, - mkdirSync, - lstatSync, -} from 'fs' -import { dump, load } from 'js-yaml' + const isMac = process.platform === 'darwin' export function handleAppIPCs() { + /** + * Handles the "openAppDirectory" IPC message by opening the app's user data directory. + * The `shell.openPath` method is used to open the directory in the user's default file explorer. + * @param _event - The IPC event object. + */ + ipcMain.handle(NativeRoute.openAppDirectory, async (_event) => { + shell.openPath(getJanDataFolderPath()) + }) + /** * Handles the "setNativeThemeLight" IPC message by setting the native theme source to "light". * This will change the appearance of the app to the light theme. @@ -43,13 +41,6 @@ export function handleAppIPCs() { windowManager.mainWindow?.minimize() }) - ipcMain.handle(NativeRoute.homePath, () => { - // Handles the 'get jan home path' IPC event. This event is triggered to get the default jan home path. - return join( - process.env[process.platform == 'win32' ? 'USERPROFILE' : 'HOME'] ?? '', - 'jan' - ) - }) ipcMain.handle(NativeRoute.setMaximizeApp, async (_event) => { if (windowManager.mainWindow?.isMaximized()) { windowManager.mainWindow.unmaximize() @@ -58,28 +49,6 @@ export function handleAppIPCs() { } }) - ipcMain.handle(NativeRoute.getThemes, async () => { - const folderPath = join(getJanDataFolderPath(), 'themes') - const installedThemes = readdirSync(folderPath) - - const themesOptions = Promise.all( - installedThemes - .filter((x: string) => x !== '.DS_Store') - .map(async (x: string) => { - const y = join(folderPath, x, `theme.json`) - const c = JSON.parse(readFileSync(y, 'utf-8')) - return { name: c?.displayName, value: c.id } - }) - ) - return themesOptions - }) - - ipcMain.handle(NativeRoute.readTheme, async (_event, themeId: string) => { - const folderPath = join(getJanDataFolderPath(), 'themes') - const filePath = join(folderPath, themeId, `theme.json`) - return JSON.parse(readFileSync(filePath, 'utf-8')) - }) - /** * Handles the "setNativeThemeDark" IPC message by setting the native theme source to "dark". * This will change the appearance of the app to the dark theme. @@ -112,8 +81,27 @@ export function handleAppIPCs() { * @param url - The URL to reload. */ ipcMain.handle(NativeRoute.relaunch, async (_event) => { - app.relaunch() - app.exit() + ModuleManager.instance.clearImportedModules() + + if (app.isPackaged) { + app.relaunch() + app.exit() + } else { + for (const modulePath in ModuleManager.instance.requiredModules) { + delete require.cache[ + require.resolve(join(getJanExtensionsPath(), modulePath)) + ] + } + init({ + // Function to check from the main process that user wants to install a extension + confirmInstall: async (_extensions: string[]) => { + return true + }, + // Path to install extension to + extensionsPath: getJanExtensionsPath(), + }) + windowManager.mainWindow?.reload() + } }) ipcMain.handle(NativeRoute.selectDirectory, async () => { @@ -183,7 +171,7 @@ export function handleAppIPCs() { } ) - ipcMain.handle(NativeRoute.showOpenMenu, function (_e, args) { + ipcMain.handle(NativeRoute.showOpenMenu, function (e, args) { if (!isMac && windowManager.mainWindow) { menu.popup({ window: windowManager.mainWindow, @@ -212,273 +200,4 @@ export function handleAppIPCs() { ipcMain.handle(NativeRoute.ackDeepLink, async (_event): Promise => { windowManager.ackDeepLink() }) - - ipcMain.handle(NativeRoute.openAppLog, async (_event): Promise => { - const configuration = getAppConfigurations() - const dataFolder = configuration.dataFolderPath - - try { - const errorMessage = await shell.openPath(join(dataFolder)) - if (errorMessage) { - console.error(`An error occurred: ${errorMessage}`) - } else { - console.log('Path opened successfully') - } - } catch (error) { - console.error(`Failed to open path: ${error}`) - } - }) - - ipcMain.handle(NativeRoute.syncModelFileToCortex, async (_event) => { - // Read models from legacy data folder - const janModelFolderPath = join(legacyDataPath(), 'models') - const allModelFolders = readdirSync(janModelFolderPath) - - // Latest app configs - const configration = getAppConfigurations() - const destinationFolderPath = join(configration.dataFolderPath, 'models') - - if (!existsSync(destinationFolderPath)) mkdirSync(destinationFolderPath) - - console.log( - `Syncing model from ${allModelFolders} to ${destinationFolderPath}` - ) - const reflect = require('@alumna/reflect') - - for (const modelName of allModelFolders) { - const modelFolderPath = join(janModelFolderPath, modelName) - // check if exist and is a directory - if (!existsSync(modelFolderPath)) { - console.debug(`Model folder ${modelFolderPath} does not exist`) - continue - } - - // check if it is a directory - if (!lstatSync(modelFolderPath).isDirectory()) { - console.debug(`${modelFolderPath} is not a directory`) - continue - } - - try { - const filesInModelFolder = readdirSync(modelFolderPath) - const destinationPath = join(destinationFolderPath, modelName) - - const modelJsonFullPath = join( - janModelFolderPath, - modelName, - 'model.json' - ) - if (!existsSync(modelJsonFullPath)) { - console.error(`Model json file not found in ${modelName}`) - continue - } - - const model = JSON.parse(readFileSync(modelJsonFullPath, 'utf-8')) - const fileNames: string[] = model.sources.map((x: any) => x.filename) - let files: string[] = [] - - if (filesInModelFolder.length > 1) { - // prepend fileNames with model folder path - files = fileNames.map((x: string) => - join(destinationFolderPath, model.id, x) - ) - } else if ( - model.sources.length && - !/^(http|https):\/\/[^/]+\/.*/.test(model.sources[0].url) - ) { - // Symlink case - files = [model.sources[0].url] - } else continue - - // create folder if not exist - // only for local model files - if (!existsSync(destinationPath) && filesInModelFolder.length > 1) { - mkdirSync(destinationPath, { recursive: true }) - } - - const engine = - model.engine === 'nitro' || model.engine === 'cortex' - ? 'cortex.llamacpp' - : (model.engine ?? 'cortex.llamacpp') - - const updatedModelFormat = { - id: model.id, - name: model.id, - model: model.id, - version: Number(model.version), - files: files ?? [], - created: Date.now(), - object: 'model', - owned_by: model.metadata?.author ?? '', - - // settings - ngl: model.settings?.ngl, - ctx_len: model.settings?.ctx_len ?? 2048, - engine: engine, - prompt_template: model.settings?.prompt_template ?? '', - - // parameters - stop: model.parameters?.stop ?? [], - top_p: model.parameters?.top_p, - temperature: model.parameters?.temperature, - frequency_penalty: model.parameters?.frequency_penalty, - presence_penalty: model.parameters?.presence_penalty, - max_tokens: model.parameters?.max_tokens ?? 2048, - stream: model.parameters?.stream ?? true, - } - if (filesInModelFolder.length > 1) { - const { err } = await reflect({ - src: modelFolderPath, - dest: destinationPath, - recursive: true, - delete: false, - overwrite: true, - errorOnExist: false, - }) - - if (err) { - console.error(err) - continue - } - } - // create the model.yml file - const modelYamlData = dump(updatedModelFormat) - const modelYamlPath = join(destinationFolderPath, `${modelName}.yaml`) - - writeFileSync(modelYamlPath, modelYamlData) - } catch (err) { - console.error(err) - } - } - }) - - ipcMain.handle( - NativeRoute.getAllMessagesAndThreads, - async (_event): Promise => { - const janThreadFolderPath = join(legacyDataPath(), 'threads') - // check if exist - if (!existsSync(janThreadFolderPath)) { - return { - threads: [], - messages: [], - } - } - // get children of thread folder - const allThreadFolders = readdirSync(janThreadFolderPath) - const threads: any[] = [] - const messages: any[] = [] - for (const threadFolder of allThreadFolders) { - try { - const threadJsonFullPath = join( - janThreadFolderPath, - threadFolder, - 'thread.json' - ) - const thread = JSON.parse(readFileSync(threadJsonFullPath, 'utf-8')) - threads.push(thread) - - const messageFullPath = join( - janThreadFolderPath, - threadFolder, - 'messages.jsonl' - ) - - if (!existsSync(messageFullPath)) continue - const lines = readFileSync(messageFullPath, 'utf-8') - .toString() - .split('\n') - .filter((line: any) => line !== '') - for (const line of lines) { - messages.push(JSON.parse(line)) - } - } catch (err) { - console.error(err) - } - } - return { - threads, - messages, - } - } - ) - - ipcMain.handle( - NativeRoute.getAllLocalModels, - async (_event): Promise => { - const janModelsFolderPath = join(legacyDataPath(), 'models') - - if (!existsSync(janModelsFolderPath)) { - console.debug('No local models found') - return false - } - - // get children of thread folder - const allModelsFolders = readdirSync(janModelsFolderPath) - let hasLocalModels = false - for (const modelFolder of allModelsFolders) { - try { - const modelsFullPath = join(janModelsFolderPath, modelFolder) - const dir = readdirSync(modelsFullPath) - const ggufFile = dir.some((file) => file.endsWith('.gguf')) - if (ggufFile) { - hasLocalModels = true - break - } - } catch (err) { - console.error(err) - } - } - return hasLocalModels - } - ) - ipcMain.handle(NativeRoute.appDataFolder, () => { - return getJanDataFolderPath() - }) - - ipcMain.handle(NativeRoute.changeDataFolder, async (_event, path) => { - const appConfiguration: AppConfiguration = getAppConfigurations() - const currentJanDataFolder = appConfiguration.dataFolderPath - - appConfiguration.dataFolderPath = path - - const reflect = require('@alumna/reflect') - const { err } = await reflect({ - src: currentJanDataFolder, - dest: path, - recursive: true, - delete: false, - overwrite: true, - errorOnExist: false, - }) - if (err) { - console.error(err) - throw err - } - - // Migrate models - const janModelsPath = join(path, 'models') - if (existsSync(janModelsPath)) { - const modelYamls = readdirSync(janModelsPath).filter( - (x) => x.endsWith('.yaml') || x.endsWith('.yml') - ) - for (const yaml of modelYamls) { - const modelPath = join(janModelsPath, yaml) - const model = load(readFileSync(modelPath, 'utf-8')) as any - if ( - 'files' in model && - Array.isArray(model.files) && - model.files.length > 0 - ) { - model.files[0] = model.files[0].replace(currentJanDataFolder, path) - } - writeFileSync(modelPath, dump(model)) - } - } - await updateAppConfiguration(appConfiguration) - }) - - ipcMain.handle(NativeRoute.isDirectoryEmpty, async (_event, path) => { - const dirChildren = readdirSync(path) - return dirChildren.filter((x) => x !== '.DS_Store').length === 0 - }) } diff --git a/electron/main.ts b/electron/main.ts index c346f8e713..6ce7f476a5 100644 --- a/electron/main.ts +++ b/electron/main.ts @@ -1,15 +1,16 @@ import { app, BrowserWindow } from 'electron' import { join, resolve } from 'path' - /** * Managers **/ import { windowManager } from './managers/window' +import { getAppConfigurations, log } from '@janhq/core/node' /** * IPC Handlers **/ +import { injectHandler } from './handlers/common' import { handleAppUpdates } from './handlers/update' import { handleAppIPCs } from './handlers/native' @@ -17,22 +18,24 @@ import { handleAppIPCs } from './handlers/native' * Utils **/ import { setupMenu } from './utils/menu' -import { createUserSpace, getJanDataFolderPath } from './utils/path' +import { createUserSpace } from './utils/path' import { migrate } from './utils/migration' import { cleanUpAndQuit } from './utils/clean' +import { setupExtensions } from './utils/extension' import { setupCore } from './utils/setup' import { setupReactDevTool } from './utils/dev' -import log from 'electron-log' - -import { start } from 'cortexso' -import { cortexCppPort, cortexJsPort, cortexHost, cleanCortexProcesses } from './utils/cortex' +import { trayManager } from './managers/tray' +import { logSystemInfo } from './utils/system' +import { registerGlobalShortcuts } from './utils/shortcut' const preloadPath = join(__dirname, 'preload.js') const rendererPath = join(__dirname, '..', 'renderer') +const quickAskPath = join(rendererPath, 'search.html') const mainPath = join(rendererPath, 'index.html') const mainUrl = 'http://localhost:3000' +const quickAskUrl = `${mainUrl}/search` const gotTheLock = app.requestSingleInstanceLock() @@ -51,19 +54,8 @@ const createMainWindow = () => { windowManager.createMainWindow(preloadPath, startUrl) } -log.initialize() -log.info('Starting jan from main thread..') - -// replace all console.log to log -Object.assign(console, log.functions) - app .whenReady() - .then(() => { - const dataFolderPath = join(getJanDataFolderPath(), 'jan.log') - log.transports.file.resolvePathFn = () => dataFolderPath - }) - .then(() => setupCore()) .then(() => { if (!gotTheLock) { app.quit() @@ -85,23 +77,24 @@ app ) } }) - - .then(() => cleanCortexProcesses()) - .then(() => { - start('jan', cortexHost, cortexJsPort, cortexCppPort, getJanDataFolderPath()) - }) + .then(setupCore) .then(createUserSpace) .then(migrate) + .then(setupExtensions) .then(setupMenu) .then(handleIPCs) .then(handleAppUpdates) + .then(() => process.env.CI !== 'e2e' && createQuickAskWindow()) .then(createMainWindow) + .then(registerGlobalShortcuts) .then(() => { if (!app.isPackaged) { setupReactDevTool() windowManager.mainWindow?.webContents.openDevTools() } }) + .then(() => process.env.CI !== 'e2e' && trayManager.createSystemTray()) + .then(logSystemInfo) .then(() => { app.on('activate', () => { if (!BrowserWindow.getAllWindows().length) { @@ -116,27 +109,45 @@ app.on('open-url', (_event, url) => { windowManager.sendMainAppDeepLink(url) }) -app.once('quit', async () => { +app.on('before-quit', function (_event) { + trayManager.destroyCurrentTray() +}) + +app.once('quit', () => { cleanUpAndQuit() }) -app.once('window-all-closed', async () => { +app.once('window-all-closed', () => { + // Feature Toggle for Quick Ask + if ( + getAppConfigurations().quick_ask && + !windowManager.isQuickAskWindowDestroyed() + ) + return cleanUpAndQuit() }) +function createQuickAskWindow() { + // Feature Toggle for Quick Ask + if (!getAppConfigurations().quick_ask) return + const startUrl = app.isPackaged ? `file://${quickAskPath}` : quickAskUrl + windowManager.createQuickAskWindow(preloadPath, startUrl) +} /** * Handles various IPC messages from the renderer process. */ function handleIPCs() { // Inject core handlers for IPCs + injectHandler() + // Handle native IPCs handleAppIPCs() } -/** - * Suppress Node error messages +/* + ** Suppress Node error messages */ process.on('uncaughtException', function (err) { - log.error(`Error: ${err}`) + log(`Error: ${err}`) }) diff --git a/electron/managers/tray.ts b/electron/managers/tray.ts index 470499238d..b81b1e5565 100644 --- a/electron/managers/tray.ts +++ b/electron/managers/tray.ts @@ -1,14 +1,14 @@ import { join } from 'path' import { Tray, app, Menu } from 'electron' import { windowManager } from '../managers/window' -import { getAppConfigurations } from './../utils/path' +import { getAppConfigurations } from '@janhq/core/node' class TrayManager { currentTray: Tray | undefined createSystemTray = () => { // Feature Toggle for Quick Ask - if (!getAppConfigurations().quickAsk) return + if (!getAppConfigurations().quick_ask) return if (this.currentTray) { return diff --git a/electron/managers/window.ts b/electron/managers/window.ts index d837505aa7..3d5107b280 100644 --- a/electron/managers/window.ts +++ b/electron/managers/window.ts @@ -1,9 +1,8 @@ import { BrowserWindow, app, shell } from 'electron' import { quickAskWindowConfig } from './quickAskWindowConfig' import { mainWindowConfig } from './mainWindowConfig' -import { getAppConfigurations } from './../utils/path' +import { getAppConfigurations, AppEvent } from '@janhq/core/node' import { getBounds, saveBounds } from '../utils/setup' -import { AppEvent } from '@janhq/core/node' /** * Manages the current window instance. @@ -32,7 +31,6 @@ class WindowManager { x: bounds.x, y: bounds.y, webPreferences: { - allowRunningInsecureContent: true, nodeIntegration: true, preload: preloadPath, webSecurity: false, @@ -73,7 +71,7 @@ class WindowManager { windowManager.mainWindow?.on('close', function (evt) { // Feature Toggle for Quick Ask - if (!getAppConfigurations().quickAsk) return + if (!getAppConfigurations().quick_ask) return if (!isAppQuitting) { evt.preventDefault() diff --git a/electron/package.json b/electron/package.json index 6e32a2bdbb..feaee5e16d 100644 --- a/electron/package.json +++ b/electron/package.json @@ -10,15 +10,6 @@ "build": { "appId": "jan.ai.app", "productName": "Jan", - "extraResources": [ - { - "from": "resources/${os}", - "to": "bin", - "filter": [ - "**/*" - ] - } - ], "files": [ "renderer/**/*", "build/**/*.{js,map}", @@ -35,8 +26,7 @@ "docs", "scripts", "icons", - "themes", - "package.json" + "themes" ], "publish": [ { @@ -92,28 +82,23 @@ "dev": "yarn copy:assets && tsc -p . && electron .", "compile": "tsc -p .", "start": "electron .", + "build": "yarn copy:assets && run-script-os", "build:test": "yarn copy:assets && run-script-os", "build:test:darwin": "tsc -p . && electron-builder -p never -m --dir", "build:test:win32": "tsc -p . && electron-builder -p never -w --dir", "build:test:linux": "tsc -p . && electron-builder -p never -l --dir", - "downloadcortex": "run-script-os", - "downloadcortex:linux": "CORTEX_VERSION=$(cat ./resources/version.txt) && echo https://github.com/janhq/cortex/releases/download/v${CORTEX_VERSION}/cortex-${CORTEX_VERSION}-amd64-linux.tar.gz && download https://github.com/janhq/cortex/releases/download/v${CORTEX_VERSION}/cortex-${CORTEX_VERSION}-amd64-linux.tar.gz -e -o ./resources/linux && rm -rf ./resources/linux/cortex-${CORTEX_VERSION}-amd64-linux.tar.gz && chmod +x ./resources/linux/cortex", - "downloadcortex:darwin": "CORTEX_VERSION=$(cat ./resources/version.txt) && ARCH=$(node -e \"console.log(process.arch === 'arm64' ? 'arm64' : 'amd64')\") && echo https://github.com/janhq/cortex/releases/download/v${CORTEX_VERSION}/cortex-${CORTEX_VERSION}-${ARCH}-mac.tar.gz && download https://github.com/janhq/cortex/releases/download/v${CORTEX_VERSION}/cortex-${CORTEX_VERSION}-${ARCH}-mac.tar.gz -e -o ./resources/mac/ && rm -rf ./resources/mac/cortex-${CORTEX_VERSION}-${ARCH}-mac.tar.gz && chmod +x ./resources/mac/cortex", - "downloadcortex:win32": "download.bat", - "build": "yarn copy:assets && run-script-os", "build:darwin": "tsc -p . && electron-builder -p never -m", "build:win32": "tsc -p . && electron-builder -p never -w", "build:linux": "tsc -p . && electron-builder -p never -l deb -l AppImage", - "build:publish": "yarn copy:assets && yarn downloadcortex && run-script-os", + "build:publish": "yarn copy:assets && run-script-os", "build:publish:darwin": "tsc -p . && electron-builder -p always -m", "build:publish:win32": "tsc -p . && electron-builder -p always -w", "build:publish:linux": "tsc -p . && electron-builder -p always -l deb -l AppImage" }, "dependencies": { - "js-yaml": "4.1.0", - "electron-log": "^5.1.5", "@alumna/reflect": "^1.1.3", "@janhq/core": "link:./core", + "@janhq/server": "link:./server", "@npmcli/arborist": "^7.1.0", "electron-store": "^8.1.0", "electron-updater": "^6.1.7", @@ -122,12 +107,10 @@ "pacote": "^17.0.4", "request": "^2.88.2", "request-progress": "^3.0.0", - "@kirillvakalov/nut-tree__nut-js": "4.2.1-2", - "cortexso": "v0.5.0-43" + "ulidx": "^2.3.0", + "@kirillvakalov/nut-tree__nut-js": "4.2.1-2" }, "devDependencies": { - "@types/js-yaml": "4.0.9", - "download-cli": "^1.1.1", "@electron/notarize": "^2.1.0", "@playwright/test": "^1.38.1", "@types/npmcli__arborist": "^5.6.4", diff --git a/electron/preload.ts b/electron/preload.ts index 7378add11f..05f48d37ad 100644 --- a/electron/preload.ts +++ b/electron/preload.ts @@ -3,8 +3,9 @@ * @module preload */ -import { APIEvents, APIRoutes } from '@janhq/core/node' +import { APIEvents, APIRoutes, AppConfiguration, getAppConfigurations, updateAppConfiguration } from '@janhq/core/node' import { contextBridge, ipcRenderer } from 'electron' +import { readdirSync } from 'fs' const interfaces: { [key: string]: (...args: any[]) => any } = {} @@ -12,8 +13,9 @@ const interfaces: { [key: string]: (...args: any[]) => any } = {} APIRoutes.forEach((method) => { // For each method, create a function on the interfaces object // This function invokes the method on the ipcRenderer with any provided arguments - + interfaces[method] = (...args: any[]) => ipcRenderer.invoke(method, ...args) + }) // Loop over each method in APIEvents @@ -25,6 +27,30 @@ APIEvents.forEach((method) => { }) +interfaces['changeDataFolder'] = async path => { + const appConfiguration: AppConfiguration = await ipcRenderer.invoke('getAppConfigurations') + const currentJanDataFolder = appConfiguration.data_folder + appConfiguration.data_folder = path + const reflect = require('@alumna/reflect') + const { err } = await reflect({ + src: currentJanDataFolder, + dest: path, + recursive: true, + delete: false, + overwrite: true, + errorOnExist: false, + }) + if (err) { + console.error(err) + throw err + } + await ipcRenderer.invoke('updateAppConfiguration', appConfiguration) +} + +interfaces['isDirectoryEmpty'] = async path => { + const dirChildren = await readdirSync(path) + return dirChildren.filter((x) => x !== '.DS_Store').length === 0 +} // Expose the 'interfaces' object in the main world under the name 'electronAPI' // This allows the renderer process to access these methods directly diff --git a/electron/resources/version.txt b/electron/resources/version.txt deleted file mode 100644 index ec4dae953e..0000000000 --- a/electron/resources/version.txt +++ /dev/null @@ -1 +0,0 @@ -0.5.0-33 diff --git a/electron/tests/e2e/hub.e2e.spec.ts b/electron/tests/e2e/hub.e2e.spec.ts index 8a4a5680d9..23d4d0b6d7 100644 --- a/electron/tests/e2e/hub.e2e.spec.ts +++ b/electron/tests/e2e/hub.e2e.spec.ts @@ -16,9 +16,9 @@ test.beforeAll(async () => { test('explores hub', async ({ hubPage }) => { await hubPage.navigateByMenu() await hubPage.verifyContainerVisible() + const useModelBtn= page.getByTestId(/^use-model-btn-.*/).first() - const searchBar = page.getByTestId('hub-search-bar').first() - await expect(searchBar).toBeVisible({ + await expect(useModelBtn).toBeVisible({ timeout: TIMEOUT, }) }) diff --git a/electron/tests/e2e/navigation.e2e.spec.ts b/electron/tests/e2e/navigation.e2e.spec.ts index 7c416aac78..b599a951c1 100644 --- a/electron/tests/e2e/navigation.e2e.spec.ts +++ b/electron/tests/e2e/navigation.e2e.spec.ts @@ -7,13 +7,12 @@ test('renders left navigation panel', async () => { .first() .isEnabled({ timeout: TIMEOUT }) expect([settingsBtn].filter((e) => !e).length).toBe(0) - - // System Monitor should be there - await page.getByText('System Monitor').first().click({ + // Chat section should be there + await page.getByTestId('Local API Server').first().click({ timeout: TIMEOUT, }) - const systemMonitors = page.getByText('Running Models').first() - await expect(systemMonitors).toBeVisible({ + const localServer = page.getByTestId('local-server-testid').first() + await expect(localServer).toBeVisible({ timeout: TIMEOUT, }) }) diff --git a/electron/tests/e2e/thread.e2e.spec.ts b/electron/tests/e2e/thread.e2e.spec.ts index 60899fbfb2..c13e911191 100644 --- a/electron/tests/e2e/thread.e2e.spec.ts +++ b/electron/tests/e2e/thread.e2e.spec.ts @@ -7,40 +7,29 @@ test('Select GPT model from Hub and Chat with Invalid API Key', async ({ hubPage // Select the first GPT model await page - .locator('[data-testid*="GPT"]') + .locator('[data-testid^="use-model-btn"][data-testid*="gpt"]') .first().click() - // TBU - // await page - // .getByTestId('btn-setup') - // .click() + // Attempt to create thread and chat in Thread page + await page + .getByTestId('btn-create-thread') + .click() + + await page + .getByTestId('txt-input-chat') + .fill('dummy value') + + await page + .getByTestId('btn-send-chat') + .click() - // const APIKeyError = page.getByTestId('setup-api-key-modal') - // await expect(APIKeyError).toBeVisible({ - // timeout: TIMEOUT, - // }) + await page.waitForFunction(() => { + const loaders = document.querySelectorAll('[data-testid$="loader"]'); + return !loaders.length; + }, { timeout: TIMEOUT }); - // Deprecated since Jan is no longer allow chat with remote model without API Key, but keep it here to wait for a new feature - // // Attempt to create thread and chat in Thread page - // await page - // .getByTestId('btn-create-thread') - // .click() - // - // await page - // .getByTestId('txt-input-chat') - // .fill('dummy value') - // - // await page - // .getByTestId('btn-send-chat') - // .click() - // - // await page.waitForFunction(() => { - // const loaders = document.querySelectorAll('[data-testid$="loader"]'); - // return !loaders.length; - // }, { timeout: TIMEOUT }); - // - // const APIKeyError = page.getByTestId('invalid-API-key-error') - // await expect(APIKeyError).toBeVisible({ - // timeout: TIMEOUT, - // }) + const APIKeyError = page.getByTestId('invalid-API-key-error') + await expect(APIKeyError).toBeVisible({ + timeout: TIMEOUT, + }) }) diff --git a/electron/tsconfig.json b/electron/tsconfig.json index 9a3f5823c6..11c9d85770 100644 --- a/electron/tsconfig.json +++ b/electron/tsconfig.json @@ -1,6 +1,5 @@ { "compilerOptions": { - "resolveJsonModule": true, "target": "es5", "module": "commonjs", "noImplicitAny": true, diff --git a/electron/utils/clean.ts b/electron/utils/clean.ts index 9ca094876c..12a68d39e4 100644 --- a/electron/utils/clean.ts +++ b/electron/utils/clean.ts @@ -1,14 +1,14 @@ +import { ModuleManager } from '@janhq/core/node' import { windowManager } from './../managers/window' +import { dispose } from './disposable' import { app } from 'electron' -import { cleanCortexProcesses, stopCortexApiServer } from './cortex' - -/** - * Clean up windows then quit - */ -export async function cleanUpAndQuit() { - windowManager.cleanUp() - await stopCortexApiServer() - await cleanCortexProcesses() - app.quit() -} \ No newline at end of file +export function cleanUpAndQuit() { + if (!ModuleManager.instance.cleaningResource) { + ModuleManager.instance.cleaningResource = true + windowManager.cleanUp() + dispose(ModuleManager.instance.requiredModules) + ModuleManager.instance.clearImportedModules() + app.quit() + } +} diff --git a/electron/utils/cortex.ts b/electron/utils/cortex.ts deleted file mode 100644 index afd2534c78..0000000000 --- a/electron/utils/cortex.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { killProcessesOnPort } from './process' - -// Cortex server configurations -export const cortexJsPort = 1338 -export const cortexCppPort = 3940 -export const cortexHost = '127.0.0.1' - -/** - * Kills all possible running cortex processes - */ -export async function cleanCortexProcesses() { - await killProcessesOnPort(cortexCppPort) - await killProcessesOnPort(cortexJsPort) -} - -/** - * Stops the cortex API server - */ -export async function stopCortexApiServer() { - // this function is not meant to be success. It will throw an error. - try { - await fetch(`http://${cortexHost}:${cortexJsPort}/v1/system`, { - method: 'DELETE', - }) - } catch (error) { - // Do nothing - // Accept failure here - } -} diff --git a/electron/utils/extension.ts b/electron/utils/extension.ts new file mode 100644 index 0000000000..e055411a68 --- /dev/null +++ b/electron/utils/extension.ts @@ -0,0 +1,12 @@ +import { getJanExtensionsPath, init } from '@janhq/core/node' + +export const setupExtensions = async () => { + init({ + // Function to check from the main process that user wants to install a extension + confirmInstall: async (_extensions: string[]) => { + return true + }, + // Path to install extension to + extensionsPath: getJanExtensionsPath(), + }) +} diff --git a/electron/utils/menu.ts b/electron/utils/menu.ts index 1abb0c6ab8..3f838e5caa 100644 --- a/electron/utils/menu.ts +++ b/electron/utils/menu.ts @@ -1,6 +1,7 @@ // @ts-nocheck import { app, Menu, shell, dialog } from 'electron' import { autoUpdater } from 'electron-updater' +import { log } from '@janhq/core/node' const isMac = process.platform === 'darwin' const template: (Electron.MenuItemConstructorOptions | Electron.MenuItem)[] = [ @@ -12,7 +13,7 @@ const template: (Electron.MenuItemConstructorOptions | Electron.MenuItem)[] = [ click: () => dialog.showMessageBox({ title: `Jan`, - message: `Jan Version v${app.getVersion()}\nCortex Version ${global.core.cortexVersion()}\n\nCopyright © 2024 Jan`, + message: `Jan Version v${app.getVersion()}\n\nCopyright © 2024 Jan`, }), }, { @@ -33,9 +34,7 @@ const template: (Electron.MenuItemConstructorOptions | Electron.MenuItem)[] = [ } }) .catch((error) => { - console.error( - 'Error checking for updates:' + JSON.stringify(error) - ) + log('Error checking for updates:' + JSON.stringify(error)) }), }, { type: 'separator' }, diff --git a/electron/utils/migration.ts b/electron/utils/migration.ts index 9e0af53ed2..defe0cebb8 100644 --- a/electron/utils/migration.ts +++ b/electron/utils/migration.ts @@ -11,7 +11,11 @@ import { lstatSync, } from 'fs' import Store from 'electron-store' -import { getJanDataFolderPath, appResourcePath } from './../utils/path' +import { + getJanExtensionsPath, + getJanDataFolderPath, + appResourcePath, +} from '@janhq/core/node' /** * Migrates the extensions & themes. @@ -24,6 +28,8 @@ export async function migrate() { if (store.get('migrated_version') !== app.getVersion()) { console.debug('start migration:', store.get('migrated_version')) + // if (existsSync(getJanExtensionsPath())) + // rmdirSync(getJanExtensionsPath(), { recursive: true }) await migrateThemes() store.set('migrated_version', app.getVersion()) diff --git a/electron/utils/path.ts b/electron/utils/path.ts index 6ca495c978..4438156bcb 100644 --- a/electron/utils/path.ts +++ b/electron/utils/path.ts @@ -1,23 +1,6 @@ import { mkdir } from 'fs-extra' -import { existsSync, writeFileSync, readFileSync } from 'fs' -import { join } from 'path' -import { AppConfiguration } from '@janhq/core/node' -import os from 'os' -import { dump, load } from 'js-yaml' -import { app } from 'electron' - -const configurationFileName = '.janrc' - -const defaultJanDataFolder = join(os.homedir(), 'jan') - -const defaultAppConfig: AppConfiguration = { - dataFolderPath: defaultJanDataFolder, - quickAsk: false, - cortexCppHost: '127.0.0.1', - cortexCppPort: 3940, - apiServerHost: '127.0.0.1', - apiServerPort: 1338, -} +import { existsSync } from 'fs' +import { getJanDataFolderPath } from '@janhq/core/node' export async function createUserSpace(): Promise { const janDataFolderPath = getJanDataFolderPath() @@ -31,110 +14,3 @@ export async function createUserSpace(): Promise { } } } - -export async function appResourcePath(): Promise { - let electron: any = undefined - - try { - const moduleName = 'electron' - electron = await import(moduleName) - } catch (err) { - console.error('Electron is not available') - } - - // electron - if (electron && electron.protocol) { - let appPath = join(electron.app.getAppPath(), '..', 'app.asar.unpacked') - - if (!electron.app.isPackaged) { - // for development mode - appPath = join(electron.app.getAppPath()) - } - return appPath - } - // server - return join(global.core.appPath(), '../../..') -} - -/** - * Getting App Configurations. - * - * @returns {AppConfiguration} The app configurations. - */ -export const getAppConfigurations = (): AppConfiguration => { - // Retrieve Application Support folder path - // Fallback to user home directory if not found - const configurationFile = getConfigurationFilePath() - console.debug('getAppConfiguration file path', configurationFile) - - if (!existsSync(configurationFile)) { - // create default app config if we don't have one - console.debug( - `App config not found, creating default config at ${configurationFile}` - ) - writeFileSync(configurationFile, dump(defaultAppConfig)) - return defaultAppConfig - } - - try { - const configYaml = readFileSync(configurationFile, 'utf-8') - const appConfigurations = load(configYaml) as AppConfiguration - console.debug('app config', appConfigurations) - return { - ...appConfigurations, - quickAsk: false, - } - } catch (err) { - console.error( - `Failed to read app config, return default config instead! Err: ${err}` - ) - return defaultAppConfig - } -} - -// Get configuration file path of the application -const getConfigurationFilePath = () => { - const homeDir = os.homedir() - const configPath = join(homeDir, configurationFileName) - return configPath -} - -export const updateAppConfiguration = ( - configuration: AppConfiguration -): Promise => { - const configurationFile = getConfigurationFilePath() - console.debug( - 'updateAppConfiguration, configurationFile: ', - configurationFile - ) - - writeFileSync(configurationFile, dump(configuration)) - return Promise.resolve() -} - -/** - * Utility function to get data folder path - * - * @returns {string} The data folder path. - */ -export const getJanDataFolderPath = (): string => { - return getAppConfigurations().dataFolderPath -} - -// This is to support pulling legacy configs for migration purpose -export const legacyConfigs = () => { - const legacyConfigFilePath = join(app.getPath('userData'), 'settings.json') - - const legacyConfigs = JSON.parse( - readFileSync(legacyConfigFilePath, 'utf-8') - ) as any - - console.debug('legacyConfigs', legacyConfigs) - - return legacyConfigs -} - -// This is to support pulling legacy data path for migration purpose -export const legacyDataPath = () => { - return legacyConfigs().data_folder -} diff --git a/electron/utils/process.ts b/electron/utils/process.ts deleted file mode 100644 index 90d9f49e46..0000000000 --- a/electron/utils/process.ts +++ /dev/null @@ -1,97 +0,0 @@ -import { execSync } from 'child_process' - -/** - * Kill process on port util - * @param port port number to kill - */ -export function killProcessesOnPort(port: number): void { - try { - console.log(`Killing processes on port ${port}...`) - if (process.platform === 'win32') { - killProcessesOnWindowsPort(port) - } else { - killProcessesOnUnixPort(port) - } - } catch (error) { - console.error( - `Failed to kill process(es) on port ${port}: ${(error as Error).message}` - ) - } - } - -/** - * Kill process on port - Windows - * @param port - * @returns - */ -function killProcessesOnWindowsPort(port: number): void { - let result: string - try { - result = execSync(`netstat -ano | findstr :${port}`).toString() - } catch (error) { - console.log(`No processes found on port ${port}.`) - return - } - - const lines = result.split('\n').filter(Boolean) - - if (lines.length === 0) { - console.log(`No processes found on port ${port}.`) - return - } - - const pids = lines - .map((line) => { - const parts = line.trim().split(/\s+/) - return parts[parts.length - 1] - }) - .filter((pid): pid is string => Boolean(pid) && !isNaN(Number(pid))) - - if (pids.length === 0) { - console.log(`No valid PIDs found for port ${port}.`) - return - } - const uniquePids = Array.from(new Set(pids)) - console.log('uniquePids', uniquePids) - - uniquePids.forEach((pid) => { - try { - execSync(`taskkill /PID ${pid} /F`) - console.log( - `Process with PID ${pid} on port ${port} has been terminated.` - ) - } catch (error) { - console.error( - `Failed to kill process with PID ${pid}: ${(error as Error).message}` - ) - } - }) - } - - /** - * Kill process on port - Unix - * @param port - * @returns - */ - function killProcessesOnUnixPort(port: number): void { - let pids: string[] - - try { - pids = execSync(`lsof -ti tcp:${port}`) - .toString() - .trim() - .split('\n') - .filter(Boolean) - } catch (error) { - if ((error as { status?: number }).status === 1) { - console.log(`No processes found on port ${port}.`) - return - } - throw error // Re-throw if it's not the "no processes found" error - } - - pids.forEach((pid) => { - process.kill(parseInt(pid), 'SIGTERM') - console.log(`Process with PID ${pid} on port ${port} has been terminated.`) - }) - } \ No newline at end of file diff --git a/electron/utils/setup.ts b/electron/utils/setup.ts index a6297fd24e..437e21f977 100644 --- a/electron/utils/setup.ts +++ b/electron/utils/setup.ts @@ -1,26 +1,16 @@ import { app } from 'electron' import Store from 'electron-store' -import { existsSync, readFileSync } from 'original-fs' -import { appResourcePath } from './path' -import { join } from 'path' + const DEFAULT_WIDTH = 1000 const DEFAULT_HEIGHT = 800 const storage = new Store() export const setupCore = async () => { - let cortexVersion = 'N/A' - // Read package.json - const pkgPath = join(await appResourcePath(), 'package.json') - if(existsSync(pkgPath)) { - const pkg = JSON.parse(readFileSync(pkgPath, 'utf-8')) - cortexVersion = pkg.dependencies['cortexso'] - } // Setup core api for main process global.core = { // Define appPath function for app to retrieve app path globally appPath: () => app.getPath('userData'), - cortexVersion: () => cortexVersion, } } diff --git a/electron/utils/shortcut.ts b/electron/utils/shortcut.ts index 7f4d4c4f70..aa4607d9a1 100644 --- a/electron/utils/shortcut.ts +++ b/electron/utils/shortcut.ts @@ -1,11 +1,11 @@ -import { getAppConfigurations } from './../utils/path' +import { getAppConfigurations } from '@janhq/core/node' import { registerShortcut } from './selectedText' import { windowManager } from '../managers/window' // TODO: Retrieve from config later const quickAskHotKey = 'CommandOrControl+J' export function registerGlobalShortcuts() { - if (!getAppConfigurations().quickAsk) return + if (!getAppConfigurations().quick_ask) return const ret = registerShortcut(quickAskHotKey, (selectedText: string) => { // Feature Toggle for Quick Ask if (!windowManager.isQuickAskWindowVisible()) { diff --git a/electron/utils/system.ts b/electron/utils/system.ts new file mode 100644 index 0000000000..5799de8616 --- /dev/null +++ b/electron/utils/system.ts @@ -0,0 +1,16 @@ +import { log } from '@janhq/core/node' +import { app } from 'electron' +import os from 'os' + +export const logSystemInfo = (): void => { + log(`[SPECS]::Version: ${app.getVersion()}`) + log(`[SPECS]::CPUs: ${JSON.stringify(os.cpus())}`) + log(`[SPECS]::Machine: ${os.machine()}`) + log(`[SPECS]::Endianness: ${os.endianness()}`) + log(`[SPECS]::Parallelism: ${os.availableParallelism()}`) + log(`[SPECS]::Free Mem: ${os.freemem()}`) + log(`[SPECS]::Total Mem: ${os.totalmem()}`) + log(`[SPECS]::OS Version: ${os.version()}`) + log(`[SPECS]::OS Platform: ${os.platform()}`) + log(`[SPECS]::OS Release: ${os.release()}`) +} diff --git a/extensions/assistant-extension/README.md b/extensions/assistant-extension/README.md new file mode 100644 index 0000000000..f9690da09d --- /dev/null +++ b/extensions/assistant-extension/README.md @@ -0,0 +1,75 @@ +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/assistant-extension/package.json b/extensions/assistant-extension/package.json new file mode 100644 index 0000000000..aa5dba6922 --- /dev/null +++ b/extensions/assistant-extension/package.json @@ -0,0 +1,50 @@ +{ + "name": "@janhq/assistant-extension", + "productName": "Jan Assistant", + "version": "1.0.1", + "description": "This extension enables assistants, including Jan, a default assistant that can call all downloaded models", + "main": "dist/index.js", + "node": "dist/node/index.js", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "clean:modules": "rimraf node_modules/pdf-parse/test && cd node_modules/pdf-parse/lib/pdf.js && rimraf v1.9.426 v1.10.88 v2.0.550", + "build": "yarn clean:modules && tsc --module commonjs && rollup -c rollup.config.ts", + "build:publish:linux": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install", + "build:publish:darwin": "rimraf *.tgz --glob && yarn build && ../../.github/scripts/auto-sign.sh && npm pack && cpx *.tgz ../../pre-install", + "build:publish:win32": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install", + "build:publish": "run-script-os" + }, + "devDependencies": { + "@rollup/plugin-commonjs": "^25.0.7", + "@rollup/plugin-json": "^6.1.0", + "@rollup/plugin-node-resolve": "^15.2.3", + "@rollup/plugin-replace": "^5.0.5", + "@types/pdf-parse": "^1.1.4", + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "rollup": "^2.38.5", + "rollup-plugin-define": "^1.0.1", + "rollup-plugin-sourcemaps": "^0.6.3", + "rollup-plugin-typescript2": "^0.36.0", + "typescript": "^5.3.3", + "run-script-os": "^1.1.6" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "@langchain/community": "0.0.13", + "hnswlib-node": "^1.4.2", + "langchain": "^0.0.214", + "pdf-parse": "^1.1.1", + "ts-loader": "^9.5.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "@janhq/core", + "hnswlib-node" + ] +} diff --git a/extensions/assistant-extension/rollup.config.ts b/extensions/assistant-extension/rollup.config.ts new file mode 100644 index 0000000000..263f6cc605 --- /dev/null +++ b/extensions/assistant-extension/rollup.config.ts @@ -0,0 +1,73 @@ +import resolve from '@rollup/plugin-node-resolve' +import commonjs from '@rollup/plugin-commonjs' +import sourceMaps from 'rollup-plugin-sourcemaps' +import typescript from 'rollup-plugin-typescript2' +import json from '@rollup/plugin-json' +import replace from '@rollup/plugin-replace' + +const packageJson = require('./package.json') + +export default [ + { + input: `src/index.ts`, + output: [{ file: packageJson.main, format: 'es', sourcemap: true }], + // Indicate here external modules you don't wanna include in your bundle (i.e.: 'lodash') + external: [], + watch: { + include: 'src/**', + }, + plugins: [ + replace({ + preventAssignment: true, + NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), + VERSION: JSON.stringify(packageJson.version), + }), + // Allow json resolution + json(), + // Compile TypeScript files + typescript({ useTsconfigDeclarationDir: true }), + // Compile TypeScript files + // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) + commonjs(), + // Allow node_modules resolution, so you can use 'external' to control + // which external modules to include in the bundle + // https://github.com/rollup/rollup-plugin-node-resolve#usage + resolve({ + extensions: ['.js', '.ts', '.svelte'], + browser: true, + }), + + // Resolve source maps to the original source + sourceMaps(), + ], + }, + { + input: `src/node/index.ts`, + output: [{ dir: 'dist/node', format: 'cjs', sourcemap: false }], + // Indicate here external modules you don't wanna include in your bundle (i.e.: 'lodash') + external: ['@janhq/core/node', 'path', 'hnswlib-node'], + watch: { + include: 'src/node/**', + }, + // inlineDynamicImports: true, + plugins: [ + // Allow json resolution + json(), + // Compile TypeScript files + typescript({ useTsconfigDeclarationDir: true }), + // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) + commonjs({ + ignoreDynamicRequires: true, + }), + // Allow node_modules resolution, so you can use 'external' to control + // which external modules to include in the bundle + // https://github.com/rollup/rollup-plugin-node-resolve#usage + resolve({ + extensions: ['.ts', '.js', '.json'], + }), + + // Resolve source maps to the original source + // sourceMaps(), + ], + }, +] diff --git a/extensions/assistant-extension/src/@types/global.d.ts b/extensions/assistant-extension/src/@types/global.d.ts new file mode 100644 index 0000000000..2ca4a40809 --- /dev/null +++ b/extensions/assistant-extension/src/@types/global.d.ts @@ -0,0 +1,2 @@ +declare const NODE: string +declare const VERSION: string diff --git a/extensions/assistant-extension/src/index.ts b/extensions/assistant-extension/src/index.ts new file mode 100644 index 0000000000..12441995ee --- /dev/null +++ b/extensions/assistant-extension/src/index.ts @@ -0,0 +1,150 @@ +import { + fs, + Assistant, + events, + joinPath, + AssistantExtension, + AssistantEvent, + ToolManager, +} from '@janhq/core' +import { RetrievalTool } from './tools/retrieval' + +export default class JanAssistantExtension extends AssistantExtension { + private static readonly _homeDir = 'file://assistants' + + async onLoad() { + // Register the retrieval tool + ToolManager.instance().register(new RetrievalTool()) + + // making the assistant directory + const assistantDirExist = await fs.existsSync( + JanAssistantExtension._homeDir + ) + if ( + localStorage.getItem(`${this.name}-version`) !== VERSION || + !assistantDirExist + ) { + if (!assistantDirExist) await fs.mkdir(JanAssistantExtension._homeDir) + + // Write assistant metadata + await this.createJanAssistant() + // Finished migration + localStorage.setItem(`${this.name}-version`, VERSION) + // Update the assistant list + events.emit(AssistantEvent.OnAssistantsUpdate, {}) + } + } + + /** + * Called when the extension is unloaded. + */ + onUnload(): void {} + + async createAssistant(assistant: Assistant): Promise { + const assistantDir = await joinPath([ + JanAssistantExtension._homeDir, + assistant.id, + ]) + if (!(await fs.existsSync(assistantDir))) await fs.mkdir(assistantDir) + + // store the assistant metadata json + const assistantMetadataPath = await joinPath([ + assistantDir, + 'assistant.json', + ]) + try { + await fs.writeFileSync( + assistantMetadataPath, + JSON.stringify(assistant, null, 2) + ) + } catch (err) { + console.error(err) + } + } + + async getAssistants(): Promise { + // get all the assistant directories + // get all the assistant metadata json + const results: Assistant[] = [] + const allFileName: string[] = await fs.readdirSync( + JanAssistantExtension._homeDir + ) + for (const fileName of allFileName) { + const filePath = await joinPath([ + JanAssistantExtension._homeDir, + fileName, + ]) + + if (!(await fs.fileStat(filePath))?.isDirectory) continue + const jsonFiles: string[] = (await fs.readdirSync(filePath)).filter( + (file: string) => file === 'assistant.json' + ) + + if (jsonFiles.length !== 1) { + // has more than one assistant file -> ignore + continue + } + + const content = await fs.readFileSync( + await joinPath([filePath, jsonFiles[0]]), + 'utf-8' + ) + const assistant: Assistant = + typeof content === 'object' ? content : JSON.parse(content) + + results.push(assistant) + } + + return results + } + + async deleteAssistant(assistant: Assistant): Promise { + if (assistant.id === 'jan') { + return Promise.reject('Cannot delete Jan Assistant') + } + + // remove the directory + const assistantDir = await joinPath([ + JanAssistantExtension._homeDir, + assistant.id, + ]) + return fs.rm(assistantDir) + } + + private async createJanAssistant(): Promise { + const janAssistant: Assistant = { + avatar: '', + thread_location: undefined, + id: 'jan', + object: 'assistant', + created_at: Date.now(), + name: 'Jan', + description: 'A default assistant that can use all downloaded models', + model: '*', + instructions: '', + tools: [ + { + type: 'retrieval', + enabled: false, + useTimeWeightedRetriever: false, + settings: { + top_k: 2, + chunk_size: 1024, + chunk_overlap: 64, + retrieval_template: `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. +---------------- +CONTEXT: {CONTEXT} +---------------- +QUESTION: {QUESTION} +---------------- +Helpful Answer:`, + }, + }, + ], + file_ids: [], + metadata: undefined, + } + + await this.createAssistant(janAssistant) + } +} diff --git a/extensions/assistant-extension/src/node/engine.ts b/extensions/assistant-extension/src/node/engine.ts new file mode 100644 index 0000000000..05a3803406 --- /dev/null +++ b/extensions/assistant-extension/src/node/engine.ts @@ -0,0 +1,38 @@ +import fs from 'fs' +import path from 'path' +import { SettingComponentProps, getJanDataFolderPath } from '@janhq/core/node' + +// Sec: Do not send engine settings over requests +// Read it manually instead +export const readEmbeddingEngine = (engineName: string) => { + if (engineName !== 'openai' && engineName !== 'groq') { + const engineSettings = fs.readFileSync( + path.join(getJanDataFolderPath(), 'engines', `${engineName}.json`), + 'utf-8' + ) + return JSON.parse(engineSettings) + } else { + const settingDirectoryPath = path.join( + getJanDataFolderPath(), + 'settings', + '@janhq', + // TODO: James - To be removed + engineName === 'openai' + ? 'inference-openai-extension' + : 'inference-groq-extension', + 'settings.json' + ) + + const content = fs.readFileSync(settingDirectoryPath, 'utf-8') + const settings: SettingComponentProps[] = JSON.parse(content) + const apiKeyId = engineName === 'openai' ? 'openai-api-key' : 'groq-api-key' + const keySetting = settings.find((setting) => setting.key === apiKeyId) + + let apiKey = keySetting?.controllerProps.value + if (typeof apiKey !== 'string') apiKey = '' + + return { + api_key: apiKey, + } + } +} diff --git a/extensions/assistant-extension/src/node/index.ts b/extensions/assistant-extension/src/node/index.ts new file mode 100644 index 0000000000..83a4a19831 --- /dev/null +++ b/extensions/assistant-extension/src/node/index.ts @@ -0,0 +1,44 @@ +import { getJanDataFolderPath, normalizeFilePath } from '@janhq/core/node' +import { retrieval } from './retrieval' +import path from 'path' + +export function toolRetrievalUpdateTextSplitter( + chunkSize: number, + chunkOverlap: number +) { + retrieval.updateTextSplitter(chunkSize, chunkOverlap) +} +export async function toolRetrievalIngestNewDocument( + file: string, + model: string, + engine: string, + useTimeWeighted: boolean +) { + const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file)) + const threadPath = path.dirname(filePath.replace('files', '')) + retrieval.updateEmbeddingEngine(model, engine) + return retrieval + .ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted) + .catch((err) => { + console.error(err) + }) +} + +export async function toolRetrievalLoadThreadMemory(threadId: string) { + return retrieval + .loadRetrievalAgent( + path.join(getJanDataFolderPath(), 'threads', threadId, 'memory') + ) + .catch((err) => { + console.error(err) + }) +} + +export async function toolRetrievalQueryResult( + query: string, + useTimeWeighted: boolean = false +) { + return retrieval.generateResult(query, useTimeWeighted).catch((err) => { + console.error(err) + }) +} diff --git a/extensions/assistant-extension/src/node/retrieval.ts b/extensions/assistant-extension/src/node/retrieval.ts new file mode 100644 index 0000000000..28d629aa80 --- /dev/null +++ b/extensions/assistant-extension/src/node/retrieval.ts @@ -0,0 +1,128 @@ +import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter' +import { formatDocumentsAsString } from 'langchain/util/document' +import { PDFLoader } from 'langchain/document_loaders/fs/pdf' + +import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted' +import { MemoryVectorStore } from 'langchain/vectorstores/memory' + +import { HNSWLib } from 'langchain/vectorstores/hnswlib' + +import { OpenAIEmbeddings } from 'langchain/embeddings/openai' +import { readEmbeddingEngine } from './engine' + +import path from 'path' + +export class Retrieval { + public chunkSize: number = 100 + public chunkOverlap?: number = 0 + private retriever: any + + private embeddingModel?: OpenAIEmbeddings = undefined + private textSplitter?: RecursiveCharacterTextSplitter + + // to support time-weighted retrieval + private timeWeightedVectorStore: MemoryVectorStore + private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever + + constructor(chunkSize: number = 4000, chunkOverlap: number = 200) { + this.updateTextSplitter(chunkSize, chunkOverlap) + + // declare time-weighted retriever and storage + this.timeWeightedVectorStore = new MemoryVectorStore( + new OpenAIEmbeddings( + { openAIApiKey: 'nitro-embedding' }, + { basePath: 'http://127.0.0.1:3928/v1' } + ) + ) + this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({ + vectorStore: this.timeWeightedVectorStore, + memoryStream: [], + searchKwargs: 2, + }) + } + + public updateTextSplitter(chunkSize: number, chunkOverlap: number): void { + this.chunkSize = chunkSize + this.chunkOverlap = chunkOverlap + this.textSplitter = new RecursiveCharacterTextSplitter({ + chunkSize: chunkSize, + chunkOverlap: chunkOverlap, + }) + } + + public updateEmbeddingEngine(model: string, engine: string): void { + // Engine settings are not compatible with the current embedding model params + // Switch case manually for now + if (engine === 'nitro') { + this.embeddingModel = new OpenAIEmbeddings( + { openAIApiKey: 'nitro-embedding', model }, + // TODO: Raw settings + { basePath: 'http://127.0.0.1:3928/v1' }, + ) + } else { + // Fallback to OpenAI Settings + const settings = readEmbeddingEngine(engine) + this.embeddingModel = new OpenAIEmbeddings({ + openAIApiKey: settings.api_key, + }) + } + + // update time-weighted embedding model + this.timeWeightedVectorStore.embeddings = this.embeddingModel + } + + public ingestAgentKnowledge = async ( + filePath: string, + memoryPath: string, + useTimeWeighted: boolean + ): Promise => { + const loader = new PDFLoader(filePath, { + splitPages: true, + }) + if (!this.embeddingModel) return Promise.reject() + const doc = await loader.load() + const docs = await this.textSplitter!.splitDocuments(doc) + const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel) + + // add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval + if (useTimeWeighted && this.timeWeightedretriever) { + await ( + this.timeWeightedretriever as TimeWeightedVectorStoreRetriever + ).addDocuments(docs) + } + return vectorStore.save(memoryPath) + } + + public loadRetrievalAgent = async (memoryPath: string): Promise => { + if (!this.embeddingModel) return Promise.reject() + const vectorStore = await HNSWLib.load(memoryPath, this.embeddingModel) + this.retriever = vectorStore.asRetriever(2) + return Promise.resolve() + } + + public generateResult = async ( + query: string, + useTimeWeighted: boolean + ): Promise => { + if (useTimeWeighted) { + if (!this.timeWeightedretriever) { + return Promise.resolve(' ') + } + // use invoke because getRelevantDocuments is deprecated + const relevantDocs = await this.timeWeightedretriever.invoke(query) + const serializedDoc = formatDocumentsAsString(relevantDocs) + return Promise.resolve(serializedDoc) + } + + if (!this.retriever) { + return Promise.resolve(' ') + } + + // should use invoke(query) because getRelevantDocuments is deprecated + const relevantDocs = await this.retriever.getRelevantDocuments(query) + const serializedDoc = formatDocumentsAsString(relevantDocs) + return Promise.resolve(serializedDoc) + } +} + +export const retrieval = new Retrieval() diff --git a/extensions/assistant-extension/src/tools/retrieval.ts b/extensions/assistant-extension/src/tools/retrieval.ts new file mode 100644 index 0000000000..7631922871 --- /dev/null +++ b/extensions/assistant-extension/src/tools/retrieval.ts @@ -0,0 +1,117 @@ +import { + AssistantTool, + executeOnMain, + fs, + InferenceTool, + joinPath, + MessageRequest, +} from '@janhq/core' + +export class RetrievalTool extends InferenceTool { + private _threadDir = 'file://threads' + private retrievalThreadId: string | undefined = undefined + + name: string = 'retrieval' + + async process( + data: MessageRequest, + tool?: AssistantTool + ): Promise { + if (!data.model || !data.messages) { + return Promise.resolve(data) + } + + const latestMessage = data.messages[data.messages.length - 1] + + // 1. Ingest the document if needed + if ( + latestMessage && + latestMessage.content && + typeof latestMessage.content !== 'string' && + latestMessage.content.length > 1 + ) { + const docFile = latestMessage.content[1]?.doc_url?.url + if (docFile) { + await executeOnMain( + NODE, + 'toolRetrievalIngestNewDocument', + docFile, + data.model?.id, + data.model?.engine, + tool?.useTimeWeightedRetriever ?? false + ) + } else { + return Promise.resolve(data) + } + } else if ( + // Check whether we need to ingest document or not + // Otherwise wrong context will be sent + !(await fs.existsSync( + await joinPath([this._threadDir, data.threadId, 'memory']) + )) + ) { + // No document ingested, reroute the result to inference engine + + return Promise.resolve(data) + } + // 2. Load agent on thread changed + if (this.retrievalThreadId !== data.threadId) { + await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId) + + this.retrievalThreadId = data.threadId + + // Update the text splitter + await executeOnMain( + NODE, + 'toolRetrievalUpdateTextSplitter', + tool?.settings?.chunk_size ?? 4000, + tool?.settings?.chunk_overlap ?? 200 + ) + } + + // 3. Using the retrieval template with the result and query + if (latestMessage.content) { + const prompt = + typeof latestMessage.content === 'string' + ? latestMessage.content + : latestMessage.content[0].text + // Retrieve the result + const retrievalResult = await executeOnMain( + NODE, + 'toolRetrievalQueryResult', + prompt, + tool?.useTimeWeightedRetriever ?? false + ) + console.debug('toolRetrievalQueryResult', retrievalResult) + + // Update message content + if (retrievalResult) + data.messages[data.messages.length - 1].content = + tool?.settings?.retrieval_template + ?.replace('{CONTEXT}', retrievalResult) + .replace('{QUESTION}', prompt) + } + + // 4. Reroute the result to inference engine + return Promise.resolve(this.normalize(data)) + } + + // Filter out all the messages that are not text + // TODO: Remove it until engines can handle multiple content types + normalize(request: MessageRequest): MessageRequest { + request.messages = request.messages?.map((message) => { + if ( + message.content && + typeof message.content !== 'string' && + (message.content.length ?? 0) > 0 + ) { + return { + ...message, + content: [message.content[0]], + } + } + return message + }) + return request + } +} diff --git a/extensions/assistant-extension/tsconfig.json b/extensions/assistant-extension/tsconfig.json new file mode 100644 index 0000000000..e425358c35 --- /dev/null +++ b/extensions/assistant-extension/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "moduleResolution": "node", + "target": "es5", + "module": "ES2020", + "lib": ["es2015", "es2016", "es2017", "dom"], + "strict": true, + "sourceMap": true, + "declaration": true, + "allowSyntheticDefaultImports": true, + "experimentalDecorators": true, + "emitDecoratorMetadata": true, + "declarationDir": "dist/types", + "outDir": "dist", + "importHelpers": true, + "typeRoots": ["node_modules/@types"], + "skipLibCheck": true + }, + "include": ["src"] +} diff --git a/extensions/conversational-extension/package.json b/extensions/conversational-extension/package.json new file mode 100644 index 0000000000..d062ce9c33 --- /dev/null +++ b/extensions/conversational-extension/package.json @@ -0,0 +1,36 @@ +{ + "name": "@janhq/conversational-extension", + "productName": "Conversational", + "version": "1.0.0", + "description": "This extension enables conversations and state persistence via your filesystem", + "main": "dist/index.js", + "author": "Jan ", + "license": "MIT", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4", + "ts-loader": "^9.5.0" + }, + "dependencies": { + "@janhq/core": "file:../../core" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [] +} diff --git a/extensions/conversational-extension/src/index.ts b/extensions/conversational-extension/src/index.ts new file mode 100644 index 0000000000..1bca75347d --- /dev/null +++ b/extensions/conversational-extension/src/index.ts @@ -0,0 +1,277 @@ +import { + fs, + joinPath, + ConversationalExtension, + Thread, + ThreadMessage, +} from '@janhq/core' + +/** + * JSONConversationalExtension is a ConversationalExtension implementation that provides + * functionality for managing threads. + */ +export default class JSONConversationalExtension extends ConversationalExtension { + private static readonly _threadFolder = 'file://threads' + private static readonly _threadInfoFileName = 'thread.json' + private static readonly _threadMessagesFileName = 'messages.jsonl' + + /** + * Called when the extension is loaded. + */ + async onLoad() { + if (!(await fs.existsSync(JSONConversationalExtension._threadFolder))) { + await fs.mkdir(JSONConversationalExtension._threadFolder) + } + } + + /** + * Called when the extension is unloaded. + */ + onUnload() { + console.debug('JSONConversationalExtension unloaded') + } + + /** + * Returns a Promise that resolves to an array of Conversation objects. + */ + async getThreads(): Promise { + try { + const threadDirs = await this.getValidThreadDirs() + + const promises = threadDirs.map((dirName) => this.readThread(dirName)) + const promiseResults = await Promise.allSettled(promises) + const convos = promiseResults + .map((result) => { + if (result.status === 'fulfilled') { + return typeof result.value === 'object' + ? result.value + : JSON.parse(result.value) + } + }) + .filter((convo) => convo != null) + convos.sort( + (a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime() + ) + + return convos + } catch (error) { + console.error(error) + return [] + } + } + + /** + * Saves a Thread object to a json file. + * @param thread The Thread object to save. + */ + async saveThread(thread: Thread): Promise { + try { + const threadDirPath = await joinPath([ + JSONConversationalExtension._threadFolder, + thread.id, + ]) + const threadJsonPath = await joinPath([ + threadDirPath, + JSONConversationalExtension._threadInfoFileName, + ]) + if (!(await fs.existsSync(threadDirPath))) { + await fs.mkdir(threadDirPath) + } + + await fs.writeFileSync(threadJsonPath, JSON.stringify(thread, null, 2)) + } catch (err) { + console.error(err) + Promise.reject(err) + } + } + + /** + * Delete a thread with the specified ID. + * @param threadId The ID of the thread to delete. + */ + async deleteThread(threadId: string): Promise { + const path = await joinPath([ + JSONConversationalExtension._threadFolder, + `${threadId}`, + ]) + try { + await fs.rm(path) + } catch (err) { + console.error(err) + } + } + + async addNewMessage(message: ThreadMessage): Promise { + try { + const threadDirPath = await joinPath([ + JSONConversationalExtension._threadFolder, + message.thread_id, + ]) + const threadMessagePath = await joinPath([ + threadDirPath, + JSONConversationalExtension._threadMessagesFileName, + ]) + if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath) + + if (message.content[0]?.type === 'image') { + const filesPath = await joinPath([threadDirPath, 'files']) + if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath) + + const imagePath = await joinPath([filesPath, `${message.id}.png`]) + const base64 = message.content[0].text.annotations[0] + await this.storeImage(base64, imagePath) + if ((await fs.existsSync(imagePath)) && message.content?.length) { + // Use file path instead of blob + message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.png` + } + } + + if (message.content[0]?.type === 'pdf') { + const filesPath = await joinPath([threadDirPath, 'files']) + if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath) + + const filePath = await joinPath([filesPath, `${message.id}.pdf`]) + const blob = message.content[0].text.annotations[0] + await this.storeFile(blob, filePath) + + if ((await fs.existsSync(filePath)) && message.content?.length) { + // Use file path instead of blob + message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.pdf` + } + } + await fs.appendFileSync(threadMessagePath, JSON.stringify(message) + '\n') + Promise.resolve() + } catch (err) { + Promise.reject(err) + } + } + + async storeImage(base64: string, filePath: string): Promise { + const base64Data = base64.replace(/^data:image\/\w+;base64,/, '') + + try { + await fs.writeBlob(filePath, base64Data) + } catch (err) { + console.error(err) + } + } + + async storeFile(base64: string, filePath: string): Promise { + const base64Data = base64.replace(/^data:application\/pdf;base64,/, '') + try { + await fs.writeBlob(filePath, base64Data) + } catch (err) { + console.error(err) + } + } + + async writeMessages( + threadId: string, + messages: ThreadMessage[] + ): Promise { + try { + const threadDirPath = await joinPath([ + JSONConversationalExtension._threadFolder, + threadId, + ]) + const threadMessagePath = await joinPath([ + threadDirPath, + JSONConversationalExtension._threadMessagesFileName, + ]) + if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath) + await fs.writeFileSync( + threadMessagePath, + messages.map((msg) => JSON.stringify(msg)).join('\n') + + (messages.length ? '\n' : '') + ) + Promise.resolve() + } catch (err) { + Promise.reject(err) + } + } + + /** + * A promise builder for reading a thread from a file. + * @param threadDirName the thread dir we are reading from. + * @returns data of the thread + */ + private async readThread(threadDirName: string): Promise { + return fs.readFileSync( + await joinPath([ + JSONConversationalExtension._threadFolder, + threadDirName, + JSONConversationalExtension._threadInfoFileName, + ]), + 'utf-8' + ) + } + + /** + * Returns a Promise that resolves to an array of thread directories. + * @private + */ + private async getValidThreadDirs(): Promise { + const fileInsideThread: string[] = await fs.readdirSync( + JSONConversationalExtension._threadFolder + ) + + const threadDirs: string[] = [] + for (let i = 0; i < fileInsideThread.length; i++) { + const path = await joinPath([ + JSONConversationalExtension._threadFolder, + fileInsideThread[i], + ]) + if (!(await fs.fileStat(path))?.isDirectory) continue + + const isHavingThreadInfo = (await fs.readdirSync(path)).includes( + JSONConversationalExtension._threadInfoFileName + ) + if (!isHavingThreadInfo) { + console.debug(`Ignore ${path} because it does not have thread info`) + continue + } + + threadDirs.push(fileInsideThread[i]) + } + return threadDirs + } + + async getAllMessages(threadId: string): Promise { + try { + const threadDirPath = await joinPath([ + JSONConversationalExtension._threadFolder, + threadId, + ]) + + const files: string[] = await fs.readdirSync(threadDirPath) + if ( + !files.includes(JSONConversationalExtension._threadMessagesFileName) + ) { + console.debug(`${threadDirPath} not contains message file`) + return [] + } + + const messageFilePath = await joinPath([ + threadDirPath, + JSONConversationalExtension._threadMessagesFileName, + ]) + + let readResult = await fs.readFileSync(messageFilePath, 'utf-8') + + if (typeof readResult === 'object') { + readResult = JSON.stringify(readResult) + } + + const result = readResult.split('\n').filter((line) => line !== '') + + const messages: ThreadMessage[] = [] + result.forEach((line: string) => { + messages.push(JSON.parse(line)) + }) + return messages + } catch (err) { + console.error(err) + return [] + } + } +} diff --git a/extensions/conversational-extension/tsconfig.json b/extensions/conversational-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/conversational-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/conversational-extension/webpack.config.js b/extensions/conversational-extension/webpack.config.js new file mode 100644 index 0000000000..e4a0b2179e --- /dev/null +++ b/extensions/conversational-extension/webpack.config.js @@ -0,0 +1,29 @@ +const webpack = require('webpack') + +module.exports = { + experiments: { outputModule: true }, + entry: './src/index.ts', // Adjust the entry point to match your project's main file + mode: 'production', + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + ], + }, + output: { + filename: 'index.js', // Adjust the output file name as needed + library: { type: 'module' }, // Specify ESM output format + }, + plugins: [new webpack.DefinePlugin({})], + resolve: { + extensions: ['.ts', '.js'], + }, + // Do not minify the output, otherwise it breaks the class registration + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +} diff --git a/extensions/inference-anthropic-extension/README.md b/extensions/inference-anthropic-extension/README.md new file mode 100644 index 0000000000..1c0dcbd3d4 --- /dev/null +++ b/extensions/inference-anthropic-extension/README.md @@ -0,0 +1,79 @@ +# Anthropic Engine Extension + +Created using Jan extension example + +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-anthropic-extension/package.json b/extensions/inference-anthropic-extension/package.json new file mode 100644 index 0000000000..a9d30a8e5d --- /dev/null +++ b/extensions/inference-anthropic-extension/package.json @@ -0,0 +1,43 @@ +{ + "name": "@janhq/inference-anthropic-extension", + "productName": "Anthropic Inference Engine", + "version": "1.0.2", + "description": "This extension enables Anthropic chat completion API calls", + "main": "dist/index.js", + "module": "dist/module.js", + "engine": "anthropic", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install", + "sync:core": "cd ../.. && yarn build:core && cd extensions && rm yarn.lock && cd inference-anthropic-extension && yarn && yarn build:publish" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4", + "ts-loader": "^9.5.0" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "fetch-retry": "^5.0.6", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "fetch-retry" + ] +} diff --git a/extensions/inference-anthropic-extension/resources/models.json b/extensions/inference-anthropic-extension/resources/models.json new file mode 100644 index 0000000000..1462837ac0 --- /dev/null +++ b/extensions/inference-anthropic-extension/resources/models.json @@ -0,0 +1,98 @@ +[ + { + "sources": [ + { + "url": "https://www.anthropic.com/" + } + ], + "id": "claude-3-opus-20240229", + "object": "model", + "name": "Claude 3 Opus", + "version": "1.0", + "description": "Claude 3 Opus is a powerful model suitables for highly complex task.", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "stream": false + }, + "metadata": { + "author": "Anthropic", + "tags": ["General", "Big Context Length"] + }, + "engine": "anthropic" + }, + { + "sources": [ + { + "url": "https://www.anthropic.com/" + } + ], + "id": "claude-3-sonnet-20240229", + "object": "model", + "name": "Claude 3 Sonnet", + "version": "1.0", + "description": "Claude 3 Sonnet is an ideal model balance of intelligence and speed for enterprise workloads.", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "stream": false + }, + "metadata": { + "author": "Anthropic", + "tags": ["General", "Big Context Length"] + }, + "engine": "anthropic" + }, + { + "sources": [ + { + "url": "https://www.anthropic.com/" + } + ], + "id": "claude-3-haiku-20240307", + "object": "model", + "name": "Claude 3 Haiku", + "version": "1.0", + "description": "Claude 3 Haiku is the fastest model provides near-instant responsiveness.", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "stream": false + }, + "metadata": { + "author": "Anthropic", + "tags": ["General", "Big Context Length"] + }, + "engine": "anthropic" + }, + { + "sources": [ + { + "url": "https://www.anthropic.com/" + } + ], + "id": "claude-3-5-sonnet-20240620", + "object": "model", + "name": "Claude 3.5 Sonnet", + "version": "1.0", + "description": "Claude 3.5 Sonnet raises the industry bar for intelligence, outperforming competitor models and Claude 3 Opus on a wide range of evaluations, with the speed and cost of our mid-tier model, Claude 3 Sonnet.", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "stream": true + }, + "metadata": { + "author": "Anthropic", + "tags": ["General", "Big Context Length"] + }, + "engine": "anthropic" + } +] diff --git a/extensions/inference-anthropic-extension/resources/settings.json b/extensions/inference-anthropic-extension/resources/settings.json new file mode 100644 index 0000000000..bb35e6b3d3 --- /dev/null +++ b/extensions/inference-anthropic-extension/resources/settings.json @@ -0,0 +1,23 @@ +[ + { + "key": "chat-completions-endpoint", + "title": "Chat Completions Endpoint", + "description": "The endpoint to use for chat completions. See the [Anthropic API documentation](https://docs.anthropic.com/claude/docs/intro-to-claude) for more information.", + "controllerType": "input", + "controllerProps": { + "placeholder": "https://api.anthropic.com/v1/messages", + "value": "https://api.anthropic.com/v1/messages" + } + }, + { + "key": "anthropic-api-key", + "title": "API Key", + "description": "The Anthropic API uses API keys for authentication. Visit your [API Keys](https://console.anthropic.com/settings/keys) page to retrieve the API key you'll use in your requests.", + "controllerType": "input", + "controllerProps": { + "placeholder": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "value": "", + "type": "password" + } + } +] \ No newline at end of file diff --git a/extensions/inference-anthropic-extension/src/index.ts b/extensions/inference-anthropic-extension/src/index.ts new file mode 100644 index 0000000000..f28a584f25 --- /dev/null +++ b/extensions/inference-anthropic-extension/src/index.ts @@ -0,0 +1,148 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-anthropic-extension/src/index + */ + +import { RemoteOAIEngine } from '@janhq/core' +import { PayloadType } from '@janhq/core' +import { ChatCompletionRole } from '@janhq/core' + +declare const SETTINGS: Array +declare const MODELS: Array + +enum Settings { + apiKey = 'anthropic-api-key', + chatCompletionsEndPoint = 'chat-completions-endpoint', +} + +type AnthropicPayloadType = { + stream: boolean + model?: string + max_tokens?: number + messages?: Array<{ role: string; content: string }> +} + +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanInferenceAnthropicExtension extends RemoteOAIEngine { + inferenceUrl: string = '' + provider: string = 'anthropic' + maxTokens: number = 4096 + + override async onLoad(): Promise { + super.onLoad() + + // Register Settings + this.registerSettings(SETTINGS) + this.registerModels(MODELS) + + this.apiKey = await this.getSetting(Settings.apiKey, '') + this.inferenceUrl = await this.getSetting( + Settings.chatCompletionsEndPoint, + '' + ) + + if (this.inferenceUrl.length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } + } + + // Override the headers method to include the x-API-key in the request headers + override async headers(): Promise { + return { + 'Content-Type': 'application/json', + 'x-api-key': this.apiKey, + 'anthropic-version': '2023-06-01', + } + } + + onSettingUpdate(key: string, value: T): void { + if (key === Settings.apiKey) { + this.apiKey = value as string + } else if (key === Settings.chatCompletionsEndPoint) { + if (typeof value !== 'string') return + + if (value.trim().length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } else { + this.inferenceUrl = value + } + } + } + + // Override the transformPayload method to convert the payload to the required format + transformPayload = (payload: PayloadType): AnthropicPayloadType => { + if (!payload.messages || payload.messages.length === 0) { + return { + max_tokens: this.maxTokens, + messages: [], + model: payload.model, + stream: payload.stream, + } + } + + const convertedData: AnthropicPayloadType = { + max_tokens: this.maxTokens, + messages: [], + model: payload.model, + stream: payload.stream, + } + + payload.messages.forEach((item) => { + if (item.role === ChatCompletionRole.User) { + convertedData.messages.push({ + role: 'user', + content: item.content as string, + }) + } else if (item.role === ChatCompletionRole.Assistant) { + convertedData.messages.push({ + role: 'assistant', + content: item.content as string, + }) + } + }) + + return convertedData + } + + // Sample returned stream data from anthropic + // {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + // {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"} } + // {"type":"content_block_stop","index":0 } + // {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":12} } + + // Override the transformResponse method to convert the response to the required format + transformResponse = (data: any): string => { + // handling stream response + if (typeof data === 'string' && data.trim().length === 0) return '' + if (typeof data === 'string' && data.startsWith('event: ')) return '' + if (typeof data === 'string' && data.startsWith('data: ')) { + data = data.replace('data: ', '') + const parsedData = JSON.parse(data) + if (parsedData.type !== 'content_block_delta') return '' + return parsedData.delta?.text ?? '' + } + + // non stream response + if (data.content && data.content.length > 0 && data.content[0].text) { + return data.content[0].text + } + + console.error('Invalid response format:', data) + return '' + } +} diff --git a/extensions/inference-anthropic-extension/tsconfig.json b/extensions/inference-anthropic-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/inference-anthropic-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/inference-anthropic-extension/webpack.config.js b/extensions/inference-anthropic-extension/webpack.config.js new file mode 100644 index 0000000000..cd5e65c725 --- /dev/null +++ b/extensions/inference-anthropic-extension/webpack.config.js @@ -0,0 +1,37 @@ +const webpack = require('webpack') +const packageJson = require('./package.json') +const settingJson = require('./resources/settings.json') +const modelsJson = require('./resources/models.json') + +module.exports = { + experiments: { outputModule: true }, + entry: './src/index.ts', // Adjust the entry point to match your project's main file + mode: 'production', + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + ], + }, + plugins: [ + new webpack.DefinePlugin({ + MODELS: JSON.stringify(modelsJson), + SETTINGS: JSON.stringify(settingJson), + ENGINE: JSON.stringify(packageJson.engine), + }), + ], + output: { + filename: 'index.js', // Adjust the output file name as needed + library: { type: 'module' }, // Specify ESM output format + }, + resolve: { + extensions: ['.ts', '.js'], + }, + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +} diff --git a/extensions/inference-cohere-extension/README.md b/extensions/inference-cohere-extension/README.md new file mode 100644 index 0000000000..089a096e8b --- /dev/null +++ b/extensions/inference-cohere-extension/README.md @@ -0,0 +1,79 @@ +# Cohere Engine Extension + +Created using Jan extension example + +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-cohere-extension/package.json b/extensions/inference-cohere-extension/package.json new file mode 100644 index 0000000000..ea03bb33b9 --- /dev/null +++ b/extensions/inference-cohere-extension/package.json @@ -0,0 +1,43 @@ +{ + "name": "@janhq/inference-cohere-extension", + "productName": "Cohere Inference Engine", + "version": "1.0.0", + "description": "This extension enables Cohere chat completion API calls", + "main": "dist/index.js", + "module": "dist/module.js", + "engine": "cohere", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install", + "sync:core": "cd ../.. && yarn build:core && cd extensions && rm yarn.lock && cd inference-cohere-extension && yarn && yarn build:publish" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4", + "ts-loader": "^9.5.0" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "fetch-retry": "^5.0.6", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "fetch-retry" + ] +} diff --git a/extensions/inference-cohere-extension/resources/models.json b/extensions/inference-cohere-extension/resources/models.json new file mode 100644 index 0000000000..2b4cc3e8e4 --- /dev/null +++ b/extensions/inference-cohere-extension/resources/models.json @@ -0,0 +1,56 @@ +[ + { + "sources": [ + { + "url": "https://cohere.com" + } + ], + "id": "command-r-plus", + "object": "model", + "name": "Command R+", + "version": "1.0", + "description": "Command R+ is an instruction-following conversational model that performs language tasks at a higher quality, more reliably, and with a longer context than previous models. It is best suited for complex RAG workflows and multi-step tool use.", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 128000, + "temperature": 0.7, + "stream": false + }, + "metadata": { + "author": "Cohere", + "tags": [ + "General", + "Big Context Length" + ] + }, + "engine": "cohere" + }, + { + "sources": [ + { + "url": "https://cohere.com" + } + ], + "id": "command-r", + "object": "model", + "name": "Command R", + "version": "1.0", + "description": "Command R is an instruction-following conversational model that performs language tasks at a higher quality, more reliably, and with a longer context than previous models. It can be used for complex workflows like code generation, retrieval augmented generation (RAG), tool use, and agents.", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 128000, + "temperature": 0.7, + "stream": false + }, + "metadata": { + "author": "Cohere", + "tags": [ + "General", + "Big Context Length" + ] + }, + "engine": "cohere" + } +] diff --git a/extensions/inference-cohere-extension/resources/settings.json b/extensions/inference-cohere-extension/resources/settings.json new file mode 100644 index 0000000000..2a32b57f8b --- /dev/null +++ b/extensions/inference-cohere-extension/resources/settings.json @@ -0,0 +1,23 @@ +[ + { + "key": "chat-completions-endpoint", + "title": "Chat Completions Endpoint", + "description": "The endpoint to use for chat completions. See the [Cohere API documentation](https://docs.cohere.com/reference/chat) for more information.", + "controllerType": "input", + "controllerProps": { + "placeholder": "https://api.cohere.ai/v1/chat", + "value": "https://api.cohere.ai/v1/chat" + } + }, + { + "key": "cohere-api-key", + "title": "API Key", + "description": "The Cohere API uses API keys for authentication. Visit your [API Keys](https://dashboard.cohere.com/api-keys) page to retrieve the API key you'll use in your requests.", + "controllerType": "input", + "controllerProps": { + "placeholder": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "value": "", + "type": "password" + } + } +] diff --git a/extensions/inference-cohere-extension/src/index.ts b/extensions/inference-cohere-extension/src/index.ts new file mode 100644 index 0000000000..dd7f033174 --- /dev/null +++ b/extensions/inference-cohere-extension/src/index.ts @@ -0,0 +1,118 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-cohere-extension/src/index + */ + +import { RemoteOAIEngine } from '@janhq/core' +import { PayloadType } from '@janhq/core' +import { ChatCompletionRole } from '@janhq/core' + +declare const SETTINGS: Array +declare const MODELS: Array + +enum Settings { + apiKey = 'cohere-api-key', + chatCompletionsEndPoint = 'chat-completions-endpoint', +} + +enum RoleType { + user = 'USER', + chatbot = 'CHATBOT', + system = 'SYSTEM', +} + +type CoherePayloadType = { + chat_history?: Array<{ role: RoleType; message: string }> + message?: string + preamble?: string +} + +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanInferenceCohereExtension extends RemoteOAIEngine { + inferenceUrl: string = '' + provider: string = 'cohere' + + override async onLoad(): Promise { + super.onLoad() + + // Register Settings + this.registerSettings(SETTINGS) + this.registerModels(MODELS) + + this.apiKey = await this.getSetting(Settings.apiKey, '') + this.inferenceUrl = await this.getSetting( + Settings.chatCompletionsEndPoint, + '' + ) + if (this.inferenceUrl.length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } + } + + onSettingUpdate(key: string, value: T): void { + if (key === Settings.apiKey) { + this.apiKey = value as string + } else if (key === Settings.chatCompletionsEndPoint) { + if (typeof value !== 'string') return + + if (value.trim().length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } else { + this.inferenceUrl = value + } + } + } + + transformPayload = (payload: PayloadType): CoherePayloadType => { + if (payload.messages.length === 0) { + return {} + } + + const { messages, ...params } = payload + const convertedData: CoherePayloadType = { + ...params, + chat_history: [], + message: '', + } + messages.forEach((item, index) => { + // Assign the message of the last item to the `message` property + if (index === messages.length - 1) { + convertedData.message = item.content as string + return + } + if (item.role === ChatCompletionRole.User) { + convertedData.chat_history.push({ + role: RoleType.user, + message: item.content as string, + }) + } else if (item.role === ChatCompletionRole.Assistant) { + convertedData.chat_history.push({ + role: RoleType.chatbot, + message: item.content as string, + }) + } else if (item.role === ChatCompletionRole.System) { + convertedData.preamble = item.content as string + } + }) + return convertedData + } + + transformResponse = (data: any) => { + return typeof data === 'object' ? data.text : JSON.parse(data).text ?? '' + } +} diff --git a/extensions/inference-cohere-extension/tsconfig.json b/extensions/inference-cohere-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/inference-cohere-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/inference-cohere-extension/webpack.config.js b/extensions/inference-cohere-extension/webpack.config.js new file mode 100644 index 0000000000..cd5e65c725 --- /dev/null +++ b/extensions/inference-cohere-extension/webpack.config.js @@ -0,0 +1,37 @@ +const webpack = require('webpack') +const packageJson = require('./package.json') +const settingJson = require('./resources/settings.json') +const modelsJson = require('./resources/models.json') + +module.exports = { + experiments: { outputModule: true }, + entry: './src/index.ts', // Adjust the entry point to match your project's main file + mode: 'production', + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + ], + }, + plugins: [ + new webpack.DefinePlugin({ + MODELS: JSON.stringify(modelsJson), + SETTINGS: JSON.stringify(settingJson), + ENGINE: JSON.stringify(packageJson.engine), + }), + ], + output: { + filename: 'index.js', // Adjust the output file name as needed + library: { type: 'module' }, // Specify ESM output format + }, + resolve: { + extensions: ['.ts', '.js'], + }, + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +} diff --git a/extensions/inference-groq-extension/README.md b/extensions/inference-groq-extension/README.md new file mode 100644 index 0000000000..f9690da09d --- /dev/null +++ b/extensions/inference-groq-extension/README.md @@ -0,0 +1,75 @@ +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-groq-extension/package.json b/extensions/inference-groq-extension/package.json new file mode 100644 index 0000000000..509cb7611d --- /dev/null +++ b/extensions/inference-groq-extension/package.json @@ -0,0 +1,41 @@ +{ + "name": "@janhq/inference-groq-extension", + "productName": "Groq Inference Engine", + "version": "1.0.1", + "description": "This extension enables fast Groq chat completion API calls", + "main": "dist/index.js", + "module": "dist/module.js", + "author": "Carsen Klock & Jan", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && npm run build && npm pack && cpx *.tgz ../../pre-install" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4", + "ts-loader": "^9.5.0" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "fetch-retry": "^5.0.6", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "fetch-retry" + ] +} diff --git a/extensions/inference-groq-extension/resources/models.json b/extensions/inference-groq-extension/resources/models.json new file mode 100644 index 0000000000..81275f47ce --- /dev/null +++ b/extensions/inference-groq-extension/resources/models.json @@ -0,0 +1,125 @@ +[ + { + "sources": [ + { + "url": "https://groq.com" + } + ], + "id": "llama3-70b-8192", + "object": "model", + "name": "Groq Llama 3 70b", + "version": "1.1", + "description": "Groq Llama 3 70b with supercharged speed!", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 8192, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Meta", + "tags": [ + "General", + "Big Context Length" + ] + }, + "engine": "groq" + }, + { + "sources": [ + { + "url": "https://groq.com" + } + ], + "id": "llama3-8b-8192", + "object": "model", + "name": "Groq Llama 3 8b", + "version": "1.1", + "description": "Groq Llama 3 8b with supercharged speed!", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 8192, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Meta", + "tags": [ + "General", + "Big Context Length" + ] + }, + "engine": "groq" + }, + { + "sources": [ + { + "url": "https://groq.com" + } + ], + "id": "gemma-7b-it", + "object": "model", + "name": "Groq Gemma 7b Instruct", + "version": "1.1", + "description": "Groq Gemma 7b Instruct with supercharged speed!", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 8192, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Google", + "tags": [ + "General" + ] + }, + "engine": "groq" + }, + { + "sources": [ + { + "url": "https://groq.com" + } + ], + "id": "mixtral-8x7b-32768", + "object": "model", + "name": "Groq Mixtral 8x7b Instruct", + "version": "1.1", + "description": "Groq Mixtral 8x7b Instruct is Mixtral with supercharged speed!", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 32768, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Mistral", + "tags": [ + "General", + "Big Context Length" + ] + }, + "engine": "groq" + } +] \ No newline at end of file diff --git a/extensions/inference-groq-extension/resources/settings.json b/extensions/inference-groq-extension/resources/settings.json new file mode 100644 index 0000000000..493b602cd9 --- /dev/null +++ b/extensions/inference-groq-extension/resources/settings.json @@ -0,0 +1,23 @@ +[ + { + "key": "chat-completions-endpoint", + "title": "Chat Completions Endpoint", + "description": "The endpoint to use for chat completions. See the [Groq documentation](https://console.groq.com/docs/openai) for more information.", + "controllerType": "input", + "controllerProps": { + "placeholder": "https://api.groq.com/openai/v1/chat/completions", + "value": "https://api.groq.com/openai/v1/chat/completions" + } + }, + { + "key": "groq-api-key", + "title": "API Key", + "description": "The Groq API uses API keys for authentication. Visit your [API Keys](https://console.groq.com/keys) page to retrieve the API key you'll use in your requests.", + "controllerType": "input", + "controllerProps": { + "placeholder": "gsk_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "value": "", + "type": "password" + } + } +] diff --git a/extensions/inference-groq-extension/src/index.ts b/extensions/inference-groq-extension/src/index.ts new file mode 100644 index 0000000000..eafb7fe8ad --- /dev/null +++ b/extensions/inference-groq-extension/src/index.ts @@ -0,0 +1,67 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-groq-extension/src/index + */ + +import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core' + +declare const SETTINGS: Array +declare const MODELS: Array + +enum Settings { + apiKey = 'groq-api-key', + chatCompletionsEndPoint = 'chat-completions-endpoint', +} +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanInferenceGroqExtension extends RemoteOAIEngine { + inferenceUrl: string = '' + provider = 'groq' + + override async onLoad(): Promise { + super.onLoad() + + // Register Settings + this.registerSettings(SETTINGS) + this.registerModels(MODELS) + + // Retrieve API Key Setting + this.apiKey = await this.getSetting(Settings.apiKey, '') + this.inferenceUrl = await this.getSetting( + Settings.chatCompletionsEndPoint, + '' + ) + + if (this.inferenceUrl.length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } + } + + onSettingUpdate(key: string, value: T): void { + if (key === Settings.apiKey) { + this.apiKey = value as string + } else if (key === Settings.chatCompletionsEndPoint) { + if (typeof value !== 'string') return + + if (value.trim().length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } else { + this.inferenceUrl = value + } + } + } +} diff --git a/extensions/inference-groq-extension/tsconfig.json b/extensions/inference-groq-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/inference-groq-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/inference-groq-extension/webpack.config.js b/extensions/inference-groq-extension/webpack.config.js new file mode 100644 index 0000000000..199dee42cb --- /dev/null +++ b/extensions/inference-groq-extension/webpack.config.js @@ -0,0 +1,37 @@ +const webpack = require('webpack') +const packageJson = require('./package.json') +const settingJson = require('./resources/settings.json') +const modelsJson = require('./resources/models.json') + +module.exports = { + experiments: { outputModule: true }, + entry: './src/index.ts', // Adjust the entry point to match your project's main file + mode: 'production', + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + ], + }, + plugins: [ + new webpack.DefinePlugin({ + MODELS: JSON.stringify(modelsJson), + SETTINGS: JSON.stringify(settingJson), + MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`), + }), + ], + output: { + filename: 'index.js', // Adjust the output file name as needed + library: { type: 'module' }, // Specify ESM output format + }, + resolve: { + extensions: ['.ts', '.js'], + }, + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +} diff --git a/extensions/inference-martian-extension/README.md b/extensions/inference-martian-extension/README.md new file mode 100644 index 0000000000..5b8e898d7c --- /dev/null +++ b/extensions/inference-martian-extension/README.md @@ -0,0 +1,79 @@ +# Martian Engine Extension + +Created using Jan extension example + +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-martian-extension/package.json b/extensions/inference-martian-extension/package.json new file mode 100644 index 0000000000..15d392b9c1 --- /dev/null +++ b/extensions/inference-martian-extension/package.json @@ -0,0 +1,42 @@ +{ + "name": "@janhq/inference-martian-extension", + "productName": "Martian Inference Engine", + "version": "1.0.1", + "description": "This extension enables Martian chat completion API calls", + "main": "dist/index.js", + "module": "dist/module.js", + "engine": "martian", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4", + "ts-loader": "^9.5.0" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "fetch-retry": "^5.0.6", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "fetch-retry" + ] +} diff --git a/extensions/inference-martian-extension/resources/models.json b/extensions/inference-martian-extension/resources/models.json new file mode 100644 index 0000000000..cf59e958e7 --- /dev/null +++ b/extensions/inference-martian-extension/resources/models.json @@ -0,0 +1,32 @@ +[ + { + "sources": [ + { + "url": "https://withmartian.com/" + } + ], + "id": "router", + "object": "model", + "name": "Martian Model Router", + "version": "1.0", + "description": "Martian Model Router dynamically routes requests to the best LLM in real-time", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Martian", + "tags": [ + "General" + ] + }, + "engine": "martian" + } +] \ No newline at end of file diff --git a/extensions/inference-martian-extension/resources/settings.json b/extensions/inference-martian-extension/resources/settings.json new file mode 100644 index 0000000000..bc83d76d40 --- /dev/null +++ b/extensions/inference-martian-extension/resources/settings.json @@ -0,0 +1,23 @@ +[ + { + "key": "chat-completions-endpoint", + "title": "Chat Completions Endpoint", + "description": "The endpoint to use for chat completions. See the [Martian API documentation](https://docs.withmartian.com/martian-model-router/api-reference/get-chat-completions) for more information.", + "controllerType": "input", + "controllerProps": { + "placeholder": "https://withmartian.com/api/openai/v1/chat/completions", + "value": "https://withmartian.com/api/openai/v1/chat/completions" + } + }, + { + "key": "martian-api-key", + "title": "API Key", + "description": "The Martian API uses API keys for authentication. Visit your [API Keys](https://withmartian.com/dashboard) page to retrieve the API key you'll use in your requests.", + "controllerType": "input", + "controllerProps": { + "placeholder": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "value": "", + "type": "password" + } + } +] diff --git a/extensions/inference-martian-extension/src/index.ts b/extensions/inference-martian-extension/src/index.ts new file mode 100644 index 0000000000..f59a6b7fc0 --- /dev/null +++ b/extensions/inference-martian-extension/src/index.ts @@ -0,0 +1,66 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-martian-extension/src/index + */ + +import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core' + +declare const SETTINGS: Array +declare const MODELS: Array + +enum Settings { + apiKey = 'martian-api-key', + chatCompletionsEndPoint = 'chat-completions-endpoint', +} + +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanInferenceMartianExtension extends RemoteOAIEngine { + inferenceUrl: string = '' + provider: string = 'martian' + + override async onLoad(): Promise { + super.onLoad() + + // Register Settings + this.registerSettings(SETTINGS) + this.registerModels(MODELS) + + this.apiKey = await this.getSetting(Settings.apiKey, '') + this.inferenceUrl = await this.getSetting( + Settings.chatCompletionsEndPoint, + '' + ) + if (this.inferenceUrl.length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } + } + + onSettingUpdate(key: string, value: T): void { + if (key === Settings.apiKey) { + this.apiKey = value as string + } else if (key === Settings.chatCompletionsEndPoint) { + if (typeof value !== 'string') return + + if (value.trim().length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } else { + this.inferenceUrl = value + } + } + } +} diff --git a/extensions/inference-martian-extension/tsconfig.json b/extensions/inference-martian-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/inference-martian-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/inference-martian-extension/webpack.config.js b/extensions/inference-martian-extension/webpack.config.js new file mode 100644 index 0000000000..cd5e65c725 --- /dev/null +++ b/extensions/inference-martian-extension/webpack.config.js @@ -0,0 +1,37 @@ +const webpack = require('webpack') +const packageJson = require('./package.json') +const settingJson = require('./resources/settings.json') +const modelsJson = require('./resources/models.json') + +module.exports = { + experiments: { outputModule: true }, + entry: './src/index.ts', // Adjust the entry point to match your project's main file + mode: 'production', + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + ], + }, + plugins: [ + new webpack.DefinePlugin({ + MODELS: JSON.stringify(modelsJson), + SETTINGS: JSON.stringify(settingJson), + ENGINE: JSON.stringify(packageJson.engine), + }), + ], + output: { + filename: 'index.js', // Adjust the output file name as needed + library: { type: 'module' }, // Specify ESM output format + }, + resolve: { + extensions: ['.ts', '.js'], + }, + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +} diff --git a/extensions/inference-mistral-extension/README.md b/extensions/inference-mistral-extension/README.md new file mode 100644 index 0000000000..adb36558cf --- /dev/null +++ b/extensions/inference-mistral-extension/README.md @@ -0,0 +1,79 @@ +# Mistral Engine Extension + +Created using Jan extension example + +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-mistral-extension/package.json b/extensions/inference-mistral-extension/package.json new file mode 100644 index 0000000000..7cdb612538 --- /dev/null +++ b/extensions/inference-mistral-extension/package.json @@ -0,0 +1,43 @@ +{ + "name": "@janhq/inference-mistral-extension", + "productName": "MistralAI Inference Engine", + "version": "1.0.1", + "description": "This extension enables Mistral chat completion API calls", + "main": "dist/index.js", + "module": "dist/module.js", + "engine": "mistral", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4", + "ts-loader": "^9.5.0" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "fetch-retry": "^5.0.6", + "path-browserify": "^1.0.1", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "fetch-retry" + ] +} diff --git a/extensions/inference-mistral-extension/resources/models.json b/extensions/inference-mistral-extension/resources/models.json new file mode 100644 index 0000000000..23ecd6fdd4 --- /dev/null +++ b/extensions/inference-mistral-extension/resources/models.json @@ -0,0 +1,83 @@ +[ + { + "sources": [ + { + "url": "https://docs.mistral.ai/api/" + } + ], + "id": "mistral-small-latest", + "object": "model", + "name": "Mistral Small", + "version": "1.1", + "description": "Mistral Small is the ideal choice for simple tasks (Classification, Customer Support, or Text Generation) at an affordable price.", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 32000, + "temperature": 0.7, + "top_p": 0.95, + "stream": true + }, + "metadata": { + "author": "Mistral", + "tags": [ + "General" + ] + }, + "engine": "mistral" + }, + { + "sources": [ + { + "url": "https://docs.mistral.ai/api/" + } + ], + "id": "mistral-large-latest", + "object": "model", + "name": "Mistral Large", + "version": "1.1", + "description": "Mistral Large is ideal for complex tasks (Synthetic Text Generation, Code Generation, RAG, or Agents).", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 32000, + "temperature": 0.7, + "top_p": 0.95, + "stream": true + }, + "metadata": { + "author": "Mistral", + "tags": [ + "General" + ] + }, + "engine": "mistral" + }, + { + "sources": [ + { + "url": "https://docs.mistral.ai/api/" + } + ], + "id": "open-mixtral-8x22b", + "object": "model", + "name": "Mixtral 8x22B", + "version": "1.1", + "description": "Mixtral 8x22B is a high-performance, cost-effective model designed for complex tasks.", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 32000, + "temperature": 0.7, + "top_p": 0.95, + "stream": true + }, + "metadata": { + "author": "Mistral", + "tags": [ + "General" + ] + }, + "engine": "mistral" + } +] diff --git a/extensions/inference-mistral-extension/resources/settings.json b/extensions/inference-mistral-extension/resources/settings.json new file mode 100644 index 0000000000..2ca8ec7e55 --- /dev/null +++ b/extensions/inference-mistral-extension/resources/settings.json @@ -0,0 +1,23 @@ +[ + { + "key": "chat-completions-endpoint", + "title": "Chat Completions Endpoint", + "description": "The endpoint to use for chat completions. See the [Mistral API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.", + "controllerType": "input", + "controllerProps": { + "placeholder": "https://api.mistral.ai/v1/chat/completions", + "value": "https://api.mistral.ai/v1/chat/completions" + } + }, + { + "key": "mistral-api-key", + "title": "API Key", + "description": "The Mistral API uses API keys for authentication. Visit your [API Keys](https://console.mistral.ai/api-keys/) page to retrieve the API key you'll use in your requests.", + "controllerType": "input", + "controllerProps": { + "placeholder": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "value": "", + "type": "password" + } + } +] diff --git a/extensions/inference-mistral-extension/src/index.ts b/extensions/inference-mistral-extension/src/index.ts new file mode 100644 index 0000000000..461fc326e7 --- /dev/null +++ b/extensions/inference-mistral-extension/src/index.ts @@ -0,0 +1,66 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-mistral-extension/src/index + */ + +import { RemoteOAIEngine } from '@janhq/core' + +declare const SETTINGS: Array +declare const MODELS: Array + +enum Settings { + apiKey = 'mistral-api-key', + chatCompletionsEndPoint = 'chat-completions-endpoint', +} +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanInferenceMistralExtension extends RemoteOAIEngine { + inferenceUrl: string = '' + provider: string = 'mistral' + + override async onLoad(): Promise { + super.onLoad() + + // Register Settings + this.registerSettings(SETTINGS) + this.registerModels(MODELS) + + this.apiKey = await this.getSetting(Settings.apiKey, '') + this.inferenceUrl = await this.getSetting( + Settings.chatCompletionsEndPoint, + '' + ) + + if (this.inferenceUrl.length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } + } + + onSettingUpdate(key: string, value: T): void { + if (key === Settings.apiKey) { + this.apiKey = value as string + } else if (key === Settings.chatCompletionsEndPoint) { + if (typeof value !== 'string') return + + if (value.trim().length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } else { + this.inferenceUrl = value + } + } + } +} diff --git a/extensions/inference-mistral-extension/tsconfig.json b/extensions/inference-mistral-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/inference-mistral-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/inference-mistral-extension/webpack.config.js b/extensions/inference-mistral-extension/webpack.config.js new file mode 100644 index 0000000000..0e35fc227b --- /dev/null +++ b/extensions/inference-mistral-extension/webpack.config.js @@ -0,0 +1,42 @@ +const path = require('path') +const webpack = require('webpack') +const packageJson = require('./package.json') +const settingJson = require('./resources/settings.json') +const modelsJson = require('./resources/models.json') + +module.exports = { + experiments: { outputModule: true }, + entry: './src/index.ts', // Adjust the entry point to match your project's main file + mode: 'production', + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + ], + }, + plugins: [ + new webpack.DefinePlugin({ + SETTINGS: JSON.stringify(settingJson), + ENGINE: JSON.stringify(packageJson.engine), + MODELS: JSON.stringify(modelsJson), + }), + ], + output: { + filename: 'index.js', // Adjust the output file name as needed + path: path.resolve(__dirname, 'dist'), + library: { type: 'module' }, // Specify ESM output format + }, + resolve: { + extensions: ['.ts', '.js'], + fallback: { + path: require.resolve('path-browserify'), + }, + }, + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +} diff --git a/extensions/inference-nitro-extension/.gitignore b/extensions/inference-nitro-extension/.gitignore new file mode 100644 index 0000000000..10780f1d4c --- /dev/null +++ b/extensions/inference-nitro-extension/.gitignore @@ -0,0 +1,2 @@ +bin +!version.txt \ No newline at end of file diff --git a/extensions/inference-nitro-extension/README.md b/extensions/inference-nitro-extension/README.md new file mode 100644 index 0000000000..f9690da09d --- /dev/null +++ b/extensions/inference-nitro-extension/README.md @@ -0,0 +1,75 @@ +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-nitro-extension/bin/version.txt b/extensions/inference-nitro-extension/bin/version.txt new file mode 100644 index 0000000000..2b2a18d265 --- /dev/null +++ b/extensions/inference-nitro-extension/bin/version.txt @@ -0,0 +1 @@ +0.4.20 diff --git a/extensions/inference-nitro-extension/download.bat b/extensions/inference-nitro-extension/download.bat new file mode 100644 index 0000000000..9bd2d4b074 --- /dev/null +++ b/extensions/inference-nitro-extension/download.bat @@ -0,0 +1,3 @@ +@echo off +set /p CORTEX_VERSION=<./bin/version.txt +.\node_modules\.bin\download https://github.com/janhq/cortex/releases/download/v%CORTEX_VERSION%/cortex-cpp-%CORTEX_VERSION%-windows-amd64-avx2-cuda-12-0.tar.gz -e --strip 1 -o ./bin/win-cuda-12-0 && .\node_modules\.bin\download https://github.com/janhq/cortex/releases/download/v%CORTEX_VERSION%/cortex-cpp-%CORTEX_VERSION%-windows-amd64-avx2-cuda-11-7.tar.gz -e --strip 1 -o ./bin/win-cuda-11-7 && .\node_modules\.bin\download https://github.com/janhq/nitro/releases/download/v%CORTEX_VERSION%/cortex-cpp-%CORTEX_VERSION%-windows-amd64-avx2.tar.gz -e --strip 1 -o ./bin/win-cpu && .\node_modules\.bin\download https://github.com/janhq/cortex/releases/download/v%CORTEX_VERSION%/cortex-cpp-%CORTEX_VERSION%-windows-amd64-vulkan.tar.gz -e --strip 1 -o ./bin/win-vulkan diff --git a/extensions/inference-nitro-extension/jest.config.js b/extensions/inference-nitro-extension/jest.config.js new file mode 100644 index 0000000000..b413e106db --- /dev/null +++ b/extensions/inference-nitro-extension/jest.config.js @@ -0,0 +1,5 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', +}; \ No newline at end of file diff --git a/extensions/inference-nitro-extension/package.json b/extensions/inference-nitro-extension/package.json new file mode 100644 index 0000000000..3150108c48 --- /dev/null +++ b/extensions/inference-nitro-extension/package.json @@ -0,0 +1,73 @@ +{ + "name": "@janhq/inference-cortex-extension", + "productName": "Cortex Inference Engine", + "version": "1.0.14", + "description": "This extension embeds cortex.cpp, a lightweight inference engine written in C++. See https://nitro.jan.ai.\nAdditional dependencies could be installed to run without Cuda Toolkit installation.", + "main": "dist/index.js", + "node": "dist/node/index.cjs.js", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "test": "jest", + "build": "tsc --module commonjs && rollup -c rollup.config.ts", + "downloadnitro:linux": "CORTEX_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/cortex/releases/download/v${CORTEX_VERSION}/cortex-cpp-${CORTEX_VERSION}-linux-amd64-avx2.tar.gz -e --strip 1 -o ./bin/linux-cpu && chmod +x ./bin/linux-cpu/cortex-cpp && download https://github.com/janhq/cortex/releases/download/v${CORTEX_VERSION}/cortex-cpp-${CORTEX_VERSION}-linux-amd64-avx2-cuda-12-0.tar.gz -e --strip 1 -o ./bin/linux-cuda-12-0 && chmod +x ./bin/linux-cuda-12-0/cortex-cpp && download https://github.com/janhq/cortex/releases/download/v${CORTEX_VERSION}/cortex-cpp-${CORTEX_VERSION}-linux-amd64-avx2-cuda-11-7.tar.gz -e --strip 1 -o ./bin/linux-cuda-11-7 && chmod +x ./bin/linux-cuda-11-7/cortex-cpp && download https://github.com/janhq/cortex/releases/download/v${CORTEX_VERSION}/cortex-cpp-${CORTEX_VERSION}-linux-amd64-vulkan.tar.gz -e --strip 1 -o ./bin/linux-vulkan && chmod +x ./bin/linux-vulkan/cortex-cpp", + "downloadnitro:darwin": "CORTEX_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/cortex/releases/download/v${CORTEX_VERSION}/cortex-cpp-${CORTEX_VERSION}-mac-arm64.tar.gz -o ./bin/ && mkdir -p ./bin/mac-arm64 && tar -zxvf ./bin/cortex-cpp-${CORTEX_VERSION}-mac-arm64.tar.gz --strip-components=1 -C ./bin/mac-arm64 && rm -rf ./bin/cortex-cpp-${CORTEX_VERSION}-mac-arm64.tar.gz && chmod +x ./bin/mac-arm64/cortex-cpp && download https://github.com/janhq/cortex/releases/download/v${CORTEX_VERSION}/cortex-cpp-${CORTEX_VERSION}-mac-amd64.tar.gz -o ./bin/ && mkdir -p ./bin/mac-amd64 && tar -zxvf ./bin/cortex-cpp-${CORTEX_VERSION}-mac-amd64.tar.gz --strip-components=1 -C ./bin/mac-amd64 && rm -rf ./bin/cortex-cpp-${CORTEX_VERSION}-mac-amd64.tar.gz && chmod +x ./bin/mac-amd64/cortex-cpp", + "downloadnitro:win32": "download.bat", + "downloadnitro": "run-script-os", + "build:publish:darwin": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", + "build:publish:win32": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", + "build:publish:linux": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", + "build:publish": "yarn test && run-script-os" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/node/index.cjs.js" + }, + "devDependencies": { + "@babel/preset-typescript": "^7.24.1", + "@jest/globals": "^29.7.0", + "@rollup/plugin-commonjs": "^25.0.7", + "@rollup/plugin-json": "^6.1.0", + "@rollup/plugin-node-resolve": "^15.2.3", + "@rollup/plugin-replace": "^5.0.5", + "@types/decompress": "^4.2.7", + "@types/jest": "^29.5.12", + "@types/node": "^20.11.4", + "@types/os-utils": "^0.0.4", + "@types/tcp-port-used": "^1.0.4", + "cpx": "^1.5.0", + "download-cli": "^1.1.1", + "jest": "^29.7.0", + "rimraf": "^3.0.2", + "rollup": "^2.38.5", + "rollup-plugin-define": "^1.0.1", + "rollup-plugin-sourcemaps": "^0.6.3", + "rollup-plugin-typescript2": "^0.36.0", + "run-script-os": "^1.1.6", + "ts-jest": "^29.1.2", + "typescript": "^5.3.3" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "decompress": "^4.2.1", + "fetch-retry": "^5.0.6", + "rxjs": "^7.8.1", + "tcp-port-used": "^1.0.2", + "terminate": "2.6.1", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "tcp-port-used", + "fetch-retry", + "@janhq/core", + "decompress" + ] +} diff --git a/extensions/inference-nitro-extension/resources/default_settings.json b/extensions/inference-nitro-extension/resources/default_settings.json new file mode 100644 index 0000000000..09d014a12c --- /dev/null +++ b/extensions/inference-nitro-extension/resources/default_settings.json @@ -0,0 +1,33 @@ +[ + { + "key": "test", + "title": "Test", + "description": "Test", + "controllerType": "input", + "controllerProps": { + "placeholder": "Test", + "value": "" + } + }, + { + "key": "embedding", + "title": "Embedding", + "description": "Whether to enable embedding.", + "controllerType": "checkbox", + "controllerProps": { + "value": true + } + }, + { + "key": "ctx_len", + "title": "Context Length", + "description": "The context length for model operations varies; the maximum depends on the specific model used.", + "controllerType": "slider", + "controllerProps": { + "min": 0, + "max": 4096, + "step": 128, + "value": 2048 + } + } +] diff --git a/extensions/inference-nitro-extension/resources/models/aya-23-35b/model.json b/extensions/inference-nitro-extension/resources/models/aya-23-35b/model.json new file mode 100644 index 0000000000..8c3029be0a --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/aya-23-35b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "aya-23-35B-Q4_K_M.gguf", + "url": "https://huggingface.co/bartowski/aya-23-35B-GGUF/resolve/main/aya-23-35B-Q4_K_M.gguf" + } + ], + "id": "aya-23-35b", + "object": "model", + "name": "Aya 23 35B Q4", + "version": "1.1", + "description": "Aya 23 can talk upto 23 languages fluently.", + "format": "gguf", + "settings": { + "ctx_len": 8192, + "prompt_template": "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + "llama_model_path": "aya-23-35B-Q4_K_M.gguf", + "ngl": 41 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 8192, + "frequency_penalty": 0, + "presence_penalty": 0, + "stop": ["<|END_OF_TURN_TOKEN|>"] + }, + "metadata": { + "author": "CohereForAI", + "tags": ["34B", "Finetuned"], + "size": 21556982144 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/aya-23-8b/model.json b/extensions/inference-nitro-extension/resources/models/aya-23-8b/model.json new file mode 100644 index 0000000000..b82cf2f393 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/aya-23-8b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "aya-23-8B-Q4_K_M.gguf", + "url": "https://huggingface.co/bartowski/aya-23-8B-GGUF/resolve/main/aya-23-8B-Q4_K_M.gguf" + } + ], + "id": "aya-23-8b", + "object": "model", + "name": "Aya 23 8B Q4", + "version": "1.1", + "description": "Aya 23 can talk upto 23 languages fluently.", + "format": "gguf", + "settings": { + "ctx_len": 8192, + "prompt_template": "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + "llama_model_path": "aya-23-8B-Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 8192, + "frequency_penalty": 0, + "presence_penalty": 0, + "stop": ["<|END_OF_TURN_TOKEN|>"] + }, + "metadata": { + "author": "CohereForAI", + "tags": ["7B", "Finetuned","Featured"], + "size": 5056982144 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/bakllava-1/model.json b/extensions/inference-nitro-extension/resources/models/bakllava-1/model.json new file mode 100644 index 0000000000..93f87c7f46 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/bakllava-1/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "ggml-model-q5_k.gguf", + "url": "https://huggingface.co/mys/ggml_bakllava-1/resolve/main/ggml-model-q5_k.gguf" + }, + { + "filename": "mmproj-model-f16.gguf", + "url": "https://huggingface.co/mys/ggml_bakllava-1/resolve/main/mmproj-model-f16.gguf" + } + ], + "id": "bakllava-1", + "object": "model", + "name": "BakLlava 1", + "version": "1.0", + "description": "BakLlava 1 can bring vision understanding to Jan", + "format": "gguf", + "settings": { + "vision_model": true, + "text_model": false, + "ctx_len": 4096, + "prompt_template": "\n### Instruction:\n{prompt}\n### Response:\n", + "llama_model_path": "ggml-model-q5_k.gguf", + "mmproj": "mmproj-model-f16.gguf" + }, + "parameters": { + "max_tokens": 4096 + }, + "metadata": { + "author": "Mys", + "tags": ["Vision"], + "size": 5750000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/codeninja-1.0-7b/model.json b/extensions/inference-nitro-extension/resources/models/codeninja-1.0-7b/model.json new file mode 100644 index 0000000000..fb2a5f3467 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/codeninja-1.0-7b/model.json @@ -0,0 +1,34 @@ +{ + "sources": [ + { + "filename": "codeninja-1.0-openchat-7b.Q4_K_M.gguf", + "url": "https://huggingface.co/beowolx/CodeNinja-1.0-OpenChat-7B-GGUF/resolve/main/codeninja-1.0-openchat-7b.Q4_K_M.gguf" + } + ], + "id": "codeninja-1.0-7b", + "object": "model", + "name": "CodeNinja 7B Q4", + "version": "1.2", + "description": "CodeNinja is good for coding tasks and can handle various languages including Python, C, C++, Rust, Java, JavaScript, and more.", + "format": "gguf", + "settings": { + "ctx_len": 8192, + "prompt_template": "GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:", + "llama_model_path": "codeninja-1.0-openchat-7b.Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 8192, + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Beowolx", + "tags": ["7B", "Finetuned"], + "size": 4370000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/codestral-22b/model.json b/extensions/inference-nitro-extension/resources/models/codestral-22b/model.json new file mode 100644 index 0000000000..f90f848dd9 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/codestral-22b/model.json @@ -0,0 +1,36 @@ +{ + "sources": [ + { + "filename": "Codestral-22B-v0.1-Q4_K_M.gguf", + "url": "https://huggingface.co/bartowski/Codestral-22B-v0.1-GGUF/resolve/main/Codestral-22B-v0.1-Q4_K_M.gguf" + } + ], + "id": "codestral-22b", + "object": "model", + "name": "Codestral 22B Q4", + "version": "1.1", + "description": "Latest model from MistralAI optimized for code generation tasks.", + "format": "gguf", + "settings": { + "ctx_len": 32000, + "prompt_template": "{system_message} [INST] {prompt} [/INST]", + "llama_model_path": "Codestral-22B-v0.1-Q4_K_M.gguf", + "ngl": 57 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 32000, + "stop": [", [/INST]"], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "MistralAI", + "tags": ["22B", "Finetuned", "Featured"], + "size": 13341237440 + }, + "engine": "nitro" + } + diff --git a/extensions/inference-nitro-extension/resources/models/command-r-34b/model.json b/extensions/inference-nitro-extension/resources/models/command-r-34b/model.json new file mode 100644 index 0000000000..d29e70a17f --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/command-r-34b/model.json @@ -0,0 +1,36 @@ +{ + "sources": [ + { + "filename": "c4ai-command-r-v01-Q4_K_M.gguf", + "url": "https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF/resolve/main/c4ai-command-r-v01-Q4_K_M.gguf" + } + ], + "id": "command-r-34b", + "object": "model", + "name": "Command-R v01 34B Q4", + "version": "1.5", + "description": "C4AI Command-R developed by CohereAI is optimized for a variety of use cases including reasoning, summarization, and question answering.", + "format": "gguf", + "settings": { + "ctx_len": 131072, + "prompt_template": "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + "llama_model_path": "c4ai-command-r-v01-Q4_K_M.gguf", + "ngl": 41 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 131072, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "CohereAI", + "tags": ["34B", "Finetuned", "Featured"], + "size": 21500000000 + }, + "engine": "nitro" + } + diff --git a/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json b/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json new file mode 100644 index 0000000000..53f7f43e9f --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "deepseek-coder-1.3b-instruct.Q8_0.gguf", + "url": "https://huggingface.co/TheBloke/deepseek-coder-1.3b-instruct-GGUF/resolve/main/deepseek-coder-1.3b-instruct.Q8_0.gguf" + } + ], + "id": "deepseek-coder-1.3b", + "object": "model", + "name": "Deepseek Coder 1.3B Q8", + "version": "1.2", + "description": "Deepseek Coder excelled in project-level code completion with advanced capabilities across multiple programming languages.", + "format": "gguf", + "settings": { + "ctx_len": 16384, + "prompt_template": "### Instruction:\n{prompt}\n### Response:", + "llama_model_path": "deepseek-coder-1.3b-instruct.Q8_0.gguf", + "ngl": 25 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 16384, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Deepseek, The Bloke", + "tags": ["Tiny", "Foundational Model"], + "size": 1430000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json b/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json new file mode 100644 index 0000000000..0a3e58b489 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "deepseek-coder-33b-instruct.Q4_K_M.gguf", + "url": "https://huggingface.co/TheBloke/deepseek-coder-33B-instruct-GGUF/resolve/main/deepseek-coder-33b-instruct.Q4_K_M.gguf" + } + ], + "id": "deepseek-coder-34b", + "object": "model", + "name": "Deepseek Coder 33B Q4", + "version": "1.2", + "description": "Deepseek Coder excelled in project-level code completion with advanced capabilities across multiple programming languages.", + "format": "gguf", + "settings": { + "ctx_len": 16384, + "prompt_template": "### Instruction:\n{prompt}\n### Response:", + "llama_model_path": "deepseek-coder-33b-instruct.Q4_K_M.gguf", + "ngl": 63 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 16384, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Deepseek, The Bloke", + "tags": ["34B", "Foundational Model"], + "size": 19940000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/gemma-2b/model.json b/extensions/inference-nitro-extension/resources/models/gemma-2b/model.json new file mode 100644 index 0000000000..68cff325a7 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/gemma-2b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "gemma-2b-it-q4_k_m.gguf", + "url": "https://huggingface.co/lmstudio-ai/gemma-2b-it-GGUF/resolve/main/gemma-2b-it-q4_k_m.gguf" + } + ], + "id": "gemma-2b", + "object": "model", + "name": "Gemma 2B Q4", + "version": "1.3", + "description": "Gemma is built from the same technology with Google's Gemini.", + "format": "gguf", + "settings": { + "ctx_len": 8192, + "prompt_template": "user\n{prompt}\nmodel", + "llama_model_path": "gemma-2b-it-q4_k_m.gguf", + "ngl": 19 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 8192, + "stop": [""], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Google", + "tags": ["2B", "Finetuned", "Tiny"], + "size": 1500000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/gemma-7b/model.json b/extensions/inference-nitro-extension/resources/models/gemma-7b/model.json new file mode 100644 index 0000000000..615f1149b3 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/gemma-7b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "gemma-7b-it-q4_K_M.gguf", + "url": "https://huggingface.co/mmnga/gemma-7b-it-gguf/resolve/main/gemma-7b-it-q4_K_M.gguf" + } + ], + "id": "gemma-7b", + "object": "model", + "name": "Gemma 7B Q4", + "version": "1.2", + "description": "Google's Gemma is built for multilingual purpose", + "format": "gguf", + "settings": { + "ctx_len": 8192, + "prompt_template": "user\n{prompt}\nmodel", + "llama_model_path": "gemma-7b-it-q4_K_M.gguf", + "ngl": 29 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 8192, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Google", + "tags": ["7B", "Finetuned", "Featured"], + "size": 5330000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/llama2-chat-70b/model.json b/extensions/inference-nitro-extension/resources/models/llama2-chat-70b/model.json new file mode 100644 index 0000000000..0c770b1896 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/llama2-chat-70b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "llama-2-70b-chat.Q4_K_M.gguf", + "url": "https://huggingface.co/TheBloke/Llama-2-70B-Chat-GGUF/resolve/main/llama-2-70b-chat.Q4_K_M.gguf" + } + ], + "id": "llama2-chat-70b", + "object": "model", + "name": "Llama 2 Chat 70B Q4", + "version": "1.1", + "description": "Llama 2 specifically designed for a comprehensive understanding the world.", + "format": "gguf", + "settings": { + "ctx_len": 4096, + "prompt_template": "[INST] <>\n{system_message}<>\n{prompt}[/INST]", + "llama_model_path": "llama-2-70b-chat.Q4_K_M.gguf", + "ngl": 81 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 4096, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "MetaAI", + "tags": ["70B", "Foundational Model"], + "size": 43920000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/llama2-chat-7b/model.json b/extensions/inference-nitro-extension/resources/models/llama2-chat-7b/model.json new file mode 100644 index 0000000000..9efd634b59 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/llama2-chat-7b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "llama-2-7b-chat.Q4_K_M.gguf", + "url": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf" + } + ], + "id": "llama2-chat-7b", + "object": "model", + "name": "Llama 2 Chat 7B Q4", + "version": "1.1", + "description": "Llama 2 specifically designed for a comprehensive understanding the world.", + "format": "gguf", + "settings": { + "ctx_len": 4096, + "prompt_template": "[INST] <>\n{system_message}<>\n{prompt}[/INST]", + "llama_model_path": "llama-2-7b-chat.Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 4096, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "MetaAI", + "tags": ["7B", "Foundational Model"], + "size": 4080000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/llama3-8b-instruct/model.json b/extensions/inference-nitro-extension/resources/models/llama3-8b-instruct/model.json new file mode 100644 index 0000000000..313bf84257 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/llama3-8b-instruct/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf", + "url": "https://huggingface.co/lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf" + } + ], + "id": "llama3-8b-instruct", + "object": "model", + "name": "Llama 3 8B Q4", + "version": "1.2", + "description": "Meta's Llama 3 excels at general usage situations, including chat, general world knowledge, and coding.", + "format": "gguf", + "settings": { + "ctx_len": 8192, + "prompt_template": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + "llama_model_path": "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 8192, + "stop": ["<|end_of_text|>","<|eot_id|>"], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "MetaAI", + "tags": ["7B", "Featured"], + "size": 4920000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/llama3-hermes-8b/model.json b/extensions/inference-nitro-extension/resources/models/llama3-hermes-8b/model.json new file mode 100644 index 0000000000..a3601c8cdd --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/llama3-hermes-8b/model.json @@ -0,0 +1,38 @@ +{ + "sources": [ + { + "filename": "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", + "url": "https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF/resolve/main/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf" + } + ], + "id": "llama3-hermes-8b", + "object": "model", + "name": "Hermes Pro Llama 3 8B Q4", + "version": "1.2", + "description": "Hermes Pro is well-designed for General chat and JSON output.", + "format": "gguf", + "settings": { + "ctx_len": 8192, + "prompt_template": "<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant", + "llama_model_path": "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 8192, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "NousResearch", + "tags": [ + "7B", + "Finetuned" + ], + "size": 4920000000 + }, + "engine": "nitro" + } diff --git a/extensions/inference-nitro-extension/resources/models/llamacorn-1.1b/model.json b/extensions/inference-nitro-extension/resources/models/llamacorn-1.1b/model.json new file mode 100644 index 0000000000..94b62ec82d --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/llamacorn-1.1b/model.json @@ -0,0 +1,38 @@ +{ + "sources": [ + { + "url":"https://huggingface.co/janhq/llamacorn-1.1b-chat-GGUF/resolve/main/llamacorn-1.1b-chat.Q8_0.gguf", + "filename": "llamacorn-1.1b-chat.Q8_0.gguf" + } + ], + "id": "llamacorn-1.1b", + "object": "model", + "name": "LlamaCorn 1.1B Q8", + "version": "1.1", + "description": "LlamaCorn is designed to improve chat functionality from TinyLlama.", + "format": "gguf", + "settings": { + "ctx_len": 2048, + "prompt_template": "<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant", + "llama_model_path": "llamacorn-1.1b-chat.Q8_0.gguf", + "ngl": 23 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 2048, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Jan", + "tags": [ + "Tiny", + "Finetuned" + ], + "size": 1170000000 + }, + "engine": "nitro" + } \ No newline at end of file diff --git a/extensions/inference-nitro-extension/resources/models/llava-13b/model.json b/extensions/inference-nitro-extension/resources/models/llava-13b/model.json new file mode 100644 index 0000000000..caca33b7e0 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/llava-13b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "llava-v1.6-vicuna-13b.Q4_K_M.gguf", + "url": "https://huggingface.co/cjpais/llava-v1.6-vicuna-13b-gguf/resolve/main/llava-v1.6-vicuna-13b.Q4_K_M.gguf" + }, + { + "filename": "mmproj-model-f16.gguf", + "url": "https://huggingface.co/cjpais/llava-v1.6-vicuna-13b-gguf/resolve/main/mmproj-model-f16.gguf" + } + ], + "id": "llava-13b", + "object": "model", + "name": "LlaVa 13B Q4", + "version": "1.1", + "description": "LlaVa can bring vision understanding to Jan", + "format": "gguf", + "settings": { + "vision_model": true, + "text_model": false, + "ctx_len": 4096, + "prompt_template": "\n### Instruction:\n{prompt}\n### Response:\n", + "llama_model_path": "llava-v1.6-vicuna-13b.Q4_K_M.gguf", + "mmproj": "mmproj-model-f16.gguf" + }, + "parameters": { + "max_tokens": 4096 + }, + "metadata": { + "author": "liuhaotian", + "tags": ["Vision"], + "size": 7870000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/llava-7b/model.json b/extensions/inference-nitro-extension/resources/models/llava-7b/model.json new file mode 100644 index 0000000000..b61ec38c2c --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/llava-7b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "llava-v1.6-mistral-7b.Q4_K_M.gguf", + "url": "https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q4_K_M.gguf" + }, + { + "filename": "mmproj-model-f16.gguf", + "url": "https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/mmproj-model-f16.gguf" + } + ], + "id": "llava-7b", + "object": "model", + "name": "LlaVa 7B", + "version": "1.1", + "description": "LlaVa can bring vision understanding to Jan", + "format": "gguf", + "settings": { + "vision_model": true, + "text_model": false, + "ctx_len": 4096, + "prompt_template": "\n### Instruction:\n{prompt}\n### Response:\n", + "llama_model_path": "llava-v1.6-mistral-7b.Q4_K_M.gguf", + "mmproj": "mmproj-model-f16.gguf" + }, + "parameters": { + "max_tokens": 4096 + }, + "metadata": { + "author": "liuhaotian", + "tags": ["Vision"], + "size": 4370000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/mistral-ins-7b-q4/model.json b/extensions/inference-nitro-extension/resources/models/mistral-ins-7b-q4/model.json new file mode 100644 index 0000000000..d223306f8c --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/mistral-ins-7b-q4/model.json @@ -0,0 +1,36 @@ +{ + "sources": [ + { + "filename": "Mistral-7B-Instruct-v0.3-Q4_K_M.gguf", + "url": "https://huggingface.co/bartowski/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3-Q4_K_M.gguf" + } + ], + "id": "mistral-ins-7b-q4", + "object": "model", + "name": "Mistral Instruct 7B Q4", + "version": "1.3", + "description": "Mistral Instruct 7b model, specifically designed for a comprehensive understanding of the world.", + "format": "gguf", + "settings": { + "ctx_len": 32768, + "prompt_template": "{system_message} [INST] {prompt} [/INST]", + "llama_model_path": "Mistral-7B-Instruct-v0.3-Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 32768, + "stop": ["[/INST]"], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "MistralAI", + "tags": ["Featured", "7B", "Foundational Model"], + "size": 4370000000, + "cover": "https://raw.githubusercontent.com/janhq/jan/dev/models/mistral-ins-7b-q4/cover.png" + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/mixtral-8x7b-instruct/model.json b/extensions/inference-nitro-extension/resources/models/mixtral-8x7b-instruct/model.json new file mode 100644 index 0000000000..4413b415c4 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/mixtral-8x7b-instruct/model.json @@ -0,0 +1,34 @@ +{ + "sources": [ + { + "filename": "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf", + "url": "https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/resolve/main/mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf" + } + ], + "id": "mixtral-8x7b-instruct", + "object": "model", + "name": "Mixtral 8x7B Instruct Q4", + "version": "1.1", + "description": "The Mixtral-8x7B is a pretrained generative Sparse Mixture of Experts. The Mixtral-8x7B outperforms 70B models on most benchmarks.", + "format": "gguf", + "settings": { + "ctx_len": 32768, + "prompt_template": "[INST] {prompt} [/INST]", + "llama_model_path": "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf", + "ngl": 100 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 32768, + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "MistralAI, TheBloke", + "tags": ["70B", "Foundational Model"], + "size": 26440000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/noromaid-7b/model.json b/extensions/inference-nitro-extension/resources/models/noromaid-7b/model.json new file mode 100644 index 0000000000..10c17c3109 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/noromaid-7b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "Noromaid-7B-0.4-DPO.q4_k_m.gguf", + "url": "https://huggingface.co/NeverSleep/Noromaid-7B-0.4-DPO-GGUF/resolve/main/Noromaid-7B-0.4-DPO.q4_k_m.gguf" + } + ], + "id": "noromaid-7b", + "object": "model", + "name": "Noromaid 7B Q4", + "version": "1.2", + "description": "The Noromaid 7b model is designed for role-playing with human-like behavior.", + "format": "gguf", + "settings": { + "ctx_len": 32768, + "prompt_template": "<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant", + "llama_model_path": "Noromaid-7B-0.4-DPO.q4_k_m.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 32768, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "NeverSleep", + "tags": ["7B", "Finetuned"], + "size": 4370000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/openchat-3.5-7b/model.json b/extensions/inference-nitro-extension/resources/models/openchat-3.5-7b/model.json new file mode 100644 index 0000000000..e743a74c9c --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/openchat-3.5-7b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "openchat-3.5-0106.Q4_K_M.gguf", + "url": "https://huggingface.co/TheBloke/openchat-3.5-0106-GGUF/resolve/main/openchat-3.5-0106.Q4_K_M.gguf" + } + ], + "id": "openchat-3.5-7b", + "object": "model", + "name": "Openchat-3.5 7B Q4", + "version": "1.2", + "description": "The performance of Openchat surpasses ChatGPT-3.5 and Grok-1 across various benchmarks.", + "format": "gguf", + "settings": { + "ctx_len": 8192, + "prompt_template": "GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:", + "llama_model_path": "openchat-3.5-0106.Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 8192, + "stop": ["<|end_of_turn|>"], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Openchat", + "tags": ["Recommended", "7B", "Finetuned"], + "size": 4370000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/phi3-3.8b/model.json b/extensions/inference-nitro-extension/resources/models/phi3-3.8b/model.json new file mode 100644 index 0000000000..2a572db92c --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/phi3-3.8b/model.json @@ -0,0 +1,38 @@ +{ + "sources": [ + { + "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf", + "filename": "Phi-3-mini-4k-instruct-q4.gguf" + } + ], + "id": "phi3-3.8b", + "object": "model", + "name": "Phi-3 Mini", + "version": "1.2", + "description": "Phi-3 Mini is Microsoft's newest, compact model designed for mobile use.", + "format": "gguf", + "settings": { + "ctx_len": 4096, + "prompt_template": "<|user|>\n{prompt}<|end|>\n<|assistant|>\n", + "llama_model_path": "Phi-3-mini-4k-instruct-q4.gguf", + "ngl": 33 + }, + "parameters": { + "max_tokens": 4096, + "stop": ["<|end|>"], + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Microsoft", + "tags": [ + "3B", + "Finetuned" + ], + "size": 2320000000 + }, + "engine": "nitro" + } \ No newline at end of file diff --git a/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json b/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json new file mode 100644 index 0000000000..ac83ca0777 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json @@ -0,0 +1,38 @@ +{ + "sources": [ + { + "url": "https://huggingface.co/bartowski/Phi-3-medium-128k-instruct-GGUF/resolve/main/Phi-3-medium-128k-instruct-Q4_K_M.gguf", + "filename": "Phi-3-medium-128k-instruct-Q4_K_M.gguf" + } + ], + "id": "phi3-medium", + "object": "model", + "name": "Phi-3 Medium", + "version": "1.2", + "description": "Phi-3 Medium is Microsoft's latest SOTA model.", + "format": "gguf", + "settings": { + "ctx_len": 128000, + "prompt_template": "<|user|>\n{prompt}<|end|>\n<|assistant|>\n", + "llama_model_path": "Phi-3-medium-128k-instruct-Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "max_tokens": 128000, + "stop": ["<|end|>"], + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Microsoft", + "tags": [ + "14B", + "Finetuned" + ], + "size": 8366000000 + }, + "engine": "nitro" + } \ No newline at end of file diff --git a/extensions/inference-nitro-extension/resources/models/phind-34b/model.json b/extensions/inference-nitro-extension/resources/models/phind-34b/model.json new file mode 100644 index 0000000000..14099a635e --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/phind-34b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "phind-codellama-34b-v2.Q5_K_M.gguf", + "url": "https://huggingface.co/TheBloke/Phind-CodeLlama-34B-v2-GGUF/resolve/main/phind-codellama-34b-v2.Q5_K_M.gguf" + } + ], + "id": "phind-34b", + "object": "model", + "name": "Phind 34B Q4", + "version": "1.3", + "description": "Phind 34B is the best Open-source coding model.", + "format": "gguf", + "settings": { + "ctx_len": 16384, + "prompt_template": "### System Prompt\n{system_message}\n### User Message\n{prompt}\n### Assistant", + "llama_model_path": "phind-codellama-34b-v2.Q4_K_M.gguf", + "ngl": 49 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 16384, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Phind", + "tags": ["34B", "Finetuned"], + "size": 20220000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/qwen-7b/model.json b/extensions/inference-nitro-extension/resources/models/qwen-7b/model.json new file mode 100644 index 0000000000..85081a605a --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/qwen-7b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "qwen1_5-7b-chat-q4_k_m.gguf", + "url": "https://huggingface.co/Qwen/Qwen1.5-7B-Chat-GGUF/resolve/main/qwen1_5-7b-chat-q4_k_m.gguf" + } + ], + "id": "qwen-7b", + "object": "model", + "name": "Qwen Chat 7B Q4", + "version": "1.2", + "description": "Qwen is optimized at Chinese, ideal for everyday tasks.", + "format": "gguf", + "settings": { + "ctx_len": 32768, + "prompt_template": "<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant", + "llama_model_path": "qwen1_5-7b-chat-q4_k_m.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 32768, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Alibaba", + "tags": ["7B", "Finetuned"], + "size": 4770000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/qwen2-7b/model.json b/extensions/inference-nitro-extension/resources/models/qwen2-7b/model.json new file mode 100644 index 0000000000..8939a98f30 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/qwen2-7b/model.json @@ -0,0 +1,36 @@ +{ + "sources": [ + { + "filename": "Qwen2-7B-Instruct-Q4_K_M.gguf", + "url": "https://huggingface.co/bartowski/Qwen2-7B-Instruct-GGUF/resolve/main/Qwen2-7B-Instruct-Q4_K_M.gguf" + } + ], + "id": "qwen2-7b", + "object": "model", + "name": "Qwen 2 Instruct 7B Q4", + "version": "1.1", + "description": "Qwen is optimized at Chinese, ideal for everyday tasks.", + "format": "gguf", + "settings": { + "ctx_len": 32768, + "prompt_template": "<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant", + "llama_model_path": "Qwen2-7B-Instruct-Q4_K_M.gguf", + "ngl": 29 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 32768, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Alibaba", + "tags": ["7B", "Finetuned"], + "size": 4680000000 + }, + "engine": "nitro" + } + \ No newline at end of file diff --git a/extensions/inference-nitro-extension/resources/models/stable-zephyr-3b/model.json b/extensions/inference-nitro-extension/resources/models/stable-zephyr-3b/model.json new file mode 100644 index 0000000000..938e03fb71 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/stable-zephyr-3b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "url": "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q8_0.gguf", + "filename": "stablelm-zephyr-3b.Q8_0.gguf" + } + ], + "id": "stable-zephyr-3b", + "object": "model", + "name": "Stable Zephyr 3B Q8", + "version": "1.1", + "description": "StableLM Zephyr 3B is a best model for low-end machine.", + "format": "gguf", + "settings": { + "ctx_len": 4096, + "prompt_template": "<|user|>\n{prompt}<|endoftext|>\n<|assistant|>", + "llama_model_path": "stablelm-zephyr-3b.Q8_0.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 4096, + "stop": ["<|endoftext|>"], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "StabilityAI", + "tags": ["3B", "Finetuned", "Tiny"], + "size": 2970000000 + }, + "engine": "nitro" + } \ No newline at end of file diff --git a/extensions/inference-nitro-extension/resources/models/stealth-v1.2-7b/model.json b/extensions/inference-nitro-extension/resources/models/stealth-v1.2-7b/model.json new file mode 100644 index 0000000000..c17d1c35e7 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/stealth-v1.2-7b/model.json @@ -0,0 +1,34 @@ +{ + "sources": [ + { + "filename": "stealth-v1.3.Q4_K_M.gguf", + "url": "https://huggingface.co/janhq/stealth-v1.3-GGUF/resolve/main/stealth-v1.3.Q4_K_M.gguf" + } + ], + "id": "stealth-v1.2-7b", + "object": "model", + "name": "Stealth 7B Q4", + "version": "1.2", + "description": "This is a new experimental family designed to enhance Mathematical and Logical abilities.", + "format": "gguf", + "settings": { + "ctx_len": 32768, + "prompt_template": "<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant", + "llama_model_path": "stealth-v1.3.Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 32768, + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Jan", + "tags": ["7B", "Finetuned"], + "size": 4370000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/tinyllama-1.1b/model.json b/extensions/inference-nitro-extension/resources/models/tinyllama-1.1b/model.json new file mode 100644 index 0000000000..a49e790734 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/tinyllama-1.1b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", + "url": "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + } + ], + "id": "tinyllama-1.1b", + "object": "model", + "name": "TinyLlama Chat 1.1B Q4", + "version": "1.1", + "description": "TinyLlama is a tiny model with only 1.1B. It's a good model for less powerful computers.", + "format": "gguf", + "settings": { + "ctx_len": 4096, + "prompt_template": "<|system|>\n{system_message}<|user|>\n{prompt}<|assistant|>", + "llama_model_path": "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", + "ngl": 23 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 2048, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "TinyLlama", + "tags": ["Tiny", "Foundation Model"], + "size": 669000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/trinity-v1.2-7b/model.json b/extensions/inference-nitro-extension/resources/models/trinity-v1.2-7b/model.json new file mode 100644 index 0000000000..6c9aa2b89c --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/trinity-v1.2-7b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "trinity-v1.2.Q4_K_M.gguf", + "url": "https://huggingface.co/janhq/trinity-v1.2-GGUF/resolve/main/trinity-v1.2.Q4_K_M.gguf" + } + ], + "id": "trinity-v1.2-7b", + "object": "model", + "name": "Trinity-v1.2 7B Q4", + "version": "1.2", + "description": "Trinity is an experimental model merge using the Slerp method. Recommended for daily assistance purposes.", + "format": "gguf", + "settings": { + "ctx_len": 32768, + "prompt_template": "<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant", + "llama_model_path": "trinity-v1.2.Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 32768, + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Jan", + "tags": ["7B", "Merged"], + "size": 4370000000, + "cover": "https://raw.githubusercontent.com/janhq/jan/dev/models/trinity-v1.2-7b/cover.png" + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/vistral-7b/model.json b/extensions/inference-nitro-extension/resources/models/vistral-7b/model.json new file mode 100644 index 0000000000..b84f2c676e --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/vistral-7b/model.json @@ -0,0 +1,36 @@ +{ + "sources": [ + { + "filename": "vistral-7b-chat-dpo.Q4_K_M.gguf", + "url": "https://huggingface.co/janhq/vistral-7b-chat-dpo-GGUF/resolve/main/vistral-7b-chat-dpo.Q4_K_M.gguf" + } + ], + "id": "vistral-7b", + "object": "model", + "name": "Vistral 7B Q4", + "version": "1.2", + "description": "Vistral 7B has a deep understanding of Vietnamese.", + "format": "gguf", + "settings": { + "ctx_len": 32768, + "prompt_template": "[INST] <>\n{system_message}\n<>\n{prompt} [/INST]", + "llama_model_path": "vistral-7b-chat-dpo.Q4_K_M.gguf", + "ngl": 33 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 32768, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "Viet Mistral, Jan", + "tags": ["7B", "Finetuned"], + "size": 4410000000 + }, + "engine": "nitro" + } + \ No newline at end of file diff --git a/extensions/inference-nitro-extension/resources/models/wizardcoder-13b/model.json b/extensions/inference-nitro-extension/resources/models/wizardcoder-13b/model.json new file mode 100644 index 0000000000..101eedfd19 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/wizardcoder-13b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "wizardcoder-python-13b-v1.0.Q4_K_M.gguf", + "url": "https://huggingface.co/TheBloke/WizardCoder-Python-13B-V1.0-GGUF/resolve/main/wizardcoder-python-13b-v1.0.Q4_K_M.gguf" + } + ], + "id": "wizardcoder-13b", + "object": "model", + "name": "Wizard Coder Python 13B Q4", + "version": "1.2", + "description": "WizardCoder 13B is a Python coding model. This model demonstrate high proficiency in specific domains like coding and mathematics.", + "format": "gguf", + "settings": { + "ctx_len": 16384, + "prompt_template": "### Instruction:\n{prompt}\n### Response:", + "llama_model_path": "wizardcoder-python-13b-v1.0.Q4_K_M.gguf", + "ngl": 41 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 16384, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "WizardLM, The Bloke", + "tags": ["Recommended", "13B", "Finetuned"], + "size": 7870000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/resources/models/yi-34b/model.json b/extensions/inference-nitro-extension/resources/models/yi-34b/model.json new file mode 100644 index 0000000000..db7df9f2d4 --- /dev/null +++ b/extensions/inference-nitro-extension/resources/models/yi-34b/model.json @@ -0,0 +1,35 @@ +{ + "sources": [ + { + "filename": "yi-34b-chat.Q4_K_M.gguf", + "url": "https://huggingface.co/TheBloke/Yi-34B-Chat-GGUF/resolve/main/yi-34b-chat.Q4_K_M.gguf" + } + ], + "id": "yi-34b", + "object": "model", + "name": "Yi 34B Q4", + "version": "1.1", + "description": "Yi-34B, a specialized chat model, is known for its diverse and creative responses and excels across various NLP tasks and benchmarks.", + "format": "gguf", + "settings": { + "ctx_len": 4096, + "prompt_template": "<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant", + "llama_model_path": "yi-34b-chat.Q4_K_M.gguf", + "ngl": 61 + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 4096, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "01-ai, The Bloke", + "tags": ["34B", "Foundational Model"], + "size": 20660000000 + }, + "engine": "nitro" +} diff --git a/extensions/inference-nitro-extension/rollup.config.ts b/extensions/inference-nitro-extension/rollup.config.ts new file mode 100644 index 0000000000..71712a4d67 --- /dev/null +++ b/extensions/inference-nitro-extension/rollup.config.ts @@ -0,0 +1,155 @@ +import resolve from '@rollup/plugin-node-resolve' +import commonjs from '@rollup/plugin-commonjs' +import sourceMaps from 'rollup-plugin-sourcemaps' +import typescript from 'rollup-plugin-typescript2' +import json from '@rollup/plugin-json' +import replace from '@rollup/plugin-replace' +const packageJson = require('./package.json') +const defaultSettingJson = require('./resources/default_settings.json') + +const bakllavaJson = require('./resources/models/bakllava-1/model.json') +const codeninja7bJson = require('./resources/models/codeninja-1.0-7b/model.json') +const commandr34bJson = require('./resources/models/command-r-34b/model.json') +const deepseekCoder13bJson = require('./resources/models/deepseek-coder-1.3b/model.json') +const deepseekCoder34bJson = require('./resources/models/deepseek-coder-34b/model.json') +const gemma2bJson = require('./resources/models/gemma-2b/model.json') +const gemma7bJson = require('./resources/models/gemma-7b/model.json') +const llama2Chat70bJson = require('./resources/models/llama2-chat-70b/model.json') +const llama2Chat7bJson = require('./resources/models/llama2-chat-7b/model.json') +const llamacorn1bJson = require('./resources/models/llamacorn-1.1b/model.json') +const llava13bJson = require('./resources/models/llava-13b/model.json') +const llava7bJson = require('./resources/models/llava-7b/model.json') +const mistralIns7bq4Json = require('./resources/models/mistral-ins-7b-q4/model.json') +const mixtral8x7bInstructJson = require('./resources/models/mixtral-8x7b-instruct/model.json') +const noromaid7bJson = require('./resources/models/noromaid-7b/model.json') +const openchat357bJson = require('./resources/models/openchat-3.5-7b/model.json') +const phi3bJson = require('./resources/models/phi3-3.8b/model.json') +const phind34bJson = require('./resources/models/phind-34b/model.json') +const qwen7bJson = require('./resources/models/qwen-7b/model.json') +const stableZephyr3bJson = require('./resources/models/stable-zephyr-3b/model.json') +const stealthv127bJson = require('./resources/models/stealth-v1.2-7b/model.json') +const tinyllama11bJson = require('./resources/models/tinyllama-1.1b/model.json') +const trinityv127bJson = require('./resources/models/trinity-v1.2-7b/model.json') +const vistral7bJson = require('./resources/models/vistral-7b/model.json') +const wizardcoder13bJson = require('./resources/models/wizardcoder-13b/model.json') +const yi34bJson = require('./resources/models/yi-34b/model.json') +const llama3Json = require('./resources/models/llama3-8b-instruct/model.json') +const llama3Hermes8bJson = require('./resources/models/llama3-hermes-8b/model.json') +const aya8bJson = require('./resources/models/aya-23-8b/model.json') +const aya35bJson = require('./resources/models/aya-23-35b/model.json') +const phimediumJson = require('./resources/models/phi3-medium/model.json') +const codestralJson = require('./resources/models/codestral-22b/model.json') +const qwen2Json = require('./resources/models/qwen2-7b/model.json') + + +export default [ + { + input: `src/index.ts`, + output: [{ file: packageJson.main, format: 'es', sourcemap: true }], + // Indicate here external modules you don't wanna include in your bundle (i.e.: 'lodash') + external: [], + watch: { + include: 'src/**', + }, + plugins: [ + replace({ + preventAssignment: true, + MODELS: JSON.stringify([ + bakllavaJson, + codeninja7bJson, + commandr34bJson, + deepseekCoder13bJson, + deepseekCoder34bJson, + gemma2bJson, + gemma7bJson, + llama2Chat70bJson, + llama2Chat7bJson, + llamacorn1bJson, + llava13bJson, + llava7bJson, + mistralIns7bq4Json, + mixtral8x7bInstructJson, + noromaid7bJson, + openchat357bJson, + phi3bJson, + phind34bJson, + qwen7bJson, + stableZephyr3bJson, + stealthv127bJson, + tinyllama11bJson, + trinityv127bJson, + vistral7bJson, + wizardcoder13bJson, + yi34bJson, + llama3Json, + llama3Hermes8bJson, + phimediumJson, + aya8bJson, + aya35bJson, + codestralJson, + qwen2Json + ]), + NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), + DEFAULT_SETTINGS: JSON.stringify(defaultSettingJson), + INFERENCE_URL: JSON.stringify( + process.env.INFERENCE_URL || + 'http://127.0.0.1:3928/inferences/server/chat_completion' + ), + TROUBLESHOOTING_URL: JSON.stringify( + 'https://jan.ai/guides/troubleshooting' + ), + JAN_SERVER_INFERENCE_URL: JSON.stringify( + 'http://localhost:1337/v1/chat/completions' + ), + CUDA_DOWNLOAD_URL: JSON.stringify( + 'https://catalog.jan.ai/dist/cuda-dependencies///cuda.tar.gz' + ), + }), + // Allow json resolution + json(), + // Compile TypeScript files + typescript({ useTsconfigDeclarationDir: true }), + // Compile TypeScript files + // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) + commonjs(), + // Allow node_modules resolution, so you can use 'external' to control + // which external modules to include in the bundle + // https://github.com/rollup/rollup-plugin-node-resolve#usage + resolve({ + extensions: ['.js', '.ts', '.svelte'], + browser: true, + }), + + // Resolve source maps to the original source + sourceMaps(), + ], + }, + { + input: `src/node/index.ts`, + output: [ + { file: 'dist/node/index.cjs.js', format: 'cjs', sourcemap: true }, + ], + // Indicate here external modules you don't wanna include in your bundle (i.e.: 'lodash') + external: ['@janhq/core/node'], + watch: { + include: 'src/node/**', + }, + plugins: [ + // Allow json resolution + json(), + // Compile TypeScript files + typescript({ useTsconfigDeclarationDir: true }), + // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) + commonjs(), + // Allow node_modules resolution, so you can use 'external' to control + // which external modules to include in the bundle + // https://github.com/rollup/rollup-plugin-node-resolve#usage + resolve({ + extensions: ['.ts', '.js', '.json'], + }), + + // Resolve source maps to the original source + sourceMaps(), + ], + }, +] diff --git a/extensions/inference-nitro-extension/src/@types/global.d.ts b/extensions/inference-nitro-extension/src/@types/global.d.ts new file mode 100644 index 0000000000..85c9b939f5 --- /dev/null +++ b/extensions/inference-nitro-extension/src/@types/global.d.ts @@ -0,0 +1,15 @@ +declare const NODE: string +declare const INFERENCE_URL: string +declare const TROUBLESHOOTING_URL: string +declare const JAN_SERVER_INFERENCE_URL: string +declare const DEFAULT_SETTINGS: Array +declare const MODELS: Array + +/** + * The response from the initModel function. + * @property error - An error message if the model fails to load. + */ +interface ModelOperationResponse { + error?: any + modelFile?: string +} diff --git a/extensions/inference-nitro-extension/src/babel.config.js b/extensions/inference-nitro-extension/src/babel.config.js new file mode 100644 index 0000000000..befbdd148b --- /dev/null +++ b/extensions/inference-nitro-extension/src/babel.config.js @@ -0,0 +1,6 @@ +module.exports = { + presets: [ + ['@babel/preset-env', { targets: { node: 'current' } }], + '@babel/preset-typescript', + ], +} diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts new file mode 100644 index 0000000000..a027e88449 --- /dev/null +++ b/extensions/inference-nitro-extension/src/index.ts @@ -0,0 +1,191 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-extension/src/index + */ + +import { + events, + executeOnMain, + Model, + ModelEvent, + LocalOAIEngine, + InstallationState, + systemInformation, + fs, + getJanDataFolderPath, + joinPath, + DownloadRequest, + baseName, + downloadFile, + DownloadState, + DownloadEvent, +} from '@janhq/core' + +declare const CUDA_DOWNLOAD_URL: string +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanInferenceNitroExtension extends LocalOAIEngine { + nodeModule: string = NODE + provider: string = 'nitro' + + /** + * Checking the health for Nitro's process each 5 secs. + */ + private static readonly _intervalHealthCheck = 5 * 1000 + + /** + * The interval id for the health check. Used to stop the health check. + */ + private getNitroProcessHealthIntervalId: NodeJS.Timeout | undefined = undefined + + /** + * Tracking the current state of nitro process. + */ + private nitroProcessInfo: any = undefined + + /** + * The URL for making inference requests. + */ + inferenceUrl = '' + + /** + * Subscribes to events emitted by the @janhq/core package. + */ + async onLoad() { + this.inferenceUrl = INFERENCE_URL + + // If the extension is running in the browser, use the base API URL from the core package. + if (!('electronAPI' in window)) { + this.inferenceUrl = `${window.core?.api?.baseApiUrl}/v1/chat/completions` + } + + this.getNitroProcessHealthIntervalId = setInterval( + () => this.periodicallyGetNitroHealth(), + JanInferenceNitroExtension._intervalHealthCheck + ) + const models = MODELS as unknown as Model[] + this.registerModels(models) + super.onLoad() + + executeOnMain(NODE, 'addAdditionalDependencies', { + name: this.name, + version: this.version, + }) + } + + /** + * Periodically check for nitro process's health. + */ + private async periodicallyGetNitroHealth(): Promise { + const health = await executeOnMain(NODE, 'getCurrentNitroProcessInfo') + + const isRunning = this.nitroProcessInfo?.isRunning ?? false + if (isRunning && health.isRunning === false) { + console.debug('Nitro process is stopped') + events.emit(ModelEvent.OnModelStopped, {}) + } + this.nitroProcessInfo = health + } + + override loadModel(model: Model): Promise { + if (model.engine !== this.provider) return Promise.resolve() + this.getNitroProcessHealthIntervalId = setInterval( + () => this.periodicallyGetNitroHealth(), + JanInferenceNitroExtension._intervalHealthCheck + ) + return super.loadModel(model) + } + + override async unloadModel(model?: Model): Promise { + if (model?.engine && model.engine !== this.provider) return + + // stop the periocally health check + if (this.getNitroProcessHealthIntervalId) { + clearInterval(this.getNitroProcessHealthIntervalId) + this.getNitroProcessHealthIntervalId = undefined + } + return super.unloadModel(model) + } + + override async install(): Promise { + const info = await systemInformation() + + const platform = info.osInfo?.platform === 'win32' ? 'windows' : 'linux' + const downloadUrl = CUDA_DOWNLOAD_URL + + const url = downloadUrl + .replace('', info.gpuSetting?.cuda?.version ?? '12.4') + .replace('', platform) + + console.debug('Downloading Cuda Toolkit Dependency: ', url) + + const janDataFolderPath = await getJanDataFolderPath() + + const executableFolderPath = await joinPath([ + janDataFolderPath, + 'engines', + this.name ?? 'cortex-cpp', + this.version ?? '1.0.0', + ]) + + if (!(await fs.existsSync(executableFolderPath))) { + await fs.mkdir(executableFolderPath) + } + + const tarball = await baseName(url) + const tarballFullPath = await joinPath([executableFolderPath, tarball]) + + const downloadRequest: DownloadRequest = { + url, + localPath: tarballFullPath, + extensionId: this.name, + downloadType: 'extension', + } + downloadFile(downloadRequest) + + const onFileDownloadSuccess = async (state: DownloadState) => { + console.log(state) + // if other download, ignore + if (state.fileName !== tarball) return + events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) + await executeOnMain( + NODE, + 'decompressRunner', + tarballFullPath, + executableFolderPath + ) + events.emit(DownloadEvent.onFileUnzipSuccess, state) + } + events.on(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) + } + + override async installationState(): Promise { + const info = await systemInformation() + if ( + info.gpuSetting?.run_mode === 'gpu' && + !info.gpuSetting?.vulkan && + info.osInfo && + info.osInfo.platform !== 'darwin' && + !info.gpuSetting?.cuda?.exist + ) { + const janDataFolderPath = await getJanDataFolderPath() + + const executableFolderPath = await joinPath([ + janDataFolderPath, + 'engines', + this.name ?? 'cortex-cpp', + this.version ?? '1.0.0', + ]) + + if (!(await fs.existsSync(executableFolderPath))) return 'NotInstalled' + return 'Installed' + } + return 'NotRequired' + } +} diff --git a/extensions/inference-nitro-extension/src/node/execute.test.ts b/extensions/inference-nitro-extension/src/node/execute.test.ts new file mode 100644 index 0000000000..cf9e84acf7 --- /dev/null +++ b/extensions/inference-nitro-extension/src/node/execute.test.ts @@ -0,0 +1,227 @@ +import { describe, expect, it } from '@jest/globals' +import { executableNitroFile } from './execute' +import { GpuSetting } from '@janhq/core' +import { sep } from 'path' + +let testSettings: GpuSetting = { + run_mode: 'cpu', + vulkan: false, + cuda: { + exist: false, + version: '11', + }, + gpu_highest_vram: '0', + gpus: [], + gpus_in_use: [], + is_initial: false, + notify: true, + nvidia_driver: { + exist: false, + version: '11', + }, +} +const originalPlatform = process.platform + +describe('test executable nitro file', () => { + afterAll(function () { + Object.defineProperty(process, 'platform', { + value: originalPlatform, + }) + }) + + it('executes on MacOS', () => { + Object.defineProperty(process, 'platform', { + value: 'darwin', + }) + Object.defineProperty(process, 'arch', { + value: 'arm64', + }) + expect(executableNitroFile(testSettings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`mac-arm64${sep}cortex-cpp`), + cudaVisibleDevices: '', + vkVisibleDevices: '', + }) + ) + Object.defineProperty(process, 'arch', { + value: 'amd64', + }) + expect(executableNitroFile(testSettings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`mac-amd64${sep}cortex-cpp`), + cudaVisibleDevices: '', + vkVisibleDevices: '', + }) + ) + }) + + it('executes on Windows CPU', () => { + Object.defineProperty(process, 'platform', { + value: 'win32', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'cpu', + cuda: { + exist: true, + version: '11', + }, + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`win-cpu${sep}cortex-cpp.exe`), + cudaVisibleDevices: '', + vkVisibleDevices: '', + }) + ) + }) + + it('executes on Windows Cuda 11', () => { + Object.defineProperty(process, 'platform', { + value: 'win32', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'gpu', + cuda: { + exist: true, + version: '11', + }, + nvidia_driver: { + exist: true, + version: '12', + }, + gpus_in_use: ['0'], + gpus: [ + { + id: '0', + name: 'NVIDIA GeForce GTX 1080', + vram: '80000000', + }, + ], + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`win-cuda-11-7${sep}cortex-cpp.exe`), + cudaVisibleDevices: '0', + vkVisibleDevices: '0', + }) + ) + }) + + it('executes on Windows Cuda 12', () => { + Object.defineProperty(process, 'platform', { + value: 'win32', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'gpu', + cuda: { + exist: true, + version: '12', + }, + nvidia_driver: { + exist: true, + version: '12', + }, + gpus_in_use: ['0'], + gpus: [ + { + id: '0', + name: 'NVIDIA GeForce GTX 1080', + vram: '80000000', + }, + ], + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`win-cuda-12-0${sep}cortex-cpp.exe`), + cudaVisibleDevices: '0', + vkVisibleDevices: '0', + }) + ) + }) + + it('executes on Linux CPU', () => { + Object.defineProperty(process, 'platform', { + value: 'linux', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'cpu', + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`linux-cpu${sep}cortex-cpp`), + cudaVisibleDevices: '', + vkVisibleDevices: '', + }) + ) + }) + + it('executes on Linux Cuda 11', () => { + Object.defineProperty(process, 'platform', { + value: 'linux', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'gpu', + cuda: { + exist: true, + version: '11', + }, + nvidia_driver: { + exist: true, + version: '12', + }, + gpus_in_use: ['0'], + gpus: [ + { + id: '0', + name: 'NVIDIA GeForce GTX 1080', + vram: '80000000', + }, + ], + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`linux-cuda-11-7${sep}cortex-cpp`), + cudaVisibleDevices: '0', + vkVisibleDevices: '0', + }) + ) + }) + + it('executes on Linux Cuda 12', () => { + Object.defineProperty(process, 'platform', { + value: 'linux', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'gpu', + cuda: { + exist: true, + version: '12', + }, + nvidia_driver: { + exist: true, + version: '12', + }, + gpus_in_use: ['0'], + gpus: [ + { + id: '0', + name: 'NVIDIA GeForce GTX 1080', + vram: '80000000', + }, + ], + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`linux-cuda-12-0${sep}cortex-cpp`), + cudaVisibleDevices: '0', + vkVisibleDevices: '0', + }) + ) + }) +}) diff --git a/extensions/inference-nitro-extension/src/node/execute.ts b/extensions/inference-nitro-extension/src/node/execute.ts new file mode 100644 index 0000000000..417734afa7 --- /dev/null +++ b/extensions/inference-nitro-extension/src/node/execute.ts @@ -0,0 +1,62 @@ +import { GpuSetting } from '@janhq/core' +import * as path from 'path' + +export interface NitroExecutableOptions { + executablePath: string + cudaVisibleDevices: string + vkVisibleDevices: string +} +const runMode = (settings?: GpuSetting): string => { + if (process.platform === 'darwin') + // MacOS now has universal binaries + return '' + + if (!settings) return 'cpu' + + return settings.vulkan === true + ? 'vulkan' + : settings.run_mode === 'cpu' + ? 'cpu' + : 'cuda' +} + +const os = (): string => { + return process.platform === 'win32' + ? 'win' + : process.platform === 'darwin' + ? process.arch === 'arm64' ? 'mac-arm64' : 'mac-amd64' + : 'linux' +} + +const extension = (): '.exe' | '' => { + return process.platform === 'win32' ? '.exe' : '' +} + +const cudaVersion = (settings?: GpuSetting): '11-7' | '12-0' | undefined => { + const isUsingCuda = + settings?.vulkan !== true && settings?.run_mode === 'gpu' && os() !== 'mac' + + if (!isUsingCuda) return undefined + return settings?.cuda?.version === '11' ? '11-7' : '12-0' +} + +/** + * Find which executable file to run based on the current platform. + * @returns The name of the executable file to run. + */ +export const executableNitroFile = ( + gpuSetting?: GpuSetting +): NitroExecutableOptions => { + let binaryFolder = [os(), runMode(gpuSetting), cudaVersion(gpuSetting)] + .filter((e) => !!e) + .join('-') + let cudaVisibleDevices = gpuSetting?.gpus_in_use.join(',') ?? '' + let vkVisibleDevices = gpuSetting?.gpus_in_use.join(',') ?? '' + let binaryName = `cortex-cpp${extension()}` + + return { + executablePath: path.join(__dirname, '..', 'bin', binaryFolder, binaryName), + cudaVisibleDevices, + vkVisibleDevices, + } +} diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts new file mode 100644 index 0000000000..1b24e0a381 --- /dev/null +++ b/extensions/inference-nitro-extension/src/node/index.ts @@ -0,0 +1,464 @@ +import fs from 'fs' +import path from 'path' +import { ChildProcessWithoutNullStreams, spawn } from 'child_process' +import tcpPortUsed from 'tcp-port-used' +import fetchRT from 'fetch-retry' +import { + log, + getSystemResourceInfo, + Model, + InferenceEngine, + ModelSettingParams, + PromptTemplate, + SystemInformation, + getJanDataFolderPath, +} from '@janhq/core/node' +import { executableNitroFile } from './execute' +import terminate from 'terminate' +import decompress from 'decompress' + +// Polyfill fetch with retry +const fetchRetry = fetchRT(fetch) + +/** + * The response object for model init operation. + */ +interface ModelInitOptions { + modelFolder: string + model: Model +} +// The PORT to use for the Nitro subprocess +const PORT = 3928 +// The HOST address to use for the Nitro subprocess +const LOCAL_HOST = '127.0.0.1' +// The URL for the Nitro subprocess +const NITRO_HTTP_SERVER_URL = `http://${LOCAL_HOST}:${PORT}` +// The URL for the Nitro subprocess to load a model +const NITRO_HTTP_LOAD_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/server/loadmodel` +// The URL for the Nitro subprocess to validate a model +const NITRO_HTTP_VALIDATE_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/server/modelstatus` +// The URL for the Nitro subprocess to kill itself +const NITRO_HTTP_KILL_URL = `${NITRO_HTTP_SERVER_URL}/processmanager/destroy` + +const NITRO_PORT_FREE_CHECK_INTERVAL = 100 + +// The supported model format +// TODO: Should be an array to support more models +const SUPPORTED_MODEL_FORMAT = '.gguf' + +// The subprocess instance for Nitro +let subprocess: ChildProcessWithoutNullStreams | undefined = undefined + +// The current model settings +let currentSettings: (ModelSettingParams & { model?: string }) | undefined = + undefined + +/** + * Stops a Nitro subprocess. + * @param wrapper - The model wrapper. + * @returns A Promise that resolves when the subprocess is terminated successfully, or rejects with an error message if the subprocess fails to terminate. + */ +function unloadModel(): Promise { + return killSubprocess() +} + +/** + * Initializes a Nitro subprocess to load a machine learning model. + * @param wrapper - The model wrapper. + * @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load. + * TODO: Should pass absolute of the model file instead of just the name - So we can modurize the module.ts to npm package + */ +async function loadModel( + params: ModelInitOptions, + systemInfo?: SystemInformation +): Promise { + if (params.model.engine !== InferenceEngine.nitro) { + // Not a nitro model + return Promise.resolve() + } + + if (params.model.engine !== InferenceEngine.nitro) { + return Promise.reject('Not a cortex model') + } else { + const nitroResourceProbe = await getSystemResourceInfo() + // Convert settings.prompt_template to system_prompt, user_prompt, ai_prompt + if (params.model.settings.prompt_template) { + const promptTemplate = params.model.settings.prompt_template + const prompt = promptTemplateConverter(promptTemplate) + if (prompt?.error) { + return Promise.reject(prompt.error) + } + params.model.settings.system_prompt = prompt.system_prompt + params.model.settings.user_prompt = prompt.user_prompt + params.model.settings.ai_prompt = prompt.ai_prompt + } + + // modelFolder is the absolute path to the running model folder + // e.g. ~/jan/models/llama-2 + let modelFolder = params.modelFolder + + let llama_model_path = params.model.settings.llama_model_path + + // Absolute model path support + if ( + params.model?.sources.length && + params.model.sources.every((e) => fs.existsSync(e.url)) + ) { + llama_model_path = + params.model.sources.length === 1 + ? params.model.sources[0].url + : params.model.sources.find((e) => + e.url.includes(llama_model_path ?? params.model.id) + )?.url + } + + if (!llama_model_path || !path.isAbsolute(llama_model_path)) { + // Look for GGUF model file + const modelFiles: string[] = fs.readdirSync(modelFolder) + const ggufBinFile = modelFiles.find( + (file) => + // 1. Prioritize llama_model_path (predefined) + (llama_model_path && file === llama_model_path) || + // 2. Prioritize GGUF File (manual import) + file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT) || + // 3. Fallback Model ID (for backward compatibility) + file === params.model.id + ) + if (ggufBinFile) llama_model_path = path.join(modelFolder, ggufBinFile) + } + + // Look for absolute source path for single model + + if (!llama_model_path) return Promise.reject('No GGUF model file found') + + currentSettings = { + cpu_threads: Math.max(1, nitroResourceProbe.numCpuPhysicalCore), + // model.settings can override the default settings + ...params.model.settings, + llama_model_path, + model: params.model.id, + // This is critical and requires real CPU physical core count (or performance core) + ...(params.model.settings.mmproj && { + mmproj: path.isAbsolute(params.model.settings.mmproj) + ? params.model.settings.mmproj + : path.join(modelFolder, params.model.settings.mmproj), + }), + } + return runNitroAndLoadModel(params.model.id, systemInfo) + } +} + +/** + * 1. Spawn Nitro process + * 2. Load model into Nitro subprocess + * 3. Validate model status + * @returns + */ +async function runNitroAndLoadModel( + modelId: string, + systemInfo?: SystemInformation +) { + // Gather system information for CPU physical cores and memory + return killSubprocess() + .then(() => + tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) + ) + .then(() => spawnNitroProcess(systemInfo)) + .then(() => loadLLMModel(currentSettings)) + .then(() => validateModelStatus(modelId)) + .catch((err) => { + // TODO: Broadcast error so app could display proper error message + log(`[CORTEX]::Error: ${err}`) + return { error: err } + }) +} + +/** + * Parse prompt template into agrs settings + * @param promptTemplate Template as string + * @returns + */ +function promptTemplateConverter(promptTemplate: string): PromptTemplate { + // Split the string using the markers + const systemMarker = '{system_message}' + const promptMarker = '{prompt}' + + if ( + promptTemplate.includes(systemMarker) && + promptTemplate.includes(promptMarker) + ) { + // Find the indices of the markers + const systemIndex = promptTemplate.indexOf(systemMarker) + const promptIndex = promptTemplate.indexOf(promptMarker) + + // Extract the parts of the string + const system_prompt = promptTemplate.substring(0, systemIndex) + const user_prompt = promptTemplate.substring( + systemIndex + systemMarker.length, + promptIndex + ) + const ai_prompt = promptTemplate.substring( + promptIndex + promptMarker.length + ) + + // Return the split parts + return { system_prompt, user_prompt, ai_prompt } + } else if (promptTemplate.includes(promptMarker)) { + // Extract the parts of the string for the case where only promptMarker is present + const promptIndex = promptTemplate.indexOf(promptMarker) + const user_prompt = promptTemplate.substring(0, promptIndex) + const ai_prompt = promptTemplate.substring( + promptIndex + promptMarker.length + ) + + // Return the split parts + return { user_prompt, ai_prompt } + } + + // Return an error if none of the conditions are met + return { error: 'Cannot split prompt template' } +} + +/** + * Loads a LLM model into the Nitro subprocess by sending a HTTP POST request. + * @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load. + */ +function loadLLMModel(settings: any): Promise { + if (!settings?.ngl) { + settings.ngl = 100 + } + log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`) + return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(settings), + retries: 3, + retryDelay: 300, + }) + .then((res) => { + log( + `[CORTEX]::Debug: Load model success with response ${JSON.stringify( + res + )}` + ) + return Promise.resolve(res) + }) + .catch((err) => { + log(`[CORTEX]::Error: Load model failed with error ${err}`) + return Promise.reject(err) + }) +} + +/** + * Validates the status of a model. + * @returns {Promise} A promise that resolves to an object. + * If the model is loaded successfully, the object is empty. + * If the model is not loaded successfully, the object contains an error message. + */ +async function validateModelStatus(modelId: string): Promise { + // Send a GET request to the validation URL. + // Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries. + return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, { + method: 'POST', + body: JSON.stringify({ model: modelId }), + headers: { + 'Content-Type': 'application/json', + }, + retries: 5, + retryDelay: 300, + }).then(async (res: Response) => { + log( + `[CORTEX]::Debug: Validate model state with response ${JSON.stringify( + res.status + )}` + ) + // If the response is OK, check model_loaded status. + if (res.ok) { + const body = await res.json() + // If the model is loaded, return an empty object. + // Otherwise, return an object with an error message. + if (body.model_loaded) { + log( + `[CORTEX]::Debug: Validate model state success with response ${JSON.stringify( + body + )}` + ) + return Promise.resolve() + } + } + log( + `[CORTEX]::Debug: Validate model state failed with response ${JSON.stringify( + res.statusText + )}` + ) + return Promise.reject('Validate model status failed') + }) +} + +/** + * Terminates the Nitro subprocess. + * @returns A Promise that resolves when the subprocess is terminated successfully, or rejects with an error message if the subprocess fails to terminate. + */ +async function killSubprocess(): Promise { + const controller = new AbortController() + setTimeout(() => controller.abort(), 5000) + log(`[CORTEX]::Debug: Request to kill cortex`) + + const killRequest = () => { + return fetch(NITRO_HTTP_KILL_URL, { + method: 'DELETE', + signal: controller.signal, + }) + .catch(() => {}) // Do nothing with this attempt + .then(() => + tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) + ) + .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .catch((err) => { + log( + `[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` + ) + throw 'PORT_NOT_AVAILABLE' + }) + } + + if (subprocess?.pid && process.platform !== 'darwin') { + log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) + const pid = subprocess.pid + return new Promise((resolve, reject) => { + terminate(pid, function (err) { + if (err) { + log('[CORTEX]::Failed to kill PID - sending request to kill') + killRequest().then(resolve).catch(reject) + } else { + tcpPortUsed + .waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) + .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => resolve()) + .catch(() => { + log( + '[CORTEX]::Failed to kill PID (Port check timeout) - sending request to kill' + ) + killRequest().then(resolve).catch(reject) + }) + } + }) + }) + } else { + return killRequest() + } +} + +/** + * Spawns a Nitro subprocess. + * @returns A promise that resolves when the Nitro subprocess is started. + */ +function spawnNitroProcess(systemInfo?: SystemInformation): Promise { + log(`[CORTEX]::Debug: Spawning cortex subprocess...`) + + return new Promise(async (resolve, reject) => { + let executableOptions = executableNitroFile(systemInfo?.gpuSetting) + + const args: string[] = ['1', LOCAL_HOST, PORT.toString()] + // Execute the binary + log( + `[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` + ) + log(path.parse(executableOptions.executablePath).dir) + subprocess = spawn( + executableOptions.executablePath, + ['1', LOCAL_HOST, PORT.toString()], + { + cwd: path.join(path.parse(executableOptions.executablePath).dir), + env: { + ...process.env, + CUDA_VISIBLE_DEVICES: executableOptions.cudaVisibleDevices, + // Vulkan - Support 1 device at a time for now + ...(executableOptions.vkVisibleDevices?.length > 0 && { + GGML_VULKAN_DEVICE: executableOptions.vkVisibleDevices[0], + }), + }, + } + ) + + // Handle subprocess output + subprocess.stdout.on('data', (data: any) => { + log(`[CORTEX]::Debug: ${data}`) + }) + + subprocess.stderr.on('data', (data: any) => { + log(`[CORTEX]::Error: ${data}`) + }) + + subprocess.on('close', (code: any) => { + log(`[CORTEX]::Debug: cortex exited with code: ${code}`) + subprocess = undefined + reject(`child process exited with code ${code}`) + }) + + tcpPortUsed + .waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000) + .then(() => { + log(`[CORTEX]::Debug: cortex is ready`) + resolve() + }) + }) +} + +/** + * Every module should have a dispose function + * This will be called when the extension is unloaded and should clean up any resources + * Also called when app is closed + */ +function dispose() { + // clean other registered resources here + killSubprocess() +} + +/** + * Nitro process info + */ +export interface NitroProcessInfo { + isRunning: boolean +} + +/** + * Retrieve current nitro process + */ +const getCurrentNitroProcessInfo = (): NitroProcessInfo => { + return { + isRunning: subprocess != null, + } +} + +const addAdditionalDependencies = (data: { name: string; version: string }) => { + const additionalPath = path.delimiter.concat( + path.join(getJanDataFolderPath(), 'engines', data.name, data.version) + ) + // Set the updated PATH + process.env.PATH = (process.env.PATH || '').concat(additionalPath) + process.env.LD_LIBRARY_PATH = (process.env.LD_LIBRARY_PATH || '').concat( + additionalPath + ) +} + +const decompressRunner = async (zipPath: string, output: string) => { + console.debug(`Decompressing ${zipPath} to ${output}...`) + try { + const files = await decompress(zipPath, output) + console.debug('Decompress finished!', files) + } catch (err) { + console.error(`Decompress ${zipPath} failed: ${err}`) + } +} + +export default { + loadModel, + unloadModel, + dispose, + getCurrentNitroProcessInfo, + addAdditionalDependencies, + decompressRunner, +} diff --git a/extensions/inference-nitro-extension/tsconfig.json b/extensions/inference-nitro-extension/tsconfig.json new file mode 100644 index 0000000000..bada43fc7b --- /dev/null +++ b/extensions/inference-nitro-extension/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "moduleResolution": "node", + "target": "es5", + "module": "ES2020", + "lib": ["es2015", "es2016", "es2017", "dom"], + "strict": true, + "sourceMap": true, + "declaration": true, + "allowSyntheticDefaultImports": true, + "experimentalDecorators": true, + "emitDecoratorMetadata": true, + "declarationDir": "dist/types", + "outDir": "dist", + "importHelpers": true, + "typeRoots": ["node_modules/@types"] + }, + "include": ["src"] +} diff --git a/extensions/inference-nvidia-extension/README.md b/extensions/inference-nvidia-extension/README.md new file mode 100644 index 0000000000..65a1b2b593 --- /dev/null +++ b/extensions/inference-nvidia-extension/README.md @@ -0,0 +1,79 @@ +# Nvidia Engine Extension + +Created using Jan extension example + +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-nvidia-extension/package.json b/extensions/inference-nvidia-extension/package.json new file mode 100644 index 0000000000..8bd7708bc8 --- /dev/null +++ b/extensions/inference-nvidia-extension/package.json @@ -0,0 +1,43 @@ +{ + "name": "@janhq/inference-nvidia-extension", + "productName": "NVIDIA NIM Inference Engine", + "version": "1.0.1", + "description": "This extension enables NVIDIA chat completion API calls", + "main": "dist/index.js", + "module": "dist/module.js", + "engine": "nvidia", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4", + "ts-loader": "^9.5.0" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "fetch-retry": "^5.0.6", + "path-browserify": "^1.0.1", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "fetch-retry" + ] +} diff --git a/extensions/inference-nvidia-extension/resources/models.json b/extensions/inference-nvidia-extension/resources/models.json new file mode 100644 index 0000000000..b97644fc99 --- /dev/null +++ b/extensions/inference-nvidia-extension/resources/models.json @@ -0,0 +1,31 @@ +[ + { + "sources": [ + { + "url": "https://integrate.api.nvidia.com/v1/chat/completions" + } + ], + "id": "mistralai/mistral-7b-instruct-v0.2", + "object": "model", + "name": "Mistral 7B", + "version": "1.1", + "description": "Mistral 7B with NVIDIA", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 1024, + "temperature": 0.3, + "top_p": 1, + "stream": false, + "frequency_penalty": 0, + "presence_penalty": 0, + "stop": null, + "seed": null + }, + "metadata": { + "author": "NVIDIA", + "tags": ["General"] + }, + "engine": "nvidia" + } +] diff --git a/extensions/inference-nvidia-extension/resources/settings.json b/extensions/inference-nvidia-extension/resources/settings.json new file mode 100644 index 0000000000..e7647b5621 --- /dev/null +++ b/extensions/inference-nvidia-extension/resources/settings.json @@ -0,0 +1,24 @@ +[ + { + "key": "chat-completions-endpoint", + "title": "Chat Completions Endpoint", + "description": "The endpoint to use for chat completions. See the [NVIDIA API documentation](https://www.nvidia.com/en-us/ai/) for more information.", + "controllerType": "input", + "controllerProps": { + "placeholder": "https://integrate.api.nvidia.com/v1/chat/completions", + "value": "https://integrate.api.nvidia.com/v1/chat/completions" + } + }, + { + "key": "nvidia-api-key", + "title": "API Key", + "description": "The NVIDIA API uses API keys for authentication. Visit your [API Keys](https://org.ngc.nvidia.com/setup/personal-keys) page to retrieve the API key you'll use in your requests..", + "controllerType": "input", + "controllerProps": { + "placeholder": "nvapi-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "value": "", + "type": "password", + "inputActions": ["unobscure", "copy"] + } + } +] diff --git a/extensions/inference-nvidia-extension/src/index.ts b/extensions/inference-nvidia-extension/src/index.ts new file mode 100644 index 0000000000..9af27d90c7 --- /dev/null +++ b/extensions/inference-nvidia-extension/src/index.ts @@ -0,0 +1,66 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-mistral-extension/src/index + */ + +import { RemoteOAIEngine } from '@janhq/core' + +declare const SETTINGS: Array +declare const MODELS: Array + +enum Settings { + apiKey = 'nvidia-api-key', + chatCompletionsEndPoint = 'chat-completions-endpoint', +} +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanNVIDIANIMInferenceEngine extends RemoteOAIEngine { + inferenceUrl: string = '' + provider: string = 'nvidia' + + override async onLoad(): Promise { + super.onLoad() + + // Register Settings + this.registerSettings(SETTINGS) + this.registerModels(MODELS) + + this.apiKey = await this.getSetting(Settings.apiKey, '') + this.inferenceUrl = await this.getSetting( + Settings.chatCompletionsEndPoint, + '' + ) + + if (this.inferenceUrl.length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } + } + + onSettingUpdate(key: string, value: T): void { + if (key === Settings.apiKey) { + this.apiKey = value as string + } else if (key === Settings.chatCompletionsEndPoint) { + if (typeof value !== 'string') return + + if (value.trim().length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } else { + this.inferenceUrl = value + } + } + } +} diff --git a/extensions/inference-nvidia-extension/tsconfig.json b/extensions/inference-nvidia-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/inference-nvidia-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/inference-nvidia-extension/webpack.config.js b/extensions/inference-nvidia-extension/webpack.config.js new file mode 100644 index 0000000000..0e35fc227b --- /dev/null +++ b/extensions/inference-nvidia-extension/webpack.config.js @@ -0,0 +1,42 @@ +const path = require('path') +const webpack = require('webpack') +const packageJson = require('./package.json') +const settingJson = require('./resources/settings.json') +const modelsJson = require('./resources/models.json') + +module.exports = { + experiments: { outputModule: true }, + entry: './src/index.ts', // Adjust the entry point to match your project's main file + mode: 'production', + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + ], + }, + plugins: [ + new webpack.DefinePlugin({ + SETTINGS: JSON.stringify(settingJson), + ENGINE: JSON.stringify(packageJson.engine), + MODELS: JSON.stringify(modelsJson), + }), + ], + output: { + filename: 'index.js', // Adjust the output file name as needed + path: path.resolve(__dirname, 'dist'), + library: { type: 'module' }, // Specify ESM output format + }, + resolve: { + extensions: ['.ts', '.js'], + fallback: { + path: require.resolve('path-browserify'), + }, + }, + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +} diff --git a/extensions/inference-openai-extension/README.md b/extensions/inference-openai-extension/README.md new file mode 100644 index 0000000000..c716c725c0 --- /dev/null +++ b/extensions/inference-openai-extension/README.md @@ -0,0 +1,79 @@ +# OpenAI Engine Extension + +Created using Jan extension example + +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-openai-extension/package.json b/extensions/inference-openai-extension/package.json new file mode 100644 index 0000000000..cd776257c4 --- /dev/null +++ b/extensions/inference-openai-extension/package.json @@ -0,0 +1,42 @@ +{ + "name": "@janhq/inference-openai-extension", + "productName": "OpenAI Inference Engine", + "version": "1.0.2", + "description": "This extension enables OpenAI chat completion API calls", + "main": "dist/index.js", + "module": "dist/module.js", + "engine": "openai", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4", + "ts-loader": "^9.5.0" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "fetch-retry": "^5.0.6", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "fetch-retry" + ] +} diff --git a/extensions/inference-openai-extension/resources/models.json b/extensions/inference-openai-extension/resources/models.json new file mode 100644 index 0000000000..6852a1892e --- /dev/null +++ b/extensions/inference-openai-extension/resources/models.json @@ -0,0 +1,123 @@ +[ + { + "sources": [ + { + "url": "https://openai.com" + } + ], + "id": "gpt-4-turbo", + "object": "model", + "name": "OpenAI GPT 4 Turbo", + "version": "1.2", + "description": "OpenAI GPT 4 Turbo model is extremely good", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "OpenAI", + "tags": [ + "General" + ] + }, + "engine": "openai" + }, + { + "sources": [ + { + "url": "https://openai.com" + } + ], + "id": "gpt-4-vision-preview", + "object": "model", + "name": "OpenAI GPT 4 with Vision (Preview)", + "version": "1.1", + "description": "OpenAI GPT-4 Vision model features vision understanding capabilities", + "format": "api", + "settings": { + "vision_model": true, + "textModel": false + }, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true + }, + "metadata": { + "author": "OpenAI", + "tags": [ + "General", + "Vision" + ] + }, + "engine": "openai" + }, + { + "sources": [ + { + "url": "https://openai.com" + } + ], + "id": "gpt-3.5-turbo", + "object": "model", + "name": "OpenAI GPT 3.5 Turbo", + "version": "1.1", + "description": "OpenAI GPT 3.5 Turbo model is extremely fast", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "OpenAI", + "tags": [ + "General" + ] + }, + "engine": "openai" + }, + { + "sources": [ + { + "url": "https://openai.com" + } + ], + "id": "gpt-4o", + "object": "model", + "name": "OpenAI GPT 4o", + "version": "1.1", + "description": "OpenAI GPT 4o is a new flagship model with fast speed and high quality", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "OpenAI", + "tags": [ + "General" + ] + }, + "engine": "openai" + } +] diff --git a/extensions/inference-openai-extension/resources/settings.json b/extensions/inference-openai-extension/resources/settings.json new file mode 100644 index 0000000000..ccd7dd5454 --- /dev/null +++ b/extensions/inference-openai-extension/resources/settings.json @@ -0,0 +1,23 @@ +[ + { + "key": "chat-completions-endpoint", + "title": "Chat Completions Endpoint", + "description": "The endpoint to use for chat completions. See the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create) for more information.", + "controllerType": "input", + "controllerProps": { + "placeholder": "https://api.openai.com/v1/chat/completions", + "value": "https://api.openai.com/v1/chat/completions" + } + }, + { + "key": "openai-api-key", + "title": "API Key", + "description": "The OpenAI API uses API keys for authentication. Visit your [API Keys](https://platform.openai.com/account/api-keys) page to retrieve the API key you'll use in your requests.", + "controllerType": "input", + "controllerProps": { + "placeholder": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "value": "", + "type": "password" + } + } +] diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts new file mode 100644 index 0000000000..60446ccce6 --- /dev/null +++ b/extensions/inference-openai-extension/src/index.ts @@ -0,0 +1,66 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-openai-extension/src/index + */ + +import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core' + +declare const SETTINGS: Array +declare const MODELS: Array + +enum Settings { + apiKey = 'openai-api-key', + chatCompletionsEndPoint = 'chat-completions-endpoint', +} + +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanInferenceOpenAIExtension extends RemoteOAIEngine { + inferenceUrl: string = '' + provider: string = 'openai' + + override async onLoad(): Promise { + super.onLoad() + + // Register Settings + this.registerSettings(SETTINGS) + this.registerModels(MODELS) + + this.apiKey = await this.getSetting(Settings.apiKey, '') + this.inferenceUrl = await this.getSetting( + Settings.chatCompletionsEndPoint, + '' + ) + if (this.inferenceUrl.length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } + } + + onSettingUpdate(key: string, value: T): void { + if (key === Settings.apiKey) { + this.apiKey = value as string + } else if (key === Settings.chatCompletionsEndPoint) { + if (typeof value !== 'string') return + + if (value.trim().length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } else { + this.inferenceUrl = value + } + } + } +} diff --git a/extensions/inference-openai-extension/tsconfig.json b/extensions/inference-openai-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/inference-openai-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/inference-openai-extension/webpack.config.js b/extensions/inference-openai-extension/webpack.config.js new file mode 100644 index 0000000000..cd5e65c725 --- /dev/null +++ b/extensions/inference-openai-extension/webpack.config.js @@ -0,0 +1,37 @@ +const webpack = require('webpack') +const packageJson = require('./package.json') +const settingJson = require('./resources/settings.json') +const modelsJson = require('./resources/models.json') + +module.exports = { + experiments: { outputModule: true }, + entry: './src/index.ts', // Adjust the entry point to match your project's main file + mode: 'production', + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + ], + }, + plugins: [ + new webpack.DefinePlugin({ + MODELS: JSON.stringify(modelsJson), + SETTINGS: JSON.stringify(settingJson), + ENGINE: JSON.stringify(packageJson.engine), + }), + ], + output: { + filename: 'index.js', // Adjust the output file name as needed + library: { type: 'module' }, // Specify ESM output format + }, + resolve: { + extensions: ['.ts', '.js'], + }, + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +} diff --git a/extensions/inference-openrouter-extension/README.md b/extensions/inference-openrouter-extension/README.md new file mode 100644 index 0000000000..aab10755d4 --- /dev/null +++ b/extensions/inference-openrouter-extension/README.md @@ -0,0 +1,79 @@ +# Open Router Engine Extension + +Created using Jan extension example + +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-openrouter-extension/package.json b/extensions/inference-openrouter-extension/package.json new file mode 100644 index 0000000000..9d3d68d470 --- /dev/null +++ b/extensions/inference-openrouter-extension/package.json @@ -0,0 +1,43 @@ +{ + "name": "@janhq/inference-openrouter-extension", + "productName": "OpenRouter Inference Engine", + "version": "1.0.0", + "description": "This extension enables Open Router chat completion API calls", + "main": "dist/index.js", + "module": "dist/module.js", + "engine": "openrouter", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install", + "sync:core": "cd ../.. && yarn build:core && cd extensions && rm yarn.lock && cd inference-openrouter-extension && yarn && yarn build:publish" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4", + "ts-loader": "^9.5.0" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "fetch-retry": "^5.0.6", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "fetch-retry" + ] +} diff --git a/extensions/inference-openrouter-extension/resources/models.json b/extensions/inference-openrouter-extension/resources/models.json new file mode 100644 index 0000000000..d89c07e5af --- /dev/null +++ b/extensions/inference-openrouter-extension/resources/models.json @@ -0,0 +1,28 @@ + [ + { + "sources": [ + { + "url": "https://openrouter.ai" + } + ], + "id": "open-router-auto", + "object": "model", + "name": "OpenRouter", + "version": "1.0", + "description": " OpenRouter scouts for the lowest prices and best latencies/throughputs across dozens of providers, and lets you choose how to prioritize them.", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 1024, + "temperature": 0.7, + "top_p": 0.95, + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "OpenRouter", + "tags": ["General", "Big Context Length"] + }, + "engine": "openrouter" + } +] diff --git a/extensions/inference-openrouter-extension/resources/settings.json b/extensions/inference-openrouter-extension/resources/settings.json new file mode 100644 index 0000000000..85040e96bd --- /dev/null +++ b/extensions/inference-openrouter-extension/resources/settings.json @@ -0,0 +1,23 @@ +[ + { + "key": "chat-completions-endpoint", + "title": "Chat Completions Endpoint", + "description": "The endpoint to use for chat completions. See the [OpenRouter API documentation](https://openrouter.ai/docs) for more information.", + "controllerType": "input", + "controllerProps": { + "placeholder": "https://openrouter.ai/api/v1/chat/completions", + "value": "https://openrouter.ai/api/v1/chat/completions" + } + }, + { + "key": "openrouter-api-key", + "title": "API Key", + "description": "The OpenRouter API uses API keys for authentication. Visit your [API Keys](https://openrouter.ai/keys) page to retrieve the API key you'll use in your requests.", + "controllerType": "input", + "controllerProps": { + "placeholder": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "value": "", + "type": "password" + } + } +] diff --git a/extensions/inference-openrouter-extension/src/index.ts b/extensions/inference-openrouter-extension/src/index.ts new file mode 100644 index 0000000000..5417503e5d --- /dev/null +++ b/extensions/inference-openrouter-extension/src/index.ts @@ -0,0 +1,76 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-openai-extension/src/index + */ + +import { RemoteOAIEngine } from '@janhq/core' +import { PayloadType } from '@janhq/core' +import { ChatCompletionRole } from '@janhq/core' + +declare const SETTINGS: Array +declare const MODELS: Array + +enum Settings { + apiKey = 'openrouter-api-key', + chatCompletionsEndPoint = 'chat-completions-endpoint', +} + +enum RoleType { + user = 'USER', + chatbot = 'CHATBOT', + system = 'SYSTEM', +} + +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanInferenceOpenRouterExtension extends RemoteOAIEngine { + inferenceUrl: string = '' + provider: string = 'openrouter' + + override async onLoad(): Promise { + super.onLoad() + + // Register Settings + this.registerSettings(SETTINGS) + this.registerModels(MODELS) + + this.apiKey = await this.getSetting(Settings.apiKey, '') + this.inferenceUrl = await this.getSetting( + Settings.chatCompletionsEndPoint, + '' + ) + if (this.inferenceUrl.length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } + } + + onSettingUpdate(key: string, value: T): void { + if (key === Settings.apiKey) { + this.apiKey = value as string + } else if (key === Settings.chatCompletionsEndPoint) { + if (typeof value !== 'string') return + + if (value.trim().length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } else { + this.inferenceUrl = value + } + } + } + + transformPayload = (payload: PayloadType)=>({...payload,model:"openrouter/auto"}) +} diff --git a/extensions/inference-openrouter-extension/tsconfig.json b/extensions/inference-openrouter-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/inference-openrouter-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/inference-openrouter-extension/webpack.config.js b/extensions/inference-openrouter-extension/webpack.config.js new file mode 100644 index 0000000000..cd5e65c725 --- /dev/null +++ b/extensions/inference-openrouter-extension/webpack.config.js @@ -0,0 +1,37 @@ +const webpack = require('webpack') +const packageJson = require('./package.json') +const settingJson = require('./resources/settings.json') +const modelsJson = require('./resources/models.json') + +module.exports = { + experiments: { outputModule: true }, + entry: './src/index.ts', // Adjust the entry point to match your project's main file + mode: 'production', + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + ], + }, + plugins: [ + new webpack.DefinePlugin({ + MODELS: JSON.stringify(modelsJson), + SETTINGS: JSON.stringify(settingJson), + ENGINE: JSON.stringify(packageJson.engine), + }), + ], + output: { + filename: 'index.js', // Adjust the output file name as needed + library: { type: 'module' }, // Specify ESM output format + }, + resolve: { + extensions: ['.ts', '.js'], + }, + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +} diff --git a/extensions/inference-triton-trtllm-extension/README.md b/extensions/inference-triton-trtllm-extension/README.md new file mode 100644 index 0000000000..f9690da09d --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/README.md @@ -0,0 +1,75 @@ +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-triton-trtllm-extension/package.json b/extensions/inference-triton-trtllm-extension/package.json new file mode 100644 index 0000000000..6612dc1911 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/package.json @@ -0,0 +1,43 @@ +{ + "name": "@janhq/inference-triton-trt-llm-extension", + "productName": "Triton-TRT-LLM Inference Engine", + "version": "1.0.0", + "description": "This extension enables Nvidia's TensorRT-LLM as an inference engine option", + "main": "dist/index.js", + "module": "dist/module.js", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "ts-loader": "^9.5.0", + "typescript": "5.3.3", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "fetch-retry": "^5.0.6", + "rxjs": "^7.8.1", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "fetch-retry" + ] +} diff --git a/extensions/inference-triton-trtllm-extension/resources/settings.json b/extensions/inference-triton-trtllm-extension/resources/settings.json new file mode 100644 index 0000000000..9c220eed7b --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/resources/settings.json @@ -0,0 +1,23 @@ +[ + { + "key": "chat-completions-endpoint", + "title": "Chat Completions Endpoint", + "description": "The endpoint to use for chat completions.", + "controllerType": "input", + "controllerProps": { + "placeholder": "http://localhost:8000/v2/models/tensorrt_llm_bls/generate", + "value": "http://localhost:8000/v2/models/tensorrt_llm_bls/generate" + } + }, + { + "key": "tritonllm-api-key", + "title": "Triton LLM API Key", + "description": "The Triton LLM API uses API keys for authentication.", + "controllerType": "input", + "controllerProps": { + "placeholder": "xxxxxxxxxxxxxxxxxxxx", + "value": "", + "type": "password" + } + } +] diff --git a/extensions/inference-triton-trtllm-extension/src/index.ts b/extensions/inference-triton-trtllm-extension/src/index.ts new file mode 100644 index 0000000000..be34837ac5 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/src/index.ts @@ -0,0 +1,67 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-nvidia-triton-trt-llm-extension/src/index + */ + +import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core' + +declare const SETTINGS: Array +enum Settings { + apiKey = 'tritonllm-api-key', + chatCompletionsEndPoint = 'chat-completions-endpoint', +} +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanInferenceTritonTrtLLMExtension extends RemoteOAIEngine { + inferenceUrl: string = '' + provider: string = 'triton_trtllm' + + /** + * Subscribes to events emitted by the @janhq/core package. + */ + async onLoad() { + super.onLoad() + + // Register Settings + this.registerSettings(SETTINGS) + + // Retrieve API Key Setting + this.apiKey = await this.getSetting(Settings.apiKey, '') + this.inferenceUrl = await this.getSetting( + Settings.chatCompletionsEndPoint, + '' + ) + + if (this.inferenceUrl.length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } + } + + onSettingUpdate(key: string, value: T): void { + if (key === Settings.apiKey) { + this.apiKey = value as string + } else if (key === Settings.chatCompletionsEndPoint) { + if (typeof value !== 'string') return + + if (value.trim().length === 0) { + SETTINGS.forEach((setting) => { + if (setting.key === Settings.chatCompletionsEndPoint) { + this.inferenceUrl = setting.controllerProps.value as string + } + }) + } else { + this.inferenceUrl = value + } + } + } +} diff --git a/extensions/inference-triton-trtllm-extension/tsconfig.json b/extensions/inference-triton-trtllm-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/inference-triton-trtllm-extension/webpack.config.js b/extensions/inference-triton-trtllm-extension/webpack.config.js new file mode 100644 index 0000000000..6486d5efc9 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/webpack.config.js @@ -0,0 +1,35 @@ +const webpack = require('webpack') +const packageJson = require('./package.json') +const settingJson = require('./resources/settings.json') + +module.exports = { + experiments: { outputModule: true }, + entry: './src/index.ts', // Adjust the entry point to match your project's main file + mode: 'production', + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + ], + }, + plugins: [ + new webpack.DefinePlugin({ + SETTINGS: JSON.stringify(settingJson), + MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`), + }), + ], + output: { + filename: 'index.js', // Adjust the output file name as needed + library: { type: 'module' }, // Specify ESM output format + }, + resolve: { + extensions: ['.ts', '.js'], + }, + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +} diff --git a/extensions/model-extension/README.md b/extensions/model-extension/README.md new file mode 100644 index 0000000000..f9690da09d --- /dev/null +++ b/extensions/model-extension/README.md @@ -0,0 +1,75 @@ +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/model-extension/download.bat b/extensions/model-extension/download.bat new file mode 100644 index 0000000000..de055cb805 --- /dev/null +++ b/extensions/model-extension/download.bat @@ -0,0 +1,3 @@ +@echo off +set /p LLAMA_CPP_VERSION=<./scripts/version.txt +.\node_modules\.bin\download https://github.com/ggerganov/llama.cpp/archive/refs/tags/%LLAMA_CPP_VERSION%.tar.gz -o . --filename ./scripts/llama.cpp.tar.gz && tar -xzf .\scripts\llama.cpp.tar.gz "llama.cpp-%LLAMA_CPP_VERSION%/convert.py" "llama.cpp-%LLAMA_CPP_VERSION%/convert-hf-to-gguf.py" "llama.cpp-%LLAMA_CPP_VERSION%/gguf-py" && cpx "./llama.cpp-%LLAMA_CPP_VERSION%/**" "scripts" && rimraf "./scripts/llama.cpp.tar.gz" && rimraf "./llama.cpp-%LLAMA_CPP_VERSION%" \ No newline at end of file diff --git a/extensions/model-extension/package.json b/extensions/model-extension/package.json new file mode 100644 index 0000000000..6bd8bbe5e0 --- /dev/null +++ b/extensions/model-extension/package.json @@ -0,0 +1,48 @@ +{ + "name": "@janhq/model-extension", + "productName": "Model Management", + "version": "1.0.33", + "description": "Model Management Extension provides model exploration and seamless downloads", + "main": "dist/index.js", + "node": "dist/node/index.cjs.js", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs", + "download:llama": "run-script-os", + "download:llama:linux": "LLAMA_CPP_VERSION=$(cat ./scripts/version.txt) && download https://github.com/ggerganov/llama.cpp/archive/refs/tags/${LLAMA_CPP_VERSION}.tar.gz -o . --filename ./scripts/llama.cpp.tar.gz && tar -xzf ./scripts/llama.cpp.tar.gz --wildcards '*/convert.py' '*/convert-hf-to-gguf.py' '*/gguf-py' && cpx \"./llama.cpp-$LLAMA_CPP_VERSION/**\" \"scripts\" && rimraf \"./scripts/llama.cpp.tar.gz\" && rimraf \"./llama.cpp-$LLAMA_CPP_VERSION\"", + "download:llama:darwin": "LLAMA_CPP_VERSION=$(cat ./scripts/version.txt) && download https://github.com/ggerganov/llama.cpp/archive/refs/tags/${LLAMA_CPP_VERSION}.tar.gz -o . --filename ./scripts/llama.cpp.tar.gz && tar -xzf ./scripts/llama.cpp.tar.gz '*/convert.py' '*/convert-hf-to-gguf.py' '*/gguf-py' && cpx \"./llama.cpp-$LLAMA_CPP_VERSION/**\" \"scripts\" && rimraf \"./scripts/llama.cpp.tar.gz\" && rimraf \"./llama.cpp-$LLAMA_CPP_VERSION\"", + "download:llama:win32": "download.bat", + "build:publish:linux": "rimraf *.tgz --glob && yarn build && yarn download:llama && cpx \"scripts/**\" \"dist/scripts\" && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", + "build:publish:darwin": "rimraf *.tgz --glob && yarn build && yarn download:llama && cpx \"scripts/**\" \"dist/scripts\" && cpx \"bin/**\" \"dist/bin\" && ../../.github/scripts/auto-sign.sh && npm pack && cpx *.tgz ../../pre-install", + "build:publish:win32": "rimraf *.tgz --glob && yarn build && yarn download:llama && cpx \"scripts/**\" \"dist/scripts\" && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", + "build:publish": "run-script-os" + }, + "devDependencies": { + "cpx": "^1.5.0", + "download-cli": "^1.1.1", + "rimraf": "^3.0.2", + "ts-loader": "^9.5.0", + "typescript": "5.3.3", + "@rollup/plugin-commonjs": "^25.0.7", + "@rollup/plugin-json": "^6.1.0", + "@rollup/plugin-node-resolve": "^15.2.3", + "@rollup/plugin-replace": "^5.0.5", + "@rollup/plugin-typescript": "^11.1.6", + "@types/pdf-parse": "^1.1.4", + "rollup": "^2.38.5", + "rollup-plugin-define": "^1.0.1", + "rollup-plugin-sourcemaps": "^0.6.3", + "rollup-plugin-typescript2": "^0.36.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "dependencies": { + "@janhq/core": "file:../../core", + "@huggingface/gguf": "^0.0.11", + "python-shell": "^5.0.0" + } +} diff --git a/extensions/model-extension/resources/default-model.json b/extensions/model-extension/resources/default-model.json new file mode 100644 index 0000000000..c02008cd64 --- /dev/null +++ b/extensions/model-extension/resources/default-model.json @@ -0,0 +1,36 @@ +{ + "object": "model", + "version": "1.0", + "format": "gguf", + "sources": [ + { + "url": "N/A", + "filename": "N/A" + } + ], + "id": "N/A", + "name": "N/A", + "created": 0, + "description": "User self import model", + "settings": { + "ctx_len": 2048, + "embedding": false, + "prompt_template": "{system_message}\n### Instruction: {prompt}\n### Response:", + "llama_model_path": "N/A" + }, + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "max_tokens": 2048, + "stop": ["<|END_OF_TURN_TOKEN|>", "", "[/INST]", "<|end_of_text|>", "<|eot_id|>", "<|im_end|>", "<|end|>"], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "User", + "tags": [], + "size": 0 + }, + "engine": "nitro" +} diff --git a/extensions/model-extension/resources/settings.json b/extensions/model-extension/resources/settings.json new file mode 100644 index 0000000000..d896f1271d --- /dev/null +++ b/extensions/model-extension/resources/settings.json @@ -0,0 +1,14 @@ +[ + { + "key": "hugging-face-access-token", + "title": "Hugging Face Access Token", + "description": "Access tokens programmatically authenticate your identity to the Hugging Face Hub, allowing applications to perform specific actions specified by the scope of permissions granted.", + "controllerType": "input", + "controllerProps": { + "value": "", + "placeholder": "hf_**********************************", + "type": "password", + "inputActions": ["unobscure", "copy"] + } + } +] diff --git a/extensions/model-extension/rollup.config.ts b/extensions/model-extension/rollup.config.ts new file mode 100644 index 0000000000..aa22bd1f6e --- /dev/null +++ b/extensions/model-extension/rollup.config.ts @@ -0,0 +1,46 @@ +import resolve from '@rollup/plugin-node-resolve' +import sourceMaps from 'rollup-plugin-sourcemaps' +import typescript from 'rollup-plugin-typescript2' +import json from '@rollup/plugin-json' +import replace from '@rollup/plugin-replace' + +const settingJson = require('./resources/settings.json') +const packageJson = require('./package.json') +const defaultModelJson = require('./resources/default-model.json') + +export default [ + { + input: `src/index.ts`, + output: [{ file: packageJson.main, format: 'es', sourcemap: true }], + // Indicate here external modules you don't wanna include in your bundle (i.e.: 'lodash') + external: [], + watch: { + include: 'src/**', + }, + plugins: [ + replace({ + preventAssignment: true, + DEFAULT_MODEL: JSON.stringify(defaultModelJson), + SETTINGS: JSON.stringify(settingJson), + NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), + }), + // Allow json resolution + json(), + // Compile TypeScript files + typescript({ useTsconfigDeclarationDir: true }), + // Compile TypeScript files + // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) + // commonjs(), + // Allow node_modules resolution, so you can use 'external' to control + // which external modules to include in the bundle + // https://github.com/rollup/rollup-plugin-node-resolve#usage + resolve({ + extensions: ['.js', '.ts', '.svelte'], + browser: true, + }), + + // Resolve source maps to the original source + sourceMaps(), + ], + }, +] diff --git a/extensions/model-extension/scripts/convert-hf-to-gguf.py b/extensions/model-extension/scripts/convert-hf-to-gguf.py new file mode 100755 index 0000000000..0d4ea03b44 --- /dev/null +++ b/extensions/model-extension/scripts/convert-hf-to-gguf.py @@ -0,0 +1,1720 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import contextlib +import json +import os +import re +import sys +from enum import IntEnum +from pathlib import Path +from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast + +import numpy as np +import torch + +if TYPE_CHECKING: + from torch import Tensor + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + +from convert import HfVocab + + +# check for any of the given keys in the dictionary and return the value of the first key found +def get_key_opts(d, keys): + for k in keys: + if k in d: + return d[k] + print(f"Could not find any of {keys}") + sys.exit() + + +###### MODEL DEFINITIONS ###### + +class SentencePieceTokenTypes(IntEnum): + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 + USER_DEFINED = 4 + UNUSED = 5 + BYTE = 6 + + +class Model: + def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool): + self.dir_model = dir_model + self.ftype = ftype + self.fname_out = fname_out + self.is_big_endian = is_big_endian + self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.is_safetensors = self._is_model_safetensors() + self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") + self.part_names = self._get_part_names() + self.hparams = Model.load_hparams(self.dir_model) + self.model_arch = self._get_model_architecture() + self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False) + + def set_vocab(self): + self._set_vocab_gpt2() + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + for part_name in self.part_names: + print(f"gguf: loading model part '{part_name}'") + ctx: ContextManager[Any] + if self.is_safetensors: + from safetensors import safe_open + ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) + else: + ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) + + with ctx as model_part: + for name in model_part.keys(): + data = model_part.get_tensor(name) if self.is_safetensors else model_part[name] + yield name, data + + def set_gguf_parameters(self): + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_block_count(self.hparams.get( + "n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")), + )) + if (n_ctx := self.hparams.get("max_position_embeddings")) is not None: + self.gguf_writer.add_context_length(n_ctx) + if (n_embd := self.hparams.get("hidden_size")) is not None: + self.gguf_writer.add_embedding_length(n_embd) + if (n_ff := self.hparams.get("intermediate_size")) is not None: + self.gguf_writer.add_feed_forward_length(n_ff) + if (n_head := self.hparams.get("num_attention_heads")) is not None: + self.gguf_writer.add_head_count(n_head) + if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: + self.gguf_writer.add_head_count_kv(n_head_kv) + + if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None: + self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps) + if (n_experts := self.hparams.get("num_local_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + + self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True)) + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + def write(self): + self.write_tensors() + self.gguf_writer.write_header_to_file() + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file() + self.gguf_writer.close() + + def write_vocab(self): + self.gguf_writer.write_header_to_file() + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.close() + + @staticmethod + def count_model_parts(dir_model: Path, prefix: str) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.endswith(prefix): + num_parts += 1 + + return num_parts + + @staticmethod + def load_hparams(dir_model): + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + return json.load(f) + + @staticmethod + def from_model_architecture(model_architecture): + if model_architecture == "GPTNeoXForCausalLM": + return GPTNeoXModel + if model_architecture == "BloomForCausalLM": + return BloomModel + if model_architecture == "MPTForCausalLM": + return MPTModel + if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): + return BaichuanModel + if model_architecture in ("FalconForCausalLM", "RWForCausalLM"): + return FalconModel + if model_architecture == "GPTBigCodeForCausalLM": + return StarCoderModel + if model_architecture == "GPTRefactForCausalLM": + return RefactModel + if model_architecture == "PersimmonForCausalLM": + return PersimmonModel + if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"): + return StableLMModel + if model_architecture == "QWenLMHeadModel": + return QwenModel + if model_architecture == "Qwen2ForCausalLM": + return Model + if model_architecture == "MixtralForCausalLM": + return MixtralModel + if model_architecture == "GPT2LMHeadModel": + return GPT2Model + if model_architecture == "PhiForCausalLM": + return Phi2Model + if model_architecture == "PlamoForCausalLM": + return PlamoModel + if model_architecture == "CodeShellForCausalLM": + return CodeShellModel + if model_architecture == "OrionForCausalLM": + return OrionModel + if model_architecture == "InternLM2ForCausalLM": + return InternLM2Model + if model_architecture == "MiniCPMForCausalLM": + return MiniCPMModel + return Model + + def _is_model_safetensors(self) -> bool: + return Model.count_model_parts(self.dir_model, ".safetensors") > 0 + + def _get_part_names(self): + if self.is_safetensors: + if self.num_parts == 1: # there's only one .safetensors file + return ("model.safetensors",) + return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1)) + + if self.num_parts == 1: # there's only one .bin file + return ("pytorch_model.bin",) + return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) + + def _get_model_architecture(self) -> gguf.MODEL_ARCH: + arch = self.hparams["architectures"][0] + if arch == "GPTNeoXForCausalLM": + return gguf.MODEL_ARCH.GPTNEOX + if arch == "BloomForCausalLM": + return gguf.MODEL_ARCH.BLOOM + if arch == "MPTForCausalLM": + return gguf.MODEL_ARCH.MPT + if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): + return gguf.MODEL_ARCH.BAICHUAN + if arch in ("FalconForCausalLM", "RWForCausalLM"): + return gguf.MODEL_ARCH.FALCON + if arch == "GPTBigCodeForCausalLM": + return gguf.MODEL_ARCH.STARCODER + if arch == "GPTRefactForCausalLM": + return gguf.MODEL_ARCH.REFACT + if arch == "PersimmonForCausalLM": + return gguf.MODEL_ARCH.PERSIMMON + if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"): + return gguf.MODEL_ARCH.STABLELM + if arch == "QWenLMHeadModel": + return gguf.MODEL_ARCH.QWEN + if arch == "Qwen2ForCausalLM": + return gguf.MODEL_ARCH.QWEN2 + if arch == "MixtralForCausalLM": + return gguf.MODEL_ARCH.LLAMA + if arch == "GPT2LMHeadModel": + return gguf.MODEL_ARCH.GPT2 + if arch == "PhiForCausalLM": + return gguf.MODEL_ARCH.PHI2 + if arch == "PlamoForCausalLM": + return gguf.MODEL_ARCH.PLAMO + if arch == "CodeShellForCausalLM": + return gguf.MODEL_ARCH.CODESHELL + if arch == "OrionForCausalLM": + return gguf.MODEL_ARCH.ORION + if arch == "InternLM2ForCausalLM": + return gguf.MODEL_ARCH.INTERNLM2 + if arch == "MiniCPMForCausalLM": + return gguf.MODEL_ARCH.MINICPM + + raise NotImplementedError(f'Architecture "{arch}" not supported!') + + def _set_vocab_gpt2(self): + dir_model = self.dir_model + hparams = self.hparams + tokens: list[bytearray] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model) + vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) + assert max(tokenizer.vocab.values()) < vocab_size + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + added_vocab = tokenizer.get_added_vocab() + + for i in range(vocab_size): + if i not in reverse_vocab: + pad_token = f"[PAD{i}]".encode('utf-8') + tokens.append(bytearray(pad_token)) + toktypes.append(gguf.TokenType.USER_DEFINED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + if tokenizer.added_tokens_decoder[i].special: + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + + def _set_vocab_qwen(self): + dir_model = self.dir_model + hparams = self.hparams + tokens: list[bytearray] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + vocab_size = hparams["vocab_size"] + assert max(tokenizer.get_vocab().values()) < vocab_size + + merges = [] + vocab = {} + mergeable_ranks = tokenizer.mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[QwenModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) + assert len(merged) == 2 + merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + + # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined + added_vocab = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in (vocab | added_vocab).items()} + + for i in range(vocab_size): + if i not in reverse_vocab: + pad_token = f"[PAD{i}]".encode("utf-8") + tokens.append(bytearray(pad_token)) + toktypes.append(gguf.TokenType.USER_DEFINED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.CONTROL) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) + special_vocab.merges = merges + # only add special tokens when they were not already loaded from config.json + if len(special_vocab.special_token_ids) == 0: + special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"]) + special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"]) + # this one is usually not in config.json anyway + special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + + def _set_vocab_sentencepiece(self): + from sentencepiece import SentencePieceProcessor + + tokenizer_path = self.dir_model / 'tokenizer.model' + + tokens: list[bytes] = [] + scores: list[float] = [] + toktypes: list[int] = [] + + if not tokenizer_path.is_file(): + print(f'Error: Missing {tokenizer_path}', file=sys.stderr) + sys.exit(1) + + tokenizer = SentencePieceProcessor(str(tokenizer_path)) + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + for token_id in range(vocab_size): + piece = tokenizer.id_to_piece(token_id) + text = piece.encode("utf-8") + score = tokenizer.get_score(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + + for key in added_tokens_json: + tokens.append(key.encode("utf-8")) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.USER_DEFINED) + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def _set_vocab_hf(self): + path = self.dir_model + added_tokens_path = self.dir_model + vocab = HfVocab( + path, added_tokens_path if added_tokens_path.exists() else None + ) + tokens = [] + scores = [] + toktypes = [] + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + +class GPTNeoXModel(Model): + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count( + int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])), + ) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True)) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) + + +class BloomModel(Model): + def set_gguf_parameters(self): + self.gguf_writer.add_name("Bloom") + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) + self.gguf_writer.add_embedding_length(n_embed) + self.gguf_writer.add_feed_forward_length(4 * n_embed) + self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def write_tensors(self): + block_count = self.hparams["n_layer"] + tensors = dict(self.get_tensors()) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + has_lm_head = True + n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + + for name, data_torch in tensors.items(): + if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys(): + has_lm_head = False + + name = re.sub(r'transformer\.', '', name) + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name): + # Map bloom-style qkv_linear to gpt-style qkv_linear + # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa + # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa + qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed)) + data = np.concatenate( + ( + qkv_weights[:, 0, :, :].reshape((-1, n_embed)), + qkv_weights[:, 1, :, :].reshape((-1, n_embed)), + qkv_weights[:, 2, :, :].reshape((-1, n_embed)), + ), + axis=0, + ) + print("re-format attention.linear_qkv.weight") + elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name): + qkv_bias = data.reshape((n_head, 3, n_embed // n_head)) + data = np.concatenate( + ( + qkv_bias[:, 0, :].reshape((n_embed,)), + qkv_bias[:, 1, :].reshape((n_embed,)), + qkv_bias[:, 2, :].reshape((n_embed,)), + ), + axis=0, + ) + print("re-format attention.linear_qkv.bias") + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + if not has_lm_head and name == "word_embeddings.weight": + self.gguf_writer.add_tensor("output.weight", data) + print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") + + +class MPTModel(Model): + def set_gguf_parameters(self): + block_count = self.hparams["n_layers"] + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) + self.gguf_writer.add_embedding_length(self.hparams["d_model"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"]) + self.gguf_writer.add_head_count(self.hparams["n_heads"]) + if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"): + self.gguf_writer.add_head_count_kv(kv_n_heads) + self.gguf_writer.add_layer_norm_eps(1e-5) + if self.hparams["attn_config"]["clip_qkv"] is not None: + self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"]) + self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"]) + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers")) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + if "scales" in name: + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales")) + if new_name is not None: + new_name = new_name.replace("scales", "act.scales") + else: + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + # note: MPT output is tied to (same as) wte in original model; + # for easier implementation in llama.cpp it's duplicated in GGUF, though :/ + if new_name == "token_embd.weight": + self.gguf_writer.add_tensor("output.weight", data) + + +class OrionModel(Model): + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + head_count = self.hparams["num_attention_heads"] + head_count_kv = self.hparams.get("num_key_value_heads", head_count) + hf_repo = self.hparams.get("_name_or_path", "") + + ctx_length = 0 + if "max_sequence_length" in self.hparams: + ctx_length = self.hparams["max_sequence_length"] + elif "max_position_embeddings" in self.hparams: + ctx_length = self.hparams["max_position_embeddings"] + elif "model_max_length" in self.hparams: + ctx_length = self.hparams["model_max_length"] + else: + print("gguf: can not find ctx length parameter.") + sys.exit() + + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_source_hf_repo(hf_repo) + self.gguf_writer.add_tensor_data_layout("Meta AI original pth") + self.gguf_writer.add_context_length(ctx_length) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_head_count(head_count) + self.gguf_writer.add_head_count_kv(head_count_kv) + self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"]) + + def write_tensors(self): + # Collect tensors from generator object + model_kv = dict(self.get_tensors()) + block_count = self.hparams["num_hidden_layers"] + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + for name, data_torch in model_kv.items(): + # we don't need these + if name.endswith(".rotary_emb.inv_freq"): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + self.gguf_writer.add_tensor(new_name, data) + + +class BaichuanModel(Model): + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + head_count = self.hparams["num_attention_heads"] + head_count_kv = self.hparams.get("num_key_value_heads", head_count) + hf_repo = self.hparams.get("_name_or_path", "") + + ctx_length = 0 + if "max_sequence_length" in self.hparams: + ctx_length = self.hparams["max_sequence_length"] + elif "max_position_embeddings" in self.hparams: + ctx_length = self.hparams["max_position_embeddings"] + elif "model_max_length" in self.hparams: + ctx_length = self.hparams["model_max_length"] + else: + print("gguf: can not find ctx length parameter.") + sys.exit() + + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_source_hf_repo(hf_repo) + self.gguf_writer.add_tensor_data_layout("Meta AI original pth") + self.gguf_writer.add_context_length(ctx_length) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count(head_count) + self.gguf_writer.add_head_count_kv(head_count_kv) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + + def write_tensors(self): + # Collect tensors from generator object + model_kv = dict(self.get_tensors()) + block_count = self.hparams["num_hidden_layers"] + head_count = self.hparams["num_attention_heads"] + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + head_count_kv = self.hparams.get("num_key_value_heads", head_count) + + for i in range(block_count): + if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None: + print(f"Unpacking and permuting layer {i}") + model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \ + self._reverse_hf_permute_part(w, 0, head_count, head_count) + model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \ + self._reverse_hf_permute_part(w, 1, head_count, head_count_kv) + model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \ + self._reverse_hf_part(w, 2) + del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"] + + for name, data_torch in model_kv.items(): + # we don't need these + if name.endswith(".rotary_emb.inv_freq"): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + self.gguf_writer.add_tensor(new_name, data) + + def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head + + return ( + weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape) + ) + + def _reverse_hf_permute_part( + self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None, + ) -> Tensor: + r = weights.shape[0] // 3 + return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv) + + def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor: + r = weights.shape[0] // 3 + return weights[r * n_part:r * n_part + r, ...] + + +class FalconModel(Model): + def set_gguf_parameters(self): + block_count = self.hparams.get("num_hidden_layers") + if block_count is None: + block_count = self.hparams["n_layer"] # old name + + n_head = self.hparams.get("num_attention_heads") + if n_head is None: + n_head = self.hparams["n_head"] # old name + + n_head_kv = self.hparams.get("num_kv_heads") + if n_head_kv is None: + n_head_kv = self.hparams.get("n_head_kv", 1) # old name + + self.gguf_writer.add_name("Falcon") + self.gguf_writer.add_context_length(2048) # not in config.json + self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head_kv) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def write_tensors(self): + block_count = self.hparams.get("num_hidden_layers") + if block_count is None: + block_count = self.hparams["n_layer"] # old name + + n_head = self.hparams.get("num_attention_heads") + if n_head is None: + n_head = self.hparams["n_head"] # old name + + n_head_kv = self.hparams.get("num_kv_heads") + if n_head_kv is None: + n_head_kv = self.hparams.get("n_head_kv", 1) # old name + + head_dim = self.hparams["hidden_size"] // n_head + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + for name, data_torch in self.get_tensors(): + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # QKV tensor transform + # The original query_key_value tensor contains n_head_kv "kv groups", + # each consisting of n_head/n_head_kv query weights followed by one key + # and one value weight (shared by all query heads in the kv group). + # This layout makes it a big pain to work with in GGML. + # So we rearrange them here,, so that we have n_head query weights + # followed by n_head_kv key weights followed by n_head_kv value weights, + # in contiguous fashion. + # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py + + if "query_key_value" in name: + qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head) + q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head) + k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head) + v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head) + data_torch = torch.cat((q, k, v)).reshape_as(data_torch) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + +class StarCoderModel(Model): + def set_gguf_parameters(self): + block_count = self.hparams["n_layer"] + + self.gguf_writer.add_name("StarCoder") + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_head_count_kv(1) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + +class RefactModel(Model): + def set_gguf_parameters(self): + hidden_dim = self.hparams["n_embd"] + inner_dim = 4 * hidden_dim + hidden_dim = int(2 * inner_dim / 3) + multiple_of = 256 + ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + block_count = self.hparams["n_layer"] + + self.gguf_writer.add_name("Refact") + # refact uses Alibi. So this is from config.json which might be used by training. + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + + self.gguf_writer.add_feed_forward_length(ff_dim) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_head_count_kv(1) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def write_tensors(self): + hidden_dim = self.hparams["n_embd"] + inner_dim = 4 * hidden_dim + hidden_dim = int(2 * inner_dim / 3) + multiple_of = 256 + ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + n_head = self.hparams["n_head"] + n_head_kv = 1 + head_dim = self.hparams["n_embd"] // n_head + block_count = self.hparams["n_layer"] + + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + tensors = dict(self.get_tensors()) + for i in range(block_count): + if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None: + tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim] + tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:] + del tensors[f"transformer.h.{i}.attn.kv.weight"] + if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None: + tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w + del tensors[f"transformer.h.{i}.attn.q.weight"] + if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None: + tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim] + tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:] + del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"] + + for name, data_torch in tensors.items(): + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight",)) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + +class PersimmonModel(Model): + def set_gguf_parameters(self): + block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) + head_count = self.hparams["num_attention_heads"] + head_count_kv = head_count + hidden_size = self.hparams["hidden_size"] + + self.gguf_writer.add_name('persimmon-8b-chat') + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(hidden_size) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + + # NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller + # than the head size? + # ref: https://github.com/ggerganov/llama.cpp/pull/4889 + # self.gguf_writer.add_rope_dimension_count(hidden_size // head_count) + self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2) + + self.gguf_writer.add_head_count(head_count) + self.gguf_writer.add_head_count_kv(head_count_kv) + self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + + def set_vocab(self): + self._set_vocab_sentencepiece() + # self.gguf_writer.add_bos_token_id(71013) + # self.gguf_writer.add_eos_token_id(71013) + + def write_tensors(self): + block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + for name, data_torch in self.get_tensors(): + if name.endswith(".self_attention.rotary_emb.inv_freq"): + continue + old_dtype = data_torch.dtype + # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?) + data = data_torch.to(torch.float32).squeeze().numpy() + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + n_dims = len(data.shape) + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + self.gguf_writer.add_tensor(new_name, data) + + +class StableLMModel(Model): + def set_vocab(self): + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + else: + # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab + self._set_vocab_qwen() + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"] * (hparams["hidden_size"] // hparams["num_attention_heads"]))) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True) + self.gguf_writer.add_layer_norm_eps(1e-5) + + +class MixtralModel(Model): + def set_vocab(self): + self._set_vocab_sentencepiece() + + +class MiniCPMModel(Model): + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + self.gguf_writer.add_name("MiniCPM") + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) + + def set_vocab(self): + self._set_vocab_hf() + + def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head + + return ( + weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape) + ) + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + n_head = self.hparams.get("num_attention_heads") + n_kv_head = self.hparams.get("num_key_value_heads") + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # HF models permute some of the tensors, so we need to undo that + if name.endswith(("q_proj.weight")): + data_torch = self._reverse_hf_permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight")): + data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + +class QwenModel(Model): + @staticmethod + def token_bytes_to_string(b): + from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + byte_encoder = bytes_to_unicode() + return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) + + @staticmethod + def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] + return parts + + def set_vocab(self): + self._set_vocab_qwen() + + def set_gguf_parameters(self): + self.gguf_writer.add_name("Qwen") + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) + + def write_tensors(self): + block_count = self.hparams["num_hidden_layers"] + model_kv = dict(self.get_tensors()) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + for name, data_torch in model_kv.items(): + # we don't need these + if name.endswith(".rotary_emb.inv_freq"): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + self.gguf_writer.add_tensor(new_name, data) + + +class GPT2Model(Model): + def set_gguf_parameters(self): + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_context_length(self.hparams["n_ctx"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias", ".attn.masked_bias")): + continue + + if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")): + data_torch = data_torch.transpose(1, 0) + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + # note: GPT2 output is tied to (same as) wte in original model + if new_name == "token_embd.weight": + print(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + self.gguf_writer.add_tensor("output.weight", data) + + +class Phi2Model(Model): + def set_gguf_parameters(self): + block_count = get_key_opts(self.hparams, ["num_hidden_layers", "n_layer"]) + + rot_pct = get_key_opts(self.hparams, ["partial_rotary_factor"]) + n_embd = get_key_opts(self.hparams, ["hidden_size", "n_embd"]) + n_head = get_key_opts(self.hparams, ["num_attention_heads", "n_head"]) + + self.gguf_writer.add_name("Phi2") + self.gguf_writer.add_context_length(get_key_opts(self.hparams, ["n_positions", "max_position_embeddings"])) + + self.gguf_writer.add_embedding_length(n_embd) + self.gguf_writer.add_feed_forward_length(4 * n_embd) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head) + self.gguf_writer.add_layer_norm_eps(get_key_opts(self.hparams, ["layer_norm_epsilon", "layer_norm_eps"])) + self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_add_bos_token(False) + + +class PlamoModel(Model): + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + self.gguf_writer.add_name("PLaMo") + self.gguf_writer.add_context_length(4096) # not in config.json + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong + self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + + def shuffle_attn_q_weight(self, data_torch): + assert data_torch.size() == (5120, 5120) + data_torch = data_torch.reshape(8, 5, 128, 5120) + data_torch = torch.permute(data_torch, (1, 0, 2, 3)) + data_torch = torch.reshape(data_torch, (5120, 5120)) + return data_torch + + def shuffle_attn_output_weight(self, data_torch): + assert data_torch.size() == (5120, 5120) + data_torch = data_torch.reshape(5120, 8, 5, 128) + data_torch = torch.permute(data_torch, (0, 2, 1, 3)) + data_torch = torch.reshape(data_torch, (5120, 5120)) + return data_torch + + def write_tensors(self): + block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + for name, data_torch in self.get_tensors(): + if "self_attn.rotary_emb.inv_freq" in name: + continue + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + # shuffle for broadcasting of gqa in ggml_mul_mat + if new_name.endswith("attn_q.weight"): + data_torch = self.shuffle_attn_q_weight(data_torch) + elif new_name.endswith("attn_output.weight"): + data_torch = self.shuffle_attn_output_weight(data_torch) + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + +class CodeShellModel(Model): + def set_gguf_parameters(self): + block_count = self.hparams["n_layer"] + + self.gguf_writer.add_name("CodeShell") + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_rope_freq_base(10000.0) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + tensors = dict(self.get_tensors()) + has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys() + for name, data_torch in tensors.items(): + # we don't need these + if name.endswith((".attn.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + if not has_lm_head and name == "transformer.wte.weight": + self.gguf_writer.add_tensor("output.weight", data) + print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") + + +class InternLM2Model(Model): + def set_vocab(self): + # (TODO): Is there a better way? + # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character + # \x00 specially and convert it into an emoji character to prevent it from being mistakenly + # recognized as an empty string in C++. + from sentencepiece import SentencePieceProcessor + from sentencepiece import sentencepiece_model_pb2 as model + + tokenizer_path = self.dir_model / 'tokenizer.model' + + tokens: list[bytes] = [] + scores: list[float] = [] + toktypes: list[int] = [] + + if not tokenizer_path.is_file(): + print(f'Error: Missing {tokenizer_path}', file=sys.stderr) + sys.exit(1) + + sentencepiece_model = model.ModelProto() + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + + tokenizer = SentencePieceProcessor(str(tokenizer_path)) + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + for token_id in range(vocab_size): + piece = tokenizer.id_to_piece(token_id) + text = piece.encode("utf-8") + score = tokenizer.get_score(token_id) + if text == b"\x00": + # (TODO): fixme + # Hack here and replace the \x00 characters. + print(f"InternLM2 convert token '{text}' to '🐉'!") + text = "🐉" + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + + for key in added_tokens_json: + tokens.append(key.encode("utf-8")) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.USER_DEFINED) + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_space_prefix(add_prefix) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + old_eos = special_vocab.special_token_ids["eos"] + if "chat" in os.path.basename(self.dir_model.absolute()): + # For the chat model, we replace the eos with '<|im_end|>'. + special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer) + print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \ +in chat mode so that the conversation can end normally.") + + special_vocab.add_to_gguf(self.gguf_writer) + + def _try_get_sft_eos(self, tokenizer): + unused_145_list = tokenizer.encode('[UNUSED_TOKEN_145]') + im_end_list = tokenizer.encode('<|im_end|>') + assert (len(unused_145_list) == 1) ^ (len(im_end_list) == 1) + if len(unused_145_list) == 1: + eos_token = unused_145_list[0] + if len(im_end_list) == 1: + eos_token = im_end_list[0] + return eos_token + + def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def set_gguf_parameters(self): + self.gguf_writer.add_name("InternLM2") + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) + + def post_write_tensors(self, tensor_map, name, data_torch): + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + self.gguf_writer.add_tensor(new_name, data) + + def write_tensors(self): + from einops import rearrange + + num_heads = self.hparams.get("num_attention_heads") + num_kv_heads = self.hparams.get("num_key_value_heads") + hidden_size = self.hparams.get("hidden_size") + q_per_kv = num_heads // num_kv_heads + head_dim = hidden_size // num_heads + num_groups = num_heads // q_per_kv + + block_count = self.hparams["num_hidden_layers"] + model_kv = dict(self.get_tensors()) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv" + for name, data_torch in model_kv.items(): + # we don't need these + if name.endswith(".rotary_emb.inv_freq"): + continue + + if re.match(qkv_pattern, name): + bid = re.findall(qkv_pattern, name)[0] + qkv = data_torch + qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim) + q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :] + # The model weights of q and k equire additional reshape. + q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads) + k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads) + v = rearrange(v, " o g n i -> o (g n i)").T + self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q) + self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k) + self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v) + else: + self.post_write_tensors(tensor_map, name, data_torch) + + +###### CONVERSION LOGIC ###### + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert a huggingface model to a GGML compatible file") + parser.add_argument( + "--vocab-only", action="store_true", + help="extract only the vocab", + ) + parser.add_argument( + "--awq-path", type=Path, default=None, + help="Path to scale awq cache file") + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16"], default="f16", + help="output format - use f32 for float32, f16 for float16", + ) + parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") + parser.add_argument( + "model", type=Path, + help="directory containing model file", + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + dir_model = args.model + + if args.awq_path: + sys.path.insert(1, str(Path(__file__).parent / 'awq-py')) + from awq.apply_awq import add_scale_weights # type: ignore[import-not-found] + tmp_model_path = args.model / "weighted_model" + dir_model = tmp_model_path + if tmp_model_path.is_dir(): + print(f"{tmp_model_path} exists as a weighted model.") + else: + tmp_model_path.mkdir(parents=True, exist_ok=True) + print("Saving new weighted model ...") + add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path)) + print(f"Saved weighted model at {tmp_model_path}.") + + if not dir_model.is_dir(): + print(f'Error: {args.model} is not a directory', file=sys.stderr) + sys.exit(1) + + ftype_map = { + "f32": gguf.GGMLQuantizationType.F32, + "f16": gguf.GGMLQuantizationType.F16, + } + + if args.outfile is not None: + fname_out = args.outfile + else: + # output in the same directory as the model by default + fname_out = dir_model / f'ggml-model-{args.outtype}.gguf' + + print(f"Loading model: {dir_model.name}") + + hparams = Model.load_hparams(dir_model) + + with torch.inference_mode(): + model_class = Model.from_model_architecture(hparams["architectures"][0]) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian) + + print("Set model parameters") + model_instance.set_gguf_parameters() + + print("Set model tokenizer") + model_instance.set_vocab() + + if args.vocab_only: + print(f"Exporting model vocab to '{fname_out}'") + model_instance.write_vocab() + else: + print(f"Exporting model to '{fname_out}'") + model_instance.write() + + print(f"Model successfully exported to '{fname_out}'") + + +if __name__ == '__main__': + main() diff --git a/extensions/model-extension/scripts/convert.py b/extensions/model-extension/scripts/convert.py new file mode 100755 index 0000000000..323e8058d5 --- /dev/null +++ b/extensions/model-extension/scripts/convert.py @@ -0,0 +1,1478 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import concurrent.futures +import enum +import faulthandler +import functools +import itertools +import json +import math +import mmap +import os +import pickle +import re +import signal +import struct +import sys +import time +import zipfile +from abc import ABCMeta, abstractmethod +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar + +import numpy as np +from sentencepiece import SentencePieceProcessor + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + +if TYPE_CHECKING: + from typing import TypeAlias + +if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): + faulthandler.register(signal.SIGUSR1) + +NDArray: TypeAlias = 'np.ndarray[Any, Any]' + +ARCH = gguf.MODEL_ARCH.LLAMA + +DEFAULT_CONCURRENCY = 8 + +# +# data types +# + + +@dataclass(frozen=True) +class DataType: + name: str + dtype: np.dtype[Any] + valid_conversions: list[str] + + def elements_to_bytes(self, n_elements: int) -> int: + return n_elements * self.dtype.itemsize + + +@dataclass(frozen=True) +class UnquantizedDataType(DataType): + pass + + +DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) +DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0']) +DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) +DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) + + +@dataclass(frozen=True) +class QuantizedDataType(DataType): + block_size: int + quantized_dtype: np.dtype[Any] + ggml_type: gguf.GGMLQuantizationType + + def quantize(self, arr: NDArray) -> NDArray: + raise NotImplementedError(f'Quantization for {self.name} not implemented') + + def elements_to_bytes(self, n_elements: int) -> int: + assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}' + return self.quantized_dtype.itemsize * (n_elements // self.block_size) + + +@dataclass(frozen=True) +class Q8_0QuantizedDataType(QuantizedDataType): + # Mini Q8_0 quantization in Python! + def quantize(self, arr: NDArray) -> NDArray: + assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}' + assert arr.dtype == np.float32, f'Bad array type {arr.dtype}' + n_blocks = arr.size // self.block_size + blocks = arr.reshape((n_blocks, self.block_size)) + # Much faster implementation of block quantization contributed by @Cebtenzzre + + def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]: + d = abs(blocks).max(axis = 1) / np.float32(127) + with np.errstate(divide = 'ignore'): + qs = (blocks / d[:, None]).round() + qs[d == 0] = 0 + yield from zip(d, qs) + return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype) + + +DT_Q8_0 = Q8_0QuantizedDataType('Q8_0', + dtype = np.dtype(np.float32), valid_conversions = [], + ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32, + quantized_dtype = np.dtype([('d', ' DataType: + dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self) + if dt is None: + raise ValueError(self) + # 1D tensors are always F32. + return dt if len(tensor.shape) > 1 else DT_F32 + + +GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = { + GGMLFileType.AllF32 : DT_F32, + GGMLFileType.MostlyF16 : DT_F16, + GGMLFileType.MostlyQ8_0: DT_Q8_0, +} + +# +# hparams loading +# + + +@dataclass +class Params: + n_vocab: int + n_embd: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int + n_experts: int | None = None + n_experts_used: int | None = None + f_norm_eps: float | None = None + + rope_scaling_type: gguf.RopeScalingType | None = None + f_rope_freq_base: float | None = None + f_rope_scale: float | None = None + n_orig_ctx: int | None = None + rope_finetuned: bool | None = None + + ftype: GGMLFileType | None = None + + # path to the directory containing the model files + path_model: Path | None = None + + @staticmethod + def guessed(model: LazyModel) -> Params: + # try transformer naming first + n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape + + # try transformer naming first + if "model.layers.0.self_attn.q_proj.weight" in model: + n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model) + elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming + n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model) + else: + n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) + + if n_layer < 1: + raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + + n_head = n_embd // 128 # guessed + n_mult = 256 # guessed + + # TODO: verify this + n_ff = int(2 * (4 * n_embd) / 3) + n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult) + + return Params( + n_vocab = n_vocab, + n_embd = n_embd, + n_layer = n_layer, + n_ctx = -1, + n_ff = n_ff, + n_head = n_head, + n_head_kv = n_head, + f_norm_eps = 1e-5, + ) + + @staticmethod + def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: + config = json.load(open(config_path)) + + rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None + rope_scaling = config.get("rope_scaling") + + if rope_scaling is not None and (typ := rope_scaling.get("type")): + rope_factor = rope_scaling.get("factor") + f_rope_scale = rope_factor + if typ == "linear": + rope_scaling_type = gguf.RopeScalingType.LINEAR + elif typ == "yarn": + rope_scaling_type = gguf.RopeScalingType.YARN + n_orig_ctx = rope_scaling['original_max_position_embeddings'] + rope_finetuned = rope_scaling['finetuned'] + else: + raise NotImplementedError(f'Unknown rope scaling type: {typ}') + + if "max_sequence_length" in config: + n_ctx = config["max_sequence_length"] + elif "max_position_embeddings" in config: + n_ctx = config["max_position_embeddings"] + else: + raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + + n_experts = None + n_experts_used = None + + if "num_local_experts" in config: + n_experts = config["num_local_experts"] + n_experts_used = config["num_experts_per_tok"] + + return Params( + n_vocab = config["vocab_size"], + n_embd = config["hidden_size"], + n_layer = config["num_hidden_layers"], + n_ctx = n_ctx, + n_ff = config["intermediate_size"], + n_head = (n_head := config["num_attention_heads"]), + n_head_kv = config.get("num_key_value_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, + f_norm_eps = config["rms_norm_eps"], + f_rope_freq_base = config.get("rope_theta"), + rope_scaling_type = rope_scaling_type, + f_rope_scale = f_rope_scale, + n_orig_ctx = n_orig_ctx, + rope_finetuned = rope_finetuned, + ) + + # LLaMA v2 70B params.json + # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} + @staticmethod + def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: + config = json.load(open(config_path)) + + n_experts = None + n_experts_used = None + f_rope_freq_base = None + + # hack to determine LLaMA v1 vs v2 vs CodeLlama + if config.get("moe"): + # Mixtral + n_ctx = 32768 + elif config.get("rope_theta") == 1000000: + # CodeLlama + n_ctx = 16384 + elif config["norm_eps"] == 1e-05: + # LLaMA v2 + n_ctx = 4096 + else: + # LLaMA v1 + n_ctx = 2048 + + if "layers.0.feed_forward.w1.weight" in model: + n_ff = model["layers.0.feed_forward.w1.weight"].shape[0] + + if config.get("moe"): + n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0] + n_experts = config["moe"]["num_experts"] + n_experts_used = config["moe"]["num_experts_per_tok"] + f_rope_freq_base = 1e6 + + return Params( + n_vocab = model["tok_embeddings.weight"].shape[0], + n_embd = config["dim"], + n_layer = config["n_layers"], + n_ctx = n_ctx, + n_ff = n_ff, + n_head = (n_head := config["n_heads"]), + n_head_kv = config.get("n_kv_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, + f_norm_eps = config["norm_eps"], + f_rope_freq_base = config.get("rope_theta", f_rope_freq_base), + ) + + @staticmethod + def load(model_plus: ModelPlus) -> Params: + hf_config_path = model_plus.paths[0].parent / "config.json" + orig_config_path = model_plus.paths[0].parent / "params.json" + + if hf_config_path.exists(): + params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) + elif orig_config_path.exists(): + params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) + elif model_plus.format != 'none': + params = Params.guessed(model_plus.model) + else: + raise ValueError('Cannot guess params when model format is none') + + params.path_model = model_plus.paths[0].parent + + return params + + +# +# vocab +# + +class BpeVocab: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: + self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) + if isinstance(self.bpe_tokenizer.get('model'), dict): + self.vocab = self.bpe_tokenizer["model"]["vocab"] + else: + self.vocab = self.bpe_tokenizer + added_tokens: dict[str, int] + if fname_added_tokens is not None: + # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. + added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) + else: + # Fall back to trying to find the added tokens in tokenizer.json + tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json' + if not tokenizer_json_file.is_file(): + added_tokens = {} + else: + tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8")) + added_tokens = dict( + (item['content'], item['id']) + for item in tokenizer_json.get('added_tokens', []) + # Added tokens here can be duplicates of the main vocabulary. + if item['content'] not in self.bpe_tokenizer) + + vocab_size: int = len(self.vocab) + expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) + actual_ids = sorted(added_tokens.values()) + if expected_ids != actual_ids: + expected_end_id = vocab_size + len(actual_ids) - 1 + raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}") + + items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) + self.added_tokens_dict = added_tokens + self.added_tokens_list = [text for (text, idx) in items] + self.vocab_size_base: int = vocab_size + self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens + + def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()} + + for i, _ in enumerate(self.vocab): + yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL + + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + score = -1000.0 + yield text.encode("utf-8"), score, gguf.TokenType.CONTROL + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.bpe_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" + + +class SentencePieceVocab: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: + self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) + added_tokens: dict[str, int] + if fname_added_tokens is not None: + added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) + else: + added_tokens = {} + + vocab_size: int = self.sentencepiece_tokenizer.vocab_size() + + new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} + expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) + actual_new_ids = sorted(new_tokens.keys()) + + if expected_new_ids != actual_new_ids: + raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") + + # Token pieces that were added to the base vocabulary. + self.added_tokens_dict = added_tokens + self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens + + def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + tokenizer = self.sentencepiece_tokenizer + for i in range(tokenizer.vocab_size()): + piece = tokenizer.id_to_piece(i) + text: bytes = piece.encode("utf-8") + score: float = tokenizer.get_score(i) + + toktype = gguf.TokenType.NORMAL + if tokenizer.is_unknown(i): + toktype = gguf.TokenType.UNKNOWN + if tokenizer.is_control(i): + toktype = gguf.TokenType.CONTROL + + # NOTE: I think added_tokens are user defined. + # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto + # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED + + if tokenizer.is_unused(i): + toktype = gguf.TokenType.UNUSED + if tokenizer.is_byte(i): + toktype = gguf.TokenType.BYTE + + yield text, score, toktype + + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + score = -1000.0 + yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.sentencepiece_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" + + +class HfVocab: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None: + try: + from transformers import AutoTokenizer + except ImportError as e: + raise ImportError( + "To use HfVocab, please install the `transformers` package. " + "You can install it with `pip install transformers`." + ) from e + + print("fname_tokenizer:", fname_tokenizer) + # Allow the tokenizer to default to slow or fast versions. + # Explicitly set tokenizer to use local paths. + self.tokenizer = AutoTokenizer.from_pretrained( + fname_tokenizer, + cache_dir=fname_tokenizer, + local_files_only=True, + ) + + # Initialize lists and dictionaries for added tokens + self.added_tokens_list = [] + self.added_tokens_dict = dict() + self.added_tokens_ids = set() + + # Process added tokens + for tok, tokidx in sorted( + self.tokenizer.get_added_vocab().items(), key=lambda x: x[1] + ): + # Only consider added tokens that are not in the base vocabulary + if tokidx >= self.tokenizer.vocab_size: + self.added_tokens_list.append(tok) + self.added_tokens_dict[tok] = tokidx + self.added_tokens_ids.add(tokidx) + + # Store special tokens and their IDs + self.specials = { + tok: self.tokenizer.get_vocab()[tok] + for tok in self.tokenizer.all_special_tokens + } + self.special_ids = set(self.tokenizer.all_special_ids) + + # Set vocabulary sizes + self.vocab_size_base = self.tokenizer.vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens + + def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + reverse_vocab = { + id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items() + } + + for token_id in range(self.vocab_size_base): + # Skip processing added tokens here + if token_id in self.added_tokens_ids: + continue + + # Convert token text to bytes + token_text = reverse_vocab[token_id].encode("utf-8") + + # Yield token text, score, and type + yield token_text, self.get_token_score(token_id), self.get_token_type( + token_id, token_text, self.special_ids # Reuse already stored special IDs + ) + + def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType: + # Special case for byte tokens + if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text): + return gguf.TokenType.BYTE + + # Determine token type based on whether it's a special token + return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL + + def get_token_score(self, token_id: int) -> float: + # Placeholder for actual logic to determine the token's score + # This needs to be implemented based on specific requirements + return -1000.0 # Default score + + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + if text in self.specials: + toktype = self.get_token_type(self.specials[text], b'', self.special_ids) + score = self.get_token_score(self.specials[text]) + else: + toktype = gguf.TokenType.USER_DEFINED + score = -1000.0 + + yield text.encode("utf-8"), score, toktype + + def has_newline_token(self): + return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.hf_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" + + +Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab" + + +# +# data loading +# TODO: reuse (probably move to gguf.py?) +# + + +def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: + # print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) ) + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + +class Tensor(metaclass=ABCMeta): + data_type: DataType + + @abstractmethod + def astype(self, data_type: DataType) -> Tensor: ... + @abstractmethod + def permute(self, n_head: int, n_head_kv: int) -> Tensor: ... + @abstractmethod + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ... + @abstractmethod + def part(self, n_part: int) -> UnquantizedTensor: ... + @abstractmethod + def to_ggml(self) -> GGMLCompatibleTensor: ... + + +def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray: + assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" + fp32_arr = bf16_arr.astype(np.uint32) << 16 + return fp32_arr.view(np.float32) + + +class UnquantizedTensor(Tensor): + def __init__(self, ndarray: NDArray) -> None: + assert isinstance(ndarray, np.ndarray) + self.ndarray = ndarray + self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] + + def astype(self, data_type: DataType) -> Tensor: + dtype = data_type.dtype + if self.data_type == DT_BF16: + self.ndarray = bf16_to_fp32(self.ndarray) + return UnquantizedTensor(self.ndarray.astype(dtype)) + + def to_ggml(self) -> UnquantizedTensor: + return self + + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv)) + + def part(self, n_part: int) -> UnquantizedTensor: + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) + + def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor: + return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv)) + + +def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray: + tensor = lazy_tensor.load() + assert isinstance(tensor, UnquantizedTensor) + + # double-check: + actual_shape = list(tensor.ndarray.shape) + assert actual_shape == lazy_tensor.shape, (actual_shape, lazy_tensor.shape) + if expected_dtype is not None and expected_dtype != tensor.ndarray.dtype: + if convert: + tensor.ndarray = tensor.ndarray.astype(expected_dtype) + else: + raise ValueError(f'expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}') + + return tensor.ndarray + + +GGMLCompatibleTensor = UnquantizedTensor + + +@dataclass +class LazyTensor: + _load: Callable[[], Tensor] + shape: list[int] + data_type: DataType + description: str + + def load(self) -> Tensor: + ret = self._load() + # Should be okay if it maps to the same numpy type? + assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \ + (self.data_type, ret.data_type, self.description) + return ret + + def astype(self, data_type: DataType) -> LazyTensor: + self.validate_conversion_to(data_type) + + def load() -> Tensor: + return self.load().astype(data_type) + return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') + + def validate_conversion_to(self, data_type: DataType) -> None: + if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions: + raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.') + + +LazyModel: TypeAlias = 'dict[str, LazyTensor]' + + +@dataclass +class ModelPlus: + model: LazyModel + paths: list[Path] # Where this was read from. + format: Literal['ggml', 'torch', 'safetensors', 'none'] + vocab: Vocab | None # For GGML models (which have vocab built in), the vocab. + + +def merge_sharded(models: list[LazyModel]) -> LazyModel: + # Original LLaMA models have each file contain one part of each tensor. + # Use a dict instead of a set to preserve order. + names = {name: None for model in models for name in model} + + def convert(name: str) -> LazyTensor: + lazy_tensors: list[LazyTensor] = [model[name] for model in models] + if len(lazy_tensors) == 1: + # only one file; don't go through this procedure since there might + # be quantized tensors + return lazy_tensors[0] + if len(lazy_tensors[0].shape) == 1: + # the tensor is just duplicated in every file + return lazy_tensors[0] + if name.startswith('tok_embeddings.') or \ + name.endswith('.attention.wo.weight') or \ + name.endswith('.feed_forward.w2.weight'): + # split by columns + axis = 1 + else: + # split by rows + axis = 0 + concatenated_shape = list(lazy_tensors[0].shape) + concatenated_shape[axis] = sum(tensor.shape[axis] for tensor in lazy_tensors) + + def load() -> UnquantizedTensor: + ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors] + concatenated: NDArray = np.concatenate(ndarrays, axis=axis) + return UnquantizedTensor(concatenated) + description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]' + return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description) + return {name: convert(name) for name in names} + + +def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: + formats = set(mp.format for mp in models_plus) + assert len(formats) == 1, "different formats?" + format = formats.pop() + paths = [path for mp in models_plus for path in mp.paths] + # Use the first non-None vocab, if any. + try: + vocab = next(mp.vocab for mp in models_plus if mp.vocab is not None) + except StopIteration: + vocab = None + + if any("model.embed_tokens.weight" in mp.model for mp in models_plus): + # Transformers models put different tensors in different files, but + # don't split individual tensors between files. + model: LazyModel = {} + for mp in models_plus: + model.update(mp.model) + else: + model = merge_sharded([mp.model for mp in models_plus]) + + return ModelPlus(model, paths, format, vocab) # pytype: disable=wrong-arg-types + + +def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute(n_head, n_head_kv) + return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + + +def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + + +def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().part(n_part) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) + + +# Functionality that simulates `torch.load` but where individual tensors are +# only loaded into memory on demand, not all at once. +# PyTorch can't do this natively as of time of writing: +# - https://github.com/pytorch/pytorch/issues/64327 +# This allows us to de-shard without multiplying RAM usage, and also +# conveniently drops the PyTorch dependency (though we still need numpy). + + +@dataclass +class LazyStorageKind: + data_type: DataType + + +@dataclass +class LazyStorage: + load: Callable[[int, int], NDArray] + kind: LazyStorageKind + description: str + + +class LazyUnpickler(pickle.Unpickler): + def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile): + super().__init__(fp) + self.data_base_path = data_base_path + self.zip_file = zip_file + + def persistent_load(self, pid: Any) -> Any: + assert pid[0] == 'storage' + assert isinstance(pid[1], LazyStorageKind) + data_type = pid[1].data_type + filename_stem = pid[2] + filename = f'{self.data_base_path}/{filename_stem}' + info = self.zip_file.getinfo(filename) + + def load(offset: int, elm_count: int) -> NDArray: + dtype = data_type.dtype + fp = self.zip_file.open(info) + fp.seek(offset * dtype.itemsize) + size = elm_count * dtype.itemsize + data = fp.read(size) + assert len(data) == size + return np.frombuffer(data, dtype) + description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' + return LazyStorage(load=load, kind=pid[1], description=description) + + @staticmethod + def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, + requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: + assert isinstance(storage, LazyStorage) + + def load() -> UnquantizedTensor: + elm_count = stride[0] * size[0] + return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size)) + description = f'pickled storage_offset={storage_offset} in {storage.description}' + return LazyTensor(load, list(size), storage.kind.data_type, description) + + @staticmethod + def rebuild_from_type_v2(func, new_type, args, state): + return func(*args) + + CLASSES: dict[tuple[str, str], Any] = { + # getattr used here as a workaround for mypy not being smart enough to determine + # the staticmethods have a __func__ attribute. + ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), + ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), + ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), + ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), + ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), + ('torch', 'IntStorage'): LazyStorageKind(DT_I32), + ('torch', 'Tensor'): LazyTensor, + } + + def find_class(self, module: str, name: str) -> Any: + if not module.startswith('torch'): + return super().find_class(module, name) + return self.CLASSES[(module, name)] + + +def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: + zf = zipfile.ZipFile(outer_fp) + pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')] + assert len(pickle_paths) == 1, pickle_paths + pickle_fp = zf.open(pickle_paths[0], 'r') + unpickler = LazyUnpickler(pickle_fp, + data_base_path=pickle_paths[0][:-4], + zip_file=zf) + model = unpickler.load() + if 'model' in model: model = model['model'] + as_dict = dict(model.items()) + return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) + + +def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: + header_size, = struct.unpack(' LazyTensor: + data_type = SAFETENSORS_DATA_TYPES[info['dtype']] + numpy_dtype = data_type.dtype + shape: list[int] = info['shape'] + begin, end = info['data_offsets'] + assert 0 <= begin <= end <= len(byte_buf) + assert end - begin == math.prod(shape) * numpy_dtype.itemsize + buf = byte_buf[begin:end] + + def load() -> UnquantizedTensor: + return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) + description = f'safetensors begin={begin} end={end} type={data_type} path={path}' + return LazyTensor(load, shape, data_type, description) + model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} + return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) + + +def must_read(fp: IO[bytes], length: int) -> bytes: + ret = fp.read(length) + if len(ret) < length: + raise Exception("unexpectedly reached end of file") + return ret + + +@functools.lru_cache(maxsize=None) +def lazy_load_file(path: Path) -> ModelPlus: + fp = open(path, 'rb') + first8 = fp.read(8) + fp.seek(0) + if first8[:2] == b'PK': + # A zip file, i.e. PyTorch format + return lazy_load_torch_file(fp, path) + elif struct.unpack(' Iterable[Out]: + '''Parallel map, but with backpressure. If the caller doesn't call `next` + fast enough, this will stop calling `func` at some point rather than + letting results pile up in memory. Specifically, there is a max of one + output value buffered per thread.''' + if concurrency < 2: + yield from map(func, iterable) + # Not reached. + iterable = iter(iterable) + executor_class: type[ThreadPoolExecutor] | type[ProcessPoolExecutor] + if use_processpool_executor: + executor_class = ProcessPoolExecutor + else: + executor_class = ThreadPoolExecutor + with executor_class(max_workers=max_workers) as executor: + futures: list[concurrent.futures.Future[Out]] = [] + done = False + for _ in range(concurrency): + try: + futures.append(executor.submit(func, next(iterable))) + except StopIteration: + done = True + break + + while futures: + result = futures.pop(0).result() + while not done and len(futures) < concurrency: + try: + futures.append(executor.submit(func, next(iterable))) + except StopIteration: + done = True + break + yield result + + +def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: + # Handle special case where the model's vocab size is not set + if params.n_vocab == -1: + raise ValueError( + f"The model's vocab size is set to -1 in params.json. Please update it manually. Maybe {vocab.vocab_size}?" + ) + + # Check for a vocab size mismatch + if params.n_vocab == vocab.vocab_size: + print("Ignoring added_tokens.json since model matches vocab size without it.") + return + + if pad_vocab and params.n_vocab > vocab.vocab_size: + pad_count = params.n_vocab - vocab.vocab_size + print( + f"Padding vocab with {pad_count} token(s) - through " + ) + for i in range(1, pad_count + 1): + vocab.added_tokens_dict[f""] = -1 + vocab.added_tokens_list.append(f"") + vocab.vocab_size = params.n_vocab + return + + msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer} has {vocab.vocab_size})." + if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20: + msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." + if vocab.vocab_size < params.n_vocab: + msg += " Add the --pad-vocab option and try again." + + raise Exception(msg) + + +class OutputFile: + def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None: + self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) + + def add_meta_arch(self, params: Params) -> None: + name = "LLaMA" + + # TODO: better logic to determine model name + if params.n_ctx == 4096: + name = "LLaMA v2" + elif params.path_model is not None: + name = str(params.path_model.parent).split('/')[-1] + + self.gguf.add_name (name) + self.gguf.add_context_length (params.n_ctx) + self.gguf.add_embedding_length (params.n_embd) + self.gguf.add_block_count (params.n_layer) + self.gguf.add_feed_forward_length (params.n_ff) + self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) + self.gguf.add_head_count (params.n_head) + self.gguf.add_head_count_kv (params.n_head_kv) + + if params.n_experts: + self.gguf.add_expert_count(params.n_experts) + + if params.n_experts_used: + self.gguf.add_expert_used_count(params.n_experts_used) + + if params.f_norm_eps: + self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) + else: + raise ValueError('f_norm_eps is None') + + if params.f_rope_freq_base is not None: + self.gguf.add_rope_freq_base(params.f_rope_freq_base) + + if params.rope_scaling_type: + assert params.f_rope_scale is not None + self.gguf.add_rope_scaling_type(params.rope_scaling_type) + self.gguf.add_rope_scaling_factor(params.f_rope_scale) + + if params.n_orig_ctx is not None: + self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx) + + if params.rope_finetuned is not None: + self.gguf.add_rope_scaling_finetuned(params.rope_finetuned) + + if params.ftype is not None: + self.gguf.add_file_type(params.ftype) + + def handle_tokenizer_model(self, vocab: Vocab) -> str: + # Map the vocab types to the supported tokenizer models + tokenizer_model = { + SentencePieceVocab: "llama", + HfVocab: "llama", + BpeVocab: "gpt2", + }.get(type(vocab)) + + # Block if vocab type is not predefined + if tokenizer_model is None: + raise ValueError("Unknown vocab type: Not supported") + + return tokenizer_model + + def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: + tokens = [] + scores = [] + toktypes = [] + + # NOTE: `all_tokens` returns the base vocabulary and added tokens + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size + + return tokens, scores, toktypes + + def add_meta_vocab(self, vocab: Vocab) -> None: + # Handle the tokenizer model + tokenizer_model = self.handle_tokenizer_model(vocab) + + # Ensure that tokenizer_model is added to the GGUF model + self.gguf.add_tokenizer_model(tokenizer_model) + + # Extract model vocabulary for model conversion + tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab) + + # Add extracted token information for model conversion + self.gguf.add_token_list(tokens) + self.gguf.add_token_scores(scores) + self.gguf.add_token_types(toktypes) + + def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: + svocab.add_to_gguf(self.gguf) + + def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: + n_elements = int(np.prod(tensor.shape)) + raw_dtype = getattr(tensor.data_type, 'ggml_type', None) + data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype + data_nbytes = tensor.data_type.elements_to_bytes(n_elements) + self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype) + + def write_meta(self) -> None: + self.gguf.write_header_to_file() + self.gguf.write_kv_data_to_file() + + def write_tensor_info(self) -> None: + self.gguf.write_ti_data_to_file() + + def close(self) -> None: + self.gguf.close() + + @staticmethod + def write_vocab_only( + fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, + ) -> None: + check_vocab_size(params, vocab, pad_vocab = pad_vocab) + + of = OutputFile(fname_out, endianess=endianess) + + # meta data + of.add_meta_arch(params) + of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) + + of.write_meta() + + of.close() + + @staticmethod + def do_item(item: tuple[str, LazyTensor]) -> tuple[DataType, NDArray]: + name, lazy_tensor = item + tensor = lazy_tensor.load().to_ggml() + return (lazy_tensor.data_type, tensor.ndarray) + + @staticmethod + def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: + dt, arr = item + if not isinstance(dt, QuantizedDataType): + return arr + return dt.quantize(arr) + + @staticmethod + def write_all( + fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, + concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + pad_vocab: bool = False, + ) -> None: + check_vocab_size(params, vocab, pad_vocab=pad_vocab) + + of = OutputFile(fname_out, endianess=endianess) + + # meta data + of.add_meta_arch(params) + of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) + + # tensor info + for name, lazy_tensor in model.items(): + of.add_tensor_info(name, lazy_tensor) + + of.write_meta() + of.write_tensor_info() + + # tensor data + ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) + if ftype == GGMLFileType.MostlyQ8_0: + ndarrays = bounded_parallel_map( + OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, + use_processpool_executor=True, + ) + else: + ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) + + start = time.time() + for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + elapsed = time.time() - start + size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + padi = len(str(len(model))) + print( + f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" + ) + of.gguf.write_tensor_data(ndarray) + + of.close() + + +def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: + wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type + + if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32): + return GGMLFileType.AllF32 + if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)): + return GGMLFileType.MostlyF16 + if output_type_str == "q8_0": + return GGMLFileType.MostlyQ8_0 + + name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} + + raise Exception(f"Unexpected combination of types: {name_to_type}") + + +def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: + return {name: tensor.astype(output_type.type_for_tensor(name, tensor)) + for (name, tensor) in model.items()} + + +def convert_model_names(model: LazyModel, params: Params) -> LazyModel: + tmap = gguf.TensorNameMap(ARCH, params.n_layer) + should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) + + tmp = model + + # HF models permut or pack some of the tensors, so we need to undo that + for i in itertools.count(): + if f"model.layers.{i}.self_attn.q_proj.weight" in model: + print(f"Permuting layer {i}") + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv) + # tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + elif f"model.layers.{i}.self_attn.W_pack.weight" in model: + print(f"Unpacking and permuting layer {i}") + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv) + tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + del tmp[f"model.layers.{i}.self_attn.W_pack.weight"] + else: + break + + out: LazyModel = {} + for name, lazy_tensor in model.items(): + tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None) + if name_new is None: + raise Exception(f"Unexpected tensor name: {name}") + + if tensor_type in should_skip: + print(f"skipping tensor {name_new}") + continue + + print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") + out[name_new] = lazy_tensor + + return out + + +def nth_multifile_path(path: Path, n: int) -> Path | None: + '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + the nth path in the model. + ''' + # Support the following patterns: + patterns: list[tuple[str, str]] = [ + # - x.00.pth, x.01.pth, etc. + (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'), + # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. + (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'), + # x.bin, x.bin.1, etc. + (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}') + ] + for regex, replacement in patterns: + if re.search(regex, path.name): + new_path = path.with_name(re.sub(regex, replacement, path.name)) + if new_path.exists(): + return new_path + return None + + +def find_multifile_paths(path: Path) -> list[Path]: + '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + the whole list of paths in the model. + ''' + ret: list[Path] = [] + for i in itertools.count(): + nth_path = nth_multifile_path(path, i) + if nth_path is None: + break + ret.append(nth_path) + if not ret: + # No matches. This should only happen if the file was named, e.g., + # foo.0, and there was no file named foo. Oh well, try to process it + # as a single file. + return [path] + return ret + + +def load_some_model(path: Path) -> ModelPlus: + '''Load a model of any supported format.''' + # Be extra-friendly and accept either a file or a directory: + if path.is_dir(): + # Check if it's a set of safetensors files first + globs = ["model-00001-of-*.safetensors", "model.safetensors"] + files = [file for glob in globs for file in path.glob(glob)] + if not files: + # Try the PyTorch patterns too, with lower priority + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] + files = [file for glob in globs for file in path.glob(glob)] + if not files: + raise Exception(f"Can't find model in directory {path}") + if len(files) > 1: + raise Exception(f"Found multiple models in {path}, not sure which to pick: {files}") + path = files[0] + + paths = find_multifile_paths(path) + models_plus: list[ModelPlus] = [] + for path in paths: + print(f"Loading model file {path}") + models_plus.append(lazy_load_file(path)) + + model_plus = merge_multifile_models(models_plus) + return model_plus + + +class VocabFactory: + def __init__(self, path: Path): + self.path = path + self.files: dict[str, Path | None] = { + "tokenizer.model": None, + "vocab.json": None, + "tokenizer.json": None, + } + self._detect_files() + + def _detect_files(self): + for file in self.files.keys(): + file_path = self.path / file + parent_file_path = self.path.parent / file + if file_path.exists(): + self.files[file] = file_path + elif parent_file_path.exists(): + self.files[file] = parent_file_path + print(f"Found vocab files: {self.files}") + + def _select_file(self, vocabtype: str | None) -> Path: + if vocabtype in ["spm", "bpe"]: + for file_key in self.files.keys(): + if (file := self.files[file_key]) is not None: + return file + raise FileNotFoundError(f"{vocabtype} vocab not found.") + if vocabtype == "hfft": + # For Hugging Face Fast Tokenizer, return the directory path instead of a specific file + return self.path + raise ValueError(f"Unsupported vocabulary type {vocabtype}") + + def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: Path) -> gguf.SpecialVocab: + load_merges = vocabtype == "bpe" + n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None + return gguf.SpecialVocab( + model_parent_path, + load_merges=load_merges, + special_token_types=None, # Predetermined or passed as a parameter + n_vocab=n_vocab, + ) + + def load_vocab(self, vocabtype: str, model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]: + path = self._select_file(vocabtype) + print(f"Loading vocab file '{path}', type '{vocabtype}'") + + added_tokens_path = path.parent / "added_tokens.json" + vocab: Vocab + if vocabtype == "bpe": + vocab = BpeVocab( + path, added_tokens_path if added_tokens_path.exists() else None + ) + elif vocabtype == "spm": + vocab = SentencePieceVocab( + path, added_tokens_path if added_tokens_path.exists() else None + ) + elif vocabtype == "hfft": + vocab = HfVocab( + path, added_tokens_path if added_tokens_path.exists() else None + ) + else: + raise ValueError(f"Unsupported vocabulary type {vocabtype}") + # FIXME: Respect --vocab-dir? + special_vocab = self._create_special_vocab( + vocab, + vocabtype, + model_parent_path, + ) + return vocab, special_vocab + + +def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path: + namestr = { + GGMLFileType.AllF32: "f32", + GGMLFileType.MostlyF16: "f16", + GGMLFileType.MostlyQ8_0:"q8_0", + }[file_type] + ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf" + if ret in model_paths: + sys.stderr.write( + f"Error: Default output path ({ret}) would overwrite the input. " + "Please explicitly specify a path using --outfile.\n") + sys.exit(1) + return ret + + +def do_dump_model(model_plus: ModelPlus) -> None: + print(f"model_plus.paths = {model_plus.paths!r}") + print(f"model_plus.format = {model_plus.format!r}") + print(f"model_plus.vocab = {model_plus.vocab!r}") + for name, lazy_tensor in model_plus.model.items(): + print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") + + +def main(args_in: list[str] | None = None) -> None: + output_choices = ["f32", "f16"] + if np.uint32(1) == np.uint32(1).newbyteorder("<"): + # We currently only support Q8_0 output on little endian systems. + output_choices.append("q8_0") + vocab_types = ["spm", "bpe", "hfft"] + parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") + parser.add_argument("--awq-path", type=Path, help="Path to scale awq cache file", default=None) + parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") + parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") + parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") + parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") + parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") + parser.add_argument("--vocab-type", choices=vocab_types, help="The vocabulary format used to define the tokenizer model (default: spm)", default="spm") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") + parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") + parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default=DEFAULT_CONCURRENCY) + parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine") + parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") + + args = parser.parse_args(args_in) + if args.awq_path: + sys.path.insert(1, str(Path(__file__).parent / 'awq-py')) + from awq.apply_awq import add_scale_weights # type: ignore[import-not-found] + tmp_model_path = args.model / "weighted_model" + if tmp_model_path.is_dir(): + print(f"{tmp_model_path} exists as a weighted model.") + else: + tmp_model_path.mkdir(parents=True, exist_ok=True) + print("Saving new weighted model ...") + add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path)) + print(f"Saved weighted model at {tmp_model_path}.") + args.model = tmp_model_path + + if args.dump_single: + model_plus = lazy_load_file(args.model) + do_dump_model(model_plus) + return + + if not args.vocab_only: + model_plus = load_some_model(args.model) + else: + model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) + + if args.dump: + do_dump_model(model_plus) + return + endianess = gguf.GGUFEndian.LITTLE + if args.big_endian: + endianess = gguf.GGUFEndian.BIG + + params = Params.load(model_plus) + if params.n_ctx == -1: + if args.ctx is None: + raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n" + "Please specify one with --ctx:\n" + " - LLaMA v1: --ctx 2048\n" + " - LLaMA v2: --ctx 4096\n") + params.n_ctx = args.ctx + + if args.outtype: + params.ftype = { + "f32": GGMLFileType.AllF32, + "f16": GGMLFileType.MostlyF16, + "q8_0": GGMLFileType.MostlyQ8_0, + }[args.outtype] + + print(f"params = {params}") + + model_parent_path = model_plus.paths[0].parent + vocab_path = Path(args.vocab_dir or args.model or model_parent_path) + vocab_factory = VocabFactory(vocab_path) + vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type, model_parent_path) + + if args.vocab_only: + if not args.outfile: + raise ValueError("need --outfile if using --vocab-only") + outfile = args.outfile + OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, + endianess=endianess, pad_vocab=args.pad_vocab) + print(f"Wrote {outfile}") + return + + if model_plus.vocab is not None and args.vocab_dir is None: + vocab = model_plus.vocab + + print(f"Vocab info: {vocab}") + print(f"Special vocab info: {special_vocab}") + + model = model_plus.model + model = convert_model_names(model, params) + ftype = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, ftype) + outfile = args.outfile or default_outfile(model_plus.paths, ftype) + + params.ftype = ftype + print(f"Writing {outfile}, format {ftype}") + + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, + concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab) + print(f"Wrote {outfile}") + + +if __name__ == '__main__': + main() diff --git a/extensions/model-extension/scripts/gguf-py/LICENSE b/extensions/model-extension/scripts/gguf-py/LICENSE new file mode 100644 index 0000000000..76f67efdc6 --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Georgi Gerganov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/extensions/model-extension/scripts/gguf-py/README.md b/extensions/model-extension/scripts/gguf-py/README.md new file mode 100644 index 0000000000..22d7ffa52d --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/README.md @@ -0,0 +1,81 @@ +## gguf + +This is a Python package for writing binary files in the [GGUF](https://github.com/ggerganov/ggml/pull/302) +(GGML Universal File) format. + +See [convert-llama-hf-to-gguf.py](https://github.com/ggerganov/llama.cpp/blob/master/convert-hf-to-gguf.py) +as an example for its usage. + +## Installation +```sh +pip install gguf +``` + +## API Examples/Simple Tools + +[examples/writer.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model. + +[scripts/gguf-dump.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-dump.py) — Dumps a GGUF file's metadata to the console. + +[scripts/gguf-set-metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-set-metadata.py) — Allows changing simple metadata values in a GGUF file by key. + +[scripts/gguf-convert-endian.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-convert-endian.py) — Allows converting the endianness of GGUF files. + +## Development +Maintainers who participate in development of this package are advised to install it in editable mode: + +```sh +cd /path/to/llama.cpp/gguf-py + +pip install --editable . +``` + +**Note**: This may require to upgrade your Pip installation, with a message saying that editable installation currently requires `setup.py`. +In this case, upgrade Pip to the latest: + +```sh +pip install --upgrade pip +``` + +## Automatic publishing with CI + +There's a GitHub workflow to make a release automatically upon creation of tags in a specified format. + +1. Bump the version in `pyproject.toml`. +2. Create a tag named `gguf-vx.x.x` where `x.x.x` is the semantic version number. + +```sh +git tag -a gguf-v1.0.0 -m "Version 1.0 release" +``` + +3. Push the tags. + +```sh +git push origin --tags +``` + +## Manual publishing +If you want to publish the package manually for any reason, you need to have `twine` and `build` installed: + +```sh +pip install build twine +``` + +Then, follow these steps to release a new version: + +1. Bump the version in `pyproject.toml`. +2. Build the package: + +```sh +python -m build +``` + +3. Upload the generated distribution archives: + +```sh +python -m twine upload dist/* +``` + +## TODO +- [ ] Add tests +- [ ] Include conversion scripts as command line entry points in this package. diff --git a/extensions/model-extension/scripts/gguf-py/examples/writer.py b/extensions/model-extension/scripts/gguf-py/examples/writer.py new file mode 100755 index 0000000000..f39eed1afe --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/examples/writer.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +import sys +from pathlib import Path + +import numpy as np + +# Necessary to load the local gguf package +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf import GGUFWriter # noqa: E402 + + +# Example usage: +def writer_example() -> None: + # Example usage with a file + gguf_writer = GGUFWriter("example.gguf", "llama") + + gguf_writer.add_architecture() + gguf_writer.add_block_count(12) + gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer + gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float + gguf_writer.add_custom_alignment(64) + + tensor1 = np.ones((32,), dtype=np.float32) * 100.0 + tensor2 = np.ones((64,), dtype=np.float32) * 101.0 + tensor3 = np.ones((96,), dtype=np.float32) * 102.0 + + gguf_writer.add_tensor("tensor1", tensor1) + gguf_writer.add_tensor("tensor2", tensor2) + gguf_writer.add_tensor("tensor3", tensor3) + + gguf_writer.write_header_to_file() + gguf_writer.write_kv_data_to_file() + gguf_writer.write_tensors_to_file() + + gguf_writer.close() + + +if __name__ == '__main__': + writer_example() diff --git a/extensions/model-extension/scripts/gguf-py/gguf/__init__.py b/extensions/model-extension/scripts/gguf-py/gguf/__init__.py new file mode 100644 index 0000000000..110ab342cc --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/gguf/__init__.py @@ -0,0 +1,5 @@ +from .constants import * +from .gguf_reader import * +from .gguf_writer import * +from .tensor_mapping import * +from .vocab import * diff --git a/extensions/model-extension/scripts/gguf-py/gguf/constants.py b/extensions/model-extension/scripts/gguf-py/gguf/constants.py new file mode 100644 index 0000000000..1cfd41c0be --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/gguf/constants.py @@ -0,0 +1,665 @@ +from __future__ import annotations + +import sys +from enum import Enum, IntEnum, auto +from typing import Any + +# +# constants +# + +GGUF_MAGIC = 0x46554747 # "GGUF" +GGUF_VERSION = 3 +GGUF_DEFAULT_ALIGNMENT = 32 + +# +# metadata keys +# + + +class Keys: + class General: + ARCHITECTURE = "general.architecture" + QUANTIZATION_VERSION = "general.quantization_version" + ALIGNMENT = "general.alignment" + NAME = "general.name" + AUTHOR = "general.author" + URL = "general.url" + DESCRIPTION = "general.description" + LICENSE = "general.license" + SOURCE_URL = "general.source.url" + SOURCE_HF_REPO = "general.source.huggingface.repository" + FILE_TYPE = "general.file_type" + + class LLM: + CONTEXT_LENGTH = "{arch}.context_length" + EMBEDDING_LENGTH = "{arch}.embedding_length" + BLOCK_COUNT = "{arch}.block_count" + FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" + USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" + TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" + EXPERT_COUNT = "{arch}.expert_count" + EXPERT_USED_COUNT = "{arch}.expert_used_count" + + class Attention: + HEAD_COUNT = "{arch}.attention.head_count" + HEAD_COUNT_KV = "{arch}.attention.head_count_kv" + MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias" + CLAMP_KQV = "{arch}.attention.clamp_kqv" + KEY_LENGTH = "{arch}.attention.key_length" + VALUE_LENGTH = "{arch}.attention.value_length" + LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" + LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" + + class Rope: + DIMENSION_COUNT = "{arch}.rope.dimension_count" + FREQ_BASE = "{arch}.rope.freq_base" + SCALING_TYPE = "{arch}.rope.scaling.type" + SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" + SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + + class Tokenizer: + MODEL = "tokenizer.ggml.model" + LIST = "tokenizer.ggml.tokens" + TOKEN_TYPE = "tokenizer.ggml.token_type" + SCORES = "tokenizer.ggml.scores" + MERGES = "tokenizer.ggml.merges" + BOS_ID = "tokenizer.ggml.bos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" + UNK_ID = "tokenizer.ggml.unknown_token_id" + SEP_ID = "tokenizer.ggml.seperator_token_id" + PAD_ID = "tokenizer.ggml.padding_token_id" + ADD_BOS = "tokenizer.ggml.add_bos_token" + ADD_EOS = "tokenizer.ggml.add_eos_token" + ADD_PREFIX = "tokenizer.ggml.add_space_prefix" + HF_JSON = "tokenizer.huggingface.json" + RWKV = "tokenizer.rwkv.world" + CHAT_TEMPLATE = "tokenizer.chat_template" + + +# +# recommended mapping of model tensor names for storage in gguf +# + + +class MODEL_ARCH(IntEnum): + LLAMA = auto() + FALCON = auto() + BAICHUAN = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + MPT = auto() + STARCODER = auto() + PERSIMMON = auto() + REFACT = auto() + BERT = auto() + BLOOM = auto() + STABLELM = auto() + QWEN = auto() + QWEN2 = auto() + PHI2 = auto() + PLAMO = auto() + CODESHELL = auto() + ORION = auto() + INTERNLM2 = auto() + MINICPM = auto() + + +class MODEL_TENSOR(IntEnum): + TOKEN_EMBD = auto() + TOKEN_EMBD_NORM = auto() + TOKEN_TYPES = auto() + POS_EMBD = auto() + OUTPUT = auto() + OUTPUT_NORM = auto() + ROPE_FREQS = auto() + ATTN_Q = auto() + ATTN_K = auto() + ATTN_V = auto() + ATTN_QKV = auto() + ATTN_OUT = auto() + ATTN_NORM = auto() + ATTN_NORM_2 = auto() + ATTN_ROT_EMBD = auto() + FFN_GATE_INP = auto() + FFN_NORM = auto() + FFN_GATE = auto() + FFN_DOWN = auto() + FFN_UP = auto() + FFN_ACT = auto() + FFN_GATE_EXP = auto() + FFN_DOWN_EXP = auto() + FFN_UP_EXP = auto() + ATTN_Q_NORM = auto() + ATTN_K_NORM = auto() + + +MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { + MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.GPT2: "gpt2", + MODEL_ARCH.GPTJ: "gptj", + MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.MPT: "mpt", + MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.PERSIMMON: "persimmon", + MODEL_ARCH.REFACT: "refact", + MODEL_ARCH.BERT: "bert", + MODEL_ARCH.BLOOM: "bloom", + MODEL_ARCH.STABLELM: "stablelm", + MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.QWEN2: "qwen2", + MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.CODESHELL: "codeshell", + MODEL_ARCH.ORION: "orion", + MODEL_ARCH.INTERNLM2: "internlm2", + MODEL_ARCH.MINICPM: "minicpm", +} + +TENSOR_NAMES: dict[MODEL_TENSOR, str] = { + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", + MODEL_TENSOR.TOKEN_TYPES: "token_types", + MODEL_TENSOR.POS_EMBD: "position_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", + MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", + MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", + MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", + MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}", + MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}", + MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}", +} + +MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { + MODEL_ARCH.LLAMA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], + MODEL_ARCH.GPTNEOX: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.FALCON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_NORM_2, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.BAICHUAN: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.STARCODER: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.BERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.MPT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_ACT, + ], + MODEL_ARCH.GPTJ: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.PERSIMMON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.REFACT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.BLOOM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.STABLELM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.PLAMO: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.GPT2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.PHI2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.CODESHELL: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.ORION: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.INTERNLM2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.MINICPM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], + # TODO +} + +# tensors that will not be serialized +MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { + MODEL_ARCH.LLAMA: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.BAICHUAN: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.PERSIMMON: [ + MODEL_TENSOR.ROPE_FREQS, + ], + MODEL_ARCH.QWEN: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.CODESHELL: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], + MODEL_ARCH.ORION: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], +} + +# +# types +# + + +class TokenType(IntEnum): + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 + USER_DEFINED = 4 + UNUSED = 5 + BYTE = 6 + + +class RopeScalingType(Enum): + NONE = 'none' + LINEAR = 'linear' + YARN = 'yarn' + + +class GGMLQuantizationType(IntEnum): + F32 = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 + + +class GGUFEndian(IntEnum): + LITTLE = 0 + BIG = 1 + + +class GGUFValueType(IntEnum): + UINT8 = 0 + INT8 = 1 + UINT16 = 2 + INT16 = 3 + UINT32 = 4 + INT32 = 5 + FLOAT32 = 6 + BOOL = 7 + STRING = 8 + ARRAY = 9 + UINT64 = 10 + INT64 = 11 + FLOAT64 = 12 + + @staticmethod + def get_type(val: Any) -> GGUFValueType: + if isinstance(val, (str, bytes, bytearray)): + return GGUFValueType.STRING + elif isinstance(val, list): + return GGUFValueType.ARRAY + elif isinstance(val, float): + return GGUFValueType.FLOAT32 + elif isinstance(val, bool): + return GGUFValueType.BOOL + elif isinstance(val, int): + return GGUFValueType.INT32 + # TODO: need help with 64-bit types in Python + else: + print("Unknown type:", type(val)) + sys.exit() + + +# Note: Does not support GGML_QKK_64 +QK_K = 256 +# Items here are (block size, type size) +GGML_QUANT_SIZES = { + GGMLQuantizationType.F32: (1, 4), + GGMLQuantizationType.F16: (1, 2), + GGMLQuantizationType.Q4_0: (32, 2 + 16), + GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), + GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), + GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), + GGMLQuantizationType.Q8_0: (32, 2 + 32), + GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), + GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), + GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), + GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), + GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), + GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), + GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), +} + + +# Aliases for backward compatibility. + +# general +KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE +KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION +KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT +KEY_GENERAL_NAME = Keys.General.NAME +KEY_GENERAL_AUTHOR = Keys.General.AUTHOR +KEY_GENERAL_URL = Keys.General.URL +KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION +KEY_GENERAL_LICENSE = Keys.General.LICENSE +KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL +KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO +KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE + +# LLM +KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH +KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH +KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT +KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH +KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL +KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT + +# attention +KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT +KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV +KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS +KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV +KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS +KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS + +# RoPE +KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT +KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE +KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE +KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR +KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN +KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED + +# tokenization +KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL +KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST +KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE +KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES +KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES +KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID +KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID +KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID +KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID +KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID +KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON +KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV diff --git a/extensions/model-extension/scripts/gguf-py/gguf/gguf.py b/extensions/model-extension/scripts/gguf-py/gguf/gguf.py new file mode 100644 index 0000000000..651a81eb82 --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/gguf/gguf.py @@ -0,0 +1,15 @@ +# This file left for compatibility. If you want to use the GGUF API from Python +# then don't import gguf/gguf.py directly. If you're looking for examples, see the +# examples/ directory for gguf-py + +import importlib +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +# Compatibility for people trying to import gguf/gguf.py directly instead of as a package. +importlib.invalidate_caches() +import gguf # noqa: E402 + +importlib.reload(gguf) diff --git a/extensions/model-extension/scripts/gguf-py/gguf/gguf_reader.py b/extensions/model-extension/scripts/gguf-py/gguf/gguf_reader.py new file mode 100644 index 0000000000..5b6d4ba6bc --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/gguf/gguf_reader.py @@ -0,0 +1,264 @@ +# +# GGUF file reading/modification support. For API usage information, +# please see the files scripts/ for some fairly simple examples. +# +from __future__ import annotations + +import os +from collections import OrderedDict +from typing import Any, Literal, NamedTuple, TypeVar, Union + +import numpy as np +import numpy.typing as npt + +if __name__ == "__main__": + import sys + from pathlib import Path + + # Allow running file in package as a script. + sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf.constants import ( + GGML_QUANT_SIZES, + GGUF_DEFAULT_ALIGNMENT, + GGUF_MAGIC, + GGUF_VERSION, + GGMLQuantizationType, + GGUFValueType, +) + + +READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION] + + +class ReaderField(NamedTuple): + # Offset to start of this field. + offset: int + + # Name of the field (not necessarily from file data). + name: str + + # Data parts. Some types have multiple components, such as strings + # that consist of a length followed by the string data. + parts: list[npt.NDArray[Any]] = [] + + # Indexes into parts that we can call the actual data. For example + # an array of strings will be populated with indexes to the actual + # string data. + data: list[int] = [-1] + + types: list[GGUFValueType] = [] + + +class ReaderTensor(NamedTuple): + name: str + tensor_type: GGMLQuantizationType + shape: npt.NDArray[np.uint32] + n_elements: int + n_bytes: int + data_offset: int + data: npt.NDArray[Any] + field: ReaderField + + +class GGUFReader: + # I - same as host, S - swapped + byte_order: Literal['I' | 'S'] = 'I' + alignment: int = GGUF_DEFAULT_ALIGNMENT + + # Note: Internal helper, API may change. + gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = { + GGUFValueType.UINT8: np.uint8, + GGUFValueType.INT8: np.int8, + GGUFValueType.UINT16: np.uint16, + GGUFValueType.INT16: np.int16, + GGUFValueType.UINT32: np.uint32, + GGUFValueType.INT32: np.int32, + GGUFValueType.FLOAT32: np.float32, + GGUFValueType.UINT64: np.uint64, + GGUFValueType.INT64: np.int64, + GGUFValueType.FLOAT64: np.float64, + GGUFValueType.BOOL: np.bool_, + } + + def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'): + self.data = np.memmap(path, mode = mode) + offs = 0 + if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: + raise ValueError('GGUF magic invalid') + offs += 4 + temp_version = self._get(offs, np.uint32) + if temp_version[0] & 65535 == 0: + # If we get 0 here that means it's (probably) a GGUF file created for + # the opposite byte order of the machine this script is running on. + self.byte_order = 'S' + temp_version = temp_version.newbyteorder(self.byte_order) + version = temp_version[0] + if version not in READER_SUPPORTED_VERSIONS: + raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle') + self.fields: OrderedDict[str, ReaderField] = OrderedDict() + self.tensors: list[ReaderTensor] = [] + offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32])) + temp_counts = self._get(offs, np.uint64, 2) + offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64])) + offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64])) + tensor_count, kv_count = temp_counts + offs = self._build_fields(offs, kv_count) + offs, tensors_fields = self._build_tensors_fields(offs, tensor_count) + new_align = self.fields.get('general.alignment') + if new_align is not None: + if new_align.types != [GGUFValueType.UINT32]: + raise ValueError('Bad type for general.alignment field') + self.alignment = new_align.parts[-1][0] + padding = offs % self.alignment + if padding != 0: + offs += self.alignment - padding + self._build_tensors(offs, tensors_fields) + + _DT = TypeVar('_DT', bound = npt.DTypeLike) + + # Fetch a key/value metadata field by key. + def get_field(self, key: str) -> Union[ReaderField, None]: + return self.fields.get(key, None) + + # Fetch a tensor from the list by index. + def get_tensor(self, idx: int) -> ReaderTensor: + return self.tensors[idx] + + def _get( + self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None, + ) -> npt.NDArray[Any]: + count = int(count) + itemsize = int(np.empty([], dtype = dtype).itemsize) + end_offs = offset + itemsize * count + return ( + self.data[offset:end_offs] + .view(dtype = dtype)[:count] + .newbyteorder(override_order or self.byte_order) + ) + + def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: + if field.name in self.fields: + raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}') + self.fields[field.name] = field + return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) + + def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: + slen = self._get(offset, np.uint64) + return slen, self._get(offset + 8, np.uint8, slen[0]) + + def _get_field_parts( + self, orig_offs: int, raw_type: int, + ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]: + offs = orig_offs + types: list[GGUFValueType] = [] + gtype = GGUFValueType(raw_type) + types.append(gtype) + # Handle strings. + if gtype == GGUFValueType.STRING: + sparts: list[npt.NDArray[Any]] = list(self._get_str(offs)) + size = sum(int(part.nbytes) for part in sparts) + return size, sparts, [1], types + # Check if it's a simple scalar type. + nptype = self.gguf_scalar_to_np.get(gtype) + if nptype is not None: + val = self._get(offs, nptype) + return int(val.nbytes), [val], [0], types + # Handle arrays. + if gtype == GGUFValueType.ARRAY: + raw_itype = self._get(offs, np.uint32) + offs += int(raw_itype.nbytes) + alen = self._get(offs, np.uint64) + offs += int(alen.nbytes) + aparts: list[npt.NDArray[Any]] = [raw_itype, alen] + data_idxs: list[int] = [] + for idx in range(alen[0]): + curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0]) + if idx == 0: + types += curr_types + idxs_offs = len(aparts) + aparts += curr_parts + data_idxs += (idx + idxs_offs for idx in curr_idxs) + offs += curr_size + return offs - orig_offs, aparts, data_idxs, types + # We can't deal with this one. + raise ValueError('Unknown/unhandled field type {gtype}') + + def _get_tensor(self, orig_offs: int) -> ReaderField: + offs = orig_offs + name_len, name_data = self._get_str(offs) + offs += int(name_len.nbytes + name_data.nbytes) + n_dims = self._get(offs, np.uint32) + offs += int(n_dims.nbytes) + dims = self._get(offs, np.uint64, n_dims[0]) + offs += int(dims.nbytes) + raw_dtype = self._get(offs, np.uint32) + offs += int(raw_dtype.nbytes) + offset_tensor = self._get(offs, np.uint64) + offs += int(offset_tensor.nbytes) + return ReaderField( + orig_offs, + str(bytes(name_data), encoding = 'utf-8'), + [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor], + [1, 3, 4, 5], + ) + + def _build_fields(self, offs: int, count: int) -> int: + for _ in range(count): + orig_offs = offs + kv_klen, kv_kdata = self._get_str(offs) + offs += int(kv_klen.nbytes + kv_kdata.nbytes) + raw_kv_type = self._get(offs, np.uint32) + offs += int(raw_kv_type.nbytes) + parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type] + idxs_offs = len(parts) + field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0]) + parts += field_parts + self._push_field(ReaderField( + orig_offs, + str(bytes(kv_kdata), encoding = 'utf-8'), + parts, + [idx + idxs_offs for idx in field_idxs], + field_types, + ), skip_sum = True) + offs += field_size + return offs + + def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]: + tensor_fields = [] + for _ in range(count): + field = self._get_tensor(offs) + offs += sum(int(part.nbytes) for part in field.parts) + tensor_fields.append(field) + return offs, tensor_fields + + def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None: + tensors = [] + for field in fields: + _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts + ggml_type = GGMLQuantizationType(raw_dtype[0]) + n_elems = np.prod(dims) + block_size, type_size = GGML_QUANT_SIZES[ggml_type] + n_bytes = n_elems * type_size // block_size + data_offs = int(start_offs + offset_tensor[0]) + item_type: npt.DTypeLike + if ggml_type == GGMLQuantizationType.F32: + item_count = n_elems + item_type = np.float32 + elif ggml_type == GGMLQuantizationType.F16: + item_count = n_elems + item_type = np.float16 + else: + item_count = n_bytes + item_type = np.uint8 + tensors.append(ReaderTensor( + name = str(bytes(name_data), encoding = 'utf-8'), + tensor_type = ggml_type, + shape = dims, + n_elements = n_elems, + n_bytes = n_bytes, + data_offset = data_offs, + data = self._get(data_offs, item_type, item_count), + field = field, + )) + self.tensors = tensors diff --git a/extensions/model-extension/scripts/gguf-py/gguf/gguf_writer.py b/extensions/model-extension/scripts/gguf-py/gguf/gguf_writer.py new file mode 100644 index 0000000000..16808196e7 --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/gguf/gguf_writer.py @@ -0,0 +1,427 @@ +from __future__ import annotations + +import os +import shutil +import struct +import tempfile +from enum import Enum, auto +from io import BufferedWriter +from typing import IO, Any, Sequence + +import numpy as np + +from .constants import ( + GGUF_DEFAULT_ALIGNMENT, + GGUF_MAGIC, + GGUF_VERSION, + GGMLQuantizationType, + GGUFEndian, + GGUFValueType, + Keys, + RopeScalingType, + TokenType, +) + + +class WriterState(Enum): + EMPTY = auto() + HEADER = auto() + KV_DATA = auto() + TI_DATA = auto() + + +class GGUFWriter: + fout: BufferedWriter + temp_file: tempfile.SpooledTemporaryFile[bytes] | None + tensors: list[np.ndarray[Any, Any]] + _simple_value_packing = { + GGUFValueType.UINT8: "B", + GGUFValueType.INT8: "b", + GGUFValueType.UINT16: "H", + GGUFValueType.INT16: "h", + GGUFValueType.UINT32: "I", + GGUFValueType.INT32: "i", + GGUFValueType.FLOAT32: "f", + GGUFValueType.UINT64: "Q", + GGUFValueType.INT64: "q", + GGUFValueType.FLOAT64: "d", + GGUFValueType.BOOL: "?", + } + + def __init__( + self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True, + endianess: GGUFEndian = GGUFEndian.LITTLE, + ): + self.fout = open(path, "wb") + self.arch = arch + self.endianess = endianess + self.offset_tensor = 0 + self.data_alignment = GGUF_DEFAULT_ALIGNMENT + self.kv_data = bytearray() + self.kv_data_count = 0 + self.ti_data = bytearray() + self.ti_data_count = 0 + self.use_temp_file = use_temp_file + self.temp_file = None + self.tensors = [] + print("gguf: This GGUF file is for {0} Endian only".format( + "Big" if self.endianess == GGUFEndian.BIG else "Little", + )) + self.state = WriterState.EMPTY + + self.add_architecture() + + def write_header_to_file(self) -> None: + if self.state is not WriterState.EMPTY: + raise ValueError(f'Expected output file to be empty, got {self.state}') + + self._write_packed(" None: + if self.state is not WriterState.HEADER: + raise ValueError(f'Expected output file to contain the header, got {self.state}') + + self.fout.write(self.kv_data) + self.flush() + self.state = WriterState.KV_DATA + + def write_ti_data_to_file(self) -> None: + if self.state is not WriterState.KV_DATA: + raise ValueError(f'Expected output file to contain KV data, got {self.state}') + + self.fout.write(self.ti_data) + self.flush() + self.state = WriterState.TI_DATA + + def add_key(self, key: str) -> None: + self.add_val(key, GGUFValueType.STRING, add_vtype=False) + + def add_uint8(self, key: str, val: int) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.UINT8) + + def add_int8(self, key: str, val: int) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.INT8) + + def add_uint16(self, key: str, val: int) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.UINT16) + + def add_int16(self, key: str, val: int) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.INT16) + + def add_uint32(self, key: str, val: int) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.UINT32) + + def add_int32(self, key: str, val: int) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.INT32) + + def add_float32(self, key: str, val: float) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.FLOAT32) + + def add_uint64(self, key: str, val: int) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.UINT64) + + def add_int64(self, key: str, val: int) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.INT64) + + def add_float64(self, key: str, val: float) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.FLOAT64) + + def add_bool(self, key: str, val: bool) -> None: + self.add_key(key) + self.add_val(val, GGUFValueType.BOOL) + + def add_string(self, key: str, val: str) -> None: + if not val: + return + self.add_key(key) + self.add_val(val, GGUFValueType.STRING) + + def add_array(self, key: str, val: Sequence[Any]) -> None: + if not isinstance(val, Sequence): + raise ValueError("Value must be a sequence for array type") + + self.add_key(key) + self.add_val(val, GGUFValueType.ARRAY) + + def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None: + if vtype is None: + vtype = GGUFValueType.get_type(val) + + if add_vtype: + self.kv_data += self._pack("I", vtype) + self.kv_data_count += 1 + + pack_fmt = self._simple_value_packing.get(vtype) + if pack_fmt is not None: + self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) + elif vtype == GGUFValueType.STRING: + encoded_val = val.encode("utf8") if isinstance(val, str) else val + self.kv_data += self._pack("Q", len(encoded_val)) + self.kv_data += encoded_val + elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: + ltype = GGUFValueType.get_type(val[0]) + if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): + raise ValueError("All items in a GGUF array should be of the same type") + self.kv_data += self._pack("I", ltype) + self.kv_data += self._pack("Q", len(val)) + for item in val: + self.add_val(item, add_vtype=False) + else: + raise ValueError("Invalid GGUF metadata value type or value") + + @staticmethod + def ggml_pad(x: int, n: int) -> int: + return ((x + n - 1) // n) * n + + def add_tensor_info( + self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], + tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None, + ) -> None: + if self.state is not WriterState.EMPTY: + raise ValueError(f'Expected output file to be empty, got {self.state}') + + if raw_dtype is None and tensor_dtype not in (np.float32, np.float16): + raise ValueError("Only F32 and F16 tensors are supported for now") + + encoded_name = name.encode("utf8") + self.ti_data += self._pack("Q", len(encoded_name)) + self.ti_data += encoded_name + n_dims = len(tensor_shape) + self.ti_data += self._pack("I", n_dims) + for i in range(n_dims): + self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) + if raw_dtype is None: + dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16 + else: + dtype = raw_dtype + self.ti_data += self._pack("I", dtype) + self.ti_data += self._pack("Q", self.offset_tensor) + self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) + self.ti_data_count += 1 + + def add_tensor( + self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, + raw_dtype: GGMLQuantizationType | None = None, + ) -> None: + if self.endianess == GGUFEndian.BIG: + tensor.byteswap(inplace=True) + if self.use_temp_file and self.temp_file is None: + fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) + fp.seek(0) + self.temp_file = fp + + shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape + self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) + + if self.temp_file is None: + self.tensors.append(tensor) + return + + tensor.tofile(self.temp_file) + self.write_padding(self.temp_file, tensor.nbytes) + + def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None: + pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n + if pad != 0: + fp.write(bytes([0] * pad)) + + def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: + if self.state is not WriterState.TI_DATA: + raise ValueError(f'Expected output file to contain tensor info, got {self.state}') + + if self.endianess == GGUFEndian.BIG: + tensor.byteswap(inplace=True) + self.write_padding(self.fout, self.fout.tell()) + tensor.tofile(self.fout) + self.write_padding(self.fout, tensor.nbytes) + + def write_tensors_to_file(self) -> None: + self.write_ti_data_to_file() + + self.write_padding(self.fout, self.fout.tell()) + + if self.temp_file is None: + while True: + try: + tensor = self.tensors.pop(0) + except IndexError: + break + tensor.tofile(self.fout) + self.write_padding(self.fout, tensor.nbytes) + return + + self.temp_file.seek(0) + + shutil.copyfileobj(self.temp_file, self.fout) + self.flush() + self.temp_file.close() + + def flush(self) -> None: + self.fout.flush() + + def close(self) -> None: + self.fout.close() + + def add_architecture(self) -> None: + self.add_string(Keys.General.ARCHITECTURE, self.arch) + + def add_author(self, author: str) -> None: + self.add_string(Keys.General.AUTHOR, author) + + def add_tensor_data_layout(self, layout: str) -> None: + self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) + + def add_url(self, url: str) -> None: + self.add_string(Keys.General.URL, url) + + def add_description(self, description: str) -> None: + self.add_string(Keys.General.DESCRIPTION, description) + + def add_source_url(self, url: str) -> None: + self.add_string(Keys.General.SOURCE_URL, url) + + def add_source_hf_repo(self, repo: str) -> None: + self.add_string(Keys.General.SOURCE_HF_REPO, repo) + + def add_file_type(self, ftype: int) -> None: + self.add_uint32(Keys.General.FILE_TYPE, ftype) + + def add_name(self, name: str) -> None: + self.add_string(Keys.General.NAME, name) + + def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None: + self.add_uint32( + Keys.General.QUANTIZATION_VERSION, quantization_version) + + def add_custom_alignment(self, alignment: int) -> None: + self.data_alignment = alignment + self.add_uint32(Keys.General.ALIGNMENT, alignment) + + def add_context_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length) + + def add_embedding_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length) + + def add_block_count(self, length: int) -> None: + self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length) + + def add_feed_forward_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) + + def add_parallel_residual(self, use: bool) -> None: + self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) + + def add_head_count(self, count: int) -> None: + self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) + + def add_head_count_kv(self, count: int) -> None: + self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + + def add_key_length(self, length: int) -> None: + self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length) + + def add_value_length(self, length: int) -> None: + self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length) + + def add_max_alibi_bias(self, bias: float) -> None: + self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias) + + def add_clamp_kqv(self, value: float) -> None: + self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value) + + def add_expert_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count) + + def add_expert_used_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count) + + def add_layer_norm_eps(self, value: float) -> None: + self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) + + def add_layer_norm_rms_eps(self, value: float) -> None: + self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) + + def add_rope_dimension_count(self, count: int) -> None: + self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) + + def add_rope_freq_base(self, value: float) -> None: + self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value) + + def add_rope_scaling_type(self, value: RopeScalingType) -> None: + self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value) + + def add_rope_scaling_factor(self, value: float) -> None: + self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) + + def add_rope_scaling_orig_ctx_len(self, value: int) -> None: + self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value) + + def add_rope_scaling_finetuned(self, value: bool) -> None: + self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value) + + def add_tokenizer_model(self, model: str) -> None: + self.add_string(Keys.Tokenizer.MODEL, model) + + def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: + self.add_array(Keys.Tokenizer.LIST, tokens) + + def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: + self.add_array(Keys.Tokenizer.MERGES, merges) + + def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None: + self.add_array(Keys.Tokenizer.TOKEN_TYPE, types) + + def add_token_scores(self, scores: Sequence[float]) -> None: + self.add_array(Keys.Tokenizer.SCORES, scores) + + def add_bos_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.BOS_ID, id) + + def add_eos_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.EOS_ID, id) + + def add_unk_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.UNK_ID, id) + + def add_sep_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.SEP_ID, id) + + def add_pad_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.PAD_ID, id) + + def add_add_bos_token(self, value: bool) -> None: + self.add_bool(Keys.Tokenizer.ADD_BOS, value) + + def add_add_eos_token(self, value: bool) -> None: + self.add_bool(Keys.Tokenizer.ADD_EOS, value) + + def add_add_space_prefix(self, value: bool) -> None: + self.add_bool(Keys.Tokenizer.ADD_PREFIX, value) + + def add_chat_template(self, value: str) -> None: + self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value) + + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: + pack_prefix = '' + if not skip_pack_prefix: + pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>' + return struct.pack(f'{pack_prefix}{fmt}', value) + + def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: + self.fout.write(self._pack(fmt, value, skip_pack_prefix)) diff --git a/extensions/model-extension/scripts/gguf-py/gguf/py.typed b/extensions/model-extension/scripts/gguf-py/gguf/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/extensions/model-extension/scripts/gguf-py/gguf/tensor_mapping.py b/extensions/model-extension/scripts/gguf-py/gguf/tensor_mapping.py new file mode 100644 index 0000000000..4f16d85044 --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/gguf/tensor_mapping.py @@ -0,0 +1,332 @@ +from __future__ import annotations + +from typing import Sequence + +from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES + + +class TensorNameMap: + mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { + # Token embeddings + MODEL_TENSOR.TOKEN_EMBD: ( + "gpt_neox.embed_in", # gptneox + "transformer.wte", # gpt2 gpt-j mpt refact qwen + "transformer.word_embeddings", # falcon + "word_embeddings", # bloom + "model.embed_tokens", # llama-hf + "tok_embeddings", # llama-pth + "embeddings.word_embeddings", # bert + "language_model.embedding.word_embeddings", # persimmon + "wte", # gpt2 + "transformer.embd.wte", # phi2 + "model.tok_embeddings", # internlm2 + ), + + # Token type embeddings + MODEL_TENSOR.TOKEN_TYPES: ( + "embeddings.token_type_embeddings", # bert + ), + + # Normalization of token embeddings + MODEL_TENSOR.TOKEN_EMBD_NORM: ( + "word_embeddings_layernorm", # bloom + ), + + # Position embeddings + MODEL_TENSOR.POS_EMBD: ( + "transformer.wpe", # gpt2 + "embeddings.position_embeddings", # bert + "wpe", # gpt2 + ), + + # Output + MODEL_TENSOR.OUTPUT: ( + "embed_out", # gptneox + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen + "output", # llama-pth bloom internlm2 + "word_embeddings_for_head", # persimmon + "lm_head.linear", # phi2 + ), + + # Output norm + MODEL_TENSOR.OUTPUT_NORM: ( + "gpt_neox.final_layer_norm", # gptneox + "transformer.ln_f", # gpt2 gpt-j falcon + "model.norm", # llama-hf baichuan internlm2 + "norm", # llama-pth + "embeddings.LayerNorm", # bert + "transformer.norm_f", # mpt + "ln_f", # refact bloom qwen gpt2 + "language_model.encoder.final_layernorm", # persimmon + "model.final_layernorm", # persimmon + "lm_head.ln", # phi2 + ), + + # Rope frequencies + MODEL_TENSOR.ROPE_FREQS: ( + "rope.freqs", # llama-pth + ), + } + + block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { + # Attention norm + MODEL_TENSOR.ATTN_NORM: ( + "gpt_neox.layers.{bid}.input_layernorm", # gptneox + "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen + "transformer.blocks.{bid}.norm_1", # mpt + "transformer.h.{bid}.input_layernorm", # falcon7b + "h.{bid}.input_layernorm", # bloom + "transformer.h.{bid}.ln_mlp", # falcon40b + "model.layers.{bid}.input_layernorm", # llama-hf + "layers.{bid}.attention_norm", # llama-pth + "encoder.layer.{bid}.attention.output.LayerNorm", # bert + "language_model.encoder.layers.{bid}.input_layernorm", # persimmon + "model.layers.{bid}.ln1", # yi + "h.{bid}.ln_1", # gpt2 + "transformer.h.{bid}.ln", # phi2 + "model.layers.layers.{bid}.norm", # plamo + "model.layers.{bid}.attention_norm", # internlm2 + ), + + # Attention norm 2 + MODEL_TENSOR.ATTN_NORM_2: ( + "transformer.h.{bid}.ln_attn", # falcon40b + ), + + # Attention query-key-value + MODEL_TENSOR.ATTN_QKV: ( + "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox + "transformer.h.{bid}.attn.c_attn", # gpt2 qwen + "transformer.blocks.{bid}.attn.Wqkv", # mpt + "transformer.h.{bid}.self_attention.query_key_value", # falcon + "h.{bid}.self_attention.query_key_value", # bloom + "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon + "model.layers.{bid}.self_attn.query_key_value", # persimmon + "h.{bid}.attn.c_attn", # gpt2 + "transformer.h.{bid}.mixer.Wqkv", # phi2 + ), + + # Attention query + MODEL_TENSOR.ATTN_Q: ( + "model.layers.{bid}.self_attn.q_proj", # llama-hf + "layers.{bid}.attention.wq", # llama-pth + "encoder.layer.{bid}.attention.self.query", # bert + "transformer.h.{bid}.attn.q_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.q_proj", # plamo + "model.layers.{bid}.attention.wq" # internlm2 + ), + + # Attention key + MODEL_TENSOR.ATTN_K: ( + "model.layers.{bid}.self_attn.k_proj", # llama-hf + "layers.{bid}.attention.wk", # llama-pth + "encoder.layer.{bid}.attention.self.key", # bert + "transformer.h.{bid}.attn.k_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.k_proj", # plamo + "model.layers.{bid}.attention.wk" # internlm2 + ), + + # Attention value + MODEL_TENSOR.ATTN_V: ( + "model.layers.{bid}.self_attn.v_proj", # llama-hf + "layers.{bid}.attention.wv", # llama-pth + "encoder.layer.{bid}.attention.self.value", # bert + "transformer.h.{bid}.attn.v_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.v_proj", # plamo + "model.layers.{bid}.attention.wv" # internlm2 + ), + + # Attention output + MODEL_TENSOR.ATTN_OUT: ( + "gpt_neox.layers.{bid}.attention.dense", # gptneox + "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen + "transformer.blocks.{bid}.attn.out_proj", # mpt + "transformer.h.{bid}.self_attention.dense", # falcon + "h.{bid}.self_attention.dense", # bloom + "model.layers.{bid}.self_attn.o_proj", # llama-hf + "layers.{bid}.attention.wo", # llama-pth + "encoder.layer.{bid}.attention.output.dense", # bert + "transformer.h.{bid}.attn.out_proj", # gpt-j + "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon + "model.layers.{bid}.self_attn.dense", # persimmon + "h.{bid}.attn.c_proj", # gpt2 + "transformer.h.{bid}.mixer.out_proj", # phi2 + "model.layers.layers.{bid}.self_attn.o_proj", # plamo + "model.layers.{bid}.attention.wo", # internlm2 + ), + + # Rotary embeddings + MODEL_TENSOR.ATTN_ROT_EMBD: ( + "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf + "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth + "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo + "transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell + ), + + # Feed-forward norm + MODEL_TENSOR.FFN_NORM: ( + "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox + "transformer.h.{bid}.ln_2", # gpt2 refact qwen + "h.{bid}.post_attention_layernorm", # bloom + "transformer.blocks.{bid}.norm_2", # mpt + "model.layers.{bid}.post_attention_layernorm", # llama-hf + "layers.{bid}.ffn_norm", # llama-pth + "encoder.layer.{bid}.output.LayerNorm", # bert + "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon + "model.layers.{bid}.ln2", # yi + "h.{bid}.ln_2", # gpt2 + "model.layers.{bid}.ffn_norm", # internlm2 + ), + + MODEL_TENSOR.FFN_GATE_INP: ( + "layers.{bid}.feed_forward.gate", # mixtral + "model.layers.{bid}.block_sparse_moe.gate", # mixtral + ), + + # Feed-forward up + MODEL_TENSOR.FFN_UP: ( + "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox + "transformer.h.{bid}.mlp.c_fc", # gpt2 + "transformer.blocks.{bid}.ffn.up_proj", # mpt + "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon + "h.{bid}.mlp.dense_h_to_4h", # bloom + "model.layers.{bid}.mlp.up_proj", # llama-hf refact + "layers.{bid}.feed_forward.w3", # llama-pth + "encoder.layer.{bid}.intermediate.dense", # bert + "transformer.h.{bid}.mlp.fc_in", # gpt-j + "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon + "model.layers.{bid}.mlp.dense_h_to_4h", # persimmon + "transformer.h.{bid}.mlp.w1", # qwen + "h.{bid}.mlp.c_fc", # gpt2 + "transformer.h.{bid}.mlp.fc1", # phi2 + "model.layers.{bid}.mlp.fc1", # phi2 + "model.layers.layers.{bid}.mlp.up_proj", # plamo + "model.layers.{bid}.feed_forward.w3", # internlm2 + ), + + MODEL_TENSOR.FFN_UP_EXP: ( + "layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral + "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral + ), + + # AWQ-activation gate + MODEL_TENSOR.FFN_ACT: ( + "transformer.blocks.{bid}.ffn.act", # mpt + ), + + # Feed-forward gate + MODEL_TENSOR.FFN_GATE: ( + "model.layers.{bid}.mlp.gate_proj", # llama-hf refact + "layers.{bid}.feed_forward.w1", # llama-pth + "transformer.h.{bid}.mlp.w2", # qwen + "model.layers.layers.{bid}.mlp.gate_proj", # plamo + "model.layers.{bid}.feed_forward.w1", # internlm2 + ), + + MODEL_TENSOR.FFN_GATE_EXP: ( + "layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral + "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral + ), + + # Feed-forward down + MODEL_TENSOR.FFN_DOWN: ( + "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox + "transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen + "transformer.blocks.{bid}.ffn.down_proj", # mpt + "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon + "h.{bid}.mlp.dense_4h_to_h", # bloom + "model.layers.{bid}.mlp.down_proj", # llama-hf + "layers.{bid}.feed_forward.w2", # llama-pth + "encoder.layer.{bid}.output.dense", # bert + "transformer.h.{bid}.mlp.fc_out", # gpt-j + "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon + "model.layers.{bid}.mlp.dense_4h_to_h", # persimmon + "h.{bid}.mlp.c_proj", # gpt2 + "transformer.h.{bid}.mlp.fc2", # phi2 + "model.layers.{bid}.mlp.fc2", # phi2 + "model.layers.layers.{bid}.mlp.down_proj", # plamo + "model.layers.{bid}.feed_forward.w2", # internlm2 + ), + + MODEL_TENSOR.FFN_DOWN_EXP: ( + "layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral + "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral + ), + + MODEL_TENSOR.ATTN_Q_NORM: ( + "language_model.encoder.layers.{bid}.self_attention.q_layernorm", + "model.layers.{bid}.self_attn.q_layernorm", # persimmon + ), + + MODEL_TENSOR.ATTN_K_NORM: ( + "language_model.encoder.layers.{bid}.self_attention.k_layernorm", + "model.layers.{bid}.self_attn.k_layernorm", # persimmon + ), + + MODEL_TENSOR.ROPE_FREQS: ( + "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon + ), + } + + mapping: dict[str, tuple[MODEL_TENSOR, str]] + + def __init__(self, arch: MODEL_ARCH, n_blocks: int): + self.mapping = {} + for tensor, keys in self.mappings_cfg.items(): + if tensor not in MODEL_TENSORS[arch]: + continue + tensor_name = TENSOR_NAMES[tensor] + self.mapping[tensor_name] = (tensor, tensor_name) + for key in keys: + self.mapping[key] = (tensor, tensor_name) + for bid in range(n_blocks): + for tensor, keys in self.block_mappings_cfg.items(): + if tensor not in MODEL_TENSORS[arch]: + continue + # TODO: make this configurable + n_experts = 8 + for xid in range(n_experts): + tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) + self.mapping[tensor_name] = (tensor, tensor_name) + for key in keys: + key = key.format(bid = bid, xid = xid) + self.mapping[key] = (tensor, tensor_name) + + def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: + result = self.mapping.get(key) + if result is not None: + return result + for suffix in try_suffixes: + if key.endswith(suffix): + result = self.mapping.get(key[:-len(suffix)]) + if result is not None: + return result[0], result[1] + suffix + return None + + def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None: + result = self.get_type_and_name(key, try_suffixes = try_suffixes) + if result is None: + return None + return result[1] + + def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None: + result = self.get_type_and_name(key, try_suffixes = try_suffixes) + if result is None: + return None + return result[0] + + def __getitem__(self, key: str) -> str: + try: + return self.mapping[key][1] + except KeyError: + raise KeyError(key) + + def __contains__(self, key: str) -> bool: + return key in self.mapping + + def __repr__(self) -> str: + return repr(self.mapping) + + +def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap: + return TensorNameMap(arch, n_blocks) diff --git a/extensions/model-extension/scripts/gguf-py/gguf/vocab.py b/extensions/model-extension/scripts/gguf-py/gguf/vocab.py new file mode 100644 index 0000000000..cd19429754 --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/gguf/vocab.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path +from typing import Any, Callable + +from .gguf_writer import GGUFWriter + + +class SpecialVocab: + merges: list[str] + add_special_token: dict[str, bool] + special_token_ids: dict[str, int] + chat_template: str | None + + def __init__( + self, path: str | os.PathLike[str], load_merges: bool = False, + special_token_types: tuple[str, ...] | None = None, + n_vocab: int | None = None, + ): + self.special_token_ids = {} + self.add_special_token = {} + self.n_vocab = n_vocab + self.load_merges = load_merges + self.merges = [] + self.chat_template = None + if special_token_types is not None: + self.special_token_types = special_token_types + else: + self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad') + self._load(Path(path)) + + def __repr__(self) -> str: + return ''.format( + len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset", + ) + + def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None: + if self.merges: + if not quiet: + print(f'gguf: Adding {len(self.merges)} merge(s).') + gw.add_token_merges(self.merges) + elif self.load_merges: + print( + 'gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.', + file = sys.stderr, + ) + for typ, tokid in self.special_token_ids.items(): + id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None) + if id_handler is None: + print( + f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping', + file = sys.stderr, + ) + continue + if not quiet: + print(f'gguf: Setting special token type {typ} to {tokid}') + id_handler(tokid) + for typ, value in self.add_special_token.items(): + add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None) + if add_handler is None: + print( + f'gguf: WARNING: No handler for add_{typ}_token with value {value} - skipping', + file = sys.stderr, + ) + continue + if not quiet: + print(f'gguf: Setting add_{typ}_token to {value}') + add_handler(value) + if self.chat_template is not None: + if not quiet: + print(f'gguf: Setting chat_template to {self.chat_template}') + gw.add_chat_template(self.chat_template) + + def _load(self, path: Path) -> None: + self._try_load_from_tokenizer_json(path) + self._try_load_from_config_json(path) + if self.load_merges and not self.merges: + self._try_load_merges_txt(path) + + def _try_load_merges_txt(self, path: Path) -> bool: + merges_file = path / 'merges.txt' + if not merges_file.is_file(): + return False + with open(merges_file, 'r', encoding = 'utf-8') as fp: + first_line = next(fp, '').strip() + if not first_line.startswith('#'): + fp.seek(0) + line_num = 0 + else: + line_num = 1 + merges = [] + for line in fp: + line_num += 1 + line = line.strip() + if not line: + continue + parts = line.split(None, 3) + if len(parts) != 2: + print( + f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring', + file = sys.stderr, + ) + continue + merges.append(f'{parts[0]} {parts[1]}') + self.merges = merges + return True + + def _set_special_token(self, typ: str, tid: Any) -> None: + if not isinstance(tid, int): + return + if tid < 0: + raise ValueError(f'invalid value for special token type {typ}: {tid}') + if self.n_vocab is None or tid < self.n_vocab: + if typ in self.special_token_ids: + return + self.special_token_ids[typ] = tid + return + print( + f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping', + file = sys.stderr, + ) + + def _try_load_from_tokenizer_json(self, path: Path) -> bool: + tokenizer_file = path / 'tokenizer.json' + if tokenizer_file.is_file(): + with open(tokenizer_file, encoding = 'utf-8') as f: + tokenizer = json.load(f) + if self.load_merges: + merges = tokenizer.get('model', {}).get('merges') + if isinstance(merges, list) and merges and isinstance(merges[0], str): + self.merges = merges + added_tokens = tokenizer.get('added_tokens', {}) + else: + added_tokens = {} + tokenizer_config_file = path / 'tokenizer_config.json' + if not tokenizer_config_file.is_file(): + return True + with open(tokenizer_config_file, encoding = 'utf-8') as f: + tokenizer_config = json.load(f) + chat_template = tokenizer_config.get('chat_template') + if chat_template is None or isinstance(chat_template, str): + self.chat_template = chat_template + else: + print( + f'gguf: WARNING: Bad type for chat_template field in {tokenizer_config_file!r} - ignoring', + file = sys.stderr + ) + for typ in self.special_token_types: + add_entry = tokenizer_config.get(f'add_{typ}_token') + if isinstance(add_entry, bool): + self.add_special_token[typ] = add_entry + if not added_tokens: + # We will need this to get the content for the token, so if it's empty + # may as well just give up. + continue + entry = tokenizer_config.get(f'{typ}_token') + if isinstance(entry, str): + tc_content = entry + elif isinstance(entry, dict): + entry_content = entry.get('content') + if not isinstance(entry_content, str): + continue + tc_content = entry_content + else: + continue + # We only need the first match here. + maybe_token_id = next( + (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content), + None, + ) + self._set_special_token(typ, maybe_token_id) + return True + + def _try_load_from_config_json(self, path: Path) -> bool: + config_file = path / 'config.json' + if not config_file.is_file(): + return False + with open(config_file, encoding = 'utf-8') as f: + config = json.load(f) + for typ in self.special_token_types: + self._set_special_token(typ, config.get(f'{typ}_token_id')) + return True diff --git a/extensions/model-extension/scripts/gguf-py/pyproject.toml b/extensions/model-extension/scripts/gguf-py/pyproject.toml new file mode 100644 index 0000000000..9789c2c877 --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/pyproject.toml @@ -0,0 +1,35 @@ +[tool.poetry] +name = "gguf" +version = "0.7.0" +description = "Read and write ML models in GGUF for GGML" +authors = ["GGML "] +packages = [ + {include = "gguf"}, + {include = "gguf/py.typed"}, + {include = "scripts"}, +] +readme = "README.md" +homepage = "https://ggml.ai" +repository = "https://github.com/ggerganov/llama.cpp" +keywords = ["ggml", "gguf", "llama.cpp"] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +[tool.poetry.dependencies] +python = ">=3.8" +numpy = ">=1.17" + +[tool.poetry.dev-dependencies] +pytest = "^5.2" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry.scripts] +gguf-convert-endian = "scripts:gguf_convert_endian_entrypoint" +gguf-dump = "scripts:gguf_dump_entrypoint" +gguf-set-metadata = "scripts:gguf_set_metadata_entrypoint" diff --git a/extensions/model-extension/scripts/gguf-py/scripts/__init__.py b/extensions/model-extension/scripts/gguf-py/scripts/__init__.py new file mode 100644 index 0000000000..77132db7a0 --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/scripts/__init__.py @@ -0,0 +1,12 @@ +import os + +from importlib import import_module + + +os.environ["NO_LOCAL_GGUF"] = "TRUE" + +gguf_convert_endian_entrypoint = import_module("scripts.gguf-convert-endian").main +gguf_dump_entrypoint = import_module("scripts.gguf-dump").main +gguf_set_metadata_entrypoint = import_module("scripts.gguf-set-metadata").main + +del import_module, os diff --git a/extensions/model-extension/scripts/gguf-py/scripts/gguf-convert-endian.py b/extensions/model-extension/scripts/gguf-py/scripts/gguf-convert-endian.py new file mode 100755 index 0000000000..10a16ad063 --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/scripts/gguf-convert-endian.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +import numpy as np + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +import gguf + + +def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None: + if np.uint32(1) == np.uint32(1).newbyteorder("<"): + # Host is little endian + host_endian = "little" + swapped_endian = "big" + else: + # Sorry PDP or other weird systems that don't use BE or LE. + host_endian = "big" + swapped_endian = "little" + if reader.byte_order == "S": + file_endian = swapped_endian + else: + file_endian = host_endian + order = host_endian if args.order == "native" else args.order + print(f"* Host is {host_endian.upper()} endian, GGUF file seems to be {file_endian.upper()} endian") + if file_endian == order: + print(f"* File is already {order.upper()} endian. Nothing to do.") + sys.exit(0) + print("* Checking tensors for conversion compatibility") + for tensor in reader.tensors: + if tensor.tensor_type not in ( + gguf.GGMLQuantizationType.F32, + gguf.GGMLQuantizationType.F16, + gguf.GGMLQuantizationType.Q8_0, + ): + raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}") + print(f"* Preparing to convert from {file_endian.upper()} to {order.upper()}") + if args.dry_run: + return + print("\n*** Warning *** Warning *** Warning **") + print("* This conversion process may damage the file. Ensure you have a backup.") + if order != host_endian: + print("* Requested endian differs from host, you will not be able to load the model on this machine.") + print("* The file will be modified immediately, so if conversion fails or is interrupted") + print("* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:") + response = input("YES, I am sure> ") + if response != "YES": + print("You didn't enter YES. Okay then, see ya!") + sys.exit(0) + print(f"\n* Converting fields ({len(reader.fields)})") + for idx, field in enumerate(reader.fields.values()): + print(f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}") + for part in field.parts: + part.byteswap(inplace=True) + print(f"\n* Converting tensors ({len(reader.tensors)})") + for idx, tensor in enumerate(reader.tensors): + print( + f" - {idx:4}: Converting tensor {repr(tensor.name)}, type={tensor.tensor_type.name}, " + f"elements={tensor.n_elements}... ", + end="", + ) + tensor_type = tensor.tensor_type + for part in tensor.field.parts: + part.byteswap(inplace=True) + if tensor_type != gguf.GGMLQuantizationType.Q8_0: + tensor.data.byteswap(inplace=True) + print() + continue + # A Q8_0 block consists of a f16 delta followed by 32 int8 quants, so 34 bytes + block_size = 34 + n_blocks = len(tensor.data) // block_size + for block_num in range(n_blocks): + block_offs = block_num * block_size + # I know I said f16, but it doesn't matter here - any simple 16 bit type works. + delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) + delta.byteswap(inplace=True) + if block_num % 100000 == 0: + print(f"[{(n_blocks - block_num) // 1000}K]", end="") + sys.stdout.flush() + print() + print("* Completion") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Convert GGUF file byte order") + parser.add_argument( + "model", type=str, + help="GGUF format model filename", + ) + parser.add_argument( + "order", type=str, choices=['big', 'little', 'native'], + help="Requested byte order", + ) + parser.add_argument( + "--dry-run", action="store_true", + help="Don't actually change anything", + ) + args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) + print(f'* Loading: {args.model}') + reader = gguf.GGUFReader(args.model, 'r' if args.dry_run else 'r+') + convert_byteorder(reader, args) + + +if __name__ == "__main__": + main() diff --git a/extensions/model-extension/scripts/gguf-py/scripts/gguf-dump.py b/extensions/model-extension/scripts/gguf-py/scripts/gguf-dump.py new file mode 100755 index 0000000000..dbf8915089 --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/scripts/gguf-dump.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path +from typing import Any + +import numpy as np + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf import GGUFReader, GGUFValueType # noqa: E402 + + +def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]: + host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG' + if reader.byte_order == 'S': + file_endian = 'BIG' if host_endian == 'LITTLE' else 'LITTLE' + else: + file_endian = host_endian + return (host_endian, file_endian) + + +# For more information about what field.parts and field.data represent, +# please see the comments in the modify_gguf.py example. +def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: + host_endian, file_endian = get_file_host_endian(reader) + print(f'* File is {file_endian} endian, script is running on a {host_endian} endian host.') + print(f'\n* Dumping {len(reader.fields)} key/value pair(s)') + for n, field in enumerate(reader.fields.values(), 1): + if not field.types: + pretty_type = 'N/A' + elif field.types[0] == GGUFValueType.ARRAY: + nest_count = len(field.types) - 1 + pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count + else: + pretty_type = str(field.types[-1].name) + print(f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}', end = '') + if len(field.types) == 1: + curr_type = field.types[0] + if curr_type == GGUFValueType.STRING: + print(' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60])), end = '') + elif field.types[0] in reader.gguf_scalar_to_np: + print(' = {0}'.format(field.parts[-1][0]), end = '') + print() + if args.no_tensors: + return + print(f'\n* Dumping {len(reader.tensors)} tensor(s)') + for n, tensor in enumerate(reader.tensors, 1): + prettydims = ', '.join('{0:5}'.format(d) for d in list(tensor.shape) + [1] * (4 - len(tensor.shape))) + print(f' {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}') + + +def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: + import json + host_endian, file_endian = get_file_host_endian(reader) + metadata: dict[str, Any] = {} + tensors: dict[str, Any] = {} + result = { + "filename": args.model, + "endian": file_endian, + "metadata": metadata, + "tensors": tensors, + } + for idx, field in enumerate(reader.fields.values()): + curr: dict[str, Any] = { + "index": idx, + "type": field.types[0].name if field.types else 'UNKNOWN', + "offset": field.offset, + } + metadata[field.name] = curr + if field.types[:1] == [GGUFValueType.ARRAY]: + curr["array_types"] = [t.name for t in field.types][1:] + if not args.json_array: + continue + itype = field.types[-1] + if itype == GGUFValueType.STRING: + curr["value"] = [str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data] + else: + curr["value"] = [pv for idx in field.data for pv in field.parts[idx].tolist()] + elif field.types[0] == GGUFValueType.STRING: + curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8") + else: + curr["value"] = field.parts[-1].tolist()[0] + if not args.no_tensors: + for idx, tensor in enumerate(reader.tensors): + tensors[tensor.name] = { + "index": idx, + "shape": tensor.shape.tolist(), + "type": tensor.tensor_type.name, + "offset": tensor.field.offset, + } + json.dump(result, sys.stdout) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Dump GGUF file metadata") + parser.add_argument("model", type=str, help="GGUF format model filename") + parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata") + parser.add_argument("--json", action="store_true", help="Produce JSON output") + parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)") + args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) + if not args.json: + print(f'* Loading: {args.model}') + reader = GGUFReader(args.model, 'r') + if args.json: + dump_metadata_json(reader, args) + else: + dump_metadata(reader, args) + + +if __name__ == '__main__': + main() diff --git a/extensions/model-extension/scripts/gguf-py/scripts/gguf-set-metadata.py b/extensions/model-extension/scripts/gguf-py/scripts/gguf-set-metadata.py new file mode 100755 index 0000000000..3ebdfa898a --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/scripts/gguf-set-metadata.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +import argparse +import os +import sys +from pathlib import Path + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf import GGUFReader # noqa: E402 + + +def minimal_example(filename: str) -> None: + reader = GGUFReader(filename, 'r+') + field = reader.fields['tokenizer.ggml.bos_token_id'] + if field is None: + return + part_index = field.data[0] + field.parts[part_index][0] = 2 # Set tokenizer.ggml.bos_token_id to 2 + # + # So what's this field.data thing? It's helpful because field.parts contains + # _every_ part of the GGUF field. For example, tokenizer.ggml.bos_token_id consists + # of: + # + # Part index 0: Key length (27) + # Part index 1: Key data ("tokenizer.ggml.bos_token_id") + # Part index 2: Field type (4, the id for GGUFValueType.UINT32) + # Part index 3: Field value + # + # Note also that each part is an NDArray slice, so even a part that + # is only a single value like the key length will be a NDArray of + # the key length type (numpy.uint32). + # + # The .data attribute in the Field is a list of relevant part indexes + # and doesn't contain internal GGUF details like the key length part. + # In this case, .data will be [3] - just the part index of the + # field value itself. + + +def set_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: + field = reader.get_field(args.key) + if field is None: + print(f'! Field {repr(args.key)} not found', file = sys.stderr) + sys.exit(1) + # Note that field.types is a list of types. This is because the GGUF + # format supports arrays. For example, an array of UINT32 would + # look like [GGUFValueType.ARRAY, GGUFValueType.UINT32] + handler = reader.gguf_scalar_to_np.get(field.types[0]) if field.types else None + if handler is None: + print( + f'! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}', + file = sys.stderr, + ) + sys.exit(1) + current_value = field.parts[field.data[0]][0] + new_value = handler(args.value) + print(f'* Preparing to change field {repr(args.key)} from {current_value} to {new_value}') + if current_value == new_value: + print(f'- Key {repr(args.key)} already set to requested value {current_value}') + sys.exit(0) + if args.dry_run: + sys.exit(0) + if not args.force: + print('*** Warning *** Warning *** Warning **') + print('* Changing fields in a GGUF file can make it unusable. Proceed at your own risk.') + print('* Enter exactly YES if you are positive you want to proceed:') + response = input('YES, I am sure> ') + if response != 'YES': + print("You didn't enter YES. Okay then, see ya!") + sys.exit(0) + field.parts[field.data[0]][0] = new_value + print('* Field changed. Successful completion.') + + +def main() -> None: + parser = argparse.ArgumentParser(description="Set a simple value in GGUF file metadata") + parser.add_argument("model", type=str, help="GGUF format model filename") + parser.add_argument("key", type=str, help="Metadata key to set") + parser.add_argument("value", type=str, help="Metadata value to set") + parser.add_argument("--dry-run", action="store_true", help="Don't actually change anything") + parser.add_argument("--force", action="store_true", help="Change the field without confirmation") + args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) + print(f'* Loading: {args.model}') + reader = GGUFReader(args.model, 'r' if args.dry_run else 'r+') + set_metadata(reader, args) + + +if __name__ == '__main__': + main() diff --git a/extensions/model-extension/scripts/gguf-py/tests/test_gguf.py b/extensions/model-extension/scripts/gguf-py/tests/test_gguf.py new file mode 100644 index 0000000000..0adeb7d557 --- /dev/null +++ b/extensions/model-extension/scripts/gguf-py/tests/test_gguf.py @@ -0,0 +1,7 @@ +import gguf # noqa: F401 + +# TODO: add tests + + +def test_write_gguf() -> None: + pass diff --git a/extensions/model-extension/scripts/install_deps.py b/extensions/model-extension/scripts/install_deps.py new file mode 100644 index 0000000000..2dfabed077 --- /dev/null +++ b/extensions/model-extension/scripts/install_deps.py @@ -0,0 +1,14 @@ +import subprocess +import sys + +deps = [ + 'numpy~=1.24.4', + 'sentencepiece~=0.1.98', + 'transformers>=4.35.2,<5.0.0', + 'gguf>=0.1.0', + 'protobuf>=4.21.0,<5.0.0', + 'torch~=2.1.1', + 'packaging>=20.0', + 'tiktoken~=0.5.0' +] +subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', '--force-reinstall', *deps]) diff --git a/extensions/model-extension/scripts/version.txt b/extensions/model-extension/scripts/version.txt new file mode 100644 index 0000000000..f743d6c4a4 --- /dev/null +++ b/extensions/model-extension/scripts/version.txt @@ -0,0 +1 @@ +b2106 \ No newline at end of file diff --git a/extensions/model-extension/src/@types/InvalidHostError.ts b/extensions/model-extension/src/@types/InvalidHostError.ts new file mode 100644 index 0000000000..47262206ee --- /dev/null +++ b/extensions/model-extension/src/@types/InvalidHostError.ts @@ -0,0 +1,6 @@ +export class InvalidHostError extends Error { + constructor(message: string) { + super(message) + this.name = 'InvalidHostError' + } +} diff --git a/extensions/model-extension/src/@types/NotSupportModelError.ts b/extensions/model-extension/src/@types/NotSupportModelError.ts new file mode 100644 index 0000000000..0a19461763 --- /dev/null +++ b/extensions/model-extension/src/@types/NotSupportModelError.ts @@ -0,0 +1,6 @@ +export class NotSupportedModelError extends Error { + constructor(message: string) { + super(message) + this.name = 'NotSupportedModelError' + } +} diff --git a/extensions/model-extension/src/@types/global.d.ts b/extensions/model-extension/src/@types/global.d.ts new file mode 100644 index 0000000000..3878d4bf25 --- /dev/null +++ b/extensions/model-extension/src/@types/global.d.ts @@ -0,0 +1,14 @@ +export {} +declare global { + declare const DEFAULT_MODEL: object + declare const NODE: string + + interface Core { + api: APIFunctions + events: EventEmitter + } + interface Window { + core?: Core | undefined + electronAPI?: any | undefined + } +} diff --git a/extensions/model-extension/src/helpers/path.ts b/extensions/model-extension/src/helpers/path.ts new file mode 100644 index 0000000000..cbb151aa6c --- /dev/null +++ b/extensions/model-extension/src/helpers/path.ts @@ -0,0 +1,11 @@ +/** + * try to retrieve the download file name from the source url + */ + +export function extractFileName(url: string, fileExtension: string): string { + const extractedFileName = url.split('/').pop() + const fileName = extractedFileName.toLowerCase().endsWith(fileExtension) + ? extractedFileName + : extractedFileName + fileExtension + return fileName +} diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts new file mode 100644 index 0000000000..7561ee6edf --- /dev/null +++ b/extensions/model-extension/src/index.ts @@ -0,0 +1,1043 @@ +import { + fs, + downloadFile, + abortDownload, + InferenceEngine, + joinPath, + ModelExtension, + Model, + getJanDataFolderPath, + events, + DownloadEvent, + DownloadRoute, + DownloadState, + OptionType, + ImportingModel, + LocalImportModelEvent, + baseName, + GpuSetting, + DownloadRequest, + executeOnMain, + HuggingFaceRepoData, + Quantization, + log, + getFileSize, + AllQuantizations, + ModelEvent, +} from '@janhq/core' + +import { extractFileName } from './helpers/path' +import { GGUFMetadata, gguf } from '@huggingface/gguf' +import { NotSupportedModelError } from './@types/NotSupportModelError' +import { InvalidHostError } from './@types/InvalidHostError' + +declare const SETTINGS: Array +enum Settings { + huggingFaceAccessToken = 'hugging-face-access-token', +} + +/** + * A extension for models + */ +export default class JanModelExtension extends ModelExtension { + private static readonly _homeDir = 'file://models' + private static readonly _modelMetadataFileName = 'model.json' + private static readonly _supportedModelFormat = '.gguf' + private static readonly _incompletedModelFileName = '.download' + private static readonly _offlineInferenceEngine = [ + InferenceEngine.nitro, + InferenceEngine.nitro_tensorrt_llm, + ] + private static readonly _tensorRtEngineFormat = '.engine' + private static readonly _supportedGpuArch = ['ampere', 'ada'] + private static readonly _safetensorsRegexs = [ + /model\.safetensors$/, + /model-[0-9]+-of-[0-9]+\.safetensors$/, + ] + private static readonly _pytorchRegexs = [ + /pytorch_model\.bin$/, + /consolidated\.[0-9]+\.pth$/, + /pytorch_model-[0-9]+-of-[0-9]+\.bin$/, + /.*\.pt$/, + ] + interrupted = false + + /** + * Called when the extension is loaded. + * @override + */ + async onLoad() { + // Handle Desktop Events + this.registerSettings(SETTINGS) + this.handleDesktopEvents() + } + + /** + * Called when the extension is unloaded. + * @override + */ + async onUnload() {} + + /** + * Downloads a machine learning model. + * @param model - The model to download. + * @param network - Optional object to specify proxy/whether to ignore SSL certificates. + * @returns A Promise that resolves when the model is downloaded. + */ + async downloadModel( + model: Model, + gpuSettings?: GpuSetting, + network?: { ignoreSSL?: boolean; proxy?: string } + ): Promise { + // create corresponding directory + const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id]) + if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath) + const modelJsonPath = await joinPath([modelDirPath, 'model.json']) + if (!(await fs.existsSync(modelJsonPath))) { + await fs.writeFileSync(modelJsonPath, JSON.stringify(model, null, 2)) + events.emit(ModelEvent.OnModelsUpdate, {}) + } + if (model.engine === InferenceEngine.nitro_tensorrt_llm) { + if (!gpuSettings || gpuSettings.gpus.length === 0) { + console.error('No GPU found. Please check your GPU setting.') + return + } + const firstGpu = gpuSettings.gpus[0] + if (!firstGpu.name.toLowerCase().includes('nvidia')) { + console.error('No Nvidia GPU found. Please check your GPU setting.') + return + } + const gpuArch = firstGpu.arch + if (gpuArch === undefined) { + console.error( + 'No GPU architecture found. Please check your GPU setting.' + ) + return + } + + if (!JanModelExtension._supportedGpuArch.includes(gpuArch)) { + console.debug( + `Your GPU: ${JSON.stringify(firstGpu)} is not supported. Only 30xx, 40xx series are supported.` + ) + return + } + + const os = 'windows' // TODO: remove this hard coded value + + const newSources = model.sources.map((source) => { + const newSource = { ...source } + newSource.url = newSource.url + .replace(//g, os) + .replace(//g, gpuArch) + return newSource + }) + model.sources = newSources + } + + console.debug(`Download sources: ${JSON.stringify(model.sources)}`) + + if (model.sources.length > 1) { + // path to model binaries + for (const source of model.sources) { + let path = extractFileName( + source.url, + JanModelExtension._supportedModelFormat + ) + if (source.filename) { + path = await joinPath([modelDirPath, source.filename]) + } + const downloadRequest: DownloadRequest = { + url: source.url, + localPath: path, + } + downloadFile(downloadRequest, network) + } + // TODO: handle multiple binaries for web later + } else { + const fileName = extractFileName( + model.sources[0]?.url, + JanModelExtension._supportedModelFormat + ) + const path = await joinPath([modelDirPath, fileName]) + const downloadRequest: DownloadRequest = { + url: model.sources[0]?.url, + localPath: path, + } + downloadFile(downloadRequest, network) + + if (window && window.core?.api && window.core.api.baseApiUrl) { + this.startPollingDownloadProgress(model.id) + } + } + } + + private toHuggingFaceUrl(repoId: string): string { + try { + const url = new URL(repoId) + if (url.host !== 'huggingface.co') { + throw new InvalidHostError(`Invalid Hugging Face repo URL: ${repoId}`) + } + + const paths = url.pathname.split('/').filter((e) => e.trim().length > 0) + if (paths.length < 2) { + throw new InvalidHostError(`Invalid Hugging Face repo URL: ${repoId}`) + } + + return `${url.origin}/api/models/${paths[0]}/${paths[1]}` + } catch (err) { + if (err instanceof InvalidHostError) { + throw err + } + + if (repoId.startsWith('https')) { + throw new Error(`Cannot parse url: ${repoId}`) + } + + return `https://huggingface.co/api/models/${repoId}` + } + } + + async fetchHuggingFaceRepoData(repoId: string): Promise { + const sanitizedUrl = this.toHuggingFaceUrl(repoId) + console.debug('sanitizedUrl', sanitizedUrl) + + const huggingFaceAccessToken = ( + await this.getSetting(Settings.huggingFaceAccessToken, '') + ).trim() + + const headers = { + Accept: 'application/json', + } + + if (huggingFaceAccessToken.length > 0) { + headers['Authorization'] = `Bearer ${huggingFaceAccessToken}` + } + + const res = await fetch(sanitizedUrl, { + headers: headers, + }) + const response = await res.json() + if (response['error'] != null) { + throw new Error(response['error']) + } + + const data = response as HuggingFaceRepoData + + if (data.tags.indexOf('gguf') === -1) { + throw new NotSupportedModelError( + `${repoId} is not supported. Only GGUF models are supported.` + ) + } + + const promises: Promise[] = [] + + // fetching file sizes + const url = new URL(sanitizedUrl) + const paths = url.pathname.split('/').filter((e) => e.trim().length > 0) + + for (const sibling of data.siblings) { + const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}` + sibling.downloadUrl = downloadUrl + promises.push(getFileSize(downloadUrl)) + } + + const result = await Promise.all(promises) + for (let i = 0; i < data.siblings.length; i++) { + data.siblings[i].fileSize = result[i] + } + + AllQuantizations.forEach((quantization) => { + data.siblings.forEach((sibling) => { + if (!sibling.quantization && sibling.rfilename.includes(quantization)) { + sibling.quantization = quantization + } + }) + }) + + data.modelUrl = `https://huggingface.co/${paths[2]}/${paths[3]}` + return data + } + + async fetchModelMetadata(url: string): Promise { + const { metadata } = await gguf(url) + return metadata + } + + /** + * Specifically for Jan server. + */ + private async startPollingDownloadProgress(modelId: string): Promise { + // wait for some seconds before polling + await new Promise((resolve) => setTimeout(resolve, 3000)) + + return new Promise((resolve) => { + const interval = setInterval(async () => { + fetch( + `${window.core.api.baseApiUrl}/v1/download/${DownloadRoute.getDownloadProgress}/${modelId}`, + { + method: 'GET', + headers: { contentType: 'application/json' }, + } + ).then(async (res) => { + const state: DownloadState = await res.json() + if (state.downloadState === 'end') { + events.emit(DownloadEvent.onFileDownloadSuccess, state) + clearInterval(interval) + resolve() + return + } + + if (state.downloadState === 'error') { + events.emit(DownloadEvent.onFileDownloadError, state) + clearInterval(interval) + resolve() + return + } + + events.emit(DownloadEvent.onFileDownloadUpdate, state) + }) + }, 1000) + }) + } + + /** + * Cancels the download of a specific machine learning model. + * + * @param {string} modelId - The ID of the model whose download is to be cancelled. + * @returns {Promise} A promise that resolves when the download has been cancelled. + */ + async cancelModelDownload(modelId: string): Promise { + const path = await joinPath([JanModelExtension._homeDir, modelId, modelId]) + try { + await abortDownload(path) + await fs.unlinkSync(path) + } catch (e) { + console.error(e) + } + } + + /** + * Deletes a machine learning model. + * @param filePath - The path to the model file to delete. + * @returns A Promise that resolves when the model is deleted. + */ + async deleteModel(modelId: string): Promise { + try { + const dirPath = await joinPath([JanModelExtension._homeDir, modelId]) + const jsonFilePath = await joinPath([ + dirPath, + JanModelExtension._modelMetadataFileName, + ]) + const modelInfo = JSON.parse( + await this.readModelMetadata(jsonFilePath) + ) as Model + + const isUserImportModel = + modelInfo.metadata?.author?.toLowerCase() === 'user' + if (isUserImportModel) { + // just delete the folder + return fs.rm(dirPath) + } + + // remove all files under dirPath except model.json + const files = await fs.readdirSync(dirPath) + const deletePromises = files.map(async (fileName: string) => { + if (fileName !== JanModelExtension._modelMetadataFileName) { + return fs.unlinkSync(await joinPath([dirPath, fileName])) + } + }) + await Promise.allSettled(deletePromises) + } catch (err) { + console.error(err) + } + } + + /** + * Saves a machine learning model. + * @param model - The model to save. + * @returns A Promise that resolves when the model is saved. + */ + async saveModel(model: Model): Promise { + const jsonFilePath = await joinPath([ + JanModelExtension._homeDir, + model.id, + JanModelExtension._modelMetadataFileName, + ]) + + try { + await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2)) + } catch (err) { + console.error(err) + } + } + + /** + * Gets all downloaded models. + * @returns A Promise that resolves with an array of all models. + */ + async getDownloadedModels(): Promise { + return await this.getModelsMetadata( + async (modelDir: string, model: Model) => { + if (!JanModelExtension._offlineInferenceEngine.includes(model.engine)) + return true + + // model binaries (sources) are absolute path & exist + const existFiles = await Promise.all( + model.sources.map( + (source) => + // Supposed to be a local file url + !source.url.startsWith(`http://`) && + !source.url.startsWith(`https://`) + ) + ) + if (existFiles.every((exist) => exist)) return true + + const result = await fs + .readdirSync(await joinPath([JanModelExtension._homeDir, modelDir])) + .then((files: string[]) => { + // Model binary exists in the directory + // Model binary name can match model ID or be a .gguf file and not be an incompleted model file + return ( + files.includes(modelDir) || + files.filter((file) => { + if ( + file.endsWith(JanModelExtension._incompletedModelFileName) + ) { + return false + } + return ( + file + .toLowerCase() + .includes(JanModelExtension._supportedModelFormat) || + file + .toLowerCase() + .includes(JanModelExtension._tensorRtEngineFormat) + ) + })?.length > 0 // TODO: find better way (can use basename to check the file name with source url) + ) + }) + + return result + } + ) + } + + private async getModelJsonPath( + folderFullPath: string + ): Promise { + // try to find model.json recursively inside each folder + if (!(await fs.existsSync(folderFullPath))) return undefined + const files: string[] = await fs.readdirSync(folderFullPath) + if (files.length === 0) return undefined + if (files.includes(JanModelExtension._modelMetadataFileName)) { + return joinPath([ + folderFullPath, + JanModelExtension._modelMetadataFileName, + ]) + } + // continue recursive + for (const file of files) { + const path = await joinPath([folderFullPath, file]) + const fileStats = await fs.fileStat(path) + if (fileStats.isDirectory) { + const result = await this.getModelJsonPath(path) + if (result) return result + } + } + } + + private async getModelsMetadata( + selector?: (path: string, model: Model) => Promise + ): Promise { + try { + if (!(await fs.existsSync(JanModelExtension._homeDir))) { + console.debug('Model folder not found') + return [] + } + + const files: string[] = await fs.readdirSync(JanModelExtension._homeDir) + + const allDirectories: string[] = [] + for (const file of files) { + if (file === '.DS_Store') continue + if (file === 'config') continue + allDirectories.push(file) + } + + const readJsonPromises = allDirectories.map(async (dirName) => { + // filter out directories that don't match the selector + // read model.json + const folderFullPath = await joinPath([ + JanModelExtension._homeDir, + dirName, + ]) + const jsonPath = await this.getModelJsonPath(folderFullPath) + + if (await fs.existsSync(jsonPath)) { + // if we have the model.json file, read it + let model = await this.readModelMetadata(jsonPath) + + model = typeof model === 'object' ? model : JSON.parse(model) + + // This to ensure backward compatibility with `model.json` with `source_url` + if (model['source_url'] != null) { + model['sources'] = [ + { + filename: model.id, + url: model['source_url'], + }, + ] + } + + if (selector && !(await selector?.(dirName, model))) { + return + } + return model + } else { + // otherwise, we generate our own model file + // TODO: we might have more than one binary file here. This will be addressed with new version of Model file + // which is the PR from Hiro on branch Jan can see + return this.generateModelMetadata(dirName) + } + }) + const results = await Promise.allSettled(readJsonPromises) + const modelData = results.map((result) => { + if (result.status === 'fulfilled' && result.value) { + try { + const model = + typeof result.value === 'object' + ? result.value + : JSON.parse(result.value) + return model as Model + } catch { + console.debug(`Unable to parse model metadata: ${result.value}`) + } + } + return undefined + }) + + return modelData.filter((e) => !!e) + } catch (err) { + console.error(err) + return [] + } + } + + private readModelMetadata(path: string) { + return fs.readFileSync(path, 'utf-8') + } + + /** + * Handle the case where we have the model directory but we don't have the corresponding + * model.json file associated with it. + * + * This function will create a model.json file for the model. + * It works only with single binary file model. + * + * @param dirName the director which reside in ~/jan/models but does not have model.json file. + */ + private async generateModelMetadata(dirName: string): Promise { + const files: string[] = await fs.readdirSync( + await joinPath([JanModelExtension._homeDir, dirName]) + ) + + // sort files by name + files.sort() + + // find the first file which is not a directory + let binaryFileName: string | undefined = undefined + let binaryFileSize: number | undefined = undefined + + for (const file of files) { + if (file.endsWith(JanModelExtension._supportedModelFormat)) { + const path = await joinPath([JanModelExtension._homeDir, dirName, file]) + const fileStats = await fs.fileStat(path) + if (fileStats.isDirectory) continue + binaryFileSize = fileStats.size + binaryFileName = file + break + } + } + + if (!binaryFileName) { + console.warn(`Unable to find binary file for model ${dirName}`) + return + } + + const defaultModel = (await this.getDefaultModel()) as Model + if (!defaultModel) { + console.error('Unable to find default model') + return + } + + const model: Model = { + ...defaultModel, + // Overwrite default N/A fields + id: dirName, + name: dirName, + sources: [ + { + url: binaryFileName, + filename: binaryFileName, + }, + ], + settings: { + ...defaultModel.settings, + llama_model_path: binaryFileName, + }, + created: Date.now(), + description: '', + metadata: { + size: binaryFileSize, + author: 'User', + tags: [], + }, + } + + const modelFilePath = await joinPath([ + JanModelExtension._homeDir, + dirName, + JanModelExtension._modelMetadataFileName, + ]) + + await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2)) + + return model + } + + override async getDefaultModel(): Promise { + const defaultModel = DEFAULT_MODEL as Model + return defaultModel + } + + /** + * Gets all available models. + * @returns A Promise that resolves with an array of all models. + */ + async getConfiguredModels(): Promise { + return this.getModelsMetadata() + } + + handleDesktopEvents() { + if (window && window.electronAPI) { + window.electronAPI.onFileDownloadUpdate( + async (_event: string, state: DownloadState | undefined) => { + if (!state) return + state.downloadState = 'downloading' + events.emit(DownloadEvent.onFileDownloadUpdate, state) + } + ) + window.electronAPI.onFileDownloadError( + async (_event: string, state: DownloadState) => { + state.downloadState = 'error' + events.emit(DownloadEvent.onFileDownloadError, state) + } + ) + window.electronAPI.onFileDownloadSuccess( + async (_event: string, state: DownloadState) => { + state.downloadState = 'end' + events.emit(DownloadEvent.onFileDownloadSuccess, state) + } + ) + } + } + + private async importModelSymlink( + modelBinaryPath: string, + modelFolderName: string, + modelFolderPath: string + ): Promise { + const fileStats = await fs.fileStat(modelBinaryPath, true) + const binaryFileSize = fileStats.size + + // Just need to generate model.json there + const defaultModel = (await this.getDefaultModel()) as Model + if (!defaultModel) { + console.error('Unable to find default model') + return + } + + const binaryFileName = await baseName(modelBinaryPath) + + const model: Model = { + ...defaultModel, + id: modelFolderName, + name: modelFolderName, + sources: [ + { + url: modelBinaryPath, + filename: binaryFileName, + }, + ], + settings: { + ...defaultModel.settings, + llama_model_path: binaryFileName, + }, + created: Date.now(), + description: '', + metadata: { + size: binaryFileSize, + author: 'User', + tags: [], + }, + } + + const modelFilePath = await joinPath([ + modelFolderPath, + JanModelExtension._modelMetadataFileName, + ]) + + await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2)) + + return model + } + + async updateModelInfo(modelInfo: Partial): Promise { + const modelId = modelInfo.id + if (modelInfo.id == null) throw new Error('Model ID is required') + + const janDataFolderPath = await getJanDataFolderPath() + const jsonFilePath = await joinPath([ + janDataFolderPath, + 'models', + modelId, + JanModelExtension._modelMetadataFileName, + ]) + const model = JSON.parse( + await this.readModelMetadata(jsonFilePath) + ) as Model + + const updatedModel: Model = { + ...model, + ...modelInfo, + metadata: { + ...model.metadata, + tags: modelInfo.metadata?.tags ?? [], + }, + } + + await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2)) + return updatedModel + } + + private async importModel( + model: ImportingModel, + optionType: OptionType + ): Promise { + const binaryName = (await baseName(model.path)).replace(/\s/g, '') + + let modelFolderName = binaryName + if (binaryName.endsWith(JanModelExtension._supportedModelFormat)) { + modelFolderName = binaryName.replace( + JanModelExtension._supportedModelFormat, + '' + ) + } + + const modelFolderPath = await this.getModelFolderName(modelFolderName) + await fs.mkdir(modelFolderPath) + + const uniqueFolderName = await baseName(modelFolderPath) + const modelBinaryFile = binaryName.endsWith( + JanModelExtension._supportedModelFormat + ) + ? binaryName + : `${binaryName}${JanModelExtension._supportedModelFormat}` + + const binaryPath = await joinPath([modelFolderPath, modelBinaryFile]) + + if (optionType === 'SYMLINK') { + return this.importModelSymlink( + model.path, + uniqueFolderName, + modelFolderPath + ) + } + + const srcStat = await fs.fileStat(model.path, true) + + // interval getting the file size to calculate the percentage + const interval = setInterval(async () => { + const destStats = await fs.fileStat(binaryPath, true) + const percentage = destStats.size / srcStat.size + events.emit(LocalImportModelEvent.onLocalImportModelUpdate, { + ...model, + percentage, + }) + }, 1000) + + await fs.copyFile(model.path, binaryPath) + + clearInterval(interval) + + // generate model json + return this.generateModelMetadata(uniqueFolderName) + } + + private async getModelFolderName( + modelFolderName: string, + count?: number + ): Promise { + const newModelFolderName = count + ? `${modelFolderName}-${count}` + : modelFolderName + + const janDataFolderPath = await getJanDataFolderPath() + const modelFolderPath = await joinPath([ + janDataFolderPath, + 'models', + newModelFolderName, + ]) + + const isFolderExist = await fs.existsSync(modelFolderPath) + if (!isFolderExist) { + return modelFolderPath + } else { + const newCount = (count ?? 0) + 1 + return this.getModelFolderName(modelFolderName, newCount) + } + } + + async importModels( + models: ImportingModel[], + optionType: OptionType + ): Promise { + const importedModels: Model[] = [] + + for (const model of models) { + events.emit(LocalImportModelEvent.onLocalImportModelUpdate, model) + try { + const importedModel = await this.importModel(model, optionType) + events.emit(LocalImportModelEvent.onLocalImportModelSuccess, { + ...model, + modelId: importedModel.id, + }) + importedModels.push(importedModel) + } catch (err) { + events.emit(LocalImportModelEvent.onLocalImportModelFailed, { + ...model, + error: err, + }) + } + } + + events.emit( + LocalImportModelEvent.onLocalImportModelFinished, + importedModels + ) + } + + private getGgufFileList( + repoData: HuggingFaceRepoData, + selectedQuantization: Quantization + ): string[] { + return repoData.siblings + .map((file) => file.rfilename) + .filter((file) => file.indexOf(selectedQuantization) !== -1) + .filter((file) => file.endsWith('.gguf')) + } + + private getFileList(repoData: HuggingFaceRepoData): string[] { + // SafeTensors first, if not, then PyTorch + const modelFiles = repoData.siblings + .map((file) => file.rfilename) + .filter((file) => + JanModelExtension._safetensorsRegexs.some((regex) => regex.test(file)) + ) + if (modelFiles.length === 0) { + repoData.siblings.forEach((file) => { + if ( + JanModelExtension._pytorchRegexs.some((regex) => + regex.test(file.rfilename) + ) + ) { + modelFiles.push(file.rfilename) + } + }) + } + + const vocabFiles = [ + 'tokenizer.model', + 'vocab.json', + 'tokenizer.json', + ].filter((file) => + repoData.siblings.some((sibling) => sibling.rfilename === file) + ) + + const etcFiles = repoData.siblings + .map((file) => file.rfilename) + .filter( + (file) => + (file.endsWith('.json') && !vocabFiles.includes(file)) || + file.endsWith('.txt') || + file.endsWith('.py') || + file.endsWith('.tiktoken') + ) + + return [...modelFiles, ...vocabFiles, ...etcFiles] + } + + private async getModelDirPath(repoID: string): Promise { + const modelName = repoID.split('/').slice(1).join('/') + return joinPath([await getJanDataFolderPath(), 'models', modelName]) + } + + private async getConvertedModelPath(repoID: string): Promise { + const modelName = repoID.split('/').slice(1).join('/') + const modelDirPath = await this.getModelDirPath(repoID) + return joinPath([modelDirPath, modelName + '.gguf']) + } + + private async getQuantizedModelPath( + repoID: string, + quantization: Quantization + ): Promise { + const modelName = repoID.split('/').slice(1).join('/') + const modelDirPath = await this.getModelDirPath(repoID) + return joinPath([ + modelDirPath, + modelName + `-${quantization.toLowerCase()}.gguf`, + ]) + } + private getCtxLength(config: { + max_sequence_length?: number + max_position_embeddings?: number + n_ctx?: number + }): number { + if (config.max_sequence_length) return config.max_sequence_length + if (config.max_position_embeddings) return config.max_position_embeddings + if (config.n_ctx) return config.n_ctx + return 2048 + } + + /** + * Converts a Hugging Face model to GGUF. + * @param repoID - The repo ID of the model to convert. + * @returns A promise that resolves when the conversion is complete. + */ + async convert(repoID: string): Promise { + if (this.interrupted) return + const modelDirPath = await this.getModelDirPath(repoID) + const modelOutPath = await this.getConvertedModelPath(repoID) + if (!(await fs.existsSync(modelDirPath))) { + throw new Error('Model dir not found') + } + if (await fs.existsSync(modelOutPath)) return + + await executeOnMain(NODE, 'installDeps') + if (this.interrupted) return + + try { + await executeOnMain( + NODE, + 'convertHf', + modelDirPath, + modelOutPath + '.temp' + ) + } catch (err) { + log(`[Conversion]::Debug: Error using hf-to-gguf.py, trying convert.py`) + + let ctx = 2048 + try { + const config = await fs.readFileSync( + await joinPath([modelDirPath, 'config.json']), + 'utf8' + ) + const configParsed = JSON.parse(config) + ctx = this.getCtxLength(configParsed) + configParsed.max_sequence_length = ctx + await fs.writeFileSync( + await joinPath([modelDirPath, 'config.json']), + JSON.stringify(configParsed, null, 2) + ) + } catch (err) { + log(`${err}`) + // ignore missing config.json + } + + const bpe = await fs.existsSync( + await joinPath([modelDirPath, 'vocab.json']) + ) + + await executeOnMain( + NODE, + 'convert', + modelDirPath, + modelOutPath + '.temp', + { + ctx, + bpe, + } + ) + } + await executeOnMain( + NODE, + 'renameSync', + modelOutPath + '.temp', + modelOutPath + ) + + for (const file of await fs.readdirSync(modelDirPath)) { + if ( + modelOutPath.endsWith(file) || + (file.endsWith('config.json') && !file.endsWith('_config.json')) + ) + continue + await fs.unlinkSync(await joinPath([modelDirPath, file])) + } + } + + /** + * Quantizes a GGUF model. + * @param repoID - The repo ID of the model to quantize. + * @param quantization - The quantization to use. + * @returns A promise that resolves when the quantization is complete. + */ + async quantize(repoID: string, quantization: Quantization): Promise { + if (this.interrupted) return + const modelDirPath = await this.getModelDirPath(repoID) + const modelOutPath = await this.getQuantizedModelPath(repoID, quantization) + if (!(await fs.existsSync(modelDirPath))) { + throw new Error('Model dir not found') + } + if (await fs.existsSync(modelOutPath)) return + + await executeOnMain( + NODE, + 'quantize', + await this.getConvertedModelPath(repoID), + modelOutPath + '.temp', + quantization + ) + await executeOnMain( + NODE, + 'renameSync', + modelOutPath + '.temp', + modelOutPath + ) + + await fs.unlinkSync(await this.getConvertedModelPath(repoID)) + } + + /** + * Cancels the convert of current Hugging Face model. + * @param repoID - The repository ID to cancel. + * @param repoData - The repository data to cancel. + * @returns {Promise} A promise that resolves when the download has been cancelled. + */ + async cancelConvert( + repoID: string, + repoData: HuggingFaceRepoData + ): Promise { + this.interrupted = true + const modelDirPath = await this.getModelDirPath(repoID) + const files = this.getFileList(repoData) + for (const file of files) { + const filePath = file + const localPath = await joinPath([modelDirPath, filePath]) + await abortDownload(localPath) + } + + executeOnMain(NODE, 'killProcesses') + } +} diff --git a/extensions/model-extension/src/node/index.ts b/extensions/model-extension/src/node/index.ts new file mode 100644 index 0000000000..991548e001 --- /dev/null +++ b/extensions/model-extension/src/node/index.ts @@ -0,0 +1,182 @@ +import { PythonShell } from 'python-shell' +import { spawn, ChildProcess } from 'child_process' +import { resolve as presolve, join as pjoin } from 'path' +import { log, Quantization } from '@janhq/core/node' +import { statSync } from 'fs' +export { renameSync } from 'fs' + +let pythonShell: PythonShell | undefined = undefined +let quantizeProcess: ChildProcess | undefined = undefined + +export const getSize = (path: string): number => statSync(path).size + +export const killProcesses = () => { + if (pythonShell) { + pythonShell.kill() + pythonShell = undefined + } + if (quantizeProcess) { + quantizeProcess.kill() + quantizeProcess = undefined + } +} + +export const getQuantizeExecutable = (): string => { + let binaryFolder = pjoin(__dirname, '..', 'bin') // Current directory by default + let binaryName = 'quantize' + /** + * The binary folder is different for each platform. + */ + if (process.platform === 'win32') { + binaryFolder = pjoin(binaryFolder, 'win') + binaryName = 'quantize.exe' + } else if (process.platform === 'darwin') { + /** + * For MacOS: mac-universal both Silicon and InteL + */ + binaryFolder = pjoin(binaryFolder, 'mac-universal') + } else { + binaryFolder = pjoin(binaryFolder, 'linux-cpu') + } + return pjoin(binaryFolder, binaryName) +} + +export const installDeps = (): Promise => { + return new Promise((resolve, reject) => { + const _pythonShell = new PythonShell( + presolve(__dirname, '..', 'scripts', 'install_deps.py') + ) + _pythonShell.on('message', (message) => { + log(`[Install Deps]::Debug: ${message}`) + }) + _pythonShell.on('stderr', (stderr) => { + log(`[Install Deps]::Error: ${stderr}`) + }) + _pythonShell.on('error', (err) => { + pythonShell = undefined + log(`[Install Deps]::Error: ${err}`) + reject(err) + }) + _pythonShell.on('close', () => { + const exitCode = _pythonShell.exitCode + pythonShell = undefined + log( + `[Install Deps]::Debug: Deps installation exited with code: ${exitCode}` + ) + exitCode === 0 ? resolve() : reject(exitCode) + }) + }) +} + +export const convertHf = async ( + modelDirPath: string, + outPath: string +): Promise => { + return await new Promise((resolve, reject) => { + const _pythonShell = new PythonShell( + presolve(__dirname, '..', 'scripts', 'convert-hf-to-gguf.py'), + { + args: [modelDirPath, '--outfile', outPath], + } + ) + pythonShell = _pythonShell + _pythonShell.on('message', (message) => { + log(`[Conversion]::Debug: ${message}`) + }) + _pythonShell.on('stderr', (stderr) => { + log(`[Conversion]::Error: ${stderr}`) + }) + _pythonShell.on('error', (err) => { + pythonShell = undefined + log(`[Conversion]::Error: ${err}`) + reject(err) + }) + _pythonShell.on('close', () => { + const exitCode = _pythonShell.exitCode + pythonShell = undefined + if (exitCode !== 0) { + log(`[Conversion]::Debug: Conversion exited with code: ${exitCode}`) + reject(exitCode) + } else { + resolve() + } + }) + }) +} + +export const convert = async ( + modelDirPath: string, + outPath: string, + { ctx, bpe }: { ctx?: number; bpe?: boolean } +): Promise => { + const args = [modelDirPath, '--outfile', outPath] + if (ctx) { + args.push('--ctx') + args.push(ctx.toString()) + } + if (bpe) { + args.push('--vocab-type') + args.push('bpe') + } + return await new Promise((resolve, reject) => { + const _pythonShell = new PythonShell( + presolve(__dirname, '..', 'scripts', 'convert.py'), + { + args, + } + ) + _pythonShell.on('message', (message) => { + log(`[Conversion]::Debug: ${message}`) + }) + _pythonShell.on('stderr', (stderr) => { + log(`[Conversion]::Error: ${stderr}`) + }) + _pythonShell.on('error', (err) => { + pythonShell = undefined + log(`[Conversion]::Error: ${err}`) + reject(err) + }) + _pythonShell.on('close', () => { + const exitCode = _pythonShell.exitCode + pythonShell = undefined + if (exitCode !== 0) { + log(`[Conversion]::Debug: Conversion exited with code: ${exitCode}`) + reject(exitCode) + } else { + resolve() + } + }) + }) +} + +export const quantize = async ( + modelPath: string, + outPath: string, + quantization: Quantization +): Promise => { + return await new Promise((resolve, reject) => { + const quantizeExecutable = getQuantizeExecutable() + const _quantizeProcess = spawn(quantizeExecutable, [ + modelPath, + outPath, + quantization, + ]) + quantizeProcess = _quantizeProcess + + _quantizeProcess.stdout?.on('data', (data) => { + log(`[Quantization]::Debug: ${data}`) + }) + _quantizeProcess.stderr?.on('data', (data) => { + log(`[Quantization]::Error: ${data}`) + }) + + _quantizeProcess.on('close', (code) => { + if (code !== 0) { + log(`[Quantization]::Debug: Quantization exited with code: ${code}`) + reject(code) + } else { + resolve() + } + }) + }) +} diff --git a/extensions/model-extension/tsconfig.json b/extensions/model-extension/tsconfig.json new file mode 100644 index 0000000000..addd8e1274 --- /dev/null +++ b/extensions/model-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "esnext", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/monitoring-extension/README.md b/extensions/monitoring-extension/README.md new file mode 100644 index 0000000000..f9690da09d --- /dev/null +++ b/extensions/monitoring-extension/README.md @@ -0,0 +1,75 @@ +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/monitoring-extension/bin/.gitkeep b/extensions/monitoring-extension/bin/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/extensions/monitoring-extension/download.bat b/extensions/monitoring-extension/download.bat new file mode 100644 index 0000000000..14e0aadd91 --- /dev/null +++ b/extensions/monitoring-extension/download.bat @@ -0,0 +1,2 @@ +@echo off +.\node_modules\.bin\download https://catalog.jan.ai/vulkaninfoSDK.exe -o ./bin \ No newline at end of file diff --git a/extensions/monitoring-extension/package.json b/extensions/monitoring-extension/package.json new file mode 100644 index 0000000000..e728b46291 --- /dev/null +++ b/extensions/monitoring-extension/package.json @@ -0,0 +1,52 @@ +{ + "name": "@janhq/monitoring-extension", + "productName": "System Monitoring", + "version": "1.0.10", + "description": "This extension provides system health and OS level data", + "main": "dist/index.js", + "node": "dist/node/index.cjs.js", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc --module commonjs && rollup -c rollup.config.ts && npm run download-artifacts", + "download-artifacts": "run-script-os && cpx \"bin/**\" \"dist/bin\"", + "download-artifacts:darwin": "echo 'No artifacts to download for darwin'", + "download-artifacts:win32": "download.bat", + "download-artifacts:linux": "download https://catalog.jan.ai/vulkaninfo -o ./bin && chmod +x ./bin/vulkaninfo", + "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/node/index.cjs.js" + }, + "devDependencies": { + "@rollup/plugin-commonjs": "^25.0.7", + "@rollup/plugin-json": "^6.1.0", + "@rollup/plugin-node-resolve": "^15.2.3", + "@types/node": "^20.11.4", + "@types/node-os-utils": "^1.3.4", + "run-script-os": "^1.1.6", + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "rollup": "^2.38.5", + "rollup-plugin-define": "^1.0.1", + "rollup-plugin-sourcemaps": "^0.6.3", + "rollup-plugin-typescript2": "^0.36.0", + "typescript": "^5.3.3", + "download-cli": "^1.1.1" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "@rollup/plugin-replace": "^5.0.5", + "node-os-utils": "^1.3.7" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "node-os-utils", + "@janhq/core" + ] +} diff --git a/extensions/monitoring-extension/resources/settings.json b/extensions/monitoring-extension/resources/settings.json new file mode 100644 index 0000000000..40b0b97f9a --- /dev/null +++ b/extensions/monitoring-extension/resources/settings.json @@ -0,0 +1,22 @@ +[ + { + "key": "log-enabled", + "title": "Enable App Logs", + "description": "Saves app logs locally on your computer. This enables you to send us crash reports.", + "controllerType": "checkbox", + "controllerProps": { + "value": true + } + }, + { + "key": "log-cleaning-interval", + "title": "Log Cleaning Interval", + "description": "Automatically delete local logs after a certain time interval (in milliseconds).", + "controllerType": "input", + "controllerProps": { + "value": "120000", + "placeholder": "Interval in milliseconds. E.g. 120000", + "textAlign": "right" + } + } +] \ No newline at end of file diff --git a/extensions/monitoring-extension/rollup.config.ts b/extensions/monitoring-extension/rollup.config.ts new file mode 100644 index 0000000000..b054d62916 --- /dev/null +++ b/extensions/monitoring-extension/rollup.config.ts @@ -0,0 +1,71 @@ +import resolve from '@rollup/plugin-node-resolve' +import commonjs from '@rollup/plugin-commonjs' +import sourceMaps from 'rollup-plugin-sourcemaps' +import typescript from 'rollup-plugin-typescript2' +import json from '@rollup/plugin-json' +import replace from '@rollup/plugin-replace' +const settingJson = require('./resources/settings.json') +const packageJson = require('./package.json') + +export default [ + { + input: `src/index.ts`, + output: [{ file: packageJson.main, format: 'es', sourcemap: true }], + // Indicate here external modules you don't wanna include in your bundle (i.e.: 'lodash') + external: [], + watch: { + include: 'src/**', + }, + plugins: [ + replace({ + preventAssignment: true, + NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), + SETTINGS: JSON.stringify(settingJson), + }), + // Allow json resolution + json(), + // Compile TypeScript files + typescript({ useTsconfigDeclarationDir: true }), + // Compile TypeScript files + // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) + commonjs(), + // Allow node_modules resolution, so you can use 'external' to control + // which external modules to include in the bundle + // https://github.com/rollup/rollup-plugin-node-resolve#usage + resolve({ + extensions: ['.js', '.ts', '.svelte'], + }), + + // Resolve source maps to the original source + sourceMaps(), + ], + }, + { + input: `src/node/index.ts`, + output: [ + { file: 'dist/node/index.cjs.js', format: 'cjs', sourcemap: true }, + ], + // Indicate here external modules you don't wanna include in your bundle (i.e.: 'lodash') + external: ['@janhq/core/node'], + watch: { + include: 'src/node/**', + }, + plugins: [ + // Allow json resolution + json(), + // Compile TypeScript files + typescript({ useTsconfigDeclarationDir: true }), + // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) + commonjs(), + // Allow node_modules resolution, so you can use 'external' to control + // which external modules to include in the bundle + // https://github.com/rollup/rollup-plugin-node-resolve#usage + resolve({ + extensions: ['.ts', '.js', '.json'], + }), + + // Resolve source maps to the original source + sourceMaps(), + ], + }, +] diff --git a/extensions/monitoring-extension/src/@types/global.d.ts b/extensions/monitoring-extension/src/@types/global.d.ts new file mode 100644 index 0000000000..dfa96a0b1b --- /dev/null +++ b/extensions/monitoring-extension/src/@types/global.d.ts @@ -0,0 +1,18 @@ +declare const NODE: string + +type CpuGpuInfo = { + cpu: { + usage: number + } + gpu: GpuInfo[] +} + +type GpuInfo = { + id: string + name: string + temperature: string + utilization: string + memoryTotal: string + memoryFree: string + memoryUtilization: string +} diff --git a/extensions/monitoring-extension/src/index.ts b/extensions/monitoring-extension/src/index.ts new file mode 100644 index 0000000000..1d21fde775 --- /dev/null +++ b/extensions/monitoring-extension/src/index.ts @@ -0,0 +1,89 @@ +import { + GpuSetting, + MonitoringExtension, + OperatingSystemInfo, + executeOnMain, +} from '@janhq/core' + +declare const SETTINGS: Array + +enum Settings { + logEnabled = 'log-enabled', + logCleaningInterval = 'log-cleaning-interval', +} +/** + * JanMonitoringExtension is a extension that provides system monitoring functionality. + * It implements the MonitoringExtension interface from the @janhq/core package. + */ +export default class JanMonitoringExtension extends MonitoringExtension { + /** + * Called when the extension is loaded. + */ + async onLoad() { + // Register extension settings + this.registerSettings(SETTINGS) + + const logEnabled = await this.getSetting(Settings.logEnabled, true) + const logCleaningInterval = parseInt( + await this.getSetting(Settings.logCleaningInterval, '120000') + ) + // Register File Logger provided by this extension + await executeOnMain(NODE, 'registerLogger', { + logEnabled, + logCleaningInterval: isNaN(logCleaningInterval) + ? 120000 + : logCleaningInterval, + }) + + // Attempt to fetch nvidia info + await executeOnMain(NODE, 'updateNvidiaInfo') + } + + onSettingUpdate(key: string, value: T): void { + if (key === Settings.logEnabled) { + executeOnMain(NODE, 'updateLogger', { logEnabled: value }) + } else if (key === Settings.logCleaningInterval) { + executeOnMain(NODE, 'updateLogger', { logCleaningInterval: value }) + } + } + + /** + * Called when the extension is unloaded. + */ + onUnload(): void { + // Register File Logger provided by this extension + executeOnMain(NODE, 'unregisterLogger') + } + + /** + * Returns the GPU configuration. + * @returns A Promise that resolves to an object containing the GPU configuration. + */ + async getGpuSetting(): Promise { + return executeOnMain(NODE, 'getGpuConfig') + } + + /** + * Returns information about the system resources. + * @returns A Promise that resolves to an object containing information about the system resources. + */ + getResourcesInfo(): Promise { + return executeOnMain(NODE, 'getResourcesInfo') + } + + /** + * Returns information about the current system load. + * @returns A Promise that resolves to an object containing information about the current system load. + */ + getCurrentLoad(): Promise { + return executeOnMain(NODE, 'getCurrentLoad') + } + + /** + * Returns information about the OS + * @returns + */ + getOsInfo(): Promise { + return executeOnMain(NODE, 'getOsInfo') + } +} diff --git a/extensions/monitoring-extension/src/node/index.ts b/extensions/monitoring-extension/src/node/index.ts new file mode 100644 index 0000000000..980ee75d1f --- /dev/null +++ b/extensions/monitoring-extension/src/node/index.ts @@ -0,0 +1,389 @@ +import { + GpuSetting, + GpuSettingInfo, + LoggerManager, + OperatingSystemInfo, + ResourceInfo, + SupportedPlatforms, + getJanDataFolderPath, + log, +} from '@janhq/core/node' +import { mem, cpu } from 'node-os-utils' +import { exec } from 'child_process' +import { writeFileSync, existsSync, readFileSync, mkdirSync } from 'fs' +import path from 'path' +import os from 'os' +import { FileLogger } from './logger' + +/** + * Path to the settings directory + **/ +export const SETTINGS_DIR = path.join(getJanDataFolderPath(), 'settings') +/** + * Path to the settings file + **/ +export const GPU_INFO_FILE = path.join(SETTINGS_DIR, 'settings.json') + +/** + * Default GPU settings + * TODO: This needs to be refactored to support multiple accelerators + **/ +const DEFAULT_SETTINGS: GpuSetting = { + notify: true, + run_mode: 'cpu', + nvidia_driver: { + exist: false, + version: '', + }, + cuda: { + exist: false, + version: '', + }, + gpus: [], + gpu_highest_vram: '', + gpus_in_use: [], + is_initial: true, + // TODO: This needs to be set based on user toggle in settings + vulkan: false, +} + +export const getGpuConfig = async (): Promise => { + if (process.platform === 'darwin') return undefined + if (existsSync(GPU_INFO_FILE)) + return JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8')) + return DEFAULT_SETTINGS +} + +export const getResourcesInfo = async (): Promise => { + const ramUsedInfo = await mem.used() + const totalMemory = ramUsedInfo.totalMemMb * 1024 * 1024 + const usedMemory = ramUsedInfo.usedMemMb * 1024 * 1024 + + const resourceInfo: ResourceInfo = { + mem: { + totalMemory, + usedMemory, + }, + } + + return resourceInfo +} + +export const getCurrentLoad = () => + new Promise(async (resolve, reject) => { + const cpuPercentage = await cpu.usage() + let data = { + run_mode: 'cpu', + gpus_in_use: [], + } + + if (process.platform !== 'darwin') { + data = JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8')) + } + + if (data.run_mode === 'gpu' && data.gpus_in_use.length > 0) { + const gpuIds = data.gpus_in_use.join(',') + if (gpuIds !== '' && data['vulkan'] !== true) { + exec( + `nvidia-smi --query-gpu=index,name,temperature.gpu,utilization.gpu,memory.total,memory.free,utilization.memory --format=csv,noheader,nounits --id=${gpuIds}`, + (error, stdout, _) => { + if (error) { + console.error(`exec error: ${error}`) + throw new Error(error.message) + } + const gpuInfo: GpuInfo[] = stdout + .trim() + .split('\n') + .map((line) => { + const [ + id, + name, + temperature, + utilization, + memoryTotal, + memoryFree, + memoryUtilization, + ] = line.split(', ').map((item) => item.replace(/\r/g, '')) + return { + id, + name, + temperature, + utilization, + memoryTotal, + memoryFree, + memoryUtilization, + } + }) + + resolve({ + cpu: { usage: cpuPercentage }, + gpu: gpuInfo, + }) + } + ) + } else { + // Handle the case where gpuIds is empty + resolve({ + cpu: { usage: cpuPercentage }, + gpu: [], + }) + } + } else { + // Handle the case where run_mode is not 'gpu' or no GPUs are in use + resolve({ + cpu: { usage: cpuPercentage }, + gpu: [], + }) + } + }) + +/** + * This will retrieve GPU information and persist settings.json + * Will be called when the extension is loaded to turn on GPU acceleration if supported + */ +export const updateNvidiaInfo = async () => { + // ignore if macos + if (process.platform === 'darwin') return + + try { + JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8')) + } catch (error) { + if (!existsSync(SETTINGS_DIR)) { + mkdirSync(SETTINGS_DIR, { + recursive: true, + }) + } + writeFileSync(GPU_INFO_FILE, JSON.stringify(DEFAULT_SETTINGS, null, 2)) + } + + await updateNvidiaDriverInfo() + await updateGpuInfo() +} + +const updateNvidiaDriverInfo = async () => + new Promise((resolve, reject) => { + exec( + 'nvidia-smi --query-gpu=driver_version --format=csv,noheader', + (error, stdout) => { + const data: GpuSetting = JSON.parse( + readFileSync(GPU_INFO_FILE, 'utf-8') + ) + + if (!error) { + const firstLine = stdout.split('\n')[0].trim() + data.nvidia_driver.exist = true + data.nvidia_driver.version = firstLine + } else { + data.nvidia_driver.exist = false + } + + writeFileSync(GPU_INFO_FILE, JSON.stringify(data, null, 2)) + resolve({}) + } + ) + }) + +const getGpuArch = (gpuName: string): string => { + if (!gpuName.toLowerCase().includes('nvidia')) return 'unknown' + + if (gpuName.includes('30')) return 'ampere' + else if (gpuName.includes('40')) return 'ada' + else return 'unknown' +} + +const updateGpuInfo = async () => + new Promise((resolve, reject) => { + let data: GpuSetting = JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8')) + + // Cuda + if (data.vulkan === true) { + // Vulkan + exec( + process.platform === 'win32' + ? `${__dirname}\\..\\bin\\vulkaninfoSDK.exe --summary` + : `${__dirname}/../bin/vulkaninfo --summary`, + async (error, stdout) => { + if (!error) { + const output = stdout.toString() + + log(output) + const gpuRegex = /GPU(\d+):(?:[\s\S]*?)deviceName\s*=\s*(.*)/g + + const gpus: GpuSettingInfo[] = [] + let match + while ((match = gpuRegex.exec(output)) !== null) { + const id = match[1] + const name = match[2] + const arch = getGpuArch(name) + gpus.push({ id, vram: '0', name, arch }) + } + data.gpus = gpus + + if (!data.gpus_in_use || data.gpus_in_use.length === 0) { + data.gpus_in_use = [data.gpus.length > 1 ? '1' : '0'] + } + + data = await updateCudaExistence(data) + writeFileSync(GPU_INFO_FILE, JSON.stringify(data, null, 2)) + log(`[APP]::${JSON.stringify(data)}`) + resolve({}) + } else { + reject(error) + } + } + ) + } else { + exec( + 'nvidia-smi --query-gpu=index,memory.total,name --format=csv,noheader,nounits', + async (error, stdout) => { + if (!error) { + log(`[SPECS]::${stdout}`) + // Get GPU info and gpu has higher memory first + let highestVram = 0 + let highestVramId = '0' + const gpus: GpuSettingInfo[] = stdout + .trim() + .split('\n') + .map((line) => { + let [id, vram, name] = line.split(', ') + const arch = getGpuArch(name) + vram = vram.replace(/\r/g, '') + if (parseFloat(vram) > highestVram) { + highestVram = parseFloat(vram) + highestVramId = id + } + return { id, vram, name, arch } + }) + + data.gpus = gpus + data.gpu_highest_vram = highestVramId + } else { + data.gpus = [] + data.gpu_highest_vram = '' + } + + if (!data.gpus_in_use || data.gpus_in_use.length === 0) { + data.gpus_in_use = [data.gpu_highest_vram] + } + + data = await updateCudaExistence(data) + console.log(data) + writeFileSync(GPU_INFO_FILE, JSON.stringify(data, null, 2)) + log(`[APP]::${JSON.stringify(data)}`) + resolve({}) + } + ) + } + }) + +/** + * Check if file exists in paths + */ +const checkFileExistenceInPaths = (file: string, paths: string[]): boolean => { + return paths.some((p) => existsSync(path.join(p, file))) +} + +/** + * Validate cuda for linux and windows + */ +const updateCudaExistence = async ( + data: GpuSetting = DEFAULT_SETTINGS +): Promise => { + let filesCuda12: string[] + let filesCuda11: string[] + let paths: string[] + let cudaVersion: string = '' + + if (process.platform === 'win32') { + filesCuda12 = ['cublas64_12.dll', 'cudart64_12.dll', 'cublasLt64_12.dll'] + filesCuda11 = ['cublas64_11.dll', 'cudart64_110.dll', 'cublasLt64_11.dll'] + paths = process.env.PATH ? process.env.PATH.split(path.delimiter) : [] + } else { + filesCuda12 = ['libcudart.so.12', 'libcublas.so.12', 'libcublasLt.so.12'] + filesCuda11 = ['libcudart.so.11.0', 'libcublas.so.11', 'libcublasLt.so.11'] + paths = process.env.LD_LIBRARY_PATH + ? process.env.LD_LIBRARY_PATH.split(path.delimiter) + : [] + paths.push('/usr/lib/x86_64-linux-gnu/') + } + + let cudaExists = filesCuda12.every( + (file) => existsSync(file) || checkFileExistenceInPaths(file, paths) + ) + + if (!cudaExists) { + cudaExists = filesCuda11.every( + (file) => existsSync(file) || checkFileExistenceInPaths(file, paths) + ) + if (cudaExists) { + cudaVersion = '11' + } + } else { + cudaVersion = '12' + } + + data.cuda.exist = cudaExists + data.cuda.version = cudaVersion + + console.debug(data.is_initial, data.gpus_in_use) + + if (cudaExists && data.is_initial && data.gpus_in_use.length > 0) { + data.run_mode = 'gpu' + } + + data.is_initial = false + + // Attempt to query CUDA using NVIDIA SMI + if (!cudaExists) { + await new Promise((resolve) => { + exec('nvidia-smi', (error, stdout) => { + if (!error) { + const regex = /CUDA\s*Version:\s*(\d+\.\d+)/g + const match = regex.exec(stdout) + if (match && match[1]) { + data.cuda.version = match[1] + } + } + console.log(data) + resolve() + }) + }) + } + return data +} + +export const getOsInfo = (): OperatingSystemInfo => { + const platform = + SupportedPlatforms.find((p) => p === process.platform) || 'unknown' + + const osInfo: OperatingSystemInfo = { + platform: platform, + arch: process.arch, + release: os.release(), + machine: os.machine(), + version: os.version(), + totalMem: os.totalmem(), + freeMem: os.freemem(), + } + + return osInfo +} + +export const registerLogger = ({ logEnabled, logCleaningInterval }) => { + const logger = new FileLogger(logEnabled, logCleaningInterval) + LoggerManager.instance().register(logger) + logger.cleanLogs() +} + +export const unregisterLogger = () => { + LoggerManager.instance().unregister('file') +} + +export const updateLogger = ({ logEnabled, logCleaningInterval }) => { + const logger = LoggerManager.instance().loggers.get('file') as FileLogger + if (logger && logEnabled !== undefined) logger.logEnabled = logEnabled + if (logger && logCleaningInterval) + logger.logCleaningInterval = logCleaningInterval + // Rerun + logger && logger.cleanLogs() +} diff --git a/extensions/monitoring-extension/src/node/logger.ts b/extensions/monitoring-extension/src/node/logger.ts new file mode 100644 index 0000000000..ca64ea2d97 --- /dev/null +++ b/extensions/monitoring-extension/src/node/logger.ts @@ -0,0 +1,142 @@ +import fs from 'fs' +import util from 'util' +import { + getAppConfigurations, + getJanDataFolderPath, + Logger, +} from '@janhq/core/node' +import path, { join } from 'path' + +export class FileLogger extends Logger { + name = 'file' + logCleaningInterval: number = 120000 + timeout: NodeJS.Timeout | null = null + appLogPath: string = './' + logEnabled: boolean = true + + constructor( + logEnabled: boolean = true, + logCleaningInterval: number = 120000 + ) { + super() + this.logEnabled = logEnabled + if (logCleaningInterval) this.logCleaningInterval = logCleaningInterval + + const appConfigurations = getAppConfigurations() + const logFolderPath = join(appConfigurations.data_folder, 'logs') + if (!fs.existsSync(logFolderPath)) { + fs.mkdirSync(logFolderPath, { recursive: true }) + } + + this.appLogPath = join(logFolderPath, 'app.log') + } + + log(args: any) { + if (!this.logEnabled) return + let message = args[0] + const scope = args[1] + if (!message) return + const path = this.appLogPath + if (!scope && !message.startsWith('[')) { + message = `[APP]::${message}` + } else if (scope) { + message = `${scope}::${message}` + } + + message = `${new Date().toISOString()} ${message}` + + writeLog(message, path) + } + + cleanLogs( + maxFileSizeBytes?: number | undefined, + daysToKeep?: number | undefined + ): void { + // clear existing timeout + // in case we rerun it with different values + if (this.timeout) clearTimeout(this.timeout) + this.timeout = undefined + + if (!this.logEnabled) return + + console.log( + 'Validating app logs. Next attempt in ', + this.logCleaningInterval + ) + + const size = maxFileSizeBytes ?? 1 * 1024 * 1024 // 1 MB + const days = daysToKeep ?? 7 // 7 days + const logDirectory = path.join(getJanDataFolderPath(), 'logs') + // Perform log cleaning + const currentDate = new Date() + if (fs.existsSync(logDirectory)) + fs.readdir(logDirectory, (err, files) => { + if (err) { + console.error('Error reading log directory:', err) + return + } + + files.forEach((file) => { + const filePath = path.join(logDirectory, file) + fs.stat(filePath, (err, stats) => { + if (err) { + console.error('Error getting file stats:', err) + return + } + + // Check size + if (stats.size > size) { + fs.unlink(filePath, (err) => { + if (err) { + console.error('Error deleting log file:', err) + return + } + console.debug( + `Deleted log file due to exceeding size limit: ${filePath}` + ) + }) + } else { + // Check age + const creationDate = new Date(stats.ctime) + const daysDifference = Math.floor( + (currentDate.getTime() - creationDate.getTime()) / + (1000 * 3600 * 24) + ) + if (daysDifference > days) { + fs.unlink(filePath, (err) => { + if (err) { + console.error('Error deleting log file:', err) + return + } + console.debug(`Deleted old log file: ${filePath}`) + }) + } + } + }) + }) + }) + + // Schedule the next execution with doubled delays + this.timeout = setTimeout( + () => this.cleanLogs(maxFileSizeBytes, daysToKeep), + this.logCleaningInterval + ) + } +} + +const writeLog = (message: string, logPath: string) => { + if (!fs.existsSync(logPath)) { + const logDirectory = path.join(getJanDataFolderPath(), 'logs') + if (!fs.existsSync(logDirectory)) { + fs.mkdirSync(logDirectory) + } + fs.writeFileSync(logPath, message) + } else { + const logFile = fs.createWriteStream(logPath, { + flags: 'a', + }) + logFile.write(util.format(message) + '\n') + logFile.close() + console.debug(message) + } +} diff --git a/extensions/monitoring-extension/tsconfig.json b/extensions/monitoring-extension/tsconfig.json new file mode 100644 index 0000000000..2477d58ce5 --- /dev/null +++ b/extensions/monitoring-extension/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/package.json b/extensions/package.json new file mode 100644 index 0000000000..b5b34bf79c --- /dev/null +++ b/extensions/package.json @@ -0,0 +1,11 @@ +{ + "private": true, + "workspaces": { + "packages": [ + "**" + ], + "nohoist": [ + "**" + ] + } +} diff --git a/extensions/tensorrt-llm-extension/README.md b/extensions/tensorrt-llm-extension/README.md new file mode 100644 index 0000000000..34a6705160 --- /dev/null +++ b/extensions/tensorrt-llm-extension/README.md @@ -0,0 +1,79 @@ +# Tensorrt-LLM Extension + +Created using Jan extension example + +# Create a Jan Extension using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 + +## Create Your Own Extension + +To create your own extension, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your extension directory now + +## Update the Extension Metadata + +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your extension. + +## Update the Extension Code + +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your extension code: + +- Most Jan Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { events, MessageEvent, MessageRequest } from '@janhq/core' + + function onStart(): Promise { + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) + } + ``` + + For more information about the Jan Extension Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/tensorrt-llm-extension/package.json b/extensions/tensorrt-llm-extension/package.json new file mode 100644 index 0000000000..c5cb548093 --- /dev/null +++ b/extensions/tensorrt-llm-extension/package.json @@ -0,0 +1,78 @@ +{ + "name": "@janhq/tensorrt-llm-extension", + "productName": "TensorRT-LLM Inference Engine", + "version": "0.0.3", + "description": "This extension enables Nvidia's TensorRT-LLM for the fastest GPU acceleration. See the [setup guide](https://jan.ai/guides/providers/tensorrt-llm/) for next steps.", + "main": "dist/index.js", + "node": "dist/node/index.cjs.js", + "author": "Jan ", + "license": "AGPL-3.0", + "config": { + "host": "127.0.0.1", + "port": "3929" + }, + "compatibility": { + "platform": [ + "win32" + ], + "app": [ + "0.1.0" + ] + }, + "tensorrtVersion": "0.1.8", + "provider": "nitro-tensorrt-llm", + "scripts": { + "build": "tsc --module commonjs && rollup -c rollup.config.ts", + "build:publish:win32": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", + "build:publish:linux": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", + "build:publish:darwin": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", + "build:publish": "run-script-os" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/node/index.cjs.js" + }, + "devDependencies": { + "@rollup/plugin-commonjs": "^25.0.7", + "@rollup/plugin-json": "^6.1.0", + "@rollup/plugin-node-resolve": "^15.2.3", + "@rollup/plugin-replace": "^5.0.5", + "@types/decompress": "4.2.7", + "@types/node": "^20.11.4", + "@types/os-utils": "^0.0.4", + "@types/tcp-port-used": "^1.0.4", + "cpx": "^1.5.0", + "download-cli": "^1.1.1", + "rimraf": "^3.0.2", + "rollup": "^2.38.5", + "rollup-plugin-define": "^1.0.1", + "rollup-plugin-sourcemaps": "^0.6.3", + "rollup-plugin-typescript2": "^0.36.0", + "run-script-os": "^1.1.6", + "typescript": "^5.2.2" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "decompress": "^4.2.1", + "fetch-retry": "^5.0.6", + "rxjs": "^7.8.1", + "tcp-port-used": "^1.0.2", + "terminate": "^2.6.1", + "ulidx": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "tcp-port-used", + "fetch-retry", + "decompress", + "@janhq/core", + "terminate" + ] +} diff --git a/extensions/tensorrt-llm-extension/resources/models.json b/extensions/tensorrt-llm-extension/resources/models.json new file mode 100644 index 0000000000..387b711040 --- /dev/null +++ b/extensions/tensorrt-llm-extension/resources/models.json @@ -0,0 +1,156 @@ +[ + { + "sources": [ + { + "filename": "config.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/config.json" + }, + { + "filename": "mistral_float16_tp1_rank0.engine", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/mistral_float16_tp1_rank0.engine" + }, + { + "filename": "tokenizer.model", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/tokenizer.model" + }, + { + "filename": "special_tokens_map.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/special_tokens_map.json" + }, + { + "filename": "tokenizer.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/tokenizer.json" + }, + { + "filename": "tokenizer_config.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/tokenizer_config.json" + }, + { + "filename": "model.cache", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/model.cache" + } + ], + "id": "llamacorn-1.1b-chat-fp16", + "object": "model", + "name": "LlamaCorn 1.1B Chat FP16", + "version": "1.0", + "description": "LlamaCorn is a refined version of TinyLlama-1.1B, optimized for conversational quality, running on consumer devices through TensorRT-LLM", + "format": "TensorRT-LLM", + "settings": { + "ctx_len": 2048, + "text_model": false + }, + "parameters": { + "max_tokens": 4096 + }, + "metadata": { + "author": "LLama", + "tags": ["TensorRT-LLM", "1B", "Finetuned"], + "size": 2151000000 + }, + "engine": "nitro-tensorrt-llm" + }, + { + "sources": [ + { + "filename": "config.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/config.json" + }, + { + "filename": "mistral_float16_tp1_rank0.engine", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/mistral_float16_tp1_rank0.engine" + }, + { + "filename": "tokenizer.model", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/tokenizer.model" + }, + { + "filename": "special_tokens_map.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/special_tokens_map.json" + }, + { + "filename": "tokenizer.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/tokenizer.json" + }, + { + "filename": "tokenizer_config.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/tokenizer_config.json" + }, + { + "filename": "model.cache", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/model.cache" + } + ], + "id": "tinyjensen-1.1b-chat-fp16", + "object": "model", + "name": "TinyJensen 1.1B Chat FP16", + "version": "1.0", + "description": "Do you want to chat with Jensen Huan? Here you are", + "format": "TensorRT-LLM", + "settings": { + "ctx_len": 2048, + "text_model": false + }, + "parameters": { + "max_tokens": 4096 + }, + "metadata": { + "author": "LLama", + "tags": ["TensorRT-LLM", "1B", "Finetuned"], + "size": 2151000000 + }, + "engine": "nitro-tensorrt-llm" + }, + { + "sources": [ + { + "filename": "config.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/config.json" + }, + { + "filename": "mistral_float16_tp1_rank0.engine", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/mistral_float16_tp1_rank0.engine" + }, + { + "filename": "tokenizer.model", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/tokenizer.model" + }, + { + "filename": "special_tokens_map.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/special_tokens_map.json" + }, + { + "filename": "tokenizer.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/tokenizer.json" + }, + { + "filename": "tokenizer_config.json", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/tokenizer_config.json" + }, + { + "filename": "model.cache", + "url": "https://catalog.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/model.cache" + } + ], + "id": "mistral-7b-instruct-int4", + "object": "model", + "name": "Mistral 7B Instruct v0.1 INT4", + "version": "1.0", + "description": "Mistral 7B Instruct v0.1 INT4", + "format": "TensorRT-LLM", + "settings": { + "ctx_len": 2048, + "text_model": false, + "prompt_template": "[INST] {prompt} [/INST]" + }, + "parameters": { + "max_tokens": 4096 + }, + "metadata": { + "author": "MistralAI", + "tags": ["TensorRT-LLM", "7B", "Finetuned"], + "size": 3840000000 + }, + "engine": "nitro-tensorrt-llm" + } +] diff --git a/extensions/tensorrt-llm-extension/rollup.config.ts b/extensions/tensorrt-llm-extension/rollup.config.ts new file mode 100644 index 0000000000..1fad0e711b --- /dev/null +++ b/extensions/tensorrt-llm-extension/rollup.config.ts @@ -0,0 +1,79 @@ +import resolve from '@rollup/plugin-node-resolve' +import commonjs from '@rollup/plugin-commonjs' +import sourceMaps from 'rollup-plugin-sourcemaps' +import typescript from 'rollup-plugin-typescript2' +import json from '@rollup/plugin-json' +import replace from '@rollup/plugin-replace' +const packageJson = require('./package.json') +const modelsJson = require('./resources/models.json') + +export default [ + { + input: `src/index.ts`, + output: [{ file: packageJson.main, format: 'es', sourcemap: true }], + watch: { + include: 'src/**', + }, + plugins: [ + replace({ + preventAssignment: true, + MODELS: JSON.stringify(modelsJson), + TENSORRT_VERSION: JSON.stringify(packageJson.tensorrtVersion), + PROVIDER: JSON.stringify(packageJson.provider), + DOWNLOAD_RUNNER_URL: + process.platform === 'win32' + ? JSON.stringify( + 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v-tensorrt-llm-v0.7.1/nitro-windows-v-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz' + ) + : JSON.stringify( + 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/linux-v/nitro-linux-v-amd64-tensorrt-llm-.tar.gz' + ), + NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), + INFERENCE_URL: JSON.stringify( + process.env.INFERENCE_URL || + `${packageJson.config?.protocol ?? 'http'}://${packageJson.config?.host}:${packageJson.config?.port}/v1/chat/completions` + ), + COMPATIBILITY: JSON.stringify(packageJson.compatibility), + }), + json(), + typescript({ useTsconfigDeclarationDir: true }), + commonjs(), + resolve({ + extensions: ['.js', '.ts', '.svelte'], + }), + sourceMaps(), + ], + }, + { + input: `src/node/index.ts`, + output: [ + { file: 'dist/node/index.cjs.js', format: 'cjs', sourcemap: true }, + ], + external: ['@janhq/core/node'], + watch: { + include: 'src/node/**', + }, + plugins: [ + replace({ + preventAssignment: true, + TENSORRT_VERSION: JSON.stringify(packageJson.tensorrtVersion), + PROVIDER: JSON.stringify(packageJson.provider), + LOAD_MODEL_URL: JSON.stringify( + `${packageJson.config?.protocol ?? 'http'}://${packageJson.config?.host}:${packageJson.config?.port}/inferences/tensorrtllm/loadmodel` + ), + TERMINATE_ENGINE_URL: JSON.stringify( + `${packageJson.config?.protocol ?? 'http'}://${packageJson.config?.host}:${packageJson.config?.port}/processmanager/destroy` + ), + ENGINE_HOST: JSON.stringify(packageJson.config?.host ?? '127.0.0.1'), + ENGINE_PORT: JSON.stringify(packageJson.config?.port ?? '3928'), + }), + json(), + typescript({ useTsconfigDeclarationDir: true }), + commonjs(), + resolve({ + extensions: ['.ts', '.js', '.json'], + }), + sourceMaps(), + ], + }, +] diff --git a/extensions/tensorrt-llm-extension/src/@types/global.d.ts b/extensions/tensorrt-llm-extension/src/@types/global.d.ts new file mode 100644 index 0000000000..b550080f74 --- /dev/null +++ b/extensions/tensorrt-llm-extension/src/@types/global.d.ts @@ -0,0 +1,11 @@ +declare const NODE: string +declare const INFERENCE_URL: string +declare const LOAD_MODEL_URL: string +declare const TERMINATE_ENGINE_URL: string +declare const ENGINE_HOST: string +declare const ENGINE_PORT: string +declare const DOWNLOAD_RUNNER_URL: string +declare const TENSORRT_VERSION: string +declare const COMPATIBILITY: object +declare const PROVIDER: string +declare const MODELS: Array diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts new file mode 100644 index 0000000000..189abc706a --- /dev/null +++ b/extensions/tensorrt-llm-extension/src/index.ts @@ -0,0 +1,199 @@ +/** + * @module tensorrt-llm-extension/src/index + */ + +import { + Compatibility, + DownloadEvent, + DownloadRequest, + DownloadState, + GpuSetting, + InstallationState, + Model, + baseName, + downloadFile, + events, + executeOnMain, + joinPath, + showToast, + systemInformation, + LocalOAIEngine, + fs, + MessageRequest, + ModelEvent, + getJanDataFolderPath, + SystemInformation, +} from '@janhq/core' + +/** + * TensorRTLLMExtension - Implementation of LocalOAIEngine + * @extends BaseOAILocalInferenceProvider + * Provide pre-populated models for TensorRTLLM + */ +export default class TensorRTLLMExtension extends LocalOAIEngine { + /** + * Override custom function name for loading and unloading model + * Which are implemented from node module + */ + override provider = PROVIDER + override inferenceUrl = INFERENCE_URL + override nodeModule = NODE + + private supportedGpuArch = ['ampere', 'ada'] + private supportedPlatform = ['win32', 'linux'] + + override compatibility() { + return COMPATIBILITY as unknown as Compatibility + } + + override async onLoad(): Promise { + super.onLoad() + + if ((await this.installationState()) === 'Installed') { + const models = MODELS as unknown as Model[] + this.registerModels(models) + } + } + + override async install(): Promise { + await this.removePopulatedModels() + + const info = await systemInformation() + + if (!this.isCompatible(info)) return + + const janDataFolderPath = await getJanDataFolderPath() + const engineVersion = TENSORRT_VERSION + + const executableFolderPath = await joinPath([ + janDataFolderPath, + 'engines', + this.provider, + engineVersion, + info.gpuSetting?.gpus[0].arch, + ]) + + if (!(await fs.existsSync(executableFolderPath))) { + await fs.mkdir(executableFolderPath) + } + + const placeholderUrl = DOWNLOAD_RUNNER_URL + const tensorrtVersion = TENSORRT_VERSION + + const url = placeholderUrl + .replace(//g, tensorrtVersion) + .replace(//g, info.gpuSetting!.gpus[0]!.arch!) + + const tarball = await baseName(url) + + const tarballFullPath = await joinPath([executableFolderPath, tarball]) + const downloadRequest: DownloadRequest = { + url, + localPath: tarballFullPath, + extensionId: this.name, + downloadType: 'extension', + } + downloadFile(downloadRequest) + + const onFileDownloadSuccess = async (state: DownloadState) => { + // if other download, ignore + if (state.fileName !== tarball) return + events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) + await executeOnMain( + this.nodeModule, + 'decompressRunner', + tarballFullPath, + executableFolderPath + ) + events.emit(DownloadEvent.onFileUnzipSuccess, state) + + // Prepopulate models as soon as it's ready + const models = MODELS as unknown as Model[] + this.registerModels(models).then(() => { + showToast( + 'Extension installed successfully.', + 'New models are added to Model Hub.' + ) + }) + } + events.on(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) + } + + private async removePopulatedModels(): Promise { + const models = MODELS as unknown as Model[] + console.debug(`removePopulatedModels`, JSON.stringify(models)) + const janDataFolderPath = await getJanDataFolderPath() + const modelFolderPath = await joinPath([janDataFolderPath, 'models']) + + for (const model of models) { + const modelPath = await joinPath([modelFolderPath, model.id]) + + try { + await fs.rm(modelPath) + } catch (err) { + console.error(`Error removing model ${modelPath}`, err) + } + } + events.emit(ModelEvent.OnModelsUpdate, {}) + } + + override async loadModel(model: Model): Promise { + if ((await this.installationState()) === 'Installed') + return super.loadModel(model) + + throw new Error('EXTENSION_IS_NOT_INSTALLED::TensorRT-LLM extension') + } + + override async installationState(): Promise { + const info = await systemInformation() + + if (!this.isCompatible(info)) return 'NotCompatible' + const firstGpu = info.gpuSetting?.gpus[0] + const janDataFolderPath = await getJanDataFolderPath() + const engineVersion = TENSORRT_VERSION + + const enginePath = await joinPath([ + janDataFolderPath, + 'engines', + this.provider, + engineVersion, + firstGpu.arch, + info.osInfo.platform === 'win32' ? 'nitro.exe' : 'nitro', + ]) + + // For now, we just check the executable of nitro x tensor rt + return (await fs.existsSync(enginePath)) ? 'Installed' : 'NotInstalled' + } + + override stopInference() { + if (!this.loadedModel) return + showToast( + 'Unable to Stop Inference', + 'The model does not support stopping inference.' + ) + return Promise.resolve() + } + + override async inference(data: MessageRequest) { + if (!this.loadedModel) return + // TensorRT LLM Extension supports streaming only + if (data.model) data.model.parameters.stream = true + super.inference(data) + } + + isCompatible(info: SystemInformation): info is Required & { + gpuSetting: { gpus: { arch: string }[] } + } { + const firstGpu = info.gpuSetting?.gpus[0] + return ( + !!info.osInfo && + !!info.gpuSetting && + !!firstGpu && + info.gpuSetting.gpus.length > 0 && + this.supportedPlatform.includes(info.osInfo.platform) && + !!firstGpu.arch && + firstGpu.name.toLowerCase().includes('nvidia') && + this.supportedGpuArch.includes(firstGpu.arch) + ) + } +} diff --git a/extensions/tensorrt-llm-extension/src/node/index.ts b/extensions/tensorrt-llm-extension/src/node/index.ts new file mode 100644 index 0000000000..c8bc48459e --- /dev/null +++ b/extensions/tensorrt-llm-extension/src/node/index.ts @@ -0,0 +1,325 @@ +import path from 'path' +import { ChildProcessWithoutNullStreams, spawn } from 'child_process' +import tcpPortUsed from 'tcp-port-used' +import fetchRT from 'fetch-retry' +import { + log, + getJanDataFolderPath, + SystemInformation, + PromptTemplate, +} from '@janhq/core/node' +import decompress from 'decompress' +import terminate from 'terminate' + +// Polyfill fetch with retry +const fetchRetry = fetchRT(fetch) + +const supportedPlatform = (): string[] => ['win32', 'linux'] +const supportedGpuArch = (): string[] => ['ampere', 'ada'] +const PORT_CHECK_INTERVAL = 100 + +/** + * The response object for model init operation. + */ +interface ModelLoadParams { + engine_path: string + ctx_len: number +} + +// The subprocess instance for Engine +let subprocess: ChildProcessWithoutNullStreams | undefined = undefined + +/** + * Initializes a engine subprocess to load a machine learning model. + * @param params - The model load settings. + */ +async function loadModel( + params: any, + systemInfo?: SystemInformation +): Promise<{ error: Error | undefined }> { + // modelFolder is the absolute path to the running model folder + // e.g. ~/jan/models/llama-2 + let modelFolder = params.modelFolder + + if (params.model.settings.prompt_template) { + const promptTemplate = params.model.settings.prompt_template + const prompt = promptTemplateConverter(promptTemplate) + if (prompt?.error) { + return Promise.reject(prompt.error) + } + params.model.settings.system_prompt = prompt.system_prompt + params.model.settings.user_prompt = prompt.user_prompt + params.model.settings.ai_prompt = prompt.ai_prompt + } + + const settings: ModelLoadParams = { + engine_path: modelFolder, + ctx_len: params.model.settings.ctx_len ?? 2048, + ...params.model.settings, + } + if (!systemInfo) { + throw new Error('Cannot get system info. Unable to start nitro x tensorrt.') + } + return runEngineAndLoadModel(settings, systemInfo) +} + +/** + * Stops a Engine subprocess. + */ +function unloadModel(): Promise { + const controller = new AbortController() + setTimeout(() => controller.abort(), 5000) + debugLog(`Request to kill engine`) + + const killRequest = () => { + return fetch(TERMINATE_ENGINE_URL, { + method: 'DELETE', + signal: controller.signal, + }) + .then(() => { + subprocess = undefined + }) + .catch(() => {}) // Do nothing with this attempt + .then(() => + tcpPortUsed.waitUntilFree( + parseInt(ENGINE_PORT), + PORT_CHECK_INTERVAL, + 5000 + ) + ) // Wait for port available + .then(() => debugLog(`Engine process is terminated`)) + .catch((err) => { + debugLog( + `Could not kill running process on port ${ENGINE_PORT}. Might be another process running on the same port? ${err}` + ) + throw 'PORT_NOT_AVAILABLE' + }) + } + + if (subprocess?.pid) { + log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) + const pid = subprocess.pid + return new Promise((resolve, reject) => { + terminate(pid, function (err) { + if (err) { + return killRequest() + } else { + return tcpPortUsed + .waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000) + .then(() => resolve()) + .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .catch(() => { + killRequest() + }) + } + }) + }) + } else { + return killRequest() + } +} +/** + * 1. Spawn engine process + * 2. Load model into engine subprocess + * @returns + */ +async function runEngineAndLoadModel( + settings: ModelLoadParams, + systemInfo: SystemInformation +) { + return unloadModel() + .then(() => runEngine(systemInfo)) + .then(() => loadModelRequest(settings)) + .catch((err) => { + // TODO: Broadcast error so app could display proper error message + debugLog(`${err}`, 'Error') + return { error: err } + }) +} + +/** + * Loads a LLM model into the Engine subprocess by sending a HTTP POST request. + */ +async function loadModelRequest( + settings: ModelLoadParams +): Promise<{ error: Error | undefined }> { + debugLog(`Loading model with params ${JSON.stringify(settings)}`) + return fetchRetry(LOAD_MODEL_URL, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(settings), + retries: 3, + retryDelay: 500, + }) + .then((res) => { + debugLog(`Load model success with response ${JSON.stringify(res)}`) + return Promise.resolve({ error: undefined }) + }) + .catch((err) => { + debugLog(`Load model failed with error ${err}`, 'Error') + return Promise.resolve({ error: err }) + }) +} + +/** + * Spawns engine subprocess. + */ +async function runEngine(systemInfo: SystemInformation): Promise { + debugLog(`Spawning engine subprocess...`) + if (systemInfo.gpuSetting == null) { + return Promise.reject( + 'No GPU information found. Please check your GPU setting.' + ) + } + + if (systemInfo.gpuSetting?.gpus.length === 0) { + return Promise.reject('No GPU found. Please check your GPU setting.') + } + + if (systemInfo.osInfo == null) { + return Promise.reject( + 'No OS information found. Please check your OS setting.' + ) + } + const platform = systemInfo.osInfo.platform + if (platform == null || supportedPlatform().includes(platform) === false) { + return Promise.reject( + 'No OS architecture found. Please check your OS setting.' + ) + } + + const gpu = systemInfo.gpuSetting?.gpus[0] + if (gpu.name.toLowerCase().includes('nvidia') === false) { + return Promise.reject('No Nvidia GPU found. Please check your GPU setting.') + } + const gpuArch = gpu.arch + if (gpuArch == null || supportedGpuArch().includes(gpuArch) === false) { + return Promise.reject( + `Your GPU: ${gpu.name} is not supported. Only ${supportedGpuArch().join( + ', ' + )} series are supported.` + ) + } + const janDataFolderPath = await getJanDataFolderPath() + const tensorRtVersion = TENSORRT_VERSION + const provider = PROVIDER + + return new Promise((resolve, reject) => { + // Current directory by default + + const executableFolderPath = path.join( + janDataFolderPath, + 'engines', + provider, + tensorRtVersion, + gpuArch + ) + const nitroExecutablePath = path.join( + executableFolderPath, + platform === 'win32' ? 'nitro.exe' : 'nitro' + ) + + const args: string[] = ['1', ENGINE_HOST, ENGINE_PORT] + // Execute the binary + debugLog(`Spawn nitro at path: ${nitroExecutablePath}, and args: ${args}`) + subprocess = spawn(nitroExecutablePath, args, { + cwd: executableFolderPath, + env: { + ...process.env, + }, + }) + + // Handle subprocess output + subprocess.stdout.on('data', (data: any) => { + debugLog(`${data}`) + }) + + subprocess.stderr.on('data', (data: any) => { + debugLog(`${data}`) + }) + + subprocess.on('close', (code: any) => { + debugLog(`Engine exited with code: ${code}`) + subprocess = undefined + reject(`child process exited with code ${code}`) + }) + + tcpPortUsed + .waitUntilUsed(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 30000) + .then(() => { + debugLog(`Engine is ready`) + resolve() + }) + }) +} + +function debugLog(message: string, level: string = 'Debug') { + log(`[TENSORRT_LLM_NITRO]::${level}:${message}`) +} + +const decompressRunner = async (zipPath: string, output: string) => { + console.debug(`Decompressing ${zipPath} to ${output}...`) + try { + const files = await decompress(zipPath, output) + console.debug('Decompress finished!', files) + } catch (err) { + console.error(`Decompress ${zipPath} failed: ${err}`) + } +} + +/** + * Parse prompt template into agrs settings + * @param promptTemplate Template as string + * @returns + */ +function promptTemplateConverter(promptTemplate: string): PromptTemplate { + // Split the string using the markers + const systemMarker = '{system_message}' + const promptMarker = '{prompt}' + + if ( + promptTemplate.includes(systemMarker) && + promptTemplate.includes(promptMarker) + ) { + // Find the indices of the markers + const systemIndex = promptTemplate.indexOf(systemMarker) + const promptIndex = promptTemplate.indexOf(promptMarker) + + // Extract the parts of the string + const system_prompt = promptTemplate.substring(0, systemIndex) + const user_prompt = promptTemplate.substring( + systemIndex + systemMarker.length, + promptIndex + ) + const ai_prompt = promptTemplate.substring( + promptIndex + promptMarker.length + ) + + // Return the split parts + return { system_prompt, user_prompt, ai_prompt } + } else if (promptTemplate.includes(promptMarker)) { + // Extract the parts of the string for the case where only promptMarker is present + const promptIndex = promptTemplate.indexOf(promptMarker) + const user_prompt = promptTemplate.substring(0, promptIndex) + const ai_prompt = promptTemplate.substring( + promptIndex + promptMarker.length + ) + + // Return the split parts + return { user_prompt, ai_prompt } + } + + // Return an error if none of the conditions are met + return { error: 'Cannot split prompt template' } +} + +export default { + supportedPlatform, + supportedGpuArch, + decompressRunner, + loadModel, + unloadModel, + dispose: unloadModel, +} diff --git a/extensions/tensorrt-llm-extension/tsconfig.json b/extensions/tensorrt-llm-extension/tsconfig.json new file mode 100644 index 0000000000..478a057288 --- /dev/null +++ b/extensions/tensorrt-llm-extension/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "moduleResolution": "node", + "target": "es5", + "module": "ES2020", + "lib": ["es2015", "es2016", "es2017", "dom"], + "strict": true, + "sourceMap": true, + "declaration": true, + "allowSyntheticDefaultImports": true, + "experimentalDecorators": true, + "emitDecoratorMetadata": true, + "declarationDir": "dist/types", + "outDir": "dist", + "importHelpers": true, + "resolveJsonModule": true, + "typeRoots": ["node_modules/@types"] + }, + "include": ["src"] +} diff --git a/extensions/turbo.json b/extensions/turbo.json new file mode 100644 index 0000000000..b2e876d585 --- /dev/null +++ b/extensions/turbo.json @@ -0,0 +1,17 @@ +{ + "$schema": "https://turbo.build/schema.json", + "pipeline": { + "build": { + "dependsOn": ["^build"], + "outputs": ["dist/**"] + }, + "build:publish": { + "dependsOn": ["build"], + "outputs": ["**.tgz"] + }, + "dev": { + "cache": false + }, + "type-check": {} + } +} diff --git a/joi/package.json b/joi/package.json index 973d1942ea..3f1bd07f73 100644 --- a/joi/package.json +++ b/joi/package.json @@ -31,7 +31,6 @@ "dependencies": { "@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-dialog": "^1.0.5", - "@radix-ui/react-dropdown-menu": "^2.1.1", "@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-scroll-area": "^1.0.5", "@radix-ui/react-select": "^2.0.0", @@ -39,9 +38,8 @@ "@radix-ui/react-slot": "^1.0.2", "@radix-ui/react-tabs": "^1.0.4", "@radix-ui/react-tooltip": "^1.0.7", - "@radix-ui/react-visually-hidden": "^1.1.0", - "autoprefixer": "10.4.16", "tailwind-merge": "^2.2.0", + "autoprefixer": "10.4.16", "tailwindcss": "^3.4.1" }, "devDependencies": { diff --git a/joi/src/core/Button/styles.scss b/joi/src/core/Button/styles.scss index 1bd8288fc0..f7cdce6a46 100644 --- a/joi/src/core/Button/styles.scss +++ b/joi/src/core/Button/styles.scss @@ -1,6 +1,10 @@ .btn { @apply inline-flex items-center justify-center px-4 font-semibold transition-all; + &:focus, + &:focus-within { + @apply outline-2 outline-offset-4; + } &:hover { filter: brightness(95%); } @@ -9,7 +13,6 @@ &--primary { color: hsla(var(--primary-fg)); background-color: hsla(var(--primary-bg)) !important; - &:hover { filter: brightness(65%); } @@ -30,15 +33,14 @@ // Ghost &--ghost { - background-color: transparent; - + background-color: transparent !important; &.btn--soft { - background-color: transparent; + background-color: transparent !important; } // Variant outline ghost &.btn--outline { - background-color: transparent; + background-color: transparent !important; border: 1px solid hsla(var(--ghost-border)); } } @@ -47,7 +49,6 @@ &--destructive { color: hsla(var(--destructive-fg)); background-color: hsla(var(--destructive-bg)) !important; - &:hover { filter: brightness(65%); } @@ -82,7 +83,6 @@ width: 24px; height: 24px; padding: 2px; - &:hover { background-color: hsla(var(--icon-bg)) !important; } @@ -90,7 +90,6 @@ &.btn--outline { background-color: transparent !important; border: 1px solid hsla(var(--icon-border)); - &:hover { background-color: hsla(var(--icon-bg)) !important; } @@ -102,7 +101,6 @@ @apply h-6 px-2; font-size: 12px; border-radius: 4px; - &.btn--icon { width: 24px; height: 24px; @@ -113,7 +111,6 @@ &--medium { @apply h-8; border-radius: 6px; - &.btn--icon { width: 24px; height: 24px; @@ -124,7 +121,6 @@ &--large { @apply h-9; border-radius: 8px; - &.btn--icon { width: 24px; height: 24px; diff --git a/joi/src/core/DropdownMenu/index.tsx b/joi/src/core/DropdownMenu/index.tsx deleted file mode 100644 index 5c0aedbff0..0000000000 --- a/joi/src/core/DropdownMenu/index.tsx +++ /dev/null @@ -1,198 +0,0 @@ -import * as React from "react" -import * as DropdownMenuPrimitive from "@radix-ui/react-dropdown-menu" -import { Check, ChevronRight, Circle } from "lucide-react" -import { twMerge } from "tailwind-merge" - - -const DropdownMenu = DropdownMenuPrimitive.Root - -const DropdownMenuTrigger = DropdownMenuPrimitive.Trigger - -const DropdownMenuGroup = DropdownMenuPrimitive.Group - -const DropdownMenuPortal = DropdownMenuPrimitive.Portal - -const DropdownMenuSub = DropdownMenuPrimitive.Sub - -const DropdownMenuRadioGroup = DropdownMenuPrimitive.RadioGroup - -const DropdownMenuSubTrigger = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef & { - inset?: boolean - } ->(({ className, inset, children, ...props }, ref) => ( - - {children} - - -)) -DropdownMenuSubTrigger.displayName = - DropdownMenuPrimitive.SubTrigger.displayName - -const DropdownMenuSubContent = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, ...props }, ref) => ( - -)) -DropdownMenuSubContent.displayName = - DropdownMenuPrimitive.SubContent.displayName - -const DropdownMenuContent = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, sideOffset = 4, ...props }, ref) => ( - - - -)) -DropdownMenuContent.displayName = DropdownMenuPrimitive.Content.displayName - -const DropdownMenuItem = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef & { - inset?: boolean - } ->(({ className, inset, ...props }, ref) => ( - -)) -DropdownMenuItem.displayName = DropdownMenuPrimitive.Item.displayName - -const DropdownMenuCheckboxItem = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, children, checked, ...props }, ref) => ( - - - - - - - {children} - -)) -DropdownMenuCheckboxItem.displayName = - DropdownMenuPrimitive.CheckboxItem.displayName - -const DropdownMenuRadioItem = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, children, ...props }, ref) => ( - - - - - - - {children} - -)) -DropdownMenuRadioItem.displayName = DropdownMenuPrimitive.RadioItem.displayName - -const DropdownMenuLabel = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef & { - inset?: boolean - } ->(({ className, inset, ...props }, ref) => ( - -)) -DropdownMenuLabel.displayName = DropdownMenuPrimitive.Label.displayName - -const DropdownMenuSeparator = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef ->(({ className, ...props }, ref) => ( - -)) -DropdownMenuSeparator.displayName = DropdownMenuPrimitive.Separator.displayName - -const DropdownMenuShortcut = ({ - className, - ...props -}: React.HTMLAttributes) => { - return ( - - ) -} -DropdownMenuShortcut.displayName = "DropdownMenuShortcut" - -export { - DropdownMenu, - DropdownMenuTrigger, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuCheckboxItem, - DropdownMenuRadioItem, - DropdownMenuLabel, - DropdownMenuSeparator, - DropdownMenuShortcut, - DropdownMenuGroup, - DropdownMenuPortal, - DropdownMenuSub, - DropdownMenuSubContent, - DropdownMenuSubTrigger, - DropdownMenuRadioGroup, -} diff --git a/joi/src/core/Modal/index.tsx b/joi/src/core/Modal/index.tsx index 7e31671ede..923004b99b 100644 --- a/joi/src/core/Modal/index.tsx +++ b/joi/src/core/Modal/index.tsx @@ -4,7 +4,6 @@ import { Cross2Icon } from '@radix-ui/react-icons' import './styles.scss' import { twMerge } from 'tailwind-merge' -import { VisuallyHidden } from '../VisuallyHidden' type Props = { trigger?: ReactNode @@ -40,13 +39,7 @@ const Modal = ({ className )} > - - - - - - - {title &&
{title}
} +
{title}
{content} {!hideClose && ( diff --git a/joi/src/core/ScrollArea/index.tsx b/joi/src/core/ScrollArea/index.tsx index c7fe71572a..3a2ffaaa84 100644 --- a/joi/src/core/ScrollArea/index.tsx +++ b/joi/src/core/ScrollArea/index.tsx @@ -1,4 +1,4 @@ -import React from 'react' +import React, { PropsWithChildren, forwardRef } from 'react' import * as ScrollAreaPrimitive from '@radix-ui/react-scroll-area' import { twMerge } from 'tailwind-merge' @@ -9,7 +9,7 @@ const ScrollArea = React.forwardRef< React.ComponentPropsWithoutRef >(({ className, children, onScroll, ...props }, ref) => ( diff --git a/joi/src/core/ScrollArea/styles.scss b/joi/src/core/ScrollArea/styles.scss index 20d9665411..cb5832c53d 100644 --- a/joi/src/core/ScrollArea/styles.scss +++ b/joi/src/core/ScrollArea/styles.scss @@ -44,17 +44,17 @@ } .scroll-area__bar[data-orientation='vertical'] { - width: 10px; + width: 8px; } .scroll-area__bar[data-orientation='horizontal'] { flex-direction: column; - height: 10px; + height: 8px; } ::-webkit-scrollbar { - width: 10px; - height: 10px; + width: 6px; + height: 6px; } ::-webkit-scrollbar-track, ::-webkit-scrollbar-thumb { diff --git a/joi/src/core/Select/index.tsx b/joi/src/core/Select/index.tsx index a576caecfd..bce5473da1 100644 --- a/joi/src/core/Select/index.tsx +++ b/joi/src/core/Select/index.tsx @@ -1,7 +1,11 @@ import React, { ReactNode } from 'react' import * as SelectPrimitive from '@radix-ui/react-select' -import { CheckIcon, ChevronDownIcon } from '@radix-ui/react-icons' +import { + CheckIcon, + ChevronDownIcon, + ChevronUpIcon, +} from '@radix-ui/react-icons' import './styles.scss' import { twMerge } from 'tailwind-merge' diff --git a/joi/src/core/Table/index.tsx b/joi/src/core/Table/index.tsx deleted file mode 100644 index 7a26d3dd8c..0000000000 --- a/joi/src/core/Table/index.tsx +++ /dev/null @@ -1,123 +0,0 @@ -import * as React from 'react' -import { twMerge } from 'tailwind-merge' - -const Table = React.forwardRef< - HTMLTableElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
- - -)) -Table.displayName = 'Table' - -const TableHeader = React.forwardRef< - HTMLTableSectionElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( - -)) -TableHeader.displayName = 'TableHeader' - -const TableBody = React.forwardRef< - HTMLTableSectionElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( - -)) -TableBody.displayName = 'TableBody' - -const TableFooter = React.forwardRef< - HTMLTableSectionElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( - tr]:last:border-b-0', - className - )} - {...props} - /> -)) -TableFooter.displayName = 'TableFooter' - -const TableRow = React.forwardRef< - HTMLTableRowElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( - -)) -TableRow.displayName = 'TableRow' - -const TableHead = React.forwardRef< - HTMLTableCellElement, - React.ThHTMLAttributes ->(({ className, ...props }, ref) => ( -
-)) -TableHead.displayName = 'TableHead' - -const TableCell = React.forwardRef< - HTMLTableCellElement, - React.TdHTMLAttributes ->(({ className, ...props }, ref) => ( - -)) -TableCell.displayName = 'TableCell' - -const TableCaption = React.forwardRef< - HTMLTableCaptionElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
-)) -TableCaption.displayName = 'TableCaption' - -export { - Table, - TableHeader, - TableBody, - TableFooter, - TableHead, - TableRow, - TableCell, - TableCaption, -} diff --git a/joi/src/core/TextArea/index.tsx b/joi/src/core/TextArea/index.tsx index 791ff1430f..33d6744ada 100644 --- a/joi/src/core/TextArea/index.tsx +++ b/joi/src/core/TextArea/index.tsx @@ -1,7 +1,8 @@ -import React, { forwardRef } from 'react' +import React, { ReactNode, forwardRef } from 'react' import { twMerge } from 'tailwind-merge' import './styles.scss' +import { ScrollArea } from '../ScrollArea' export interface TextAreaProps extends React.TextareaHTMLAttributes {} diff --git a/joi/src/core/Tooltip/styles.scss b/joi/src/core/Tooltip/styles.scss index 1ec9a56994..04fb841c63 100644 --- a/joi/src/core/Tooltip/styles.scss +++ b/joi/src/core/Tooltip/styles.scss @@ -10,7 +10,7 @@ animation-timing-function: cubic-bezier(0.16, 1, 0.3, 1); will-change: transform, opacity; font-weight: 500; - z-index: 999999999; + z-index: 100; max-width: 240px; @apply text-sm leading-normal; } diff --git a/joi/src/core/VisuallyHidden/index.tsx b/joi/src/core/VisuallyHidden/index.tsx deleted file mode 100644 index 4d26b6e57e..0000000000 --- a/joi/src/core/VisuallyHidden/index.tsx +++ /dev/null @@ -1,5 +0,0 @@ -import * as VisuallyHiddenPrimitive from '@radix-ui/react-visually-hidden' - -const VisuallyHidden = VisuallyHiddenPrimitive.Root - -export { VisuallyHidden } diff --git a/joi/src/index.ts b/joi/src/index.ts index d3389bddd6..5431627475 100644 --- a/joi/src/index.ts +++ b/joi/src/index.ts @@ -12,9 +12,6 @@ export * from './core/Select' export * from './core/TextArea' export * from './core/Tabs' export * from './core/Accordion' -export * from './core/DropdownMenu' -export * from './core/VisuallyHidden' -export * from './core/Table' export * from './hooks/useClipboard' export * from './hooks/usePageLeave' diff --git a/package.json b/package.json index b459984bf0..68c11c68cf 100644 --- a/package.json +++ b/package.json @@ -20,7 +20,7 @@ "scripts": { "lint": "yarn workspace jan lint && yarn workspace @janhq/web lint", "test:unit": "yarn workspace @janhq/core test", - "test": "yarn test:unit", + "test": "yarn workspace jan test:e2e", "test-local": "yarn lint && yarn build:test && yarn test", "pre-install:darwin": "find extensions -type f -path \"**/*.tgz\" -exec cp {} pre-install \\;", "pre-install:linux": "find extensions -type f -path \"**/*.tgz\" -exec cp {} pre-install \\;", @@ -29,11 +29,14 @@ "copy:assets": "cpx \"pre-install/*.tgz\" \"electron/pre-install/\" && cpx \"themes/**\" \"electron/themes\" && cpx \"docs/openapi/**\" \"electron/docs/openapi\"", "dev:electron": "yarn copy:assets && yarn workspace jan dev", "dev:web": "yarn workspace @janhq/web dev", - "dev": "turbo run dev --parallel", + "dev:server": "yarn copy:assets && yarn workspace @janhq/server dev", + "dev": "turbo run dev --parallel --filter=!@janhq/server", + "build:server": "yarn copy:assets && cd server && yarn install && yarn run build", "build:core": "cd core && yarn install && yarn run build", "build:web": "yarn workspace @janhq/web build && cpx \"web/out/**\" \"electron/renderer/\"", "build:electron": "yarn copy:assets && yarn workspace jan build", "build:electron:test": "yarn workspace jan build:test", + "build:extensions": "rimraf ./pre-install/*.tgz && turbo run @janhq/core#build && cd extensions && yarn install && turbo run build:publish && cd .. && yarn pre-install", "build:test": "yarn copy:assets && turbo run @janhq/web#build && cpx \"web/out/**\" \"electron/renderer/\" && turbo run build:test", "build": "yarn build:web && yarn build:electron", "build:publish": "yarn copy:assets && yarn build:web && yarn workspace jan build:publish", diff --git a/server/.gitignore b/server/.gitignore new file mode 100644 index 0000000000..6320cd248d --- /dev/null +++ b/server/.gitignore @@ -0,0 +1 @@ +data \ No newline at end of file diff --git a/server/helpers/logger.ts b/server/helpers/logger.ts new file mode 100644 index 0000000000..2e61473867 --- /dev/null +++ b/server/helpers/logger.ts @@ -0,0 +1,58 @@ +import { log } from '@janhq/core/node' +import { FastifyBaseLogger } from 'fastify' +import { ChildLoggerOptions } from 'fastify/types/logger' +import pino from 'pino' + +export class Logger implements FastifyBaseLogger { + child( + bindings: pino.Bindings, + options?: ChildLoggerOptions | undefined + ): FastifyBaseLogger { + return new Logger() + } + level = 'info' + + silent = () => {} + + info = (obj?: any, msg?: string, ...args: any[]) => { + if (obj?.res?.raw?.statusCode || obj?.req?.url) { + log( + `[SERVER]::${JSON.stringify({ + level: obj?.level, + time: obj?.time, + hostname: obj?.hostname, + reqId: obj?.req?.id ?? obj?.res?.request?.id, + res: { + statusCode: obj?.res?.raw?.statusCode, + }, + req: { + method: obj?.req?.method, + url: obj?.req?.url, + path: obj?.req?.path, + hostname: obj?.req?.hostname, + remoteAddress: obj?.req?.remoteAddress, + remotePort: obj?.req?.remotePort, + }, + msg, + responseTime: obj?.responseTime, + ...args, + })}` + ) + } + } + error = function (message: any) { + log(`[SERVER]::${JSON.stringify(message)}`) + } + debug = function (message: any) { + log(`[SERVER]::${JSON.stringify(message)}`) + } + fatal = function (message: any) { + log(`[SERVER]::${JSON.stringify(message)}`) + } + warn = function (message: any) { + log(`[SERVER]::${JSON.stringify(message)}`) + } + trace = function (message: any) { + log(`[SERVER]::${JSON.stringify(message)}`) + } +} diff --git a/server/helpers/setup.ts b/server/helpers/setup.ts new file mode 100644 index 0000000000..41595d70c4 --- /dev/null +++ b/server/helpers/setup.ts @@ -0,0 +1,73 @@ +import { join, extname } from 'path' +import { existsSync, readdirSync, writeFileSync, mkdirSync } from 'fs' +import { init, installExtensions } from '@janhq/core/node' + +export async function setup() { + /** + * Setup Jan Data Directory + */ + const appDir = process.env.JAN_DATA_DIRECTORY ?? join(__dirname, '..', 'jan') + + console.debug(`Create app data directory at ${appDir}...`) + if (!existsSync(appDir)) mkdirSync(appDir) + //@ts-ignore + global.core = { + // Define appPath function for app to retrieve app path globally + appPath: () => appDir, + } + init({ + extensionsPath: join(appDir, 'extensions'), + }) + + /** + * Write app configurations. See #1619 + */ + console.debug('Writing config file...') + writeFileSync( + join(appDir, 'settings.json'), + JSON.stringify({ + data_folder: appDir, + }), + 'utf-8' + ) + + if (!existsSync(join(appDir, 'settings'))) { + console.debug('Writing nvidia config file...') + mkdirSync(join(appDir, 'settings')) + writeFileSync( + join(appDir, 'settings', 'settings.json'), + JSON.stringify( + { + notify: true, + run_mode: 'cpu', + nvidia_driver: { + exist: false, + version: '', + }, + cuda: { + exist: false, + version: '', + }, + gpus: [], + gpu_highest_vram: '', + gpus_in_use: [], + is_initial: true, + }), + 'utf-8' + ) + } + + /** + * Install extensions + */ + + console.debug('Installing extensions...') + + const baseExtensionPath = join(__dirname, '../../..', 'pre-install') + const extensions = readdirSync(baseExtensionPath) + .filter((file) => extname(file) === '.tgz') + .map((file) => join(baseExtensionPath, file)) + + await installExtensions(extensions) + console.debug('Extensions installed') +} diff --git a/server/index.ts b/server/index.ts new file mode 100644 index 0000000000..f82c4f5bc6 --- /dev/null +++ b/server/index.ts @@ -0,0 +1,155 @@ +import fastify from 'fastify' +import dotenv from 'dotenv' +import { v1Router, log, getJanExtensionsPath } from '@janhq/core/node' +import { join } from 'path' +import tcpPortUsed from 'tcp-port-used' +import { Logger } from './helpers/logger' + +// Load environment variables +dotenv.config() + +// Define default settings +const JAN_API_HOST = process.env.JAN_API_HOST || '127.0.0.1' +const JAN_API_PORT = Number.parseInt(process.env.JAN_API_PORT || '1337') + +// Initialize server settings +let server: any | undefined = undefined +let hostSetting: string = JAN_API_HOST +let portSetting: number = JAN_API_PORT +let corsEnabled: boolean = true +let isVerbose: boolean = true + +/** + * Server configurations + * @param host - The host address for the server + * @param port - The port number for the server + * @param isCorsEnabled - Flag to enable or disable CORS + * @param isVerboseEnabled - Flag to enable or disable verbose logging + * @param schemaPath - Path to the OpenAPI schema file + * @param baseDir - Base directory for the OpenAPI schema file + */ +export interface ServerConfig { + host?: string + port?: number + isCorsEnabled?: boolean + isVerboseEnabled?: boolean + schemaPath?: string + baseDir?: string + prefix?: string + storageAdataper?: any +} + +/** + * Function to start the server + * @param configs - Server configurations + */ +export const startServer = async (configs?: ServerConfig): Promise => { + if (configs?.port && configs?.host) { + const inUse = await tcpPortUsed.check(Number(configs.port), configs.host) + if (inUse) { + const errorMessage = `Port ${configs.port} is already in use.` + log(errorMessage, '[SERVER]') + throw new Error(errorMessage) + } + } + + // Update server settings + isVerbose = configs?.isVerboseEnabled ?? true + hostSetting = configs?.host ?? JAN_API_HOST + portSetting = configs?.port ?? JAN_API_PORT + corsEnabled = configs?.isCorsEnabled ?? true + + // Start the server + try { + // Log server start + if (isVerbose) log(`Debug: Starting JAN API server...`, '[SERVER]') + + // Initialize Fastify server with logging + server = fastify({ + logger: new Logger(), + }) + + // Register CORS if enabled + if (corsEnabled) await server.register(require('@fastify/cors'), {}) + + // Register Swagger for API documentation + await server.register(require('@fastify/swagger'), { + mode: 'static', + specification: { + path: configs?.schemaPath ?? './../docs/openapi/jan.yaml', + baseDir: configs?.baseDir ?? './../docs/openapi', + postProcessor: function (swaggerObject: any) { + swaggerObject.servers[0].url = configs?.prefix ?? '/v1' + return swaggerObject + }, + }, + }) + + // Register Swagger UI + await server.register(require('@fastify/swagger-ui'), { + routePrefix: '/', + baseDir: configs?.baseDir ?? join(__dirname, '../..', './docs/openapi'), + uiConfig: { + docExpansion: 'full', + deepLinking: false, + }, + staticCSP: false, + transformSpecificationClone: true, + }) + + // Register static file serving for extensions + // TODO: Watch extension files changes and reload + await server.register( + (childContext: any, _: any, done: any) => { + childContext.register(require('@fastify/static'), { + root: getJanExtensionsPath(), + wildcard: false, + }) + + done() + }, + { prefix: 'extensions' } + ) + + // Register proxy middleware + if (configs?.storageAdataper) + server.addHook('preHandler', configs.storageAdataper) + + // Register API routes + await server.register(v1Router, { prefix: configs?.prefix ?? '/v1' }) + // Start listening for requests + await server + .listen({ + port: portSetting, + host: hostSetting, + }) + .then(() => { + // Log server listening + if (isVerbose) + log( + `Debug: JAN API listening at: http://${hostSetting}:${portSetting}`, + '[SERVER]' + ) + }) + return true + } catch (e) { + // Log any errors + if (isVerbose) log(`Error: ${e}`, '[SERVER]') + } + return false +} + +/** + * Function to stop the server + */ +export const stopServer = async () => { + try { + // Log server stop + if (isVerbose) log(`Debug: Server stopped`, '[SERVER]') + // Stop the server + await server?.close() + } catch (e) { + // Log any errors + if (isVerbose) log(`Error: ${e}`, '[SERVER]') + } +} diff --git a/server/main.ts b/server/main.ts new file mode 100644 index 0000000000..71fb111062 --- /dev/null +++ b/server/main.ts @@ -0,0 +1,7 @@ +import { s3 } from './middleware/s3' +import { setup } from './helpers/setup' +import { startServer as start } from './index' +/** + * Setup extensions and start the server + */ +setup().then(() => start({ storageAdataper: s3 })) diff --git a/server/middleware/s3.ts b/server/middleware/s3.ts new file mode 100644 index 0000000000..3024285a3d --- /dev/null +++ b/server/middleware/s3.ts @@ -0,0 +1,70 @@ +import { join } from 'path' + +// Middleware to intercept requests and proxy if certain conditions are met +const config = { + endpoint: process.env.AWS_ENDPOINT, + region: process.env.AWS_REGION, + credentials: { + accessKeyId: process.env.AWS_ACCESS_KEY_ID, + secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY, + }, +} + +const S3_BUCKET_NAME = process.env.S3_BUCKET_NAME + +const fs = require('@cyclic.sh/s3fs')(S3_BUCKET_NAME, config) +const PROXY_PREFIX = '/v1/fs' +const PROXY_ROUTES = ['/threads', '/messages'] + +export const s3 = (req: any, reply: any, done: any) => { + // Proxy FS requests to S3 using S3FS + if (req.url.startsWith(PROXY_PREFIX)) { + const route = req.url.split('/').pop() + const args = parseRequestArgs(req) + + // Proxy matched requests to the s3fs module + if (args.length && PROXY_ROUTES.some((route) => args[0].includes(route))) { + try { + // Handle customized route + // S3FS does not handle appendFileSync + if (route === 'appendFileSync') { + let result = handAppendFileSync(args) + + reply.status(200).send(result) + return + } + // Reroute the other requests to the s3fs module + const result = fs[route](...args) + reply.status(200).send(result) + return + } catch (ex) { + console.error(ex) + } + } + } + // Let other requests go through + done() +} + +const parseRequestArgs = (req: Request) => { + const { + getJanDataFolderPath, + normalizeFilePath, + } = require('@janhq/core/node') + + return JSON.parse(req.body as any).map((arg: any) => + typeof arg === 'string' && + (arg.startsWith(`file:/`) || arg.startsWith(`file:\\`)) + ? join(getJanDataFolderPath(), normalizeFilePath(arg)) + : arg + ) +} + +const handAppendFileSync = (args: any[]) => { + if (fs.existsSync(args[0])) { + const data = fs.readFileSync(args[0], 'utf-8') + return fs.writeFileSync(args[0], data + args[1]) + } else { + return fs.writeFileSync(args[0], args[1]) + } +} diff --git a/server/package.json b/server/package.json new file mode 100644 index 0000000000..b2c237c615 --- /dev/null +++ b/server/package.json @@ -0,0 +1,46 @@ +{ + "name": "@janhq/server", + "version": "0.1.3", + "main": "build/index.js", + "types": "build/index.d.ts", + "author": "Jan ", + "license": "AGPL-3.0", + "homepage": "https://jan.ai", + "description": "Use offline LLMs with your own data. Run open source models like Llama2 or Falcon on your internal computers/servers.", + "files": [ + "build/**" + ], + "scripts": { + "lint": "eslint . --ext \".js,.jsx,.ts,.tsx\"", + "test:e2e": "playwright test --workers=1", + "dev": "tsc --watch & node --watch build/main.js", + "build": "tsc" + }, + "dependencies": { + "@alumna/reflect": "^1.1.3", + "@cyclic.sh/s3fs": "^1.2.9", + "@fastify/cors": "^8.4.2", + "@fastify/static": "^6.12.0", + "@fastify/swagger": "^8.13.0", + "@fastify/swagger-ui": "2.0.1", + "@janhq/core": "link:./core", + "@npmcli/arborist": "^7.3.1", + "dotenv": "^16.3.1", + "fastify": "^4.24.3", + "fetch-retry": "^5.0.6", + "node-fetch": "2", + "request": "^2.88.2", + "request-progress": "^3.0.0", + "tcp-port-used": "^1.0.2" + }, + "devDependencies": { + "@types/body-parser": "^1.19.5", + "@types/npmcli__arborist": "^5.6.4", + "@types/tcp-port-used": "^1.0.4", + "@typescript-eslint/eslint-plugin": "^6.7.3", + "@typescript-eslint/parser": "^6.7.3", + "eslint-plugin-react": "^7.34.0", + "run-script-os": "^1.1.6", + "typescript": "^5.3.3" + } +} diff --git a/server/tsconfig.json b/server/tsconfig.json new file mode 100644 index 0000000000..dd27b89323 --- /dev/null +++ b/server/tsconfig.json @@ -0,0 +1,24 @@ +{ + "compilerOptions": { + "target": "es5", + "module": "commonjs", + "noImplicitAny": true, + "sourceMap": true, + "strict": true, + "outDir": "./build", + "rootDir": "./", + "noEmitOnError": true, + "esModuleInterop": true, + "baseUrl": ".", + "allowJs": true, + "skipLibCheck": true, + "paths": { "*": ["node_modules/*"] }, + "typeRoots": ["node_modules/@types"], + "ignoreDeprecations": "5.0", + "declaration": true + }, + // "sourceMap": true, + + "include": ["./**/*.ts"], + "exclude": ["core", "build", "dist", "tests", "node_modules", "extensions"] +} diff --git a/turbo.json b/turbo.json index 5f9bb0a5f5..bfb6b08937 100644 --- a/turbo.json +++ b/turbo.json @@ -12,10 +12,14 @@ "persistent": true, "dependsOn": ["@janhq/core#build", "@janhq/joi#build"] }, + "@janhq/server#build": { + "outputs": ["dist/**"], + "dependsOn": ["@janhq/core#build"] + }, "jan#dev": { "cache": false, "persistent": true, - "dependsOn": ["@janhq/core#build"] + "dependsOn": ["@janhq/core#build", "@janhq/server#build"] }, "@janhq/core#build": { "outputs": ["dist/**"] @@ -28,6 +32,7 @@ "outputs": ["dist/**"], "dependsOn": [ "@janhq/core#build", + "@janhq/server#build", "@janhq/web#build" ] }, @@ -36,6 +41,7 @@ "cache": false, "dependsOn": [ "@janhq/core#build", + "@janhq/server#build", "@janhq/web#build" ] }, diff --git a/web/app/layout.tsx b/web/app/layout.tsx index ae5601e257..5f14d6f5cc 100644 --- a/web/app/layout.tsx +++ b/web/app/layout.tsx @@ -2,7 +2,6 @@ import { PropsWithChildren } from 'react' import { Metadata } from 'next' -import 'katex/dist/katex.min.css' import '@/styles/main.scss' export const metadata: Metadata = { @@ -14,7 +13,7 @@ export const metadata: Metadata = { export default function RootLayout({ children }: PropsWithChildren) { return ( - +
{children} diff --git a/web/app/search/SelectedText.tsx b/web/app/search/SelectedText.tsx index 4db14b6d5e..fdb24bbfb3 100644 --- a/web/app/search/SelectedText.tsx +++ b/web/app/search/SelectedText.tsx @@ -32,7 +32,7 @@ const SelectedText = ({ onCleared }: { onCleared?: () => void }) => { className="relative rounded-lg border border-[hsla(var(--app-border))] p-[10px]" >
diff --git a/web/app/search/layout.tsx b/web/app/search/layout.tsx index 3b7a280a3f..dedbe22f53 100644 --- a/web/app/search/layout.tsx +++ b/web/app/search/layout.tsx @@ -2,6 +2,8 @@ import { useEffect } from 'react' +import { AppConfiguration, getUserHomePath, joinPath } from '@janhq/core' + import { useSetAtom } from 'jotai' import ClipboardListener from '@/containers/Providers/ClipboardListener' @@ -27,14 +29,17 @@ export default function RootLayout() { }, []) useEffect(() => { - window.electronAPI?.appDataFolder()?.then((path: string) => { - setJanDataFolderPath(path) - }) + window.core?.api + ?.getAppConfigurations() + ?.then((appConfig: AppConfiguration) => { + setJanDataFolderPath(appConfig.data_folder) + }) }, [setJanDataFolderPath]) useEffect(() => { async function getDefaultJanDataFolder() { - const defaultJanDataFolder = await window?.electronAPI.homePath() + const homePath = await getUserHomePath() + const defaultJanDataFolder = await joinPath([homePath, 'jan']) setJanDefaultDataFolder(defaultJanDataFolder) } @@ -48,8 +53,9 @@ export default function RootLayout() { - - + + + diff --git a/web/components/Discord.tsx b/web/components/Discord.tsx deleted file mode 100644 index 04935166e5..0000000000 --- a/web/components/Discord.tsx +++ /dev/null @@ -1,18 +0,0 @@ -import React from 'react' - -const Discord: React.FC = () => ( - - - -) - -export default React.memo(Discord) diff --git a/web/components/GitHub.tsx b/web/components/GitHub.tsx deleted file mode 100644 index cc948f0ef7..0000000000 --- a/web/components/GitHub.tsx +++ /dev/null @@ -1,18 +0,0 @@ -import React from 'react' - -const GitHub: React.FC = () => ( - - - -) - -export default React.memo(GitHub) diff --git a/web/components/UserAvatar.tsx b/web/components/UserAvatar.tsx deleted file mode 100644 index d885c45623..0000000000 --- a/web/components/UserAvatar.tsx +++ /dev/null @@ -1,22 +0,0 @@ -import React from 'react' - -const UserAvatar: React.FC = () => { - return ( -
- - - -
- ) -} - -export default React.memo(UserAvatar) diff --git a/web/constants/Threads.ts b/web/constants/Threads.ts deleted file mode 100644 index a5cb995fa2..0000000000 --- a/web/constants/Threads.ts +++ /dev/null @@ -1 +0,0 @@ -export const defaultThreadTitle = 'New Thread' diff --git a/web/constants/screens.ts b/web/constants/screens.ts new file mode 100644 index 0000000000..cb12be3c25 --- /dev/null +++ b/web/constants/screens.ts @@ -0,0 +1,6 @@ +export enum MainViewState { + Hub, + Settings, + Thread, + LocalServer, +} diff --git a/web/constants/tagType.ts b/web/constants/tagType.ts new file mode 100644 index 0000000000..021dbee725 --- /dev/null +++ b/web/constants/tagType.ts @@ -0,0 +1,62 @@ +export enum ModelPerformance { + PerformancePositive = 'PerformancePositive', + + PerformanceNeutral = 'PerformanceNeutral', + + PerformanceNegative = 'PerformanceNegative', +} + +export enum HardwareCompatibility { + HardwareCompatible = 'HardwareCompatible', + + HardwareIncompatible = 'HardwareIncompatible', +} + +export enum ExpectedPerformance { + ExpectPerformanceMedium = 'ExpectPerformanceMedium', +} + +export enum ModelFormat { + GGUF = 'GGUF', +} + +export enum FreestyleTag { + FreeStyle = 'FreeStyle', +} + +export enum VersionTag { + Version = 'Version', +} + +export enum QuantMethodTag { + Default = 'Default', +} + +export enum NumOfBit { + Default = 'Default', +} + +export enum RamRequired { + RamDefault = 'RamDefault', +} + +export enum UsecaseTag { + UsecaseDefault = 'UsecaseDefault', +} + +export enum MiscellaneousTag { + MiscellaneousDefault = 'MiscellaneousDefault', +} + +export type TagType = + | ModelPerformance + | HardwareCompatibility + | ExpectedPerformance + | ModelFormat + | FreestyleTag + | VersionTag + | QuantMethodTag + | NumOfBit + | RamRequired + | UsecaseTag + | MiscellaneousTag diff --git a/web/containers/BlankState/index.tsx b/web/containers/BlankState/index.tsx deleted file mode 100644 index 0d953d45b9..0000000000 --- a/web/containers/BlankState/index.tsx +++ /dev/null @@ -1,24 +0,0 @@ -import { ReactNode } from 'react' - -import LogoMark from '@/containers/Brand/Logo/Mark' - -type Props = { - title: string - description?: string - action?: ReactNode -} - -const BlankState = ({ title, description, action }: Props) => { - return ( -
- -

{title}

- {description && ( -

{description}

- )} - {action && action} -
- ) -} - -export default BlankState diff --git a/web/containers/Brand/Logo/Mark.tsx b/web/containers/Brand/Logo/Mark.tsx index fe7cdc7d78..f26b9ee2ae 100644 --- a/web/containers/Brand/Logo/Mark.tsx +++ b/web/containers/Brand/Logo/Mark.tsx @@ -1,5 +1,3 @@ -import React from 'react' - import Image from 'next/image' type Props = { @@ -8,14 +6,15 @@ type Props = { className?: string } -const LogoMark: React.FC = ({ width = 24, height = 24, className }) => ( - Jan - Logo -) - -export default React.memo(LogoMark) +export default function LogoMark(props: Props) { + const { width = 24, height = 24, className } = props + return ( + Jan - Logo + ) +} diff --git a/web/containers/CenterPanelContainer/index.tsx b/web/containers/CenterPanelContainer/index.tsx index 8797b5ee64..dd8fa0ae4a 100644 --- a/web/containers/CenterPanelContainer/index.tsx +++ b/web/containers/CenterPanelContainer/index.tsx @@ -4,24 +4,13 @@ import { useAtomValue } from 'jotai' import { twMerge } from 'tailwind-merge' -import { mainViewStateAtom, MainViewState } from '@/helpers/atoms/App.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { reduceTransparentAtom } from '@/helpers/atoms/Setting.atom' const CenterPanelContainer = ({ children }: PropsWithChildren) => { const reduceTransparent = useAtomValue(reduceTransparentAtom) - const mainViewState = useAtomValue(mainViewStateAtom) - const downloadedModels = useAtomValue(downloadedModelsAtom) - return (
{ + const messages = useAtomValue(getCurrentChatMessagesAtom) + const { resendChatMessage } = useSendChatMessage() + const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom) + const setMainState = useSetAtom(mainViewStateAtom) + const setSelectedSettingScreen = useSetAtom(selectedSettingAtom) + const activeThread = useAtomValue(activeThreadAtom) + + const regenerateMessage = async () => { + const lastMessageIndex = messages.length - 1 + const message = messages[lastMessageIndex] + resendChatMessage(message) + } + + const getErrorTitle = () => { + switch (message.error_code) { + case ErrorCode.Unknown: + return 'Apologies, something’s amiss!' + case ErrorCode.InvalidApiKey: + case ErrorCode.AuthenticationError: + case ErrorCode.InvalidRequestError: + return ( + + Invalid API key. Please check your API key from{' '} + {' '} + and try again. + + ) + default: + return ( + <> + {message.content[0]?.text?.value && ( + + )} + + ) + } + } + + return ( +
+ {message.status === MessageStatus.Stopped && ( +
+ + Oops! The generation was interrupted. Let's give it another go! + + +
+ )} + {message.status === MessageStatus.Error && ( +
+ {getErrorTitle()} +

+ Jan’s in beta. Access  + setModalTroubleShooting(true)} + > + troubleshooting assistance + +  now. +

+ +
+ )} +
+ ) +} +export default ErrorMessage diff --git a/web/containers/Layout/BottomPanel/DownloadingState/index.tsx b/web/containers/Layout/BottomPanel/DownloadingState/index.tsx new file mode 100644 index 0000000000..ddc2eab913 --- /dev/null +++ b/web/containers/Layout/BottomPanel/DownloadingState/index.tsx @@ -0,0 +1,97 @@ +import { Fragment } from 'react' + +import { Progress, Modal, Button } from '@janhq/joi' + +import { useAtomValue } from 'jotai' + +import useDownloadModel from '@/hooks/useDownloadModel' +import { modelDownloadStateAtom } from '@/hooks/useDownloadState' + +import { formatDownloadPercentage } from '@/utils/converter' + +import { getDownloadingModelAtom } from '@/helpers/atoms/Model.atom' + +export default function DownloadingState() { + const downloadStates = useAtomValue(modelDownloadStateAtom) + const downloadingModels = useAtomValue(getDownloadingModelAtom) + const { abortModelDownload } = useDownloadModel() + + const totalCurrentProgress = Object.values(downloadStates) + .map((a) => a.size.transferred + a.size.transferred) + .reduce((partialSum, a) => partialSum + a, 0) + + const totalSize = Object.values(downloadStates) + .map((a) => a.size.total + a.size.total) + .reduce((partialSum, a) => partialSum + a, 0) + + const totalPercentage = + totalSize !== 0 ? ((totalCurrentProgress / totalSize) * 100).toFixed(2) : 0 + + return ( + + {Object.values(downloadStates)?.length > 0 && ( + + + +
+ {totalPercentage}% +
+
+ } + content={ +
+ {Object.values(downloadStates).map((item, i) => ( +
+ +
+
+

+ {item?.modelId} +

+ + {formatDownloadPercentage(item?.percent)} + +
+ +
+
+ ))} +
+ } + /> + )} + + ) +} diff --git a/web/containers/Layout/BottomPanel/DownloadingStatus/index.tsx b/web/containers/Layout/BottomPanel/DownloadingStatus/index.tsx deleted file mode 100644 index 2302728a98..0000000000 --- a/web/containers/Layout/BottomPanel/DownloadingStatus/index.tsx +++ /dev/null @@ -1,125 +0,0 @@ -import { Fragment } from 'react' - -import { DownloadItem } from '@janhq/core' -import { Progress, Modal, Button } from '@janhq/joi' -import { useAtomValue } from 'jotai' - -import useAbortDownload from '@/hooks/useAbortDownload' -import { downloadStateListAtom } from '@/hooks/useDownloadState' - -import { formatDownloadPercentage } from '@/utils/converter' - -const DownloadStatus: React.FC = () => { - const downloadStates = useAtomValue(downloadStateListAtom) - const { abortDownload } = useAbortDownload() - - const totalTransfferedSize = downloadStates.reduce( - (partialSum: number, state) => - partialSum + - state.children.reduce( - (partialSum: number, downloadItem: DownloadItem) => - partialSum + downloadItem.size.transferred, - 0 - ), - 0 - ) - - const totalDownloadSize = downloadStates.reduce( - (partialSum: number, state) => - partialSum + - state.children.reduce( - (partialSum: number, downloadItem: DownloadItem) => - partialSum + downloadItem.size.total, - 0 - ), - 0 - ) - - const totalPercentage = - totalDownloadSize !== 0 - ? ((totalTransfferedSize / totalDownloadSize) * 100).toFixed(2) - : 0 - - const downloadTitle = `Downloading ${downloadStates - .map((state) => state.type) - .filter((value, index, self) => self.indexOf(value) === index) - .join(', ') - .trim()}` - - return ( - - {Object.values(downloadStates)?.length > 0 && ( - - - -
- {totalPercentage}% -
-
- } - content={ -
- {Object.values(downloadStates).map((item, i) => { - // TODO: move this to another component - const transferred = item.children.reduce( - (sum: number, downloadItem: DownloadItem) => - sum + downloadItem.size.transferred, - 0 - ) - const total = item.children.reduce( - (sum: number, downloadItem: DownloadItem) => - sum + downloadItem.size.total, - 0 - ) - - return ( -
- -
-
-

- {item.title} -

- - {formatDownloadPercentage(transferred / total)} - -
- -
-
- ) - })} -
- } - /> - )} - - ) -} - -export default DownloadStatus diff --git a/web/containers/Layout/BottomPanel/InstallingExtension/InstallingExtensionModal.tsx b/web/containers/Layout/BottomPanel/InstallingExtension/InstallingExtensionModal.tsx index d6f37d84cc..0d5e4d4e32 100644 --- a/web/containers/Layout/BottomPanel/InstallingExtension/InstallingExtensionModal.tsx +++ b/web/containers/Layout/BottomPanel/InstallingExtension/InstallingExtensionModal.tsx @@ -1,10 +1,9 @@ import { useCallback, useEffect } from 'react' +import { abortDownload } from '@janhq/core' import { Button, Modal, Progress } from '@janhq/joi' import { atom, useAtom, useAtomValue } from 'jotai' -import useAbortDownload from '@/hooks/useAbortDownload' - import { formatDownloadPercentage, formatExtensionsName, @@ -21,7 +20,6 @@ const InstallingExtensionModal = () => { const [showInstallingExtensionModal, setShowInstallingExtensionModal] = useAtom(showInstallingExtensionModalAtom) const installingExtensions = useAtomValue(installingExtensionAtom) - const { abortDownload } = useAbortDownload() useEffect(() => { if (installingExtensions.length === 0) { @@ -35,7 +33,7 @@ const InstallingExtensionModal = () => { abortDownload(item.localPath) } }, - [abortDownload] + [] ) return ( diff --git a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx index b320a80ad2..5ea32558c1 100644 --- a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx +++ b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx @@ -1,80 +1,81 @@ -import { useCallback } from 'react' +import { Fragment } from 'react' -import { Model } from '@janhq/core' -import { Button, Badge } from '@janhq/joi' +import { Tooltip, Button, Badge } from '@janhq/joi' -import { useAtomValue } from 'jotai' +import { useAtom } from 'jotai' -import useModelStop from '@/hooks/useModelStop' +import { useActiveModel } from '@/hooks/useActiveModel' -import { - activeModelsAtom, - downloadedModelsAtom, -} from '@/helpers/atoms/Model.atom' +import { toGibibytes } from '@/utils/converter' -const Column = ['Name', 'Engine', ''] +import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' -const TableActiveModel: React.FC = () => { - const stopModelMutation = useModelStop() - const activeModels = useAtomValue(activeModelsAtom) - const downloadedModels = useAtomValue(downloadedModelsAtom) +const Column = ['Name', 'Size', ''] - const models: Model[] = [] - activeModels.forEach((m) => { - const model = downloadedModels.find((dm) => dm.model === m.model) - if (model) { - models.push(model) - } - }) - - const onStopModelClick = useCallback( - (modelId: string) => { - stopModelMutation.mutate(modelId) - }, - [stopModelMutation] - ) +const TableActiveModel = () => { + const { activeModel, stateModel, stopModel } = useActiveModel() + const [serverEnabled, setServerEnabled] = useAtom(serverEnabledAtom) return (
- + - {Column.map((col, i) => ( - - ))} + {Column.map((col, i) => { + return ( + + ) + })} - {models.map((model) => ( - - - - - + + - - - ))} +

{activeModel.name}

+ + + + + + + )}
- {col} - + {col} +
-

{model.model}

-
- - {!model.engine ? '-' : `${model.engine}`} - - -
- Stop - -
+ + {toGibibytes(activeModel.metadata.size)} + + + { + stopModel() + window.core?.api?.stopServer() + setServerEnabled(false) + }} + > + Stop + + } + content="The API server is running, stop the model will + also stop the server" + disabled={!serverEnabled} + /> +
diff --git a/web/containers/Layout/BottomPanel/SystemMonitor/index.tsx b/web/containers/Layout/BottomPanel/SystemMonitor/index.tsx index 50c577c58a..9d6311e737 100644 --- a/web/containers/Layout/BottomPanel/SystemMonitor/index.tsx +++ b/web/containers/Layout/BottomPanel/SystemMonitor/index.tsx @@ -1,10 +1,7 @@ -import { Fragment, useCallback, useEffect, useRef, useState } from 'react' +import { Fragment, useEffect, useState } from 'react' -import { ResourceStatus } from '@janhq/core' import { Progress } from '@janhq/joi' import { useClickOutside } from '@janhq/joi' - -import { fetchEventSource } from '@microsoft/fetch-event-source' import { useAtom, useAtomValue } from 'jotai' import { MonitorIcon, @@ -16,28 +13,32 @@ import { import { twMerge } from 'tailwind-merge' +import useGetSystemResources from '@/hooks/useGetSystemResources' + +import { usePath } from '@/hooks/usePath' + import { toGibibytes } from '@/utils/converter' import TableActiveModel from './TableActiveModel' import { showSystemMonitorPanelAtom } from '@/helpers/atoms/App.atom' -import { hostAtom } from '@/helpers/atoms/AppConfig.atom' import { reduceTransparentAtom } from '@/helpers/atoms/Setting.atom' import { cpuUsageAtom, gpusAtom, + ramUtilitizedAtom, totalRamAtom, usedRamAtom, } from '@/helpers/atoms/SystemBar.atom' -const SystemMonitor: React.FC = () => { - const host = useAtomValue(hostAtom) - const [usedRam, setUsedRam] = useAtom(usedRamAtom) - const [totalRam, setTotalRam] = useAtom(totalRamAtom) - const [cpuUsage, setCpuUsage] = useAtom(cpuUsageAtom) - const [gpus, setGpus] = useAtom(gpusAtom) - +const SystemMonitor = () => { + const totalRam = useAtomValue(totalRamAtom) + const usedRam = useAtomValue(usedRamAtom) + const cpuUsage = useAtomValue(cpuUsageAtom) + const gpus = useAtomValue(gpusAtom) + const { onRevealInFinder } = usePath() const [showFullScreen, setShowFullScreen] = useState(false) + const ramUtilitized = useAtomValue(ramUtilitizedAtom) const [showSystemMonitorPanel, setShowSystemMonitorPanel] = useAtom( showSystemMonitorPanelAtom ) @@ -46,44 +47,8 @@ const SystemMonitor: React.FC = () => { null ) const reduceTransparent = useAtomValue(reduceTransparentAtom) - const abortControllerRef = useRef(null) - - const onOpenAppLogClick = useCallback(() => { - window?.electronAPI?.openAppLog() - }, []) - - const register = useCallback(async () => { - if (abortControllerRef.current) return - abortControllerRef.current = new AbortController() - await fetchEventSource(`${host}/system/events/resources`, { - onmessage(ev) { - if (!ev.data || ev.data === '') return - try { - const resourceEvent = JSON.parse(ev.data) as ResourceStatus - setUsedRam(resourceEvent.mem.used) - setTotalRam(resourceEvent.mem.total) - setCpuUsage(resourceEvent.cpu.usage) - setGpus( - resourceEvent.gpus?.filter( - // Do not check vram used here - // since it could count 0 case - (gpu) => gpu.name && gpu.vram.total - ) ?? [] - ) - } catch (err) { - console.error(err) - } - }, - signal: abortControllerRef.current.signal, - }) - }, [host, setTotalRam, setUsedRam, setCpuUsage, setGpus]) - - const unregister = useCallback(() => { - if (!abortControllerRef.current) return - abortControllerRef.current.abort() - abortControllerRef.current = null - }, []) + const { watch, stopWatching } = useGetSystemResources() useClickOutside( () => { setShowSystemMonitorPanel(false) @@ -94,11 +59,14 @@ const SystemMonitor: React.FC = () => { ) useEffect(() => { - register() + // Watch for resource update + watch() + return () => { - unregister() + stopWatching() } - }, [register, unregister]) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) return ( @@ -132,7 +100,7 @@ const SystemMonitor: React.FC = () => {
onOpenAppLogClick()} + onClick={() => onRevealInFinder('Logs')} > App Log
@@ -185,9 +153,7 @@ const SystemMonitor: React.FC = () => { className="w-full" size="small" /> - - {Math.round((usedRam / totalRam) * 100)}% - + {ramUtilitized}%
@@ -202,7 +168,8 @@ const SystemMonitor: React.FC = () => {
- {gpu.vram.used}/{gpu.vram.total} + {gpu.memoryTotal - gpu.memoryFree}/ + {gpu.memoryTotal} MB
@@ -211,17 +178,12 @@ const SystemMonitor: React.FC = () => {
- {Math.round( - (gpu.vram.used / Math.max(gpu.vram.total, 1)) * 100 - )} - % + {gpu.utilization}%
diff --git a/web/containers/Layout/BottomPanel/index.tsx b/web/containers/Layout/BottomPanel/index.tsx index 24c332c78d..cc0efd8056 100644 --- a/web/containers/Layout/BottomPanel/index.tsx +++ b/web/containers/Layout/BottomPanel/index.tsx @@ -1,13 +1,10 @@ import { Button, Tooltip } from '@janhq/joi' import { useAtomValue } from 'jotai' +import { FaGithub, FaDiscord } from 'react-icons/fa' import { twMerge } from 'tailwind-merge' -import Discord from '@/components/Discord' - -import GitHub from '@/components/GitHub' - -import DownloadingStatus from './DownloadingStatus' +import DownloadingState from './DownloadingState' import ImportingModelState from './ImportingModelState' import InstallingExtension from './InstallingExtension' @@ -21,12 +18,12 @@ import { reduceTransparentAtom } from '@/helpers/atoms/Setting.atom' const menuLinks = [ { name: 'Discord', - icon: , + icon: , link: 'https://discord.gg/FTk2MvZwJH', }, { name: 'Github', - icon: , + icon: , link: 'https://github.com/janhq/jan', }, ] @@ -43,18 +40,18 @@ const BottomPanel = () => { 'border-t border-[hsla(var(--app-border))] bg-[hsla(var(--bottom-panel-bg))]' )} > -
+
{progress && progress > 0 ? ( ) : null}
- +
-
+
Jan v{VERSION ?? ''} @@ -62,25 +59,26 @@ const BottomPanel = () => {
{menuLinks .filter((link) => !!link) - .map((link) => ( - - - {link.icon} - - - } - content={link.name} - /> + .map((link, i) => ( +
+ + + {link.icon} + + + } + content={link.name} + /> +
))}
diff --git a/web/containers/Layout/RibbonPanel/index.tsx b/web/containers/Layout/RibbonPanel/index.tsx index 46cf96922c..b9b1434ae4 100644 --- a/web/containers/Layout/RibbonPanel/index.tsx +++ b/web/containers/Layout/RibbonPanel/index.tsx @@ -1,57 +1,73 @@ import { Tooltip, useMediaQuery } from '@janhq/joi' import { motion as m } from 'framer-motion' import { useAtom, useAtomValue, useSetAtom } from 'jotai' -import { MessageCircleIcon, SettingsIcon, LayoutGridIcon } from 'lucide-react' +import { + MessageCircleIcon, + SettingsIcon, + LayoutGridIcon, + SquareCodeIcon, +} from 'lucide-react' import { twMerge } from 'tailwind-merge' -import { - MainViewState, - mainViewStateAtom, - showLeftPanelAtom, -} from '@/helpers/atoms/App.atom' -import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom' +import { MainViewState } from '@/constants/screens' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' +import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom' +import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom' +import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' import { reduceTransparentAtom, selectedSettingAtom, } from '@/helpers/atoms/Setting.atom' -const RibbonNavMenus = [ - { - name: 'Thread', - icon: , - state: MainViewState.Thread, - }, - { - name: 'Hub', - icon: , - state: MainViewState.Hub, - }, - { - name: 'Settings', - icon: , - state: MainViewState.Settings, - }, -] - export default function RibbonPanel() { const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom) + const [serverEnabled] = useAtom(serverEnabledAtom) const setEditMessage = useSetAtom(editMessageAtom) const showLeftPanel = useAtomValue(showLeftPanelAtom) const matches = useMediaQuery('(max-width: 880px)') const reduceTransparent = useAtomValue(reduceTransparentAtom) const setSelectedSetting = useSetAtom(selectedSettingAtom) - const downloadedModels = useAtomValue(downloadedModelsAtom) const onMenuClick = (state: MainViewState) => { if (mainViewState === state) return + if (serverEnabled && state === MainViewState.Thread) return if (state === MainViewState.Settings) setSelectedSetting('My Models') setMainViewState(state) setEditMessage('') } + const RibbonNavMenus = [ + { + name: 'Thread', + icon: ( + + ), + state: MainViewState.Thread, + }, + { + name: 'Hub', + icon: , + state: MainViewState.Hub, + }, + { + name: 'Local API Server', + icon: , + state: MainViewState.LocalServer, + }, + { + name: 'Settings', + icon: , + state: MainViewState.Settings, + }, + ] + return (
{RibbonNavMenus.filter((menu) => !!menu).map((menu, i) => { const isActive = mainViewState === menu.state return (
onMenuClick(menu.state)} + key={i} > -
- -
- {menu.icon} -
- {isActive && ( - + +
onMenuClick(menu.state)} + > + {menu.icon}
- } - content={menu.name} - /> -
+ {isActive && ( + + )} +
+ } + content={ + serverEnabled && menu.state === MainViewState.Thread + ? 'Threads are disabled while the server is running' + : menu.name + } + />
) })} diff --git a/web/containers/Layout/TopPanel/index.tsx b/web/containers/Layout/TopPanel/index.tsx index 7cb264f471..6dd9ba8a5d 100644 --- a/web/containers/Layout/TopPanel/index.tsx +++ b/web/containers/Layout/TopPanel/index.tsx @@ -1,4 +1,4 @@ -import { Fragment, useCallback, useEffect } from 'react' +import { Fragment } from 'react' import { Button } from '@janhq/joi' import { useAtom, useAtomValue, useSetAtom } from 'jotai' @@ -12,35 +12,22 @@ import { SquareIcon, PaletteIcon, XIcon, - PenSquareIcon, } from 'lucide-react' import { twMerge } from 'tailwind-merge' import LogoMark from '@/containers/Brand/Logo/Mark' -import { toaster } from '@/containers/Toast' - -import useAssistantQuery from '@/hooks/useAssistantQuery' -import useThreadCreateMutation from '@/hooks/useThreadCreateMutation' -import useThreads from '@/hooks/useThreads' - -import { copyOverInstructionEnabledAtom } from '@/screens/Thread/ThreadRightPanel/AssistantSettingContainer/components/CopyOverInstruction' +import { MainViewState } from '@/constants/screens' import { - MainViewState, mainViewStateAtom, showLeftPanelAtom, showRightPanelAtom, } from '@/helpers/atoms/App.atom' -import { - downloadedModelsAtom, - getSelectedModelAtom, -} from '@/helpers/atoms/Model.atom' import { reduceTransparentAtom, selectedSettingAtom, } from '@/helpers/atoms/Setting.atom' -import { threadsAtom, activeThreadAtom } from '@/helpers/atoms/Thread.atom' const TopPanel = () => { const [showLeftPanel, setShowLeftPanel] = useAtom(showLeftPanelAtom) @@ -48,52 +35,6 @@ const TopPanel = () => { const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom) const setSelectedSetting = useSetAtom(selectedSettingAtom) const reduceTransparent = useAtomValue(reduceTransparentAtom) - const downloadedModels = useAtomValue(downloadedModelsAtom) - - const { setActiveThread } = useThreads() - const createThreadMutation = useThreadCreateMutation() - - const selectedModel = useAtomValue(getSelectedModelAtom) - const threads = useAtomValue(threadsAtom) - - const activeThread = useAtomValue(activeThreadAtom) - const { data: assistants } = useAssistantQuery() - const copyOverInstructionEnabled = useAtomValue( - copyOverInstructionEnabledAtom - ) - - useEffect(() => { - if (activeThread?.id) return - if (threads.length === 0) return - setActiveThread(threads[0].id) - }, [activeThread?.id, setActiveThread, threads]) - - const onCreateThreadClicked = useCallback(async () => { - if (!assistants || !assistants.length) { - toaster({ - title: 'No assistant available.', - description: `Could not create a new thread. Please add an assistant.`, - type: 'error', - }) - return - } - if (!selectedModel) return - let instructions: string | undefined = undefined - if (copyOverInstructionEnabled) { - instructions = activeThread?.assistants[0]?.instructions ?? undefined - } - await createThreadMutation.mutateAsync({ - modelId: selectedModel.model, - assistant: assistants[0], - instructions, - }) - }, [ - createThreadMutation, - selectedModel, - assistants, - activeThread, - copyOverInstructionEnabled, - ]) return (
{ )} - {mainViewState !== MainViewState.Hub && - downloadedModels.length > 0 && ( - - {showLeftPanel ? ( - - ) : ( - - )} - - )} - {mainViewState === MainViewState.Thread && ( - + {mainViewState !== MainViewState.Hub && ( + + {showLeftPanel ? ( + + ) : ( + + )} + )}
{mainViewState !== MainViewState.Hub && - mainViewState !== MainViewState.Settings && - downloadedModels.length > 0 && ( + mainViewState !== MainViewState.Settings && ( {showRightPanel ? ( + ) : ( + + /> + + {formatDownloadPercentage(downloadState.percent)} + +
+ + ) } content={

Are you sure you want to cancel the download of  - {/* {downloadState?.modelId}? */} + {downloadState?.modelId}?

@@ -45,10 +76,7 @@ const ModalCancelDownload: React.FC = ({ model }) => { - diff --git a/web/containers/ModalTroubleShoot/AppLogs.tsx b/web/containers/ModalTroubleShoot/AppLogs.tsx index 3979e09606..7b0a31a5da 100644 --- a/web/containers/ModalTroubleShoot/AppLogs.tsx +++ b/web/containers/ModalTroubleShoot/AppLogs.tsx @@ -10,8 +10,6 @@ import { useClipboard } from '@/hooks/useClipboard' import { useLogs } from '@/hooks/useLogs' import { usePath } from '@/hooks/usePath' -import EmptyIcon from '@/screens/HubScreen2/components/EmptyIcon' - const AppLogs = () => { const { getLogs } = useLogs() const [logs, setLogs] = useState([]) @@ -88,7 +86,135 @@ const AppLogs = () => { ) : (
- + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

Empty logs

)} diff --git a/web/containers/ModalTroubleShoot/index.tsx b/web/containers/ModalTroubleShoot/index.tsx index 058307b140..67ccbe22fa 100644 --- a/web/containers/ModalTroubleShoot/index.tsx +++ b/web/containers/ModalTroubleShoot/index.tsx @@ -36,7 +36,7 @@ const ModalTroubleShooting = () => {
{!showLogFullSize && ( -
+

Step 1

Follow our  @@ -54,7 +54,7 @@ const ModalTroubleShooting = () => {

diff --git a/web/containers/ModelDropdown/ModelSection.tsx b/web/containers/ModelDropdown/ModelSection.tsx deleted file mode 100644 index dca851c1fc..0000000000 --- a/web/containers/ModelDropdown/ModelSection.tsx +++ /dev/null @@ -1,199 +0,0 @@ -import React, { useCallback, useEffect, useState } from 'react' - -import Image from 'next/image' - -import { - EngineStatus, - LlmEngine, - LocalEngine, - Model, - RemoteEngine, - RemoteEngines, -} from '@janhq/core' - -import { Button } from '@janhq/joi' -import { useAtom, useSetAtom } from 'jotai' -import { - SettingsIcon, - ChevronDownIcon, - ChevronUpIcon, - PlusIcon, -} from 'lucide-react' - -import { twMerge } from 'tailwind-merge' - -import useEngineQuery from '@/hooks/useEngineQuery' -import useGetModelsByEngine from '@/hooks/useGetModelsByEngine' - -import { - getLogoByLocalEngine, - getLogoByRemoteEngine, - getTitleByCategory, -} from '@/utils/model-engine' - -import ModelLabel from '../ModelLabel' - -import { showEngineListModelAtom } from '@/helpers/atoms/Model.atom' -import { setUpRemoteModelStageAtom } from '@/helpers/atoms/SetupRemoteModel.atom' - -type Props = { - engine: LlmEngine - searchText: string - onModelSelected: (model: Model) => void -} - -const ModelSection: React.FC = ({ - engine, - searchText, - onModelSelected, -}) => { - const [models, setModels] = useState([]) - const { getModelsByEngine } = useGetModelsByEngine() - const setUpRemoteModelStage = useSetAtom(setUpRemoteModelStageAtom) - const { data: engineData } = useEngineQuery() - - const [showEngineListModel, setShowEngineListModel] = useAtom( - showEngineListModelAtom - ) - - const engineLogo: string | undefined = models.find( - (entry) => entry?.metadata?.logo != null - )?.metadata?.logo - - const apiKeyUrl: string | undefined = models.find( - (entry) => entry?.metadata?.api_key_url != null - )?.metadata?.api_key_url - - const onSettingClick = useCallback(() => { - setUpRemoteModelStage('SETUP_API_KEY', engine as unknown as RemoteEngine, { - logo: engineLogo, - api_key_url: apiKeyUrl, - }) - }, [apiKeyUrl, engine, engineLogo, setUpRemoteModelStage]) - - const isEngineReady = - engineData?.find((e) => e.name === engine)?.status === EngineStatus.Ready - - const getEngineStatusReady: LlmEngine[] | undefined = engineData - ?.filter((e) => e.status === EngineStatus.Ready) - .map((x) => x.name as LlmEngine) - - const showModel = showEngineListModel.includes(engine) - - const onClickChevron = useCallback(() => { - if (showModel) { - setShowEngineListModel((prev) => prev.filter((item) => item !== engine)) - } else { - setShowEngineListModel((prev) => [...prev, engine]) - } - }, [engine, setShowEngineListModel, showModel]) - - useEffect(() => { - const matchedModels = getModelsByEngine(engine, searchText) - setModels(matchedModels) - setShowEngineListModel((prev) => [ - ...prev, - ...(getEngineStatusReady as LlmEngine[]), - ]) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [getModelsByEngine, engine, searchText, setShowEngineListModel]) - - const engineName = getTitleByCategory(engine) - const localEngineLogo = getLogoByLocalEngine(engine as LocalEngine) - const remoteEngineLogo = getLogoByRemoteEngine(engine as RemoteEngine) - const isRemoteEngine = RemoteEngines.includes(engine as RemoteEngine) - - if (models.length === 0) return null - - return ( -
-
-
- {!isRemoteEngine && localEngineLogo && ( - logo - )} - - {remoteEngineLogo && ( - {`logo - )} -
- {engineName} -
-
-
- {isRemoteEngine && ( - - )} - {!showModel ? ( - - ) : ( - - )} -
-
-
    - {models.map((model) => { - if (!showModel) return null - return ( -
  • { - onModelSelected(model) - }} - > -
    -

    {model.name ?? model.model}

    -
    - -
  • - ) - })} -
-
- ) -} - -export default ModelSection diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx index 2cd09b329d..c19fb64bdf 100644 --- a/web/containers/ModelDropdown/index.tsx +++ b/web/containers/ModelDropdown/index.tsx @@ -1,87 +1,246 @@ -import { useState, useCallback, useEffect, useRef } from 'react' +import { useState, useMemo, useEffect, useCallback, useRef } from 'react' -import { LlmEngines, LocalEngines, Model, RemoteEngines } from '@janhq/core' +import { InferenceEngine } from '@janhq/core' import { Badge, Input, ScrollArea, Select, useClickOutside } from '@janhq/joi' -import { useAtomValue } from 'jotai' +import { useAtom, useAtomValue, useSetAtom } from 'jotai' -import { ChevronDownIcon, XIcon } from 'lucide-react' +import { ChevronDownIcon, DownloadCloudIcon, XIcon } from 'lucide-react' import { twMerge } from 'tailwind-merge' -import useCortex from '@/hooks/useCortex' +import ProgressCircle from '@/containers/Loader/ProgressCircle' -import useSelectModel from '@/hooks/useSelectModel' +import ModelLabel from '@/containers/ModelLabel' -import ModelSection from './ModelSection' +import SetupRemoteModel from '@/containers/SetupRemoteModel' +import useDownloadModel from '@/hooks/useDownloadModel' +import { modelDownloadStateAtom } from '@/hooks/useDownloadState' +import useRecommendedModel from '@/hooks/useRecommendedModel' + +import useUpdateModelParameters from '@/hooks/useUpdateModelParameters' + +import { formatDownloadPercentage, toGibibytes } from '@/utils/converter' + +import { extensionManager } from '@/extension' + +import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' import { - downloadedModelsAtom, - getSelectedModelAtom, + configuredModelsAtom, + getDownloadingModelAtom, + selectedModelAtom, } from '@/helpers/atoms/Model.atom' -import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' +import { + activeThreadAtom, + setThreadModelParamsAtom, +} from '@/helpers/atoms/Thread.atom' type Props = { chatInputMode?: boolean + strictedThread?: boolean + disabled?: boolean } -const ModelDropdown: React.FC = ({ chatInputMode }) => { - const downloadedModels = useAtomValue(downloadedModelsAtom) +const engineHasLogo = [ + InferenceEngine.anthropic, + InferenceEngine.cohere, + InferenceEngine.martian, + InferenceEngine.mistral, + InferenceEngine.openai, +] + +const ModelDropdown = ({ + disabled, + chatInputMode, + strictedThread = true, +}: Props) => { + const { downloadModel } = useDownloadModel() const [searchFilter, setSearchFilter] = useState('all') const [filterOptionsOpen, setFilterOptionsOpen] = useState(false) const [searchText, setSearchText] = useState('') - const { selectModel } = useSelectModel() - const [open, setOpen] = useState(false) const activeThread = useAtomValue(activeThreadAtom) + const downloadingModels = useAtomValue(getDownloadingModelAtom) const [toggle, setToggle] = useState(null) - const selectedModel = useAtomValue(getSelectedModelAtom) - const { createModel } = useCortex() - + const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) + const { recommendedModel, downloadedModels } = useRecommendedModel() const [dropdownOptions, setDropdownOptions] = useState( null ) + const downloadStates = useAtomValue(modelDownloadStateAtom) + const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) + const { updateModelParameter } = useUpdateModelParameters() const searchInputRef = useRef(null) + const configuredModels = useAtomValue(configuredModelsAtom) + const featuredModel = configuredModels.filter((x) => + x.metadata.tags.includes('Featured') + ) useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [ dropdownOptions, toggle, ]) + const filteredDownloadedModels = useMemo( + () => + configuredModels + .filter((e) => + e.name.toLowerCase().includes(searchText.toLowerCase().trim()) + ) + .filter((e) => { + if (searchFilter === 'all') { + return e.engine + } + if (searchFilter === 'local') { + return ( + e.engine === InferenceEngine.nitro || + e.engine === InferenceEngine.nitro_tensorrt_llm + ) + } + if (searchFilter === 'remote') { + return ( + e.engine !== InferenceEngine.nitro && + e.engine !== InferenceEngine.nitro_tensorrt_llm + ) + } + }) + .sort((a, b) => a.name.localeCompare(b.name)) + .sort((a, b) => { + const aInDownloadedModels = downloadedModels.some( + (item) => item.id === a.id + ) + const bInDownloadedModels = downloadedModels.some( + (item) => item.id === b.id + ) + if (aInDownloadedModels && !bInDownloadedModels) { + return -1 + } else if (!aInDownloadedModels && bInDownloadedModels) { + return 1 + } else { + return 0 + } + }), + [configuredModels, searchText, searchFilter, downloadedModels] + ) + useEffect(() => { if (open && searchInputRef.current) { searchInputRef.current.focus() } }, [open]) - const onModelSelected = useCallback( - async (model: Model) => { - const isModelAddedToCortex = downloadedModels.find( - (m) => m.model === model.model - ) - if (!isModelAddedToCortex) { - await createModel(model) - } + useEffect(() => { + if (!activeThread) return + let model = downloadedModels.find( + (model) => model.id === activeThread.assistants[0].model.id + ) + if (!model) { + model = recommendedModel + } + setSelectedModel(model) + }, [recommendedModel, activeThread, downloadedModels, setSelectedModel]) - selectModel(model) + const onClickModelItem = useCallback( + async (modelId: string) => { + const model = downloadedModels.find((m) => m.id === modelId) + setSelectedModel(model) setOpen(false) + + if (activeThread) { + // Default setting ctx_len for the model for a better onboarding experience + // TODO: When Cortex support hardware instructions, we should remove this + const overriddenSettings = + model?.settings.ctx_len && model.settings.ctx_len > 2048 + ? { ctx_len: 2048 } + : {} + + const modelParams = { + ...model?.parameters, + ...model?.settings, + ...overriddenSettings, + } + + // Update model parameter to the thread state + setThreadModelParams(activeThread.id, modelParams) + + // Update model parameter to the thread file + if (model) + updateModelParameter(activeThread, { + params: modelParams, + modelId: model.id, + engine: model.engine, + }) + } }, - [selectModel, createModel, downloadedModels] + [ + downloadedModels, + activeThread, + setSelectedModel, + setThreadModelParams, + updateModelParameter, + ] ) - const engines = - searchFilter === 'local' - ? LocalEngines - : searchFilter === 'remote' - ? RemoteEngines - : LlmEngines + const [extensionHasSettings, setExtensionHasSettings] = useState< + { name?: string; setting: string; apiKey: string; provider: string }[] + >([]) + + const inActiveEngineProvider = useAtomValue(inActiveEngineProviderAtom) - if (!activeThread) return null + useEffect(() => { + const getAllSettings = async () => { + const extensionsMenu: { + name?: string + setting: string + apiKey: string + provider: string + }[] = [] + const extensions = extensionManager.getAll() + + for (const extension of extensions) { + if (typeof extension.getSettings === 'function') { + const settings = await extension.getSettings() - const modelId = selectedModel?.model ?? '' - const modelName = selectedModel?.name ?? modelId + if ( + (settings && settings.length > 0) || + (await extension.installationState()) !== 'NotRequired' + ) { + extensionsMenu.push({ + name: extension.productName, + setting: extension.name, + apiKey: + 'apiKey' in extension && typeof extension.apiKey === 'string' + ? extension.apiKey + : '', + provider: + 'provider' in extension && + typeof extension.provider === 'string' + ? extension.provider + : '', + }) + } + } + } + setExtensionHasSettings(extensionsMenu) + } + getAllSettings() + }, []) + + const findByEngine = filteredDownloadedModels + .filter((x) => !inActiveEngineProvider.includes(x.engine)) + .map((x) => x.engine) + + const groupByEngine = findByEngine.filter(function (item, index) { + if (findByEngine.indexOf(item) === index) + return item !== InferenceEngine.nitro + }) + + if (strictedThread && !activeThread) { + return null + } return ( -
+
{chatInputMode ? ( = ({ chatInputMode }) => { className="cursor-pointer" onClick={() => setOpen(!open)} > - {modelName} + {selectedModel?.name} ) : ( = ({ chatInputMode }) => {
= ({ chatInputMode }) => {
setSearchText(e.target.value)} suffixIcon={ @@ -152,14 +312,230 @@ const ModelDropdown: React.FC = ({ chatInputMode }) => {
- {engines.map((engine) => ( - - ))} + {searchFilter !== 'remote' && ( +
+
+
+ Cortex +
+
+ {filteredDownloadedModels + .filter((x) => { + if (searchText.length === 0) { + return downloadedModels.find((c) => c.id === x.id) + } else { + return x + } + }) + .filter((x) => x.engine === InferenceEngine.nitro).length !== + 0 ? ( +
    + {filteredDownloadedModels + ? filteredDownloadedModels + .filter((x) => x.engine === InferenceEngine.nitro) + .filter((x) => { + if (searchText.length === 0) { + return downloadedModels.find((c) => c.id === x.id) + } else { + return x + } + }) + .map((model) => { + const isDownloading = downloadingModels.some( + (md) => md.id === model.id + ) + const isdDownloaded = downloadedModels.some( + (c) => c.id === model.id + ) + return ( +
  • { + if (isdDownloaded) { + onClickModelItem(model.id) + } + }} + > +
    +

    + {model.name} +

    + +
    +
    + {!isdDownloaded && ( + + {toGibibytes(model.metadata.size)} + + )} + {!isDownloading && !isdDownloaded ? ( + downloadModel(model)} + /> + ) : ( + Object.values(downloadStates) + .filter((x) => x.modelId === model.id) + .map((item) => ( + + )) + )} +
    +
  • + ) + }) + : null} +
+ ) : ( +
    + {featuredModel.map((model) => { + const isDownloading = downloadingModels.some( + (md) => md.id === model.id + ) + return ( +
  • +
    +

    + {model.name} +

    + +
    +
    + + {toGibibytes(model.metadata.size)} + + {!isDownloading ? ( + downloadModel(model)} + /> + ) : ( + Object.values(downloadStates) + .filter((x) => x.modelId === model.id) + .map((item) => ( + + )) + )} +
    +
  • + ) + })} +
+ )} +
+ )} + + {groupByEngine.map((engine, i) => { + const apiKey = + extensionHasSettings.filter((x) => x.provider === engine)[0] + ?.apiKey.length > 1 + return ( +
+
+
+
+ {engine} +
+
+ +
+
+
    + {filteredDownloadedModels + .filter((x) => x.engine === engine) + .map((model) => { + return ( +
  • { + if ( + apiKey || + model.engine === + InferenceEngine.nitro_tensorrt_llm + ) { + onClickModelItem(model.id) + } + }} + > +
    + {engineHasLogo.map((x) => { + if (x === model.engine) { + return ( +
    + Model Provider +
    + ) + } + })} +

    + {model.name} +

    +
    +
  • + ) + })} +
+
+
+ ) + })}
diff --git a/web/containers/ModelLabel/index.tsx b/web/containers/ModelLabel/index.tsx index ae11caf03e..2c32e288c0 100644 --- a/web/containers/ModelLabel/index.tsx +++ b/web/containers/ModelLabel/index.tsx @@ -1,9 +1,11 @@ -/* eslint-disable @typescript-eslint/no-explicit-any */ import React from 'react' +import { ModelMetadata } from '@janhq/core' import { Badge } from '@janhq/joi' import { useAtomValue } from 'jotai' +import { useActiveModel } from '@/hooks/useActiveModel' + import { useSettings } from '@/hooks/useSettings' import NotEnoughMemoryLabel from './NotEnoughMemoryLabel' @@ -12,7 +14,6 @@ import RecommendedLabel from './RecommendedLabel' import SlowOnYourDeviceLabel from './SlowOnYourDeviceLabel' -import { activeModelsAtom } from '@/helpers/atoms/Model.atom' import { availableVramAtom, totalRamAtom, @@ -20,10 +21,9 @@ import { } from '@/helpers/atoms/SystemBar.atom' type Props = { - metadata: Record | undefined + metadata: ModelMetadata compact?: boolean } - const UnsupportedModel = () => { return ( @@ -33,23 +33,18 @@ const UnsupportedModel = () => { } const ModelLabel = ({ metadata, compact }: Props) => { - const activeModels = useAtomValue(activeModelsAtom) + const { activeModel } = useActiveModel() const totalRam = useAtomValue(totalRamAtom) const usedRam = useAtomValue(usedRamAtom) const availableVram = useAtomValue(availableVramAtom) const { settings } = useSettings() const getLabel = (size: number) => { - const activeModelMemoryUsed = activeModels.reduce( - (acc, model) => acc + Number(model.metadata.size ?? 0), - 0 - ) - const minimumRamModel = size * 1.25 const availableRam = settings?.run_mode === 'gpu' ? availableVram * 1000000 // MB to bytes - : totalRam - usedRam + activeModelMemoryUsed + : totalRam - usedRam + (activeModel?.metadata.size ?? 0) if (minimumRamModel > totalRam) { return ( { return null } - return metadata?.tags?.includes('Coming Soon') ? ( + return metadata.tags.includes('Coming Soon') ? ( ) : ( - getLabel(metadata?.size ?? 0) + getLabel(metadata.size ?? 0) ) } diff --git a/web/containers/Providers/AppUpdateListener.tsx b/web/containers/Providers/AppUpdateListener.tsx index 033000412c..77b39bb065 100644 --- a/web/containers/Providers/AppUpdateListener.tsx +++ b/web/containers/Providers/AppUpdateListener.tsx @@ -1,4 +1,4 @@ -import { useEffect } from 'react' +import { Fragment, PropsWithChildren, useEffect } from 'react' import { AppUpdateInfo } from '@janhq/core' import { useSetAtom } from 'jotai' @@ -8,7 +8,7 @@ import { updateVersionErrorAtom, } from '@/helpers/atoms/App.atom' -const AppUpdateListener: React.FC = () => { +const AppUpdateListener = ({ children }: PropsWithChildren) => { const setProgress = useSetAtom(appDownloadProgressAtom) const setUpdateVersionError = useSetAtom(updateVersionErrorAtom) @@ -39,7 +39,7 @@ const AppUpdateListener: React.FC = () => { } }, [setProgress, setUpdateVersionError]) - return null + return {children} } export default AppUpdateListener diff --git a/web/containers/Providers/ClipboardListener.tsx b/web/containers/Providers/ClipboardListener.tsx index bafec718fe..2d9910b9b5 100644 --- a/web/containers/Providers/ClipboardListener.tsx +++ b/web/containers/Providers/ClipboardListener.tsx @@ -1,8 +1,10 @@ +import { Fragment, PropsWithChildren } from 'react' + import { useSetAtom } from 'jotai' import { selectedTextAtom } from './Jotai' -const ClipboardListener: React.FC = () => { +const ClipboardListener = ({ children }: PropsWithChildren) => { const setSelectedText = useSetAtom(selectedTextAtom) if (typeof window !== 'undefined') { @@ -11,7 +13,7 @@ const ClipboardListener: React.FC = () => { }) } - return null + return {children} } export default ClipboardListener diff --git a/web/containers/Providers/DataLoader.tsx b/web/containers/Providers/DataLoader.tsx index 1f2b8c835d..269d2f8770 100644 --- a/web/containers/Providers/DataLoader.tsx +++ b/web/containers/Providers/DataLoader.tsx @@ -1,177 +1,70 @@ 'use client' -import { useEffect, useMemo } from 'react' +import { Fragment, ReactNode, useEffect } from 'react' -import { Engine } from '@cortexso/cortex.js/resources' -import { - EngineStatus, - LocalEngine, - LocalEngines, - Model, - RemoteEngine, - RemoteEngines, -} from '@janhq/core' -import { useAtomValue, useSetAtom } from 'jotai' +import { AppConfiguration, getUserHomePath, joinPath } from '@janhq/core' +import { useSetAtom } from 'jotai' -import useAssistantCreate, { janAssistant } from '@/hooks/useAssistantCreate' -import useAssistantQuery from '@/hooks/useAssistantQuery' -import useCortex from '@/hooks/useCortex' -import useEngineQuery from '@/hooks/useEngineQuery' +import useAssistants from '@/hooks/useAssistants' +import useGetSystemResources from '@/hooks/useGetSystemResources' import { useLoadTheme } from '@/hooks/useLoadTheme' -import useModelHub from '@/hooks/useModelHub' -import useModelQuery from '@/hooks/useModelQuery' -import useThreadCreateMutation from '@/hooks/useThreadCreateMutation' -import useThreadQuery from '@/hooks/useThreadQuery' - -import { - getSelectedModelAtom, - updateSelectedModelAtom, -} from '@/helpers/atoms/Model.atom' -import { threadsAtom } from '@/helpers/atoms/Thread.atom' +import useModels from '@/hooks/useModels' +import useThreads from '@/hooks/useThreads' -const DataLoader: React.FC = () => { - const selectedModel = useAtomValue(getSelectedModelAtom) - const setSelectedModel = useSetAtom(updateSelectedModelAtom) - const allThreads = useAtomValue(threadsAtom) - const { data: assistants } = useAssistantQuery() - const { data: models } = useModelQuery() - const { data: threads, isLoading: isFetchingThread } = useThreadQuery() - const { data: engineData } = useEngineQuery() - const { data: modelHubData } = useModelHub() - const createThreadMutation = useThreadCreateMutation() - const assistantCreateMutation = useAssistantCreate() - const { createModel } = useCortex() +import { SettingScreenList } from '@/screens/Settings' - useEffect(() => { - if (!assistants) return - if (assistants.length === 0 && assistantCreateMutation.isIdle) { - // empty assistant. create new one - console.debug('Empty assistants received. Create Jan Assistant...') - assistantCreateMutation.mutate(janAssistant) - } - }, [assistants, assistantCreateMutation]) +import { defaultJanDataFolderAtom } from '@/helpers/atoms/App.atom' +import { + janDataFolderPathAtom, + quickAskEnabledAtom, +} from '@/helpers/atoms/AppConfig.atom' +import { janSettingScreenAtom } from '@/helpers/atoms/Setting.atom' - const isAnyRemoteModelConfigured = useMemo(() => { - if (!engineData) return false +type Props = { + children: ReactNode +} - let result = false - for (const engine of engineData) { - if (RemoteEngines.includes(engine.name as RemoteEngine)) { - if (engine.status === EngineStatus.Ready) { - result = true - } - } - } - return result - }, [engineData]) +const DataLoader: React.FC = ({ children }) => { + const setJanDataFolderPath = useSetAtom(janDataFolderPathAtom) + const setQuickAskEnabled = useSetAtom(quickAskEnabledAtom) + const setJanDefaultDataFolder = useSetAtom(defaultJanDataFolderAtom) + const setJanSettingScreen = useSetAtom(janSettingScreenAtom) - const isAnyModelReady = useMemo(() => { - if (!models) return false - return models.length > 0 - }, [models]) + useModels() + useThreads() + useAssistants() + useGetSystemResources() + useLoadTheme() - // automatically create new thread if thread list is empty useEffect(() => { - if (isFetchingThread) return - if (allThreads.length > 0) return - if (!assistants || assistants.length === 0) return - const shouldCreateNewThread = isAnyRemoteModelConfigured || isAnyModelReady - - if (shouldCreateNewThread && !createThreadMutation.isPending) { - // if we already have selected model then can safely proceed - if (selectedModel) { - const assistant = assistants[0] - - console.debug( - 'Create new thread because user have no thread, with selected model', - selectedModel.model - ) - createThreadMutation.mutate({ - modelId: selectedModel.model, - assistant: assistant, - }) - return - } - - let modelToBeUsed: Model | undefined = undefined - // if we have a model registered already, try to use it and prioritize local model - if (models && models.length > 0) { - for (const model of models) { - if (!model.engine) continue - if (LocalEngines.includes(model.engine as LocalEngine)) { - modelToBeUsed = model - } - } - - // if we don't have it, then just take the first one - if (!modelToBeUsed) { - modelToBeUsed = models[0] - } - } else { - if (!engineData) return - // we don't have nay registered model, so will need to check the remote engine - const remoteEngineReadyList: Engine[] = [] - for (const engine of engineData) { - if (RemoteEngines.includes(engine.name as RemoteEngine)) { - if (engine.status === EngineStatus.Ready) { - remoteEngineReadyList.push(engine) - } - } - } - - if (remoteEngineReadyList.length === 0) { - console.debug("No remote engine ready, can't create thread") - return - } - // find the model from hub that using the engine - if (!modelHubData) return - const remoteEngineReadyNames = remoteEngineReadyList.map((e) => e.name) + window.core?.api + ?.getAppConfigurations() + ?.then((appConfig: AppConfiguration) => { + setJanDataFolderPath(appConfig.data_folder) + setQuickAskEnabled(appConfig.quick_ask) + }) + }, [setJanDataFolderPath, setQuickAskEnabled]) - console.log('remoteEngineReady:', remoteEngineReadyNames) - // loop through the modelHubData.modelCategories to find the model that using the engine - for (const [key, value] of modelHubData.modelCategories) { - if (remoteEngineReadyNames.includes(key) && value.length > 0) { - modelToBeUsed = value[0].model - if (modelToBeUsed) break - } - } - } + useEffect(() => { + async function getDefaultJanDataFolder() { + const homePath = await getUserHomePath() + const defaultJanDataFolder = await joinPath([homePath, 'jan']) - if (!modelToBeUsed) { - console.debug('No model to be used') - return - } - console.log( - 'Create new thread because user have no thread, model to be used:', - modelToBeUsed.model - ) - createModel(modelToBeUsed) - setSelectedModel(modelToBeUsed) - const assistant = assistants[0] - createThreadMutation.mutate({ - modelId: modelToBeUsed.model, - assistant: assistant, - }) + setJanDefaultDataFolder(defaultJanDataFolder) } - }, [ - assistants, - models, - isFetchingThread, - threads, - createThreadMutation, - allThreads, - selectedModel, - isAnyModelReady, - isAnyRemoteModelConfigured, - engineData, - modelHubData, - setSelectedModel, - createModel, - ]) + getDefaultJanDataFolder() + }, [setJanDefaultDataFolder]) - useLoadTheme() + useEffect(() => { + const janSettingScreen = SettingScreenList.filter( + (screen) => window.electronAPI || screen !== 'Extensions' + ) + setJanSettingScreen(janSettingScreen) + }, [setJanSettingScreen]) + + console.debug('Load Data...') - return null + return {children} } export default DataLoader diff --git a/web/containers/Providers/DeepLinkListener.tsx b/web/containers/Providers/DeepLinkListener.tsx index 3a628f1f32..d5941204f2 100644 --- a/web/containers/Providers/DeepLinkListener.tsx +++ b/web/containers/Providers/DeepLinkListener.tsx @@ -1,3 +1,5 @@ +import { Fragment, ReactNode } from 'react' + import { useSetAtom } from 'jotai' import { useDebouncedCallback } from 'use-debounce' @@ -11,8 +13,11 @@ import { importHuggingFaceModelStageAtom, importingHuggingFaceRepoDataAtom, } from '@/helpers/atoms/HuggingFace.atom' +type Props = { + children: ReactNode +} -const DeepLinkListener: React.FC = () => { +const DeepLinkListener: React.FC = ({ children }) => { const { getHfRepoData } = useGetHFRepoData() const setLoadingInfo = useSetAtom(loadingModalInfoAtom) const setImportingHuggingFaceRepoData = useSetAtom( @@ -64,7 +69,7 @@ const DeepLinkListener: React.FC = () => { handleDeepLinkAction(action) }) - return null + return {children} } type DeepLinkAction = { diff --git a/web/containers/Providers/DownloadEventListener.tsx b/web/containers/Providers/DownloadEventListener.tsx deleted file mode 100644 index bb60fa41dc..0000000000 --- a/web/containers/Providers/DownloadEventListener.tsx +++ /dev/null @@ -1,125 +0,0 @@ -import { useCallback, useEffect, useRef } from 'react' - -import { DownloadState2 } from '@janhq/core' -import { fetchEventSource } from '@microsoft/fetch-event-source' -import { useQueryClient } from '@tanstack/react-query' -import { useAtomValue, useSetAtom } from 'jotai' - -import { downloadStateListAtom } from '@/hooks/useDownloadState' - -import { modelQueryKey } from '@/hooks/useModelQuery' - -import { waitingForCortexAtom } from '@/helpers/atoms/App.atom' -import { hostAtom } from '@/helpers/atoms/AppConfig.atom' -import { - setImportingModelSuccessAtom, - updateImportingModelProgressAtom, -} from '@/helpers/atoms/Model.atom' - -const DownloadEventListener: React.FC = () => { - const host = useAtomValue(hostAtom) - const isRegistered = useRef(false) - const abortController = useRef(new AbortController()) - const setDownloadStateList = useSetAtom(downloadStateListAtom) - const setWaitingForCortex = useSetAtom(waitingForCortexAtom) - - const updateImportingModelProgress = useSetAtom( - updateImportingModelProgressAtom - ) - const setImportingModelSuccess = useSetAtom(setImportingModelSuccessAtom) - const queryClient = useQueryClient() - - const handleLocalImportModels = useCallback( - (events: DownloadState2[]) => { - if (events.length === 0) return - for (const event of events) { - if (event.progress === 100) { - setImportingModelSuccess(event.id) - } else { - updateImportingModelProgress(event.id, event.progress) - } - } - - queryClient.invalidateQueries({ queryKey: modelQueryKey }) - }, - [setImportingModelSuccess, updateImportingModelProgress, queryClient] - ) - - const subscribeDownloadEvent = useCallback(async () => { - if (isRegistered.current) return - await fetchEventSource(`${host}/system/events/download`, { - onmessage(ev) { - if (!ev.data || ev.data === '') return - try { - const downloadEvents = JSON.parse(ev.data) as DownloadState2[] - const remoteDownloadEvents: DownloadState2[] = [] - const localImportEvents: DownloadState2[] = [] - // filter out the import local events - for (const event of downloadEvents) { - if ( - isAbsolutePath(event.id) && - event.type === 'model' && - event.children.length === 0 - ) { - localImportEvents.push(event) - } else { - remoteDownloadEvents.push(event) - } - } - handleLocalImportModels(localImportEvents) - setDownloadStateList(remoteDownloadEvents) - } catch (err) { - console.error(err) - } - }, - onerror(err) { - if (err.message === 'Failed to fetch') { - setWaitingForCortex(true) - } - }, - async onopen() { - setWaitingForCortex(false) - }, - signal: abortController.current.signal, - }) - console.log('Download event subscribed') - isRegistered.current = true - }, [host, setDownloadStateList, setWaitingForCortex, handleLocalImportModels]) - - const unsubscribeDownloadEvent = useCallback(() => { - if (!isRegistered.current) return - - abortController.current.abort() - isRegistered.current = false - console.log('Download event unsubscribed') - }, []) - - useEffect(() => { - subscribeDownloadEvent() - return () => { - unsubscribeDownloadEvent() - } - }, [subscribeDownloadEvent, unsubscribeDownloadEvent]) - - return null -} - -const isAbsolutePath = (path: string): boolean => { - // Trim any leading or trailing whitespace - const trimmedPath = path.trim() - - // Check for Unix-like absolute path - if (trimmedPath.startsWith('/')) { - return true - } - - // Check for Windows absolute path (with drive letter) - if (/^[A-Za-z]:[/\\]/.test(trimmedPath)) { - return true - } - - // All other paths are not considered absolute local paths - return false -} - -export default DownloadEventListener diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx new file mode 100644 index 0000000000..e4c96aeb70 --- /dev/null +++ b/web/containers/Providers/EventHandler.tsx @@ -0,0 +1,299 @@ +import { Fragment, ReactNode, useCallback, useEffect, useRef } from 'react' + +import { + ChatCompletionMessage, + ChatCompletionRole, + events, + ThreadMessage, + ExtensionTypeEnum, + MessageStatus, + MessageRequest, + ConversationalExtension, + MessageEvent, + MessageRequestType, + ModelEvent, + Thread, + EngineManager, +} from '@janhq/core' +import { useAtomValue, useSetAtom } from 'jotai' +import { ulid } from 'ulidx' + +import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' + +import { toRuntimeParams } from '@/utils/modelParam' + +import { extensionManager } from '@/extension' +import { + getCurrentChatMessagesAtom, + addNewMessageAtom, + updateMessageAtom, +} from '@/helpers/atoms/ChatMessage.atom' +import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' +import { + updateThreadWaitingForResponseAtom, + threadsAtom, + isGeneratingResponseAtom, + updateThreadAtom, + getActiveThreadModelParamsAtom, +} from '@/helpers/atoms/Thread.atom' + +const maxWordForThreadTitle = 10 +const defaultThreadTitle = 'New Thread' + +export default function EventHandler({ children }: { children: ReactNode }) { + const messages = useAtomValue(getCurrentChatMessagesAtom) + const addNewMessage = useSetAtom(addNewMessageAtom) + const updateMessage = useSetAtom(updateMessageAtom) + const downloadedModels = useAtomValue(downloadedModelsAtom) + const activeModel = useAtomValue(activeModelAtom) + const setActiveModel = useSetAtom(activeModelAtom) + const setStateModel = useSetAtom(stateModelAtom) + + const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom) + const threads = useAtomValue(threadsAtom) + const modelsRef = useRef(downloadedModels) + const threadsRef = useRef(threads) + const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) + const updateThread = useSetAtom(updateThreadAtom) + const messagesRef = useRef(messages) + const activeModelRef = useRef(activeModel) + const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) + const activeModelParamsRef = useRef(activeModelParams) + + useEffect(() => { + threadsRef.current = threads + }, [threads]) + + useEffect(() => { + modelsRef.current = downloadedModels + }, [downloadedModels]) + + useEffect(() => { + messagesRef.current = messages + }, [messages]) + + useEffect(() => { + activeModelRef.current = activeModel + }, [activeModel]) + + useEffect(() => { + activeModelParamsRef.current = activeModelParams + }, [activeModelParams]) + + const onNewMessageResponse = useCallback( + (message: ThreadMessage) => { + if (message.type === MessageRequestType.Thread) { + addNewMessage(message) + } + }, + [addNewMessage] + ) + + const onModelStopped = useCallback(() => { + setActiveModel(undefined) + setStateModel({ state: 'start', loading: false, model: undefined }) + }, [setActiveModel, setStateModel]) + + const updateThreadTitle = useCallback( + (message: ThreadMessage) => { + // Update only when it's finished + if (message.status !== MessageStatus.Ready) { + return + } + + const thread = threadsRef.current?.find((e) => e.id == message.thread_id) + if (!thread) { + console.warn( + `Failed to update title for thread ${message.thread_id}: Thread not found!` + ) + return + } + + const messageContent = message.content[0]?.text?.value + if (!messageContent) { + console.warn( + `Failed to update title for thread ${message.thread_id}: Responded content is null!` + ) + return + } + + // The thread title should not be updated if the message is less than 10 words + // And no new line character is present + // And non-alphanumeric characters should be removed + if (messageContent.includes('\n')) { + console.warn( + `Failed to update title for thread ${message.thread_id}: Title can't contain new line character!` + ) + return + } + + // Remove non-alphanumeric characters + const cleanedMessageContent = messageContent + .replace(/[^a-z0-9\s]/gi, '') + .trim() + + // Split the message into words + const words = cleanedMessageContent.split(' ') + + if (words.length >= maxWordForThreadTitle) { + console.warn( + `Failed to update title for thread ${message.thread_id}: Title can't be greater than ${maxWordForThreadTitle} words!` + ) + return + } + + const updatedThread: Thread = { + ...thread, + + title: cleanedMessageContent, + metadata: thread.metadata, + } + + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.saveThread({ + ...updatedThread, + }) + .then(() => { + // Update the Thread title with the response of the inference on the 1st prompt + updateThread({ + ...updatedThread, + }) + }) + }, + [updateThread] + ) + + const updateThreadMessage = useCallback( + (message: ThreadMessage) => { + updateMessage( + message.id, + message.thread_id, + message.content, + message.status + ) + if (message.status === MessageStatus.Pending) { + if (message.content.length) { + setIsGeneratingResponse(false) + } + return + } + // Mark the thread as not waiting for response + updateThreadWaiting(message.thread_id, false) + + setIsGeneratingResponse(false) + + const thread = threadsRef.current?.find((e) => e.id == message.thread_id) + if (!thread) return + const messageContent = message.content[0]?.text?.value + const metadata = { + ...thread.metadata, + ...(messageContent && { lastMessage: messageContent }), + } + + updateThread({ + ...thread, + metadata, + }) + + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.saveThread({ + ...thread, + metadata, + }) + + // If this is not the summary of the Thread, don't need to add it to the Thread + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.addNewMessage(message) + + // Attempt to generate the title of the Thread when needed + generateThreadTitle(message, thread) + }, + [setIsGeneratingResponse, updateMessage, updateThread, updateThreadWaiting] + ) + + const onMessageResponseUpdate = useCallback( + (message: ThreadMessage) => { + switch (message.type) { + case MessageRequestType.Summary: + updateThreadTitle(message) + break + default: + updateThreadMessage(message) + break + } + }, + [updateThreadMessage, updateThreadTitle] + ) + + const generateThreadTitle = (message: ThreadMessage, thread: Thread) => { + // If this is the first ever prompt in the thread + if (thread.title?.trim() !== defaultThreadTitle) { + return + } + + if (!activeModelRef.current) { + return + } + + // This is the first time message comes in on a new thread + // Summarize the first message, and make that the title of the Thread + // 1. Get the summary of the first prompt using whatever engine user is currently using + const threadMessages = messagesRef?.current + + if (!threadMessages || threadMessages.length === 0) return + + const summarizeFirstPrompt = `Summarize in a ${maxWordForThreadTitle}-word Title. Give the title only. "${threadMessages[0].content[0].text.value}"` + + // Prompt: Given this query from user {query}, return to me the summary in 10 words as the title + const msgId = ulid() + const messages: ChatCompletionMessage[] = [ + { + role: ChatCompletionRole.User, + content: summarizeFirstPrompt, + }, + ] + + const runtimeParams = toRuntimeParams(activeModelParamsRef.current) + + const messageRequest: MessageRequest = { + id: msgId, + threadId: message.thread_id, + type: MessageRequestType.Summary, + messages, + model: { + ...activeModelRef.current, + parameters: { + ...runtimeParams, + stream: false, + }, + }, + } + + // 2. Update the title with the result of the inference + setTimeout(() => { + const engine = EngineManager.instance().get( + messageRequest.model?.engine ?? activeModelRef.current?.engine ?? '' + ) + engine?.inference(messageRequest) + }, 1000) + } + + useEffect(() => { + if (window.core?.events) { + events.on(MessageEvent.OnMessageResponse, onNewMessageResponse) + events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate) + events.on(ModelEvent.OnModelStopped, onModelStopped) + } + + return () => { + events.off(MessageEvent.OnMessageResponse, onNewMessageResponse) + events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate) + events.off(ModelEvent.OnModelStopped, onModelStopped) + } + }, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped]) + + return {children} +} diff --git a/web/containers/Providers/EventListener.tsx b/web/containers/Providers/EventListener.tsx index 71dfccd38c..b35ab2e439 100644 --- a/web/containers/Providers/EventListener.tsx +++ b/web/containers/Providers/EventListener.tsx @@ -1,24 +1,119 @@ -import { Fragment } from 'react' +import { PropsWithChildren, useCallback, useEffect } from 'react' import React from 'react' +import { DownloadEvent, events, DownloadState, ModelEvent } from '@janhq/core' +import { useSetAtom } from 'jotai' + +import { setDownloadStateAtom } from '@/hooks/useDownloadState' + +import { formatExtensionsName } from '@/utils/converter' + +import { toaster } from '../Toast' + import AppUpdateListener from './AppUpdateListener' import ClipboardListener from './ClipboardListener' -import DeepLinkListener from './DeepLinkListener' -import DownloadEventListener from './DownloadEventListener' - -import KeyListener from './KeyListener' -import ModelEventListener from './ModelEventListener' - -const EventListenerWrapper: React.FC = () => ( - - - - - - - - -) +import EventHandler from './EventHandler' + +import ModelImportListener from './ModelImportListener' +import QuickAskListener from './QuickAskListener' + +import { + InstallingExtensionState, + removeInstallingExtensionAtom, + setInstallingExtensionAtom, +} from '@/helpers/atoms/Extension.atom' + +const EventListenerWrapper = ({ children }: PropsWithChildren) => { + const setDownloadState = useSetAtom(setDownloadStateAtom) + const setInstallingExtension = useSetAtom(setInstallingExtensionAtom) + const removeInstallingExtension = useSetAtom(removeInstallingExtensionAtom) + + const onFileDownloadUpdate = useCallback( + async (state: DownloadState) => { + console.debug('onFileDownloadUpdate', state) + if (state.downloadType === 'extension') { + const installingExtensionState: InstallingExtensionState = { + extensionId: state.extensionId!, + percentage: state.percent, + localPath: state.localPath, + } + setInstallingExtension(state.extensionId!, installingExtensionState) + } else { + setDownloadState(state) + } + }, + [setDownloadState, setInstallingExtension] + ) + + const onFileDownloadError = useCallback( + (state: DownloadState) => { + console.debug('onFileDownloadError', state) + if (state.downloadType === 'extension') { + removeInstallingExtension(state.extensionId!) + } else { + setDownloadState(state) + } + }, + [setDownloadState, removeInstallingExtension] + ) + + const onFileDownloadSuccess = useCallback( + (state: DownloadState) => { + console.debug('onFileDownloadSuccess', state) + if (state.downloadType !== 'extension') { + setDownloadState(state) + } + events.emit(ModelEvent.OnModelsUpdate, {}) + }, + [setDownloadState] + ) + + const onFileUnzipSuccess = useCallback( + (state: DownloadState) => { + console.debug('onFileUnzipSuccess', state) + toaster({ + title: 'Success', + description: `Install ${formatExtensionsName(state.extensionId!)} successfully.`, + type: 'success', + }) + removeInstallingExtension(state.extensionId!) + }, + [removeInstallingExtension] + ) + + useEffect(() => { + console.debug('EventListenerWrapper: registering event listeners...') + events.on(DownloadEvent.onFileDownloadUpdate, onFileDownloadUpdate) + events.on(DownloadEvent.onFileDownloadError, onFileDownloadError) + events.on(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) + events.on(DownloadEvent.onFileUnzipSuccess, onFileUnzipSuccess) + + return () => { + console.debug('EventListenerWrapper: unregistering event listeners...') + events.off(DownloadEvent.onFileDownloadUpdate, onFileDownloadUpdate) + events.off(DownloadEvent.onFileDownloadError, onFileDownloadError) + events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) + events.off(DownloadEvent.onFileUnzipSuccess, onFileUnzipSuccess) + } + }, [ + onFileDownloadUpdate, + onFileDownloadError, + onFileDownloadSuccess, + onFileUnzipSuccess, + ]) + + return ( + + + + + {children} + + + + + ) +} export default EventListenerWrapper diff --git a/web/containers/Providers/KeyListener.tsx b/web/containers/Providers/KeyListener.tsx index 1e012501ca..2731846df2 100644 --- a/web/containers/Providers/KeyListener.tsx +++ b/web/containers/Providers/KeyListener.tsx @@ -1,74 +1,30 @@ 'use client' -import { useCallback, useEffect } from 'react' +import { Fragment, ReactNode, useEffect } from 'react' import { useAtomValue, useSetAtom } from 'jotai' -import useAssistantQuery from '@/hooks/useAssistantQuery' -import useThreads from '@/hooks/useThreads' +import { MainViewState } from '@/constants/screens' -import { copyOverInstructionEnabledAtom } from '@/screens/Thread/ThreadRightPanel/AssistantSettingContainer/components/CopyOverInstruction' - -import { toaster } from '../Toast' +import { useCreateNewThread } from '@/hooks/useCreateNewThread' import { - MainViewState, mainViewStateAtom, showLeftPanelAtom, showRightPanelAtom, } from '@/helpers/atoms/App.atom' -import { getSelectedModelAtom } from '@/helpers/atoms/Model.atom' +import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' -import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' +type Props = { + children: ReactNode +} -const KeyListener: React.FC = () => { +export default function KeyListener({ children }: Props) { const setShowLeftPanel = useSetAtom(showLeftPanelAtom) const setShowRightPanel = useSetAtom(showRightPanelAtom) const setMainViewState = useSetAtom(mainViewStateAtom) - const { createThread } = useThreads() - - const activeThread = useAtomValue(activeThreadAtom) - const copyOverInstructionEnabled = useAtomValue( - copyOverInstructionEnabledAtom - ) - const { data: assistants } = useAssistantQuery() - - const selectedModel = useAtomValue(getSelectedModelAtom) - - const createNewThread = useCallback(() => { - if (!selectedModel) { - toaster({ - title: 'No model selected.', - description: 'Please select a model to create a new thread.', - type: 'error', - }) - return - } - - if (!assistants || assistants.length === 0) { - toaster({ - title: 'No assistant available.', - description: 'Please create an assistant to create a new thread', - type: 'error', - }) - return - } - - if (!selectedModel) return - let instructions: string | undefined = undefined - if (copyOverInstructionEnabled) { - instructions = activeThread?.assistants[0]?.instructions ?? undefined - } - createThread(selectedModel.model, assistants[0], instructions) - setMainViewState(MainViewState.Thread) - }, [ - createThread, - setMainViewState, - selectedModel, - assistants, - activeThread, - copyOverInstructionEnabled, - ]) + const { requestCreateNewThread } = useCreateNewThread() + const assistants = useAtomValue(assistantsAtom) useEffect(() => { const onKeyDown = (e: KeyboardEvent) => { @@ -80,7 +36,9 @@ const KeyListener: React.FC = () => { } if (e.key === 'n' && prefixKey) { - return createNewThread() + requestCreateNewThread(assistants[0]) + setMainViewState(MainViewState.Thread) + return } if (e.key === 'b' && prefixKey) { @@ -97,15 +55,11 @@ const KeyListener: React.FC = () => { return () => document.removeEventListener('keydown', onKeyDown) }, [ assistants, - setShowRightPanel, - selectedModel, - createThread, + requestCreateNewThread, setMainViewState, setShowLeftPanel, - createNewThread, + setShowRightPanel, ]) - return null + return {children} } - -export default KeyListener diff --git a/web/containers/Providers/ModalMigrations.tsx b/web/containers/Providers/ModalMigrations.tsx deleted file mode 100644 index 8f1746674d..0000000000 --- a/web/containers/Providers/ModalMigrations.tsx +++ /dev/null @@ -1,234 +0,0 @@ -import React, { Fragment, useCallback, useMemo, useState } from 'react' - -import { Button, Modal, Badge } from '@janhq/joi' - -import { useQueryClient } from '@tanstack/react-query' -import { atom, useAtom, useSetAtom } from 'jotai' -import { AlertTriangleIcon } from 'lucide-react' - -import { twMerge } from 'tailwind-merge' - -import Spinner from '@/containers/Loader/Spinner' - -import useMigratingData from '@/hooks/useMigratingData' - -import { modelQueryKey } from '@/hooks/useModelQuery' - -import { didShowMigrationWarningAtom } from '@/helpers/atoms/AppConfig.atom' - -export const showMigrationModalAtom = atom(false) - -const MigrationStates = ['idle', 'in_progress', 'failed', 'success'] as const -type MigrationState = (typeof MigrationStates)[number] - -const ModalMigrations = () => { - const setDidShowMigrationModal = useSetAtom(didShowMigrationWarningAtom) - const [showMigrationModal, setShowMigrationModal] = useAtom( - showMigrationModalAtom - ) - const [step, setStep] = React.useState(1) - const { migrateModels, migrateThreadsAndMessages } = useMigratingData() - const [threadAndMessageMigrationState, setThreadAndMessageMigrationState] = - useState('idle') - const [modelMigrationState, setModelMigrationState] = - useState('idle') - const queryClient = useQueryClient() - - const getStepTitle = () => { - switch (step) { - case 1: - return 'Important Update: Data Migration Needed' - - default: - return threadAndMessageMigrationState === 'in_progress' || - modelMigrationState === 'in_progress' - ? 'Migrating' - : 'Migration Completed' - } - } - - const migrationThreadsAndMessages = useCallback(async () => { - setThreadAndMessageMigrationState('in_progress') - try { - await migrateThreadsAndMessages() - setThreadAndMessageMigrationState('success') - console.debug('Migrating threads and messages successfully!') - } catch (err) { - console.error('Migrating threads and messages error', err) - setThreadAndMessageMigrationState('failed') - } - }, [setThreadAndMessageMigrationState, migrateThreadsAndMessages]) - - const migratingModels = useCallback(async () => { - setModelMigrationState('in_progress') - try { - await migrateModels() - setModelMigrationState('success') - console.debug('Migrating models successfully!') - } catch (err) { - console.error('Migrating models error', err) - setModelMigrationState('failed') - } - }, [migrateModels, setModelMigrationState]) - - const onStartMigrationClick = useCallback(async () => { - setStep(2) - await migratingModels() - await migrationThreadsAndMessages() - queryClient.invalidateQueries({ queryKey: modelQueryKey }) - }, [migratingModels, migrationThreadsAndMessages, queryClient]) - - const onDismiss = useCallback(() => { - setStep(1) - setShowMigrationModal(false) - setDidShowMigrationModal(true) - }, [setDidShowMigrationModal, setShowMigrationModal]) - - const disableDismissButton = useMemo( - () => - threadAndMessageMigrationState === 'in_progress' || - modelMigrationState === 'in_progress', - [threadAndMessageMigrationState, modelMigrationState] - ) - - return ( - - {step === 1 && ( - -

- {`We've made some exciting improvements to the app, but we need your - help to update your data.`} -

-
-
- -

What to expect:

-
- -
-
    -
  • - - Some threads or models{' '} - might be missing - after the migration. - -
  • -
  • - - This will take a few seconds and reload the app. - -
  • -
-
-
- -
- - -
-
- )} - {step === 2 && ( - -
-
- {threadAndMessageMigrationState !== 'in_progress' && ( - <> - {threadAndMessageMigrationState === 'success' ? ( - Success - ) : ( - Failed - )} - - )} -

Threads

-
- {threadAndMessageMigrationState === 'in_progress' ? ( - - ) : ( - threadAndMessageMigrationState !== 'success' && ( - - ) - )} -
-
-
- {modelMigrationState !== 'in_progress' && ( - <> - {modelMigrationState === 'success' ? ( - Success - ) : ( - Failed - )} - - )} -

Models

-
- {modelMigrationState === 'in_progress' ? ( - - ) : ( - modelMigrationState === 'failed' && ( - - ) - )} -
-
- -
-
- )} - - } - /> - ) -} - -export default ModalMigrations diff --git a/web/containers/Providers/ModelEventListener.tsx b/web/containers/Providers/ModelEventListener.tsx deleted file mode 100644 index b872898d16..0000000000 --- a/web/containers/Providers/ModelEventListener.tsx +++ /dev/null @@ -1,134 +0,0 @@ -import { useCallback, useEffect, useRef } from 'react' - -import { - EmptyModelEvent, - ModelEvent, - ModelStatus, - StatusAndEvent, -} from '@janhq/core' -import { fetchEventSource } from '@microsoft/fetch-event-source' -import { useQueryClient } from '@tanstack/react-query' -import { useAtomValue, useSetAtom } from 'jotai' - -import { removeDownloadSuccessItemAtom } from '@/hooks/useDownloadState' -import { modelQueryKey } from '@/hooks/useModelQuery' - -import { toaster } from '../Toast' - -import { hostAtom } from '@/helpers/atoms/AppConfig.atom' -import { activeModelsAtom } from '@/helpers/atoms/Model.atom' -import { isLoadingModelAtom } from '@/helpers/atoms/Thread.atom' - -function ModelEventListener() { - const setActiveModels = useSetAtom(activeModelsAtom) - const host = useAtomValue(hostAtom) - const abortController = useRef(null) - const removeDownloadSuccessItem = useSetAtom(removeDownloadSuccessItemAtom) - const setIsLoadingModel = useSetAtom(isLoadingModelAtom) - - const queryClient = useQueryClient() - - const handleModelEvent = useCallback( - (modelEvent: ModelEvent) => { - console.log('Model event:', modelEvent.event) - switch (modelEvent.event) { - case 'starting': - setIsLoadingModel(true) - break - - case 'started': - setIsLoadingModel(false) - toaster({ - title: 'Success!', - description: `Model ${modelEvent.model} has been started.`, - type: 'success', - }) - break - - case 'starting-failed': - setIsLoadingModel(false) - toaster({ - title: 'Failed!', - description: `Model ${modelEvent.model} failed to start.`, - type: 'error', - }) - break - - case 'stopped': - setIsLoadingModel(false) - toaster({ - title: 'Success!', - description: `Model ${modelEvent.model} has been stopped.`, - type: 'success', - }) - break - - case 'model-downloaded': - removeDownloadSuccessItem(modelEvent.model) - queryClient.invalidateQueries({ queryKey: modelQueryKey }) - break - - case 'model-deleted': - queryClient.invalidateQueries({ queryKey: modelQueryKey }) - break - - case 'stopping-failed': - setIsLoadingModel(false) - toaster({ - title: 'Failed!', - description: `Model ${modelEvent.model} failed to stop.`, - type: 'error', - }) - break - - default: - break - } - }, - [removeDownloadSuccessItem, setIsLoadingModel, queryClient] - ) - - const subscribeModelEvent = useCallback(async () => { - if (abortController.current) return - abortController.current = new AbortController() - - await fetchEventSource(`${host}/system/events/model`, { - onmessage(ev) { - if (!ev.data || ev.data === '') return - try { - const modelEvent = JSON.parse(ev.data) as StatusAndEvent - - const runningModels: ModelStatus[] = [] - Object.values(modelEvent.status).forEach((value) => { - runningModels.push(value) - }) - setActiveModels(runningModels) - - if (modelEvent.event === EmptyModelEvent) return - handleModelEvent(modelEvent.event as ModelEvent) - } catch (err) { - console.error(err) - } - }, - signal: abortController.current.signal, - }) - }, [host, setActiveModels, handleModelEvent]) - - const unsubscribeModelEvent = useCallback(() => { - if (!abortController.current) return - - abortController.current.abort() - abortController.current = null - }, []) - - useEffect(() => { - subscribeModelEvent() - return () => { - unsubscribeModelEvent() - } - }, [subscribeModelEvent, unsubscribeModelEvent]) - - return null -} - -export default ModelEventListener diff --git a/web/containers/Providers/ModelImportListener.tsx b/web/containers/Providers/ModelImportListener.tsx new file mode 100644 index 0000000000..f1ca2a7688 --- /dev/null +++ b/web/containers/Providers/ModelImportListener.tsx @@ -0,0 +1,109 @@ +import { Fragment, PropsWithChildren, useCallback, useEffect } from 'react' + +import { + ImportingModel, + LocalImportModelEvent, + Model, + ModelEvent, + events, +} from '@janhq/core' +import { useSetAtom } from 'jotai' + +import { snackbar } from '../Toast' + +import { + setImportingModelErrorAtom, + setImportingModelSuccessAtom, + updateImportingModelProgressAtom, +} from '@/helpers/atoms/Model.atom' + +const ModelImportListener = ({ children }: PropsWithChildren) => { + const updateImportingModelProgress = useSetAtom( + updateImportingModelProgressAtom + ) + const setImportingModelSuccess = useSetAtom(setImportingModelSuccessAtom) + const setImportingModelFailed = useSetAtom(setImportingModelErrorAtom) + + const onImportModelUpdate = useCallback( + async (state: ImportingModel) => { + if (!state.importId) return + updateImportingModelProgress(state.importId, state.percentage ?? 0) + }, + [updateImportingModelProgress] + ) + + const onImportModelFailed = useCallback( + async (state: ImportingModel) => { + if (!state.importId) return + setImportingModelFailed(state.importId, state.error ?? '') + }, + [setImportingModelFailed] + ) + + const onImportModelSuccess = useCallback( + (state: ImportingModel) => { + if (!state.modelId) return + events.emit(ModelEvent.OnModelsUpdate, {}) + setImportingModelSuccess(state.importId, state.modelId) + }, + [setImportingModelSuccess] + ) + + const onImportModelFinished = useCallback((importedModels: Model[]) => { + const modelText = importedModels.length === 1 ? 'model' : 'models' + snackbar({ + description: `Successfully imported ${importedModels.length} ${modelText}`, + type: 'success', + }) + }, []) + + useEffect(() => { + console.debug('ModelImportListener: registering event listeners..') + + events.on( + LocalImportModelEvent.onLocalImportModelUpdate, + onImportModelUpdate + ) + events.on( + LocalImportModelEvent.onLocalImportModelSuccess, + onImportModelSuccess + ) + events.on( + LocalImportModelEvent.onLocalImportModelFinished, + onImportModelFinished + ) + events.on( + LocalImportModelEvent.onLocalImportModelFailed, + onImportModelFailed + ) + + return () => { + console.debug('ModelImportListener: unregistering event listeners...') + events.off( + LocalImportModelEvent.onLocalImportModelUpdate, + onImportModelUpdate + ) + events.off( + LocalImportModelEvent.onLocalImportModelSuccess, + onImportModelSuccess + ) + events.off( + LocalImportModelEvent.onLocalImportModelFinished, + onImportModelFinished + ) + events.off( + LocalImportModelEvent.onLocalImportModelFailed, + onImportModelFailed + ) + } + }, [ + onImportModelUpdate, + onImportModelSuccess, + onImportModelFinished, + onImportModelFailed, + ]) + + return {children} +} + +export default ModelImportListener diff --git a/web/containers/Providers/QuickAskListener.tsx b/web/containers/Providers/QuickAskListener.tsx index f1bba3e382..415fc19a63 100644 --- a/web/containers/Providers/QuickAskListener.tsx +++ b/web/containers/Providers/QuickAskListener.tsx @@ -1,25 +1,33 @@ +import { Fragment, ReactNode } from 'react' + import { useSetAtom } from 'jotai' import { useDebouncedCallback } from 'use-debounce' -import useSendMessage from '@/hooks/useSendMessage' +import { MainViewState } from '@/constants/screens' + +import useSendChatMessage from '@/hooks/useSendChatMessage' -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' +import { mainViewStateAtom } from '@/helpers/atoms/App.atom' + +type Props = { + children: ReactNode +} -const QuickAskListener: React.FC = () => { - const { sendMessage } = useSendMessage() +const QuickAskListener: React.FC = ({ children }) => { + const { sendChatMessage } = useSendChatMessage() const setMainState = useSetAtom(mainViewStateAtom) const debounced = useDebouncedCallback((value) => { setMainState(MainViewState.Thread) - sendMessage(value) + sendChatMessage(value) }, 300) window.electronAPI?.onUserSubmitQuickAsk((_event: string, input: string) => { debounced(input) }) - return null + return {children} } export default QuickAskListener diff --git a/web/containers/Providers/index.tsx b/web/containers/Providers/index.tsx index f5ed40929c..4731c600b8 100644 --- a/web/containers/Providers/index.tsx +++ b/web/containers/Providers/index.tsx @@ -1,47 +1,93 @@ 'use client' -import { Fragment, PropsWithChildren, useEffect, useState } from 'react' +import { PropsWithChildren, useCallback, useEffect, useState } from 'react' import { Toaster } from 'react-hot-toast' -import { QueryClient, QueryClientProvider } from '@tanstack/react-query' - +import Loader from '@/containers/Loader' import EventListenerWrapper from '@/containers/Providers/EventListener' import JotaiWrapper from '@/containers/Providers/Jotai' import ThemeWrapper from '@/containers/Providers/Theme' import { setupCoreServices } from '@/services/coreService' +import { + isCoreExtensionInstalled, + setupBaseExtensions, +} from '@/services/extensionService' + +import Umami from '@/utils/umami' import DataLoader from './DataLoader' +import DeepLinkListener from './DeepLinkListener' +import KeyListener from './KeyListener' import Responsive from './Responsive' -const queryClient = new QueryClient() +import { extensionManager } from '@/extension' const Providers = ({ children }: PropsWithChildren) => { const [setupCore, setSetupCore] = useState(false) + const [activated, setActivated] = useState(false) + const [settingUp, setSettingUp] = useState(false) + + const setupExtensions = useCallback(async () => { + // Register all active extensions + await extensionManager.registerActive() + + setTimeout(async () => { + if (!isCoreExtensionInstalled()) { + setSettingUp(true) + await setupBaseExtensions() + return + } + + extensionManager.load() + setSettingUp(false) + setActivated(true) + }, 500) + }, []) // Services Setup useEffect(() => { setupCoreServices() setSetupCore(true) + return () => { + extensionManager.unload() + } }, []) + useEffect(() => { + if (setupCore) { + // Electron + if (window && window.core?.api) { + setupExtensions() + } else { + // Host + setActivated(true) + } + } + }, [setupCore, setupExtensions]) + return ( - - {/* */} - {setupCore && ( - - - - {children} - - - )} - + + {settingUp && } + {setupCore && activated && ( + <> + + + + + {children} + + + + + + + )} ) diff --git a/web/containers/SetupRemoteModel/index.tsx b/web/containers/SetupRemoteModel/index.tsx index 5ca141d95b..ab71240af2 100644 --- a/web/containers/SetupRemoteModel/index.tsx +++ b/web/containers/SetupRemoteModel/index.tsx @@ -1,75 +1,84 @@ -import { LlmEngine } from '@janhq/core' +import { useState, useEffect } from 'react' + +import { InferenceEngine } from '@janhq/core' + import { Button } from '@janhq/joi' +import { useSetAtom } from 'jotai' import { SettingsIcon } from 'lucide-react' +import { MainViewState } from '@/constants/screens' + +import { extensionManager } from '@/extension' +import { mainViewStateAtom } from '@/helpers/atoms/App.atom' +import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom' + type Props = { - engine: LlmEngine + engine: InferenceEngine } const SetupRemoteModel = ({ engine }: Props) => { - console.log('SetupRemoteModel', engine) - // const setSelectedSetting = useSetAtom(selectedSettingAtom) - // const setMainViewState = useSetAtom(mainViewStateAtom) + const setSelectedSetting = useSetAtom(selectedSettingAtom) + const setMainViewState = useSetAtom(mainViewStateAtom) - // const [extensionHasSettings, setExtensionHasSettings] = useState< - // { name?: string; setting: string; apiKey: string; provider: string }[] - // >([]) + const [extensionHasSettings, setExtensionHasSettings] = useState< + { name?: string; setting: string; apiKey: string; provider: string }[] + >([]) - // useEffect(() => { - // const getAllSettings = async () => { - // const extensionsMenu: { - // name?: string - // setting: string - // apiKey: string - // provider: string - // }[] = [] - // const extensions = extensionManager.getAll() + useEffect(() => { + const getAllSettings = async () => { + const extensionsMenu: { + name?: string + setting: string + apiKey: string + provider: string + }[] = [] + const extensions = extensionManager.getAll() - // for (const extension of extensions) { - // if (typeof extension.getSettings === 'function') { - // const settings = await extension.getSettings() + for (const extension of extensions) { + if (typeof extension.getSettings === 'function') { + const settings = await extension.getSettings() - // if ( - // (settings && settings.length > 0) || - // (await extension.installationState()) !== 'NotRequired' - // ) { - // extensionsMenu.push({ - // name: extension.productName, - // setting: extension.name, - // apiKey: - // 'apiKey' in extension && typeof extension.apiKey === 'string' - // ? extension.apiKey - // : '', - // provider: - // 'provider' in extension && - // typeof extension.provider === 'string' - // ? extension.provider - // : '', - // }) - // } - // } - // } - // setExtensionHasSettings(extensionsMenu) - // } - // getAllSettings() - // }, []) + if ( + (settings && settings.length > 0) || + (await extension.installationState()) !== 'NotRequired' + ) { + extensionsMenu.push({ + name: extension.productName, + setting: extension.name, + apiKey: + 'apiKey' in extension && typeof extension.apiKey === 'string' + ? extension.apiKey + : '', + provider: + 'provider' in extension && + typeof extension.provider === 'string' + ? extension.provider + : '', + }) + } + } + } + setExtensionHasSettings(extensionsMenu) + } + getAllSettings() + }, []) - // const onSetupItemClick = (engine: LlmEngine) => { - // setMainViewState(MainViewState.Settings) - // setSelectedSetting( - // extensionHasSettings.filter((x) => - // x.provider?.toLowerCase().includes(engine) - // )[0]?.setting - // ) - // } + const onSetupItemClick = (setting: InferenceEngine) => { + setMainViewState(MainViewState.Settings) + setSelectedSetting( + extensionHasSettings.filter((x) => + x.provider.toLowerCase().includes(setting) + )[0]?.setting + ) + } return ( - )} -
- } - /> - ) -} - -export default WaitingForCortexModal diff --git a/web/extension/Extension.ts b/web/extension/Extension.ts new file mode 100644 index 0000000000..9438238ca5 --- /dev/null +++ b/web/extension/Extension.ts @@ -0,0 +1,40 @@ +/** + * Extension manifest object. + */ +class Extension { + /** @type {string} Name of the extension. */ + name: string + + /** @type {string} Product name of the extension. */ + productName?: string + + /** @type {string} The URL of the extension to load. */ + url: string + + /** @type {boolean} Whether the extension is activated or not. */ + active + + /** @type {string} Extension's description. */ + description + + /** @type {string} Extension's version. */ + version + + constructor( + url: string, + name: string, + productName?: string, + active?: boolean, + description?: string, + version?: string + ) { + this.name = name + this.productName = productName + this.url = url + this.active = active + this.description = description + this.version = version + } +} + +export default Extension diff --git a/web/extension/ExtensionManager.ts b/web/extension/ExtensionManager.ts new file mode 100644 index 0000000000..aa1a7674b4 --- /dev/null +++ b/web/extension/ExtensionManager.ts @@ -0,0 +1,202 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import { AIEngine, BaseExtension, ExtensionTypeEnum } from '@janhq/core' + +import Extension from './Extension' + +/** + * Manages the registration and retrieval of extensions. + */ +export class ExtensionManager { + // Registered extensions + private extensions = new Map() + + // Registered inference engines + private engines = new Map() + + /** + * Registers an extension. + * @param extension - The extension to register. + */ + register(name: string, extension: T) { + // Register for naming use + this.extensions.set(name, extension) + + // Register AI Engines + if ('provider' in extension && typeof extension.provider === 'string') { + this.engines.set( + extension.provider as unknown as string, + extension as unknown as AIEngine + ) + } + } + + /** + * Retrieves a extension by its type. + * @param type - The type of the extension to retrieve. + * @returns The extension, if found. + */ + get(type: ExtensionTypeEnum): T | undefined { + return this.getAll().findLast((e) => e.type() === type) as T | undefined + } + + /** + * Retrieves a extension by its type. + * @param type - The type of the extension to retrieve. + * @returns The extension, if found. + */ + getByName(name: string): BaseExtension | undefined { + return this.extensions.get(name) as BaseExtension | undefined + } + + /** + * Retrieves a extension by its type. + * @param type - The type of the extension to retrieve. + * @returns The extension, if found. + */ + getAll(): BaseExtension[] { + return Array.from(this.extensions.values()) + } + + /** + * Retrieves a extension by its type. + * @param engine - The engine name to retrieve. + * @returns The extension, if found. + */ + getEngine(engine: string): T | undefined { + return this.engines.get(engine) as T | undefined + } + + /** + * Loads all registered extension. + */ + load() { + this.listExtensions().forEach((ext) => { + ext.onLoad() + }) + } + + /** + * Unloads all registered extensions. + */ + unload() { + this.listExtensions().forEach((ext) => { + ext.onUnload() + }) + } + + /** + * Retrieves a list of all registered extensions. + * @returns An array of extensions. + */ + listExtensions() { + return [...this.extensions.values()] + } + + /** + * Retrieves a list of all registered extensions. + * @returns An array of extensions. + */ + async getActive(): Promise { + const res = await window.core?.api?.getActiveExtensions() + if (!res || !Array.isArray(res)) return [] + + const extensions: Extension[] = res.map( + (ext: any) => + new Extension( + ext.url, + ext.name, + ext.productName, + ext.active, + ext.description, + ext.version + ) + ) + return extensions + } + + /** + * Register a extension with its class. + * @param {Extension} extension extension object as provided by the main process. + * @returns {void} + */ + async activateExtension(extension: Extension) { + // Import class + const extensionUrl = window.electronAPI + ? extension.url + : extension.url.replace( + 'extension://', + `${window.core?.api?.baseApiUrl ?? ''}/extensions/` + ) + await import(/* webpackIgnore: true */ extensionUrl).then( + (extensionClass) => { + // Register class if it has a default export + if ( + typeof extensionClass.default === 'function' && + extensionClass.default.prototype + ) { + this.register( + extension.name, + new extensionClass.default( + extension.url, + extension.name, + extension.productName, + extension.active, + extension.description, + extension.version + ) + ) + } + } + ) + } + + /** + * Registers all active extensions. + * @returns {void} + */ + async registerActive() { + // Get active extensions + const activeExtensions = await this.getActive() + // Activate all + await Promise.all( + activeExtensions.map((ext: Extension) => this.activateExtension(ext)) + ) + } + + /** + * Install a new extension. + * @param {Array.} extensions A list of NPM specifiers, or installation configuration objects. + * @returns {Promise. | false>} extension as defined by the main process. Has property cancelled set to true if installation was cancelled in the main process. + */ + async install(extensions: any[]) { + if (typeof window === 'undefined') { + return + } + const res = await window.core?.api?.installExtension(extensions) + if (res.cancelled) return false + return res.map(async (ext: any) => { + const extension = new Extension(ext.name, ext.url, ext.active) + await this.activateExtension(extension) + return extension + }) + } + + /** + * Uninstall provided extensions + * @param {Array.} extensions List of names of extensions to uninstall. + * @param {boolean} reload Whether to reload all renderers after updating the extensions. + * @returns {Promise.} Whether uninstalling the extensions was successful. + */ + uninstall(extensions: string[], reload = true) { + if (typeof window === 'undefined') { + return + } + return window.core?.api?.uninstallExtension(extensions, reload) + } +} + +/** + * The singleton instance of the ExtensionManager. + */ +export const extensionManager = new ExtensionManager() diff --git a/web/extension/index.ts b/web/extension/index.ts new file mode 100644 index 0000000000..e2fbb5ad5e --- /dev/null +++ b/web/extension/index.ts @@ -0,0 +1 @@ +export { extensionManager } from './ExtensionManager' diff --git a/web/helpers/atoms/ApiServer.atom.ts b/web/helpers/atoms/ApiServer.atom.ts new file mode 100644 index 0000000000..ce37ba4ed3 --- /dev/null +++ b/web/helpers/atoms/ApiServer.atom.ts @@ -0,0 +1,20 @@ +import { atomWithStorage } from 'jotai/utils' + +export const hostOptions = [ + { name: '127.0.0.1', value: '127.0.0.1' }, + { name: '0.0.0.0', value: '0.0.0.0' }, +] + +export const apiServerPortAtom = atomWithStorage('apiServerPort', '1337') +export const apiServerHostAtom = atomWithStorage('apiServerHost', '127.0.0.1') +export const apiServerPrefix = atomWithStorage('apiServerPrefix', '/v1') + +export const apiServerCorsEnabledAtom = atomWithStorage( + 'apiServerCorsEnabled', + true +) + +export const apiServerVerboseLogEnabledAtom = atomWithStorage( + 'apiServerVerboseLogEnabled', + true +) diff --git a/web/helpers/atoms/App.atom.ts b/web/helpers/atoms/App.atom.ts index 0668f47631..8770b4bcd8 100644 --- a/web/helpers/atoms/App.atom.ts +++ b/web/helpers/atoms/App.atom.ts @@ -1,18 +1,11 @@ import { atom } from 'jotai' -export enum MainViewState { - Hub, - Settings, - Thread, - LocalServer, -} +import { MainViewState } from '@/constants/screens' export const mainViewStateAtom = atom(MainViewState.Thread) export const defaultJanDataFolderAtom = atom('') -export const waitingForCortexAtom = atom(true) - // Store panel atom export const showLeftPanelAtom = atom(true) export const showRightPanelAtom = atom(true) diff --git a/web/helpers/atoms/AppConfig.atom.ts b/web/helpers/atoms/AppConfig.atom.ts index 1212582004..f4acc7dc22 100644 --- a/web/helpers/atoms/AppConfig.atom.ts +++ b/web/helpers/atoms/AppConfig.atom.ts @@ -6,8 +6,8 @@ const PROXY_FEATURE_ENABLED = 'proxyFeatureEnabled' const VULKAN_ENABLED = 'vulkanEnabled' const IGNORE_SSL = 'ignoreSSLFeature' const HTTPS_PROXY_FEATURE = 'httpsProxyFeature' -//const QUICK_ASK_ENABLED = 'quickAskEnabled' -const MIGRATION_WARNING = 'didShowMigrationWarning' +const QUICK_ASK_ENABLED = 'quickAskEnabled' + export const janDataFolderPathAtom = atom('') export const experimentalFeatureEnabledAtom = atomWithStorage( @@ -20,14 +20,6 @@ export const proxyAtom = atomWithStorage(HTTPS_PROXY_FEATURE, '') export const ignoreSslAtom = atomWithStorage(IGNORE_SSL, false) export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false) -export const quickAskEnabledAtom = atom(false) //atomWithStorage(QUICK_ASK_ENABLED, false) -export const didShowMigrationWarningAtom = atomWithStorage( - MIGRATION_WARNING, - false, - undefined, - { - getOnInit: true, - } -) +export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false) -export const hostAtom = atom('http://127.0.0.1:1338/v1') +export const hostAtom = atom('http://localhost:1337/') diff --git a/web/helpers/atoms/Assistant.atom.ts b/web/helpers/atoms/Assistant.atom.ts new file mode 100644 index 0000000000..d44703cf41 --- /dev/null +++ b/web/helpers/atoms/Assistant.atom.ts @@ -0,0 +1,4 @@ +import { Assistant } from '@janhq/core' +import { atom } from 'jotai' + +export const assistantsAtom = atom([]) diff --git a/web/helpers/atoms/BottomPanel.atom.ts b/web/helpers/atoms/BottomPanel.atom.ts new file mode 100644 index 0000000000..e69de29bb2 diff --git a/web/helpers/atoms/ChatMessage.atom.ts b/web/helpers/atoms/ChatMessage.atom.ts index 289d8d3e8c..4da22d13aa 100644 --- a/web/helpers/atoms/ChatMessage.atom.ts +++ b/web/helpers/atoms/ChatMessage.atom.ts @@ -1,51 +1,94 @@ -import { Message, MessageContent } from '@janhq/core' +import { + ChatCompletionRole, + MessageStatus, + ThreadContent, + ThreadMessage, +} from '@janhq/core' import { atom } from 'jotai' -import { getActiveThreadIdAtom } from './Thread.atom' +import { + getActiveThreadIdAtom, + updateThreadStateLastMessageAtom, +} from './Thread.atom' -const chatMessages = atom>({}) - -export const disableStopInferenceAtom = atom(false) +/** + * Stores all chat messages for all threads + */ +export const chatMessages = atom>({}) -export const chunkCountAtom = atom>({}) +export const readyThreadsMessagesAtom = atom>({}) /** - * Return the chat messages for the current active thread + * Return the chat messages for the current active conversation */ -export const getCurrentChatMessagesAtom = atom((get) => { +export const getCurrentChatMessagesAtom = atom((get) => { const activeThreadId = get(getActiveThreadIdAtom) if (!activeThreadId) return [] const messages = get(chatMessages)[activeThreadId] return messages ?? [] }) -// TODO: rename this function to add instead of set -export const setThreadMessagesAtom = atom( +export const setConvoMessagesAtom = atom( null, - (get, set, threadId: string, messages: Message[]) => { - const newData: Record = { + (get, set, threadId: string, messages: ThreadMessage[]) => { + const newData: Record = { ...get(chatMessages), } - newData[threadId] = [...(newData.messages ?? []), ...messages.reverse()] + newData[threadId] = messages set(chatMessages, newData) + set(readyThreadsMessagesAtom, { + ...get(readyThreadsMessagesAtom), + [threadId]: true, + }) } ) -export const addNewMessageAtom = atom(null, (get, set, newMessage: Message) => { - const currentMessages = get(chatMessages)[newMessage.thread_id] ?? [] - const updatedMessages = [...currentMessages, newMessage] +/** + * Used for pagination. Add old messages to the current conversation + */ +export const addOldMessagesAtom = atom( + null, + (get, set, newMessages: ThreadMessage[]) => { + const currentConvoId = get(getActiveThreadIdAtom) + if (!currentConvoId) return + + const currentMessages = get(chatMessages)[currentConvoId] ?? [] + const updatedMessages = [...currentMessages, ...newMessages] - const newData: Record = { - ...get(chatMessages), + const newData: Record = { + ...get(chatMessages), + } + newData[currentConvoId] = updatedMessages + set(chatMessages, newData) } - newData[newMessage.thread_id] = updatedMessages - set(chatMessages, newData) -}) +) + +export const addNewMessageAtom = atom( + null, + (get, set, newMessage: ThreadMessage) => { + const currentMessages = get(chatMessages)[newMessage.thread_id] ?? [] + const updatedMessages = [...currentMessages, newMessage] + + const newData: Record = { + ...get(chatMessages), + } + newData[newMessage.thread_id] = updatedMessages + set(chatMessages, newData) + + // Update thread last message + if (newMessage.content.length) + set( + updateThreadStateLastMessageAtom, + newMessage.thread_id, + newMessage.content + ) + } +) export const deleteChatMessageAtom = atom( null, (get, set, threadId: string) => { - const newData: Record = { + const newData: Record = { ...get(chatMessages), } newData[threadId] = [] @@ -54,22 +97,26 @@ export const deleteChatMessageAtom = atom( ) export const cleanChatMessageAtom = atom(null, (get, set, id: string) => { - const newData: Record = { + const newData: Record = { ...get(chatMessages), } - newData[id] = [] + newData[id] = newData[id]?.filter((e) => e.role === ChatCompletionRole.System) set(chatMessages, newData) }) export const deleteMessageAtom = atom(null, (get, set, id: string) => { - const newData: Record = { + const newData: Record = { ...get(chatMessages), } const threadId = get(getActiveThreadIdAtom) - if (!threadId) return + if (threadId) { + // Should also delete error messages to clear out the error state + newData[threadId] = newData[threadId].filter( + (e) => e.id !== id && e.status !== MessageStatus.Error + ) - newData[threadId] = newData[threadId].filter((e) => e.id !== id) - set(chatMessages, newData) + set(chatMessages, newData) + } }) export const editMessageAtom = atom('') @@ -80,20 +127,25 @@ export const updateMessageAtom = atom( get, set, id: string, - threadId: string, - text: MessageContent[], - status: 'in_progress' | 'completed' | 'incomplete' + conversationId: string, + text: ThreadContent[], + status: MessageStatus ) => { - const messages = get(chatMessages)[threadId] ?? [] + const messages = get(chatMessages)[conversationId] ?? [] const message = messages.find((e) => e.id === id) - if (!message) return - message.content = text - message.status = status - const updatedMessages = [...messages] - const newData: Record = { - ...get(chatMessages), + if (message) { + message.content = text + message.status = status + const updatedMessages = [...messages] + + const newData: Record = { + ...get(chatMessages), + } + newData[conversationId] = updatedMessages + set(chatMessages, newData) + // Update thread last message + if (text.length) + set(updateThreadStateLastMessageAtom, conversationId, text) } - newData[threadId] = updatedMessages - set(chatMessages, newData) } ) diff --git a/web/helpers/atoms/DownloadLocalModel.atom.ts b/web/helpers/atoms/DownloadLocalModel.atom.ts deleted file mode 100644 index bcde069f57..0000000000 --- a/web/helpers/atoms/DownloadLocalModel.atom.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { atom } from 'jotai' - -export type DownloadLocalModelStage = 'NONE' | 'MODEL_LIST' - -const downloadLocalModelStageAtom = atom('NONE') -const modelHubSelectedModelHandle = atom(undefined) - -export const localModelModalStageAtom = atom( - (get) => ({ - stage: get(downloadLocalModelStageAtom), - modelHandle: get(modelHubSelectedModelHandle), - }), - ( - _get, - set, - stage: DownloadLocalModelStage, - modelHandle: string | undefined - ) => { - set(downloadLocalModelStageAtom, stage) - set(modelHubSelectedModelHandle, modelHandle) - } -) diff --git a/web/helpers/atoms/Extension.atom.ts b/web/helpers/atoms/Extension.atom.ts index 7af755e351..28b8a6bc13 100644 --- a/web/helpers/atoms/Extension.atom.ts +++ b/web/helpers/atoms/Extension.atom.ts @@ -1,4 +1,5 @@ import { atom } from 'jotai' +import { atomWithStorage } from 'jotai/utils' type ExtensionId = string @@ -38,3 +39,9 @@ export const removeInstallingExtensionAtom = atom( set(installingExtensionAtom, newCurrent) } ) + +const INACTIVE_ENGINE_PROVIDER = 'inActiveEngineProvider' +export const inActiveEngineProviderAtom = atomWithStorage( + INACTIVE_ENGINE_PROVIDER, + [] +) diff --git a/web/helpers/atoms/Hub.atom.ts b/web/helpers/atoms/Hub.atom.ts deleted file mode 100644 index 87eb155423..0000000000 --- a/web/helpers/atoms/Hub.atom.ts +++ /dev/null @@ -1,5 +0,0 @@ -import { atom } from 'jotai' - -import { ModelFilter } from '@/screens/HubScreen2' - -export const hubFilterAtom = atom('All') diff --git a/web/helpers/atoms/HuggingFace.atom.ts b/web/helpers/atoms/HuggingFace.atom.ts index 09f7870a38..514efb186c 100644 --- a/web/helpers/atoms/HuggingFace.atom.ts +++ b/web/helpers/atoms/HuggingFace.atom.ts @@ -1,4 +1,4 @@ -import { HuggingFaceRepoData } from '@janhq/core' +import { HuggingFaceRepoData } from '@janhq/core/.' import { atom } from 'jotai' // modals diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts index 9d11e26cda..7ad65a15e6 100644 --- a/web/helpers/atoms/Model.atom.ts +++ b/web/helpers/atoms/Model.atom.ts @@ -1,17 +1,35 @@ -import { - ImportingModel, - LlmEngine, - LocalEngines, - Model, - ModelStatus, -} from '@janhq/core' +import { ImportingModel, Model } from '@janhq/core' import { atom } from 'jotai' -import { activeThreadAtom, threadsAtom } from './Thread.atom' - export const stateModel = atom({ state: 'start', loading: false, model: '' }) export const activeAssistantModelAtom = atom(undefined) +/** + * Stores the list of models which are being downloaded. + */ +const downloadingModelsAtom = atom([]) + +export const getDownloadingModelAtom = atom((get) => get(downloadingModelsAtom)) + +export const addDownloadingModelAtom = atom(null, (get, set, model: Model) => { + const downloadingModels = get(downloadingModelsAtom) + if (!downloadingModels.find((e) => e.id === model.id)) { + set(downloadingModelsAtom, [...downloadingModels, model]) + } +}) + +export const removeDownloadingModelAtom = atom( + null, + (get, set, modelId: string) => { + const downloadingModels = get(downloadingModelsAtom) + + set( + downloadingModelsAtom, + downloadingModels.filter((e) => e.id !== modelId) + ) + } +) + export const downloadedModelsAtom = atom([]) export const removeDownloadedModelAtom = atom( @@ -21,11 +39,15 @@ export const removeDownloadedModelAtom = atom( set( downloadedModelsAtom, - downloadedModels.filter((m) => m.model !== modelId) + downloadedModels.filter((e) => e.id !== modelId) ) } ) +export const configuredModelsAtom = atom([]) + +export const defaultModelAtom = atom(undefined) + /// TODO: move this part to another atom // store the paths of the models that are being imported export const importingModelsAtom = atom([]) @@ -67,14 +89,14 @@ export const setImportingModelErrorAtom = atom( export const setImportingModelSuccessAtom = atom( null, - (get, set, importId: string) => { + (get, set, importId: string, modelId: string) => { const model = get(importingModelsAtom).find((x) => x.importId === importId) if (!model) return const newModel: ImportingModel = { ...model, - modelId: undefined, + modelId, status: 'IMPORTED', - percentage: 100, + percentage: 1, } const newList = get(importingModelsAtom).map((x) => x.importId === importId ? newModel : x @@ -109,30 +131,4 @@ export const updateImportingModelAtom = atom( } ) -const selectedModelAtom = atom(undefined) - -export const getSelectedModelAtom = atom((get) => get(selectedModelAtom)) - -export const updateSelectedModelAtom = atom(null, (get, set, model: Model) => { - const activeThread = get(activeThreadAtom) - if (activeThread) { - activeThread.assistants[0].model = model.model - // update threadsAtom - const allThreads = get(threadsAtom) - allThreads.forEach((t) => { - if (t.id === activeThread.id) { - t.assistants[0].model = model.model - } - }) - console.debug( - `Update threads state list: ${JSON.stringify(allThreads, null, 2)}` - ) - set(threadsAtom, allThreads) - } - console.debug('Set selected model:', JSON.stringify(model, null, 2)) - set(selectedModelAtom, model) -}) - -export const activeModelsAtom = atom([]) - -export const showEngineListModelAtom = atom([...LocalEngines]) +export const selectedModelAtom = atom(undefined) diff --git a/web/helpers/atoms/Setting.atom.ts b/web/helpers/atoms/Setting.atom.ts index 75a70c698a..ced0fbe377 100644 --- a/web/helpers/atoms/Setting.atom.ts +++ b/web/helpers/atoms/Setting.atom.ts @@ -6,16 +6,17 @@ import { SettingScreen } from '@/screens/Settings' export const selectedSettingAtom = atom('My Models') +export const janSettingScreenAtom = atom([]) + export const THEME = 'themeAppearance' export const REDUCE_TRANSPARENT = 'reduceTransparent' export const SPELL_CHECKING = 'spellChecking' export const themesOptionsAtom = atom<{ name: string; value: string }[]>([]) -export const selectedThemeIdAtom = atomWithStorage(THEME, null) +export const janThemesPathAtom = atom(undefined) +export const selectedThemeIdAtom = atomWithStorage(THEME, '') export const themeDataAtom = atom(undefined) export const reduceTransparentAtom = atomWithStorage( REDUCE_TRANSPARENT, false ) export const spellCheckAtom = atomWithStorage(SPELL_CHECKING, true) - -export const showSidbarFilterAtom = atom(false) diff --git a/web/helpers/atoms/SetupRemoteModel.atom.ts b/web/helpers/atoms/SetupRemoteModel.atom.ts deleted file mode 100644 index a3a07eda8c..0000000000 --- a/web/helpers/atoms/SetupRemoteModel.atom.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { RemoteEngine } from '@janhq/core' -import { atom } from 'jotai' - -export type SetupRemoteModelStage = 'NONE' | 'SETUP_INTRO' | 'SETUP_API_KEY' - -const remoteModelSetUpStageAtom = atom('NONE') -const engineBeingSetUpAtom = atom(undefined) -const remoteEngineBeingSetUpMetadataAtom = atom< - Record | undefined ->(undefined) - -export const setUpRemoteModelStageAtom = atom( - (get) => ({ - stage: get(remoteModelSetUpStageAtom), - remoteEngine: get(engineBeingSetUpAtom), - metadata: get(remoteEngineBeingSetUpMetadataAtom), - }), - ( - _get, - set, - stage: SetupRemoteModelStage, - remoteEngine: RemoteEngine | undefined, - metadata?: Record | undefined - ) => { - set(remoteModelSetUpStageAtom, stage) - set(engineBeingSetUpAtom, remoteEngine) - set(remoteEngineBeingSetUpMetadataAtom, metadata) - } -) - -export const navigateToSetUpApiKeyAtom = atom(null, (_get, set) => { - set(remoteModelSetUpStageAtom, 'SETUP_API_KEY') -}) diff --git a/web/helpers/atoms/SystemBar.atom.ts b/web/helpers/atoms/SystemBar.atom.ts index 2ece2a8498..ba91364ba7 100644 --- a/web/helpers/atoms/SystemBar.atom.ts +++ b/web/helpers/atoms/SystemBar.atom.ts @@ -1,12 +1,12 @@ -import { GpuInfo } from '@janhq/core/.' import { atom } from 'jotai' export const totalRamAtom = atom(0) export const usedRamAtom = atom(0) export const cpuUsageAtom = atom(0) +export const ramUtilitizedAtom = atom(0) -export const gpusAtom = atom([]) +export const gpusAtom = atom[]>([]) export const nvidiaTotalVramAtom = atom(0) export const availableVramAtom = atom(0) diff --git a/web/helpers/atoms/Thread.atom.ts b/web/helpers/atoms/Thread.atom.ts index dc45018b45..c3fdb82607 100644 --- a/web/helpers/atoms/Thread.atom.ts +++ b/web/helpers/atoms/Thread.atom.ts @@ -1,25 +1,14 @@ -import { ModelRuntimeParams, ModelSettingParams, Thread } from '@janhq/core' - -import { atom } from 'jotai' - import { - downloadedModelsAtom, - getSelectedModelAtom, - updateSelectedModelAtom, -} from './Model.atom' + ModelRuntimeParams, + ModelSettingParams, + Thread, + ThreadContent, + ThreadState, +} from '@janhq/core' -const threadIdShouldAnimateTitle = atom([]) - -export const getThreadIdsShouldAnimateTitleAtom = atom((get) => - get(threadIdShouldAnimateTitle) -) +import { atom } from 'jotai' -export const addThreadIdShouldAnimateTitleAtom = atom( - null, - (_get, set, threadId: string) => { - set(threadIdShouldAnimateTitle, (current) => [...current, threadId]) - } -) +export const engineParamsUpdateAtom = atom(false) /** * Stores the current active thread id. @@ -30,73 +19,91 @@ export const getActiveThreadIdAtom = atom((get) => get(activeThreadIdAtom)) export const setActiveThreadIdAtom = atom( null, - (get, set, threadId: string | undefined) => { - const thread = get(threadsAtom).find((t) => t.id === threadId) - if (!thread) { - console.error(`Thread ${threadId} not found in state`) - return - } + (_get, set, threadId: string | undefined) => set(activeThreadIdAtom, threadId) +) - set(activeThreadIdAtom, threadId) - const modelId = thread.assistants[0]?.model - if (!modelId) { - console.error(`No model id ${modelId} found in thread`, thread) - return - } +export const waitingToSendMessage = atom(undefined) - const activeModelId = get(getSelectedModelAtom)?.model - if (activeModelId === modelId) { - console.debug('Model already selected:', modelId) - return - } +export const isGeneratingResponseAtom = atom(undefined) +/** + * Stores all thread states for the current user + */ +export const threadStatesAtom = atom>({}) - const model = get(downloadedModelsAtom).find((m) => m.model === modelId) - if (!model) { - console.warn(`Model ${modelId} removed or deleted`) - return - } +// Whether thread data is ready or not +export const threadDataReadyAtom = atom(false) + +export const activeThreadStateAtom = atom((get) => { + const threadId = get(activeThreadIdAtom) + if (!threadId) { + console.debug('Active thread id is undefined') + return undefined + } + + return get(threadStatesAtom)[threadId] +}) + +export const deleteThreadStateAtom = atom( + null, + (get, set, threadId: string) => { + const currentState = { ...get(threadStatesAtom) } + delete currentState[threadId] + set(threadStatesAtom, currentState) + } +) - console.debug('Set selected model:', model) - set(updateSelectedModelAtom, model) +export const updateThreadWaitingForResponseAtom = atom( + null, + (get, set, threadId: string, waitingForResponse: boolean) => { + const currentState = { ...get(threadStatesAtom) } + currentState[threadId] = { + ...currentState[threadId], + waitingForResponse, + error: undefined, + } + set(threadStatesAtom, currentState) + } +) +export const updateThreadStateLastMessageAtom = atom( + null, + (get, set, threadId: string, lastContent?: ThreadContent[]) => { + const currentState = { ...get(threadStatesAtom) } + const lastMessage = lastContent?.[0]?.text?.value ?? '' + currentState[threadId] = { + ...currentState[threadId], + lastMessage, + } + set(threadStatesAtom, currentState) } ) -export const isLoadingModelAtom = atom(undefined) +export const updateThreadAtom = atom( + null, + (get, set, updatedThread: Thread) => { + const threads: Thread[] = get(threadsAtom).map((c) => + c.id === updatedThread.id ? updatedThread : c + ) + + // sort new threads based on updated at + threads.sort((thread1, thread2) => { + const aDate = new Date(thread1.updated ?? 0) + const bDate = new Date(thread2.updated ?? 0) + return bDate.getTime() - aDate.getTime() + }) -export const isGeneratingResponseAtom = atom(false) + set(threadsAtom, threads) + } +) /** * Stores all threads for the current user */ export const threadsAtom = atom([]) -export const deleteThreadAtom = atom(null, (get, set, threadId: string) => { - const allThreads = get(threadsAtom) - const filteredThreads = allThreads.filter((t) => t.id !== threadId) - if (filteredThreads.length > 0) { - const latestThread = allThreads[0] - set(activeThreadIdAtom, latestThread.id) - } - set(threadsAtom, filteredThreads) -}) - export const activeThreadAtom = atom((get) => get(threadsAtom).find((c) => c.id === get(getActiveThreadIdAtom)) ) -export const updateThreadTitleAtom = atom( - null, - (_get, set, threadId: string, title: string) => { - set( - threadsAtom, - (threads) => - threads.map((t) => - t.id === threadId ? { ...t, title } : t - ) as Thread[] - ) - } -) - /** * Store model params at thread level settings */ @@ -104,6 +111,18 @@ export const threadModelParamsAtom = atom>({}) export type ModelParams = ModelRuntimeParams | ModelSettingParams +export const getActiveThreadModelParamsAtom = atom( + (get) => { + const threadId = get(activeThreadIdAtom) + if (!threadId) { + console.debug('Active thread id is undefined') + return undefined + } + + return get(threadModelParamsAtom)[threadId] + } +) + export const setThreadModelParamsAtom = atom( null, (get, set, threadId: string, params: ModelParams) => { diff --git a/web/hooks/useAbortDownload.ts b/web/hooks/useAbortDownload.ts deleted file mode 100644 index 7bb949f7ac..0000000000 --- a/web/hooks/useAbortDownload.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { useCallback } from 'react' - -import useCortex from './useCortex' - -const useAbortDownload = () => { - const { abortDownload: cancelDownload } = useCortex() - - const abortDownload = useCallback( - (downloadId: string) => { - cancelDownload(downloadId) - }, - [cancelDownload] - ) - - return { abortDownload } -} - -export default useAbortDownload diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts new file mode 100644 index 0000000000..9768ac4c4a --- /dev/null +++ b/web/hooks/useActiveModel.ts @@ -0,0 +1,174 @@ +import { useCallback, useEffect, useRef } from 'react' + +import { EngineManager, Model } from '@janhq/core' +import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' + +import { toaster } from '@/containers/Toast' + +import { LAST_USED_MODEL_ID } from './useRecommendedModel' + +import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom' +import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' +import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' + +export const activeModelAtom = atom(undefined) +export const loadModelErrorAtom = atom(undefined) + +type ModelState = { + state: string + loading: boolean + model?: Model +} + +export const stateModelAtom = atom({ + state: 'start', + loading: false, + model: undefined, +}) + +const pendingModelLoadAtom = atom(false) + +export function useActiveModel() { + const [activeModel, setActiveModel] = useAtom(activeModelAtom) + const activeThread = useAtomValue(activeThreadAtom) + const [stateModel, setStateModel] = useAtom(stateModelAtom) + const downloadedModels = useAtomValue(downloadedModelsAtom) + const setLoadModelError = useSetAtom(loadModelErrorAtom) + const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom) + const isVulkanEnabled = useAtomValue(vulkanEnabledAtom) + + const downloadedModelsRef = useRef([]) + + useEffect(() => { + downloadedModelsRef.current = downloadedModels + }, [downloadedModels]) + + const startModel = async (modelId: string, abortable: boolean = true) => { + if ( + (activeModel && activeModel.id === modelId) || + (stateModel.model?.id === modelId && stateModel.loading) + ) { + console.debug(`Model ${modelId} is already initialized. Ignore..`) + return Promise.resolve() + } + setPendingModelLoad(true) + + let model = downloadedModelsRef?.current.find((e) => e.id === modelId) + + const error = await stopModel().catch((error: Error) => error) + if (error) { + return Promise.reject(error) + } + + setLoadModelError(undefined) + + setActiveModel(undefined) + + setStateModel({ state: 'start', loading: true, model }) + + if (!model) { + toaster({ + title: `Model ${modelId} not found!`, + description: `Please download the model first.`, + type: 'warning', + }) + setStateModel(() => ({ + state: 'start', + loading: false, + model: undefined, + })) + + return Promise.reject(`Model ${modelId} not found!`) + } + + /// Apply thread model settings + if (activeThread?.assistants[0]?.model.id === modelId) { + model = { + ...model, + settings: { + ...model.settings, + ...activeThread.assistants[0].model.settings, + }, + } + } + + if (isVulkanEnabled) { + // @ts-expect-error flash_attn is newly added and will be migrate to cortex in the future + model.settings['flash_attn'] = false + } + + localStorage.setItem(LAST_USED_MODEL_ID, model.id) + const engine = EngineManager.instance().get(model.engine) + return engine + ?.loadModel(model) + .then(() => { + setActiveModel(model) + setStateModel(() => ({ + state: 'stop', + loading: false, + model, + })) + toaster({ + title: 'Success!', + description: `Model ${model.id} has been started.`, + type: 'success', + }) + }) + .catch((error) => { + setStateModel(() => ({ + state: 'start', + loading: false, + model, + })) + + if (!pendingModelLoad && abortable) { + return Promise.reject(new Error('aborted')) + } + + toaster({ + title: 'Failed!', + description: `Model ${model.id} failed to start.`, + type: 'error', + }) + setLoadModelError(error) + return Promise.reject(error) + }) + } + + const stopModel = useCallback(async () => { + const stoppingModel = activeModel || stateModel.model + if (!stoppingModel || (stateModel.state === 'stop' && stateModel.loading)) + return + + setStateModel({ state: 'stop', loading: true, model: stoppingModel }) + const engine = EngineManager.instance().get(stoppingModel.engine) + return engine + ?.unloadModel(stoppingModel) + .catch() + .then(() => { + setActiveModel(undefined) + setStateModel({ state: 'start', loading: false, model: undefined }) + setPendingModelLoad(false) + }) + }, [ + activeModel, + setActiveModel, + setStateModel, + setPendingModelLoad, + stateModel, + ]) + + const stopInference = useCallback(async () => { + // Loading model + if (stateModel.loading) { + stopModel() + return + } + if (!activeModel) return + + const engine = EngineManager.instance().get(activeModel.engine) + engine?.stopInference() + }, [activeModel, stateModel, stopModel]) + + return { activeModel, startModel, stopModel, stopInference, stateModel } +} diff --git a/web/hooks/useAssistantCreate.ts b/web/hooks/useAssistantCreate.ts deleted file mode 100644 index 6b0e9a9b00..0000000000 --- a/web/hooks/useAssistantCreate.ts +++ /dev/null @@ -1,55 +0,0 @@ -import { Assistant } from '@janhq/core' -import { useMutation, useQueryClient } from '@tanstack/react-query' - -import { assistantQueryKey } from './useAssistantQuery' - -import useCortex from './useCortex' - -export const janAssistant: Assistant = { - avatar: '', - id: 'jan', - object: 'assistant', - created_at: Date.now(), - name: 'Jan', - description: 'A default assistant that can use all downloaded models', - model: '*', - instructions: '', - tools: [ - // { - // type: 'retrieval', - // enabled: false, - // settings: { - // top_k: 2, - // chunk_size: 1024, - // chunk_overlap: 64, - // retrieval_template: - // "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\nCONTEXT: {CONTEXT}\n----------------\nQUESTION: {QUESTION}\n----------------\nHelpful Answer:", - // }, - // }, - ], - metadata: undefined, -} - -const useAssistantCreate = () => { - const { createAssistant } = useCortex() - const queryClient = useQueryClient() - - return useMutation({ - mutationFn: createAssistant, - - onSuccess(data) { - queryClient.setQueryData( - assistantQueryKey, - (oldData: Assistant[] | undefined) => [...(oldData ?? []), data] - ) - }, - - onError(error, variables) { - console.error( - `Error while creating assistant: ${JSON.stringify(variables)}. Error: ${error}` - ) - }, - }) -} - -export default useAssistantCreate diff --git a/web/hooks/useAssistantQuery.ts b/web/hooks/useAssistantQuery.ts deleted file mode 100644 index 889a652f0f..0000000000 --- a/web/hooks/useAssistantQuery.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { useQuery } from '@tanstack/react-query' - -import useCortex from './useCortex' - -export const assistantQueryKey = ['assistants'] - -const useAssistantQuery = () => { - const { fetchAssistants } = useCortex() - - return useQuery({ - queryKey: assistantQueryKey, - queryFn: fetchAssistants, - staleTime: Infinity, - }) -} - -export default useAssistantQuery diff --git a/web/hooks/useAssistants.ts b/web/hooks/useAssistants.ts new file mode 100644 index 0000000000..61679bce5e --- /dev/null +++ b/web/hooks/useAssistants.ts @@ -0,0 +1,39 @@ +import { useCallback, useEffect } from 'react' + +import { + Assistant, + AssistantEvent, + AssistantExtension, + ExtensionTypeEnum, + events, +} from '@janhq/core' + +import { useSetAtom } from 'jotai' + +import { extensionManager } from '@/extension' +import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' + +const useAssistants = () => { + const setAssistants = useSetAtom(assistantsAtom) + + const getData = useCallback(async () => { + const assistants = await getLocalAssistants() + setAssistants(assistants) + }, [setAssistants]) + + useEffect(() => { + getData() + + events.on(AssistantEvent.OnAssistantsUpdate, () => getData()) + return () => { + events.off(AssistantEvent.OnAssistantsUpdate, () => getData()) + } + }, [getData]) +} + +const getLocalAssistants = async (): Promise => + extensionManager + .get(ExtensionTypeEnum.Assistant) + ?.getAssistants() ?? [] + +export default useAssistants diff --git a/web/hooks/useCortex.ts b/web/hooks/useCortex.ts deleted file mode 100644 index a98026825e..0000000000 --- a/web/hooks/useCortex.ts +++ /dev/null @@ -1,367 +0,0 @@ -import { useCallback } from 'react' - -import { Cortex } from '@cortexso/cortex.js' -import { Engine } from '@cortexso/cortex.js/resources' -import { - Assistant, - Model, - Message, - Thread, - ChatCompletionCreateParamsNonStreaming, - ChatCompletionCreateParamsStreaming, - AssistantCreateParams, - AssistantUpdateParams, - LlmEngine, - LlmEngines, -} from '@janhq/core' - -import { useAtomValue } from 'jotai' - -import { defaultThreadTitle } from '@/constants/Threads' - -import { UpdateConfigMutationVariables } from './useEngineMutation' -import { MessageCreateMutationVariables } from './useMessageCreateMutation' -import { MessageDeleteMutationVariables } from './useMessageDeleteMutation' -import { MessageUpdateMutationVariables } from './useMessageUpdateMutation' -import { DownloadModelMutationVariables } from './useModelDownloadMutation' - -import { hostAtom } from '@/helpers/atoms/AppConfig.atom' - -const useCortex = () => { - const host = useAtomValue(hostAtom) - - const cortex = new Cortex({ - baseURL: host, - apiKey: '', - dangerouslyAllowBrowser: true, - }) - - const getEngineStatuses = useCallback(async (): Promise => { - const engineResponse = await cortex.engines.list() - // @ts-expect-error incompatible types - const engineStatuses: Engine[] = engineResponse.body.data.map( - (engine: Engine) => { - return { - name: engine.name, - description: engine.description, - version: engine.version, - productName: engine.productName, - status: engine.status, - } - } - ) - - return engineStatuses - }, [cortex.engines]) - - const initializeEngine = useCallback( - async (engine: LlmEngine) => { - try { - await cortex.engines.init(engine) - } catch (err) { - console.error(err) - } - }, - [cortex.engines] - ) - - const fetchAssistants = useCallback(async () => { - const assistants: Assistant[] = [] - const response = await cortex.beta.assistants.list() - response.data.forEach((assistant) => { - assistants.push(assistant) - }) - return assistants - }, [cortex.beta.assistants]) - - const fetchThreads = useCallback(async () => { - const threads: Thread[] = [] - for await (const thread of cortex.beta.threads.list()) { - // @ts-expect-error each thread must have associated assistants - const assistants = thread['assistants'] as Assistant[] - if (!assistants || assistants.length === 0) continue - - // @ts-expect-error each thread must have a title, else default to 'New Thread' - const title: string = thread['title'] ?? defaultThreadTitle - - threads.push({ - ...thread, - title: title, - assistants: assistants, - }) - } - return threads - }, [cortex.beta.threads]) - - const fetchModels = useCallback(async () => { - const models: Model[] = [] - for await (const model of cortex.models.list()) { - // @ts-expect-error model should not be empty - const modelId = model.model - if (!modelId || modelId.length === 0) { - console.debug('Model id is empty, skipping', model) - continue - } - const engine = LlmEngines.find((engine) => engine === model.engine) - if (!engine) { - console.error(`Model ${modelId} has an invalid engine ${model.engine}`) - continue - } - - models.push({ - ...model, - engine: engine, - model: modelId, - // @ts-expect-error each model must have associated files - files: model['files'], - }) - } - return models - }, [cortex.models]) - - const fetchMessages = useCallback( - async (threadId: string) => { - try { - const messages: Message[] = [] - const response = await cortex.beta.threads.messages.list(threadId) - response.data.forEach((message) => { - messages.push(message) - }) - return messages - } catch (error) { - return [] - } - }, - [cortex.beta.threads.messages] - ) - - const startModel = useCallback( - async (modelId: string, options?: Record) => { - await cortex.models.start(modelId, options ?? {}) - }, - [cortex.models] - ) - - const stopModel = useCallback( - async (modelId: string, options?: Record) => { - await cortex.models.stop(modelId, options ?? {}) - }, - [cortex.models] - ) - - const chatCompletionNonStreaming = useCallback( - async ( - chatCompletionCreateParams: ChatCompletionCreateParamsNonStreaming, - options?: Record - // @ts-expect-error incompatible types - ) => cortex.chat.completions.create(chatCompletionCreateParams, options), - [cortex.chat.completions] - ) - - const chatCompletionStreaming = useCallback( - async ( - chatCompletionCreateParams: ChatCompletionCreateParamsStreaming, - options?: Record - // @ts-expect-error incompatible types - ) => cortex.chat.completions.create(chatCompletionCreateParams, options), - [cortex.chat.completions] - ) - - const deleteModel = useCallback( - async (modelId: string) => { - await cortex.models.del(modelId) - }, - [cortex.models] - ) - - const cleanThread = useCallback( - async (threadId: string) => cortex.beta.threads.clean(threadId), - [cortex.beta.threads] - ) - - const deleteThread = useCallback( - async (threadId: string) => { - await cortex.beta.threads.del(threadId) - }, - [cortex.beta.threads] - ) - - const updateThread = useCallback( - async (thread: Thread) => { - const result = await cortex.beta.threads.update(thread.id, thread) - console.debug( - `Update thread ${thread.id}, result: ${JSON.stringify(result, null, 2)}` - ) - }, - [cortex.beta.threads] - ) - - const deleteMessage = useCallback( - async (params: MessageDeleteMutationVariables) => { - const { threadId, messageId } = params - return cortex.beta.threads.messages.del(threadId, messageId) - }, - [cortex.beta.threads] - ) - - const createMessage = useCallback( - async (params: MessageCreateMutationVariables) => { - const { threadId, createMessageParams } = params - return cortex.beta.threads.messages.create(threadId, createMessageParams) - }, - [cortex.beta.threads] - ) - - const updateMessage = useCallback( - async (params: MessageUpdateMutationVariables) => { - const { threadId, messageId, data } = params - return cortex.beta.threads.messages.update(threadId, messageId, data) - }, - [cortex.beta.threads] - ) - - const createThread = useCallback( - async (assistant: Assistant) => { - const createThreadResponse = await cortex.beta.threads.create({ - // @ts-expect-error customize so that each thread will have an assistant - assistants: [assistant], - }) - const thread: Thread = { - ...createThreadResponse, - // @ts-expect-error each thread will have a title, else default to 'New Thread' - title: createThreadResponse.title ?? defaultThreadTitle, - assistants: [assistant], - } - return thread - }, - [cortex.beta.threads] - ) - - const updateModel = useCallback( - async (modelId: string, options: Record) => { - try { - return await cortex.models.update(modelId, options) - } catch (err) { - console.error(err) - } - }, - [cortex.models] - ) - - const downloadModel = useCallback( - async (variables: DownloadModelMutationVariables) => { - const { modelId, fileName, persistedModelId } = variables - const response = await fetch(`${host}/models/${modelId}/pull`, { - method: 'POST', - headers: { - 'accept': 'application/json', - // eslint-disable-next-line @typescript-eslint/naming-convention - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - fileName: fileName, - persistedModelId: persistedModelId, - }), - }) - if (!response.ok) { - const responseJson = await response.json() - const errorMessage: string = - (responseJson.error?.message as string) ?? - `Failed to download model ${modelId}` - throw new Error(errorMessage) - } - }, - [host] - ) - - const abortDownload = useCallback( - async (downloadId: string) => { - try { - return await cortex.models.abortDownload(downloadId) - } catch (err) { - console.error(err) - } - }, - [cortex.models] - ) - - const createAssistant = useCallback( - async (createParams: AssistantCreateParams) => - cortex.beta.assistants.create(createParams), - [cortex.beta.assistants] - ) - - const updateAssistant = useCallback( - async (assistantId: string, updateParams: AssistantUpdateParams) => - cortex.beta.assistants.update(assistantId, updateParams), - [cortex.beta.assistants] - ) - - // TODO: add this to cortex-node - const registerEngineConfig = useCallback( - async (variables: UpdateConfigMutationVariables) => { - const { engine, config } = variables - try { - await cortex.engines.update(engine, config) - } catch (err) { - console.error(err) - } - }, - [cortex.engines] - ) - - // add this to cortex-node? - const createModel = useCallback( - (model: Model) => - fetch(`${host}/models`, { - method: 'POST', - headers: { - 'accept': 'application/json', - // eslint-disable-next-line @typescript-eslint/naming-convention - 'Content-Type': 'application/json', - }, - body: JSON.stringify(model), - }), - [host] - ) - - const isSystemAlive = useCallback(async () => { - try { - await cortex.system.status() - return true - } catch { - return false - } - }, [cortex.system]) - - return { - fetchAssistants, - fetchThreads, - fetchModels, - fetchMessages, - startModel, - stopModel, - chatCompletionStreaming, - deleteModel, - deleteThread, - deleteMessage, - cleanThread, - updateThread, - createMessage, - updateMessage, - createThread, - downloadModel, - abortDownload, - createAssistant, - updateAssistant, - updateModel, - chatCompletionNonStreaming, - registerEngineConfig, - createModel, - initializeEngine, - getEngineStatuses, - isSystemAlive, - } -} - -export default useCortex diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts new file mode 100644 index 0000000000..5a1a32cb1f --- /dev/null +++ b/web/hooks/useCreateNewThread.ts @@ -0,0 +1,166 @@ +import { useCallback } from 'react' + +import { + Assistant, + ConversationalExtension, + ExtensionTypeEnum, + Thread, + ThreadAssistantInfo, + ThreadState, + Model, + AssistantTool, +} from '@janhq/core' +import { atom, useAtomValue, useSetAtom } from 'jotai' + +import { fileUploadAtom } from '@/containers/Providers/Jotai' + +import { generateThreadId } from '@/utils/thread' + +import { useActiveModel } from './useActiveModel' +import useRecommendedModel from './useRecommendedModel' + +import useSetActiveThread from './useSetActiveThread' + +import { extensionManager } from '@/extension' + +import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' +import { selectedModelAtom } from '@/helpers/atoms/Model.atom' +import { + threadsAtom, + threadStatesAtom, + updateThreadAtom, + setThreadModelParamsAtom, + isGeneratingResponseAtom, +} from '@/helpers/atoms/Thread.atom' + +const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => { + // create thread state for this new thread + const currentState = { ...get(threadStatesAtom) } + + const threadState: ThreadState = { + hasMore: false, + waitingForResponse: false, + lastMessage: undefined, + } + currentState[newThread.id] = threadState + set(threadStatesAtom, currentState) + + // add the new thread on top of the thread list to the state + const threads = get(threadsAtom) + set(threadsAtom, [newThread, ...threads]) +}) + +export const useCreateNewThread = () => { + const createNewThread = useSetAtom(createNewThreadAtom) + const { setActiveThread } = useSetActiveThread() + const updateThread = useSetAtom(updateThreadAtom) + const setFileUpload = useSetAtom(fileUploadAtom) + const setSelectedModel = useSetAtom(selectedModelAtom) + const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) + + const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) + const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) + + const { recommendedModel, downloadedModels } = useRecommendedModel() + + const threads = useAtomValue(threadsAtom) + const { stopInference } = useActiveModel() + + const requestCreateNewThread = async ( + assistant: Assistant, + model?: Model | undefined + ) => { + // Stop generating if any + setIsGeneratingResponse(false) + stopInference() + + const defaultModel = model ?? recommendedModel ?? downloadedModels[0] + + if (!model) { + // if we have model, which means user wants to create new thread from Model hub. Allow them. + + // check last thread message, if there empty last message use can not create thread + const lastMessage = threads[0]?.metadata?.lastMessage + + if (!lastMessage && threads.length) { + return null + } + } + + // modify assistant tools when experimental on, retieval toggle enabled in default + const assistantTools: AssistantTool = { + type: 'retrieval', + enabled: true, + settings: assistant.tools && assistant.tools[0].settings, + } + + const overriddenSettings = + defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048 + ? { ctx_len: 2048 } + : {} + + const overriddenParameters = + defaultModel?.parameters.max_tokens && defaultModel.parameters.max_tokens + ? { max_tokens: 2048 } + : {} + + const createdAt = Date.now() + const assistantInfo: ThreadAssistantInfo = { + assistant_id: assistant.id, + assistant_name: assistant.name, + tools: experimentalEnabled ? [assistantTools] : assistant.tools, + model: { + id: defaultModel?.id ?? '*', + settings: { ...defaultModel?.settings, ...overriddenSettings } ?? {}, + parameters: + { ...defaultModel?.parameters, ...overriddenParameters } ?? {}, + engine: defaultModel?.engine, + }, + instructions: assistant.instructions, + } + + const threadId = generateThreadId(assistant.id) + const thread: Thread = { + id: threadId, + object: 'thread', + title: 'New Thread', + assistants: [assistantInfo], + created: createdAt, + updated: createdAt, + } + + // add the new thread on top of the thread list to the state + //TODO: Why do we have thread list then thread states? Should combine them + createNewThread(thread) + + setSelectedModel(defaultModel) + setThreadModelParams(thread.id, { + ...defaultModel?.settings, + ...defaultModel?.parameters, + ...overriddenSettings, + }) + + // Delete the file upload state + setFileUpload([]) + // Update thread metadata + await updateThreadMetadata(thread) + + setActiveThread(thread) + } + + const updateThreadMetadata = useCallback( + async (thread: Thread) => { + updateThread(thread) + + await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.saveThread(thread) + }, + [updateThread] + ) + + return { + requestCreateNewThread, + updateThreadMetadata, + } +} diff --git a/web/hooks/useDeleteModel.ts b/web/hooks/useDeleteModel.ts new file mode 100644 index 0000000000..9736f82563 --- /dev/null +++ b/web/hooks/useDeleteModel.ts @@ -0,0 +1,32 @@ +import { useCallback } from 'react' + +import { ExtensionTypeEnum, ModelExtension, Model } from '@janhq/core' + +import { useSetAtom } from 'jotai' + +import { toaster } from '@/containers/Toast' + +import { extensionManager } from '@/extension/ExtensionManager' +import { removeDownloadedModelAtom } from '@/helpers/atoms/Model.atom' + +export default function useDeleteModel() { + const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom) + + const deleteModel = useCallback( + async (model: Model) => { + await localDeleteModel(model.id) + removeDownloadedModel(model.id) + toaster({ + title: 'Model Deletion Successful', + description: `Model ${model.name} has been successfully deleted.`, + type: 'success', + }) + }, + [removeDownloadedModel] + ) + + return { deleteModel } +} + +const localDeleteModel = async (id: string) => + extensionManager.get(ExtensionTypeEnum.Model)?.deleteModel(id) diff --git a/web/hooks/useDeleteThread.ts b/web/hooks/useDeleteThread.ts new file mode 100644 index 0000000000..69e51228f1 --- /dev/null +++ b/web/hooks/useDeleteThread.ts @@ -0,0 +1,139 @@ +import { useCallback } from 'react' + +import { + ChatCompletionRole, + ExtensionTypeEnum, + ConversationalExtension, + fs, + joinPath, + Thread, +} from '@janhq/core' + +import { useAtom, useAtomValue, useSetAtom } from 'jotai' + +import { currentPromptAtom } from '@/containers/Providers/Jotai' + +import { toaster } from '@/containers/Toast' + +import { extensionManager } from '@/extension/ExtensionManager' + +import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' +import { + chatMessages, + cleanChatMessageAtom as cleanChatMessagesAtom, + deleteChatMessageAtom as deleteChatMessagesAtom, +} from '@/helpers/atoms/ChatMessage.atom' +import { + threadsAtom, + setActiveThreadIdAtom, + deleteThreadStateAtom, + updateThreadStateLastMessageAtom, + updateThreadAtom, +} from '@/helpers/atoms/Thread.atom' + +export default function useDeleteThread() { + const [threads, setThreads] = useAtom(threadsAtom) + const messages = useAtomValue(chatMessages) + const janDataFolderPath = useAtomValue(janDataFolderPathAtom) + + const setCurrentPrompt = useSetAtom(currentPromptAtom) + const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) + const deleteMessages = useSetAtom(deleteChatMessagesAtom) + const cleanMessages = useSetAtom(cleanChatMessagesAtom) + + const deleteThreadState = useSetAtom(deleteThreadStateAtom) + const updateThreadLastMessage = useSetAtom(updateThreadStateLastMessageAtom) + const updateThread = useSetAtom(updateThreadAtom) + + const cleanThread = useCallback( + async (threadId: string) => { + cleanMessages(threadId) + const thread = threads.find((c) => c.id === threadId) + if (!thread) return + + const updatedMessages = (messages[threadId] ?? []).filter( + (msg) => msg.role === ChatCompletionRole.System + ) + + // remove files + try { + const threadFolderPath = await joinPath([ + janDataFolderPath, + 'threads', + threadId, + ]) + const threadFilesPath = await joinPath([threadFolderPath, 'files']) + const threadMemoryPath = await joinPath([threadFolderPath, 'memory']) + await fs.rm(threadFilesPath) + await fs.rm(threadMemoryPath) + } catch (err) { + console.warn('Error deleting thread files', err) + } + + await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.writeMessages(threadId, updatedMessages) + + thread.metadata = { + ...thread.metadata, + } + + const updatedThread: Thread = { + ...thread, + title: 'New Thread', + metadata: { ...thread.metadata, lastMessage: undefined }, + } + + await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.saveThread(updatedThread) + updateThreadLastMessage(threadId, undefined) + updateThread(updatedThread) + }, + [ + cleanMessages, + threads, + messages, + updateThreadLastMessage, + updateThread, + janDataFolderPath, + ] + ) + + const deleteThread = async (threadId: string) => { + if (!threadId) { + alert('No active thread') + return + } + try { + await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.deleteThread(threadId) + const availableThreads = threads.filter((c) => c.id !== threadId) + setThreads(availableThreads) + + // delete the thread state + deleteThreadState(threadId) + + deleteMessages(threadId) + setCurrentPrompt('') + toaster({ + title: 'Thread successfully deleted.', + description: `Thread ${threadId} has been successfully deleted.`, + type: 'success', + }) + if (availableThreads.length > 0) { + setActiveThreadId(availableThreads[0].id) + } else { + setActiveThreadId(undefined) + } + } catch (err) { + console.error(err) + } + } + + return { + cleanThread, + deleteThread, + } +} diff --git a/web/hooks/useDownloadModel.ts b/web/hooks/useDownloadModel.ts new file mode 100644 index 0000000000..d0d13d93b9 --- /dev/null +++ b/web/hooks/useDownloadModel.ts @@ -0,0 +1,115 @@ +import { useCallback } from 'react' + +import { + Model, + ExtensionTypeEnum, + ModelExtension, + abortDownload, + joinPath, + ModelArtifact, + DownloadState, + GpuSetting, +} from '@janhq/core' + +import { useAtomValue, useSetAtom } from 'jotai' + +import { setDownloadStateAtom } from './useDownloadState' + +import useGpuSetting from './useGpuSetting' + +import { extensionManager } from '@/extension/ExtensionManager' +import { + ignoreSslAtom, + proxyAtom, + proxyEnabledAtom, +} from '@/helpers/atoms/AppConfig.atom' +import { addDownloadingModelAtom } from '@/helpers/atoms/Model.atom' + +export default function useDownloadModel() { + const ignoreSSL = useAtomValue(ignoreSslAtom) + const proxy = useAtomValue(proxyAtom) + const proxyEnabled = useAtomValue(proxyEnabledAtom) + const setDownloadState = useSetAtom(setDownloadStateAtom) + const addDownloadingModel = useSetAtom(addDownloadingModelAtom) + + const { getGpuSettings } = useGpuSetting() + + const downloadModel = useCallback( + async (model: Model) => { + const childProgresses: DownloadState[] = model.sources.map( + (source: ModelArtifact) => ({ + fileName: source.filename, + modelId: model.id, + time: { + elapsed: 0, + remaining: 0, + }, + speed: 0, + percent: 0, + size: { + total: 0, + transferred: 0, + }, + downloadState: 'downloading', + }) + ) + + // set an initial download state + setDownloadState({ + fileName: '', + modelId: model.id, + time: { + elapsed: 0, + remaining: 0, + }, + speed: 0, + percent: 0, + size: { + total: 0, + transferred: 0, + }, + children: childProgresses, + downloadState: 'downloading', + }) + + addDownloadingModel(model) + const gpuSettings = await getGpuSettings() + await localDownloadModel( + model, + ignoreSSL, + proxyEnabled ? proxy : '', + gpuSettings + ) + }, + [ + ignoreSSL, + proxy, + proxyEnabled, + getGpuSettings, + addDownloadingModel, + setDownloadState, + ] + ) + + const abortModelDownload = useCallback(async (model: Model) => { + for (const source of model.sources) { + const path = await joinPath(['models', model.id, source.filename]) + await abortDownload(path) + } + }, []) + + return { + downloadModel, + abortModelDownload, + } +} + +const localDownloadModel = async ( + model: Model, + ignoreSSL: boolean, + proxy: string, + gpuSettings?: GpuSetting +) => + extensionManager + .get(ExtensionTypeEnum.Model) + ?.downloadModel(model, gpuSettings, { ignoreSSL, proxy }) diff --git a/web/hooks/useDownloadState.ts b/web/hooks/useDownloadState.ts index 0833c428d1..03a8883cb5 100644 --- a/web/hooks/useDownloadState.ts +++ b/web/hooks/useDownloadState.ts @@ -1,47 +1,146 @@ -import { DownloadState2, DownloadStatus, DownloadType2 } from '@janhq/core' +import { DownloadState } from '@janhq/core' import { atom } from 'jotai' -export const downloadStateListAtom = atom([]) +import { toaster } from '@/containers/Toast' -export const addDownloadModelStateAtom = atom( - null, - (_get, set, modelId: string) => { - const state: DownloadState2 = { - id: modelId, - title: modelId, - type: DownloadType2.Model, - progress: 0, - status: DownloadStatus.Downloading, - children: [ - { - id: modelId, - time: { - elapsed: 0, - remaining: 0, - }, - size: { - total: 0, - transferred: 0, - }, - status: DownloadStatus.Downloading, - }, - ], - } - set(downloadStateListAtom, (old) => [...old, state]) - } -) +import { + configuredModelsAtom, + downloadedModelsAtom, + removeDownloadingModelAtom, +} from '@/helpers/atoms/Model.atom' + +// download states +export const modelDownloadStateAtom = atom>({}) /** - * Used to remove a download item from a list of downloading. - * - * @param downloadId The download id to be removed. If item is model then - * this is the modelId. + * Used to set the download state for a particular model. */ -export const removeDownloadSuccessItemAtom = atom( +export const setDownloadStateAtom = atom( null, - (_get, set, downloadId: string) => { - set(downloadStateListAtom, (old) => - old.filter((downloadState) => downloadState.id !== downloadId) - ) + (get, set, state: DownloadState) => { + try { + const currentState = { ...get(modelDownloadStateAtom) } + + if (state.downloadState === 'end') { + const modelDownloadState = currentState[state.modelId] + + const updatedChildren: DownloadState[] = ( + modelDownloadState.children ?? [] + ).filter((m) => m.fileName !== state.fileName) + updatedChildren.push(state) + modelDownloadState.children = updatedChildren + currentState[state.modelId] = modelDownloadState + + const isAllChildrenDownloadEnd = modelDownloadState.children?.every( + (m) => m.downloadState === 'end' + ) + + if (isAllChildrenDownloadEnd) { + // download successfully + delete currentState[state.modelId] + set(removeDownloadingModelAtom, state.modelId) + + const model = get(configuredModelsAtom).find( + (e) => e.id === state.modelId + ) + if (model) set(downloadedModelsAtom, (prev) => [...prev, model]) + toaster({ + title: 'Download Completed', + description: `Download ${state.modelId} completed`, + type: 'success', + }) + } + } else if (state.downloadState === 'error') { + // download error + delete currentState[state.modelId] + set(removeDownloadingModelAtom, state.modelId) + if (state.error === 'aborted') { + toaster({ + title: 'Cancel Download', + description: `Model ${state.modelId} download cancelled`, + type: 'warning', + }) + } else { + let error = state.error + if ( + typeof error?.includes === 'function' && + state.error?.includes('certificate') + ) { + error += + '. To fix enable "Ignore SSL Certificates" in Advanced settings.' + } + toaster({ + title: 'Download Failed', + description: `Model ${state.modelId} download failed: ${error}`, + type: 'error', + }) + } + } else { + // download in progress + if (state.size.total === 0) { + // this is initial state, just set the state + currentState[state.modelId] = state + set(modelDownloadStateAtom, currentState) + return + } + + const modelDownloadState = currentState[state.modelId] + if (!modelDownloadState) { + console.debug('setDownloadStateAtom: modelDownloadState not found') + return + } + + // delete the children if the filename is matched and replace the new state + const updatedChildren: DownloadState[] = ( + modelDownloadState.children ?? [] + ).filter((m) => m.fileName !== state.fileName) + + updatedChildren.push(state) + + // re-calculate the overall progress if we have all the children download data + const isAnyChildDownloadNotReady = updatedChildren.some( + (m) => + m.size.total === 0 && + !modelDownloadState.children?.some( + (e) => e.fileName === m.fileName && e.downloadState === 'end' + ) && + modelDownloadState.children?.some((e) => e.fileName === m.fileName) + ) + + modelDownloadState.children = updatedChildren + if (isAnyChildDownloadNotReady) { + // just update the children + currentState[state.modelId] = modelDownloadState + set(modelDownloadStateAtom, currentState) + return + } + + const parentTotalSize = modelDownloadState.size.total + if (parentTotalSize === 0) { + // calculate the total size of the parent by sum all children total size + const totalSize = updatedChildren.reduce( + (acc, m) => acc + m.size.total, + 0 + ) + + modelDownloadState.size.total = totalSize + } + + // calculate the total transferred size by sum all children transferred size + const transferredSize = updatedChildren.reduce( + (acc, m) => acc + m.size.transferred, + 0 + ) + modelDownloadState.size.transferred = transferredSize + modelDownloadState.percent = + parentTotalSize === 0 ? 0 : transferredSize / parentTotalSize + currentState[state.modelId] = modelDownloadState + } + + set(modelDownloadStateAtom, currentState) + } catch (e) { + console.debug('setDownloadStateAtom: state', state) + console.debug('setDownloadStateAtom: error', e) + } } ) diff --git a/web/hooks/useDropModelBinaries.ts b/web/hooks/useDropModelBinaries.ts index 6e593304bc..d87e96627e 100644 --- a/web/hooks/useDropModelBinaries.ts +++ b/web/hooks/useDropModelBinaries.ts @@ -3,6 +3,8 @@ import { useCallback } from 'react' import { ImportingModel } from '@janhq/core' import { useSetAtom } from 'jotai' +import { v4 as uuidv4 } from 'uuid' + import { snackbar } from '@/containers/Toast' import { getFileInfoFromFile } from '@/utils/file' @@ -24,23 +26,17 @@ export default function useDropModelBinaries() { ) const supportedFiles = files.filter((file) => file.path.endsWith('.gguf')) - const importingModels: ImportingModel[] = supportedFiles.map((file) => { - const normalizedPath = isWindows - ? file.path.replace(/\\/g, '/') - : file.path - - return { - importId: normalizedPath, - modelId: undefined, - name: normalizedPath.replace('.gguf', ''), - description: '', - path: file.path, - tags: [], - size: file.size, - status: 'PREPARING', - format: 'gguf', - } - }) + const importingModels: ImportingModel[] = supportedFiles.map((file) => ({ + importId: uuidv4(), + modelId: undefined, + name: file.name.replace('.gguf', ''), + description: '', + path: file.path, + tags: [], + size: file.size, + status: 'PREPARING', + format: 'gguf', + })) if (unsupportedFiles.length > 0) { snackbar({ description: `Only files with .gguf extension can be imported.`, diff --git a/web/hooks/useEngineInit.ts b/web/hooks/useEngineInit.ts deleted file mode 100644 index fc89a72bfd..0000000000 --- a/web/hooks/useEngineInit.ts +++ /dev/null @@ -1,39 +0,0 @@ -import { Engine } from '@cortexso/cortex.js/resources' -import { EngineStatus } from '@janhq/core' -import { useMutation, useQueryClient } from '@tanstack/react-query' - -import useCortex from './useCortex' -import { engineQueryKey } from './useEngineQuery' - -const useEngineInit = () => { - const { initializeEngine } = useCortex() - const queryClient = useQueryClient() - - return useMutation({ - mutationFn: initializeEngine, - - onSuccess: async (data, engineName) => { - console.debug(`Engine ${engineName} initialized`, data) - - // optimistically set the engine status to 'ready' - const queryCacheData = await queryClient.getQueryData(engineQueryKey) - if (!queryCacheData) { - return queryClient.invalidateQueries({ queryKey: engineQueryKey }) - } - const engineStatuses = queryCacheData as Engine[] - engineStatuses.forEach((engine) => { - if (engine.name === engineName) { - engine.status = EngineStatus.Ready - } - }) - console.debug(`Updated engine status: ${engineStatuses}`) - await queryClient.setQueryData(engineQueryKey, engineStatuses) - }, - - onError(error, variables) { - console.error(`Engine ${variables} failed to initialize`, error) - }, - }) -} - -export default useEngineInit diff --git a/web/hooks/useEngineMutation.ts b/web/hooks/useEngineMutation.ts deleted file mode 100644 index 5d34aa1391..0000000000 --- a/web/hooks/useEngineMutation.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { RemoteEngine } from '@janhq/core' -import { useMutation, useQueryClient } from '@tanstack/react-query' - -import { useSetAtom } from 'jotai' - -import { toaster } from '@/containers/Toast' - -import useCortex from './useCortex' - -import { engineQueryKey } from './useEngineQuery' - -import { setUpRemoteModelStageAtom } from '@/helpers/atoms/SetupRemoteModel.atom' - -export type UpdateConfigMutationVariables = { - engine: RemoteEngine - config: { config: string; value: string } -} - -const useEngineMutation = () => { - const { registerEngineConfig } = useCortex() - const queryClient = useQueryClient() - - const setUpRemoteModelStage = useSetAtom(setUpRemoteModelStageAtom) - - return useMutation({ - mutationFn: registerEngineConfig, - - onError: (err, variables) => { - console.error( - `Failed to register engine with variables: ${variables}, err: ${err}` - ) - }, - - onSuccess: async () => { - await queryClient.invalidateQueries({ queryKey: engineQueryKey }) - setUpRemoteModelStage('NONE', undefined) - toaster({ - title: 'Success!', - description: `Key added successfully`, - type: 'success', - }) - }, - }) -} - -export default useEngineMutation diff --git a/web/hooks/useEngineQuery.ts b/web/hooks/useEngineQuery.ts deleted file mode 100644 index 1280bf50df..0000000000 --- a/web/hooks/useEngineQuery.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { useQuery } from '@tanstack/react-query' - -import useCortex from './useCortex' - -export const engineQueryKey = ['getEngineStatuses'] - -const useEngineQuery = () => { - const { getEngineStatuses } = useCortex() - - return useQuery({ - queryKey: engineQueryKey, - queryFn: getEngineStatuses, - staleTime: 30 * 1000, - }) -} - -export default useEngineQuery diff --git a/web/hooks/useFactoryReset.ts b/web/hooks/useFactoryReset.ts index b44656068e..8364ca10d9 100644 --- a/web/hooks/useFactoryReset.ts +++ b/web/hooks/useFactoryReset.ts @@ -1,6 +1,11 @@ import { useCallback } from 'react' -import { atom, useSetAtom } from 'jotai' +import { fs, AppConfiguration } from '@janhq/core' +import { atom, useAtomValue, useSetAtom } from 'jotai' + +import { useActiveModel } from './useActiveModel' + +import { defaultJanDataFolderAtom } from '@/helpers/atoms/App.atom' export enum FactoryResetState { Idle = 'idle', @@ -12,56 +17,50 @@ export enum FactoryResetState { export const factoryResetStateAtom = atom(FactoryResetState.Idle) -const useFactoryReset = () => { - // const defaultJanDataFolder = useAtomValue(defaultJanDataFolderAtom) +export default function useFactoryReset() { + const defaultJanDataFolder = useAtomValue(defaultJanDataFolderAtom) + const { stopModel } = useActiveModel() const setFactoryResetState = useSetAtom(factoryResetStateAtom) const resetAll = useCallback( async (keepCurrentFolder?: boolean) => { - console.log('resetAll', keepCurrentFolder) setFactoryResetState(FactoryResetState.Starting) // read the place of jan data folder - // const appConfiguration: AppConfiguration | undefined = - // await window.core?.api?.getAppConfigurations() + const appConfiguration: AppConfiguration | undefined = + await window.core?.api?.getAppConfigurations() - // if (!appConfiguration) { - // console.debug('Failed to get app configuration') - // } + if (!appConfiguration) { + console.debug('Failed to get app configuration') + } - // // @james - delete the cortex data folder - // const janDataFolderPath = appConfiguration!.data_folder + const janDataFolderPath = appConfiguration!.data_folder - // if (!keepCurrentFolder) { - // // set the default jan data folder to user's home directory - // const configuration: AppConfiguration = { - // data_folder: defaultJanDataFolder, - // quick_ask: appConfiguration?.quick_ask ?? false, - // } - // await window.core?.api?.updateAppConfiguration(configuration) - // } + if (!keepCurrentFolder) { + // set the default jan data folder to user's home directory + const configuration: AppConfiguration = { + data_folder: defaultJanDataFolder, + quick_ask: appConfiguration?.quick_ask ?? false, + } + await window.core?.api?.updateAppConfiguration(configuration) + } - // setFactoryResetState(FactoryResetState.StoppingModel) - // for (const { model } of activeModels) { - // await stopModel(model) - // } + setFactoryResetState(FactoryResetState.StoppingModel) + await stopModel() + await new Promise((resolve) => setTimeout(resolve, 4000)) - // await new Promise((resolve) => setTimeout(resolve, 4000)) + setFactoryResetState(FactoryResetState.DeletingData) + await fs.rm(janDataFolderPath) - // setFactoryResetState(FactoryResetState.DeletingData) - // // await fs.rm(janDataFolderPath) + setFactoryResetState(FactoryResetState.ClearLocalStorage) + // reset the localStorage + localStorage.clear() - // setFactoryResetState(FactoryResetState.ClearLocalStorage) - // // reset the localStorage - // localStorage.clear() - - // await window.core?.api?.relaunch() + await window.core?.api?.relaunch() }, - [setFactoryResetState] + [defaultJanDataFolder, stopModel, setFactoryResetState] ) return { resetAll, } } - -export default useFactoryReset diff --git a/web/hooks/useGetFileSize.ts b/web/hooks/useGetFileSize.ts deleted file mode 100644 index 38dd1f036d..0000000000 --- a/web/hooks/useGetFileSize.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { useQuery } from '@tanstack/react-query' - -import { getFileSize } from '@/utils/huggingface' - -const useGetFileSize = (url: string) => - useQuery({ - queryKey: ['fileSize', url], - queryFn: () => getFileSize(url), - }) - -export default useGetFileSize diff --git a/web/hooks/useGetHFRepoData.ts b/web/hooks/useGetHFRepoData.ts index 4415287027..3dab2c72e8 100644 --- a/web/hooks/useGetHFRepoData.ts +++ b/web/hooks/useGetHFRepoData.ts @@ -1,6 +1,12 @@ import { useCallback, useState } from 'react' -import { fetchHuggingFaceRepoData } from '@/utils/huggingface' +import { + ExtensionTypeEnum, + HuggingFaceRepoData, + ModelExtension, +} from '@janhq/core' + +import { extensionManager } from '@/extension' export const useGetHFRepoData = () => { const [error, setError] = useState(undefined) @@ -10,7 +16,7 @@ export const useGetHFRepoData = () => { try { setError(undefined) setLoading(true) - const data = await fetchHuggingFaceRepoData(repoId) + const data = await extensionGetHfRepoData(repoId) return data } catch (err) { console.error(err) @@ -25,3 +31,11 @@ export const useGetHFRepoData = () => { return { loading, error, getHfRepoData } } + +const extensionGetHfRepoData = async ( + repoId: string +): Promise => { + return extensionManager + .get(ExtensionTypeEnum.Model) + ?.fetchHuggingFaceRepoData(repoId) +} diff --git a/web/hooks/useGetModelsByEngine.ts b/web/hooks/useGetModelsByEngine.ts deleted file mode 100644 index 3c2c0962f7..0000000000 --- a/web/hooks/useGetModelsByEngine.ts +++ /dev/null @@ -1,65 +0,0 @@ -import { useCallback } from 'react' - -import { LlmEngine, LocalEngines, Model } from '@janhq/core' - -import { useQueryClient } from '@tanstack/react-query' - -import { useAtomValue } from 'jotai' - -import { HfModelEntry } from '@/utils/huggingface' - -import { cortexHubModelsQueryKey } from './useModelHub' - -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' - -const useGetModelsByEngine = () => { - const downloadedModels = useAtomValue(downloadedModelsAtom) - const queryClient = useQueryClient() - - // TODO: this function needs to be clean up - const getModelsByEngine = useCallback( - (engine: LlmEngine, searchText = ''): Model[] => { - if (LocalEngines.some((x) => x === engine)) { - return downloadedModels - .filter((m) => m.engine === engine) - .filter((m) => { - if (searchText.trim() === '') return true - return ( - m.model?.toLowerCase().includes(searchText) || - m.name?.toLowerCase().includes(searchText) - ) - }) - } - - const availableModels = downloadedModels.filter( - (m) => m.engine === engine - ) - // engine is remote engine - const data = queryClient.getQueryData(cortexHubModelsQueryKey) - if (!data) return availableModels - - const modelEntries = data as HfModelEntry[] - const models: Model[] = [...availableModels] - for (const entry of modelEntries) { - const entryModel = entry.model - if (!entryModel) continue - if (entry.engine !== engine) continue - if (models.some((m) => m.model === entryModel.model)) continue - models.push(entryModel) - } - - return models.filter((m) => { - if (searchText.trim() === '') return true - return ( - m.model?.toLowerCase().includes(searchText) || - m.name?.toLowerCase().includes(searchText) - ) - }) - }, - [queryClient, downloadedModels] - ) - - return { getModelsByEngine } -} - -export default useGetModelsByEngine diff --git a/web/hooks/useGetReadMeContent.ts b/web/hooks/useGetReadMeContent.ts deleted file mode 100644 index 3928166fb2..0000000000 --- a/web/hooks/useGetReadMeContent.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { useQuery } from '@tanstack/react-query' - -import { tryGettingReadMeFile } from '@/utils/huggingface' - -const useGetReadMeContent = (repoName: string) => - useQuery({ - queryKey: ['useGetReadMeContent', repoName], - queryFn: () => tryGettingReadMeFile(repoName), - }) - -export default useGetReadMeContent diff --git a/web/hooks/useGetSystemResources.ts b/web/hooks/useGetSystemResources.ts new file mode 100644 index 0000000000..a05a6a7102 --- /dev/null +++ b/web/hooks/useGetSystemResources.ts @@ -0,0 +1,122 @@ +import { useCallback, useEffect, useState } from 'react' + +import { ExtensionTypeEnum, MonitoringExtension } from '@janhq/core' + +import { useSetAtom } from 'jotai' + +import { extensionManager } from '@/extension/ExtensionManager' +import { + cpuUsageAtom, + totalRamAtom, + usedRamAtom, + nvidiaTotalVramAtom, + gpusAtom, + ramUtilitizedAtom, + availableVramAtom, +} from '@/helpers/atoms/SystemBar.atom' + +export default function useGetSystemResources() { + const [intervalId, setIntervalId] = useState< + NodeJS.Timeout | number | undefined + >(undefined) + + const setTotalRam = useSetAtom(totalRamAtom) + const setGpus = useSetAtom(gpusAtom) + const setUsedRam = useSetAtom(usedRamAtom) + const setCpuUsage = useSetAtom(cpuUsageAtom) + const setTotalNvidiaVram = useSetAtom(nvidiaTotalVramAtom) + const setAvailableVram = useSetAtom(availableVramAtom) + const setRamUtilitized = useSetAtom(ramUtilitizedAtom) + + const getSystemResources = useCallback(async () => { + if ( + !extensionManager.get( + ExtensionTypeEnum.SystemMonitoring + ) + ) { + return + } + const monitoring = extensionManager.get( + ExtensionTypeEnum.SystemMonitoring + ) + const resourceInfor = await monitoring?.getResourcesInfo() + const currentLoadInfor = await monitoring?.getCurrentLoad() + + if (resourceInfor?.mem?.usedMemory) setUsedRam(resourceInfor.mem.usedMemory) + if (resourceInfor?.mem?.totalMemory) + setTotalRam(resourceInfor.mem.totalMemory) + + const ramUtilitized = + ((resourceInfor?.mem?.usedMemory ?? 0) / + (resourceInfor?.mem?.totalMemory ?? 1)) * + 100 + setRamUtilitized(Math.round(ramUtilitized)) + + setCpuUsage(Math.round(currentLoadInfor?.cpu?.usage ?? 0)) + + const gpus = currentLoadInfor?.gpu ?? [] + setGpus(gpus) + + let totalNvidiaVram = 0 + if (gpus.length > 0) { + totalNvidiaVram = gpus.reduce( + (total: number, gpu: { memoryTotal: string }) => + total + Number(gpu.memoryTotal), + 0 + ) + } + setTotalNvidiaVram(totalNvidiaVram) + setAvailableVram( + gpus.reduce( + (total: number, gpu: { memoryFree: string }) => + total + Number(gpu.memoryFree), + 0 + ) + ) + }, [ + setUsedRam, + setTotalRam, + setRamUtilitized, + setCpuUsage, + setGpus, + setTotalNvidiaVram, + setAvailableVram, + ]) + + const watch = () => { + getSystemResources() + + // Fetch interval - every 2s + const itv = setInterval(() => { + getSystemResources() + }, 2000) + setIntervalId(itv) + } + const stopWatching = useCallback(() => { + if (intervalId) clearInterval(intervalId) + }, [intervalId]) + + useEffect(() => { + getSystemResources() + // Component did unmount + // Stop watching if any + return () => { + stopWatching() + } + }, [getSystemResources, stopWatching]) + + return { + /** + * Fetch resource information once + */ + getSystemResources, + /** + * Fetch & watch for resource update + */ + watch, + /** + * Stop watching + */ + stopWatching, + } +} diff --git a/web/hooks/useGpuSetting.ts b/web/hooks/useGpuSetting.ts new file mode 100644 index 0000000000..36f51ed577 --- /dev/null +++ b/web/hooks/useGpuSetting.ts @@ -0,0 +1,21 @@ +import { useCallback } from 'react' + +import { ExtensionTypeEnum, MonitoringExtension } from '@janhq/core' + +import { extensionManager } from '@/extension' + +export default function useGpuSetting() { + const getGpuSettings = useCallback(async () => { + const gpuSetting = await extensionManager + ?.get(ExtensionTypeEnum.SystemMonitoring) + ?.getGpuSetting() + + if (!gpuSetting) { + console.debug('No GPU setting found') + return undefined + } + return gpuSetting + }, []) + + return { getGpuSettings } +} diff --git a/web/hooks/useHfEngineToBranchesQuery.ts b/web/hooks/useHfEngineToBranchesQuery.ts deleted file mode 100644 index d785ec39fb..0000000000 --- a/web/hooks/useHfEngineToBranchesQuery.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { useQuery } from '@tanstack/react-query' - -import { getEngineAndBranches } from '@/utils/huggingface' - -const useHfEngineToBranchesQuery = (modelHandle: string) => - useQuery({ - queryKey: ['useHfEngineToBranchesQuery', modelHandle], - queryFn: () => getEngineAndBranches(modelHandle), - staleTime: 5 * 60 * 1000, - }) - -export default useHfEngineToBranchesQuery diff --git a/web/hooks/useHfModelFetchAndDownload.ts b/web/hooks/useHfModelFetchAndDownload.ts deleted file mode 100644 index 76918656dc..0000000000 --- a/web/hooks/useHfModelFetchAndDownload.ts +++ /dev/null @@ -1,102 +0,0 @@ -import { useCallback } from 'react' - -import { HuggingFaceRepoData } from '@janhq/core' -import { useQueryClient } from '@tanstack/react-query' - -import { toaster } from '@/containers/Toast' - -import { fetchHuggingFaceRepoData } from '@/utils/huggingface' - -import { fetchHuggingFaceRepoDataQueryKey } from './useHfRepoDataQuery' -import useModelDownloadMutation from './useModelDownloadMutation' - -/** - * Fetches the data for a Hugging Face model and downloads it. - * This function will query local cache data first, before send request - * to HuggingFace to prevent unnecessary requests. - * - * @param modelHandle The model handle to fetch and download. E.g: "NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF" - */ -const useHfModelFetchAndDownload = () => { - const downloadModelMutation = useModelDownloadMutation() - const queryClient = useQueryClient() - - const fetchData = useCallback( - async (modelHandle: string): Promise => { - const data = queryClient.getQueryData([ - ...fetchHuggingFaceRepoDataQueryKey, - modelHandle, - ]) - - if (data) return data as HuggingFaceRepoData - console.debug(`No cache data found for ${modelHandle}`) - const repoData = await fetchHuggingFaceRepoData(modelHandle) - await queryClient.setQueryData( - [...fetchHuggingFaceRepoDataQueryKey, data], - repoData - ) - return repoData - }, - [queryClient] - ) - - const fetchDataAndDownload = useCallback( - async (modelHandle: string) => { - const repoData = await fetchData(modelHandle) - if (!repoData) { - console.error(`Could not fetch data for repo ${modelHandle}`) - toaster({ - title: `Failed to get data`, - description: `Could not get data for repo ${modelHandle}`, - type: 'error', - }) - return - } - - const recommendedQuant = 'Q4_K_S' - let recommendedModel = repoData.siblings.find( - (sibling) => - sibling.quantization?.toLowerCase() === recommendedQuant.toLowerCase() - ) - - if (!recommendedModel) { - console.debug('Q4_K_S model not found. Try with smallest model') - // get filesize min from repoData.siblings - - repoData.siblings - .filter((sibling) => { - sibling.fileSize != null && sibling.quantization != null - }) - .sort((a, b) => a.fileSize! - b.fileSize!) - recommendedModel = repoData.siblings[0] - console.debug('Min size recommended model:', recommendedModel) - } - - if (!recommendedModel) { - toaster({ - title: `Failed to get recommended model`, - description: `Could not get recommended model for repo ${modelHandle}. Please open the details page and select model manually!`, - type: 'error', - }) - return - } - - const persistModelId = modelHandle - .replaceAll('/', '_') - .concat('_') - .concat(recommendedModel.rfilename) - - downloadModelMutation.mutate({ - modelId: modelHandle, - fileName: recommendedModel.rfilename, - persistedModelId: persistModelId, - }) - }, - - [fetchData, downloadModelMutation] - ) - - return { fetchDataAndDownload } -} - -export default useHfModelFetchAndDownload diff --git a/web/hooks/useHfRepoDataQuery.ts b/web/hooks/useHfRepoDataQuery.ts deleted file mode 100644 index 1114c39bc1..0000000000 --- a/web/hooks/useHfRepoDataQuery.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { useQuery } from '@tanstack/react-query' - -import { fetchHuggingFaceRepoData } from '@/utils/huggingface' - -export const fetchHuggingFaceRepoDataQueryKey = ['fetchHuggingFaceRepoData'] - -const useHfRepoDataQuery = (repoId: string) => - useQuery({ - queryKey: [...fetchHuggingFaceRepoDataQueryKey, repoId], - queryFn: () => fetchHuggingFaceRepoData(repoId), - staleTime: 5 * 60 * 1000, - }) - -export default useHfRepoDataQuery diff --git a/web/hooks/useHfRevisionQuery.ts b/web/hooks/useHfRevisionQuery.ts deleted file mode 100644 index f77b1c7753..0000000000 --- a/web/hooks/useHfRevisionQuery.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { useQuery } from '@tanstack/react-query' - -import { getBranches } from '@/utils/huggingface' - -const useHfRevisionQuery = (repoName: string) => - useQuery({ - queryKey: ['hfRevision', repoName], - queryFn: () => getBranches(repoName), - staleTime: 5 * 60 * 1000, - }) - -export default useHfRevisionQuery diff --git a/web/hooks/useImportModel.ts b/web/hooks/useImportModel.ts index 529ae7a8ec..170f03b5ea 100644 --- a/web/hooks/useImportModel.ts +++ b/web/hooks/useImportModel.ts @@ -1,8 +1,26 @@ import { useCallback } from 'react' -import { ImportingModel, Model, OptionType } from '@janhq/core' +import { + ExtensionTypeEnum, + ImportingModel, + Model, + ModelExtension, + OptionType, + baseName, + fs, + joinPath, +} from '@janhq/core' -import { atom } from 'jotai' +import { atom, useSetAtom } from 'jotai' + +import { v4 as uuidv4 } from 'uuid' + +import { snackbar } from '@/containers/Toast' + +import { FilePathWithSize } from '@/utils/file' + +import { extensionManager } from '@/extension' +import { importingModelsAtom } from '@/helpers/atoms/Model.atom' export type ImportModelStage = | 'NONE' @@ -31,82 +49,104 @@ export type ModelUpdate = { } const useImportModel = () => { - // const setImportModelStage = useSetAtom(setImportModelStageAtom) - // const setImportingModels = useSetAtom(importingModelsAtom) + const setImportModelStage = useSetAtom(setImportModelStageAtom) + const setImportingModels = useSetAtom(importingModelsAtom) const importModels = useCallback( - (models: ImportingModel[], optionType: OptionType) => { - console.log('importModels', models, optionType) - // return localImportModels(models, optionType) - }, + (models: ImportingModel[], optionType: OptionType) => + localImportModels(models, optionType), [] ) - const updateModelInfo = useCallback(async (modelInfo: Partial) => { - console.log('updateModelInfo', modelInfo) - // localUpdateModelInfo(modelInfo) - }, []) - - const sanitizeFilePaths = useCallback(async (filePaths: string[]) => { - console.log('sanitizeFilePaths', filePaths) - // if (!filePaths || filePaths.length === 0) return - // const sanitizedFilePaths: FilePathWithSize[] = [] - // for (const filePath of filePaths) { - // const fileStats = await fs.fileStat(filePath, true) - // if (!fileStats) continue - // if (!fileStats.isDirectory) { - // const fileName = await baseName(filePath) - // sanitizedFilePaths.push({ - // path: filePath, - // name: fileName, - // size: fileStats.size, - // }) - // } else { - // // allowing only one level of directory - // const files = await fs.readdirSync(filePath) - // for (const file of files) { - // const fullPath = await joinPath([filePath, file]) - // const fileStats = await fs.fileStat(fullPath, true) - // if (!fileStats || fileStats.isDirectory) continue - // sanitizedFilePaths.push({ - // path: fullPath, - // name: file, - // size: fileStats.size, - // }) - // } - // } - // } - // const unsupportedFiles = sanitizedFilePaths.filter( - // (file) => !file.path.endsWith('.gguf') - // ) - // const supportedFiles = sanitizedFilePaths.filter((file) => - // file.path.endsWith('.gguf') - // ) - // const importingModels: ImportingModel[] = supportedFiles.map( - // ({ path, name, size }: FilePathWithSize) => ({ - // importId: uuidv4(), - // modelId: undefined, - // name: name.replace('.gguf', ''), - // description: '', - // path: path, - // tags: [], - // size: size, - // status: 'PREPARING', - // format: 'gguf', - // }) - // ) - // if (unsupportedFiles.length > 0) { - // snackbar({ - // description: `Only files with .gguf extension can be imported.`, - // type: 'error', - // }) - // } - // if (importingModels.length === 0) return - // setImportingModels(importingModels) - // setImportModelStage('MODEL_SELECTED') - }, []) + const updateModelInfo = useCallback( + async (modelInfo: Partial) => localUpdateModelInfo(modelInfo), + [] + ) + + const sanitizeFilePaths = useCallback( + async (filePaths: string[]) => { + if (!filePaths || filePaths.length === 0) return + + const sanitizedFilePaths: FilePathWithSize[] = [] + for (const filePath of filePaths) { + const fileStats = await fs.fileStat(filePath, true) + if (!fileStats) continue + + if (!fileStats.isDirectory) { + const fileName = await baseName(filePath) + sanitizedFilePaths.push({ + path: filePath, + name: fileName, + size: fileStats.size, + }) + } else { + // allowing only one level of directory + const files = await fs.readdirSync(filePath) + + for (const file of files) { + const fullPath = await joinPath([filePath, file]) + const fileStats = await fs.fileStat(fullPath, true) + if (!fileStats || fileStats.isDirectory) continue + + sanitizedFilePaths.push({ + path: fullPath, + name: file, + size: fileStats.size, + }) + } + } + } + + const unsupportedFiles = sanitizedFilePaths.filter( + (file) => !file.path.endsWith('.gguf') + ) + const supportedFiles = sanitizedFilePaths.filter((file) => + file.path.endsWith('.gguf') + ) + + const importingModels: ImportingModel[] = supportedFiles.map( + ({ path, name, size }: FilePathWithSize) => ({ + importId: uuidv4(), + modelId: undefined, + name: name.replace('.gguf', ''), + description: '', + path: path, + tags: [], + size: size, + status: 'PREPARING', + format: 'gguf', + }) + ) + if (unsupportedFiles.length > 0) { + snackbar({ + description: `Only files with .gguf extension can be imported.`, + type: 'error', + }) + } + if (importingModels.length === 0) return + + setImportingModels(importingModels) + setImportModelStage('MODEL_SELECTED') + }, + [setImportModelStage, setImportingModels] + ) return { importModels, updateModelInfo, sanitizeFilePaths } } +const localImportModels = async ( + models: ImportingModel[], + optionType: OptionType +): Promise => + extensionManager + .get(ExtensionTypeEnum.Model) + ?.importModels(models, optionType) + +const localUpdateModelInfo = async ( + modelInfo: Partial +): Promise => + extensionManager + .get(ExtensionTypeEnum.Model) + ?.updateModelInfo(modelInfo) + export default useImportModel diff --git a/web/hooks/useLoadTheme.ts b/web/hooks/useLoadTheme.ts index bea854b8b3..8afba27c46 100644 --- a/web/hooks/useLoadTheme.ts +++ b/web/hooks/useLoadTheme.ts @@ -2,11 +2,15 @@ import { useCallback, useEffect } from 'react' import { useTheme } from 'next-themes' -import { useAtom, useSetAtom } from 'jotai' +import { fs, joinPath } from '@janhq/core' + +import { useAtom, useAtomValue, useSetAtom } from 'jotai' import cssVars from '@/utils/jsonToCssVariables' +import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' import { + janThemesPathAtom, selectedThemeIdAtom, themeDataAtom, themesOptionsAtom, @@ -15,7 +19,9 @@ import { type NativeThemeProps = 'light' | 'dark' export const useLoadTheme = async () => { + const janDataFolderPath = useAtomValue(janDataFolderPathAtom) const setThemeOptions = useSetAtom(themesOptionsAtom) + const setThemePath = useSetAtom(janThemesPathAtom) const [themeData, setThemeData] = useAtom(themeDataAtom) const [selectedIdTheme, setSelectedIdTheme] = useAtom(selectedThemeIdAtom) const { setTheme } = useTheme() @@ -36,28 +42,46 @@ export const useLoadTheme = async () => { ) const getThemes = useCallback(async () => { - const themesOptions: { name: string; value: string }[] = - await window.electronAPI?.getThemes() + const folderPath = await joinPath([janDataFolderPath, 'themes']) + const installedThemes = await fs.readdirSync(folderPath) + + const themesOptions: { name: string; value: string }[] = installedThemes + .filter((x: string) => x !== '.DS_Store') + .map(async (x: string) => { + const y = await joinPath([`${folderPath}/${x}`, `theme.json`]) + const c: Theme = JSON.parse(await fs.readFileSync(y, 'utf-8')) + return { name: c?.displayName, value: c.id } + }) Promise.all(themesOptions).then((results) => { setThemeOptions(results) }) - // if (selectedIdTheme === null) return setSelectedIdTheme('joi-light') - - // console.log(typeof selectedIdTheme, 'selectedIdTheme') + if (janDataFolderPath.length > 0) { + if (!selectedIdTheme.length) return setSelectedIdTheme('joi-light') + setThemePath(folderPath) + const filePath = await joinPath([ + `${folderPath}/${selectedIdTheme}`, + `theme.json`, + ]) + const theme: Theme = JSON.parse(await fs.readFileSync(filePath, 'utf-8')) - const theme: Theme = await window.electronAPI.readTheme( - selectedIdTheme || 'joi-light' - ) - - setThemeData(theme) - setNativeTheme(theme.nativeTheme) - const variables = cssVars(theme.variables) - const headTag = document.getElementsByTagName('head')[0] - const styleTag = document.createElement('style') - styleTag.innerHTML = `:root {${variables}}` - headTag.appendChild(styleTag) - }, [selectedIdTheme, setNativeTheme, setThemeData, setThemeOptions]) + setThemeData(theme) + setNativeTheme(theme.nativeTheme) + const variables = cssVars(theme.variables) + const headTag = document.getElementsByTagName('head')[0] + const styleTag = document.createElement('style') + styleTag.innerHTML = `:root {${variables}}` + headTag.appendChild(styleTag) + } + }, [ + janDataFolderPath, + selectedIdTheme, + setNativeTheme, + setSelectedIdTheme, + setThemeData, + setThemeOptions, + setThemePath, + ]) useEffect(() => { getThemes() diff --git a/web/hooks/useLogs.ts b/web/hooks/useLogs.ts index 1594daf638..a391a22782 100644 --- a/web/hooks/useLogs.ts +++ b/web/hooks/useLogs.ts @@ -1,30 +1,36 @@ import { useCallback } from 'react' +import { fs, joinPath, openFileExplorer } from '@janhq/core' +import { useAtomValue } from 'jotai' + +import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' + export const useLogs = () => { - const getLogs = useCallback(async (file: string): Promise => { - console.log('getLogs', file) - // const path = await joinPath(['file://logs', `${file}.log`]) - // if (!(await fs.existsSync(path))) return '' - // const logs = await fs.readFileSync(path, 'utf-8') + const janDataFolderPath = useAtomValue(janDataFolderPathAtom) - // const sanitizedLogs = logs.replace( - // new RegExp(`${janDataFolderPath}\\/`, 'g'), - // 'jan-data-folder/' - // ) + const getLogs = useCallback( + async (file: string): Promise => { + const path = await joinPath(['file://logs', `${file}.log`]) + if (!(await fs.existsSync(path))) return '' + const logs = await fs.readFileSync(path, 'utf-8') - // return sanitizedLogs + const sanitizedLogs = logs.replace( + new RegExp(`${janDataFolderPath}\\/`, 'g'), + 'jan-data-folder/' + ) - // TODO: @james - read from cortex log - return Promise.resolve('') - }, []) + return sanitizedLogs + }, + [janDataFolderPath] + ) const openServerLog = useCallback(async () => { - // const fullPath = await joinPath([janDataFolderPath, 'logs', 'app.log']) - // return openFileExplorer(fullPath) - }, []) + const fullPath = await joinPath([janDataFolderPath, 'logs', 'app.log']) + return openFileExplorer(fullPath) + }, [janDataFolderPath]) const clearServerLog = useCallback(async () => { - // await fs.writeFileSync(await joinPath(['file://logs', 'app.log']), '') + await fs.writeFileSync(await joinPath(['file://logs', 'app.log']), '') }, []) return { getLogs, openServerLog, clearServerLog } diff --git a/web/hooks/useMessageCreateMutation.ts b/web/hooks/useMessageCreateMutation.ts deleted file mode 100644 index 493b3f0119..0000000000 --- a/web/hooks/useMessageCreateMutation.ts +++ /dev/null @@ -1,30 +0,0 @@ -import { MessageCreateParams } from '@janhq/core' -import { useMutation } from '@tanstack/react-query' - -import useCortex from './useCortex' - -export type MessageCreateMutationVariables = { - threadId: string - createMessageParams: MessageCreateParams -} - -const useMessageCreateMutation = () => { - const { createMessage } = useCortex() - - return useMutation({ - mutationFn: (variables: MessageCreateMutationVariables) => - createMessage(variables), - - onSuccess: (data) => { - console.debug(`Successfully created message: ${JSON.stringify(data)}`) - }, - - onError: (err, variables) => { - console.error( - `Failed to create message with variables: ${JSON.stringify(variables, null, 2)}, err: ${err}` - ) - }, - }) -} - -export default useMessageCreateMutation diff --git a/web/hooks/useMessageDeleteMutation.ts b/web/hooks/useMessageDeleteMutation.ts deleted file mode 100644 index 340d96067b..0000000000 --- a/web/hooks/useMessageDeleteMutation.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { useMutation } from '@tanstack/react-query' - -import useCortex from './useCortex' - -export type MessageDeleteMutationVariables = { - threadId: string - messageId: string -} - -const useMessageDeleteMutation = () => { - const { deleteMessage } = useCortex() - - return useMutation({ - mutationFn: (variables: MessageDeleteMutationVariables) => - deleteMessage(variables), - - onSuccess: (_data, variables) => { - console.debug( - `Successfully deleted message: ${JSON.stringify(variables)}` - ) - }, - - onError: (variables, err) => { - console.error( - `Failed to delete message: ${JSON.stringify(variables)}, err: ${err}` - ) - }, - }) -} - -export default useMessageDeleteMutation diff --git a/web/hooks/useMessageQuery.ts b/web/hooks/useMessageQuery.ts deleted file mode 100644 index 6fdef32f49..0000000000 --- a/web/hooks/useMessageQuery.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { useQuery } from '@tanstack/react-query' - -import useCortex from './useCortex' - -export const messageQueryKey = ['getMessages'] - -const useMessageQuery = (threadId: string) => { - const { fetchMessages } = useCortex() - - return useQuery({ - queryKey: [...messageQueryKey, threadId], - queryFn: () => fetchMessages(threadId), - staleTime: 30 * 1000, - }) -} - -export default useMessageQuery diff --git a/web/hooks/useMessageUpdateMutation.ts b/web/hooks/useMessageUpdateMutation.ts deleted file mode 100644 index bf7635a6df..0000000000 --- a/web/hooks/useMessageUpdateMutation.ts +++ /dev/null @@ -1,30 +0,0 @@ -import { useMutation } from '@tanstack/react-query' - -import useCortex from './useCortex' - -export type MessageUpdateMutationVariables = { - threadId: string - messageId: string - data: object -} - -const useMessageUpdateMutation = () => { - const { updateMessage } = useCortex() - - return useMutation({ - mutationFn: (variables: MessageUpdateMutationVariables) => - updateMessage(variables), - - onSuccess: (data) => { - console.debug(`Successfully updated message: ${JSON.stringify(data)}`) - }, - - onError: (err, variables) => { - console.error( - `Failed to update message with variables: ${JSON.stringify(variables, null, 2)}, err: ${err}` - ) - }, - }) -} - -export default useMessageUpdateMutation diff --git a/web/hooks/useMigratingData.ts b/web/hooks/useMigratingData.ts deleted file mode 100644 index f45688f255..0000000000 --- a/web/hooks/useMigratingData.ts +++ /dev/null @@ -1,99 +0,0 @@ -/* eslint-disable @typescript-eslint/no-explicit-any */ -import { useCallback } from 'react' - -import { defaultThreadTitle } from '@/constants/Threads' - -import useAssistantQuery from './useAssistantQuery' - -import useCortex from './useCortex' -import useMessageCreateMutation from './useMessageCreateMutation' -import useThreads from './useThreads' - -const useMigratingData = () => { - const { createThread } = useThreads() - const { updateThread } = useCortex() - const createMessage = useMessageCreateMutation() - const { data: assistants } = useAssistantQuery() - - const getJanThreadsAndMessages = useCallback(async (): Promise<{ - messages: any[] - threads: any[] - }> => { - return window?.electronAPI?.getAllMessagesAndThreads() - }, []) - - const getJanLocalModels = useCallback(async (): Promise => { - // TODO: change the name of this function - return window?.electronAPI?.getAllLocalModels() - }, []) - - const migrateModels = useCallback(async () => { - return window?.electronAPI?.syncModelFileToCortex() - }, []) - - const migrateThreadsAndMessages = useCallback(async () => { - if (!assistants || assistants.length === 0) { - console.error('No assistant found') - return - } - const threadsAndMessages = await getJanThreadsAndMessages() - const janThreads = threadsAndMessages.threads - - for (const thread of janThreads) { - const modelId: string | undefined = thread.assistants[0]?.model?.id - if (!modelId || modelId.trim().length === 0 || modelId === '*') { - console.error(`Ignore thread ${thread.id} because modelId is not found`) - continue - } - const threadTitle: string = thread.title ?? defaultThreadTitle - const instructions: string = thread.assistants[0]?.instructions ?? '' - // currently, we don't have api support for creating thread with messages - const cortexThread = await createThread(modelId, assistants[0]) - - console.log('createThread', cortexThread) - // update instruction - cortexThread.assistants[0].instructions = instructions - cortexThread.title = threadTitle - - // update thread name - await updateThread(cortexThread) - console.log('updateThread', cortexThread) - - // we finished with thread, now continue with messages - const janMessages = threadsAndMessages.messages.filter( - (m) => m.thread_id === thread.id - ) - - for (let j = 0; j < janMessages.length; ++j) { - const janMessage = janMessages[j] - // filter out the system message if any - if (janMessage.role === 'system') continue - const messageContent: string = janMessage.content[0]?.text.value ?? '' - - // can speed up here with Promise.allSettled - await createMessage.mutateAsync({ - threadId: cortexThread.id, - createMessageParams: { - content: messageContent, - role: janMessage.role, - }, - }) - } - } - }, [ - assistants, - getJanThreadsAndMessages, - createThread, - updateThread, - createMessage, - ]) - - return { - migrateModels, - migrateThreadsAndMessages, - getJanThreadsAndMessages, - getJanLocalModels, - } -} - -export default useMigratingData diff --git a/web/hooks/useModelDownloadMutation.ts b/web/hooks/useModelDownloadMutation.ts deleted file mode 100644 index 5b75dea85d..0000000000 --- a/web/hooks/useModelDownloadMutation.ts +++ /dev/null @@ -1,49 +0,0 @@ -import { useMutation } from '@tanstack/react-query' - -import { useSetAtom } from 'jotai' - -import { toaster } from '@/containers/Toast' - -import useCortex from './useCortex' -import { addDownloadModelStateAtom } from './useDownloadState' - -export type DownloadModelMutationVariables = { - modelId: string - fileName?: string - persistedModelId?: string -} - -const useModelDownloadMutation = () => { - const { downloadModel } = useCortex() - const addDownloadState = useSetAtom(addDownloadModelStateAtom) - - return useMutation({ - mutationFn: downloadModel, - - onMutate: (variables) => { - console.debug('Downloading model', variables) - }, - - onSuccess: (data, variables) => { - console.debug('Download response success', data, variables) - - const { persistedModelId, modelId } = variables - if (persistedModelId) { - addDownloadState(persistedModelId) - } else { - addDownloadState(modelId) - } - }, - - onError: (err, variables) => { - console.error('Failed to download model', err, variables) - toaster({ - title: 'Failed to download model', - description: err.message, - type: 'error', - }) - }, - }) -} - -export default useModelDownloadMutation diff --git a/web/hooks/useModelHub.ts b/web/hooks/useModelHub.ts deleted file mode 100644 index a87de67c8d..0000000000 --- a/web/hooks/useModelHub.ts +++ /dev/null @@ -1,191 +0,0 @@ -import { LlmEngine, LlmEngines } from '@janhq/core' - -import { useQueries } from '@tanstack/react-query' - -import { - fetchHuggingFaceModel, - HfModelEntry, - fetchCortexHubModels, -} from '@/utils/huggingface' - -type CuratedModelResponse = { - quickstart_models: QuickStartModel[] - popular_models: CuratedModel[] -} - -export type QuickStartModel = { - note: string - url: string - author: string - logo: string - model_name: string - model_logo: string - size: number - engine: LlmEngine -} - -export type CuratedModel = { note: string; url: string } - -const getFileSize = async (url: string): Promise => { - try { - const response = await fetch(url, { method: 'HEAD' }) - const size = response.headers.get('content-length') - return Number(size) - } catch (err) { - console.error('Getting file size failed for:', url, err) - return 0 - } -} - -const fetchBuiltInModels = async (): Promise => { - const response = await fetch( - 'https://raw.githubusercontent.com/janhq/cortex-web/main/static/huggingface/hub.json' - ) - const data = (await response.json()) as CuratedModelResponse - - const getFileSizePromises: Promise[] = data.quickstart_models.map( - (model) => { - const directDownloadUrl = model.url.replace('/blob/', '/resolve/') - return getFileSize(directDownloadUrl) - } - ) - - const sizes = await Promise.all(getFileSizePromises) - data.quickstart_models = data.quickstart_models.map((model, i) => { - const engine = (model.engine ?? 'cortex.llamacpp') as LlmEngine - return { - ...model, - engine, - size: sizes[i], - } - }) - - return data -} - -type BuiltInModels = { - popularModelEntries: HfModelEntry[] - quickStartModels: QuickStartModel[] -} - -const getBuiltInModelEntries = async (): Promise => { - const builtInModels = await fetchBuiltInModels() - const popularModelPaths = builtInModels.popular_models.map( - (model) => model.url - ) - - const result: HfModelEntry[] = [] - const promises: Promise[] = [] - popularModelPaths.forEach((path) => { - try { - const replacedUrl = path.replace('https://huggingface.co/', '') - const ownerName = replacedUrl.split('/')[0] - const repoName = replacedUrl.split('/')[1] - promises.push(fetchHuggingFaceModel(ownerName, repoName)) - } catch (err) { - console.error('Failed to getBuiltInModelEntries:', err) - } - }) - const promiseResult = await Promise.allSettled(promises) - // check if fulfilled or rejected - for (let i = 0; i < promiseResult.length; i++) { - if (promiseResult[i].status === 'fulfilled') { - const fulfillResult = promiseResult[i] as PromiseFulfilledResult< - HfModelEntry[] - > - const modelEntries: HfModelEntry[] = fulfillResult.value as HfModelEntry[] - result.push(...modelEntries) - } else { - console.error('Failed to getBuiltInModelEntries:', promiseResult[i]) - } - } - - return { - popularModelEntries: result, - quickStartModels: builtInModels.quickstart_models, - } -} - -export type ModelHubData = { - sliderData: QuickStartModel[] - modelCategories: Map -} - -export const ModelHubCategoryList = [ - 'BuiltInModels', - 'HuggingFace', - ...Object.values(LlmEngines), -] as const -export type ModelHubCategory = (typeof ModelHubCategoryList)[number] - -export const builtInModelsEntriesQueryKey = ['builtInModelsEntriesQueryKey'] -export const cortexHubModelsQueryKey = ['cortexHubModelsQueryKey'] - -const useModelHub = () => { - const results = useQueries({ - queries: [ - { - queryKey: builtInModelsEntriesQueryKey, - queryFn: getBuiltInModelEntries, - }, - { - queryKey: cortexHubModelsQueryKey, - queryFn: fetchCortexHubModels, - }, - ], - }) - - const isLoading = results.some((result) => result.isLoading) - const isError = results.some((result) => result.isError) - const error = results.find((result) => result.error)?.error - - const data: ModelHubData | undefined = (() => { - if (results.every((result) => result.isSuccess)) { - const data: ModelHubData = { - sliderData: [], - modelCategories: new Map(), - } - if (results[0].data) { - // quick start - data.sliderData = results[0].data.quickStartModels - - // popular - data.modelCategories.set( - 'HuggingFace', - results[0].data.popularModelEntries - ) - } - - if (results[1].data) { - data.modelCategories.set( - 'BuiltInModels', - results[1].data.filter( - (modelEntry) => modelEntry.remoteModel === false - ) - ) - - // for remote models - results[1].data.forEach((modelEntry) => { - const engine = modelEntry.engine - if (modelEntry.remoteModel === true && engine) { - if (data.modelCategories.has(engine)) { - data.modelCategories.set( - engine, - data.modelCategories.get(engine)!.concat(modelEntry) - ) - } else { - data.modelCategories.set(engine, [modelEntry]) - } - } - }) - } - - return data - } - return undefined - })() - - return { data, isLoading, isError, error } -} - -export default useModelHub diff --git a/web/hooks/useModelQuery.ts b/web/hooks/useModelQuery.ts deleted file mode 100644 index ddc77ad638..0000000000 --- a/web/hooks/useModelQuery.ts +++ /dev/null @@ -1,26 +0,0 @@ -import { useQuery } from '@tanstack/react-query' - -import { useSetAtom } from 'jotai' - -import useCortex from './useCortex' - -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' - -export const modelQueryKey = ['getModels'] - -const useModelQuery = () => { - const { fetchModels } = useCortex() - const setDownloadedModels = useSetAtom(downloadedModelsAtom) - - return useQuery({ - queryKey: modelQueryKey, - queryFn: async () => { - const models = await fetchModels() - setDownloadedModels(models) - return models - }, - staleTime: 30 * 1000, - }) -} - -export default useModelQuery diff --git a/web/hooks/useModelStart.ts b/web/hooks/useModelStart.ts deleted file mode 100644 index 17fff28f3d..0000000000 --- a/web/hooks/useModelStart.ts +++ /dev/null @@ -1,28 +0,0 @@ -import { useMutation } from '@tanstack/react-query' - -import { toaster } from '@/containers/Toast' - -import useCortex from './useCortex' - -const useModelStart = () => { - const { startModel } = useCortex() - - return useMutation({ - mutationFn: (modelId: string) => startModel(modelId), - - onSuccess: (data, variables) => { - console.debug('Model started', variables, data) - }, - - onError: (error, variables) => { - toaster({ - title: 'Failed to send message', - description: `Failed to start model ${variables}. Please try again!`, - type: 'error', - }) - console.error('Failed to start model', variables, error) - }, - }) -} - -export default useModelStart diff --git a/web/hooks/useModelStop.ts b/web/hooks/useModelStop.ts deleted file mode 100644 index 891f38b34c..0000000000 --- a/web/hooks/useModelStop.ts +++ /dev/null @@ -1,21 +0,0 @@ -import { useMutation } from '@tanstack/react-query' - -import useCortex from './useCortex' - -const useModelStop = () => { - const { stopModel } = useCortex() - - return useMutation({ - mutationFn: stopModel, - - onSuccess: (data, modelId) => { - console.debug(`Model ${modelId} stopped successfully`, data) - }, - - onError: (error, modelId) => { - console.debug(`Stop model ${modelId} error`, error) - }, - }) -} - -export default useModelStop diff --git a/web/hooks/useModels.ts b/web/hooks/useModels.ts index 4d097a6da1..5a6f13e035 100644 --- a/web/hooks/useModels.ts +++ b/web/hooks/useModels.ts @@ -1,39 +1,76 @@ -import { useCallback } from 'react' +import { useCallback, useEffect } from 'react' + +import { + ExtensionTypeEnum, + Model, + ModelEvent, + ModelExtension, + events, +} from '@janhq/core' import { useSetAtom } from 'jotai' -import { toaster } from '@/containers/Toast' +import { extensionManager } from '@/extension' +import { + configuredModelsAtom, + defaultModelAtom, + downloadedModelsAtom, +} from '@/helpers/atoms/Model.atom' -import useCortex from './useCortex' +const useModels = () => { + const setDownloadedModels = useSetAtom(downloadedModelsAtom) + const setConfiguredModels = useSetAtom(configuredModelsAtom) + const setDefaultModel = useSetAtom(defaultModelAtom) -import { removeDownloadedModelAtom } from '@/helpers/atoms/Model.atom' + const getData = useCallback(() => { + const getDownloadedModels = async () => { + const models = await getLocalDownloadedModels() + setDownloadedModels(models) + } -const useModels = () => { - const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom) - const { deleteModel: cortexDeleteModel, updateModel: cortexUpdateModel } = - useCortex() - - const deleteModel = useCallback( - async (modelId: string) => { - await cortexDeleteModel(modelId) - removeDownloadedModel(modelId) - - toaster({ - title: 'Model Deletion Successful', - description: `Model ${modelId} has been successfully deleted.`, - type: 'success', - }) - }, - [removeDownloadedModel, cortexDeleteModel] - ) - - const updateModel = useCallback( - async (modelId: string, modelSettings: Record) => - cortexUpdateModel(modelId, modelSettings), - [cortexUpdateModel] - ) - - return { deleteModel, updateModel } + const getConfiguredModels = async () => { + const models = await getLocalConfiguredModels() + setConfiguredModels(models) + } + + const getDefaultModel = async () => { + const defaultModel = await getLocalDefaultModel() + setDefaultModel(defaultModel) + } + + Promise.all([ + getDownloadedModels(), + getConfiguredModels(), + getDefaultModel(), + ]) + }, [setDownloadedModels, setConfiguredModels, setDefaultModel]) + + useEffect(() => { + // Try get data on mount + getData() + + // Listen for model updates + events.on(ModelEvent.OnModelsUpdate, async () => getData()) + return () => { + // Remove listener on unmount + events.off(ModelEvent.OnModelsUpdate, async () => {}) + } + }, [getData]) } +const getLocalDefaultModel = async (): Promise => + extensionManager + .get(ExtensionTypeEnum.Model) + ?.getDefaultModel() + +const getLocalConfiguredModels = async (): Promise => + extensionManager + .get(ExtensionTypeEnum.Model) + ?.getConfiguredModels() ?? [] + +const getLocalDownloadedModels = async (): Promise => + extensionManager + .get(ExtensionTypeEnum.Model) + ?.getDownloadedModels() ?? [] + export default useModels diff --git a/web/hooks/usePath.ts b/web/hooks/usePath.ts index e8886bbfa4..98e3009b49 100644 --- a/web/hooks/usePath.ts +++ b/web/hooks/usePath.ts @@ -1,94 +1,100 @@ +import { openFileExplorer, joinPath, baseName } from '@janhq/core' import { useAtomValue } from 'jotai' +import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' +import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' export const usePath = () => { - // const janDataFolderPath = useAtomValue(janDataFolderPathAtom) + const janDataFolderPath = useAtomValue(janDataFolderPathAtom) const activeThread = useAtomValue(activeThreadAtom) + const selectedModel = useAtomValue(selectedModelAtom) const onRevealInFinder = async (type: string) => { - console.log('onRevealInFinder', type) - // // TODO: this logic should be refactored. - // if (type !== 'Model' && !activeThread) return - // let filePath = undefined - // const assistantId = activeThread?.assistants[0]?.assistant_id - // switch (type) { - // case 'Engine': - // case 'Thread': - // filePath = await joinPath(['threads', activeThread?.id ?? '']) - // break - // case 'Model': - // if (!selectedModel) return - // filePath = await joinPath(['models', selectedModel.model]) - // break - // case 'Tools': - // case 'Assistant': - // if (!assistantId) return - // filePath = await joinPath(['assistants', assistantId]) - // break - // case 'Logs': - // filePath = 'logs' - // break - // default: - // break - // } - // if (!filePath) return - // const fullPath = await joinPath([janDataFolderPath, filePath]) - // openFileExplorer(fullPath) + // TODO: this logic should be refactored. + if (type !== 'Model' && !activeThread) return + + let filePath = undefined + const assistantId = activeThread?.assistants[0]?.assistant_id + switch (type) { + case 'Engine': + case 'Thread': + filePath = await joinPath(['threads', activeThread?.id ?? '']) + break + case 'Model': + if (!selectedModel) return + filePath = await joinPath(['models', selectedModel.id]) + break + case 'Tools': + case 'Assistant': + if (!assistantId) return + filePath = await joinPath(['assistants', assistantId]) + break + case 'Logs': + filePath = 'logs' + break + default: + break + } + + if (!filePath) return + const fullPath = await joinPath([janDataFolderPath, filePath]) + openFileExplorer(fullPath) } const onViewJson = async (type: string) => { - console.log('onViewJson', type) - // // TODO: this logic should be refactored. - // if (type !== 'Model' && !activeThread) return - // let filePath = undefined - // const assistantId = activeThread?.assistants[0]?.assistant_id - // switch (type) { - // case 'Engine': - // case 'Thread': - // filePath = await joinPath([ - // 'threads', - // activeThread?.id ?? '', - // 'thread.json', - // ]) - // break - // case 'Model': - // if (!selectedModel) return - // filePath = await joinPath(['models', selectedModel.model, 'model.json']) - // break - // case 'Assistant': - // case 'Tools': - // if (!assistantId) return - // filePath = await joinPath(['assistants', assistantId, 'assistant.json']) - // break - // default: - // break - // } - // if (!filePath) return - // const fullPath = await joinPath([janDataFolderPath, filePath]) - // openFileExplorer(fullPath) + // TODO: this logic should be refactored. + if (type !== 'Model' && !activeThread) return + + let filePath = undefined + const assistantId = activeThread?.assistants[0]?.assistant_id + switch (type) { + case 'Engine': + case 'Thread': + filePath = await joinPath([ + 'threads', + activeThread?.id ?? '', + 'thread.json', + ]) + break + case 'Model': + if (!selectedModel) return + filePath = await joinPath(['models', selectedModel.id, 'model.json']) + break + case 'Assistant': + case 'Tools': + if (!assistantId) return + filePath = await joinPath(['assistants', assistantId, 'assistant.json']) + break + default: + break + } + + if (!filePath) return + const fullPath = await joinPath([janDataFolderPath, filePath]) + openFileExplorer(fullPath) } const onViewFile = async (id: string) => { if (!activeThread) return - console.log('onViewFile', id) - // let filePath = undefined - // id = await baseName(id) - // filePath = await joinPath(['threads', `${activeThread.id}/files`, `${id}`]) - // if (!filePath) return - // const fullPath = await joinPath([janDataFolderPath, filePath]) - // openFileExplorer(fullPath) + let filePath = undefined + + id = await baseName(id) + filePath = await joinPath(['threads', `${activeThread.id}/files`, `${id}`]) + if (!filePath) return + const fullPath = await joinPath([janDataFolderPath, filePath]) + openFileExplorer(fullPath) } const onViewFileContainer = async () => { if (!activeThread) return - // let filePath = undefined - // filePath = await joinPath(['threads', `${activeThread.id}/files`]) - // if (!filePath) return - // const fullPath = await joinPath([janDataFolderPath, filePath]) - // openFileExplorer(fullPath) + let filePath = undefined + filePath = await joinPath(['threads', `${activeThread.id}/files`]) + if (!filePath) return + const fullPath = await joinPath([janDataFolderPath, filePath]) + openFileExplorer(fullPath) } return { diff --git a/web/hooks/useRecommendedModel.ts b/web/hooks/useRecommendedModel.ts new file mode 100644 index 0000000000..8122e2b77a --- /dev/null +++ b/web/hooks/useRecommendedModel.ts @@ -0,0 +1,103 @@ +import { useCallback, useEffect, useState } from 'react' + +import { Model, InferenceEngine } from '@janhq/core' + +import { atom, useAtomValue } from 'jotai' + +import { activeModelAtom } from './useActiveModel' + +import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' +import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' + +export const lastUsedModel = atom(undefined) + +export const LAST_USED_MODEL_ID = 'last-used-model-id' + +/** + * A hook that return the recommended model when user + * wants to create a new thread. + * + * The precedence is as follows: + * 1. Active model + * 2. If no active model(s), then the last used model + * 3. If no active or last used model, then the 1st model on the list + */ +export default function useRecommendedModel() { + const activeModel = useAtomValue(activeModelAtom) + const [sortedModels, setSortedModels] = useState([]) + const [recommendedModel, setRecommendedModel] = useState() + const activeThread = useAtomValue(activeThreadAtom) + const downloadedModels = useAtomValue(downloadedModelsAtom) + + const getAndSortDownloadedModels = useCallback(async (): Promise => { + const models = downloadedModels.sort((a, b) => + a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro + ? 1 + : -1 + ) + setSortedModels(models) + return models + }, [downloadedModels]) + + const getRecommendedModel = useCallback(async (): Promise< + Model | undefined + > => { + const models = await getAndSortDownloadedModels() + if (!activeThread) return + const modelId = activeThread.assistants[0]?.model.id + const model = models.find((model) => model.id === modelId) + + if (model) { + setRecommendedModel(model) + } + + if (activeModel) { + // if we have active model alr, then we can just use that + console.debug(`Using active model ${activeModel.id}`) + setRecommendedModel(activeModel) + return + } + + // sort the model, for display purpose + + if (models.length === 0) { + // if we have no downloaded models, then can't recommend anything + console.debug("No downloaded models, can't recommend anything") + return + } + + // otherwise, get the last used model id + const lastUsedModelId = localStorage.getItem(LAST_USED_MODEL_ID) + + // if we don't have [lastUsedModelId], then we can just use the first model + // in the downloaded list + if (!lastUsedModelId) { + console.debug( + `No last used model, using first model in list ${models[0].id}}` + ) + setRecommendedModel(models[0]) + return + } + + const lastUsedModel = models.find((model) => model.id === lastUsedModelId) + if (!lastUsedModel) { + // if we can't find the last used model, then we can just use the first model + // in the downloaded list + console.debug( + `Last used model ${lastUsedModelId} not found, using first model in list ${models[0].id}}` + ) + setRecommendedModel(models[0]) + return + } + + console.debug(`Using last used model ${lastUsedModel.id}`) + setRecommendedModel(lastUsedModel) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [getAndSortDownloadedModels, activeThread]) + + useEffect(() => { + getRecommendedModel() + }, [getRecommendedModel]) + + return { recommendedModel, downloadedModels: sortedModels } +} diff --git a/web/hooks/useSelectModel.ts b/web/hooks/useSelectModel.ts deleted file mode 100644 index fefca4c6a3..0000000000 --- a/web/hooks/useSelectModel.ts +++ /dev/null @@ -1,34 +0,0 @@ -import { useCallback } from 'react' - -import { Model } from '@janhq/core' -import { useAtomValue, useSetAtom } from 'jotai' - -import useCortex from './useCortex' - -import { updateSelectedModelAtom } from '@/helpers/atoms/Model.atom' -import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' - -const useSelectModel = () => { - const updateSelectedModel = useSetAtom(updateSelectedModelAtom) - const activeThread = useAtomValue(activeThreadAtom) - const { updateThread } = useCortex() - - const selectModel = useCallback( - (model: Model) => { - if (activeThread) { - console.debug( - `Set model id ${model.model} to active thread ${activeThread.id}` - ) - activeThread.assistants[0].model = model.model - updateThread(activeThread) - } - - updateSelectedModel(model) - }, - [activeThread, updateSelectedModel, updateThread] - ) - - return { selectModel } -} - -export default useSelectModel diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts new file mode 100644 index 0000000000..8c6013505b --- /dev/null +++ b/web/hooks/useSendChatMessage.ts @@ -0,0 +1,308 @@ +import { useEffect, useRef } from 'react' + +import { + ChatCompletionRole, + MessageRequestType, + ExtensionTypeEnum, + Thread, + ThreadMessage, + Model, + ConversationalExtension, + EngineManager, + ToolManager, + ChatCompletionMessage, +} from '@janhq/core' +import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' + +import { + currentPromptAtom, + editPromptAtom, + fileUploadAtom, +} from '@/containers/Providers/Jotai' + +import { Stack } from '@/utils/Stack' +import { compressImage, getBase64 } from '@/utils/base64' +import { MessageRequestBuilder } from '@/utils/messageRequestBuilder' +import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' + +import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder' + +import { loadModelErrorAtom, useActiveModel } from './useActiveModel' + +import { extensionManager } from '@/extension/ExtensionManager' +import { + addNewMessageAtom, + deleteMessageAtom, + getCurrentChatMessagesAtom, +} from '@/helpers/atoms/ChatMessage.atom' +import { selectedModelAtom } from '@/helpers/atoms/Model.atom' +import { + activeThreadAtom, + engineParamsUpdateAtom, + getActiveThreadModelParamsAtom, + isGeneratingResponseAtom, + updateThreadAtom, + updateThreadWaitingForResponseAtom, +} from '@/helpers/atoms/Thread.atom' + +export const queuedMessageAtom = atom(false) +export const reloadModelAtom = atom(false) + +export default function useSendChatMessage() { + const activeThread = useAtomValue(activeThreadAtom) + const addNewMessage = useSetAtom(addNewMessageAtom) + const updateThread = useSetAtom(updateThreadAtom) + const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom) + const setCurrentPrompt = useSetAtom(currentPromptAtom) + const deleteMessage = useSetAtom(deleteMessageAtom) + const setEditPrompt = useSetAtom(editPromptAtom) + + const currentMessages = useAtomValue(getCurrentChatMessagesAtom) + const selectedModel = useAtomValue(selectedModelAtom) + const { activeModel, startModel } = useActiveModel() + const loadModelFailed = useAtomValue(loadModelErrorAtom) + + const modelRef = useRef() + const loadModelFailedRef = useRef() + const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) + const engineParamsUpdate = useAtomValue(engineParamsUpdateAtom) + + const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom) + const setReloadModel = useSetAtom(reloadModelAtom) + const [fileUpload, setFileUpload] = useAtom(fileUploadAtom) + const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) + const activeThreadRef = useRef() + const setQueuedMessage = useSetAtom(queuedMessageAtom) + + const selectedModelRef = useRef() + + useEffect(() => { + modelRef.current = activeModel + }, [activeModel]) + + useEffect(() => { + loadModelFailedRef.current = loadModelFailed + }, [loadModelFailed]) + + useEffect(() => { + activeThreadRef.current = activeThread + }, [activeThread]) + + useEffect(() => { + selectedModelRef.current = selectedModel + }, [selectedModel]) + + const normalizeMessages = ( + messages: ChatCompletionMessage[] + ): ChatCompletionMessage[] => { + const stack = new Stack() + for (const message of messages) { + if (stack.isEmpty()) { + stack.push(message) + continue + } + const topMessage = stack.peek() + + if (message.role === topMessage.role) { + // add an empty message + stack.push({ + role: + topMessage.role === ChatCompletionRole.User + ? ChatCompletionRole.Assistant + : ChatCompletionRole.User, + content: '.', // some model requires not empty message + }) + } + stack.push(message) + } + + return stack.reverseOutput() + } + + const resendChatMessage = async (currentMessage: ThreadMessage) => { + if (!activeThreadRef.current) { + console.error('No active thread') + return + } + updateThreadWaiting(activeThreadRef.current.id, true) + + const requestBuilder = new MessageRequestBuilder( + MessageRequestType.Thread, + activeThreadRef.current.assistants[0].model ?? selectedModelRef.current, + activeThreadRef.current, + currentMessages + ) + .addSystemMessage(activeThreadRef.current.assistants[0]?.instructions) + .removeLastAssistantMessage() + + const modelId = + selectedModelRef.current?.id ?? + activeThreadRef.current.assistants[0].model.id + + if (modelRef.current?.id !== modelId) { + const error = await startModel(modelId).catch((error: Error) => error) + if (error) { + updateThreadWaiting(activeThreadRef.current.id, false) + return + } + } + + setIsGeneratingResponse(true) + + if (currentMessage.role !== ChatCompletionRole.User) { + // Delete last response before regenerating + deleteMessage(currentMessage.id ?? '') + if (activeThreadRef.current) { + await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.writeMessages( + activeThreadRef.current.id, + currentMessages.filter((msg) => msg.id !== currentMessage.id) + ) + } + } + // Process message request with Assistants tools + const request = await ToolManager.instance().process( + requestBuilder.build(), + activeThreadRef.current.assistants?.flatMap( + (assistant) => assistant.tools ?? [] + ) ?? [] + ) + + request.messages = normalizeMessages(request.messages ?? []) + + const engine = + requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? '' + + EngineManager.instance().get(engine)?.inference(request) + } + + // Define interface extending Array prototype + + const sendChatMessage = async (message: string) => { + if (!message || message.trim().length === 0) return + + if (!activeThreadRef.current) { + console.error('No active thread') + return + } + + if (engineParamsUpdate) setReloadModel(true) + + const runtimeParams = toRuntimeParams(activeModelParams) + const settingParams = toSettingParams(activeModelParams) + + const prompt = message.trim() + + updateThreadWaiting(activeThreadRef.current.id, true) + setCurrentPrompt('') + setEditPrompt('') + + let base64Blob = fileUpload[0] + ? await getBase64(fileUpload[0].file) + : undefined + + if (base64Blob && fileUpload[0]?.type === 'image') { + // Compress image + base64Blob = await compressImage(base64Blob, 512) + } + + const modelRequest = + selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model + + // Fallback support for previous broken threads + if (activeThreadRef.current?.assistants[0]?.model?.id === '*') { + activeThreadRef.current.assistants[0].model = { + id: modelRequest.id, + settings: modelRequest.settings, + parameters: modelRequest.parameters, + } + } + if (runtimeParams.stream == null) { + runtimeParams.stream = true + } + + // Build Message Request + const requestBuilder = new MessageRequestBuilder( + MessageRequestType.Thread, + { + ...modelRequest, + settings: settingParams, + parameters: runtimeParams, + }, + activeThreadRef.current, + currentMessages + ).addSystemMessage(activeThreadRef.current.assistants[0].instructions) + + requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type) + + // Build Thread Message to persist + const threadMessageBuilder = new ThreadMessageBuilder( + requestBuilder + ).pushMessage(prompt, base64Blob, fileUpload) + + const newMessage = threadMessageBuilder.build() + + // Push to states + addNewMessage(newMessage) + + // Update thread state + const updatedThread: Thread = { + ...activeThreadRef.current, + updated: newMessage.created, + metadata: { + ...(activeThreadRef.current.metadata ?? {}), + lastMessage: prompt, + }, + } + updateThread(updatedThread) + + // Add message + await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.addNewMessage(newMessage) + + // Start Model if not started + const modelId = + selectedModelRef.current?.id ?? + activeThreadRef.current.assistants[0].model.id + + if (base64Blob) { + setFileUpload([]) + } + + if (modelRef.current?.id !== modelId) { + setQueuedMessage(true) + const error = await startModel(modelId).catch((error: Error) => error) + setQueuedMessage(false) + if (error) { + updateThreadWaiting(activeThreadRef.current.id, false) + return + } + } + setIsGeneratingResponse(true) + + // Process message request with Assistants tools + const request = await ToolManager.instance().process( + requestBuilder.build(), + activeThreadRef.current.assistants?.flatMap( + (assistant) => assistant.tools ?? [] + ) ?? [] + ) + request.messages = normalizeMessages(request.messages ?? []) + + // Request for inference + EngineManager.instance() + .get(requestBuilder.model?.engine ?? modelRequest.engine ?? '') + ?.inference(request) + + // Reset states + setReloadModel(false) + setEngineParamsUpdate(false) + } + + return { + sendChatMessage, + resendChatMessage, + } +} diff --git a/web/hooks/useSendMessage.ts b/web/hooks/useSendMessage.ts deleted file mode 100644 index bd81d367bd..0000000000 --- a/web/hooks/useSendMessage.ts +++ /dev/null @@ -1,813 +0,0 @@ -import { useCallback, useRef } from 'react' - -import { - ChatCompletionCreateParamsNonStreaming, - ChatCompletionMessageParam, - EngineStatus, - LocalEngines, - Message, - MessageContent, - RemoteEngines, - TextContentBlock, - Thread, -} from '@janhq/core' - -import { useAtomValue, useSetAtom } from 'jotai' - -import { currentPromptAtom, editPromptAtom } from '@/containers/Providers/Jotai' - -import { toaster } from '@/containers/Toast' - -import { defaultThreadTitle } from '@/constants/Threads' - -import { inferenceErrorAtom } from '@/screens/HubScreen2/components/InferenceErrorModal' - -import { showWarningMultipleModelModalAtom } from '@/screens/HubScreen2/components/WarningMultipleModelModal' -import { concurrentModelWarningThreshold } from '@/screens/Settings/MyModels/ModelItem' - -import { Stack } from '@/utils/Stack' - -import useCortex from './useCortex' - -import useEngineInit from './useEngineInit' -import useEngineQuery from './useEngineQuery' -import useMessageCreateMutation from './useMessageCreateMutation' -import useMessageUpdateMutation from './useMessageUpdateMutation' - -import useModelStart from './useModelStart' - -import { - addNewMessageAtom, - chunkCountAtom, - disableStopInferenceAtom, - getCurrentChatMessagesAtom, - updateMessageAtom, -} from '@/helpers/atoms/ChatMessage.atom' -import { - activeModelsAtom, - getSelectedModelAtom, -} from '@/helpers/atoms/Model.atom' -import { - activeThreadAtom, - addThreadIdShouldAnimateTitleAtom, - isGeneratingResponseAtom, - updateThreadTitleAtom, -} from '@/helpers/atoms/Thread.atom' - -const normalizeMessages = ( - messages: ChatCompletionMessageParam[] -): ChatCompletionMessageParam[] => { - const stack = new Stack() - for (const message of messages) { - if (stack.isEmpty()) { - stack.push(message) - continue - } - const topMessage = stack.peek() - - if (message.role === topMessage.role) { - // add an empty message - stack.push({ - role: topMessage.role === 'user' ? 'assistant' : 'user', - content: '.', // some model requires not empty message - }) - } - stack.push(message) - } - - return stack.reverseOutput() -} - -const useSendMessage = () => { - const createMessage = useMessageCreateMutation() - const updateMessage = useMessageUpdateMutation() - const initializeEngine = useEngineInit() - const addNewMessage = useSetAtom(addNewMessageAtom) - const { chatCompletionStreaming, chatCompletionNonStreaming, updateThread } = - useCortex() - const updateMessageState = useSetAtom(updateMessageAtom) - const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) - const setCurrentPrompt = useSetAtom(currentPromptAtom) - const setEditPrompt = useSetAtom(editPromptAtom) - const updateThreadTitle = useSetAtom(updateThreadTitleAtom) - const addThreadIdShouldAnimateTitle = useSetAtom( - addThreadIdShouldAnimateTitleAtom - ) - const { data: engineData } = useEngineQuery() - - const activeThread = useAtomValue(activeThreadAtom) - const activeModels = useAtomValue(activeModelsAtom) - const currentMessages = useAtomValue(getCurrentChatMessagesAtom) - const selectedModel = useAtomValue(getSelectedModelAtom) - const startModel = useModelStart() - - const abortControllerRef = useRef(undefined) - const didUserAborted = useRef(false) - const setInferenceErrorAtom = useSetAtom(inferenceErrorAtom) - const setShowWarningMultipleModelModal = useSetAtom( - showWarningMultipleModelModalAtom - ) - - const setDisableStopInference = useSetAtom(disableStopInferenceAtom) - const setChunkCount = useSetAtom(chunkCountAtom) - - const validatePrerequisite = useCallback(async (): Promise => { - const errorTitle = 'Failed to send message' - if (!activeThread) { - toaster({ - title: errorTitle, - description: 'No active thread! Please select a thread!', - type: 'error', - }) - return false - } - if (!selectedModel) { - toaster({ - title: errorTitle, - description: 'No model selected! Please select a model!', - type: 'error', - }) - return false - } - if (!engineData) { - toaster({ - title: errorTitle, - description: - 'Jan failed to fetch available engine data! Please try restart the app!', - type: 'error', - }) - return false - } - - try { - if (selectedModel.model !== activeThread.assistants[0].model) { - activeThread.assistants[0].model = selectedModel.model - await updateThread(activeThread) - } - } catch (err) { - toaster({ - title: errorTitle, - description: 'Please try select model for this thread again!', - type: 'error', - }) - console.error(`Failed to update thread ${activeThread.id}, error: ${err}`) - return false - } - - if (!selectedModel.engine) { - toaster({ - title: errorTitle, - description: `Model ${selectedModel.model} does not have an engine`, - type: 'error', - }) - console.error(`Model ${selectedModel.model} does not have an engine`) - return false - } - - const engineStatus = engineData.find((e) => e.name === selectedModel.engine) - if (!engineStatus) { - toaster({ - title: errorTitle, - description: `Engine ${selectedModel.engine} is not available`, - type: 'error', - }) - console.error(`Engine ${selectedModel.engine} is not available`) - return false - } - - if ( - RemoteEngines.find((e) => e === selectedModel.engine) != null && - engineStatus.status === 'missing_configuration' - ) { - toaster({ - title: errorTitle, - description: `Engine ${engineStatus.name} is missing configuration`, - type: 'error', - }) - console.error(`Engine ${engineStatus.name} is missing configuration`) - return false - } - - if ( - LocalEngines.find((e) => e === selectedModel.engine) != null && - engineStatus.status === 'not_initialized' - ) { - toaster({ - title: 'Please wait for engine to initialize', - description: `Please retry after engine ${engineStatus.name} is installed.`, - type: 'default', - }) - initializeEngine.mutate(selectedModel.engine) - return false - } - - if (engineStatus.status !== EngineStatus.Ready) { - toaster({ - title: errorTitle, - description: `Engine ${engineStatus.name} is not ready`, - type: 'error', - }) - console.error(`Engine ${engineStatus.name} is not ready`) - return false - } - - return true - }, [activeThread, selectedModel, engineData, initializeEngine, updateThread]) - - const stopInference = useCallback(() => { - abortControllerRef.current?.abort() - didUserAborted.current = true - }, []) - - const summarizeThread = useCallback( - async (messages: string[], modelId: string, thread: Thread) => { - // if its a local model, and is not started, skip summarization - if (LocalEngines.find((e) => e === selectedModel!.engine) != null) { - if (!activeModels.map((model) => model.model).includes(modelId)) { - return - } - } - const maxWordForThreadTitle = 10 - const summarizeMessages: ChatCompletionMessageParam[] = [ - { - role: 'user', - content: `Summarize in a ${maxWordForThreadTitle}-word title the following conversation:\n\n${messages.join('\n')}`, - }, - ] - - const summarizeParams: ChatCompletionCreateParamsNonStreaming = { - messages: summarizeMessages, - model: modelId, - max_tokens: 150, - temperature: 0.5, - } - const summarizeStream = await chatCompletionNonStreaming(summarizeParams) - const summarizedText = ( - summarizeStream.choices[0].message.content ?? 'New Thread' - ).replace(/"/g, '') - - addThreadIdShouldAnimateTitle(thread.id) - updateThread({ ...thread, title: summarizedText }) - updateThreadTitle(thread.id, summarizedText) - }, - [ - activeModels, - selectedModel, - addThreadIdShouldAnimateTitle, - chatCompletionNonStreaming, - updateThreadTitle, - updateThread, - ] - ) - - const resendMessage = useCallback(async () => { - const isValid = await validatePrerequisite() - if (!isValid) return - - const modelId = activeThread!.assistants[0].model - - try { - // start model if not yet started - if (LocalEngines.find((e) => e === selectedModel!.engine) != null) { - // start model if local and not started - if (!activeModels.map((model) => model.model).includes(modelId)) { - if (activeModels.length >= concurrentModelWarningThreshold) { - // if max concurrent models reached, stop the first model - // display popup - setShowWarningMultipleModelModal(true) - } - await startModel.mutateAsync(modelId) - } - } - } catch (err) { - console.error(`Failed to start model ${modelId}, error: ${err}`) - toaster({ - title: 'Failed to start model', - description: `Failed to start model ${modelId}`, - type: 'error', - }) - } - - setIsGeneratingResponse(true) - - // building messages - const systemMessage: ChatCompletionMessageParam = { - role: 'system', - content: activeThread!.assistants[0].instructions ?? '', - } - - let messages: ChatCompletionMessageParam[] = currentMessages - .map((msg) => { - switch (msg.role) { - case 'user': - case 'assistant': - return { - role: msg.role, - content: - msg.content[0] != null - ? (msg.content[0] as TextContentBlock).text.value - : '', - } - - // we will need to support other roles in the future - default: - break - } - }) - .filter((msg) => msg != null) as ChatCompletionMessageParam[] - messages.unshift(systemMessage) - messages = normalizeMessages(messages) - const modelOptions: Record = {} - if (selectedModel!.frequency_penalty) { - modelOptions.frequency_penalty = selectedModel!.frequency_penalty - } - if (selectedModel!.presence_penalty) { - modelOptions.presence_penalty = selectedModel!.presence_penalty - } - try { - let assistantResponseMessage = '' - if (selectedModel!.stream === true) { - const stream = await chatCompletionStreaming({ - messages, - model: selectedModel!.model, - stream: true, - max_tokens: selectedModel!.max_tokens, - stop: selectedModel!.stop, - temperature: selectedModel!.temperature ?? 1, - top_p: selectedModel!.top_p ?? 1, - ...modelOptions, - }) - - didUserAborted.current = false - abortControllerRef.current = stream.controller - - const assistantMessage = await createMessage.mutateAsync({ - threadId: activeThread!.id, - createMessageParams: { - role: 'assistant', - content: '', - }, - }) - - const responseMessage: Message = { - id: assistantMessage.id, - thread_id: activeThread!.id, - assistant_id: activeThread!.id, - role: 'assistant', - content: [], - status: 'in_progress', - created_at: assistantMessage.created_at, - metadata: undefined, - attachments: null, - completed_at: Date.now(), - incomplete_at: null, - incomplete_details: null, - object: 'thread.message', - run_id: null, - } - - addNewMessage(responseMessage) - - let chunkCount = 1 - for await (const chunk of stream) { - setChunkCount((prev) => ({ - ...prev, - [responseMessage.id]: chunkCount++, - })) - const content = chunk.choices[0]?.delta?.content || '' - assistantResponseMessage += content - const messageContent: MessageContent = { - type: 'text', - text: { - value: assistantResponseMessage, - annotations: [], - }, - } - responseMessage.content = [messageContent] - updateMessageState( - responseMessage.id, - responseMessage.thread_id, - responseMessage.content, - responseMessage.status - ) - } - - abortControllerRef.current = undefined - - responseMessage.status = 'completed' - updateMessageState( - responseMessage.id, - responseMessage.thread_id, - responseMessage.content, - responseMessage.status - ) - - updateMessage.mutateAsync({ - threadId: activeThread!.id, - messageId: responseMessage.id, - data: { - content: responseMessage.content, - }, - }) - } else { - didUserAborted.current = false - const abortController = new AbortController() - const response = await chatCompletionNonStreaming( - { - messages, - model: selectedModel!.model, - stream: false, - max_tokens: selectedModel!.max_tokens, - stop: selectedModel!.stop, - temperature: selectedModel!.temperature ?? 1, - top_p: selectedModel!.top_p ?? 1, - ...modelOptions, - }, - { - signal: abortController.signal, - } - ) - - assistantResponseMessage = response.choices[0].message.content ?? '' - const assistantMessage = await createMessage.mutateAsync({ - threadId: activeThread!.id, - createMessageParams: { - role: 'assistant', - content: assistantResponseMessage, - }, - }) - - const responseMessage: Message = { - id: assistantMessage.id, - thread_id: activeThread!.id, - assistant_id: activeThread!.id, - role: 'assistant', - content: [ - { - type: 'text', - text: { - value: assistantResponseMessage, - annotations: [], - }, - }, - ], - status: 'completed', - created_at: assistantMessage.created_at, - metadata: undefined, - attachments: null, - completed_at: Date.now(), - incomplete_at: null, - incomplete_details: null, - object: 'thread.message', - run_id: null, - } - updateMessage.mutate({ - threadId: activeThread!.id, - messageId: responseMessage.id, - data: { - content: responseMessage.content, - }, - }) - addNewMessage(responseMessage) - } - } catch (err) { - console.error(err) - // @ts-expect-error error message should be there - const errorMessage = err['message'] - if (errorMessage != null) { - setInferenceErrorAtom({ - engine: selectedModel!.engine, - message: errorMessage, - }) - } - - toaster({ - title: `Error with ${selectedModel!.model}`, - description: 'Failed to generate response', - type: 'error', - }) - } - - setIsGeneratingResponse(false) - }, [ - activeThread, - activeModels, - currentMessages, - selectedModel, - updateMessage, - createMessage, - startModel, - setInferenceErrorAtom, - validatePrerequisite, - updateMessageState, - addNewMessage, - chatCompletionNonStreaming, - chatCompletionStreaming, - setIsGeneratingResponse, - setShowWarningMultipleModelModal, - setChunkCount, - ]) - - const sendMessage = useCallback( - async (message: string) => { - const isValid = await validatePrerequisite() - if (!isValid) return - - let shouldSummarize = - activeThread!.title === defaultThreadTitle || - activeThread!.title.trim() === '' - const modelId = activeThread!.assistants[0].model - - setCurrentPrompt('') - setEditPrompt('') - - const userMessage = await createMessage.mutateAsync({ - threadId: activeThread!.id, - createMessageParams: { - role: 'user', - content: message, - }, - }) - // Push to states - addNewMessage(userMessage) - - try { - // start model if not yet started - if (LocalEngines.find((e) => e === selectedModel!.engine) != null) { - // start model if local and not started - if (!activeModels.map((model) => model.model).includes(modelId)) { - if (activeModels.length >= concurrentModelWarningThreshold) { - // if max concurrent models reached, stop the first model - // display popup - setShowWarningMultipleModelModal(true) - } - await startModel.mutateAsync(modelId) - } - } - } catch (err) { - console.error(`Failed to start model ${modelId}, error: ${err}`) - return - } - - setIsGeneratingResponse(true) - - // building messages - const systemMessage: ChatCompletionMessageParam = { - role: 'system', - content: activeThread!.assistants[0].instructions ?? '', - } - - let messages: ChatCompletionMessageParam[] = currentMessages - .map((msg) => { - switch (msg.role) { - case 'user': - case 'assistant': - return { - role: msg.role, - content: - msg.content[0] != null - ? (msg.content[0] as TextContentBlock).text.value - : '', - } - - // we will need to support other roles in the future - default: - break - } - }) - .filter((msg) => msg != null) as ChatCompletionMessageParam[] - messages.push({ - role: 'user', - content: message, - }) - messages.unshift(systemMessage) - messages = normalizeMessages(messages) - const modelOptions: Record = {} - if (selectedModel!.frequency_penalty) { - modelOptions.frequency_penalty = selectedModel!.frequency_penalty - } - if (selectedModel!.presence_penalty) { - modelOptions.presence_penalty = selectedModel!.presence_penalty - } - let assistantResponseMessage = '' - try { - if (selectedModel!.stream === true) { - setDisableStopInference(true) - const stream = await chatCompletionStreaming({ - messages, - model: selectedModel!.model, - stream: true, - max_tokens: selectedModel!.max_tokens, - stop: selectedModel!.stop, - temperature: selectedModel!.temperature ?? 1, - top_p: selectedModel!.top_p ?? 1, - ...modelOptions, - }) - didUserAborted.current = false - abortControllerRef.current = stream.controller - - const assistantMessage = await createMessage.mutateAsync({ - threadId: activeThread!.id, - createMessageParams: { - role: 'assistant', - content: '', - }, - }) - - const responseMessage: Message = { - id: assistantMessage.id, - thread_id: activeThread!.id, - assistant_id: activeThread!.id, - role: 'assistant', - content: [], - status: 'in_progress', - created_at: assistantMessage.created_at, - metadata: undefined, - attachments: null, - completed_at: Date.now(), - incomplete_at: null, - incomplete_details: null, - object: 'thread.message', - run_id: null, - } - - if (responseMessage) { - setIsGeneratingResponse(false) - } - - addNewMessage(responseMessage) - - let chunkCount = 1 - for await (const chunk of stream) { - setChunkCount((prev) => ({ - ...prev, - [responseMessage.id]: chunkCount++, - })) - // we have first chunk, enable the inference button - setDisableStopInference(false) - const content = chunk.choices[0]?.delta?.content || '' - assistantResponseMessage += content - const messageContent: MessageContent = { - type: 'text', - text: { - value: assistantResponseMessage, - annotations: [], - }, - } - responseMessage.content = [messageContent] - updateMessageState( - responseMessage.id, - responseMessage.thread_id, - responseMessage.content, - responseMessage.status - ) - } - - abortControllerRef.current = undefined - - responseMessage.status = 'completed' - updateMessageState( - responseMessage.id, - responseMessage.thread_id, - responseMessage.content, - responseMessage.status - ) - updateMessage.mutateAsync({ - threadId: activeThread!.id, - messageId: responseMessage.id, - data: { - content: responseMessage.content, - }, - }) - } else { - didUserAborted.current = false - const abortController = new AbortController() - abortControllerRef.current = abortController - - const response = await chatCompletionNonStreaming( - { - messages, - model: selectedModel!.model, - stream: false, - max_tokens: selectedModel!.max_tokens, - stop: selectedModel!.stop, - temperature: selectedModel!.temperature ?? 1, - top_p: selectedModel!.top_p ?? 1, - ...modelOptions, - }, - { - signal: abortController.signal, - } - ) - - assistantResponseMessage = response.choices[0].message.content ?? '' - const assistantMessage = await createMessage.mutateAsync({ - threadId: activeThread!.id, - createMessageParams: { - role: 'assistant', - content: assistantResponseMessage, - }, - }) - - const responseMessage: Message = { - id: assistantMessage.id, - thread_id: activeThread!.id, - assistant_id: activeThread!.id, - role: 'assistant', - content: [ - { - type: 'text', - text: { - value: assistantResponseMessage, - annotations: [], - }, - }, - ], - status: 'completed', - created_at: assistantMessage.created_at, - metadata: undefined, - attachments: null, - completed_at: Date.now(), - incomplete_at: null, - incomplete_details: null, - object: 'thread.message', - run_id: null, - } - updateMessage.mutateAsync({ - threadId: activeThread!.id, - messageId: responseMessage.id, - data: { - content: responseMessage.content, - }, - }) - abortControllerRef.current = undefined - if (responseMessage) { - setIsGeneratingResponse(false) - } - - addNewMessage(responseMessage) - } - } catch (err) { - console.error(err) - // @ts-expect-error error message should be there - const errorMessage = err['message'] - if (errorMessage != null) { - setInferenceErrorAtom({ - engine: selectedModel!.engine, - message: errorMessage, - }) - } - - setDisableStopInference(false) - setIsGeneratingResponse(false) - shouldSummarize = false - - toaster({ - title: `Error with ${selectedModel!.model}`, - description: 'Failed to generate response', - type: 'error', - }) - } - - try { - if (!shouldSummarize || didUserAborted.current === true) return - // summarize if needed - const textMessages: string[] = messages - .map((msg) => { - if (typeof msg.content === 'string') return msg.content - }) - .filter((msg) => msg != null) as string[] - textMessages.push(assistantResponseMessage) - summarizeThread(textMessages, modelId, activeThread!) - } catch (err) { - console.error(`Failed to summarize thread: ${err}`) - } - }, - [ - activeThread, - activeModels, - currentMessages, - selectedModel, - updateMessage, - createMessage, - startModel, - setInferenceErrorAtom, - validatePrerequisite, - setCurrentPrompt, - setEditPrompt, - setIsGeneratingResponse, - updateMessageState, - addNewMessage, - chatCompletionNonStreaming, - chatCompletionStreaming, - summarizeThread, - setShowWarningMultipleModelModal, - setDisableStopInference, - setChunkCount, - ] - ) - - return { resendMessage, sendMessage, stopInference } -} - -export default useSendMessage diff --git a/web/hooks/useSetActiveThread.ts b/web/hooks/useSetActiveThread.ts new file mode 100644 index 0000000000..8e92680650 --- /dev/null +++ b/web/hooks/useSetActiveThread.ts @@ -0,0 +1,43 @@ +import { ExtensionTypeEnum, Thread, ConversationalExtension } from '@janhq/core' + +import { useAtomValue, useSetAtom } from 'jotai' + +import { extensionManager } from '@/extension' +import { + readyThreadsMessagesAtom, + setConvoMessagesAtom, +} from '@/helpers/atoms/ChatMessage.atom' +import { + ModelParams, + setActiveThreadIdAtom, + setThreadModelParamsAtom, +} from '@/helpers/atoms/Thread.atom' + +export default function useSetActiveThread() { + const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) + const setThreadMessage = useSetAtom(setConvoMessagesAtom) + const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) + const readyMessageThreads = useAtomValue(readyThreadsMessagesAtom) + + const setActiveThread = async (thread: Thread) => { + // Load local messages only if there are no messages in the state + if (!readyMessageThreads[thread?.id]) { + const messages = await getLocalThreadMessage(thread?.id) + setThreadMessage(thread?.id, messages) + } + + setActiveThreadId(thread?.id) + const modelParams: ModelParams = { + ...thread?.assistants[0]?.model?.parameters, + ...thread?.assistants[0]?.model?.settings, + } + setThreadModelParams(thread?.id, modelParams) + } + + return { setActiveThread } +} + +const getLocalThreadMessage = async (threadId: string) => + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.getAllMessages(threadId) ?? [] diff --git a/web/hooks/useSettings.ts b/web/hooks/useSettings.ts index 87996c1663..8743813173 100644 --- a/web/hooks/useSettings.ts +++ b/web/hooks/useSettings.ts @@ -1,5 +1,7 @@ import { useCallback, useEffect, useState } from 'react' +import { fs, joinPath } from '@janhq/core' + type NvidiaDriver = { exist: boolean version: string @@ -25,14 +27,14 @@ export const useSettings = () => { }, []) const readSettings = useCallback(async () => { - // if (!window?.core?.api) { - // return - // } - // const settingsFile = await joinPath(['file://settings', 'settings.json']) - // if (await fs.existsSync(settingsFile)) { - // const settings = await fs.readFileSync(settingsFile, 'utf-8') - // return typeof settings === 'object' ? settings : JSON.parse(settings) - // } + if (!window?.core?.api) { + return + } + const settingsFile = await joinPath(['file://settings', 'settings.json']) + if (await fs.existsSync(settingsFile)) { + const settings = await fs.readFileSync(settingsFile, 'utf-8') + return typeof settings === 'object' ? settings : JSON.parse(settings) + } return {} }, []) @@ -47,22 +49,21 @@ export const useSettings = () => { gpusInUse?: string[] | undefined vulkan?: boolean | undefined }) => { - console.log('saveSettings', runMode, notify, gpusInUse, vulkan) - // const settingsFile = await joinPath(['file://settings', 'settings.json']) - // const settings = await readSettings() - // if (runMode != null) settings.run_mode = runMode - // if (notify != null) settings.notify = notify - // if (gpusInUse != null) settings.gpus_in_use = gpusInUse - // if (vulkan != null) { - // settings.vulkan = vulkan - // // GPU enabled, set run_mode to 'gpu' - // if (settings.vulkan === true) { - // settings.run_mode = 'gpu' - // } else { - // settings.run_mode = 'cpu' - // } - // } - // await fs.writeFileSync(settingsFile, JSON.stringify(settings)) + const settingsFile = await joinPath(['file://settings', 'settings.json']) + const settings = await readSettings() + if (runMode != null) settings.run_mode = runMode + if (notify != null) settings.notify = notify + if (gpusInUse != null) settings.gpus_in_use = gpusInUse + if (vulkan != null) { + settings.vulkan = vulkan + // GPU enabled, set run_mode to 'gpu' + if (settings.vulkan === true) { + settings.run_mode = 'gpu' + } else { + settings.run_mode = 'cpu' + } + } + await fs.writeFileSync(settingsFile, JSON.stringify(settings)) } return { diff --git a/web/hooks/useThreadCreateMutation.ts b/web/hooks/useThreadCreateMutation.ts deleted file mode 100644 index 4cff3c045b..0000000000 --- a/web/hooks/useThreadCreateMutation.ts +++ /dev/null @@ -1,58 +0,0 @@ -import { Assistant } from '@janhq/core' -import { useMutation } from '@tanstack/react-query' - -import { useSetAtom } from 'jotai' - -import { toaster } from '@/containers/Toast' - -import useCortex from './useCortex' - -import { setThreadMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' -import { setActiveThreadIdAtom, threadsAtom } from '@/helpers/atoms/Thread.atom' - -export type ThreadCreateMutationVariables = { - modelId: string - assistant: Assistant - instructions?: string -} - -const useThreadCreateMutation = () => { - const { createThread } = useCortex() - const setThreads = useSetAtom(threadsAtom) - const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) - const setThreadMessage = useSetAtom(setThreadMessagesAtom) - - return useMutation({ - mutationFn: async (variables: ThreadCreateMutationVariables) => { - const { assistant, modelId, instructions } = variables - if (instructions) { - assistant.instructions = instructions - } - - return createThread({ - ...assistant, - model: modelId, - }) - }, - - onSuccess: (thread, variables, context) => { - console.log('New thread created', thread, variables, context) - setThreads((threads) => [thread, ...threads]) - setActiveThreadId(thread.id) - setThreadMessage(thread.id, []) - }, - - onError: (error, variables) => { - console.error( - `Failed to create new thread: ${JSON.stringify(variables)}, error: ${error}` - ) - toaster({ - title: 'Failed to create thread', - description: `Unexpected error while creating thread. Please try again!`, - type: 'error', - }) - }, - }) -} - -export default useThreadCreateMutation diff --git a/web/hooks/useThreadQuery.ts b/web/hooks/useThreadQuery.ts deleted file mode 100644 index 034490e9e8..0000000000 --- a/web/hooks/useThreadQuery.ts +++ /dev/null @@ -1,26 +0,0 @@ -import { useQuery } from '@tanstack/react-query' - -import { useSetAtom } from 'jotai' - -import useCortex from './useCortex' - -import { threadsAtom } from '@/helpers/atoms/Thread.atom' - -export const threadQueryKey = ['getThreads'] - -const useThreadQuery = () => { - const { fetchThreads } = useCortex() - const setThreads = useSetAtom(threadsAtom) - - return useQuery({ - queryKey: threadQueryKey, - queryFn: async () => { - const threads = await fetchThreads() - setThreads(threads) - return threads - }, - staleTime: 30 * 1000, - }) -} - -export default useThreadQuery diff --git a/web/hooks/useThreads.ts b/web/hooks/useThreads.ts index 54128d9ddb..fd0b3456d4 100644 --- a/web/hooks/useThreads.ts +++ b/web/hooks/useThreads.ts @@ -1,87 +1,73 @@ -import { useCallback } from 'react' +import { useEffect } from 'react' -import { Assistant } from '@janhq/core' +import { + ExtensionTypeEnum, + Thread, + ThreadState, + ConversationalExtension, +} from '@janhq/core' import { useSetAtom } from 'jotai' -import useCortex from './useCortex' - -import { - cleanChatMessageAtom, - deleteChatMessageAtom, -} from '@/helpers/atoms/ChatMessage.atom' - -import { setThreadMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' +import { extensionManager } from '@/extension/ExtensionManager' import { - deleteThreadAtom, - setActiveThreadIdAtom, + ModelParams, + threadDataReadyAtom, + threadModelParamsAtom, + threadStatesAtom, threadsAtom, } from '@/helpers/atoms/Thread.atom' const useThreads = () => { + const setThreadStates = useSetAtom(threadStatesAtom) const setThreads = useSetAtom(threadsAtom) - const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) - const setThreadMessage = useSetAtom(setThreadMessagesAtom) - const deleteMessages = useSetAtom(deleteChatMessageAtom) - const deleteThreadState = useSetAtom(deleteThreadAtom) - const cleanMessages = useSetAtom(cleanChatMessageAtom) - const { - createThread, - fetchMessages, - deleteThread: deleteCortexThread, - cleanThread: cleanCortexThread, - } = useCortex() + const setThreadModelRuntimeParams = useSetAtom(threadModelParamsAtom) + const setThreadDataReady = useSetAtom(threadDataReadyAtom) + + useEffect(() => { + const getThreads = async () => { + const localThreads = await getLocalThreads() + const localThreadStates: Record = {} + const threadModelParams: Record = {} - const setActiveThread = useCallback( - async (threadId: string) => { - const messages = await fetchMessages(threadId) - setThreadMessage(threadId, messages) - setActiveThreadId(threadId) - }, - [fetchMessages, setThreadMessage, setActiveThreadId] - ) + localThreads.forEach((thread) => { + if (thread.id != null) { + const lastMessage = (thread.metadata?.lastMessage as string) ?? '' - const createNewThread = useCallback( - async (modelId: string, assistant: Assistant, instructions?: string) => { - assistant.model = modelId - if (instructions) { - assistant.instructions = instructions - } - const thread = await createThread(assistant) - setThreads((threads) => [thread, ...threads]) - setActiveThread(thread.id) - return thread - }, - [createThread, setActiveThread, setThreads] - ) + localThreadStates[thread.id] = { + hasMore: false, + waitingForResponse: false, + lastMessage, + } - const deleteThread = useCallback( - async (threadId: string) => { - try { - await deleteCortexThread(threadId) - deleteThreadState(threadId) - deleteMessages(threadId) - } catch (err) { - console.error(err) - } - }, - [deleteMessages, deleteCortexThread, deleteThreadState] - ) + const modelParams = thread.assistants?.[0]?.model?.parameters + const engineParams = thread.assistants?.[0]?.model?.settings + threadModelParams[thread.id] = { + ...modelParams, + ...engineParams, + } + } + }) - const cleanThread = useCallback( - async (threadId: string) => { - await cleanCortexThread(threadId) - cleanMessages(threadId) - }, - [cleanCortexThread, cleanMessages] - ) + // updating app states + setThreadStates(localThreadStates) + setThreads(localThreads) + setThreadModelRuntimeParams(threadModelParams) + setThreadDataReady(true) + } - return { - createThread: createNewThread, - setActiveThread, - deleteThread, - cleanThread, - } + getThreads() + }, [ + setThreadModelRuntimeParams, + setThreadStates, + setThreads, + setThreadDataReady, + ]) } +const getLocalThreads = async (): Promise => + (await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.getThreads()) ?? [] + export default useThreads diff --git a/web/hooks/useUpdateInstruction.ts b/web/hooks/useUpdateInstruction.ts deleted file mode 100644 index 378a3b1eca..0000000000 --- a/web/hooks/useUpdateInstruction.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { useCallback } from 'react' - -import { useAtomValue } from 'jotai' - -import useCortex from './useCortex' - -import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' - -const useUpdateInstruction = () => { - const activeThread = useAtomValue(activeThreadAtom) - const { updateThread } = useCortex() - - const updateInstruction = useCallback( - (instructions: string) => { - if (!activeThread) return - activeThread.assistants[0].instructions = instructions - updateThread(activeThread) - }, - [activeThread, updateThread] - ) - - return { updateInstruction } -} - -export default useUpdateInstruction diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts new file mode 100644 index 0000000000..d819a85ff2 --- /dev/null +++ b/web/hooks/useUpdateModelParameters.ts @@ -0,0 +1,81 @@ +import { useCallback } from 'react' + +import { + ConversationalExtension, + ExtensionTypeEnum, + InferenceEngine, + Thread, + ThreadAssistantInfo, +} from '@janhq/core' + +import { useAtomValue, useSetAtom } from 'jotai' + +import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' + +import { extensionManager } from '@/extension' +import { selectedModelAtom } from '@/helpers/atoms/Model.atom' +import { + ModelParams, + getActiveThreadModelParamsAtom, + setThreadModelParamsAtom, +} from '@/helpers/atoms/Thread.atom' + +export type UpdateModelParameter = { + params?: ModelParams + modelId?: string + engine?: InferenceEngine +} + +export default function useUpdateModelParameters() { + const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) + const selectedModel = useAtomValue(selectedModelAtom) + const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) + + const updateModelParameter = useCallback( + async (thread: Thread, settings: UpdateModelParameter) => { + const toUpdateSettings = processStopWords(settings.params ?? {}) + const updatedModelParams = settings.modelId + ? toUpdateSettings + : { ...activeModelParams, ...toUpdateSettings } + + // update the state + setThreadModelParams(thread.id, updatedModelParams) + + const assistants = thread.assistants.map( + (assistant: ThreadAssistantInfo) => { + const runtimeParams = toRuntimeParams(updatedModelParams) + const settingParams = toSettingParams(updatedModelParams) + + assistant.model.parameters = runtimeParams + assistant.model.settings = settingParams + if (selectedModel) { + assistant.model.id = settings.modelId ?? selectedModel?.id + assistant.model.engine = settings.engine ?? selectedModel?.engine + } + return assistant + } + ) + + // update thread + const updatedThread: Thread = { + ...thread, + assistants, + } + + await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.saveThread(updatedThread) + }, + [activeModelParams, selectedModel, setThreadModelParams] + ) + + const processStopWords = (params: ModelParams): ModelParams => { + if ('stop' in params && typeof params['stop'] === 'string') { + // Input as string but stop words accept an array of strings (space as separator) + params['stop'] = (params['stop'] as string).split(' ') + } + return params + } + + return { updateModelParameter } +} diff --git a/web/package.json b/web/package.json index d22a60e2b1..53e04c3e6b 100644 --- a/web/package.json +++ b/web/package.json @@ -13,12 +13,9 @@ "compile": "tsc --noEmit -p . --pretty" }, "dependencies": { - "@tanstack/react-query": "^5.48.0", - "yaml": "^2.4.5", - "@huggingface/hub": "^0.15.1", - "embla-carousel-react": "^8.1.5", - "@cortexso/cortex.js": "^0.1.6", - "@microsoft/fetch-event-source": "^2.0.1", + "@headlessui/react": "^1.7.15", + "@heroicons/react": "^2.0.18", + "@hookform/resolvers": "^3.3.2", "@janhq/core": "link:./core", "@janhq/joi": "link:./joi", "autoprefixer": "10.4.16", @@ -27,7 +24,8 @@ "highlight.js": "^11.9.0", "jotai": "^2.6.0", "katex": "^0.16.10", - "lucide-react": "^0.352.0", + "lodash": "^4.17.21", + "lucide-react": "^0.291.0", "marked": "^9.1.2", "marked-highlight": "^2.0.6", "marked-katex-extension": "^5.0.2", @@ -39,17 +37,28 @@ "react-circular-progressbar": "^2.1.0", "react-dom": "18.2.0", "react-dropzone": "^14.2.3", + "react-hook-form": "^7.47.0", "react-hot-toast": "^2.4.1", + "csstype": "^3.0.10", + "react-icons": "^4.12.0", + "react-scroll-to-bottom": "^4.2.0", + "react-toastify": "^9.1.3", "sass": "^1.69.4", "tailwind-merge": "^2.0.0", "tailwindcss": "3.3.5", - "use-debounce": "^10.0.0" + "ulidx": "^2.3.0", + "uuid": "^9.0.1", + "use-debounce": "^10.0.0", + "zod": "^3.22.4" }, "devDependencies": { "@next/eslint-plugin-next": "^14.0.1", + "@types/lodash": "^4.14.200", "@types/node": "20.8.10", "@types/react": "18.2.34", "@types/react-dom": "18.2.14", + "@types/react-icons": "^3.0.0", + "@types/react-scroll-to-bottom": "^4.2.4", "@types/uuid": "^9.0.6", "@typescript-eslint/eslint-plugin": "^6.8.0", "@typescript-eslint/parser": "^6.8.0", diff --git a/web/public/icons/anthropic.svg b/web/public/icons/anthropic.svg deleted file mode 100644 index 1f3f18dcfc..0000000000 --- a/web/public/icons/anthropic.svg +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - - - - diff --git a/web/public/icons/cohere.svg b/web/public/icons/cohere.svg deleted file mode 100644 index 0ff4f00290..0000000000 --- a/web/public/icons/cohere.svg +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - - - - diff --git a/web/public/icons/dot.svg b/web/public/icons/dot.svg deleted file mode 100644 index f667c20b13..0000000000 --- a/web/public/icons/dot.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/web/public/icons/groq.svg b/web/public/icons/groq.svg deleted file mode 100644 index 9c2e0a34a8..0000000000 --- a/web/public/icons/groq.svg +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - - - - diff --git a/web/public/icons/ic_cortex.svg b/web/public/icons/ic_cortex.svg deleted file mode 100644 index 8e75cfb348..0000000000 --- a/web/public/icons/ic_cortex.svg +++ /dev/null @@ -1,34 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/public/icons/ic_hugging_face.svg b/web/public/icons/ic_hugging_face.svg deleted file mode 100644 index f5d8a5d0dc..0000000000 --- a/web/public/icons/ic_hugging_face.svg +++ /dev/null @@ -1,21 +0,0 @@ - - - - - - - - - \ No newline at end of file diff --git a/web/public/icons/llamacpp.svg b/web/public/icons/llamacpp.svg deleted file mode 100644 index 44db362013..0000000000 --- a/web/public/icons/llamacpp.svg +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - - - - - - - - - diff --git a/web/public/icons/martian.svg b/web/public/icons/martian.svg deleted file mode 100644 index b5ceacdf8c..0000000000 --- a/web/public/icons/martian.svg +++ /dev/null @@ -1,11 +0,0 @@ - - - - - - - - - - - diff --git a/web/public/icons/mistral.svg b/web/public/icons/mistral.svg deleted file mode 100644 index 22233c55cd..0000000000 --- a/web/public/icons/mistral.svg +++ /dev/null @@ -1,28 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/web/public/icons/nvidia.svg b/web/public/icons/nvidia.svg deleted file mode 100644 index 09c2194ece..0000000000 --- a/web/public/icons/nvidia.svg +++ /dev/null @@ -1,10 +0,0 @@ - - - - - - - - - - diff --git a/web/public/icons/openRouter.svg b/web/public/icons/openRouter.svg deleted file mode 100644 index 62ff2b424b..0000000000 --- a/web/public/icons/openRouter.svg +++ /dev/null @@ -1,14 +0,0 @@ - - - - - - - - - - - - - - diff --git a/web/public/icons/openai.svg b/web/public/icons/openai.svg deleted file mode 100644 index 8f07854155..0000000000 --- a/web/public/icons/openai.svg +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - - - - diff --git a/web/public/icons/send.svg b/web/public/icons/send.svg deleted file mode 100644 index 28d30299fd..0000000000 --- a/web/public/icons/send.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/web/public/images/ModelProvider/anthropic.svg b/web/public/images/ModelProvider/anthropic.svg new file mode 100644 index 0000000000..7bb86df4a4 --- /dev/null +++ b/web/public/images/ModelProvider/anthropic.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/web/public/images/ModelProvider/cohere.svg b/web/public/images/ModelProvider/cohere.svg new file mode 100644 index 0000000000..543bc2d6ca --- /dev/null +++ b/web/public/images/ModelProvider/cohere.svg @@ -0,0 +1,30 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/public/images/ModelProvider/martian.svg b/web/public/images/ModelProvider/martian.svg new file mode 100644 index 0000000000..f63ded55a2 --- /dev/null +++ b/web/public/images/ModelProvider/martian.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/web/public/images/ModelProvider/mistral.svg b/web/public/images/ModelProvider/mistral.svg new file mode 100644 index 0000000000..2bb14b9bc0 --- /dev/null +++ b/web/public/images/ModelProvider/mistral.svg @@ -0,0 +1,32 @@ + + + Mistral AI + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/public/images/ModelProvider/openai.svg b/web/public/images/ModelProvider/openai.svg new file mode 100644 index 0000000000..433ae3d45e --- /dev/null +++ b/web/public/images/ModelProvider/openai.svg @@ -0,0 +1,24 @@ + + + + + + + + + diff --git a/web/screens/Hub/ModelList/ModelHeader/index.tsx b/web/screens/Hub/ModelList/ModelHeader/index.tsx new file mode 100644 index 0000000000..b20977affe --- /dev/null +++ b/web/screens/Hub/ModelList/ModelHeader/index.tsx @@ -0,0 +1,179 @@ +import { useCallback } from 'react' + +import { Model } from '@janhq/core' +import { Button, Badge, Tooltip } from '@janhq/joi' + +import { useAtomValue, useSetAtom } from 'jotai' + +import { ChevronDownIcon } from 'lucide-react' + +import { twMerge } from 'tailwind-merge' + +import ModalCancelDownload from '@/containers/ModalCancelDownload' + +import ModelLabel from '@/containers/ModelLabel' + +import { toaster } from '@/containers/Toast' + +import { MainViewState } from '@/constants/screens' + +import { useCreateNewThread } from '@/hooks/useCreateNewThread' +import useDownloadModel from '@/hooks/useDownloadModel' + +import { useSettings } from '@/hooks/useSettings' + +import { toGibibytes } from '@/utils/converter' + +import { mainViewStateAtom } from '@/helpers/atoms/App.atom' +import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' +import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' + +import { + downloadedModelsAtom, + getDownloadingModelAtom, +} from '@/helpers/atoms/Model.atom' +import { + nvidiaTotalVramAtom, + totalRamAtom, +} from '@/helpers/atoms/SystemBar.atom' + +type Props = { + model: Model + onClick: () => void + open: string +} + +const ModelItemHeader = ({ model, onClick, open }: Props) => { + const { downloadModel } = useDownloadModel() + const downloadingModels = useAtomValue(getDownloadingModelAtom) + const downloadedModels = useAtomValue(downloadedModelsAtom) + const { requestCreateNewThread } = useCreateNewThread() + const totalRam = useAtomValue(totalRamAtom) + const { settings } = useSettings() + // const [imageLoaded, setImageLoaded] = useState(true) + + const nvidiaTotalVram = useAtomValue(nvidiaTotalVramAtom) + const setMainViewState = useSetAtom(mainViewStateAtom) + + // Default nvidia returns vram in MB, need to convert to bytes to match the unit of totalRamW + let ram = nvidiaTotalVram * 1024 * 1024 + if (ram === 0 || settings?.run_mode === 'cpu') { + ram = totalRam + } + const serverEnabled = useAtomValue(serverEnabledAtom) + const assistants = useAtomValue(assistantsAtom) + + const onDownloadClick = useCallback(() => { + downloadModel(model) + }, [model, downloadModel]) + + const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null + + let downloadButton = ( + + ) + + const isDownloading = downloadingModels.some((md) => md.id === model.id) + + const onUseModelClick = useCallback(async () => { + if (assistants.length === 0) { + toaster({ + title: 'No assistant available.', + description: `Could not use Model ${model.name} as no assistant is available.`, + type: 'error', + }) + return + } + await requestCreateNewThread(assistants[0], model) + setMainViewState(MainViewState.Thread) + }, [assistants, model, requestCreateNewThread, setMainViewState]) + + if (isDownloaded) { + downloadButton = ( + + Use + + } + disabled={!serverEnabled} + content="Threads are disabled while the server is running" + /> + ) + } else if (isDownloading) { + downloadButton = + } + + return ( +
+ {/* TODO: @faisal are we still using cover? */} + {/* {model.metadata.cover && imageLoaded && ( +
+ setImageLoaded(false)} + src={model.metadata.cover} + className="h-[250px] w-full object-cover" + alt={`Cover - ${model.id}`} + /> +
+ )} */} +
+
+ + {model.name} + + +
+
+
+ + {toGibibytes(model.metadata.size)} + + +
+ {downloadButton} + +
+
+
+ ) +} + +type EngineBadgeProps = { + engine: string +} + +const EngineBadge = ({ engine }: EngineBadgeProps) => { + const title = 'TensorRT-LLM' + + switch (engine) { + case 'nitro-tensorrt-llm': + return {title} + default: + return null + } +} + +export default ModelItemHeader diff --git a/web/screens/Hub/ModelList/ModelItem/index.tsx b/web/screens/Hub/ModelList/ModelItem/index.tsx index 837cb53df2..c9b2f13294 100644 --- a/web/screens/Hub/ModelList/ModelItem/index.tsx +++ b/web/screens/Hub/ModelList/ModelItem/index.tsx @@ -3,8 +3,12 @@ import { useState } from 'react' import { Model } from '@janhq/core' import { Badge } from '@janhq/joi' +import { twMerge } from 'tailwind-merge' + import ModelLabel from '@/containers/ModelLabel' +import ModelItemHeader from '@/screens/Hub/ModelList/ModelHeader' + import { toGibibytes } from '@/utils/converter' type Props = { @@ -14,23 +18,30 @@ type Props = { const ModelItem: React.FC = ({ model }) => { const [open, setOpen] = useState('') - console.log('ModelItem', model, setOpen) + const handleToggle = () => { + if (open === model.id) { + setOpen('') + } else { + setOpen(model.id) + } + } return (
- {open === model.model && ( + + {open === model.id && (
- {toGibibytes(model.metadata?.size ?? 0)} + {toGibibytes(model.metadata.size)}
About

- {model.metadata?.description || '-'} + {model.description || '-'}

@@ -38,24 +49,24 @@ const ModelItem: React.FC = ({ model }) => { Author

- {model.metadata?.author} + {model.metadata.author}

Model ID

- {model.model} + {model.id}

Tags
- {model.metadata?.tags?.map((tag: string) => ( + {model.metadata.tags.map((tag: string) => ( {tag} @@ -68,7 +79,7 @@ const ModelItem: React.FC = ({ model }) => {
Format - {/*

= ({ model }) => { )} > {model.format} -

*/} +

diff --git a/web/screens/Hub/ModelList/index.tsx b/web/screens/Hub/ModelList/index.tsx index 17fd2265fa..f3f39d373b 100644 --- a/web/screens/Hub/ModelList/index.tsx +++ b/web/screens/Hub/ModelList/index.tsx @@ -19,20 +19,21 @@ const ModelList = ({ models }: Props) => { const remoteModels: Model[] = [] const localModels: Model[] = [] const remainingModels: Model[] = [] - models.forEach((m) => { if (m.metadata?.tags?.includes('Featured')) { featuredModels.push(m) - } else if (downloadedModels.map((x) => x.model).includes(m.model)) { + } else if (m.format === 'api') { + remoteModels.push(m) + } else if (downloadedModels.map((m) => m.id).includes(m.id)) { localModels.push(m) } else { remainingModels.push(m) } }) - featuredModels.sort((m1, m2) => m1.metadata?.size - m2.metadata?.size) - remoteModels.sort((m1, m2) => m1.model.localeCompare(m2.model)) - localModels.sort((m1, m2) => m1.metadata?.size - m2.metadata?.size) - remainingModels.sort((m1, m2) => m1.metadata?.size - m2.metadata?.size) + featuredModels.sort((m1, m2) => m1.metadata.size - m2.metadata.size) + remoteModels.sort((m1, m2) => m1.name.localeCompare(m2.name)) + localModels.sort((m1, m2) => m1.metadata.size - m2.metadata.size) + remainingModels.sort((m1, m2) => m1.metadata.size - m2.metadata.size) return [ ...featuredModels, ...remoteModels, @@ -43,9 +44,7 @@ const ModelList = ({ models }: Props) => { return (
- {sortedModels?.map((model) => ( - - ))} + {sortedModels?.map((model) => )}
) } diff --git a/web/screens/Hub/index.tsx b/web/screens/Hub/index.tsx index c527b31356..190efa136f 100644 --- a/web/screens/Hub/index.tsx +++ b/web/screens/Hub/index.tsx @@ -14,7 +14,10 @@ import { setImportModelStageAtom } from '@/hooks/useImportModel' import ModelList from '@/screens/Hub/ModelList' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' +import { + configuredModelsAtom, + downloadedModelsAtom, +} from '@/helpers/atoms/Model.atom' const sortMenus = [ { @@ -32,25 +35,26 @@ const sortMenus = [ ] const HubScreen = () => { + const configuredModels = useAtomValue(configuredModelsAtom) const downloadedModels = useAtomValue(downloadedModelsAtom) const [searchValue, setsearchValue] = useState('') const [sortSelected, setSortSelected] = useState('all-models') const setImportModelStage = useSetAtom(setImportModelStageAtom) - const filteredModels = downloadedModels.filter((model) => { + const filteredModels = configuredModels.filter((x) => { if (sortSelected === 'downloaded') { return ( - model.model?.toLowerCase().includes(searchValue.toLowerCase()) && - downloadedModels.some((m) => m.model === model.model) + x.name.toLowerCase().includes(searchValue.toLowerCase()) && + downloadedModels.some((y) => y.id === x.id) ) } else if (sortSelected === 'featured') { return ( - model.model?.toLowerCase().includes(searchValue.toLowerCase()) && - model.metadata?.tags.includes('Featured') + x.name.toLowerCase().includes(searchValue.toLowerCase()) && + x.metadata.tags.includes('Featured') ) } else { - return model.model?.toLowerCase().includes(searchValue.toLowerCase()) + return x.name.toLowerCase().includes(searchValue.toLowerCase()) } }) diff --git a/web/screens/HubScreen2/components/BuiltInModelCard.tsx b/web/screens/HubScreen2/components/BuiltInModelCard.tsx deleted file mode 100644 index e96b866014..0000000000 --- a/web/screens/HubScreen2/components/BuiltInModelCard.tsx +++ /dev/null @@ -1,173 +0,0 @@ -import { useMemo, useCallback } from 'react' - -import { Button, Progress } from '@janhq/joi' -import { useAtomValue, useSetAtom } from 'jotai' -import { CloudDownload } from 'lucide-react' - -import { toaster } from '@/containers/Toast' - -import useAbortDownload from '@/hooks/useAbortDownload' -import useAssistantQuery from '@/hooks/useAssistantQuery' -import { downloadStateListAtom } from '@/hooks/useDownloadState' -import useModelDownloadMutation from '@/hooks/useModelDownloadMutation' -import useThreads from '@/hooks/useThreads' - -import { formatDownloadPercentage } from '@/utils/converter' -import { downloadProgress } from '@/utils/download' -import { HfModelEntry } from '@/utils/huggingface' -import { addThousandSeparator } from '@/utils/number' - -import ModelTitle from './ModelTitle' - -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' -import { localModelModalStageAtom } from '@/helpers/atoms/DownloadLocalModel.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' - -const BuiltInModelCard: React.FC = ({ - name, - downloads, - model, -}) => { - const setLocalModelModalStage = useSetAtom(localModelModalStageAtom) - - const onItemClick = useCallback(() => { - setLocalModelModalStage('MODEL_LIST', name) - }, [setLocalModelModalStage, name]) - - const owner = model?.metadata?.owned_by ?? '' - const logoUrl = model?.metadata?.logo ?? '' - - return ( -
-
- - {name.replaceAll('cortexso/', '')} - - -
-
- - - {addThousandSeparator(downloads)} - - -
-
- ) -} - -type DownloadContainerProps = { - modelHandle: string -} - -const DownloadContainer: React.FC = ({ - modelHandle, -}) => { - const downloadModelMutation = useModelDownloadMutation() - const { abortDownload } = useAbortDownload() - const setMainViewState = useSetAtom(mainViewStateAtom) - const { createThread } = useThreads() - const setDownloadLocalModelModalStage = useSetAtom(localModelModalStageAtom) - - const downloadedModels = useAtomValue(downloadedModelsAtom) - const allDownloadState = useAtomValue(downloadStateListAtom) - const { data: assistants } = useAssistantQuery() - - const modelId = useMemo(() => `${modelHandle.split('/')[1]}`, [modelHandle]) - const downloadState = allDownloadState.find( - (downloadState) => downloadState.id == modelId - ) - const downloadedModel = useMemo( - () => downloadedModels.find((m) => m.model.split(':')[0] === modelId), - [downloadedModels, modelId] - ) - - const onDownloadClick = useCallback(() => { - downloadModelMutation.mutate({ modelId }) - }, [downloadModelMutation, modelId]) - - const onUseModelClick = useCallback(async () => { - if (!assistants || assistants.length === 0) { - toaster({ - title: 'No assistant available.', - description: 'Please create an assistant to create a new thread', - type: 'error', - }) - return - } - await createThread(modelId, { - ...assistants[0], - model: modelId, - }) - setDownloadLocalModelModalStage('NONE', undefined) - setMainViewState(MainViewState.Thread) - }, [ - setDownloadLocalModelModalStage, - setMainViewState, - createThread, - modelId, - assistants, - ]) - - return ( -
- {downloadedModel ? ( - - ) : downloadState != null ? ( - - ) : ( - - )} -
- ) -} - -export default BuiltInModelCard diff --git a/web/screens/HubScreen2/components/BuiltInModelGroup.tsx b/web/screens/HubScreen2/components/BuiltInModelGroup.tsx deleted file mode 100644 index 983ed90020..0000000000 --- a/web/screens/HubScreen2/components/BuiltInModelGroup.tsx +++ /dev/null @@ -1,62 +0,0 @@ -import { Fragment } from 'react' - -import React from 'react' - -import Image from 'next/image' - -import { Button } from '@janhq/joi' - -import { useAtomValue } from 'jotai' - -import useModelHub from '@/hooks/useModelHub' - -import { HfModelEntry } from '@/utils/huggingface' - -import BuiltInModelCard from './BuiltInModelCard' - -import { hubFilterAtom } from '@/helpers/atoms/Hub.atom' - -type Props = { - onSeeAllClick: () => void -} - -const BuiltInModelGroup: React.FC = ({ onSeeAllClick }) => { - const { data } = useModelHub() - const activeFilter = useAtomValue(hubFilterAtom) - - if (!data) return null - - const models: HfModelEntry[] = ( - data.modelCategories.get('BuiltInModels') ?? [] - ).slice(0, activeFilter === 'On-device' ? 6 : 4) - if (models.length === 0) return null - - return ( - -
- Built-In Models -

Built-In Models

- -
- -
- {models.map((model) => ( - - ))} -
-
- ) -} - -export default React.memo(BuiltInModelGroup) diff --git a/web/screens/HubScreen2/components/Carousel.tsx b/web/screens/HubScreen2/components/Carousel.tsx deleted file mode 100644 index 44d0bafb55..0000000000 --- a/web/screens/HubScreen2/components/Carousel.tsx +++ /dev/null @@ -1,253 +0,0 @@ -'use client' - -import React from 'react' - -import { Button } from '@janhq/joi' -import useEmblaCarousel, { - type UseEmblaCarouselType, -} from 'embla-carousel-react' - -import { ArrowLeft, ArrowRight } from 'lucide-react' -import { twMerge } from 'tailwind-merge' - -type CarouselApi = UseEmblaCarouselType[1] -type UseCarouselParameters = Parameters -type CarouselOptions = UseCarouselParameters[0] -type CarouselPlugin = UseCarouselParameters[1] - -type CarouselProps = { - opts?: CarouselOptions - plugins?: CarouselPlugin - orientation?: 'horizontal' | 'vertical' - setApi?: (api: CarouselApi) => void -} - -type CarouselContextProps = { - carouselRef: ReturnType[0] - api: ReturnType[1] - scrollPrev: () => void - scrollNext: () => void - canScrollPrev: boolean - canScrollNext: boolean -} & CarouselProps - -const CarouselContext = React.createContext(null) - -function useCarousel() { - const context = React.useContext(CarouselContext) - - if (!context) { - throw new Error('useCarousel must be used within a ') - } - - return context -} - -const Carousel = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes & CarouselProps ->( - ( - { - orientation = 'horizontal', - opts, - setApi, - plugins, - className, - children, - ...props - }, - ref - ) => { - const [carouselRef, api] = useEmblaCarousel( - { - ...opts, - axis: orientation === 'horizontal' ? 'x' : 'y', - }, - plugins - ) - const [canScrollPrev, setCanScrollPrev] = React.useState(false) - const [canScrollNext, setCanScrollNext] = React.useState(false) - - const onSelect = React.useCallback((api: CarouselApi) => { - if (!api) { - return - } - - setCanScrollPrev(api.canScrollPrev()) - setCanScrollNext(api.canScrollNext()) - }, []) - - const scrollPrev = React.useCallback(() => { - api?.scrollPrev() - }, [api]) - - const scrollNext = React.useCallback(() => { - api?.scrollNext() - }, [api]) - - const handleKeyDown = React.useCallback( - (event: React.KeyboardEvent) => { - if (event.key === 'ArrowLeft') { - event.preventDefault() - scrollPrev() - } else if (event.key === 'ArrowRight') { - event.preventDefault() - scrollNext() - } - }, - [scrollPrev, scrollNext] - ) - - React.useEffect(() => { - if (!api || !setApi) { - return - } - - setApi(api) - }, [api, setApi]) - - React.useEffect(() => { - if (!api) { - return - } - - onSelect(api) - api.on('reInit', onSelect) - api.on('select', onSelect) - - return () => { - api?.off('select', onSelect) - } - }, [api, onSelect]) - - return ( - -
- {children} -
-
- ) - } -) -Carousel.displayName = 'Carousel' - -const CarouselContent = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => { - const { carouselRef } = useCarousel() - - return ( -
-
-
- ) -}) -CarouselContent.displayName = 'CarouselContent' - -const CarouselItem = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => { - const { orientation } = useCarousel() - - return ( -
- ) -}) -CarouselItem.displayName = 'CarouselItem' - -const CarouselPrevious = React.forwardRef< - HTMLButtonElement, - React.ComponentProps ->(({ className, ...props }, ref) => { - const { orientation, scrollPrev, canScrollPrev } = useCarousel() - - return ( - - ) -}) -CarouselPrevious.displayName = 'CarouselPrevious' - -const CarouselNext = React.forwardRef< - HTMLButtonElement, - React.ComponentProps ->(({ className, ...props }, ref) => { - const { orientation, scrollNext, canScrollNext } = useCarousel() - - return ( - - ) -}) -CarouselNext.displayName = 'CarouselNext' - -export { - type CarouselApi, - Carousel, - CarouselContent, - CarouselItem, - CarouselPrevious, - CarouselNext, -} diff --git a/web/screens/HubScreen2/components/DetailModelGroup.tsx b/web/screens/HubScreen2/components/DetailModelGroup.tsx deleted file mode 100644 index 9c86ab1f69..0000000000 --- a/web/screens/HubScreen2/components/DetailModelGroup.tsx +++ /dev/null @@ -1,150 +0,0 @@ -import { useCallback, useState } from 'react' - -import React from 'react' - -import { Input, ScrollArea } from '@janhq/joi' -import { ArrowLeft, Search } from 'lucide-react' - -import BlankState from '@/containers/BlankState' - -import CenterPanelContainer from '@/containers/CenterPanelContainer' - -import useModelHub, { ModelHubCategory } from '@/hooks/useModelHub' - -import { HfModelEntry } from '@/utils/huggingface' - -import { getLogoByCategory } from '@/utils/model-engine' - -import BuiltInModelCard from './BuiltInModelCard' -import GroupInfo from './GroupInfo' -import HuggingFaceModelCard from './HuggingFaceModelCard' -import RemoteModelCard from './RemoteModelCard' - -type Props = { - category: ModelHubCategory - onBackClicked: () => void -} - -const DetailModelGroup: React.FC = ({ category, onBackClicked }) => { - const [filter, setFilter] = useState('') - const { data } = useModelHub() - - const onFilterChange = useCallback( - (e: React.ChangeEvent) => { - setFilter(e.target.value) - }, - [] - ) - - if (!data) return null - - const modelEntries: HfModelEntry[] = [] - if (category === 'BuiltInModels') { - modelEntries.push(...(data.modelCategories.get('BuiltInModels') ?? [])) - } else if (category === 'HuggingFace') { - modelEntries.push(...(data.modelCategories.get('HuggingFace') ?? [])) - } else { - Object.entries(data.modelCategories).forEach(([key, value]) => { - if (key === category) { - modelEntries.push(...value) - } - }) - } - - const refinedImageUrl = - getLogoByCategory(category) ?? - modelEntries.find((entry) => entry.model?.metadata?.logo != null)?.model - ?.metadata?.logo - - const apiKeyUrl: string | undefined = modelEntries.find( - (entry) => entry.model?.metadata?.api_key_url != null - )?.model?.metadata?.api_key_url - - const filteredModels = - filter.trim().length > 0 - ? modelEntries.filter((model) => - model.name.toLowerCase().includes(filter.toLowerCase()) - ) - : modelEntries - - return ( - - -
- - -
- } - placeholder={ - category === 'HuggingFace' - ? 'Search or paste Hugging Face URL' - : 'Search' - } - value={filter} - onChange={onFilterChange} - /> -
- - {/*
- setMinRange(parseInt(e.target.value, 10))} - /> - setMaxRange(parseInt(e.target.value, 10))} - /> -
-
-
- from - -
-
- to - -
-
- - ) -} - -export default SlideRange diff --git a/web/screens/HubScreen2/components/DownloadLocalModelModal.tsx b/web/screens/HubScreen2/components/DownloadLocalModelModal.tsx deleted file mode 100644 index dd1a87e263..0000000000 --- a/web/screens/HubScreen2/components/DownloadLocalModelModal.tsx +++ /dev/null @@ -1,87 +0,0 @@ -import { Fragment, useEffect, useState, useCallback } from 'react' - -import { Modal } from '@janhq/joi' -import { useAtom } from 'jotai' - -import HeaderModal from './HeaderModal' -import HfListModel from './HfListModel' -import ListModel from './ListModel' - -import ModelInformation from './ModelInformation' -import Tab, { ModelTab } from './Tab' - -import { localModelModalStageAtom } from '@/helpers/atoms/DownloadLocalModel.atom' - -const DownloadLocalModelModal: React.FC = () => { - const [availableModels, setAvailableModels] = useState([]) - const [{ stage, modelHandle }, setLocalModelModalStage] = useAtom( - localModelModalStageAtom - ) - const [tab, setTab] = useState('Versions') - const [height, setHeight] = useState(0) - - useEffect(() => { - const updateHeight = () => { - setHeight(window.innerHeight - window.innerHeight * 0.4) - } - window.addEventListener('resize', updateHeight) - updateHeight() - return () => { - window.removeEventListener('resize', updateHeight) - } - }, []) - - const modelName = modelHandle?.split('/')[1] ?? '' - const isFromCortexHub = modelHandle?.includes('cortexso') ?? false - - const onModelBranchChanged = useCallback( - (models: string[]) => { - const isFromCortexHub = modelHandle?.includes('cortexso') ?? false - if (isFromCortexHub) { - setAvailableModels(models) - } else { - setAvailableModels(modelHandle != null ? [modelHandle] : []) - } - }, - [modelHandle] - ) - - if (!modelHandle) return null - - return ( - setLocalModelModalStage('NONE', undefined)} - content={ - - {}} - isLocalModel={true} - availableModels={availableModels} - /> - setTab(input as 'Versions' | 'Information')} - /> - {tab === 'Versions' && - (isFromCortexHub ? ( - - ) : ( - - ))} - {tab === 'Information' && ( - - )} - - } - /> - ) -} - -export default DownloadLocalModelModal diff --git a/web/screens/HubScreen2/components/DropdownModal/index.tsx b/web/screens/HubScreen2/components/DropdownModal/index.tsx deleted file mode 100644 index 665d79ae6f..0000000000 --- a/web/screens/HubScreen2/components/DropdownModal/index.tsx +++ /dev/null @@ -1,26 +0,0 @@ -import { - DropdownMenu, - DropdownMenuTrigger, - DropdownMenuContent, - DropdownMenuPortal, -} from '@janhq/joi' - -type Props = { - trigger: React.ReactNode - content: React.ReactNode - className?: string -} - -const DropdownModal: React.FC = ({ trigger, content, className }) => { - return ( - - {trigger} - - - {content} - - - - ) -} -export default DropdownModal diff --git a/web/screens/HubScreen2/components/EmptyIcon.tsx b/web/screens/HubScreen2/components/EmptyIcon.tsx deleted file mode 100644 index d2f2100a4f..0000000000 --- a/web/screens/HubScreen2/components/EmptyIcon.tsx +++ /dev/null @@ -1,134 +0,0 @@ -import React from 'react' - -const EmptyIcon: React.FC = () => ( - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -) -export default React.memo(EmptyIcon) diff --git a/web/screens/HubScreen2/components/Filter.tsx b/web/screens/HubScreen2/components/Filter.tsx deleted file mode 100644 index 5e148f2f06..0000000000 --- a/web/screens/HubScreen2/components/Filter.tsx +++ /dev/null @@ -1,33 +0,0 @@ -import React from 'react' - -import { Button } from '@janhq/joi' -import { twMerge } from 'tailwind-merge' - -import { ModelFilter, ModelFilters } from '..' - -type Props = { - currentFilter: ModelFilter - onFilterClicked: (filter: ModelFilter) => void - callback: () => void -} - -const Filter: React.FC = ({ currentFilter, onFilterClicked }) => ( -
- {ModelFilters.map((filter) => ( - - ))} -
-) - -export default Filter diff --git a/web/screens/HubScreen2/components/FormatSelect.tsx b/web/screens/HubScreen2/components/FormatSelect.tsx deleted file mode 100644 index 9f1d3ae7da..0000000000 --- a/web/screens/HubScreen2/components/FormatSelect.tsx +++ /dev/null @@ -1,58 +0,0 @@ -import { useState } from 'react' - -import { ChevronDown } from 'lucide-react' - -const FormatSelect: React.FC = () => { - const [format, setFormat] = useState([ - 'GGUF', - 'ONNX', - 'TensorRT-LLM', - ]) - const checkBoxes = ['GGUF', 'ONNX', 'TensorRT-LLM'] - const [show, setShow] = useState(false) - - return ( -
- - {show && ( -
- {checkBoxes.map((item) => ( -
- { - if (e.target.checked) { - setFormat([...format, item]) - } else { - setFormat(format.filter((f) => f !== item)) - } - }} - /> - -
- ))} -
- )} -
- ) -} - -export default FormatSelect diff --git a/web/screens/HubScreen2/components/GroupInfo.tsx b/web/screens/HubScreen2/components/GroupInfo.tsx deleted file mode 100644 index 3ac585b833..0000000000 --- a/web/screens/HubScreen2/components/GroupInfo.tsx +++ /dev/null @@ -1,103 +0,0 @@ -import { useCallback, useMemo } from 'react' - -import Image from 'next/image' - -import { EngineStatus, RemoteEngine, RemoteEngines } from '@janhq/core' - -import { Button } from '@janhq/joi' -import { useSetAtom } from 'jotai' -import { Settings } from 'lucide-react' - -import useEngineQuery from '@/hooks/useEngineQuery' -import { ModelHubCategory } from '@/hooks/useModelHub' - -import { - getDescriptionByCategory, - getTitleByCategory, -} from '@/utils/model-engine' - -import { setUpRemoteModelStageAtom } from '@/helpers/atoms/SetupRemoteModel.atom' - -type Props = { - category: ModelHubCategory - imageUrl?: string - apiKeyUrl?: string -} - -const GroupInfo: React.FC = ({ category, imageUrl, apiKeyUrl }) => { - const title = getTitleByCategory(category) - const description = getDescriptionByCategory(category) - - const remoteEngine = RemoteEngines.find((engine) => engine === category) - - return ( -
-
- {imageUrl && ( - Group Logo - )} - {title} - {remoteEngine && ( - - )} -
- - {description} - -
- ) -} - -type SetUpProps = { - engine: RemoteEngine - imageUrl?: string - apiKeyUrl?: string -} - -const SetUpComponent: React.FC = ({ - imageUrl, - engine, - apiKeyUrl, -}) => { - const { data: engineData } = useEngineQuery() - const setUpRemoteModelStage = useSetAtom(setUpRemoteModelStageAtom) - - const isHasApiKey = useMemo( - () => - engineData == null - ? false - : engineData.find((e) => e.name === engine)?.status === - EngineStatus.Ready, - [engineData, engine] - ) - - const onSetUpClick = useCallback(() => { - setUpRemoteModelStage('SETUP_API_KEY', engine, { - logo: imageUrl, - api_key_url: apiKeyUrl, - }) - }, [setUpRemoteModelStage, engine, imageUrl, apiKeyUrl]) - - return ( -
- {isHasApiKey ? ( - - ) : ( - - )} -
- ) -} - -export default GroupInfo diff --git a/web/screens/HubScreen2/components/HeaderModal.tsx b/web/screens/HubScreen2/components/HeaderModal.tsx deleted file mode 100644 index 23a45244a1..0000000000 --- a/web/screens/HubScreen2/components/HeaderModal.tsx +++ /dev/null @@ -1,150 +0,0 @@ -import { Fragment, useCallback, useEffect, useRef, useState } from 'react' - -import Image from 'next/image' - -import { Button, Select } from '@janhq/joi' -import { ChevronsLeftRight, Copy, ExternalLink } from 'lucide-react' - -import { twMerge } from 'tailwind-merge' - -import DropdownModal from './DropdownModal' - -type Props = { - name: string - modelHandle?: string - availableModels: string[] - onActionClick: () => void - isLocalModel?: boolean -} - -const HeaderModal: React.FC = ({ - name, - modelHandle, - availableModels, - onActionClick, - isLocalModel = false, -}) => { - const [options, setOptions] = useState<{ name: string; value: string }[]>([]) - const [selectedVariant, setSelectedVariant] = useState() - const textRef = useRef(null) - - useEffect(() => { - const isFromCortexHub = modelHandle?.includes('cortexso') ?? false - if (!isLocalModel) { - setOptions( - availableModels.map((variant) => ({ - name: variant, - value: variant, - })) - ) - if (availableModels.length > 0) { - setSelectedVariant(availableModels[0]) - } - return - } - - if (isLocalModel && !isFromCortexHub) { - setOptions([ - { - name: modelHandle ?? '', - value: modelHandle ?? '', - }, - ]) - setSelectedVariant(modelHandle) - return - } - - setOptions( - availableModels.map((variant) => ({ - name: `${name}:${variant}`, - value: `${name}:${variant}`, - })) - ) - if (availableModels.length > 0) { - setSelectedVariant(`${name}:${availableModels[0]}`) - } - }, [availableModels, name, modelHandle, isLocalModel]) - - const onCopyClicked = useCallback(() => { - navigator.clipboard.writeText(textRef.current?.innerText ?? '') - }, []) - - const title = name.charAt(0).toUpperCase() + name.slice(1) - - if (!selectedVariant) return null - - return ( -
- {title} - - Cortex icon - - Cortex - - -
- } - content={ - - - - {downloadableModels.map((item) => ( - - - - - - ))} - -
-
- {item.quantization} -
-
{item.rfilename} -
- - {toGibibytes(item.fileSize)} - - -
-
-
- - ) -} - -type DownloadContainerProps = { - modelHandle: string - fileName: string -} - -const DownloadContainer: React.FC = ({ - modelHandle, - fileName, -}) => { - const downloadModelMutation = useModelDownloadMutation() - const { abortDownload } = useAbortDownload() - const setMainViewState = useSetAtom(mainViewStateAtom) - const { createThread } = useThreads() - const { data: assistants } = useAssistantQuery() - - const setDownloadLocalModelModalStage = useSetAtom(localModelModalStageAtom) - - const downloadedModels = useAtomValue(downloadedModelsAtom) - const allDownloadState = useAtomValue(downloadStateListAtom) - - const persistModelId = modelHandle - .replaceAll('/', '_') - .concat('_') - .concat(fileName) - - const downloadState = allDownloadState.find((s) => s.id === persistModelId) - - const downloadedModel = useMemo( - () => downloadedModels.find((m) => m.model === persistModelId), - [downloadedModels, persistModelId] - ) - - const onDownloadClick = useCallback(async () => { - downloadModelMutation.mutate({ - modelId: modelHandle, - fileName: fileName, - persistedModelId: persistModelId, - }) - }, [downloadModelMutation, modelHandle, fileName, persistModelId]) - - const onUseModelClick = useCallback(async () => { - if (!assistants || assistants.length === 0) { - toaster({ - title: 'No assistant available.', - description: 'Please create an assistant to create a new thread', - type: 'error', - }) - return - } - - await createThread(fileName, { - ...assistants[0], - model: fileName, - }) - setDownloadLocalModelModalStage('NONE', undefined) - setMainViewState(MainViewState.Thread) - }, [ - setDownloadLocalModelModalStage, - setMainViewState, - createThread, - fileName, - assistants, - ]) - - return ( -
- {downloadedModel ? ( - - ) : downloadState != null ? ( - - ) : ( - - )} -
- ) -} - -export default HfListModel diff --git a/web/screens/HubScreen2/components/HubModelCard.tsx b/web/screens/HubScreen2/components/HubModelCard.tsx deleted file mode 100644 index 667bc56c1c..0000000000 --- a/web/screens/HubScreen2/components/HubModelCard.tsx +++ /dev/null @@ -1,187 +0,0 @@ -import React, { useCallback, useMemo } from 'react' - -import { EngineStatus, LocalEngines, RemoteEngine } from '@janhq/core' - -import { Button } from '@janhq/joi' -import { useQueryClient } from '@tanstack/react-query' -import { useAtomValue, useSetAtom } from 'jotai' - -import { CloudDownload } from 'lucide-react' - -import { toaster } from '@/containers/Toast' - -import useAssistantQuery from '@/hooks/useAssistantQuery' - -import useCortex from '@/hooks/useCortex' - -import useEngineQuery from '@/hooks/useEngineQuery' -import { modelQueryKey } from '@/hooks/useModelQuery' -import useThreads from '@/hooks/useThreads' - -import { HfModelEntry } from '@/utils/huggingface' - -import { addThousandSeparator } from '@/utils/number' - -import ModelTitle from './ModelTitle' - -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' -import { localModelModalStageAtom } from '@/helpers/atoms/DownloadLocalModel.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' -import { setUpRemoteModelStageAtom } from '@/helpers/atoms/SetupRemoteModel.atom' - -const HubModelCard: React.FC = ({ name, downloads, model }) => { - const downloadedModels = useAtomValue(downloadedModelsAtom) - const { data: assistants } = useAssistantQuery() - const { data: engineData } = useEngineQuery() - const queryClient = useQueryClient() - - const setUpRemoteModelStage = useSetAtom(setUpRemoteModelStageAtom) - const setLocalModelModalStage = useSetAtom(localModelModalStageAtom) - - const { createThread } = useThreads() - const setMainViewState = useSetAtom(mainViewStateAtom) - const { createModel } = useCortex() - - const isLocalModel = useMemo( - () => - model == null || - LocalEngines.filter((e) => e === model.engine).length > 0, - [model] - ) - - const actionLabel = useMemo(() => { - if (isLocalModel) return 'Download' - - const isEngineConfigured: boolean = - engineData == null || model?.engine == null - ? false - : engineData.find((e) => e.name === model.engine)?.status === - EngineStatus.Ready - - const isModelDownloaded = downloadedModels.find( - (m) => m.model === model!.model - ) - - if (isEngineConfigured && isModelDownloaded) return 'Use' - - if (!isEngineConfigured && !isModelDownloaded) return 'Setup' - - if (isModelDownloaded) return 'Setup API Key' - - return 'Add' - }, [model, isLocalModel, downloadedModels, engineData]) - - const onActionClick = useCallback(() => { - if (isLocalModel) { - setLocalModelModalStage('MODEL_LIST', name) - } else { - if (!model) return - - const isApiKeyAdded: boolean = - engineData == null || model?.engine == null - ? false - : engineData.find((e) => e.name === model.engine)?.status === - EngineStatus.Ready - - const isModelDownloaded = downloadedModels.find( - (m) => m.model === model.model - ) - - if (isApiKeyAdded && isModelDownloaded) { - if (!assistants || assistants.length === 0) { - toaster({ - title: 'No assistant available.', - description: 'Please create an assistant to create a new thread', - type: 'error', - }) - return - } - - // use - createThread(model.model, { - ...assistants[0], - model: model.model, - }) - .then(() => { - setMainViewState(MainViewState.Thread) - }) - .catch((err) => { - console.log('Error creating thread', err) - }) - return - } - - if (!isApiKeyAdded && !isModelDownloaded) { - setUpRemoteModelStage( - 'SETUP_INTRO', - model.engine as RemoteEngine, - model.metadata - ) - return - } - - if (isModelDownloaded) { - // when model is downloaded but key is not there or deleted, we need to setup api key - setUpRemoteModelStage( - 'SETUP_API_KEY', - model.engine as RemoteEngine, - model.metadata - ) - return - } - - if (isApiKeyAdded) { - createModel(model).then(() => { - queryClient.invalidateQueries({ queryKey: modelQueryKey }) - }) - return - } - } - }, [ - createModel, - createThread, - setMainViewState, - setUpRemoteModelStage, - setLocalModelModalStage, - name, - model, - engineData, - isLocalModel, - downloadedModels, - assistants, - queryClient, - ]) - - const owner = model?.metadata?.owned_by ?? '' - const logoUrl = model?.metadata?.logo ?? '' - - return ( -
-
- {name} - -
-
- - - {addThousandSeparator(downloads)} - - -
-
- ) -} - -export default HubModelCard diff --git a/web/screens/HubScreen2/components/HubScreenFilter/index.tsx b/web/screens/HubScreen2/components/HubScreenFilter/index.tsx deleted file mode 100644 index 90a811c790..0000000000 --- a/web/screens/HubScreen2/components/HubScreenFilter/index.tsx +++ /dev/null @@ -1,96 +0,0 @@ -import { Fragment } from 'react' - -import BlankState from '@/containers/BlankState' - -import useModelHub from '@/hooks/useModelHub' - -import { HfModelEntry } from '@/utils/huggingface' - -import BuiltInModelCard from '../BuiltInModelCard' -import HuggingFaceModelCard from '../HuggingFaceModelCard' -import RemoteModelCard from '../RemoteModelCard' - -type Props = { - queryText: string -} - -const HubScreenFilter: React.FC = ({ queryText }) => { - const { data } = useModelHub() - - if (!data) return null - const builtInModels = data.modelCategories.get('BuiltInModels') ?? [] - const huggingFaceModels = data.modelCategories.get('HuggingFace') ?? [] - const remoteModels: HfModelEntry[] = [] - - for (const [key, value] of data.modelCategories) { - if (key !== 'HuggingFace' && key !== 'BuiltInModels') { - remoteModels.push(...value) - } - } - - const filteredBuiltInModels = builtInModels.filter((model) => { - return model.name.toLowerCase().includes(queryText.toLowerCase()) - }) - - const filteredHuggingFaceModels = huggingFaceModels.filter((model) => { - return model.name.toLowerCase().includes(queryText.toLowerCase()) - }) - - const filteredRemoteModels = remoteModels.filter((model) => { - return model.name.toLowerCase().includes(queryText.toLowerCase()) - }) - - const isResultEmpty: boolean = - filteredBuiltInModels.length === 0 && - filteredHuggingFaceModels.length === 0 && - filteredRemoteModels.length === 0 - - const isOnDevice = - filteredBuiltInModels.length > 0 || filteredHuggingFaceModels.length > 0 - - return ( -
- {isResultEmpty ? ( -
- -
- ) : ( -
- {isOnDevice && ( - -
-

On-Device Models

-
-
- {filteredBuiltInModels.map((hfModelEntry) => ( - - ))} - {filteredHuggingFaceModels.map((hfModelEntry) => ( - - ))} -
-
- )} - - {filteredRemoteModels.length > 0 && ( - -
-

Cloud Models

-
-
- {filteredRemoteModels.map((hfModelEntry) => ( - - ))} -
-
- )} -
- )} -
- ) -} - -export default HubScreenFilter diff --git a/web/screens/HubScreen2/components/HuggingFaceModelCard.tsx b/web/screens/HubScreen2/components/HuggingFaceModelCard.tsx deleted file mode 100644 index 87b16c8403..0000000000 --- a/web/screens/HubScreen2/components/HuggingFaceModelCard.tsx +++ /dev/null @@ -1,189 +0,0 @@ -import { useMemo, useCallback } from 'react' - -import { Button, Progress } from '@janhq/joi' -import { useAtomValue, useSetAtom } from 'jotai' -import { CloudDownload } from 'lucide-react' - -import { toaster } from '@/containers/Toast' - -import useAbortDownload from '@/hooks/useAbortDownload' -import useAssistantQuery from '@/hooks/useAssistantQuery' - -import { downloadStateListAtom } from '@/hooks/useDownloadState' -import useHfModelFetchAndDownload from '@/hooks/useHfModelFetchAndDownload' -import useThreads from '@/hooks/useThreads' - -import { formatDownloadPercentage } from '@/utils/converter' -import { downloadProgress } from '@/utils/download' -import { HfModelEntry } from '@/utils/huggingface' -import { addThousandSeparator } from '@/utils/number' - -import ModelTitle from './ModelTitle' - -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' -import { localModelModalStageAtom } from '@/helpers/atoms/DownloadLocalModel.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' - -const HuggingFaceModelCard: React.FC = ({ - name, - downloads, - model, -}) => { - const setLocalModelModalStage = useSetAtom(localModelModalStageAtom) - - const onItemClick = useCallback(() => { - setLocalModelModalStage('MODEL_LIST', name) - }, [setLocalModelModalStage, name]) - - const owner = model?.metadata?.owned_by ?? '' - const logoUrl = model?.metadata?.logo ?? '' - - return ( -
-
- - {name} - - -
-
- - - {addThousandSeparator(downloads)} - - -
-
- ) -} - -type DownloadContainerProps = { - modelHandle: string -} - -const DownloadContainer: React.FC = ({ - modelHandle, -}) => { - const { abortDownload } = useAbortDownload() - const setMainViewState = useSetAtom(mainViewStateAtom) - const { createThread } = useThreads() - const setDownloadLocalModelModalStage = useSetAtom(localModelModalStageAtom) - - const downloadedModels = useAtomValue(downloadedModelsAtom) - const allDownloadState = useAtomValue(downloadStateListAtom) - const { data: assistants } = useAssistantQuery() - const { fetchDataAndDownload } = useHfModelFetchAndDownload() - - const modelIdPrefix = modelHandle.replaceAll('/', '_') - - const downloadState = allDownloadState.find((s) => - s.id.startsWith(modelIdPrefix) - ) - - const downloadedModel = useMemo( - () => downloadedModels.find((m) => m.model.startsWith(modelIdPrefix)), - [downloadedModels, modelIdPrefix] - ) - - const onDownloadClick = useCallback( - async () => fetchDataAndDownload(modelHandle), - [fetchDataAndDownload, modelHandle] - ) - - const onUseModelClick = useCallback(async () => { - if (!downloadedModel) { - console.error('Downloaded model not found') - return - } - if (!assistants || assistants.length === 0) { - toaster({ - title: 'No assistant available.', - description: 'Please create an assistant to create a new thread', - type: 'error', - }) - return - } - await createThread(downloadedModel.model, { - ...assistants[0], - model: downloadedModel.model, - }) - setDownloadLocalModelModalStage('NONE', undefined) - setMainViewState(MainViewState.Thread) - }, [ - setDownloadLocalModelModalStage, - setMainViewState, - createThread, - downloadedModel, - assistants, - ]) - - const onAbortDownloadClick = useCallback(() => { - if (!downloadState) { - console.error('Download state not found') - return - } - abortDownload(downloadState.id) - }, [abortDownload, downloadState]) - - return ( -
- {downloadedModel ? ( - - ) : downloadState != null ? ( - - ) : ( - - )} -
- ) -} - -export default HuggingFaceModelCard diff --git a/web/screens/HubScreen2/components/HuggingFaceModelGroup.tsx b/web/screens/HubScreen2/components/HuggingFaceModelGroup.tsx deleted file mode 100644 index 535ad81f0c..0000000000 --- a/web/screens/HubScreen2/components/HuggingFaceModelGroup.tsx +++ /dev/null @@ -1,62 +0,0 @@ -import { Fragment } from 'react' - -import React from 'react' - -import Image from 'next/image' - -import { Button } from '@janhq/joi' - -import { useAtomValue } from 'jotai' - -import useModelHub from '@/hooks/useModelHub' - -import { HfModelEntry } from '@/utils/huggingface' - -import HuggingFaceModelCard from './HuggingFaceModelCard' - -import { hubFilterAtom } from '@/helpers/atoms/Hub.atom' - -type Props = { - onSeeAllClick: () => void -} - -const HuggingFaceModelGroup: React.FC = ({ onSeeAllClick }) => { - const { data } = useModelHub() - const activeFilter = useAtomValue(hubFilterAtom) - - if (!data) return null - - const models: HfModelEntry[] = ( - data.modelCategories.get('HuggingFace') ?? [] - ).slice(0, activeFilter === 'On-device' ? 6 : 4) - if (models.length === 0) return null - - return ( - -
- Hugging Face -

Hugging Face

- -
- -
- {models.map((hfModelEntry) => ( - - ))} -
-
- ) -} - -export default React.memo(HuggingFaceModelGroup) diff --git a/web/screens/HubScreen2/components/InferenceErrorModal.tsx b/web/screens/HubScreen2/components/InferenceErrorModal.tsx deleted file mode 100644 index eef456e98a..0000000000 --- a/web/screens/HubScreen2/components/InferenceErrorModal.tsx +++ /dev/null @@ -1,45 +0,0 @@ -import { Fragment, useCallback } from 'react' - -import { LlmEngine } from '@janhq/core' -import { Button, Modal, ModalClose } from '@janhq/joi' -import { atom, useAtom } from 'jotai' - -export type InferenceError = { - message: string - engine?: LlmEngine -} - -export const inferenceErrorAtom = atom(undefined) - -const InferenceErrorModal: React.FC = () => { - const [inferenceError, setInferenceError] = useAtom(inferenceErrorAtom) - - const onClose = useCallback(() => { - setInferenceError(undefined) - }, [setInferenceError]) - - return ( - -

- {inferenceError?.message} -

-
- - - -
- - } - /> - ) -} - -export default InferenceErrorModal diff --git a/web/screens/HubScreen2/components/InputApiKey.tsx b/web/screens/HubScreen2/components/InputApiKey.tsx deleted file mode 100644 index 8c8ad6704f..0000000000 --- a/web/screens/HubScreen2/components/InputApiKey.tsx +++ /dev/null @@ -1,13 +0,0 @@ -type Props = { - className?: string - placeholder?: string -} - -const InputApiKey: React.FC = ({ className, placeholder }) => ( - -) -export default InputApiKey diff --git a/web/screens/HubScreen2/components/ListModel.tsx b/web/screens/HubScreen2/components/ListModel.tsx deleted file mode 100644 index 46cb951f20..0000000000 --- a/web/screens/HubScreen2/components/ListModel.tsx +++ /dev/null @@ -1,243 +0,0 @@ -import { Fragment, useCallback, useEffect, useMemo, useState } from 'react' - -import { Button, Progress, Select } from '@janhq/joi' -import { useAtomValue, useSetAtom } from 'jotai' - -import Spinner from '@/containers/Loader/Spinner' -import { toaster } from '@/containers/Toast' - -import useAbortDownload from '@/hooks/useAbortDownload' -import useAssistantQuery from '@/hooks/useAssistantQuery' - -import { downloadStateListAtom } from '@/hooks/useDownloadState' - -import useEngineQuery from '@/hooks/useEngineQuery' -import useHfEngineToBranchesQuery from '@/hooks/useHfEngineToBranchesQuery' - -import useModelDownloadMutation from '@/hooks/useModelDownloadMutation' -import useThreads from '@/hooks/useThreads' - -import { formatDownloadPercentage, toGibibytes } from '@/utils/converter' - -import { downloadProgress } from '@/utils/download' - -import { CortexHubModel, EngineType } from '@/utils/huggingface' - -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' -import { localModelModalStageAtom } from '@/helpers/atoms/DownloadLocalModel.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' - -type Props = { - modelHandle: string - onBranchSelected?: (availableSelections: string[]) => void -} - -const ListModel: React.FC = ({ modelHandle, onBranchSelected }) => { - const { data: engineData } = useEngineQuery() - const { data, isLoading } = useHfEngineToBranchesQuery(modelHandle) - - const [engineFilter, setEngineFilter] = useState( - undefined - ) - - const engineSelections: { name: string; value: string }[] = useMemo(() => { - if (!data || !engineData) return [] - - const isSupportTensorRt = - engineData.find((engine) => engine.name === 'cortex.tensorrt-llm') - ?.status !== 'not_supported' ?? false - - const isSupportOnnx = - engineData.find((engine) => engine.name === 'cortex.onnx')?.status !== - 'not_supported' ?? false - - const result: { name: string; value: string }[] = [] - if (data.gguf.length > 0) result.push({ name: 'GGUF', value: 'gguf' }) - if (isSupportOnnx && data.onnx.length > 0) - result.push({ name: 'ONNX', value: 'onnx' }) - if (isSupportTensorRt && data.tensorrtllm.length > 0) - result.push({ name: 'TensorRT', value: 'tensorrtllm' }) - - return result - }, [data, engineData]) - - const modelBranches: CortexHubModel[] = useMemo((): CortexHubModel[] => { - if (!data) return [] - return (data[engineFilter as EngineType] as CortexHubModel[]) ?? [] - }, [data, engineFilter]) - - useEffect(() => { - if (engineSelections.length === 0) return - setEngineFilter(engineSelections[0].value as EngineType) - }, [engineSelections]) - - useEffect(() => { - const models = modelBranches.map((m) => m.name) - onBranchSelected?.(models) - }, [modelBranches, onBranchSelected]) - - const onSelectionChanged = useCallback( - (selectionValue: string) => { - setEngineFilter(selectionValue as EngineType) - const models = modelBranches.map((m) => m.name) - onBranchSelected?.(models) - }, - [setEngineFilter, onBranchSelected, modelBranches] - ) - - if (isLoading) - return ( -
- -
- ) - - return ( - -
- Format: - - - {modelBranches.map((item) => ( - - - - - - ))} - -
-
- {item.name} -
-
- {item.name} - -
- {item.fileSize && {toGibibytes(item.fileSize)}} - -
-
-
-
- ) -} - -type DownloadContainerProps = { - modelHandle: string - branch: string -} - -const DownloadContainer: React.FC = ({ - modelHandle, - branch, -}) => { - const downloadModelMutation = useModelDownloadMutation() - const { abortDownload } = useAbortDownload() - const setMainViewState = useSetAtom(mainViewStateAtom) - const { createThread } = useThreads() - const setDownloadLocalModelModalStage = useSetAtom(localModelModalStageAtom) - - const downloadedModels = useAtomValue(downloadedModelsAtom) - const allDownloadState = useAtomValue(downloadStateListAtom) - const { data: assistants } = useAssistantQuery() - - const modelId = useMemo( - () => `${modelHandle.split('/')[1]}:${branch}`, - [modelHandle, branch] - ) - const downloadState = allDownloadState.find((s) => s.id == modelId) - - const isDownloaded = useMemo( - () => downloadedModels.find((m) => m.model === modelId), - [downloadedModels, modelId] - ) - - const onDownloadClick = useCallback(() => { - downloadModelMutation.mutate({ modelId }) - }, [downloadModelMutation, modelId]) - - const onUseModelClick = useCallback(async () => { - if (!assistants || assistants.length === 0) { - toaster({ - title: 'No assistant available.', - description: 'Please create an assistant to create a new thread', - type: 'error', - }) - return - } - await createThread(modelId, { - ...assistants[0], - model: modelId, - }) - setDownloadLocalModelModalStage('NONE', undefined) - setMainViewState(MainViewState.Thread) - }, [ - setDownloadLocalModelModalStage, - setMainViewState, - createThread, - modelId, - assistants, - ]) - - const onAbortDownloadClick = useCallback(() => { - abortDownload(modelId) - }, [abortDownload, modelId]) - - return ( -
- {isDownloaded ? ( - - ) : downloadState != null ? ( - - ) : ( - - )} -
- ) -} - -export default ListModel diff --git a/web/screens/HubScreen2/components/LoadingIndicator.tsx b/web/screens/HubScreen2/components/LoadingIndicator.tsx deleted file mode 100644 index 447705ef1d..0000000000 --- a/web/screens/HubScreen2/components/LoadingIndicator.tsx +++ /dev/null @@ -1,12 +0,0 @@ -import { memo } from 'react' - -import { Loader } from 'lucide-react' - -const LoadingIndicator: React.FC = () => { - return ( -
- -
- ) -} -export default memo(LoadingIndicator) diff --git a/web/screens/HubScreen2/components/ModelInformation.tsx b/web/screens/HubScreen2/components/ModelInformation.tsx deleted file mode 100644 index 7524dd99ac..0000000000 --- a/web/screens/HubScreen2/components/ModelInformation.tsx +++ /dev/null @@ -1,37 +0,0 @@ -import React from 'react' - -import { twMerge } from 'tailwind-merge' - -import Spinner from '@/containers/Loader/Spinner' - -import useGetReadMeContent from '@/hooks/useGetReadMeContent' - -type Props = { - modelHandle: string - maxHeight: number -} - -const ModelInformation: React.FC = ({ modelHandle, maxHeight }) => { - const { data, isLoading } = useGetReadMeContent(modelHandle) - if (isLoading) - return ( -
- -
- ) - - return ( -
- ) -} - -export default React.memo(ModelInformation) diff --git a/web/screens/HubScreen2/components/ModelSearchBar.tsx b/web/screens/HubScreen2/components/ModelSearchBar.tsx deleted file mode 100644 index 6ca3da884d..0000000000 --- a/web/screens/HubScreen2/components/ModelSearchBar.tsx +++ /dev/null @@ -1,97 +0,0 @@ -import React, { useCallback, useState } from 'react' - -import { Button, Input } from '@janhq/joi' -import { useSetAtom } from 'jotai' -import { SearchIcon } from 'lucide-react' -import { FoldersIcon } from 'lucide-react' -import { useDebouncedCallback } from 'use-debounce' - -import { toaster } from '@/containers/Toast' - -import { useGetHFRepoData } from '@/hooks/useGetHFRepoData' - -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' -import { - importHuggingFaceModelStageAtom, - importingHuggingFaceRepoDataAtom, -} from '@/helpers/atoms/HuggingFace.atom' -import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom' - -type Props = { - onSearchChanged: (query: string) => void -} - -const ModelSearchBar: React.FC = ({ onSearchChanged }) => { - const [searchText, setSearchText] = useState('') - const { getHfRepoData } = useGetHFRepoData() - const setMainViewState = useSetAtom(mainViewStateAtom) - const setSelectedSetting = useSetAtom(selectedSettingAtom) - - const setImportingHuggingFaceRepoData = useSetAtom( - importingHuggingFaceRepoDataAtom - ) - const setImportHuggingFaceModelStage = useSetAtom( - importHuggingFaceModelStageAtom - ) - - const debounced = useDebouncedCallback(async (searchText: string) => { - if (searchText.indexOf('/') === -1) { - // If we don't find / in the text, perform a local search - onSearchChanged?.(searchText) - return - } - - try { - const data = await getHfRepoData(searchText) - setImportingHuggingFaceRepoData(data) - setImportHuggingFaceModelStage('REPO_DETAIL') - } catch (err) { - let errMessage = 'Unexpected Error' - if (err instanceof Error) { - errMessage = err.message - } - toaster({ - title: 'Failed to get Hugging Face models', - description: errMessage, - type: 'error', - }) - console.error(err) - } - }, 300) - - const onQueryChanged = useCallback( - (e: React.ChangeEvent) => { - e.preventDefault() - e.stopPropagation() - const text = e.target.value - setSearchText(text) - debounced(text) - }, - [debounced] - ) - - return ( -
- } - placeholder="Search or paste Hugging Face URL" - value={searchText} - onChange={onQueryChanged} - /> - -
- ) -} - -export default ModelSearchBar diff --git a/web/screens/HubScreen2/components/ModelTitle.tsx b/web/screens/HubScreen2/components/ModelTitle.tsx deleted file mode 100644 index 269cfe9665..0000000000 --- a/web/screens/HubScreen2/components/ModelTitle.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import Image from 'next/image' - -type Props = { - name: string - image: string - className?: string -} - -const ModelTitle: React.FC = ({ name, image, className }) => ( -
- {image && bot} - {name} -
-) -export default ModelTitle diff --git a/web/screens/HubScreen2/components/RemoteModelCard.tsx b/web/screens/HubScreen2/components/RemoteModelCard.tsx deleted file mode 100644 index 0d2978307c..0000000000 --- a/web/screens/HubScreen2/components/RemoteModelCard.tsx +++ /dev/null @@ -1,133 +0,0 @@ -import React, { useCallback } from 'react' - -import { EngineStatus, RemoteEngine } from '@janhq/core' -import { useQueryClient } from '@tanstack/react-query' - -import { useAtomValue, useSetAtom } from 'jotai' - -import { toaster } from '@/containers/Toast' - -import useAssistantQuery from '@/hooks/useAssistantQuery' - -import useCortex from '@/hooks/useCortex' - -import useEngineQuery from '@/hooks/useEngineQuery' -import { modelQueryKey } from '@/hooks/useModelQuery' -import useThreads from '@/hooks/useThreads' - -import { HfModelEntry } from '@/utils/huggingface' - -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' -import { setUpRemoteModelStageAtom } from '@/helpers/atoms/SetupRemoteModel.atom' - -const RemoteModelCard: React.FC = ({ name, model }) => { - const { createThread } = useThreads() - const setMainViewState = useSetAtom(mainViewStateAtom) - const setUpRemoteModelStage = useSetAtom(setUpRemoteModelStageAtom) - const queryClient = useQueryClient() - - const { createModel } = useCortex() - const downloadedModels = useAtomValue(downloadedModelsAtom) - const { data: assistants } = useAssistantQuery() - const { data: engineData } = useEngineQuery() - - const modelDisplayName = model?.name ?? name - - const onClick = useCallback(async () => { - if (!model || !engineData) return - const isApiKeyAdded: boolean = - engineData == null || model?.engine == null - ? false - : engineData.find((e) => e.name === model.engine)?.status === - EngineStatus.Ready - - const isModelDownloaded = downloadedModels.find( - (m) => m.model === model.model - ) - - if (isApiKeyAdded && isModelDownloaded) { - if (!assistants || assistants.length === 0) { - toaster({ - title: 'No assistant available.', - description: 'Please create an assistant to create a new thread', - type: 'error', - }) - return - } - // use this model to create new thread - await createThread(model.model, { - ...assistants[0], - model: model.model, - }) - setMainViewState(MainViewState.Thread) - return - } - - if (!isApiKeyAdded) { - setUpRemoteModelStage('SETUP_INTRO', model.engine as RemoteEngine, { - ...model.metadata, - modelName: modelDisplayName, - modelId: model.model, - }) - return - } - - if (isModelDownloaded) { - // when model is downloaded but key is not there or deleted, we need to setup api key - setUpRemoteModelStage('SETUP_API_KEY', model.engine as RemoteEngine, { - ...model.metadata, - modelName: modelDisplayName, - }) - - return - } - - if (isApiKeyAdded) { - // TODO: useMutation reactQuery? - await createModel(model) - queryClient.invalidateQueries({ queryKey: modelQueryKey }) - if (!assistants || assistants.length === 0) { - toaster({ - title: 'No assistant available.', - description: 'Please create an assistant to create a new thread', - type: 'error', - }) - return - } - - // use this model to create new thread - await createThread(model.model, { - ...assistants[0], - model: model.model, - }) - setMainViewState(MainViewState.Thread) - return - } - }, [ - assistants, - engineData, - createModel, - createThread, - downloadedModels, - model, - setMainViewState, - setUpRemoteModelStage, - modelDisplayName, - queryClient, - ]) - - return ( -
-

- {modelDisplayName} -

-
- ) -} - -export default RemoteModelCard diff --git a/web/screens/HubScreen2/components/RemoteModelGroup.tsx b/web/screens/HubScreen2/components/RemoteModelGroup.tsx deleted file mode 100644 index 52b1eb2c23..0000000000 --- a/web/screens/HubScreen2/components/RemoteModelGroup.tsx +++ /dev/null @@ -1,108 +0,0 @@ -import { Fragment, useCallback, useMemo } from 'react' - -import React from 'react' - -import Image from 'next/image' - -import { EngineStatus, RemoteEngine } from '@janhq/core' -import { Button } from '@janhq/joi' - -import { useSetAtom } from 'jotai' -import { Settings } from 'lucide-react' - -import useEngineQuery from '@/hooks/useEngineQuery' - -import { HfModelEntry } from '@/utils/huggingface' - -import { getLogoByRemoteEngine, getTitleByCategory } from '@/utils/model-engine' - -import RemoteModelCard from './RemoteModelCard' - -import { setUpRemoteModelStageAtom } from '@/helpers/atoms/SetupRemoteModel.atom' - -type Props = { - data: HfModelEntry[] - engine: RemoteEngine - onSeeAllClick: () => void -} - -const RemoteModelGroup: React.FC = ({ data, engine, onSeeAllClick }) => { - const { data: engineData } = useEngineQuery() - const setUpRemoteModelStage = useSetAtom(setUpRemoteModelStageAtom) - - const apiKeyUrl: string | undefined = data.find( - (entry) => entry.model?.metadata?.api_key_url != null - )?.model?.metadata?.api_key_url - - const remoteEngineLogo = getLogoByRemoteEngine(engine as RemoteEngine) - - // get maximum 4 items - const models = data.slice(0, 4) - const showSeeAll = models.length < data.length - const refinedTitle = getTitleByCategory(engine) - - const isHasApiKey = useMemo( - () => - engineData == null || engine == null - ? false - : engineData.find((e) => e.name === engine)?.status === - EngineStatus.Ready, - [engineData, engine] - ) - - const onSetUpClick = useCallback(() => { - setUpRemoteModelStage('SETUP_API_KEY', engine, { - logo: remoteEngineLogo, - api_key_url: apiKeyUrl, - }) - }, [setUpRemoteModelStage, engine, remoteEngineLogo, apiKeyUrl]) - - return ( - -
- {remoteEngineLogo && ( - Engine logo - )} -

{refinedTitle}

- - {isHasApiKey ? ( - - ) : ( - - )} - - {showSeeAll && ( - - )} -
-
- {models.map((model) => ( - - ))} -
-
- ) -} - -export default React.memo(RemoteModelGroup) diff --git a/web/screens/HubScreen2/components/SetUpApiKeyModal.tsx b/web/screens/HubScreen2/components/SetUpApiKeyModal.tsx deleted file mode 100644 index 05fb8a9fd9..0000000000 --- a/web/screens/HubScreen2/components/SetUpApiKeyModal.tsx +++ /dev/null @@ -1,116 +0,0 @@ -import { Fragment, useCallback, useEffect, useState } from 'react' - -import Image from 'next/image' - -import { EngineStatus } from '@janhq/core' -import { Button, Input, Modal } from '@janhq/joi' -import { useAtom } from 'jotai' -import { ArrowUpRight } from 'lucide-react' - -import useEngineMutation from '@/hooks/useEngineMutation' -import useEngineQuery from '@/hooks/useEngineQuery' - -import { getTitleByCategory } from '@/utils/model-engine' - -import { setUpRemoteModelStageAtom } from '@/helpers/atoms/SetupRemoteModel.atom' - -const SetUpApiKeyModal: React.FC = () => { - const updateEngineConfig = useEngineMutation() - const { data: engineData } = useEngineQuery() - - const [{ stage, remoteEngine, metadata }, setUpRemoteModelStage] = useAtom( - setUpRemoteModelStageAtom - ) - const [apiKey, setApiKey] = useState('') - - useEffect(() => { - if (!remoteEngine || !engineData) return - const isEngineReady = - engineData.find((e) => e.name === remoteEngine)?.status === - EngineStatus.Ready - const fakeApiKey = '******************************************' - setApiKey(isEngineReady ? fakeApiKey : '') - }, [remoteEngine, engineData]) - - const onSaveClicked = useCallback(async () => { - if (!remoteEngine) { - alert('Does not have engine') - return - } - const normalizedApiKey = apiKey.trim().replaceAll('*', '') - if (normalizedApiKey.length === 0) return - - updateEngineConfig.mutate({ - engine: remoteEngine, - config: { - config: 'apiKey', - value: apiKey, - }, - }) - }, [remoteEngine, updateEngineConfig, apiKey]) - - const onDismiss = useCallback(() => { - setUpRemoteModelStage('NONE', undefined) - }, [setUpRemoteModelStage]) - - if (remoteEngine == null) return null - const owner: string = getTitleByCategory(remoteEngine) - const logoUrl: string = (metadata?.logo ?? '') as string - const apiKeyUrl: string | undefined = (metadata?.api_key_url ?? '') as - | string - | undefined - - return ( - -
- {logoUrl && ( - Model owner - )} -

- {owner} -

-
- -
API Key
- - setApiKey(e.target.value)} - /> - - {apiKeyUrl && ( - - - Get your API key from {owner} - - - )} - - -
- - -
- - } - /> - ) -} - -export default SetUpApiKeyModal diff --git a/web/screens/HubScreen2/components/SetUpRemoteModelModal.tsx b/web/screens/HubScreen2/components/SetUpRemoteModelModal.tsx deleted file mode 100644 index 5f608f63d3..0000000000 --- a/web/screens/HubScreen2/components/SetUpRemoteModelModal.tsx +++ /dev/null @@ -1,53 +0,0 @@ -import { Fragment } from 'react' - -import { Modal } from '@janhq/joi' -import { useAtomValue, useSetAtom } from 'jotai' - -import HeaderModal from './HeaderModal' -import ModelTitle from './ModelTitle' - -import { - navigateToSetUpApiKeyAtom, - setUpRemoteModelStageAtom, -} from '@/helpers/atoms/SetupRemoteModel.atom' - -const SetUpRemoteModelModal: React.FC = () => { - const setUpRemoteModelStage = useSetAtom(setUpRemoteModelStageAtom) - const navigateToSetUpApiKey = useSetAtom(navigateToSetUpApiKeyAtom) - const { stage, metadata } = useAtomValue(setUpRemoteModelStageAtom) - - const author: string = (metadata?.author ?? '') as string - const logoUrl: string = (metadata?.logo ?? '') as string - const description: string = (metadata?.description ?? '') as string - const modelName: string = (metadata?.modelName ?? '') as string - const modelId: string = (metadata?.modelId ?? '') as string - - return ( - setUpRemoteModelStage('NONE', undefined)} - content={ - - - - - {description && ( - - {description} - - )} - - } - /> - ) -} - -export default SetUpRemoteModelModal diff --git a/web/screens/HubScreen2/components/SidebarFilter.tsx b/web/screens/HubScreen2/components/SidebarFilter.tsx deleted file mode 100644 index 889303e49f..0000000000 --- a/web/screens/HubScreen2/components/SidebarFilter.tsx +++ /dev/null @@ -1,43 +0,0 @@ -import { Checkbox } from '@janhq/joi' -import { useAtomValue } from 'jotai' - -import { twMerge } from 'tailwind-merge' - -import RangeSlider from './DoubleRange' -import Toggle from './Toggle' - -import { - reduceTransparentAtom, - showSidbarFilterAtom, -} from '@/helpers/atoms/Setting.atom' - -const SidebarFilter: React.FC = () => { - const reduceTransparent = useAtomValue(reduceTransparentAtom) - const showSidbarFilter = useAtomValue(showSidbarFilterAtom) - - return ( -
- Filter -
- - Compatible with my device -
- Format -
- - - -
- Model Size - -
- ) -} - -export default SidebarFilter diff --git a/web/screens/HubScreen2/components/Slider.tsx b/web/screens/HubScreen2/components/Slider.tsx deleted file mode 100644 index 337a756e75..0000000000 --- a/web/screens/HubScreen2/components/Slider.tsx +++ /dev/null @@ -1,80 +0,0 @@ -import React, { useEffect, useState } from 'react' - -import { twMerge } from 'tailwind-merge' - -import useModelHub, { QuickStartModel } from '@/hooks/useModelHub' - -import { - Carousel, - CarouselContent, - CarouselItem, - CarouselNext, - CarouselPrevious, -} from './Carousel' -import SliderItem from './SliderItem' - -const Slider: React.FC = () => { - const { data } = useModelHub() - const [width, setWidth] = useState(window.innerWidth) - - useEffect(() => { - window.addEventListener('resize', () => { - setWidth(window.innerWidth) - }) - return () => { - window.removeEventListener('resize', () => { - setWidth(window.innerWidth) - }) - } - }, []) - - if (!data) return null - const models = data.sliderData ?? [] - - const normalizedModelsList: QuickStartModel[][] = [] - - const getColumnCount = () => { - if (width <= 670) return 1 - if (width <= 1000) return 2 - return 3 - } - - models.forEach((model, index) => { - if (index % getColumnCount() === 0) { - normalizedModelsList.push([model]) - } else { - normalizedModelsList[normalizedModelsList.length - 1].push(model) - } - }) - - return ( - - - {normalizedModelsList.map((modelArray, index) => ( - 1000 && 'grid-cols-3', - width <= 1000 && 'grid-cols-2', - width <= 670 && 'grid-cols-1' - )} - key={index} - > - {modelArray.map((model) => ( - - ))} - - ))} - - - - - ) -} - -export default Slider diff --git a/web/screens/HubScreen2/components/SliderItem.tsx b/web/screens/HubScreen2/components/SliderItem.tsx deleted file mode 100644 index 43c2d40597..0000000000 --- a/web/screens/HubScreen2/components/SliderItem.tsx +++ /dev/null @@ -1,186 +0,0 @@ -import React, { useCallback, useMemo } from 'react' - -import Image from 'next/image' - -import { Button, Progress } from '@janhq/joi' - -import { useAtomValue, useSetAtom } from 'jotai' - -import { toaster } from '@/containers/Toast' - -import useAbortDownload from '@/hooks/useAbortDownload' -import useAssistantQuery from '@/hooks/useAssistantQuery' -import { downloadStateListAtom } from '@/hooks/useDownloadState' -import useModelDownloadMutation from '@/hooks/useModelDownloadMutation' -import { QuickStartModel } from '@/hooks/useModelHub' - -import useThreadCreateMutation from '@/hooks/useThreadCreateMutation' - -import { formatDownloadPercentage, toGibibytes } from '@/utils/converter' -import { downloadProgress } from '@/utils/download' - -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' -import { localModelModalStageAtom } from '@/helpers/atoms/DownloadLocalModel.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' - -type Props = { - model: QuickStartModel -} - -const SliderItem: React.FC = ({ model }) => { - const url = new URL(model.url) - const pathArray = url.pathname.split('/').filter((segment) => segment !== '') - - const owner = pathArray[0] - const repo = pathArray[1] - const fileName = pathArray[pathArray.length - 1] - const repoId = `${owner}/${repo}` - - const shouldShowOwnerLogo = model.logo !== undefined && model.logo !== '' - - return ( -
-
- - {model.model_name} - -
- {shouldShowOwnerLogo && ( - {model.author} - )} - - {model.author} - -
-
-
- - {toGibibytes(model.size)} - - -
-
- ) -} - -type DownloadContainerProps = { - modelHandle: string - fileName: string -} - -const DownloadContainer: React.FC = ({ - modelHandle, - fileName, -}) => { - const downloadModelMutation = useModelDownloadMutation() - const createThreadMutation = useThreadCreateMutation() - const setMainViewState = useSetAtom(mainViewStateAtom) - const { data: assistants } = useAssistantQuery() - - const { abortDownload } = useAbortDownload() - - const setDownloadLocalModelModalStage = useSetAtom(localModelModalStageAtom) - - const downloadedModels = useAtomValue(downloadedModelsAtom) - const allDownloadState = useAtomValue(downloadStateListAtom) - - const persistModelId = modelHandle - .replaceAll('/', '_') - .concat('_') - .concat(fileName) - - const downloadState = allDownloadState.find((s) => s.id == persistModelId) - - const downloadedModel = useMemo( - () => downloadedModels.find((m) => m.model === persistModelId), - [downloadedModels, persistModelId] - ) - - const onDownloadClick = useCallback(() => { - downloadModelMutation.mutate({ - modelId: modelHandle, - fileName: fileName, - persistedModelId: persistModelId, - }) - }, [downloadModelMutation, modelHandle, fileName, persistModelId]) - - const onUseModelClick = useCallback(async () => { - if (!assistants || assistants.length === 0) { - toaster({ - title: 'No assistant available.', - description: 'Please create an assistant to create a new thread', - type: 'error', - }) - return - } - - await createThreadMutation.mutateAsync({ - modelId: persistModelId, - assistant: { - ...assistants[0], - model: persistModelId, - }, - }) - setDownloadLocalModelModalStage('NONE', undefined) - setMainViewState(MainViewState.Thread) - }, [ - setDownloadLocalModelModalStage, - setMainViewState, - createThreadMutation, - persistModelId, - assistants, - ]) - - const onAbortDownloadClick = useCallback(() => { - abortDownload(persistModelId) - }, [abortDownload, persistModelId]) - - return ( -
- {downloadedModel ? ( - - ) : downloadState != null ? ( - - ) : ( - - )} -
- ) -} - -export default SliderItem diff --git a/web/screens/HubScreen2/components/Tab.tsx b/web/screens/HubScreen2/components/Tab.tsx deleted file mode 100644 index 2b628bf352..0000000000 --- a/web/screens/HubScreen2/components/Tab.tsx +++ /dev/null @@ -1,38 +0,0 @@ -import { twMerge } from 'tailwind-merge' - -type Props = { - tab: ModelTab - handleTab: (ModelTab: string) => void -} - -export const AvailableLocalModelTabs = ['Versions', 'Information'] as const -export type ModelTab = (typeof AvailableLocalModelTabs)[number] - -const Tab: React.FC = ({ tab, handleTab }) => { - return ( -
- {AvailableLocalModelTabs.map((item) => ( - - ))} -
- ) -} - -export default Tab diff --git a/web/screens/HubScreen2/components/Toggle.tsx b/web/screens/HubScreen2/components/Toggle.tsx deleted file mode 100644 index 8668136efe..0000000000 --- a/web/screens/HubScreen2/components/Toggle.tsx +++ /dev/null @@ -1,17 +0,0 @@ -import { twMerge } from 'tailwind-merge' - -const Toggle: React.FC = () => { - return ( -
+
+ + + ) +} + +export default LocalServerLeftPanel diff --git a/web/screens/LocalServer/LocalServerRightPanel/index.tsx b/web/screens/LocalServer/LocalServerRightPanel/index.tsx new file mode 100644 index 0000000000..309709c268 --- /dev/null +++ b/web/screens/LocalServer/LocalServerRightPanel/index.tsx @@ -0,0 +1,135 @@ +import { useCallback, useEffect, useMemo, useState } from 'react' + +import { Accordion, AccordionItem } from '@janhq/joi' +import { useAtomValue, useSetAtom } from 'jotai' +import { AlertTriangleIcon, InfoIcon } from 'lucide-react' + +import EngineSetting from '@/containers/EngineSetting' +import { modalTroubleShootingAtom } from '@/containers/ModalTroubleShoot' +import ModelDropdown from '@/containers/ModelDropdown' +import ModelSetting from '@/containers/ModelSetting' +import RightPanelContainer from '@/containers/RightPanelContainer' + +import { loadModelErrorAtom } from '@/hooks/useActiveModel' + +import { getConfigurationsData } from '@/utils/componentSettings' + +import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' + +import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' +import { selectedModelAtom } from '@/helpers/atoms/Model.atom' + +const LocalServerRightPanel = () => { + const loadModelError = useAtomValue(loadModelErrorAtom) + const serverEnabled = useAtomValue(serverEnabledAtom) + const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom) + + const selectedModel = useAtomValue(selectedModelAtom) + + const [currentModelSettingParams, setCurrentModelSettingParams] = useState( + toSettingParams(selectedModel?.settings) + ) + + useEffect(() => { + if (selectedModel) { + setCurrentModelSettingParams(toSettingParams(selectedModel?.settings)) + } + }, [selectedModel]) + + const modelRuntimeParams = toRuntimeParams(selectedModel?.settings) + + const componentDataRuntimeSetting = getConfigurationsData( + modelRuntimeParams, + selectedModel + ) + + const componentDataEngineSetting = getConfigurationsData( + currentModelSettingParams + ) + + const engineSettings = useMemo( + () => + componentDataEngineSetting.filter( + (x) => x.key !== 'prompt_template' && x.key !== 'embedding' + ), + [componentDataEngineSetting] + ) + + const modelSettings = useMemo(() => { + return componentDataRuntimeSetting.filter( + (x) => x.key !== 'prompt_template' + ) + }, [componentDataRuntimeSetting]) + + const onValueChanged = useCallback( + (key: string, value: string | number | boolean) => { + setCurrentModelSettingParams({ + ...currentModelSettingParams, + [key]: value, + }) + }, + [currentModelSettingParams] + ) + + return ( + +
+
+ +

+ You can concurrently send requests to one active local model and + multiple remote models. +

+
+ + + + {loadModelError && serverEnabled && ( +
+ + + Model failed to start. Access{' '} + setModalTroubleShooting(true)} + > + troubleshooting assistance + + +
+ )} +
+ + + {modelSettings.length !== 0 && ( + + + + )} + + {engineSettings.length !== 0 && ( + + + + )} + +
+ ) +} + +export default LocalServerRightPanel diff --git a/web/screens/LocalServer/index.tsx b/web/screens/LocalServer/index.tsx new file mode 100644 index 0000000000..529c799dbf --- /dev/null +++ b/web/screens/LocalServer/index.tsx @@ -0,0 +1,23 @@ +'use client' + +import ModalTroubleShooting from '@/containers/ModalTroubleShoot' + +import LocalServerCenterPanel from './LocalServerCenterPanel' +import LocalServerLeftPanel from './LocalServerLeftPanel' +import LocalServerRightPanel from './LocalServerRightPanel' + +const LocalServerScreen = () => { + return ( +
+ + + + +
+ ) +} + +export default LocalServerScreen diff --git a/web/screens/Settings/Advanced/DataFolder/index.tsx b/web/screens/Settings/Advanced/DataFolder/index.tsx index bbbcc6b8f0..3bb059a87c 100644 --- a/web/screens/Settings/Advanced/DataFolder/index.tsx +++ b/web/screens/Settings/Advanced/DataFolder/index.tsx @@ -1,9 +1,8 @@ -import { isAbsolute, relative } from 'path' - -import { Fragment, useCallback, useEffect, useState } from 'react' +import { Fragment, useCallback, useState } from 'react' +import { AppConfiguration, isSubdirectory } from '@janhq/core' import { Button, Input } from '@janhq/joi' -import { useAtom, useSetAtom } from 'jotai' +import { useAtomValue, useSetAtom } from 'jotai' import { PencilIcon, FolderOpenIcon } from 'lucide-react' import Loader from '@/containers/Loader' @@ -31,18 +30,8 @@ const DataFolder = () => { const setShowChangeFolderError = useSetAtom(showChangeFolderErrorAtom) const showDestNotEmptyConfirm = useSetAtom(showDestNotEmptyConfirmAtom) - const [janDataFolderPath, setJanDataFolderPath] = useAtom( - janDataFolderPathAtom - ) - const getAppDataFolder = useCallback(async () => { - return window.electronAPI?.appDataFolder().then(setJanDataFolderPath) - }, [setJanDataFolderPath]) - const [destinationPath, setDestinationPath] = useState(undefined) - - useEffect(() => { - getAppDataFolder() - }, [getAppDataFolder]) + const janDataFolderPath = useAtomValue(janDataFolderPathAtom) const onChangeFolderClick = useCallback(async () => { const destFolder = await window.core?.api?.selectDirectory() @@ -53,15 +42,11 @@ const DataFolder = () => { return } - const currentJanDataFolder = await window.electronAPI?.appDataFolder() - - const relativePath = relative(currentJanDataFolder, destFolder) + const appConfiguration: AppConfiguration = + await window.core?.api?.getAppConfigurations() + const currentJanDataFolder = appConfiguration.data_folder - if ( - relativePath && - !relativePath.startsWith('..') && - !isAbsolute(relativePath) - ) { + if (await isSubdirectory(currentJanDataFolder, destFolder)) { setShowSameDirectory(true) return } @@ -122,9 +107,7 @@ const DataFolder = () => { - window.electronAPI?.openFileExplorer(janDataFolderPath) - } + onClick={() => window.core?.api?.openAppDirectory()} />
-
-
-
-
-
-
Delete All Threads
-
-

- Multiple migrations may create duplicate threads. Use this button to - clean up if needed. -

-
-
- -
-
- - ) -} - -export default DataMigration diff --git a/web/screens/Settings/Advanced/index.tsx b/web/screens/Settings/Advanced/index.tsx index 2de42ebb3f..f132f81e77 100644 --- a/web/screens/Settings/Advanced/index.tsx +++ b/web/screens/Settings/Advanced/index.tsx @@ -2,20 +2,31 @@ import { useEffect, useState, useCallback, ChangeEvent } from 'react' -import { AppConfiguration } from '@janhq/core' +import { openExternalUrl, fs, AppConfiguration } from '@janhq/core' -import { ScrollArea, Switch, Input } from '@janhq/joi' +import { + ScrollArea, + Button, + Switch, + Input, + Tooltip, + Checkbox, + useClickOutside, +} from '@janhq/joi' import { useAtom, useAtomValue } from 'jotai' +import { ChevronDownIcon } from 'lucide-react' +import { AlertTriangleIcon, AlertCircleIcon } from 'lucide-react' + +import { twMerge } from 'tailwind-merge' -import { toaster } from '@/containers/Toast' +import { snackbar, toaster } from '@/containers/Toast' -import useModelStop from '@/hooks/useModelStop' +import { useActiveModel } from '@/hooks/useActiveModel' import { useSettings } from '@/hooks/useSettings' import DataFolder from './DataFolder' - -import DataMigration from './components/DataMigration' +import FactoryReset from './FactoryReset' import { experimentalFeatureEnabledAtom, @@ -26,7 +37,24 @@ import { quickAskEnabledAtom, } from '@/helpers/atoms/AppConfig.atom' -import { activeModelsAtom } from '@/helpers/atoms/Model.atom' +type GPU = { + id: string + vram: number | null + name: string +} + +const test = [ + { + id: 'test a', + vram: 2, + name: 'nvidia A', + }, + { + id: 'test', + vram: 2, + name: 'nvidia B', + }, +] const Advanced = () => { const [experimentalEnabled, setExperimentalEnabled] = useAtom( @@ -40,25 +68,24 @@ const Advanced = () => { const [ignoreSSL, setIgnoreSSL] = useAtom(ignoreSslAtom) const [partialProxy, setPartialProxy] = useState(proxy) - // const [gpuEnabled, setGpuEnabled] = useState(false) - // const [gpuList, setGpuList] = useState([]) - // const [gpusInUse, setGpusInUse] = useState([]) - // const [dropdownOptions, setDropdownOptions] = useState( - // null - // ) + const [gpuEnabled, setGpuEnabled] = useState(false) + const [gpuList, setGpuList] = useState(test) + const [gpusInUse, setGpusInUse] = useState([]) + const [dropdownOptions, setDropdownOptions] = useState( + null + ) - // const [toggle, setToggle] = useState(null) + const [toggle, setToggle] = useState(null) const { readSettings, saveSettings } = useSettings() - const activeModels = useAtomValue(activeModelsAtom) - // const [open, setOpen] = useState(false) - const stopModel = useModelStop() + const { stopModel } = useActiveModel() + const [open, setOpen] = useState(false) - // const selectedGpu = gpuList - // .filter((x) => gpusInUse.includes(x.id)) - // .map((y) => { - // return y['name'] - // }) + const selectedGpu = gpuList + .filter((x) => gpusInUse.includes(x.id)) + .map((y) => { + return y['name'] + }) const onProxyChange = useCallback( (event: ChangeEvent) => { @@ -79,7 +106,7 @@ const Advanced = () => { ) => { const appConfiguration: AppConfiguration = await window.core?.api?.getAppConfigurations() - appConfiguration.quickAsk = e + appConfiguration.quick_ask = e await window.core?.api?.updateAppConfiguration(appConfiguration) if (relaunch) window.core?.api?.relaunch() } @@ -89,11 +116,7 @@ const Advanced = () => { title: 'Reload', description: 'Vulkan settings updated. Reload now to apply the changes.', }) - - for (const model of activeModels) { - await stopModel.mutateAsync(model.model) - } - + stopModel() setVulkanEnabled(e) await saveSettings({ vulkan: e, gpusInUse: [] }) // Relaunch to apply settings @@ -114,57 +137,58 @@ const Advanced = () => { } useEffect(() => { - // const setUseGpuIfPossible = async () => { - // const settings = await readSettings() - // setGpuEnabled(settings.run_mode === 'gpu' && settings.gpus?.length > 0) - // setGpusInUse(settings.gpus_in_use || []) - // setVulkanEnabled(settings.vulkan || false) - // if (settings.gpus) { - // setGpuList(settings.gpus) - // } - // } - // setUseGpuIfPossible() - }, [readSettings, setVulkanEnabled]) - - // const clearLogs = async () => { - // // try { - // // await fs.rm(`file://logs`) - // // } catch (err) { - // // console.error('Error clearing logs: ', err) - // // } - // // toaster({ - // // title: 'Logs cleared', - // // description: 'All logs have been cleared.', - // // type: 'success', - // // }) - // } - - // const handleGPUChange = (gpuId: string) => { - // let updatedGpusInUse = [...gpusInUse] - // if (updatedGpusInUse.includes(gpuId)) { - // updatedGpusInUse = updatedGpusInUse.filter((id) => id !== gpuId) - // if (gpuEnabled && updatedGpusInUse.length === 0) { - // // Vulkan support only allow 1 active device at a time - // if (vulkanEnabled) { - // updatedGpusInUse = [] - // } - // updatedGpusInUse.push(gpuId) - // } - // } else { - // // Vulkan support only allow 1 active device at a time - // if (vulkanEnabled) { - // updatedGpusInUse = [] - // } - // updatedGpusInUse.push(gpuId) - // } - // setGpusInUse(updatedGpusInUse) - // saveSettings({ gpusInUse: updatedGpusInUse }) - // } - - // const gpuSelectionPlaceHolder = - // gpuList.length > 0 ? 'Select GPU' : "You don't have any compatible GPU" - - // useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle]) + const setUseGpuIfPossible = async () => { + const settings = await readSettings() + setGpuEnabled(settings.run_mode === 'gpu' && settings.gpus?.length > 0) + setGpusInUse(settings.gpus_in_use || []) + setVulkanEnabled(settings.vulkan || false) + if (settings.gpus) { + setGpuList(settings.gpus) + } + } + setUseGpuIfPossible() + }, [readSettings, setGpuList, setGpuEnabled, setGpusInUse, setVulkanEnabled]) + + const clearLogs = async () => { + try { + await fs.rm(`file://logs`) + } catch (err) { + console.error('Error clearing logs: ', err) + } + + toaster({ + title: 'Logs cleared', + description: 'All logs have been cleared.', + type: 'success', + }) + } + + const handleGPUChange = (gpuId: string) => { + let updatedGpusInUse = [...gpusInUse] + if (updatedGpusInUse.includes(gpuId)) { + updatedGpusInUse = updatedGpusInUse.filter((id) => id !== gpuId) + if (gpuEnabled && updatedGpusInUse.length === 0) { + // Vulkan support only allow 1 active device at a time + if (vulkanEnabled) { + updatedGpusInUse = [] + } + updatedGpusInUse.push(gpuId) + } + } else { + // Vulkan support only allow 1 active device at a time + if (vulkanEnabled) { + updatedGpusInUse = [] + } + updatedGpusInUse.push(gpuId) + } + setGpusInUse(updatedGpusInUse) + saveSettings({ gpusInUse: updatedGpusInUse }) + } + + const gpuSelectionPlaceHolder = + gpuList.length > 0 ? 'Select GPU' : "You don't have any compatible GPU" + + useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle]) return ( @@ -186,184 +210,180 @@ const Advanced = () => {
{/* CPU / GPU switching */} - {/* {!isMac && ( */} - {/*
*/} - {/*
*/} - {/*
*/} - {/*
*/} - {/*
GPU Acceleration
*/} - {/*
*/} - {/*

*/} - {/* Enable to enhance model performance by utilizing your GPU */} - {/* devices for acceleration. Read{' '} */} - {/* */} - {/* {' '} */} - {/* */} - {/* // openExternalUrl( */} - {/* // 'https://jan.ai/guides/troubleshooting/gpu-not-used/' */} - {/* // ) */} - {/* // } */} - {/* > */} - {/* troubleshooting guide */} - {/* {' '} */} - {/* {' '} */} - {/* for further assistance. */} - {/*

*/} - {/*
*/} - {/**/} - {/*
*/} - {/* {gpuList.length > 0 && !gpuEnabled && ( */} - {/* */} - {/* } */} - {/* content="Disabling NVIDIA GPU Acceleration may result in reduced */} - {/* performance. It is recommended to keep this enabled for */} - {/* optimal user experience." */} - {/* /> */} - {/* )} */} - {/* { */} - {/* if (e.target.checked === true) { */} - {/* saveSettings({ runMode: 'gpu' }) */} - {/* setGpuEnabled(true) */} - {/* snackbar({ */} - {/* description: */} - {/* 'Successfully turned on GPU Acceleration', */} - {/* type: 'success', */} - {/* }) */} - {/* } else { */} - {/* saveSettings({ runMode: 'cpu' }) */} - {/* setGpuEnabled(false) */} - {/* snackbar({ */} - {/* description: */} - {/* 'Successfully turned off GPU Acceleration', */} - {/* type: 'success', */} - {/* }) */} - {/* } */} - {/* // Stop any running model to apply the changes */} - {/* if (e.target.checked !== gpuEnabled) { */} - {/* for (const activeModel of activeModels) { */} - {/* stopModel(activeModel.model) */} - {/* } */} - {/* } */} - {/* }} */} - {/* /> */} - {/* } */} - {/* content="Your current device does not have a compatible GPU for */} - {/* monitoring. To enable GPU monitoring, please ensure your */} - {/* device has a supported Nvidia or AMD GPU with updated */} - {/* drivers." */} - {/* disabled={gpuList.length > 0} */} - {/* /> */} - {/*
*/} - {/*
*/} - {/*
*/} - {/* */} - {/*
*/} - {/* */} - {/* } */} - {/* onClick={() => setOpen(!open)} */} - {/* /> */} - {/*
*/} - {/*
*/} - {/*

{vulkanEnabled ? 'Vulkan Supported GPUs' : 'Nvidia'}

*/} - {/*
*/} - {/*
*/} - {/* {gpuList */} - {/* .filter((gpu) => */} - {/* vulkanEnabled */} - {/* ? gpu.name */} - {/* : gpu.name?.toLowerCase().includes('nvidia') */} - {/* ) */} - {/* .map((gpu) => ( */} - {/*
*/} - {/* handleGPUChange(gpu.id)} */} - {/* label={ */} - {/* */} - {/* {gpu.name} */} - {/* {!vulkanEnabled && ( */} - {/* {gpu.vram}MB VRAM */} - {/* )} */} - {/* */} - {/* } */} - {/* /> */} - {/*
*/} - {/* ))} */} - {/*
*/} - {/* {gpuEnabled && gpusInUse.length > 1 && ( */} - {/*
*/} - {/* */} - {/*

*/} - {/* If multi-GPU is enabled with different GPU models or */} - {/* without NVLink, it could impact token speed. */} - {/*

*/} - {/*
*/} - {/* )} */} - {/*
*/} - {/*
*/} - {/*
*/} - {/*
*/} - {/*
*/} - {/*
*/} - {/* )} */} + {!isMac && ( +
+
+
+
+
GPU Acceleration
+
+

+ Enable to enhance model performance by utilizing your GPU + devices for acceleration. Read{' '} + + {' '} + + openExternalUrl( + 'https://jan.ai/guides/troubleshooting/gpu-not-used/' + ) + } + > + troubleshooting guide + {' '} + {' '} + for further assistance. +

+
+ +
+ {gpuList.length > 0 && !gpuEnabled && ( + + } + content="Disabling NVIDIA GPU Acceleration may result in reduced + performance. It is recommended to keep this enabled for + optimal user experience." + /> + )} + { + if (e.target.checked === true) { + saveSettings({ runMode: 'gpu' }) + setGpuEnabled(true) + snackbar({ + description: + 'Successfully turned on GPU Acceleration', + type: 'success', + }) + } else { + saveSettings({ runMode: 'cpu' }) + setGpuEnabled(false) + snackbar({ + description: + 'Successfully turned off GPU Acceleration', + type: 'success', + }) + } + // Stop any running model to apply the changes + if (e.target.checked !== gpuEnabled) stopModel() + }} + /> + } + content="Your current device does not have a compatible GPU for + monitoring. To enable GPU monitoring, please ensure your + device has a supported Nvidia or AMD GPU with updated + drivers." + disabled={gpuList.length > 0} + /> +
+
+
+ +
+ + } + onClick={() => setOpen(!open)} + /> +
+
+

{vulkanEnabled ? 'Vulkan Supported GPUs' : 'Nvidia'}

+
+
+ {gpuList + .filter((gpu) => + vulkanEnabled + ? gpu.name + : gpu.name?.toLowerCase().includes('nvidia') + ) + .map((gpu) => ( +
+ handleGPUChange(gpu.id)} + label={ + + {gpu.name} + {!vulkanEnabled && ( + {gpu.vram}MB VRAM + )} + + } + /> +
+ ))} +
+ {gpuEnabled && gpusInUse.length > 1 && ( +
+ +

+ If multi-GPU is enabled with different GPU models or + without NVLink, it could impact token speed. +

+
+ )} +
+
+
+
+
+
+ )} {/* Vulkan for AMD GPU/ APU and Intel Arc GPU */} - {/* {!isMac && experimentalEnabled && ( */} - {/*
*/} - {/*
*/} - {/*
*/} - {/*
Vulkan Support
*/} - {/*
*/} - {/*

*/} - {/* Enable Vulkan with AMD GPU/APU and Intel Arc GPU for better */} - {/* model performance (reload needed). */} - {/*

*/} - {/*
*/} - {/**/} - {/* updateVulkanEnabled(e.target.checked)} */} - {/* /> */} - {/*
*/} - {/* )} */} + {!isMac && experimentalEnabled && ( +
+
+
+
Vulkan Support
+
+

+ Enable Vulkan with AMD GPU/APU and Intel Arc GPU for better + model performance (reload needed). +

+
+ + updateVulkanEnabled(e.target.checked)} + /> +
+ )} @@ -379,7 +399,7 @@ const Advanced = () => {

-
+
setProxyEnabled(!proxyEnabled)} @@ -413,7 +433,7 @@ const Advanced = () => { />
- {/* {experimentalEnabled && ( + {experimentalEnabled && (
@@ -439,10 +459,10 @@ const Advanced = () => { }} />
- )} */} + )} {/* Clear log */} - {/*
+
Clear logs
@@ -454,11 +474,10 @@ const Advanced = () => { -
*/} +
{/* Factory Reset */} - {/* */} - {experimentalEnabled && } +
) diff --git a/web/screens/Settings/Appearance/index.tsx b/web/screens/Settings/Appearance/index.tsx index a6571be994..837a7074c0 100644 --- a/web/screens/Settings/Appearance/index.tsx +++ b/web/screens/Settings/Appearance/index.tsx @@ -2,10 +2,12 @@ import { useCallback } from 'react' import { useTheme } from 'next-themes' +import { fs, joinPath } from '@janhq/core' import { Button, Select, Switch } from '@janhq/joi' import { useAtom, useAtomValue } from 'jotai' import { + janThemesPathAtom, reduceTransparentAtom, selectedThemeIdAtom, spellCheckAtom, @@ -17,6 +19,7 @@ export default function AppearanceOptions() { const [selectedIdTheme, setSelectedIdTheme] = useAtom(selectedThemeIdAtom) const themeOptions = useAtomValue(themesOptionsAtom) const { setTheme } = useTheme() + const janThemesPath = useAtomValue(janThemesPathAtom) const [themeData, setThemeData] = useAtom(themeDataAtom) const [reduceTransparent, setReduceTransparent] = useAtom( reduceTransparentAtom @@ -26,7 +29,8 @@ export default function AppearanceOptions() { const handleClickTheme = useCallback( async (e: string) => { setSelectedIdTheme(e) - const theme: Theme = await window.electronAPI.readTheme(e) + const filePath = await joinPath([`${janThemesPath}/${e}`, `theme.json`]) + const theme: Theme = JSON.parse(await fs.readFileSync(filePath, 'utf-8')) setThemeData(theme) setTheme(String(theme?.nativeTheme)) if (theme?.reduceTransparent) { @@ -36,6 +40,7 @@ export default function AppearanceOptions() { } }, [ + janThemesPath, reduceTransparent, setReduceTransparent, setSelectedIdTheme, @@ -52,11 +57,11 @@ export default function AppearanceOptions() {
Appearance

- Select a color theme. + Select a color theme

} -// placeholder="Search" -// onChange={(e) => setSearchText(e.target.value)} -// /> -//
-//
-// -// -//
-//
- -//
-// {engineActiveExtensions.length !== 0 && ( -//
-//
-// Model Providers -//
-//
-// )} -// {engineActiveExtensions -// .filter((x) => x.name.includes(searchText.toLowerCase().trim())) -// .sort((a, b) => a.provider.localeCompare(b.provider)) -// .map((item, i) => { -// return ( -//
-//
-//
-//
-//
-// {item.productName?.replace('Inference Engine', '') ?? -// formatExtensionsName(item.name)} -//
-// -// v{item.version} -// -//

{item.provider}

-//
-//
-// {/* {!inActiveEngineProvider.includes(item.provider) && ( -// -// )} */} -// {/* onSwitchChange(item.provider)} -// /> */} -//
-//
-// { -//
-// } -//
-//
-// ) -// })} - -// // {coreActiveExtensions.length > 0 && ( -// //
-// //
-// // Core Extention -// //
-// //
-// // )} -// // {coreActiveExtensions -// // .filter((x) => x.name.includes(searchText.toLowerCase().trim())) -// // .sort((a, b) => a.name.localeCompare(b.name)) -// // .map((item, i) => { -// // return ( -// //
-// //
-// //
-// //
-// // {item.productName ?? formatExtensionsName(item.name)} -// //
-// // -// // v{item.version} -// // -// //
-// // { -// //
-// // } -// //
-// //
-// // ) -// // })} -// //
-// // -// // {showLoading && } -// // -// // ) -// // } - -// // export default ExtensionCatalog +/* eslint-disable @typescript-eslint/no-explicit-any */ +import React, { useState, useEffect, useRef, useCallback } from 'react' + +import { InferenceEngine } from '@janhq/core' + +import { Button, ScrollArea, Badge, Switch, Input } from '@janhq/joi' +import { useAtom } from 'jotai' +import { SearchIcon } from 'lucide-react' +import { Marked, Renderer } from 'marked' + +import Loader from '@/containers/Loader' + +import SetupRemoteModel from '@/containers/SetupRemoteModel' + +import { formatExtensionsName } from '@/utils/converter' + +import { extensionManager } from '@/extension' +import Extension from '@/extension/Extension' +import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' + +type EngineExtension = { + provider: InferenceEngine +} & Extension + +const ExtensionCatalog = () => { + const [coreActiveExtensions, setCoreActiveExtensions] = useState( + [] + ) + const [engineActiveExtensions, setEngineActiveExtensions] = useState< + EngineExtension[] + >([]) + const [searchText, setSearchText] = useState('') + const [showLoading, setShowLoading] = useState(false) + const fileInputRef = useRef(null) + + useEffect(() => { + const getAllSettings = async () => { + const extensionsMenu = [] + const engineMenu = [] + const extensions = extensionManager.getAll() + + for (const extension of extensions) { + const settings = await extension.getSettings() + if ( + typeof extension.getSettings === 'function' && + 'provider' in extension && + typeof extension.provider === 'string' + ) { + if ( + (settings && settings.length > 0) || + (await extension.installationState()) !== 'NotRequired' + ) { + engineMenu.push({ + ...extension, + provider: + 'provider' in extension && + typeof extension.provider === 'string' + ? extension.provider + : '', + }) + } + } else { + extensionsMenu.push({ + ...extension, + }) + } + } + + setCoreActiveExtensions(extensionsMenu) + setEngineActiveExtensions(engineMenu as any) + } + getAllSettings() + }, []) + + /** + * Installs a extension by calling the `extensions.install` function with the extension file path. + * If the installation is successful, the application is relaunched using the `coreAPI.relaunch` function. + * @param e - The event object. + */ + const install = async (e: any) => { + e.preventDefault() + const extensionFile = e.target.files?.[0].path + + // Send the filename of the to be installed extension + // to the main process for installation + const installed = await extensionManager.install([extensionFile]) + if (installed) window.core?.api?.relaunch() + } + + /** + * Uninstalls a extension by calling the `extensions.uninstall` function with the extension name. + * If the uninstallation is successful, the application is relaunched using the `coreAPI.relaunch` function. + * @param name - The name of the extension to uninstall. + */ + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const uninstall = async (name: string) => { + // Send the filename of the to be uninstalled extension + // to the main process for removal + const res = await extensionManager.uninstall([name]) + if (res) window.core?.api?.relaunch() + } + + /** + * Handles the change event of the extension file input element by setting the file name state. + * Its to be used to display the extension file name of the selected file. + * @param event - The change event object. + */ + const handleFileChange = (event: React.ChangeEvent) => { + const file = event.target.files?.[0] + if (file) { + setShowLoading(true) + install(event) + } + } + + const [inActiveEngineProvider, setInActiveEngineProvider] = useAtom( + inActiveEngineProviderAtom + ) + + const onSwitchChange = useCallback( + (name: string) => { + if (inActiveEngineProvider.includes(name)) { + setInActiveEngineProvider( + [...inActiveEngineProvider].filter((x) => x !== name) + ) + } else { + setInActiveEngineProvider([...inActiveEngineProvider, name]) + } + }, + [inActiveEngineProvider, setInActiveEngineProvider] + ) + + return ( + <> + +
+
+ } + placeholder="Search" + onChange={(e) => setSearchText(e.target.value)} + /> +
+
+ + +
+
+ +
+ {engineActiveExtensions.length !== 0 && ( +
+
+ Model Providers +
+
+ )} + {engineActiveExtensions + .filter((x) => x.name.includes(searchText.toLowerCase().trim())) + .sort((a, b) => a.provider.localeCompare(b.provider)) + .map((item, i) => { + return ( +
+
+
+
+
+ {item.productName?.replace('Inference Engine', '') ?? + formatExtensionsName(item.name)} +
+ + v{item.version} + +

{item.provider}

+
+
+ {!inActiveEngineProvider.includes(item.provider) && ( + + )} + onSwitchChange(item.provider)} + /> +
+
+ { +
+ } +
+
+ ) + })} + + {coreActiveExtensions.length > 0 && ( +
+
+ Core Extension +
+
+ )} + {coreActiveExtensions + .filter((x) => x.name.includes(searchText.toLowerCase().trim())) + .sort((a, b) => a.name.localeCompare(b.name)) + .map((item, i) => { + return ( +
+
+
+
+ {item.productName ?? formatExtensionsName(item.name)} +
+ + v{item.version} + +
+ { +
+ } +
+
+ ) + })} +
+ + {showLoading && } + + ) +} + +const marked: Marked = new Marked({ + renderer: { + link: (href, title, text) => { + return Renderer.prototype.link + ?.apply(this, [href, title, text]) + .replace( + ' { const updateImportingModel = useSetAtom(updateImportingModelAtom) const { updateModelInfo } = useImportModel() const [modelPath, setModelPath] = useState('') - console.log('EditModelInfoModal', setModelPath) + const editingModel = importingModels.find( (model) => model.importId === editingModelId ) @@ -63,7 +69,7 @@ const EditModelInfoModal = () => { const modelInfo: Partial = { id: editingModel.modelId, name: modelName, - // description, + description, metadata: { author: 'User', tags, @@ -72,7 +78,7 @@ const EditModelInfoModal = () => { } await updateModelInfo(modelInfo) - // events.emit(ModelEvent.OnModelsUpdate, {}) + events.emit(ModelEvent.OnModelsUpdate, {}) updateImportingModel(editingModel.importId, modelName, description, tags) setImportModelStage('IMPORTING_MODEL') @@ -83,15 +89,15 @@ const EditModelInfoModal = () => { const getModelPath = async () => { const modelId = editingModel?.modelId if (!modelId) return '' - // const path = await joinPath([janDataFolder, 'models', modelId]) - // setModelPath(path) + const path = await joinPath([janDataFolder, 'models', modelId]) + setModelPath(path) } getModelPath() }, [janDataFolder, editingModel]) const onShowInFinderClick = useCallback(() => { - // openFileExplorer(modelPath) - }, []) + openFileExplorer(modelPath) + }, [modelPath]) if (!editingModel) { setImportModelStage('IMPORTING_MODEL') diff --git a/web/screens/Settings/EngineSetting/index.tsx b/web/screens/Settings/EngineSetting/index.tsx deleted file mode 100644 index 9ae8298c71..0000000000 --- a/web/screens/Settings/EngineSetting/index.tsx +++ /dev/null @@ -1,102 +0,0 @@ -import { EngineStatus, LlmEngine, LocalEngines } from '@janhq/core' -import { - Button, - ScrollArea, - Table, - TableBody, - TableCell, - TableHead, - TableHeader, - TableRow, -} from '@janhq/joi' - -import useEngineInit from '@/hooks/useEngineInit' -import useEngineQuery from '@/hooks/useEngineQuery' - -import LoadingIndicator from '@/screens/HubScreen2/components/LoadingIndicator' - -const getStatusTitle = (status: string) => { - const normalized = status.charAt(0).toUpperCase() + status.slice(1) - return normalized.replaceAll('_', ' ') -} - -const EngineSetting: React.FC = () => { - const { isLoading, data } = useEngineQuery() - - const initializeEngine = useEngineInit() - - if (isLoading) { - return ( -
- -
- ) - } - - if (!data) { - return ( -
-
Failed to get engine statuses..
-
- ) - } - - return ( - -
- - - - Engine name - Description - Version - Status - Install - - - - {data.map((engineStatus) => { - return ( - - - {engineStatus.name} - - {engineStatus.description} - - {engineStatus.version} - - {getStatusTitle(engineStatus.status)} - - {LocalEngines.some((e) => e === engineStatus.name) && - [EngineStatus.Ready, EngineStatus.NotInitialized].includes( - engineStatus.status as EngineStatus - ) ? ( - - ) : ( - - )} - - - ) - })} - -
-
-
- ) -} - -export default EngineSetting diff --git a/web/screens/Settings/ExtensionSetting/index.tsx b/web/screens/Settings/ExtensionSetting/index.tsx index 56cdcaa5e2..4a8b140f34 100644 --- a/web/screens/Settings/ExtensionSetting/index.tsx +++ b/web/screens/Settings/ExtensionSetting/index.tsx @@ -1,37 +1,43 @@ import React, { Fragment, useEffect, useState } from 'react' -import { SettingComponentProps } from '@janhq/core' +import { + BaseExtension, + InstallationState, + SettingComponentProps, +} from '@janhq/core' import { useAtomValue } from 'jotai' +import ExtensionItem from '../CoreExtensions/ExtensionItem' import SettingDetailItem from '../SettingDetail/SettingDetailItem' +import { extensionManager } from '@/extension' import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom' const ExtensionSetting = () => { const selectedExtensionName = useAtomValue(selectedSettingAtom) const [settings, setSettings] = useState([]) - // const [installationState, setInstallationState] = - // useState('NotRequired') - // const [baseExtension, setBaseExtension] = useState( - // undefined - // ) + const [installationState, setInstallationState] = + useState('NotRequired') + const [baseExtension, setBaseExtension] = useState( + undefined + ) useEffect(() => { const getExtensionSettings = async () => { if (!selectedExtensionName) return - // const allSettings: SettingComponentProps[] = [] - // const baseExtension = extensionManager.getByName(selectedExtensionName) - // if (!baseExtension) return + const allSettings: SettingComponentProps[] = [] + const baseExtension = extensionManager.getByName(selectedExtensionName) + if (!baseExtension) return - // setBaseExtension(baseExtension) - // if (typeof baseExtension.getSettings === 'function') { - // const setting = await baseExtension.getSettings() - // if (setting) allSettings.push(...setting) - // } - // setSettings(allSettings) + setBaseExtension(baseExtension) + if (typeof baseExtension.getSettings === 'function') { + const setting = await baseExtension.getSettings() + if (setting) allSettings.push(...setting) + } + setSettings(allSettings) - // setInstallationState(await baseExtension.installationState()) + setInstallationState(await baseExtension.installationState()) } getExtensionSettings() }, [selectedExtensionName]) @@ -45,10 +51,10 @@ const ExtensionSetting = () => { if (setting.key !== key) return setting setting.controllerProps.value = value - // const extensionName = setting.extensionName - // if (extensionName) { - // extensionManager.getByName(extensionName)?.updateSettings([setting]) - // } + const extensionName = setting.extensionName + if (extensionName) { + extensionManager.getByName(extensionName)?.updateSettings([setting]) + } return setting }) @@ -64,9 +70,9 @@ const ExtensionSetting = () => { onValueUpdated={onValueChanged} /> )} - {/* {baseExtension && installationState !== 'NotRequired' && ( + {baseExtension && installationState !== 'NotRequired' && ( - )} */} + )} ) } diff --git a/web/screens/Settings/Hotkeys/index.tsx b/web/screens/Settings/Hotkeys/index.tsx index 3a416e7ce2..aa79ae11e1 100644 --- a/web/screens/Settings/Hotkeys/index.tsx +++ b/web/screens/Settings/Hotkeys/index.tsx @@ -23,11 +23,11 @@ const availableHotkeys = [ }, { combination: 'Enter', - description: 'Send a message (in input field)', + description: 'Send a message', }, { combination: 'Shift Enter', - description: 'Insert a new line (in input field)', + description: 'Insert new line in input box', }, { combination: 'Arrow Up', diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadList/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadList/index.tsx index f96cba9f57..3078a8c368 100644 --- a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadList/index.tsx +++ b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadList/index.tsx @@ -7,7 +7,7 @@ import ModelDownloadRow from '../ModelDownloadRow' import { importingHuggingFaceRepoDataAtom } from '@/helpers/atoms/HuggingFace.atom' -const ModelDownloadList: React.FC = () => { +const ModelDownloadList = () => { const importingHuggingFaceRepoData = useAtomValue( importingHuggingFaceRepoDataAtom ) @@ -34,7 +34,8 @@ const ModelDownloadList: React.FC = () => { if (!model.downloadUrl) return null return ( = ({ - modelHandle, + repoData, + downloadUrl, fileName, fileSize = 0, quantization, }) => { - return ( -
-
- {quantization && ( - - {quantization} - - )} -

- {fileName} -

- {toGibibytes(fileSize)} -
- - -
- ) -} - -type DownloadContainerProps = { - modelHandle: string - fileName: string -} + const downloadedModels = useAtomValue(downloadedModelsAtom) + const { downloadModel, abortModelDownload } = useDownloadModel() + const allDownloadStates = useAtomValue(modelDownloadStateAtom) + const downloadState: DownloadState | undefined = allDownloadStates[fileName] -const DownloadContainer: React.FC = ({ - modelHandle, - fileName, -}) => { - const downloadModelMutation = useModelDownloadMutation() + const { requestCreateNewThread } = useCreateNewThread() const setMainViewState = useSetAtom(mainViewStateAtom) - const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom) - const { createThread } = useThreads() - const { abortDownload } = useAbortDownload() - - const { data: assistants } = useAssistantQuery() + const assistants = useAtomValue(assistantsAtom) + const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null - const downloadedModels = useAtomValue(downloadedModelsAtom) - const allDownloadState = useAtomValue(downloadStateListAtom) + const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom) + const defaultModel = useAtomValue(defaultModelAtom) - const persistModelId = modelHandle - .replaceAll('/', '_') - .concat('_') - .concat(fileName) + const model = useMemo(() => { + if (!defaultModel) { + return undefined + } - const downloadState = allDownloadState.find((s) => s.id === persistModelId) + const model: Model = { + ...defaultModel, + sources: [ + { + url: downloadUrl, + filename: fileName, + }, + ], + id: fileName, + name: fileName, + created: Date.now(), + metadata: { + author: 'User', + tags: repoData.tags, + size: fileSize, + }, + } + return model + }, [fileName, fileSize, repoData, downloadUrl, defaultModel]) - const downloadedModel = useMemo( - () => downloadedModels.find((m) => m.model === persistModelId), - [downloadedModels, persistModelId] - ) + const onAbortDownloadClick = useCallback(() => { + if (model) { + abortModelDownload(model) + } + }, [model, abortModelDownload]) - const onDownloadClick = useCallback(() => { - downloadModelMutation.mutate({ - modelId: modelHandle, - fileName: fileName, - persistedModelId: persistModelId, - }) - }, [downloadModelMutation, modelHandle, fileName, persistModelId]) + const onDownloadClick = useCallback(async () => { + if (model) { + downloadModel(model) + } + }, [model, downloadModel]) const onUseModelClick = useCallback(async () => { - if (!assistants || assistants.length === 0) { - toaster({ - title: 'No assistant available.', - description: 'Please create an assistant to create a new thread', - type: 'error', - }) + if (assistants.length === 0) { + alert('No assistant available') return } - - await createThread(fileName, { - ...assistants[0], - model: fileName, - }) - setHfImportingStage('NONE') + await requestCreateNewThread(assistants[0], model) setMainViewState(MainViewState.Thread) + setHfImportingStage('NONE') }, [ - setHfImportingStage, - setMainViewState, - createThread, - fileName, assistants, + model, + requestCreateNewThread, + setMainViewState, + setHfImportingStage, ]) - const onAbortDownloadClick = useCallback(() => { - abortDownload(persistModelId) - }, [abortDownload, persistModelId]) + if (!model) { + return null + } return ( -
- {downloadedModel ? ( +
+
+ {quantization && ( + + {quantization} + + )} +

+ {fileName} +

+ {toGibibytes(fileSize)} +
+ + {isDownloaded ? ( @@ -143,13 +145,13 @@ const DownloadContainer: React.FC = ({ - {formatDownloadPercentage(downloadProgress(downloadState))} + {formatDownloadPercentage(downloadState.percent)}
diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelSegmentInfo/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelSegmentInfo/index.tsx index 96a54fc155..5a63e59023 100644 --- a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelSegmentInfo/index.tsx +++ b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelSegmentInfo/index.tsx @@ -33,9 +33,9 @@ const ModelSegmentInfo = () => { if (!importingHuggingFaceRepoData) return null return ( -
+
-

+

{modelName}

@@ -63,7 +63,7 @@ const ModelSegmentInfo = () => {
- + {downloads}
@@ -73,7 +73,7 @@ const ModelSegmentInfo = () => {
- {importingHuggingFaceRepoData.tags?.map((tag: string) => ( + {importingHuggingFaceRepoData.tags.map((tag) => ( {tag} diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/index.tsx index 74f58b9618..95c01d0cfa 100644 --- a/web/screens/Settings/HuggingFaceRepoDetailModal/index.tsx +++ b/web/screens/Settings/HuggingFaceRepoDetailModal/index.tsx @@ -39,12 +39,11 @@ const HuggingFaceRepoDetailModal = () => { title={importingHuggingFaceRepoData.id} fullPage content={ -
-
+
+
-
- -
+
+
} diff --git a/web/screens/Settings/ImportModelOptionModal/index.tsx b/web/screens/Settings/ImportModelOptionModal/index.tsx index ed01c400af..5a2af2335f 100644 --- a/web/screens/Settings/ImportModelOptionModal/index.tsx +++ b/web/screens/Settings/ImportModelOptionModal/index.tsx @@ -20,12 +20,12 @@ const importOptions: ModelImportOption[] = [ description: 'You maintain your model files outside of Jan. Keeping your files where they are, and Jan will create a smart link to them.', }, - // { - // type: 'MOVE_BINARY_FILE', - // title: 'Move model binary file', - // description: - // 'Jan will move your model binary file from your current folder into Jan Data Folder.', - // }, + { + type: 'MOVE_BINARY_FILE', + title: 'Move model binary file', + description: + 'Jan will move your model binary file from your current folder into Jan Data Folder.', + }, ] const ImportModelOptionModal = () => { diff --git a/web/screens/Settings/ImportSuccessIcon/index.tsx b/web/screens/Settings/ImportSuccessIcon/index.tsx index a822ca4d2c..e574acbf0d 100644 --- a/web/screens/Settings/ImportSuccessIcon/index.tsx +++ b/web/screens/Settings/ImportSuccessIcon/index.tsx @@ -1,6 +1,6 @@ -import React, { useState } from 'react' +import React, { useCallback, useState } from 'react' -import { Check } from 'lucide-react' +import { Check, Pencil } from 'lucide-react' type Props = { onEditModelClick: () => void @@ -9,8 +9,6 @@ type Props = { const ImportSuccessIcon: React.FC = ({ onEditModelClick }) => { const [isHovered, setIsHovered] = useState(false) - console.log(isHovered, onEditModelClick) - const onMouseOver = () => { setIsHovered(true) } @@ -21,34 +19,34 @@ const ImportSuccessIcon: React.FC = ({ onEditModelClick }) => { return (
- {/* {isHovered ? ( + {isHovered ? ( - ) : ( */} - - {/* )} */} + ) : ( + + )}
) } const SuccessIcon = React.memo(() => ( -
+
)) -// const EditIcon: React.FC = React.memo(({ onEditModelClick }) => { -// const onClick = useCallback(() => { -// onEditModelClick() -// }, [onEditModelClick]) - -// return ( -//
-// -//
-// ) -// }) +const EditIcon: React.FC = React.memo(({ onEditModelClick }) => { + const onClick = useCallback(() => { + onEditModelClick() + }, [onEditModelClick]) + + return ( +
+ +
+ ) +}) export default ImportSuccessIcon diff --git a/web/screens/Settings/ImportingModelModal/ImportingModelItem.tsx b/web/screens/Settings/ImportingModelModal/ImportingModelItem.tsx index c1f13fe8b9..c7f6c35f0d 100644 --- a/web/screens/Settings/ImportingModelModal/ImportingModelItem.tsx +++ b/web/screens/Settings/ImportingModelModal/ImportingModelItem.tsx @@ -1,11 +1,15 @@ import { useCallback, useMemo } from 'react' import { ImportingModel } from '@janhq/core' +import { useSetAtom } from 'jotai' import { AlertCircle } from 'lucide-react' +import { setImportModelStageAtom } from '@/hooks/useImportModel' + import { toGibibytes } from '@/utils/converter' +import { editingModelIdAtom } from '../EditModelInfoModal' import ImportInProgressIcon from '../ImportInProgressIcon' import ImportSuccessIcon from '../ImportSuccessIcon' @@ -14,12 +18,16 @@ type Props = { } const ImportingModelItem = ({ model }: Props) => { + const setImportModelStage = useSetAtom(setImportModelStageAtom) + const setEditingModelId = useSetAtom(editingModelIdAtom) + const onEditModelInfoClick = useCallback(() => { - // setEditingModelId(model.importId) - // setImportModelStage('EDIT_MODEL_INFO') - }, []) + setEditingModelId(model.importId) + setImportModelStage('EDIT_MODEL_INFO') + }, [setImportModelStage, setEditingModelId, model.importId]) const onDeleteModelClick = useCallback(() => {}, []) + const displayStatus = useMemo(() => { if (model.status === 'FAILED') { return 'Failed' diff --git a/web/screens/Settings/ImportingModelModal/index.tsx b/web/screens/Settings/ImportingModelModal/index.tsx index 2581c76219..6932ee346b 100644 --- a/web/screens/Settings/ImportingModelModal/index.tsx +++ b/web/screens/Settings/ImportingModelModal/index.tsx @@ -1,57 +1,46 @@ -import { Fragment, useEffect } from 'react' +import { useCallback, useEffect, useState } from 'react' -import { Modal } from '@janhq/joi' +import { joinPath, openFileExplorer } from '@janhq/core' +import { Button, Modal } from '@janhq/joi' import { useAtomValue, useSetAtom } from 'jotai' import { AlertCircle } from 'lucide-react' -import useCortex from '@/hooks/useCortex' import { getImportModelStageAtom, setImportModelStageAtom, } from '@/hooks/useImportModel' -import useModelDownloadMutation from '@/hooks/useModelDownloadMutation' +import { openFileTitle } from '@/utils/titleUtils' import ImportingModelItem from './ImportingModelItem' -import { - importingModelsAtom, - setImportingModelErrorAtom, -} from '@/helpers/atoms/Model.atom' +import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' +import { importingModelsAtom } from '@/helpers/atoms/Model.atom' -const ImportingModelModal: React.FC = () => { - const downloadModelMutation = useModelDownloadMutation() - const { downloadModel } = useCortex() - const setImportModelStage = useSetAtom(setImportModelStageAtom) - const setImportModelError = useSetAtom(setImportingModelErrorAtom) +const ImportingModelModal = () => { const importingModels = useAtomValue(importingModelsAtom) const importModelStage = useAtomValue(getImportModelStageAtom) + const setImportModelStage = useSetAtom(setImportModelStageAtom) + const janDataFolder = useAtomValue(janDataFolderPathAtom) + + const [modelFolder, setModelFolder] = useState('') + + useEffect(() => { + const getModelPath = async () => { + const modelPath = await joinPath([janDataFolder, 'models']) + setModelFolder(modelPath) + } + getModelPath() + }, [janDataFolder]) const finishedImportModel = importingModels.filter( (model) => model.status === 'IMPORTED' ).length - useEffect(() => { - const importModels = async () => { - for (const model of importingModels) { - try { - await downloadModelMutation.mutateAsync({ - modelId: model.path, - }) - } catch (error) { - let errorMessage = String(error) - if (error instanceof Error) { - errorMessage = error.message - } - - setImportModelError(model.importId, errorMessage) - } - } - } - importModels() - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [downloadModel]) + const onOpenModelFolderClick = useCallback(() => { + openFileExplorer(modelFolder) + }, [modelFolder]) return ( { onOpenChange={() => setImportModelStage('NONE')} title={`Importing model (${finishedImportModel}/${importingModels.length})`} content={ - -
+
+
+ + +
+ +
{importingModels.map((model) => ( ))}
-
- -

- Own your model configurations, use at your own risk. - Misconfigurations may result in lower quality or unexpected - outputs. -

+
+
+ +

+ Own your model configurations, use at your own risk. + Misconfigurations may result in lower quality or unexpected + outputs. +

+
- +
} /> ) diff --git a/web/screens/Settings/MyModels/ModelGroup/index.tsx b/web/screens/Settings/MyModels/ModelGroup/index.tsx deleted file mode 100644 index c72e57fe52..0000000000 --- a/web/screens/Settings/MyModels/ModelGroup/index.tsx +++ /dev/null @@ -1,174 +0,0 @@ -import React, { useCallback, useEffect, useState } from 'react' - -import Image from 'next/image' - -import { - EngineStatus, - LlmEngine, - LocalEngine, - Model, - RemoteEngine, - RemoteEngines, -} from '@janhq/core' - -import { Button } from '@janhq/joi' -import { useAtom, useSetAtom } from 'jotai' -import { - SettingsIcon, - ChevronDownIcon, - ChevronUpIcon, - PlusIcon, -} from 'lucide-react' - -import useEngineQuery from '@/hooks/useEngineQuery' -import useGetModelsByEngine from '@/hooks/useGetModelsByEngine' - -import { - getLogoByLocalEngine, - getLogoByRemoteEngine, - getTitleByCategory, -} from '@/utils/model-engine' - -import ModelItem from '../ModelItem' - -import { showEngineListModelAtom } from '@/helpers/atoms/Model.atom' -import { setUpRemoteModelStageAtom } from '@/helpers/atoms/SetupRemoteModel.atom' - -type Props = { - engine: LlmEngine - searchText: string -} - -const ModelGroup: React.FC = ({ engine, searchText }) => { - const [models, setModels] = useState([]) - const { getModelsByEngine } = useGetModelsByEngine() - const setUpRemoteModelStage = useSetAtom(setUpRemoteModelStageAtom) - const { data: engineData } = useEngineQuery() - - const [showEngineListModel, setShowEngineListModel] = useAtom( - showEngineListModelAtom - ) - - const engineLogo: string | undefined = models.find( - (entry) => entry?.metadata?.logo != null - )?.metadata?.logo - - const apiKeyUrl: string | undefined = models.find( - (entry) => entry?.metadata?.api_key_url != null - )?.metadata?.api_key_url - - const onSettingClick = useCallback(() => { - setUpRemoteModelStage('SETUP_API_KEY', engine as unknown as RemoteEngine, { - logo: engineLogo, - api_key_url: apiKeyUrl, - }) - }, [apiKeyUrl, engine, engineLogo, setUpRemoteModelStage]) - - const isEngineReady = - engineData?.find((e) => e.name === engine)?.status === EngineStatus.Ready - - const getEngineStatusReady: LlmEngine[] | undefined = engineData - ?.filter((e) => e.status === EngineStatus.Ready) - .map((x) => x.name as LlmEngine) - - const showModel = showEngineListModel.includes(engine) - - const onClickChevron = useCallback(() => { - if (showModel) { - setShowEngineListModel((prev) => prev.filter((item) => item !== engine)) - } else { - setShowEngineListModel((prev) => [...prev, engine]) - } - }, [engine, setShowEngineListModel, showModel]) - - useEffect(() => { - const matchedModels = getModelsByEngine(engine, searchText) - setModels(matchedModels) - setShowEngineListModel((prev) => [ - ...prev, - ...(getEngineStatusReady as LlmEngine[]), - ]) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [getModelsByEngine, engine, searchText, setShowEngineListModel]) - - const engineName = getTitleByCategory(engine) - const localEngineLogo = getLogoByLocalEngine(engine as LocalEngine) - const remoteEngineLogo = getLogoByRemoteEngine(engine as RemoteEngine) - const isRemoteEngine = RemoteEngines.includes(engine as RemoteEngine) - - if (models.length === 0) return null - - return ( -
-
-
- {!isRemoteEngine && localEngineLogo && ( - logo - )} - {remoteEngineLogo && ( - logo - )} -
- {engineName} -
-
-
- {isRemoteEngine && ( - - )} - {!showModel ? ( - - ) : ( - - )} -
-
-
- {models.map((model) => { - if (!showModel) return null - - return - })} -
-
- ) -} - -export default ModelGroup diff --git a/web/screens/Settings/MyModels/ModelItem/index.tsx b/web/screens/Settings/MyModels/ModelItem/index.tsx deleted file mode 100644 index 021b2e3fb1..0000000000 --- a/web/screens/Settings/MyModels/ModelItem/index.tsx +++ /dev/null @@ -1,299 +0,0 @@ -import { memo, useCallback, useMemo, useState } from 'react' - -import { - EngineStatus, - LocalEngines, - Model, - RemoteEngine, - RemoteEngines, -} from '@janhq/core' -import { Badge, Button, useClickOutside } from '@janhq/joi' - -import { useAtomValue, useSetAtom } from 'jotai' -import { - MoreVerticalIcon, - PlayIcon, - StopCircleIcon, - Trash2Icon, -} from 'lucide-react' -import { twMerge } from 'tailwind-merge' - -import { toaster } from '@/containers/Toast' - -import useAssistantQuery from '@/hooks/useAssistantQuery' -import useEngineInit from '@/hooks/useEngineInit' -import useEngineQuery from '@/hooks/useEngineQuery' -import useModelStart from '@/hooks/useModelStart' -import useModelStop from '@/hooks/useModelStop' -import useModels from '@/hooks/useModels' - -import useThreadCreateMutation from '@/hooks/useThreadCreateMutation' - -import { showWarningMultipleModelModalAtom } from '@/screens/HubScreen2/components/WarningMultipleModelModal' - -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' -import { activeModelsAtom } from '@/helpers/atoms/Model.atom' - -type Props = { - model: Model -} - -// If more than this number of models are running, show a warning modal. -export const concurrentModelWarningThreshold = 2 - -const ModelItem: React.FC = ({ model }) => { - const activeModels = useAtomValue(activeModelsAtom) - const startModel = useModelStart() - const stopModel = useModelStop() - const [more, setMore] = useState(false) - const { deleteModel } = useModels() - const { data: engineData } = useEngineQuery() - const createThreadMutation = useThreadCreateMutation() - const { data: assistants } = useAssistantQuery() - const setMainViewState = useSetAtom(mainViewStateAtom) - const isRemoteEngine = RemoteEngines.includes(model.engine as RemoteEngine) - const isEngineReady = - engineData?.find((e) => e.name === model.engine)?.status === - EngineStatus.Ready - const initializeEngine = useEngineInit() - - const [menu, setMenu] = useState(null) - const [toggle, setToggle] = useState(null) - const setShowWarningMultipleModelModal = useSetAtom( - showWarningMultipleModelModalAtom - ) - useClickOutside(() => setMore(false), null, [menu, toggle]) - - const isActive = useMemo( - () => activeModels.map((m) => m.model).includes(model.model), - [activeModels, model.model] - ) - - const onModelActionClick = useCallback( - (modelId: string) => { - if (isActive) { - // if model already active, stop it - stopModel.mutate(modelId) - return - } - const modelEngine = model.engine - if (!modelEngine) { - toaster({ - title: 'Failed to start model', - description: `Engine for model ${model.model} is undefined`, - type: 'error', - }) - return - } - if (!engineData) { - toaster({ - title: 'Failed to start model', - description: `Engine data is not available. Please try again!`, - type: 'error', - }) - return - } - const engineStatus = engineData.find((e) => e.name === modelEngine) - if (!engineStatus) { - toaster({ - title: 'Failed to start model', - description: `Engine ${modelEngine} is not available`, - type: 'error', - }) - console.error(`Engine ${modelEngine} is not available`) - return - } - - if ( - LocalEngines.find((e) => e === modelEngine) != null && - engineStatus.status === 'not_initialized' - ) { - toaster({ - title: 'Please wait for engine to initialize', - description: `Please retry after engine ${engineStatus.name} is installed.`, - type: 'default', - }) - initializeEngine.mutate(modelEngine) - return - } - - if (activeModels.length >= concurrentModelWarningThreshold) { - // if max concurrent models reached, stop the first model - // display popup - setShowWarningMultipleModelModal(true) - } - startModel.mutate(modelId) - }, - [ - isActive, - startModel, - stopModel, - activeModels.length, - setShowWarningMultipleModelModal, - engineData, - initializeEngine, - model, - ] - ) - - const onDeleteModelClicked = useCallback( - async (modelId: string) => { - await stopModel.mutateAsync(modelId) - await deleteModel(modelId) - }, - [stopModel, deleteModel] - ) - - const isLocalModel = LocalEngines.find( - (e) => model.engine != null && e === model.engine - ) - - const onClickCloudModel = useCallback(async () => { - if (!isRemoteEngine) return null - if (!model || !engineData) return - if (!assistants || !assistants.length) { - toaster({ - title: 'No assistant available.', - description: `Could not create a new thread. Please add an assistant.`, - type: 'error', - }) - return - } - - await createThreadMutation.mutateAsync({ - modelId: model.model, - assistant: assistants[0], - }) - - setMainViewState(MainViewState.Thread) - }, [ - assistants, - createThreadMutation, - engineData, - isRemoteEngine, - model, - setMainViewState, - ]) - - return ( -
-
-
-
-
- {model.model} -
- {model.engine === 'cortex.llamacpp' && ( -
-

- {model.model} -

-
- )} -
-
- - {isLocalModel && ( -
- - {model.version != null ? `v${model.version}` : '-'} - - -
- {isActive ? ( - - - Active - - ) : ( - - - Inactive - - )} -
{ - setMore(!more) - }} - > - - {more && ( -
-
{ - onModelActionClick(model.model) - setMore(false) - }} - > - {isActive ? ( - - ) : ( - - )} - - {isActive ? 'Stop' : 'Start'} -  Model - -
-
onDeleteModelClicked(model.model)} - > - - - Delete Model - -
-
- )} -
-
-
- )} -
-
- ) -} - -export default memo(ModelItem) diff --git a/web/screens/Settings/MyModels/MyModelList/index.tsx b/web/screens/Settings/MyModels/MyModelList/index.tsx new file mode 100644 index 0000000000..045f454c0c --- /dev/null +++ b/web/screens/Settings/MyModels/MyModelList/index.tsx @@ -0,0 +1,226 @@ +import { memo, useState } from 'react' + +import { InferenceEngine, Model } from '@janhq/core' +import { Badge, Button, Tooltip, useClickOutside } from '@janhq/joi' +import { useAtom } from 'jotai' +import { + MoreVerticalIcon, + PlayIcon, + StopCircleIcon, + Trash2Icon, +} from 'lucide-react' +import { twMerge } from 'tailwind-merge' + +import { useActiveModel } from '@/hooks/useActiveModel' +import useDeleteModel from '@/hooks/useDeleteModel' + +import { toGibibytes } from '@/utils/converter' + +import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' + +type Props = { + model: Model + groupTitle?: string +} + +const MyModelList = ({ model }: Props) => { + const { activeModel, startModel, stopModel, stateModel } = useActiveModel() + const isActiveModel = stateModel.model?.id === model.id + const { deleteModel } = useDeleteModel() + const [more, setMore] = useState(false) + const [serverEnabled, setServerEnabled] = useAtom(serverEnabledAtom) + + const [menu, setMenu] = useState(null) + const [toggle, setToggle] = useState(null) + useClickOutside(() => setMore(false), null, [menu, toggle]) + + const onModelActionClick = (modelId: string) => { + if (activeModel && activeModel.id === modelId) { + stopModel() + window.core?.api?.stopServer() + setServerEnabled(false) + } else if (!serverEnabled) { + startModel(modelId) + } + } + + const engineHasLogo = [ + InferenceEngine.anthropic, + InferenceEngine.cohere, + InferenceEngine.martian, + InferenceEngine.mistral, + InferenceEngine.openai, + ] + + return ( +
+
+
+ {engineHasLogo.map((x) => { + if (x === model.engine) { + return ( +
+ Model Provider +
+ ) + } + })} +
+
+ {model.name} +
+ {model.engine === InferenceEngine.nitro && ( +
+

+ {model.id} +

+
+ )} +
+
+ + {model.engine === InferenceEngine.nitro && ( +
+ + {toGibibytes(model.metadata.size)} + + +
+ {stateModel.loading && stateModel.model?.id === model.id ? ( + + + + {stateModel.state === 'start' + ? 'Starting...' + : 'Stopping...'} + + + ) : activeModel && activeModel.id === model.id ? ( + + + Active + + ) : ( + + + Inactive + + )} +
{ + setMore(!more) + }} + > + + {more && ( +
+ { + onModelActionClick(model.id) + setMore(false) + }} + > + {activeModel && activeModel.id === model.id ? ( + + ) : ( + + )} + + {isActiveModel ? stateModel.state : 'Start'} +  Model + +
+ } + disabled={!serverEnabled} + content={ + + {activeModel && activeModel.id === model.id + ? 'The API server is running, change model will stop the server' + : 'Threads are disabled while the server is running'} + + } + /> +
{ + setTimeout(async () => { + if (!serverEnabled) { + await stopModel() + deleteModel(model) + } + }, 500) + setMore(false) + }} + > + + + Delete Model + +
+
+ )} +
+
+
+ )} +
+
+ ) +} + +export default memo(MyModelList) diff --git a/web/screens/Settings/MyModels/index.tsx b/web/screens/Settings/MyModels/index.tsx index 060875256b..d90081b6c3 100644 --- a/web/screens/Settings/MyModels/index.tsx +++ b/web/screens/Settings/MyModels/index.tsx @@ -2,45 +2,39 @@ import { useCallback, useMemo, useState } from 'react' import { useDropzone } from 'react-dropzone' -import { LlmEngines } from '@janhq/core' +import { InferenceEngine } from '@janhq/core' + import { Button, ScrollArea } from '@janhq/joi' import { useAtomValue, useSetAtom } from 'jotai' -import { UploadIcon, UploadCloudIcon } from 'lucide-react' +import { UploadCloudIcon, UploadIcon } from 'lucide-react' import { twMerge } from 'tailwind-merge' -import BlankState from '@/containers/BlankState' - import ModelSearch from '@/containers/ModelSearch' -import useDropModelBinaries from '@/hooks/useDropModelBinaries' +import SetupRemoteModel from '@/containers/SetupRemoteModel' +import useDropModelBinaries from '@/hooks/useDropModelBinaries' import { setImportModelStageAtom } from '@/hooks/useImportModel' -import ModelGroup from './ModelGroup' +import MyModelList from './MyModelList' -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' const MyModels = () => { - const setMainViewState = useSetAtom(mainViewStateAtom) const downloadedModels = useAtomValue(downloadedModelsAtom) + const setImportModelStage = useSetAtom(setImportModelStageAtom) const { onDropModels } = useDropModelBinaries() const [searchText, setSearchText] = useState('') - const setImportModelStage = useSetAtom(setImportModelStageAtom) - - const onImportModelClick = useCallback(() => { - setImportModelStage('SELECTING_MODEL') - }, [setImportModelStage]) const filteredDownloadedModels = useMemo( () => downloadedModels - .filter((m) => - m.model.toLowerCase().includes(searchText.toLowerCase().trim()) + .filter((e) => + e.name.toLowerCase().includes(searchText.toLowerCase().trim()) ) - .sort((a, b) => a.model.localeCompare(b.model)), + .sort((a, b) => a.name.localeCompare(b.name)), [downloadedModels, searchText] ) @@ -50,10 +44,20 @@ const MyModels = () => { onDrop: onDropModels, }) + const onImportModelClick = useCallback(() => { + setImportModelStage('SELECTING_MODEL') + }, [setImportModelStage]) + const onSearchChange = useCallback((input: string) => { setSearchText(input) }, []) + const findByEngine = filteredDownloadedModels.map((x) => x.engine) + const groupByEngine = findByEngine.filter(function (item, index) { + if (findByEngine.indexOf(item) === index) + return item !== InferenceEngine.nitro + }) + return (
@@ -92,49 +96,48 @@ const MyModels = () => {
- {!filteredDownloadedModels.length ? ( - <> - {searchText.length > 0 ? ( - setMainViewState(MainViewState.Hub)} - > - Explore The Hub - - } - /> - ) : ( - setMainViewState(MainViewState.Hub)} - > - Explore The Hub - - } - /> - )} - - ) : ( -
- {LlmEngines.map((engine) => { - return ( - - ) - })} -
- )} +
+ {filteredDownloadedModels.filter( + (x) => x.engine === InferenceEngine.nitro + ).length !== 0 && ( +
+
+
Cortex
+
+
+ {filteredDownloadedModels + ? filteredDownloadedModels + .filter((x) => x.engine === InferenceEngine.nitro) + .map((model) => { + return + }) + : null} +
+
+ )} + + {groupByEngine.map((engine, i) => { + return ( +
+
+
+ {engine} +
+ +
+
+ {filteredDownloadedModels + ? filteredDownloadedModels + .filter((x) => x.engine === engine) + .map((model) => { + return + }) + : null} +
+
+ ) + })} +
diff --git a/web/screens/Settings/SelectingModelModal/index.tsx b/web/screens/Settings/SelectingModelModal/index.tsx index b84f72e3ea..6273d00323 100644 --- a/web/screens/Settings/SelectingModelModal/index.tsx +++ b/web/screens/Settings/SelectingModelModal/index.tsx @@ -1,72 +1,29 @@ import { useCallback } from 'react' import { useDropzone } from 'react-dropzone' -import { ImportingModel, SelectFileOption } from '@janhq/core' +import { SelectFileOption, systemInformation } from '@janhq/core' import { Modal } from '@janhq/joi' import { useAtomValue, useSetAtom } from 'jotai' import { UploadCloudIcon } from 'lucide-react' -import { snackbar } from '@/containers/Toast' - import useDropModelBinaries from '@/hooks/useDropModelBinaries' -import { +import useImportModel, { getImportModelStageAtom, setImportModelStageAtom, } from '@/hooks/useImportModel' -import { importingModelsAtom } from '@/helpers/atoms/Model.atom' - -const SelectingModelModal: React.FC = () => { +const SelectingModelModal = () => { const setImportModelStage = useSetAtom(setImportModelStageAtom) - const setImportingModels = useSetAtom(importingModelsAtom) const importModelStage = useAtomValue(getImportModelStageAtom) const { onDropModels } = useDropModelBinaries() - - const onImportFileWindowsClick = useCallback(async () => { - const options: SelectFileOption = { - title: 'Select model files', - buttonLabel: 'Select', - allowMultiple: true, - filters: [ - { name: 'GGUF Files', extensions: ['gguf'] }, - { name: 'All Files', extensions: ['*'] }, - ], - } - const filePaths: string[] = await window.core?.api?.selectFiles(options) - if (!filePaths || filePaths.length === 0) return - - const importingModels: ImportingModel[] = filePaths - .filter((path) => path.endsWith('.gguf')) - .map((path) => { - const normalizedPath = isWindows ? path.replace(/\\/g, '/') : path - - return { - importId: normalizedPath, - modelId: undefined, - name: normalizedPath.replace('.gguf', ''), - description: '', - path: path, - tags: [], - size: 0, - status: 'PREPARING', - format: 'gguf', - } - }) - if (importingModels.length < 1) { - snackbar({ - description: `Only files with .gguf extension can be imported.`, - type: 'error', - }) - return - } - setImportingModels(importingModels) - setImportModelStage('MODEL_SELECTED') - }, [setImportingModels, setImportModelStage]) + const { sanitizeFilePaths } = useImportModel() const onSelectFileClick = useCallback(async () => { - if (isWindows) { - return onImportFileWindowsClick() + const platform = (await systemInformation()).osInfo?.platform + if (platform === 'win32') { + setImportModelStage('CHOOSE_WHAT_TO_IMPORT') + return } const options: SelectFileOption = { title: 'Select model folders', @@ -74,36 +31,10 @@ const SelectingModelModal: React.FC = () => { allowMultiple: true, selectDirectory: true, } - const filePaths: string[] = await window.core?.api?.selectFiles(options) + const filePaths = await window.core?.api?.selectFiles(options) if (!filePaths || filePaths.length === 0) return - - const importingModels: ImportingModel[] = filePaths - .filter((path) => path.endsWith('.gguf')) - .map((path) => { - const normalizedPath = isWindows ? path.replace(/\\/g, '/') : path - - return { - importId: normalizedPath, - modelId: undefined, - name: normalizedPath.replace('.gguf', ''), - description: '', - path: path, - tags: [], - size: 0, - status: 'PREPARING', - format: 'gguf', - } - }) - if (importingModels.length < 1) { - snackbar({ - description: `Only files with .gguf extension can be imported.`, - type: 'error', - }) - return - } - setImportingModels(importingModels) - setImportModelStage('MODEL_SELECTED') - }, [setImportModelStage, setImportingModels, onImportFileWindowsClick]) + sanitizeFilePaths(filePaths) + }, [sanitizeFilePaths, setImportModelStage]) const { isDragActive, getRootProps } = useDropzone({ noClick: true, @@ -122,7 +53,9 @@ const SelectingModelModal: React.FC = () => { return ( setImportModelStage('NONE')} + onOpenChange={() => { + setImportModelStage('NONE') + }} title="Import Model" content={
diff --git a/web/screens/Settings/SettingDetail/SettingDetailItem/SettingDetailTextInputItem/index.tsx b/web/screens/Settings/SettingDetail/SettingDetailItem/SettingDetailTextInputItem/index.tsx index 96e48639e9..b6a204e2e8 100644 --- a/web/screens/Settings/SettingDetail/SettingDetailItem/SettingDetailTextInputItem/index.tsx +++ b/web/screens/Settings/SettingDetail/SettingDetailItem/SettingDetailTextInputItem/index.tsx @@ -8,14 +8,25 @@ import { import { Input } from '@janhq/joi' import { CopyIcon, EyeIcon, FolderOpenIcon } from 'lucide-react' - -import { markdownParser } from '@/utils/markdown-parser' +import { Marked, Renderer } from 'marked' type Props = { settingProps: SettingComponentProps onValueChanged?: (e: string) => void } +const marked: Marked = new Marked({ + renderer: { + link: (href, title, text) => + Renderer.prototype.link + ?.apply(this, [href, title, text]) + .replace( + ') => void } +const marked: Marked = new Marked({ + renderer: { + link: (href, title, text) => { + return Renderer.prototype.link + ?.apply(this, [href, title, text]) + .replace( + ' = ({ settingProps, onValueChanged, }) => { const { value } = settingProps.controllerProps as CheckboxComponentProps - const description = markdownParser.parse(settingProps.description ?? '', { + const description = marked.parse(settingProps.description ?? '', { async: false, }) diff --git a/web/screens/Settings/SettingDetail/index.tsx b/web/screens/Settings/SettingDetail/index.tsx index b19cf259a9..85feafbb34 100644 --- a/web/screens/Settings/SettingDetail/index.tsx +++ b/web/screens/Settings/SettingDetail/index.tsx @@ -2,19 +2,20 @@ import { useAtomValue } from 'jotai' import Advanced from '@/screens/Settings/Advanced' import AppearanceOptions from '@/screens/Settings/Appearance' - +import ExtensionCatalog from '@/screens/Settings/CoreExtensions' import ExtensionSetting from '@/screens/Settings/ExtensionSetting' import Hotkeys from '@/screens/Settings/Hotkeys' import MyModels from '@/screens/Settings/MyModels' -import EngineSetting from '../EngineSetting' - import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom' const SettingDetail = () => { const selectedSetting = useAtomValue(selectedSettingAtom) switch (selectedSetting) { + case 'Extensions': + return + case 'Appearance': return @@ -27,9 +28,6 @@ const SettingDetail = () => { case 'My Models': return - case 'Engines': - return - default: return } diff --git a/web/screens/Settings/SettingLeftPanel/index.tsx b/web/screens/Settings/SettingLeftPanel/index.tsx index d6ee7da582..87ddde4d41 100644 --- a/web/screens/Settings/SettingLeftPanel/index.tsx +++ b/web/screens/Settings/SettingLeftPanel/index.tsx @@ -1,23 +1,71 @@ -import React from 'react' +import { memo, useEffect, useState } from 'react' import { useAtomValue } from 'jotai' import LeftPanelContainer from '@/containers/LeftPanelContainer' -import { SettingScreen, SettingScreenList } from '..' - import SettingItem from './SettingItem' -import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' +import { extensionManager } from '@/extension' +import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' +import { janSettingScreenAtom } from '@/helpers/atoms/Setting.atom' + +const SettingLeftPanel = () => { + const settingScreens = useAtomValue(janSettingScreenAtom) + const inActiveEngineProvider = useAtomValue(inActiveEngineProviderAtom) + + const [extensionHasSettings, setExtensionHasSettings] = useState< + { name?: string; setting: string }[] + >([]) + + const [engineHasSettings, setEngineHasSettings] = useState< + { name?: string; setting: string; provider: string }[] + >([]) + + useEffect(() => { + const getAllSettings = async () => { + const extensionsMenu: { name?: string; setting: string }[] = [] + const engineMenu: { + name?: string + setting: string + provider: string + }[] = [] + const extensions = extensionManager.getAll() -const SettingLeftPanel: React.FC = () => { - const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) + for (const extension of extensions) { + const settings = await extension.getSettings() + if ( + typeof extension.getSettings === 'function' && + 'provider' in extension && + typeof extension.provider === 'string' + ) { + if ( + (settings && settings.length > 0) || + (await extension.installationState()) !== 'NotRequired' + ) { + engineMenu.push({ + name: extension.productName, + setting: extension.name, + provider: + 'provider' in extension && + typeof extension.provider === 'string' + ? extension.provider + : '', + }) + } + } else if (settings && settings.length > 0) { + extensionsMenu.push({ + name: extension.productName, + setting: extension.name, + }) + } + } - const screenList: SettingScreen[] = Array.isArray(SettingScreenList) - ? experimentalEnabled - ? SettingScreenList - : SettingScreenList.filter((screen) => screen !== 'Engines') - : [] + setExtensionHasSettings(extensionsMenu) + setEngineHasSettings(engineMenu) + } + getAllSettings() + }, []) return ( @@ -28,16 +76,55 @@ const SettingLeftPanel: React.FC = () => {
- {screenList.map((settingScreen) => ( + {settingScreens.map((settingScreen) => ( ))} + + {engineHasSettings.filter( + (x) => !inActiveEngineProvider.includes(x.provider) + ).length > 0 && ( +
+ +
+ )} + + {engineHasSettings + .sort((a, b) => a.provider.localeCompare(b.provider)) + .filter((x) => !inActiveEngineProvider.includes(x.provider)) + .map((item) => ( + + ))} + + {extensionHasSettings.length > 0 && ( +
+ +
+ )} + + {extensionHasSettings + .sort((a, b) => String(a.name).localeCompare(String(b.name))) + .map((item) => ( + + ))}
) } -export default React.memo(SettingLeftPanel) +export default memo(SettingLeftPanel) diff --git a/web/screens/Settings/index.tsx b/web/screens/Settings/index.tsx index fe441a452d..a90a37915f 100644 --- a/web/screens/Settings/index.tsx +++ b/web/screens/Settings/index.tsx @@ -16,10 +16,11 @@ export const SettingScreenList = [ 'Appearance', 'Keyboard Shortcuts', 'Advanced Settings', - 'Engines', + 'Extensions', ] as const -export type SettingScreen = (typeof SettingScreenList)[number] +export type SettingScreenTuple = typeof SettingScreenList +export type SettingScreen = SettingScreenTuple[number] const SettingsScreen = () => { const setSelectedSettingScreen = useSetAtom(selectedSettingAtom) diff --git a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx index 233308f51b..4dab6bfa82 100644 --- a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx @@ -1,14 +1,17 @@ import { useCallback } from 'react' import { SettingComponentProps } from '@janhq/core' -import { useAtomValue } from 'jotai' +import { useAtomValue, useSetAtom } from 'jotai' -import useModelStop from '@/hooks/useModelStop' +import { useActiveModel } from '@/hooks/useActiveModel' +import { useCreateNewThread } from '@/hooks/useCreateNewThread' import SettingComponentBuilder from '../../../../containers/ModelSetting/SettingComponent' -import { activeModelsAtom } from '@/helpers/atoms/Model.atom' -import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' +import { + activeThreadAtom, + engineParamsUpdateAtom, +} from '@/helpers/atoms/Thread.atom' type Props = { componentData: SettingComponentProps[] @@ -16,47 +19,70 @@ type Props = { const AssistantSetting: React.FC = ({ componentData }) => { const activeThread = useAtomValue(activeThreadAtom) - const activeModels = useAtomValue(activeModelsAtom) - const stopModel = useModelStop() + const { updateThreadMetadata } = useCreateNewThread() + const { stopModel } = useActiveModel() + const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom) const onValueChanged = useCallback( (key: string, value: string | number | boolean) => { if (!activeThread) return - console.log('onValueChanged', key, value) const shouldReloadModel = componentData.find((x) => x.key === key)?.requireModelReload ?? false if (shouldReloadModel) { - const model = activeModels.find( - (model) => activeThread.assistants[0]?.model === model.model - ) - if (model) stopModel.mutate(model.model) + setEngineParamsUpdate(true) + stopModel() } - // if ( - // activeThread.assistants[0].tools && - // (key === 'chunk_overlap' || key === 'chunk_size') - // ) { - // if ( - // activeThread.assistants[0].tools[0]?.settings.chunk_size < - // activeThread.assistants[0].tools[0]?.settings.chunk_overlap - // ) { - // activeThread.assistants[0].tools[0].settings.chunk_overlap = - // activeThread.assistants[0].tools[0].settings.chunk_size - // } - // if ( - // key === 'chunk_size' && - // value < activeThread.assistants[0].tools[0].settings.chunk_overlap - // ) { - // activeThread.assistants[0].tools[0].settings.chunk_overlap = value - // } else if ( - // key === 'chunk_overlap' && - // value > activeThread.assistants[0].tools[0].settings.chunk_size - // ) { - // activeThread.assistants[0].tools[0].settings.chunk_size = value - // } - // } + if ( + activeThread.assistants[0].tools && + (key === 'chunk_overlap' || key === 'chunk_size') + ) { + if ( + activeThread.assistants[0].tools[0]?.settings.chunk_size < + activeThread.assistants[0].tools[0]?.settings.chunk_overlap + ) { + activeThread.assistants[0].tools[0].settings.chunk_overlap = + activeThread.assistants[0].tools[0].settings.chunk_size + } + if ( + key === 'chunk_size' && + value < activeThread.assistants[0].tools[0].settings.chunk_overlap + ) { + activeThread.assistants[0].tools[0].settings.chunk_overlap = value + } else if ( + key === 'chunk_overlap' && + value > activeThread.assistants[0].tools[0].settings.chunk_size + ) { + activeThread.assistants[0].tools[0].settings.chunk_size = value + } + } + updateThreadMetadata({ + ...activeThread, + assistants: [ + { + ...activeThread.assistants[0], + tools: [ + { + type: 'retrieval', + enabled: true, + settings: { + ...(activeThread.assistants[0].tools && + activeThread.assistants[0].tools[0]?.settings), + [key]: value, + }, + }, + ], + }, + ], + }) }, - [activeModels, activeThread, componentData, stopModel] + [ + activeThread, + componentData, + setEngineParamsUpdate, + stopModel, + updateThreadMetadata, + ] ) if (!activeThread) return null diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyModel/OnDeviceListStarter.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyModel/OnDeviceListStarter.tsx deleted file mode 100644 index fc202a0c01..0000000000 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyModel/OnDeviceListStarter.tsx +++ /dev/null @@ -1,193 +0,0 @@ -import React, { Fragment, useCallback, useState } from 'react' - -import Image from 'next/image' - -import { Model, RemoteEngine, RemoteEngines } from '@janhq/core' -import { Input } from '@janhq/joi' - -import { useSetAtom } from 'jotai' -import { SearchIcon, PlusIcon } from 'lucide-react' - -import { twMerge } from 'tailwind-merge' - -import Spinner from '@/containers/Loader/Spinner' - -import useModelHub from '@/hooks/useModelHub' - -import BuiltInModelCard from '@/screens/HubScreen2/components/BuiltInModelCard' - -import { HfModelEntry } from '@/utils/huggingface' - -import { getTitleByCategory } from '@/utils/model-engine' - -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' -import { localModelModalStageAtom } from '@/helpers/atoms/DownloadLocalModel.atom' -import { hubFilterAtom } from '@/helpers/atoms/Hub.atom' -import { setUpRemoteModelStageAtom } from '@/helpers/atoms/SetupRemoteModel.atom' - -const OnDeviceStarterScreen = () => { - const { data } = useModelHub() - const [searchValue, setSearchValue] = useState('') - const setLocalModelModalStage = useSetAtom(localModelModalStageAtom) - const setUpRemoteModelStage = useSetAtom(setUpRemoteModelStageAtom) - const setMainViewState = useSetAtom(mainViewStateAtom) - const setFilter = useSetAtom(hubFilterAtom) - - const onItemClick = useCallback( - (name: string) => { - setLocalModelModalStage('MODEL_LIST', name) - }, - [setLocalModelModalStage] - ) - - if (!data) - return ( -
- -
- ) - - const builtInModels: HfModelEntry[] = - data.modelCategories.get('BuiltInModels') || [] - const huggingFaceModels: HfModelEntry[] = - data.modelCategories.get('HuggingFace') || [] - - const engineModelMap = new Map() - for (const [key, value] of data.modelCategories) { - if (key !== 'HuggingFace' && key !== 'BuiltInModels') { - engineModelMap.set(key as unknown as typeof RemoteEngines, value) - } - } - - const models: HfModelEntry[] = builtInModels.concat(huggingFaceModels) - - const filteredModels = models.filter((model) => { - return model.name.toLowerCase().includes(searchValue.toLowerCase()) - }) - - const recommendModels = models.filter((model) => { - return ( - model.name.toLowerCase().includes('cortexso/tinyllama') || - model.name.toLowerCase().includes('cortexso/mistral') - ) - }) - - return ( - -
- setSearchValue(e.target.value)} - placeholder="Search..." - prefixIcon={} - /> -
- {!filteredModels.length ? ( -
-

- No Result Found -

-
- ) : ( - filteredModels.map((model) => ( -
onItemClick(model.name)} - > -

- {model.name.replaceAll('cortexso/', '')} -

-
- )) - )} -
-
-
-

On-device Models

-

{ - setFilter('On-device') - setMainViewState(MainViewState.Hub) - }} - > - See All -

-
- {recommendModels.map((model) => ( - - ))} - -
-

Cloud Models

-
- -
- {Array.from(engineModelMap.entries()) - .slice(0, 3) - .map(([engine, models]) => { - const engineLogo: string | undefined = models.find( - (entry) => entry.model?.metadata?.logo != null - )?.model?.metadata?.logo - const apiKeyUrl: string | undefined = models.find( - (entry) => entry.model?.metadata?.api_key_url != null - )?.model?.metadata?.api_key_url - const defaultModel: Model | undefined = models.find( - (entry) => entry.model != null - )?.model - return ( -
{ - setUpRemoteModelStage( - 'SETUP_API_KEY', - engine as unknown as RemoteEngine, - { - logo: engineLogo, - api_key_url: apiKeyUrl, - model: defaultModel, - } - ) - }} - > - {engineLogo ? ( - Engine logo - ) : ( -
- )} -

{getTitleByCategory(engine as unknown as RemoteEngine)}

-
- ) - })} - -
-
{ - setFilter('Cloud') - setMainViewState(MainViewState.Hub) - }} - > - -
-

See All

-
-
-
- ) -} - -export default OnDeviceStarterScreen diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyModel/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyModel/index.tsx index 2dcfd80416..77913c991b 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyModel/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyModel/index.tsx @@ -1,30 +1,31 @@ import { memo } from 'react' +import { Button } from '@janhq/joi' +import { useSetAtom } from 'jotai' + import LogoMark from '@/containers/Brand/Logo/Mark' -import CenterPanelContainer from '@/containers/CenterPanelContainer' +import { MainViewState } from '@/constants/screens' -import OnDeviceStarterScreen from './OnDeviceListStarter' +import { mainViewStateAtom } from '@/helpers/atoms/App.atom' const EmptyModel = () => { + const setMainViewState = useSetAtom(mainViewStateAtom) + return ( - -
-
-
- -

Select a model to start

-
- -
-
-
-
-
+
+ +

Welcome!

+

+ You need to download your first model +

+ +
) } diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyThread/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyThread/index.tsx index 06c6154c3d..6fc05d44bb 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyThread/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/EmptyThread/index.tsx @@ -1,29 +1,28 @@ -import { Fragment, memo } from 'react' +import { memo } from 'react' -import { LocalEngines } from '@janhq/core' +import { InferenceEngine } from '@janhq/core' import { Button } from '@janhq/joi' import { useAtomValue, useSetAtom } from 'jotai' import LogoMark from '@/containers/Brand/Logo/Mark' -import { MainViewState, mainViewStateAtom } from '@/helpers/atoms/App.atom' +import { MainViewState } from '@/constants/screens' + +import { mainViewStateAtom } from '@/helpers/atoms/App.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' -const EmptyThread: React.FC = () => { +const EmptyThread = () => { const downloadedModels = useAtomValue(downloadedModelsAtom) const setMainViewState = useSetAtom(mainViewStateAtom) - - const haveLocalModel = downloadedModels.filter( - (e) => LocalEngines.find((x) => x === e.engine) != null - ) + const showOnboardingStep = + downloadedModels.filter((e) => e.engine === InferenceEngine.nitro) + .length === 0 return (
- {haveLocalModel ? ( -

How can I help you?

- ) : ( - + {showOnboardingStep ? ( + <>

{`You don't have a local model yet.`}

@@ -34,7 +33,9 @@ const EmptyThread: React.FC = () => { > Explore The Hub -
+ + ) : ( +

How can I help you?

)}
) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx index 7ae4b1456c..5b5218bb9a 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx @@ -1,35 +1,48 @@ +import { MessageStatus } from '@janhq/core' + import { useAtomValue } from 'jotai' +import ErrorMessage from '@/containers/ErrorMessage' import ListContainer from '@/containers/ListContainer' -import SimpleTextMessage from '../SimpleTextMessage' +import { loadModelErrorAtom } from '@/hooks/useActiveModel' + +import ChatItem from '../ChatItem' + +import LoadModelError from '../LoadModelError' +import EmptyModel from './EmptyModel' import EmptyThread from './EmptyThread' import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' +import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' -type Props = { - onResendMessage: () => void -} - -const ChatBody: React.FC = ({ onResendMessage }) => { +const ChatBody = () => { const messages = useAtomValue(getCurrentChatMessagesAtom) + const downloadedModels = useAtomValue(downloadedModelsAtom) + const loadModelError = useAtomValue(loadModelErrorAtom) + if (!downloadedModels.length) return if (!messages.length) return return ( - {messages.map((message, index) => { - const isLatestMessage = index === messages.length - 1 - return ( - - ) - })} + {messages.map((message, index) => ( +
+ {message.status !== MessageStatus.Error && + message.content.length > 0 && ( + + )} + + {!loadModelError && + index === messages.length - 1 && + message.status !== MessageStatus.Pending && + message.status !== MessageStatus.Ready && ( + + )} +
+ ))} + {loadModelError && }
) } diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/SendMessageButton.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/SendMessageButton.tsx deleted file mode 100644 index 0469338c30..0000000000 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/SendMessageButton.tsx +++ /dev/null @@ -1,47 +0,0 @@ -import { useMemo } from 'react' - -import React from 'react' - -import { Button } from '@janhq/joi' -import { useAtomValue } from 'jotai' - -import { currentPromptAtom } from '@/containers/Providers/Jotai' - -type Props = { - onSendMessageClick: (message: string) => void -} - -const SendMessageButton: React.FC = ({ onSendMessageClick }) => { - const currentPrompt = useAtomValue(currentPromptAtom) - - const showSendButton = useMemo(() => { - if (currentPrompt.trim().length === 0) return false - return true - }, [currentPrompt]) - - if (!showSendButton) return null - - return ( - - ) -} - -export default React.memo(SendMessageButton) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/StopInferenceButton.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/StopInferenceButton.tsx deleted file mode 100644 index 7fe2764cd9..0000000000 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/StopInferenceButton.tsx +++ /dev/null @@ -1,29 +0,0 @@ -import React from 'react' - -import { Button } from '@janhq/joi' - -import { useAtomValue } from 'jotai' -import { StopCircle } from 'lucide-react' - -import { disableStopInferenceAtom } from '@/helpers/atoms/ChatMessage.atom' - -type Props = { - onStopInferenceClick: () => void -} - -const StopInferenceButton: React.FC = ({ onStopInferenceClick }) => { - const disabled = useAtomValue(disableStopInferenceAtom) - - return ( - - ) -} - -export default React.memo(StopInferenceButton) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/index.tsx deleted file mode 100644 index a55d691411..0000000000 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatActionButton/index.tsx +++ /dev/null @@ -1,40 +0,0 @@ -import { useMemo } from 'react' - -import { useAtomValue } from 'jotai' - -import SendMessageButton from './SendMessageButton' -import StopInferenceButton from './StopInferenceButton' - -import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' - -import { isGeneratingResponseAtom } from '@/helpers/atoms/Thread.atom' - -type Props = { - onStopInferenceClick: () => void - onSendMessageClick: (message: string) => void -} - -const ChatActionButton: React.FC = ({ - onStopInferenceClick, - onSendMessageClick, -}) => { - const messages = useAtomValue(getCurrentChatMessagesAtom) - const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom) - - const showStopButton = useMemo(() => { - if (isGeneratingResponse) return true - - const lastMessage = messages[messages.length - 1] - if (!lastMessage) return false - if (lastMessage.status === 'in_progress') return true - return false - }, [isGeneratingResponse, messages]) - - if (showStopButton) { - return - } - - return -} - -export default ChatActionButton diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatTextInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatTextInput/index.tsx deleted file mode 100644 index f43bf68924..0000000000 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/ChatTextInput/index.tsx +++ /dev/null @@ -1,91 +0,0 @@ -import { useCallback, useEffect, useMemo, useRef } from 'react' - -import { TextArea } from '@janhq/joi' -import { useAtom, useAtomValue } from 'jotai' - -import { twMerge } from 'tailwind-merge' - -import { currentPromptAtom } from '@/containers/Providers/Jotai' - -import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' - -import { spellCheckAtom } from '@/helpers/atoms/Setting.atom' -import { - getActiveThreadIdAtom, - isGeneratingResponseAtom, -} from '@/helpers/atoms/Thread.atom' - -type Props = { - isSettingActive: boolean - onSendMessageClick: (message: string) => void -} - -const ChatTextInput: React.FC = ({ - isSettingActive, - onSendMessageClick, -}) => { - const messages = useAtomValue(getCurrentChatMessagesAtom) - const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom) - const textareaRef = useRef(null) - const activeThreadId = useAtomValue(getActiveThreadIdAtom) - const spellCheck = useAtomValue(spellCheckAtom) - - const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom) - - const disabled = useMemo(() => !activeThreadId, [activeThreadId]) - - const onChange = useCallback( - (e: React.ChangeEvent) => { - setCurrentPrompt(e.target.value) - }, - [setCurrentPrompt] - ) - - useEffect(() => { - textareaRef.current?.focus() - }) - - useEffect(() => { - if (textareaRef.current?.clientHeight) { - textareaRef.current.style.height = isSettingActive ? '100px' : '40px' - textareaRef.current.style.height = textareaRef.current.scrollHeight + 'px' - textareaRef.current.style.overflow = - textareaRef.current.clientHeight >= 390 ? 'auto' : 'hidden' - } - }, [textareaRef.current?.clientHeight, currentPrompt, isSettingActive]) - - const onKeyDown = useCallback( - (e: React.KeyboardEvent) => { - if (e.key === 'Enter' && !e.shiftKey && !e.nativeEvent.isComposing) { - e.preventDefault() - if (isGeneratingResponse) return - const lastMessage = messages[messages.length - 1] - if (!lastMessage || lastMessage.status !== 'in_progress') { - onSendMessageClick(currentPrompt) - return - } - } - }, - [messages, isGeneratingResponse, currentPrompt, onSendMessageClick] - ) - - return ( -