From 85f7f3261c6fe2ce291817087ae4c7ee8f99a9c7 Mon Sep 17 00:00:00 2001 From: Paul Prescod Date: Mon, 21 Mar 2022 20:15:25 -0700 Subject: [PATCH] Allow decimal datatype --- snowfakery/data_generator_runtime.py | 6 ++++-- .../data_generator_runtime_object_model.py | 6 +++++- snowfakery/output_streams.py | 2 ++ snowfakery/utils/template_utils.py | 3 +++ tests/decimal.yml | 17 +++++++++++++++++ tests/test_faker.py | 9 +++++++-- tests/test_types.py | 11 +++++++++++ 7 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 tests/decimal.yml diff --git a/snowfakery/data_generator_runtime.py b/snowfakery/data_generator_runtime.py index 5ca82df7..4d3dcfd1 100644 --- a/snowfakery/data_generator_runtime.py +++ b/snowfakery/data_generator_runtime.py @@ -344,8 +344,10 @@ def __init__( "snowfakery.standard_plugins.SnowfakeryVersion.snowfakery_version", 2 ) assert snowfakery_version in (2, 3) - native_types = snowfakery_version == 3 - self.template_evaluator_factory = JinjaTemplateEvaluatorFactory(native_types) + self.native_types = snowfakery_version == 3 + self.template_evaluator_factory = JinjaTemplateEvaluatorFactory( + self.native_types + ) def execute(self): self.current_context = RuntimeContext(interpreter=self) diff --git a/snowfakery/data_generator_runtime_object_model.py b/snowfakery/data_generator_runtime_object_model.py index 5a335d40..0c12dba1 100644 --- a/snowfakery/data_generator_runtime_object_model.py +++ b/snowfakery/data_generator_runtime_object_model.py @@ -333,6 +333,8 @@ def render(self, context: RuntimeContext) -> FieldValue: if evaluator: try: val = evaluator(context) + if hasattr(val, "render"): + val = val.render() except jinja2.exceptions.UndefinedError as e: raise DataGenNameError(e.message, self.filename, self.line_num) from e except Exception as e: @@ -340,7 +342,9 @@ def render(self, context: RuntimeContext) -> FieldValue: else: val = self.definition context.unique_context_identifier = old_context_identifier - return look_for_number(val) if isinstance(val, str) else val + if isinstance(val, str) and not context.interpreter.native_types: + val = look_for_number(val) + return val def __repr__(self): return f"<{self.__class__.__name__ , self.definition}>" diff --git a/snowfakery/output_streams.py b/snowfakery/output_streams.py index da5a9783..9bac0d93 100644 --- a/snowfakery/output_streams.py +++ b/snowfakery/output_streams.py @@ -5,6 +5,7 @@ import subprocess import datetime import sys +from decimal import Decimal from pathlib import Path from collections import namedtuple, defaultdict from typing import Dict, Union, Optional, Mapping, Callable, Sequence @@ -54,6 +55,7 @@ class OutputStream(ABC): datetime.datetime: format_datetime, type(None): noop, bool: int, + Decimal: str, } uses_folder = False uses_path = False diff --git a/snowfakery/utils/template_utils.py b/snowfakery/utils/template_utils.py index b3f2a112..c452e71d 100644 --- a/snowfakery/utils/template_utils.py +++ b/snowfakery/utils/template_utils.py @@ -37,6 +37,9 @@ def __add__(self, other): def __radd__(self, other): return str(other) + str(self) + def render(self): + return self.func() + class FakerTemplateLibrary: """A Jinja template library to add the fake.xyz objects to templates""" diff --git a/tests/decimal.yml b/tests/decimal.yml new file mode 100644 index 00000000..5a75d90a --- /dev/null +++ b/tests/decimal.yml @@ -0,0 +1,17 @@ +- snowfakery_version: 3 +- object: Foo + fields: + lat: ${{fake.latitude | string}} # jinja2 will still make a number + long: ${{fake.longitude | string}} # https://github.com/pallets/jinja/issues/1200 + +- object: Bar + fields: + lat2: ${{fake.latitude}} + long2: ${{fake.longitude}} + +- object: Baz + fields: + lat3: + fake: latitude + long3: + fake: longitude diff --git a/tests/test_faker.py b/tests/test_faker.py index 127a59da..ec8cb224 100644 --- a/tests/test_faker.py +++ b/tests/test_faker.py @@ -50,14 +50,19 @@ def test_fake_block_one_param(self, write_row_mock): generate(StringIO(yaml), {}) assert len(row_values(write_row_mock, 0, "country")) == 2 + @pytest.mark.parametrize("snowfakery_version", (2, 3)) @mock.patch(write_row_path) - def test_fake_inline(self, write_row_mock): + def test_fake_inline(self, write_row_mock, snowfakery_version): yaml = """ - object: OBJ fields: country: ${{fake.country_code(representation='alpha-2')}} """ - generate(StringIO(yaml), {}, None) + generate( + StringIO(yaml), + {}, + plugin_options={"snowfakery_version": snowfakery_version}, + ) assert len(row_values(write_row_mock, 0, "country")) == 2 @mock.patch(write_row_path) diff --git a/tests/test_types.py b/tests/test_types.py index 6e373cb1..abadabaf 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,3 +1,4 @@ +import pytest from unittest import mock from io import StringIO @@ -36,3 +37,13 @@ def test_float(self, generated_rows): generate(StringIO(yaml)) assert generated_rows.row_values(0, "foo") == 0.1 assert generated_rows.row_values(0, "foo2") == 0.1 + + @pytest.mark.parametrize("snowfakery_version", (2, 3)) + def test_decimal(self, generated_rows, snowfakery_version): + with open("tests/decimal.yml") as f: + generate(f, plugin_options={"snowfakery_version": snowfakery_version}) + assert isinstance( + generated_rows.table_values("Foo", 0)["lat"], float + ) # Jinja quirk + assert isinstance(generated_rows.table_values("Bar", 0)["lat2"], str) + assert isinstance(generated_rows.table_values("Baz", 0)["lat3"], str)