diff --git a/README.md b/README.md index b0a537d0..cc0553a9 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,9 @@ pip install git+https://github.com/ramanathanlab/genslm GenSLMs were trained on the [Polaris](https://www.alcf.anl.gov/polaris) and [Perlmutter](https://perlmutter.carrd.co/) supercomputers. For installation on these systems, please see [`INSTALL.md`](https://github.com/ramanathanlab/genslm/blob/main/docs/INSTALL.md). ## Usage +> :warning: **Model weights will be unavailable May 5, 2023 to May 12, 2023** + +> :warning: **Model weights downloaded prior to May 3, 2023 have a small issue in name space. Please redownload models for fix.** Our pre-trained models and datasets can be downloaded from this [Globus Endpoint](https://app.globus.org/file-manager?origin_id=25918ad0-2a4e-4f37-bcfc-8183b19c3150&origin_path=%2F). @@ -34,9 +37,14 @@ import numpy as np from torch.utils.data import DataLoader from genslm import GenSLM, SequenceDataset +# Load model model = GenSLM("genslm_25M_patric", model_cache_dir="/content/gdrive/MyDrive") model.eval() +# Select GPU device if it is available, else use CPU +device = "cuda" if torch.cuda.is_available() else "cpu" +model.to(device) + # Input data is a list of gene sequences sequences = [ "ATGAAAGTAACCGTTGTTGGAGCAGGTGCAGTTGGTGCAAGTTGCGCAGAATATATTGCA", @@ -50,9 +58,15 @@ dataloader = DataLoader(dataset) embeddings = [] with torch.no_grad(): for batch in dataloader: - outputs = model(batch["input_ids"], batch["attention_mask"], output_hidden_states=True) + for batch in dataloader: + outputs = model( + batch["input_ids"].to(device), + batch["attention_mask"].to(device), + output_hidden_states=True, + ) # outputs.hidden_states shape: (layers, batch_size, sequence_length, hidden_size) - emb = outputs.hidden_states[0].detach().cpu().numpy() + # Use the embeddings of the last layer + emb = outputs.hidden_states[-1].detach().cpu().numpy() # Compute average over sequence length emb = np.mean(emb, axis=1) embeddings.append(emb) @@ -67,11 +81,16 @@ embeddings.shape ```python from genslm import GenSLM +# Load model model = GenSLM("genslm_25M_patric", model_cache_dir="/content/gdrive/MyDrive") model.eval() +# Select GPU device if it is available, else use CPU +device = "cuda" if torch.cuda.is_available() else "cpu" +model.to(device) + # Prompt the language model with a start codon -prompt = model.tokenizer.encode("ATG", return_tensors="pt") +prompt = model.tokenizer.encode("ATG", return_tensors="pt").to(device) tokens = model.model.generate( prompt, diff --git a/docs/COMMANDS.md b/docs/COMMANDS.md index fed1bb9c..806fc506 100644 --- a/docs/COMMANDS.md +++ b/docs/COMMANDS.md @@ -29,7 +29,7 @@ python -m genslm.cmdline.remove_neox_attention_bias \ 2. Setup a config file that looks like this: ``` load_pt_checkpoint: /home/hippekp/CVD-Mol-AI/hippekp/model_training/25m_genome_embeddings/model-epoch69-val_loss0.01.pt -tokenizer_file: /home/hippekp/github/genslm/genslm/tokenizer_files/codon_wordlevel_100vocab.json +tokenizer_file: /home/hippekp/github/genslm/genslm/tokenizer_files/codon_wordlevel_69vocab.json data_file: $DATA.h5 embeddings_out_path: /home/hippekp/CVD-Mol-AI/hippekp/model_training/25m_genome_embeddings/train_embeddings/ model_config_json: /lus/eagle/projects/CVD-Mol-AI/hippekp/model_training/genome_finetuning_25m/config/neox_25,290,752.json @@ -64,7 +64,7 @@ Converting a directory of fasta files into a directory of h5 files (Step one of python -m genslm.cmdline.fasta_to_h5 \ --fasta $PATH_TO_FASTA_DIR \ --h5_dir $PATH_TO_OUTDIR \ - --tokenizer_file ~/github/genslm/genslm/tokenizer_files/codon_wordlevel_100vocab.json + --tokenizer_file ~/github/genslm/genslm/tokenizer_files/codon_wordlevel_69vocab.json ``` Converting a directory of h5 files into a single h5 file (Step two of data preprocessing for pretraining, output of this step is what we use for pretraining) @@ -83,7 +83,7 @@ Converting individual fasta files into individual h5 files (Useful for getting e python -m genslm.cmdline.single_fasta_to_h5 \ -f $PATH_TO_SINGLE_FASTA \ --h5 $PATH_TO_SINGLE_H5 \ - -t ~/github/genslm/genslm/tokenizer_files/codon_wordlevel_100vocab.json \ + -t ~/github/genslm/genslm/tokenizer_files/codon_wordlevel_69vocab.json \ -b 10240 \ -n 16\ --train_val_test_split diff --git a/examples/embedding.ipynb b/examples/embedding.ipynb index 15e5e86d..7243cf5d 100644 --- a/examples/embedding.ipynb +++ b/examples/embedding.ipynb @@ -202,9 +202,14 @@ "from torch.utils.data import DataLoader\n", "from genslm import GenSLM, SequenceDataset\n", "\n", + "# Load model\n", "model = GenSLM(\"genslm_25M_patric\", model_cache_dir=\"/content/gdrive/MyDrive\")\n", "model.eval()\n", "\n", + "# Select GPU device if it is available, else use CPU\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)\n", + "\n", "# Input data is a list of gene sequences\n", "sequences = [\n", " \"ATGAAAGTAACCGTTGTTGGAGCAGGTGCAGTTGGTGCAAGTTGCGCAGAATATATTGCA\",\n", @@ -219,10 +224,13 @@ "with torch.no_grad():\n", " for batch in dataloader:\n", " outputs = model(\n", - " batch[\"input_ids\"], batch[\"attention_mask\"], output_hidden_states=True\n", + " batch[\"input_ids\"].to(device),\n", + " batch[\"attention_mask\"].to(device),\n", + " output_hidden_states=True,\n", " )\n", " # outputs.hidden_states shape: (layers, batch_size, sequence_length, hidden_size)\n", - " emb = outputs.hidden_states[0].detach().cpu().numpy()\n", + " # Use the embeddings of the last layer\n", + " emb = outputs.hidden_states[-1].detach().cpu().numpy()\n", " # Compute average over sequence length\n", " emb = np.mean(emb, axis=1)\n", " embeddings.append(emb)\n", @@ -241,7 +249,7 @@ "outputs": [], "source": [ "# NOTE: This is not the best performance you can get. For a scalable implementation,\n", - "# refer to genslm.cmdline.inference_outputs for an example of how to utilize multiple\n", + "# refer to genslm.cmdline.run_inference for an example of how to utilize multiple\n", "# GPUs for parallel inference." ] } diff --git a/examples/generate.ipynb b/examples/generate.ipynb index 14baa45f..eef09206 100644 --- a/examples/generate.ipynb +++ b/examples/generate.ipynb @@ -195,13 +195,19 @@ } ], "source": [ + "import torch\n", "from genslm import GenSLM\n", "\n", + "# Load model\n", "model = GenSLM(\"genslm_25M_patric\", model_cache_dir=\"/content/gdrive/MyDrive\")\n", "model.eval()\n", "\n", + "# Select GPU device if it is available, else use CPU\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)\n", + "\n", "# Prompt the language model with a start codon\n", - "prompt = model.tokenizer.encode(\"ATG\", return_tensors=\"pt\")\n", + "prompt = model.tokenizer.encode(\"ATG\", return_tensors=\"pt\").to(device)\n", "\n", "tokens = model.model.generate(\n", " prompt,\n", diff --git a/examples/training/covid_models/250M_finetune_first_year.yaml b/examples/training/covid_models/250M_finetune_first_year.yaml index d3a616c4..01ed8bb3 100644 --- a/examples/training/covid_models/250M_finetune_first_year.yaml +++ b/examples/training/covid_models/250M_finetune_first_year.yaml @@ -16,7 +16,7 @@ limit_val_batches: 32 check_val_every_n_epoch: 1 checkpoint_every_n_train_steps: 500 checkpoint_every_n_epochs: null -tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_100vocab.json +tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_69vocab.json train_file: /path/to/data/first_year/first_year_train.h5 val_file: /path/to/data/first_year/first_year_val.h5 test_file: /path/to/data/first_year/first_year_val.h5 diff --git a/examples/training/covid_models/25M_finetune_first_year.yaml b/examples/training/covid_models/25M_finetune_first_year.yaml index 4e52424b..81ae29fd 100644 --- a/examples/training/covid_models/25M_finetune_first_year.yaml +++ b/examples/training/covid_models/25M_finetune_first_year.yaml @@ -16,7 +16,7 @@ limit_val_batches: 32 check_val_every_n_epoch: 1 checkpoint_every_n_train_steps: 500 checkpoint_every_n_epochs: null -tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_100vocab.json +tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_69vocab.json train_file: /path/to/data/first_year/first_year_train.h5 val_file: /path/to/data/first_year/first_year_val.h5 test_file: /path/to/data/first_year/first_year_val.h5 diff --git a/examples/training/foundation_models/250M_foundation.yaml b/examples/training/foundation_models/250M_foundation.yaml index 905943c1..75c7a573 100644 --- a/examples/training/foundation_models/250M_foundation.yaml +++ b/examples/training/foundation_models/250M_foundation.yaml @@ -15,7 +15,7 @@ limit_val_batches: 32 check_val_every_n_epoch: 1 checkpoint_every_n_train_steps: 500 checkpoint_every_n_epochs: null -tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_100vocab.json +tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_69vocab.json train_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_train.h5 val_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_val.h5 test_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_test.h5 diff --git a/examples/training/foundation_models/25B_foundation.yaml b/examples/training/foundation_models/25B_foundation.yaml index ef11d56d..42edac10 100644 --- a/examples/training/foundation_models/25B_foundation.yaml +++ b/examples/training/foundation_models/25B_foundation.yaml @@ -16,7 +16,7 @@ limit_val_batches: 32 check_val_every_n_epoch: 1 checkpoint_every_n_train_steps: 50 checkpoint_every_n_epochs: null -tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_100vocab.json +tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_69vocab.json train_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_train.h5 val_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_val.h5 test_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_test.h5 diff --git a/examples/training/foundation_models/25M_foundation.yaml b/examples/training/foundation_models/25M_foundation.yaml index f6ff257b..336aae39 100644 --- a/examples/training/foundation_models/25M_foundation.yaml +++ b/examples/training/foundation_models/25M_foundation.yaml @@ -15,7 +15,7 @@ limit_val_batches: 32 check_val_every_n_epoch: 1 checkpoint_every_n_train_steps: 500 checkpoint_every_n_epochs: null -tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_100vocab.json +tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_69vocab.json train_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_train.h5 val_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_val.h5 test_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_test.h5 diff --git a/examples/training/foundation_models/2B_foundation.yaml b/examples/training/foundation_models/2B_foundation.yaml index 499f4acd..bc69f42a 100644 --- a/examples/training/foundation_models/2B_foundation.yaml +++ b/examples/training/foundation_models/2B_foundation.yaml @@ -2,7 +2,7 @@ wandb_active: true wandb_project_name: codon_transformer wandb_entity_name: gene_mdh_gan checkpoint_dir: patric_2.5B_pretraining/checkpoints_v2/ -tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_100vocab.json +tokenizer_file: ../../genslm/tokenizer_files/codon_wordlevel_69vocab.json train_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_train.h5 val_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_val.h5 test_file: /path/to/data/patric_89M/pgfam_30k_h5_tts/combined_test.h5 diff --git a/genslm/__init__.py b/genslm/__init__.py index 6731c875..f88bb556 100644 --- a/genslm/__init__.py +++ b/genslm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.3a1" +__version__ = "0.0.4a1" # Public imports from genslm.dataset import SequenceDataset # noqa diff --git a/genslm/cmdline/process_single_family_file.py b/genslm/cmdline/process_single_family_file.py index 45b365cb..bd34396c 100644 --- a/genslm/cmdline/process_single_family_file.py +++ b/genslm/cmdline/process_single_family_file.py @@ -31,7 +31,7 @@ def main(input_fasta: Path, output_h5: Path, tokenizer_path: Path, block_size: i "--tokenizer_file", help="Path to tokenizer file", default=( - fp.parent.parent / "genslm/tokenizer_files/codon_wordlevel_100vocab.json" + fp.parent.parent / "genslm/tokenizer_files/codon_wordlevel_69vocab.json" ), ) parser.add_argument( diff --git a/genslm/cmdline/run_inference.py b/genslm/cmdline/run_inference.py index 0c87e25f..4e2fb56d 100644 --- a/genslm/cmdline/run_inference.py +++ b/genslm/cmdline/run_inference.py @@ -1,10 +1,12 @@ import functools import hashlib import os +import time import uuid from argparse import ArgumentParser +from concurrent.futures import ProcessPoolExecutor from pathlib import Path -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple import h5py import numpy as np @@ -34,8 +36,8 @@ class InferenceConfig(BaseSettings): """Directory to write embeddings, attentions, logits to.""" # Which outputs to generate - layer_bounds: Union[Tuple[int, int], List[int]] = (0, -1) - """Which layers to generate data for, all by default.""" + layers: List[int] = [-1] + """Which layers to generate data for, last only by default.""" output_embeddings: bool = True """Whether or not to generate and save embeddings.""" output_attentions: bool = False @@ -125,11 +127,163 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: return sample +def _read_average_embedding_process_fn( + chunk_idxs: Tuple[int, int], + h5_file_path: Path, + hidden_dim: int, + model_seq_len: int, +) -> np.ndarray: + num_embs = chunk_idxs[1] - chunk_idxs[0] + embs = np.empty(shape=(num_embs, hidden_dim), dtype=np.float32) + emb = np.zeros((model_seq_len, hidden_dim), dtype=np.float32) + with h5py.File(h5_file_path, "r") as f: + group = f["embeddings"] + for i, idx in enumerate(map(str, range(*chunk_idxs))): + seqlen = group[idx].shape[0] + f[f"embeddings/{idx}"].read_direct(emb, dest_sel=np.s_[:seqlen]) + embs[i] = emb[:seqlen].mean(axis=0) + return embs + + +def read_average_embeddings( + h5_file_path: Path, + hidden_dim: int = 512, + seq_len: int = 2048, + num_workers: int = 4, + return_md5: bool = False, +) -> Dict[str, np.ndarray]: + """Read average embeddings from an HDF5 file. + + Parameters + ---------- + h5_file_path : Path + path to h5 file + hidden_dim : int, optional + hidden dimension of model that generated embeddings, by default 512 + seq_len : int, optional + sequence length of the model, by default 2048 + num_workers : int, optional + number of workers to use, by default 4 + + Returns + ------- + Dict[str, np.ndarray] + embeddings averaged into hidden_dim under the 'embeddings' key, and if specified, the hashes under 'na-hashes' + """ + os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" + + out_data = {} + + with h5py.File(h5_file_path, "r") as h5_file: + total_embeddings = len(h5_file["embeddings"]) + if return_md5: + out_data["na-hashes"] = h5_file["na-hashes"][...] + + chunk_size = max(1, total_embeddings // num_workers) + chunk_idxs = [ + (i, min(i + chunk_size, total_embeddings)) + for i in range(0, total_embeddings, chunk_size) + ] + + read_func = functools.partial( + _read_average_embedding_process_fn, + h5_file_path=h5_file_path, + hidden_dim=hidden_dim, + model_seq_len=seq_len, + ) + out_array = np.empty(shape=(total_embeddings, hidden_dim), dtype=np.float32) + with ProcessPoolExecutor(max_workers=num_workers) as executor: + for chunk_emb, chunk_range in zip( + executor.map(read_func, chunk_idxs), chunk_idxs + ): + out_array[chunk_range[0] : chunk_range[1]] = chunk_emb + + out_data["embeddings"] = out_array + return out_data + + +def _read_full_embeddings_process_fn( + chunk_idxs: Tuple[int, int], + h5_file_path: Path, + hidden_dim: int, + model_seq_len: int, +) -> np.ndarray: + num_embs = chunk_idxs[1] - chunk_idxs[0] + embs = np.zeros(shape=(num_embs, model_seq_len, hidden_dim), dtype=np.float32) + emb = np.zeros((model_seq_len, hidden_dim), dtype=np.float32) + with h5py.File(h5_file_path, "r") as f: + group = f["embeddings"] + for i, idx in enumerate(map(str, range(*chunk_idxs))): + seqlen = group[idx].shape[0] + emb[:] = 0 # reset + f[f"embeddings/{idx}"].read_direct(emb, dest_sel=np.s_[:seqlen]) + embs[i] = emb + return embs + + +def read_full_embeddings( + h5_file_path: Path, + hidden_dim: int = 512, + seq_len: int = 2048, + num_workers: int = 4, + return_md5: bool = False, +) -> Dict[str, np.ndarray]: + """Read token level embeddings from an HDF5 file. + + Parameters + ---------- + h5_file_path : Path + path to h5 file + hidden_dim : int, optional + hidden dimension of the model that generated embeddings, by default 512 + seq_len : int, optional + sequence length of the model, by default 2048 + num_workers : int, optional + number of workers to use, by default 4 + + Returns + ------- + Dict[str, np.ndarray] + token level embeddings under the 'embeddings' key, and if specified, the hashes under 'na-hashes' + """ + os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" + out_data = {} + with h5py.File(h5_file_path, "r") as h5_file: + total_embeddings = len(h5_file["embeddings"]) + if return_md5: + out_data["na-hashes"] = h5_file["na-hashes"][...] + + chunk_size = max(1, total_embeddings // num_workers) + chunk_idxs = [ + (i, min(i + chunk_size, total_embeddings)) + for i in range(0, total_embeddings, chunk_size) + ] + + read_func = functools.partial( + _read_full_embeddings_process_fn, + h5_file_path=h5_file_path, + hidden_dim=hidden_dim, + model_seq_len=seq_len, + ) + + out_array = np.empty( + shape=(total_embeddings, seq_len, hidden_dim), dtype=np.float32 + ) + with ProcessPoolExecutor(max_workers=num_workers) as executor: + for chunk_emb, chunk_range in zip( + executor.map(read_func, chunk_idxs), chunk_idxs + ): + out_array[chunk_range[0] : chunk_range[1]] = chunk_emb + + out_data["embeddings"] = out_array + return out_data + + class OutputsCallback(Callback): def __init__( self, save_dir: Path = Path("./outputs"), - layer_bounds: Tuple[int, int] = (0, -1), + layers: List[int] = [-1], output_embeddings: bool = True, output_attentions: bool = False, output_logits: bool = False, @@ -142,12 +296,7 @@ def __init__( self.save_dir = save_dir self.save_dir.mkdir(parents=True, exist_ok=True) - if isinstance(layer_bounds, tuple): - self.layer_lb, self.layer_ub = layer_bounds - self.layers = None - elif isinstance(layer_bounds, list): - self.layer_lb, self.layer_ub = None, None - self.layers = layer_bounds + self.layers = layers # Embeddings: Key layer-id, value embedding array self.attentions, self.indices, self.na_hashes = [], [], [] @@ -158,26 +307,21 @@ def __init__( self.h5_kwargs = { # "compression": "gzip", # "compression_opts": 4, Compression is too slow for current impl - "fletcher32": True, + # "fletcher32": True, } + self.io_time = 0 + def on_predict_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" ) -> None: - # Plus one for embedding layer + # Plus one for initial embedding layer num_hidden_layers = pl_module.model.model.config.num_hidden_layers + 1 - if self.layer_lb is not None and self.layer_lb < 0: - self.layer_lb = num_hidden_layers + self.layer_lb - if self.layer_ub is not None and self.layer_ub < 0: - self.layer_ub = num_hidden_layers + self.layer_ub - - if self.layers is None: - self.layers = list(range(self.layer_lb, self.layer_ub)) - for ind in range(len(self.layers)): layer_num = self.layers[ind] if layer_num < 0: + # e.g -1 turns into model_layers + -1 (e.g. 12 + -1 = 11 last layer for 0 indexed arrays) self.layers[ind] = num_hidden_layers + layer_num if self.output_logits: @@ -204,15 +348,16 @@ def on_predict_batch_end( self.attentions.append(attend) if self.output_logits: + start = time.time() logits = outputs.logits.detach().cpu().numpy() for logit, seq_len, fasta_ind in zip(logits, seq_lens, fasta_inds): self.h5logit_file["logits"].create_dataset( - f"{fasta_ind}", - data=logit[:seq_len], - **self.h5_kwargs, + f"{fasta_ind}", data=logit[:seq_len], **self.h5_kwargs ) + self.io_time += time.time() - start if self.output_embeddings: + start = time.time() for layer, embeddings in enumerate(outputs.hidden_states): # User specified list of layers to take if layer not in self.layers: @@ -230,13 +375,11 @@ def on_predict_batch_end( embed = embeddings.detach().cpu().numpy() for emb, seq_len, fasta_ind in zip(embed, seq_lens, fasta_inds): h5_file["embeddings"].create_dataset( - f"{fasta_ind}", - data=emb[:seq_len], - **self.h5_kwargs, + f"{fasta_ind}", data=emb[:seq_len], **self.h5_kwargs ) h5_file.flush() - + self.io_time += time.time() - start self.na_hashes.extend(batch["na_hash"]) self.indices.append(batch["indices"].detach().cpu()) @@ -244,18 +387,22 @@ def on_predict_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" ) -> None: - self.indices = torch.cat(self.indices).numpy().squeeze() + self.indices = torch.cat(self.indices).numpy().reshape(-1) if self.output_logits: + start = time.time() self.h5logit_file.create_dataset( "fasta-indices", data=self.indices, **self.h5_kwargs ) + print(self.na_hashes, flush=True) self.h5logit_file.create_dataset( "na-hashes", data=self.na_hashes, **self.h5_kwargs ) self.h5logit_file.close() + self.io_time += time.time() - start if self.output_embeddings: + start = time.time() # Write indices to h5 files to map embeddings back to fasta file for h5_file in self.h5embeddings_open.values(): h5_file.create_dataset( @@ -268,6 +415,9 @@ def on_predict_end( # Close all h5 files for h5_file in self.h5embeddings_open.values(): h5_file.close() + self.io_time += time.time() - start + + print("IO time:\t", self.io_time) class LightningGenSLM(pl.LightningModule): diff --git a/genslm/config.py b/genslm/config.py index 9450cf3d..93a3620d 100644 --- a/genslm/config.py +++ b/genslm/config.py @@ -131,7 +131,7 @@ class ModelSettings(BaseSettings): tokenizer_file: Path = ( Path(genslm.__file__).parent / "tokenizer_files" - / "codon_wordlevel_100vocab.json" + / "codon_wordlevel_69vocab.json" ) """Path to the tokenizer file.""" train_file: Path diff --git a/genslm/inference.py b/genslm/inference.py index 51a8b6e4..7d0cc338 100644 --- a/genslm/inference.py +++ b/genslm/inference.py @@ -22,25 +22,25 @@ class GenSLM(nn.Module): MODELS: Dict[str, Dict[str, str]] = { "genslm_25M_patric": { "config": str(__architecture_path / "neox" / "neox_25,290,752.json"), - "tokenizer": str(__tokenizer_path / "codon_wordlevel_100vocab.json"), + "tokenizer": str(__tokenizer_path / "codon_wordlevel_69vocab.json"), "weights": "patric_25m_epoch01-val_loss_0.57_bias_removed.pt", "seq_length": "2048", }, "genslm_250M_patric": { "config": str(__architecture_path / "neox" / "neox_244,464,576.json"), - "tokenizer": str(__tokenizer_path / "codon_wordlevel_100vocab.json"), + "tokenizer": str(__tokenizer_path / "codon_wordlevel_69vocab.json"), "weights": "patric_250m_epoch00_val_loss_0.48_attention_removed.pt", "seq_length": "2048", }, "genslm_2.5B_patric": { "config": str(__architecture_path / "neox" / "neox_2,533,931,008.json"), - "tokenizer": str(__tokenizer_path / "codon_wordlevel_100vocab.json"), + "tokenizer": str(__tokenizer_path / "codon_wordlevel_69vocab.json"), "weights": "patric_2.5b_epoch00_val_los_0.29_bias_removed.pt", "seq_length": "2048", }, "genslm_25B_patric": { "config": str(__architecture_path / "neox" / "neox_25,076,188,032.json"), - "tokenizer": str(__tokenizer_path / "codon_wordlevel_100vocab.json"), + "tokenizer": str(__tokenizer_path / "codon_wordlevel_69vocab.json"), "weights": "model-epoch00-val_loss0.70-v2.pt", "seq_length": "2048", }, diff --git a/genslm/tokenizer_files/codon_wordlevel_100vocab.json b/genslm/tokenizer_files/codon_wordlevel_69vocab.json similarity index 100% rename from genslm/tokenizer_files/codon_wordlevel_100vocab.json rename to genslm/tokenizer_files/codon_wordlevel_69vocab.json diff --git a/genslm/tokenizer_files/codon_wordlevel_71vocab.json b/genslm/tokenizer_files/codon_wordlevel_71vocab.json new file mode 100644 index 00000000..388b7120 --- /dev/null +++ b/genslm/tokenizer_files/codon_wordlevel_71vocab.json @@ -0,0 +1,209 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "[CLS]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "[BOS]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "[EOS]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 4, + "content": "[SEP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 5, + "content": "[PAD]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 6, + "content": "[MASK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Whitespace" + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "[BOS]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "[EOS]", + "type_id": 0 + } + } + ], + "pair": [ + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + } + ], + "special_tokens": { + "[BOS]": { + "id": "[BOS]", + "ids": [ + 2 + ], + "tokens": [ + "[BOS]" + ] + }, + "[EOS]": { + "id": "[EOS]", + "ids": [ + 3 + ], + "tokens": [ + "[EOS]" + ] + } + } + }, + "decoder": null, + "model": { + "type": "WordLevel", + "vocab": { + "[UNK]": 0, + "[CLS]": 1, + "[BOS]": 2, + "[EOS]": 3, + "[SEP]": 4, + "[PAD]": 5, + "[MASK]": 6, + "GGC": 7, + "GCC": 8, + "ATC": 9, + "GAC": 10, + "GAA": 11, + "ATG": 12, + "GTG": 13, + "CTG": 14, + "GTC": 15, + "GCG": 16, + "GAT": 17, + "AAA": 18, + "GGT": 19, + "AAG": 20, + "GAG": 21, + "ACC": 22, + "AAC": 23, + "GTT": 24, + "ATT": 25, + "GCA": 26, + "CTC": 27, + "CGC": 28, + "GCT": 29, + "CAG": 30, + "CCG": 31, + "TTC": 32, + "GTA": 33, + "TCG": 34, + "GGA": 35, + "AAT": 36, + "TAC": 37, + "CTT": 38, + "TTG": 39, + "ACG": 40, + "TCC": 41, + "GGG": 42, + "AGC": 43, + "CCC": 44, + "ACA": 45, + "ACT": 46, + "TCT": 47, + "TTA": 48, + "CGT": 49, + "TAT": 50, + "CAA": 51, + "CGG": 52, + "TTT": 53, + "CAC": 54, + "CCT": 55, + "CCA": 56, + "TGG": 57, + "ATA": 58, + "TCA": 59, + "TGC": 60, + "AGT": 61, + "AGA": 62, + "CAT": 63, + "TGT": 64, + "CTA": 65, + "AGG": 66, + "TAA": 67, + "CGA": 68, + "TGA": 69, + "TAG": 70 + }, + "unk_token": "[UNK]" + } +} \ No newline at end of file