-
Notifications
You must be signed in to change notification settings - Fork 13
/
inference.py
129 lines (102 loc) · 4.41 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import argparse
import torch
from diffusers.utils.import_utils import is_xformers_available
from datasets import load_dataset
from tqdm.auto import tqdm
from scipy.io.wavfile import write
from auffusion_pipeline import AuffusionPipeline
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a inference script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default="auffusion/auffusion",
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--test_data_dir",
type=str,
default="./data/test_audiocaps.raw.json",
help="Path to test dataset in json file",
)
parser.add_argument(
"--audio_column", type=str, default="audio_path", help="The column of the dataset containing an audio."
)
parser.add_argument(
"--caption_column", type=str, default="text", help="The column of the dataset containing a caption."
)
parser.add_argument(
"--output_dir",
type=str,
default="./output/auffusion_hf",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--sample_rate", type=int, default=16000, help="The sample rate of audio."
)
parser.add_argument(
"--duration", type=int, default=10, help="The duration(s) of audio."
)
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible inference.")
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--guidance_scale", type=float, default=7.5, help="The scale of guidance."
)
parser.add_argument(
"--num_inference_steps", type=int, default=100, help="Number of inference steps to perform."
)
parser.add_argument(
"--width", type=int, default=1024, help="Width of the image."
)
parser.add_argument(
"--height", type=int, default=256, help="Height of the image."
)
args = parser.parse_args()
return args
def main():
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.float32
pipeline = AuffusionPipeline.from_pretrained(args.pretrained_model_name_or_path)
pipeline = pipeline.to(device, weight_dtype)
pipeline.set_progress_bar_config(disable=True)
if is_xformers_available() and args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
generator = torch.Generator(device=device).manual_seed(args.seed)
# load dataset
audio_column, caption_column = args.audio_column, args.caption_column
data_files = {"test": args.test_data_dir}
dataset = load_dataset("json", data_files=data_files, split="test")
# output dir
audio_output_dir = os.path.join(args.output_dir, "audios")
os.makedirs(audio_output_dir, exist_ok=True)
# generating
audio_length = args.sample_rate * args.duration
for i in tqdm(range(len(dataset)), desc="Generating"):
prompt = dataset[i][caption_column]
audio_name = os.path.basename(dataset[i][audio_column])
audio_path = os.path.join(audio_output_dir, audio_name)
if os.path.exists(audio_path):
continue
with torch.autocast("cuda"):
output = pipeline(prompt=prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, generator=generator, width=args.width, height=args.height)
audio = output.audios[0][:audio_length]
write(audio_path, args.sample_rate, audio)
if __name__ == "__main__":
main()