forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorch_workflow_test.py
34 lines (26 loc) · 996 Bytes
/
torch_workflow_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from keras.src import layers
from keras.src import testing
from keras.src.backend.common import KerasVariable
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = layers.Dense(1)
def forward(self, x):
x = self.fc1(x)
return x
class TorchWorkflowTest(testing.TestCase):
def test_keras_layer_in_nn_module(self):
net = Net()
# Test using Keras layer in a nn.Module.
# Test forward pass
self.assertAllEqual(list(net(torch.empty(100, 10)).shape), [100, 1])
# Test KerasVariables are added as nn.Parameter.
self.assertLen(list(net.parameters()), 2)
# Test using KerasVariable as a torch tensor for torch ops.
kernel = net.fc1.kernel
transposed_kernel = torch.transpose(kernel, 0, 1)
self.assertIsInstance(kernel, KerasVariable)
self.assertIsInstance(
torch.mul(kernel, transposed_kernel), torch.Tensor
)