Skip to content

Commit

Permalink
add ref_impl
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Dec 12, 2024
1 parent 2bccee4 commit 13237f4
Show file tree
Hide file tree
Showing 16 changed files with 468 additions and 231 deletions.
8 changes: 6 additions & 2 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -29095,8 +29095,12 @@ This version of the operator has been available since version 23 of the default

# Inserted rotated embeddings back to the original input
if interleaved:
x_rotate[:, :, :, 0::2] = real
x_rotate[:, :, :, 1::2] = imag
# x_rotate[:, :, :, 0::2] = real
# x_rotate[:, :, :, 1::2] = imag
real = np.expand_dims(real, axis=-1)
imag = np.expand_dims(imag, axis=-1)
x_rotate_concat = np.concatenate((real, imag), axis=-1)
x_rotate = np.reshape(x_rotate_concat, x_rotate.shape)
else:
x_rotate = np.concatenate((real, imag), axis=-1)
output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
Expand Down
72 changes: 52 additions & 20 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -27380,8 +27380,12 @@ expect(

# Inserted rotated embeddings back to the original input
if interleaved:
x_rotate[:, :, :, 0::2] = real
x_rotate[:, :, :, 1::2] = imag
# x_rotate[:, :, :, 0::2] = real
# x_rotate[:, :, :, 1::2] = imag
real = np.expand_dims(real, axis=-1)
imag = np.expand_dims(imag, axis=-1)
x_rotate_concat = np.concatenate((real, imag), axis=-1)
x_rotate = np.reshape(x_rotate_concat, x_rotate.shape)
else:
x_rotate = np.concatenate((real, imag), axis=-1)
output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
Expand Down Expand Up @@ -27444,21 +27448,23 @@ This version of the operator has been available since version 23 of the default
node = onnx.helper.make_node(
"RotaryEmbedding",
inputs=["input", "cos_cache", "sin_cache", "position_ids"],
outputs=["output"]
outputs=["output"],
)

input_data = np.random.rand(2, 3, 4, 8).astype(np.float32)
position_ids_data = np.random.rand(2, 3).astype(np.int64)
sin_cache_data = np.random.rand(50, 4).astype(np.float32)
cos_cache_data = np.random.rand(50, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data)
expected_output = compute_rotary_embedding(
input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data, position_ids_data],
outputs=[expected_output],
name="test_rotary_embedding"
name="test_rotary_embedding",
)
```

Expand All @@ -27474,21 +27480,27 @@ node = onnx.helper.make_node(
"RotaryEmbedding",
inputs=["input", "cos_cache", "sin_cache", "position_ids"],
outputs=["output"],
num_heads=num_heads
num_heads=num_heads,
)

input_data = np.random.rand(2, 3, 32).astype(np.float32)
position_ids_data = np.random.rand(2, 3).astype(np.int64)
sin_cache_data = np.random.rand(50, 4).astype(np.float32)
cos_cache_data = np.random.rand(50, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data, num_heads=num_heads)
expected_output = compute_rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids=position_ids_data,
num_heads=num_heads,
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data, position_ids_data],
outputs=[expected_output],
name="test_rotary_embedding_3d_input"
name="test_rotary_embedding_3d_input",
)
```

Expand All @@ -27503,21 +27515,27 @@ node = onnx.helper.make_node(
"RotaryEmbedding",
inputs=["input", "cos_cache", "sin_cache", "position_ids"],
outputs=["output"],
interleaved=1
interleaved=1,
)

input_data = np.random.rand(2, 3, 4, 8).astype(np.float32)
position_ids_data = np.random.rand(2, 3).astype(np.int64)
sin_cache_data = np.random.rand(50, 4).astype(np.float32)
cos_cache_data = np.random.rand(50, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data, interleaved=1)
expected_output = compute_rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids=position_ids_data,
interleaved=1,
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data, position_ids_data],
outputs=[expected_output],
name="test_rotary_embedding_interleaved"
name="test_rotary_embedding_interleaved",
)
```

Expand All @@ -27530,22 +27548,23 @@ expect(
```python
node = onnx.helper.make_node(
"RotaryEmbedding",
inputs=["input", "cos_cache", "sin_cache", ""],
inputs=["input", "cos_cache", "sin_cache"],
outputs=["output"],
interleaved=1,
)

input_data = np.random.rand(2, 3, 4, 8).astype(np.float32)
sin_cache_data = np.random.rand(2, 3, 4).astype(np.float32)
cos_cache_data = np.random.rand(2, 3, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, interleaved=1)
expected_output = compute_rotary_embedding(
input_data, cos_cache_data, sin_cache_data
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data],
outputs=[expected_output],
name="test_rotary_embedding_no_position_ids"
name="test_rotary_embedding_no_position_ids",
)
```

Expand All @@ -27569,13 +27588,20 @@ position_ids_data = np.random.rand(2, 3).astype(np.int64)
sin_cache_data = np.random.rand(50, 4).astype(np.float32)
cos_cache_data = np.random.rand(50, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data, interleaved=1, rotary_embedding_dim=4)
expected_output = compute_rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids=position_ids_data,
interleaved=1,
rotary_embedding_dim=4,
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data, position_ids_data],
outputs=[expected_output],
name="test_rotary_embedding_with_interleaved_rotary_dim"
name="test_rotary_embedding_with_interleaved_rotary_dim",
)
```

Expand All @@ -27590,21 +27616,27 @@ node = onnx.helper.make_node(
"RotaryEmbedding",
inputs=["input", "cos_cache", "sin_cache", "position_ids"],
outputs=["output"],
rotary_embedding_dim=4
rotary_embedding_dim=4,
)

