Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/sync optimization logger across ranks #1970

Merged
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c711d61
add pct underflow
bmosaicml Feb 12, 2023
07a8626
add pct underflow
bmosaicml Feb 13, 2023
f45f800
sync logger
bmosaicml Feb 15, 2023
765c59f
sync logger
bmosaicml Feb 15, 2023
10bd309
sync logger
bmosaicml Feb 15, 2023
cc18e4d
sync logger
bmosaicml Feb 15, 2023
3cb1304
sync logger
bmosaicml Feb 15, 2023
bae92be
sync logger
bmosaicml Feb 15, 2023
d3e340b
sync logger
bmosaicml Feb 15, 2023
dc054df
sync logger
bmosaicml Feb 15, 2023
e49741e
add layerwise lr
bmosaicml Feb 18, 2023
d2375e3
fix lp prevention
bmosaicml Feb 21, 2023
60f3d3f
fix lp prevention
bmosaicml Feb 21, 2023
c766c92
fix lp prevention
bmosaicml Feb 21, 2023
a7b3999
fix lp prevention
bmosaicml Feb 21, 2023
4d7773b
fix lp prevention
bmosaicml Feb 22, 2023
2fa81f0
fix lp prevention
bmosaicml Feb 22, 2023
54d9fb9
fix lp prevention
bmosaicml Feb 22, 2023
ad2044d
fix lp prevention
bmosaicml Feb 22, 2023
a5f53f9
fix lp prevention
bmosaicml Feb 22, 2023
ff4831a
Merge branch 'dev' into bug/sync_optimization_logger_across_ranks
mvpatel2000 Feb 24, 2023
5763d96
add layerwise lr
bmosaicml Feb 18, 2023
4b0bd6b
fix lp prevention
bmosaicml Feb 22, 2023
062521c
add layerwise lr
bmosaicml Feb 18, 2023
a84c832
fix lp prevention
bmosaicml Feb 22, 2023
187200f
rebase
bmosaicml Feb 24, 2023
543ba48
Update tests/callbacks/test_optimizer_monitor.py
bmosaicml Feb 24, 2023
55223bc
Update composer/optim/decoupled_weight_decay.py
bmosaicml Feb 24, 2023
a11126f
Update composer/optim/decoupled_weight_decay.py
bmosaicml Feb 24, 2023
445ea11
rebase
bmosaicml Feb 24, 2023
eeb1992
Update tests/callbacks/test_optimizer_monitor.py
bmosaicml Feb 24, 2023
bb63598
Update tests/callbacks/test_optimizer_monitor.py
bmosaicml Feb 24, 2023
32a5326
rebase
bmosaicml Feb 24, 2023
18ba0c6
rebase
bmosaicml Feb 24, 2023
211022f
Update composer/optim/decoupled_weight_decay.py
bmosaicml Feb 24, 2023
966dc57
Update composer/callbacks/optimizer_monitor.py
bmosaicml Feb 24, 2023
e8fb131
rebase
bmosaicml Feb 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions composer/callbacks/optimizer_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import dist

__all__ = ['OptimizerMonitor']

Expand Down Expand Up @@ -87,21 +88,43 @@ def __init__(self, log_optimizer_metrics: bool = True):

def batch_end(self, state: State, logger: Logger):
norm = 0.0
default_metrics = {}
optimizer_metrics = {}

for name, p in state.model.named_parameters():
if p.grad is not None and p.requires_grad:
param_grad_norm = torch.linalg.vector_norm(p.grad)
default_metrics[f'l2_norm/grad/{name}'] = param_grad_norm

norm += param_grad_norm**2
metric_reporter = getattr(state.optimizers[0], 'report_per_parameter_metrics', None)
if callable(metric_reporter) and self.log_optimizer_metrics:
optimizer_metrics = metric_reporter(p, name, optimizer_metrics)

default_metrics['l2_norm/grad/global'] = norm**0.5

logger.log_metrics(default_metrics)
if self.log_optimizer_metrics:
logger.log_metrics(optimizer_metrics)
if f'l2_norm/grad/{name}' not in optimizer_metrics:
param_grad_norm = torch.linalg.vector_norm(p.grad)
optimizer_metrics[f'l2_norm/grad/{name}'] = param_grad_norm
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved

if state.fsdp_enabled and dist.get_world_size() > 0 and self.log_optimizer_metrics:
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
# If FSDP is enabled, the optimizer state lives on different ranks and must be reduced
# and combined before we can compute metrics.
# Each metric has a different way of being reduced, so the optimizer is responsible for implementing
# the reduction process.
# It occurs first via a pre-reduce, where the metric on each rank is modified and prepared
# then an all-reduce where the modified metric on each rank is combined into the correct metric across all ranks.
#
# For example, L2 norms are squared on each rank before we apply all_reduce(SUM) and take the sqrt on each rank
pre_reduce_metrics = getattr(state.optimizers[0], 'pre_reduce_metrics', None)
if callable(pre_reduce_metrics) and self.log_optimizer_metrics:
optimizer_metrics = pre_reduce_metrics(optimizer_metrics)

dist_reduce_metrics = getattr(state.optimizers[0], 'dist_reduce_metrics', None)
if callable(dist_reduce_metrics) and self.log_optimizer_metrics:
optimizer_metrics = dist_reduce_metrics(optimizer_metrics)

for metric in optimizer_metrics:
if metric.startswith('l2_norm/grad'):
norm += optimizer_metrics[metric]**2

optimizer_metrics['l2_norm/grad/global'] = norm**0.5

for metric in optimizer_metrics:
if isinstance(optimizer_metrics[metric], torch.Tensor):
optimizer_metrics[metric] = optimizer_metrics[metric].item()
logger.log_metrics(optimizer_metrics)
78 changes: 63 additions & 15 deletions composer/optim/decoupled_weight_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from torch.optim import SGD, AdamW
from torch.optim.optimizer import required # type: ignore

from composer.utils import dist

log = logging.getLogger(__name__)

__all__ = ['DecoupledSGDW', 'DecoupledAdamW']
Expand Down Expand Up @@ -189,24 +191,20 @@ class DecoupledAdamW(AdamW):
metric_functions = {
'l2_norm/moment':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(optim_state['exp_avg']),
'l2_norm_ratio/moment_grad':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(param.grad) / torch.linalg.vector_norm(
optim_state['exp_avg']),
'cosine/moment_grad':
lambda param, optim_state, step_tensor: torch.nn.functional.cosine_similarity(
param.grad.flatten(), optim_state['exp_avg'].flatten(), dim=0),
'l2_norm/param':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(param.data),
'l2_norm/second_moment_sqrt':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(optim_state['exp_avg_sq']).sqrt(),
'l2_norm/update':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(step_tensor),
'l2_norm/grad':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(param.grad),
'cosine/update_grad':
lambda param, optim_state, step_tensor: torch.nn.functional.cosine_similarity(
param.grad.flatten(), step_tensor.flatten(), dim=0),
'l2_norm_ratio/update_param':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(step_tensor) / torch.linalg.vector_norm(
param.data),
'cosine/moment_grad':
lambda param, optim_state, step_tensor: torch.nn.functional.cosine_similarity(
param.grad.flatten(), optim_state['exp_avg'].flatten(), dim=0)
}

def __init__(self,
Expand All @@ -223,6 +221,8 @@ def __init__(self,
super().__init__(params=params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
for group in self.param_groups:
group['initial_lr'] = group['lr']
self.layer_to_scale = {}
self.amsgrad = amsgrad

@staticmethod
def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[torch.Tensor],
Expand Down Expand Up @@ -303,7 +303,7 @@ def step(self, closure=None):
weight_decay = group['weight_decay']

for p in group['params']:
if p.grad is None:
if p.grad is None or not p.requires_grad:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
Expand All @@ -313,7 +313,7 @@ def step(self, closure=None):
state = self.state[p]

# State initialization
if len(state) == 0:
if 'step' not in state:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
Expand All @@ -325,13 +325,12 @@ def step(self, closure=None):

exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])

if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])

# update the steps for each param group update
# Update the steps for each param group update
state['step'] += 1
# record the step after step update
# Record the step after step update
state_steps.append(state['step'])

