diff --git a/setup.cfg b/setup.cfg index ff43492..ba57239 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,9 @@ packages = find: python_requires = >= 3.8 include_package_data = True install_requires = + numpy pandas + typing-extensions [options.extras_require] diff --git a/starfile/functions.py b/starfile/functions.py index 6b6dea0..e952d29 100644 --- a/starfile/functions.py +++ b/starfile/functions.py @@ -8,6 +8,7 @@ from .parser import StarParser from .writer import StarWriter +from .typing import DataBlock if TYPE_CHECKING: import pandas as pd @@ -21,22 +22,26 @@ def read(filename: PathLike, read_n_blocks: int = None, always_dict: bool = Fals default behaviour in the case of only one data block being present in the STAR file is to return only a dataframe, this can be changed by setting 'always_dict=True' """ - star = StarParser(filename, read_n_blocks=read_n_blocks) - if len(star.dataframes) == 1 and always_dict is False: - return star.first_dataframe + parser = StarParser(filename, n_blocks_to_read=read_n_blocks) + if len(parser.data_blocks) == 1 and always_dict is False: + return list(parser.data_blocks.values())[0] else: - return star.dataframes - - -def write(data: Union[pd.DataFrame, Dict[str, pd.DataFrame], List[pd.DataFrame]], - filename: PathLike, - float_format: str = '%.6f', sep: str = '\t', na_rep: str = '', - overwrite: bool = False, force_loop: bool = True): - """ - Write dataframes from pandas dataframe(s) to a star file - - data can be a single dataframe, a list of dataframes or a dict of dataframes - float format defaults to 6 digits after the decimal point - """ - StarWriter(data, filename=filename, float_format=float_format, overwrite=overwrite, - na_rep=na_rep, sep=sep, force_loop=force_loop) + return parser.data_blocks + + +def write( + data: Union[DataBlock, Dict[str, DataBlock], List[DataBlock]], + filename: PathLike, + float_format: str = '%.6f', + sep: str = '\t', + na_rep: str = '', + **kwargs, +): + """Write data blocks as STAR files.""" + StarWriter( + data, + filename=filename, + float_format=float_format, + na_rep=na_rep, + separator=sep + ) diff --git a/starfile/parser.py b/starfile/parser.py index e1a3c65..dc98a39 100644 --- a/starfile/parser.py +++ b/starfile/parser.py @@ -1,217 +1,142 @@ from __future__ import annotations -from collections import OrderedDict +from collections import deque from io import StringIO +from linecache import getline +import numpy as np import pandas as pd from pathlib import Path -from typing import TYPE_CHECKING, List, Union, Optional +from typing import TYPE_CHECKING, Union, Optional, Dict, Tuple -from .utils import TextBuffer, TextCrawler +from starfile.typing import DataBlock if TYPE_CHECKING: from os import PathLike class StarParser: - def __init__(self, filename: PathLike, read_n_blocks: Optional[int] = None): + filename: Path + n_lines_in_file: int + n_blocks_to_read: int + current_line_number: int + data_blocks: Dict[DataBlock] + + def __init__(self, filename: PathLike, n_blocks_to_read: Optional[int] = None): # set filename, with path checking + filename = Path(filename) + if not filename.exists(): + raise FileNotFoundError(filename) self.filename = filename - # initialise attributes for parsing - self.text_buffer = TextBuffer() - self.crawler = TextCrawler(self.filename) - self.read_n_blocks = read_n_blocks - self._dataframes = OrderedDict() - self._current_dataframe_index = 0 - self._initialise_n_lines() + # setup for parsing + self.data_blocks = {} + self.n_lines_in_file = count_lines(self.filename) + self.n_blocks_to_read = n_blocks_to_read # parse file + self.current_line_number = 0 self.parse_file() - def parse_file(self): - while self.crawler.current_line_number <= self.n_lines: - if len(self.dataframes) == self.read_n_blocks: - break - - elif self.crawler.current_line.startswith('data_'): - self._parse_data_block() - - if not self.crawler.current_line.startswith('data_'): - self.crawler.increment_line_number() - - self.dataframes_to_numeric() - return - - def _parse_data_block(self): - self.current_block_name = self._block_name_from_current_line() - - while self.crawler.current_line_number <= self.n_lines: - self.crawler.increment_line_number() - line = self.crawler.current_line - - if line.startswith('loop_'): - self._parse_loop_block() - return - - elif line.startswith( - 'data_') or self.crawler.current_line_number == self.n_lines: - self._parse_simple_block_from_buffer() - return - - self.text_buffer.add_line(line) - return - - def _parse_simple_block_from_buffer(self): - data = self._clean_simple_block_in_buffer() - - df = self._cleaned_simple_block_to_dataframe(data) - df.name = self._current_data_block_name - self._add_dataframe(df) - - self.text_buffer.clear() - - def _parse_loop_block(self): - self.crawler.increment_line_number() - header = self._parse_loop_header() - df = self._parse_loop_data() - if df is None: - df = pd.DataFrame({h: None for h in header}, index=[0]) - df.columns = header - df.name = self._current_data_block_name - self._add_dataframe(df) - return - - @property - def filename(self): - return self._filename - - @filename.setter - def filename(self, filename: Union[str, Path]): - filename = Path(filename) - if filename.exists(): - self._filename = filename - else: - raise FileNotFoundError - - @property - def n_lines(self): - return self._n_lines - - def _initialise_n_lines(self): - self._n_lines = self.crawler.count_lines() - @property - def dataframes(self): - return self._dataframes + def current_line(self) -> str: + return getline(str(self.filename), self.current_line_number).strip() - def _add_dataframe(self, df: pd.DataFrame): - key = self._get_dataframe_key(df) - self.dataframes[key] = df - self._increment_dataframe_index() - - @property - def current_block_name(self): - return self._current_data_block_name - - @current_block_name.setter - def current_block_name(self, name: str): - self._current_data_block_name = name - - @property - def current_dataframe_index(self): - return self._current_dataframe_index - - def _increment_dataframe_index(self): - self._current_dataframe_index += 1 - - def _get_dataframe_key(self, df): - name = df.name - - if name == '' or isinstance(name, int) or name in self.dataframes.keys(): - return self._current_dataframe_index - else: - return df.name - - def _clean_simple_block_in_buffer(self): - clean_datablock = {} - - for line in self.text_buffer.buffer: - if line == '' or line.startswith('#'): - continue - - heading_name = self.heading_from_line(line) - value = line.split()[1] - clean_datablock[heading_name] = value - - return clean_datablock - - @staticmethod - def _cleaned_simple_block_to_dataframe(data: dict): - return pd.DataFrame(data, columns=data.keys(), index=[0]) - - def _parse_loop_header(self) -> List[str]: - self.text_buffer.clear() - - while self.crawler.current_line.startswith('_'): - heading = self.heading_from_line(self.crawler.current_line) - self.text_buffer.add_line(heading) - self.crawler.increment_line_number() - - return self.text_buffer.buffer - - def _parse_loop_data(self) -> Union[pd.DataFrame, None]: - self.text_buffer.clear() - - while self.crawler.current_line_number <= self.n_lines: - current_line = self.crawler.current_line - if current_line.startswith('data_'): + def parse_file(self): + while self.current_line_number <= self.n_lines_in_file: + if len(self.data_blocks) == self.n_blocks_to_read: + break + elif self.current_line.startswith('data_'): + block_name, block = self._parse_data_block() + self.data_blocks[block_name] = block + else: + self.current_line_number += 1 + + def _parse_data_block(self) -> Tuple[str, DataBlock]: + # current line starts with 'data_foo' + block_name = self.current_line[5:] # 'data_foo' -> 'foo' + self.current_line_number += 1 + + # iterate over file, + while self.current_line_number <= self.n_lines_in_file: + self.current_line_number += 1 + if self.current_line.startswith('loop_'): + return block_name, self._parse_loop_block() + elif self.current_line.startswith('_'): # line is simple block + return block_name, self._parse_simple_block() + + def _parse_simple_block(self) -> Dict[str, Union[str, int, float]]: + block = {} + while self.current_line_number <= self.n_lines_in_file: + if self.current_line.startswith('data'): break - self.text_buffer.add_line(current_line) - self.crawler.increment_line_number() - - # check whether the buffer is empty - if self.text_buffer.is_empty: - return None - - df = pd.read_csv( - StringIO(self.text_buffer.as_str()), - delim_whitespace=True, - header=None, - comment='#' - ) + elif self.current_line.startswith('_'): # '_foo bar' + k, v = self.current_line.split() + block[k[1:]] = numericise(v) + self.current_line_number += 1 + return block + + def _parse_loop_block(self) -> pd.DataFrame: + # parse loop header + loop_column_names = deque() + self.current_line_number += 1 + + while self.current_line.startswith('_'): + column_name = self.current_line.split()[0][1:] + loop_column_names.append(column_name) + self.current_line_number += 1 + + # now parse the loop block data + loop_data = deque() + while self.current_line_number <= self.n_lines_in_file: + if self.current_line.startswith('data_'): + break + loop_data.append(self.current_line) + self.current_line_number += 1 + loop_data = '\n'.join(loop_data) + if loop_data[-2:] != '\n': + loop_data += '\n' + + # put string data into a dataframe + if loop_data == '\n': + n_cols = len(loop_column_names) + df = pd.DataFrame(np.zeros(shape=(0, n_cols))) + else: + df = pd.read_csv( + StringIO(loop_data), + delim_whitespace=True, + header=None, + comment='#' + ) + df = df.apply(pd.to_numeric, errors='ignore') + df.columns = loop_column_names return df - def dataframes_to_numeric(self): - """ - Converts strings in dataframes into numerical values where possible - applying pd.to_numeric causes loss of 'name' attribute of DataFrame, - need to extract name and reapply inline - """ - for key, df in self.dataframes.items(): - name = getattr(df, 'name', None) - self.dataframes[key] = df.apply(pd.to_numeric, errors='ignore') - if name is not None: - self.dataframes[key].name = name +def count_lines(file: Path) -> int: + with open(file, 'rb') as f: + return sum(1 for _ in f) - @staticmethod - def _block_name_from_line(line: str): - return line[5:] - def _block_name_from_current_line(self): - return self._block_name_from_line(self.crawler.current_line) +def block_name_from_line(line: str) -> str: + """'data_general' -> 'general'""" + return line[5:] - @staticmethod - def heading_from_line(line: str): - return line.split()[0][1:] - @property - def first_dataframe(self): - return self.dataframe_at_index(0) +def heading_from_line(line: str) -> str: + """'_rlnSpectralIndex #1' -> 'rlnSpectralIndex'.""" + return line.split()[0][1:] - def dataframe_at_index(self, idx: int): - return self.dataframes_as_list()[idx] - def dataframes_as_list(self): - return list(self.dataframes.values()) +def numericise(value: str) -> Union[str, int, float]: + try: + # Try to convert the string value to an integer + value = int(value) + except ValueError: + try: + # If it's not an integer, try to convert it to a float + value = float(value) + except ValueError: + # If it's not a float either, leave it as a string + value = value + return value diff --git a/starfile/typing.py b/starfile/typing.py new file mode 100644 index 0000000..c039b57 --- /dev/null +++ b/starfile/typing.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import Union, Dict +from typing_extensions import TypeAlias + +import pandas as pd + +DataBlock: TypeAlias = Union[ + pd.DataFrame, + Dict[str, Union[str, int, float]] +] diff --git a/starfile/writer.py b/starfile/writer.py index 8a04320..754b5ae 100644 --- a/starfile/writer.py +++ b/starfile/writer.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections import OrderedDict from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Union, Dict, List @@ -9,6 +8,7 @@ from pkg_resources import get_distribution from .utils import TextBuffer +from .typing import DataBlock if TYPE_CHECKING: from os import PathLike @@ -17,135 +17,155 @@ class StarWriter: - def __init__(self, dataframes: Union[pd.DataFrame, Dict[pd.DataFrame], List[pd.DataFrame]], - filename: PathLike, overwrite: bool = False, float_format: str = '%.6f', - sep: str = '\t', na_rep: str = '', force_loop: bool = False): - self.overwrite = overwrite - self.filename = filename - self.dataframes = dataframes + def __init__( + self, + data_blocks: Union[DataBlock, Dict[str, DataBlock], List[DataBlock]], + filename: PathLike, + float_format: str = '%.6f', + separator: str = '\t', + na_rep: str = '', + ): + # coerce data + self.data_blocks = self.coerce_data_blocks(data_blocks) + + # write + self.filename = Path(filename) self.float_format = float_format - self.sep = sep + self.sep = separator self.na_rep = na_rep - self.force_loop = force_loop self.buffer = TextBuffer() - self.write_star_file() - - @property - def dataframes(self): - """ - Ordered dictionary of pandas dataframes - df.name defines the data block name - """ - return self._dataframes - - @dataframes.setter - def dataframes(self, dataframes: Union[pd.DataFrame, Dict[pd.DataFrame], List[pd.DataFrame]]): - if isinstance(dataframes, pd.DataFrame): - self._dataframes = self.coerce_dataframe(dataframes) - elif isinstance(dataframes, dict): - self._dataframes = self.coerce_dict(dataframes) - elif isinstance(dataframes, list): - self._dataframes = self.coerce_list(dataframes) + self.backup_if_file_exists() + self.write() + + def coerce_data_blocks( + self, + data_blocks: Union[DataBlock, List[DataBlock], Dict[str, DataBlock]] + ) -> Dict[str, DataBlock]: + if isinstance(data_blocks, pd.DataFrame): + return coerce_dataframe(data_blocks) + elif isinstance(data_blocks, dict): + return coerce_dict(data_blocks) + elif isinstance(data_blocks, list): + return coerce_list(data_blocks) else: - raise ValueError(f'Expected a DataFrame, Dict or List object, got {type(dataframes)}') - - @staticmethod - def coerce_dataframe(df: pd.DataFrame): - name = getattr(df, 'name', '') - if name != '': - name = 0 - return {name: df} - - @staticmethod - def coerce_dict(dfs: Dict[str, pd.DataFrame]): - """ - This method ensures that dataframe names are updated based on dict keys - """ - for key, df in dfs.items(): - df.name = str(key) - return dfs - - def coerce_list(self, dfs: List[pd.DataFrame]): - """ - This method coerces a list of DataFrames into a dict - """ - return self.coerce_dict(OrderedDict([(idx, df) for idx, df in enumerate(dfs)])) - - @property - def filename(self): - return self._filename - - @filename.setter - def filename(self, filename: Union[Path, str]): - self._filename = Path(filename) - if not self.file_writable: - raise FileExistsError('to overwrite an existing file set overwrite=True') - - @property - def file_exists(self): - return self.filename.exists() - - @property - def file_writable(self): - if self.overwrite or (not self.file_exists): - return True - else: - return False - - def write_package_info(self): - date = datetime.now().strftime('%d/%m/%Y') - time = datetime.now().strftime('%H:%M:%S') - line = f'Created by the starfile Python package (version {__version__}) at {time} on' \ - f' {date}' - self.buffer.add_comment(line) - self.buffer.add_blank_lines(1) - self.buffer.write_as_new_file_and_clear(self.filename) - - def write_star_file(self, filename: str = None): - self.write_package_info() - - for _, df in self.dataframes.items(): - self.write_block(df) - - self.buffer.add_blank_line() - self.buffer.append_to_file_and_clear(self.filename) - - def write_loopheader(self, df: pd.DataFrame): - self.buffer.add_line('loop_') - lines = [f'_{column_name} #{idx}' for idx, column_name in enumerate(df.columns, 1)] - - for line in lines: - self.buffer.add_line(line) - - self.buffer.append_to_file_and_clear(self.filename) - - @staticmethod - def get_block_name(df: pd.DataFrame): - return 'data_' + getattr(df, 'name', '') - - def add_block_name_to_buffer(self, df: pd.DataFrame): - self.buffer.add_line(self.get_block_name(df)) - self.buffer.add_blank_lines(1) - self.buffer.append_to_file_and_clear(self.filename) - - def write_block(self, df: pd.DataFrame): - self.add_block_name_to_buffer(df) - - if (df.shape[0] == 1) and not self.force_loop: - self._write_simple_block(df) - elif (df.shape[0] > 1) or self.force_loop: - self._write_loop_block(df) - self.buffer.add_blank_lines(2) - self.buffer.append_to_file_and_clear(self.filename) - - def _write_simple_block(self, df: pd.DataFrame): - lines = [f'_{column_name}\t\t\t{df[column_name].iloc[0]}' - for column_name in df.columns] - for line in lines: - self.buffer.add_line(line) - self.buffer.append_to_file_and_clear(self.filename) - - def _write_loop_block(self, df: pd.DataFrame): - self.write_loopheader(df) - df.to_csv(self.filename, mode='a', sep=self.sep, header=False, index=False, - float_format=self.float_format, na_rep=self.na_rep) + raise ValueError( + f'Expected \ + {pd.DataFrame}, \ + {Dict[str, pd.DataFrame]} \ + or {List[pd.DataFrame]}, \ + got {type(data_blocks)}' + ) + + def write(self): + write_package_info(self.filename) + write_blank_lines(self.filename, n=2) + self.write_data_blocks() + + def write_data_blocks(self): + for block_name, block in self.data_blocks.items(): + if isinstance(block, dict): + write_simple_block( + file=self.filename, + block_name=block_name, + data=block + ) + elif isinstance(block, pd.DataFrame): + write_loop_block( + file=self.filename, + block_name=block_name, + df=block, + float_format=self.float_format, + separator=self.sep, + na_rep=self.na_rep, + ) + + def backup_if_file_exists(self): + if self.filename.exists(): + new_name = self.filename.name + '~' + backup_path = self.filename.resolve().parent / new_name + self.filename.rename(backup_path) + + +def coerce_dataframe(df: pd.DataFrame) -> Dict[str, DataBlock]: + return {'': df} + + +def coerce_dict( + data_blocks: Union[DataBlock, Dict[str, DataBlock]] +) -> Dict[str, DataBlock]: + """Coerce dict into dict of data blocks.""" + # check if data is already Dict[str, DataBlock] + for k, v in data_blocks.items(): + if type(v) in (dict, pd.DataFrame): # + return data_blocks + # coerce if not + return {'': data_blocks} + + +def coerce_list(data_blocks: List[DataBlock]) -> Dict[str, DataBlock]: + """Coerces a list of DataFrames into a dict""" + return {f'{idx}': df for idx, df in enumerate(data_blocks)} + + +def write_blank_lines(file: Path, n: int): + with open(file, mode='a') as f: + f.write('\n' * n) + + +def write_package_info(file: Path): + date = datetime.now().strftime('%d/%m/%Y') + time = datetime.now().strftime('%H:%M:%S') + line = f'# Created by the starfile Python package (version {__version__}) at {time} on {date}' + with open(file, mode='w+') as f: + f.write(f'{line}\n') + + +def write_simple_block( + file: Path, + block_name: str, + data: Dict[str, Union[str, int, float]] +): + formatted_lines = '\n'.join( + [ + f'_{k}\t\t\t{v}' + for k, v + in data.items() + ] + ) + with open(file, mode='a') as f: + f.write(f'data_{block_name}\n\n') + f.write(formatted_lines) + f.write('\n\n\n') + + +def write_loop_block( + file: Path, + block_name: str, + df: pd.DataFrame, + float_format: str = '%.6f', + separator: str = '\t', + na_rep: str = '', +): + # write header + header_lines = [ + f'_{column_name} #{idx}' + for idx, column_name + in enumerate(df.columns, 1) + ] + with open(file, mode='a') as f: + f.write(f'data_{block_name}\n\n') + f.write('loop_\n') + f.write('\n'.join(header_lines)) + f.write('\n') + + # write data + df.to_csv( + path_or_buf=file, + mode='a', + sep=separator, + header=False, + index=False, + float_format=float_format, + na_rep=na_rep, + ) + write_blank_lines(file, n=2) diff --git a/tests/constants.py b/tests/constants.py index e75ca1c..3f52707 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -27,10 +27,3 @@ 'Price': [22000, 25000, 27000, 35000] } test_df = pd.DataFrame.from_dict(cars) - - -# Attributes of certain files -loop_simple_columns = ['rlnCoordinateX', 'rlnCoordinateY', 'rlnCoordinateZ', - 'rlnMicrographName', 'rlnMagnification', 'rlnDetectorPixelSize', - 'rlnCtfMaxResolution', 'rlnImageName', 'rlnCtfImage', 'rlnAngleRot', - 'rlnAngleTilt', 'rlnAnglePsi'] diff --git a/tests/data/single_line_end_of_multiblock.star b/tests/data/single_line_end_of_multiblock.star index a3511f0..409b9fe 100644 --- a/tests/data/single_line_end_of_multiblock.star +++ b/tests/data/single_line_end_of_multiblock.star @@ -1,6 +1,6 @@ # multi-current_line, python engine -data_block_3 +data_block_1 loop_ _rlnImageName #1 @@ -13,7 +13,7 @@ _rlnCoordinateZ #5 # single-current_line, c engine -data_block_3 +data_block_2 loop_ _rlnImageName #1 diff --git a/tests/data/single_line_middle_of_multiblock.star b/tests/data/single_line_middle_of_multiblock.star index 02ed528..783b7fd 100644 --- a/tests/data/single_line_middle_of_multiblock.star +++ b/tests/data/single_line_middle_of_multiblock.star @@ -1,6 +1,6 @@ -# single-current_line, python engine +# single-line, python engine -data_block_3 +data_block_1 loop_ _rlnImageName #1 @@ -10,9 +10,9 @@ _rlnCoordinateY #4 _rlnCoordinateZ #5 000001@0001_sum_particles.mrcs 0001_sum.mrc 587.000000 1268.000000 -51.52270 -# single-current_line, c engine +# single-line, c engine -data_block_3 +data_block_2 loop_ _rlnImageName #1 diff --git a/tests/test_functional_interface.py b/tests/test_functional_interface.py index a5d676c..8739cce 100644 --- a/tests/test_functional_interface.py +++ b/tests/test_functional_interface.py @@ -29,29 +29,14 @@ def test_write(): assert output_file.exists() -def test_write_fails_to_overwrite_without_flag(): - output_file = test_data_directory / 'test_overwrite_flag.star' - starfile.write(test_df, output_file, overwrite=True) - - assert output_file.exists() - with pytest.raises(FileExistsError): - starfile.write(test_df, output_file, overwrite=False) - starfile.new(test_df, output_file) - - -def test_write_overwrites_with_flag(): - output_file = test_data_directory / 'test_overwrite_flag.star' - starfile.write(test_df, output_file, overwrite=True) - +def test_write_overwrites_with_backup(): + output_file = test_data_directory / 'test_overwrite_backup.star' + starfile.write(test_df, output_file) assert output_file.exists() - starfile.write(test_df, output_file, overwrite=True) - -def test_write_with_float_format(): - output_file = test_data_directory / 'test_write_with_float_format.star' - test_df['float_col'] = 1.23456789 - starfile.write(test_df, output_file, float_format='%.3f', overwrite=True) - assert output_file.exists() + starfile.write(test_df, output_file) + backup = test_data_directory / 'test_overwrite_backup.star~' + assert backup.exists() def test_read_non_existent_file(): diff --git a/tests/test_parsing.py b/tests/test_parsing.py index c5f0232..dd6cad8 100644 --- a/tests/test_parsing.py +++ b/tests/test_parsing.py @@ -17,7 +17,6 @@ single_line_middle_of_multiblock, single_line_end_of_multiblock, non_existant_file, - loop_simple_columns, two_single_line_loop_blocks, two_basic_blocks, empty_loop, @@ -42,22 +41,34 @@ def test_read_loop_block(): """ Check that loop block is parsed correctly, data has the correct shape """ - s = StarParser(loop_simple) - - # Check the output - for idx, key in enumerate(s.dataframes.keys()): - # Check that only one object is present - assert idx < 1 - - # get dataframe - df = s.dataframes[key] - assert isinstance(df, pd.DataFrame) - - # Check shape of dataframe - assert df.shape == (16, 12) - - # check columns - assert all(df.columns == loop_simple_columns) + parser = StarParser(loop_simple) + + # Check that only one object is present + assert len(parser.data_blocks) == 1 + + # get dataframe + df = list(parser.data_blocks.values())[0] + assert isinstance(df, pd.DataFrame) + + # Check shape of dataframe + assert df.shape == (16, 12) + + # check columns + expected_columns = [ + 'rlnCoordinateX', + 'rlnCoordinateY', + 'rlnCoordinateZ', + 'rlnMicrographName', + 'rlnMagnification', + 'rlnDetectorPixelSize', + 'rlnCtfMaxResolution', + 'rlnImageName', + 'rlnCtfImage', + 'rlnAngleRot', + 'rlnAngleTilt', + 'rlnAnglePsi', + ] + assert all(df.columns == expected_columns) def test_read_multiblock_file(): @@ -65,35 +76,44 @@ def test_read_multiblock_file(): Check that multiblock STAR files such as postprocess RELION files parse properly """ - s = StarParser(postprocess) - assert len(s.dataframes) == 3 - - for key, df in s.dataframes.items(): - assert isinstance(df, pd.DataFrame) - - assert s.dataframes['general'].shape == (1, 6) - assert all( - ['rlnFinalResolution', 'rlnBfactorUsedForSharpening', 'rlnUnfilteredMapHalf1', - 'rlnUnfilteredMapHalf2', 'rlnMaskName', 'rlnRandomiseFrom'] - == s.dataframes['general'].columns) - assert s.dataframes['fsc'].shape == (49, 7) - assert s.dataframes['guinier'].shape == (49, 3) + parser = StarParser(postprocess) + assert len(parser.data_blocks) == 3 + + assert 'general' in parser.data_blocks + assert isinstance(parser.data_blocks['general'], dict) + assert len(parser.data_blocks['general']) == 6 + columns = list(parser.data_blocks['general'].keys()) + expected_columns = [ + 'rlnFinalResolution', + 'rlnBfactorUsedForSharpening', + 'rlnUnfilteredMapHalf1', + 'rlnUnfilteredMapHalf2', + 'rlnMaskName', + 'rlnRandomiseFrom', + ] + assert columns == expected_columns + + assert 'fsc' in parser.data_blocks + assert isinstance(parser.data_blocks['fsc'], pd.DataFrame) + assert parser.data_blocks['fsc'].shape == (49, 7) + + assert 'guinier' in parser.data_blocks + assert isinstance(parser.data_blocks['guinier'], pd.DataFrame) + assert parser.data_blocks['guinier'].shape == (49, 3) def test_read_pipeline(): """ Check that a pipeline.star file is parsed correctly """ - s = StarParser(pipeline) - for key, df in s.dataframes.items(): - assert isinstance(df, pd.DataFrame) + parser = StarParser(pipeline) - # Check that dataframes have the correct shapes - assert s.dataframes['pipeline_general'].shape == (1, 1) - assert s.dataframes['pipeline_processes'].shape == (31, 4) - assert s.dataframes['pipeline_nodes'].shape == (74, 2) - assert s.dataframes['pipeline_input_edges'].shape == (48, 2) - assert s.dataframes['pipeline_output_edges'].shape == (72, 2) + # Check that data match file contents + assert isinstance(parser.data_blocks['pipeline_general'], dict) + assert parser.data_blocks['pipeline_processes'].shape == (31, 4) + assert parser.data_blocks['pipeline_nodes'].shape == (74, 2) + assert parser.data_blocks['pipeline_input_edges'].shape == (48, 2) + assert parser.data_blocks['pipeline_output_edges'].shape == (72, 2) def test_read_rln31(): @@ -102,12 +122,12 @@ def test_read_rln31(): """ s = StarParser(rln31_style) - for key, df in s.dataframes.items(): + for key, df in s.data_blocks.items(): assert isinstance(df, pd.DataFrame) - assert isinstance(s.dataframes['block_1'], pd.DataFrame) - assert isinstance(s.dataframes['block_2'], pd.DataFrame) - assert isinstance(s.dataframes['block_3'], pd.DataFrame) + assert isinstance(s.data_blocks['block_1'], pd.DataFrame) + assert isinstance(s.data_blocks['block_2'], pd.DataFrame) + assert isinstance(s.data_blocks['block_3'], pd.DataFrame) def test_read_n_blocks(): @@ -116,68 +136,56 @@ def test_read_n_blocks(): number of data blocks from a star file """ # test 1 block - s = StarParser(postprocess, read_n_blocks=1) - assert len(s.dataframes) == 1 + s = StarParser(postprocess, n_blocks_to_read=1) + assert len(s.data_blocks) == 1 # test 2 blocks - s = StarParser(postprocess, read_n_blocks=2) - assert len(s.dataframes) == 2 + s = StarParser(postprocess, n_blocks_to_read=2) + assert len(s.data_blocks) == 2 def test_single_line_middle_of_multiblock(): s = StarParser(single_line_middle_of_multiblock) - assert len(s.dataframes) == 2 + assert len(s.data_blocks) == 2 def test_single_line_end_of_multiblock(): s = StarParser(single_line_end_of_multiblock) - assert len(s.dataframes) == 2 + assert len(s.data_blocks) == 2 # iterate over dataframes, checking keys, names and shapes - for idx, (key, df) in enumerate(s.dataframes.items()): - assert df.name == 'block_3' + for idx, (key, df) in enumerate(s.data_blocks.items()): if idx == 0: - assert key == 'block_3' + assert key == 'block_1' assert df.shape == (2, 5) if idx == 1: - assert key == 1 + assert key == 'block_2' assert df.shape == (1, 5) def test_read_optimiser_2d(): - s = StarParser(optimiser_2d) - assert len(s.dataframes) == 1 - assert s.dataframes['optimiser_general'].shape == (1, 84) + parser = StarParser(optimiser_2d) + assert len(parser.data_blocks) == 1 + assert len(parser.data_blocks['optimiser_general']) == 84 def test_read_optimiser_3d(): - s = StarParser(optimiser_3d) - assert len(s.dataframes) == 1 - assert s.dataframes['optimiser_general'].shape == (1, 84) + parser = StarParser(optimiser_3d) + assert len(parser.data_blocks) == 1 + assert len(parser.data_blocks['optimiser_general']) == 84 def test_read_sampling_2d(): - s = StarParser(sampling_2d) - assert len(s.dataframes) == 1 - assert s.dataframes['sampling_general'].shape == (1, 12) + parser = StarParser(sampling_2d) + assert len(parser.data_blocks) == 1 + assert len(parser.data_blocks['sampling_general']) == 12 def test_read_sampling_3d(): - s = StarParser(sampling_3d) - assert len(s.dataframes) == 2 - assert s.dataframes['sampling_general'].shape == (1, 15) - assert s.dataframes['sampling_directions'].shape == (192, 2) - - -def test_df_as_list(): - s = StarParser(sampling_3d) - assert isinstance(s.dataframes_as_list(), list) - assert len(s.dataframes_as_list()) == 2 - - -def test_first_dataframe(): - s = StarParser(sampling_3d) - assert isinstance(s.first_dataframe, pd.DataFrame) + parser = StarParser(sampling_3d) + assert len(parser.data_blocks) == 2 + assert len(parser.data_blocks['sampling_general']) == 15 + assert parser.data_blocks['sampling_directions'].shape == (192, 2) def test_parsing_speed(): @@ -193,27 +201,39 @@ def test_parsing_speed(): def test_two_single_line_loop_blocks(): parser = StarParser(two_single_line_loop_blocks) - assert len(parser.dataframes) == 2 + assert len(parser.data_blocks) == 2 np.testing.assert_array_equal( - parser.dataframes['block_0'].columns, [f'val{i}' for i in (1, 2, 3)] + parser.data_blocks['block_0'].columns, [f'val{i}' for i in (1, 2, 3)] ) - assert parser.dataframes['block_0'].shape == (1, 3) + assert parser.data_blocks['block_0'].shape == (1, 3) np.testing.assert_array_equal( - parser.dataframes['block_1'].columns, [f'col{i}' for i in (1, 2, 3)] + parser.data_blocks['block_1'].columns, [f'col{i}' for i in (1, 2, 3)] ) - assert parser.dataframes['block_1'].shape == (1, 3) + assert parser.data_blocks['block_1'].shape == (1, 3) def test_two_basic_blocks(): parser = StarParser(two_basic_blocks) - assert len(parser.dataframes) == 2 - for df in parser.dataframes.values(): - assert df.shape == (1, 3) + assert len(parser.data_blocks) == 2 + assert 'block_0' in parser.data_blocks + b0 = parser.data_blocks['block_0'] + assert b0 == { + 'val1': 1.0, + 'val2': 2.0, + 'val3': 3.0, + } + assert 'block_1' in parser.data_blocks + b1 = parser.data_blocks['block_1'] + assert b1 == { + 'col1': 'A', + 'col2': 'B', + 'col3': 'C', + } def test_empty_loop_block(): """Parsing an empty loop block should return an empty dataframe.""" parser = StarParser(empty_loop) - assert len(parser.dataframes) == 1 + assert len(parser.data_blocks) == 1 diff --git a/tests/test_read_write_round_trip.py b/tests/test_read_write_round_trip.py index e30a6ce..d28bfb3 100644 --- a/tests/test_read_write_round_trip.py +++ b/tests/test_read_write_round_trip.py @@ -1,4 +1,4 @@ -from .constants import two_single_line_loop_blocks +from .constants import two_single_line_loop_blocks, postprocess import starfile import pandas as pd @@ -19,3 +19,20 @@ def test_round_trip_two_single_line_loop_blocks(tmp_path): for expected_df, actual_df in zip(expected.values(), star_after_round_trip.values()): pd.testing.assert_frame_equal(expected_df, actual_df) + +def test_round_trip_postprocess(tmp_path): + expected = starfile.read(postprocess) + + # write + output_file = tmp_path / 'two_single_line_loop_blocks.star' + starfile.write(expected, output_file) + + # read + star_after_round_trip = starfile.read(output_file) + + # assert + for _expected, _actual in zip(expected.values(), star_after_round_trip.values()): + if isinstance(_actual, pd.DataFrame): + pd.testing.assert_frame_equal(_actual, _expected, atol=1e-6) + else: + assert _actual == _expected \ No newline at end of file diff --git a/tests/test_writing.py b/tests/test_writing.py index 23b5ddc..819f285 100644 --- a/tests/test_writing.py +++ b/tests/test_writing.py @@ -6,35 +6,34 @@ from starfile.parser import StarParser from starfile.writer import StarWriter -from .constants import loop_simple, postprocess, pipeline, rln31_style, optimiser_2d, optimiser_3d, sampling_2d, \ - sampling_3d, test_data_directory, test_df +from .constants import loop_simple, postprocess, test_data_directory, test_df def test_write_simple_block(): s = StarParser(postprocess) output_file = test_data_directory / 'basic_block.star' - StarWriter(s.dataframes, output_file, overwrite=True) + StarWriter(s.data_blocks, output_file) assert output_file.exists() def test_write_loop(): s = StarParser(loop_simple) output_file = test_data_directory / 'loop_block.star' - StarWriter(s.dataframes, output_file, overwrite=True) + StarWriter(s.data_blocks, output_file) assert output_file.exists() def test_write_multiblock(): s = StarParser(postprocess) output_file = test_data_directory / 'multiblock.star' - StarWriter(s.dataframes, output_file, overwrite=True) + StarWriter(s.data_blocks, output_file) assert output_file.exists() def test_from_single_dataframe(): output_file = test_data_directory / 'from_df.star' - StarWriter(test_df, output_file, overwrite=True) + StarWriter(test_df, output_file) assert output_file.exists() s = StarParser(output_file) @@ -44,15 +43,16 @@ def test_create_from_dataframes(): dfs = [test_df, test_df] output_file = test_data_directory / 'from_list.star' - StarWriter(dfs, output_file, overwrite=True) + StarWriter(dfs, output_file) assert output_file.exists() s = StarParser(output_file) - assert len(s.dataframes) == 2 + assert len(s.data_blocks) == 2 + def test_can_write_non_zero_indexed_one_row_dataframe(): # see PR #13 - https://github.com/alisterburt/starfile/pull/13 - df = pd.DataFrame([[1,2,3]], columns=["A", "B", "C"]) + df = pd.DataFrame([[1, 2, 3]], columns=["A", "B", "C"]) df.index += 1 with TemporaryDirectory() as directory: @@ -62,8 +62,9 @@ def test_can_write_non_zero_indexed_one_row_dataframe(): output = output_file.read() expected = ( - "_A\t\t\t1\n" - "_B\t\t\t2\n" - "_C\t\t\t3\n" + "_A #1\n" + "_B #2\n" + "_C #3\n" + "1\t2\t3" ) assert (expected in output)