input_data = np.random.rand(2, 3, 4, 8).astype(np.float32)
position_ids_data = np.random.rand(2, 3).astype(np.int64)
sin_cache_data = np.random.rand(50, 4).astype(np.float32)
cos_cache_data = np.random.rand(50, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data, rotary_embedding_dim=4)
expected_output = compute_rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids=position_ids_data,
rotary_embedding_dim=4,
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data, position_ids_data],
outputs=[expected_output],
name="test_rotary_embedding_with_rotary_dim"
name="test_rotary_embedding_with_rotary_dim",
)
```

Expand Down
64 changes: 46 additions & 18 deletions docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -19233,21 +19233,23 @@ There are 6 test cases, listed as following:
node = onnx.helper.make_node(
"RotaryEmbedding",
inputs=["input", "cos_cache", "sin_cache", "position_ids"],
outputs=["output"]
outputs=["output"],
)

input_data = np.random.rand(2, 3, 4, 8).astype(np.float32)
position_ids_data = np.random.rand(2, 3).astype(np.int64)
sin_cache_data = np.random.rand(50, 4).astype(np.float32)
cos_cache_data = np.random.rand(50, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data)
expected_output = compute_rotary_embedding(
input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data, position_ids_data],
outputs=[expected_output],
name="test_rotary_embedding"
name="test_rotary_embedding",
)
```

Expand All @@ -19261,21 +19263,27 @@ node = onnx.helper.make_node(
"RotaryEmbedding",
inputs=["input", "cos_cache", "sin_cache", "position_ids"],
outputs=["output"],
num_heads=num_heads
num_heads=num_heads,
)

input_data = np.random.rand(2, 3, 32).astype(np.float32)
position_ids_data = np.random.rand(2, 3).astype(np.int64)
sin_cache_data = np.random.rand(50, 4).astype(np.float32)
cos_cache_data = np.random.rand(50, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data, num_heads=num_heads)
expected_output = compute_rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids=position_ids_data,
num_heads=num_heads,
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data, position_ids_data],
outputs=[expected_output],
name="test_rotary_embedding_3d_input"
name="test_rotary_embedding_3d_input",
)
```

Expand All @@ -19288,21 +19296,27 @@ node = onnx.helper.make_node(
"RotaryEmbedding",
inputs=["input", "cos_cache", "sin_cache", "position_ids"],
outputs=["output"],
interleaved=1
interleaved=1,
)

input_data = np.random.rand(2, 3, 4, 8).astype(np.float32)
position_ids_data = np.random.rand(2, 3).astype(np.int64)
sin_cache_data = np.random.rand(50, 4).astype(np.float32)
cos_cache_data = np.random.rand(50, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data, interleaved=1)
expected_output = compute_rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids=position_ids_data,
interleaved=1,
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data, position_ids_data],
outputs=[expected_output],
name="test_rotary_embedding_interleaved"
name="test_rotary_embedding_interleaved",
)
```

Expand All @@ -19313,22 +19327,23 @@ expect(
```python
node = onnx.helper.make_node(
"RotaryEmbedding",
inputs=["input", "cos_cache", "sin_cache", ""],
inputs=["input", "cos_cache", "sin_cache"],
outputs=["output"],
interleaved=1,
)

input_data = np.random.rand(2, 3, 4, 8).astype(np.float32)
sin_cache_data = np.random.rand(2, 3, 4).astype(np.float32)
cos_cache_data = np.random.rand(2, 3, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, interleaved=1)
expected_output = compute_rotary_embedding(
input_data, cos_cache_data, sin_cache_data
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data],
outputs=[expected_output],
name="test_rotary_embedding_no_position_ids"
name="test_rotary_embedding_no_position_ids",
)
```

Expand All @@ -19350,13 +19365,20 @@ position_ids_data = np.random.rand(2, 3).astype(np.int64)
sin_cache_data = np.random.rand(50, 4).astype(np.float32)
cos_cache_data = np.random.rand(50, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data, interleaved=1, rotary_embedding_dim=4)
expected_output = compute_rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids=position_ids_data,
interleaved=1,
rotary_embedding_dim=4,
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data, position_ids_data],
outputs=[expected_output],
name="test_rotary_embedding_with_interleaved_rotary_dim"
name="test_rotary_embedding_with_interleaved_rotary_dim",
)
```

Expand All @@ -19369,21 +19391,27 @@ node = onnx.helper.make_node(
"RotaryEmbedding",
inputs=["input", "cos_cache", "sin_cache", "position_ids"],
outputs=["output"],
rotary_embedding_dim=4
rotary_embedding_dim=4,
)

input_data = np.random.rand(2, 3, 4, 8).astype(np.float32)
position_ids_data = np.random.rand(2, 3).astype(np.int64)
sin_cache_data = np.random.rand(50, 4).astype(np.float32)
cos_cache_data = np.random.rand(50, 4).astype(np.float32)

expected_output = compute_rotary_embedding(input_data, cos_cache_data, sin_cache_data, position_ids=position_ids_data, rotary_embedding_dim=4)
expected_output = compute_rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids=position_ids_data,
rotary_embedding_dim=4,
)

expect(
node,
inputs=[input_data, cos_cache_data, sin_cache_data, position_ids_data],
outputs=[expected_output],
name="test_rotary_embedding_with_rotary_dim"
name="test_rotary_embedding_with_rotary_dim",
)
```

Expand Down
Loading

0 comments on commit 13237f4

Please sign in to comment.