-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[op] added support for TensorScatterUpdate op #7189
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
let defaultValuesString = ''; | ||
if (defaultIsTensor) { | ||
defaultValuesString = 'coords[0], coords[1]'; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea! This bypasses flattening and de-flattening the index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a nit
for (size_t k = 0; k < slice_size; ++k) { | ||
*out_buf_ptr = static_cast<T>(*updates_ptr); | ||
|
||
out_buf_ptr++; | ||
updates_ptr++; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: This is somewhat confusing. Can we use memcpy to copy the slices instead?
for (size_t k = 0; k < slice_size; ++k) { | |
*out_buf_ptr = static_cast<T>(*updates_ptr); | |
out_buf_ptr++; | |
updates_ptr++; | |
} | |
memcpy(out_buf_ptr, updates_ptr, slice_size * dtype_size); | |
out_buf_ptr += slice_size; | |
updates_ptr += slice_size; |
It may be possible to further simplify the index calculations as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 2 of 1 approvals obtained (waiting on @mattsoulanille)
tfjs-backend-wasm/src/cc/kernels/TensorScatterUpdate.cc
line 50 at r3 (raw file):
memcpy(out_buf_ptr, updates_ptr, slice_size * dtype_size);
out_buf_ptr += slice_size;
updates_ptr += slice_size;
thanks
FEATURE * added support for TensorScatterUpdate op * fix snippet error * fix license year * addressed comments
ref #6709
This is the backend side of implementing TensorScatterUpdate op support across cpu/wasm/webgl/node
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is