Skip to content

Commit

Permalink
new awq kernels paths (#2572)
Browse files Browse the repository at this point in the history
* new awq kernels paths
  • Loading branch information
vince62s authored Mar 18, 2024
1 parent 39c984f commit 3f0c5f7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 3 additions & 1 deletion onmt/modules/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def forward(self, x):
y = torch.empty_like(x)
for i, expert in enumerate(self.experts):
if torch.any(flat_expert_indices == i):
y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
y[flat_expert_indices == i] = expert(
x[flat_expert_indices == i].unsqueeze(0)
)
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(
dim=1
)
Expand Down
10 changes: 5 additions & 5 deletions onmt/modules/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch.nn as nn

try:
import awq_inference_engine
import awq_ext

AWQ_INFERENCE_ENGINE = True
AWQ_EXT = True
except ImportError:
AWQ_INFERENCE_ENGINE = False
AWQ_EXT = False


class RMSNorm(torch.nn.Module):
Expand All @@ -24,12 +24,12 @@ def __init__(self, hidden_size: int, eps: float = 1e-6):
self.weight = nn.Parameter(torch.ones(hidden_size))

def forward(self, hidden_states):
if AWQ_INFERENCE_ENGINE and not self.training:
if AWQ_EXT and not self.training:
inp_type = hidden_states.dtype
output = torch.empty_like(hidden_states).to(inp_type)
if hidden_states.dim() == 2: # patch for multi experts
hidden_states = hidden_states.unsqueeze(0)
awq_inference_engine.layernorm_forward_cuda(
awq_ext.layernorm_forward_cuda(
hidden_states.half(), self.weight.half(), output.half(), self.eps
)
if hidden_states.dim() == 2: # patch for multi experts
Expand Down

0 comments on commit 3f0c5f7

Please sign in to comment.