-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make print_summary
by default as true
#2
base: master
Are you sure you want to change the base?
Conversation
By default, we will have all inputs and outputs including model's i/o >>> class myNN(nn.Module):
def __init__(self):
super().__init__()
self.model1 = nn.Sequential(
nn.Linear(100,200),
nn.ReLU(inplace=True),
nn.Linear(200,50),
nn.Softmax(-1)
)
self.model2 = nn.Sequential(
nn.Linear(10,20),
nn.ReLU(inplace=True),
nn.Linear(20,5),
nn.Softmax(-1)
)
def forward(self, x1, x2):
y1 = self.model1(x1)
y2 = self.model2(x2)
return y1,y2
>>> mynn = myNN()
>>> summary(mynn, torch.zeros(1,100), torch.zeros(1,10))
|
Hi @sizhky! Thank you so much for your PR! After years using summary in keras, we needed a similar version in pytorch 😄 I agree, print as default is better. In general, when we call this method we want to print indeed When I created this module, I thought about printing both input/output shapes, but thinking in keras behavior I realized that, excluding first and last layers, information is duplicated because output from last layer is the input to the next (in general). Maybe an option to drop one of them when desired could embrace all programmers who [do/do not] want to see so much information. On the other hand, version with both sounds good, specially in your example. I think for your example, maybe, an version showing parent layer could be better than print both input and output shapes. Keras has an option for that, but I haven't implemented that yet Did you test your version with lib examples? What do you think about above points/questions? |
I understand your points about keeping the table succinct. I want to add a couple of points -
See the outputs for lib examples below.
I didn't undrestand CNN-----------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
=======================================================================
Input [1, 1, 28, 28] -1
Conv2d-2 [1, 1, 28, 28] [1, 10, 24, 24] 260
Conv2d-3 [1, 10, 12, 12] [1, 20, 8, 8] 5,020
Dropout2d-4 [1, 20, 8, 8] [1, 20, 8, 8] 0
Linear-5 [1, 320] [1, 50] 16,050
Linear-6 [1, 50] [1, 10] 510
Output [1, 10] -1
=======================================================================
Total params: 21,838
Trainable params: 21,838
Non-trainable params: 0
-----------------------------------------------------------------------
=========================== Hierarchical Summary ===========================
Net(
(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1)), 260 params
(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1)), 5,020 params
(conv2_drop): Dropout2d(p=0.5, inplace=False), 0 params
(fc1): Linear(in_features=320, out_features=50, bias=True), 16,050 params
(fc2): Linear(in_features=50, out_features=10, bias=True), 510 params
), 21,840 params
============================================================================ Transformer-----------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===================================================================================
Input [1, 5], [1, 5] -1
Encoder-2 [1, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] 17,332,224
Decoder-3 [1, 5], [1, 5], [1, 5, 512] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] 22,060,544
Linear-4 [1, 5, 512] [1, 5, 7] 3,584
Output [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] -1
===================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
-----------------------------------------------------------------------------------
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
==========================================================================================================================================================================================================================================================================================================================
Input [1, 5], [1, 5] -1
Encoder-2 [1, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] 17,332,224
Decoder-3 [1, 5], [1, 5], [1, 5, 512] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] 22,060,544
Linear-4 [1, 5, 512] [1, 5, 7] 3,584
Output [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] -1
==========================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
Batch size: 1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
================================ Hierarchical Summary ================================
Transformer(
(encoder): Encoder(
(src_emb): Embedding(6, 512), 3,072 params
(pos_emb): Embedding(6, 512), 3,072 params
(layers): ModuleList(
(0): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
(1): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
(2): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
(3): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
(4): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
(5): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
), 17,326,080 params
), 17,332,224 params
(decoder): Decoder(
(tgt_emb): Embedding(7, 512), 3,584 params
(pos_emb): Embedding(6, 512), 3,072 params
(layers): ModuleList(
(0): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
(1): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
(2): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
(3): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
(4): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
(5): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
), 22,053,888 params
), 22,060,544 params
(projection): Linear(in_features=512, out_features=7, bias=False), 3,584 params
), 39,396,352 params
======================================================================================
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
==========================================================================================================================================================================================================================================================================================================================
Input [1, 5], [1, 5] -1
Embedding-2 [1, 5] [1, 5, 512] 3,072
Embedding-3 [1, 5] [1, 5, 512] 3,072
EncoderLayer-4 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
EncoderLayer-5 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
EncoderLayer-6 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
EncoderLayer-7 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
EncoderLayer-8 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
EncoderLayer-9 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
Embedding-10 [1, 5] [1, 5, 512] 3,584
Embedding-11 [1, 5] [1, 5, 512] 3,072
DecoderLayer-12 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
DecoderLayer-13 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
DecoderLayer-14 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
DecoderLayer-15 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
DecoderLayer-16 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
DecoderLayer-17 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
Linear-18 [1, 5, 512] [1, 5, 7] 3,584
Output [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] -1
==========================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
==========================================================================================================================================================================================================================================================================================================================
Input [1, 5], [1, 5] -1
Embedding-2 [1, 5] [1, 5, 512] 3,072
Embedding-3 [1, 5] [1, 5, 512] 3,072
Linear-4 [1, 5, 512] [1, 5, 512] 262,656
Linear-5 [1, 5, 512] [1, 5, 512] 262,656
Linear-6 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-7 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-8 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-9 [1, 5, 512] [1, 5, 512] 262,656
Linear-10 [1, 5, 512] [1, 5, 512] 262,656
Linear-11 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-12 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-13 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-14 [1, 5, 512] [1, 5, 512] 262,656
Linear-15 [1, 5, 512] [1, 5, 512] 262,656
Linear-16 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-17 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-18 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-19 [1, 5, 512] [1, 5, 512] 262,656
Linear-20 [1, 5, 512] [1, 5, 512] 262,656
Linear-21 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-22 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-23 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-24 [1, 5, 512] [1, 5, 512] 262,656
Linear-25 [1, 5, 512] [1, 5, 512] 262,656
Linear-26 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-27 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-28 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-29 [1, 5, 512] [1, 5, 512] 262,656
Linear-30 [1, 5, 512] [1, 5, 512] 262,656
Linear-31 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-32 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-33 [1, 2048, 5] [1, 512, 5] 1,049,088
Embedding-34 [1, 5] [1, 5, 512] 3,584
Embedding-35 [1, 5] [1, 5, 512] 3,072
Linear-36 [1, 5, 512] [1, 5, 512] 262,656
Linear-37 [1, 5, 512] [1, 5, 512] 262,656
Linear-38 [1, 5, 512] [1, 5, 512] 262,656
Linear-39 [1, 5, 512] [1, 5, 512] 262,656
Linear-40 [1, 5, 512] [1, 5, 512] 262,656
Linear-41 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-42 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-43 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-44 [1, 5, 512] [1, 5, 512] 262,656
Linear-45 [1, 5, 512] [1, 5, 512] 262,656
Linear-46 [1, 5, 512] [1, 5, 512] 262,656
Linear-47 [1, 5, 512] [1, 5, 512] 262,656
Linear-48 [1, 5, 512] [1, 5, 512] 262,656
Linear-49 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-50 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-51 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-52 [1, 5, 512] [1, 5, 512] 262,656
Linear-53 [1, 5, 512] [1, 5, 512] 262,656
Linear-54 [1, 5, 512] [1, 5, 512] 262,656
Linear-55 [1, 5, 512] [1, 5, 512] 262,656
Linear-56 [1, 5, 512] [1, 5, 512] 262,656
Linear-57 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-58 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-59 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-60 [1, 5, 512] [1, 5, 512] 262,656
Linear-61 [1, 5, 512] [1, 5, 512] 262,656
Linear-62 [1, 5, 512] [1, 5, 512] 262,656
Linear-63 [1, 5, 512] [1, 5, 512] 262,656
Linear-64 [1, 5, 512] [1, 5, 512] 262,656
Linear-65 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-66 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-67 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-68 [1, 5, 512] [1, 5, 512] 262,656
Linear-69 [1, 5, 512] [1, 5, 512] 262,656
Linear-70 [1, 5, 512] [1, 5, 512] 262,656
Linear-71 [1, 5, 512] [1, 5, 512] 262,656
Linear-72 [1, 5, 512] [1, 5, 512] 262,656
Linear-73 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-74 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-75 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-76 [1, 5, 512] [1, 5, 512] 262,656
Linear-77 [1, 5, 512] [1, 5, 512] 262,656
Linear-78 [1, 5, 512] [1, 5, 512] 262,656
Linear-79 [1, 5, 512] [1, 5, 512] 262,656
Linear-80 [1, 5, 512] [1, 5, 512] 262,656
Linear-81 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-82 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-83 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-84 [1, 5, 512] [1, 5, 7] 3,584
Output [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] -1
==========================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Parent Layers Layer (type) Input Shape Output Shape Param #
======================================================================================================================================================================================================================================================================================================================================================================
Input [1, 5], [1, 5] -1
Transformer/Encoder Embedding-2 [1, 5] [1, 5, 512] 3,072
Transformer/Encoder Embedding-3 [1, 5] [1, 5, 512] 3,072
Transformer/Encoder/EncoderLayer MultiHeadAttention-4 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-5 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Encoder/EncoderLayer MultiHeadAttention-6 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-7 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Encoder/EncoderLayer MultiHeadAttention-8 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-9 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Encoder/EncoderLayer MultiHeadAttention-10 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-11 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Encoder/EncoderLayer MultiHeadAttention-12 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-13 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Encoder/EncoderLayer MultiHeadAttention-14 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-15 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder Embedding-16 [1, 5] [1, 5, 512] 3,584
Transformer/Decoder Embedding-17 [1, 5] [1, 5, 512] 3,072
Transformer/Decoder/DecoderLayer MultiHeadAttention-18 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-19 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-20 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder/DecoderLayer MultiHeadAttention-21 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-22 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-23 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder/DecoderLayer MultiHeadAttention-24 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-25 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-26 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder/DecoderLayer MultiHeadAttention-27 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-28 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-29 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder/DecoderLayer MultiHeadAttention-30 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-31 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-32 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder/DecoderLayer MultiHeadAttention-33 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-34 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-35 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer Linear-36 [1, 5, 512] [1, 5, 7] 3,584
Output [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] -1
======================================================================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
@sizhky table width is breaking in transformer example? Im going to merge your PR when table width is validated and add a parameter to drop some column as optional parameter |
Hi,
I've been using this module since a month and my experience with it has been largely pleasant. Thank you for your contribution 😄
One thing i observed was that i was always using
print_summary
as True, so i just feel, using True as default is more user friendly. In fact using False asserts the usage of needing a string instead of a printed report more than expecting a user to print it everytime...Edit:
I added a few more edits to i/o. I think the changes might have been too aggressive in design change, but we can test it out more and try to break it...