Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shrink the Gather's output shape #5

Merged
merged 1 commit into from
May 17, 2023
Merged

Conversation

mingmingtasd
Copy link

Shrink the Gather's output shape to avoid crash in AddGemm(). For example:

Gather output shape {1, 1, 256} -> shrink to 2d {1, 256}. Gemm can only support 2d inputs now.

Copy link
Owner

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Mingming.

Interesting 🤔. There are a number of places where I'm passing the modified output tensor desc to CreateNodeOutput. I'll need to review each of these cases...

GraphDMLImpl::AddInstanceNormalization(...)
...
  output_tensor_desc.PermuteDimensions(tensor_dimensions_permutation, TensorDesc::Alignment::kTrailing);
...
  auto node_output = graph_desc_builder_->CreateNodeOutput(
      node, 0, std::move(output_tensor_desc));

@fdwr fdwr merged commit 1e09d98 into fdwr:dml_sd May 17, 2023
@mingmingtasd
Copy link
Author

mingmingtasd commented May 18, 2023

Thanks Mingming.

Interesting 🤔. There are a number of places where I'm passing the modified output tensor desc to CreateNodeOutput. I'll need to review each of these cases...

GraphDMLImpl::AddInstanceNormalization(...)
...
  output_tensor_desc.PermuteDimensions(tensor_dimensions_permutation, TensorDesc::Alignment::kTrailing);
...
  auto node_output = graph_desc_builder_->CreateNodeOutput(
      node, 0, std::move(output_tensor_desc));

Yes, all nodes need to be checked again. If the output shape has been reset in graph_dml_impl.cc and mismatch with the graph builder level, it's dangerous and will let device_->CreateOperator(&op_desc, IID_PPV_ARGS(&op)) fail. You can refer to the https://github.com/mingmingtasd/webnn-native/blob/dml_optimize/src/webnn/native/dml/GraphDML.cpp, you can see, I ensure the output shape can align with the graph builder level.

@mingmingtasd
Copy link
Author

I mean aligning with graph builder level equals to webnn spec, anyway, we need to keep a specific shape for each node no matther which backend/implementation actually give.

@fdwr
Copy link
Owner

fdwr commented May 18, 2023

Indeed. The reason it works as-is for so many operators is because the input and output shapes were the same anyway, or the output shape between DML and WebNN/ONNX were the same. So, this really hits operators where the necessary DML shape and the WebNN/ONNX size differ. My single operator tests would not have caught these because it shows the client-side builder shape, not the internal shape, and this is the first time we've glued together so many operators together.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants