Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused linear cross entropy #134

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
8b743c0
cleanup
mayank31398 Feb 20, 2025
138b4b8
cleanup
mayank31398 Feb 20, 2025
78226dd
cleanup
mayank31398 Feb 20, 2025
74a74bc
cleanup
mayank31398 Feb 28, 2025
0dad48e
cleanup
mayank31398 Feb 28, 2025
471d165
cleanup
mayank31398 Feb 28, 2025
4c9b6a0
cleanup
mayank31398 Feb 28, 2025
894fd07
cleanup
mayank31398 Feb 28, 2025
0ae39d0
cleanup
mayank31398 Feb 28, 2025
0736f42
cleanup
mayank31398 Feb 28, 2025
6b184c7
cleanup
mayank31398 Feb 28, 2025
9c462a2
Merge branch 'main' into cleanup
mayank31398 Feb 28, 2025
32c2e7c
cleanup
mayank31398 Mar 1, 2025
a4f82e2
cleanup
mayank31398 Mar 1, 2025
faf623f
cleanup
mayank31398 Mar 1, 2025
2647eeb
cleanup
mayank31398 Mar 1, 2025
46bbf39
cleanup
mayank31398 Mar 1, 2025
d025063
cleanup
mayank31398 Mar 1, 2025
1b08c57
cleanup
mayank31398 Mar 1, 2025
3e975b7
cleanup
mayank31398 Mar 1, 2025
f2cc055
cleanup
mayank31398 Mar 1, 2025
7a3ef56
cleanup
mayank31398 Mar 1, 2025
fdc415e
cleanup
mayank31398 Mar 1, 2025
9b7bad5
cleanup
mayank31398 Mar 1, 2025
f44be55
cleanup
mayank31398 Mar 1, 2025
016db4a
cleanup
mayank31398 Mar 1, 2025
612dcef
cleanup
mayank31398 Mar 1, 2025
866d400
cleanup
mayank31398 Mar 1, 2025
4dc80eb
cleanup
mayank31398 Mar 1, 2025
7e706e3
cleanup
mayank31398 Mar 1, 2025
8e818ef
cleanup
mayank31398 Mar 2, 2025
20e6daa
cleanup
mayank31398 Mar 2, 2025
baf60fe
cleanup
mayank31398 Mar 2, 2025
93fe263
cleanup
mayank31398 Mar 2, 2025
4ed518e
cleanup
mayank31398 Mar 2, 2025
5291ee8
cleanup
mayank31398 Mar 2, 2025
28a12b7
cleanup
mayank31398 Mar 2, 2025
2b11192
cleanup
mayank31398 Mar 2, 2025
0c5fdc0
cleanup
mayank31398 Mar 3, 2025
e45580e
Merge branch 'main' into cleanup
mayank31398 Mar 4, 2025
6ee1ec1
cleanup
mayank31398 Mar 4, 2025
80978b3
cleanup
mayank31398 Mar 4, 2025
f4856be
cleanup
mayank31398 Mar 4, 2025
632c94c
cleanup
mayank31398 Mar 4, 2025
101b20c
cleanup
mayank31398 Mar 4, 2025
5119703
cleanup
mayank31398 Mar 4, 2025
032b7c8
cleanup
mayank31398 Mar 4, 2025
65abd2f
cleanup
mayank31398 Mar 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@ datasets:
tokenizer_args:
tokenizer_name: bigcode/starcoder

# kernel_args:
# kernels:
# - cute_swiglu
# - cute_rmsnorm

