-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Support pushing custom ops through backend-contract using torch.operator
#1959
Conversation
torch.operator
torch.operator
goofy_lib = torch.library.Library("goofy", "DEF") | ||
goofy_lib.define("identity(Tensor t) -> Tensor") | ||
goofy_lib.impl("identity", identity) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing as the "classical" torch custom op registration; this torch.jit.trace
s to
graph(%self : __torch__.CustomOpExampleModule,
%a : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%4 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]()
%5 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%a, %4)
%6 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = goofy::identity(%5)
return (%6)
if (!libFunc) | ||
return success(); | ||
} else { | ||
libFuncNamesUsed.push_back(libFuncName); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
libFuncNamesUsed
is ultimately used to "import" shape functions - since the user-provided shape functions are already in the user module, this isn't necessary (and causes a segfault somewhere around ReifyAbstractInterpCalculationsUtils.cpp#L159).
if (isa<OperatorOp>(op)) { | ||
auto opOp = cast<OperatorOp>(op); | ||
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue(); | ||
name_ = "operator." + opName.str(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape/dtype functions for operator
ops should be namespaced one level deeper.
cc5e2c2
to
7e69413
Compare
resultType.isa<Torch::NoneType>() || | ||
(resultType.isa<Torch::ListType>() && cast<Torch::ListType>(resultType) | ||
.getContainedType() | ||
.isa<Torch::IntType>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape functions return list<int>
.
This is potentially a way to support LLaMA 4bit quant via Torch-mlir. |
Okay but the approach proposed here works today with minimal expansion of the API surface. Is it possible to merge this in before the other roadmap is complete and then remove it afterwards? In fact, given this
if we instead just simply add |
… in `ReduceOpVariants`
ebbf9a5
to
ae35567
Compare
No, I don't think that's a wise engineering decision. Removing workarounds is usually 10x the work of adding them. It is better to push forward and complete the plan as originally specified. I think we will probably end up with something kind of similar to this in some aspects, but until we have migrated to the dtype functions there isn't much point since we won't be able to compile real models (see explanation in #1807 for why it happens on a branch). An equally big problem is how this feature exposed to users -- the approach in this patch "happens to work", but does not rely on a supported API surface area. |
And how about with the change I just pushed? |
It doesn't fundamentally address any of the issues I mentioned. |
It does;
There is nothing here that won't have to look exactly the same to support your proposal - it just frontloads |
@@ -191,7 +191,10 @@ static bool isValidNonContainerResultType(Type resultType) { | |||
resultType.isa<Torch::FloatType>() || | |||
resultType.isa<Torch::IntType>() || | |||
resultType.isa<Torch::BoolType>() || | |||
resultType.isa<Torch::NoneType>(); | |||
resultType.isa<Torch::NoneType>() || | |||
(resultType.isa<Torch::ListType>() && cast<Torch::ListType>(resultType) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't be doing this. This function was created with the goal of preventing something like a ListType
return to reach the backend contract. This would lead to invalid IR being generated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fine but if you want to enable user-provided shape and dtype functions in the same parent module then there needs to be special casing for them. The alternative, is to provide some mechanism for passing handles to a ModuleOp
all the way down into wrapWithCalculateOpIfLibraryFunctionAvailable
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
User provided shape and dtype functions will be handled exactly the same way that current shape and dtype functions are handled. The plan is to load them from a .mlir
file when the shape or dtype pass is taking place, then remove them once the pass is done. This is all outlined in the RFC: #1462. There shouldn't be any coupling between the shape and dtype pipelines and the rest of the passes in torch-mlir.
We will likely require all custom ops to be value-semantics. However, this does not mean we want all I agree with Sean's comments that it would be really appreciated if we could focus our work efforts on getting the custom op support RFC published a few months ago across the finish line rather than duplicating efforts. |
If you check the first commit in this stack, this literally exactly what I had static bool operatorOpHasValueSemantics(OperatorOp opOp) {
if (!opOp->hasAttr("has_value_semantics"))
return false;
auto hasValueSemantics =
opOp->getAttr("has_value_semantics").cast<BoolAttr>().getValue();
return hasValueSemantics;
} and then special case checks in I'm happy to develop more safety and docs in order to get it to cross the threshold from "happens to work" to "supported API surface area." |
Today torch-mlir lacks the ability to lower state of art quantized models (llama 4bit, stable diffusion 8bit, and other custom models). IIUC this PR allows for a pragmatic but short term way to support 4bit llama and 8bit SD today for our customers. It is forward progress. We can continue to invest in bringing the best approach and can get dedicated resources to help with the right approach but we do have to ship something and can't wait O(months) since customers will just bypass torch-mlir. |
Sorry, this isn't what I meant. Users shouldn't have to modify the IR. We should have a very simple pass that checks if the |
I'm sorry, I'm not trying to be argumentative but I'm not sure why
is an appropriate contract/model/invariant; can't (conceptually) non-valsem ops support dtype and shape refinement? Like I don't see what type refinement even has to do (conceptually) with value semantics. And having to implement those functions isn't "hidden API surface" (combined a with a whole pass) but
is hidden API surface? Or onerous? Or bad practice? There are ample examples of just this pattern in |
Not exactly. It's more like
The RFC then makes the proposal of assuming
Indeed, type refinement on its own has nothing to do with value semantics. Once again, this is a fundamental assumption of the RFC. Since it was posted several months ago, no member of the community has disagreed with this assumption in the design, but if you have concerns, we should definitely discuss them in the RFC.
Not sure I understand what your point is here. The implementing of shape+dtype functions is the design proposed in the RFC that so far no one has had issues with. It uses an approach people are already familiar with thanks to the shape library. Again, if you have design concerns, do bring them up on the RFC. |
Okay folks, I want to take a step back here. For each custom op, we need shape/dtype functions and knowing that it has value semantics. This will be specified in the library. The mechanism for specifying value semantics is TBD but a simple way to do it is a dummy function The biggest part that this PR doesn't address is how users interact with the custom ops API, and that is the part that currently needs design work. Torch-MLIR's only supported "stable" interface is the I will point out that this custom ops RFC has been in place since October and was in large part developed for (and signed off by) the Torch-MLIR users in this thread asking for short-term workarounds. We have not seen contributions to the RFC by these users, not even at the baseline level of pings for "when will this be done? we need this", so the RFC has languished. We do not want to set the precedent that community members can deprioritize contributing and then expect short-term workarounds to be approved. That is not in line with LLVM community engineering practices and does not scale to a vibrant community upholding high engineering standards. Going forward we should be more intentional about indicating and following up on the relative priority and timelines for features that are expected to be needed, so that they can be ready at the right time. This is a reality of open-source development across a variety of industry partners -- unlike within a single monolithic organization, the community doesn't have a direct sense of the importance of different work, so it falls on everyone to drive/champion/follow-up on the features that are important to them, especially when work is being done for them by other community members. (there are lots of times where users ask for things but then drop off the radar and end up not needing the feature they asked for, so slowing down work on things people don't actively ask for is "a feature, not a bug", though here it seems like the outcome wasn't great) Concretely, to move forward here, what would be useful is if we can do the following:
This is definitely an O(weeks) type thing and not O(months). We originally specced the RFC as a 1-2 month project and significant work has already happened, so expect less than that. |
Few notes to consider: 1: Ouch.
That sounds like "yeah the community can do all the grunt / mechanical work let Sean / Ramiro handle the architectural work". Besides that is what we have been doing since September 2022 when we created this branch https://github.com/llvm/torch-mlir/tree/custom-op-example for existing custom op users and the RFC was floated for the new way forward. We are six months in and so we see this PR as a pragmatic stopgap. We are happy to hash out what the compile api should look like. 2: Ouch again.
As the largest contributor in terms of commits to torch-mlir, including the grunt work of keeping up to date with RollPyTorch the CI, LLVM rotations we would like to know how we can do better. Besides we don't think this is a short term workaround - it accelerates the custom op support (happy to discuss any api surface changes) see three below. 3: Most importantly existing dtypes port is not required for custom op support.
All that said, like I offered in late January for dtype porting we are happy to get someone to help on the mechanical dtype work but we shouldn't conflate it as a mandatory requirement for custom op support as proven by this PR (modulo API refinements). |
I wanna quickly comment on this since I am one of the people most familiar with this issue and I think it is important for guiding future decisions on this. This is a blocker for custom op support. Any op that has to be handled by the new shape+dtype inference pipeline is affected by the dtype function transition. The reason is explained in detail here #1807, but I will try to give a different explanation, since it seems to still be a point of a lot of confusion. First, when we get to the type refinement level, custom ops and regular ops are essentially the same. We need a pass that goes to that op, and uses the information from the inputs of the op to calculate the shape and dtype of the outputs. For dtypes, this currently can happen in two places: If your graph contains some ops that are handled by
Because the information for dtypes flows in order down the graph, when Note: this has nothing to do with |
Okay it's true that wasn't clear to me so thank you for the clarification but
is a perfectly conservative and explicit graceful fail for anyone attempting to use this PR (if it were merged) - users who are willing to pay the price will bump the
|
@powderluv what is your specific timeline that you would like to have this done by? If 2-3 weeks is okay with your timeline, then let's work together, prioritize this, and implement it as we already had specced and planned. If this is needed sooner, then I really can't think of a way to move forward with this on the main branch while sticking to our engineering principles and overarching LLVM community practices.
Sorry, I didn't intend it to come off like "grunt work", but there is a separation between the more mechanical changes here and the parts that need a deeper analysis of the architectural implications, and Ramiro and I happen to be the ones with more extensive knowledge here. I am happy to help other folks who want to work on this, but since delivery time is the primary criterion here, having Ramiro and I work on it seems most practical.
I think that pinging PR's or RFC's that aren't moving but that are important would be my specific request, and giving expected timelines for things. This PR as-is is definitely a short-term workaround. Undoing these things is really, really difficult. Once you start punching layering/API holes things fall apart very quickly and repairing that is always 10x more work.
I agree, but the "modulo API refinements" is probably 2-3 weeks end-to-end time to implement. And as Ramiro said you are likely to hit a significant compilation time issue if we don't port the dtype functions too so it is kind of risky to try to ignore the dtype functions too (could result in 100x+ compilation time blowup for a model like llama with a lot of layers). |
2-3 weeks is great. Let's split up the tasks and prioritize them. @gpetters94 is back on Friday and can help with dtype porting. |
That's only 100x (at worst) blowup in compile from PyTorch to I don't have SD or llama on hand but I timed |
Hey folks I've created another issue to shepherd this feature to completion at a level of concrete action items: #1963 I know some folks are on a deadline and hopefully this will help show remaining work and how on track we are for transparency. |
Pitch
I think lots of people want to be able to push opaque kernels through backend contract (#1519, #1947, #1514). Indeed, there is also a similarly themed proposal that goes through
backend-legal-ops
(and is specific to TOSA). I believe that there's also a need for supporting custom ops, i.e., ops that are effectively placeholders for some possible lowering/implementation on the otherside of backend contract (posterchild here is quantized ops).Approach
The approach here is an adaptation of the other approach - we go through
backend-legal-ops
usingtorch.operator
. The catch is that becauseOperandOp
doesn't possess theHasValueSemantics
trait,ReduceOpVariants
andMaximizeValueSemantics
will stumble and ultimately backend contract won't be satisfied. Thus, we addHasValueSemantics
totorch.operator
. In addition, we extendwrapWithCalculateOpIfLibraryFunctionAvailable
to support user-provided shape and dtype functions; in particular we provide these two (trivial) refinement functions:With the example/demo here I am able fully lower to
Certainly there are other ways to do this, so I'm open to suggestions/advice.
cc @powderluv @qedawkins @AmosLewis