From f4b1f32b5b7aef8bd2b3bb309b9d61379a75afd0 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 13 Sep 2023 12:42:52 +0100 Subject: [PATCH] Black --- grants_tagger_light/augmentation/augment.py | 19 +- .../preprocessing/preprocess_mesh.py | 21 +-- .../retagging/cnn_gpu_config.cfg | 127 ------------- grants_tagger_light/retagging/config.cfg | 124 ------------- grants_tagger_light/retagging/retagging.py | 167 ++++++++++-------- 5 files changed, 102 insertions(+), 356 deletions(-) delete mode 100644 grants_tagger_light/retagging/cnn_gpu_config.cfg delete mode 100644 grants_tagger_light/retagging/config.cfg diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index e97ff46d..900f8817 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -156,33 +156,24 @@ def augment( @augment_app.command() def augment_cli( - data_path: str = typer.Argument( - ..., - help="Path to mesh.jsonl"), - save_to_path: str = typer.Argument( - ..., - help="Path to save the new jsonl data" - ), + data_path: str = typer.Argument(..., help="Path to mesh.jsonl"), + save_to_path: str = typer.Argument(..., help="Path to save the new jsonl data"), model_key: str = typer.Option( "gpt-3.5-turbo", help="LLM to use data augmentation. By now, only `openai` is supported", ), num_proc: int = typer.Option( - os.cpu_count(), - help="Number of processes to use for data augmentation" + os.cpu_count(), help="Number of processes to use for data augmentation" ), batch_size: int = typer.Option( - 64, - help="Preprocessing batch size (for dataset, filter, map, ...)" + 64, help="Preprocessing batch size (for dataset, filter, map, ...)" ), min_examples: int = typer.Option( None, help="Minimum number of examples to require. " "Less than that will trigger data augmentation.", ), - examples: int = typer.Option( - 25, - help="Examples to generate per each tag."), + examples: int = typer.Option(25, help="Examples to generate per each tag."), prompt_template: str = typer.Option( "grants_tagger_light/augmentation/prompt.template", help="File to use as a prompt. " diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 586783cf..92a06eb8 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -117,7 +117,7 @@ def preprocess_mesh( batch_size=batch_size, num_proc=num_proc, desc="Tokenizing", - fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"} + fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, ) logger.info("Time taken to tokenize: {}".format(time.time() - t1)) @@ -225,13 +225,9 @@ def preprocess_mesh( @preprocess_app.command() def preprocess_mesh_cli( - data_path: str = typer.Argument( - ..., - help="Path to mesh.jsonl" - ), + data_path: str = typer.Argument(..., help="Path to mesh.jsonl"), save_to_path: str = typer.Argument( - ..., - help="Path to save the serialized PyArrow dataset after preprocessing" + ..., help="Path to save the serialized PyArrow dataset after preprocessing" ), model_key: str = typer.Argument( ..., @@ -239,21 +235,16 @@ def preprocess_mesh_cli( "Leave blank if training from scratch", # noqa ), test_size: float = typer.Option( - None, - help="Fraction of data to use for testing in (0,1] or number of rows" + None, help="Fraction of data to use for testing in (0,1] or number of rows" ), num_proc: int = typer.Option( - os.cpu_count(), - help="Number of processes to use for preprocessing" + os.cpu_count(), help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( -1, help="Maximum number of samples to use for preprocessing", ), - batch_size: int = typer.Option( - 256, - help="Size of the preprocessing batch" - ), + batch_size: int = typer.Option(256, help="Size of the preprocessing batch"), tags: str = typer.Option( None, help="Comma-separated tags you want to include in the dataset " diff --git a/grants_tagger_light/retagging/cnn_gpu_config.cfg b/grants_tagger_light/retagging/cnn_gpu_config.cfg deleted file mode 100644 index e71cf867..00000000 --- a/grants_tagger_light/retagging/cnn_gpu_config.cfg +++ /dev/null @@ -1,127 +0,0 @@ -[paths] -train = "" -dev = "" -raw = null -init_tok2vec = null -vectors = null - -[system] -seed = 42 -gpu_allocator = "pytorch" - -[nlp] -lang = "en" -pipeline = ["textcat"] -tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"} -disabled = [] -before_creation = null -after_creation = null -after_pipeline_creation = null -batch_size = 1000 - -[components] - -[components.textcat] -factory = "textcat_multilabel" -threshold = 0.5 - -[components.textcat.model] -@architectures = "spacy.TextCatCNN.v1" -exclusive_classes = false -nO = null - -[components.textcat.model.tok2vec] -@architectures = "spacy.Tok2Vec.v2" - -[components.textcat.model.tok2vec.embed] -@architectures = "spacy.MultiHashEmbed.v1" -width = ${components.textcat.model.tok2vec.encode:width} -rows = [10000,5000,5000,5000] -attrs = ["NORM","PREFIX","SUFFIX","SHAPE"] -include_static_vectors = false - -[components.textcat.model.tok2vec.encode] -@architectures = "spacy.MaxoutWindowEncoder.v2" -width = 96 -depth = 4 -window_size = 1 -maxout_pieces = 3 - -[corpora] - -[corpora.dev] -@readers = "spacy.Corpus.v1" -path = ${paths:dev} -gold_preproc = ${corpora.train.gold_preproc} -max_length = 0 -limit = 0 -augmenter = null - -[corpora.train] -@readers = "spacy.Corpus.v1" -path = ${paths:train} -gold_preproc = false -max_length = 0 -limit = 0 -augmenter = null - -[training] -train_corpus = "corpora.train" -dev_corpus = "corpora.dev" -seed = ${system.seed} -gpu_allocator = ${system.gpu_allocator} -dropout = 0.2 -patience = 1600 -max_epochs = 0 -max_steps = 20000 -eval_frequency = 200 -accumulate_gradient = 1 -frozen_components = [] -before_to_disk = null - -[training.batcher] -@batchers = "spacy.batch_by_sequence.v1" -size = 32 -get_length = null - -[training.logger] -@loggers = "spacy.ConsoleLogger.v1" -progress_bar = false - -[training.optimizer] -@optimizers = "Adam.v1" -beta1 = 0.9 -beta2 = 0.999 -L2_is_weight_decay = true -L2 = 0.01 -grad_clip = 1.0 -eps = 0.00000001 -learn_rate = 0.001 -use_averages = true - -[training.score_weights] -cats_score_desc = null -cats_micro_p = null -cats_micro_r = null -cats_micro_f = null -cats_macro_p = null -cats_macro_r = null -cats_macro_f = null -cats_macro_auc = null -cats_f_per_type = null -cats_macro_auc_per_type = null -cats_score = 1.0 - -[pretraining] - -[initialize] -vectors = ${paths.vectors} -init_tok2vec = ${paths.init_tok2vec} -vocab_data = null -lookups = null -before_init = null -after_init = null - -[initialize.components] - -[initialize.tokenizer] \ No newline at end of file diff --git a/grants_tagger_light/retagging/config.cfg b/grants_tagger_light/retagging/config.cfg deleted file mode 100644 index a5fb381e..00000000 --- a/grants_tagger_light/retagging/config.cfg +++ /dev/null @@ -1,124 +0,0 @@ -[paths] -train = null -dev = null -vectors = null -init_tok2vec = null - -[system] -gpu_allocator = null -seed = 0 - -[nlp] -lang = "en" -pipeline = ["textcat"] -batch_size = 1000 -disabled = [] -before_creation = null -after_creation = null -after_pipeline_creation = null -tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"} - -[components] - -[components.textcat] -factory = "textcat" -scorer = {"@scorers":"spacy.textcat_scorer.v1"} -threshold = 0.5 - -[components.textcat.model] -@architectures = "spacy.TextCatBOW.v2" -exclusive_classes = true -ngram_size = 1 -no_output_layer = false -nO = null - -[corpora] - -[corpora.dev] -@readers = "spacy.Corpus.v1" -path = ${paths.dev} -max_length = 0 -gold_preproc = false -limit = 0 -augmenter = null - -[corpora.train] -@readers = "spacy.Corpus.v1" -path = ${paths.train} -max_length = 0 -gold_preproc = false -limit = 0 -augmenter = null - -[training] -dev_corpus = "corpora.dev" -train_corpus = "corpora.train" -seed = ${system.seed} -gpu_allocator = ${system.gpu_allocator} -# dropout = 0.1 -dropout = 0.0 -accumulate_gradient = 1 -# patience = 1600 -patience = 0 -max_epochs = 15 -# max_steps = 20000 -eval_frequency = 200 -frozen_components = [] -annotating_components = [] -before_to_disk = null - -[training.batcher] -@batchers = "spacy.batch_by_words.v1" -discard_oversize = false -tolerance = 0.2 -get_length = null - -[training.batcher.size] -@schedules = "compounding.v1" -start = 100 -stop = 1000 -compound = 1.001 -t = 0.0 - -[training.logger] -@loggers = "spacy.ConsoleLogger.v1" -progress_bar = false - -[training.optimizer] -@optimizers = "Adam.v1" -beta1 = 0.9 -beta2 = 0.999 -L2_is_weight_decay = true -L2 = 0.01 -grad_clip = 1.0 -use_averages = false -eps = 0.00000001 -#learn_rate = 0.001 -learn_rate = 0.005 - -[training.score_weights] -cats_score = 1.0 -cats_score_desc = null -cats_micro_p = null -cats_micro_r = null -cats_micro_f = null -cats_macro_p = null -cats_macro_r = null -cats_macro_f = null -cats_macro_auc = null -cats_f_per_type = null -cats_macro_auc_per_type = null - -[pretraining] - -[initialize] -vectors = ${paths.vectors} -init_tok2vec = ${paths.init_tok2vec} -vocab_data = null -lookups = null -before_init = null -after_init = null - -[initialize.components] - -[initialize.tokenizer] \ No newline at end of file diff --git a/grants_tagger_light/retagging/retagging.py b/grants_tagger_light/retagging/retagging.py index 3ca81b1b..eb73f793 100644 --- a/grants_tagger_light/retagging/retagging.py +++ b/grants_tagger_light/retagging/retagging.py @@ -30,7 +30,7 @@ def _load_data(dset: Dataset, tag, limit=100, split=0.8): dset = dset.select([x for x in range(limit)]) # Not in parallel since the data is very small and it's worse to divide and conquer dset.map( - lambda x: {'featured_tag': tag}, + lambda x: {"featured_tag": tag}, desc=f"Adding featured tag ({tag})", ) train_size = int(split * min_limit) @@ -40,44 +40,52 @@ def _load_data(dset: Dataset, tag, limit=100, split=0.8): def _annotate(curation_file, dset, tag, limit, is_positive): - field = 'positive' if is_positive else 'negative' - human_supervision = {tag: {'positive': [], 'negative': []}} + field = "positive" if is_positive else "negative" + human_supervision = {tag: {"positive": [], "negative": []}} if os.path.isfile(curation_file): - prompt = f"File `{curation_file}` found. Do you want to reuse previous work? [y|n]: " + prompt = ( + f"File `{curation_file}` found. Do you want to reuse previous work? [y|n]: " + ) answer = input(prompt) - while answer not in ['y', 'n']: + while answer not in ["y", "n"]: answer = input(prompt) - if answer == 'y': - with open(curation_file, 'r') as f: + if answer == "y": + with open(curation_file, "r") as f: human_supervision = json.load(f) count = len(human_supervision[tag][field]) - logging.info(f"[{tag}] Annotated: {count} Required: {limit} Available: {len(dset) - count}") + logging.info( + f"[{tag}] Annotated: {count} Required: {limit} Available: {len(dset) - count}" + ) finished = False while count < limit: tries = 0 random.seed(time.time()) random_pos_row = random.randint(0, len(dset)) - id_ = dset[random_pos_row]['pmid'] - while id_ in [x['pmid'] for x in human_supervision[tag][field]]: + id_ = dset[random_pos_row]["pmid"] + while id_ in [x["pmid"] for x in human_supervision[tag][field]]: random_pos_row = random.randint(0, len(dset)) - id_ = dset[random_pos_row]['pmid'] + id_ = dset[random_pos_row]["pmid"] tries += 1 if tries >= 10: - logger.error(f"Unable to find more examples for {field} {tag} which are not already tagged. " - f"Continuing with {count} examples...") + logger.error( + f"Unable to find more examples for {field} {tag} which are not already tagged. " + f"Continuing with {count} examples..." + ) finished = True break if finished: break - print("="*50) - print(dset[random_pos_row]['abstractText']) print("=" * 50) - res = input(f'[{count}/{limit}]> Is this {"NOT " if not is_positive else ""} a `{tag}` text? ' - f'[a to accept]: ') - if res == 'a': + print(dset[random_pos_row]["abstractText"]) + print("=" * 50) + res = input( + f'[{count}/{limit}]> Is this {"NOT " if not is_positive else ""} a `{tag}` text? ' + f"[a to accept]: " + ) + if res == "a": human_supervision[tag][field].append(dset[random_pos_row]) - with open(curation_file, 'w') as f: + with open(curation_file, "w") as f: json.dump(human_supervision, f) count = len(human_supervision[tag]) @@ -102,7 +110,6 @@ def retag( supervised: bool = True, years: list = None, ): - # We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing logging.info("Loading the MeSH jsonl...") dset = load_dataset("json", data_files=data_path, num_proc=1) @@ -116,7 +123,7 @@ def retag( ) if tags_file_path is not None and os.path.isfile(tags_file_path): - with open(tags_file_path, 'r') as f: + with open(tags_file_path, "r") as f: tags = [x.strip() for x in f.readlines()] logging.info(f"- Total tags detected: {tags}.") @@ -125,15 +132,15 @@ def retag( for tag in tags: os.makedirs(os.path.join(save_to_path, tag.replace(" ", "")), exist_ok=True) logging.info(f"- Obtaining positive examples for {tag}...") - positive_dset = dset.filter( - lambda x: tag in x["meshMajor"], num_proc=num_proc - ) - - if len(positive_dset['abstractText']) < train_examples: - logging.info(f"Skipping {tag}: low examples ({len(positive_dset['abstractText'])} vs " - f"expected {train_examples}). " - f"Check {save_to_path}.err for more information about skipped tags.") - with open(f"{save_to_path}.err", 'a') as f: + positive_dset = dset.filter(lambda x: tag in x["meshMajor"], num_proc=num_proc) + + if len(positive_dset["abstractText"]) < train_examples: + logging.info( + f"Skipping {tag}: low examples ({len(positive_dset['abstractText'])} vs " + f"expected {train_examples}). " + f"Check {save_to_path}.err for more information about skipped tags." + ) + with open(f"{save_to_path}.err", "a") as f: f.write(tag) continue @@ -142,33 +149,48 @@ def retag( lambda x: tag not in x["meshMajor"], num_proc=num_proc ) - curation_file = os.path.join(save_to_path, tag.replace(' ', ''), "curation") + curation_file = os.path.join(save_to_path, tag.replace(" ", ""), "curation") if supervised: logging.info(f"- Curating {tag}...") _curate(curation_file, positive_dset, negative_dset, tag, train_examples) else: - with open(curation_file, 'w') as f: - json.dump({tag: {'positive': [positive_dset[i] for i in range(train_examples)], - 'negative': [negative_dset[i] for i in range(train_examples)] - } - }, f) + with open(curation_file, "w") as f: + json.dump( + { + tag: { + "positive": [ + positive_dset[i] for i in range(train_examples) + ], + "negative": [ + negative_dset[i] for i in range(train_examples) + ], + } + }, + f, + ) logging.info("- Retagging...") models = {} for tag in tags: - curation_file = os.path.join(save_to_path, tag.replace(' ', ''), "curation") + curation_file = os.path.join(save_to_path, tag.replace(" ", ""), "curation") if not os.path.isfile(curation_file): - logger.info(f"Skipping `{tag}` retagging as no curation data was found. " - f"Maybe there were too little examples? (check {save_to_path}.err)") + logger.info( + f"Skipping `{tag}` retagging as no curation data was found. " + f"Maybe there were too little examples? (check {save_to_path}.err)" + ) continue with open(curation_file, "r") as fr: data = json.load(fr) - positive_dset = Dataset.from_list(data[tag]['positive']) - negative_dset = Dataset.from_list(data[tag]['negative']) + positive_dset = Dataset.from_list(data[tag]["positive"]) + negative_dset = Dataset.from_list(data[tag]["negative"]) - pos_x_train, pos_x_test = _load_data(positive_dset, tag, limit=train_examples, split=0.8) - neg_x_train, neg_x_test = _load_data(negative_dset, "other", limit=train_examples, split=0.8) + pos_x_train, pos_x_test = _load_data( + positive_dset, tag, limit=train_examples, split=0.8 + ) + neg_x_train, neg_x_test = _load_data( + negative_dset, "other", limit=train_examples, split=0.8 + ) pos_x_train = pos_x_train.add_column("tag", [tag] * len(pos_x_train)) pos_x_test = pos_x_test.add_column("tag", [tag] * len(pos_x_test)) @@ -180,14 +202,19 @@ def retag( test = concatenate_datasets([pos_x_test, neg_x_test]) label_binarizer = preprocessing.LabelBinarizer() - label_binarizer_path = os.path.join(save_to_path, tag.replace(" ", ""), 'labelbinarizer') + label_binarizer_path = os.path.join( + save_to_path, tag.replace(" ", ""), "labelbinarizer" + ) labels = [1 if x == tag else 0 for x in train["tag"]] label_binarizer.fit(labels) - with open(label_binarizer_path, 'wb') as f: + with open(label_binarizer_path, "wb") as f: pkl.dump(label_binarizer, f) model = MeshXLinear(label_binarizer_path=label_binarizer_path) - model.fit(train["abstractText"], scipy.sparse.csr_matrix(label_binarizer.transform(labels))) + model.fit( + train["abstractText"], + scipy.sparse.csr_matrix(label_binarizer.transform(labels)), + ) models[tag] = model model_path = os.path.join(save_to_path, tag.replace(" ", ""), "clf") os.makedirs(model_path, exist_ok=True) @@ -195,71 +222,61 @@ def retag( logging.info("- Predicting all tags") dset = dset.add_column("changes", [[]] * len(dset)) - with open(os.path.join(save_to_path, 'corrections'), 'w') as f: + with open(os.path.join(save_to_path, "corrections"), "w") as f: for b in tqdm.tqdm(range(int(len(dset) / batch_size))): start = b * batch_size - end = min(len(dset), (b+1) * batch_size) + end = min(len(dset), (b + 1) * batch_size) batch = dset.select([i for i in range(start, end)]) batch_buffer = [x for x in batch] for tag in models.keys(): batch_preds = models[tag](batch["abstractText"], threshold=threshold) for i, bp in enumerate(batch_preds): is_predicted = bp == [0] - is_expected = tag in batch[i]['meshMajor'] + is_expected = tag in batch[i]["meshMajor"] if is_predicted != is_expected: if is_predicted: - batch_buffer[i]['meshMajor'].append(tag) - batch_buffer[i]['changes'].append(f"+{tag}") + batch_buffer[i]["meshMajor"].append(tag) + batch_buffer[i]["changes"].append(f"+{tag}") else: - batch_buffer[i]['meshMajor'].remove(tag) - batch_buffer[i]['changes'].append(f"-{tag}") + batch_buffer[i]["meshMajor"].remove(tag) + batch_buffer[i]["changes"].append(f"-{tag}") # batch = Dataset.from_list(batch_buffer) # buffer = io.BytesIO() # batch.to_json(buffer) # f.write(buffer.getvalue().decode('utf-8')) batch_buffer = [json.dumps(x) for x in batch_buffer] - f.write('\n'.join(batch_buffer)) + f.write("\n".join(batch_buffer)) @retag_app.command() def retag_cli( - data_path: str = typer.Argument( - ..., - help="Path to allMeSH_2021.jsonl"), + data_path: str = typer.Argument(..., help="Path to allMeSH_2021.jsonl"), save_to_path: str = typer.Argument( - ..., - help="Path where to save the retagged data" + ..., help="Path where to save the retagged data" ), num_proc: int = typer.Option( - os.cpu_count(), - help="Number of processes to use for data augmentation" + os.cpu_count(), help="Number of processes to use for data augmentation" ), batch_size: int = typer.Option( - 1024, - help="Preprocessing batch size (for dataset, filter, map, ...)" - ), - tags: str = typer.Option( - None, - help="Comma separated list of tags to retag" + 1024, help="Preprocessing batch size (for dataset, filter, map, ...)" ), + tags: str = typer.Option(None, help="Comma separated list of tags to retag"), tags_file_path: str = typer.Option( None, help="Text file containing one line per tag to be considered. " "The rest will be discarded.", ), threshold: float = typer.Option( - 0.9, - help="Minimum threshold of confidence to retag a model. Default: 0.9" + 0.9, help="Minimum threshold of confidence to retag a model. Default: 0.9" ), train_examples: int = typer.Option( - 100, - help="Number of examples to use for training the retaggers" + 100, help="Number of examples to use for training the retaggers" ), supervised: bool = typer.Option( False, help="Use human curation, showing a `limit` amount of positive and negative examples to curate data" - " for training the retaggers. The user will be required to accept or reject. When the limit is reached," - " the model will be train. All intermediary steps will be saved." + " for training the retaggers. The user will be required to accept or reject. When the limit is reached," + " the model will be train. All intermediary steps will be saved.", ), years: str = typer.Option( None, help="Comma-separated years you want to include in the retagging process" @@ -280,9 +297,7 @@ def retag_cli( exit(-1) if tags_file_path is not None and not os.path.isfile(tags_file_path): - logger.error( - f"{tags_file_path} not found" - ) + logger.error(f"{tags_file_path} not found") exit(-1) retag(