Skip to content

Commit

Permalink
Upstream native_batch_norm and native_batch_norm_backward shape infer…
Browse files Browse the repository at this point in the history
…ence functions (#978)

* Removed compute_shape_native_batch_norm

* Removed compute_shape_native_batch_norm_backward
  • Loading branch information
henrytwo committed Jul 30, 2022
1 parent 0cee0dc commit 1510eae
Showing 1 changed file with 0 additions and 65 deletions.
65 changes: 0 additions & 65 deletions python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,71 +29,6 @@ compute_shape_mul(const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var, bool training,
double momentum, double eps) {
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());

// A separate mean and var needs to be kept for each channel.
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
int64_t num_features = input.sizes().vec()[1];

if (running_mean.has_value()) {
shapes.emplace_back(
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
} else {
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
}

if (running_var.has_value()) {
shapes.emplace_back(
running_var.value().scalar_type(), running_var.value().sizes().vec());
} else {
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
}
return shapes;
}

std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
const at::Tensor& grad_out, const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
const c10::optional<at::Tensor>& save_mean,
const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
::std::array<bool, 3> output_mask) {
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());

// A separate mean and var needs to be kept for each channel.
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
int64_t num_features = input.sizes().vec()[1];

// `weight` and `bias` are vectors of length C (number of channels)`
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});

return shapes;
}

std::vector<torch::lazy::Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
if (dtype.has_value()) {
return {Shape(*dtype, size)};
Expand Down

0 comments on commit 1510eae

Please sign in to comment.