-
Notifications
You must be signed in to change notification settings - Fork 870
/
moe.py
36 lines (30 loc) · 1.12 KB
/
moe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import dataclasses
from typing import List
import torch
import torch.nn.functional as F
from simple_parsing.helpers import Serializable
from torch import nn
@dataclasses.dataclass
class MoeArgs(Serializable):
num_experts: int
num_experts_per_tok: int
class MoeLayer(nn.Module):
def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
super().__init__()
assert len(experts) > 0
self.experts = nn.ModuleList(experts)
self.gate = gate
self.args = moe_args
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
gate_logits = self.gate(inputs)
weights, selected_experts = torch.topk(
gate_logits, self.args.num_experts_per_tok
)
weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
results = torch.zeros_like(inputs)
for i, expert in enumerate(self.experts):
batch_idx, nth_expert = torch.where(selected_experts == i)
results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
inputs[batch_idx]
)
return results