Skip to content

Commit

Permalink
Marlin Mixed Input Kernel Productionization (#3008)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3008

X-link: facebookresearch/FBGEMM#103

This diff does quite a bit of facelifting to our [Marlin](https://github.com/IST-DASLab/marlin) BF16 X I4 kernels. These improvements include:

* Upgrading the kernel with the latest improvements from VLLM. This helps quite a bit with stability and fixes issues with group scaling.
* Adds template specializations so that the marlin kernel supports both BF16 and FP16 using a single implementation.
* Fixes BF16 Dequantization issue.
* Exposes a simplified torch custom op `torch.ops.marlin.marlin_gemm` and convenient helpers for quantizing to the marlin format `marlin_quantize`.
* Adds these new ops to our quantize benchmarks.
* New tests and better directory structure.

One downside of this work is that we have diverged a bit from VLLM so it may be harder to stay in sync going forward. However, I think the benefits of the improvements in this diff outweigh potential sync costs.

Reviewed By: jianyuh, jiawenliu64

Differential Revision: D61408771

fbshipit-source-id: 66b651ce794309a408f30244cac20a3c9ab0ce5a
  • Loading branch information
jwfromm authored and facebook-github-bot committed Aug 20, 2024
1 parent e19bde1 commit 162cc69
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
if torch.cuda.is_available() and torch.version.cuda:
torch.ops.load_library("//tinygemm:tinygemm")

# Marlin currently only is supported only internally at Meta.
try:
from marlin.quantize import marlin_quantize

torch.ops.load_library("//ai_codesign/gen_ai/marlin:marlin_ops")
MARLIN_ENABLED = True
except ImportError:
MARLIN_ENABLED = False


quantize_op_registry = []


Expand Down Expand Up @@ -670,3 +680,36 @@ def hip(self) -> bool:
@property
def cuda(self) -> bool:
return True


@register_quantize_op
class MarlinBF16I4(QuantizeOpBase):
"""
Mixed Precision BF16 Activations with Int4 Weights using Marlin.
"""

def quantize(self, x, w):
# Marlin quantize expects weights in [K, N] layout.
_, wq, scale = marlin_quantize(w.t().contiguous(), 128)
return x, wq, scale

def compute(self, x, wq, scale):
return torch.ops.marlin.marlin_gemm(x, wq, scale)

def quantize_and_compute(self, x, w):
x, wq, scale = self.quantize(x, w)
return self.compute(x, wq, scale)

@property
def name(self) -> str:
return "marlin_bf16i4"

@property
def hip(self) -> bool:
# Marlin only supported for cuda.
return False

@property
def cuda(self) -> bool:
# This op is not always supported.
return MARLIN_ENABLED

0 comments on commit 162cc69

Please sign in to comment.