Skip to content

Commit

Permalink
🐛 fix labels
Browse files Browse the repository at this point in the history
  • Loading branch information
agombert committed Oct 19, 2023
1 parent a14315d commit 8fcf4fb
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions grants_tagger_light/evaluation/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def evaluate_model(
):
model = BertMesh.from_pretrained(model_path)

label_binarizer = MultiLabelBinarizer()
label_binarizer = MultiLabelBinarizer(classes=list(model.id2label.keys()))
label_binarizer.fit([list(model.id2label.keys())])
model.label2id = {value: key for key, value in model.id2label.items()}

Expand All @@ -57,8 +57,8 @@ def evaluate_model(
X_test, Y_test, _ = load_data(data_path, label_binarizer, model_label2id=model.label2id)

logging.info('data loaded')
X_test = X_test[:100]
Y_test = Y_test[:100]
X_test = X_test[:10]
Y_test = Y_test[:10]

top_10_index = np.argsort(np.sum(Y_test, axis=0))[::-1][:10]
print(top_10_index)
Expand All @@ -74,9 +74,8 @@ def evaluate_model(
print(f'argmax: {argmax}')
print(argmax, output[0][argmax])
Y_pred_proba = pipe(X_test, return_labels=False)
#Y_pred_proba = [torch.sigmoid(proba) for proba in Y_pred_proba]

#Y_pred_proba = [torch.sigmoid(proba) for proba in Y_pred_proba]
Y_pred_proba = [torch.sigmoid(proba) for proba in Y_pred_proba]
print(Y_pred_proba)
Y_pred_proba = torch.vstack(Y_pred_proba)
print('loss')
Expand Down

0 comments on commit 8fcf4fb

Please sign in to comment.