Skip to content

Commit

Permalink
feat: experimental deepspeed on windows (#419)
Browse files Browse the repository at this point in the history
* add experimental deepspeed wheel for win

* add japanese and italian for bark voice

* fix stable audio layout

* create React UI proxy base

* update README
  • Loading branch information
rsxdalv authored Nov 14, 2024
1 parent 136e6e2 commit c1b4414
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 44 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@

## Changelog

Nov 14:
* Add experimental Windows deepspeed wheel.
* Add more languages to Bark voice clone.

Nov 11:
* Switch to a fixed fairseq version for windows reducing installation conflicts and speeding up updates.

Expand Down
80 changes: 37 additions & 43 deletions react-ui/src/pages/api/gradio/[name].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ import { Client } from "@gradio/client";
import type { NextApiRequest, NextApiResponse } from "next";
import { getFile } from "../../../backend-utils/getFile";
import { GradioFile } from "../../../types/GradioFile";
import { PayloadMessage, PredictFunction } from "@gradio/client/dist/types";
import {
GradioEvent,
PayloadMessage,
PredictFunction,
SubmitIterable,
} from "@gradio/client/dist/types";

type Data = { data: any };

Expand Down Expand Up @@ -53,11 +58,40 @@ const extractChoicesTuple = ({ choices }: GradioChoices) =>
const getChoices = (result: { data: GradioChoices[] }) =>
extractChoices(result?.data[0]);

const proxyGradioFile = (data: any) =>
// typeof data === "object" && data.__type__ === "file"
// // ? new GradioFile(data.url, data.name)
// : data;
data

const proxyGradioFiles = (data: any[]) =>
Array.isArray(data)
? data.map(proxyGradioFile)
: // : typeof data === "object"
// ? Object.fromEntries(
// Object.entries(data).map(([key, value]) => [
// key,
// proxyGradioFiles(value),
// ])
// )
data;

const gradioPredict = <T extends any[]>(...args: Parameters<PredictFunction>) =>
getClient().then((app) => app.predict(...args)) as Promise<{ data: T }>;
// getClient().then((app) => app.predict(...args)) as Promise<{ data: T }>;
getClient()
.then((app) => app.predict(...args) as Promise<{ data: T }>)
.then((result: { data: T }) => ({
...result,
data: proxyGradioFiles(result?.data) as T,
}));

const gradioSubmit = <T extends any[]>(...args: Parameters<PredictFunction>) =>
getClient().then((app) => app.submit(...args));
getClient().then(
(app) =>
app.submit(...args) as SubmitIterable<
({ data: T } & PayloadMessage) | GradioEvent
>
);

async function musicgen({ melody, model, ...params }) {
const melodyBlob = await getFile(melody);
Expand Down Expand Up @@ -155,11 +189,6 @@ async function bark({
};
}

const reload_old_generation_dropdown = () =>
gradioPredict<[GradioChoices]>("/reload_old_generation_dropdown").then(
getChoices
);

const bark_favorite = async ({ folder_root }) =>
gradioPredict<[Object]>("/bark_favorite", [folder_root]).then(
(result) => result?.data
Expand Down Expand Up @@ -237,15 +266,6 @@ async function tortoise({
return results.slice(0, -1);
}

const tortoise_refresh_models = () =>
gradioPredict<[GradioChoices]>("/tortoise_refresh_models").then(getChoices);

const tortoise_refresh_voices = () =>
gradioPredict<[GradioChoices]>("/tortoise_refresh_voices").then(getChoices);

const tortoise_open_models = () => gradioPredict<[]>("/tortoise_open_models");
const tortoise_open_voices = () => gradioPredict<[]>("/tortoise_open_voices");

async function tortoise_apply_model_settings({
model, // string (Option from: ['Default']) in 'parameter_2488' Dropdown component
kv_cache, // boolean in 'parameter_2493' Checkbox component
Expand Down Expand Up @@ -308,32 +328,6 @@ async function rvc({
const delete_generation = ({ folder_root }) =>
gradioPredict<[]>("/delete_generation", [folder_root]);

const save_to_voices = ({ history_npz }) =>
gradioPredict<[Object]>("/save_to_voices", [history_npz]);

const save_config_bark = ({
text_use_gpu,
text_use_small,
coarse_use_gpu,
coarse_use_small,
fine_use_gpu,
fine_use_small,
codec_use_gpu,
load_models_on_startup,
}) =>
gradioPredict<[string]>("/save_config_bark", [
text_use_gpu, // boolean in 'Use GPU' Checkbox component
text_use_small, // boolean in 'Use small model' Checkbox component
coarse_use_gpu, // boolean in 'Use GPU' Checkbox component
coarse_use_small, // boolean in 'Use small model' Checkbox component
fine_use_gpu, // boolean in 'Use GPU' Checkbox component
fine_use_small, // boolean in 'Use small model' Checkbox component
codec_use_gpu, // boolean in 'Use GPU for codec' Checkbox component
load_models_on_startup, // boolean in 'Load Bark models on startup' Checkbox component
]).then((result) => result?.data[0]);

// get_config_bark

async function get_config_bark() {
const result = await gradioPredict<
[
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ beartype>=0.16.1 # workaround for a bug
# no longer required directly # transformers==4.36.1 # cross-compatibility
iso639-lang==2.2.3
pillow==10.3.0 # for gradio, conda fix
deepspeed @ https://github.com/rsxdalv/DeepSpeed/releases/download/v0.15.5-test/deepspeed-0.15.5+unknown-cp310-cp310-win_amd64.whl ; sys_platform == 'win32' # Apache 2.0
2 changes: 2 additions & 0 deletions tts_webui/bark/clone/tab_voice_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def tab_voice_clone():
"es_tokenizer.pth @ Lancer1408/bark-es-tokenizer",
"portuguese-HuBERT-quantizer_24_epoch.pth @ MadVoyager/bark-voice-cloning-portuguese-HuBERT-quantizer",
"turkish_model_epoch_14.pth @ egeadam/bark-voice-cloning-turkish-HuBERT-quantizer",
"japanese-HuBERT-quantizer_24_epoch.pth @ junwchina/bark-voice-cloning-japanese-HuBERT-quantizer",
"it_tokenizer.pth @ gpwr/bark-it-tokenizer",
],
value="quantifier_hubert_base_ls960_14.pth @ GitMylo/bark-voice-cloning",
allow_custom_value=True,
Expand Down
2 changes: 1 addition & 1 deletion tts_webui/stable_audio/stable_audio_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def model_select_ui():
outputs=[model_select],
api_name="stable_audio_refresh_models",
)
load_model_button = gr.Button(value="Load model")
load_model_button = gr.Button(value="Load model")

with gr.Column():
gr.Markdown(
Expand Down

0 comments on commit c1b4414

Please sign in to comment.