diff --git a/containers/crosswalking/context/main.py b/containers/crosswalking/context/main.py index d382455..c01faf1 100644 --- a/containers/crosswalking/context/main.py +++ b/containers/crosswalking/context/main.py @@ -9,32 +9,6 @@ from src.util.ids import create_temp_asctb_id -def filter_crosswalk_table( - table: pd.DataFrame, - organ_id: str, - organ_level: str, - organ_id_column: str, - organ_level_column: str, - table_label_column: str, -) -> pd.DataFrame: - """Filter the crosswalk table to only include rows with organ id and level. - - Also removes empty rows. - - Args: - table (pd.DataFrame): Original full crosswalk table - - Returns: - pd.DataFrame: Filtered table - """ - organ_id_rows = table[organ_id_column].str.lower() == organ_id.lower() - organ_level_rows = table[organ_level_column].str.lower() == organ_level.lower() - filtered_table = table[organ_id_rows & organ_level_rows] - normalized_table = filtered_table.dropna(how="all") - unique_table = normalized_table.drop_duplicates(table_label_column) - return unique_table - - def crosswalk( matrix: anndata.AnnData, organ_id: str, @@ -74,7 +48,8 @@ def crosswalk( table_clid_column: data_clid_column, table_match_column: data_match_column, } - table = filter_crosswalk_table( + matrix = _filter_invalid_rows(matrix, data_label_column) + table = _filter_crosswalk_table( table, organ_id, organ_level, @@ -103,6 +78,47 @@ def crosswalk( return result +def _filter_invalid_rows(matrix: anndata.AnnData, column: str) -> anndata.AnnData: + """Filter out rows where the obs column's value is NaN or the empty string ''. + + Args: + matrix (anndata.AnnData): Matrix to filter + column (str): Column in obs + + Returns: + anndata.AnnData: Filtered subset matrix + """ + obs_subset = matrix.obs[column] + mask = obs_subset.notna() & (obs_subset != "") + return matrix[mask, :] + + +def _filter_crosswalk_table( + table: pd.DataFrame, + organ_id: str, + organ_level: str, + organ_id_column: str, + organ_level_column: str, + table_label_column: str, +) -> pd.DataFrame: + """Filter the crosswalk table to only include rows with organ id and level. + + Also removes empty rows. + + Args: + table (pd.DataFrame): Original full crosswalk table + + Returns: + pd.DataFrame: Filtered table + """ + organ_id_rows = table[organ_id_column].str.lower() == organ_id.lower() + organ_level_rows = table[organ_level_column].str.lower() == organ_level.lower() + filtered_table = table[organ_id_rows & organ_level_rows] + normalized_table = filtered_table.dropna(how="all") + unique_table = normalized_table.drop_duplicates(table_label_column) + return unique_table + + def _set_defaults( obs: pd.DataFrame, column: str, defaults: t.Union[pd.Series, str] ) -> None: