Skip to content

Commit

Permalink
added testcase for distance transformation and multiple regions
Browse files Browse the repository at this point in the history
  • Loading branch information
Mittmich committed Mar 17, 2024
1 parent 0792aed commit 027e944
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
41 changes: 31 additions & 10 deletions spoc/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def __init__(
# preprocess regions
if isinstance(regions, list):
self._regions, half_window_sizes = zip(
*[self._prepare_regions(region, half_window_size) for region in regions]
*[
self._prepare_regions(region, half_window_size, index=index)
for index, region in enumerate(regions)
]
)
if not all(
half_window_size == half_window_sizes[0]
Expand All @@ -128,7 +131,7 @@ def __init__(
self._anchor = anchor_mode

def _prepare_regions(
self, regions: pd.DataFrame, half_window_size: Optional[int]
self, regions: pd.DataFrame, half_window_size: Optional[int], index: int = 0
) -> Tuple[pd.DataFrame, int]:
"""Preprocessing of regions including adding an id column."""
if "id" not in regions.columns:
Expand All @@ -149,6 +152,9 @@ def _prepare_regions(
preprocssed_regions = expanded_regions.drop(
columns=["midpoint"]
).add_prefix("region_")
preprocssed_regions = RegionSchema.validate(preprocssed_regions)
if index > 0:
preprocssed_regions = preprocssed_regions.add_suffix(f"_{index}")
return preprocssed_regions, half_window_size
preprocssed_regions = RegionSchema.validate(regions.add_prefix("region_"))
# infer window size -> variable regions will have largest possible window size
Expand All @@ -158,6 +164,9 @@ def _prepare_regions(
).max()
// 2
)
# add index
if index > 0:
preprocssed_regions = preprocssed_regions.add_suffix(f"_{index}")
return preprocssed_regions, calculated_half_window_size

def validate(self, data_schema: GenomicDataSchema) -> None:
Expand Down Expand Up @@ -200,15 +209,17 @@ def _construct_query_multi_region(
for index, region in enumerate(regions):
snipped_df = snipped_df.join(
region.set_alias(f"regions_{index}"),
self._contstruct_filter(position_fields, f"regions_{index}"),
self._contstruct_filter(
position_fields, f"regions_{index}", index=index
),
how="left",
)
# filter regions based on region mode
if self._anchor.region_mode == "ALL":
return snipped_df.filter(
" and ".join(
[
f"regions_{index}.region_chrom is not null"
f"regions_{index}.region_chrom{'_' + str(index) if index > 0 else ''} is not null"
for index in range(0, len(regions))
]
)
Expand All @@ -217,7 +228,7 @@ def _construct_query_multi_region(
return snipped_df.filter(
" or ".join(
[
f"regions_{index}.region_chrom is not null"
f"regions_{index}.region_chrom{'_' + str(index) if index > 0 else ''} is not null"
for index in range(0, len(regions))
]
)
Expand All @@ -236,7 +247,7 @@ def _constrcut_query_single_region(
)

def _contstruct_filter(
self, position_fields: Dict[int, List[str]], region_name: str
self, position_fields: Dict[int, List[str]], region_name: str, index: int = 0
) -> str:
"""Constructs the filter string.
Expand All @@ -251,6 +262,10 @@ def _contstruct_filter(
"""
query_strings = []
join_string = " or " if self._anchor.fragment_mode == "ANY" else " and "
if index > 0:
column_index = f"_{index}"
else:
column_index = ""
# subset on anchor regions
if self._anchor.positions is not None:
subset_positions = [
Expand All @@ -260,11 +275,11 @@ def _contstruct_filter(
subset_positions = list(position_fields.values())
for fields in subset_positions:
chrom, start, end = fields
output_string = f"""(data.{chrom} = {region_name}.region_chrom and
output_string = f"""(data.{chrom} = {region_name}.region_chrom{column_index} and
(
data.{start} between {region_name}.region_start and {region_name}.region_end or
data.{end} between {region_name}.region_start and {region_name}.region_end or
{region_name}.region_start between data.{start} and data.{end}
data.{start} between {region_name}.region_start{column_index} and {region_name}.region_end{column_index} or
data.{end} between {region_name}.region_start{column_index} and {region_name}.region_end{column_index} or
{region_name}.region_start{column_index} between data.{start} and data.{end}
)
)"""
query_strings.append(output_string)
Expand Down Expand Up @@ -607,6 +622,12 @@ def validate(self, data_schema: GenomicDataSchema) -> None:
raise ValueError(
"Binsize specified in data schema, but distance mode is not set to LEFT."
)
# check wheter there has only been a single region overlapped
region_number = data_schema.get_region_number()
if isinstance(region_number, list):
raise ValueError(
"Distance transformation requires only a single set of regions overlapped."
)

def _create_transform_columns(
self, genomic_df: duckdb.DuckDBPyRelation, input_schema: GenomicDataSchema
Expand Down
12 changes: 12 additions & 0 deletions tests/query_engine/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@ def contacts_with_multiple_regions_fixture(contacts_without_regions, multi_regio
)


@pytest.fixture(name="contacts_with_multiple_regions_overlapped")
def contacts_with_multiple_regions_overlapped_fixture(
contacts_without_regions, single_region, single_region_2
):
"""Pixels with multiple regions overlapped"""
return Overlap(
[single_region, single_region_2],
anchor_mode=Anchor(fragment_mode="ANY"),
half_window_size=100,
)(contacts_without_regions)


@pytest.fixture(name="pixels_with_single_region")
def pixels_with_single_region_fixture(pixels_without_regions, single_region):
"""Pixels with single region"""
Expand Down
1 change: 1 addition & 0 deletions tests/query_engine/test_distance_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
[
"contacts_without_regions",
"pixels_without_regions",
"contacts_with_multiple_regions_overlapped",
],
)
def test_incompatible_input_rejected(genomic_data_fixture, request):
Expand Down

0 comments on commit 027e944

Please sign in to comment.