Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bert #71

Open
wants to merge 1 commit into
base: 9-9
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions aten/src/ATen/native/Checkpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,30 @@ Tensor checkpoint_to(at::Tensor const& a, c10::TensorOptions const& b, bool c, b
return CheckpointTensorImpl::make("to", rt, {a})[0];
}

Tensor checkpoint_to(at::Tensor const& a, at::Tensor const& b, bool c, bool d, c10::optional<c10::MemoryFormat> e) {
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
return {vec.at(0).to(b, c, d, e)};
};
return CheckpointTensorImpl::make("to", rt, {a})[0];
}

Tensor checkpoint_to(at::Tensor const& a, c10::ScalarType b, bool c, bool d, c10::optional<c10::MemoryFormat> e) {
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
return {vec.at(0).to(b, c, d, e)};
};
return CheckpointTensorImpl::make("to", rt, {a})[0];
}

Tensor checkpoint_to(at::Tensor const& a, c10::Device b, c10::ScalarType c, bool d, bool e, c10::optional<c10::MemoryFormat> f) {
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
return {vec.at(0).to(b, c, d, e, f)};
};
return CheckpointTensorImpl::make("to", rt, {a})[0];
}

Tensor checkpoint_div(const Tensor& a, const Tensor& b) {
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
Expand Down Expand Up @@ -703,6 +727,40 @@ Tensor checkpoint_sum_dim_IntList(const Tensor& a, c10::ArrayRef<long> b, bool c
return CheckpointTensorImpl::make("sum_dim_IntList", rt, {a})[0];
}

Tensor& checkpoint_transpose_(at::Tensor& a, long b, long c) {
mutate_function_t mt =
[=](const Tensors& vec) {
Tensor a_ = vec.at(0);
a_.transpose_(b, c);
};
CheckpointTensorImpl::mutate("transpose_", mt, {a}, {0});
return a;
}

Tensor checkpoint_transpose(at::Tensor const& a, long b, long c) {
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
return {at::transpose(vec.at(0), b, c)};
};
return CheckpointTensorImpl::make("transpose", rt, {a})[0];
}

Tensor checkpoint_gelu(at::Tensor const& a) {
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
return {at::gelu(vec.at(0))};
};
return CheckpointTensorImpl::make("gelu", rt, {a})[0];
}

Tensor checkpoint_matmul(at::Tensor const& a, at::Tensor const& b) {
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
return {at::matmul(vec.at(0), vec.at(1))};
};
return CheckpointTensorImpl::make("matmul", rt, {a, b})[0];
}

Tensor checkpoint_threshold(const Tensor& a, c10::Scalar b, c10::Scalar c) {
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
Expand Down
25 changes: 25 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2157,8 +2157,17 @@
- func: matmul(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: matmul
Checkpoint: checkpoint_matmul

- func: matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
#use_c10_dispatcher: full
#variants: function, method
#dispatch:
# CPU, CUDA: matmul_out
# Checkpoint: checkpoint_matmul_out
# there is some pt bug which disallow the above code.

- func: matrix_rank.tol(Tensor self, float tol, bool symmetric=False) -> Tensor
use_c10_dispatcher: full
Expand Down Expand Up @@ -2906,6 +2915,7 @@
dispatch:
CPU: gelu_cpu
CUDA: gelu_cuda
Checkpoint: checkpoint_gelu

- func: gelu_backward(Tensor grad, Tensor self) -> Tensor
use_c10_dispatcher: full
Expand Down Expand Up @@ -3426,6 +3436,9 @@
use_c10_dispatcher: full
variants: function, method
device_guard: False
dispatch:
CPU, CUDA: transpose
Checkpoint: checkpoint_transpose

- func: transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)
variants: function, method
Expand All @@ -3441,6 +3454,9 @@
use_c10_dispatcher: full
variants: method
device_guard: False
dispatch:
CPU, CUDA: transpose_
Checkpoint: checkpoint_transpose_

- func: _mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
use_c10_dispatcher: full
Expand Down Expand Up @@ -4461,16 +4477,25 @@
use_c10_dispatcher: full
variants: method
device_guard: False
dispatch:
CPU, CUDA: to
Checkpoint: checkpoint_to

- func: to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor
use_c10_dispatcher: full
variants: method
device_guard: False
dispatch:
CPU, CUDA: to
Checkpoint: checkpoint_to

- func: to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor
use_c10_dispatcher: full
variants: method
device_guard: False
dispatch:
CPU, CUDA: to
Checkpoint: checkpoint_to

- func: meshgrid(Tensor[] tensors) -> Tensor[]
use_c10_dispatcher: full
Expand Down