diff --git a/CHANGELOG.md b/CHANGELOG.md index 205cc5e..825c32f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1 @@ # Changelog - -## Version 0.1 (development) - -- Feature A added -- FIX: nasty bug #1729 fixed -- add your changes here! diff --git a/README.md b/README.md index ca9432a..da904a9 100644 --- a/README.md +++ b/README.md @@ -14,37 +14,77 @@ > utility for using transformers summarization models on text docs -A continuation of the [document summarization]() space on huggingface. +An extension/generalization of the [document summarization]() space on huggingface. The purpose of this package is to provide a simple interface for using summarization models on text documents of arbitrary length. + +⚠️ **WARNING**: _This package is a WIP and is not ready for production use. Some things may not work yet._ ⚠️ ## Installation +Install the package using pip: + ```bash -pip install -e . +# create a virtual environment (optional) +pip install git+https://github.com/pszemraj/textsum.git ``` +The textsum package is now installed in your virtual environment. You can now use the CLI or UI demo (see [Usage](#usage)). + +### Full Installation _(PDF OCR, gradio UI demo)_ + To install all the dependencies _(includes PDF OCR, gradio UI demo)_, run: ```bash +git clone https://github.com/pszemraj/textsum.git +cd textsum +# create a virtual environment (optional) pip install -e .[all] ``` ## Usage +### CLI + +To summarize a directory of text files, run the following command: + +```bash +textsum-dir /path/to/dir +``` + +The following options are available: + +``` +usage: textsum-dir [-h] [-o OUTPUT_DIR] [-m MODEL_NAME] [-batch BATCH_LENGTH] [-stride BATCH_STRIDE] [-nb NUM_BEAMS] + [-l2 LENGTH_PENALTY] [-r2 REPETITION_PENALTY] [--no_cuda] [-length_ratio MAX_LENGTH_RATIO] [-ml MIN_LENGTH] + [-enc_ngram ENCODER_NO_REPEAT_NGRAM_SIZE] [-dec_ngram NO_REPEAT_NGRAM_SIZE] [--no_early_stopping] [--shuffle] + [--lowercase] [-v] [-vv] [-lf LOGFILE] + input_dir +``` + +For more information, run: + +```bash +textsum-dir --help +``` + ### UI Demo -Simply run the following command to start the UI demo: +For convenience, a UI demo is provided using [gradio](https://gradio.app/). To run the demo, run the following command: ```bash -ts-ui +textsum-ui ``` -Other args to be added soon +This is currently a minimal demo, but it will be expanded in the future to accept other arguments and options. + +--- ## Roadmap - [ ] add argparse CLI for UI demo -- [ ] add CLI for summarization of all text files in a directory -- [ ] API for summarization of text docs +- [x] add CLI for summarization of all text files in a directory +- [ ] python API for summarization of text docs +- [ ] optimum inference integration +- [ ] better documentation, details on improving performance (speed, quality, memory usage, etc.) and other things I haven't thought of yet diff --git a/setup.cfg b/setup.cfg index d07aa01..f72c610 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,7 +85,8 @@ testing = [options.entry_points] # Add here console scripts like: console_scripts = - ts-ui = textsum.app:run + textsum-ui = textsum.app:run + textsum-dir = textsum.cli:run # For example: # console_scripts = # fibonacci = textsum.skeleton:run diff --git a/src/textsum/__init__.py b/src/textsum/__init__.py index e451f10..243e7cf 100644 --- a/src/textsum/__init__.py +++ b/src/textsum/__init__.py @@ -1,5 +1,11 @@ +""" +textsum - a package for summarizing text + +""" import sys +from . import cli, utils + if sys.version_info[:2] >= (3, 8): # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` from importlib.metadata import PackageNotFoundError, version # pragma: no cover diff --git a/src/textsum/app.py b/src/textsum/app.py index 57d5c36..faab016 100644 --- a/src/textsum/app.py +++ b/src/textsum/app.py @@ -1,3 +1,6 @@ +""" +app.py - a module to run the text summarization app (gradio interface) +""" import contextlib import logging import os @@ -19,7 +22,7 @@ from textsum.pdf2text import convert_PDF_to_Text from textsum.summarize import load_model_and_tokenizer, summarize_via_tokenbatches -from textsum.utils import load_example_filenames, saves_summary, truncate_word_count +from textsum.utils import save_summary, truncate_word_count _here = Path(__file__).parent @@ -137,7 +140,7 @@ def proc_submission( html += "" # save to file - saved_file = saves_summary(_summaries) + saved_file = save_summary(_summaries) return html, sum_text_out, scores_out, saved_file @@ -156,6 +159,7 @@ def load_uploaded_file(file_obj, max_pages=20): # file_path = Path(file_obj[0].name) # check if mysterious file object is a list + global ocr_model if isinstance(file_obj, list): file_obj = file_obj[0] file_path = Path(file_obj.name) diff --git a/src/textsum/cli.py b/src/textsum/cli.py new file mode 100644 index 0000000..b181536 --- /dev/null +++ b/src/textsum/cli.py @@ -0,0 +1,367 @@ +""" +cli.py - a module containing functions for the command line interface (to run the summarization on a directory of files) + #TODO: add a function to summarize a single file + +usage: textsum-dir [-h] [-o OUTPUT_DIR] [-m MODEL_NAME] [-batch BATCH_LENGTH] [-stride BATCH_STRIDE] [-nb NUM_BEAMS] + [-l2 LENGTH_PENALTY] [-r2 REPETITION_PENALTY] [--no_cuda] [-length_ratio MAX_LENGTH_RATIO] [-ml MIN_LENGTH] + [-enc_ngram ENCODER_NO_REPEAT_NGRAM_SIZE] [-dec_ngram NO_REPEAT_NGRAM_SIZE] [--no_early_stopping] [--shuffle] + [--lowercase] [-v] [-vv] [-lf LOGFILE] + input_dir + +Summarize text files in a directory + +positional arguments: + input_dir the directory containing the input files + +""" +import argparse +import logging +import pprint as pp +import random +import sys +import warnings +from pathlib import Path + +import torch +from cleantext import clean +from tqdm.auto import tqdm + +from textsum.summarize import ( + load_model_and_tokenizer, + save_params, + summarize_via_tokenbatches, +) +from textsum.utils import get_mem_footprint, postprocess_booksummary, setup_logging + + +def summarize_text_file( + file_path: str or Path, + model, + tokenizer, + batch_length: int = 4096, + batch_stride: int = 16, + lowercase: bool = False, + **kwargs, +) -> dict: + """ + summarize_text_file - given a file path, summarize the text in the file + + :param str or Path file_path: the path to the file to summarize + :param model: the model to use for summarization + :param tokenizer: the tokenizer to use for summarization + :param int batch_length: length of each batch in tokens to summarize, defaults to 4096 + :param int batch_stride: stride between batches in tokens, defaults to 16 + :param bool lowercase: whether to lowercase the text before summarizing, defaults to False + :return: a dictionary containing the summary and other information + """ + file_path = Path(file_path) + ALLOWED_EXTENSIONS = [".txt", ".md", ".rst", ".py", ".ipynb"] + assert ( + file_path.exists() and file_path.suffix in ALLOWED_EXTENSIONS + ), f"File {file_path} does not exist or is not a text file" + + logging.info(f"Summarizing {file_path}") + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + text = clean(f.read(), lower=lowercase, no_line_breaks=True) + logging.debug( + f"Text length: {len(text)}. batch length: {batch_length} batch stride: {batch_stride}" + ) + summary_data = summarize_via_tokenbatches( + input_text=text, + model=model, + tokenizer=tokenizer, + batch_length=batch_length, + batch_stride=batch_stride, + **kwargs, + ) + logging.info(f"Finished summarizing {file_path}") + return summary_data + + +def process_summarization( + summary_data: dict, + target_file: str or Path, + custom_phrases: list = None, + save_scores: bool = True, +) -> None: + """ + process_summarization - given a dictionary of summary data, save the summary to a file + + :param dict summary_data: a dictionary containing the summary and other information (output from summarize_text_file) + :param str or Path target_file: the path to the file to save the summary to + :param list custom_phrases: a list of custom phrases to remove from each summary (relevant for dataset specific repeated phrases) + :param bool save_scores: whether to write the scores to a file + """ + target_file = Path(target_file).resolve() + if target_file.exists(): + warnings.warn(f"File {target_file} exists, overwriting") + + sum_text = [ + postprocess_booksummary( + s["summary"][0], + custom_phrases=custom_phrases, + ) + for s in summary_data + ] + sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in summary_data] + scores_text = "\n".join(sum_scores) + full_summary = "\n\t".join(sum_text) + + with open( + target_file, + "w", + ) as fo: + + fo.writelines(full_summary) + + if save_scores: + with open( + target_file, + "a", + ) as fo: + + fo.write("\n" * 3) + fo.write(f"\n\nSection Scores for {target_file.stem}:\n") + fo.writelines(scores_text) + fo.write("\n\n---\n") + + logging.info(f"Saved summary to {target_file.resolve()}") + + +def get_parser(): + """ + get_parser - a function that returns an argument parser for the sum_files script + + :return argparse.ArgumentParser: the argument parser + """ + parser = argparse.ArgumentParser( + description="Summarize text files in a directory", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "-o", + "--output_dir", + type=str, + default=None, + dest="output_dir", + help="directory to write the output files (if None, writes to input_dir/summarized)", + ) + parser.add_argument( + "-m", + "--model_name", + type=str, + default="pszemraj/long-t5-tglobal-base-16384-book-summary", + help="the name of the model to use for summarization", + ) + parser.add_argument( + "-batch", + "--batch_length", + dest="batch_length", + type=int, + default=4096, + help="the length of each batch", + ) + parser.add_argument( + "-stride", + "--batch_stride", + type=int, + default=16, + help="the stride of each batch", + ) + parser.add_argument( + "-nb", + "--num_beams", + type=int, + default=4, + help="the number of beams to use for beam search", + ) + parser.add_argument( + "-l2", + "--length_penalty", + type=float, + default=0.8, + help="the length penalty to use for decoding", + ) + parser.add_argument( + "-r2", + "--repetition_penalty", + type=float, + default=2.5, + help="the repetition penalty to use for beam search", + ) + parser.add_argument( + "--no_cuda", + action="store_true", + help="flag to not use cuda if available", + ) + parser.add_argument( + "-length_ratio", + "--max_length_ratio", + dest="max_length_ratio", + type=int, + default=0.25, + help="the maximum length of the summary as a ratio of the batch length", + ) + parser.add_argument( + "-ml", + "--min_length", + type=int, + default=8, + help="the minimum length of the summary", + ) + parser.add_argument( + "-enc_ngram", + "--encoder_no_repeat_ngram_size", + type=int, + default=4, + dest="encoder_no_repeat_ngram_size", + help="encoder no repeat ngram size (input text). smaller values mean more unique summaries", + ) + parser.add_argument( + "-dec_ngram", + "--no_repeat_ngram_size", + type=int, + default=3, + dest="no_repeat_ngram_size", + help="the decoder no repeat ngram size (output text)", + ) + parser.add_argument( + "--no_early_stopping", + action="store_false", + dest="early_stopping", + help="whether to use early stopping. this disables the early_stopping value", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle the input files before summarizing", + ) + parser.add_argument( + "--lowercase", + action="store_true", + help="whether to lowercase the input text", + ) + parser.add_argument( + "-v", + "--verbose", + dest="loglevel", + help="set loglevel to INFO", + action="store_const", + const=logging.INFO, + ) + parser.add_argument( + "-vv", + "--very_verbose", + dest="loglevel", + help="set loglevel to DEBUG", + action="store_const", + const=logging.DEBUG, + ) + parser.add_argument( + "-lf", + "--log_file", + dest="logfile", + type=str, + default=None, + help="path to the log file. this will set loglevel to INFO (if not set) and write to the file", + ) + parser.add_argument( + "input_dir", + type=str, + help="the directory containing the input files", + ) + + # if there are no args, print the help + if len(sys.argv) == 1: + parser.print_help(sys.stderr) + sys.exit(1) + + return parser + + +def main(args): + """ + main - the main function for the script + + :param argparse.Namespace args: the arguments for the script + """ + setup_logging(args.loglevel, args.logfile) + logging.info("starting summarization") + logging.info(f"args: {pp.pformat(args)}") + + device = torch.device( + "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + ) + logging.info(f"using device: {device}") + # load the model and tokenizer + model, tokenizer = load_model_and_tokenizer( + args.model_name, use_cuda=not args.no_cuda + ) + + logging.info(f"model size: {get_mem_footprint(model)}") + # move the model to the device + model.to(device) + + params = { + "min_length": args.min_length, + "max_length": int(args.max_length_ratio * args.batch_length), + "encoder_no_repeat_ngram_size": args.encoder_no_repeat_ngram_size, + "no_repeat_ngram_size": args.no_repeat_ngram_size, + "repetition_penalty": args.repetition_penalty, + "num_beams": args.num_beams, + "num_beam_groups": 1, + "length_penalty": args.length_penalty, + "early_stopping": args.early_stopping, + "do_sample": False, + } + # get the input files + input_files = list(Path(args.input_dir).glob("*.txt")) + + if args.shuffle: + logging.info("shuffling input files") + random.SystemRandom().shuffle(input_files) + + # get the output directory + output_dir = ( + Path(args.output_dir) + if args.output_dir + else Path(args.input_dir) / "summarized" + ) + output_dir.mkdir(exist_ok=True, parents=True) + + # get the batches + for f in tqdm(input_files): + + outpath = output_dir / f"{f.stem}.summary.txt" + summary_data = summarize_text_file( + file_path=f, + model=model, + tokenizer=tokenizer, + batch_length=args.batch_length, + batch_stride=args.batch_stride, + lowercase=args.lowercase, + **params, + ) + process_summarization( + summary_data=summary_data, target_file=outpath, save_scores=True + ) + + logging.info(f"finished summarization loop - output dir: {output_dir.resolve()}") + save_params(params=params, output_dir=output_dir, hf_tag=args.model_name) + + logging.info("finished summarizing files") + + +def run(): + """ + run - main entry point for the script + """ + + parser = get_parser() + args = parser.parse_args() + main(args) + + +if __name__ == "__main__": + run() diff --git a/src/textsum/pdf2text.py b/src/textsum/pdf2text.py index 88cc02e..cbdf31e 100644 --- a/src/textsum/pdf2text.py +++ b/src/textsum/pdf2text.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- """ - -easyocr.py - A wrapper for easyocr to convert pdf to images to text +pdf2text.py - convert pdf files to text files (OCR). helper functions for textsum """ import logging diff --git a/src/textsum/skeleton.py b/src/textsum/skeleton.py deleted file mode 100644 index e43156e..0000000 --- a/src/textsum/skeleton.py +++ /dev/null @@ -1,149 +0,0 @@ -""" -This is a skeleton file that can serve as a starting point for a Python -console script. To run this script uncomment the following lines in the -``[options.entry_points]`` section in ``setup.cfg``:: - - console_scripts = - fibonacci = textsum.skeleton:run - -Then run ``pip install .`` (or ``pip install -e .`` for editable mode) -which will install the command ``fibonacci`` inside your current environment. - -Besides console scripts, the header (i.e. until ``_logger``...) of this file can -also be used as template for Python modules. - -Note: - This file can be renamed depending on your needs or safely removed if not needed. - -References: - - https://setuptools.pypa.io/en/latest/userguide/entry_point.html - - https://pip.pypa.io/en/stable/reference/pip_install -""" - -import argparse -import logging -import sys - -from textsum import __version__ - -__author__ = "peter szemraj" -__copyright__ = "peter szemraj" -__license__ = "Apache-2.0" - -_logger = logging.getLogger(__name__) - - -# ---- Python API ---- -# The functions defined in this section can be imported by users in their -# Python scripts/interactive interpreter, e.g. via -# `from textsum.skeleton import fib`, -# when using this Python module as a library. - - -def fib(n): - """Fibonacci example function - - Args: - n (int): integer - - Returns: - int: n-th Fibonacci number - """ - assert n > 0 - a, b = 1, 1 - for _i in range(n - 1): - a, b = b, a + b - return a - - -# ---- CLI ---- -# The functions defined in this section are wrappers around the main Python -# API allowing them to be called directly from the terminal as a CLI -# executable/script. - - -def parse_args(args): - """Parse command line parameters - - Args: - args (List[str]): command line parameters as list of strings - (for example ``["--help"]``). - - Returns: - :obj:`argparse.Namespace`: command line parameters namespace - """ - parser = argparse.ArgumentParser(description="Just a Fibonacci demonstration") - parser.add_argument( - "--version", - action="version", - version="textsum {ver}".format(ver=__version__), - ) - parser.add_argument(dest="n", help="n-th Fibonacci number", type=int, metavar="INT") - parser.add_argument( - "-v", - "--verbose", - dest="loglevel", - help="set loglevel to INFO", - action="store_const", - const=logging.INFO, - ) - parser.add_argument( - "-vv", - "--very-verbose", - dest="loglevel", - help="set loglevel to DEBUG", - action="store_const", - const=logging.DEBUG, - ) - return parser.parse_args(args) - - -def setup_logging(loglevel): - """Setup basic logging - - Args: - loglevel (int): minimum loglevel for emitting messages - """ - logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s" - logging.basicConfig( - level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S" - ) - - -def main(args): - """Wrapper allowing :func:`fib` to be called with string arguments in a CLI fashion - - Instead of returning the value from :func:`fib`, it prints the result to the - ``stdout`` in a nicely formatted message. - - Args: - args (List[str]): command line parameters as list of strings - (for example ``["--verbose", "42"]``). - """ - args = parse_args(args) - setup_logging(args.loglevel) - _logger.debug("Starting crazy calculations...") - print("The {}-th Fibonacci number is {}".format(args.n, fib(args.n))) - _logger.info("Script ends here") - - -def run(): - """Calls :func:`main` passing the CLI arguments extracted from :obj:`sys.argv` - - This function can be used as entry point to create console scripts with setuptools. - """ - main(sys.argv[1:]) - - -if __name__ == "__main__": - # ^ This is a guard statement that will prevent the following code from - # being executed in the case someone imports this file instead of - # executing it as a script. - # https://docs.python.org/3/library/__main__.html - - # After installing your project with pip, users can also run your Python - # modules as scripts via the ``-m`` flag, as defined in PEP 338:: - # - # python -m textsum.skeleton 42 - # - run() diff --git a/src/textsum/summarize.py b/src/textsum/summarize.py index 28167df..3036d91 100644 --- a/src/textsum/summarize.py +++ b/src/textsum/summarize.py @@ -1,29 +1,37 @@ +""" +summarize.py - a module that contains functions for summarizing text +""" +import json import logging +from pathlib import Path import torch from tqdm.auto import tqdm from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from textsum.utils import get_timestamp -def load_model_and_tokenizer(model_name): + +def load_model_and_tokenizer(model_name: str, use_cuda: bool = True): """ load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface Args: - model_name (str): the name of the model to load + model_name (str): the name of the model to load from huggingface + use_cuda (bool, optional): whether to use cuda. Defaults to True. Returns: AutoModelForSeq2SeqLM: the model AutoTokenizer: the tokenizer """ - device = "cuda" if torch.cuda.is_available() else "cpu" + logger = logging.getLogger(__name__) + device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" + logger.debug(f"loading model {model_name} to {device}") model = AutoModelForSeq2SeqLM.from_pretrained( model_name, - # low_cpu_mem_usage=True, - # use_cache=False, ).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) - logging.info(f"Loaded model {model_name} to {device}") + logger.info(f"Loaded model {model_name} to {device}") return model, tokenizer @@ -63,6 +71,7 @@ def summarize_and_score( **kwargs, ) else: + # this is for LED etc. summary_pred_ids = model.generate( input_ids, attention_mask=attention_mask, @@ -85,7 +94,7 @@ def summarize_via_tokenbatches( input_text: str, model, tokenizer, - batch_length=2048, + batch_length=4096, batch_stride=16, **kwargs, ): @@ -96,18 +105,20 @@ def summarize_via_tokenbatches( input_text (str): the text to summarize model (): the model to use for summarizationz tokenizer (): the tokenizer to use for summarization - batch_length (int, optional): the length of each batch. Defaults to 2048. + batch_length (int, optional): the length of each batch. Defaults to 4096. batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches. Returns: str: the summary """ + + logger = logging.getLogger(__name__) # log all input parameters if batch_length < 512: batch_length = 512 - print("WARNING: batch_length was set to 512") - print( - f"input parameters: {kwargs}, batch_length={batch_length}, batch_stride={batch_stride}" + logger.warning("WARNING: batch_length was set to 512") + logger.debug( + f"batch_length: {batch_length} batch_stride: {batch_stride}, kwargs: {kwargs}" ) encoded_input = tokenizer( input_text, @@ -141,9 +152,41 @@ def summarize_via_tokenbatches( "summary_score": score, } gen_summaries.append(_sum) - print(f"\t{result[0]}\nScore:\t{score}") + logger.debug(f"\n\t{result[0]}\nScore:\t{score}") pbar.update() pbar.close() return gen_summaries + + +def save_params( + params: dict, + output_dir: str or Path, + hf_tag: str = None, + verbose: bool = False, +) -> None: + """ + save_params - save the parameters of the run to a json file + + :param dict params: parameters to save + :param str or Path output_dir: directory to save the parameters to + :param str hf_tag: the model tag on huggingface + :param bool verbose: whether to log the parameters + + :return: None + """ + output_dir = Path(output_dir) if output_dir is not None else Path.cwd() + session_settings = params + session_settings["huggingface-model-tag"] = "" if hf_tag is None else hf_tag + session_settings["date-run"] = get_timestamp() + + metadata_path = output_dir / "summarization-parameters.json" + logging.info(f"Saving parameters to {metadata_path}") + with open(metadata_path, "w") as write_file: + json.dump(session_settings, write_file) + + logging.debug(f"Saved parameters to {metadata_path}") + if verbose: + # log the parameters + logging.info(f"parameters: {session_settings}") diff --git a/src/textsum/utils.py b/src/textsum/utils.py index e7c2fe2..73b5ad6 100644 --- a/src/textsum/utils.py +++ b/src/textsum/utils.py @@ -5,6 +5,7 @@ import logging import re import subprocess +import sys from datetime import datetime from pathlib import Path @@ -13,7 +14,6 @@ format="%(asctime)s %(levelname)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S", ) -from natsort import natsorted # ------------------------- # @@ -28,24 +28,98 @@ def get_timestamp() -> str: """ get_timestamp - get a timestamp for the current time - Returns: - str, the timestamp """ return datetime.now().strftime("%Y%m%d_%H%M%S") +def regex_gpu_name(input_text: str): + """backup if not a100""" + + pattern = re.compile(r"(\s([A-Za-z0-9]+\s)+)(\s([A-Za-z0-9]+\s)+)", re.IGNORECASE) + return pattern.search(input_text).group() + + +def check_GPU(verbose=False): + """ + check_GPU - a function in Python that uses the subprocess module and regex to call the `nvidia-smi` command and check the available GPU. the function returns a boolean as to whether the GPU is an A100 or not + + :param verbose: if true, print out which GPU was found if it is not an A100 + """ + # call nvidia-smi + nvidia_smi = subprocess.run( + ["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + # convert to string + nvidia_smi = nvidia_smi.stdout.decode("utf-8") + search_past = "===============================" + # use regex to find the GPU name. search in the first newline underneath + output_lines = nvidia_smi.split("\n") + for i, line in enumerate(output_lines): + if search_past in line: + break + # get the next line + next_line = output_lines[i + 1] + if verbose: + print(next_line) + # use regex to find the GPU name + try: + gpu_name = re.search(r"\w+-\w+-\w+", next_line).group() + except AttributeError: + logging.debug("Could not find GPU name with initial regex") + gpu_name = None + + if gpu_name is None: + # try alternates + try: + gpu_name = regex_gpu_name(next_line) + except Exception as e: + logging.error(f"Could not find GPU name: {e}") + return False + + if verbose: + print(f"GPU found: {gpu_name}") + # check if it is an A100 + return bool("A100" in gpu_name) + + +def cstr(s, color="black"): + """styles a string with a color""" + return "{}".format(color, s) + + +def color_print(text: str, c_id="pink"): + """helper function to print colored text to the terminal""" + + colormap = { + "red": "\033[91m", + "green": "\033[92m", + "yellow": "\033[93m", + "blue": "\033[94m", + "pink": "\033[95m", + "teal": "\033[96m", + "grey": "\033[97m", + } + + print(f"{colormap[c_id]}{text}") + + +def get_mem_footprint(test_model): + """ + get_mem_footprint - a helper function for the gradio module to get the memory footprint of a model (for huggingface models) + """ + fp = test_model.get_memory_footprint() * (10**-9) + print(f"memory footprint is approx {round(fp, 2)} GB") + + def truncate_word_count(text, max_words=512): """ - truncate_word_count - a helper function for the gradio module - Parameters - ---------- - text : str, required, the text to be processed - max_words : int, optional, the maximum number of words, default=512 - Returns - ------- - dict, the text and whether it was truncated + truncate_word_count - a helper function for the gradio module to truncate the text to a max number of words + + :param str text: the text to truncate + :param int max_words: the max number of words to truncate to (default 512) + :return dict: a dictionary with the truncated text and a boolean indicating whether the text was truncated """ - # split on whitespace with regex + words = re.split(r"\s+", text) processed = {} if len(words) > max_words: @@ -57,31 +131,6 @@ def truncate_word_count(text, max_words=512): return processed -def load_pdf_examples(src, filetypes=[".txt", ".pdf"]): - """ - load_examples - a helper function for the gradio module to load examples - Returns: - list of str, the examples - """ - src = Path(src) - src.mkdir(exist_ok=True) - - pdf_url = ( - "https://www.dropbox.com/s/y92xy7o5qb88yij/all_you_need_is_attention.pdf?dl=1" - ) - subprocess.run(["wget", pdf_url, "-O", src / "all_you_need_is_attention.pdf"]) - examples = [f for f in src.iterdir() if f.suffix in filetypes] - examples = natsorted(examples) - # load the examples into a list - text_examples = [] - for example in examples: - with open(example, "r") as f: - text = f.read() - text_examples.append([text, "base", 2, 1024, 0.7, 3.5, 3]) - - return text_examples - - def load_text_examples( urls: dict = TEXT_EXAMPLE_URLS, target_dir: str or Path = None ) -> Path: @@ -89,7 +138,7 @@ def load_text_examples( load_text_examples - load the text examples from the web to a directory :param dict urls: the urls to the text examples, defaults to TEXT_EXAMPLE_URLS - :param strorPath target_dir: the path to the target directory, defaults to the current working directory + :param str or Path target_dir: the path to the target directory, defaults to the current working directory :return Path: the path to the directory containing the text examples """ target_dir = Path.cwd() if target_dir is None else Path(target_dir) @@ -101,7 +150,12 @@ def load_text_examples( return target_dir -def load_example_filenames(example_path: str or Path, ext: list = [".txt", ".md"]): +TEXT_EX_EXTENSIONS = [".txt", ".md"] + + +def load_example_filenames( + example_path: str or Path, ext: list = TEXT_EX_EXTENSIONS +) -> dict: """ load_example_filenames - load the example filenames from a directory @@ -121,17 +175,26 @@ def load_example_filenames(example_path: str or Path, ext: list = [".txt", ".md" return examples -def saves_summary(summarize_output, outpath: str or Path = None, add_signature=True): +def save_summary( + summarize_output, outpath: str or Path = None, write_scores=True +) -> Path: """ - saves_summary - save the summary generated from summarize_via_tokenbatches() to a text file + save_summary - save the summary generated from summarize_via_tokenbatches() to a text file + + :param list summarize_output: the output from summarize_via_tokenbatches() + :param strorPath outpath: the path to the output file, defaults to the current working directory + :param bool write_scores: whether to write the scores to the output file, defaults to True + :return Path: the path to the output file + Example in use: _summaries = summarize_via_tokenbatches( text, batch_length=token_batch_length, batch_stride=batch_stride, **settings, ) + save_summary(_summaries, outpath=outpath, write_scores=True) """ outpath = ( @@ -144,23 +207,87 @@ def saves_summary(summarize_output, outpath: str or Path = None, add_signature=T scores_text = "\n".join(sum_scores) full_summary = "\n\t".join(sum_text) - with open( - outpath, - "w", - ) as fo: - if add_signature: - fo.write( - "Generated with the Document Summarization space :) https://hf.co/spaces/pszemraj/document-summarization\n\n" - ) + with open(outpath, "w", encoding="utf-8", errors="ignore") as fo: fo.writelines(full_summary) - with open( - outpath, - "a", - ) as fo: + if write_scores: + with open(outpath, "a", encoding="utf-8", errors="ignore") as fo: - fo.write("\n" * 3) - fo.write(f"\n\nSection Scores:\n") - fo.writelines(scores_text) - fo.write("\n\n---\n") + fo.write("\n" * 3) + fo.write(f"\n\nSection Scores:\n") + fo.writelines(scores_text) + fo.write("\n\n---\n") return outpath + + +def setup_logging(loglevel, logfile=None) -> None: + """Setup basic logging + you will need something like this in your main script: + parser.add_argument( + "-v", + "--verbose", + dest="loglevel", + help="set loglevel to INFO", + action="store_const", + const=logging.INFO, + ) + parser.add_argument( + "-vv", + "--very-verbose", + dest="loglevel", + help="set loglevel to DEBUG", + action="store_const", + const=logging.DEBUG, + ) + Args: + loglevel (int): minimum loglevel for emitting messages + logfile (str): path to logfile. If None, log to stderr. + """ + # remove any existing handlers + root = logging.getLogger() + if root.handlers: + for handler in root.handlers: + root.removeHandler(handler) + + logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s" + if logfile is None: + logging.basicConfig( + level=loglevel, + stream=sys.stdout, + format=logformat, + datefmt="%Y-%m-%d %H:%M:%S", + ) + else: + loglevel = ( + logging.INFO if not loglevel in [logging.DEBUG, logging.INFO] else loglevel + ) + logging.basicConfig( + level=loglevel, + filename=logfile, + filemode="w", + format=logformat, + datefmt="%Y-%m-%d %H:%M:%S", + ) + + +def postprocess_booksummary(text: str, custom_phrases: list = None) -> str: + """ + postprocess_booksummary - postprocess the book summary + + :param str text: the text to postprocess + :param list custom_phrases: custom phrases to remove from the text, defaults to None + :return str: the postprocessed text + """ + REMOVAL_PHRASES = [ + "In this section, ", + "In this lecture, ", + "In this chapter, ", + "In this paper, ", + ] # the default phrases to remove (from booksum dataset) + + if custom_phrases is not None: + REMOVAL_PHRASES.extend(custom_phrases) + for pr in REMOVAL_PHRASES: + + text = text.replace(pr, "") + return text