diff --git a/snowfakery/data_generator_runtime.py b/snowfakery/data_generator_runtime.py index 80b37d5d..337f3a74 100644 --- a/snowfakery/data_generator_runtime.py +++ b/snowfakery/data_generator_runtime.py @@ -14,6 +14,7 @@ from .data_gen_exceptions import DataGenSyntaxError, DataGenNameError import snowfakery # noQA from snowfakery.object_rows import NicknameSlot, SlotState, ObjectRow +from snowfakery.plugins import PluginContext, SnowfakeryPlugin OutputStream = "snowfakery.output_streams.OutputStream" VariableDefinition = "snowfakery.data_generator_runtime_object_model.VariableDefinition" @@ -297,6 +298,9 @@ def __init__( f"No template creating {stop_table_name}", ) + # make a plugin context for our Faker stuff to act like a plugin + self.faker_plugin_context = PluginContext(SnowfakeryPlugin(self)) + self.faker_template_libraries = {} # inject context into the standard functions @@ -316,9 +320,14 @@ def execute(self): return self.globals def faker_template_library(self, locale): + """Create a faker template library for locale, or retrieve it from a cache""" rc = self.faker_template_libraries.get(locale) if not rc: - rc = FakerTemplateLibrary(self.faker_providers, locale) + rc = FakerTemplateLibrary( + self.faker_providers, + locale, + self.faker_plugin_context, + ) self.faker_template_libraries[locale] = rc return rc @@ -359,6 +368,7 @@ class RuntimeContext: obj: Optional[ObjectRow] = None template_evaluator_recipe = JinjaTemplateEvaluatorFactory() current_template = None + local_vars = None def __init__( self, @@ -378,6 +388,7 @@ def __init__( self._plugin_context_vars = ChainMap() locale = self.variable_definitions().get("snowfakery_locale") self.faker_template_library = self.interpreter.faker_template_library(locale) + self.local_vars = {} # TODO: move this into the interpreter object def check_if_finished(self): @@ -454,6 +465,9 @@ def field_vars(self): return self.evaluation_namespace.field_vars() def context_vars(self, plugin_namespace): + """Variables which are inherited by child scopes""" + # This looks like a candidate for optimization. + # An unconditional object copy seems expensive. local_plugin_vars = self._plugin_context_vars.get(plugin_namespace, {}).copy() self._plugin_context_vars[plugin_namespace] = local_plugin_vars return local_plugin_vars diff --git a/snowfakery/data_generator_runtime_object_model.py b/snowfakery/data_generator_runtime_object_model.py index a0f7741a..13cf9b2d 100644 --- a/snowfakery/data_generator_runtime_object_model.py +++ b/snowfakery/data_generator_runtime_object_model.py @@ -136,7 +136,9 @@ def exception_handling(self, message: str): except DataGenError: raise except Exception as e: - raise DataGenError(f"{message} : {str(e)}", self.filename, self.line_num) + raise DataGenError( + f"{message} : {str(e)}", self.filename, self.line_num + ) from e def _evaluate_count(self, context: RuntimeContext) -> int: """Evaluate the count expression to an integer""" diff --git a/snowfakery/fakedata/fake_data_generator.py b/snowfakery/fakedata/fake_data_generator.py index 32031149..87112c7d 100644 --- a/snowfakery/fakedata/fake_data_generator.py +++ b/snowfakery/fakedata/fake_data_generator.py @@ -1,14 +1,39 @@ from difflib import get_close_matches import typing as T +import random +from snowfakery.plugins import PluginContext +from itertools import product +from datetime import datetime from faker import Faker, Generator +# .format language doesn't allow slicing. :( +first_name_patterns = ("{firstname}", "{firstname[0]}", "{firstname[0]}{firstname[1]}") +first_name_separators = ("", ".", "-", "_", "+") +year_patterns = ("{year}", "{year[2]}{year[3]}", "{year[3]}", "") + +email_templates = [ + f"{first_name}{first_name_separator}{{lastname}}{year}@{{domain}}" + for first_name, first_name_separator, year in product( + first_name_patterns, first_name_separators, year_patterns + ) +] + +this_year = datetime.today().year + class FakeNames(T.NamedTuple): f: Faker + faker_context: PluginContext = None - def user_name(self): + # "matching" allows us to turn off the behaviour of + # trying to incorporate one field into another if we + # need to. + def user_name(self, matching: bool = True): "Salesforce-style username in the form of an email address" + already_created = self._already_have(("firstname", "lastname")) + if matching and all(already_created): + return f"{already_created[0]}.{already_created[1]}_{self.f.uuid4()}@{self.f.safe_domain_name()}" return f"{self.f.first_name()}_{self.f.last_name()}_{self.f.uuid4()}@{self.f.hostname()}" def alias(self): @@ -17,8 +42,18 @@ def alias(self): numbers of them.""" return self.f.first_name()[0:8] - def email(self): + def email(self, matching: bool = True): """Email address using one of the "example" domains""" + already_created = self._already_have(("firstname", "lastname")) + if matching and all(already_created): + template = random.choice(email_templates) + + return template.format( + firstname=already_created[0].ljust(2, "_"), + lastname=already_created[1], + domain=self.f.safe_domain_name(), + year=str(random.randint(this_year - 80, this_year - 10)), + ) return self.f.ascii_safe_email() def realistic_maybe_real_email(self): @@ -28,6 +63,12 @@ def realistic_maybe_real_email(self): """ return self.f.email() + def _already_have(self, names: T.Sequence[str]): + """Get a list of field values that we've already generated""" + already_created = self.faker_context.local_vars() + vals = [already_created.get(name) for name in names] + return vals + def state(self): """Return a state, province or other appropriate administrative unit""" return self.f.administrative_unit() @@ -45,9 +86,20 @@ def postalcode(self): class FakeData: """Wrapper for Faker which adds Salesforce names and case insensitivity.""" - def __init__(self, faker: Faker): - fake_names = FakeNames(faker) - self.faker = faker + def __init__( + self, + faker_providers: T.Sequence[object], + locale: str = None, + faker_context: PluginContext = None, + ): + # access to persistent state + self.faker_context = faker_context + + faker = Faker(locale, use_weighting=False) + for provider in faker_providers: + faker.add_provider(provider) + + fake_names = FakeNames(faker, faker_context) def no_underscore_name(name): return name.lower().replace("_", "") @@ -72,13 +124,17 @@ def obj_to_func_list(obj: object, canonicalizer: T.Callable, ignore_list: set): } def _get_fake_data(self, origname, *args, **kwargs): + local_faker_vars = self.faker_context.local_vars() + # faker names are all lower-case name = origname.lower() meth = self.fake_names.get(name) if meth: - return meth(*args, **kwargs) + ret = meth(*args, **kwargs) + local_faker_vars[name.replace("_", "")] = ret + return ret msg = f"No fake data type named {origname}." match_list = get_close_matches(name, self.fake_names.keys(), n=1) diff --git a/snowfakery/plugins.py b/snowfakery/plugins.py index ae4fd32a..16f4555c 100644 --- a/snowfakery/plugins.py +++ b/snowfakery/plugins.py @@ -97,12 +97,18 @@ def field_vars(self): return self.interpreter.current_context.field_vars() def context_vars(self): - return self.interpreter.current_context.context_vars( - self.plugin.__class__.__name__ + return self.interpreter.current_context.context_vars(id(self.plugin)) + + def local_vars(self): + return self.interpreter.current_context.local_vars.setdefault( + id(self.plugin), {} ) def unique_context_identifier(self) -> str: - "An identifier that will be unique across iterations (but not portion invocations)" + """An identifier representing a template context that will be + unique across iterations (but not portion invocations). It + allows templates that do counting or iteration for a particular + template context.""" return self.interpreter.current_context.unique_context_identifier def evaluate_raw(self, field_definition): @@ -248,8 +254,8 @@ def __init__(self, name, typ): def convert(self, value): try: return self.type(value) - except TypeError as e: - raise TypeError( + except (TypeError, ValueError) as e: + raise exc.DataGenTypeError( f"{self.name} option is wrong type {type(value)} rather than {self.type}", *e.args, ) diff --git a/snowfakery/utils/template_utils.py b/snowfakery/utils/template_utils.py index 1ef845a2..29b7c1a3 100644 --- a/snowfakery/utils/template_utils.py +++ b/snowfakery/utils/template_utils.py @@ -1,11 +1,9 @@ -from functools import lru_cache from typing import Sequence import string - -from faker import Faker -from jinja2 import Template from snowfakery.fakedata.fake_data_generator import FakeData +from snowfakery.plugins import PluginContext + class StringGenerator: """Sometimes in templates you want a reference to a variable to @@ -41,17 +39,18 @@ def __radd__(self, other): class FakerTemplateLibrary: - """A Jinja template library to add the faker.xyz objects to templates""" - - def __init__(self, faker_providers: Sequence[object], locale=None): + """A Jinja template library to add the fake.xyz objects to templates""" + + def __init__( + self, + faker_providers: Sequence[object], + locale: str = None, + context: PluginContext = None, + ): self.locale = locale + self.context = context - # TODO: Push this all down into FakeData - faker = Faker(self.locale, use_weighting=False) - for provider in faker_providers: - faker.add_provider(provider) - - self.fake_data = FakeData(faker) + self.fake_data = FakeData(faker_providers, locale, self.context) def _get_fake_data(self, name): return self.fake_data._get_fake_data(name) @@ -62,9 +61,6 @@ def __getattr__(self, name): ) -Template = lru_cache(512)(Template) - - number_chars = set(string.digits + ".") diff --git a/tests/test_custom_plugins_and_providers.py b/tests/test_custom_plugins_and_providers.py index ace34444..b02869fd 100644 --- a/tests/test_custom_plugins_and_providers.py +++ b/tests/test_custom_plugins_and_providers.py @@ -4,7 +4,7 @@ from base64 import b64decode from snowfakery import SnowfakeryPlugin, lazy -from snowfakery.plugins import PluginResult +from snowfakery.plugins import PluginResult, PluginOption from snowfakery.data_gen_exceptions import ( DataGenError, DataGenTypeError, @@ -25,6 +25,15 @@ def row_values(write_row_mock, index, value): class SimpleTestPlugin(SnowfakeryPlugin): + allowed_options = [ + PluginOption( + "tests.test_custom_plugins_and_providers.SimpleTestPlugin.option_str", str + ), + PluginOption( + "tests.test_custom_plugins_and_providers.SimpleTestPlugin.option_int", int + ), + ] + class Functions: def double(self, value): return value * 2 @@ -200,6 +209,26 @@ def test_binary(self, generated_rows): assert rawdata.startswith(b"%PDF-1.3") assert b"Helvetica" in rawdata + def test_option__simple(self, generated_rows): + yaml = """- plugin: tests.test_custom_plugins_and_providers.SimpleTestPlugin""" + + generate_data(StringIO(yaml), plugin_options={"option_str": "AAA"}) + + def test_option__unknown(self, generated_rows): + yaml = """- plugin: tests.test_custom_plugins_and_providers.SimpleTestPlugin""" + + generate_data(StringIO(yaml), plugin_options={"option_str": "zzz"}) + + def test_option__bad_type(self, generated_rows): + yaml = """- plugin: tests.test_custom_plugins_and_providers.SimpleTestPlugin""" + with pytest.raises(DataGenTypeError): + generate_data(StringIO(yaml), plugin_options={"option_int": "abcd"}) + + def test_option_type_coercion_needed(self, generated_rows): + yaml = """- plugin: tests.test_custom_plugins_and_providers.SimpleTestPlugin""" + + generate_data(StringIO(yaml), plugin_options={"option_int": "5"}) + class PluginThatNeedsState(SnowfakeryPlugin): class Functions: diff --git a/tests/test_data_generator_runtime_dom.py b/tests/test_data_generator_runtime_dom.py index e0cb5db5..8b0cd286 100644 --- a/tests/test_data_generator_runtime_dom.py +++ b/tests/test_data_generator_runtime_dom.py @@ -19,9 +19,6 @@ from snowfakery.output_streams import DebugOutputStream -from snowfakery.utils.template_utils import FakerTemplateLibrary - -ftl = FakerTemplateLibrary([]) line = {"filename": "abc.yml", "line_num": 42} diff --git a/tests/test_faker.py b/tests/test_faker.py index 3caf6157..cfcf9048 100644 --- a/tests/test_faker.py +++ b/tests/test_faker.py @@ -229,3 +229,149 @@ def test_faker_internals_are_invisible(self): with pytest.raises(exc.DataGenError) as e: generate(StringIO(yaml), {}, None) assert "seed" in str(e.value) + + def test_context_aware(self, generated_rows): + yaml = """ + - object: X + fields: + FirstName: + fake: FirstName + LastName: + fake: LastName + Email: + fake: Email + """ + generate(StringIO(yaml)) + assert generated_rows.table_values( + "X", 0, "LastName" + ) in generated_rows.table_values("X", 0, "Email") + + def test_context_username_incorporates_fakes(self, generated_rows): + yaml = """ + - object: X + fields: + FirstName: + fake: FirstName + LastName: + fake: LastName + Username: + fake: Username + """ + generate(StringIO(yaml)) + assert generated_rows.table_values( + "X", 0, "FirstName" + ) in generated_rows.table_values("X", 0, "Username") + assert generated_rows.table_values( + "X", 0, "LastName" + ) in generated_rows.table_values("X", 0, "Username") + + def test_context_aware_multiple_values(self, generated_rows): + yaml = """ + - object: X + count: 3 + fields: + FirstName: + fake: FirstName + LastName: + fake: LastName + Email: + fake: Email + """ + generate(StringIO(yaml)) + assert ( + generated_rows.table_values("X", 2)["LastName"] + in generated_rows.table_values("X", 2)["Email"] + ) + + @mock.patch("faker.providers.person.en_US.Provider.first_name") + @mock.patch("faker.providers.internet.en_US.Provider.ascii_safe_email") + def test_context_aware_order_matters(self, email, first_name, generated_rows): + yaml = """ + - object: X + count: 3 + fields: + Email: + fake: Email + FirstName: + fake: FirstName + LastName: + fake: LastName + """ + generate(StringIO(yaml)) + assert first_name.mock_calls + assert email.mock_calls + + @mock.patch("faker.providers.person.en_US.Provider.first_name") + @mock.patch("faker.providers.internet.en_US.Provider.ascii_safe_email") + def test_context_aware_no_leakage_count(self, email, first_name, generated_rows): + yaml = """ + - object: X + count: 3 + fields: + FirstName: + fake: FirstName + LastName: + fake: LastName + Email: + fake: Email + """ + generate(StringIO(yaml)) + assert first_name.mock_calls + assert not email.mock_calls + + @mock.patch("faker.providers.person.en_US.Provider.first_name") + @mock.patch("faker.providers.internet.en_US.Provider.ascii_safe_email") + def test_context_aware_no_leakage_templates( + self, email, first_name, generated_rows + ): + # no leakage between templates + yaml = """ + - object: X + fields: + FirstName: + fake: FirstName + LastName: + fake: LastName + Email: + fake: Email + - object: Y + fields: + Email: + fake: Email + """ + generate(StringIO(yaml)) + assert first_name.mock_calls + email.assert_called_once() + + @mock.patch("faker.providers.person.en_US.Provider.first_name") + @mock.patch("faker.providers.internet.en_US.Provider.ascii_safe_email") + def test_context_aware_alernate_names(self, email, first_name, generated_rows): + yaml = """ + - object: X + fields: + FirstName: + fake: first_name + LastName: + fake: last_name + Email: + fake: Email + """ + generate(StringIO(yaml)) + assert first_name.mock_calls + assert not email.mock_calls + + @mock.patch("faker.providers.person.en_US.Provider.first_name") + @mock.patch("faker.providers.internet.en_US.Provider.ascii_safe_email") + def test_disable_matching(self, email, first_name, generated_rows): + yaml = """ + - object: X + fields: + FirstName: + fake: FirstName + LastName: + fake: last_name + Email: ${{fake.email(matching=False)}} + """ + generate(StringIO(yaml)) + assert first_name.mock_calls + assert email.mock_calls diff --git a/tests/test_locales.py b/tests/test_locales.py index 5adaa22f..764215eb 100644 --- a/tests/test_locales.py +++ b/tests/test_locales.py @@ -21,7 +21,7 @@ def test_locales(self, generated_rows): name: fake: name """ - with mock.patch("snowfakery.utils.template_utils.Faker") as f: + with mock.patch("snowfakery.fakedata.fake_data_generator.Faker") as f: class FakeFaker(Faker): def name(self): diff --git a/tools/faker_docs_utils/faker_markdown.py b/tools/faker_docs_utils/faker_markdown.py index 87dd2c8e..957dda3d 100644 --- a/tools/faker_docs_utils/faker_markdown.py +++ b/tools/faker_docs_utils/faker_markdown.py @@ -68,7 +68,7 @@ def generate_markdown_for_fakers(outfile, locale: str, header: str = standard_he "Generate the Markdown page for a locale" faker = Faker(locale) language = language_codes[locale.split("_")[0]] - fd = FakeData(faker) + fd = FakeData([], locale) all_fakers = summarize_all_fakers(fd)