Skip to content

Commit

Permalink
Merge pull request #2 from boostcampaitech2/mingu
Browse files Browse the repository at this point in the history
add test
  • Loading branch information
deokgu1994 authored Sep 4, 2021
2 parents 9c25136 + b0f3b37 commit ecfc90c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 19 deletions.
2 changes: 1 addition & 1 deletion config_multi_label.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"beta": 0.8
},
"test":{
"path": "/opt/ml/code/src/logs/log/Maske/0901_122235/checkpoint-epoch4.pth"
"path": "/opt/ml/code/src_final/save_model/epoch_1_loss_0.008717534132301807.pth"
}

}
Expand Down
Binary file not shown.
55 changes: 37 additions & 18 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,39 @@
#
from collections import Counter

def func_labels(outputs_mask, outputs_gender, outputs_age, device):
"""mask, gender, age의 조합을 통해 최종 18개의 class 중 하나를 반환한다.
Args:
outputs_mask: model의 forward 결과물중 mask에 대한 output
outputs_gender: model의 forward 결과물중 gender에 대한 output
outputs_age: model의 forward 결과물중 age에 대한 output
device: torch.device('cuda') 혹은 torch.device('cpu')
Returns:
outputs_label.to(device): ex) torch.Tensor([0, 17, 1, 3, ...]).to(device)
"""
outputs_label = torch.Tensor([])
len_outputs = len(outputs_mask)
for i in range(len_outputs):
mask_class = outputs_mask[i]
_, mask_class = mask_class.max(dim=0)

gender_class = outputs_gender[i]
_, gender_class = gender_class.max(dim=0)

age_class = outputs_age[i]
_, age_class = age_class.max(dim=0)

label = mask_class * 6 + gender_class * 3 + age_class

#label: int -> [[1, 0, 0, 0]]
one_hot = torch.zeros((1,18))
one_hot[0][label] = 1
label = one_hot
outputs_label = torch.cat([outputs_label, label])
return outputs_label.to(device)

class TestDataset(Dataset):
def __init__(self, img_paths):
self.img_paths = img_paths
Expand Down Expand Up @@ -67,36 +100,22 @@ def main(config):
import sys
pth = config["test"]["path"]
checkpoint = torch.load(pth)
state_dict = checkpoint['state_dict']
# state_dict = checkpoint['state_dict']
if config['n_gpu'] > 1:
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
model.load_state_dict(checkpoint)

# prepare model for testing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

all_predictions = []
# for images in tqdm(loader):
# count =1
# colloct_pred = []
# for transformer in transform.transformations["eavl"]:
# # print("-----" * 10, transformer, count, "-"*20)
# with torch.no_grad():
# augmented_image = transformer.augment_image(images)
# images = images.to(device)
# pred = model(images)
# pred = pred.argmax(1).item()
# colloct_pred.append(pred)
# count +=1
# pred = Counter(colloct_pred).most_common()[0][0]
# all_predictions.append(pred)
for images in tqdm(loader):
with torch.no_grad():
images = images.to(device)
pred = model(images)
pred = pred.argmax(1).item()
pred_outputs_mask, pred_outputs_gender, pred_outputs_age = model(images)
pred = func_labels(pred_outputs_mask, pred_outputs_gender, pred_outputs_age, device)
all_predictions.append(pred)
submission['ans'] = all_predictions
submission.to_csv(os.path.join(test_dir, f"{config['test_name']}_submission.csv"), index=False)
Expand Down

0 comments on commit ecfc90c

Please sign in to comment.