Skip to content

Commit

Permalink
Fix on prefetch list issues (thread safe access on forward function) (#…
Browse files Browse the repository at this point in the history
…810)

Summary:
Pull Request resolved: #810

## Root cause:

In inference, the TBE class is the same for all sparse streams (different GPU sparse workers) with multi-thread. This means the variable in the class in global (e.g., https://fburl.com/phabricator/r5m9sq38 ). When handling multiple requests in the predictor, we will have multi-thread invocation of `forward` function. If we have mutable values (e.g., `self.timestep`, `self.timesteps_prefetched`, etc.), it will cause race condition and report the errors in the following post.

## Solution:
- Create a torch custom class `AtomicCounter` (thread safe) for the counter used for timesteps.
- Remove `timesteps_prefetched` list and use the atomic counter to record the size of the prefetch steps
- Create a torch custom class `TensorQueue` (thread safe) for `lxu_cache_locations_list`.

Reviewed By: yinghai

Differential Revision: D32954954

fbshipit-source-id: fc9cdd394c50832d4bdf455eb0336308eaff49fd
  • Loading branch information
jianyuh authored and facebook-github-bot committed Dec 12, 2021
1 parent ff93b93 commit 60be17f
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 27 deletions.
132 changes: 132 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,135 @@ static auto PrunedMapCPURegistry =
[](std::string data) -> c10::intrusive_ptr<PrunedMapCPU> {
return c10::make_intrusive<PrunedMapCPU>(data);
});

class AtomicCounter : public torch::jit::CustomClassHolder {
public:
AtomicCounter() {
counter_ = 0;
}
explicit AtomicCounter(std::string serialized) {
std::stringstream ss(serialized);
int64_t val;
ss >> val;
counter_ = val;
}
int64_t increment() {
return counter_++;
}
int64_t decrement() {
return counter_--;
}
void reset() {
counter_ = 0;
}
int64_t get() {
return counter_;
}
void set(int64_t val) {
counter_ = val;
}

std::string serialize() const {
std::ostringstream oss;
oss << counter_;
return oss.str();
}

private:
std::atomic<int64_t> counter_{0};
};

static auto AtomicCounterRegistry =
torch::class_<AtomicCounter>("fbgemm", "AtomicCounter")
.def(torch::init<>())
.def("increment", &AtomicCounter::increment)
.def("decrement", &AtomicCounter::decrement)
.def("reset", &AtomicCounter::reset)
.def("get", &AtomicCounter::get)
.def("set", &AtomicCounter::set)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<AtomicCounter>& self) -> std::string {
return self->serialize();
},
// __setstate__
[](std::string data) -> c10::intrusive_ptr<AtomicCounter> {
return c10::make_intrusive<AtomicCounter>(data);
});

// Thread-safe Tensor Queue
struct TensorQueue : torch::CustomClassHolder {
explicit TensorQueue(Tensor t) : init_tensor_(t) {}

explicit TensorQueue(std::string serialized) {
torch::serialize::InputArchive archive;
archive.load_from(serialized.data(), serialized.size());

archive.read(std::string("init_tensor"), init_tensor_);
string key = "queue";
Tensor size_tensor;
archive.read(std::string(key + "/size"), size_tensor);
const auto* size_tensor_acc = size_tensor.data_ptr<int64_t>();
int64_t queue_size = size_tensor_acc[0];

for (const auto index : c10::irange(queue_size)) {
Tensor val;
archive.read(key + "/" + c10::to_string(index), queue_[index]);
queue_.push_back(val);
}
}

std::string serialize() const {
torch::serialize::OutputArchive archive(
std::make_shared<torch::jit::CompilationUnit>());
std::ostringstream oss;
archive.write(std::string("init_tensor"), init_tensor_);
string key = "queue";
archive.write(
key + "/size", torch::tensor(static_cast<int64_t>(queue_.size())));
for (const auto index : c10::irange(queue_.size())) {
archive.write(key + "/" + c10::to_string(index), queue_[index]);
}
archive.save_to(oss);
return oss.str();
}

void push(Tensor x) {
std::lock_guard<std::mutex> guard(mutex_);
queue_.push_back(x);
}
Tensor pop() {
std::lock_guard<std::mutex> guard(mutex_);
if (!queue_.empty()) {
auto val = queue_.front();
queue_.pop_front();
return val;
} else {
return init_tensor_;
}
}
int64_t size() {
return queue_.size();
}

private:
std::deque<Tensor> queue_;
std::mutex mutex_;
Tensor init_tensor_;
};

