Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu-native] Add transpose op #21986

Merged
merged 8 commits into from
Sep 11, 2024

Conversation

axinging
Copy link
Contributor

@axinging axinging commented Sep 4, 2024

Description

Motivation and Context

@fs-eire fs-eire self-assigned this Sep 5, 2024
@axinging axinging marked this pull request as ready for review September 5, 2024 01:43

class TransposeProgram final : public Program<TransposeProgram> {
public:
TransposeProgram(const std::string& kernel_name, const gsl::span<const size_t>& permutations)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name can remove because it should always be "Transpose"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

const std::string permFunctionBody(const std::string& input_name, const std::string& output_name, const gsl::span<const size_t>& perm) {
Copy link
Contributor

@fs-eire fs-eire Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a std::ostringstream& ss as the first parameter.

body in this context (at least in MainFunctionBody) means the part between brackets. maybe just use Function.

may be changed to "AppendPermFunction"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ORT prefer PascalCase (first letter upper case) for function name

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually can safely replace const std::string& with std::string_view in function parameters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 79 to 80
if (!status.IsOK())
return status;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use ORT_RETURN_IF_ERROR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Status Transpose::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may need a discussion: which is better? we can choose one to use then keep consistent

  • gsl::narrow_cast<>
  • SafeInt<>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done gsl::narrow_cast<>

TensorShape output_shape(output_dims);
auto* output_tensor = context.Output(0, output_shape);

SafeInt<uint32_t> vec_size = input_tensor->Shape().Size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not name vec_size since the input/output is not vec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


SafeInt<uint32_t> vec_size = input_tensor->Shape().Size();
TransposeProgram program{"Transpose", *p_perm};
program
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perm should either be a part of cache hint or uniform. currently it seems different perm may use the same shader...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@axinging axinging force-pushed the transpose_webgpunative branch 2 times, most recently from f1449e1 to ea145ce Compare September 10, 2024 02:05
@axinging axinging force-pushed the transpose_webgpunative branch from ea145ce to ed4134f Compare September 10, 2024 04:26
Copy link
Contributor Author

@axinging axinging left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, @fs-eire , ptal

Status Transpose::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done gsl::narrow_cast<>


SafeInt<uint32_t> vec_size = input_tensor->Shape().Size();
TransposeProgram program{"Transpose", *p_perm};
program
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

const std::string permFunctionBody(const std::string& input_name, const std::string& output_name, const gsl::span<const size_t>& perm) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


class TransposeProgram final : public Program<TransposeProgram> {
public:
TransposeProgram(const std::string& kernel_name, const gsl::span<const size_t>& permutations)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 79 to 80
if (!status.IsOK())
return status;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

TensorShape output_shape(output_dims);
auto* output_tensor = context.Output(0, output_shape);

SafeInt<uint32_t> vec_size = input_tensor->Shape().Size();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

onnxruntime/core/providers/webgpu/tensor/transpose.cc Outdated Show resolved Hide resolved
@fs-eire fs-eire merged commit d4a963d into microsoft:fs-eire/webgpu-ep Sep 11, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants