Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Fix collect_params().zero_grad() in gluon numpy interface #16716

Merged
merged 4 commits into from
Nov 13, 2019

Conversation

sxjscience
Copy link
Member

@sxjscience sxjscience commented Nov 4, 2019

Description

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Fix the zero_grad in mxnet numpy interface, tests

@sxjscience sxjscience requested a review from szha as a code owner November 4, 2019 02:19
for ele in arr:
ele[:] = 0
else:
mx.nd.reset_arrays(*arr, num_arrays=len(arr))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not always use in-place assign?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m not sure why we used reset_arrays before. I guess that it would be faster if we use multiple arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's the case, then we need its equivalence in npx namespace @reminisce

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an alias _npi_reset_arrays in reset_arrays.cc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked the source code. The new approach should be fine as long as we use cudaMemsetAsync for implementing ele[()] = 0. In fact, reset_arrays.cc lies in the contrib folder and there is no need to add it to numpy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is worth having diverging implementation. If reset_arrays is not useful then we should stay away from it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we move away from reset_array in the old ndarary too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually do not know why we’ve used the reset_array. This op should be in the contrib while now it’s in the main API. I think this is somehow out-of-the-scope of this PR.

@sxjscience sxjscience changed the title [Numpy][WIP] Fix collect_params().zero_grad() in gluon numpy interface [Numpy] Fix collect_params().zero_grad() in gluon numpy interface Nov 7, 2019
@sxjscience
Copy link
Member Author

@reminisce @szha I've added the test. Should be ready for review

@szha
Copy link
Member

szha commented Nov 7, 2019

Shall we move away from reset_array in the old ndarary too?

@sxjscience This concern is not addressed yet

Copy link
Member

@szha szha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that we should have divergence in implementation in such simple task of resetting the gradients.

@sxjscience
Copy link
Member Author

Because we need to use a[:]=0 for the original ndarray and use a[()] = 0 for the new numpy ndarray, we are not able to share the implementation.

@sxjscience
Copy link
Member Author

The reset_array was introduced in #16446. May be we should ask @ptrendx . For the numpy array, the current approach should be fine.

@sxjscience
Copy link
Member Author

I guess the main purpose is to accelerate the speed of initializing a huge amount of NDArrays. Adding a reset_array op is appropriate as a short-term solution. However, numpy serves as the first step to our longer-term evolvement and we should consider to solve it in a different way.

@szha
Copy link
Member

szha commented Nov 7, 2019

I've checked the source code. The new approach should be fine as long as we use cudaMemsetAsync for implementing ele[()] = 0. In fact, reset_arrays.cc lies in the contrib folder and there is no need to add it to numpy.

I think reset_arrays also tries to help with reducing the launch overhead from the multiple kernel launches from the for loop implementation.

@sxjscience
Copy link
Member Author

sxjscience commented Nov 7, 2019

@szha If these operators are executed in the bulk mode there will be no StreamSynchrnoize in-between. I'm actually investigating the root of the overhead and I think it could be solved. Also, we could later fuse these initialization kernels into a single kernel. Adding a new operator is not the best way and would not suit our long-term goal, i.e., MXNet 2.0.

@reminisce
Copy link
Contributor

reminisce commented Nov 7, 2019

I'm personally not fond of using the operator reset_array for some specific purpose. First of all, it does not feel as natural as using for-loop for assigning zeros to a list of ndarrays. Performance-wise, the operator still launches a cuda kernel for each ndarray in the same stream, while in the for-loop way, assignment of zeros could happen in different streams without that much overhead of stream dependency. In addition, the execution of reset_array depends on all the input arrays being write-ready which is not ideal since assignment of zeros happens individually for each ndarray.

@ptrendx
Copy link
Member

ptrendx commented Nov 7, 2019

@reminisce No. That is not how it works and please do not undo performance optimizations because of such false presumptions (especially since this is not a user facing code, zero_grad is, so what "feels more natural" should not really matter).

Let me give you a little bit of data and context - when pretraining BERT, zeroing of gradient arrays (which happened once per few iterations) took ~5% of time, because of this approach to launch each zeroing as a new operator. It is ridiculously high overhead. The actual GPU cost of this zeroing was minuscule, but the cost of the operator launch and sync at the end is 10+x that.

About the fact that the new implementation uses cudaMemsetAsync per array - frankly that was a "good enough" fix for that (BERT) usecase. The actual full fix to this would mean writing an aggregated operator that would zero all the arrays inside a single kernel - I bet this would increase the performance of zeroing gradients by additional 2x or more.

This comment:

However, numpy serves as the first step to our longer-term evolvement and we should consider to solve it in a different way.

together with

First of all, it does not feel as natural as using for-loop for assigning zeros to a list of ndarrays. 

scare me, frankly. Doing numpy-like fully imperative execution has no chance of being actually fast and leaning into that direction will not make MXNet any better. Gluon seems like a good compromise - sure, do imperative if you just debug but hybridize when you are actually serious about getting something deployable. And the fix proposed in this PR: if this is NumPy-like array then do slow thing (especially since as I understand it numpy style will be promoted going forward) is super bad practice.

@sxjscience
Copy link
Member Author

@ptrendx I think you have misunderstood my comment. What I means is that we could try to fuse these zeroing operators into a single one instead of writing a new operator to do that.

@ptrendx
Copy link
Member

ptrendx commented Nov 7, 2019

@sxjscience How would you like to fuse them without writing a new operator?

@sxjscience
Copy link
Member Author

@ptrendx For example, we can later try to hybridize part of the imperative codes by the nvrtc approach that you used in the fused op or rely on TVM. I'm not here to blame the reset_arrays op. My concern is that this feature is in contrib and has not been promoted to the official API yet. We can support that if we find it's appropriate, but that's out-of-the-scope of this fix.

@sxjscience
Copy link
Member Author

@ptrendx The problem of the reset_arrays approach is that the users may later require fill_arrays, which try to fill multiple NDArrays with the same value. We can also write a large kernel to fuse it but this is not scalable.

@reminisce
Copy link
Contributor

@ptrendx The the performance overhead in your benchmark really comes from the FFI and pushing ops to the async engine. It becomes more obvious when the kernel execution is negligible. We are working on reducing the operator calling overhead. Except that, just from the pure code analysis, reset_arrays requires all input ndarrays to be write-ready to run in the same cuda stream, while the other way has no such restriction. Your benchmark case might be special to be in favor of reset_array, but this op is exposed as a public API and we cannot prevent users from using it in other cases.

sync at the end is 10+x that.

I'm not sure what is the sync here you are referring to. Could you explain?

@sxjscience
Copy link
Member Author

scare me, frankly. Doing numpy-like fully imperative execution has no chance of being actually fast and leaning into that direction will not make MXNet any better. Gluon seems like a good compromise - sure, do imperative if you just debug but hybridize when you are actually serious about getting something deployable. And the fix proposed in this PR: if this is NumPy-like array then do slow thing (especially since as I understand it numpy style will be promoted going forward) is super bad practice.

Leaning towards numpy does not mean to throw away the mixing of symbolic and imperative. It's more about the front-end interface to be numpy-like. Also, doing slow things is not always bad. Sometimes it will improve the readability and make the package more researcher-friendly.

@reminisce
Copy link
Contributor

Another big factor that may contribute to the slowdown of assigning zeros is through a[:] = 0 which has to go through ndarray.__setitem__ indexing analysis to find the appropriate workflow which is known to be very slow. Nevertheless, this is still thing that we can improve through other avenues.

@ptrendx
Copy link
Member

ptrendx commented Nov 7, 2019