model_args:
model_class: AutoModelForCausalLM
pretrained_config:
Expand Down
2 changes: 1 addition & 1 deletion cute-kernels
Submodule cute-kernels updated 55 files
+4 −1 cute_kernels/__init__.py
+787 −0 cute_kernels/cache.yml
+21 −24 cute_kernels/cute_inductor/compiler.py
+1 −1 cute_kernels/cutotune/__init__.py
+38 −74 cute_kernels/cutotune/cache.py
+21 −28 cute_kernels/cutotune/tuner.py
+2 −0 cute_kernels/kernels/__init__.py
+0 −16 cute_kernels/kernels/add/add_scalar/forward.py->_forward.yml
+0 −16 cute_kernels/kernels/add/add_tensor/forward.py->_forward.yml
+11 −33 cute_kernels/kernels/continuous_count/__init__.py
+0 −169 cute_kernels/kernels/continuous_count/__init__.py->_continuous_count_cute.yml
+0 −60 cute_kernels/kernels/continuous_count/triton_implementation.py
+72 −0 cute_kernels/kernels/cross_entropy/__init__.py
+10 −0 cute_kernels/kernels/cross_entropy/torch_implementation.py
+135 −0 cute_kernels/kernels/cross_entropy/triton_implementation.py
+0 −19 cute_kernels/kernels/embedding/backward.py->_backward.yml
+0 −19 cute_kernels/kernels/embedding/forward.py->_forward.yml
+103 −0 cute_kernels/kernels/fused_linear_cross_entropy/__init__.py
+11 −0 cute_kernels/kernels/fused_linear_cross_entropy/torch_implementation.py
+0 −49 cute_kernels/kernels/gemm/__init__.py->gemm_cute.yml
+0 −61 cute_kernels/kernels/gemm/cuda_implementation/__init__.py->naive_gemm_cuda.yml
+0 −13 cute_kernels/kernels/gemm/cuda_implementation/__init__.py->shared_memory_gemm_cuda.yml
+2 −2 cute_kernels/kernels/gemm/cuda_implementation/naive_kernel.cu
+4 −4 cute_kernels/kernels/gemm/cuda_implementation/ops.cpp
+2 −2 cute_kernels/kernels/gemm/cuda_implementation/shared_memory_kernel.cu
+2 −2 cute_kernels/kernels/gemm/triton_implementation.py
+0 −97 cute_kernels/kernels/gemm/triton_implementation.py->gemm_triton.yml
+7 −1 cute_kernels/kernels/rmsnorm/backward.py
+0 −13 cute_kernels/kernels/rmsnorm/backward.py->_backward.yml
+0 −13 cute_kernels/kernels/rmsnorm/forward.py->_forward.yml
+5 −78 cute_kernels/kernels/rmsnorm/triton_implementation/backward.py
+0 −409 cute_kernels/kernels/rmsnorm/triton_implementation/backward.py->rmsnorm_backward_triton.yml
+0 −409 cute_kernels/kernels/rmsnorm/triton_implementation/forward.py->rmsnorm_forward_triton.yml
+6 −0 cute_kernels/kernels/softmax/__init__.py
+2 −0 cute_kernels/kernels/softmax/backward.py
+0 −13 cute_kernels/kernels/softmax/backward.py->_backward.yml
+8 −1 cute_kernels/kernels/softmax/forward.py
+0 −13 cute_kernels/kernels/softmax/forward.py->_forward.yml
+8 −2 cute_kernels/kernels/softmax/torch_implementation.py
+9 −1 cute_kernels/kernels/softmax/triton_implementation/backward.py
+0 −16 cute_kernels/kernels/softmax/triton_implementation/backward.py->softmax_backward_triton.yml
+15 −12 cute_kernels/kernels/softmax/triton_implementation/forward.py
+0 −16 cute_kernels/kernels/softmax/triton_implementation/forward.py->softmax_forward_triton.yml
+0 −16 cute_kernels/kernels/swiglu/backward.py->_backward.yml
+0 −16 cute_kernels/kernels/swiglu/forward.py->_forward.yml
+0 −19 cute_kernels/kernels/swiglu_unchunked/backward.py->_backward.yml
+0 −19 cute_kernels/kernels/swiglu_unchunked/forward.py->_forward.yml
+1 −1 cute_kernels/utils/__init__.py
+8 −2 cute_kernels/utils/custom_op.py
+45 −0 examples/cute_inductor.py
+0 −12 tests/kernels/continuous_count_test.py
+47 −0 tests/kernels/cross_entropy_test.py
+60 −0 tests/kernels/fused_linear_cross_entropy_test.py
+15 −3 tests/kernels/softmax_test.py
+14 −10 tools/build_cutotune_cache.py
6 changes: 4 additions & 2 deletions dolomite_engine/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ class ExperimentsTrackerName(Enum):


