Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi region handling #11

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 241 additions & 25 deletions notebooks/query_engine_usage.ipynb

Large diffs are not rendered by default.

47 changes: 26 additions & 21 deletions spoc/contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
label_sorted: bool = False,
binary_labels_equal: bool = False,
symmetry_flipped: bool = False,
label_values: Optional[List[str]] = None,
) -> None:
self.contains_metadata = (
"metadata_1" in contact_frame.columns
Expand All @@ -61,7 +62,6 @@ def __init__(
number_fragments=self.number_fragments,
contains_metadata=self.contains_metadata,
)
# TODO: make this work for duckdb pyrelation -> switch to mode
if isinstance(contact_frame, pd.DataFrame):
self.data_mode = DataMode.PANDAS
elif isinstance(contact_frame, dd.DataFrame):
Expand All @@ -75,6 +75,7 @@ def __init__(
self.label_sorted = label_sorted
self.binary_labels_equal = binary_labels_equal
self.symmetry_flipped = symmetry_flipped
self.label_values = label_values

@staticmethod
def from_uri(uri, mode=DataMode.PANDAS):
Expand Down Expand Up @@ -116,15 +117,19 @@ def get_label_values(self) -> List[str]:
# TODO: This could be put in global metadata of parquet file
if not self.contains_metadata:
raise ValueError("Contacts do not contain metadata!")
output = set()
for i in range(self.number_fragments):
if self.data_mode == DataMode.DASK:
output.update(self.data[f"metadata_{i+1}"].unique().compute())
elif self.data_mode == DataMode.PANDAS:
output.update(self.data[f"metadata_{i+1}"].unique())
else:
raise ValueError("Label values not supported for duckdb!")
return list(output)
if self.label_values is None:
output = set()
for i in range(self.number_fragments):
if self.data_mode == DataMode.DASK:
output.update(self.data[f"metadata_{i+1}"].unique().compute())
elif self.data_mode == DataMode.PANDAS:
output.update(self.data[f"metadata_{i+1}"].unique())
else:
raise ValueError("Label values not supported for duckdb!")
# add metadata values and return
self.label_values = list(output)
return list(output)
return self.label_values

def get_chromosome_values(self) -> List[str]:
"""Returns all chromosome values"""
Expand Down Expand Up @@ -345,16 +350,16 @@ def sort_labels(self, contacts: Contacts) -> Contacts:
)
# determine which method to use for concatenation
if contacts.data_mode == DataMode.DASK:
# this is a bit of a hack to get the index sorted. Dask does not support index sorting
result = (
dd.concat(subsets).reset_index().sort_values("index").set_index("index")
)
result = dd.concat(subsets)
elif contacts.data_mode == DataMode.PANDAS:
result = pd.concat(subsets).sort_index()
result = pd.concat(subsets)
else:
raise ValueError("Sorting labels for duckdb relations is not implemented.")
return Contacts(
result, number_fragments=contacts.number_fragments, label_sorted=True
result,
number_fragments=contacts.number_fragments,
label_sorted=True,
label_values=label_values,
)

def _sort_chromosomes(self, df: DataFrame, number_fragments: int) -> DataFrame:
Expand Down Expand Up @@ -452,12 +457,9 @@ def equate_binary_labels(self, contacts: Contacts) -> Contacts:
subsets.append(subset)
# determine which method to use for concatenation
if contacts.data_mode == DataMode.DASK:
# this is a bit of a hack to get the index sorted. Dask does not support index sorting
result = (
dd.concat(subsets).reset_index().sort_values("index").set_index("index")
)
result = dd.concat(subsets)
elif contacts.data_mode == DataMode.PANDAS:
result = pd.concat(subsets).sort_index()
result = pd.concat(subsets)
else:
raise ValueError(
"Equate binary labels for duckdb relations is not implemented."
Expand All @@ -467,6 +469,7 @@ def equate_binary_labels(self, contacts: Contacts) -> Contacts:
number_fragments=contacts.number_fragments,
label_sorted=True,
binary_labels_equal=True,
label_values=label_values,
)

def subset_on_metadata(
Expand Down Expand Up @@ -506,6 +509,7 @@ def subset_on_metadata(
label_sorted=contacts.label_sorted,
binary_labels_equal=contacts.binary_labels_equal,
symmetry_flipped=contacts.symmetry_flipped,
label_values=label_values,
)

def flip_symmetric_contacts(
Expand Down Expand Up @@ -534,6 +538,7 @@ def flip_symmetric_contacts(
label_sorted=True,
binary_labels_equal=contacts.binary_labels_equal,
symmetry_flipped=True,
label_values=label_values,
)
result = self._flip_unlabelled_contacts(contacts.data)
if sort_chromosomes:
Expand Down
38 changes: 31 additions & 7 deletions spoc/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class FileManager:
data_mode (DataMode, optional): Data mode. Defaults to DataMode.PANDAS.
"""

def __init__(self, data_mode: DataMode = DataMode.PANDAS) -> None:
def __init__(self, data_mode: DataMode = DataMode.PANDAS, **kwargs) -> None:
if data_mode == DataMode.DUCKDB:
self._parquet_reader_func = partial(
duckdb.read_parquet, connection=DUCKDB_CONNECTION
Expand All @@ -45,6 +45,17 @@ def __init__(self, data_mode: DataMode = DataMode.PANDAS) -> None:
self._parquet_reader_func = pd.read_parquet
else:
raise ValueError(f"Data mode {data_mode} not supported!")
# store data mode
self._data_mode = data_mode
# set duckdb parameters if they are there
if "duckdb_max_memory" in kwargs and data_mode == DataMode.DUCKDB:
DUCKDB_CONNECTION.execute(
f"PRAGMA memory_limit = '{kwargs['duckdb_max_memory']}'"
)
if "duckdb_max_threads" in kwargs and data_mode == DataMode.DUCKDB:
DUCKDB_CONNECTION.execute(
f"PRAGMA threads = {kwargs['duckdb_max_threads']}"
)

@staticmethod
def write_label_library(path: str, data: Dict[str, bool]) -> None:
Expand Down Expand Up @@ -254,9 +265,19 @@ def load_pixels(
)
# rewrite path to contain parent folder
full_pixel_path = Path(path) / pixel_path
df = self._parquet_reader_func(full_pixel_path)
if self._data_mode == DataMode.DUCKDB:
df = self._parquet_reader_func(self._get_duckdb_path(full_pixel_path))
else:
df = self._parquet_reader_func(full_pixel_path)
return Pixels(df, **matched_parameters.dict())

def _get_duckdb_path(self, path: Path) -> str:
"""Constructs duckdb path string to handle cases
where parquet files are stored in a directory with multiple files."""
if Path(path).is_dir():
return f"{path}/*.parquet"
return str(path)

def load_contacts(
self, path: str, global_parameters: Optional[ContactsParameters] = None
) -> Contacts:
Expand Down Expand Up @@ -287,7 +308,10 @@ def load_contacts(
)
# rewrite path to contain parent folder
full_contacts_path = Path(path) / contacts_path
df = self._parquet_reader_func(full_contacts_path)
if self._data_mode == DataMode.DUCKDB:
df = self._parquet_reader_func(self._get_duckdb_path(full_contacts_path))
else:
df = self._parquet_reader_func(str(full_contacts_path))
return Contacts(df, **matched_parameters.dict())

@staticmethod
Expand All @@ -305,7 +329,7 @@ def _get_object_hash_path(path: str, data_object: Union[Pixels, Contacts]) -> st
)
return md5(hash_string.encode(encoding="utf-8")).hexdigest() + ".parquet"

def write_pixels(self, path: str, pixels: Pixels) -> None:
def write_pixels(self, path: str, pixels: Pixels, **kwargs) -> None:
"""Write pixels

Args:
Expand All @@ -331,13 +355,13 @@ def write_pixels(self, path: str, pixels: Pixels) -> None:
raise ValueError(
"Writing pixels only suppported for pixels hodling dataframes!"
)
pixels.data.to_parquet(write_path, row_group_size=1024 * 1024)
pixels.data.to_parquet(str(write_path), **kwargs)
# write metadata
current_metadata[write_path.name] = pixels.get_global_parameters().dict()
with open(metadata_path, "w", encoding="UTF-8") as f:
json.dump(current_metadata, f)

def write_contacts(self, path: str, contacts: Contacts) -> None:
def write_contacts(self, path: str, contacts: Contacts, **kwargs) -> None:
"""Write contacts"""
# check whether path exists
metadata_path = Path(path) / "metadata.json"
Expand All @@ -354,7 +378,7 @@ def write_contacts(self, path: str, contacts: Contacts) -> None:
raise ValueError(
"Writing contacts only suppported for contacts hodling dataframes!"
)
contacts.data.to_parquet(write_path, row_group_size=1024 * 1024)
contacts.data.to_parquet(str(write_path), **kwargs)
# write metadata
current_metadata[write_path.name] = contacts.get_global_parameters().dict()
with open(metadata_path, "w") as f:
Expand Down
23 changes: 11 additions & 12 deletions spoc/models/dataframe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_schema(self) -> pa.DataFrameSchema:
def get_binsize(self) -> Optional[int]:
"""Returns the binsize of the genomic data"""

def get_region_number(self) -> Optional[int]:
def get_region_number(self) -> Optional[Union[int, List[int]]]:
"""Returns the number of regions in the genomic data
if present."""

Expand All @@ -89,7 +89,7 @@ def __init__(
position_fields: Dict[int, List[str]],
contact_order: int,
binsize: Optional[int] = None,
region_number: Optional[int] = None,
region_number: Optional[Union[int, List[int]]] = None,
half_window_size: Optional[int] = None,
) -> None:
self._columns = columns
Expand Down Expand Up @@ -127,7 +127,7 @@ def get_binsize(self) -> Optional[int]:
"""Returns the binsize of the genomic data"""
return self._binsize

def get_region_number(self) -> Optional[int]:
def get_region_number(self) -> Optional[Union[int, List[int]]]:
"""Returns the number of regions in the genomic data
if present."""
return self._region_number
Expand Down Expand Up @@ -215,8 +215,8 @@ def validate_header(self, data_frame: DataFrame) -> None:
Args:
data_frame (DataFrame): The DataFrame to validate.
"""
for column in data_frame.columns:
if column not in self._schema.columns:
for column in self._schema.columns:
if column not in data_frame.columns:
raise pa.errors.SchemaError(
self._schema, data_frame, "Header is invalid!"
)
Expand Down Expand Up @@ -258,7 +258,7 @@ def get_binsize(self) -> Optional[int]:
"""Returns the binsize of the genomic data"""
return None

def get_region_number(self) -> Optional[int]:
def get_region_number(self) -> Optional[Union[int, List[int]]]:
"""Returns the number of regions in the genomic data
if present."""
return None
Expand Down Expand Up @@ -350,17 +350,16 @@ def get_position_fields(self) -> Dict[int, List[str]]:
return {
i: ["chrom", f"start_{i}"] for i in range(1, self._number_fragments + 1)
}
else:
return {
i: [f"chrom_{i}", f"start_{i}"]
for i in range(1, self._number_fragments + 1)
}
return {
i: [f"chrom_{i}", f"start_{i}"]
for i in range(1, self._number_fragments + 1)
}

def get_binsize(self) -> Optional[int]:
"""Returns the binsize of the genomic data"""
return self._binsize

def get_region_number(self) -> Optional[int]:
def get_region_number(self) -> Optional[Union[int, List[int]]]:
"""Returns the number of regions in the genomic data
if present."""
return None
Expand Down
Loading