Skip to content

Commit

Permalink
new: refactor async client generator
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Oct 13, 2023
1 parent f228567 commit e72ed0f
Show file tree
Hide file tree
Showing 25 changed files with 625 additions and 570 deletions.
76 changes: 21 additions & 55 deletions tools/async_client_generator/base_client_generator.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,40 @@
# type: ignore
import ast
from typing import Optional
from typing import Dict, List, Optional

from tools.async_client_generator.base_generator import BaseGenerator
from tools.async_client_generator.transformers import (
ClassDefTransformer,
ConstantTransformer,
FunctionDefTransformer,
)

class AsyncAwaitTransformer(ast.NodeTransformer):

class BaseClientGenerator(BaseGenerator):
def __init__(
self,
keep_sync: Optional[list[str]] = None,
class_replace_map: Optional[dict] = None,
import_replace_map: Optional[dict] = None,
constant_replace_map: Optional[dict] = None,
keep_sync: Optional[List[str]] = None,
class_replace_map: Optional[Dict[str, str]] = None,
constant_replace_map: Optional[Dict[str, str]] = None,
):
self._async_methods = None
self.keep_sync = keep_sync if keep_sync is not None else []
self.class_replace_map = class_replace_map if class_replace_map is not None else {}
self.import_replace_map = import_replace_map if import_replace_map is not None else {}
self.constant_replace_map = (
constant_replace_map if constant_replace_map is not None else {}
)

def _keep_sync(self, name: str) -> bool:
return name in self.keep_sync

def visit_FunctionDef(self, sync_node: ast.FunctionDef):
if self._keep_sync(sync_node.name):
return self.generic_visit(sync_node)
super().__init__()

async_node = ast.AsyncFunctionDef(
name=sync_node.name,
args=sync_node.args,
body=sync_node.body,
decorator_list=sync_node.decorator_list,
returns=sync_node.returns,
type_comment=sync_node.type_comment,
)
async_node.lineno = sync_node.lineno
async_node.col_offset = sync_node.col_offset
async_node.end_lineno = sync_node.end_lineno
async_node.end_col_offset = sync_node.end_col_offset
return self.generic_visit(async_node)

def visit_ClassDef(self, node: ast.ClassDef):
# update class name
for old_value, new_value in self.class_replace_map.items():
node.name = node.name.replace(old_value, new_value)
return self.generic_visit(node)

def visit_Constant(self, node: ast.arg):
for old_value, new_value in self.constant_replace_map.items():
if isinstance(node.value, str):
node.value = node.value.replace(old_value, new_value)
return self.generic_visit(node)
self.transformers.append(FunctionDefTransformer(keep_sync))
self.transformers.append(ClassDefTransformer(class_replace_map))
self.transformers.append(ConstantTransformer(constant_replace_map))


if __name__ == "__main__":
from .config import CODE_DIR
from tools.async_client_generator.config import CLIENT_DIR, CODE_DIR

with open(CODE_DIR / "client_base.py", "r") as source_file:
with open(CLIENT_DIR / "client_base.py", "r") as source_file:
code = source_file.read()

# Parse the code into an AST
parsed_code = ast.parse(code)

