Skip to content

Commit

Permalink
fix gan grad accum
Browse files Browse the repository at this point in the history
  • Loading branch information
autumn-2-net committed Dec 9, 2023
1 parent 08ae459 commit afc88dd
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 21 deletions.
38 changes: 21 additions & 17 deletions model_trainer/basic_lib/gan_training_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,28 @@ def train_strategy(self, prent_obj,

generator_cache = []
grad_accum_number = len(self.batchs)
for i in range(grad_accum_number):
generator_output = generator_model.training_step(batch=self.batchs[i][0], batch_idx=self.batchs[i][1])
batchs_cache=self.batchs
for i in batchs_cache:
generator_output = generator_model.training_step(batch=i[0], batch_idx=i[1])
generator_cache.append(generator_output)

discriminator_optimizers.zero_grad()

log_cache = []
for i in range(grad_accum_number):
for idx,i in enumerate(batchs_cache):
discriminator_fake_output = discriminator_model.training_step(
batch=generator_output_warp(generator_cache[i]),
batch_idx=self.batchs[i][1])
discriminator_true_output = discriminator_model.training_step(batch=self.batchs[i][0],
batch_idx=self.batchs[i][1])
batch=generator_output_warp(generator_cache[idx]),
batch_idx=i[1])
discriminator_true_output = discriminator_model.training_step(batch=i[0],
batch_idx=i[1])
discriminator_losses, discriminator_logs = discriminator_model.model_loss.discriminator_loss_fn(
discriminator_fake=discriminator_fake_output,
discriminator_true=discriminator_true_output)
log_cache.append(apply_to_collection(discriminator_logs, dtype=torch.Tensor,
function=lambda x: x.detach().cpu().item()))
log_cache.append(discriminator_logs)
prent_obj.fabric.backward(discriminator_losses / grad_accum_number)

loges.update(abs_log(log_cache))
# loges.update(log_cache[0])

discriminator_optimizers.step()
discriminator_optimizers.zero_grad()
Expand All @@ -163,26 +164,29 @@ def train_strategy(self, prent_obj,
prent_obj.global_step += 1
generator_optimizers.zero_grad()
log_cache = []
for i in range(grad_accum_number):
generator_discriminator_fake_output = discriminator_model.training_step(batch=generator_cache[i],
batch_idx=self.batchs[i][1])
generator_discriminator_true_output = discriminator_model.training_step(batch=self.batchs[i][0],
batch_idx=self.batchs[i][1])
for idx,i in enumerate(batchs_cache):
generator_discriminator_fake_output = discriminator_model.training_step(batch=generator_cache[idx],
batch_idx=i[1])
generator_discriminator_true_output = discriminator_model.training_step(batch=i[0],
batch_idx=i[1])
generator_discriminator_losses, generator_discriminator_logs = discriminator_model.model_loss.generator_discriminator_loss_fn(
generator_discriminator_fake=generator_discriminator_fake_output,
generator_discriminator_true=generator_discriminator_true_output)
log_cache.append(apply_to_collection(generator_discriminator_logs, dtype=torch.Tensor,
function=lambda x: x.detach().cpu().item()))
log_cache.append(generator_discriminator_logs)

prent_obj.fabric.backward(generator_discriminator_losses / grad_accum_number)

loges.update(abs_log(log_cache))
# loges.update(log_cache[0])

generator_optimizers.step()
generator_optimizers.zero_grad()
prent_obj.step_scheduler(generator_model, generator_schedulers, level="step",
current_value=prent_obj.get_state_step())
prent_obj.global_step += 1
prent_obj.train_log = loges
prent_obj.train_log = apply_to_collection(loges, dtype=torch.Tensor,
function=lambda x: x.detach().cpu().item())
self.batchs=[]

else:
pass
8 changes: 6 additions & 2 deletions model_trainer/trainer/gan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self, accelerator: Union[str, Accelerator] = "auto",
# self.skip_save = True
self.skip_val = True
self.ModelSummary = ModelSummary(max_depth=max_depth)
self.last_val_step = 0
if self.fabric.is_global_zero:
self.bar_obj = Adp_bar(bar_type=progress_bar_type)
else:
Expand Down Expand Up @@ -319,11 +320,11 @@ def get_local_ckpt_name(self, state_type: Literal['G', 'D']):
step = int(search.group(0)[6:])
ckpt_list.append((step, str(ckpt.name)))
if len(ckpt_list) < self.keep_ckpt_num:
return remove_list, f'model_ckpt_steps_{str(self.global_step)}.ckpt', work_dir
return remove_list, f'model_ckpt_steps_{str(self.get_state_step())}.ckpt', work_dir
num_remove = len(ckpt_list) + 1 - self.keep_ckpt_num
ckpt_list.sort(key=lambda x: x[0])
remove_list = ckpt_list[:num_remove]
return remove_list, f'model_ckpt_steps_{str(self.global_step)}.ckpt', work_dir
return remove_list, f'model_ckpt_steps_{str(self.get_state_step())}.ckpt', work_dir
# for i in ckpt_list:
# todo

Expand Down Expand Up @@ -503,13 +504,16 @@ def fit_loop(
)

if self.get_state_step() % self.val_step == 0 and self.fabric.is_global_zero and not self.without_val: # todo need add
if self.last_val_step == self.get_state_step():
self.skip_val = True
if self.skip_val:
self.skip_val = False
else:
self.val_loop(model=generator_model, val_loader=val_loader)
generator_model.train()
discriminator_model.train()
can_save = True
self.last_val_step = self.get_state_step()
if self.fabric.is_global_zero and can_save:
self.save_checkpoint(self.generator_state, state_type='G')
self.save_checkpoint(self.discriminator_state, state_type='D')
Expand Down
4 changes: 2 additions & 2 deletions test_trainer/gan_testmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def train(Gmodel,Dmodel):
trainer = GanTrainer(
accelerator=accelerator, devices="auto", limit_train_batches=None, limit_val_batches=10, max_epochs=130,
loggers=TensorBoardLogger(
save_dir=str('./ckpt/gan1'),
save_dir=str('./ckpt/gan2'),
name='lightning_logs',
version='lastest',

), checkpoint_dir='./ckpt/gan1',progress_bar_type='rich',val_step=4000,grad_accum_steps=2
), checkpoint_dir='./ckpt/gan2',progress_bar_type='rich',val_step=4000,grad_accum_steps=1
)
trainer.fit(generator_model=Gmodel,discriminator_model=Dmodel)

Expand Down

0 comments on commit afc88dd

Please sign in to comment.