diff --git a/src/atomate2/utils/file_client.py b/src/atomate2/utils/file_client.py index ba3989ac0e..5e4aca0ef3 100644 --- a/src/atomate2/utils/file_client.py +++ b/src/atomate2/utils/file_client.py @@ -353,7 +353,7 @@ def gzip( path: str | Path, host: str | None = None, compresslevel: int = 6, - force: bool = False, + force: bool | str = False, ): """ Gzip a file. @@ -367,7 +367,12 @@ def gzip( compresslevel : bool Level of compression, 1-9. 9 is default for GzipFile, 6 is default for gzip. force : bool - Overwrite gzipped file if it already exists. + How to handle writing a gzipped file if it already exists. Accepts + either a string or bool: + + - `"force"` or `True`: Overwrite gzipped file if it already exists. + - `"raise"` or `False`: Raise an error if file already exists. + - `"skip"` Skip file if it already exists. """ path = self.abspath(path, host=host) path_gz = path.parent / f"{path.name}.gz" @@ -380,8 +385,21 @@ def gzip( warnings.warn(f"{path} is a directory, skipping...", stacklevel=1) return - if self.exists(path_gz, host=host) and not force: - raise FileExistsError(f"{path_gz} file already exists.") + if self.exists(path_gz, host=host): + if force is False or force == "raise": + raise FileExistsError(f"{path_gz} file already exists") + if force is True or force == "force": + pass + elif force == "skip": + warnings.warn( + f"{path_gz} file already exists, skipping...", stacklevel=2 + ) + return + else: + raise ValueError( + f"Invalid value for force: {force} " + "(must be True, False, 'raise', 'force', or 'skip'))" + ) if host is None: with open(path, "rb") as f_in, GzipFile( @@ -398,7 +416,7 @@ def gunzip( self, path: str | Path, host: str | None = None, - force: bool = False, + force: bool | str = False, ): """ Ungzip a file. @@ -410,7 +428,12 @@ def gunzip( host : str or None A remote file system host on which to perform file operations. force : bool - Overwrite non-gzipped file if it already exists. + How to handle writing a non-gzipped file if it already exists. Accepts + either a string or bool: + + - `"force"` or `True`: Overwrite non-gzipped file if it already exists. + - `"raise"` or `False`: Raise an error if file already exists. + - `"skip"` Skip file if it already exists. """ path = self.abspath(path, host=host) path_nongz = path.with_suffix("") @@ -419,8 +442,21 @@ def gunzip( warnings.warn(f"{path} is not gzipped, skipping...", stacklevel=2) return - if self.exists(path_nongz, host=host) and not force: - raise FileExistsError(f"{path_nongz} file already exists") + if self.exists(path_nongz, host=host): + if force is False or force == "raise": + raise FileExistsError(f"{path_nongz} file already exists") + if force is True or force == "force": + pass + elif force == "skip": + warnings.warn( + f"{path_nongz} file already exists, skipping...", stacklevel=2 + ) + return + else: + raise ValueError( + f"Invalid value for force: {force} " + "(must be True, False, 'raise', 'force', or 'skip'))" + ) if host is None: with open(path_nongz, "wb") as f_out, zopen(path, "rb") as f_in: diff --git a/src/atomate2/vasp/files.py b/src/atomate2/vasp/files.py index 6da912e66d..cd7102c70a 100644 --- a/src/atomate2/vasp/files.py +++ b/src/atomate2/vasp/files.py @@ -29,7 +29,7 @@ def copy_vasp_outputs( src_host: str | None = None, additional_vasp_files: Sequence[str] = (), contcar_to_poscar: bool = True, - force_overwrite: bool = False, + force_overwrite: bool | str = False, file_client: FileClient | None = None, ): """ @@ -53,8 +53,13 @@ def copy_vasp_outputs( Additional files to copy, e.g. ["CHGCAR", "WAVECAR"]. contcar_to_poscar : bool Move CONTCAR to POSCAR (original POSCAR is not copied). - force_overwrite : bool - If True, overwrite existing files during the copy step. + force_overwrite : bool or str + How to handle overwriting existing files during the copy step. Accepts + either a string or bool: + + - `"force"` or `True`: Overwrite existing files if they already exist. + - `"raise"` or `False`: Raise an error if files already exist. + - `"skip"` Skip files they already exist. file_client : .FileClient A file client to use for performing file operations. """ diff --git a/tests/common/test_files.py b/tests/common/test_files.py new file mode 100644 index 0000000000..8204a1ecaf --- /dev/null +++ b/tests/common/test_files.py @@ -0,0 +1,30 @@ +def test_gunzip_force_overwrites(tmp_path): + from atomate2.common.files import gunzip_files, gzip_files + + files = ["file1", "file2", "file3"] + for fname in files: + f = tmp_path / fname + f.write_text(fname) + gzip_files(tmp_path) + + for fname in files: + f = tmp_path / fname + f.write_text(f"{fname} overwritten") + # "file1" in the zipped files and "file1 overwritten" in the unzipped files + gunzip_files(tmp_path, force=True) + + for fname in files: + f = tmp_path / fname + assert f.read_text() == fname + + gzip_files(tmp_path) + + for fname in files: + f = tmp_path / fname + f.write_text(f"{fname} overwritten") + + # "file1" in the zipped files and "file1 overwritten" in the unzipped files + gunzip_files(tmp_path, force="skip") + for fname in files: + f = tmp_path / fname + assert f.read_text() == f"{fname} overwritten"