Skip to content

Commit

Permalink
fix: remove previously assigned extensions before extracting new metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
HLasse committed Jan 5, 2023
1 parent 16f9f3f commit 1a7ca00
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 104 deletions.
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
default_stages: [commit, push]

ci:
autofix_prs: false

repos:
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
Expand Down
2 changes: 1 addition & 1 deletion src/textdescriptives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .components.quality import QualityThresholds # noqa: F401
from .extractors import extract_df, extract_dict, extract_metrics # noqa: F401
from .load_components import TextDescriptives # noqa: F401
from .utils import get_assigns, get_valid_metrics # noqa: F401
from .utils import get_doc_assigns, get_valid_metrics # noqa: F401
2 changes: 2 additions & 0 deletions src/textdescriptives/components/descriptive_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def __call__(self, doc):
"doc._.syllables",
"doc._.counts",
"doc._.descriptive_stats",
"span._._n_tokens",
"span._._n_syllables",
"span._.token_length",
"span._.counts",
"span._.descriptive_stats",
Expand Down
40 changes: 1 addition & 39 deletions src/textdescriptives/components/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Utility functions for calculating various text descriptives."""
from typing import Any, Callable, Union
from typing import Union

from pyphen import Pyphen
from spacy.tokens import Doc, Span, Token
Expand Down Expand Up @@ -37,41 +37,3 @@ def count_syl(token: Token):
return max(1, word_hyphenated.count("-") + 1)

return [count_syl(token) for token in filter_tokens(doc)]


def span_getter_to_token_getter(
span_getter: Callable[[Span], Any],
) -> Callable[[Token], Any]:
"""Converts a span getter to a token getter.
Args:
span_getter (Callable[[Span], Any]):
The span getter function.
Returns:
Callable[[Token], Any]: The token getter function.
"""

def token_getter(token):
return span_getter(token.doc[token.i : token.i + 1])

return token_getter


def span_getter_to_doc_getter(
span_getter: Callable[[Span], Any],
) -> Callable[[Doc], Any]:
"""Converts a span getter to a document getter.
Args:
span_getter (Callable[[Span], Any]):
The span getter function.
Returns:
Callable[[Doc], Any]: The document getter function.
"""

def doc_getter(doc):
return span_getter(doc[:])

return doc_getter
114 changes: 61 additions & 53 deletions src/textdescriptives/extractors.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
"""Extract metrics as Pandas DataFrame."""
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union

import pandas as pd
import spacy
import spacy.cli
from spacy import Language
from spacy.tokens import Doc
from spacy.tokens import Doc, Span, Token
from wasabi import msg

from textdescriptives.utils import get_assigns, get_valid_metrics
from textdescriptives.utils import (
get_doc_assigns,
get_span_assigns,
get_token_assigns,
get_valid_metrics,
)


def __get_quality(doc: Doc) -> dict:
Expand Down Expand Up @@ -114,35 +119,14 @@ def download_spacy_model(lang: str, size: str) -> str:
Returns:
str: Name of the downloaded model.
"""
if isinstance(metrics, str):
metrics = [metrics]

if spacy_model is None and lang is None:
raise ValueError("Either a spacy model or a language must be provided.")

if metrics is None:
metrics = get_valid_metrics()

# load spacy model if any component requires it
nlp = load_spacy_model(
spacy_model=spacy_model,
lang=lang,
metrics=metrics,
spacy_model_size=spacy_model_size,
)

# add pipeline components
for component in metrics:
nlp.add_pipe(f"textdescriptives/{component}")

if isinstance(text, str):
text = [text]
docs = nlp.pipe(text)

df = extract_df(docs)
_clean_doc_extensions(metrics=metrics)
data_source = "news" if lang != "en" else "web"
spacy_model = f"{lang}_core_{data_source}_{size}"
# don't download if model already exists
if spacy_model in spacy.cli.info()["pipelines"]: # type: ignore
return spacy_model
spacy.cli.download(spacy_model)
return spacy_model

return df

