Skip to content

Commit

Permalink
updated requirements.txt
Browse files Browse the repository at this point in the history
  • Loading branch information
RakeshRaj97 committed Dec 7, 2020
1 parent 65a30b5 commit 0eba623
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
40 changes: 20 additions & 20 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
#from torch.cuda import amp
from apex import amp
# from torch.cuda import amp
# from apex import amp
from sklearn import metrics

from wtfml.data_loaders.image import ClassificationLoader
Expand Down Expand Up @@ -38,9 +38,9 @@ def forward(self, image, targets):
return out, loss

def train(fold):
training_data_path = ""
model_path = ""
df = pd.read_csv("")
training_data_path = "/train"
model_path = "/models"
df = pd.read_csv("train_folds.csv")
device = "cuda"
epochs = 50
train_bs = 32
Expand Down Expand Up @@ -113,12 +113,12 @@ def train(fold):

#scaler = amp.GradScaler()

model, optimizer = amp.initialize(
model,
optimizer,
opt_level="O1",
verbosity=0
)
# model, optimizer = amp.initialize(
# model,
# optimizer,
# opt_level="O1",
# verbosity=0
# )

es = EarlyStopping(patience=5, mode="max")
for epoch in range(epochs):
Expand Down Expand Up @@ -196,15 +196,15 @@ def predict(fold):

if __name__ == "__main__":
train(0)
# train(1)
# train(2)
# train(3)
# train(4)
# train(5)
# train(6)
# train(7)
# train(8)
# train(9)
train(1)
train(2)
train(3)
train(4)
train(5)
train(6)
train(7)
train(8)
train(9)

# p1 = predict(0)
# p2 = predict(1)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ albumentations
pretrainedmodels
torch==1.5.0
wtfml==0.0.2
pandas==1.1.4

0 comments on commit 0eba623

Please sign in to comment.