Skip to content

Commit

Permalink
Merge branch 'main' into add-reinforce
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony authored Nov 19, 2024
2 parents c1839c7 + a8f7913 commit 82a479f
Show file tree
Hide file tree
Showing 32 changed files with 499 additions and 87 deletions.
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA

### Host Setup

First make sure you are in an environment with Python 3.8 with an appropriate version of PyTorch 1.8 or later installed. **Note:** Some of the libraries that GPT-NeoX depends on have not been updated to be compatible with Python 3.10+. Python 3.9 appears to work, but this codebase has been developed and tested for Python 3.8.
This codebase has primarily developed and tested for Python 3.8-3.10, and PyTorch 1.8-2.0. This is not a strict requirement, and other versions and combinations of libraries may work.

To install the remaining basic dependencies, run:

Expand Down Expand Up @@ -679,7 +679,9 @@ We support profiling with Nsight Systems, the PyTorch Profiler, and PyTorch Memo
## Nsight Systems Profiling
To use the Nsight Systems profiling, set config options `profile`, `profile_step_start`, and `profile_step_stop`. Launch training with:
To use the Nsight Systems profiling, set config options `profile`, `profile_step_start`, and `profile_step_stop` (see [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/neox_arguments.md) for argument usage, and [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/prof.yml) for a sample config).
To populate nsys metrics, launch training with:
```
nsys profile -s none -t nvtx,cuda -o <path/to/profiling/output> --force-overwrite true \
Expand All @@ -689,22 +691,22 @@ $TRAIN_PATH/train.py --conf_dir configs <config files>
The generated output file can then by viewed with the Nsight Systems GUI:
![Alt text](images/nsight_profiling.png)
![nsight-prof](images/nsight_profiling.png)
## PyTorch Profiling
To use the built-in PyTorch profiler, set config options `profile`, `profile_step_start`, and `profile_step_stop`.
To use the built-in PyTorch profiler, set config options `profile`, `profile_step_start`, and `profile_step_stop` (see [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/neox_arguments.md) for argument usage, and [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/prof.yml) for a sample config).
The PyTorch profiler will save traces to your `tensorboard` log directory. You can view these traces within
TensorBoard by following the steps [here](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html).
![Alt text](images/pytorch_profiling.png)
![torch-prof](images/pytorch_profiling.png)
## PyTorch Memory Profiling
To use PyTorch Memory Profiling, set config options `memory_profiling` and `memory_profiling_path`.
To use PyTorch Memory Profiling, set config options `memory_profiling` and `memory_profiling_path` (see [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/neox_arguments.md) for argument usage, and [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/prof.yml) for a sample config).
![Alt text](images/memory_profiling.png)
![mem-prof](images/memory_profiling.png)
View the generated profile with the [memory_viz.py](https://github.com/pytorch/pytorch/blob/main/torch/cuda/_memory_viz.py) script. Run with:
Expand Down
2 changes: 2 additions & 0 deletions configs/llama/13B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 40,
"hidden_size": 5120,
"intermediate_size": 40960,
"num_attention_heads": 40,
"seq_length": 2048,
"max_position_embeddings": 2048,
Expand All @@ -16,6 +17,7 @@
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,
"use_bias_in_mlp": False,

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
Expand Down
2 changes: 2 additions & 0 deletions configs/llama/30B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 60,
"hidden_size": 6656,
"intermediate_size": 53248,
"num_attention_heads": 52,
"seq_length": 2048,
"max_position_embeddings": 2048,
Expand All @@ -16,6 +17,7 @@
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,
"use_bias_in_mlp": False,

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
Expand Down
2 changes: 2 additions & 0 deletions configs/llama/65B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 80,
"hidden_size": 8192,
"intermediate_size": 65536,
"num_attention_heads": 64,
"seq_length": 2048,
"max_position_embeddings": 2048,
Expand All @@ -16,6 +17,7 @@
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,
"use_bias_in_mlp": False,

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
Expand Down
2 changes: 2 additions & 0 deletions configs/llama/7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 32,
"hidden_size": 4096,
"intermediate_size": 32768,
"num_attention_heads": 32,
"seq_length": 2048,
"max_position_embeddings": 2048,
Expand All @@ -16,6 +17,7 @@
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,
"use_bias_in_mlp": False,

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
Expand Down
2 changes: 1 addition & 1 deletion configs/llama/train_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,
"mlp_multiple_of": 256,

}
1 change: 1 addition & 0 deletions configs/llama2/13B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 40,
"hidden_size": 5120,
"intermediate_size": 41472,
"num_attention_heads": 40,
"seq_length": 4096,
"max_position_embeddings": 4096,
Expand Down
2 changes: 1 addition & 1 deletion configs/llama2/70B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# model settings
"num_layers": 80,
"hidden_size": 8192,
"intermediate_size": 28672,
"intermediate_size": 86016,
"num_attention_heads": 64,
"num_kv_heads": 8,
"seq_length": 4096,
Expand Down
1 change: 1 addition & 0 deletions configs/llama2/7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 32,
"hidden_size": 4096,
"intermediate_size": 32768,
"num_attention_heads": 32,
"seq_length": 4096,
"max_position_embeddings": 4096,
Expand Down
23 changes: 23 additions & 0 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,29 @@ Model Arguments
- **dim_att**: int
Default = None
Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size.
- **head_size**: int
Default = None
Size of each attention head for RWKV. Calculated as dim_att // num_attention_heads.
- **ffn_dim**: int
Default = None
Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor.
## NeoXArgsOptimizer
Optimizer Arguments
Expand Down
17 changes: 17 additions & 0 deletions configs/prof.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Sample profiling config
{
# Turns on nsys and pytorch profiling
"profile": true,

# pytorch profiler options
"profile_step_start": 10,
"profile_step_stop": 12,

# pytorch memory profiler options
"memory_profiling": true,
"memory_profiling_path": tensorboard,


# All trace files (pytorch, nsys, tensorboard, etc) will be written here
"tensorboard_dir": "tensorboard",
}
71 changes: 25 additions & 46 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,13 @@ def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
rpe = self.rpe(query_layer.size(0), key_layer.size(0))
else:
rpe = None
return self.sparse_attn(
attn_scores = self.sparse_attn(
query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe
)
# apply dropout
if self.training:
attn_scores = self.attention_dropout(attn_scores)
return attn_scores

def gqa_project(self, hidden_states, attention_mask, layer_past=None):
# QKV projection and separation into separate Q/K/V layers for GQA,
Expand All @@ -763,51 +767,16 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
# pass through projection: [sq, b, h] --> [sq, b, ((np + 2 * kvp) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

# First: reshape so we have seqlen, batch, and num. query heads each as separate dims
# Final dim is not exactly head dim: the first (head dim) dims are query heads,
# The last (head dim * ratio of kv to q heads) each are the "k/v heads"
# (right now we treat like we have same num. heads, but smaller head dim)

# [sq, b, ((np + 2 * kvp) * hn)] --> [sq, b, np, (hn * (1 + 2 * (kvp / np)))]
new_qkv_shape = (
mixed_x_layer.shape[0],
mixed_x_layer.shape[1],
self.num_attention_heads_per_partition,
int(
self.hidden_size_per_attention_head
* (
1
+ 2
* (
self.num_kv_heads_per_partition
/ self.num_attention_heads_per_partition
)
)
),
)
mixed_x_layer = mixed_x_layer.reshape(*new_qkv_shape)

# Next: split our fake head dim. (last dim) so that the first (head dim) dimensions go to Q,
# the last smaller 2 * (head dim * kv to q head ratio) each divided between K and V separately
# split the last dim, so that the first (q head * head dim) dimensions go to Q,
# the last smaller 2 * (kv head * head dim) each divided between K and V separately
split_sizes = (
self.hidden_size_per_attention_head,
int(
(
self.num_kv_heads_per_partition
/ self.num_attention_heads_per_partition
)
* self.hidden_size_per_attention_head
),
int(
(
self.num_kv_heads_per_partition
/ self.num_attention_heads_per_partition
)
* self.hidden_size_per_attention_head
),
self.num_attention_heads_per_partition
* self.hidden_size_per_attention_head,
self.num_kv_heads_per_partition * self.hidden_size_per_attention_head,
self.num_kv_heads_per_partition * self.hidden_size_per_attention_head,
)

# [sq, b, np, (hn * (1 + 2 * (kvp / np)))] --> 1 x [sq, b, np, hn] , 2 x [sq, b, np, (hn * (kvp / np))]
# [sq, b, ((np + 2 * kvp) * hn)] --> 1 x [sq, b, np * hn] , 2 x [sq, b, kvp * hn]
(query_layer, key_layer, value_layer) = [
x.contiguous()
for x in torch.split(
Expand All @@ -817,6 +786,17 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
)
]

# reshape Q to proper output shape (last dim = correct full "real" head size again)
# [sq, b, np * hn] --> [sq, b, np, hn]
new_query_shape = (
query_layer.size(0),
query_layer.size(1),
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)

query_layer = query_layer.view(*new_query_shape)

# reshape K/V to proper output shape (last dim = correct full "real" head size again)
# 2 x [sq, b, np, (hn * (kvp / np))] --> 2 x [sq, b, kvp, hn]
new_kv_shape = (
Expand Down Expand Up @@ -1269,9 +1249,8 @@ def forward(self, x, attention_mask, layer_past=None):

with torch.enable_grad() if not self.eval else nullcontext():
if (
self.activation == "swiglu"
or self.num_experts > 1
and self.moe_type == "deepspeed"
mlp_bias == None,
self.num_experts > 1 and self.moe_type == "deepspeed",
):
# No dropout either
assert mlp_bias is None
Expand Down
39 changes: 25 additions & 14 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,19 @@
ATTENTION_TYPE_CHOICES,
)

### Logging colors ###
### ANSI escape codes ###
END = "\033[0m"
GREEN = "\033[92m"
RED = "\033[91m"
YELLOW = "\033[93m"
END = "\033[0m"
SUCCESS = f"{GREEN} [SUCCESS] {END}"
OKAY = f"{GREEN}[OKAY]{END}"
WARNING = f"{YELLOW}[WARNING]{END}"

### Formatted logging prefixes ###
ERROR = f"{RED}[ERROR]{END} "
FAIL = f"{RED}[FAIL]{END}"
INFO = "[INFO]"
OKAY = f"{GREEN}[OKAY]{END}"
SUCCESS = f"{GREEN} [SUCCESS] {END}"
WARNING = f"{YELLOW}[WARNING]{END}"

# ZERO defaults by deespeed
# These values should not be changed unless defaults in deepspeed are changed
Expand Down Expand Up @@ -953,12 +956,19 @@ def calculate_derived(self):
)

# derive precision
fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts"
if self.fp16 and self.fp16.get("enabled", False):
if self.precision is None:
self.update_value("precision", "fp16")
else:
fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts"
assert self.precision == "fp16", fp16_conflict

if self.bf16 and self.bf16.get("enabled", False):
if self.precision is None:
self.update_value("precision", "bfloat16")
else:
bf16_conflict = "DeepSpeed bf16 field was set but precision conflicts"
assert self.precision == "bfloat16", bf16_conflict

if self.precision == "fp16":
if isinstance(self.fp16, dict) and len(self.fp16) > 0:
Expand All @@ -968,14 +978,15 @@ def calculate_derived(self):
fp16_args = {"type": "fp16", "enabled": True}
self.update_value("fp16", fp16_args)
elif self.precision == "bfloat16":
bf_config = {"bf16": {"enabled": True}}
# dt_config = {"grad_accum_dtype": "fp32"}
if self.deepspeed_extra_args is None:
self.update_value("deepspeed_extra_args", bf_config)
else:
extra_args = copy.deepcopy(self.deepspeed_extra_args)
extra_args.update(bf_config)
self.update_value("deepspeed_extra_args", extra_args)
if not self.bf16:
bf_config = {"bf16": {"enabled": True}}
# dt_config = {"grad_accum_dtype": "fp32"}
if self.deepspeed_extra_args is None:
self.update_value("deepspeed_extra_args", bf_config)
else:
extra_args = copy.deepcopy(self.deepspeed_extra_args)
extra_args.update(bf_config)
self.update_value("deepspeed_extra_args", extra_args)

zero_stage = self.zero_optimization["stage"]
if self.data_types is None:
Expand Down
Loading

0 comments on commit 82a479f

Please sign in to comment.