Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP8 constant #4222

Merged
merged 4 commits into from
Jun 28, 2024
Merged

Support FP8 constant #4222

merged 4 commits into from
Jun 28, 2024

Conversation

htyu
Copy link
Collaborator

@htyu htyu commented Jun 27, 2024

To unblock the compilation of kernels like below which don't operate arithmetically on FP8.

@triton.jit
def triton_poi_fused__scaled_mm__to_copy_constant_pad_nd_lift_fresh_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 400624
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 784
    x1 = (xindex // 784)
    x2 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 769, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = tl.load(in_ptr0 + (x0 + (769*x1)), tmp2 & xmask, other=0.0)
    tmp4 = tmp3.to(tl.float8e4nv)
    tmp5 = tl.full(tmp4.shape, 0.0, tmp4.dtype)
    tmp6 = tl.where(tmp2, tmp4, tmp5)
    tl.store(out_ptr0 + (x2), tmp6, xmask)

@htyu htyu requested a review from ptillet as a code owner June 27, 2024 17:29
@htyu
Copy link
Collaborator Author

htyu commented Jun 27, 2024

I'll add a lit test if this sounds a reasonable fix.

@htyu
Copy link
Collaborator Author

htyu commented Jun 27, 2024

I was hitting

  %2 = "llvm.mlir.constant"() <{value = 0.000000e+00 : f8E4M3FNUZ}> : () -> f8E4M3FNUZ loc(#loc1)

error: 'llvm.mlir.constant' op result #0 must be LLVM dialect-compatible type, but got 'f8E4M3FNUZ'

@ThomasRaoux
Copy link
Collaborator

I was hitting

  %2 = "llvm.mlir.constant"() <{value = 0.000000e+00 : f8E4M3FNUZ}> : () -> f8E4M3FNUZ loc(#loc1)

error: 'llvm.mlir.constant' op result #0 must be LLVM dialect-compatible type, but got 'f8E4M3FNUZ'

you're right, makes sense. Does the conversion to LLVM dialect work for scalar constant?

@htyu
Copy link
Collaborator Author

htyu commented Jun 27, 2024

I was hitting

  %2 = "llvm.mlir.constant"() <{value = 0.000000e+00 : f8E4M3FNUZ}> : () -> f8E4M3FNUZ loc(#loc1)

error: 'llvm.mlir.constant' op result #0 must be LLVM dialect-compatible type, but got 'f8E4M3FNUZ'

you're right, makes sense. Does the conversion to LLVM dialect work for scalar constant?

Good point. Probably not. Actually the conversion being changed is for scalar constant, which is then broadcasted to a tensor in registers.

@htyu
Copy link
Collaborator Author

htyu commented Jun 27, 2024

Oh I see your point. The original constant on TTGIR is a tensor
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf8E4M3FNUZ, #blocked> loc(#loc1)

I'll add support for scalar constant.

@ThomasRaoux
Copy link
Collaborator

I was hitting

  %2 = "llvm.mlir.constant"() <{value = 0.000000e+00 : f8E4M3FNUZ}> : () -> f8E4M3FNUZ loc(#loc1)

error: 'llvm.mlir.constant' op result #0 must be LLVM dialect-compatible type, but got 'f8E4M3FNUZ'

you're right, makes sense. Does the conversion to LLVM dialect work for scalar constant?

Good point. Probably not. Actually the conversion being changed is for scalar constant, which is then broadcasted to a tensor in registers.

I believe we rely on upstream pattern for scalar arith.constant. It would be worth checking if it works and maybe we need to upstream a fix

@htyu
Copy link
Collaborator Author

htyu commented Jun 27, 2024

Confirmed that scalar const lowering works with LLVM:

`%cst = arith.constant 0.000000e+00 : f8E4M3FNUZ`

=>
%0 = llvm.mlir.constant(0.000000e+00 : f8E4M3FNUZ) : i8

x, y = binary_op_type_checking_impl(x, y, builder, True, True)
# Bypass arithmetic type check for FP8 types where they are not supported.
is_fp8 = x.type == y.type and x.type.is_fp8() and y.type.is_fp8()
x, y = binary_op_type_checking_impl(x, y, builder, True, True, not is_fp8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so fp8 would not have auto-promote but other types would? Seems a bit odd. I would add support for fp8 in binary_op_type_checking_impl instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current logics does not support fp8 because fp8 arithmetic is not available on hardware (except for dot). Do I understand it correctly?

Also, I'm not sure we need promote here, as in the example kernel fp8 is not used for arithmetic, rather, the kernel loads fp32, convert it to fp8, and conditionally stored out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or do you think it's safe to not promote anything when x.type == y.type?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant in general if we have a mix-mode op we promote to the highest format. We could do that for fp8 as well right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah we should do that.

@htyu
Copy link
Collaborator Author

htyu commented Jun 27, 2024

Bypass arithmetic type check when inputs are same-typed.

@htyu htyu merged commit 938e388 into triton-lang:main Jun 28, 2024
6 checks passed
Jokeren pushed a commit that referenced this pull request Jul 1, 2024
To unblock the compilation of kernels like below which don't operate
arithmetically on FP8.

```
@triton.jit
def triton_poi_fused__scaled_mm__to_copy_constant_pad_nd_lift_fresh_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 400624
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 784
    x1 = (xindex // 784)
    x2 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 769, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = tl.load(in_ptr0 + (x0 + (769*x1)), tmp2 & xmask, other=0.0)
    tmp4 = tmp3.to(tl.float8e4nv)
    tmp5 = tl.full(tmp4.shape, 0.0, tmp4.dtype)
    tmp6 = tl.where(tmp2, tmp4, tmp5)
    tl.store(out_ptr0 + (x2), tmp6, xmask)
```
Jokeren added a commit that referenced this pull request Jul 3, 2024
Update

Update

Update

Update

Add a more meaningful check to make sure we are not merging blocks (#4186)

This is a follow-up to
#4176 (comment)

I am now counting the number of blocks with (17) and without (31) block
merging. I double checked to make sure this does not pass when we use an
aggressive region simplification strategy.

[AMD] Skip mfma layout in maybeDuplicate (#4170)

The workaround introduced in
#4048 "forgot" to skip mfma
layout.

[TEST] Merge duplicate `max_num_imprecise_acc` tests and improve code (#4191)

[DOCS][NFC] Fix doc formatting problems (#4195)

1. f-string cannot be used as docstrings in Python.
2. URLs should follow the reStructuredText format.
3. Code snippets in a code block should be indented.

Tested and passed on a local machine.

[BACKEND] Fix regression in pipeliner pre-checks. (#4196)

During some previous refactoring we changed the logic and started
pipeling cases that had incompatible shared encoding. This was missed
because one of the lit test had not been updated :(

Remove tl.multiple_of call from tma persistent kernel (#4198)

[AMD] Guard against null in `BypassEpilogueSMEM` (#4203)

`val.getDefiningOp()` can return `nullptr`. In this case, we must fail
the `BypassEpilogueSMEM` rewrite pass for the given op. This prevents
run-time crashes.

[FRONTEND][NFC] Fix type checking, conditional logic, and loop structures for improved readability and performance (#4208)

Document TRITON_HOME (#4210)

Document the existence of `TRITON_HOME` environment variable.

The `TRITON_HOME` variable controls the location of the `.triton`
directory that stores, among other things, the files downloaded during a
`pip install -e python` virtualenv build. By default, this is located in
the user's home directory, at `~/.triton`.

I was trying to build Triton on my system on a large local disk, but
with limited network home directory space, and the `pip` command kept
failing with out of disk space errors. It turned out that during
installation, large files were downloaded to the `~/.triton` directory
causing failure.

After checking that it was not `pip` doing this, I found the
`TRITON_HOME` variable which allowed me to workaround the issue and
build Triton successfully. After seconding #4007, I decided to
contribute this documentation fix.

Co-authored-by: sree <sree@buckyball>

[BACKEND] Fix regression in i1 reduction (#4215)

Recent refactoring broke i1 shared memory load.

[BUILD] update URL for LLVM tarballs (#4216)

[BACKEND] Fix divisibility analysis for shift ops (#4221)

Divisibility does not ensure that a value is not 0 therefore we cannot
use divisibility as a minimum shifted values.

Support FP8 constant (#4222)

To unblock the compilation of kernels like below which don't operate
arithmetically on FP8.

```
@triton.jit
def triton_poi_fused__scaled_mm__to_copy_constant_pad_nd_lift_fresh_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 400624
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 784
    x1 = (xindex // 784)
    x2 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 769, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = tl.load(in_ptr0 + (x0 + (769*x1)), tmp2 & xmask, other=0.0)
    tmp4 = tmp3.to(tl.float8e4nv)
    tmp5 = tl.full(tmp4.shape, 0.0, tmp4.dtype)
    tmp6 = tl.where(tmp2, tmp4, tmp5)
    tl.store(out_ptr0 + (x2), tmp6, xmask)
```

[INTERPRETER] Implement implicit tensor conversion for assignment operators (#4214)

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this pull request Aug 16, 2024
y-sq added a commit to y-sq/ao that referenced this pull request Sep 9, 2024
Summary:
Add test cases to verify that the compile of inner-padding works with the triton PR triton-lang/triton#4222.

Before the triton PR, the inductor code-gen kernel fails at
```
tmp10 = tl.where(tmp6, tmp8, tmp9)

TypeError: unexpected type fp8e5 and fp8e5
```

Reviewed By: irobert0126

Differential Revision: D62003827
y-sq added a commit to y-sq/ao that referenced this pull request Sep 10, 2024
Summary:
Pull Request resolved: pytorch#858

Add test cases to verify that the compile of inner-padding works with the triton PR triton-lang/triton#4222.

Before the triton PR, the inductor code-gen kernel fails at
```
tmp10 = tl.where(tmp6, tmp8, tmp9)

TypeError: unexpected type fp8e5 and fp8e5
```

Reviewed By: irobert0126

Differential Revision: D62003827
y-sq added a commit to y-sq/ao that referenced this pull request Sep 10, 2024
Summary:
Pull Request resolved: pytorch#858

Add test cases to verify that the compile of inner-padding works with the triton PR triton-lang/triton#4222.

Before the triton PR, the inductor code-gen kernel fails at
```
tmp10 = tl.where(tmp6, tmp8, tmp9)

TypeError: unexpected type fp8e5 and fp8e5
```

Reviewed By: irobert0126

Differential Revision: D62003827
y-sq added a commit to y-sq/ao that referenced this pull request Sep 16, 2024
Summary:
Pull Request resolved: pytorch#858

The diff modifies the `padding` option and added tests with `compile`:

* For the scaled_mm of shape MxKxN, the current `inner_padding` option only pads the `K` dimension. However, if `N` is not divisible by 16, we also got the error
```
E       RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasLtMatmulAlgoGetHeuristic( ltHandle, computeDesc.descriptor(), Adesc.descriptor(), Bdesc.descriptor(), Cdesc.descriptor(), Ddesc.descriptor(), preference.descriptor(), 1, &heuristicResult, &returnedResult)`
```
So, modified the pad_inner option to also pad the K dimensions.

-----
* The compile of inner-padding only works with the triton PR triton-lang/triton#4222.

Before the triton PR, the inductor code-gen kernel fails at
```
tmp10 = tl.where(tmp6, tmp8, tmp9)

TypeError: unexpected type fp8e5 and fp8e5
```

Reviewed By: irobert0126

Differential Revision: D62003827
quanta42 pushed a commit to quanta42/triton that referenced this pull request Nov 22, 2024
To unblock the compilation of kernels like below which don't operate
arithmetically on FP8.

```
@triton.jit
def triton_poi_fused__scaled_mm__to_copy_constant_pad_nd_lift_fresh_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 400624
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 784
    x1 = (xindex // 784)
    x2 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 769, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = tl.load(in_ptr0 + (x0 + (769*x1)), tmp2 & xmask, other=0.0)
    tmp4 = tmp3.to(tl.float8e4nv)
    tmp5 = tl.full(tmp4.shape, 0.0, tmp4.dtype)
    tmp6 = tl.where(tmp2, tmp4, tmp5)
    tl.store(out_ptr0 + (x2), tmp6, xmask)
```
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
To unblock the compilation of kernels like below which don't operate
arithmetically on FP8.

```
@triton.jit
def triton_poi_fused__scaled_mm__to_copy_constant_pad_nd_lift_fresh_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 400624
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 784
    x1 = (xindex // 784)
    x2 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 769, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = tl.load(in_ptr0 + (x0 + (769*x1)), tmp2 & xmask, other=0.0)
    tmp4 = tmp3.to(tl.float8e4nv)
    tmp5 = tl.full(tmp4.shape, 0.0, tmp4.dtype)
    tmp6 = tl.where(tmp2, tmp4, tmp5)
    tl.store(out_ptr0 + (x2), tmp6, xmask)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants