diff --git a/spoc/query_engine.py b/spoc/query_engine.py index 33fd75c..03f2215 100644 --- a/spoc/query_engine.py +++ b/spoc/query_engine.py @@ -108,7 +108,10 @@ def __init__( # preprocess regions if isinstance(regions, list): self._regions, half_window_sizes = zip( - *[self._prepare_regions(region, half_window_size) for region in regions] + *[ + self._prepare_regions(region, half_window_size, index=index) + for index, region in enumerate(regions) + ] ) if not all( half_window_size == half_window_sizes[0] @@ -128,7 +131,7 @@ def __init__( self._anchor = anchor_mode def _prepare_regions( - self, regions: pd.DataFrame, half_window_size: Optional[int] + self, regions: pd.DataFrame, half_window_size: Optional[int], index: int = 0 ) -> Tuple[pd.DataFrame, int]: """Preprocessing of regions including adding an id column.""" if "id" not in regions.columns: @@ -149,6 +152,9 @@ def _prepare_regions( preprocssed_regions = expanded_regions.drop( columns=["midpoint"] ).add_prefix("region_") + preprocssed_regions = RegionSchema.validate(preprocssed_regions) + if index > 0: + preprocssed_regions = preprocssed_regions.add_suffix(f"_{index}") return preprocssed_regions, half_window_size preprocssed_regions = RegionSchema.validate(regions.add_prefix("region_")) # infer window size -> variable regions will have largest possible window size @@ -158,6 +164,9 @@ def _prepare_regions( ).max() // 2 ) + # add index + if index > 0: + preprocssed_regions = preprocssed_regions.add_suffix(f"_{index}") return preprocssed_regions, calculated_half_window_size def validate(self, data_schema: GenomicDataSchema) -> None: @@ -200,7 +209,9 @@ def _construct_query_multi_region( for index, region in enumerate(regions): snipped_df = snipped_df.join( region.set_alias(f"regions_{index}"), - self._contstruct_filter(position_fields, f"regions_{index}"), + self._contstruct_filter( + position_fields, f"regions_{index}", index=index + ), how="left", ) # filter regions based on region mode @@ -208,7 +219,7 @@ def _construct_query_multi_region( return snipped_df.filter( " and ".join( [ - f"regions_{index}.region_chrom is not null" + f"regions_{index}.region_chrom{'_' + str(index) if index > 0 else ''} is not null" for index in range(0, len(regions)) ] ) @@ -217,7 +228,7 @@ def _construct_query_multi_region( return snipped_df.filter( " or ".join( [ - f"regions_{index}.region_chrom is not null" + f"regions_{index}.region_chrom{'_' + str(index) if index > 0 else ''} is not null" for index in range(0, len(regions)) ] ) @@ -236,7 +247,7 @@ def _constrcut_query_single_region( ) def _contstruct_filter( - self, position_fields: Dict[int, List[str]], region_name: str + self, position_fields: Dict[int, List[str]], region_name: str, index: int = 0 ) -> str: """Constructs the filter string. @@ -251,6 +262,10 @@ def _contstruct_filter( """ query_strings = [] join_string = " or " if self._anchor.fragment_mode == "ANY" else " and " + if index > 0: + column_index = f"_{index}" + else: + column_index = "" # subset on anchor regions if self._anchor.positions is not None: subset_positions = [ @@ -260,11 +275,11 @@ def _contstruct_filter( subset_positions = list(position_fields.values()) for fields in subset_positions: chrom, start, end = fields - output_string = f"""(data.{chrom} = {region_name}.region_chrom and + output_string = f"""(data.{chrom} = {region_name}.region_chrom{column_index} and ( - data.{start} between {region_name}.region_start and {region_name}.region_end or - data.{end} between {region_name}.region_start and {region_name}.region_end or - {region_name}.region_start between data.{start} and data.{end} + data.{start} between {region_name}.region_start{column_index} and {region_name}.region_end{column_index} or + data.{end} between {region_name}.region_start{column_index} and {region_name}.region_end{column_index} or + {region_name}.region_start{column_index} between data.{start} and data.{end} ) )""" query_strings.append(output_string) @@ -607,6 +622,12 @@ def validate(self, data_schema: GenomicDataSchema) -> None: raise ValueError( "Binsize specified in data schema, but distance mode is not set to LEFT." ) + # check wheter there has only been a single region overlapped + region_number = data_schema.get_region_number() + if isinstance(region_number, list): + raise ValueError( + "Distance transformation requires only a single set of regions overlapped." + ) def _create_transform_columns( self, genomic_df: duckdb.DuckDBPyRelation, input_schema: GenomicDataSchema diff --git a/tests/query_engine/conftest.py b/tests/query_engine/conftest.py index 40b3afd..ee2bd46 100644 --- a/tests/query_engine/conftest.py +++ b/tests/query_engine/conftest.py @@ -135,6 +135,18 @@ def contacts_with_multiple_regions_fixture(contacts_without_regions, multi_regio ) +@pytest.fixture(name="contacts_with_multiple_regions_overlapped") +def contacts_with_multiple_regions_overlapped_fixture( + contacts_without_regions, single_region, single_region_2 +): + """Pixels with multiple regions overlapped""" + return Overlap( + [single_region, single_region_2], + anchor_mode=Anchor(fragment_mode="ANY"), + half_window_size=100, + )(contacts_without_regions) + + @pytest.fixture(name="pixels_with_single_region") def pixels_with_single_region_fixture(pixels_without_regions, single_region): """Pixels with single region""" diff --git a/tests/query_engine/test_distance_transformation.py b/tests/query_engine/test_distance_transformation.py index 2278006..5532009 100644 --- a/tests/query_engine/test_distance_transformation.py +++ b/tests/query_engine/test_distance_transformation.py @@ -14,6 +14,7 @@ [ "contacts_without_regions", "pixels_without_regions", + "contacts_with_multiple_regions_overlapped", ], ) def test_incompatible_input_rejected(genomic_data_fixture, request):