-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features/714 trace #718
Features/714 trace #718
Conversation
rerun tests |
heat/core/linalg/basics.py
Outdated
sum_along_diagonals = factories.array( | ||
sum_along_diagonals_t, dtype=dtype, split=gather_axis, comm=a.comm, device=a.device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to clarify, you will return a split object even if the input is split=None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this won't happen. If a.split=None
, the first part of the (and
) if condition is already false and the else case is always executed. Thus the result will also be not distributed.
rerun tests |
Description
Implementation of
ht.trace
, which computes the sum along the diagonal, analogous to np.trace.Used torch functions:
Issue/s resolved: #714
Strategy (Algorithm)
CASE 2D Input
Returns: Scalar
MPI_Allreduce
operation to obtain the final resultCASE > 2D Input
Returns: DNDArray
(Hint: as
torch.trace
is only implemented for 2-dimensional tensors, I decided to do a workaround by first extracting the diagonal entries viatorch.diagonal
followed by a summation along the last axis of the resulting array)I differentiated between two cases: the split axis being or not being within the "trace axes" (
axis1
,axis2
), as this leads to the required diagonal entries of the last axis being distributed on varying processes or not.CASE: split axis NOT in (
axis1
,axis2
)torch.sum
as explained abovegather_axis
")CASE: split axis IN (
axis1
,axis2
)torch.diagonal
provides such even with unmodifiedoffset
)MPI_Allgather
If
out
is notNone
, provide the correct split configuration to store the result in the given variable.Changes proposed:
ht.trace
Type of change
Due Diligence
Does this change modify the behaviour of other functions? If so, which?
no