Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new interfaces and tests for rotary_emb #1360

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9076,6 +9076,37 @@
],
),
),

'apply_rotary': dict(
name=['apply_rotary'],
interface=['CustomizedTest'],
dtype=[np.float64, np.float32, np.float16],
para=dict(
conj=[True, False, False, False, True, True, False, True, False, True],
interleaved=[True, False, True, False, True, False, False, True, True, False]
),
tensor_para=dict(
gen_fn='Genfunc.randn',
args=[
{
"ins": ['input1'],
"shape": ((6,), (32,), (2, 64), (3, 8, 128), (1, 32), (3, 5, 6), (1, 125, 16, 256), (1, 125, 16, 256), (2, 64, 16, 16), (3, 100, 8, 32)),
},
{
"ins": ['input2'],
"shape": ((6,), (32,), (2, 64), (3, 8, 128), (1, 32), (3, 5, 6), (1, 125, 16, 256), (1, 125, 16, 256), (2, 64, 16, 16), (3, 100, 8, 32)),
},
{
"ins": ['cos'],
"shape": ((6,), (32,), (2, 64), (3, 1, 128), (1, 32), (3, 5, 6), (125, 1, 256), (125, 1, 256), (64, 1, 16), (100, 1, 32)),
},
{
"ins": ['sin'],
"shape": ((6,), (32,), (2, 64), (3, 1, 128), (1, 32), (3, 5, 6), (125, 1, 256), (125, 1, 256), (64, 1, 16), (100, 1, 32)),
},
],
),
),

'rotary_emb_empty_tensor': dict(
name=['rotary_emb'],
Expand Down
16 changes: 16 additions & 0 deletions diopi_test/python/conformance/customized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,22 @@ def rotary_emb(input, cos, sin, conj, interleaved):
out2 = out2.to(data_type)
out = torch.cat((out1, out2), dim=-1)
return out

def apply_rotary(input1, input2, cos, sin, conj, interleaved):
data_type = input1.dtype
input1 = input1.to(torch.float32)
ipnut2 = input2.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
if not conj:
out1 = input1 * cos - input2 * sin
out2 = input1 * sin + input2 * cos
else:
out1 = input1 * cos + input2 * sin
out2 = -input1 * sin + input2 * cos
out1 = out1.to(data_type)
out2 = out2.to(data_type)
return (out1, out2)

def rms_norm(input, normalized_shape, weight, bias, eps):
if normalized_shape is not None:
Expand Down
8 changes: 8 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7036,6 +7036,14 @@ def rotary_emb(input, cos, sin, conj, interleaved):
check_returncode(ret)
return out

def apply_rotary(input1, input2, cos, sin, conj, interleaved):
call = "diopiApplyRotary"
func = check_function(call)
out1 = Tensor(list(input1.size().data), input1.get_dtype())
out2 = Tensor(list(input2.size().data), input2.get_dtype())
ret = func(input1.context(), out1, out2, input1, input2, cos, sin, conj, interleaved)
check_returncode(ret)
return (out1, out2)

def rms_norm(input, normalized_shape, weight, bias, eps):
if bias is not None:
Expand Down
19 changes: 19 additions & 0 deletions impl/torch/functions/functions_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t
return diopiSuccess;
}

diopiError_t diopiApplyRotary(diopiContextHandle_t ctx, diopiTensorHandle_t out1, diopiTensorHandle_t out2, diopiConstTensorHandle_t x1,
diopiConstTensorHandle_t x2, diopiConstTensorHandle_t cos, diopiConstTensorHandle_t sin, const bool conj,
const bool interleaved = false) {
if (interleaved) {
set_last_error_string("interleaved rotary embedding is not supported yet");
return diopiNoImplement;
}
impl::aten::setCurStream(ctx);
auto atX1 = impl::aten::buildATen(x1);
auto atX2 = impl::aten::buildATen(x2);
auto atCos = impl::aten::buildATen(cos);
auto atSin = impl::aten::buildATen(sin);
auto atOut1 = impl::aten::buildATen(out1);
auto atOut2 = impl::aten::buildATen(out2);
ext::ops::apply_rotary_cuda(atX1, atX2, atCos, atSin, atOut1, atOut2, conj);

return diopiSuccess;
}

diopiError_t diopiRMSNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t invRMS, diopiConstTensorHandle_t input,
diopiSize_t normalized_shape, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, double eps) {
impl::aten::setCurStream(ctx);
Expand Down
19 changes: 19 additions & 0 deletions proto/include/diopi/functions_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ extern "C" {
DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t x, diopiConstTensorHandle_t cos,
diopiConstTensorHandle_t sin, const bool conj, const bool interleaved);

/**
* @brief Apply rotary embedding operation to an input tensor.
* @param[in] ctx The diopi context.
* @param[out] out1 The output tensor containing the rotary embeddings. type = [bfloat16, float16, float32, float64].
* @param[out] out2 The output tensor containing the rotary embeddings. type = [bfloat16, float16, float32, float64].
* @param[in] x1 The input tensor which rotary embedding will be applied. type = [bfloat16, float16, float32, float64].
* @param[in] x2 The input tensor which rotary embedding will be applied. type = [bfloat16, float16, float32, float64].
* @param[in] cos The cosine values. type = [bfloat16, float16, float32, float64].
* @param[in] sin The sine values. type = [bfloat16, float16, float32, float64].
* @param[in] conj bool: If `false`, compute rotary embeddings for forward. If `true`, computes the backward of rotary embeddings according to the conjugate of
* the rotary matrix.
* @param[in] interleaved bool:
* - When set to `false`, rotary embedding is applied by splitting 'x' in half and separately applying sine and cosine to each half.
* - When set to `true`, rotary embedding is applied by pairing every two elements in 'x' and applying sine and cosine to each pair.
*/
DIOPI_API diopiError_t diopiApplyRotary(diopiContextHandle_t ctx, diopiTensorHandle_t out1, diopiTensorHandle_t out2, diopiConstTensorHandle_t x1,
diopiConstTensorHandle_t x2, diopiConstTensorHandle_t cos, diopiConstTensorHandle_t sin, const bool conj,
const bool interleaved);

/**
* @brief Apply Root Mean Square (RMS) Normalization to the input tensor.
* @param[in] ctx The diopi context.
Expand Down
Loading