self.adamw(params_with_grad,
Expand All @@ -350,6 +349,54 @@ def step(self, closure=None):

return loss

def dist_reduce_metrics(self, optimizer_metrics):
for metric in optimizer_metrics:
if metric.startswith('l2_norm'):
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')

optimizer_metrics[metric] = math.sqrt(reduced)
elif metric.startswith('cosine'):
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')

_, vectors, layer = tuple(metric.split('/'))

A, B = tuple(vectors.split('_'))

A_reduced_norm = optimizer_metrics[f'l2_norm/{A}/{layer}']
B_reduced_norm = optimizer_metrics[f'l2_norm/{B}/{layer}']
optimizer_metrics[metric] = reduced / (A_reduced_norm * B_reduced_norm)
else:
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')
optimizer_metrics[metric] = reduced / dist.get_world_size()

return optimizer_metrics

def pre_reduce_metrics(self, optimizer_metrics):
# Some metrics require preprocessing to reduce across ranks correctly
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved

for metric in optimizer_metrics:
if metric.startswith('l2_norm'):
# L2 norms need to be squared, before they are reduced via summation
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
optimizer_metrics[metric] = optimizer_metrics[metric]**2
elif metric.startswith('cosine'):
_, vectors, layer = tuple(metric.split('/'))

A, B = tuple(vectors.split('_'))

# L2 norm would've been squared in previous branch
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
A_rank_subset_norm = math.sqrt(optimizer_metrics[f'l2_norm/{A}/{layer}'])
B_rank_subset_norm = math.sqrt(optimizer_metrics[f'l2_norm/{B}/{layer}'])

optimizer_metrics[metric] *= A_rank_subset_norm * B_rank_subset_norm
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved

return optimizer_metrics

def report_per_parameter_metrics(self, param: torch.Tensor, name: str, optimizer_metrics: dict):
lr = self.param_groups[0]['lr']
eps = self.param_groups[0]['eps']
Expand All @@ -369,5 +416,6 @@ def report_per_parameter_metrics(self, param: torch.Tensor, name: str, optimizer
step_tensor.add_(param, alpha=-weight_decay * decay_factor)
for metric in self.metric_functions:
optimizer_metrics[f'{metric}/{name}'] = self.metric_functions[metric](param, param_optim_state,
step_tensor).item()
step_tensor)

return optimizer_metrics
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ filterwarnings = [
'ignore:Positional args are being deprecated, use kwargs instead.*:UserWarning',
'ignore:yahp-based workflows are deprecated and will be removed:DeprecationWarning',
'ignore:Torchmetrics v0.9 introduced a new argument class property:UserWarning',
'ignore:torch.distributed._all_gather_base is a private function and will be deprecated:UserWarning',
'ignore:torch.distributed._reduce_scatter_base is a private function and will be deprecated:UserWarning',
'ignore:torch.distributed._all_gather_base is a private function and will be deprecated.*:UserWarning',
'ignore:torch.distributed._reduce_scatter_base is a private function and will be deprecated.*:UserWarning',
# Ignore tensorboard deprecation warnings
'ignore:Call to deprecated create function Descriptor().*:DeprecationWarning:tensorboard',
'ignore:Call to deprecated create function EnumDescriptor().*:DeprecationWarning:tensorboard',
Expand Down
51 changes: 47 additions & 4 deletions tests/callbacks/test_optimizer_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from composer.loggers import InMemoryLogger
from composer.optim import DecoupledAdamW
from composer.trainer import Trainer
from tests.common import device, world_size
from tests.common.datasets import RandomClassificationDataset
from tests.common.models import SimpleModel

Expand All @@ -34,17 +35,59 @@ def test_optimizer_monitor(log_optimizer_metrics: bool):
grad_norm_calls = len(in_memory_logger.data['l2_norm/grad/global'])
layer_norm_calls = [len(calls) for (k, calls) in in_memory_logger.data.items() if 'l2_norm/grad' in k]
assert 'l2_norm/grad/module.2.weight' in in_memory_logger.data.keys()

if log_optimizer_metrics:
assert 'l2_norm/moment/module.2.weight' in in_memory_logger.data.keys()
assert 'l2_norm_ratio/moment_grad/module.2.weight' in in_memory_logger.data.keys()
assert 'cosine/moment_grad/module.2.weight' in in_memory_logger.data.keys()
assert 'l2_norm/second_moment_sqrt/module.2.weight' in in_memory_logger.data.keys()
assert 'l2_norm/update/module.2.weight' in in_memory_logger.data.keys()
assert 'cosine/update_grad/module.2.weight' in in_memory_logger.data.keys()
assert 'l2_norm_ratio/update_param/module.2.weight' in in_memory_logger.data.keys()

# expected to log gradient norm once per step (total batch)
# Expected to log gradient norm once per step (total batch)
assert grad_norm_calls == num_train_steps
for num_calls in layer_norm_calls:
assert num_calls == num_train_steps


@device('gpu')
@world_size(1, 2)
def test_fsdp_optimizer_monitor(device, world_size):
# Construct the callback
grad_monitor = OptimizerMonitor(log_optimizer_metrics=True)
in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger
model = SimpleModel()
# Construct the trainer and train
trainer = Trainer(model=model,
callbacks=grad_monitor,
loggers=in_memory_logger,
train_dataloader=DataLoader(RandomClassificationDataset()),
optimizers=DecoupledAdamW(model.parameters()),
max_duration='3ba',
fsdp_config={
'sharding_strategy': 'FULL_SHARD',
'min_params': 10,
'cpu_offload': False,
'mixed_precision': 'PURE',
'backward_prefetch': 'BACKWARD_PRE',
'activation_checkpointing': False,
'activation_ocpu_offload': False,
'verbose': False
})
trainer.fit()
num_train_steps = int(trainer.state.timestamp.batch)

# Count the logged steps
grad_norm_calls = len(in_memory_logger.data['l2_norm/grad/global'])
layer_norm_calls = [len(calls) for (k, calls) in in_memory_logger.data.items() if 'l2_norm/grad' in k]
assert 'l2_norm/grad/module.2.weight' in in_memory_logger.data.keys()
assert 'l2_norm/moment/module.2.weight' in in_memory_logger.data.keys()
assert 'l2_norm_ratio/moment_grad/module.2.weight' in in_memory_logger.data.keys()
assert 'cosine/moment_grad/module.2.weight' in in_memory_logger.data.keys()
assert 'l2_norm/second_moment_sqrt/module.2.weight' in in_memory_logger.data.keys()
assert 'l2_norm/update/module.2.weight' in in_memory_logger.data.keys()
assert 'cosine/update_grad/module.2.weight' in in_memory_logger.data.keys()
assert 'l2_norm_ratio/update_param/module.2.weight' in in_memory_logger.data.keys()

# Expected to log gradient norm once per step (total batch)
assert grad_norm_calls == num_train_steps
for num_calls in layer_norm_calls:
assert num_calls == num_train_steps