Ok, let me address those comments 1 point at a time :-).

  • usage of TVM/nvrtc - I am generally in favor of that (even though it is harder than it looks because those arrays do not have the same shape and the imperative nature of the code makes it tricky when such horizontal fusion can happen), but this is not a short term solution for this problem
  • other cases that look similar - I agree with you, longer term general solution is needed
  • reset_arrays is in contrib directory - that is unfortunate placement, I agree.
  • source of the performance overhead - no. I strongly encourage you to look at the profiler (something like nvprof, not MXNet profiler as it only gives you the info on how much time operator takes and does not tell you how much of that time was actually spent on the GPU) and look at it yourself. I agree that FFI and creating (and destroying) engine op takes some time (which could be reduced by e.g. having a pool of ThreadedOpr). The main source of the overhead in the GPU case however is actually the fact that each operation needs to synchronize after calling the kernel (as GPU is asynchronous with respect to the host CPU) in order to update engine dependencies, which for those super short operations like zeroing array not only slows down because of the overhead of this sync (cudaStreamSynchronize at the end of the operator), but also completely exposes overhead of the kernel launch of the next operator (because the fact that GPU is asynchronous and you could queue multiple launches is completely lost if you need to sync after every one of them).
  • Also, doing slow things is not always bad. - my HPC soul screams in terror when reading this :-P. I am not against having simple abstractions for the user - in fact I am all for it. The role of the framework though is internally take those simple abstractions and transform them into efficient execution.

@sxjscience
Copy link
Member Author

@ptrendx
Let me clarify a little bit:

  1. In the nd interface, using reset_arrays as a short-term solution is acceptable

    I agree that using reset_arrays as the short term solution is appropriate. That's the reason why I haven't revised the reset_arrays part. I understand the underlining logic of accelerating the code via reset_arrays, which tries to reduce the number of CudaStreamSynchronize. In fact, the correct way to guarantee the order of the launched device kernels is to launch them in the same stream without intermediatecudaStreamSynchronize. Currently, there is the bulk mode in MXNet engine which pushes multiple operators asynchronously and waits in the end:

    https://github.com/apache/incubator-mxnet/blob/c38b52784325380f79cafb9f0407ad327554fe6b/src/engine/threaded_engine.h#L534-L546

    Thus, if we redesign the engine and analyze the dependency of the computation nodes, we are able to do better in the hybridized case.

  2. Numpy serves as the first step to the longer-term goal

    As stated in [RFC] Apache MXNet 2.0 Roadmap #16167, the trend is to move towards numpy-like programing experience (not exactly compatible, but we will make sure that the incompatible cases have good practical reasons.). This means, we have the chance of solving the issue in a more systematic manner.

  3. What a better way may look like

    I think in the future, we should be able to support hybridizing an arbitrary function:

    @hybridize
    def zero_grad(arrays):
        for arr in arrays:
           arr[()] = 0

    Thus, I suggest we revisit the decision later after we obtained some profiling numbers for the models trained via numpy. If we find that the reset_arrays is a nice solution, we need to promote that to the main API and also support the numpy interface.

  4. "Doing slow things is not always bad."

    I'm not an HPC person. In my PhD years, I was mainly focusing on studying and developing machine learning-based models for spatio-temporal problems. However, I do know that speed matters much because when working on the DKVMN paper, WWW2017 with MXNet, I find that our embedding layer is much slower than Torch and I tried to accelerate the code (Accelerate AddTakeGrad + Support Sorting dmlc/mshadow#153), which helps me run the experiments using limited K20 cards (5G GPU memory) in my lab in HKUST. I also accelerated part of MXNet, like the recent acceleration of the LayerNorm ([OP] Accelerate GPU version of LayerNorm(axis=-1) #14935) and I used MXNet to train graph-attention-based models on relatively large-scale graphs: GaAN paper, UAI 2018.

    Speed matters a lot. But I haven't revised the reset_arrays in this PR because I think we could potentially do better in the numpy interface. I'll consider to promote the reset_arrays to the main API and add numpy support if we later find it's worthwhile. But, that's out of the scope of this simple fix.

