Skip to content

Commit

Permalink
Back to FSDP
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras committed Dec 20, 2024
1 parent d7ed30e commit 0c47992
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/scripts/train/OLMo2-32B.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ def build_model_config(common: CommonComponents) -> TransformerConfig:
compile=compile,
fused_ops=False,
use_flash=not compile,
# dp_config=TransformerDataParallelConfig(
# name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
# ),
dp_config=TransformerDataParallelConfig(
name=DataParallelType.hsdp,
param_dtype=DType.bfloat16,
reduce_dtype=DType.float32,
num_replicas=64 // 16, # common.launch.num_nodes // 2,
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
ac_config=TransformerActivationCheckpointingConfig(TransformerActivationCheckpointingMode.full),
# ac_config=TransformerActivationCheckpointingConfig(
# mode=TransformerActivationCheckpointingMode.selected_modules,
# modules=[f"blocks.{i}.feed_forward" for i in range(64)],
# dp_config=TransformerDataParallelConfig(
# name=DataParallelType.hsdp,
# param_dtype=DType.bfloat16,
# reduce_dtype=DType.float32,
# num_replicas=64 // 16, # common.launch.num_nodes // 2,
# ),
# ac_config=TransformerActivationCheckpointingConfig(TransformerActivationCheckpointingMode.full),
ac_config=TransformerActivationCheckpointingConfig(
mode=TransformerActivationCheckpointingMode.selected_modules,
modules=[f"blocks.{i}.feed_forward" for i in range(64)],
),
float8_config=Float8Config(compile=compile, enabled=False),
)

Expand Down

0 comments on commit 0c47992

Please sign in to comment.