Skip to content

Commit

Permalink
annote triton constexpr #56
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianlim committed Jul 31, 2024
1 parent 80edec8 commit ba579e3
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _cross_entropy_backward(
pass


MAX_FUSED_SIZE = 65536 # 2**16
MAX_FUSED_SIZE: tl.constexpr = 65536 # 2**16

class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from .utils import calculate_settings

ROPE_GROUP_SIZE = 4
ROPE_GROUP_SIZE: tl.constexpr = 4

@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
@triton.jit
Expand Down
6 changes: 0 additions & 6 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ deps =
packaging # this is required for flash-attn dep as fms_hf_tuning did not specify
-e {toxinidir}/plugins/framework # install the framework here as the flash attention deps requires torch
passenv = * # will pass the parent env, otherwise there are too many envs e.g. TRANSFORMERS that need to be set
setenv =
# Need to be set in new versions of triton that don't allow for access to global variable in the JIT compile
# Subsequently, consider changing triton kernels to access global variables that are annotated as constexpr
# source: https://github.com/triton-lang/triton/blob/7b617bcc35c4cf06f61dd267fc049fe33b2851f9/python/triton/compiler/code_generator.py#L280
# Tracking this as an issue here # https://github.com/foundation-model-stack/fms-acceleration/issues/56
TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1
commands =
# need a version of fms-hf-tuning that has integrated the framework
# NOTE: have to install this first coz havnt merged
Expand Down

0 comments on commit ba579e3

Please sign in to comment.