Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate Retagging Experimentation #19

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3ae973e
:see_no_evil: add .DS_Store and related
agombert Oct 5, 2023
5b48387
:wrench: update gradient accumulation
Oct 9, 2023
624a321
:wrench: increase gradient accumulation
Oct 9, 2023
c5c1526
:wrench: decrease batch size
Oct 9, 2023
beb5da7
:wrench: changing batch size
Oct 9, 2023
98972dd
:wrench: reset gradient accumulation to 2
Oct 9, 2023
477407c
:beers: making tests on models
Oct 17, 2023
233c60b
:disk: load data with HF
agombert Oct 17, 2023
52f69ef
:chart: metrics for both results
agombert Oct 17, 2023
23a66aa
:beers: add the label2id for getting data with different mapping
agombert Oct 17, 2023
8605854
:beers: add the label2id as parameter if different mapping
agombert Oct 17, 2023
6709d3e
:charts: charts for last iterations
agombert Oct 17, 2023
dd9de53
:see_no_evil: add preprocessed results
agombert Oct 17, 2023
f4b1ea6
:gear: save model bertmesh training 102023
agombert Oct 17, 2023
a6f9c4b
:wrench: put back default value for gradient accumulation
agombert Oct 19, 2023
a14315d
:wrench: modify gradient accumulation in the yaml
agombert Oct 19, 2023
8fcf4fb
:bug: fix labels
agombert Oct 19, 2023
1095bd9
:wrench: add gradient accumulation in dvc lock
agombert Oct 19, 2023
b3b7efc
:wrench: add eval each 100 steps to check the evaluation
agombert Oct 19, 2023
82e8769
:wrench: add max sample size
agombert Oct 19, 2023
9cbc840
:memo: add logs
agombert Oct 19, 2023
d2eb48d
:wrench: add sigmoid
agombert Oct 24, 2023
2b625e2
:alembic: make test on lower sample
agombert Oct 24, 2023
b2057a1
:memo: add logs and metrics
agombert Oct 24, 2023
20308ba
:wrench: modify paramters for training
agombert Oct 24, 2023
fe958f9
:floppy_disk: quit jsonl data
agombert Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
__pycache__/
*.py[cod]
*$py.class
.DS_Store
*/.DS_Store
/*/.DS_Store

# C extensions
*.so
Expand Down Expand Up @@ -164,3 +167,4 @@ cython_debug/
bertmesh_outs/
wandb/
/bertmesh_before_retagging
/preprocessed_results
4 changes: 0 additions & 4 deletions data/raw/allMeSH_2021.jsonl.dvc

This file was deleted.

69 changes: 52 additions & 17 deletions grants_tagger_light/evaluation/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@
Evaluate model performance on test set
"""
import json
import logging
import os
from pathlib import Path
from typing import Optional
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY

import scipy.sparse as sp
import numpy as np
import typer
import torch
import torch.nn.functional as F
from sklearn.metrics import classification_report, precision_recall_fscore_support
from sklearn.preprocessing import MultiLabelBinarizer
from wasabi import row, table
from grants_tagger_light.utils import load_data, load_train_test_data
from grants_tagger_light.models.bert_mesh import BertMesh, BertMeshPipeline

from tqdm import tqdm

PIPELINE_REGISTRY.register_pipeline(
"grants-tagging", pipeline_class=BertMeshPipeline, pt_model=BertMesh
Expand All @@ -29,11 +33,13 @@ def evaluate_model(
split_data=True,
results_path=None,
full_report_path=None,
batch_size=10,
):
model = BertMesh.from_pretrained(model_path)

label_binarizer = MultiLabelBinarizer()
label_binarizer.fit([list(model.id2label.values())])
label_binarizer = MultiLabelBinarizer(classes=list(model.id2label.keys()))
label_binarizer.fit([list(model.id2label.keys())])
model.label2id = {value: key for key, value in model.id2label.items()}

pipe = pipeline(
"grants-tagging",
Expand All @@ -48,12 +54,33 @@ def evaluate_model(
)
_, X_test, _, Y_test = load_train_test_data(data_path, label_binarizer)
else:
X_test, Y_test, _ = load_data(data_path, label_binarizer)

X_test, Y_test, _ = load_data(data_path, label_binarizer, model_label2id=model.label2id)

logging.info('data loaded')
X_test = X_test[:10]
Y_test = Y_test[:10]

top_10_index = np.argsort(np.sum(Y_test, axis=0))[::-1][:10]
print(top_10_index)
print(np.sum(Y_test, axis=0)[top_10_index])

logging.info(f'beginning of evaluation - {len(X_test)} items')

outputs = pipe(['This grant is about malaria and HIV',
'My name is Arnault and I live in barcelona'], return_labels=False)
print(outputs)
for output in outputs:
argmax = np.argmax(output[0])
print(f'argmax: {argmax}')
print(argmax, output[0][argmax])
Y_pred_proba = pipe(X_test, return_labels=False)


Y_pred_proba = [torch.sigmoid(proba) for proba in Y_pred_proba]
print(Y_pred_proba)
Y_pred_proba = torch.vstack(Y_pred_proba)

print('loss')
print(F.binary_cross_entropy_with_logits(torch.tensor(Y_pred_proba), torch.tensor(Y_test).float()))
print(Y_pred_proba.shape, Y_test.shape)
Y_pred_proba = sp.csr_matrix(Y_pred_proba)

if not isinstance(threshold, list):
Expand All @@ -66,15 +93,23 @@ def evaluate_model(
results = []
for th in threshold:
Y_pred = Y_pred_proba > th

p, r, f1, _ = precision_recall_fscore_support(Y_test, Y_pred, average="micro")
full_report = classification_report(Y_test, Y_pred, output_dict=True)
print(np.sum(Y_pred, axis=0)[:, top_10_index])
print(np.sum(Y_pred))
print(Y_pred_proba.shape)
print(np.max(Y_pred_proba.toarray(), axis=0)[top_10_index])
print(Y_pred.sum(axis=0))
top_10_index = np.argsort(np.array(Y_pred.sum(axis=0))[0])[::-1][:10]
print(top_10_index)
print(np.array(Y_pred.sum(axis=0))[0, top_10_index])

p, r, f1, _ = precision_recall_fscore_support(Y_test, Y_pred.toarray(), average="micro")
full_report = classification_report(Y_test, Y_pred.toarray(), output_dict=True)

# Gets averages
averages = {idx: report for idx, report in full_report.items() if "avg" in idx}
averages = {str(idx): report for idx, report in full_report.items() if "avg" in idx}
# Gets class reports and converts index to class names for readability
full_report = {
label_binarizer.classes_[int(idx)]: report
str(label_binarizer.classes_[int(idx)]): report
for idx, report in full_report.items()
if "avg" not in idx
}
Expand All @@ -99,10 +134,10 @@ def evaluate_model(
print(row(row_data, widths=widths))

if results_path:
with open(results_path, "w") as f:
with open(os.path.join(results_path, 'metrics.json'), "w") as f:
f.write(json.dumps(results, indent=4))
if full_report_path:
with open(full_report_path, "w") as f:
with open(os.path.join(results_path, 'full_report_metrics.json'), "w") as f:
f.write(json.dumps(full_report, indent=4))


Expand All @@ -121,9 +156,9 @@ def evaluate_model_cli(
"0.5", help="threshold or comma separated thresholds used to assign tags"
),
results_path: Optional[str] = typer.Option(None, help="path to save results"),
full_report_path: Optional[str] = typer.Option(
None,
help="Path to save full report, i.e. "
full_report_path: Optional[bool] = typer.Option(
False,
help="Whether to save full report or not, i.e. "
"more comprehensive results than the ones saved in results_path",
),
split_data: bool = typer.Option(
Expand Down
1 change: 1 addition & 0 deletions grants_tagger_light/models/bert_mesh/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def postprocess(
return_labels: bool,
threshold: float = 0.5,
):
print('I AM IN THE POSTPROCESS MY GOD')
if return_labels:
outs = [
[
Expand Down
5 changes: 3 additions & 2 deletions grants_tagger_light/preprocessing/preprocess_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def preprocess_mesh(
train_years: list = None,
test_years: list = None,
):

if max_samples != -1:
logger.info(f"Filtering examples to {max_samples}")
data_path = create_sample_file(data_path, max_samples)
Expand All @@ -70,7 +71,7 @@ def preprocess_mesh(
label2id = None
id2label = None
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
"Wellcome/WellcomeBertMesh"#"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
)
else:
# Load the model to get its label2id
Expand Down Expand Up @@ -106,7 +107,7 @@ def preprocess_mesh(
dset = dset.filter(
lambda x: any(np.isin(tags, x["meshMajor"])), num_proc=num_proc
)

# Remove unused columns to save space & time
dset = dset.remove_columns(["journal", "pmid", "title"])

Expand Down
36 changes: 33 additions & 3 deletions grants_tagger_light/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pprint import pformat
import typer
import numpy as np

import os
import transformers
import json
Expand Down Expand Up @@ -80,7 +81,7 @@ def train_bertmesh(
train_years=train_years,
test_years=test_years,
)

train_dset, val_dset = dset["train"], dset["test"]

metric_labels = []
Expand Down Expand Up @@ -160,15 +161,35 @@ def train_bertmesh(
def sklearn_metrics(prediction: EvalPrediction):
# This is a batch, so it's an array (rows) of array (labels)
# Array of arrays with probas [[5.4e-5 1.3e-3...] [5.4e-5 1.3e-3...] ... ]
y_pred = prediction.predictions
def sigmoid(x):
return 1 / (1 + np.exp(-x))
y_pred = np.array([sigmoid(y) for y in prediction.predictions])
top_indexes = np.argsort(np.sum(prediction.label_ids, axis=0))[::-1][:5]
#print(y_pred.shape)
#print('mean probabilties and standard deviations for 10 first examples')
#print(np.mean(y_pred[:11, :],axis=1), np.std(y_pred[:11, :], axis=1))
#print('ids')
#print(np.where(prediction.label_ids[10, :] == 1)[0])
#print('Probabilities for true value')
#print(y_pred[10, np.where(prediction.label_ids[10, :] == 1)[0]])
#print('Higher predictions')
#print(np.sort(y_pred[10, :])[::-1])
# Transformed to 0-1 if bigger than threshold [[0 1 0...] [0 0 1...] ... ]
y_pred = np.int64(y_pred > training_args.threshold)
#print('Number of predictions and outputs')
#print(sum(y_pred[10, :]), y_pred[10, :])
logger.info("predictions made")
logger.info(np.sum(y_pred))

# Array of arrays with 0/1 [[0 0 1 ...] [0 1 0 ...] ... ]
y_true = prediction.label_ids

# report = classification_report(y_pred, y_true, output_dict=True)

reports = {index: {'pred_max': np.max(y_pred[[0, 1000, 2500, 10000], :], axis=1)}
for index in top_indexes}
logger.info("highest pred calculated")

if training_args.prune_labels_in_evaluation:
mask = np.zeros(y_pred.shape, dtype=bool)
mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True
Expand All @@ -178,10 +199,14 @@ def sklearn_metrics(prediction: EvalPrediction):
else:
filtered_y_pred = y_pred
filtered_y_true = y_true

logger.info("pruned labels in evaluation")

report = classification_report(
filtered_y_pred, filtered_y_true, output_dict=True
filtered_y_true, filtered_y_pred, output_dict=True,
zero_division=0.0,
)
logger.info("classification report computed")

metric_dict = {
"micro_avg": report["micro avg"],
Expand All @@ -190,6 +215,11 @@ def sklearn_metrics(prediction: EvalPrediction):
"samples_avg": report["samples avg"],
}

metric_dict = {**metric_dict,
**{f'index_{index}': reports.get(index)
for index in top_indexes}}


return metric_dict

logger.info("Collating labels...")
Expand Down
29 changes: 28 additions & 1 deletion grants_tagger_light/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
import logging
import os
from datasets import load_from_disk

# encoding: utf-8
import pickle
from functools import partial

import pandas as pd
import numpy as np
import requests
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -37,10 +40,34 @@ def yield_tags(data_path, label_binarizer=None):
yield item["tags"]


def load_data(data_path, label_binarizer=None, X_format="List"):
def load_data(data_path, label_binarizer=None, X_format="List", model_label2id=None):
"""Load data from the dataset."""
print("Loading data...")


if os.path.isdir(data_path):
logger.info(
"Train/test data found in a folder, which means you preprocessed and "
"save the data before. Loading that split from disk..."
)
dset = load_from_disk(os.path.join(data_path, "dataset"))
with open(os.path.join(data_path, "label2id"), "r") as f:
label2id = json.load(f)
with open(os.path.join(data_path, "id2label"), "r") as f:
id2label = json.load(f)

train_dset, val_dset = dset["train"], dset["test"]
texts = val_dset['abstractText']
tags = val_dset['label_ids']

if model_label2id:
tags = [[model_label2id.get(id2label.get(str(x), None), None) for x in val] for val in tags]

if label_binarizer:
tags = label_binarizer.transform(tags)
meta = val_dset['meshMajor']
return texts, tags, meta

texts = []
tags = []
meta = []
Expand Down
Loading
Loading