Skip to content

Commit

Permalink
Total refactor: XLinear
Browse files Browse the repository at this point in the history
  • Loading branch information
Jose J. Martinez committed Sep 12, 2023
1 parent 902eda0 commit 22fad69
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 583 deletions.
4 changes: 3 additions & 1 deletion examples/retag.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# run in c5.9xlarge with at least 72GB of RAM
grants-tagger retag mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FILE_HERE] \
--tags "Artificial Intelligence,HIV" \
--years 2017,2018,2019,2020,2021
--years 2016,2017,2018,2019,2020,2021 \
--supervised
13 changes: 12 additions & 1 deletion grants_tagger_light/models/xlinear/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ def __init__(
# Those are MeshXLinear params
self.threshold = threshold

self.model_path = None
self.xlinear_model_ = None
self.vectorizer_ = None

self.label_binarizer_path = label_binarizer_path
self.label_binarizer_ = None

if label_binarizer_path is not None:
self.load_label_binarizer(label_binarizer_path)

Expand Down Expand Up @@ -167,7 +174,6 @@ def predict_tags(
"""
X: list or numpy array of texts
model_path: path to trained model
label_binarizer_path: path to trained label_binarizer
probabilities: bool, default False. When true probabilities
are returned along with tags
threshold: float, default 0.5. Probability threshold to be used to assign tags.
Expand Down Expand Up @@ -217,6 +223,9 @@ def load(self, model_path, is_predict_only=True):
with open(params_path, "r") as f:
self.__dict__.update(json.load(f))

self.load_label_binarizer(self.label_binarizer_path)
self.model_path = model_path

if self.vectorizer_library == "sklearn":
self.vectorizer_ = load_pickle(vectorizer_path)
else:
Expand All @@ -229,6 +238,8 @@ def load(self, model_path, is_predict_only=True):
model_path, is_predict_only=is_predict_only
)

return self

def load_label_binarizer(self, label_binarizer_path):
with open(label_binarizer_path, "rb") as f:
self.label_binarizer_ = pickle.loads(f.read())
Expand Down
237 changes: 81 additions & 156 deletions grants_tagger_light/retagging/retagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
import typer
from loguru import logger

from datasets import Dataset, load_dataset, concatenate_datasets
from johnsnowlabs import nlp
from datasets import Dataset, load_dataset, concatenate_datasets, load_from_disk

Check failure on line 9 in grants_tagger_light/retagging/retagging.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

grants_tagger_light/retagging/retagging.py:9:67: F401 `datasets.load_from_disk` imported but unused

import os

from sklearn.metrics import classification_report
import pyarrow.parquet as pq
from sklearn import preprocessing

from grants_tagger_light.models.xlinear import MeshXLinear
from grants_tagger_light.utils.years_tags_parser import parse_years, parse_tags
import scipy
import pickle as pkl

import numpy as np
import tqdm

retag_app = typer.Typer()

Expand All @@ -36,94 +38,22 @@ def _load_data(dset: Dataset, tag, limit=100, split=0.8):
return train_dset, test_dset


def _create_pipelines(save_to_path, batch_size, train_df, test_df, tag, spark):
"""
This method creates a Spark pipeline (to run on dataframes)
Args:
save_to_path: path where to save the final results.
batch_size: max size of the batch to train. Since data is small for training, I limit it to 8.
train_df: Spark Dataframe of the train data
test_df: Spark Dataframe of the test data
spark: the Spark Object
Returns:
a tuple of (pipeline, lightpipeline)
"""
document_assembler = nlp.DocumentAssembler() \
.setInputCol("abstractText") \
.setOutputCol("document")

# Biobert Sentence Embeddings (clinical)
embeddings = nlp.BertSentenceEmbeddings.pretrained("sent_biobert_clinical_base_cased", "en") \
.setInputCols(["document"]) \
.setOutputCol("sentence_embeddings")

retrain = True
clf_dir = f"{save_to_path}.{tag.replace(' ', '')}_clf"
if os.path.isdir(clf_dir):
answer = input("Classifier already trained. Do you want to reuse it? [y|n]: ")
while answer not in ['y', 'n']:
answer = input("Classifier already trained. Do you want to reuse it? [y|n]: ")
if answer == 'y':
retrain = False

if retrain:
# I'm limiting the batch size to 8 since there are not many examples and big batch sizes will decrease accuracy
classifierdl = nlp.ClassifierDLApproach() \
.setInputCols(["sentence_embeddings"]) \
.setOutputCol("label") \
.setLabelColumn("featured_tag") \
.setMaxEpochs(25) \
.setLr(0.001) \
.setBatchSize(max(batch_size, 8)) \
.setEnableOutputLogs(True)
# .setOutputLogsPath('logs')

clf_pipeline = nlp.Pipeline(stages=[document_assembler,
embeddings,
classifierdl])

fit_clf_pipeline = clf_pipeline.fit(train_df)
preds = fit_clf_pipeline.transform(test_df)
preds_df = preds.select('featured_tag', 'abstractText', 'label.result').toPandas()
preds_df['result'] = preds_df['result'].apply(lambda x: x[0])
logging.info(classification_report(preds_df['featured_tag'], preds_df['result']))

logging.info("- Loading the model for prediction...")
fit_clf_pipeline.stages[-1].write().overwrite().save(clf_dir)

fit_clf_model = nlp.ClassifierDLModel.load(clf_dir)

pred_pipeline = nlp.Pipeline(stages=[document_assembler,
embeddings,
fit_clf_model])
pred_df = spark.createDataFrame([['']]).toDF("text")
fit_pred_pipeline = pred_pipeline.fit(pred_df)

return fit_pred_pipeline


def _annotate(save_to_path, dset, tag, limit, is_positive):
human_supervision = {}
curation_file = f"{save_to_path}.{tag.replace(' ', '')}.curation.json"
def _annotate(curation_file, dset, tag, limit, is_positive):
field = 'positive' if is_positive else 'negative'
human_supervision = {tag: {'positive': [], 'negative': []}}
if os.path.isfile(curation_file):
with open(curation_file, 'r') as f:
human_supervision = json.load(f)
prompt = f"File `{curation_file}` found. Do you want to reuse previous work? [y|n]: "

Check failure on line 45 in grants_tagger_light/retagging/retagging.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

grants_tagger_light/retagging/retagging.py:45:89: E501 Line too long (93 > 88 characters)
answer = input(prompt)
while answer not in ['y', 'n']:
answer = input(prompt)
if answer == 'n':
human_supervision[tag][is_positive] = []

if tag not in human_supervision:
human_supervision[tag] = {'positive': [], 'negative': []}
if answer == 'y':
with open(curation_file, 'r') as f:
human_supervision = json.load(f)

field = 'positive' if is_positive else 'negative'
count = len(human_supervision[tag][field])
logging.info(f"[{tag}] Annotated: {count} Required: {limit} Available: {len(dset) - count}")

Check failure on line 54 in grants_tagger_light/retagging/retagging.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

grants_tagger_light/retagging/retagging.py:54:89: E501 Line too long (96 > 88 characters)
finished = False
while count <= limit:
while count < limit:
tries = 0
random.seed(time.time())
random_pos_row = random.randint(0, len(dset))
Expand All @@ -148,7 +78,7 @@ def _annotate(save_to_path, dset, tag, limit, is_positive):
human_supervision[tag][field].append(dset[random_pos_row])
with open(curation_file, 'w') as f:
json.dump(human_supervision, f)
count = len(human_supervision[tag][field])
count = len(human_supervision[tag])


def _curate(save_to_path, pos_dset, neg_dset, tag, limit):
Expand All @@ -162,9 +92,8 @@ def _curate(save_to_path, pos_dset, neg_dset, tag, limit):
def retag(
data_path: str,
save_to_path: str,
spark_memory: int = 27,
num_proc: int = os.cpu_count(),
batch_size: int = 64,
batch_size: int = 1024,
tags: list = None,
tags_file_path: str = None,
threshold: float = 0.8,
Expand All @@ -173,11 +102,6 @@ def retag(
years: list = None,
):

spark = nlp.start(spark_conf={
'spark.driver.memory': f'{spark_memory}g',
'spark.executor.memory': f'{spark_memory}g',
})

# We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing
logging.info("Loading the MeSH jsonl...")
dset = load_dataset("json", data_files=data_path, num_proc=1)
Expand All @@ -194,18 +118,19 @@ def retag(
with open(tags_file_path, 'r') as f:
tags = [x.strip() for x in f.readlines()]

logging.info(f"Total tags detected: {tags}")
logging.info(f"- Total tags detected: {tags}.")
logging.info("- Training classifiers (retaggers)")

for tag in tags:
logging.info(f"Retagging: {tag}")

os.makedirs(os.path.join(save_to_path, tag.replace(" ", "")), exist_ok=True)
logging.info(f"- Obtaining positive examples for {tag}...")
positive_dset = dset.filter(
lambda x: tag in x["meshMajor"], num_proc=num_proc
)

if len(positive_dset['abstractText']) < 50:
logging.info(f"Skipping {tag}: low examples ({len(positive_dset['abstractText'])}. "
if len(positive_dset['abstractText']) < train_examples:
logging.info(f"Skipping {tag}: low examples ({len(positive_dset['abstractText'])} vs "

Check failure on line 132 in grants_tagger_light/retagging/retagging.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

grants_tagger_light/retagging/retagging.py:132:89: E501 Line too long (98 > 88 characters)
f"expected {train_examples}). "
f"Check {save_to_path}.err for more information about skipped tags.")

Check failure on line 134 in grants_tagger_light/retagging/retagging.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

grants_tagger_light/retagging/retagging.py:134:89: E501 Line too long (94 > 88 characters)
with open(f"{save_to_path}.err", 'a') as f:
f.write(tag)
Expand All @@ -216,80 +141,85 @@ def retag(
lambda x: tag not in x["meshMajor"], num_proc=num_proc
)

curation_file = os.path.join(save_to_path, tag.replace(' ', ''), "curation")
if supervised:
logging.info(f"- Curating data...")
_curate(save_to_path, positive_dset, negative_dset, tag, train_examples)
logging.info(f"- Curating {tag}...")
_curate(curation_file, positive_dset, negative_dset, tag, train_examples)
else:
with open(curation_file, 'w') as f:
json.dump({tag: {'positive': [positive_dset[i] for i in range(train_examples)],

Check failure on line 150 in grants_tagger_light/retagging/retagging.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

grants_tagger_light/retagging/retagging.py:150:89: E501 Line too long (95 > 88 characters)
'negative': [negative_dset[i] for i in range(train_examples)]

Check failure on line 151 in grants_tagger_light/retagging/retagging.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

grants_tagger_light/retagging/retagging.py:151:89: E501 Line too long (94 > 88 characters)
}
}, f)

logging.info("- Retagging...")

curation_file = f"{save_to_path}.{tag.replace(' ', '')}.curation.json"
if os.path.isfile(curation_file):
with open(curation_file, "r") as fr:
# I load the curated data file
human_supervision = json.load(fr)
positive_dset = Dataset.from_list(human_supervision[tag]['positive'])
negative_dset = Dataset.from_list(human_supervision[tag]['negative'])
models = {}
for tag in tags:
curation_file = os.path.join(save_to_path, tag.replace(' ', ''), "curation")
if not os.path.isfile(curation_file):
logger.info(f"Skipping `{tag}` retagging as no curation data was found. "
f"Maybe there were too little examples? (check {save_to_path}.err)")

Check failure on line 162 in grants_tagger_light/retagging/retagging.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

grants_tagger_light/retagging/retagging.py:162:89: E501 Line too long (92 > 88 characters)
continue
with open(curation_file, "r") as fr:
data = json.load(fr)
positive_dset = Dataset.from_list(data[tag]['positive'])
negative_dset = Dataset.from_list(data[tag]['negative'])

pos_x_train, pos_x_test = _load_data(positive_dset, tag, limit=train_examples, split=0.8)
neg_x_train, neg_x_test = _load_data(negative_dset, "other", limit=train_examples, split=0.8)

pos_x_train = pos_x_train.add_column("featured_tag", [tag] * len(pos_x_train))
pos_x_test = pos_x_test.add_column("featured_tag", [tag] * len(pos_x_test))
neg_x_train = neg_x_train.add_column("featured_tag", ["other"] * len(neg_x_train))
neg_x_test = neg_x_test.add_column("featured_tag", ["other"] * len(neg_x_test))
pos_x_train = pos_x_train.add_column("tag", [tag] * len(pos_x_train))
pos_x_test = pos_x_test.add_column("tag", [tag] * len(pos_x_test))
neg_x_train = neg_x_train.add_column("tag", ["other"] * len(neg_x_train))
neg_x_test = neg_x_test.add_column("tag", ["other"] * len(neg_x_test))

logging.info(f"- Creating train/test sets...")
train = concatenate_datasets([pos_x_train, neg_x_train])
train_df = spark.createDataFrame(train)
test = concatenate_datasets([pos_x_test, neg_x_test])
test_df = spark.createDataFrame(test)

logging.info(f"- Train dataset size: {train_df.count()}")
logging.info(f"- Test dataset size: {test_df.count()}")

logging.info(f"- Creating `sparknlp` pipelines...")
pipeline = _create_pipelines(save_to_path, batch_size, train_df, test_df, tag, spark)

logging.info(f"- Optimizing dataframe...")
data_in_parquet = f"{save_to_path}.data.parquet"
optimize=True
if os.path.isfile(data_in_parquet):
answer = input("Optimized dataframe found. Do you want to use it? [y|n]: ")
while answer not in ['y', 'n']:
answer = input("Optimized dataframe found. Do you want to use it? [y|n]: ")
if answer == 'y':
optimize = False

if optimize:
dset = dset.remove_columns(["title", "journal", "year"])

pq.write_table(dset.data.table, data_in_parquet)
del dset, train, train_df, test, test_df, pos_x_train, pos_x_test, neg_x_train, neg_x_test, positive_dset,\
negative_dset
sdf = spark.read.load(data_in_parquet)

logging.info(f"- Repartitioning...")
sdf = sdf.repartition(num_proc)

logging.info(f"- Retagging {tag}...")
pipeline.transform(sdf).write.mode('overwrite').save(f"{save_to_path}.{tag.replace(' ', '')}.prediction")

# 1) We load
# 2) We filter to get those results where the predicted tag was not initially in meshMajor
# 3) We filter by confidence > threshold
# predictions = spark.read.load(f"{save_to_path}.{tag}.prediction").\
# filter(~array_contains(col('meshMajor'), tag)).\
label_binarizer = preprocessing.LabelBinarizer()
label_binarizer_path = os.path.join(save_to_path, tag.replace(" ", ""), 'labelbinarizer')
labels = [1 if x == tag else 0 for x in train["tag"]]
label_binarizer.fit(labels)
with open(label_binarizer_path, 'wb') as f:
pkl.dump(label_binarizer, f)

model = MeshXLinear(label_binarizer_path=label_binarizer_path)
model.fit(train["abstractText"], scipy.sparse.csr_matrix(label_binarizer.transform(labels)))
models[tag] = model
model_path = os.path.join(save_to_path, tag.replace(" ", ""), "clf")
os.makedirs(model_path, exist_ok=True)
model.save(model_path)

logging.info("- Predicting all tags")
for b in tqdm.tqdm(range(int(len(dset) / batch_size))):
start = b * batch_size
end = min(len(dset), (b+1) * batch_size)
batch = dset[start:end]["abstractText"]
for tag in tags:
if tag not in models:
logger.info(f"Skipping {tag} - classifier not trained. Maybe there were little data?")
continue
models[tag](batch, threshold=threshold)


@retag_app.command()
def retag_cli(
data_path: str = typer.Argument(..., help="Path to mesh.jsonl"),
data_path: str = typer.Argument(
...,
help="Path to allMeSH_2021.jsonl"),
save_to_path: str = typer.Argument(
..., help="Path where to save the retagged data"
...,
help="Path where to save the retagged data"
),
num_proc: int = typer.Option(
os.cpu_count(), help="Number of processes to use for data augmentation"
os.cpu_count(),
help="Number of processes to use for data augmentation"
),
batch_size: int = typer.Option(
64, help="Preprocessing batch size (for dataset, filter, map, ...)"
1024,
help="Preprocessing batch size (for dataset, filter, map, ...)"
),
tags: str = typer.Option(
None,
Expand All @@ -309,15 +239,11 @@ def retag_cli(
help="Number of examples to use for training the retaggers"
),
supervised: bool = typer.Option(
True,
False,
help="Use human curation, showing a `limit` amount of positive and negative examples to curate data"
" for training the retaggers. The user will be required to accept or reject. When the limit is reached,"
" the model will be train. All intermediary steps will be saved."
),
spark_memory: int = typer.Option(
20,
help="Gigabytes of memory to be used. Recommended at least 20 to run on MeSH."
),
years: str = typer.Option(
None, help="Comma-separated years you want to include in the retagging process"
),
Expand Down Expand Up @@ -345,7 +271,6 @@ def retag_cli(
retag(
data_path,
save_to_path,
spark_memory=spark_memory,
num_proc=num_proc,
batch_size=batch_size,
tags=parse_tags(tags),
Expand Down
Loading

0 comments on commit 22fad69

Please sign in to comment.