Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement enum keys code style setting #106

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@ pip install pydantic-avro

### Pydantic class to avro schema

```python
```python title="example.py"
import json
from typing import Optional


from pydantic_avro.base import AvroBase


class TestModel(AvroBase):
key1: str
key2: int
key2: Optional[str]


schema_dict: dict = TestModel.avro_schema()
print(json.dumps(schema_dict))

Expand All @@ -35,10 +38,10 @@ print(json.dumps(schema_dict))

```shell
# Print to stdout
pydantic-avro avro_to_pydantic --asvc /path/to/schema.asvc
pydantic-avro avro_to_pydantic --avsc /path/to/schema.avsc

# Save it to a file
pydantic-avro avro_to_pydantic --asvc /path/to/schema.asvc --output /path/to/output.py
pydantic-avro avro_to_pydantic --avsc /path/to/schema.avsc --output /path/to/output.py
```


Expand Down
6 changes: 4 additions & 2 deletions src/pydantic_avro/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
from typing import List

from pydantic_avro.avro_to_pydantic import convert_file
from pydantic_avro.helpers import ENUM_KEY_STYLES


def main(input_args: List[str]):
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="sub_command", required=True)

parser_cache = subparsers.add_parser("avro_to_pydantic")
parser_cache.add_argument("--asvc", type=str, dest="avsc", required=True)
parser_cache.add_argument("--avsc", type=str, dest="avsc", required=True)
parser_cache.add_argument("--output", type=str, dest="output")
parser_cache.add_argument("--enum-key-style", type=str, dest="enum_key_style", choices=ENUM_KEY_STYLES)

args = parser.parse_args(input_args)

if args.sub_command == "avro_to_pydantic":
convert_file(args.avsc, args.output)
convert_file(args.avsc, args.output, args.enum_key_style)


def root_main():
Expand Down
33 changes: 21 additions & 12 deletions src/pydantic_avro/avro_to_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import json
import logging
import sys
from pathlib import Path
from typing import Optional, Union

from pydantic_avro.helpers import convert_enum_key

def avsc_to_pydantic(schema: dict) -> str:
logging.basicConfig(format="{asctime} - {levelname} - {message}", level=logging.INFO, style="{", stream=sys.stdout)
log = logging.getLogger(__name__)


def avsc_to_pydantic(schema: dict, enum_key_style: Optional[str] = None) -> str:
"""Generate python code of pydantic of given Avro Schema"""
if "type" not in schema or schema["type"] != "record":
raise AttributeError("Type not supported")
Expand All @@ -14,7 +22,7 @@ def avsc_to_pydantic(schema: dict) -> str:
classes = {}

def get_python_type(t: Union[str, dict]) -> str:
"""Returns python type for given avro type"""
"""Return python type for given avro type"""
optional = False
if isinstance(t, str):
if t == "string":
Expand All @@ -41,9 +49,9 @@ def get_python_type(t: Union[str, dict]) -> str:
py_type = get_python_type(c[0])
else:
if "null" in t:
py_type = f"Optional[Union[{','.join([ get_python_type(e) for e in t if e != 'null'])}]]"
py_type = f"Optional[Union[{','.join([get_python_type(e) for e in t if e != 'null'])}]]"
else:
py_type = f"Union[{','.join([ get_python_type(e) for e in t])}]"
py_type = f"Union[{','.join([get_python_type(e) for e in t])}]"
elif t.get("logicalType") == "uuid":
py_type = "UUID"
elif t.get("logicalType") == "decimal":
Expand All @@ -59,7 +67,7 @@ def get_python_type(t: Union[str, dict]) -> str:
if enum_name not in classes:
enum_class = f"class {enum_name}(str, Enum):\n"
for s in t.get("symbols"):
enum_class += f' {s} = "{s}"\n'
enum_class += f' {convert_enum_key(s, enum_key_style)} = "{s}"\n'
classes[enum_name] = enum_class
py_type = enum_name
elif t.get("type") == "string":
Expand Down Expand Up @@ -127,12 +135,13 @@ def record_type_to_pydantic(schema: dict):
return file_content


def convert_file(avsc_path: str, output_path: Optional[str] = None):
with open(avsc_path, "r") as fh:
def convert_file(avsc_path: str, output_path: Optional[str] = None, enum_key_style: Optional[str] = None) -> None:
with open(avsc_path) as fh:
avsc_dict = json.load(fh)
file_content = avsc_to_pydantic(avsc_dict)
if output_path is None:
print(file_content)
file_content = avsc_to_pydantic(avsc_dict, enum_key_style)
if not output_path:
log.info(f"Converted file content:\n{file_content}")
else:
with open(output_path, "w") as fh:
fh.write(file_content)
fh = Path(output_path)
fh.parent.mkdir(parents=True, exist_ok=True)
fh.write_text(file_content)
13 changes: 13 additions & 0 deletions src/pydantic_avro/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from functools import reduce
from typing import Optional

ENUM_KEY_STYLES = {"snake_case", "snake_case_upper"}


def convert_enum_key(key: str, style: Optional[str] = None) -> str:
if not style:
return key
if style not in ENUM_KEY_STYLES:
raise NotImplementedError(f"Invalid enum key style: {key}. Supported styles: {ENUM_KEY_STYLES}")
snaked = reduce(lambda x, y: x + ("_" if y.isupper() else "") + y, key)
return snaked.lower() if style == "snake_case" else snaked.upper()
25 changes: 25 additions & 0 deletions tests/test_from_avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,31 @@ def test_enums_reuse():
assert "class Status(str, Enum):\n" ' passed = "passed"\n' ' failed = "failed"' in pydantic_code


def test_enums_keys_style():
schema = {
"name": "Test",
"type": "record",
"fields": [
{"name": "c1", "type": {"type": "enum", "symbols": ["FirstValue", "secondValue"], "name": "Status"}},
{"name": "c2", "type": "Status"},
],
}
no_style = avsc_to_pydantic(schema)

assert "class Test(BaseModel):\n" " c1: Status\n" " c2: Status" in no_style
assert "class Status(str, Enum):\n" ' FirstValue = "FirstValue"\n' ' secondValue = "secondValue"' in no_style

snake = avsc_to_pydantic(schema, enum_key_style="snake_case")

assert "class Test(BaseModel):\n" " c1: Status\n" " c2: Status" in snake
assert "class Status(str, Enum):\n" ' first_value = "FirstValue"\n' ' second_value = "secondValue"' in snake

snake_upper = avsc_to_pydantic(schema, enum_key_style="snake_case_upper")

assert "class Test(BaseModel):\n" " c1: Status\n" " c2: Status" in snake_upper
assert "class Status(str, Enum):\n" ' FIRST_VALUE = "FirstValue"\n' ' SECOND_VALUE = "secondValue"' in snake_upper


def test_unions():
pydantic_code = avsc_to_pydantic(
{
Expand Down
21 changes: 21 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

from pydantic_avro.helpers import convert_enum_key


@pytest.mark.parametrize(
("key", "style", "expected"),
[
("MyKeyName", "snake_case", "my_key_name"),
("MyKeyName", "snake_case_upper", "MY_KEY_NAME"),
("MyKeyName", None, "MyKeyName"),
]
)
def test_convert_enum_key(key: str, style: str, expected: str) -> None:
assert convert_enum_key(key, style) == expected


def test_convert_enum_key_wrong_style() -> None:
with pytest.raises(NotImplementedError) as e:
convert_enum_key("MyKey", "test")
assert "Invalid enum key style: MyKey. Supported styles: {'snake_case_upper', 'snake_case'}" in str(e.value)