Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support of resuming shard positions and saving in ckpt. #66

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/training/metaclip_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __iter__(self):
image = img.convert("RGB")
image = self.transform(image)

yield image, txt #, worker_id, shard_id
yield image, txt, worker_id, shard_id
json_uuid, img_uuid = None, None

shard_id = self._get_next_shard_id(shard_id)
Expand Down
56 changes: 49 additions & 7 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,34 @@
from src.open_clip.factory import unwrap_model
from src.training.distributed import is_master, world_info_from_env
from src.training.zero_shot import zero_shot_eval
from src.training.detect import detect_unused_parameters


def save_checkpoint(args, model, optimizer, scaler, step, checkpoint_fn="epoch_latest.pt", positions=None):
def agg_positions(positions, worker_ids, shard_ids):
if positions is None or worker_ids is None or shard_ids is None:
return None
assert sum(worker_ids) == worker_ids[0] * worker_ids.shape[0] # pt dataloader should iter over worker for each batch;
positions[worker_ids[0]] = shard_ids.max()
return positions


def collect_positions(args, positions):
if positions is None:
return None
if args.distributed:
import torch.distributed as dist

_, _, world_size = world_info_from_env()

gathered_tensors = [torch.zeros_like(positions, device=args.device) for _ in range(world_size)]
dist.all_gather(gathered_tensors, positions.to(args.device))
else:
gathered_tensors = [positions]
gathered_tensors = [gathered_tensor.cpu() for gathered_tensor in gathered_tensors]
positions = {f"{rank}_{worker_id}": shard_id for rank, gathered_tensor in enumerate(gathered_tensors) for worker_id, shard_id in enumerate(gathered_tensor)}
return positions


def save_checkpoint(args, model, optimizer, scaler, step, checkpoint_fn="epoch_latest.pt", positions_dict=None):
checkpoint_dict = {
"step": step,
"name": args.name,
Expand All @@ -33,8 +57,8 @@ def save_checkpoint(args, model, optimizer, scaler, step, checkpoint_fn="epoch_l
if scaler is not None:
checkpoint_dict["scaler"] = scaler.state_dict()

if positions is not None:
checkpoint_dict["positions"] = positions
if positions_dict is not None:
checkpoint_dict["positions"] = positions_dict

# Saving checkpoints. use eval_steps to save a checkpoint.
if args.save_logs: # master_only.
Expand Down Expand Up @@ -118,13 +142,22 @@ def train_one_epoch_ex(args, model, data, start_step, total_steps, optimizer, sc
data_time_m = AverageMeter()
end = time.time()

positions = torch.full((args.workers,), fill_value=-1, dtype=torch.long)

batch_iter = iter(dataloader)

for step in range(start_step, total_steps):
batch = next(batch_iter)
scheduler(step)

images, texts = to_device(batch, device)
if len(batch) == 2:
(images, texts), worker_ids, shard_ids = batch, None, None
else:
images, texts, worker_ids, shard_ids = batch

images, texts = to_device((images, texts), device)

positions = agg_positions(positions, worker_ids, shard_ids)

data_time_m.update(time.time() - end)
optimizer.zero_grad()
Expand All @@ -136,13 +169,20 @@ def train_one_epoch_ex(args, model, data, start_step, total_steps, optimizer, sc
if torch.isfinite(total_loss).all():
if scaler is not None:
scaler.scale(total_loss).backward()
# if args.world_size == 1:
# from src.training.detect import detect_unused_parameters
# detect_unused_parameters(model)
if args.norm_gradient_clip is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0)
scaler.step(optimizer)
scaler.update()
else:
total_loss.backward()
# if args.world_size == 1:
# from src.training.detect import detect_unused_parameters
# detect_unused_parameters(model)
# detect_nan(model, optimizer)
if args.norm_gradient_clip is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0)
optimizer.step()
Expand Down Expand Up @@ -196,14 +236,16 @@ def train_one_epoch_ex(args, model, data, start_step, total_steps, optimizer, sc
data_time_m.reset()

if hasattr(args, "save_steps") and (step + 1) % args.save_steps == 0:
save_checkpoint(args, model, optimizer, scaler, step)
positions_dict = collect_positions(args, positions)
save_checkpoint(args, model, optimizer, scaler, step, positions_dict=positions_dict)

# TODO: copied from main.py, wrap as a function call.
if hasattr(args, "eval_steps") and (step + 1) % args.eval_steps == 0: # TODO (huxu): put eval on master only?
if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')):
evaluate_ex(args, model, data, step, tb_writer) # completed_epoch -> epoch, writer -> tb_writer
model.train() # evaluate won't turn model back to train.
save_checkpoint(args, model, optimizer, scaler, step)
positions_dict = collect_positions(args, positions)
save_checkpoint(args, model, optimizer, scaler, step, positions_dict=positions_dict)
# end for


Expand Down