Skip to content

Commit

Permalink
remove torchtext vectors (refs #829)
Browse files Browse the repository at this point in the history
  • Loading branch information
cwmeijer committed Sep 19, 2024
1 parent 983d287 commit 5c45841
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 415 deletions.
9 changes: 3 additions & 6 deletions tutorials/explainers/LIME/lime_text.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,13 @@
},
"outputs": [],
"source": [
"import os\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import spacy\n",
"from torchtext.vocab import Vectors\n",
"from scipy.special import expit as sigmoid\n",
"\n",
"import dianna\n",
"from dianna import visualization\n",
"from dianna import utils\n",
"from dianna import visualization\n",
"from dianna.utils.downloader import download\n",
"from dianna.utils.tokenizers import SpacyTokenizer"
]
Expand Down Expand Up @@ -131,7 +128,7 @@
"class MovieReviewsModelRunner:\n",
" def __init__(self, model, word_vectors, max_filter_size):\n",
" self.run_model = utils.get_function(model)\n",
" self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors))\n",
" self.keys = list(pd.read_csv(word_vector_path, header=None, delimiter=' ')[0])\n",
" self.max_filter_size = max_filter_size\n",
" \n",
" self.tokenizer = SpacyTokenizer(name='en_core_web_sm')\n",
Expand All @@ -149,7 +146,7 @@
" tokens += ['<pad>'] * (self.max_filter_size - len(tokens))\n",
" \n",
" # numericalize the tokens\n",
" tokens_numerical = [self.vocab.stoi[token] if token in self.vocab.stoi else self.vocab.stoi['<unk>']\n",
" tokens_numerical = [self.keys.index(token) if token in self.keys else self.keys.index('<unk>')\n",
" for token in tokens]\n",
"\n",
" # run the model, applying a sigmoid because the model outputs logits, remove any remaining batch axis\n",
Expand Down
Loading

0 comments on commit 5c45841

Please sign in to comment.