-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new: refactor async client generator
- Loading branch information
Showing
25 changed files
with
625 additions
and
570 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,40 @@ | ||
# type: ignore | ||
import ast | ||
from typing import Optional | ||
from typing import Dict, List, Optional | ||
|
||
from tools.async_client_generator.base_generator import BaseGenerator | ||
from tools.async_client_generator.transformers import ( | ||
ClassDefTransformer, | ||
ConstantTransformer, | ||
FunctionDefTransformer, | ||
) | ||
|
||
class AsyncAwaitTransformer(ast.NodeTransformer): | ||
|
||
class BaseClientGenerator(BaseGenerator): | ||
def __init__( | ||
self, | ||
keep_sync: Optional[list[str]] = None, | ||
class_replace_map: Optional[dict] = None, | ||
import_replace_map: Optional[dict] = None, | ||
constant_replace_map: Optional[dict] = None, | ||
keep_sync: Optional[List[str]] = None, | ||
class_replace_map: Optional[Dict[str, str]] = None, | ||
constant_replace_map: Optional[Dict[str, str]] = None, | ||
): | ||
self._async_methods = None | ||
self.keep_sync = keep_sync if keep_sync is not None else [] | ||
self.class_replace_map = class_replace_map if class_replace_map is not None else {} | ||
self.import_replace_map = import_replace_map if import_replace_map is not None else {} | ||
self.constant_replace_map = ( | ||
constant_replace_map if constant_replace_map is not None else {} | ||
) | ||
|
||
def _keep_sync(self, name: str) -> bool: | ||
return name in self.keep_sync | ||
|
||
def visit_FunctionDef(self, sync_node: ast.FunctionDef): | ||
if self._keep_sync(sync_node.name): | ||
return self.generic_visit(sync_node) | ||
super().__init__() | ||
|
||
async_node = ast.AsyncFunctionDef( | ||
name=sync_node.name, | ||
args=sync_node.args, | ||
body=sync_node.body, | ||
decorator_list=sync_node.decorator_list, | ||
returns=sync_node.returns, | ||
type_comment=sync_node.type_comment, | ||
) | ||
async_node.lineno = sync_node.lineno | ||
async_node.col_offset = sync_node.col_offset | ||
async_node.end_lineno = sync_node.end_lineno | ||
async_node.end_col_offset = sync_node.end_col_offset | ||
return self.generic_visit(async_node) | ||
|
||
def visit_ClassDef(self, node: ast.ClassDef): | ||
# update class name | ||
for old_value, new_value in self.class_replace_map.items(): | ||
node.name = node.name.replace(old_value, new_value) | ||
return self.generic_visit(node) | ||
|
||
def visit_Constant(self, node: ast.arg): | ||
for old_value, new_value in self.constant_replace_map.items(): | ||
if isinstance(node.value, str): | ||
node.value = node.value.replace(old_value, new_value) | ||
return self.generic_visit(node) | ||
self.transformers.append(FunctionDefTransformer(keep_sync)) | ||
self.transformers.append(ClassDefTransformer(class_replace_map)) | ||
self.transformers.append(ConstantTransformer(constant_replace_map)) | ||
|
||
|
||
if __name__ == "__main__": | ||
from .config import CODE_DIR | ||
from tools.async_client_generator.config import CLIENT_DIR, CODE_DIR | ||
|
||
with open(CODE_DIR / "client_base.py", "r") as source_file: | ||
with open(CLIENT_DIR / "client_base.py", "r") as source_file: | ||
code = source_file.read() | ||
|
||
# Parse the code into an AST | ||
parsed_code = ast.parse(code) | ||
|
||
await_transformer = AsyncAwaitTransformer( | ||
base_client_generator = BaseClientGenerator( | ||
keep_sync=["__init__", "upload_records", "upload_collection", "migrate"], | ||
class_replace_map={"QdrantBase": "AsyncQdrantBase"}, | ||
constant_replace_map={"QdrantBase": "AsyncQdrantBase"}, | ||
) | ||
modified_code_ast = await_transformer.visit(parsed_code) | ||
modified_code = ast.unparse(modified_code_ast) | ||
modified_code = base_client_generator.generate(code) | ||
|
||
with open("async_client_base.py", "w") as target_file: | ||
with open(CODE_DIR / "async_client_base.py", "w") as target_file: | ||
target_file.write(modified_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# type: ignore | ||
import ast | ||
from typing import List | ||
|
||
|
||
class BaseGenerator: | ||
def __init__(self): | ||
self.transformers: List[ast.NodeTransformer] = [] | ||
|
||
def generate(self, code: str) -> str: | ||
nodes = ast.parse(code) | ||
|
||
for transformer in self.transformers: | ||
nodes = transformer.visit(nodes) | ||
|
||
return ast.unparse(nodes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from pathlib import Path | ||
|
||
ROOT_DIR = Path(__file__).parent.parent.parent | ||
CODE_DIR = ROOT_DIR / "qdrant_client" | ||
CODE_DIR = Path(__file__).parent | ||
ROOT_DIR = CODE_DIR.parent.parent | ||
CLIENT_DIR = ROOT_DIR / "qdrant_client" |
Oops, something went wrong.