Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow fakers to build on previously populated fields. #420

Merged
merged 14 commits into from
Aug 16, 2021
16 changes: 15 additions & 1 deletion snowfakery/data_generator_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .data_gen_exceptions import DataGenSyntaxError, DataGenNameError
import snowfakery # noQA
from snowfakery.object_rows import NicknameSlot, SlotState, ObjectRow
from snowfakery.plugins import PluginContext, SnowfakeryPlugin

OutputStream = "snowfakery.output_streams.OutputStream"
VariableDefinition = "snowfakery.data_generator_runtime_object_model.VariableDefinition"
Expand Down Expand Up @@ -297,6 +298,9 @@ def __init__(
f"No template creating {stop_table_name}",
)

# make a plugin context for our Faker stuff to act like a plugin
self.faker_plugin_context = PluginContext(SnowfakeryPlugin(self))

self.faker_template_libraries = {}

# inject context into the standard functions
Expand All @@ -316,9 +320,14 @@ def execute(self):
return self.globals

def faker_template_library(self, locale):
"""Create a faker template library for locale, or retrieve it from a cache"""
rc = self.faker_template_libraries.get(locale)
if not rc:
rc = FakerTemplateLibrary(self.faker_providers, locale)
rc = FakerTemplateLibrary(
self.faker_providers,
locale,
self.faker_plugin_context,
)
self.faker_template_libraries[locale] = rc
return rc

Expand Down Expand Up @@ -359,6 +368,7 @@ class RuntimeContext:
obj: Optional[ObjectRow] = None
template_evaluator_recipe = JinjaTemplateEvaluatorFactory()
current_template = None
local_vars = None

