Skip to content

Commit

Permalink
general solution
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang committed Feb 27, 2024
1 parent 2d5369d commit fee3476
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 890 deletions.
318 changes: 193 additions & 125 deletions onnxruntime/core/optimizer/gather_fusion.cc

Large diffs are not rendered by default.

16 changes: 10 additions & 6 deletions onnxruntime/core/optimizer/gather_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@
namespace onnxruntime {

/**
@Class GatherToSplitFusion
@Class GatherSliceToSplitFusion
Fuse multiple Gather nodes that comsuming one output to one Split node.
Fuse multiple Gather/Slice nodes that comsuming one output to one Split node.
*/
class GatherToSplitFusion : public GraphTransformer {
class GatherSliceToSplitFusion : public GraphTransformer {
public:
GatherToSplitFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GatherToSplitFusion", compatible_execution_providers) {}
GatherSliceToSplitFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept

Check warning on line 17 in onnxruntime/core/optimizer/gather_fusion.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Constructors callable with one argument should be marked explicit. [runtime/explicit] [5] Raw Output: onnxruntime/core/optimizer/gather_fusion.h:17: Constructors callable with one argument should be marked explicit. [runtime/explicit] [5]
: GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const;
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size,
InlinedVector<bool>& consumed, int64_t& start, bool& need_squeeze) const;

bool IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size,
InlinedVector<bool>& consumed, int64_t& start, int64_t& end) const;
};

/**
Expand Down
333 changes: 0 additions & 333 deletions onnxruntime/core/optimizer/gather_slice_fusion.cc

This file was deleted.

31 changes: 0 additions & 31 deletions onnxruntime/core/optimizer/gather_slice_fusion.h

This file was deleted.

Loading

0 comments on commit fee3476

Please sign in to comment.