Skip to content

Commit

Permalink
Fix openai not conditioned imports (#806)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Dec 15, 2023
1 parent 5cc4dd4 commit f56f122
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 6 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/smoketest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Smoketest
on:
push:
branches:
- main
- release/*
pull_request:
branches:
- main
- release/*
workflow_dispatch:
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/dev' }}
defaults:
run:
working-directory: .
jobs:
smoketest:
runs-on: ubuntu-20.04
timeout-minutes: 10
strategy:
matrix:
python_version:
- "3.9"
- "3.10"
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
- name: Setup
run: |
set -ex
python -m pip install --upgrade 'pip<23' wheel
python -m pip install --upgrade .
python -m pip install pytest==7.2.1 pytest_codeblocks==0.16.1
- name: Run checks
run: |
pytest tests/test_smoketest.py
14 changes: 8 additions & 6 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
]
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from openai.types.completion_choice import Logprobs

if TYPE_CHECKING:
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from openai.types.completion_choice import Logprobs

MAX_RETRIES = 10

Expand Down Expand Up @@ -99,7 +101,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
'role':
'system',
'content':
model_cfg.get('sytsem_role_prompt',
model_cfg.get('system_role_prompt',
'Please complete the following text: ')
}, {
'role': 'user',
Expand Down Expand Up @@ -201,7 +203,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):

return torch.stack(output_logits_batch).to(batch['input_ids'].device)

def process_result(self, completion: Optional[ChatCompletion]):
def process_result(self, completion: Optional['ChatCompletion']):
if completion is None:
raise ValueError("Couldn't generate model output")

Expand Down Expand Up @@ -234,7 +236,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
logprobs=5,
temperature=0.0)

def process_result(self, completion: Optional[Completion]):
def process_result(self, completion: Optional['Completion']):
if completion is None:
raise ValueError("Couldn't generate model output")

Expand Down
16 changes: 16 additions & 0 deletions tests/test_smoketest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry import callbacks, data, models, optim, tokenizers, utils


# This very simple test is just to use the above imports, which check and make sure we can import all the top-level
# modules from foundry. This is mainly useful for checking that we have correctly conditionally imported all optional
# dependencies.
def test_smoketest():
assert callbacks
assert data
assert models
assert optim
assert tokenizers
assert utils

0 comments on commit f56f122

Please sign in to comment.