@reminisce
Copy link
Contributor

reminisce commented Nov 8, 2019

To avoid too frequent cuda stream synchronization for the arrays to be zeroed without introducing a new operator, I think we can put the assignment loop into the bulk scope so that there should be only one stream synchronization in the end.

with mx.engine.bulk(len(arrays)):
    for arr in arrays:
        arr[:] = 0

@ptrendx
Copy link
Member

ptrendx commented Nov 8, 2019

@reminisce Huh, I did not know about that way of trying to bulk the imperative execution. You are right that if it worked well then that would solve this issue. Unfortunately, I tested it with this script:

import mxnet as mx
import time

arrays = [mx.nd.ones((100,100), ctx=mx.gpu()) for _ in range(500)]

for a in arrays:
    a[:] = 0
mx.nd.waitall()

start = time.time()

for _ in range(10):
    for a in arrays:
        a[:] = 0

mx.nd.waitall()
end = time.time()
print("normal: Elapsed ", end - start)

mx.nd.waitall()

start = time.time()

with mx.engine.bulk(len(arrays)):
    for _ in range(10):
        for a in arrays:
            a[:] = 0

mx.nd.waitall()
end = time.time()

print("bulk: Elapsed ", end - start)

mx.nd.waitall()

start = time.time()

for _ in range(10):
    mx.nd.reset_arrays(*arrays, num_arrays=len(arrays))

mx.nd.waitall()

end = time.time()

print("reset_arrays: Elapsed ", end - start)

and got those results:

# python test.py
normal: Elapsed  0.8372836112976074
bulk: Elapsed  0.6354436874389648
reset_arrays: Elapsed  0.016309261322021484

(I also tried with the with mx.engine.bulk() line inside the for _ in range(10) loop, the results were similar).
Looking at the profile, bulking does work (as in, it removes the synchronization between ops), but it introduces HUGE gaps between bulks (over 65 ms in this example). If we can fix that, then the reset_arrays approach will not be needed.

@sxjscience
Copy link
Member Author

@ptrendx What's your suggestion for this PR? I think we can create another PR to promote the reset_arrays to the main API and also add numpy support if we find that it's the best solution.

In fact, similar problem also happens for the global gradient clipping, in which we need to calculate the L2 norms of all the parameters. Currently, it's implemented as the summation of individual L2 norms:
https://github.com/dmlc/gluon-nlp/blob/61ec27064545ae6350d984b0f7a0944e6e28ed47/src/gluonnlp/utils/parameter.py#L71-L102

I have had the plan to write a large global_norm kernel to accelerate the speed of this part. The problem of this approach in general is that it won't be scalable. Consider the case of an arbitrary optimization algorithm, e.g., ADAM. ADAM has not used the block-wise structure of the parameters and update them by treating all parameters + gradients as a vector. We can certainly accelerate the speed by concatenating the parameters together and directly write a kernel for that. The reset_arrays solution is not applicable in this case because we have multiple element-wise operators in the optimization process. Thus, I think we should try to provide a more general solution instead of always relying on the reset_arrays approach.

@reminisce
Copy link
Contributor

