diff --git a/grants_tagger_light/evaluation/evaluate_model.py b/grants_tagger_light/evaluation/evaluate_model.py index 93356aa8..edfd749b 100644 --- a/grants_tagger_light/evaluation/evaluate_model.py +++ b/grants_tagger_light/evaluation/evaluate_model.py @@ -37,7 +37,7 @@ def evaluate_model( ): model = BertMesh.from_pretrained(model_path) - label_binarizer = MultiLabelBinarizer() + label_binarizer = MultiLabelBinarizer(classes=list(model.id2label.keys())) label_binarizer.fit([list(model.id2label.keys())]) model.label2id = {value: key for key, value in model.id2label.items()} @@ -57,8 +57,8 @@ def evaluate_model( X_test, Y_test, _ = load_data(data_path, label_binarizer, model_label2id=model.label2id) logging.info('data loaded') - X_test = X_test[:100] - Y_test = Y_test[:100] + X_test = X_test[:10] + Y_test = Y_test[:10] top_10_index = np.argsort(np.sum(Y_test, axis=0))[::-1][:10] print(top_10_index) @@ -74,9 +74,8 @@ def evaluate_model( print(f'argmax: {argmax}') print(argmax, output[0][argmax]) Y_pred_proba = pipe(X_test, return_labels=False) - #Y_pred_proba = [torch.sigmoid(proba) for proba in Y_pred_proba] - #Y_pred_proba = [torch.sigmoid(proba) for proba in Y_pred_proba] + Y_pred_proba = [torch.sigmoid(proba) for proba in Y_pred_proba] print(Y_pred_proba) Y_pred_proba = torch.vstack(Y_pred_proba) print('loss')