await_transformer = AsyncAwaitTransformer(
base_client_generator = BaseClientGenerator(
keep_sync=["__init__", "upload_records", "upload_collection", "migrate"],
class_replace_map={"QdrantBase": "AsyncQdrantBase"},
constant_replace_map={"QdrantBase": "AsyncQdrantBase"},
)
modified_code_ast = await_transformer.visit(parsed_code)
modified_code = ast.unparse(modified_code_ast)
modified_code = base_client_generator.generate(code)

with open("async_client_base.py", "w") as target_file:
with open(CODE_DIR / "async_client_base.py", "w") as target_file:
target_file.write(modified_code)
16 changes: 16 additions & 0 deletions tools/async_client_generator/base_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# type: ignore
import ast
from typing import List


class BaseGenerator:
def __init__(self):
self.transformers: List[ast.NodeTransformer] = []

def generate(self, code: str) -> str:
nodes = ast.parse(code)

for transformer in self.transformers:
nodes = transformer.visit(nodes)

return ast.unparse(nodes)
213 changes: 48 additions & 165 deletions tools/async_client_generator/client_generator.py
Original file line number Diff line number Diff line change
@@ -1,190 +1,73 @@
# type: ignore

import ast
import inspect
from typing import Optional

from .async_client_base import AsyncQdrantBase


class AsyncAwaitTransformer(ast.NodeTransformer):
from typing import Dict, List, Optional

from tools.async_client_generator.async_client_base import AsyncQdrantBase
from tools.async_client_generator.base_generator import BaseGenerator
from tools.async_client_generator.transformers import (
AnnAssignTransformer,
CallTransformer,
ClassDefTransformer,
ImportFromTransformer,
ImportTransformer,
NameTransformer,
)
from tools.async_client_generator.transformers.client import (
ClientFunctionDefTransformer,
)


class ClientGenerator(BaseGenerator):
def __init__(
self,
keep_sync: Optional[list[str]] = None,
class_replace_map: Optional[dict] = None,
import_replace_map: Optional[dict] = None,
exclude_methods: Optional[list[str]] = None,
rename_methods: Optional[dict[str, str]] = None,
keep_sync: Optional[List[str]] = None,
class_replace_map: Optional[Dict[str, str]] = None,
import_replace_map: Optional[Dict[str, str]] = None,
exclude_methods: Optional[List[str]] = None,
rename_methods: Optional[Dict[str, str]] = None,
):
super().__init__()
self._async_methods = None
self.keep_sync = keep_sync if keep_sync is not None else []
self.class_replace_map = class_replace_map if class_replace_map is not None else {}
self.import_replace_map = import_replace_map if import_replace_map is not None else {}
self.exclude_methods = exclude_methods if exclude_methods is not None else []
self.rename_methods = rename_methods if rename_methods is not None else {}

def _keep_sync(self, name: str) -> bool:
return name in self.keep_sync
self.transformers.append(ImportTransformer(import_replace_map))
self.transformers.append(ImportFromTransformer(import_replace_map))
self.transformers.append(
ClientFunctionDefTransformer(
keep_sync, class_replace_map, exclude_methods, rename_methods, self.async_methods
)
)
self.transformers.append(ClassDefTransformer(class_replace_map))
self.transformers.append(AnnAssignTransformer(class_replace_map))

# call_transformer should be after function_def_transformer
self.transformers.append(CallTransformer(class_replace_map, self.async_methods))
# name_transformer should be after function_def, class_def and ann_assign transformers
self.transformers.append(
NameTransformer(class_replace_map, import_replace_map, rename_methods)
)

@property
def async_methods(self):
def async_methods(self) -> List[str]:
if self._async_methods is None:
self._async_methods = self.get_async_methods(AsyncQdrantBase)
return self._async_methods

@staticmethod
def get_async_methods(class_obj):
def get_async_methods(class_obj: type) -> List[str]:
async_methods = []
for name, method in inspect.getmembers(class_obj):
if inspect.iscoroutinefunction(method):
async_methods.append(name)
return async_methods

def visit_Name(self, node: ast.Name):
if node.id in self.class_replace_map:
node.id = self.class_replace_map[node.id]
elif node.id in self.import_replace_map:
node.id = self.import_replace_map[node.id]
elif node.id in self.rename_methods:
node.id = self.rename_methods[node.id]
return self.generic_visit(node)

def visit_Call(self, node):
if isinstance(node.func, ast.Name):
if node.func.id in self.class_replace_map:
node.func.id = self.class_replace_map[node.func.id]

if isinstance(node.func, ast.Attribute):
if node.func.attr in self.async_methods:
return ast.Await(value=node)

return self.generic_visit(node)

def visit_AnnAssign(self, node: ast.AnnAssign):
for old_value, new_value in self.class_replace_map.items():
if isinstance(node.annotation, ast.Name):
node.annotation.id = node.annotation.id.replace(old_value, new_value)
return self.generic_visit(node)

def visit_FunctionDef(self, sync_node: ast.FunctionDef):
if sync_node.name in self.exclude_methods:
return None

if sync_node.name == "__init__":
return self.generate_init(sync_node)

if self._keep_sync(sync_node.name):
return self.generic_visit(sync_node)

async_node = ast.AsyncFunctionDef(
name=sync_node.name,
args=sync_node.args,
body=sync_node.body,
decorator_list=sync_node.decorator_list,
returns=sync_node.returns,
type_comment=sync_node.type_comment,
)
async_node.lineno = sync_node.lineno
async_node.col_offset = sync_node.col_offset
async_node.end_lineno = sync_node.end_lineno
async_node.end_col_offset = sync_node.end_col_offset
return self.generic_visit(async_node)

def visit_Import(self, node: ast.Import):
for old_value, new_value in self.import_replace_map.items():
for alias in node.names:
alias.name = alias.name.replace(old_value, new_value)
return self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom):
# update module name
for old_value, new_value in self.import_replace_map.items():
node.module = node.module.replace(old_value, new_value)

