Skip to content

Commit

Permalink
Adds documentation, configurable tags, years, memory
Browse files Browse the repository at this point in the history
  • Loading branch information
Jose J. Martinez committed Sep 12, 2023
1 parent c5667be commit 902eda0
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 202 deletions.
144 changes: 141 additions & 3 deletions README.md

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions examples/retag.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
grants-tagger retag mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FILE_HERE] \
--tags "Artificial Intelligence,HIV" \
--years 2017,2018,2019,2020,2021
17 changes: 12 additions & 5 deletions grants_tagger_light/augmentation/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,26 +156,33 @@ def augment(

@augment_app.command()
def augment_cli(
data_path: str = typer.Argument(..., help="Path to mesh.jsonl"),
data_path: str = typer.Argument(
...,
help="Path to mesh.jsonl"),
save_to_path: str = typer.Argument(
..., help="Path to save the serialized PyArrow dataset after preprocessing"
...,
help="Path to save the new jsonl data"
),
model_key: str = typer.Option(
"gpt-3.5-turbo",
help="LLM to use data augmentation. By now, only `openai` is supported",
),
num_proc: int = typer.Option(
os.cpu_count(), help="Number of processes to use for data augmentation"
os.cpu_count(),
help="Number of processes to use for data augmentation"
),
batch_size: int = typer.Option(
64, help="Preprocessing batch size (for dataset, filter, map, ...)"
64,
help="Preprocessing batch size (for dataset, filter, map, ...)"
),
min_examples: int = typer.Option(
None,
help="Minimum number of examples to require. "
"Less than that will trigger data augmentation.",
),
examples: int = typer.Option(25, help="Examples to generate per each tag."),
examples: int = typer.Option(
25,
help="Examples to generate per each tag."),
prompt_template: str = typer.Option(
"grants_tagger_light/augmentation/prompt.template",
help="File to use as a prompt. "
Expand Down
19 changes: 14 additions & 5 deletions grants_tagger_light/preprocessing/preprocess_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,26 +225,35 @@ def preprocess_mesh(

@preprocess_app.command()
def preprocess_mesh_cli(
data_path: str = typer.Argument(..., help="Path to mesh.jsonl"),
data_path: str = typer.Argument(
...,
help="Path to mesh.jsonl"
),
save_to_path: str = typer.Argument(
..., help="Path to save the serialized PyArrow dataset after preprocessing"
...,
help="Path to save the serialized PyArrow dataset after preprocessing"
),
model_key: str = typer.Argument(
...,
help="Key to use when loading tokenizer and label2id. "
"Leave blank if training from scratch", # noqa
),
test_size: float = typer.Option(
None, help="Fraction of data to use for testing in (0,1] or number of rows"
None,
help="Fraction of data to use for testing in (0,1] or number of rows"
),
num_proc: int = typer.Option(
os.cpu_count(), help="Number of processes to use for preprocessing"
os.cpu_count(),
help="Number of processes to use for preprocessing"
),
max_samples: int = typer.Option(
-1,
help="Maximum number of samples to use for preprocessing",
),
batch_size: int = typer.Option(256, help="Size of the preprocessing batch"),
batch_size: int = typer.Option(
256,
help="Size of the preprocessing batch"
),
tags: str = typer.Option(
None,
help="Comma-separated tags you want to include in the dataset "
Expand Down
80 changes: 51 additions & 29 deletions grants_tagger_light/retagging/retagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,9 @@
from sklearn.metrics import classification_report
import pyarrow.parquet as pq

spark = nlp.start(spark_conf={
'spark.driver.memory': '12g',
'spark.executor.memory': '6g',
# Fraction of heap space used for execution memory
'spark.memory.fraction': '0.6',
# Fraction of heap space used for storage memory
'spark.memory.storageFraction': '0.4',
# Enable off-heap storage (for large datasets)
'spark.memory.offHeap.enabled': 'true',
# Off-heap memory size (adjust as needed)
'spark.memory.offHeap.size': '6g',
'spark.shuffle.manager': 'sort',
'spark.shuffle.spill': 'true',
'spark.master': f'local[{os.cpu_count()}]',
'spark.default.parallelism': f'{os.cpu_count()*2}',
'spark.speculation': 'false',
'spark.task.maxFailures': '4',
'spark.local.dir': f"{os.path.join(os.getcwd())}",
'spark.eventLog.enabled': 'true',
'spark.eventLog.dir': f"{os.path.join(os.getcwd())}"
})
from grants_tagger_light.utils.years_tags_parser import parse_years, parse_tags

import numpy as np

retag_app = typer.Typer()

Expand All @@ -54,14 +36,15 @@ def _load_data(dset: Dataset, tag, limit=100, split=0.8):
return train_dset, test_dset


def _create_pipelines(save_to_path, batch_size, train_df, test_df, tag):
def _create_pipelines(save_to_path, batch_size, train_df, test_df, tag, spark):
"""
This method creates a Spark pipeline (to run on dataframes)
Args:
save_to_path: path where to save the final results.
batch_size: max size of the batch to train. Since data is small for training, I limit it to 8.
train_df: Spark Dataframe of the train data
test_df: Spark Dataframe of the test data
spark: the Spark Object
Returns:
a tuple of (pipeline, lightpipeline)
Expand Down Expand Up @@ -179,21 +162,39 @@ def _curate(save_to_path, pos_dset, neg_dset, tag, limit):
def retag(
data_path: str,
save_to_path: str,
spark_memory: int = 27,
num_proc: int = os.cpu_count(),
batch_size: int = 64,
tags: list = None,
tags_file_path: str = None,
threshold: float = 0.8,
train_examples: int = 100,
supervised: bool = True,
years: list = None,
):

spark = nlp.start(spark_conf={
'spark.driver.memory': f'{spark_memory}g',
'spark.executor.memory': f'{spark_memory}g',
})

# We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing
logging.info("Loading the MeSH jsonl...")
dset = load_dataset("json", data_files=data_path, num_proc=1)
if "train" in dset:
dset = dset["train"]

with open(tags_file_path, 'r') as f:
tags = [x.strip() for x in f.readlines()]
if years is not None:
logger.info(f"Removing all years which are not in {years}")
dset = dset.filter(
lambda x: any(np.isin(years, [str(x["year"])])), num_proc=num_proc
)

if tags_file_path is not None and os.path.isfile(tags_file_path):
with open(tags_file_path, 'r') as f:
tags = [x.strip() for x in f.readlines()]

logging.info(f"Total tags detected: {tags}")

for tag in tags:
logging.info(f"Retagging: {tag}")
Expand Down Expand Up @@ -245,7 +246,7 @@ def retag(
logging.info(f"- Test dataset size: {test_df.count()}")

logging.info(f"- Creating `sparknlp` pipelines...")
pipeline = _create_pipelines(save_to_path, batch_size, train_df, test_df, tag)
pipeline = _create_pipelines(save_to_path, batch_size, train_df, test_df, tag, spark)

logging.info(f"- Optimizing dataframe...")
data_in_parquet = f"{save_to_path}.data.parquet"
Expand Down Expand Up @@ -290,6 +291,10 @@ def retag_cli(
batch_size: int = typer.Option(
64, help="Preprocessing batch size (for dataset, filter, map, ...)"
),
tags: str = typer.Option(
None,
help="Comma separated list of tags to retag"
),
tags_file_path: str = typer.Option(
None,
help="Text file containing one line per tag to be considered. "
Expand All @@ -308,7 +313,14 @@ def retag_cli(
help="Use human curation, showing a `limit` amount of positive and negative examples to curate data"
" for training the retaggers. The user will be required to accept or reject. When the limit is reached,"
" the model will be train. All intermediary steps will be saved."
)
),
spark_memory: int = typer.Option(
20,
help="Gigabytes of memory to be used. Recommended at least 20 to run on MeSH."
),
years: str = typer.Option(
None, help="Comma-separated years you want to include in the retagging process"
),
):
if not data_path.endswith("jsonl"):
logger.error(
Expand All @@ -317,19 +329,29 @@ def retag_cli(
)
exit(-1)

if tags_file_path is None:
if tags_file_path is None and tags is None:
logger.error(
"To understand which tags need to be augmented, use --tags [tags separated by comma] or create a file with"
"a newline per tag and set the path in --tags-file-path"
)
exit(-1)

if tags_file_path is not None and not os.path.isfile(tags_file_path):
logger.error(
"To understand which tags need to be augmented set the path to the tags file in --tags-file-path"
f"{tags_file_path} not found"
)
exit(-1)

retag(
data_path,
save_to_path,
spark_memory=spark_memory,
num_proc=num_proc,
batch_size=batch_size,
tags=parse_tags(tags),
tags_file_path=tags_file_path,
threshold=threshold,
train_examples=train_examples,
supervised=supervised
supervised=supervised,
years=parse_years(years),
)
Loading

0 comments on commit 902eda0

Please sign in to comment.