Skip to content

Commit

Permalink
Text modules : adding CUDA support for eligible modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Woolverine94 committed Nov 27, 2023
1 parent 05d49a3 commit 630c975
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 12 deletions.
6 changes: 4 additions & 2 deletions ressources/img2txt_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import time
from ressources.common import *

device_img2txt_git = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_label_img2txt_git, model_arch = detect_device()
device_img2txt_git = torch.device(device_label_img2txt_git)

# Gestion des modèles
model_path_img2txt_git = "./models/GIT"
Expand All @@ -34,11 +35,12 @@ def text_img2txt_git(
pipe_img2txt_git = AutoModelForCausalLM.from_pretrained(
modelid_img2txt_git,
cache_dir=model_path_img2txt_git,
torch_dtype=torch.float32,
torch_dtype=model_arch,
use_safetensors=True,
resume_download=True,
local_files_only=True if offline_test() else None
)

pipe_img2txt_git = pipe_img2txt_git.to(device_img2txt_git)
inpipe_img2txt_git = processor_img2txt_git(images=img_img2txt_git, return_tensors="pt").to(device_img2txt_git)

Expand Down
11 changes: 9 additions & 2 deletions ressources/nllb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, AutoTokenizer
from ressources.common import *

device_nllb = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_label_nllb, model_arch = detect_device()
device_nllb = torch.device(device_label_nllb)

model_path_nllb = "./models/nllb/"
os.makedirs(model_path_nllb, exist_ok=True)
Expand Down Expand Up @@ -273,18 +274,24 @@ def text_nllb(
resume_download=True,
local_files_only=True if offline_test() else None
)

tokenizer_nllb = NllbTokenizer.from_pretrained(
model_nllb,
torch_dtype=model_arch,
src_lang=source_language_nllb,
tgt_lang=output_language_nllb
)
automodel_nllb = AutoModelForSeq2SeqLM.from_pretrained(model_nllb)

automodel_nllb = AutoModelForSeq2SeqLM.from_pretrained(model_nllb).to(device_nllb)
inputs_nllb = tokenizer_nllb(prompt_nllb, return_tensors="pt").to(device_nllb)
automodel_nllb = automodel_nllb.to_bettertransformer()

translated_tokens = automodel_nllb.generate(
**inputs_nllb,
forced_bos_token_id=tokenizer_nllb.lang_code_to_id[output_language_nllb],
max_new_tokens=max_tokens_nllb,
)

output_nllb = tokenizer_nllb.batch_decode(translated_tokens, skip_special_tokens=True)[0]
filename_nllb = write_file(output_nllb)

Expand Down
14 changes: 9 additions & 5 deletions ressources/txt2prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from transformers import pipeline, set_seed
from ressources.common import *

device_txt2prompt = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_label_txt2prompt, model_arch = detect_device()
device_txt2prompt = torch.device(device_label_txt2prompt)

model_path_txt2prompt = "./models/Prompt_generator/"
os.makedirs(model_path_txt2prompt, exist_ok=True)
Expand Down Expand Up @@ -67,20 +68,23 @@ def text_txt2prompt(
if seed_txt2prompt == 0:
seed_txt2prompt = random.randint(0, 4294967295)

pipeline_txt2prompt = pipeline(
pipe_txt2prompt = pipeline(
task="text-generation",
model=modelid_txt2prompt,
torch_dtype=torch.float32,
torch_dtype=model_arch,
device=device_txt2prompt,
local_files_only=True if offline_test() else None
)

set_seed(seed_txt2prompt)
generator_txt2prompt = pipeline_txt2prompt(

generator_txt2prompt = pipe_txt2prompt(
prompt_txt2prompt,
do_sample=True,
max_new_tokens=max_tokens_txt2prompt,
num_return_sequences=num_prompt_txt2prompt,
)

for i in range(len(generator_txt2prompt)):
output_txt2prompt_int = generator_txt2prompt[i]["generated_text"]
if output_type_txt2prompt == "ChatGPT":
Expand All @@ -99,7 +103,7 @@ def text_txt2prompt(
f"Seed={seed_txt2prompt}"
print(reporting_txt2prompt)

del pipeline_txt2prompt
del pipe_txt2prompt
clean_ram()

print(f">>>[Prompt generator 📝 ]: leaving module")
Expand Down
8 changes: 5 additions & 3 deletions ressources/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from transformers import AutoModel, AutoTokenizer, AutoFeatureExtractor
from ressources.common import *

device_whisper = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_label_whisper, model_arch = detect_device()
device_whisper = torch.device(device_label_whisper)

model_path_whisper = "./models/whisper/"
os.makedirs(model_path_whisper, exist_ok=True)
Expand Down Expand Up @@ -137,6 +138,7 @@ def text_whisper(
model_whisper = WhisperForConditionalGeneration.from_pretrained(
modelid_whisper,
cache_dir=model_path_whisper,
torch_dtype=model_arch,
low_cpu_mem_usage=True,
resume_download=True,
local_files_only=True if offline_test() else None
Expand Down Expand Up @@ -164,9 +166,9 @@ def text_whisper(
feature_extractor=feat_ex_whisper,
chunk_length_s=30,
device=device_whisper,
torch_dtype=torch.float32
torch_dtype=model_arch,
)

if srt_output_whisper == False :
transcription_whisper_final = pipe_whisper(
audio_whisper.copy(),
Expand Down

0 comments on commit 630c975

Please sign in to comment.