Skip to content

Commit

Permalink
Update accel_sdxl_gen_img.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DKnight54 authored Feb 1, 2025
1 parent eaf4a9a commit cc6266d
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions accel_sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -2883,21 +2883,25 @@ def scale_and_round(x):
for index in res:
templist.append(prompt_data_list[index])
split_into_batches = get_batches(items=templist, batch_size=args.batch_size)
sublist = []
if(len(split_into_batches) % distributed_state.num_processes != 0):
#Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch
sublist = []
for j in range(len(split_into_batches) % distributed_state.num_processes):
if len(split_into_batches) > 1 :
sublist.extend(split_into_batches.pop(-1))
elif len(split_into_batches) == 1 :
sublist.extend(split_into_batches.pop(-1))
split_into_batches = []
n, m = divmod(len(sublist), distributed_state.num_processes)
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])
popnum = (len(split_into_batches) % distributed_state.num_processes
else:
#force distribution check on last round of batches
popnum = distributed_state.num_processes

for j in range(popnum):
if len(split_into_batches) > 1 :
sublist.extend(split_into_batches.pop(-1))
elif len(split_into_batches) == 1 :
sublist.extend(split_into_batches.pop(-1))
split_into_batches = []

n, m = divmod(len(sublist), distributed_state.num_processes)
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])
ext_separated_list_of_batches.append(split_into_batches)
if distributed_state.num_processes > 1:


for x in range(len(ext_separated_list_of_batches)):
temp_list = []
logger.info(f"start: ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}")
Expand Down

0 comments on commit cc6266d

Please sign in to comment.