class Kernel(Enum):
cute_rmsnorm = "cute_rmsnorm"
cute_swiglu_unchunked = "cute_swiglu_unchunked"
rmsnorm_cute = "rmsnorm_cute"
swiglu_unchunked_cute = "swiglu_unchunked_cute"
mamba2_ssm = "mamba2_ssm"
scattermoe = "scattermoe"
cross_entropy_cute = "cross_entropy_cute"
fused_linear_cross_entropy_cute = "fused_linear_cross_entropy_cute"
3 changes: 2 additions & 1 deletion dolomite_engine/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .distributed import dtensor_to_tensor, wrap_model_container_for_distributed_training
from .enums import DatasetSplit, Mode, TuningMethod
from .hf_models import disable_generation_cache
from .kernels import enable_kernels
from .model_wrapper import ModelWrapper, get_model_container
from .optimization import get_optimizer_container, get_scheduler_container
from .train_utils import all_reduce_metrics_tracker, get_torch_profiler, track_metrics
Expand Down Expand Up @@ -395,7 +396,7 @@ def main() -> None:
experiments_tracker.log_args(args)

# main training loop
with disable_generation_cache():
with disable_generation_cache(), enable_kernels(args.kernel_args.kernels):
train(
args,
model_container=model_container,
Expand Down
77 changes: 55 additions & 22 deletions dolomite_engine/hf_models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,80 @@
from torch.distributed.tensor.parallel import loss_parallel

from ..distributed import tensor_to_dtensor
from ..utils import ProcessGroupManager
from ..enums import Kernel
from ..kernels import is_kernel_allowed
from ..utils import ProcessGroupManager, is_cute_kernels_available


if is_cute_kernels_available():
from cute_kernels import cross_entropy_cute, fused_linear_cross_entropy_cute


def get_autoregressive_language_modeling_loss(
lm_logits: torch.Tensor,
labels: torch.Tensor,
hidden_states: torch.Tensor | None = None,
vocab_weight: torch.Tensor | None = None,
logits_multiplier: float = 1,
cu_seqlens: torch.Tensor | None = None,
use_padding_free_transformer: bool = False,
reduction: str = "mean",
shift_logits_and_labels: bool = True,
tensor_parallel_enabled: bool = False,
) -> torch.Tensor | DTensor:
if use_padding_free_transformer:
assert cu_seqlens is not None
if shift_logits_and_labels:
lm_logits = lm_logits[..., :-1, :]
labels = labels[..., 1:]

shift_logits = lm_logits[:-1, :]
shift_labels = labels[1:].to(shift_logits.device)
if use_padding_free_transformer:
if shift_logits_and_labels:
assert cu_seqlens is not None

# this is needed so that the last token of current example doesn't predict first token of next example
drop_loss_positions = cu_seqlens[1:-1] - 1
shift_labels[drop_loss_positions] = -100
# this is needed so that the last token of current example doesn't predict first token of next example
drop_loss_positions = cu_seqlens[1:-1] - 1
labels[drop_loss_positions] = -100
else:
assert cu_seqlens is None

# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
if is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute):
assert lm_logits is None
assert not tensor_parallel_enabled

loss_context = nullcontext
loss = fused_linear_cross_entropy_cute(
x=hidden_states.reshape(-1, hidden_states.size(-1)),
weight=vocab_weight,
labels=labels.reshape(-1),
reduction=reduction,
logits_multiplier=logits_multiplier,
)
elif is_kernel_allowed(Kernel.cross_entropy_cute):
assert hidden_states is None
assert vocab_weight is None
assert not tensor_parallel_enabled

loss = cross_entropy_cute(
x=lm_logits.reshape(-1, lm_logits.size(-1)),
labels=labels.reshape(-1),
reduction=reduction,
logits_multiplier=logits_multiplier,
)
else:
assert logits_multiplier == 1
loss_context = nullcontext

if ProcessGroupManager.is_initialized() and ProcessGroupManager.is_tensor_parallel_enabled():
loss_context = loss_parallel
tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh()
if tensor_parallel_enabled:
loss_context = loss_parallel
tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh()

shift_logits = tensor_to_dtensor(shift_logits, device_mesh=tp_mesh, current_placement=Shard(-1))
shift_labels = tensor_to_dtensor(shift_labels, device_mesh=tp_mesh, current_placement=Replicate())
lm_logits = tensor_to_dtensor(lm_logits, device_mesh=tp_mesh, current_placement=Shard(-1))
labels = tensor_to_dtensor(labels, device_mesh=tp_mesh, current_placement=Replicate())

shift_logits = shift_logits.float()
lm_logits = lm_logits.float()

with loss_context():
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction=reduction
)
with loss_context():
loss = F.cross_entropy(
input=lm_logits.reshape(-1, lm_logits.size(-1)), target=labels.reshape(-1), reduction=reduction
)

return loss

Expand Down
1 change: 1 addition & 0 deletions dolomite_engine/hf_models/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .dense import BaseModelMixin, CausalLMModelMixin, PreTrainedModelMixin
from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .moe_TP import CausalLMMoEModelMixin_TP
4 changes: 1 addition & 3 deletions dolomite_engine/hf_models/mixins/dense/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import torch
import torch.nn as nn
from transformers import DynamicCache, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast

from ....utils import divide_if_divisible
from ...cache import HybridMambaAttentionDynamicCache
from ...config import CommonConfig
from ...enums import PositionEmbeddingType
from ...loss import clear_aux_loss
from ...modeling_utils import ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function
from ...utils import convert_padding_free_lists_to_tensors, is_generation_cache_enabled
from ..modeling_outputs import BaseModelOutputWithPast


class PreTrainedModelMixin(PreTrainedModel):
Expand Down Expand Up @@ -196,7 +195,6 @@ def forward(
self._get_empty_cache(input_ids) if use_cache and past_key_values is None else past_key_values
)

clear_aux_loss()
mamba_mask = None
mamba_mask_computed = False

Expand Down
45 changes: 34 additions & 11 deletions dolomite_engine/hf_models/mixins/dense/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import torch
import torch.nn.functional as F
from transformers import DynamicCache, GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast

from ....enums import Kernel
from ....kernels import is_kernel_allowed
from ...config import CommonConfig
from ...loss import get_autoregressive_language_modeling_loss, get_aux_loss
from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss
from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear
from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .base import PreTrainedModelMixin


Expand Down Expand Up @@ -39,8 +41,7 @@ def set_input_embeddings(self, value: ParameterizedEmbedding) -> None:
self.transformer.wte = value

def get_output_embeddings(self) -> ParameterizedLinear:
if not self._tied_word_embeddings:
return self.lm_head
return self.transformer.wte if self._tied_word_embeddings else self.lm_head

def set_output_embeddings(self, new_embeddings: ParameterizedLinear) -> None:
if not self._tied_word_embeddings:
Expand All @@ -60,6 +61,8 @@ def forward(
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
reduction: str = "mean",
apply_output_projection: bool = True,
apply_logits_multiplier: bool = True,
) -> CausalLMOutputWithPast:
assert return_dict

Expand Down Expand Up @@ -87,6 +90,8 @@ def forward(
# position_ids -> None or (batch_size, key_length)
# ==========================================================================================

clear_aux_loss()

transformer_outputs: BaseModelOutputWithPast = self.transformer(
input_ids,
past_key_values=past_key_values,
Expand All @@ -99,19 +104,38 @@ def forward(
max_seqlen=max_seqlen,
)

lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state)
lm_logits = None
loss = None

if self.m_width is not None:
lm_logits = lm_logits / self.m_width
if labels is None:
if apply_output_projection:
lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state)

