Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix openai not conditioned imports #806

Merged
merged 6 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions .github/workflows/smoketest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Smoketest
on:
push:
branches:
- main
- release/*
pull_request:
branches:
- main
- release/*
workflow_dispatch:
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/dev' }}
defaults:
run:
working-directory: .
jobs:
smoketest:
runs-on: ubuntu-20.04
timeout-minutes: 10
strategy:
matrix:
python_version:
- "3.9"
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
- "3.10"
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
- name: Setup
run: |
set -ex
python -m pip install --upgrade 'pip<23' wheel
python -m pip install --upgrade .
python -m pip install pytest==7.2.1 pytest_codeblocks==0.16.1
- name: Run checks
run: |
pytest tests/test_smoketest.py
14 changes: 8 additions & 6 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
]
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from openai.types.completion_choice import Logprobs

if TYPE_CHECKING:
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from openai.types.completion_choice import Logprobs

MAX_RETRIES = 10

Expand Down Expand Up @@ -99,7 +101,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
'role':
'system',
'content':
model_cfg.get('sytsem_role_prompt',
model_cfg.get('system_role_prompt',
'Please complete the following text: ')
}, {
'role': 'user',
Expand Down Expand Up @@ -201,7 +203,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):

return torch.stack(output_logits_batch).to(batch['input_ids'].device)

def process_result(self, completion: Optional[ChatCompletion]):
def process_result(self, completion: Optional['ChatCompletion']):
if completion is None:
raise ValueError("Couldn't generate model output")

Expand Down Expand Up @@ -234,7 +236,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
logprobs=5,
temperature=0.0)

def process_result(self, completion: Optional[Completion]):
def process_result(self, completion: Optional['Completion']):
if completion is None:
raise ValueError("Couldn't generate model output")

Expand Down
16 changes: 16 additions & 0 deletions tests/test_smoketest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry import callbacks, data, models, optim, tokenizers, utils


# This very simple test is just to use the above imports, which check and make sure we can import all the top-level
# modules from foundry. This is mainly useful for checking that we have correctly conditionally imported all optional
# dependencies.
def test_smoketest():
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
assert callbacks
assert data
assert models
assert optim
assert tokenizers
assert utils
Loading