-
Notifications
You must be signed in to change notification settings - Fork 529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QPyTorch Support (BFP quantization) #910
Comments
Here is the stack trace of the error I'm hitting:
|
Hi, this issue is not specific to Torch-MLIR. This is a general issue with TorchScript'ing the model, and comes from not having a type annotation on the (note, that the issue is not the assertion triggering, but a failure of the TorchScript compiler to compile the line of code which contains the assertion). I then get this IR, which looks like a normal unquantized module... is there something I need to do to enable quantization on the module? (I applied your PR verbatim, except for adding
|
Hmm... it seems like an issue with torch.jit.trace
I will investigate how to debug this. So, to summarize, the issues so far are not specific to Torch-MLIR -- we are having trouble even torch.jit.script/torch.jit.trace'ing it, which is a prerequisite for importing into Torch-MLIR. |
Hi, so it looks like the source of the issue is that As a next step, I would recommend that you define your ops with |
Assigning back to @Svoch as the next step is on the QPyTorch side. |
Thank you @silvasean! This is very helpful, and it makes sense why the op is not being picked up. Let me modify the registration method somehow and I'll update this issue with my findings. |
I was able to modify QPyTorch operation bindings such that
|
More details on steps to reproduce this error: 1. QPyTorch ModificationsThe modifications for QPyTorch operator bindings is present on the
2. The ResultsBelow is the logs from running the Torch-MLIR experiment script. The custom operators seem to be successfully built and linked, but as seen in the IR in the previous comment the shape information for Ops like
Do I need to add additional info to enable shape inference for the custom Ops in Torch-MLIR? |
You will want to do something like what Bob does in #895 to add the shape and dtype inference. That will need to be done in a fork of Torch-MLIR for now, but I would consider including QPyTorch support as first class if we can get to a good solution here for our customers (there seem to be a LOT of hardware vendors that would love to have this be really well supported, and upstream PyTorch doesn't seem to be providing a good solution, so I'm interested in incubating something in Torch-MLIR). We discussed in the Torch-MLIR developer hour that one of the nod.ai folks was going to be building a PoC of QPyTorch lowering into TOSA. Was that you @nithinsubbiah that was going to work on it? I'm happy to provide architectural guidance here to get a really great PoC and deliver really good first-class support here. cc @powderluv |
Hi @silvasean, yes I'll work on this integration. Adding shape inference for this QPyTorch op and check if that can lower to TOSA would be the first step I think (please correct me if I'm wrong). |
Sounds great @silvasean! This is very exciting update on the custom Op support front, and looks like a very good timing. cc @rdadolf |
You also typically need to update RefineTypes.cpp too -- you can see all the steps here: https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation |
Just to reiterate what Sean said: yes, the process for extensions is intended to follow the same 5-step process that Sean linked. There's a bit more information on the differences in this readme co-located with the example code. You should not need to write anything that looks like what's in |
Thank you @silvasean and @rdadolf for the details! Quick update: I kept hitting some KeyErrors in
Will update this issue with the findings, @nithinsubbiah is helping with this step. |
That sounds accurate. That was changed recently in #915. I've been working some of the bumps with Nithin on Discord, including his |
Quick update, I registered one qtorch op with dtype and shape inference following instructions in #895. I got the following Torch IR:
But lowering to TOSA fails with an exception:
|
Reviving discussion on this - Was able to use custom op extension to register a qtorch op and did a rewrite from Torch -> TOSA as
|
What is the action item here @nithinsubbiah ? |
@Svoch @nithinsubbiah can we close this issue? |
I think we can close this issue. @silvasean - I wonder how does the Custom Ops support RFC affect this path however, we can discuss it there for better visibility. |
Hi folks! My team and I are looking into having compiler support for block floating point (BFP) in Torch-MLIR. Wondering what you think about extending the Torch-MLIR support for these cases. Below is a dummy test network I used as an experiment to compile a PyTorch model with BFP additions from
qtorch
viatorch_mlir
:Input Description
The model is a basic MatMul followed by a BFP cast and a ReLU activation.
Observed Behaviour
I used the
torch_mlir.compile()
API to compile the module into TOSA IR. While the module seems to run forward-propagate fine, the compilation seems to hit an assert in theqtorch.quant_function.block_quantize()
for not having a valid "rounding mode". Also, removing the BFP quantization in the forward-propagate of the module yields a successful compilation.Lastly, if I quantize the inputs tensors of the module and then call
torch_mlir.compile()
on them, there doesn't seem to be any issue - are the casts optimized out in this case?Script to Reproduce
For convenience, I made a draft PR with a minimal script to reproduce the issue I'm hitting here: #909
FYI @silvasean
The text was updated successfully, but these errors were encountered: