Skip to content

Commit

Permalink
⚡ Added cache to regen tries (#60)
Browse files Browse the repository at this point in the history
* ⚡ Added cache to regen tries

Signed-off-by: Marcos Martinez <[email protected]>

* 🎨 Fix 721 Flake

Signed-off-by: Marcos Martinez <[email protected]>

---------

Signed-off-by: Marcos Martinez <[email protected]>
  • Loading branch information
marmg authored Aug 24, 2023
1 parent d2e97f2 commit 8d12ed1
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 12 deletions.
13 changes: 9 additions & 4 deletions zshot/linker/linker_regen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from huggingface_hub import hf_hub_download

from zshot.config import MODELS_CACHE_PATH
from zshot.linker.linker_regen.trie import Trie
from zshot.utils.data_models import Span

Expand Down Expand Up @@ -36,7 +37,8 @@ def load_wikipedia_trie() -> Trie: # pragma: no cover
"""
wikipedia_trie_file = hf_hub_download(repo_id=REPO_ID,
repo_type='model',
filename=WIKIPEDIA_TRIE_FILE_NAME)
filename=WIKIPEDIA_TRIE_FILE_NAME,
cache_dir=MODELS_CACHE_PATH)
with open(wikipedia_trie_file, "rb") as f:
wikipedia_trie = pickle.load(f)
return wikipedia_trie
Expand All @@ -49,7 +51,8 @@ def load_wikipedia_mapping() -> Dict[str, str]: # pragma: no cover
"""
wikipedia_map = hf_hub_download(repo_id=REPO_ID,
repo_type='model',
filename=WIKIPEDIA_MAP)
filename=WIKIPEDIA_MAP,
cache_dir=MODELS_CACHE_PATH)
with open(wikipedia_map, "r") as f:
wikipedia_map = json.load(f)
return wikipedia_map
Expand Down Expand Up @@ -77,7 +80,8 @@ def load_dbpedia_trie() -> Trie: # pragma: no cover
"""
dbpedia_trie_file = hf_hub_download(repo_id=REPO_ID,
repo_type='model',
filename=DBPEDIA_TRIE_FILE_NAME)
filename=DBPEDIA_TRIE_FILE_NAME,
cache_dir=MODELS_CACHE_PATH)
with open(dbpedia_trie_file, "rb") as f:
dbpedia_trie = pickle.load(f)
return dbpedia_trie
Expand All @@ -90,7 +94,8 @@ def load_dbpedia_mapping() -> Dict[str, str]: # pragma: no cover
"""
dbpedia_map = hf_hub_download(repo_id=REPO_ID,
repo_type='model',
filename=DBPEDIA_MAP)
filename=DBPEDIA_MAP,
cache_dir=MODELS_CACHE_PATH)
with open(dbpedia_map, "r") as f:
dbpedia_map = json.load(f)
return dbpedia_map
Expand Down
2 changes: 1 addition & 1 deletion zshot/linker/linker_tars.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def flat_entities(self):
if isinstance(self.entities, dict):
self._entities = list(self.entities.keys())
if isinstance(self.entities, list):
self._entities = [e.name if type(e) == Entity else e for e in self.entities]
self._entities = [e.name if type(e) is Entity else e for e in self.entities]
if self.entities is None:
self._entities = []

Expand Down
2 changes: 1 addition & 1 deletion zshot/mentions_extractor/mentions_extractor_tars.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def flat_entities(self):
if isinstance(self._mentions, dict):
self._mentions = list(self._mentions.keys())
if isinstance(self._mentions, list):
self._mentions = [e.name if type(e) == Entity else e for e in self._mentions]
self._mentions = [e.name if type(e) is Entity else e for e in self._mentions]
if self._mentions is None:
self._mentions = []

Expand Down
10 changes: 5 additions & 5 deletions zshot/tests/test_zshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def get_mentions() -> List[Entity]:
zshot_component: Zshot = [comp for name, comp in nlp.pipeline if name == 'zshot'][0]
assert len(zshot_component.entities) == len(EX_ENTITIES)
assert len(zshot_component.mentions) == len(EX_ENTITIES)
assert type(zshot_component.entities[0]) == Entity
assert type(zshot_component.mentions[0]) == Entity
assert type(zshot_component.entities[0]) is Entity
assert type(zshot_component.mentions[0]) is Entity


def test_call_pipe_with_pipeline_configuration():
Expand All @@ -107,11 +107,11 @@ def test_call_pipe_with_pipeline_configuration():
assert "zshot" in nlp.pipe_names
zshot_component: Zshot = [comp for name, comp in nlp.pipeline if name == 'zshot'][0]
assert len(zshot_component.entities) == len(EX_ENTITIES)
assert type(zshot_component.entities[0]) == Entity
assert type(zshot_component.entities[0]) is Entity
assert len(zshot_component.mentions) == len(EX_ENTITIES)
assert type(zshot_component.mentions[0]) == Entity
assert type(zshot_component.mentions[0]) is Entity
assert len(zshot_component.mentions) == len(EX_ENTITIES)
assert type(zshot_component.mentions[0]) == Entity
assert type(zshot_component.mentions[0]) is Entity


def test_process_single_document():
Expand Down
2 changes: 1 addition & 1 deletion zshot/utils/data_models/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __hash__(self):
return zlib.crc32(self.__repr__().encode())

def __eq__(self, other: Any):
return (type(other) == type(self)
return (type(other) is type(self)
and self.start == other.start
and self.end == other.end
and self.label == other.label
Expand Down

0 comments on commit 8d12ed1

Please sign in to comment.