diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..5e69f6e --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,33 @@ +version: 2.1 + +orbs: + python: circleci/python@0.2.1 + +jobs: + build-and-test: + executor: python/default + steps: + - checkout + - python/load-cache + - run: + name: Install cython/numpy/bhtsne + command: | + pip install Cython + pip install numpy + pip install bhtsne + - python/install-deps + - python/save-cache + - run: + name: Install seqc + command: pip install . + - run: + name: Test + command: | + export TMPDIR="/tmp" + python -m nose2 -s src/seqc/tests test_run_rmt_correction + + +workflows: + main: + jobs: + - build-and-test diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml new file mode 100644 index 0000000..2a869f0 --- /dev/null +++ b/.github/workflows/python-app.yml @@ -0,0 +1,38 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Python application + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + pip install Cython + pip install numpy + pip install bhtsne + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Install SEQC + run: pip install . + - name: Test with nose2 + run: | + export TMPDIR="/tmp" + nose2 -s src/seqc/tests test_run_rmt_correction diff --git a/.gitignore b/.gitignore index df56e36..6482b17 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,6 @@ dist/* .project .pydevproject .c9/ +test-data/ +dask-worker-space/ + diff --git a/README.md b/README.md index 29f3df7..86fbe93 100644 --- a/README.md +++ b/README.md @@ -1,83 +1,150 @@ -## SEquence Quality Control (SEQC -- /sek-si:/) +# SEquence Quality Control (SEQC -- /sek-si:/) ## Overview: SEQC is a python package that processes single-cell sequencing data in the cloud and analyzes it interactively on your local machine. -To faciliate easy installation and use, we have made available Amazon Machine Images (AMIs) that come with all of SEQC's dependencies pre-installed. In addition, we have uploaded common genome indices (`-i/--index parameter`) and barcode data (`--barcode-files`) to public amazon s3 repositories. These links can be provided to SEQC and it will automatically fetch them prior to initiating an analysis run. Finally, it can fetch input data directly from BaseSpace or amazon s3 for analysis. +To faciliate easy installation and use, we have made available Amazon Machine Images (AMIs) that come with all of SEQC's dependencies pre-installed. In addition, we have uploaded common genome indices (`-i/--index parameter`) and barcode data (`--barcode-files`) to public Amazon S3 repositories. These links can be provided to SEQC and it will automatically fetch them prior to initiating an analysis run. Finally, it can fetch input data directly from BaseSpace or amazon s3 for analysis. -For users with access to in-house compute clusters, SEQC can be installed on your systems and run using the --local parameter. +For users with access to in-house compute clusters, SEQC can be installed on your systems and run using the `--local` parameter. -### Dependencies: +## Dependencies: +### Python 3 -#### Python3 -Python must be installed on your local machine to run SEQC. We recommend installing python3 through your unix operating system's package manager. For Mac OSX users we recommend homebrew. Typical installation commands would be: +Python3 must be installed on your local machine to run SEQC. We recommend installing Python3 through Miniconda (https://docs.conda.io/en/latest/miniconda.html). - brew install python3 # mac - apt-get install python3 # debian - yum install python3 # rpm-based +### Python 3 Libraries -#### Python3 Libraries + We recommend creating a virtual environment before installing anything: -Installing these libraries is necessary before installing SEQC. +```bash +conda create -n seqc python=3.7.7 pip +conda activate seqc +``` - pip3 install Cython - pip3 install numpy - pip3 install bhtsne +```bash +pip install Cython +pip install numpy +pip install bhtsne +``` -#### STAR -To process data locally using SEQC, you must install the STAR Aligner, Samtools, and hdf5. If you only intend to use SEQC to trigger remote processing on AWS, these dependencies are optional. We recommend installing samtools and hdf5 through your package manager, if possible. - -#### Hardware Requirements: -For processing a single lane (~200M reads) against human- and mouse-scale genomes, SEQC requires 30GB RAM, approximately 200GB free hard drive space, and scales linearly with additional compute cores. If running on AWS (see below), jobs are automatically scaled up or down according to the size of the input. There are no hardware requirements for the computer used to launch remote instances. - - -#### Amazon Web Services: -SEQC can be run on any unix-based operating system, however it also features the ability to automatically spawn Amazon Web Services instances to process your data. If you wish to take advantage of AWS, you will need to follow their instructions to: - -1. Set up an AWS account -2. Install and configure AWS CLI -3. Create and upload an rsa-key for AWS - - -### SEQC Installation: +### STAR, Samtools, and HDF5 -Once all dependencies have been installed, SEQC can be installed on any machine by typing: - - $> git clone https://github.com/dpeerlab/seqc.git - $> cd seqc && python3 setup.py install - -Please note that to avoid passing the -k/--rsa-key command when you execute SEQC runs, you can also set the environment variable `AWS_RSA_KEY` to the path to your newly created key. - -### Testing SEQC: - -All the unit tests in class `TestSEQC` in `test.py` have been tested. Currently, only two platforms `ten_x_v2` and `in_drop_v2` have been tested. Old unit tests from these two platforms together with other platforms are stored at `s3://dp-lab-data/seqc-old-unit-test/`. - -### Running SEQC: - -After SEQC is installed, help can be listed: +To process data locally using SEQC, you must install the STAR Aligner, Samtools, and hdf5. If you only intend to use SEQC to trigger remote processing on AWS, these dependencies are optional. We recommend installing samtools and hdf5 through your package manager, if possible. - SEQC [-h] [-v] {run,progress,terminate,instances,start,index} ... +## SEQC Installation - Processing Tools for scRNA-seq Experiments +Once all dependencies have been installed, SEQC can be installed by running: - positional arguments: - {run,progress,terminate,instances,start,index} - run initiate SEQC runs - progress check SEQC run progress - terminate terminate SEQC runs - instances list all running instances - start initialize a seqc-ready instance - index create a SEQC index +```bash +export SEQC_VERSION="0.2.6" +wget https://github.com/hisplan/seqc/archive/v${SEQC_VERSION}.tar.gz +tar xvzf v${SEQC_VERSION}.tar.gz +cd seqc-${SEQC_VERSION} +pip install . +``` - optional arguments: - -h, --help show this help message and exit - -v, --version show program's version number and exit +## Hardware Requirements: -In addition to processing sequencing experiments, SEQC.py provides some convenience tools to create indices for use with SEQC and STAR, and tools to check the progress of remote runs, list current runs, start instances, and terminate them. +For processing a single lane (~200M reads) against human- and mouse-scale genomes, SEQC requires 30GB RAM, approximately 200GB free hard drive space, and scales linearly with additional compute cores. If running on AWS (see below), jobs are automatically scaled up or down according to the size of the input. There are no hardware requirements for the computer used to launch remote instances. -To seamlessly start an AWS instance with automatic installation of SEQC from your local machine you can run: +## Running SEQC on Local Machine: + +Download an example dataset (1k PBMCs from a healthy donor; freely available at 10x Genomics https://support.10xgenomics.com/single-cell-gene-expression/datasets/3.0.0/pbmc_1k_v3): + +```bash +wget https://cf.10xgenomics.com/samples/cell-exp/3.0.0/pbmc_1k_v3/pbmc_1k_v3_fastqs.tar +tar xvf pbmc_1k_v3_fastqs.tar +``` + +Move R1 FASTQ files to the `barcode` folder and R2 FASTQ files to the `genomic` folder: + +```bash +mkdir barcode +mkdir genomic +mv ./pbmc_1k_v3_fastqs/*R1*.fastq.gz barcode/ +mv ./pbmc_1k_v3_fastqs/*R2*.fastq.gz genomic/ +``` + +Download the 10x barcode whitelist file: + +```bash +mkdir whitelist +wget https://seqc-public.s3.amazonaws.com/barcodes/ten_x_v3/flat/3M-february-2018.txt +mv 3M-february-2018.txt ./whitelist/ +``` + +The resulting directory structure should look something like this: + +``` +. +├── barcode +│   ├── pbmc_1k_v3_S1_L001_R1_001.fastq.gz +│   └── pbmc_1k_v3_S1_L002_R1_001.fastq.gz +├── genomic +│   ├── pbmc_1k_v3_S1_L001_R2_001.fastq.gz +│   └── pbmc_1k_v3_S1_L002_R2_001.fastq.gz +├── pbmc_1k_v3_fastqs +│   ├── pbmc_1k_v3_S1_L001_I1_001.fastq.gz +│   └── pbmc_1k_v3_S1_L002_I1_001.fastq.gz +├── pbmc_1k_v3_fastqs.tar +└── whitelist + └── 3M-february-2018.txt +``` + +Create a reference package (STAR index + gene annotation): + +```bash +SEQC index \ + --organism homo_sapiens \ + --ensemble-release 93 \ + --valid-biotypes protein_coding lincRNA antisense IG_V_gene IG_D_gene IG_J_gene IG_C_gene TR_V_gene TR_D_gene TR_J_gene TR_C_gene \ + --read-length 101 \ + --folder index \ + --local +``` + +Run SEQC: + +```bash +export AWS_DEFAULT_REGION=us-east-1 +export SEQC_MAX_WORKERS=7 + +SEQC run ten_x_v3 \ + --index ./index/ \ + --barcode-files ./whitelist/ \ + --barcode-fastq ./barcode/ \ + --genomic-fastq ./genomic/ \ + --upload-prefix ./seqc-results/ \ + --output-prefix PBMC \ + --no-filter-low-coverage \ + --min-poly-t 0 \ + --star-args runRNGseed=0 \ + --local +``` + +## Running SEQC on Amazon Web Services: + +SEQC can be run on any unix-based operating system, however it also features the ability to automatically spawn Amazon Web Services instances to process your data. - SEQC start +1. Set up an AWS account +2. Install and configure AWS CLI +3. Create and upload an rsa-key for AWS +Run SEQC: + +```bash +SEQC run ten_x_v2 \ + --ami-id ami-08652ee2477761403 \ + --user-tags Job:Test,Project:PBMC-Test,Sample:pbmc_1k_v3 \ + --index s3://seqc-public/genomes/hg38_long_polya/ \ + --barcode-files s3://seqc-public/barcodes/ten_x_v2/flat/ \ + --genomic-fastq s3://.../genomic/ \ + --barcode-fastq s3://.../barcode/ \ + --upload-prefix s3://.../seqc-results/ \ + --output-prefix PBMC \ + --no-filter-low-coverage \ + --min-poly-t 0 \ + --star-args runRNGseed=0 +``` diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..eb09dd2 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,44 @@ +# docs + +## Developers + +- [Environment setup for development](./install-dev.md) +- [Running test](./run-test.md) + + +## Generating Reference Packages + +This generates a reference package (STAR index and GTF) using SEQC v0.2.6. + +- Ensembl 86 +- Gene annotation file that contains only the reference chromosomes (no scaffolds, no patches) +- Only these biotypes: 'protein_coding', 'lincRNA', 'IG_V_gene', 'IG_C_gene', 'IG_J_gene', 'TR_C_gene', 'TR_J_gene', 'TR_V_gene', 'TR_D_gene', 'IG_D_gene' +- Not passing anything to `--additional-id-types` +- Setting the read length to 101 (internally, this becomes 100) + +### Local + +```bash +SEQC index \ + -o homo_sapiens \ + -f homo_sapiens \ + --ensemble-release 93 \ + --valid-biotypes protein_coding lincRNA antisense IG_V_gene IG_D_gene IG_J_gene IG_C_gene TR_V_gene TR_D_gene TR_J_gene TR_C_gene \ + --read-length 101 \ + --folder ./test-data/index/ \ + --local +``` + +### AWS + +```bash +SEQC index \ + -o homo_sapiens \ + -f homo_sapiens \ + --ensemble-release 93 \ + --valid-biotypes protein_coding lincRNA antisense IG_V_gene IG_D_gene IG_J_gene IG_C_gene TR_V_gene TR_D_gene TR_J_gene TR_C_gene \ + --read-length 101 \ + --upload-prefix s3://dp-lab-test/seqc/index/86/ \ + --rsa-key ~/dpeerlab-chunj.pem \ + --ami-id ami-037cc8c1417e197c1 +``` diff --git a/docs/install-SUSE.md b/docs/install-SUSE.md new file mode 100644 index 0000000..869fbf1 --- /dev/null +++ b/docs/install-SUSE.md @@ -0,0 +1,50 @@ +# Installation for SUSE + +This was tested with AWS SUSE Linux Enterprise Server 15 SP1 (HVM). + +## Install gcc & c++ + +```bash +sudo zypper in gcc-c++ +``` + +## Install Miniconda + +```bash +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +bash Miniconda3-latest-Linux-x86_64.sh +``` + +For more information: +- https://docs.conda.io/en/latest/miniconda.html +- https://conda.io/projects/conda/en/latest/user-guide/install/linux.html#install-linux-silent + +Log out log back in. + +## Create a Virtual Environment + +```bash +conda create -n seqc python=3.7.7 pip +conda activate seqc +``` + +## Install dependencies + +``` +pip install Cython +pip install numpy +pip install bhtsne + +conda install -c anaconda hdf5 +conda install -c bioconda samtools +conda install -c bioconda star +``` + +## Install SEQC + +``` +wget https://github.com/dpeerlab/seqc/archive/v0.2.6.tar.gz +tar xvzf v0.2.6.tar.gz +cd seqc-0.2.6/ +pip install . +``` diff --git a/docs/install-dev.md b/docs/install-dev.md new file mode 100644 index 0000000..9e0978d --- /dev/null +++ b/docs/install-dev.md @@ -0,0 +1,59 @@ +# Setup for Development + +Last verified: Jun 4, 2020 + +## Create Conda Environment + +```bash +conda create -n seqc-dev python=3.7.7 pip +conda activate seqc-dev +``` + +## Install Dependencies + +```bash +pip install Cython +pip install numpy +pip install bhtsne +``` + +For Mac (Mojave 10.14.6), install the following additional components. You must have `brew` to install. + +``` +brew install cairo +brew install pango +``` + +## Install SEQC (editable mode) + +```bash +pip install --editable . +``` + +## Install STAR + +```bash +curl -OL https://github.com/alexdobin/STAR/archive/2.5.3a.tar.gz +tar -xf 2.5.3a.tar.gz +cp STAR-2.5.3a/bin/MacOSX_x86_64/STAR /usr/local/bin/ +``` + +## Install samtools + +```bash +conda install -c bioconda samtools=1.3.1 +``` + +## Install Packages for Testing + +```bash +pip install nose +``` + +## Install Packages for Linting and Formating + +```bash +pip install pylint +pip install autopep8 +pip install black +``` diff --git a/docs/run-test.md b/docs/run-test.md new file mode 100644 index 0000000..866e5be --- /dev/null +++ b/docs/run-test.md @@ -0,0 +1,75 @@ +# Running Test + +## Setup + +Set the following environment variables: + +```bash +export SEQC_TEST_RSA_KEY=/Users/chunj/dpeerlab-chunj.pem +export SEQC_TEST_EMAIL=jaeyoung.chun@gmail.com +export SEQC_TEST_AMI_ID=ami-037cc8c1417e197c1 +``` + +For local test, download test data in S3 to your test machine: + +``` +aws s3 sync s3://seqc-public/test/ten_x_v2/ ./test-data/datasets/ten_x_v2/ +aws s3 sync s3://seqc-public/barcodes/ten_x_v2/ ./test-data/datasets/barcodes/ten_x_v2/ +aws s3 sync s3://seqc-public/genomes/hg38_chr19/ ./test-data/datasets/genomes/hg38_chr19/ +``` + +## Test Everything + +Runs tests based on `nose2.cfg`: + +```bash +nose2 +``` + +## SEQC index + +```bash +nose2 -s src/seqc/tests test_index +``` + +Besides the nose2 test results, actual SEQC output files can be found here, for example: + +``` +s3://dp-lab-cicd/seqc/index-ciona_intestinalis-0d19e818-7623-4a1d-bac3-a8c9e3be1e3e/ +``` + +## SEQC run + +### Local + +SEQC will run with `--local`. + +```bash +nose2 -s src/seqc/tests test_run_e2e_local +``` + +### Remote + +SEQC will run on AWS. + +The following will generate a package that can be uploaded to AWS EC2 for testing: + +```bash +python repackage.py +``` + +```bash +nose2 -s src/seqc/tests test_run_e2e_remote +``` + +Besides the nose2 test results, actual SEQC output files can be found here, for example: + +``` +s3://dp-lab-cicd/seqc/run-in_drop_v2-a997b408-f883-4ba2-9941-8b541e319850/ +``` + +### Clean Up + +```bash +aws s3 rm s3://dp-lab-cicd/seqc/ --recursive +``` diff --git a/nose2.cfg b/nose2.cfg new file mode 100644 index 0000000..5af2d96 --- /dev/null +++ b/nose2.cfg @@ -0,0 +1,4 @@ +[unittest] +start-dir = src/seqc/tests +test-file-pattern = test_*.py +test-method-prefix = test diff --git a/repackage.py b/repackage.py index e89c3fe..49bb41a 100644 --- a/repackage.py +++ b/repackage.py @@ -11,7 +11,11 @@ def ignore_test_and_tools(dir_, files): :param files: output of os.listdir(), files to be subjected to filtering :return list: list of files that should be filtered, and not copied. """ - return [f for f in files if (f == "test" or f.startswith("."))] + return [ + f + for f in files + if (f == "test" or f == "test-data" or f == "__pycache__" or f.startswith(".")) + ] setup_dir = os.path.dirname(os.path.realpath(__file__)) @@ -27,5 +31,5 @@ def ignore_test_and_tools(dir_, files): # copy the SEQC files in the working directory to ~/.seqc/seqc shutil.copytree(setup_dir, seqc_dir, ignore=ignore_test_and_tools) -# create .tag.gz of ~/.seqc/seqc/* +# create .tar.gz of ~/.seqc/seqc/* shutil.make_archive(base_name=seqc_dir, format="gztar", root_dir=seqc_dir) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fe9c99a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,35 @@ +Cython>0.14 +numpy>=1.10.0 +bhtsne +wikipedia +awscli +numexpr>=2.4 +pandas>=1.0.4 +paramiko>=2.0.2 +regex +requests +nose2 +scipy>=1.5.1 +boto3 +intervaltree +matplotlib +tinydb +tables +fastcluster +statsmodels==0.11.1 +ecdsa +jupyter +jinja2 +pycrypto +cairocffi==0.8.0 +weasyprint==0.42.2 +scikit_learn>=0.17 +tqdm +pendulum +dask>=2.25.0 +distributed>=2.25.0 +dill>=0.3.2 +bokeh>=2.1.1 +numba~=0.51.2 +PhenoGraph>=1.5.7 +magic@https://github.com/dpeerlab/magic/archive/v0.1.1.tar.gz diff --git a/setup.py b/setup.py index 4df4f31..0bcc1b3 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ from setuptools import setup from warnings import warn import py_compile +from pathlib import Path # Replace py_compile.compile with a function that calls it with doraise=True @@ -19,73 +20,51 @@ def doraise_py_compile(file, cfile=None, dfile=None, doraise=False): py_compile.compile = doraise_py_compile if sys.version_info.major != 3: - raise RuntimeError('SEQC requires Python 3') + raise RuntimeError("SEQC requires Python 3") if sys.version_info.minor < 5: - warn('Multiprocessing analysis methods may not function on Python versions < 3.5') + warn("Multiprocessing analysis methods may not function on Python versions < 3.5") + +main_ns = {} # get version -with open('src/seqc/version.py') as f: - exec(f.read()) +with open("src/seqc/version.py") as f: + exec(f.read(), main_ns) setup( - name='seqc', - version=__version__, # read in from the exec of version.py; ignore error - description='Single Cell Sequencing Processing and QC Suite', - author='Ambrose J. Carr', - author_email='mail@ambrosejcarr.com', - package_dir={'': 'src'}, - package_data={'': ['*.r', '*.R']}, - packages=['seqc', 'seqc.sequence', 'seqc.alignment', 'seqc.core', 'seqc.stats', - 'seqc.summary', 'seqc.notebooks'], + name="seqc", + version=main_ns["__version__"], + description="Single Cell Sequencing Processing and QC Suite", + author="Ambrose J. Carr", + author_email="mail@ambrosejcarr.com", + package_dir={"": "src"}, + package_data={"": ["*.r", "*.R"]}, + packages=[ + "seqc", + "seqc.sequence", + "seqc.alignment", + "seqc.core", + "seqc.stats", + "seqc.summary", + "seqc.notebooks", + ], install_requires=[ - 'numpy>=1.10.0', - 'bhtsne', - 'wikipedia', - 'awscli', - 'Cython>0.14', - 'numexpr>=2.4', - 'pandas>=0.18.1', - 'paramiko>=2.0.2', - 'regex', - 'requests', - 'nose2', - 'scipy>=0.14.0', - 'boto3', - 'intervaltree', - 'matplotlib', - 'tinydb', - 'tables', - 'fastcluster', - 'statsmodels', - 'ecdsa', - 'dill', - 'jupyter', - 'multiprocessing_on_dill', - 'jinja2', - 'pycrypto', - 'cairocffi==0.8.0', - 'weasyprint==0.42.2', - 'scikit_learn>=0.17', - 'PhenoGraph@https://github.com/dpeerlab/PhenoGraph/archive/v1.5.2.tar.gz', - 'magic@https://github.com/dpeerlab/magic/archive/v0.1.1.tar.gz' + dep.strip() for dep in Path("requirements.txt").read_text("utf-8").splitlines() ], - scripts=['src/scripts/SEQC'], - extras_require={ - 'GSEA_XML': ['html5lib', 'lxml', 'BeautifulSoup4'], - }, - include_package_data=True + scripts=["src/scripts/SEQC"], + extras_require={"GSEA_XML": ["html5lib", "lxml", "BeautifulSoup4"]}, + include_package_data=True, ) # look for star -if not shutil.which('STAR'): - warn('SEQC: STAR is not installed. SEQC will not be able to align files.') +if not shutil.which("STAR"): + warn("SEQC: STAR is not installed. SEQC will not be able to align files.") # get location of setup.py setup_dir = os.path.dirname(os.path.realpath(__file__)) -seqc_dir = os.path.expanduser('~/.seqc/seqc') +seqc_dir = os.path.expanduser("~/.seqc/seqc") -print('setup_dir: {}'.format(setup_dir)) -print('seqc_dir: {}'.format(seqc_dir)) +print("setup_dir: {}".format(setup_dir)) +print("seqc_dir: {}".format(seqc_dir)) if os.path.isdir(seqc_dir): shutil.rmtree(seqc_dir) @@ -98,10 +77,10 @@ def ignore_test_and_tools(dir_, files): :param files: output of os.listdir(), files to be subjected to filtering :return list: list of files that should be filtered, and not copied. """ - return [f for f in files if (f == 'test' or f.startswith('.'))] + return [f for f in files if (f == "test" or f.startswith("."))] # install tools and a local copy of seqc. # copy seqc repository shutil.copytree(setup_dir, seqc_dir, ignore=ignore_test_and_tools) -shutil.make_archive(base_name=seqc_dir, format='gztar', root_dir=seqc_dir) +shutil.make_archive(base_name=seqc_dir, format="gztar", root_dir=seqc_dir) diff --git a/src/seqc/alignment/sam.py b/src/seqc/alignment/sam.py index 7d5fc09..d434105 100644 --- a/src/seqc/alignment/sam.py +++ b/src/seqc/alignment/sam.py @@ -4,22 +4,42 @@ import gzip +def get_version(): + + proc = Popen(["samtools", "--version"], stderr=PIPE, stdout=PIPE) + out, err = proc.communicate() + if err: + raise ChildProcessError(err) + + # e.g. + # samtools 1.9 + # Using htslib 1.9 + # Copyright (C) 2018 Genome Research Ltd. + # --> 'samtools 1.9' + version = out.decode().strip().split("\n")[0] + + # --> '1.9' + version = version.split(" ")[1] + + return version + + class SamRecord: """Simple record object allowing access to Sam record properties""" - __slots__ = ['_record', '_parsed_name_field'] + __slots__ = ["_record", "_parsed_name_field"] - NameField = namedtuple('NameField', ['pool', 'cell', 'rmt', 'poly_t', 'name']) + NameField = namedtuple("NameField", ["pool", "cell", "rmt", "poly_t", "name"]) def __init__(self, record): self._record = record self._parsed_name_field = None def __repr__(self): - return ''.format('\t'.join(self._record)) + return "".format("\t".join(self._record)) def __bytes__(self): - return '\t'.join(self._record) + '\n' + return "\t".join(self._record) + "\n" @property def qname(self) -> str: @@ -69,13 +89,13 @@ def qual(self) -> str: def optional_fields(self): flags_ = {} for f in self._record[11:]: - k, _, v = f.split(':') + k, _, v = f.split(":") flags_[k] = int(v) return flags_ def _parse_name_field(self): - fields, name = self.qname.split(';') - processed_fields = fields.split(':') + fields, name = self.qname.split(";") + processed_fields = fields.split(":") processed_fields.append(name) self._parsed_name_field = self.NameField(*processed_fields) @@ -129,16 +149,16 @@ def is_unmapped(self): @property def is_multimapped(self): - return True if self.optional_fields['NH'] > 1 else False + return True if self.optional_fields["NH"] > 1 else False @property def is_uniquely_mapped(self): - return True if self.optional_fields['NH'] == 1 else False + return True if self.optional_fields["NH"] == 1 else False @property def strand(self): minus_strand = int(self.flag) & 16 - return '-' if minus_strand else '+' + return "-" if minus_strand else "+" # # todo this takes up 66% of the processing time for parsing the sam record # @property @@ -181,8 +201,9 @@ def __init__(self, samfile: str): except RuntimeError as ex: raise ex except: - raise ValueError('%s is an invalid samfile. Please check file formatting.' % - samfile) + raise ValueError( + "%s is an invalid samfile. Please check file formatting." % samfile + ) @property def samfile(self): @@ -193,15 +214,15 @@ def _open(self): seamlessly open self._samfile, whether gzipped or uncompressed :returns: open file object """ - if self.samfile.endswith('.gz'): - fobj = gzip.open(self.samfile, 'rb') - elif self.samfile.endswith('.bam'): - if not shutil.which('samtools'): - raise RuntimeError('samtools utility must be installed to run bamfiles') - p = Popen(['samtools', 'view', self.samfile], stdout=PIPE) + if self.samfile.endswith(".gz"): + fobj = gzip.open(self.samfile, "rb") + elif self.samfile.endswith(".bam"): + if not shutil.which("samtools"): + raise RuntimeError("samtools utility must be installed to run bamfiles") + p = Popen(["samtools", "view", self.samfile], stdout=PIPE) fobj = p.stdout else: - fobj = open(self.samfile, 'rb') + fobj = open(self.samfile, "rb") return fobj def __len__(self): @@ -214,9 +235,9 @@ def __iter__(self): for line in fobj: line = line.decode() # todo move this if statement to execute only until header is exhausted - if line.startswith('@'): + if line.startswith("@"): continue - yield SamRecord(line.strip().split('\t')) + yield SamRecord(line.strip().split("\t")) finally: fobj.close() diff --git a/src/seqc/alignment/star.py b/src/seqc/alignment/star.py index e15985d..52348a1 100644 --- a/src/seqc/alignment/star.py +++ b/src/seqc/alignment/star.py @@ -4,8 +4,27 @@ import shlex +def get_version(): + + proc = Popen(["STAR", "--version"], stderr=PIPE, stdout=PIPE) + out, err = proc.communicate() + if err: + raise ChildProcessError(err) + + version = out.decode().strip() + + if version.startswith("STAR_"): + # e.g. STAR_2.5.3a + # --> 2.5.3a + return out.decode().strip().split("_")[1] + else: + # e.g. 2.7.3a + return version + + def default_alignment_args( - fastq_records: str, n_threads: int or str, index: str, output_dir: str) -> dict: + fastq_records: str, n_threads: int or str, index: str, output_dir: str +) -> dict: """default arguments for STAR alignment To report unaligned reads, add '--outSAMunmapped': 'Within', @@ -17,30 +36,36 @@ def default_alignment_args( :return: dict, default alignment arguments """ default_align_args = { - '--runMode': 'alignReads', - '--runThreadN': str(n_threads), - '--genomeDir': index, - '--outFilterType': 'BySJout', - '--outFilterMultimapNmax': '10', # require unique alignments - '--limitOutSJcollapsed': '2000000', # deal with many splice variants - '--alignSJDBoverhangMin': '8', - '--outFilterMismatchNoverLmax': '0.04', - '--alignIntronMin': '20', - '--alignIntronMax': '1000000', - '--readFilesIn': fastq_records, - '--outSAMprimaryFlag': 'AllBestScore', # all equal-scoring reads are primary - '--outSAMtype': 'BAM Unsorted', - '--outFileNamePrefix': output_dir, + "--runMode": "alignReads", + "--runThreadN": str(n_threads), + "--genomeDir": index, + "--outFilterType": "BySJout", + "--outFilterMultimapNmax": "10", # require unique alignments + "--limitOutSJcollapsed": "2000000", # deal with many splice variants + "--alignSJDBoverhangMin": "8", + "--outFilterMismatchNoverLmax": "0.04", + "--alignIntronMin": "20", + "--alignIntronMax": "1000000", + "--readFilesIn": fastq_records, + "--outSAMprimaryFlag": "AllBestScore", # all equal-scoring reads are primary + "--outSAMtype": "BAM Unsorted", + "--outFileNamePrefix": output_dir, } - if fastq_records.endswith('.gz'): - default_align_args['--readFilesCommand'] = 'gunzip -c' - if fastq_records.endswith('.bz2'): - default_align_args['--readFilesCommand'] = 'bunzip2 -c' + if fastq_records.endswith(".gz"): + default_align_args["--readFilesCommand"] = "gunzip -c" + if fastq_records.endswith(".bz2"): + default_align_args["--readFilesCommand"] = "bunzip2 -c" return default_align_args -def align(fastq_file: str, index: str, n_threads: int, alignment_dir: str, - reverse_fastq_file: str or bool=None, **kwargs) -> str: +def align( + fastq_file: str, + index: str, + n_threads: int, + alignment_dir: str, + reverse_fastq_file: str or bool = None, + **kwargs +) -> str: """align a fastq file, or a paired set of fastq files :param fastq_file: str, location of a fastq file @@ -52,27 +77,26 @@ def align(fastq_file: str, index: str, n_threads: int, alignment_dir: str, :return: str, .sam file location """ - runtime_args = default_alignment_args( - fastq_file, n_threads, index, alignment_dir) + runtime_args = default_alignment_args(fastq_file, n_threads, index, alignment_dir) for k, v in kwargs.items(): # overwrite or add any arguments passed from cmdline if not isinstance(k, str): try: k = str(k) except ValueError: - raise ValueError('arguments passed to STAR must be strings') + raise ValueError("arguments passed to STAR must be strings") if not isinstance(v, str): try: v = str(v) except ValueError: - raise ValueError('arguments passed to STAR must be strings') - runtime_args['--' + k] = v + raise ValueError("arguments passed to STAR must be strings") + runtime_args["--" + k] = v # construct command line arguments for STAR - cmd = ['STAR'] + cmd = ["STAR"] if reverse_fastq_file: for key, value in runtime_args.items(): - if key == '--readFilesIn': + if key == "--readFilesIn": cmd.extend((key, value)) cmd.append(reverse_fastq_file) else: @@ -81,20 +105,18 @@ def align(fastq_file: str, index: str, n_threads: int, alignment_dir: str, for pair in runtime_args.items(): cmd.extend(pair) - cmd = shlex.split(' '.join(cmd)) + cmd = shlex.split(" ".join(cmd)) aln = Popen(cmd, stderr=PIPE, stdout=PIPE) - out, err = aln.communicate() + _, err = aln.communicate() if err: raise ChildProcessError(err) - return alignment_dir + 'Aligned.out.bam' + return alignment_dir + "Aligned.out.bam" def create_index( - fasta: str, - gtf: str, - genome_dir: str, - read_length: int=75, **kwargs) -> None: + fasta: str, gtf: str, genome_dir: str, read_length: int = 75, **kwargs +) -> None: """Create a new STAR index :param fasta: complete filepath to fasta file @@ -109,19 +131,34 @@ def create_index( makedirs(genome_dir, exist_ok=True) overhang = str(read_length - 1) - cmd = ( - 'STAR ' - '--runMode genomeGenerate ' - '--runThreadN {ncpu} ' - '--genomeDir {genome_dir} ' - '--genomeFastaFiles {fasta} ' - '--sjdbGTFfile {gtf} ' - '--sjdbOverhang {overhang} '.format( - ncpu=ncpu, genome_dir=genome_dir, fasta=fasta, gtf=gtf, overhang=overhang) - ) + # Popen is hard to work as far as process substitution is concerned. + # let's just gunzip it before passing to STAR. + if fasta.endswith(".gz"): + proc_gunzip = Popen(["gunzip", fasta]) + out, err = proc_gunzip.communicate() + if err: + raise ChildProcessError(err) + fasta = fasta.replace(".gz", "") + + cmd = [ + "STAR", + "--runMode", + "genomeGenerate", + "--runThreadN", + ncpu, + "--genomeDir", + genome_dir, + "--genomeFastaFiles", + fasta, + "--sjdbGTFfile", + gtf, + "--sjdbOverhang", + overhang, + ] for k, v in kwargs.items(): - cmd += '--{k} {v} '.format(k=k, v=v) + cmd.append("--{}".format(k)) + cmd.append(v) p = Popen(cmd, stderr=PIPE, stdout=PIPE) out, err = p.communicate() diff --git a/src/seqc/barcode_correction.py b/src/seqc/barcode_correction.py index e17ea98..2aaaa91 100644 --- a/src/seqc/barcode_correction.py +++ b/src/seqc/barcode_correction.py @@ -6,25 +6,27 @@ from seqc import log -def ten_x_barcode_correction(ra, platform, barcode_files, max_ed=2, - default_error_rate=0.02): - ''' +def ten_x_barcode_correction( + ra, platform, barcode_files, max_ed=2, default_error_rate=0.02 +): + """ Correct reads with incorrect barcodes according to the correct barcodes files. - Reads with barcodes that have too many errors are filtered out. + Reads with barcodes that have too many errors are filtered out. :param ra: seqc.read_array.ReadArray object :param platform: the platform object :param barcode_files: the list of the paths of barcode files - :param max_ed: maximum allowed Hamming distance from known cell barcodes + :param max_ed: maximum allowed Hamming distance from known cell barcodes :param default_error_rate: assumed sequencing error rate :return: - ''' + """ # Read the barcodes into lists valid_barcodes = set() for barcode_file in barcode_files: - with open(barcode_file, 'r') as f: - valid_barcodes = set([DNA3Bit.encode(line.strip()) for line in - f.readlines()]) + with open(barcode_file, "r") as f: + valid_barcodes = set( + [DNA3Bit.encode(line.strip()) for line in f.readlines()] + ) # Group reads by cells indices_grouped_by_cells = ra.group_indices_by_cell(multimapping=True) @@ -33,42 +35,51 @@ def ten_x_barcode_correction(ra, platform, barcode_files, max_ed=2, valid_barcode_count = dict() for inds in indices_grouped_by_cells: # Extract barcodes for one of the reads - barcode = platform.extract_barcodes(ra.data['cell'][inds[0]])[0] + barcode = platform.extract_barcodes(ra.data["cell"][inds[0]])[0] if barcode in valid_barcodes: valid_barcode_count[barcode] = len(inds) # Find the set of invalid barcodes and check out whether they can be corrected + mapping = [] for inds in indices_grouped_by_cells: # Extract barcodes for one of the reads - barcode = platform.extract_barcodes(ra.data['cell'][inds[0]])[0] + barcode = platform.extract_barcodes(ra.data["cell"][inds[0]])[0] if barcode not in valid_barcode_count: - # Identify correct barcode as one Hamming distance away with most reads - hammind_dist_1_barcodes = seqc.sequence.barcodes.generate_hamming_dist_1(barcode) + # Identify correct barcode as one Hamming distance away with most reads + hammind_dist_1_barcodes = seqc.sequence.barcodes.generate_hamming_dist_1( + barcode + ) fat_bc = -1 fat_bc_count = 0 for bc in hammind_dist_1_barcodes: - if (bc in valid_barcode_count) and (valid_barcode_count[bc] > fat_bc_count): + if (bc in valid_barcode_count) and ( + valid_barcode_count[bc] > fat_bc_count + ): fat_bc = bc fat_bc_count = valid_barcode_count[bc] if fat_bc < 0: - ra.data['status'][inds] |= ra.filter_codes['cell_error'] + ra.data["status"][inds] |= ra.filter_codes["cell_error"] else: + # record pre-/post-correction + for reported_cb in np.unique(ra.data["cell"][inds]): + mapping.append((reported_cb, fat_bc)) + # Update the read array with the correct barcode - ra.data['cell'][inds] = fat_bc + ra.data["cell"][inds] = fat_bc + return default_error_rate, pd.DataFrame(mapping, columns=["CR", "CB"]) -def in_drop(ra, platform, barcode_files, max_ed=2, - default_error_rate=0.02): +def in_drop(ra, platform, barcode_files, max_ed=2, default_error_rate=0.02): """ Correct reads with incorrect barcodes according to the correct barcodes files. Reads with barcodes that have too many errors are filtered out. :param ra: seqc.read_array.ReadArray object :param platform: the platform object :param barcode_files: the list of the paths of barcode files - :param max_ed: maximum allowed Hamming distance from known cell barcodes + :param max_ed: maximum allowed Hamming distance from known cell barcodes :param default_error_rate: assumed sequencing error rate :return: """ @@ -76,11 +87,12 @@ def in_drop(ra, platform, barcode_files, max_ed=2, # Read the barcodes into lists valid_barcodes = [] for barcode_file in barcode_files: - with open(barcode_file, 'r') as f: - valid_barcodes.append(set(DNA3Bit.encode(line.strip()) for line in - f.readlines())) - - # Containers + with open(barcode_file, "r") as f: + valid_barcodes.append( + set(DNA3Bit.encode(line.strip()) for line in f.readlines()) + ) + + # Containers num_barcodes = platform.num_barcodes correct = [None] * num_barcodes edit_dist = [None] * num_barcodes @@ -88,9 +100,10 @@ def in_drop(ra, platform, barcode_files, max_ed=2, # Error table container errors = [p for p in permutations(DNA3Bit.bin2strdict.keys(), r=2)] error_table = dict(zip(errors, np.zeros(len(errors)))) - cor_instance_table = dict(zip(DNA3Bit.bin2strdict.keys(), - np.zeros(len(DNA3Bit.bin2strdict)))) - + cor_instance_table = dict( + zip(DNA3Bit.bin2strdict.keys(), np.zeros(len(DNA3Bit.bin2strdict))) + ) + # Check if the barcode has to be an exact match exact_match = False if max_ed == 0: @@ -102,12 +115,13 @@ def in_drop(ra, platform, barcode_files, max_ed=2, for inds in indices_grouped_by_cells: # Extract barcodes for one of the reads - barcodes = platform.extract_barcodes(ra.data['cell'][inds[0]]) + barcodes = platform.extract_barcodes(ra.data["cell"][inds[0]]) # Identify correct barcode for i in range(num_barcodes): correct[i], edit_dist[i] = seqc.sequence.barcodes.find_correct_barcode( - barcodes[i], valid_barcodes[i], exact_match) + barcodes[i], valid_barcodes[i], exact_match + ) # 1. If all edit distances are 0, barcodes are correct, # update the correct instance table @@ -124,17 +138,17 @@ def in_drop(ra, platform, barcode_files, max_ed=2, tmp_bc >>= 3 elif max(edit_dist) > max_ed: - ra.data['status'][inds] |= ra.filter_codes['cell_error'] + ra.data["status"][inds] |= ra.filter_codes["cell_error"] continue else: # These barcodes can be corrected, Count the number of correct bases - # Update the error table if there was only one error across the barcodes + # Update the error table if there was only one error across the barcodes tmp_bc = DNA3Bit.ints2int(barcodes) tmp_cor = DNA3Bit.ints2int(correct) # Update the read array with the correct barcode - ra.data['cell'][inds] = tmp_cor + ra.data["cell"][inds] = tmp_cor # Iterating through the sequences while tmp_bc > 0: @@ -147,35 +161,42 @@ def in_drop(ra, platform, barcode_files, max_ed=2, # Create error rate table if sum(error_table.values()) == 0: - log.info('No errors were detected or barcodes do not support error ' - 'correction, using %f uniform error chance.' % default_error_rate) + log.info( + "No errors were detected or barcodes do not support error " + "correction, using %f uniform error chance." % default_error_rate + ) err_rate = dict(zip(errors, [default_error_rate] * len(errors))) # todo @Manu bug here, we're always setting the error rate even if there are # no detected errors. should the following line be in an "else" clause? err_rate = dict(zip(errors, [0.0] * len(errors))) for k, v in error_table.items(): - if DNA3Bit.decode(k[0]) in b'Nn': + if DNA3Bit.decode(k[0]) in b"Nn": continue try: - err_rate[k] = v / (sum(n for err_type, n in error_table.items() - if err_type[0] == k[0]) + cor_instance_table[k[0]]) + err_rate[k] = v / ( + sum(n for err_type, n in error_table.items() if err_type[0] == k[0]) + + cor_instance_table[k[0]] + ) except ZeroDivisionError: - log.info('Warning: too few reads to estimate error rate for %s, setting ' - 'default rate of %f' % - (str(DNA3Bit.decode(k)), default_error_rate)) + log.info( + "Warning: too few reads to estimate error rate for %s, setting " + "default rate of %f" % (str(DNA3Bit.decode(k)), default_error_rate) + ) err_rate[k] = default_error_rate - return err_rate + return err_rate, None -def drop_seq(ra, min_rmt_cutoff=10, rmt_error_frequency=0.8, barcode_base_shift_threshold=0.9): +def drop_seq( + ra, min_rmt_cutoff=10, rmt_error_frequency=0.8, barcode_base_shift_threshold=0.9 +): """Drop-seq barcode correction suggested by Ashley 1. Barcodes can be truncated to 11 bases because of synthesis error. Therefore a single barcode can be potentially be split to 4 barcodes Solution: Fix barcode: At the 8th position of RMT, if the fraction of T > 80%, replace the 12th position of the cell barcode with N - Fix RMT: Remove the T in the last position of the RMT and prepend the + Fix RMT: Remove the T in the last position of the RMT and prepend the first base from the uncorrected cell barcode 2. If a particular base dominates any of the positions of the RMT, remove that cell barcode @@ -185,69 +206,75 @@ def drop_seq(ra, min_rmt_cutoff=10, rmt_error_frequency=0.8, barcode_base_shift_ :param min_rmt_cutoff: Minimum number of RMTs to apply barcode correction :param rmt_error_frequency: If a base appears with this frequency across the RMTs associated with the barcode in any position, the barcode is removed - :param barcode_base_shift_threshold: Thresholds for detecting barcode shift + :param barcode_base_shift_threshold: Thresholds for detecting barcode shift :return: """ - + # Cell header [First 11 bases only - this should be parametrized] - cell_header = ra.data['cell'] >> 3 - idx = np.argsort( cell_header ) + cell_header = ra.data["cell"] >> 3 + idx = np.argsort(cell_header) # Active reads - passing = ra.data['status'][idx] == 0 + passing = ra.data["status"][idx] == 0 idx = idx[passing] breaks = np.where(np.diff(cell_header[idx]))[0] + 1 indices_grouped_by_cell_headers = np.split(idx, breaks) # RMT length - rmt_length = DNA3Bit.seq_len( ra.data['rmt'][idx[0]] ) + rmt_length = DNA3Bit.seq_len(ra.data["rmt"][idx[0]]) # 1. Barcode synthesis errors for header_group in indices_grouped_by_cell_headers: # RMT set # todo this could potentially be used in RMT correction / barcode correction in indrop - all_rmts = list(set(ra.data['rmt'][header_group])) + all_rmts = list(set(ra.data["rmt"][header_group])) if len(all_rmts) < min_rmt_cutoff: continue # Count Ts in the last RMT position - nuc_counts = dict(zip(DNA3Bit.bin2strdict.keys(), np.zeros(len(DNA3Bit.bin2strdict)))) + nuc_counts = dict( + zip(DNA3Bit.bin2strdict.keys(), np.zeros(len(DNA3Bit.bin2strdict))) + ) for rmt in all_rmts: nuc_counts[rmt & 0b0111] += 1 # Correct cell barcode if necessary - if nuc_counts[DNA3Bit.str2bindict['T']] > barcode_base_shift_threshold * len(all_rmts): + if nuc_counts[DNA3Bit.str2bindict["T"]] > barcode_base_shift_threshold * len( + all_rmts + ): # Correct the RMTs [This needs to done for each cell/RMT combination] - idx = header_group[np.argsort(ra.data['cell'][header_group])] - breaks = np.where(np.diff(ra.data['cell'][idx]))[0] + 1 + idx = header_group[np.argsort(ra.data["cell"][header_group])] + breaks = np.where(np.diff(ra.data["cell"][idx]))[0] + 1 cell_groups = np.split(idx, breaks) for cell_group in cell_groups: - last_base = ra.data['cell'][cell_group[0]] & 0b111 + last_base = ra.data["cell"][cell_group[0]] & 0b111 # Correct the RMTs - idx = cell_group[np.argsort(ra.data['rmt'][cell_group])] - breaks = np.where(np.diff(ra.data['rmt'][cell_group]))[0] + 1 + idx = cell_group[np.argsort(ra.data["rmt"][cell_group])] + breaks = np.where(np.diff(ra.data["rmt"][cell_group]))[0] + 1 rmt_groups = np.split(idx, breaks) for rmt_group in rmt_groups: # Skip the last base - new_rmt = ra.data['rmt'][rmt_group[0]] >> 3 + new_rmt = ra.data["rmt"][rmt_group[0]] >> 3 # Get the last base from the cell barcode - new_rmt = DNA3Bit.ints2int([last_base, new_rmt ]) - ra.data['rmt'][rmt_group] = new_rmt + new_rmt = DNA3Bit.ints2int([last_base, new_rmt]) + ra.data["rmt"][rmt_group] = new_rmt # Append N to the cell header - correct_barcode = DNA3Bit.ints2int([cell_header[header_group[0]], DNA3Bit.str2bindict['N']]) - ra.data['cell'][header_group] = correct_barcode + correct_barcode = DNA3Bit.ints2int( + [cell_header[header_group[0]], DNA3Bit.str2bindict["N"]] + ) + ra.data["cell"][header_group] = correct_barcode # 2. Single UMI error indices_grouped_by_cells = ra.group_indices_by_cell() for cell_group in indices_grouped_by_cells: # RMT set - all_rmts = list(set(ra.data['rmt'][cell_group])) + all_rmts = list(set(ra.data["rmt"][cell_group])) if len(all_rmts) < min_rmt_cutoff: continue @@ -257,7 +284,7 @@ def drop_seq(ra, min_rmt_cutoff=10, rmt_error_frequency=0.8, barcode_base_shift_ base_frequencies[i] = np.zeros(rmt_length) for i in range(len(all_rmts)): rmt = all_rmts[i] - position = rmt_length-1 + position = rmt_length - 1 while rmt > 0: base_frequencies[rmt & 0b111][position] += 1 rmt >>= 3 @@ -265,10 +292,13 @@ def drop_seq(ra, min_rmt_cutoff=10, rmt_error_frequency=0.8, barcode_base_shift_ # Chuck N base_frequencies = pd.DataFrame(base_frequencies).T - base_frequencies.ix[DNA3Bit.str2bindict['N']] = 0 + base_frequencies.ix[DNA3Bit.str2bindict["N"]] = 0 # Identify incorrect UMIs - if any( base_frequencies.iloc[:, 0:(rmt_length-1)].max() > rmt_error_frequency * len(all_rmts)): - ra.data['status'][cell_group] |= ra.filter_codes['cell_error'] - + if any( + base_frequencies.iloc[:, 0 : (rmt_length - 1)].max() + > rmt_error_frequency * len(all_rmts) + ): + ra.data["status"][cell_group] |= ra.filter_codes["cell_error"] + return None, None diff --git a/src/seqc/core/download.py b/src/seqc/core/download.py index f224596..d3c3b36 100644 --- a/src/seqc/core/download.py +++ b/src/seqc/core/download.py @@ -11,13 +11,14 @@ def s3_data(files_or_links, output_prefix): """ files = [] for f in files_or_links: - if not f.startswith('s3://'): - if f.endswith('/'): + if not f.startswith("s3://"): + if f.endswith("/"): files.extend(f + subfile for subfile in os.listdir(f)) else: files.append(f) else: - recursive = True if f.endswith('/') else False - files.extend(io.S3.download(f, output_prefix, overwrite=True, - recursive=recursive)) + recursive = True if f.endswith("/") else False + files.extend( + io.S3.download(f, output_prefix, overwrite=True, recursive=recursive) + ) return files diff --git a/src/seqc/core/index.py b/src/seqc/core/index.py index 4e79b42..cf57461 100644 --- a/src/seqc/core/index.py +++ b/src/seqc/core/index.py @@ -1,4 +1,3 @@ - def index(args): """create an index for SEQC. @@ -7,10 +6,56 @@ def index(args): """ # functions to be pickled and run remotely must import all their own modules - from seqc import ec2, log + import sys + import logging + from seqc import ec2, log, io from seqc.sequence.index import Index + from seqc.alignment import star + from seqc import version + + logging.basicConfig( + level=logging.DEBUG, + handlers=[ + logging.FileHandler(args.log_name), + logging.StreamHandler(sys.stdout), + ], + ) + + log.info("SEQC=v{}".format(version.__version__)) + log.info("STAR=v{}".format(star.get_version())) + log.args(args) + + with ec2.instance_clean_up( + email=args.email, + upload=args.upload_prefix, + log_name=args.log_name, + debug=args.debug, + terminate=args.terminate, + running_remote=args.remote, + ): + + idx = Index(args.organism, args.ids, args.folder) + idx.create_index( + s3_location=args.upload_prefix, + ensemble_release=args.ensemble_release, + read_length=args.read_length, + valid_biotypes=args.valid_biotypes, + ) + + # upload the log file (seqc_log.txt, nohup.log, Log.out) + if args.upload_prefix: + bucket, key = io.S3.split_link(args.upload_prefix) + for item in [args.log_name, "./nohup.log", "./Log.out"]: + try: + ec2.Retry(retries=5)(io.S3.upload_file)(item, bucket, key) + log.info( + "Successfully uploaded {} to {}".format( + item, args.upload_prefix + ) + ) + except FileNotFoundError: + log.notify( + "Item {} was not found! Continuing with upload...".format(item) + ) - log.setup_logger(args.log_name) - with ec2.instance_clean_up(args.email, args.upload, log_name=args.log_name): - idx = Index(args.organism, args.additional_id_types) - idx.create_index(args.upload_location) + log.info("DONE.") diff --git a/src/seqc/core/main.py b/src/seqc/core/main.py index d70ca26..a67326f 100755 --- a/src/seqc/core/main.py +++ b/src/seqc/core/main.py @@ -22,7 +22,7 @@ def clean_up_security_groups(): ) # get security groups associated with instances unused_sgs = all_sgs - all_inst_sgs # get ones without instance association - if len(unused_sgs) >= 300: + if len(unused_sgs) >= 100: print("Cleaning up the unused security groups:") client = boto3.client("ec2") for g in unused_sgs: @@ -67,13 +67,22 @@ def main(argv): "volume_size", "user_tags", "remote_update", - "ami_id" + "ami_id", ) if getattr(verified_args, k) } + + # store the command-line arguments supplied by the user + # the same aguments will be used to run SEQC on EC2 + remote_args["argv"] = argv + + # clean up AWS security groups clean_up_security_groups() + + # start EC2 instance and run the function ec2.AWSInstance(synchronous=False, **remote_args)(func)(verified_args) else: + # run the function locally func(arguments) diff --git a/src/seqc/core/parser.py b/src/seqc/core/parser.py index a3263d2..3d81542 100644 --- a/src/seqc/core/parser.py +++ b/src/seqc/core/parser.py @@ -16,264 +16,447 @@ def parse_args(args): """ meta = argparse.ArgumentParser( - description='Processing Tools for scRNA-seq Experiments') - meta.add_argument('-v', '--version', action='version', - version='{} v{}'.format(meta.prog, version.__version__)) - subparsers = meta.add_subparsers(dest='subparser_name') + description="Processing Tools for scRNA-seq Experiments" + ) + meta.add_argument( + "-v", + "--version", + action="version", + version="{} v{}".format(meta.prog, version.__version__), + ) + subparsers = meta.add_subparsers(dest="subparser_name") # subparser for running experiments # can use to make prettier: formatter_class=partial(argparse.HelpFormatter, width=200) - p = subparsers.add_parser('run', help='initiate SEQC runs') + p = subparsers.add_parser("run", help="initiate SEQC runs") # Platform choices - choices = [x[0] for x in inspect.getmembers(platforms, inspect.isclass) if - issubclass(x[1], platforms.AbstractPlatform)][1:] - p.add_argument('platform', - choices=choices, - help='which platform are you merging annotations from?') + choices = [ + x[0] + for x in inspect.getmembers(platforms, inspect.isclass) + if issubclass(x[1], platforms.AbstractPlatform) + ][1:] + p.add_argument( + "platform", + choices=choices, + help="which platform are you merging annotations from?", + ) - a = p.add_argument_group('required arguments') - a.add_argument('-o', '--output-prefix', metavar='O', required=True, - help='filename prefix for all seqc output. Should not be a directory.') - a.add_argument('-i', '--index', metavar='I', required=True, - help='Local folder or s3 link to a directory containing the STAR ' - 'index used for alignment.') - a.add_argument('--barcode-files', nargs='*', metavar='BF', default=[], - help='Either (a) an s3 link to a folder containing only barcode ' - 'files, or (b) the full file path of each file on the local ' - 'machine.') + a = p.add_argument_group("required arguments") + a.add_argument( + "-o", + "--output-prefix", + metavar="O", + required=True, + help="filename prefix for all seqc output. Should not be a directory.", + ) + a.add_argument( + "-i", + "--index", + metavar="I", + required=True, + help="Local folder or s3 link to a directory containing the STAR " + "index used for alignment.", + ) + a.add_argument( + "--barcode-files", + nargs="*", + metavar="BF", + default=[], + help="Either (a) an s3 link to a folder containing only barcode " + "files, or (b) the full file path of each file on the local " + "machine.", + ) - i = p.add_argument_group('input arguments') - i.add_argument('-g', '--genomic-fastq', nargs='*', metavar='G', default=[], - help='List of fastq file(s) containing genomic information, or an s3 ' - 'link to a directory containing only genomic fastq file(s).') - i.add_argument('-b', '--barcode-fastq', nargs='*', metavar='B', default=[], - help='List of fastq file(s) containing barcode information, or an s3 ' - 'link to a directory containing only barcode fastq file(s).') - i.add_argument('-m', '--merged-fastq', nargs='?', metavar='M', default='', - help='Filename or s3 link to a fastq file containing genomic ' - 'information annotated with barcode data.') - i.add_argument('-a', '--alignment-file', nargs='?', metavar='A', default='', - help='Filename or s3 link to a .sam or .bam file containing aligned, ' - 'merged sequence records.') - i.add_argument('-r', '--read-array', nargs='?', metavar='RA', default='', - help='Filename or s3 link to a ReadArray (.h5) archive containing ' - 'processed sam records.') - i.add_argument('--basespace', metavar='BS', - help='BaseSpace sample ID. The string of numbers indicating the id ' - 'of the BaseSpace sample. (e.g. if the link to the sample is ' - 'https://basespace.illumina.com/sample/34000253/0309, ' - 'then --basespace would be 34000253.') - i.add_argument('--basespace-token', metavar='BST', default=None, - help='OAuth token for basespace access. Required if BaseSpace input ' - 'is used.') + i = p.add_argument_group("input arguments") + i.add_argument( + "-g", + "--genomic-fastq", + nargs="*", + metavar="G", + default=[], + help="List of fastq file(s) containing genomic information, or an s3 " + "link to a directory containing only genomic fastq file(s).", + ) + i.add_argument( + "-b", + "--barcode-fastq", + nargs="*", + metavar="B", + default=[], + help="List of fastq file(s) containing barcode information, or an s3 " + "link to a directory containing only barcode fastq file(s).", + ) + i.add_argument( + "-m", + "--merged-fastq", + nargs="?", + metavar="M", + default="", + help="Filename or s3 link to a fastq file containing genomic " + "information annotated with barcode data.", + ) + i.add_argument( + "-a", + "--alignment-file", + nargs="?", + metavar="A", + default="", + help="Filename or s3 link to a .sam or .bam file containing aligned, " + "merged sequence records.", + ) + i.add_argument( + "-r", + "--read-array", + nargs="?", + metavar="RA", + default="", + help="Filename or s3 link to a ReadArray (.h5) archive containing " + "processed sam records.", + ) + i.add_argument( + "--basespace", + metavar="BS", + help="BaseSpace sample ID. The string of numbers indicating the id " + "of the BaseSpace sample. (e.g. if the link to the sample is " + "https://basespace.illumina.com/sample/34000253/0309, " + "then --basespace would be 34000253.", + ) + i.add_argument( + "--basespace-token", + metavar="BST", + default=None, + help="OAuth token for basespace access. Required if BaseSpace input " + "is used.", + ) - f = p.add_argument_group('filter arguments') - f.add_argument('--max-insert-size', metavar='F', type=int, - help='the maximum fragment size in bp. Aligments that are further ' - 'than this distance from a TTS are discarded. Default=1000', - default=1000) - f.add_argument('--min-poly-t', metavar='T', - help='minimum size of poly-T tail that is required for a barcode to ' - 'be considered a valid record (default=None, automatically ' - 'estimates the parameter from the sequence length)', - default=None, type=int) + f = p.add_argument_group("filter arguments") + f.add_argument( + "--max-insert-size", + metavar="F", + type=int, + help="the maximum fragment size in bp. Aligments that are further " + "than this distance from a TTS are discarded. Default=1000", + default=1000, + ) + f.add_argument( + "--min-poly-t", + metavar="T", + help="minimum size of poly-T tail that is required for a barcode to " + "be considered a valid record (default=None, automatically " + "estimates the parameter from the sequence length)", + default=None, + type=int, + ) # f.add_argument('--max-dust-score', metavar='D', default=10, type=int, # help='maximum complexity score for a read to be considered valid. ' # '(default=10, higher scores indicate lower complexity.)') - f.add_argument('--singleton-weight', metavar='SW', - help='Weight to apply to singletons in the count matrix. Float ' - 'between 0 and 1, default=1 (all molecules get full weight)', - default=1.0, type=float) + f.add_argument( + "--singleton-weight", + metavar="SW", + help="Weight to apply to singletons in the count matrix. Float " + "between 0 and 1, default=1 (all molecules get full weight)", + default=1.0, + type=float, + ) f.set_defaults(filter_mitochondrial_rna=True) - f.add_argument('--no-filter-mitochondrial-rna', action='store_false', - dest='filter_mitochondrial_rna', - help='Do not filter cells with greater than 20 percent mitochondrial ' - 'RNA ') + f.add_argument( + "--no-filter-mitochondrial-rna", + action="store_false", + dest="filter_mitochondrial_rna", + help="Do not filter cells with greater than 20 percent mitochondrial " "RNA ", + ) f.set_defaults(filter_low_coverage=True) - f.add_argument('--no-filter-low-coverage', action='store_false', - dest='filter_low_coverage', - help='Do not filter cells with low coverage') + f.add_argument( + "--no-filter-low-coverage", + action="store_false", + dest="filter_low_coverage", + help="Do not filter cells with low coverage", + ) f.set_defaults(filter_low_gene_abundance=True) - f.add_argument('--no-filter-low-gene-abundance', action='store_false', - dest='filter_low_gene_abundance', - help='Do not filter cells with low coverage') - f.add_argument('--low-coverage-alpha', metavar='LA', - help='FDR rate for low coverage reads filter in mars-seq datasets. ' - 'Float between 0 and 1, default=0.25', - default=0.25, type=float) + f.add_argument( + "--no-filter-low-gene-abundance", + action="store_false", + dest="filter_low_gene_abundance", + help="Do not filter cells with low coverage", + ) + f.add_argument( + "--low-coverage-alpha", + metavar="LA", + help="FDR rate for low coverage reads filter in mars-seq datasets. " + "Float between 0 and 1, default=0.25", + default=0.25, + type=float, + ) # right now, it doesn't do much except you can override the default value for `--max-insert-size` f.add_argument( - '--filter-mode', dest='filter_mode', type=str, default="scRNA-seq", - help='Either "scRNA-seq" or "snRNA-seq"' + "--filter-mode", + dest="filter_mode", + type=str, + default="scRNA-seq", + help='Either "scRNA-seq" or "snRNA-seq"', ) - s = p.add_argument_group('alignment arguments') - s.add_argument('--star-args', default=None, nargs='*', - help='additional arguments that should be passed to the STAR ' - 'aligner. For example, to set the maximum allowable times for a ' - 'read to align to 20, one would set ' - '--star-args outFilterMultimapNmax=20. Additional arguments can ' - 'be provided as a white-space separated list.') + s = p.add_argument_group("alignment arguments") + s.add_argument( + "--star-args", + default=None, + nargs="*", + help="additional arguments that should be passed to the STAR " + "aligner. For example, to set the maximum allowable times for a " + "read to align to 20, one would set " + "--star-args outFilterMultimapNmax=20. Additional arguments can " + "be provided as a white-space separated list.", + ) # PROGRESS PARSER - progress = subparsers.add_parser('progress', help='check SEQC run progress') + progress = subparsers.add_parser("progress", help="check SEQC run progress") progress.set_defaults(remote=False) progress.add_argument( - '-i', '--instance-ids', help='check the progress of run(s)', nargs='+') + "-i", "--instance-ids", help="check the progress of run(s)", nargs="+" + ) progress.add_argument( - '-k', '--rsa-key', help='RSA key registered to your aws account', - default=None) + "-k", "--rsa-key", help="RSA key registered to your aws account", default=None + ) # TERMINATE PARSER - terminate = subparsers.add_parser('terminate', help='terminate SEQC runs') + terminate = subparsers.add_parser("terminate", help="terminate SEQC runs") terminate.set_defaults(remote=False) terminate.add_argument( - '-i', '--instance-ids', help='terminate these instance(s)', nargs='+') + "-i", "--instance-ids", help="terminate these instance(s)", nargs="+" + ) # INSTANCES PARSER - instances = subparsers.add_parser('instances', help='list all running instances') + instances = subparsers.add_parser("instances", help="list all running instances") instances.set_defaults(remote=False) instances.add_argument( - '-k', '--rsa-key', help='RSA key registered to your aws account', - default=None) + "-k", "--rsa-key", help="RSA key registered to your aws account", default=None + ) # START PARSER - start = subparsers.add_parser( - 'start', help='initialize a seqc-ready instance') + start = subparsers.add_parser("start", help="initialize a seqc-ready instance") start.set_defaults(remote=False) start.add_argument( - '-s', '--volume-size', help='size of volume (Gb) to attach to instance', - default=5, type=int) + "-s", + "--volume-size", + help="size of volume (Gb) to attach to instance", + default=5, + type=int, + ) start.add_argument( - '-b', '--spot-bid', help='amount to bid for instance in fractions of dollars', - type=float, default=None) + "-b", + "--spot-bid", + help="amount to bid for instance in fractions of dollars", + type=float, + default=None, + ) start.add_argument( - '-t', '--instance-type', default='r5.2xlarge', - help='AWS instance type to initialize. ' - 'See https://aws.amazon.com/ec2/instance-types/ for valid types') + "-t", + "--instance-type", + default="r5.2xlarge", + help="AWS instance type to initialize. " + "See https://aws.amazon.com/ec2/instance-types/ for valid types", + ) start.add_argument( - '-k', '--rsa-key', help='RSA key registered to your aws account', - default=None) + "-k", "--rsa-key", help="RSA key registered to your aws account", default=None + ) start.add_argument( - '--ami-id', dest='ami_id', required=False, - help='ID of the SEQC AMI to use' + "--ami-id", dest="ami_id", required=False, help="ID of the SEQC AMI to use" ) # NOTEBOOK PARSERS - notebook_sp = subparsers.add_parser('notebook', help='notebook tools') - _nb_parser = notebook_sp.add_subparsers(dest='subsubparser_name') + notebook_sp = subparsers.add_parser("notebook", help="notebook tools") + _nb_parser = notebook_sp.add_subparsers(dest="subsubparser_name") # NOTEBOOK MERGE PARSER merge = _nb_parser.add_parser( - 'merge', help='merge multiple datasets prior to running an analysis notebook') + "merge", help="merge multiple datasets prior to running an analysis notebook" + ) merge.add_argument( - '-o', '--output-filename', help='name for merged fastq file', required=True) + "-o", "--output-filename", help="name for merged fastq file", required=True + ) merge.add_argument( - '-i', '--input-data', nargs='+', help='count matrices to merge', required=True) + "-i", "--input-data", nargs="+", help="count matrices to merge", required=True + ) # NOTEBOOK GENERATE PARSER - generate = _nb_parser.add_parser('generate', help='generate a notebook from a dataset') + generate = _nb_parser.add_parser( + "generate", help="generate a notebook from a dataset" + ) generate.add_argument( - '-i', '--input-count-matrix', help='count matrix file', required=True) + "-i", "--input-count-matrix", help="count matrix file", required=True + ) generate.add_argument( - '-o', '--output-stem', help='directory and filestem for output', required=True) + "-o", "--output-stem", help="directory and filestem for output", required=True + ) - pindex = subparsers.add_parser('index', help='create a SEQC index') + pindex = subparsers.add_parser("index", help="create a SEQC index") pindex.add_argument( - '-o', '--organism', required=True, - help='organism to create index for. Must be formatted as genus_species in all ' - 'lower-case. e.g. human is homo_sapiens.') + "-o", + "--organism", + required=True, + help="organism to create index for. Must be formatted as genus_species in all " + "lower-case. e.g. human is homo_sapiens.", + ) pindex.add_argument( - '-f', '--folder', default=None, - help='folder in which to create the index. Defaults to the name of the organism, ' - 'which is created in the current directory.') + "-f", + "--folder", + default=None, + help="folder in which to create the index. Defaults to the name of the organism, " + "which is created in the current directory.", + ) pindex.add_argument( - '--ids', '--additional-id-types', nargs='*', - help='names of additional ids from other consortia to check against. If ' - 'provided, each ENSEMBL gene id must also be annotated by at least one of ' - 'these consortia to be considered valid and appear in the final SEQC count ' - 'matrix.') + "--ids", + "--additional-id-types", + nargs="*", + help="names of additional ids from other consortia to check against. If " + "provided, each ENSEMBL gene id must also be annotated by at least one of " + "these consortia to be considered valid and appear in the final SEQC count " + "matrix.", + ) pindex.add_argument( - '-b', '--valid-biotypes', default=('protein_coding', 'lincRNA'), - help='list of gene biotypes that are considered valid. Defaults are ' - 'protein_coding and lincRNA. In most cases, other biotypes are not expected ' - 'to be captured by SEQC, and should be excluded') + "-b", + "--valid-biotypes", + nargs="+", + default=[ + "protein_coding", + "lincRNA", + "antisense", + "IG_V_gene", + "IG_D_gene", + "IG_J_gene", + "IG_C_gene", + "TR_V_gene", + "TR_D_gene", + "TR_J_gene", + "TR_C_gene", + ], + help="list of gene biotypes that are considered valid. Defaults are " + "protein_coding and lincRNA. In most cases, other biotypes are not expected " + "to be captured by SEQC, and should be excluded", + ) + pindex.add_argument( + "--ensemble-release", + default=None, + type=int, + help="ENSEMBLE release number to be used (e.g. 85)", + ) + pindex.add_argument( + "--read-length", + default=100, + type=int, + help="length of reads that will be aligned against this index (will be used for STAR --sjdbOverhang)", + ) for parser in [pindex, p]: - r = parser.add_argument_group('Amazon Web Services arguments') + r = parser.add_argument_group("Amazon Web Services arguments") r.set_defaults(remote=True) r.set_defaults(terminate=True) - r.set_defaults(log_name='seqc_log.txt') # changed to .txt for email + r.set_defaults(log_name="seqc_log.txt") # changed to .txt for email + r.add_argument( + "--local", + dest="remote", + action="store_false", + help="Run locally instead of on an aws instance", + ) r.add_argument( - '--local', dest="remote", action="store_false", - help='Run locally instead of on an aws instance') + "-u", + "--upload-prefix", + metavar="U", + default=None, + help="s3 location for data to be uploaded.", + ) r.add_argument( - '-u', '--upload-prefix', metavar='U', default=None, - help='s3 location for data to be uploaded.') + "--instance-type", + default="r5.2xlarge", + help="AWS instance type to initialize for this job. " + "See https://aws.amazon.com/ec2/instance-types/ for valid types", + ) r.add_argument( - '--instance-type', default='r5.2xlarge', - help='AWS instance type to initialize for this job. ' - 'See https://aws.amazon.com/ec2/instance-types/ for valid types') + "--spot-bid", + type=float, + default=None, + help="float, Amount to bid for a spot instance. Default=None (will reserve a " + "non-spot instance). WARNING: using spot instances will cause your " + "instance to terminate if instance prices exceed your spot bid during " + "runtime.", + ) r.add_argument( - '--spot-bid', type=float, default=None, - help='float, Amount to bid for a spot instance. Default=None (will reserve a ' - 'non-spot instance). WARNING: using spot instances will cause your ' - 'instance to terminate if instance prices exceed your spot bid during ' - 'runtime.') + "--volume-size", + type=int, + default=None, + help="size in Gb required to execute the requested process. If not provided, " + "it will be estimated from passed parameters.", + ) r.add_argument( - '--volume-size', type=int, default=None, - help='size in Gb required to execute the requested process. If not provided, ' - 'it will be estimated from passed parameters.') + "-e", + "--email", + metavar="E", + default=None, + help="Email address to receive run summary or errors when running remotely. " + "Optional only if running locally.", + ) r.add_argument( - '-e', '--email', metavar='E', default=None, - help='Email address to receive run summary or errors when running remotely. ' - 'Optional only if running locally.') - r.add_argument('--debug', default=False, action='store_true', - help='If debug is set, runs that throw errors do not ' - 'terminate the instance they were run on.') + "--debug", + default=False, + action="store_true", + help="If debug is set, runs that throw errors do not " + "terminate the instance they were run on.", + ) terminate_parser = r.add_mutually_exclusive_group(required=False) terminate_parser.add_argument( - '--terminate', dest='terminate', action='store_true', - help='Terminate the ec2 instance upon completion.' + "--terminate", + dest="terminate", + action="store_true", + help="Terminate the ec2 instance upon completion.", ) terminate_parser.add_argument( - '--no-terminate', dest='terminate', action='store_false', - help='Do not terminate the ec2 instance upon completion.' + "--no-terminate", + dest="terminate", + action="store_false", + help="Do not terminate the ec2 instance upon completion.", ) terminate_parser.set_defaults(terminate=True) r.add_argument( - '--ami-id', dest='ami_id', required=False, - help='ID of the SEQC AMI to use' + "--ami-id", dest="ami_id", required=False, help="ID of the SEQC AMI to use" ) r.add_argument( - '--user-tags', dest='user_tags', required=False, - help='comma-separated key-value pairs for tagging ec2 instance (e.g. k1:v1,k2:v2).' + "--user-tags", + dest="user_tags", + required=False, + help="comma-separated key-value pairs for tagging ec2 instance (e.g. k1:v1,k2:v2).", ) r.add_argument( - '--remote-update', dest='remote_update', action='store_true', default=False, - help='whether to use the local SEQC installation package to update the remote instance' + "--remote-update", + dest="remote_update", + action="store_true", + default=False, + help="whether to use the local SEQC installation package to update the remote instance", ) r.add_argument( - '-k', '--rsa-key', metavar='K', default=None, - help='RSA key registered to your aws account that allowed access to ec2 ' - 'resources. Required if running instance remotely.') + "-k", + "--rsa-key", + metavar="K", + default=None, + help="RSA key registered to your aws account that allowed access to ec2 " + "resources. Required if running instance remotely.", + ) # custom help handling if len(args) == 0: # print help if no args are passed meta.print_help() sys.exit(1) - if args == ['run', '-h']: # send help for run to less, is too long - pipe = Popen(['less'], stdin=PIPE) + if args == ["run", "-h"]: # send help for run to less, is too long + pipe = Popen(["less"], stdin=PIPE) pipe.communicate(p.format_help().encode()) sys.exit(1) parsed = meta.parse_args(args) - if hasattr(parsed, 'rsa_key'): + if hasattr(parsed, "rsa_key"): if parsed.rsa_key is None: try: - parsed.rsa_key = os.environ['AWS_RSA_KEY'] + parsed.rsa_key = os.environ["AWS_RSA_KEY"] except KeyError: pass diff --git a/src/seqc/core/run.py b/src/seqc/core/run.py index d07d9c0..0a1afee 100644 --- a/src/seqc/core/run.py +++ b/src/seqc/core/run.py @@ -1,5 +1,3 @@ - - def run(args) -> None: """Run SEQC on the files provided in args, given specifications provided on the command line @@ -13,9 +11,10 @@ def run(args) -> None: import os import multiprocessing - from seqc import log, ec2, platforms, io + from seqc import log, ec2, platforms, io, version from seqc.sequence import fastq from seqc.alignment import star + from seqc.alignment import sam from seqc.email_ import email_user from seqc.read_array import ReadArray from seqc.core import verify, download @@ -25,13 +24,17 @@ def run(args) -> None: import numpy as np import scipy.io from shutil import copyfile + from shutil import move as movefile from seqc.summary.summary import MiniSummary from seqc.stats.mast import run_mast import logging - logger = logging.getLogger('weasyprint') - logger.handlers = [] # Remove the default stderr handler - logger.setLevel(100) - logger.addHandler(logging.FileHandler('weasyprint.log')) + import pickle + import pendulum + + # logger = logging.getLogger('weasyprint') + # logger.handlers = [] # Remove the default stderr handler + # logger.setLevel(100) + # logger.addHandler(logging.FileHandler('weasyprint.log')) def determine_start_point(arguments) -> (bool, bool, bool): """ @@ -61,45 +64,67 @@ def download_input(dir_, arguments): # download basespace data if necessary if arguments.basespace: arguments.barcode_fastq, arguments.genomic_fastq = io.BaseSpace.download( - arguments.platform, arguments.basespace, dir_, arguments.basespace_token) + arguments.platform, arguments.basespace, dir_, arguments.basespace_token + ) - # check for remote fastq file links + # get a list of input FASTQ files + # download from AWS S3 if the URI is prefixed with s3:// arguments.genomic_fastq = download.s3_data( - arguments.genomic_fastq, dir_ + '/genomic_fastq/') + arguments.genomic_fastq, dir_ + "/genomic_fastq/" + ) arguments.barcode_fastq = download.s3_data( - arguments.barcode_fastq, dir_ + '/barcode_fastq/') + arguments.barcode_fastq, dir_ + "/barcode_fastq/" + ) # get merged fastq file, unzip if necessary arguments.merged_fastq = ( - download.s3_data([arguments.merged_fastq], dir_ + '/')[0] if - arguments.merged_fastq is not None else None) + download.s3_data([arguments.merged_fastq], dir_ + "/")[0] + if arguments.merged_fastq is not None + else None + ) - # check if the index must be downloaded + # get a path to the STAR index files + # download from AWS S3 if the URI is prefixed with s3:// if any((arguments.alignment_file, arguments.read_array)): - index_link = arguments.index + 'annotations.gtf' + index_link = arguments.index + "annotations.gtf" else: index_link = arguments.index - download.s3_data([index_link], dir_ + '/index/') - arguments.index = dir_ + '/index/' - - # check if barcode files must be downloaded + index_files = download.s3_data([index_link], dir_ + "/index/") + # use the first filename in the list to get the index directory + # add a trailing slash to make the rest of the code not break;; + # e.g. test-data/index/chrStart.txt --> test-data/index/ + arguments.index = os.path.dirname(index_files[0]) + "/" + + # get a list of whitelisted barcodes files + # download from AWS S3 if the URI is prefixed with s3:// arguments.barcode_files = download.s3_data( - arguments.barcode_files, dir_ + '/barcodes/') + arguments.barcode_files, dir_ + "/barcodes/" + ) - # check if alignment_file needs downloading + # check if `alignment_file` is specified if arguments.alignment_file: + # get the alignment filename (*.bam) + # download from AWS S3 if the URI is prefixed with s3:// arguments.alignment_file = download.s3_data( - [arguments.alignment_file], dir_ + '/')[0] + [arguments.alignment_file], dir_ + "/" + )[0] - # check if readarray needs downloading + # check if `read_array` is specified if arguments.read_array: - arguments.read_array = download.s3_data([arguments.read_array], dir_ + '/')[0] + # get the readarray fileanem (*.h5) + # download from AWS S3 if the URI is prefixed with s3:// + arguments.read_array = download.s3_data([arguments.read_array], dir_ + "/")[ + 0 + ] return arguments def merge_fastq_files( - technology_platform, barcode_fastq: [str], output_stem: str, - genomic_fastq: [str]) -> (str, int): + technology_platform, + barcode_fastq: [str], + output_stem: str, + genomic_fastq: [str], + ) -> (str, int): """annotates genomic fastq with barcode information; merging the two files. :param technology_platform: class from platforms.py that defines the @@ -112,23 +137,24 @@ def merge_fastq_files( :returns str merged_fastq: name of merged fastq file """ - log.info('Merging genomic reads and barcode annotations.') + log.info("Merging genomic reads and barcode annotations.") merged_fastq = fastq.merge_paired( merge_function=technology_platform.merge_function, - fout=output_stem + '_merged.fastq', + fout=output_stem + "_merged.fastq", genomic=genomic_fastq, - barcode=barcode_fastq) + barcode=barcode_fastq, + ) # delete genomic/barcode fastq files after merged.fastq creation - log.info('Removing original fastq file for memory management.') - delete_fastq = ' '.join(['rm'] + genomic_fastq + barcode_fastq) - io.ProcessManager(delete_fastq).run_all() + # log.info('Removing original fastq file for memory management.') + # delete_fastq = ' '.join(['rm'] + genomic_fastq + barcode_fastq) + # io.ProcessManager(delete_fastq).run_all() return merged_fastq def align_fastq_records( - merged_fastq, dir_, star_args, star_index, n_proc, - aws_upload_key) -> (str, str, io.ProcessManager): + merged_fastq, dir_, star_args, star_index, n_proc, aws_upload_key + ) -> (str, str, io.ProcessManager): """ Align fastq records. @@ -143,43 +169,45 @@ def align_fastq_records( name of .sam file containing aligned reads, indicator of which data was used as input, and a ProcessManager for merged fastq files """ - log.info('Aligning merged fastq records.') - alignment_directory = dir_ + '/alignments/' + log.info("Aligning merged fastq records.") + alignment_directory = dir_ + "/alignments/" os.makedirs(alignment_directory, exist_ok=True) if star_args is not None: - star_kwargs = dict(a.strip().split('=') for a in star_args) + star_kwargs = dict(a.strip().split("=") for a in star_args) else: star_kwargs = {} bamfile = star.align( - merged_fastq, star_index, n_proc, alignment_directory, - **star_kwargs) + merged_fastq, star_index, n_proc, alignment_directory, **star_kwargs + ) + + log.info("Gzipping merged fastq file.") + if pigz: + pigz_zip = "pigz --best -f {fname}".format(fname=merged_fastq) + else: + pigz_zip = "gzip -f {fname}".format(fname=merged_fastq) + pigz_proc = io.ProcessManager(pigz_zip) + pigz_proc.run_all() + pigz_proc.wait_until_complete() # prevents slowing down STAR alignment + merged_fastq += ".gz" # reflect gzipped nature of file if aws_upload_key: - log.info('Gzipping merged fastq file.') - if pigz: - pigz_zip = "pigz --best -k -f {fname}".format(fname=merged_fastq) - else: - pigz_zip = "gzip -kf {fname}".format(fname=merged_fastq) - pigz_proc = io.ProcessManager(pigz_zip) - pigz_proc.run_all() - pigz_proc.wait_until_complete() # prevents slowing down STAR alignment - merged_fastq += '.gz' # reflect gzipped nature of file - - log.info('Uploading gzipped merged fastq file to S3.') - merge_upload = 'aws s3 mv {fname} {s3link}'.format( - fname=merged_fastq, s3link=aws_upload_key) + log.info("Uploading gzipped merged fastq file to S3.") + merge_upload = "aws s3 mv {fname} {s3link}".format( + fname=merged_fastq, s3link=aws_upload_key + ) upload_manager = io.ProcessManager(merge_upload) upload_manager.run_all() else: - log.info('Removing merged fastq file for memory management.') - rm_merged = 'rm %s' % merged_fastq - io.ProcessManager(rm_merged).run_all() + # log.info('Removing merged fastq file for memory management.') + # rm_merged = 'rm %s' % merged_fastq + # io.ProcessManager(rm_merged).run_all() upload_manager = None return bamfile, upload_manager - def create_read_array(bamfile, index, aws_upload_key, min_poly_t, - max_transcript_length): + def create_read_array( + bamfile, index, aws_upload_key, min_poly_t, max_transcript_length + ): """Create or download a ReadArray object. :param max_transcript_length: @@ -189,55 +217,85 @@ def create_read_array(bamfile, index, aws_upload_key, min_poly_t, :param int min_poly_t: minimum number of poly_t nucleotides for a read to be valid :returns ReadArray, UploadManager: ReadArray object, bamfile ProcessManager """ - log.info('Filtering aligned records and constructing record database.') + log.info("Filtering aligned records and constructing record database.") # Construct translator translator = GeneIntervals( - index + 'annotations.gtf', max_transcript_length=max_transcript_length) - read_array = ReadArray.from_alignment_file( - bamfile, translator, min_poly_t) + index + "annotations.gtf", max_transcript_length=max_transcript_length + ) + read_array, read_names = ReadArray.from_alignment_file( + bamfile, translator, min_poly_t + ) # converting sam to bam and uploading to S3, else removing bamfile if aws_upload_key: - log.info('Uploading bam file to S3.') - upload_bam = 'aws s3 mv {fname} {s3link}{prefix}_Aligned.out.bam'.format( - fname=bamfile, s3link=aws_upload_key, prefix=args.output_prefix) + log.info("Uploading bam file to S3.") + upload_bam = "aws s3 mv {fname} {s3link}{prefix}_Aligned.out.bam".format( + fname=bamfile, s3link=aws_upload_key, prefix=args.output_prefix + ) print(upload_bam) upload_manager = io.ProcessManager(upload_bam) upload_manager.run_all() else: - log.info('Removing bamfile for memory management.') - rm_bamfile = 'rm %s' % bamfile - io.ProcessManager(rm_bamfile).run_all() + if os.path.exists(bamfile): + movefile(bamfile, args.output_prefix + "_Aligned.out.bam") + # log.info('Removing bamfile for memory management.') + # rm_bamfile = 'rm %s' % bamfile + # io.ProcessManager(rm_bamfile).run_all() upload_manager = None - return read_array, upload_manager + return read_array, upload_manager, read_names # ######################## MAIN FUNCTION BEGINS HERE ################################ - log.setup_logger(args.log_name) + log.setup_logger(args.log_name, args.debug) with ec2.instance_clean_up( - email=args.email, upload=args.upload_prefix, log_name=args.log_name, - debug=args.debug, terminate=args.terminate + email=args.email, + upload=args.upload_prefix, + log_name=args.log_name, + debug=args.debug, + terminate=args.terminate, + running_remote=args.remote, ): - pigz, mutt = verify.executables('pigz', 'mutt') + + start_run_time = pendulum.now() + + log.notify("SEQC=v{}".format(version.__version__)) + log.notify("STAR=v{}".format(star.get_version())) + log.notify("samtools=v{}".format(sam.get_version())) + + pigz, mutt = verify.executables("pigz", "mutt") if mutt: - log.notify('mutt executable identified, email will be sent when run ' - 'terminates. ') + log.notify( + "mutt executable identified, email will be sent when run " + "terminates. " + ) else: - log.notify('mutt was not found on this machine; an email will not be sent to ' - 'the user upon termination of SEQC run.') + log.notify( + "mutt was not found on this machine; an email will not be sent to " + "the user upon termination of SEQC run." + ) # turn off lower coverage filter for 10x - if (args.platform == "ten_x") or (args.platform == "ten_x_v2") or (args.platform == "ten_x_v3"): + if ( + (args.platform == "ten_x") + or (args.platform == "ten_x_v2") + or (args.platform == "ten_x_v3") + ): args.filter_low_coverage = False max_insert_size = args.max_insert_size if args.filter_mode == "scRNA-seq": # for scRNA-seq - if (args.platform == "ten_x") or (args.platform == "ten_x_v2") or (args.platform == "ten_x_v3"): + if ( + (args.platform == "ten_x") + or (args.platform == "ten_x_v2") + or (args.platform == "ten_x_v3") + ): # set max_transcript_length (max_insert_size) = 10000 max_insert_size = 10000 - log.notify("Full length transcripts are used for read mapping in 10x data.") + log.notify( + "Full length transcripts are used for read mapping in 10x data." + ) elif args.filter_mode == "snRNA-seq": # for snRNA-seq # e.g. 2304700 # hg38 @@ -251,9 +309,15 @@ def create_read_array(bamfile, index, aws_upload_key, min_poly_t, log.args(args) + # e.g. + # --output-prefix=test-data/_outs/test + # output_dir=test-data + # output_prefix=test output_dir, output_prefix = os.path.split(args.output_prefix) if not output_dir: - output_dir = '.' + output_dir = "." + else: + os.makedirs(output_dir, exist_ok=True) # check if the platform name provided is supported by seqc # todo move into verify for run @@ -273,149 +337,238 @@ def create_read_array(bamfile, index, aws_upload_key, min_poly_t, if merge: if args.min_poly_t is None: # estimate min_poly_t if it was not provided args.min_poly_t = filter.estimate_min_poly_t( - args.barcode_fastq, platform) - log.notify('Estimated min_poly_t={!s}'.format(args.min_poly_t)) + args.barcode_fastq, platform + ) + log.notify("Estimated min_poly_t={!s}".format(args.min_poly_t)) args.merged_fastq = merge_fastq_files( - platform, args.barcode_fastq, args.output_prefix, args.genomic_fastq) + platform, args.barcode_fastq, args.output_prefix, args.genomic_fastq + ) # SEQC was started from input other than fastq files if args.min_poly_t is None: args.min_poly_t = 0 - log.notify('Warning: SEQC started from step other than unmerged fastq with ' - 'empty --min-poly-t parameter. Continuing with --min-poly-t 0.') + log.warn( + "Warning: SEQC started from step other than unmerged fastq with " + "empty --min-poly-t parameter. Continuing with --min-poly-t 0." + ) if align: upload_merged = args.upload_prefix if merge else None args.alignment_file, manage_merged = align_fastq_records( - args.merged_fastq, output_dir, args.star_args, - args.index, n_processes, upload_merged) + args.merged_fastq, + output_dir, + args.star_args, + args.index, + n_processes, + upload_merged, + ) else: manage_merged = None if process_bamfile: + # if the starting point was a BAM file (i.e. args.alignment_file=*.bam & align=False) + # do not upload by setting this to None upload_bamfile = args.upload_prefix if align else None - ra, manage_bamfile, = create_read_array( - args.alignment_file, args.index, upload_bamfile, args.min_poly_t, - max_insert_size) - + ra, manage_bamfile, read_names = create_read_array( + args.alignment_file, + args.index, + upload_bamfile, + args.min_poly_t, + max_insert_size, + ) else: manage_bamfile = None ra = ReadArray.load(args.read_array) + # fixme: the old read_array doesn't have read_names + read_names = None # create the first summary section here - status_filters_section = Section.from_status_filters(ra, 'initial_filtering.html') + status_filters_section = Section.from_status_filters( + ra, "initial_filtering.html" + ) sections = [status_filters_section] # Skip over the corrections if read array is specified by the user if not args.read_array: # Correct barcodes - log.info('Correcting barcodes and estimating error rates.') - error_rate = platform.apply_barcode_correction(ra, args.barcode_files) + log.info("Correcting barcodes and estimating error rates.") + error_rate, df_cb_correction = platform.apply_barcode_correction( + ra, args.barcode_files + ) + if df_cb_correction is not None and len(df_cb_correction) > 0: + df_cb_correction.to_csv( + args.output_prefix + "_cb-correction.csv.gz", + index=False, + compression="gzip", + ) # Resolve multimapping - log.info('Resolving ambiguous alignments.') + log.info("Resolving ambiguous alignments.") mm_results = ra.resolve_ambiguous_alignments() + + + # 121319782799149 / 614086965 / pos=49492038 / AAACATAACG + # 121319782799149 / 512866590 / pos=49490848 / TCAATTAATC (1 hemming dist away from TCAATTAATT) + # ra.data["rmt"][91490] = 512866590 + # ra.positions[91490] = 49492038 + + + # correct errors - log.info('Identifying RMT errors.') - platform.apply_rmt_correction(ra, error_rate) + log.info("Identifying RMT errors.") + df_umi_correction = platform.apply_rmt_correction(ra, error_rate) + if df_umi_correction is not None and len(df_umi_correction) > 0: + df_umi_correction.to_csv( + args.output_prefix + "_umi-correction.csv.gz", + index=False, + compression="gzip", + ) # Apply low coverage filter if platform.filter_lonely_triplets: - log.info('Filtering lonely triplet reads') + log.info("Filtering lonely triplet reads") ra.filter_low_coverage(alpha=args.low_coverage_alpha) - log.info('Saving read array.') - ra.save(args.output_prefix + '.h5') + log.info("Saving read array.") + ra.save(args.output_prefix + ".h5") + + # generate a file with read_name, corrected cb, corrected umi + # read_name already has pre-corrected cb & umi + # log.info("Saving correction information.") + # ra.create_readname_cb_umi_mapping( + # read_names, args.output_prefix + "_correction.csv.gz" + # ) # Summary sections # create the sections for the summary object sections += [ - Section.from_cell_barcode_correction(ra, 'cell_barcode_correction.html'), - Section.from_rmt_correction(ra, 'rmt_correction.html'), - Section.from_resolve_multiple_alignments(mm_results, 'multialignment.html')] + Section.from_cell_barcode_correction( + ra, "cell_barcode_correction.html" + ), + Section.from_rmt_correction(ra, "rmt_correction.html"), + Section.from_resolve_multiple_alignments( + mm_results, "multialignment.html" + ), + ] # create a dictionary to store output parameters mini_summary_d = dict() # filter non-cells - log.info('Creating counts matrix.') + log.info("Creating counts matrix.") sp_reads, sp_mols = ra.to_count_matrix( - sparse_frame=True, genes_to_symbols=args.index + 'annotations.gtf') + sparse_frame=True, genes_to_symbols=args.index + "annotations.gtf" + ) # Save sparse matrices - log.info('Saving sparse matrices') - scipy.io.mmwrite(args.output_prefix + '_sparse_read_counts.mtx', sp_reads.data) - scipy.io.mmwrite(args.output_prefix + '_sparse_molecule_counts.mtx', sp_mols.data) + log.info("Saving sparse matrices") + scipy.io.mmwrite(args.output_prefix + "_sparse_read_counts.mtx", sp_reads.data) + scipy.io.mmwrite( + args.output_prefix + "_sparse_molecule_counts.mtx", sp_mols.data + ) # Indices df = np.array([np.arange(sp_reads.shape[0]), sp_reads.index]).T np.savetxt( - args.output_prefix + '_sparse_counts_barcodes.csv', df, - fmt='%d', delimiter=',') + args.output_prefix + "_sparse_counts_barcodes.csv", + df, + fmt="%d", + delimiter=",", + ) # Columns df = np.array([np.arange(sp_reads.shape[1]), sp_reads.columns]).T np.savetxt( - args.output_prefix + '_sparse_counts_genes.csv', df, - fmt='%s', delimiter=',') + args.output_prefix + "_sparse_counts_genes.csv", df, fmt="%s", delimiter="," + ) - log.info('Creating filtered counts matrix.') - cell_filter_figure = args.output_prefix + '_cell_filters.png' + log.info("Creating filtered counts matrix.") + cell_filter_figure = args.output_prefix + "_cell_filters.png" # By pass low count filter for mars seq - sp_csv, total_molecules, molecules_lost, cells_lost, cell_description = ( - filter.create_filtered_dense_count_matrix( - sp_mols, sp_reads, mini_summary_d, plot=True, figname=cell_filter_figure, - filter_low_count=platform.filter_low_count, - filter_mitochondrial_rna=args.filter_mitochondrial_rna, - filter_low_coverage=args.filter_low_coverage, - filter_low_gene_abundance=args.filter_low_gene_abundance)) + ( + sp_csv, + total_molecules, + molecules_lost, + cells_lost, + cell_description, + ) = filter.create_filtered_dense_count_matrix( + sp_mols, + sp_reads, + mini_summary_d, + plot=True, + figname=cell_filter_figure, + filter_low_count=platform.filter_low_count, + filter_mitochondrial_rna=args.filter_mitochondrial_rna, + filter_low_coverage=args.filter_low_coverage, + filter_low_gene_abundance=args.filter_low_gene_abundance, + ) # Output files - files = [cell_filter_figure, - args.output_prefix + '.h5', - args.output_prefix + '_sparse_read_counts.mtx', - args.output_prefix + '_sparse_molecule_counts.mtx', - args.output_prefix + '_sparse_counts_barcodes.csv', - args.output_prefix + '_sparse_counts_genes.csv'] + files = [ + cell_filter_figure, + args.output_prefix + ".h5", + args.output_prefix + "_sparse_read_counts.mtx", + args.output_prefix + "_sparse_molecule_counts.mtx", + args.output_prefix + "_sparse_counts_barcodes.csv", + args.output_prefix + "_sparse_counts_genes.csv", + ] + + if os.path.exists(args.output_prefix + "_cb-correction.csv.gz"): + files.append(args.output_prefix + "_cb-correction.csv.gz") + if os.path.exists(args.output_prefix + "_umi-correction.csv.gz"): + files.append(args.output_prefix + "_umi-correction.csv.gz") # Summary sections # create the sections for the summary object sections += [ - Section.from_cell_filtering(cell_filter_figure, 'cell_filtering.html'), - Section.from_run_time(args.log_name, 'seqc_log.html')] + Section.from_cell_filtering(cell_filter_figure, "cell_filtering.html"), + Section.from_run_time(args.log_name, "seqc_log.html"), + ] # get alignment summary - if os.path.isfile(output_dir + '/alignments/Log.final.out'): - os.rename(output_dir + '/alignments/Log.final.out', - output_dir + '/' + args.output_prefix + '_alignment_summary.txt') + if os.path.isfile(output_dir + "/alignments/Log.final.out"): + os.rename( + output_dir + "/alignments/Log.final.out", + args.output_prefix + "_alignment_summary.txt", + ) # Upload files and summary sections - files += [output_dir + '/' + args.output_prefix + '_alignment_summary.txt'] + files += [args.output_prefix + "_alignment_summary.txt"] sections.insert( - 0, Section.from_alignment_summary( - output_dir + '/' + args.output_prefix + '_alignment_summary.txt', - 'alignment_summary.html')) - - cell_size_figure = 'cell_size_distribution.png' + 0, + Section.from_alignment_summary( + args.output_prefix + "_alignment_summary.txt", + "alignment_summary.html", + ), + ) + + cell_size_figure = args.output_prefix + "_cell_size_distribution.png" index_section = Section.from_final_matrix( - sp_csv, cell_size_figure, 'cell_distribution.html') - seqc_summary = Summary( - output_dir + '/' + args.output_prefix + '_summary', sections, index_section) + sp_csv, cell_size_figure, "cell_distribution.html" + ) + seqc_summary = Summary(args.output_prefix + "_summary", sections, index_section) seqc_summary.prepare_archive() seqc_summary.import_image(cell_filter_figure) seqc_summary.import_image(cell_size_figure) seqc_summary.render() + + # create a .tar.gz with `test_summary/*` summary_archive = seqc_summary.compress_archive() files += [summary_archive] # Create a mini summary section - alignment_summary_file = output_dir + '/' + args.output_prefix + '_alignment_summary.txt' + alignment_summary_file = args.output_prefix + "_alignment_summary.txt" seqc_mini_summary = MiniSummary( - args.output_prefix, mini_summary_d, alignment_summary_file, cell_filter_figure, - cell_size_figure) + output_dir, + output_prefix, + mini_summary_d, + alignment_summary_file, + cell_filter_figure, + cell_size_figure, + ) seqc_mini_summary.compute_summary_fields(ra, sp_csv) seqc_mini_summary_json, seqc_mini_summary_pdf = seqc_mini_summary.render() files += [seqc_mini_summary_json, seqc_mini_summary_pdf] @@ -423,13 +576,17 @@ def create_read_array(bamfile, index, aws_upload_key, min_poly_t, # Running MAST for differential analysis # file storing the list of differentially expressed genes for each cluster de_gene_list_file = run_mast( - seqc_mini_summary.get_counts_filtered(), seqc_mini_summary.get_clustering_result(), - args.output_prefix) + seqc_mini_summary.get_counts_filtered(), + seqc_mini_summary.get_clustering_result(), + args.output_prefix, + ) files += [de_gene_list_file] # adding the cluster column and write down gene-cell count matrix - dense_csv = args.output_prefix + '_dense.csv' - sp_csv.insert(loc=0, column='CLUSTER', value=seqc_mini_summary.get_clustering_result()) + dense_csv = args.output_prefix + "_dense.csv" + sp_csv.insert( + loc=0, column="CLUSTER", value=seqc_mini_summary.get_clustering_result() + ) sp_csv.to_csv(dense_csv) files += [dense_csv] @@ -439,46 +596,66 @@ def create_read_array(bamfile, index, aws_upload_key, min_poly_t, for item in files: try: ec2.Retry(retries=5)(io.S3.upload_file)(item, bucket, key) - item_name = item.split('/')[-1] - log.info('Successfully uploaded %s to the specified S3 location ' - '"%s%s".' % (item, args.upload_prefix, item_name)) + item_name = item.split("/")[-1] + log.info( + 'Successfully uploaded %s to "%s%s".' + % (item, args.upload_prefix, item_name) + ) except FileNotFoundError: - log.notify('Item %s was not found! Continuing with upload...' % item) + log.notify( + "Item %s was not found! Continuing with upload..." % item + ) if manage_merged: manage_merged.wait_until_complete() - log.info('Successfully uploaded %s to the specified S3 location "%s"' % - (args.merged_fastq, args.upload_prefix)) + log.info( + 'Successfully uploaded %s to "%s"' + % (args.merged_fastq, args.upload_prefix) + ) if manage_bamfile: manage_bamfile.wait_until_complete() - log.info('Successfully uploaded %s to the specified S3 location "%s"' - % (args.alignment_file, args.upload_prefix)) + log.info( + 'Successfully uploaded %s to "%s"' + % (args.alignment_file, args.upload_prefix) + ) - log.info('SEQC run complete. Cluster will be terminated') + log.info("SEQC run complete. Cluster will be terminated") # upload logs if args.upload_prefix: - # Upload count matrices files, logs, and return + # upload logs (seqc_log.txt, nohup.log) bucket, key = io.S3.split_link(args.upload_prefix) - for item in [args.log_name, './nohup.log']: + for item in [args.log_name, "./nohup.log"]: try: # Make a copy of the file with the output prefix - copyfile(item, args.output_prefix + '_' + item) - print(args.output_prefix + '_' + item) + copyfile(item, args.output_prefix + "_" + item) + print(args.output_prefix + "_" + item) ec2.Retry(retries=5)(io.S3.upload_file)( - args.output_prefix + '_' + item, bucket, key) - log.info('Successfully uploaded %s to the specified S3 location ' - '"%s".' % (item, args.upload_prefix)) + args.output_prefix + "_" + item, bucket, key + ) + log.info( + 'Successfully uploaded %s to "%s".' % (item, args.upload_prefix) + ) except FileNotFoundError: - log.notify('Item %s was not found! Continuing with upload...' % item) + log.notify( + "Item %s was not found! Continuing with upload..." % item + ) + else: + # move the log to output directory + movefile(args.log_name, args.output_prefix + "_" + args.log_name) # todo local test does not send this email if mutt: email_body = ( '' - 'SEQC RUN COMPLETE.\n\n' - 'The run log has been attached to this email and ' - 'results are now available in the S3 location you specified: ' - '"%s"\n\n' % args.upload_prefix) - email_body = email_body.replace('\n', '
').replace('\t', ' ') + "SEQC RUN COMPLETE.\n\n" + "The run log has been attached to this email and " + "results are now available in the S3 location you specified: " + '"%s"\n\n' % args.upload_prefix + ) + email_body = email_body.replace("\n", "
").replace("\t", " ") email_user(summary_archive, email_body, args.email) + + end_run_time = pendulum.now() + running_time = end_run_time - start_run_time + log.info("Running Time={}".format(running_time.in_words())) diff --git a/src/seqc/core/start.py b/src/seqc/core/start.py index de22478..9f02d9e 100644 --- a/src/seqc/core/start.py +++ b/src/seqc/core/start.py @@ -15,7 +15,7 @@ def start(args): instance_type=args.instance_type, spot_bid=args.spot_bid, volume_size=args.volume_size, - ami_id=args.ami_id + ami_id=args.ami_id, ) - + instance.start() diff --git a/src/seqc/core/terminate.py b/src/seqc/core/terminate.py index 784818e..2667d75 100644 --- a/src/seqc/core/terminate.py +++ b/src/seqc/core/terminate.py @@ -8,11 +8,11 @@ def terminate(args): :param args: namespace object from argparse, must include rsa-key and instance-id :return None: """ - ec2 = boto3.resource('ec2') + ec2 = boto3.resource("ec2") for id_ in args.instance_ids: instance = ec2.Instance(id=id_) try: response = instance.terminate() - print('termination signal sent:\n%s' % response) + print("termination signal sent:\n%s" % response) except ClientError: - print('instance %s does not exist') + print("instance %s does not exist") diff --git a/src/seqc/core/verify.py b/src/seqc/core/verify.py index b42e506..fb2d637 100644 --- a/src/seqc/core/verify.py +++ b/src/seqc/core/verify.py @@ -20,17 +20,17 @@ def validate_and_return_size(filename): :param str filename: filepath or s3 link :return None: raises errors if path or link is invalid. """ - if filename.startswith('s3://'): + if filename.startswith("s3://"): io.S3.check_links([filename]) return io.S3.obtain_size(filename) else: if os.path.isfile(filename): return filesize(filename) - elif os.path.isdir(filename.rstrip('/')): + elif os.path.isdir(filename.rstrip("/")): return sum(filesize(filename + f) for f in os.listdir(filename)) else: print(filename) - raise ValueError('%s does not point to a valid file') + raise ValueError("%s does not point to a valid file") def estimate_required_volume_size(args): @@ -44,8 +44,12 @@ def estimate_required_volume_size(args): # todo stopped here; remove aws dependency if args.barcode_fastq and args.genomic_fastq: - total += sum(validate_and_return_size(f) for f in args.barcode_fastq) * 14 + 9e10 - total += sum(validate_and_return_size(f) for f in args.genomic_fastq) * 14 + 9e10 + total += ( + sum(validate_and_return_size(f) for f in args.barcode_fastq) * 14 + 9e10 + ) + total += ( + sum(validate_and_return_size(f) for f in args.genomic_fastq) * 14 + 9e10 + ) total += validate_and_return_size(args.index) elif args.alignment_file: @@ -60,13 +64,16 @@ def estimate_required_volume_size(args): total += validate_and_return_size(args.read_array) if args.basespace: - if not args.basespace_token or args.basespace_token == 'None': + if not args.basespace_token or args.basespace_token == "None": raise ValueError( - 'If the --basespace argument is used, the basespace token must be ' - 'specified in the seqc config file or passed as --basespace-token') + "If the --basespace argument is used, the basespace token must be " + "specified in the seqc config file or passed as --basespace-token" + ) io.BaseSpace.check_sample(args.basespace, args.basespace_token) - total += io.BaseSpace.check_size(args.basespace, args.basespace_token) * 14 + 9e10 + total += ( + io.BaseSpace.check_size(args.basespace, args.basespace_token) * 14 + 9e10 + ) return ceil(total * 1e-9) @@ -84,13 +91,13 @@ def run(args) -> float: """ if args.rsa_key is None: - raise ValueError('-k/--rsa-key does not point to a valid file object. ') + raise ValueError("-k/--rsa-key does not point to a valid file object. ") if not os.path.isfile(args.rsa_key): - raise ValueError('-k/--rsa-key does not point to a valid file object. ') + raise ValueError("-k/--rsa-key does not point to a valid file object. ") - if args.output_prefix.endswith('/'): - raise ValueError('output_stem should not be a directory.') - if not args.index.endswith('/'): + if args.output_prefix.endswith("/"): + raise ValueError("output_stem should not be a directory.") + if not args.index.endswith("/"): raise ValueError('index must be a directory, and must end with "/"') # check platform name; raises ValueError if invalid @@ -98,7 +105,7 @@ def run(args) -> float: # check to make sure that --email-status is passed with remote run if args.remote and not args.email: - raise ValueError('Please supply the --email-status flag for a remote SEQC run.') + raise ValueError("Please supply the --email-status flag for a remote SEQC run.") # if args.instance_type not in ['c3', 'c4', 'r3']: # todo fix this instance check # raise ValueError('All AWS instance types must be either c3, c4, or r3.') # if args.terminate not in ['True', 'true', 'False', 'false', 'on-success']: @@ -107,48 +114,66 @@ def run(args) -> float: # make sure at least one input has been passed valid_inputs = ( - args.barcode_fastq, args.genomic_fastq, args.merged_fastq, args.alignment_file, - args.basespace, args.read_array) + args.barcode_fastq, + args.genomic_fastq, + args.merged_fastq, + args.alignment_file, + args.basespace, + args.read_array, + ) if not any(valid_inputs): raise ValueError( - 'At least one input argument (-b/-g, -m, -s, -r, --basespace) must be passed ' - 'to SEQC.') + "At least one input argument (-b/-g, -m, -s, -r, --basespace) must be passed " + "to SEQC." + ) if not args.barcode_files: # todo clean this up and fold into platform somehow - if args.platform != 'drop_seq': - raise ValueError('--barcode-files is required for this platform.') + if args.platform != "drop_seq": + raise ValueError("--barcode-files is required for this platform.") # make sure at most one input type has been passed num_inputs = 0 if args.barcode_fastq or args.genomic_fastq: if not all((args.barcode_fastq, args.genomic_fastq)): raise ValueError( - 'if either genomic or barcode fastq are provided, both must be provided') + "if either genomic or barcode fastq are provided, both must be provided" + ) num_inputs += 1 - num_inputs += sum(1 for i in (args.merged_fastq, args.alignment_file, - args.basespace, args.read_array) if i) + num_inputs += sum( + 1 + for i in ( + args.merged_fastq, + args.alignment_file, + args.basespace, + args.read_array, + ) + if i + ) if num_inputs > 1: raise ValueError( - 'user should provide at most one input argument (-b/-g, -m, -s, -r, ' - '--basespace') + "user should provide at most one input argument (-b/-g, -m, -s, -r, " + "--basespace" + ) # if basespace is being used, make sure there is a valid basespace token - if args.basespace and not hasattr(args, 'basespace_token'): - raise RuntimeError('if --basespace input is selected, user must provide an OAuth ' - 'token using the --basespace-token parameter.') + if args.basespace and not hasattr(args, "basespace_token"): + raise RuntimeError( + "if --basespace input is selected, user must provide an OAuth " + "token using the --basespace-token parameter." + ) # check that spot-bid is correct if args.spot_bid is not None: if args.spot_bid < 0: - raise ValueError('bid %f must be a non-negative float.' % args.spot_bid) + raise ValueError("bid %f must be a non-negative float." % args.spot_bid) - if args.upload_prefix and not args.upload_prefix.startswith('s3://'): - raise ValueError('upload_prefix should be an s3 address beginning with s3://') + if args.upload_prefix and not args.upload_prefix.startswith("s3://"): + raise ValueError("upload_prefix should be an s3 address beginning with s3://") - if args.upload_prefix.startswith('s3://'): + if args.upload_prefix.startswith("s3://"): ec2.check_bucket(args.upload_prefix) if args.volume_size is None: - setattr(args, 'volume_size', estimate_required_volume_size(args)) + setattr(args, "volume_size", estimate_required_volume_size(args)) return args @@ -160,7 +185,7 @@ def index(args): :return: updated namespace object with volume_size set. """ if args.volume_size is None: - setattr(args, 'volume_size', 100) + setattr(args, "volume_size", 100) return args @@ -182,12 +207,17 @@ def platform_name(name: str): :param name: string of platform name to check :return: name (if supported by seqc). """ - choices = [x[0] for x in inspect.getmembers(platforms, inspect.isclass) if - issubclass(x[1], platforms.AbstractPlatform)][1:] + choices = [ + x[0] + for x in inspect.getmembers(platforms, inspect.isclass) + if issubclass(x[1], platforms.AbstractPlatform) + ][1:] if name not in choices: - raise ValueError('Please specify a valid platform name for SEQC. The available ' - 'options are: {}'.format(choices)) + raise ValueError( + "Please specify a valid platform name for SEQC. The available " + "options are: {}".format(choices) + ) # throw error for mars1_seq since we don't have the appropriate primer length yet - if name == 'mars1_seq': - raise ValueError('Mars1-seq is currently not stable in this version of SEQC.') + if name == "mars1_seq": + raise ValueError("Mars1-seq is currently not stable in this version of SEQC.") return name diff --git a/src/seqc/ec2.py b/src/seqc/ec2.py index 51db42b..dd1a680 100644 --- a/src/seqc/ec2.py +++ b/src/seqc/ec2.py @@ -4,7 +4,6 @@ import configparser import traceback import types -import dill from functools import wraps from contextlib import closing from paramiko.ssh_exception import NoValidConnectionsError @@ -14,13 +13,12 @@ from subprocess import Popen, PIPE from seqc import log, io from seqc.core import verify -from seqc.exceptions import ( - RetryLimitExceeded, InstanceNotRunningError, EC2RuntimeError) +from seqc.exceptions import RetryLimitExceeded, InstanceNotRunningError, EC2RuntimeError from botocore.exceptions import ClientError # change some logging defaults -log.logging.getLogger('paramiko').setLevel(log.logging.CRITICAL) -log.logging.getLogger('boto3').setLevel(log.logging.CRITICAL) +log.logging.getLogger("paramiko").setLevel(log.logging.CRITICAL) +log.logging.getLogger("boto3").setLevel(log.logging.CRITICAL) def _get_ec2_configuration(): @@ -28,29 +26,24 @@ def _get_ec2_configuration(): for credentials will be searchable!""" defaults = {} config = configparser.ConfigParser() - config.read(os.path.expanduser('~/.aws/config')) - defaults['region'] = config['default']['region'] - config.read(os.path.expanduser('~/.aws/credentials')) - defaults['aws_access_key_id'] = config['default']['aws_access_key_id'] - defaults['aws_secret_access_key'] = config['default']['aws_secret_access_key'] + config.read(os.path.expanduser("~/.aws/config")) + defaults["region"] = config["default"]["region"] + config.read(os.path.expanduser("~/.aws/credentials")) + defaults["aws_access_key_id"] = config["default"]["aws_access_key_id"] + defaults["aws_secret_access_key"] = config["default"]["aws_secret_access_key"] return defaults class Retry: - def __init__( - self, - retries: int=10, - catch=(ClientError,), - delay: int=1, - verbose=False): + self, retries: int = 10, catch=(ClientError,), delay: int = 1, verbose=False + ): self.retries = retries self.exceptions_to_catch = catch self.delay_retry = delay self.verbose = verbose def __call__(self, function): - @wraps(function) def wrapper(*args, **kwargs): retries = self.retries @@ -62,17 +55,25 @@ def wrapper(*args, **kwargs): retries -= 1 if self.verbose: log.notify( - 'Non fatal error in function {} (retrying in ' - '{!s}s):\n{}'.format( - function.__qualname__, self.delay_retry, - traceback.format_exc())) + "Non fatal error in function {} (retrying in " + "{!s}s):\n{}".format( + function.__qualname__, + self.delay_retry, + traceback.format_exc(), + ) + ) time.sleep(self.delay_retry) else: raise RetryLimitExceeded( - 'fatal error in function {} occurred {} times at {!s}s call ' - 'interval:\n{}'.format( - function.__qualname__, self.retries, self.delay_retry, - traceback.format_exc())) + "fatal error in function {} occurred {} times at {!s}s call " + "interval:\n{}".format( + function.__qualname__, + self.retries, + self.delay_retry, + traceback.format_exc(), + ) + ) + return wrapper @@ -83,14 +84,21 @@ class AWSInstance(object): of commands on the remote server """ - ec2 = boto3.resource('ec2') - client = boto3.client('ec2') + ec2 = boto3.resource("ec2") + client = boto3.client("ec2") def __init__( self, - rsa_key, instance_type, instance_id=None, security_group_id=None, - spot_bid=None, synchronous=False, volume_size=5, - user_tags=None, remote_update=False, ami_id=None, + rsa_key, + instance_type, + instance_id=None, + security_group_id=None, + spot_bid=None, + synchronous=False, + volume_size=5, + user_tags=None, + remote_update=False, + ami_id=None, **kwargs ): """ @@ -117,9 +125,9 @@ def __init__( # todo allow overwriting of these arguments with **kwargs defaults = _get_ec2_configuration() - self.aws_public_access_key = defaults['aws_access_key_id'] - self.aws_secret_access_key = defaults['aws_secret_access_key'] - self.region = defaults['region'] + self.aws_public_access_key = defaults["aws_access_key_id"] + self.aws_secret_access_key = defaults["aws_secret_access_key"] + self.region = defaults["region"] self._rsa_key = rsa_key if not ami_id or not ami_id.startswith("ami-"): raise ValueError("You must specify a valid ID for the SEQC AMI to be used.") @@ -134,12 +142,15 @@ def __init__( self.user_tags = user_tags if not isinstance(volume_size, int) or not 1 <= volume_size < 16384: - raise ValueError('volume size must be an integer.') + raise ValueError("volume size must be an integer.") self.volume_size = volume_size # additional properties self._ssh_connection = None + # store the command-line arguments supplied by the user + self.argv = kwargs["argv"] + # todo define def __repr__(self): @property @@ -149,8 +160,8 @@ def instance_id(self): @instance_id.setter def instance_id(self, value): if not isinstance(value, str): - raise ValueError('instance must be a string instance id') - if not value.startswith('i-'): + raise ValueError("instance must be a string instance id") + if not value.startswith("i-"): raise ValueError('valid instance identifiers must start with "i-"') self._instance_id = value @@ -161,8 +172,8 @@ def security_group_id(self): @security_group_id.setter def security_group_id(self, value): if not isinstance(value, str): - raise ValueError('instance must be a string instance id') - if not value.startswith('sg-'): + raise ValueError("instance must be a string instance id") + if not value.startswith("sg-"): raise ValueError('valid instance identifiers must start with "i-"') self._security_group_id = value @@ -173,7 +184,7 @@ def rsa_key(self): @rsa_key.setter def rsa_key(self, value): if not isinstance(value, str): - raise ValueError('rsa_key_path must be type str') + raise ValueError("rsa_key_path must be type str") self._rsa_key = os.path.expanduser(value) @property @@ -192,9 +203,9 @@ def create_security_group(cls, name=None): # todo get list of existing groups; check against if name is None: - name = 'SEQC-%07d' % random.randint(1, int(1e7)) + name = "SEQC-%07d" % random.randint(1, int(1e7)) sg = cls.ec2.create_security_group(GroupName=name, Description=name) - log.notify('Created new security group: %s (name=%s).' % (sg.id, name)) + log.notify("Created new security group: %s (name=%s)." % (sg.id, name)) return sg.id @classmethod @@ -203,14 +214,17 @@ def enable_ssh(cls, security_group_id): security_group = cls.ec2.SecurityGroup(security_group_id) try: security_group.authorize_ingress( - IpProtocol="tcp", CidrIp="0.0.0.0/0", FromPort=22, ToPort=22) + IpProtocol="tcp", CidrIp="0.0.0.0/0", FromPort=22, ToPort=22 + ) security_group.authorize_ingress( - SourceSecurityGroupName=security_group.description) + SourceSecurityGroupName=security_group.description + ) except ClientError as e: # todo figure out why this is happening - if 'InvalidPermission.Duplicate' not in e.args[0]: + if "InvalidPermission.Duplicate" not in e.args[0]: raise - log.notify('Enabled ssh access via port 22 for security group %s' % - security_group_id) + log.notify( + "Enabled ssh access via port 22 for security group %s" % security_group_id + ) @classmethod @Retry(retries=20, delay=0.5) @@ -224,8 +238,7 @@ def verify_security_group(cls, security_group_id) -> None: @Retry(retries=10, delay=0.5) def remove_security_group(cls, security_group_id) -> None: cls.ec2.SecurityGroup(security_group_id).delete() - log.notify('security group %s successfully removed.' % ( - security_group_id)) + log.notify("security group %s successfully removed." % (security_group_id)) def launch_specification(self) -> dict: """return the specification for launching an instance with parameters defined @@ -240,14 +253,20 @@ def launch_specification(self) -> dict: sg_id = self.security_group_id spec = { - 'ImageId': self.image_id, - 'KeyName': self.rsa_key.split('/')[-1].split('.')[0], - 'InstanceType': self.instance_type, - 'SecurityGroupIds': [sg_id], - 'BlockDeviceMappings': [{'DeviceName': '/dev/xvdf', - 'Ebs': {'VolumeSize': self.volume_size, - 'VolumeType': 'gp2', - 'DeleteOnTermination': True}}], + "ImageId": self.image_id, + "KeyName": self.rsa_key.split("/")[-1].split(".")[0], + "InstanceType": self.instance_type, + "SecurityGroupIds": [sg_id], + "BlockDeviceMappings": [ + { + "DeviceName": "/dev/xvdf", + "Ebs": { + "VolumeSize": self.volume_size, + "VolumeType": "gp2", + "DeleteOnTermination": True, + }, + } + ], } return spec @@ -255,27 +274,26 @@ def launch_specification(self) -> dict: def verify_instance_running(self, instance_id): """wait for instance to reach 'running' state, then return""" instance = self.ec2.Instance(id=instance_id) - if not instance.state['Name'] == 'running': + if not instance.state["Name"] == "running": raise InstanceNotRunningError - log.notify('instance %s in running state' % instance_id) + log.notify("Instance %s in running state" % instance_id) def create_instance(self) -> None: if self.instance_id is not None: - raise RuntimeError('instance %s already exists.' % self.instance_id) + raise RuntimeError("instance %s already exists." % self.instance_id) if self.spot_bid: self.create_spot_instance() else: specification = self.launch_specification() - specification['MinCount'] = specification['MaxCount'] = 1 + specification["MinCount"] = specification["MaxCount"] = 1 instance = self.ec2.create_instances(**specification)[0] self.instance_id = instance.id - log.notify('instance %s created, waiting until running' % - self.instance_id) + log.notify("Instance %s created, waiting until running" % self.instance_id) instance.wait_until_running() - log.notify('instance %s in running state' % self.instance_id) + log.notify("Instance %s in running state" % self.instance_id) @staticmethod - def mount_volume(ssh, directory='/home/ec2-user'): + def mount_volume(ssh, directory="/home/ec2-user"): """mount /dev/xvdf to /data given an ssh client with access to an instance :param str directory: directory to mount the drive to. Note that odd behavior may @@ -287,36 +305,43 @@ def mount_volume(ssh, directory='/home/ec2-user'): ssh.execute("sudo mkfs -t ext4 /dev/xvdf 2>&1") # redir; errors invisible ssh.execute("sudo cp -a %s/. /tmp/directory/" % directory) # copy original ssh.execute("sudo mkdir -p %s" % directory) - ssh.execute("sudo mount /dev/xvdf %s && sudo cp -a /tmp/directory/. %s/" - % (directory, directory)) - ssh.execute("sudo chown ec2-user:ec2-user %s/lost+found && " - "chmod 755 %s/lost+found" % (directory, directory)) + ssh.execute( + "sudo mount /dev/xvdf %s && sudo cp -a /tmp/directory/. %s/" + % (directory, directory) + ) + ssh.execute( + "sudo chown ec2-user:ec2-user %s/lost+found && " + "chmod 755 %s/lost+found" % (directory, directory) + ) log.notify("Successfully mounted new volume onto %s." % directory) except ChildProcessError as e: - if not ('mount: according to mtab, /dev/xvdf is already mounted on %s' - % directory in ' '.join(e.args[0])): + if not ( + "mount: according to mtab, /dev/xvdf is already mounted on %s" + % directory + in " ".join(e.args[0]) + ): raise def set_credentials(self, ssh): """sets aws credentials on remote instance from user's local config file""" - ssh.execute('aws configure set aws_access_key_id %s' % self.aws_public_access_key) ssh.execute( - 'aws configure set aws_secret_access_key %s' % self.aws_secret_access_key) - ssh.execute('aws configure set region %s' % self.region) + "aws configure set aws_access_key_id %s" % self.aws_public_access_key + ) + ssh.execute( + "aws configure set aws_secret_access_key %s" % self.aws_secret_access_key + ) + ssh.execute("aws configure set region %s" % self.region) def construct_ec2_tags(self): """construct tags for ec2 instance""" # for the owner tag, we will just use the RSA key filename tags = [ - { - "Key": "Name", - "Value": "SEQC" - }, + {"Key": "Name", "Value": "SEQC"}, { "Key": "Owner", - "Value": os.path.splitext(os.path.basename(self.rsa_key))[0] + "Value": os.path.splitext(os.path.basename(self.rsa_key))[0], }, ] @@ -325,14 +350,11 @@ def construct_ec2_tags(self): try: # user_tags come in k1:v1,k2:v2 format # convert to a dictionary - user_tags_dict = dict(kv.split(':') for kv in self.user_tags.split(',')) + user_tags_dict = dict(kv.split(":") for kv in self.user_tags.split(",")) # convert the dictionary to something suitable for EC2 tag format for k, v in user_tags_dict.items(): - kv = { - "Key": k, - "Value": v - } + kv = {"Key": k, "Value": v} tags.append(kv) except: # ignore if invalid/not parseable @@ -348,155 +370,161 @@ def setup_seqc(self): # tag the instance tags = self.construct_ec2_tags() - self.ec2.create_tags( - Resources=[self.instance_id], - Tags=tags - ) + self.ec2.create_tags(Resources=[self.instance_id], Tags=tags) - with SSHConnection( - instance_id=self.instance_id, rsa_key=self.rsa_key - ) as ssh: + with SSHConnection(instance_id=self.instance_id, rsa_key=self.rsa_key) as ssh: self.mount_volume(ssh) - log.notify('setting aws credentials.') + log.notify("Setting aws credentials.") self.set_credentials(ssh) # use the local SEQC package (.tar.gz) to update the remote instance # this will overwrite whatever SEQC version exists in the remote instance if self.remote_update: - log.notify('uploading local SEQC installation to remote instance.') - seqc_distribution = os.path.expanduser('~/.seqc/seqc.tar.gz') - ssh.execute('mkdir -p software/seqc') - ssh.put_file(seqc_distribution, 'software/seqc.tar.gz') + log.notify("Uploading local SEQC installation to remote instance.") + seqc_distribution = os.path.expanduser("~/.seqc/seqc.tar.gz") + ssh.execute("mkdir -p software/seqc") + ssh.put_file(seqc_distribution, "software/seqc.tar.gz") ssh.execute( - 'tar -m -xvf software/seqc.tar.gz -C software/seqc --strip-components 1' + "tar -m -xvf software/seqc.tar.gz -C software/seqc --strip-components 1" ) log.notify("Sources are uploaded and decompressed, installing seqc.") try: - ssh.execute('sudo -H pip3 install software/seqc/') + ssh.execute("sudo -H pip3 install software/seqc/") except ChildProcessError as e: - if 'pip install --upgrade pip' in str(e): + if "pip install --upgrade pip" in str(e): pass else: raise try: # test the installation - ssh.execute('SEQC -h') + ssh.execute("SEQC -h") except: - log.notify('SEQC installation failed.') + log.notify("SEQC installation failed.") log.exception() raise try: # retrieves the SEQC version information - seqc_version, _ = ssh.execute('SEQC --version') + seqc_version, _ = ssh.execute("SEQC --version") # this returns an array seqc_version = seqc_version[0] # update the Name tag (e.g. SEQC 0.2.3) self.ec2.create_tags( Resources=[self.instance_id], - Tags=[ - { - "Key": "Name", - "Value": seqc_version - } - ] + Tags=[{"Key": "Name", "Value": seqc_version}], ) except: # just warn and proceed log.notify("Unable to retrieve SEQC version.") - log.notify('SEQC setup complete.') - log.notify('instance login: %s' % ssh.obscure_login_command()) + log.notify("SEQC setup complete.") + log.notify("Instance login: %s" % ssh.obscure_login_command()) def start(self): self.setup_seqc() - log.notify('Instance set-up complete.') + log.notify("Instance set-up complete.") def stop(self): """stops a running instance""" if self.instance_id is None: - raise RuntimeError('Instance not yet created, nothing to be stopped.') + raise RuntimeError("Instance not yet created, nothing to be stopped.") instance = self.ec2.Instance(self.instance_id) - if instance.state['Name'] not in ( - 'stopped', 'terminated', 'shutting-down'): - log.notify('requesting termination of instance {id}'.format( - id=self.instance_id)) + if instance.state["Name"] not in ("stopped", "terminated", "shutting-down"): + log.notify( + "requesting termination of instance {id}".format(id=self.instance_id) + ) instance.stop() instance.wait_until_stopped() - log.notify('instance {id} stopped.'.format(id=self.instance_id)) + log.notify("instance {id} stopped.".format(id=self.instance_id)) else: - log.notify('instance is not running') + log.notify("instance is not running") def restart(self): """restarts a stopped instance""" if self.instance_id is None: - raise RuntimeError('Instance not yet created, nothing to be restarted.') + raise RuntimeError("Instance not yet created, nothing to be restarted.") instance = self.ec2.Instance(self.instance_id) - if instance.state['Name'] == 'stopped': + if instance.state["Name"] == "stopped": instance.start() instance.wait_until_running() - log.notify('Stopped instance %s has restarted.' % self.instance_id) + log.notify("Stopped instance %s has restarted." % self.instance_id) else: - log.notify('Instance %s in state "%s" must be in a stopped state to be ' - 'restarted.' % (self.instance_id, instance.state['Name'])) + log.notify( + 'Instance %s in state "%s" must be in a stopped state to be ' + "restarted." % (self.instance_id, instance.state["Name"]) + ) def terminate(self): """terminates an instance in any state (including stopped)""" if self.instance_id is None: - raise RuntimeError('Instance not yet created, nothing to be restarted.') + raise RuntimeError("Instance not yet created, nothing to be restarted.") instance = self.ec2.Instance(self.instance_id) - if instance.state['Name'] not in ('terminated', 'shutting-down'): - log.notify('requesting termination of instance {id}'.format( - id=self.instance_id)) + if instance.state["Name"] not in ("terminated", "shutting-down"): + log.notify( + "requesting termination of instance {id}".format(id=self.instance_id) + ) instance.terminate() instance.wait_until_terminated() - log.notify('instance {id} terminated.'.format(id=self.instance_id)) + log.notify("instance {id} terminated.".format(id=self.instance_id)) else: - log.notify('Instance %s in state "%s" must be running to be stopped.' % - (self.instance_id, instance.state['Name'])) + log.notify( + 'Instance %s in state "%s" must be running to be stopped.' + % (self.instance_id, instance.state["Name"]) + ) @classmethod @Retry(retries=40, delay=5, catch=(InstanceNotRunningError, ClientError)) def verify_spot_bid_fulfilled(cls, sir_id): result = cls.client.describe_spot_instance_requests( - SpotInstanceRequestIds=[sir_id]) - status = result['SpotInstanceRequests'][0]['Status']['Code'] - if status not in ['pending-evaluation', 'pending-fulfillment', 'fulfilled']: - raise EC2RuntimeError('spot request bad-status: %s' % status) - elif status != 'fulfilled': + SpotInstanceRequestIds=[sir_id] + ) + status = result["SpotInstanceRequests"][0]["Status"]["Code"] + if status not in ["pending-evaluation", "pending-fulfillment", "fulfilled"]: + raise EC2RuntimeError("spot request bad-status: %s" % status) + elif status != "fulfilled": raise InstanceNotRunningError - return result['SpotInstanceRequests'][0]['InstanceId'] + return result["SpotInstanceRequests"][0]["InstanceId"] def create_spot_instance(self): if not self.spot_bid: - raise ValueError('must pass constructor spot_bid price (float) to create a ' - 'spot bid request.') + raise ValueError( + "must pass constructor spot_bid price (float) to create a " + "spot bid request." + ) response = self.client.request_spot_instances( - DryRun=False, SpotPrice=str(self.spot_bid), - LaunchSpecification=self.launch_specification()) - sir_id = response['SpotInstanceRequests'][0]['SpotInstanceRequestId'] + DryRun=False, + SpotPrice=str(self.spot_bid), + LaunchSpecification=self.launch_specification(), + ) + sir_id = response["SpotInstanceRequests"][0]["SpotInstanceRequestId"] log.notify( - 'spot instance requested (%s), waiting for bid to be accepted.' % sir_id) + "Spot instance requested (%s), waiting for bid to be accepted." % sir_id + ) self.instance_id = self.verify_spot_bid_fulfilled(sir_id) if self.instance_id is None: raise InstanceNotRunningError( - 'spot bid of %f was not fulfilled, please try a higher bid or ') - log.notify('spot bid accepted, waiting for instance (id=%s) to attain running ' - 'state.' % self.instance_id) + "Spot bid of %f was not fulfilled, please try a higher bid or " + ) + log.notify( + "Spot bid accepted, waiting for instance (id=%s) to attain running " + "state." % self.instance_id + ) self.ec2.Instance(self.instance_id).wait_until_running() - log.notify('spot instance (id=%s) in running state' % self.instance_id) + log.notify("Spot instance (id=%s) in running state" % self.instance_id) def __enter__(self): try: self.setup_seqc() except: if self.synchronous and self.instance_id: - log.notify('error occurred during setup, attemption instance termination') + log.notify( + "error occurred during setup, attemption instance termination" + ) log.exception() try: self.terminate() @@ -511,107 +539,71 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not exc_type: return True - @staticmethod - def pickle_function(function: object, args, kwargs) -> str: - """ pickle and function and its arguments - - :param object function: function to be pickled - :param tuple args: positional arguments for the function - :param dict kwargs: keyword arguments for the function - :return str: filename of the pickled function - """ - filename = '{}{!s}_{}.p'.format( - os.environ['TMPDIR'], random.randint(0, 1e9), function.__name__) - - with open(filename, 'wb') as f: - dill.dump(dict(function=function, args=args, kwargs=kwargs), f) - return filename - - # todo this doesn't work; it gets THIS module's imports, but not the calling module! - @staticmethod - def get_imports(): - for alias, val in globals().items(): - if isinstance(val, types.ModuleType): - yield (val.__name__, alias) - @classmethod - def format_importlist(cls): - importlist = '' - for name, alias in cls.get_imports(): - if name != alias: - importlist += 'import {name} as {alias}\n'.format(name=name, alias=alias) - else: - importlist += 'import {name}\n'.format(name=name) - return importlist - - @classmethod - def write_script(cls, function) -> str: - """generate a python script that calls function after importing required modules + def write_script(cls, argv, function) -> str: + """generate a bash script that runs SEQC + :param list argv: the original command-line arguments supplied by user :param object function: function to be called :return str: filename of the python script """ - script_name = '{}{!s}_{}.py'.format( - os.environ['TMPDIR'], random.randint(0, 1e9), function.__name__) + script_name = "{}{!s}_{}.py".format( + os.environ["TMPDIR"], random.randint(0, 1e9), function.__name__ + ) script_body = ( - '{imports}' - 'with open("func.p", "rb") as fin:\n' - ' data = dill.load(fin)\n' - 'results = data["function"](*data["args"], **data["kwargs"])\n' - 'with open("results.p", "wb") as f:\n' - ' dill.dump(results, f)\n' + "#!/bin/bash -x" + "\n" + "\n" + "SEQC " + + " ".join(argv) + + " --local" + + (" --terminate" if "--no-terminate" not in argv else "") + + "\n" ) - script_body = script_body.format(imports=cls.format_importlist()) - with open(script_name, 'w') as f: - # log.notify('writing script to file:\n%s' % script_body) + with open(script_name, "wt") as f: + log.notify("writing script to file:\n%s" % script_body) f.write(script_body) return script_name def __call__(self, function): - def function_executed_on_aws(*args, **kwargs): - # dump original function to file - script = self.write_script(function) - func = self.pickle_function(function, args, kwargs) + # create a bash script running SEQC + script = self.write_script(self.argv, function) # create an instance, or ensure the passed instance has the necessary # packages installed self.setup_seqc() + # connect to EC2 instance + # run the bash script which will run SEQC with SSHConnection(self.instance_id, self.rsa_key) as ssh: - ssh.put_file(script, 'script.py') - ssh.put_file(func, 'func.p') + ssh.put_file(script, "script.sh") + ssh.execute("chmod +x ./script.sh") if self.synchronous: - ssh.execute('python3 script.py') - results_name = os.environ['TMPDIR'] + function.__name__ + '_results.p' - ssh.get_file('results.p', results_name) - with open(results_name, 'rb') as f: - results = dill.load(f) + ssh.execute("./script.sh") else: - ssh.execute('nohup python3 script.py > nohup.log 2>&1 &') - results = None + ssh.execute("nohup ./script.sh > nohup.log 2>&1 &") if self.synchronous: self.terminate() - return results + return None return function_executed_on_aws class SSHConnection: - _error_msg = ('You need to specify a valid RSA key to connect to Amazon EC2 ' - 'instances, see https://github.com/ambrosejcarr/seqc#create-an-rsa-key' - '-to-allow-you-to-launch-a-cluster') + _error_msg = ( + "You need to specify a valid RSA key to connect to Amazon EC2 instances" + ) - ec2 = boto3.resource('ec2') + ec2 = boto3.resource("ec2") def __init__(self, instance_id, rsa_key): if not isinstance(instance_id, str): - raise ValueError('instance must be a string instance id') - if not instance_id.startswith('i-'): + raise ValueError("instance must be a string instance id") + if not instance_id.startswith("i-"): raise ValueError('valid instance identifiers must start with "i-"') self._instance_id = instance_id self.rsa_key = os.path.expanduser(rsa_key) @@ -629,31 +621,37 @@ def instance_id(self): @instance_id.setter def instance_id(self, value): if isinstance(value, str): - raise ValueError('instance must be a string instance id') - if not value.startswith('i-'): + raise ValueError("instance must be a string instance id") + if not value.startswith("i-"): raise ValueError('valid instance identifiers must start with "i-"') self._instance_id = value def check_key_file(self): """Checks the rsa file is present""" if not self.rsa_key: - log.notify('The key %s was not found!' % self.rsa_key) - raise FileNotFoundError(self._error_msg, 'The key file %s does not exist' % - self.rsa_key) + log.notify("The key %s was not found!" % self.rsa_key) + raise FileNotFoundError( + self._error_msg, "The key file %s does not exist" % self.rsa_key + ) @Retry(retries=40, delay=2.5, catch=(NoValidConnectionsError, socket.error)) def connect(self): """connects to a remote instance""" instance = self.ec2.Instance(self.instance_id) try: - self.ssh.connect(instance.public_dns_name, username='ec2-user', - key_filename=self.rsa_key, timeout=3.0) + self.ssh.connect( + instance.public_dns_name, + username="ec2-user", + key_filename=self.rsa_key, + timeout=3.0, + ) except NoValidConnectionsError: - state = instance.state['Name'] - if state not in ['running', 'pending']: + state = instance.state["Name"] + if state not in ["running", "pending"]: raise InstanceNotRunningError( - 'instance %s in state %s. Only running instances can be connected to.' - % (self.instance_id, state)) + "instance %s in state %s. Only running instances can be connected to." + % (self.instance_id, state) + ) else: raise @@ -688,8 +686,9 @@ def put_file(self, local_file, remote_file): self.connect() with closing(self.ssh.open_sftp()) as ftp: ftp.put(local_file, remote_file) - log.info('placed {lfile} at {rfile}.'.format( - lfile=local_file, rfile=remote_file)) + log.info( + "placed {lfile} at {rfile}.".format(lfile=local_file, rfile=remote_file) + ) def execute(self, args): """executes the specified arguments remotely on an AWS instance @@ -705,25 +704,26 @@ def execute(self, args): data = stdout.read().decode().splitlines() errs = stderr.read().decode().splitlines() if errs: - raise ChildProcessError('\n'.join(errs)) + raise ChildProcessError("\n".join(errs)) return data, errs def login_command(self): instance = self.ec2.Instance(self.instance_id) - return ('ssh -i {rsa_path} ec2-user@{dns_name}'.format( - rsa_path=self.rsa_key, dns_name=instance.public_ip_address)) + return "ssh -i {rsa_path} ec2-user@{dns_name}".format( + rsa_path=self.rsa_key, dns_name=instance.public_ip_address + ) def obscure_login_command(self): """ same as login_command() except it hides the key file location """ instance = self.ec2.Instance(self.instance_id) - return ('ssh -i ec2-user@{dns_name}'.format( - dns_name=instance.public_ip_address) + return "ssh -i ec2-user@{dns_name}".format( + dns_name=instance.public_ip_address ) def __enter__(self): - log.notify('connecting to instance %s via ssh' % self.instance_id) + log.notify("Connecting to instance %s via ssh" % self.instance_id) self.connect() return self @@ -734,9 +734,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): class instance_clean_up: - def __init__( - self, email=None, upload=None, log_name='seqc.log', terminate=True, debug=False + self, + email=None, + upload=None, + log_name="seqc.log", + terminate=True, + debug=False, + running_remote=False, ): """Execution context for on-server code execution with defined clean-up practices. @@ -762,8 +767,9 @@ def __init__( self.terminate = terminate # only terminate if no errors occur self.aws_upload_key = upload self.err_status = False - self.mutt = verify.executables('mutt')[0] # unpacking necessary for singleton + self.mutt = verify.executables("mutt")[0] # unpacking necessary for singleton self.debug = debug + self.running_remote = running_remote @staticmethod def email_user(attachment: str, email_body: str, email_address: str) -> None: @@ -781,7 +787,9 @@ def email_user(attachment: str, email_body: str, email_address: str) -> None: email_args = ( 'echo "{b}" | mutt -e "set content_type="text/html"" -s ' '"Remote Process" {e} -a "{a}"'.format( - b=email_body, a=attachment, e=email_address)) + b=email_body, a=attachment, e=email_address + ) + ) email_process = Popen(email_args, shell=True, stderr=PIPE, stdout=PIPE) out, err = email_process.communicate(email_body) if err: @@ -789,8 +797,6 @@ def email_user(attachment: str, email_body: str, email_address: str) -> None: def __enter__(self): pass - # log.setup_logger(self.log_name) - # log.notify('Beginning protected execution') # todo only run if verbose @staticmethod def _get_instance_id(): @@ -798,14 +804,22 @@ def _get_instance_id(): p = Popen( "curl --silent http://169.254.169.254/latest/meta-data/instance-id", - shell=True, stdout=PIPE, stderr=PIPE + shell=True, + stdout=PIPE, + stderr=PIPE, ) instance_id, err = p.communicate() - if err: # not an ec2 linux instance, nothing to terminate - return + if err: + # not an ec2 linux instance, nothing to terminate + return None + + instance_id = instance_id.decode().strip() + if instance_id == "": + # not an ec2 linux instance, nothing to terminate + return None - return instance_id.decode().strip() + return instance_id def __exit__(self, exc_type, exc_val, exc_tb): """If an exception occurs, log the exception, email if possible, then terminate @@ -820,47 +834,57 @@ def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: log.exception() - email_body = 'Process interrupted -- see attached error message' + email_body = "Process interrupted -- see attached error message" elif self.terminate: - email_body = 'Process completed successfully -- see attached log' - log.info('Execution completed successfully, instance will be terminated.') + email_body = "Process completed successfully -- see attached log" + log.info("Execution completed successfully, instance will be terminated.") else: - email_body = 'Process completed successfully -- see attached log' - log.info('Execution completed successfully, but user requested no ' - 'termination. Instance will continue to run.') + email_body = "Process completed successfully -- see attached log" + log.info( + "Execution completed successfully, but user requested no " + "termination. Instance will continue to run." + ) # todo this is the source of the second email for successful runs # email user if possible; catch exceptions if email fails. if self.email and self.mutt: - log.notify('Emailing user.') + log.notify("Emailing user.") try: self.email_user( - attachment=self.log_name, email_body=email_body, - email_address=self.email) + attachment=self.log_name, + email_body=email_body, + email_address=self.email, + ) except ChildProcessError: log.exception() # upload data if requested if self.aws_upload_key: - log.notify('Uploading log to {}'.format(self.aws_upload_key)) + log.notify("Uploading log to {}".format(self.aws_upload_key)) bucket, key = io.S3.split_link(self.aws_upload_key) @Retry(catch=Exception) def upload_file(): io.S3.upload_file(self.log_name, bucket, key) + upload_file() # terminate if no errors and debug is False if self.terminate: if exc_type and self.debug: - return # don't terminate if an error was raised and debug was set + # don't terminate if an error was raised and debug was set + return instance_id = self._get_instance_id() if instance_id is None: - return # todo notify if verbose - ec2 = boto3.resource('ec2') + # probably not an ec2 instance + # return without attempting to terminate the instance + return + ec2 = boto3.resource("ec2") instance = ec2.Instance(instance_id) - log.notify('instance %s termination requested. If successful, this is the ' - 'final log entry.' % instance_id) + log.notify( + "instance %s termination requested. If successful, this is the " + "final log entry." % instance_id + ) instance.terminate() instance.wait_until_terminated() @@ -871,7 +895,7 @@ def remove_inactive_security_groups(): This function finds all inactive security groups. Note that it is NOT limited to your user account """ - ec2 = boto3.resource('ec2') + ec2 = boto3.resource("ec2") for s in ec2.security_groups.all(): try: s.delete() @@ -884,11 +908,12 @@ def check_bucket(s3_uri): :param str s3_uri: name of uri in a bucket to check """ - if not s3_uri.startswith('s3://'): - raise ValueError('%s is not a valid s3 URI' % s3_uri) - bucket = s3_uri[5:].split('/')[0] - s3 = boto3.resource('s3') + if not s3_uri.startswith("s3://"): + raise ValueError("%s is not a valid s3 URI" % s3_uri) + bucket = s3_uri[5:].split("/")[0] + s3 = boto3.resource("s3") try: s3.meta.client.head_bucket(Bucket=bucket) except ClientError: - raise ValueError('Bucket %s for s3 URI %s does not exist' % (bucket, s3_uri)) \ No newline at end of file + raise ValueError("Bucket %s for s3 URI %s does not exist" % (bucket, s3_uri)) + diff --git a/src/seqc/filter.py b/src/seqc/filter.py index 084b922..3fe2dd2 100644 --- a/src/seqc/filter.py +++ b/src/seqc/filter.py @@ -27,13 +27,14 @@ def estimate_min_poly_t(fastq_files: list, platform) -> int: primer_length = platform.primer_length() if primer_length is None: raise RuntimeError( - 'provided platform does not have a defined primer length, and thus the ' - 'min_poly_t parameter cannot be estimated. Please provide --min-poly-t ' - 'explicitly in process_experiment.py.') + "provided platform does not have a defined primer length, and thus the " + "min_poly_t parameter cannot be estimated. Please provide --min-poly-t " + "explicitly in process_experiment.py." + ) for f in fastq_files: mean = Reader(f).estimate_sequence_length()[0] available_nucleotides = max(0, mean - primer_length) - min_vals.append(floor(min(available_nucleotides * .8, 20))) + min_vals.append(floor(min(available_nucleotides * 0.8, 20))) return min(min_vals) @@ -65,12 +66,15 @@ def low_count(molecules, is_invalid, plot=False, ax=None): # these cells are empirically determined to have "transition" library sizes # that confound downstream analysis inflection_pt = np.min(np.where(np.abs(d2) == 0)[0]) - inflection_pt = int(inflection_pt * .9) + inflection_pt = int(inflection_pt * 0.9) except ValueError as e: - if e.args[0] == ('zero-size array to reduction operation minimum which has no ' - 'identity'): - log.notify('Low count filter passed-through; too few cells to estimate ' - 'inflection point.') + if e.args[0] == ( + "zero-size array to reduction operation minimum which has no " "identity" + ): + log.notify( + "Low count filter passed-through; too few cells to estimate " + "inflection point." + ) return is_invalid # can't estimate validity else: raise @@ -83,13 +87,13 @@ def low_count(molecules, is_invalid, plot=False, ax=None): if plot and ax: cms /= np.max(cms) # normalize to one ax.plot(np.arange(len(cms))[:inflection_pt], cms[:inflection_pt]) - ax.plot(np.arange(len(cms))[inflection_pt:], cms[inflection_pt:], c='indianred') - ax.hlines(cms[inflection_pt], *ax.get_xlim(), linestyle='--') - ax.vlines(inflection_pt, *ax.get_ylim(), linestyle='--') + ax.plot(np.arange(len(cms))[inflection_pt:], cms[inflection_pt:], c="indianred") + ax.hlines(cms[inflection_pt], *ax.get_xlim(), linestyle="--") + ax.vlines(inflection_pt, *ax.get_ylim(), linestyle="--") ax.set_xticklabels([]) - ax.set_xlabel('putative cell') - ax.set_ylabel('ECDF (Cell Size)') - ax.set_title('Cell Size') + ax.set_xlabel("putative cell") + ax.set_ylabel("ECDF (Cell Size)") + ax.set_title("Cell Size") ax.set_ylim((0, 1)) ax.set_xlim((0, len(cms))) @@ -118,8 +122,9 @@ def low_coverage(molecules, reads, is_invalid, plot=False, ax=None, filter_on=Tr if ms.shape[0] < 10 or rs.shape[0] < 10: log.notify( - 'Low coverage filter passed-through; too few cells to calculate ' - 'mixture model.') + "Low coverage filter passed-through; too few cells to calculate " + "mixture model." + ) return is_invalid # get read / cell ratio, filter out low coverage cells @@ -133,7 +138,7 @@ def low_coverage(molecules, reads, is_invalid, plot=False, ax=None, filter_on=Tr gmm2.fit(col_ratio) if filter_on: - # check if adding a second component is necessary; if not, filter is pass-through + # check if adding a second component is necessary; if not, filter is pass-through filter_on = gmm2.bic(col_ratio) / gmm1.bic(col_ratio) < 0.95 if filter_on: @@ -154,15 +159,19 @@ def low_coverage(molecules, reads, is_invalid, plot=False, ax=None, filter_on=Tr try: seqc.plot.scatter.continuous(logms, ratio, colorbar=False, ax=ax, s=3) except LinAlgError: - warnings.warn('SEQC: Insufficient number of cells to calculate density for ' - 'coverage plot') + warnings.warn( + "SEQC: Insufficient number of cells to calculate density for " + "coverage plot" + ) ax.scatter(logms, ratio, s=3) - ax.set_xlabel('log10(molecules)') - ax.set_ylabel('reads / molecule') + ax.set_xlabel("log10(molecules)") + ax.set_ylabel("reads / molecule") if filter_on: - ax.set_title('Coverage: {:.2}%'.format(np.sum(failing) / len(failing) * 100)) + ax.set_title( + "Coverage: {:.2}%".format(np.sum(failing) / len(failing) * 100) + ) else: - ax.set_title('Coverage') + ax.set_title("Coverage") xmin, xmax = np.min(logms), np.max(logms) ymax = np.max(ratio) ax.set_xlim((xmin, xmax)) @@ -175,14 +184,25 @@ def low_coverage(molecules, reads, is_invalid, plot=False, ax=None, filter_on=Tr # plot the discarded cells in red, like other filters if filter_on: ax.scatter( - logms[res == np.argmin(means)], ratio[res == np.argmin(means)], - s=4, c='indianred') + logms[res == np.argmin(means)], + ratio[res == np.argmin(means)], + s=4, + c="indianred", + ) return is_invalid -def high_mitochondrial_rna(molecules, gene_ids, is_invalid, mini_summary_d, max_mt_content=0.2, - plot=False, ax=None, filter_on=True): +def high_mitochondrial_rna( + molecules, + gene_ids, + is_invalid, + mini_summary_d, + max_mt_content=0.2, + plot=False, + ax=None, + filter_on=True, +): """ Sets any cell with a fraction of mitochondrial mRNA greater than max_mt_content to invalid. @@ -199,9 +219,10 @@ def high_mitochondrial_rna(molecules, gene_ids, is_invalid, mini_summary_d, max_ :return: is_invalid, np.ndarray(dtype=bool), updated valid and invalid cells """ # identify % genes that are mitochondrial - mt_genes = np.fromiter(map(lambda x: x.startswith('MT-'), gene_ids), dtype=np.bool) - mt_molecules = np.ravel(molecules.tocsr()[~is_invalid, :].tocsc()[:, mt_genes].sum( - axis=1)) + mt_genes = np.fromiter(map(lambda x: x.startswith("MT-"), gene_ids), dtype=np.bool) + mt_molecules = np.ravel( + molecules.tocsr()[~is_invalid, :].tocsc()[:, mt_genes].sum(axis=1) + ) ms = np.ravel(molecules.tocsr()[~is_invalid, :].sum(axis=1)) ratios = mt_molecules / ms @@ -217,30 +238,37 @@ def high_mitochondrial_rna(molecules, gene_ids, is_invalid, mini_summary_d, max_ try: seqc.plot.scatter.continuous(ms, ratios, colorbar=False, ax=ax, s=3) except LinAlgError: - log.notify('Inadequate number of cells or MT gene abundance to plot MT ' - 'filter, no visual will be produced, but filter has been ' - 'applied.') + log.notify( + "Inadequate number of cells or MT gene abundance to plot MT " + "filter, no visual will be produced, but filter has been " + "applied." + ) return is_invalid else: return is_invalid # nothing else to do here if filter_on and (np.sum(failing) != 0): - ax.scatter(ms[failing], ratios[failing], c='indianred', s=3) # failing cells + ax.scatter( + ms[failing], ratios[failing], c="indianred", s=3 + ) # failing cells xmax = np.max(ms) ymax = np.max(ratios) ax.set_xlim((0, xmax)) ax.set_ylim((0, ymax)) - ax.hlines(max_mt_content, *ax.get_xlim(), linestyle='--', colors='indianred') - ax.set_xlabel('total molecules') - ax.set_ylabel('mtRNA fraction') + ax.hlines(max_mt_content, *ax.get_xlim(), linestyle="--", colors="indianred") + ax.set_xlabel("total molecules") + ax.set_ylabel("mtRNA fraction") if filter_on: ax.set_title( - 'mtRNA Fraction: {:.2}%'.format(np.sum(failing) / len(failing) * 100)) - mini_summary_d['mt_rna_fraction'] = (np.sum(failing) *1.0 / len(failing)) * 100.0 + "mtRNA Fraction: {:.2}%".format(np.sum(failing) / len(failing) * 100) + ) + mini_summary_d["mt_rna_fraction"] = ( + np.sum(failing) * 1.0 / len(failing) + ) * 100.0 else: - ax.set_title('mtRNA Fraction') - mini_summary_d['mt_rna_fraction'] = 0.0 + ax.set_title("mtRNA Fraction") + mini_summary_d["mt_rna_fraction"] = 0.0 seqc.plot.xtick_vertical(ax=ax) - + return is_invalid @@ -268,14 +296,14 @@ def low_gene_abundance(molecules, is_invalid, plot=False, ax=None, filter_on=Tru # get line of best fit with warnings.catch_warnings(): # ignore scipy LinAlg warning about LAPACK bug. - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") regr = LinearRegression() regr.fit(x, y) # mark large residuals as failing yhat = regr.predict(x) residuals = yhat - y - failing = residuals > .15 + failing = residuals > 0.15 is_invalid = is_invalid.copy() if filter_on: @@ -286,33 +314,45 @@ def low_gene_abundance(molecules, is_invalid, plot=False, ax=None, filter_on=Tru try: seqc.plot.scatter.continuous(x, y, ax=ax, colorbar=False, s=3) except LinAlgError: - log.notify('Inadequate number of cells to plot low coverage filter no visual ' - 'will be produced, but filter has been applied.') + log.notify( + "Inadequate number of cells to plot low coverage filter no visual " + "will be produced, but filter has been applied." + ) return is_invalid xmin, xmax = np.min(x), np.max(x) ymin, ymax = np.min(y), np.max(y) lx = np.linspace(xmin, xmax, 200) ly = m * lx + b - ax.plot(lx, np.ravel(ly), linestyle='--', c='indianred') + ax.plot(lx, np.ravel(ly), linestyle="--", c="indianred") if filter_on: - ax.scatter(x[failing], y[failing], c='indianred', s=3) + ax.scatter(x[failing], y[failing], c="indianred", s=3) ax.set_ylim((ymin, ymax)) ax.set_xlim((xmin, xmax)) - ax.set_xlabel('molecules (cell)') - ax.set_ylabel('genes (cell)') + ax.set_xlabel("molecules (cell)") + ax.set_ylabel("genes (cell)") if filter_on: - ax.set_title('Low Complexity: {:.2}%'.format(np.sum(failing) / len(failing) * 100)) + ax.set_title( + "Low Complexity: {:.2}%".format(np.sum(failing) / len(failing) * 100) + ) else: - ax.set_title('Low Complexity') + ax.set_title("Low Complexity") seqc.plot.xtick_vertical(ax=ax) return is_invalid def create_filtered_dense_count_matrix( - molecules: SparseFrame, reads: SparseFrame, mini_summary_d, max_mt_content=0.2, plot=False, - figname=None, filter_mitochondrial_rna: bool=True, filter_low_count: bool=True, - filter_low_coverage: bool=True, filter_low_gene_abundance: bool=True): + molecules: SparseFrame, + reads: SparseFrame, + mini_summary_d, + max_mt_content=0.2, + plot=False, + figname=None, + filter_mitochondrial_rna: bool = True, + filter_low_count: bool = True, + filter_low_coverage: bool = True, + filter_low_gene_abundance: bool = True, +): """ filter cells with low molecule counts, low read coverage, high mitochondrial content, and low gene detection. Returns a dense pd.DataFrame of filtered counts, the total @@ -334,17 +374,18 @@ def create_filtered_dense_count_matrix( cells_lost = OrderedDict() molecules_lost = OrderedDict() - if not molecules.columns.dtype.char == 'U': + if not molecules.columns.dtype.char == "U": if molecules.sum().sum() == 0: - raise EmptyMatrixError('Matrix is empty, cannot create dense matrix') + raise EmptyMatrixError("Matrix is empty, cannot create dense matrix") else: raise RuntimeError( - 'non-string column names detected. Please convert column names into ' - 'string gene symbols before calling this function.') + "non-string column names detected. Please convert column names into " + "string gene symbols before calling this function." + ) if not isinstance(max_mt_content, float): - raise TypeError('Parameter max_mt_content must be of type float.') + raise TypeError("Parameter max_mt_content must be of type float.") if not 0 <= max_mt_content <= 1: - raise ValueError('Parameter max_mt_content must be in the interval [0, 1]') + raise ValueError("Parameter max_mt_content must be in the interval [0, 1]") # set data structures and original molecule counts molecules_data = molecules.data @@ -369,31 +410,48 @@ def additional_loss(new_filter, old_filter, data_matrix): ms = np.ravel(molecules_data.tocsr()[~is_invalid, :].sum(axis=1)).sum() rs = np.ravel(reads_data.tocsr()[~is_invalid, :].sum(axis=1)).sum() - mini_summary_d['avg_reads_per_molc'] = rs / ms + mini_summary_d["avg_reads_per_molc"] = rs / ms # filter low counts if filter_low_count: count_invalid = low_count(molecules_data, is_invalid, plot, ax_count) - cells_lost['low_count'], molecules_lost['low_count'] = additional_loss( - count_invalid, is_invalid, molecules_data) + cells_lost["low_count"], molecules_lost["low_count"] = additional_loss( + count_invalid, is_invalid, molecules_data + ) else: count_invalid = is_invalid # filter low coverage - cov_invalid = low_coverage(molecules_data, reads_data, count_invalid, plot, ax_cov, filter_low_coverage) - cells_lost['low_coverage'], molecules_lost['low_coverage'] = additional_loss(cov_invalid, count_invalid, molecules_data) + cov_invalid = low_coverage( + molecules_data, reads_data, count_invalid, plot, ax_cov, filter_low_coverage + ) + cells_lost["low_coverage"], molecules_lost["low_coverage"] = additional_loss( + cov_invalid, count_invalid, molecules_data + ) # filter high_mt_content if requested - mt_invalid = high_mitochondrial_rna(molecules_data, molecules_columns, cov_invalid, mini_summary_d, max_mt_content, - plot, ax_mt, filter_mitochondrial_rna) - cells_lost['high_mt'], molecules_lost['high_mt'] = additional_loss(mt_invalid, cov_invalid, molecules_data) - + mt_invalid = high_mitochondrial_rna( + molecules_data, + molecules_columns, + cov_invalid, + mini_summary_d, + max_mt_content, + plot, + ax_mt, + filter_mitochondrial_rna, + ) + cells_lost["high_mt"], molecules_lost["high_mt"] = additional_loss( + mt_invalid, cov_invalid, molecules_data + ) # filter low gene abundance - gene_invalid = low_gene_abundance(molecules_data, mt_invalid, plot, ax_gene, filter_low_gene_abundance) - cells_lost['low_gene_detection'], molecules_lost[ - 'low_gene_detection'] = additional_loss( - gene_invalid, mt_invalid, molecules_data) + gene_invalid = low_gene_abundance( + molecules_data, mt_invalid, plot, ax_gene, filter_low_gene_abundance + ) + ( + cells_lost["low_gene_detection"], + molecules_lost["low_gene_detection"], + ) = additional_loss(gene_invalid, mt_invalid, molecules_data) # construct dense matrix dense = molecules_data.tocsr()[~gene_invalid, :].todense() @@ -402,9 +460,10 @@ def additional_loss(new_filter, old_filter, data_matrix): dense = pd.DataFrame( dense, index=molecules.index[~gene_invalid], - columns=molecules.columns[nonzero_gene_count]) + columns=molecules.columns[nonzero_gene_count], + ) - mini_summary_d['avg_reads_per_cell'] = rs / len(dense.index) + mini_summary_d["avg_reads_per_cell"] = rs / len(dense.index) # describe cells cell_description = dense.sum(axis=1).describe() diff --git a/src/seqc/io.py b/src/seqc/io.py index 330af16..69026c2 100644 --- a/src/seqc/io.py +++ b/src/seqc/io.py @@ -125,7 +125,7 @@ def download_awscli(cls, link, prefix='./', overwrite=True, recursive=False): :param recursive: :return list: all downloaded filenames """ - if prefix is '': + if prefix == '': prefix = './' if overwrite is False: @@ -177,10 +177,10 @@ def download(cls, link, prefix='', overwrite=True, recursive=False): @staticmethod def upload_file(filename, bucket, key, boto=False): """upload filename to aws at s3://bucket/key/filename - :param key: key of S3 bucket to download + :param key: key of S3 bucket to upload :param bucket: name of S3 bucket - :param filename: name of file to download - :param boto: True if download using boto3 (default=False, uses awscli) + :param filename: name of file to upload + :param boto: True if upload using boto3 (default=False, uses awscli) """ if key.startswith('/'): @@ -196,8 +196,8 @@ def upload_file(filename, bucket, key, boto=False): if not boto: s3link = 's3://' + bucket + '/' + key cmd = 'aws s3 cp {fname} {s3link}'.format(fname=filename, s3link=s3link) - download_cmd = shlex.split(cmd) - Popen(download_cmd).wait() + upload_cmd = shlex.split(cmd) + Popen(upload_cmd).wait() else: client = boto3.client('s3') client.upload_file(filename, bucket, key) diff --git a/src/seqc/log.py b/src/seqc/log.py index 5db7280..eeec75f 100644 --- a/src/seqc/log.py +++ b/src/seqc/log.py @@ -8,32 +8,51 @@ import re -def setup_logger(filename): +def setup_logger(filename, is_debug): """create a simple log file in the cwd to track progress and any errors""" - logging.basicConfig(filename=filename, level=logging.DEBUG, filemode='w') + logging.basicConfig( + filename=filename, + level=logging.DEBUG if is_debug else logging.INFO, + filemode="w", + ) def info(message): """print a timestamped update for the user. :param message: """ - logging.info(datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ':' + message) + logging.info(datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ":" + message) + + +def warn(message): + """print a timestamped update for the user. + :param message: + """ + logging.warn(datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ":" + message) def exception(): """log the most recent exception to an initialized logger""" - logging.exception(datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ':main:') + logging.exception(datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ":main:") def notify(message): """print a timestamped update for the user and log it to file""" info(message) - print('SEQC: ' + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ': %s' % message) + print("SEQC: " + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ": %s" % message) + +def debug(message, module_name=None, func_name=None): -def debug(message): - logging.debug(datetime.now().strftime("%Y-%m-%d %H:%M:%S") + - ':%(module)s:%(funcName)s:' + ': %s' % message) + module_name = f" [{module_name}]" if module_name else "" + func_name = f" [{func_name}]" if func_name else "" + logging.debug( + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + module_name + + func_name + + " " + + message + ) def args(arguments): @@ -44,8 +63,11 @@ def args(arguments): :return: None """ arguments = vars(arguments) - info('Passed command line arguments: {}'.format( - json.dumps(arguments, separators=(',', ': '), indent=4, sort_keys=True))) + info( + "Passed command line arguments: {}".format( + json.dumps(arguments, separators=(",", ": "), indent=4, sort_keys=True) + ) + ) class LogData: @@ -59,35 +81,36 @@ class LogData: stored in. """ - _oldver = ('{divide}\nINPUT\n{divide}\n' - 'Total input reads:\t{n_fastq}\n' - '{divide}\nALIGNMENT (% FROM INPUT)\n{divide}\n' - 'Total reads aligned:\t{n_sam} ({prop_al}%)\n' - ' - Genomic alignments:\t{genomic} ({prop_gen}%)\n' - ' - PhiX alignments:\t{phi_x} ({prop_phix}%)\n' - ' - Transcriptome alignments:\t{trans} ({prop_trans}%)\n' - '{divide}\nFILTERING (% FROM ALIGNMENT)\n{divide}\n' - 'Genomic alignments:\t{genomic} ({bad_gen}%)\n' - 'PhiX alignments:\t{phi_x} ({bad_phi}%)\n' - 'Incorrect barcodes:\t{wrong_cb} ({bad_cb}%)\n' - 'Missing cell barcodes:\t{no_cell} ({bad_cell}%)\n' - 'Missing RMTs (same as above):\t{no_cell} ({bad_cell}%)\n' - 'N present in RMT:\t{rmt_N} ({bad_rmtN}%)\n' - 'Insufficient poly(T):\t{poly_t} ({bad_polyt}%)\n' - '{divide}\nCELL/MOLECULE COUNT DISTRIBUTION\n{divide}\n' - 'Total molecules:\t\t{tot_mc}\n' - 'Molecules lost:\t{mols_lost}\n' - 'Cells lost:\t{cells_lost}\n' - 'Cell description:\n{cell_desc}\n' - '{divide}\nSUMMARY\n{divide}\n' - 'Total retained reads:\t{n_good} ({prop_good}%)\n' - 'Total reads unaligned:\t{lost_al} ({prop_un}%)\n' - 'Total reads filtered:\t{n_bad} ({prop_bad}%)\n' - '{divide}\n' - ) + _oldver = ( + "{divide}\nINPUT\n{divide}\n" + "Total input reads:\t{n_fastq}\n" + "{divide}\nALIGNMENT (% FROM INPUT)\n{divide}\n" + "Total reads aligned:\t{n_sam} ({prop_al}%)\n" + " - Genomic alignments:\t{genomic} ({prop_gen}%)\n" + " - PhiX alignments:\t{phi_x} ({prop_phix}%)\n" + " - Transcriptome alignments:\t{trans} ({prop_trans}%)\n" + "{divide}\nFILTERING (% FROM ALIGNMENT)\n{divide}\n" + "Genomic alignments:\t{genomic} ({bad_gen}%)\n" + "PhiX alignments:\t{phi_x} ({bad_phi}%)\n" + "Incorrect barcodes:\t{wrong_cb} ({bad_cb}%)\n" + "Missing cell barcodes:\t{no_cell} ({bad_cell}%)\n" + "Missing RMTs (same as above):\t{no_cell} ({bad_cell}%)\n" + "N present in RMT:\t{rmt_N} ({bad_rmtN}%)\n" + "Insufficient poly(T):\t{poly_t} ({bad_polyt}%)\n" + "{divide}\nCELL/MOLECULE COUNT DISTRIBUTION\n{divide}\n" + "Total molecules:\t\t{tot_mc}\n" + "Molecules lost:\t{mols_lost}\n" + "Cells lost:\t{cells_lost}\n" + "Cell description:\n{cell_desc}\n" + "{divide}\nSUMMARY\n{divide}\n" + "Total retained reads:\t{n_good} ({prop_good}%)\n" + "Total reads unaligned:\t{lost_al} ({prop_un}%)\n" + "Total reads filtered:\t{n_bad} ({prop_bad}%)\n" + "{divide}\n" + ) @staticmethod - def string_to_regex(summary: str=None) -> str: + def string_to_regex(summary: str = None) -> str: """ converts the contents of seqc.stats.ExperimentalYield.output into a regex object that may contain duplicate definitions @@ -97,11 +120,11 @@ def string_to_regex(summary: str=None) -> str: if not summary: summary = ExperimentalYield.output replacements = [ - ('{divide}', '-*?'), - ('(', '\('), - (')', '\)'), - ('{', '(?P<'), - ('}', '>.*?)') + ("{divide}", "-*?"), + ("(", "\("), + (")", "\)"), + ("{", "(?P<"), + ("}", ">.*?)"), ] for r in replacements: summary = summary.replace(r[0], r[1]) @@ -117,7 +140,7 @@ def identify_duplicate_patterns(regex: str) -> dict: equal to the number of times each was replicated """ - name_pattern = '(\(\?P<)(.*?)(>\.\*\?\))' + name_pattern = "(\(\?P<)(.*?)(>\.\*\?\))" patterns = set() replicates = defaultdict(int) for mo in re.finditer(name_pattern, regex): @@ -136,8 +159,8 @@ def replace_replicated_patterns(regex: str, duplicated_pattern: str) -> str: :param duplicated_pattern: pattern_id :return regex: str, pattern without duplicate group definitions """ - old = '(?P<{}>.*?)'.format(duplicated_pattern) - new = '(?P={})'.format(duplicated_pattern) + old = "(?P<{}>.*?)".format(duplicated_pattern) + new = "(?P={})".format(duplicated_pattern) idx = regex.find(old) + len(old) return regex[:idx] + regex[idx:].replace(old, new) @@ -154,31 +177,47 @@ def dictionary_to_dataframe(cls, groupdict, col_label) -> pd.DataFrame: :return: pd.DataFrame containing log data """ index = ( - ('total', 'input_reads'), - ('total', 'reads_aligned'), - ('aligned', 'genomic'), - ('aligned', 'phi_x'), - ('aligned', 'transcriptome'), - ('filtered', 'genomic'), - ('filtered', 'phi_x'), - ('filtered', 'incorrect_barcodes'), - ('filtered', 'no_barcodes'), - ('filtered', 'CB_contains_N'), - ('filtered', 'RMT_contains_N'), - ('filtered', 'broken_capture_primer'), - ('filtered', 'low_complexity'), - ('summary', 'reads_retained'), - ('summary', 'reads_not_aligned'), - ('summary', 'reads_filtered'), - ('summary', 'total_molecules') + ("total", "input_reads"), + ("total", "reads_aligned"), + ("aligned", "genomic"), + ("aligned", "phi_x"), + ("aligned", "transcriptome"), + ("filtered", "genomic"), + ("filtered", "phi_x"), + ("filtered", "incorrect_barcodes"), + ("filtered", "no_barcodes"), + ("filtered", "CB_contains_N"), + ("filtered", "RMT_contains_N"), + ("filtered", "broken_capture_primer"), + ("filtered", "low_complexity"), + ("summary", "reads_retained"), + ("summary", "reads_not_aligned"), + ("summary", "reads_filtered"), + ("summary", "total_molecules"), + ) + data_list = ( + "n_fastq", + "n_sam", + "genomic", + "phi_x", + "trans", + "genomic", + "phi_x", + "wrong_cb", + "no_cell", + "cell_N", + "rmt_N", + "poly_t", + "dust", + "n_good", + "lost_al", + "n_bad", + "tot_mc", ) - data_list = ('n_fastq', 'n_sam', 'genomic', 'phi_x', 'trans', 'genomic', 'phi_x', - 'wrong_cb', 'no_cell', 'cell_N', 'rmt_N', 'poly_t', 'dust', 'n_good', - 'lost_al', 'n_bad', 'tot_mc') # account for older log version - if groupdict['wrong_cb'] == 'NA': - groupdict['wrong_cb'] = 0 + if groupdict["wrong_cb"] == "NA": + groupdict["wrong_cb"] = 0 data = list(map(lambda x: float(groupdict[x]), data_list)) @@ -202,7 +241,8 @@ def parse_special_fields(groupdict: dict) -> (tuple, list): "^\[\('low_count', (?P[0-9]+)\), " "\('low_coverage', (?P[0-9]+)\), " "\('high_mt', (?P[0-9]+)\), " - "\('low_gene_detection', (?P[0-9]+)\)\]$") + "\('low_gene_detection', (?P[0-9]+)\)\]$" + ) summary_pattern = ( "^count\s+(?P[0-9]+\.[0-9]+)\s" @@ -212,45 +252,60 @@ def parse_special_fields(groupdict: dict) -> (tuple, list): "25%\s+(?P[0-9]+\.[0-9]+)\s" "50%\s+(?P[0-9]+\.[0-9]+)\s" "75%\s+(?P[0-9]+\.[0-9]+)\s" - "max\s+(?P[0-9]+\.[0-9]+)\s?") + "max\s+(?P[0-9]+\.[0-9]+)\s?" + ) - cell = re.match(lost_pattern, groupdict['cells_lost'], re.M).groupdict() - mols = re.match(lost_pattern, groupdict['mols_lost'], re.M).groupdict() - desc = re.match(summary_pattern, groupdict['cell_desc'], re.M).groupdict() + cell = re.match(lost_pattern, groupdict["cells_lost"], re.M).groupdict() + mols = re.match(lost_pattern, groupdict["mols_lost"], re.M).groupdict() + desc = re.match(summary_pattern, groupdict["cell_desc"], re.M).groupdict() if not all((cell, mols, desc)): - raise ValueError('Regex failed to match log. Please check that you are using ' - 'a matched log/seqc pair.') + raise ValueError( + "Regex failed to match log. Please check that you are using " + "a matched log/seqc pair." + ) index = ( - ('molecules_lost', 'low_count'), - ('molecules_lost', 'low_coverage'), - ('molecules_lost', 'high_mt'), - ('molecules_lost', 'low_gene_detection'), - ('cells_lost', 'low_count'), - ('cells_lost', 'low_coverage'), - ('cells_lost', 'high_mt'), - ('cells_lost', 'low_gene_detection'), - ('cell_summary', 'count'), - ('cell_summary', 'mean'), - ('cell_summary', 'std'), - ('cell_summary', 'min'), - ('cell_summary', '25%'), - ('cell_summary', '50%'), - ('cell_summary', '75%'), - ('cell_summary', 'max') + ("molecules_lost", "low_count"), + ("molecules_lost", "low_coverage"), + ("molecules_lost", "high_mt"), + ("molecules_lost", "low_gene_detection"), + ("cells_lost", "low_count"), + ("cells_lost", "low_coverage"), + ("cells_lost", "high_mt"), + ("cells_lost", "low_gene_detection"), + ("cell_summary", "count"), + ("cell_summary", "mean"), + ("cell_summary", "std"), + ("cell_summary", "min"), + ("cell_summary", "25%"), + ("cell_summary", "50%"), + ("cell_summary", "75%"), + ("cell_summary", "max"), ) data_list = ( - mols['low_count'], mols['low_coverage'], mols['high_mt'], - mols['low_gene_detection'], cell['low_count'], cell['low_coverage'], - cell['high_mt'], cell['low_gene_detection'], desc['count'], desc['mean'], - desc['std'], desc['min'], desc['low_quartile'], desc['median'], - desc['high_quartile'], desc['max']) + mols["low_count"], + mols["low_coverage"], + mols["high_mt"], + mols["low_gene_detection"], + cell["low_count"], + cell["low_coverage"], + cell["high_mt"], + cell["low_gene_detection"], + desc["count"], + desc["mean"], + desc["std"], + desc["min"], + desc["low_quartile"], + desc["median"], + desc["high_quartile"], + desc["max"], + ) data_list = list(map(lambda x: float(x), data_list)) return index, data_list @classmethod - def match_log(cls, log_file: str, pattern: str=None) -> dict: + def match_log(cls, log_file: str, pattern: str = None) -> dict: """ create a dictionary to hold data from SEQC summary. @@ -269,8 +324,8 @@ def get_match_object(pattern_): pattern_ = cls.replace_replicated_patterns(pattern_, k) # add beginning and end wildcards - pattern_ = '^.*?' + pattern_ + '.*?$' - with open(log_file, 'r') as f: + pattern_ = "^.*?" + pattern_ + ".*?$" + with open(log_file, "r") as f: summary_data = f.read() mo = re.match(pattern_, summary_data, re.M | re.DOTALL) match_results = mo.groupdict() @@ -293,10 +348,11 @@ def parse_log(cls, logfile: str) -> pd.DataFrame: """ mo = LogData.match_log(logfile) return LogData.dictionary_to_dataframe( - mo, logfile.split('/')[-1].replace('.log', '')) + mo, logfile.split("/")[-1].replace(".log", "") + ) @classmethod - def parse_multiple(cls, directory: str, exclude: str='') -> pd.DataFrame: + def parse_multiple(cls, directory: str, exclude: str = "") -> pd.DataFrame: """ parse multiple SEQC logs into a pd.DataFrame object with a multi-index corresponding to the RUN SUMMARY section of the seqc.log object and a column @@ -313,13 +369,13 @@ def parse_multiple(cls, directory: str, exclude: str='') -> pd.DataFrame: for path, subdirs, files in os.walk(directory): for name in files: filepath = os.path.join(path, name) - if filepath.endswith('.log') and re.match(exclude, filepath) is None: + if filepath.endswith(".log") and re.match(exclude, filepath) is None: logs.append(filepath) frames = [cls.parse_log(f) for f in logs] # create column index - cols = pd.MultiIndex.from_tuples(list(map(lambda p: tuple(p.split('/')), logs))) + cols = pd.MultiIndex.from_tuples(list(map(lambda p: tuple(p.split("/")), logs))) df = pd.concat(frames, 1) df.columns = cols return df diff --git a/src/seqc/multialignment.py b/src/seqc/multialignment.py index dd63795..8c0809b 100644 --- a/src/seqc/multialignment.py +++ b/src/seqc/multialignment.py @@ -84,6 +84,7 @@ def find_component(self, iterable): """ return self[next(iter(iterable))] + def intersection(set_l): res = set_l[0] for s in set_l: @@ -109,7 +110,7 @@ def intersection(set_l): # if len(temp)>0: # res[g] = temp # return res - + # #def strip(genes): # # return tuple(sorted([int(g[2:]) for g in genes])) # def strip(genes): @@ -125,7 +126,7 @@ def intersection(set_l): # uf = UnionFind() # uf.union_all(obs.keys()) # set_membership, sets = uf.find_all(obs.keys()) - + # for s in sets: # d = {} # for k in np.array(list(obs.keys()))[set_membership == s]: @@ -143,26 +144,26 @@ def intersection(set_l): # for g in model: # if model[g]==1: # return g - - + + # def get_combinations(l): # res = [] # for i in range(len(l)): # res += itertools.combinations(l,i+1) # return res - + # # rank the different possible models by their scores # def best_fit_model(obs_s, coalignment_mat): # #obs_s = strip_model(obs) # gene_l = single_gene_list(obs_s) # From the list of observation create a list of unique single genes from which different models can be inferred - - + + # if len(obs_s) == 1: # if len(list(obs_s.keys())[0]) == 1: # return [{gene_l[0]:1}], NO_DISAMBIGUATION - + # possible_genes = intersection(list(obs_s.keys())) - + # #There is one gene that resolve the disambiguation # if len(possible_genes) == 1: # model = {} @@ -170,24 +171,24 @@ def intersection(set_l): # model[g] = 0 # model[list(possible_genes)[0]] = 1 # return [model], RESOLVED_GENE - + # #There is more than one gene that can explain it, no model can be decided # if len(possible_genes) > 1: # return [], NO_GENE_RESOLVED # #There are multiple competing models. For now we don't decide bewteen them # return [], MULTIPLE_MODELS -# # mod_score_list = [] +# # mod_score_list = [] # # for mod in get_combinations(gene_l): # # model = {} -# # for k in gene_l: +# # for k in gene_l: # # if k in mod: # # model[k] = 1 # # else: # # model[k] = 0 # # score = model_score(model, obs_s, coalignment_mat) # # mod_score_list.append((model,score)) - + # #Here to decide if there is one model that's obviously better # # return mod_score_list, MULTIPLE_MODELS @@ -217,16 +218,16 @@ def intersection(set_l): # tot[gene] = model[gene]*(observed[gene,]/coalignment_mat[gene][gene,]) # keys = get_combinations(model.keys()) #get a list of all possible molecule combinations - + # # key is a set of genes and the expected number of reads for it is the sum of expected reads from all genes shared by the key, # # these in turn are the total reads for a gene (extrapoletaed from the uniqely mapped) multiplied by the coalignment factor (present in the coalignment matrix) # # e.g. if A has 20% coalignment with B and there are 80 reads mapped uniquely to A, we expect 80/0.8 * 0.2 = 20 reads to be mapped to AB from A (and more from B) -# for k in keys: +# for k in keys: # k = tuple(sorted(k)) # sum = 0 # for gene in k: # #Patch for SC000 -# if gene==0: +# if gene==0: # if k==(0,): # sum=1 # else: @@ -235,7 +236,7 @@ def intersection(set_l): # elif k in coalignment_mat[gene]: # sum += tot[gene]*coalignment_mat[gene][k] # exp[k] = sum - + # score = calc_score(observed, exp) # return score @@ -260,4 +261,3 @@ def intersection(set_l): # l.append(g) # return list(set(l)) - diff --git a/src/seqc/notebooks/analysis_template.json b/src/seqc/notebooks/analysis_template.json index 122eccb..c6394d0 100644 --- a/src/seqc/notebooks/analysis_template.json +++ b/src/seqc/notebooks/analysis_template.json @@ -112,7 +112,7 @@ "outputs": [], "source": [ "# load counts \n", - "columns = pd.read_csv(DATA, nrows=1, header=None).as_matrix()\n", + "columns = pd.read_csv(DATA, nrows=1, header=None).values\n", "if columns[0][0] == 'sample_number':\n", " counts = pd.read_csv(DATA, index_col=[0, 1])\n", "else:\n", diff --git a/src/seqc/platforms.py b/src/seqc/platforms.py index bc16f52..4f8f6d2 100644 --- a/src/seqc/platforms.py +++ b/src/seqc/platforms.py @@ -14,7 +14,9 @@ class AbstractPlatform: __metaclass__ = ABCMeta - def __init__(self, barcodes_len, filter_lonely_triplets=False, filter_low_count=True): + def __init__( + self, barcodes_len, filter_lonely_triplets=False, filter_low_count=True + ): """ Ctor for the abstract class. barcodes_len is a list of barcodes lengths, check_barcodes is a flag signalling whether or not the barcodes are known apriori @@ -26,6 +28,7 @@ def __init__(self, barcodes_len, filter_lonely_triplets=False, filter_low_count= self._filter_lonely_triplets = filter_lonely_triplets self._filter_low_count = filter_low_count + @staticmethod def factory(type): if type == "in_drop": return in_drop() @@ -128,7 +131,6 @@ def apply_rmt_correction(self, ra, error_rate): class in_drop(AbstractPlatform): - def __init__(self): AbstractPlatform.__init__(self, [-1, 8]) @@ -139,30 +141,30 @@ def check_spacer(cls, sequence): :param sequence: fastq sequence data :returns: (cell_barcode, rmt, poly_t) """ - assert ~sequence.endswith(b'\n') + assert ~sequence.endswith(b"\n") identifier = sequence[24:28] - if identifier == b'CGCC': + if identifier == b"CGCC": cb1 = sequence[:8] cb2 = sequence[30:38] rmt = sequence[38:44] poly_t = sequence[44:] - elif identifier == b'ACGC': + elif identifier == b"ACGC": cb1 = sequence[:9] cb2 = sequence[31:39] rmt = sequence[39:45] poly_t = sequence[45:] - elif identifier == b'GACG': + elif identifier == b"GACG": cb1 = sequence[:10] cb2 = sequence[32:40] rmt = sequence[40:46] poly_t = sequence[46:] - elif identifier == b'TGAC': + elif identifier == b"TGAC": cb1 = sequence[:11] cb2 = sequence[33:41] rmt = sequence[41:47] poly_t = sequence[47:] else: - return b'', b'', b'' + return b"", b"", b"" cell = cb1 + cb2 return cell, rmt, poly_t @@ -182,16 +184,19 @@ def merge_function(self, g, b): :param b: barcode fastq sequence data :return: annotated genomic sequence. """ - pattern = re.compile(b'(.{8,11}?)(GAGTGATTGCTTGTGACGCCTT){s<=2}(.{8})(.{6})(.*?)') + pattern = re.compile( + b"(.{8,11}?)(GAGTGATTGCTTGTGACGCCTT){s<=2}(.{8})(.{6})(.*?)" + ) cell, rmt, poly_t = self.check_spacer(b.sequence[:-1]) if not cell: try: cell1, spacer, cell2, rmt, poly_t = re.match( - pattern, b.sequence[:-1]).groups() + pattern, b.sequence[:-1] + ).groups() cell = cell1 + cell2 except AttributeError: - cell, rmt, poly_t = b'', b'', b'' - g.add_annotation((b'', cell, rmt, poly_t)) + cell, rmt, poly_t = b"", b"", b"" + g.add_annotation((b"", cell, rmt, poly_t)) return g def apply_barcode_correction(self, ra, barcode_files): @@ -200,11 +205,13 @@ def apply_barcode_correction(self, ra, barcode_files): :param ra: Read array :param barcode_files: Valid barcodes files - :returns: Error rate table + :returns: Error rate table and pre-/post-corrected barcodes """ - error_rate = barcode_correction.in_drop(ra, self, barcode_files, max_ed=1) - return error_rate + error_rate, df_correction = barcode_correction.in_drop( + ra, self, barcode_files, max_ed=1 + ) + return error_rate, df_correction def apply_rmt_correction(self, ra, error_rate): """ @@ -214,11 +221,10 @@ def apply_rmt_correction(self, ra, error_rate): :param error_rate: Error rate table from apply_barcode_correction """ - rmt_correction.in_drop(ra, error_rate) + return rmt_correction.in_drop(ra, error_rate) class in_drop_v2(AbstractPlatform): - def __init__(self): AbstractPlatform.__init__(self, [-1, 8]) @@ -229,30 +235,30 @@ def check_spacer(cls, sequence): :param sequence: fastq sequence data :returns: (cell_barcode, rmt, poly_t) """ - assert ~sequence.endswith(b'\n') + assert ~sequence.endswith(b"\n") identifier = sequence[24:28] - if identifier == b'CGCC': + if identifier == b"CGCC": cb1 = sequence[:8] cb2 = sequence[30:38] rmt = sequence[38:46] poly_t = sequence[46:] - elif identifier == b'ACGC': + elif identifier == b"ACGC": cb1 = sequence[:9] cb2 = sequence[31:39] rmt = sequence[39:47] poly_t = sequence[47:] - elif identifier == b'GACG': + elif identifier == b"GACG": cb1 = sequence[:10] cb2 = sequence[32:40] rmt = sequence[40:48] poly_t = sequence[48:] - elif identifier == b'TGAC': + elif identifier == b"TGAC": cb1 = sequence[:11] cb2 = sequence[33:41] rmt = sequence[41:49] poly_t = sequence[49:] else: - return b'', b'', b'' + return b"", b"", b"" cell = cb1 + cb2 return cell, rmt, poly_t @@ -272,16 +278,19 @@ def merge_function(self, g, b): :param b: barcode fastq sequence data :return: annotated genomic sequence. """ - pattern = re.compile(b'(.{8,11}?)(GAGTGATTGCTTGTGACGCCAA){s<=2}(.{8})(.{8})(.*?)') + pattern = re.compile( + b"(.{8,11}?)(GAGTGATTGCTTGTGACGCCAA){s<=2}(.{8})(.{8})(.*?)" + ) cell, rmt, poly_t = self.check_spacer(b.sequence[:-1]) if not cell: try: cell1, spacer, cell2, rmt, poly_t = re.match( - pattern, b.sequence[:-1]).groups() + pattern, b.sequence[:-1] + ).groups() cell = cell1 + cell2 except AttributeError: - cell, rmt, poly_t = b'', b'', b'' - g.add_annotation((b'', cell, rmt, poly_t)) + cell, rmt, poly_t = b"", b"", b"" + g.add_annotation((b"", cell, rmt, poly_t)) return g def apply_barcode_correction(self, ra, barcode_files): @@ -290,11 +299,13 @@ def apply_barcode_correction(self, ra, barcode_files): :param ra: Read array :param barcode_files: Valid barcodes files - :returns: Error rate table + :returns: Error rate table and pre-/post-corrected barcodes """ - error_rate = barcode_correction.in_drop(ra, self, barcode_files, max_ed=2) - return error_rate + error_rate, df_correction = barcode_correction.in_drop( + ra, self, barcode_files, max_ed=2 + ) + return error_rate, df_correction def apply_rmt_correction(self, ra, error_rate): """ @@ -304,11 +315,10 @@ def apply_rmt_correction(self, ra, error_rate): :param error_rate: Error rate table from apply_barcode_correction """ - rmt_correction.in_drop(ra, error_rate) + return rmt_correction.in_drop(ra, error_rate) class in_drop_v3(AbstractPlatform): - def __init__(self): AbstractPlatform.__init__(self, [-1, 8]) @@ -337,7 +347,7 @@ def merge_function(self, g, b): poly_t = seq[16:] # bc is in a fixed position in the name; assumes 8bp indices. cell1 = g.name.strip()[-17:-9] - g.add_annotation((b'', cell1 + cell2, rmt, poly_t)) + g.add_annotation((b"", cell1 + cell2, rmt, poly_t)) return g def apply_barcode_correction(self, ra, barcode_files): @@ -348,7 +358,6 @@ def apply_rmt_correction(self, ra, error_rate): class in_drop_v4(AbstractPlatform): - def __init__(self): AbstractPlatform.__init__(self, [-1, 8]) @@ -376,7 +385,7 @@ def merge_function(self, g, b): cell2 = seq[12:20] rmt = seq[20:28] poly_t = seq[28:] - g.add_annotation((b'', cell1 + cell2, rmt, poly_t)) + g.add_annotation((b"", cell1 + cell2, rmt, poly_t)) return g def apply_barcode_correction(self, ra, barcode_files): @@ -385,10 +394,12 @@ def apply_barcode_correction(self, ra, barcode_files): :param ra: Read array :param barcode_files: Valid barcodes files - :returns: Error rate table + :returns: Error rate table and pre-/post-corrected barcodes """ - error_rate = barcode_correction.in_drop(ra, self, barcode_files, max_ed=2) - return error_rate + error_rate, df_correction = barcode_correction.in_drop( + ra, self, barcode_files, max_ed=2 + ) + return error_rate, df_correction def apply_rmt_correction(self, ra, error_rate): """ @@ -397,16 +408,17 @@ def apply_rmt_correction(self, ra, error_rate): :param ra: Read array :param error_rate: Error rate table from apply_barcode_correction """ - rmt_correction.in_drop(ra, error_rate) + return rmt_correction.in_drop(ra, error_rate) class in_drop_v5(AbstractPlatform): - def __init__(self, potential_barcodes=None): AbstractPlatform.__init__(self, [-1, 8]) self.potential_barcodes = potential_barcodes if self.potential_barcodes is not None: - self.potential_encoded_bcs = set(DNA3Bit.encode(pb) for pb in self.potential_barcodes) + self.potential_encoded_bcs = set( + DNA3Bit.encode(pb) for pb in self.potential_barcodes + ) @classmethod def check_spacer(cls, sequence): @@ -415,22 +427,22 @@ def check_spacer(cls, sequence): :param sequence: fastq sequence data :returns: (cb1, rest), where rest includes cb2, rmt, poly_t """ - assert ~sequence.endswith(b'\n') + assert ~sequence.endswith(b"\n") identifier = sequence[24:28] - if identifier == b'CGCC': + if identifier == b"CGCC": cb1 = sequence[:8] rest = sequence[30:] - elif identifier == b'ACGC': + elif identifier == b"ACGC": cb1 = sequence[:9] rest = sequence[31:] - elif identifier == b'GACG': + elif identifier == b"GACG": cb1 = sequence[:10] rest = sequence[32:] - elif identifier == b'TGAC': + elif identifier == b"TGAC": cb1 = sequence[:11] rest = sequence[33:] else: - return b'', b'' + return b"", b"" return cb1, rest @@ -451,7 +463,7 @@ def check_cb2(self, rest): rmt = rest[9:17] poly_t = rest[17:] else: - return b'', b'', b'' + return b"", b"", b"" return cb2, rmt, poly_t @@ -469,7 +481,7 @@ def build_cb2_barcodes(cls, barcode_files, max_ed=1): # Build set of all potential correct and incorrect cb2 potential_barcodes = set() cb2_file = barcode_files[1] - with open(cb2_file, 'r') as f: + with open(cb2_file, "r") as f: valid_barcodes = set([line.strip() for line in f.readlines()]) # This will work for any number of allowable mismatches for bc in valid_barcodes: @@ -478,9 +490,11 @@ def build_cb2_barcodes(cls, barcode_files, max_ed=1): invalid_bc = [[nt] for nt in bc] for ind in inds: valid_nt = bc[ind] - invalid_bc[ind] = [nt for nt in ['A', 'C', 'G', 'T', 'N'] if nt != valid_nt] + invalid_bc[ind] = [ + nt for nt in ["A", "C", "G", "T", "N"] if nt != valid_nt + ] for mut in itertools.product(*invalid_bc): - potential_barcodes.add(''.join(mut)) + potential_barcodes.add("".join(mut)) potential_barcodes = set([pb.encode() for pb in potential_barcodes]) return cls(potential_barcodes=potential_barcodes) @@ -503,15 +517,15 @@ def merge_function(self, g, b): """ cb1, rest = self.check_spacer(b.sequence[:-1]) if not cb1: - cell, rmt, poly_t = b'', b'', b'' + cell, rmt, poly_t = b"", b"", b"" else: cb2, rmt, poly_t = self.check_cb2(rest) if not cb2: - cell = b'' + cell = b"" else: cell = cb1 + cb2 - g.add_annotation((b'', cell, rmt, poly_t)) + g.add_annotation((b"", cell, rmt, poly_t)) return g def extract_barcodes(self, seq): @@ -541,11 +555,13 @@ def apply_barcode_correction(self, ra, barcode_files): :param ra: Read array :param barcode_files: Valid barcodes files - :returns: Error rate table + :returns: Error rate table and pre-/post-corrected barcodes """ - error_rate = barcode_correction.in_drop(ra, self, barcode_files, max_ed=1) - return error_rate + error_rate, df_correction = barcode_correction.in_drop( + ra, self, barcode_files, max_ed=1 + ) + return error_rate, df_correction def apply_rmt_correction(self, ra, error_rate): """ @@ -555,11 +571,10 @@ def apply_rmt_correction(self, ra, error_rate): :param error_rate: Error rate table from apply_barcode_correction """ - rmt_correction.in_drop(ra, error_rate) + return rmt_correction.in_drop(ra, error_rate) class drop_seq(AbstractPlatform): - def __init__(self): AbstractPlatform.__init__(self, [12]) @@ -582,7 +597,7 @@ def merge_function(self, g, b): cell = b.sequence[:12] rmt = b.sequence[12:20] poly_t = b.sequence[20:-1] - g.add_annotation((b'', cell, rmt, poly_t)) + g.add_annotation((b"", cell, rmt, poly_t)) return g def apply_barcode_correction(self, ra, barcode_files): @@ -591,11 +606,11 @@ def apply_barcode_correction(self, ra, barcode_files): :param ra: Read array :param barcode_files: Valid barcodes files - :returns: Error rate table + :returns: Error rate table and pre-/post-corrected barcodes """ barcode_correction.drop_seq(ra) - return None + return None, None def apply_rmt_correction(self, ra, error_rate): """ @@ -605,11 +620,10 @@ def apply_rmt_correction(self, ra, error_rate): :param error_rate: Error rate table from apply_barcode_correction """ - log.info('Drop-seq barcodes do not support RMT correction') + log.info("Drop-seq barcodes do not support RMT correction") class mars1_seq(AbstractPlatform): - def __init__(self): AbstractPlatform.__init__(self, [4, 8], True, False) @@ -632,9 +646,14 @@ def merge_function(self, g, b): :return: annotated genomic sequence. """ - *name_fields, pool, cell, rmt = g.name[1:-1].split(b':') - g.name = (b'@' + b':'.join((pool, cell, rmt, b'')) + b';' + - b':'.join(name_fields) + b'\n') + *name_fields, pool, cell, rmt = g.name[1:-1].split(b":") + g.name = ( + b"@" + + b":".join((pool, cell, rmt, b"")) + + b";" + + b":".join(name_fields) + + b"\n" + ) return g def apply_barcode_correction(self, ra, barcode_files): @@ -643,12 +662,14 @@ def apply_barcode_correction(self, ra, barcode_files): :param ra: Read array :param barcode_files: Valid barcodes files - :returns: Error rate table + :returns: Error rate table and pre-/post-corrected barcodes """ # todo: verify max edit distance - error_rate = barcode_correction.in_drop(ra, self, barcode_files, max_ed=0) - return error_rate + error_rate, df_correction = barcode_correction.in_drop( + ra, self, barcode_files, max_ed=0 + ) + return error_rate, df_correction def apply_rmt_correction(self, ra, error_rate): """ @@ -658,11 +679,10 @@ def apply_rmt_correction(self, ra, error_rate): :param error_rate: Error rate table from apply_barcode_correction """ - log.info('Mars-seq barcodes do not support RMT correction') + log.info("Mars-seq barcodes do not support RMT correction") class mars2_seq(AbstractPlatform): - def __init__(self): AbstractPlatform.__init__(self, [4, 8], True, False) @@ -687,7 +707,7 @@ def merge_function(self, g, b): cell = seq[:7] rmt = seq[7:15] poly_t = seq[15:] - g.add_annotation((b'', pool + cell, rmt, poly_t)) + g.add_annotation((b"", pool + cell, rmt, poly_t)) return g def apply_barcode_correction(self, ra, barcode_files): @@ -696,12 +716,14 @@ def apply_barcode_correction(self, ra, barcode_files): :param ra: Read array :param barcode_files: Valid barcodes files - :returns: Error rate table + :returns: Error rate table and pre-/post-corrected barcodes """ # todo: verify max edit distance - error_rate = barcode_correction.in_drop(ra, self, barcode_files, max_ed=0) - return error_rate + error_rate, df_correction = barcode_correction.in_drop( + ra, self, barcode_files, max_ed=0 + ) + return error_rate, df_correction def apply_rmt_correction(self, ra, error_rate): """ @@ -711,11 +733,10 @@ def apply_rmt_correction(self, ra, error_rate): :param error_rate: Error rate table from apply_barcode_correction """ - log.info('Mars-seq barcodes do not support RMT correction') + log.info("Mars-seq barcodes do not support RMT correction") class mars_germany(AbstractPlatform): - def __init__(self): AbstractPlatform.__init__(self, [10], True, False) @@ -725,13 +746,13 @@ def primer_length(self): def merge_function(self, g, b): pool = g.sequence.strip()[3:7] # 4 bp # strip() is necessary in case there is a truncated read. \n=good, \n\n=bad - g.sequence = g.sequence.strip()[7:] + b'\n' + g.sequence = g.sequence.strip()[7:] + b"\n" # Need to skip over the quality as well - g.quality = g.quality.strip()[7:] + b'\n' + g.quality = g.quality.strip()[7:] + b"\n" seq = b.sequence.strip() cell = seq[:6] # 6 bp rmt = seq[6:12] # 6 bp - g.add_annotation((b'', pool + cell, rmt, b'')) + g.add_annotation((b"", pool + cell, rmt, b"")) return g def apply_barcode_correction(self, ra, barcode_files): @@ -740,12 +761,14 @@ def apply_barcode_correction(self, ra, barcode_files): :param ra: Read array :param barcode_files: Valid barcodes files - :returns: Error rate table + :returns: Error rate table and pre-/post-corrected barcodes """ # todo: verify max edit distance - error_rate = barcode_correction.in_drop(ra, self, barcode_files, max_ed=0) - return error_rate + error_rate, df_correction = barcode_correction.in_drop( + ra, self, barcode_files, max_ed=0 + ) + return error_rate, df_correction def apply_rmt_correction(self, ra, error_rate): """ @@ -755,7 +778,7 @@ def apply_rmt_correction(self, ra, error_rate): :param error_rate: Error rate table from apply_barcode_correction """ - log.info('Mars-seq barcodes do not support RMT correction') + log.info("Mars-seq barcodes do not support RMT correction") class ten_x(AbstractPlatform): @@ -788,15 +811,17 @@ def merge_function(self, g, b): # bc is in a fixed position in the name; assumes 10bp indices. cell = g.name.strip()[-23:-9] poly_t = combined[10:] - g.add_annotation((b'', cell, rmt, poly_t)) + g.add_annotation((b"", cell, rmt, poly_t)) return g def apply_barcode_correction(self, ra, barcode_files): - error_rate = barcode_correction.ten_x_barcode_correction(ra, self, barcode_files, max_ed=0) - return error_rate + error_rate, df_correction = barcode_correction.ten_x_barcode_correction( + ra, self, barcode_files, max_ed=0 + ) + return error_rate, df_correction def apply_rmt_correction(self, ra, error_rate): - rmt_correction.in_drop(ra, error_rate=0.02) + return rmt_correction.in_drop(ra, error_rate=0.02) class ten_x_v2(AbstractPlatform): @@ -826,7 +851,7 @@ def merge_function(self, g, b): cell = combined[0:16] # v2 chemistry has 16bp barcodes rmt = combined[16:26] # 10 baselength RMT poly_t = combined[26:] - g.add_annotation((b'', cell, rmt, poly_t)) + g.add_annotation((b"", cell, rmt, poly_t)) return g def apply_barcode_correction(self, ra, barcode_files): @@ -835,12 +860,14 @@ def apply_barcode_correction(self, ra, barcode_files): :param ra: Read array :param barcode_files: Valid barcodes files - :returns: Error rate table + :returns: Error rate table and pre-/post-corrected barcodes """ # todo: verify max edit distance - error_rate = barcode_correction.ten_x_barcode_correction(ra, self, barcode_files, max_ed=0) - return error_rate + error_rate, df_correction = barcode_correction.ten_x_barcode_correction( + ra, self, barcode_files, max_ed=0 + ) + return error_rate, df_correction def apply_rmt_correction(self, ra, error_rate): """ @@ -850,7 +877,7 @@ def apply_rmt_correction(self, ra, error_rate): :param error_rate: Error rate table from apply_barcode_correction """ - rmt_correction.in_drop(ra, error_rate=0.02) + return rmt_correction.in_drop(ra, error_rate=0.02) class ten_x_v3(AbstractPlatform): @@ -881,10 +908,10 @@ def merge_function(self, g, b): :return: annotated genomic sequence. """ combined = b.sequence.strip() - cell = combined[0:self.cb_len] - rmt = combined[self.cb_len:self.cb_len + self.mb_len] - poly_t = combined[self.cb_len + self.mb_len:] - g.add_annotation((b'', cell, rmt, poly_t)) + cell = combined[0 : self.cb_len] + rmt = combined[self.cb_len : self.cb_len + self.mb_len] + poly_t = combined[self.cb_len + self.mb_len :] + g.add_annotation((b"", cell, rmt, poly_t)) return g def apply_barcode_correction(self, ra, barcode_files): @@ -893,12 +920,14 @@ def apply_barcode_correction(self, ra, barcode_files): :param ra: Read array :param barcode_files: Valid barcodes files - :returns: Error rate table + :returns: Error rate table and pre-/post-corrected barcodes """ # todo: verify max edit distance - error_rate = barcode_correction.ten_x_barcode_correction(ra, self, barcode_files, max_ed=0) - return error_rate + error_rate, df_correction = barcode_correction.ten_x_barcode_correction( + ra, self, barcode_files, max_ed=0 + ) + return error_rate, df_correction def apply_rmt_correction(self, ra, error_rate): """ @@ -908,4 +937,4 @@ def apply_rmt_correction(self, ra, error_rate): :param error_rate: Error rate table from apply_barcode_correction """ - rmt_correction.in_drop(ra, error_rate=0.02) + return rmt_correction.in_drop(ra, error_rate=0.02) diff --git a/src/seqc/plot.py b/src/seqc/plot.py index 70b4edf..4bbacec 100644 --- a/src/seqc/plot.py +++ b/src/seqc/plot.py @@ -11,74 +11,63 @@ # make matplotlib logger less verbose import logging -logging.getLogger('matplotlib').setLevel(logging.WARNING) + +logging.getLogger("matplotlib").setLevel(logging.WARNING) try: - os.environ['DISPLAY'] + os.environ["DISPLAY"] except KeyError: - matplotlib.use('Agg') + matplotlib.use("Agg") import matplotlib.pyplot as plt + with warnings.catch_warnings(): - warnings.simplefilter('ignore') # catch warnings that system can't find fonts + warnings.simplefilter("ignore") # catch warnings that system can't find fonts fm = font_manager.fontManager - fm.findfont('Raleway') - fm.findfont('Lato') + fm.findfont("Raleway") + fm.findfont("Lato") warnings.filterwarnings(action="ignore", module="matplotlib", message="^tight_layout") -dark_gray = '.15' +dark_gray = ".15" -_colors = ['#4C72B0', '#55A868', '#C44E52', - '#8172B2', '#CCB974', '#64B5CD'] +_colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2", "#CCB974", "#64B5CD"] style_dictionary = { - 'figure.figsize': (3, 3), - 'figure.facecolor': 'white', - - 'figure.dpi': 200, - 'savefig.dpi': 200, - - 'text.color': 'k', - + "figure.figsize": (3, 3), + "figure.facecolor": "white", + "figure.dpi": 200, + "savefig.dpi": 200, + "text.color": "k", "legend.frameon": False, "legend.numpoints": 1, "legend.scatterpoints": 1, - - 'font.family': ['sans-serif'], - 'font.serif': ['Computer Modern Roman', 'serif'], - 'font.monospace': ['Inconsolata', 'Computer Modern Typewriter', 'Monaco'], - 'font.sans-serif': ['Helvetica', 'Lato', 'sans-serif'], - - 'patch.facecolor': _colors[0], - 'patch.edgecolor': 'none', - - 'grid.linestyle': "-", - - 'axes.labelcolor': dark_gray, - 'axes.facecolor': 'white', - 'axes.linewidth': 1., - 'axes.grid': False, - 'axes.axisbelow': False, - 'axes.edgecolor': dark_gray, - 'axes.prop_cycle': cycler('color', _colors), - - 'lines.solid_capstyle': 'round', - 'lines.color': _colors[0], - 'lines.markersize': 4, - - 'image.cmap': 'viridis', - 'image.interpolation': 'none', - - 'xtick.direction': 'in', - 'xtick.major.size': 4, - 'xtick.minor.size': 2, - 'xtick.color': dark_gray, - - 'ytick.direction': 'in', - 'ytick.major.size': 4, - 'ytick.minor.size': 2, + "font.family": ["sans-serif"], + "font.serif": ["Computer Modern Roman", "serif"], + "font.monospace": ["Inconsolata", "Computer Modern Typewriter", "Monaco"], + "font.sans-serif": ["Helvetica", "Lato", "sans-serif"], + "patch.facecolor": _colors[0], + "patch.edgecolor": "none", + "grid.linestyle": "-", + "axes.labelcolor": dark_gray, + "axes.facecolor": "white", + "axes.linewidth": 1.0, + "axes.grid": False, + "axes.axisbelow": False, + "axes.edgecolor": dark_gray, + "axes.prop_cycle": cycler("color", _colors), + "lines.solid_capstyle": "round", + "lines.color": _colors[0], + "lines.markersize": 4, + "image.cmap": "viridis", + "image.interpolation": "none", + "xtick.direction": "in", + "xtick.major.size": 4, + "xtick.minor.size": 2, + "xtick.color": dark_gray, + "ytick.direction": "in", + "ytick.major.size": 4, + "ytick.minor.size": 2, "ytick.color": dark_gray, - } matplotlib.rcParams.update(style_dictionary) @@ -86,7 +75,7 @@ def refresh_rc(): matplotlib.rcParams.update(style_dictionary) - print('rcParams updated') + print("rcParams updated") class FigureGrid: @@ -153,7 +142,7 @@ def detick(self, x=True, y=True): for ax in self: detick(ax, x=x, y=y) - def savefig(self, filename, pad_inches=0.1, bbox_inches='tight', *args, **kwargs): + def savefig(self, filename, pad_inches=0.1, bbox_inches="tight", *args, **kwargs): """ wrapper for savefig, including necessary paramters to avoid cut-off @@ -165,7 +154,8 @@ def savefig(self, filename, pad_inches=0.1, bbox_inches='tight', *args, **kwargs :return: """ self.figure.savefig( - filename, pad_inches=pad_inches, bbox_inches=bbox_inches, *args, **kwargs) + filename, pad_inches=pad_inches, bbox_inches=bbox_inches, *args, **kwargs + ) def detick(ax=None, x=True, y=True): @@ -185,27 +175,27 @@ def despine(ax=None, top=True, right=True, bottom=False, left=False) -> None: # set spines if top: - ax.spines['top'].set_visible(False) + ax.spines["top"].set_visible(False) if right: - ax.spines['right'].set_visible(False) + ax.spines["right"].set_visible(False) if bottom: - ax.spines['bottom'].set_visible(False) + ax.spines["bottom"].set_visible(False) if left: - ax.spines['left'].set_visible(False) + ax.spines["left"].set_visible(False) # set ticks if top and bottom: - ax.xaxis.set_ticks_position('none') + ax.xaxis.set_ticks_position("none") elif top: - ax.xaxis.set_ticks_position('bottom') + ax.xaxis.set_ticks_position("bottom") elif bottom: - ax.xaxis.set_ticks_position('top') + ax.xaxis.set_ticks_position("top") if left and right: - ax.yaxis.set_ticks_position('none') + ax.yaxis.set_ticks_position("none") elif left: - ax.yaxis.set_ticks_position('right') + ax.yaxis.set_ticks_position("right") elif right: - ax.yaxis.set_ticks_position('left') + ax.yaxis.set_ticks_position("left") def xtick_vertical(ax=None): @@ -215,7 +205,7 @@ def xtick_vertical(ax=None): xt = ax.get_xticks() if np.all(xt.astype(int) == xt): # ax.get_xticks() returns floats xt = xt.astype(int) - ax.set_xticklabels(xt, rotation='vertical') + ax.set_xticklabels(xt, rotation="vertical") def equalize_numerical_tick_number(ax=None): @@ -250,7 +240,7 @@ def map_categorical_to_cmap(data: np.ndarray, cmap=plt.get_cmap()): """ categories = np.unique(data) n = len(categories) - if isinstance(cmap, str) and 'random' in cmap: + if isinstance(cmap, str) and "random" in cmap: colors = np.random.rand(n, 3) else: colors = cmap(np.linspace(0, 1, n)) @@ -259,8 +249,13 @@ def map_categorical_to_cmap(data: np.ndarray, cmap=plt.get_cmap()): def add_legend_to_categorical_vector( - colors: np.ndarray, labels, ax, loc='best', # bbox_to_anchor=(0.98, 0.5), - markerscale=0.75, **kwargs): + colors: np.ndarray, + labels, + ax, + loc="best", # bbox_to_anchor=(0.98, 0.5), + markerscale=0.75, + **kwargs +): """ Add a legend to a plot where the color scale was set by discretizing a colormap. @@ -272,18 +267,31 @@ def add_legend_to_categorical_vector( """ artists = [] for c in colors: - artists.append(plt.Line2D((0, 1), (0, 0), color=c, marker='o', linestyle='')) + artists.append(plt.Line2D((0, 1), (0, 0), color=c, marker="o", linestyle="")) ax.legend( - artists, labels, loc=loc, markerscale=markerscale, # bbox_to_anchor=bbox_to_anchor, - **kwargs) + artists, + labels, + loc=loc, + markerscale=markerscale, # bbox_to_anchor=bbox_to_anchor, + **kwargs + ) class scatter: - @staticmethod def categorical( - x, y, c, ax=None, cmap=plt.get_cmap(), legend=True, legend_kwargs=None, - randomize=True, remove_ticks=False, *args, **kwargs): + x, + y, + c, + ax=None, + cmap=plt.get_cmap(), + legend=True, + legend_kwargs=None, + randomize=True, + remove_ticks=False, + *args, + **kwargs + ): """ wrapper for scatter wherein the output should be colored by a categorical vector c @@ -313,21 +321,31 @@ def categorical( else: ind = np.argsort(np.ravel(c)) - ax.scatter(np.ravel(x)[ind], np.ravel(y)[ind], c=color_vector[ind], *args, - **kwargs) + ax.scatter( + np.ravel(x)[ind], np.ravel(y)[ind], c=color_vector[ind], *args, **kwargs + ) if remove_ticks: ax.xaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator()) labels, colors = zip(*sorted(category_to_color.items())) if legend: - add_legend_to_categorical_vector(colors, labels, ax, markerscale=2, - **legend_kwargs) + add_legend_to_categorical_vector( + colors, labels, ax, markerscale=2, **legend_kwargs + ) return ax @staticmethod - def continuous(x, y, c=None, ax=None, colorbar=True, randomize=True, - remove_ticks=False, **kwargs): + def continuous( + x, + y, + c=None, + ax=None, + colorbar=True, + randomize=True, + remove_ticks=False, + **kwargs + ): """ wrapper for scatter wherein the coordinates x and y are colored according to a continuous vector c @@ -379,16 +397,15 @@ def tatarize(n): :return: """ - with open(os.path.expanduser('~/.seqc/tools/tatarize_269.txt')) as f: + with open(os.path.expanduser("~/.seqc/tools/tatarize_269.txt")) as f: s = f.read().split('","') - s[0] = s[0].replace('{"', '') - s[-1] = s[-1].replace('"}', '') + s[0] = s[0].replace('{"', "") + s[-1] = s[-1].replace('"}', "") s = [hex2color(s) for s in s] return s[:n] class Diagnostics: - @staticmethod def mitochondrial_fraction(data: pd.DataFrame, ax=None): """plot the fraction of mRNA that are of mitochondrial origin for each cell. @@ -398,7 +415,7 @@ def mitochondrial_fraction(data: pd.DataFrame, ax=None): :return: ax """ - mt_genes = data.molecules.columns[data.molecules.columns.str.contains('MT-')] + mt_genes = data.molecules.columns[data.molecules.columns.str.contains("MT-")] mt_counts = data.molecules[mt_genes].sum(axis=1) library_size = data.molecules.sum(axis=1) @@ -406,9 +423,9 @@ def mitochondrial_fraction(data: pd.DataFrame, ax=None): ax = plt.gca() scatter.continuous(library_size, mt_counts / library_size) - ax.set_title('Mitochondrial Fraction') - ax.set_xlabel('Total Gene Expression') - ax.set_ylabel('Mitochondrial Gene Expression') + ax.set_title("Mitochondrial Fraction") + ax.set_xlabel("Total Gene Expression") + ax.set_ylabel("Mitochondrial Gene Expression") _, xmax = ax.get_xlim() ax.set_xlim((None, xmax)) _, ymax = ax.get_ylim() @@ -418,30 +435,30 @@ def mitochondrial_fraction(data: pd.DataFrame, ax=None): @staticmethod def pca_components(fig_name, variance_ratio, pca_comps): - ''' + """ :param fig_name: name for the figure :param variance_ratio: variance ratios of at least 20 pca components :param pca_comps: pca components of cells - ''' + """ fig = FigureGrid(4, max_cols=2) ax_pca, ax_pca12, ax_pca13, ax_pca23 = iter(fig) - ax_pca.plot(variance_ratio[0:20]*100.0, c = '#1f77b4') - ax_pca.set_xlabel('pca components') - ax_pca.set_ylabel('explained variance') - ax_pca.set_xlim([0,20.5]) + ax_pca.plot(variance_ratio[0:20] * 100.0, c="#1f77b4") + ax_pca.set_xlabel("pca components") + ax_pca.set_ylabel("explained variance") + ax_pca.set_xlim([0, 20.5]) - ax_pca12.scatter(pca_comps[:, 0], pca_comps[:, 1], s=3, c = '#1f77b4') + ax_pca12.scatter(pca_comps[:, 0], pca_comps[:, 1], s=3, c="#1f77b4") ax_pca12.set_xlabel("pca 1") ax_pca12.set_ylabel("pca 2") xtick_vertical(ax=ax_pca12) - ax_pca13.scatter(pca_comps[:, 0], pca_comps[:, 2], s=3, c = '#1f77b4') + ax_pca13.scatter(pca_comps[:, 0], pca_comps[:, 2], s=3, c="#1f77b4") ax_pca13.set_xlabel("pca 1") ax_pca13.set_ylabel("pca 3") xtick_vertical(ax=ax_pca13) - ax_pca23.scatter(pca_comps[:, 1], pca_comps[:, 2], s=3, c = '#1f77b4') + ax_pca23.scatter(pca_comps[:, 1], pca_comps[:, 2], s=3, c="#1f77b4") ax_pca23.set_xlabel("pca 2") ax_pca23.set_ylabel("pca 3") xtick_vertical(ax=ax_pca23) @@ -456,31 +473,95 @@ def phenograph_clustering(fig_name, cell_sizes, clust_info, tsne_comps): ax_tsne, ax_phenograph = iter(fig) cl = np.log10(cell_sizes) - splot = ax_tsne.scatter(tsne_comps[:, 0], tsne_comps[:, 1], - c=cl, s=3, cmap=plt.cm.coolwarm, vmin = np.min(cl), - vmax=np.percentile(cl, 98)) + splot = ax_tsne.scatter( + tsne_comps[:, 0], + tsne_comps[:, 1], + c=cl, + s=3, + cmap=plt.cm.coolwarm, + vmin=np.min(cl), + vmax=np.percentile(cl, 98), + ) ax_tsne.set_title("UMI Counts (log10)") ax_tsne.set_xticks([]) ax_tsne.set_yticks([]) divider = make_axes_locatable(ax_tsne) - cax = divider.append_axes('right', size='3%', pad=0.04) - fig.figure.colorbar(splot, cax=cax, orientation='vertical') + cax = divider.append_axes("right", size="3%", pad=0.04) + fig.figure.colorbar(splot, cax=cax, orientation="vertical") # this is a list of contrast colors for clutering - cmap=["#010067","#D5FF00","#FF0056","#9E008E","#0E4CA1","#FFE502","#005F39","#00FF00","#95003A", - "#FF937E","#A42400","#001544","#91D0CB","#620E00","#6B6882","#0000FF","#007DB5","#6A826C", - "#00AE7E","#C28C9F","#BE9970","#008F9C","#5FAD4E","#FF0000","#FF00F6","#FF029D","#683D3B", - "#FF74A3","#968AE8","#98FF52","#A75740","#01FFFE","#FFEEE8","#FE8900","#BDC6FF","#01D0FF", - "#BB8800","#7544B1","#A5FFD2","#FFA6FE","#774D00","#7A4782","#263400","#004754","#43002C", - "#B500FF","#FFB167","#FFDB66","#90FB92","#7E2DD2","#BDD393","#E56FFE","#DEFF74","#00FF78", - "#009BFF","#006401","#0076FF","#85A900","#00B917","#788231","#00FFC6","#FF6E41","#E85EBE"] + cmap = [ + "#010067", + "#D5FF00", + "#FF0056", + "#9E008E", + "#0E4CA1", + "#FFE502", + "#005F39", + "#00FF00", + "#95003A", + "#FF937E", + "#A42400", + "#001544", + "#91D0CB", + "#620E00", + "#6B6882", + "#0000FF", + "#007DB5", + "#6A826C", + "#00AE7E", + "#C28C9F", + "#BE9970", + "#008F9C", + "#5FAD4E", + "#FF0000", + "#FF00F6", + "#FF029D", + "#683D3B", + "#FF74A3", + "#968AE8", + "#98FF52", + "#A75740", + "#01FFFE", + "#FFEEE8", + "#FE8900", + "#BDC6FF", + "#01D0FF", + "#BB8800", + "#7544B1", + "#A5FFD2", + "#FFA6FE", + "#774D00", + "#7A4782", + "#263400", + "#004754", + "#43002C", + "#B500FF", + "#FFB167", + "#FFDB66", + "#90FB92", + "#7E2DD2", + "#BDD393", + "#E56FFE", + "#DEFF74", + "#00FF78", + "#009BFF", + "#006401", + "#0076FF", + "#85A900", + "#00B917", + "#788231", + "#00FFC6", + "#FF6E41", + "#E85EBE", + ] colors = [] for i in range(len(clust_info)): colors.append(cmap[clust_info[i]]) - for ci in range(np.min(clust_info),np.max(clust_info)+1): + for ci in range(np.min(clust_info), np.max(clust_info) + 1): x1 = [] y1 = [] for i in range(len(clust_info)): @@ -488,11 +569,13 @@ def phenograph_clustering(fig_name, cell_sizes, clust_info, tsne_comps): x1.append(tsne_comps[i, 0]) y1.append(tsne_comps[i, 1]) cl = colors[i] - ax_phenograph.scatter(x1, y1, c=cl, s=3, label="C"+str(ci+1)) - ax_phenograph.set_title('Phenograph Clustering') + ax_phenograph.scatter(x1, y1, c=cl, s=3, label="C" + str(ci + 1)) + ax_phenograph.set_title("Phenograph Clustering") ax_phenograph.set_xticks([]) ax_phenograph.set_yticks([]) - ax_phenograph.legend(bbox_to_anchor=(1, 1), loc=2, borderaxespad=0., markerscale=2) + ax_phenograph.legend( + bbox_to_anchor=(1, 1), loc=2, borderaxespad=0.0, markerscale=2 + ) fig.tight_layout() fig.savefig(fig_name, dpi=300, transparent=True) @@ -507,14 +590,15 @@ def cell_size_histogram(data, f=None, ax=None, save=None): cell_size = data.sum(axis=1) plt.hist(np.log10(cell_size), bins=25, log=True) - ax.set_xlabel('log10(cell size)') - ax.set_ylabel('frequency') + ax.set_xlabel("log10(cell size)") + ax.set_ylabel("frequency") despine(ax) xtick_vertical(ax) if save is not None: if not isinstance(save, str): - raise TypeError('save must be the string filename of the ' - 'figure-to-be-saved') + raise TypeError( + "save must be the string filename of the " "figure-to-be-saved" + ) plt.tight_layout() f.savefig(save, dpi=300) diff --git a/src/seqc/read_array.py b/src/seqc/read_array.py index 6028a9f..105655a 100644 --- a/src/seqc/read_array.py +++ b/src/seqc/read_array.py @@ -16,10 +16,11 @@ class ReadArray: _dtype = [ - ('status', np.uint8), # if > 8 tests, change to int16 - ('cell', np.int64), - ('rmt', np.int32), - ('n_poly_t', np.uint8)] + ("status", np.uint8), # if > 8 tests, change to int16 + ("cell", np.int64), + ("rmt", np.int64), + ("n_poly_t", np.uint8), + ] def __init__(self, data, genes, positions): """ @@ -34,21 +35,26 @@ def __init__(self, data, genes, positions): if not isinstance(genes, (csr_matrix, np.ndarray)): raise TypeError( - 'genes must be a scipy csr_matrix or np.array, not %s' - % repr(type(genes))) + "genes must be a scipy csr_matrix or np.array, not %s" + % repr(type(genes)) + ) self._genes = genes if not isinstance(positions, (csr_matrix, np.ndarray)): raise TypeError( - 'positions must be a scipy csr_matrix or np.array, not %s' - % repr(type(positions))) + "positions must be a scipy csr_matrix or np.array, not %s" + % repr(type(positions)) + ) self._positions = positions if not isinstance(data, np.ndarray): raise TypeError( - 'data must be a structured np.array object, not %s' % repr(type(data))) + "data must be a structured np.array object, not %s" % repr(type(data)) + ) self._data = data - if isinstance(genes, csr_matrix): # track if genes/positions are csr or np.array + if isinstance( + genes, csr_matrix + ): # track if genes/positions are csr or np.array self._ambiguous_genes = True else: self._ambiguous_genes = False @@ -96,13 +102,13 @@ def __iter__(self): yield self.data[i], self.genes[i], self.positions[i] filter_codes = { - 'no_gene': 0b1, - 'rmt_error': 0b10, - 'cell_error': 0b100, # todo this must execute before multialignment - 'low_polyt': 0b1000, - 'gene_not_unique': 0b10000, - 'primer_missing': 0b100000, - 'lonely_triplet': 0b1000000, # todo could call this low coverage? + "no_gene": 0b1, + "rmt_error": 0b10, + "cell_error": 0b100, # todo this must execute before multialignment + "low_polyt": 0b1000, + "gene_not_unique": 0b10000, + "primer_missing": 0b100000, + "lonely_triplet": 0b1000000, # todo could call this low coverage? } def initial_filtering(self, required_poly_t=1): @@ -121,17 +127,19 @@ def initial_filtering(self, required_poly_t=1): # genes are dealt with differently depending on the state of the array if self._ambiguous_genes: nnz = self.genes.getnnz(axis=1) - failing[nnz == 0] |= self.filter_codes['no_gene'] - failing[nnz > 1] |= self.filter_codes['gene_not_unique'] + failing[nnz == 0] |= self.filter_codes["no_gene"] + failing[nnz > 1] |= self.filter_codes["gene_not_unique"] else: # multiple gene filter is empty - failing[self.genes == 0] |= self.filter_codes['no_gene'] + failing[self.genes == 0] |= self.filter_codes["no_gene"] # todo add logic for "primer_missing" - failing[self.data['rmt'] == 0] |= self.filter_codes['primer_missing'] - failing[self.data['cell'] == 0] |= self.filter_codes['primer_missing'] - failing[self.data['n_poly_t'] < required_poly_t] |= self.filter_codes['low_polyt'] + failing[self.data["rmt"] == 0] |= self.filter_codes["primer_missing"] + failing[self.data["cell"] == 0] |= self.filter_codes["primer_missing"] + failing[self.data["n_poly_t"] < required_poly_t] |= self.filter_codes[ + "low_polyt" + ] - self.data['status'] = np.bitwise_or(self.data['status'], failing) + self.data["status"] = np.bitwise_or(self.data["status"], failing) def filtering_mask(self, *ignore): """return a filtering mask that, when compared to status, will return False for @@ -145,8 +153,10 @@ def filtering_mask(self, *ignore): try: mask ^= self.filter_codes[filter_] # mask filter except KeyError: - raise KeyError('%s is not a valid filter. Please select from %s' % - (filter_, repr(self.filter_codes.keys()))) + raise KeyError( + "%s is not a valid filter. Please select from %s" + % (filter_, repr(self.filter_codes.keys())) + ) return mask def iter_active(self, *ignore): @@ -160,13 +170,13 @@ def iter_active(self, *ignore): if not ignore: # save a bit of work by not &ing with mask for i, (data, gene, position) in enumerate(self): - if data['status'] == 0: + if data["status"] == 0: yield i, data, gene, position else: # create the appropriate mask for filters we want to ignore mask = self.filtering_mask(*ignore) for i, (data, gene, position) in enumerate(self): - if not data['status'] & mask: # ignores fields in ignore + if not data["status"] & mask: # ignores fields in ignore yield i, data, gene, position @classmethod @@ -186,12 +196,14 @@ def from_alignment_file(cls, alignment_file, translator, required_poly_t): # todo add a check for @GO query header (file matches sorting assumptions) - reader = sam.Reader(alignment_file) # todo swap to pysam reader, probably faster + reader = sam.Reader( + alignment_file + ) # todo swap to pysam reader, probably faster # todo allow reading of this from alignment summary num_reads = 0 num_unique = 0 - prev_alignment_name = '' + prev_alignment_name = "" for alignment in reader: num_reads += 1 if alignment.qname != prev_alignment_name: @@ -205,6 +217,8 @@ def from_alignment_file(cls, alignment_file, translator, required_poly_t): position = np.zeros(num_reads, dtype=np.int32) gene = np.zeros(num_reads, dtype=np.int32) + read_names = [] + # loop over multialignments row_idx = 0 # identifies the read index arr_idx = 0 # identifies the alignment index across all reads @@ -225,24 +239,34 @@ def from_alignment_file(cls, alignment_file, translator, required_poly_t): col_idx += 1 max_ma = max(max_ma, col_idx) + # items in ma all must have the same read name + # ma[0]==ma[1]==... + read_names.append(ma[0].qname) + cell = seqc.sequence.encodings.DNA3Bit.encode(a.cell) rmt = seqc.sequence.encodings.DNA3Bit.encode(a.rmt) - n_poly_t = a.poly_t.count('T') + a.poly_t.count('N') + n_poly_t = a.poly_t.count("T") + a.poly_t.count("N") data[row_idx] = (0, cell, rmt, n_poly_t) row_idx += 1 # some reads will not have aligned, throw away excess allocated space before # creating the ReadArray row, col, position, gene = ( - row[:arr_idx], col[:arr_idx], position[:arr_idx], gene[:arr_idx]) + row[:arr_idx], + col[:arr_idx], + position[:arr_idx], + gene[:arr_idx], + ) gene = coo_matrix((gene, (row, col)), shape=(row_idx, max_ma), dtype=np.int32) - position = coo_matrix((position, (row, col)), shape=(row_idx, max_ma), - dtype=np.int32) + position = coo_matrix( + (position, (row, col)), shape=(row_idx, max_ma), dtype=np.int32 + ) ra = cls(data, gene.tocsr(), position.tocsr()) ra.initial_filtering(required_poly_t=required_poly_t) - return ra + + return ra, read_names def group_indices_by_cell(self, multimapping=False): """group the reads in ra.data by cell. @@ -253,23 +277,22 @@ def group_indices_by_cell(self, multimapping=False): reads that correspond to a group, defined as a unique combination of the columns specified in parameter by. """ - idx = np.argsort(self.data['cell']) + idx = np.argsort(self.data["cell"]) # filter the index for reads that if multimapping: - mask = self.filtering_mask('gene_not_unique') - passing = (self.data['status'][idx] & mask) == 0 + mask = self.filtering_mask("gene_not_unique") + passing = (self.data["status"][idx] & mask) == 0 else: - passing = self.data['status'][idx] == 0 + passing = self.data["status"][idx] == 0 idx = idx[passing] # determine which positions in idx are the start of new groups (boolean, True) # convert boolean positions to indices, add start and end points. - breaks = np.where(np.diff(self.data['cell'][idx]))[0] + 1 + breaks = np.where(np.diff(self.data["cell"][idx]))[0] + 1 # use these break points to split the filtered index according to "by" return np.split(idx, breaks) - def save(self, archive_name): """save a ReadArray object as an hdf5 archive @@ -283,25 +306,26 @@ def store_carray(archive, array, name): store[:] = array store.flush() - if not archive_name.endswith('.h5'): - archive_name += '.h5' + if not archive_name.endswith(".h5"): + archive_name += ".h5" # construct container - blosc5 = tb.Filters(complevel=5, complib='blosc') - f = tb.open_file(archive_name, mode='w', title='Data for seqc.ReadArray', - filters=blosc5) + blosc5 = tb.Filters(complevel=5, complib="blosc") + f = tb.open_file( + archive_name, mode="w", title="Data for seqc.ReadArray", filters=blosc5 + ) - f.create_table(f.root, 'data', self.data) + f.create_table(f.root, "data", self.data) if self._ambiguous_genes: # each array is data, indices, indptr - store_carray(f, self.genes.indices, 'indices') - store_carray(f, self.genes.indptr, 'indptr') - store_carray(f, self.genes.data, 'gene_data') - store_carray(f, self.positions.data, 'positions_data') + store_carray(f, self.genes.indices, "indices") + store_carray(f, self.genes.indptr, "indptr") + store_carray(f, self.genes.data, "gene_data") + store_carray(f, self.positions.data, "positions_data") else: - store_carray(f, self.genes, 'genes') - store_carray(f, self.positions, 'positions') + store_carray(f, self.genes, "genes") + store_carray(f, self.positions, "positions") f.close() @@ -314,11 +338,11 @@ def load(cls, archive_name): :return ReadArray: """ - f = tb.open_file(archive_name, mode='r') + f = tb.open_file(archive_name, mode="r") data = f.root.data.read() try: - f.get_node('/genes') + f.get_node("/genes") genes = f.root.genes.read() positions = f.root.positions.read() except tb.NoSuchNodeError: @@ -345,7 +369,7 @@ def resolve_ambiguous_alignments(self): # Reset genes and positions to be an array self._ambiguous_genes = False - self.genes = np.ravel(self.genes.tocsc()[:, 0].todense()) + self.genes = np.ravel(self.genes.tocsc()[:, 0].todense()) self.positions = np.ravel(self.positions.tocsc()[:, 0].todense()) return mm_results @@ -355,7 +379,7 @@ def _resolve_alignments(self, indices_grouped_by_cells): Resolve ambiguously aligned molecules and edit the ReadArray data structures in-place to reflect the more specific gene assignments. - After loading the co alignment matrix we group the reads of the ra by cell/rmt. + After loading the co alignment matrix we group the reads of the ra by cell/rmt. In each group we look at the different disjoint subsetes of genes reads are aligned to. @@ -366,24 +390,26 @@ def _resolve_alignments(self, indices_grouped_by_cells): :return dict results: dictionary containing information on how many molecules were resolved by this algorithm """ - + # Mask for reseting status on resolved genes - mask = self.filtering_mask('gene_not_unique') + mask = self.filtering_mask("gene_not_unique") # results dictionary for tracking effect of algorithm - results = OrderedDict(( - ('unique molecules', 0), - ('cell/rmt barcode collisions', 0), - ('resolved molecules: disjoint', 0), - ('resolved molecules: model', 0), - ('ambiguous molecules', 0) - )) + results = OrderedDict( + ( + ("unique molecules", 0), + ("cell/rmt barcode collisions", 0), + ("resolved molecules: disjoint", 0), + ("resolved molecules: model", 0), + ("ambiguous molecules", 0), + ) + ) for cell_group in indices_grouped_by_cells: # Sort by molecules - inds = cell_group[np.argsort(self.data['rmt'][cell_group])] - breaks = np.where(np.diff(self.data['rmt'][inds]))[0] + 1 + inds = cell_group[np.argsort(self.data["rmt"][cell_group])] + breaks = np.where(np.diff(self.data["rmt"][inds]))[0] + 1 indices_grouped_by_molecule = np.split(inds, breaks) # Each molecule group @@ -400,11 +426,11 @@ def _resolve_alignments(self, indices_grouped_by_cells): # Return if there is only one gene group if len(gene_groups) == 1: - results['unique molecules'] += 1 + results["unique molecules"] += 1 continue # if it was not unique, there is a collision - results['cell/rmt barcode collisions'] += 1 + results["cell/rmt barcode collisions"] += 1 # Divide into disjoint sets uf = multialignment.UnionFind() @@ -412,13 +438,13 @@ def _resolve_alignments(self, indices_grouped_by_cells): set_membership, sets = uf.find_all(gene_groups.keys()) # Disambiguate each set - keys = np.array(list(gene_groups.keys())) + keys = np.array(list(gene_groups.keys()), dtype=object) for s in sets: set_groups = keys[set_membership == s] # Return if the set contains only one group if len(set_groups) == 1: - results['resolved molecules: disjoint'] += 1 + results["resolved molecules: disjoint"] += 1 continue # Disambiguate if possible @@ -426,131 +452,161 @@ def _resolve_alignments(self, indices_grouped_by_cells): # Resolved if there is only one common gene if len(common) == 1: - results['resolved molecules: model'] += 1 + results["resolved molecules: model"] += 1 for group in set_groups: # Update status - self.data['status'][gene_groups[tuple(group)]] &= mask + self.data["status"][gene_groups[tuple(group)]] &= mask for ind in gene_groups[tuple(group)]: # Update gene and position if self.genes[ind, 0] == common[0]: continue - gene_index = ( - self.genes[ind] == common[0]).nonzero()[1][0] + gene_index = (self.genes[ind] == common[0]).nonzero()[ + 1 + ][0] self.genes[ind, 0] = self.genes[ind, gene_index] self.positions[ind, 0] = self.positions[ind, gene_index] else: - results['ambiguous molecules'] += 1 + results["ambiguous molecules"] += 1 # Todo: Likelihood model goes here return results + def create_readname_cb_umi_mapping(self, read_names, path_filename): + + if read_names == None: + return + + # index with no cell error & no rmt error + noerr_idx = np.where(self.data["status"] == 0)[0] + rnames = np.array(read_names)[noerr_idx] + cell = self.data["cell"][noerr_idx] + rmt = self.data["rmt"][noerr_idx] + + df = pd.DataFrame({"read_name": rnames, "CB": cell, "UB": rmt}) + df.set_index("read_name", inplace=True) + + df.to_csv(path_filename, index=True, compression="gzip") + # todo : document me # Triplet filter from Adam def filter_low_coverage(self, alpha=0.25): - - use_inds = np.where( self.data['status'] == 0 )[0] - cell = self.data['cell'][use_inds] + + use_inds = np.where(self.data["status"] == 0)[0] + cell = self.data["cell"][use_inds] position = self.positions[use_inds] - rmt = self.data['rmt'][use_inds] + rmt = self.data["rmt"][use_inds] genes = self.genes[use_inds] - + # A triplet is a (cell, position, rmt) triplet in each gene - df = pd.DataFrame({'gene': genes, 'cell': cell, 'position': position, - 'rmt': rmt}) - grouped = df.groupby(['gene', 'position']) + df = pd.DataFrame( + {"gene": genes, "cell": cell, "position": position, "rmt": rmt} + ) + grouped = df.groupby(["gene", "position"]) # This gives the gene followed by the number of triplets at each position # Summing across each gene will give the number of total triplets in gene - num_per_position = (grouped['position'].agg({ - 'Num Triplets at Pos': np.count_nonzero})).reset_index() - - + num_per_position = ( + grouped["position"].agg({"Num Triplets at Pos": np.count_nonzero}) + ).reset_index() + # Total triplets in each gene - trips_in_gene = (num_per_position.groupby(['gene']) - )['Num Triplets at Pos'].agg({'Num Triplets at Gene': np.sum}) - + trips_in_gene = (num_per_position.groupby(["gene"]))["Num Triplets at Pos"].agg( + {"Num Triplets at Gene": np.sum} + ) + trips_in_gene = trips_in_gene.reset_index() - - num_per_position = num_per_position.merge(trips_in_gene,how = 'left') - - - # for each (c,rmt) in df check in grouped2 if it is lonely - # determine number of lonely triplets at each position - grouped2 = df.groupby(['gene','cell','rmt']) - # lonely_triplets = grouped2["position"].apply(lambda x: len(x.unique())) + + num_per_position = num_per_position.merge(trips_in_gene, how="left") + + # for each (c,rmt) in df check in grouped2 if it is lonely + # determine number of lonely triplets at each position + grouped2 = df.groupby(["gene", "cell", "rmt"]) + # lonely_triplets = grouped2["position"].apply(lambda x: len(x.unique())) # This is a list of each gene, cell, rmt combo and the positions with that criteria - lonely_triplets = grouped2['position'].apply(np.unique) + lonely_triplets = grouped2["position"].apply(np.unique) lonely_triplets = pd.DataFrame(lonely_triplets) - + # if the length is one, this is a lonely triplet - lonely_triplets_u = lonely_triplets['position'].apply(len) + lonely_triplets_u = lonely_triplets["position"].apply(len) lonely_triplets_u = pd.DataFrame(lonely_triplets_u) - + lonely_triplets_u = lonely_triplets_u.reset_index() lonely_triplets = lonely_triplets.reset_index() - + # Rename the columns - lonely_triplets = lonely_triplets.rename(columns=lambda x: x.replace( - 'position', 'lonely position')) - lonely_triplets_u = lonely_triplets_u.rename(columns=lambda x: x.replace( - 'position', 'num')) - + lonely_triplets = lonely_triplets.rename( + columns=lambda x: x.replace("position", "lonely position") + ) + lonely_triplets_u = lonely_triplets_u.rename( + columns=lambda x: x.replace("position", "num") + ) + # merge the column that is the length of the positions array # take the ones with length 1 - lonely_triplets = lonely_triplets.merge(lonely_triplets_u,how = 'left') - lonely_triplets = lonely_triplets.loc[lonely_triplets.loc[:,'num'] == 1,:] - + lonely_triplets = lonely_triplets.merge(lonely_triplets_u, how="left") + lonely_triplets = lonely_triplets.loc[lonely_triplets.loc[:, "num"] == 1, :] + # This is the gene, cell, rmt combo and the position that is lonely - # We need to convert the array to a scalar + # We need to convert the array to a scalar scalar = lonely_triplets["lonely position"].apply(np.asscalar) lonely_triplets["lonely position"] = scalar # Now if we group as such, we can determine how many (c, rmt) paris exist at each position # This would be the number of lonely pairs at a position - grouped3 = lonely_triplets.groupby(["gene","lonely position"]) - l_num_at_position = (grouped3["cell"].agg(['count'])).reset_index() - l_num_at_position = l_num_at_position.rename(columns=lambda x: x.replace( - 'count', 'lonely triplets at pos')) - l_num_at_position = l_num_at_position.rename(columns=lambda x: x.replace( - 'lonely position', 'position')) + grouped3 = lonely_triplets.groupby(["gene", "lonely position"]) + l_num_at_position = (grouped3["cell"].agg(["count"])).reset_index() + l_num_at_position = l_num_at_position.rename( + columns=lambda x: x.replace("count", "lonely triplets at pos") + ) + l_num_at_position = l_num_at_position.rename( + columns=lambda x: x.replace("lonely position", "position") + ) # lonely pairs in each gene - l_num_at_gene = (lonely_triplets.groupby(["gene"]))['lonely position'].agg( - ['count']) + l_num_at_gene = (lonely_triplets.groupby(["gene"]))["lonely position"].agg( + ["count"] + ) l_num_at_gene = l_num_at_gene.reset_index() - l_num_at_gene = l_num_at_gene.rename(columns=lambda x: x.replace( - 'count', 'lonely triplets at gen')) + l_num_at_gene = l_num_at_gene.rename( + columns=lambda x: x.replace("count", "lonely triplets at gen") + ) + + # aggregate + total = l_num_at_position.merge(l_num_at_gene, how="left") + total = total.merge(num_per_position, how="left") - # aggregate - total = l_num_at_position.merge(l_num_at_gene,how='left') - total = total.merge(num_per_position, how = 'left') - # scipy hypergeom p = total.apply(self._hypergeom_wrapper, axis=1) - p = 1-p - + p = 1 - p + from statsmodels.sandbox.stats.multicomp import multipletests as mt - adj_p = mt(p,alpha = alpha, method='fdr_bh') - + + adj_p = mt(p, alpha=alpha, method="fdr_bh") + keep = pd.DataFrame(adj_p[0]) - total['remove'] = keep - - remove = total[total['remove'] == True] - + total["remove"] = keep + + remove = total[total["remove"] == True] + final = df.merge(remove, how="left") final = final[final["remove"] == True] # Indicies to remove remove_inds = use_inds[final.index.values] - - self.data['status'][remove_inds] |= self.filter_codes['lonely_triplet'] + self.data["status"][remove_inds] |= self.filter_codes["lonely_triplet"] def _hypergeom_wrapper(self, x): - + from scipy.stats import hypergeom - p = hypergeom.cdf(x['lonely triplets at pos'],x['Num Triplets at Gene'], - x['lonely triplets at gen'],x['Num Triplets at Pos']) - return p + p = hypergeom.cdf( + x["lonely triplets at pos"], + x["Num Triplets at Gene"], + x["lonely triplets at gen"], + x["Num Triplets at Pos"], + ) + return p - def to_count_matrix(self, csv_path=None, sparse_frame=False, genes_to_symbols=False): + def to_count_matrix( + self, csv_path=None, sparse_frame=False, genes_to_symbols=False + ): """Convert the ReadArray into a count matrix. Since the matrix is sparse we represent it with 3 columns: row (cell), @@ -571,29 +627,34 @@ def to_count_matrix(self, csv_path=None, sparse_frame=False, genes_to_symbols=Fa for i, data, gene, pos in self.iter_active(): try: - reads_mat[data['cell'], gene] += 1 + reads_mat[data["cell"], gene] += 1 except KeyError: - reads_mat[data['cell'], gene] = 1 + reads_mat[data["cell"], gene] = 1 try: - rmt = data['rmt'] - if rmt not in mols_mat[data['cell'], gene]: - mols_mat[data['cell'], gene].append(rmt) + rmt = data["rmt"] + if rmt not in mols_mat[data["cell"], gene]: + mols_mat[data["cell"], gene].append(rmt) except KeyError: - mols_mat[data['cell'], gene] = [rmt] + mols_mat[data["cell"], gene] = [rmt] if sparse_frame: - return (SparseFrame.from_dict(reads_mat, genes_to_symbols=genes_to_symbols), - SparseFrame.from_dict( - {k: len(v) for k, v in mols_mat.items()}, - genes_to_symbols=genes_to_symbols)) + return ( + SparseFrame.from_dict(reads_mat, genes_to_symbols=genes_to_symbols), + SparseFrame.from_dict( + {k: len(v) for k, v in mols_mat.items()}, + genes_to_symbols=genes_to_symbols, + ), + ) if csv_path is None: return reads_mat, mols_mat # todo convert gene integers to symbols before saving csv - f = open(csv_path+'reads_count.csv', 'w') - for data['cell'], gene in reads_mat: - f.write('{},{},{}\n'.format(data['cell'], gene, reads_mat[data['cell'], gene])) + f = open(csv_path + "reads_count.csv", "w") + for data["cell"], gene in reads_mat: + f.write( + "{},{},{}\n".format(data["cell"], gene, reads_mat[data["cell"], gene]) + ) f.close() diff --git a/src/seqc/rmt_correction.py b/src/seqc/rmt_correction.py index e69a372..d3c479b 100644 --- a/src/seqc/rmt_correction.py +++ b/src/seqc/rmt_correction.py @@ -1,30 +1,69 @@ -from scipy.special import gammainc -from seqc.sequence.encodings import DNA3Bit +import os +import pickle +import math +import time +import psutil +import pandas as pd import numpy as np +from tqdm import tqdm +from scipy.special import gammainc from seqc import log from seqc.read_array import ReadArray -import time -import pandas as pd -import multiprocessing as multi -from itertools import repeat -import ctypes -from contextlib import closing -from functools import partial +import dask +from distributed import Client, LocalCluster +from dask.distributed import wait, performance_report +from tlz import partition_all +from numba import jit, njit +from collections import defaultdict -# todo document me + +log.logging.getLogger("asyncio").setLevel(log.logging.WARNING) + + +@njit +def DNA3Bit_seq_len(i: int) -> int: + """ + Return the length of an encoded sequence based on its binary representation + + :param i: int, encoded sequence + """ + l = 0 + while i > 0: + l += 1 + i >>= 3 + return l + + +@njit def generate_close_seq(seq): - """ Return a list of all sequences that are up to 2 hamm distance from seq + """Return a list of all sequences that are up to 2 hamm distance from seq :param seq: """ - res = [] - l = DNA3Bit.seq_len(seq) + DNA3Bit_bin2strdict = { + 0b100: b"A", + 0b110: b"C", + 0b101: b"G", + 0b011: b"T", + 0b111: b"N", + } + + # res = [] + res = np.empty(0, dtype=np.int64) + + l = DNA3Bit_seq_len(seq) # generate all sequences that are dist 1 for i in range(l): mask = 0b111 << (i * 3) cur_chr = (seq & mask) >> (i * 3) - res += [seq & (~mask) | (new_chr << (i * 3)) - for new_chr in DNA3Bit.bin2strdict.keys() if new_chr != cur_chr] + res = np.append( + res, + [ + seq & (~mask) | (new_chr << (i * 3)) + for new_chr in DNA3Bit_bin2strdict.keys() + if new_chr != cur_chr + ], + ) # generate all sequences that are dist 2 for i in range(l): mask_i = 0b111 << (i * 3) @@ -33,74 +72,95 @@ def generate_close_seq(seq): mask_j = 0b111 << (j * 3) chr_j = (seq & mask_j) >> (j * 3) mask = mask_i | mask_j - res += [seq & (~mask) | (new_chr_i << (i * 3)) | (new_chr_j << (j * 3)) for - new_chr_i in DNA3Bit.bin2strdict.keys() if new_chr_i != chr_i for - new_chr_j in DNA3Bit.bin2strdict.keys() if new_chr_j != chr_j] + res = np.append( + res, + [ + seq & (~mask) | (new_chr_i << (i * 3)) | (new_chr_j << (j * 3)) + for new_chr_i in DNA3Bit_bin2strdict.keys() + if new_chr_i != chr_i + for new_chr_j in DNA3Bit_bin2strdict.keys() + if new_chr_j != chr_j + ], + ) - return res + return list(res) + + +@njit +def probability_for_convert_d_to_r_float(d_seq, r_seq, err_rate): + """ + Return the probability of d_seq turning into r_seq based on the err_rate table + (all binary) + :param err_rate: for 10x e.g. 0.02 + :param r_seq: + :param d_seq: + """ + + if DNA3Bit_seq_len(d_seq) != DNA3Bit_seq_len(r_seq): + return 1 + + p = 1.0 + while d_seq > 0: + if d_seq & 0b111 != r_seq & 0b111: + p *= err_rate + d_seq >>= 3 + r_seq >>= 3 + return p -# todo document me -def probability_for_convert_d_to_r(d_seq, r_seq, err_rate): + +def probability_for_convert_d_to_r_dict(d_seq, r_seq, err_rate): """ Return the probability of d_seq turning into r_seq based on the err_rate table (all binary) - :param err_rate: + :param err_rate: for indrop e.g. {(4, 6): 0.0007813189167391277, (4, 5): 0.0013484052272755914, ...} :param r_seq: :param d_seq: """ - if DNA3Bit.seq_len(d_seq) != DNA3Bit.seq_len(r_seq): + if DNA3Bit_seq_len(d_seq) != DNA3Bit_seq_len(r_seq): return 1 p = 1.0 while d_seq > 0: if d_seq & 0b111 != r_seq & 0b111: - if isinstance(err_rate,float): - p *= err_rate - else: - p *= err_rate[(d_seq & 0b111, r_seq & 0b111)] + p *= err_rate[(d_seq & 0b111, r_seq & 0b111)] d_seq >>= 3 r_seq >>= 3 return p def in_drop(read_array, error_rate, alpha=0.05): - """ Tag any RMT errors + """Tag any RMT errors :param read_array: Read array :param error_rate: Sequencing error rate determined during barcode correction :param alpha: Tolerance for errors """ - global ra - global indices_grouped_by_cells - - ra = read_array - indices_grouped_by_cells = ra.group_indices_by_cell() - _correct_errors(error_rate, alpha) + return _correct_errors(read_array, error_rate, alpha) # a method called by each process to correct RMT for each cell -def _correct_errors_by_cell_group(err_rate, p_value, cell_index): +def _correct_errors_by_cell_group(ra, cell_group, err_rate, p_value): - cell_group = indices_grouped_by_cells[cell_index] # Breaks for each gene gene_inds = cell_group[np.argsort(ra.genes[cell_group])] breaks = np.where(np.diff(ra.genes[gene_inds]))[0] + 1 splits = np.split(gene_inds, breaks) - rmt_groups = {} + + del gene_inds + del breaks + + rmt_groups = defaultdict(list) res = [] for inds in splits: # RMT groups for ind in inds: - rmt = ra.data['rmt'][ind] - try: - rmt_groups[rmt].append(ind) - except KeyError: - rmt_groups[rmt] = [ind] + rmt = ra.data["rmt"][ind] + rmt_groups[rmt].append(ind) if len(rmt_groups) == 1: continue @@ -117,23 +177,32 @@ def _correct_errors_by_cell_group(err_rate, p_value, cell_index): for donor_rmt in generate_close_seq(rmt): # Check if donor is detected - try: + if donor_rmt in rmt_groups: donor_count = len(rmt_groups[donor_rmt]) - except KeyError: + else: continue # Build likelihood # Probability of converting donor to target - p_dtr = probability_for_convert_d_to_r(donor_rmt, rmt, err_rate) + if type(err_rate) is float: + # e.g. 10x: err_rate=0.02 + p_dtr = probability_for_convert_d_to_r_float( + donor_rmt, rmt, err_rate + ) + else: + # e.g. indrop: err_rate={(4, 6): 0.0007813189167391277, (4, 5): 0.0013484052272755914, ...} + p_dtr = probability_for_convert_d_to_r_dict( + donor_rmt, rmt, err_rate + ) # Number of occurrences expected_errors += donor_count * p_dtr # Check if jaitin correction is feasible - if not jaitin_corrected: + if not jaitin_corrected: ref_positions = ra.positions[rmt_groups[rmt]] donor_positions = ra.positions[rmt_groups[donor_rmt]] - # Is reference a subset of the donor ? + # Is reference a subset of the donor? (in terms of position) if (set(ref_positions)).issubset(donor_positions): jaitin_corrected = True jaitin_donor = donor_rmt @@ -146,27 +215,231 @@ def _correct_errors_by_cell_group(err_rate, p_value, cell_index): # Save the RMT donor # save the index of the read and index of donor rmt read for i in rmt_groups[rmt]: - res.append(i) - res.append(rmt_groups[jaitin_donor][0]) + res.append((i, rmt_groups[jaitin_donor][0])) rmt_groups.clear() return res -def _correct_errors(err_rate, p_value=0.05): - #Calculate and correct errors in RMTs - with multi.Pool(processes=multi.cpu_count()) as p: - p = multi.Pool(processes=multi.cpu_count()) - results = p.starmap(_correct_errors_by_cell_group, - zip(repeat(err_rate), repeat(p_value), range(len(indices_grouped_by_cells)))) - p.close() - p.join() - - # iterate through the list of returned read indices and donor rmts - for i in range(len(results)): - res = results[i] - if len(res) > 0: - for i in range(0, len(res), 2): - ra.data['rmt'][res[i]] = ra.data['rmt'][res[i+1]] - ra.data['status'][res[i]] |= ra.filter_codes['rmt_error'] \ No newline at end of file +def _correct_errors_by_cell_group_chunks(ra, cell_group_chunks, err_rate, p_value): + + if ra == None: + with open("pre-correction-ra.pickle", "rb") as fin: + ra = pickle.load(fin) + + return [ + _correct_errors_by_cell_group(ra, cell_group, err_rate, p_value) + for cell_group in cell_group_chunks + ] + + +def _get_cpu_count(): + # this will give the total CPU count that your (virtual) machine is equipped with. + # with LSF, this number is NOT the CPU count your job is allocated with. + + return psutil.cpu_count() + + +def _get_total_memory(): + # this will give the total memory that your (virtual) machine is equipped with. + # with LSF, this number is NOT the memory amount your job is allocated with. + + return psutil.virtual_memory().total + + +def _get_available_memory(): + # the memory that can be given instantly to processes without the system going into swap. + # with LSF, this number is probably inaccurate. + + return psutil.virtual_memory().available + + +def _calc_max_workers(ra): + # calculate based on avail memory & readarray size. + # just increasing memory won't help. lack of cpu will make each process fight for cpu time. + + # ra.data, ra.genes, and ra.positions are all numpy array + ra_size = ra.data.nbytes + ra.genes.nbytes + ra.positions.nbytes + + # extra bytes needed + extra = 2 * 1024 ** 3 + + n = math.floor(_get_available_memory() / (ra_size + extra)) + + return 1 if n == 0 else n + + +def _correct_errors(ra, err_rate, p_value=0.05): + + # True: use Dask's broadcast (ra transfer via inproc/tcp) + # False: each worker reacs ra.pickle from disk + use_dask_broadcast = False + + log.debug( + "Available CPU / RAM: {} / {} GB".format( + _get_cpu_count(), int(_get_available_memory() / 1024 ** 3) + ), + module_name="rmt_correction", + ) + + n_workers = _calc_max_workers(ra) + + log.debug( + "Estimated optimum n_workers: {}".format(n_workers), + module_name="rmt_correction", + ) + + if int(os.environ.get("SEQC_MAX_WORKERS", 0)) > 0: + n_workers = int(os.environ.get("SEQC_MAX_WORKERS")) + log.debug( + "n_workers overridden with SEQC_MAX_WORKERS: {}".format(n_workers), + module_name="rmt_correction", + ) + + # n_workers = 1 + # p_value = 0.005 + + # configure dask.distributed + # memory_terminate_fraction doesn't work for some reason + # https://github.com/dask/distributed/issues/3519 + # https://docs.dask.org/en/latest/setup/single-distributed.html#localcluster + # https://docs.dask.org/en/latest/scheduling.html#local-threads + worker_kwargs = { + "n_workers": n_workers, + "threads_per_worker": 1, + "processes": True, + "memory_limit": "64G", + "memory_target_fraction": 0.95, + "memory_spill_fraction": 0.99, + "memory_pause_fraction": False, + # "memory_terminate_fraction": False, + } + + # do not kill worker at 95% memory level + dask.config.set({"distributed.worker.memory.terminate": False}) + dask.config.set({"distributed.scheduler.allowed-failures": 50}) + + # setup Dask distributed client + cluster = LocalCluster(**worker_kwargs) + client = Client(cluster) + + # debug message + log.debug( + "Dask processes={} threads={}".format( + len(client.nthreads().values()), np.sum(list(client.nthreads().values())) + ), + module_name="rmt_correction", + ) + log.debug( + "Dask worker_kwargs " + + " ".join([f"{k}={v}" for k, v in worker_kwargs.items()]), + module_name="rmt_correction", + ) + log.debug("Dask Dashboard=" + client.dashboard_link, module_name="rmt_correction") + + # group by cells (same cell barcodes as one group) + log.debug("Grouping...", module_name="rmt_correction") + indices_grouped_by_cells = ra.group_indices_by_cell() + + if use_dask_broadcast: + # send readarray in advance to all workers (i.e. broadcast=True) + # this way, we reduce the serialization time + log.debug("Scattering ReadArray...", module_name="rmt_correction") + [future_ra] = client.scatter([ra], broadcast=True) + else: + # write ra to pickle which will be used later to parallel process rmt correction + with open("pre-correction-ra.pickle", "wb") as fout: + pickle.dump(ra, fout, protocol=4) + + # correct errors per cell group in parallel + log.debug("Submitting jobs to Dask...", module_name="rmt_correction") + with performance_report(filename="dask-report.html"): + futures = [] + + # distribute chunks to workers evenly + n_chunks = math.ceil(len(indices_grouped_by_cells) / n_workers) + chunks = partition_all(n_chunks, indices_grouped_by_cells) + + for chunk in tqdm(chunks, disable=None): + + future = client.submit( + _correct_errors_by_cell_group_chunks, + future_ra if use_dask_broadcast else None, + chunk, + err_rate, + p_value, + ) + futures.append(future) + + # wait until all done + log.debug("Waiting untill all tasks complete...", module_name="rmt_correction") + completed, not_completed = wait(futures) + + if len(not_completed) > 1: + raise Exception("There are uncompleted tasks!") + + # gather the resutls and release + log.debug( + "Collecting the task results from the workers...", module_name="rmt_correction" + ) + results = [] + for future in tqdm(completed, disable=None): + # this returns a list of a list + # len(result) should be the number of chunks e.g. 50 + result = future.result() + + # remove empty lists + result = list(filter(lambda x: len(x) > 0, result)) + + # aggregate and release + results.extend(result) + future.release() + + # clean up + del futures + del completed + del not_completed + + client.shutdown() + client.close() + + # iterate through the list of returned read indices and donor rmts + # create a mapping tble of pre-/post-correction + mapping = set() + for result in results: + for idx, idx_corrected_rmt in result: + + # record pre-/post-correction + # skip if it's already marked as rmt error + if ( + ra.data["cell"][idx], + ra.data["rmt"][idx_corrected_rmt], + ra.data["rmt"][idx], + ) in mapping: + continue + + mapping.add( + ( + ra.data["cell"][idx], + ra.data["rmt"][idx], + ra.data["rmt"][idx_corrected_rmt], + ) + ) + + # iterate through the list of returned read indices and donor rmts + # actually, update the read array object with corrected UMI + for result in results: + for idx, idx_corrected_rmt in result: + + # skip if it's already marked as rmt error + if ra.data["status"][idx_corrected_rmt] & ra.filter_codes["rmt_error"]: + continue + + # correct + ra.data["rmt"][idx] = ra.data["rmt"][idx_corrected_rmt] + + # report error + ra.data["status"][idx] |= ra.filter_codes["rmt_error"] + + return pd.DataFrame(mapping, columns=["CB", "UR", "UB"]) diff --git a/src/seqc/sequence/fastq.py b/src/seqc/sequence/fastq.py index 3481279..c14d435 100644 --- a/src/seqc/sequence/fastq.py +++ b/src/seqc/sequence/fastq.py @@ -18,7 +18,7 @@ class FastqRecord: :property average_quality: return the mean quality of FastqRecord """ - __slots__ = ['_data'] + __slots__ = ["_data"] def __init__(self, record: [bytes, bytes, bytes, bytes]): self._data = list(record) @@ -56,7 +56,7 @@ def quality(self, value: bytes): self._data[3] = value def __bytes__(self) -> bytes: - return b''.join(self._data) + return b"".join(self._data) def __str__(self) -> str: return bytes(self).decode() @@ -72,8 +72,8 @@ def annotations(self) -> list: list of annotations present in the fastq header """ try: - end = self.name.index(b';') - return self.name[:end].split(b':') + end = self.name.index(b";") + return self.name[:end].split(b":") except ValueError: return [] @@ -84,12 +84,12 @@ def metadata(self) -> dict: -------- dictionary of annotations and fields, if any are present""" try: - start = self.name.rindex(b'|') + start = self.name.rindex(b"|") except ValueError: return {} fields = {} - for field in self.name[start + 1:].split(b':'): - k, v = field.split(b'=') + for field in self.name[start + 1 :].split(b":"): + k, v = field.split(b"=") fields[k] = v return fields @@ -97,18 +97,22 @@ def add_annotation(self, values) -> None: """prepends a list of annotations to the name field of self.name :param values: """ - self._data[0] = b'@' + b':'.join(values) + b';' + self.name[1:] + self._data[0] = b"@" + b":".join(values) + b";" + self.name[1:] def add_metadata(self, values) -> None: """appends a list of metadata fields to the name field of self.name :param values: """ - self.name += b'|' + b':'.join(k + '=' + v for k, v in values.items()) + self.name += b"|" + b":".join(k + "=" + v for k, v in values.items()) def average_quality(self) -> int: """""" - return np.mean(np.frombuffer(self.quality, dtype=np.int8, count=len(self)))\ - .astype(int) - 33 + return ( + np.mean(np.frombuffer(self.quality, dtype=np.int8, count=len(self))).astype( + int + ) + - 33 + ) class Reader(reader.Reader): @@ -157,8 +161,8 @@ def estimate_sequence_length(self): data[i] = len(seq) - 1 # last character is a newline i += 1 return np.mean(data), np.std(data), np.unique(data, return_counts=True) - - + + def merge_paired(merge_function, fout, genomic, barcode=None) -> (str, int): """ General function to annotate genomic fastq with barcode information from reverse read. @@ -178,12 +182,12 @@ def merge_paired(merge_function, fout, genomic, barcode=None) -> (str, int): genomic = Reader(genomic) if barcode: barcode = Reader(barcode) - with open(fout, 'wb') as f: + with open(fout, "wb") as f: for g, b in zip(genomic, barcode): r = merge_function(g, b) f.write(bytes(r)) else: - with open(fout, 'wb') as f: + with open(fout, "wb") as f: for g in genomic: r = merge_function(g) f.write(bytes(r)) @@ -205,7 +209,7 @@ def truncate(fastq_file, lengths): length = len(record.sequence) break - print('sequence length in file is %d' % length) + print("sequence length in file is %d" % length) # remove any lengths longer than sequence length of file lengths = sorted([l for l in lengths if l < length])[::-1] # largest to smallest @@ -213,8 +217,10 @@ def truncate(fastq_file, lengths): # open a bunch of files files = [] for l in lengths: - name = fastq_file.replace('.gz', '').replace('.fastq', '') + '_%d_' % l + '.fastq' - files.append(open(name, 'wb')) + name = ( + fastq_file.replace(".gz", "").replace(".fastq", "") + "_%d_" % l + ".fastq" + ) + files.append(open(name, "wb")) i = 0 indices = list(range(len(lengths))) @@ -222,8 +228,8 @@ def truncate(fastq_file, lengths): if i > 10e6: break for j in indices: - record.sequence = record.sequence[:-1][:lengths[j]] + b'\n' - record.quality = record.quality[:-1][:lengths[j]] + b'\n' + record.sequence = record.sequence[:-1][: lengths[j]] + b"\n" + record.quality = record.quality[:-1][: lengths[j]] + b"\n" files[j].write(bytes(record)) i += 1 diff --git a/src/seqc/sequence/gtf.py b/src/seqc/sequence/gtf.py index 85578ae..7487c5c 100644 --- a/src/seqc/sequence/gtf.py +++ b/src/seqc/sequence/gtf.py @@ -1,3 +1,4 @@ +import os import re import fileinput import string @@ -12,11 +13,10 @@ class Record: to create records specific to exons, transcripts, and genes """ - __slots__ = ['_fields', '_attribute'] + __slots__ = ["_fields", "_attribute"] - _del_letters = string.ascii_letters.encode() - _del_non_letters = ''.join(set(string.printable).difference(string.ascii_letters))\ - .encode() + _del_letters = string.ascii_letters + _del_non_letters = "".join(set(string.printable).difference(string.ascii_letters)) def __init__(self, fields: list): @@ -24,34 +24,36 @@ def __init__(self, fields: list): self._attribute = {} def __repr__(self) -> str: - return '' % bytes(self).decode() + return "".format("\t".join(self._fields)) - def __bytes__(self) -> bytes: - return b'\t'.join(self._fields) + def __bytes__(self) -> str: + return "\t".join(self._fields) def _parse_attribute(self) -> None: - for field in self._fields[8].rstrip(b';\n').split(b';'): + for field in self._fields[8].rstrip(";\n").split(";"): key, *value = field.strip().split() - self._attribute[key] = b' '.join(value).strip(b'"') + self._attribute[key] = " ".join(value).strip('"') def __hash__(self) -> int: """concatenate strand, start, end, and chromosome and hash the resulting bytes""" - return hash(self._fields[6] + self._fields[3] + self._fields[4] + self._fields[0]) + return hash( + self._fields[6] + self._fields[3] + self._fields[4] + self._fields[0] + ) @property - def seqname(self) -> bytes: + def seqname(self) -> str: return self._fields[0] @property - def chromosome(self) -> bytes: + def chromosome(self) -> str: return self._fields[0] # synonym for seqname @property - def source(self) -> bytes: + def source(self) -> str: return self._fields[1] @property - def feature(self) -> bytes: + def feature(self) -> str: return self._fields[2] @property @@ -63,15 +65,15 @@ def end(self) -> int: return int(self._fields[4]) @property - def score(self) -> bytes: + def score(self) -> str: return self._fields[5] @property - def strand(self) -> bytes: + def strand(self) -> str: return self._fields[6] @property - def frame(self) -> bytes: + def frame(self) -> str: return self._fields[7] @property @@ -95,36 +97,38 @@ def attribute(self, item): self._parse_attribute() return self._attribute[item] else: - raise KeyError('%s is not a stored attribute of this gtf record' % - repr(item)) + raise KeyError( + "%s is not a stored attribute of this gtf record" % repr(item) + ) @property def integer_gene_id(self) -> int: """ENSEMBL gene id without the organism specific prefix, encoded as an integer""" - return int(self.attribute(b'gene_id').split(b'.')[0] - .translate(None, self._del_letters)) + return int( + self.attribute("gene_id").split(".")[0].translate(None, self._del_letters) + ) @property - def organism_prefix(self) -> bytes: + def organism_prefix(self) -> str: """Organism prefix of ENSEMBL gene id (e.g. ENSG for human, ENSMUSG)""" - return self.attribute(b'gene_id').translate(None, self._del_non_letters) + return self.attribute("gene_id").translate(None, self._del_non_letters) @property - def string_gene_id(self) -> bytes: + def string_gene_id(self) -> str: """ENSEMBL gene id, including organism prefix.""" - return self.attribute(b'gene_id') + return self.attribute("gene_id") @staticmethod - def int2str_gene_id(integer_id: int, organism_prefix: bytes) -> bytes: + def int2str_gene_id(integer_id: int, organism_prefix: str) -> str: """ converts an integer gene id (suffix) to a string gene id (including organism- specific suffix) - :param organism_prefix: bytes + :param organism_prefix: str :param integer_id: int """ - bytestring = str(integer_id).encode() + bytestring = str(integer_id) diff = 11 - len(bytestring) - return organism_prefix + (b'0' * diff) + bytestring + return organism_prefix + ("0" * diff) + bytestring def __eq__(self, other): """equivalent to testing if start, end, chrom and strand are the same.""" @@ -163,7 +167,9 @@ def __init__(self, gtf: str, max_transcript_length=1000): transcript sizes indicates that the majority of non-erroneous fragments of mRNA molecules should align within this region. """ - self._chromosomes_to_genes = self.construct_translator(gtf, max_transcript_length) + self._chromosomes_to_genes = self.construct_translator( + gtf, max_transcript_length + ) @staticmethod def iterate_adjusted_exons(exons, strand, max_transcript_length): @@ -180,7 +186,7 @@ def iterate_adjusted_exons(exons, strand, max_transcript_length): start, end = int(exon[3]), int(exon[4]) size = end - start if size >= max_transcript_length: - if strand == '+': + if strand == "+": yield end - max_transcript_length, end else: yield start, start + max_transcript_length @@ -255,18 +261,23 @@ def construct_translator(self, gtf, max_transcript_length): chromosome -> strand -> position which returns a gene. """ results_dictionary = defaultdict(dict) - for (tx_chromosome, tx_strand, gene_id), exons in Reader(gtf).iter_transcripts(): + for (tx_chromosome, tx_strand, gene_id), exons in Reader( + gtf + ).iter_transcripts(): for start, end in self.iterate_adjusted_exons( - exons, tx_strand, max_transcript_length): + exons, tx_strand, max_transcript_length + ): if start == end: continue # zero-length exons apparently occur in the gtf try: results_dictionary[tx_chromosome][tx_strand].addi( - start, end, gene_id) + start, end, gene_id + ) except KeyError: results_dictionary[tx_chromosome][tx_strand] = IntervalTree() results_dictionary[tx_chromosome][tx_strand].addi( - start, end, gene_id) + start, end, gene_id + ) return dict(results_dictionary) def translate(self, chromosome, strand, pos): @@ -275,16 +286,17 @@ def translate(self, chromosome, strand, pos): Uses the IntervalTree data structure to rapidly search for the corresponding identifier. - :param bytes chromosome: chromosome for this alignment - :param bytes strand: strand for this alignment (one of ['+', '-']) + :param str chromosome: chromosome for this alignment + :param str strand: strand for this alignment (one of ['+', '-']) :param int pos: position of the alignment within the chromosome :return int|None: Returns either an integer gene_id if a unique gene was found at the specified position, or None otherwise """ # todo remove duplicate exons during construction to save time try: - result = set(x.data for x in - self._chromosomes_to_genes[chromosome][strand][pos]) + result = set( + x.data for x in self._chromosomes_to_genes[chromosome][strand][pos] + ) if len(result) == 1: return first(result) # just right else: @@ -299,49 +311,71 @@ class Reader(reader.Reader): methods. :method __iter__: Iterator over all non-header records in gtf; yields Record objects. - :method iter_genes: Iterator over all genes in gtf; yields Gene objects. + :method iter_transcripts: Iterate over transcripts in a gtf file, returning a transcripts's + chromosome strand, gene_id, and a list of tab-split exon records """ def __iter__(self): """return an iterator over all non-header records in gtf""" - hook = fileinput.hook_compressed - with fileinput.input(self._files, openhook=hook, mode='r') as f: + + # fixme: workaround for https://bugs.python.org/issue36865 (as of 2019-10-24) + # force to "rt" instead of using `mode` being passed + # this will let us avoid using `.decode()` all over the place + def hook_compressed(filename, mode): + ext = os.path.splitext(filename)[1] + if ext == ".gz": + import gzip + + return gzip.open(filename, "rt") + elif ext == ".bz2": + import bz2 + + return bz2.BZ2File(filename, "rt") + else: + return open(filename, mode) + + with fileinput.input(self._files, openhook=hook_compressed, mode="r") as f: # get rid of header lines file_iterator = iter(f) first_record = next(file_iterator) - while first_record.startswith('#'): + while first_record.startswith("#"): first_record = next(file_iterator) - yield first_record.split('\t') # avoid loss of first non-comment line + # avoid loss of first non-comment line + yield first_record.split("\t") for record in file_iterator: # now, run to exhaustion - yield record.split('\t') + yield record.split("\t") @staticmethod def strip_gene_num(attribute_str): try: - gene_start = attribute_str.index('gene_id') + gene_start = attribute_str.index("gene_id") except ValueError: raise ValueError( - 'Gene_id field is missing in annotations file: {}'.format(attribute_str)) + "Gene_id field is missing in annotations file: {}".format(attribute_str) + ) try: gene_end = attribute_str.index('";', gene_start) except ValueError: raise ValueError( 'no "; in gene_id attribute, gtf file might be corrupted: {}'.format( - attribute_str)) + attribute_str + ) + ) try: - id_start = attribute_str.index('0', gene_start) + id_start = attribute_str.index("0", gene_start) except ValueError: raise ValueError( - 'Corrupt gene_id field in annotations file - {}'.format(attribute_str)) + "Corrupt gene_id field in annotations file - {}".format(attribute_str) + ) # ignore the gene version, which is located after a decimal in some gtf files try: - gene_end = attribute_str.index('.', id_start, gene_end) + gene_end = attribute_str.index(".", id_start, gene_end) except ValueError: pass @@ -357,7 +391,7 @@ def iter_transcripts(self): record = next(iterator) # skip to first transcript record and store chromosome and strand - while record[2] != 'transcript': + while record[2] != "transcript": record = next(iterator) transcript_chromosome = record[0] transcript_strand = record[6] @@ -384,13 +418,15 @@ def iter_transcripts(self): exons = [] for record in iterator: - if record[2] == 'exon': + if record[2] == "exon": exons.append(record) - elif record[2] == 'transcript': + elif record[2] == "transcript": # we want exons in inverse order exons = exons[::-1] yield ( - (transcript_chromosome, transcript_strand, transcript_gene_id), exons) + (transcript_chromosome, transcript_strand, transcript_gene_id), + exons, + ) exons = [] transcript_chromosome = record[0] transcript_strand = record[6] @@ -411,23 +447,24 @@ def create_phix_annotation(phix_fasta): """ import numpy as np - with open(phix_fasta, 'r') as f: + with open(phix_fasta, "r") as f: header = f.readline() # phiX has only one chromosome data = f.readlines() # concatenate data - contig = '' + contig = "" for line in data: contig += line.strip() # get chromosome - chromosome = header.split()[0].strip('>') - source = 'seqc' - score = '.' - frame = '.' + chromosome = header.split()[0].strip(">") + source = "seqc" + score = "." + frame = "." gene_meta = 'gene_id "PHIXG00{NUM}"; gene_name "PHIX{NAME!s}";' - exon_meta = ('gene_id "PHIXG00{NUM}"; gene_name "PHIX{NAME!s}"; ' - 'exon_id "PHIX{NAME!s}";') + exon_meta = ( + 'gene_id "PHIXG00{NUM}"; gene_name "PHIX{NAME!s}"; ' 'exon_id "PHIX{NAME!s}";' + ) # SEQC truncates genes at 1000b from the end of each transcript. However, phiX DNA # that is spiked into an experiment is not subject to library construction. Thus, @@ -438,24 +475,60 @@ def create_phix_annotation(phix_fasta): transcript_starts = np.arange(length // 1000 + 1) * 1000 transcript_ends = np.array([min(s + 1000, length) for s in transcript_starts]) - phix_gtf = phix_fasta.replace('.fa', '.gtf') + phix_gtf = phix_fasta.replace(".fa", ".gtf") - with open(phix_gtf, 'w') as f: + with open(phix_gtf, "w") as f: for i, (s, e) in enumerate(zip(transcript_starts, transcript_ends)): # add forward strand gene - gene = [chromosome, source, 'gene', str(s), str(e), score, '+', frame, - gene_meta.format(NUM=str(i + 1) * 9, NAME=i + 1)] - f.write('\t'.join(gene) + '\n') - exon = [chromosome, source, 'exon', str(s), str(e), score, '+', frame, - exon_meta.format(NUM=str(i + 1) * 9, NAME=i + 1)] - f.write('\t'.join(exon) + '\n') + gene = [ + chromosome, + source, + "gene", + str(s), + str(e), + score, + "+", + frame, + gene_meta.format(NUM=str(i + 1) * 9, NAME=i + 1), + ] + f.write("\t".join(gene) + "\n") + exon = [ + chromosome, + source, + "exon", + str(s), + str(e), + score, + "+", + frame, + exon_meta.format(NUM=str(i + 1) * 9, NAME=i + 1), + ] + f.write("\t".join(exon) + "\n") # add reverse strand gene - gene = [chromosome, source, 'gene', str(s), str(e), score, '-', frame, - gene_meta.format(NUM=str(i + 1) * 9, NAME=i + 1)] - f.write('\t'.join(gene) + '\n') - exon = [chromosome, source, 'exon', str(s), str(e), score, '-', frame, - exon_meta.format(NUM=str(i + 1) * 9, NAME=i + 1)] - f.write('\t'.join(exon) + '\n') + gene = [ + chromosome, + source, + "gene", + str(s), + str(e), + score, + "-", + frame, + gene_meta.format(NUM=str(i + 1) * 9, NAME=i + 1), + ] + f.write("\t".join(gene) + "\n") + exon = [ + chromosome, + source, + "exon", + str(s), + str(e), + score, + "-", + frame, + exon_meta.format(NUM=str(i + 1) * 9, NAME=i + 1), + ] + f.write("\t".join(exon) + "\n") def create_gene_id_to_official_gene_symbol_map(gtf: str): @@ -466,16 +539,17 @@ def create_gene_id_to_official_gene_symbol_map(gtf: str): :param gtf: str, filename of gtf file from which to create the map. """ pattern = re.compile( - r'(^.*?gene_id "[^0-9]*)([0-9]*)(\.?.*?gene_name ")(.*?)(".*?$)') + r'(^.*?gene_id "[^0-9]*)([0-9]*)(\.?.*?gene_name ")(.*?)(".*?$)' + ) gene_id_map = defaultdict(set) - with open(gtf, 'r') as f: + with open(gtf, "r") as f: for line in f: # Skip comment lines - if line.startswith('#'): + if line.startswith("#"): continue - fields = line.split('\t') # speed-up, only run regex on gene lines - if fields[2] != 'gene': + fields = line.split("\t") # speed-up, only run regex on gene lines + if fields[2] != "gene": continue match = re.match(pattern, line) # run regex @@ -493,4 +567,5 @@ def ensembl_gene_id_to_official_gene_symbol(ids, gene_id_map): objects, it is much faster to only construct the map a single time. :return list: converted ids """ - return ['-'.join(gene_id_map[i]) for i in ids] + return ["-".join(gene_id_map[i]) for i in ids] + diff --git a/src/seqc/sequence/index.py b/src/seqc/sequence/index.py index 23b2c72..b3fa7b7 100644 --- a/src/seqc/sequence/index.py +++ b/src/seqc/sequence/index.py @@ -6,11 +6,11 @@ from seqc.sequence import gtf from seqc.alignment import star from seqc.io import S3 +from seqc import log class Index: - - def __init__(self, organism, additional_id_types=None, index_folder_name='.'): + def __init__(self, organism, additional_id_types=None, index_folder_name="."): """Create an Index object for organism, requiring that a valid annotation have both an ENSEMBL id and at least one additional id provided by an additional_id_field (if provided) @@ -42,20 +42,25 @@ def __init__(self, organism, additional_id_types=None, index_folder_name='.'): # check organism input if not organism: raise ValueError( - 'organism must be formatted as genus_species in all lower case') + "organism must be formatted as genus_species in all lower case" + ) elif not isinstance(organism, str): - raise TypeError('organism must be a string') - elif any([('_' not in organism) or (organism.lower() != organism)]): + raise TypeError("organism must be a string") + elif any([("_" not in organism) or (organism.lower() != organism)]): raise ValueError( - 'organism must be formatted as genus_species in all lower case') + "organism must be formatted as genus_species in all lower case" + ) self._organism = organism # check additional_id_fields argument - if not (isinstance(additional_id_types, (list, tuple, np.ndarray)) or - additional_id_types is None): + if not ( + isinstance(additional_id_types, (list, tuple, np.ndarray)) + or additional_id_types is None + ): raise TypeError( - 'if provided, additional id fields must be a list, tuple, or numpy ' - 'array') + "if provided, additional id fields must be a list, tuple, or numpy " + "array" + ) if additional_id_types: self._additional_id_types = additional_id_types else: @@ -64,6 +69,11 @@ def __init__(self, organism, additional_id_types=None, index_folder_name='.'): # todo type checks self.index_folder_name = index_folder_name + if self.index_folder_name != ".": + os.makedirs( + os.path.join(self.index_folder_name, self.organism), exist_ok=True + ) + @property def organism(self) -> str: return self._organism @@ -77,20 +87,22 @@ def _converter_xml(self) -> str: """Generate The xml query to download an ENSEMBL BioMART file mapping ENSEMBL gene ids to any identifiers implemented in self.additional_id_fields """ - attributes = ''.join( - '' % f for f in self.additional_id_types) - genus, species = self.organism.split('_') + attributes = "".join( + '' % f for f in self.additional_id_types + ) + genus, species = self.organism.split("_") genome_name = genus[0] + species xml = ( '' - '' + "" '' '' '' - '{attr}' - '' - '\''.format(genome=genome_name, attr=attributes)) + "{attr}" + "" + "'".format(genome=genome_name, attr=attributes) + ) return xml @staticmethod @@ -102,20 +114,23 @@ def _identify_genome_file(files: [str]) -> str: :param files: list of fasta files obtained from the ENSEMBL ftp server :return str: name of the correct genome file""" for f in files: - if '.dna_sm.primary_assembly' in f: + if ".dna_sm.primary_assembly" in f: return f for f in files: - if f.endswith('.dna_sm.toplevel.fa.gz'): + if f.endswith(".dna_sm.toplevel.fa.gz"): return f - raise FileNotFoundError('could not find the correct fasta file in %r' % files) + raise FileNotFoundError("could not find the correct fasta file in %r" % files) @staticmethod - def _identify_gtf_file(files: [str], newest: int) -> str: + def _identify_gtf_file(files: [str], release_num: int) -> str: """Identify and return the basic gtf file from a list of annotation files""" + search_pattern = ".%d.chr.gtf.gz" % release_num for f in files: - if f.endswith('.%d.gtf.gz' % newest): + if f.endswith(search_pattern): return f + raise FileNotFoundError("Unable to find *.{}".format(search_pattern)) + @staticmethod def _identify_newest_release(open_ftp: FTP) -> int: """Identify the most recent genome release given an open link to ftp.ensembl.org @@ -125,34 +140,56 @@ def _identify_newest_release(open_ftp: FTP) -> int: :param FTP open_ftp: open FTP link to ftp.ensembl.org """ - open_ftp.cwd('/pub') - releases = [f for f in open_ftp.nlst() if 'release' in f] - newest = max(int(r[r.find('-') + 1:]) for r in releases) + open_ftp.cwd("/pub") + releases = [f for f in open_ftp.nlst() if f.startswith("release-")] + newest = max(int(r[r.find("-") + 1 :]) for r in releases) + return newest - 1 - def _download_fasta_file(self, ftp: FTP, download_name: str) -> None: + def _download_fasta_file( + self, ftp: FTP, download_name: str, ensemble_release: int + ) -> None: """download the fasta file for cls.organism from ftp, an open Ensembl FTP server :param FTP ftp: open FTP link to ENSEMBL :param str download_name: filename for downloaded fasta file """ - newest = self._identify_newest_release(ftp) - ftp.cwd('/pub/release-%d/fasta/%s/dna' % (newest, self.organism)) + + release_num = ( + ensemble_release if ensemble_release else self._identify_newest_release(ftp) + ) + work_dir = "/pub/release-%d/fasta/%s/dna" % (release_num, self.organism) + ftp.cwd(work_dir) ensembl_fasta_filename = self._identify_genome_file(ftp.nlst()) - with open(download_name, 'wb') as f: - ftp.retrbinary('RETR %s' % ensembl_fasta_filename, f.write) - def _download_gtf_file(self, ftp, download_name) -> None: + log.info("FASTA Ensemble Release {}".format(release_num)) + log.info("ftp://{}{}/{}".format(ftp.host, work_dir, ensembl_fasta_filename)) + + with open(download_name, "wb") as f: + ftp.retrbinary("RETR %s" % ensembl_fasta_filename, f.write) + + def _download_gtf_file( + self, ftp, download_name: str, ensemble_release: int + ) -> None: """download the gtf file for cls.organism from ftp, an open Ensembl FTP server :param FTP ftp: open FTP link to ENSEMBL :param str download_name: filename for downloaded gtf file """ - newest = self._identify_newest_release(ftp) - ftp.cwd('/pub/release-%d/gtf/%s/' % (newest, self.organism)) - ensembl_gtf_filename = self._identify_gtf_file(ftp.nlst(), newest) - with open(download_name, 'wb') as f: - ftp.retrbinary('RETR %s' % ensembl_gtf_filename, f.write) + release_num = ( + ensemble_release if ensemble_release else self._identify_newest_release(ftp) + ) + work_dir = "/pub/release-%d/gtf/%s/" % (release_num, self.organism) + ftp.cwd(work_dir) + ensembl_gtf_filename = self._identify_gtf_file(ftp.nlst(), release_num) + + log.info("GTF Ensemble Release {}".format(release_num)) + log.info( + "ftp://{}{}".format(ftp.host, os.path.join(work_dir, ensembl_gtf_filename)) + ) + + with open(download_name, "wb") as f: + ftp.retrbinary("RETR %s" % ensembl_gtf_filename, f.write) # todo remove wget dependency def _download_conversion_file(self, download_name: str) -> None: @@ -161,44 +198,54 @@ def _download_conversion_file(self, download_name: str) -> None: :param download_name: name for the downloaded file """ - cmd = ('wget -O %s \'http://www.ensembl.org/biomart/martservice?query=%s > ' - '/dev/null 2>&1' % (download_name, self._converter_xml)) + cmd = ( + "wget -O %s 'http://www.ensembl.org/biomart/martservice?query=%s > " + "/dev/null 2>&1" % (download_name, self._converter_xml) + ) err = check_call(cmd, shell=True) if err: - raise ChildProcessError('conversion file download failed: %s' % err) + raise ChildProcessError("conversion file download failed: %s" % err) def _download_ensembl_files( - self, fasta_name: str=None, gtf_name: str=None, - conversion_name: str=None) -> None: + self, + ensemble_release: int, + fasta_name: str = None, + gtf_name: str = None, + conversion_name: str = None, + ) -> None: """download the fasta, gtf, and id_mapping file for the organism defined in cls.organism + :param ensemble_release: Ensemble release number :param fasta_name: name for the downloaded fasta file :param gtf_name: name for the downloaded gtf file :param conversion_name: name for the downloaded conversion file """ if fasta_name is None: - fasta_name = '%s/%s.fa.gz' % (self.index_folder_name, self.organism) + fasta_name = "%s/%s.fa.gz" % (self.index_folder_name, self.organism) if gtf_name is None: - gtf_name = '%s/%s.gtf.gz' % (self.index_folder_name, self.organism) + gtf_name = "%s/%s.gtf.gz" % (self.index_folder_name, self.organism) if conversion_name is None: - conversion_name = '%s/%s_ids.csv' % (self.index_folder_name, self.organism) + conversion_name = "%s/%s_ids.csv" % (self.index_folder_name, self.organism) - with FTP(host='ftp.ensembl.org') as ftp: + ensemble_ftp_address = "ftp.ensembl.org" + + with FTP(host=ensemble_ftp_address) as ftp: ftp.login() - self._download_fasta_file(ftp, fasta_name) - self._download_gtf_file(ftp, gtf_name) + self._download_fasta_file(ftp, fasta_name, ensemble_release) + self._download_gtf_file(ftp, gtf_name, ensemble_release) self._download_conversion_file(conversion_name) def _subset_genes( - self, - conversion_file: str=None, - gtf_file: str=None, - truncated_annotation: str=None, - valid_biotypes=(b'protein_coding', b'lincRNA')): + self, + conversion_file: str = None, + gtf_file: str = None, + truncated_annotation: str = None, + valid_biotypes=("protein_coding", "lincRNA"), + ): """ Remove any annotation from the annotation_file that is not also defined by at least one additional identifer present in conversion file. @@ -219,47 +266,66 @@ def _subset_genes( :param conversion_file: file location of the conversion file :param gtf_file: file location of the annotation file :param truncated_annotation: name for the generated output file - :param list(bytes) valid_biotypes: only accept genes of this biotype. + :param list(str) valid_biotypes: only accept genes of this biotype. """ - if not (self.additional_id_types or valid_biotypes): # nothing to be done - return - - # change to set for efficiency - if all(isinstance(t, str) for t in valid_biotypes): - valid_biotypes = set((t.encode() for t in valid_biotypes)) - elif all(isinstance(t, bytes) for t in valid_biotypes): - valid_biotypes = set(valid_biotypes) - else: - raise TypeError('mixed-type biotypes detected. Please pass valid_biotypes ' - 'as strings or bytes objects (but not both).') if gtf_file is None: - gtf_file = '%s/%s.gtf.gz' % (self.index_folder_name, self.organism) + gtf_file = os.path.join( + self.index_folder_name, "{}.gtf.gz".format(self.organism) + ) if conversion_file is None: - conversion_file = '%s/%s_ids.csv' % (self.index_folder_name, self.organism) + conversion_file = os.path.join( + self.index_folder_name, "{}_ids.csv".format(self.organism) + ) if truncated_annotation is None: - truncated_annotation = '%s/%s_multiconsortia.gtf' % ( - self.index_folder_name, self.organism) + truncated_annotation = os.path.join( + self.index_folder_name, self.organism, "annotations.gtf" + ) + + if not (self.additional_id_types or valid_biotypes): # nothing to be done + # no need to truncate the annotation file + # let's just make a copy of the original file so that it can be added to the final output directory + cmd = "gunzip -c {} > {}".format(gtf_file, truncated_annotation) + err = check_call(cmd, shell=True) + if err: + raise ChildProcessError("conversion file download failed: %s" % err) + return + + # change to set for efficiency + valid_biotypes = set(valid_biotypes) # extract valid ensembl ids from the conversion file c = pd.read_csv(conversion_file, index_col=[0]) - valid_ensembl_ids = set(c[np.any(~c.isnull().values, axis=1)].index) + + if c.shape[1] == 1: + # index == ensembl_gene_id & col 1 == hgnc_symbol + valid_ensembl_ids = set(c[np.any(~c.isnull().values, axis=1)].index) + elif c.shape[1] == 0: + # index == ensembl_gene_id & no columns + # set to none to take all IDs + valid_ensembl_ids = None + else: + raise Exception("Not implemented/supported shape={}".format(c.shape)) # remove any invalid ids from the annotation file gr = gtf.Reader(gtf_file) - with open(truncated_annotation, 'wb') as f: + with open(truncated_annotation, "wt") as f: for line_fields in gr: record = gtf.Record(line_fields) - if (record.attribute(b'gene_id').decode() in valid_ensembl_ids and - record.attribute(b'gene_biotype') in valid_biotypes): - f.write(bytes(record)) + # include only biotypes of interest + if record.attribute("gene_biotype") in valid_biotypes: + if (valid_ensembl_ids is None) or ( + record.attribute("gene_id") in valid_ensembl_ids + ): + f.write("\t".join(line_fields)) def _create_star_index( - self, - fasta_file: str=None, - gtf_file: str=None, - genome_dir: str=None, - read_length: int=75) -> None: + self, + fasta_file: str = None, + gtf_file: str = None, + genome_dir: str = None, + read_length: int = 75, + ) -> None: """Create a new STAR index for the associated genome :param fasta_file: @@ -269,16 +335,23 @@ def _create_star_index( :return: """ if fasta_file is None: - fasta_file = '%s/%s.fa.gz' % (self.index_folder_name, self.organism) + fasta_file = os.path.join( + self.index_folder_name, "{}.fa.gz".format(self.organism) + ) if gtf_file is None: - if os.path.isfile('%s/%s_multiconsortia.gtf' % ( - self.index_folder_name, self.organism)): - gtf_file = '%s/%s_multiconsortia.gtf' % ( - self.index_folder_name, self.organism) + if os.path.isfile( + os.path.join(self.index_folder_name, self.organism, "annotations.gtf") + ): + gtf_file = os.path.join( + self.index_folder_name, self.organism, "annotations.gtf" + ) else: - gtf_file = '%s/%s.gtf.gz' % (self.index_folder_name, self.organism) + gtf_file = os.path.join( + self.index_folder_name, "{}.gtf.gz".format(self.organism) + ) if genome_dir is None: - genome_dir = '%s/%s' % (self.index_folder_name, self.organism) + genome_dir = os.path.join(self.index_folder_name, self.organism) + star.create_index(fasta_file, gtf_file, genome_dir, read_length) @staticmethod @@ -288,16 +361,23 @@ def _upload_index(index_directory: str, s3_upload_location: str) -> None: :param index_directory: folder containing index :param s3_upload_location: location to upload index on s3 """ - if not index_directory.endswith('/'): - index_directory += '/' - if not s3_upload_location.endswith('/'): - s3_upload_location += '/' - bucket, *dirs = s3_upload_location.replace('s3://', '').split('/') - key_prefix = '/'.join(dirs) - S3.upload_files(file_prefix=index_directory, bucket=bucket, key_prefix=key_prefix) + if not index_directory.endswith("/"): + index_directory += "/" + if not s3_upload_location.endswith("/"): + s3_upload_location += "/" + bucket, *dirs = s3_upload_location.replace("s3://", "").split("/") + key_prefix = "/".join(dirs) + S3.upload_files( + file_prefix=index_directory, bucket=bucket, key_prefix=key_prefix + ) def create_index( - self, valid_biotypes=('protein_coding', 'lincRNA'), s3_location: str=None): + self, + ensemble_release: int, + read_length: int, + valid_biotypes=("protein_coding", "lincRNA"), + s3_location: str = None, + ): """create an optionally upload an index :param valid_biotypes: gene biotypes that do not match values in this list will @@ -305,11 +385,19 @@ def create_index( :param s3_location: optional, s3 location to upload the index to. :return: """ - if self.index_folder_name is not '.': - os.makedirs(self.index_folder_name, exist_ok=True) - self._download_ensembl_files() + + log.info("Downloading Ensemble files...") + self._download_ensembl_files(ensemble_release) + + log.info("Subsetting genes...") self._subset_genes(valid_biotypes=valid_biotypes) - self._create_star_index() + + log.info("Creating STAR index...") + self._create_star_index(read_length=read_length) + if s3_location: - self._upload_index('%s/%s' % (self.index_folder_name, self.organism), - s3_location) + log.info("Uploading...") + self._upload_index( + "%s/%s" % (self.index_folder_name, self.organism), s3_location + ) + diff --git a/src/seqc/sparse_frame.py b/src/seqc/sparse_frame.py index 4aecb79..d0b82e1 100644 --- a/src/seqc/sparse_frame.py +++ b/src/seqc/sparse_frame.py @@ -7,7 +7,6 @@ class SparseFrame: - def __init__(self, data, index, columns): """ lightweight wrapper of scipy.stats.coo_matrix to provide pd.DataFrame-like access @@ -25,11 +24,11 @@ def __init__(self, data, index, columns): """ if not isinstance(data, coo_matrix): - raise TypeError('data must be type coo_matrix') + raise TypeError("data must be type coo_matrix") if not isinstance(index, np.ndarray): - raise TypeError('index must be type np.ndarray') + raise TypeError("index must be type np.ndarray") if not isinstance(columns, np.ndarray): - raise TypeError('columns must be type np.ndarray') + raise TypeError("columns must be type np.ndarray") self._data = data self._index = index @@ -42,7 +41,7 @@ def data(self): @data.setter def data(self, item): if not isinstance(item, coo_matrix): - raise TypeError('data must be type coo_matrix') + raise TypeError("data must be type coo_matrix") self._data = item @property @@ -54,7 +53,7 @@ def index(self, item): try: self._index = np.array(item) except: - raise TypeError('self.index must be convertible into a np.array object') + raise TypeError("self.index must be convertible into a np.array object") @property def columns(self): @@ -65,7 +64,7 @@ def columns(self, item): try: self._columns = np.array(item) except: - raise TypeError('self.columns must be convertible into a np.array object') + raise TypeError("self.columns must be convertible into a np.array object") @property def shape(self): @@ -106,18 +105,22 @@ def from_dict(cls, dictionary, genes_to_symbols=False): i_inds = np.fromiter((imap[v] for v in i), dtype=int) j_inds = np.fromiter((jmap[v] for v in j), dtype=int) - coo = coo_matrix((data, (i_inds, j_inds)), shape=(len(imap), len(jmap)), - dtype=np.int32) + coo = coo_matrix( + (data, (i_inds, j_inds)), shape=(len(imap), len(jmap)), dtype=np.int32 + ) index = np.fromiter(imap.keys(), dtype=int) columns = np.fromiter(jmap.keys(), dtype=int) if genes_to_symbols: if not os.path.isfile(genes_to_symbols): - raise ValueError('genes_to_symbols argument %s is not a valid annotation ' - 'file' % repr(genes_to_symbols)) + raise ValueError( + "genes_to_symbols argument %s is not a valid annotation " + "file" % repr(genes_to_symbols) + ) gmap = create_gene_id_to_official_gene_symbol_map(genes_to_symbols) - columns = np.array(ensembl_gene_id_to_official_gene_symbol( - columns, gene_id_map=gmap)) + columns = np.array( + ensembl_gene_id_to_official_gene_symbol(columns, gene_id_map=gmap) + ) return cls(coo, index, columns) diff --git a/src/seqc/summary/summary.py b/src/seqc/summary/summary.py index a3ac004..9bc59e0 100644 --- a/src/seqc/summary/summary.py +++ b/src/seqc/summary/summary.py @@ -98,25 +98,25 @@ def from_status_filters(cls, ra, filename): :param str filename: html file name for this section :return cls: Section containing initial filtering results """ - # todo replace whitespace characters with html equiv, add space b/w lines + description = ( - 'Initial filters are run over the sam file while our ReadArray database is ' + '

Initial filters are run over the sam file while our ReadArray database is ' 'being constructed. These filters indicate heuristic reasons why reads ' - 'should be omitted from downstream operations:

' - 'no gene: Regardless of the read\'s genomic alignment status, there was no ' - 'transcriptomic alignment for this read.
' - 'gene not unique: this indicates that more than one alignment was recovered ' - 'for this read. We attempt to resolve these multi-alignments downstream.
' - 'primer missing: This is an in-drop specific filter, it indices that the ' + 'should be omitted from downstream operations:

' + '
  • no gene: Regardless of the read\'s genomic alignment status, there was no ' + 'transcriptomic alignment for this read.
  • ' + '
  • gene not unique: This indicates that more than one alignment was recovered ' + 'for this read. We attempt to resolve these multi-alignments downstream.
  • ' + '
  • primer missing: This is an in-drop specific filter, it indices that the ' 'spacer sequence could not be identified, and thus neither a cell barcode ' - 'nor an rmt were recorded for this read.
    ' - 'low poly t: the primer did not display enough t-sequence in the primer ' + 'nor an rmt were recorded for this read.
  • ' + '
  • low poly t: The primer did not display enough t-sequence in the primer ' 'tail, where these nucleotides are expected. This indicates an increased ' 'probability that this primer randomly primed, instead of hybridizing with ' - 'the poly-a tail of an mRNA molecule.') + 'the poly-a tail of an mRNA molecule.
') description_section = TextContent(description) - # Get counts + # Get counts no_gene = np.sum(ra.data['status'] & ra.filter_codes['no_gene'] > 0) gene_not_unique = np.sum(ra.data['status'] & ra.filter_codes['gene_not_unique'] > 0) primer_missing = np.sum(ra.data['status'] & ra.filter_codes['primer_missing'] > 0) @@ -209,10 +209,23 @@ def from_cell_filtering(cls, figure_path, filename): :param str filename: html file name for this section :return: """ - description = 'description for cell filtering' # todo implement + description = [ + "Top Left: Cells whose molecule counts are below the inflection point of an ecdf constructed from cell molecule counts", + "Top Right: Fits a two-component gaussian mixture model to the data. If a component is found to fit a low-coverage fraction of the data, this fraction is set as invalid.", + "Bottom Left: Sets any cell with a fraction of mitochondrial mRNA greater than 20% to invalid.", + "Bottom Right: Fits a linear model to the relationship between number of genes detected and number of molecules detected. Cells with a lower than expected number of detected genes are set as invalid." + ] + + description = list(map(lambda text: f"
  • {text}
  • ", description)) + description = "
      " + "".join(description) + "
    " description_section = TextContent(description) - image_legend = 'image legend' # todo implement - image_section = ImageContent(figure_path, 'cell filtering figure', image_legend) + image_legend = "Cell Filtering" + # use basename in the HTML file + image_section = ImageContent( + os.path.basename(figure_path), + 'cell filtering figure', + image_legend + ) return cls( 'Cell Filtering', {'Description': description_section, 'Results': image_section}, @@ -229,8 +242,9 @@ def from_final_matrix(cls, counts_matrix, figure_path, filename): :param pd.DataFrame counts_matrix: :return: """ + # use full path to generate an image plot.Diagnostics.cell_size_histogram(counts_matrix, save=figure_path) - + # Number of cells and molecule count distributions image_legend = "Number of cells: {}
    ".format(counts_matrix.shape[0]) ms = counts_matrix.sum(axis=1) @@ -239,7 +253,12 @@ def from_final_matrix(cls, counts_matrix, figure_path, filename): image_legend += '{}th percentile: {}
    '.format(prctile, np.percentile(ms, prctile)) image_legend += "Max number of molecules: {}
    ".format(ms.max()) - image_section = ImageContent(figure_path, 'cell size figure', image_legend) + # use basename in the HTML file + image_section = ImageContent( + os.path.basename(figure_path), + 'cell size figure', + image_legend + ) return cls('Cell Summary', {'Library Size Distribution': image_section}, filename) @@ -320,26 +339,33 @@ def render(self): def compress_archive(self): root_dir, _, base_dir = self.archive_name.rpartition('/') + if root_dir == "": + # make_archive doesn't like an empty string, should be None + root_dir = None shutil.make_archive( - self.archive_name, 'gztar', root_dir, base_dir) + self.archive_name, 'gztar', root_dir, base_dir + ) return self.archive_name + '.tar.gz' class MiniSummary: - def __init__(self, output_prefix, mini_summary_d, alignment_summary_file, filter_fig, cellsize_fig): + def __init__(self, output_dir, output_prefix, mini_summary_d, alignment_summary_file, filter_fig, cellsize_fig): """ :param mini_summary_d: dictionary containing output parameters :param count_mat: count matrix after filtered :param filter_fig: filtering figure :param cellsize_fig: cell size figure """ + self.output_dir = output_dir self.output_prefix = output_prefix self.mini_summary_d = mini_summary_d self.alignment_summary_file = alignment_summary_file + + # use the full path for figures self.filter_fig = filter_fig self.cellsize_fig = cellsize_fig - self.pca_fig = output_prefix+"_pca.png" - self.tsne_and_phenograph_fig = output_prefix+"_phenograph.png" + self.pca_fig = os.path.join(output_dir, output_prefix + "_pca.png") + self.tsne_and_phenograph_fig = os.path.join(output_dir, output_prefix + "_phenograph.png") def compute_summary_fields(self, read_array, count_mat): self.count_mat = pd.DataFrame(count_mat) @@ -384,7 +410,7 @@ def compute_summary_fields(self, read_array, count_mat): # Doing PCA transformation pcaModel = PCA(n_components=min(20, counts_normalized.shape[1])) - counts_pca_reduced = pcaModel.fit_transform(counts_normalized.as_matrix()) + counts_pca_reduced = pcaModel.fit_transform(counts_normalized.values) # taking at most 20 components or total variance is greater than 80% num_comps = 0 @@ -396,7 +422,7 @@ def compute_summary_fields(self, read_array, count_mat): self.counts_after_pca = counts_pca_reduced[:, :num_comps] self.explained_variance_ratio = pcaModel.explained_variance_ratio_ - # regressed library size out of principal components + # regressed library size out of principal components for c in range(num_comps): lm = LinearRegression(normalize=False) X = self.counts_filtered.sum(1).values.reshape(len(self.counts_filtered), 1) @@ -424,7 +450,7 @@ def get_counts_filtered(self): def render(self): plot.Diagnostics.pca_components(self.pca_fig, self.explained_variance_ratio, self.counts_after_pca) - plot.Diagnostics.phenograph_clustering(self.tsne_and_phenograph_fig, self.counts_filtered.sum(1), + plot.Diagnostics.phenograph_clustering(self.tsne_and_phenograph_fig, self.counts_filtered.sum(1), self.clustering_communities, self.counts_after_tsne) self.mini_summary_d['seq_sat_rate'] = ((self.mini_summary_d['avg_reads_per_molc'] - 1.0) * 100.0 @@ -438,22 +464,39 @@ def render(self): warning_d["High percentage of cell death"] = "No" warning_d["Noisy first few principle components"] = "Yes" if (self.explained_variance_ratio[0]<=0.05) else "No" if self.mini_summary_d['seq_sat_rate'] <= 5.00: - warning_d["Low sequencing saturation rate"] = ("Yes (%.2f%%)" % (self.mini_summary_d['seq_sat_rate'])) + warning_d["Low sequencing saturation rate"] = ("Yes (%.2f%%)" % (self.mini_summary_d['seq_sat_rate'])) else: warning_d["Low sequencing saturation rate"] = "No" env = Environment(loader=PackageLoader('seqc.summary', 'templates')) section_template = env.get_template('mini_summary_base.html') - rendered_section = section_template.render(output_prefix = self.output_prefix, warning_d = warning_d, - mini_summary_d = self.mini_summary_d, cellsize_fig = self.cellsize_fig, - pca_fig = self.pca_fig, filter_fig = self.filter_fig, - tsne_and_phenograph_fig = self.tsne_and_phenograph_fig) - with open(self.output_prefix + "_mini_summary.html", 'w') as f: + + # use the basename (i.e. not full path) when rendering figures + rendered_section = section_template.render( + output_prefix = self.output_prefix, + warning_d = warning_d, + mini_summary_d = self.mini_summary_d, + cellsize_fig = os.path.basename(self.cellsize_fig), + pca_fig = os.path.basename(self.pca_fig), + filter_fig = os.path.basename(self.filter_fig), + tsne_and_phenograph_fig = os.path.basename(self.tsne_and_phenograph_fig) + ) + + # construct path for mini summary in HTML, JSON, and PDF + path_mini_html = os.path.join(self.output_dir, self.output_prefix + "_mini_summary.html") + path_mini_json = os.path.join(self.output_dir, self.output_prefix + "_mini_summary.json") + path_mini_pdf = os.path.join(self.output_dir, self.output_prefix + "_mini_summary.pdf") + + # save html + with open(path_mini_html, 'w') as f: f.write(rendered_section) - HTML(self.output_prefix + "_mini_summary.html").write_pdf(self.output_prefix + "_mini_summary.pdf") + # save pdf + HTML(path_mini_html).write_pdf(path_mini_pdf) - with open(self.output_prefix + "_mini_summary.json","w") as f: + # save json + with open(path_mini_json, "w") as f: json.dump(self.mini_summary_d, f) - return self.output_prefix + "_mini_summary.json", self.output_prefix + "_mini_summary.pdf" \ No newline at end of file + # return path to mini summary in JSON & PDF + return path_mini_json, path_mini_pdf diff --git a/src/seqc/test.py b/src/seqc/test.py deleted file mode 100644 index b9e68af..0000000 --- a/src/seqc/test.py +++ /dev/null @@ -1,388 +0,0 @@ -import os -import unittest -import gzip -import pandas as pd -import numpy as np -import ftplib -import nose2 -from nose2.tools import params -from seqc.sequence import index, gtf -from seqc.core import main -from seqc.read_array import ReadArray -import seqc -import logging - -logging.basicConfig() - -seqc_dir = '/'.join(seqc.__file__.split('/')[:-3]) + '/' - -# fill and uncomment these variables to avoid having to provide input to tests -TEST_BUCKET = "seqc-public" # None -EMAIL = None -# RSA_KEY = None - -# define some constants for testing -BARCODE_FASTQ = 's3://seqc-public/test/%s/barcode/' # platform -GENOMIC_FASTQ = 's3://seqc-public/test/%s/genomic/' # platform -MERGED = 's3://seqc-public/test/%s/%s_merged.fastq.gz' # platform, platform -SAMFILE = 's3://seqc-public/test/%s/Aligned.out.bam' # platform -INDEX = 's3://seqc-public/genomes/hg38_chr19/' -LOCAL_OUTPUT = os.environ['TMPDIR'] + 'seqc/%s/test' # test_name -REMOTE_OUTPUT = './test' -UPLOAD = 's3://%s/seqc_test/%s/' # bucket_name, test_folder -PLATFORM_BARCODES = 's3://seqc-public/barcodes/%s/flat/' # platform - - -def makedirs(output): - """make directories based on OUTPUT""" - stem, prefix = os.path.split(output) - os.makedirs(stem, exist_ok=True) - - -class TestSEQC(unittest.TestCase): - - email = globals()['EMAIL'] if 'EMAIL' in globals() else None - bucket = globals()['TEST_BUCKET'] if 'TEST_BUCKET' in globals() else None - rsa_key = globals()['RSA_KEY'] if 'RSA_KEY' in globals() else None - if rsa_key is None: - try: - rsa_key = os.environ['AWS_RSA_KEY'] - except KeyError: - pass - - def check_parameters(self): - if self.email is None: - self.email = input('please provide an email address for SEQC to mail ' - 'results: ') - if self.bucket is None: - self.bucket = input('please provide an amazon s3 bucket to upload test ' - 'results: ') - self.bucket.rstrip('/') - if self.rsa_key is None: - self.rsa_key = input('please provide an RSA key with permission to create ' - 'aws instances: ') - - # @params('in_drop', 'in_drop_v2', 'drop_seq', 'ten_x', 'mars_seq') - def test_local(self, platform='in_drop_v2'): - """test seqc after pre-downloading all files""" - with open('seqc_log.txt', 'w') as f: - f.write('Dummy log. nose2 captures input, so no log is produced. This causes ' - 'pipeline errors.\n') - test_name = 'test_no_aws_%s' % platform - makedirs(LOCAL_OUTPUT % test_name) - if self.email is None: - self.email = input('please provide an email address for SEQC to mail results: ') - argv = [ - 'run', - platform, - '-o', test_name, - '-i', INDEX, - '-b', BARCODE_FASTQ % platform, - '-g', GENOMIC_FASTQ % platform, - '--barcode-files', PLATFORM_BARCODES % platform, - '-e', self.email, - '--local'] - main.main(argv) - os.remove('./seqc_log.txt') # clean up the dummy log we made. - - # @params('in_drop', 'in_drop_v2', 'drop_seq', 'ten_x', 'mars_seq') - def test_remote_from_raw_fastq(self, platform='ten_x_v2'): - test_name = 'test_remote_%s' % platform - self.check_parameters() - argv = [ - 'run', - platform, - '-o', REMOTE_OUTPUT, - '-u', UPLOAD % (self.bucket, test_name), - '-i', INDEX, - '-e', self.email, - '-b', BARCODE_FASTQ % platform, - '-g', GENOMIC_FASTQ % platform, - '--instance-type', 'c4.large', - '--spot-bid', '1.0', - '-k', self.rsa_key, '--debug'] - if platform != 'drop_seq': - argv += ['--barcode-files', PLATFORM_BARCODES % platform] - main.main(argv) - - # @params('in_drop', 'in_drop_v2', 'drop_seq', 'ten_x', 'mars_seq') - def test_remote_from_merged(self, platform='in_drop_v2'): - test_name = 'test_remote_%s' % platform - self.check_parameters() - argv = [ - 'run', - platform, - '-o', REMOTE_OUTPUT, - '-u', UPLOAD % (self.bucket, test_name), - '-i', INDEX, - '-e', self.email, - '-m', MERGED % (platform, platform), - '-k', self.rsa_key, - '--instance-type', 'c4.large', - # '--spot-bid', '1.0' - ] - if platform != 'drop_seq': - argv += ['--barcode-files', PLATFORM_BARCODES % platform] - main.main(argv) - - # @params('in_drop', 'in_drop_v2', 'drop_seq', 'ten_x', 'mars_seq') - def test_remote_from_samfile(self, platform='in_drop_v2'): - test_name = 'test_remote_%s' % platform - self.check_parameters() - argv = [ - 'run', - platform, - '-o', REMOTE_OUTPUT, - '-u', UPLOAD % (self.bucket, test_name), - '-i', INDEX, - '-e', self.email, - '-a', SAMFILE % platform, - '-k', self.rsa_key, - '--instance-type', 'r5.2xlarge', - '--debug', - # '--spot-bid', '1.0' - ] - if platform != 'drop_seq': - argv += ['--barcode-files', PLATFORM_BARCODES % platform] - main.main(argv) - - -class TestIndex(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.outdir = os.environ['TMPDIR'] - - def test_Index_raises_ValueError_when_organism_is_not_provided(self): - self.assertRaises(ValueError, index.Index, organism='', additional_id_fields=[]) - - def test_Index_raises_ValueError_when_organism_isnt_lower_case(self): - self.assertRaises(ValueError, index.Index, organism='Homo_sapiens', - additional_id_fields=[]) - self.assertRaises(ValueError, index.Index, organism='Homo_Sapiens', - additional_id_fields=[]) - self.assertRaises(ValueError, index.Index, organism='hoMO_Sapiens', - additional_id_fields=[]) - - def test_Index_raises_ValueError_when_organism_has_no_underscore(self): - self.assertRaises(ValueError, index.Index, organism='homosapiens', - additional_id_fields=[]) - - def test_Index_raises_TypeError_when_additional_id_fields_is_not_correct_type(self): - self.assertRaises(TypeError, index.Index, organism='homo_sapiens', - additional_id_fields='not_an_array_tuple_or_list') - self.assertRaises(TypeError, index.Index, organism='homo_sapiens', - additional_id_fields='') - - def test_False_evaluating_additional_id_fields_are_accepted_but_set_empty_list(self): - idx = index.Index('homo_sapiens', []) - self.assertEqual(idx.additional_id_types, []) - idx = index.Index('homo_sapiens', tuple()) - self.assertEqual(idx.additional_id_types, []) - idx = index.Index('homo_sapiens', np.array([])) - self.assertEqual(idx.additional_id_types, []) - - def test_converter_xml_contains_one_attribute_line_per_gene_list(self): - idx = index.Index('homo_sapiens', ['hgnc_symbol', 'mgi_symbol']) - self.assertEqual(idx._converter_xml.count('Attribute name'), 3) - idx = index.Index('homo_sapiens', []) - self.assertEqual(idx._converter_xml.count('Attribute name'), 1) - - def test_converter_xml_formats_genome_as_first_initial_plus_species(self): - idx = index.Index('homo_sapiens', ['hgnc_symbol', 'mgi_symbol']) - self.assertIn('hsapiens', idx._converter_xml) - idx = index.Index('mus_musculus') - self.assertIn('mmusculus', idx._converter_xml) - - def test_can_login_to_ftp_ensembl(self): - with ftplib.FTP(host='ftp.ensembl.org') as ftp: - ftp.login() - - def test_download_converter_gets_output_and_is_pandas_loadable(self): - idx = index.Index('ciona_intestinalis', ['entrezgene']) - filename = self.outdir + 'ci.csv' - idx._download_conversion_file(filename) - converter = pd.read_csv(filename, index_col=0) - self.assertGreaterEqual(len(converter), 10) - self.assertEqual(converter.shape[1], 1) - os.remove(filename) # cleanup - - def test_identify_newest_release_finds_a_release_which_is_gt_eq_85(self): - idx = index.Index('ciona_intestinalis', ['entrezgene']) - with ftplib.FTP(host='ftp.ensembl.org') as ftp: - ftp.login() - newest = idx._identify_newest_release(ftp) - self.assertGreaterEqual(int(newest), 85) # current=85, they only get bigger - - def test_identify_genome_file_finds_primary_assembly_when_present(self): - idx = index.Index('homo_sapiens', ['entrezgene']) - with ftplib.FTP(host='ftp.ensembl.org') as ftp: - ftp.login() - newest = idx._identify_newest_release(ftp) - ftp.cwd('/pub/release-%d/fasta/%s/dna' % (newest, idx.organism)) - filename = idx._identify_genome_file(ftp.nlst()) - self.assertIn('primary_assembly', filename) - - def test_identify_genome_file_defaults_to_toplevel_when_no_primary_assembly(self): - idx = index.Index('ciona_intestinalis', ['entrezgene']) - with ftplib.FTP(host='ftp.ensembl.org') as ftp: - ftp.login() - newest = idx._identify_newest_release(ftp) - ftp.cwd('/pub/release-%d/fasta/%s/dna' % (newest, idx.organism)) - filename = idx._identify_genome_file(ftp.nlst()) - self.assertIn('toplevel', filename) - - def test_download_fasta_file_gets_a_properly_formatted_file(self): - idx = index.Index('ciona_intestinalis', ['entrezgene']) - with ftplib.FTP(host='ftp.ensembl.org') as ftp: - ftp.login() - filename = self.outdir + 'ci.fa.gz' - idx._download_fasta_file(ftp, filename) - with gzip.open(filename, 'rt') as f: - self.assertIs(f.readline()[0], '>') # starting character for genome fa record - os.remove(filename) - - def test_identify_annotation_file_finds_a_gtf_file(self): - idx = index.Index('ciona_intestinalis', ['entrezgene']) - with ftplib.FTP(host='ftp.ensembl.org') as ftp: - ftp.login() - newest = idx._identify_newest_release(ftp) - ftp.cwd('/pub/release-%d/gtf/%s/' % (newest, idx.organism)) - filename = idx._identify_gtf_file(ftp.nlst(), newest) - self.assertIsNotNone(filename) - - def test_download_gtf_file_gets_a_file_readable_by_seqc_gtf_reader(self): - idx = index.Index('ciona_intestinalis', ['entrezgene']) - with ftplib.FTP(host='ftp.ensembl.org') as ftp: - ftp.login() - filename = self.outdir + 'ci.gtf.gz' - idx._download_gtf_file(ftp, filename) - rd = gtf.Reader(filename) - rc = next(rd.iter_genes()) - self.assertIsInstance(rc, gtf.Gene) - os.remove(filename) - - def test_subset_genes_does_nothing_if_no_additional_fields_or_valid_biotypes(self): - idx = index.Index('ciona_intestinalis') - fasta_name = self.outdir + 'ci.fa.gz' - gtf_name = self.outdir + 'ci.gtf.gz' - conversion_name = self.outdir + 'ci_ids.csv' - idx._download_ensembl_files(fasta_name, gtf_name, conversion_name) - idx._subset_genes(conversion_name, gtf_name, self.outdir + 'test.csv', - valid_biotypes=None) - self.assertFalse(os.path.isfile(self.outdir)) - - def test_subset_genes_produces_a_reduced_annotation_file_when_passed_fields(self): - organism = 'ciona_intestinalis' - idx = index.Index(organism, ['entrezgene']) - os.chdir(self.outdir) - idx._download_ensembl_files() - self.assertTrue(os.path.isfile('%s.fa.gz' % organism), 'fasta file not found') - self.assertTrue(os.path.isfile('%s.gtf.gz' % organism), 'gtf file not found') - self.assertTrue(os.path.isfile('%s_ids.csv' % organism), 'id file not found') - - idx._subset_genes() - self.assertTrue(os.path.isfile('%s_multiconsortia.gtf' % organism)) - gr_subset = gtf.Reader('%s_multiconsortia.gtf' % organism) - gr_complete = gtf.Reader('%s.gtf.gz' % organism) - self.assertLess( - len(gr_subset), len(gr_complete), - 'Subset annotation was not smaller than the complete annotation') - - # make sure only valid biotypes are returned - complete_invalid = False - valid_biotypes = {b'protein_coding', b'lincRNA'} - for r in gr_complete.iter_genes(): - if r.attribute(b'gene_biotype') not in valid_biotypes: - complete_invalid = True - break - self.assertTrue(complete_invalid) - subset_invalid = False - for r in gr_subset.iter_genes(): - if r.attribute(b'gene_biotype') not in valid_biotypes: - subset_invalid = True - break - self.assertFalse(subset_invalid) - self.assertGreater(len(gr_subset), 0) - - def test_create_star_index_produces_an_index(self): - organism = 'ciona_intestinalis' - idx = index.Index(organism, ['entrezgene']) - os.chdir(self.outdir) - idx._download_ensembl_files() - idx._subset_genes() - print(os.getcwd()) - idx._create_star_index() - self.assertTrue(os.path.isfile('{outdir}/{organism}/SAindex'.format( - outdir=self.outdir, organism=organism))) - - def test_upload_star_index_correctly_places_index_on_s3(self): - os.chdir(self.outdir) - if 'TEST_BUCKET' in globals(): - bucket = globals()['TEST_BUCKET'] - else: - bucket = input('please provide an amazon s3 bucket to upload test results: ') - organism = 'ciona_intestinalis' - idx = index.Index(organism, ['entrezgene']) - index_directory = organism + '/' - idx._download_ensembl_files() - idx._subset_genes() - idx._create_star_index() - idx._upload_index(index_directory, 's3://%s/genomes/ciona_intestinalis/' % bucket) - - def test_create_index_produces_and_uploads_an_index(self): - if 'TEST_BUCKET' in globals(): - bucket = globals()['TEST_BUCKET'] - else: - bucket = input('please provide an amazon s3 bucket to upload test results: ') - organism = 'ciona_intestinalis' - idx = index.Index(organism, ['entrezgene'], self.outdir) - idx.create_index(s3_location='s3://%s/genomes/%s/' % (bucket, idx.organism)) - - -class TestReadArrayCreation(unittest.TestCase): - - @classmethod - def setUpClass(cls): - platform = 'in_drop_v2' - # cls.bamfile = LOCAL_OUTPUT % platform + '_bamfile.bam' - # cls.annotation = LOCAL_OUTPUT % platform + '_annotations.gtf' - # S3.download(SAMFILE % platform, cls.bamfile, recursive=False) - # S3.download(INDEX + 'annotations.gtf', cls.annotation, recursive=False) - - cls.bamfile = os.path.expanduser('~/Downloads/mm_test_short.bam') - cls.annotation = os.path.expanduser('~/Downloads/annotations.gtf') - cls.summary = os.path.expanduser('~/Downloads/mm_test_summary.txt') - cls.total_input_reads = 12242659 - cls.translator = gtf.GeneIntervals(cls.annotation, 10000) - - def test_read_array_creation(self): - ra = ReadArray.from_alignment_file(self.bamfile, self.translator, - required_poly_t=0) - print(repr(ra.data.shape[0])) - print(repr(ra.genes)) - ra.save(os.path.expanduser('~/Downloads/test.ra')) - - -class TestTranslator(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.annotation = os.path.expanduser('~/Downloads/annotations.gtf') - - def test_construct_translator(self): - translator = gtf.GeneIntervals(self.annotation) - print(len(translator._chromosomes_to_genes)) - - def get_length_of_gtf(self): - rd = gtf.Reader(self.annotation) - # print(len(rd)) - print(sum(1 for _ in rd.iter_transcripts())) - - -######################################################################################### - -if __name__ == "__main__": - nose2.main() - -######################################################################################### diff --git a/src/seqc/tests/__init__.py b/src/seqc/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/seqc/tests/test_args.py b/src/seqc/tests/test_args.py new file mode 100644 index 0000000..de7f96f --- /dev/null +++ b/src/seqc/tests/test_args.py @@ -0,0 +1,78 @@ +import nose2 +import unittest + +import seqc +from seqc.core import main + + +# class TestSEQC(unittest.TestCase): +# def setUp(self): +# pass + +# def tearDown(self): +# pass + +# def test_args(self): + +# argv = ["start", "-k", "/Users/dchun/dpeerlab-chunj.pem", "-t", "t2.micro"] + +# self.assertRaises(ValueError, lambda: main.main(argv)) + +# class MyUnitTest(unittest.TestCase): +# def setUp(self): +# pass + +# def tearDown(self): +# pass + +# def test_args(self): + +# # argv = [ +# # "run", "ten_x_v2", "--local", +# # "--index", "s3://seqc-public/genomes/hg38_chr19/", +# # "--barcode-files", "s3://seqc-public/barcodes/ten_x_v2/flat/", +# # "--genomic-fastq", "./test-data/genomic/", +# # "--barcode-fastq", "./test-data/barcode/", +# # "--output-prefix", "./test-data/seqc-results/", +# # "--email", "jaeyoung.chun@gmail.com", +# # "--star-args", "\"runRNGseed=0\"" +# # ] + +# argv = [ +# "run" +# ] + +# try: +# main.main(argv) +# # self.assertRaises(BaseException, lambda: main.main(argv)) +# except: +# pass +# # self.assertRaises(ValueError, lambda: main.main(argv)) + + +# class TestSEQC(unittest.TestCase): +# def setUp(self): +# pass + +# def tearDown(self): +# pass + +# def test_args(self): + +# from seqc.sequence import gtf + +# # remove any invalid ids from the annotation file +# gr = gtf.Reader("./test-data/homo_sapiens.gtf.gz") + +# for line_fields in gr: +# record = gtf.Record(line_fields) +# print(record) +# biotype = record.attribute("gene_biotype") +# print(biotype) + +# # self.assertRaises(ValueError, lambda: main.main(argv)) + + +if __name__ == "__main__": + + unittest.main() diff --git a/src/seqc/tests/test_dataset.py b/src/seqc/tests/test_dataset.py new file mode 100644 index 0000000..d269b2b --- /dev/null +++ b/src/seqc/tests/test_dataset.py @@ -0,0 +1,24 @@ +from collections import namedtuple + +TestDataset = namedtuple( + "datasets", + ["barcode_fastq", "genomic_fastq", "merged_fastq", "bam", "index", "barcodes",], +) + +dataset_s3 = TestDataset( + barcode_fastq="s3://seqc-public/test/%s/barcode/", # platform + genomic_fastq="s3://seqc-public/test/%s/genomic/", # platform + merged_fastq="s3://seqc-public/test/%s/%s_merged.fastq.gz", # platform, platform + bam="s3://seqc-public/test/%s/Aligned.out.bam", # platform + index="s3://seqc-public/genomes/hg38_chr19/", + barcodes="s3://seqc-public/barcodes/%s/flat/", # platform +) + +dataset_local = TestDataset( + barcode_fastq="test-data/datasets/%s/barcode/", # platform + genomic_fastq="test-data/datasets/%s/genomic/", # platform + merged_fastq=None, + bam="test-data/datasets/%s/Aligned.out.bam", # platform + index="test-data/datasets/genomes/hg38_chr19/", + barcodes="test-data/datasets/barcodes/%s/flat/", # platform +) diff --git a/src/seqc/tests/test_index.py b/src/seqc/tests/test_index.py new file mode 100644 index 0000000..b45f58f --- /dev/null +++ b/src/seqc/tests/test_index.py @@ -0,0 +1,393 @@ +import os +import shutil +import uuid +import unittest +import gzip +import pandas as pd +import numpy as np +import ftplib +import nose2 +from nose2.tools import params +import seqc +from seqc.sequence import index, gtf +from seqc import io + + +def expected_output_files(): + + files = set( + [ + "Genome", + "SA", + "SAindex", + "annotations.gtf", + "chrLength.txt", + "chrName.txt", + "chrNameLength.txt", + "chrStart.txt", + "exonGeTrInfo.tab", + "exonInfo.tab", + "geneInfo.tab", + "genomeParameters.txt", + "sjdbInfo.txt", + "sjdbList.fromGTF.out.tab", + "sjdbList.out.tab", + "transcriptInfo.tab", + ] + ) + + return files + + +class TestIndexRemote(unittest.TestCase): + + s3_bucket = "dp-lab-cicd" + + @classmethod + def setUp(cls): + cls.test_id = str(uuid.uuid4()) + cls.outdir = os.path.join(os.environ["TMPDIR"], "seqc-test", cls.test_id) + os.makedirs(cls.outdir, exist_ok=True) + + @classmethod + def tearDown(self): + if os.path.isdir(self.outdir): + shutil.rmtree(self.outdir, ignore_errors=True) + + def test_upload_star_index_correctly_places_index_on_s3(self): + organism = "ciona_intestinalis" + # must end with a slash + test_folder = f"seqc/index-{organism}-{self.test_id}/" + + idx = index.Index( + organism, ["external_gene_name"], index_folder_name=self.outdir + ) + index_directory = os.path.join(self.outdir, organism) + "/" + idx._download_ensembl_files(ensemble_release=None) + idx._subset_genes() + idx._create_star_index() + idx._upload_index(index_directory, f"s3://{self.s3_bucket}/{test_folder}") + + # check files generated in S3 + files = io.S3.listdir(self.s3_bucket, test_folder) + + # extract only filenames (i.e. remove directory hierarchy) + # convert to a set for easy comparison + files = set(map(lambda filename: filename.replace(test_folder, ""), files)) + + # check for the exact same filenames + self.assertSetEqual(files, expected_output_files()) + + def test_create_index_produces_and_uploads_an_index(self): + organism = "ciona_intestinalis" + # must end with a slash + test_folder = f"seqc/index-{organism}-{self.test_id}/" + + idx = index.Index( + organism, ["external_gene_name"], index_folder_name=self.outdir + ) + idx.create_index( + s3_location=f"s3://{self.s3_bucket}/{test_folder}", + ensemble_release=None, + read_length=101, + ) + + # check files generated in S3 + files = io.S3.listdir(self.s3_bucket, test_folder) + + # extract only filenames (i.e. remove directory hierarchy) + # convert to a set for easy comparison + files = set(map(lambda filename: filename.replace(test_folder, ""), files)) + + # check for the exact same filenames + self.assertSetEqual(files, expected_output_files()) + + +class MyUnitTest(unittest.TestCase): + + s3_bucket = "dp-lab-cicd" + + @classmethod + def setUp(cls): + cls.test_id = str(uuid.uuid4()) + cls.outdir = os.path.join(os.environ["TMPDIR"], "seqc-test", cls.test_id) + os.makedirs(cls.outdir, exist_ok=True) + + @classmethod + def tearDown(self): + if os.path.isdir(self.outdir): + shutil.rmtree(self.outdir, ignore_errors=True) + + def test_Index_raises_ValueError_when_organism_is_not_provided(self): + self.assertRaises(ValueError, index.Index, organism="", additional_id_types=[]) + + def test_Index_raises_ValueError_when_organism_isnt_lower_case(self): + self.assertRaises( + ValueError, index.Index, organism="Homo_sapiens", additional_id_types=[] + ) + self.assertRaises( + ValueError, index.Index, organism="Homo_Sapiens", additional_id_types=[] + ) + self.assertRaises( + ValueError, index.Index, organism="hoMO_Sapiens", additional_id_types=[] + ) + + def test_Index_raises_ValueError_when_organism_has_no_underscore(self): + self.assertRaises( + ValueError, index.Index, organism="homosapiens", additional_id_types=[] + ) + + def test_Index_raises_TypeError_when_additional_id_fields_is_not_correct_type(self): + self.assertRaises( + TypeError, + index.Index, + organism="homo_sapiens", + additional_id_types="not_an_array_tuple_or_list", + ) + self.assertRaises( + TypeError, index.Index, organism="homo_sapiens", additional_id_types="" + ) + + def test_False_evaluating_additional_id_fields_are_accepted_but_set_empty_list( + self, + ): + idx = index.Index("homo_sapiens", []) + self.assertEqual(idx.additional_id_types, []) + idx = index.Index("homo_sapiens", tuple()) + self.assertEqual(idx.additional_id_types, []) + idx = index.Index("homo_sapiens", np.array([])) + self.assertEqual(idx.additional_id_types, []) + + def test_converter_xml_contains_one_attribute_line_per_gene_list(self): + idx = index.Index("homo_sapiens", ["hgnc_symbol", "mgi_symbol"]) + self.assertEqual(idx._converter_xml.count("Attribute name"), 3) + idx = index.Index("homo_sapiens", []) + self.assertEqual(idx._converter_xml.count("Attribute name"), 1) + + def test_converter_xml_formats_genome_as_first_initial_plus_species(self): + idx = index.Index("homo_sapiens", ["hgnc_symbol", "mgi_symbol"]) + self.assertIn("hsapiens", idx._converter_xml) + idx = index.Index("mus_musculus") + self.assertIn("mmusculus", idx._converter_xml) + + def test_identify_gtf_file_should_return_correct_file(self): + + files = [ + "CHECKSUMS", + "Homo_sapiens.GRCh38.86.abinitio.gtf.gz", + "Homo_sapiens.GRCh38.86.chr.gtf.gz", + "Homo_sapiens.GRCh38.85.chr.gtf.gz", + "Homo_sapiens.GRCh38.86.chr_patch_hapl_scaff.gtf.gz", + "Homo_sapiens.GRCh38.86.gtf.gz", + "README", + ] + release_num = 86 + + filename = index.Index._identify_gtf_file(files, release_num) + + self.assertEqual(filename, "Homo_sapiens.GRCh38.86.chr.gtf.gz") + + def test_identify_gtf_file_should_throw_exception(self): + + files = [ + "CHECKSUMS", + "Homo_sapiens.GRCh38.86.abinitio.gtf.gz", + "Homo_sapiens.GRCh38.86.chr_patch_hapl_scaff.gtf.gz", + "Homo_sapiens.GRCh38.86.gtf.gz", + "README", + ] + release_num = 86 + + self.assertRaises( + FileNotFoundError, + index.Index._identify_gtf_file, + files=files, + release_num=release_num, + ) + + def test_can_login_to_ftp_ensembl(self): + with ftplib.FTP(host="ftp.ensembl.org") as ftp: + ftp.login() + + def test_download_converter_gets_output_and_is_pandas_loadable(self): + idx = index.Index("ciona_intestinalis", ["external_gene_name"]) + filename = os.path.join(self.outdir, "ci.csv") + idx._download_conversion_file(filename) + converter = pd.read_csv(filename, index_col=0) + self.assertGreaterEqual(len(converter), 10) + self.assertEqual(converter.shape[1], 1) + + def test_identify_newest_release_finds_a_release_which_is_gt_eq_85(self): + idx = index.Index("ciona_intestinalis", ["external_gene_name"]) + with ftplib.FTP(host="ftp.ensembl.org") as ftp: + ftp.login() + newest = idx._identify_newest_release(ftp) + self.assertGreaterEqual(int(newest), 85) # current=85, they only get bigger + + def test_identify_genome_file_finds_primary_assembly_when_present(self): + idx = index.Index("homo_sapiens", ["entrezgene"]) + with ftplib.FTP(host="ftp.ensembl.org") as ftp: + ftp.login() + newest = idx._identify_newest_release(ftp) + ftp.cwd("/pub/release-%d/fasta/%s/dna" % (newest, idx.organism)) + filename = idx._identify_genome_file(ftp.nlst()) + self.assertIn("primary_assembly", filename) + + def test_identify_genome_file_defaults_to_toplevel_when_no_primary_assembly(self): + idx = index.Index("ciona_intestinalis", ["entrezgene"]) + with ftplib.FTP(host="ftp.ensembl.org") as ftp: + ftp.login() + newest = idx._identify_newest_release(ftp) + ftp.cwd("/pub/release-%d/fasta/%s/dna" % (newest, idx.organism)) + filename = idx._identify_genome_file(ftp.nlst()) + self.assertIn("toplevel", filename) + + def test_download_fasta_file_gets_a_properly_formatted_file(self): + idx = index.Index( + "ciona_intestinalis", ["external_gene_name"], index_folder_name=self.outdir + ) + with ftplib.FTP(host="ftp.ensembl.org") as ftp: + ftp.login() + filename = os.path.join(self.outdir, "ci.fa.gz") + idx._download_fasta_file(ftp, filename, ensemble_release=None) + with gzip.open(filename, "rt") as f: + self.assertIs( + f.readline()[0], ">" + ) # starting character for genome fa record + + def test_identify_annotation_file_finds_a_gtf_file(self): + idx = index.Index("ciona_intestinalis", ["external_gene_name"]) + with ftplib.FTP(host="ftp.ensembl.org") as ftp: + ftp.login() + newest = idx._identify_newest_release(ftp) + ftp.cwd("/pub/release-%d/gtf/%s/" % (newest, idx.organism)) + filename = idx._identify_gtf_file(ftp.nlst(), newest) + self.assertIsNotNone(filename) + + def test_download_gtf_file_gets_a_file_readable_by_seqc_gtf_reader(self): + + idx = index.Index("ciona_intestinalis", ["entrezgene"]) + + with ftplib.FTP(host="ftp.ensembl.org") as ftp: + ftp.login() + filename = self.outdir + "ci.gtf.gz" + idx._download_gtf_file(ftp, filename, ensemble_release=99) + + rd = gtf.Reader(filename) + (transcript_chromosome, transcript_strand, transcript_gene_id), exons = next( + rd.iter_transcripts() + ) + + # (('1', '+', 17842), [['1', 'ensembl', 'exon', '1636', '1902', '.', '+', '.', 'gene_id "ENSCING00000017842"; gene_version "1"; transcript_id "ENSCINT00000030147"; transcript_version "1"; exon_number "1"; gene_name "RNaseP_nuc"; gene_source "ensembl"; gene_biotype "misc_RNA"; transcript_name "RNaseP_nuc-201"; transcript_source "ensembl"; transcript_biotype "misc_RNA"; exon_id "ENSCINE00000207263"; exon_version "1";\n']]) + self.assertEqual(transcript_chromosome, "1") + self.assertEqual(transcript_strand, "+") + self.assertEqual(transcript_gene_id, 17842) + self.assertEqual(len(exons), 1) + + def test_subset_genes_should_returns_original_if_no_additional_fields_or_valid_biotypes( + self, + ): + + fasta_name = os.path.join(self.outdir, "ci.fa.gz") + gtf_name = os.path.join(self.outdir, "ci.gtf.gz") + conversion_name = os.path.join(self.outdir, "ci_ids.csv") + + idx = index.Index("ciona_intestinalis", index_folder_name=self.outdir) + + idx._download_ensembl_files( + ensemble_release=None, + fasta_name=fasta_name, + gtf_name=gtf_name, + conversion_name=conversion_name, + ) + truncated_gtf = os.path.join(self.outdir, "test.gtf") + idx._subset_genes(conversion_name, gtf_name, truncated_gtf, valid_biotypes=None) + + # expect the same file as the original file + self.assertTrue(os.path.isfile(truncated_gtf)) + + # the current implementation of GTF Reader doesn't allow this: + # for gr1, gr2 in zip(gtf.Reader(gtf_name), gtf.Reader(truncated_gtf)): + + records = [] + for gr in gtf.Reader(gtf_name): + records.append(gtf.Record(gr)) + + for i, gr in enumerate(gtf.Reader(truncated_gtf)): + rec1 = records[i] + rec2 = gtf.Record(gr) + self.assertEqual(rec1, rec2) + + def test_subset_genes_produces_a_reduced_annotation_file_when_passed_fields(self): + organism = "ciona_intestinalis" + idx = index.Index( + organism, ["external_gene_name"], index_folder_name=self.outdir + ) + idx._download_ensembl_files(ensemble_release=None) + self.assertTrue( + os.path.isfile(os.path.join(self.outdir, "%s.fa.gz" % organism)), + "fasta file not found", + ) + self.assertTrue( + os.path.isfile(os.path.join(self.outdir, "%s.gtf.gz" % organism)), + "gtf file not found", + ) + self.assertTrue( + os.path.isfile(os.path.join(self.outdir, "%s_ids.csv" % organism)), + "id file not found", + ) + + valid_biotypes = {"protein_coding", "lincRNA"} + idx._subset_genes(valid_biotypes=valid_biotypes) + + self.assertTrue( + os.path.isfile(os.path.join(self.outdir, organism, "annotations.gtf")) + ) + gr_subset = gtf.Reader(os.path.join(self.outdir, organism, "annotations.gtf")) + gr_complete = gtf.Reader(os.path.join(self.outdir, "%s.gtf.gz" % organism)) + self.assertLess( + len(gr_subset), + len(gr_complete), + "Subset annotation was not smaller than the complete annotation", + ) + + # make sure only valid biotypes are returned + complete_invalid = False + + for r in gr_complete: + record = gtf.Record(r) + if record.attribute("gene_biotype") not in valid_biotypes: + complete_invalid = True + break + self.assertTrue(complete_invalid) + + subset_invalid = False + for r in gr_subset: + record = gtf.Record(r) + if record.attribute("gene_biotype") not in valid_biotypes: + subset_invalid = True + break + self.assertFalse(subset_invalid) + self.assertGreater(len(gr_subset), 0) + + def test_create_star_index_produces_an_index(self): + organism = "ciona_intestinalis" + idx = index.Index( + organism, ["external_gene_name"], index_folder_name=self.outdir + ) + idx._download_ensembl_files(ensemble_release=None) + idx._subset_genes() + idx._create_star_index() + expected_file = os.path.join( + self.outdir, + "{outdir}/{organism}/SAindex".format(outdir=self.outdir, organism=organism), + ) + self.assertTrue(os.path.isfile(expected_file)) + + +######################################################################################### + +if __name__ == "__main__": + nose2.main() + +######################################################################################### diff --git a/src/seqc/tests/test_run_e2e_local.py b/src/seqc/tests/test_run_e2e_local.py new file mode 100644 index 0000000..0c7770a --- /dev/null +++ b/src/seqc/tests/test_run_e2e_local.py @@ -0,0 +1,130 @@ +import unittest +import os +import uuid +import shutil +import subprocess +import re +from nose2.tools import params +from seqc.core import main +from test_dataset import dataset_local, dataset_s3 + + +def get_output_file_list(test_id, test_folder): + + proc = subprocess.Popen( + ["find", test_folder, "-type", "f"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stdout, _ = proc.communicate() + files = stdout.decode().splitlines() + + # extract only filenames (i.e. remove directory hierarchy) + # convert to a set for easy comparison + files = set(map(lambda filename: filename.replace(test_folder + "/", ""), files)) + + return files + + +def expected_output_files(file_prefix): + + files = set( + [ + f"{file_prefix}.h5", + f"{file_prefix}_alignment_summary.txt", + f"{file_prefix}_cell_filters.png", + f"{file_prefix}_de_gene_list.txt", + f"{file_prefix}_dense.csv", + f"{file_prefix}_merged.fastq.gz", + f"{file_prefix}_mini_summary.json", + f"{file_prefix}_mini_summary.pdf", + f"{file_prefix}_seqc_log.txt", + f"{file_prefix}_sparse_counts_barcodes.csv", + f"{file_prefix}_sparse_counts_genes.csv", + f"{file_prefix}_sparse_molecule_counts.mtx", + f"{file_prefix}_sparse_read_counts.mtx", + f"{file_prefix}_summary.tar.gz", + f"{file_prefix}_Aligned.out.bam", + ] + ) + + return files + + +class TestRunLocal(unittest.TestCase): + @classmethod + def setUp(cls): + cls.test_id = str(uuid.uuid4()) + cls.path_temp = os.path.join( + os.environ["TMPDIR"], "seqc-test", str(uuid.uuid4()) + ) + os.makedirs(cls.path_temp, exist_ok=True) + with open("seqc_log.txt", "wt") as f: + f.write("Dummy log.\n") + f.write("nose2 captures input, so no log is produced.\n") + f.write("This causes pipeline errors.\n") + + @classmethod + def tearDown(self): + if os.path.isdir(self.path_temp): + shutil.rmtree(self.path_temp, ignore_errors=True) + + def test_using_dataset_in_s3(self, platform="ten_x_v2"): + # must NOT end with a slash + file_prefix = "test" + output_prefix = os.path.join(self.path_temp, file_prefix) + + params = [ + ("run", platform), + ("--local",), + ("--output-prefix", output_prefix), + ("--index", dataset_s3.index), + ("--barcode-files", dataset_s3.barcodes % platform), + ("--barcode-fastq", dataset_s3.barcode_fastq % platform), + ("--genomic-fastq", dataset_s3.genomic_fastq % platform), + ("--star-args", "runRNGseed=0"), + ] + + argv = [element for tupl in params for element in tupl] + + if platform != "drop_seq": + argv += ["--barcode-files", dataset_s3.barcodes % platform] + + main.main(argv) + + # get output file list + files = get_output_file_list(self.test_id, self.path_temp) + + # check if each expected file is found in the list of files generated + for file in expected_output_files(file_prefix): + self.assertIn(file, files) + + def test_using_local_dataset(self, platform="ten_x_v2"): + # must NOT end with a slash + file_prefix = "test" + output_prefix = os.path.join(self.path_temp, file_prefix) + + params = [ + ("run", platform), + ("--local",), + ("--output-prefix", output_prefix), + ("--index", dataset_local.index), + ("--barcode-files", dataset_local.barcodes % platform), + ("--barcode-fastq", dataset_local.barcode_fastq % platform), + ("--genomic-fastq", dataset_local.genomic_fastq % platform), + ("--star-args", "runRNGseed=0"), + ] + + argv = [element for tupl in params for element in tupl] + + if platform != "drop_seq": + argv += ["--barcode-files", dataset_local.barcodes % platform] + + main.main(argv) + + # get output file list + files = get_output_file_list(self.test_id, self.path_temp) + + # check if each expected file is found in the list of files generated + for file in expected_output_files(file_prefix): + self.assertIn(file, files) diff --git a/src/seqc/tests/test_run_e2e_remote.py b/src/seqc/tests/test_run_e2e_remote.py new file mode 100644 index 0000000..f2d9bfd --- /dev/null +++ b/src/seqc/tests/test_run_e2e_remote.py @@ -0,0 +1,269 @@ +import unittest +import os +import uuid +import shutil +import re +from seqc.core import main +from seqc import io +import boto3 +from nose2.tools import params +from test_dataset import dataset_s3 + + +def get_instance_by_test_id(test_id): + + ec2 = boto3.resource("ec2") + instances = ec2.instances.filter( + Filters=[{"Name": "tag:TestID", "Values": [test_id]}] + ) + instances = list(instances) + + if len(instances) != 1: + raise Exception("Test ID is not found or not unique!") + + return instances[0] + + +def expected_output_files(output_prefix): + + files = set( + [ + f"{output_prefix}.h5", + f"{output_prefix}_Aligned.out.bam", + f"{output_prefix}_alignment_summary.txt", + f"{output_prefix}_cell_filters.png", + f"{output_prefix}_de_gene_list.txt", + f"{output_prefix}_dense.csv", + f"{output_prefix}_merged.fastq.gz", + f"{output_prefix}_mini_summary.json", + f"{output_prefix}_mini_summary.pdf", + f"{output_prefix}_seqc_log.txt", + f"{output_prefix}_sparse_counts_barcodes.csv", + f"{output_prefix}_sparse_counts_genes.csv", + f"{output_prefix}_sparse_molecule_counts.mtx", + f"{output_prefix}_sparse_read_counts.mtx", + f"{output_prefix}_summary.tar.gz", + f"seqc_log.txt", + ] + ) + + return files + + +def expected_output_files_run_from_merged(output_prefix): + + files = expected_output_files(output_prefix) + + excludes = set([f"{output_prefix}_merged.fastq.gz"]) + + return files - excludes + + +def expected_output_files_run_from_bam(output_prefix): + + files = expected_output_files(output_prefix) + + excludes = set( + [ + f"{output_prefix}_Aligned.out.bam", + f"{output_prefix}_alignment_summary.txt", + f"{output_prefix}_merged.fastq.gz", + ] + ) + + return files - excludes + + +def get_output_file_list(test_id, s3_bucket, test_folder): + + # get instance and wait until terminated + instance = get_instance_by_test_id(test_id) + instance.wait_until_terminated() + + # check files generated in S3 + files = io.S3.listdir(s3_bucket, test_folder) + + # extract only filenames (i.e. remove directory hierarchy) + # convert to a set for easy comparison + files = set(map(lambda filename: filename.replace(test_folder, ""), files)) + + return files + + +def check_for_success_msg(s3_seqc_log_uri, path_temp): + + # download seqc_log.txt + io.S3.download( + link=s3_seqc_log_uri, prefix=path_temp, overwrite=True, recursive=False + ) + + # check if seqc_log.txt has a successful message + with open(os.path.join(path_temp, "seqc_log.txt"), "rt") as fin: + logs = fin.read() + match = re.search(r"Execution completed successfully", logs, re.MULTILINE) + + return True if match else False + + +class TestRunRemote(unittest.TestCase): + + email = os.environ["SEQC_TEST_EMAIL"] + rsa_key = os.environ["SEQC_TEST_RSA_KEY"] + ami_id = os.environ["SEQC_TEST_AMI_ID"] + + s3_bucket = "dp-lab-cicd" + + @classmethod + def setUp(cls): + cls.test_id = str(uuid.uuid4()) + cls.path_temp = os.path.join( + os.environ["TMPDIR"], "seqc-test", str(uuid.uuid4()) + ) + os.makedirs(cls.path_temp, exist_ok=True) + + @classmethod + def tearDown(self): + if os.path.isdir(self.path_temp): + shutil.rmtree(self.path_temp, ignore_errors=True) + + @params("in_drop_v2", "ten_x_v2") + def test_remote_from_raw_fastq(self, platform="ten_x_v2"): + output_prefix = "from-raw-fastq" + # must end with a slash + test_folder = f"seqc/run-{platform}-{self.test_id}/" + + params = [ + ("run", platform), + ("--output-prefix", "from-raw-fastq"), + ("--upload-prefix", f"s3://{self.s3_bucket}/{test_folder}"), + ("--index", dataset_s3.index), + ("--email", self.email), + ("--barcode-fastq", dataset_s3.barcode_fastq % platform), + ("--genomic-fastq", dataset_s3.genomic_fastq % platform), + ("--instance-type", "r5.2xlarge"), + ("--spot-bid", "1.0"), + ("--rsa-key", self.rsa_key), + ("--debug",), + ("--remote-update",), + ("--ami-id", self.ami_id), + ("--user-tags", f"TestID:{self.test_id}"), + ] + + argv = [element for tupl in params for element in tupl] + + if platform != "drop_seq": + argv += ["--barcode-files", dataset_s3.barcodes % platform] + + main.main(argv) + + # wait until terminated + # get output file list + files = get_output_file_list(self.test_id, self.s3_bucket, test_folder) + + # check for the exact same filenames + self.assertSetEqual(files, expected_output_files(output_prefix)) + + # check for success message in seqc_log.txt + has_success_msg = check_for_success_msg( + s3_seqc_log_uri="s3://{}/{}".format( + self.s3_bucket, os.path.join(test_folder, "seqc_log.txt") + ), + path_temp=self.path_temp, + ) + + self.assertTrue( + has_success_msg, msg="Unable to find the success message in the log" + ) + + def test_remote_from_merged(self, platform="in_drop_v2"): + output_prefix = "from-merged" + # must end with a slash + test_folder = f"seqc/run-{platform}-{self.test_id}/" + + params = [ + ("run", platform), + ("--output-prefix", output_prefix), + ("--upload-prefix", f"s3://{self.s3_bucket}/{test_folder}"), + ("--index", dataset_s3.index), + ("--email", self.email), + ("--merged-fastq", dataset_s3.merged_fastq % (platform, platform)), + ("--rsa-key", self.rsa_key), + ("--instance-type", "r5.2xlarge"), + ("--ami-id", self.ami_id), + ("--remote-update",), + ("--user-tags", f"TestID:{self.test_id}") + # ('--spot-bid', '1.0') + ] + + argv = [element for tupl in params for element in tupl] + + if platform != "drop_seq": + argv += ["--barcode-files", dataset_s3.barcodes % platform] + + main.main(argv) + + # wait until terminated + # get output file list + files = get_output_file_list(self.test_id, self.s3_bucket, test_folder) + + # check for the exact same filenames + self.assertSetEqual(files, expected_output_files_run_from_merged(output_prefix)) + + # check for success message in seqc_log.txt + has_success_msg = check_for_success_msg( + s3_seqc_log_uri="s3://{}/{}".format( + self.s3_bucket, os.path.join(test_folder, "seqc_log.txt") + ), + path_temp=self.path_temp, + ) + + self.assertTrue( + has_success_msg, msg="Unable to find the success message in the log" + ) + + def test_remote_from_bamfile(self, platform="in_drop_v2"): + output_prefix = "from-bamfile" + # must end with a slash + test_folder = f"seqc/run-{platform}-{self.test_id}/" + + params = [ + ("run", platform), + ("--output-prefix", output_prefix), + ("--upload-prefix", f"s3://{self.s3_bucket}/{test_folder}"), + ("--index", dataset_s3.index), + ("--email", self.email), + ("--alignment-file", dataset_s3.bam % platform), + ("--rsa-key", self.rsa_key), + ("--instance-type", "r5.2xlarge"), + ("--debug",), + ("--ami-id", self.ami_id), + ("--remote-update",), + ("--user-tags", f"TestID:{self.test_id}") + # ('--spot-bid', '1.0') + ] + + argv = [element for tupl in params for element in tupl] + + if platform != "drop_seq": + argv += ["--barcode-files", dataset_s3.barcodes % platform] + + main.main(argv) + + # wait until terminated + # get output file list + files = get_output_file_list(self.test_id, self.s3_bucket, test_folder) + + # check for the exact same filenames + self.assertSetEqual(files, expected_output_files_run_from_bam(output_prefix)) + + # check for success message in seqc_log.txt + has_success_msg = check_for_success_msg( + s3_seqc_log_uri="s3://{}/{}".format( + self.s3_bucket, os.path.join(test_folder, "seqc_log.txt") + ), + path_temp=self.path_temp, + ) + + self.assertTrue( + has_success_msg, msg="Unable to find the success message in the log" + ) diff --git a/src/seqc/tests/test_run_gtf.py b/src/seqc/tests/test_run_gtf.py new file mode 100644 index 0000000..78c5c7f --- /dev/null +++ b/src/seqc/tests/test_run_gtf.py @@ -0,0 +1,66 @@ +from unittest import TestCase, mock +import os +import uuid +import shutil +import nose2 +from seqc.sequence import gtf +from test_dataset import dataset_local + + +class TestGtf(TestCase): + @classmethod + def setUp(cls): + cls.test_id = str(uuid.uuid4()) + cls.path_temp = os.path.join( + os.environ["TMPDIR"], "seqc-test", str(uuid.uuid4()) + ) + cls.annotation = os.path.join(dataset_local.index, "annotations.gtf") + + @classmethod + def tearDown(self): + if os.path.isdir(self.path_temp): + shutil.rmtree(self.path_temp, ignore_errors=True) + + def test_construct_translator(self): + translator = gtf.GeneIntervals(self.annotation) + self.assertIsNotNone(translator) + + def test_num_of_transcripts(self): + rd = gtf.Reader(self.annotation) + num_transcripts = sum(1 for _ in rd.iter_transcripts()) + # awk -F'\t' '$3=="transcript" { print $0 }' annotations.gtf | wc -l + self.assertEqual(num_transcripts, 12747) + + def test_iter_transcripts(self): + rd = gtf.Reader(self.annotation) + (transcript_chromosome, transcript_strand, transcript_gene_id), exons = next( + rd.iter_transcripts() + ) + + # this should give us 3 exons of the first transcript of the first gene found in inverse order: + # + # chr19 HAVANA gene 60951 71626 . - . gene_id "ENSG00000282458.1"; gene_type "transcribed_processed_pseudogene"; gene_status "KNOWN"; gene_name "WASH5P"; level 2; havana_gene "OTTHUMG00000180466.8"; + # chr19 HAVANA transcript 60951 70976 . - . gene_id "ENSG00000282458.1"; transcript_id "ENST00000632506.1"; gene_type "transcribed_processed_pseudogene"; gene_status "KNOWN"; gene_name "WASH5P"; transcript_type "processed_transcript"; transcript_status "KNOWN"; transcript_name "WASH5P-008"; level 2; tag "basic"; transcript_support_level "1"; havana_gene "OTTHUMG00000180466.8"; havana_transcript "OTTHUMT00000471217.2"; + # chr19 HAVANA exon 70928 70976 . - . gene_id "ENSG00000282458.1"; transcript_id "ENST00000632506.1"; gene_type "transcribed_processed_pseudogene"; gene_status "KNOWN"; gene_name "WASH5P"; transcript_type "processed_transcript"; transcript_status "KNOWN"; transcript_name "WASH5P-008"; exon_number 1; exon_id "ENSE00003781173.1"; level 2; tag "basic"; transcript_support_level "1"; havana_gene "OTTHUMG00000180466.8"; havana_transcript "OTTHUMT00000471217.2"; + # chr19 HAVANA exon 66346 66499 . - . gene_id "ENSG00000282458.1"; transcript_id "ENST00000632506.1"; gene_type "transcribed_processed_pseudogene"; gene_status "KNOWN"; gene_name "WASH5P"; transcript_type "processed_transcript"; transcript_status "KNOWN"; transcript_name "WASH5P-008"; exon_number 2; exon_id "ENSE00003783498.1"; level 2; tag "basic"; transcript_support_level "1"; havana_gene "OTTHUMG00000180466.8"; havana_transcript "OTTHUMT00000471217.2"; + # chr19 HAVANA exon 60951 61894 . - . gene_id "ENSG00000282458.1"; transcript_id "ENST00000632506.1"; gene_type "transcribed_processed_pseudogene"; gene_status "KNOWN"; gene_name "WASH5P"; transcript_type "processed_transcript"; transcript_status "KNOWN"; transcript_name "WASH5P-008"; exon_number 3; exon_id "ENSE00003783010.1"; level 2; tag "basic"; transcript_support_level "1"; havana_gene "OTTHUMG00000180466.8"; havana_transcript "OTTHUMT00000471217.2"; + + self.assertEqual(transcript_chromosome, "chr19") + self.assertEqual(transcript_strand, "-") + self.assertEqual(transcript_gene_id, 282458) + self.assertEqual(len(exons), 3) + + # 8th column has exon ID + self.assertIn("ENSE00003783010.1", exons[0][8]) # exon number 3 + self.assertIn("ENSE00003783498.1", exons[1][8]) # exon number 2 + self.assertIn("ENSE00003781173.1", exons[2][8]) # exon number 1 + + def test_translate(self): + translator = gtf.GeneIntervals(self.annotation) + # chr19 HAVANA gene 60951 71626 . - . gene_id "ENSG00000282458.1"; gene_type "transcribed_processed_pseudogene"; gene_status "KNOWN"; gene_name "WASH5P"; level 2; havana_gene "OTTHUMG00000180466.8"; + gene_id = translator.translate("chr19", "-", 60951) + self.assertEqual(gene_id, 282458) + + +if __name__ == "__main__": + nose2.main() diff --git a/src/seqc/tests/test_run_readarray.py b/src/seqc/tests/test_run_readarray.py new file mode 100644 index 0000000..417278d --- /dev/null +++ b/src/seqc/tests/test_run_readarray.py @@ -0,0 +1,65 @@ +from unittest import TestCase, mock +import os +import uuid +import shutil +import nose2 +from test_dataset import dataset_local +from seqc.sequence.encodings import DNA3Bit +from seqc.read_array import ReadArray +from seqc.sequence import gtf + + +class TestReadArray(TestCase): + @classmethod + def setUp(cls): + cls.test_id = str(uuid.uuid4()) + cls.path_temp = os.path.join( + os.environ["TMPDIR"], "seqc-test", str(uuid.uuid4()) + ) + cls.annotation = os.path.join(dataset_local.index, "annotations.gtf") + cls.translator = gtf.GeneIntervals(cls.annotation, 10000) + + @classmethod + def tearDown(self): + if os.path.isdir(self.path_temp): + shutil.rmtree(self.path_temp, ignore_errors=True) + + def test_read_array_creation(self, platform="ten_x_v2"): + ra, _ = ReadArray.from_alignment_file( + dataset_local.bam % platform, self.translator, required_poly_t=0 + ) + self.assertIsNotNone(ra) + + def test_read_array_rmt_decode_10x_v2(self): + platform = "ten_x_v2" + + # create a readarray + ra, _ = ReadArray.from_alignment_file( + dataset_local.bam % platform, self.translator, required_poly_t=0 + ) + + # see if we can decode numeric UMI back to nucleotide sequence + dna3bit = DNA3Bit() + for rmt in ra.data["rmt"]: + decoded = dna3bit.decode(rmt).decode() + # ten_x_v2 UMI length = 10 nt + self.assertEqual(len(decoded), 10) + + def test_read_array_rmt_decode_10x_v3(self): + platform = "ten_x_v3" + + # create a readarray + ra, _ = ReadArray.from_alignment_file( + dataset_local.bam % platform, self.translator, required_poly_t=0 + ) + + # see if we can decode numeric UMI back to nucleotide sequence + dna3bit = DNA3Bit() + for rmt in ra.data["rmt"]: + decoded = dna3bit.decode(rmt).decode() + # ten_x_v3 UMI length = 12 nt + self.assertEqual(len(decoded), 12) + + +if __name__ == "__main__": + nose2.main() diff --git a/src/seqc/tests/test_run_rmt_correction.py b/src/seqc/tests/test_run_rmt_correction.py new file mode 100644 index 0000000..294fcb1 --- /dev/null +++ b/src/seqc/tests/test_run_rmt_correction.py @@ -0,0 +1,97 @@ +from unittest import TestCase, mock +import nose2 +import os +import numpy as np +from seqc.read_array import ReadArray +from seqc import rmt_correction + + +class TestRmtCorrection(TestCase): + @classmethod + def setUp(self): + # pre-allocate arrays + n_barcodes = 183416337 + data = np.recarray((n_barcodes,), ReadArray._dtype) + genes = np.zeros(n_barcodes, dtype=np.int32) + positions = np.zeros(n_barcodes, dtype=np.int32) + self.ra = ReadArray(data, genes, positions) + + @classmethod + def tearDown(self): + pass + + def test_should_return_correct_ra_size(self): + + ra_size = self.ra.data.nbytes + self.ra.genes.nbytes + self.ra.positions.nbytes + + self.assertEqual(4768824762, ra_size) + + # 64GB + @mock.patch( + "seqc.rmt_correction._get_available_memory", return_value=50 * 1024 ** 3 + ) + def test_should_return_correct_max_workers(self, mock_mem): + + n_workers = rmt_correction._calc_max_workers(self.ra) + + self.assertEqual(n_workers, 7) + + # 1TB + @mock.patch("seqc.rmt_correction._get_available_memory", return_value=1079354630144) + def test_should_return_correct_max_workers2(self, mock_mem): + + n_workers = rmt_correction._calc_max_workers(self.ra) + + self.assertEqual(n_workers, 156) + + # having less memory than ra size + @mock.patch("seqc.rmt_correction._get_available_memory") + def test_should_return_one_if_ra_larger_than_mem(self, mock_mem): + + ra_size = self.ra.data.nbytes + self.ra.genes.nbytes + self.ra.positions.nbytes + + # assume the available memory is a half of ra + mock_mem.return_value = int(ra_size) / 2 + + n_workers = rmt_correction._calc_max_workers(self.ra) + + self.assertEqual(n_workers, 1) + + +class TestRmtCorrection2(TestCase): + @classmethod + def setUp(self): + # pre-allocate arrays + n_barcodes = 183416337 + data = np.recarray((n_barcodes,), ReadArray._dtype) + genes = np.zeros(n_barcodes, dtype=np.int32) + positions = np.zeros(n_barcodes, dtype=np.int32) + self.ra = ReadArray(data, genes, positions) + + import pickle + + with open("pre-correction-ra.pickle", "wb") as fout: + pickle.dump(self.ra, fout) + + @classmethod + def tearDown(self): + import os + + try: + os.remove("pre-correction-ra.pickle") + except: + pass + + @mock.patch("seqc.rmt_correction._correct_errors_by_cell_group", return_value=0) + def test_correct_errors_by_chunks(self, mock_correct): + cell_group = [1, 2, 3] + x = rmt_correction._correct_errors_by_cell_group_chunks( + self.ra, cell_group, 0.02, 0.05 + ) + mock_correct.assert_called() + self.assertEquals(len(cell_group), mock_correct.call_count) + self.assertEquals([0, 0, 0], x) + + +if __name__ == "__main__": + nose2.main() diff --git a/src/seqc/version.py b/src/seqc/version.py index fe404ae..01ef120 100644 --- a/src/seqc/version.py +++ b/src/seqc/version.py @@ -1 +1 @@ -__version__ = "0.2.5" +__version__ = "0.2.6"