diff --git a/edvart/report_sections/bivariate_analysis.py b/edvart/report_sections/bivariate_analysis.py index 0d9a5c6..9c59791 100644 --- a/edvart/report_sections/bivariate_analysis.py +++ b/edvart/report_sections/bivariate_analysis.py @@ -813,16 +813,22 @@ def contingency_tables( Data based on which to create a contingency table. columns : List[str], optional Which columns to generate pair-wise contingency tables for. - All columns which contain more than 1 unique value are used by default. - Columns which contain only null values are always excluded. + Columns with more than `table_threshold` unique values are excluded. + Columns which contain only null values are excluded. + To override the excluded columns, specify `columns_pairs`. + Ignored if `columns_x` and `columns_y` or `columns_pairs` is specified. columns_x : List[str], optional If specified, contingency tables are plotted for each pair in the cartesian product of `columns_x` and `columns_y`. + Columns with more than `table_threshold` unique values are excluded. + Columns which contain only null values are excluded. If `columns_x` is specified, then `columns_y` must also be specified. Ignored if `columns_pairs` is specified. columns_y : List[str], optional If specified, contingency tables are plotted for each pair in the cartesian product of `columns_x` and `columns_y`. + Columns with more than `table_threshold` unique values are excluded. + Columns which contain only null values are excluded. If `columns_y` is specified, then `columns_x` must also be specified. Ignored if `columns_pairs` is specified. columns_pairs : List[Tuple[str, str]], optional @@ -842,10 +848,12 @@ def include_column(col: str) -> bool: if (columns_x is None) != (columns_y is None): raise ValueError("Either both or neither of columns_x, columns_y must be specified.") + if columns is None: + columns = list(df.columns) + columns = [col for col in df.columns if include_column(col)] + if columns_pairs is None: if columns_x is None: - if columns is None: - columns = [col for col in df.columns if include_column(col)] columns_pairs = list(itertools.combinations(columns, 2)) else: columns_pairs = [ @@ -853,7 +861,7 @@ def include_column(col: str) -> bool: for (col_x, col_y) in itertools.product(columns_x, columns_y) # Filter out pairs of columns which contain the same column since # they make no sense in a contingency table - if col_x != col_y + if col_x != col_y and include_column(col_x) and include_column(col_y) ] for column1, column2 in columns_pairs: