-
Notifications
You must be signed in to change notification settings - Fork 75
/
Copy pathquant_cuda.cpp
87 lines (74 loc) · 3.37 KB
/
quant_cuda.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <torch/torch.h>
#include "quant_cuda.h"
#include <tuple>
using namespace at;
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
Tensor fixed_point_quantize_nearest(Tensor a, int wl, int fl, bool use_clamp, bool symmetric)
{
CHECK_INPUT(a);
return fixed_point_quantize_nearest_cuda(a, wl, fl, use_clamp, symmetric);
}
std::tuple<Tensor, Tensor> fixed_point_quantize_nearest_mask(Tensor a, int wl, int fl,
bool symmetric)
{
CHECK_INPUT(a);
return fixed_point_quantize_nearest_mask_cuda(a, wl, fl, symmetric);
}
Tensor block_quantize_nearest(Tensor a, int wl, int dim)
{
CHECK_INPUT(a);
return block_quantize_nearest_cuda(a, wl, dim);
}
Tensor block_quantize_sim_nearest(Tensor a, int wl)
{
CHECK_INPUT(a);
return block_quantize_sim_nearest_cuda(a, wl);
}
Tensor float_quantize_nearest(Tensor a, int man_bits, int exp_bits)
{
CHECK_INPUT(a);
return float_quantize_nearest_cuda(a, man_bits, exp_bits);
}
Tensor fixed_point_quantize_stochastic(Tensor a, int wl, int fl, bool use_clamp, bool symmetric)
{
CHECK_INPUT(a);
return fixed_point_quantize_stochastic_cuda(a, wl, fl, use_clamp, symmetric);
}
std::tuple<Tensor, Tensor> fixed_point_quantize_stochastic_mask(Tensor a, int wl, int fl,
bool symmetric)
{
CHECK_INPUT(a);
return fixed_point_quantize_stochastic_mask_cuda(a, wl, fl, symmetric);
}
Tensor block_quantize_stochastic(Tensor a, int wl, int dim)
{
CHECK_INPUT(a);
return block_quantize_stochastic_cuda(a, wl, dim);
}
Tensor block_quantize_sim_stochastic(Tensor a, int wl)
{
CHECK_INPUT(a);
return block_quantize_sim_stochastic_cuda(a, wl);
}
Tensor float_quantize_stochastic(Tensor a, int man_bits, int exp_bits)
{
CHECK_INPUT(a);
return float_quantize_stochastic_cuda(a, man_bits, exp_bits);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("fixed_point_quantize_stochastic", &fixed_point_quantize_stochastic, "Fixed Point Number Stochastic Quantization (CUDA)");
m.def("fixed_point_quantize_stochastic_mask", &fixed_point_quantize_stochastic_mask, "Fixed Point Number Stochastic Quantization (CUDA)");
m.def("block_quantize_stochastic", &block_quantize_stochastic, "Block Floating Point Number Stochastic Quantization (CUDA)");
m.def("block_quantize_sim_stochastic", &block_quantize_sim_stochastic, "Block Floating Point Number Stochastic Quantization (CUDA)");
m.def("float_quantize_stochastic", &float_quantize_stochastic, "Low-Bitwidth Floating Point Number Stochastic Quantization (CUDA)");
m.def("fixed_point_quantize_nearest", &fixed_point_quantize_nearest, "Fixed Point Number Nearest Neighbor Quantization (CUDA)");
m.def("fixed_point_quantize_nearest_mask", &fixed_point_quantize_nearest_mask, "Fixed Point Number Nearest Neighbor Quantization (CUDA)");
m.def("block_quantize_nearest", &block_quantize_nearest, "Block Floating Point Number Nearest Neighbor Quantization (CUDA)");
m.def("block_quantize_sim_nearest", &block_quantize_sim_nearest, "Block Floating Point Number Stochastic Quantization (CUDA)");
m.def("float_quantize_nearest", &float_quantize_nearest, "Low-Bitwidth Floating Point Number Nearest Neighbor Quantization (CUDA)");
}