if apply_logits_multiplier and self.m_width is not None:
assert apply_output_projection
lm_logits = lm_logits / self.m_width
else:
assert apply_output_projection
assert apply_logits_multiplier
assert not is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute)
assert not is_kernel_allowed(Kernel.cross_entropy_cute)

lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state)

if self.m_width is not None:
lm_logits = lm_logits / self.m_width

loss = None
if labels is not None:
loss = get_autoregressive_language_modeling_loss(
lm_logits=lm_logits,
labels=labels,
hidden_states=None,
vocab_weight=None,
logits_multiplier=1,
cu_seqlens=cu_seqlens,
use_padding_free_transformer=self._use_padding_free_transformer,
reduction=reduction,
shift_logits_and_labels=True,
tensor_parallel_enabled=False,
)

aux_loss = get_aux_loss()
Expand All @@ -125,8 +149,7 @@ def forward(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
last_hidden_state=transformer_outputs.last_hidden_state,
)

def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion dolomite_engine/hf_models/mixins/dense_TP/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
from transformers import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast

from ....utils import ProcessGroupManager, divide_if_divisible
from ...config import CommonConfig
Expand All @@ -11,6 +10,7 @@
from ...modeling_utils_TP import Dropout_TP, Embedding_TP, get_normalization_function_TP
from ...utils import is_generation_cache_enabled
from ..dense import BaseModelMixin, PreTrainedModelMixin
from ..modeling_outputs import BaseModelOutputWithPast


class PreTrainedModelMixin_TP(PreTrainedModelMixin):
Expand Down
16 changes: 11 additions & 5 deletions dolomite_engine/hf_models/mixins/dense_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from torch.distributed._tensor.placement_types import Replicate, Shard
from transformers import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast

from ....distributed import dtensor_to_tensor, tensor_to_dtensor
from ....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible
Expand All @@ -12,6 +11,7 @@
from ...loss import get_autoregressive_language_modeling_loss
from ...modeling_utils_TP import LMHead_TP
from ..dense import CausalLMModelMixin
from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .base import PreTrainedModelMixin_TP


Expand Down Expand Up @@ -54,6 +54,8 @@ def forward(
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
reduction: str = "mean",
apply_output_projection: bool = True,
apply_logits_multiplier: bool = True,
) -> CausalLMOutputWithPast | torch.Tensor:
assert return_dict

Expand Down Expand Up @@ -84,9 +86,12 @@ def forward(
)

if not self.is_pipeline_parallel_enabled or self.is_last_stage:
lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state)
lm_logits = None

if self.m_width is not None:
if apply_output_projection:
lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state)

if apply_logits_multiplier and self.m_width is not None:
lm_logits = lm_logits / self.m_width

if not self.is_pipeline_parallel_enabled:
Expand All @@ -98,6 +103,8 @@ def forward(
cu_seqlens=cu_seqlens,
use_padding_free_transformer=self._use_padding_free_transformer,
reduction=reduction,
shift_logits_and_labels=True,
tensor_parallel_enabled=True,
)

if (not self.is_pipeline_parallel_enabled or self.is_last_stage) and not output_parallel_lm_logits:
Expand All @@ -110,8 +117,7 @@ def forward(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
last_hidden_state=transformer_outputs.last_hidden_state,
)
elif self.is_last_stage:
output = lm_logits
Expand Down
18 changes: 18 additions & 0 deletions dolomite_engine/hf_models/mixins/modeling_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass

import torch
from transformers.modeling_outputs import ModelOutput


@dataclass
class BaseModelOutputWithPast(ModelOutput):
last_hidden_state: torch.Tensor | None = None
past_key_values: tuple[tuple[torch.Tensor]] | None = None


@dataclass
class CausalLMOutputWithPast(ModelOutput):
loss: torch.Tensor | None = None
logits: torch.Tensor | None = None
past_key_values: tuple[tuple[torch.Tensor]] | None = None
last_hidden_state: torch.Tensor | None = None
Loading