@ptrendx Thanks for the script. I think a large part of overhead for zeroing ndarrays individually in Python comes from ndarray indexing, FFI, and pushing operators to the async engine. I modified your script a little bit to demonstrate the point.

  import mxnet as mx
 import time
 
 arrays = [mx.nd.ones((100,100), ctx=mx.gpu()) for _ in range(500)]
 
 for a in arrays:
     a[:] = 0
 
 num_repeats = 10
 
 mx.nd.waitall()
 start = time.time()
 #for _ in range(num_repeats):
 for a in arrays:
     mx.nd.zeros(a.shape, out=a)
 end = time.time()
 #print("async push per `mx.nd.zeros`: Elapsed ", (end - start) / num_repeats / len(arrays))
 print("async push per `mx.nd.zeros`: Elapsed ", (end - start) / len(arrays))
 
 mx.nd.waitall()
 start = time.time()
 for _ in range(num_repeats):
     for a in arrays:
         mx.nd.zeros(a.shape, out=a)
 mx.nd.waitall()
 end = time.time()
 #print("normal: Elapsed ", (end - start) / num_repeats)
 print("normal: Elapsed ", (end - start))
 
 mx.nd.waitall()
 start = time.time()
 for _ in range(num_repeats):
     with mx.engine.bulk(len(arrays)):
         for a in arrays:
             mx.nd.zeros(a.shape, out=a)
 mx.nd.waitall()
 end = time.time()
 #print("bulk: Elapsed ", (end - start) / num_repeats)
 print("bulk: Elapsed ", (end - start))
 
 mx.nd.waitall()
 start = time.time()
 for _ in range(100):
     mx.nd.reset_arrays(*arrays, num_arrays=len(arrays))
 end = time.time()
 print("async push per `reset_arrays`: Elapsed ", (end - start) / 100)
 #print("reset_arrays: Elapsed ", (end - start) / num_repeats)
 
 mx.nd.waitall()
 start = time.time()
 for _ in range(num_repeats):
     mx.nd.reset_arrays(*arrays, num_arrays=len(arrays))
 mx.nd.waitall()
 end = time.time()
 print("reset_arrays: Elapsed ", (end - start))
 #print("reset_arrays: Elapsed ", (end - start) / num_repeats)

and got results

async push per `mx.nd.zeros`: Elapsed  7.888364791870118e-05
normal: Elapsed  0.3912644386291504
bulk: Elapsed  0.3276066780090332
async push per `reset_arrays`: Elapsed  0.0005680346488952637
reset_arrays: Elapsed  0.019466638565063477

If you calculate the overhead of invoking zeroing 500 ndarrays with 10 repeats (roughly excluding the kernel execution time), it's 8.108711242675781e-05 * 500 * 10 = 0.40543556213378906 seconds. This is just an estimated number, but it shows how significant the accumulated overhead of invoking operators is for small ops.

I agree in this situation, we should keep reset_arrays as an intermediate solution to keep the performance on par, and we will continue to optimize the latency of invoking operators.

@sxjscience
Copy link
Member Author

@szha Is it good to merge?

@reminisce
Copy link
Contributor

@szha @ptrendx From the discussion, I think we have aligned to keep nd.reset_arrays as is for the legacy mode and be in favor of NumPy-like programming style in the new mode. We will continue to optimize operator invocation latency so that zeroing a list of np.ndarrays can be as performant as zeroing them all through a single operator. What do you think?

@sxjscience sxjscience merged commit e88e97f into apache:master Nov 13, 2019
@sxjscience sxjscience deleted the fix_zero_grad_gluon branch November 13, 2019 19:16
ptrendx pushed a commit to ptrendx/mxnet that referenced this pull request Nov 15, 2019
ptrendx added a commit that referenced this pull request Nov 16, 2019
…, #16792) (#16832)

* Fix nightly build (#16773)

* Remove dependency on tvmop.conf

* Fix binaries dependencies for ni nightly

* Add comments

* Update tvmop.py

* Fix rebase

* Fix (#16781)

* Speed fused_op compilation by caching ptx and jit-compiled functions (#16783)

* [Numpy] Fix collect_params().zero_grad() in gluon numpy interface (#16716)

* fix zero_grad

* Update parameter.py

* add test

* fix

* Mixed data type binary ops (#16699)

* support mixed-precision binary operations

* improvement for documentations and error messages

* Support boolean elemwise/broadcast binary add, multiply and true_divide (#16728)

* support pure boolean elemwise/broadcast binary op

* switch to unique_tpr

* fix the test error

* Fix rtrue_divide grad (#16769)

* Fix rtrue_divide_scalar

* More tests

* Fix numpy-compatible mean output type for integer inputs (#16792)

* fix mean output type for integer inputs

* enable for windows
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants