Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[op] added support for TensorScatterUpdate op #7189

Merged
merged 6 commits into from
Jan 4, 2023
Merged

Conversation

pyu10055
Copy link
Collaborator

@pyu10055 pyu10055 commented Dec 19, 2022

ref #6709
This is the backend side of implementing TensorScatterUpdate op support across cpu/wasm/webgl/node

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

Copy link
Collaborator

@Linchenn Linchenn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

let defaultValuesString = '';
if (defaultIsTensor) {
defaultValuesString = 'coords[0], coords[1]';
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! This bypasses flattening and de-flattening the index.

Copy link
Member

@mattsoulanille mattsoulanille left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a nit

Comment on lines 45 to 50
for (size_t k = 0; k < slice_size; ++k) {
*out_buf_ptr = static_cast<T>(*updates_ptr);

out_buf_ptr++;
updates_ptr++;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is somewhat confusing. Can we use memcpy to copy the slices instead?

Suggested change
for (size_t k = 0; k < slice_size; ++k) {
*out_buf_ptr = static_cast<T>(*updates_ptr);
out_buf_ptr++;
updates_ptr++;
}
memcpy(out_buf_ptr, updates_ptr, slice_size * dtype_size);
out_buf_ptr += slice_size;
updates_ptr += slice_size;

It may be possible to further simplify the index calculations as well.

Copy link
Collaborator Author

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @mattsoulanille)


tfjs-backend-wasm/src/cc/kernels/TensorScatterUpdate.cc line 50 at r3 (raw file):

memcpy(out_buf_ptr, updates_ptr, slice_size * dtype_size);
out_buf_ptr += slice_size;
updates_ptr += slice_size;
thanks

@pyu10055 pyu10055 merged commit 9eb0229 into master Jan 4, 2023
@pyu10055 pyu10055 deleted the tensor_scatter_update branch January 4, 2023 18:30
Linchenn pushed a commit to Linchenn/tfjs that referenced this pull request Jan 9, 2023
FEATURE
* added support for TensorScatterUpdate op

* fix snippet error

* fix license year

* addressed comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants