From a2493104bc621bdadc41f885f7848485a17099db Mon Sep 17 00:00:00 2001 From: Amy Steier Date: Thu, 29 Oct 2020 15:24:41 -0700 Subject: [PATCH] (Draft) Auto-Balance Blueprint and Code (#7) * (Draft) Auto-Balance Blueprint and Code * Auto-bal blueprint, fix bug blueprint, update colors * (Draft) Auto-Balance Blueprint updates * (Draft) Final touches on Auto-Balance Blueprint * (Draft) auto-bal manifest updates * (Draft) Auto-bal blueprint more cleanup --- gretel/auto_balance_dataset/README.md | 17 + gretel/auto_balance_dataset/bias_bp_data.py | 136 +++ .../auto_balance_dataset/bias_bp_generate.py | 203 ++++ gretel/auto_balance_dataset/bias_bp_graphs.py | 188 ++++ gretel/auto_balance_dataset/bias_bp_inputs.py | 69 ++ gretel/auto_balance_dataset/blueprint.ipynb | 890 ++++++++++++++++++ gretel/auto_balance_dataset/manifest.json | 9 + 7 files changed, 1512 insertions(+) create mode 100644 gretel/auto_balance_dataset/README.md create mode 100644 gretel/auto_balance_dataset/bias_bp_data.py create mode 100644 gretel/auto_balance_dataset/bias_bp_generate.py create mode 100644 gretel/auto_balance_dataset/bias_bp_graphs.py create mode 100644 gretel/auto_balance_dataset/bias_bp_inputs.py create mode 100644 gretel/auto_balance_dataset/blueprint.ipynb create mode 100644 gretel/auto_balance_dataset/manifest.json diff --git a/gretel/auto_balance_dataset/README.md b/gretel/auto_balance_dataset/README.md new file mode 100644 index 00000000..6d6d1d42 --- /dev/null +++ b/gretel/auto_balance_dataset/README.md @@ -0,0 +1,17 @@ +# Auto-Balance Dataset + +In this blueprint, we will use Gretel-Synthetics to produce a balanced, privacy preserving version of your dataset. This blueprint can be used to support fair AI as well as generally any imbalanced dataset. Information Systems (IS) utilizing Artificial Intelligence (AI) are now ubiquitous in our culture. They are often responsible for critical decisions such as who to hire and at what salary, who to give a loan or insurance policy to, and who is at risk for cancer or heart decease. Fair AI strives to eliminate IS discrimination against demographic groups. This blueprint can help you achieve fair AI by eliminating the bias in your data. All it takes is one pass through the data, and bias will be completely removed from as many fields as you like. Correlations and distributions in non-bias fields will, as always, transfer from your training data to your synthetic data. + + +## Objective +In this blueprint, we will remove bias by training a generative synthetic data model to create a balanced dataset. The blueprint supports two different modes for balancing your data. The first (mode="full"), is the scenario where you'd like to generate a complete synthetic dataset with bias removed. The second (mode="additive"), is the scenario where you only want to generate synthetic samples, such that when added to the original set will remove bias. The blueprint takes data from an existing Gretel project and first shows graphically the field distributions to help you narrow in on fields containing bias. After choosing the fields you would like to balance, a synthetic data model is then trained and data is generated as needed to balance your dataset. At the conclusion of the blueprint, the option is given to generate a full Synthetic Performance Report. An example report created after balancing the columns "gender", "race" and "income_bracket" in the Kaggle US Adult Income dataset is located [here](https://gretel-public-website.s3-us-west-2.amazonaws.com/blueprints/data_balancing/Auto_Balance_Performance_Report.html). + + +## Steps +1. Click "Transform" on the project NavBar. +2. Copy your Project URI key from the Console. +3. Select the "Auto-Balance Dataset" notebook. +4. If using Colab, click Runtime->Change project runtime and change to "GPU". +5. Step through the notebook cells, answering questions and adjusting params as needed. +6. Once your new data is generated, choose either the cell to save to a CSV or to a new Gretel project. +7. Finally, if desired, generate a full Synthetic Performance Report. diff --git a/gretel/auto_balance_dataset/bias_bp_data.py b/gretel/auto_balance_dataset/bias_bp_data.py new file mode 100644 index 00000000..6ee5b1d3 --- /dev/null +++ b/gretel/auto_balance_dataset/bias_bp_data.py @@ -0,0 +1,136 @@ +import os +from typing import List, Dict, Optional + +import numpy as np +import pandas as pd +from scipy.spatial.distance import jensenshannon + +from gretel_auto_xf.facts import ProjectFacts +from gretel_auto_xf.pipeline import Config +from gretel_client.projects import Project + +F_RATIO_THRESHOLD = .7 + + +def get_field_type(types: List[dict]) -> Optional[str]: + """ + If the field contains any numeric values, return numeric. + We want to return type string only when the field contains strings + but no numeric. + """ + + found_categorical = False + for next_type in types: + if next_type["type"] == "numeric": + return "numeric" + if next_type["type"] == "string": + found_categorical = True + + if found_categorical: + return "string" + else: + return None + + +def get_entities(entities: dict) -> List[str]: + """ + We only want to list an entity with a field if it is pervasively + tagged in the column. + """ + + entity_list = [] + for next_entity in entities: + if next_entity["f_ratio"] > F_RATIO_THRESHOLD: + entity_list.append(next_entity["label"]) + + return entity_list + + +def get_distrib(class_cnts: Dict[str, int], count: int) -> Dict[str, float]: + + distribution = {} + for k in class_cnts.keys(): + distribution[k] = class_cnts[k] / count + + sorted_distrib = {k: v for k, v in sorted(distribution.items(), key=lambda item: item[1], reverse=True)} + + return sorted_distrib + + +def get_field_cnts(field: pd.Series) -> Dict[str, int]: + + distribution = {} + field_clean = field.dropna() + for v in field_clean: + distribution[str(v)] = distribution.get(str(v), 0) + 1 + + return distribution + + +def get_project_facts(project: Project, num_records: int) -> "ProjectFacts": + """ + This function borrows a method from the gretel_auto_xf module to retrieve + information about a given project + """ + + config = Config() + config.max_records = num_records + + return ProjectFacts.from_project(project, config) + + +def get_project_info(project: Project, mode = "full", num_records = 5000, gen_lines = None) -> dict: + """ + This gathers the necessary information from a Project to support synthetic auto-balance + + Arguments: + project: Reference to a project's API client + mode: Can be either "full" or "additive". Mode "full" means generate a full synthetic balanced dataset. + Mode "additive" means only generate enough synthetic samples, such that when added to the + original set, the categorical classes are balanced. + num_records: How many records to retrieve from the project. + gen_lines: In mode "full", this is the number of synthetic records you'd like generated. + + Returns: + A data structure that is used throughout the synthetic auto-balance notebook + """ + + project_info = {} + facts = get_project_facts(project, num_records) + project_info["mode"] = mode + project_info["gen_lines"] = gen_lines + project_info["records"] = facts.as_df + project_info["num_records"] = len(project_info["records"].index) + project_info["field_stats"] = {} + + for field in facts.stats: + if get_field_type(facts.stats[field]["types"]) != "string": + continue + entities = get_entities(facts.stats[field]["entities"]) + if "date" in entities: + continue + project_info["field_stats"][field] = {} + project_info["field_stats"][field]["count"] = facts.stats[field]["count"] + project_info["field_stats"][field]["cardinality"] = facts.stats[field]["approx_cardinality"] + project_info["field_stats"][field]["pct_missing"] = facts.stats[field]["pct_missing"] + project_info["field_stats"][field]["use"] = False + project_info["field_stats"][field]["entities"] = entities + project_info["field_stats"][field]["class_cnts"] = get_field_cnts(project_info["records"][field]) + project_info["field_stats"][field]["distrib"] = get_distrib(project_info["field_stats"][field]["class_cnts"], + project_info["field_stats"][field]["count"]) + + return project_info + + +def bias_fields(project_info: dict) -> List[str]: + """ + This functions returns the list of fields that were chosen + by the user in the notebook to remove bias from + """ + + use_fields = [] + for field in project_info["field_stats"]: + if project_info["field_stats"][field]["use"]: + use_fields.append(field) + + return use_fields diff --git a/gretel/auto_balance_dataset/bias_bp_generate.py b/gretel/auto_balance_dataset/bias_bp_generate.py new file mode 100644 index 00000000..3b721e8d --- /dev/null +++ b/gretel/auto_balance_dataset/bias_bp_generate.py @@ -0,0 +1,203 @@ +import random +from typing import List, Dict + +import pandas as pd +import itertools + + +def get_mode_full_seeds(project_info: dict) -> List[dict]: + """ + This function gets the smarts seeds needed to generate synthetic data + when the user has chosen mode "full" (generate a full synthetic dataset). + To get the number of synthetic lines needed per seed, it first compute + the ratio needed for each field's values. Then the global ratio needed + per each combo seed is the product of each fields ratio. + """ + + even_percents = {} + categ_val_lists = [] + seed_percent = 1 + balance_columns = [] + gen_lines = project_info["gen_lines"] + for field in project_info["field_stats"]: + if project_info["field_stats"][field]["use"]: + values = set(pd.Series(project_info["records"][field].dropna())) + category_cnt = len(values) + even_percents[field] = 1/category_cnt + categ_val_lists.append(list(values)) + seed_percent = seed_percent * even_percents[field] + balance_columns.append(field) + + seed_gen_cnt = seed_percent * gen_lines + seed_fields = [] + for combo in itertools.product(*categ_val_lists): + seed_dict = {} + i = 0 + for field in balance_columns: + seed_dict[field] = combo[i] + i += 1 + seed = {} + seed["seed"] = seed_dict + seed["cnt"] = seed_gen_cnt + seed_fields.append(seed) + + return seed_fields + + +def get_seed_amts(field: dict) -> Dict[str, int]: + + seed_needs = {} + max = 0 + for class_value in field["class_cnts"]: + if field["class_cnts"][class_value] > max: + max = field["class_cnts"][class_value] + for class_value in field["class_cnts"]: + seed_needs[class_value] = max - field["class_cnts"][class_value] + + return seed_needs + + +def get_mode_additive_seeds(project_info: dict) -> List[dict]: + """ + This function gets the smarts seeds needed to generate synthetic data + when the user has chosen mode "additive" (generate only enough synthetic data, such + that when added to the original data creates a balanced set). + """ + + #First get the seed needs relative to that field + seed_amts = {} + for field in project_info["field_stats"]: + if project_info["field_stats"][field]["use"]: + seed_amts[field] = get_seed_amts(project_info["field_stats"][field]) + + #Now determine the field with the highest need count, and bring the other + #seeds up to that count, keeping each field balanced. + + max_need = 0 + field_needs = {} + for field in seed_amts: + need = 0 + for class_value in seed_amts[field]: + need += seed_amts[field][class_value] + field_needs[field] = need + if need > max_need: + max_need = need + + for field in seed_amts: + if field_needs[field] < max_need: + diff = max_need - field_needs[field] + more = True + used = 0 + #Idea is to keep looping through the field's class values, incrementing the corresponding seed + #amount by one each time, until the "diff" has been used up. Once the "diff" is used up, we + #must exit the loop immediately so that all field lists are of the same length. + while more: + for class_value in seed_amts[field]: + if used == diff: + more = False + continue + seed_amts[field][class_value] += 1 + used += 1 + + #Now create the needed combo seeds and their counts + #For each field, we'll create a list of all values needed and sort it randomly. + #The length of these lists will be the same, as we brought the sum of each field's + #seed needs up to the max need. + #Then we'll create combo seeds by taking one value from each field list + + field_lists = {} + for field in seed_amts: + curr_list = [] + for class_value in seed_amts[field]: + for i in range(seed_amts[field][class_value]): + curr_list.append(class_value) + random.shuffle(curr_list) + field_lists[field] = curr_list + + all_seeds = {} + for i in range(max_need): + seed_dict = {} + seed = "" + for field in field_lists: + seed_dict[field] = field_lists[field].pop() + seed = seed + "::" + seed_dict[field] + + if seed in all_seeds: + all_seeds[seed]["cnt"] += 1 + else: + all_seeds[seed] = {} + all_seeds[seed]["info"] = seed_dict + all_seeds[seed]["cnt"] = 1 + + #reformat to be like other mode seeds + + seed_fields = [] + for next in all_seeds: + seed = {} + seed["seed"] = all_seeds[next]["info"] + seed["cnt"] = all_seeds[next]["cnt"] + seed_fields.append(seed) + + return seed_fields + + +def gen_smart_seeds(project_info: dict) -> dict: + + if project_info["mode"] == "full": + seeds = get_mode_full_seeds(project_info) + else: + seeds = get_mode_additive_seeds(project_info) + + project_info["seeds"] = seeds + + return project_info + + +def compute_synth_needs(project_info: dict) -> dict: + + project_info = gen_smart_seeds(project_info) + + if project_info["mode"] == "additive": + synth_needs = 0 + for seed in project_info["seeds"]: + synth_needs += seed["cnt"] + + print("Total synthetic records required to fix bias is: " + str(synth_needs)) + + return project_info + + +def gen_synth_nobias(bundle, project_info: dict) -> pd.DataFrame: + """ + This is the main routine called in the synth auto-balance notebook for + generating balanced synthetic data. It returns the final synthetic + dataframe. + """ + + seeds = project_info["seeds"] + seed_cnt = len(seeds) + bias_cnt = 0 + for field in project_info["field_stats"]: + if project_info["field_stats"][field]["use"]: + bias_cnt += 1 + + print("Balancing synthetic generation for " + str(seed_cnt) + " value combinations from " + \ + str(bias_cnt) + " bias fields.\n") + + synth_df = pd.DataFrame(columns=project_info["records"].columns) + + max_invalid = 0 + if project_info["mode"] == "full": + max_invalid = 1000 * project_info["gen_lines"] + else: + max_invalid = 1000 * project_info["num_records"] + + cnt = 1 + for seed in seeds: + print("Balancing combination " + str(cnt) + " of " + str(seed_cnt) + ":") + cnt += 1 + bundle.generate(num_lines=int(seed["cnt"]), max_invalid=max_invalid, seed_fields=seed["seed"]) + tempdf = bundle.get_synthetic_df() + synth_df = synth_df.append(tempdf, ignore_index=True) + + return synth_df diff --git a/gretel/auto_balance_dataset/bias_bp_graphs.py b/gretel/auto_balance_dataset/bias_bp_graphs.py new file mode 100644 index 00000000..336becc8 --- /dev/null +++ b/gretel/auto_balance_dataset/bias_bp_graphs.py @@ -0,0 +1,188 @@ +import math +from typing import Tuple, Dict + +from plotly.subplots import make_subplots +import plotly.graph_objects as go +import pandas as pd + +_GRETEL_PALETTE = ['#C18DFC', '#47E0B3'] +_GRAPH_OPACITY = 0.75 +_GRAPH_BARGAP = 0.2 +_GRAPH_BARGROUPGAP = .1 +_GRAPH_MAX_BARS = 1000 + + +def get_graph_dimen(fields: dict, uniq_cnt_threshold: int) -> Tuple[int,int]: + """ + Helper function to first figure out how many graphs we'll be + displaying, and then based on that determine the appropriate + row and column count for display + """ + + graph_cnt = 0 + for field in fields: + if fields[field]["cardinality"] <= uniq_cnt_threshold: + graph_cnt += 1 + + col_cnt = 1 + if uniq_cnt_threshold <= 50: + col_cnt = min(3, graph_cnt) + elif uniq_cnt_threshold <= 100: + col_cnt = min(2, graph_cnt) + else: + col_cnt = 1 + row_cnt = math.ceil(graph_cnt/col_cnt) + + return row_cnt, col_cnt + + +def get_distrib_show(distrib: Dict[str, float]) -> Dict[str, float]: + """ + Plotly slighly freaks with more than 1000 bars, so in the remote + chance they chose to see graphs with more than 1000 unique values + limit the graph bars to the highest 1000 values + """ + + if len(distrib) <= _GRAPH_MAX_BARS: + return distrib + + cnt = 0 + new_distrib = {} + for field in distrib: + new_distrib[field] = distrib[field] + cnt += 1 + if cnt == _GRAPH_MAX_BARS: + return new_distrib + + +def show_field_graphs(fields: dict, uniq_cnt_threshold=10): + """ + This function takes the categorical fields in a project that have + a unique value count less than or equal to the parameter + "uniq_cnt_threshold" and displays their current distributions + using plotly bar charts. The number of columns used to display + the graphs will depend on this value as well. + """ + + row_cnt, col_cnt = get_graph_dimen(fields, uniq_cnt_threshold) + titles = [] + for field in fields: + if fields[field]["cardinality"] <= uniq_cnt_threshold: + titles.append(field) + + shared_yaxes = True + if col_cnt == 1: + shared_yaxes = False + + fig = make_subplots(rows=row_cnt, cols=col_cnt, shared_yaxes=shared_yaxes, subplot_titles=titles) + + row = 1 + col = 1 + for field in fields: + if fields[field]["cardinality"] <= uniq_cnt_threshold: + distrib = get_distrib_show(fields[field]["distrib"]) + fig.add_trace( + go.Bar( + x=list(distrib.keys()), + y=list(distrib.values()), + name=field + ), + row, + col + ) + if col == col_cnt: + col = 1 + row += 1 + else: + col += 1 + + height = (700 / col_cnt) * row_cnt + + fig.update_layout( + height=height, + width=900, + showlegend=False, + title='Existing Categorical Field Distributions', + font=dict( + size=8, + color="RebeccaPurple" + ) + ) + fig.show() + + +def get_new_distrib(field: pd.Series) -> Dict[str, float]: + """ + Even though we know what the new distribution will be, here + we compute it fresh from the new data as a sanity check + """ + + distribution = {} + for v in field: + distribution[str(v)] = distribution.get(str(v), 0) + 1 + series_len = float(len(field)) + for k in distribution.keys(): + distribution[k] = distribution[k] / series_len + + return distribution + + +def show_bar_chart(orig: Dict[str, float], new: Dict[str, float], field: str, mode: str): + """ + This function takes two distributions (orig and new), along + with the name of the field and mode and plots the + distributions on the same graph + """ + + fig = go.Figure() + fig.add_trace( + go.Bar( + x=list(orig.keys()), + y=list(orig.values()), + name='Training', + marker_color=_GRETEL_PALETTE[0], + opacity=_GRAPH_OPACITY + ) + ) + + name = "Synthetic" + if mode == "additive": + name = "Training + Synthetic" + + fig.add_trace( + go.Bar( + x=list(new.keys()), + y=list(new.values()), + name=name, + marker_color=_GRETEL_PALETTE[1], + opacity=_GRAPH_OPACITY + ) + ) + fig.update_layout( + title='Field: ' + field + '', + yaxis_title_text='Percentage', + bargap=_GRAPH_BARGAP, + bargroupgap=_GRAPH_BARGROUPGAP, + barmode='group' + ) + fig.show() + + +def show_new_graphs(project_info: dict, synth_df: pd.DataFrame): + """ + This function is called at the conclusion of the synth auto-balance notebook to take + a look at how the new distributions compare to the original + """ + + new_df = pd.DataFrame() + if project_info["mode"] == "additive": + new_df = pd.concat([project_info["records"], synth_df], ignore_index=True) + else: + new_df = synth_df + + for field in project_info["field_stats"]: + if project_info["field_stats"][field]["use"]: + new = pd.Series(new_df[field]).dropna() + new_distrib = get_new_distrib(new) + show_bar_chart(project_info["field_stats"][field]["distrib"], new_distrib, field, project_info["mode"]) + \ No newline at end of file diff --git a/gretel/auto_balance_dataset/bias_bp_inputs.py b/gretel/auto_balance_dataset/bias_bp_inputs.py new file mode 100644 index 00000000..29f2d70e --- /dev/null +++ b/gretel/auto_balance_dataset/bias_bp_inputs.py @@ -0,0 +1,69 @@ +from functools import partial + +from tabulate import tabulate +from ipywidgets import HBox, VBox, widgets, Layout, HTML + + +def choose_bias_fields(project_info: dict) -> dict: + """ + This is the function called from the synthetic auto-balance notebook that + enables the user to pick which fields they want to remove bias from. + It displays a table consisting of the categorical fields in the project, + and for each field shows its cardinality, %missing and + prevalent entities associated with the column. + Approach/code hugely borrowed from the gretel_auto_xf module. + """ + + table = {} + field_names = [] + field_cardinality = [] + field_pct_missing = [] + field_entities = [] + for field in project_info["field_stats"]: + field_names.append(field) + field_cardinality.append(project_info["field_stats"][field]["cardinality"]) + field_pct_missing.append(project_info["field_stats"][field]["pct_missing"]) + field_entities.append(project_info["field_stats"][field]["entities"]) + table["Field"] = field_names + table["Unique Value Cnt"] = field_cardinality + table["% Missing"] = field_pct_missing + table["Entities"] = field_entities + + report_str = tabulate(table, headers="keys", tablefmt="simple") + + line_height = "1.5rem" + padding_margin = {"margin": "0 0 0 0", "padding": "0 0 0 0"} + + layout = Layout(width="30px", height=line_height, **padding_margin) # type: ignore + + def on_check(field, evt: dict): + if evt.get("new"): + project_info["field_stats"][field]["use"] = True + else: + project_info["field_stats"][field]["use"] = False + + def build_checkbox(field): + check_box = widgets.Checkbox(value=False, indent=False, layout=layout) # type: ignore + check_box.observe(partial(on_check, field), names=["value"]) + return check_box + + buttons = [build_checkbox(field) for field in project_info["field_stats"]] + + display( + HBox( + [ + VBox( + [ + widgets.Label(layout=layout), # type: ignore + widgets.Label(layout=layout), # type: ignore + *buttons, + ], + layout=Layout(**padding_margin), # type: ignore + ), + HTML(f'
{report_str}
'), + ], + layout=Layout(**padding_margin), # type: ignore + ) + ) + + return project_info diff --git a/gretel/auto_balance_dataset/blueprint.ipynb b/gretel/auto_balance_dataset/blueprint.ipynb new file mode 100644 index 00000000..9efc936e --- /dev/null +++ b/gretel/auto_balance_dataset/blueprint.ipynb @@ -0,0 +1,890 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Gretel Blueprint: Auto-Balance Dataset\n", + "Use Gretel-Synthetics to automatically balance your project data. This blueprint can be used in support of fair AI and generally any imbalanced dataset to boost minority classes. In one pass, bias will be completely removed from as many fields as you like." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y9_t84MuNLFi" + }, + "source": [ + "# Install Packages\n", + "Install open source and premium packages from Gretel.ai" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-N2jfpPqsgZ7" + }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install numpy pandas \n", + "!pip install -U gretel-client \"gretel-synthetics>=0.14.0\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Be sure to use your Gretel URI here, which is available from the Integration menu in the Console\n", + "\n", + "import getpass\n", + "import os\n", + "\n", + "gretel_uri = os.getenv(\"GRETEL_URI\") or getpass.getpass(\"Your Gretel URI\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install Gretel SDKs\n", + "\n", + "from gretel_client import project_from_uri\n", + "\n", + "project = project_from_uri(gretel_uri)\n", + "client = project.client\n", + "project.client.install_packages()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Import Blueprint Modules\n", + "If you are running on Google Colab, use the first cell to download files from our blueprint repo into a Colab notebook's working directory. Remember to change colab to a GPU runtime." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!curl -sL https://get.gretel.cloud/blueprint.sh | bash -s gretel/auto_balance_dataset/*.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import bias_bp_inputs as bpi\n", + "import bias_bp_generate as bpgen\n", + "import bias_bp_graphs as bpg\n", + "import bias_bp_data as bpd" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T0yn3yBU9ukF" + }, + "source": [ + "# Gather Project Data\n", + "There are two different modes for balancing your data. The first (mode=\"full\"), is the scenario where you'd like to generate a complete synthetic dataset with bias removed. The second (mode=\"additive\"), is the scenario where you only want to generate synthetic samples, such that when added to the original set will remove bias.\n", + "\n", + "In the below command to gather project data, specifiy the appropriate mode, as well as the number of records from your project that you'd like to use (num_records). If you are running in mode \"full\", please also specify \n", + "the number of synthetic data records you'd like generated (gen_lines). If you are running in mode \"additive\", we will tell you the number of synthetic data records that will need to be generated to balance your dataset after you have chosen the fields to balance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "project_info = bpd.get_project_info(project, mode=\"full\", num_records=14000, gen_lines=1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "project_info[\"records\"].head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Look at Current Categorical Field Distributions\n", + "Graphs are shown for categorical fields having a unique value count less than or equal \n", + "to the parameter \"uniq_cnt_threshold\". Adjust this parameter to fit your needs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bpg.show_field_graphs(project_info[\"field_stats\"], uniq_cnt_threshold=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Choose Which Fields to Fix Bias In" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "project_info = bpi.choose_bias_fields(project_info)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compute Records Needed to Fix Bias\n", + "\n", + "If you are running in mode \"additive\", this command will also tell you the total number of synthetic\n", + "records that will need to be generated to fix the bias in your chosen fields. After viewing this, if you\n", + "would like to go back and adjust your bias field selections, you may." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "project_info = bpgen.compute_synth_needs(project_info)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train Your Synthetic Model\n", + "\n", + "- See [our documentation](https://gretel-synthetics.readthedocs.io/en/stable/api/config.html) for additional config options" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the Gretel Synthtetics Training / Model Configuration\n", + "from pathlib import Path\n", + "\n", + "checkpoint_dir = str(Path.cwd() / \"checkpoints\")\n", + "\n", + "config_template = {\n", + " \"checkpoint_dir\": checkpoint_dir,\n", + " \"overwrite\": True\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Create the Synthetic Training Bundle\n", + "from gretel_helpers.synthetics import SyntheticDataBundle\n", + "\n", + "bundle = SyntheticDataBundle(\n", + " header_prefix=bpd.bias_fields(project_info),\n", + " training_df=project_info[\"records\"],\n", + " delimiter=\",\", # Specify the appropriate delimeter in your data\n", + " auto_validate=True, \n", + " synthetic_config=config_template, \n", + ")\n", + "\n", + "bundle.build()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Now train your model\n", + "bundle.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generate Balanced Synthetic Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "synth_df = bpgen.gen_synth_nobias(bundle, project_info)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Take a Look At Your Synthetic Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "synth_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Combine Your Original and New Synthetic Data\n", + "Relevant if you are using mode=\"additive\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "new_df = pd.concat([synth_df,project_info[\"records\"]],ignore_index=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save to CSV" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "synth_df.to_csv('synthetic-data.csv', index=False, header=True)\n", + "#new_df.to_csv('synth-plus-orig-data.csv', index=False, header=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save to New Gretel Project" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_project = client.get_project(create=True)\n", + "new_project.send_dataframe(synth_df, detection_mode='fast') #alternatively use new_df\n", + "print(f\"Access your project at {new_project.get_console_url()}\") " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Delete project if you don't need it\n", + "new_project.delete()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Show New Distributions\n", + "When running in \"full\" mode, graphs will be shown comparing training data to synthetic data. When running in \"additive\" mode, still pass in the synth_df and the graphs will automatically compare training data to training plus synthetic records." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bpg.show_new_graphs(project_info, synth_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generate a Full Synthetic Performance Report\n", + "Correlations and distributions in non-bias fields should, as always, transfer from training data to synthetic data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gretel_helpers.reports.correlation import generate_report\n", + "from IPython.core.display import display\n", + "from IPython.display import IFrame\n", + "\n", + "generate_report(project_info[\"records\"], synth_df, report_path=\"./report.html\") #alternatively use new_df\n", + "display(IFrame(\"./report.html\", 1000, 600))" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "blueprint-synthetics-massive-imbalance", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "13e6f99bfa184a3e87bbf184c9327880": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_dcb9d1ff9d934d3a86169738a35627a1", + "IPY_MODEL_48736cd4b9ec480d8ac3ac75f628bd1c" + ], + "layout": "IPY_MODEL_a0758b455a8e46dcafc056c6754f19ef" + } + }, + "28253b42e1364bbf8538951277b1b7a4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_eb1c1019d18c460b9b967ef27a67136b", + "placeholder": "​", + "style": "IPY_MODEL_98b909083d00408d9d659c72b117903d", + "value": " 2500/2500 [01:17<00:00, 32.46it/s]" + } + }, + "2be98d0186d44b939451e305bf201c48": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "40b0069f29694447936296f714f22313": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "46f9b8b40d5b409aa7cbaab1ca55f782": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "48736cd4b9ec480d8ac3ac75f628bd1c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_df549aad36ac4c6fa538cc5718af664c", + "placeholder": "​", + "style": "IPY_MODEL_46f9b8b40d5b409aa7cbaab1ca55f782", + "value": " 146/2500 [01:17<20:41, 1.90it/s]" + } + }, + "4d23dbaafca04afdb879e7ca708930d8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e4b79fa2d813427c8e96b1d741ba011f", + "IPY_MODEL_28253b42e1364bbf8538951277b1b7a4" + ], + "layout": "IPY_MODEL_b9ec65dba8b2406181272a71251c2828" + } + }, + "98b909083d00408d9d659c72b117903d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a0758b455a8e46dcafc056c6754f19ef": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b39867c2eb224dd18488237d11534a0d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "initial" + } + }, + "b8107f1374cc41a485d89aff1bcfa071": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "initial" + } + }, + "b9ec65dba8b2406181272a71251c2828": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "dcb9d1ff9d934d3a86169738a35627a1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "Invalid record count : 6%", + "description_tooltip": null, + "layout": "IPY_MODEL_40b0069f29694447936296f714f22313", + "max": 2500, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b39867c2eb224dd18488237d11534a0d", + "value": 146 + } + }, + "df549aad36ac4c6fa538cc5718af664c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e4b79fa2d813427c8e96b1d741ba011f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "Valid record count : 100%", + "description_tooltip": null, + "layout": "IPY_MODEL_2be98d0186d44b939451e305bf201c48", + "max": 2500, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b8107f1374cc41a485d89aff1bcfa071", + "value": 2500 + } + }, + "eb1c1019d18c460b9b967ef27a67136b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/gretel/auto_balance_dataset/manifest.json b/gretel/auto_balance_dataset/manifest.json new file mode 100644 index 00000000..cb156da1 --- /dev/null +++ b/gretel/auto_balance_dataset/manifest.json @@ -0,0 +1,9 @@ +{ + "name": "Automatically balance your data", + "description": "Automatically balance minority classes in a dataset. Easily remove bias from as many fields as you like.", + "tags": ["synthetic-data", "ai"], + "sample_data_key": "us-adult-income", + "blog_url": "", + "language": "python", + "featured": true +}