static auto TensorQueueRegistry =
torch::class_<TensorQueue>("fbgemm", "TensorQueue")
.def(torch::init<Tensor>())
.def("push", &TensorQueue::push)
.def("pop", &TensorQueue::pop)
.def("size", &TensorQueue::size)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<TensorQueue>& self) -> std::string {
return self->serialize();
},
// __setstate__
[](std::string data) -> c10::intrusive_ptr<TensorQueue> {
return c10::make_intrusive<TensorQueue>(data);
});
51 changes: 24 additions & 27 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,9 +1600,6 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
"""

embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]]
lxu_cache_locations_list: List[Tensor]
lxu_cache_locations_empty: Tensor
timesteps_prefetched: List[int]

def __init__(
self,
Expand Down Expand Up @@ -1810,11 +1807,11 @@ def align_to_cacheline(a: int) -> int:
cache_sets,
cache_reserved_memory,
)
self.step = 0

@torch.jit.export
def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
self.timestep += 1
self.timesteps_prefetched.append(self.timestep)
self.timestep_counter.increment()
self.timestep_prefetch_size.increment()
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(Tensor.numel)[[Named(self, Tensor)],
# int], Tensor], Tensor, nn.Module]` is not a function.
Expand All @@ -1840,7 +1837,7 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
linear_cache_indices,
self.lxu_cache_state,
self.lxu_cache_weights,
self.timestep,
self.timestep_counter.get(),
self.lxu_state,
)
elif self.cache_algorithm == CacheAlgorithm.LFU:
Expand All @@ -1859,9 +1856,9 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
)

assert (
len(self.lxu_cache_locations_list) < self.max_prefetch_depth
), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
self.lxu_cache_locations_list.append(
self.lxu_cache_locations_list.size() < self.max_prefetch_depth
), f"self.lxu_cache_locations_list has grown to size: {self.lxu_cache_locations_list.size()}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
self.lxu_cache_locations_list.push(
torch.ops.fb.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
Expand All @@ -1874,18 +1871,11 @@ def forward(
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
) -> Tensor:
self.step += 1

if len(self.timesteps_prefetched) == 0:
if self.timestep_prefetch_size.get() <= 0:
self.prefetch(indices, offsets)
self.timestep_prefetch_size.decrement()

self.timesteps_prefetched.pop(0)

lxu_cache_locations = (
self.lxu_cache_locations_empty
if len(self.lxu_cache_locations_list) == 0
else self.lxu_cache_locations_list.pop(0)
)
lxu_cache_locations = self.lxu_cache_locations_list.pop()

assert (
self.weight_initialized
Expand Down Expand Up @@ -2008,14 +1998,21 @@ def _apply_cache_state(
cache_reserved_memory: float,
) -> None:
self.cache_algorithm = cache_algorithm
self.timestep = 1
self.timesteps_prefetched = []
self.timestep_counter = torch.classes.fbgemm.AtomicCounter()
self.timestep_prefetch_size = torch.classes.fbgemm.AtomicCounter()

self.max_prefetch_depth = MAX_PREFETCH_DEPTH
self.lxu_cache_locations_list = []
self.lxu_cache_locations_empty = torch.empty(
0, device=self.current_device, dtype=torch.int32
).fill_(-1)

if self.current_device.type == "meta":
# To reslove "Cannot copy out of meta tensor; no data!" error
lxu_cache_locations_empty = torch.empty(0, dtype=torch.int32).fill_(-1)
else:
lxu_cache_locations_empty = torch.empty(
0, device=self.current_device, dtype=torch.int32
).fill_(-1)
self.lxu_cache_locations_list = torch.classes.fbgemm.TensorQueue(
lxu_cache_locations_empty
)

# NOTE: no cache for CPU mode!
if cache_state.total_cache_hash_size == 0 or self.use_cpu:
Expand Down Expand Up @@ -2151,7 +2148,7 @@ def reset_cache_states(self) -> None:
return
self.lxu_cache_state.fill_(-1)
self.lxu_state.fill_(0)
self.timestep = 1
self.timestep_counter.reset()

@torch.jit.export
def split_embedding_weights(
Expand Down

0 comments on commit 60be17f

Please sign in to comment.