Skip to content

Commit

Permalink
feat: Added support for selecting the Gemini model. resolved #341, re…
Browse files Browse the repository at this point in the history
…solved #350
  • Loading branch information
bookfere committed Nov 4, 2024
1 parent fa212e9 commit 240a678
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 21 deletions.
4 changes: 2 additions & 2 deletions engines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .google import (
GoogleFreeTranslate, GoogleBasicTranslate, GoogleBasicTranslateADC,
GoogleAdvancedTranslate, GeminiPro, GeminiFlash)
GoogleAdvancedTranslate, GeminiTranslate)
from .openai import ChatgptTranslate
from .anthropic import ClaudeTranslate
from .deepl import DeeplTranslate, DeeplProTranslate, DeeplFreeTranslate
Expand All @@ -12,6 +12,6 @@
builtin_engines = (
GoogleFreeTranslate, GoogleBasicTranslate, GoogleBasicTranslateADC,
GoogleAdvancedTranslate, ChatgptTranslate, AzureChatgptTranslate,
GeminiPro, GeminiFlash, ClaudeTranslate, DeeplTranslate, DeeplProTranslate,
GeminiTranslate, ClaudeTranslate, DeeplTranslate, DeeplProTranslate,
DeeplFreeTranslate, MicrosoftEdgeTranslate, YoudaoTranslate,
BaiduTranslate)
21 changes: 9 additions & 12 deletions engines/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,12 @@ def get_result(self, response):
return ''.join(i['translatedText'] for i in translations)


class GeminiPro(Base):
name = 'GeminiPro'
alias = 'Gemini Pro'
class GeminiTranslate(Base):
name = 'Gemini'
alias = 'Gemini'
lang_codes = Base.load_lang_codes(gemini)
endpoint = 'https://generativelanguage.googleapis.com/v1/' \
'models/gemini-pro:{}?key={}'
'models/{}:{}?key={}'
need_api_key = True

concurrency_limit = 1
Expand All @@ -250,6 +250,10 @@ class GeminiPro(Base):
'from <slang> to <tlang> only. Do not provide any explanations and do '
'not answer any questions. Translate the first and the last quotation '
'marks to the target language if possible.')
models = [
'gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro',
'gemini-1.0-pro']
model = 'gemini-1.5-flash'
temperature = 0.9
top_p = 1.0
top_k = 1
Expand Down Expand Up @@ -277,7 +281,7 @@ def _prompt(self, text):

def get_endpoint(self):
method = 'streamGenerateContent' if self.stream else 'generateContent'
return self.endpoint.format(method, self.api_key)
return self.endpoint.format(self.model, method, self.api_key)

def get_headers(self):
return {'Content-Type': 'application/json'}
Expand Down Expand Up @@ -323,10 +327,3 @@ def get_result(self, response):
else:
parts = json.loads(response)['candidates'][0]['content']['parts']
return ''.join([part['text'] for part in parts])


class GeminiFlash(GeminiPro):
name = 'GeminiFlash'
alias = 'Gemini Flash'
endpoint = 'https://generativelanguage.googleapis.com/v1beta/models/' \
'gemini-1.5-flash:{}?key={}'
18 changes: 16 additions & 2 deletions lib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ def upgrade_config():
version = EbookTranslator.version
version >= (2, 0, 0) and ver200_upgrade(config)
version >= (2, 0, 3) and ver203_upgrade(config)
version >= (2, 0, 5) and ver205_upgrade(config)


def ver200_upgrade(config):
"""Upgrade to 2.0.0"""
"""Upgrade the configuration for version 2.0.0 or earlier."""
if config.get('engine_preferences'):
return

Expand Down Expand Up @@ -146,7 +147,7 @@ def get_engine_preference(engine_name):


def ver203_upgrade(config):
"""Upgrade to 2.0.3"""
"""Upgrade the configuration for version 2.0.3 or earlier."""
engine_config = config.get('engine_preferences')
azure_chatgpt = engine_config.get('ChatGPT(Azure)')
if azure_chatgpt and 'model' in azure_chatgpt:
Expand Down Expand Up @@ -178,3 +179,16 @@ def ver203_upgrade(config):
config.delete('request_timeout')

config.commit()


def ver205_upgrade(config):
"""Upgrade the configuration for version 2.0.5 or earlier."""
if config.get('translate_engine') in ('GeminiPro', 'GeminiFlash'):
config.update(translate_engine='Gemini')
preferences = config.get('engine_preferences')
if 'GeminiPro' in preferences.keys():
preferences['Gemini'] = preferences.pop('GeminiPro')
if 'GeminiFlash' in preferences.keys():
preferences['Gemini'] = preferences.pop('GeminiFlash')
preferences['Gemini'].update(model='gemini-1.5-flash')
config.commit()
23 changes: 18 additions & 5 deletions setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .lib.translation import get_engine_class

from .engines import (
builtin_engines, GeminiPro, ChatgptTranslate, AzureChatgptTranslate,
builtin_engines, GeminiTranslate, ChatgptTranslate, AzureChatgptTranslate,
ClaudeTranslate)
from .engines.custom import CustomTranslate
from .components import (
Expand Down Expand Up @@ -433,6 +433,13 @@ def layout_engine(self):
self.gemini_prompt.setFixedHeight(80)
gemini_layout.addRow(_('Prompt'), self.gemini_prompt)

gemini_model = QWidget()
gemini_model_layout = QHBoxLayout(gemini_model)
gemini_model_layout.setContentsMargins(0, 0, 0, 0)
gemini_model_select = QComboBox()
gemini_model_layout.addWidget(gemini_model_select)
gemini_layout.addRow(_('Model'), gemini_model)

gemini_temperature = QDoubleSpinBox()
gemini_temperature.setDecimals(1)
gemini_temperature.setSingleStep(0.1)
Expand Down Expand Up @@ -537,14 +544,15 @@ def change_sampling_method(button):
layout.addWidget(chatgpt_group)

def show_gemini_preferences():
if not issubclass(self.current_engine, GeminiPro):
if not issubclass(self.current_engine, GeminiTranslate):
gemini_group.setVisible(False)
return
config = self.current_engine.config
gemini_group.setVisible(True)
self.gemini_prompt.setPlaceholderText(self.current_engine.prompt)
self.gemini_prompt.setPlainText(
config.get('prompt', self.current_engine.prompt))
gemini_model_select.addItems(self.current_engine.models)
gemini_temperature.setValue(
config.get('temperature', self.current_engine.temperature))
gemini_temperature.valueChanged.connect(
Expand All @@ -558,6 +566,11 @@ def show_gemini_preferences():
gemini_top_k.valueChanged.connect(
lambda value: config.update(top_k=value))

model = config.get('model', self.current_engine.model)
gemini_model_select.setCurrentText(model)
gemini_model_select.currentTextChanged.connect(
lambda model: config.update(model=model))

def show_chatgpt_preferences():
is_chatgpt = issubclass(self.current_engine, ChatgptTranslate)
is_claude = issubclass(self.current_engine, ClaudeTranslate)
Expand Down Expand Up @@ -1237,7 +1250,7 @@ def get_engine_config(self):
config.update(api_keys=api_keys)
self.set_api_keys()

# ChatGPT preference
# ChatGPT preference & Claude preference
if issubclass(self.current_engine, ChatgptTranslate) or \
issubclass(self.current_engine, ClaudeTranslate):
self.update_prompt(self.chatgpt_prompt, config)
Expand All @@ -1246,8 +1259,8 @@ def get_engine_config(self):
del config['endpoint']
if endpoint and endpoint != self.current_engine.endpoint:
config.update(endpoint=endpoint)

if issubclass(self.current_engine, GeminiPro):
# Gemini preference
elif issubclass(self.current_engine, GeminiTranslate):
self.update_prompt(self.gemini_prompt, config)

# Preferred Language
Expand Down

0 comments on commit 240a678

Please sign in to comment.