diff --git a/src/starfile/functions.py b/src/starfile/functions.py index 792f666..4ba0669 100644 --- a/src/starfile/functions.py +++ b/src/starfile/functions.py @@ -55,6 +55,7 @@ def write( na_rep: str = '', quote_character: str = '"', quote_all_strings: bool = False, + include_field_numbers: bool = True, **kwargs ): """Write data to disk in the STAR format. @@ -72,6 +73,11 @@ def write( Separator between values, will be passed to pandas. na_rep: str Representation of null values, will be passed to pandas. + include_field_numbers: bool + Whether field numbers should be included after field names in the ouput file. + Default is True which includes field numbers (i.e. `_rlnImageName #1`) and is + compatible with RELION and Python STOPGAP. False excludes field numbers (i.e. + `_motl_idx`) and is compatible with legacy/MATLAB STOPGAP. """ StarWriter( data, @@ -81,6 +87,7 @@ def write( separator=sep, quote_character=quote_character, quote_all_strings=quote_all_strings, + include_field_numbers=include_field_numbers, ).write() diff --git a/src/starfile/writer.py b/src/starfile/writer.py index be8d784..0418857 100644 --- a/src/starfile/writer.py +++ b/src/starfile/writer.py @@ -27,6 +27,7 @@ def __init__( na_rep: str = '', quote_character: str = '"', quote_all_strings: bool = False, + include_field_numbers: bool = True, ): # coerce data self.data_blocks = self.coerce_data_blocks(data_blocks) @@ -40,6 +41,7 @@ def __init__( self.na_rep = na_rep self.quote_character = quote_character self.quote_all_strings = quote_all_strings + self.include_field_numbers = include_field_numbers self.buffer = TextBuffer() self.backup_if_file_exists() @@ -93,7 +95,8 @@ def data_block_generator(self) -> Generator[str, None, None]: separator=self.sep, na_rep=self.na_rep, quote_character=self.quote_character, - quote_all_strings=self.quote_all_strings + quote_all_strings=self.quote_all_strings, + include_field_numbers=self.include_field_numbers, ): yield line @@ -163,7 +166,8 @@ def loop_block( separator: str = '\t', na_rep: str = '', quote_character: str = '"', - quote_all_strings: bool = False + quote_all_strings: bool = False, + include_field_numbers: bool = True, ) -> Generator[str, None, None]: # Header @@ -171,7 +175,7 @@ def loop_block( yield '' yield 'loop_' for idx, column_name in enumerate(df.columns, 1): - yield f'_{column_name} #{idx}' + yield f'_{column_name} #{idx}' if include_field_numbers else f'_{column_name}' # Data for line in df.map(lambda x: diff --git a/tests/test_writing.py b/tests/test_writing.py index 8e0ee00..135e319 100644 --- a/tests/test_writing.py +++ b/tests/test_writing.py @@ -72,6 +72,24 @@ def test_can_write_non_zero_indexed_one_row_dataframe(): assert (expected in output) +@pytest.mark.parametrize("include_field_numbers, expected", + [ + (True, "_Brand #1\n_Price #2\n"), + (False, "_Brand\n_Price\n"), + ]) +def test_include_exclude_field_numbers(include_field_numbers, expected): + with TemporaryDirectory() as directory: + filename = join_path(directory, "test.star") + StarWriter( + test_df, + filename, + include_field_numbers=include_field_numbers + ).write() + with open(filename) as output_file: + output = output_file.read() + assert (expected in output) + + @pytest.mark.parametrize("quote_character, quote_all_strings, num_quotes", [('"', False, 6), ('"', True, 8),