Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The optimizer state key names differ when using data_parallel for embedding sharding compared to when using row_wise #2394

Open
tiankongdeguiji opened this issue Sep 14, 2024 · 2 comments

Comments

@tiankongdeguiji
Copy link
Contributor

We can reproduce this problem using the following command: torchrun --master_addr=127.0.0.1 --master_port=1234 --nnodes=1 --nproc-per-node=1 --node_rank=0 test_optimizer_state.py --sharding_type $SHARDING_TYPE, and use the enviroment torchrec==0.8.0+cu121, torch==2.4.0+cu121, fbgemm-gpu==0.8.0+cu121

when SHARDING_TYPE=row_wise, it will print

['state.sparse.ebc.embedding_bags.table_0.weight.table_0.momentum1', 'state.sparse.ebc.embedding_bags.table_0.weight.table_0.exp_avg_sq', ...]

when SHARDING_TYPE=data_parallel, it will print

['state.sparse.ebc.embedding_bags.table_0.weight.step', 'state.sparse.ebc.embedding_bags.table_0.weight.exp_avg', 'state.sparse.ebc.embedding_bags.table_0.weight.exp_avg_sq', ...]

xxx.weight.table_0.momentum1 -> xxx.weight.exp_avg,xxx.weight.table_0.exp_avg_sq -> xxx.weight.exp_avg_sq

We may load the model to continue training on clusters with different scales, which can lead to different Sharding Plans, and consequently result in the optimizer's parameters not being loaded correctly.

test_optimizer_state.py

import os
import torch
import argparse

from torch import distributed as dist
from torch.distributed.checkpoint._nested_dict import flatten_state_dict

from torchrec.distributed.comm import get_local_size
from torchrec.distributed.planner import EmbeddingShardingPlanner
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
    get_default_sharders
)
from torchrec.distributed.planner.types import Topology, ParameterConstraints
from torchrec.distributed.embedding_types import ShardingType
from torchrec.distributed.test_utils.test_model import TestSparseNN, ModelInput
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.distributed.planner.storage_reservations import (
    HeuristicalStorageReservation,
)
from torchrec.optim import optimizers
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.optim.apply_optimizer_in_backward import (
    apply_optimizer_in_backward,  # NOQA
)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--sharding_type",
    type=str,
    default="data_parallel"
)
args, extra_args = parser.parse_known_args()


BATCH_SIZE = 8196
rank = int(os.environ.get("LOCAL_RANK", 0))
device: torch.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
dist.init_process_group(backend='nccl')

tables = [
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name="table_" + str(i),
        feature_names=["feature_" + str(i)],
    )
    for i in range(4)
]
topology = Topology(
    local_world_size=get_local_size(),
    world_size=dist.get_world_size(),
    compute_device=device.type
)
constrains = {
    t.name: ParameterConstraints(sharding_types=[args.sharding_type])  
    for t in tables
}
planner = EmbeddingShardingPlanner(
    topology=topology,
    batch_size=BATCH_SIZE,
    debug=True,
    storage_reservation = HeuristicalStorageReservation(
        percentage=0.7
    ),
    constraints=constrains
)
model = TestSparseNN(tables=tables, num_float_features=10, sparse_device=torch.device("meta"))

apply_optimizer_in_backward(
    optimizers.Adam, model.sparse.parameters(), {"lr": 0.01}
)
plan = planner.collective_plan(
    model, get_default_sharders(), dist.GroupMember.WORLD
)
# print(plan)
model = DistributedModelParallel(module=model, device=device, plan=plan)
dense_optimizer = KeyedOptimizerWrapper(
    dict(in_backward_optimizer_filter(model.named_parameters())),
    lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])

_, local_batchs = ModelInput.generate(
    batch_size=BATCH_SIZE,
    world_size=int(os.environ.get("WORLD_SIZE", 0)),
    num_float_features=10,
    tables=tables,
    weighted_tables=[]
)
loss, _ = model.forward(local_batchs[rank].to(device))
torch.sum(loss).backward()
optimizer.step()

print(flatten_state_dict(optimizer.state_dict())[0].keys())
@tiankongdeguiji
Copy link
Contributor Author

@tiankongdeguiji
Copy link
Contributor Author

Hi, @sarckk @TroyGarden can you see this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant