diff --git a/dp_wizard/app/components/column_module.py b/dp_wizard/app/components/column_module.py index 148a86a..f899ae7 100644 --- a/dp_wizard/app/components/column_module.py +++ b/dp_wizard/app/components/column_module.py @@ -189,4 +189,5 @@ def column_plot(): histogram, error=accuracy, cutoff=0, # TODO + title=f"Simulated {name}, assuming normal distribution", ) diff --git a/dp_wizard/utils/code_generators/__init__.py b/dp_wizard/utils/code_generators/__init__.py index d538a6d..8d4fb72 100644 --- a/dp_wizard/utils/code_generators/__init__.py +++ b/dp_wizard/utils/code_generators/__init__.py @@ -2,6 +2,9 @@ from abc import ABC, abstractmethod from pathlib import Path import re + +import black + from dp_wizard.utils.csv_helper import name_to_identifier from dp_wizard.utils.code_generators._template import Template from dp_wizard.utils.dp_helper import confidence @@ -37,35 +40,24 @@ def _make_extra_blocks(self): return {} def make_py(self): - return str( - Template(self.root_template).fill_blocks( + code = ( + Template(self.root_template) + .fill_blocks( IMPORTS_BLOCK=_make_imports(), COLUMNS_BLOCK=self._make_columns(), CONTEXT_BLOCK=self._make_context(), QUERIES_BLOCK=self._make_queries(), **self._make_extra_blocks(), ) + .finish() ) + return black.format_str(code, mode=black.Mode()) def _make_margins_dict(self, bin_names: Iterable[str]): - # TODO: Don't worry too much about the formatting here. - # Plan to run the output through black for consistency. - # https://github.com/opendp/dp-creator-ii/issues/50 - margins = ( - [ - """ - (): dp.polars.Margin( - public_info="lengths", - ),""" - ] - + [ - f""" - ("{bin_name}",): dp.polars.Margin( - public_info="keys", - ),""" - for bin_name in bin_names - ] - ) + margins = ["(): dp.polars.Margin(public_info='lengths',),"] + [ + f"('{bin_name}',): dp.polars.Margin(public_info='keys',)," + for bin_name in bin_names + ] margins_dict = "{" + "".join(margins) + "\n }" return margins_dict @@ -81,14 +73,79 @@ def _make_columns(self): for name, col in self.columns.items() ) + def _make_pre(self) -> str: + """ + If generating a notebook, this will open a new code paragraph. + """ + return "" + + def _make_post(self) -> str: + """ + If generating a notebook, this will close a new code paragraph. + """ + return "" + + def _make_confidence_note(self): + return f"{int(confidence * 100)}% confidence interval" + def _make_queries(self): - confidence_note = ( - "The actual value is within the shown range " - f"with {int(confidence * 100)}% confidence." - ) + pre = self._make_pre() + post = self._make_post() column_names = self.columns.keys() - return f"confidence = {confidence} # {confidence_note}\n\n" + "\n".join( - _make_query(column_name) for column_name in column_names + return ( + f"{pre}confidence = {confidence} # {self._make_confidence_note()}\n{post}" + + "\n".join( + f"{pre}{self._make_query(column_name)}{post}" + for column_name in column_names + ) + ) + + # def _make_queries(self): + # confidence_note = ( + # "The actual value is within the shown range " + # f"with {int(confidence * 100)}% confidence." + # ) + # column_names = self.columns.keys() + # return f"confidence = {confidence} # {confidence_note}\n\n" + "\n".join( + # _make_query(column_name) for column_name in column_names + + def _make_query(self, column_name): + indentifier = name_to_identifier(column_name) + title = f"DP counts for {column_name}" + accuracy_name = f"{indentifier}_accuracy" + histogram_name = f"{indentifier}_histogram" + return ( + Template("query") + .fill_values( + BIN_NAME=f"{indentifier}_bin", + ) + .fill_expressions( + QUERY_NAME=f"{indentifier}_query", + ACCURACY_NAME=accuracy_name, + HISTOGRAM_NAME=histogram_name, + ) + .fill_blocks( + OUTPUT_BLOCK=self._make_output( + title=title, + accuracy_name=accuracy_name, + histogram_name=histogram_name, + ) + ) + .finish() + ) + + def _make_output(self, title: str, accuracy_name: str, histogram_name: str): + return ( + Template(f"{self.root_template}_output") + .fill_values( + TITLE=title, + ) + .fill_expressions( + ACCURACY_NAME=accuracy_name, + HISTOGRAM_NAME=histogram_name, + CONFIDENCE_NOTE=self._make_confidence_note(), + ) + .finish() ) def _make_partial_context(self): @@ -118,29 +175,34 @@ class NotebookGenerator(_CodeGenerator): root_template = "notebook" def _make_context(self): - return str(self._make_partial_context().fill_values(CSV_PATH=self.csv_path)) + return self._make_partial_context().fill_values(CSV_PATH=self.csv_path).finish() + + def _make_pre(self): + return "# +\n" + + def _make_post(self): + return "# -\n" def _make_extra_blocks(self): outputs_expression = ( "{" + ",".join( - str( - Template("report_kv") - .fill_values( - NAME=name, - CONFIDENCE=confidence, - ) - .fill_expressions( - IDENTIFIER_HISTOGRAM=f"{name_to_identifier(name)}_histogram", - IDENTIFIER_ACCURACY=f"{name_to_identifier(name)}_accuracy", - ) + Template("report_kv") + .fill_values( + NAME=name, + CONFIDENCE=confidence, ) + .fill_expressions( + IDENTIFIER_HISTOGRAM=f"{name_to_identifier(name)}_histogram", + IDENTIFIER_ACCURACY=f"{name_to_identifier(name)}_accuracy", + ) + .finish() for name in self.columns.keys() ) + "}" ) tmp_path = Path(__file__).parent.parent.parent / "tmp" - reports_block = str( + reports_block = ( Template("reports") .fill_expressions( OUTPUTS=outputs_expression, @@ -152,6 +214,7 @@ def _make_extra_blocks(self): TXT_REPORT_PATH=str(tmp_path / "report.txt"), CSV_REPORT_PATH=str(tmp_path / "report.csv"), ) + .finish() ) return {"REPORTS_BLOCK": reports_block} @@ -160,7 +223,14 @@ class ScriptGenerator(_CodeGenerator): root_template = "script" def _make_context(self): - return str(self._make_partial_context().fill_expressions(CSV_PATH="csv_path")) + return ( + self._make_partial_context().fill_expressions(CSV_PATH="csv_path").finish() + ) + + def _make_confidence_note(self): + # In the superclass, the string is unquoted so it can be + # used in comments: It needs to be wrapped here. + return repr(super()._make_confidence_note()) # Public functions used to generate code snippets in the UI; @@ -168,11 +238,11 @@ def _make_context(self): def make_privacy_unit_block(contributions: int): - return str(Template("privacy_unit").fill_values(CONTRIBUTIONS=contributions)) + return Template("privacy_unit").fill_values(CONTRIBUTIONS=contributions).finish() def make_privacy_loss_block(epsilon: float): - return str(Template("privacy_loss").fill_values(EPSILON=epsilon)) + return Template("privacy_loss").fill_values(EPSILON=epsilon).finish() def make_column_config_block( @@ -202,7 +272,7 @@ def make_column_config_block( """ snake_name = _snake_case(name) - return str( + return ( Template("column_config") .fill_expressions( CUT_LIST_NAME=f"{snake_name}_cut_points", @@ -215,6 +285,7 @@ def make_column_config_block( COLUMN_NAME=name, BIN_COLUMN_NAME=f"{snake_name}_bin", ) + .finish() ) @@ -223,31 +294,22 @@ def make_column_config_block( # so it's better to keep them out of the class. -def _make_query(column_name): - indentifier = name_to_identifier(column_name) - return str( - Template("query") - .fill_values( - BIN_NAME=f"{indentifier}_bin", - ) - .fill_expressions( - QUERY_NAME=f"{indentifier}_query", - ACCURACY_NAME=f"{indentifier}_accuracy", - HISTOGRAM_NAME=f"{indentifier}_histogram", - ) - ) - - def _snake_case(name: str): """ >>> _snake_case("HW GRADE") 'hw_grade' + >>> _snake_case("123") + '_123' """ - return re.sub(r"\W+", "_", name.lower()) + snake = re.sub(r"\W+", "_", name.lower()) + # TODO: More validation in UI so we don't get zero-length strings. + if snake == "" or not re.match(r"[a-z]", snake[0]): + snake = f"_{snake}" + return snake def _make_imports(): return ( - str(Template("imports").fill_values()) + Template("imports").fill_values().finish() + (Path(__file__).parent.parent / "shared.py").read_text() ) diff --git a/dp_wizard/utils/code_generators/_template.py b/dp_wizard/utils/code_generators/_template.py index 35d2d97..619f672 100644 --- a/dp_wizard/utils/code_generators/_template.py +++ b/dp_wizard/utils/code_generators/_template.py @@ -26,6 +26,9 @@ def _find_slots(self): return set(re.findall(slot_re, self._template)) def fill_expressions(self, **kwargs): + """ + Fill in variable names, or dicts or lists represented as strings. + """ for k, v in kwargs.items(): k_re = re.escape(k) self._template, count = re.subn(rf"\b{k_re}\b", str(v), self._template) @@ -37,6 +40,9 @@ def fill_expressions(self, **kwargs): return self def fill_values(self, **kwargs): + """ + Fill in string or numeric values. `repr` is called before filling. + """ for k, v in kwargs.items(): k_re = re.escape(k) self._template, count = re.subn(rf"\b{k_re}\b", repr(v), self._template) @@ -48,6 +54,9 @@ def fill_values(self, **kwargs): return self def fill_blocks(self, **kwargs): + """ + Fill in code blocks. Slot must be alone on line. + """ for k, v in kwargs.items(): def match_indent(match): @@ -76,7 +85,7 @@ def match_indent(match): raise Exception(base_message) return self - def __str__(self): + def finish(self): unfilled_slots = self._initial_slots & self._find_slots() if unfilled_slots: slots_str = ", ".join(sorted(f"'{slot}'" for slot in unfilled_slots)) diff --git a/dp_wizard/utils/code_generators/no-tests/_notebook.py b/dp_wizard/utils/code_generators/no-tests/_notebook.py index a6e8948..d3342d2 100644 --- a/dp_wizard/utils/code_generators/no-tests/_notebook.py +++ b/dp_wizard/utils/code_generators/no-tests/_notebook.py @@ -31,7 +31,6 @@ # # Finally, we run the queries and plot the results. -# + QUERIES_BLOCK # - diff --git a/dp_wizard/utils/code_generators/no-tests/_notebook_output.py b/dp_wizard/utils/code_generators/no-tests/_notebook_output.py new file mode 100644 index 0000000..0d288c3 --- /dev/null +++ b/dp_wizard/utils/code_generators/no-tests/_notebook_output.py @@ -0,0 +1,2 @@ +# CONFIDENCE_NOTE +plot_histogram(HISTOGRAM_NAME, error=ACCURACY_NAME, cutoff=0, title=TITLE) diff --git a/dp_wizard/utils/code_generators/no-tests/_query.py b/dp_wizard/utils/code_generators/no-tests/_query.py index 0d78f0d..c6ed0da 100644 --- a/dp_wizard/utils/code_generators/no-tests/_query.py +++ b/dp_wizard/utils/code_generators/no-tests/_query.py @@ -1,4 +1,4 @@ QUERY_NAME = context.query().group_by(BIN_NAME).agg(pl.len().dp.noise()) ACCURACY_NAME = QUERY_NAME.summarize(alpha=1 - confidence)["accuracy"].item() HISTOGRAM_NAME = QUERY_NAME.release().collect().sort(BIN_NAME) -plot_histogram(HISTOGRAM_NAME, ACCURACY_NAME, 0) +OUTPUT_BLOCK diff --git a/dp_wizard/utils/code_generators/no-tests/_script.py b/dp_wizard/utils/code_generators/no-tests/_script.py index ab43f5c..ff669b7 100644 --- a/dp_wizard/utils/code_generators/no-tests/_script.py +++ b/dp_wizard/utils/code_generators/no-tests/_script.py @@ -14,7 +14,9 @@ def get_context(csv_path): parser = ArgumentParser( description="Creates a differentially private release from a csv" ) - parser.add_argument("--csv", help="Path to csv containing private data") + parser.add_argument( + "--csv", required=True, help="Path to csv containing private data" + ) args = parser.parse_args() context = get_context(csv_path=args.csv) diff --git a/dp_wizard/utils/code_generators/no-tests/_script_output.py b/dp_wizard/utils/code_generators/no-tests/_script_output.py new file mode 100644 index 0000000..a47a251 --- /dev/null +++ b/dp_wizard/utils/code_generators/no-tests/_script_output.py @@ -0,0 +1,3 @@ +print(TITLE) +print(CONFIDENCE_NOTE, ACCURACY_NAME) +print(HISTOGRAM_NAME) diff --git a/dp_wizard/utils/shared.py b/dp_wizard/utils/shared.py index b6cc03a..99ea164 100644 --- a/dp_wizard/utils/shared.py +++ b/dp_wizard/utils/shared.py @@ -40,7 +40,7 @@ def df_to_columns(df: DataFrame): def plot_histogram( - histogram_df: DataFrame, error: float, cutoff: float + histogram_df: DataFrame, error: float, cutoff: float, title: str ): # pragma: no cover """ Given a Dataframe for a histogram, plot the data. @@ -58,3 +58,4 @@ def plot_histogram( axes.set_xticks(minors, ["" for _ in minors], minor=True) axes.axhline(cutoff, color="lightgrey", zorder=-1) axes.set_ylim(bottom=0) + axes.set_title(title) diff --git a/pyproject.toml b/pyproject.toml index 26cd7fc..49c7cd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "jupyter-client", "nbconvert", "ipykernel", + "black", "pyyaml", ] diff --git a/tests/utils/test_code_generators.py b/tests/utils/test_code_generators.py index aad08d4..f09ddb4 100644 --- a/tests/utils/test_code_generators.py +++ b/tests/utils/test_code_generators.py @@ -23,74 +23,62 @@ def test_param_conflict(): def test_fill_expressions(): template = Template(None, template="No one VERB the ADJ NOUN!") - filled = str( - template.fill_expressions( - VERB="expects", - ADJ="Spanish", - NOUN="Inquisition", - ) - ) + filled = template.fill_expressions( + VERB="expects", + ADJ="Spanish", + NOUN="Inquisition", + ).finish() assert filled == "No one expects the Spanish Inquisition!" def test_fill_expressions_missing_slot_in_template(): template = Template(None, template="No one ... the ADJ NOUN!") with pytest.raises(Exception, match=r"No 'VERB' slot to fill with 'expects'"): - str( - template.fill_expressions( - VERB="expects", - ADJ="Spanish", - NOUN="Inquisition", - ) - ) + template.fill_expressions( + VERB="expects", + ADJ="Spanish", + NOUN="Inquisition", + ).finish() def test_fill_expressions_extra_slot_in_template(): template = Template(None, template="No one VERB ARTICLE ADJ NOUN!") with pytest.raises(Exception, match=r"'ARTICLE' slot not filled"): - str( - template.fill_expressions( - VERB="expects", - ADJ="Spanish", - NOUN="Inquisition", - ) - ) + template.fill_expressions( + VERB="expects", + ADJ="Spanish", + NOUN="Inquisition", + ).finish() def test_fill_values(): template = Template(None, template="assert [STRING] * NUM == LIST") - filled = str( - template.fill_values( - STRING="🙂", - NUM=3, - LIST=["🙂", "🙂", "🙂"], - ) - ) + filled = template.fill_values( + STRING="🙂", + NUM=3, + LIST=["🙂", "🙂", "🙂"], + ).finish() assert filled == "assert ['🙂'] * 3 == ['🙂', '🙂', '🙂']" def test_fill_values_missing_slot_in_template(): template = Template(None, template="assert [STRING] * ... == LIST") with pytest.raises(Exception, match=r"No 'NUM' slot to fill with '3'"): - str( - template.fill_values( - STRING="🙂", - NUM=3, - LIST=["🙂", "🙂", "🙂"], - ) - ) + template.fill_values( + STRING="🙂", + NUM=3, + LIST=["🙂", "🙂", "🙂"], + ).finish() def test_fill_values_extra_slot_in_template(): template = Template(None, template="CMD [STRING] * NUM == LIST") with pytest.raises(Exception, match=r"'CMD' slot not filled"): - str( - template.fill_values( - STRING="🙂", - NUM=3, - LIST=["🙂", "🙂", "🙂"], - ) - ) + template.fill_values( + STRING="🙂", + NUM=3, + LIST=["🙂", "🙂", "🙂"], + ).finish() def test_fill_blocks(): @@ -113,7 +101,7 @@ def test_fill_blocks(): THIRD="\n".join(f"{i}()" for i in "xyz"), ) assert ( - str(template) + template.finish() == """# MixedCase is OK import a @@ -135,7 +123,7 @@ def test_fill_blocks(): def test_fill_blocks_missing_slot_in_template_alone(): template = Template(None, template="No block slot") with pytest.raises(Exception, match=r"No 'SLOT' slot"): - str(template.fill_blocks(SLOT="placeholder")) + template.fill_blocks(SLOT="placeholder").finish() def test_fill_blocks_missing_slot_in_template_not_alone(): @@ -143,13 +131,13 @@ def test_fill_blocks_missing_slot_in_template_not_alone(): with pytest.raises( Exception, match=r"Block slots must be alone on line; No 'SLOT' slot" ): - str(template.fill_blocks(SLOT="placeholder")) + template.fill_blocks(SLOT="placeholder").finish() def test_fill_blocks_extra_slot_in_template(): template = Template(None, template="EXTRA\nSLOT") with pytest.raises(Exception, match=r"'EXTRA' slot not filled"): - str(template.fill_blocks(SLOT="placeholder")) + template.fill_blocks(SLOT="placeholder").finish() def test_make_notebook(): @@ -194,9 +182,21 @@ def test_make_script(): ).make_py() print(script) + # Make sure jupytext formatting doesn't bleed into the script. + # https://jupytext.readthedocs.io/en/latest/formats-scripts.html#the-light-format + assert "# -" not in script + assert "# +" not in script + with NamedTemporaryFile(mode="w") as fp: fp.write(script) fp.flush() - result = subprocess.run(["python", fp.name, "--csv", fake_csv]) + result = subprocess.run( + ["python", fp.name, "--csv", fake_csv], capture_output=True + ) assert result.returncode == 0 + output = result.stdout.decode() + print(output) + assert "DP counts for hw-number" in output + assert "95% confidence interval 3.3" in output + assert "hw_number_bin" in output