Skip to content

Commit

Permalink
feat: Ability to specify a custom model for ChatGPT. resolved #167
Browse files Browse the repository at this point in the history
  • Loading branch information
bookfere committed Nov 15, 2023
1 parent c783801 commit e1bd8fb
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 36 deletions.
32 changes: 14 additions & 18 deletions engines/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def __init__(self):
Base.__init__(self)
self.endpoint = self.config.get('endpoint', self.endpoint)
self.prompt = self.config.get('prompt', self.prompt)
self.model = self.config.get('model', self.model)
if self.model is not None:
self.model = self.config.get('model', self.model)
self.sampling = self.config.get('sampling', self.sampling)
self.temperature = self.config.get('temperature', self.temperature)
self.top_p = self.config.get('top_p', self.top_p)
Expand Down Expand Up @@ -73,18 +74,20 @@ def _get_headers(self):
'User-Agent': 'Ebook-Translator/%s' % EbookTranslator.__version__
}

def _get_body(self, text):
return {
def _get_data(self, text):
data = {
'stream': self.stream,
'model': self.model,
'messages': [
{'role': 'system', 'content': self._get_prompt()},
{'role': 'user', 'content': text}
]
}
if self.model is not None:
data.update(model=self.model)
return data

def translate(self, text):
data = self._get_body(text)
data = self._get_data(text)
sampling_value = getattr(self, self.sampling)
data.update({self.sampling: sampling_value})

Expand Down Expand Up @@ -119,23 +122,16 @@ def _parse_stream(self, data):
class AzureChatgptTranslate(ChatgptTranslate):
name = 'ChatGPT(Azure)'
alias = 'ChatGPT (Azure)'
endpoint = ('https://{your-resource-name}.openai.azure.com/openai/'
'deployments/{deployment-id}/chat/completions'
'?api-version={api-version}')
models = ['gpt-35-turbo', 'gpt-4', 'gpt-4-32k']
model = 'gpt-35-turbo'
endpoint = (
'$AZURE_OPENAI_ENDPOINT/openai/deployments/gpt-35-turbo/chat/'
'completions?api-version=2023-05-15')
model = None

def _get_headers(self):
return {
'Content-Type': 'application/json',
'api-key': self.api_key
}

def _get_body(self, text):
data = ChatgptTranslate._get_body(self, text)
# Some versions do not support the `model` parameter.
for version in ('2023-03-15-preview', '2023-05-15'):
if self.endpoint.endswith(version):
del data['model']
break
return data
def _get_data(self, text):
return ChatgptTranslate._get_data(self, text)
64 changes: 50 additions & 14 deletions setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,18 +393,24 @@ def layout_engine(self):
# ChatGPT Setting
chatgpt_group = QGroupBox(_('Tune ChatGPT'))
chatgpt_group.setVisible(False)
endpoint_layout = QFormLayout(chatgpt_group)
self.set_form_layout_policy(endpoint_layout)
chatgpt_layout = QFormLayout(chatgpt_group)
self.set_form_layout_policy(chatgpt_layout)

self.prompt = QPlainTextEdit()
self.prompt.setMinimumHeight(80)
self.prompt.setMaximumHeight(80)
endpoint_layout.addRow(_('Prompt'), self.prompt)
chatgpt_layout.addRow(_('Prompt'), self.prompt)
self.chatgpt_endpoint = QLineEdit()
endpoint_layout.addRow(_('Endpoint'), self.chatgpt_endpoint)
chatgpt_layout.addRow(_('Endpoint'), self.chatgpt_endpoint)

chatgpt_model = QComboBox()
endpoint_layout.addRow(_('Model'), chatgpt_model)
chatgpt_model = QWidget()
chatgpt_model_layout = QHBoxLayout(chatgpt_model)
chatgpt_model_layout.setContentsMargins(0, 0, 0, 0)
chatgpt_select = QComboBox()
chatgpt_custom = QLineEdit()
chatgpt_model_layout.addWidget(chatgpt_select)
chatgpt_model_layout.addWidget(chatgpt_custom)
chatgpt_layout.addRow(_('Model'), chatgpt_model)

self.disable_wheel_event(chatgpt_model)

Expand All @@ -427,13 +433,13 @@ def layout_engine(self):
sampling_layout.addWidget(top_p)
sampling_layout.addWidget(top_p_value)
sampling_layout.addStretch(1)
endpoint_layout.addRow(_('Sampling'), sampling_widget)
chatgpt_layout.addRow(_('Sampling'), sampling_widget)

self.disable_wheel_event(temperature_value)
self.disable_wheel_event(top_p_value)

stream_enabled = QCheckBox(_('Enable streaming text like in ChatGPT'))
endpoint_layout.addRow(_('Stream'), stream_enabled)
chatgpt_layout.addRow(_('Stream'), stream_enabled)

sampling_btn_group = QButtonGroup(sampling_widget)
sampling_btn_group.addButton(temperature, 0)
Expand All @@ -460,12 +466,42 @@ def show_chatgpt_preferences():
self.chatgpt_endpoint.setText(
config.get('endpoint', self.current_engine.endpoint))
# Model
chatgpt_model.clear()
chatgpt_model.addItems(self.current_engine.models)
chatgpt_model.setCurrentText(
config.get('model', self.current_engine.model))
chatgpt_model.currentTextChanged.connect(
lambda model: self.current_engine.config.update(model=model))
if self.current_engine.model is not None:
chatgpt_layout.setRowVisible(chatgpt_model, True)
chatgpt_select.clear()
chatgpt_select.addItems(self.current_engine.models)
chatgpt_select.addItem(_('Custom'))
model = config.get('model', self.current_engine.model)
chatgpt_select.setCurrentText(
model if model in self.current_engine.models
else _('Custom'))

def setup_chatgpt_model(model):
if model in self.current_engine.models:
chatgpt_custom.setVisible(False)
else:
chatgpt_custom.setVisible(True)
if model != _('Custom'):
chatgpt_custom.setText(model)
setup_chatgpt_model(model)

def update_chatgpt_model(model):
if not model or _(model) == _('Custom'):
model = self.current_engine.models[0]
config.update(model=model)

def change_chatgpt_model(model):
setup_chatgpt_model(model)
update_chatgpt_model(model)

chatgpt_custom.textChanged.connect(
lambda model: update_chatgpt_model(model=model.strip()))
chatgpt_select.currentTextChanged.connect(change_chatgpt_model)
self.save_config.connect(
lambda: chatgpt_select.setCurrentText(config.get('model')))
else:
chatgpt_layout.setRowVisible(chatgpt_model, False)

# Sampling
sampling = config.get('sampling', self.current_engine.sampling)
btn_id = self.current_engine.samplings.index(sampling)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def test_translate_stream(self, mock_browser, mock_request, mock_et):
'question-like content.')
data = json.dumps({
'stream': True,
'model': 'gpt-3.5-turbo',
'messages': [
{'role': 'system', 'content': prompt},
{'role': 'user', 'content': 'Hello World!'}
],
'model': 'gpt-3.5-turbo',
'temperature': 1,
})
mock_et.__version__ = '1.0.0'
Expand Down Expand Up @@ -174,7 +174,6 @@ def test_translate(self, mock_browser, mock_request):
'question-like content.')
data = json.dumps({
'stream': True,
# 'model': 'gpt-35-turbo',
'messages': [
{'role': 'system', 'content': prompt},
{'role': 'user', 'content': 'Hello World!'}
Expand All @@ -192,8 +191,8 @@ def test_translate(self, mock_browser, mock_request):
template % i.encode() for i in '你好世界!'] \
+ ['data: [DONE]'.encode()]
mock_browser.return_value.response.return_value = mock_response
url = 'https://test.openai.azure.com/openai/deployments/test/' \
'chat/completions?api-version=2023-05-15'
url = ('https://docs-test-001.openai.azure.com/openai/deployments/'
'gpt-35-turbo/chat/completions?api-version=2023-05-15')
self.translator.endpoint = url
result = self.translator.translate('Hello World!')
mock_request.assert_called_with(
Expand Down

0 comments on commit e1bd8fb

Please sign in to comment.