Skip to content

Commit

Permalink
Optional dependencies (stanford-crfm#1798)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored and danielz02 committed Sep 7, 2023
1 parent 5e65806 commit 91863a8
Show file tree
Hide file tree
Showing 17 changed files with 122 additions and 37 deletions.
4 changes: 3 additions & 1 deletion install-dev.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ pip install --no-binary=protobuf protobuf==3.20.2
# Install all pinned dependencies
pip install -r requirements-freeze.txt
# Install HELM in edit mode
pip install -e .
pip install -e .[all]
# Check dependencies
pip check
2 changes: 0 additions & 2 deletions pre-commit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ if [ "$valid_version" == "False" ]; then
exit 1
fi

pip check

# Python style checks and linting
black --check --diff src scripts || (
echo ""
Expand Down
38 changes: 24 additions & 14 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ install_requires=
# sqlitedict==2.0.0 is slow! https://github.com/RaRe-Technologies/sqlitedict/issues/152
# Keep sqlitedict version at 1.7.0.
sqlitedict~=1.7.0
bottle~=0.12.23
# TODO: Remove these from common
protobuf~=3.20.2 # Can't use 4.21.0 due to backward incompatibility
pymongo~=4.2.0
Expand All @@ -56,15 +57,6 @@ install_requires=
# TODO: Remove after this issue is resolved
scikit-learn~=1.1.2

# Server Extras
bottle~=0.12.23
gunicorn~=20.1.0

# Scenario Extras
gdown~=4.4.0 # For opinions_qa_scenario
sympy~=1.11.1 # For numeracy_scenario
xlrd~=2.0.1 # For ice_scenario: used by pandas.read_excel

# Model Extras
aleph-alpha-client~=2.14.0
anthropic~=0.2.5
Expand All @@ -84,20 +76,38 @@ install_requires=

# Metrics Extras
google-api-python-client~=2.64.0 # For perspective_api_client via toxicity_metrics

[options.extras_require]
proxy-server =
gunicorn~=20.1.0

human-evaluation =
scaleapi~=2.13.0
surge-api~=1.1.0

scenarios =
gdown~=4.4.0 # For disinformation_scenario, med_mcqa_scenario, med_qa_scenario: used by ensure_file_downloaded()
sympy~=1.11.1 # For numeracy_scenario
xlrd~=2.0.1 # For ice_scenario: used by pandas.read_excel()

metrics =
numba~=0.56.4 # For copyright_metrics
pytrec_eval==0.5 # For ranking_metrics
sacrebleu~=2.2.1 # For disinformation_metrics, machine_translation_metrics
summ-eval~=0.892 # For summarization_metrics

# Human Evaluation Extras
scaleapi~=2.13.0
surge-api~=1.1.0

# Plots Extras
plots =
colorcet~=3.0.1
matplotlib~=3.6.0
seaborn~=0.11.0

all =
crfm-helm[server]
crfm-helm[human-evaluation]
crfm-helm[scenarios]
crfm-helm[metrics]
crfm-helm[plots]

[options.entry_points]
console_scripts =
helm-run = helm.benchmark.run:main
Expand Down
7 changes: 6 additions & 1 deletion src/helm/benchmark/metrics/copyright_metrics.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import re
from typing import List, Optional

import numba
import numpy as np
from nltk.tokenize.treebank import TreebankWordTokenizer

from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.scenarios.scenario import Reference
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import RequestResult
from .metric import Metric
from .metric_name import MetricName
from .metric_service import MetricService
from .statistic import Stat

try:
import numba
except ModuleNotFoundError as e:
handle_module_not_found_error(e)


def _longest_common_prefix_length(s1: np.ndarray, s2: np.ndarray, previous_best: Optional[float] = None) -> float:
"""Compute the length of the longest common prefix."""
Expand Down
7 changes: 6 additions & 1 deletion src/helm/benchmark/metrics/disinformation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from typing import Dict, List, Optional

import numpy as np
from sacrebleu.metrics import BLEU

from helm.common.general import ensure_file_downloaded
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import RequestResult, Sequence
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
Expand All @@ -16,6 +16,11 @@
from .metric_service import MetricService
from .statistic import Stat

try:
from sacrebleu.metrics import BLEU
except ModuleNotFoundError as e:
handle_module_not_found_error(e)


HUMAN_EVAL_CODALAB_LINK: str = (
"https://worksheets.codalab.org/rest/bundles/0xd8c577022f584f27aead3f00aa771da5/contents/blob/{file_name}"
Expand Down
7 changes: 6 additions & 1 deletion src/helm/benchmark/metrics/machine_translation_metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import List
from sacrebleu import BLEU

from helm.benchmark.adaptation.request_state import RequestState
from helm.common.optional_dependencies import handle_module_not_found_error
from .metric import Metric
from .metric_name import MetricName
from .statistic import Stat

try:
from sacrebleu.metrics import BLEU
except ModuleNotFoundError as e:
handle_module_not_found_error(e)


class MachineTranslationMetric(Metric):
"""
Expand Down
8 changes: 6 additions & 2 deletions src/helm/benchmark/metrics/ranking_metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Tuple, Optional

import pytrec_eval

from helm.benchmark.adaptation.adapters.adapter_factory import ADAPT_RANKING_BINARY
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.benchmark.scenarios.scenario import unpack_tag, CORRECT_TAG, Reference
from helm.common.request import RequestResult
from helm.common.general import binarize_dict
Expand All @@ -14,6 +13,11 @@
from .metric_service import MetricService
from .statistic import Stat

try:
import pytrec_eval
except ModuleNotFoundError as e:
handle_module_not_found_error(e)


@dataclass
class RankingObject:
Expand Down
8 changes: 7 additions & 1 deletion src/helm/benchmark/metrics/summarization_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.common.hierarchical_logger import hlog
from helm.common.general import ensure_file_downloaded
from helm.common.optional_dependencies import handle_module_not_found_error
from .metric import Metric, MetricResult
from .metric_name import MetricName
from .metric_service import MetricService
Expand All @@ -21,6 +22,7 @@
from .summac.model_summac import SummaCZS
from bert_score import BERTScorer


QAFACTEVAL_CODALAB_LINK: str = (
"https://worksheets.codalab.org/rest/bundles/0xf4de83c1f0d34d7999480223e8f5ab87/contents/blob/"
)
Expand Down Expand Up @@ -52,7 +54,11 @@ def __init__(self, task: str, device: str = "cpu"):
# `NameError: name 'stderr' is not defined`
if not spacy.util.is_package("en_core_web_sm"):
spacy.cli.download("en_core_web_sm") # type: ignore
from summ_eval.data_stats_metric import DataStatsMetric

try:
from summ_eval.data_stats_metric import DataStatsMetric
except ModuleNotFoundError as e:
handle_module_not_found_error(e)

self.data_stats_metric = DataStatsMetric()
self.task: str = task
Expand Down
14 changes: 10 additions & 4 deletions src/helm/benchmark/presentation/create_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
import os
from typing import List, Dict, Optional, Any, Callable, Union, Mapping, Tuple, Set

import colorcet
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import pearsonr
import seaborn as sns

from helm.common.hierarchical_logger import hlog
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.benchmark.presentation.schema import read_schema
from helm.benchmark.presentation.summarize import AGGREGATE_WIN_RATE_COLUMN

try:
import colorcet
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
except ModuleNotFoundError as e:
handle_module_not_found_error(e)


sns.set_style("whitegrid")

DOWN_ARROW = "\u2193"
Expand Down
3 changes: 2 additions & 1 deletion src/helm/benchmark/run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from .scenarios.scenario import ScenarioSpec
from .scenarios.big_bench_scenario import BIGBenchScenario
from .scenarios.msmarco_scenario import MSMARCOScenario
from .scenarios.numeracy_scenario import get_numeracy_adapter_spec, RELTYPE_INFO
from .scenarios.copyright_scenario import datatag2hash_code
from .scenarios.raft_scenario import get_raft_instructions
from .scenarios.lextreme_scenario import (
Expand Down Expand Up @@ -1043,6 +1042,8 @@ def get_raft_spec(subset: str) -> RunSpec:
def get_numeracy_spec(
relation_type: str = "linear", mode: str = "function", seed: str = "0", run_solver: str = "False"
) -> RunSpec:
from .scenarios.numeracy_scenario import get_numeracy_adapter_spec, RELTYPE_INFO

run_solver: bool = True if run_solver == "True" else False # type: ignore
random_seed = int(seed)
scenario_spec = ScenarioSpec(
Expand Down
7 changes: 7 additions & 0 deletions src/helm/benchmark/scenarios/ice_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@
from enum import Enum
import pandas as pd

from helm.common.optional_dependencies import handle_module_not_found_error
from .ice_scenario_pinned_file_order import listdir_with_pinned_file_order
from .scenario import Scenario, Instance, TEST_SPLIT, Input

try:
# pd.read_excel() uses xlrd
import xlrd # noqa
except ModuleNotFoundError as e:
handle_module_not_found_error(e)


class ICESubset(Enum):
CANADA = "can"
Expand Down
11 changes: 8 additions & 3 deletions src/helm/benchmark/scenarios/numeracy_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@
import numpy as np
import numpy.typing as npt
import random
import sympy
from sympy import Symbol, Poly, diff
from sympy.parsing.sympy_parser import standard_transformations, implicit_multiplication_application
from typing import List, Optional, Tuple, Dict

from helm.benchmark.adaptation.adapters.adapter_factory import ADAPT_GENERATION
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.window_services.tokenizer_service import TokenizerService
from helm.common.authentication import Authentication
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.proxy.services.server_service import ServerService
from .scenario import Scenario, Instance, Reference, TRAIN_SPLIT, TEST_SPLIT, CORRECT_TAG, Input, Output

try:
import sympy
from sympy import Symbol, Poly, diff
from sympy.parsing.sympy_parser import standard_transformations, implicit_multiplication_application
except ModuleNotFoundError as e:
handle_module_not_found_error(e)


# TODO: we shouldn't create an Adapter and TokenizerService in a scenario
# The Adapter and Scenarios should be completely decoupled.
Expand Down
5 changes: 5 additions & 0 deletions src/helm/common/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dataclasses import asdict, is_dataclass

from helm.common.hierarchical_logger import hlog, htrack, htrack_block
from helm.common.optional_dependencies import handle_module_not_found_error


_CREDENTIALS_FILE_NAME = "credentials.conf"
Expand Down Expand Up @@ -82,6 +83,10 @@ def ensure_file_downloaded(
# gdown is used to download large files/zip folders from Google Drive.
# It bypasses security warnings which wget cannot handle.
if source_url.startswith("https://drive.google.com"):
try:
import gdown # noqa
except ModuleNotFoundError as e:
handle_module_not_found_error(e)
downloader_executable = "gdown"
tmp_path: str = f"{target_path}.tmp"
shell([downloader_executable, source_url, "-O", tmp_path])
Expand Down
10 changes: 10 additions & 0 deletions src/helm/common/optional_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class OptionalDependencyNotInstalled(Exception):
pass


def handle_module_not_found_error(e: ModuleNotFoundError):
# TODO: Ask user to install more specific optional dependencies
# e.g. crfm-helm[plots] or crfm-helm[server]
raise OptionalDependencyNotInstalled(
f"Optional dependency {e.name} is not installed. " "Please run `pip install helm-crfm[all]` to install it."
) from e
11 changes: 8 additions & 3 deletions src/helm/proxy/clients/scale_critique_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from typing import Dict, List, Union, Set, Any

from cattrs import unstructure
import scaleapi
from scaleapi.tasks import TaskType, TaskStatus
from scaleapi.exceptions import ScaleDuplicateResource

from helm.common.hierarchical_logger import hlog
from helm.common.cache import Cache, CacheConfig
Expand All @@ -17,8 +14,16 @@
CritiqueTaskTemplate,
CritiqueResponse,
)
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.proxy.clients.critique_client import CritiqueClient

try:
import scaleapi
from scaleapi.tasks import TaskType, TaskStatus
from scaleapi.exceptions import ScaleDuplicateResource
except ModuleNotFoundError as e:
handle_module_not_found_error(e)


class ScaleCritiqueClientError(Exception):
pass
Expand Down
10 changes: 7 additions & 3 deletions src/helm/proxy/clients/surge_ai_critique_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
import threading
from typing import Dict, List

import surge
from surge import questions as surge_questions

from helm.common.cache import Cache, CacheConfig
from helm.common.critique_request import (
CritiqueQuestionTemplate,
Expand All @@ -14,8 +11,15 @@
CritiqueTaskTemplate,
)
from helm.common.hierarchical_logger import hlog
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.proxy.clients.critique_client import CritiqueClient

try:
import surge
from surge import questions as surge_questions
except ModuleNotFoundError as e:
handle_module_not_found_error(e)


_surge_cache_lock = threading.Lock()

Expand Down
7 changes: 7 additions & 0 deletions src/helm/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@

from helm.common.authentication import Authentication
from helm.common.hierarchical_logger import hlog
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import Request
from helm.common.perspective_api_request import PerspectiveAPIRequest
from helm.common.tokenization_request import TokenizationRequest, DecodeRequest
from .accounts import Account
from .services.server_service import ServerService
from .query import Query

try:
import gunicorn # noqa
except ModuleNotFoundError as e:
handle_module_not_found_error(e)


bottle.BaseRequest.MEMFILE_MAX = 1024 * 1024

app = bottle.default_app()
Expand Down

0 comments on commit 91863a8

Please sign in to comment.