Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Batching improvements for GEMM/TRSM operators and full MKL usage docs. #8846

Merged
merged 2 commits into from
Jan 15, 2018

Conversation

meissnereric
Copy link
Contributor

@meissnereric meissnereric commented Nov 28, 2017

Description

During some benchmarking I discovered that the CUDA internal batching implementations for trsm and gemm operators were slow for large matrices. For gemm, found that the gemmStridedBatch implementation is faster at all matrix sizes so we should use that when possible (cuda 8+) Otherwise, since most use cases for these operators use relatively large matrices, use a simple for loop for batch calls instead of the specific batched cuda implementation.

Also added instructions for how to compile with a full MKL installation instead of just the MKL2017 subset.

Checklist

Essentials

  • [ X ] Passed code style checking (make lint)
  • [ X ] Changes are complete (i.e. I finished coding on this PR)
  • [ X ] All changes have test coverage
  • [ X ] For user-facing API changes, API doc string has been updated. For new C++ functions in header files, their functionalities and arguments are well-documented.
  • [ X ] To my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
  • Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
  • Added instructions for using a full MKL installation instead of just MKL2017

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

Reviewers

linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
check_gemm(A[0], B[0], C[0], alpha, beta, tA, tB); \
using namespace mshadow::cuda; \
int ngrid = std::min(kMaxGridNum, \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ngrid is not needed anymore

using namespace mshadow::cuda; \
int ngrid = std::min(kMaxGridNum, \
static_cast<int>((A.size(0) + kBaseThreadNum - 1) / kBaseThreadNum)); \
linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls remove the function linaalgCollectBatchOffsetsGPU from the file as it should not be needed anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -1,3 +1,21 @@
# Full MKL Installation
Copy link
Contributor

@asmushetzel asmushetzel Nov 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should mention the purpose of doing so, i.e. that this will enable MKL for all operators in the linalg-namespace.

What about this piece of code (I guess it is still in the config):
ifeq ($(USE_BLAS), mkl)
USE_LAPACK = 0
endif

Guess this has to be changed as well.

And unfortunately this does not work exactly as planned. With the suggested settings, a user would get MKL for blas/lapack, but same time setting USE_MKL2017=0 would internally switch off use of MKLML for a alot of NN operators. Setting USE_MKL2017=0 was just a shortcut for our experiments with the linalg-operators.

The mechanism ideally should work such that the user just sets USE_BLAS=mkl and that is it. He can in addition set USE_MKL2017 and then also some other operators will start useing MKL's NN-functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I believe I changed the config file to adhere to the behavior we want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to me no config changes now except removal of the piece of code above. Is this enough? I.e. have you tried that specifically the variant
USE_BLAS=mkl
USE_MKL2017=1
works?

@@ -103,8 +103,8 @@ void linalg_batch_gemm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<
LINALG_CPU_GEMM(sgemm, float)
LINALG_CPU_GEMM(dgemm, double)

LINALG_CPU_BATCH_GEMM(float)
LINALG_CPU_BATCH_GEMM(double)
LINALG_XPU_BATCH_GEMM(cpu, float)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes an error in the amalgamation build as there MSHADOW_USE_CBLAS==0 && MSHADOW_USE_MKL==0. As a consequence, the dummy stubs starting at line 84 will be generated, but they are named "CPU_GEMM" and not "XPU_GEMM". So you may have to change the names of these dummy stubs such that a call to LINALG_XPU_BATCH_GEMM(cpu,...) correctly generates these stubs instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, let's see if it builds correctly.

@piiswrong
Copy link
Contributor

@asmushetzel Any updates?

@meissnereric
Copy link
Contributor Author

Apologies for excess merge commits, not sure how to clean those up.

@meissnereric
Copy link
Contributor Author

Hey, this was rebased against master just before pushing. I'm not sure why the pr-merge isn't building correctly, is this alright?

@asmushetzel
Copy link
Contributor

The pr-merge error seems fishy, i.e. not related to your changes. There should never any non-checked in changes in a build.

MKL_README.md Outdated

1.1 Set ADD_LDFLAGS=-L<path/to/mkl/lib/folder> (ex. ADD_LDFLAGS=-L/opt/intel/compilers_and_libraries_2018.0.128/linux/mkl/lib)

1.1 Set ADD_CFLAGS=-L<path/to/mkl/include/folder> (ex. ADD_CFLAGS=-L/opt/intel/compilers_and_libraries_2018.0.128/linux/mkl/include)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess you mean the "-I" flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The examples I have there are what I ran it with successfully, using "-L" not "-l". I'm not sure exactly which things would need to be included if I used "-l".

@meissnereric
Copy link
Contributor Author

meissnereric commented Jan 10, 2018

Hmm the pr-head build failed with:

docker: Error response from daemon: create nvidia_driver_384.111: VolumeDriver.Create: internal error, check logs for details.
See 'docker run --help'.
script returned exit code 125

I don't think that has to do with my changes, but is there anything I can do to help find/fix whatever the issue is?

Apart from that failure, once this passes I think these changes are good to go in.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017
@asmushetzel
Copy link
Contributor

@marcoabreu Marco, seems like a CI problem. Any advice how Eric can get this going again?

@meissnereric
Copy link
Contributor Author

Looks like something else went wrong this time, the build has been hanging for 18 hours now on most of the GPU builds. Is there an easy way to restart/rerun the CI tests?

@marcoabreu
Copy link
Contributor

The nvidia_driver_384 issue was due to a security patch for Spectre. We have reinstalled all slaves yesterday evening and it should be working now.
Regarding the hang: We're still using the windows slaves from the old system and they are pretty unrealiable, but we're going to set them up from scratch soon.

In order to trigger a new build, just make a new commit. The old build will time out after 24h.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017
@meissnereric
Copy link
Contributor Author

Cool, thanks Marco I thought it might be related to the Spectre change. Just pushed up a new commit, hopefully it doesn't get the same hang this time.

@asmushetzel
Copy link
Contributor

IMO, this can be integrated now.

@piiswrong piiswrong merged commit 3ac5376 into apache:master Jan 15, 2018
CodingCat pushed a commit to CodingCat/mxnet that referenced this pull request Jan 16, 2018
apache#8846)

* Batching improvements for GEMM/TRSM operators and full MKL usage docs.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017

* Batching improvements for GEMM/TRSM operators and full MKL usage docs.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017
@meissnereric meissnereric deleted the gemm_trsm_batching branch January 16, 2018 10:10
larroy pushed a commit to larroy/mxnet that referenced this pull request Jan 18, 2018
apache#8846)

* Batching improvements for GEMM/TRSM operators and full MKL usage docs.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017

* Batching improvements for GEMM/TRSM operators and full MKL usage docs.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017
yuxiangw pushed a commit to yuxiangw/incubator-mxnet that referenced this pull request Jan 25, 2018
apache#8846)

* Batching improvements for GEMM/TRSM operators and full MKL usage docs.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017

* Batching improvements for GEMM/TRSM operators and full MKL usage docs.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
apache#8846)

* Batching improvements for GEMM/TRSM operators and full MKL usage docs.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017

* Batching improvements for GEMM/TRSM operators and full MKL usage docs.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
apache#8846)

* Batching improvements for GEMM/TRSM operators and full MKL usage docs.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017

* Batching improvements for GEMM/TRSM operators and full MKL usage docs.

* Changed GEMM operator to use gemmStridedBatch CUDA implementation when CUDA is version 8 or higher, otherwise to just do batching manually.
* Changed TRSM operator to not use the CUDA batching functionality as it's slower for large matrices. Instead do batching manually.
* Added instructions for using a full MKL installation instead of just MKL2017
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants