-
Notifications
You must be signed in to change notification settings - Fork 0
/
Model.py
78 lines (59 loc) · 2.25 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.nn import Sigmoid
import sys
import os
from collections import Counter
import string
import numpy as np
import argparse
from LoadData import LoadData
from MyLoss import MyLoss
class Model(nn.Module):
def __init__(self, dataset, args):
super(Model, self).__init__()
# Define input dimension of RNN.
self.inputSize = args.inputSize
#cores, memory, disk, time
self.hiddenSize = args.hiddenSize
# Define the number of layers of the RNN.
self.numLayers = args.numLayers
self.rnnUnit = nn.LSTM(input_size=self.inputSize,
hidden_size=self.hiddenSize,
num_layers=self.numLayers,
)
self.fc = nn.Linear(self.hiddenSize, self.hiddenSize)
#self.relu = nn.ReLU()
def forward(self, X, prevState):
output, state = self.rnnUnit(X, prevState)
#print('prev_state is:', prevState)
output = self.fc(output)
#output = self.relu(output)
return output, state
def initState(self, seqLength):
stateHidden = torch.zeros(self.numLayers, seqLength, self.hiddenSize)
stateCurrent = torch.zeros(self.numLayers, seqLength, self.hiddenSize)
return (stateHidden, stateCurrent)
def evaluate(self, model, validate_data, device, prevState, args):
model.eval()
best_loss = float('inf')
for data in validate_data:
inputs, labels = data
# inputs = inputs.to(device)
# labels = labels.to(device)
predictions = model(inputs, prevState)[0]
# inputs = inputs.detach().cpu().numpy()
# labels = labels.detach().cpu().numpy()
# predictions = predictions.detach().cpu().numpy()
loss_f = MyLoss(args)
loss = loss_f(predictions, labels)
if loss < best_loss:
best_loss = loss
print(f'Best loss value per batch across validation dataset is {best_loss}')
return best_loss
def predict(self, X, state):
output, _ = self.forward(X, state)
return output