Skip to content

Commit

Permalink
reinstate cuda rmsnorm (much faster in fp16/awq) + ct2 enc/dec config (
Browse files Browse the repository at this point in the history
…#167)

* reinstate cuda rmsnorm (much faster in fp16/awq) + ct2 enc/dec config
* cast systematically to fp16 for rmsnorm cuda kernel
* comments in validator
  • Loading branch information
vince62s authored Dec 19, 2024
1 parent 34e9c94 commit faa8917
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
14 changes: 11 additions & 3 deletions eole/bin/run/predict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from eole.inference_engine import InferenceEnginePY, InferenceEngineCT2

from eole.constants import ModelType
from argparse import ArgumentParser
from eole.utils.misc import use_gpu, set_random_seed
from torch.profiler import profile, record_function, ProfilerActivity
Expand All @@ -11,13 +11,21 @@
from time import time


def model_type(config) -> ModelType:
if config.decoder is None:
return ModelType.ENCODER
elif config.encoder is None:
return ModelType.DECODER
else:
return ModelType.ENCODER_DECODER


def predict(config):
set_random_seed(config.seed, use_gpu(config))

if config.engine == "eole":
engine = InferenceEnginePY(config)
elif config.engine == "ct2":
engine = InferenceEngineCT2(config, "decoder")
engine = InferenceEngineCT2(config, model_type(config.model))
else:
raise ValueError("You need to use --engine with 'eole' or 'ct2'")
_, _, _ = engine.infer_file()
Expand Down
28 changes: 26 additions & 2 deletions eole/modules/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import torch.nn as nn


try:
import awq_ext

AWQ_EXT = True
except ImportError:
AWQ_EXT = False


class RMSNorm(torch.nn.Module):
"""RMSNorm: https://arxiv.org/abs/1910.07467
Args:
Expand All @@ -17,9 +25,25 @@ def __init__(self, hidden_size: int, eps: float = 1e-6):
self.weight = nn.Parameter(torch.ones(hidden_size))

@torch.compile(dynamic=True)
def forward(self, hidden_states):
def compute_rms(self, hidden_states, dtype):
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states.to(dtype)
return hidden_states * self.weight

def forward(self, hidden_states):
inp_dtype = hidden_states.dtype
if AWQ_EXT and not self.training:
# cuda kernel support only fp16 - need to cast
output = torch.empty_like(hidden_states).to(torch.float16)
if hidden_states.dim() == 2: # patch for multi experts
hidden_states = hidden_states.unsqueeze(0)
awq_ext.layernorm_forward_cuda(
hidden_states.half(), self.weight.half(), output, self.eps
)
if hidden_states.dim() == 2: # patch for multi experts
output = output.unsqueeze(0)
return output.to(inp_dtype)
else:
return self.compute_rms(hidden_states, inp_dtype)
4 changes: 3 additions & 1 deletion eole/predict/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from eole.predict.beam_search import BeamSearchLM
from eole.utils.misc import tile

# from time import time


class GeneratorLM(Inference):
@classmethod
Expand Down Expand Up @@ -164,7 +166,7 @@ def _predict_batch_with_strategy(self, batch, decode_strategy, left_pad=True):
# select indexes in model state/cache
self.model.decoder.map_state(lambda state, dim: state[select_indices])
# if step == 0:
# print("step0 time: ", time() - beg_time)
# print("step0 time: ", time() - beg_time)

if self.add_estimator:
# Prepare estimator input = decoder out of each pred with initial enc_out
Expand Down
17 changes: 5 additions & 12 deletions recipes/model-validator/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,21 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

# Define the models table
models=(
# Validated
"mistralai/Ministral-8B-Instruct-2410"
"mistralai/Mistral-7B-v0.3"
"mistralai/Mistral-7B-Instruct-v0.3"
"mistralai/Mistral-7B-Instruct-v0.2"
"mistralai/Mistral-7B-Instruct-v0.1"
"mistralai/Mistral-7B-v0.1"
"mistralai/Mathstral-7B-v0.1"
"meta-llama/Llama-3.2-1B"
"meta-llama/Llama-3.2-3B"
"meta-llama/Llama-3.2-3B-Instruct"
"meta-llama/Llama-3.2-1B-Instruct"
"meta-llama/Llama-3.1-8B"
"meta-llama/Llama-3.1-8B-Instruct"
"meta-llama/Meta-Llama-3-8B"
"meta-llama/Meta-Llama-3-8B-Instruct"
"meta-llama/CodeLlama-7b-hf"
"microsoft/Phi-3.5-mini-instruct"
"microsoft/Phi-3.5-MoE-instruct"
"microsoft/Phi-3-mini-4k-instruct"
"microsoft/Phi-3-mini-128k-instruct"
"microsoft/Phi-3-small-8k-instruct"
"microsoft/Phi-3-small-128k-instruct"
# to work on
# "mistralai/Mathstral-7B-v0.1" # fp32 !
# "microsoft/Phi-3.5-MoE-instruct" # convert_HF not set for PhiMoEForCausalLM
# "microsoft/Phi-3-small-128k-instruct" # tokenizer to be taken from another model
)

# Log file for errors
Expand Down

0 comments on commit faa8917

Please sign in to comment.