Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerated Backends (CPU & GPU) #8

Open
11 of 18 tasks
TiarkRompf opened this issue Sep 20, 2018 · 8 comments
Open
11 of 18 tasks

Accelerated Backends (CPU & GPU) #8

TiarkRompf opened this issue Sep 20, 2018 · 8 comments
Assignees
Labels
feature New feature or request
Milestone

Comments

@TiarkRompf
Copy link
Collaborator

TiarkRompf commented Sep 20, 2018

High priority:

  • cuBLAS
  • cuDNN

CUDA backend todos:

  • Implement cuDNN ops (forward and backward). Prioritize ones used in models.
    • Done. 2-D convolution.
    • Done. Activation functions (relu, tanh, sigmoid).
    • Done. Softmax and log-softmax.
    • Done. Max-pooling.
    • Reduction ops (e.g. sum).
      • Done. Implement cudnnReduceTensor, which supports many reduction operations (sum/product/avg/argmax/norm).
      • Implement reduction backward functions. This may be facilitated by custom kernel generation. cuDNN doesn't seem to provide a backward function for reduction ops.
    • Loss functions (e.g. nll, mse).
      • Done. nll. Currently, GPU impl uses CPU impl + copying. Revisit when custom kernel generation is possible.
    • Batch normalization (lower priority if it's not used in models).
  • Revisit GPU elementwise broadcasting op implementation.
    • The current approach comes from PyTorch Aten and involves multiple template functions (gpu_binary_kernel -> launch_kernel -> elementwise_kernel). We should be able to generate simpler, specialized kernels that potentially don’t do bounds checking. This will likely improve performance (though it might not matter too much for models).
  • Implement custom kernel generation. (WIP by @dan-zheng)
    • Learn from UninlinedFunctionOps in snek. Basically, dupe LMS Lambda infrastructure, add a gpuFunc DSL function mimicking fun.
    • Custom kernels/device functions will facilitate some things. Revisit GPU elementwise op and reduction op implementations.
  • Fix memory management. (assigned to @feiwang3311)
    • Currently, myGpuMalloc uses a memory arena, much like myMalloc on CPU. However, this produced some errors. A more robust approach that prevents memory leaks is necessary for training loops.
      • GPU backend cleanup fails: CUDA_CALL(cudaFree(gpuMallocAddr)) produces CUDA error occurred: invalid device pointer. This doesn't manifest in GPU tests containing backend = BackendCPU() because backend isn't reset, so the cudaFree cleanup is never called.
      • cuDNN dropout failed due to some alignment issue: CUDA error occurred: misaligned address. Using cudaMalloc fixed the issue.
  • Add Tensor helper method to produce cudnnTensorDescriptor_t.
    • Currently, cudnnTensorDescriptor_t is constructed on-demand at call sites, which is inefficient/redundant.
    • We can learn from PyTorch's various Descriptor wrappers.
    • Should we add descriptor nodes for robust lowering? Currently, descriptors are constructed within unchecked calls, which leads to name conflicts. The current workaround is to use ad-hoc block scopes: { cudnnTensorDescriptor_t x_desc; ... }.

Medium priority:

  • OpenBLAS/MKL (CPU)

Other potential back-ends to look into: TVM, XLA, ...

@dan-zheng dan-zheng self-assigned this Sep 20, 2018
@TiarkRompf TiarkRompf added this to the M1 milestone Sep 20, 2018
@TiarkRompf TiarkRompf added the feature New feature or request label Sep 20, 2018
@dan-zheng
Copy link
Collaborator

A robust solution for GPU support needs to reconcile CPU-only operations (e.g. printf) and operations that should be run on GPU (e.g. cuBLAS library functions).

For cuBLAS/cuDNN support, I think I'll start with the naive implementation of allocating all tensors on GPU memory. This is the shortest path to testing GPU code generation.

However, this essentially breaks all operations that aren't defined with GPU memory in mind: printf certainly won't work (unless it's modified to copy tensors to CPU memory) and even ops like elementwise addition need to be rewritten using library functions like cublasSaxpy.

Redefining many ops for GPU support greatly increases the surface area of the Backend trait, which is not ideal. If you have ideas for avoiding this, or if you have other ideas/feedback about backend support, please share!

@dan-zheng
Copy link
Collaborator

dan-zheng commented Oct 3, 2018

@jmd1011 had the idea of running all ops on the CPU by default, and only using GPU ops within an explicitly demarcated section of code (e.g. a training loop). I feel like this design is facilitated by the flexible Backend trait implementation: simply change the backend value to change the codegen target.

This approach leads to a better separation of concerns: rather than handling arbitrary mixings of CPU and GPU ops (which effectively requires each op to worry about the device allocation of its arguments and result), only "chunks" of CPU and GPU code are handled (ops assume tensors all live on either CPU or GPU). This means that the backend-swapping code is responsible for handling "copying tensors between devices" (rather than every single op).

// Adapted from mnistCNN.scala.
val mnist = new DslDriverC[String, Unit] with TensorExp {
  def snippet(a: Rep[String]): Rep[Unit] = {
    // The backend is initially CPU (`var backend: BackendNative`).
    val data = new DataLoader("mnist")
    ...

    // Start training loop. Generate GPU ops!
    backend = new BackendCudnn
    for (epoch <- 0 until epochCount: Rep[Range]) {
       data foreach { (input: Tensor, target: Rep[Int]) =>
         // It's nice to have a way to print values within the training loop.
         // Some ad-hoc mechanism for communication would be good.
         // Strawman syntax:
         // `printf("Loss: %f\n", loss.toCPU())`
         ...
       }
    }

    // Change backend back to CPU.
    backend = new BackendNative
    printf(...)
  }
}

This idea seems similar to "device placement" in TensorFlow:

results = []
a = tf.get_variable("a", ...)
b = tf.get_variable("b", ...)

# GPU 0 performs matmul.
with tf.device('/gpu:0'):
    results.append(tf.matmul(a, b))

# GPU 1 performs addition.
with tf.device('/gpu:1'):
    results.append(a + b)

# TensorFlow handles copying tensors between devices.
with tf.device('/cpu:0'):
    sum = tf.add_n(results)

Here's the equivalent feature in Swift for TensorFlow. It should be possible to implement a similar API in Lantern:

Original incomplete prototype
// Not sure what the type of `f` should be. Any tips?
def withBackend(b: Backend, f: ??? -> ???) = {
  val originalBackend = backend
  // Copy tensors to the new backend.
  // Question: what tensors need to copied?
  // Answer: the ones that are passed as arguments to `f`.
  // Change the backend (i.e. codegen target).
  backend = b
  // Call `f`.
  val result = f(...)
  // Copy `result` to the old backend, then reset the backend.
  backend = originalBackend
}

// Revised based on @GSAir's suggestion below.
def withBackend[T, U](b: Backend, input: T)(f: T => U) = {
  val originalBackend = backend
  // Transfer input to the new backend.
  transferBetweenBackends(originalBackend, b, input)

  // Change the backend (i.e. codegen target), then call `f`.
  backend = b
  val result = f(input)

  // Transfer `result` to the old backend, then reset the backend.
  transferBetweenBackends(b, originalBackend, result)
  backend = originalBackend
}

// Usage:
def withGPU[T, U](input: T)(f: T => U) = withBackend(BackendCudnn, input)(f)

// Type-inference: `withGPU[Tensor, Tensor]` below.
withGPU(Tensor.ones(2, 3)) { x => x + x }

If you have feedback or related ideas, please share!

@GSAir
Copy link
Collaborator

GSAir commented Oct 4, 2018

For the type of f, you may want to be flexible:

def withBackend[T,U](b: Backend, input: T)(f: T => U) = {
}

The currying form allow you to do:

withBackend[Int, Unit](CPU, 0) { in =>
    printf("%d\n", in)
}

dan-zheng added a commit to dan-zheng/Lantern that referenced this issue Oct 6, 2018
`withBackend` explicitly demarcates code that should be run on a
different backend. It transfers inputs/results between backends
automatically.

Design info: feiwang3311#8 (comment)
dan-zheng added a commit to dan-zheng/Lantern that referenced this issue Oct 6, 2018
`withBackend` explicitly demarcates code that should be run on a
different backend. It transfers inputs/results between backends
automatically.

Design info: feiwang3311#8 (comment)
dan-zheng added a commit to dan-zheng/Lantern that referenced this issue Oct 6, 2018
`withBackend` explicitly demarcates code that should be run on a
different backend. It transfers inputs/results between backends
automatically.

Design info: feiwang3311#8 (comment)
@dan-zheng
Copy link
Collaborator

dan-zheng commented Oct 7, 2018

I propose to change the cuDNN backend into a cuBLAS+cuDNN backend.

cuDNN by itself defines high-level NN operations, like convolutions and activation functions.
However, it doesn't define lower-level primitives, like matrix multiplication or basic elementwise ops.
Thus, a standalone cuDNN backend is not particularly useful.

A cuBlas+cuDNN backend can use cuBLAS for low-level linear algebra primitives and cuDNN for optimized high-level NN ops.

dan-zheng added a commit to dan-zheng/Lantern that referenced this issue Oct 7, 2018
Rationale here: feiwang3311#8 (comment)

- Move GPU test utilities to `LanternFunSuite`.
- Add cuDNN test.
dan-zheng added a commit to dan-zheng/Lantern that referenced this issue Oct 7, 2018
Rationale here: feiwang3311#8 (comment)

- Move GPU test utilities to `LanternFunSuite`.
- Add cuDNN test.
dan-zheng added a commit to dan-zheng/Lantern that referenced this issue Oct 7, 2018
Rationale here: feiwang3311#8 (comment)

- Move GPU test utilities to `LanternFunSuite`.
- Improve CUDA/cuBLAS/cuDNN error messages.
  - Example: "cuBLAS error occurred: 7 (lantern-snippet.cpp:150)"
- Add cuDNN test.
dan-zheng added a commit to dan-zheng/Lantern that referenced this issue Oct 7, 2018
Rationale here: feiwang3311#8 (comment)

- Move GPU test utilities to `LanternFunSuite`.
- Improve CUDA/cuBLAS/cuDNN error messages.
  - Example: "cuBLAS error occurred: 7 (lantern-snippet.cpp:150)"
- Add cuDNN test.
@feiwang3311
Copy link
Owner

https://github.com/feiwang3311/Lantern/blob/master/src/main/scala/lantern/ad_lms_vector.scala#L522

@dan-zheng should this line be comparing this.shape(1) with that.shape(0)?

@dan-zheng
Copy link
Collaborator

Thanks for the catch! Fixed in db0a80f.

@TiarkRompf
Copy link
Collaborator Author

I propose to change the cuDNN backend into a cuBLAS+cuDNN backend.

Absolutely makes sense. The use case I had in mind was cuBLAS without cuDNN, but that's covered with BackendCudnn extends BackendCublas.

@dan-zheng
Copy link
Collaborator

FYI: I added a concrete todo list to the issue description.
Preliminary MNIST CNN support is nearly done. Afterwards, we can evaluate performance and optimize.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants