From f8d71f18fa1f4fdeb3d2e283ca70cc07ffd390ef Mon Sep 17 00:00:00 2001 From: funsimple Date: Thu, 27 May 2021 15:34:05 +0800 Subject: [PATCH] add GPU operator: sparse_fill_empty_rows and sparse_reshape --- .../dynamic_embedding/core/BUILD | 10 +- .../core/kernels/sparse_fill_empty_rows_op.cc | 61 ++++ .../kernels/sparse_fill_empty_rows_op.cu.cc | 276 ++++++++++++++++++ .../core/kernels/sparse_fill_empty_rows_op.h | 38 +++ .../core/kernels/sparse_reshape_op.cc | 49 ++++ .../core/kernels/sparse_reshape_op.cu.cc | 173 +++++++++++ .../core/kernels/sparse_reshape_op.h | 41 +++ .../dynamic_embedding/core/ops/math_ops.cc | 59 ++++ .../python/kernel_tests/math_grad_test.py | 32 ++ .../python/kernel_tests/math_ops_test.py | 68 ++++- .../dynamic_embedding/python/ops/BUILD | 2 +- .../dynamic_embedding/python/ops/math_grad.py | 17 +- .../dynamic_embedding/python/ops/math_ops.py | 214 +++++++++++++- 13 files changed, 1030 insertions(+), 10 deletions(-) create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cc create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cu.cc create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.h create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cc create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cu.cc create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.h diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/BUILD b/tensorflow_recommenders_addons/dynamic_embedding/core/BUILD index 965e5b313..0931d51bb 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/BUILD +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/BUILD @@ -23,15 +23,23 @@ custom_op_library( ) custom_op_library( - name = "_segment_reduction_ops.so", + name = "_math_ops.so", srcs = [ "kernels/segment_reduction_ops.h", "kernels/segment_reduction_ops_impl.cc", "kernels/segment_reduction_ops_impl.h", + "kernels/sparse_fill_empty_rows_op.cc", + "kernels/sparse_fill_empty_rows_op.h", + "kernels/sparse_reshape_op.cc", + "kernels/sparse_reshape_op.h", "ops/math_ops.cc", ], cuda_srcs = [ "kernels/segment_reduction_ops.h", "kernels/segment_reduction_ops_gpu.cu.cc", + "kernels/sparse_fill_empty_rows_op.h", + "kernels/sparse_fill_empty_rows_op.cu.cc", + "kernels/sparse_reshape_op.h", + "kernels/sparse_reshape_op.cu.cc", ], ) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cc new file mode 100644 index 000000000..c5fe30c6d --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cc @@ -0,0 +1,61 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "sparse_fill_empty_rows_op.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +template +class SparseFillEmptyRowsOp : public OpKernel { + public: + explicit SparseFillEmptyRowsOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + functor::SparseFillEmptyRowsFunctor()(context); + } +}; + +#if GOOGLE_CUDA +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("TFRA>SparseFillEmptyRows") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + SparseFillEmptyRowsOp) +TF_CALL_int32(REGISTER_KERNELS); +TF_CALL_int64(REGISTER_KERNELS); +TF_CALL_float(REGISTER_KERNELS); +TF_CALL_double(REGISTER_KERNELS); +#undef REGISTER_KERNELS +#endif +} // namespace tensorflow diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cu.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cu.cc new file mode 100644 index 000000000..9f8085d40 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.cu.cc @@ -0,0 +1,276 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include +#include +#include +#include +#include + +#include "cub/device/device_scan.cuh" +#include "sparse_fill_empty_rows_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +// calculate how many rows are empty and record their location +__global__ void SparseFillEmptyRowCountKernel( + const int64* indices, const int nnz, const int64* input_shape, + int* row_nnz_count, // size: num_rows + int64* input_row_offset, // size: num_rows + 1 + int64* output_row_offset // size: num_rows + 1 +) { + GPU_1D_KERNEL_LOOP(idx, nnz) { + const int64 num_rows = input_shape[0]; + + int64 _row = indices[idx * 2]; + atomicAdd(row_nnz_count + _row, 1); + } +} + +__global__ void SparseFillEmptyRowAddOneKernel(const int64* input_shape, + int* row_nnz_count) { + const int64 num_rows = input_shape[0]; + GPU_1D_KERNEL_LOOP(id_row, num_rows) { + if (row_nnz_count[id_row] == 0) { + row_nnz_count[id_row] += 1; + } + } +} + +// copy the original data to output data address and fill default value to empty +// rows +template +__global__ void SparseFillEmptyRowFillKernel( + // inputs + const int64* input_indices, const T* input_values, const int64* input_shape, + const T* default_value, const int64* input_row_offset, + const int64* output_row_offset, + // outputs + int64* output_indices, T* output_values, bool* empty_row_indicator, + int64* reverse_index_map) { + const int64 num_rows = input_shape[0]; + GPU_1D_KERNEL_LOOP(id_row, num_rows) { +#pragma unroll + for (int i = 0; i < input_row_offset[id_row + 1] - input_row_offset[id_row]; + i++) { + output_values[output_row_offset[id_row] + i] = + input_values[input_row_offset[id_row] + i]; + output_indices[2 * (output_row_offset[id_row] + i) + 0] = + id_row; // no need to read indices from input again; + output_indices[2 * (output_row_offset[id_row] + i) + 1] = + input_indices[2 * (input_row_offset[id_row] + i) + 1]; + if (reverse_index_map) { + reverse_index_map[input_row_offset[id_row] + i] = + output_row_offset[id_row] + i; + } + } + + // for empty rows + if (input_row_offset[id_row + 1] == input_row_offset[id_row]) { + // insert default value + output_values[output_row_offset[id_row]] = *default_value; + output_indices[2 * output_row_offset[id_row] + 0] = id_row; + output_indices[2 * output_row_offset[id_row] + 1] = 0; + + // mark as empty + if (empty_row_indicator) { + empty_row_indicator[id_row] = true; + } + } + } + return; +} + +namespace functor { +template +void SparseFillEmptyRowsGpuImpl(OpKernelContext* context, + const int64* input_indices, + const T* input_values, const int64 nnz, + const int64* input_shape, + const T* default_value) { + auto d = context->eigen_gpu_device(); + auto OpStream = d.stream(); + int64 dense_row_number; + + // get the dense shape, which is stored in GPU. + // If the dense shape is already in CPU, we don't need to do the copy here. + cudaMemcpyAsync(&dense_row_number, input_shape, sizeof(int64), + cudaMemcpyDeviceToHost, OpStream); + cudaStreamSynchronize(OpStream); + + // temp vector to store start index of each row + Tensor input_row_offset; + Tensor output_row_offset; + Tensor row_nnz_count; // temp buffer for the count kernel, count number of + // non-zero values on each row. + + // the size of input_row_offset and output_row_offset is dense_row_number+1, + // because we need one extra place to store the initial value of the offset 0 + OP_REQUIRES_OK(context, context->allocate_temp( + DT_INT64, TensorShape({dense_row_number + 1}), + &input_row_offset)); + + OP_REQUIRES_OK(context, context->allocate_temp( + DT_INT64, TensorShape({dense_row_number + 1}), + &output_row_offset)); + + OP_REQUIRES_OK( + context, context->allocate_temp( + // use DT_INT32 instead of DT_INT64, because CUDA atomic_add + // only support int32 + DT_INT32, TensorShape({dense_row_number}), &row_nnz_count)); + + cudaMemset(row_nnz_count.flat().data(), 0, + sizeof(int) * dense_row_number); + cudaMemset(input_row_offset.flat().data(), 0, sizeof(int64)); + cudaMemset(output_row_offset.flat().data(), 0, sizeof(int64)); + + // Get the number of rows in each row + GpuLaunchConfig count_kernel_config = GetGpuLaunchConfig(nnz, d); + TF_CHECK_OK(GpuLaunchKernel( + SparseFillEmptyRowCountKernel, count_kernel_config.block_count, + count_kernel_config.thread_per_block, 0, d.stream(), input_indices, nnz, + input_shape, row_nnz_count.flat().data(), + input_row_offset.flat().data(), + output_row_offset.flat().data())); + + /* Calculate the offset of each row of input + * example: the number of rows in each row: [3, 4, 0, 0, 6] + * the offset of each row of input: [0, 3, 7, 7, 7, 13] + */ + // Determine temporary device storage requirements for inclusive prefix sum + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum( + NULL, temp_storage_bytes, row_nnz_count.flat().data(), + input_row_offset.flat().data() + 1, dense_row_number); + + // Allocate temporary storage for inclusive prefix sum + Tensor temp_storage; + OP_REQUIRES_OK( + context, + context->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + void* d_temp_storage = temp_storage.flat().data(); + + // Run inclusive prefix sum + cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, row_nnz_count.flat().data(), + input_row_offset.flat().data() + 1, dense_row_number); + + /* Add 1 to the row whose row count is 0 + * example: the number of rows in each row(row_nnz_count): [3, 4, 0, 0, 6] + * row_nnz_count after the kernel: [3, 4, 1, 1, 6] + */ + GpuLaunchConfig add_kernel_config = GetGpuLaunchConfig(nnz, d); + TF_CHECK_OK(GpuLaunchKernel( + SparseFillEmptyRowAddOneKernel, count_kernel_config.block_count, + count_kernel_config.thread_per_block, 0, d.stream(), input_shape, + row_nnz_count.flat().data())); + + // Calculate the offset of each row of output + cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, row_nnz_count.flat().data(), + output_row_offset.flat().data() + 1, dense_row_number); + + // Read the output size from GPU, which is result of the first kernel. + // copy nnz + num_of_empty_row = output_nnz to CPU + int64 output_nnz; + cudaMemcpyAsync(&output_nnz, + output_row_offset.flat().data() + dense_row_number, + sizeof(int64), cudaMemcpyDeviceToHost, OpStream); + cudaStreamSynchronize(OpStream); + + // Allocate output tensors. + Tensor* output_indices; + Tensor* output_values; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({output_nnz, 2}), + &output_indices)); + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({output_nnz}), + &output_values)); + + bool* empty_row_indicator = nullptr; + if (context->output_required(2)) { + Tensor* empty_row_indicator_t = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(2, TensorShape({dense_row_number}), + &empty_row_indicator_t)); + empty_row_indicator = empty_row_indicator_t->vec().data(); + // assume row not empty first + cudaMemset(empty_row_indicator, false, sizeof(bool) * dense_row_number); + } + + int64* reverse_index_map = nullptr; + if (context->output_required(3)) { + Tensor* reverse_index_map_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(3, TensorShape({nnz}), + &reverse_index_map_t)); + reverse_index_map = reverse_index_map_t->vec().data(); + } + + // Launch the second Kernel to move data and insert value to empty rows. + GpuLaunchConfig config = GetGpuLaunchConfig(dense_row_number, d); + TF_CHECK_OK(GpuLaunchKernel( + SparseFillEmptyRowFillKernel, config.block_count, + config.thread_per_block, 0, d.stream(), input_indices, input_values, + input_shape, default_value, input_row_offset.flat().data(), + output_row_offset.flat().data(), + output_indices->flat().data(), output_values->flat().data(), + empty_row_indicator, reverse_index_map)); +} + +template +struct SparseFillEmptyRowsFunctor { + void operator()(OpKernelContext* context) { + auto input_indices = context->input(0); + auto input_values = context->input(1); + auto input_shape = context->input(2); + auto default_value = context->input(3); + + const int64 nnz = input_indices.shape().dim_size(0); + + SparseFillEmptyRowsGpuImpl(context, input_indices.flat().data(), + input_values.flat().data(), nnz, + input_shape.flat().data(), + default_value.flat().data()); + } +}; + +#define DEFINE_GPU_KERNELS(type) \ + template struct SparseFillEmptyRowsFunctor; + +TF_CALL_int32(DEFINE_GPU_KERNELS); +TF_CALL_int64(DEFINE_GPU_KERNELS); +TF_CALL_float(DEFINE_GPU_KERNELS); +TF_CALL_double(DEFINE_GPU_KERNELS); + +} // namespace functor +} // namespace tensorflow + +#endif diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.h new file mode 100644 index 000000000..08b0bf1f9 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_fill_empty_rows_op.h @@ -0,0 +1,38 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { +namespace functor { + +template +struct SparseFillEmptyRowsFunctor { + void operator()(OpKernelContext* ctx); +}; + +} // namespace functor +} // namespace tensorflow diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cc new file mode 100644 index 000000000..9aa31ce88 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cc @@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "sparse_reshape_op.h" + +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +template +class SparseReshapeOp : public OpKernel { + public: + explicit SparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + functor::SparseReshapeFunctor()(context); + } +}; + +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("TFRA>SparseReshape").Device(DEVICE_GPU), + SparseReshapeOp); +#endif +} // namespace tensorflow diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cu.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cu.cc new file mode 100644 index 000000000..5a3baf901 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.cu.cc @@ -0,0 +1,173 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include +#include +#include +#include +#include + +#include "sparse_reshape_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +template +__global__ void SparseReshapeKernel(const IndexType* input_indices_in, + const int nnz, const int input_dim, + const IndexType* input_shape_in, + const IndexType* target_shape_in, + IndexType* output_indices, + const int output_dim) { +#define RESHAPE_KERNEL_MAX_DIM 32 + int64 input_strides[RESHAPE_KERNEL_MAX_DIM]; + int64 output_strides[RESHAPE_KERNEL_MAX_DIM]; +#undef RESHAPE_KERNEL_MAX_DIM + // compute input strides + input_strides[input_dim - 1] = 1; + for (int i = input_dim - 2; i >= 0; i--) { + input_strides[i] = input_strides[i + 1] * input_shape_in[i + 1]; + } + // compute output strides + output_strides[output_dim - 1] = 1; + for (int i = output_dim - 2; i >= 0; i--) { + output_strides[i] = output_strides[i + 1] * target_shape_in[i + 1]; + } + + GPU_1D_KERNEL_LOOP(idx, nnz) { + IndexType id = 0; +#pragma unroll + for (int i = 0; i < input_dim; i++) { + id = id + input_strides[i] * input_indices_in[idx * input_dim + i]; + } +#pragma unroll + for (int i = 0; i < output_dim; i++) { + output_indices[idx * output_dim + i] = id / output_strides[i]; + id = id % output_strides[i]; + } + } +} + +template +__global__ void DetermineOutputShapeKernel(const int input_dim, + const IndexType* input_shape, + const IndexType* target_shape, + const int output_dim, + IndexType* output_shape) { + int64 dense_size = 1; + for (int d = 0; d < input_dim; ++d) { + dense_size *= input_shape[d]; + } + int64 product = 1; + int unknown_index = -1; + for (int d = 0; d < output_dim; ++d) { + const int64 size = target_shape[d]; + if (size == -1) { + unknown_index = d; + output_shape[d] = 1; + } else { + product *= size; + output_shape[d] = size; + } + } + + if (unknown_index != -1) { + const int64 missing = dense_size / product; + output_shape[unknown_index] = missing; + } +} + +namespace functor { + +template +struct SparseReshapeImpl { + public: + static Status Compute(OpKernelContext* context, + const IndexType* input_indices_in, const int nnz, + const int input_dim, const IndexType* input_shape_in, + const IndexType* target_shape_in, + IndexType* output_indices, int output_dim, + IndexType* output_shape) { + auto d = context->eigen_gpu_device(); + DetermineOutputShapeKernel<<<1, 1, 0, d.stream()>>>( + input_dim, input_shape_in, target_shape_in, output_dim, output_shape); + + GpuLaunchConfig config = GetGpuLaunchConfig(nnz, d); + TF_CHECK_OK(GpuLaunchKernel( + SparseReshapeKernel, config.block_count, + config.thread_per_block, 0, d.stream(), input_indices_in, nnz, + input_dim, input_shape_in, output_shape, output_indices, output_dim)); + + return Status::OK(); + } +}; + +template class SparseReshapeFunctor; +template class SparseReshapeFunctor; + +void SparseReshapeFunctor::operator()(OpKernelContext* context) { + auto input_indices = context->input(0); + auto input_shape_in = context->input(1); + auto target_shape_in = context->input(2); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()), + errors::InvalidArgument( + "Input indices should be a matrix but received shape ", + input_indices.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()), + errors::InvalidArgument( + "Input shape should be a vector but received shape ", + input_shape_in.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()), + errors::InvalidArgument( + "Target shape should be a vector but received shape ", + target_shape_in.shape().DebugString())); + + const int nnz = input_indices.shape().dim_size(0); + const int rank = input_shape_in.NumElements(); + const int output_rank = target_shape_in.NumElements(); + + Tensor* output_indices = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({nnz, output_rank}), + &output_indices)); + + Tensor* result_shape = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 1, TensorShape({output_rank}), &result_shape)); + + OP_REQUIRES_OK(context, SparseReshapeImpl::Compute( + context, input_indices.flat().data(), nnz, + rank, input_shape_in.flat().data(), + target_shape_in.flat().data(), + output_indices->flat().data(), output_rank, + result_shape->flat().data())); +} +} // namespace functor +} // namespace tensorflow + +#endif diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.h new file mode 100644 index 000000000..8bebdc994 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/sparse_reshape_op.h @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TFRA_CORE_KERNELS_SPARSE_RESHAPE_OP_H_ +#define TFRA_CORE_KERNELS_SPARSE_RESHAPE_OP_H_ + +// Functor definition for SparseReshapeOp, must be compilable by nvcc. + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace functor { + +template +struct SparseReshapeFunctor { + void operator()(OpKernelContext* context); +}; + +#if GOOGLE_CUDA +template <> +struct SparseReshapeFunctor { + void operator()(OpKernelContext* context); +}; +#endif + +} // namespace functor +} // namespace tensorflow + +#endif // TFRA_CORE_KERNELS_SPARSE_RESHAPE_OP_H_ diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/ops/math_ops.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/ops/math_ops.cc index c6e4fa9aa..0ed1038e9 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/ops/math_ops.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/ops/math_ops.cc @@ -113,6 +113,65 @@ REGISTER_OP("TFRA>SparseSegmentSumWithNumSegments") .Attr("Tnumsegments: {int32,int64} = DT_INT32") .Attr("Tsegmentids: {int32, int64} = DT_INT32") .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); + +REGISTER_OP("TFRA>SparseFillEmptyRows") + .Input("indices: int64") + .Input("values: T") + .Input("dense_shape: int64") + .Input("default_value: T") + .Output("output_indices: int64") + .Output("output_values: T") + .Output("empty_row_indicator: bool") + .Output("reverse_index_map: int64") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input_indices = c->input(0); + TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices)); + ShapeHandle input_values = c->input(1); + TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values)); + ShapeHandle input_shape = c->input(2); + TF_RETURN_IF_ERROR(c->WithRank(input_shape, 1, &input_shape)); + ShapeHandle default_value = c->input(3); + TF_RETURN_IF_ERROR(c->WithRank(default_value, 0, &default_value)); + DimensionHandle N = c->Dim(input_indices, 0); + TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1), + c->Dim(input_shape, 0), &unused_dim)); + ShapeHandle output_indices = + c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape)); + ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim); + ShapeHandle constant_input_shape; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &constant_input_shape)); + ShapeHandle empty_row_indicator = + c->Vector(c->Dim(constant_input_shape, 0)); + ShapeHandle reverse_index_map = c->Vector(N); + c->set_output(0, output_indices); + c->set_output(1, output_values); + c->set_output(2, empty_row_indicator); + c->set_output(3, reverse_index_map); + return Status::OK(); + }); + +REGISTER_OP("TFRA>SparseReshape") + .Input("input_indices: int64") + .Input("input_shape: int64") + .Input("new_shape: int64") + .Output("output_indices: int64") + .Output("output_shape: int64") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle indices; + ShapeHandle unused; + ShapeHandle new_shape; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape)); + + c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0))); + c->set_output(1, new_shape); + return Status::OK(); + }); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/math_grad_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/math_grad_test.py index 871fd30c6..2b09a69eb 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/math_grad_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/math_grad_test.py @@ -18,14 +18,18 @@ from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow_recommenders_addons.dynamic_embedding.python.ops import math_ops as de_math from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test default_config = config_pb2.ConfigProto( @@ -87,5 +91,33 @@ def test_value(self): self.assertAllEqual(self.evaluate(result), self.evaluate(expected)) +class SparseFillEmptyRowsGpuGradTest(test.TestCase): + + def _SparseTensor_5x6(self): + ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]]) + val = np.array([0, 10, 13, 14, 32, 33]) + shape = np.array([5, 6]) + return sparse_tensor.SparseTensor(constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.float32), + constant_op.constant(shape, dtypes.int64)) + + def backward_compute(self, sp_input, default_value): + with backprop.GradientTape(persistent=True) as tape: + tape.watch(sp_input) + result_output, result_indicator = de_math.sparse_fill_empty_rows( + sp_input, default_value) + expected_output, expected_indicator = sparse_ops.sparse_fill_empty_rows( + sp_input, default_value) + result = tape.gradient(result_output.values, sp_input.values) + expected = tape.gradient(expected_output.values, sp_input.values) + return result, expected + + @test_util.run_in_graph_and_eager_modes + def test_value(self): + with self.session(use_gpu=use_gpu, config=default_config): + result, expected = self.backward_compute(self._SparseTensor_5x6(), -1) + self.assertAllEqual(self.evaluate(result), self.evaluate(expected)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/math_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/math_ops_test.py index 8b39596e2..9660a5690 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/math_ops_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/math_ops_test.py @@ -18,15 +18,19 @@ from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow_recommenders_addons.dynamic_embedding.python.ops import math_ops as de_math from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test default_config = config_pb2.ConfigProto( @@ -108,5 +112,67 @@ def test_value(self): self.assertAllEqual(self.evaluate(result), self.evaluate(expected)) +class SparseFillEmptyRowsGpuTest(test.TestCase): + + def _SparseTensor_5x6(self): + ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]]) + val = np.array([0, 10, 13, 14, 32, 33]) + shape = np.array([5, 6]) + return sparse_tensor.SparseTensor(constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) + + def forward_compute(self, sp_input, default_value): + result_output, result_indicator = de_math.sparse_fill_empty_rows( + sp_input, default_value) + expected_output, expected_indicator = sparse_ops.sparse_fill_empty_rows( + sp_input, default_value) + return result_output, result_indicator, expected_output, expected_indicator + + @test_util.run_in_graph_and_eager_modes + def test_value(self): + with self.session(use_gpu=use_gpu, config=default_config): + result_output, result_indicator, expected_output, expected_indicator = self.forward_compute( + self._SparseTensor_5x6(), -1) + result_output, result_indicator, expected_output, expected_indicator = self.evaluate( + [ + result_output, result_indicator, expected_output, + expected_indicator + ]) + self.assertAllEqual(result_output.indices, expected_output.indices) + self.assertAllEqual(result_output.values, expected_output.values) + self.assertAllEqual(result_output.dense_shape, + expected_output.dense_shape) + self.assertAllEqual(result_indicator, expected_indicator) + + +class SparseReshapeGpuTest(test.TestCase): + + def _SparseTensor_5x6(self): + ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]]) + val = np.array([0, 10, 13, 14, 32, 33]) + shape = np.array([5, 6]) + return sparse_tensor.SparseTensor(constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) + + def forward_compute(self, sp_input, shape): + result_output = de_math.sparse_reshape(sp_input, shape) + expected_output = sparse_ops.sparse_reshape(sp_input, shape) + return result_output, expected_output + + @test_util.run_in_graph_and_eager_modes + def test_value(self): + with self.session(use_gpu=use_gpu, config=default_config): + result_output, expected_output = self.forward_compute( + self._SparseTensor_5x6(), (2, 15)) + result_output, expected_output = self.evaluate( + [result_output, expected_output]) + self.assertAllEqual(result_output.indices, expected_output.indices) + self.assertAllEqual(result_output.values, expected_output.values) + self.assertAllEqual(result_output.dense_shape, + expected_output.dense_shape) + + if __name__ == "__main__": test.main() diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/BUILD b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/BUILD index ae0264856..7a78735a8 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/BUILD +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/BUILD @@ -13,7 +13,7 @@ py_library( srcs = glob(["*.py"]), data = [ "//tensorflow_recommenders_addons/dynamic_embedding/core:_cuckoo_hashtable_ops.so", - "//tensorflow_recommenders_addons/dynamic_embedding/core:_segment_reduction_ops.so", + "//tensorflow_recommenders_addons/dynamic_embedding/core:_math_ops.so", ], srcs_version = "PY2AND3", ) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/math_grad.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/math_grad.py index 1e4072342..5d7a81864 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/math_grad.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/math_grad.py @@ -19,8 +19,9 @@ from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_sparse_ops +from tensorflow.python.ops import math_ops @ops.RegisterGradient("TFRA>SparseSegmentSum") @@ -38,3 +39,17 @@ def _TfraSparseSegmentSumWithNumSegmentsGrad(op, grad): return (math_ops.unsorted_segment_sum(array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None, None, None) + + +@ops.RegisterGradient("TFRA>SparseFillEmptyRows") +def _SparseFillEmptyRowsGrad(op, unused_grad_output_indices, output_grad_values, + unused_grad_empty_row_indicator, + unused_grad_reverse_index_map): + """Gradients for TFRA>SparseFillEmptyRows.""" + reverse_index_map = op.outputs[3] + + d_values, d_default_value = gen_sparse_ops.sparse_fill_empty_rows_grad( + reverse_index_map=reverse_index_map, grad_values=output_grad_values) + + # d_indices, d_values, d_dense_shape, d_default_value. + return [None, d_values, None, d_default_value] \ No newline at end of file diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/math_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/math_ops.py index 39a446576..ab179618e 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/math_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/math_ops.py @@ -19,20 +19,45 @@ from __future__ import print_function import functools +import numpy as np import tensorflow as tf from tensorflow.python.eager import context from tensorflow.python.framework import config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging from tensorflow.python.training.saver import BaseSaverBuilder from tensorflow_recommenders_addons.utils.resource_loader import LazySO from tensorflow_recommenders_addons.dynamic_embedding.python.ops import math_grad -segment_reduction_ops = LazySO( - "dynamic_embedding/core/_segment_reduction_ops.so").ops +tfra_math_ops = LazySO("dynamic_embedding/core/_math_ops.so").ops + + +def _convert_to_sparse_tensor(sp_input): + """Convert `sp_input` to `SparseTensor` and return it. + + Args: + sp_input: `SparseTensor` or `SparseTensorValue`. + + Returns: + `sp_input` converted to `SparseTensor`. + + Raises: + ValueError: if `sp_input` is neither `SparseTensor` nor `SparseTensorValue`. + """ + if isinstance(sp_input, sparse_tensor.SparseTensorValue): + return sparse_tensor.SparseTensor.from_value(sp_input) + if not isinstance(sp_input, sparse_tensor.SparseTensor): + raise TypeError("Input must be a SparseTensor.") + return sp_input def sparse_segment_sum(data, @@ -117,7 +142,7 @@ def _sparse_segment_sum_gpu(data, segment_ids, name=None, num_segments=None): - if not hasattr(segment_reduction_ops, 'tfra_sparse_segment_sum'): + if not hasattr(tfra_math_ops, 'tfra_sparse_segment_sum'): tf_logging.warn('`tfra.dynamic_embedding.sparse_segment_sum` is not' ' found. Use tf.sparse.segment_sum instead.') return tf.sparse.segment_sum(data, @@ -127,12 +152,189 @@ def _sparse_segment_sum_gpu(data, num_segments=num_segments) if num_segments is not None: - return segment_reduction_ops.tfra_sparse_segment_sum_with_num_segments( + return tfra_math_ops.tfra_sparse_segment_sum_with_num_segments( data=data, indices=indices, segment_ids=segment_ids, name=name, num_segments=num_segments) else: - return segment_reduction_ops.tfra_sparse_segment_sum( - data=data, indices=indices, segment_ids=segment_ids, name=name) + return tfra_math_ops.tfra_sparse_segment_sum(data=data, + indices=indices, + segment_ids=segment_ids, + name=name) + + +def sparse_fill_empty_rows(sp_input, default_value, name=None): + """Fills empty rows in the input 2-D `SparseTensor` with a default value. + + It do same things as `tf.sparse.fill_empty_rows`. Here we provide GPU impl. + + Go [tf api](https://www.tensorflow.org/api_docs/python/tf/sparse/fill_empty_rows) + for more details. + + Args: + sp_input: A `SparseTensor` with shape `[N, M]`. + default_value: The value to fill for empty rows, with the same type as + `sp_input.` + name: A name prefix for the returned tensors (optional) + + Returns: + sp_ordered_output: A `SparseTensor` with shape `[N, M]`, and with all empty + rows filled in with `default_value`. + empty_row_indicator: A bool vector of length `N` indicating whether each + input row was empty. + """ + gpu_devices = config.list_physical_devices('GPU') + if gpu_devices: + if context.executing_eagerly(): + try: + return _sparse_fill_empty_rows_gpu(sp_input, default_value, name=name) + except errors.NotFoundError: + tf_logging.warn('`tfra.dynamic_embedding.sparse_fill_empty_rows` is not' + ' found. Use tf.sparse.fill_empty_rows instead.') + return tf.sparse.fill_empty_rows(sp_input, default_value, name=name) + + else: + predef = _sparse_fill_empty_rows_gpu(sp_input, default_value, name=name) + use_origin = True + if predef[0].values.device == '': + tf_logging.warn( + 'Haven\'t specify devices while GPU devices are' + 'available: {}, use CPU by default.'.format(gpu_devices)) + else: + device_type = predef[0].values.device.split(':')[-2][-3:].lower() + if device_type == 'gpu': + use_origin = False + + if use_origin: + return tf.sparse.fill_empty_rows(sp_input, default_value, name=name) + return predef + + else: + return tf.sparse.fill_empty_rows(sp_input, default_value, name=name) + + +def _sparse_fill_empty_rows_gpu(sp_input, default_value, name=None): + if not hasattr(tfra_math_ops, 'tfra_sparse_fill_empty_rows'): + tf_logging.warn('`tfra.dynamic_embedding.sparse_fill_empty_rows` is not' + ' found. Use tf.sparse.fill_empty_rows instead.') + return tf.sparse.fill_empty_rows(sp_input, default_value, name=name) + + sp_input = _convert_to_sparse_tensor(sp_input) + with ops.name_scope(name, "SparseFillEmptyRows", [sp_input]): + default_value = ops.convert_to_tensor(default_value, + dtype=sp_input.values.dtype) + (output_indices, output_values, empty_row_indicator, + unused_reverse_index_map) = tfra_math_ops.tfra_sparse_fill_empty_rows( + indices=sp_input.indices, + values=sp_input.values, + dense_shape=sp_input.dense_shape, + default_value=default_value) + return (sparse_tensor.SparseTensor(indices=output_indices, + values=output_values, + dense_shape=sp_input.dense_shape), + empty_row_indicator) + + +def sparse_reshape(sp_input, shape, name=None): + """Reshapes a `SparseTensor` to represent values in a new dense shape. + + It do same things as `tf.sparse.reshape`. Here we provide GPU impl. + + Go [tf api](https://www.tensorflow.org/api_docs/python/tf/sparse/reshape) + for more details. + + Args: + sp_input: The input `SparseTensor`. + shape: A 1-D (vector) int64 `Tensor` specifying the new dense shape of the + represented `SparseTensor`. + name: A name prefix for the returned tensors (optional) + + Returns: + A `SparseTensor` with the same non-empty values but with indices calculated + by the new dense shape. + """ + gpu_devices = config.list_physical_devices('GPU') + if gpu_devices: + if context.executing_eagerly(): + try: + return _sparse_reshape_gpu(sp_input, shape, name=name) + except errors.NotFoundError: + tf_logging.warn('`tfra.dynamic_embedding.sparse_reshape` is not' + ' found. Use tf.sparse.reshape instead.') + return tf.sparse.reshape(sp_input, shape, name=name) + + else: + predef = _sparse_reshape_gpu(sp_input, shape, name=name) + use_origin = True + if predef.values.device == '': + tf_logging.warn( + 'Haven\'t specify devices while GPU devices are' + 'available: {}, use CPU by default.'.format(gpu_devices)) + else: + device_type = predef.values.device.split(':')[-2][-3:].lower() + if device_type == 'gpu': + use_origin = False + + if use_origin: + return tf.sparse.reshape(sp_input, shape, name=name) + return predef + + else: + return tf.sparse.reshape(sp_input, shape, name=name) + + +def _sparse_reshape_gpu(sp_input, shape, name=None): + if not hasattr(tfra_math_ops, 'tfra_sparse_reshape'): + tf_logging.warn('`tfra.dynamic_embedding.sparse_reshape` is not' + ' found. Use tf.sparse.reshape instead.') + return tf.sparse.reshape(sp_input, shape, name=name) + + sp_input = _convert_to_sparse_tensor(sp_input) + shape = math_ops.cast(shape, dtype=dtypes.int64) + with ops.name_scope(name, "SparseReshape", [sp_input]): + # shape = ops.convert_to_tensor(shape, dtype=sp_input.values.dtype) + reshaped_ind, reshaped_shape = tfra_math_ops.tfra_sparse_reshape( + sp_input.indices, sp_input.dense_shape, shape, name=name) + + reshaped_shape_const = tensor_util.constant_value_as_shape(shape) + reshaped_shape_const = (reshaped_shape_const.as_list() + if reshaped_shape_const.ndims is not None else None) + + if (reshaped_shape_const is not None and sp_input.shape.is_fully_defined()): + # constant_value_as_shape tends to get more information about the partial + # shape values, but here we specifically need to know if the *user* passed + # a shape with 2+ unknown dimensions; and for that constant_value + # provides either the user's direct value or None if only partial elements + # are known via the python shape inference code. + shape_const_by_user = tensor_util.constant_value(shape) + if shape_const_by_user is not None: + num_implied_by_user = sum(d == -1 for d in shape_const_by_user) + if num_implied_by_user > 1: + raise ValueError( + "At most one dimension can be inferred (-1). Found: %s" % + shape_const_by_user) + original_reshaped_shape = list(reshaped_shape_const) # A copy + in_shape_size = np.prod(sp_input.shape.as_list()) + num_implied = sum(dim is None for dim in reshaped_shape_const) + if num_implied == 1: + implied_idx = original_reshaped_shape.index(None) + non_implied_idx = (original_reshaped_shape[:implied_idx] + + original_reshaped_shape[implied_idx + 1:]) + reshaped_shape_const[implied_idx] = int(in_shape_size // + np.prod(non_implied_idx)) + if num_implied <= 1: + reshaped_size = np.prod(reshaped_shape_const) + if reshaped_size != in_shape_size: + raise ValueError( + "Cannot reshape a tensor with %d elements to shape %s " + "(%d elements)." % + (in_shape_size, original_reshaped_shape, reshaped_size)) + reshaped_shape = constant_op.constant(reshaped_shape_const, + dtype=dtypes.int64) + + return sparse_tensor.SparseTensor(indices=reshaped_ind, + values=array_ops.identity( + sp_input.values), + dense_shape=reshaped_shape)