-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[webgpu-native] Add transpose op #21986
[webgpu-native] Add transpose op #21986
Conversation
|
||
class TransposeProgram final : public Program<TransposeProgram> { | ||
public: | ||
TransposeProgram(const std::string& kernel_name, const gsl::span<const size_t>& permutations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name can remove because it should always be "Transpose"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), | ||
Transpose); | ||
|
||
const std::string permFunctionBody(const std::string& input_name, const std::string& output_name, const gsl::span<const size_t>& perm) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a std::ostringstream& ss
as the first parameter.
body in this context (at least in MainFunctionBody
) means the part between brackets. maybe just use Function
.
may be changed to "AppendPermFunction"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ORT prefer PascalCase (first letter upper case) for function name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually can safely replace const std::string&
with std::string_view
in function parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if (!status.IsOK()) | ||
return status; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use ORT_RETURN_IF_ERROR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Status Transpose::ComputeInternal(ComputeContext& context) const { | ||
const auto* input_tensor = context.Input(0); | ||
const TensorShape& input_shape = input_tensor->Shape(); | ||
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may need a discussion: which is better? we can choose one to use then keep consistent
- gsl::narrow_cast<>
- SafeInt<>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done gsl::narrow_cast<>
TensorShape output_shape(output_dims); | ||
auto* output_tensor = context.Output(0, output_shape); | ||
|
||
SafeInt<uint32_t> vec_size = input_tensor->Shape().Size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should not name vec_size
since the input/output is not vec.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
SafeInt<uint32_t> vec_size = input_tensor->Shape().Size(); | ||
TransposeProgram program{"Transpose", *p_perm}; | ||
program |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perm should either be a part of cache hint or uniform. currently it seems different perm may use the same shader...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
f1449e1
to
ea145ce
Compare
ea145ce
to
ed4134f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, @fs-eire , ptal
Status Transpose::ComputeInternal(ComputeContext& context) const { | ||
const auto* input_tensor = context.Input(0); | ||
const TensorShape& input_shape = input_tensor->Shape(); | ||
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done gsl::narrow_cast<>
|
||
SafeInt<uint32_t> vec_size = input_tensor->Shape().Size(); | ||
TransposeProgram program{"Transpose", *p_perm}; | ||
program |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), | ||
Transpose); | ||
|
||
const std::string permFunctionBody(const std::string& input_name, const std::string& output_name, const gsl::span<const size_t>& perm) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
class TransposeProgram final : public Program<TransposeProgram> { | ||
public: | ||
TransposeProgram(const std::string& kernel_name, const gsl::span<const size_t>& permutations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if (!status.IsOK()) | ||
return status; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
TensorShape output_shape(output_dims); | ||
auto* output_tensor = context.Output(0, output_shape); | ||
|
||
SafeInt<uint32_t> vec_size = input_tensor->Shape().Size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Description
Motivation and Context