def load_spacy_model(
spacy_model: Optional[str],
Expand All @@ -165,8 +149,14 @@ def load_spacy_model(
Language: a spacy pipeline
"""

metrics_requiring_spacy_model = {"dependency_distance", "pos_stats", "coherence"}
# if no spacy model is necesarry for the metrics, return a blank model for the language
metrics_requiring_spacy_model = {
"dependency_distance",
"pos_stats",
"coherence",
"pos_proportions",
}
# if no spacy model is necesarry for the metrics, return a blank model
# for the language
if not bool(metrics_requiring_spacy_model.intersection(metrics)):
if lang is not None:
return spacy.blank(lang)
Expand All @@ -189,10 +179,37 @@ def load_spacy_model(
return spacy.load(spacy_model)


def _remove_spacy_extension(
language: Union[Type[Doc], Type[Span], Type[Token]],
extension: str,
) -> None:
"""Remove spacy extension from a Language object if it exists."""
if language.has_extension(extension):
language.remove_extension(extension)


def _remove_textdescriptives_extensions() -> None:
"""Remove spacy extensions added by textdescriptives.
This is necessary to avoid errors if running `extract_metrics`
multiple times with different metrics
"""
for metric in get_valid_metrics():
doc_assigns = get_doc_assigns(metric)
for assigned in doc_assigns:
_remove_spacy_extension(language=Doc, extension=assigned)
span_assigns = get_span_assigns(metric)
for assigned in span_assigns:
_remove_spacy_extension(language=Span, extension=assigned)
token_assings = get_token_assigns(metric)
for assigned in token_assings:
_remove_spacy_extension(language=Token, extension=assigned)


def extract_metrics(
text: Union[str, List[str]],
spacy_model=None,
lang: str = None,
spacy_model: Optional[Language] = None,
lang: Optional[str] = None,
metrics: Optional[Iterable[str]] = None,
spacy_model_size: str = "lg",
) -> pd.DataFrame:
Expand All @@ -216,23 +233,6 @@ def extract_metrics(
Returns:
pd.DataFrame: DataFrame with a row for each text and column for each metric.
"""
data_source = "news" if lang != "en" else "web"
spacy_model = f"{lang}_core_{data_source}_{size}"
# don't download if model already exists
if spacy_model in spacy.cli.info()["pipelines"]:
return spacy_model
spacy.cli.download(spacy_model)
return spacy_model


def _clean_doc_extensions(metrics: Iterable[str]) -> None:
"""Remove doc extensions added by textdescriptives. This is necesarry to avoid
errors if running `extract_metrics` multiple times with different metrics"""
for metric in metrics:
assigns = get_assigns(metric)
for assigned in assigns:
Doc.remove_extension(assigned)
=======
if isinstance(metrics, str):
metrics = [metrics]

Expand All @@ -242,8 +242,16 @@ def _clean_doc_extensions(metrics: Iterable[str]) -> None:
if metrics is None:
metrics = get_valid_metrics()

# remove previously set metrics to avoid conflicts
_remove_textdescriptives_extensions()

# load spacy model if any component requires it
nlp = load_spacy_model(spacy_model, lang, metrics, spacy_model_size)
nlp = load_spacy_model(
spacy_model=spacy_model,
lang=lang,
metrics=metrics,
spacy_model_size=spacy_model_size,
)

# add pipeline components
for component in metrics:
Expand Down
30 changes: 28 additions & 2 deletions src/textdescriptives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from spacy import Language


def get_assigns(metric: str) -> List[str]:
"""Get columns for a given metric.
def get_doc_assigns(metric: str) -> List[str]:
"""Get doc extension attributes for a given metric.
Args:
metric (str): Metric to get columns for
Expand All @@ -20,6 +20,32 @@ def get_assigns(metric: str) -> List[str]:
]


def get_span_assigns(metric: str) -> List[str]:
"""Get span extension attributes for a given metric.
Args:
metric (str): Metric to get columns for
"""
return [
col[7:]
for col in Language.get_factory_meta(f"textdescriptives/{metric}").assigns
if col.startswith("span._.")
]


def get_token_assigns(metric: str) -> List[str]:
"""Get token extension attributes for a given metric.
Args:
metric (str): Metric to get columns for
"""
return [
col[8:]
for col in Language.get_factory_meta(f"textdescriptives/{metric}").assigns
if col.startswith("token._.")
]


def get_valid_metrics() -> set:
"""Get valid metrics for extractor.
Expand Down
34 changes: 27 additions & 7 deletions tests/test_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ def test_extract_with_lang():
@pytest.mark.parametrize(
"text",
[
"This is just a cute little text. Actually, it's two sentences. No, it's three",
"This is just a cute little text. Actually, it's two sentences. "
+ "No, it's three",
[
"This is just a cute little text. Actually, it's two sentences. No, it's three.",
"This is just a cute little text. Actually, it's two sentences. "
+ "No, it's three.",
"Two documents in this bad boy. Let's see how it works.",
],
],
Expand All @@ -133,6 +135,24 @@ def test_extract_similar_extract_df(text):
assert df.equals(df2)


def test_extract_df_then_extract_metric():
text = [
"This is just a cute little text. Actually, it's two sentences. "
+ "No, it's three.",
"Two documents in this bad boy. Let's see how it works.",
]
nlp = spacy.load("en_core_web_sm")
nlp.add_pipe("textdescriptives/coherence")
docs = nlp.pipe(text)
td.extract_df(docs)

td.extract_metrics(
text,
spacy_model="en_core_web_sm",
metrics="quality",
)


def test_extract_model_not_needed():
df = td.extract_metrics(
"This is just a cute little text. Actually, it's two sentences.",
Expand All @@ -144,13 +164,13 @@ def test_extract_model_not_needed():

def test_extract_metrics_twice():
text = "Just a small test"
df = td.extract_metrics(
td.extract_metrics(
text,
metrics="coherence",
metrics="descriptive_stats",
lang="en",
)
df2 = td.extract_metrics(
td.extract_metrics(
text,
metrics="descriptive_stats",
metrics="pos_proportions",
lang="en",
)
)
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def test_get_valid_metrics():
def test_get_columns():
columns_ = set()
for metric in td.get_valid_metrics():
columns = td.get_assigns(metric)
columns = td.get_doc_assigns(metric)
assert isinstance(columns, list)
assert len(columns) > 0
assert isinstance(columns[0], str)
columns_.update(columns)
columns_all = td.get_assigns("all")
columns_all = td.get_doc_assigns("all")
assert set(columns_all) == columns_
assert set(columns_all) == columns_

0 comments on commit 1a7ca00

Please sign in to comment.