Skip to content

Commit

Permalink
fix: support async context, fix list ops, and fix pronunciation extra…
Browse files Browse the repository at this point in the history
…ction
  • Loading branch information
ssut committed Nov 23, 2024
1 parent a0eb858 commit 15ecbc0
Showing 1 changed file with 39 additions and 17 deletions.
56 changes: 39 additions & 17 deletions googletrans/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
You can translate text using this module.
"""

import asyncio
import random
import re
import typing
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
proxies: typing.Optional[ProxiesTypes] = None,
timeout: typing.Optional[Timeout] = None,
http2: bool = True,
list_operation_max_concurrency: int = 2,
):
self.client = httpx.AsyncClient(
http2=http2,
Expand All @@ -86,7 +88,7 @@ def __init__(
# default way of working: use the defined values from user app
self.service_urls = service_urls
self.client_type = "webapp"
self.tok1en_acquirer = TokenAcquirer(
self.token_acquirer = TokenAcquirer(
client=self.client, host=self.service_urls[0]
)

Expand All @@ -99,12 +101,19 @@ def __init__(
break

self.raise_exception = raise_exception
self.list_operation_max_concurrency = list_operation_max_concurrency

def _pick_service_url(self) -> str:
if len(self.service_urls) == 1:
return self.service_urls[0]
return random.choice(self.service_urls)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.aclose()

async def _translate(
self, text: str, dest: str, src: str, override: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.List[typing.Any], Response]:
Expand Down Expand Up @@ -266,10 +275,17 @@ async def translate(
raise ValueError("invalid destination language")

if isinstance(text, list):
result = []
for item in text:
translated = await self.translate(item, dest=dest, src=src, **kwargs)
result.append(translated)
concurrency_limit = kwargs.pop(
"list_operation_max_concurrency", self.list_operation_max_concurrency
)
semaphore = asyncio.Semaphore(concurrency_limit)

async def translate_with_semaphore(item):
async with semaphore:
return await self.translate(item, dest=dest, src=src, **kwargs)

tasks = [translate_with_semaphore(item) for item in text]
result = await asyncio.gather(*tasks)
return result

origin = text
Expand All @@ -289,17 +305,16 @@ async def translate(

pron = origin
try:
# Get pronunciation from [0][1][3] which contains romanized pronunciation
if data[0][1] and len(data[0][1]) > 3:
pron = data[0][1][3]
# Fallback to previous methods if not found
elif data[0][1] and len(data[0][1]) > 2:
pron = data[0][1][2]
elif data[0][1] and len(data[0][1]) >= 2:
pron = data[0][1][-2]
pron = data[0][1][-2]
except Exception: # pragma: nocover
pass

if pron is None:
try:
pron = data[0][1][2]
except: # pragma: nocover # noqa: E722
pass

if dest in EXCLUDES and pron == origin:
pron = translated

Expand Down Expand Up @@ -358,10 +373,17 @@ async def detect(
fr 0.043500196
"""
if isinstance(text, list):
result = []
for item in text:
lang = self.detect(item)
result.append(lang)
concurrency_limit = kwargs.pop(
"list_operation_max_concurrency", self.list_operation_max_concurrency
)
semaphore = asyncio.Semaphore(concurrency_limit)

async def detect_with_semaphore(item):
async with semaphore:
return await self.detect(item, **kwargs)

tasks = [detect_with_semaphore(item) for item in text]
result = await asyncio.gather(*tasks)
return result

data, response = await self._translate(text, "en", "auto", kwargs)
Expand Down

0 comments on commit 15ecbc0

Please sign in to comment.