def __init__(
self,
Expand All @@ -378,6 +388,7 @@ def __init__(
self._plugin_context_vars = ChainMap()
locale = self.variable_definitions().get("snowfakery_locale")
self.faker_template_library = self.interpreter.faker_template_library(locale)
self.local_vars = {}

# TODO: move this into the interpreter object
def check_if_finished(self):
Expand Down Expand Up @@ -454,6 +465,9 @@ def field_vars(self):
return self.evaluation_namespace.field_vars()

def context_vars(self, plugin_namespace):
"""Variables which are inherited by child scopes"""
# This looks like a candidate for optimization.
# An unconditional object copyseems expensive.
local_plugin_vars = self._plugin_context_vars.get(plugin_namespace, {}).copy()
self._plugin_context_vars[plugin_namespace] = local_plugin_vars
return local_plugin_vars
Expand Down
4 changes: 3 additions & 1 deletion snowfakery/data_generator_runtime_object_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def exception_handling(self, message: str):
except DataGenError:
raise
except Exception as e:
raise DataGenError(f"{message} : {str(e)}", self.filename, self.line_num)
raise DataGenError(
f"{message} : {str(e)}", self.filename, self.line_num
) from e

def _evaluate_count(self, context: RuntimeContext) -> int:
"""Evaluate the count expression to an integer"""
Expand Down
59 changes: 53 additions & 6 deletions snowfakery/fakedata/fake_data_generator.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,46 @@
from difflib import get_close_matches
import typing as T
import random
from snowfakery.plugins import PluginContext


email_templates = [ # .format language doesn't allow slicing. :(
f"{first_name}{first_name_separator}{{lastname}}{year}@{{domain}}"
for first_name in ("{firstname}", "{firstname[0]}", "{firstname[0]}{firstname[1]}")
boakley marked this conversation as resolved.
Show resolved Hide resolved
for first_name_separator in ("", ".", "-", "_", "+")
for year in ("{year}", "{year[2]}{year[3]}", "{year[3]}", "")
]

from faker import Faker, Generator


class FakeNames(T.NamedTuple):
f: Faker
faker_context: PluginContext = None

def user_name(self):
def user_name(self, matching: bool = True):
boakley marked this conversation as resolved.
Show resolved Hide resolved
"Salesforce-style username in the form of an email address"
already_created = self._already_have(("firstname", "lastname"), matching)
if all(already_created):
return f"{already_created[0]}.{already_created[1]}_{self.f.uuid4()}@{self.f.safe_domain_name()}"
return f"{self.f.first_name()}_{self.f.last_name()}_{self.f.uuid4()}@{self.f.hostname()}"

def alias(self):
"Salesforce-style 8-character alias"
return self.f.first_name()[0:8]

def email(self):
def email(self, matching: bool = True):
"""Email address using one of the "example" domains"""
already_created = self._already_have(("firstname", "lastname"), matching)
if all(already_created):
template = random.choice(email_templates)

return template.format(
firstname=already_created[0],
lastname=already_created[1],
domain=self.f.safe_domain_name(),
year=str(random.randint(1955, 2020)),
boakley marked this conversation as resolved.
Show resolved Hide resolved
)
return self.f.ascii_safe_email()

def realistic_maybe_real_email(self):
Expand All @@ -26,6 +50,14 @@ def realistic_maybe_real_email(self):
"""
return self.f.email()

def _already_have(self, names: T.Sequence[str], matching: bool):
"""Get a list of field values that we've already generated"""
if not matching:
return [None]
already_created = self.faker_context.local_vars()
vals = [already_created.get(name) for name in names]
return vals

def state(self):
"""Return a state, province or other appropriate administrative unit"""
return self.f.administrative_unit()
Expand All @@ -43,9 +75,20 @@ def postalcode(self):
class FakeData:
"""Wrapper for Faker which adds Salesforce names and case insensitivity."""

def __init__(self, faker: Faker):
fake_names = FakeNames(faker)
self.faker = faker
def __init__(
self,
faker_providers: T.Sequence[object],
locale: str = None,
faker_context: PluginContext = None,
):
# access to persistent state
self.faker_context = faker_context

faker = Faker(locale, use_weighting=False)
for provider in faker_providers:
faker.add_provider(provider)

fake_names = FakeNames(faker, faker_context)

def no_underscore_name(name):
return name.lower().replace("_", "")
Expand All @@ -70,13 +113,17 @@ def obj_to_func_list(obj: object, canonicalizer: T.Callable, ignore_list: set):
}

def _get_fake_data(self, origname, *args, **kwargs):
local_faker_vars = self.faker_context.local_vars()

# faker names are all lower-case
name = origname.lower()

meth = self.fake_names.get(name)

if meth:
return meth(*args, **kwargs)
ret = meth(*args, **kwargs)
local_faker_vars[name.replace("_", "")] = ret
return ret

msg = f"No fake data type named {origname}."
match_list = get_close_matches(name, self.fake_names.keys(), n=1)
Expand Down
16 changes: 11 additions & 5 deletions snowfakery/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,18 @@ def field_vars(self):
return self.interpreter.current_context.field_vars()

def context_vars(self):
return self.interpreter.current_context.context_vars(
self.plugin.__class__.__name__
return self.interpreter.current_context.context_vars(id(self.plugin))

def local_vars(self):
return self.interpreter.current_context.local_vars.setdefault(
id(self.plugin), {}
)

def unique_context_identifier(self) -> str:
"An identifier that will be unique across iterations (but not portion invocations)"
"""An identifier representing a template context that will be
unique across iterations (but not portion invocations). It
allows templates that do counting or iteration for a particular
template context."""
return self.interpreter.current_context.unique_context_identifier

def evaluate_raw(self, field_definition):
Expand Down Expand Up @@ -248,8 +254,8 @@ def __init__(self, name, typ):
def convert(self, value):
try:
return self.type(value)
except TypeError as e:
raise TypeError(
except (TypeError, ValueError) as e:
raise exc.DataGenTypeError(
f"{self.name} option is wrong type {type(value)} rather than {self.type}",
*e.args,
)
Expand Down
24 changes: 10 additions & 14 deletions snowfakery/utils/template_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from functools import lru_cache
from typing import Sequence
import string

from faker import Faker
from jinja2 import Template
from snowfakery.fakedata.fake_data_generator import FakeData

from snowfakery.plugins import PluginContext


class StringGenerator:
"""Sometimes in templates you want a reference to a variable to
Expand Down Expand Up @@ -43,15 +41,16 @@ def __radd__(self, other):
class FakerTemplateLibrary:
"""A Jinja template library to add the faker.xyz objects to templates"""

def __init__(self, faker_providers: Sequence[object], locale=None):
def __init__(
self,
faker_providers: Sequence[object],
locale: str = None,
context: PluginContext = None,
):
self.locale = locale
self.context = context

# TODO: Push this all down into FakeData
faker = Faker(self.locale, use_weighting=False)
for provider in faker_providers:
faker.add_provider(provider)

self.fake_data = FakeData(faker)
self.fake_data = FakeData(faker_providers, locale, self.context)

def _get_fake_data(self, name):
return self.fake_data._get_fake_data(name)
Expand All @@ -62,9 +61,6 @@ def __getattr__(self, name):
)


Template = lru_cache(512)(Template)


number_chars = set(string.digits + ".")


Expand Down
31 changes: 30 additions & 1 deletion tests/test_custom_plugins_and_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from base64 import b64decode

from snowfakery import SnowfakeryPlugin, lazy
from snowfakery.plugins import PluginResult
from snowfakery.plugins import PluginResult, PluginOption
from snowfakery.data_gen_exceptions import (
DataGenError,
DataGenTypeError,
Expand All @@ -25,6 +25,15 @@ def row_values(write_row_mock, index, value):


class SimpleTestPlugin(SnowfakeryPlugin):
allowed_options = [
PluginOption(
"tests.test_custom_plugins_and_providers.SimpleTestPlugin.option_str", str
),
PluginOption(
"tests.test_custom_plugins_and_providers.SimpleTestPlugin.option_int", int
),
]

class Functions:
def double(self, value):
return value * 2
Expand Down Expand Up @@ -200,6 +209,26 @@ def test_binary(self, generated_rows):
assert rawdata.startswith(b"%PDF-1.3")
assert b"Helvetica" in rawdata

def test_option__simple(self, generated_rows):
yaml = """- plugin: tests.test_custom_plugins_and_providers.SimpleTestPlugin"""

generate_data(StringIO(yaml), plugin_options={"option_str": "AAA"})

def test_option__unknown(self, generated_rows):
yaml = """- plugin: tests.test_custom_plugins_and_providers.SimpleTestPlugin"""

generate_data(StringIO(yaml), plugin_options={"option_str": "zzz"})

def test_option__bad_type(self, generated_rows):
yaml = """- plugin: tests.test_custom_plugins_and_providers.SimpleTestPlugin"""
with pytest.raises(DataGenTypeError):
generate_data(StringIO(yaml), plugin_options={"option_int": "abcd"})

def test_option_type_coercion_needed(self, generated_rows):
yaml = """- plugin: tests.test_custom_plugins_and_providers.SimpleTestPlugin"""

generate_data(StringIO(yaml), plugin_options={"option_int": "5"})


class PluginThatNeedsState(SnowfakeryPlugin):
class Functions:
Expand Down
3 changes: 0 additions & 3 deletions tests/test_data_generator_runtime_dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

from snowfakery.output_streams import DebugOutputStream

from snowfakery.utils.template_utils import FakerTemplateLibrary

ftl = FakerTemplateLibrary([])

line = {"filename": "abc.yml", "line_num": 42}

Expand Down
Loading