Skip to content

Commit

Permalink
Follow up on BC issue for open sourcing TBE inplace update op
Browse files Browse the repository at this point in the history
Reviewed By: jspark1105

Differential Revision: D41717190

fbshipit-source-id: 5a32a5c183db1364082638ee945d19c33a0de950
  • Loading branch information
jianyuh authored and facebook-github-bot committed Dec 4, 2022
1 parent 4cd267c commit 79eb9f9
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,22 @@ void embedding_inplace_update_cuda(
c10::optional<Tensor> lxu_cache_weights = c10::nullopt,
c10::optional<Tensor> lxu_cache_locations = c10::nullopt);

void embedding_inplace_update_cpu(
Tensor dev_weights,
Tensor uvm_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor weights_tys,
Tensor D_offsets,
Tensor update_weights,
Tensor update_table_idx,
Tensor update_row_idx,
Tensor update_offsets,
const int64_t row_alignment,
c10::optional<Tensor> lxu_cache_weights =
c10::nullopt, // Not used, to match cache interface for CUDA op
c10::optional<Tensor> lxu_cache_locations =
c10::nullopt // Not used, to match cache interface for CUDA op
);

} // namespace fbgemm_gpu
2 changes: 1 addition & 1 deletion fbgemm_gpu/src/embedding_inplace_update.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include <c10/cuda/CUDAGuard.h>

#include "embedding_inplace_update.h"
#include "fbgemm_gpu/embedding_inplace_update.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"

using Tensor = at::Tensor;
Expand Down
9 changes: 3 additions & 6 deletions fbgemm_gpu/src/embedding_inplace_update_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <ATen/ATen.h>
#include <torch/library.h>

#include "embedding_inplace_update.h"
#include "fbgemm_gpu/embedding_inplace_update.h"

using Tensor = at::Tensor;

Expand Down Expand Up @@ -72,11 +72,8 @@ void embedding_inplace_update_cpu(
Tensor update_row_idx,
Tensor update_offsets,
const int64_t row_alignment,
c10::optional<Tensor> lxu_cache_weights =
c10::nullopt, // Not used, to match cache interface for CUDA op
c10::optional<Tensor> lxu_cache_locations =
c10::nullopt // Not used, to match cache interface for CUDA op
) {
c10::optional<Tensor> lxu_cache_weights,
c10::optional<Tensor> lxu_cache_locations) {
TENSOR_ON_CPU(dev_weights);
TENSOR_ON_CPU(uvm_weights);
TENSOR_ON_CPU(weights_placements);
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/src/embedding_inplace_update_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/library.h>
#include "embedding_inplace_update.h"
#include "fbgemm_gpu/embedding_inplace_update.h"

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
DISPATCH_TO_CUDA(
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/test/embedding_inplace_update_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/
#include <folly/Random.h>
#include <gtest/gtest.h>
#include "deeplearning/fbgemm/fbgemm_gpu/src/embedding_inplace_update.h"
#include "fbgemm_gpu/embedding_inplace_update.h"

using namespace ::testing;
using namespace fbgemm_gpu;
Expand Down

0 comments on commit 79eb9f9

Please sign in to comment.