-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SubsetRagged & PruneRagged #919
Changes from all commits
6afb7e7
7341450
280e2c2
294976f
2ba62e1
efb786d
f9a07de
ccb29e0
d9c7f08
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,27 +111,48 @@ class Renumbering { | |
(pre-renumbering) indexes. Its dimension is the number of | ||
new indexes (i.e. the number of 1 in keep_), but internally | ||
it has one extra element which contains the number of old | ||
elements, so it's OK to read one past the end. (We may | ||
later make it possible to access the array with the one-larger | ||
dimension). | ||
elements, so it's OK to read one past the end. | ||
*/ | ||
Array1<int32_t> &New2Old() { | ||
NVTX_RANGE(K2_FUNC); | ||
if (!new2old_.IsValid()) ComputeNew2Old(); | ||
return new2old_; | ||
} | ||
|
||
/* Return a mapping from new index to old index, with one extra element | ||
containing the total number of kept elements if extra_element == true. | ||
If Keep() can be interpreted as a tails vector, i.e. with 1 at the end | ||
of sub-lists of elements, then New2Old(true) would corresponds to a | ||
row-splits array and Old2New(false) would correspond to a row-ids | ||
array. | ||
*/ | ||
Array1<int32_t> New2Old(bool extra_element) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we make the interface the same as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is the same in effect; there is an overloaded version taking no arg, that doesn't return the extra element, but There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, removing the one with no argument. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RE removing the one with no argument: the reason it's designed the way it is, is to avoid redundant copies of tensors, which could take time at runtime and also bulk up the compiled code with unnecessary instructions. New2Old() is called quite a lot so I thought it was worth saving that time. |
||
Array1<int32_t> &new2old_part = New2Old(); | ||
if (!extra_element) { | ||
return new2old_part; | ||
} else { | ||
// This is a little perverse, using low-level interfaces to increase the | ||
// dimension of the array; but we know it does have one more element. | ||
// Because we normally use New2Old() with no arg (equivalent to false), | ||
// the overloaded version of this function returns a reference for | ||
// efficiency. | ||
return Array1<int32_t>(new2old_part.Dim() + 1, | ||
new2old_part.GetRegion(), 0); | ||
} | ||
} | ||
|
||
/* Return a mapping from old index to new index. This is created on demand | ||
(must only be called after the Keep() array has been populated). | ||
|
||
@param [in] extra_element If true, will return the array of size | ||
NumOldElems() + 1, which includes one more element; | ||
otherwise it will return an array of size NumOldElems(). | ||
|
||
|
||
@return Returns an array mapping the old indexes to the new indexes. | ||
This array is just the exclusive sum of Keep(). | ||
It gives the mapping for indexes that are kept; element | ||
i is kept if `Old2New()[i+1] > Old2New()[i]`. | ||
|
||
@return Returns an array mapping the old indexes to the new indexes. | ||
*/ | ||
Array1<int32_t> Old2New(bool extra_element = false) { | ||
NVTX_RANGE(K2_FUNC); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -759,14 +759,32 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, | |
int32_t max_num_elements = 2000); | ||
|
||
/* | ||
Return ragged shape with only a subset of the bottom-level elements kept. | ||
Require renumbering.NumOldElems() == src.NumElements(). Note: all | ||
dimensions and tot-sizes preceding the final axis will remain the same, which | ||
might give rise to empty lists. | ||
Return ragged shape with only a subset of the elements or sub-lists | ||
on the specified axis kept. (This is not regular sampling, it is | ||
irregular subsampling with specified elements kept). | ||
|
||
@param [in] src The ragged shape that we are subsampling | ||
@param [in] renumbering The renumbering object that dictates | ||
which elements of `src` we keep; we require | ||
renumbering.NumOldElems() == src.TotSize(axis2) | ||
where axis2 = (axis < 0 ? src.NumAxes() + axis : axis). | ||
@param [in] axis The axis to subsample; if negative, will be | ||
interpreted as an offset from src.NumAxes(). | ||
@param [out] elems_new2old If supplied, this function will | ||
output to this location a new2old vector that | ||
dictates how the elements of a ragged tensor | ||
with shape `src` would be renumbered. | ||
@return Returns the subsampled shape. All dimensions and tot-sizes | ||
preceding the axis `axis` will remain the same, which might give | ||
rise to empty lists on those axes; these can be removed if | ||
necessary with RemoveEmptyLists(). | ||
|
||
Notice the other version of this function below. | ||
*/ | ||
RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering); | ||
RaggedShape SubsetRaggedShape(RaggedShape &src, | ||
Renumbering &renumbering, | ||
int32_t axis = -1, | ||
Array1<int32_t> *elems_new2old = nullptr); | ||
|
||
/* | ||
Return ragged shape with only a subset of the elements on the last | ||
|
@@ -777,9 +795,9 @@ RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering); | |
Note: all dimensions and tot-sizes preceding the last two axes will remain the | ||
same, which might give rise to empty lists. | ||
*/ | ||
RaggedShape SubsampleRaggedShape(RaggedShape &src, | ||
Renumbering &renumbering_before_last, | ||
Renumbering &renumbering_last); | ||
RaggedShape SubsetRaggedShape(RaggedShape &src, | ||
Renumbering &renumbering_before_last, | ||
Renumbering &renumbering_last); | ||
|
||
/* | ||
Removes empty lists on a particular axis (not last axis) of a RaggedShape, | ||
|
@@ -866,17 +884,82 @@ RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, | |
|
||
|
||
/* | ||
Return ragged array with only a subset of the bottom-level elements kept. | ||
Require renumbering.NumOldElems() == src.NumElements(). Note: all | ||
dimensions and tot-sizes preceding the final axis will remain the same, which | ||
might give rise to empty lists. | ||
Return ragged array with only a subset of the elements or sub-lists | ||
on the specified axis kept. (This is not regular sampling, it is | ||
irregular subsampling with specified elements kept). | ||
|
||
@param [in] src The ragged shape that we are subsampling | ||
@param [in] renumbering The renumbering object that dictates | ||
which elements of `src` we keep; we require | ||
renumbering.NumOldElems() == src.TotSize(axis2) | ||
where axis2 = (axis < 0 ? src.NumAxes() - axis : axis). | ||
@param [in] axis The axis to subsample; if negative, will be | ||
interpreted as an offset from src.NumAxes(). | ||
@param [out] elems_new2old If supplied, this function will | ||
output to this location a new2old array that | ||
dictates how the elements of a ragged tensor | ||
with shape `src` would be renumbered. | ||
@return Returns the subsampled shape. All dimensions and tot-sizes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comments as on lines 710 and 718. |
||
preceding the axis `axis` will remain the same, which might give | ||
rise to empty lists on those axes; these can be removed if | ||
necessary with RemoveEmptyLists(). | ||
*/ | ||
template <typename T> | ||
Ragged<T> SubsampleRagged(Ragged<T> &src, Renumbering &renumbering) { | ||
return Ragged<T>(SubsampleRaggedShape(src.shape, renumbering), | ||
src.values[renumbering.New2Old()]); | ||
Ragged<T> SubsetRagged(Ragged<T> &src, Renumbering &renumbering, | ||
int32_t axis = -1, | ||
Array1<int32_t> *elems_new2old = nullptr) { | ||
Array1<int32_t> tmp; | ||
if (elems_new2old == nullptr) | ||
elems_new2old = &tmp; | ||
RaggedShape shape = SubsetRaggedShape(src.shape, renumbering, | ||
axis, elems_new2old); | ||
return Ragged<T>(shape, src.values[*elems_new2old]); | ||
} | ||
|
||
/* | ||
This function creates a Renumbering object that can be used to obtain subsets | ||
of ragged arrays via SubsetRaggedShape(). It implements beam pruning as | ||
used in pruned Viterbi search and similar algorithms, where there is both a | ||
beam and a max-active (`max_elems`) constraint. T will probably be float or | ||
double, interpreted as a "positive-is-better" sense, i.e. as scores. | ||
|
||
@param [in] src The ragged object to be subsampled. | ||
@param [in] axis The axis to be subsampled, must satisfy | ||
0 <= axis < src.NumAxes(). The axis before `axis`, if axis > 0, | ||
will be interpreted as a "batch" axis. | ||
@param [in] beam The main pruning beam. The sub-lists of elements on axis | ||
`axis` will be removed if their maximum element (or the element | ||
itself, if axis + 1 == src.NumAxes()) is less than | ||
this_best_elem - beam, where this_best_elem | ||
is the maximum element taken over axis `axis-1` (or over the | ||
entire array, if axis == 0). Think of axis `axis-1`, if | ||
axis > 0, as the "batch" axis, and axis `axis` as the axis that we | ||
actually remove elements or sub-lists on. Empty sub-lists on axis | ||
`axis` will always be pruned, as their score would be treated | ||
as -infinity. | ||
@param [in] max_elems If max_elems > 0, it is the maximum number of | ||
sub-lists or elements that are allowed within any sub-list | ||
on axis `axis-1` (or the maximum number of top-level sub-lists | ||
after subsampling, if axis == 0). We keep the best ones. | ||
If max_elems <= 0, there is no such constraint. | ||
@return Returns the renumbering object to be used to actually | ||
prune/subsample the specified axis. | ||
|
||
Example: | ||
PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 1, 5.0, 3) | ||
would create a Renumbering object that would prune the | ||
ragged tensor to [ [0 -1 -2], [ -10 ], [ ] ] | ||
|
||
PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 0, 5.0, 0) | ||
would create a Renumbering object that would prune the | ||
ragged tensor to [ [0 -1 -2 -3] ] | ||
*/ | ||
template <typename T> | ||
Renumbering PruneRagged(Ragged<T> &src, | ||
int32_t axis, | ||
T beam, | ||
int32_t max_elems); | ||
|
||
/* | ||
Stack a list of Ragged arrays to create a Ragged array with one more axis. | ||
Similar to TF/PyTorch's Stack. The result will have Dim0 == num_srcs. All | ||
|
@@ -974,8 +1057,7 @@ void Unstack(Ragged<T> src, int32_t axis, std::vector<Ragged<T>> *out, | |
/* | ||
Concatenate a list of Ragged<T> to form a single Ragged<T>. | ||
|
||
@param [in] axis Axis to append them on. Currently | ||
we only support axis == 0 or axis == 1. | ||
@param [in] axis Axis to append them on. | ||
Previous axes must | ||
have the same shape, i.e. if axis == 1 | ||
then `src[i]->Dim0()` must all have the | ||
|
@@ -1368,7 +1450,7 @@ Ragged<T> Merge(int32_t num_srcs, Ragged<T> **src, | |
/* | ||
Returns a ragged tensor after removing all 'values' that were <= a provided | ||
cutoff. Leaves all layers of the shape except for the last one unaffected. | ||
Equivalent to SubsampleRaggedShape with a numbering given by (src.values[i] <= | ||
Equivalent to SubsetRaggedShape with a numbering given by (src.values[i] <= | ||
cutoff). | ||
*/ | ||
template <typename T> | ||
|
@@ -1377,7 +1459,7 @@ Ragged<T> RemoveValuesLeq(Ragged<T> &src, T cutoff); | |
/* | ||
Returns a ragged tensor after removing all 'values' that equal a provided | ||
target. Leaves all layers of the shape except for the last one unaffected. | ||
Equivalent to SubsampleRaggedShape with a numbering given by (src.values[i] == | ||
Equivalent to SubsetRaggedShape with a numbering given by (src.values[i] == | ||
target). | ||
*/ | ||
template <typename T> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on line 115 needs also to be updated.