# update imported item name
for alias in node.names:
if hasattr(alias, "name"):
for old_value, new_value in self.import_replace_map.items():
alias.name = alias.name.replace(old_value, new_value)

return self.generic_visit(node)

def visit_ClassDef(self, node: ast.ClassDef):
# update class name
for old_value, new_value in self.class_replace_map.items():
node.name = node.name.replace(old_value, new_value)

# update parent classes names
for base in node.bases:
for old_value, new_value in self.class_replace_map.items():
base.id = base.id.replace(old_value, new_value)
return self.generic_visit(node)

def generate_init(self, sync_node: ast.FunctionDef):
def traverse(node):
assignment_nodes = []

if isinstance(node, ast.Assign):
assignment_nodes.append(node)
for field_name, field_value in ast.iter_fields(node):
if isinstance(field_value, ast.AST):
assignment_nodes.extend(traverse(field_value))
elif isinstance(field_value, list):
for item in field_value:
if isinstance(item, ast.AST):
assignment_nodes.extend(traverse(item))
return assignment_nodes

def unwrap_orelse_assignment(assign_node: ast.Assign):
for target in assign_node.targets:
if isinstance(target, ast.Attribute) and target.attr == "_client":
if isinstance(assign_node.value, ast.Call):
if assign_node.value.func.id in self.class_replace_map:
assign_node.value.func.id = self.class_replace_map[
assign_node.value.func.id
]
return assign_node

args, defaults = [sync_node.args.args[0]], []
for arg, default in zip(sync_node.args.args[1:], sync_node.args.defaults):
if arg.arg not in ("location", "path"):
args.append(arg)
defaults.append(default)
sync_node.args.args = args
sync_node.args.defaults = defaults

for i, child_node in enumerate(sync_node.body):
if isinstance(child_node, ast.If):
orelse_assignment_nodes = traverse(child_node)
assignments = list(
filter(
lambda x: x,
[
unwrap_orelse_assignment(assign_node)
for assign_node in orelse_assignment_nodes
],
)
)
if len(assignments) == 1:
sync_node.body[i] = assignments[0]
break
return self.generic_visit(sync_node)


class ClientAsyncAwaitTransformer(AsyncAwaitTransformer):
def _keep_sync(self, name: str) -> bool:
return name in self.keep_sync or name not in self.async_methods


if __name__ == "__main__":
from .config import CODE_DIR
from tools.async_client_generator.config import CLIENT_DIR, CODE_DIR

with open(CODE_DIR / "qdrant_client.py", "r") as source_file:
with open(CLIENT_DIR / "qdrant_client.py", "r") as source_file:
code = source_file.read()

parsed_code = ast.parse(code)

await_transformer = ClientAsyncAwaitTransformer(
generator = ClientGenerator(
class_replace_map={
"QdrantBase": "AsyncQdrantBase",
"QdrantFastembedMixin": "AsyncQdrantFastembedMixin",
Expand All @@ -208,8 +91,8 @@ def _keep_sync(self, name: str) -> bool:
"async_grpc_points",
],
)
modified_code_ast = await_transformer.visit(parsed_code)
modified_code = ast.unparse(modified_code_ast)

with open("async_qdrant_client.py", "w") as target_file:
modified_code = generator.generate(code)

with open(CODE_DIR / "async_qdrant_client.py", "w") as target_file:
target_file.write(modified_code)
5 changes: 3 additions & 2 deletions tools/async_client_generator/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path

ROOT_DIR = Path(__file__).parent.parent.parent
CODE_DIR = ROOT_DIR / "qdrant_client"
CODE_DIR = Path(__file__).parent
ROOT_DIR = CODE_DIR.parent.parent
CLIENT_DIR = ROOT_DIR / "qdrant_client"
Loading

0 comments on commit e72ed0f

Please sign in to comment.