diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index f74dd490da0ac2..ef19af41e3a679 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -930,9 +930,9 @@ int GraphDataGenerator::GenerateBatch() { stream_); } if (sage_mode_) { - size_t temp_storage_bytes = slot_instance * slot_num_ * sizeof(uint64_t); + size_t temp_storage_bytes = slot_instance * fea_num_per_node_ * sizeof(uint64_t); // No need to allocate a new d_feature_buf_ if the old one is enough. - if (d_feature_buf_->size() < temp_storage_bytes) { + if (d_feature_buf_ == NULL || d_feature_buf_->size() < temp_storage_bytes) { d_feature_buf_ = memory::AllocShared(place_, temp_storage_bytes); } } @@ -1537,7 +1537,9 @@ void GraphDataGenerator::AllocResource( if (slot_num_ > 0) { if (!sage_mode_) { d_feature_buf_ = memory::AllocShared( - place_, (batch_size_ * 2 * 2) * slot_num_ * sizeof(uint64_t)); + place_, (batch_size_ * 2 * 2) * fea_num_per_node_ * sizeof(uint64_t)); + } else { + d_feature_buf_ = NULL; } } d_pair_num_ = memory::AllocShared(place_, sizeof(int));