-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathds_mlp.py
124 lines (106 loc) · 6.15 KB
/
ds_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import math
import torch
import torch.nn as nn
from deepspeed import comm as dist
from deepspeed.utils.types import GATED_ACTIVATION_TYPES
from deepspeed.accelerator import get_accelerator
from .op_binding import MLPGemmOp, VectorMatMulOp, GELUGemmOp, ResidualAddOp
class DeepSpeedMLP(nn.Module):
_inter_w_buffers = []
def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False):
super(DeepSpeedMLP, self).__init__()
self.config = config
data_type = torch.int8 if self.config.dtype == torch.int8 else self.config.dtype
data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
device = get_accelerator().current_device_name()
proj_factor = 2 if self.config.mlp_act_func_type in GATED_ACTIVATION_TYPES else 1
self.config.intermediate_size = self.config.intermediate_size if self.config.intermediate_size > 0 else 4 * self.config.hidden_size
self.intm_w_sz_per_partition = self.config.intermediate_size * proj_factor // self.config.mp_size
self.intm_o_sz_per_partition = self.config.intermediate_size // self.config.mp_size
if self.config.set_empty_params:
self.attn_nw = None
self.attn_nb = None
self.inter_w = None
self.inter_b = None
self.inter_up_w = None
self.inter_up_b = None
self.inter_gate_w = None
self.inter_gate_b = None
self.output_w = None
self.output_b = None
else:
self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
self.intm_w_sz_per_partition,
dtype=data_type,
device=device),
requires_grad=False)
self.inter_b = nn.Parameter(torch.empty(self.intm_w_sz_per_partition, dtype=data_type_fp, device=device),
requires_grad=False)
self.output_w = nn.Parameter(torch.empty(self.intm_o_sz_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
# used for quantization
self.q_scales = q_scales
self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups
self.merge_count = int(math.log2(merge_count))
self.mp_group = mp_group
self.mlp_gemm_func = MLPGemmOp(config)
self.vector_matmul_func = VectorMatMulOp(config)
self.fused_gemm_gelu = GELUGemmOp(config)
self.residual_add_func = ResidualAddOp(config)
if len(DeepSpeedMLP._inter_w_buffers) == 0:
DeepSpeedMLP._inter_w_buffers = [
torch.empty(self.intm_w_sz_per_partition, self.config.hidden_size, dtype=data_type, device=device),
torch.empty(self.intm_w_sz_per_partition, dtype=data_type_fp, device=device)
]
def _merge_inter_w(self):
inter_w = DeepSpeedMLP._inter_w_buffers[0]
inter_w[:self.intm_w_sz_per_partition // 2, :] = self.inter_up_w # type: ignore
inter_w[self.intm_w_sz_per_partition // 2:, :] = self.inter_gate_w # type: ignore
if self.inter_up_b is not None:
inter_b = DeepSpeedMLP._inter_w_buffers[1]
inter_b[:self.intm_w_sz_per_partition // 2] = self.inter_up_b # type: ignore
inter_b[self.intm_w_sz_per_partition // 2:] = self.inter_gate_b # type: ignore
return DeepSpeedMLP._inter_w_buffers
def forward(self, input, residual, residual_norm, bias):
if self.inter_w is None:
self._inter_w, self._inter_b = self._merge_inter_w()
else:
self._inter_w = self.inter_w
self._inter_b = self.inter_b
residual_add = None
if self.attn_nw is None:
output = self.fused_gemm_gelu(input=residual_norm,
weight=self._inter_w,
bias=self._inter_b,
weight_out=self.output_w)
else:
output, residual_add = self.mlp_gemm_func(input=input,
residual=residual,
weight_interm=self._inter_w,
weight_out=self.output_w,
input_bias=bias,
bias=self._inter_b,
gamma=self.attn_nw,
beta=self.attn_nb)
residual = self.residual_add_func(hidden_state=output,
residual=residual,
add_bias=bias is not None,
attention_output=input,
attention_bias=bias if bias is not None else self.output_b,
final_bias=self.output_b,
residual_add=residual_add)
if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
dist.all_reduce(residual, group=self.mp_group)
return residual